Rate this Page

torch.utils.data#

Created On: Jun 13, 2025 | Last Updated On: Jun 13, 2025

At the heart of PyTorch data loading utility is thetorch.utils.data.DataLoaderclass. It represents a Python iterable over a dataset, with support for

These options are configured by the constructor arguments of aDataLoader, which has signature:

DataLoader(dataset,batch_size=1,shuffle=False,sampler=None,batch_sampler=None,num_workers=0,collate_fn=None,pin_memory=False,drop_last=False,timeout=0,worker_init_fn=None,*,prefetch_factor=2,persistent_workers=False)

The sections below describe in details the effects and usages of these options.

Dataset Types#

The most important argument ofDataLoaderconstructor isdataset, which indicates a dataset object to load datafrom. PyTorch supports two different types of datasets:

Map-style datasets#

A map-style dataset is one that implements the__getitem__() and__len__() protocols, and represents a map from (possibly non-integral)indices/keys to data samples.

For example, such a dataset, when accessed withdataset[idx], could readtheidx-th image and its corresponding label from a folder on the disk.

SeeDataset for more details.

Iterable-style datasets#

An iterable-style dataset is an instance of a subclass ofIterableDatasetthat implements the__iter__() protocol, and represents an iterable overdata samples. This type of datasets is particularly suitable for cases whererandom reads are expensive or even improbable, and where the batch size dependson the fetched data.

For example, such a dataset, when callediter(dataset), could return astream of data reading from a database, a remote server, or even logs generatedin real time.

SeeIterableDataset for more details.

Note

When using aIterableDataset withmulti-process data loading. The samedataset object is replicated on each worker process, and thus thereplicas must be configured differently to avoid duplicated data. SeeIterableDataset documentations for how toachieve this.

Data Loading Order andSampler#

Foriterable-style datasets, data loading orderis entirely controlled by the user-defined iterable. This allows easierimplementations of chunk-reading and dynamic batch size (e.g., by yielding abatched sample at each time).

The rest of this section concerns the case withmap-style datasets.torch.utils.data.Samplerclasses are used to specify the sequence of indices/keys used in data loading.They represent iterable objects over the indices to datasets. E.g., in thecommon case with stochastic gradient decent (SGD), aSampler could randomly permute a list of indicesand yield each one at a time, or yield a small number of them for mini-batchSGD.

A sequential or shuffled sampler will be automatically constructed based on theshuffle argument to aDataLoader.Alternatively, users may use thesampler argument to specify acustomSampler object that at each time yieldsthe next index/key to fetch.

A customSampler that yields a list of batchindices at a time can be passed as thebatch_sampler argument.Automatic batching can also be enabled viabatch_size anddrop_last arguments. Seethe next section for more detailson this.

Note

Neithersampler norbatch_sampler is compatible withiterable-style datasets, since such datasets have no notion of a key or anindex.

Loading Batched and Non-Batched Data#

DataLoader supports automatically collatingindividual fetched data samples into batches via argumentsbatch_size,drop_last,batch_sampler, andcollate_fn (which has a default function).

Automatic batching (default)#

This is the most common case, and corresponds to fetching a minibatch ofdata and collating them into batched samples, i.e., containing Tensors withone dimension being the batch dimension (usually the first).

Whenbatch_size (default1) is notNone, the data loader yieldsbatched samples instead of individual samples.batch_size anddrop_last arguments are used to specify how the data loader obtainsbatches of dataset keys. For map-style datasets, users can alternativelyspecifybatch_sampler, which yields a list of keys at a time.

Note

Thebatch_size anddrop_last arguments essentially are usedto construct abatch_sampler fromsampler. For map-styledatasets, thesampler is either provided by user or constructedbased on theshuffle argument. For iterable-style datasets, thesampler is a dummy infinite one. Seethis section on more details onsamplers.

Note

When fetching fromiterable-style datasets withmulti-processing thedrop_lastargument drops the last non-full batch of each worker’s dataset replica.

After fetching a list of samples using the indices from sampler, the functionpassed as thecollate_fn argument is used to collate lists of samplesinto batches.

In this case, loading from a map-style dataset is roughly equivalent with:

forindicesinbatch_sampler:yieldcollate_fn([dataset[i]foriinindices])

and loading from an iterable-style dataset is roughly equivalent with:

dataset_iter=iter(dataset)forindicesinbatch_sampler:yieldcollate_fn([next(dataset_iter)for_inindices])

A customcollate_fn can be used to customize collation, e.g., paddingsequential data to max length of a batch. Seethis section on more aboutcollate_fn.

Disable automatic batching#

