Reputation: 2191
I’m trying to generate time-series data with an LSTM and a Mixture Density Network as described in https://arxiv.org/pdf/1308.0850.pdf Here is a link to my implementation: https://github.com/NeoVand/MDNLSTM The repository contains a toy dataset to train the network. On training, the LSTM layer returns nan for its hidden state after one iteration. A similar issue is reported here. For your convenience, here is the code:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import numpy.random as npr
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
ts = torch.load('LDS_Toy_Data.pt')
def detach(states):
return [state.detach() for state in states]
class MDNLSTM(nn.Module):
def __init__(self, d_obs, d_lat=2, n_gaussians=2, n_layers=1):
super(MDNLSTM, self).__init__()
self.d_obs = d_obs
self.d_lat = d_lat
self.n_gaussians = n_gaussians
self.n_layers = n_layers
self.lstm = nn.LSTM(d_obs, d_lat, n_layers, batch_first=True)
self.fcPi = nn.Linear(d_lat, n_gaussians*d_obs)
self.fcMu = nn.Linear(d_lat, n_gaussians*d_obs)
self.fcSigma = nn.Linear(d_lat, n_gaussians*d_obs)
def get_mixture_coef(self, y):
time_steps = y.size(1)
pi, mu, sigma = self.fcPi(y), self.fcMu(y), self.fcSigma(y)
pi = pi.view(-1, time_steps, self.n_gaussians, self.d_obs)
mu = mu.view(-1, time_steps, self.n_gaussians, self.d_obs)
sigma = sigma.view(-1, time_steps, self.n_gaussians, self.d_obs)
pi = F.softmax(pi, 2)
sigma = torch.exp(sigma)
return pi, mu, sigma
def forward(self, x, h):
y, (h, c) = self.lstm(x, h)
#print(h)
pi, mu, sigma = self.get_mixture_coef(y)
return (pi, mu, sigma), (h, c)
def init_hidden(self, bsz):
return (torch.zeros(self.n_layers, bsz, self.d_lat).to(device),
torch.zeros(self.n_layers, bsz, self.d_lat).to(device))
def mdn_loss_fn(y, pi, mu, sigma):
m = torch.distributions.Normal(loc=mu, scale=sigma)
loss = torch.exp(m.log_prob(y))
loss = torch.sum(loss * pi, dim=2)
loss = -torch.log(loss)
return loss.mean()
def criterion(y, pi, mu, sigma):
y = y.unsqueeze(2)
return mdn_loss_fn(y, pi, mu, sigma)
DOBS = 10
DLAT = 2
INSTS = 100
seqlen = 30
epochs = 200
mdnlstm = MDNLSTM(DOBS, DLAT).to(device)
optimizer = torch.optim.Adam(mdnlstm.parameters())
z = torch.from_numpy(ts[:INSTS,:,:]).float().to(device)
# hiddens=[]
# Train the model
for epoch in range(epochs):
# Set initial hidden and cell states
hidden = mdnlstm.init_hidden(INSTS)
for i in range(0, z.size(1) - seqlen, seqlen):
# Get mini-batch inputs and targets
inputs = z[:, i:i+seqlen, :]
targets = z[:, (i+1):(i+1)+seqlen, :]
hidden = detach(hidden)
# hiddens.append(hidden)
(pi, mu, sigma), hidden = mdnlstm(inputs, hidden)
loss = criterion(targets, pi, mu, sigma)
mdnlstm.zero_grad()
loss.backward()
optimizer.step()
if epoch % 100 == 0:
print ('Epoch [{}/{}], Loss: {:.4f}'
.format(epoch, epochs, loss.item()))
I would appreciate any help on this.
Upvotes: 0
Views: 2479
Reputation: 2191
The issue was caused by the log-sum-exp operation not being done in a stable way. Here is an implementation of a weighted log-sum-exp trick that I used and could fix the problem:
def weighted_logsumexp(x,w, dim=None, keepdim=False):
if dim is None:
x, dim = x.view(-1), 0
xm, _ = torch.max(x, dim, keepdim=True)
x = torch.where(
# to prevent nasty nan's
(xm == float('inf')) | (xm == float('-inf')),
xm,
xm + torch.log(torch.sum(torch.exp(x - xm)*w, dim, keepdim=True)))
return x if keepdim else x.squeeze(dim)
and using that implemented the stable loss function:
def mdn_loss_stable(y,pi,mu,sigma):
m = torch.distributions.Normal(loc=mu, scale=sigma)
m_lp_y = m.log_prob(y)
loss = -weighted_logsumexp(m_lp_y,pi,dim=2)
return loss.mean()
This worked like a charm. In general, the problem is that torch won't report under-flows.
Upvotes: 2