Rate this Page

Automatic Mixed Precision package - torch.amp#

Created On: Jun 12, 2025 | Last Updated On: Jun 12, 2025

torch.amp provides convenience methods for mixed precision,where some operations use thetorch.float32 (float) datatype and other operationsuse lower precision floating point datatype (lower_precision_fp):torch.float16 (half) ortorch.bfloat16. Some ops, like linear layers and convolutions,are much faster inlower_precision_fp. Other ops, like reductions, often require the dynamicrange offloat32. Mixed precision tries to match each op to its appropriate datatype.

Ordinarily, “automatic mixed precision training” with datatype oftorch.float16 usestorch.autocast andtorch.amp.GradScaler together, as shown in theAutomatic Mixed Precision examplesandAutomatic Mixed Precision recipe.However,torch.autocast andtorch.GradScaler are modular, and may be used separately if desired.As shown in the CPU example section oftorch.autocast, “automatic mixed precision training/inference” on CPU withdatatype oftorch.bfloat16 only usestorch.autocast.

Warning

torch.cuda.amp.autocast(args...) andtorch.cpu.amp.autocast(args...) is deprecated. Please usetorch.amp.autocast("cuda",args...) ortorch.amp.autocast("cpu",args...) instead.torch.cuda.amp.GradScaler(args...) andtorch.cpu.amp.GradScaler(args...) is deprecated. Please usetorch.amp.GradScaler("cuda",args...) ortorch.amp.GradScaler("cpu",args...) instead.

torch.autocast andtorch.cpu.amp.autocast are new in version1.10.

Autocasting#

torch.amp.autocast_mode.is_autocast_available(device_type)[source]#

Return a bool indicating if autocast is available ondevice_type.

Parameters:

device_type (str) – Device type to use. Possible values are: ‘cuda’, ‘cpu’, ‘mtia’, ‘maia’, ‘xpu’, and so on.The type is the same as thetype attribute of atorch.device.Thus, you may obtain the device type of a tensor usingTensor.device.type.

Return type:

bool

classtorch.autocast(device_type,dtype=None,enabled=True,cache_enabled=None)[source]#

Instances ofautocast serve as context managers or decorators thatallow regions of your script to run in mixed precision.

In these regions, ops run in an op-specific dtype chosen by autocastto improve performance while maintaining accuracy.See theAutocast Op Reference for details.

When entering an autocast-enabled region, Tensors may be any type.You should not callhalf() orbfloat16() on your model(s) or inputs when using autocasting.

autocast should wrap only the forward pass(es) of your network, including the losscomputation(s). Backward passes under autocast are not recommended.Backward ops run in the same type that autocast used for corresponding forward ops.

Example for CUDA Devices:

# Creates model and optimizer in default precisionmodel=Net().cuda()optimizer=optim.SGD(model.parameters(),...)forinput,targetindata:optimizer.zero_grad()# Enables autocasting for the forward pass (model + loss)withtorch.autocast(device_type="cuda"):output=model(input)loss=loss_fn(output,target)# Exits the context manager before backward()loss.backward()optimizer.step()

See theAutomatic Mixed Precision examples for usage (along with gradient scaling)in more complex scenarios (e.g., gradient penalty, multiple models/losses, custom autograd functions).

autocast can also be used as a decorator, e.g., on theforward method of your model:

classAutocastModel(nn.Module):...@torch.autocast(device_type="cuda")defforward(self,input):...

Floating-point Tensors produced in an autocast-enabled region may befloat16.After returning to an autocast-disabled region, using them with floating-pointTensors of different dtypes may cause type mismatch errors. If so, cast the Tensor(s)produced in the autocast region back tofloat32 (or other dtype if desired).If a Tensor from the autocast region is alreadyfloat32, the cast is a no-op,and incurs no additional overhead.CUDA Example:

# Creates some tensors in default dtype (here assumed to be float32)a_float32=torch.rand((8,8),device="cuda")b_float32=torch.rand((8,8),device="cuda")c_float32=torch.rand((8,8),device="cuda")d_float32=torch.rand((8,8),device="cuda")withtorch.autocast(device_type="cuda"):# torch.mm is on autocast's list of ops that should run in float16.# Inputs are float32, but the op runs in float16 and produces float16 output.# No manual casts are required.e_float16=torch.mm(a_float32,b_float32)# Also handles mixed input typesf_float16=torch.mm(d_float32,e_float16)# After exiting autocast, calls f_float16.float() to use with d_float32g_float32=torch.mm(d_float32,f_float16.float())

