Michael
Michael

Reputation: 109

Batch normalization makes training worse

I am trying to implement the batch normalization with Pytorch and use a simple fully connected neural network to approximate a given function.

The code is as follows. The result shows that the neural network without the batch normalization performs better than that with the batch normalization technique. This means that the batch normalization makes the training even worse. Could someone explain this result? Thanks!

import matplotlib.pyplot as plt
import numpy as np
import torch

class Net(torch.nn.Module):
    
    def __init__(self, num_inputs, num_outputs, hidden_size=256, is_bn=True):
        super(Net, self).__init__()
        self.num_inputs = num_inputs
        self.num_outputs = num_outputs
        self.is_bn = is_bn
        
        # no bias is needed if batch normalization
        if self.is_bn:
            self.linear1 = torch.nn.Linear(num_inputs, hidden_size, bias=False)
            self.linear2 = torch.nn.Linear(hidden_size, hidden_size, bias=False)
        else:            
            self.linear1 = torch.nn.Linear(num_inputs, hidden_size)
            self.linear2 = torch.nn.Linear(hidden_size, hidden_size)
                
        self.linear3 = torch.nn.Linear(hidden_size, num_outputs)
        
        if self.is_bn:
            self.bn1 = torch.nn.BatchNorm1d(hidden_size)
            self.bn2 = torch.nn.BatchNorm1d(hidden_size)

        self.activation = torch.nn.ReLU()
        
    def forward(self, inputs):
        x = inputs
        if self.is_bn:
            x = self.activation(self.bn1(self.linear1(x)))
            x = self.activation(self.bn2(self.linear2(x)))
        else:
            x = self.activation(self.linear1(x))
            x = self.activation(self.linear2(x))
        out = self.linear3(x)        
        return out


torch.manual_seed(0)    # reproducible

Nx = 100
x = torch.linspace(-1., 1., Nx)
x = torch.reshape(x, (Nx, 1))
y = torch.sin(3*x)

fcn_bn, fcn_no_bn = Net(num_inputs=1, num_outputs=1, is_bn=True), Net(num_inputs=1, num_outputs=1, is_bn=False)

criterion = torch.nn.MSELoss()
optimizer_bn = torch.optim.Adam(fcn_bn.parameters(), lr=0.001)
optimizer_no_bn = torch.optim.Adam(fcn_no_bn.parameters(), lr=0.001)

total_epoch = 5000

# record loss history    
loss_history_bn = np.zeros(total_epoch)
loss_history_no_bn = np.zeros(total_epoch)

fcn_bn.train()
fcn_no_bn.train()
for epoch in range(total_epoch):
        
    optimizer_bn.zero_grad()
    loss = criterion(fcn_bn(x), y)    
    loss_history_bn[epoch] = loss.item()
    loss.backward()
    optimizer_bn.step()

    optimizer_no_bn.zero_grad()
    loss = criterion(fcn_no_bn(x), y)    
    loss_history_no_bn[epoch] = loss.item()
    loss.backward()
    optimizer_no_bn.step()
    
    if epoch%1000 == 0:
        print("epoch: %d; MSE (with bn): %.2e; MSE (without bn): %.2e"%(epoch, loss_history_bn[epoch], loss_history_no_bn[epoch]))
        
fcn_bn.eval()
fcn_no_bn.eval()

plt.figure()
plt.semilogy(np.arange(total_epoch), loss_history_bn, label='neural network (with bn)')
plt.semilogy(np.arange(total_epoch), loss_history_no_bn, label='neural network (without bn)')
plt.legend()

plt.figure()
plt.plot(x, y, '-', label='exact')
plt.plot(x, fcn_bn(x).detach(), 'o', markersize=2, label='neural network (with bn)')
plt.plot(x, fcn_no_bn(x).detach(), 'o', markersize=2, label='neural network (without bn)')
plt.legend()

plt.figure()
plt.plot(x, np.abs(fcn_bn(x).detach() - y), 'o', markersize=2, label='neural network (with bn)')
plt.plot(x, np.abs(fcn_no_bn(x).detach() - y), 'o', markersize=2, label='neural network (without bn)')
plt.legend()

plt.show()

The result is as follows:

epoch: 0; MSE (with bn): 3.99e-01; MSE (without bn): 4.84e-01
epoch: 1000; MSE (with bn): 4.70e-05; MSE (without bn): 1.27e-06
epoch: 2000; MSE (with bn): 1.81e-04; MSE (without bn): 7.93e-07
epoch: 3000; MSE (with bn): 2.73e-04; MSE (without bn): 7.45e-07
epoch: 4000; MSE (with bn): 4.04e-04; MSE (without bn): 5.68e-07

Upvotes: 1

Views: 1604

Answers (1)

sim
sim

Reputation: 1257

To provide an alternate view to the answer that Khalid linked in the comments, which puts a stronger focus on generalization performance rather than training loss, consider this:

Batch Normalization has been postulated to have a regularizing effect. Luo et al. look at BN as a decomposition into population normalization and gamma decay and observe similar training loss curves as you do (comparing BN to no BN - note, however, that they use vanilla SGD and not Adam). There are a couple of things that affect BN (as outlined also in Khalid's link): The batch size, for example, on the one hand should be large enough for robust estimation of population parameters, however, with increasing size of the batch generalization performance can also drop (see Luo et al.'s paper: the gist is that lower batch sizes result in noisy population parameter estimates, essentially perturbing the input).

In your case I would not intuitively have expected a big difference (given how your data is set up), but maybe someone deeper into the theoretical analysis of BN can still provide insights.

Upvotes: 2

Related Questions