Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commite4fe67f

Browse files
Revert "[MPS] Make fused rms_norm traceable (#150661)"
This reverts commit682f09e.Reverted#150661 on behalf ofhttps://github.com/malfet due to Has decomp started to fail again ([comment](#150661 (comment)))
1 parent32c79da commite4fe67f

File tree

6 files changed

+24
-20
lines changed

6 files changed

+24
-20
lines changed

‎aten/src/ATen/native/layer_norm.cpp‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
#include<ATen/ops/empty_like.h>
1717
#include<ATen/ops/empty_like_native.h>
1818
#include<ATen/ops/layer_norm_native.h>
19-
#include<ATen/ops/_fused_rms_norm.h>
2019
#include<ATen/ops/native_batch_norm.h>
2120
#include<ATen/ops/native_layer_norm.h>
2221
#include<ATen/ops/native_layer_norm_backward_native.h>
@@ -28,6 +27,7 @@
2827
#endif
2928

3029
#ifdef USE_MPS
30+
#include<ATen/native/mps/operations/RMSNorm.h>
3131
#include<c10/core/GradMode.h>
3232
#endif
3333

@@ -281,7 +281,7 @@ Tensor rms_norm_symint(
281281

282282
if (!(GradMode::is_enabled() && any_inputs_require_grad) && !any_nested && is_input_fp && is_weight_fp) {
283283
auto eps_val = eps.value_or(std::numeric_limits<double>::epsilon());
284-
returnat::_fused_rms_norm(input.contiguous(), normalized_shape.size(), weight.contiguous(), eps_val);
284+
returnmps::rms_norm_mps_kernel(input.contiguous(), normalized_shape, weight.contiguous(), eps_val);
285285
}
286286
}
287287
#endif
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#pragma once
2+
3+
#include<ATen/core/Tensor.h>
4+
#include<c10/core/SymIntArrayRef.h>
5+
6+
namespaceat::native::mps {
7+
8+
Tensorrms_norm_mps_kernel(
9+
const Tensor& input,
10+
c10::SymIntArrayRef normalized_shape,
11+
const Tensor& weight,
12+
constdouble eps);
13+
14+
}// namespace at::native::mps

‎aten/src/ATen/native/mps/operations/RMSNorm.mm‎

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,27 @@
44
#include<ATen/Functions.h>
55
#include<ATen/NativeFunctions.h>
66
#else
7-
#include<ATen/ops/_fused_rms_norm_native.h>
87
#include<ATen/ops/empty_like.h>
98
#endif
109
#include<ATen/native/mps/OperationUtils.h>
10+
#include<ATen/native/mps/operations/RMSNorm.h>
1111
#include<fmt/format.h>
1212

13-
namespaceat::native {
14-
usingnamespacemps;
13+
namespaceat::native::mps {
1514

1615
#ifndef PYTORCH_JIT_COMPILE_SHADERS
1716
staticauto& lib = MetalShaderLibrary::getBundledLibrary();
1817
#else
1918
#include<ATen/native/mps/RMSNorm_metallib.h>
2019
#endif
2120

22-
Tensor_fused_rms_norm_mps(const Tensor& input,constint64_t normalized_ndim,const Tensor& weight,constdouble eps) {
21+
Tensorrms_norm_mps_kernel(const Tensor& input,
22+
c10::SymIntArrayRef normalized_shape,
23+
const Tensor& weight,
24+
constdouble eps) {
2325
TORCH_CHECK(input.is_contiguous() && weight.is_contiguous(),"Expected contiguous input and weight tensors");
2426
auto output =at::empty_like(input);
27+
constint normalized_ndim = normalized_shape.size();
2528
constauto input_shape = input.sizes();
2629
constauto input_ndim = input.dim();
2730
constint axis = input_ndim - normalized_ndim;
@@ -61,4 +64,4 @@ Tensor _fused_rms_norm_mps(const Tensor& input, const int64_t normalized_ndim, c
6164
return output;
6265
}
6366

64-
}// namespace at::native
67+
}// namespace at::native::mps

‎aten/src/ATen/native/native_functions.yaml‎

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3301,10 +3301,6 @@
33013301
dispatch:
33023302
CompositeImplicitAutograd: rms_norm_symint
33033303

3304-
- func: _fused_rms_norm(Tensor input, int normalized_shape_ndim, Tensor weight, float eps) -> Tensor
3305-
dispatch:
3306-
MPS: _fused_rms_norm_mps
3307-
33083304
- func: nan_to_num(Tensor self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor
33093305
variants: function, method
33103306
dispatch:

‎test/inductor/test_mps_basic.py‎

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -152,14 +152,6 @@ def inc_(x):
152152

153153
self.common(inc_, (torch.rand(1024),))
154154

155-
deftest_rms_norm_nograd(self):
156-
# Regression test for https://github.com/pytorch/pytorch/issues/150629
157-
deffn(x,w):
158-
withtorch.no_grad():
159-
returntorch.nn.functional.rms_norm(x,x.shape,w)
160-
161-
self.common(fn, (torch.rand(10),torch.ones(10)))
162-
163155

164156
if__name__=="__main__":
165157
fromtorch._dynamo.test_caseimportrun_tests

‎torch/_inductor/lowering.py‎

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2625,7 +2625,6 @@ def is_aligned(x):
26252625
make_fallback(aten.exponential.default,warn=False)# (fails accuracy on test_torch.py)
26262626
make_fallback(aten._pdist_forward)# Has decomp. Needs benchmarks
26272627
make_fallback(aten.soft_margin_loss_backward,warn=False)# py_impl?
2628-
make_fallback(aten._fused_rms_norm,warn=False)# (MPS-only and faster than decomp)
26292628

26302629

26312630
# 1.5) Easy or Impossible

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp