Skip to content

Commit

Permalink
Add initial type hints and CI infrastructure
Browse files Browse the repository at this point in the history
  • Loading branch information
intgr committed May 9, 2024
1 parent 4b1350d commit 4faca57
Show file tree
Hide file tree
Showing 11 changed files with 114 additions and 44 deletions.
14 changes: 14 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[mypy]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_untyped_defs = true
fast_module_lookup = true
ignore_missing_imports = true
implicit_reexport = false
local_partial_types = true
show_column_numbers = true
show_error_codes = true
strict_equality = true
warn_redundant_casts = true
warn_unreachable = true
warn_unused_ignores = true
5 changes: 5 additions & 0 deletions requirements-tox.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ flake8==6.1.0
ipdb==0.13.13
pytz==2024.1

# Type checking requirements
mypy==1.10.0
django-stubs==5.0.0
djangorestframework-stubs==3.15.0

# wheel for PyPI installs
wheel==0.43.0

Expand Down
31 changes: 19 additions & 12 deletions rest_framework_nested/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,31 @@
These fields allow you to specify the style that should be used to represent
model relationships with hyperlinks.
"""
from __future__ import annotations
from functools import reduce
from typing import Any, TypeVar, Generic

import rest_framework.relations
from rest_framework.relations import ObjectDoesNotExist, ObjectValueError, ObjectTypeError
from django.core.exceptions import ObjectDoesNotExist
from django.db.models import Model
from rest_framework.relations import HyperlinkedRelatedField, ObjectValueError, ObjectTypeError
from rest_framework.exceptions import ValidationError
from rest_framework.request import Request


class NestedHyperlinkedRelatedField(rest_framework.relations.HyperlinkedRelatedField):
T_Model = TypeVar('T_Model', bound=Model)


class NestedHyperlinkedRelatedField(HyperlinkedRelatedField, Generic[T_Model]):
lookup_field = 'pk'
parent_lookup_kwargs = {
'parent_pk': 'parent__pk'
}

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
self.parent_lookup_kwargs = kwargs.pop('parent_lookup_kwargs', self.parent_lookup_kwargs)
super().__init__(*args, **kwargs)

def get_url(self, obj, view_name, request, format):
def get_url(self, obj: Model, view_name: str, request: Request, format: str | None) -> str | None:
"""
Given an object, return the URL that hyperlinks to the object.
Expand All @@ -46,7 +53,7 @@ def get_url(self, obj, view_name, request, format):

try:
# use the Django ORM to lookup this value, e.g., obj.parent.pk
lookup_value = reduce(getattr, [obj] + lookups)
lookup_value = reduce(getattr, [obj] + lookups) # type: ignore[operator,arg-type]
except AttributeError:
# Not nested. Act like a standard HyperlinkedRelatedField
return super().get_url(obj, view_name, request, format)
Expand All @@ -56,7 +63,7 @@ def get_url(self, obj, view_name, request, format):

return self.reverse(view_name, kwargs=kwargs, request=request, format=format)

def get_object(self, view_name, view_args, view_kwargs):
def get_object(self, view_name: str, view_args: list[Any], view_kwargs: dict[str, Any]) -> T_Model:
"""
Return the object corresponding to a matched URL.
Expand All @@ -74,14 +81,14 @@ def get_object(self, view_name, view_args, view_kwargs):

return self.get_queryset().get(**kwargs)

def use_pk_only_optimization(self):
def use_pk_only_optimization(self) -> bool:
return False

def to_internal_value(self, data):
def to_internal_value(self, data: Any) -> T_Model:
try:
return super().to_internal_value(data)
except ValidationError as err:
if err.detail[0].code != 'no_match':
if err.detail[0].code != 'no_match': # type: ignore[union-attr,index]
raise

# data is probable the lookup value, not the resource URL
Expand All @@ -91,8 +98,8 @@ def to_internal_value(self, data):
self.fail('does_not_exist')


class NestedHyperlinkedIdentityField(NestedHyperlinkedRelatedField):
def __init__(self, view_name=None, **kwargs):
class NestedHyperlinkedIdentityField(NestedHyperlinkedRelatedField[T_Model]):
def __init__(self, view_name: str | None = None, **kwargs: Any) -> None:
assert view_name is not None, 'The `view_name` argument is required.'
kwargs['read_only'] = True
kwargs['source'] = '*'
Expand Down
33 changes: 24 additions & 9 deletions rest_framework_nested/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,13 @@
urlpatterns = router.urls
"""

from __future__ import annotations
import sys
import re
from rest_framework.routers import SimpleRouter, DefaultRouter # noqa: F401
from typing import Any

from rest_framework.routers import SimpleRouter, DefaultRouter, DynamicRoute, Route
from rest_framework.viewsets import ViewSetMixin


if sys.version_info[0] < 3:
Expand All @@ -45,7 +49,16 @@ class LookupMixin:


class NestedMixin:
def __init__(self, parent_router, parent_prefix, *args, **kwargs):
trailing_slash: str
routes: list[Route | DynamicRoute]

