ds_user
ds_user

Reputation: 2179

Selecting only the subset from set of sets in python

I am trying to remove the superset(if there is any for any set in my set of sets) and return only the subsets from the set of sets. I have written the below code, but it is taking long time for execution since i am handling large dataset, could someone suggest an other option for this.

For example, if i have a set of frozensets like this

skt = {{D},{E,D,M},{E,M}}

I need an output like

skt = {{D},{E,M}}

My code is,

for item in skt.copy():
    for other_item in skt.difference([item]):
        if item >= other_item:
            skt.remove(item)
            break

Thanks in advance.

Upvotes: 2

Views: 1774

Answers (3)

MariusSiuram
MariusSiuram

Reputation: 3634

Understanding that

For a list L of sets, return the sets with no superset in L

skt = {{D},{E,D,M},{E,M}}
out = {{D}, {E,M}}

and

skt = {{D}, {E,G}, {E,H}, {D,E,F}, {E,F,G}}
out = {{D}, {E,G}, {E,H}, {D,E,F}}

if that's correct then (in my head and I may be wrong) the worst case always forces you to check all the pairs. You could do improvements, like don't iterate elements that have already been deleted. Or check each pair only one, and do it in both directions, and update accordingly. A itertools.product may be useful, but again, it doesn't update itself so when you delete an element, then I'm not sure what would be efficient.

A code a bit more optimized may be:

skt = {frozenset({1}), frozenset({1,2,3}), frozenset({2,3}), frozenset({4}), 
       frozenset({5,7}), frozenset({5,8}), frozenset({5,6,7}), 
       frozenset({6,7,8})}

newset = set()

def check(elem):
    to_delete = []
    ret = True
    for y in skt:
        if elem > y:
            to_delete.append(elem)
            ret = False
            break
        if y > elem:
            to_delete.append(y)
    for d in to_delete:
        skt.remove(d)
    return ret  

while skt:
    checking = skt.pop()
    if check(checking):
        newset.add(checking)

Upvotes: 1

Anton Savin
Anton Savin

Reputation: 41301

At least a minor optimization can be done: don't copy a set, but rather create a new one:

newset = set()
for x in skt:
   if not any(y < x for y in skt):
      newset.add(x)

Or in one line:

newset = set(x for x in skt if not any(y < x for y in skt))

UPDATE:

You can pre-calculate for each element the set of sets containing that element, and after that check each set only against the sets containing at least one of its elements:

setsForElement = defaultdict(set);
for s in skt:
    for element in s:
        setsForElement[element].add(s);

newset = set(s for s in skt if not any (setForElement < s for element in s for setForElement in setsForElement[element]))

# last line is equal to:
newset = set();
for s in skt:
    good = True;
    for element in s:
        if any(setForElement < s for setForElement in setsForElement[element]):
            good = False;
            break;

    if good:
        newset.add(s);

It may save you some time depending on your dataset. Of course in worst case (for example if your dataset is a power set of some set), the complexity will be again O(N^2) set comparisons. Or thinking of it, it can be even worse than straight algorithm because you may check the same set multiple times.

Upvotes: 3

Chris Martin
Chris Martin

Reputation: 30736

This approach is essentially the same as yours, but it runs in order of ascending cardinality. The advantage could be significant, depending on your data (if there are some small sets that can knock out a lot of others in the early iterations).

from collections import defaultdict

def foo(skt):

    # Index the sets by cardinality
    index = defaultdict(lambda: set())
    for s in skt:
        index[len(s)].add(s)

    # For each cardinality i, starting with the lowest
    for i in range(max(index.keys()) + 1):

        # For each cardinality j > i (because supersets must be larger)
        for j in range(i + 1, max(index.keys()) + 1):

            # Remove j-sized supersets
            for y in [y for y in index[j] if any(y >= x for x in index[i])]:
                index[j].remove(y)

    # Flatten the index
    return set(x for xs in index.values() for x in xs)

Upvotes: 1

Related Questions