Daniel
Daniel

Reputation: 53

How to get the symmetric difference of more than 2 lists?

I want to get all exclusive elements between all my lists. So if I have 3 lists like:

list1 = [1, 3, 2]
list2 = ["a", 1, 3]
list3 = [2, 0]

My output should be:

['a', 0]

I tried to do symmetric differencing with all of the lists like:

set(list1) ^ set(list2) ^ set(list3)

But this doesn´t work well.

Also I tried:

def exclusive(*lista):
    excl = set(lista[0])
    for idx in range(len(lista)):
        excl ^= set(lista[idx])
    return excl

That works the same as the first method but it doesn´t produce what I want.

Then I tried (set(list1) ^ set(list2)) ^ (set(list2) ^ (set(list3)) and found that it's not the same as what I first tried.

EDIT:

I give 3 list as an example but function take undifined number of arguments

Upvotes: 4

Views: 2114

Answers (3)

benvc
benvc

Reputation: 15130

This can be done primarily with set operations, but I prefer the simplicity of the answer from @pault. In order to get the symmetric difference of an arbitrary number of sets, you can find the intersection among all set combinations and then get the symmetric difference of that combined intersection from a union of all sets.

from itertools import combinations

def symdiff(*sets):
    union = set()
    union.update(*sets)

    intersect = set()
    for a, b in combinations(sets, 2):
        intersect.update(a.intersection(b))

    return intersect.symmetric_difference(union)

distincts = symdiff(set([1, 3, 2]), set(['a', 1, 3]), set([2, 0]))
print(distincts)
# {0, 'a'}

Following are better example inputs where a simple sequential symmetric difference of the sets would not provide the same result.

distincts = symdiff(set([1, 3, 2, 0]), set(['a', 1, 3, 0]), set([2, 0]))
print(distincts)
# {'a'}

Upvotes: 0

pault
pault

Reputation: 43544

You could also take a non-set approach using collections.Counter:

from itertools import chain
from collections import Counter

res = [k for k, v in Counter(chain(list1, list2, list3)).items() if v==1]
print(res)
#['a', 0]

Use itertools.chain to flatten your lists together and use Counter to count the occurrences. Keep only those where the count is 1.


Update: Here is a better example that demonstrates why the other methods do not work.

list1 = [1, 3, 2]
list2 = ["a", 1, 3]
list3 = [2, 0]
list4 = [1, 4]
all_lists = [list1, list2, list3, list4]

Based on your criteria, the correct answer is:

print([k for k, v in Counter(chain(*all_lists)).items() if v==1])
#['a', 4, 0]

Using reduce(set.symmetric_difference, ...):

sets = map(set, all_lists)
print(reduce(set.symmetric_difference, sets))
#{0, 1, 4, 'a'}

Using the symmetric difference minus the intersection:

set1 = set(list1)
set2 = set(list2)
set3 = set(list3)
set4 = set(list4)

print((set1 ^ set2 ^ set3 ^ set4) - (set1 & set2 & set3 & set4))
#{0, 1, 4, 'a'}

Upvotes: 5

blhsing
blhsing

Reputation: 107094

You should subtract the intersection of the 3 sets from the symmetric difference of the 3 sets in order to get the exclusive items:

set1 = set(list1)
set2 = set(list2)
set3 = set(list3)

(set1 ^ set2 ^ set3) - (set1 & set2 & set3)

so that given:

list1 = [1,3,2]
list2 = ["a",1,3]
list3 = [2,0,1]

this returns:

{0, 'a'}

whereas your attempt of set1 ^ set2 ^ set3 would incorrectly return:

{0, 1, 'a'}

Upvotes: 0

Related Questions