In certain cases, users may want to handle batching manually in dataset code,or simply load individual samples. For example, it could be cheaper to directlyload batched data (e.g., bulk reads from a database or reading continuouschunks of memory), or the batch size is data dependent, or the program isdesigned to work on individual samples. Under these scenarios, it’s likelybetter to not use automatic batching (wherecollate_fn is used tocollate the samples), but let the data loader directly return each member ofthedataset object.

When bothbatch_size andbatch_sampler areNone (defaultvalue forbatch_sampler is alreadyNone), automatic batching isdisabled. Each sample obtained from thedataset is processed with thefunction passed as thecollate_fn argument.

When automatic batching is disabled, the defaultcollate_fn simplyconverts NumPy arrays into PyTorch Tensors, and keeps everything else untouched.

In this case, loading from a map-style dataset is roughly equivalent with:

forindexinsampler:yieldcollate_fn(dataset[index])

and loading from an iterable-style dataset is roughly equivalent with:

fordatainiter(dataset):yieldcollate_fn(data)

Seethis section on more aboutcollate_fn.

Working withcollate_fn#

The use ofcollate_fn is slightly different when automatic batching isenabled or disabled.

When automatic batching is disabled,collate_fn is called witheach individual data sample, and the output is yielded from the data loaderiterator. In this case, the defaultcollate_fn simply converts NumPyarrays in PyTorch tensors.

When automatic batching is enabled,collate_fn is called with a listof data samples at each time. It is expected to collate the input samples intoa batch for yielding from the data loader iterator. The rest of this sectiondescribes the behavior of the defaultcollate_fn(default_collate()).

For instance, if each data sample consists of a 3-channel image and an integralclass label, i.e., each element of the dataset returns a tuple(image,class_index), the defaultcollate_fn collates a list ofsuch tuples into a single tuple of a batched image tensor and a batched classlabel Tensor. In particular, the defaultcollate_fn has the followingproperties:

  • It always prepends a new dimension as the batch dimension.

  • It automatically converts NumPy arrays and Python numerical values intoPyTorch Tensors.

  • It preserves the data structure, e.g., if each sample is a dictionary, itoutputs a dictionary with the same set of keys but batched Tensors as values(or lists if the values can not be converted into Tensors). Sameforlist s,tuple s,namedtuple s, etc.

Users may use customizedcollate_fn to achieve custom batching, e.g.,collating along a dimension other than the first, padding sequences ofvarious lengths, or adding support for custom data types.

If you run into a situation where the outputs ofDataLoaderhave dimensions or type that is different from your expectation, you maywant to check yourcollate_fn.

Single- and Multi-process Data Loading#

ADataLoader uses single-process data loading bydefault.

Within a Python process, theGlobal Interpreter Lock (GIL)prevents true fully parallelizing Python code across threads. To avoid blockingcomputation code with data loading, PyTorch provides an easy switch to performmulti-process data loading by simply setting the argumentnum_workersto a positive integer.

Single-process data loading (default)#

In this mode, data fetching is done in the same process aDataLoader is initialized. Therefore, data loadingmay block computing. However, this mode may be preferred when resource(s) usedfor sharing data among processes (e.g., shared memory, file descriptors) islimited, or when the entire dataset is small and can be loaded entirely inmemory. Additionally, single-process loading often shows more readable errortraces and thus is useful for debugging.

Multi-process data loading#

Setting the argumentnum_workers as a positive integer willturn on multi-process data loading with the specified number of loader workerprocesses.

Warning

After several iterations, the loader worker processes will consumethe same amount of CPU memory as the parent process for all Pythonobjects in the parent process which are accessed from the workerprocesses. This can be problematic if the Dataset contains a lot ofdata (e.g., you are loading a very large list of filenames at Datasetconstruction time) and/or you are using a lot of workers (overallmemory usage isnumberofworkers*sizeofparentprocess). Thesimplest workaround is to replace Python objects with non-refcountedrepresentations such as Pandas, Numpy or PyArrow objects. Check outissue #13246for more details on why this occurs and example code for how toworkaround these problems.

In this mode, each time an iterator of aDataLoaderis created (e.g., when you callenumerate(dataloader)),num_workersworker processes are created. At this point, thedataset,collate_fn, andworker_init_fn are passed to eachworker, where they are used to initialize, and fetch data. This means thatdataset access together with its internal IO, transforms(includingcollate_fn) runs in the worker process.

torch.utils.data.get_worker_info() returns various useful informationin a worker process (including the worker id, dataset replica, initial seed,etc.), and returnsNone in main process. Users may use this function indataset code and/orworker_init_fn to individually configure eachdataset replica, and to determine whether the code is running in a workerprocess. For example, this can be particularly helpful in sharding the dataset.

For map-style datasets, the main process generates the indices usingsampler and sends them to the workers. So any shuffle randomization isdone in the main process which guides loading by assigning indices to load.

For iterable-style datasets, since each worker process gets a replica of thedataset object, naive multi-process loading will often result induplicated data. Usingtorch.utils.data.get_worker_info() and/orworker_init_fn, users may configure each replica independently. (SeeIterableDataset documentations for how to achievethis. ) For similar reasons, in multi-process loading, thedrop_lastargument drops the last non-full batch of each worker’s iterable-style datasetreplica.

Workers are shut down once the end of the iteration is reached, or when theiterator becomes garbage collected.

Warning

It is generally not recommended to return CUDA tensors in multi-processloading because of many subtleties in using CUDA and sharing CUDA tensors inmultiprocessing (seeCUDA in multiprocessing). Instead, we recommendusingautomatic memory pinning (i.e., settingpin_memory=True), which enables fast data transfer to CUDA-enabledGPUs.

Platform-specific behaviors#

Since workers rely on Pythonmultiprocessing, worker launch behavior isdifferent on Windows compared to Unix.

  • On Unix,fork() is the defaultmultiprocessing start method.Usingfork(), child workers typically can access thedataset andPython argument functions directly through the cloned address space.

  • On Windows or MacOS,spawn() is the defaultmultiprocessing start method.Usingspawn(), another interpreter is launched which runs your main script,followed by the internal worker function that receives thedataset,collate_fn and other arguments throughpickle serialization.

This separate serialization means that you should take two steps to ensure youare compatible with Windows while using multi-process data loading:

  • Wrap most of you main script’s code withinif__name__=='__main__': block,to make sure it doesn’t run again (most likely generating error) when each workerprocess is launched. You can place your dataset andDataLoaderinstance creation logic here, as it doesn’t need to be re-executed in workers.

  • Make sure that any customcollate_fn,worker_init_fnordataset code is declared as top level definitions, outside of the__main__ check. This ensures that they are available in worker processes.(this is needed since functions are pickled as references only, notbytecode.)

Randomness in multi-process data loading#

By default, each worker will have its PyTorch seed set tobase_seed+worker_id,wherebase_seed is a long generated by main process using its RNG (thereby,consuming a RNG state mandatorily) or a specifiedgenerator. However, seeds for otherlibraries may be duplicated upon initializing workers, causing each worker to returnidentical random numbers. (Seethis section in FAQ.).

Inworker_init_fn, you may access the PyTorch seed set for each workerwith eithertorch.utils.data.get_worker_info().seedortorch.initial_seed(), and use it to seed other libraries before dataloading.

Memory Pinning#

Host to GPU copies are much faster when they originate from pinned (page-locked)memory. SeeUse pinned memory buffers for more details on when and how to usepinned memory generally.

For data loading, passingpin_memory=True to aDataLoader will automatically put the fetched dataTensors in pinned memory, and thus enables faster data transfer to CUDA-enabledGPUs.

The default memory pinning logic only recognizes Tensors and maps and iterablescontaining Tensors. By default, if the pinning logic sees a batch that is acustom type (which will occur if you have acollate_fn that returns acustom batch type), or if each element of your batch is a custom type, thepinning logic will not recognize them, and it will return that batch (or thoseelements) without pinning the memory. To enable memory pinning for custombatch or data type(s), define apin_memory() method on your customtype(s).

See the example below.

Example:

classSimpleCustomBatch:def__init__(self,data):transposed_data=list(zip(*data))self.inp=torch.stack(transposed_data[0],0)self.tgt=torch.stack(transposed_data[1],0)# custom memory pinning method on custom typedefpin_memory(self):self.inp=self.inp.pin_memory()self.tgt=self.tgt.pin_memory()returnselfdefcollate_wrapper(batch):returnSimpleCustomBatch(batch)inps=torch.arange(10*5,dtype=torch.float32).view(10,5)tgts=torch.arange(10*5,dtype=torch.float32).view(10,5)dataset=TensorDataset(inps,tgts)loader=DataLoader(dataset,batch_size=2,collate_fn=collate_wrapper,pin_memory=True)forbatch_ndx,sampleinenumerate(loader):print(sample.inp.is_pinned())print(sample.tgt.is_pinned())
classtorch.utils.data.DataLoader(dataset,batch_size=1,shuffle=None,sampler=None,batch_sampler=None,num_workers=0,collate_fn=None,pin_memory=False,drop_last=False,timeout=0,worker_init_fn=None,multiprocessing_context=None,generator=None,*,prefetch_factor=None,persistent_workers=False,pin_memory_device='',in_order=True)[source]#

Data loader combines a dataset and a sampler, and provides an iterable over the given dataset.

TheDataLoader supports both map-style anditerable-style datasets with single- or multi-process loading, customizingloading order and optional automatic batching (collation) and memory pinning.

Seetorch.utils.data documentation page for more details.

Parameters
  • dataset (Dataset) – dataset from which to load the data.

  • batch_size (int,optional) – how many samples per batch to load(default:1).

  • shuffle (bool,optional) – set toTrue to have the data reshuffledat every epoch (default:False).

  • sampler (Sampler orIterable,optional) – defines the strategy to drawsamples from the dataset. Can be anyIterable with__len__implemented. If specified,shuffle must not be specified.

  • batch_sampler (Sampler orIterable,optional) – likesampler, butreturns a batch of indices at a time. Mutually exclusive withbatch_size,shuffle,sampler,anddrop_last.

  • num_workers (int,optional) – how many subprocesses to use for dataloading.0 means that the data will be loaded in the main process.(default:0)

  • collate_fn (Callable,optional) – merges a list of samples to form amini-batch of Tensor(s). Used when using batched loading from amap-style dataset.

  • pin_memory (bool,optional) – IfTrue, the data loader will copy Tensorsinto device/CUDA pinned memory before returning them. If your data elementsare a custom type, or yourcollate_fn returns a batch that is a custom type,see the example below.

  • drop_last (bool,optional) – set toTrue to drop the last incomplete batch,if the dataset size is not divisible by the batch size. IfFalse andthe size of dataset is not divisible by the batch size, then the last batchwill be smaller. (default:False)

  • timeout (numeric,optional) – if positive, the timeout value for collecting a batchfrom workers. Should always be non-negative. (default:0)

  • worker_init_fn (Callable,optional) – If notNone, this will be called on eachworker subprocess with the worker id (an int in[0,num_workers-1]) asinput, after seeding and before data loading. (default:None)

  • multiprocessing_context (str ormultiprocessing.context.BaseContext,optional) – IfNone, the defaultmultiprocessing context # noqa: D401of your operating system willbe used. (default:None)

  • generator (torch.Generator,optional) – If notNone, this RNG will be usedby RandomSampler to generate random indexes and multiprocessing to generatebase_seed for workers. (default:None)

  • prefetch_factor (int,optional,keyword-only arg) – Number of batches loadedin advance by each worker.2 means there will be a total of2 * num_workers batches prefetched across all workers. (default value dependson the set value for num_workers. If value of num_workers=0 default isNone.Otherwise, if value ofnum_workers>0 default is2).

  • persistent_workers (bool,optional) – IfTrue, the data loader will not shut downthe worker processes after a dataset has been consumed once. This allows tomaintain the workersDataset instances alive. (default:False)

  • pin_memory_device (str,optional) – Deprecated, the currentacceleratorwill be used as the device ifpin_memory=True.

  • in_order (bool,optional) – IfFalse, the data loader will not enforce that batchesare returned in a first-in, first-out order. Only applies whennum_workers>0. (default:True)

Warning

If thespawn start method is used,worker_init_fncannot be an unpicklable object, e.g., a lambda function. SeeMultiprocessing best practices on more details relatedto multiprocessing in PyTorch.

Warning

len(dataloader) heuristic is based on the length of the sampler used.Whendataset is anIterableDataset,it instead returns an estimate based onlen(dataset)/batch_size, with properrounding depending ondrop_last, regardless of multi-process loadingconfigurations. This represents the best guess PyTorch can make because PyTorchtrusts userdataset code in correctly handling multi-processloading to avoid duplicate data.

However, if sharding results in multiple workers having incomplete last batches,this estimate can still be inaccurate, because (1) an otherwise complete batch canbe broken into multiple ones and (2) more than one batch worth of samples can bedropped whendrop_last is set. Unfortunately, PyTorch can not detect suchcases in general.

SeeDataset Types for more details on these two types of datasets and howIterableDataset interacts withMulti-process data loading.

Warning

Settingin_order toFalse can harm reproducibility and may lead to a skewed datadistribution being fed to the trainer in cases with imbalanced data.

classtorch.utils.data.Dataset[source]#

An abstract class representing aDataset.

All datasets that represent a map from keys to data samples should subclassit. All subclasses should overwrite__getitem__(), supporting fetching adata sample for a given key. Subclasses could also optionally overwrite__len__(), which is expected to return the size of the dataset by manySampler implementations and the default optionsofDataLoader. Subclasses could alsooptionally implement__getitems__(), for speedup batched samplesloading. This method accepts list of indices of samples of batch and returnslist of samples.

Note

DataLoader by default constructs an indexsampler that yields integral indices. To make it work with a map-styledataset with non-integral indices/keys, a custom sampler must be provided.

classtorch.utils.data.IterableDataset[source]#

An iterable Dataset.

All datasets that represent an iterable of data samples should subclass it.Such form of datasets is particularly useful when data come from a stream.

All subclasses should overwrite__iter__(), which would return aniterator of samples in this dataset.

When a subclass is used withDataLoader, eachitem in the dataset will be yielded from theDataLoaderiterator. Whennum_workers>0, each worker process will have adifferent copy of the dataset object, so it is often desired to configureeach copy independently to avoid having duplicate data returned from theworkers.get_worker_info(), when called in a workerprocess, returns information about the worker. It can be used in either thedataset’s__iter__() method or theDataLoader ‘sworker_init_fn option to modify each copy’s behavior.

Example 1: splitting workload across all workers in__iter__():

>>>classMyIterableDataset(torch.utils.data.IterableDataset):...def__init__(self,start,end):...super(MyIterableDataset).__init__()...assertend>start,"this example only works with end >= start"...self.start=start...self.end=end......def__iter__(self):...worker_info=torch.utils.data.get_worker_info()...ifworker_infoisNone:# single-process data loading, return the full iterator...iter_start=self.start...iter_end=self.end...else:# in a worker process...# split workload...per_worker=int(math.ceil((self.end-self.start)/float(worker_info.num_workers)))...worker_id=worker_info.id...iter_start=self.start+worker_id*per_worker...iter_end=min(iter_start+per_worker,self.end)...returniter(range(iter_start,iter_end))...>>># should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].>>>ds=MyIterableDataset(start=3,end=7)>>># Single-process loading>>>print(list(torch.utils.data.DataLoader(ds,num_workers=0)))[tensor([3]), tensor([4]), tensor([5]), tensor([6])]>>># Multi-process loading with two worker processes>>># Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].>>>print(list(torch.utils.data.DataLoader(ds,num_workers=2)))[tensor([3]), tensor([5]), tensor([4]), tensor([6])]>>># With even more workers>>>print(list(torch.utils.data.DataLoader(ds,num_workers=12)))[tensor([3]), tensor([5]), tensor([4]), tensor([6])]

Example 2: splitting workload across all workers usingworker_init_fn:

>>>classMyIterableDataset(torch.utils.data.IterableDataset):...def__init__(self,start,end):...super(MyIterableDataset).__init__()...assertend>start,"this example only works with end >= start"...self.start=start...self.end=end......def__iter__(self):...returniter(range(self.start,self.end))...>>># should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].>>>ds=MyIterableDataset(start=3,end=7)>>># Single-process loading>>>print(list(torch.utils.data.DataLoader(ds,num_workers=0)))[3, 4, 5, 6]>>>>>># Directly doing multi-process loading yields duplicate data>>>print(list(torch.utils.data.DataLoader(ds,num_workers=2)))[3, 3, 4, 4, 5, 5, 6, 6]>>># Define a `worker_init_fn` that configures each dataset copy differently>>>defworker_init_fn(worker_id):...worker_info=torch.utils.data.get_worker_info()...dataset=worker_info.dataset# the dataset copy in this worker process...overall_start=dataset.start...overall_end=dataset.end...# configure the dataset to only process the split workload...per_worker=int(math.ceil((overall_end-overall_start)/float(worker_info.num_workers)))...worker_id=worker_info.id...dataset.start=overall_start+worker_id*per_worker...dataset.end=min(dataset.start+per_worker,overall_end)...>>># Mult-process loading with the custom `worker_init_fn`>>># Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].>>>print(list(torch.utils.data.DataLoader(ds,num_workers=2,worker_init_fn=worker_init_fn)))[3, 5, 4, 6]>>># With even more workers>>>print(list(torch.utils.data.DataLoader(ds,num_workers=12,worker_init_fn=worker_init_fn)))[3, 4, 5, 6]
classtorch.utils.data.TensorDataset(*tensors)[source]#

Dataset wrapping tensors.

Each sample will be retrieved by indexing tensors along the first dimension.

Parameters

*tensors (Tensor) – tensors that have the same size of the first dimension.

classtorch.utils.data.StackDataset(*args,**kwargs)[source]#

Dataset as a stacking of multiple datasets.

This class is useful to assemble different parts of complex input data, given as datasets.

Example

