- Notifications
You must be signed in to change notification settings - Fork3.6k
How to switch dataloader every n training steps#12415
-
Hi everyone, This is what I would like to achieve: I foundthis solution on the old forum but this only switches the dataset after each epoch. Here is my current attempt at switching it every n batches: classSimpleModule(pl.LightningModule):def__init__(self):super().__init__()self.model= ...self.batch_size= ...self.change_every_n_batch=20deftrain_dataloader(self):self.current_dataset= (self.global_step//self.change_every_n_batch)%2ifself.current_dataset==0:dataset=Dataset1()elifself.current_dataset==1:dataset=Dataset2()dataloader=DataLoader(dataset,batch_size=self.batch_size)returndataloaderdefon_train_batch_end(self,outputs,batch,batch_idx):new_dataset= (self.global_step//self.change_every_n_batch)%2ifnew_dataset!=self.current_dataset:self.trainer.reset_train_dataloader(self)
Any idea what could be going wrong? Or do you have a solution for what I want to achieve? Thanks! |
BetaWas this translation helpful?Give feedback.
All reactions
hey@matprst !
you can set:
limit_train_batches=n
. This will ensure that every training epoch will progress for only n batchesreload_dataloaders_every_n_epochs=1
. this will ensure that train dataloader is reloaded after every epoch.
and insidetrain_dataloader
, flip the dataloader on each reload. something like:
deftrain_dataloader(self):ifself.some_flag:dataset=Dataset1()else:dataset=Dataset2()self.some_flag=notself.some_flagreturnDataLoader(dataset,batch_size=self.batch_size)
Replies: 2 comments 1 reply
-
hey@matprst ! you can set:
and inside deftrain_dataloader(self):ifself.some_flag:dataset=Dataset1()else:dataset=Dataset2()self.some_flag=notself.some_flagreturnDataLoader(dataset,batch_size=self.batch_size) |
BetaWas this translation helpful?Give feedback.
All reactions
-
Hi, I want to know how to manually reset the |
BetaWas this translation helpful?Give feedback.
All reactions
-
Works like a charm, and much cleaner than what I thought! Thanks for the reply! I realise now that since I am using iterable datasets (they are large and don't fit into memory), the reloading restarts the iterable from the beginning rather than continuing where it stopped (or at least returning a random batch). This is another problem with the dataset, so I will consider the question answered. |
BetaWas this translation helpful?Give feedback.