--- title: Batch Norm Folding keywords: fastai sidebar: home_sidebar summary: "Fold the batchnorm and the conv layers together to reduce computation" description: "Fold the batchnorm and the conv layers together to reduce computation" nb_path: "nbs/06_bn_folding.ipynb" ---
{% raw %}
{% endraw %} {% raw %}
 
{% endraw %} {% raw %}
{% endraw %}

Batch Normalization is a technique which takes care of normalizing the input of each layer to make the training process faster and more stable. In practice, it is an extra layer that we generally add after the computation layer and before the non-linearity.

It consists of 2 steps:

  1. Normalize the batch by first subtracting its mean $\mu$, then dividing it by its standard deviation $\sigma$.
  2. Further scale by a factor $\gamma$ and shift by a factor $\beta$. Those are the parameters of the batch normalization layer, required in case of the network not needing the data to have a mean of $0$ and a standard deviation of $1$.
$$ \begin{aligned}\mu_{\mathcal{B}} & \leftarrow \frac{1}{m} \sum_{i=1}^{m} x_{i} \\ \sigma_{\mathcal{B}}^{2} & \leftarrow \frac{1}{m} \sum_{i=1}^{m}\left(x_{i}-\mu_{\mathcal{B}}\right)^{2} \\ \widehat{x}_{i} & \leftarrow \frac{x_{i}-\mu_{\mathcal{B}}}{\sqrt{\sigma_{\mathcal{B}}^{2}+\epsilon}} \\ y_{i} & \leftarrow \gamma \widehat{x}_{i}+\beta \equiv \mathrm{BN}_{\gamma, \beta}\left(x_{i}\right) \end{aligned}$$

Due to its efficiency for training neural networks, batch normalization is now widely used. But how useful is it at inference time?

Once the training has ended, each batch normalization layer possesses a specific set of $\gamma$ and $\beta$, but also $\mu$ and $\sigma$, the latter being computed using an exponentially weighted average during training. It means that during inference, the batch normalization acts as a simple linear transformation of what comes out of the previous layer, often a convolution.

As a convolution is also a linear transformation, it also means that both operations can be merged into a single linear transformation!

This would remove some unnecessary parameters but also reduce the number of operations to be performed at inference time.

With a little bit of math, we can easily rearrange the terms of the convolution to take the batch normalization into account.

As a little reminder, the convolution operation followed by the batch normalization operation can be expressed, for an input $x$, as:

{% raw %} $$\begin{aligned} z &=W * x+b \\ \mathrm{out} &=\gamma \cdot \frac{z-\mu}{\sqrt{\sigma^{2}+\epsilon}}+\beta \end{aligned}$$ {% endraw %}

So, if we re-arrange the $W$ and $b$ of the convolution to take the parameters of the batch normalization into account, as such:

{% raw %} $$\begin{aligned} w_{\text {fold }} &=\gamma \cdot \frac{W}{\sqrt{\sigma^{2}+\epsilon}} \\ b_{\text {fold }} &=\gamma \cdot \frac{b-\mu}{\sqrt{\sigma^{2}+\epsilon}}+\beta \end{aligned}$$ {% endraw %}

This is how to do it with fasterai !

{% raw %}
path = untar_data(URLs.PETS)
files = get_image_files(path/"images")

def label_func(f): return f[0].isupper()
{% endraw %} {% raw %}
dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64))
{% endraw %} {% raw %}
learn = Learner(dls, resnet18(num_classes=2), metrics=accuracy)
{% endraw %} {% raw %}
learn.fit_one_cycle(5)
epoch train_loss valid_loss accuracy time
0 0.615641 0.596630 0.688092 00:10
1 0.582679 0.558671 0.689445 00:10
2 0.529308 0.517995 0.744926 00:10
3 0.481804 0.449941 0.784168 00:10
4 0.400030 0.414093 0.800406 00:10
{% endraw %} {% raw %}

class BN_Folder[source]

BN_Folder()

{% endraw %} {% raw %}
{% endraw %} {% raw %}
bn = BN_Folder()
{% endraw %} {% raw %}
new_model = bn.fold(learn.model)
{% endraw %}

The batch norm layers have been replaced by an Identity layer, and the weights of the convolutions have been modified accordingly.

{% raw %}
new_model
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
  (bn1): Identity()
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): Identity()
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): Identity()
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): Identity()
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): Identity()
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (bn1): Identity()
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): Identity()
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2))
        (1): Identity()
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): Identity()
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): Identity()
    )
  )
  (layer3): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (bn1): Identity()
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): Identity()
      (downsample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2))
        (1): Identity()
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): Identity()
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): Identity()
    )
  )
  (layer4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (bn1): Identity()
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): Identity()
      (downsample): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2))
        (1): Identity()
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): Identity()
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): Identity()
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=512, out_features=2, bias=True)
)
{% endraw %}

We can see that the new model possess fewer parameters

{% raw %}
count_parameters(learn.model)
11177538
{% endraw %} {% raw %}
count_parameters(new_model)
11172738
{% endraw %}

But is also faster to run !

{% raw %}
x,y = dls.one_batch()
{% endraw %} {% raw %}
%%timeit
learn.model(x[0][None].cuda())
5.59 ms ± 547 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
{% endraw %} {% raw %}
%%timeit
new_model(x[0][None].cuda())
4.14 ms ± 446 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
{% endraw %}

But most importantly, has the exact same perfomance as before:

{% raw %}
new_learn = Learner(dls, new_model, metrics=accuracy)
{% endraw %} {% raw %}
new_learn.validate()
(#2) [0.4140927791595459,0.8004059791564941]
{% endraw %}