模型泛化技巧“随机权重平均(Stochastic Weight Averaging, SWA)”介绍与Pytorch Lightning的SWA实现讲解

中文 | English

SWA简介

SWA,全程为“Stochastic Weight Averaging”(随机权重平均)。它是一种深度学习中提高模型泛化能力的一种常用技巧。

其思路为:对于模型的权重,不直接使用最后的权重,而是将之前的权重做个平均

该方法适用于深度学习,不限领域、不限Optimzer,可以和多种技巧同时使用。

SWA公式

我们的模型参数记为:$\theta=\{w_0, w_1, w_2, \cdots, w_n\}$, $n$ 为模型总参数量。

对于模型的训练,会在epoch结束后保存一个副本,第 $t$ 个epoch的模型参数记为 $\theta_t$。

则我们模型的最终参数为:

$$ \begin{aligned} \bar{\theta} = \frac{1}{T} \sum^T_{t=1}\theta_t \end{aligned} $$

其中 $T$ 表示我们有 $T$ 个不同个模型参数的副本。

该公式的意思就是将前面t个模型的权重取平均,然后作为最终的模型参数。

注意事项:

  1. 通常只在一个epoch结束后保存模型参数副本。
  2. 并不是每个epoch都要保存模型副本。通常会从模型开始很好地收敛后再开始保存模型参数副本。

SWA常见参数

通常我们在使用SWA时会有如下的超参数:

  1. SWA Start:从第几个epoch再开始保存模型副本。若在模型还不能很好的收敛时就开始保存模型参数副本,可能会损害模型的性能。
  2. SWA Learning Rate:在SWA期间采用学习率。例如,我们设置在第20个epoch开始进行SWA,则在第20个epoch后就会采用你指定的SWA Learning Rate,而不是之前的。

Pytorch Lightning的SWA源码分析

本节展示一下Pytorch Lightning中对SWA的实现,以便更清晰的认识SWA。

在开始看代码前,明确几个在Pytorch Lightning实现中的几个重要的概念:

  1. 平均模型(self._average_model):Pytorch Lightning会将平均的后的模型存入该变量中。
  2. pl_module:该变量为当前的模型。
```python
class StochasticWeightAveraging(Callback):
    def __init__(
        self,
        swa_lrs: Union[float, List[float]], # swa的学习率
        # swa_epoch_start: 从第0.8位置的epoch开始,例如一共100个epoch,那就从第81个epoch开始swa。
        #                  若指定整数,则会从指定的epoch开始swa。
        swa_epoch_start: Union[int, float] = 0.8, 
        annealing_epochs: int = 10,   # 模拟退火的epoch数。SWALR学习策略用的参数
        annealing_strategy: str = "cos", # 模拟退火策略。SWALR学习策略用的参数
        avg_fn: Optional[_AVG_FN] = None, # 平局函数,做模型参数平均时使用的函数,通常不需要指定。会使用默认的。
        device: Optional[Union[torch.device, str]] = torch.device("cpu"), # 平均后的model存在哪个device上
    ):
    ...

    def on_train_epoch_start(self, ...): # 在每个epoch开始前执行
        if (not self._initialized) and (self.swa_start <= trainer.current_epoch <= self.swa_end):
            # 初始化SWA,在整个SWA过程中只执行一遍
            self._initialized = True
            ...

            # 使用原来的optimizer
            optimizer = trainer.optimizers[0]

            ...

            # 使用SWALR学习率策略(SWA Learning Scheduler),后面会讲
            self._swa_scheduler = cast(
                LRScheduler,
                SWALR(
                    optimizer,
                    swa_lr=self._swa_lrs,  # type: ignore[arg-type]
                    anneal_epochs=self._annealing_epochs,
                    anneal_strategy=self._annealing_strategy,
                    last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1,
                ),
            )

        # end if, 初始化代码结束。

        # 接下来是SWA在epoch开始前的处理逻辑
        if (self.swa_start <= trainer.current_epoch <= self.swa_end):
            # 在SWA期间,每个epoch开始前将当前的模型参数更新到“平均模型”上。
            self.update_parameters(self._average_model, pl_module, self.n_averaged, self._avg_fn)

        if trainer.current_epoch == self.swa_end + 1:
            # 到最后结束的时候,将平均模型的参数迁移到模型上。
            self.transfer_weights(self._average_model, pl_module)

    @staticmethod
    def update_parameters(
        average_model: "pl.LightningModule", model: "pl.LightningModule", n_averaged: Tensor, avg_fn: _AVG_FN
    ) -> None:
        for p_swa, p_model in zip(average_model.parameters(), model.parameters()):
            device = p_swa.device
            p_swa_ = p_swa.detach()
            p_model_ = p_model.detach().to(device)
            src = p_model_ if n_averaged == 0 else avg_fn(p_swa_, p_model_, n_averaged.to(device))
            p_swa_.copy_(src)
        n_averaged += 1

    @staticmethod
    def avg_fn(averaged_model_parameter: Tensor, model_parameter: Tensor, num_averaged: Tensor) -> Tensor:
        return averaged_model_parameter + (model_parameter - averaged_model_parameter) / (num_averaged + 1)     
```

从上述Pytorch Lightning对SWA实现的源码中我们可以获得以下信息:

  1. 使用SWA需要指定SWA学习率从哪个epoch开始这两个最重要的参数。
  2. 在开始SWA后,将会使用新的“swa_lrs”学习率和新的“SWALR”学习率策略。(但在“退火”期间,会参考模型原本的学习率)
  3. 每个epoch开始前,会把上一个epoch学习到的模型参数更新到“平均模型”上。
  4. SWA期间,使用的Optimizer和之前一样。例如你模型训练时用的是Adam,则SWA期间也用Adam。

SWALR

