Introduction

This notebook trains and evaluates a PyTorch CNN-model on the task of classifying images from the CIFAR10 dataset https://www.cs.toronto.edu/~kriz/cifar.html. I made it with the intent of learning PyTorch and it is based on the "Training a Classifier" tutorial from the PyTorch web site https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html. The source code can be found at https://github.com/CarlFredriksson/image_classification_using_pytorch.

Let us start by importing some modules.

In [5]:
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

Prepare Data

The output of torchvision datasets are PILImage images with pixel values in the range [0,1]. We need to transform the images to PyTorch tensors and normalize their pixel values to be in the range [-1,1]. The transform RandomHorizontalFlip is a form of data augmentation that increases the size of training set by randomly flipping images horizontally with a given probability (0.5 by default) every time they are loaded. Increasing the size of the training set generally helps reduce overfitting and adding the RandomHorizontalFlip increased the final test accuracy. However, I tried a couple of the other data augmentation transforms and did not see an improvement in performance.

In [8]:
# Prepare training data
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=128, shuffle=True, num_workers=2)

# Prepare test data
test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
print('Number of training examples:', len(train_set))
print('Number of test examples:', len(test_set))
Files already downloaded and verified
Files already downloaded and verified
Number of training examples: 50000
Number of test examples: 10000

Display Data

Let us display some images from the training set to see what we are working with.

In [9]:
def imshow(img):
    img = img / 2 + 0.5 # Un-normalize
    np_img = img.numpy()
    plt.imshow(np.transpose(np_img, (1, 2, 0)))
    plt.show()
In [10]:
# Show some random training images
NUM_IMAGES_TO_SHOW = 5
data_iter = iter(train_loader)
images, labels = data_iter.next()
imshow(torchvision.utils.make_grid(images[:NUM_IMAGES_TO_SHOW]))
print('Labels:', ' '.join(f'{classes[labels[i]]}' for i in range(NUM_IMAGES_TO_SHOW)))
Labels: car cat dog truck deer

Create Model

