1//===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===// 3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4// See https://llvm.org/LICENSE.txt for license information. 5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 7//===----------------------------------------------------------------------===// 9// Lower matrix intrinsics to vector operations. 13// * Support more cases, e.g. multiply-add, multiply-sub, operands/results 15// * Improve cost-modeling, e.g. choose different number of rows/columns 16// columns for tiles, consider cost of copies on alias. 18//===----------------------------------------------------------------------===// 51using namespacePatternMatch;
53#define DEBUG_TYPE "lower-matrix-intrinsics" 57cl::desc(
"Enable/disable fusing matrix instructions."));
58// TODO: Allow and use non-square tiles. 62"Tile size for matrix instruction fusion using square-shaped tiles."));
65cl::desc(
"Generate loop nest for tiling."));
68cl::desc(
"Force matrix instruction fusion even if not profitable."));
71cl::desc(
"Allow the use of FMAs if available and profitable. This may " 72"result in different results, due to less rounding error."));
76cl::desc(
"Enable/disable matrix shape verification."),
83cl::desc(
"Sets the default matrix layout"),
85"Use column-major layout"),
87"Use row-major layout")));
92/// Helper function to either return Scope, if it is a subprogram or the 93/// attached subprogram for a local scope. 95if (
auto *Subprogram = dyn_cast<DISubprogram>(Scope))
100/// Return true if V is a splat of a value (which is used when multiplying a 101/// matrix with a scalar). 103if (
auto *SV = dyn_cast<ShuffleVectorInst>(V))
104return SV->isZeroEltSplat();
108/// Match any mul operation (fp or integer). 109template <
typename LTy,
typename RTy>
114/// Match any add operation (fp or integer). 115template <
typename LTy,
typename RTy>
122// Given an element pointer \p BasePtr to the start of a (sub) matrix, compute 123// the start address of vector \p VecIdx with type (\p EltType x \p NumElements) 124// assuming \p Stride elements between start two consecutive vectors. 125// \p Stride must be >= \p NumElements. 126// For column-major matrixes, the function computes the address of a column 127// vectors and \p NumElements must be set to the number of elements in a column 128// (= number of rows of the matrix). For row-major matrixes, the function 129// computes the address of a row vector and \p NumElements must be set to the 130// number of elements in a column (= number of columns of the matrix). 132// Consider a 4x4 matrix in column-mjaor layout like below 135// 0 v_0_0 v_0_1 v_0_2 v_0_3 136// 1 v_1_0 v_1_1 v_1_2 v_1_3 137// 2 v_2_0 v_2_1 v_2_2 v_2_3 138// 3 v_3_0 v_3_1 v_3_2 v_3_3 140// To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1, 141// we need a pointer to the first element of the submatrix as base pointer. 142// Then we can use computeVectorAddr to compute the addresses for the columns 145// Column 0: computeVectorAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..) 146// -> just returns Base 147// Column 1: computeVectorAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..) 148// -> returns Base + (1 * 4) 149// Column 2: computeVectorAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..) 150// -> returns Base + (2 * 4) 152// The graphic below illustrates the number of elements in a column (marked 153// with |) and the number of skipped elements (marked with }). 155// v_0_0 v_0_1 {v_0_2 {v_0_3 158// v_1_0 |v_1_1 |v_1_2 |v_1_3 159// v_2_0 |v_2_1 |v_2_2 |v_2_3 160// v_3_0 {v_3_1 {v_3_2 v_3_3 163unsigned NumElements,
Type *EltType,
166assert((!isa<ConstantInt>(Stride) ||
167 cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) &&
168"Stride must be >= the number of elements in the result vector.");
170// Compute the start of the vector with index VecIdx as VecIdx * Stride. 173// Get pointer to the start of the selected vector. Skip GEP creation, 174// if we select vector 0. 175if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->
isZero())
178 VecStart = Builder.
CreateGEP(EltType, BasePtr, VecStart,
"vec.gep");
190 ShapeInfo(
unsigned NumRows = 0,
unsigned NumColumns = 0)
191 : NumRows(NumRows), NumColumns(NumColumns),
199return NumRows == other.NumRows && NumColumns == other.NumColumns;
201booloperator!=(
const ShapeInfo &other) {
return !(*
this == other); }
203 /// Returns true if shape-information is defined, meaning both dimensions 206assert(NumRows == 0 || NumColumns != 0);
210unsigned getStride()
const{
216unsigned getNumVectors()
const{
222 /// Returns the transposed shape. 223 ShapeInfo t()
const{
return ShapeInfo(NumColumns, NumRows); }
227staticbool isUniformShape(
Value *V) {
232switch (
I->getOpcode()) {
233case Instruction::FAdd:
234case Instruction::FSub:
235case Instruction::FMul:
// Scalar multiply. 236case Instruction::FNeg:
237case Instruction::Add:
238case Instruction::Mul:
239case Instruction::Sub:
246/// Return the ShapeInfo for the result of \p I, it it can be determined. 247static std::optional<ShapeInfo>
253if (
match(
I, m_Intrinsic<Intrinsic::matrix_multiply>(
255return ShapeInfo(M, K);
259return ShapeInfo(
N, M);
261if (
match(
I, m_Intrinsic<Intrinsic::matrix_column_major_store>(
264return ShapeInfo(
N, M);
265if (
match(
I, m_Intrinsic<Intrinsic::matrix_column_major_load>(
267return ShapeInfo(M,
N);
270auto OpShape = ShapeMap.
find(MatrixA);
271if (OpShape != ShapeMap.
end())
272return OpShape->second;
275if (isUniformShape(
I)) {
276// Find the first operand that has a known shape and use that. 277for (
auto &
Op :
I->operands()) {
278auto OpShape = ShapeMap.
find(
Op.get());
279if (OpShape != ShapeMap.
end())
280return OpShape->second;
286/// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics. 288/// Currently, the lowering for each matrix intrinsic is done as follows: 289/// 1. Propagate the shape information from intrinsics to connected 291/// 2. Lower instructions with shape information (assuming column-major layout). 292/// The lowering works similarly using row-major layout. 293/// 2.1. Get column vectors for each argument. If we already lowered the 294/// definition of an argument, use the produced column vectors directly. 295/// If not, split the operand vector containing an embedded matrix into 296/// a set of column vectors, 297/// 2.2. Lower the instruction in terms of column major operations, which 298/// yields a set of column vectors containing result matrix. Note that we 299/// lower all instructions that have shape information. Besides the 300/// intrinsics, this includes stores for example. 301/// 2.3. Update uses of the lowered instruction. If we have shape information 302/// for a user, there is nothing to do, as we will look up the result 303/// column matrix when lowering the user. For other uses, we embed the 304/// result matrix in a flat vector and update the use. 305/// 2.4. Cache the result column matrix for the instruction we lowered 306/// 3. After we lowered all instructions in a function, remove the now 307/// obsolete instructions. 309classLowerMatrixIntrinsics {
319 /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation. 321 /// Number of stores emitted to generate this matrix. 322unsigned NumStores = 0;
323 /// Number of loads emitted to generate this matrix. 324unsigned NumLoads = 0;
325 /// Number of compute operations emitted to generate this matrix. 326unsigned NumComputeOps = 0;
327 /// Most of the time transposes can be fused with matrix multiplies or can 328 /// be folded away via algebraic simplifications. This is the number of 329 /// transposes that we failed to make "free" via such optimizations. 330unsigned NumExposedTransposes = 0;
333 NumStores +=
RHS.NumStores;
334 NumLoads +=
RHS.NumLoads;
335 NumComputeOps +=
RHS.NumComputeOps;
336 NumExposedTransposes +=
RHS.NumExposedTransposes;
341 /// Wrapper class representing a matrix as a set of vectors, either in row or 342 /// column major layout. All vectors must have the same vector type. 348bool IsColumnMajor =
true;
355 MatrixTy(
unsigned NumRows,
unsigned NumColumns,
Type *EltTy)
358unsignedD = isColumnMajor() ? NumColumns : NumRows;
359for (
unsigned J = 0; J <
D; ++J)
361 EltTy, isColumnMajor() ? NumRows : NumColumns)));
364Value *getVector(
unsigned i)
const{
return Vectors[i]; }
365Value *getColumn(
unsigned i)
const{
366assert(isColumnMajor() &&
"only supported for column-major matrixes");
369Value *getRow(
unsigned i)
const{
370assert(!isColumnMajor() &&
"only supported for row-major matrixes");
374void setVector(
unsigned i,
Value *V) { Vectors[i] =
V; }
376Type *getElementType()
const{
return getVectorTy()->getElementType(); }
378unsigned getNumVectors()
const{
380return getNumColumns();
384unsigned getNumColumns()
const{
386return Vectors.
size();
388assert(Vectors.
size() > 0 &&
"Cannot call getNumRows without columns");
389return cast<FixedVectorType>(Vectors[0]->
getType())->getNumElements();
392unsigned getNumRows()
const{
393if (isColumnMajor()) {
394assert(Vectors.
size() > 0 &&
"Cannot call getNumRows without columns");
395return cast<FixedVectorType>(Vectors[0]->
getType())->getNumElements();
397return Vectors.
size();
402assert(isColumnMajor() &&
"only supported for column-major matrixes");
407return cast<VectorType>(Vectors[0]->
getType());
412"columns() only supported for column-major matrixes");
420 /// Embed the vectors of the matrix into a flat vector by concatenating 423return Vectors.
size() == 1 ? Vectors[0]
427 MatrixTy &addNumLoads(
unsignedN) {
432void setNumLoads(
unsignedN) { OpInfo.NumLoads =
N; }
434 MatrixTy &addNumStores(
unsignedN) {
435 OpInfo.NumStores +=
N;
439 MatrixTy &addNumExposedTransposes(
unsignedN) {
440 OpInfo.NumExposedTransposes +=
N;
444 MatrixTy &addNumComputeOps(
unsignedN) {
445 OpInfo.NumComputeOps +=
N;
449unsigned getNumStores()
const{
return OpInfo.NumStores; }
450unsigned getNumLoads()
const{
return OpInfo.NumLoads; }
451unsigned getNumComputeOps()
const{
return OpInfo.NumComputeOps; }
453const OpInfoTy &getOpInfo()
const{
return OpInfo; }
455bool isColumnMajor()
const{
return IsColumnMajor; }
457unsigned getStride()
const{
460return getNumColumns();
463 /// Extract a vector of \p NumElts starting at index (\p I, \p J). If the 464 /// matrix is column-major, the result vector is extracted from a column 465 /// vector, otherwise from a row vector. 468Value *Vec = isColumnMajor() ? getColumn(J) : getRow(
I);
469assert(cast<FixedVectorType>(Vec->
getType())->getNumElements() >=
471"Extracted vector will contain poison values");
478 /// Maps instructions to their shape information. The shape information 479 /// describes the shape to be used while lowering. This matches the shape of 480 /// the result value of the instruction, with the only exceptions being store 481 /// instructions and the matrix_column_major_store intrinsics. For those, the 482 /// shape information indicates that those instructions should be lowered 483 /// using shape information as well. Note that extra care is needed when 484 /// erasing or RAUW'ing a value that is present in ShapeMap. If the 485 /// replacement is also a matrix operation, use 486 /// updateShapeAndReplaceAllUsesWith to make sure the replacement is added to 487 /// ShapeMap. We don't use ValueMap, as there are also cases where we do not 488 /// want to add shape information for a replacement instruction. When directly 489 /// erasing a value with an entry in ShapeMap, use 490 /// eraseFromParentAndRemoveFromShapeMap to make sure ShapeMap is also updated 494 /// List of instructions to remove. While lowering, we are not replacing all 495 /// users of a lowered instruction, if shape information is available and 496 /// those need to be removed after we finished lowering. 499 /// Map from instructions to their produced column matrix. 506if (isa<FPMathOperator>(*Inst))
519unsigned getNumOps(
Type *VT) {
520assert(isa<VectorType>(VT) &&
"Expected vector type");
522 cast<FixedVectorType>(VT)->getNumElements());
525 /// Is this the minimal version executed in the backend pipelines. 526bool isMinimal()
const{
530 /// Return the estimated number of vector ops required for an operation on 532unsigned getNumOps(
Type *ST,
unsignedN) {
533return std::ceil((
ST->getPrimitiveSizeInBits() *
N).getFixedValue() /
539 /// Return the set of vectors that a matrix value is lowered to. 541 /// If we lowered \p MatrixVal, just return the cache result matrix. Otherwise 542 /// split the flat vector \p MatrixVal containing a matrix with shape \p SI 544 MatrixTy getMatrix(
Value *MatrixVal,
const ShapeInfo &SI,
547assert(VType &&
"MatrixVal must be a vector type");
549SI.NumRows *
SI.NumColumns &&
550"The vector size must match the number of matrix elements");
552// Check if we lowered MatrixVal using shape information. In that case, 553// return the existing matrix, if it matches the requested shape 554// information. If there is a mis-match, embed the result in a flat 555// vector and split it later. 556auto Found = Inst2ColumnMatrix.
find(MatrixVal);
557if (Found != Inst2ColumnMatrix.
end()) {
558 MatrixTy &
M = Found->second;
559// Return the found matrix, if its shape matches the requested shape 561if (
SI.NumRows ==
M.getNumRows() &&
SI.NumColumns ==
M.getNumColumns())
564 MatrixVal =
M.embedInVector(Builder);
567// Otherwise split MatrixVal. 569for (
unsigned MaskStart = 0;
570 MaskStart < cast<FixedVectorType>(VType)->getNumElements();
571 MaskStart +=
SI.getStride()) {
581 /// If \p V already has a known shape return false. Otherwise set the shape 582 /// for instructions that support it. 583bool setShapeInfo(
Value *V, ShapeInfo Shape) {
584assert(Shape &&
"Shape not set");
585if (isa<UndefValue>(V) || !supportsShapeInfo(V))
588auto SIter = ShapeMap.
find(V);
589if (SIter != ShapeMap.
end()) {
591 SIter->second.NumColumns != Shape.NumColumns)) {
592errs() <<
"Conflicting shapes (" << SIter->second.NumRows <<
"x" 593 << SIter->second.NumColumns <<
" vs " << Shape.NumRows <<
"x" 594 << Shape.NumColumns <<
") for " << *
V <<
"\n";
596"Matrix shape verification failed, compilation aborted!");
600 << SIter->second.NumRows <<
" " 601 << SIter->second.NumColumns <<
" for " << *V <<
"\n");
606LLVM_DEBUG(
dbgs() <<
" " << Shape.NumRows <<
" x " << Shape.NumColumns
607 <<
" for " << *V <<
"\n");
611 /// Returns true if shape information can be used for \p V. The supported 612 /// instructions must match the instructions that can be lowered by this pass. 613bool supportsShapeInfo(
Value *V) {
620switch (
II->getIntrinsicID()) {
621case Intrinsic::matrix_multiply:
622case Intrinsic::matrix_transpose:
623case Intrinsic::matrix_column_major_load:
624case Intrinsic::matrix_column_major_store:
629return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
632 /// Propagate the shape information of instructions to their users. 633 /// The work list contains instructions for which we can compute the shape, 634 /// either based on the information provided by matrix intrinsics or known 635 /// shapes of operands. 639// Pop an element for which we guaranteed to have at least one of the 640// operand shapes. Add the shape for this and then add users to the work 643while (!WorkList.
empty()) {
646// New entry, set the value and insert operands 647bool Propagate =
false;
648if (
auto SI = computeShapeInfoForInst(Inst, ShapeMap))
649 Propagate = setShapeInfo(Inst, *SI);
662 /// Propagate the shape to operands of instructions with shape information. 663 /// \p Worklist contains the instruction for which we already know the shape. 668auto pushInstruction = [](
Value *
V,
674// Pop an element with known shape. Traverse the operands, if their shape 675// derives from the result shape and is unknown, add it and add them to the 678while (!WorkList.
empty()) {
681size_t BeforeProcessingV = WorkList.
size();
682if (!isa<Instruction>(V))
690if (
match(V, m_Intrinsic<Intrinsic::matrix_multiply>(
693if (setShapeInfo(MatrixA, {
M,
N}))
694 pushInstruction(MatrixA, WorkList);
696if (setShapeInfo(MatrixB, {
N,
K}))
697 pushInstruction(MatrixB, WorkList);
699 }
elseif (
match(V, m_Intrinsic<Intrinsic::matrix_transpose>(
702if (setShapeInfo(MatrixA, {
M,
N}))
703 pushInstruction(MatrixA, WorkList);
704 }
elseif (
match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>(
707if (setShapeInfo(MatrixA, {
M,
N})) {
708 pushInstruction(MatrixA, WorkList);
710 }
elseif (isa<LoadInst>(V) ||
711match(V, m_Intrinsic<Intrinsic::matrix_column_major_load>())) {
712// Nothing to do, no matrix input. 713 }
elseif (isa<StoreInst>(V)) {
714// Nothing to do. We forward-propagated to this so we would just 715// backward propagate to an instruction with an already known shape. 716 }
elseif (isUniformShape(V)) {
717// Propagate to all operands. 718 ShapeInfo Shape = ShapeMap[
V];
719for (
Use &U : cast<Instruction>(V)->operands()) {
720if (setShapeInfo(
U.get(), Shape))
721 pushInstruction(
U.get(), WorkList);
724// After we discovered new shape info for new instructions in the 725// worklist, we use their users as seeds for the next round of forward 727for (
size_tI = BeforeProcessingV;
I != WorkList.
size();
I++)
729if (isa<Instruction>(U) && V != U)
730 NewWorkList.
push_back(cast<Instruction>(U));
735 /// (Op0 op Op1)^T -> Op0^T op Op1^T 736 /// Transpose \p Op0 and \p Op1 of shape \p Shape0 and \p Shape1, then use 737 /// them on both sides of \p Operation. 739Value *Op0, ShapeInfo Shape0,
Value *Op1, ShapeInfo Shape1,
744 Op0, Shape0.NumRows, Shape0.NumColumns, Op0->
getName() +
"_t");
745// We are being run after shape prop, add shape for newly created 746// instructions so that we lower them later. 747 setShapeInfo(T0, Shape0.t());
749 Op1, Shape1.NumRows, Shape1.NumColumns, Op1->
getName() +
"_t");
750 setShapeInfo(T1, Shape1.t());
751returnOperation(T0, Shape0.t(), T1, Shape1.t());
754 /// Erase \p Inst from both ShapeMap (if an entry exists) and erase \p Inst 756void eraseFromParentAndRemoveFromShapeMap(
Instruction *Inst) {
757auto Iter = ShapeMap.
find(Inst);
758if (Iter != ShapeMap.
end())
759 ShapeMap.
erase(Iter);
763 /// Erase \p V from \p BB and move \II forward to avoid invalidating 767auto *Inst = cast<Instruction>(V);
768// Still used, don't erase. 773 eraseFromParentAndRemoveFromShapeMap(Inst);
776 /// Add a new entry to ShapeMap for \p New with \p Old's shape info, erase the 777 /// entry for \p Old and replace all uses of \p Old with \p New. 779// We need to remove Old from the ShapeMap otherwise RAUW will replace it 780// with New. We should only add New it it supportsShapeInfo so we insert 781// it conditionally instead. 782auto S = ShapeMap.
find(&Old);
783if (S != ShapeMap.
end()) {
785if (supportsShapeInfo(New))
791 /// Sink a top-level transpose inside matmuls and adds. 792 /// This creates and erases instructions as needed, and returns the newly 793 /// created instruction while updating the iterator to avoid invalidation. If 794 /// this returns nullptr, no new instruction was created. 802if (!
match(&
I, m_Intrinsic<Intrinsic::matrix_transpose>(
806// Transpose of a transpose is a nop 808if (
match(TA, m_Intrinsic<Intrinsic::matrix_transpose>(
m_Value(TATA)))) {
809 updateShapeAndReplaceAllUsesWith(
I, TATA);
810 eraseFromParentAndMove(&
I,
II, BB);
811 eraseFromParentAndMove(TA,
II, BB);
817 updateShapeAndReplaceAllUsesWith(
I, TA);
818 eraseFromParentAndMove(&
I,
II, BB);
822// (A * B)^t -> B^t * A^t 824if (
match(TA, m_Intrinsic<Intrinsic::matrix_multiply>(
827auto NewInst = distributeTransposes(
828 TAMB, {
K,
C}, TAMA, {
R,
K}, Builder,
829 [&](
Value *T0, ShapeInfo Shape0,
Value *
T1, ShapeInfo Shape1) {
832 Shape1.NumColumns,
"mmul");
834 updateShapeAndReplaceAllUsesWith(
I, NewInst);
835 eraseFromParentAndMove(&
I,
II, BB);
836 eraseFromParentAndMove(TA,
II, BB);
840// Same as above, but with a mul, which occurs when multiplied 842// (A * k)^t -> A^t * k 847// We know that the transposed operand is of shape RxC. 848// An when multiplied with a scalar, the shape is preserved. 849auto NewInst = distributeTransposes(
850 TAMA, {
R,
C}, TAMB, {
R,
C}, Builder,
851 [&](
Value *T0, ShapeInfo Shape0,
Value *
T1, ShapeInfo Shape1) {
852bool IsFP =
I.getType()->isFPOrFPVectorTy();
853auto *
Mul = IsFP ? LocalBuilder.CreateFMul(T0, T1,
"mmul")
854 : LocalBuilder.CreateMul(T0, T1,
"mmul");
856 setShapeInfo(Result, Shape0);
859 updateShapeAndReplaceAllUsesWith(
I, NewInst);
860 eraseFromParentAndMove(&
I,
II, BB);
861 eraseFromParentAndMove(TA,
II, BB);
865// (A + B)^t -> A^t + B^t 869auto NewInst = distributeTransposes(
870 TAMA, {
R,
C}, TAMB, {
R,
C}, Builder,
871 [&](
Value *T0, ShapeInfo Shape0,
Value *
T1, ShapeInfo Shape1) {
872bool IsFP =
I.getType()->isFPOrFPVectorTy();
873auto *
Add = IsFP ? LocalBuilder.CreateFAdd(T0, T1,
"madd")
874 : LocalBuilder.CreateAdd(T0, T1,
"madd");
877 setShapeInfo(Result, Shape0);
880 updateShapeAndReplaceAllUsesWith(
I, NewInst);
881 eraseFromParentAndMove(&
I,
II, BB);
882 eraseFromParentAndMove(TA,
II, BB);
890// Erase dead Instructions after lifting transposes from binops. 893 eraseFromParentAndRemoveFromShapeMap(&
T);
895 eraseFromParentAndRemoveFromShapeMap(cast<Instruction>(
A));
896if (
A !=
B &&
B->use_empty())
897 eraseFromParentAndRemoveFromShapeMap(cast<Instruction>(
B));
902// A^t * B ^t -> (B * A)^t 903if (
match(&
I, m_Intrinsic<Intrinsic::matrix_multiply>(
906match(
A, m_Intrinsic<Intrinsic::matrix_transpose>(
m_Value(AT))) &&
911BT, AT,
C->getZExtValue(),
K->getZExtValue(),
R->getZExtValue());
912 setShapeInfo(M, {
C,
R});
915 updateShapeAndReplaceAllUsesWith(
I, NewInst);
918// A^t + B ^t -> (A + B)^t. Pick rows and columns from first transpose. If 919// the shape of the second transpose is different, there's a shape conflict 920// which gets resolved by picking the shape of the first operand. 922match(
A, m_Intrinsic<Intrinsic::matrix_transpose>(
924match(
B, m_Intrinsic<Intrinsic::matrix_transpose>(
927auto *
Add = Builder.CreateFAdd(AT,
BT,
"mfadd");
929Instruction *NewInst = MBuilder.CreateMatrixTranspose(
930Add,
R->getZExtValue(),
C->getZExtValue(),
"mfadd_t");
931 updateShapeAndReplaceAllUsesWith(
I, NewInst);
932assert(computeShapeInfoForInst(NewInst, ShapeMap) ==
933 computeShapeInfoForInst(&
I, ShapeMap) &&
934"Shape of new instruction doesn't match original shape.");
936if (
auto *AddI = dyn_cast<Instruction>(
Add)) {
937 setShapeInfo(AddI, {
R,
C});
939 computeShapeInfoForInst(AddI, ShapeMap).value_or(ShapeMap[AddI]) ==
941"Shape of updated addition doesn't match cached shape.");
946 /// Try moving transposes in order to fold them away or into multiplies. 947void optimizeTransposes() {
948// First sink all transposes inside matmuls and adds, hoping that we end up 949// with NN, NT or TN variants. 953// We may remove II. By default continue on the next/prev instruction. 960// If we have a TT matmul or a TT add, lift the transpose. We may be able 961// to fold into consuming multiply or add. 972// Initially only the shape of matrix intrinsics is known. 973// Initialize the work list with ops carrying shape information. 980switch (
II->getIntrinsicID()) {
981case Intrinsic::matrix_multiply:
982case Intrinsic::matrix_transpose:
983case Intrinsic::matrix_column_major_load:
984case Intrinsic::matrix_column_major_store:
992// Avoid unnecessary work if there are no matrix intrinsics in the function. 1003// Propagate shapes until nothing changes any longer. 1004while (!WorkList.
empty()) {
1005 WorkList = propagateShapeForward(WorkList);
1006 WorkList = propagateShapeBackward(WorkList);
1010 optimizeTransposes();
1012dbgs() <<
"Dump after matrix transpose optimization:\n";
1022// First, collect all instructions with shape information and candidates for 1023// fusion (currently only matrix multiplies). 1025for (
auto *BB : RPOT)
1027if (
match(&
I, m_Intrinsic<Intrinsic::lifetime_end>()))
1028 LifetimeEnds.
push_back(cast<IntrinsicInst>(&
I));
1029if (ShapeMap.
find(&
I) == ShapeMap.
end())
1031if (
match(&
I, m_Intrinsic<Intrinsic::matrix_multiply>()))
1032 MaybeFusableInsts.
push_back(cast<CallInst>(&
I));
1036// Second, try to lower any dot products 1038for (
CallInst *CI : MaybeFusableInsts)
1039 lowerDotProduct(CI, FusedInsts, getFastMathFlags(CI));
1041// Third, try to fuse candidates. 1042for (
CallInst *CI : MaybeFusableInsts)
1044 LowerMatrixMultiplyFused(CI, FusedInsts, LifetimeEnds);
1046 Changed = !FusedInsts.
empty();
1048// Fourth, lower remaining instructions with shape information. 1050if (FusedInsts.
count(Inst))
1055if (
CallInst *CInst = dyn_cast<CallInst>(Inst))
1056 Changed |= VisitCallInst(CInst);
1060if (
auto *BinOp = dyn_cast<BinaryOperator>(Inst))
1061 Changed |= VisitBinaryOperator(BinOp);
1062if (
auto *UnOp = dyn_cast<UnaryOperator>(Inst))
1063 Changed |= VisitUnaryOperator(UnOp);
1065 Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder);
1067 Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder);
1071 RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func);
1072 RemarkGen.emitRemarks();
1075// Delete the instructions backwards, as it has a reduced likelihood of 1076// having to update as many def-use and use-def chains. 1078// Because we add to ToRemove during fusion we can't guarantee that defs 1079// are before uses. Change uses to poison temporarily as these should get 1082// For verification, we keep track of where we changed uses to poison in 1083// PoisonedInsts and then check that we in fact remove them. 1085for (
auto *Inst :
reverse(ToRemove)) {
1087if (
auto *Poisoned = dyn_cast<Instruction>(
U.getUser()))
1088 PoisonedInsts.
insert(Poisoned);
1092 PoisonedInsts.
erase(Inst);
1094if (!PoisonedInsts.
empty()) {
1095// If we didn't remove all poisoned instructions, it's a hard error. 1096dbgs() <<
"Poisoned but present instructions:\n";
1097for (
auto *
I : PoisonedInsts)
1105 /// Replace intrinsic calls 1106bool VisitCallInst(
CallInst *Inst) {
1111case Intrinsic::matrix_multiply:
1112 LowerMultiply(Inst);
1114case Intrinsic::matrix_transpose:
1115 LowerTranspose(Inst);
1117case Intrinsic::matrix_column_major_load:
1118 LowerColumnMajorLoad(Inst);
1120case Intrinsic::matrix_column_major_store:
1121 LowerColumnMajorStore(Inst);
1129 /// Compute the alignment for a column/row \p Idx with \p Stride between them. 1130 /// The address at \p Idx == 0 has alignment \p A. If \p Stride is a 1131 /// ConstantInt, reduce the initial alignment based on the byte offset. For 1132 /// non-ConstantInt strides, return the common alignment of the initial 1133 /// alignment and the element size in bytes. 1136Align InitialAlign =
DL.getValueOrABITypeAlignment(
A, ElementTy);
1140TypeSize ElementSizeInBits =
DL.getTypeSizeInBits(ElementTy);
1141if (
auto *ConstStride = dyn_cast<ConstantInt>(Stride)) {
1143 ConstStride->getZExtValue() * ElementSizeInBits / 8;
1149 /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between 1152bool IsVolatile, ShapeInfo Shape,
IRBuilder<> &Builder) {
1153auto *VType = cast<VectorType>(Ty);
1154Type *EltTy = VType->getElementType();
1158for (
unsignedI = 0, E = Shape.getNumVectors();
I < E; ++
I) {
1161 Stride, Shape.getStride(), EltTy, Builder);
1163 VecTy,
GEP, getAlignForIndex(
I, Stride, EltTy, MAlign),
1164 IsVolatile,
"col.load");
1168returnResult.addNumLoads(getNumOps(
Result.getVectorTy()) *
1172 /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix, 1173 /// starting at \p MatrixPtr[I][J]. 1176 ShapeInfo ResultShape,
Type *EltTy,
1184 ResultShape.NumColumns);
1186return loadMatrix(TileTy, TileStart,
Align,
1187 Builder.
getInt64(MatrixShape.getStride()), IsVolatile,
1188 ResultShape, Builder);
1191 /// Lower a load instruction with shape information. 1193bool IsVolatile, ShapeInfo Shape) {
1195 finalizeLowering(Inst,
1201 /// Lowers llvm.matrix.column.major.load. 1203 /// The intrinsic loads a matrix from memory using a stride between columns. 1204void LowerColumnMajorLoad(
CallInst *Inst) {
1206"Intrinsic only supports column-major layout!");
1211 {Inst->getArgOperand(3), Inst->getArgOperand(4)});
1214 /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p 1215 /// MatrixPtr[I][J]. 1216void storeMatrix(
const MatrixTy &StoreVal,
Value *MatrixPtr,
1217MaybeAlign MAlign,
bool IsVolatile, ShapeInfo MatrixShape,
1224 StoreVal.getNumColumns());
1226 storeMatrix(TileTy, StoreVal, TileStart, MAlign,
1227 Builder.
getInt64(MatrixShape.getStride()), IsVolatile, Builder);
1230 /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between 1232 MatrixTy storeMatrix(
Type *Ty, MatrixTy StoreVal,
Value *
Ptr,
1235auto VType = cast<VectorType>(Ty);
1237for (
auto Vec :
enumerate(StoreVal.vectors())) {
1242 Stride, StoreVal.getStride(), VType->getElementType(), Builder);
1244 getAlignForIndex(Vec.index(), Stride,
1245 VType->getElementType(),
1249return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) *
1250 StoreVal.getNumVectors());
1253 /// Lower a store instruction with shape information. 1255Value *Stride,
bool IsVolatile, ShapeInfo Shape) {
1257auto StoreVal = getMatrix(
Matrix, Shape, Builder);
1258 finalizeLowering(Inst,
1259 storeMatrix(
Matrix->getType(), StoreVal,
Ptr,
A, Stride,
1260 IsVolatile, Builder),
1264 /// Lowers llvm.matrix.column.major.store. 1266 /// The intrinsic store a matrix back memory using a stride between columns. 1267void LowerColumnMajorStore(
CallInst *Inst) {
1269"Intrinsic only supports column-major layout!");
1275 {Inst->getArgOperand(4), Inst->getArgOperand(5)});
1278// Set elements I..I+NumElts-1 to Block 1282// First, bring Block to the same size as Col 1283unsigned BlockNumElts =
1284 cast<FixedVectorType>(
Block->getType())->getNumElements();
1285unsigned NumElts = cast<FixedVectorType>(Col->getType())->getNumElements();
1286assert(NumElts >= BlockNumElts &&
"Too few elements for current block");
1291// If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7, 1295for (i = 0; i <
I; i++)
1298unsigned VecNumElts =
1299 cast<FixedVectorType>(Col->getType())->getNumElements();
1300for (; i <
I + BlockNumElts; i++)
1301Mask.push_back(i -
I + VecNumElts);
1303for (; i < VecNumElts; i++)
1311unsigned &NumComputeOps) {
1312 NumComputeOps += getNumOps(
A->getType());
1317if (AllowContraction) {
1318// Use fmuladd for floating point operations and let the backend decide 1319// if that's profitable. 1323 NumComputeOps += getNumOps(
A->getType());
1328 NumComputeOps += getNumOps(
A->getType());
1333 /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For 1334 /// users with shape information, there's nothing to do: they will use the 1335 /// cached value when they are lowered. For other users, \p Matrix is 1336 /// flattened and the uses are updated to use it. Also marks \p Inst for 1340auto inserted = Inst2ColumnMatrix.
insert(std::make_pair(Inst,
Matrix));
1342assert(inserted.second &&
"multiple matrix lowering mapping");
1345Value *Flattened =
nullptr;
1347if (ShapeMap.
find(
U.getUser()) == ShapeMap.
end()) {
1349 Flattened =
Matrix.embedInVector(Builder);
1355 /// Special case for MatMul lowering. Prevents scalar loads of row-major 1356 /// vectors Lowers to vector reduction add instead of sequential add if 1357 /// reassocation is enabled. 1358void lowerDotProduct(
CallInst *MatMul,
1367if (LShape.NumRows != 1 || RShape.NumColumns != 1)
// not a dot product 1376// Floating point reductions require reassocation. 1380auto CanBeFlattened = [](
Value *
Op) {
1386m_CombineOr(m_Intrinsic<Intrinsic::matrix_transpose>(),
1387 m_Intrinsic<Intrinsic::matrix_column_major_load>(
1390// Returns the cost benefit of using \p Op with the dot product lowering. If 1391// the returned cost is < 0, the argument is cheaper to use in the 1392// dot-product lowering. 1393auto GetCostForArg = [
this, &CanBeFlattened](
Value *
Op,
unsignedN) {
1397if (!isa<Instruction>(
Op))
1403if (!CanBeFlattened(
Op)) {
1405// Roughly estimate the cost for embedding the columns into a vector. 1406for (
unsignedI = 1;
I <
N; ++
I)
1420return NewCost - OriginalCost;
1423if (
match(
Op, m_Intrinsic<Intrinsic::matrix_transpose>())) {
1424// The transpose can be skipped for the dot product lowering, roughly 1425// estimate the savings as the cost of embedding the columns in a 1428for (
unsignedI = 1;
I <
N; ++
I)
1443// Iterate over LHS and operations feeding LHS and check if it is profitable 1444// to flatten the visited ops. For each op, we compute the difference 1445// between the flattened and matrix versions. 1451while (!WorkList.
empty()) {
1457if (OpCost + LHSCost >= LHSCost)
1462if (
auto *
I = dyn_cast<Instruction>(
Op))
1463 WorkList.
append(
I->op_begin(),
I->op_end());
1466// We compare the costs of a vector.reduce.add to sequential add. 1467int AddOpCode = IsIntVec ? Instruction::Add : Instruction::FAdd;
1468int MulOpCode = IsIntVec ? Instruction::Mul : Instruction::FMul;
1472 IsIntVec ? std::nullopt : std::optional(FMF)) +
1476 (LShape.NumColumns - 1) +
1478 (LShape.NumColumns);
1482 FusedInsts.
insert(MatMul);
1484auto FlattenArg = [&Builder, &FusedInsts, &CanBeFlattened,
1486// Matmul must be the only user of loads because we don't use LowerLoad 1487// for row vectors (LowerLoad results in scalar loads and shufflevectors 1488// instead of single vector load). 1489if (!CanBeFlattened(
Op))
1493auto It = ShapeMap.
find(
Op);
1494if (It != ShapeMap.
end()) {
1495 It->second = It->second.t();
1500 FusedInsts.insert(cast<Instruction>(
Op));
1501// If vector uses the builtin load, lower to a LoadInst 1503if (
match(
Op, m_Intrinsic<Intrinsic::matrix_column_major_load>(
1506Op->replaceAllUsesWith(NewLoad);
1507 eraseFromParentAndRemoveFromShapeMap(cast<Instruction>(
Op));
1509 }
elseif (
match(
Op, m_Intrinsic<Intrinsic::matrix_transpose>(
1512Op->replaceAllUsesWith(Arg);
1517for (
auto *V : ToFlatten)
1522// Insert mul/fmul and llvm.vector.reduce.fadd 1531 ConstantFP::get(cast<VectorType>(
LHS->
getType())->getElementType(),
1537// pack scalar back into a matrix and then replace matmul inst 1541 FusedInsts.insert(MatMul);
1545 /// Compute \p Result += \p A * \p B for input matrices with left-associating 1548 /// We can fold a transpose into the operand that is used to extract scalars. 1549 /// This is the first operands with row-major and the second with 1550 /// column-major. If \p IsScalarMatrixTransposed we assume the appropriate 1551 /// operand is transposed. 1552void emitMatrixMultiply(MatrixTy &Result,
const MatrixTy &
A,
1555constunsigned VF = std::max<unsigned>(
1558Result.getElementType()->getPrimitiveSizeInBits().getFixedValue(),
1560unsignedR =
Result.getNumRows();
1561unsignedC =
Result.getNumColumns();
1562unsignedM =
A.getNumColumns();
1564bool IsFP =
Result.getElementType()->isFloatingPointTy();
1565assert(
A.isColumnMajor() ==
B.isColumnMajor() &&
1566Result.isColumnMajor() ==
A.isColumnMajor() &&
1567"operands must agree on matrix layout");
1568unsigned NumComputeOps = 0;
1572if (
A.isColumnMajor()) {
1573// Multiply columns from the first operand with scalars from the second 1574// operand. Then move along the K axes and accumulate the columns. With 1575// this the adds can be vectorized without reassociation. 1576for (
unsigned J = 0; J <
C; ++J) {
1578// If Result is zero, we don't need to accumulate in the K==0 iteration. 1579bool isSumZero = isa<ConstantAggregateZero>(
Result.getColumn(J));
1582// Gradually lower the vectorization factor to cover the remainder. 1588for (
unsigned K = 0;
K <
M; ++
K) {
1591B.getColumn(IsScalarMatrixTransposed ? K : J),
1592 IsScalarMatrixTransposed ? J : K);
1595 createMulAdd(isSumZero && K == 0 ?
nullptr : Sum, L,
Splat,
1603// Multiply rows from the second operand with scalars from the first 1604// operand. Then move along the K axes and accumulate the rows. With this 1605// the adds can be vectorized without reassociation. 1606for (
unsignedI = 0;
I <
R; ++
I) {
1608bool isSumZero = isa<ConstantAggregateZero>(
Result.getRow(
I));
1610// Gradually lower the vectorization factor to cover the remainder. 1615for (
unsigned K = 0;
K <
M; ++
K) {
1618A.getVector(IsScalarMatrixTransposed ? K :
I),
1619 IsScalarMatrixTransposed ?
I : K);
1622 createMulAdd(isSumZero && K == 0 ?
nullptr : Sum,
Splat, R,
1630Result.addNumComputeOps(NumComputeOps);
1633 /// Ensure that the memory in \p Load does not alias \p Store by potentially 1634 /// copying it to a new location. This new or otherwise the original location 1641// If we can statically determine noalias we're good. 1643returnLoad->getPointerOperand();
1645// Create code to check if the memory locations of the Load and Store 1646// overlap and if they do, copy Load's operand to a new buffer. 1648// First, create new blocks for 2n part of the check and the copy. 1650// FIXME: Use lazy DTU and update SplitBlock to accept a DTU instead of a 1651// DT. Manually collect dominator tree updates, to avoid unnecessary work, 1652// as we adjust Check0 and Check1's branches. 1659nullptr,
"alias_cont");
1667// Check if the loaded memory location begins before the end of the store 1668// location. If the condition holds, they might overlap, otherwise they are 1669// guaranteed to not overlap. 1675const_cast<Value *
>(StoreLoc.
Ptr), IntPtrTy,
"store.begin");
1677 StoreBegin, ConstantInt::get(IntPtrTy, StoreLoc.
Size.
getValue()),
1678"store.end",
true,
true);
1680 IntPtrTy,
"load.begin");
1684// Check if the store begins before the end of the load location. If the 1685// condition holds, they alias, otherwise they are guaranteed to not 1690 LoadBegin, ConstantInt::get(IntPtrTy, LoadLoc.
Size.
getValue()),
1691"load.end",
true,
true);
1695// Copy load operand to new alloca. 1697auto *VT = cast<FixedVectorType>(
Load->getType());
1698// Use an array type for the alloca, to avoid potentially huge alignment 1699// requirements for large vector types. 1700auto *ArrayTy = ArrayType::get(VT->getElementType(), VT->getNumElements());
1708PHI->addIncoming(
Load->getPointerOperand(), Check0);
1709PHI->addIncoming(
Load->getPointerOperand(), Check1);
1710PHI->addIncoming(Alloca, Copy);
1721bool isFusionProfitable(
CallInst *MatMul) {
1728constunsignedR = LShape.NumRows;
1729constunsignedC = RShape.NumColumns;
1730constunsignedM = LShape.NumColumns;
1731auto *EltType = cast<VectorType>(MatMul->
getType())->getElementType();
1733constunsigned VF = std::max<unsigned>(
1739// Cost model for tiling 1741// For tiling to be beneficial, we need reuse either along the R or 1742// the C axis. We vectorize along the R axis so that means at least 1744// TODO: Also consider cost of copying if operands alias. 1745if (R <= VF &&
C == 1)
1747// Then we need enough elements to exceed the number of vector 1748// registers we have. Note that this is an oversimplification since 1749// fusing also takes some extra loads which may exceed the number of 1750// reloads necessary. 1751unsigned Op0Regs = (
R + VF - 1) / VF * M;
1752unsigned Op1Regs = (
M + VF - 1) / VF *
C;
1753return Op0Regs + Op1Regs >
1757 MatrixTy getZeroMatrix(
Type *EltType,
unsigned R,
unsignedC) {
1760for (
unsignedI = 0;
I <
C; ++
I)
1765void createTiledLoops(
CallInst *MatMul,
Value *LPtr, ShapeInfo LShape,
1767auto *EltType = cast<VectorType>(MatMul->
getType())->getElementType();
1769// Create the main tiling loop nest. 1777BasicBlock *InnerBody = TI.CreateTiledLoops(Start,
End, Builder, DTU, *LI);
1781 MatrixTy TileResult;
1782// Insert in the inner loop header. 1784// Create PHI nodes for the result columns to accumulate across iterations. 1789 TI.RowLoop.Header->getSingleSuccessor());
1790 TileResult.addVector(Phi);
1794// Insert in the inner loop body, which computes 1795// Res += Load(CurrentRow, K) * Load(K, CurrentColumn) 1797// Load tiles of the operands. 1799 loadMatrix(LPtr, {},
false, LShape, TI.RowLoop.Index, TI.KLoop.Index,
1802 loadMatrix(RPtr, {},
false, RShape, TI.KLoop.Index, TI.ColumnLoop.Index,
1804 emitMatrixMultiply(TileResult,
A,
B, Builder,
true,
false,
1805 getFastMathFlags(MatMul));
1806// Store result after the inner loop is done. 1808 storeMatrix(TileResult,
Store->getPointerOperand(),
Store->getAlign(),
1809Store->isVolatile(), {LShape.NumRows, RShape.NumColumns},
1810 TI.RowLoop.Index, TI.ColumnLoop.Index, EltType, Builder);
1812for (
unsignedI = 0;
I < TileResult.getNumVectors();
I++)
1813 ColumnPhis[
I]->addIncoming(TileResult.getVector(
I), TI.KLoop.Latch);
1815// Force unrolling of a few iterations of the inner loop, to make sure there 1816// is enough work per iteration. 1817// FIXME: The unroller should make this decision directly instead, but 1818// currently the cost-model is not up to the task. 1819unsigned InnerLoopUnrollCount = std::min(10u, LShape.NumColumns /
TileSize);
1821"llvm.loop.unroll.count", InnerLoopUnrollCount);
1828"Tiling only supported for column-major matrixes at the moment!");
1829if (!isFusionProfitable(MatMul))
1835constunsignedR = LShape.NumRows;
1836constunsignedC = RShape.NumColumns;
1837constunsignedM = LShape.NumColumns;
1838auto *EltType = cast<VectorType>(MatMul->
getType())->getElementType();
1840Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul);
1841Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul);
1845 createTiledLoops(MatMul, APtr, LShape, BPtr, RShape, Store);
1850constunsigned TileR = std::min(R -
I,
unsigned(
TileSize));
1851constunsigned TileC = std::min(
C - J,
unsigned(
TileSize));
1852 MatrixTy Res = getZeroMatrix(EltType, TileR, TileC);
1855constunsigned TileM = std::min(M - K,
unsigned(
TileSize));
1859 {TileR, TileM}, EltType, Builder);
1863 {TileM, TileC}, EltType, Builder);
1864 emitMatrixMultiply(Res,
A,
B, Builder,
true,
false,
1865 getFastMathFlags(MatMul));
1867 storeMatrix(Res, CPtr,
Store->getAlign(),
Store->isVolatile(), {R, M},
1873// Mark eliminated instructions as fused and remove them. 1874 FusedInsts.
insert(Store);
1875 FusedInsts.
insert(MatMul);
1876 eraseFromParentAndRemoveFromShapeMap(Store);
1877 eraseFromParentAndRemoveFromShapeMap(MatMul);
1879 FusedInsts.
insert(LoadOp0);
1880 eraseFromParentAndRemoveFromShapeMap(LoadOp0);
1882if (LoadOp1 != LoadOp0 && LoadOp1->
hasNUses(0)) {
1883 FusedInsts.
insert(LoadOp1);
1884 eraseFromParentAndRemoveFromShapeMap(LoadOp1);
1888 /// Try to lower matrix multiply chains by fusing operations. 1890 /// Call finalizeLowering on lowered instructions. Instructions that are 1891 /// completely eliminated by fusion are added to \p FusedInsts. 1893 LowerMatrixMultiplyFused(
CallInst *MatMul,
1899assert(AA && LI &&
"Analyses should be available");
1904// We can fold the transpose into the operand that is used to fetch scalars. 1908 :
match(
A, m_Intrinsic<Intrinsic::matrix_transpose>(
m_Value(
T)))) {
1910auto *EltType = cast<VectorType>(MatMul->
getType())->getElementType();
1913constunsignedR = LShape.NumRows;
1914constunsignedM = LShape.NumColumns;
1915constunsignedC = RShape.NumColumns;
1922 MA = getMatrix(
A, ShapeInfo(R, M), Builder);
1923 MB = getMatrix(
T, ShapeInfo(
C, M), Builder);
1926 MA = getMatrix(
T, ShapeInfo(R, M), Builder);
1927 MB = getMatrix(
B, ShapeInfo(
C, M), Builder);
1931// Initialize the output 1934 emitMatrixMultiply(Result, MA, MB, Builder,
false,
true,
1935 getFastMathFlags(MatMul));
1937 FusedInsts.
insert(MatMul);
1939 FusedInsts.
insert(cast<Instruction>(Transpose));
1940ToRemove.push_back(cast<Instruction>(Transpose));
1941// TODO: add a fake entry for the folded instruction so that this is 1942// included in the expression in the remark. 1943 Inst2ColumnMatrix[Transpose] = MatrixTy(M,
C, EltType);
1945 finalizeLowering(MatMul, Result, Builder);
1952// Lower {ld, ld} -> matmul -> st chains. No need to call finalizeLowering 1953// since the single store user will be lowered as part of this. 1954auto *LoadOp0 = dyn_cast<LoadInst>(
A);
1955auto *LoadOp1 = dyn_cast<LoadInst>(
B);
1957if (LoadOp0 && LoadOp1 && Store) {
1958// The store address must dominate the MatMul instruction, otherwise 1959// we create invalid IR. 1963for (
unsignedI = 0;
I != WorkList.
size(); ++
I) {
1964Value *Current = WorkList[
I];
1965auto *CurrI = dyn_cast<Instruction>(Current);
1968if (isa<PHINode>(CurrI))
1972if (CurrI->mayHaveSideEffects() || CurrI->mayReadFromMemory())
1975 WorkList.
insert(CurrI->op_begin(), CurrI->op_end());
1984// Deal with lifetime.end calls that might be between Load0/Load1 and the 1985// store. To avoid introducing loads to dead objects (i.e. after the 1986// lifetime has been termined by @llvm.lifetime.end), either sink them 1987// after the store if in the same block, or remove the lifetime.end marker 1988// otherwise. This might pessimize further optimizations, by extending the 1989// lifetime of the object until the function returns, but should be 1990// conservatively correct. 1994bool FusableOpsInSameBlock = LoadOp0->
getParent() == StoreParent &&
1996for (
unsignedIdx = 0;
Idx != LifetimeEnds.
size();) {
1999// If the lifetime.end is guaranteed to be before the loads or after the 2000// store, it won't interfere with fusion. 2005// If all fusable ops are in the same block and the lifetime.end is in a 2006// different block, it won't interfere with fusion. 2007if (FusableOpsInSameBlock &&
End->getParent() != StoreParent)
2010// If the loads don't alias the lifetime.end, it won't interfere with 2018// If both lifetime.end and the store are in the same block, extend the 2019// lifetime until after the store, so the new lifetime covers the loads 2020// we introduce later. 2021if (
End->getParent() == StoreParent) {
2022End->moveAfter(Store);
2026// Otherwise remove the conflicting lifetime.end marker. 2033 emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts);
2038 /// Lowers llvm.matrix.multiply. 2039void LowerMultiply(
CallInst *MatMul) {
2041auto *EltType = cast<VectorType>(MatMul->
getType())->getElementType();
2045const MatrixTy &Lhs = getMatrix(MatMul->
getArgOperand(0), LShape, Builder);
2046const MatrixTy &Rhs = getMatrix(MatMul->
getArgOperand(1), RShape, Builder);
2047assert(Lhs.getElementType() == Rhs.getElementType() &&
2048"Matrix multiply argument element types do not match.");
2050constunsignedR = LShape.NumRows;
2051constunsignedC = RShape.NumColumns;
2052assert(LShape.NumColumns == RShape.NumRows);
2054// Initialize the output 2056assert(Lhs.getElementType() ==
Result.getElementType() &&
2057"Matrix multiply result element type does not match arguments.");
2059 emitMatrixMultiply(Result, Lhs, Rhs, Builder,
false,
false,
2060 getFastMathFlags(MatMul));
2061 finalizeLowering(MatMul, Result, Builder);
2064 /// Lowers llvm.matrix.transpose. 2065void LowerTranspose(
CallInst *Inst) {
2071 MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder);
2073constunsigned NewNumVecs =
2074 InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns;
2075constunsigned NewNumElts =
2076 InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows;
2078for (
unsignedI = 0;
I < NewNumVecs; ++
I) {
2079// Build a single result vector. First initialize it. 2082// Go through the old elements and insert it into the resulting vector. 2083for (
auto J :
enumerate(InputMatrix.vectors())) {
2085// Row and column indices are transposed. 2089Result.addVector(ResultVector);
2092// TODO: Improve estimate of operations needed for transposes. Currently we 2093// just count the insertelement/extractelement instructions, but do not 2094// account for later simplifications/combines. 2097Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns)
2098 .addNumExposedTransposes(1),
2102 /// Lower load instructions, if shape information is available. 2104autoI = ShapeMap.
find(Inst);
2105if (
I == ShapeMap.
end())
2116autoI = ShapeMap.
find(StoredVal);
2117if (
I == ShapeMap.
end())
2126 /// Lower binary operators, if shape information is available. 2128autoI = ShapeMap.
find(Inst);
2129if (
I == ShapeMap.
end())
2136 ShapeInfo &Shape =
I->second;
2139 MatrixTy
A = getMatrix(Lhs, Shape, Builder);
2140 MatrixTy
B = getMatrix(Rhs, Shape, Builder);
2141assert(
A.isColumnMajor() ==
B.isColumnMajor() &&
2142Result.isColumnMajor() ==
A.isColumnMajor() &&
2143"operands must agree on matrix layout");
2147// Helper to perform binary op on vectors. 2150case Instruction::Add:
2152case Instruction::Mul:
2154case Instruction::Sub:
2156case Instruction::FAdd:
2158case Instruction::FMul:
2160case Instruction::FSub:
2167for (
unsignedI = 0;
I < Shape.getNumVectors(); ++
I)
2168Result.addVector(BuildVectorOp(
A.getVector(
I),
B.getVector(
I)));
2170 finalizeLowering(Inst,
2171Result.addNumComputeOps(getNumOps(
Result.getVectorTy()) *
2177 /// Lower unary operators, if shape information is available. 2179autoI = ShapeMap.
find(Inst);
2180if (
I == ShapeMap.
end())
2186 ShapeInfo &Shape =
I->second;
2189 MatrixTy
M = getMatrix(
Op, Shape, Builder);
2193// Helper to perform unary op on vectors. 2194auto BuildVectorOp = [&Builder, Inst](
Value *
Op) {
2196case Instruction::FNeg:
2203for (
unsignedI = 0;
I < Shape.getNumVectors(); ++
I)
2204Result.addVector(BuildVectorOp(
M.getVector(
I)));
2206 finalizeLowering(Inst,
2207Result.addNumComputeOps(getNumOps(
Result.getVectorTy()) *
2213 /// Helper to linearize a matrix expression tree into a string. Currently 2214 /// matrix expressions are linarized by starting at an expression leaf and 2215 /// linearizing bottom up. 2216structExprLinearizer {
2217unsigned LengthToBreak = 100;
2220unsigned LineLength = 0;
2223 /// Mapping from instructions to matrixes. It is used to identify 2224 /// matrix instructions. 2227 /// Mapping from values to the leaves of all expressions that the value is 2231 /// Set of matrix expressions in the scope of a given DISubprogram. 2234 /// Leaf node of the expression to linearize. 2237 /// Used to keep track of sub-expressions that get reused while linearizing 2238 /// the expression. Re-used sub-expressions are marked as (reused). 2246 : Stream(Str),
DL(
DL), Inst2Matrix(Inst2Matrix), Shared(Shared),
2247 ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {}
2251for (
unsigned i = 0; i <
N; i++)
2260void maybeIndent(
unsigned Indent) {
2261if (LineLength >= LengthToBreak)
2269 LineLength += S.
size();
2273Value *getUnderlyingObjectThroughLoads(
Value *V) {
2275return getUnderlyingObjectThroughLoads(
Ptr);
2276elseif (
V->getType()->isPointerTy())
2281 /// Returns true if \p V is a matrix value in the given subprogram. 2282bool isMatrix(
Value *V)
const{
return ExprsInSubprogram.
count(V); }
2284 /// If \p V is a matrix value, print its shape as NumRows x NumColumns to 2287autoM = Inst2Matrix.
find(V);
2288if (M == Inst2Matrix.
end())
2291SS <<
M->second.getNumRows();
2293SS <<
M->second.getNumColumns();
2297 /// Write the called function name. Handles calls to llvm.matrix.* 2298 /// specially: we write the name, followed by the dimensions of the input 2299 /// matrixes, followed by the scalar type name. 2302write(
"<no called fn>");
2305if (!
Name.starts_with(
"llvm.matrix")) {
2309auto *
II = cast<IntrinsicInst>(CI);
2316switch (
II->getIntrinsicID()) {
2317case Intrinsic::matrix_multiply:
2318 prettyPrintMatrixType(
II->getOperand(0), SS);
2320 prettyPrintMatrixType(
II->getOperand(1), SS);
2321SS <<
"." << *
II->getType()->getScalarType();
2323case Intrinsic::matrix_transpose:
2324 prettyPrintMatrixType(
II->getOperand(0), SS);
2325SS <<
"." << *
II->getType()->getScalarType();
2327case Intrinsic::matrix_column_major_load:
2328 prettyPrintMatrixType(
II, SS);
2329SS <<
"." << *
II->getType()->getScalarType();
2331case Intrinsic::matrix_column_major_store:
2332 prettyPrintMatrixType(
II->getOperand(0), SS);
2333SS <<
"." << *
II->getOperand(0)->getType()->getScalarType();
2342unsigned getNumShapeArgs(
CallInst *CI)
const{
2344switch (
II->getIntrinsicID()) {
2345case Intrinsic::matrix_multiply:
2347case Intrinsic::matrix_transpose:
2349case Intrinsic::matrix_column_major_load:
2350case Intrinsic::matrix_column_major_store:
2359 /// Special printing for values: for pointers, we print if they refer to an 2360 /// (function) external address or a stack address, for other values we 2361 /// either print the constant or "scalar"/"matrix" for other values. 2363V = getUnderlyingObjectThroughLoads(V);
2364if (
V->getType()->isPointerTy()) {
2365if (isa<AllocaInst>(V)) {
2366 Stream <<
"stack addr";
2372if (!
V->getName().empty()) {
2373 Stream <<
" %" <<
V->getName() <<
"";
2374 LineLength +=
V->getName().size() + 2;
2382if (
auto *CI = dyn_cast<ConstantInt>(V))
2383 TmpStream << CI->getValue();
2384elseif (isa<Constant>(V))
2385 TmpStream <<
"constant";
2388 TmpStream <<
"matrix";
2390 TmpStream <<
"scalar";
2392 Tmp = std::string(
StringRef(Tmp).trim());
2393 LineLength += Tmp.size();
2397 /// Linearize expression \p Expr starting at an indentation of \p Indent. 2398 /// Expressions that are re-used multiple times are prefixed with (reused) 2399 /// at the re-used root instruction. 2400void linearizeExpr(
Value *Expr,
unsigned Indent,
bool ParentReused,
2402auto *
I = cast<Instruction>(Expr);
2403 maybeIndent(Indent);
2406// Is Expr shared with other expression leaves? 2407bool ExprShared =
false;
2409// Deal with shared subtrees. Mark them as shared, if required. 2411autoSI = Shared.find(Expr);
2412assert(SI != Shared.end() &&
SI->second.count(Leaf));
2417DebugLocDL = cast<Instruction>(S)->getDebugLoc();
2418write(
"shared with remark at line " + std::to_string(
DL.getLine()) +
2419" column " + std::to_string(
DL.getCol()) +
" (");
2421 ExprShared =
SI->second.size() > 1;
2424bool Reused = !ReusedExprs.
insert(Expr).second;
2425if (Reused && !ParentReused)
2428if (
auto *CI = dyn_cast<CallInst>(
I)) {
2432 }
elseif (isa<BitCastInst>(Expr)) {
2433// Special case bitcasts, which are used to materialize matrixes from 2438 Ops.
append(
I->value_op_begin(),
I->value_op_end());
2439write(std::string(
I->getOpcodeName()));
2442write(std::string(
"("));
2444unsigned NumOpsToBreak = 1;
2445if (
match(Expr, m_Intrinsic<Intrinsic::matrix_column_major_load>()))
2449if (Ops.size() > NumOpsToBreak)
2452 maybeIndent(Indent + 1);
2454 linearizeExpr(
Op, Indent + 1, Reused, ExprShared);
2457if (
Op != Ops.back())
2464const std::string &getResult() {
2469 /// Generate remarks for matrix operations in a function. To generate remarks 2470 /// for matrix expressions, the following approach is used: 2471 /// 1. Use the inlined-at debug information to group matrix operations to the 2472 /// DISubprograms they are contained in. 2473 /// 2. Collect leaves of matrix expressions (done in 2474 /// RemarkGenerator::getExpressionLeaves) for each subprogram - expression 2475// mapping. Leaves are lowered matrix instructions without other matrix 2476// users (like stores) in the current subprogram. 2477 /// 3. For each leaf, create a remark containing a linearizied version of the 2478 /// matrix expression. The expression is linearized by a recursive 2479 /// bottom-up traversal of the matrix operands, starting at a leaf. Note 2480 /// that multiple leaves can share sub-expressions. Shared subexpressions 2481 /// are explicitly marked as shared(). 2482structRemarkGenerator {
2490 : Inst2Matrix(Inst2Matrix), ORE(ORE),
Func(
Func),
2491DL(
Func.getDataLayout()) {}
2493 /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are 2494 /// instructions in Inst2Matrix returning void or without any users in 2495 /// \p ExprsInSubprogram. Currently that should only include stores. 2499for (
auto *Expr : ExprsInSubprogram)
2502 return ExprsInSubprogram.count(U);
2508 /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf 2509 /// to all visited expressions in \p Shared. Limit the matrix operations to 2510 /// the ones in \p ExprsInSubprogram. 2515if (!ExprsInSubprogram.
count(V))
2518 Shared[
V].insert(Leaf);
2520for (
Value *
Op : cast<Instruction>(V)->operand_values())
2521 collectSharedInfo(Leaf,
Op, ExprsInSubprogram, Shared);
2524 /// Calculate the number of exclusive and shared op counts for expression 2525 /// starting at \p V. Expressions used multiple times are counted once. 2526 /// Limit the matrix operations to the ones in \p ExprsInSubprogram. 2527 std::pair<OpInfoTy, OpInfoTy>
2531if (!ExprsInSubprogram.
count(Root))
2534// Already counted this expression. Stop. 2535if (!ReusedExprs.
insert(Root).second)
2538 OpInfoTy SharedCount;
2541autoI = Shared.find(Root);
2542auto CM = Inst2Matrix.
find(Root);
2543if (
I->second.size() == 1)
2544 Count = CM->second.getOpInfo();
2546 SharedCount = CM->second.getOpInfo();
2548for (
Value *
Op : cast<Instruction>(Root)->operand_values()) {
2549autoC = sumOpInfos(
Op, ReusedExprs, ExprsInSubprogram, Shared);
2551 SharedCount +=
C.second;
2553return {Count, SharedCount};
2560// Map matrix operations to their containting subprograms, by traversing 2561// the inlinedAt chain. If the function does not have a DISubprogram, we 2562// only map them to the containing function. 2564for (
constauto &KV : Inst2Matrix) {
2565if (
Func.getSubprogram()) {
2566auto *
I = cast<Instruction>(KV.first);
2569 Subprog2Exprs[
getSubprogram(Context->getScope())].push_back(
2574 Subprog2Exprs[
nullptr].push_back(KV.first);
2577for (
auto &KV : Subprog2Exprs) {
2580auto Leaves = getExpressionLeaves(ExprsInSubprogram);
2583for (
Value *Leaf : Leaves)
2584 collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared);
2586// Generate remarks for each leaf. 2587for (
auto *L : Leaves) {
2589DebugLoc Loc = cast<Instruction>(L)->getDebugLoc();
2590DILocation *Context = cast<Instruction>(L)->getDebugLoc();
2600 OpInfoTy Counts, SharedCounts;
2601 std::tie(Counts, SharedCounts) =
2602 sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared);
2607 Rem <<
"Lowered with ";
2608 Rem <<
ore::NV(
"NumStores", Counts.NumStores) <<
" stores, " 2609 <<
ore::NV(
"NumLoads", Counts.NumLoads) <<
" loads, " 2610 <<
ore::NV(
"NumComputeOps", Counts.NumComputeOps)
2612 <<
ore::NV(
"NumExposedTransposes", Counts.NumExposedTransposes)
2613 <<
" exposed transposes";
2615if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 ||
2616 SharedCounts.NumComputeOps > 0) {
2617 Rem <<
",\nadditionally " 2618 <<
ore::NV(
"NumStores", SharedCounts.NumStores) <<
" stores, " 2619 <<
ore::NV(
"NumLoads", SharedCounts.NumLoads) <<
" loads, " 2620 <<
ore::NV(
"NumFPOps", SharedCounts.NumComputeOps)
2622 <<
" are shared with other expressions";
2625 Rem << (
"\n" + linearize(L, Shared, ExprsInSubprogram, DL));
2636 ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L);
2637 Lin.linearizeExpr(L, 0,
false,
false);
2638return Lin.getResult();
2648 LowerMatrixIntrinsics LMT(
F,
TTI, Minimal ?
nullptr : &AM);
2663OS, MapClassName2PassName);
ReachingDefAnalysis InstSet & ToRemove
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static const Function * getParent(const Value *V)
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
#define clEnumValN(ENUMVAL, FLAGNAME, DESC)
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
hexagon Hexagon specific predictive commoning for HVX vectors
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
static DISubprogram * getSubprogram(DIScope *Scope)
Helper function to either return Scope, if it is a subprogram or the attached subprogram for a local ...
static cl::opt< bool > ForceFusion("force-fuse-matrix", cl::init(false), cl::Hidden, cl::desc("Force matrix instruction fusion even if not profitable."))
static cl::opt< bool > VerifyShapeInfo("verify-matrix-shapes", cl::Hidden, cl::desc("Enable/disable matrix shape verification."), cl::init(false))
static bool isSplat(Value *V)
Return true if V is a splat of a value (which is used when multiplying a matrix with a scalar).
static cl::opt< bool > TileUseLoops("fuse-matrix-use-loops", cl::init(false), cl::Hidden, cl::desc("Generate loop nest for tiling."))
static cl::opt< bool > FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden, cl::desc("Enable/disable fusing matrix instructions."))
auto m_AnyAdd(const LTy &L, const RTy &R)
Match any add operation (fp or integer).
static cl::opt< bool > AllowContractEnabled("matrix-allow-contract", cl::init(false), cl::Hidden, cl::desc("Allow the use of FMAs if available and profitable. This may " "result in different results, due to less rounding error."))
auto m_AnyMul(const LTy &L, const RTy &R)
Match any mul operation (fp or integer).
static cl::opt< bool > PrintAfterTransposeOpt("matrix-print-after-transpose-opt", cl::init(false))
static cl::opt< unsigned > TileSize("fuse-matrix-tile-size", cl::init(4), cl::Hidden, cl::desc("Tile size for matrix instruction fusion using square-shaped tiles."))
static cl::opt< MatrixLayoutTy > MatrixLayout("matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor), cl::desc("Sets the default matrix layout"), cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major", "Use column-major layout"), clEnumValN(MatrixLayoutTy::RowMajor, "row-major", "Use row-major layout")))
uint64_t IntrinsicInst * II
PowerPC Reduce CR logical Operation
This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
static unsigned getNumElements(Type *Ty)
static Value * extractVector(IRBuilderTy &IRB, Value *V, unsigned BeginIndex, unsigned EndIndex, const Twine &Name)
static Value * insertVector(IRBuilderTy &IRB, Value *Old, Value *V, unsigned BeginIndex, const Twine &Name)
This file defines the make_scope_exit function, which executes user-defined cleanup logic at scope ex...
This file defines the SmallSet class.
This file defines the SmallVector class.
static SymbolRef::Type getType(const Symbol *Sym)
static const int BlockSize
This pass exposes codegen information to IR-level passes.
static std::optional< unsigned > getOpcode(ArrayRef< VPValue * > Values)
Returns the opcode of Values or ~0 if they do not all agree.
static SDValue LowerStore(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG)
static SDValue LowerLoad(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG)
A manager for alias analyses.
bool isNoAlias(const MemoryLocation &LocA, const MemoryLocation &LocB)
A trivial helper function to check to see if the specified pointers are no-alias.
an instruction to allocate memory on the stack
Align getAlign() const
Return the alignment of the memory that is being allocated by the instruction.
A container for analyses that lazily runs them and caches their results.
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
LLVM Basic Block Representation.
iterator begin()
Instruction iterator methods.
reverse_iterator rbegin()
InstListType::reverse_iterator reverse_iterator
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
BinaryOps getOpcode() const
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
User::op_iterator arg_begin()
Return the iterator pointing to the beginning of the argument list.
MaybeAlign getParamAlign(unsigned ArgNo) const
Extract the alignment for a call or parameter (0=unknown).
Value * getArgOperand(unsigned i) const
User::op_iterator arg_end()
Return the iterator pointing to the end of the argument list.
This class represents a function call, abstracting a target machine's calling convention.
static ConstantAggregateZero * get(Type *Ty)
This is the shared class of boolean and integer constants.
DISubprogram * getSubprogram() const
Get the subprogram for this scope.
Base class for scope-like contexts.
This class represents an Operation in the Expression.
A parsed version of the target data layout string in and methods for querying it.
DILocation * getInlinedAt() const
iterator find(const_arg_type_t< KeyT > Val)
bool erase(const KeyT &Val)
size_type count(const_arg_type_t< KeyT > Val) const
Return 1 if the specified key is in the map, 0 otherwise.
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Analysis pass which computes a DominatorTree.
void applyUpdates(ArrayRef< UpdateType > Updates)
Inform the dominator tree about a sequence of CFG edge insertions and deletions and perform a batch u...
static constexpr UpdateKind Delete
static constexpr UpdateKind Insert
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
Convenience struct for specifying and reasoning about fast-math flags.
void setAllowContract(bool B=true)
bool allowReassoc() const
Flag queries.
bool allowContract() const
Class to represent fixed width SIMD vectors.
static FixedVectorType * get(Type *ElementType, unsigned NumElts)
Intrinsic::ID getIntrinsicID() const LLVM_READONLY
getIntrinsicID - This method returns the ID number of the specified function, or Intrinsic::not_intri...
bool isIntrinsic() const
isIntrinsic - Returns true if the function's name starts with "llvm.".
CallInst * CreateFAddReduce(Value *Acc, Value *Src)
Create a sequential vector fadd reduction intrinsic of the source vector.
Value * CreateICmpULT(Value *LHS, Value *RHS, const Twine &Name="")
Value * CreateFSub(Value *L, Value *R, const Twine &Name="", MDNode *FPMD=nullptr)
Value * CreateInsertElement(Type *VecTy, Value *NewElt, Value *Idx, const Twine &Name="")
AllocaInst * CreateAlloca(Type *Ty, unsigned AddrSpace, Value *ArraySize=nullptr, const Twine &Name="")
Value * CreateExtractElement(Value *Vec, Value *Idx, const Twine &Name="")
LoadInst * CreateAlignedLoad(Type *Ty, Value *Ptr, MaybeAlign Align, const char *Name)
Value * CreateFAdd(Value *L, Value *R, const Twine &Name="", MDNode *FPMD=nullptr)
Value * CreateVectorSplat(unsigned NumElts, Value *V, const Twine &Name="")
Return a vector value that contains.
CallInst * CreateAddReduce(Value *Src)
Create a vector int add reduction intrinsic of the source vector.
IntegerType * getIntPtrTy(const DataLayout &DL, unsigned AddrSpace=0)
Fetch the type of an integer with size at least as big as that of a pointer in the given address spac...
void setFastMathFlags(FastMathFlags NewFMF)
Set the fast-math flags to be used with generated fp-math operators.
Value * CreateGEP(Type *Ty, Value *Ptr, ArrayRef< Value * > IdxList, const Twine &Name="", GEPNoWrapFlags NW=GEPNoWrapFlags::none())
ConstantInt * getInt64(uint64_t C)
Get a constant 64-bit value.
CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > Types, ArrayRef< Value * > Args, FMFSource FMFSource={}, const Twine &Name="")
Create a call to intrinsic ID with Args, mangled using Types.
PHINode * CreatePHI(Type *Ty, unsigned NumReservedValues, const Twine &Name="")
Value * CreateSub(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
ConstantInt * getIntN(unsigned N, uint64_t C)
Get a constant N-bit value, zero extended or truncated from a 64-bit value.
BranchInst * CreateCondBr(Value *Cond, BasicBlock *True, BasicBlock *False, MDNode *BranchWeights=nullptr, MDNode *Unpredictable=nullptr)
Create a conditional 'br Cond, TrueDest, FalseDest' instruction.
LoadInst * CreateLoad(Type *Ty, Value *Ptr, const char *Name)
Provided to resolve 'CreateLoad(Ty, Ptr, "...")' correctly, instead of converting the string to 'bool...
Value * CreateShuffleVector(Value *V1, Value *V2, Value *Mask, const Twine &Name="")
Value * CreateAdd(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Value * CreatePtrToInt(Value *V, Type *DestTy, const Twine &Name="")
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
StoreInst * CreateAlignedStore(Value *Val, Value *Ptr, MaybeAlign Align, bool isVolatile=false)
Value * CreateFMul(Value *L, Value *R, const Twine &Name="", MDNode *FPMD=nullptr)
Value * CreateFNeg(Value *V, const Twine &Name="", MDNode *FPMathTag=nullptr)
CallInst * CreateMemCpy(Value *Dst, MaybeAlign DstAlign, Value *Src, MaybeAlign SrcAlign, uint64_t Size, bool isVolatile=false, MDNode *TBAATag=nullptr, MDNode *TBAAStructTag=nullptr, MDNode *ScopeTag=nullptr, MDNode *NoAliasTag=nullptr)
Create and insert a memcpy between the specified pointers.
Value * CreateMul(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
static InstructionCost getInvalid(CostType Val=0)
void setFastMathFlags(FastMathFlags FMF)
Convenience function for setting multiple fast-math flags on this instruction, which must be an opera...
InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
FastMathFlags getFastMathFlags() const LLVM_READONLY
Convenience function for getting all the fast-math flags, which must be an operator which supports th...
A wrapper class for inspecting calls to intrinsic functions.
An instruction for reading from memory.
bool isVolatile() const
Return true if this is a load from a volatile memory location.
Align getAlign() const
Return the alignment of the access that is being performed.
TypeSize getValue() const
Analysis pass that exposes the LoopInfo for a function.
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
void printPipeline(raw_ostream &OS, function_ref< StringRef(StringRef)> MapClassName2PassName)
This class implements a map that also provides access to all stored values in a deterministic order.
iterator find(const KeyT &Key)
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
CallInst * CreateMatrixTranspose(Value *Matrix, unsigned Rows, unsigned Columns, const Twine &Name="")
Create a llvm.matrix.transpose call, transposing Matrix with Rows rows and Columns columns.
CallInst * CreateMatrixMultiply(Value *LHS, Value *RHS, unsigned LHSRows, unsigned LHSColumns, unsigned RHSColumns, const Twine &Name="")
Create a llvm.matrix.multiply call, multiplying matrixes LHS and RHS.
Representation for a specific memory location.
static MemoryLocation get(const LoadInst *LI)
Return a location with information about the memory reference by the given instruction.
LocationSize Size
The maximum size of the location, in address-units, or UnknownSize if the size is not known.
const Value * Ptr
The address of the start of the location.
static MemoryLocation getForArgument(const CallBase *Call, unsigned ArgIdx, const TargetLibraryInfo *TLI)
Return a location representing a particular argument of a call.
The optimization diagnostic interface.
bool allowExtraAnalysis(StringRef PassName) const
Whether we allow for extra compile-time budget to perform more analysis to produce fewer false positi...
void emit(DiagnosticInfoOptimizationBase &OptDiag)
Output the remark via the diagnostic handler and to the optimization record file.
Diagnostic information for applied optimization remarks.
static PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
void preserve()
Mark an analysis as preserved.
A vector that has set insertion semantics.
size_type size() const
Determine the number of elements in the SetVector.
size_type count(const key_type &key) const
Count the number of elements of a given key in the SetVector.
bool insert(const value_type &X)
Insert a new element into the SetVector.
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
bool contains(ConstPtrType Ptr) const
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
A SetVector that performs no allocations if smaller than a certain size.
SmallSet - This maintains a set of unique values, optimizing for the case when the set is small (less...
std::pair< const_iterator, bool > insert(const T &V)
insert - Insert an element into the set if it isn't already there.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
An instruction for storing to memory.
bool isVolatile() const
Return true if this is a store to a volatile memory location.
StringRef - Represent a constant reference to a string, i.e.
StringRef drop_front(size_t N=1) const
Return a StringRef equal to 'this' but with the first N elements dropped.
constexpr size_t size() const
size - Get the string size.
Analysis pass providing the TargetTransformInfo.
This pass provides access to the codegen interfaces that are needed for IR-level transformations.
TypeSize getRegisterBitWidth(RegisterKind K) const
InstructionCost getMemoryOpCost(unsigned Opcode, Type *Src, Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput, OperandValueInfo OpdInfo={OK_AnyValue, OP_None}, const Instruction *I=nullptr) const
InstructionCost getArithmeticReductionCost(unsigned Opcode, VectorType *Ty, std::optional< FastMathFlags > FMF, TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput) const
Calculate the cost of vector reduction intrinsics.
unsigned getRegisterClassForType(bool Vector, Type *Ty=nullptr) const
@ TCK_RecipThroughput
Reciprocal throughput.
InstructionCost getArithmeticInstrCost(unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput, TTI::OperandValueInfo Opd1Info={TTI::OK_AnyValue, TTI::OP_None}, TTI::OperandValueInfo Opd2Info={TTI::OK_AnyValue, TTI::OP_None}, ArrayRef< const Value * > Args={}, const Instruction *CxtI=nullptr, const TargetLibraryInfo *TLibInfo=nullptr) const
This is an approximation of reciprocal throughput of a math/logic op.
InstructionCost getShuffleCost(ShuffleKind Kind, VectorType *Tp, ArrayRef< int > Mask={}, TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput, int Index=0, VectorType *SubTp=nullptr, ArrayRef< const Value * > Args={}, const Instruction *CxtI=nullptr) const
unsigned getNumberOfRegisters(unsigned ClassID) const
@ SK_Splice
Concatenates elements from the first input vector with elements of the second input vector.
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
The instances of the Type class are immutable: once they are created, they are never changed.
unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
bool isVoidTy() const
Return true if this is 'void'.
Type * getScalarType() const
If this is a vector type, return the element type, otherwise return 'this'.
UnaryOps getOpcode() const
A Use represents the edge between a Value definition and its users.
Value * getOperand(unsigned i) const
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
user_iterator user_begin()
bool hasOneUse() const
Return true if there is exactly one use of this value.
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
iterator_range< user_iterator > users()
bool hasNUses(unsigned N) const
Return true if this Value has exactly N uses.
iterator_range< use_iterator > uses()
StringRef getName() const
Return a constant reference to the value's name.
Type * getElementType() const
constexpr ScalarTy getFixedValue() const
An efficient, type-erasing, non-owning reference to a callable.
const ParentTy * getParent() const
self_iterator getIterator()
A range adaptor for a pair of iterators.
This class implements an extremely fast bulk output stream that can only output to a stream.
A raw_ostream that writes to an std::string.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
@ C
The default llvm calling convention, compatible with C.
StringRef getBaseName(ID id)
Return the LLVM name for an intrinsic, without encoded types for overloading, such as "llvm....
TwoOps_match< ValueOpTy, PointerOpTy, Instruction::Store > m_Store(const ValueOpTy &ValueOp, const PointerOpTy &PointerOp)
Matches StoreInst.
BinaryOp_match< LHS, RHS, Instruction::Add > m_Add(const LHS &L, const RHS &R)
class_match< BinaryOperator > m_BinOp()
Match an arbitrary binary operation and ignore it.
specific_intval< false > m_SpecificInt(const APInt &V)
Match a specific integer value or vector with all elements equal to the value.
BinaryOp_match< LHS, RHS, Instruction::FMul > m_FMul(const LHS &L, const RHS &R)
bool match(Val *V, const Pattern &P)
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
BinaryOp_match< LHS, RHS, Instruction::FAdd > m_FAdd(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)
OneUse_match< T > m_OneUse(const T &SubPattern)
OneOps_match< OpTy, Instruction::Load > m_Load(const OpTy &Op)
Matches LoadInst.
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
match_combine_or< LTy, RTy > m_CombineOr(const LTy &L, const RTy &R)
Combine two pattern matchers matching L || R.
ValuesClass values(OptsTy... Options)
Helper to build a ValuesClass by forwarding a variable number of arguments as an initializer list to ...
initializer< Ty > init(const Ty &Val)
ElementType
The element type of an SRV or UAV resource.
DiagnosticInfoOptimizationBase::Argument NV
NodeAddr< PhiNode * > Phi
NodeAddr< FuncNode * > Func
This is an optimization pass for GlobalISel generic memory operations.
auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)
Get the size of a range.
detail::scope_exit< std::decay_t< Callable > > make_scope_exit(Callable &&F)
auto enumerate(FirstRange &&First, RestRanges &&...Rest)
Given two or more input ranges, returns a new range whose values are tuples (A, B,...
auto successors(const MachineBasicBlock *BB)
bool operator!=(uint64_t V1, const APInt &V2)
iterator_range< T > make_range(T x, T y)
Convenience function for iterating over sub-ranges.
LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt & operator+=(DynamicAPInt &A, int64_t B)
const Value * getUnderlyingObject(const Value *V, unsigned MaxLookup=6)
This method strips off any GEP address adjustments, pointer casts or llvm.threadlocal....
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Value * concatenateVectors(IRBuilderBase &Builder, ArrayRef< Value * > Vecs)
Concatenate a list of vectors.
bool operator==(const AddressRangeValuePair &LHS, const AddressRangeValuePair &RHS)
const Value * getPointerOperand(const Value *V)
A helper function that returns the pointer operand of a load, store or GEP instruction.
void addStringMetadataToLoop(Loop *TheLoop, const char *MDString, unsigned V=0)
Set input string into loop metadata by keeping other values intact.
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
auto reverse(ContainerTy &&C)
Error write(MCStreamer &Out, ArrayRef< std::string > Inputs, OnCuIndexOverflow OverflowOptValue)
void sort(IteratorTy Start, IteratorTy End)
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
void report_fatal_error(Error Err, bool gen_crash_diag=true)
Report a serious error, calling any installed error handler.
raw_fd_ostream & errs()
This returns a reference to a raw_ostream for standard error.
DWARFExpression::Operation Op
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
BasicBlock * SplitBlock(BasicBlock *Old, BasicBlock::iterator SplitPt, DominatorTree *DT, LoopInfo *LI=nullptr, MemorySSAUpdater *MSSAU=nullptr, const Twine &BBName="", bool Before=false)
Split the specified block at the specified instruction.
Align commonAlignment(Align A, uint64_t Offset)
Returns the alignment that satisfies both alignments.
llvm::SmallVector< int, 16 > createSequentialMask(unsigned Start, unsigned NumInts, unsigned NumUndefs)
Create a sequential shuffle mask.
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
This struct is a compact representation of a valid (non-zero power of two) alignment.
This struct is a compact representation of a valid (power of two) or undefined (0) alignment.
A CRTP mix-in to automatically provide informational APIs needed for passes.
A helper struct to create IR loop nests for tiling in IR of the following form: for ColumnLoop....