Artificial Inteligence
  • Preface
  • Introduction
  • Machine Learning
    • Linear Algebra
    • Supervised Learning
      • Neural Networks
      • Linear Classification
      • Loss Function
      • Model Optimization
      • Backpropagation
      • Feature Scaling
      • Model Initialization
      • Recurrent Neural Networks
        • Machine Translation Using RNN
    • Deep Learning
      • Convolution
      • Convolutional Neural Networks
      • Fully Connected Layer
      • Relu Layer
      • Dropout Layer
      • Convolution Layer
        • Making faster
      • Pooling Layer
      • Batch Norm layer
      • Model Solver
      • Object Localization and Detection
      • Single Shot Detectors
        • Yolo
        • SSD
      • Image Segmentation
      • GoogleNet
      • Residual Net
      • Deep Learning Libraries
    • Unsupervised Learning
      • Principal Component Analysis
      • Generative Models
    • Distributed Learning
    • Methodology for usage
      • Imbalanced/Missing Datasets
  • Artificial Intelligence
    • OpenAI Gym
    • Tree Search
    • Markov Decision process
    • Reinforcement Learning
      • Q_Learning_Simple
      • Deep Q Learning
      • Deep Reinforcement Learning
    • Natural Language Processing
      • Word2Vec
  • Appendix
    • Statistics and Probability
      • Probability
        • Markov Chains
        • Random Walk
    • Lua and Torch
    • Tensorflow
      • Multi Layer Perceptron MNIST
      • Convolution Neural Network MNIST
      • SkFlow
    • PyTorch
      • Transfer Learning
      • DataLoader and DataSets
      • Visualizing Results
Powered by GitBook
On this page
  • Introduction
  • Where to use the Batch-Norm layer
  • Test time
  • Backpropagation
  • Computation Graph
  • New nodes
  • Implementation
  • Python Forward Propagation
  • Python Backward Propagation
  • Matlab version forward propagation
  • Matlab version backward propagation
  • Spatial batchnorm
  • References
  • Next Chapter

Was this helpful?

  1. Machine Learning
  2. Deep Learning

Batch Norm layer

PreviousPooling LayerNextModel Solver

Last updated 5 years ago

Was this helpful?

Introduction

On this chapter we will learn about the batch norm layer. Previously we said that make the job of the gradient descent easier. Now we will extend this idea and normalize the activation of every Fully Connected layer or Convolution layer during training. This also means that while we're training we will select an batch calculate it's mean and standard deviation.

You can think that the batch-norm will be some kind of adaptive (or learnable) pre-processing block with trainable parameters. Which also means that we need to back-propagate them.

The original batch-norm paper can be found .

Here is the list of advantages of using Batch-Norm:

  1. Improves gradient flow, used on very deep models (Resnet need this)

  2. Allow higher learning rates

  3. Reduce dependency on initialization

  4. Gives some kind of regularization (Even make Dropout less important but keep using it)

  5. As a rule of thumb if you use Dropout+BatchNorm you don't need L2 regularization

It basically force your activations (Conv,FC ouputs) to be unit standard deviation and zero mean.

To each learning batch of data we apply the following normalization.

x^(k)=x(k)−E[x(k)]VAR[x(k)]\Large \hat{x}^{(k)}=\frac{x^{(k)}-E[x^{(k)}]}{ \sqrt{VAR[x^{(k)}]} }x^(k)=VAR[x(k)]​x(k)−E[x(k)]​

The output of the batch norm layer, has the γ,β\gamma, \betaγ,β are parameters. Those parameters will be learned to best represent your activations. Those parameters allows a learnable (scale and shift) factor yk=γk.x^(k)+βk\Large y^{k}=\gamma^{k}.\hat{x}^{(k)}+\beta^{k}yk=γk.x^(k)+βk

Now summarizing the operations:

Here, ϵ\epsilonϵ is a small number, 1e-5.

Where to use the Batch-Norm layer

The batch norm layer is used after linear layers (ie: FC, conv), and before the non-linear layers (relu). There is actually 2 batch norm implementations one for FC layer and the other for conv layers (Spatial batch-norm). The good news is that the Spatial batch norm just calls the normal batch-norm after some reshapes.

Test time

At prediction time that batch norm works differently. The mean/std are not computed based on the batch. Instead, we need to build a estimate during training of the mean/std of the whole dataset(population) for each batch norm layer on your model.

