• Tutorials >
  • Transfer Learning for Computer Vision Tutorial
Shortcuts

Transfer Learning for Computer Vision Tutorial

Author: Sasank Chilamkurthy

In this tutorial, you will learn how to train a convolutional neural network for image classification using transfer learning. You can read more about the transfer learning at cs231n notes

Quoting these notes,

In practice, very few people train an entire Convolutional Network from scratch (with random initialization), because it is relatively rare to have a dataset of sufficient size. Instead, it is common to pretrain a ConvNet on a very large dataset (e.g. ImageNet, which contains 1.2 million images with 1000 categories), and then use the ConvNet either as an initialization or a fixed feature extractor for the task of interest.

These two major transfer learning scenarios look as follows:

  • Finetuning the convnet: Instead of random initialization, we initialize the network with a pretrained network, like the one that is trained on imagenet 1000 dataset. Rest of the training looks as usual.
  • ConvNet as fixed feature extractor: Here, we will freeze the weights for all of the network except that of the final fully connected layer. This last fully connected layer is replaced with a new one with random weights and only this layer is trained.
# License: BSD
# Author: Sasank Chilamkurthy

from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy

cudnn.benchmark = True
plt.ion()   # interactive mode

Load Data

We will use torchvision and torch.utils.data packages for loading the data.

The problem we’re going to solve today is to train a model to classify ants and bees. We have about 120 training images each for ants and bees. There are 75 validation images for each class. Usually, this is a very small dataset to generalize upon, if trained from scratch. Since we are using transfer learning, we should be able to generalize reasonably well.

This dataset is a very small subset of imagenet.

Note

Download the data from here and extract it to the current directory.

# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

data_dir = 'data/hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                             shuffle=True, num_workers=4)
              for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Visualize a few images

Let’s visualize a few training images so as to understand the data augmentations.

def imshow(inp, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated


# Get a batch of training data
inputs, classes = next(iter(dataloaders['train']))

# Make a grid from batch
out = torchvision.utils.make_grid(inputs)

imshow(out, title=[class_names[x] for x in classes])
../_images/sphx_glr_transfer_learning_tutorial_001.png

Training the model

Now, let’s write a general function to train a model. Here, we will illustrate:

  • Scheduling the learning rate
  • Saving the best model

In the following, parameter scheduler is an LR scheduler object from torch.optim.lr_scheduler.

def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:4f}')

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

Visualizing the model predictions

Generic function to display predictions for a few images

def visualize_model(model, num_images=6):
    was_training = model.training
    model.eval()
    images_so_far = 0
    fig = plt.figure()

    with torch.no_grad():
        for i, (inputs, labels) in enumerate(dataloaders['val']):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            for j in range(inputs.size()[0]):
                images_so_far += 1
                ax = plt.subplot(num_images//2, 2, images_so_far)
                ax.axis('off')
                ax.set_title(f'predicted: {class_names[preds[j]]}')
                imshow(inputs.cpu().data[j])

                if images_so_far == num_images:
                    model.train(mode=was_training)
                    return
        model.train(mode=was_training)

Finetuning the convnet

Load a pretrained model and reset final fully connected layer.

model_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_features
# Here the size of each output sample is set to 2.
# Alternatively, it can be generalized to nn.Linear(num_ftrs, len(class_names)).
model_ft.fc = nn.Linear(num_ftrs, 2)

model_ft = model_ft.to(device)

criterion = nn.CrossEntropyLoss()

# Observe that all parameters are being optimized
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

Train and evaluate

It should take around 15-25 min on CPU. On GPU though, it takes less than a minute.

model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,
                       num_epochs=25)

Out:

Epoch 0/24
----------
train Loss: 0.6104 Acc: 0.7008
val Loss: 0.6322 Acc: 0.7124

Epoch 1/24
----------
train Loss: 0.4168 Acc: 0.8402
val Loss: 0.3723 Acc: 0.8105

