Serialization#
NeMo 2.0 offers the option to capture the initialization arguments for an experiment’s trainer, model, and dataloader. This feature enables precise reconstruction of these objects, facilitating easy reproducibility of experiments.
IOMixin#
Serialization is performed using theIOMixin class. This class captures the arguments passed to a class’__init__ method, which allows for exact restoration of a trainer, model, and datamodule from a given experiment. The following is a simple example:
fromnemo.lightningimportiockpt=io.TrainerContext(model,trainer,extra={"datamodule":data})## dump the current stateckpt.io_dump(save_dir)## restore the serialized stateloaded=io.load_context(save_dir)## model, trainer and dataloader will be reinitialized using the same args as beforemodel=loaded.modeltrainer=loaded.trainerdatamodule=loaded.extra["datamodule"]
Saving these initialization states can be done automatically viaModelCheckpoint’senable_nemo_ckpt_io argument. Ifenable_nemo_ckpt_io=True,IOMixin’sio_dump functionality will be invoked to save the trainer, model, and dataloader initialization states. These states can then be restored using theio.load_context function. Note that this feature is independent from checkpoint loading; once the objects have been instantiated, if you would like to use the weights from a previous run, they still need to be restored from the checkpoint. An example workflow is as follows:
First, run some training and save a checkpoint:
importnemo.lightningasnlfromnemo.collectionsimportllmfromnemo.lightningimportiotrainer=nl.Trainer(...)model=llm.GPTModel(...)datamodule=llm.PreTrainingDataModule(...)optim=nl.MegatronOptimizerModule(...)checkpoint_callback=nl.ModelCheckpoint(...enable_nemo_ckpt_io=True,...)nemo_logger=nl.NeMoLogger(...explicit_log_dir='explicit_dir_test',ckpt=checkpoint_callback,...)resume=nl.AutoResume(resume_if_exists=True,resume_ignore_no_checkpoint=True,)llm.train(model=model,data=datamodule,trainer=trainer,log=nemo_logger,resume=resume,tokenizer='data',optim=opt,)
In the above example,ModelCheckpoint,NeMoLogger, andAutoResume are responsible for setting up the logging and checkpointing directories and determining when to save and restore checkpoints. More information about these classes can be found in thelogging and checkpointing doc.
Once the initialization states have been saved, we can resume the trainer, model, and datamodule from the serialized path. Note that everything not captured byio_dump (e.g. the checkpoint callback, logger and resume) should be reinitialized. Doing so ensures that the logging and checkpointing directories are set up correctly. It also ensures that the appropriate model weights are restored after reinitialization.
importnemo.lightningasnlfromnemo.collectionsimportllmfromnemo.lightningimportioloaded=io.load_context("explicit_dir_test/<PATH TO LATEST CHECKPOINT>")model=loaded.modeltrainer=loaded.trainerdatamodule=loaded.extra["datamodule"]optim=nl.MegatronOptimizerModule(...)## optimizer needs to be reinitializedcheckpoint_callback=nl.ModelCheckpoint(...enable_nemo_ckpt_io=True,...)nemo_logger=nl.NeMoLogger(...explicit_log_dir='explicit_dir_test',ckpt=checkpoint_callback,...)resume=nl.AutoResume(## handles resuming of the latest checkpoint in `explicit_dir_test`resume_if_exists=True,resume_ignore_no_checkpoint=True,)llm.train(model=model,data=datamodule,trainer=trainer,log=nemo_logger,resume=resume,tokenizer='data',optim=opt,)