St=α.St−1+(1−α).YtS_t=\alpha.S_{t-1}+(1-\alpha).Y_tSt​=α.St−1​+(1−α).Yt​

Where: St,St−1{S_t,S_{t-1}}St​,St−1​: Current and previous estimation (α)(\alpha)(α): Represents the degree of weighting decrease, a constant smoothing factor between 0 and 1 YtY_tYt​: Current value (could be mean or std) that we're trying to estimate

Normally when we implement this layer we have some kind of flag that detects if we're on training or testing.

Backpropagation

As mentioned earlier we need to know how to backpropagate on the batch-norm layer, first as we did with other layers we need to create the computation graph. After this step we need to calculate the derivative of each node with respect to it's inputs.

Computation Graph

In order to find the partial derivatives on back-propagation is better to visualize the algorithm as a computation graph:

New nodes

In other words:

$$\Large\frac{\partial(\frac{1}{x})}{\partial x}=-\frac{1}{x^2} \therefore dx=-\frac{1}{x_{cache}^2}.dout

Where: xcachex_{cache}xcache​: the cached (or saved) input from the forward propagation. doutdoutdout: the previous block gradient ϵ\epsilonϵ: Some small number 0.00005

In other words:

$$\Large\frac{\partial{x^2}}{\partial x}=2.x \therefore dx=2.x_{cache}.dout

$$

Block Summation

Like the SUM block this block will copy the input gradient dout equally to all it's inputs. So for all elements in X we will divide by N and multiply by dout.

Implementation

Python Forward Propagation

Python Backward Propagation

Matlab version forward propagation

function [activations] = ForwardPropagation(obj, input, weights, bias)
    obj.previousInput = input;
    % Tensor format (rows,cols,channels, batch) on matlab
    % Get batch size            
    lenSizeActivations = length(size(input));
    [~,D] = size(input);
    if (lenSizeActivations < 3)
        N = size(input,1);
    else
        N = size(input,ndims(input));
    end

    % Initialize for the first time running_mean and running_var
    if isempty(obj.running_mean)
        obj.running_mean = zeros(1,D);
        obj.running_var = zeros(1,D);
    end

    if (obj.isTraining)                                
        % Step1: Calculate mean on the batch
        mu = (1/N) * sum(input,1);

        % Step2: Subtract the mean from each column
        obj.xmu = input - repmat(mu,N,1);

        % Step3: Calculate denominator
        sq = obj.xmu .^ 2;

        % Step4: Calculate variance
        obj.var = (1/N) * sum(sq,1);

        % Step5: add eps for numerical stability, then sqrt
        obj.sqrtvar = sqrt(obj.var + obj.eps);

        % Step6: Invert the square root
        obj.ivar = 1./obj.sqrtvar;

        %Step7: Do normalization
        obj.xhat = obj.xmu .* repmat(obj.ivar,N,1);

        %Step8: Nor the two transformation steps
        gammax = repmat(weights,N,1) .* obj.xhat;

        % Step9: Adjust with bias (Batchnorm output)
        activations = gammax + repmat(bias,N,1); 

        % Calculate running mean and variance to be used latter on
        % prediction
        obj.running_mean = (obj.momentum .* obj.running_mean) + (1.0 - obj.momentum) * mu;
        obj.running_var = (obj.momentum .* obj.running_var) + (1.0 - obj.momentum) .* obj.var;
    else
        xbar = (input - repmat(obj.running_mean,N,1)) ./ repmat(sqrt(obj.running_var + obj.eps),N,1);
        activations = (repmat(weights,N,1) .* xbar) + repmat(bias,N,1);
    end

    % Store stuff for backpropagation
    obj.activations = activations;
    obj.weights = weights;
    obj.biases = bias;
end

Matlab version backward propagation