CPU Training Example:

# Creates model and optimizer in default precisionmodel=Net()optimizer=optim.SGD(model.parameters(),...)forepochinepochs:forinput,targetindata:optimizer.zero_grad()# Runs the forward pass with autocasting.withtorch.autocast(device_type="cpu",dtype=torch.bfloat16):output=model(input)loss=loss_fn(output,target)loss.backward()optimizer.step()

CPU Inference Example:

# Creates model in default precisionmodel=Net().eval()withtorch.autocast(device_type="cpu",dtype=torch.bfloat16):forinputindata:# Runs the forward pass with autocasting.output=model(input)

CPU Inference Example with Jit Trace:

classTestModel(nn.Module):def__init__(self,input_size,num_classes):super().__init__()self.fc1=nn.Linear(input_size,num_classes)defforward(self,x):returnself.fc1(x)input_size=2num_classes=2model=TestModel(input_size,num_classes).eval()# For now, we suggest to disable the Jit Autocast Pass,# As the issue: https://github.com/pytorch/pytorch/issues/75956torch._C._jit_set_autocast_mode(False)withtorch.cpu.amp.autocast(cache_enabled=False):model=torch.jit.trace(model,torch.randn(1,input_size))model=torch.jit.freeze(model)# Models Runfor_inrange(3):model(torch.randn(1,input_size))

Type mismatch errorsin an autocast-enabled region are a bug; if this is what you observe,please file an issue.

autocast(enabled=False) subregions can be nested in autocast-enabled regions.Locally disabling autocast can be useful, for example, if you want to force a subregionto run in a particulardtype. Disabling autocast gives you explicit control overthe execution type. In the subregion, inputs from the surrounding regionshould be cast todtype before use:

# Creates some tensors in default dtype (here assumed to be float32)a_float32=torch.rand((8,8),device="cuda")b_float32=torch.rand((8,8),device="cuda")c_float32=torch.rand((8,8),device="cuda")d_float32=torch.rand((8,8),device="cuda")withtorch.autocast(device_type="cuda"):e_float16=torch.mm(a_float32,b_float32)withtorch.autocast(device_type="cuda",enabled=False):# Calls e_float16.float() to ensure float32 execution# (necessary because e_float16 was created in an autocasted region)f_float32=torch.mm(c_float32,e_float16.float())# No manual casts are required when re-entering the autocast-enabled region.# torch.mm again runs in float16 and produces float16 output, regardless of input types.g_float16=torch.mm(d_float32,f_float32)

The autocast state is thread-local. If you want it enabled in a new thread, the context manager or decoratormust be invoked in that thread. This affectstorch.nn.DataParallel andtorch.nn.parallel.DistributedDataParallel when used with more than one GPU per process(seeWorking with Multiple GPUs).

Parameters:
  • device_type (str,required) – Device type to use. Possible values are: ‘cuda’, ‘cpu’, ‘mtia’, ‘maia’, ‘xpu’, and ‘hpu’.The type is the same as thetype attribute of atorch.device.Thus, you may obtain the device type of a tensor usingTensor.device.type.

  • enabled (bool,optional) – Whether autocasting should be enabled in the region.Default:True

  • dtype (torch_dtype,optional) – Data type for ops run in autocast. It uses the default value(torch.float16 for CUDA andtorch.bfloat16 for CPU), given byget_autocast_dtype(), ifdtype isNone.Default:None

  • cache_enabled (bool,optional) – Whether the weight cache inside autocast should be enabled.Default:True

torch.amp.custom_fwd(fwd=None,*,device_type,cast_inputs=None)[source]#

Create a helper decorator forforward methods of custom autograd functions.

Autograd functions are subclasses oftorch.autograd.Function.See theexample page for more detail.

Parameters:
  • device_type (str) – Device type to use. ‘cuda’, ‘cpu’, ‘mtia’, ‘maia’, ‘xpu’ and so on.The type is the same as thetype attribute of atorch.device.Thus, you may obtain the device type of a tensor usingTensor.device.type.

  • cast_inputs (torch.dtype or None, optional, default=None) – If notNone,whenforward runs in an autocast-enabled region, casts incomingfloating-point Tensors to the target dtype (non-floating-point Tensors are not affected),then executesforward with autocast disabled.IfNone,forward’s internal ops execute with the current autocast state.

