Travis Black
Travis Black

Reputation: 715

2d numpy, efficiently assign nonzero indices [row,col] the minimum of [row,row] and [col,col]

I am using a matrix to count combinations of relationships between nodes in a graph (the details are irrelevant).

I have an N*N adjacency matrix, rows and columns correspond to nodes. So, position [5,7] is how many times node 5 is with node 7. As is [7,5]. Position [3,3] is how many times node 3 appears at all, so how many times it appears in total.

At each loop I have to reduce my matrix. I take a vector of size n, and I subtract the matrix diagonal by that vector. So, I am reducing the total count of each node: so like, [1,1] and [2,2] and [3,3] in my matrix, etc.

Hopefully I am making sense so far. Here is this question.

At this point, I have modified the diagonal of my matrix. Now, I want to modify every position [i,j] where

i != j and matrix[i,j] != 0  

I want to modify it such that:

matrix[i,j] = min(matrix[i,i],matrix[j][j])

Now of course I could just iterate through every index pair (i,j) and do what I wrote just above. But that is slow. I am hoping there is some clever math or numpy trick to make this much faster.

Thank you!

Upvotes: 0

Views: 36

Answers (1)

DSM
DSM

Reputation: 353449

First, before doing any optimizations you should profile: there's no point in trying to be clever about something that only takes tens of milliseconds in the entire life of your program, or that only amounts to a small fraction of the total runtime.

That said, you can vectorize the taking of the minimum by taking advantage of broadcasting:

def slow(arr):
    out = arr.copy()
    for (i, j), x in np.ndenumerate(arr):
        if i != j and arr[i,j] != 0:
            out[i,j] = min(arr[i, i], arr[j, j])
    return out

def fast(arr):
    diag = arr.diagonal()
    mins = np.minimum(diag, diag[:, None])
    out = np.where(arr != 0, mins, arr)
    out[np.diag_indices_from(arr)] = diag
    return out

which gives me

In [61]: a = np.random.randint(0, 10, (100, 100))

In [62]: (slow(a) == fast(a)).all()
Out[62]: True

In [63]: %timeit slow(a)
11.9 ms ± 188 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [64]: %timeit fast(a)
62.8 µs ± 916 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Upvotes: 2

Related Questions