Rate this Page

Note

Go to the endto download the full example code.

Reducing AoT cold start compilation time with regional compilation#

Author:Sayak Paul,Charles Bensimon,Angela Yi

In theregional compilation recipe, we showedhow to reduce cold start compilation times while retaining (almost) full compilation benefits. This was demonstrated forjust-in-time (JIT) compilation.

This recipe shows how to apply similar principles when compiling a model ahead-of-time (AoT). If youare not familiar with AOTInductor andtorch.export, we recommend you to check outthis tutorial.

Prerequisites#

  • Pytorch 2.6 or later

  • Familiarity with regional compilation

  • Familiarity with AOTInductor andtorch.export

Setup#

Before we begin, we need to installtorch if it is not alreadyavailable.

pipinstalltorch

Steps#

In this recipe, we will follow the same steps as the regional compilation recipe mentioned above:

  1. Import all necessary libraries.

  2. Define and initialize a neural network with repeated regions.

  3. Measure the compilation time of the full model and the regional compilation with AoT.

First, let’s import the necessary libraries for loading our data:

importtorchtorch.set_grad_enabled(False)fromtimeimportperf_counter

Defining the Neural Network#

We will use the same neural network structure as the regional compilation recipe.

We will use a network, composed of repeated layers. This mimics alarge language model, that typically is composed of many Transformer blocks. In this recipe,we will create aLayer using thenn.Module class as a proxy for a repeated region.We will then create aModel which is composed of 64 instances of thisLayer class.

classLayer(torch.nn.Module):def__init__(self):super().__init__()self.linear1=torch.nn.Linear(10,10)self.relu1=torch.nn.ReLU()self.linear2=torch.nn.Linear(10,10)self.relu2=torch.nn.ReLU()defforward(self,x):a=self.linear1(x)a=self.relu1(a)a=torch.sigmoid(a)b=self.linear2(a)b=self.relu2(b)returnbclassModel(torch.nn.Module):def__init__(self):super().__init__()self.linear=torch.nn.Linear(10,10)self.layers=torch.nn.ModuleList([Layer()for_inrange(64)])defforward(self,x):# In regional compilation, the self.linear is outside of the scope of ``torch.compile``.x=self.linear(x)forlayerinself.layers:x=layer(x)returnx

Compiling the model ahead-of-time#

Since we’re compiling the model ahead-of-time, we need to prepare representativeinput examples, that we expect the model to see during actual deployments.

Let’s create an instance ofModel and pass it some sample input data.

model=Model().cuda()input=torch.randn(10,10,device="cuda")output=model(input)print(f"{output.shape=}")
output.shape=torch.Size([10, 10])

Now, let’s compile our model ahead-of-time. We will useinput created above to passtotorch.export. This will yield atorch.export.ExportedProgram which we can compile.

/usr/lib/python3.10/copyreg.py:101: FutureWarning:`isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead./usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py:321: UserWarning:TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.

We can load from thispath and use it to perform inference.

compiled_binary=torch._inductor.aoti_load_package(path)output_compiled=compiled_binary(input)print(f"{output_compiled.shape=}")
output_compiled.shape=torch.Size([10, 10])

Compiling _regions_ of the model ahead-of-time#

Compiling model regions ahead-of-time, on the other hand, requires a few key changes.

Since the compute pattern is shared by all the blocks thatare repeated in a model (Layer instances in this cases), we can justcompile a single block and let the inductor reuse it.

model=Model().cuda()path=torch._inductor.aoti_compile_and_package(torch.export.export(model.layers[0],args=(input,)),inductor_configs={# compile artifact w/o saving params in the artifact"aot_inductor.package_constants_in_so":False,})
/usr/lib/python3.10/copyreg.py:101: FutureWarning:`isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.

An exported program (torch.export.ExportedProgram) contains the Tensor computation,astate_dict containing tensor values of all lifted parameters and buffer alongsideother metadata. We specify theaot_inductor.package_constants_in_so to beFalse tonot serialize the model parameters in the generated artifact.

Now, when loading the compiled binary, we can reuse the existing parameters ofeach block. This lets us take advantage of the compiled binary obtained above.

forlayerinmodel.layers:compiled_layer=torch._inductor.aoti_load_package(path)compiled_layer.load_constants(layer.state_dict(),check_full_update=True,user_managed=True)layer.forward=compiled_layeroutput_regional_compiled=model(input)print(f"{output_regional_compiled.shape=}")
output_regional_compiled.shape=torch.Size([10, 10])

Just like JIT regional compilation, compiling regions within a model ahead-of-timeleads to significantly reduced cold start times. The actual number will vary frommodel to model.

Even though full model compilation offers the fullest scope of optimizations,for practical purposes and depending on the type of model, we have seen regionalcompilation (both JiT and AoT) providing similar speed benefits, while drasticallyreducing the cold start times.

Measuring compilation time#

Next, let’s measure the compilation time of the full model and the regional compilation.

defmeasure_compile_time(input,regional=False):start=perf_counter()model=aot_compile_load_model(regional=regional)torch.cuda.synchronize()end=perf_counter()# make sure the model works._=model(input)returnend-startdefaot_compile_load_model(regional=False)->torch.nn.Module:input=torch.randn(10,10,device="cuda")model=Model().cuda()inductor_configs={}ifregional:inductor_configs={"aot_inductor.package_constants_in_so":False}# Reset the compiler caches to ensure no reuse between different runstorch.compiler.reset()withtorch._inductor.utils.fresh_inductor_cache():path=torch._inductor.aoti_compile_and_package(torch.export.export(model.layers[0]ifregionalelsemodel,args=(input,)),inductor_configs=inductor_configs,)ifregional:forlayerinmodel.layers:compiled_layer=torch._inductor.aoti_load_package(path)compiled_layer.load_constants(layer.state_dict(),check_full_update=True,user_managed=True)layer.forward=compiled_layerelse:model=torch._inductor.aoti_load_package(path)returnmodelinput=torch.randn(10,10,device="cuda")full_model_compilation_latency=measure_compile_time(input,regional=False)print(f"Full model compilation time ={full_model_compilation_latency:.2f} seconds")regional_compilation_latency=measure_compile_time(input,regional=True)print(f"Regional compilation time ={regional_compilation_latency:.2f} seconds")assertregional_compilation_latency<full_model_compilation_latency
/usr/lib/python3.10/copyreg.py:101: FutureWarning:`isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.Full model compilation time = 11.35 seconds/usr/lib/python3.10/copyreg.py:101: FutureWarning:`isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.Regional compilation time = 4.89 seconds

There may also be layers in a model incompatible with compilation. So,full compilation will result in a fragmented computation graph resultingin potential latency degradation. In these case, regional compilationcan be beneficial.

Conclusion#

This recipe shows how to control the cold start time when compiling yourmodel ahead-of-time. This becomes effective when your model has repeatedblocks, which is typically seen in large generative models. We used thisrecipe on various models to speed up real-time performance. Learn morehere.

Total running time of the script: (0 minutes 41.550 seconds)