Enquest
Enquest

Reputation: 101

django-tenant and djangorestframework testing

I'm trying to switch to test-driven development but for that, I need to understand unit test :)

I got the following problem. I'm using Django rest framework and django-tenants in combination. So far so good. However to test anything you need to make a tenant.

class Test1(TenantTestCase):

def setUp(self):
    super().setUp()
    self.client = TenantClient(self.tenant)

Once you do so your tenant is setup.

But if I was to test the API the client is the TenantClient and not the APIClient. The testcase is TenantTestCase.

So my question. How do you combine the two?

https://django-tenants.readthedocs.io/en/latest/test.html

Upvotes: 1

Views: 1045

Answers (2)

Ahmed Shehab
Ahmed Shehab

Reputation: 1857

Let's leverage python inheritance...


# test case

from rest_framework.test import APITestCase
from django_tenants.test.cases import TenantTestCase

class ApiTest(APITestCase, TenantTestCase):
    @override_settings(ROOT_URLCONF='app.tenant_urls')
    def setUp(self) -> None:
        self.client = APITenantClient(self.tenant, enforce_csrf_checks=False)

# client
from rest_framework.test import APIClient
from django_tenants.test.client import TenantClient

class APITenantClient(TenantClient, APIClient):
    pass

read more @ https://django-tenants.readthedocs.io/en/latest/test.html

Upvotes: 2

leotrubach
leotrubach

Reputation: 1597

I ended up with this class:

from rest_framework.test import APIClient


class TenantAPIClient(APIClient):
    def __init__(self, tenant, enforce_csrf_checks=False, **defaults):
        super().__init__(enforce_csrf_checks, **defaults)
        self.tenant = tenant

    def get(self, path, data=None, follow=False, **extra):
        if "HTTP_HOST" not in extra:
            extra["HTTP_HOST"] = self.tenant.get_primary_domain().domain

        return super().get(path, data, **extra)

    def post(self, path, data=None, follow=False, **extra):
        if "HTTP_HOST" not in extra:
            extra["HTTP_HOST"] = self.tenant.get_primary_domain().domain

        return super().post(path, data, **extra)

    def patch(self, path, data=None, follow=False, **extra):
        if "HTTP_HOST" not in extra:
            extra["HTTP_HOST"] = self.tenant.get_primary_domain().domain

        return super().patch(path, data, **extra)

    def put(self, path, data=None, follow=False, **extra):
        if "HTTP_HOST" not in extra:
            extra["HTTP_HOST"] = self.tenant.get_primary_domain().domain

        return super().put(path, data, **extra)

    def delete(self, path, data=None, follow=False, **extra):
        if "HTTP_HOST" not in extra:
            extra["HTTP_HOST"] = self.tenant.get_primary_domain().domain

        return super().delete(path, data, **extra)

and then

class TestContentDeletion(TenantTestCase):
    def setUp(self):
        super().setUp()
        self.c = TenantAPIClient(self.tenant)

Upvotes: 2

Related Questions