Reputation: 31431
I am now in python so I am trying to understand this line from pytorch tutorial.
x_train, y_train, x_valid, y_valid = map(
torch.tensor, (x_train, y_train, x_valid, y_valid)
)
I understand how map works on a single element
def sqr(a):
return a * a
a = [1, 2, 3, 4]
a = map(sqr, a)
print(list(a))
And here I need to use list(a)
to convert map object back to list.
But what I don't understand, is how does it work on multiple variables?
If I try to do this
def sqr(a):
return a * a
a = [1, 2, 3, 4]
b = [1, 3, 5, 7]
a, b = map(sqr, (a, b))
print(list(a))
print(list(b))
I get an error: TypeError: can't multiply sequence by non-int of type 'list'
Please clarify this for me Thank you
Upvotes: 1
Views: 13347
Reputation: 6135
map
works on a single the same way it works on list/tuple of lists, it fetches an element of the given input regardless what is it.
The reason why torch.tensor
works, is that it accepts a list as input.
If you unfold the following line you provided:
x_train, y_train, x_valid, y_valid = map(
torch.tensor, (x_train, y_train, x_valid, y_valid)
)
it's the same as doing:
x_train, y_train, x_valid, y_valid = [torch.tensor(x_train), torch.tensor(y_train), torch.tensor(x_valid), torch.tensor(y_valid)]
On other hand, your sqr
function does not accept lists. It expects a scalar type to square, which is not the case for your a
an b
, they are lists.
However, if you change sqr
to:
def sqr(a):
return [s * s for s in a]
a = [1, 2, 3, 4]
b = [1, 3, 5, 7]
a, b = map(sqr, (a, b))
or as suggested by @Jean, a, b = map(sqr, x) for x in (a, b)
It will work.
Upvotes: 5