thepandaatemyface
thepandaatemyface

Reputation: 5267

Take the intersection of an arbitrary number of lists in python

Suppose I have a list of lists of elements which are all the same (i'll use ints in this example)

[range(100)[::4], range(100)[::3], range(100)[::2], range(100)[::1]]

What would be a nice and/or efficient way to take the intersection of these lists (so you would get every element that is in each of the lists)? For the example that would be:

[0, 12, 24, 36, 48, 60, 72, 84, 96]

Upvotes: 4

Views: 1816

Answers (7)

Zev Averbach
Zev Averbach

Reputation: 1134

Here's a one-liner using the good old all() built-in function:

list(num for num in data[0] 
     if all(num in range_ for range_ in data[1:]))

Interestingly, this is (I think) more readable and faster than using set for larger data sets.

Upvotes: 1

Andrew Jaffe
Andrew Jaffe

Reputation: 27077

You can treat them as sets and use set.intersection():

lists = [range(100)[::4], range(100)[::3], range(100)[::2], range(100)[::1]]
sets = [set(l) for l in lists]

isect = reduce(lambda x,y: x.intersection(y), sets)

Upvotes: 0

inspectorG4dget
inspectorG4dget

Reputation: 113915

l = [range(100)[::4], range(100)[::3], range(100)[::2], range(100)[::1]]
l = [set(i) for i in l]
intersect = l[0].intersection(l[1])
for i in l[2:]:
    intersect = intersect.intersection(i)

Upvotes: 0

Mike Graham
Mike Graham

Reputation: 76683

Use sets, which have an intersection method.

>>> s = set()
>>> s.add(4)
>>> s.add(5)
>>> s
set([4, 5])
>>> t = set([2, 4, 9])
>>> s.intersection(t)
set([4])

For your example, something like

>>> data = [range(100)[::4], range(100)[::3], range(100)[::2], range(100)[::1]]
>>> sets = map(set, data)
>>> print set.intersection(*sets)
set([0, 96, 36, 72, 12, 48, 84, 24, 60])

Upvotes: 9

thepandaatemyface
thepandaatemyface

Reputation: 5267

I'm going to answer my own question:

lists =  [range(100)[::4],range(100)[::3],range(100)[::2],range(100)[::1]]

out = set(lists[0])
for l in lists[1:]:
    out = set(l).intersection(out)

print out

or

print list(out)

Upvotes: 1

Samir Talwar
Samir Talwar

Reputation: 14330

Convert them to sets and use the set.intersection method, reducing over the list of sets:

xs = [range(100)[::4], range(100)[::3], range(100)[::2], range(100)[::1]]
reduce(set.intersection, [set(x) for x in xs])

reduce is a functional programming device that iterates through any iterable and applies the function provided to the first two elements, then to the result and the next, and then the result of that and the next, and so on.

Upvotes: 3

Dan Loewenherz
Dan Loewenherz

Reputation: 11236

I think the built-in set module should do the trick.

>>> elements = [range(100)[::4], range(100)[::3], range(100)[::2], range(100)[::1]]
>>> sets = map(set, elements)
>>> result = list(reduce(lambda x, y: x & y, sets))
>>> print result
[0, 96, 36, 72, 12, 48, 84, 24, 60]

Upvotes: 4

Related Questions