Movatterモバイル変換


[0]ホーム

URL:


LLVM 20.0.0git
WebAssemblyFixFunctionBitcasts.cpp
Go to the documentation of this file.
1//===-- WebAssemblyFixFunctionBitcasts.cpp - Fix function bitcasts --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file
10/// Fix bitcasted functions.
11///
12/// WebAssembly requires caller and callee signatures to match, however in LLVM,
13/// some amount of slop is vaguely permitted. Detect mismatch by looking for
14/// bitcasts of functions and rewrite them to use wrapper functions instead.
15///
16/// This doesn't catch all cases, such as when a function's address is taken in
17/// one place and casted in another, but it works for many common cases.
18///
19/// Note that LLVM already optimizes away function bitcasts in common cases by
20/// dropping arguments as needed, so this pass only ends up getting used in less
21/// common cases.
22///
23//===----------------------------------------------------------------------===//
24
25#include "WebAssembly.h"
26#include "llvm/IR/Constants.h"
27#include "llvm/IR/Instructions.h"
28#include "llvm/IR/Module.h"
29#include "llvm/IR/Operator.h"
30#include "llvm/Pass.h"
31#include "llvm/Support/Debug.h"
32#include "llvm/Support/raw_ostream.h"
33using namespacellvm;
34
35#define DEBUG_TYPE "wasm-fix-function-bitcasts"
36
37namespace{
38classFixFunctionBitcasts final :publicModulePass {
39StringRefgetPassName() const override{
40return"WebAssembly Fix Function Bitcasts";
41 }
42
43voidgetAnalysisUsage(AnalysisUsage &AU) const override{
44 AU.setPreservesCFG();
45ModulePass::getAnalysisUsage(AU);
46 }
47
48boolrunOnModule(Module &M)override;
49
50public:
51staticcharID;
52 FixFunctionBitcasts() :ModulePass(ID) {}
53};
54}// End anonymous namespace
55
56char FixFunctionBitcasts::ID = 0;
57INITIALIZE_PASS(FixFunctionBitcasts,DEBUG_TYPE,
58"Fix mismatching bitcasts for WebAssembly",false,false)
59
60ModulePass *llvm::createWebAssemblyFixFunctionBitcasts() {
61returnnew FixFunctionBitcasts();
62}
63
64// Recursively descend the def-use lists from V to find non-bitcast users of
65// bitcasts of V.
66staticvoidfindUses(Value *V,Function &F,
67SmallVectorImpl<std::pair<CallBase *, Function *>> &Uses) {
68for (User *U : V->users()) {
69if (auto *BC = dyn_cast<BitCastOperator>(U))
70findUses(BC,F,Uses);
71elseif (auto *A = dyn_cast<GlobalAlias>(U))
72findUses(A,F,Uses);
73elseif (auto *CB = dyn_cast<CallBase>(U)) {
74Value *Callee = CB->getCalledOperand();
75if (Callee != V)
76// Skip calls where the function isn't the callee
77continue;
78if (CB->getFunctionType() ==F.getValueType())
79// Skip uses that are immediately called
80continue;
81Uses.push_back(std::make_pair(CB, &F));
82 }
83 }
84}
85
86// Create a wrapper function with type Ty that calls F (which may have a
87// different type). Attempt to support common bitcasted function idioms:
88// - Call with more arguments than needed: arguments are dropped
89// - Call with fewer arguments than needed: arguments are filled in with poison
90// - Return value is not needed: drop it
91// - Return value needed but not present: supply a poison value
92//
93// If the all the argument types of trivially castable to one another (i.e.
94// I32 vs pointer type) then we don't create a wrapper at all (return nullptr
95// instead).
96//
97// If there is a type mismatch that we know would result in an invalid wasm
98// module then generate wrapper that contains unreachable (i.e. abort at
99// runtime). Such programs are deep into undefined behaviour territory,
100// but we choose to fail at runtime rather than generate and invalid module
101// or fail at compiler time. The reason we delay the error is that we want
102// to support the CMake which expects to be able to compile and link programs
103// that refer to functions with entirely incorrect signatures (this is how
104// CMake detects the existence of a function in a toolchain).
105//
106// For bitcasts that involve struct types we don't know at this stage if they
107// would be equivalent at the wasm level and so we can't know if we need to
108// generate a wrapper.
109staticFunction *createWrapper(Function *F,FunctionType *Ty) {
110Module *M =F->getParent();
111
112Function *Wrapper =Function::Create(Ty, Function::PrivateLinkage,
113F->getName() +"_bitcast", M);
114Wrapper->setAttributes(F->getAttributes());
115BasicBlock *BB =BasicBlock::Create(M->getContext(),"body",Wrapper);
116constDataLayout &DL = BB->getDataLayout();
117
118// Determine what arguments to pass.
119SmallVector<Value *, 4> Args;
120Function::arg_iterator AI =Wrapper->arg_begin();
121Function::arg_iterator AE =Wrapper->arg_end();
122FunctionType::param_iterator PI =F->getFunctionType()->param_begin();
123FunctionType::param_iterator PE =F->getFunctionType()->param_end();
124bool TypeMismatch =false;
125bool WrapperNeeded =false;
126
127Type *ExpectedRtnType =F->getFunctionType()->getReturnType();
128Type *RtnType = Ty->getReturnType();
129
130if ((F->getFunctionType()->getNumParams() != Ty->getNumParams()) ||
131 (F->getFunctionType()->isVarArg() != Ty->isVarArg()) ||
132 (ExpectedRtnType != RtnType))
133 WrapperNeeded =true;
134
135for (; AI != AE && PI != PE; ++AI, ++PI) {
136Type *ArgType = AI->getType();
137Type *ParamType = *PI;
138
139if (ArgType == ParamType) {
140 Args.push_back(&*AI);
141 }else {
142if (CastInst::isBitOrNoopPointerCastable(ArgType, ParamType,DL)) {
143Instruction *PtrCast =
144CastInst::CreateBitOrPointerCast(AI, ParamType,"cast");
145 PtrCast->insertInto(BB, BB->end());
146 Args.push_back(PtrCast);
147 }elseif (ArgType->isStructTy() || ParamType->isStructTy()) {
148LLVM_DEBUG(dbgs() <<"createWrapper: struct param type in bitcast: "
149 <<F->getName() <<"\n");
150 WrapperNeeded =false;
151 }else {
152LLVM_DEBUG(dbgs() <<"createWrapper: arg type mismatch calling: "
153 <<F->getName() <<"\n");
154LLVM_DEBUG(dbgs() <<"Arg[" << Args.size() <<"] Expected: "
155 << *ParamType <<" Got: " << *ArgType <<"\n");
156 TypeMismatch =true;
157break;
158 }
159 }
160 }
161
162if (WrapperNeeded && !TypeMismatch) {
163for (; PI != PE; ++PI)
164 Args.push_back(PoisonValue::get(*PI));
165if (F->isVarArg())
166for (; AI != AE; ++AI)
167 Args.push_back(&*AI);
168
169CallInst *Call =CallInst::Create(F, Args,"", BB);
170
171Type *ExpectedRtnType =F->getFunctionType()->getReturnType();
172Type *RtnType = Ty->getReturnType();
173// Determine what value to return.
174if (RtnType->isVoidTy()) {
175ReturnInst::Create(M->getContext(), BB);
176 }elseif (ExpectedRtnType->isVoidTy()) {
177LLVM_DEBUG(dbgs() <<"Creating dummy return: " << *RtnType <<"\n");
178ReturnInst::Create(M->getContext(),PoisonValue::get(RtnType), BB);
179 }elseif (RtnType == ExpectedRtnType) {
180ReturnInst::Create(M->getContext(), Call, BB);
181 }elseif (CastInst::isBitOrNoopPointerCastable(ExpectedRtnType, RtnType,
182DL)) {
183Instruction *Cast =
184CastInst::CreateBitOrPointerCast(Call, RtnType,"cast");
185 Cast->insertInto(BB, BB->end());
186ReturnInst::Create(M->getContext(), Cast, BB);
187 }elseif (RtnType->isStructTy() || ExpectedRtnType->isStructTy()) {
188LLVM_DEBUG(dbgs() <<"createWrapper: struct return type in bitcast: "
189 <<F->getName() <<"\n");
190 WrapperNeeded =false;
191 }else {
192LLVM_DEBUG(dbgs() <<"createWrapper: return type mismatch calling: "
193 <<F->getName() <<"\n");
194LLVM_DEBUG(dbgs() <<"Expected: " << *ExpectedRtnType
195 <<" Got: " << *RtnType <<"\n");
196 TypeMismatch =true;
197 }
198 }
199
200if (TypeMismatch) {
201// Create a new wrapper that simply contains `unreachable`.
202Wrapper->eraseFromParent();
203Wrapper =Function::Create(Ty, Function::PrivateLinkage,
204F->getName() +"_bitcast_invalid", M);
205Wrapper->setAttributes(F->getAttributes());
206BasicBlock *BB =BasicBlock::Create(M->getContext(),"body",Wrapper);
207newUnreachableInst(M->getContext(), BB);
208Wrapper->setName(F->getName() +"_bitcast_invalid");
209 }elseif (!WrapperNeeded) {
210LLVM_DEBUG(dbgs() <<"createWrapper: no wrapper needed: " <<F->getName()
211 <<"\n");
212Wrapper->eraseFromParent();
213returnnullptr;
214 }
215LLVM_DEBUG(dbgs() <<"createWrapper: " <<F->getName() <<"\n");
216returnWrapper;
217}
218
219// Test whether a main function with type FuncTy should be rewritten to have
220// type MainTy.
221staticboolshouldFixMainFunction(FunctionType *FuncTy,FunctionType *MainTy) {
222// Only fix the main function if it's the standard zero-arg form. That way,
223// the standard cases will work as expected, and users will see signature
224// mismatches from the linker for non-standard cases.
225return FuncTy->getReturnType() == MainTy->getReturnType() &&
226 FuncTy->getNumParams() == 0 &&
227 !FuncTy->isVarArg();
228}
229
230bool FixFunctionBitcasts::runOnModule(Module &M) {
231LLVM_DEBUG(dbgs() <<"********** Fix Function Bitcasts **********\n");
232
233Function *Main =nullptr;
234CallInst *CallMain =nullptr;
235SmallVector<std::pair<CallBase *, Function *>, 0>Uses;
236
237// Collect all the places that need wrappers.
238for (Function &F : M) {
239// Skip to fix when the function is swiftcc because swiftcc allows
240// bitcast type difference for swiftself and swifterror.
241if (F.getCallingConv() ==CallingConv::Swift)
242continue;
243findUses(&F,F,Uses);
244
245// If we have a "main" function, and its type isn't
246// "int main(int argc, char *argv[])", create an artificial call with it
247// bitcasted to that type so that we generate a wrapper for it, so that
248// the C runtime can call it.
249if (F.getName() =="main") {
250 Main = &F;
251LLVMContext &C =M.getContext();
252Type *MainArgTys[] = {Type::getInt32Ty(C), PointerType::get(C, 0)};
253FunctionType *MainTy = FunctionType::get(Type::getInt32Ty(C), MainArgTys,
254/*isVarArg=*/false);
255if (shouldFixMainFunction(F.getFunctionType(), MainTy)) {
256LLVM_DEBUG(dbgs() <<"Found `main` function with incorrect type: "
257 << *F.getFunctionType() <<"\n");
258Value *Args[] = {PoisonValue::get(MainArgTys[0]),
259PoisonValue::get(MainArgTys[1])};
260 CallMain =CallInst::Create(MainTy, Main, Args,"call_main");
261Uses.push_back(std::make_pair(CallMain, &F));
262 }
263 }
264 }
265
266DenseMap<std::pair<Function *, FunctionType *>,Function *> Wrappers;
267
268for (auto &UseFunc :Uses) {
269CallBase *CB = UseFunc.first;
270Function *F = UseFunc.second;
271FunctionType *Ty = CB->getFunctionType();
272
273auto Pair = Wrappers.insert(std::make_pair(std::make_pair(F, Ty),nullptr));
274if (Pair.second)
275 Pair.first->second =createWrapper(F, Ty);
276
277Function *Wrapper = Pair.first->second;
278if (!Wrapper)
279continue;
280
281 CB->setCalledOperand(Wrapper);
282 }
283
284// If we created a wrapper for main, rename the wrapper so that it's the
285// one that gets called from startup.
286if (CallMain) {
287 Main->setName("__original_main");
288auto *MainWrapper =
289 cast<Function>(CallMain->getCalledOperand()->stripPointerCasts());
290delete CallMain;
291if (Main->isDeclaration()) {
292// The wrapper is not needed in this case as we don't need to export
293// it to anyone else.
294 MainWrapper->eraseFromParent();
295 }else {
296// Otherwise give the wrapper the same linkage as the original main
297// function, so that it can be called from the same places.
298 MainWrapper->setName("main");
299 MainWrapper->setLinkage(Main->getLinkage());
300 MainWrapper->setVisibility(Main->getVisibility());
301 }
302 }
303
304returntrue;
305}
Wrapper
amdgpu aa AMDGPU Address space based Alias Analysis Wrapper
Definition:AMDGPUAliasAnalysis.cpp:31
DL
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Definition:ARMSLSHardening.cpp:73
A
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
Constants.h
This file contains the declarations for the subclasses of Constant, which represent the different fla...
Debug.h
LLVM_DEBUG
#define LLVM_DEBUG(...)
Definition:Debug.h:106
Module.h
Module.h This file contains the declarations for the Module class.
Operator.h
Instructions.h
F
#define F(x, y, z)
Definition:MD5.cpp:55
INITIALIZE_PASS
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
Definition:PassSupport.h:38
Pass.h
Uses
Remove Loads Into Fake Uses
Definition:RemoveLoadsIntoFakeUses.cpp:75
findUses
static void findUses(Value *V, Function &F, SmallVectorImpl< std::pair< CallBase *, Function * > > &Uses)
Definition:WebAssemblyFixFunctionBitcasts.cpp:66
shouldFixMainFunction
static bool shouldFixMainFunction(FunctionType *FuncTy, FunctionType *MainTy)
Definition:WebAssemblyFixFunctionBitcasts.cpp:221
createWrapper
static Function * createWrapper(Function *F, FunctionType *Ty)
Definition:WebAssemblyFixFunctionBitcasts.cpp:109
DEBUG_TYPE
#define DEBUG_TYPE
Definition:WebAssemblyFixFunctionBitcasts.cpp:35
WebAssembly.h
This file contains the entry points for global functions defined in the LLVM WebAssembly back-end.
FunctionType
Definition:ItaniumDemangle.h:823
llvm::AnalysisUsage
Represent the analysis usage information of a pass.
Definition:PassAnalysisSupport.h:47
llvm::AnalysisUsage::setPreservesCFG
void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition:Pass.cpp:256
llvm::Argument
This class represents an incoming formal argument to a Function.
Definition:Argument.h:31
llvm::BasicBlock
LLVM Basic Block Representation.
Definition:BasicBlock.h:61
llvm::BasicBlock::end
iterator end()
Definition:BasicBlock.h:474
llvm::BasicBlock::Create
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
Definition:BasicBlock.h:213
llvm::BasicBlock::getDataLayout
const DataLayout & getDataLayout() const
Get the data layout of the module this basic block belongs to.
Definition:BasicBlock.cpp:296
llvm::CallBase
Base class for all callable instructions (InvokeInst and CallInst) Holds everything related to callin...
Definition:InstrTypes.h:1112
llvm::CallBase::getCalledOperand
Value * getCalledOperand() const
Definition:InstrTypes.h:1334
llvm::CallBase::getFunctionType
FunctionType * getFunctionType() const
Definition:InstrTypes.h:1199
llvm::CallBase::setCalledOperand
void setCalledOperand(Value *V)
Definition:InstrTypes.h:1377
llvm::CallInst
This class represents a function call, abstracting a target machine's calling convention.
Definition:Instructions.h:1479
llvm::CallInst::Create
static CallInst * Create(FunctionType *Ty, Value *F, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
Definition:Instructions.h:1514
llvm::CastInst::isBitOrNoopPointerCastable
static bool isBitOrNoopPointerCastable(Type *SrcTy, Type *DestTy, const DataLayout &DL)
Check whether a bitcast, inttoptr, or ptrtoint cast between these types is valid and a no-op.
Definition:Instructions.cpp:3122
llvm::CastInst::CreateBitOrPointerCast
static CastInst * CreateBitOrPointerCast(Value *S, Type *Ty, const Twine &Name="", InsertPosition InsertBefore=nullptr)
Create a BitCast, a PtrToInt, or an IntToPTr cast instruction.
Definition:Instructions.cpp:3047
llvm::DataLayout
A parsed version of the target data layout string in and methods for querying it.
Definition:DataLayout.h:63
llvm::DenseMapBase::insert
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition:DenseMap.h:211
llvm::DenseMap
Definition:DenseMap.h:727
llvm::FunctionType::param_iterator
Type::subtype_iterator param_iterator
Definition:DerivedTypes.h:128
llvm::Function
Definition:Function.h:63
llvm::Function::Create
static Function * Create(FunctionType *Ty, LinkageTypes Linkage, unsigned AddrSpace, const Twine &N="", Module *M=nullptr)
Definition:Function.h:173
llvm::GlobalValue::getVisibility
VisibilityTypes getVisibility() const
Definition:GlobalValue.h:249
llvm::GlobalValue::isDeclaration
bool isDeclaration() const
Return true if the primary definition of this global value is outside of the current translation unit...
Definition:Globals.cpp:296
llvm::GlobalValue::getLinkage
LinkageTypes getLinkage() const
Definition:GlobalValue.h:547
llvm::Instruction
Definition:Instruction.h:68
llvm::Instruction::eraseFromParent
InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
Definition:Instruction.cpp:94
llvm::Instruction::insertInto
InstListType::iterator insertInto(BasicBlock *ParentBB, InstListType::iterator It)
Inserts an unlinked instruction into ParentBB at position It and returns the iterator of the inserted...
Definition:Instruction.cpp:123
llvm::LLVMContext
This is an important class for using LLVM in a threaded context.
Definition:LLVMContext.h:67
llvm::ModulePass
ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...
Definition:Pass.h:251
llvm::ModulePass::runOnModule
virtual bool runOnModule(Module &M)=0
runOnModule - Virtual method overriden by subclasses to process the module being operated on.
llvm::Module
A Module instance is used to store all the information related to an LLVM module.
Definition:Module.h:65
llvm::Pass::getAnalysisUsage
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
Definition:Pass.cpp:98
llvm::Pass::getPassName
virtual StringRef getPassName() const
getPassName - Return a nice clean name for a pass.
Definition:Pass.cpp:81
llvm::PoisonValue::get
static PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
Definition:Constants.cpp:1878
llvm::ReturnInst::Create
static ReturnInst * Create(LLVMContext &C, Value *retVal=nullptr, InsertPosition InsertBefore=nullptr)
Definition:Instructions.h:2965
llvm::SmallVectorImpl
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition:SmallVector.h:573
llvm::SmallVector
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition:SmallVector.h:1196
llvm::StringRef
StringRef - Represent a constant reference to a string, i.e.
Definition:StringRef.h:51
llvm::Type
The instances of the Type class are immutable: once they are created, they are never changed.
Definition:Type.h:45
llvm::Type::isStructTy
bool isStructTy() const
True if this is an instance of StructType.
Definition:Type.h:258
llvm::Type::getInt32Ty
static IntegerType * getInt32Ty(LLVMContext &C)
llvm::Type::isVoidTy
bool isVoidTy() const
Return true if this is 'void'.
Definition:Type.h:139
llvm::UnreachableInst
This function has undefined behavior.
Definition:Instructions.h:4461
llvm::User
Definition:User.h:44
llvm::Value
LLVM Value Representation.
Definition:Value.h:74
llvm::Value::getType
Type * getType() const
All values are typed, get the type of this value.
Definition:Value.h:255
llvm::Value::setName
void setName(const Twine &Name)
Change the name of the value.
Definition:Value.cpp:377
llvm::Value::stripPointerCasts
const Value * stripPointerCasts() const
Strip off pointer casts, all-zero GEPs and address space casts.
Definition:Value.cpp:694
unsigned
llvm::AMDGPU::HSAMD::Kernel::Key::Args
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
Definition:AMDGPUMetadata.h:395
llvm::ARM::ProfileKind::M
@ M
llvm::CallingConv::Swift
@ Swift
Calling convention for Swift.
Definition:CallingConv.h:69
llvm::CallingConv::C
@ C
The default llvm calling convention, compatible with C.
Definition:CallingConv.h:34
llvm::CallingConv::ID
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition:CallingConv.h:24
llvm
This is an optimization pass for GlobalISel generic memory operations.
Definition:AddressRanges.h:18
llvm::dbgs
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition:Debug.cpp:163
llvm::createWebAssemblyFixFunctionBitcasts
ModulePass * createWebAssemblyFixFunctionBitcasts()
raw_ostream.h

Generated on Fri Jul 18 2025 14:44:25 for LLVM by doxygen 1.9.6
[8]ページ先頭

©2009-2025 Movatter.jp