Reputation: 60
I am confused while passing the parameters in classmethod.
The code is shown below:
def gather_feature(fmap, index):
# fmap.shape(B, k1, 1) index.shape(B, k2)
dim = fmap.size(-1)
index = index.unsqueeze(len(index.shape)).expand(*index.shape, dim) # this works
fmap = fmap.gather(dim=1, index=index) # out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1.
return fmap
def gather_feature(fmap, index):
# fmap.shape(B, k1, 1) index.shape(B, k2)
dim = fmap.size(-1)
index = index.unsqueeze(len(index.shape))
index = index.expand(*index.shape, dim) # raise error
fmap = fmap.gather(dim=1, index=index) # out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1.
return fmap
Once index.unsqueeze() has done, the shape of index would be changed to (B, k2, 1).
If the index.shape that pass to expand() classmethod is (B, k2, 1), an error has raised.
However, if writing these tow classmethod in one line, namely index.unsqueeze().expand(), the index.shape passing to expand() classmethod seems to be (B, k2).
Has the index.shape been computed and stored before performing .unsqueeze()?
Therefore, the .unsqueeze() won't affect the index.shape which pass to .expand().
That is my guess, but I cannot figure out another one.
Thank you for your time.
Upvotes: 1
Views: 53
Reputation: 713
Consider the first case -
index.unsqueeze(len(index.shape)).expand(*index.shape, dim)
This is equivalent to the following code -
A = index.unsqueeze(len(index.shape))
index = A.expand(*index.shape, dim)
Note that Tensor index
has not been changed after the execution of the first line. So when you then execute A.expand(*index.shape, dim)
the original shape of index
is used.
However in the second case you when you first do index = index.unsqueeze(len(index.shape))
, you are changing index
. So in the next step the new unsqueezed index's
shape is used.
Upvotes: 1