#Deep Learning T04 - Batch Normalization

In this exercise notebook, you will implement the Batch Normalization (BN) operator in Pytorch, verify that your implementation passes a selection of tests, and answer/discuss a few questions.

Implementing BN involves a number of steps:

1. Inputting batch $x_i$, saved moving averages of mean and variance of activations
2. Defining learnable parameters $\gamma$ and $\beta$
3. Calculating mini-batch mean $\mu_i$
4. Calculating mini-batch variance $\sigma^2_i$
5. Normalizing $x_i$ to have zero mean and unit variance across batch dimension, $\hat{x}_i$ = $\frac{x_i - \mu_i}{\sigma^2_i  + \epsilon}$
6. Scaling and shifting using learnable parameters: $z_i = \gamma \odot \hat{x}_i - \beta$
7. Update moving averages of mean and variance

At test time, the mini-batch operations in steps 3 and 4 are replaced with moving averages of the mean and variance that are computed during training.

Additionally, we will only implement BN for the following inputs:

* Output of a fully connected layer (shape: (batch_size, L))

    * Mean and var computed to retain **feature dimension, L**

* Output of a convolutional layer (shape: (batch_size, C, H, W))

    * Mean and var computed to retain **channel dim, C**

Note: passing the tests is not a guarantee that your implementation is perfect.
Check the model answer when it is released to confirm that your implementation is correct. **Try not to look at the test code unless you get completely stuck, as this may give hints about how to complete the exercise.**



**Questions**:
1. Why do we use $\epsilon$ in the denominator when normalizing the input?
2. Why do we use moving averages of the mean and variance at test time to compute the BN operator?
3. Is there any need for a bias term in the layer preceding the BN operator? Explain.
4. What kind of modifications to the learning rate might BN enable?
5. How might BN affect the network's sensitivity to weight initialization?
6. How might the batch size relate to BN's suggested regularization effect?

Complete the following skeleton by **replacing any lines where variables are set to placeholder `torch.zeros(1)` or `torch.ones(1)` and adding any other code you need** and verify the tests in cell below pass. **Do not modify any other provided code or variable names** as this may break the tests.

In [1]:
import torch
import torch.nn as nn

class BatchNorm(nn.Module):
    def __init__(self, shape, eps=1e-5, momentum=0.9):
        """
        shape: Expected shape of input
        eps: epsilon used in normalization step
        momentum: momentum value used to update moving averages
        """
        super().__init__()
        if len(shape) not in (2, 4):
            raise ValueError("Invalid input shape!")

        self.eps = eps
        self.momentum = momentum

        # PART 1: defining Pytorch learnable parameters (hint: use `shape` argument
        # which should depend on whether the input is from FC or Conv). NOTE:
        # tests expect that broadcasting will be used in PART 8 below, so define
        # your parameters accordingly.
        # UPDATE:
        self.gamma = torch.ones(1)
        self.beta = torch.ones(1)

        # PART 2: initialize moving avg variables (hint: these are NOT learnable 
        # parameters)
        # UPDATE:
        self.moving_mu = torch.zeros(1)
        self.moving_sigma = torch.zeros(1)

    def forward(self, x):
        # Test
        if not torch.is_grad_enabled():
            # PART 3: Test time normalization operation; use self.eps as epsilon
            # UPDATE:
            x_hat = torch.zeros(1)

            # Logging code for tests; ignore:
            self._tmp_x_hat_test = x_hat
        
        # Training
        else:
            if len(x.shape) == 2:
                # PART 4: Compute mean and var for FC input (retaining feature dim)
                # UPDATE:
                mean = torch.zeros(1)
                var = torch.zeros(1)

            elif len(x.shape) == 4:
                # PART 5: Compute mean and var for Conv input (retaining channel dim)
                # UPDATE (hint: use `keepdim` flag to use broadcasting later):
                mean = torch.zeros(1)
                var = torch.zeros(1)
            else:
                raise ValueError("Incorrect input shape!")
            
            # Logging code for tests; ignore:
            self._tmp_mean = mean
            self._tmp_var = var

            # PART 6: Training time normalization operation; use self.eps as epsilon
            # UPDATE:
            x_hat = torch.zeros(1)

            # Logging code for tests; ignore:
            self._tmp_x_hat_train = x_hat
            
            # PART 7: Updating moving averages; use self.momentum to calculate
            # contribution to update (hint: be careful about unnecessary 
            # autograd computation tracking)
            # UPDATE:
            self.moving_mu = torch.zeros(1)
            self.moving_sigma = torch.zeros(1)

            # Logging code for tests; ignore:
            self._tmp_moving_mu = self.moving_mu
            self._tmp_moving_sigma = self.moving_sigma

        # PART 8: Scale and shift x_hat using learnable parameters to compute output
        # UPDATE:
        z = torch.zeros(1)

        return z

In [2]:
#@title To run the tests, first run your code in the cell above to define your BN implementation, and then run this block. Do not read code for this block until solutions released!
torch.manual_seed(0)

test_conv_input = torch.randn(64, 16, 32, 32)
test_fc_input = torch.randn(64, 128)

