Reputation: 4813
I have a problem vectorizing some code in pytorch.
A numpy solution would also help, but a pytorch solution would be better.
I'm going to use array
and Tensor
interchangeably.
The problem I am facing is this:
Given an 2D float array X
of size (n, x), and a boolean 2D array A
of size (n, n), compute the mean over rows in X
indexed by rows in A
.
The problem is that the rows in A
contain a variable number of True
indices.
Example (numpy):
import numpy as np
A = np.array([[0, 1, 0, 0, 0, 0],
[1, 0, 1, 0, 0, 0],
[0, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 0],
[0, 0, 0, 1, 0, 0],
[0, 1, 1, 1, 0, 0]])
X = np.arange(6 * 3, dtype=np.float32).reshape(6, 3)
# Compute the mean in numpy with a for loop
means_np = np.array([X[A.astype(np.bool)[i]].mean(axis=0) for i in np.arange(len(A)])
So this examples works, but this formulation has 3 problems:
The for loop is slow for larger A
and X
. I need to loop over a few 10 thousand indices.
It can happen that A[i]
contains no True
indices. This results in np.mean(np.array([]))
, which is NaN
. I want this to be 0 instead.
Implementing it this way in pytorch results in SIGFPE (Floating point error) during the backwards pass of backpropagation through this function. The cause is when nothing being selected.
The workaround that I am using now is (also see code below):
A
to True
so that there is always at least one element to selectX
from that sum (the diagonal is guaranteed to be False
in the beginning), and divide by the number of True
elements - 1 clamped to at least 1 in each row.This works, is differentiable in pytorch and does not produce NaN
, but I still need a loop over all indices.
How can I get rid of this loop?
This is my current pytorch code:
import torch
A = torch.from_numpy(A).bytes()
X = torch.from_numpy(X)
A[np.diag_indices(len(A)] = 1 # Set the diagonal to 1
means = [(X[A[i]].sum(dim=0) - X[i]) / torch.clamp(A[i].sum() - 1, min=1.) # Compute the mean safely
for i in range(len(A))] # Get rid of the loop somehow
means = torch.stack(means)
I don't mind if your version looks completely different, as long as it is differentiable and produces the same result.
Upvotes: 4
Views: 1333
Reputation: 221774
We can leverage matrix-multiplication
-
c = A.sum(1,keepdims=True)
means_np = np.where(c==0,0,A.dot(X)/c)
We can optimize it further by converting A
to float32
dtype if it's not already so and if the loss of precision is okay there, as shown below -
In [57]: np.random.seed(0)
In [58]: A = np.random.randint(0,2,(1000,1000))
In [59]: X = np.random.rand(1000,1000).astype(np.float32)
In [60]: %timeit A.dot(X)
10 loops, best of 3: 27 ms per loop
In [61]: %timeit A.astype(np.float32).dot(X)
100 loops, best of 3: 10.2 ms per loop
In [62]: np.allclose(A.dot(X), A.astype(np.float32).dot(X))
Out[62]: True
Thus, use A.astype(np.float32).dot(X)
to replace A.dot(X)
.
Alternatively, to solve for the case where the row-sum is zero
, and that requires us to use np.where
, we could assign any non-zero value, say 1
into c
and then simply divide by it, like so -
c = A.sum(1,keepdims=True)
c[c==0] = 1
means_np = A.dot(X)/c
This would also avoid the warning that we would otherwise get from np.where
in those zero row sum cases.
Upvotes: 1