hamsterdam
hamsterdam

Reputation: 589

Solr search results with Django REST Framework

I'm using Django REST Framework for our web API and Solr to power a search. Currenlty, in a subclass of ListAPIView, I override get_queryset() to get a QuerySet with the Solr search results:

class ClipList(generics.ListAPIView):
    """
    List all Clips. 

    Permissions: IsAuthenticatedOrReadOnly

    Parameters:
    query -- Search for Clips. EX: clips/?query=aaron%20rodgers
    """
    model = Clip
    serializer_class = ClipSerializer
    permission_classes = (permissions.IsAuthenticatedOrReadOnly,)

    def get_queryset(self):
        params = request.GET
        query = params.get('query', None)

        queryset = Clip.objects.all()

        if query is not None:
            conn = solr.Solr(settings.SOLR_URL)
            sh = solr.SearchHandler(conn, "/select")
            response = sh(query)
            ids = []

            for result in response.results:
                ids.append(result['id'])

            # filter the initial queryset based on the Clip identifiers from the Solr response
            # PROBLEM: This does not preserve the order of the results as it equates to
            # `SELECT * FROM clips WHERE id in (75, 89, 106, 45)`.
            # SQL is not guaranteed to return the results in the same order used in the WHERE clause.
            # There is no way (that I'm aware of) to do this in SQL.
            queryset = queryset.filter(pk__in=ids)

        return queryset

However, as explained in the comments, this does not preserve the order of the results. I realize I could make a Python set of Clip objects, but I would then lose the lazy-evaluation of Django QuerySet and the results are likely to be large and will be paginated.

I looked into Haystack, and my understanding is that the code above using Haystack would look like:

    def get_queryset(self):
        params = self.request.GET
        query = params.get('query', None)

        search_queryset = SearchQuerySet().filter(content=query)

        return search_queryset

This is super simple and will maintain the order of the results, but Django REST Framework does not serialize SearchQuerySets.

Is there a method in REST Framework that I can override that would allow for the serialization of SearchQuerySets? Or is there a way to maintain the ranked results order without using Haystack or Python sets?

Upvotes: 4

Views: 2892

Answers (1)

hamsterdam
hamsterdam

Reputation: 589

We came up with this solution. It mimics the idioms of Django and REST Framework.

Briefly, I'll explain (hopefully, the code and comments are a more in-depth explanation :)). Instead of using a regular Django Page to power pagination, we have a FakePage that can be initialized with a total count that is not coupled to the object list. This is the hack. Yes, it's a "hack", but is a very simple solution. The alternative for us would have been to reimplement a version of QuerySet. For us, the simple solution always wins. Despite being simple, it is reusable and performant.

Using the fake page, we have an abstract SearchListModelMixin class that knows how to get a serializer that uses our fake page. The mixin is applied to concrete view classes like SearchListCreateAPIView. The code is below. If other people have a need, I can explain more thoroughly.

class FakePage(object):
    """
    Fake page used by Django paginator.
    Required for wrapping responses from Solr.
    """

    def __init__(self, object_list, number, total_count):
        """
        Create fake page instance.

        Args:
            object_list: list of objects, represented by a page
            number: 1-based page number
            total_count: total count of objects (in all pages)
        """
        self.object_list = object_list
        self.number = number
        # count of objects per page equals to length of list
        self.per_page = len(object_list)
        if self.per_page > 0:
            self.num_pages = total_count // self.per_page
        else:
            self.num_pages = 0
        self.total_count = total_count

    def __repr__(self):
        return '<Page %s of %s>' % (self.number, self.num_pages)

    def __len__(self):
        return len(self.object_list)

    def __getitem__(self, index):
        if not isinstance(index, (slice,) + six.integer_types):
            raise TypeError
        # The object_list is converted to a list so that if it was a QuerySet
        # it won't be a database hit per __getitem__.
        if not isinstance(self.object_list, list):
            self.object_list = list(self.object_list)
        return self.object_list[index]

    def total_count(self):
        return self.total_count

    def has_next(self):
        return self.number < self.num_pages

    def has_previous(self):
        return self.number > 1

    def has_other_pages(self):
        return self.has_previous() or self.has_next()

    def next_page_number(self):
        if self.has_next():
            return self.number + 1
        raise EmptyPage('Next page does not exist')

    def previous_page_number(self):
        if self.has_previous():
            return self.number - 1
        raise EmptyPage('Previous page does not exist')

    def start_index(self):
        """
        Returns the 1-based index of the first object on this page,
        relative to total objects in the paginator.
        """
        # Special case, return zero if no items.
        if self.total_count == 0:
            return 0
        return (self.per_page * (self.number - 1)) + 1

    def end_index(self):
        """
        Returns the 1-based index of the last object on this page,
        relative to total objects found (hits).
        """
        # Special case for the last page because there can be orphans.
        if self.number == self.num_pages:
            return self.total_count
        return self.number * self.per_page    

