Reputation: 109
I am trying to implement the batch normalization with Pytorch and use a simple fully connected neural network to approximate a given function.
The code is as follows. The result shows that the neural network without the batch normalization performs better than that with the batch normalization technique. This means that the batch normalization makes the training even worse. Could someone explain this result? Thanks!
import matplotlib.pyplot as plt
import numpy as np
import torch
class Net(torch.nn.Module):
def __init__(self, num_inputs, num_outputs, hidden_size=256, is_bn=True):
super(Net, self).__init__()
self.num_inputs = num_inputs
self.num_outputs = num_outputs
self.is_bn = is_bn
# no bias is needed if batch normalization
if self.is_bn:
self.linear1 = torch.nn.Linear(num_inputs, hidden_size, bias=False)
self.linear2 = torch.nn.Linear(hidden_size, hidden_size, bias=False)
else:
self.linear1 = torch.nn.Linear(num_inputs, hidden_size)
self.linear2 = torch.nn.Linear(hidden_size, hidden_size)
self.linear3 = torch.nn.Linear(hidden_size, num_outputs)
if self.is_bn:
self.bn1 = torch.nn.BatchNorm1d(hidden_size)
self.bn2 = torch.nn.BatchNorm1d(hidden_size)
self.activation = torch.nn.ReLU()
def forward(self, inputs):
x = inputs
if self.is_bn:
x = self.activation(self.bn1(self.linear1(x)))
x = self.activation(self.bn2(self.linear2(x)))
else:
x = self.activation(self.linear1(x))
x = self.activation(self.linear2(x))
out = self.linear3(x)
return out
torch.manual_seed(0) # reproducible
Nx = 100
x = torch.linspace(-1., 1., Nx)
x = torch.reshape(x, (Nx, 1))
y = torch.sin(3*x)
fcn_bn, fcn_no_bn = Net(num_inputs=1, num_outputs=1, is_bn=True), Net(num_inputs=1, num_outputs=1, is_bn=False)
criterion = torch.nn.MSELoss()
optimizer_bn = torch.optim.Adam(fcn_bn.parameters(), lr=0.001)
optimizer_no_bn = torch.optim.Adam(fcn_no_bn.parameters(), lr=0.001)
total_epoch = 5000
# record loss history
loss_history_bn = np.zeros(total_epoch)
loss_history_no_bn = np.zeros(total_epoch)
fcn_bn.train()
fcn_no_bn.train()
for epoch in range(total_epoch):
optimizer_bn.zero_grad()
loss = criterion(fcn_bn(x), y)
loss_history_bn[epoch] = loss.item()
loss.backward()
optimizer_bn.step()
optimizer_no_bn.zero_grad()
loss = criterion(fcn_no_bn(x), y)
loss_history_no_bn[epoch] = loss.item()
loss.backward()
optimizer_no_bn.step()
if epoch%1000 == 0:
print("epoch: %d; MSE (with bn): %.2e; MSE (without bn): %.2e"%(epoch, loss_history_bn[epoch], loss_history_no_bn[epoch]))
fcn_bn.eval()
fcn_no_bn.eval()
plt.figure()
plt.semilogy(np.arange(total_epoch), loss_history_bn, label='neural network (with bn)')
plt.semilogy(np.arange(total_epoch), loss_history_no_bn, label='neural network (without bn)')
plt.legend()
plt.figure()
plt.plot(x, y, '-', label='exact')
plt.plot(x, fcn_bn(x).detach(), 'o', markersize=2, label='neural network (with bn)')
plt.plot(x, fcn_no_bn(x).detach(), 'o', markersize=2, label='neural network (without bn)')
plt.legend()
plt.figure()
plt.plot(x, np.abs(fcn_bn(x).detach() - y), 'o', markersize=2, label='neural network (with bn)')
plt.plot(x, np.abs(fcn_no_bn(x).detach() - y), 'o', markersize=2, label='neural network (without bn)')
plt.legend()
plt.show()
The result is as follows:
epoch: 0; MSE (with bn): 3.99e-01; MSE (without bn): 4.84e-01
epoch: 1000; MSE (with bn): 4.70e-05; MSE (without bn): 1.27e-06
epoch: 2000; MSE (with bn): 1.81e-04; MSE (without bn): 7.93e-07
epoch: 3000; MSE (with bn): 2.73e-04; MSE (without bn): 7.45e-07
epoch: 4000; MSE (with bn): 4.04e-04; MSE (without bn): 5.68e-07
Upvotes: 1
Views: 1604
Reputation: 1257
To provide an alternate view to the answer that Khalid linked in the comments, which puts a stronger focus on generalization performance rather than training loss, consider this:
Batch Normalization has been postulated to have a regularizing effect. Luo et al. look at BN as a decomposition into population normalization and gamma decay and observe similar training loss curves as you do (comparing BN to no BN - note, however, that they use vanilla SGD and not Adam). There are a couple of things that affect BN (as outlined also in Khalid's link): The batch size, for example, on the one hand should be large enough for robust estimation of population parameters, however, with increasing size of the batch generalization performance can also drop (see Luo et al.'s paper: the gist is that lower batch sizes result in noisy population parameter estimates, essentially perturbing the input).
In your case I would not intuitively have expected a big difference (given how your data is set up), but maybe someone deeper into the theoretical analysis of BN can still provide insights.
Upvotes: 2