Epoch 2/24
----------
train Loss: 0.6424 Acc: 0.7090
val Loss: 0.8405 Acc: 0.6340

Epoch 3/24
----------
train Loss: 0.4585 Acc: 0.8074
val Loss: 0.2838 Acc: 0.8889

Epoch 4/24
----------
train Loss: 0.4374 Acc: 0.8361
val Loss: 0.2964 Acc: 0.8824

Epoch 5/24
----------
train Loss: 0.5119 Acc: 0.7951
val Loss: 0.2815 Acc: 0.8954

Epoch 6/24
----------
train Loss: 0.5043 Acc: 0.8402
val Loss: 0.2270 Acc: 0.9346

Epoch 7/24
----------
train Loss: 0.2356 Acc: 0.9057
val Loss: 0.2393 Acc: 0.9150

Epoch 8/24
----------
train Loss: 0.2722 Acc: 0.8893
val Loss: 0.2385 Acc: 0.9281

Epoch 9/24
----------
train Loss: 0.3548 Acc: 0.8689
val Loss: 0.2182 Acc: 0.9346

Epoch 10/24
----------
train Loss: 0.3702 Acc: 0.8648
val Loss: 0.2113 Acc: 0.9281

Epoch 11/24
----------
train Loss: 0.2236 Acc: 0.9098
val Loss: 0.2301 Acc: 0.9281

Epoch 12/24
----------
train Loss: 0.3456 Acc: 0.8566
val Loss: 0.2457 Acc: 0.9281

Epoch 13/24
----------
train Loss: 0.2549 Acc: 0.9016
val Loss: 0.2738 Acc: 0.9085

Epoch 14/24
----------
train Loss: 0.3553 Acc: 0.8607
val Loss: 0.2935 Acc: 0.8824

Epoch 15/24
----------
train Loss: 0.2945 Acc: 0.8770
val Loss: 0.2390 Acc: 0.9216

Epoch 16/24
----------
train Loss: 0.2350 Acc: 0.9016
val Loss: 0.2376 Acc: 0.9216

Epoch 17/24
----------
train Loss: 0.3603 Acc: 0.8238
val Loss: 0.2450 Acc: 0.9281

Epoch 18/24
----------
train Loss: 0.2551 Acc: 0.8852
val Loss: 0.2362 Acc: 0.9281

Epoch 19/24
----------
train Loss: 0.3012 Acc: 0.8689
val Loss: 0.2332 Acc: 0.9281

Epoch 20/24
----------
train Loss: 0.3427 Acc: 0.8689
val Loss: 0.2539 Acc: 0.9281

Epoch 21/24
----------
train Loss: 0.3070 Acc: 0.8525
val Loss: 0.2490 Acc: 0.9281

Epoch 22/24
----------
train Loss: 0.2947 Acc: 0.8770
val Loss: 0.2319 Acc: 0.9281

Epoch 23/24
----------
train Loss: 0.2480 Acc: 0.9016
val Loss: 0.2276 Acc: 0.9281

Epoch 24/24
----------
train Loss: 0.3127 Acc: 0.8648
val Loss: 0.2424 Acc: 0.9150

Training complete in 1m 7s
Best val Acc: 0.934641
visualize_model(model_ft)
../_images/sphx_glr_transfer_learning_tutorial_002.png

ConvNet as fixed feature extractor

Here, we need to freeze all the network except the final layer. We need to set requires_grad = False to freeze the parameters so that the gradients are not computed in backward().

You can read more about this in the documentation here.

model_conv = torchvision.models.resnet18(pretrained=True)
for param in model_conv.parameters():
    param.requires_grad = False

# Parameters of newly constructed modules have requires_grad=True by default
num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, 2)

model_conv = model_conv.to(device)

criterion = nn.CrossEntropyLoss()

# Observe that only parameters of final layer are being optimized as
# opposed to before.
optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)

Train and evaluate

On CPU this will take about half the time compared to previous scenario. This is expected as gradients don’t need to be computed for most of the network. However, forward does need to be computed.

