Tree
Tree

Reputation: 31431

How does python map works with torch.tensor?

I am now in python so I am trying to understand this line from pytorch tutorial.

x_train, y_train, x_valid, y_valid = map(
    torch.tensor, (x_train, y_train, x_valid, y_valid)
)

I understand how map works on a single element

def sqr(a):
    return a * a

a = [1, 2, 3, 4]    

a = map(sqr, a)
print(list(a))

And here I need to use list(a) to convert map object back to list.

But what I don't understand, is how does it work on multiple variables?

If I try to do this

def sqr(a):
    return a * a


a = [1, 2, 3, 4]
b = [1, 3, 5, 7]

a, b = map(sqr, (a, b))
print(list(a))
print(list(b))

I get an error: TypeError: can't multiply sequence by non-int of type 'list'

Please clarify this for me Thank you

Upvotes: 1

Views: 13347

Answers (1)

ndrwnaguib
ndrwnaguib

Reputation: 6135

map works on a single the same way it works on list/tuple of lists, it fetches an element of the given input regardless what is it.

The reason why torch.tensor works, is that it accepts a list as input.

If you unfold the following line you provided:

x_train, y_train, x_valid, y_valid = map(
    torch.tensor, (x_train, y_train, x_valid, y_valid)
)

it's the same as doing:

x_train, y_train, x_valid, y_valid = [torch.tensor(x_train), torch.tensor(y_train), torch.tensor(x_valid), torch.tensor(y_valid)]

On other hand, your sqr function does not accept lists. It expects a scalar type to square, which is not the case for your a an b, they are lists.

However, if you change sqr to:

def sqr(a):
    return [s * s for s in a]


a = [1, 2, 3, 4]
b = [1, 3, 5, 7]

a, b = map(sqr, (a, b))

or as suggested by @Jean, a, b = map(sqr, x) for x in (a, b)

It will work.

Upvotes: 5

Related Questions