Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New Feature: Validate with all properties required. #146

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions openapi_core/schema/media_types/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def deserialize(self, value):
deserializer = self.get_dererializer()
return deserializer(value)

def unmarshal(self, value, custom_formatters=None):
def unmarshal(self, value, custom_formatters=None,
require_all_props=False):
if not self.schema:
return value

Expand All @@ -42,12 +43,19 @@ def unmarshal(self, value, custom_formatters=None):
raise InvalidMediaTypeValue(exc)

try:
unmarshalled = self.schema.unmarshal(deserialized, custom_formatters=custom_formatters)
unmarshalled = self.schema.unmarshal(
deserialized,
custom_formatters=custom_formatters,
require_all_props=require_all_props
)
except OpenAPISchemaError as exc:
raise InvalidMediaTypeValue(exc)

try:
return self.schema.validate(
unmarshalled, custom_formatters=custom_formatters)
unmarshalled,
custom_formatters=custom_formatters,
require_all_props=require_all_props
)
except OpenAPISchemaError as exc:
raise InvalidMediaTypeValue(exc)
9 changes: 7 additions & 2 deletions openapi_core/schema/parameters/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ def get_value(self, request):

return location[self.name]

def unmarshal(self, value, custom_formatters=None):
def unmarshal(self, value, custom_formatters=None,
require_all_props=False):
if self.deprecated:
warnings.warn(
"{0} parameter is deprecated".format(self.name),
Expand All @@ -112,13 +113,17 @@ def unmarshal(self, value, custom_formatters=None):
unmarshalled = self.schema.unmarshal(
deserialized,
custom_formatters=custom_formatters,
require_all_props=require_all_props,
strict=False,
)
except OpenAPISchemaError as exc:
raise InvalidParameterValue(self.name, exc)

try:
return self.schema.validate(
unmarshalled, custom_formatters=custom_formatters)
unmarshalled,
custom_formatters=custom_formatters,
require_all_props=require_all_props
)
except OpenAPISchemaError as exc:
raise InvalidParameterValue(self.name, exc)
139 changes: 100 additions & 39 deletions openapi_core/schema/schemas/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,24 +155,42 @@ def get_all_required_properties_names(self):

return set(required)

def get_cast_mapping(self, custom_formatters=None, strict=True):
def get_cast_mapping(self, custom_formatters=None, strict=True,
require_all_props=False):
primitive_unmarshallers = self.get_primitive_unmarshallers(
custom_formatters=custom_formatters)
custom_formatters=custom_formatters,
require_all_props=require_all_props
)

primitive_unmarshallers_partial = dict(
(t, functools.partial(u, type_format=self.format, strict=strict))
(t, functools.partial(
u,
type_format=self.format,
strict=strict,
require_all_props=require_all_props)
)
for t, u in primitive_unmarshallers.items()
)

pass_defaults = lambda f: functools.partial(
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I needed to refactor that part since the parameters were not passed in right with that lambda function here.

f, custom_formatters=custom_formatters, strict=strict)
complex_unmarshallers = {
SchemaType.ANY: self._unmarshal_any,
SchemaType.ARRAY: self._unmarshal_collection,
SchemaType.OBJECT: self._unmarshal_object,
}

complex_unmarshallers_partial = dict(
(t, functools.partial(
u,
custom_formatters=custom_formatters,
strict=strict,
require_all_props=require_all_props)
)
for t, u in complex_unmarshallers.items()
)

mapping = self.DEFAULT_CAST_CALLABLE_GETTER.copy()
mapping.update(primitive_unmarshallers_partial)
mapping.update({
SchemaType.ANY: pass_defaults(self._unmarshal_any),
SchemaType.ARRAY: pass_defaults(self._unmarshal_collection),
SchemaType.OBJECT: pass_defaults(self._unmarshal_object),
})
mapping.update(complex_unmarshallers_partial)

return defaultdict(lambda: lambda x: x, mapping)

Expand All @@ -183,7 +201,8 @@ def are_additional_properties_allowed(self, one_of_schema=None):
one_of_schema.additional_properties is not False)
)

def cast(self, value, custom_formatters=None, strict=True):
def cast(self, value, custom_formatters=None, strict=True,
require_all_props=False):
"""Cast value to schema type"""
if value is None:
if not self.nullable:
Expand All @@ -195,7 +214,8 @@ def cast(self, value, custom_formatters=None, strict=True):
"Value {value} not in enum choices: {type}", value, self.enum)

cast_mapping = self.get_cast_mapping(
custom_formatters=custom_formatters, strict=strict)
custom_formatters=custom_formatters, strict=strict,
require_all_props=require_all_props)

