Reputation: 715
I am using a matrix to count combinations of relationships between nodes in a graph (the details are irrelevant).
I have an N*N adjacency matrix, rows and columns correspond to nodes. So, position [5,7] is how many times node 5 is with node 7. As is [7,5]. Position [3,3] is how many times node 3 appears at all, so how many times it appears in total.
At each loop I have to reduce my matrix. I take a vector of size n, and I subtract the matrix diagonal by that vector. So, I am reducing the total count of each node: so like, [1,1] and [2,2] and [3,3] in my matrix, etc.
Hopefully I am making sense so far. Here is this question.
At this point, I have modified the diagonal of my matrix. Now, I want to modify every position [i,j] where
i != j and matrix[i,j] != 0
I want to modify it such that:
matrix[i,j] = min(matrix[i,i],matrix[j][j])
Now of course I could just iterate through every index pair (i,j) and do what I wrote just above. But that is slow. I am hoping there is some clever math or numpy trick to make this much faster.
Thank you!
Upvotes: 0
Views: 36
Reputation: 353449
First, before doing any optimizations you should profile: there's no point in trying to be clever about something that only takes tens of milliseconds in the entire life of your program, or that only amounts to a small fraction of the total runtime.
That said, you can vectorize the taking of the minimum by taking advantage of broadcasting:
def slow(arr):
out = arr.copy()
for (i, j), x in np.ndenumerate(arr):
if i != j and arr[i,j] != 0:
out[i,j] = min(arr[i, i], arr[j, j])
return out
def fast(arr):
diag = arr.diagonal()
mins = np.minimum(diag, diag[:, None])
out = np.where(arr != 0, mins, arr)
out[np.diag_indices_from(arr)] = diag
return out
which gives me
In [61]: a = np.random.randint(0, 10, (100, 100))
In [62]: (slow(a) == fast(a)).all()
Out[62]: True
In [63]: %timeit slow(a)
11.9 ms ± 188 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
In [64]: %timeit fast(a)
62.8 µs ± 916 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Upvotes: 2