Rate this Page

torch.utils.checkpoint#

Created On: Jun 16, 2025 | Last Updated On: Oct 29, 2025

Note

Checkpointing is implemented by rerunning a forward-pass segment foreach checkpointed segment during backward propagation. This can cause persistentstates like the RNG state to be more advanced than they would withoutcheckpointing. By default, checkpointing includes logic to jugglethe RNG state such that checkpointed passes making use of RNG(through dropout for example) have deterministic output ascompared to non-checkpointed passes. The logic to stash and restoreRNG states can incur a moderate performance hit depending on the runtimeof checkpointed operations. If deterministic output compared tonon-checkpointed passes is not required, supplypreserve_rng_state=Falsetocheckpoint orcheckpoint_sequential to omit stashing andrestoring the RNG state during each checkpoint.

The stashing logic saves and restores the RNG state for CPU and anotherdevice type (infer the device type from Tensor arguments excluding CPUtensors by_infer_device_type) to therun_fn. If there are multipledevices, device state will only be saved for devices of a single device type,and the remaining devices will be ignored. Consequently, if any checkpointedfunctions involve randomness, this may result in incorrect gradients. (Notethat if CUDA devices are among the devices detected, it will be prioritized;otherwise, the first device encountered will be selected.) If there are noCPU-tensors, the default device type state (default value iscuda, and itcould be set to other device byDefaultDeviceType) will be saved and restored.However, the logic has no way to anticipate if the user will moveTensors to a new device within therun_fn itself. Therefore, if you moveTensors to a new device (“new” meaning not belonging to the set of[current device + devices of Tensor arguments]) withinrun_fn, deterministicoutput compared to non-checkpointed passes is never guaranteed.

torch.utils.checkpoint.checkpoint(function,*args,use_reentrant=None,context_fn=<functionnoop_context_fn>,determinism_check='default',debug=False,early_stop=True,**kwargs)[source]#

Checkpoint a model or part of the model.

Activation checkpointing is a technique that trades compute for memory.By default, tensors computed during the forward pass are kept alive untilthey are used in gradient computations in the backward pass. To reduce thismemory usage, tensors produced in the passedfunction are not keptalive until the backward pass. Instead, any passed tensors inargsare kept alive, and the unsaved tensors are recomputed by re-invokingfunction in the backward pass as needed for gradient computation.Activation checkpointing can be applied to any part of a model – this issometimes described as “checkpointing” that part of the model.

There are currently two checkpointing implementations available, determinedby theuse_reentrant parameter. It is recommended that you useuse_reentrant=False. Please refer the note below for a discussion oftheir differences.

Warning

If thefunction invocation during the backward pass differsfrom the forward pass, e.g., due to a global variable, the checkpointedversion may not be equivalent, potentially causing anerror being raised or leading to silently incorrect gradients.

Warning

Theuse_reentrant parameter should be passed explicitly. In version2.9 we will raise an exception ifuse_reentrant is not passed.If you are using theuse_reentrant=True variant, please refer to thenote below for important considerations and potential limitations.

Note

The reentrant variant of checkpoint (use_reentrant=True) andthe non-reentrant variant of checkpoint (use_reentrant=False)differ in the following ways:

  • Non-reentrant checkpoint stops recomputation as soon as all neededintermediate activations have been recomputed. This feature is enabledby default, but can be disabled withset_checkpoint_early_stop().Reentrant checkpoint always recomputesfunction in itsentirety during the backward pass.

  • The reentrant variant does not record the autograd graph during theforward pass, as it runs with the forward pass undertorch.no_grad(). The non-reentrant version does record theautograd graph, allowing one to perform backward on the graph withincheckpointed regions.

  • The reentrant checkpoint only supports thetorch.autograd.backward() API for the backward pass without itsinputs argument, while the non-reentrant version supports all waysof performing the backward pass.

  • At least one input and output must haverequires_grad=True for thereentrant variant. If this condition is unmet, the checkpointed partof the model will not have gradients. The non-reentrant version doesnot have this requirement.

  • The reentrant version does not consider tensors in nested structures(e.g., custom objects, lists, dicts, etc) as participating inautograd, while the non-reentrant version does.

  • The reentrant checkpoint does not support checkpointed regions withdetached tensors from the computational graph, whereas thenon-reentrant version does. For the reentrant variant, if thecheckpointed segment contains tensors detached usingdetach() orwithtorch.no_grad(), the backward pass will raise an error.This is becausecheckpoint makes all the outputs require gradientsand this causes issues when a tensor is defined to have no gradient inthe model. To avoid this, detach the tensors outside of thecheckpoint function.

