Note
Go to the endto download the full example code.
Automatic Mixed Precision#
Created On: Sep 15, 2020 | Last Updated: Jan 30, 2025 | Last Verified: Nov 05, 2024
Author:Michael Carilli
torch.cuda.amp provides convenience methods for mixed precision,where some operations use thetorch.float32 (float) datatype and other operationsusetorch.float16 (half). Some ops, like linear layers and convolutions,are much faster infloat16 orbfloat16. Other ops, like reductions, often require the dynamicrange offloat32. Mixed precision tries to match each op to its appropriate datatype,which can reduce your network’s runtime and memory footprint.
Ordinarily, “automatic mixed precision training” usestorch.autocast andtorch.cuda.amp.GradScaler together.
This recipe measures the performance of a simple network in default precision,then walks through addingautocast andGradScaler to run the same network inmixed precision with improved performance.
You may download and run this recipe as a standalone Python script.The only requirements are PyTorch 1.6 or later and a CUDA-capable GPU.
Mixed precision primarily benefits Tensor Core-enabled architectures (Volta, Turing, Ampere).This recipe should show significant (2-3X) speedup on those architectures.On earlier architectures (Kepler, Maxwell, Pascal), you may observe a modest speedup.Runnvidia-smi to display your GPU’s architecture.
importtorch,time,gc# Timing utilitiesstart_time=Nonedefstart_timer():globalstart_timegc.collect()torch.cuda.empty_cache()torch.cuda.reset_max_memory_allocated()torch.cuda.synchronize()start_time=time.time()defend_timer_and_print(local_msg):torch.cuda.synchronize()end_time=time.time()print("\n"+local_msg)print("Total execution time ={:.3f} sec".format(end_time-start_time))print("Max memory used by tensors ={} bytes".format(torch.cuda.max_memory_allocated()))
A simple network#
The following sequence of linear layers and ReLUs should show a speedup with mixed precision.
defmake_model(in_size,out_size,num_layers):layers=[]for_inrange(num_layers-1):layers.append(torch.nn.Linear(in_size,in_size))layers.append(torch.nn.ReLU())layers.append(torch.nn.Linear(in_size,out_size))returntorch.nn.Sequential(*tuple(layers)).cuda()
batch_size,in_size,out_size, andnum_layers are chosen to be large enough to saturate the GPU with work.Typically, mixed precision provides the greatest speedup when the GPU is saturated.Small networks may be CPU bound, in which case mixed precision won’t improve performance.Sizes are also chosen such that linear layers’ participating dimensions are multiples of 8,to permit Tensor Core usage on Tensor Core-capable GPUs (seeTroubleshooting below).
Exercise: Vary participating sizes and see how the mixed precision speedup changes.
batch_size=512# Try, for example, 128, 256, 513.in_size=4096out_size=4096num_layers=3num_batches=50epochs=3device='cuda'iftorch.cuda.is_available()else'cpu'torch.set_default_device(device)# Creates data in default precision.# The same data is used for both default and mixed precision trials below.# You don't need to manually change inputs' ``dtype`` when enabling mixed precision.data=[torch.randn(batch_size,in_size)for_inrange(num_batches)]targets=[torch.randn(batch_size,out_size)for_inrange(num_batches)]loss_fn=torch.nn.MSELoss().cuda()
Default Precision#
Withouttorch.cuda.amp, the following simple network executes all ops in default precision (torch.float32):
net=make_model(in_size,out_size,num_layers)opt=torch.optim.SGD(net.parameters(),lr=0.001)start_timer()forepochinrange(epochs):forinput,targetinzip(data,targets):output=net(input)loss=loss_fn(output,target)loss.backward()opt.step()opt.zero_grad()# set_to_none=True here can modestly improve performanceend_timer_and_print("Default precision:")
Addingtorch.autocast#
Instances oftorch.autocastserve as context managers that allow regions of your script to run in mixed precision.
In these regions, CUDA ops run in adtype chosen byautocastto improve performance while maintaining accuracy.See theAutocast Op Referencefor details on what precisionautocast chooses for each op, and under what circumstances.
forepochinrange(0):# 0 epochs, this section is for illustration onlyforinput,targetinzip(data,targets):# Runs the forward pass under ``autocast``.withtorch.autocast(device_type=device,dtype=torch.float16):output=net(input)# output is float16 because linear layers ``autocast`` to float16.assertoutput.dtypeistorch.float16loss=loss_fn(output,target)# loss is float32 because ``mse_loss`` layers ``autocast`` to float32.assertloss.dtypeistorch.float32# Exits ``autocast`` before backward().# Backward passes under ``autocast`` are not recommended.# Backward ops run in the same ``dtype`` ``autocast`` chose for corresponding forward ops.loss.backward()opt.step()opt.zero_grad()# set_to_none=True here can modestly improve performance
AddingGradScaler#
Gradient scalinghelps prevent gradients with small magnitudes from flushing to zero(“underflowing”) when training with mixed precision.
torch.cuda.amp.GradScalerperforms the steps of gradient scaling conveniently.
# Constructs a ``scaler`` once, at the beginning of the convergence run, using default arguments.# If your network fails to converge with default ``GradScaler`` arguments, please file an issue.# The same ``GradScaler`` instance should be used for the entire convergence run.# If you perform multiple convergence runs in the same script, each run should use# a dedicated fresh ``GradScaler`` instance. ``GradScaler`` instances are lightweight.scaler=torch.amp.GradScaler("cuda")forepochinrange(0):# 0 epochs, this section is for illustration onlyforinput,targetinzip(data,targets):withtorch.autocast(device_type=device,dtype=torch.float16):output=net(input)loss=loss_fn(output,target)# Scales loss. Calls ``backward()`` on scaled loss to create scaled gradients.scaler.scale(loss).backward()# ``scaler.step()`` first unscales the gradients of the optimizer's assigned parameters.# If these gradients do not contain ``inf``s or ``NaN``s, optimizer.step() is then called,# otherwise, optimizer.step() is skipped.scaler.step(opt)# Updates the scale for next iteration.scaler.update()opt.zero_grad()# set_to_none=True here can modestly improve performance
All together: “Automatic Mixed Precision”#
(The following also demonstratesenabled, an optional convenience argument toautocast andGradScaler.If False,autocast andGradScaler‘s calls become no-ops.This allows switching between default precision and mixed precision without if/else statements.)
use_amp=Truenet=make_model(in_size,out_size,num_layers)opt=torch.optim.SGD(net.parameters(),lr=0.001)scaler=torch.amp.GradScaler("cuda",enabled=use_amp)start_timer()forepochinrange(epochs):forinput,targetinzip(data,targets):withtorch.autocast(device_type=device,dtype=torch.float16,enabled=use_amp):output=net(input)loss=loss_fn(output,target)scaler.scale(loss).backward()scaler.step(opt)scaler.update()opt.zero_grad()# set_to_none=True here can modestly improve performanceend_timer_and_print("Mixed precision:")
Inspecting/modifying gradients (e.g., clipping)#
All gradients produced byscaler.scale(loss).backward() are scaled. If you wish to modify or inspectthe parameters’.grad attributes betweenbackward() andscaler.step(optimizer), you shouldunscale them first usingscaler.unscale_(optimizer).
forepochinrange(0):# 0 epochs, this section is for illustration onlyforinput,targetinzip(data,targets):withtorch.autocast(device_type=device,dtype=torch.float16):output=net(input)loss=loss_fn(output,target)scaler.scale(loss).backward()# Unscales the gradients of optimizer's assigned parameters in-placescaler.unscale_(opt)# Since the gradients of optimizer's assigned parameters are now unscaled, clips as usual.# You may use the same value for max_norm here as you would without gradient scaling.torch.nn.utils.clip_grad_norm_(net.parameters(),max_norm=0.1)scaler.step(opt)scaler.update()opt.zero_grad()# set_to_none=True here can modestly improve performance
Saving/Resuming#
To save/resume Amp-enabled runs with bitwise accuracy, usescaler.state_dict andscaler.load_state_dict.
When saving, save thescaler state dict alongside the usual model and optimizer statedicts.Do this either at the beginning of an iteration before any forward passes, or at the end ofan iteration afterscaler.update().
checkpoint={"model":net.state_dict(),"optimizer":opt.state_dict(),"scaler":scaler.state_dict()}# Write checkpoint as desired, e.g.,# torch.save(checkpoint, "filename")
When resuming, load thescaler state dict alongside the model and optimizer statedicts.Read checkpoint as desired, for example:
dev=torch.cuda.current_device()checkpoint=torch.load("filename",map_location=lambdastorage,loc:storage.cuda(dev))
net.load_state_dict(checkpoint["model"])opt.load_state_dict(checkpoint["optimizer"])scaler.load_state_dict(checkpoint["scaler"])
If a checkpoint was created from a runwithout Amp, and you want to resume trainingwith Amp,load model and optimizer states from the checkpoint as usual. The checkpoint won’t contain a savedscaler state, souse a fresh instance ofGradScaler.
If a checkpoint was created from a runwith Amp and you want to resume trainingwithoutAmp,load model and optimizer states from the checkpoint as usual, and ignore the savedscaler state.
Inference/Evaluation#
autocast may be used by itself to wrap inference or evaluation forward passes.GradScaler is not necessary.
Advanced topics#
See theAutomatic Mixed Precision Examples for advanced use cases including:
Gradient accumulation
Gradient penalty/double backward
Networks with multiple models, optimizers, or losses
Multiple GPUs (
torch.nn.DataParallelortorch.nn.parallel.DistributedDataParallel)Custom autograd functions (subclasses of
torch.autograd.Function)
If you perform multiple convergence runs in the same script, each run should usea dedicated freshGradScaler instance.GradScaler instances are lightweight.
If you’re registering a custom C++ op with the dispatcher, see theautocast sectionof the dispatcher tutorial.
Troubleshooting#
Speedup with Amp is minor#
Your network may fail to saturate the GPU(s) with work, and is therefore CPU bound. Amp’s effect on GPU performancewon’t matter.
A rough rule of thumb to saturate the GPU is to increase batch and/or network size(s)as much as you can without running OOM.
Try to avoid excessive CPU-GPU synchronization (
.item()calls, or printing values from CUDA tensors).Try to avoid sequences of many small CUDA ops (coalesce these into a few large CUDA ops if you can).
Your network may be GPU compute bound (lots of
matmuls/convolutions) but your GPU does not have Tensor Cores.In this case a reduced speedup is expected.The
matmuldimensions are not Tensor Core-friendly. Make surematmulsparticipating sizes are multiples of 8.(For NLP models with encoders/decoders, this can be subtle. Also, convolutions used to have similar size constraintsfor Tensor Core use, but for CuDNN versions 7.3 and later, no such constraints exist. Seehere for guidance.)
Loss is inf/NaN#
First, check if your network fits anadvanced use case.See alsoPrefer binary_cross_entropy_with_logits over binary_cross_entropy.
If you’re confident your Amp usage is correct, you may need to file an issue, but before doing so, it’s helpful to gather the following information:
Disable
autocastorGradScalerindividually (by passingenabled=Falseto their constructor) and see ifinfs/NaNspersist.If you suspect part of your network (e.g., a complicated loss function) overflows , run that forward region in
float32and see ifinfs/NaN``spersist.`Theautocastdocstring<https://pytorch.org/docs/stable/amp.html#torch.autocast>`_'slastcodesnippetshowsforcingasubregiontorunin``float32(by locally disablingautocastand casting the subregion’s inputs).
Type mismatch error (may manifest asCUDNN_STATUS_BAD_PARAM)#
Autocast tries to cover all ops that benefit from or require casting.Ops that receive explicit coverageare chosen based on numerical properties, but also on experience.If you see a type mismatch error in anautocast enabled forward region or a backward pass following that region,it’s possibleautocast missed an op.
Please file an issue with the error backtrace.exportTORCH_SHOW_CPP_STACKTRACES=1 before running your script to providefine-grained information on which backend op is failing.