Reputation: 45
Hi everyone and please excuse my limited programming knoweledge. I have two arrays like:
A =([[ 0.10111977, 0.5511177 , 0.49532397, 0.42136468, 0.43345532],
[ 0.3812068 , 0.97679566, 0.20473656, 0.40256096, 0.32423426],
[ 0.2387294 , 0.88714084, 0.01064819, 0.48275173, 0.78234234]])
B = ([[ 0.10111977, 0.5511177 , 0.49532397],
[ 0.2387294 , 0.88714084, 0.01064819]])
(they actually have many thousands of lines but just to demonstrate the problem). I'd like to compare the two in order to find which of the lines in B are also present in A in order to copy the relevant row into a new array that would look like:
C =([[ 0.10111977, 0.5511177 , 0.49532397, 0.42136468, 0.43345532],
[ 0.2387294 , 0.88714084, 0.01064819, 0.48275173, 0.78234234]])
The easy (brute force) solution I tried is to do something like:
for rowB in B:
for rowA in A:
if A[rowA,0]==B[rowB,0] and A[rowA,1]==B[rowB,1] and A[rowA,2]==B[rowB,2]:
C.extend(row)
continue
now this will work but as I said my datasets are huge and it takes for ever. Is there an easier\faster way to do this? I have thought of interpolation but I don't see how it can be done with those data.
Upvotes: 1
Views: 1901
Reputation: 301
You can use set logic:
SetA & setB will return all of the items in A that are in B only:
a = set(list1)
b = set(list2)
c = a & b
c will now contain matches!
Edit, as i did not see the numpy reference, if you search the docs you can find the method that you are looking for:
http://docs.scipy.org/doc/numpy/reference/generated/numpy.intersect1d.html#numpy.intersect1d
Upvotes: 1
Reputation: 35125
This is a version with better time complexity [O(n) on average according to https://wiki.python.org/moin/TimeComplexity ]:
import numpy as np
def common_rows(A, B):
items = set(tuple(row) for row in B)
return np.array([row for row in A if tuple(row[:3]) in items])
n = 10000
A = np.random.rand(n, 5)
B = np.random.rand(n, 3)
# Make some common rows
B[123,:] = A[5775,:3]
B[1443,:] = A[85,:3]
print("-- Expected:")
print(B[123])
print(B[1443])
print("-- Got:")
print(common_rows(A, B))
Numpy doesn't have a set
data structure, so we convert here each row to Python object. This is somewhat inefficient, but should be faster for large n
.
Upvotes: 0