Rate this Page

Adafactor#

classtorch.optim.Adafactor(params,lr=0.01,beta2_decay=-0.8,eps=(None,0.001),d=1.0,weight_decay=0.0,*,foreach=None,maximize=False)#

Implements Adafactor algorithm.

input:γ(lr),τ(β2 decay),θ0(params),f(θ)(objective),ϵ1,ϵ2 (epsilons),d(clipping threshold),λ(weight decay),maximizeinitialize:R00 (second moment row factor),C00 (second moment col factor),V^00 (second moment for vectors)fort=1todoifmaximize:Gtθft(θt1)elseGtθft(θt1)β^2t1tτρtmin(lr,1t)αtmax(ϵ2,RMS(θt1))ρtθtθt1γλθt1ifdim(Gt)>1:Rtβ^2tRt1+(1β^2t)(GtGt)1mCtβ^2tCt1+(1β^2t)1n(GtGt)V^tRtCtmax(1nRt,ϵ1)elseV^tβ^2tV^t1+(1β^2t)(GtGt)UtGtmax(V^t,ϵ1)U^tUtmax(1,RMS(Ut)d)θtθt1αtU^treturnθt\begin{aligned} &\rule{110mm}{0.4pt} \\ &\textbf{input} : \gamma \text{(lr)}, \: \tau \text{(}\beta_2\text{ decay)}, \: \theta_0 \text{(params)}, \: f(\theta) \text{(objective)}, \\ &\hspace{15mm} \: \epsilon_1, \epsilon_2 \text{ (epsilons)}, \: d \text{(clipping threshold)}, \\ &\hspace{15mm} \: \lambda \text{(weight decay)}, \: \textit{maximize} \\ &\textbf{initialize} : \: R_0 \leftarrow 0 \text{ (second moment row factor)}, \\ &\hspace{23mm} \: C_0 \leftarrow 0 \text{ (second moment col factor)}, \\ &\hspace{23mm} \: \widehat{V}_0 \leftarrow 0 \text{ (second moment for vectors)} \\[-1.ex] &\rule{110mm}{0.4pt} \\ &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\ &\hspace{10mm}G_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\ &\hspace{5mm}\textbf{else} \\ &\hspace{10mm}G_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ &\hspace{5mm}\widehat{\beta}_{2_t} \leftarrow 1 - t^{\tau} \\ &\hspace{5mm}\rho_t \leftarrow min(lr, \frac{1}{\sqrt{t}}) \\ &\hspace{5mm}\alpha_t \leftarrow max(\epsilon_2, \text{RMS}(\theta_{t-1}))\rho_t \\ &\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\ &\hspace{5mm}\textbf{if} \: \text{dim}(G_t) > 1: \\ &\hspace{10mm}R_t \leftarrow \widehat{\beta}_{2_t}R_{t-1}+ (1-\widehat{\beta}_{2_t})(G_t \odot G_t) \cdot 1_m \\ &\hspace{10mm}C_t \leftarrow \widehat{\beta}_{2_t}C_{t-1}+ (1-\widehat{\beta}_{2_t}) 1^\top_n \cdot (G_t \odot G_t) \\ &\hspace{10mm}\widehat{V}_t \leftarrow \frac{R_t \cdot C_t}{max(1^\top_n \cdot R_t, \epsilon_1)} \\ &\hspace{5mm}\textbf{else} \\ &\hspace{10mm}\widehat{V}_t \leftarrow \widehat{\beta}_{2_t}\widehat{V}_{t-1}+ (1-\widehat{\beta}_{2_t}) \cdot (G_t \odot G_t) \\ &\hspace{5mm}U_t \leftarrow \frac{G_t}{max(\sqrt{\widehat{V}_t}, \epsilon_1)} \\ &\hspace{5mm}\widehat{U}_t \leftarrow \frac{U_t}{max(1, \frac{\text{RMS}(U_t)}{d})} \\ &\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \alpha_t \widehat{U}_t \\ &\rule{110mm}{0.4pt} \\[-1.ex] &\bf{return} \: \theta_t \\[-1.ex] &\rule{110mm}{0.4pt} \\[-1.ex]\end{aligned}

For further details regarding the algorithm we refer toAdafactor: Adaptive Learning Rates with Sublinear Memory Cost.

