4747help = 'the number of heads in the encoder/decoder of the transformer model' )
4848parser .add_argument ('--dry-run' ,action = 'store_true' ,
4949help = 'verify the code and the model' )
50- parser .add_argument ('--accel' ,action = 'store_true' ,help = 'Enables accelerated training' )
50+ parser .add_argument ('--accel' ,action = 'store_true' ,
51+ help = 'Enables accelerated training' )
52+ parser .add_argument ('--use-optimizer' ,action = 'store_true' ,
53+ help = 'Uses AdamW optimizer for gradient updating' )
5154args = parser .parse_args ()
5255
5356# Set the random seed manually for reproducibility.
@@ -104,6 +107,8 @@ def batchify(data, bsz):
104107model = RNNModel (args .model ,ntokens ,args .emsize ,args .nhid ,args .nlayers ,args .dropout ,args .tied ).to (device )
105108
106109criterion = nn .NLLLoss ()
110+ if args .use_optimizer :
111+ optimizer = torch .optim .AdamW (model .parameters (),lr = args .lr )
107112
108113###############################################################################
109114# Training code
@@ -167,7 +172,10 @@ def train():
167172data ,targets = get_batch (train_data ,i )
168173# Starting each batch, we detach the hidden state from how it was previously produced.
169174# If we didn't, the model would try backpropagating all the way to start of the dataset.
170- model .zero_grad ()
175+ if args .use_optimizer :
176+ optimizer .zero_grad ()
177+ else :
178+ model .zero_grad ()
171179if args .model == 'Transformer' :
172180output = model (data )
173181output = output .view (- 1 ,ntokens )
@@ -179,8 +187,11 @@ def train():
179187
180188# `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
181189torch .nn .utils .clip_grad_norm_ (model .parameters (),args .clip )
182- for p in model .parameters ():
183- p .data .add_ (p .grad ,alpha = - lr )
190+ if args .use_optimizer :
191+ optimizer .step ()
192+ else :
193+ for p in model .parameters ():
194+ p .data .add_ (p .grad ,alpha = - lr )
184195
185196total_loss += loss .item ()
186197