Rate this Page

RAdam#

classtorch.optim.RAdam(params,lr=0.001,betas=(0.9,0.999),eps=1e-08,weight_decay=0,decoupled_weight_decay=False,*,foreach=None,maximize=False,capturable=False,differentiable=False)[source]#

Implements RAdam algorithm.

input:γ (lr),β1,β2 (betas),θ0 (params),f(θ) (objective),λ (weightdecay),maximizeϵ (epsilon),decoupled_weight_decayinitialize:m00 ( first moment),v00 ( second moment),ρ2/(1β2)1fort=1todoifmaximize:gtθft(θt1)elsegtθft(θt1)θtθt1ifλ0ifdecoupled_weight_decayθtθtγλθtelsegtgt+λθtmtβ1mt1+(1β1)gtvtβ2vt1+(1β2)gt2mt^mt/(1β1t)ρtρ2tβ2t/(1β2t)ifρt>5lt(1β2t)vt+ϵrt(ρt4)(ρt2)ρ(ρ4)(ρ2)ρtθtθtγmt^rtltelseθtθtγmt^returnθt\begin{aligned} &\rule{110mm}{0.4pt} \\ &\textbf{input} : \gamma \text{ (lr)}, \: \beta_1, \beta_2 \text{ (betas)}, \: \theta_0 \text{ (params)}, \:f(\theta) \text{ (objective)}, \: \lambda \text{ (weightdecay)}, \:\textit{maximize} \\ &\hspace{13mm} \epsilon \text{ (epsilon)}, \textit{decoupled\_weight\_decay} \\ &\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)}, v_0 \leftarrow 0 \text{ ( second moment)}, \\ &\hspace{18mm} \rho_{\infty} \leftarrow 2/(1-\beta_2) -1 \\[-1.ex] &\rule{110mm}{0.4pt} \\ &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ &\hspace{6mm}\textbf{if} \: \textit{maximize}: \\ &\hspace{12mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\ &\hspace{6mm}\textbf{else} \\ &\hspace{12mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ &\hspace{6mm} \theta_t \leftarrow \theta_{t-1} \\ &\hspace{6mm} \textbf{if} \: \lambda \neq 0 \\ &\hspace{12mm}\textbf{if} \: \textit{decoupled\_weight\_decay} \\ &\hspace{18mm} \theta_t \leftarrow \theta_{t} - \gamma \lambda \theta_{t} \\ &\hspace{12mm}\textbf{else} \\ &\hspace{18mm} g_t \leftarrow g_t + \lambda \theta_{t} \\ &\hspace{6mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ &\hspace{6mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\ &\hspace{6mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\ &\hspace{6mm}\rho_t \leftarrow \rho_{\infty} - 2 t \beta^t_2 /\big(1-\beta_2^t \big) \\[0.1.ex] &\hspace{6mm}\textbf{if} \: \rho_t > 5 \\ &\hspace{12mm} l_t \leftarrow \frac{\sqrt{ (1-\beta^t_2) }}{ \sqrt{v_t} +\epsilon } \\ &\hspace{12mm} r_t \leftarrow\sqrt{\frac{(\rho_t-4)(\rho_t-2)\rho_{\infty}}{(\rho_{\infty}-4)(\rho_{\infty}-2) \rho_t}} \\ &\hspace{12mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t} r_t l_t \\ &\hspace{6mm}\textbf{else} \\ &\hspace{12mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t} \\ &\rule{110mm}{0.4pt} \\[-1.ex] &\bf{return} \: \theta_t \\[-1.ex] &\rule{110mm}{0.4pt} \\[-1.ex] \end{aligned}

For further details regarding the algorithm we refer toOn the variance of the adaptive learning rate and beyond.

This implementation provides an option to use either the original weight_decay implementation as in Adam(where the weight_decay is applied to the gradient) or the one from AdamW (where weight_decay is appliedto the weight) through the decoupled_weight_decay option. When decoupled_weight_decay is set to False(default), it uses the original Adam style weight decay, otherwise, it uses the AdamW style whichcorresponds more closely to theauthor’s implementation in the RAdam paper. Further informationabout decoupled weight decay can be found inDecoupled Weight Decay Regularization.

Parameters
  • params (iterable) – iterable of parameters or named_parameters to optimizeor iterable of dicts defining parameter groups. When using named_parameters,all parameters in all groups should be named

  • lr (float,Tensor,optional) – learning rate (default: 1e-3)

  • betas (Tuple[float,float],optional) – coefficients used for computingrunning averages of gradient and its square (default: (0.9, 0.999))

  • eps (float,optional) – term added to the denominator to improvenumerical stability (default: 1e-8)

  • weight_decay (float,optional) – weight decay (L2 penalty) (default: 0)

  • decoupled_weight_decay (bool,optional) – whether to decouple the weightdecay as in AdamW to obtain RAdamW. If True, the algorithm does notaccumulate weight decay in the momentum nor variance. (default: False)

  • foreach (bool,optional) – whether foreach implementation of optimizeris used. If unspecified by the user (so foreach is None), we will try to useforeach over the for-loop implementation on CUDA, since it is usuallysignificantly more performant. Note that the foreach implementation uses~ sizeof(params) more peak memory than the for-loop version due to the intermediatesbeing a tensorlist vs just one tensor. If memory is prohibitive, batch fewerparameters through the optimizer at a time or switch this flag to False (default: None)

  • maximize (bool,optional) – maximize the objective with respect to theparams, instead of minimizing (default: False)

  • capturable (bool,optional) – whether this instance is safe tocapture in a graph, whether for CUDA graphs or for torch.compile support.Tensors are only capturable when on supportedaccelerators.Passing True can impair ungraphed performance, so if you don’t intend to graphcapture this instance, leave it False (default: False)

  • differentiable (bool,optional) – whether autograd shouldoccur through the optimizer step in training. Otherwise, the step()function runs in a torch.no_grad() context. Setting to True can impairperformance, so leave it False if you don’t intend to run autogradthrough this instance (default: False)

add_param_group(param_group)[source]#

Add a param group to theOptimizer sparam_groups.

This can be useful when fine tuning a pre-trained network as frozen layers can be madetrainable and added to theOptimizer as training progresses.

Parameters

param_group (dict) – Specifies what Tensors should be optimized along with groupspecific optimization options.

load_state_dict(state_dict)[source]#

Load the optimizer state.

Parameters

state_dict (dict) – optimizer state. Should be an object returnedfrom a call tostate_dict().

Warning

Make sure this method is called after initializingtorch.optim.lr_scheduler.LRScheduler,as calling it beforehand will overwrite the loaded learning rates.

Note

The names of the parameters (if they exist under the “param_names” key of each param groupinstate_dict()) will not affect the loading process.To use the parameters’ names for custom cases (such as when the parameters in the loaded state dictdiffer from those initialized in the optimizer),a customregister_load_state_dict_pre_hook should be implemented to adapt the loaded dictaccordingly.Ifparam_names exist in loaded state dictparam_groups they will be saved and overridethe current names, if present, in the optimizer state. If they do not exist in loaded state dict,the optimizerparam_names will remain unchanged.

Example

>>>model=torch.nn.Linear(10,10)>>>optim=torch.optim.SGD(model.parameters(),lr=3e-4)>>>scheduler1=torch.optim.lr_scheduler.LinearLR(...optim,...start_factor=0.1,...end_factor=1,...total_iters=20,...)>>>scheduler2=torch.optim.lr_scheduler.CosineAnnealingLR(...optim,...T_max=80,...eta_min=3e-5,...)>>>lr=torch.optim.lr_scheduler.SequentialLR(...optim,...schedulers=[scheduler1,scheduler2],...milestones=[20],...)>>>lr.load_state_dict(torch.load("./save_seq.pt"))>>># now load the optimizer checkpoint after loading the LRScheduler>>>optim.load_state_dict(torch.load("./save_optim.pt"))
register_load_state_dict_post_hook(hook,prepend=False)[source]#

Register a load_state_dict post-hook which will be called afterload_state_dict() is called. It should have thefollowing signature:

hook(optimizer)->None

Theoptimizer argument is the optimizer instance being used.

The hook will be called with argumentself after callingload_state_dict onself. The registered hook can be used toperform post-processing afterload_state_dict has loaded thestate_dict.

Parameters
  • hook (Callable) – The user defined hook to be registered.

  • prepend (bool) – If True, the provided posthook will be fired beforeall the already registered post-hooks onload_state_dict. Otherwise,the providedhook will be fired after all the already registeredpost-hooks. (default: False)

Returns

a handle that can be used to remove the added hook by callinghandle.remove()

Return type

torch.utils.hooks.RemoveableHandle

register_load_state_dict_pre_hook(hook,prepend=False)[source]#

Register a load_state_dict pre-hook which will be called beforeload_state_dict() is called. It should have thefollowing signature:

hook(optimizer,state_dict)->state_dictorNone

Theoptimizer argument is the optimizer instance being used and thestate_dict argument is a shallow copy of thestate_dict the userpassed in toload_state_dict. The hook may modify the state_dict inplaceor optionally return a new one. If a state_dict is returned, it will be usedto be loaded into the optimizer.

The hook will be called with argumentself andstate_dict beforecallingload_state_dict onself. The registered hook can be used toperform pre-processing before theload_state_dict call is made.

Parameters
  • hook (Callable) – The user defined hook to be registered.

  • prepend (bool) – If True, the provided prehook will be fired beforeall the already registered pre-hooks onload_state_dict. Otherwise,the providedhook will be fired after all the already registeredpre-hooks. (default: False)

Returns

a handle that can be used to remove the added hook by callinghandle.remove()

Return type

torch.utils.hooks.RemoveableHandle

register_state_dict_post_hook(hook,prepend=False)[source]#

Register a state dict post-hook which will be called afterstate_dict() is called.

It should have the following signature:

hook(optimizer,state_dict)->state_dictorNone

The hook will be called with argumentsself andstate_dict after generatingastate_dict onself. The hook may modify the state_dict inplace or optionallyreturn a new one. The registered hook can be used to perform post-processingon thestate_dict before it is returned.

Parameters
  • hook (Callable) – The user defined hook to be registered.

  • prepend (bool) – If True, the provided posthook will be fired beforeall the already registered post-hooks onstate_dict. Otherwise,the providedhook will be fired after all the already registeredpost-hooks. (default: False)

Returns

a handle that can be used to remove the added hook by callinghandle.remove()

Return type

torch.utils.hooks.RemoveableHandle

register_state_dict_pre_hook(hook,prepend=False)[source]#

Register a state dict pre-hook which will be called beforestate_dict() is called.

It should have the following signature:

hook(optimizer)->None

Theoptimizer argument is the optimizer instance being used.The hook will be called with argumentself before callingstate_dict onself.The registered hook can be used to perform pre-processing before thestate_dictcall is made.

Parameters
  • hook (Callable) – The user defined hook to be registered.

  • prepend (bool) – If True, the provided prehook will be fired beforeall the already registered pre-hooks onstate_dict. Otherwise,the providedhook will be fired after all the already registeredpre-hooks. (default: False)

Returns

a handle that can be used to remove the added hook by callinghandle.remove()

Return type

torch.utils.hooks.RemoveableHandle

register_step_post_hook(hook)[source]#

Register an optimizer step post hook which will be called after optimizer step.

It should have the following signature:

hook(optimizer,args,kwargs)->None

Theoptimizer argument is the optimizer instance being used.

Parameters

hook (Callable) – The user defined hook to be registered.

Returns

a handle that can be used to remove the added hook by callinghandle.remove()

Return type

torch.utils.hooks.RemovableHandle

register_step_pre_hook(hook)[source]#

Register an optimizer step pre hook which will be called before optimizer step.

It should have the following signature:

hook(optimizer,args,kwargs)->Noneormodifiedargsandkwargs

Theoptimizer argument is the optimizer instance being used. Ifargs and kwargs are modified by the pre-hook, then the transformedvalues are returned as a tuple containing the new_args and new_kwargs.

Parameters

hook (Callable) – The user defined hook to be registered.

Returns

a handle that can be used to remove the added hook by callinghandle.remove()

Return type

torch.utils.hooks.RemovableHandle

state_dict()[source]#

Return the state of the optimizer as adict.

It contains two entries:

  • state: a Dict holding current optimization state. Its content

    differs between optimizer classes, but some common characteristicshold. For example, state is saved per parameter, and the parameteritself is NOT saved.state is a Dictionary mapping parameter idsto a Dict with state corresponding to each parameter.

  • param_groups: a List containing all parameter groups where each

    parameter group is a Dict. Each parameter group contains metadataspecific to the optimizer, such as learning rate and weight decay,as well as a List of parameter IDs of the parameters in the group.If a param group was initialized withnamed_parameters() the namescontent will also be saved in the state dict.

NOTE: The parameter IDs may look like indices but they are just IDsassociating state with param_group. When loading from a state_dict,the optimizer will zip the param_groupparams (int IDs) and theoptimizerparam_groups (actualnn.Parameter s) in order tomatch state WITHOUT additional verification.

A returned state dict might look something like:

{    'state': {        0: {'momentum_buffer': tensor(...), ...},        1: {'momentum_buffer': tensor(...), ...},        2: {'momentum_buffer': tensor(...), ...},        3: {'momentum_buffer': tensor(...), ...}    },    'param_groups': [        {            'lr': 0.01,            'weight_decay': 0,            ...            'params': [0]            'param_names' ['param0']  (optional)        },        {            'lr': 0.001,            'weight_decay': 0.5,            ...            'params': [1, 2, 3]            'param_names': ['param1', 'layer.weight', 'layer.bias'] (optional)        }    ]}
Return type

dict[str,Any]

step(closure=None)[source]#

Perform a single optimization step.

Parameters

closure (Callable,optional) – A closure that reevaluates the modeland returns the loss.

zero_grad(set_to_none=True)[source]#

Reset the gradients of all optimizedtorch.Tensor s.

Parameters

set_to_none (bool,optional) –

Instead of setting to zero, set the grads to None. Default:True

This will in general have lower memory footprint, and can modestly improve performance.However, it changes certain behaviors. For example:

  1. When the user tries to access a gradient and perform manual ops on it,a None attribute or a Tensor full of 0s will behave differently.

  2. If the user requestszero_grad(set_to_none=True) followed by a backward pass,.gradsare guaranteed to be None for params that did not receive a gradient.

  3. torch.optim optimizers have a different behavior if the gradient is 0 or None(in one case it does the step with a gradient of 0 and in the other it skipsthe step altogether).