Movatterモバイル変換


[0]ホーム

URL:


LLVM 20.0.0git
LowerMatrixIntrinsics.cpp
Go to the documentation of this file.
1//===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Lower matrix intrinsics to vector operations.
10//
11// TODO:
12// * Improve fusion:
13// * Support more cases, e.g. multiply-add, multiply-sub, operands/results
14// transposed.
15// * Improve cost-modeling, e.g. choose different number of rows/columns
16// columns for tiles, consider cost of copies on alias.
17//
18//===----------------------------------------------------------------------===//
19
20#include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h"
21#include "llvm/ADT/PostOrderIterator.h"
22#include "llvm/ADT/ScopeExit.h"
23#include "llvm/ADT/SmallSet.h"
24#include "llvm/ADT/SmallVector.h"
25#include "llvm/Analysis/AliasAnalysis.h"
26#include "llvm/Analysis/DomTreeUpdater.h"
27#include "llvm/Analysis/LoopInfo.h"
28#include "llvm/Analysis/OptimizationRemarkEmitter.h"
29#include "llvm/Analysis/TargetTransformInfo.h"
30#include "llvm/Analysis/ValueTracking.h"
31#include "llvm/Analysis/VectorUtils.h"
32#include "llvm/IR/CFG.h"
33#include "llvm/IR/DataLayout.h"
34#include "llvm/IR/DebugInfoMetadata.h"
35#include "llvm/IR/Function.h"
36#include "llvm/IR/IRBuilder.h"
37#include "llvm/IR/Instructions.h"
38#include "llvm/IR/IntrinsicInst.h"
39#include "llvm/IR/MatrixBuilder.h"
40#include "llvm/IR/PatternMatch.h"
41#include "llvm/Support/Alignment.h"
42#include "llvm/Support/CommandLine.h"
43#include "llvm/Support/Debug.h"
44#include "llvm/Transforms/Utils/BasicBlockUtils.h"
45#include "llvm/Transforms/Utils/LoopUtils.h"
46#include "llvm/Transforms/Utils/MatrixUtils.h"
47
48#include <cmath>
49
50using namespacellvm;
51using namespacePatternMatch;
52
53#define DEBUG_TYPE "lower-matrix-intrinsics"
54
55staticcl::opt<bool>
56FuseMatrix("fuse-matrix",cl::init(true),cl::Hidden,
57cl::desc("Enable/disable fusing matrix instructions."));
58// TODO: Allow and use non-square tiles.
59staticcl::opt<unsigned>TileSize(
60"fuse-matrix-tile-size",cl::init(4),cl::Hidden,
61cl::desc(
62"Tile size for matrix instruction fusion using square-shaped tiles."));
63staticcl::opt<bool>TileUseLoops("fuse-matrix-use-loops",cl::init(false),
64cl::Hidden,
65cl::desc("Generate loop nest for tiling."));
66staticcl::opt<bool>ForceFusion(
67"force-fuse-matrix",cl::init(false),cl::Hidden,
68cl::desc("Force matrix instruction fusion even if not profitable."));
69staticcl::opt<bool>AllowContractEnabled(
70"matrix-allow-contract",cl::init(false),cl::Hidden,
71cl::desc("Allow the use of FMAs if available and profitable. This may "
72"result in different results, due to less rounding error."));
73
74staticcl::opt<bool>
75VerifyShapeInfo("verify-matrix-shapes",cl::Hidden,
76cl::desc("Enable/disable matrix shape verification."),
77cl::init(false));
78
79enum classMatrixLayoutTy {ColumnMajor,RowMajor };
80
81staticcl::opt<MatrixLayoutTy>MatrixLayout(
82"matrix-default-layout",cl::init(MatrixLayoutTy::ColumnMajor),
83cl::desc("Sets the default matrix layout"),
84cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor,"column-major",
85"Use column-major layout"),
86clEnumValN(MatrixLayoutTy::RowMajor,"row-major",
87"Use row-major layout")));
88
89staticcl::opt<bool>PrintAfterTransposeOpt("matrix-print-after-transpose-opt",
90cl::init(false));
91
92/// Helper function to either return Scope, if it is a subprogram or the
93/// attached subprogram for a local scope.
94staticDISubprogram *getSubprogram(DIScope *Scope) {
95if (auto *Subprogram = dyn_cast<DISubprogram>(Scope))
96return Subprogram;
97return cast<DILocalScope>(Scope)->getSubprogram();
98}
99
100/// Return true if V is a splat of a value (which is used when multiplying a
101/// matrix with a scalar).
102staticboolisSplat(Value *V) {
103if (auto *SV = dyn_cast<ShuffleVectorInst>(V))
104return SV->isZeroEltSplat();
105returnfalse;
106}
107
108/// Match any mul operation (fp or integer).
109template <typename LTy,typename RTy>
110autom_AnyMul(const LTy &L,const RTy &R) {
111returnm_CombineOr(m_Mul(L, R),m_FMul(L, R));
112}
113
114/// Match any add operation (fp or integer).
115template <typename LTy,typename RTy>
116autom_AnyAdd(const LTy &L,const RTy &R) {
117returnm_CombineOr(m_Add(L, R),m_FAdd(L, R));
118}
119
120namespace{
121
122// Given an element pointer \p BasePtr to the start of a (sub) matrix, compute
123// the start address of vector \p VecIdx with type (\p EltType x \p NumElements)
124// assuming \p Stride elements between start two consecutive vectors.
125// \p Stride must be >= \p NumElements.
126// For column-major matrixes, the function computes the address of a column
127// vectors and \p NumElements must be set to the number of elements in a column
128// (= number of rows of the matrix). For row-major matrixes, the function
129// computes the address of a row vector and \p NumElements must be set to the
130// number of elements in a column (= number of columns of the matrix).
131//
132// Consider a 4x4 matrix in column-mjaor layout like below
133//
134// 0 1 2 3
135// 0 v_0_0 v_0_1 v_0_2 v_0_3
136// 1 v_1_0 v_1_1 v_1_2 v_1_3
137// 2 v_2_0 v_2_1 v_2_2 v_2_3
138// 3 v_3_0 v_3_1 v_3_2 v_3_3
139
140// To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1,
141// we need a pointer to the first element of the submatrix as base pointer.
142// Then we can use computeVectorAddr to compute the addresses for the columns
143// of the sub-matrix.
144//
145// Column 0: computeVectorAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..)
146// -> just returns Base
147// Column 1: computeVectorAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..)
148// -> returns Base + (1 * 4)
149// Column 2: computeVectorAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..)
150// -> returns Base + (2 * 4)
151//
152// The graphic below illustrates the number of elements in a column (marked
153// with |) and the number of skipped elements (marked with }).
154//
155// v_0_0 v_0_1 {v_0_2 {v_0_3
156// Base Col 1 Col 2
157// | | |
158// v_1_0 |v_1_1 |v_1_2 |v_1_3
159// v_2_0 |v_2_1 |v_2_2 |v_2_3
160// v_3_0 {v_3_1 {v_3_2 v_3_3
161//
162Value *computeVectorAddr(Value *BasePtr,Value *VecIdx,Value *Stride,
163unsigned NumElements,Type *EltType,
164IRBuilder<> &Builder) {
165
166assert((!isa<ConstantInt>(Stride) ||
167 cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) &&
168"Stride must be >= the number of elements in the result vector.");
169
170// Compute the start of the vector with index VecIdx as VecIdx * Stride.
171Value *VecStart = Builder.CreateMul(VecIdx, Stride,"vec.start");
172
173// Get pointer to the start of the selected vector. Skip GEP creation,
174// if we select vector 0.
175if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->isZero())
176 VecStart =BasePtr;
177else
178 VecStart = Builder.CreateGEP(EltType, BasePtr, VecStart,"vec.gep");
179
180return VecStart;
181}
182
183namespace{
184structShapeInfo {
185unsigned NumRows;
186unsigned NumColumns;
187
188bool IsColumnMajor;
189
190 ShapeInfo(unsigned NumRows = 0,unsigned NumColumns = 0)
191 : NumRows(NumRows), NumColumns(NumColumns),
192 IsColumnMajor(MatrixLayout ==MatrixLayoutTy::ColumnMajor) {}
193
194 ShapeInfo(Value *NumRows,Value *NumColumns)
195 : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(),
196cast<ConstantInt>(NumColumns)->getZExtValue()) {}
197
198booloperator==(const ShapeInfo &other) {
199return NumRows == other.NumRows && NumColumns == other.NumColumns;
200 }
201booloperator!=(const ShapeInfo &other) {return !(*this == other); }
202
203 /// Returns true if shape-information is defined, meaning both dimensions
204 /// are != 0.
205operatorbool() const{
206assert(NumRows == 0 || NumColumns != 0);
207return NumRows != 0;
208 }
209
210unsigned getStride() const{
211if (IsColumnMajor)
212return NumRows;
213return NumColumns;
214 }
215
216unsigned getNumVectors() const{
217if (IsColumnMajor)
218return NumColumns;
219return NumRows;
220 }
221
222 /// Returns the transposed shape.
223 ShapeInfo t() const{return ShapeInfo(NumColumns, NumRows); }
224};
225}// namespace
226
227staticbool isUniformShape(Value *V) {
228Instruction *I = dyn_cast<Instruction>(V);
229if (!I)
230returntrue;
231
232switch (I->getOpcode()) {
233case Instruction::FAdd:
234case Instruction::FSub:
235case Instruction::FMul:// Scalar multiply.
236case Instruction::FNeg:
237case Instruction::Add:
238case Instruction::Mul:
239case Instruction::Sub:
240returntrue;
241default:
242returnfalse;
243 }
244}
245
246/// Return the ShapeInfo for the result of \p I, it it can be determined.
247static std::optional<ShapeInfo>
248computeShapeInfoForInst(Instruction *I,
249constDenseMap<Value *, ShapeInfo> &ShapeMap) {
250Value *M;
251Value *N;
252Value *K;
253if (match(I, m_Intrinsic<Intrinsic::matrix_multiply>(
254m_Value(),m_Value(),m_Value(M),m_Value(N),m_Value(K))))
255return ShapeInfo(M, K);
256if (match(I, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(),m_Value(M),
257m_Value(N)))) {
258// Flip dimensions.
259return ShapeInfo(N, M);
260 }
261if (match(I, m_Intrinsic<Intrinsic::matrix_column_major_store>(
262m_Value(),m_Value(),m_Value(),m_Value(),m_Value(M),
263m_Value(N))))
264return ShapeInfo(N, M);
265if (match(I, m_Intrinsic<Intrinsic::matrix_column_major_load>(
266m_Value(),m_Value(),m_Value(),m_Value(M),m_Value(N))))
267return ShapeInfo(M,N);
268Value *MatrixA;
269if (match(I,m_Store(m_Value(MatrixA),m_Value()))) {
270auto OpShape = ShapeMap.find(MatrixA);
271if (OpShape != ShapeMap.end())
272return OpShape->second;
273 }
274
275if (isUniformShape(I)) {
276// Find the first operand that has a known shape and use that.
277for (auto &Op :I->operands()) {
278auto OpShape = ShapeMap.find(Op.get());
279if (OpShape != ShapeMap.end())
280return OpShape->second;
281 }
282 }
283return std::nullopt;
284}
285
286/// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics.
287///
288/// Currently, the lowering for each matrix intrinsic is done as follows:
289/// 1. Propagate the shape information from intrinsics to connected
290/// instructions.
291/// 2. Lower instructions with shape information (assuming column-major layout).
292/// The lowering works similarly using row-major layout.
293/// 2.1. Get column vectors for each argument. If we already lowered the
294/// definition of an argument, use the produced column vectors directly.
295/// If not, split the operand vector containing an embedded matrix into
296/// a set of column vectors,
297/// 2.2. Lower the instruction in terms of column major operations, which
298/// yields a set of column vectors containing result matrix. Note that we
299/// lower all instructions that have shape information. Besides the
300/// intrinsics, this includes stores for example.
301/// 2.3. Update uses of the lowered instruction. If we have shape information
302/// for a user, there is nothing to do, as we will look up the result
303/// column matrix when lowering the user. For other uses, we embed the
304/// result matrix in a flat vector and update the use.
305/// 2.4. Cache the result column matrix for the instruction we lowered
306/// 3. After we lowered all instructions in a function, remove the now
307/// obsolete instructions.
308///
309classLowerMatrixIntrinsics {
310Function &Func;
311constDataLayout &DL;
312constTargetTransformInfo &TTI;
313FunctionAnalysisManager *AM;
314AliasAnalysis *AA =nullptr;
315DominatorTree *DT =nullptr;
316LoopInfo *LI =nullptr;
317OptimizationRemarkEmitter *ORE =nullptr;
318
319 /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation.
320structOpInfoTy {
321 /// Number of stores emitted to generate this matrix.
322unsigned NumStores = 0;
323 /// Number of loads emitted to generate this matrix.
324unsigned NumLoads = 0;
325 /// Number of compute operations emitted to generate this matrix.
326unsigned NumComputeOps = 0;
327 /// Most of the time transposes can be fused with matrix multiplies or can
328 /// be folded away via algebraic simplifications. This is the number of
329 /// transposes that we failed to make "free" via such optimizations.
330unsigned NumExposedTransposes = 0;
331
332 OpInfoTy &operator+=(const OpInfoTy &RHS) {
333 NumStores +=RHS.NumStores;
334 NumLoads +=RHS.NumLoads;
335 NumComputeOps +=RHS.NumComputeOps;
336 NumExposedTransposes +=RHS.NumExposedTransposes;
337return *this;
338 }
339 };
340
341 /// Wrapper class representing a matrix as a set of vectors, either in row or
342 /// column major layout. All vectors must have the same vector type.
343classMatrixTy {
344SmallVector<Value *, 16> Vectors;
345
346 OpInfoTy OpInfo;
347
348bool IsColumnMajor =true;
349
350public:
351 MatrixTy() : IsColumnMajor(MatrixLayout ==MatrixLayoutTy::ColumnMajor) {}
352 MatrixTy(ArrayRef<Value *> Vectors)
353 : Vectors(Vectors),
354 IsColumnMajor(MatrixLayout ==MatrixLayoutTy::ColumnMajor) {}
355 MatrixTy(unsigned NumRows,unsigned NumColumns,Type *EltTy)
356 : IsColumnMajor(MatrixLayout ==MatrixLayoutTy::ColumnMajor) {
357
358unsignedD = isColumnMajor() ? NumColumns : NumRows;
359for (unsigned J = 0; J <D; ++J)
360 addVector(PoisonValue::get(FixedVectorType::get(
361 EltTy, isColumnMajor() ? NumRows : NumColumns)));
362 }
363
364Value *getVector(unsigned i) const{return Vectors[i]; }
365Value *getColumn(unsigned i) const{
366assert(isColumnMajor() &&"only supported for column-major matrixes");
367return Vectors[i];
368 }
369Value *getRow(unsigned i) const{
370assert(!isColumnMajor() &&"only supported for row-major matrixes");
371return Vectors[i];
372 }
373
374void setVector(unsigned i,Value *V) { Vectors[i] =V; }
375
376Type *getElementType() const{return getVectorTy()->getElementType(); }
377
378unsigned getNumVectors() const{
379if (isColumnMajor())
380return getNumColumns();
381return getNumRows();
382 }
383
384unsigned getNumColumns() const{
385if (isColumnMajor())
386return Vectors.size();
387else {
388assert(Vectors.size() > 0 &&"Cannot call getNumRows without columns");
389return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements();
390 }
391 }
392unsigned getNumRows() const{
393if (isColumnMajor()) {
394assert(Vectors.size() > 0 &&"Cannot call getNumRows without columns");
395return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements();
396 }else
397return Vectors.size();
398 }
399
400void addVector(Value *V) { Vectors.push_back(V); }
401VectorType *getColumnTy() {
402assert(isColumnMajor() &&"only supported for column-major matrixes");
403return getVectorTy();
404 }
405
406VectorType *getVectorTy() const{
407return cast<VectorType>(Vectors[0]->getType());
408 }
409
410iterator_range<SmallVector<Value *, 8>::iterator> columns() {
411assert(isColumnMajor() &&
412"columns() only supported for column-major matrixes");
413returnmake_range(Vectors.begin(), Vectors.end());
414 }
415
416iterator_range<SmallVector<Value *, 8>::iterator>vectors() {
417returnmake_range(Vectors.begin(), Vectors.end());
418 }
419
420 /// Embed the vectors of the matrix into a flat vector by concatenating
421 /// them.
422Value *embedInVector(IRBuilder<> &Builder) const{
423return Vectors.size() == 1 ? Vectors[0]
424 :concatenateVectors(Builder, Vectors);
425 }
426
427 MatrixTy &addNumLoads(unsignedN) {
428 OpInfo.NumLoads +=N;
429return *this;
430 }
431
432void setNumLoads(unsignedN) { OpInfo.NumLoads =N; }
433
434 MatrixTy &addNumStores(unsignedN) {
435 OpInfo.NumStores +=N;
436return *this;
437 }
438
439 MatrixTy &addNumExposedTransposes(unsignedN) {
440 OpInfo.NumExposedTransposes +=N;
441return *this;
442 }
443
444 MatrixTy &addNumComputeOps(unsignedN) {
445 OpInfo.NumComputeOps +=N;
446return *this;
447 }
448
449unsigned getNumStores() const{return OpInfo.NumStores; }
450unsigned getNumLoads() const{return OpInfo.NumLoads; }
451unsigned getNumComputeOps() const{return OpInfo.NumComputeOps; }
452
453const OpInfoTy &getOpInfo() const{return OpInfo; }
454
455bool isColumnMajor() const{return IsColumnMajor; }
456
457unsigned getStride() const{
458if (isColumnMajor())
459return getNumRows();
460return getNumColumns();
461 }
462
463 /// Extract a vector of \p NumElts starting at index (\p I, \p J). If the
464 /// matrix is column-major, the result vector is extracted from a column
465 /// vector, otherwise from a row vector.
466Value *extractVector(unsignedI,unsigned J,unsigned NumElts,
467IRBuilder<> &Builder) const{
468Value *Vec = isColumnMajor() ? getColumn(J) : getRow(I);
469assert(cast<FixedVectorType>(Vec->getType())->getNumElements() >=
470 NumElts &&
471"Extracted vector will contain poison values");
472return Builder.CreateShuffleVector(
473 Vec,createSequentialMask(isColumnMajor() ?I : J, NumElts, 0),
474"block");
475 }
476 };
477
478 /// Maps instructions to their shape information. The shape information
479 /// describes the shape to be used while lowering. This matches the shape of
480 /// the result value of the instruction, with the only exceptions being store
481 /// instructions and the matrix_column_major_store intrinsics. For those, the
482 /// shape information indicates that those instructions should be lowered
483 /// using shape information as well. Note that extra care is needed when
484 /// erasing or RAUW'ing a value that is present in ShapeMap. If the
485 /// replacement is also a matrix operation, use
486 /// updateShapeAndReplaceAllUsesWith to make sure the replacement is added to
487 /// ShapeMap. We don't use ValueMap, as there are also cases where we do not
488 /// want to add shape information for a replacement instruction. When directly
489 /// erasing a value with an entry in ShapeMap, use
490 /// eraseFromParentAndRemoveFromShapeMap to make sure ShapeMap is also updated
491 /// accordingly.
492DenseMap<Value *, ShapeInfo> ShapeMap;
493
494 /// List of instructions to remove. While lowering, we are not replacing all
495 /// users of a lowered instruction, if shape information is available and
496 /// those need to be removed after we finished lowering.
497SmallVector<Instruction *, 16>ToRemove;
498
499 /// Map from instructions to their produced column matrix.
500MapVector<Value *, MatrixTy> Inst2ColumnMatrix;
501
502private:
503staticFastMathFlags getFastMathFlags(Instruction *Inst) {
504FastMathFlags FMF;
505
506if (isa<FPMathOperator>(*Inst))
507 FMF = Inst->getFastMathFlags();
508
509 FMF.setAllowContract(AllowContractEnabled || FMF.allowContract());
510
511return FMF;
512 }
513
514public:
515 LowerMatrixIntrinsics(Function &F,TargetTransformInfo &TTI,
516FunctionAnalysisManager *AM)
517 :Func(F),DL(F.getDataLayout()),TTI(TTI), AM(AM) {}
518
519unsigned getNumOps(Type *VT) {
520assert(isa<VectorType>(VT) &&"Expected vector type");
521return getNumOps(VT->getScalarType(),
522 cast<FixedVectorType>(VT)->getNumElements());
523 }
524
525 /// Is this the minimal version executed in the backend pipelines.
526bool isMinimal() const{
527return !DT;
528 }
529
530 /// Return the estimated number of vector ops required for an operation on
531 /// \p VT * N.
532unsigned getNumOps(Type *ST,unsignedN) {
533return std::ceil((ST->getPrimitiveSizeInBits() *N).getFixedValue() /
534double(TTI.getRegisterBitWidth(
535TargetTransformInfo::RGK_FixedWidthVector)
536 .getFixedValue()));
537 }
538
539 /// Return the set of vectors that a matrix value is lowered to.
540 ///
541 /// If we lowered \p MatrixVal, just return the cache result matrix. Otherwise
542 /// split the flat vector \p MatrixVal containing a matrix with shape \p SI
543 /// into vectors.
544 MatrixTy getMatrix(Value *MatrixVal,const ShapeInfo &SI,
545IRBuilder<> &Builder) {
546VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType());
547assert(VType &&"MatrixVal must be a vector type");
548assert(cast<FixedVectorType>(VType)->getNumElements() ==
549SI.NumRows *SI.NumColumns &&
550"The vector size must match the number of matrix elements");
551
552// Check if we lowered MatrixVal using shape information. In that case,
553// return the existing matrix, if it matches the requested shape
554// information. If there is a mis-match, embed the result in a flat
555// vector and split it later.
556auto Found = Inst2ColumnMatrix.find(MatrixVal);
557if (Found != Inst2ColumnMatrix.end()) {
558 MatrixTy &M = Found->second;
559// Return the found matrix, if its shape matches the requested shape
560// information
561if (SI.NumRows ==M.getNumRows() &&SI.NumColumns ==M.getNumColumns())
562returnM;
563
564 MatrixVal =M.embedInVector(Builder);
565 }
566
567// Otherwise split MatrixVal.
568SmallVector<Value *, 16> SplitVecs;
569for (unsigned MaskStart = 0;
570 MaskStart < cast<FixedVectorType>(VType)->getNumElements();
571 MaskStart +=SI.getStride()) {
572Value *V = Builder.CreateShuffleVector(
573 MatrixVal,createSequentialMask(MaskStart,SI.getStride(), 0),
574"split");
575 SplitVecs.push_back(V);
576 }
577
578return {SplitVecs};
579 }
580
581 /// If \p V already has a known shape return false. Otherwise set the shape
582 /// for instructions that support it.
583bool setShapeInfo(Value *V, ShapeInfo Shape) {
584assert(Shape &&"Shape not set");
585if (isa<UndefValue>(V) || !supportsShapeInfo(V))
586returnfalse;
587
588auto SIter = ShapeMap.find(V);
589if (SIter != ShapeMap.end()) {
590if (VerifyShapeInfo && (SIter->second.NumRows != Shape.NumRows ||
591 SIter->second.NumColumns != Shape.NumColumns)) {
592errs() <<"Conflicting shapes (" << SIter->second.NumRows <<"x"
593 << SIter->second.NumColumns <<" vs " << Shape.NumRows <<"x"
594 << Shape.NumColumns <<") for " << *V <<"\n";
595report_fatal_error(
596"Matrix shape verification failed, compilation aborted!");
597 }
598
599LLVM_DEBUG(dbgs() <<" not overriding existing shape: "
600 << SIter->second.NumRows <<" "
601 << SIter->second.NumColumns <<" for " << *V <<"\n");
602returnfalse;
603 }
604
605 ShapeMap.insert({V, Shape});
606LLVM_DEBUG(dbgs() <<" " << Shape.NumRows <<" x " << Shape.NumColumns
607 <<" for " << *V <<"\n");
608returntrue;
609 }
610
611 /// Returns true if shape information can be used for \p V. The supported
612 /// instructions must match the instructions that can be lowered by this pass.
613bool supportsShapeInfo(Value *V) {
614Instruction *Inst = dyn_cast<Instruction>(V);
615if (!Inst)
616returnfalse;
617
618IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst);
619if (II)
620switch (II->getIntrinsicID()) {
621case Intrinsic::matrix_multiply:
622case Intrinsic::matrix_transpose:
623case Intrinsic::matrix_column_major_load:
624case Intrinsic::matrix_column_major_store:
625returntrue;
626default:
627returnfalse;
628 }
629return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
630 }
631
632 /// Propagate the shape information of instructions to their users.
633 /// The work list contains instructions for which we can compute the shape,
634 /// either based on the information provided by matrix intrinsics or known
635 /// shapes of operands.
636SmallVector<Instruction *, 32>
637 propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) {
638SmallVector<Instruction *, 32> NewWorkList;
639// Pop an element for which we guaranteed to have at least one of the
640// operand shapes. Add the shape for this and then add users to the work
641// list.
642LLVM_DEBUG(dbgs() <<"Forward-propagate shapes:\n");
643while (!WorkList.empty()) {
644Instruction *Inst = WorkList.pop_back_val();
645
646// New entry, set the value and insert operands
647bool Propagate =false;
648if (auto SI = computeShapeInfoForInst(Inst, ShapeMap))
649 Propagate = setShapeInfo(Inst, *SI);
650
651if (Propagate) {
652 NewWorkList.push_back(Inst);
653for (auto *User : Inst->users())
654if (ShapeMap.count(User) == 0)
655 WorkList.push_back(cast<Instruction>(User));
656 }
657 }
658
659return NewWorkList;
660 }
661
662 /// Propagate the shape to operands of instructions with shape information.
663 /// \p Worklist contains the instruction for which we already know the shape.
664SmallVector<Instruction *, 32>
665 propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) {
666SmallVector<Instruction *, 32> NewWorkList;
667
668auto pushInstruction = [](Value *V,
669SmallVectorImpl<Instruction *> &WorkList) {
670Instruction *I = dyn_cast<Instruction>(V);
671if (I)
672 WorkList.push_back(I);
673 };
674// Pop an element with known shape. Traverse the operands, if their shape
675// derives from the result shape and is unknown, add it and add them to the
676// worklist.
677LLVM_DEBUG(dbgs() <<"Backward-propagate shapes:\n");
678while (!WorkList.empty()) {
679Value *V = WorkList.pop_back_val();
680
681size_t BeforeProcessingV = WorkList.size();
682if (!isa<Instruction>(V))
683continue;
684
685Value *MatrixA;
686Value *MatrixB;
687Value *M;
688Value *N;
689Value *K;
690if (match(V, m_Intrinsic<Intrinsic::matrix_multiply>(
691m_Value(MatrixA),m_Value(MatrixB),m_Value(M),
692m_Value(N),m_Value(K)))) {
693if (setShapeInfo(MatrixA, {M,N}))
694 pushInstruction(MatrixA, WorkList);
695
696if (setShapeInfo(MatrixB, {N,K}))
697 pushInstruction(MatrixB, WorkList);
698
699 }elseif (match(V, m_Intrinsic<Intrinsic::matrix_transpose>(
700m_Value(MatrixA),m_Value(M),m_Value(N)))) {
701// Flip dimensions.
702if (setShapeInfo(MatrixA, {M,N}))
703 pushInstruction(MatrixA, WorkList);
704 }elseif (match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>(
705m_Value(MatrixA),m_Value(),m_Value(),m_Value(),
706m_Value(M),m_Value(N)))) {
707if (setShapeInfo(MatrixA, {M,N})) {
708 pushInstruction(MatrixA, WorkList);
709 }
710 }elseif (isa<LoadInst>(V) ||
711match(V, m_Intrinsic<Intrinsic::matrix_column_major_load>())) {
712// Nothing to do, no matrix input.
713 }elseif (isa<StoreInst>(V)) {
714// Nothing to do. We forward-propagated to this so we would just
715// backward propagate to an instruction with an already known shape.
716 }elseif (isUniformShape(V)) {
717// Propagate to all operands.
718 ShapeInfo Shape = ShapeMap[V];
719for (Use &U : cast<Instruction>(V)->operands()) {
720if (setShapeInfo(U.get(), Shape))
721 pushInstruction(U.get(), WorkList);
722 }
723 }
724// After we discovered new shape info for new instructions in the
725// worklist, we use their users as seeds for the next round of forward
726// propagation.
727for (size_tI = BeforeProcessingV;I != WorkList.size();I++)
728for (User *U : WorkList[I]->users())
729if (isa<Instruction>(U) && V != U)
730 NewWorkList.push_back(cast<Instruction>(U));
731 }
732return NewWorkList;
733 }
734
735 /// (Op0 op Op1)^T -> Op0^T op Op1^T
736 /// Transpose \p Op0 and \p Op1 of shape \p Shape0 and \p Shape1, then use
737 /// them on both sides of \p Operation.
738Instruction *distributeTransposes(
739Value *Op0, ShapeInfo Shape0,Value *Op1, ShapeInfo Shape1,
740MatrixBuilder &Builder,
741function_ref<Instruction *(Value *, ShapeInfo,Value *, ShapeInfo)>
742Operation) {
743Value *T0 = Builder.CreateMatrixTranspose(
744 Op0, Shape0.NumRows, Shape0.NumColumns, Op0->getName() +"_t");
745// We are being run after shape prop, add shape for newly created
746// instructions so that we lower them later.
747 setShapeInfo(T0, Shape0.t());
748Value *T1 = Builder.CreateMatrixTranspose(
749 Op1, Shape1.NumRows, Shape1.NumColumns, Op1->getName() +"_t");
750 setShapeInfo(T1, Shape1.t());
751returnOperation(T0, Shape0.t(), T1, Shape1.t());
752 }
753
754 /// Erase \p Inst from both ShapeMap (if an entry exists) and erase \p Inst
755 /// itself.
756void eraseFromParentAndRemoveFromShapeMap(Instruction *Inst) {
757auto Iter = ShapeMap.find(Inst);
758if (Iter != ShapeMap.end())
759 ShapeMap.erase(Iter);
760 Inst->eraseFromParent();
761 }
762
763 /// Erase \p V from \p BB and move \II forward to avoid invalidating
764 /// iterators.
765void eraseFromParentAndMove(Value *V,BasicBlock::reverse_iterator &II,
766BasicBlock &BB) {
767auto *Inst = cast<Instruction>(V);
768// Still used, don't erase.
769if (!Inst->use_empty())
770return;
771if (II != BB.rend() && Inst == &*II)
772 ++II;
773 eraseFromParentAndRemoveFromShapeMap(Inst);
774 }
775
776 /// Add a new entry to ShapeMap for \p New with \p Old's shape info, erase the
777 /// entry for \p Old and replace all uses of \p Old with \p New.
778void updateShapeAndReplaceAllUsesWith(Instruction &Old,Value *New) {
779// We need to remove Old from the ShapeMap otherwise RAUW will replace it
780// with New. We should only add New it it supportsShapeInfo so we insert
781// it conditionally instead.
782auto S = ShapeMap.find(&Old);
783if (S != ShapeMap.end()) {
784 ShapeMap.erase(S);
785if (supportsShapeInfo(New))
786 ShapeMap.insert({New, S->second});
787 }
788 Old.replaceAllUsesWith(New);
789 }
790
791 /// Sink a top-level transpose inside matmuls and adds.
792 /// This creates and erases instructions as needed, and returns the newly
793 /// created instruction while updating the iterator to avoid invalidation. If
794 /// this returns nullptr, no new instruction was created.
795Instruction *sinkTranspose(Instruction &I,BasicBlock::reverse_iterator &II) {
796BasicBlock &BB = *I.getParent();
797IRBuilder<>IB(&I);
798MatrixBuilder Builder(IB);
799
800Value *TA, *TAMA, *TAMB;
801ConstantInt *R, *K, *C;
802if (!match(&I, m_Intrinsic<Intrinsic::matrix_transpose>(
803m_Value(TA),m_ConstantInt(R),m_ConstantInt(C))))
804returnnullptr;
805
806// Transpose of a transpose is a nop
807Value *TATA;
808if (match(TA, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TATA)))) {
809 updateShapeAndReplaceAllUsesWith(I, TATA);
810 eraseFromParentAndMove(&I,II, BB);
811 eraseFromParentAndMove(TA,II, BB);
812returnnullptr;
813 }
814
815// k^T -> k
816if (isSplat(TA)) {
817 updateShapeAndReplaceAllUsesWith(I, TA);
818 eraseFromParentAndMove(&I,II, BB);
819returnnullptr;
820 }
821
822// (A * B)^t -> B^t * A^t
823// RxK KxC CxK KxR
824if (match(TA, m_Intrinsic<Intrinsic::matrix_multiply>(
825m_Value(TAMA),m_Value(TAMB),m_ConstantInt(R),
826m_ConstantInt(K),m_ConstantInt(C)))) {
827auto NewInst = distributeTransposes(
828 TAMB, {K,C}, TAMA, {R,K}, Builder,
829 [&](Value *T0, ShapeInfo Shape0,Value *T1, ShapeInfo Shape1) {
830return Builder.CreateMatrixMultiply(T0, T1, Shape0.NumRows,
831 Shape0.NumColumns,
832 Shape1.NumColumns,"mmul");
833 });
834 updateShapeAndReplaceAllUsesWith(I, NewInst);
835 eraseFromParentAndMove(&I,II, BB);
836 eraseFromParentAndMove(TA,II, BB);
837return NewInst;
838 }
839
840// Same as above, but with a mul, which occurs when multiplied
841// with a scalar.
842// (A * k)^t -> A^t * k
843// R x C RxC
844if (match(TA,m_AnyMul(m_Value(TAMA),m_Value(TAMB))) &&
845 (isSplat(TAMA) ||isSplat(TAMB))) {
846IRBuilder<> LocalBuilder(&I);
847// We know that the transposed operand is of shape RxC.
848// An when multiplied with a scalar, the shape is preserved.
849auto NewInst = distributeTransposes(
850 TAMA, {R,C}, TAMB, {R,C}, Builder,
851 [&](Value *T0, ShapeInfo Shape0,Value *T1, ShapeInfo Shape1) {
852bool IsFP =I.getType()->isFPOrFPVectorTy();
853auto *Mul = IsFP ? LocalBuilder.CreateFMul(T0, T1,"mmul")
854 : LocalBuilder.CreateMul(T0, T1,"mmul");
855auto *Result = cast<Instruction>(Mul);
856 setShapeInfo(Result, Shape0);
857returnResult;
858 });
859 updateShapeAndReplaceAllUsesWith(I, NewInst);
860 eraseFromParentAndMove(&I,II, BB);
861 eraseFromParentAndMove(TA,II, BB);
862return NewInst;
863 }
864
865// (A + B)^t -> A^t + B^t
866// RxC RxC CxR CxR
867if (match(TA,m_AnyAdd(m_Value(TAMA),m_Value(TAMB)))) {
868IRBuilder<> LocalBuilder(&I);
869auto NewInst = distributeTransposes(
870 TAMA, {R,C}, TAMB, {R,C}, Builder,
871 [&](Value *T0, ShapeInfo Shape0,Value *T1, ShapeInfo Shape1) {
872bool IsFP =I.getType()->isFPOrFPVectorTy();
873auto *Add = IsFP ? LocalBuilder.CreateFAdd(T0, T1,"madd")
874 : LocalBuilder.CreateAdd(T0, T1,"madd");
875
876auto *Result = cast<Instruction>(Add);
877 setShapeInfo(Result, Shape0);
878returnResult;
879 });
880 updateShapeAndReplaceAllUsesWith(I, NewInst);
881 eraseFromParentAndMove(&I,II, BB);
882 eraseFromParentAndMove(TA,II, BB);
883return NewInst;
884 }
885
886returnnullptr;
887 }
888
889void liftTranspose(Instruction &I) {
890// Erase dead Instructions after lifting transposes from binops.
891auto CleanupBinOp = [this](Instruction &T,Value *A,Value *B) {
892if (T.use_empty())
893 eraseFromParentAndRemoveFromShapeMap(&T);
894if (A->use_empty())
895 eraseFromParentAndRemoveFromShapeMap(cast<Instruction>(A));
896if (A !=B &&B->use_empty())
897 eraseFromParentAndRemoveFromShapeMap(cast<Instruction>(B));
898 };
899
900Value *A, *B, *AT, *BT;
901ConstantInt *R, *K, *C;
902// A^t * B ^t -> (B * A)^t
903if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>(
904m_Value(A),m_Value(B),m_ConstantInt(R),
905m_ConstantInt(K),m_ConstantInt(C))) &&
906match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(AT))) &&
907match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value((BT))))) {
908IRBuilder<>IB(&I);
909MatrixBuilder Builder(IB);
910Value *M = Builder.CreateMatrixMultiply(
911BT, AT,C->getZExtValue(),K->getZExtValue(),R->getZExtValue());
912 setShapeInfo(M, {C,R});
913Instruction *NewInst = Builder.CreateMatrixTranspose(M,C->getZExtValue(),
914R->getZExtValue());
915 updateShapeAndReplaceAllUsesWith(I, NewInst);
916 CleanupBinOp(I,A,B);
917 }
918// A^t + B ^t -> (A + B)^t. Pick rows and columns from first transpose. If
919// the shape of the second transpose is different, there's a shape conflict
920// which gets resolved by picking the shape of the first operand.
921elseif (match(&I,m_FAdd(m_Value(A),m_Value(B))) &&
922match(A, m_Intrinsic<Intrinsic::matrix_transpose>(
923m_Value(AT),m_ConstantInt(R),m_ConstantInt(C))) &&
924match(B, m_Intrinsic<Intrinsic::matrix_transpose>(
925m_Value(BT),m_ConstantInt(),m_ConstantInt()))) {
926IRBuilder<> Builder(&I);
927auto *Add = Builder.CreateFAdd(AT,BT,"mfadd");
928MatrixBuilder MBuilder(Builder);
929Instruction *NewInst = MBuilder.CreateMatrixTranspose(
930Add,R->getZExtValue(),C->getZExtValue(),"mfadd_t");
931 updateShapeAndReplaceAllUsesWith(I, NewInst);
932assert(computeShapeInfoForInst(NewInst, ShapeMap) ==
933 computeShapeInfoForInst(&I, ShapeMap) &&
934"Shape of new instruction doesn't match original shape.");
935 CleanupBinOp(I,A,B);
936if (auto *AddI = dyn_cast<Instruction>(Add)) {
937 setShapeInfo(AddI, {R,C});
938assert(
939 computeShapeInfoForInst(AddI, ShapeMap).value_or(ShapeMap[AddI]) ==
940 ShapeMap[AddI] &&
941"Shape of updated addition doesn't match cached shape.");
942 }
943 }
944 }
945
946 /// Try moving transposes in order to fold them away or into multiplies.
947void optimizeTransposes() {
948// First sink all transposes inside matmuls and adds, hoping that we end up
949// with NN, NT or TN variants.
950for (BasicBlock &BB :reverse(Func)) {
951for (autoII = BB.rbegin();II != BB.rend();) {
952Instruction &I = *II;
953// We may remove II. By default continue on the next/prev instruction.
954 ++II;
955if (Instruction *NewInst = sinkTranspose(I,II))
956II = std::next(BasicBlock::reverse_iterator(NewInst));
957 }
958 }
959
960// If we have a TT matmul or a TT add, lift the transpose. We may be able
961// to fold into consuming multiply or add.
962for (BasicBlock &BB : Func) {
963for (Instruction &I :llvm::make_early_inc_range(BB)) {
964 liftTranspose(I);
965 }
966 }
967 }
968
969bool Visit() {
970SmallVector<Instruction *, 32> WorkList;
971
972// Initially only the shape of matrix intrinsics is known.
973// Initialize the work list with ops carrying shape information.
974for (BasicBlock &BB : Func)
975for (Instruction &Inst : BB) {
976IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst);
977if (!II)
978continue;
979
980switch (II->getIntrinsicID()) {
981case Intrinsic::matrix_multiply:
982case Intrinsic::matrix_transpose:
983case Intrinsic::matrix_column_major_load:
984case Intrinsic::matrix_column_major_store:
985 WorkList.push_back(&Inst);
986break;
987default:
988break;
989 }
990 }
991
992// Avoid unnecessary work if there are no matrix intrinsics in the function.
993if (WorkList.empty())
994returnfalse;
995
996if (AM) {
997 ORE = &AM->getResult<OptimizationRemarkEmitterAnalysis>(Func);
998 AA = &AM->getResult<AAManager>(Func);
999 DT = &AM->getResult<DominatorTreeAnalysis>(Func);
1000 LI = &AM->getResult<LoopAnalysis>(Func);
1001 }
1002
1003// Propagate shapes until nothing changes any longer.
1004while (!WorkList.empty()) {
1005 WorkList = propagateShapeForward(WorkList);
1006 WorkList = propagateShapeBackward(WorkList);
1007 }
1008
1009if (!isMinimal()) {
1010 optimizeTransposes();
1011if (PrintAfterTransposeOpt) {
1012dbgs() <<"Dump after matrix transpose optimization:\n";
1013Func.print(dbgs());
1014 }
1015 }
1016
1017bool Changed =false;
1018SmallVector<CallInst *, 16> MaybeFusableInsts;
1019SmallVector<Instruction *, 16> MatrixInsts;
1020SmallVector<IntrinsicInst *, 16> LifetimeEnds;
1021
1022// First, collect all instructions with shape information and candidates for
1023// fusion (currently only matrix multiplies).
1024ReversePostOrderTraversal<Function *> RPOT(&Func);
1025for (auto *BB : RPOT)
1026for (Instruction &I : *BB) {
1027if (match(&I, m_Intrinsic<Intrinsic::lifetime_end>()))
1028 LifetimeEnds.push_back(cast<IntrinsicInst>(&I));
1029if (ShapeMap.find(&I) == ShapeMap.end())
1030continue;
1031if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>()))
1032 MaybeFusableInsts.push_back(cast<CallInst>(&I));
1033 MatrixInsts.push_back(&I);
1034 }
1035
1036// Second, try to lower any dot products
1037SmallPtrSet<Instruction *, 16> FusedInsts;
1038for (CallInst *CI : MaybeFusableInsts)
1039 lowerDotProduct(CI, FusedInsts, getFastMathFlags(CI));
1040
1041// Third, try to fuse candidates.
1042for (CallInst *CI : MaybeFusableInsts)
1043if (!FusedInsts.contains(CI))
1044 LowerMatrixMultiplyFused(CI, FusedInsts, LifetimeEnds);
1045
1046 Changed = !FusedInsts.empty();
1047
1048// Fourth, lower remaining instructions with shape information.
1049for (Instruction *Inst : MatrixInsts) {
1050if (FusedInsts.count(Inst))
1051continue;
1052
1053IRBuilder<> Builder(Inst);
1054
1055if (CallInst *CInst = dyn_cast<CallInst>(Inst))
1056 Changed |= VisitCallInst(CInst);
1057
1058Value *Op1;
1059Value *Op2;
1060if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
1061 Changed |= VisitBinaryOperator(BinOp);
1062if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
1063 Changed |= VisitUnaryOperator(UnOp);
1064if (match(Inst,m_Load(m_Value(Op1))))
1065 Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder);
1066elseif (match(Inst,m_Store(m_Value(Op1),m_Value(Op2))))
1067 Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder);
1068 }
1069
1070if (ORE) {
1071 RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func);
1072 RemarkGen.emitRemarks();
1073 }
1074
1075// Delete the instructions backwards, as it has a reduced likelihood of
1076// having to update as many def-use and use-def chains.
1077//
1078// Because we add to ToRemove during fusion we can't guarantee that defs
1079// are before uses. Change uses to poison temporarily as these should get
1080// removed as well.
1081//
1082// For verification, we keep track of where we changed uses to poison in
1083// PoisonedInsts and then check that we in fact remove them.
1084SmallSet<Instruction *, 16> PoisonedInsts;
1085for (auto *Inst :reverse(ToRemove)) {
1086for (Use &U :llvm::make_early_inc_range(Inst->uses())) {
1087if (auto *Poisoned = dyn_cast<Instruction>(U.getUser()))
1088 PoisonedInsts.insert(Poisoned);
1089U.set(PoisonValue::get(Inst->getType()));
1090 }
1091 Inst->eraseFromParent();
1092 PoisonedInsts.erase(Inst);
1093 }
1094if (!PoisonedInsts.empty()) {
1095// If we didn't remove all poisoned instructions, it's a hard error.
1096dbgs() <<"Poisoned but present instructions:\n";
1097for (auto *I : PoisonedInsts)
1098dbgs() << *I <<"\n";
1099llvm_unreachable("Poisoned but instruction not removed");
1100 }
1101
1102return Changed;
1103 }
1104
1105 /// Replace intrinsic calls
1106bool VisitCallInst(CallInst *Inst) {
1107if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic())
1108returnfalse;
1109
1110switch (Inst->getCalledFunction()->getIntrinsicID()) {
1111case Intrinsic::matrix_multiply:
1112 LowerMultiply(Inst);
1113break;
1114case Intrinsic::matrix_transpose:
1115 LowerTranspose(Inst);
1116break;
1117case Intrinsic::matrix_column_major_load:
1118 LowerColumnMajorLoad(Inst);
1119break;
1120case Intrinsic::matrix_column_major_store:
1121 LowerColumnMajorStore(Inst);
1122break;
1123default:
1124returnfalse;
1125 }
1126returntrue;
1127 }
1128
1129 /// Compute the alignment for a column/row \p Idx with \p Stride between them.
1130 /// The address at \p Idx == 0 has alignment \p A. If \p Stride is a
1131 /// ConstantInt, reduce the initial alignment based on the byte offset. For
1132 /// non-ConstantInt strides, return the common alignment of the initial
1133 /// alignment and the element size in bytes.
1134Align getAlignForIndex(unsignedIdx,Value *Stride,Type *ElementTy,
1135MaybeAlignA) const{
1136Align InitialAlign =DL.getValueOrABITypeAlignment(A, ElementTy);
1137if (Idx == 0)
1138return InitialAlign;
1139
1140TypeSize ElementSizeInBits =DL.getTypeSizeInBits(ElementTy);
1141if (auto *ConstStride = dyn_cast<ConstantInt>(Stride)) {
1142uint64_t StrideInBytes =
1143 ConstStride->getZExtValue() * ElementSizeInBits / 8;
1144returncommonAlignment(InitialAlign,Idx * StrideInBytes);
1145 }
1146returncommonAlignment(InitialAlign, ElementSizeInBits / 8);
1147 }
1148
1149 /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between
1150 /// vectors.
1151 MatrixTy loadMatrix(Type *Ty,Value *Ptr,MaybeAlign MAlign,Value *Stride,
1152bool IsVolatile, ShapeInfo Shape,IRBuilder<> &Builder) {
1153auto *VType = cast<VectorType>(Ty);
1154Type *EltTy = VType->getElementType();
1155Type *VecTy =FixedVectorType::get(EltTy, Shape.getStride());
1156Value *EltPtr =Ptr;
1157 MatrixTyResult;
1158for (unsignedI = 0, E = Shape.getNumVectors();I < E; ++I) {
1159Value *GEP = computeVectorAddr(
1160 EltPtr, Builder.getIntN(Stride->getType()->getScalarSizeInBits(),I),
1161 Stride, Shape.getStride(), EltTy, Builder);
1162Value *Vector = Builder.CreateAlignedLoad(
1163 VecTy,GEP, getAlignForIndex(I, Stride, EltTy, MAlign),
1164 IsVolatile,"col.load");
1165
1166Result.addVector(Vector);
1167 }
1168returnResult.addNumLoads(getNumOps(Result.getVectorTy()) *
1169Result.getNumVectors());
1170 }
1171
1172 /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix,
1173 /// starting at \p MatrixPtr[I][J].
1174 MatrixTy loadMatrix(Value *MatrixPtr,MaybeAlignAlign,bool IsVolatile,
1175 ShapeInfo MatrixShape,Value *I,Value *J,
1176 ShapeInfo ResultShape,Type *EltTy,
1177IRBuilder<> &Builder) {
1178
1179Value *Offset = Builder.CreateAdd(
1180 Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())),I);
1181
1182Value *TileStart = Builder.CreateGEP(EltTy, MatrixPtr,Offset);
1183auto *TileTy =FixedVectorType::get(EltTy, ResultShape.NumRows *
1184 ResultShape.NumColumns);
1185
1186return loadMatrix(TileTy, TileStart,Align,
1187 Builder.getInt64(MatrixShape.getStride()), IsVolatile,
1188 ResultShape, Builder);
1189 }
1190
1191 /// Lower a load instruction with shape information.
1192voidLowerLoad(Instruction *Inst,Value *Ptr,MaybeAlignAlign,Value *Stride,
1193bool IsVolatile, ShapeInfo Shape) {
1194IRBuilder<> Builder(Inst);
1195 finalizeLowering(Inst,
1196 loadMatrix(Inst->getType(),Ptr,Align, Stride, IsVolatile,
1197 Shape, Builder),
1198 Builder);
1199 }
1200
1201 /// Lowers llvm.matrix.column.major.load.
1202 ///
1203 /// The intrinsic loads a matrix from memory using a stride between columns.
1204void LowerColumnMajorLoad(CallInst *Inst) {
1205assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
1206"Intrinsic only supports column-major layout!");
1207Value *Ptr = Inst->getArgOperand(0);
1208Value *Stride = Inst->getArgOperand(1);
1209LowerLoad(Inst,Ptr, Inst->getParamAlign(0), Stride,
1210 cast<ConstantInt>(Inst->getArgOperand(2))->isOne(),
1211 {Inst->getArgOperand(3), Inst->getArgOperand(4)});
1212 }
1213
1214 /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p
1215 /// MatrixPtr[I][J].
1216void storeMatrix(const MatrixTy &StoreVal,Value *MatrixPtr,
1217MaybeAlign MAlign,bool IsVolatile, ShapeInfo MatrixShape,
1218Value *I,Value *J,Type *EltTy,IRBuilder<> &Builder) {
1219Value *Offset = Builder.CreateAdd(
1220 Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())),I);
1221
1222Value *TileStart = Builder.CreateGEP(EltTy, MatrixPtr,Offset);
1223auto *TileTy =FixedVectorType::get(EltTy, StoreVal.getNumRows() *
1224 StoreVal.getNumColumns());
1225
1226 storeMatrix(TileTy, StoreVal, TileStart, MAlign,
1227 Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder);
1228 }
1229
1230 /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between
1231 /// vectors.
1232 MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal,Value *Ptr,
1233MaybeAlign MAlign,Value *Stride,bool IsVolatile,
1234IRBuilder<> &Builder) {
1235auto VType = cast<VectorType>(Ty);
1236Value *EltPtr =Ptr;
1237for (auto Vec :enumerate(StoreVal.vectors())) {
1238Value *GEP = computeVectorAddr(
1239 EltPtr,
1240 Builder.getIntN(Stride->getType()->getScalarSizeInBits(),
1241 Vec.index()),
1242 Stride, StoreVal.getStride(), VType->getElementType(), Builder);
1243 Builder.CreateAlignedStore(Vec.value(),GEP,
1244 getAlignForIndex(Vec.index(), Stride,
1245 VType->getElementType(),
1246 MAlign),
1247 IsVolatile);
1248 }
1249return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) *
1250 StoreVal.getNumVectors());
1251 }
1252
1253 /// Lower a store instruction with shape information.
1254voidLowerStore(Instruction *Inst,Value *Matrix,Value *Ptr,MaybeAlignA,
1255Value *Stride,bool IsVolatile, ShapeInfo Shape) {
1256IRBuilder<> Builder(Inst);
1257auto StoreVal = getMatrix(Matrix, Shape, Builder);
1258 finalizeLowering(Inst,
1259 storeMatrix(Matrix->getType(), StoreVal,Ptr,A, Stride,
1260 IsVolatile, Builder),
1261 Builder);
1262 }
1263
1264 /// Lowers llvm.matrix.column.major.store.
1265 ///
1266 /// The intrinsic store a matrix back memory using a stride between columns.
1267void LowerColumnMajorStore(CallInst *Inst) {
1268assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
1269"Intrinsic only supports column-major layout!");
1270Value *Matrix = Inst->getArgOperand(0);
1271Value *Ptr = Inst->getArgOperand(1);
1272Value *Stride = Inst->getArgOperand(2);
1273LowerStore(Inst,Matrix,Ptr, Inst->getParamAlign(1), Stride,
1274 cast<ConstantInt>(Inst->getArgOperand(3))->isOne(),
1275 {Inst->getArgOperand(4), Inst->getArgOperand(5)});
1276 }
1277
1278// Set elements I..I+NumElts-1 to Block
1279Value *insertVector(Value *Col,unsignedI,Value *Block,
1280IRBuilder<> &Builder) {
1281
1282// First, bring Block to the same size as Col
1283unsigned BlockNumElts =
1284 cast<FixedVectorType>(Block->getType())->getNumElements();
1285unsigned NumElts = cast<FixedVectorType>(Col->getType())->getNumElements();
1286assert(NumElts >= BlockNumElts &&"Too few elements for current block");
1287
1288Block = Builder.CreateShuffleVector(
1289Block,createSequentialMask(0, BlockNumElts, NumElts - BlockNumElts));
1290
1291// If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7,
1292// 8, 4, 5, 6
1293SmallVector<int, 16>Mask;
1294unsigned i;
1295for (i = 0; i <I; i++)
1296Mask.push_back(i);
1297
1298unsigned VecNumElts =
1299 cast<FixedVectorType>(Col->getType())->getNumElements();
1300for (; i <I + BlockNumElts; i++)
1301Mask.push_back(i -I + VecNumElts);
1302
1303for (; i < VecNumElts; i++)
1304Mask.push_back(i);
1305
1306return Builder.CreateShuffleVector(Col,Block, Mask);
1307 }
1308
1309Value *createMulAdd(Value *Sum,Value *A,Value *B,bool UseFPOp,
1310IRBuilder<> &Builder,bool AllowContraction,
1311unsigned &NumComputeOps) {
1312 NumComputeOps += getNumOps(A->getType());
1313if (!Sum)
1314return UseFPOp ? Builder.CreateFMul(A,B) : Builder.CreateMul(A,B);
1315
1316if (UseFPOp) {
1317if (AllowContraction) {
1318// Use fmuladd for floating point operations and let the backend decide
1319// if that's profitable.
1320return Builder.CreateIntrinsic(Intrinsic::fmuladd,A->getType(),
1321 {A, B, Sum});
1322 }
1323 NumComputeOps += getNumOps(A->getType());
1324Value *Mul = Builder.CreateFMul(A,B);
1325return Builder.CreateFAdd(Sum,Mul);
1326 }
1327
1328 NumComputeOps += getNumOps(A->getType());
1329Value *Mul = Builder.CreateMul(A,B);
1330return Builder.CreateAdd(Sum,Mul);
1331 }
1332
1333 /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For
1334 /// users with shape information, there's nothing to do: they will use the
1335 /// cached value when they are lowered. For other users, \p Matrix is
1336 /// flattened and the uses are updated to use it. Also marks \p Inst for
1337 /// deletion.
1338void finalizeLowering(Instruction *Inst, MatrixTyMatrix,
1339IRBuilder<> &Builder) {
1340auto inserted = Inst2ColumnMatrix.insert(std::make_pair(Inst,Matrix));
1341 (void)inserted;
1342assert(inserted.second &&"multiple matrix lowering mapping");
1343
1344ToRemove.push_back(Inst);
1345Value *Flattened =nullptr;
1346for (Use &U :llvm::make_early_inc_range(Inst->uses())) {
1347if (ShapeMap.find(U.getUser()) == ShapeMap.end()) {
1348if (!Flattened)
1349 Flattened =Matrix.embedInVector(Builder);
1350U.set(Flattened);
1351 }
1352 }
1353 }
1354
1355 /// Special case for MatMul lowering. Prevents scalar loads of row-major
1356 /// vectors Lowers to vector reduction add instead of sequential add if
1357 /// reassocation is enabled.
1358void lowerDotProduct(CallInst *MatMul,
1359SmallPtrSet<Instruction *, 16> &FusedInsts,
1360FastMathFlags FMF) {
1361if (FusedInsts.contains(MatMul) ||
1362MatrixLayout != MatrixLayoutTy::ColumnMajor)
1363return;
1364 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1365 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1366
1367if (LShape.NumRows != 1 || RShape.NumColumns != 1)// not a dot product
1368return;
1369
1370Value *LHS = MatMul->getArgOperand(0);
1371Value *RHS = MatMul->getArgOperand(1);
1372
1373Type *ElementType = cast<VectorType>(LHS->getType())->getElementType();
1374bool IsIntVec =ElementType->isIntegerTy();
1375
1376// Floating point reductions require reassocation.
1377if (!IsIntVec && !FMF.allowReassoc())
1378return;
1379
1380auto CanBeFlattened = [](Value *Op) {
1381if (match(Op,m_BinOp()))
1382returntrue;
1383returnmatch(
1384Op,m_OneUse(m_CombineOr(
1385m_Load(m_Value()),
1386m_CombineOr(m_Intrinsic<Intrinsic::matrix_transpose>(),
1387 m_Intrinsic<Intrinsic::matrix_column_major_load>(
1388m_Value(),m_SpecificInt(1))))));
1389 };
1390// Returns the cost benefit of using \p Op with the dot product lowering. If
1391// the returned cost is < 0, the argument is cheaper to use in the
1392// dot-product lowering.
1393auto GetCostForArg = [this, &CanBeFlattened](Value *Op,unsignedN) {
1394if (ShapeMap.find(Op) == ShapeMap.end())
1395returnInstructionCost::getInvalid();
1396
1397if (!isa<Instruction>(Op))
1398returnInstructionCost(0);
1399
1400FixedVectorType *VecTy = cast<FixedVectorType>(Op->getType());
1401Type *EltTy = VecTy->getElementType();
1402
1403if (!CanBeFlattened(Op)) {
1404InstructionCost EmbedCost(0);
1405// Roughly estimate the cost for embedding the columns into a vector.
1406for (unsignedI = 1;I <N; ++I)
1407 EmbedCost +=
1408TTI.getShuffleCost(TTI::SK_Splice,FixedVectorType::get(EltTy, 1),
1409 {},TTI::TCK_RecipThroughput);
1410return EmbedCost;
1411 }
1412
1413if (match(Op,m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end()) {
1414InstructionCost OriginalCost =
1415TTI.getArithmeticInstrCost(cast<Instruction>(Op)->getOpcode(),
1416 EltTy) *
1417N;
1418InstructionCost NewCost =TTI.getArithmeticInstrCost(
1419 cast<Instruction>(Op)->getOpcode(), VecTy);
1420return NewCost - OriginalCost;
1421 }
1422
1423if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>())) {
1424// The transpose can be skipped for the dot product lowering, roughly
1425// estimate the savings as the cost of embedding the columns in a
1426// vector.
1427InstructionCost EmbedCost(0);
1428for (unsignedI = 1;I <N; ++I)
1429 EmbedCost -=
1430TTI.getShuffleCost(TTI::SK_Splice,FixedVectorType::get(EltTy, 1),
1431 {},TTI::TCK_RecipThroughput);
1432return EmbedCost;
1433 }
1434
1435// Costs for loads.
1436if (N == 1)
1437returnInstructionCost(0);
1438
1439returnTTI.getMemoryOpCost(Instruction::Load, VecTy,Align(1), 0) -
1440N *TTI.getMemoryOpCost(Instruction::Load, EltTy,Align(1), 0);
1441 };
1442
1443// Iterate over LHS and operations feeding LHS and check if it is profitable
1444// to flatten the visited ops. For each op, we compute the difference
1445// between the flattened and matrix versions.
1446SmallPtrSet<Value *, 4> Seen;
1447SmallVector<Value *> WorkList;
1448SmallVector<Value *> ToFlatten;
1449 WorkList.push_back(LHS);
1450InstructionCost LHSCost(0);
1451while (!WorkList.empty()) {
1452Value *Op = WorkList.pop_back_val();
1453if (!Seen.insert(Op).second)
1454continue;
1455
1456InstructionCost OpCost = GetCostForArg(Op, LShape.NumColumns);
1457if (OpCost + LHSCost >= LHSCost)
1458continue;
1459
1460 LHSCost += OpCost;
1461 ToFlatten.push_back(Op);
1462if (auto *I = dyn_cast<Instruction>(Op))
1463 WorkList.append(I->op_begin(),I->op_end());
1464 }
1465
1466// We compare the costs of a vector.reduce.add to sequential add.
1467int AddOpCode = IsIntVec ? Instruction::Add : Instruction::FAdd;
1468int MulOpCode = IsIntVec ? Instruction::Mul : Instruction::FMul;
1469InstructionCost ReductionCost =
1470TTI.getArithmeticReductionCost(
1471 AddOpCode, cast<VectorType>(LHS->getType()),
1472 IsIntVec ? std::nullopt : std::optional(FMF)) +
1473TTI.getArithmeticInstrCost(MulOpCode,LHS->getType());
1474InstructionCost SequentialAddCost =
1475TTI.getArithmeticInstrCost(AddOpCode, ElementType) *
1476 (LShape.NumColumns - 1) +
1477TTI.getArithmeticInstrCost(MulOpCode, ElementType) *
1478 (LShape.NumColumns);
1479if ((LHSCost + ReductionCost - SequentialAddCost) >InstructionCost(0))
1480return;
1481
1482 FusedInsts.insert(MatMul);
1483IRBuilder<> Builder(MatMul);
1484auto FlattenArg = [&Builder, &FusedInsts, &CanBeFlattened,
1485this](Value *Op) {
1486// Matmul must be the only user of loads because we don't use LowerLoad
1487// for row vectors (LowerLoad results in scalar loads and shufflevectors
1488// instead of single vector load).
1489if (!CanBeFlattened(Op))
1490return;
1491
1492if (match(Op,m_BinOp())) {
1493auto It = ShapeMap.find(Op);
1494if (It != ShapeMap.end()) {
1495 It->second = It->second.t();
1496return;
1497 }
1498 }
1499
1500 FusedInsts.insert(cast<Instruction>(Op));
1501// If vector uses the builtin load, lower to a LoadInst
1502Value *Arg;
1503if (match(Op, m_Intrinsic<Intrinsic::matrix_column_major_load>(
1504m_Value(Arg)))) {
1505auto *NewLoad = Builder.CreateLoad(Op->getType(), Arg);
1506Op->replaceAllUsesWith(NewLoad);
1507 eraseFromParentAndRemoveFromShapeMap(cast<Instruction>(Op));
1508return;
1509 }elseif (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>(
1510m_Value(Arg)))) {
1511ToRemove.push_back(cast<Instruction>(Op));
1512Op->replaceAllUsesWith(Arg);
1513return;
1514 }
1515 };
1516
1517for (auto *V : ToFlatten)
1518 FlattenArg(V);
1519
1520LHS = MatMul->getArgOperand(0);
1521
1522// Insert mul/fmul and llvm.vector.reduce.fadd
1523Value *Mul =
1524 IsIntVec ? Builder.CreateMul(LHS, RHS) : Builder.CreateFMul(LHS, RHS);
1525
1526Value *Result;
1527if (IsIntVec)
1528Result = Builder.CreateAddReduce(Mul);
1529else {
1530Result = Builder.CreateFAddReduce(
1531 ConstantFP::get(cast<VectorType>(LHS->getType())->getElementType(),
1532 0.0),
1533Mul);
1534 cast<Instruction>(Result)->setFastMathFlags(FMF);
1535 }
1536
1537// pack scalar back into a matrix and then replace matmul inst
1538Result = Builder.CreateInsertElement(PoisonValue::get(MatMul->getType()),
1539 Result,uint64_t(0));
1540 MatMul->replaceAllUsesWith(Result);
1541 FusedInsts.insert(MatMul);
1542ToRemove.push_back(MatMul);
1543 }
1544
1545 /// Compute \p Result += \p A * \p B for input matrices with left-associating
1546 /// addition.
1547 ///
1548 /// We can fold a transpose into the operand that is used to extract scalars.
1549 /// This is the first operands with row-major and the second with
1550 /// column-major. If \p IsScalarMatrixTransposed we assume the appropriate
1551 /// operand is transposed.
1552void emitMatrixMultiply(MatrixTy &Result,const MatrixTy &A,
1553const MatrixTy &B,IRBuilder<> &Builder,bool IsTiled,
1554bool IsScalarMatrixTransposed,FastMathFlags FMF) {
1555constunsigned VF = std::max<unsigned>(
1556TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
1557 .getFixedValue() /
1558Result.getElementType()->getPrimitiveSizeInBits().getFixedValue(),
1559 1U);
1560unsignedR =Result.getNumRows();
1561unsignedC =Result.getNumColumns();
1562unsignedM =A.getNumColumns();
1563
1564bool IsFP =Result.getElementType()->isFloatingPointTy();
1565assert(A.isColumnMajor() ==B.isColumnMajor() &&
1566Result.isColumnMajor() ==A.isColumnMajor() &&
1567"operands must agree on matrix layout");
1568unsigned NumComputeOps = 0;
1569
1570 Builder.setFastMathFlags(FMF);
1571
1572if (A.isColumnMajor()) {
1573// Multiply columns from the first operand with scalars from the second
1574// operand. Then move along the K axes and accumulate the columns. With
1575// this the adds can be vectorized without reassociation.
1576for (unsigned J = 0; J <C; ++J) {
1577unsignedBlockSize = VF;
1578// If Result is zero, we don't need to accumulate in the K==0 iteration.
1579bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J));
1580
1581for (unsignedI = 0;I <R;I +=BlockSize) {
1582// Gradually lower the vectorization factor to cover the remainder.
1583while (I +BlockSize > R)
1584BlockSize /= 2;
1585
1586Value *Sum = IsTiled ?Result.extractVector(I, J,BlockSize, Builder)
1587 :nullptr;
1588for (unsigned K = 0;K <M; ++K) {
1589Value *L =A.extractVector(I, K,BlockSize, Builder);
1590Value *RH = Builder.CreateExtractElement(
1591B.getColumn(IsScalarMatrixTransposed ? K : J),
1592 IsScalarMatrixTransposed ? J : K);
1593Value *Splat = Builder.CreateVectorSplat(BlockSize, RH,"splat");
1594 Sum =
1595 createMulAdd(isSumZero && K == 0 ?nullptr : Sum, L,Splat,
1596 IsFP, Builder, FMF.allowContract(), NumComputeOps);
1597 }
1598Result.setVector(J,
1599insertVector(Result.getVector(J),I, Sum, Builder));
1600 }
1601 }
1602 }else {
1603// Multiply rows from the second operand with scalars from the first
1604// operand. Then move along the K axes and accumulate the rows. With this
1605// the adds can be vectorized without reassociation.
1606for (unsignedI = 0;I <R; ++I) {
1607unsignedBlockSize = VF;
1608bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I));
1609for (unsigned J = 0; J <C; J +=BlockSize) {
1610// Gradually lower the vectorization factor to cover the remainder.
1611while (J +BlockSize >C)
1612BlockSize /= 2;
1613
1614Value *Sum =nullptr;
1615for (unsigned K = 0;K <M; ++K) {
1616Value *R =B.extractVector(K, J,BlockSize, Builder);
1617Value *LH = Builder.CreateExtractElement(
1618A.getVector(IsScalarMatrixTransposed ? K :I),
1619 IsScalarMatrixTransposed ?I : K);
1620Value *Splat = Builder.CreateVectorSplat(BlockSize, LH,"splat");
1621 Sum =
1622 createMulAdd(isSumZero && K == 0 ?nullptr : Sum,Splat, R,
1623 IsFP, Builder, FMF.allowContract(), NumComputeOps);
1624 }
1625Result.setVector(I,
1626insertVector(Result.getVector(I), J, Sum, Builder));
1627 }
1628 }
1629 }
1630Result.addNumComputeOps(NumComputeOps);
1631 }
1632
1633 /// Ensure that the memory in \p Load does not alias \p Store by potentially
1634 /// copying it to a new location. This new or otherwise the original location
1635 /// is returned.
1636Value *getNonAliasingPointer(LoadInst *Load,StoreInst *Store,
1637CallInst *MatMul) {
1638MemoryLocation StoreLoc =MemoryLocation::get(Store);
1639MemoryLocation LoadLoc =MemoryLocation::get(Load);
1640
1641// If we can statically determine noalias we're good.
1642if (AA->isNoAlias(LoadLoc, StoreLoc))
1643returnLoad->getPointerOperand();
1644
1645// Create code to check if the memory locations of the Load and Store
1646// overlap and if they do, copy Load's operand to a new buffer.
1647
1648// First, create new blocks for 2n part of the check and the copy.
1649BasicBlock *Check0 = MatMul->getParent();
1650// FIXME: Use lazy DTU and update SplitBlock to accept a DTU instead of a
1651// DT. Manually collect dominator tree updates, to avoid unnecessary work,
1652// as we adjust Check0 and Check1's branches.
1653SmallVector<DominatorTree::UpdateType, 4> DTUpdates;
1654for (BasicBlock *Succ :successors(Check0))
1655 DTUpdates.push_back({DT->Delete, Check0, Succ});
1656
1657BasicBlock *Check1 =
1658SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
1659nullptr,"alias_cont");
1660BasicBlock *Copy =
1661SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
1662nullptr,"copy");
1663BasicBlock *Fusion =
1664SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
1665nullptr,"no_alias");
1666
1667// Check if the loaded memory location begins before the end of the store
1668// location. If the condition holds, they might overlap, otherwise they are
1669// guaranteed to not overlap.
1670IRBuilder<> Builder(MatMul);
1671 Check0->getTerminator()->eraseFromParent();
1672 Builder.SetInsertPoint(Check0);
1673Type *IntPtrTy = Builder.getIntPtrTy(Load->getDataLayout());
1674Value *StoreBegin = Builder.CreatePtrToInt(
1675const_cast<Value *>(StoreLoc.Ptr), IntPtrTy,"store.begin");
1676Value *StoreEnd = Builder.CreateAdd(
1677 StoreBegin, ConstantInt::get(IntPtrTy, StoreLoc.Size.getValue()),
1678"store.end",true,true);
1679Value *LoadBegin = Builder.CreatePtrToInt(const_cast<Value *>(LoadLoc.Ptr),
1680 IntPtrTy,"load.begin");
1681 Builder.CreateCondBr(Builder.CreateICmpULT(LoadBegin, StoreEnd), Check1,
1682 Fusion);
1683
1684// Check if the store begins before the end of the load location. If the
1685// condition holds, they alias, otherwise they are guaranteed to not
1686// overlap.
1687 Check1->getTerminator()->eraseFromParent();
1688 Builder.SetInsertPoint(Check1, Check1->begin());
1689Value *LoadEnd = Builder.CreateAdd(
1690 LoadBegin, ConstantInt::get(IntPtrTy, LoadLoc.Size.getValue()),
1691"load.end",true,true);
1692 Builder.CreateCondBr(Builder.CreateICmpULT(StoreBegin, LoadEnd), Copy,
1693 Fusion);
1694
1695// Copy load operand to new alloca.
1696 Builder.SetInsertPoint(Copy,Copy->begin());
1697auto *VT = cast<FixedVectorType>(Load->getType());
1698// Use an array type for the alloca, to avoid potentially huge alignment
1699// requirements for large vector types.
1700auto *ArrayTy = ArrayType::get(VT->getElementType(), VT->getNumElements());
1701AllocaInst *Alloca =
1702 Builder.CreateAlloca(ArrayTy,Load->getPointerAddressSpace());
1703
1704 Builder.CreateMemCpy(Alloca, Alloca->getAlign(),Load->getPointerOperand(),
1705Load->getAlign(), LoadLoc.Size.getValue());
1706 Builder.SetInsertPoint(Fusion, Fusion->begin());
1707PHINode *PHI = Builder.CreatePHI(Load->getPointerOperandType(), 3);
1708PHI->addIncoming(Load->getPointerOperand(), Check0);
1709PHI->addIncoming(Load->getPointerOperand(), Check1);
1710PHI->addIncoming(Alloca, Copy);
1711
1712// Adjust DT.
1713 DTUpdates.push_back({DT->Insert, Check0, Check1});
1714 DTUpdates.push_back({DT->Insert, Check0, Fusion});
1715 DTUpdates.push_back({DT->Insert, Check1,Copy});
1716 DTUpdates.push_back({DT->Insert, Check1, Fusion});
1717 DT->applyUpdates(DTUpdates);
1718returnPHI;
1719 }
1720
1721bool isFusionProfitable(CallInst *MatMul) {
1722if (ForceFusion)
1723returntrue;
1724
1725 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1726 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1727
1728constunsignedR = LShape.NumRows;
1729constunsignedC = RShape.NumColumns;
1730constunsignedM = LShape.NumColumns;
1731auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1732
1733constunsigned VF = std::max<unsigned>(
1734TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
1735 .getFixedValue() /
1736 EltType->getPrimitiveSizeInBits().getFixedValue(),
1737 1U);
1738
1739// Cost model for tiling
1740//
1741// For tiling to be beneficial, we need reuse either along the R or
1742// the C axis. We vectorize along the R axis so that means at least
1743// 3 elements.
1744// TODO: Also consider cost of copying if operands alias.
1745if (R <= VF &&C == 1)
1746returnfalse;
1747// Then we need enough elements to exceed the number of vector
1748// registers we have. Note that this is an oversimplification since
1749// fusing also takes some extra loads which may exceed the number of
1750// reloads necessary.
1751unsigned Op0Regs = (R + VF - 1) / VF * M;
1752unsigned Op1Regs = (M + VF - 1) / VF *C;
1753return Op0Regs + Op1Regs >
1754TTI.getNumberOfRegisters(TTI.getRegisterClassForType(true));
1755 }
1756
1757 MatrixTy getZeroMatrix(Type *EltType,unsigned R,unsignedC) {
1758 MatrixTy Res;
1759auto *ColumType =FixedVectorType::get(EltType, R);
1760for (unsignedI = 0;I <C; ++I)
1761 Res.addVector(ConstantAggregateZero::get(ColumType));
1762return Res;
1763 }
1764
1765void createTiledLoops(CallInst *MatMul,Value *LPtr, ShapeInfo LShape,
1766Value *RPtr, ShapeInfo RShape,StoreInst *Store) {
1767auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1768
1769// Create the main tiling loop nest.
1770TileInfo TI(LShape.NumRows, RShape.NumColumns, LShape.NumColumns,TileSize);
1771DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
1772Instruction *InsertI = cast<Instruction>(MatMul);
1773BasicBlock *Start = InsertI->getParent();
1774BasicBlock *End =
1775SplitBlock(InsertI->getParent(), InsertI, DT, LI,nullptr,"continue");
1776IRBuilder<> Builder(MatMul);
1777BasicBlock *InnerBody = TI.CreateTiledLoops(Start,End, Builder, DTU, *LI);
1778
1779Type *TileVecTy =
1780FixedVectorType::get(MatMul->getType()->getScalarType(),TileSize);
1781 MatrixTy TileResult;
1782// Insert in the inner loop header.
1783 Builder.SetInsertPoint(TI.KLoop.Header->getTerminator());
1784// Create PHI nodes for the result columns to accumulate across iterations.
1785SmallVector<PHINode *, 4> ColumnPhis;
1786for (unsignedI = 0;I <TileSize;I++) {
1787auto *Phi = Builder.CreatePHI(TileVecTy, 2,"result.vec." +Twine(I));
1788Phi->addIncoming(ConstantAggregateZero::get(TileVecTy),
1789 TI.RowLoop.Header->getSingleSuccessor());
1790 TileResult.addVector(Phi);
1791 ColumnPhis.push_back(Phi);
1792 }
1793
1794// Insert in the inner loop body, which computes
1795// Res += Load(CurrentRow, K) * Load(K, CurrentColumn)
1796 Builder.SetInsertPoint(InnerBody->getTerminator());
1797// Load tiles of the operands.
1798 MatrixTyA =
1799 loadMatrix(LPtr, {},false, LShape, TI.RowLoop.Index, TI.KLoop.Index,
1800 {TileSize,TileSize}, EltType, Builder);
1801 MatrixTyB =
1802 loadMatrix(RPtr, {},false, RShape, TI.KLoop.Index, TI.ColumnLoop.Index,
1803 {TileSize,TileSize}, EltType, Builder);
1804 emitMatrixMultiply(TileResult,A,B, Builder,true,false,
1805 getFastMathFlags(MatMul));
1806// Store result after the inner loop is done.
1807 Builder.SetInsertPoint(TI.RowLoop.Latch->getTerminator());
1808 storeMatrix(TileResult,Store->getPointerOperand(),Store->getAlign(),
1809Store->isVolatile(), {LShape.NumRows, RShape.NumColumns},
1810 TI.RowLoop.Index, TI.ColumnLoop.Index, EltType, Builder);
1811
1812for (unsignedI = 0;I < TileResult.getNumVectors();I++)
1813 ColumnPhis[I]->addIncoming(TileResult.getVector(I), TI.KLoop.Latch);
1814
1815// Force unrolling of a few iterations of the inner loop, to make sure there
1816// is enough work per iteration.
1817// FIXME: The unroller should make this decision directly instead, but
1818// currently the cost-model is not up to the task.
1819unsigned InnerLoopUnrollCount = std::min(10u, LShape.NumColumns /TileSize);
1820addStringMetadataToLoop(LI->getLoopFor(TI.KLoop.Header),
1821"llvm.loop.unroll.count", InnerLoopUnrollCount);
1822 }
1823
1824void emitSIMDTiling(CallInst *MatMul,LoadInst *LoadOp0,LoadInst *LoadOp1,
1825StoreInst *Store,
1826SmallPtrSetImpl<Instruction *> &FusedInsts) {
1827assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
1828"Tiling only supported for column-major matrixes at the moment!");
1829if (!isFusionProfitable(MatMul))
1830return;
1831
1832 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1833 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1834
1835constunsignedR = LShape.NumRows;
1836constunsignedC = RShape.NumColumns;
1837constunsignedM = LShape.NumColumns;
1838auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1839
1840Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul);
1841Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul);
1842Value *CPtr =Store->getPointerOperand();
1843
1844if (TileUseLoops && (R %TileSize == 0 &&C %TileSize == 0))
1845 createTiledLoops(MatMul, APtr, LShape, BPtr, RShape, Store);
1846else {
1847IRBuilder<> Builder(Store);
1848for (unsigned J = 0; J <C; J +=TileSize)
1849for (unsignedI = 0;I <R;I +=TileSize) {
1850constunsigned TileR = std::min(R -I,unsigned(TileSize));
1851constunsigned TileC = std::min(C - J,unsigned(TileSize));
1852 MatrixTy Res = getZeroMatrix(EltType, TileR, TileC);
1853
1854for (unsigned K = 0;K <M;K +=TileSize) {
1855constunsigned TileM = std::min(M - K,unsigned(TileSize));
1856 MatrixTyA =
1857 loadMatrix(APtr, LoadOp0->getAlign(), LoadOp0->isVolatile(),
1858 LShape, Builder.getInt64(I), Builder.getInt64(K),
1859 {TileR, TileM}, EltType, Builder);
1860 MatrixTyB =
1861 loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(),
1862 RShape, Builder.getInt64(K), Builder.getInt64(J),
1863 {TileM, TileC}, EltType, Builder);
1864 emitMatrixMultiply(Res,A,B, Builder,true,false,
1865 getFastMathFlags(MatMul));
1866 }
1867 storeMatrix(Res, CPtr,Store->getAlign(),Store->isVolatile(), {R, M},
1868 Builder.getInt64(I), Builder.getInt64(J), EltType,
1869 Builder);
1870 }
1871 }
1872
1873// Mark eliminated instructions as fused and remove them.
1874 FusedInsts.insert(Store);
1875 FusedInsts.insert(MatMul);
1876 eraseFromParentAndRemoveFromShapeMap(Store);
1877 eraseFromParentAndRemoveFromShapeMap(MatMul);
1878if (LoadOp0->hasNUses(0)) {
1879 FusedInsts.insert(LoadOp0);
1880 eraseFromParentAndRemoveFromShapeMap(LoadOp0);
1881 }
1882if (LoadOp1 != LoadOp0 && LoadOp1->hasNUses(0)) {
1883 FusedInsts.insert(LoadOp1);
1884 eraseFromParentAndRemoveFromShapeMap(LoadOp1);
1885 }
1886 }
1887
1888 /// Try to lower matrix multiply chains by fusing operations.
1889 ///
1890 /// Call finalizeLowering on lowered instructions. Instructions that are
1891 /// completely eliminated by fusion are added to \p FusedInsts.
1892void
1893 LowerMatrixMultiplyFused(CallInst *MatMul,
1894SmallPtrSetImpl<Instruction *> &FusedInsts,
1895SmallVector<IntrinsicInst *, 16> &LifetimeEnds) {
1896if (!FuseMatrix || !DT)
1897return;
1898
1899assert(AA && LI &&"Analyses should be available");
1900
1901Value *A = MatMul->getArgOperand(0);
1902Value *B = MatMul->getArgOperand(1);
1903
1904// We can fold the transpose into the operand that is used to fetch scalars.
1905Value *T;
1906if (MatrixLayout == MatrixLayoutTy::ColumnMajor
1907 ?match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))
1908 :match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))) {
1909IRBuilder<> Builder(MatMul);
1910auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1911 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1912 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1913constunsignedR = LShape.NumRows;
1914constunsignedM = LShape.NumColumns;
1915constunsignedC = RShape.NumColumns;
1916
1917 MatrixTy MA;
1918 MatrixTy MB;
1919
1920Value *Transpose;
1921if (MatrixLayout == MatrixLayoutTy::ColumnMajor) {
1922 MA = getMatrix(A, ShapeInfo(R, M), Builder);
1923 MB = getMatrix(T, ShapeInfo(C, M), Builder);
1924 Transpose =B;
1925 }else {
1926 MA = getMatrix(T, ShapeInfo(R, M), Builder);
1927 MB = getMatrix(B, ShapeInfo(C, M), Builder);
1928 Transpose =A;
1929 }
1930
1931// Initialize the output
1932 MatrixTyResult(R,C, EltType);
1933
1934 emitMatrixMultiply(Result, MA, MB, Builder,false,true,
1935 getFastMathFlags(MatMul));
1936
1937 FusedInsts.insert(MatMul);
1938if (Transpose->hasOneUse()) {
1939 FusedInsts.insert(cast<Instruction>(Transpose));
1940ToRemove.push_back(cast<Instruction>(Transpose));
1941// TODO: add a fake entry for the folded instruction so that this is
1942// included in the expression in the remark.
1943 Inst2ColumnMatrix[Transpose] = MatrixTy(M,C, EltType);
1944 }
1945 finalizeLowering(MatMul, Result, Builder);
1946return;
1947 }
1948
1949if (!MatMul->hasOneUse() ||MatrixLayout != MatrixLayoutTy::ColumnMajor)
1950return;
1951
1952// Lower {ld, ld} -> matmul -> st chains. No need to call finalizeLowering
1953// since the single store user will be lowered as part of this.
1954auto *LoadOp0 = dyn_cast<LoadInst>(A);
1955auto *LoadOp1 = dyn_cast<LoadInst>(B);
1956auto *Store = dyn_cast<StoreInst>(*MatMul->user_begin());
1957if (LoadOp0 && LoadOp1 && Store) {
1958// The store address must dominate the MatMul instruction, otherwise
1959// we create invalid IR.
1960SetVector<Value *> WorkList;
1961 WorkList.insert(Store->getOperand(1));
1962SmallVector<Instruction *> ToHoist;
1963for (unsignedI = 0;I != WorkList.size(); ++I) {
1964Value *Current = WorkList[I];
1965auto *CurrI = dyn_cast<Instruction>(Current);
1966if (!CurrI)
1967continue;
1968if (isa<PHINode>(CurrI))
1969return;
1970if (DT->dominates(CurrI, MatMul))
1971continue;
1972if (CurrI->mayHaveSideEffects() || CurrI->mayReadFromMemory())
1973return;
1974 ToHoist.push_back(CurrI);
1975 WorkList.insert(CurrI->op_begin(), CurrI->op_end());
1976 }
1977
1978sort(ToHoist, [this](Instruction *A,Instruction *B) {
1979return DT->dominates(A,B);
1980 });
1981for (Instruction *I : ToHoist)
1982I->moveBefore(MatMul->getIterator());
1983
1984// Deal with lifetime.end calls that might be between Load0/Load1 and the
1985// store. To avoid introducing loads to dead objects (i.e. after the
1986// lifetime has been termined by @llvm.lifetime.end), either sink them
1987// after the store if in the same block, or remove the lifetime.end marker
1988// otherwise. This might pessimize further optimizations, by extending the
1989// lifetime of the object until the function returns, but should be
1990// conservatively correct.
1991MemoryLocation Load0Loc =MemoryLocation::get(LoadOp0);
1992MemoryLocation Load1Loc =MemoryLocation::get(LoadOp1);
1993BasicBlock *StoreParent =Store->getParent();
1994bool FusableOpsInSameBlock = LoadOp0->getParent() == StoreParent &&
1995 LoadOp1->getParent() == StoreParent;
1996for (unsignedIdx = 0;Idx != LifetimeEnds.size();) {
1997IntrinsicInst *End = LifetimeEnds[Idx];
1998auto Inc =make_scope_exit([&Idx]() {Idx++; });
1999// If the lifetime.end is guaranteed to be before the loads or after the
2000// store, it won't interfere with fusion.
2001if (DT->dominates(End, LoadOp0) && DT->dominates(End, LoadOp1))
2002continue;
2003if (DT->dominates(Store,End))
2004continue;
2005// If all fusable ops are in the same block and the lifetime.end is in a
2006// different block, it won't interfere with fusion.
2007if (FusableOpsInSameBlock &&End->getParent() != StoreParent)
2008continue;
2009
2010// If the loads don't alias the lifetime.end, it won't interfere with
2011// fusion.
2012MemoryLocation EndLoc =MemoryLocation::getForArgument(End, 1,nullptr);
2013if (!EndLoc.Ptr)
2014continue;
2015if (AA->isNoAlias(Load0Loc, EndLoc) && AA->isNoAlias(Load1Loc, EndLoc))
2016continue;
2017
2018// If both lifetime.end and the store are in the same block, extend the
2019// lifetime until after the store, so the new lifetime covers the loads
2020// we introduce later.
2021if (End->getParent() == StoreParent) {
2022End->moveAfter(Store);
2023continue;
2024 }
2025
2026// Otherwise remove the conflicting lifetime.end marker.
2027ToRemove.push_back(End);
2028std::swap(LifetimeEnds[Idx], LifetimeEnds.back());
2029 LifetimeEnds.pop_back();
2030 Inc.release();
2031 }
2032
2033 emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts);
2034return;
2035 }
2036 }
2037
2038 /// Lowers llvm.matrix.multiply.
2039void LowerMultiply(CallInst *MatMul) {
2040IRBuilder<> Builder(MatMul);
2041auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
2042 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
2043 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
2044
2045const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder);
2046const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder);
2047assert(Lhs.getElementType() == Rhs.getElementType() &&
2048"Matrix multiply argument element types do not match.");
2049
2050constunsignedR = LShape.NumRows;
2051constunsignedC = RShape.NumColumns;
2052assert(LShape.NumColumns == RShape.NumRows);
2053
2054// Initialize the output
2055 MatrixTyResult(R,C, EltType);
2056assert(Lhs.getElementType() ==Result.getElementType() &&
2057"Matrix multiply result element type does not match arguments.");
2058
2059 emitMatrixMultiply(Result, Lhs, Rhs, Builder,false,false,
2060 getFastMathFlags(MatMul));
2061 finalizeLowering(MatMul, Result, Builder);
2062 }
2063
2064 /// Lowers llvm.matrix.transpose.
2065void LowerTranspose(CallInst *Inst) {
2066 MatrixTyResult;
2067IRBuilder<> Builder(Inst);
2068Value *InputVal = Inst->getArgOperand(0);
2069VectorType *VectorTy = cast<VectorType>(InputVal->getType());
2070 ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2));
2071 MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder);
2072
2073constunsigned NewNumVecs =
2074 InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns;
2075constunsigned NewNumElts =
2076 InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows;
2077
2078for (unsignedI = 0;I < NewNumVecs; ++I) {
2079// Build a single result vector. First initialize it.
2080Value *ResultVector =PoisonValue::get(
2081FixedVectorType::get(VectorTy->getElementType(), NewNumElts));
2082// Go through the old elements and insert it into the resulting vector.
2083for (auto J :enumerate(InputMatrix.vectors())) {
2084Value *Elt = Builder.CreateExtractElement(J.value(),I);
2085// Row and column indices are transposed.
2086 ResultVector =
2087 Builder.CreateInsertElement(ResultVector, Elt, J.index());
2088 }
2089Result.addVector(ResultVector);
2090 }
2091
2092// TODO: Improve estimate of operations needed for transposes. Currently we
2093// just count the insertelement/extractelement instructions, but do not
2094// account for later simplifications/combines.
2095 finalizeLowering(
2096 Inst,
2097Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns)
2098 .addNumExposedTransposes(1),
2099 Builder);
2100 }
2101
2102 /// Lower load instructions, if shape information is available.
2103bool VisitLoad(LoadInst *Inst,Value *Ptr,IRBuilder<> &Builder) {
2104autoI = ShapeMap.find(Inst);
2105if (I == ShapeMap.end())
2106returnfalse;
2107
2108LowerLoad(Inst,Ptr, Inst->getAlign(),
2109 Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
2110I->second);
2111returntrue;
2112 }
2113
2114bool VisitStore(StoreInst *Inst,Value *StoredVal,Value *Ptr,
2115IRBuilder<> &Builder) {
2116autoI = ShapeMap.find(StoredVal);
2117if (I == ShapeMap.end())
2118returnfalse;
2119
2120LowerStore(Inst, StoredVal,Ptr, Inst->getAlign(),
2121 Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
2122I->second);
2123returntrue;
2124 }
2125
2126 /// Lower binary operators, if shape information is available.
2127bool VisitBinaryOperator(BinaryOperator *Inst) {
2128autoI = ShapeMap.find(Inst);
2129if (I == ShapeMap.end())
2130returnfalse;
2131
2132Value *Lhs = Inst->getOperand(0);
2133Value *Rhs = Inst->getOperand(1);
2134
2135IRBuilder<> Builder(Inst);
2136 ShapeInfo &Shape =I->second;
2137
2138 MatrixTyResult;
2139 MatrixTyA = getMatrix(Lhs, Shape, Builder);
2140 MatrixTyB = getMatrix(Rhs, Shape, Builder);
2141assert(A.isColumnMajor() ==B.isColumnMajor() &&
2142Result.isColumnMajor() ==A.isColumnMajor() &&
2143"operands must agree on matrix layout");
2144
2145 Builder.setFastMathFlags(getFastMathFlags(Inst));
2146
2147// Helper to perform binary op on vectors.
2148auto BuildVectorOp = [&Builder, Inst](Value *LHS,Value *RHS) {
2149switch (Inst->getOpcode()) {
2150case Instruction::Add:
2151return Builder.CreateAdd(LHS, RHS);
2152case Instruction::Mul:
2153return Builder.CreateMul(LHS, RHS);
2154case Instruction::Sub:
2155return Builder.CreateSub(LHS, RHS);
2156case Instruction::FAdd:
2157return Builder.CreateFAdd(LHS, RHS);
2158case Instruction::FMul:
2159return Builder.CreateFMul(LHS, RHS);
2160case Instruction::FSub:
2161return Builder.CreateFSub(LHS, RHS);
2162default:
2163llvm_unreachable("Unsupported binary operator for matrix");
2164 }
2165 };
2166
2167for (unsignedI = 0;I < Shape.getNumVectors(); ++I)
2168Result.addVector(BuildVectorOp(A.getVector(I),B.getVector(I)));
2169
2170 finalizeLowering(Inst,
2171Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
2172Result.getNumVectors()),
2173 Builder);
2174returntrue;
2175 }
2176
2177 /// Lower unary operators, if shape information is available.
2178bool VisitUnaryOperator(UnaryOperator *Inst) {
2179autoI = ShapeMap.find(Inst);
2180if (I == ShapeMap.end())
2181returnfalse;
2182
2183Value *Op = Inst->getOperand(0);
2184
2185IRBuilder<> Builder(Inst);
2186 ShapeInfo &Shape =I->second;
2187
2188 MatrixTyResult;
2189 MatrixTyM = getMatrix(Op, Shape, Builder);
2190
2191 Builder.setFastMathFlags(getFastMathFlags(Inst));
2192
2193// Helper to perform unary op on vectors.
2194auto BuildVectorOp = [&Builder, Inst](Value *Op) {
2195switch (Inst->getOpcode()) {
2196case Instruction::FNeg:
2197return Builder.CreateFNeg(Op);
2198default:
2199llvm_unreachable("Unsupported unary operator for matrix");
2200 }
2201 };
2202
2203for (unsignedI = 0;I < Shape.getNumVectors(); ++I)
2204Result.addVector(BuildVectorOp(M.getVector(I)));
2205
2206 finalizeLowering(Inst,
2207Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
2208Result.getNumVectors()),
2209 Builder);
2210returntrue;
2211 }
2212
2213 /// Helper to linearize a matrix expression tree into a string. Currently
2214 /// matrix expressions are linarized by starting at an expression leaf and
2215 /// linearizing bottom up.
2216structExprLinearizer {
2217unsigned LengthToBreak = 100;
2218 std::string Str;
2219raw_string_ostream Stream;
2220unsigned LineLength = 0;
2221constDataLayout &DL;
2222
2223 /// Mapping from instructions to matrixes. It is used to identify
2224 /// matrix instructions.
2225constMapVector<Value *, MatrixTy> &Inst2Matrix;
2226
2227 /// Mapping from values to the leaves of all expressions that the value is
2228 /// part of.
2229constDenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared;
2230
2231 /// Set of matrix expressions in the scope of a given DISubprogram.
2232constSmallSetVector<Value *, 32> &ExprsInSubprogram;
2233
2234 /// Leaf node of the expression to linearize.
2235Value *Leaf;
2236
2237 /// Used to keep track of sub-expressions that get reused while linearizing
2238 /// the expression. Re-used sub-expressions are marked as (reused).
2239SmallPtrSet<Value *, 8> ReusedExprs;
2240
2241 ExprLinearizer(constDataLayout &DL,
2242constMapVector<Value *, MatrixTy> &Inst2Matrix,
2243constDenseMap<Value *,SmallPtrSet<Value *, 2>> &Shared,
2244constSmallSetVector<Value *, 32> &ExprsInSubprogram,
2245Value *Leaf)
2246 : Stream(Str),DL(DL), Inst2Matrix(Inst2Matrix), Shared(Shared),
2247 ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {}
2248
2249voidindent(unsignedN) {
2250 LineLength +=N;
2251for (unsigned i = 0; i <N; i++)
2252 Stream <<" ";
2253 }
2254
2255void lineBreak() {
2256 Stream <<"\n";
2257 LineLength = 0;
2258 }
2259
2260void maybeIndent(unsigned Indent) {
2261if (LineLength >= LengthToBreak)
2262 lineBreak();
2263
2264if (LineLength == 0)
2265indent(Indent);
2266 }
2267
2268voidwrite(StringRef S) {
2269 LineLength += S.size();
2270 Stream << S;
2271 }
2272
2273Value *getUnderlyingObjectThroughLoads(Value *V) {
2274if (Value *Ptr =getPointerOperand(V))
2275return getUnderlyingObjectThroughLoads(Ptr);
2276elseif (V->getType()->isPointerTy())
2277returngetUnderlyingObject(V);
2278returnV;
2279 }
2280
2281 /// Returns true if \p V is a matrix value in the given subprogram.
2282bool isMatrix(Value *V) const{return ExprsInSubprogram.count(V); }
2283
2284 /// If \p V is a matrix value, print its shape as NumRows x NumColumns to
2285 /// \p SS.
2286void prettyPrintMatrixType(Value *V,raw_string_ostream &SS) {
2287autoM = Inst2Matrix.find(V);
2288if (M == Inst2Matrix.end())
2289SS <<"unknown";
2290else {
2291SS <<M->second.getNumRows();
2292SS <<"x";
2293SS <<M->second.getNumColumns();
2294 }
2295 }
2296
2297 /// Write the called function name. Handles calls to llvm.matrix.*
2298 /// specially: we write the name, followed by the dimensions of the input
2299 /// matrixes, followed by the scalar type name.
2300void writeFnName(CallInst *CI) {
2301if (!CI->getCalledFunction())
2302write("<no called fn>");
2303else {
2304StringRefName = CI->getCalledFunction()->getName();
2305if (!Name.starts_with("llvm.matrix")) {
2306write(Name);
2307return;
2308 }
2309auto *II = cast<IntrinsicInst>(CI);
2310write(Intrinsic::getBaseName(II->getIntrinsicID())
2311 .drop_front(StringRef("llvm.matrix.").size()));
2312write(".");
2313 std::string Tmp;
2314raw_string_ostreamSS(Tmp);
2315
2316switch (II->getIntrinsicID()) {
2317case Intrinsic::matrix_multiply:
2318 prettyPrintMatrixType(II->getOperand(0), SS);
2319SS <<".";
2320 prettyPrintMatrixType(II->getOperand(1), SS);
2321SS <<"." << *II->getType()->getScalarType();
2322break;
2323case Intrinsic::matrix_transpose:
2324 prettyPrintMatrixType(II->getOperand(0), SS);
2325SS <<"." << *II->getType()->getScalarType();
2326break;
2327case Intrinsic::matrix_column_major_load:
2328 prettyPrintMatrixType(II, SS);
2329SS <<"." << *II->getType()->getScalarType();
2330break;
2331case Intrinsic::matrix_column_major_store:
2332 prettyPrintMatrixType(II->getOperand(0), SS);
2333SS <<"." << *II->getOperand(0)->getType()->getScalarType();
2334break;
2335default:
2336llvm_unreachable("Unhandled case");
2337 }
2338write(Tmp);
2339 }
2340 }
2341
2342unsigned getNumShapeArgs(CallInst *CI) const{
2343if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) {
2344switch (II->getIntrinsicID()) {
2345case Intrinsic::matrix_multiply:
2346return 3;
2347case Intrinsic::matrix_transpose:
2348return 2;
2349case Intrinsic::matrix_column_major_load:
2350case Intrinsic::matrix_column_major_store:
2351return 3;
2352default:
2353return 0;
2354 }
2355 }
2356return 0;
2357 }
2358
2359 /// Special printing for values: for pointers, we print if they refer to an
2360 /// (function) external address or a stack address, for other values we
2361 /// either print the constant or "scalar"/"matrix" for other values.
2362voidwrite(Value *V) {
2363V = getUnderlyingObjectThroughLoads(V);
2364if (V->getType()->isPointerTy()) {
2365if (isa<AllocaInst>(V)) {
2366 Stream <<"stack addr";
2367 LineLength +=StringRef("stack addr").size();
2368 }else {
2369 Stream <<"addr";
2370 LineLength +=StringRef("addr").size();
2371 }
2372if (!V->getName().empty()) {
2373 Stream <<" %" <<V->getName() <<"";
2374 LineLength +=V->getName().size() + 2;
2375 }
2376return;
2377 }
2378
2379 std::string Tmp;
2380raw_string_ostream TmpStream(Tmp);
2381
2382if (auto *CI = dyn_cast<ConstantInt>(V))
2383 TmpStream << CI->getValue();
2384elseif (isa<Constant>(V))
2385 TmpStream <<"constant";
2386else {
2387if (isMatrix(V))
2388 TmpStream <<"matrix";
2389else
2390 TmpStream <<"scalar";
2391 }
2392 Tmp = std::string(StringRef(Tmp).trim());
2393 LineLength += Tmp.size();
2394 Stream << Tmp;
2395 }
2396
2397 /// Linearize expression \p Expr starting at an indentation of \p Indent.
2398 /// Expressions that are re-used multiple times are prefixed with (reused)
2399 /// at the re-used root instruction.
2400void linearizeExpr(Value *Expr,unsigned Indent,bool ParentReused,
2401bool ParentShared) {
2402auto *I = cast<Instruction>(Expr);
2403 maybeIndent(Indent);
2404SmallVector<Value *, 8> Ops;
2405
2406// Is Expr shared with other expression leaves?
2407bool ExprShared =false;
2408
2409// Deal with shared subtrees. Mark them as shared, if required.
2410if (!ParentShared) {
2411autoSI = Shared.find(Expr);
2412assert(SI != Shared.end() &&SI->second.count(Leaf));
2413
2414for (Value *S :SI->second) {
2415if (S == Leaf)
2416continue;
2417DebugLocDL = cast<Instruction>(S)->getDebugLoc();
2418write("shared with remark at line " + std::to_string(DL.getLine()) +
2419" column " + std::to_string(DL.getCol()) +" (");
2420 }
2421 ExprShared =SI->second.size() > 1;
2422 }
2423
2424bool Reused = !ReusedExprs.insert(Expr).second;
2425if (Reused && !ParentReused)
2426write("(reused) ");
2427
2428if (auto *CI = dyn_cast<CallInst>(I)) {
2429 writeFnName(CI);
2430
2431 Ops.append(CI->arg_begin(), CI->arg_end() - getNumShapeArgs(CI));
2432 }elseif (isa<BitCastInst>(Expr)) {
2433// Special case bitcasts, which are used to materialize matrixes from
2434// non-matrix ops.
2435write("matrix");
2436return;
2437 }else {
2438 Ops.append(I->value_op_begin(),I->value_op_end());
2439write(std::string(I->getOpcodeName()));
2440 }
2441
2442write(std::string("("));
2443
2444unsigned NumOpsToBreak = 1;
2445if (match(Expr, m_Intrinsic<Intrinsic::matrix_column_major_load>()))
2446 NumOpsToBreak = 2;
2447
2448for (Value *Op : Ops) {
2449if (Ops.size() > NumOpsToBreak)
2450 lineBreak();
2451
2452 maybeIndent(Indent + 1);
2453if (isMatrix(Op))
2454 linearizeExpr(Op, Indent + 1, Reused, ExprShared);
2455else
2456write(Op);
2457if (Op != Ops.back())
2458write(", ");
2459 }
2460
2461write(")");
2462 }
2463
2464const std::string &getResult() {
2465return Str;
2466 }
2467 };
2468
2469 /// Generate remarks for matrix operations in a function. To generate remarks
2470 /// for matrix expressions, the following approach is used:
2471 /// 1. Use the inlined-at debug information to group matrix operations to the
2472 /// DISubprograms they are contained in.
2473 /// 2. Collect leaves of matrix expressions (done in
2474 /// RemarkGenerator::getExpressionLeaves) for each subprogram - expression
2475// mapping. Leaves are lowered matrix instructions without other matrix
2476// users (like stores) in the current subprogram.
2477 /// 3. For each leaf, create a remark containing a linearizied version of the
2478 /// matrix expression. The expression is linearized by a recursive
2479 /// bottom-up traversal of the matrix operands, starting at a leaf. Note
2480 /// that multiple leaves can share sub-expressions. Shared subexpressions
2481 /// are explicitly marked as shared().
2482structRemarkGenerator {
2483constMapVector<Value *, MatrixTy> &Inst2Matrix;
2484OptimizationRemarkEmitter &ORE;
2485Function &Func;
2486constDataLayout &DL;
2487
2488 RemarkGenerator(constMapVector<Value *, MatrixTy> &Inst2Matrix,
2489OptimizationRemarkEmitter &ORE,Function &Func)
2490 : Inst2Matrix(Inst2Matrix), ORE(ORE),Func(Func),
2491DL(Func.getDataLayout()) {}
2492
2493 /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are
2494 /// instructions in Inst2Matrix returning void or without any users in
2495 /// \p ExprsInSubprogram. Currently that should only include stores.
2496SmallVector<Value *, 4>
2497 getExpressionLeaves(constSmallSetVector<Value *, 32> &ExprsInSubprogram) {
2498SmallVector<Value *, 4> Leaves;
2499for (auto *Expr : ExprsInSubprogram)
2500if (Expr->getType()->isVoidTy() ||
2501 !any_of(Expr->users(), [&ExprsInSubprogram](User *U) {
2502 return ExprsInSubprogram.count(U);
2503 }))
2504 Leaves.push_back(Expr);
2505return Leaves;
2506 }
2507
2508 /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf
2509 /// to all visited expressions in \p Shared. Limit the matrix operations to
2510 /// the ones in \p ExprsInSubprogram.
2511void collectSharedInfo(Value *Leaf,Value *V,
2512constSmallSetVector<Value *, 32> &ExprsInSubprogram,
2513DenseMap<Value *,SmallPtrSet<Value *, 2>> &Shared) {
2514
2515if (!ExprsInSubprogram.count(V))
2516return;
2517
2518 Shared[V].insert(Leaf);
2519
2520for (Value *Op : cast<Instruction>(V)->operand_values())
2521 collectSharedInfo(Leaf,Op, ExprsInSubprogram, Shared);
2522 }
2523
2524 /// Calculate the number of exclusive and shared op counts for expression
2525 /// starting at \p V. Expressions used multiple times are counted once.
2526 /// Limit the matrix operations to the ones in \p ExprsInSubprogram.
2527 std::pair<OpInfoTy, OpInfoTy>
2528 sumOpInfos(Value *Root,SmallPtrSetImpl<Value *> &ReusedExprs,
2529constSmallSetVector<Value *, 32> &ExprsInSubprogram,
2530DenseMap<Value *,SmallPtrSet<Value *, 2>> &Shared) const{
2531if (!ExprsInSubprogram.count(Root))
2532return {};
2533
2534// Already counted this expression. Stop.
2535if (!ReusedExprs.insert(Root).second)
2536return {};
2537
2538 OpInfoTy SharedCount;
2539 OpInfoTy Count;
2540
2541autoI = Shared.find(Root);
2542auto CM = Inst2Matrix.find(Root);
2543if (I->second.size() == 1)
2544 Count = CM->second.getOpInfo();
2545else
2546 SharedCount = CM->second.getOpInfo();
2547
2548for (Value *Op : cast<Instruction>(Root)->operand_values()) {
2549autoC = sumOpInfos(Op, ReusedExprs, ExprsInSubprogram, Shared);
2550 Count +=C.first;
2551 SharedCount +=C.second;
2552 }
2553return {Count, SharedCount};
2554 }
2555
2556void emitRemarks() {
2557if (!ORE.allowExtraAnalysis(DEBUG_TYPE))
2558return;
2559
2560// Map matrix operations to their containting subprograms, by traversing
2561// the inlinedAt chain. If the function does not have a DISubprogram, we
2562// only map them to the containing function.
2563MapVector<DISubprogram *, SmallVector<Value *, 8>> Subprog2Exprs;
2564for (constauto &KV : Inst2Matrix) {
2565if (Func.getSubprogram()) {
2566auto *I = cast<Instruction>(KV.first);
2567DILocation *Context =I->getDebugLoc();
2568while (Context) {
2569 Subprog2Exprs[getSubprogram(Context->getScope())].push_back(
2570 KV.first);
2571 Context =DebugLoc(Context).getInlinedAt();
2572 }
2573 }else {
2574 Subprog2Exprs[nullptr].push_back(KV.first);
2575 }
2576 }
2577for (auto &KV : Subprog2Exprs) {
2578SmallSetVector<Value *, 32> ExprsInSubprogram(KV.second.begin(),
2579 KV.second.end());
2580auto Leaves = getExpressionLeaves(ExprsInSubprogram);
2581
2582DenseMap<Value *, SmallPtrSet<Value *, 2>> Shared;
2583for (Value *Leaf : Leaves)
2584 collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared);
2585
2586// Generate remarks for each leaf.
2587for (auto *L : Leaves) {
2588
2589DebugLoc Loc = cast<Instruction>(L)->getDebugLoc();
2590DILocation *Context = cast<Instruction>(L)->getDebugLoc();
2591while (Context) {
2592if (getSubprogram(Context->getScope()) == KV.first) {
2593 Loc = Context;
2594break;
2595 }
2596 Context =DebugLoc(Context).getInlinedAt();
2597 }
2598
2599SmallPtrSet<Value *, 8> ReusedExprs;
2600 OpInfoTy Counts, SharedCounts;
2601 std::tie(Counts, SharedCounts) =
2602 sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared);
2603
2604OptimizationRemark Rem(DEBUG_TYPE,"matrix-lowered", Loc,
2605 cast<Instruction>(L)->getParent());
2606
2607 Rem <<"Lowered with ";
2608 Rem <<ore::NV("NumStores", Counts.NumStores) <<" stores, "
2609 <<ore::NV("NumLoads", Counts.NumLoads) <<" loads, "
2610 <<ore::NV("NumComputeOps", Counts.NumComputeOps)
2611 <<" compute ops, "
2612 <<ore::NV("NumExposedTransposes", Counts.NumExposedTransposes)
2613 <<" exposed transposes";
2614
2615if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 ||
2616 SharedCounts.NumComputeOps > 0) {
2617 Rem <<",\nadditionally "
2618 <<ore::NV("NumStores", SharedCounts.NumStores) <<" stores, "
2619 <<ore::NV("NumLoads", SharedCounts.NumLoads) <<" loads, "
2620 <<ore::NV("NumFPOps", SharedCounts.NumComputeOps)
2621 <<" compute ops"
2622 <<" are shared with other expressions";
2623 }
2624
2625 Rem << ("\n" + linearize(L, Shared, ExprsInSubprogram, DL));
2626 ORE.emit(Rem);
2627 }
2628 }
2629 }
2630
2631 std::string
2632 linearize(Value *L,
2633constDenseMap<Value *,SmallPtrSet<Value *, 2>> &Shared,
2634constSmallSetVector<Value *, 32> &ExprsInSubprogram,
2635constDataLayout &DL) {
2636 ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L);
2637 Lin.linearizeExpr(L, 0,false,false);
2638return Lin.getResult();
2639 }
2640 };
2641};
2642}// namespace
2643
2644PreservedAnalysesLowerMatrixIntrinsicsPass::run(Function &F,
2645FunctionAnalysisManager &AM) {
2646auto &TTI = AM.getResult<TargetIRAnalysis>(F);
2647
2648 LowerMatrixIntrinsics LMT(F,TTI, Minimal ?nullptr : &AM);
2649if (LMT.Visit()) {
2650PreservedAnalyses PA;
2651if (!Minimal) {
2652 PA.preserve<LoopAnalysis>();
2653 PA.preserve<DominatorTreeAnalysis>();
2654 }
2655return PA;
2656 }
2657returnPreservedAnalyses::all();
2658}
2659
2660voidLowerMatrixIntrinsicsPass::printPipeline(
2661raw_ostream &OS,function_ref<StringRef(StringRef)> MapClassName2PassName) {
2662static_cast<PassInfoMixin<LowerMatrixIntrinsicsPass> *>(this)->printPipeline(
2663OS, MapClassName2PassName);
2664OS <<'<';
2665if (Minimal)
2666OS <<"minimal";
2667OS <<'>';
2668}
PHI
Rewrite undef for PHI
Definition:AMDGPURewriteUndefForPHI.cpp:100
ToRemove
ReachingDefAnalysis InstSet & ToRemove
Definition:ARMLowOverheadLoops.cpp:531
DL
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Definition:ARMSLSHardening.cpp:73
AliasAnalysis.h
Alignment.h
getParent
static const Function * getParent(const Value *V)
Definition:BasicAliasAnalysis.cpp:863
BasicBlockUtils.h
BT
BitTracker BT
Definition:BitTracker.cpp:73
B
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
A
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
D
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
CommandLine.h
clEnumValN
#define clEnumValN(ENUMVAL, FLAGNAME, DESC)
Definition:CommandLine.h:686
DataLayout.h
Idx
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
Definition:DeadArgumentElimination.cpp:353
DebugInfoMetadata.h
Debug.h
LLVM_DEBUG
#define LLVM_DEBUG(...)
Definition:Debug.h:106
DomTreeUpdater.h
Name
std::string Name
Definition:ELFObjHandler.cpp:77
End
bool End
Definition:ELF_riscv.cpp:480
GEP
Hexagon Common GEP
Definition:HexagonCommonGEP.cpp:170
vectors
hexagon Hexagon specific predictive commoning for HVX vectors
Definition:HexagonVectorLoopCarriedReuse.cpp:218
IRBuilder.h
CFG.h
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
Function.h
IntrinsicInst.h
users
iv users
Definition:IVUsers.cpp:48
Instructions.h
isZero
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
Definition:Lint.cpp:557
Matrix
Live Register Matrix
Definition:LiveRegMatrix.cpp:44
LoopInfo.h
LoopUtils.h
getSubprogram
static DISubprogram * getSubprogram(DIScope *Scope)
Helper function to either return Scope, if it is a subprogram or the attached subprogram for a local ...
Definition:LowerMatrixIntrinsics.cpp:94
ForceFusion
static cl::opt< bool > ForceFusion("force-fuse-matrix", cl::init(false), cl::Hidden, cl::desc("Force matrix instruction fusion even if not profitable."))
VerifyShapeInfo
static cl::opt< bool > VerifyShapeInfo("verify-matrix-shapes", cl::Hidden, cl::desc("Enable/disable matrix shape verification."), cl::init(false))
isSplat
static bool isSplat(Value *V)
Return true if V is a splat of a value (which is used when multiplying a matrix with a scalar).
Definition:LowerMatrixIntrinsics.cpp:102
TileUseLoops
static cl::opt< bool > TileUseLoops("fuse-matrix-use-loops", cl::init(false), cl::Hidden, cl::desc("Generate loop nest for tiling."))
FuseMatrix
static cl::opt< bool > FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden, cl::desc("Enable/disable fusing matrix instructions."))
m_AnyAdd
auto m_AnyAdd(const LTy &L, const RTy &R)
Match any add operation (fp or integer).
Definition:LowerMatrixIntrinsics.cpp:116
AllowContractEnabled
static cl::opt< bool > AllowContractEnabled("matrix-allow-contract", cl::init(false), cl::Hidden, cl::desc("Allow the use of FMAs if available and profitable. This may " "result in different results, due to less rounding error."))
MatrixLayoutTy
MatrixLayoutTy
Definition:LowerMatrixIntrinsics.cpp:79
MatrixLayoutTy::RowMajor
@ RowMajor
MatrixLayoutTy::ColumnMajor
@ ColumnMajor
m_AnyMul
auto m_AnyMul(const LTy &L, const RTy &R)
Match any mul operation (fp or integer).
Definition:LowerMatrixIntrinsics.cpp:110
PrintAfterTransposeOpt
static cl::opt< bool > PrintAfterTransposeOpt("matrix-print-after-transpose-opt", cl::init(false))
DEBUG_TYPE
#define DEBUG_TYPE
Definition:LowerMatrixIntrinsics.cpp:53
TileSize
static cl::opt< unsigned > TileSize("fuse-matrix-tile-size", cl::init(4), cl::Hidden, cl::desc("Tile size for matrix instruction fusion using square-shaped tiles."))
MatrixLayout
static cl::opt< MatrixLayoutTy > MatrixLayout("matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor), cl::desc("Sets the default matrix layout"), cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major", "Use column-major layout"), clEnumValN(MatrixLayoutTy::RowMajor, "row-major", "Use row-major layout")))
LowerMatrixIntrinsics.h
F
#define F(x, y, z)
Definition:MD5.cpp:55
I
#define I(x, y, z)
Definition:MD5.cpp:58
MatrixBuilder.h
MatrixUtils.h
T1
#define T1
Definition:Mips16ISelLowering.cpp:340
II
uint64_t IntrinsicInst * II
Definition:NVVMIntrRange.cpp:51
OptimizationRemarkEmitter.h
Operation
PowerPC Reduce CR logical Operation
Definition:PPCReduceCRLogicals.cpp:735
PatternMatch.h
PostOrderIterator.h
This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.
assert
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
getNumElements
static unsigned getNumElements(Type *Ty)
Definition:SLPVectorizer.cpp:254
extractVector
static Value * extractVector(IRBuilderTy &IRB, Value *V, unsigned BeginIndex, unsigned EndIndex, const Twine &Name)
Definition:SROA.cpp:2550
insertVector
static Value * insertVector(IRBuilderTy &IRB, Value *Old, Value *V, unsigned BeginIndex, const Twine &Name)
Definition:SROA.cpp:2572
OS
raw_pwrite_stream & OS
Definition:SampleProfWriter.cpp:51
ScopeExit.h
This file defines the make_scope_exit function, which executes user-defined cleanup logic at scope ex...
SmallSet.h
This file defines the SmallSet class.
SmallVector.h
This file defines the SmallVector class.
getType
static SymbolRef::Type getType(const Symbol *Sym)
Definition:TapiFile.cpp:39
BlockSize
static const int BlockSize
Definition:TarWriter.cpp:33
Ptr
@ Ptr
Definition:TargetLibraryInfo.cpp:77
TargetTransformInfo.h
This pass exposes codegen information to IR-level passes.
getOpcode
static std::optional< unsigned > getOpcode(ArrayRef< VPValue * > Values)
Returns the opcode of Values or ~0 if they do not all agree.
Definition:VPlanSLP.cpp:191
ValueTracking.h
VectorUtils.h
LowerStore
static SDValue LowerStore(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG)
Definition:X86ISelLowering.cpp:25165
LowerLoad
static SDValue LowerLoad(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG)
Definition:X86ISelLowering.cpp:25252
RHS
Value * RHS
Definition:X86PartialReduction.cpp:74
LHS
Value * LHS
Definition:X86PartialReduction.cpp:73
Mul
BinaryOperator * Mul
Definition:X86PartialReduction.cpp:68
T
VectorType
Definition:ItaniumDemangle.h:1173
bool
llvm::AAManager
A manager for alias analyses.
Definition:AliasAnalysis.h:933
llvm::AAResults
Definition:AliasAnalysis.h:314
llvm::AAResults::isNoAlias
bool isNoAlias(const MemoryLocation &LocA, const MemoryLocation &LocB)
A trivial helper function to check to see if the specified pointers are no-alias.
Definition:AliasAnalysis.h:368
llvm::AllocaInst
an instruction to allocate memory on the stack
Definition:Instructions.h:63
llvm::AllocaInst::getAlign
Align getAlign() const
Return the alignment of the memory that is being allocated by the instruction.
Definition:Instructions.h:124
llvm::AnalysisManager
A container for analyses that lazily runs them and caches their results.
Definition:PassManager.h:253
llvm::AnalysisManager::getResult
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Definition:PassManager.h:410
llvm::ArrayRef
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition:ArrayRef.h:41
llvm::BasicBlock
LLVM Basic Block Representation.
Definition:BasicBlock.h:61
llvm::BasicBlock::begin
iterator begin()
Instruction iterator methods.
Definition:BasicBlock.h:451
llvm::BasicBlock::rbegin
reverse_iterator rbegin()
Definition:BasicBlock.h:467
llvm::BasicBlock::reverse_iterator
InstListType::reverse_iterator reverse_iterator
Definition:BasicBlock.h:179
llvm::BasicBlock::rend
reverse_iterator rend()
Definition:BasicBlock.h:469
llvm::BasicBlock::getTerminator
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition:BasicBlock.h:240
llvm::BinaryOperator
Definition:InstrTypes.h:170
llvm::BinaryOperator::getOpcode
BinaryOps getOpcode() const
Definition:InstrTypes.h:370
llvm::CallBase::getCalledFunction
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
Definition:InstrTypes.h:1341
llvm::CallBase::arg_begin
User::op_iterator arg_begin()
Return the iterator pointing to the beginning of the argument list.
Definition:InstrTypes.h:1261
llvm::CallBase::getParamAlign
MaybeAlign getParamAlign(unsigned ArgNo) const
Extract the alignment for a call or parameter (0=unknown).
Definition:InstrTypes.h:1748
llvm::CallBase::getArgOperand
Value * getArgOperand(unsigned i) const
Definition:InstrTypes.h:1286
llvm::CallBase::arg_end
User::op_iterator arg_end()
Return the iterator pointing to the end of the argument list.
Definition:InstrTypes.h:1267
llvm::CallInst
This class represents a function call, abstracting a target machine's calling convention.
Definition:Instructions.h:1479
llvm::ConstantAggregateZero::get
static ConstantAggregateZero * get(Type *Ty)
Definition:Constants.cpp:1672
llvm::ConstantInt
This is the shared class of boolean and integer constants.
Definition:Constants.h:83
llvm::DILocalScope::getSubprogram
DISubprogram * getSubprogram() const
Get the subprogram for this scope.
Definition:DebugInfoMetadata.cpp:1051
llvm::DILocation
Debug location.
Definition:DebugInfoMetadata.h:1988
llvm::DIScope
Base class for scope-like contexts.
Definition:DebugInfoMetadata.h:519
llvm::DISubprogram
Subprogram description.
Definition:DebugInfoMetadata.h:1710
llvm::DWARFExpression::Operation
This class represents an Operation in the Expression.
Definition:DWARFExpression.h:32
llvm::DataLayout
A parsed version of the target data layout string in and methods for querying it.
Definition:DataLayout.h:63
llvm::DebugLoc
A debug info location.
Definition:DebugLoc.h:33
llvm::DebugLoc::getInlinedAt
DILocation * getInlinedAt() const
Definition:DebugLoc.cpp:39
llvm::DenseMapBase::find
iterator find(const_arg_type_t< KeyT > Val)
Definition:DenseMap.h:156
llvm::DenseMapBase::erase
bool erase(const KeyT &Val)
Definition:DenseMap.h:321
llvm::DenseMapBase::count
size_type count(const_arg_type_t< KeyT > Val) const
Return 1 if the specified key is in the map, 0 otherwise.
Definition:DenseMap.h:152
llvm::DenseMapBase::end
iterator end()
Definition:DenseMap.h:84
llvm::DenseMapBase::insert
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition:DenseMap.h:211
llvm::DenseMap
Definition:DenseMap.h:727
llvm::DomTreeUpdater
Definition:DomTreeUpdater.h:30
llvm::DominatorTreeAnalysis
Analysis pass which computes a DominatorTree.
Definition:Dominators.h:279
llvm::DominatorTreeBase::applyUpdates
void applyUpdates(ArrayRef< UpdateType > Updates)
Inform the dominator tree about a sequence of CFG edge insertions and deletions and perform a batch u...
Definition:GenericDomTree.h:612
llvm::DominatorTreeBase::Delete
static constexpr UpdateKind Delete
Definition:GenericDomTree.h:253
llvm::DominatorTreeBase::Insert
static constexpr UpdateKind Insert
Definition:GenericDomTree.h:252
llvm::DominatorTree
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition:Dominators.h:162
llvm::DominatorTree::dominates
bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
Definition:Dominators.cpp:122
llvm::FastMathFlags
Convenience struct for specifying and reasoning about fast-math flags.
Definition:FMF.h:20
llvm::FastMathFlags::setAllowContract
void setAllowContract(bool B=true)
Definition:FMF.h:91
llvm::FastMathFlags::allowReassoc
bool allowReassoc() const
Flag queries.
Definition:FMF.h:65
llvm::FastMathFlags::allowContract
bool allowContract() const
Definition:FMF.h:70
llvm::FixedVectorType
Class to represent fixed width SIMD vectors.
Definition:DerivedTypes.h:563
llvm::FixedVectorType::get
static FixedVectorType * get(Type *ElementType, unsigned NumElts)
Definition:Type.cpp:791
llvm::Function
Definition:Function.h:63
llvm::Function::getIntrinsicID
Intrinsic::ID getIntrinsicID() const LLVM_READONLY
getIntrinsicID - This method returns the ID number of the specified function, or Intrinsic::not_intri...
Definition:Function.h:251
llvm::Function::isIntrinsic
bool isIntrinsic() const
isIntrinsic - Returns true if the function's name starts with "llvm.".
Definition:Function.h:256
llvm::IRBuilderBase::CreateFAddReduce
CallInst * CreateFAddReduce(Value *Acc, Value *Src)
Create a sequential vector fadd reduction intrinsic of the source vector.
Definition:IRBuilder.cpp:402
llvm::IRBuilderBase::CreateICmpULT
Value * CreateICmpULT(Value *LHS, Value *RHS, const Twine &Name="")
Definition:IRBuilder.h:2286
llvm::IRBuilderBase::CreateFSub
Value * CreateFSub(Value *L, Value *R, const Twine &Name="", MDNode *FPMD=nullptr)
Definition:IRBuilder.h:1595
llvm::IRBuilderBase::CreateInsertElement
Value * CreateInsertElement(Type *VecTy, Value *NewElt, Value *Idx, const Twine &Name="")
Definition:IRBuilder.h:2511
llvm::IRBuilderBase::CreateAlloca
AllocaInst * CreateAlloca(Type *Ty, unsigned AddrSpace, Value *ArraySize=nullptr, const Twine &Name="")
Definition:IRBuilder.h:1781
llvm::IRBuilderBase::CreateExtractElement
Value * CreateExtractElement(Value *Vec, Value *Idx, const Twine &Name="")
Definition:IRBuilder.h:2499
llvm::IRBuilderBase::CreateAlignedLoad
LoadInst * CreateAlignedLoad(Type *Ty, Value *Ptr, MaybeAlign Align, const char *Name)
Definition:IRBuilder.h:1815
llvm::IRBuilderBase::CreateFAdd
Value * CreateFAdd(Value *L, Value *R, const Twine &Name="", MDNode *FPMD=nullptr)
Definition:IRBuilder.h:1576
llvm::IRBuilderBase::CreateVectorSplat
Value * CreateVectorSplat(unsigned NumElts, Value *V, const Twine &Name="")
Return a vector value that contains.
Definition:IRBuilder.cpp:1163
llvm::IRBuilderBase::CreateAddReduce
CallInst * CreateAddReduce(Value *Src)
Create a vector int add reduction intrinsic of the source vector.
Definition:IRBuilder.cpp:412
llvm::IRBuilderBase::getIntPtrTy
IntegerType * getIntPtrTy(const DataLayout &DL, unsigned AddrSpace=0)
Fetch the type of an integer with size at least as big as that of a pointer in the given address spac...
Definition:IRBuilder.h:594
llvm::IRBuilderBase::setFastMathFlags
void setFastMathFlags(FastMathFlags NewFMF)
Set the fast-math flags to be used with generated fp-math operators.
Definition:IRBuilder.h:330
llvm::IRBuilderBase::CreateGEP
Value * CreateGEP(Type *Ty, Value *Ptr, ArrayRef< Value * > IdxList, const Twine &Name="", GEPNoWrapFlags NW=GEPNoWrapFlags::none())
Definition:IRBuilder.h:1874
llvm::IRBuilderBase::getInt64
ConstantInt * getInt64(uint64_t C)
Get a constant 64-bit value.
Definition:IRBuilder.h:510
llvm::IRBuilderBase::CreateIntrinsic
CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > Types, ArrayRef< Value * > Args, FMFSource FMFSource={}, const Twine &Name="")
Create a call to intrinsic ID with Args, mangled using Types.
Definition:IRBuilder.cpp:900
llvm::IRBuilderBase::CreatePHI
PHINode * CreatePHI(Type *Ty, unsigned NumReservedValues, const Twine &Name="")
Definition:IRBuilder.h:2435
llvm::IRBuilderBase::CreateSub
Value * CreateSub(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Definition:IRBuilder.h:1387
llvm::IRBuilderBase::getIntN
ConstantInt * getIntN(unsigned N, uint64_t C)
Get a constant N-bit value, zero extended or truncated from a 64-bit value.
Definition:IRBuilder.h:516
llvm::IRBuilderBase::CreateCondBr
BranchInst * CreateCondBr(Value *Cond, BasicBlock *True, BasicBlock *False, MDNode *BranchWeights=nullptr, MDNode *Unpredictable=nullptr)
Create a conditional 'br Cond, TrueDest, FalseDest' instruction.
Definition:IRBuilder.h:1164
llvm::IRBuilderBase::CreateLoad
LoadInst * CreateLoad(Type *Ty, Value *Ptr, const char *Name)
Provided to resolve 'CreateLoad(Ty, Ptr, "...")' correctly, instead of converting the string to 'bool...
Definition:IRBuilder.h:1798
llvm::IRBuilderBase::CreateShuffleVector
Value * CreateShuffleVector(Value *V1, Value *V2, Value *Mask, const Twine &Name="")
Definition:IRBuilder.h:2533
llvm::IRBuilderBase::CreateAdd
Value * CreateAdd(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Definition:IRBuilder.h:1370
llvm::IRBuilderBase::CreatePtrToInt
Value * CreatePtrToInt(Value *V, Type *DestTy, const Twine &Name="")
Definition:IRBuilder.h:2142
llvm::IRBuilderBase::SetInsertPoint
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
Definition:IRBuilder.h:199
llvm::IRBuilderBase::CreateAlignedStore
StoreInst * CreateAlignedStore(Value *Val, Value *Ptr, MaybeAlign Align, bool isVolatile=false)
Definition:IRBuilder.h:1834
llvm::IRBuilderBase::CreateFMul
Value * CreateFMul(Value *L, Value *R, const Twine &Name="", MDNode *FPMD=nullptr)
Definition:IRBuilder.h:1614
llvm::IRBuilderBase::CreateFNeg
Value * CreateFNeg(Value *V, const Twine &Name="", MDNode *FPMathTag=nullptr)
Definition:IRBuilder.h:1742
llvm::IRBuilderBase::CreateMemCpy
CallInst * CreateMemCpy(Value *Dst, MaybeAlign DstAlign, Value *Src, MaybeAlign SrcAlign, uint64_t Size, bool isVolatile=false, MDNode *TBAATag=nullptr, MDNode *TBAAStructTag=nullptr, MDNode *ScopeTag=nullptr, MDNode *NoAliasTag=nullptr)
Create and insert a memcpy between the specified pointers.
Definition:IRBuilder.h:677
llvm::IRBuilderBase::CreateMul
Value * CreateMul(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Definition:IRBuilder.h:1404
llvm::IRBuilder
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition:IRBuilder.h:2705
llvm::InstructionCost
Definition:InstructionCost.h:29
llvm::InstructionCost::getInvalid
static InstructionCost getInvalid(CostType Val=0)
Definition:InstructionCost.h:73
llvm::Instruction
Definition:Instruction.h:68
llvm::Instruction::setFastMathFlags
void setFastMathFlags(FastMathFlags FMF)
Convenience function for setting multiple fast-math flags on this instruction, which must be an opera...
Definition:Instruction.cpp:601
llvm::Instruction::eraseFromParent
InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
Definition:Instruction.cpp:94
llvm::Instruction::getFastMathFlags
FastMathFlags getFastMathFlags() const LLVM_READONLY
Convenience function for getting all the fast-math flags, which must be an operator which supports th...
Definition:Instruction.cpp:651
llvm::IntrinsicInst
A wrapper class for inspecting calls to intrinsic functions.
Definition:IntrinsicInst.h:48
llvm::LoadInst
An instruction for reading from memory.
Definition:Instructions.h:176
llvm::LoadInst::isVolatile
bool isVolatile() const
Return true if this is a load from a volatile memory location.
Definition:Instructions.h:205
llvm::LoadInst::getAlign
Align getAlign() const
Return the alignment of the access that is being performed.
Definition:Instructions.h:211
llvm::LocationSize::getValue
TypeSize getValue() const
Definition:MemoryLocation.h:170
llvm::LoopAnalysis
Analysis pass that exposes the LoopInfo for a function.
Definition:LoopInfo.h:566
llvm::LoopInfoBase::getLoopFor
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
Definition:GenericLoopInfo.h:606
llvm::LoopInfo
Definition:LoopInfo.h:407
llvm::LowerMatrixIntrinsicsPass::run
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
Definition:LowerMatrixIntrinsics.cpp:2644
llvm::LowerMatrixIntrinsicsPass::printPipeline
void printPipeline(raw_ostream &OS, function_ref< StringRef(StringRef)> MapClassName2PassName)
Definition:LowerMatrixIntrinsics.cpp:2660
llvm::MapVector
This class implements a map that also provides access to all stored values in a deterministic order.
Definition:MapVector.h:36
llvm::MapVector::end
iterator end()
Definition:MapVector.h:71
llvm::MapVector::find
iterator find(const KeyT &Key)
Definition:MapVector.h:167
llvm::MapVector::insert
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition:MapVector.h:141
llvm::MatrixBuilder
Definition:MatrixBuilder.h:33
llvm::MatrixBuilder::CreateMatrixTranspose
CallInst * CreateMatrixTranspose(Value *Matrix, unsigned Rows, unsigned Columns, const Twine &Name="")
Create a llvm.matrix.transpose call, transposing Matrix with Rows rows and Columns columns.
Definition:MatrixBuilder.h:110
llvm::MatrixBuilder::CreateMatrixMultiply
CallInst * CreateMatrixMultiply(Value *LHS, Value *RHS, unsigned LHSRows, unsigned LHSColumns, unsigned RHSColumns, const Twine &Name="")
Create a llvm.matrix.multiply call, multiplying matrixes LHS and RHS.
Definition:MatrixBuilder.h:126
llvm::MemoryLocation
Representation for a specific memory location.
Definition:MemoryLocation.h:227
llvm::MemoryLocation::get
static MemoryLocation get(const LoadInst *LI)
Return a location with information about the memory reference by the given instruction.
Definition:MemoryLocation.cpp:35
llvm::MemoryLocation::Size
LocationSize Size
The maximum size of the location, in address-units, or UnknownSize if the size is not known.
Definition:MemoryLocation.h:244
llvm::MemoryLocation::Ptr
const Value * Ptr
The address of the start of the location.
Definition:MemoryLocation.h:235
llvm::MemoryLocation::getForArgument
static MemoryLocation getForArgument(const CallBase *Call, unsigned ArgIdx, const TargetLibraryInfo *TLI)
Return a location representing a particular argument of a call.
Definition:MemoryLocation.cpp:159
llvm::OptimizationRemarkEmitterAnalysis
Definition:OptimizationRemarkEmitter.h:164
llvm::OptimizationRemarkEmitter
The optimization diagnostic interface.
Definition:OptimizationRemarkEmitter.h:32
llvm::OptimizationRemarkEmitter::allowExtraAnalysis
bool allowExtraAnalysis(StringRef PassName) const
Whether we allow for extra compile-time budget to perform more analysis to produce fewer false positi...
Definition:OptimizationRemarkEmitter.h:97
llvm::OptimizationRemarkEmitter::emit
void emit(DiagnosticInfoOptimizationBase &OptDiag)
Output the remark via the diagnostic handler and to the optimization record file.
Definition:OptimizationRemarkEmitter.cpp:79
llvm::OptimizationRemark
Diagnostic information for applied optimization remarks.
Definition:DiagnosticInfo.h:762
llvm::PHINode
Definition:Instructions.h:2600
llvm::PoisonValue::get
static PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
Definition:Constants.cpp:1878
llvm::PreservedAnalyses
A set of analyses that are preserved following a run of a transformation pass.
Definition:Analysis.h:111
llvm::PreservedAnalyses::all
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition:Analysis.h:117
llvm::PreservedAnalyses::preserve
void preserve()
Mark an analysis as preserved.
Definition:Analysis.h:131
llvm::ReversePostOrderTraversal
Definition:PostOrderIterator.h:299
llvm::SetVector
A vector that has set insertion semantics.
Definition:SetVector.h:57
llvm::SetVector::size
size_type size() const
Determine the number of elements in the SetVector.
Definition:SetVector.h:98
llvm::SetVector::count
size_type count(const key_type &key) const
Count the number of elements of a given key in the SetVector.
Definition:SetVector.h:264
llvm::SetVector::insert
bool insert(const value_type &X)
Insert a new element into the SetVector.
Definition:SetVector.h:162
llvm::SmallPtrSetImplBase::empty
bool empty() const
Definition:SmallPtrSet.h:93
llvm::SmallPtrSetImpl
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
Definition:SmallPtrSet.h:363
llvm::SmallPtrSetImpl::count
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
Definition:SmallPtrSet.h:452
llvm::SmallPtrSetImpl::insert
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition:SmallPtrSet.h:384
llvm::SmallPtrSetImpl::contains
bool contains(ConstPtrType Ptr) const
Definition:SmallPtrSet.h:458
llvm::SmallPtrSet
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
Definition:SmallPtrSet.h:519
llvm::SmallSetVector
A SetVector that performs no allocations if smaller than a certain size.
Definition:SetVector.h:370
llvm::SmallSet
SmallSet - This maintains a set of unique values, optimizing for the case when the set is small (less...
Definition:SmallSet.h:132
llvm::SmallSet::empty
bool empty() const
Definition:SmallSet.h:168
llvm::SmallSet::erase
bool erase(const T &V)
Definition:SmallSet.h:193
llvm::SmallSet::insert
std::pair< const_iterator, bool > insert(const T &V)
insert - Insert an element into the set if it isn't already there.
Definition:SmallSet.h:181
llvm::SmallVectorBase::empty
bool empty() const
Definition:SmallVector.h:81
llvm::SmallVectorBase::size
size_t size() const
Definition:SmallVector.h:78
llvm::SmallVectorImpl
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition:SmallVector.h:573
llvm::SmallVectorImpl::pop_back_val
T pop_back_val()
Definition:SmallVector.h:673
llvm::SmallVectorImpl::append
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
Definition:SmallVector.h:683
llvm::SmallVectorTemplateBase::pop_back
void pop_back()
Definition:SmallVector.h:425
llvm::SmallVectorTemplateBase::push_back
void push_back(const T &Elt)
Definition:SmallVector.h:413
llvm::SmallVectorTemplateCommon::end
iterator end()
Definition:SmallVector.h:269
llvm::SmallVectorTemplateCommon::begin
iterator begin()
Definition:SmallVector.h:267
llvm::SmallVectorTemplateCommon::back
reference back()
Definition:SmallVector.h:308
llvm::SmallVector
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition:SmallVector.h:1196
llvm::StoreInst
An instruction for storing to memory.
Definition:Instructions.h:292
llvm::StoreInst::getAlign
Align getAlign() const
Definition:Instructions.h:333
llvm::StoreInst::isVolatile
bool isVolatile() const
Return true if this is a store to a volatile memory location.
Definition:Instructions.h:325
llvm::StringRef
StringRef - Represent a constant reference to a string, i.e.
Definition:StringRef.h:51
llvm::StringRef::drop_front
StringRef drop_front(size_t N=1) const
Return a StringRef equal to 'this' but with the first N elements dropped.
Definition:StringRef.h:609
llvm::StringRef::size
constexpr size_t size() const
size - Get the string size.
Definition:StringRef.h:150
llvm::TargetIRAnalysis
Analysis pass providing the TargetTransformInfo.
Definition:TargetTransformInfo.h:3194
llvm::TargetTransformInfo
This pass provides access to the codegen interfaces that are needed for IR-level transformations.
Definition:TargetTransformInfo.h:212
llvm::TargetTransformInfo::getRegisterBitWidth
TypeSize getRegisterBitWidth(RegisterKind K) const
Definition:TargetTransformInfo.cpp:776
llvm::TargetTransformInfo::getMemoryOpCost
InstructionCost getMemoryOpCost(unsigned Opcode, Type *Src, Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput, OperandValueInfo OpdInfo={OK_AnyValue, OP_None}, const Instruction *I=nullptr) const
Definition:TargetTransformInfo.cpp:1125
llvm::TargetTransformInfo::getArithmeticReductionCost
InstructionCost getArithmeticReductionCost(unsigned Opcode, VectorType *Ty, std::optional< FastMathFlags > FMF, TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput) const
Calculate the cost of vector reduction intrinsics.
Definition:TargetTransformInfo.cpp:1215
llvm::TargetTransformInfo::getRegisterClassForType
unsigned getRegisterClassForType(bool Vector, Type *Ty=nullptr) const
Definition:TargetTransformInfo.cpp:767
llvm::TargetTransformInfo::TCK_RecipThroughput
@ TCK_RecipThroughput
Reciprocal throughput.
Definition:TargetTransformInfo.h:264
llvm::TargetTransformInfo::getArithmeticInstrCost
InstructionCost getArithmeticInstrCost(unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput, TTI::OperandValueInfo Opd1Info={TTI::OK_AnyValue, TTI::OP_None}, TTI::OperandValueInfo Opd2Info={TTI::OK_AnyValue, TTI::OP_None}, ArrayRef< const Value * > Args={}, const Instruction *CxtI=nullptr, const TargetLibraryInfo *TLibInfo=nullptr) const
This is an approximation of reciprocal throughput of a math/logic op.
Definition:TargetTransformInfo.cpp:940
llvm::TargetTransformInfo::RGK_FixedWidthVector
@ RGK_FixedWidthVector
Definition:TargetTransformInfo.h:1180
llvm::TargetTransformInfo::getShuffleCost
InstructionCost getShuffleCost(ShuffleKind Kind, VectorType *Tp, ArrayRef< int > Mask={}, TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput, int Index=0, VectorType *SubTp=nullptr, ArrayRef< const Value * > Args={}, const Instruction *CxtI=nullptr) const
Definition:TargetTransformInfo.cpp:976
llvm::TargetTransformInfo::getNumberOfRegisters
unsigned getNumberOfRegisters(unsigned ClassID) const
Definition:TargetTransformInfo.cpp:759
llvm::TargetTransformInfo::SK_Splice
@ SK_Splice
Concatenates elements from the first input vector with elements of the second input vector.
Definition:TargetTransformInfo.h:1111
llvm::Twine
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
Definition:Twine.h:81
llvm::TypeSize
Definition:TypeSize.h:334
llvm::Type
The instances of the Type class are immutable: once they are created, they are never changed.
Definition:Type.h:45
llvm::Type::getScalarSizeInBits
unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
llvm::Type::getPrimitiveSizeInBits
TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
llvm::Type::isVoidTy
bool isVoidTy() const
Return true if this is 'void'.
Definition:Type.h:139
llvm::Type::getScalarType
Type * getScalarType() const
If this is a vector type, return the element type, otherwise return 'this'.
Definition:Type.h:355
llvm::UnaryOperator
Definition:InstrTypes.h:100
llvm::UnaryOperator::getOpcode
UnaryOps getOpcode() const
Definition:InstrTypes.h:153
llvm::Use
A Use represents the edge between a Value definition and its users.
Definition:Use.h:43
llvm::User
Definition:User.h:44
llvm::User::getOperand
Value * getOperand(unsigned i) const
Definition:User.h:228
llvm::Value
LLVM Value Representation.
Definition:Value.h:74
llvm::Value::getType
Type * getType() const
All values are typed, get the type of this value.
Definition:Value.h:255
llvm::Value::user_begin
user_iterator user_begin()
Definition:Value.h:397
llvm::Value::hasOneUse
bool hasOneUse() const
Return true if there is exactly one use of this value.
Definition:Value.h:434
llvm::Value::replaceAllUsesWith
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition:Value.cpp:534
llvm::Value::users
iterator_range< user_iterator > users()
Definition:Value.h:421
llvm::Value::hasNUses
bool hasNUses(unsigned N) const
Return true if this Value has exactly N uses.
Definition:Value.cpp:149
llvm::Value::use_empty
bool use_empty() const
Definition:Value.h:344
llvm::Value::uses
iterator_range< use_iterator > uses()
Definition:Value.h:376
llvm::Value::getName
StringRef getName() const
Return a constant reference to the value's name.
Definition:Value.cpp:309
llvm::VectorType::getElementType
Type * getElementType() const
Definition:DerivedTypes.h:460
llvm::cl::opt
Definition:CommandLine.h:1423
llvm::details::FixedOrScalableQuantity::getFixedValue
constexpr ScalarTy getFixedValue() const
Definition:TypeSize.h:202
llvm::function_ref
An efficient, type-erasing, non-owning reference to a callable.
Definition:STLFunctionalExtras.h:37
llvm::ilist_detail::node_parent_access::getParent
const ParentTy * getParent() const
Definition:ilist_node.h:32
llvm::ilist_node_impl::getIterator
self_iterator getIterator()
Definition:ilist_node.h:132
llvm::iterator_range
A range adaptor for a pair of iterators.
Definition:iterator_range.h:42
llvm::raw_ostream
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition:raw_ostream.h:52
llvm::raw_string_ostream
A raw_ostream that writes to an std::string.
Definition:raw_ostream.h:661
uint64_t
llvm_unreachable
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
Definition:ErrorHandling.h:143
llvm::AArch64PACKey::IB
@ IB
Definition:AArch64BaseInfo.h:876
llvm::ARM_MB::ST
@ ST
Definition:ARMBaseInfo.h:73
llvm::ARM::ProfileKind::M
@ M
llvm::BitmaskEnumDetail::Mask
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
Definition:BitmaskEnum.h:125
llvm::CallingConv::C
@ C
The default llvm calling convention, compatible with C.
Definition:CallingConv.h:34
llvm::Intrinsic::getBaseName
StringRef getBaseName(ID id)
Return the LLVM name for an intrinsic, without encoded types for overloading, such as "llvm....
Definition:Intrinsics.cpp:42
llvm::M68k::MemAddrModeKind::U
@ U
llvm::M68k::MemAddrModeKind::V
@ V
llvm::M68k::MemAddrModeKind::K
@ K
llvm::M68k::MemAddrModeKind::L
@ L
llvm::PatternMatch::m_Store
TwoOps_match< ValueOpTy, PointerOpTy, Instruction::Store > m_Store(const ValueOpTy &ValueOp, const PointerOpTy &PointerOp)
Matches StoreInst.
Definition:PatternMatch.h:1930
llvm::PatternMatch::m_Add
BinaryOp_match< LHS, RHS, Instruction::Add > m_Add(const LHS &L, const RHS &R)
Definition:PatternMatch.h:1102
llvm::PatternMatch::m_BinOp
class_match< BinaryOperator > m_BinOp()
Match an arbitrary binary operation and ignore it.
Definition:PatternMatch.h:100
llvm::PatternMatch::m_SpecificInt
specific_intval< false > m_SpecificInt(const APInt &V)
Match a specific integer value or vector with all elements equal to the value.
Definition:PatternMatch.h:982
llvm::PatternMatch::m_FMul
BinaryOp_match< LHS, RHS, Instruction::FMul > m_FMul(const LHS &L, const RHS &R)
Definition:PatternMatch.h:1174
llvm::PatternMatch::match
bool match(Val *V, const Pattern &P)
Definition:PatternMatch.h:49
llvm::PatternMatch::m_ConstantInt
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
Definition:PatternMatch.h:168
llvm::PatternMatch::m_FAdd
BinaryOp_match< LHS, RHS, Instruction::FAdd > m_FAdd(const LHS &L, const RHS &R)
Definition:PatternMatch.h:1108
llvm::PatternMatch::m_Mul
BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)
Definition:PatternMatch.h:1168
llvm::PatternMatch::m_OneUse
OneUse_match< T > m_OneUse(const T &SubPattern)
Definition:PatternMatch.h:67
llvm::PatternMatch::m_Load
OneOps_match< OpTy, Instruction::Load > m_Load(const OpTy &Op)
Matches LoadInst.
Definition:PatternMatch.h:1923
llvm::PatternMatch::m_Value
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
Definition:PatternMatch.h:92
llvm::PatternMatch::m_CombineOr
match_combine_or< LTy, RTy > m_CombineOr(const LTy &L, const RTy &R)
Combine two pattern matchers matching L || R.
Definition:PatternMatch.h:239
llvm::RISCVFenceField::R
@ R
Definition:RISCVBaseInfo.h:373
llvm::SIEncodingFamily::SI
@ SI
Definition:SIDefines.h:36
llvm::SPII::Store
@ Store
Definition:SparcInstrInfo.h:33
llvm::SPII::Load
@ Load
Definition:SparcInstrInfo.h:32
llvm::X86AS::SS
@ SS
Definition:X86.h:212
llvm::X86II::TA
@ TA
Definition:X86BaseInfo.h:738
llvm::cl::Hidden
@ Hidden
Definition:CommandLine.h:137
llvm::cl::values
ValuesClass values(OptsTy... Options)
Helper to build a ValuesClass by forwarding a variable number of arguments as an initializer list to ...
Definition:CommandLine.h:711
llvm::cl::init
initializer< Ty > init(const Ty &Val)
Definition:CommandLine.h:443
llvm::codeview::EncodedFramePtrReg::BasePtr
@ BasePtr
llvm::codeview::FrameCookieKind::Copy
@ Copy
llvm::dxil::ElementType
ElementType
The element type of an SRV or UAV resource.
Definition:DXILABI.h:58
llvm::ms_demangle::IntrinsicFunctionKind::New
@ New
llvm::ms_demangle::QualifierMangleMode::Result
@ Result
llvm::ore::NV
DiagnosticInfoOptimizationBase::Argument NV
Definition:OptimizationRemarkEmitter.h:135
llvm::rdf::Phi
NodeAddr< PhiNode * > Phi
Definition:RDFGraph.h:390
llvm::rdf::Func
NodeAddr< FuncNode * > Func
Definition:RDFGraph.h:393
llvm
This is an optimization pass for GlobalISel generic memory operations.
Definition:AddressRanges.h:18
llvm::Offset
@ Offset
Definition:DWP.cpp:480
llvm::PseudoProbeType::Block
@ Block
llvm::size
auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)
Get the size of a range.
Definition:STLExtras.h:1697
llvm::make_scope_exit
detail::scope_exit< std::decay_t< Callable > > make_scope_exit(Callable &&F)
Definition:ScopeExit.h:59
llvm::enumerate
auto enumerate(FirstRange &&First, RestRanges &&...Rest)
Given two or more input ranges, returns a new range whose values are tuples (A, B,...
Definition:STLExtras.h:2448
llvm::successors
auto successors(const MachineBasicBlock *BB)
Definition:MachineBasicBlock.h:1376
llvm::operator!=
bool operator!=(uint64_t V1, const APInt &V2)
Definition:APInt.h:2082
llvm::make_range
iterator_range< T > make_range(T x, T y)
Convenience function for iterating over sub-ranges.
Definition:iterator_range.h:77
llvm::operator+=
LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt & operator+=(DynamicAPInt &A, int64_t B)
Definition:DynamicAPInt.h:518
llvm::getUnderlyingObject
const Value * getUnderlyingObject(const Value *V, unsigned MaxLookup=6)
This method strips off any GEP address adjustments, pointer casts or llvm.threadlocal....
Definition:ValueTracking.cpp:6768
llvm::make_early_inc_range
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Definition:STLExtras.h:657
llvm::concatenateVectors
Value * concatenateVectors(IRBuilderBase &Builder, ArrayRef< Value * > Vecs)
Concatenate a list of vectors.
Definition:VectorUtils.cpp:1095
llvm::operator==
bool operator==(const AddressRangeValuePair &LHS, const AddressRangeValuePair &RHS)
Definition:AddressRanges.h:153
llvm::getPointerOperand
const Value * getPointerOperand(const Value *V)
A helper function that returns the pointer operand of a load, store or GEP instruction.
Definition:Instructions.h:4998
llvm::addStringMetadataToLoop
void addStringMetadataToLoop(Loop *TheLoop, const char *MDString, unsigned V=0)
Set input string into loop metadata by keeping other values intact.
Definition:LoopUtils.cpp:214
llvm::any_of
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition:STLExtras.h:1746
llvm::reverse
auto reverse(ContainerTy &&C)
Definition:STLExtras.h:420
llvm::write
Error write(MCStreamer &Out, ArrayRef< std::string > Inputs, OnCuIndexOverflow OverflowOptValue)
Definition:DWP.cpp:625
llvm::sort
void sort(IteratorTy Start, IteratorTy End)
Definition:STLExtras.h:1664
llvm::ComplexDeinterleavingOperation::Splat
@ Splat
llvm::dbgs
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition:Debug.cpp:163
llvm::report_fatal_error
void report_fatal_error(Error Err, bool gen_crash_diag=true)
Report a serious error, calling any installed error handler.
Definition:Error.cpp:167
llvm::errs
raw_fd_ostream & errs()
This returns a reference to a raw_ostream for standard error.
Definition:raw_ostream.cpp:907
llvm::RecurKind::Add
@ Add
Sum of integers.
llvm::Op
DWARFExpression::Operation Op
Definition:DWARFExpression.cpp:22
llvm::cast
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition:Casting.h:565
llvm::SplitBlock
BasicBlock * SplitBlock(BasicBlock *Old, BasicBlock::iterator SplitPt, DominatorTree *DT, LoopInfo *LI=nullptr, MemorySSAUpdater *MSSAU=nullptr, const Twine &BBName="", bool Before=false)
Split the specified block at the specified instruction.
Definition:BasicBlockUtils.cpp:1084
llvm::commonAlignment
Align commonAlignment(Align A, uint64_t Offset)
Returns the alignment that satisfies both alignments.
Definition:Alignment.h:212
llvm::VFParamKind::Vector
@ Vector
llvm::createSequentialMask
llvm::SmallVector< int, 16 > createSequentialMask(unsigned Start, unsigned NumInts, unsigned NumUndefs)
Create a sequential shuffle mask.
Definition:VectorUtils.cpp:1040
std::swap
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition:BitVector.h:860
N
#define N
llvm::Align
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition:Alignment.h:39
llvm::BitTracker
Definition:BitTracker.h:35
llvm::MaybeAlign
This struct is a compact representation of a valid (power of two) or undefined (0) alignment.
Definition:Alignment.h:117
llvm::PassInfoMixin
A CRTP mix-in to automatically provide informational APIs needed for passes.
Definition:PassManager.h:69
llvm::TileInfo
A helper struct to create IR loop nests for tiling in IR of the following form: for ColumnLoop....
Definition:MatrixUtils.h:31
llvm::cl::desc
Definition:CommandLine.h:409
llvm::indent
Definition:raw_ostream.h:781

Generated on Thu Jul 17 2025 17:13:19 for LLVM by doxygen 1.9.6
[8]ページ先頭

©2009-2025 Movatter.jp