Note

If the decoratedforward is called outside an autocast-enabled region,custom_fwd is a no-op andcast_inputs has no effect.

torch.amp.custom_bwd(bwd=None,*,device_type)[source]#

Create a helper decorator for backward methods of custom autograd functions.

Autograd functions are subclasses oftorch.autograd.Function.Ensures thatbackward executes with the same autocast state asforward.See theexample page for more detail.

Parameters:

device_type (str) – Device type to use. ‘cuda’, ‘cpu’, ‘mtia’, ‘maia’, ‘xpu’ and so on.The type is the same as thetype attribute of atorch.device.Thus, you may obtain the device type of a tensor usingTensor.device.type.

classtorch.cuda.amp.autocast(enabled=True,dtype=torch.float16,cache_enabled=True)[source]#

Seetorch.autocast.

torch.cuda.amp.autocast(args...) is deprecated. Please usetorch.amp.autocast("cuda",args...) instead.

torch.cuda.amp.custom_fwd(fwd=None,*,cast_inputs=None)[source]#

torch.cuda.amp.custom_fwd(args...) is deprecated. Please usetorch.amp.custom_fwd(args...,device_type='cuda') instead.

torch.cuda.amp.custom_bwd(bwd)[source]#

torch.cuda.amp.custom_bwd(args...) is deprecated. Please usetorch.amp.custom_bwd(args...,device_type='cuda') instead.

classtorch.cpu.amp.autocast(enabled=True,dtype=torch.bfloat16,cache_enabled=True)[source]#

Seetorch.autocast.torch.cpu.amp.autocast(args...) is deprecated. Please usetorch.amp.autocast("cpu",args...) instead.

Gradient Scaling#

If the forward pass for a particular op hasfloat16 inputs, the backward pass forthat op will producefloat16 gradients.Gradient values with small magnitudes may not be representable infloat16.These values will flush to zero (“underflow”), so the update for the corresponding parameters will be lost.

To prevent underflow, “gradient scaling” multiplies the network’s loss(es) by a scale factor andinvokes a backward pass on the scaled loss(es). Gradients flowing backward through the network arethen scaled by the same factor. In other words, gradient values have a larger magnitude,so they don’t flush to zero.

Each parameter’s gradient (.grad attribute) should be unscaled before the optimizerupdates the parameters, so the scale factor does not interfere with the learning rate.

Note

AMP/fp16 may not work for every model! For example, most bf16-pretrained models cannot operate inthe fp16 numerical range of max 65504 and will cause gradients to overflow instead of underflow. Inthis case, the scale factor may decrease under 1 as an attempt to bring gradients to a numberrepresentable in the fp16 dynamic range. While one may expect the scale to always be above 1, ourGradScaler does NOT make this guarantee to maintain performance. If you encounter NaNs in your lossor gradients when running with AMP/fp16, verify your model is compatible.

classtorch.cuda.amp.GradScaler(init_scale=65536.0,growth_factor=2.0,backoff_factor=0.5,growth_interval=2000,enabled=True)[source]#

Seetorch.amp.GradScaler.torch.cuda.amp.GradScaler(args...) is deprecated. Please usetorch.amp.GradScaler("cuda",args...) instead.

classtorch.cpu.amp.GradScaler(init_scale=65536.0,growth_factor=2.0,backoff_factor=0.5,growth_interval=2000,enabled=True)[source]#

Seetorch.amp.GradScaler.torch.cpu.amp.GradScaler(args...) is deprecated. Please usetorch.amp.GradScaler("cpu",args...) instead.

Autocast Op Reference#

Op Eligibility#

Ops that run infloat64 or non-floating-point dtypes are not eligible, and willrun in these types whether or not autocast is enabled.

Only out-of-place ops and Tensor methods are eligible.In-place variants and calls that explicitly supply anout=... Tensorare allowed in autocast-enabled regions, but won’t go through autocasting.For example, in an autocast-enabled regiona.addmm(b,c) can autocast,buta.addmm_(b,c) anda.addmm(b,c,out=d) cannot.For best performance and stability, prefer out-of-place ops in autocast-enabledregions.

Ops called with an explicitdtype=... argument are not eligible,and will produce output that respects thedtype argument.

CUDA Op-Specific Behavior#

