Reputation: 9798
I am building a family tree application in Django where I need to represent and query marriages symmetrically. Each marriage should have only one record, and the relationship should include both partners without duplicating data. Here's the relevant model structure:
class Person(models.Model):
first_name = models.CharField(max_length=100)
last_name = models.CharField(max_length=100)
spouses = models.ManyToManyField(
'self', through="Marriage", symmetrical=True, related_name="partners"
)
class Marriage(models.Model):
person1 = models.ForeignKey(Person, on_delete=models.CASCADE, related_name="marriages_as_person1")
person2 = models.ForeignKey(Person, on_delete=models.CASCADE, related_name="marriages_as_person2")
start_date = models.DateField(null=True, blank=True)
end_date = models.DateField(null=True, blank=True)
I want to:
Here’s the code I’m using to query spouses:
# Query spouses for a person
p1 = Person.objects.create()
p2 = Person.objects.create()
Marriage.objects.create(person1=p1, person2=p2)
p1.spouses.all() # Returns list containing p2
p2.spouses.all() # Returns empty list
However, I’m facing challenges:
p1
spouses are queried, it should contain p2
and if p2
spouses are queried, it should contain p1
My use case is to return a list or Person having pids
(partner id) as a list of id like below using DRF
[
{
id: 1,
full_name: ‘John’,
pids: [2]
},
{
id: 2,
full_name: ‘Mary’,
pids: [1]
}
]
Current serializer code is
def get_pids(self, obj):
"""
Returns a list of IDs of all the spouses of the person, ensuring bidirectional relationships.
"""
partner_ids = set(obj.spouses.values_list('id', flat=True))
# Ensure bidirectional relationships
for spouse in obj.spouses.all():
partner_ids.update(spouse.spouses.values_list('id', flat=True))
return list(partner_ids)
Upvotes: 0
Views: 43
Reputation: 1030
I would improve your structure. I think ManyToManyField
should not be used here, probably my option would work better in your case. Also, this will definitely ensure that only one marriage record is created for the two persons.
class Person(models.Model):
first_name = models.CharField(max_length=100)
last_name = models.CharField(max_length=100)
class Marriage(models.Model):
person1 = models.ForeignKey(
Person,
on_delete=models.CASCADE,
related_name='+',
)
person2 = models.ForeignKey(
Person,
on_delete=models.CASCADE,
related_name='+',
)
unique_key = models.CharField(max_length=50)
start_date = models.DateField(null=True, blank=True)
end_date = models.DateField(null=True, blank=True)
class Meta:
constraints = [
models.UniqueConstraint(
fields=['unique_key'],
name='unique_marriage_pair',
)
]
def save(
self,
force_insert=None,
force_update=None,
using=None,
update_fields=None):
if not self.pk:
self.unique_key = '_'.join(
map(str, sorted([self.person1_id, self.person2_id]))
)
return super().save(
force_insert=force_insert,
force_update=force_update,
using=using,
update_fields=update_fields,
)
Which will allow you to find all spouse identifiers in a single query, for example, like this:
def get_pids(self, obj: Person) -> list[int]:
"""Returns a list of IDs of all the spouses of the person,
ensuring bidirectional relationships.
"""
person_id = obj.id
queryset = (
Marriage.objects
.filter(
models.Q(person1_id=person_id)
| models.Q(person2_id=person_id),
)
.values_list('person1_id', 'person2_id')
)
return list(set().union(*queryset) - {person_id})
Or it will be possible to make a selection for the list of passed identifiers in two queries, like this (provided Postgres
is used):
from collections.abc import Iterable, Iterator
from typing import TypeAlias
from django.contrib.postgres.aggregates import ArrayAgg
Result: TypeAlias = Iterator[tuple[int, list[int]]]
def get_pids_for_few_persons(person_ids: Iterable[int]) -> Result:
def get_queryset(search_key: str, aggregate_key: str):
return (
Marriage.objects
.filter(models.Q(**{search_key + '__in': person_ids}))
.values(search_key)
.annotate(pids=ArrayAgg(aggregate_key))
.values_list(search_key, 'pids')
)
queryset1 = get_queryset(
search_key='person1_id',
aggregate_key='person2_id',
)
queryset2 = get_queryset(
search_key='person2_id',
aggregate_key='person1_id',
)
data1, data2 = dict(queryset1), dict(queryset2)
for person_id in data1.keys() | data2.keys():
yield person_id, data1.get(person_id, []) + data2.get(person_id, [])
p1 = Person.objects.create()
p2 = Person.objects.create()
p3 = Person.objects.create()
p4 = Person.objects.create()
Marriage.objects.create(person1=p1, person2=p2)
Marriage.objects.create(person1=p3, person2=p1)
result = list(get_pids((1, 2, 3)))
# [(1, [2, 3]), (2, [1]), (3, [1])]
Upvotes: 0