Parameters
  • params (iterable) – iterable of parameters or named_parameters to optimizeor iterable of dicts defining parameter groups. When using named_parameters,all parameters in all groups should be named

  • lr (float,Tensor,optional) – unlike other optimizers, Adafactor does not require alearning rate, and Noam Shazeer and Mitchell Stern do not use lr at all.Deviating from the paper, this implementation uses lr for applying weightdecay and as the maximum value for relative step size rho_t. Note that inthe paper, a constant of 0.01 is used as the maximum value for relativestep size, and so we set 0.01 as the default value. (default: 1e-2)

  • beta2_decay (float,optional) – the decay rate of beta2. beta2 standardly refersto the coefficient used for computing the running average of the gradientsquared. (default: -0.8)

  • eps (Tuple[float,float],optional) – epsilon1 is the term added to the denominatorof the update calculation to improve numerical stability. This use of epsilon1deviates from the algorithm written in the paper! See note below for more details.epsilon2 is the term used to avoid having too small a weight update when applyingparameter scaling. (default: (None, 1e-3))

  • d (float,optional) – the clipping threshold, used to avoid larger-than-desiredupdates.

  • weight_decay (float,optional) – weight decay coefficient (default: 1e-2)

  • foreach (bool,optional) – whether foreach implementation of optimizer is used. Notethat the foreach implementation uses ~ sizeof(params) more peak memory than thefor-loop version due to the intermediates being a tensorlist vs just one tensor.As Adafactor is commonly used when memory is prohibitive, Adafactor will defaultto the slower single tensor for-loop implementation unless this flag is explicitlyTrue. This behavior is contrary to other optimizers, which will attempt defaultingto foreach on CUDA for faster runtime. (default: None)

  • maximize (bool,optional) – maximize the objective with respect to theparams, instead of minimizing (default: False)

Note

The implementation of Adafactor subtly differs from Noam Shazeer and Mitchell Sternand implementations in some other frameworks with its use of learning rate andϵ1\epsilon_1.

Regarding the learning rate hyperparameter: Noam Shazeer and Mitchell Stern do notuse lr at all, as the stated algorithm usesρt\rho_t and update clipping toaffect the step size.

This implementation allowslr to influence the maximum value forρt\rho_t:

This differs from Noam Shazeer and Mitchell Stern, who use a constant of 0.01 asthe maximum value ofρt\rho_t

Noam Shazeer and Mitchell Stern do not enforce an opinion on how weight decay shouldbe computed, and so we use the learning rate as a coefficient for decoupled weightdecay, similar to what is suggested inDecoupled Weight Decay Regularization.

Regarding the use ofϵ1\epsilon_1: The implementation attempts to replicate thepresumed intention of Noam Shazeer and Mitchell Stern to useϵ1\epsilon_1 asa stabilizing term when the squared gradient becomes small.

This stabilization can be written as

Rtβ^2tRt1+(1β^2t)(GtGt+1n1m)1mCtβ^2tCt1+(1β^2t)1n(GtGt+1n1m)V^tRtCtmax(1nRt,ϵ1)UtGtmax(V^t,ϵ1)\begin{aligned} &\hspace{5mm}R_t \leftarrow \widehat{\beta}_{2_t}R_{t-1}+ (1-\widehat{\beta}_{2_t})(G_t \odot G_t + 1_n \cdot 1^\top_m) \cdot 1_m \\ &\hspace{5mm}C_t \leftarrow \widehat{\beta}_{2_t}C_{t-1}+ (1-\widehat{\beta}_{2_t}) 1^\top_n \cdot (G_t \odot G_t + 1_n \cdot 1^\top_m) \\ &\hspace{5mm}\widehat{V}_t \leftarrow \frac{R_t \cdot C_t}{max(1^\top_n \cdot R_t, \epsilon_1)} \\ &\hspace{5mm}U_t \leftarrow \frac{G_t}{max(\sqrt{\widehat{V}_t}, \epsilon_1)} \\\end{aligned}

where the row and column factors of gradient squaredRtR_t andCtC_tare left alone, and we applyϵ1\epsilon_1 at the final calculation ofthe variance estimateV^t\widehat{V}_t and for the updateUtU_t.

This is in contrast to Noam Shazeer and Mitchell Stern and other frameworks whichapplyϵ1\epsilon_1 to both row and column factors of the squared gradient, butnot in the calculations after:

Rtβ^2tRt1+(1β^2t)(GtGt+ϵ11n1m)1mCtβ^2tCt1+(1β^2t)1n(GtGt+ϵ11n1m)V^tRtCt1nRtUtGtV^t\begin{aligned} &\hspace{5mm}R_t \leftarrow \widehat{\beta}_{2_t}R_{t-1}+ (1-\widehat{\beta}_{2_t})(G_t \odot G_t + \epsilon_1 1_n \cdot 1^\top_m) \cdot 1_m \\ &\hspace{5mm}C_t \leftarrow \widehat{\beta}_{2_t}C_{t-1}+ (1-\widehat{\beta}_{2_t}) 1^\top_n \cdot (G_t \odot G_t + \epsilon_1 1_n \cdot 1^\top_m) \\ &\hspace{5mm}\widehat{V}_t \leftarrow \frac{R_t \cdot C_t}{1^\top_n \cdot R_t} \\ &\hspace{5mm}U_t \leftarrow \frac{G_t}{\sqrt{\widehat{V}_t}} \\\end{aligned}