Parameters:
  • function – describes what to run in the forward pass of the model orpart of the model. It should also know how to handle the inputspassed as the tuple. For example, in LSTM, if user passes(activation,hidden),function should correctly use thefirst input asactivation and the second input ashidden

  • args – tuple containing inputs to thefunction

Keyword Arguments:
  • preserve_rng_state (bool,optional) – Omit stashing and restoringthe RNG state during each checkpoint. Note that under torch.compile,this flag doesn’t take effect and we always preserve RNG state.Default:True

  • use_reentrant (bool) – specify whether to use the activation checkpoint variant thatrequires reentrant autograd. This parameter should be passedexplicitly. In version 2.9 we will raise an exception ifuse_reentrant is not passed. Ifuse_reentrant=False,checkpoint will use an implementation that does not requirereentrant autograd. This allowscheckpoint to support additionalfunctionality, such as working as expected withtorch.autograd.grad and support for keyword arguments input intothe checkpointed function.

  • context_fn (Callable,optional) – A callable returning a tuple of twocontext managers. The function and its recomputation will be rununder the first and second context managers respectively.This argument is only supported ifuse_reentrant=False.

  • determinism_check (str,optional) – A string specifying the determinismcheck to perform. By default it is set to"default" whichcompares the shapes, dtypes, and devices of the recomputed tensorsagainst those the saved tensors. To turn off this check, specify"none". Currently these are the only two supported values.Please open an issue if you would like to see more determinismchecks. This argument is only supported ifuse_reentrant=False,ifuse_reentrant=True, the determinism check is always disabled.

  • debug (bool,optional) – IfTrue, error messages will also includea trace of the operators ran during the original forward computationas well as the recomputation. This argument is only supported ifuse_reentrant=False.

  • early_stop (bool,optional) – IfTrue, non-reentrant checkpoint stopsrecomputation as soon as it has computed all needed Tensors. Thisargument is ignored ifuse_reentrant=True. Can be overriddenglobally usingset_checkpoint_early_stop() context manager.Default:True.

Returns:

Output of runningfunction on*args

torch.utils.checkpoint.checkpoint_sequential(functions,segments,input,use_reentrant=None,**kwargs)[source]#

Checkpoint a sequential model to save memory.

Sequential models execute a list of modules/functions in order(sequentially). Therefore, we can divide such a model in various segmentsand checkpoint each segment. All segments except the last will not storethe intermediate activations. The inputs of each checkpointed segment willbe saved for re-running the segment in the backward pass.

Warning

