Reputation: 91
I'm very new to PyTorch, and I have encountered the "Index tensor must have the same number of dimensions as input tensor" error when running my neural network. It happens with I call an instance of torch.gather(). Could someone help me understand torch.gather() and explain the cause of this error?
Here is the code where the error occurs:
def learn(batch, optim, net, target_net, gamma, global_step, target_update):
my_loss = []
optim.zero_grad()
state, action, next_state, reward, done, next_action = batch
qval = net(state.float())
loss_a = torch.gather(qval, 3, action.view(-1,1,1,1)).squeeze() #Error happens here!
loss_b = reward + gamma * torch.max(target_net(next_state.float()).cuda(), dim=3).values * (1 - done.int())
loss_val = torch.sum(( torch.abs(loss_a-loss_b) ))
loss_val /= 128
my_loss.append(loss_val.item())
loss_val.backward()
optim.step()
if global_step % target_update == 0:
target_network.load_state_dict(q_network.state_dict())
In case it is helpful, here is the batch function that creates the batch that the action comes from:
def sample_batch(memory,batch_size):
indices = np.random.randint(0,len(memory), (batch_size,))
state = torch.stack([memory[i][0] for i in indices])
action = torch.tensor([memory[i][1] for i in indices], dtype = torch.long)
next_state = torch.stack([memory[i][2] for i in indices])
reward = torch.tensor([memory[i][3] for i in indices], dtype = torch.float)
done = torch.tensor([memory[i][4] for i in indices], dtype = torch.float)
next_action = torch.tensor([memory[i][5] for i in indices], dtype = torch.long)
return state,action,next_state,reward,done,next_action
When I print out the different shapes of 'qvals', 'action', and 'action.view(-1,1,1,1)' this is the output:
qval torch.Size([10, 225])
act view torch.Size([10, 1, 1, 1])
action shape torch.Size([10])
Any explanation as to what is causing this error is appreciated! I want to understand more what is going on in the code as well as how to fix the problem. Thanks!
Upvotes: 5
Views: 10257
Reputation: 3573
Torch.gather is described here. If we take your code, this line
torch.gather(qval, 3, action.view(-1,1,1,1))
is equivalent to
act_view = action.view(10,1,1,1)
out = torch.zeros_like(act_view)
for i in range(10):
for j in range(1):
for k in range(1):
for p in range(1):
out[i,j,k,p] = qval[i,j,k, act_view[i,j,k,p]]
return out
which obviously makes very little sense. In particular, qval
is not 4-D and thus cannot be indexed like this. The number of for
loops is determined by the shape of your input tensors, and they should all have the same number of dimensions for this to work (this is what your error tells you by the way). Here, qval
is 2D and act_view
is 4D.
I'm not sure what you wanted to do with this, but if you can explain your goal and remove all the useless stuff in your example (mostly the training and backprop related code) to get a minimal reproducible example, I could help you further in finding the correct way to do it :)
Upvotes: 2