def __init__(
self,
parent_router: SimpleRouter | DefaultRouter | NestedMixin,
parent_prefix: str,
*args: Any,
**kwargs: Any
) -> None:
self.parent_router = parent_router
self.parent_prefix = parent_prefix
self.nest_count = getattr(parent_router, 'nest_count', 0) + 1
Expand All @@ -69,19 +82,21 @@ def __init__(self, parent_router, parent_prefix, *args, **kwargs):
# we set our trailing slash to just '/', leading to inconsistent behavior.
self.trailing_slash = parent_router.trailing_slash

parent_registry = [registered for registered
in self.parent_router.registry
if registered[0] == self.parent_prefix]
parent_registry: list[tuple[str, type[ViewSetMixin], str]] = [
registered for registered
in self.parent_router.registry # type: ignore[union-attr]
if registered[0] == self.parent_prefix
]
try:
parent_registry = parent_registry[0]
parent_prefix, parent_viewset, parent_basename = parent_registry
parent_registry_item = parent_registry[0]
parent_prefix, parent_viewset, parent_basename = parent_registry_item
except:
raise RuntimeError('parent registered resource not found')

self.check_valid_name(self.nest_prefix)

nested_routes = []
parent_lookup_regex = parent_router.get_lookup_regex(parent_viewset, self.nest_prefix)
parent_lookup_regex = parent_router.get_lookup_regex(parent_viewset, self.nest_prefix) # type: ignore[union-attr]

self.parent_regex = f'{parent_prefix}/{parent_lookup_regex}/'
# If there is no parent prefix, the first part of the url is probably
Expand All @@ -105,7 +120,7 @@ def __init__(self, parent_router, parent_prefix, *args, **kwargs):

self.routes = nested_routes

def check_valid_name(self, value):
def check_valid_name(self, value: str) -> None:
if IDENTIFIER_REGEX.match(value) is None:
raise ValueError(f"lookup argument '{value}' needs to be valid python identifier")

Expand Down
6 changes: 4 additions & 2 deletions rest_framework_nested/runtests/runcoverage.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@
# http://code.djangoproject.com/svn/django/trunk/tests/runtests.py
import os
import sys
from typing import NoReturn
from coverage import coverage

# fix sys path so we don't need to setup PYTHONPATH
sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
os.environ['DJANGO_SETTINGS_MODULE'] = 'rest_framework_nested.runtests.settings'


def main():
def main() -> NoReturn:
"""Run the tests for rest_framework and generate a coverage report."""

cov = coverage()
Expand All @@ -26,6 +27,7 @@ def main():
from django.test.utils import get_runner
TestRunner = get_runner(settings)

failures: int
if hasattr(TestRunner, 'func_name'):
# Pre 1.2 test runners were just functions,
# and did not support the 'failfast' option.
Expand All @@ -34,7 +36,7 @@ def main():
'Function-based test runners are deprecated. Test runners should be classes with a run_tests() method.',
DeprecationWarning
)
failures = TestRunner(['tests'])
failures = TestRunner(['tests']) # type: ignore[assignment,arg-type]
else:
test_runner = TestRunner()
failures = test_runner.run_tests(['tests'])
Expand Down
5 changes: 3 additions & 2 deletions rest_framework_nested/runtests/runtests.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# http://code.djangoproject.com/svn/django/trunk/tests/runtests.py
import os
import sys
from typing import NoReturn

# fix sys path so we don't need to setup PYTHONPATH
sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
Expand All @@ -15,7 +16,7 @@
from django.test.utils import get_runner


def usage():
def usage() -> str:
return """
Usage: python runtests.py [UnitTestClass].[method]
Expand All @@ -25,7 +26,7 @@ def usage():
"""


def main():
def main() -> NoReturn:
TestRunner = get_runner(settings)
import ipdb
ipdb.set_trace()
Expand Down
4 changes: 3 additions & 1 deletion rest_framework_nested/runtests/settings.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

# Django settings for testproject project.

DEBUG = True
Expand Down Expand Up @@ -83,7 +85,7 @@
# Don't forget to use absolute paths, not relative paths.
)

