Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add initial type hints and run mypy in CI #339

Merged
merged 6 commits into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[mypy]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_untyped_defs = true
alanjds marked this conversation as resolved.
Show resolved Hide resolved
fast_module_lookup = true
ignore_missing_imports = true
implicit_reexport = false
local_partial_types = true
show_column_numbers = true
show_error_codes = true
strict_equality = true
warn_redundant_casts = true
warn_unreachable = true
warn_unused_ignores = true
5 changes: 5 additions & 0 deletions requirements-tox.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ flake8==6.1.0
ipdb==0.13.13
pytz==2024.1

# Type checking requirements
mypy==1.10.0
django-stubs==5.0.0
djangorestframework-stubs==3.15.0
Comment on lines +9 to +12
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Should these dependencies be included in requirements-tox.txt, as is now?
  2. Should I create separate requirements-mypy.txt?
  3. Should I embed them into tox.ini?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am fine with any of the options, yet less files seems preferable.


# wheel for PyPI installs
wheel==0.43.0

Expand Down
31 changes: 19 additions & 12 deletions rest_framework_nested/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,31 @@
These fields allow you to specify the style that should be used to represent
model relationships with hyperlinks.
"""
from __future__ import annotations
from functools import reduce
from typing import Any, TypeVar, Generic

import rest_framework.relations
from rest_framework.relations import ObjectDoesNotExist, ObjectValueError, ObjectTypeError
from django.core.exceptions import ObjectDoesNotExist
from django.db.models import Model
from rest_framework.relations import HyperlinkedRelatedField, ObjectValueError, ObjectTypeError
from rest_framework.exceptions import ValidationError
from rest_framework.request import Request


class NestedHyperlinkedRelatedField(rest_framework.relations.HyperlinkedRelatedField):
T_Model = TypeVar('T_Model', bound=Model)


class NestedHyperlinkedRelatedField(HyperlinkedRelatedField, Generic[T_Model]):
lookup_field = 'pk'
parent_lookup_kwargs = {
'parent_pk': 'parent__pk'
}

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
self.parent_lookup_kwargs = kwargs.pop('parent_lookup_kwargs', self.parent_lookup_kwargs)
super().__init__(*args, **kwargs)

def get_url(self, obj, view_name, request, format):
def get_url(self, obj: Model, view_name: str, request: Request, format: str | None) -> str | None:
intgr marked this conversation as resolved.
Show resolved Hide resolved
"""
Given an object, return the URL that hyperlinks to the object.

Expand All @@ -46,7 +53,7 @@ def get_url(self, obj, view_name, request, format):

try:
# use the Django ORM to lookup this value, e.g., obj.parent.pk
lookup_value = reduce(getattr, [obj] + lookups)
lookup_value = reduce(getattr, [obj] + lookups) # type: ignore[operator,arg-type]
except AttributeError:
# Not nested. Act like a standard HyperlinkedRelatedField
return super().get_url(obj, view_name, request, format)
Expand All @@ -56,7 +63,7 @@ def get_url(self, obj, view_name, request, format):

return self.reverse(view_name, kwargs=kwargs, request=request, format=format)

def get_object(self, view_name, view_args, view_kwargs):
def get_object(self, view_name: str, view_args: list[Any], view_kwargs: dict[str, Any]) -> T_Model:
"""
Return the object corresponding to a matched URL.

Expand All @@ -74,14 +81,14 @@ def get_object(self, view_name, view_args, view_kwargs):

return self.get_queryset().get(**kwargs)

def use_pk_only_optimization(self):
def use_pk_only_optimization(self) -> bool:
return False

def to_internal_value(self, data):
def to_internal_value(self, data: Any) -> T_Model:
try:
return super().to_internal_value(data)
except ValidationError as err:
if err.detail[0].code != 'no_match':
if err.detail[0].code != 'no_match': # type: ignore[union-attr,index]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is a bit iffy. err.detail isn't guaranteed to be a list of ErrorDetail. It can be an arbitrarily nested combination of list, dict and ErrorDetail.

raise

# data is probable the lookup value, not the resource URL
Expand All @@ -91,8 +98,8 @@ def to_internal_value(self, data):
self.fail('does_not_exist')


class NestedHyperlinkedIdentityField(NestedHyperlinkedRelatedField):
def __init__(self, view_name=None, **kwargs):
class NestedHyperlinkedIdentityField(NestedHyperlinkedRelatedField[T_Model]):
def __init__(self, view_name: str | None = None, **kwargs: Any) -> None:
assert view_name is not None, 'The `view_name` argument is required.'
kwargs['read_only'] = True
kwargs['source'] = '*'
Expand Down
33 changes: 24 additions & 9 deletions rest_framework_nested/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,13 @@
urlpatterns = router.urls
"""

