12fromtypingimportDict1314fromlabml_nn.optimizersimportWeightDecay15fromlabml_nn.optimizers.amsgradimportAMSGrad
18classAdamWarmup(AMSGrad):
params is the list of parameterslr is the learning ratebetas is a tuple of (,)eps is or based onoptimized_updateweight_decay is an instance of classWeightDecay defined in__init__.pyamsgrad is a flag indicating whether to use AMSGrad or fallback to plain Adamwarmup number of warmup stepsdefaults is a dictionary of default for group values. This is useful when you want to extend the classAdamWarmup.24def__init__(self,params,lr=1e-3,betas=(0.9,0.999),eps=1e-16,25weight_decay:WeightDecay=WeightDecay(),26optimized_update:bool=True,27amsgrad=False,warmup=0,defaults=None):
44defaults={}ifdefaultsisNoneelsedefaults45defaults.update(dict(warmup=warmup))46super().__init__(params,lr,betas,eps,weight_decay,optimized_update,amsgrad,defaults)
48defget_lr(self,state:Dict[str,any],group:Dict[str,any]):
If we are in warmup stage
56ifgroup['warmup']>state['step']:
A linearly increasing learning rate from to
58return1e-8+state['step']*group['lr']/group['warmup']59else:
Constant learning rate
61returngroup['lr']