在上面我们提到了Pytorch Lightning实现中,在SWA期间使用的是SWALR。

SWALR使用的是“模拟退火”策略,简单来说就是:学习率是从原本的学习率逐渐过度到SWA学习率的。例如,原本你使用的学习率是0.1,指定的SWA学习率为0.01,从第20个epoch开始进行SWA。那么并不是到第20个epoch后学习率立刻从0.1变到0.01,而是从0.1逐渐过度到0.01,过度的epoch数就是指定的annealing_epochs参数,而过度时减小的策略就是annealing_strategy参数。

这里不使用难以理解的源码或数学,而是来通过几组实验来直观的观察一下SWALR策略下的学习率的变化来进行解释:


在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

上述实验为:模型训练过程中学习率随epoch的变化,横坐标为epoch,纵坐标为这个epoch使用的学习率。其中图上的几个参数分别为:

  • model_lr:模型一开始使用的学习率。
  • swa_lr:用户指定的swa学习率
  • swa_epoch_start:从第几个epoch开始swa
  • annealing_epoch:模拟退火的epoch数
  • annealing_strategy:模拟退火策略。目前仅支持“cos”和“linear”两种。


在这里插入图片描述

例如对于图一意思就是:模型一开始在Optimizer上指定的学习率是0.1,SWA学习率为0.001,从第2个epoch开始进行SWA,总共进行10(annealing_epochs) 个epoch将学习率从0.1逐渐过度到0.001,学习率调整使用cos策略。

从上述图中很容易得出以下结论:

  1. 所谓的SWALR学习率策略就是让学习率从原来的学习率逐渐过度到swa学习率。过度的epoch数就是annealing_epoch
  2. 若你指定的swa学习率和之前的是一样的,那么SWALR相当于什么都没做。(图二)
  3. 若你指定的swa学习率比之前的学习率高,那么学习率就会逐渐升高(图三)。不过通常不会这么做,通常swa_lr要比model_lr小才对,因为到后面模型都稳定了,不能再用更高的学习率了。
  4. 若annealing_epoch数较小,那么“退火”速度较快,即从model_lr到swa_lr的过度速度就较快(图四),反正则慢。
  5. “cos”退火策略下学习率变化是先慢,然后快,最后再慢。(图五),而“linear”实现线性策略变化速度是一样的。(图六)

实验环境与代码如下:

```
lightning==2.0.1
pytorch==1.13.0
```

实验代码如下:

```python
import torch
import torch.nn as nn

import lightning.pytorch as pl
from lightning.pytorch.callbacks import StochasticWeightAveraging

from matplotlib import pyplot as plt

import numpy as np

def plot_swa_lr_curve(model_lr,  # 模型的学习率
                      swa_lr,  # swa的学习率
                      swa_epoch_start=2,  # 从哪个epoch开始swa
                      annealing_epochs=10,  # 模拟退火的epoch数
                      annealing_strategy='cos'  # 模拟退火策略
                      ):
    lrs = []

    # 定义一个简单的模型,用于测试
    class SimpleModel(pl.LightningModule):

        def __init__(self):
            super(SimpleModel, self).__init__()
            self.linear = nn.Linear(1, 1)

        def training_step(self, batch, batch_idx, *args, **kwargs):
            return nn.functional.mse_loss(self.linear(torch.rand(4, 1)), torch.rand(4, 1))

        def configure_optimizers(self):
            # 使用model_lr作为测试模型的学习率
            return torch.optim.SGD(self.parameters(), lr=model_lr)

    # 重写一下StochasticWeightAveraging,用于记录学习率变化
    class MyStochasticWeightAveraging(StochasticWeightAveraging):

        def on_train_epoch_start(self, *args, **kwargs):
            super().on_train_epoch_start(*args, **kwargs)
            if hasattr(self._swa_scheduler, "_last_lr"):
                # 记录lr的变化
                lrs.append(self._swa_scheduler._last_lr[0])
            else:
                lrs.append(model_lr)

    # 定义trainer进行训练
    trainer = pl.Trainer(
        callbacks=[MyStochasticWeightAveraging(swa_lrs=swa_lr, swa_epoch_start=swa_epoch_start,
                                               annealing_epochs=annealing_epochs,
                                               annealing_strategy=annealing_strategy)],
        max_epochs=20,
        num_sanity_val_steps=0,
        enable_progress_bar=False,  # Use custom progress bar
        accelerator='cpu',
    )

    # 训练模型
    trainer.fit(SimpleModel(), train_dataloaders=range(10))

    plt.plot(np.arange(1, len(lrs)+1).astype(dtype=np.str), lrs)
    plt.xlabel("epoch")
    plt.ylabel("learning rate")
    plt.text(0.7, 0.9, "model_lr: %s" % model_lr, fontsize=11, transform=plt.gca().transAxes)
    plt.text(0.7, 0.8, "swa_lr: %s" % swa_lr, fontsize=11, transform=plt.gca().transAxes)
    plt.text(0.6, 0.7, "swa_epoch_start: %s" % swa_epoch_start, fontsize=11, transform=plt.gca().transAxes)
    plt.text(0.6, 0.6, "annealing_epochs: %s" % annealing_epochs, fontsize=11, transform=plt.gca().transAxes)
    plt.text(0.6, 0.5, "annealing_strategy: %s" % annealing_strategy, fontsize=11, transform=plt.gca().transAxes)    
    plt.show()

    print("lrs:", lrs)  # 输出lr的变化
    return lrs

plot_swa_lr_curve(0.1, 0.001)
```





参考资料

Averaging Weights Leads to Wider Optima and Better Generalization(原论文): https://arxiv.org/abs/1803.05407

PyTorch 1.6 now includes Stochastic Weight Averaging: https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging/

Next Post Previous Post
No Comment
Add Comment
comment url