from __future__ import annotations
import sys
import re
from rest_framework.routers import SimpleRouter, DefaultRouter # noqa: F401
from typing import Any

from rest_framework.routers import SimpleRouter, DefaultRouter, DynamicRoute, Route
from rest_framework.viewsets import ViewSetMixin


if sys.version_info[0] < 3:
Expand All @@ -45,7 +49,16 @@ class LookupMixin:


class NestedMixin:
def __init__(self, parent_router, parent_prefix, *args, **kwargs):
trailing_slash: str
routes: list[Route | DynamicRoute]
intgr marked this conversation as resolved.
Show resolved Hide resolved

def __init__(
self,
parent_router: SimpleRouter | DefaultRouter | NestedMixin,
parent_prefix: str,
*args: Any,
**kwargs: Any
) -> None:
self.parent_router = parent_router
self.parent_prefix = parent_prefix
self.nest_count = getattr(parent_router, 'nest_count', 0) + 1
Expand All @@ -69,19 +82,21 @@ def __init__(self, parent_router, parent_prefix, *args, **kwargs):
# we set our trailing slash to just '/', leading to inconsistent behavior.
self.trailing_slash = parent_router.trailing_slash

parent_registry = [registered for registered
in self.parent_router.registry
if registered[0] == self.parent_prefix]
parent_registry: list[tuple[str, type[ViewSetMixin], str]] = [
intgr marked this conversation as resolved.
Show resolved Hide resolved
registered for registered
in self.parent_router.registry # type: ignore[union-attr]
if registered[0] == self.parent_prefix
]
try:
parent_registry = parent_registry[0]
parent_prefix, parent_viewset, parent_basename = parent_registry
parent_registry_item = parent_registry[0]
parent_prefix, parent_viewset, parent_basename = parent_registry_item
except:
raise RuntimeError('parent registered resource not found')

self.check_valid_name(self.nest_prefix)

nested_routes = []
parent_lookup_regex = parent_router.get_lookup_regex(parent_viewset, self.nest_prefix)
parent_lookup_regex = parent_router.get_lookup_regex(parent_viewset, self.nest_prefix) # type: ignore[union-attr]

self.parent_regex = f'{parent_prefix}/{parent_lookup_regex}/'
# If there is no parent prefix, the first part of the url is probably
Expand All @@ -105,7 +120,7 @@ def __init__(self, parent_router, parent_prefix, *args, **kwargs):

self.routes = nested_routes

def check_valid_name(self, value):
def check_valid_name(self, value: str) -> None:
if IDENTIFIER_REGEX.match(value) is None:
raise ValueError(f"lookup argument '{value}' needs to be valid python identifier")

Expand Down
6 changes: 4 additions & 2 deletions rest_framework_nested/runtests/runcoverage.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@
# http://code.djangoproject.com/svn/django/trunk/tests/runtests.py
import os
import sys
from typing import NoReturn
from coverage import coverage

# fix sys path so we don't need to setup PYTHONPATH
sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
os.environ['DJANGO_SETTINGS_MODULE'] = 'rest_framework_nested.runtests.settings'


def main():
def main() -> NoReturn:
"""Run the tests for rest_framework and generate a coverage report."""