def test_is_param():
    bn_fc = BatchNorm(test_fc_input.shape)
    bn_conv = BatchNorm(test_conv_input.shape)
    if not bn_fc.gamma.requires_grad:
        return False
    if not bn_fc.beta.requires_grad:
        return False
    if not bn_conv.gamma.requires_grad:
        return False
    if not bn_conv.beta.requires_grad:
        return False
    return True

def test_param_shapes():
    bn_fc = BatchNorm(test_fc_input.shape)
    bn_conv = BatchNorm(test_conv_input.shape)
    if bn_fc.gamma.shape != (1, 128):
        return False
    if bn_fc.beta.shape != (1, 128):
        return False
    if bn_conv.gamma.shape != (1, 16, 1, 1):
        return False
    if bn_conv.beta.shape != (1, 16, 1, 1):
        return False
    return True

def test_ma_shapes():
    bn_fc = BatchNorm(test_fc_input.shape)
    bn_conv = BatchNorm(test_conv_input.shape)
    if bn_fc.moving_mu.shape != (1, 128):
        return False
    if bn_fc.moving_sigma.shape != (1, 128):
        return False
    if bn_conv.moving_mu.shape != (1, 16, 1, 1):
        return False
    if bn_conv.moving_sigma.shape != (1, 16, 1, 1):
        return False
    return True

def test_test_norm():
    bn_fc = BatchNorm(test_fc_input.shape)
    bn_conv = BatchNorm(test_conv_input.shape)
    bn_fc.moving_mu = torch.randn((1, 128))
    bn_fc.moving_sigma = torch.rand((1, 128)) + 0.01
    bn_conv.moving_mu = torch.randn((1, 16, 1, 1))
    bn_conv.moving_sigma = torch.rand((1, 16, 1, 1)) + 0.01
    try:
        with torch.no_grad():
            bn_fc(test_fc_input)
            bn_conv(test_conv_input)

            tmp_fc_xhat = bn_fc._tmp_x_hat_test
            expected = (test_fc_input - bn_fc.moving_mu) / torch.sqrt(bn_fc.moving_sigma + bn_fc.eps)
            if not torch.allclose(expected, tmp_fc_xhat):
                return False

            tmp_conv_xhat = bn_conv._tmp_x_hat_test
            expected = (test_conv_input - bn_conv.moving_mu) / torch.sqrt(bn_conv.moving_sigma + bn_conv.eps)
            if not torch.allclose(expected, tmp_conv_xhat):
                return False
    except:
        return False

    return True

def test_train_mean():
    bn_fc = BatchNorm(test_fc_input.shape)
    bn_conv = BatchNorm(test_conv_input.shape)
    try:
        bn_fc(test_fc_input)
        bn_conv(test_conv_input)

        bn_fc_mean = bn_fc._tmp_mean
        expected_fc = test_fc_input.mean(dim=0)
                
        bn_conv_mean = bn_conv._tmp_mean
        expected_conv = test_conv_input.mean(dim=(0, 2, 3), keepdim=True)

        if not torch.allclose(bn_fc_mean, expected_fc):
            return False
        
        if not torch.allclose(bn_conv_mean, expected_conv):
            return False

    except Exception as e:
        print("Test 4 - trace: ", e)
        return False
    
    return True

def test_train_var():
    bn_fc = BatchNorm(test_fc_input.shape)
    bn_conv = BatchNorm(test_conv_input.shape)
    try:
        bn_fc(test_fc_input)
        bn_conv(test_conv_input)

        bn_fc_var = bn_fc._tmp_var
        expected_fc = test_fc_input.var(dim=0)
                
        bn_conv_var = bn_conv._tmp_var
        expected_conv = test_conv_input.var(dim=(0, 2, 3), keepdim=True)

        if not torch.allclose(bn_fc_var, expected_fc):
            return False
        
        if not torch.allclose(bn_conv_var, expected_conv):
            return False
    except Exception as e:
        print("Test 5 - trace: ", e)
        return False

    return True

def test_train_norm():
    bn_fc = BatchNorm(test_fc_input.shape)
    bn_conv = BatchNorm(test_conv_input.shape)

    try:
        bn_fc(test_fc_input)
        bn_conv(test_conv_input)
        
        expected_fc_mean = test_fc_input.mean(dim=0)
        expected_fc_var = test_fc_input.var(dim=0)
        bn_fc_xhat = bn_fc._tmp_x_hat_train
        expected_fc = (test_fc_input - expected_fc_mean) / torch.sqrt(expected_fc_var + bn_fc.eps)
        
        expected_conv_mean = test_conv_input.mean(dim=(0, 2, 3), keepdim=True)
        expected_conv_var = test_conv_input.var(dim=(0, 2, 3), keepdim=True)                
        bn_conv_xhat = bn_conv._tmp_x_hat_train
        expected_conv = (test_conv_input - expected_conv_mean) / torch.sqrt(expected_conv_var + bn_conv.eps)

        if not torch.allclose(bn_fc_xhat, expected_fc):
            return False
        if not torch.allclose(bn_conv_xhat, expected_conv):
            return False
    except Exception as e:
        print("Test 6 - trace: ", e)
        return False
    
    return True

