PyTorch Beginner‘s Tutorial (8) - Few-shot Learning for Image Classification
Overview
In this article, we will train a Siamese Network (a similarity network) using the Omniglot dataset. This network can be used to assess the similarity between two images, which helps implement few-shot learning.
We will train the neural network using the Omniglot training set and use the validation set to construct the Support Set. Specifically, we will select 5 samples from each category in the validation set to form the Support Set. We will choose a total of 10 categories, which gives us a 10-way 5-shot few-shot learning setup.
Environment Setup
The environment used in this article is as follows:
``` python==3.8.5 torch==1.10.2 torchvision==0.11.3 numpy==1.22.3 matplotlib==3.2.2 ```
Import the required packages for this article:
```python import random import torch import torchvision from torch import nn from torchvision import transforms import numpy as np import matplotlib.pyplot as plt from tqdm import tqdm ```
```python device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print("Device:", device) ```
``` Device: cuda ```
Loading the Dataset
Here, we use the torchvision.datasets.Omniglot
method provided by PyTorch to load the dataset.
```python train_dataset = torchvision.datasets.Omniglot('./dataset', background=True, transform=transforms.ToTensor(), download=True) validation_dataset = torchvision.datasets.Omniglot('./dataset', background=False, transform=transforms.ToTensor(), download=True) ```
Once the dataset is successfully loaded, let's take a quick look:
```python image, target = train_dataset.__getitem__(0) print("image size:", image.size()) print("target:", target) ```
``` image size: torch.Size([1, 105, 105]) target: 0 ```
The Omniglot dataset consists of grayscale images, similar to the MNIST handwritten dataset. The target
represents the category as a number, but the actual category name doesn't matter for our purposes.
Let's quickly plot one image for a better understanding:
```python plt.imshow(image.squeeze(), cmap='gray') ```
Data Processing
In a Siamese network, the model is given a pair of images at a time, and it is tasked with determining whether the two images belong to the same category. In this section, we need to define a function that generates a batch of image pairs, where half of the pairs belong to the same category and the other half belong to different categories.
We will start by retrieving all the targets and labels from the training set:
```python all_targets = np.array([train_dataset.__getitem__(i)[1] for i in range(len(train_dataset))]) all_labels = np.array(list(set(all_targets))) ```
```python print("all_targets:", all_targets) print("all_labels:", all_labels) ```
``` all_targets: [ 0 0 0 ... 963 963 963] all_labels: [ 0 1 2 ... 959 960 961 962 963] ```
With the two basic datasets ready, we can now construct our Sample function. The function's purpose is to return a batch of image pairs, with half being of the same category (referred to as positive samples) and the other half of different categories (referred to as negative samples).
```python def sample_batch(batch_size): """ Sample some data pairs from the train_dataset. Half positive samples, half negative samples. """ # Select half of the batch's labels as positive samples, which completes the construction of positive samples. positive_labels = np.random.choice(all_labels, batch_size // 2) # For these labels, select two images of the same category for each batch = [] for label in positive_labels: labels_indexes = np.argwhere(all_targets == label) pair = np.random.choice(labels_indexes.flatten(), 2) batch.append((pair[0], pair[1], 1)) # Since the images belong to the same category, the target is 1. # Select negative samples, choosing a batch's worth of labels and picking one image for each label. # This completes the construction of negative samples. negative_labels = np.random.choice(all_labels, batch_size) for sample1, sample2 in negative_labels.reshape(-1, 2): sample1 = np.random.choice(np.argwhere(all_targets == sample1).flatten(), 1) sample2 = np.random.choice(np.argwhere(all_targets == sample2).flatten(), 1) batch.append((sample1.item(), sample2.item(), 0)) # Since the images belong to different categories, the target is 0. """ After performing the above steps, the final batch looks like this: (734, 736, 1), (127, 132, 1), ... (859, 173, 0), ... where the first two numbers represent the indices of the sample pairs in the dataset, and 1 indicates that the two samples belong to the same category. 0 indicates the samples are from different categories. Next, we need to shuffle the batch and then retrieve the corresponding data from the dataset to finally form the batch. """ random.shuffle(batch) sample1_list = [] sample2_list = [] target_list = [] for sample1, sample2, target in batch: sample1_list.append(train_dataset.__getitem__(sample1)[0]) sample2_list.append(train_dataset.__getitem__(sample2)[0]) target_list.append(target) sample1 = torch.stack(sample1_list) sample2 = torch.stack(sample2_list) targets = torch.LongTensor(target_list) return sample1, sample2, targets ```
After completing the sample function, let’s quickly test it.
```python sample1, sample2, targets = sample_batch(16) ```
```python print("sample1:", sample1.size()) print("sample2:", sample1.size()) print("targets:", targets) ```
``` sample1: torch.Size([16, 1, 105, 105]) sample2: torch.Size([16, 1, 105, 105]) targets: tensor([1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1]) ```
In this example, sample1[0]
and sample2[0]
form a pair, and targets[0]
is their label, indicating whether they belong to the same category.
Once the data is ready, we can start building the model.
Model Construction
The model we need to build is quite simple. Its function is to take two images as input and output whether these two images belong to the same category. Since this is a binary classification problem, we process the final output using the Sigmoid function:
```python class SimilarityModel(nn.Module): def __init__(self): super(SimilarityModel, self).__init__() # Define a convolutional layer for feature extraction self.conv = nn.Sequential( nn.Conv2d(1, 4, kernel_size=3), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), nn.Conv2d(4, 16, kernel_size=3), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), nn.Conv2d(16, 32, kernel_size=3), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), nn.Conv2d(32, 32, kernel_size=3), nn.ReLU(), ) # Define a linear layer to determine whether the two images are of the same category self.sim = nn.Sequential( nn.Flatten(), nn.Linear(2592, 256), nn.ReLU(), nn.Linear(256, 1), nn.Sigmoid() ) def forward(self, sample1, sample2): # Use the convolutional layer to extract features from sample1 sample1_features = self.conv(sample1) # Use the convolutional layer to extract features from sample2 sample2_features = self.conv(sample2) # Feed the result of |sample1 - sample2| into the linear layer to determine their similarity return self.sim(torch.abs(sample1_features - sample2_features)) ```
```python model = SimilarityModel() model = model.to(device) ```
After the model definition, let's take the first trial.
```python model(sample1.to(device), sample2.to(device)) ```
``` tensor([[0.5004], [0.5005], [0.5003], [0.5005], [0.5000], ... [0.5002], [0.5003], [0.5000], [0.5002]], device='cuda:0', grad_fn=<SigmoidBackward0>) ```
As we can see, since the model hasn't been trained yet, the output values are around 50%.
Training the Model
Now let's start training the model. It's not much different from a typical binary classification problem.
```python model = model.train() ```
```python criteria = nn.BCELoss() optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) ```
```python batch_size = 512 # Stop training if the loss doesn't decrease after 500 iterations early_stop = 1500 # Keep track of the minimum loss min_loss = 100. # Record the episode of the last minimum loss last_episode = 0 # Update parameters until the loss stops decreasing for episode in range(100000): # Sample a batch of data, half positive and half negative samples sample1, sample2, targets = sample_batch(batch_size) # Feed the data into the model to determine if they belong to the same class outputs = model(sample1.to(device), sample2.to(device)) # Calculate the loss using BCELoss loss = criteria(outputs.flatten(), targets.to(device).float()) loss.backward() # Update the model parameters optimizer.step() optimizer.zero_grad() # If this loss is smaller than the previous one, update the record if loss < min_loss: min_loss = loss.item() last_episode = episode torch.save(model, 'best_model.pt') # Stop training if the loss hasn't improved for {early_stop} consecutive times if episode - last_episode > early_stop: break # Print logs every 50 episodes if episode % 50 == 0: print(f"episode {episode}, loss {loss}") print("Finish Training.") ```
``` episode 0, loss 0.6931208372116089 episode 50, loss 0.6730687618255615 episode 100, loss 0.6514454483985901 episode 150, loss 0.6112750768661499 ... episode 2600, loss 0.22310321033000946 episode 2650, loss 0.24409082531929016 episode 2700, loss 0.3104301393032074 ```
Model Validation
Before starting the validation, let's load the best model from earlier:
```python model = torch.load('best_model.pt') model = model.to(device) ```
Now, let's evaluate the performance of our model using the validation dataset we created earlier. The dataset contains categories that the model has never seen before. This is not just about unseen images; the model has never encountered these categories at all. If you don't believe it, you can check the directory dataset/images_evaluation
.
First, let's take a look at the categories in the validation set:
```python all_targets = np.array([validation_dataset.__getitem__(i)[1] for i in range(len(validation_dataset))]) all_labels = np.array(list(set(all_targets))) ```
```python print("sample size:", len(all_targets)) print("all_targets:", all_targets) print("all_labels:", all_labels) ```
``` sample size: 13180 all_targets: [ 0 0 0 ... 658 658 658] all_labels: [ 0 1 2 3 ... 655 656 657 658] ```
As we can see, the validation set contains 13,180 data points across 658 categories, with each category having 20 samples. For each category, we select 5 samples as the support set for the model to reference, and the remaining 15 samples are used to validate the model’s performance.
If we were to predict for all 658 categories, this would be a 658-way, 5-shot task. This would not only make the prediction process very slow but would also result in low accuracy. So, to simplify the process, we will reduce the number of categories and validate the model using only 10 categories.
```python all_targets = all_targets[all_targets < 10] all_labels = all_labels[:10] print("sample size:", len(all_targets)) print("all_targets:", all_targets) print("all_labels:", all_labels) ```
``` sample size: 200 all_targets: [0 0 0 0 0 0 0 0 .... 9 9 9 9 9 9] all_labels: [0 1 2 3 4 5 6 7 8 9] ```
```python support_set = [] validation_set = [] # Iterate over all labels, selecting the first 5 images for the support set and the rest for the validation set for label in all_labels: label_indexes = np.argwhere(all_targets == label) support_set.append((label_indexes[:5].flatten().tolist())) validation_set += label_indexes[5:].flatten().tolist() ```
```python print("support set:", support_set[:5]) print("validation set:", validation_set[:5]) print("validation size:", len(validation_set)) ```
``` support set: [[0, 1, 2, 3, 4], [20, 21, 22, 23, 24], [40, 41, 42, 43, 44], [60, 61, 62, 63, 64], [80, 81, 82, 83, 84]] validation set: [5, 6, 7, 8, 9] validation size: 150 ```
Next, we need to define a prediction function that, given an image, outputs the target of that image. The idea of the function is: compare the image with all categories in the support set, and see which category has the highest similarity. The image will be classified into that category. Since each category has 5 images, we can calculate the average similarity to determine the likelihood that the image belongs to that category.
```python def predict(image): sim_list = [] # Stores the similarity between the image and each category # Iterate over each category; indexes hold the indices of the 5 images in the current category for indexes in support_set: # Retrieve the images from the validation dataset corresponding to the indexes tensor_list = [] for i in indexes: tensor_list.append(validation_dataset[i][0]) support_tensor = torch.stack(tensor_list) # Once we have the 5 images for the category, we can pass them to the model to calculate the similarity between the image and them, then take the average sim = model(image.repeat(5, 1, 1, 1).to(device), support_tensor.to(device)).mean() sim_list.append(sim) # Find the category with the highest similarity; this is the predicted result result_index = torch.stack(sim_list).argmax().item() return all_labels[result_index] ```
Let's try the predict
function:
```python predict(validation_dataset.__getitem__(validation_set[0])[0]) ```
``` 0 ```
Finally, we validate the data from the validation set one by one and compute the accuracy.
```python total = 0 total_correct = 0 # Since the validation set is large, to see results more quickly, we shuffle the validation set before validation random.shuffle(validation_set) progress = tqdm(validation_set) for i in progress: image, label = validation_dataset.__getitem__(i) predict_label = predict(image) total += 1 if predict_label == label: total_correct += 1 progress.set_postfix({ "accuracy": str("%.3f" % (total_correct / total)) }) ```
``` 100%|██████████| 150/150 [00:06<00:00, 21.66it/s, accuracy=0.700] ```
```python print("Accuracy:", total_correct / total) ```
``` Accuracy: 0.7 ```
In the end, we achieved an accuracy of 70% on these 150 validation samples. It’s not too high, but it at least shows that the model is effective.
Since the model architecture was designed on the fly and is relatively small, the performance drops significantly when predicting with all 658 classes. The accuracy is around 15%. Interested readers can try it out and optimize the model.