The following lists describe the behavior of eligible ops in autocast-enabled regions.These ops always go through autocasting whether they are invoked as part of atorch.nn.Module,as a function, or as atorch.Tensor method. If functions are exposed in multiple namespaces,they go through autocasting regardless of the namespace.

Ops not listed below do not go through autocasting. They run in the typedefined by their inputs. However, autocasting may still change the typein which unlisted ops run if they’re downstream from autocasted ops.

If an op is unlisted, we assume it’s numerically stable infloat16.If you believe an unlisted op is numerically unstable infloat16,please file an issue.

CUDA Ops that can autocast tofloat16#

__matmul__,addbmm,addmm,addmv,addr,baddbmm,bmm,chain_matmul,multi_dot,conv1d,conv2d,conv3d,conv_transpose1d,conv_transpose2d,conv_transpose3d,GRUCell,linear,LSTMCell,matmul,mm,mv,prelu,RNNCell

CUDA Ops that can autocast tofloat32#

__pow__,__rdiv__,__rpow__,__rtruediv__,acos,asin,binary_cross_entropy_with_logits,cosh,cosine_embedding_loss,cdist,cosine_similarity,cross_entropy,cumprod,cumsum,dist,erfinv,exp,expm1,group_norm,hinge_embedding_loss,kl_div,l1_loss,layer_norm,log,log_softmax,log10,log1p,log2,margin_ranking_loss,mse_loss,multilabel_margin_loss,multi_margin_loss,nll_loss,norm,normalize,pdist,poisson_nll_loss,pow,prod,reciprocal,rsqrt,sinh,smooth_l1_loss,soft_margin_loss,softmax,softmin,softplus,sum,renorm,tan,triplet_margin_loss

CUDA Ops that promote to the widest input type#

These ops don’t require a particular dtype for stability, but take multiple inputsand require that the inputs’ dtypes match. If all of the inputs arefloat16, the op runs infloat16. If any of the inputs isfloat32,autocast casts all inputs tofloat32 and runs the op infloat32.

addcdiv,addcmul,atan2,bilinear,cross,dot,grid_sample,index_put,scatter_add,tensordot

Some ops not listed here (e.g., binary ops likeadd) natively promoteinputs without autocasting’s intervention. If inputs are a mixture offloat16andfloat32, these ops run infloat32 and producefloat32 output,regardless of whether autocast is enabled.

Preferbinary_cross_entropy_with_logits overbinary_cross_entropy#

The backward passes oftorch.nn.functional.binary_cross_entropy() (andtorch.nn.BCELoss, which wraps it)can produce gradients that aren’t representable infloat16. In autocast-enabled regions, the forward inputmay befloat16, which means the backward gradient must be representable infloat16 (autocastingfloat16forward inputs tofloat32 doesn’t help, because that cast must be reversed in backward).Therefore,binary_cross_entropy andBCELoss raise an error in autocast-enabled regions.

Many models use a sigmoid layer right before the binary cross entropy layer.In this case, combine the two layers usingtorch.nn.functional.binary_cross_entropy_with_logits()ortorch.nn.BCEWithLogitsLoss.binary_cross_entropy_with_logits andBCEWithLogitsare safe to autocast.

XPU Op-Specific Behavior (Experimental)#

The following lists describe the behavior of eligible ops in autocast-enabled regions.These ops always go through autocasting whether they are invoked as part of atorch.nn.Module,as a function, or as atorch.Tensor method. If functions are exposed in multiple namespaces,they go through autocasting regardless of the namespace.

Ops not listed below do not go through autocasting. They run in the typedefined by their inputs. However, autocasting may still change the typein which unlisted ops run if they’re downstream from autocasted ops.

If an op is unlisted, we assume it’s numerically stable infloat16.If you believe an unlisted op is numerically unstable infloat16,please file an issue.

XPU Ops that can autocast tofloat16#

addbmm,addmm,addmv,addr,baddbmm,bmm,chain_matmul,multi_dot,conv1d,conv2d,conv3d,conv_transpose1d,conv_transpose2d,conv_transpose3d,GRUCell,linear,LSTMCell,matmul,mm,mv,RNNCell

XPU Ops that can autocast tofloat32#

