Artificial Inteligence
  • Preface
  • Introduction
  • Machine Learning
    • Linear Algebra
    • Supervised Learning
      • Neural Networks
      • Linear Classification
      • Loss Function
      • Model Optimization
      • Backpropagation
      • Feature Scaling
      • Model Initialization
      • Recurrent Neural Networks
        • Machine Translation Using RNN
    • Deep Learning
      • Convolution
      • Convolutional Neural Networks
      • Fully Connected Layer
      • Relu Layer
      • Dropout Layer
      • Convolution Layer
        • Making faster
      • Pooling Layer
      • Batch Norm layer
      • Model Solver
      • Object Localization and Detection
      • Single Shot Detectors
        • Yolo
        • SSD
      • Image Segmentation
      • GoogleNet
      • Residual Net
      • Deep Learning Libraries
    • Unsupervised Learning
      • Principal Component Analysis
      • Generative Models
    • Distributed Learning
    • Methodology for usage
      • Imbalanced/Missing Datasets
  • Artificial Intelligence
    • OpenAI Gym
    • Tree Search
    • Markov Decision process
    • Reinforcement Learning
      • Q_Learning_Simple
      • Deep Q Learning
      • Deep Reinforcement Learning
    • Natural Language Processing
      • Word2Vec
  • Appendix
    • Statistics and Probability
      • Probability
        • Markov Chains
        • Random Walk
    • Lua and Torch
    • Tensorflow
      • Multi Layer Perceptron MNIST
      • Convolution Neural Network MNIST
      • SkFlow
    • PyTorch
      • Transfer Learning
      • DataLoader and DataSets
      • Visualizing Results
Powered by GitBook
On this page
  • Dataset parent class
  • Instantiating the dataset and passing to the dataloader
  • Tranformation
  • References:

Was this helpful?

  1. Appendix
  2. PyTorch

DataLoader and DataSets

PyTorch provides some helper functions to load data, shuffling, and augmentations. This section we will learn more about it.

Data loading in PyTorch can be separated in 2 parts:

  • Data must be wrapped on a Dataset parent class where the methods __getitem__ and __len__ must be overrided. Not that at this point the data is not loaded on memory. PyTorch will only load what is needed to the memory.

  • Use a Dataloader that will actually read the data and put into memory.

The example shown here is going to be used to load data from our driverless car demo.

Dataset parent class

So let's create a class that is inherited from the Dataset class, here we will provide functions to gather data and also to know the number of items, but we will not load the whole thing in memory.

import torch
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
import numpy as np
import os
from PIL import Image
import matplotlib.pyplot as plt
FOLDER_DATASET = "./Track_1_Wheel_Test/"
plt.ion()

class DriveData(Dataset):
    __xs = []
    __ys = []

    def __init__(self, folder_dataset, transform=None):
        self.transform = transform
        # Open and load text file including the whole training data
        with open(folder_dataset + "data.txt") as f:
            for line in f:
                # Image path
                self.__xs.append(folder_dataset + line.split()[0])        
                # Steering wheel label
                self.__ys.append(np.float32(line.split()[1]))

    # Override to give PyTorch access to any image on the dataset
    def __getitem__(self, index):
        img = Image.open(self.__xs[index])
        img = img.convert('RGB')
        if self.transform is not None:
            img = self.transform(img)

        # Convert image and label to torch tensors
        img = torch.from_numpy(np.asarray(img))
        label = torch.from_numpy(np.asarray(self.__ys[index]).reshape([1,1]))
        return img, label

    # Override to give PyTorch size of dataset
    def __len__(self):
        return len(self.__xs)

Instantiating the dataset and passing to the dataloader

dset_train = DriveData(FOLDER_DATASET)
train_loader = DataLoader(dset_train, batch_size=10, shuffle=True, num_workers=1)

Now pytorch will manage for you all the shuffling management and loading (multi-threaded) of your data.

# Get a batch of training data
imgs, steering_angle = next(iter(train_loader))
print('Batch shape:',imgs.numpy().shape)
plt.imshow(imgs.numpy()[0,:,:,:])
plt.show()
plt.imshow(imgs.numpy()[-1,:,:,:])
plt.show()

# If you want the batch on a for-loop
# for batch_idx, (data, target) in enumerate(train_loader):

Tranformation

PyTorch also has a mechanism to apply simple transformations on the image

References:

PreviousTransfer LearningNextVisualizing Results

Last updated 5 years ago

Was this helpful?

https://www.kaggle.com/mratsim/starting-kit-for-pytorch-deep-learning
https://github.com/pytorch/tutorials/issues/78