Ameen Ali
Ameen Ali

Reputation: 315

Converting python list to pytorch tensor

I have a problem converting a python list of numbers to pytorch Tensor : this is my code :

caption_feat = [int(x)  if x < 11660  else 3 for x in caption_feat]

printing caption_feat gives : [1, 9903, 7876, 9971, 2770, 2435, 10441, 9370, 2]
I do the converting like this : tmp2 = torch.Tensor(caption_feat) now printing tmp2 gives : tensor([1.0000e+00, 9.9030e+03, 7.8760e+03, 9.9710e+03, 2.7700e+03, 2.4350e+03, 1.0441e+04, 9.3700e+03, 2.0000e+00])
However I expected to get : tensor([1. , 9903, , 9971. ......]) Any Idea?

Upvotes: 18

Views: 85769

Answers (4)

Devanshi
Devanshi

Reputation: 211

You can directly convert python list to a pytorch Tensor by defining the dtype. For example,

import torch

a_list = [3,23,53,32,53] 
a_tensor = torch.Tensor(a_list)
print(a_tensor.int())

>>> tensor([3,23,53,32,53])

Upvotes: 20

GarAust89
GarAust89

Reputation: 357

A simple option is to convert your list to a numpy array, specify the dtype you want and call torch.from_numpy on your new array.

Toy example:

some_list = [1, 10, 100, 9999, 99999]
tensor = torch.from_numpy(np.array(some_list, dtype=np.int))

Another option as others have suggested is to specify the type when you create the tensor:

torch.tensor(some_list, dtype=torch.int)

Both should work fine.

Upvotes: 1

ychnh
ychnh

Reputation: 197

Try

torch.IntTensor(caption_feat)

You can see the other types here https://pytorch.org/docs/stable/tensors.html

Upvotes: 0

Dishin H Goyani
Dishin H Goyani

Reputation: 7693

If all elements are integer you can make integer torch tensor by defining dtype

>>> a_list = [1, 9903, 7876, 9971, 2770, 2435, 10441, 9370, 2]
>>> tmp2 = torch.tensor(a_list, dtype=torch.int)
>>> tmp2
tensor([    1,  9903,  7876,  9971,  2770,  2435, 10441,  9370,     2],
       dtype=torch.int32)

While torch.Tensor returns torch.float32 which made it to print number in scientific notation

>>> tmp2 = torch.Tensor(a_list)
>>> tmp2
tensor([1.0000e+00, 9.9030e+03, 7.8760e+03, 9.9710e+03, 2.7700e+03, 2.4350e+03,
        1.0441e+04, 9.3700e+03, 2.0000e+00])
>>> tmp2.dtype
torch.float32

Upvotes: 5

Related Questions