Rate this Page

(beta) Compiling the optimizer with torch.compile#

Created On: Jan 24, 2024 | Last Updated: Jan 29, 2024 | Last Verified: Nov 05, 2024

Author:Michael Lazos

The optimizer is a key algorithm for training any deep learning model.Since it is responsible for updating every model parameter, it can oftenbecome the bottleneck in training performance for large models. In this recipe,we will applytorch.compile to the optimizer to observe the GPU performanceimprovement.

Note

This tutorial requires PyTorch 2.2.0 or later.

Model Setup#

For this example, we’ll use a simple sequence of linear layers.Since we are only benchmarking the optimizer, the choice of model doesn’t matterbecause optimizer performance is a function of the number of parameters.

Depending on what machine you are using, your exact results may vary.

importtorchmodel=torch.nn.Sequential(*[torch.nn.Linear(1024,1024,False,device="cuda")for_inrange(10)])input=torch.rand(1024,device="cuda")output=model(input)output.sum().backward()

Setting up and running the optimizer benchmark#

In this example, we’ll use the Adam optimizerand create a helper function to wrap the step()intorch.compile().

Note

torch.compile is only supported on cuda devices with compute capability >= 7.0

# exit cleanly if we are on a device that doesn't support torch.compileiftorch.cuda.get_device_capability()<(7,0):print("Exiting because torch.compile is not supported on this device.")importsyssys.exit(0)opt=torch.optim.Adam(model.parameters(),lr=0.01)@torch.compile(fullgraph=False)deffn():opt.step()# Let's define a helpful benchmarking function:importtorch.utils.benchmarkasbenchmarkdefbenchmark_torch_function_in_microseconds(f,*args,**kwargs):t0=benchmark.Timer(stmt="f(*args, **kwargs)",globals={"args":args,"kwargs":kwargs,"f":f})returnt0.blocked_autorange().mean*1e6# Warmup runs to compile the functionfor_inrange(5):fn()eager_runtime=benchmark_torch_function_in_microseconds(opt.step)compiled_runtime=benchmark_torch_function_in_microseconds(fn)asserteager_runtime>compiled_runtimeprint(f"eager runtime:{eager_runtime}us")print(f"compiled runtime:{compiled_runtime}us")

Sample Results:

  • Eager runtime: 747.2437149845064us

  • Compiled runtime: 392.07384741178us

See Also#

  • For an in-depth technical overview, see

Compiling the optimizer with PT2