Getting Started#
Created On: Jun 16, 2025 | Last Updated On: Jun 16, 2025
Before you read this section, make sure to read thetorch.compiler
let’s start by looking at a simpletorch.compile example that demonstrateshow to usetorch.compile for inference. This example demonstrates thetorch.cos() andtorch.sin() features which are examples of pointwiseoperators as they operate element by element on a vector. This example mightnot show significant performance gains but should help you form an intuitiveunderstanding of how you can usetorch.compile in your own programs.
Note
To run this script, you need to have at least one GPU on your machine.If you do not have a GPU, you can remove the.to(device="cuda:0") codein the snippet below and it will run on CPU. You can also set device toxpu:0 to run on Intel® GPUs.
importtorchdeffn(x):a=torch.cos(x)b=torch.sin(a)returnbnew_fn=torch.compile(fn,backend="inductor")input_tensor=torch.randn(10000).to(device="cuda:0")a=new_fn(input_tensor)
A more famous pointwise operator you might want to use wouldbe something liketorch.relu(). Pointwise ops in eager mode aresuboptimal because each one would need to read a tensor from thememory, make some changes, and then write back those changes. The singlemost important optimization that inductor performs is fusion. In theexample above we can turn 2 reads (x,a) and2 writes (a,b) into 1 read (x) and 1 write (b), whichis crucial especially for newer GPUs where the bottleneck is memorybandwidth (how quickly you can send data to a GPU) rather than compute(how quickly your GPU can crunch floating point operations).
Another major optimization that inductor provides is automaticsupport for CUDA graphs.CUDA graphs help eliminate the overhead from launching individualkernels from a Python program which is especially relevant for newer GPUs.
TorchDynamo supports many different backends, but TorchInductor specifically worksby generatingTriton kernels. Let’s saveour example above into a file calledexample.py. We can inspect the codegenerated Triton kernels by runningTORCH_COMPILE_DEBUG=1pythonexample.py.As the script executes, you should seeDEBUG messages printed to theterminal. Closer to the end of the log, you should see a path to a folderthat containstorchinductor_<your_username>. In that folder, you can findtheoutput_code.py file that contains the generated kernel code similar tothe following:
@pointwise(size_hints=[16384],filename=__file__,triton_meta={'signature':{'in_ptr0':'*fp32','out_ptr0':'*fp32','xnumel':'i32'},'device':0,'constants':{},'mutated_arg_names':[],'configs':[AttrsDescriptor(divisible_by_16=(0,1,2),equal_to_1=())]})@triton.jitdeftriton_(in_ptr0,out_ptr0,xnumel,XBLOCK:tl.constexpr):xnumel=10000xoffset=tl.program_id(0)*XBLOCKxindex=xoffset+tl.arange(0,XBLOCK)[:]xmask=xindex<xnumelx0=xindextmp0=tl.load(in_ptr0+(x0),xmask,other=0.0)tmp1=tl.cos(tmp0)tmp2=tl.sin(tmp1)tl.store(out_ptr0+(x0+tl.zeros([XBLOCK],tl.int32)),tmp2,xmask)
Note
The above code snippet is an example. Depending on your hardware,you might see different code generated.
And you can verify that fusing thecos andsin did actually occurbecause thecos andsin operations occur within a single Triton kerneland the temporary variables are held in registers with very fast access.
Read more on Triton’s performancehere. Because the code is writtenin Python, it’s fairly easy to understand even if you have not written all thatmany CUDA kernels.
Next, let’s try a real model like resnet50 from the PyTorchhub.
importtorchmodel=torch.hub.load('pytorch/vision:v0.10.0','resnet50',pretrained=True)opt_model=torch.compile(model,backend="inductor")opt_model(torch.randn(1,3,64,64))
And that is not the only available backend, you can run in a REPLtorch.compiler.list_backends() to see all the available backends. Try out thecudagraphs next as inspiration.
Using a pretrained model#
PyTorch users frequently leverage pretrained models fromtransformers orTIMM and one ofthe design goals is TorchDynamo and TorchInductor is to work out of the box withany model that people would like to author.
Let’s download a pretrained model directly from the HuggingFace hub and optimizeit:
importtorchfromtransformersimportBertTokenizer,BertModel# Copy pasted from here https://huggingface.co/bert-base-uncasedtokenizer=BertTokenizer.from_pretrained('bert-base-uncased')model=BertModel.from_pretrained("bert-base-uncased").to(device="cuda:0")model=torch.compile(model,backend="inductor")# This is the only line of code that we changedtext="Replace me by any text you'd like."encoded_input=tokenizer(text,return_tensors='pt').to(device="cuda:0")output=model(**encoded_input)
If you remove theto(device="cuda:0") from the model andencoded_input, then Triton will generate C++ kernels that will beoptimized for running on your CPU. You can inspect both Triton or C++kernels for BERT. They are more complex than the trigonometryexample we tried above but you can similarly skim through it and see if youunderstand how PyTorch works.
Similarly, let’s try out a TIMM example:
importtimmimporttorchmodel=timm.create_model('resnext101_32x8d',pretrained=True,num_classes=2)opt_model=torch.compile(model,backend="inductor")opt_model(torch.randn(64,3,7,7))
Next Steps#
In this section, we have reviewed a few inference examples and developed abasic understanding of how torch.compile works. Here is what you check out next: