Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit28d16ff

Browse files
authored
Add AdamW optimizer support for World Language Model example (#1380)
1 parent993a98a commit28d16ff

File tree

2 files changed

+20
-6
lines changed

2 files changed

+20
-6
lines changed

‎word_language_model/README.md‎

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@ python main.py --accel --epochs 6 # Train a LSTM on Wikitext-2.
88
python main.py --accel --epochs 6 --tied# Train a tied LSTM on Wikitext-2.
99
python main.py --accel --tied# Train a tied LSTM on Wikitext-2for 40 epochs.
1010
python main.py --accel --epochs 6 --model Transformer --lr 5
11-
# Train a Transformer model on Wikitext-2.
11+
# Train a Transformer model on Wikitext-2.
12+
python main.py --accel --epochs 6 --model Transformer --use-optimizer --lr 0.001
13+
# Train a Transformer model with AdamW optimizer on Wikitext-2.
1214

13-
python generate.py --accel# Generate samples from the default model checkpoint.
15+
python generate.py --accel# Generate samples from the default model checkpoint.
1416
```
1517

1618
>[!NOTE]
@@ -45,6 +47,7 @@ optional arguments:
4547
path toexport the final modelin onnx format
4648
--nhead NHEAD the number of headsin the encoder/decoder of the transformer model
4749
--dry-run verify the code and the model
50+
--use-optimizer specify whether to use an AdamW optimizer
4851
```
4952

5053
With these arguments, a variety of models can be tested.

‎word_language_model/main.py‎

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,10 @@
4747
help='the number of heads in the encoder/decoder of the transformer model')
4848
parser.add_argument('--dry-run',action='store_true',
4949
help='verify the code and the model')
50-
parser.add_argument('--accel',action='store_true',help='Enables accelerated training')
50+
parser.add_argument('--accel',action='store_true',
51+
help='Enables accelerated training')
52+
parser.add_argument('--use-optimizer',action='store_true',
53+
help='Uses AdamW optimizer for gradient updating')
5154
args=parser.parse_args()
5255

5356
# Set the random seed manually for reproducibility.
@@ -104,6 +107,8 @@ def batchify(data, bsz):
104107
model=RNNModel(args.model,ntokens,args.emsize,args.nhid,args.nlayers,args.dropout,args.tied).to(device)
105108

106109
criterion=nn.NLLLoss()
110+
ifargs.use_optimizer:
111+
optimizer=torch.optim.AdamW(model.parameters(),lr=args.lr)
107112

108113
###############################################################################
109114
# Training code
@@ -167,7 +172,10 @@ def train():
167172
data,targets=get_batch(train_data,i)
168173
# Starting each batch, we detach the hidden state from how it was previously produced.
169174
# If we didn't, the model would try backpropagating all the way to start of the dataset.
170-
model.zero_grad()
175+
ifargs.use_optimizer:
176+
optimizer.zero_grad()
177+
else:
178+
model.zero_grad()
171179
ifargs.model=='Transformer':
172180
output=model(data)
173181
output=output.view(-1,ntokens)
@@ -179,8 +187,11 @@ def train():
179187

180188
# `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
181189
torch.nn.utils.clip_grad_norm_(model.parameters(),args.clip)
182-
forpinmodel.parameters():
183-
p.data.add_(p.grad,alpha=-lr)
190+
ifargs.use_optimizer:
191+
optimizer.step()
192+
else:
193+
forpinmodel.parameters():
194+
p.data.add_(p.grad,alpha=-lr)
184195

185196
total_loss+=loss.item()
186197

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp