Reputation: 3058
I have two lists of custom objects train
and test
of sizes 9,904 and 7,223, respectively.
The elements in each of these lists are unique.
I want to find the elements that exist in both lists. Currently I'm using the following approach but it's painfully slow:
overlap = [e for e in test if e in train]
Is there a faster way to achieve this?
Upvotes: 2
Views: 1636
Reputation: 835
You can use numpy's intersect1d()
:-
import random
import numpy as np
train = [random.randint(1,51) for var in range(1,9000)] #Your list
test = [random.randint(1,51) for var in range(1,9000)] #Your list
train = np.array(train) #Converting list into numpy's array
test = np.array(test)
overlap = np.intersect1d(train, test)
print(overlap)
Upvotes: 0
Reputation: 37193
@yatu suggested using sets, because this makes membership tests much faster - with a list the interpreter has to look at each element successively, whereas with a set (or a dict, though that's irrelevant) hashing techniques are used. You could simply replace the lists with their set equivalents (obtained by applying the set()
constructor to the list, and you should see some speedup.
There are, however, specific methods to determine the intersection and the union of two sets. As long as ordering isn't important*, here's how you might do it:
train_set = set(train) # Use frozenset if no mutation is required
test_set = set(test)
common_elements = train_set & test_set # or, equivalently
common_elements = train_set.intersection(test_set)
* Until Python 3.7 the ordering of elements in a set or dictionary was not guaranteed.
Upvotes: 1
Reputation: 1524
To complete @Jeff's answer, we can compare the time of computation for the two methods:
import numpy as np
import time
test = np.random.randint(1,50000,10000)
train = np.random.randint(1,50000,10000)
start_list = time.time()
overlap = [e for e in test if e in train]
end_list = time.time()
print("with list comprehension: " + str(end_list - start_list))
set_test = set(test)
set_train = set(train)
start_set = time.time()
overlap = set_test.intersection(set_train)
end_set = time.time()
print("with sets: " + str(end_set - start_set))
We get the output:
with list comprehension: 0.08894968032836914
with sets: 0.0003533363342285156
So the method with sets is around 300 times faster.
Upvotes: 5
Reputation: 251
set_test = set(e)
set_train = set(train)
overlap = set_test.intersection(set_train)
Upvotes: 1