Reputation: 420
I have a pytorch tensor
span_end = tensor([[[13]]])
I do the following
span_end = span_end.view(1).squeeze().data.numpy()
print(type(span_end))
print(span_end.shape)
This gives me the following output
<class 'numpy.ndarray'>
()
Then later when I try to access the 0th
element of span_end
I get IndexError
because the shape is null somehow. What am I doing wrong here?
Upvotes: 0
Views: 228
Reputation: 1519
tensor.squeeze()
will remove all dimensions of size 1, which in this case all of them are therefore it will result in a tensor with no dimensions.
Removing that statement will work.
import torch
span_end = torch.tensor([[[13]]])
span_end = span_end.view(1).numpy()
print(type(span_end))
print(span_end.shape)
print(span_end[0])
Outputs:
<class 'numpy.ndarray'>
(1,)
13
Upvotes: 2