Andrea Aquino
Andrea Aquino

Reputation: 132

Get a list of lists of the sets which intersect each other in python

Given a list of sets I would like to get a list of the lists of the sets which intersect each other. Basically what I want is a list of lists s.t. for each list in the output all sets in that list have a non-empty intersection with at least another set in the same list.

I hope I was able to explain my problem. Hopefully the following example and the rest of the post should clarify it even more.

Given,

sets = [
    set([1,3]), # A
    set([2,3,5]), # B
    set([21,22]), # C
    set([1,9]), # D
    set([5]), # E
    set([18,21]), # F
]

My desired output is:

[
    [
        set([1,3]), # A, shares elements with B
        set([2,3,5]), # B, shares elements with A 
        set([1,9]), # D, shares elements with A
        set([5]), # E shares elements with B
    ],
    [
        set([21,22]), # C shares elements with F
        set([18,21]), # F shares elements with C
    ]
]

The order of the sets in the output does NOT matter.

I would like to achieve this goal with a very fast algorithm. Performance is my first requirement.

At the moment my solution creates a graph with as many nodes as sets in the original list. Then it creates an edge in this graph between the nodes that represents sets A and B iff these sets have a non empty intersection. Than it calculates the connected components of such a graph which gives me my expected result.

I am wondering if there is a faster way of doing this with an algorithm which does not involve graphs.

Best, Andrea

Upvotes: 2

Views: 420

Answers (2)

beetea
beetea

Reputation: 308

I've only seen solutions that make multiple passes on sets (typically O(N^2)). So, out of curiosity I wanted to see if this could be done with only one pass on sets (ie. O(N)).

Using the input set from the original question, here's an example of how this algorithm iterates through each set in the list.

[
    set([1,3]), # A
    set([2,3,5]), # B
    set([21,22]), # C
    set([1,9]), # D
    set([5]), # E
    set([18,21]), # F
]

We start at Set A (set([1,3])) and assume it is part of a new result list L1. We also maintain a Pointer, C1, to our current result list (which is L1 at the moment). The purpose of this Pointer will be explained later:

L1 = [set([1,3]),]
C1 = Pointer(L1)

For each integer in Set A, we also update our mapping M:

{
    1: C1, # Pointer(L1)
    3: C1, # Pointer(L1)
}

The next item is Set B (set([2,3,5])) and we assume it is part of a new result list L2.

L2 = [set([2,3,5]),]
C2 = Pointer(L2)

Again, we iterate the members of this set and update our mapping M. Explicitly, we see that 2 is not in M and we update it such that we have:

{
    1: C1, # Pointer(L1)
    3: C1, # Pointer(L1)
    2: C2, # Pointer(L2)
}

However, when we see 3, we notice that it already points to another result list, L1. This indicates that there's an intersection and we need to do two things:

  1. Update C2 to point to L1
  2. Add Set B to L1

We should end up with M looking like:

{
    1: C1, # Pointer(L1)
    3: C1, # Pointer(L1)
    2: C2, # Pointer(L1)
}

And L1 should now be:

