Rate this Page

Automatic Mixed Precision examples#

Created On: Feb 13, 2020 | Last Updated On: Sep 13, 2024

Ordinarily, “automatic mixed precision training” means training withtorch.autocast andtorch.amp.GradScaler together.

Instances oftorch.autocast enable autocasting for chosen regions.Autocasting automatically chooses the precision for operations to improve performancewhile maintaining accuracy.

Instances oftorch.amp.GradScaler help perform the steps ofgradient scaling conveniently. Gradient scaling improves convergence for networks withfloat16 (by default on CUDA and XPU)gradients by minimizing gradient underflow, as explainedhere.

torch.autocast andtorch.amp.GradScaler are modular.In the samples below, each is used as its individual documentation suggests.

(Samples here are illustrative. See theAutomatic Mixed Precision recipefor a runnable walkthrough.)

Typical Mixed Precision Training#

# Creates model and optimizer in default precisionmodel=Net().cuda()optimizer=optim.SGD(model.parameters(),...)# Creates a GradScaler once at the beginning of training.scaler=GradScaler()forepochinepochs:forinput,targetindata:optimizer.zero_grad()# Runs the forward pass with autocasting.withautocast(device_type='cuda',dtype=torch.float16):output=model(input)loss=loss_fn(output,target)# Scales loss.  Calls backward() on scaled loss to create scaled gradients.# Backward passes under autocast are not recommended.# Backward ops run in the same dtype autocast chose for corresponding forward ops.scaler.scale(loss).backward()# scaler.step() first unscales the gradients of the optimizer's assigned params.# If these gradients do not contain infs or NaNs, optimizer.step() is then called,# otherwise, optimizer.step() is skipped.scaler.step(optimizer)# Updates the scale for next iteration.scaler.update()

Working with Unscaled Gradients#

All gradients produced byscaler.scale(loss).backward() are scaled. If you wish to modify or inspectthe parameters’.grad attributes betweenbackward() andscaler.step(optimizer), you shouldunscale them first. For example, gradient clipping manipulates a set of gradients such that their global norm(seetorch.nn.utils.clip_grad_norm_()) or maximum magnitude (seetorch.nn.utils.clip_grad_value_())is<=<= some user-imposed threshold. If you attempted to clipwithout unscaling, the gradients’ norm/maximummagnitude would also be scaled, so your requested threshold (which was meant to be the threshold forunscaledgradients) would be invalid.

scaler.unscale_(optimizer) unscales gradients held byoptimizer’s assigned parameters.If your model or models contain other parameters that were assigned to another optimizer(sayoptimizer2), you may callscaler.unscale_(optimizer2) separately to unscale thoseparameters’ gradients as well.

Gradient clipping#

Callingscaler.unscale_(optimizer) before clipping enables you to clip unscaled gradients as usual:

scaler=GradScaler()forepochinepochs:forinput,targetindata:optimizer.zero_grad()withautocast(device_type='cuda',dtype=torch.float16):output=model(input)loss=loss_fn(output,target)scaler.scale(loss).backward()# Unscales the gradients of optimizer's assigned params in-placescaler.unscale_(optimizer)# Since the gradients of optimizer's assigned params are unscaled, clips as usual:torch.nn.utils.clip_grad_norm_(model.parameters(),max_norm)# optimizer's gradients are already unscaled, so scaler.step does not unscale them,# although it still skips optimizer.step() if the gradients contain infs or NaNs.scaler.step(optimizer)# Updates the scale for next iteration.scaler.update()

scaler records thatscaler.unscale_(optimizer) was already called for this optimizerthis iteration, soscaler.step(optimizer) knows not to redundantly unscale gradients before(internally) callingoptimizer.step().

Warning

unscale_ should only be called once per optimizer perstep call,and only after all gradients for that optimizer’s assigned parameters have been accumulated.Callingunscale_ twice for a given optimizer between eachstep triggers a RuntimeError.

Working with Scaled Gradients#

Gradient accumulation#

Gradient accumulation adds gradients over an effective batch of sizebatch_per_iter*iters_to_accumulate(*num_procs if distributed). The scale should be calibrated for the effective batch, which means inf/NaN checking,step skipping if inf/NaN grads are found, and scale updates should occur at effective-batch granularity.Also, grads should remain scaled, and the scale factor should remain constant, while grads for a given effectivebatch are accumulated. If grads are unscaled (or the scale factor changes) before accumulation is complete,the next backward pass will add scaled grads to unscaled grads (or grads scaled by a different factor)after which it’s impossible to recover the accumulated unscaled gradsstep must apply.