INSTALLED_APPS = (
INSTALLED_APPS: tuple[str, ...] = (
'django.contrib.auth',
'django.contrib.contenttypes',
'django.contrib.sessions',
Expand Down
2 changes: 1 addition & 1 deletion rest_framework_nested/runtests/urls.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
Blank URLConf just to keep runtests.py happy.
"""
from rest_framework.compat import patterns
from rest_framework.compat import patterns # type: ignore[attr-defined]

urlpatterns = patterns('',)
20 changes: 16 additions & 4 deletions rest_framework_nested/serializers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
from __future__ import annotations

from typing import Any, TypeVar

import rest_framework.serializers
from django.db.models import Model
from rest_framework.fields import Field
from rest_framework.utils.model_meta import RelationInfo
from rest_framework_nested.relations import NestedHyperlinkedIdentityField, NestedHyperlinkedRelatedField
try:
from rest_framework.utils.field_mapping import get_nested_relation_kwargs
Expand All @@ -8,7 +15,10 @@
# if version too old.


class NestedHyperlinkedModelSerializer(rest_framework.serializers.HyperlinkedModelSerializer):
T_Model = TypeVar('T_Model', bound=Model)


class NestedHyperlinkedModelSerializer(rest_framework.serializers.HyperlinkedModelSerializer[T_Model]):
"""
A type of `ModelSerializer` that uses hyperlinked relationships with compound keys instead
of primary key relationships. Specifically:
Expand All @@ -25,11 +35,11 @@ class NestedHyperlinkedModelSerializer(rest_framework.serializers.HyperlinkedMod
serializer_url_field = NestedHyperlinkedIdentityField
serializer_related_field = NestedHyperlinkedRelatedField

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
self.parent_lookup_kwargs = kwargs.pop('parent_lookup_kwargs', self.parent_lookup_kwargs)
super().__init__(*args, **kwargs)

def build_url_field(self, field_name, model_class):
def build_url_field(self, field_name: str, model_class: T_Model) -> tuple[type[Field], dict[str, Any]]:
field_class, field_kwargs = super().build_url_field(
field_name,
model_class
Expand All @@ -38,7 +48,9 @@ def build_url_field(self, field_name, model_class):

return field_class, field_kwargs

def build_nested_field(self, field_name, relation_info, nested_depth):
def build_nested_field(
self, field_name: str, relation_info: RelationInfo, nested_depth: int
) -> tuple[type[Field], dict[str, Any]]:
"""
Create nested fields for forward and reverse relationships.
"""
Expand Down
34 changes: 22 additions & 12 deletions rest_framework_nested/viewsets.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,34 @@
from __future__ import annotations

import contextlib
from typing import Any, Generator, Generic, TypeVar, cast

from django.core.exceptions import ImproperlyConfigured
from django.db.models import Model, QuerySet
from django.http import HttpRequest, QueryDict
from rest_framework.generics import GenericAPIView
from rest_framework.request import Request
from rest_framework.viewsets import ViewSetMixin

T_Model = TypeVar('T_Model', bound=Model)


@contextlib.contextmanager
def _force_mutable(querydict: dict) -> dict:
def _force_mutable(querydict: QueryDict | dict[str, Any]) -> Generator[QueryDict | dict[str, Any], None, None]:
"""
Takes a HttpRequest querydict from Django and forces it to be mutable.
Reverts the initial state back on exit, if any.
"""
initial_mutability = getattr(querydict, '_mutable', None)
if initial_mutability is not None:
querydict._mutable = True
querydict._mutable = True # type: ignore[union-attr]
yield querydict
if initial_mutability is not None:
querydict._mutable = initial_mutability
querydict._mutable = initial_mutability # type: ignore[union-attr]


class NestedViewSetMixin:
def _get_parent_lookup_kwargs(self) -> dict:
class NestedViewSetMixin(Generic[T_Model]):
def _get_parent_lookup_kwargs(self) -> dict[str, str]:
"""
Locates and returns the `parent_lookup_kwargs` dict informing
how the kwargs in the URL maps to the parents of the model instance
Expand All @@ -29,7 +39,7 @@ def _get_parent_lookup_kwargs(self) -> dict:
parent_lookup_kwargs = getattr(self, 'parent_lookup_kwargs', None)

if not parent_lookup_kwargs:
serializer_class = self.get_serializer_class()
serializer_class = cast(GenericAPIView, self).get_serializer_class()
parent_lookup_kwargs = getattr(serializer_class, 'parent_lookup_kwargs', None)

if not parent_lookup_kwargs:
Expand All @@ -39,28 +49,28 @@ def _get_parent_lookup_kwargs(self) -> dict:

return parent_lookup_kwargs

def get_queryset(self):
def get_queryset(self) -> QuerySet[T_Model]:
"""
Filter the `QuerySet` based on its parents as defined in the
`serializer_class.parent_lookup_kwargs` or `viewset.parent_lookup_kwargs`
"""
queryset = super().get_queryset()
queryset = super().get_queryset() # type: ignore[misc]

if getattr(self, 'swagger_fake_view', False):
return queryset

orm_filters = {}
parent_lookup_kwargs = self._get_parent_lookup_kwargs()
for query_param, field_name in parent_lookup_kwargs.items():
orm_filters[field_name] = self.kwargs[query_param]
orm_filters[field_name] = cast(ViewSetMixin, self).kwargs[query_param]
return queryset.filter(**orm_filters)

def initialize_request(self, request, *args, **kwargs):
def initialize_request(self, request: HttpRequest, *args: Any, **kwargs: Any) -> Request:
"""
Adds the parent params from URL inside the children data available
"""
request = super().initialize_request(request, *args, **kwargs)
request = cast(ViewSetMixin, super()).initialize_request(request, *args, **kwargs)

if getattr(self, 'swagger_fake_view', False):
return request

Expand Down
Loading

0 comments on commit 4faca57

Please sign in to comment.