[[set([1,3]), set([2,3,5])]

This is why we need the Pointer class: we can't just assign C2 to a new Pointer() instance. If we do that, then M would look like:

{
    1: C1, # Pointer(L1)
    3: C1, # Pointer(L1)
    2: C2, # Pointer(L2) <-- this should be pointing to L1
}

Instead, we do somethng like:

C2.set(L1)

Stepping ahead, after we've processed Set B and Set C, the state of M should look like:

{
    1: C1, # Pointer(L1)
    3: C1, # Pointer(L1)
    2: C2, # Pointer(L1)
    5: C2, # Pointer(L1)
    21: C3, # Pointer(L3)
    22: C3, # Pointer(L3)
}

And the result lists:

[set([1,3]), set([2,3,5])] # L1
[set([21,22]), ]           # L3

[set([2,3,5]), ]           # L2 (but nothing in M points to it)

When we look at Set D, we again assume it's part of a new result list, L4, and create the corresponding Pointer, C4 which points to L4. However, when we see see that Set D contains 1 which intersects with L1, we change C4 to point to L1 and add Set D to L1.

The state of M is:

{
    1: C1, # Pointer(L1)
    3: C1, # Pointer(L1)
    2: C2, # Pointer(L1)
    5: C2, # Pointer(L1)
    21: C3, # Pointer(L3)
    22: C3, # Pointer(L3)
    9: C4, # Pointer(L1)
}

And L1 is now:

[set([1,3]), set([2,3,5]), set([1,9])]

Stepping through all the way to the end, the state of M will look like:

{
    1: C1, # Pointer(L1)
    3: C1, # Pointer(L1)
    2: C2, # Pointer(L1)
    5: C2, # Pointer(L1)
    21: C3, # Pointer(L3)
    22: C3, # Pointer(L3)
    9: C4, # Pointer(L1)
    18: C6, # Pointer(L3)
}

And the lists look like:

[set([1,3]), set([2,3,5]), set([1,9]), set([5])] # L1
[set([21,22]), set([18, 21]) # L3

At this point, you can iterate the Pointers in M and find all the unique lists that are being referenced and that will be your result.

In the original code, I tried to avoid this last step by maintaining the list of unique result sets on-the-fly. But, I've removed that logic to keep the code simpler since the performance gain probably isn't that great anyway.


Here's the updated code with the fix for the bug mentioned by Andrea in the comments.

class Pointer(object):
    """
    Implements a pointer to an object. The actual object can be accessed
    through self.get().
    """
    def __init__(self, o): self.o = o
    def __str__(self): return self.o.__str__()
    def __repr__(self): return '<Pointer to %s>' % id(self.o)

    # These two methods make instances hashed on self.o rather than self
    # so Pointers to the same object will be considered the same in
    # dicts, sets, etc.
    def __eq__(self, x): return x.o is self.o
    def __hash__(self): return id(self.o)

    def get(self): return self.o
    def set(self, obj): self.o = obj.o
    def id(self): return id(self.o)

def connected_sets(s):
    M = {}
    for C in s:
        # We assume C belongs to a new result list
        P = Pointer(list())
        added = False # to check if we've already added C to a result list
        for item in C:
            # There was an intersection detected, so the P points to this
            # intersecting set. Also add this C to the intersecting set.
            if item in M:
                # The two pointers point to different objects.
                if P.id() != M[item].id():
                    # If this was a regular assignment of a set to a new set, this
                    # wouldn't work since references to the old set would not be
                    # updated.
                    P.set(M[item])
                    M[item].o.append(C)
                    added = True
            # Otherwise, we can add it to P which, at this point, could point to
            # a stand-alone set or an intersecting set.
            else:
                if not added:
                    P.o.append(C)
                    added = True
                M[item] = P

    return M

if __name__ == '__main__':
    sets = [
        set([1,3]), # A
        set([2,3,5]), # B
        set([21,22]), # C
        set([1,9]), # D
        set([5]), # E
        set([18,21]), # F
    ]

    #sets = [
    #    set([1,5]), # A
    #    set([1,5]), # B
    #]

    M = connected_sets(sets)
    import pprint
    pprint.pprint(
        # create list of unique lists referenced in M
        [x.o for x in set(M.values())]
        )

And the output for is:

[[set([1, 3]), set([2, 3, 5]), set([1, 9]), set([5])],
 [set([21, 22]), set([18, 21])]]

It seems to work for me, but I'd be curious how it performs compared to your solution. Using the timeit module, I've compared this to the answer above that uses networkx and my script is almost 3x faster on the same data set.

Upvotes: 1

Abhijit
Abhijit

Reputation: 63787

As @MartijnPieters rightly said, the problem calls for graphs, and networkx would be at your rescue.

Salient Points

  1. Nodes of the graph should be sets
  2. Edges between the graph exist iff the sets intersect
  3. From the resultant graph, find all connected components

Implementation

def intersecting_sets(sets):
    import networkx as nx
    G = nx.Graph()
    # Nodes of the graph should be hashable
    sets = map(frozenset, sets)
    for to_node in sets:
        for from_node in sets:
            # off-course you don't want a self loop
            # and only interested in intersecting nodes 
            if to_node != from_node and to_node & from_node:
                G.add_edge(to_node, from_node)
    # and remember to convert the frozen sets to sets
    return [map(set, lst) for lst in nx.connected_components(G)]

Output

>>> intersecting_sets(sets)
[[set([2, 3, 5]), set([1, 3]), set([5]), set([1, 9])], [set([21, 22]), set([18, 21])]]
>>> pprint.pprint(intersecting_sets(sets))
[[set([2, 3, 5]), set([1, 3]), set([5]), set([1, 9])],
 [set([21, 22]), set([18, 21])]]

Upvotes: 1

Related Questions