Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commita8d6afb

Browse files
Disabling amp context when invoking compiler (#138659)
Disabling amp context when invoking compiler (#138624)Fix for#133974Pull Requestresolved:#138624Approved by:https://github.com/bdhirsh,https://github.com/drisspg(cherry picked from commit5942b29)Co-authored-by: eellison <elias.ellison@gmail.com>
1 parentf31b8bb commita8d6afb

File tree

2 files changed

+63
-24
lines changed

2 files changed

+63
-24
lines changed

‎test/inductor/test_cpu_repro.py‎

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3941,6 +3941,47 @@ def forward(self, x):
39413941
x=torch.randn(1,4,2,2)
39423942
self.common(fn, (x,))
39433943

3944+
@parametrize("is_inference", (True,False))
3945+
deftest_disabled_amp(self,is_inference):
3946+
classM(torch.nn.Module):
3947+
def__init__(self):
3948+
super().__init__()
3949+
self.all_head_size=12*64
3950+
self.dense=nn.Linear(self.all_head_size,self.all_head_size)
3951+
3952+
defforward(self,q,k,v):
3953+
context_layer=F.scaled_dot_product_attention(
3954+
q,k,v,attn_mask=None,dropout_p=0.2
3955+
)
3956+
context_layer=context_layer.permute(0,2,1,3).contiguous()
3957+
new_context_layer_shape=context_layer.size()[:-2]+ (
3958+
self.all_head_size,
3959+
)
3960+
context_layer=context_layer.view(new_context_layer_shape)
3961+
returnself.dense(context_layer)
3962+
3963+
mod=M().to(torch.bfloat16).eval()
3964+
3965+
q=torch.randn((4,12,512,64),dtype=torch.bfloat16)/10.0
3966+
k=torch.randn((4,12,512,64),dtype=torch.bfloat16)/10.0
3967+
v=torch.randn((4,12,512,64),dtype=torch.bfloat16)/10.0
3968+
inputs= (
3969+
q,
3970+
k,
3971+
v,
3972+
)
3973+
compiler_mode=torch.compile(mod)
3974+
fromtorch.nn.attentionimportsdpa_kernel,SDPBackend
3975+
3976+
context=contextlib.nullcontextifnotis_inferenceelsetorch.no_grad
3977+
withconfig.patch(
3978+
{"fallback_random":True}
3979+
),torch.cpu.amp.autocast(),context(),sdpa_kernel(SDPBackend.MATH):
3980+
torch.manual_seed(0)
3981+
eager=mod(*inputs)
3982+
torch.manual_seed(0)
3983+
self.assertEqual(compiler_mode(*inputs),eager)
3984+
39443985
@requires_vectorization
39453986
deftest_vec_indirect_load_cse_cache(self):
39463987
# https://github.com/pytorch/pytorch/issues/123502

‎torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py‎

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,9 @@ def aot_dispatch_autograd(
555555
),
556556
)
557557

558-
withtrack_graph_compiling(aot_config,"forward"):
558+
# AMP is already traced out in joint graph. we do not wish to reapply it accidentally
559+
# in the compiler.
560+
withtrack_graph_compiling(aot_config,"forward"),torch._C._DisableAutocast():
559561
# flat_args at this point might still be subclasses-
560562
# make sure to pass the unwrapped fake tensors into the compiler!
561563
adjusted_flat_args=joint_inputs[0]
@@ -620,7 +622,7 @@ def aot_dispatch_autograd(
620622
# NB: It's important to compile backwards ahead of time, as this may
621623
# add extra guards which we need to apply to the Dynamo cache at
622624
# forwards
623-
withtrack_graph_compiling(aot_config,"backward"):
625+
withtrack_graph_compiling(aot_config,"backward"),torch._C._DisableAutocast():
624626
placeholder_list=fx_placeholder_vals(bw_module)
625627

626628
forward_saved_for_backwards_strides=None
@@ -672,28 +674,24 @@ def aot_dispatch_autograd(
672674

673675
compiled_bw_func=None
674676
ifnum_symints_saved_for_bw>0:
675-
context=torch._C._DisableAutocastifdisable_ampelsenullcontext
676-
withcontext():
677-
try:
678-
compiled_bw_func=aot_config.bw_compiler(
679-
bw_module,placeholder_list
680-
)
681-
exceptExceptionase:
682-
exc=e
683-
trace_structured(
684-
"artifact",
685-
metadata_fn=lambda: {
686-
"name":"eager_compile_backwards_failure",
687-
"encoding":"string",
688-
},
689-
payload_fn=lambda:"\n".join(
690-
traceback.format_exception(exc)
691-
),
692-
)
693-
log.warning(
694-
"failed to eagerly compile backwards for dynamic, suppressing in case backwards not needed",
695-
exc_info=True,
696-
)
677+
try:
678+
compiled_bw_func=aot_config.bw_compiler(
679+
bw_module,placeholder_list
680+
)
681+
exceptExceptionase:
682+
exc=e
683+
trace_structured(
684+
"artifact",
685+
metadata_fn=lambda: {
686+
"name":"eager_compile_backwards_failure",
687+
"encoding":"string",
688+
},
689+
payload_fn=lambda:"\n".join(traceback.format_exception(exc)),
690+
)
691+
log.warning(
692+
"failed to eagerly compile backwards for dynamic, suppressing in case backwards not needed",
693+
exc_info=True,
694+
)
697695
# Compiled autograd will run the bw_module in the backward pass,
698696
# so recompilation need happen anyway if the backward pass is ever
699697
# called.

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp