Rate this Page

Note

Go to the endto download the full example code.

Learn the Basics ||Quickstart ||Tensors ||Datasets & DataLoaders ||Transforms ||Build Model ||Autograd ||Optimization ||Save & Load Model

Datasets & DataLoaders#

Created On: Feb 09, 2021 | Last Updated: Sep 24, 2025 | Last Verified: Nov 05, 2024

Code for processing data samples can get messy and hard to maintain; we ideally want our dataset codeto be decoupled from our model training code for better readability and modularity.PyTorch provides two data primitives:torch.utils.data.DataLoader andtorch.utils.data.Datasetthat allow you to use pre-loaded datasets as well as your own data.Dataset stores the samples and their corresponding labels, andDataLoader wraps an iterable aroundtheDataset to enable easy access to the samples.

PyTorch domain libraries provide a number of pre-loaded datasets (such as FashionMNIST) thatsubclasstorch.utils.data.Dataset and implement functions specific to the particular data.They can be used to prototype and benchmark your model. You can find themhere:Image Datasets,Text Datasets, andAudio Datasets

Loading a Dataset#

Here is an example of how to load theFashion-MNIST dataset from TorchVision.Fashion-MNIST is a dataset of Zalando’s article images consisting of 60,000 training examples and 10,000 test examples.Each example comprises a 28×28 grayscale image and an associated label from one of 10 classes.

We load theFashionMNIST Dataset with the following parameters:
  • root is the path where the train/test data is stored,

  • train specifies training or test dataset,

  • download=True downloads the data from the internet if it’s not available atroot.

  • transform andtarget_transform specify the feature and label transformations

importtorchfromtorch.utils.dataimportDatasetfromtorchvisionimportdatasetsfromtorchvision.transformsimportToTensorimportmatplotlib.pyplotasplttraining_data=datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor())test_data=datasets.FashionMNIST(root="data",train=False,download=True,transform=ToTensor())
  0%|          | 0.00/26.4M [00:00<?, ?B/s]  0%|          | 65.5k/26.4M [00:00<01:13, 358kB/s]  1%|          | 197k/26.4M [00:00<00:46, 569kB/s]  3%|▎         | 819k/26.4M [00:00<00:13, 1.86MB/s] 12%|█▏        | 3.28M/26.4M [00:00<00:03, 6.42MB/s] 32%|███▏      | 8.39M/26.4M [00:00<00:01, 16.9MB/s] 45%|████▍     | 11.9M/26.4M [00:01<00:00, 18.0MB/s] 64%|██████▍   | 17.0M/26.4M [00:01<00:00, 25.7MB/s] 78%|███████▊  | 20.7M/26.4M [00:01<00:00, 24.1MB/s] 98%|█████████▊| 25.9M/26.4M [00:01<00:00, 30.4MB/s]100%|██████████| 26.4M/26.4M [00:01<00:00, 19.1MB/s]  0%|          | 0.00/29.5k [00:00<?, ?B/s]100%|██████████| 29.5k/29.5k [00:00<00:00, 324kB/s]  0%|          | 0.00/4.42M [00:00<?, ?B/s]  1%|▏         | 65.5k/4.42M [00:00<00:12, 361kB/s]  5%|▌         | 229k/4.42M [00:00<00:06, 680kB/s] 21%|██        | 918k/4.42M [00:00<00:01, 2.10MB/s] 83%|████████▎ | 3.67M/4.42M [00:00<00:00, 7.26MB/s]100%|██████████| 4.42M/4.42M [00:00<00:00, 6.08MB/s]  0%|          | 0.00/5.15k [00:00<?, ?B/s]100%|██████████| 5.15k/5.15k [00:00<00:00, 50.2MB/s]

Iterating and Visualizing the Dataset#

We can indexDatasets manually like a list:training_data[index].We usematplotlib to visualize some samples in our training data.

