Note
Go to the endto download the full example code.
Dynamic Compilation Control withtorch.compiler.set_stance#
Author:William Wen
torch.compiler.set_stance is atorch.compiler API thatenables you to change the behavior oftorch.compile across differentcalls to your model without having to reapplytorch.compile to your model.
This recipe provides some examples on how to usetorch.compiler.set_stance.
Prerequisites#
torch>=2.6
Description#
torch.compile.set_stance can be used as a decorator, context manager, or raw functionto change the behavior oftorch.compile across different calls to your model.
In the example below, the"force_eager" stance ignores alltorch.compile directives.
importtorch@torch.compiledeffoo(x):iftorch.compiler.is_compiling():# torch.compile is activereturnx+1else:# torch.compile is not activereturnx-1inp=torch.zeros(3)print(foo(inp))# compiled, prints 1
tensor([1., 1., 1.])
Sample decorator usage
@torch.compiler.set_stance("force_eager")defbar(x):# force disable the compilerreturnfoo(x)print(bar(inp))# not compiled, prints -1
tensor([-1., -1., -1.])
Sample context manager usage
withtorch.compiler.set_stance("force_eager"):print(foo(inp))# not compiled, prints -1
tensor([-1., -1., -1.])
Sample raw function usage
torch.compiler.set_stance("force_eager")print(foo(inp))# not compiled, prints -1torch.compiler.set_stance("default")print(foo(inp))# compiled, prints 1
tensor([-1., -1., -1.])tensor([1., 1., 1.])
torch.compile stance can only be changedoutside of anytorch.compile region. Attemptsto do otherwise will result in an error.
@torch.compiledefbaz(x):# error!withtorch.compiler.set_stance("force_eager"):returnx+1try:baz(inp)exceptExceptionase:print(e)@torch.compiler.set_stance("force_eager")definner(x):returnx+1@torch.compiledefouter(x):# error!returninner(x)try:outer(inp)exceptExceptionase:print(e)
Attempt to trace forbidden callable <function set_stance at 0x7f0d27cd0940>from user code: File "/var/lib/workspace/recipes_source/torch_compiler_set_stance_tutorial.py", line 85, in baz with torch.compiler.set_stance("force_eager"):Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"Attempt to trace forbidden callable <function inner at 0x7f0d4d9292d0>from user code: File "/var/lib/workspace/recipes_source/torch_compiler_set_stance_tutorial.py", line 103, in outer return inner(x)Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"- Other stances include:
"default": The default stance, used for normal compilation."eager_on_recompile": Run code eagerly when a recompile is necessary. If there is cached compiled code valid for the input, it will still be used."fail_on_recompile": Raise an error when recompiling a function.
See thetorch.compiler.set_stancedoc pagefor more stances and options. More stances/options may also be added in the future.
Examples#
Preventing recompilation#
Some models do not expect any recompilations - for example, you may always have inputs with the same shape.Since recompilations may be expensive, we may wish to error out when we attempt to recompile so we can detect and fix recompilation cases.The"fail_on_recompilation" stance can be used for this.
@torch.compiledefmy_big_model(x):returntorch.relu(x)# first compilationmy_big_model(torch.randn(3))withtorch.compiler.set_stance("fail_on_recompile"):my_big_model(torch.randn(3))# no recompilation - OKtry:my_big_model(torch.randn(4))# recompilation - errorexceptExceptionase:print(e)
Detected recompile when torch.compile stance is 'fail_on_recompile'. filename: '/var/lib/workspace/recipes_source/torch_compiler_set_stance_tutorial.py', function name: 'my_big_model', line number: 0 triggered by the following guard failure(s): - 3/0: tensor 'x' size mismatch at index 0. expected 3, actual 4
If erroring out is too disruptive, we can use"eager_on_recompile" instead,which will causetorch.compile to fall back to eager instead of erroring out.This may be useful if we don’t expect recompilations to happen frequently, butwhen one is required, we’d rather pay the cost of running eagerly over the cost of recompilation.
@torch.compiledefmy_huge_model(x):iftorch.compiler.is_compiling():returnx+1else:returnx-1# first compilationprint(my_huge_model(torch.zeros(3)))# 1withtorch.compiler.set_stance("eager_on_recompile"):print(my_huge_model(torch.zeros(3)))# 1print(my_huge_model(torch.zeros(4)))# -1print(my_huge_model(torch.zeros(3)))# 1
tensor([1., 1., 1.])tensor([1., 1., 1.])tensor([-1., -1., -1., -1.])tensor([1., 1., 1.])
Measuring performance gains#
torch.compiler.set_stance can be used to compare eager vs. compiled performancewithout having to define a separate eager model.
# Returns the result of running `fn()` and the time it took for `fn()` to run,# in seconds. We use CUDA events and synchronization for the most accurate# measurements.deftimed(fn):start=torch.cuda.Event(enable_timing=True)end=torch.cuda.Event(enable_timing=True)start.record()result=fn()end.record()torch.cuda.synchronize()returnresult,start.elapsed_time(end)/1000@torch.compiledefmy_gigantic_model(x,y):x=x@yx=x@yx=x@yreturnxinps=torch.randn(5,5),torch.randn(5,5)withtorch.compiler.set_stance("force_eager"):print("eager:",timed(lambda:my_gigantic_model(*inps))[1])# warmupsfor_inrange(3):my_gigantic_model(*inps)print("compiled:",timed(lambda:my_gigantic_model(*inps))[1])
eager: 0.00025190401077270505compiled: 8.70399996638298e-05
Crashing sooner#
Running an eager iteration first before a compiled iteration using the"force_eager" stancecan help us to catch errors unrelated totorch.compile before attempting a very long compile.
@torch.compiledefmy_humongous_model(x):returntorch.sin(x,x)try:withtorch.compiler.set_stance("force_eager"):print(my_humongous_model(torch.randn(3)))# this call to the compiled model won't runprint(my_humongous_model(torch.randn(3)))exceptExceptionase:print(e)
sin() takes 1 positional argument but 2 were given
Conclusion#
In this recipe, we have learned how to use thetorch.compiler.set_stance APIto modify the behavior oftorch.compile across different calls to a modelwithout needing to reapply it. The recipe demonstrates usingtorch.compiler.set_stance as a decorator, context manager, or raw functionto control compilation stances likeforce_eager,default,eager_on_recompile, and “fail_on_recompile.”
For more information, see:torch.compiler.set_stance API documentation.
Total running time of the script: (0 minutes 10.548 seconds)