I tried several simple CNN-models and ended up with the one below. It does not achieve state of the art performance by any means but it is decent. The model uses the common pattern of convolutional layers (+ ReLU activation function) that increase the number of channels but does not change the dimensions (height and width) followed by max pooling layers for dimensionality reduction (downsampling). A famous architecture that use this pattern is VGG16, introduced in (Simonyan, Zisserman, 2014, https://arxiv.org/abs/1409.1556). However, VGG16 is much deeper and has multiple convolutional layers per pooling layer. At the time of writing this notebook there is no option to specify padding='SAME' as in TensorFlow, which means that we have to do the padding calculations ourselves. Let $H_{in}, W_{in}$ be the height and width of the input and $H_{out}, W_{out}$ be the height and width of the output. Then

$$ H_{out} = \frac{H_{in} + 2 \times \text{padding}[0] - \text{dilation}[0] \times (\text{kernel_size}[0] - 1) - 1}{\text{stride}[0]} + 1 \\ W_{out} = \frac{W_{in} + 2 \times \text{padding}[1] - \text{dilation}[1] \times (\text{kernel_size}[1] - 1) - 1}{\text{stride}[1]} + 1 $$

When padding[0] = padding[1], kernel_size[0] = kernel_size[1], dilation[0] = dilation[1] = 1 (default), and stride[0] = stride[1] = 1 (default), the equations can be reduced to

$$ H_{out} = H_{in} + 2 \times \text{padding} - \text{kernel_size} + 1 \\ W_{out} = W_{in} + 2 \times \text{padding} - \text{kernel_size} + 1 $$

A dropout layer is added between the first and second fully connected layers to reduce overfitting. Note that there is no need to add a soft-max layer after the last fully connected layer, since it is added automatically by torch.nn.CrossEntropyLoss() during training and it is not needed during inference because argmax does not get affected by applying soft-max. If cuda is available, the network is moved to the gpu by net.to(device). This can radically reduce training time, depending on what gpu you have available. I trained the model for 100 epochs using a batch size of 128 on both an "Intel Core i7-4790S" cpu and an "NVIDIA GeForce RTX 2070 SUPER" gpu. Training on the cpu takes about 2.5 hours and training on the gpu takes about 20 minutes.

In [36]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 5, padding=2)
        self.conv2 = nn.Conv2d(16, 32, 5, padding=2)
        self.conv3 = nn.Conv2d(32, 64, 5, padding=2)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(p=0.2)
        self.fc1 = nn.Linear(64 * 4 * 4, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        # x size: N x 3 x 32 x 32
        x = F.relu(self.conv1(x))  # N x 16 x 32 x 32
        x = self.pool(x) # N x 16 x 16 x 16
        x = F.relu(self.conv2(x)) # N x 32 x 16 x 16
        x = self.pool(x) # N x 32 x 8 x 8
        x = F.relu(self.conv3(x)) # N x 64 x 8 x 8
        x = self.pool(x) # N x 64 x 4 x 4
        x = x.view(-1, 64 * 4 * 4) # N x 1024
        x = F.relu(self.fc1(x)) # N x 128
        x = self.dropout(x)
        x = F.relu(self.fc2(x)) # N x 64
        x = self.fc3(x) # N x 10
        return x
In [37]:
# Create model
net = Net()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
net.to(device)
print('Device:', device)
Device: cuda:0
In [38]:
# Create optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

Train model

Now it is time to train the model. One thing to note is that we need to move the input and labels to the same device the neural network is running on by using to(device). Another thing to remember is to zero the gradients each iteration by calling optimizer.zero_grad(). This is necessary because loss.backward() accumulates gradients (sums them).

In [39]:
def compute_test_accuracy():
    num_total = len(test_set)
    num_correct = 0
    with torch.no_grad():
        for data in test_loader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = net(images)
            predictions = torch.argmax(outputs.data, 1)
            num_correct += (predictions == labels).sum().item()
    return num_correct / num_total
In [40]:
print('Number of iterations per epoch:', len(train_loader))
Number of iterations per epoch: 391
In [41]:
# Train model
NUM_EPOCHS = 100
loss_per_epoch = np.zeros(NUM_EPOCHS)
accuracy_per_epoch = np.zeros(NUM_EPOCHS)
test_accuracy_per_epoch = np.zeros(NUM_EPOCHS)
print('Starting training')
start_time = time.perf_counter()
for epoch in range(NUM_EPOCHS):
    for i, data in enumerate(train_loader):
        inputs, labels = data[0].to(device), data[1].to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward + backward + optimize
        outputs = net(inputs)
        predictions = torch.argmax(outputs.data, 1)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        loss_per_epoch[epoch] += loss.item() / len(train_loader)
        accuracy_per_epoch[epoch] += (predictions == labels).sum().item() / len(train_set)
    # Compute test accuracy
    net.train(False)
    test_accuracy_per_epoch[epoch] = compute_test_accuracy()
    net.train(True)

    # Print statistics
    print('Epoch [{}/{}] loss: {:.5f} accuracy: {:.2f}% test accuracy: {:.2f}%'.format(
        epoch, NUM_EPOCHS - 1, loss_per_epoch[epoch],
        100 * accuracy_per_epoch[epoch],
        100 * test_accuracy_per_epoch[epoch]))
elapsed_time = time.perf_counter() - start_time
print(f'Finished training - elapsed_time: {elapsed_time:.0f} sec')
Starting training
Epoch [0/99] loss: 2.30430 accuracy: 10.12% test accuracy: 10.10%
Epoch [1/99] loss: 2.29794 accuracy: 13.22% test accuracy: 16.57%
Epoch [2/99] loss: 2.27545 accuracy: 17.33% test accuracy: 19.37%
Epoch [3/99] loss: 2.11230 accuracy: 21.62% test accuracy: 26.13%
Epoch [4/99] loss: 1.97091 accuracy: 26.78% test accuracy: 31.26%
Epoch [5/99] loss: 1.84752 accuracy: 31.80% test accuracy: 35.64%
Epoch [6/99] loss: 1.73976 accuracy: 35.57% test accuracy: 37.70%
Epoch [7/99] loss: 1.66805 accuracy: 38.19% test accuracy: 40.39%
Epoch [8/99] loss: 1.61438 accuracy: 40.21% test accuracy: 43.29%
Epoch [9/99] loss: 1.56588 accuracy: 42.24% test accuracy: 44.45%
Epoch [10/99] loss: 1.51483 accuracy: 44.48% test accuracy: 46.25%
Epoch [11/99] loss: 1.47795 accuracy: 45.98% test accuracy: 47.05%
Epoch [12/99] loss: 1.43904 accuracy: 47.42% test accuracy: 49.05%
Epoch [13/99] loss: 1.40022 accuracy: 49.15% test accuracy: 51.44%
Epoch [14/99] loss: 1.36441 accuracy: 50.41% test accuracy: 52.47%
Epoch [15/99] loss: 1.33119 accuracy: 51.75% test accuracy: 53.77%
Epoch [16/99] loss: 1.29663 accuracy: 52.92% test accuracy: 55.02%
Epoch [17/99] loss: 1.26588 accuracy: 54.31% test accuracy: 56.14%
Epoch [18/99] loss: 1.23510 accuracy: 55.41% test accuracy: 56.78%
Epoch [19/99] loss: 1.20778 accuracy: 56.48% test accuracy: 56.05%
Epoch [20/99] loss: 1.17762 accuracy: 57.75% test accuracy: 59.06%
Epoch [21/99] loss: 1.15223 accuracy: 58.52% test accuracy: 60.47%
Epoch [22/99] loss: 1.12348 accuracy: 59.71% test accuracy: 60.39%
Epoch [23/99] loss: 1.09939 accuracy: 60.67% test accuracy: 62.08%
Epoch [24/99] loss: 1.07624 accuracy: 61.51% test accuracy: 62.75%
Epoch [25/99] loss: 1.05219 accuracy: 62.61% test accuracy: 63.92%
Epoch [26/99] loss: 1.03124 accuracy: 63.22% test accuracy: 64.55%
Epoch [27/99] loss: 1.00860 accuracy: 64.12% test accuracy: 63.67%
Epoch [28/99] loss: 0.99193 accuracy: 64.81% test accuracy: 65.61%
Epoch [29/99] loss: 0.97042 accuracy: 65.38% test accuracy: 64.61%
Epoch [30/99] loss: 0.95229 accuracy: 65.95% test accuracy: 66.88%
Epoch [31/99] loss: 0.93602 accuracy: 66.72% test accuracy: 66.41%
Epoch [32/99] loss: 0.92103 accuracy: 67.40% test accuracy: 67.47%
Epoch [33/99] loss: 0.90594 accuracy: 67.76% test accuracy: 68.24%
Epoch [34/99] loss: 0.88815 accuracy: 68.53% test accuracy: 68.55%
Epoch [35/99] loss: 0.87620 accuracy: 69.00% test accuracy: 68.00%
Epoch [36/99] loss: 0.86055 accuracy: 69.56% test accuracy: 68.67%
Epoch [37/99] loss: 0.84617 accuracy: 70.08% test accuracy: 69.18%
Epoch [38/99] loss: 0.82969 accuracy: 70.77% test accuracy: 69.99%
Epoch [39/99] loss: 0.81974 accuracy: 71.15% test accuracy: 69.63%
Epoch [40/99] loss: 0.81121 accuracy: 71.24% test accuracy: 70.06%
Epoch [41/99] loss: 0.79935 accuracy: 71.70% test accuracy: 70.80%
Epoch [42/99] loss: 0.78336 accuracy: 72.28% test accuracy: 71.78%
Epoch [43/99] loss: 0.77000 accuracy: 72.87% test accuracy: 70.56%
Epoch [44/99] loss: 0.75890 accuracy: 73.36% test accuracy: 71.80%
Epoch [45/99] loss: 0.74887 accuracy: 73.53% test accuracy: 72.23%
Epoch [46/99] loss: 0.74319 accuracy: 73.91% test accuracy: 72.45%
Epoch [47/99] loss: 0.72988 accuracy: 74.33% test accuracy: 72.57%
Epoch [48/99] loss: 0.71950 accuracy: 74.51% test accuracy: 72.62%
Epoch [49/99] loss: 0.70928 accuracy: 74.81% test accuracy: 73.06%
Epoch [50/99] loss: 0.69906 accuracy: 75.47% test accuracy: 73.71%
Epoch [51/99] loss: 0.69411 accuracy: 75.68% test accuracy: 74.11%
Epoch [52/99] loss: 0.68380 accuracy: 76.05% test accuracy: 73.84%
Epoch [53/99] loss: 0.67348 accuracy: 76.31% test accuracy: 73.81%
Epoch [54/99] loss: 0.66093 accuracy: 76.85% test accuracy: 74.58%
Epoch [55/99] loss: 0.66023 accuracy: 76.81% test accuracy: 74.32%
Epoch [56/99] loss: 0.64836 accuracy: 77.06% test accuracy: 74.57%
Epoch [57/99] loss: 0.63803 accuracy: 77.63% test accuracy: 75.02%
Epoch [58/99] loss: 0.63673 accuracy: 77.64% test accuracy: 74.69%
Epoch [59/99] loss: 0.62224 accuracy: 78.22% test accuracy: 74.67%
Epoch [60/99] loss: 0.61626 accuracy: 78.29% test accuracy: 74.73%
Epoch [61/99] loss: 0.61031 accuracy: 78.68% test accuracy: 74.54%
Epoch [62/99] loss: 0.60648 accuracy: 78.73% test accuracy: 75.34%
Epoch [63/99] loss: 0.59909 accuracy: 78.98% test accuracy: 75.36%
Epoch [64/99] loss: 0.59078 accuracy: 79.31% test accuracy: 75.94%
Epoch [65/99] loss: 0.57954 accuracy: 79.71% test accuracy: 75.02%
Epoch [66/99] loss: 0.57334 accuracy: 79.96% test accuracy: 76.12%
Epoch [67/99] loss: 0.57036 accuracy: 80.06% test accuracy: 76.03%
Epoch [68/99] loss: 0.56312 accuracy: 80.20% test accuracy: 76.42%
Epoch [69/99] loss: 0.55298 accuracy: 80.56% test accuracy: 76.36%
Epoch [70/99] loss: 0.55083 accuracy: 80.67% test accuracy: 76.55%
Epoch [71/99] loss: 0.54023 accuracy: 81.18% test accuracy: 75.54%
Epoch [72/99] loss: 0.53319 accuracy: 81.13% test accuracy: 75.72%
Epoch [73/99] loss: 0.52956 accuracy: 81.46% test accuracy: 76.30%
Epoch [74/99] loss: 0.52344 accuracy: 81.64% test accuracy: 76.37%
Epoch [75/99] loss: 0.51407 accuracy: 81.90% test accuracy: 76.04%
Epoch [76/99] loss: 0.51370 accuracy: 82.02% test accuracy: 76.50%
Epoch [77/99] loss: 0.50465 accuracy: 82.10% test accuracy: 76.97%
Epoch [78/99] loss: 0.49947 accuracy: 82.42% test accuracy: 77.06%
Epoch [79/99] loss: 0.49529 accuracy: 82.56% test accuracy: 76.31%
Epoch [80/99] loss: 0.48774 accuracy: 82.76% test accuracy: 77.01%
Epoch [81/99] loss: 0.48161 accuracy: 83.02% test accuracy: 76.65%
Epoch [82/99] loss: 0.47302 accuracy: 83.29% test accuracy: 77.69%
Epoch [83/99] loss: 0.47186 accuracy: 83.17% test accuracy: 76.69%
Epoch [84/99] loss: 0.46560 accuracy: 83.57% test accuracy: 76.63%
Epoch [85/99] loss: 0.46313 accuracy: 83.68% test accuracy: 77.24%
Epoch [86/99] loss: 0.45530 accuracy: 83.82% test accuracy: 77.22%
Epoch [87/99] loss: 0.45082 accuracy: 84.07% test accuracy: 77.20%
Epoch [88/99] loss: 0.44217 accuracy: 84.40% test accuracy: 77.42%
Epoch [89/99] loss: 0.43704 accuracy: 84.33% test accuracy: 76.67%
Epoch [90/99] loss: 0.43138 accuracy: 84.76% test accuracy: 77.66%
Epoch [91/99] loss: 0.43230 accuracy: 84.64% test accuracy: 77.37%
Epoch [92/99] loss: 0.42329 accuracy: 85.02% test accuracy: 77.89%
Epoch [93/99] loss: 0.41373 accuracy: 85.36% test accuracy: 78.21%
Epoch [94/99] loss: 0.40631 accuracy: 85.63% test accuracy: 77.46%
Epoch [95/99] loss: 0.40739 accuracy: 85.55% test accuracy: 76.99%
Epoch [96/99] loss: 0.40414 accuracy: 85.56% test accuracy: 77.78%
Epoch [97/99] loss: 0.39867 accuracy: 85.80% test accuracy: 77.66%
Epoch [98/99] loss: 0.39524 accuracy: 85.97% test accuracy: 77.50%
Epoch [99/99] loss: 0.38791 accuracy: 86.26% test accuracy: 77.70%
Finished training - elapsed_time: 1246 sec
In [42]:
fig, ax = plt.subplots(1, 2, figsize=(10, 4))
ax[0].plot(np.arange(NUM_EPOCHS), loss_per_epoch)
ax[0].set_xlabel('Epoch')
ax[0].set_ylabel('Cross entropy loss')
ax[1].plot(np.arange(NUM_EPOCHS), 100 * accuracy_per_epoch, label='Training accuracy')
ax[1].plot(np.arange(NUM_EPOCHS), 100 * test_accuracy_per_epoch, label='Test accuracy')
ax[1].set_xlabel('Epoch')
ax[1].set_ylabel('Accuracy %')
ax[1].legend()
plt.show()

Save and Load Model

You obviously do not need to load the saved model if you run the whole notebook, but I included the cell for future reference and to be able to test a saved model without having to train a new one.

In [43]:
# Save trained model
save_path = './nn_cifar10.pth'
torch.save(net.state_dict(), save_path)
In [44]:
# Load saved model
net = Net()
net.load_state_dict(torch.load(save_path))
net.to(device)
print(net.training)
True

Test Model

Let us test the model on a few images, compute the test accuracy, and compute the test accuracy per class. We need to make sure that the model is in evaluation mode to disable the dropout layer.

In [45]:
# Make sure that the model is in eval mode
net.eval()
print(net.training)
False
In [46]:
# Show some test images
data_iter = iter(test_loader)
images, labels = data_iter.next()
imshow(torchvision.utils.make_grid(images[:NUM_IMAGES_TO_SHOW]))
print('Labels:', ' '.join(f'{classes[labels[i]]}' for i in range(NUM_IMAGES_TO_SHOW)))
Labels: cat ship ship plane frog
In [47]:
# Predict classes for the random test images
outputs = net(images.to(device))
predictions = torch.argmax(outputs, 1)
print('predictions:', ' '.join(f'{classes[predictions[i]]}' for i in range(NUM_IMAGES_TO_SHOW)))
predictions: cat ship ship plane frog
In [48]:
print('Accuracy of the model on the {} test images: {}%'.format(
    len(test_set), 100 * compute_test_accuracy()))
Accuracy of the model on the 10000 test images: 77.7%
In [49]:
# Check the accuracy per class
class_correct = np.zeros(len(classes))
class_total = np.zeros(len(classes))
with torch.no_grad():
    for data in test_loader:
        images, labels = data[0].to(device), data[1].to(device)
        outputs = net(images)
        predictions = torch.argmax(outputs, 1)
        predictions_correct = (predictions == labels).squeeze()
        for i in range(len(labels)):
            label = labels[i]
            class_correct[label] += predictions_correct[i].item()
            class_total[label] += 1
for i in range(len(classes)):
    print('Accuracy of {}: [{:.0f}/{:.0f}] {:.1f}%'.format(
        classes[i], class_correct[i], class_total[i], 100 * class_correct[i] / class_total[i]))
Accuracy of plane: [816/1000] 81.6%
Accuracy of car: [877/1000] 87.7%
Accuracy of bird: [699/1000] 69.9%
Accuracy of cat: [668/1000] 66.8%
Accuracy of deer: [748/1000] 74.8%
Accuracy of dog: [581/1000] 58.1%
Accuracy of frog: [822/1000] 82.2%
Accuracy of horse: [827/1000] 82.7%
Accuracy of ship: [872/1000] 87.2%
Accuracy of truck: [860/1000] 86.0%