Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

How to load a model from model-only checkpoints?#1418

AnsweredbyShuhuaGao
ShuhuaGao asked this question inQ&A
Discussion options

By default, the best model and the last model are saved by catalyst as checkpoints. For example, I have the following files in thecheckpoints folder:

model.0049.pthmodel.best.pthmodel.last.pthmodel.storage.json

I then use the following code to read the checkpoint.

checkpoint=utils.load_checkpoint(path=os.path.join(model_dir,'model.best.pth'))checkpoint.keys()# the output is: odict_keys(['model.0.weight', 'model.0.bias', 'model.2.weight', 'model.2.bias', 'model.4.weight', 'model.4.bias'])

The obtainedcheckpoint contains weights and bias values for each linear layer (the underlying net is just a MLP).

Now the question is how I can restore the network model withcheckpoint. I can't work withutils.unpack_checkpoint(checkpoint, best_model) since it depends on the 'model_state_dict' key but ourcheckpoint above does not have it.

Is there any method inutils handling this task? Or do I have to make it manually with PyTorch methods?

You must be logged in to vote

Sorry for my ignorance..

In the model-only saving mode, themodel.*.pth file contains exactly thestate_dict of a model in pytorch. Supposing we have defined ourmodel, simply do it like this to restore the trained model:

model.load_state_dict(torch.load(os.path.join(model_dir,'model.best.pth'))

Replies: 1 comment

Comment options

Sorry for my ignorance..

In the model-only saving mode, themodel.*.pth file contains exactly thestate_dict of a model in pytorch. Supposing we have defined ourmodel, simply do it like this to restore the trained model:

model.load_state_dict(torch.load(os.path.join(model_dir,'model.best.pth'))
You must be logged in to vote
0 replies
Answer selected byShuhuaGao
Sign up for freeto join this conversation on GitHub. Already have an account?Sign in to comment
Category
Q&A
Labels
None yet
1 participant
@ShuhuaGao

[8]ページ先頭

©2009-2026 Movatter.jp