cov = coverage()
Expand All @@ -26,6 +27,7 @@ def main():
from django.test.utils import get_runner
TestRunner = get_runner(settings)

failures: int
if hasattr(TestRunner, 'func_name'):
# Pre 1.2 test runners were just functions,
# and did not support the 'failfast' option.
Expand All @@ -34,7 +36,7 @@ def main():
'Function-based test runners are deprecated. Test runners should be classes with a run_tests() method.',
DeprecationWarning
)
failures = TestRunner(['tests'])
failures = TestRunner(['tests']) # type: ignore[assignment,arg-type]
intgr marked this conversation as resolved.
Show resolved Hide resolved
else:
test_runner = TestRunner()
failures = test_runner.run_tests(['tests'])
Expand Down
5 changes: 3 additions & 2 deletions rest_framework_nested/runtests/runtests.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# http://code.djangoproject.com/svn/django/trunk/tests/runtests.py
import os
import sys
from typing import NoReturn

# fix sys path so we don't need to setup PYTHONPATH
sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
Expand All @@ -15,7 +16,7 @@
from django.test.utils import get_runner


def usage():
def usage() -> str:
return """
Usage: python runtests.py [UnitTestClass].[method]

Expand All @@ -25,7 +26,7 @@ def usage():
"""


def main():
def main() -> NoReturn:
TestRunner = get_runner(settings)
import ipdb
ipdb.set_trace()
Expand Down
4 changes: 3 additions & 1 deletion rest_framework_nested/runtests/settings.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

# Django settings for testproject project.

DEBUG = True
Expand Down Expand Up @@ -83,7 +85,7 @@
# Don't forget to use absolute paths, not relative paths.
)

