1//===- SPIRVISelLowering.cpp - SPIR-V DAG Lowering Impl ---------*- C++ -*-===// 3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4// See https://llvm.org/LICENSE.txt for license information. 5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 7//===----------------------------------------------------------------------===// 9// This file implements the SPIRVTargetLowering class. 11//===----------------------------------------------------------------------===// 22#include "llvm/IR/IntrinsicsSPIRV.h" 24#define DEBUG_TYPE "spirv-lower" 30// This code avoids CallLowering fail inside getVectorTypeBreakdown 31// on v3i1 arguments. Maybe we need to return 1 for all types. 32// TODO: remove it once this case is supported by the default implementation. 45// This code avoids CallLowering fail inside getVectorTypeBreakdown 46// on v3i1 arguments. Maybe we need to return i32 for all types. 47// TODO: remove it once this case is supported by the default implementation. 60unsigned Intrinsic)
const{
63case Intrinsic::spv_load:
66case Intrinsic::spv_store: {
67if (
I.getNumOperands() >= AlignIdx + 1) {
68auto *AlignOp = cast<ConstantInt>(
I.getOperand(AlignIdx));
69Info.align =
Align(AlignOp->getZExtValue());
72 cast<ConstantInt>(
I.getOperand(AlignIdx - 1))->getZExtValue());
74// TODO: take into account opaque pointers (don't use getElementType). 75// MVT::getVT(PtrTy->getElementType()); 85std::pair<unsigned, const TargetRegisterClass *>
91return std::make_pair(0u, RC);
94 RC = VT.
isVector() ? &SPIRV::vfIDRegClass : &SPIRV::fIDRegClass;
96 RC = VT.
isVector() ? &SPIRV::vIDRegClass : &SPIRV::iIDRegClass;
98 RC = &SPIRV::iIDRegClass;
100return std::make_pair(0u, RC);
105return TypeInst && TypeInst->
getOpcode() == SPIRV::OpFunctionParameter
124I.getOperand(OpIdx).setReg(NewReg);
131 SPIRV::StorageClass::StorageClass SC =
132static_cast<SPIRV::StorageClass::StorageClass
>(
138 ResTy, MIB, SPIRV::AccessQualifier::ReadWrite, EmitIR);
142// Insert a bitcast before the instruction to keep SPIR-V code valid 143// when there is a type mismatch between results and operand types. 150Register OpReg =
I.getOperand(OpIdx).getReg();
153if (!ResType || !OpType || OpType->
getOpcode() != SPIRV::OpTypePointer)
155// Get operand's pointee type 160// Check if we need a bitcast to make a statement valid 162bool IsEqualTypes = IsSameMF ? ElemType == ResType
166// There is a type mismatch between results and operand types 167// and we insert a bitcast before the instruction to keep SPIR-V code valid 172"insert validation bitcast: incompatible result and operand types");
176// Insert a bitcast before OpGroupWaitEvents if the last argument is a pointer 177// that doesn't point to OpTypeEvent. 182constexprunsigned OpIdx = 2;
184Register OpReg =
I.getOperand(OpIdx).getReg();
187if (!OpType || OpType->
getOpcode() != SPIRV::OpTypePointer)
190if (!ElemType || ElemType->
getOpcode() == SPIRV::OpTypeEvent)
192// Insert a bitcast before the instruction to keep SPIR-V code valid. 203Register PtrReg =
I.getOperand(0).getReg();
208if (!PonteeElemType || PonteeElemType->
getOpcode() == SPIRV::OpTypeVoid ||
209 (PonteeElemType->
getOpcode() == SPIRV::OpTypeInt &&
212// To keep the code valid a bitcast must be inserted 213 SPIRV::StorageClass::StorageClass SC =
214static_cast<SPIRV::StorageClass::StorageClass
>(
229Register OpReg =
I.getOperand(OpIdx).getReg();
232if (!OpType || OpType->
getOpcode() != SPIRV::OpTypePointer)
235if (!ElemType || ElemType->
getOpcode() != SPIRV::OpTypeStruct ||
238// It's a structure-wrapper around another type with a single member field. 243unsigned MemberTypeOp = MemberType->
getOpcode();
244if (MemberTypeOp != SPIRV::OpTypeVector && MemberTypeOp != SPIRV::OpTypeInt &&
245 MemberTypeOp != SPIRV::OpTypeFloat && MemberTypeOp != SPIRV::OpTypeBool)
247// It's a structure-wrapper around a valid type. Insert a bitcast before the 248// instruction to keep SPIR-V code valid. 249 SPIRV::StorageClass::StorageClass SC =
250static_cast<SPIRV::StorageClass::StorageClass
>(
257// Insert a bitcast before the function call instruction to keep SPIR-V code 258// valid when there is a type mismatch between actual and expected types of an 260// %formal = OpFunctionParameter %formal_type 262// %res = OpFunctionCall %ty %fun %actual ... 263// implies that %actual is of %formal_type, and in case of opaque pointers. 264// We may need to insert a bitcast to ensure this. 270if (FunDef->
getOpcode() != SPIRV::OpFunction)
274 FunDef && FunDef->
getOpcode() == SPIRV::OpFunctionParameter &&
279 DefPtrType && DefPtrType->
getOpcode() == SPIRV::OpTypePointer
285// validatePtrTypes() works in the context if the call site 286// When we process historical records about forward calls 287// we need to switch context to the (forward) call site and 288// then restore it back to the current machine function. 298// Ensure there is no mismatch between actual and expected arg types: calls 299// with a processed definition. Return Function pointer if it's a forward 300// call (ahead of definition), and nullptr otherwise. 306constFunction *
F = dyn_cast<Function>(GV);
316// Ensure there is no mismatch between actual and expected arg types: calls 317// ahead of a processed definition. 325 &FunCall->getParent()->getParent()->getRegInfo();
330// Validation of an access chain. 334if (BaseTypeInst && BaseTypeInst->
getOpcode() == SPIRV::OpTypePointer) {
341// TODO: the logic of inserting additional bitcast's is to be moved 342// to pre-IRTranslation passes eventually 344// finalizeLowering() is called twice (see GlobalISel/InstructionSelect.cpp) 345// We'd like to avoid the needless second processing pass. 346if (ProcessedMF.find(&MF) != ProcessedMF.end())
358switch (
MI.getOpcode()) {
359case SPIRV::OpAtomicLoad:
360case SPIRV::OpAtomicExchange:
361case SPIRV::OpAtomicCompareExchange:
362case SPIRV::OpAtomicCompareExchangeWeak:
363case SPIRV::OpAtomicIIncrement:
364case SPIRV::OpAtomicIDecrement:
365case SPIRV::OpAtomicIAdd:
366case SPIRV::OpAtomicISub:
367case SPIRV::OpAtomicSMin:
368case SPIRV::OpAtomicUMin:
369case SPIRV::OpAtomicSMax:
370case SPIRV::OpAtomicUMax:
371case SPIRV::OpAtomicAnd:
372case SPIRV::OpAtomicOr:
373case SPIRV::OpAtomicXor:
374// for the above listed instructions 375// OpAtomicXXX <ResType>, ptr %Op, ... 376// implies that %Op is a pointer to <ResType> 378// OpLoad <ResType>, ptr %Op implies that %Op is a pointer to <ResType> 382case SPIRV::OpAtomicStore:
383// OpAtomicStore ptr %Op, <Scope>, <Mem>, <Obj> 384// implies that %Op points to the <Obj>'s type 389// OpStore ptr %Op, <Obj> implies that %Op points to the <Obj>'s type 393case SPIRV::OpPtrCastToGeneric:
394case SPIRV::OpGenericCastToPtr:
397case SPIRV::OpPtrAccessChain:
398case SPIRV::OpInBoundsPtrAccessChain:
399if (
MI.getNumOperands() == 4)
403case SPIRV::OpFunctionCall:
404// ensure there is no mismatch between actual and expected arg types: 405// calls with a processed definition 406if (
MI.getNumOperands() > 3)
410case SPIRV::OpFunction:
411// ensure there is no mismatch between actual and expected arg types: 412// calls ahead of a processed definition 416// ensure that LLVM IR add/sub instructions result in logical SPIR-V 417// instructions when applied to bool type 427// ensure that LLVM IR bitwise instructions result in logical SPIR-V 428// instructions when applied to bool type 429case SPIRV::OpBitwiseOrS:
430case SPIRV::OpBitwiseOrV:
435case SPIRV::OpBitwiseAndS:
436case SPIRV::OpBitwiseAndV:
441case SPIRV::OpBitwiseXorS:
442case SPIRV::OpBitwiseXorV:
447case SPIRV::OpLifetimeStart:
448case SPIRV::OpLifetimeStop:
449if (
MI.getOperand(1).getImm() > 0)
452case SPIRV::OpGroupAsyncCopy:
456case SPIRV::OpGroupWaitEvents:
457// OpGroupWaitEvents ..., ..., <pointer to OpTypeEvent> 460case SPIRV::OpConstantI: {
462if (
Type->getOpcode() != SPIRV::OpTypeInt &&
MI.getOperand(2).isImm() &&
463MI.getOperand(2).getImm() == 0) {
464// Validate the null constant of a target extension type 466for (
unsigned i =
MI.getNumOperands() - 1; i > 1; --i)
471// Phi refers to a type definition that goes after the Phi 472// instruction, so that the virtual register definition of the type 473// doesn't dominate all uses. Let's place the type definition 474// instruction at the end of the predecessor. 480case SPIRV::OpExtInst: {
482if (!
MI.getOperand(2).isImm() || !
MI.getOperand(3).isImm() ||
483MI.getOperand(2).getImm() != SPIRV::InstructionSet::OpenCL_std)
485switch (
MI.getOperand(3).getImm()) {
486case SPIRV::OpenCLExtInst::frexp:
487case SPIRV::OpenCLExtInst::lgamma_r:
488case SPIRV::OpenCLExtInst::remquo: {
489// The last operand must be of a pointer to i32 or vector of i32 494assert(RetType &&
"Expected return type");
496 STI,
MRI, GR,
MI,
MI.getNumOperands() - 1,
497 RetType->
getOpcode() != SPIRV::OpTypeVector
502case SPIRV::OpenCLExtInst::fract:
503case SPIRV::OpenCLExtInst::modf:
504case SPIRV::OpenCLExtInst::sincos:
505// The last operand must be of a pointer to the base type represented 506// by the previous operand. 507assert(
MI.getOperand(
MI.getNumOperands() - 2).isReg() &&
510 STI,
MRI, GR,
MI,
MI.getNumOperands() - 1,
512MI.getOperand(
MI.getNumOperands() - 2).getReg()));
514case SPIRV::OpenCLExtInst::prefetch:
515// Expected `ptr` type is a pointer to float, integer or vector, but 516// the pontee value can be wrapped into a struct. 517assert(
MI.getOperand(
MI.getNumOperands() - 2).isReg() &&
520MI.getNumOperands() - 2);
532 ProcessedMF.insert(&MF);
unsigned const MachineRegisterInfo * MRI
MachineBasicBlock MachineBasicBlock::iterator MBBI
Analysis containing CSE Info
unsigned const TargetRegisterInfo * TRI
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
static void doInsertBitcast(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR, MachineInstr &I, Register OpReg, unsigned OpIdx, SPIRVType *NewPtrType)
static void validateLifetimeStart(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR, MachineInstr &I)
static void validateGroupWaitEventsPtr(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR, MachineInstr &I)
static void validatePtrUnwrapStructField(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR, MachineInstr &I, unsigned OpIdx)
Register getTypeReg(MachineRegisterInfo *MRI, Register OpReg)
void validateAccessChain(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR, MachineInstr &I)
void validateFunCallMachineDef(const SPIRVSubtarget &STI, MachineRegisterInfo *DefMRI, MachineRegisterInfo *CallMRI, SPIRVGlobalRegistry &GR, MachineInstr &FunCall, MachineInstr *FunDef)
void validateForwardCalls(const SPIRVSubtarget &STI, MachineRegisterInfo *DefMRI, SPIRVGlobalRegistry &GR, MachineInstr &FunDef)
const Function * validateFunCall(const SPIRVSubtarget &STI, MachineRegisterInfo *CallMRI, SPIRVGlobalRegistry &GR, MachineInstr &FunCall)
static void validatePtrTypes(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR, MachineInstr &I, unsigned OpIdx, SPIRVType *ResType, const Type *ResTy=nullptr)
static SPIRVType * createNewPtrType(SPIRVGlobalRegistry &GR, MachineInstr &I, SPIRVType *OpType, bool ReuseType, bool EmitIR, SPIRVType *ResType, const Type *ResTy)
This class represents a function call, abstracting a target machine's calling convention.
LLVMContext & getContext() const
getContext - Return a reference to the LLVMContext associated with this function.
This is an important class for using LLVM in a threaded context.
bool isVector() const
Return true if this is a vector value type.
bool isInteger() const
Return true if this is an integer or a vector integer type.
bool isFloatingPoint() const
Return true if this is a FP or a vector FP type.
MachineInstr * remove_instr(MachineInstr *I)
Remove the possibly bundled instruction from the instruction list without deleting it.
instr_iterator insert(instr_iterator I, MachineInstr *M)
Insert MI into the instruction list before I, possibly inside a bundle.
iterator getFirstTerminator()
Returns an iterator to the first terminator instruction of this basic block.
pred_iterator pred_begin()
const MachineFunction * getParent() const
Return the MachineFunction containing this basic block.
MachineRegisterInfo & getRegInfo()
getRegInfo - Return information about the registers currently in use.
Function & getFunction()
Return the LLVM function that this machine code represents.
Helper class to build MachineInstr.
MachineInstrBuilder buildInstr(unsigned Opcode)
Build and insert <empty> = Opcode <empty>.
MachineFunction & getMF()
Getter for the function we currently build.
bool constrainAllUses(const TargetInstrInfo &TII, const TargetRegisterInfo &TRI, const RegisterBankInfo &RBI) const
const MachineInstrBuilder & addUse(Register RegNo, unsigned Flags=0, unsigned SubReg=0) const
Add a virtual register use operand.
const MachineInstrBuilder & addDef(Register RegNo, unsigned Flags=0, unsigned SubReg=0) const
Add a virtual register definition operand.
Representation of each machine instruction.
unsigned getOpcode() const
Returns the opcode of this MachineInstr.
const MachineBasicBlock * getParent() const
unsigned getNumOperands() const
Retuns the total number of operands.
const MachineOperand & getOperand(unsigned i) const
Flags
Flags values. These may be or'd together.
const GlobalValue * getGlobal() const
Register getReg() const
getReg - Returns the register number.
MachineRegisterInfo - Keep track of information for virtual and physical registers,...
MachineInstr * getVRegDef(Register Reg) const
getVRegDef - Return the machine instr that defines the specified virtual register or null if none is ...
Wrapper class representing virtual and physical registers.
SPIRVType * getSPIRVTypeForVReg(Register VReg, const MachineFunction *MF=nullptr) const
void addForwardCall(const Function *F, MachineInstr *MI)
const Type * getTypeForSPIRVType(const SPIRVType *Ty) const
bool isBitcastCompatible(const SPIRVType *Type1, const SPIRVType *Type2) const
const MachineInstr * getFunctionDefinition(const Function *F)
SPIRVType * getPointeeType(SPIRVType *PtrType)
Register getSPIRVTypeID(const SPIRVType *SpirvType) const
SPIRVType * getOrCreateSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder, SPIRV::AccessQualifier::AccessQualifier AQ=SPIRV::AccessQualifier::ReadWrite, bool EmitIR=true)
SmallPtrSet< MachineInstr *, 8 > * getForwardCalls(const Function *F)
bool isScalarOrVectorOfType(Register VReg, unsigned TypeOpcode) const
MachineFunction * setCurrentFunc(MachineFunction &MF)
SPIRVType * getOrCreateSPIRVPointerType(SPIRVType *BaseType, MachineIRBuilder &MIRBuilder, SPIRV::StorageClass::StorageClass SClass=SPIRV::StorageClass::Function)
SPIRVType * getOrCreateSPIRVVectorType(SPIRVType *BaseType, unsigned NumElements, MachineIRBuilder &MIRBuilder)
SPIRVType * getOrCreateSPIRVIntegerType(unsigned BitWidth, MachineIRBuilder &MIRBuilder)
const Function * getFunctionByDefinition(const MachineInstr *MI)
const SPIRVInstrInfo * getInstrInfo() const override
SPIRVGlobalRegistry * getSPIRVGlobalRegistry() const
const SPIRVRegisterInfo * getRegisterInfo() const override
const RegisterBankInfo * getRegBankInfo() const override
unsigned getNumRegisters(LLVMContext &Context, EVT VT, std::optional< MVT > RegisterVT=std::nullopt) const override
Return the number of registers that this ValueType will eventually require.
unsigned getNumRegistersForCallingConv(LLVMContext &Context, CallingConv::ID CC, EVT VT) const override
Certain targets require unusual breakdowns of certain types.
MVT getRegisterTypeForCallingConv(LLVMContext &Context, CallingConv::ID CC, EVT VT) const override
Certain combinations of ABIs, Targets and features require that types are legal for some operations a...
bool getTgtMemIntrinsic(IntrinsicInfo &Info, const CallInst &I, MachineFunction &MF, unsigned Intrinsic) const override
Given an intrinsic, checks if on the target the intrinsic will need to map to a MemIntrinsicNode (tou...
void finalizeLowering(MachineFunction &MF) const override
Execute target specific actions to finalize target lowering.
std::pair< unsigned, const TargetRegisterClass * > getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, StringRef Constraint, MVT VT) const override
Given a physical register constraint (e.g.
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
StringRef - Represent a constant reference to a string, i.e.
bool starts_with(StringRef Prefix) const
Check if this string starts with the given Prefix.
static TargetExtType * get(LLVMContext &Context, StringRef Name, ArrayRef< Type * > Types={}, ArrayRef< unsigned > Ints={})
Return a target extension type having the specified name and optional type and integer parameters.
virtual void finalizeLowering(MachineFunction &MF) const
Execute target specific actions to finalize target lowering.
MVT getRegisterType(MVT VT) const
Return the type of registers that this ValueType will eventually require.
TargetRegisterInfo base class - We assume that the target defines a static array of TargetRegisterDes...
The instances of the Type class are immutable: once they are created, they are never changed.
static IntegerType * getInt8Ty(LLVMContext &C)
NodeTy * getNextNode()
Get the next node, or nullptr for the list tail.
This is an optimization pass for GlobalISel generic memory operations.
Register createVirtualRegister(SPIRVType *SpvType, SPIRVGlobalRegistry *GR, MachineRegisterInfo *MRI, const MachineFunction &MF)
void report_fatal_error(Error Err, bool gen_crash_diag=true)
Report a serious error, calling any installed error handler.
This struct is a compact representation of a valid (non-zero power of two) alignment.
TypeSize getSizeInBits() const
Return the size of the specified value type in bits.
bool isVector() const
Return true if this is a vector value type.
EVT getVectorElementType() const
Given a vector type, return the type of each element.
unsigned getVectorNumElements() const
Given a vector type, return the number of elements it contains.
bool isInteger() const
Return true if this is an integer or a vector integer type.