Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

How to switch dataloader every n training steps#12415

Discussion options

Hi everyone,
In my current setup, I would like to change the dataloader during a training epoch:

This is what I would like to achieve:
step 1.Train on dataset 1 for n batches
step 2.Train on dataset 2 for n batches
step 3.Go to step 1

I foundthis solution on the old forum but this only switches the dataset after each epoch.

Here is my current attempt at switching it every n batches:

classSimpleModule(pl.LightningModule):def__init__(self):super().__init__()self.model= ...self.batch_size= ...self.change_every_n_batch=20deftrain_dataloader(self):self.current_dataset= (self.global_step//self.change_every_n_batch)%2ifself.current_dataset==0:dataset=Dataset1()elifself.current_dataset==1:dataset=Dataset2()dataloader=DataLoader(dataset,batch_size=self.batch_size)returndataloaderdefon_train_batch_end(self,outputs,batch,batch_idx):new_dataset= (self.global_step//self.change_every_n_batch)%2ifnew_dataset!=self.current_dataset:self.trainer.reset_train_dataloader(self)

train_dataloader() is called as expected every 20 batches byon_train_batch_end() but the returned dataloader does not seem to be used during the training loop.

Any idea what could be going wrong? Or do you have a solution for what I want to achieve?

Thanks!

You must be logged in to vote

hey@matprst !

you can set:

  • limit_train_batches=n. This will ensure that every training epoch will progress for only n batches
  • reload_dataloaders_every_n_epochs=1. this will ensure that train dataloader is reloaded after every epoch.

and insidetrain_dataloader, flip the dataloader on each reload. something like:

deftrain_dataloader(self):ifself.some_flag:dataset=Dataset1()else:dataset=Dataset2()self.some_flag=notself.some_flagreturnDataLoader(dataset,batch_size=self.batch_size)

Replies: 2 comments 1 reply

Comment options

hey@matprst !

you can set:

  • limit_train_batches=n. This will ensure that every training epoch will progress for only n batches
  • reload_dataloaders_every_n_epochs=1. this will ensure that train dataloader is reloaded after every epoch.

and insidetrain_dataloader, flip the dataloader on each reload. something like:

deftrain_dataloader(self):ifself.some_flag:dataset=Dataset1()else:dataset=Dataset2()self.some_flag=notself.some_flagreturnDataLoader(dataset,batch_size=self.batch_size)
You must be logged in to vote
1 reply
@wenhaoli-xmu
Comment options

Hi, I want to know how to manually reset thetrain_dataloader instead of calling this function periodically.

Answer selected bymatprst
Comment options

Works like a charm, and much cleaner than what I thought! Thanks for the reply!

I realise now that since I am using iterable datasets (they are large and don't fit into memory), the reloading restarts the iterable from the beginning rather than continuing where it stopped (or at least returning a random batch).

This is another problem with the dataset, so I will consider the question answered.

You must be logged in to vote
0 replies
Sign up for freeto join this conversation on GitHub. Already have an account?Sign in to comment
Labels
data handlingGeneric data-related topic
3 participants
@matprst@rohitgr7@wenhaoli-xmu

[8]ページ先頭

©2009-2025 Movatter.jp