INSTALLED_APPS = (
INSTALLED_APPS: tuple[str, ...] = (
intgr marked this conversation as resolved.
Show resolved Hide resolved
'django.contrib.auth',
'django.contrib.contenttypes',
'django.contrib.sessions',
Expand Down
2 changes: 1 addition & 1 deletion rest_framework_nested/runtests/urls.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
Blank URLConf just to keep runtests.py happy.
"""
from rest_framework.compat import patterns
from rest_framework.compat import patterns # type: ignore[attr-defined]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rest_framework.compat.patterns doesn't actually exist, this crashes with ImportError.


urlpatterns = patterns('',)
20 changes: 16 additions & 4 deletions rest_framework_nested/serializers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
from __future__ import annotations

from typing import Any, TypeVar

import rest_framework.serializers
from django.db.models import Model
from rest_framework.fields import Field
from rest_framework.utils.model_meta import RelationInfo
from rest_framework_nested.relations import NestedHyperlinkedIdentityField, NestedHyperlinkedRelatedField
try:
from rest_framework.utils.field_mapping import get_nested_relation_kwargs
Expand All @@ -8,7 +15,10 @@
# if version too old.


class NestedHyperlinkedModelSerializer(rest_framework.serializers.HyperlinkedModelSerializer):
T_Model = TypeVar('T_Model', bound=Model)


class NestedHyperlinkedModelSerializer(rest_framework.serializers.HyperlinkedModelSerializer[T_Model]):
"""
A type of `ModelSerializer` that uses hyperlinked relationships with compound keys instead
of primary key relationships. Specifically:
Expand All @@ -25,11 +35,11 @@ class NestedHyperlinkedModelSerializer(rest_framework.serializers.HyperlinkedMod
serializer_url_field = NestedHyperlinkedIdentityField
serializer_related_field = NestedHyperlinkedRelatedField

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
self.parent_lookup_kwargs = kwargs.pop('parent_lookup_kwargs', self.parent_lookup_kwargs)
super().__init__(*args, **kwargs)

def build_url_field(self, field_name, model_class):
def build_url_field(self, field_name: str, model_class: T_Model) -> tuple[type[Field], dict[str, Any]]:
field_class, field_kwargs = super().build_url_field(
field_name,
model_class
Expand All @@ -38,7 +48,9 @@ def build_url_field(self, field_name, model_class):

return field_class, field_kwargs

def build_nested_field(self, field_name, relation_info, nested_depth):
def build_nested_field(
self, field_name: str, relation_info: RelationInfo, nested_depth: int
) -> tuple[type[Field], dict[str, Any]]:
"""
Create nested fields for forward and reverse relationships.
"""
Expand Down
34 changes: 22 additions & 12 deletions rest_framework_nested/viewsets.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,34 @@
from __future__ import annotations

import contextlib
from typing import Any, Generator, Generic, TypeVar, cast

from django.core.exceptions import ImproperlyConfigured
from django.db.models import Model, QuerySet
from django.http import HttpRequest, QueryDict
from rest_framework.generics import GenericAPIView
from rest_framework.request import Request
from rest_framework.viewsets import ViewSetMixin

T_Model = TypeVar('T_Model', bound=Model)


@contextlib.contextmanager
def _force_mutable(querydict: dict) -> dict:
def _force_mutable(querydict: QueryDict | dict[str, Any]) -> Generator[QueryDict | dict[str, Any], None, None]:
intgr marked this conversation as resolved.
Show resolved Hide resolved
"""
Takes a HttpRequest querydict from Django and forces it to be mutable.
Reverts the initial state back on exit, if any.
"""
initial_mutability = getattr(querydict, '_mutable', None)
if initial_mutability is not None:
querydict._mutable = True
querydict._mutable = True # type: ignore[union-attr]
yield querydict
if initial_mutability is not None:
querydict._mutable = initial_mutability
querydict._mutable = initial_mutability # type: ignore[union-attr]


class NestedViewSetMixin:
def _get_parent_lookup_kwargs(self) -> dict:
class NestedViewSetMixin(Generic[T_Model]):
def _get_parent_lookup_kwargs(self) -> dict[str, str]:
"""
Locates and returns the `parent_lookup_kwargs` dict informing
how the kwargs in the URL maps to the parents of the model instance
Expand All @@ -29,7 +39,7 @@ def _get_parent_lookup_kwargs(self) -> dict:
parent_lookup_kwargs = getattr(self, 'parent_lookup_kwargs', None)

if not parent_lookup_kwargs:
serializer_class = self.get_serializer_class()
serializer_class = cast(GenericAPIView, self).get_serializer_class()
intgr marked this conversation as resolved.
Show resolved Hide resolved
parent_lookup_kwargs = getattr(serializer_class, 'parent_lookup_kwargs', None)

if not parent_lookup_kwargs:
Expand All @@ -39,28 +49,28 @@ def _get_parent_lookup_kwargs(self) -> dict:

return parent_lookup_kwargs

def get_queryset(self):
def get_queryset(self) -> QuerySet[T_Model]:
"""
Filter the `QuerySet` based on its parents as defined in the
`serializer_class.parent_lookup_kwargs` or `viewset.parent_lookup_kwargs`
"""
queryset = super().get_queryset()
queryset = super().get_queryset() # type: ignore[misc]
intgr marked this conversation as resolved.
Show resolved Hide resolved

if getattr(self, 'swagger_fake_view', False):
return queryset

orm_filters = {}
parent_lookup_kwargs = self._get_parent_lookup_kwargs()
for query_param, field_name in parent_lookup_kwargs.items():
orm_filters[field_name] = self.kwargs[query_param]
orm_filters[field_name] = cast(ViewSetMixin, self).kwargs[query_param]
intgr marked this conversation as resolved.
Show resolved Hide resolved
return queryset.filter(**orm_filters)

def initialize_request(self, request, *args, **kwargs):
def initialize_request(self, request: HttpRequest, *args: Any, **kwargs: Any) -> Request:
"""
Adds the parent params from URL inside the children data available
"""
request = super().initialize_request(request, *args, **kwargs)
request = cast(ViewSetMixin, super()).initialize_request(request, *args, **kwargs)
intgr marked this conversation as resolved.
Show resolved Hide resolved

if getattr(self, 'swagger_fake_view', False):
return request

Expand Down
Loading
Loading