Therefore, if you want tounscale_ grads (e.g., to allow clipping unscaled grads),callunscale_ just beforestep, after all (scaled) grads for the upcomingstep have been accumulated. Also, only callupdate at the end of iterationswhere you calledstep for a full effective batch:

scaler=GradScaler()forepochinepochs:fori,(input,target)inenumerate(data):withautocast(device_type='cuda',dtype=torch.float16):output=model(input)loss=loss_fn(output,target)loss=loss/iters_to_accumulate# Accumulates scaled gradients.scaler.scale(loss).backward()if(i+1)%iters_to_accumulate==0:# may unscale_ here if desired (e.g., to allow clipping unscaled gradients)scaler.step(optimizer)scaler.update()optimizer.zero_grad()

Gradient penalty#

A gradient penalty implementation commonly creates gradients usingtorch.autograd.grad(), combines them to create the penalty value,and adds the penalty value to the loss.

Here’s an ordinary example of an L2 penalty without gradient scaling or autocasting:

forepochinepochs:forinput,targetindata:optimizer.zero_grad()output=model(input)loss=loss_fn(output,target)# Creates gradientsgrad_params=torch.autograd.grad(outputs=loss,inputs=model.parameters(),create_graph=True)# Computes the penalty term and adds it to the lossgrad_norm=0forgradingrad_params:grad_norm+=grad.pow(2).sum()grad_norm=grad_norm.sqrt()loss=loss+grad_normloss.backward()# clip gradients here, if desiredoptimizer.step()

To implement a gradient penaltywith gradient scaling, theoutputs Tensor(s)passed totorch.autograd.grad() should be scaled. The resulting gradientswill therefore be scaled, and should be unscaled before being combined to create thepenalty value.

Also, the penalty term computation is part of the forward pass, and therefore should beinside anautocast context.

Here’s how that looks for the same L2 penalty:

scaler=GradScaler()forepochinepochs:forinput,targetindata:optimizer.zero_grad()withautocast(device_type='cuda',dtype=torch.float16):output=model(input)loss=loss_fn(output,target)# Scales the loss for autograd.grad's backward pass, producing scaled_grad_paramsscaled_grad_params=torch.autograd.grad(outputs=scaler.scale(loss),inputs=model.parameters(),create_graph=True)# Creates unscaled grad_params before computing the penalty. scaled_grad_params are# not owned by any optimizer, so ordinary division is used instead of scaler.unscale_:inv_scale=1./scaler.get_scale()grad_params=[p*inv_scaleforpinscaled_grad_params]# Computes the penalty term and adds it to the losswithautocast(device_type='cuda',dtype=torch.float16):grad_norm=0forgradingrad_params:grad_norm+=grad.pow(2).sum()grad_norm=grad_norm.sqrt()loss=loss+grad_norm# Applies scaling to the backward call as usual.# Accumulates leaf gradients that are correctly scaled.scaler.scale(loss).backward()# may unscale_ here if desired (e.g., to allow clipping unscaled gradients)# step() and update() proceed as usual.scaler.step(optimizer)scaler.update()

Working with Multiple Models, Losses, and Optimizers#

If your network has multiple losses, you must callscaler.scale on each of them individually.If your network has multiple optimizers, you may callscaler.unscale_ on any of them individually,and you must callscaler.step on each of them individually.

However,scaler.update should only be called once,after all optimizers used this iteration have been stepped:

scaler=torch.amp.GradScaler()forepochinepochs:forinput,targetindata:optimizer0.zero_grad()optimizer1.zero_grad()withautocast(device_type='cuda',dtype=torch.float16):output0=model0(input)output1=model1(input)loss0=loss_fn(2*output0+3*output1,target)loss1=loss_fn(3*output0-5*output1,target)# (retain_graph here is unrelated to amp, it's present because in this# example, both backward() calls share some sections of graph.)scaler.scale(loss0).backward(retain_graph=True)scaler.scale(loss1).backward()# You can choose which optimizers receive explicit unscaling, if you# want to inspect or modify the gradients of the params they own.scaler.unscale_(optimizer0)scaler.step(optimizer0)scaler.step(optimizer1)scaler.update()

