【炼丹Trick】EMA的原理与实现

在进行深度学习训练时,同一模型往往可以训练出不同的效果,这就是炼丹这件事的玄学所在。使用一些trick能够让你更容易追上目前SOTA的效果,一些流行的开源代码中已经集成了不少trick,值得学习一番。本节介绍EMA这一方法。

1.原理:

EMA也就是指数移动平均(Exponential moving average)。其公式非常简单,如下所示:

(theta_{text{EMA}, t+1} = (1 - lambda) cdot theta_{text{EMA}, t} + lambda cdot theta_{t})

(theta_{t})是t时刻的网络参数,(theta_{text{EMA}, t})是t时刻滑动平均后的网络参数,那么t+1时刻的滑动平均结果就是这两者的加权融合。这里 (lambda)通常会取接近于1的数,比如0.9995,数字越大平均的效果就比较强。

值得注意的是,这里可以看成有两个模型,基础模型其参数按照常规的前后向传播来更新,另外一个模型则是基础模型的滑动平均版本,它并不直接参与前后向传播,仅仅是利用基础模型的参数结果来更新自己。

EMA为什么会有效呢?大概是因为在训练的时候,会使用验证集来衡量模型精度,但其实验证集精度并不和测试集一致,在训练后期阶段,模型可能已经在测试集最佳精度附近波动,所以使用滑动平均的结果会比使用单一结果更加可靠。感兴趣的话可以看看这几篇论文,论文1,论文2,论文3

2.实现:

Pytorch其实已经为我们实现了这一功能,为了避免自己造轮子可能引入的错误,这里直接学习一下官方的代码。这个类的名称就叫做AveragedModel。代码如下所示。
我们需要做的是提供avg_fn这个函数,avg_fn用来指定以何种方式进行平均。

class AveragedModel(Module):     """     You can also use custom averaging functions with `avg_fn` parameter.     If no averaging function is provided, the default is to compute     equally-weighted average of the weights.     """     def __init__(self, model, device=None, avg_fn=None, use_buffers=False):         super(AveragedModel, self).__init__()         self.module = deepcopy(model)         if device is not None:             self.module = self.module.to(device)         self.register_buffer('n_averaged',                              torch.tensor(0, dtype=torch.long, device=device))         if avg_fn is None:             def avg_fn(averaged_model_parameter, model_parameter, num_averaged):                 return averaged_model_parameter +                      (model_parameter - averaged_model_parameter) / (num_averaged + 1)         self.avg_fn = avg_fn         self.use_buffers = use_buffers      def forward(self, *args, **kwargs):         return self.module(*args, **kwargs)      def update_parameters(self, model):         self_param = (             itertools.chain(self.module.parameters(), self.module.buffers())             if self.use_buffers else self.parameters()         )         model_param = (             itertools.chain(model.parameters(), model.buffers())             if self.use_buffers else model.parameters()         )         for p_swa, p_model in zip(self_param, model_param):             device = p_swa.device             p_model_ = p_model.detach().to(device)             if self.n_averaged == 0:                 p_swa.detach().copy_(p_model_)             else:                 p_swa.detach().copy_(self.avg_fn(p_swa.detach(), p_model_,                                                  self.n_averaged.to(device)))         self.n_averaged += 1   @torch.no_grad() def update_bn(loader, model, device=None):     r"""Updates BatchNorm running_mean, running_var buffers in the model.      It performs one pass over data in `loader` to estimate the activation     statistics for BatchNorm layers in the model.     Args:         loader (torch.utils.data.DataLoader): dataset loader to compute the             activation statistics on. Each data batch should be either a             tensor, or a list/tuple whose first element is a tensor             containing data.         model (torch.nn.Module): model for which we seek to update BatchNorm             statistics.         device (torch.device, optional): If set, data will be transferred to             :attr:`device` before being passed into :attr:`model`.      Example:         >>> loader, model = ...         >>> torch.optim.swa_utils.update_bn(loader, model)      .. note::         The `update_bn` utility assumes that each data batch in :attr:`loader`         is either a tensor or a list or tuple of tensors; in the latter case it         is assumed that :meth:`model.forward()` should be called on the first         element of the list or tuple corresponding to the data batch.     """     momenta = {}     for module in model.modules():         if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):             module.running_mean = torch.zeros_like(module.running_mean)             module.running_var = torch.ones_like(module.running_var)             momenta[module] = module.momentum      if not momenta:         return      was_training = model.training     model.train()     for module in momenta.keys():         module.momentum = None         module.num_batches_tracked *= 0      for input in loader:         if isinstance(input, (list, tuple)):             input = input[0]         if device is not None:             input = input.to(device)          model(input)      for bn_module in momenta.keys():         bn_module.momentum = momenta[bn_module]     model.train(was_training) 

这里同样参考官方的示例代码,给出滑动平均的实现。ExponentialMovingAverage继承了AveragedModel,并且复写了init方法,其实更直接的方法是将ema_avg函数作为参数传递给AveragedModel,这里可能是为了可读性,避免出现一个孤零零的ema_avg函数。

class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel):     """Maintains moving averages of model parameters using an exponential decay.     ``ema_avg = decay * avg_model_param + (1 - decay) * model_param``     `torch.optim.swa_utils.AveragedModel <https://pytorch.org/docs/stable/optim.html#custom-averaging-strategies>`_     is used to compute the EMA.     """      def __init__(self, model, decay, device="cpu"):         def ema_avg(avg_model_param, model_param, num_averaged):             return decay * avg_model_param + (1 - decay) * model_param          super().__init__(model, device, ema_avg, use_buffers=True) 

如何使用呢?方式是比较简单的,首先是利用当前模型创建出一个滑动平均模型。

model_ema = utils.ExponentialMovingAverage(model, device=device, decay=ema_decay) 

然后是进行基础模型的前后向传播,更新结束后再对滑动平均版的模型进行参数更新。

output = model(image) loss = criterion(output, target) optimizer.zero_grad() loss.backward() optimizer.step() model_ema.update_parameters(model) 

发表评论

相关文章