>>>images=ImageDataset()>>>texts=TextDataset()>>>tuple_stack=StackDataset(images,texts)>>>tuple_stack[0]==(images[0],texts[0])>>>dict_stack=StackDataset(image=images,text=texts)>>>dict_stack[0]=={"image":images[0],"text":texts[0]}
Parameters
  • *args (Dataset) – Datasets for stacking returned as tuple.

  • **kwargs (Dataset) – Datasets for stacking returned as dict.

classtorch.utils.data.ConcatDataset(datasets)[source]#

Dataset as a concatenation of multiple datasets.

This class is useful to assemble different existing datasets.

Parameters

datasets (sequence) – List of datasets to be concatenated

classtorch.utils.data.ChainDataset(datasets)[source]#

Dataset for chaining multipleIterableDataset s.

This class is useful to assemble different existing dataset streams. Thechaining operation is done on-the-fly, so concatenating large-scaledatasets with this class will be efficient.

Parameters

datasets (iterable ofIterableDataset) – datasets to be chained together

classtorch.utils.data.Subset(dataset,indices)[source]#

Subset of a dataset at specified indices.

Parameters
  • dataset (Dataset) – The whole Dataset

  • indices (sequence) – Indices in the whole set selected for subset

torch.utils.data._utils.collate.collate(batch,*,collate_fn_map=None)[source]#

General collate function that handles collection type of element within each batch.

The function also opens function registry to deal with specific element types.default_collate_fn_mapprovides default collate functions for tensors, numpy arrays, numbers and strings.

Parameters
  • batch – a single batch to be collated

  • collate_fn_map (Optional[dict[Union[type,tuple[type,...]],Callable]]) – Optional dictionary mapping from element type to the corresponding collate function.If the element type isn’t present in this dictionary,this function will go through each key of the dictionary in the insertion order toinvoke the corresponding collate function if the element type is a subclass of the key.

Examples

>>>defcollate_tensor_fn(batch,*,collate_fn_map):...# Extend this function to handle batch of tensors...returntorch.stack(batch,0)>>>defcustom_collate(batch):...collate_map={torch.Tensor:collate_tensor_fn}...returncollate(batch,collate_fn_map=collate_map)>>># Extend `default_collate` by in-place modifying `default_collate_fn_map`>>>default_collate_fn_map.update({torch.Tensor:collate_tensor_fn})

Note

Each collate function requires a positional argument for batch and a keyword argumentfor the dictionary of collate functions ascollate_fn_map.

torch.utils.data.default_collate(batch)[source]#

Take in a batch of data and put the elements within the batch into a tensor with an additional outer dimension - batch size.

The exact output type can be atorch.Tensor, aSequence oftorch.Tensor, aCollection oftorch.Tensor, or left unchanged, depending on the input type.This is used as the default function for collation whenbatch_size orbatch_sampler is defined inDataLoader.

Here is the general input type (based on the type of the element within the batch) to output type mapping:

  • torch.Tensor ->torch.Tensor (with an added outer dimension batch size)

  • NumPy Arrays ->torch.Tensor

  • float ->torch.Tensor

  • int ->torch.Tensor

  • str ->str (unchanged)

  • bytes ->bytes (unchanged)

  • Mapping[K, V_i] ->Mapping[K, default_collate([V_1, V_2, …])]

  • NamedTuple[V1_i, V2_i, …] ->NamedTuple[default_collate([V1_1, V1_2, …]),default_collate([V2_1, V2_2, …]), …]

  • Sequence[V1_i, V2_i, …] ->Sequence[default_collate([V1_1, V1_2, …]),default_collate([V2_1, V2_2, …]), …]

Parameters

batch – a single batch to be collated

Examples

>>># Example with a batch of `int`s:>>>default_collate([0,1,2,3])tensor([0, 1, 2, 3])>>># Example with a batch of `str`s:>>>default_collate(["a","b","c"])['a', 'b', 'c']>>># Example with `Map` inside the batch:>>>default_collate([{"A":0,"B":1},{"A":100,"B":100}]){'A': tensor([  0, 100]), 'B': tensor([  1, 100])}>>># Example with `NamedTuple` inside the batch:>>>Point=namedtuple("Point",["x","y"])>>>default_collate([Point(0,0),Point(1,1)])Point(x=tensor([0, 1]), y=tensor([0, 1]))>>># Example with `Tuple` inside the batch:>>>default_collate([(0,1),(2,3)])[tensor([0, 2]), tensor([1, 3])]>>># Example with `List` inside the batch:>>>default_collate([[0,1],[2,3]])[tensor([0, 2]), tensor([1, 3])]>>># Two options to extend `default_collate` to handle specific type>>># Option 1: Write custom collate function and invoke `default_collate`>>>defcustom_collate(batch):...elem=batch[0]...ifisinstance(elem,CustomType):# Some custom condition...return......else:# Fall back to `default_collate`...returndefault_collate(batch)>>># Option 2: In-place modify `default_collate_fn_map`>>>defcollate_customtype_fn(batch,*,collate_fn_map=None):...return...>>>default_collate_fn_map.update(CustomType,collate_customtype_fn)>>>default_collate(batch)# Handle `CustomType` automatically
torch.utils.data.default_convert(data)[source]#

