PyTorch Beginner's Tutorial (3) - Object Classification with a Simple CNN
Overview
In this post, we’ll walk through building a basic Convolutional Neural Network (CNN) from scratch using the CIFAR-10 dataset, which includes 10 classes and images sized 32x32.
--- Ready? Let’s go! ---
Data Preprocessing
First, import the necessary libraries:
```python import torch import torchvision import torchvision.transforms as transforms import torch.nn as nn import torch.optim as optim import matplotlib.pyplot as plt import numpy as np ```
Let’s verify the PyTorch version we’re using. Here, it’s version 1.10.2:
```python torch.__version__ ```
``` '1.10.2' ``` Next, we’ll define the transformation process and set the batch size: ```python # If memory is limited, consider reducing batch_size. # Generally, a larger batch_size improves model stability and training speed. However, larger isn’t always better. batch_size = 16 transform = transforms.Compose( [transforms.ToTensor(), # Convert images to Tensor format # Normalize images. The first argument is mean, the second is standard deviation (std). # The three 0.5s represent the RGB channels, normalizing each to a mean of 0.5 and std of 0.5. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) ```
Now, let’s set up the dataset. Here, we’ll use the official CIFAR10 dataset.
```python trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True) testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform) testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False) # CIFAR10 has 10 classes classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') ```
Next, let's display a few images:
```python def imshow(img): # Reverse normalization for display purposes img = img / 2 + 0.5 # unnormalize # Convert to Numpy format, as plt doesn’t recognize tensors npimg = img.numpy() # plt.imshow expects images in (height, width, channels) format, # while img is (channels, height, width), so we use transpose plt.imshow(np.transpose(npimg, (1, 2, 0))) plt.show() ```
```python # Retrieve some images and their corresponding labels dataiter = iter(trainloader) """ images: a tensor of shape (16, 3, 32, 32) 16 - batch_size, 3 - RGB channels, 32x32 - image size labels: a tensor of shape (16), with the label for each image """ images, labels = dataiter.next() """ make_grid: generates a grid of images. nrow=8 specifies the number of columns. The output shows 16 images arranged in a 2x8 grid, which is useful for displaying multiple images at once. """ imshow(torchvision.utils.make_grid(images, nrow=8)) # Display labels print(' '.join(f'{classes[labels[j]]:5s}' for j in range(batch_size))) ```
``` truck cat truck dog deer horse plane horse truck horse frog ship truck bird frog deer ```
定义CNN分类模型
```python class Net(nn.Module): def __init__(self): super().__init__() """ Define convolutional layers. For details, see nn.Conv2d documentation: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html. nn.Conv2d has three key parameters: in_channels: Number of input channels out_channels: Number of output channels kernel_size: Size of the convolution kernel stride: Step size (default is 1) padding: Padding (default is 0, meaning no padding) Note: The "2d" in Conv2d indicates that the data is "2-dimensional," such as images (height × width). Similarly, Conv1d is used for 1-dimensional data like text or signals, and Conv3d is for 3-dimensional data like videos. """ self.classifier = nn.Sequential( nn.Conv2d(3, 6, 5), # Activation function nn.ReLU(), # Downsampling with MaxPool nn.MaxPool2d(2, 2), nn.Conv2d(6, 16, 5), nn.ReLU(), nn.MaxPool2d(2, 2), # After convolution, use flatten to reshape data # Reshapes tensor from (batch_size, c, h, w) to (batch_size, c*h*w) for the fully connected layer nn.Flatten(), # Final fully connected layers nn.Linear(16 * 5 * 5, 120), nn.ReLU(), nn.Linear(120, 84), nn.ReLU(), nn.Linear(84, 10) # Note that Softmax is not called here # This is because Softmax is included within the CrossEntropyLoss function # If added here, it would be called twice, resulting in ineffective learning ) def forward(self, x): return self.classifier(x) ```
```python net = Net() # Using simple CrossEntropyLoss as the loss function, commonly used for multi-class classification criterion = nn.CrossEntropyLoss() # Using simple SGD as the optimizer optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) ```
Training the Network
Starting training on the network. Since it's relatively small, training is done directly on the CPU.
```python # Going through all training samples once is considered one epoch # For simplicity, we're training for only 20 epochs epochs = 20 for epoch in range(epochs): # Track the running loss running_loss = 0.0 for i, data in enumerate(trainloader): # trainloader returns a tuple: the first element is the batch of images, the second is the labels inputs, labels = data # Clear previous gradients optimizer.zero_grad() # Forward pass outputs = net(inputs) # Calculate loss loss = criterion(outputs, labels) # Backpropagation loss.backward() # Update parameters optimizer.step() # Record loss and print it every 2000 batches running_loss += loss.item() if i % 2000 == 1999: print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}') running_loss = 0.0 print('Finished Training') ```
``` [1, 2000] loss: 2.162 [2, 2000] loss: 1.603 [3, 2000] loss: 1.397 。。。 [18, 2000] loss: 0.704 [19, 2000] loss: 0.670 [20, 2000] loss: 0.648 Finished Training ```
After training for 20 epochs, the loss decreased to 0.648. As the loss continues to decrease, you may consider training for additional epochs for further improvements.
Testing the Model
Testing the overall accuracy of the model:
```python correct = 0 # Track the number of correct predictions total = 0 # Track the total number of predictions # No gradient computation needed in evaluation mode. with torch.no_grad(): for data in testloader: images, labels = data # Forward pass outputs = net(images) """ The shape of outputs is (16, 10), where batch size is 16 and 10 is the number of classes. Each entry in outputs provides the raw scores (before Softmax) for each class per image. torch.max is used to identify the highest score per image. torch.max takes two parameters: the tensor and the dimension (dim). Passing dim=1 means we are finding the maximum along the class dimension. torch.max returns both values and indices: values are the highest scores, indices are the class indices of these highest scores. Here, we only need the indices, so we ignore values. """ _, predicted = torch.max(outputs, 1) # Increment the total count total += labels.size(0) # Increment the correct count correct += (predicted == labels).sum().item() print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %') ```
``` Accuracy of the network on the 10000 test images: 64 % ```
Next, calculate the accuracy for each class:
```python # Track correct and total predictions for each class correct_pred = {classname: 0 for classname in classes} total_pred = {classname: 0 for classname in classes} # No gradient computation needed with torch.no_grad(): for data in testloader: images, labels = data outputs = net(images) _, predictions = torch.max(outputs, 1) # collect the correct predictions for each class for label, prediction in zip(labels, predictions): if label == prediction: correct_pred[classes[label]] += 1 total_pred[classes[label]] += 1 # print accuracy for each class for classname, correct_count in correct_pred.items(): accuracy = 100 * float(correct_count) / total_pred[classname] print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %') ```
``` Accuracy for class: plane is 70.1 % Accuracy for class: car is 73.1 % Accuracy for class: bird is 49.7 % Accuracy for class: cat is 45.1 % Accuracy for class: deer is 57.6 % Accuracy for class: dog is 52.3 % Accuracy for class: frog is 79.3 % Accuracy for class: horse is 65.2 % Accuracy for class: ship is 81.7 % Accuracy for class: truck is 71.3 % ```
References
- Pytorch Official CNN Classification Example: https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html