From 237289d6cd81b56287e4c6deaf6dcd3be47096cb Mon Sep 17 00:00:00 2001 From: hinashi Date: Thu, 4 Jul 2024 12:32:19 +0900 Subject: [PATCH 1/2] Added validate function for custom views --- entity/tests/test_api_v2.py | 6 +++++- entry/api_v2/serializers.py | 23 +++++++++++++++-------- entry/tests/test_api_v2.py | 6 +++++- 3 files changed, 25 insertions(+), 10 deletions(-) diff --git a/entity/tests/test_api_v2.py b/entity/tests/test_api_v2.py index ad199ad3b..77ba48110 100644 --- a/entity/tests/test_api_v2.py +++ b/entity/tests/test_api_v2.py @@ -3265,7 +3265,7 @@ def side_effect(handler_name, entity_name, user, *args): resp = self.client.post( "/entity/api/v2/%s/entries/" % self.entity.id, json.dumps(params), "application/json" ) - self.assertEqual(resp.status_code, status.HTTP_202_ACCEPTED) + self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST) self.assertTrue(mock_call_custom.called) def side_effect(handler_name, entity_name, user, *args): @@ -3273,6 +3273,10 @@ def side_effect(handler_name, entity_name, user, *args): self.assertEqual(entity_name, self.entity.name) self.assertEqual(user, self.user) + if handler_name == "validate_entry": + self.assertEqual(args[0], params["name"]) + self.assertEqual(args[1], params["attrs"]) + if handler_name == "before_create_entry_v2": self.assertEqual( args[0], {**params, "schema": self.entity, "created_user": self.user} diff --git a/entry/api_v2/serializers.py b/entry/api_v2/serializers.py index 4222ba480..a8d6a18bf 100644 --- a/entry/api_v2/serializers.py +++ b/entry/api_v2/serializers.py @@ -253,14 +253,15 @@ def validate_name(self, name: str): raise InvalidValueError("Names containing tab characters cannot be specified.") return name - def _validate(self, schema: Entity, attrs: list[dict[str, Any]]): + def _validate(self, schema: Entity, name: str, attrs: list[dict[str, Any]]): + user: User | None = None + if "request" in self.context: + user = self.context["request"].user + if "_user" in self.context: + user = self.context["_user"] + # In create case, check attrs mandatory attribute if not self.instance: - user: User | None = None - if "request" in self.context: - user = self.context["request"].user - if "_user" in self.context: - user = self.context["_user"] if user is None: raise RequiredParameterError("user is required") @@ -289,6 +290,10 @@ def _validate(self, schema: Entity, attrs: list[dict[str, Any]]): if not is_valid: raise IncorrectTypeError("attrs id(%s) - %s" % (attr["id"], msg)) + # check custom validate + if custom_view.is_custom("validate_entry", schema.name): + custom_view.call_custom("validate_entry", schema.name, user, name, attrs) + @extend_schema_field({}) class AttributeValueField(serializers.Field): @@ -321,7 +326,7 @@ class Meta: fields = ["id", "name", "schema", "attrs", "created_user"] def validate(self, params): - self._validate(params["schema"], params.get("attrs", [])) + self._validate(params["schema"], params["name"], params.get("attrs", [])) return params def create(self, validated_data: EntryCreateData): @@ -407,7 +412,9 @@ class Meta: } def validate(self, params): - self._validate(self.instance.schema, params.get("attrs", [])) + self._validate( + self.instance.schema, params.get("name", self.instance.name), params.get("attrs", []) + ) return params def update(self, entry: Entry, validated_data: EntryUpdateData): diff --git a/entry/tests/test_api_v2.py b/entry/tests/test_api_v2.py index 7a126f162..fce5c9c2b 100644 --- a/entry/tests/test_api_v2.py +++ b/entry/tests/test_api_v2.py @@ -1063,7 +1063,7 @@ def side_effect(handler_name, entity_name, user, *args): resp = self.client.put( "/entry/api/v2/%s/" % entry.id, json.dumps(params), "application/json" ) - self.assertEqual(resp.status_code, status.HTTP_202_ACCEPTED) + self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST) self.assertTrue(mock_call_custom.called) def side_effect(handler_name, entity_name, user, *args): @@ -1071,6 +1071,10 @@ def side_effect(handler_name, entity_name, user, *args): self.assertEqual(user, self.user) # Check specified parameters are expected + if handler_name == "validate_entry": + self.assertEqual(args[0], params["name"]) + self.assertEqual(args[1], params["attrs"]) + if handler_name == "before_update_entry_v2": self.assertEqual(args[0], params) return args[0] From aeaf967ade40380823f03091d49e83004ad3aee3 Mon Sep 17 00:00:00 2001 From: hinashi Date: Thu, 4 Jul 2024 12:57:55 +0900 Subject: [PATCH 2/2] Added schema name to parameters --- entity/tests/test_api_v2.py | 5 +++-- entry/api_v2/serializers.py | 2 +- entry/tests/test_api_v2.py | 5 +++-- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/entity/tests/test_api_v2.py b/entity/tests/test_api_v2.py index 77ba48110..7e995d703 100644 --- a/entity/tests/test_api_v2.py +++ b/entity/tests/test_api_v2.py @@ -3274,8 +3274,9 @@ def side_effect(handler_name, entity_name, user, *args): self.assertEqual(user, self.user) if handler_name == "validate_entry": - self.assertEqual(args[0], params["name"]) - self.assertEqual(args[1], params["attrs"]) + self.assertEqual(args[0], self.entity.name) + self.assertEqual(args[1], params["name"]) + self.assertEqual(args[2], params["attrs"]) if handler_name == "before_create_entry_v2": self.assertEqual( diff --git a/entry/api_v2/serializers.py b/entry/api_v2/serializers.py index a8d6a18bf..85a02a858 100644 --- a/entry/api_v2/serializers.py +++ b/entry/api_v2/serializers.py @@ -292,7 +292,7 @@ def _validate(self, schema: Entity, name: str, attrs: list[dict[str, Any]]): # check custom validate if custom_view.is_custom("validate_entry", schema.name): - custom_view.call_custom("validate_entry", schema.name, user, name, attrs) + custom_view.call_custom("validate_entry", schema.name, user, schema.name, name, attrs) @extend_schema_field({}) diff --git a/entry/tests/test_api_v2.py b/entry/tests/test_api_v2.py index fce5c9c2b..b52f8013f 100644 --- a/entry/tests/test_api_v2.py +++ b/entry/tests/test_api_v2.py @@ -1072,8 +1072,9 @@ def side_effect(handler_name, entity_name, user, *args): # Check specified parameters are expected if handler_name == "validate_entry": - self.assertEqual(args[0], params["name"]) - self.assertEqual(args[1], params["attrs"]) + self.assertEqual(args[0], self.entity.name) + self.assertEqual(args[1], params["name"]) + self.assertEqual(args[2], params["attrs"]) if handler_name == "before_update_entry_v2": self.assertEqual(args[0], params)