torch.utils.checkpoint#
Created On: Jun 16, 2025 | Last Updated On: Oct 29, 2025
Note
Checkpointing is implemented by rerunning a forward-pass segment foreach checkpointed segment during backward propagation. This can cause persistentstates like the RNG state to be more advanced than they would withoutcheckpointing. By default, checkpointing includes logic to jugglethe RNG state such that checkpointed passes making use of RNG(through dropout for example) have deterministic output ascompared to non-checkpointed passes. The logic to stash and restoreRNG states can incur a moderate performance hit depending on the runtimeof checkpointed operations. If deterministic output compared tonon-checkpointed passes is not required, supplypreserve_rng_state=Falsetocheckpoint orcheckpoint_sequential to omit stashing andrestoring the RNG state during each checkpoint.
The stashing logic saves and restores the RNG state for CPU and anotherdevice type (infer the device type from Tensor arguments excluding CPUtensors by_infer_device_type) to therun_fn. If there are multipledevices, device state will only be saved for devices of a single device type,and the remaining devices will be ignored. Consequently, if any checkpointedfunctions involve randomness, this may result in incorrect gradients. (Notethat if CUDA devices are among the devices detected, it will be prioritized;otherwise, the first device encountered will be selected.) If there are noCPU-tensors, the default device type state (default value iscuda, and itcould be set to other device byDefaultDeviceType) will be saved and restored.However, the logic has no way to anticipate if the user will moveTensors to a new device within therun_fn itself. Therefore, if you moveTensors to a new device (“new” meaning not belonging to the set of[current device + devices of Tensor arguments]) withinrun_fn, deterministicoutput compared to non-checkpointed passes is never guaranteed.
- torch.utils.checkpoint.checkpoint(function,*args,use_reentrant=None,context_fn=<functionnoop_context_fn>,determinism_check='default',debug=False,early_stop=True,**kwargs)[source]#
Checkpoint a model or part of the model.
Activation checkpointing is a technique that trades compute for memory.By default, tensors computed during the forward pass are kept alive untilthey are used in gradient computations in the backward pass. To reduce thismemory usage, tensors produced in the passed
functionare not keptalive until the backward pass. Instead, any passed tensors inargsare kept alive, and the unsaved tensors are recomputed by re-invokingfunctionin the backward pass as needed for gradient computation.Activation checkpointing can be applied to any part of a model – this issometimes described as “checkpointing” that part of the model.There are currently two checkpointing implementations available, determinedby the
use_reentrantparameter. It is recommended that you useuse_reentrant=False. Please refer the note below for a discussion oftheir differences.Warning
If the
functioninvocation during the backward pass differsfrom the forward pass, e.g., due to a global variable, the checkpointedversion may not be equivalent, potentially causing anerror being raised or leading to silently incorrect gradients.Warning
The
use_reentrantparameter should be passed explicitly. In version2.9 we will raise an exception ifuse_reentrantis not passed.If you are using theuse_reentrant=Truevariant, please refer to thenote below for important considerations and potential limitations.Note
The reentrant variant of checkpoint (
use_reentrant=True) andthe non-reentrant variant of checkpoint (use_reentrant=False)differ in the following ways:Non-reentrant checkpoint stops recomputation as soon as all neededintermediate activations have been recomputed. This feature is enabledby default, but can be disabled with
set_checkpoint_early_stop().Reentrant checkpoint always recomputesfunctionin itsentirety during the backward pass.The reentrant variant does not record the autograd graph during theforward pass, as it runs with the forward pass under
torch.no_grad(). The non-reentrant version does record theautograd graph, allowing one to perform backward on the graph withincheckpointed regions.The reentrant checkpoint only supports the
torch.autograd.backward()API for the backward pass without itsinputs argument, while the non-reentrant version supports all waysof performing the backward pass.At least one input and output must have
requires_grad=Truefor thereentrant variant. If this condition is unmet, the checkpointed partof the model will not have gradients. The non-reentrant version doesnot have this requirement.The reentrant version does not consider tensors in nested structures(e.g., custom objects, lists, dicts, etc) as participating inautograd, while the non-reentrant version does.
The reentrant checkpoint does not support checkpointed regions withdetached tensors from the computational graph, whereas thenon-reentrant version does. For the reentrant variant, if thecheckpointed segment contains tensors detached using
detach()orwithtorch.no_grad(), the backward pass will raise an error.This is becausecheckpointmakes all the outputs require gradientsand this causes issues when a tensor is defined to have no gradient inthe model. To avoid this, detach the tensors outside of thecheckpointfunction.
- Parameters:
function – describes what to run in the forward pass of the model orpart of the model. It should also know how to handle the inputspassed as the tuple. For example, in LSTM, if user passes
(activation,hidden),functionshould correctly use thefirst input asactivationand the second input ashiddenargs – tuple containing inputs to the
function
- Keyword Arguments:
preserve_rng_state (bool,optional) – Omit stashing and restoringthe RNG state during each checkpoint. Note that under torch.compile,this flag doesn’t take effect and we always preserve RNG state.Default:
Trueuse_reentrant (bool) – specify whether to use the activation checkpoint variant thatrequires reentrant autograd. This parameter should be passedexplicitly. In version 2.9 we will raise an exception if
use_reentrantis not passed. Ifuse_reentrant=False,checkpointwill use an implementation that does not requirereentrant autograd. This allowscheckpointto support additionalfunctionality, such as working as expected withtorch.autograd.gradand support for keyword arguments input intothe checkpointed function.context_fn (Callable,optional) – A callable returning a tuple of twocontext managers. The function and its recomputation will be rununder the first and second context managers respectively.This argument is only supported if
use_reentrant=False.determinism_check (str,optional) – A string specifying the determinismcheck to perform. By default it is set to
"default"whichcompares the shapes, dtypes, and devices of the recomputed tensorsagainst those the saved tensors. To turn off this check, specify"none". Currently these are the only two supported values.Please open an issue if you would like to see more determinismchecks. This argument is only supported ifuse_reentrant=False,ifuse_reentrant=True, the determinism check is always disabled.debug (bool,optional) – If
True, error messages will also includea trace of the operators ran during the original forward computationas well as the recomputation. This argument is only supported ifuse_reentrant=False.early_stop (bool,optional) – If
True, non-reentrant checkpoint stopsrecomputation as soon as it has computed all needed Tensors. Thisargument is ignored ifuse_reentrant=True. Can be overriddenglobally usingset_checkpoint_early_stop()context manager.Default:True.
- Returns:
Output of running
functionon*args
- torch.utils.checkpoint.checkpoint_sequential(functions,segments,input,use_reentrant=None,**kwargs)[source]#
Checkpoint a sequential model to save memory.
Sequential models execute a list of modules/functions in order(sequentially). Therefore, we can divide such a model in various segmentsand checkpoint each segment. All segments except the last will not storethe intermediate activations. The inputs of each checkpointed segment willbe saved for re-running the segment in the backward pass.
Warning
The
use_reentrantparameter should be passed explicitly. In version2.9 we will raise an exception ifuse_reentrantis not passed.If you are using theuse_reentrant=True`variant,pleasesee:func:`~torch.utils.checkpoint.checkpoint`fortheimportantconsiderationsandlimitationsofthisvariant.Itisrecommendedthatyouuse``use_reentrant=False.- Parameters:
functions – A
torch.nn.Sequentialor the list of modules orfunctions (comprising the model) to run sequentially.segments – Number of chunks to create in the model
input – A Tensor that is input to
functionspreserve_rng_state (bool,optional) – Omit stashing and restoringthe RNG state during each checkpoint.Default:
Trueuse_reentrant (bool) – specify whether to use the activation checkpoint variant thatrequires reentrant autograd. This parameter should be passedexplicitly. In version 2.5 we will raise an exception if
use_reentrantis not passed. Ifuse_reentrant=False,checkpointwill use an implementation that does not requirereentrant autograd. This allowscheckpointto support additionalfunctionality, such as working as expected withtorch.autograd.gradand support for keyword arguments input intothe checkpointed function.
- Returns:
Output of running
functionssequentially on*inputs
Example
>>>model=nn.Sequential(...)>>>input_var=checkpoint_sequential(model,chunks,input_var)
- torch.utils.checkpoint.set_checkpoint_debug_enabled(enabled)[source]#
Context manager that sets whether checkpoint should print additional debuginformation when running. See the
debugflag forcheckpoint()for more information. Note thatwhen set, this context manager overrides the value ofdebugpassed tocheckpoint. To defer to the local setting, passNoneto this context.- Parameters:
enabled (bool) – Whether checkpoint should print debug information.Default is ‘None’.
- classtorch.utils.checkpoint.CheckpointPolicy(value)[source]#
Enum for specifying the policy for checkpointing during backpropagation.
The following policies are supported:
{MUST,PREFER}_SAVE: The operation’s output will be saved during the forwardpass and will not be recomputed during the backward pass{MUST,PREFER}_RECOMPUTE: The operation’s output will not be saved during theforward pass and will be recomputed during the backward pass
Use
MUST_*overPREFER_*to indicate that the policy should not be overriddenby other subsystems liketorch.compile.Note
A policy function that always returns
PREFER_RECOMPUTEisequivalent to vanilla checkpointing.A policy function that returns
PREFER_SAVEevery op isNOT equivalent to not using checkpointing. Using such a policy wouldsave additional tensors not limited to ones that are actually needed forgradient computation.
- classtorch.utils.checkpoint.SelectiveCheckpointContext(*,is_recompute)[source]#
Context passed to policy function during selective checkpointing.
This class is used to pass relevant metadata to the policy function duringselective checkpointing. The metadata includes whether the current invocationof the policy function is during recomputation or not.
Example
>>>>>>defpolicy_fn(ctx,op,*args,**kwargs):>>>print(ctx.is_recompute)>>>>>>context_fn=functools.partial(create_selective_checkpoint_contexts,policy_fn)>>>>>>out=torch.utils.checkpoint.checkpoint(>>>fn,x,y,>>>use_reentrant=False,>>>context_fn=context_fn,>>>)
- torch.utils.checkpoint.create_selective_checkpoint_contexts(policy_fn_or_list,allow_cache_entry_mutation=False)[source]#
Helper to avoid recomputing certain ops during activation checkpointing.
Use this withtorch.utils.checkpoint.checkpoint to control whichoperations are recomputed during the backward pass.
- Parameters:
policy_fn_or_list (Callable orList) –
If a policy function is provided, it should accept a
SelectiveCheckpointContext, theOpOverload, args andkwargs to the op, and return aCheckpointPolicyenum valueindicating whether the execution of the op should be recomputed or not.If a list of operations is provided, it is equivalent to a policyreturningCheckpointPolicy.MUST_SAVE for the specifiedoperations andCheckpointPolicy.PREFER_RECOMPUTE for all otheroperations.
allow_cache_entry_mutation (bool,optional) – By default, an error israised if any tensors cached by selective activation checkpoint aremutated in order to ensure correctness. If set toTrue, this checkis disabled.
- Returns:
A tuple of two context managers.
Example
>>>importfunctools>>>>>>x=torch.rand(10,10,requires_grad=True)>>>y=torch.rand(10,10,requires_grad=True)>>>>>>ops_to_save=[>>>torch.ops.aten.mm.default,>>>]>>>>>>defpolicy_fn(ctx,op,*args,**kwargs):>>>ifopinops_to_save:>>>returnCheckpointPolicy.MUST_SAVE>>>else:>>>returnCheckpointPolicy.PREFER_RECOMPUTE>>>>>>context_fn=functools.partial(create_selective_checkpoint_contexts,policy_fn)>>>>>># or equivalently>>>context_fn=functools.partial(create_selective_checkpoint_contexts,ops_to_save)>>>>>>deffn(x,y):>>>returntorch.sigmoid(torch.matmul(torch.matmul(x,y),y))*y>>>>>>out=torch.utils.checkpoint.checkpoint(>>>fn,x,y,>>>use_reentrant=False,>>>context_fn=context_fn,>>>)