Convert each NumPy array element into atorch.Tensor.

If the input is aSequence,Collection, orMapping, it tries to convert each element inside to atorch.Tensor.If the input is not an NumPy array, it is left unchanged.This is used as the default function for collation when bothbatch_sampler andbatch_sizeare NOT defined inDataLoader.

The general input type to output type mapping is similar to thatofdefault_collate(). See the description there for more details.

Parameters

data – a single data point to be converted

Examples

>>># Example with `int`>>>default_convert(0)0>>># Example with NumPy array>>>default_convert(np.array([0,1]))tensor([0, 1])>>># Example with NamedTuple>>>Point=namedtuple("Point",["x","y"])>>>default_convert(Point(0,0))Point(x=0, y=0)>>>default_convert(Point(np.array(0),np.array(0)))Point(x=tensor(0), y=tensor(0))>>># Example with List>>>default_convert([np.array([0,1]),np.array([2,3])])[tensor([0, 1]), tensor([2, 3])]
torch.utils.data.get_worker_info()[source]#

Returns the information about the currentDataLoader iterator worker process.

When called in a worker, this returns an object guaranteed to have thefollowing attributes:

  • id: the current worker id.

  • num_workers: the total number of workers.

  • seed: the random seed set for the current worker. This value isdetermined by main process RNG and the worker id. SeeDataLoader’s documentation for more details.

  • dataset: the copy of the dataset object inthis process. Notethat this will be a different object in a different process than the onein the main process.

When called in the main process, this returnsNone.

Note

When used in aworker_init_fn passed over toDataLoader, this method can be useful toset up each worker process differently, for instance, usingworker_idto configure thedataset object to only read a specific fraction of asharded dataset, or useseed to seed other libraries used in datasetcode.

Return type

Optional[WorkerInfo]

torch.utils.data.random_split(dataset,lengths,generator=<torch._C.Generatorobject>)[source]#

Randomly split a dataset into non-overlapping new datasets of given lengths.

If a list of fractions that sum up to 1 is given,the lengths will be computed automatically asfloor(frac * len(dataset)) for each fraction provided.

After computing the lengths, if there are any remainders, 1 count will bedistributed in round-robin fashion to the lengthsuntil there are no remainders left.

Optionally fix the generator for reproducible results, e.g.:

Example

>>>generator1=torch.Generator().manual_seed(42)>>>generator2=torch.Generator().manual_seed(42)>>>random_split(range(10),[3,7],generator=generator1)>>>random_split(range(30),[0.3,0.3,0.4],generator=generator2)
Parameters
  • dataset (Dataset) – Dataset to be split

  • lengths (sequence) – lengths or fractions of splits to be produced

  • generator (Generator) – Generator used for the random permutation.

Return type

list[torch.utils.data.dataset.Subset[~_T]]

classtorch.utils.data.Sampler(data_source=None)[source]#

Base class for all Samplers.

Every Sampler subclass has to provide an__iter__() method, providing away to iterate over indices or lists of indices (batches) of dataset elements,and may provide a__len__() method that returns the length of the returned iterators.

Parameters

data_source (Dataset) – This argument is not used and will be removed in 2.2.0.You may still have custom implementation that utilizes it.

Example

>>>classAccedingSequenceLengthSampler(Sampler[int]):>>>def__init__(self,data:List[str])->None:>>>self.data=data>>>>>>def__len__(self)->int:>>>returnlen(self.data)>>>>>>def__iter__(self)->Iterator[int]:>>>sizes=torch.tensor([len(x)forxinself.data])>>>yield fromtorch.argsort(sizes).tolist()>>>>>>classAccedingSequenceLengthBatchSampler(Sampler[List[int]]):>>>def__init__(self,data:List[str],batch_size:int)->None:>>>self.data=data>>>self.batch_size=batch_size>>>>>>def__len__(self)->int:>>>return(len(self.data)+self.batch_size-1)//self.batch_size>>>>>>def__iter__(self)->Iterator[List[int]]:>>>sizes=torch.tensor([len(x)forxinself.data])>>>forbatchintorch.chunk(torch.argsort(sizes),len(self)):>>>yieldbatch.tolist()