Each optimizer checks its gradients for infs/NaNs and makes an independent decisionwhether or not to skip the step. This may result in one optimizer skipping the stepwhile the other one does not. Since step skipping occurs rarely (every several hundred iterations)this should not impede convergence. If you observe poor convergence after adding gradient scalingto a multiple-optimizer model, please report a bug.

Working with Multiple GPUs#

The issues described here only affectautocast.GradScaler‘s usage is unchanged.

DataParallel in a single process#

Even iftorch.nn.DataParallel spawns threads to run the forward pass on each device.The autocast state is propagated in each one and the following will work:

model=MyModel()dp_model=nn.DataParallel(model)# Sets autocast in the main threadwithautocast(device_type='cuda',dtype=torch.float16):# dp_model's internal threads will autocast.output=dp_model(input)# loss_fn also autocastloss=loss_fn(output)

DistributedDataParallel, one GPU per process#

torch.nn.parallel.DistributedDataParallel’s documentation recommends one GPU per process for bestperformance. In this case,DistributedDataParallel does not spawn threads internally,so usages ofautocast andGradScaler are not affected.

DistributedDataParallel, multiple GPUs per process#

Heretorch.nn.parallel.DistributedDataParallel may spawn a side thread to run the forward pass on eachdevice, liketorch.nn.DataParallel.The fix is the same:apply autocast as part of your model’sforward method to ensure it’s enabled in side threads.

Autocast and Custom Autograd Functions#

If your network usescustom autograd functions(subclasses oftorch.autograd.Function), changes are required forautocast compatibility if any function

  • takes multiple floating-point Tensor inputs,

  • wraps any autocastable op (see theAutocast Op Reference), or

  • requires a particulardtype (for example, if it wrapsCUDA extensionsthat were only compiled fordtype).

In all cases, if you’re importing the function and can’t alter its definition, a safe fallbackis to disable autocast and force execution infloat32 ( ordtype) at any points of use where errors occur:

withautocast(device_type='cuda',dtype=torch.float16):...withautocast(device_type='cuda',dtype=torch.float16,enabled=False):output=imported_function(input1.float(),input2.float())

If you’re the function’s author (or can alter its definition) a better solution is to use thetorch.amp.custom_fwd() andtorch.amp.custom_bwd() decorators as shown inthe relevant case below.

Functions with multiple inputs or autocastable ops#

Applycustom_fwd andcustom_bwd (with no arguments) toforward andbackward respectively. These ensureforward executes with the current autocast state andbackwardexecutes with the same autocast state asforward (which can prevent type mismatch errors):

classMyMM(torch.autograd.Function):@staticmethod@custom_fwddefforward(ctx,a,b):ctx.save_for_backward(a,b)returna.mm(b)@staticmethod@custom_bwddefbackward(ctx,grad):a,b=ctx.saved_tensorsreturngrad.mm(b.t()),a.t().mm(grad)

NowMyMM can be invoked anywhere, without disabling autocast or manually casting inputs:

mymm=MyMM.applywithautocast(device_type='cuda',dtype=torch.float16):output=mymm(input1,input2)

Functions that need a particulardtype#

Consider a custom function that requirestorch.float32 inputs.Applycustom_fwd(device_type='cuda',cast_inputs=torch.float32) toforwardandcustom_bwd(device_type='cuda') tobackward.Ifforward runs in an autocast-enabled region, the decorators cast floating-point Tensorinputs tofloat32 on designated device assigned by the argumentdevice_type,CUDA in this example, and locally disable autocast duringforward andbackward:

classMyFloat32Func(torch.autograd.Function):@staticmethod@custom_fwd(device_type='cuda',cast_inputs=torch.float32)defforward(ctx,input):ctx.save_for_backward(input)...returnfwd_output@staticmethod@custom_bwd(device_type='cuda')defbackward(ctx,grad):...

NowMyFloat32Func can be invoked anywhere, without manually disabling autocast or casting inputs:

func=MyFloat32Func.applywithautocast(device_type='cuda',dtype=torch.float16):# func will run in float32, regardless of the surrounding autocast stateoutput=func(input)