Skip to content

Commit

Permalink
Merge pull request #1209 from dmm-com/feature/custom_validate
Browse files Browse the repository at this point in the history
Added validate function for custom views
  • Loading branch information
userlocalhost authored Jul 4, 2024
2 parents e8bec0f + aeaf967 commit a3c6af1
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 10 deletions.
7 changes: 6 additions & 1 deletion entity/tests/test_api_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3265,14 +3265,19 @@ def side_effect(handler_name, entity_name, user, *args):
resp = self.client.post(
"/entity/api/v2/%s/entries/" % self.entity.id, json.dumps(params), "application/json"
)
self.assertEqual(resp.status_code, status.HTTP_202_ACCEPTED)
self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST)
self.assertTrue(mock_call_custom.called)

def side_effect(handler_name, entity_name, user, *args):
# Check specified parameters are expected
self.assertEqual(entity_name, self.entity.name)
self.assertEqual(user, self.user)

if handler_name == "validate_entry":
self.assertEqual(args[0], self.entity.name)
self.assertEqual(args[1], params["name"])
self.assertEqual(args[2], params["attrs"])

if handler_name == "before_create_entry_v2":
self.assertEqual(
args[0], {**params, "schema": self.entity, "created_user": self.user}
Expand Down
23 changes: 15 additions & 8 deletions entry/api_v2/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,14 +253,15 @@ def validate_name(self, name: str):
raise InvalidValueError("Names containing tab characters cannot be specified.")
return name

def _validate(self, schema: Entity, attrs: list[dict[str, Any]]):
def _validate(self, schema: Entity, name: str, attrs: list[dict[str, Any]]):
user: User | None = None
if "request" in self.context:
user = self.context["request"].user
if "_user" in self.context:
user = self.context["_user"]

# In create case, check attrs mandatory attribute
if not self.instance:
user: User | None = None
if "request" in self.context:
user = self.context["request"].user
if "_user" in self.context:
user = self.context["_user"]
if user is None:
raise RequiredParameterError("user is required")

Expand Down Expand Up @@ -289,6 +290,10 @@ def _validate(self, schema: Entity, attrs: list[dict[str, Any]]):
if not is_valid:
raise IncorrectTypeError("attrs id(%s) - %s" % (attr["id"], msg))

# check custom validate
if custom_view.is_custom("validate_entry", schema.name):
custom_view.call_custom("validate_entry", schema.name, user, schema.name, name, attrs)


@extend_schema_field({})
class AttributeValueField(serializers.Field):
Expand Down Expand Up @@ -321,7 +326,7 @@ class Meta:
fields = ["id", "name", "schema", "attrs", "created_user"]

def validate(self, params):
self._validate(params["schema"], params.get("attrs", []))
self._validate(params["schema"], params["name"], params.get("attrs", []))
return params

def create(self, validated_data: EntryCreateData):
Expand Down Expand Up @@ -407,7 +412,9 @@ class Meta:
}

def validate(self, params):
self._validate(self.instance.schema, params.get("attrs", []))
self._validate(
self.instance.schema, params.get("name", self.instance.name), params.get("attrs", [])
)
return params

def update(self, entry: Entry, validated_data: EntryUpdateData):
Expand Down
7 changes: 6 additions & 1 deletion entry/tests/test_api_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1063,14 +1063,19 @@ def side_effect(handler_name, entity_name, user, *args):
resp = self.client.put(
"/entry/api/v2/%s/" % entry.id, json.dumps(params), "application/json"
)
self.assertEqual(resp.status_code, status.HTTP_202_ACCEPTED)
self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST)
self.assertTrue(mock_call_custom.called)

def side_effect(handler_name, entity_name, user, *args):
self.assertEqual(entity_name, self.entity.name)
self.assertEqual(user, self.user)

# Check specified parameters are expected
if handler_name == "validate_entry":
self.assertEqual(args[0], self.entity.name)
self.assertEqual(args[1], params["name"])
self.assertEqual(args[2], params["attrs"])

if handler_name == "before_update_entry_v2":
self.assertEqual(args[0], params)
return args[0]
Expand Down

0 comments on commit a3c6af1

Please sign in to comment.