Rasoul
Rasoul

Reputation: 3847

Batch inference of softmax does not sum to 1

I am working with REINFORCE algorithm with PyTorch. I noticed that the batch inference/predictions of my simple network with Softmax doesn’t sum to 1 (not even close to 1). I am attaching a minimum working code so that you can reproduce it. What am I missing here?

import numpy as np
import torch

obs_size = 9
HIDDEN_SIZE = 9
n_actions = 2

np.random.seed(0)

model = torch.nn.Sequential(
        torch.nn.Linear(obs_size, HIDDEN_SIZE),
        torch.nn.ReLU(),
        torch.nn.Linear(HIDDEN_SIZE, n_actions),
        torch.nn.Softmax(dim=0)
    )

state_transitions = np.random.rand(3, obs_size)

state_batch = torch.Tensor(state_transitions)
pred_batch = model(state_batch)  # WRONG PREDICTIONS!
print('wrong predictions:\n', *pred_batch.detach().numpy())
# [0.34072137 0.34721774] [0.30972624 0.30191955] [0.3495524 0.3508627]
# DOES NOT SUM TO 1 !!!

pred_batch = [model(s).detach().numpy() for s in state_batch]  # CORRECT PREDICTIONS
print('correct predictions:\n', *pred_batch)
# [0.5955179  0.40448207] [0.6574412  0.34255883] [0.624833   0.37516695]
# DOES SUM TO 1 AS EXPECTED

Upvotes: 1

Views: 1328

Answers (1)

iacob
iacob

Reputation: 24181

Although PyTorch lets us get away with it, we don’t actually provide an input with the right dimensionality. We have a model that takes one input and produces one output, but PyTorch nn.Module and its subclasses are designed to do so on multiple samples at the same time. To accommodate multiple samples, modules expect the zeroth dimension of the input to be the number of samples in the batch.

That your model works on each individual sample is an implementation nicety. You have incorrectly specified the dimension for the softmax (across batches instead of across the variables), and hence when given a batch dimension it is computing the softmax across samples instead of within samples:

nn.Softmax requires us to specify the dimension along which the softmax function is applied:

softmax = nn.Softmax(dim=1)

In this case, we have two input vectors in two rows (just like when we work with batches), so we initialize nn.Softmax to operate along dimension 1.

Change torch.nn.Softmax(dim=0) to torch.nn.Softmax(dim=1) to get appropriate results.

Upvotes: 1

Related Questions