|
5 | 5 | #include<ATen/native/LinearAlgebraUtils.h> |
6 | 6 | #include<ATen/native/TensorFactories.h> |
7 | 7 | #include<ATen/native/mps/OperationUtils.h> |
| 8 | +#include<fmt/format.h> |
8 | 9 |
|
9 | 10 | #ifndef AT_PER_OPERATOR_HEADERS |
10 | 11 | #include<ATen/Functions.h> |
|
26 | 27 | #include<ATen/native/mps/TriangularOps_metallib.h> |
27 | 28 | #endif |
28 | 29 |
|
29 | | -TORCH_IMPL_FUNC(triu_mps_out) |
30 | | -(const Tensor& self,int64_t k,const Tensor& output) { |
31 | | -usingnamespacemps; |
32 | | -using CachedGraph = MPSUnaryCachedGraph; |
33 | | - |
34 | | -if (self.numel() ==0) { |
35 | | -return; |
36 | | - } |
37 | | -auto stream =getCurrentMPSStream(); |
38 | | - |
39 | | - @autoreleasepool { |
40 | | - std::string key ="triu_mps_out" +mps::getTensorsStringKey({self}) +":" +std::to_string(k); |
41 | | -auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph,auto newCachedGraph) { |
42 | | - MPSGraphTensor* outputTensor =nil; |
43 | | -auto inputTensor =mpsGraphRankedPlaceHolder(mpsGraph, self); |
44 | | - |
45 | | -auto minusOneTensor = [mpsGraphconstantWithScalar:-1dataType:MPSDataTypeInt32]; |
46 | | - |
47 | | -if (k >0) { |
48 | | -auto diagMinusOneTensor = [mpsGraphconstantWithScalar:(k -1)dataType:MPSDataTypeInt32]; |
49 | | -auto onesTensor = [mpsGraphconstantWithScalar:1shape:inputTensor.shapedataType:MPSDataTypeInt32]; |
50 | | -auto maskTensor = [mpsGraphbandPartWithTensor:onesTensor |
51 | | -numLowerTensor:minusOneTensor |
52 | | -numUpperTensor:diagMinusOneTensor |
53 | | -name:nil]; |
54 | | - outputTensor = [mpsGraphselectWithPredicateTensor:maskTensor |
55 | | -truePredicateTensor:[mpsGraphconstantWithScalar:0dataType:inputTensor.dataType] |
56 | | -falsePredicateTensor:inputTensor |
57 | | -name:nil]; |
58 | | - }else { |
59 | | -auto minusDiagTensor = [mpsGraphconstantWithScalar:(-k)dataType:MPSDataTypeInt32]; |
60 | | - outputTensor = [mpsGraphbandPartWithTensor:inputTensor |
61 | | -numLowerTensor:minusDiagTensor |
62 | | -numUpperTensor:minusOneTensor |
63 | | -name:nil]; |
64 | | - } |
65 | | - |
66 | | - newCachedGraph->inputTensor_ = inputTensor; |
67 | | - newCachedGraph->outputTensor_ = outputTensor; |
68 | | - }); |
69 | | - |
70 | | -auto selfPlaceholder =Placeholder(cachedGraph->inputTensor_, self); |
71 | | -auto outputPlaceholder =Placeholder(cachedGraph->outputTensor_, output); |
72 | | -runMPSGraph(stream, cachedGraph->graph(),dictionaryFromPlaceholders(selfPlaceholder), outputPlaceholder); |
| 30 | +template<typename T> |
| 31 | +static std::vector<T>reverse_array(const IntArrayRef& arr) { |
| 32 | + std::vector<T>rc(arr.size()); |
| 33 | +for (constauto& i :c10::irange(arr.size())) { |
| 34 | + rc[i] = arr[arr.size() -1 - i]; |
73 | 35 | } |
| 36 | +return rc; |
74 | 37 | } |
75 | 38 |
|
76 | | -TORCH_IMPL_FUNC(tril_mps_out) |
77 | | -(const Tensor& self,int64_t k,const Tensor& output) { |
| 39 | +staticvoidtriu_tril_impl(const Tensor& self,int64_t k,const Tensor& out,const std::string& name) { |
78 | 40 | usingnamespacemps; |
79 | | -using CachedGraph = MPSUnaryCachedGraph; |
80 | | - |
81 | 41 | if (self.numel() ==0) { |
82 | 42 | return; |
83 | 43 | } |
84 | | - |
| 44 | +auto sizes = reverse_array<uint32_t>(self.sizes()); |
| 45 | +auto inp_strides = reverse_array<int32_t>(self.strides()); |
| 46 | +auto out_strides = reverse_array<int32_t>(out.strides()); |
| 47 | + std::array<int,2> k_ndim = {int(k),int(self.ndimension())}; |
| 48 | +constbool inplace = self.is_same(out); |
| 49 | +constauto kernel_name = |
| 50 | +fmt::format("{}{}_{}_{}", name, inplace ?"_inplace" :"","int",scalarToMetalTypeString(self)); |
| 51 | +auto triuPSO = lib.getPipelineStateForFunc(kernel_name); |
| 52 | +uint32_t max_threads_per_group = [triuPSOmaxTotalThreadsPerThreadgroup]; |
85 | 53 | auto stream =getCurrentMPSStream(); |
86 | | - |
87 | | - @autoreleasepool { |
88 | | - std::string key ="tril_mps_out" +mps::getTensorsStringKey({self}) +":" +std::to_string(k); |
89 | | -auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph,auto newCachedGraph) { |
90 | | - MPSGraphTensor* outputTensor =nil; |
91 | | - |
92 | | -auto inputTensor =mpsGraphRankedPlaceHolder(mpsGraph, self); |
93 | | -auto minusOneTensor = [mpsGraphconstantWithScalar:-1dataType:MPSDataTypeInt32]; |
94 | | - |
95 | | -if (k >=0) { |
96 | | -auto diagTensor = [mpsGraphconstantWithScalar:kdataType:MPSDataTypeInt32]; |
97 | | - outputTensor = [mpsGraphbandPartWithTensor:inputTensor |
98 | | -numLowerTensor:minusOneTensor |
99 | | -numUpperTensor:diagTensor |
100 | | -name:nil]; |
| 54 | +dispatch_sync_with_rethrow(stream->queue(), ^() { |
| 55 | + @autoreleasepool { |
| 56 | +auto computeEncoder = stream->commandEncoder(); |
| 57 | + [computeEncodersetComputePipelineState:triuPSO]; |
| 58 | +if (inplace) { |
| 59 | +mtl_setArgs(computeEncoder, self, inp_strides, sizes, k_ndim); |
101 | 60 | }else { |
102 | | -auto negDiagMinusOneTensor = [mpsGraphconstantWithScalar:(-k -1)dataType:MPSDataTypeInt32]; |
103 | | -auto complementTensor = [mpsGraphbandPartWithTensor:inputTensor |
104 | | -numLowerTensor:negDiagMinusOneTensor |
105 | | -numUpperTensor:minusOneTensor |
106 | | -name:nil]; |
107 | | -auto zeroTensor = [mpsGraphconstantWithScalar:0.0dataType:getMPSDataType(self)]; |
108 | | -auto mask = [mpsGraphequalWithPrimaryTensor:complementTensorsecondaryTensor:zeroTensorname:nil]; |
109 | | - outputTensor = [mpsGraphselectWithPredicateTensor:mask |
110 | | -truePredicateTensor:inputTensor |
111 | | -falsePredicateTensor:zeroTensor |
112 | | -name:nil]; |
| 61 | +mtl_setArgs(computeEncoder, out, self, out_strides, inp_strides, sizes, k_ndim); |
113 | 62 | } |
| 63 | + [computeEncoderdispatchThreads:MTLSizeMake(sizes[0], sizes[1],self.numel() / (sizes[0] * sizes[1])) |
| 64 | +threadsPerThreadgroup:MTLSizeMake(std::min(max_threads_per_group, sizes[0]),1,1)]; |
| 65 | + } |
| 66 | + }); |
| 67 | +} |
114 | 68 |
|
115 | | - newCachedGraph->inputTensor_ = inputTensor; |
116 | | - newCachedGraph->outputTensor_ = outputTensor; |
117 | | - }); |
118 | | - |
119 | | -auto selfPlaceholder =Placeholder(cachedGraph->inputTensor_, self); |
120 | | -auto outputPlaceholder =Placeholder(cachedGraph->outputTensor_, output); |
| 69 | +TORCH_IMPL_FUNC(triu_mps_out) |
| 70 | +(const Tensor& self,int64_t k,const Tensor& output) { |
| 71 | +triu_tril_impl(self, k, output,"triu"); |
| 72 | +} |
121 | 73 |
|
122 | | -runMPSGraph(stream, cachedGraph->graph(),dictionaryFromPlaceholders(selfPlaceholder), outputPlaceholder); |
123 | | - } |
| 74 | +TORCH_IMPL_FUNC(tril_mps_out) |
| 75 | +(const Tensor& self,int64_t k,const Tensor& output) { |
| 76 | +triu_tril_impl(self, k, output,"tril"); |
124 | 77 | } |
125 | 78 |
|
126 | 79 | Tensortril_indices_mps(int64_t row, |
|