- Notifications
You must be signed in to change notification settings - Fork3
lancedb/lance-distributed-training
Folders and files
| Name | Name | Last commit message | Last commit date | |
|---|---|---|---|---|
Repository files navigation
- Create lance dataset of FOOD101:
python create_datasets/classficiation.py- train using map-style dataset:
torchrun --nproc-per-node=2 lance_map_style.py --batch_size 128- train using iterable dataset:
torchrun --nproc-per-node=2 lance_iterable.py --batch_size 128There are 2 ways to load data for training models using Lance’s pytorch integration.
- Iterable style dataset (
LanceDataset) - Suitable for streaming. Works with inbuilt distributed samplers - Map-style dataset (
SafeLanceDataset) - Suitable as a default choice unless you have a specific reason to use an iterable dataset.
A key difference in working with both is that:
- In theiterable-style (
LanceDataset), the data transformation (decoding bytes, applying transforms, stacking)must happen before theDataLoaderreceives the data. This is done inside theto_tensor_fn(decode_tensor_image).
If your dataset contains a binray feild, it can't be converted to tensor directly, so you need to handle it appropriately in a customto_tensor_fn. This is similar tocollate_fn when using map-style dataset
Example: Decoding images from LanceDatase using a custom `to_tensor_fn`
defdecode_tensor_image(batch,**kwargs):images= []labels= []foriteminbatch.to_pylist():img=Image.open(io.BytesIO(item["image"])).convert("RGB")img=_food101_transform(img)images.append(img)labels.append(item["label"])batch= {"image":torch.stack(images),"label":torch.tensor(labels,dtype=torch.long) }returnbatchds=LanceDataset(dataset_path,to_tensor_fn=decode_tensor_image,batch_size=batch_size,sampler=sampler )
In themap-style (SafeLanceDataset), theDataLoader's workers fetch the raw data, and the transformation happens later in thecollate_fn
Example: Decoding images from SafeLanceDataset using `collate_fn`
fromlance.torch.dataimportSafeLanceDataset,get_safe_loaderdefcollate_fn(batch_of_dicts):""" Collates a list of dictionaries from SafeLanceDataset into a single batch. This function handles decoding the image bytes and applying transforms. """images= []labels= []transform=transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor(), ])foriteminbatch_of_dicts:image_bytes=item["image"]img=Image.open(io.BytesIO(image_bytes)).convert("RGB")img_tensor=transform(img)images.append(img_tensor)labels.append(item["label"])return {"image":torch.stack(images),"label":torch.tensor(labels,dtype=torch.long) }loader=get_safe_loader(dataset,batch_size=args.batch_size,sampler=sampler,shuffle=(samplerisNone),num_workers=args.num_workers,collate_fn=collate_fn,pin_memory=True,persistent_workers=True )
These are some rules of thumb to decide it you should use Map-style or Iterable dataset
Standard Datasets (Default Choice): Use this for any dataset where you have a finite collection of data points that can be indexed. This covers almost all standard use cases like image classification (ImageNet, CIFAR) or text classification where each file/line is a sample.
When You Need High Performance: This is the only way to get the full performance benefit of PyTorch's DataLoader with num_workers > 0. The DataLoader can give each worker a list of indices to fetch in parallel, which is extremely efficient.In short this should be your default choice unless you have a specific reason to use an iterable dataset.
Lance doesn't materialise indexes in memory, so you can almost always use map style.
This type of dataset works like a Python generator or a data stream.
Core Idea: It only knows how to give you the next item when you iterate over it (iter). It often has no known length and you cannot ask for the N-th item directly.
FullScanSampler:Not DDP-aware. Intentionally designed for each process to scan the entire dataset. It is useful for single-GPU evaluation or when you need every process to see all the data for some reason (which is rare in ddp training).ShardedBatchSampler:DDP-aware. Splits the total set ofbatches evenly among GPUs. Provides perfect workload balance.ShardedFragmentSampler:DDP-aware. Splits the list offragments among GPUs. Can result in an unbalanced workload if fragments have uneven sizes. This needs to be handled to prevent synchornization errors,
This is the simplest sampler. It inherits fromFragmentSampler and implements theiter_fragments method. Its implementation is a single loop that gets all fragments from the dataset and yields each one sequentially.
TheFullScanSampler isnot DDP-aware. It contains no logic that checks for arank orworld_size. Consequently, when used in a distributed setting,every single process (GPU) will scan the entire dataset.
- Use Case: This sampler is primarily intended for single-GPU scenarios, such as validation, inference, or debugging, where you need one process to read all the data. It is not suitable for distributed training.
This sampler also inherits fromFragmentSampler and works by dividing thelist of fragments among the available processes. Itsiter_fragments method gets the full list of fragments and then yields only the ones corresponding to its assignedrank.
- Rank 0 gets fragments 0, 2, 4, ...
- Rank 1 gets fragments 1, 3, 5, ...
and so on
This sampler isDDP-aware, but it operates at the fragment level.
- Pros: It can be very I/O efficient. Since each process is assigned whole fragments, it can read them in long, sequential blocks. The documentation notes this is more efficient for large datasets.
- Cons: It can lead toworkload imbalance. Lance datasets can have fragments of varying sizes (e.g., the last fragment is often smaller). If one rank is assigned fragments that have more total rows than another rank, it will have more batches to process. This imbalance can lead to DDP deadlocks if not handled with padding.
can lead toworkload imbalance, and eventually error out.
Example: DDP error due to imbalanced fragment sampler
Epoch1/10:300it [07:12,1.44s/it,loss=1.07] [Epoch0]Loss:980.4352,EpochTime:432.61sEpoch2/10:133it [03:17,1.48s/it,loss=5.98]Epoch2/10:300it [07:24,1.48s/it,loss=2.49][Epoch1]Loss:1200.9648,EpochTime:444.51sEpoch3/10:300it [07:22,1.48s/it,loss=3.24][Epoch2]Loss:1324.9992,EpochTime:442.84sEpoch4/10:300it [07:23,1.48s/it,loss=3.69][Epoch3]Loss:1371.6891,EpochTime:443.10sEpoch5/10:300it [07:23,1.48s/it,loss=3.91][Epoch4]Loss:1384.9732,EpochTime:443.12s,ValAcc:0.0196Epoch6/10:300it [07:24,1.48s/it,loss=3.94][Epoch5]Loss:1388.0216,EpochTime:444.14sEpoch7/10:300it [07:24,1.48s/it,loss=4] [Epoch6]Loss:1388.9526,EpochTime:444.02sEpoch8/10:300it [07:24,1.48s/it,loss=3.99][Epoch7]Loss:1388.8115,EpochTime:444.43sEpoch9/10:300it [07:24,1.48s/it,loss=2.29][Epoch8]Loss:1314.3089,EpochTime:444.65sEpoch9/10:300it [07:24,1.48s/it,loss=2.29]][Epoch8]Loss:1314.3089,EpochTime:444.65sEpoch10/10:240it [05:55,1.47s/it,loss=5.46][rank0]:[E70917:05:38.162555850ProcessGroupNCCL.cpp:632] [Rank0]Watchdogcaughtcollectiveoperationtimeout:WorkNCCL(SeqNum=20585,OpType=ALLREDUCE,NumelIn=1259621,NumelOut=1259621,Timeout(ms)=600000)ranfor600000millisecondsbeforetimingout.[rank0]:[E70917:05:38.162814866ProcessGroupNCCL.cpp:2271] [PGID0PGGUID0(default_pg)Rank0]failuredetectedbywatchdogatworksequenceid:20585PGstatus:lastenqueuedwork:20589,lastcompletedwork:20584[rank0]:[E70917:05:38.162832798ProcessGroupNCCL.cpp:670]Stacktraceofthefailedcollectivenotfound,potentiallybecauseFlightRecorderisdisabled.YoucanenableitbysettingTORCH_NCCL_TRACE_BUFFER_SIZEtoanon-zerovalue.[rank0]:[E70917:05:38.162895613ProcessGroupNCCL.cpp:2106] [PGID0PGGUID0(default_pg)Rank0]FirstPGonthisranktosignaldumping.[rank0]:[E70917:05:38.482119928ProcessGroupNCCL.cpp:1746] [PGID0PGGUID0(default_pg)Rank0]Receivedadumpsignalduetoacollectivetimeoutfromthislocalrankandwewilltryourbesttodumpthedebuginfo.LastenqueuedNCCLwork:20589,lastcompletedNCCLwork:20584.Thisismostlikelycausedbyincorrectusagesofcollectives,e.g.,wrongsizesusedacrossranks,theorderofcollectivesisnotsameforallranksorthescheduledcollective,forsomereason,didn'trun.Additionally,thiscanbecausedbyGILdeadlockorotherreasonssuchasnetworkerrorsorbugsinthecommunicationslibrary (e.g.NCCL),etc. [rank0]:[E70917:05:38.482326987ProcessGroupNCCL.cpp:1536] [PGID0PGGUID0(default_pg)Rank0]ProcessGroupNCCLpreparingtodumpdebuginfo.Includestacktrace:1Epoch10/10:241it [15:55,181.19s/it,loss=5.09][rank0]:[E70917:05:39.081662161ProcessGroupNCCL.cpp:684] [Rank0]SomeNCCLoperationshavefailedortimedout.DuetotheasynchronousnatureofCUDAkernels,subsequentGPUoperationsmightrunoncorrupted/incompletedata.[rank0]:[E70917:05:39.081690629ProcessGroupNCCL.cpp:698] [Rank0]Toavoiddatainconsistency,wearetakingtheentireprocessdown.[rank0]:[E70917:05:39.083402482ProcessGroupNCCL.cpp:1899] [PGID0PGGUID0(default_pg)Rank0]Processgroupwatchdogthreadterminatedwithexception: [Rank0]Watchdogcaughtcollectiveoperationtimeout:WorkNCCL(SeqNum=20585,OpType=ALLREDUCE,NumelIn=1259621,NumelOut=1259621,Timeout(ms)=600000)ranfor600000millisecondsbeforetimingout.ExceptionraisedfromcheckTimeoutat/pytorch/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:635 (mostrecentcallfirst):frame#0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x98 (0x7f92e62535e8 in /opt/conda/lib/python3.10/site-packages/torch/lib/libc10.so)frame#1: c10d::ProcessGroupNCCL::WorkNCCL::checkTimeout(std::optional<std::chrono::duration<long, std::ratio<1l, 1000l> > >) + 0x23d (0x7f92e756ea6d in /opt/conda/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)frame#2: c10d::ProcessGroupNCCL::watchdogHandler() + 0xc80 (0x7f92e75707f0 in /opt/conda/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)frame#3: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x14d (0x7f92e7571efd in /opt/conda/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)frame#4: <unknown function> + 0xd8198 (0x7f92d7559198 in /opt/conda/bin/../lib/libstdc++.so.6)frame#5: <unknown function> + 0x7ea7 (0x7f933d48dea7 in /usr/lib/x86_64-linux-gnu/libpthread.so.0)frame#6: clone + 0x3f (0x7f933d25eadf in /usr/lib/x86_64-linux-gnu/libc.so.6)terminatecalledafterthrowinganinstanceof'c10::DistBackendError'what(): [PGID0PGGUID0(default_pg)Rank0]Processgroupwatchdogthreadterminatedwithexception: [Rank0]Watchdogcaughtcollectiveoperationtimeout:WorkNCCL(SeqNum=20585,OpType=ALLREDUCE,NumelIn=1259621,NumelOut=1259621,Timeout(ms)=600000)ranfor600000millisecondsbeforetimingout.ExceptionraisedfromcheckTimeoutat/pytorch/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:635 (mostrecentcallfirst):frame#0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x98 (0x7f92e62535e8 in /opt/conda/lib/python3.10/site-packages/torch/lib/libc10.so)frame#1: c10d::ProcessGroupNCCL::WorkNCCL::checkTimeout(std::optional<std::chrono::duration<long, std::ratio<1l, 1000l> > >) + 0x23d (0x7f92e756ea6d in /opt/conda/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)frame#2: c10d::ProcessGroupNCCL::watchdogHandler() + 0xc80 (0x7f92e75707f0 in /opt/conda/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)frame#3: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x14d (0x7f92e7571efd in /opt/conda/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)frame#4: <unknown function> + 0xd8198 (0x7f92d7559198 in /opt/conda/bin/../lib/libstdc++.so.6)frame#5: <unknown function> + 0x7ea7 (0x7f933d48dea7 in /usr/lib/x86_64-linux-gnu/libpthread.so.0)frame#6: clone + 0x3f (0x7f933d25eadf in /usr/lib/x86_64-linux-gnu/libc.so.6)ExceptionraisedfromncclCommWatchdogat/pytorch/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1905 (mostrecentcallfirst):frame#0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x98 (0x7f92e62535e8 in /opt/conda/lib/python3.10/site-packages/torch/lib/libc10.so)frame#1: <unknown function> + 0x11b4abe (0x7f92e7540abe in /opt/conda/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)frame#2: <unknown function> + 0xe07bed (0x7f92e7193bed in /opt/conda/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)frame#3: <unknown function> + 0xd8198 (0x7f92d7559198 in /opt/conda/bin/../lib/libstdc++.so.6)frame#4: <unknown function> + 0x7ea7 (0x7f933d48dea7 in /usr/lib/x86_64-linux-gnu/libpthread.so.0)frame#5: clone + 0x3f (0x7f933d25eadf in /usr/lib/x86_64-linux-gnu/libc.so.6)E070917:05:39.81600056204site-packages/torch/distributed/elastic/multiprocessing/api.py:874]failed (exitcode:-6)local_rank:0 (pid:56213)ofbinary:/opt/conda/bin/python3.10Traceback (mostrecentcalllast):File"/opt/conda/bin/torchrun",line8,in<module>sys.exit(main())File"/opt/conda/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py",line355,inwrapperreturnf(*args,**kwargs)File"/opt/conda/lib/python3.10/site-packages/torch/distributed/run.py",line892,inmainrun(args)File"/opt/conda/lib/python3.10/site-packages/torch/distributed/run.py",line883,inrunelastic_launch(File"/opt/conda/lib/python3.10/site-packages/torch/distributed/launcher/api.py",line139,in__call__returnlaunch_agent(self._config,self._entrypoint,list(args))File"/opt/conda/lib/python3.10/site-packages/torch/distributed/launcher/api.py",line270,inlaunch_agentraiseChildFailedError(torch.distributed.elastic.multiprocessing.errors.ChildFailedError:============================================================trainer.pyFAILED------------------------------------------------------------Failures:<NO_OTHER_FAILURES>------------------------------------------------------------RootCause (firstobservedfailure):[0]:time :2025-07-09_17:05:39host :distributed-training.us-central1-a.c.lance-dev-ayush.internalrank :0 (local_rank:0)exitcode :-6 (pid:56213)error_file:<N/A>traceback :Signal6 (SIGABRT)receivedbyPID56213============================================================(base)jupyter@distributed-training:~/lance-dist-training$ (base)jupyter@distributed-training:~/lance-dist-training$python
This sampler provides perfectly balanced sharding by operating at thebatch level, not the fragment level. Calculates row ranges for each batch and deals those ranges out to the different ranks.
This logic gives interleaved batches to each process:
- Rank 0 gets row ranges for Batch 0, Batch 2, Batch 4, ...
- Rank 1 gets row ranges for Batch 1, Batch 3, Batch 5, ...
This sampler isDDP-aware and is the safest choice for balanced distributed training.
- Pros: It guarantees that every process receives almost the exact same number of batches, preventing workload imbalance and DDP deadlocks.
- Cons: It can be slightly less I/O efficient than
ShardedFragmentSampler. To construct a batch, it may need to perform a specific range read from a fragment, which can be less optimal than reading the entire fragment at once.
you cannot use thelance samplers (likeShardedBatchSampler orShardedFragmentSampler) with a map-style dataset.
The two systems are fundamentally incompatible by design:
- Lance Samplers are designed to workinside the iterable
LanceDataset. They don't generate indices. Instead, they directly control how thelancefile scanner reads and yields entire batches of data. They are tightly coupled to theLanceDataset's streaming (__iter__) mechanism. - PyTorch's
DistributedSamplerworks by generating a list ofindices (e.g.,[10, 5, 22]). TheDataLoaderthen takes these indices and fetches each item individually from a map-style dataset using its__getitem__method (e.g.,dataset[10]).
Because thelance samplers don't produce the indices that a map-styleDataLoader needs, you cannot use them together. You have to choose one of the two paths:
- Torch IterableStyle loader: Use the iterable
LanceDatasetwith alancesampler.Benefit: Useslance's native, optimized sampling. Must handle num_workers>0 as mentionedhere. - Map-style loader Use a map-style dataset (like the
LanceMapDatasetwe built, ortorchvision's) with PyTorch'sDistributedSampler.
All the scripts log data to wandb dashboard to the same project. Simply run the training scripts and look at the dashboard fortraining time per eoch or other metircs.
The same training loop implemented onImageFolder dataset via torchvision is provided in torch_version/ folder
About
Examples and guides for distributed training with lance
Resources
Uh oh!
There was an error while loading.Please reload this page.
Stars
Watchers
Forks
Releases
Packages0
Contributors2
Uh oh!
There was an error while loading.Please reload this page.