Joel
Joel

Reputation: 15742

How to calculate mean on second axis until row-specific column in Pytorch?

I am looking for a fast way to calculate the mean for each row of a 2d matrix but only until a specific column. The remaining values of each row can be ignored. The column is different for each row.

In Numpy, it could be coded like this. However, I am hoping to find a solution without a for loop which also does not break the gradients.

import numpy as np

arr = np.linspace(0, 10, 15).reshape(3,5)
cols = [2,0,4]

for row, col in enumerate(cols):
    arr[row, col+1:] = np.nan

result = np.nanmean(arr, axis=1)

Any suggestions?

Edit: Best solution I have found so far:

result = torch.stack([arr[i, 0:cols[i]+1].mean() for i in range(len(arr))])

But I would still like to avoid the for loop.

Upvotes: 1

Views: 261

Answers (1)

Quang Hoang
Quang Hoang

Reputation: 150745

Try creating a mask:

t = torch.tensor(arr)
mask = torch.arange(t.shape[1]) <= torch.tensor(cols).unsqueeze(-1)

result = (t*mask).sum(1)/mask.sum(1)

Output:

tensor([0.7143, 3.5714, 8.5714], dtype=torch.float64)

Upvotes: 1

Related Questions