From 6f8734ef6d6ddc932cdd4ce92df96ecd09d0e08b Mon Sep 17 00:00:00 2001 From: Ryan P Kilby Date: Thu, 5 Oct 2017 18:50:43 -0400 Subject: [PATCH 1/4] Change 'related params' handling into related data --- rest_framework_filters/filterset.py | 48 +++++++++++----- tests/test_filterset.py | 89 +++++++++++++++++++++-------- 2 files changed, 100 insertions(+), 37 deletions(-) diff --git a/rest_framework_filters/filterset.py b/rest_framework_filters/filterset.py index 1f76a4b..d776ea2 100644 --- a/rest_framework_filters/filterset.py +++ b/rest_framework_filters/filterset.py @@ -3,6 +3,7 @@ from contextlib import contextmanager from django.db.models.constants import LOOKUP_SEP +from django.http.request import QueryDict from django_filters import filterset, rest_framework from django_filters.utils import get_model_field @@ -179,32 +180,53 @@ def get_param_filter_name(cls, param): return name @classmethod - def get_related_filter_param(cls, param): + def get_related_data(cls, data): """ - Get a tuple of (filter name, related param). + Given the query data, return a map of {related filter: {related: data}}. + The related data is used as the `data` argument for related FilterSet + initialization. - ex:: + Note that the related data dictionaries will be a QueryDict, regardless + of the type of the original data dict. - >>> FilterSet.get_related_filter_param('author__email__foobar') - ('author', 'email__foobar') + ex:: - >>> FilterSet.get_related_filter_param('author') - (None, None) + >>> NoteFilter.get_related_data({ + >>> 'author__email': 'foo', + >>> 'author__name': 'bar', + >>> 'name': 'baz', + >>> }) + OrderedDict([ + ('author', ) + ]) """ related_filters = cls.related_filters.keys() + related_data = OrderedDict() + data = data.copy() # get a copy of the original data # preference more specific filters. eg, `note__author` over `note`. for name in reversed(sorted(related_filters)): # we need to match against '__' to prevent eager matching against # like names. eg, note vs note2. Exact matches are handled above. - if param.startswith("%s%s" % (name, LOOKUP_SEP)): - # strip param + LOOKUP_SET from param - related_param = param[len(name) + len(LOOKUP_SEP):] - return name, related_param + related_prefix = "%s%s" % (name, LOOKUP_SEP) + + related = QueryDict('', mutable=True) + for param in list(data): + if param.startswith(related_prefix): + value = data.pop(param) + param = param[len(related_prefix):] + + # handle QueryDict & dict values + if not isinstance(value, (list, tuple)): + related[param] = value + else: + related.setlist(param, value) + + if related: + related_data[name] = related - # not a related param - return None, None + return related_data @classmethod def get_filter_subset(cls, params): diff --git a/tests/test_filterset.py b/tests/test_filterset.py index 7ea0780..a9a17d7 100644 --- a/tests/test_filterset.py +++ b/tests/test_filterset.py @@ -1,5 +1,6 @@ import sys +from django.http.request import QueryDict from django.test import TestCase from django_filters.filters import BaseInFilter from rest_framework.test import APIRequestFactory @@ -253,22 +254,31 @@ class Meta: self.assertEqual('note2', name) -class GetRelatedFilterParamTests(TestCase): +class GetRelatedDataTests(TestCase): def test_regular_filter(self): - name, param = NoteFilterWithRelated.get_related_filter_param('title') - self.assertIsNone(name) - self.assertIsNone(param) + params = NoteFilterWithRelated.get_related_data({'title': ''}) + self.assertEqual(params, {}) def test_related_filter_exact(self): - name, param = NoteFilterWithRelated.get_related_filter_param('author') - self.assertIsNone(name) - self.assertIsNone(param) - - def test_related_filter_param(self): - name, param = NoteFilterWithRelated.get_related_filter_param('author__email') - self.assertEqual('author', name) - self.assertEqual('email', param) + params = NoteFilterWithRelated.get_related_data({'author': ''}) + self.assertEqual(params, {}) + + def test_related_filters(self): + params = NoteFilterWithRelated.get_related_data({'author__email': ''}) + self.assertEqual(params, {'author': {'email': ['']}}) + + def test_multiple_related_filters(self): + params = NoteFilterWithRelated.get_related_data({ + 'author__username': '', + 'author__is_active': '', + 'author__email': '', + }) + self.assertEqual(params, {'author': { + 'email': [''], + 'is_active': [''], + 'username': [''], + }}) def test_name_hiding(self): class PostFilterNameHiding(PostFilter): @@ -280,21 +290,52 @@ class Meta: model = Post fields = [] - name, param = PostFilterNameHiding.get_related_filter_param('note__author__email') - self.assertEqual('note__author', name) - self.assertEqual('email', param) + params = PostFilterNameHiding.get_related_data({'note__author__email': ''}) + self.assertEqual(params, {'note__author': {'email': ['']}}) + + params = PostFilterNameHiding.get_related_data({'note__title': ''}) + self.assertEqual(params, {'note': {'title': ['']}}) + + params = PostFilterNameHiding.get_related_data({'note2__title': ''}) + self.assertEqual(params, {'note2': {'title': ['']}}) + + params = PostFilterNameHiding.get_related_data({'note2__author': ''}) + self.assertEqual(params, {'note2': {'author': ['']}}) + + # combined + params = PostFilterNameHiding.get_related_data({ + 'note__author__email': '', + 'note__title': '', + 'note2__title': '', + 'note2__author': '', + }) + + self.assertEqual(params, { + 'note__author': {'email': ['']}, + 'note': {'title': ['']}, + 'note2': { + 'title': [''], + 'author': [''], + }, + }) + + def test_querydict(self): + self.assertEqual( + QueryDict('a=1&a=2&b=3'), + {'a': ['1', '2'], 'b': ['3']} + ) - name, param = PostFilterNameHiding.get_related_filter_param('note__title') - self.assertEqual('note', name) - self.assertEqual('title', param) + result = {'note': { + 'author__email': ['a'], + 'title': ['b', 'c'], + }} - name, param = PostFilterNameHiding.get_related_filter_param('note2__title') - self.assertEqual('note2', name) - self.assertEqual('title', param) + query = QueryDict('note__author__email=a¬e__title=b¬e__title=c') + self.assertEqual(PostFilter.get_related_data(query), result) - name, param = PostFilterNameHiding.get_related_filter_param('note2__author') - self.assertEqual('note2', name) - self.assertEqual('author', param) + # QueryDict-like dictionary w/ multiple values for a param (a la m2m) + query = {'note__author__email': 'a', 'note__title': ['b', 'c']} + self.assertEqual(PostFilter.get_related_data(query), result) class GetFilterSubsetTests(TestCase): From a80e7bdb71e6a51d69c4a01aa0fd97d3c2646b75 Mon Sep 17 00:00:00 2001 From: Ryan P Kilby Date: Fri, 6 Oct 2017 17:13:12 -0400 Subject: [PATCH 2/4] Extract related filterset handling from filters --- rest_framework_filters/filterset.py | 58 +++++++++++++++-------------- tests/test_filterset.py | 4 +- 2 files changed, 33 insertions(+), 29 deletions(-) diff --git a/rest_framework_filters/filterset.py b/rest_framework_filters/filterset.py index d776ea2..d41daef 100644 --- a/rest_framework_filters/filterset.py +++ b/rest_framework_filters/filterset.py @@ -77,6 +77,7 @@ def __init__(self, data=None, queryset=None, *, request=None, prefix=None, **kwa super(FilterSet, self).__init__(data, queryset, request=request, prefix=prefix, **kwargs) + self.related_filtersets = self.get_related_filtersets() self.request_filters = self.get_request_filters() @classmethod @@ -92,23 +93,9 @@ def get_fields(cls): def get_request_filters(self): """ - Build a set of filters based on the request data. The resulting set - will walk `RelatedFilter`s to recursively build the set of filters. + Build a set of filters based on the request data. This currently + includes only filter exclusion/negation. """ - # build param data for related filters: {rel: {param: value}} - related_data = OrderedDict( - [(name, OrderedDict()) for name in self.__class__.related_filters] - ) - for param, value in self.data.items(): - filter_name, related_param = self.get_related_filter_param(param) - - # skip non lookup/related keys - if filter_name is None: - continue - - if filter_name in related_data: - related_data[filter_name][related_param] = value - # build the compiled set of all filters requested_filters = OrderedDict() for filter_name, f in self.filters.items(): @@ -127,19 +114,22 @@ def get_request_filters(self): f_copy.exclude = not f.exclude requested_filters[exclude_name] = f_copy + return requested_filters - # include filters from related subsets - if isinstance(f, filters.RelatedFilter) and filter_name in related_data: - subset_data = related_data[filter_name] - filterset = f.filterset(data=subset_data, request=self.request) + def get_related_filtersets(self): + related_filtersets = OrderedDict() + related_data = self.get_related_data(self.data) - # modify filter names to account for relationship - for related_name, related_f in filterset.get_request_filters().items(): - related_name = LOOKUP_SEP.join([filter_name, related_name]) - related_f.field_name = LOOKUP_SEP.join([f.field_name, related_f.field_name]) - requested_filters[related_name] = related_f + for related_name, subset_data in related_data.items(): + f = self.filters[related_name] + related_filtersets[f.field_name] = f.filterset( + data=subset_data, + queryset=f.get_queryset(self.request), + request=self.request, + prefix=self.form_prefix, + ) - return requested_filters + return related_filtersets @classmethod def get_param_filter_name(cls, param): @@ -254,8 +244,22 @@ def override_filters(self): def filter_queryset(self, queryset): with self.override_filters(): - return super(FilterSet, self).filter_queryset(queryset) + queryset = super(FilterSet, self).filter_queryset(queryset) + queryset = self.filter_related_filtersets(queryset) + return queryset def get_form_class(self): with self.override_filters(): return super(FilterSet, self).get_form_class() + + def filter_related_filtersets(self, queryset): + """ + Filter the provided `qs` by the `related_filtersets`. It is recommended + that you override this method to change the filtering behavior across + relationships. + """ + for field_name, related_filterset in self.related_filtersets.items(): + lookup_expr = LOOKUP_SEP.join([field_name, 'in']) + queryset = queryset.filter(**{lookup_expr: related_filterset.qs}) + + return queryset diff --git a/tests/test_filterset.py b/tests/test_filterset.py index a9a17d7..5b9887c 100644 --- a/tests/test_filterset.py +++ b/tests/test_filterset.py @@ -458,9 +458,9 @@ def test_related_exclude(self): } filterset = BlogPostFilter(GET, queryset=BlogPost.objects.all()) - requested_filters = filterset.request_filters + requested_filters = filterset.related_filtersets['tags'].request_filters - self.assertTrue(requested_filters['tags__name__contains!'].exclude) + self.assertTrue(requested_filters['name__contains!'].exclude) def test_exclusion_results(self): GET = { From be308b7813988db0ea06e7c8ea036c5fbcdb4696 Mon Sep 17 00:00:00 2001 From: Ryan P Kilby Date: Tue, 24 Oct 2017 10:24:49 -0400 Subject: [PATCH 3/4] Add Form.clean to mix in related errors --- rest_framework_filters/filterset.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/rest_framework_filters/filterset.py b/rest_framework_filters/filterset.py index d41daef..04d6f3e 100644 --- a/rest_framework_filters/filterset.py +++ b/rest_framework_filters/filterset.py @@ -250,7 +250,17 @@ def filter_queryset(self, queryset): def get_form_class(self): with self.override_filters(): - return super(FilterSet, self).get_form_class() + class Form(super(FilterSet, self).get_form_class()): + def clean(form): + cleaned_data = super(Form, form).clean() + + for field_name, related_filterset in self.related_filtersets.items(): + for key, error in related_filterset.form.errors.items(): + self.form.errors[LOOKUP_SEP.join([field_name, key])] = error + + return cleaned_data + + return Form def filter_related_filtersets(self, queryset): """ From eb652e078f491cf7d93827420d7db3c5aeee2ae7 Mon Sep 17 00:00:00 2001 From: Ryan P Kilby Date: Tue, 24 Oct 2017 10:59:20 -0400 Subject: [PATCH 4/4] Use subquery for related filter --- rest_framework_filters/filterset.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/rest_framework_filters/filterset.py b/rest_framework_filters/filterset.py index 04d6f3e..30c0462 100644 --- a/rest_framework_filters/filterset.py +++ b/rest_framework_filters/filterset.py @@ -2,6 +2,7 @@ from collections import OrderedDict from contextlib import contextmanager +from django.db.models import Subquery from django.db.models.constants import LOOKUP_SEP from django.http.request import QueryDict from django_filters import filterset, rest_framework @@ -270,6 +271,7 @@ def filter_related_filtersets(self, queryset): """ for field_name, related_filterset in self.related_filtersets.items(): lookup_expr = LOOKUP_SEP.join([field_name, 'in']) - queryset = queryset.filter(**{lookup_expr: related_filterset.qs}) + subquery = Subquery(related_filterset.qs.values('pk')) + queryset = queryset.filter(**{lookup_expr: subquery}) return queryset