if self.type is not SchemaType.STRING and value == '':
return None
Expand All @@ -210,12 +230,14 @@ def cast(self, value, custom_formatters=None, strict=True):
raise InvalidSchemaValue(
"Failed to cast value {value} to type {type}", value, self.type)

def unmarshal(self, value, custom_formatters=None, strict=True):
def unmarshal(self, value, custom_formatters=None, strict=True,
require_all_props=False):
"""Unmarshal parameter from the value."""
if self.deprecated:
warnings.warn("The schema is deprecated", DeprecationWarning)

casted = self.cast(value, custom_formatters=custom_formatters, strict=strict)
casted = self.cast(value, custom_formatters=custom_formatters, strict=strict,
require_all_props=require_all_props)

if casted is None and not self.required:
return None
Expand All @@ -242,7 +264,8 @@ def get_primitive_unmarshallers(self, **options):

return unmarshallers

def _unmarshal_any(self, value, custom_formatters=None, strict=True):
def _unmarshal_any(self, value, custom_formatters=None, strict=True,
require_all_props=False):
types_resolve_order = [
SchemaType.OBJECT, SchemaType.ARRAY, SchemaType.BOOLEAN,
SchemaType.INTEGER, SchemaType.NUMBER, SchemaType.STRING,
Expand All @@ -252,7 +275,10 @@ def _unmarshal_any(self, value, custom_formatters=None, strict=True):
result = None
for subschema in self.one_of:
try:
casted = subschema.cast(value, custom_formatters)
casted = subschema.cast(
value, custom_formatters,
require_all_props=require_all_props
)
except (OpenAPISchemaError, TypeError, ValueError):
continue
else:
Expand All @@ -277,7 +303,8 @@ def _unmarshal_any(self, value, custom_formatters=None, strict=True):

raise NoValidSchema(value)

def _unmarshal_collection(self, value, custom_formatters=None, strict=True):
def _unmarshal_collection(self, value, custom_formatters=None, strict=True,
require_all_props=False):
if not isinstance(value, (list, tuple)):
raise InvalidSchemaValue("Value {value} is not of type {type}", value, self.type)

Expand All @@ -287,11 +314,13 @@ def _unmarshal_collection(self, value, custom_formatters=None, strict=True):
f = functools.partial(
self.items.unmarshal,
custom_formatters=custom_formatters, strict=strict,
require_all_props=require_all_props
)
return list(map(f, value))

def _unmarshal_object(self, value, model_factory=None,
custom_formatters=None, strict=True):
custom_formatters=None, strict=True,
require_all_props=False):
if not isinstance(value, (dict, )):
raise InvalidSchemaValue("Value {value} is not of type {type}", value, self.type)

Expand All @@ -302,7 +331,8 @@ def _unmarshal_object(self, value, model_factory=None,
for one_of_schema in self.one_of:
try:
found_props = self._unmarshal_properties(
value, one_of_schema, custom_formatters=custom_formatters)
value, one_of_schema, custom_formatters=custom_formatters,
require_all_props=False)
except OpenAPISchemaError:
pass
else:
Expand All @@ -315,12 +345,14 @@ def _unmarshal_object(self, value, model_factory=None,

else:
properties = self._unmarshal_properties(
value, custom_formatters=custom_formatters)
value, custom_formatters=custom_formatters,
require_all_props=require_all_props)

return model_factory.create(properties, name=self.model)

def _unmarshal_properties(self, value, one_of_schema=None,
custom_formatters=None, strict=True):
custom_formatters=None, strict=True,
require_all_props=False):
all_props = self.get_all_properties()
all_props_names = self.get_all_properties_names()
all_req_props_names = self.get_all_required_properties_names()
Expand All @@ -344,7 +376,10 @@ def _unmarshal_properties(self, value, one_of_schema=None,
for prop_name in extra_props:
prop_value = value[prop_name]
properties[prop_name] = self.additional_properties.unmarshal(
prop_value, custom_formatters=custom_formatters)
prop_value,
custom_formatters=custom_formatters,
require_all_props=require_all_props
)

for prop_name, prop in iteritems(all_props):
try:
Expand All @@ -357,12 +392,16 @@ def _unmarshal_properties(self, value, one_of_schema=None,
prop_value = prop.default
try:
properties[prop_name] = prop.unmarshal(
prop_value, custom_formatters=custom_formatters)
prop_value,
custom_formatters=custom_formatters,
require_all_props=require_all_props
)
except OpenAPISchemaError as exc:
raise InvalidSchemaProperty(prop_name, exc)