Note

The__len__() method isn’t strictly required byDataLoader, but is expected in anycalculation involving the length of aDataLoader.

classtorch.utils.data.SequentialSampler(data_source)[source]#

Samples elements sequentially, always in the same order.

Parameters

data_source (Dataset) – dataset to sample from

classtorch.utils.data.RandomSampler(data_source,replacement=False,num_samples=None,generator=None)[source]#

Samples elements randomly. If without replacement, then sample from a shuffled dataset.

If with replacement, then user can specifynum_samples to draw.

Parameters
  • data_source (Dataset) – dataset to sample from

  • replacement (bool) – samples are drawn on-demand with replacement ifTrue, default=``False``

  • num_samples (int) – number of samples to draw, default=`len(dataset)`.

  • generator (Generator) – Generator used in sampling.

classtorch.utils.data.SubsetRandomSampler(indices,generator=None)[source]#

Samples elements randomly from a given list of indices, without replacement.

Parameters
  • indices (sequence) – a sequence of indices

  • generator (Generator) – Generator used in sampling.

classtorch.utils.data.WeightedRandomSampler(weights,num_samples,replacement=True,generator=None)[source]#

Samples elements from[0,..,len(weights)-1] with given probabilities (weights).

Parameters
  • weights (sequence) – a sequence of weights, not necessary summing up to one

  • num_samples (int) – number of samples to draw

  • replacement (bool) – ifTrue, samples are drawn with replacement.If not, they are drawn without replacement, which means that when asample index is drawn for a row, it cannot be drawn again for that row.

  • generator (Generator) – Generator used in sampling.

Example

>>>list(...WeightedRandomSampler(...[0.1,0.9,0.4,0.7,3.0,0.6],5,replacement=True...)...)[4, 4, 1, 4, 5]>>>list(...WeightedRandomSampler(...[0.9,0.4,0.05,0.2,0.3,0.1],5,replacement=False...)...)[0, 1, 4, 3, 2]
classtorch.utils.data.BatchSampler(sampler,batch_size,drop_last)[source]#

Wraps another sampler to yield a mini-batch of indices.

Parameters
  • sampler (Sampler orIterable) – Base sampler. Can be any iterable object

  • batch_size (int) – Size of mini-batch.

  • drop_last (bool) – IfTrue, the sampler will drop the last batch ifits size would be less thanbatch_size

Example

>>>list(...BatchSampler(...SequentialSampler(range(10)),batch_size=3,drop_last=False...)...)[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]>>>list(...BatchSampler(SequentialSampler(range(10)),batch_size=3,drop_last=True)...)[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
classtorch.utils.data.distributed.DistributedSampler(dataset,num_replicas=None,rank=None,shuffle=True,seed=0,drop_last=False)[source]#

Sampler that restricts data loading to a subset of the dataset.

It is especially useful in conjunction withtorch.nn.parallel.DistributedDataParallel. In such a case, eachprocess can pass aDistributedSampler instance as aDataLoader sampler, and load a subset of theoriginal dataset that is exclusive to it.

Note

Dataset is assumed to be of constant size and that any instance of it alwaysreturns the same elements in the same order.

Parameters
  • dataset (Dataset) – Dataset used for sampling.

  • num_replicas (int,optional) – Number of processes participating indistributed training. By default,world_size is retrieved from thecurrent distributed group.

  • rank (int,optional) – Rank of the current process withinnum_replicas.By default,rank is retrieved from the current distributedgroup.

  • shuffle (bool,optional) – IfTrue (default), sampler will shuffle theindices.

  • seed (int,optional) – random seed used to shuffle the sampler ifshuffle=True. This number should be identical across allprocesses in the distributed group. Default:0.

  • drop_last (bool,optional) – ifTrue, then the sampler will drop thetail of the data to make it evenly divisible across the number ofreplicas. IfFalse, the sampler will add extra indices to makethe data evenly divisible across the replicas. Default:False.

Warning

In distributed mode, calling theset_epoch() method atthe beginning of each epochbefore creating theDataLoader iteratoris necessary to make shuffling work properly across multiple epochs. Otherwise,the same ordering will be always used.

Example:

>>>sampler=DistributedSampler(dataset)ifis_distributedelseNone>>>loader=DataLoader(dataset,shuffle=(samplerisNone),...sampler=sampler)>>>forepochinrange(start_epoch,n_epochs):...ifis_distributed:...sampler.set_epoch(epoch)...train(loader)