Jh123
Jh123

Reputation: 93

How do I check which rows of one small array exists in another larger one?

How do I check which rows of one small array exists in another larger one?

Given the following setup:

final_batch = np.emtpy((batch_size,2))
batch_size = 4
a = np.array(range(10))
b = np.array(range(10,20))
edges = np.array([[0,11],[0,12],[1,11],[1,12],[0,17]])


c1 = np.random.choice(a,batch).reshape(-1,1)
c2 = np.random.choice(b,batch).reshape(-1,1)
samples = np.append(c1,c2,axis=1)

Now there can exist dubplicates in samples and edges, I want to keep making np.random.choice and only add them to final_batch IF they don't already exist in edges. The simple way to do this would be to just take them 1 by 1 in a loop

while len(final_batch)<batch_size+1:
    c1 = np.random.choice(a,1).reshape(-1,1)
    c2 = np.random.choice(b,1).reshape(-1,1)
    if not np.isin(c1,c2).any():
        final_batch = np.append(final_batch,np.append(c1,c2,axis=1),axis=0)    

final_batch = final_batch[1:]

But all of a,b and edges can be huge and batch size will be 10k, but as it's way faster to sample many elements at once I wanted to see if there is a faster way. Something like

while len(final_batch)<batch_size+1:
     c1 = np.random.choice(a,batch).reshape(-1,1)
     c2 = np.random.choice(b,batch).reshape(-1,1)
     samples = np.append(c1,c2,axis=1)
     full_batch.append(samples NOT IN edges)
     

Note that c1 and c2 are mutually exclusive, so I feel like I should be able to use this somehow.

Upvotes: 0

Views: 51

Answers (1)

Olivier Gauth&#233;
Olivier Gauth&#233;

Reputation: 402

If I understand your question, you are looking for something like

samples = np.empty((10, 2), dtype=int)
samples[:,0] = np.random.choice(a, 10)
samples[:,1] = np.random.choice(b, 10)
new_indices = (samples != edges[:,None]).any(axis=2).all(axis=0)
new_samples = samples[new_indices]

Meaning I generate 10 new samples, then I look whether they match edges. This is not optimal in operation number, as I continue checking for equality even after I found a match, but this is vectorized with numpy, which is usually faster than stopping as soon as you can.

Upvotes: 2

Related Questions