diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index f897d76e..a4c80a16 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -31,4 +31,27 @@ jobs: echo env - name: Test with tox - run: tox + run: tox run --skip-env=py311-mypy + + mypy: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: 3.11 + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install tox tox-gh-actions + - name: Know your environment + run: | + cd $GITHUB_WORKSPACE + echo + ls -F + echo + env + - name: Test with tox + run: tox run -e py311-mypy diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 00000000..a335fabe --- /dev/null +++ b/mypy.ini @@ -0,0 +1,14 @@ +[mypy] +check_untyped_defs = true +disallow_incomplete_defs = true +disallow_untyped_defs = true +fast_module_lookup = true +ignore_missing_imports = true +implicit_reexport = false +local_partial_types = true +show_column_numbers = true +show_error_codes = true +strict_equality = true +warn_redundant_casts = true +warn_unreachable = true +warn_unused_ignores = true diff --git a/requirements-tox.txt b/requirements-tox.txt index 7d18d19d..3fd09740 100644 --- a/requirements-tox.txt +++ b/requirements-tox.txt @@ -6,6 +6,11 @@ flake8==6.1.0 ipdb==0.13.13 pytz==2024.1 +# Type checking requirements +mypy==1.10.0 +django-stubs==5.0.0 +djangorestframework-stubs==3.15.0 + # wheel for PyPI installs wheel==0.43.0 diff --git a/rest_framework_nested/relations.py b/rest_framework_nested/relations.py index 3d684f00..8919b05a 100644 --- a/rest_framework_nested/relations.py +++ b/rest_framework_nested/relations.py @@ -4,24 +4,32 @@ These fields allow you to specify the style that should be used to represent model relationships with hyperlinks. """ +from __future__ import annotations + from functools import reduce +from typing import Any, Generic, TypeVar -import rest_framework.relations -from rest_framework.relations import ObjectDoesNotExist, ObjectValueError, ObjectTypeError +from django.core.exceptions import ObjectDoesNotExist +from django.db.models import Model +from rest_framework.relations import HyperlinkedRelatedField, ObjectTypeError, ObjectValueError from rest_framework.exceptions import ValidationError +from rest_framework.request import Request + + +T_Model = TypeVar('T_Model', bound=Model) -class NestedHyperlinkedRelatedField(rest_framework.relations.HyperlinkedRelatedField): +class NestedHyperlinkedRelatedField(HyperlinkedRelatedField, Generic[T_Model]): lookup_field = 'pk' parent_lookup_kwargs = { 'parent_pk': 'parent__pk' } - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: self.parent_lookup_kwargs = kwargs.pop('parent_lookup_kwargs', self.parent_lookup_kwargs) super().__init__(*args, **kwargs) - def get_url(self, obj, view_name, request, format): + def get_url(self, obj: Model, view_name: str, request: Request, format: str | None) -> str | None: """ Given an object, return the URL that hyperlinks to the object. @@ -46,7 +54,7 @@ def get_url(self, obj, view_name, request, format): try: # use the Django ORM to lookup this value, e.g., obj.parent.pk - lookup_value = reduce(getattr, [obj] + lookups) + lookup_value = reduce(getattr, [obj] + lookups) # type: ignore[operator,arg-type] except AttributeError: # Not nested. Act like a standard HyperlinkedRelatedField return super().get_url(obj, view_name, request, format) @@ -56,7 +64,7 @@ def get_url(self, obj, view_name, request, format): return self.reverse(view_name, kwargs=kwargs, request=request, format=format) - def get_object(self, view_name, view_args, view_kwargs): + def get_object(self, view_name: str, view_args: list[Any], view_kwargs: dict[str, Any]) -> T_Model: """ Return the object corresponding to a matched URL. @@ -74,14 +82,14 @@ def get_object(self, view_name, view_args, view_kwargs): return self.get_queryset().get(**kwargs) - def use_pk_only_optimization(self): + def use_pk_only_optimization(self) -> bool: return False - def to_internal_value(self, data): + def to_internal_value(self, data: Any) -> T_Model: try: return super().to_internal_value(data) except ValidationError as err: - if err.detail[0].code != 'no_match': + if err.detail[0].code != 'no_match': # type: ignore[union-attr,index] raise # data is probable the lookup value, not the resource URL @@ -91,8 +99,8 @@ def to_internal_value(self, data): self.fail('does_not_exist') -class NestedHyperlinkedIdentityField(NestedHyperlinkedRelatedField): - def __init__(self, view_name=None, **kwargs): +class NestedHyperlinkedIdentityField(NestedHyperlinkedRelatedField[T_Model]): + def __init__(self, view_name: str | None = None, **kwargs: Any) -> None: assert view_name is not None, 'The `view_name` argument is required.' kwargs['read_only'] = True kwargs['source'] = '*' diff --git a/rest_framework_nested/routers.py b/rest_framework_nested/routers.py index 106eb829..ae826e67 100644 --- a/rest_framework_nested/routers.py +++ b/rest_framework_nested/routers.py @@ -25,10 +25,13 @@ urlpatterns = router.urls """ +from __future__ import annotations + import sys import re -from rest_framework.routers import SimpleRouter, DefaultRouter # noqa: F401 +from typing import Any +from rest_framework.routers import DefaultRouter, SimpleRouter if sys.version_info[0] < 3: IDENTIFIER_REGEX = re.compile(r"^[^\d\W]\w*$") @@ -45,7 +48,13 @@ class LookupMixin: class NestedMixin: - def __init__(self, parent_router, parent_prefix, *args, **kwargs): + def __init__( + self, + parent_router: SimpleRouter | DefaultRouter | NestedMixin, + parent_prefix: str, + *args: Any, + **kwargs: Any + ) -> None: self.parent_router = parent_router self.parent_prefix = parent_prefix self.nest_count = getattr(parent_router, 'nest_count', 0) + 1 @@ -67,21 +76,23 @@ def __init__(self, parent_router, parent_prefix, *args, **kwargs): # behavior is ALWAYS consistent with the parent. If we didn't, we might create # a situation where the parent's trailing slash is truthy (but not '/') and # we set our trailing slash to just '/', leading to inconsistent behavior. - self.trailing_slash = parent_router.trailing_slash + self.trailing_slash = parent_router.trailing_slash # type: ignore[has-type] - parent_registry = [registered for registered - in self.parent_router.registry - if registered[0] == self.parent_prefix] + parent_registry = [ + registered for registered + in self.parent_router.registry # type: ignore[union-attr] + if registered[0] == self.parent_prefix + ] try: - parent_registry = parent_registry[0] - parent_prefix, parent_viewset, parent_basename = parent_registry + parent_registry_item = parent_registry[0] + parent_prefix, parent_viewset, parent_basename = parent_registry_item except: raise RuntimeError('parent registered resource not found') self.check_valid_name(self.nest_prefix) nested_routes = [] - parent_lookup_regex = parent_router.get_lookup_regex(parent_viewset, self.nest_prefix) + parent_lookup_regex = parent_router.get_lookup_regex(parent_viewset, self.nest_prefix) # type: ignore[union-attr] self.parent_regex = f'{parent_prefix}/{parent_lookup_regex}/' # If there is no parent prefix, the first part of the url is probably @@ -93,7 +104,7 @@ def __init__(self, parent_router, parent_prefix, *args, **kwargs): if hasattr(parent_router, 'parent_regex'): self.parent_regex = parent_router.parent_regex + self.parent_regex - for route in self.routes: + for route in self.routes: # type: ignore[has-type] route_contents = route._asdict() # This will get passed through .format in a little bit, so we need @@ -105,12 +116,12 @@ def __init__(self, parent_router, parent_prefix, *args, **kwargs): self.routes = nested_routes - def check_valid_name(self, value): + def check_valid_name(self, value: str) -> None: if IDENTIFIER_REGEX.match(value) is None: raise ValueError(f"lookup argument '{value}' needs to be valid python identifier") -class NestedSimpleRouter(NestedMixin, SimpleRouter): +class NestedSimpleRouter(NestedMixin, SimpleRouter): # type: ignore[misc] """ Create a NestedSimpleRouter nested within `parent_router` Args: @@ -131,7 +142,7 @@ class NestedSimpleRouter(NestedMixin, SimpleRouter): pass -class NestedDefaultRouter(NestedMixin, DefaultRouter): +class NestedDefaultRouter(NestedMixin, DefaultRouter): # type: ignore[misc] """ Create a NestedDefaultRouter nested within `parent_router` Args: diff --git a/rest_framework_nested/runtests/runcoverage.py b/rest_framework_nested/runtests/runcoverage.py index e4b05d90..ac7c1da3 100755 --- a/rest_framework_nested/runtests/runcoverage.py +++ b/rest_framework_nested/runtests/runcoverage.py @@ -8,6 +8,7 @@ # http://code.djangoproject.com/svn/django/trunk/tests/runtests.py import os import sys +from typing import NoReturn from coverage import coverage # fix sys path so we don't need to setup PYTHONPATH @@ -15,7 +16,7 @@ os.environ['DJANGO_SETTINGS_MODULE'] = 'rest_framework_nested.runtests.settings' -def main(): +def main() -> NoReturn: """Run the tests for rest_framework and generate a coverage report.""" cov = coverage() @@ -26,6 +27,7 @@ def main(): from django.test.utils import get_runner TestRunner = get_runner(settings) + failures: int if hasattr(TestRunner, 'func_name'): # Pre 1.2 test runners were just functions, # and did not support the 'failfast' option. @@ -34,7 +36,7 @@ def main(): 'Function-based test runners are deprecated. Test runners should be classes with a run_tests() method.', DeprecationWarning ) - failures = TestRunner(['tests']) + failures = TestRunner(['tests']) # type: ignore[assignment,arg-type] else: test_runner = TestRunner() failures = test_runner.run_tests(['tests']) diff --git a/rest_framework_nested/runtests/runtests.py b/rest_framework_nested/runtests/runtests.py index 46804675..78a12b6c 100755 --- a/rest_framework_nested/runtests/runtests.py +++ b/rest_framework_nested/runtests/runtests.py @@ -5,6 +5,7 @@ # http://code.djangoproject.com/svn/django/trunk/tests/runtests.py import os import sys +from typing import NoReturn # fix sys path so we don't need to setup PYTHONPATH sys.path.append(os.path.join(os.path.dirname(__file__), "../..")) @@ -15,7 +16,7 @@ from django.test.utils import get_runner -def usage(): +def usage() -> str: return """ Usage: python runtests.py [UnitTestClass].[method] @@ -25,7 +26,7 @@ def usage(): """ -def main(): +def main() -> NoReturn: TestRunner = get_runner(settings) import ipdb ipdb.set_trace() diff --git a/rest_framework_nested/runtests/settings.py b/rest_framework_nested/runtests/settings.py index 812e3df6..eecc90cd 100644 --- a/rest_framework_nested/runtests/settings.py +++ b/rest_framework_nested/runtests/settings.py @@ -1,3 +1,5 @@ +from __future__ import annotations + # Django settings for testproject project. DEBUG = True @@ -83,7 +85,7 @@ # Don't forget to use absolute paths, not relative paths. ) -INSTALLED_APPS = ( +INSTALLED_APPS = [ 'django.contrib.auth', 'django.contrib.contenttypes', 'django.contrib.sessions', @@ -96,7 +98,7 @@ # 'rest_framework.authtoken', 'rest_framework_nested', 'rest_framework_nested.tests', -) +] # OAuth is optional and won't work if there is no oauth_provider & oauth2 try: @@ -105,19 +107,19 @@ except ImportError: pass else: - INSTALLED_APPS += ( + INSTALLED_APPS += [ 'oauth_provider', - ) + ] try: import provider # noqa: F401 except ImportError: pass else: - INSTALLED_APPS += ( + INSTALLED_APPS += [ 'provider', 'provider.oauth2', - ) + ] # guardian is optional try: @@ -130,9 +132,9 @@ 'django.contrib.auth.backends.ModelBackend', # default 'guardian.backends.ObjectPermissionBackend', ) - INSTALLED_APPS += ( + INSTALLED_APPS += [ 'guardian', - ) + ] STATIC_URL = '/static/' diff --git a/rest_framework_nested/runtests/urls.py b/rest_framework_nested/runtests/urls.py index 08c42da9..98747c4f 100644 --- a/rest_framework_nested/runtests/urls.py +++ b/rest_framework_nested/runtests/urls.py @@ -1,6 +1,6 @@ """ Blank URLConf just to keep runtests.py happy. """ -from rest_framework.compat import patterns +from rest_framework.compat import patterns # type: ignore[attr-defined] urlpatterns = patterns('',) diff --git a/rest_framework_nested/serializers.py b/rest_framework_nested/serializers.py index 1a315a2d..c7779d66 100644 --- a/rest_framework_nested/serializers.py +++ b/rest_framework_nested/serializers.py @@ -1,4 +1,11 @@ +from __future__ import annotations + +from typing import Any, TypeVar + import rest_framework.serializers +from django.db.models import Model +from rest_framework.fields import Field +from rest_framework.utils.model_meta import RelationInfo from rest_framework_nested.relations import NestedHyperlinkedIdentityField, NestedHyperlinkedRelatedField try: from rest_framework.utils.field_mapping import get_nested_relation_kwargs @@ -8,7 +15,10 @@ # if version too old. -class NestedHyperlinkedModelSerializer(rest_framework.serializers.HyperlinkedModelSerializer): +T_Model = TypeVar('T_Model', bound=Model) + + +class NestedHyperlinkedModelSerializer(rest_framework.serializers.HyperlinkedModelSerializer[T_Model]): """ A type of `ModelSerializer` that uses hyperlinked relationships with compound keys instead of primary key relationships. Specifically: @@ -25,11 +35,11 @@ class NestedHyperlinkedModelSerializer(rest_framework.serializers.HyperlinkedMod serializer_url_field = NestedHyperlinkedIdentityField serializer_related_field = NestedHyperlinkedRelatedField - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: self.parent_lookup_kwargs = kwargs.pop('parent_lookup_kwargs', self.parent_lookup_kwargs) super().__init__(*args, **kwargs) - def build_url_field(self, field_name, model_class): + def build_url_field(self, field_name: str, model_class: T_Model) -> tuple[type[Field], dict[str, Any]]: field_class, field_kwargs = super().build_url_field( field_name, model_class @@ -38,7 +48,9 @@ def build_url_field(self, field_name, model_class): return field_class, field_kwargs - def build_nested_field(self, field_name, relation_info, nested_depth): + def build_nested_field( + self, field_name: str, relation_info: RelationInfo, nested_depth: int + ) -> tuple[type[Field], dict[str, Any]]: """ Create nested fields for forward and reverse relationships. """ diff --git a/rest_framework_nested/viewsets.py b/rest_framework_nested/viewsets.py index 41e9b27a..c920483b 100644 --- a/rest_framework_nested/viewsets.py +++ b/rest_framework_nested/viewsets.py @@ -1,24 +1,33 @@ +from __future__ import annotations + import contextlib +from typing import Any, Generic, Iterator, TypeVar from django.core.exceptions import ImproperlyConfigured +from django.db.models import Model, QuerySet +from django.http import HttpRequest, QueryDict +from rest_framework.request import Request +from rest_framework.serializers import BaseSerializer + +T_Model = TypeVar('T_Model', bound=Model) @contextlib.contextmanager -def _force_mutable(querydict: dict) -> dict: +def _force_mutable(querydict: QueryDict | dict[str, Any]) -> Iterator[QueryDict | dict[str, Any]]: """ Takes a HttpRequest querydict from Django and forces it to be mutable. Reverts the initial state back on exit, if any. """ initial_mutability = getattr(querydict, '_mutable', None) if initial_mutability is not None: - querydict._mutable = True + querydict._mutable = True # type: ignore[union-attr] yield querydict if initial_mutability is not None: - querydict._mutable = initial_mutability + querydict._mutable = initial_mutability # type: ignore[union-attr] -class NestedViewSetMixin: - def _get_parent_lookup_kwargs(self) -> dict: +class NestedViewSetMixin(Generic[T_Model]): + def _get_parent_lookup_kwargs(self) -> dict[str, str]: """ Locates and returns the `parent_lookup_kwargs` dict informing how the kwargs in the URL maps to the parents of the model instance @@ -29,7 +38,7 @@ def _get_parent_lookup_kwargs(self) -> dict: parent_lookup_kwargs = getattr(self, 'parent_lookup_kwargs', None) if not parent_lookup_kwargs: - serializer_class = self.get_serializer_class() + serializer_class: type[BaseSerializer[T_Model]] = self.get_serializer_class() # type: ignore[attr-defined] parent_lookup_kwargs = getattr(serializer_class, 'parent_lookup_kwargs', None) if not parent_lookup_kwargs: @@ -39,35 +48,35 @@ def _get_parent_lookup_kwargs(self) -> dict: return parent_lookup_kwargs - def get_queryset(self): + def get_queryset(self) -> QuerySet[T_Model]: """ Filter the `QuerySet` based on its parents as defined in the `serializer_class.parent_lookup_kwargs` or `viewset.parent_lookup_kwargs` """ - queryset = super().get_queryset() + queryset = super().get_queryset() # type: ignore[misc] if getattr(self, 'swagger_fake_view', False): return queryset - orm_filters = {} + orm_filters: dict[str, Any] = {} parent_lookup_kwargs = self._get_parent_lookup_kwargs() for query_param, field_name in parent_lookup_kwargs.items(): - orm_filters[field_name] = self.kwargs[query_param] + orm_filters[field_name] = self.kwargs[query_param] # type: ignore[attr-defined] return queryset.filter(**orm_filters) - def initialize_request(self, request, *args, **kwargs): + def initialize_request(self, request: HttpRequest, *args: Any, **kwargs: Any) -> Request: """ Adds the parent params from URL inside the children data available """ - request = super().initialize_request(request, *args, **kwargs) - + drf_request: Request = super().initialize_request(request, *args, **kwargs) # type: ignore[misc] + if getattr(self, 'swagger_fake_view', False): - return request + return drf_request for url_kwarg, fk_filter in self._get_parent_lookup_kwargs().items(): # fk_filter is alike 'grandparent__parent__pk' parent_arg = fk_filter.partition('__')[0] - for querydict in [request.data, request.query_params]: + for querydict in [drf_request.data, drf_request.query_params]: with _force_mutable(querydict): querydict[parent_arg] = kwargs[url_kwarg] - return request + return drf_request diff --git a/tests/conftest.py b/tests/conftest.py index ad165c8b..0504abba 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -22,7 +22,7 @@ def pytest_configure(): 'django.contrib.auth.middleware.AuthenticationMiddleware', 'django.contrib.messages.middleware.MessageMiddleware', ), - INSTALLED_APPS=( + INSTALLED_APPS=[ 'django.contrib.auth', 'django.contrib.contenttypes', 'django.contrib.sessions', @@ -34,7 +34,7 @@ def pytest_configure(): 'rest_framework.authtoken', 'tests', 'tests.serializers', - ), + ], PASSWORD_HASHERS=( 'django.contrib.auth.hashers.SHA1PasswordHasher', 'django.contrib.auth.hashers.PBKDF2PasswordHasher', @@ -51,19 +51,19 @@ def pytest_configure(): except ImportError: pass else: - settings.INSTALLED_APPS += ( + settings.INSTALLED_APPS += [ 'oauth_provider', - ) + ] try: import provider # NOQA except ImportError: pass else: - settings.INSTALLED_APPS += ( + settings.INSTALLED_APPS += [ 'provider', 'provider.oauth2', - ) + ] # guardian is optional try: @@ -76,9 +76,9 @@ def pytest_configure(): 'django.contrib.auth.backends.ModelBackend', 'guardian.backends.ObjectPermissionBackend', ) - settings.INSTALLED_APPS += ( + settings.INSTALLED_APPS += [ 'guardian', - ) + ] try: import django diff --git a/tox.ini b/tox.ini index 3771bff5..c51bbc32 100644 --- a/tox.ini +++ b/tox.ini @@ -1,6 +1,7 @@ [tox] envlist = py{38,39,310,311}-django{3.2,4.1,4.2}-drf3.14 + py311-mypy [gh-actions] python = @@ -30,3 +31,8 @@ deps = commands = mkdocs build deps = mkdocs>=1.3.0 + +[testenv:py311-mypy] +commands = mypy rest_framework_nested +deps = + -rrequirements-tox.txt