PyTorch Beginner‘s Tutorial (8) - Few-shot Learning for Image Classification



中文 | English

Overview

In this article, we will train a Siamese Network (a similarity network) using the Omniglot dataset. This network can be used to assess the similarity between two images, which helps implement few-shot learning.

We will train the neural network using the Omniglot training set and use the validation set to construct the Support Set. Specifically, we will select 5 samples from each category in the validation set to form the Support Set. We will choose a total of 10 categories, which gives us a 10-way 5-shot few-shot learning setup.

Environment Setup

The environment used in this article is as follows:

```
python==3.8.5
torch==1.10.2
torchvision==0.11.3
numpy==1.22.3
matplotlib==3.2.2
```

Import the required packages for this article:

```python
import random

import torch
import torchvision
from torch import nn
from torchvision import transforms

import numpy as np
import matplotlib.pyplot as plt

from tqdm import tqdm
```
```python
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)
```
```
Device: cuda
```

Loading the Dataset

Here, we use the torchvision.datasets.Omniglot method provided by PyTorch to load the dataset.

```python
train_dataset = torchvision.datasets.Omniglot('./dataset', background=True, transform=transforms.ToTensor(), download=True)
validation_dataset = torchvision.datasets.Omniglot('./dataset', background=False, transform=transforms.ToTensor(), download=True)
```

Once the dataset is successfully loaded, let's take a quick look:

```python
image, target = train_dataset.__getitem__(0)
print("image size:", image.size())
print("target:", target)
```
```
image size: torch.Size([1, 105, 105])
target: 0
```

The Omniglot dataset consists of grayscale images, similar to the MNIST handwritten dataset. The target represents the category as a number, but the actual category name doesn't matter for our purposes.

Let's quickly plot one image for a better understanding:

```python
plt.imshow(image.squeeze(), cmap='gray')
```



Data Processing

In a Siamese network, the model is given a pair of images at a time, and it is tasked with determining whether the two images belong to the same category. In this section, we need to define a function that generates a batch of image pairs, where half of the pairs belong to the same category and the other half belong to different categories.

We will start by retrieving all the targets and labels from the training set:

```python
all_targets = np.array([train_dataset.__getitem__(i)[1] for i in range(len(train_dataset))])
all_labels = np.array(list(set(all_targets)))
```
```python
print("all_targets:", all_targets)
print("all_labels:", all_labels)
```
```
all_targets: [  0   0   0 ... 963 963 963]
all_labels: [  0   1   2 ... 959 960 961 962 963]
```

With the two basic datasets ready, we can now construct our Sample function. The function's purpose is to return a batch of image pairs, with half being of the same category (referred to as positive samples) and the other half of different categories (referred to as negative samples).

```python
def sample_batch(batch_size):
    """
    Sample some data pairs from the train_dataset. Half positive samples, half negative samples.
    """

    # Select half of the batch's labels as positive samples, which completes the construction of positive samples.
    positive_labels = np.random.choice(all_labels, batch_size // 2)
    # For these labels, select two images of the same category for each
    batch = []
    for label in positive_labels:
        labels_indexes = np.argwhere(all_targets == label)
        pair = np.random.choice(labels_indexes.flatten(), 2)
        batch.append((pair[0], pair[1], 1))  # Since the images belong to the same category, the target is 1.

    # Select negative samples, choosing a batch's worth of labels and picking one image for each label.
    # This completes the construction of negative samples.
    negative_labels = np.random.choice(all_labels, batch_size)
    for sample1, sample2 in negative_labels.reshape(-1, 2):
        sample1 = np.random.choice(np.argwhere(all_targets == sample1).flatten(), 1)
        sample2 = np.random.choice(np.argwhere(all_targets == sample2).flatten(), 1)
        batch.append((sample1.item(), sample2.item(), 0))  # Since the images belong to different categories, the target is 0.

    """
    After performing the above steps, the final batch looks like this:
        (734, 736, 1),
        (127, 132, 1),
        ...
        (859, 173, 0),
        ...
    where the first two numbers represent the indices of the sample pairs in the dataset, and 1 indicates that the two samples belong to the same category. 0 indicates the samples are from different categories.
    Next, we need to shuffle the batch and then retrieve the corresponding data from the dataset to finally form the batch.
    """
    random.shuffle(batch)

    sample1_list = []
    sample2_list = []
    target_list = []
    for sample1, sample2, target in batch:
        sample1_list.append(train_dataset.__getitem__(sample1)[0])
        sample2_list.append(train_dataset.__getitem__(sample2)[0])
        target_list.append(target)
    sample1 = torch.stack(sample1_list)
    sample2 = torch.stack(sample2_list)
    targets = torch.LongTensor(target_list)
    return sample1, sample2, targets
```

