jerry
jerry

Reputation: 27

What is torch.randn((1, 5))?

I'm confused as to why there are double parantheses instead of just torch.randn(1,5).

Is torch.randn(1,5) the same thing as torch.randn((1,5))?

Upvotes: 1

Views: 745

Answers (2)

trsvchn
trsvchn

Reputation: 8981

You can use both variants: (1, 2) and 1, 2. Because of python asterisk magics:

torch.randn(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) → Tensor

the first *size captures all positional arguments passed to the function, when passing 1, 2 function will pack it to (1, 2).

the second * turns any parameters that follow it to be keyword-only parameters, to avoid situations like this: randn(1, 2 None, torch.strided, "cuda", True), forcing you to randn(1, 2, out=None, dtype=None, layout=torch.strided, device="cuda", requires_grad=True)

Upvotes: 1

Thang Pham
Thang Pham

Reputation: 1026

You should check the definition of this function here.

size (int...) – a sequence of integers defining the shape of the output tensor. Can be a variable number of arguments or a collection like a list or tuple.

>>> import torch
>>> a = torch.randn(1,5)
>>> b = torch.randn((1,5))
>>> a.shape == b.shape
True

Therefore, you can use either a or b since they have the same shape.

Upvotes: 3

Related Questions