You may note that Noam Shazeer and Mitchell Stern describe using the sum of squared gradients,while this implementation uses the mean instead. This choice is mathematically equivalent andallows for greater numerical stability for large sums.

add_param_group(param_group)[source]#

Add a param group to theOptimizer sparam_groups.

This can be useful when fine tuning a pre-trained network as frozen layers can be madetrainable and added to theOptimizer as training progresses.

Parameters

param_group (dict) – Specifies what Tensors should be optimized along with groupspecific optimization options.

load_state_dict(state_dict)[source]#

Load the optimizer state.

Parameters

state_dict (dict) – optimizer state. Should be an object returnedfrom a call tostate_dict().

Warning

Make sure this method is called after initializingtorch.optim.lr_scheduler.LRScheduler,as calling it beforehand will overwrite the loaded learning rates.

Note

The names of the parameters (if they exist under the “param_names” key of each param groupinstate_dict()) will not affect the loading process.To use the parameters’ names for custom cases (such as when the parameters in the loaded state dictdiffer from those initialized in the optimizer),a customregister_load_state_dict_pre_hook should be implemented to adapt the loaded dictaccordingly.Ifparam_names exist in loaded state dictparam_groups they will be saved and overridethe current names, if present, in the optimizer state. If they do not exist in loaded state dict,the optimizerparam_names will remain unchanged.

Example

>>>model=torch.nn.Linear(10,10)>>>optim=torch.optim.SGD(model.parameters(),lr=3e-4)>>>scheduler1=torch.optim.lr_scheduler.LinearLR(...optim,...start_factor=0.1,...end_factor=1,...total_iters=20,...)>>>scheduler2=torch.optim.lr_scheduler.CosineAnnealingLR(...optim,...T_max=80,...eta_min=3e-5,...)>>>lr=torch.optim.lr_scheduler.SequentialLR(...optim,...schedulers=[scheduler1,scheduler2],...milestones=[20],...)>>>lr.load_state_dict(torch.load("./save_seq.pt"))>>># now load the optimizer checkpoint after loading the LRScheduler>>>optim.load_state_dict(torch.load("./save_optim.pt"))
register_load_state_dict_post_hook(hook,prepend=False)[source]#

Register a load_state_dict post-hook which will be called afterload_state_dict() is called. It should have thefollowing signature:

hook(optimizer)->None

Theoptimizer argument is the optimizer instance being used.

The hook will be called with argumentself after callingload_state_dict onself. The registered hook can be used toperform post-processing afterload_state_dict has loaded thestate_dict.

Parameters
  • hook (Callable) – The user defined hook to be registered.

  • prepend (bool) – If True, the provided posthook will be fired beforeall the already registered post-hooks onload_state_dict. Otherwise,the providedhook will be fired after all the already registeredpost-hooks. (default: False)

Returns

a handle that can be used to remove the added hook by callinghandle.remove()

Return type

torch.utils.hooks.RemoveableHandle

register_load_state_dict_pre_hook(hook,prepend=False)[source]#

Register a load_state_dict pre-hook which will be called beforeload_state_dict() is called. It should have thefollowing signature:

hook(optimizer,state_dict)->state_dictorNone

Theoptimizer argument is the optimizer instance being used and thestate_dict argument is a shallow copy of thestate_dict the userpassed in toload_state_dict. The hook may modify the state_dict inplaceor optionally return a new one. If a state_dict is returned, it will be usedto be loaded into the optimizer.

The hook will be called with argumentself andstate_dict beforecallingload_state_dict onself. The registered hook can be used toperform pre-processing before theload_state_dict call is made.

Parameters
  • hook (Callable) – The user defined hook to be registered.

  • prepend (bool) – If True, the provided prehook will be fired beforeall the already registered pre-hooks onload_state_dict. Otherwise,the providedhook will be fired after all the already registeredpre-hooks. (default: False)

Returns

a handle that can be used to remove the added hook by callinghandle.remove()

Return type

torch.utils.hooks.RemoveableHandle

register_state_dict_post_hook(hook,prepend=False)[source]#

Register a state dict post-hook which will be called afterstate_dict() is called.

It should have the following signature:

hook(optimizer,state_dict)->state_dictorNone

The hook will be called with argumentsself andstate_dict after generatingastate_dict onself. The hook may modify the state_dict inplace or optionallyreturn a new one. The registered hook can be used to perform post-processingon thestate_dict before it is returned.