After completing the sample function, let’s quickly test it.

```python
sample1, sample2, targets = sample_batch(16)
```
```python
print("sample1:", sample1.size())
print("sample2:", sample1.size())
print("targets:", targets)
```
```
sample1: torch.Size([16, 1, 105, 105])
sample2: torch.Size([16, 1, 105, 105])
targets: tensor([1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1])
```

In this example, sample1[0] and sample2[0] form a pair, and targets[0] is their label, indicating whether they belong to the same category.

Once the data is ready, we can start building the model.

Model Construction

The model we need to build is quite simple. Its function is to take two images as input and output whether these two images belong to the same category. Since this is a binary classification problem, we process the final output using the Sigmoid function:

```python
class SimilarityModel(nn.Module):

    def __init__(self):
        super(SimilarityModel, self).__init__()

        # Define a convolutional layer for feature extraction
        self.conv = nn.Sequential(
            nn.Conv2d(1, 4, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(4, 16, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 32, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 32, kernel_size=3),
            nn.ReLU(),
        )

        # Define a linear layer to determine whether the two images are of the same category
        self.sim = nn.Sequential(
            nn.Flatten(),
            nn.Linear(2592, 256),
            nn.ReLU(),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, sample1, sample2):
        # Use the convolutional layer to extract features from sample1
        sample1_features = self.conv(sample1)
        # Use the convolutional layer to extract features from sample2
        sample2_features = self.conv(sample2)
        # Feed the result of |sample1 - sample2| into the linear layer to determine their similarity
        return self.sim(torch.abs(sample1_features - sample2_features))
```
```python
model = SimilarityModel()
model = model.to(device)
```

After the model definition, let's take the first trial.

```python
model(sample1.to(device), sample2.to(device))
```
```
tensor([[0.5004],
        [0.5005],
        [0.5003],
        [0.5005],
        [0.5000],
        ...
        [0.5002],
        [0.5003],
        [0.5000],
        [0.5002]], device='cuda:0', grad_fn=<SigmoidBackward0>)
```

As we can see, since the model hasn't been trained yet, the output values are around 50%.

Training the Model

Now let's start training the model. It's not much different from a typical binary classification problem.

```python
model = model.train()
```
```python
criteria = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
```
```python
batch_size = 512
# Stop training if the loss doesn't decrease after 500 iterations
early_stop = 1500
# Keep track of the minimum loss
min_loss = 100.
# Record the episode of the last minimum loss
last_episode = 0
# Update parameters until the loss stops decreasing
for episode in range(100000):
    # Sample a batch of data, half positive and half negative samples
    sample1, sample2, targets = sample_batch(batch_size)
    # Feed the data into the model to determine if they belong to the same class
    outputs = model(sample1.to(device), sample2.to(device))
    # Calculate the loss using BCELoss
    loss = criteria(outputs.flatten(), targets.to(device).float())
    loss.backward()
    # Update the model parameters
    optimizer.step()
    optimizer.zero_grad()

    # If this loss is smaller than the previous one, update the record
    if loss < min_loss:
        min_loss = loss.item()
        last_episode = episode
        torch.save(model, 'best_model.pt')

    # Stop training if the loss hasn't improved for {early_stop} consecutive times
    if episode - last_episode > early_stop:
        break

    # Print logs every 50 episodes
    if episode % 50 == 0:
        print(f"episode {episode}, loss {loss}")

print("Finish Training.")
```
```
episode 0, loss 0.6931208372116089
episode 50, loss 0.6730687618255615
episode 100, loss 0.6514454483985901
episode 150, loss 0.6112750768661499
...
episode 2600, loss 0.22310321033000946
episode 2650, loss 0.24409082531929016
episode 2700, loss 0.3104301393032074
```

Model Validation

Before starting the validation, let's load the best model from earlier:

```python
model = torch.load('best_model.pt')
model = model.to(device)
```

Now, let's evaluate the performance of our model using the validation dataset we created earlier. The dataset contains categories that the model has never seen before. This is not just about unseen images; the model has never encountered these categories at all. If you don't believe it, you can check the directory dataset/images_evaluation.

