Autograd mechanics#
Created On: Jan 16, 2017 | Last Updated On: Jun 16, 2025
This note will present an overview of how autograd works and records theoperations. It’s not strictly necessary to understand all this, but we recommendgetting familiar with it, as it will help you write more efficient, cleanerprograms, and can aid you in debugging.
How autograd encodes the history#
Autograd is a reverse automatic differentiation system. Conceptually,autograd records a graph recording all of the operations that createdthe data as you execute operations, giving you a directed acyclic graphwhose leaves are the input tensors and roots are the output tensors.By tracing this graph from roots to leaves, you can automaticallycompute the gradients using the chain rule.
Internally, autograd represents this graph as a graph ofFunction objects (really expressions), which can beapply() ed to compute the result ofevaluating the graph. When computing the forward pass, autogradsimultaneously performs the requested computations and builds up a graphrepresenting the function that computes the gradient (the.grad_fnattribute of eachtorch.Tensor is an entry point into this graph).When the forward pass is completed, we evaluate this graph in thebackwards pass to compute the gradients.
An important thing to note is that the graph is recreated from scratch at everyiteration, and this is exactly what allows for using arbitrary Python controlflow statements, that can change the overall shape and size of the graph atevery iteration. You don’t have to encode all possible paths before youlaunch the training - what you run is what you differentiate.
Saved tensors#
Some operations need intermediary results to be saved during the forward passin order to execute the backward pass. For example, the function saves the input to compute the gradient.
When defining a custom PythonFunction, you can usesave_for_backward() to savetensors during the forward pass andsaved_tensors to retrieve themduring the backward pass. SeeExtending PyTorch for more information.
For operations that PyTorch defines (e.g.torch.pow()), tensors areautomatically saved as needed. You can explore (for educational or debuggingpurposes) which tensors are saved by a certaingrad_fn by looking for itsattributes starting with the prefix_saved.
x=torch.randn(5,requires_grad=True)y=x.pow(2)print(x.equal(y.grad_fn._saved_self))# Trueprint(xisy.grad_fn._saved_self)# True
In the previous code,y.grad_fn._saved_self refers to the same Tensor object asx.But that may not always be the case. For instance:
x=torch.randn(5,requires_grad=True)y=x.exp()print(y.equal(y.grad_fn._saved_result))# Trueprint(yisy.grad_fn._saved_result)# False
Under the hood, to prevent reference cycles, PyTorch haspacked the tensorupon saving andunpacked it into a different tensor for reading. Here, thetensor you get from accessingy.grad_fn._saved_result is a different tensorobject thany (but they still share the same storage).
Whether a tensor will be packed into a different tensor object depends onwhether it is an output of its owngrad_fn, which is an implementation detailsubject to change and that users should not rely on.
You can control how PyTorch does packing / unpacking withHooks for saved tensors.
Gradients for non-differentiable functions#
The gradient computation using Automatic Differentiation is only valid when each elementary function being used is differentiable.Unfortunately many of the functions we use in practice do not have this property (relu orsqrt at0, for example).To try and reduce the impact of functions that are non-differentiable, we define the gradients of the elementary operations by applying the following rules in order:
If the function is differentiable and thus a gradient exists at the current point, use it.
If the function is convex (at least locally), use the sub-gradient of minimum norm.
If the function is concave (at least locally), use the super-gradient of minimum norm (consider-f(x) and apply the previous point).
If the function is defined, define the gradient at the current point by continuity (note that
infis possible here, for example forsqrt(0)). If multiple values are possible, pick one arbitrarily.If the function is not defined (
sqrt(-1),log(-1)or most functions when the input isNaN, for example) then the value used as the gradient is arbitrary (we might also raise an error but that is not guaranteed). Most functions will useNaNas the gradient, but for performance reasons, some functions will use other values (log(-1), for example).If the function is not a deterministic mapping (i.e. it is not amathematical function), it will be marked as non-differentiable. This will make it error out in the backward if used on tensors that require grad outside of a
no_gradenvironment.
Division by Zero in Autograd#
When performing division by zero in PyTorch (e.g.,x/0), the forward pass will produceinf values following IEEE-754 floating point arithmetic. While theseinf values can be masked out before computing the final loss (e.g., via indexing or masking), the autograd system still tracks and differentiates through the full computation graph, including the division by zero operation.
During backpropagation, this can lead to problematic gradient expressions. For example:
x=torch.tensor([1.,1.],requires_grad=True)div=torch.tensor([0.,1.])y=x/div# Results in [inf, 1]mask=div!=0# [False, True]loss=y[mask].sum()loss.backward()print(x.grad)# [nan, 1], not [0, 1]
In this example, even though we only use the masked output (which excludes the division by zero), autograd still computes gradients through the full computation graph, including the division by zero operation. This results innan gradients for the masked elements, which can cause training instability.
To avoid this issue, there are several recommended approaches:
Mask before division:
x=torch.tensor([1.,1.],requires_grad=True)div=torch.tensor([0.,1.])mask=div!=0safe=torch.zeros_like(x)safe[mask]=x[mask]/div[mask]loss=safe.sum()loss.backward()# Produces safe gradients [0, 1]
Use MaskedTensor (experimental API):
fromtorch.maskedimportas_masked_tensorx=torch.tensor([1.,1.],requires_grad=True)div=torch.tensor([0.,1.])y=x/divmask=div!=0loss=as_masked_tensor(y,mask).sum()loss.backward()# Cleanly handles "undefined" vs "zero" gradients
The key principle is to prevent the division by zero operation from being recorded in the computation graph, rather than masking its results after the fact. This ensures that autograd only computes gradients through valid operations.
This behavior is important to keep in mind when working with operations that might produceinf ornan values, as masking the outputs does not prevent the problematic gradients from being computed.
Locally disabling gradient computation#
There are several mechanisms available from Python to locally disable gradientcomputation:
To disable gradients across entire blocks of code, there are context managerslike no-grad mode and inference mode.For more fine-grained exclusion of subgraphs from gradient computation,there is setting therequires_grad field of a tensor.
Below, in addition to discussing the mechanisms above, we also describeevaluation mode (nn.Module.eval()), a method that is not usedto disable gradient computation but, because of its name, is often mixed up with the three.
Settingrequires_grad#
requires_grad is a flag, defaulting to falseunless wrappedin ann.Parameter, that allows for fine-grained exclusion ofsubgraphs from gradient computation. It takes effect in both theforward and backward passes:
During the forward pass, an operation is only recorded in the backward graph ifat least one of its input tensors require grad.During the backward pass (.backward()), only leaf tensors withrequires_grad=True will have gradients accumulated into their.gradfields.
It is important to note that even though every tensor has this flag,setting it only makes sense for leaf tensors (tensors that do not have agrad_fn, e.g., ann.Module’s parameters).Non-leaf tensors (tensors that do havegrad_fn) are tensors that have abackward graph associated with them. Thus their gradients will be neededas an intermediary result to compute the gradient for a leaf tensor thatrequires grad. From this definition, it is clear that all non-leaf tensorswill automatically haverequire_grad=True.
Settingrequires_grad should be the main way you control which partsof the model are part of the gradient computation, for example, if you need tofreeze parts of your pretrained model during model fine-tuning.
To freeze parts of your model, simply apply.requires_grad_(False) tothe parameters that you don’t want updated. And as described above,since computations that use these parameters as inputs would not be recorded inthe forward pass, they won’t have their.grad fields updated in the backwardpass because they won’t be part of the backward graph in the first place, asdesired.
Because this is such a common pattern,requires_grad can also be set atthe module level withnn.Module.requires_grad_().When applied to a module,.requires_grad_() takes effect on allof the module’s parameters (which haverequires_grad=True by default).
Grad Modes#
Apart from settingrequires_grad there are also three grad modes that canbe selected from Python that can affect how computations in PyTorch areprocessed by autograd internally: default mode (grad mode), no-grad mode,and inference mode, all of which can be togglable via context managers anddecorators.
Mode | Excludes operations from being recorded in backward graph | Skips additional autograd tracking overhead | Tensors created while the mode is enabled can be used in grad-mode later | Examples |
|---|---|---|---|---|
default | ✓ | Forward pass | ||
no-grad | ✓ | ✓ | Optimizer updates | |
inference | ✓ | ✓ | Data processing, model evaluation |
Default Mode (Grad Mode)#
The “default mode” is the mode we are implicitly in when no other modes likeno-grad and inference mode are enabled. To be contrasted with“no-grad mode” the default mode is also sometimes called “grad mode”.
The most important thing to know about the default mode is that it is the onlymode in whichrequires_grad takes effect.requires_grad is always overriddento beFalse in both the two other modes.
No-grad Mode#
Computations in no-grad mode behave as if none of the inputs require grad.In other words, computations in no-grad mode are never recorded in the backward grapheven if there are inputs that haverequire_grad=True.
Enable no-grad mode when you need to perform operations that should not berecorded by autograd, but you’d still like to use the outputs of thesecomputations in grad mode later. This context manager makes it convenient todisable gradients for a block of code or function withouthaving to temporarily set tensors to haverequires_grad=False, and thenback toTrue.
For example, no-grad mode might be useful when writing an optimizer: whenperforming the training update you’d like to update parametersin-place without the update being recorded by autograd.You also intend to use the updated parameters for computations ingrad mode in the next forward pass.
The implementations intorch.nn.init alsorely on no-grad mode when initializing the parameters as to avoidautograd tracking when updating the initialized parameters in-place.
Inference Mode#
Inference mode is the extreme version of no-grad mode. Just like in no-gradmode, computations in inference mode are not recorded in the backward graph, butenabling inference mode will allow PyTorch to speed up your model even more.This better runtime comes with a drawback: tensors created in inference modewill not be able to be used in computations to be recorded by autograd afterexiting inference mode.
Enable inference mode when you are performing computations that do not haveinteractions with autograd, AND you don’t plan on using the tensors createdin inference mode in any computation that is to be recorded by autograd later.
It is recommended that you try out inference mode in the parts of your codethat do not require autograd tracking (e.g., data processing and model evaluation).If it works out of the boxfor your use case it’s a free performance win. If you run into errors afterenabling inference mode, check that you are not using tensors created ininference mode in computations that are recorded by autograd after exiting inferencemode. If you cannot avoid such use in your case, you can always switch backto no-grad mode.
For details on inference mode please seeInference Mode.
For implementation details of inference mode seeRFC-0011-InferenceMode.
Evaluation Mode (nn.Module.eval())#
Evaluation mode is not a mechanism to locally disable gradient computation.It is included here anyway because it is sometimes confused to be such a mechanism.
Functionally,module.eval() (or equivalentlymodule.train(False)) are completelyorthogonal to no-grad mode and inference mode. Howmodel.eval() affectsyour model depends entirely on the specific modules used in your model andwhether they define any training-mode specific behavior.
You are responsible for callingmodel.eval() andmodel.train() if yourmodel relies on modules such astorch.nn.Dropout andtorch.nn.BatchNorm2d that may behavedifferently depending on training mode, for example, to avoid updating yourBatchNorm running statistics on validation data.
It is recommended that you always usemodel.train() whentraining andmodel.eval() when evaluating your model (validation/testing) evenif you aren’t sure your model has training-mode specific behavior, because amodule you are using might be updated to behave differently in training andeval modes.
In-place operations with autograd#
Supporting in-place operations in autograd is a hard matter, and we discouragetheir use in most cases. Autograd’s aggressive buffer freeing and reuse makesit very efficient and there are very few occasions when in-place operationslower memory usage by any significant amount. Unless you’re operatingunder heavy memory pressure, you might never need to use them.
There are two main reasons that limit the applicability of in-place operations:
In-place operations can potentially overwrite values required to computegradients.
Every in-place operation requires the implementation to rewrite thecomputational graph. Out-of-place versions simply allocate new objects andkeep references to the old graph, while in-place operations, requirechanging the creator of all inputs to the
Functionrepresentingthis operation. This can be tricky, especially if there are many Tensorsthat reference the same storage (e.g. created by indexing or transposing),and in-place functions will raise an error if the storage ofmodified inputs is referenced by any otherTensor.
In-place correctness checks#
Every tensor keeps a version counter, that is incremented every time it ismarked dirty in any operation. When a Function saves any tensors for backward,a version counter of their containing Tensor is saved as well. Once you accessself.saved_tensors it is checked, and if it is greater than the saved valuean error is raised. This ensures that if you’re using in-placefunctions and not seeing any errors, you can be sure that the computedgradients are correct.
Multithreaded Autograd#
The autograd engine is responsible for running all the backward operationsnecessary to compute the backward pass. This section will describe all the detailsthat can help you make the best use of it in a multithreaded environment. (This isrelevant only for PyTorch 1.6+ as the behavior in previous version was different.)
User could train their model with multithreading code (e.g. Hogwild training), anddoes not block on the concurrent backward computations, example code could be:
# Define a train function to be used in different threadsdeftrain_fn():x=torch.ones(5,5,requires_grad=True)# forwardy=(x+3)*(x+4)*0.5# backwardy.sum().backward()# potential optimizer update# User write their own threading code to drive the train_fnthreads=[]for_inrange(10):p=threading.Thread(target=train_fn,args=())p.start()threads.append(p)forpinthreads:p.join()
Note that some behaviors that user should be aware of:
Concurrency on CPU#
When you runbackward() orgrad() via python or C++ API in multiplethreads on CPU, you are expecting to see extra concurrency instead ofserializing all the backward calls in a specific order during execution(behavior before PyTorch 1.6).
Non-determinism#
If you are callingbackward() from multiple threads concurrently and haveshared inputs (i.e. Hogwild CPU training), then non-determinism should be expected.This can occur because parameters are automatically shared across threads,as such, multiple threads may access and try to accumulate the same.gradattribute during gradient accumulation. This is technically not safe, andit might result in race condition and the result might be invalid to use.
Users developing multithreaded models featuring shared parameters should have thethreading model in mind and should understand the issues described above.
The functional APItorch.autograd.grad() may be used to calculate thegradients instead ofbackward() to avoid non-determinism.
Graph retaining#
If part of the autograd graph is shared between threads, i.e. run firstpart of forward single thread, then run second part in multiple threads,then the first part of graph is shared. In this case different threadsexecutegrad() orbackward() on the same graph might have issue ofdestroying the graph on the fly of one thread, and the other thread willcrash in this case. Autograd will error out to the user similar to what callbackward() twice with outretain_graph=True, and let the user knowthey should useretain_graph=True.
Thread Safety on Autograd Node#
Since Autograd allows the caller thread to drive its backward execution forpotential parallelism, it’s important that we ensure thread safety on CPU withparallelbackward() calls that share part/whole of the GraphTask.
Custom Pythonautograd.Functions are automatically thread safe because of GIL.For built-in C++ Autograd Nodes (e.g. AccumulateGrad, CopySlices) and customautograd::Functions, the Autograd Engine uses thread mutex locking to ensurethread safety on autograd Nodes that might have state write/read.
No thread safety on C++ hooks#
Autograd relies on the user to write thread safe C++ hooks. If you want the hookto be correctly applied in multithreading environment, you will need to writeproper thread locking code to ensure the hooks are thread safe.
Autograd for Complex Numbers#
The short version:
When you use PyTorch to differentiate any function with complex domain and/or codomain,the gradients are computed under the assumption that the function is a part of a larger real-valuedloss function. The gradient computed is(note the conjugation of z), the negative of which is precisely the direction of steepest descentused in Gradient Descent algorithm. Thus, there is a viable path in making the existing optimizerswork out of the box with complex parameters.
This convention matches TensorFlow’s convention for complexdifferentiation, but is different from JAX (which computes).
If you have a real-to-real function which internally uses complexoperations, the convention here doesn’t matter: you will always getthe same result that you would have gotten if it had been implementedwith only real operations.
If you are curious about the mathematical details, or want to know howto define complex derivatives in PyTorch, read on.
What are complex derivatives?#
The mathematical definition of complex-differentiability takes thelimit definition of a derivative and generalizes it to operate oncomplex numbers. Consider a function,
where and are two variable real valued functionsand is the imaginary unit.
Using the derivative definition, we can write:
In order for this limit to exist, not only must and must bereal differentiable, but must also satisfy the Cauchy-Riemannequations. Inother words: the limit computed for real and imaginary steps ()must be equal. This is a more restrictive condition.
The complex differentiable functions are commonly known as holomorphicfunctions. They are well behaved, have all the nice properties thatyou’ve seen from real differentiable functions, but are practically of nouse in the optimization world. For optimization problems, only real valued objectivefunctions are used in the research community since complex numbers are not part of anyordered field and so having complex valued loss does not make much sense.
It also turns out that no interesting real-valued objective fulfill theCauchy-Riemann equations. So the theory with holomorphic function cannot beused for optimization and most people therefore use the Wirtinger calculus.
Wirtinger Calculus comes into the picture …#
So, we have this great theory of complex differentiability andholomorphic functions, and we can’t use any of it at all, because manyof the commonly used functions are not holomorphic. What’s a poormathematician to do? Well, Wirtinger observed that even ifisn’t holomorphic, one could rewrite it as a two variable function which is always holomorphic. This is because real andimaginary of the components of can be expressed in terms of and as:
Wirtinger calculus suggests to study instead, which isguaranteed to be holomorphic if was real differentiable (anotherway to think of it is as a change of coordinate system, fromto.) This function has partial derivatives and.We can use the chain rule to establish arelationship between these partial derivatives and the partialderivatives w.r.t., the real and imaginary components of.
From the above equations, we get:
which is the classic definition of Wirtinger calculus that you would find onWikipedia.
There are a lot of beautiful consequences of this change.
For one, the Cauchy-Riemann equations translate into simply saying that (that is to say, the function can be writtenentirely in terms of, without making reference to).
Another important (and somewhat counterintuitive) result, as we’ll see later, is that when we do optimization on a real-valued loss, the step we shouldtake while making variable update is given by (not).
For more reading, check out:https://arxiv.org/pdf/0906.4835.pdf
How is Wirtinger Calculus useful in optimization?#
Researchers in audio and other fields, more commonly, use gradientdescent to optimize real valued loss functions with complex variables.Typically, these people treat the real and imaginary values as separatechannels that can be updated. For a step size and loss, we can write the following equations in:
How do these equations translate into complex space?
Something very interesting has happened: Wirtinger calculus tells usthat we can simplify the complex variable update formula above to onlyrefer to the conjugate Wirtinger derivative, giving us exactly the step we take in optimization.
Because the conjugate Wirtinger derivative gives us exactly the correct step for a real valued loss function, PyTorch gives you this derivativewhen you differentiate a function with a real valued loss.
How does PyTorch compute the conjugate Wirtinger derivative?#
Typically, our derivative formulas take ingrad_output as an input,representing the incoming Vector-Jacobian product that we’ve alreadycomputed, aka,, whereis the loss of the entire computation (producing a real loss) and is the output of our function. The goal here is to compute, where is the input ofthe function. It turns out that in the case of real loss, we canget away withonly calculating,even though the chain rule implies that we also need tohave access to. If you wantto skip this derivation, look at the last equation in this sectionand then skip to the next section.
Let’s continue working with defined as. As discussed above,autograd’s gradient convention is centered around optimization for realvalued loss functions, so let’s assume is a part of largerreal valued loss function. Using chain rule, we can write:
(1)#
Now using Wirtinger derivative definition, we can write:
It should be noted here that since and are realfunctions, and is real by our assumption that is apart of a real valued function, we have:
(2)#
i.e., equals to.
Solving the above equations for and, we get:
(3)#
Substituting(3) in(1), we get:
Using(2), we get:
(4)#
This last equation is the important one for writing your own gradients,as it decomposes our derivative formula into a simpler one that is easyto compute by hand.
How can I write my own derivative formula for a complex function?#
The above boxed equation gives us the general formula for allderivatives on complex functions. However, we still need tocompute and.There are two ways you could do this:
The first way is to just use the definition of Wirtinger derivatives directly and calculate and byusing and(which you can compute in the normal way).
The second way is to use the change of variables trick and rewrite as a two variable function, and computethe conjugate Wirtinger derivatives by treating and as independent variables. This is often easier; for example, if the function in question is holomorphic, only will be used (and will be zero).
Let’s consider the function as an example, where.
Using the first way to compute the Wirtinger derivatives, we have.
Using(4), andgrad_output = 1.0 (which is the default grad output value used whenbackward() is called on a scalar output in PyTorch), we get:
Using the second way to compute Wirtinger derivatives, we directly get:
And using(4) again, we get. As you can see, the second way involves lesser calculations, and comesin more handy for faster calculations.
What about cross-domain functions?#
Some functions map from complex inputs to real outputs, or vice versa.These functions form a special case of(4), which we can derive using thechain rule:
For, we get:
For, we get:
Hooks for saved tensors#
You can controlhow saved tensors are packed / unpacked by defining a pair ofpack_hook /unpack_hookhooks. Thepack_hook function should take a tensor as its single argumentbut can return any python object (e.g. another tensor, a tuple, or even astring containing a filename). Theunpack_hook function takes as its singleargument the output ofpack_hook and should return a tensor to be used inthe backward pass. The tensor returned byunpack_hook only needs to havethe same content as the tensor passed as input topack_hook. In particular,any autograd-related metadata can be ignored as they will be overwritten duringunpacking.
An example of such pair is:
classSelfDeletingTempFile():def__init__(self):self.name=os.path.join(tmp_dir,str(uuid.uuid4()))def__del__(self):os.remove(self.name)defpack_hook(tensor):temp_file=SelfDeletingTempFile()torch.save(tensor,temp_file.name)returntemp_filedefunpack_hook(temp_file):returntorch.load(temp_file.name)
Notice that theunpack_hook should not delete the temporary file because itmight be called multiple times: the temporary file should be alive for as longas the returnedSelfDeletingTempFile object is alive. In the above example,we prevent leaking the temporary file by closing it when it is no longer needed(on deletion of theSelfDeletingTempFile object).
Note
We guarantee thatpack_hook will only be called once butunpack_hook canbe called as many times as the backward pass requires it and we expect it toreturn the same data each time.
Warning
Performing inplace operations on the input of any of the functions is forbiddenas they may lead to unexpected side-effects. PyTorch will throw an error if theinput to a pack hook is modified inplace but does not catch the case where theinput to an unpack hook is modified inplace.
Registering hooks for a saved tensor#
You can register a pair of hooks on a saved tensor by calling theregister_hooks() method on aSavedTensor object. Those objects are exposed as attributes of agrad_fn and start with the_raw_saved_ prefix.
x=torch.randn(5,requires_grad=True)y=x.pow(2)y.grad_fn._raw_saved_self.register_hooks(pack_hook,unpack_hook)
Thepack_hook method is called as soon as the pair is registered.Theunpack_hook method is called each time the saved tensor needs to beaccessed, either by means ofy.grad_fn._saved_self or during the backwardpass.
Warning
If you maintain a reference to aSavedTensor after the savedtensors have been released (i.e. after backward has been called), callingitsregister_hooks() is forbidden.PyTorch will throw an error most of the time but it may failto do so in some cases and undefined behavior may arise.
Registering default hooks for saved tensors#
Alternatively, you can use the context-managersaved_tensors_hooks to register a pair ofhooks which will be applied toall saved tensors that are created inthat context.
Example:
# Only save on disk tensors that have size >= 1000SAVE_ON_DISK_THRESHOLD=1000defpack_hook(x):ifx.numel()<SAVE_ON_DISK_THRESHOLD:returnx.detach()temp_file=SelfDeletingTempFile()torch.save(tensor,temp_file.name)returntemp_filedefunpack_hook(tensor_or_sctf):ifisinstance(tensor_or_sctf,torch.Tensor):returntensor_or_sctfreturntorch.load(tensor_or_sctf.name)classModel(nn.Module):defforward(self,x):withtorch.autograd.graph.saved_tensors_hooks(pack_hook,unpack_hook):# ... compute outputoutput=xreturnoutputmodel=Model()net=nn.DataParallel(model)
The hooks defined with this context manager are thread-local.Hence, the following code will not produce the desired effects because the hooks do not gothroughDataParallel.
# Example what NOT to donet=nn.DataParallel(model)withtorch.autograd.graph.saved_tensors_hooks(pack_hook,unpack_hook):output=net(input)
Note that using those hooks disables all the optimization in place to reduceTensor object creation. For example:
withtorch.autograd.graph.saved_tensors_hooks(lambdax:x.detach(),lambdax:x):x=torch.randn(5,requires_grad=True)y=x*x
Without the hooks,x,y.grad_fn._saved_self andy.grad_fn._saved_other all refer to the same tensor object.With the hooks, PyTorch will pack and unpackx into two new tensor objectsthat share the same storage with the originalx (no copy performed).
Backward Hooks execution#
This section will discuss when different hooks fire or don’t fire.Then it will discuss the order in which they are fired.The hooks that will be covered are: backward hooks registered to Tensor viatorch.Tensor.register_hook(), post-accumulate-grad hooks registered toTensor viatorch.Tensor.register_post_accumulate_grad_hook(), post-hooksregistered to Node viatorch.autograd.graph.Node.register_hook(), andpre-hooks registered to Node viatorch.autograd.graph.Node.register_prehook().
Whether a particular hook will be fired#
Hooks registered to a Tensor viatorch.Tensor.register_hook()are executed when gradients are being computed for that Tensor. (Note that this does not requirethe Tensor’s grad_fn to be executed. For example, if the Tensor is passedas part of theinputs argument totorch.autograd.grad(),the Tensor’s grad_fn may not be executed, but the hook register to that Tensor will always be executed.)
Hooks registered to a Tensor viatorch.Tensor.register_post_accumulate_grad_hook()are executed after the gradients have been accumulated for that Tensor, meaning theTensor’s grad field has been set. Whereas hooks registered viatorch.Tensor.register_hook()are run as gradients are being computed, hooks registered viatorch.Tensor.register_post_accumulate_grad_hook()are only triggered once the Tensor’s grad field is updated by autograd at the end ofthe backward pass. Thus, post-accumulate-grad hooks can only be registered for leafTensors. Registering a hook viatorch.Tensor.register_post_accumulate_grad_hook()on a non-leaf Tensor will error, even if you callbackward(retain_graph=True).
Hooks registered totorch.autograd.graph.Node usingtorch.autograd.graph.Node.register_hook() ortorch.autograd.graph.Node.register_prehook() are only fired ifthe Node it was registered to is executed.
Whether a particular Node is executed may depend on whether the backward pass was called withtorch.autograd.grad() ortorch.autograd.backward().Specifically, you should be aware of these differences when you register a hook on aNode corresponding to a Tensor that you are passing totorch.autograd.grad() ortorch.autograd.backward() as part of theinputs argument.
If you are usingtorch.autograd.backward(), all of the above mentioned hooks will be executed,whether or not you specified theinputs argument. This is because.backward() executes allNodes, even if they correspond to a Tensor specified as an input.(Note that the execution of this additional Node corresponding to Tensors passed asinputsis usually unnecessary, but done anyway. This behavior is subject to change;you should not depend on it.)
On the other hand, if you are usingtorch.autograd.grad(), the backward hooks registeredto Nodes that correspond to the Tensors passed toinput may not be executed, becausethose Nodes will not be executed unless there is another input that depends on the gradientresult of this Node.
The order in which the different hooks are fired#
The order in which things happen are:
hooks registered to Tensor are executed
pre-hooks registered to Node are executed (if Node is executed).
the
.gradfield is updated for Tensors that retain_gradNode is executed (subject to rules above)
for leaf Tensors that have
.gradaccumulated, post-accumulate-grad hooks are executedpost-hooks registered to Node are executed (if Node is executed)
If multiple hooks of the same type are registered on the same Tensor or Nodethey are executed in the order in which they are registered.Hooks that are executed later can observe the modifications to the gradient made byearlier hooks.
Special hooks#
torch.autograd.graph.register_multi_grad_hook() is implemented using hooks registeredto Tensors. Each individual Tensor hook is fired following the Tensor hook orderingdefined above and the registered multi-grad hook is called when the last Tensor gradientis computed.
torch.nn.modules.module.register_module_full_backward_hook() is implemented using hooksregistered to Node. As the forward is computed, hooks are registered to grad_fn correspondingto the inputs and outputs of the module. Because a module may take multiple inputs and returnmultiple outputs, a dummy custom autograd Function is first applied to the inputs of the modulebefore forward and the outputs of the module before the output of forward is returned to ensurethat those Tensors share a single grad_fn, which we can then attach our hooks to.
Behavior of Tensor hooks when Tensor is modified in-place#
Usually hooks registered to a Tensor receive the gradient of the outputs with respect to thatTensor, where the value of the Tensor is taken to be its value at the time backward is computed.
However, if you register hooks to a Tensor, and then modify that Tensor in-place, hooksregistered before in-place modification similarly receive gradients of the outputs withrespect to the Tensor, but the value of the Tensor is taken to be its value beforein-place modification.
If you prefer the behavior in the former case,you should register them to the Tensor after all in-place modifications to it have been made.For example:
t=torch.tensor(1.,requires_grad=True).sin()t.cos_()t.register_hook(fn)t.backward()
Furthermore, it can be helpful to know that under the hood,when hooks are registered to a Tensor, they actually become permanently bound to the grad_fnof that Tensor, so if that Tensor is then modified in-place,even though the Tensor now has a new grad_fn, hooks registered before it wasmodified in-place will continue to be associated with the old grad_fn, e.g. they willfire when that Tensor’s old grad_fn is reached in the graph by the autograd engine.