Reputation: 61
I'm trying to implement a neural network with aleatoric uncertainty estimation for regression with pytorch according to
Kendall et al.: "What Uncertainties Do We Need in Bayesian Deep Learning for Computer Vision?" (Link).
However, while the predicted regression values fit the desired ground truth values quite well, the predicted variance looks weird and the loss gets negative during training.
The paper suggests to have two outputs mean and variance instead of only predicting the regression value. To be more precise, it is suggested to predict mean and log(variance) due to stability reasons. Therefore, my network looks as follows:
class ReferenceResNet(nn.Module):
def __init__(self):
super().__init__()
self.fcl1 = nn.Linear(1, 32)
self.fcl2 = nn.Linear(32, 64)
self.fcl3 = nn.Linear(64, 128)
self.fcl_mean = nn.Linear(128,1)
self.fcl_var = nn.Linear(128,1)
def forward(self, x):
x = torch.tanh(self.fcl1(x))
x = torch.tanh(self.fcl2(x))
x = torch.tanh(self.fcl3(x))
mean = self.fcl_mean(x)
log_var = self.fcl_var(x)
return mean, log_var
According to the paper, given these outputs, the corresponding loss function consists of a residual regression-part and a regularization term:
where si is the log(variance) predicted by the network.
I implemented this loss-function accordingly:
def loss_function(pred_mean, pred_log_var, y):
return 1/len(pred_mean)*(0.5 * torch.exp(-pred_log_var)*torch.sqrt(torch.pow(y-pred_mean, 2))+0.5*pred_log_var).sum()
I tried this code on a self-generated toy dataset (see image with results), however, the loss gets negative during training and when I plot the variance over the dataset after training, for me it does not really make sense while the corresponding mean values fit the ground truth quite well:
I already figured out that the negative loss comes from the regularization term as logarithms are negative for values between 0 and 1, however, I don't believe that the absolute value of the regularization term is supposed to grow bigger than the regression part. Does anyone know what is the reason for this and how I can prevent this from happening? And why does my variance look so weird? For reproduction, my full code looks as follows:
import torch.nn as nn
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data.dataset import TensorDataset
from torchvision import datasets, transforms
import math
import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt
from tqdm import tqdm
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class ReferenceRegNet(nn.Module):
def __init__(self):
super().__init__()
self.fcl1 = nn.Linear(1, 32)
self.fcl2 = nn.Linear(32, 64)
self.fcl3 = nn.Linear(64, 128)
self.fcl_mean = nn.Linear(128,1)
self.fcl_var = nn.Linear(128,1)
def forward(self, x):
x = torch.tanh(self.fcl1(x))
x = torch.tanh(self.fcl2(x))
x = torch.tanh(self.fcl3(x))
mean = self.fcl_mean(x)
log_var = self.fcl_var(x)
return mean, log_var
def toy_function(x):
return math.sin(x/15-4)+2 + math.sin(x/10-5)
def loss_function(x_mean, x_log_var, y):
return 1/len(x_mean)*(0.5 * torch.exp(-x_log_var)*torch.sqrt(torch.pow(y-x_mean, 2))+0.5*x_log_var).sum()
BATCH_SIZE = 10
EVAL_BATCH_SIZE = 10
CLASSES = 1
TRAIN_EPOCHS = 50
# generate toy dataset: A train-set in form of a complex sin-curve
x_train_data = np.array([])
y_train_data = np.array([])
for repeat in range(2):
for i in range(50, 150):
for j in range(100):
sampled_x = i+np.random.randint(101)/100
sampled_y = toy_function(sampled_x)+np.random.normal(0,0.2)
x_train_data = np.append(x_train_data, sampled_x)
y_train_data = np.append(y_train_data, sampled_y)
x_eval_data = list(np.arange(50.0, 150.0, 0.1))
y_eval_data = [toy_function(x) for x in x_eval_data]
LOADER_KWARGS = {'num_workers': 0, 'pin_memory': False} if torch.cuda.is_available() else {}
train_set = TensorDataset(torch.Tensor(x_train_data),torch.Tensor(y_train_data))
eval_set = TensorDataset(torch.Tensor(x_eval_data), torch.Tensor(y_eval_data))
train_loader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, **LOADER_KWARGS)
eval_loader = torch.utils.data.DataLoader(eval_set, batch_size=EVAL_BATCH_SIZE, shuffle=False, **LOADER_KWARGS)
TRAIN_SIZE = len(train_loader.dataset)
EVAL_SIZE = len(eval_loader.dataset)
assert (TRAIN_SIZE % BATCH_SIZE) == 0
assert (EVAL_SIZE % EVAL_BATCH_SIZE) == 0
net = ReferenceRegNet().to(DEVICE)
optimizer = optim.Adam(net.parameters(), lr=1e-3)
losses = {}
# train network
for epoch in range(1,TRAIN_EPOCHS+1):
net.train()
mean_epoch_loss = 0
mean_epoch_mse = 0
# train batches
for batch_idx, (data, target) in enumerate(tqdm(train_loader), start=1):
data, target = (data.to(DEVICE)).unsqueeze(dim=1), (target.to(DEVICE)).unsqueeze(dim=1)
optimizer.zero_grad()
output_means, output_log_var = net(data)
target_np = target.detach().cpu().numpy()
output_means_np = output_means.detach().cpu().numpy()
loss = loss_function(output_means, output_log_var, target)
loss_value = loss.item() # get raw float-value out of loss-tensor
mean_epoch_loss += loss_value
# optimize network
loss.backward()
optimizer.step()
mean_epoch_loss = mean_epoch_loss / len(train_loader)
losses.update({epoch:mean_epoch_loss})
print("Epoch " + str(epoch) + ": Train-Loss = " + str(mean_epoch_loss))
net.eval()
with torch.no_grad():
mean_loss = 0
mean_mse = 0
for data, target in eval_loader:
data, target = (data.to(DEVICE)).unsqueeze(dim=1), (target.to(DEVICE)).unsqueeze(dim=1)
output_means, output_log_var = net(data) # perform prediction
target_np = target.detach().cpu().numpy()
output_means_np = output_means.detach().cpu().numpy()
mean_loss += loss_function(output_means, output_log_var, target).item()
mean_loss = mean_loss/len(eval_loader)
#print("Epoch " + str(epoch) + ": Eval-loss = " + str(mean_loss))
fig = plt.figure(figsize=(40,12)) # create a 30x30 inch figure
ax = fig.add_subplot(1,3,1)
ax.set_title("regression value")
ax.set_xlabel("x")
ax.set_ylabel("regression mean")
ax.plot(x_train_data, y_train_data, 'x', color='black')
ax.plot(x_eval_data, y_eval_data, color='red')
pred_means_list = []
output_vars_list_train = []
output_vars_list_test = []
for x_test in sorted(x_train_data):
x_test = (torch.Tensor([x_test]).to(DEVICE))
pred_means, output_log_vars = net.forward(x_test)
pred_means_list.append(pred_means.detach().cpu())
output_vars_list_train.append(torch.exp(output_log_vars).detach().cpu())
ax.plot(sorted(x_train_data), pred_means_list, color='blue', label = 'training_perform')
pred_means_list = []
for x_test in x_eval_data:
x_test = (torch.Tensor([x_test]).to(DEVICE))
pred_means, output_log_vars = net.forward(x_test)
pred_means_list.append(pred_means.detach().cpu())
output_vars_list_test.append(torch.exp(output_log_vars).detach().cpu())
ax.plot(sorted(x_eval_data), pred_means_list, color='green', label = 'eval_perform')
plt.tight_layout()
plt.legend()
ax = fig.add_subplot(1,3,2)
ax.set_title("variance")
ax.set_xlabel("x")
ax.set_ylabel("regression var")
ax.plot(sorted(x_train_data), output_vars_list_train, label = 'training data')
ax.plot(x_eval_data, output_vars_list_test, label = 'test data')
plt.tight_layout()
plt.legend()
ax = fig.add_subplot(1,3,3)
ax.set_title("training loss")
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
lists = sorted(losses.items())
epoch, loss = zip(*lists)
ax.plot(epoch, loss, label = 'loss')
plt.tight_layout()
plt.legend()
plt.savefig('ref_test.png')
Upvotes: 5
Views: 1045
Reputation: 441
TLDR: The optimization drives the loss to a minimum where the gradient becomes zero, regardless of what the nominal loss value is.
A comprehensive explanation by K.Frank:
A smaller loss – algebraically less positive or algebraically more negative – means (or should mean) better predictions. The optimization step uses some version of gradient descent to make your loss smaller. The overall level of the loss doesn’t matter as far as the optimization goes. The gradient tells the optimizer how to change the model parameters to reduce the loss, and it doesn’t care about the overall level of the loss.
An example from the same source:
Consider, for example, optimizing with lossA = MSELoss. Now imagine optimizing with lossB = lossA - 17.2. The 17.2 doesn’t really change anything at all. It is true that “perfect” predictions will yield lossB = -17.2 rather than zero. (lossA will, of course, be zero for “perfect” predictions.) But who cares?
In your example: you are right, the negative loss value comes from the logarithmic term. This is completely OK and it means that your training is dominated by contributions of high-confidence loss terms. Regarding the high values of variance - can't comment much on it but it should be fine since the loss curve drops as expected.
Upvotes: 2