nemo_rl.algorithms.distillation#

Module Contents#

Classes#

Functions#

_default_distillation_save_state

check_vocab_equality

Check if the vocab of the tokenizer (student) and the teacher tokenizer are equal.

setup

Main entry point for distillation algorithm.

distillation_train

Run Distillation training algorithm.

validate

Run validation on the validation dataset.

Data#

API#

nemo_rl.algorithms.distillation.TokenizerType#

‘TypeVar(…)’

classnemo_rl.algorithms.distillation.DistillationConfig#

Bases:typing.TypedDict

num_prompts_per_step:int#

None

num_generations_per_prompt:int#

None

max_rollout_turns:int#

None

max_num_steps:int#

None

max_num_epochs:int#

None

val_batch_size:int#

None

val_period:int#

None

val_at_start:bool#

None

max_val_samples:int#

None

topk_logits_k:int#

None

seed:int#

None

classnemo_rl.algorithms.distillation.DistillationSaveState#

Bases:typing.TypedDict

total_steps:int#

None

current_epoch:int#

None

current_step:int#

None

val_reward:NotRequired[float]#

None

consumed_samples:int#

None

total_valid_tokens:int#

None

nemo_rl.algorithms.distillation._default_distillation_save_state()nemo_rl.algorithms.distillation.DistillationSaveState#
classnemo_rl.algorithms.distillation.MasterConfig#

Bases:typing.TypedDict

Main configuration structure.

Initialization

Initialize self. See help(type(self)) for accurate signature.

policy:nemo_rl.models.policy.PolicyConfig#

None

teacher:nemo_rl.models.policy.PolicyConfig#

None

loss_fn:nemo_rl.algorithms.loss_functions.DistillationLossConfig#

None

env:dict[str,Any]#

None

data:nemo_rl.data.DataConfig#

None

distillation:nemo_rl.algorithms.distillation.DistillationConfig#

None

logger:nemo_rl.utils.logger.LoggerConfig#

None

cluster:nemo_rl.distributed.virtual_cluster.ClusterConfig#

None

checkpointing:nemo_rl.utils.checkpoint.CheckpointingConfig#

None

nemo_rl.algorithms.distillation.check_vocab_equality(
tokenizer:nemo_rl.algorithms.distillation.TokenizerType,
student_model_name:str,
teacher_model_name:str,
)None#

Check if the vocab of the tokenizer (student) and the teacher tokenizer are equal.

nemo_rl.algorithms.distillation.setup(
master_config:nemo_rl.algorithms.distillation.MasterConfig,
tokenizer:nemo_rl.algorithms.distillation.TokenizerType,
train_dataset:nemo_rl.data.datasets.AllTaskProcessedDataset,
val_dataset:Optional[nemo_rl.data.datasets.AllTaskProcessedDataset],
)tuple[nemo_rl.models.policy.interfaces.ColocatablePolicyInterface,nemo_rl.models.policy.interfaces.ColocatablePolicyInterface,Optional[nemo_rl.models.generation.interfaces.GenerationInterface],torchdata.stateful_dataloader.StatefulDataLoader,Optional[torchdata.stateful_dataloader.StatefulDataLoader],nemo_rl.algorithms.loss_functions.DistillationLossFn,nemo_rl.utils.logger.Logger,nemo_rl.utils.checkpoint.CheckpointManager,nemo_rl.algorithms.distillation.DistillationSaveState,nemo_rl.algorithms.distillation.MasterConfig]#

Main entry point for distillation algorithm.

Returns:

tuple of student_policy, teacher_policy, student_generation,train_dataloader, val_dataloader,loss_fn, logger, checkpointer, distillation_save_state, master_config

nemo_rl.algorithms.distillation.distillation_train(
student_policy:nemo_rl.models.policy.interfaces.ColocatablePolicyInterface,
teacher_policy:nemo_rl.models.policy.interfaces.ColocatablePolicyInterface,
student_generation:Optional[nemo_rl.models.generation.interfaces.GenerationInterface],
dataloader:torchdata.stateful_dataloader.StatefulDataLoader,
val_dataloader:Optional[torchdata.stateful_dataloader.StatefulDataLoader],
tokenizer:nemo_rl.algorithms.distillation.TokenizerType,
loss_fn:nemo_rl.algorithms.loss_functions.DistillationLossFn,
task_to_env:dict[str,nemo_rl.environments.interfaces.EnvironmentInterface],
val_task_to_env:Optional[dict[str,nemo_rl.environments.interfaces.EnvironmentInterface]],
logger:nemo_rl.utils.logger.Logger,
checkpointer:nemo_rl.utils.checkpoint.CheckpointManager,
distillation_save_state:nemo_rl.algorithms.distillation.DistillationSaveState,
master_config:nemo_rl.algorithms.distillation.MasterConfig,
)None#

Run Distillation training algorithm.

nemo_rl.algorithms.distillation.validate(
policy_generation:nemo_rl.models.generation.interfaces.GenerationInterface,
val_dataloader:Optional[torchdata.stateful_dataloader.StatefulDataLoader],
tokenizer,
val_task_to_env:Optional[dict[str,nemo_rl.environments.interfaces.EnvironmentInterface]],
step:int,
master_config:nemo_rl.algorithms.distillation.MasterConfig,
)tuple[dict[str,Any],dict[str,Any]]#

Run validation on the validation dataset.