GuikunChen
GuikunChen

Reputation: 60

Passing Parameters in Python/Pytorch Classmethod

I am confused while passing the parameters in classmethod.

The code is shown below:

def gather_feature(fmap, index):
    # fmap.shape(B, k1, 1)  index.shape(B, k2)
    dim = fmap.size(-1)
    index = index.unsqueeze(len(index.shape)).expand(*index.shape, dim)  # this works

    fmap = fmap.gather(dim=1, index=index)  # out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1.
    return fmap
def gather_feature(fmap, index):
    # fmap.shape(B, k1, 1)  index.shape(B, k2)
    dim = fmap.size(-1)
    index = index.unsqueeze(len(index.shape))
    index = index.expand(*index.shape, dim)  # raise error
    
    fmap = fmap.gather(dim=1, index=index)  # out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1.
    return fmap

Once index.unsqueeze() has done, the shape of index would be changed to (B, k2, 1).

If the index.shape that pass to expand() classmethod is (B, k2, 1), an error has raised.

However, if writing these tow classmethod in one line, namely index.unsqueeze().expand(), the index.shape passing to expand() classmethod seems to be (B, k2).

Has the index.shape been computed and stored before performing .unsqueeze()?

Therefore, the .unsqueeze() won't affect the index.shape which pass to .expand().

That is my guess, but I cannot figure out another one.

Thank you for your time.

Upvotes: 1

Views: 53

Answers (1)

rohit_r
rohit_r

Reputation: 713

Consider the first case -

index.unsqueeze(len(index.shape)).expand(*index.shape, dim)

This is equivalent to the following code -

A = index.unsqueeze(len(index.shape))
index = A.expand(*index.shape, dim)

Note that Tensor index has not been changed after the execution of the first line. So when you then execute A.expand(*index.shape, dim) the original shape of index is used.

However in the second case you when you first do index = index.unsqueeze(len(index.shape)) , you are changing index. So in the next step the new unsqueezed index's shape is used.

Upvotes: 1

Related Questions