Rate this Page

Note

Go to the endto download the full example code.

DebugMode: Recording Dispatched Operations and Numerical Debugging#

Authors: Pian Pawakapan, Shangdi Yu

What you will learn
  • How to capture dispatched ops for eager andtorch.compile runs

  • How to use tensor hashes and stack traces in DebugMode to pinpoint numerical divergence

Prerequisites
  • PyTorch 2.10 or later

Overview#

DebugMode (torch.utils._debug_mode.DebugMode) is aTorchDispatchMode that intercepts PyTorch runtime calls and emits ahierarchical log of operations. It is particularly useful when you need tounderstandwhat actually runs, both in eager mode and undertorch.compileor when you need to pinpoint numerical divergence between two runs.

Key capabilities:

  • Runtime logging – Records dispatched operations and TorchInductor compiledTriton kernels.

  • Tensor hashing – Attaches deterministic hashes to inputs/outputs to enablediffing runs to locate numerical divergences.

  • Dispatch hooks – Allows registration of custom hooks to annotate calls

Note

This recipe describes a prototype feature. Prototype features are typicallyat an early stage for feedback and testing and are subject to change.

Quick start#

The snippet below captures a small eager workload and prints the debug string:

fromtorch._inductor.decompositionimportdecomps_to_excludeimporttorchfromtorch.utils._debug_modeimportDebugModedefrun_once():x=torch.randn(8,8)y=torch.randn(8,8)returntorch.mm(torch.relu(x),y)withDebugMode()asdebug_mode:out=run_once()print("DebugMode output:")print(debug_mode.debug_string())
DebugMode output:    aten::randn([8, 8], device=cpu, pin_memory=False)  ->  t: f32[8, 8]    aten::randn([8, 8], device=cpu, pin_memory=False)  ->  t: f32[8, 8]    aten::relu(t: f32[8, 8])  ->  t: f32[8, 8]    aten::mm(t: f32[8, 8], t: f32[8, 8])  ->  t: f32[8, 8]

Getting more metadata#

For most investigations, you’ll want to enable stack traces, tensor IDs, and tensor hashing.These features provide metadata to correlate operations back to model code.

DebugMode.log_tensor_hashes decorates the log with hashes for every call.Thehash_tensor hash function usestorch.hash_tensor, which returns 0 for tensors whoseelements are all the same. Thenorm hash function usesnorm withp=1.With both these functions, especiallynorm, tensor closeness in numerics is related to hash closeness,so it’s rather interpretable. The defaulthash_fn isnorm.

