Rate this Page

Note

Go to the endto download the full example code.

Reducing torch.compile cold start compilation time with regional compilation#

Created On: Oct 10, 2024 | Last Updated: Oct 16, 2024 | Last Verified: Oct 10, 2024

Author:Animesh Jain

As deep learning models get larger, the compilation time of these models alsoincreases. This extended compilation time can result in a large startup time ininference services or wasted resources in large-scale training. This recipeshows an example of how to reduce the cold start compilation time by choosing tocompile a repeated region of the model instead of the entire model.

Prerequisites#

  • Pytorch 2.5 or later

Setup#

Before we begin, we need to installtorch if it is not alreadyavailable.

pipinstalltorch

Note

This feature is available starting with the 2.5 release. If you are using version 2.4,you can enable the configuration flagtorch._dynamo.config.inline_inbuilt_nn_modules=Trueto prevent recompilations during regional compilation. In version 2.5, this flag is enabled by default.

fromtimeimportperf_counter

Steps#

In this recipe, we will follow these steps:

  1. Import all necessary libraries.

  2. Define and initialize a neural network with repeated regions.

  3. Understand the difference between the full model and the regional compilation.

  4. Measure the compilation time of the full model and the regional compilation.

First, let’s import the necessary libraries for loading our data:

importtorchimporttorch.nnasnn

Next, let’s define and initialize a neural network with repeated regions.

Typically, neural networks are composed of repeated layers. For example, alarge language model is composed of many Transformer blocks. In this recipe,we will create aLayer using thenn.Module class as a proxy for a repeated region.We will then create aModel which is composed of 64 instances of thisLayer class.

classLayer(torch.nn.Module):def__init__(self):super().__init__()self.linear1=torch.nn.Linear(10,10)self.relu1=torch.nn.ReLU()self.linear2=torch.nn.Linear(10,10)self.relu2=torch.nn.ReLU()defforward(self,x):a=self.linear1(x)a=self.relu1(a)a=torch.sigmoid(a)b=self.linear2(a)b=self.relu2(b)returnbclassModel(torch.nn.Module):def__init__(self,apply_regional_compilation):super().__init__()self.linear=torch.nn.Linear(10,10)# Apply compile only to the repeated layers.ifapply_regional_compilation:self.layers=torch.nn.ModuleList([torch.compile(Layer())for_inrange(64)])else:self.layers=torch.nn.ModuleList([Layer()for_inrange(64)])defforward(self,x):# In regional compilation, the self.linear is outside of the scope of `torch.compile`.x=self.linear(x)forlayerinself.layers:x=layer(x)returnx

Next, let’s review the difference between the full model and the regional compilation.

In full model compilation, the entire model is compiled as a whole. This is the common approachmost users take withtorch.compile. In this example, we applytorch.compile totheModel object. This will effectively inline the 64 layers, producing alarge graph to compile. You can look at the full graph by running this recipewithTORCH_LOGS=graph_code.

model=Model(apply_regional_compilation=False).cuda()full_compiled_model=torch.compile(model)

The regional compilation, on the other hand, compiles a region of the model.By strategically choosing to compile a repeated region of the model, we can compile amuch smaller graph and then reuse the compiled graph for all the regions.In the example,torch.compile is applied only to thelayers and not the full model.

regional_compiled_model=Model(apply_regional_compilation=True).cuda()

Applying compilation to a repeated region, instead of full model, leads tolarge savings in compile time. Here, we will just compile a layer instance andthen reuse it 64 times in theModel object.

Note that with repeated regions, some part of the model might not be compiled.For example, theself.linear in theModel is outside of the scope ofregional compilation.

Also, note that there is a tradeoff between performance speedup and compiletime. Full model compilation involves a larger graph and,theoretically, offers more scope for optimizations. However, for practicalpurposes and depending on the model, we have observed many cases with minimalspeedup differences between the full model and regional compilation.

Next, let’s measure the compilation time of the full model and the regional compilation.

torch.compile is a JIT compiler, which means that it compiles on the first invocation.In the code below, we measure the total time spent in the first invocation. While this method is notprecise, it provides a good estimate since the majority of the time is spent incompilation.

defmeasure_latency(fn,input):# Reset the compiler caches to ensure no reuse between different runstorch.compiler.reset()withtorch._inductor.utils.fresh_inductor_cache():start=perf_counter()fn(input)torch.cuda.synchronize()end=perf_counter()returnend-startinput=torch.randn(10,10,device="cuda")full_model_compilation_latency=measure_latency(full_compiled_model,input)print(f"Full model compilation time ={full_model_compilation_latency:.2f} seconds")regional_compilation_latency=measure_latency(regional_compiled_model,input)print(f"Regional compilation time ={regional_compilation_latency:.2f} seconds")assertregional_compilation_latency<full_model_compilation_latency
/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py:321: UserWarning:TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.Full model compilation time = 11.16 secondsRegional compilation time = 0.90 seconds

Conclusion#

This recipe shows how to control the cold start compilation time if your modelhas repeated regions. This approach requires user modifications to applytorch.compile tothe repeated regions instead of more commonly used full model compilation. Weare continually working on reducing cold start compilation time.

Total running time of the script: (0 minutes 13.093 seconds)