FooBar
FooBar

Reputation: 16508

Vectorizing this non-unique-key operation

I have a non-unique original data called test. Using this input, I want to create an output vector together with a set of rows that get non-zero output, and the data, that contains their output.

import numpy as np

rows = np.array([3, 4])
test = np.array([1, 3, 3, 4, 5])
data = np.array([-1, 2])

My expected output is a vector of shape test.shape.

Each element in output:

In other words, the following generates my output.

output = np.zeros(test.shape)
for i, val in enumerate(rows):
    output[test == val] = data[i]

Is there any way of vectorizing this?

Upvotes: 3

Views: 63

Answers (2)

Paul Panzer
Paul Panzer

Reputation: 53089

Here is a method that only works if your test and rows consist of not too large integers (non negative as well but this can be relaxed if need be). But then it's fast:

>>> rows = np.array([3, 4])
>>> test = np.array([1, 3, 3, 4, 5])                                                                                        
>>> data = np.array([-1, 2])
>>> 
>>> limit = 1<<20
>>> assert all(a.dtype in map(np.dtype, np.sctypes['int']) for a in  (rows, test))
>>> assert np.all(rows>=0) and np.all(test>=0)
>>> mx = np.maximum(np.max(rows), np.max(test)) + 1
>>> assert mx <= limit
>>> lookup = np.empty((mx,), data.dtype)
>>> lookup[test] = 0
>>> lookup[rows] = data
>>> result = lookup[test]
>>> result
array([ 0, -1, -1,  2,  0])

Upvotes: 0

Divakar
Divakar

Reputation: 221654

Here's a vectorized approach based upon searchsorted -

# Get sorted index positions
idx = np.searchsorted(rows, test)

# Set out-of-bounds(invalid ones) to some dummy index, say 0
idx[idx==len(rows)] = 0

# Get invalid mask array found out by indexing data array
# with those indices and looking for matches
invalid_mask = rows[idx] != test

# Get data indexed array as output and set invalid places with 0s
out = data[idx]
out[invalid_mask] = 0

Last couple of lines could have two alternatives, if you dig one-liners -

out = data[idx] * (rows[idx] == test) # skips using `invalid_mask`

out = np.where(invalid_mask, 0, data[idx])

Upvotes: 2

Related Questions