Watchduck
Watchduck

Reputation: 1156

How to avoid code duplication with similar Django model methods?

The following model contains the two almost identical functions list_ancestors and list_descendants. What would be a good way to write this code only once?

class Node(models.Model):
    name = models.CharField(max_length=120, blank=True, null=True)
    parents = models.ManyToManyField('self', blank=True, symmetrical=False)

    def list_parents(self):
        return self.parents.all()

    def list_children(self):
        return Node.objects.filter(parents=self.id)

    def list_ancestors(self):
        parents = self.list_parents()
        ancestors = set(parents)
        for p in parents:
            ancestors |= set(p.list_ancestors())  # set union
        return list(ancestors)

    def list_descendants(self):
        children = self.list_children()
        descendants = set(children)
        for c in children:
            descendants |= set(c.list_descendants())  # set union
        return list(descendants)

    def __str__(self):
        return self.name

EDIT: The solution derived from the answers below:

def list_withindirect(self, arg):
    direct = getattr(self, arg)()
    withindirect = set(direct)
    for d in direct:
         withindirect |= set(d.list_withindirect(arg))
    return list(withindirect)

def list_ancestors(self):
     return self.list_withindirect('list_parents')

def list_descendants(self):
     return self.list_withindirect('list_children')

Upvotes: 1

Views: 124

Answers (2)

2ps
2ps

Reputation: 15936

Use a string and call getattr on the object to get the callable function.

def list_withindirect(self, fn1):
    direct = getattr(self, fn1)()
    withindirect = set(direct)
    for d in direct:
         withindirect |= set(d.list_withindirect(fn1))

    return list(withindirect)

def list_ancestors(self):
     return self.list_withindirect('list_parents')

Upvotes: 2

Nikolay Prokopyev
Nikolay Prokopyev

Reputation: 1312

This looks like the issue in bound and unbound methods problem.

When you're initially pass self.list_parents to self.list_withindirect(list_direct) everything is OK.

But when you're recursively pass the same! self.list_parents to d.list_withindirect (i.e. to descendants), you're accidentally populate your direct variable with parents of the topmost caller object, instead of d.

For example, it may be resolved using getattr, like it was answered by 2ps (upd: the error in his original code was found in comments there).

Upvotes: 0

Related Questions