nemo_rl.algorithms.distillation#
Module Contents#
Classes#
Main configuration structure. |
Functions#
Check if the vocab of the tokenizer (student) and the teacher tokenizer are equal. | |
Main entry point for distillation algorithm. | |
Run Distillation training algorithm. | |
Run validation on the validation dataset. |
Data#
API#
- nemo_rl.algorithms.distillation.TokenizerType#
‘TypeVar(…)’
- classnemo_rl.algorithms.distillation.DistillationConfig#
Bases:
typing.TypedDict- num_prompts_per_step:int#
None
- num_generations_per_prompt:int#
None
- max_rollout_turns:int#
None
- max_num_steps:int#
None
- max_num_epochs:int#
None
- val_batch_size:int#
None
- val_period:int#
None
- val_at_start:bool#
None
- max_val_samples:int#
None
- topk_logits_k:int#
None
- seed:int#
None
- classnemo_rl.algorithms.distillation.DistillationSaveState#
Bases:
typing.TypedDict- total_steps:int#
None
- current_epoch:int#
None
- current_step:int#
None
- val_reward:NotRequired[float]#
None
- consumed_samples:int#
None
- total_valid_tokens:int#
None
- nemo_rl.algorithms.distillation._default_distillation_save_state()→nemo_rl.algorithms.distillation.DistillationSaveState#
- classnemo_rl.algorithms.distillation.MasterConfig#
Bases:
typing.TypedDictMain configuration structure.
Initialization
Initialize self. See help(type(self)) for accurate signature.
- policy:nemo_rl.models.policy.PolicyConfig#
None
- teacher:nemo_rl.models.policy.PolicyConfig#
None
- env:dict[str,Any]#
None
- data:nemo_rl.data.DataConfig#
None
- distillation:nemo_rl.algorithms.distillation.DistillationConfig#
None
- logger:nemo_rl.utils.logger.LoggerConfig#
None
- checkpointing:nemo_rl.utils.checkpoint.CheckpointingConfig#
None
- nemo_rl.algorithms.distillation.check_vocab_equality(
- tokenizer:nemo_rl.algorithms.distillation.TokenizerType,
- student_model_name:str,
- teacher_model_name:str,
Check if the vocab of the tokenizer (student) and the teacher tokenizer are equal.
- nemo_rl.algorithms.distillation.setup(
- master_config:nemo_rl.algorithms.distillation.MasterConfig,
- tokenizer:nemo_rl.algorithms.distillation.TokenizerType,
- train_dataset:nemo_rl.data.datasets.AllTaskProcessedDataset,
- val_dataset:Optional[nemo_rl.data.datasets.AllTaskProcessedDataset],
Main entry point for distillation algorithm.
- Returns:
tuple of student_policy, teacher_policy, student_generation,train_dataloader, val_dataloader,loss_fn, logger, checkpointer, distillation_save_state, master_config
- nemo_rl.algorithms.distillation.distillation_train(
- student_policy:nemo_rl.models.policy.interfaces.ColocatablePolicyInterface,
- teacher_policy:nemo_rl.models.policy.interfaces.ColocatablePolicyInterface,
- student_generation:Optional[nemo_rl.models.generation.interfaces.GenerationInterface],
- dataloader:torchdata.stateful_dataloader.StatefulDataLoader,
- val_dataloader:Optional[torchdata.stateful_dataloader.StatefulDataLoader],
- tokenizer:nemo_rl.algorithms.distillation.TokenizerType,
- loss_fn:nemo_rl.algorithms.loss_functions.DistillationLossFn,
- task_to_env:dict[str,nemo_rl.environments.interfaces.EnvironmentInterface],
- val_task_to_env:Optional[dict[str,nemo_rl.environments.interfaces.EnvironmentInterface]],
- logger:nemo_rl.utils.logger.Logger,
- checkpointer:nemo_rl.utils.checkpoint.CheckpointManager,
- distillation_save_state:nemo_rl.algorithms.distillation.DistillationSaveState,
- master_config:nemo_rl.algorithms.distillation.MasterConfig,
Run Distillation training algorithm.
- nemo_rl.algorithms.distillation.validate(
- policy_generation:nemo_rl.models.generation.interfaces.GenerationInterface,
- val_dataloader:Optional[torchdata.stateful_dataloader.StatefulDataLoader],
- tokenizer,
- val_task_to_env:Optional[dict[str,nemo_rl.environments.interfaces.EnvironmentInterface]],
- step:int,
- master_config:nemo_rl.algorithms.distillation.MasterConfig,
Run validation on the validation dataset.