Ashutosh Baheti
Ashutosh Baheti

Reputation: 420

Pytorch tensor to numpy gives "()" as shape

I have a pytorch tensor

span_end = tensor([[[13]]])

I do the following

span_end = span_end.view(1).squeeze().data.numpy()
            print(type(span_end))
            print(span_end.shape)

This gives me the following output

<class 'numpy.ndarray'>
()

Then later when I try to access the 0th element of span_end I get IndexError because the shape is null somehow. What am I doing wrong here?

Upvotes: 0

Views: 228

Answers (1)

Jmonsky
Jmonsky

Reputation: 1519

tensor.squeeze() will remove all dimensions of size 1, which in this case all of them are therefore it will result in a tensor with no dimensions.

Removing that statement will work.

import torch
span_end = torch.tensor([[[13]]])
span_end = span_end.view(1).numpy()
print(type(span_end))
print(span_end.shape)
print(span_end[0])

Outputs:

<class 'numpy.ndarray'>
(1,)
13

Upvotes: 2

Related Questions