First, let's take a look at the categories in the validation set:

```python
all_targets = np.array([validation_dataset.__getitem__(i)[1] for i in range(len(validation_dataset))])
all_labels = np.array(list(set(all_targets)))
```
```python
print("sample size:", len(all_targets))
print("all_targets:", all_targets)
print("all_labels:", all_labels)
```
```
sample size: 13180
all_targets: [  0   0   0 ... 658 658 658]
all_labels: [  0   1   2   3  ... 655 656 657 658]
```

As we can see, the validation set contains 13,180 data points across 658 categories, with each category having 20 samples. For each category, we select 5 samples as the support set for the model to reference, and the remaining 15 samples are used to validate the model’s performance.

If we were to predict for all 658 categories, this would be a 658-way, 5-shot task. This would not only make the prediction process very slow but would also result in low accuracy. So, to simplify the process, we will reduce the number of categories and validate the model using only 10 categories.

```python
all_targets = all_targets[all_targets < 10]
all_labels = all_labels[:10]
print("sample size:", len(all_targets))
print("all_targets:", all_targets)
print("all_labels:", all_labels)
```
```
sample size: 200
all_targets: [0 0 0 0 0 0 0 0 .... 9 9 9 9 9 9]
all_labels: [0 1 2 3 4 5 6 7 8 9]
```
```python
support_set = []
validation_set = []
# Iterate over all labels, selecting the first 5 images for the support set and the rest for the validation set
for label in all_labels:
    label_indexes = np.argwhere(all_targets == label)
    support_set.append((label_indexes[:5].flatten().tolist()))
    validation_set += label_indexes[5:].flatten().tolist()
```
```python
print("support set:", support_set[:5])
print("validation set:", validation_set[:5])
print("validation size:", len(validation_set))
```
```
support set: [[0, 1, 2, 3, 4], [20, 21, 22, 23, 24], [40, 41, 42, 43, 44], [60, 61, 62, 63, 64], [80, 81, 82, 83, 84]]
validation set: [5, 6, 7, 8, 9]
validation size: 150
```

Next, we need to define a prediction function that, given an image, outputs the target of that image. The idea of the function is: compare the image with all categories in the support set, and see which category has the highest similarity. The image will be classified into that category. Since each category has 5 images, we can calculate the average similarity to determine the likelihood that the image belongs to that category.

```python
def predict(image):
    sim_list = [] # Stores the similarity between the image and each category
    # Iterate over each category; indexes hold the indices of the 5 images in the current category
    for indexes in support_set:
        # Retrieve the images from the validation dataset corresponding to the indexes
        tensor_list = []
        for i in indexes:
            tensor_list.append(validation_dataset[i][0])
        support_tensor = torch.stack(tensor_list)
        # Once we have the 5 images for the category, we can pass them to the model to calculate the similarity between the image and them, then take the average
        sim = model(image.repeat(5, 1, 1, 1).to(device), support_tensor.to(device)).mean()
        sim_list.append(sim)

    # Find the category with the highest similarity; this is the predicted result
    result_index = torch.stack(sim_list).argmax().item()
    return all_labels[result_index]
```

Let's try the predict function:

```python
predict(validation_dataset.__getitem__(validation_set[0])[0])
```
```
0
```

Finally, we validate the data from the validation set one by one and compute the accuracy.

```python
total = 0
total_correct = 0

# Since the validation set is large, to see results more quickly, we shuffle the validation set before validation
random.shuffle(validation_set)
progress = tqdm(validation_set)

for i in progress:
    image, label = validation_dataset.__getitem__(i)
    predict_label = predict(image)

    total += 1
    if predict_label == label:
        total_correct += 1

    progress.set_postfix({
            "accuracy": str("%.3f" % (total_correct / total))
        })
```
```
100%|██████████| 150/150 [00:06<00:00, 21.66it/s, accuracy=0.700]
```
```python
print("Accuracy:", total_correct / total)
```
```
Accuracy: 0.7
```

In the end, we achieved an accuracy of 70% on these 150 validation samples. It’s not too high, but it at least shows that the model is effective.

Since the model architecture was designed on the fly and is relatively small, the performance drops significantly when predicting with all 658 classes. The accuracy is around 15%. Interested readers can try it out and optimize the model.

Next Post Previous Post
No Comment
Add Comment
comment url