How to switch data provider during training
In this example, we will see how one can easily switch the data provider during the training usingset_data()
.
Basic Setup
Required Dependencies
!pip install pytorch-ignite
Import
from ignite.engineimport Engine, Events
Data Providers
data1= [1,2,3]data2= [11,12,13]
Create dummytrainer
Let’s create a dummytrain_step
which will print the current iteration and batch of data.
deftrain_step(engine,batch):print(f"Iter[{engine.state.iteration}] Current datapoint = ", batch)trainer=Engine(train_step)
Attach handler to switch data
Now we have to decide when to switch the data provider. It can be after an epoch, iteration or something custom. Below, we are going to switch data after some specific iteration. And then we attach a handler totrainer
that will be executed once afterswitch_iteration
and useset_data()
so that when:
- iteration <=
switch_iteration
, batch is fromdata1
- iteration >
switch_iteration
, batch is fromdata2
switch_iteration=5@trainer.on(Events.ITERATION_COMPLETED(once=switch_iteration))defswitch_dataloader():print("<------- Switch Data ------->") trainer.set_data(data2)
And finally we run thetrainer
for some epochs.
trainer.run(data1,max_epochs=5)
Iter[1] Current datapoint = 1Iter[2] Current datapoint = 2Iter[3] Current datapoint = 3Iter[4] Current datapoint = 1Iter[5] Current datapoint = 2<------- Switch Data ------->Iter[6] Current datapoint = 11Iter[7] Current datapoint = 12Iter[8] Current datapoint = 13Iter[9] Current datapoint = 11Iter[10] Current datapoint = 12Iter[11] Current datapoint = 13Iter[12] Current datapoint = 11Iter[13] Current datapoint = 12Iter[14] Current datapoint = 13Iter[15] Current datapoint = 11State:iteration: 15epoch: 5epoch_length: 3max_epochs: 5output: <class 'NoneType'>batch: 11metrics: <class 'dict'>dataloader: <class 'list'>seed: <class 'NoneType'>times: <class 'dict'>