How to get the all descendants of a node including itself with Django treebeard?

I have Category model extending MP_Node with Django treebeard as shown below:

# "models.py"

from django.db import models
from treebeard.mp_tree import MP_Node

class Category(MP_Node):
    name = models.CharField(max_length=50)
    node_order_by = ('name',)

    def __str__(self):
        return self.name

Then, I could get all descendants of a category not including itself with get_descendants() using Django treebeard as shown below:

categories = Category.objects.get(name="Food").get_descendants()
print(categories)
# <MP_NodeQuerySet [<Category: Meat>, <Category: Fish>]>

But, when I tried to get all descendants of a category including itself with get_descendants(include_self=True) using Django treebeard, I got the error as shown below:

categories = Category.objects.get(name="Food").get_descendants(include_self=True)
print(categories) # Error

TypeError: get_descendants() got an unexpected keyword argument 'include_self'

Actually, I could get all descendants of a category including itself with get_descendants(include_self=True) using Django mptt as shown below. *I switched Django mptt to Django treebeard because Django mptt is unmaintained and gives some error:

categories = Category.objects.get(name="Food").get_descendants(include_self=True)
print(categories)
# <TreeQuerySet [<Category: Food>, <Category: Meat>, <Category: Fish>]>

So, how can I get the all descendants of a category including itself with Django treebeard?

Upvotes: 0

Views: 272

Answers (1)

You can override get_descendants() and add include_self=False to it as the 2nd argument with Django treebeard as shown below. *You can see the original get_descendants():

from django.db import models
from treebeard.mp_tree import MP_Node, get_result_class

class Category(MP_Node):
    name = models.CharField(max_length=50)
    node_order_by = ('name',)

    def get_descendants(self, include_self=False):
        if include_self:
            return self.__class__.get_tree(self)
        if self.is_leaf():
            return get_result_class(self.__class__).objects.none()
        return self.__class__.get_tree(self).exclude(pk=self.pk)

    def __str__(self):
        return self.name

Then, you can get the all descendants of a category including itself with Django treebeard as shown below:

categories = Category.objects.get(name="Food").get_descendants(include_self=True)
print(categories)
# <MP_NodeQuerySet [<Category: Food>, <Category: Meat>, <Category: Fish>]>

In addition, you can get the all descendants of a category including itself with get_tree(parent=None) using Django treebeard as shown below. *You need to use an object to use get_tree(parent=None):

categories = Category.objects.all()[2].get_tree(Category.objects.get(name="Food"))
print(categories)
# <MP_NodeQuerySet [<Category: Food>, <Category: Meat>, <Category: Fish>]>

categories = Category.objects.first().get_tree(Category.objects.get(name="Food"))
print(categories)
# <MP_NodeQuerySet [<Category: Food>, <Category: Meat>, <Category: Fish>]>

categories = Category.objects.last().get_tree(Category.objects.get(name="Food"))
print(categories)
# <MP_NodeQuerySet [<Category: Food>, <Category: Meat>, <Category: Fish>]>

Upvotes: 0

Related Questions