Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wip] Rework related filtering #197

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 79 additions & 41 deletions rest_framework_filters/filterset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
from collections import OrderedDict
from contextlib import contextmanager

from django.db.models import Subquery
from django.db.models.constants import LOOKUP_SEP
from django.http.request import QueryDict
from django_filters import filterset, rest_framework
from django_filters.utils import get_model_field

Expand Down Expand Up @@ -76,6 +78,7 @@ def __init__(self, data=None, queryset=None, *, request=None, prefix=None, **kwa

super(FilterSet, self).__init__(data, queryset, request=request, prefix=prefix, **kwargs)

self.related_filtersets = self.get_related_filtersets()
self.request_filters = self.get_request_filters()

@classmethod
Expand All @@ -91,23 +94,9 @@ def get_fields(cls):

def get_request_filters(self):
"""
Build a set of filters based on the request data. The resulting set
will walk `RelatedFilter`s to recursively build the set of filters.
Build a set of filters based on the request data. This currently
includes only filter exclusion/negation.
"""
# build param data for related filters: {rel: {param: value}}
related_data = OrderedDict(
[(name, OrderedDict()) for name in self.__class__.related_filters]
)
for param, value in self.data.items():
filter_name, related_param = self.get_related_filter_param(param)

# skip non lookup/related keys
if filter_name is None:
continue

if filter_name in related_data:
related_data[filter_name][related_param] = value

# build the compiled set of all filters
requested_filters = OrderedDict()
for filter_name, f in self.filters.items():
Expand All @@ -126,19 +115,22 @@ def get_request_filters(self):
f_copy.exclude = not f.exclude

requested_filters[exclude_name] = f_copy
return requested_filters

# include filters from related subsets
if isinstance(f, filters.RelatedFilter) and filter_name in related_data:
subset_data = related_data[filter_name]
filterset = f.filterset(data=subset_data, request=self.request)
def get_related_filtersets(self):
related_filtersets = OrderedDict()
related_data = self.get_related_data(self.data)

# modify filter names to account for relationship
for related_name, related_f in filterset.get_request_filters().items():
related_name = LOOKUP_SEP.join([filter_name, related_name])
related_f.field_name = LOOKUP_SEP.join([f.field_name, related_f.field_name])
requested_filters[related_name] = related_f
for related_name, subset_data in related_data.items():
f = self.filters[related_name]
related_filtersets[f.field_name] = f.filterset(
data=subset_data,
queryset=f.get_queryset(self.request),
request=self.request,
prefix=self.form_prefix,
)

return requested_filters
return related_filtersets

@classmethod
def get_param_filter_name(cls, param):
Expand Down Expand Up @@ -179,32 +171,53 @@ def get_param_filter_name(cls, param):
return name

@classmethod
def get_related_filter_param(cls, param):
def get_related_data(cls, data):
"""
Get a tuple of (filter name, related param).
Given the query data, return a map of {related filter: {related: data}}.
The related data is used as the `data` argument for related FilterSet
initialization.

ex::
Note that the related data dictionaries will be a QueryDict, regardless
of the type of the original data dict.

>>> FilterSet.get_related_filter_param('author__email__foobar')
('author', 'email__foobar')
ex::

>>> FilterSet.get_related_filter_param('author')
(None, None)
>>> NoteFilter.get_related_data({
>>> 'author__email': 'foo',
>>> 'author__name': 'bar',
>>> 'name': 'baz',
>>> })
OrderedDict([
('author', <QueryDict: {'email': ['foo'], 'name': ['bar']}>)
])

"""
related_filters = cls.related_filters.keys()
related_data = OrderedDict()
data = data.copy() # get a copy of the original data

# preference more specific filters. eg, `note__author` over `note`.
for name in reversed(sorted(related_filters)):
# we need to match against '__' to prevent eager matching against
# like names. eg, note vs note2. Exact matches are handled above.
if param.startswith("%s%s" % (name, LOOKUP_SEP)):
# strip param + LOOKUP_SET from param
related_param = param[len(name) + len(LOOKUP_SEP):]
return name, related_param
related_prefix = "%s%s" % (name, LOOKUP_SEP)

# not a related param
return None, None
related = QueryDict('', mutable=True)
for param in list(data):
if param.startswith(related_prefix):
value = data.pop(param)
param = param[len(related_prefix):]

# handle QueryDict & dict values
if not isinstance(value, (list, tuple)):
related[param] = value
else:
related.setlist(param, value)

if related:
related_data[name] = related

return related_data

@classmethod
def get_filter_subset(cls, params):
Expand Down Expand Up @@ -232,8 +245,33 @@ def override_filters(self):

def filter_queryset(self, queryset):
with self.override_filters():
return super(FilterSet, self).filter_queryset(queryset)
queryset = super(FilterSet, self).filter_queryset(queryset)
queryset = self.filter_related_filtersets(queryset)
return queryset

def get_form_class(self):
with self.override_filters():
return super(FilterSet, self).get_form_class()
class Form(super(FilterSet, self).get_form_class()):
def clean(form):
cleaned_data = super(Form, form).clean()

for field_name, related_filterset in self.related_filtersets.items():
for key, error in related_filterset.form.errors.items():
self.form.errors[LOOKUP_SEP.join([field_name, key])] = error

return cleaned_data

return Form

def filter_related_filtersets(self, queryset):
"""
Filter the provided `qs` by the `related_filtersets`. It is recommended
that you override this method to change the filtering behavior across
relationships.
"""
for field_name, related_filterset in self.related_filtersets.items():
lookup_expr = LOOKUP_SEP.join([field_name, 'in'])
subquery = Subquery(related_filterset.qs.values('pk'))
queryset = queryset.filter(**{lookup_expr: subquery})

return queryset
93 changes: 67 additions & 26 deletions tests/test_filterset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys

from django.http.request import QueryDict
from django.test import TestCase
from django_filters.filters import BaseInFilter
from rest_framework.test import APIRequestFactory
Expand Down Expand Up @@ -253,22 +254,31 @@ class Meta:
self.assertEqual('note2', name)


class GetRelatedFilterParamTests(TestCase):
class GetRelatedDataTests(TestCase):

def test_regular_filter(self):
name, param = NoteFilterWithRelated.get_related_filter_param('title')
self.assertIsNone(name)
self.assertIsNone(param)
params = NoteFilterWithRelated.get_related_data({'title': ''})
self.assertEqual(params, {})

def test_related_filter_exact(self):
name, param = NoteFilterWithRelated.get_related_filter_param('author')
self.assertIsNone(name)
self.assertIsNone(param)

def test_related_filter_param(self):
name, param = NoteFilterWithRelated.get_related_filter_param('author__email')
self.assertEqual('author', name)
self.assertEqual('email', param)
params = NoteFilterWithRelated.get_related_data({'author': ''})
self.assertEqual(params, {})

def test_related_filters(self):
params = NoteFilterWithRelated.get_related_data({'author__email': ''})
self.assertEqual(params, {'author': {'email': ['']}})

def test_multiple_related_filters(self):
params = NoteFilterWithRelated.get_related_data({
'author__username': '',
'author__is_active': '',
'author__email': '',
})
self.assertEqual(params, {'author': {
'email': [''],
'is_active': [''],
'username': [''],
}})

def test_name_hiding(self):
class PostFilterNameHiding(PostFilter):
Expand All @@ -280,21 +290,52 @@ class Meta:
model = Post
fields = []

name, param = PostFilterNameHiding.get_related_filter_param('note__author__email')
self.assertEqual('note__author', name)
self.assertEqual('email', param)
params = PostFilterNameHiding.get_related_data({'note__author__email': ''})
self.assertEqual(params, {'note__author': {'email': ['']}})

params = PostFilterNameHiding.get_related_data({'note__title': ''})
self.assertEqual(params, {'note': {'title': ['']}})

params = PostFilterNameHiding.get_related_data({'note2__title': ''})
self.assertEqual(params, {'note2': {'title': ['']}})

params = PostFilterNameHiding.get_related_data({'note2__author': ''})
self.assertEqual(params, {'note2': {'author': ['']}})

# combined
params = PostFilterNameHiding.get_related_data({
'note__author__email': '',
'note__title': '',
'note2__title': '',
'note2__author': '',
})

self.assertEqual(params, {
'note__author': {'email': ['']},
'note': {'title': ['']},
'note2': {
'title': [''],
'author': [''],
},
})

def test_querydict(self):
self.assertEqual(
QueryDict('a=1&a=2&b=3'),
{'a': ['1', '2'], 'b': ['3']}
)

name, param = PostFilterNameHiding.get_related_filter_param('note__title')
self.assertEqual('note', name)
self.assertEqual('title', param)
result = {'note': {
'author__email': ['a'],
'title': ['b', 'c'],
}}

name, param = PostFilterNameHiding.get_related_filter_param('note2__title')
self.assertEqual('note2', name)
self.assertEqual('title', param)
query = QueryDict('note__author__email=a&note__title=b&note__title=c')
self.assertEqual(PostFilter.get_related_data(query), result)

name, param = PostFilterNameHiding.get_related_filter_param('note2__author')
self.assertEqual('note2', name)
self.assertEqual('author', param)
# QueryDict-like dictionary w/ multiple values for a param (a la m2m)
query = {'note__author__email': 'a', 'note__title': ['b', 'c']}
self.assertEqual(PostFilter.get_related_data(query), result)


class GetFilterSubsetTests(TestCase):
Expand Down Expand Up @@ -417,9 +458,9 @@ def test_related_exclude(self):
}

filterset = BlogPostFilter(GET, queryset=BlogPost.objects.all())
requested_filters = filterset.request_filters
requested_filters = filterset.related_filtersets['tags'].request_filters

self.assertTrue(requested_filters['tags__name__contains!'].exclude)
self.assertTrue(requested_filters['name__contains!'].exclude)

def test_exclusion_results(self):
GET = {
Expand Down