class SearchListModelMixin(object):
    """
    List search results or a queryset.
    """
    # Set this attribute to make a custom URL paramter for the query.
    # EX: ../widgets?query=something
    query_param = 'query'

    # Set this attribute to define the type of search.
    # This determines the source of the query value.
    # For example, for regular text search,
    # the query comes from a URL param. However, for a related search,
    # the query is a unique identifier in the URL itself.
    search_type = SearchType.QUERY

    def search_list(self, request, *args, **kwargs):
        # Get the query from the URL parameters dictionary.
        query = self.request.GET.get(self.query_param, None)

        # If there is no query use default REST Framework behavior.
        if query is None and self.search_type == SearchType.QUERY:
            return self.list(request, *args, **kwargs)

        # Get the page of objects and the total count of the results.
        if hasattr(self, 'get_search_results'):
            self.object_list, total_count = self.get_search_results()
            if not isinstance(self.object_list, list) and not isinstance(total_count, int):
               raise ImproperlyConfigured("'%s.get_search_results()' must return (list, int)"
                                    % self.__class__.__name__)
        else:
            raise ImproperlyConfigured("'%s' must define 'get_search_results()'"
                                    % self.__class__.__name__)

        # Normally, we would be serializing a QuerySet,
        # which is lazily evaluated and has the entire result set.
        # Here, we just have a Python list containing only the elements for the
        # requested page. Thus, we must generate a fake page,
        # simulating a Django page in order to fully support pagination.
        # Otherwise, the `count` field would be equal to the page size.
        page = FakePage(self.object_list,
            int(self.request.GET.get(self.page_kwarg, 1)),
            total_count)

        # Prepare a SearchPaginationSerializer
        # with the object_serializer_class
        # set to the serializer_class of the APIView.
        serializer = self.get_search_pagination_serializer(page)

        return Response(serializer.data)

    def get_search_pagination_serializer(self, page):
        """
        Return a serializer instance to use with paginated search data.
        """
        class SerializerClass(SearchPaginationSerializer):
            class Meta:
                object_serializer_class = self.get_serializer_class()

        pagination_serializer_class = SerializerClass
        context = self.get_serializer_context()
        return pagination_serializer_class(instance=page, context=context)

    def get_solr_results(self, sort=None):
        """
        This method is optional. It encapsulates logic for a Solr request
        for a list of ids using pagination. Another method could be provided to
        do the search request. 
        """
        queryset = super(self.__class__, self).get_queryset()

        conn = solr.Solr(settings.SOLR_URL)

        # Create a SearchHandler and get the query with respect to 
        # the type of search.
        if self.search_type == SearchType.QUERY:
            query = self.request.GET.get(self.query_param, None)
            sh = solr.SearchHandler(conn, "/select")
        elif self.search_type == SearchType.RELATED:
            query  = str(self.kwargs[self.lookup_field])
            sh = solr.SearchHandler(conn, "/morelikethis")

        # Get the pagination information and populate kwargs.
        page_num = int(self.request.GET.get(self.page_kwarg, 1))
        per_page = self.get_paginate_by()
        offset = (page_num - 1) * per_page
        kwargs = {'rows': per_page, 'start': offset}
        if sort:
            kwargs['sort'] = sort

        # Perform the Solr request and build a list of results.
        # For now, this only gets the id field, but this could be 
        # customized later.
        response = sh(query, 'id', **kwargs)
        results = [int(r['id']) for r in response.results]

        # Get a dictionary representing a page of objects.
        # The dict has pk keys and object values, sorted by id.
        object_dict = queryset.in_bulk(results)

        # Build the sorted list of objects.
        sorted_objects = []
        for pk in results:
            obj = object_dict.get(pk, None)
            if obj:
                sorted_objects.append(obj)
        return sorted_objects, response.numFound

class SearchType(object):
    """
    This enum-like class defines types of Solr searches
    that can be used in APIViews.
    """
    QUERY = 1
    RELATED = 2


class SearchPaginationSerializer(pagination.BasePaginationSerializer):
    count = serializers.Field(source='total_count')
    next = pagination.NextPageField(source='*')
    previous = pagination.PreviousPageField(source='*')


class SearchListCreateAPIView(SearchListModelMixin, generics.ListCreateAPIView):
    """
    Concrete view for listing search results, a queryset, or creating a model instance.
    """
    def get(self, request, *args, **kwargs):
        return self.search_list(request, *args, **kwargs)


class SearchListAPIView(SearchListModelMixin, generics.ListAPIView):
    """
    Concrete view for listing search results or a queryset.
    """
    def get(self, request, *args, **kwargs):
        return self.search_list(request, *args, **kwargs)


class WidgetList(SearchListCreateAPIView):
    """
    List all Widgets or create a new Widget.
    """
    model = Widget
    queryset = Widget.objects.all()
    serializer_class = WidgetSerializer
    permission_classes = (permissions.IsAuthenticatedOrReadOnly,)
    search_type = SearchType.QUERY # this is default, but explicitly declared here    

    def get_queryset(self):
        """
        The method implemented for database powered results.
        I'm simply illustrating the default behavior here.
        """
        queryset = super(WidgetList, self).get_queryset()
        return queryset

    def get_search_results(self):
        """
        The method implemented for search engine powered results.
        """
        return self.get_solr_results(solr_sort)

class RelatedList(SearchListAPIView):
    """
    List all related Widgets.
    """
    model = Widget
    queryset = Widget.objects.all()
    serializer_class = WdigetSerializer
    permission_classes = (permissions.IsAuthenticatedOrReadOnly,)
    search_type = SearchType.RELATED

    def get_search_results(self):
        return self.get_solr_results()

Upvotes: 5

Related Questions