Theuse_reentrant parameter should be passed explicitly. In version2.9 we will raise an exception ifuse_reentrant is not passed.If you are using theuse_reentrant=True`variant,pleasesee:func:`~torch.utils.checkpoint.checkpoint`fortheimportantconsiderationsandlimitationsofthisvariant.Itisrecommendedthatyouuse``use_reentrant=False.

Parameters:
  • functions – Atorch.nn.Sequential or the list of modules orfunctions (comprising the model) to run sequentially.

  • segments – Number of chunks to create in the model

  • input – A Tensor that is input tofunctions

  • preserve_rng_state (bool,optional) – Omit stashing and restoringthe RNG state during each checkpoint.Default:True

  • use_reentrant (bool) – specify whether to use the activation checkpoint variant thatrequires reentrant autograd. This parameter should be passedexplicitly. In version 2.5 we will raise an exception ifuse_reentrant is not passed. Ifuse_reentrant=False,checkpoint will use an implementation that does not requirereentrant autograd. This allowscheckpoint to support additionalfunctionality, such as working as expected withtorch.autograd.grad and support for keyword arguments input intothe checkpointed function.

Returns:

Output of runningfunctions sequentially on*inputs

Example

>>>model=nn.Sequential(...)>>>input_var=checkpoint_sequential(model,chunks,input_var)
torch.utils.checkpoint.set_checkpoint_debug_enabled(enabled)[source]#

Context manager that sets whether checkpoint should print additional debuginformation when running. See thedebug flag forcheckpoint() for more information. Note thatwhen set, this context manager overrides the value ofdebug passed tocheckpoint. To defer to the local setting, passNone to this context.

Parameters:

enabled (bool) – Whether checkpoint should print debug information.Default is ‘None’.

classtorch.utils.checkpoint.CheckpointPolicy(value)[source]#

Enum for specifying the policy for checkpointing during backpropagation.

The following policies are supported:

  • {MUST,PREFER}_SAVE: The operation’s output will be saved during the forwardpass and will not be recomputed during the backward pass

  • {MUST,PREFER}_RECOMPUTE: The operation’s output will not be saved during theforward pass and will be recomputed during the backward pass

UseMUST_* overPREFER_* to indicate that the policy should not be overriddenby other subsystems liketorch.compile.

Note

A policy function that always returnsPREFER_RECOMPUTE isequivalent to vanilla checkpointing.

A policy function that returnsPREFER_SAVE every op isNOT equivalent to not using checkpointing. Using such a policy wouldsave additional tensors not limited to ones that are actually needed forgradient computation.

classtorch.utils.checkpoint.SelectiveCheckpointContext(*,is_recompute)[source]#

Context passed to policy function during selective checkpointing.

This class is used to pass relevant metadata to the policy function duringselective checkpointing. The metadata includes whether the current invocationof the policy function is during recomputation or not.

Example

>>>>>>defpolicy_fn(ctx,op,*args,**kwargs):>>>print(ctx.is_recompute)>>>>>>context_fn=functools.partial(create_selective_checkpoint_contexts,policy_fn)>>>>>>out=torch.utils.checkpoint.checkpoint(>>>fn,x,y,>>>use_reentrant=False,>>>context_fn=context_fn,>>>)
torch.utils.checkpoint.create_selective_checkpoint_contexts(policy_fn_or_list,allow_cache_entry_mutation=False)[source]#

Helper to avoid recomputing certain ops during activation checkpointing.

Use this withtorch.utils.checkpoint.checkpoint to control whichoperations are recomputed during the backward pass.

Parameters:
  • policy_fn_or_list (Callable orList) –

    • If a policy function is provided, it should accept aSelectiveCheckpointContext, theOpOverload, args andkwargs to the op, and return aCheckpointPolicy enum valueindicating whether the execution of the op should be recomputed or not.

    • If a list of operations is provided, it is equivalent to a policyreturningCheckpointPolicy.MUST_SAVE for the specifiedoperations andCheckpointPolicy.PREFER_RECOMPUTE for all otheroperations.

  • allow_cache_entry_mutation (bool,optional) – By default, an error israised if any tensors cached by selective activation checkpoint aremutated in order to ensure correctness. If set toTrue, this checkis disabled.

Returns:

A tuple of two context managers.

Example

>>>importfunctools>>>>>>x=torch.rand(10,10,requires_grad=True)>>>y=torch.rand(10,10,requires_grad=True)>>>>>>ops_to_save=[>>>torch.ops.aten.mm.default,>>>]>>>>>>defpolicy_fn(ctx,op,*args,**kwargs):>>>ifopinops_to_save:>>>returnCheckpointPolicy.MUST_SAVE>>>else:>>>returnCheckpointPolicy.PREFER_RECOMPUTE>>>>>>context_fn=functools.partial(create_selective_checkpoint_contexts,policy_fn)>>>>>># or equivalently>>>context_fn=functools.partial(create_selective_checkpoint_contexts,ops_to_save)>>>>>>deffn(x,y):>>>returntorch.sigmoid(torch.matmul(torch.matmul(x,y),y))*y>>>>>>out=torch.utils.checkpoint.checkpoint(>>>fn,x,y,>>>use_reentrant=False,>>>context_fn=context_fn,>>>)