Trigger any handlers at any built-in and custom events.
from ignite.engineimport Engine, Eventstrainer=Engine(lambdaengine,batch: batch/2)@trainer.on(Events.ITERATION_COMPLETED(every=2))defprint_output(engine):print(engine.state.output)
Checkpointing, early stopping, profiling, parameter scheduling, learning rate finder, and more.
from ignite.engineimport Engine, Eventsfrom ignite.handlersimport ModelCheckpoint, EarlyStopping, PiecewiseLinearmodel= nn.Linear(3,3)trainer=Engine(lambdaengine,batch:None)# model checkpoint handlercheckpoint=ModelCheckpoint('/tmp/ckpts','training')trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), checkpoint, {'model': model})# early stopping handlerdefscore_function(engine): val_loss= engine.state.metrics['acc']return val_losses=EarlyStopping(3, score_function, trainer)# Note: the handler is attached to an *Evaluator* (runs one epoch on validation dataset).evaluator.add_event_handler(Events.COMPLETED, es)# Piecewise linear parameter schedulerscheduler=PiecewiseLinear(optimizer,'lr', [(10,0.5), (20,0.45), (21,0.3), (30,0.1), (40,0.1)])trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)
Speed up the training on CPUs, GPUs, and TPUs.
import ignite.distributedas idistdeftraining(local_rank, *args, **kwargs): dataloder_train= idist.auto_dataloader(dataset,...) model=... model= idist.auto_model(model) optimizer=... optimizer= idist.auto_optim(optimizer)backend='nccl'# or 'gloo', 'horovod', 'xla-tpu'with idist.Parallel(backend)as parallel: parallel.run(training)
Distributed ready out-of-the-box metrics to easily evaluate models.
from ignite.engineimport Enginefrom ignite.metricsimport Accuracytrainer=Engine(...)acc=Accuracy()acc.attach(trainer,'accuracy')state= engine.run(data)print(f"Accuracy:{state.metrics['accuracy']}")
Tensorboard, MLFlow, WandB, Neptune, and more.
from ignite.engineimport Engine, Eventsfrom ignite.contrib.handlers.tensorboard_loggerimport TensorboardLoggertrainer=Engine(...)# Create a tensorboard loggerwithTensorboardLogger(log_dir="experiments/tb_logs")as tb_logger:# Attach the logger to the trainer to log training loss at each iteration tb_logger.attach_output_handler( trainer,event_name=Events.ITERATION_COMPLETED,tag="training",output_transform=lambdaloss: {"loss": loss} )