Comprehensive Guide to Datasets and Dataloaders in PyTorch
The full guide to creating custom datasets and dataloaders for different models in PyTorchSource: GPT4o GeneratedBefore you can build a machine learning model, you need to load your data into a dataset. Luckily, PyTorch has many commands to help with this entire process (if you are not familiar with PyTorch I recommend refreshing on the basics here).PyTorch has good documentation to help with this process, but I have not found any comprehensive documentation or tutorials towards custom datasets. I’m first going to start with creating basic premade datasets and then work my way up to creating datasets from scratch for different models!What is a Dataset and Dataloader?Before we dive into code for different use cases, let’s understand the difference between the two terms. Generally, you first create your dataset and then create a dataloader. A dataset contains the features and labels from each data point that will be fed into the model. A dataloader is a custom PyTorch iterable that makes it easy to load data with added features.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, *, prefetch_factor=2, persistent_workers=False)The most common arguments in the dataloader are batch_size, shuffle (usually only for the training data), num_workers (to multi-process loading the data), and pin_memory (to put the fetched data Tensors in pinned memory and enable faster data transfer to CUDA-enabled GPUs).It is recommended to set pin_memory = True instead of specifying num_workers due to multiprocessing complications with CUDA.Loading a Premade DatasetIn the case that your dataset is downloaded from online or locally, it will be extremely simple to create the dataset. I think PyTorch has good documentation on this, so I will be brief.If you know the dataset is either from PyTorch or PyTorch-compatible, simply call the necessary imports and the dataset of choice:from torch.utils.data import Datasetfrom torchvision import datasetsfrom torchvision.transforms imports ToTensordata = torchvision.datasets.CIFAR10('path', train=True, transform=ToTensor())Each dataset will have unique arguments to pass into it (found here). In general, it will be the path the dataset is stored at, a boolean indicating if it needs to be downloaded or not (conveniently called download), whether it is training or testing, and if transforms need to be applied.TransformsI dropped in that transforms can be applied to a dataset at the end of the last section, but what actually is a transform?A transform is a method of manipulating data for preprocessing an image. There are many different facets to transforms. The most common transform, ToTensor(), will convert the dataset to tensors (needed to input into any model). Other transforms built into PyTorch (torchvision.transforms) include flipping, rotating, cropping, normalizing, and shifting images. These are typically used so the model can generalize better and doesn’t overfit to the training data. Data augmentations can also be used to artificially increase the size of the dataset if needed.Beware most torchvision transforms only accept Pillow image or tensor formats (not numpy). To convert, simply useTo convert from numpy, either create a torch tensor or use the following:From PIL import Image# assume arr is a numpy array# you may need to normalize and cast arr to np.uint8 depending on formatimg = Image.fromarray(arr)Transforms can be applied simultaneously using torchvision.transforms.compose. You can combine as many transforms as needed for the dataset. An example is shown below:import torchvision.transforms.Composedataset_transform = transforms.Compose([ transforms.RandomResizedCrop(256), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])Be sure to pass the saved transform as an argument into the dataset for it to be applied in the dataloader.Creating a Custom DatasetIn most cases of developing your own model, you will need a custom dataset. A common use case would be transfer learning to apply your own dataset on a pretrained model.There are 3 required parts to a PyTorch dataset class: initialization, length, and retrieving an element.__init__: To initialize the dataset, pass in the raw and labeled data. The best practice is to pass in the raw image data and labeled data separately.__len__: Return the length of the dataset. Before creating the dataset, the raw and labeled data should be checked to be the same size.__getitem__: This is where all the data handling occurs to return a given index (idx) of the raw and labeled data. If any transforms need to be applied, the data must be converted to a tensor and transformed. If the initialization contained a path to the dataset, the path must be opened and data accessed/preprocessed before it can be returned.Example dataset for a semantic segmentation model:from torch.utils.data import Datasetfrom torchvision import transformsclass ExampleDataset(Dataset): """Example dataset""" def __init__(self, raw_img, data_mask, transform=None): self.raw_img = raw_img self.data_mask = data_mask self.transform = transform def __len__(self): return len(self.raw_img) def __getitem__(self, idx): if torch.is_tensor(idx): idx = idx.tolist() image = self.raw_img[idx] mask = self.data_mask[idx] sample = {'image': image, 'mask': mask} if self.transform: sample = self.transform(sample) return sampleIt is important to look at the input of the first layer of the model (especially for a pretrained model), to make sure the shape of the data matches the input shape. If not, you may need to adjust the dimensions. This is common if the input image is a greyscale n x n array, but the model requires a channel dimension (1 x 256 x 256).After the dataset and dataloader are applied, the format of the data should be NCHW (batch size, channel size, height, width). Reformatting can be done in the __getitem__ method before outputting to the model.Splitting the DatasetWhile creating the dataset, you may want to split into a training, testing, and validation dataset. This can be done using a built-in PyTorch function and specifying the sizes. Make sure the dataset splits add up to the total length of the dataset.from torch.utils.data import random_splittrain, val, test = random_split(dataset, [train_size, val_size, test_size])Data LabelsThere can be different data labels depending on the model: classification, object detection, or segmentation. A model classification label will contain a class label if it is multiclass or a binary number if it is binary. An object detection model will contain a bounding box of coordinates as the label. A semantic segmentation model will contain a binary mask matching the size of the raw image data. An instance segmentation contains all mask data in the raw image data.Creating a dataset is a foundational aspect of model development. By having a faulty dataset, there will be many errors downstream in training or evaluating the model. The most common errors to watch out for are shape or type mismatches. By following this and referring to PyTorch docs, you should have a working dataset!ReferencesDatasets & DataLoaders - PyTorch Tutorials 2.3.0+cu121 documentationWriting Custom Datasets, DataLoaders and Transforms - PyTorch Tutorials 2.3.0+cu121 documentationTransforming and augmenting images - Torchvision 0.18 documentationCompose - Torchvision main documentationComprehensive Guide to Datasets and Dataloaders in PyTorch was originally published in Towards Data Science on Medium, where people are continuing the conversation by highlighting and responding to this story.
Welcome to Billionaire Club Co LLC, your gateway to a brand-new social media experience! Sign up today and dive into over 10,000 fresh daily articles and videos curated just for your enjoyment. Enjoy the ad free experience, unlimited content interactions, and get that coveted blue check verification—all for just $1 a month!
Account Frozen
Your account is frozen. You can still view content but cannot interact with it.
Please go to your settings to update your account status.
Open Profile Settings