PyTorch Beginner's Tutorial (6) - Using GAN to Generate Simple Anime Character Avatar


中文 | English

Overview

This post is inspired by GAN Assignment 06 from 2021 by Professor Li Yanhong (GitHub link here). The goal is to train a GAN network that generates anime character faces. This is a beginner-level guide, so we’re using the most basic GAN model, which results in somewhat blurry anime faces. After training for just 40 epochs, the generated results look like this:


在这里插入图片描述

Global Parameters

First, let's import the necessary packages:

```python
import os
import sys

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm 
from torch.utils.tensorboard import SummaryWriter
```

Then, set some global parameters:

```python
batch_size = 64
num_workers = 2
n_epoch = 100
z_dim = 100 # Dimension of the noise vector
learning_rate = 3e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Model save path: if mounted on Google Drive (in Google Colab), the model will be saved there
ckpt_dir = 'drive/MyDrive/models' 
faces_path = "faces"  # Directory for the dataset

print("Device: ", device) # Display the device to ensure training is on GPU, not CPU
```
```
Device:  cuda
```

Dataset

The dataset consists of a collection of anime character portraits. You can download it using the following link:

```
https://pan.baidu.com/s/1zsJJJapFLr1zWWhgGol-aA 提取码:2k4z
```

After downloading, extract the contents to the current directory so that the final structure appears as:

```
faces/
├── 1.jpg
├── 2.jpg
├── 3.jpg
...
```

Utility Function

Here, I’ve defined a utility class to clear outputs during training. This helps keep the output area manageable by preventing excessive clutter.

```python
def clear_output():
    """
    Clears the output in a Jupyter Notebook.
    """
    os.system('cls' if os.name == 'nt' else 'clear')
    if 'ipykernel' in sys.modules:
        from IPython.display import clear_output as clear
        clear()
```

Data Preprocessing

We’ll define the Dataset class, resizing each portrait image to 64x64 pixels and standardizing the images:

```python
class CrypkoDataset(Dataset):
    def __init__(self, img_path='./faces'):
        self.fnames = [img_path + '/' + img for img in os.listdir(img_path)]

        self.transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((64, 64)),
            transforms.ToTensor(),
            # Here, the images are normalized to have a mean of 0.5 and a standard deviation of 0.5,
            # effectively scaling them with (x - 0.5) / 0.5.
            transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
        ])

        self.num_samples = len(self.fnames)


    def __getitem__(self, idx):
        fname = self.fnames[idx]
        img = torchvision.io.read_image(fname)
        img = self.transform(img)
        return img


    def __len__(self):
        return self.num_samples
```
```python
dataset = CrypkoDataset(faces_path)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
```

Let's run a quick test to check each method:

```python
dataset.__getitem__(0).size(), len(dataset)
```
```
(torch.Size([3, 64, 64]), 71314)
```

As we can see, the images are successfully resized to 64x64 pixels, with a total of 71,314 images.

Now, let's display a few images to see the results:

```python
images = [(dataset[i] + 1) / 2 for i in range(16)]  # Select 16 images
grid_img = torchvision.utils.make_grid(images, nrow=4)  # Arrange them into a 4x4 grid
plt.figure(figsize=(6, 6))
plt.imshow(grid_img.permute(1, 2, 0))  # plt expects channels at the end, so we use permute
plt.show()
```



A brief explanation of (dataset[i] + 1) / 2: Since we previously applied normalization, where y = (x - 0.5) / 0.5, this line reverses it to retrieve the original x values by applying x = 0.5y + 0.5 = (y + 1) / 2.

Model Definition

With the dataset ready, we can define our model. A GAN requires both a Generator and a Discriminator. The Generator is responsible for creating images, while the Discriminator evaluates whether the images are generated or real. Here, we use a DCGAN (Deep Convolutional GAN) for this purpose.

Generator

```python
class Generator(nn.Module):
    """
    Input Shape: (N, in_dim), where N is the batch size, and in_dim is the dimension of the random vector.
    Output Shape: (N, 3, 64, 64), generating N color images of size 64x64.
    """

    def __init__(self, in_dim, dim=64):
        super(Generator, self).__init__()

        def dconv_bn_relu(in_dim, out_dim):
            return nn.Sequential(
                nn.ConvTranspose2d(in_dim, out_dim, 5, 2,
                                   padding=2, output_padding=1, bias=False),
                nn.BatchNorm2d(out_dim),
                nn.ReLU()
            )

        # 1. First, use a linear layer to reshape the random vector into a 4x4 image with dim*8 channels.
        self.l1 = nn.Sequential(
            nn.Linear(in_dim, dim * 8 * 4 * 4, bias=False),
            nn.BatchNorm1d(dim * 8 * 4 * 4),
            nn.ReLU()
        )

        # 2. Then, apply a series of transposed convolutions to progressively upscale the image.
        # Each layer reduces the number of channels, ultimately producing a 3-channel, 64x64 image.
        self.l2_5 = nn.Sequential(
            dconv_bn_relu(dim * 8, dim * 4),
            dconv_bn_relu(dim * 4, dim * 2),
            dconv_bn_relu(dim * 2, dim),
            nn.ConvTranspose2d(dim, 3, 5, 2, padding=2, output_padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        y = self.l1(x)
        y = y.view(y.size(0), -1, 4, 4)
        y = self.l2_5(y)
        return y
```

Discriminator

```python
class Discriminator(nn.Module):
    """
    Input shape: (N, 3, 64, 64), where N represents the number of 64x64 color images.
    Output shape: (N,), where each value represents the probability that the Discriminator believes 
    the corresponding image is real, with values closer to 1 indicating a stronger belief.
    """

    def __init__(self, in_dim=3, dim=64): # Here, in_dim refers to the image channels, so it is set to 3.
        super(Discriminator, self).__init__()

        def conv_bn_lrelu(in_dim, out_dim):
            return nn.Sequential(
                nn.Conv2d(in_dim, out_dim, 5, 2, 2),
                nn.BatchNorm2d(out_dim),
                nn.LeakyReLU(0.2),
            )

        # This sequence of convolutional layers progressively reduces the image down to a single value.
        self.ls = nn.Sequential(
            nn.Conv2d(in_dim, dim, 5, 2, 2),
            nn.LeakyReLU(0.2),
            conv_bn_lrelu(dim, dim * 2),
            conv_bn_lrelu(dim * 2, dim * 4),
            conv_bn_lrelu(dim * 4, dim * 8),
            nn.Conv2d(dim * 8, 1, 4),
            nn.Sigmoid(),
        )

    def forward(self, x):
        y = self.ls(x)
        y = y.view(-1)
        return y
```
```python
G = Generator(in_dim=z_dim)
D = Discriminator()
G = G.to(device)
D = D.to(device)
```

Since the Discriminator is handling a binary classification problem, Binary Cross Entropy is used here.

```python
criterion = nn.BCELoss()
```
```python
opt_D = torch.optim.Adam(D.parameters(), lr=learning_rate)
opt_G = torch.optim.Adam(G.parameters(), lr=learning_rate)
```

Training the Model

We’ll use TensorBoard to log the changes in loss and to visualize the generated images:

```python
writer = SummaryWriter()
```

Now you can start TensorBoard with the following command:

```python
tensorboard --logdir runs
```

Begin training:

```python
steps = 0
log_after_step = 50 # Log the loss every 50 steps

# z vector for evaluation phase
z_sample = Variable(torch.randn(100, z_dim)).to(device)

for e, epoch in enumerate(range(n_epoch)):
    total_loss_D = 0
    total_loss_G = 0

    for i, data in enumerate(tqdm(dataloader, desc='Epoch {}: '.format(e))):
        imgs = data
        imgs = imgs.to(device)

        # Recalculate batch_size in case the last batch is smaller
        batch_size = imgs.size(0)

        # ============================================
        #  Train the Discriminator
        # ============================================
        # 1. Generate a batch of random noise vectors z
        z = Variable(torch.randn(batch_size, z_dim)).to(device)
        # 2. Get real images
        r_imgs = Variable(imgs).to(device)
        # 3. Generate fake images using the Generator
        f_imgs = G(z)

        # Create labels: real images are labeled 1, fake images are labeled 0
        r_label = torch.ones((batch_size, )).to(device)
        f_label = torch.zeros((batch_size, )).to(device)

        # Use the Discriminator to classify real and fake images
        r_logit = D(r_imgs.detach())
        f_logit = D(f_imgs.detach())

        # Calculate the Discriminator's loss
        r_loss = criterion(r_logit, r_label)
        f_loss = criterion(f_logit, f_label)
        loss_D = (r_loss + f_loss) / 2
        total_loss_D += loss_D

        # Backpropagate for the Discriminator
        D.zero_grad()
        loss_D.backward()
        opt_D.step()

        # ============================================
        # Train the Generator
        # ============================================
        # 1. Generate N fake images
        z = Variable(torch.randn(batch_size, z_dim)).to(device)
        f_imgs = G(z)

        # 2. Calculate the loss; the Generator wants the fake images to look real, so we use f_logit and r_label
        f_logit = D(f_imgs)

        # 3. Calculate the loss; the Generator wants the fake images to look real, so we use f_logit and r_label
        loss_G = criterion(f_logit, r_label)
        total_loss_G += loss_G

        # Backpropagate for the Generator
        G.zero_grad()
        loss_G.backward()
        opt_G.step()

        steps += 1

        if steps % log_after_step == 0:
            writer.add_scalars("loss", {
                "Loss_D": total_loss_D / log_after_step,
                "Loss_G": total_loss_G / log_after_step
            }, global_step=steps)

    # Clear previous output
    clear_output()

    # After each epoch, generate a sample image to check progress
    G.eval()

    # Generate and de-standardize images, then save them to the logs directory
    f_imgs_sample = (G(z_sample).data + 1) / 2.0
    if not os.path.exists('logs'):
        os.makedirs('logs')
    filename = os.path.join('logs', f'Epoch_{epoch + 1:03d}.jpg')
    # Save the generated images
    torchvision.utils.save_image(f_imgs_sample, filename, nrow=10)
    print(f' | Save some samples to {filename}.')

    # Display the generated images
    grid_img = torchvision.utils.make_grid(f_imgs_sample.cpu(), nrow=10)
    plt.figure(figsize=(10, 10))
    plt.imshow(grid_img.permute(1, 2, 0))
    plt.show()

    # Log the generated images in Tensorboard
    writer.add_image("Generated_Images", grid_img, global_step=steps)

    # Switch the Generator back to training mode
    G.train()

    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)
    # Save the model every 5 epochs
    if (e + 1) % 5 == 0 or e == 0:
        # Save the checkpoints.
        torch.save(G.state_dict(), os.path.join(ckpt_dir, 'G_{}.pth'.format(steps)))
        torch.save(D.state_dict(), os.path.join(ckpt_dir, 'D_{}.pth'.format(steps)))
```


在这里插入图片描述

I stopped the training after reaching the 40th epoch.

Now, let’s take a look at the Tensorboard panel:


在这里插入图片描述
The red line represents the loss of the Generator, while the blue line represents the loss of the Discriminator. From the graph, we can observe two key points:

  1. The Discriminator's loss is significantly lower than the Generator's loss: This is a common issue with GANs. The Discriminator's task is relatively straightforward compared to that of the Generator—after all, distinguishing between real and fake images is easier than learning to generate convincing images. As a result, the Discriminator might not provide enough useful feedback to the Generator, hindering its ability to converge.
  2. The Generator’s loss fluctuates continuously (as does the Discriminator’s): Such fluctuations are expected because the Generator's goal is to "trick" the Discriminator. Initially, it fails (higher loss), but then the Generator improves, managing to deceive the Discriminator (lower loss). However, the Discriminator also improves, leading the Generator to fail again, and so the cycle continues. If the Generator’s loss consistently declines, it might indicate that it’s too easy for the Generator to deceive the Discriminator, suggesting that the Discriminator needs improvement.

Throughout training, I also documented the progression of the generated images, which show a clear improvement in quality over time.


在这里插入图片描述

Using the Model

Once training is complete, let’s test out the model:

```python
G.eval()
inputs = torch.rand(1, 100).to(device)
outputs = G(inputs)
outputs = (outputs.data + 1) / 2.0
plt.figure(figsize=(5, 5))
plt.imshow(outputs[0].cpu().permute(1, 2, 0))
plt.show()
```


在这里插入图片描述

… Well, it’s not perfect, but at least it still resembles a girl!

Next Post Previous Post
No Comment
Add Comment
comment url