self._validate_properties(properties, one_of_schema=one_of_schema,
custom_formatters=custom_formatters)
custom_formatters=custom_formatters,
require_all_props=require_all_props)

return properties

Expand All @@ -380,7 +419,8 @@ def default(x, **kw):

return defaultdict(lambda: default, mapping)

def validate(self, value, custom_formatters=None):
def validate(self, value, custom_formatters=None,
require_all_props=False):
if value is None:
if not self.nullable:
raise InvalidSchemaValue("Null value for non-nullable schema of type {type}", value, self.type)
Expand All @@ -396,11 +436,16 @@ def validate(self, value, custom_formatters=None):
# structure validation
validator_mapping = self.get_validator_mapping()
validator_callable = validator_mapping[self.type]
validator_callable(value, custom_formatters=custom_formatters)
validator_callable(
value,
custom_formatters=custom_formatters,
require_all_props=False
)

return value

def _validate_collection(self, value, custom_formatters=None):
def _validate_collection(self, value, custom_formatters=None,
require_all_props=False):
if self.items is None:
raise UndefinedItemsSchema(self.type)

Expand Down Expand Up @@ -428,10 +473,12 @@ def _validate_collection(self, value, custom_formatters=None):
raise OpenAPISchemaError("Value may not contain duplicate items")

f = functools.partial(self.items.validate,
custom_formatters=custom_formatters)
custom_formatters=custom_formatters,
require_all_props=require_all_props)
return list(map(f, value))

def _validate_number(self, value, custom_formatters=None):
def _validate_number(self, value, custom_formatters=None,
require_all_props=False):
if self.minimum is not None:
if self.exclusive_minimum and value <= self.minimum:
raise InvalidSchemaValue(
Expand All @@ -453,7 +500,8 @@ def _validate_number(self, value, custom_formatters=None):
"Value {value} is not a multiple of {type}",
value, self.multiple_of)

def _validate_string(self, value, custom_formatters=None):
def _validate_string(self, value, custom_formatters=None,
require_all_props=False):
try:
schema_format = SchemaFormat(self.format)
except ValueError:
Expand Down Expand Up @@ -502,16 +550,19 @@ def _validate_string(self, value, custom_formatters=None):

return True

def _validate_object(self, value, custom_formatters=None):
def _validate_object(self, value, custom_formatters=None,
require_all_props=False):
properties = value.__dict__

if self.one_of:
valid_one_of_schema = None
for one_of_schema in self.one_of:
try:
self._validate_properties(
properties, one_of_schema,
custom_formatters=custom_formatters)
properties, one_of_schema,
custom_formatters=custom_formatters,
require_all_props=require_all_props
)
except OpenAPISchemaError:
pass
else:
Expand All @@ -523,8 +574,11 @@ def _validate_object(self, value, custom_formatters=None):
raise NoOneOfSchema(self.type)

else:
self._validate_properties(properties,
custom_formatters=custom_formatters)
self._validate_properties(
properties,
custom_formatters=custom_formatters,
require_all_props=require_all_props
)

if self.min_properties is not None:
if self.min_properties < 0:
Expand Down Expand Up @@ -554,7 +608,8 @@ def _validate_object(self, value, custom_formatters=None):
return True

def _validate_properties(self, value, one_of_schema=None,
custom_formatters=None):
custom_formatters=None,
require_all_props=False):
all_props = self.get_all_properties()
all_props_names = self.get_all_properties_names()
all_req_props_names = self.get_all_required_properties_names()
Expand All @@ -577,19 +632,25 @@ def _validate_properties(self, value, one_of_schema=None,
for prop_name in extra_props:
prop_value = value[prop_name]
self.additional_properties.validate(
prop_value, custom_formatters=custom_formatters)
prop_value,
custom_formatters=custom_formatters,
require_all_props=require_all_props
)

for prop_name, prop in iteritems(all_props):
try:
prop_value = value[prop_name]
except KeyError:
if prop_name in all_req_props_names:
if (prop_name in all_req_props_names) or require_all_props:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@reviewers This is the important line. All other changes are passing around the flag.

raise MissingSchemaProperty(prop_name)
if not prop.nullable and not prop.default:
continue
prop_value = prop.default
try:
prop.validate(prop_value, custom_formatters=custom_formatters)
prop.validate(prop_value,
custom_formatters=custom_formatters,
require_all_props=require_all_props
)
except OpenAPISchemaError as exc:
raise InvalidSchemaProperty(prop_name, original_exception=exc)

Expand Down
Loading