labels_map={0:"T-Shirt",1:"Trouser",2:"Pullover",3:"Dress",4:"Coat",5:"Sandal",6:"Shirt",7:"Sneaker",8:"Bag",9:"Ankle Boot",}figure=plt.figure(figsize=(8,8))cols,rows=3,3foriinrange(1,cols*rows+1):sample_idx=torch.randint(len(training_data),size=(1,)).item()img,label=training_data[sample_idx]figure.add_subplot(rows,cols,i)plt.title(labels_map[label])plt.axis("off")plt.imshow(img.squeeze(),cmap="gray")plt.show()
Dress, Trouser, Coat, Bag, Ankle Boot, Shirt, Shirt, Pullover, Dress

Creating a Custom Dataset for your files#

A custom Dataset class must implement three functions:__init__,__len__, and__getitem__.Take a look at this implementation; the FashionMNIST images are storedin a directoryimg_dir, and their labels are stored separately in a CSV fileannotations_file.

In the next sections, we’ll break down what’s happening in each of these functions.

importosimportpandasaspdfromtorchvision.ioimportdecode_imageclassCustomImageDataset(Dataset):def__init__(self,annotations_file,img_dir,transform=None,target_transform=None):self.img_labels=pd.read_csv(annotations_file)self.img_dir=img_dirself.transform=transformself.target_transform=target_transformdef__len__(self):returnlen(self.img_labels)def__getitem__(self,idx):img_path=os.path.join(self.img_dir,self.img_labels.iloc[idx,0])image=decode_image(img_path)label=self.img_labels.iloc[idx,1]ifself.transform:image=self.transform(image)ifself.target_transform:label=self.target_transform(label)returnimage,label

__init__#

The __init__ function is run once when instantiating the Dataset object. We initializethe directory containing the images, the annotations file, and both transforms (coveredin more detail in the next section).

The labels.csv file looks like:

tshirt1.jpg,0tshirt2.jpg,0......ankleboot999.jpg,9
def__init__(self,annotations_file,img_dir,transform=None,target_transform=None):self.img_labels=pd.read_csv(annotations_file)self.img_dir=img_dirself.transform=transformself.target_transform=target_transform

__len__#

The __len__ function returns the number of samples in our dataset.

Example:

def__len__(self):returnlen(self.img_labels)

__getitem__#

The __getitem__ function loads and returns a sample from the dataset at the given indexidx.Based on the index, it identifies the image’s location on disk, converts that to a tensor usingdecode_image, retrieves thecorresponding label from the csv data inself.img_labels, calls the transform functions on them (if applicable), and returns thetensor image and corresponding label in a tuple.

def__getitem__(self,idx):img_path=os.path.join(self.img_dir,self.img_labels.iloc[idx,0])image=decode_image(img_path)label=self.img_labels.iloc[idx,1]ifself.transform:image=self.transform(image)ifself.target_transform:label=self.target_transform(label)returnimage,label

Preparing your data for training with DataLoaders#

TheDataset retrieves our dataset’s features and labels one sample at a time. While training a model, we typically want topass samples in “minibatches”, reshuffle the data at every epoch to reduce model overfitting, and use Python’smultiprocessing tospeed up data retrieval.

DataLoader is an iterable that abstracts this complexity for us in an easy API.

fromtorch.utils.dataimportDataLoadertrain_dataloader=DataLoader(training_data,batch_size=64,shuffle=True)test_dataloader=DataLoader(test_data,batch_size=64,shuffle=True)

Iterate through the DataLoader#

We have loaded that dataset into theDataLoader and can iterate through the dataset as needed.Each iteration below returns a batch oftrain_features andtrain_labels (containingbatch_size=64 features and labels respectively).Because we specifiedshuffle=True, after we iterate over all batches the data is shuffled (for finer-grained control overthe data loading order, take a look atSamplers).

# Display image and label.train_features,train_labels=next(iter(train_dataloader))print(f"Feature batch shape:{train_features.size()}")print(f"Labels batch shape:{train_labels.size()}")img=train_features[0].squeeze()label=train_labels[0]plt.imshow(img,cmap="gray")plt.show()print(f"Label:{label}")
data tutorial
Feature batch shape: torch.Size([64, 1, 28, 28])Labels batch shape: torch.Size([64])Label: 5

Further Reading#

Total running time of the script: (0 minutes 4.983 seconds)