Movatterモバイル変換


[0]ホーム

URL:


Adam Optimizer with Warmup

This extendsAMSGrad optimizer and adds a warmup stage.

12fromtypingimportDict1314fromlabml_nn.optimizersimportWeightDecay15fromlabml_nn.optimizers.amsgradimportAMSGrad

Adam Optimizer with Warmup

This class extends from AMSGrad optimizer defined inamsgrad.py.

18classAdamWarmup(AMSGrad):

Initialize the optimizer

  • params is the list of parameters
  • lr is the learning rate
  • betas is a tuple of (,)
  • eps is or based onoptimized_update
  • weight_decay is an instance of classWeightDecay defined in__init__.py
  • 'optimized_update' is a flag whether to optimize the bias correction of the second moment by doing it after adding
  • amsgrad is a flag indicating whether to use AMSGrad or fallback to plain Adam
  • warmup number of warmup steps
  • defaults is a dictionary of default for group values. This is useful when you want to extend the classAdamWarmup.
24def__init__(self,params,lr=1e-3,betas=(0.9,0.999),eps=1e-16,25weight_decay:WeightDecay=WeightDecay(),26optimized_update:bool=True,27amsgrad=False,warmup=0,defaults=None):
44defaults={}ifdefaultsisNoneelsedefaults45defaults.update(dict(warmup=warmup))46super().__init__(params,lr,betas,eps,weight_decay,optimized_update,amsgrad,defaults)

Get learning-rate

where is the number of warmup steps.

48defget_lr(self,state:Dict[str,any],group:Dict[str,any]):

If we are in warmup stage

56ifgroup['warmup']>state['step']:

A linearly increasing learning rate from to

58return1e-8+state['step']*group['lr']/group['warmup']59else:

Constant learning rate

61returngroup['lr']

[8]ページ先頭

©2009-2026 Movatter.jp