Dealing with Recompilations#
Created On: Jul 29, 2025 | Last Updated On: Jul 29, 2025
Recompilations are necessary fortorch.compile soundness, but can result in significantly increased compile time.Thus, minimizing recompilations while preserving soundness is essential for reducing compile time.
You can view recompilations and their reasons using tlparse orTORCH_LOGS=recompiles.
Is Dynamic Shapes Enabled?#
In the below example, we recompile due to mismatched shapes:
@torch.compiledeffn(x):returnx+1fn(torch.ones(3))fn(torch.ones(4))
Recompiling function fn in /tmp/ipykernel_507/2479206322.py:1 triggered by the following guard failure(s): - 0/0: tensor 'x' size mismatch at index 0. expected 3, actual 4
tensor([2., 2., 2., 2.])
Make sure that the dynamic option oftorch.compile is not set toFalse.The default option,dynamic=None, will only attempt dynamic shapes after the first compilation.You can setdynamic=True to upfront compile as dynamic as possible:
@torch.compile(dynamic=True)defgn(x):returnx+1gn(torch.ones(3))gn(torch.ones(4))
tensor([2., 2., 2., 2.])
For more information on dynamic shapes, including dealing with errors/recompilations due todynamic shapes, seethe dynamic shapes manual.
Wrapping Constants with Tensors#
By default,int /float variables are treated as constants and are guarded on their exact value.In the below example, we have a recompilation for each function call.
@torch.compiledeffn(x,c):returnx+cforiinrange(5):fn(torch.ones(i),0.5+i)
Recompiling function fn in /tmp/ipykernel_507/3647755280.py:1 triggered by the following guard failure(s): - 2/0: c == 0.5 # return x + c # mp/ipykernel_507/3647755280.py:3 in fnRecompiling function fn in /tmp/ipykernel_507/3647755280.py:1 triggered by the following guard failure(s): - 2/1: tensor 'x' size mismatch at index 0. expected 1, actual 2 - 2/0: c == 0.5 # return x + c # mp/ipykernel_507/3647755280.py:3 in fn
In particular, for LR schedulers, initializing with a constant can lead to recompilations:
mod=torch.nn.Linear(3,3)opt=torch.optim.Adam(mod.parameters(),lr=0.01)sched=torch.optim.lr_scheduler.ExponentialLR(opt,0.9)@torch.compiledefgn(inp):opt.zero_grad(True)out=mod(inp).sum()out.backward()opt.step()sched.step()foriinrange(5):gn(torch.ones(3,3))
Profiler function <class 'torch.autograd.profiler.record_function'> will be ignoredRecompiling function step in /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/optim/adam.py:213 triggered by the following guard failure(s): - 7/0: self.param_groups[0]['lr'] == 0.01 # for group in self.param_groups: # optim/adam.py:228 in step
In both examples, we can wrapfloat variables in tensors in order to prevent recompilations.
# first exampleforiinrange(5):fn(torch.ones(i),torch.tensor(0.5+i))# second exampleopt=torch.optim.Adam(mod.parameters(),lr=torch.tensor(0.01))sched=torch.optim.lr_scheduler.ExponentialLR(opt,torch.tensor(0.9))foriinrange(5):gn(torch.ones(3,3))
Recompiling function fn in /tmp/ipykernel_507/3647755280.py:1 triggered by the following guard failure(s): - 0/0: tensor 'x' size mismatch at index 0. expected 0, actual 1Recompiling function fn in /tmp/ipykernel_507/3647755280.py:1 triggered by the following guard failure(s): - 0/1: tensor 'x' size mismatch at index 0. expected 1, actual 2 - 0/0: tensor 'x' size mismatch at index 0. expected 0, actual 2
Changing the Cache Size Limit#
There is a limit to how many times a function can be recompiled,determined bytorch._dynamo.config.cache_size_limit andtorch._dynamo.config.accumulated_cache_size_limit(The exact difference between these 2 values is detailed intorch/_dynamo/cache_size.py).If the Dynamo cache limit is hit, then all future compilation attemptswill result in the function being skipped (run eagerly).Dynamo will still attempt to use previously compiled bytecode for future function calls, if the guards pass.Note that in the case of a recompilation limit hit,all nested function calls WILL be skipped(Dynamo will try to use previously compiled bytecode for the nested functions).Dynamo will also issue a warning containing the affected function and which limit was hit.In the example below, each function call results in a recompile attempt.When we hit the cache size limit (by default, 8), we stop attempting to recompile.(Note that we setdynamic=False for demonstration purposes to force recompilation every time).
@torch.compile(dynamic=False)deffn(x):returnx+1foriinrange(1,10):# recompile every time due to dynamic=Falsefn(torch.ones(i))
Recompiling function fn in /tmp/ipykernel_507/3054308037.py:1 triggered by the following guard failure(s): - 8/0: tensor 'x' size mismatch at index 0. expected 1, actual 2Recompiling function fn in /tmp/ipykernel_507/3054308037.py:1 triggered by the following guard failure(s): - 8/1: tensor 'x' size mismatch at index 0. expected 2, actual 3 - 8/0: tensor 'x' size mismatch at index 0. expected 1, actual 3Recompiling function fn in /tmp/ipykernel_507/3054308037.py:1 triggered by the following guard failure(s): - 8/2: tensor 'x' size mismatch at index 0. expected 3, actual 4 - 8/1: tensor 'x' size mismatch at index 0. expected 2, actual 4 - 8/0: tensor 'x' size mismatch at index 0. expected 1, actual 4Recompiling function fn in /tmp/ipykernel_507/3054308037.py:1 triggered by the following guard failure(s): - 8/3: tensor 'x' size mismatch at index 0. expected 4, actual 5 - 8/2: tensor 'x' size mismatch at index 0. expected 3, actual 5 - 8/1: tensor 'x' size mismatch at index 0. expected 2, actual 5 - 8/0: tensor 'x' size mismatch at index 0. expected 1, actual 5Recompiling function fn in /tmp/ipykernel_507/3054308037.py:1 triggered by the following guard failure(s): - 8/4: tensor 'x' size mismatch at index 0. expected 5, actual 6 - 8/3: tensor 'x' size mismatch at index 0. expected 4, actual 6 - 8/2: tensor 'x' size mismatch at index 0. expected 3, actual 6 - 8/1: tensor 'x' size mismatch at index 0. expected 2, actual 6 - 8/0: tensor 'x' size mismatch at index 0. expected 1, actual 6Recompiling function fn in /tmp/ipykernel_507/3054308037.py:1 triggered by the following guard failure(s): - 8/5: tensor 'x' size mismatch at index 0. expected 6, actual 7 - 8/4: tensor 'x' size mismatch at index 0. expected 5, actual 7 - 8/3: tensor 'x' size mismatch at index 0. expected 4, actual 7 - 8/2: tensor 'x' size mismatch at index 0. expected 3, actual 7 - 8/1: tensor 'x' size mismatch at index 0. expected 2, actual 7 - 8/0: tensor 'x' size mismatch at index 0. expected 1, actual 7Recompiling function fn in /tmp/ipykernel_507/3054308037.py:1 triggered by the following guard failure(s): - 8/6: tensor 'x' size mismatch at index 0. expected 7, actual 8 - 8/5: tensor 'x' size mismatch at index 0. expected 6, actual 8 - 8/4: tensor 'x' size mismatch at index 0. expected 5, actual 8 - 8/3: tensor 'x' size mismatch at index 0. expected 4, actual 8 - 8/2: tensor 'x' size mismatch at index 0. expected 3, actual 8 - 8/1: tensor 'x' size mismatch at index 0. expected 2, actual 8 - 8/0: tensor 'x' size mismatch at index 0. expected 1, actual 8Recompiling function fn in /tmp/ipykernel_507/3054308037.py:1 triggered by the following guard failure(s): - 8/7: tensor 'x' size mismatch at index 0. expected 8, actual 9 - 8/6: tensor 'x' size mismatch at index 0. expected 7, actual 9 - 8/5: tensor 'x' size mismatch at index 0. expected 6, actual 9 - 8/4: tensor 'x' size mismatch at index 0. expected 5, actual 9 - 8/3: tensor 'x' size mismatch at index 0. expected 4, actual 9 - 8/2: tensor 'x' size mismatch at index 0. expected 3, actual 9 - 8/1: tensor 'x' size mismatch at index 0. expected 2, actual 9 - 8/0: tensor 'x' size mismatch at index 0. expected 1, actual 9torch._dynamo hit config.recompile_limit (8) function: 'fn' (/tmp/ipykernel_507/3054308037.py:1) last reason: 8/7: tensor 'x' size mismatch at index 0. expected 8, actual 9To log all recompilation reasons, use TORCH_LOGS="recompiles".To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html
If you know that the number of recompilations has a reasonable constant upper bound, you can raise the cache size limit.If the cost of recompilation outweighs the benefit of compilation, then you can consider lowering the cache size limit.
torch._dynamo.config.cache_size_limit=16@torch.compile(dynamic=False)defgn(x):returnx+1foriinrange(1,10):gn(torch.ones(i))
Recompiling function gn in /tmp/ipykernel_507/887097224.py:2 triggered by the following guard failure(s): - 9/0: tensor 'x' size mismatch at index 0. expected 1, actual 2Recompiling function gn in /tmp/ipykernel_507/887097224.py:2 triggered by the following guard failure(s): - 9/1: tensor 'x' size mismatch at index 0. expected 2, actual 3 - 9/0: tensor 'x' size mismatch at index 0. expected 1, actual 3Recompiling function gn in /tmp/ipykernel_507/887097224.py:2 triggered by the following guard failure(s): - 9/2: tensor 'x' size mismatch at index 0. expected 3, actual 4 - 9/1: tensor 'x' size mismatch at index 0. expected 2, actual 4 - 9/0: tensor 'x' size mismatch at index 0. expected 1, actual 4Recompiling function gn in /tmp/ipykernel_507/887097224.py:2 triggered by the following guard failure(s): - 9/3: tensor 'x' size mismatch at index 0. expected 4, actual 5 - 9/2: tensor 'x' size mismatch at index 0. expected 3, actual 5 - 9/1: tensor 'x' size mismatch at index 0. expected 2, actual 5 - 9/0: tensor 'x' size mismatch at index 0. expected 1, actual 5Recompiling function gn in /tmp/ipykernel_507/887097224.py:2 triggered by the following guard failure(s): - 9/4: tensor 'x' size mismatch at index 0. expected 5, actual 6 - 9/3: tensor 'x' size mismatch at index 0. expected 4, actual 6 - 9/2: tensor 'x' size mismatch at index 0. expected 3, actual 6 - 9/1: tensor 'x' size mismatch at index 0. expected 2, actual 6 - 9/0: tensor 'x' size mismatch at index 0. expected 1, actual 6Recompiling function gn in /tmp/ipykernel_507/887097224.py:2 triggered by the following guard failure(s): - 9/5: tensor 'x' size mismatch at index 0. expected 6, actual 7 - 9/4: tensor 'x' size mismatch at index 0. expected 5, actual 7 - 9/3: tensor 'x' size mismatch at index 0. expected 4, actual 7 - 9/2: tensor 'x' size mismatch at index 0. expected 3, actual 7 - 9/1: tensor 'x' size mismatch at index 0. expected 2, actual 7 - 9/0: tensor 'x' size mismatch at index 0. expected 1, actual 7Recompiling function gn in /tmp/ipykernel_507/887097224.py:2 triggered by the following guard failure(s): - 9/6: tensor 'x' size mismatch at index 0. expected 7, actual 8 - 9/5: tensor 'x' size mismatch at index 0. expected 6, actual 8 - 9/4: tensor 'x' size mismatch at index 0. expected 5, actual 8 - 9/3: tensor 'x' size mismatch at index 0. expected 4, actual 8 - 9/2: tensor 'x' size mismatch at index 0. expected 3, actual 8 - 9/1: tensor 'x' size mismatch at index 0. expected 2, actual 8 - 9/0: tensor 'x' size mismatch at index 0. expected 1, actual 8Recompiling function gn in /tmp/ipykernel_507/887097224.py:2 triggered by the following guard failure(s): - 9/7: tensor 'x' size mismatch at index 0. expected 8, actual 9 - 9/6: tensor 'x' size mismatch at index 0. expected 7, actual 9 - 9/5: tensor 'x' size mismatch at index 0. expected 6, actual 9 - 9/4: tensor 'x' size mismatch at index 0. expected 5, actual 9 - 9/3: tensor 'x' size mismatch at index 0. expected 4, actual 9 - 9/2: tensor 'x' size mismatch at index 0. expected 3, actual 9 - 9/1: tensor 'x' size mismatch at index 0. expected 2, actual 9 - 9/0: tensor 'x' size mismatch at index 0. expected 1, actual 9
Graph Breaking to Reduce Recompilation Costs#
If a large graph is recompiling and causing high compile time, you can intentionally introducea graph break in order to reduce recompilation costs, at the expense of introducing a performance hit.
defvery_large_function(x):returnx+1@torch.compile(dynamic=False)deffn(x,c):y=very_large_function(x)# recompiled every timereturny+cforiinrange(1,5):fn(torch.ones(3),i)@torch.compile(dynamic=False)defgn(x,c):y=very_large_function(x)# compiled only oncetorch._dynamo.graph_break()returny+c# recompiled every timeforiinrange(1,5):gn(torch.ones(3),i)
Recompiling function fn in /tmp/ipykernel_507/2876112129.py:4 triggered by the following guard failure(s): - 10/0: c == 1 # return y + c # mp/ipykernel_507/2876112129.py:7 in fnRecompiling function fn in /tmp/ipykernel_507/2876112129.py:4 triggered by the following guard failure(s): - 10/1: c == 2 # return y + c # mp/ipykernel_507/2876112129.py:7 in fn - 10/0: c == 1 # return y + c # mp/ipykernel_507/2876112129.py:7 in fnRecompiling function fn in /tmp/ipykernel_507/2876112129.py:4 triggered by the following guard failure(s): - 10/2: c == 3 # return y + c # mp/ipykernel_507/2876112129.py:7 in fn - 10/1: c == 2 # return y + c # mp/ipykernel_507/2876112129.py:7 in fn - 10/0: c == 1 # return y + c # mp/ipykernel_507/2876112129.py:7 in fnRecompiling function torch_dynamo_resume_in_gn_at_15 in /tmp/ipykernel_507/2876112129.py:15 triggered by the following guard failure(s): - 12/0: c == 1 # return y + c # recompiled every time # mp/ipykernel_507/2876112129.py:16 in torch_dynamo_resume_in_gn_at_15Recompiling function torch_dynamo_resume_in_gn_at_15 in /tmp/ipykernel_507/2876112129.py:15 triggered by the following guard failure(s): - 12/1: c == 2 # return y + c # recompiled every time # mp/ipykernel_507/2876112129.py:16 in torch_dynamo_resume_in_gn_at_15 - 12/0: c == 1 # return y + c # recompiled every time # mp/ipykernel_507/2876112129.py:16 in torch_dynamo_resume_in_gn_at_15Recompiling function torch_dynamo_resume_in_gn_at_15 in /tmp/ipykernel_507/2876112129.py:15 triggered by the following guard failure(s): - 12/2: c == 3 # return y + c # recompiled every time # mp/ipykernel_507/2876112129.py:16 in torch_dynamo_resume_in_gn_at_15 - 12/1: c == 2 # return y + c # recompiled every time # mp/ipykernel_507/2876112129.py:16 in torch_dynamo_resume_in_gn_at_15 - 12/0: c == 1 # return y + c # recompiled every time # mp/ipykernel_507/2876112129.py:16 in torch_dynamo_resume_in_gn_at_15