__pow__,__rdiv__,__rpow__,__rtruediv__,binary_cross_entropy_with_logits,cosine_embedding_loss,cosine_similarity,cumsum,dist,exp,group_norm,hinge_embedding_loss,kl_div,l1_loss,layer_norm,log,log_softmax,margin_ranking_loss,nll_loss,normalize,poisson_nll_loss,pow,reciprocal,rsqrt,soft_margin_loss,softmax,softmin,sum,triplet_margin_loss

XPU Ops that promote to the widest input type#

These ops don’t require a particular dtype for stability, but take multiple inputsand require that the inputs’ dtypes match. If all of the inputs arefloat16, the op runs infloat16. If any of the inputs isfloat32,autocast casts all inputs tofloat32 and runs the op infloat32.

bilinear,cross,grid_sample,index_put,scatter_add,tensordot

Some ops not listed here (e.g., binary ops likeadd) natively promoteinputs without autocasting’s intervention. If inputs are a mixture offloat16andfloat32, these ops run infloat32 and producefloat32 output,regardless of whether autocast is enabled.

CPU Op-Specific Behavior#

The following lists describe the behavior of eligible ops in autocast-enabled regions.These ops always go through autocasting whether they are invoked as part of atorch.nn.Module,as a function, or as atorch.Tensor method. If functions are exposed in multiple namespaces,they go through autocasting regardless of the namespace.

Ops not listed below do not go through autocasting. They run in the typedefined by their inputs. However, autocasting may still change the typein which unlisted ops run if they’re downstream from autocasted ops.

If an op is unlisted, we assume it’s numerically stable inbfloat16.If you believe an unlisted op is numerically unstable inbfloat16,please file an issue.float16 shares the lists ofbfloat16.

CPU Ops that can autocast tobfloat16#

conv1d,conv2d,conv3d,bmm,mm,linalg_vecdot,baddbmm,addmm,addbmm,linear,matmul,_convolution,conv_tbc,mkldnn_rnn_layer,conv_transpose1d,conv_transpose2d,conv_transpose3d,prelu,scaled_dot_product_attention,_native_multi_head_attention

CPU Ops that can autocast tofloat32#

avg_pool3d,binary_cross_entropy,grid_sampler,grid_sampler_2d,_grid_sampler_2d_cpu_fallback,grid_sampler_3d,polar,prod,quantile,nanquantile,stft,cdist,trace,view_as_complex,cholesky,cholesky_inverse,cholesky_solve,inverse,lu_solve,orgqr,inverse,ormqr,pinverse,max_pool3d,max_unpool2d,max_unpool3d,adaptive_avg_pool3d,reflection_pad1d,reflection_pad2d,replication_pad1d,replication_pad2d,replication_pad3d,mse_loss,cosine_embedding_loss,nll_loss,nll_loss2d,hinge_embedding_loss,poisson_nll_loss,cross_entropy_loss,l1_loss,huber_loss,margin_ranking_loss,soft_margin_loss,triplet_margin_loss,multi_margin_loss,ctc_loss,kl_div,multilabel_margin_loss,binary_cross_entropy_with_logits,fft_fft,fft_ifft,fft_fft2,fft_ifft2,fft_fftn,fft_ifftn,fft_rfft,fft_irfft,fft_rfft2,fft_irfft2,fft_rfftn,fft_irfftn,fft_hfft,fft_ihfft,linalg_cond,linalg_matrix_rank,linalg_solve,linalg_cholesky,linalg_svdvals,linalg_eigvals,linalg_eigvalsh,linalg_inv,linalg_householder_product,linalg_tensorinv,linalg_tensorsolve,fake_quantize_per_tensor_affine,geqrf,_lu_with_info,qr,svd,triangular_solve,fractional_max_pool2d,fractional_max_pool3d,adaptive_max_pool3d,multilabel_margin_loss_forward,linalg_qr,linalg_cholesky_ex,linalg_svd,linalg_eig,linalg_eigh,linalg_lstsq,linalg_inv_ex

CPU Ops that promote to the widest input type#

These ops don’t require a particular dtype for stability, but take multiple inputsand require that the inputs’ dtypes match. If all of the inputs arebfloat16, the op runs inbfloat16. If any of the inputs isfloat32,autocast casts all inputs tofloat32 and runs the op infloat32.

cat,stack,index_copy

Some ops not listed here (e.g., binary ops likeadd) natively promoteinputs without autocasting’s intervention. If inputs are a mixture ofbfloat16andfloat32, these ops run infloat32 and producefloat32 output,regardless of whether autocast is enabled.