diff --git a/modeltranslation/manager.py b/modeltranslation/manager.py index 449e532d..a7cfad60 100644 --- a/modeltranslation/manager.py +++ b/modeltranslation/manager.py @@ -15,13 +15,15 @@ from django.core.exceptions import FieldDoesNotExist from django.db import models from django.db.backends.utils import CursorWrapper -from django.db.models import Field, Model +from django.db.models import Field, Model, F from django.db.models.expressions import Col +from django.db.models.functions import Concat, ConcatPair from django.db.models.lookups import Lookup from django.db.models.query import QuerySet, ValuesIterable from django.db.models.utils import create_namedtuple_class from django.utils.tree import Node +from modeltranslation._typing import Self, AutoPopulate from modeltranslation.fields import TranslationField from modeltranslation.thread_context import auto_populate_mode from modeltranslation.utils import ( @@ -30,7 +32,6 @@ get_language, resolution_order, ) -from modeltranslation._typing import Self, AutoPopulate _C2F_CACHE: dict[tuple[type[Model], str], Field] = {} _F2TM_CACHE: dict[type[Model], dict[str, type[Model]]] = {} @@ -513,6 +514,27 @@ def dates(self, field_name: str, *args: Any, **kwargs: Any) -> Self: new_key = rewrite_lookup_key(self.model, field_name) return super().dates(new_key, *args, **kwargs) + def _rewrite_concat(self, concat: Concat | ConcatPair): + new_source_expressions = [] + for exp in concat.source_expressions: + if isinstance(exp, (Concat, ConcatPair)): + exp = self._rewrite_concat(exp) + if isinstance(exp, F): + exp = self._rewrite_f(exp) + new_source_expressions.append(exp) + concat.set_source_expressions(new_source_expressions) + return concat + + def annotate(self, *args: Any, **kwargs: Any) -> Self: + if not self._rewrite: + return super().annotate(*args, **kwargs) + for key, val in list(kwargs.items()): + if isinstance(val, models.F): + kwargs[key] = self._rewrite_f(val) + if isinstance(val, Concat): + kwargs[key] = self._rewrite_concat(val) + return super().annotate(*args, **kwargs) + class FallbackValuesIterable(ValuesIterable): queryset: MultilingualQuerySet[Model] diff --git a/modeltranslation/tests/tests.py b/modeltranslation/tests/tests.py index 555b1ed4..cf5f4733 100644 --- a/modeltranslation/tests/tests.py +++ b/modeltranslation/tests/tests.py @@ -18,7 +18,7 @@ from django.core.management.base import CommandError from django.db import IntegrityError from django.db.models import CharField, Count, F, Q, TextField, Value -from django.db.models.functions import Cast +from django.db.models.functions import Cast, Concat from django.test import TestCase, TransactionTestCase from django.test.utils import override_settings from django.utils.translation import get_language, override, trans_real @@ -3670,6 +3670,45 @@ def test_distinct(self): assert titles_for_en == (("title_1_en", "desc_1_en"), ("title_2_en", "desc_1_en")) assert titles_for_de == (("title_1_de", "desc_1_de"), ("title_2_de", "desc_1_de")) + def test_annotate(self): + """Test if annotating is language-aware.""" + test = models.TestModel.objects.create(title_en="title_en", title_de="title_de") + + assert "en" == get_language() + assert ( + models.TestModel.objects.annotate(custom_title=F("title")).values_list( + "custom_title", flat=True + )[0] + == "title_en" + ) + with override("de"): + assert ( + models.TestModel.objects.annotate(custom_title=F("title")).values_list( + "custom_title", flat=True + )[0] + == "title_de" + ) + assert ( + models.TestModel.objects.annotate( + custom_title=Concat(F("title"), Value("value1"), Value("value2")) + ).values_list("custom_title", flat=True)[0] + == "title_devalue1value2" + ) + assert ( + models.TestModel.objects.annotate( + custom_title=Concat(F("title"), Concat(F("title"), Value("value"))) + ).values_list("custom_title", flat=True)[0] + == "title_detitle_devalue" + ) + models.ForeignKeyModel.objects.create(test=test) + models.ForeignKeyModel.objects.create(test=test) + assert ( + models.TestModel.objects.annotate(Count("test_fks")).values_list( + "test_fks__count", flat=True + )[0] + == 2 + ) + class TranslationModelFormTest(ModeltranslationTestBase): def test_fields(self):