Parameters
  • hook (Callable) – The user defined hook to be registered.

  • prepend (bool) – If True, the provided posthook will be fired beforeall the already registered post-hooks onstate_dict. Otherwise,the providedhook will be fired after all the already registeredpost-hooks. (default: False)

Returns

a handle that can be used to remove the added hook by callinghandle.remove()

Return type

torch.utils.hooks.RemoveableHandle

register_state_dict_pre_hook(hook,prepend=False)[source]#

Register a state dict pre-hook which will be called beforestate_dict() is called.

It should have the following signature:

hook(optimizer)->None

Theoptimizer argument is the optimizer instance being used.The hook will be called with argumentself before callingstate_dict onself.The registered hook can be used to perform pre-processing before thestate_dictcall is made.

Parameters
  • hook (Callable) – The user defined hook to be registered.

  • prepend (bool) – If True, the provided prehook will be fired beforeall the already registered pre-hooks onstate_dict. Otherwise,the providedhook will be fired after all the already registeredpre-hooks. (default: False)

Returns

a handle that can be used to remove the added hook by callinghandle.remove()

Return type

torch.utils.hooks.RemoveableHandle

register_step_post_hook(hook)[source]#

Register an optimizer step post hook which will be called after optimizer step.

It should have the following signature:

hook(optimizer,args,kwargs)->None

Theoptimizer argument is the optimizer instance being used.

Parameters

hook (Callable) – The user defined hook to be registered.

Returns

a handle that can be used to remove the added hook by callinghandle.remove()

Return type

torch.utils.hooks.RemovableHandle

register_step_pre_hook(hook)[source]#

Register an optimizer step pre hook which will be called before optimizer step.

It should have the following signature:

hook(optimizer,args,kwargs)->Noneormodifiedargsandkwargs

Theoptimizer argument is the optimizer instance being used. Ifargs and kwargs are modified by the pre-hook, then the transformedvalues are returned as a tuple containing the new_args and new_kwargs.

Parameters

hook (Callable) – The user defined hook to be registered.

Returns

a handle that can be used to remove the added hook by callinghandle.remove()

Return type

torch.utils.hooks.RemovableHandle

state_dict()[source]#

Return the state of the optimizer as adict.

It contains two entries:

  • state: a Dict holding current optimization state. Its content

    differs between optimizer classes, but some common characteristicshold. For example, state is saved per parameter, and the parameteritself is NOT saved.state is a Dictionary mapping parameter idsto a Dict with state corresponding to each parameter.

  • param_groups: a List containing all parameter groups where each

    parameter group is a Dict. Each parameter group contains metadataspecific to the optimizer, such as learning rate and weight decay,as well as a List of parameter IDs of the parameters in the group.If a param group was initialized withnamed_parameters() the namescontent will also be saved in the state dict.

NOTE: The parameter IDs may look like indices but they are just IDsassociating state with param_group. When loading from a state_dict,the optimizer will zip the param_groupparams (int IDs) and theoptimizerparam_groups (actualnn.Parameter s) in order tomatch state WITHOUT additional verification.

A returned state dict might look something like:

{    'state': {        0: {'momentum_buffer': tensor(...), ...},        1: {'momentum_buffer': tensor(...), ...},        2: {'momentum_buffer': tensor(...), ...},        3: {'momentum_buffer': tensor(...), ...}    },    'param_groups': [        {            'lr': 0.01,            'weight_decay': 0,            ...            'params': [0]            'param_names' ['param0']  (optional)        },        {            'lr': 0.001,            'weight_decay': 0.5,            ...            'params': [1, 2, 3]            'param_names': ['param1', 'layer.weight', 'layer.bias'] (optional)        }    ]}
Return type

dict[str,Any]

step(closure=None)[source]#

Perform a single optimization step.

Parameters

closure (Callable,optional) – A closure that reevaluates the modeland returns the loss.

zero_grad(set_to_none=True)[source]#

Reset the gradients of all optimizedtorch.Tensor s.

Parameters

set_to_none (bool,optional) –

Instead of setting to zero, set the grads to None. Default:True

This will in general have lower memory footprint, and can modestly improve performance.However, it changes certain behaviors. For example:

  1. When the user tries to access a gradient and perform manual ops on it,a None attribute or a Tensor full of 0s will behave differently.

  2. If the user requestszero_grad(set_to_none=True) followed by a backward pass,.gradsare guaranteed to be None for params that did not receive a gradient.

  3. torch.optim optimizers have a different behavior if the gradient is 0 or None(in one case it does the step with a gradient of 0 and in the other it skipsthe step altogether).