Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit342aa8f

Browse files
malfetpytorchbot
authored andcommitted
[MPS] Reimplementtri[ul] as Metal shaders (#157179)
And add in-place flavor, as it is currently broken for non-contig tensorsPull Requestresolved:#157179Approved by:https://github.com/dcci(cherry picked from commita1e4f1f)
1 parenta6c044a commit342aa8f

File tree

4 files changed

+157
-85
lines changed

4 files changed

+157
-85
lines changed

‎aten/src/ATen/native/mps/kernels/TriangularOps.metal‎

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,119 @@
11
#include<metal_stdlib>
2+
23
usingnamespacemetal;
4+
5+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triu/tril ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
6+
template<bool upper>
7+
inlinebooltriul_mask(int row,int col,int k);
8+
template<>
9+
inlinebool triul_mask<true>(int row,int col,int k) {
10+
return col - row >= k;
11+
}
12+
template<>
13+
inlinebool triul_mask<false>(int row,int col,int k) {
14+
return col - row <= k;
15+
}
16+
17+
template<typename IndexType>
18+
inline IndexTypecompute_offs(
19+
constant IndexType* strides,
20+
constant uint* sizes,
21+
uint3 pos,
22+
int ndim) {
23+
auto offs = pos.x * strides[0] + pos.y * strides[1];
24+
if (ndim <4) {
25+
return ndim ==3 ? offs + pos.z * strides[2] : offs;
26+
}
27+
auto idx = pos.z;
28+
for (int i =2; i < ndim; ++i) {
29+
offs += strides[i] * (idx % sizes[i]);
30+
idx /= sizes[i];
31+
}
32+
return offs;
33+
}
34+
35+
template<typename T,typename IndexType,bool upper>
36+
kernelvoidtriul_inplace(
37+
device T* self,
38+
constant IndexType* strides,
39+
constant uint* sizes,
40+
constant int2& k_ndim,
41+
uint3 pos [[thread_position_in_grid]]) {
42+
if (triul_mask<upper>(pos.y, pos.x, k_ndim.x)) {
43+
return;
44+
}
45+
auto offs =compute_offs(strides, sizes, pos, k_ndim.y);
46+
self[offs] =0;
47+
}
48+
49+
template<typename T,typename IndexType,bool upper>
50+
kernelvoidtriul(
51+
device T* out,
52+
device T* inp,
53+
constant IndexType* out_strides,
54+
constant IndexType* inp_strides,
55+
constant uint* sizes,
56+
constant int2& k_ndim,
57+
uint3 pos [[thread_position_in_grid]]) {
58+
auto out_offs =compute_offs(out_strides, sizes, pos, k_ndim.y);
59+
if (!triul_mask<upper>(pos.y, pos.x, k_ndim.x)) {
60+
out[out_offs] =0;
61+
return;
62+
}
63+
auto inp_offs =compute_offs(inp_strides, sizes, pos, k_ndim.y);
64+
out[out_offs] = inp[inp_offs];
65+
}
66+
67+
#defineINSTANTIATE_TRIUL_KERNELS(DTYPE, IDX_TYPE) \
68+
template[[host_name("triu_inplace_" #IDX_TYPE"_" #DTYPE)]] kernelvoid \
69+
triul_inplace<DTYPE, IDX_TYPE,true>( \
70+
device DTYPE * self, \
71+
constant IDX_TYPE * strides, \
72+
constant uint * sizes, \
73+
constant int2 & k_ndim, \
74+
uint3 pos [[thread_position_in_grid]]); \
75+
template[[host_name("tril_inplace_" #IDX_TYPE"_" #DTYPE)]] kernelvoid \
76+
triul_inplace<DTYPE, IDX_TYPE,false>( \
77+
device DTYPE * self, \
78+
constant IDX_TYPE * strides, \
79+
constant uint * sizes, \
80+
constant int2 & k_ndim, \
81+
uint3 pos [[thread_position_in_grid]]); \
82+
template[[host_name("triu_" #IDX_TYPE"_" #DTYPE)]] kernelvoid \
83+
triul<DTYPE, IDX_TYPE,true>( \
84+
device DTYPE * out, \
85+
device DTYPE * inp, \
86+
constant IDX_TYPE * out_strides, \
87+
constant IDX_TYPE * inp_strides, \
88+
constant uint * sizes, \
89+
constant int2 & k_ndim, \
90+
uint3 pos [[thread_position_in_grid]]); \
91+
template[[host_name("tril_" #IDX_TYPE"_" #DTYPE)]] kernelvoid \
92+
triul<DTYPE, IDX_TYPE,false>( \
93+
device DTYPE * out, \
94+
device DTYPE * inp, \
95+
constant IDX_TYPE * out_strides, \
96+
constant IDX_TYPE * inp_strides, \
97+
constant uint * sizes, \
98+
constant int2 & k_ndim, \
99+
uint3 pos [[thread_position_in_grid]])
100+
101+
INSTANTIATE_TRIUL_KERNELS(float,int);
102+
INSTANTIATE_TRIUL_KERNELS(half,int);
103+
#if __METAL_VERSION__ >= 310
104+
INSTANTIATE_TRIUL_KERNELS(bfloat,int);
105+
#endif
106+
107+
INSTANTIATE_TRIUL_KERNELS(float2,int);
108+
INSTANTIATE_TRIUL_KERNELS(half2,int);
109+
110+
INSTANTIATE_TRIUL_KERNELS(long,int);
111+
INSTANTIATE_TRIUL_KERNELS(int,int);
112+
INSTANTIATE_TRIUL_KERNELS(short,int);
113+
INSTANTIATE_TRIUL_KERNELS(char,int);
114+
INSTANTIATE_TRIUL_KERNELS(uchar,int);
115+
INSTANTIATE_TRIUL_KERNELS(bool,int);
116+
3117
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triangle ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4118

5119
// To find the max integer that does not exceed the root of an int64_t variable,

‎aten/src/ATen/native/mps/operations/TriangularOps.mm‎

Lines changed: 36 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include<ATen/native/LinearAlgebraUtils.h>
66
#include<ATen/native/TensorFactories.h>
77
#include<ATen/native/mps/OperationUtils.h>
8+
#include<fmt/format.h>
89

910
#ifndef AT_PER_OPERATOR_HEADERS
1011
#include<ATen/Functions.h>
@@ -26,101 +27,53 @@
2627
#include<ATen/native/mps/TriangularOps_metallib.h>
2728
#endif
2829

29-
TORCH_IMPL_FUNC(triu_mps_out)
30-
(const Tensor& self,int64_t k,const Tensor& output) {
31-
usingnamespacemps;
32-
using CachedGraph = MPSUnaryCachedGraph;
33-
34-
if (self.numel() ==0) {
35-
return;
36-
}
37-
auto stream =getCurrentMPSStream();
38-
39-
@autoreleasepool {
40-
std::string key ="triu_mps_out" +mps::getTensorsStringKey({self}) +":" +std::to_string(k);
41-
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph,auto newCachedGraph) {
42-
MPSGraphTensor* outputTensor =nil;
43-
auto inputTensor =mpsGraphRankedPlaceHolder(mpsGraph, self);
44-
45-
auto minusOneTensor = [mpsGraphconstantWithScalar:-1dataType:MPSDataTypeInt32];
46-
47-
if (k >0) {
48-
auto diagMinusOneTensor = [mpsGraphconstantWithScalar:(k -1)dataType:MPSDataTypeInt32];
49-
auto onesTensor = [mpsGraphconstantWithScalar:1shape:inputTensor.shapedataType:MPSDataTypeInt32];
50-
auto maskTensor = [mpsGraphbandPartWithTensor:onesTensor
51-
numLowerTensor:minusOneTensor
52-
numUpperTensor:diagMinusOneTensor
53-
name:nil];
54-
outputTensor = [mpsGraphselectWithPredicateTensor:maskTensor
55-
truePredicateTensor:[mpsGraphconstantWithScalar:0dataType:inputTensor.dataType]
56-
falsePredicateTensor:inputTensor
57-
name:nil];
58-
}else {
59-
auto minusDiagTensor = [mpsGraphconstantWithScalar:(-k)dataType:MPSDataTypeInt32];
60-
outputTensor = [mpsGraphbandPartWithTensor:inputTensor
61-
numLowerTensor:minusDiagTensor
62-
numUpperTensor:minusOneTensor
63-
name:nil];
64-
}
65-
66-
newCachedGraph->inputTensor_ = inputTensor;
67-
newCachedGraph->outputTensor_ = outputTensor;
68-
});
69-
70-
auto selfPlaceholder =Placeholder(cachedGraph->inputTensor_, self);
71-
auto outputPlaceholder =Placeholder(cachedGraph->outputTensor_, output);
72-
runMPSGraph(stream, cachedGraph->graph(),dictionaryFromPlaceholders(selfPlaceholder), outputPlaceholder);
30+
template<typename T>
31+
static std::vector<T>reverse_array(const IntArrayRef& arr) {
32+
std::vector<T>rc(arr.size());
33+
for (constauto& i :c10::irange(arr.size())) {
34+
rc[i] = arr[arr.size() -1 - i];
7335
}
36+
return rc;
7437
}
7538

76-
TORCH_IMPL_FUNC(tril_mps_out)
77-
(const Tensor& self,int64_t k,const Tensor& output) {
39+
staticvoidtriu_tril_impl(const Tensor& self,int64_t k,const Tensor& out,const std::string& name) {
7840
usingnamespacemps;
79-
using CachedGraph = MPSUnaryCachedGraph;
80-
8141
if (self.numel() ==0) {
8242
return;
8343
}
84-
44+
auto sizes = reverse_array<uint32_t>(self.sizes());
45+
auto inp_strides = reverse_array<int32_t>(self.strides());
46+
auto out_strides = reverse_array<int32_t>(out.strides());
47+
std::array<int,2> k_ndim = {int(k),int(self.ndimension())};
48+
constbool inplace = self.is_same(out);
49+
constauto kernel_name =
50+
fmt::format("{}{}_{}_{}", name, inplace ?"_inplace" :"","int",scalarToMetalTypeString(self));
51+
auto triuPSO = lib.getPipelineStateForFunc(kernel_name);
52+
uint32_t max_threads_per_group = [triuPSOmaxTotalThreadsPerThreadgroup];
8553
auto stream =getCurrentMPSStream();
86-
87-
@autoreleasepool {
88-
std::string key ="tril_mps_out" +mps::getTensorsStringKey({self}) +":" +std::to_string(k);
89-
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph,auto newCachedGraph) {
90-
MPSGraphTensor* outputTensor =nil;
91-
92-
auto inputTensor =mpsGraphRankedPlaceHolder(mpsGraph, self);
93-
auto minusOneTensor = [mpsGraphconstantWithScalar:-1dataType:MPSDataTypeInt32];
94-
95-
if (k >=0) {
96-
auto diagTensor = [mpsGraphconstantWithScalar:kdataType:MPSDataTypeInt32];
97-
outputTensor = [mpsGraphbandPartWithTensor:inputTensor
98-
numLowerTensor:minusOneTensor
99-
numUpperTensor:diagTensor
100-
name:nil];
54+
dispatch_sync_with_rethrow(stream->queue(), ^() {
55+
@autoreleasepool {
56+
auto computeEncoder = stream->commandEncoder();
57+
[computeEncodersetComputePipelineState:triuPSO];
58+
if (inplace) {
59+
mtl_setArgs(computeEncoder, self, inp_strides, sizes, k_ndim);
10160
}else {
102-
auto negDiagMinusOneTensor = [mpsGraphconstantWithScalar:(-k -1)dataType:MPSDataTypeInt32];
103-
auto complementTensor = [mpsGraphbandPartWithTensor:inputTensor
104-
numLowerTensor:negDiagMinusOneTensor
105-
numUpperTensor:minusOneTensor
106-
name:nil];
107-
auto zeroTensor = [mpsGraphconstantWithScalar:0.0dataType:getMPSDataType(self)];
108-
auto mask = [mpsGraphequalWithPrimaryTensor:complementTensorsecondaryTensor:zeroTensorname:nil];
109-
outputTensor = [mpsGraphselectWithPredicateTensor:mask
110-
truePredicateTensor:inputTensor
111-
falsePredicateTensor:zeroTensor
112-
name:nil];
61+
mtl_setArgs(computeEncoder, out, self, out_strides, inp_strides, sizes, k_ndim);
11362
}
63+
[computeEncoderdispatchThreads:MTLSizeMake(sizes[0], sizes[1],self.numel() / (sizes[0] * sizes[1]))
64+
threadsPerThreadgroup:MTLSizeMake(std::min(max_threads_per_group, sizes[0]),1,1)];
65+
}
66+
});
67+
}
11468

115-
newCachedGraph->inputTensor_ = inputTensor;
116-
newCachedGraph->outputTensor_ = outputTensor;
117-
});
118-
119-
auto selfPlaceholder =Placeholder(cachedGraph->inputTensor_, self);
120-
auto outputPlaceholder =Placeholder(cachedGraph->outputTensor_, output);
69+
TORCH_IMPL_FUNC(triu_mps_out)
70+
(const Tensor& self,int64_t k,const Tensor& output) {
71+
triu_tril_impl(self, k, output,"triu");
72+
}
12173

122-
runMPSGraph(stream, cachedGraph->graph(),dictionaryFromPlaceholders(selfPlaceholder), outputPlaceholder);
123-
}
74+
TORCH_IMPL_FUNC(tril_mps_out)
75+
(const Tensor& self,int64_t k,const Tensor& output) {
76+
triu_tril_impl(self, k, output,"tril");
12477
}
12578

12679
Tensortril_indices_mps(int64_t row,

‎test/test_mps.py‎

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7146,6 +7146,11 @@ def helper(shape, diag=0):
71467146
helper((2, 8, 4, 5), diag=-1)
71477147
helper((2, 8, 4, 5), diag=-2)
71487148
helper((2, 8, 4, 5), diag=-3)
7149+
# Test inplace
7150+
x_mps = torch.arange(9.0, device='mps').reshape(3, 3).t().triu()
7151+
x_cpu = torch.arange(9.0, device='cpu').reshape(3, 3).t().triu()
7152+
self.assertEqual(x_cpu, x_mps)
7153+
self.assertEqual(x_cpu.stride(), x_mps.stride())
71497154

71507155
# Test inverse
71517156
def test_inverse(self):

‎torch/testing/_internal/common_mps.py‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,8 @@ def mps_ops_modifier(
157157
"tensor_split",
158158
"transpose",
159159
"transpose_copy",
160+
"tril",
161+
"triu",
160162
"true_divide",
161163
"T",
162164
"unbind",
@@ -283,8 +285,6 @@ def mps_ops_modifier(
283285
"trace",
284286
"trapz",
285287
"trapezoid",
286-
"tril",
287-
"triu",
288288
"vstack",
289289
"where",
290290
"byte",

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp