torch.utils.data#
Created On: Jun 13, 2025 | Last Updated On: Jun 13, 2025
At the heart of PyTorch data loading utility is thetorch.utils.data.DataLoaderclass. It represents a Python iterable over a dataset, with support for
These options are configured by the constructor arguments of aDataLoader, which has signature:
DataLoader(dataset,batch_size=1,shuffle=False,sampler=None,batch_sampler=None,num_workers=0,collate_fn=None,pin_memory=False,drop_last=False,timeout=0,worker_init_fn=None,*,prefetch_factor=2,persistent_workers=False)
The sections below describe in details the effects and usages of these options.
Dataset Types#
The most important argument ofDataLoaderconstructor isdataset, which indicates a dataset object to load datafrom. PyTorch supports two different types of datasets:
Map-style datasets#
A map-style dataset is one that implements the__getitem__() and__len__() protocols, and represents a map from (possibly non-integral)indices/keys to data samples.
For example, such a dataset, when accessed withdataset[idx], could readtheidx-th image and its corresponding label from a folder on the disk.
SeeDataset for more details.
Iterable-style datasets#
An iterable-style dataset is an instance of a subclass ofIterableDatasetthat implements the__iter__() protocol, and represents an iterable overdata samples. This type of datasets is particularly suitable for cases whererandom reads are expensive or even improbable, and where the batch size dependson the fetched data.
For example, such a dataset, when callediter(dataset), could return astream of data reading from a database, a remote server, or even logs generatedin real time.
SeeIterableDataset for more details.
Note
When using aIterableDataset withmulti-process data loading. The samedataset object is replicated on each worker process, and thus thereplicas must be configured differently to avoid duplicated data. SeeIterableDataset documentations for how toachieve this.
Data Loading Order andSampler#
Foriterable-style datasets, data loading orderis entirely controlled by the user-defined iterable. This allows easierimplementations of chunk-reading and dynamic batch size (e.g., by yielding abatched sample at each time).
The rest of this section concerns the case withmap-style datasets.torch.utils.data.Samplerclasses are used to specify the sequence of indices/keys used in data loading.They represent iterable objects over the indices to datasets. E.g., in thecommon case with stochastic gradient decent (SGD), aSampler could randomly permute a list of indicesand yield each one at a time, or yield a small number of them for mini-batchSGD.
A sequential or shuffled sampler will be automatically constructed based on theshuffle argument to aDataLoader.Alternatively, users may use thesampler argument to specify acustomSampler object that at each time yieldsthe next index/key to fetch.
A customSampler that yields a list of batchindices at a time can be passed as thebatch_sampler argument.Automatic batching can also be enabled viabatch_size anddrop_last arguments. Seethe next section for more detailson this.
Note
Neithersampler norbatch_sampler is compatible withiterable-style datasets, since such datasets have no notion of a key or anindex.
Loading Batched and Non-Batched Data#
DataLoader supports automatically collatingindividual fetched data samples into batches via argumentsbatch_size,drop_last,batch_sampler, andcollate_fn (which has a default function).
Automatic batching (default)#
This is the most common case, and corresponds to fetching a minibatch ofdata and collating them into batched samples, i.e., containing Tensors withone dimension being the batch dimension (usually the first).
Whenbatch_size (default1) is notNone, the data loader yieldsbatched samples instead of individual samples.batch_size anddrop_last arguments are used to specify how the data loader obtainsbatches of dataset keys. For map-style datasets, users can alternativelyspecifybatch_sampler, which yields a list of keys at a time.
Note
Thebatch_size anddrop_last arguments essentially are usedto construct abatch_sampler fromsampler. For map-styledatasets, thesampler is either provided by user or constructedbased on theshuffle argument. For iterable-style datasets, thesampler is a dummy infinite one. Seethis section on more details onsamplers.
Note
When fetching fromiterable-style datasets withmulti-processing thedrop_lastargument drops the last non-full batch of each worker’s dataset replica.
After fetching a list of samples using the indices from sampler, the functionpassed as thecollate_fn argument is used to collate lists of samplesinto batches.
In this case, loading from a map-style dataset is roughly equivalent with:
forindicesinbatch_sampler:yieldcollate_fn([dataset[i]foriinindices])
and loading from an iterable-style dataset is roughly equivalent with:
dataset_iter=iter(dataset)forindicesinbatch_sampler:yieldcollate_fn([next(dataset_iter)for_inindices])
A customcollate_fn can be used to customize collation, e.g., paddingsequential data to max length of a batch. Seethis section on more aboutcollate_fn.
Disable automatic batching#
In certain cases, users may want to handle batching manually in dataset code,or simply load individual samples. For example, it could be cheaper to directlyload batched data (e.g., bulk reads from a database or reading continuouschunks of memory), or the batch size is data dependent, or the program isdesigned to work on individual samples. Under these scenarios, it’s likelybetter to not use automatic batching (wherecollate_fn is used tocollate the samples), but let the data loader directly return each member ofthedataset object.
When bothbatch_size andbatch_sampler areNone (defaultvalue forbatch_sampler is alreadyNone), automatic batching isdisabled. Each sample obtained from thedataset is processed with thefunction passed as thecollate_fn argument.
When automatic batching is disabled, the defaultcollate_fn simplyconverts NumPy arrays into PyTorch Tensors, and keeps everything else untouched.
In this case, loading from a map-style dataset is roughly equivalent with:
forindexinsampler:yieldcollate_fn(dataset[index])
and loading from an iterable-style dataset is roughly equivalent with:
fordatainiter(dataset):yieldcollate_fn(data)
Seethis section on more aboutcollate_fn.
Working withcollate_fn#
The use ofcollate_fn is slightly different when automatic batching isenabled or disabled.
When automatic batching is disabled,collate_fn is called witheach individual data sample, and the output is yielded from the data loaderiterator. In this case, the defaultcollate_fn simply converts NumPyarrays in PyTorch tensors.
When automatic batching is enabled,collate_fn is called with a listof data samples at each time. It is expected to collate the input samples intoa batch for yielding from the data loader iterator. The rest of this sectiondescribes the behavior of the defaultcollate_fn(default_collate()).
For instance, if each data sample consists of a 3-channel image and an integralclass label, i.e., each element of the dataset returns a tuple(image,class_index), the defaultcollate_fn collates a list ofsuch tuples into a single tuple of a batched image tensor and a batched classlabel Tensor. In particular, the defaultcollate_fn has the followingproperties:
It always prepends a new dimension as the batch dimension.
It automatically converts NumPy arrays and Python numerical values intoPyTorch Tensors.
It preserves the data structure, e.g., if each sample is a dictionary, itoutputs a dictionary with the same set of keys but batched Tensors as values(or lists if the values can not be converted into Tensors). Samefor
lists,tuples,namedtuples, etc.
Users may use customizedcollate_fn to achieve custom batching, e.g.,collating along a dimension other than the first, padding sequences ofvarious lengths, or adding support for custom data types.
If you run into a situation where the outputs ofDataLoaderhave dimensions or type that is different from your expectation, you maywant to check yourcollate_fn.
Single- and Multi-process Data Loading#
ADataLoader uses single-process data loading bydefault.
Within a Python process, theGlobal Interpreter Lock (GIL)prevents true fully parallelizing Python code across threads. To avoid blockingcomputation code with data loading, PyTorch provides an easy switch to performmulti-process data loading by simply setting the argumentnum_workersto a positive integer.
Single-process data loading (default)#
In this mode, data fetching is done in the same process aDataLoader is initialized. Therefore, data loadingmay block computing. However, this mode may be preferred when resource(s) usedfor sharing data among processes (e.g., shared memory, file descriptors) islimited, or when the entire dataset is small and can be loaded entirely inmemory. Additionally, single-process loading often shows more readable errortraces and thus is useful for debugging.
Multi-process data loading#
Setting the argumentnum_workers as a positive integer willturn on multi-process data loading with the specified number of loader workerprocesses.
Warning
After several iterations, the loader worker processes will consumethe same amount of CPU memory as the parent process for all Pythonobjects in the parent process which are accessed from the workerprocesses. This can be problematic if the Dataset contains a lot ofdata (e.g., you are loading a very large list of filenames at Datasetconstruction time) and/or you are using a lot of workers (overallmemory usage isnumberofworkers*sizeofparentprocess). Thesimplest workaround is to replace Python objects with non-refcountedrepresentations such as Pandas, Numpy or PyArrow objects. Check outissue #13246for more details on why this occurs and example code for how toworkaround these problems.
In this mode, each time an iterator of aDataLoaderis created (e.g., when you callenumerate(dataloader)),num_workersworker processes are created. At this point, thedataset,collate_fn, andworker_init_fn are passed to eachworker, where they are used to initialize, and fetch data. This means thatdataset access together with its internal IO, transforms(includingcollate_fn) runs in the worker process.
torch.utils.data.get_worker_info() returns various useful informationin a worker process (including the worker id, dataset replica, initial seed,etc.), and returnsNone in main process. Users may use this function indataset code and/orworker_init_fn to individually configure eachdataset replica, and to determine whether the code is running in a workerprocess. For example, this can be particularly helpful in sharding the dataset.
For map-style datasets, the main process generates the indices usingsampler and sends them to the workers. So any shuffle randomization isdone in the main process which guides loading by assigning indices to load.
For iterable-style datasets, since each worker process gets a replica of thedataset object, naive multi-process loading will often result induplicated data. Usingtorch.utils.data.get_worker_info() and/orworker_init_fn, users may configure each replica independently. (SeeIterableDataset documentations for how to achievethis. ) For similar reasons, in multi-process loading, thedrop_lastargument drops the last non-full batch of each worker’s iterable-style datasetreplica.
Workers are shut down once the end of the iteration is reached, or when theiterator becomes garbage collected.
Warning
It is generally not recommended to return CUDA tensors in multi-processloading because of many subtleties in using CUDA and sharing CUDA tensors inmultiprocessing (seeCUDA in multiprocessing). Instead, we recommendusingautomatic memory pinning (i.e., settingpin_memory=True), which enables fast data transfer to CUDA-enabledGPUs.
Platform-specific behaviors#
Since workers rely on Pythonmultiprocessing, worker launch behavior isdifferent on Windows compared to Unix.
On Unix,
fork()is the defaultmultiprocessingstart method.Usingfork(), child workers typically can access thedatasetandPython argument functions directly through the cloned address space.On Windows or MacOS,
spawn()is the defaultmultiprocessingstart method.Usingspawn(), another interpreter is launched which runs your main script,followed by the internal worker function that receives thedataset,collate_fnand other arguments throughpickleserialization.
This separate serialization means that you should take two steps to ensure youare compatible with Windows while using multi-process data loading:
Wrap most of you main script’s code within
if__name__=='__main__':block,to make sure it doesn’t run again (most likely generating error) when each workerprocess is launched. You can place your dataset andDataLoaderinstance creation logic here, as it doesn’t need to be re-executed in workers.Make sure that any custom
collate_fn,worker_init_fnordatasetcode is declared as top level definitions, outside of the__main__check. This ensures that they are available in worker processes.(this is needed since functions are pickled as references only, notbytecode.)
Randomness in multi-process data loading#
By default, each worker will have its PyTorch seed set tobase_seed+worker_id,wherebase_seed is a long generated by main process using its RNG (thereby,consuming a RNG state mandatorily) or a specifiedgenerator. However, seeds for otherlibraries may be duplicated upon initializing workers, causing each worker to returnidentical random numbers. (Seethis section in FAQ.).
Inworker_init_fn, you may access the PyTorch seed set for each workerwith eithertorch.utils.data.get_worker_info().seedortorch.initial_seed(), and use it to seed other libraries before dataloading.
Memory Pinning#
Host to GPU copies are much faster when they originate from pinned (page-locked)memory. SeeUse pinned memory buffers for more details on when and how to usepinned memory generally.
For data loading, passingpin_memory=True to aDataLoader will automatically put the fetched dataTensors in pinned memory, and thus enables faster data transfer to CUDA-enabledGPUs.
The default memory pinning logic only recognizes Tensors and maps and iterablescontaining Tensors. By default, if the pinning logic sees a batch that is acustom type (which will occur if you have acollate_fn that returns acustom batch type), or if each element of your batch is a custom type, thepinning logic will not recognize them, and it will return that batch (or thoseelements) without pinning the memory. To enable memory pinning for custombatch or data type(s), define apin_memory() method on your customtype(s).
See the example below.
Example:
classSimpleCustomBatch:def__init__(self,data):transposed_data=list(zip(*data))self.inp=torch.stack(transposed_data[0],0)self.tgt=torch.stack(transposed_data[1],0)# custom memory pinning method on custom typedefpin_memory(self):self.inp=self.inp.pin_memory()self.tgt=self.tgt.pin_memory()returnselfdefcollate_wrapper(batch):returnSimpleCustomBatch(batch)inps=torch.arange(10*5,dtype=torch.float32).view(10,5)tgts=torch.arange(10*5,dtype=torch.float32).view(10,5)dataset=TensorDataset(inps,tgts)loader=DataLoader(dataset,batch_size=2,collate_fn=collate_wrapper,pin_memory=True)forbatch_ndx,sampleinenumerate(loader):print(sample.inp.is_pinned())print(sample.tgt.is_pinned())
- classtorch.utils.data.DataLoader(dataset,batch_size=1,shuffle=None,sampler=None,batch_sampler=None,num_workers=0,collate_fn=None,pin_memory=False,drop_last=False,timeout=0,worker_init_fn=None,multiprocessing_context=None,generator=None,*,prefetch_factor=None,persistent_workers=False,pin_memory_device='',in_order=True)[source]#
Data loader combines a dataset and a sampler, and provides an iterable over the given dataset.
The
DataLoadersupports both map-style anditerable-style datasets with single- or multi-process loading, customizingloading order and optional automatic batching (collation) and memory pinning.See
torch.utils.datadocumentation page for more details.- Parameters
dataset (Dataset) – dataset from which to load the data.
batch_size (int,optional) – how many samples per batch to load(default:
1).shuffle (bool,optional) – set to
Trueto have the data reshuffledat every epoch (default:False).sampler (Sampler orIterable,optional) – defines the strategy to drawsamples from the dataset. Can be any
Iterablewith__len__implemented. If specified,shufflemust not be specified.batch_sampler (Sampler orIterable,optional) – like
sampler, butreturns a batch of indices at a time. Mutually exclusive withbatch_size,shuffle,sampler,anddrop_last.num_workers (int,optional) – how many subprocesses to use for dataloading.
0means that the data will be loaded in the main process.(default:0)collate_fn (Callable,optional) – merges a list of samples to form amini-batch of Tensor(s). Used when using batched loading from amap-style dataset.
pin_memory (bool,optional) – If
True, the data loader will copy Tensorsinto device/CUDA pinned memory before returning them. If your data elementsare a custom type, or yourcollate_fnreturns a batch that is a custom type,see the example below.drop_last (bool,optional) – set to
Trueto drop the last incomplete batch,if the dataset size is not divisible by the batch size. IfFalseandthe size of dataset is not divisible by the batch size, then the last batchwill be smaller. (default:False)timeout (numeric,optional) – if positive, the timeout value for collecting a batchfrom workers. Should always be non-negative. (default:
0)worker_init_fn (Callable,optional) – If not
None, this will be called on eachworker subprocess with the worker id (an int in[0,num_workers-1]) asinput, after seeding and before data loading. (default:None)multiprocessing_context (str ormultiprocessing.context.BaseContext,optional) – If
None, the defaultmultiprocessing context # noqa: D401of your operating system willbe used. (default:None)generator (torch.Generator,optional) – If not
None, this RNG will be usedby RandomSampler to generate random indexes and multiprocessing to generatebase_seedfor workers. (default:None)prefetch_factor (int,optional,keyword-only arg) – Number of batches loadedin advance by each worker.
2means there will be a total of2 * num_workers batches prefetched across all workers. (default value dependson the set value for num_workers. If value of num_workers=0 default isNone.Otherwise, if value ofnum_workers>0default is2).persistent_workers (bool,optional) – If
True, the data loader will not shut downthe worker processes after a dataset has been consumed once. This allows tomaintain the workersDataset instances alive. (default:False)pin_memory_device (str,optional) – Deprecated, the currentacceleratorwill be used as the device if
pin_memory=True.in_order (bool,optional) – If
False, the data loader will not enforce that batchesare returned in a first-in, first-out order. Only applies whennum_workers>0. (default:True)
Warning
If the
spawnstart method is used,worker_init_fncannot be an unpicklable object, e.g., a lambda function. SeeMultiprocessing best practices on more details relatedto multiprocessing in PyTorch.Warning
len(dataloader)heuristic is based on the length of the sampler used.Whendatasetis anIterableDataset,it instead returns an estimate based onlen(dataset)/batch_size, with properrounding depending ondrop_last, regardless of multi-process loadingconfigurations. This represents the best guess PyTorch can make because PyTorchtrusts userdatasetcode in correctly handling multi-processloading to avoid duplicate data.However, if sharding results in multiple workers having incomplete last batches,this estimate can still be inaccurate, because (1) an otherwise complete batch canbe broken into multiple ones and (2) more than one batch worth of samples can bedropped when
drop_lastis set. Unfortunately, PyTorch can not detect suchcases in general.SeeDataset Types for more details on these two types of datasets and how
IterableDatasetinteracts withMulti-process data loading.Warning
SeeReproducibility, andMy data loader workers return identical random numbers, andRandomness in multi-process data loading notes for random seed related questions.
Warning
Settingin_order toFalse can harm reproducibility and may lead to a skewed datadistribution being fed to the trainer in cases with imbalanced data.
- classtorch.utils.data.Dataset[source]#
An abstract class representing a
Dataset.All datasets that represent a map from keys to data samples should subclassit. All subclasses should overwrite
__getitem__(), supporting fetching adata sample for a given key. Subclasses could also optionally overwrite__len__(), which is expected to return the size of the dataset by manySamplerimplementations and the default optionsofDataLoader. Subclasses could alsooptionally implement__getitems__(), for speedup batched samplesloading. This method accepts list of indices of samples of batch and returnslist of samples.Note
DataLoaderby default constructs an indexsampler that yields integral indices. To make it work with a map-styledataset with non-integral indices/keys, a custom sampler must be provided.
- classtorch.utils.data.IterableDataset[source]#
An iterable Dataset.
All datasets that represent an iterable of data samples should subclass it.Such form of datasets is particularly useful when data come from a stream.
All subclasses should overwrite
__iter__(), which would return aniterator of samples in this dataset.When a subclass is used with
DataLoader, eachitem in the dataset will be yielded from theDataLoaderiterator. Whennum_workers>0, each worker process will have adifferent copy of the dataset object, so it is often desired to configureeach copy independently to avoid having duplicate data returned from theworkers.get_worker_info(), when called in a workerprocess, returns information about the worker. It can be used in either thedataset’s__iter__()method or theDataLoader‘sworker_init_fnoption to modify each copy’s behavior.Example 1: splitting workload across all workers in
__iter__():>>>classMyIterableDataset(torch.utils.data.IterableDataset):...def__init__(self,start,end):...super(MyIterableDataset).__init__()...assertend>start,"this example only works with end >= start"...self.start=start...self.end=end......def__iter__(self):...worker_info=torch.utils.data.get_worker_info()...ifworker_infoisNone:# single-process data loading, return the full iterator...iter_start=self.start...iter_end=self.end...else:# in a worker process...# split workload...per_worker=int(math.ceil((self.end-self.start)/float(worker_info.num_workers)))...worker_id=worker_info.id...iter_start=self.start+worker_id*per_worker...iter_end=min(iter_start+per_worker,self.end)...returniter(range(iter_start,iter_end))...>>># should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].>>>ds=MyIterableDataset(start=3,end=7)>>># Single-process loading>>>print(list(torch.utils.data.DataLoader(ds,num_workers=0)))[tensor([3]), tensor([4]), tensor([5]), tensor([6])]>>># Multi-process loading with two worker processes>>># Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6].>>>print(list(torch.utils.data.DataLoader(ds,num_workers=2)))[tensor([3]), tensor([5]), tensor([4]), tensor([6])]>>># With even more workers>>>print(list(torch.utils.data.DataLoader(ds,num_workers=12)))[tensor([3]), tensor([5]), tensor([4]), tensor([6])]
Example 2: splitting workload across all workers using
worker_init_fn:>>>classMyIterableDataset(torch.utils.data.IterableDataset):...def__init__(self,start,end):...super(MyIterableDataset).__init__()...assertend>start,"this example only works with end >= start"...self.start=start...self.end=end......def__iter__(self):...returniter(range(self.start,self.end))...>>># should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].>>>ds=MyIterableDataset(start=3,end=7)>>># Single-process loading>>>print(list(torch.utils.data.DataLoader(ds,num_workers=0)))[3, 4, 5, 6]>>>>>># Directly doing multi-process loading yields duplicate data>>>print(list(torch.utils.data.DataLoader(ds,num_workers=2)))[3, 3, 4, 4, 5, 5, 6, 6]>>># Define a `worker_init_fn` that configures each dataset copy differently>>>defworker_init_fn(worker_id):...worker_info=torch.utils.data.get_worker_info()...dataset=worker_info.dataset# the dataset copy in this worker process...overall_start=dataset.start...overall_end=dataset.end...# configure the dataset to only process the split workload...per_worker=int(math.ceil((overall_end-overall_start)/float(worker_info.num_workers)))...worker_id=worker_info.id...dataset.start=overall_start+worker_id*per_worker...dataset.end=min(dataset.start+per_worker,overall_end)...>>># Mult-process loading with the custom `worker_init_fn`>>># Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6].>>>print(list(torch.utils.data.DataLoader(ds,num_workers=2,worker_init_fn=worker_init_fn)))[3, 5, 4, 6]>>># With even more workers>>>print(list(torch.utils.data.DataLoader(ds,num_workers=12,worker_init_fn=worker_init_fn)))[3, 4, 5, 6]
- classtorch.utils.data.TensorDataset(*tensors)[source]#
Dataset wrapping tensors.
Each sample will be retrieved by indexing tensors along the first dimension.
- Parameters
*tensors (Tensor) – tensors that have the same size of the first dimension.
- classtorch.utils.data.StackDataset(*args,**kwargs)[source]#
Dataset as a stacking of multiple datasets.
This class is useful to assemble different parts of complex input data, given as datasets.
Example
>>>images=ImageDataset()>>>texts=TextDataset()>>>tuple_stack=StackDataset(images,texts)>>>tuple_stack[0]==(images[0],texts[0])>>>dict_stack=StackDataset(image=images,text=texts)>>>dict_stack[0]=={"image":images[0],"text":texts[0]}
- classtorch.utils.data.ConcatDataset(datasets)[source]#
Dataset as a concatenation of multiple datasets.
This class is useful to assemble different existing datasets.
- Parameters
datasets (sequence) – List of datasets to be concatenated
- classtorch.utils.data.ChainDataset(datasets)[source]#
Dataset for chaining multiple
IterableDatasets.This class is useful to assemble different existing dataset streams. Thechaining operation is done on-the-fly, so concatenating large-scaledatasets with this class will be efficient.
- Parameters
datasets (iterable ofIterableDataset) – datasets to be chained together
- classtorch.utils.data.Subset(dataset,indices)[source]#
Subset of a dataset at specified indices.
- Parameters
dataset (Dataset) – The whole Dataset
indices (sequence) – Indices in the whole set selected for subset
- torch.utils.data._utils.collate.collate(batch,*,collate_fn_map=None)[source]#
General collate function that handles collection type of element within each batch.
The function also opens function registry to deal with specific element types.default_collate_fn_mapprovides default collate functions for tensors, numpy arrays, numbers and strings.
- Parameters
batch – a single batch to be collated
collate_fn_map (Optional[dict[Union[type,tuple[type,...]],Callable]]) – Optional dictionary mapping from element type to the corresponding collate function.If the element type isn’t present in this dictionary,this function will go through each key of the dictionary in the insertion order toinvoke the corresponding collate function if the element type is a subclass of the key.
Examples
>>>defcollate_tensor_fn(batch,*,collate_fn_map):...# Extend this function to handle batch of tensors...returntorch.stack(batch,0)>>>defcustom_collate(batch):...collate_map={torch.Tensor:collate_tensor_fn}...returncollate(batch,collate_fn_map=collate_map)>>># Extend `default_collate` by in-place modifying `default_collate_fn_map`>>>default_collate_fn_map.update({torch.Tensor:collate_tensor_fn})
Note
Each collate function requires a positional argument for batch and a keyword argumentfor the dictionary of collate functions ascollate_fn_map.
- torch.utils.data.default_collate(batch)[source]#
Take in a batch of data and put the elements within the batch into a tensor with an additional outer dimension - batch size.
The exact output type can be a
torch.Tensor, aSequence oftorch.Tensor, aCollection oftorch.Tensor, or left unchanged, depending on the input type.This is used as the default function for collation whenbatch_size orbatch_sampler is defined inDataLoader.Here is the general input type (based on the type of the element within the batch) to output type mapping:
torch.Tensor->torch.Tensor(with an added outer dimension batch size)NumPy Arrays ->
torch.Tensorfloat ->
torch.Tensorint ->
torch.Tensorstr ->str (unchanged)
bytes ->bytes (unchanged)
Mapping[K, V_i] ->Mapping[K, default_collate([V_1, V_2, …])]
NamedTuple[V1_i, V2_i, …] ->NamedTuple[default_collate([V1_1, V1_2, …]),default_collate([V2_1, V2_2, …]), …]
Sequence[V1_i, V2_i, …] ->Sequence[default_collate([V1_1, V1_2, …]),default_collate([V2_1, V2_2, …]), …]
- Parameters
batch – a single batch to be collated
Examples
>>># Example with a batch of `int`s:>>>default_collate([0,1,2,3])tensor([0, 1, 2, 3])>>># Example with a batch of `str`s:>>>default_collate(["a","b","c"])['a', 'b', 'c']>>># Example with `Map` inside the batch:>>>default_collate([{"A":0,"B":1},{"A":100,"B":100}]){'A': tensor([ 0, 100]), 'B': tensor([ 1, 100])}>>># Example with `NamedTuple` inside the batch:>>>Point=namedtuple("Point",["x","y"])>>>default_collate([Point(0,0),Point(1,1)])Point(x=tensor([0, 1]), y=tensor([0, 1]))>>># Example with `Tuple` inside the batch:>>>default_collate([(0,1),(2,3)])[tensor([0, 2]), tensor([1, 3])]>>># Example with `List` inside the batch:>>>default_collate([[0,1],[2,3]])[tensor([0, 2]), tensor([1, 3])]>>># Two options to extend `default_collate` to handle specific type>>># Option 1: Write custom collate function and invoke `default_collate`>>>defcustom_collate(batch):...elem=batch[0]...ifisinstance(elem,CustomType):# Some custom condition...return......else:# Fall back to `default_collate`...returndefault_collate(batch)>>># Option 2: In-place modify `default_collate_fn_map`>>>defcollate_customtype_fn(batch,*,collate_fn_map=None):...return...>>>default_collate_fn_map.update(CustomType,collate_customtype_fn)>>>default_collate(batch)# Handle `CustomType` automatically
- torch.utils.data.default_convert(data)[source]#
Convert each NumPy array element into a
torch.Tensor.If the input is aSequence,Collection, orMapping, it tries to convert each element inside to a
torch.Tensor.If the input is not an NumPy array, it is left unchanged.This is used as the default function for collation when bothbatch_sampler andbatch_sizeare NOT defined inDataLoader.The general input type to output type mapping is similar to thatof
default_collate(). See the description there for more details.- Parameters
data – a single data point to be converted
Examples
>>># Example with `int`>>>default_convert(0)0>>># Example with NumPy array>>>default_convert(np.array([0,1]))tensor([0, 1])>>># Example with NamedTuple>>>Point=namedtuple("Point",["x","y"])>>>default_convert(Point(0,0))Point(x=0, y=0)>>>default_convert(Point(np.array(0),np.array(0)))Point(x=tensor(0), y=tensor(0))>>># Example with List>>>default_convert([np.array([0,1]),np.array([2,3])])[tensor([0, 1]), tensor([2, 3])]
- torch.utils.data.get_worker_info()[source]#
Returns the information about the current
DataLoaderiterator worker process.When called in a worker, this returns an object guaranteed to have thefollowing attributes:
id: the current worker id.num_workers: the total number of workers.seed: the random seed set for the current worker. This value isdetermined by main process RNG and the worker id. SeeDataLoader’s documentation for more details.dataset: the copy of the dataset object inthis process. Notethat this will be a different object in a different process than the onein the main process.
When called in the main process, this returns
None.Note
When used in a
worker_init_fnpassed over toDataLoader, this method can be useful toset up each worker process differently, for instance, usingworker_idto configure thedatasetobject to only read a specific fraction of asharded dataset, or useseedto seed other libraries used in datasetcode.- Return type
Optional[WorkerInfo]
- torch.utils.data.random_split(dataset,lengths,generator=<torch._C.Generatorobject>)[source]#
Randomly split a dataset into non-overlapping new datasets of given lengths.
If a list of fractions that sum up to 1 is given,the lengths will be computed automatically asfloor(frac * len(dataset)) for each fraction provided.
After computing the lengths, if there are any remainders, 1 count will bedistributed in round-robin fashion to the lengthsuntil there are no remainders left.
Optionally fix the generator for reproducible results, e.g.:
Example
>>>generator1=torch.Generator().manual_seed(42)>>>generator2=torch.Generator().manual_seed(42)>>>random_split(range(10),[3,7],generator=generator1)>>>random_split(range(30),[0.3,0.3,0.4],generator=generator2)
- Parameters
- Return type
- classtorch.utils.data.Sampler(data_source=None)[source]#
Base class for all Samplers.
Every Sampler subclass has to provide an
__iter__()method, providing away to iterate over indices or lists of indices (batches) of dataset elements,and may provide a__len__()method that returns the length of the returned iterators.- Parameters
data_source (Dataset) – This argument is not used and will be removed in 2.2.0.You may still have custom implementation that utilizes it.
Example
>>>classAccedingSequenceLengthSampler(Sampler[int]):>>>def__init__(self,data:List[str])->None:>>>self.data=data>>>>>>def__len__(self)->int:>>>returnlen(self.data)>>>>>>def__iter__(self)->Iterator[int]:>>>sizes=torch.tensor([len(x)forxinself.data])>>>yield fromtorch.argsort(sizes).tolist()>>>>>>classAccedingSequenceLengthBatchSampler(Sampler[List[int]]):>>>def__init__(self,data:List[str],batch_size:int)->None:>>>self.data=data>>>self.batch_size=batch_size>>>>>>def__len__(self)->int:>>>return(len(self.data)+self.batch_size-1)//self.batch_size>>>>>>def__iter__(self)->Iterator[List[int]]:>>>sizes=torch.tensor([len(x)forxinself.data])>>>forbatchintorch.chunk(torch.argsort(sizes),len(self)):>>>yieldbatch.tolist()
Note
The
__len__()method isn’t strictly required byDataLoader, but is expected in anycalculation involving the length of aDataLoader.
- classtorch.utils.data.SequentialSampler(data_source)[source]#
Samples elements sequentially, always in the same order.
- Parameters
data_source (Dataset) – dataset to sample from
- classtorch.utils.data.RandomSampler(data_source,replacement=False,num_samples=None,generator=None)[source]#
Samples elements randomly. If without replacement, then sample from a shuffled dataset.
If with replacement, then user can specify
num_samplesto draw.
- classtorch.utils.data.SubsetRandomSampler(indices,generator=None)[source]#
Samples elements randomly from a given list of indices, without replacement.
- Parameters
indices (sequence) – a sequence of indices
generator (Generator) – Generator used in sampling.
- classtorch.utils.data.WeightedRandomSampler(weights,num_samples,replacement=True,generator=None)[source]#
Samples elements from
[0,..,len(weights)-1]with given probabilities (weights).- Parameters
weights (sequence) – a sequence of weights, not necessary summing up to one
num_samples (int) – number of samples to draw
replacement (bool) – if
True, samples are drawn with replacement.If not, they are drawn without replacement, which means that when asample index is drawn for a row, it cannot be drawn again for that row.generator (Generator) – Generator used in sampling.
Example
>>>list(...WeightedRandomSampler(...[0.1,0.9,0.4,0.7,3.0,0.6],5,replacement=True...)...)[4, 4, 1, 4, 5]>>>list(...WeightedRandomSampler(...[0.9,0.4,0.05,0.2,0.3,0.1],5,replacement=False...)...)[0, 1, 4, 3, 2]
- classtorch.utils.data.BatchSampler(sampler,batch_size,drop_last)[source]#
Wraps another sampler to yield a mini-batch of indices.
- Parameters
Example
>>>list(...BatchSampler(...SequentialSampler(range(10)),batch_size=3,drop_last=False...)...)[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]>>>list(...BatchSampler(SequentialSampler(range(10)),batch_size=3,drop_last=True)...)[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
- classtorch.utils.data.distributed.DistributedSampler(dataset,num_replicas=None,rank=None,shuffle=True,seed=0,drop_last=False)[source]#
Sampler that restricts data loading to a subset of the dataset.
It is especially useful in conjunction with
torch.nn.parallel.DistributedDataParallel. In such a case, eachprocess can pass aDistributedSamplerinstance as aDataLoadersampler, and load a subset of theoriginal dataset that is exclusive to it.Note
Dataset is assumed to be of constant size and that any instance of it alwaysreturns the same elements in the same order.
- Parameters
dataset (Dataset) – Dataset used for sampling.
num_replicas (int,optional) – Number of processes participating indistributed training. By default,
world_sizeis retrieved from thecurrent distributed group.rank (int,optional) – Rank of the current process within
num_replicas.By default,rankis retrieved from the current distributedgroup.shuffle (bool,optional) – If
True(default), sampler will shuffle theindices.seed (int,optional) – random seed used to shuffle the sampler if
shuffle=True. This number should be identical across allprocesses in the distributed group. Default:0.drop_last (bool,optional) – if
True, then the sampler will drop thetail of the data to make it evenly divisible across the number ofreplicas. IfFalse, the sampler will add extra indices to makethe data evenly divisible across the replicas. Default:False.
Warning
In distributed mode, calling the
set_epoch()method atthe beginning of each epochbefore creating theDataLoaderiteratoris necessary to make shuffling work properly across multiple epochs. Otherwise,the same ordering will be always used.Example:
>>>sampler=DistributedSampler(dataset)ifis_distributedelseNone>>>loader=DataLoader(dataset,shuffle=(samplerisNone),...sampler=sampler)>>>forepochinrange(start_epoch,n_epochs):...ifis_distributed:...sampler.set_epoch(epoch)...train(loader)