with(DebugMode(# record_stack_trace is only supported for eager in pytorch 2.10record_stack_trace=True,record_ids=True,)asdebug_mode,DebugMode.log_tensor_hashes(hash_fn=["norm"],# this is the defaulthash_inputs=True,),):result=run_once()print("DebugMode output with more metadata:")print(debug_mode.debug_string(show_stack_trace=True))
DebugMode output with more metadata:    # File: /var/lib/workspace/recipes_source/debug_mode_tutorial.py:59 in run_once, code: x = torch.randn(8, 8)    aten::randn([8, 8], device=cpu, pin_memory=False)  ->  t$0: f32[8, 8]  # {'hash': (48.68385210155975,)}    # File: /var/lib/workspace/recipes_source/debug_mode_tutorial.py:60 in run_once, code: y = torch.randn(8, 8)    aten::randn([8, 8], device=cpu, pin_memory=False)  ->  t$1: f32[8, 8]  # {'hash': (54.39866405725479,)}    # File: /var/lib/workspace/recipes_source/debug_mode_tutorial.py:61 in run_once, code: return torch.mm(torch.relu(x), y)    aten::relu(t$0: f32[8, 8])  ->  t$2: f32[8, 8]  # {'input_hash': (((48.68385210155975,),), {}), 'hash': (26.45614107651636,)}    aten::mm(t$2: f32[8, 8], t$1: f32[8, 8])  ->  t$3: f32[8, 8]  # {'input_hash': (((26.45614107651636,), (54.39866405725479,)), {}), 'hash': (128.11866921186447,)}

Each line followsop(args)->outputs. Whenrecord_ids is enabled,tensors are suffixed with$<id> and DTensors are labeleddt.

Log Triton kernels#

Though Triton kernels are not dispatched, DebugMode has custom logic that logs their inputs and outputs.

Inductor-generated Triton kernels show up with a[triton] prefix.Pre/post hash annotations report buffer hashes around each kernel call, whichis helpful when isolating incorrect kernels.

deff(x):returntorch.mm(torch.relu(x),x.T)x=torch.randn(3,3,device="cuda")with(DebugMode(record_output=True)asdebug_mode,DebugMode.log_tensor_hashes(hash_inputs=True,)):a=torch.compile(f)(x)print("Triton in DebugMode logs:")print(debug_mode.debug_string())
/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py:321: UserWarning:TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.Triton in DebugMode logs:    aten::_to_copy(t: f32[3, 3], dtype=torch.float64)  ->  t: f64[3, 3]  # {'input_hash': ((9.95613181591034,), {'dtype': None}), 'hash': 9.95613181591034}    aten::linalg_vector_norm(t: f64[3, 3], 1)  ->  t: f64[]  # {'input_hash': ((9.95613181591034, None), {}), 'hash': 9.95613181591034}    aten::_local_scalar_dense(t: f64[])  ->  9.95613181591034  # {'input_hash': ((9.95613181591034,), {}), 'hash': None}    aten::_to_copy(t: f32[3, 3], dtype=torch.float64)  ->  t: f64[3, 3]  # {'input_hash': ((7.037307012826204,), {'dtype': None}), 'hash': 7.037307012826204}    aten::linalg_vector_norm(t: f64[3, 3], 1)  ->  t: f64[]  # {'input_hash': ((7.037307012826204, None), {}), 'hash': 7.037307012826204}    aten::_local_scalar_dense(t: f64[])  ->  7.037307012826204  # {'input_hash': ((7.037307012826204,), {}), 'hash': None}    [triton] triton_poi_fused_relu_0(in_ptr0=t: f32[3, 3], out_ptr0=t: f32[3, 3], xnumel=9)    # pre-kernel hashes: {in_ptr0: 9.95613181591034, out_ptr0: 7.037307012826204}    # post-kernel hashes: {in_ptr0: 9.95613181591034, out_ptr0: 8.674386978149414}    aten::_to_copy(t: f32[3, 3], dtype=torch.float64)  ->  t: f64[3, 3]  # {'input_hash': ((9.95613181591034,), {'dtype': None}), 'hash': 9.95613181591034}    aten::linalg_vector_norm(t: f64[3, 3], 1)  ->  t: f64[]  # {'input_hash': ((9.95613181591034, None), {}), 'hash': 9.95613181591034}    aten::_local_scalar_dense(t: f64[])  ->  9.95613181591034  # {'input_hash': ((9.95613181591034,), {}), 'hash': None}    aten::_to_copy(t: f32[3, 3], dtype=torch.float64)  ->  t: f64[3, 3]  # {'input_hash': ((8.674386978149414,), {'dtype': None}), 'hash': 8.674386978149414}    aten::linalg_vector_norm(t: f64[3, 3], 1)  ->  t: f64[]  # {'input_hash': ((8.674386978149414, None), {}), 'hash': 8.674386978149414}    aten::_local_scalar_dense(t: f64[])  ->  8.674386978149414  # {'input_hash': ((8.674386978149414,), {}), 'hash': None}    aten::mm.out(t: f32[3, 3], t: f32[3, 3], out=t: f32[3, 3])  ->  t: f32[3, 3]  # {'input_hash': ((8.674386978149414, 9.95613181591034), {'out': 3.6893488147419103e+19}), 'hash': 23.458925664424896}

Numerical debugging with tensor hashes#

If you have numerical divergence between modes, you can use DebugMode to find where thenumerical divergence originates.In the example below, you can see that all tensor hashes are the same for eager mode and compiled mode.If any hash is different, then that’s where the numerical divergence is coming from.

defrun_model(model,data,*,compile_with=None):ifcompile_withisnotNone:model=torch.compile(model,backend=compile_with)withDebugMode(record_output=True)asdm,DebugMode.log_tensor_hashes(hash_inputs=True,):dm_out=model(*data)returndm,dm_outclassToy(torch.nn.Module):defforward(self,x):returntorch.relu(x).mm(x.T)inputs=(torch.randn(4,4),)dm_eager,_=run_model(Toy(),inputs)dm_compiled,_=run_model(Toy(),inputs,compile_with="aot_eager")print("Eager mode:")print(dm_eager.debug_string())print("Compiled aot_eager mode:")print(dm_compiled.debug_string())
Eager mode:    aten::relu(t: f32[4, 4])  ->  t: f32[4, 4]  # {'input_hash': ((11.804871261119843,), {}), 'hash': 6.392621695995331}    aten::permute(t: f32[4, 4], [1, 0])  ->  t: f32[4, 4]  # {'input_hash': ((11.804871261119843, [None, None]), {}), 'hash': 11.804871261119843}    aten::mm(t: f32[4, 4], t: f32[4, 4])  ->  t: f32[4, 4]  # {'input_hash': ((6.392621695995331, 11.804871261119843), {}), 'hash': 11.2179414331913}Compiled aot_eager mode:    aten::relu(t: f32[4, 4])  ->  t: f32[4, 4]  # {'input_hash': ((11.804871261119843,), {}), 'hash': 6.392621695995331}    aten::permute(t: f32[4, 4], [1, 0])  ->  t: f32[4, 4]  # {'input_hash': ((11.804871261119843, [None, None]), {}), 'hash': 11.804871261119843}    aten::mm(t: f32[4, 4], t: f32[4, 4])  ->  t: f32[4, 4]  # {'input_hash': ((6.392621695995331, 11.804871261119843), {}), 'hash': 11.2179414331913}

Now let’s look at an example where the tensor hashes are different.I intentionally wrote a wrong decomposition that decomposes cosine to sin.This will cause numerical divergence.

fromtorch._dynamo.backends.commonimportaot_autogradfromtorch._dynamo.backends.debuggingimportget_nop_funcdefwrong_decomp(x):returntorch.sin(x)decomp_table={}decomp_table[torch.ops.aten.cos.default]=wrong_decompbackend=aot_autograd(fw_compiler=get_nop_func(),bw_compiler=get_nop_func(),decompositions=decomp_table)deff(x):y=x.relu()z=torch.cos(x)returny+zx=torch.randn(3,3)withDebugMode(record_output=True)asdm_eager,DebugMode.log_tensor_hashes(hash_inputs=True,):f(x)withDebugMode(record_output=True)asdm_compiled,DebugMode.log_tensor_hashes(hash_inputs=True,):torch.compile(f,backend=backend)(x)print("Eager:")print(dm_eager.debug_string(show_stack_trace=True))print()print("Compiled with wrong decomposition:")print(dm_compiled.debug_string())
Eager:    aten::relu(t: f32[3, 3])  ->  t: f32[3, 3]  # {'input_hash': ((6.847492925822735,), {}), 'hash': 2.654794916510582}    aten::cos(t: f32[3, 3])  ->  t: f32[3, 3]  # {'input_hash': ((6.847492925822735,), {}), 'hash': 5.881347939372063}    aten::add.Tensor(t: f32[3, 3], t: f32[3, 3])  ->  t: f32[3, 3]  # {'input_hash': ((2.654794916510582, 5.881347939372063), {}), 'hash': 8.536142885684967}Compiled with wrong decomposition:    aten::relu(t: f32[3, 3])  ->  t: f32[3, 3]  # {'input_hash': ((6.847492925822735,), {}), 'hash': 2.654794916510582}    aten::sin(t: f32[3, 3])  ->  t: f32[3, 3]  # {'input_hash': ((6.847492925822735,), {}), 'hash': 5.609869331121445}    aten::add.Tensor(t: f32[3, 3], t: f32[3, 3])  ->  t: f32[3, 3]  # {'input_hash': ((2.654794916510582, 5.609869331121445), {}), 'hash': 8.264664381742477}

In the eager log, we haveaten::cos, but in the compiled log, we haveaten::sin.Moreover, the output hash is different between eager and compiled mode.Diffing the two logs would show that the first numerical divergence shows up in theaten::cos call.

Custom dispatch hooks#

Hooks allow you to annotate each call with custom metadata such as GPU memory usage.log_hook returns a mappingthat is rendered inline with the debug string.

MB=1024*1024.0defmemory_hook(func,types,args,kwargs,result):mem=torch.cuda.memory_allocated()/MBiftorch.cuda.is_available()else0.0peak=torch.cuda.max_memory_allocated()/MBiftorch.cuda.is_available()else0.0torch.cuda.reset_peak_memory_stats()iftorch.cuda.is_available()elseNonereturn{"mem":f"{mem:.3f} MB","peak":f"{peak:.3f} MB"}with(DebugMode()asdm,DebugMode.dispatch_hooks(log_hook=memory_hook),):run_once()print("DebugMode output with memory usage:")print(dm.debug_string())
DebugMode output with memory usage:    aten::randn([8, 8], device=cpu, pin_memory=False)  ->  t: f32[8, 8]  # {'mem': '8.125 MB', 'peak': '14.128 MB'}    aten::randn([8, 8], device=cpu, pin_memory=False)  ->  t: f32[8, 8]  # {'mem': '8.125 MB', 'peak': '8.125 MB'}    aten::relu(t: f32[8, 8])  ->  t: f32[8, 8]  # {'mem': '8.125 MB', 'peak': '8.125 MB'}    aten::mm(t: f32[8, 8], t: f32[8, 8])  ->  t: f32[8, 8]  # {'mem': '8.125 MB', 'peak': '8.125 MB'}

Module boundaries#

record_nn_module=True inserts[nn.Mod] markers that show whichmodule executed each set of operations. As of PyTorch 2.10 it only works in eager mode,but support for compiled modes is under development.

classFoo(torch.nn.Module):def__init__(self):super().__init__()self.l1=torch.nn.Linear(4,4)self.l2=torch.nn.Linear(4,4)defforward(self,x):returnself.l2(self.l1(x))classBar(torch.nn.Module):def__init__(self):super().__init__()self.abc=Foo()self.xyz=torch.nn.Linear(4,4)defforward(self,x):returnself.xyz(self.abc(x))mod=Bar()inp=torch.randn(4,4)withDebugMode(record_nn_module=True,record_output=False)asdebug_mode:_=mod(inp)print("DebugMode output with stack traces and module boundaries:")print(debug_mode.debug_string(show_stack_trace=True))
DebugMode output with stack traces and module boundaries:  [nn.Mod] Bar    [nn.Mod] Bar.abc      [nn.Mod] Bar.abc.l1          aten::t(t: f32[4, 4])          aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4])      [nn.Mod] Bar.abc.l2          aten::t(t: f32[4, 4])          aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4])    [nn.Mod] Bar.xyz        aten::t(t: f32[4, 4])        aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4])

Conclusion#

In this tutorial, we saw how DebugMode gives you a lightweight, runtime-onlyview of what PyTorch actually executed, whether you are running eager code orcompiled graphs. By layering tensor hashing, Triton logging, and customdispatch hooks you can quickly track down numerical differences. This isespecially helpful in debugging bit-wise equivalence between runs.

Total running time of the script: (0 minutes 2.483 seconds)