function [gradient] = BackwardPropagation(obj, dout)
    dout = dout.input;
    lenSizeActivations = length(size(obj.previousInput));
    [~,D] = size(obj.previousInput);
    if (lenSizeActivations < 3)
        N = size(obj.previousInput,1);
    else
        N = size(obj.previousInput,ndims(obj.previousInput));
    end

    % Step9:
    dbeta = sum(dout, 1);
    dgammax = dout;

    % Step8:
    dgamma = sum(dgammax.*obj.xhat, 1);
    dxhat = dgammax .* repmat(obj.weights,N,1);

    % Step7:
    divar = sum(dxhat.* obj.xmu, 1);
    dxmu1 = dxhat .* repmat(obj.ivar,N,1);

    % Step6:
    dsqrtvar = -1 ./ (obj.sqrtvar.^2) .* divar;

    % Step 5:
    dvar = 0.5 * 1 ./sqrt(obj.var+obj.eps) .* dsqrtvar;

    % Step 4:
    dsq = 1 ./ N * ones(N,D) .* repmat(dvar,N,1);

    % Step 3:
    dxmu2 = 2 .* obj.xmu .* dsq;

    % Step 2:
    dx1 = (dxmu1 + dxmu2);
    dmu = -1 .* sum(dxmu1+dxmu2, 1);

    % Step 1:
    dx2 = 1. /N .* ones(N,D) .* repmat(dmu,N,1);

    gradient.input = dx1+dx2;
    gradient.weight = dgamma;
    gradient.bias = dbeta;        
end

Spatial batchnorm

As mentioned before the spatial batchnorm is used between CONV and Relu layers. To implement the spatial batchnorm we just call the normal batchnorm but with the input reshaped and permuted. Bellow we present the matlab version of forward and backward propagation of the spatial batchnorm.

% It's just a call to the normal batchnorm but with some
% permute/reshape on the input signal
function [activations] = ForwardPropagation(obj, input, weights, bias)
obj.previousInput = input;
[H,W,C,N] = size(input);

% Permute the dimensions to the following format 
% (cols, channel, rows, batch)    
% On python was: x.transpose((0,2,3,1))
% Python tensor format:
% (batch(0), channel(1), rows(2), cols(3))
% Matlab tensor format:
% (rows(1), cols(2), channel(3), batch(4))
inputTransposed = permute(input,[2,3,1,4]);                                    

% Flat the input (On python the reshape is row-major)           
inputFlat = reshape_row_major(inputTransposed,[(numel(inputTransposed) / C),C]);

% Call the forward propagation of normal batchnorm
activations = obj.normalBatchNorm.ForwardPropagation(inputFlat, weights, bias);

% Reshape/transpose back the signal, on python was (N,H,W,C)
activations_reshape = reshape_row_major(activations, [W,C,H,N]);
% On python was transpose(0,3,1,2)
activations = permute(activations_reshape,[3 1 2 4]);

% Store stuff for backpropagation
obj.activations = activations;
obj.weights = weights;
obj.biases = bias;
end

Now for the backpropagation we just reshape and permute again.

function [gradient] = BackwardPropagation(obj, dout)
% Observe that we use the same reshape/permutes from forward
% propagation
dout = dout.input;
[H,W,C,N] = size(dout);
% On python was: x.transpose((0,2,3,1))
dout_transp = permute(dout,[2,3,1,4]);

% Flat the input            
dout_flat = reshape_row_major(dout_transp,[(numel(dout_transp) / C),C]);

% Call the backward propagation of normal batchnorm
gradDout.input = dout_flat;
gradient = obj.normalBatchNorm.BackwardPropagation(gradDout);

% Reshape/transpose back the signal, on python was (N,H,W,C)
gradient.input = reshape_row_major(gradient.input, [W,C,H,N]);
% On python was transpose(0,3,1,2)
gradient.input = permute(gradient.input,[3 1 2 4]);

end

References

Next Chapter

Next chapter we will learn about how to optimize our model weights.

One approach to estimating the population mean and variance during training is to use an .

As reference we can find some tutorials with or .

By inspecting this graph we have some new nodes (1N.∑i=1NX(i)\frac{1}{N}.\sum\limits_{i=1}^N X(i)N1​.i=1∑N​X(i), x2x^2x2, (x−ϵ)\sqrt(x-\epsilon)(​x−ϵ), 1x\frac{1}{x}x1​). To simplify things you can use to find the derivatives. For backpropagate other nodes refer to the

exponential moving average
Tensorflow
manually on python
Wolfram alpha
Back-propagation chapter
Block 1/x
Block x^2
https://kratzert.github.io/2016/02/12/understanding-the-gradient-flow-through-the-batch-normalization-layer.html
http://cs231n.github.io/neural-networks-2/
feature scaling
here