RunOrVeith
RunOrVeith

Reputation: 4813

How to vectorize computation of mean over specific set of indices given as matrix rows?

I have a problem vectorizing some code in pytorch. A numpy solution would also help, but a pytorch solution would be better. I'm going to use array and Tensor interchangeably.

The problem I am facing is this:

Given an 2D float array X of size (n, x), and a boolean 2D array A of size (n, n), compute the mean over rows in X indexed by rows in A. The problem is that the rows in A contain a variable number of True indices.

Example (numpy):

import numpy as np
A = np.array([[0, 1, 0, 0, 0, 0],
              [1, 0, 1, 0, 0, 0],
              [0, 1, 0, 0, 0, 0],
              [0, 0, 0, 0, 1, 0],
              [0, 0, 0, 1, 0, 0],
              [0, 1, 1, 1, 0, 0]])
X = np.arange(6 * 3, dtype=np.float32).reshape(6, 3)

# Compute the mean in numpy with a for loop
means_np = np.array([X[A.astype(np.bool)[i]].mean(axis=0) for i in np.arange(len(A)])

So this examples works, but this formulation has 3 problems:

  1. The for loop is slow for larger A and X. I need to loop over a few 10 thousand indices.

  2. It can happen that A[i] contains no True indices. This results in np.mean(np.array([])), which is NaN. I want this to be 0 instead.

  3. Implementing it this way in pytorch results in SIGFPE (Floating point error) during the backwards pass of backpropagation through this function. The cause is when nothing being selected.

The workaround that I am using now is (also see code below):

This works, is differentiable in pytorch and does not produce NaN, but I still need a loop over all indices. How can I get rid of this loop?

This is my current pytorch code:

 import torch
 A = torch.from_numpy(A).bytes()
 X = torch.from_numpy(X)
 A[np.diag_indices(len(A)] = 1  # Set the diagonal to 1
 means = [(X[A[i]].sum(dim=0) - X[i]) / torch.clamp(A[i].sum() - 1, min=1.)  # Compute the mean safely
          for i in range(len(A))]  # Get rid of the loop somehow
 means = torch.stack(means)

I don't mind if your version looks completely different, as long as it is differentiable and produces the same result.

Upvotes: 4

Views: 1333

Answers (1)

Divakar
Divakar

Reputation: 221774

We can leverage matrix-multiplication -

c = A.sum(1,keepdims=True)
means_np = np.where(c==0,0,A.dot(X)/c)

We can optimize it further by converting A to float32 dtype if it's not already so and if the loss of precision is okay there, as shown below -

In [57]: np.random.seed(0)

In [58]: A = np.random.randint(0,2,(1000,1000))

In [59]: X = np.random.rand(1000,1000).astype(np.float32)

In [60]: %timeit A.dot(X)
10 loops, best of 3: 27 ms per loop

In [61]: %timeit A.astype(np.float32).dot(X)
100 loops, best of 3: 10.2 ms per loop

In [62]: np.allclose(A.dot(X), A.astype(np.float32).dot(X))
Out[62]: True

Thus, use A.astype(np.float32).dot(X) to replace A.dot(X).

Alternatively, to solve for the case where the row-sum is zero, and that requires us to use np.where, we could assign any non-zero value, say 1 into c and then simply divide by it, like so -

c = A.sum(1,keepdims=True)
c[c==0] = 1
means_np = A.dot(X)/c

This would also avoid the warning that we would otherwise get from np.where in those zero row sum cases.

Upvotes: 1

Related Questions