model_conv = train_model(model_conv, criterion, optimizer_conv,
                         exp_lr_scheduler, num_epochs=25)

Out:

Epoch 0/24
----------
train Loss: 0.8283 Acc: 0.6680
val Loss: 0.2313 Acc: 0.9412

Epoch 1/24
----------
train Loss: 0.3411 Acc: 0.8566
val Loss: 0.2130 Acc: 0.9085

Epoch 2/24
----------
train Loss: 0.4557 Acc: 0.7951
val Loss: 0.2120 Acc: 0.9346

Epoch 3/24
----------
train Loss: 0.3622 Acc: 0.8402
val Loss: 0.2383 Acc: 0.9085

Epoch 4/24
----------
train Loss: 0.5417 Acc: 0.7623
val Loss: 0.1973 Acc: 0.9412

Epoch 5/24
----------
train Loss: 0.4069 Acc: 0.8402
val Loss: 0.2052 Acc: 0.9477

Epoch 6/24
----------
train Loss: 0.4072 Acc: 0.7869
val Loss: 0.2844 Acc: 0.8824

Epoch 7/24
----------
train Loss: 0.3897 Acc: 0.8361
val Loss: 0.1749 Acc: 0.9477

Epoch 8/24
----------
train Loss: 0.3548 Acc: 0.8197
val Loss: 0.1757 Acc: 0.9477

Epoch 9/24
----------
train Loss: 0.3292 Acc: 0.8484
val Loss: 0.1728 Acc: 0.9477

Epoch 10/24
----------
train Loss: 0.4108 Acc: 0.8115
val Loss: 0.1748 Acc: 0.9412

Epoch 11/24
----------
train Loss: 0.4019 Acc: 0.8033
val Loss: 0.1815 Acc: 0.9477

Epoch 12/24
----------
train Loss: 0.4609 Acc: 0.7869
val Loss: 0.1933 Acc: 0.9281

Epoch 13/24
----------
train Loss: 0.4383 Acc: 0.7828
val Loss: 0.1774 Acc: 0.9477

Epoch 14/24
----------
train Loss: 0.2799 Acc: 0.8730
val Loss: 0.1831 Acc: 0.9412

Epoch 15/24
----------
train Loss: 0.3141 Acc: 0.8443
val Loss: 0.1811 Acc: 0.9477

Epoch 16/24
----------
train Loss: 0.2609 Acc: 0.9139
val Loss: 0.1956 Acc: 0.9346

Epoch 17/24
----------
train Loss: 0.3234 Acc: 0.8279
val Loss: 0.1788 Acc: 0.9477

Epoch 18/24
----------
train Loss: 0.3325 Acc: 0.8607
val Loss: 0.1541 Acc: 0.9477

Epoch 19/24
----------
train Loss: 0.3555 Acc: 0.8361
val Loss: 0.1735 Acc: 0.9477

Epoch 20/24
----------
train Loss: 0.3300 Acc: 0.8443
val Loss: 0.1767 Acc: 0.9608

Epoch 21/24
----------
train Loss: 0.3155 Acc: 0.8607
val Loss: 0.1737 Acc: 0.9477

Epoch 22/24
----------
train Loss: 0.3697 Acc: 0.8443
val Loss: 0.1803 Acc: 0.9281

Epoch 23/24
----------
train Loss: 0.2814 Acc: 0.8648
val Loss: 0.1667 Acc: 0.9477

Epoch 24/24
----------
train Loss: 0.3470 Acc: 0.8525
val Loss: 0.2084 Acc: 0.9281

Training complete in 0m 40s
Best val Acc: 0.960784
visualize_model(model_conv)

plt.ioff()
plt.show()
../_images/sphx_glr_transfer_learning_tutorial_003.png

Further Learning

If you would like to learn more about the applications of transfer learning, checkout our Quantized Transfer Learning for Computer Vision Tutorial.

Total running time of the script: ( 1 minutes 53.724 seconds)

Gallery generated by Sphinx-Gallery

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources