Reputation: 3847
I am working with REINFORCE algorithm with PyTorch. I noticed that the batch inference/predictions of my simple network with Softmax doesn’t sum to 1 (not even close to 1). I am attaching a minimum working code so that you can reproduce it. What am I missing here?
import numpy as np
import torch
obs_size = 9
HIDDEN_SIZE = 9
n_actions = 2
np.random.seed(0)
model = torch.nn.Sequential(
torch.nn.Linear(obs_size, HIDDEN_SIZE),
torch.nn.ReLU(),
torch.nn.Linear(HIDDEN_SIZE, n_actions),
torch.nn.Softmax(dim=0)
)
state_transitions = np.random.rand(3, obs_size)
state_batch = torch.Tensor(state_transitions)
pred_batch = model(state_batch) # WRONG PREDICTIONS!
print('wrong predictions:\n', *pred_batch.detach().numpy())
# [0.34072137 0.34721774] [0.30972624 0.30191955] [0.3495524 0.3508627]
# DOES NOT SUM TO 1 !!!
pred_batch = [model(s).detach().numpy() for s in state_batch] # CORRECT PREDICTIONS
print('correct predictions:\n', *pred_batch)
# [0.5955179 0.40448207] [0.6574412 0.34255883] [0.624833 0.37516695]
# DOES SUM TO 1 AS EXPECTED
Upvotes: 1
Views: 1328
Reputation: 24181
Although PyTorch lets us get away with it, we don’t actually provide an input with the right dimensionality. We have a model that takes one input and produces one output, but PyTorch nn.Module and its subclasses are designed to do so on multiple samples at the same time. To accommodate multiple samples, modules expect the zeroth dimension of the input to be the number of samples in the batch.
That your model works on each individual sample is an implementation nicety. You have incorrectly specified the dimension for the softmax (across batches instead of across the variables), and hence when given a batch dimension it is computing the softmax across samples instead of within samples:
nn.Softmax requires us to specify the dimension along which the softmax function is applied:
softmax = nn.Softmax(dim=1)
In this case, we have two input vectors in two rows (just like when we work with batches), so we initialize
nn.Softmax
to operate along dimension 1.
Change torch.nn.Softmax(dim=0)
to torch.nn.Softmax(dim=1)
to get appropriate results.
Upvotes: 1