def test_train_mov_mean():
    bn_fc = BatchNorm(test_fc_input.shape)
    bn_conv = BatchNorm(test_conv_input.shape)

    cached_fc_mean = torch.randn((1, 128))
    cached_conv_mean = torch.randn((1, 16, 1, 1))   

    bn_fc.moving_mu = cached_fc_mean
    bn_conv.moving_mu = cached_conv_mean

    expected_fc_mean = test_fc_input.mean(dim=0).data
    expected_conv_mean = test_conv_input.mean(dim=(0, 2, 3), keepdim=True).data

    try:
        bn_fc(test_fc_input)
        bn_conv(test_conv_input)

        bn_fc_moving_mu = bn_fc._tmp_moving_mu
        bn_conv_moving_mu = bn_conv._tmp_moving_mu

        expected_fc = bn_fc.momentum * cached_fc_mean + (1.0 - bn_fc.momentum) * expected_fc_mean
        expected_conv = bn_conv.momentum * cached_conv_mean + (1.0 - bn_conv.momentum) * expected_conv_mean

        if not torch.allclose(expected_fc, bn_fc_moving_mu):
            return False
        if not torch.allclose(expected_conv, bn_conv_moving_mu):
            return False

    except Exception as e:
        print("Test 7 - trace: ", e)
        return False

    return True

def test_train_mov_var():
    bn_fc = BatchNorm(test_fc_input.shape)
    bn_conv = BatchNorm(test_conv_input.shape)

    cached_fc_sigma = torch.rand((1, 128)) + 0.01
    cached_conv_sigma = torch.rand((1, 16, 1, 1)) + 0.01
    
    bn_fc.moving_sigma = cached_fc_sigma
    bn_conv.moving_sigma = cached_conv_sigma

    expected_fc_var = test_fc_input.var(dim=0)
    expected_conv_var = test_conv_input.var(dim=(0, 2, 3), keepdim=True)

    try:
        bn_fc(test_fc_input)
        bn_conv(test_conv_input)

        bn_fc_moving_sigma = bn_fc._tmp_moving_sigma
        bn_conv_moving_sigma = bn_conv._tmp_moving_sigma

        expected_fc = bn_fc.momentum * cached_fc_sigma + (1.0 - bn_fc.momentum) * expected_fc_var
        expected_conv = bn_conv.momentum * cached_conv_sigma + (1.0 - bn_conv.momentum) * expected_conv_var

        if not torch.allclose(expected_fc, bn_fc_moving_sigma):
            return False
        if not torch.allclose(expected_conv, bn_conv_moving_sigma):
            return False

    except Exception as e:
        print("Test 8 - trace: ", e)
        return False

    return True

def test_scale_and_shift():
    bn_fc = BatchNorm(test_fc_input.shape)
    bn_conv = BatchNorm(test_conv_input.shape)

    try:
        res_fc = bn_fc(test_fc_input)
        res_conv = bn_conv(test_conv_input)

        expected_fc = bn_fc.gamma * bn_fc._tmp_x_hat_train + bn_fc.beta
        expected_conv = bn_conv.gamma * bn_conv._tmp_x_hat_train + bn_conv.beta

        if not torch.allclose(expected_fc, res_fc):
            return False
        if not torch.allclose(expected_conv, res_conv):
            return False

        if res_fc.shape != test_fc_input.shape:
            return False
        if res_conv.shape != test_conv_input.shape:
            return False

    except Exception as e:
        print("Test 9 - trace: ", e)
        return False

    return True

def run_tests():
    results = {
        "0 - Learnable parameters set correctly\t" : test_is_param(),
        "1 - Learnable parameter shapes correct\t" : test_param_shapes(),
        "2 - Moving average shapes correct\t" : test_ma_shapes(),
        "3 - Test time normalization\t\t" : test_test_norm(),
        "4 - Mean computation\t\t\t" : test_train_mean(),
        "5 - Variance computation\t\t" : test_train_var(),
        "6 - Train time normalization\t\t" : test_train_norm(),
        "7 - Train time mean moving average\t" : test_train_mov_mean(),
        "8 - Train time variance moving average\t" : test_train_mov_var(),
        "9 - Final scale and shift\t\t" : test_scale_and_shift()        
    }
    total = sum([v for v in results.values()])
    print()
    print("{}/10 TESTS PASSED:".format(total))
    print("#################")
    for k, v in sorted(results.items()):
        print("{}: {}".format(k, "Pass" if v else "*FAIL*"))

run_tests()


0/10 TESTS PASSED:
#################
0 - Learnable parameters set correctly	: *FAIL*
1 - Learnable parameter shapes correct	: *FAIL*
2 - Moving average shapes correct	: *FAIL*
3 - Test time normalization		: *FAIL*
4 - Mean computation			: *FAIL*
5 - Variance computation		: *FAIL*
6 - Train time normalization		: *FAIL*
7 - Train time mean moving average	: *FAIL*
8 - Train time variance moving average	: *FAIL*
9 - Final scale and shift		: *FAIL*
