Reputation: 340
In my network, I want to calculate the forward pass and backward pass of my network both in the forward pass.
For this, I have to manually define all the backward pass methods of the forward pass layers.
For the activation functions, that's easy. And also for the linear and conv layers, it worked well. But I'm really struggling with BatchNorm. As the BatchNorm paper only discusses the 1D case:
So far, my implementation looks like this:
def backward_batchnorm2d(input, output, grad_output, layer):
gamma = layer.weight
beta = layer.bias
avg = layer.running_mean
var = layer.running_var
eps = layer.eps
B = input.shape[0]
# avg, var, gamma and beta are of shape [channel_size]
# while input, output, grad_output are of shape [batch_size, channel_size, w, h]
# for my calculations I have to reshape avg, var, gamma and beta to [batch_size, channel_size, w, h] by repeating the channel values over the whole image and batches
dL_dxi_hat = grad_output * gamma
dL_dvar = (-0.5 * dL_dxi_hat * (input - avg) / ((var + eps) ** 1.5)).sum((0, 2, 3), keepdim=True)
dL_davg = (-1.0 / torch.sqrt(var + eps) * dL_dxi_hat).sum((0, 2, 3), keepdim=True) + dL_dvar * (-2.0 * (input - avg)).sum((0, 2, 3), keepdim=True) / B
dL_dxi = dL_dxi_hat / torch.sqrt(var + eps) + 2.0 * dL_dvar * (input - avg) / B + dL_davg / B # dL_dxi_hat / sqrt()
dL_dgamma = (grad_output * output).sum((0, 2, 3), keepdim=True)
dL_dbeta = (grad_output).sum((0, 2, 3), keepdim=True)
return dL_dxi, dL_dgamma, dL_dbeta
When I check my gradients with torch.autograd.grad() I notice that dL_dgamma
and dL_dbeta
are correct, but dL_dxi
is incorrect, (by a lot). But I can't find my mistake. Where is my mistake?
For reference, here is the definition of BatchNorm:
And here are the formulas for the derivatives for the 1D case:
Upvotes: 3
Views: 1309
Reputation: 1308
def backward_batchnorm2d(input, output, grad_output, layer):
gamma = layer.weight
gamma = gamma.view(1,-1,1,1) # edit
# beta = layer.bias
# avg = layer.running_mean
# var = layer.running_var
eps = layer.eps
B = input.shape[0] * input.shape[2] * input.shape[3] # edit
# add new
mean = input.mean(dim = (0,2,3), keepdim = True)
variance = input.var(dim = (0,2,3), unbiased=False, keepdim = True)
x_hat = (input - mean)/(torch.sqrt(variance + eps))
dL_dxi_hat = grad_output * gamma
# dL_dvar = (-0.5 * dL_dxi_hat * (input - avg) / ((var + eps) ** 1.5)).sum((0, 2, 3), keepdim=True)
# dL_davg = (-1.0 / torch.sqrt(var + eps) * dL_dxi_hat).sum((0, 2, 3), keepdim=True) + dL_dvar * (-2.0 * (input - avg)).sum((0, 2, 3), keepdim=True) / B
dL_dvar = (-0.5 * dL_dxi_hat * (input - mean)).sum((0, 2, 3), keepdim=True) * ((variance + eps) ** -1.5) # edit
dL_davg = (-1.0 / torch.sqrt(variance + eps) * dL_dxi_hat).sum((0, 2, 3), keepdim=True) + (dL_dvar * (-2.0 * (input - mean)).sum((0, 2, 3), keepdim=True) / B) #edit
dL_dxi = (dL_dxi_hat / torch.sqrt(variance + eps)) + (2.0 * dL_dvar * (input - mean) / B) + (dL_davg / B) # dL_dxi_hat / sqrt()
# dL_dgamma = (grad_output * output).sum((0, 2, 3), keepdim=True)
dL_dgamma = (grad_output * x_hat).sum((0, 2, 3), keepdim=True) # edit
dL_dbeta = (grad_output).sum((0, 2, 3), keepdim=True)
return dL_dxi, dL_dgamma, dL_dbeta
1
, you need to reshape it to [1,gamma.shape[0],1,1]
.B = input.shape[0] * input.shape[2] * input.shape[3]
.running_mean
and running_var
only use in test/inference mode, we don't use them in training (you can find it in the paper). The mean and variance you need are computed from the input, you can store the mean, variance and x_hat = (x-mean)/sqrt(variance + eps)
into your object layer
or re-compute as I did in the code above # add new
. Then replace them with the formula of dL_dvar, dL_davg, dL_dxi
.dL_dgamma
should be incorrect since you multiplied the gradient of output
by itself, it should be modified to grad_output * x_hat
.Upvotes: 7