Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 53 additions & 16 deletions flask_parameter_validation/parameter_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from inspect import signature

from flask import request
from werkzeug.datastructures import ImmutableMultiDict
from werkzeug.exceptions import BadRequest

from .exceptions import (InvalidParameterTypeError, MissingInputError,
Expand Down Expand Up @@ -43,17 +44,20 @@ def nested_func(**kwargs):
# Step 1 - Combine all flask input types to one dict
json_input = None
if request.headers.get("Content-Type") is not None:
if re.search("application/[^+]*[+]?(json);?", request.headers.get("Content-Type")):
if re.search(
"application/[^+]*[+]?(json);?", request.headers.get("Content-Type")
):
try:
json_input = request.json
except BadRequest:
return {"error": "Could not parse JSON."}, 400

request_inputs = {
Route: kwargs.copy(),
Json: json_input or {},
Query: request.args.to_dict(),
Form: request.form.to_dict(),
File: request.files.to_dict(),
Query: self._to_dict_with_lists(request.args, split_strings=True),
Form: self._to_dict_with_lists(request.form),
File: self._to_dict_with_lists(request.files),
}
# Step 2 - Get expected input details as dict
expected_inputs = signature(f).parameters
Expand All @@ -79,6 +83,25 @@ def nested_func(**kwargs):
nested_func.__name__ = f.__name__
return nested_func

def _to_dict_with_lists(
self, multi_dict: ImmutableMultiDict, split_strings: bool = False
) -> dict:
# If a dict has duplicate keys, they should instead be stored as a list under the same key
dict_with_lists = {}
# Iterate over all keys and values in ImmutableMultiDict
for key, values in multi_dict.lists():
list_values = []
# If split strings, split each value by comma
for value in values:
if split_strings:
list_values.extend(value.split(","))
else:
list_values.append(value)
# Create tuple of all values
dict_with_lists[key] = list_values

return dict_with_lists

def validate(self, expected_input, all_request_inputs):
"""
Validate that a given expected input exists in the requested input collection
Expand All @@ -103,17 +126,23 @@ def validate(self, expected_input, all_request_inputs):
raise InvalidParameterTypeError(expected_delivery_type)

# Validate that user supplied input in expected delivery type (unless specified as Optional)
user_input = all_request_inputs[expected_delivery_type.__class__].get(expected_name)
user_input = all_request_inputs[expected_delivery_type.__class__].get(
expected_name
)
if user_input is None:
# If default is given, set and continue
if expected_delivery_type.default is not None:
user_input = expected_delivery_type.default
else:
# Optionals are Unions with a NoneType, so we should check if None is part of Union __args__ (if exist)
if hasattr(expected_input_type, "__args__") and type(None) in expected_input_type.__args__:
if (
hasattr(expected_input_type, "__args__") and type(None) in expected_input_type.__args__
):
return user_input
else:
raise MissingInputError(expected_name, expected_delivery_type.__class__)
raise MissingInputError(
expected_name, expected_delivery_type.__class__
)

# Skip validation if typing.Any is given
if expected_input_type_str.startswith("typing.Any"):
Expand Down Expand Up @@ -152,12 +181,16 @@ def validate(self, expected_input, all_request_inputs):
# Perform automatic type conversion for parameter types (i.e. "true" -> True)
for count, value in enumerate(user_inputs):
try:
user_inputs[count] = expected_delivery_type.convert(value, expected_input_types)
user_inputs[count] = expected_delivery_type.convert(
value, expected_input_types
)
except ValueError as e:
raise ValidationError(str(e), expected_name, expected_input_type)

# Validate that user type(s) match expected type(s)
validation_success = all(type(inp) in expected_input_types for inp in user_inputs)
validation_success = all(
type(inp) in expected_input_types for inp in user_inputs
)

# Validate that if lists are required, lists are given
if expected_input_type_str.startswith("typing.List"):
Expand All @@ -166,13 +199,17 @@ def validate(self, expected_input, all_request_inputs):

# Error if types don't match
if not validation_success:
if hasattr(original_expected_input_type, "__name__") and not original_expected_input_type_str.startswith(
"typing."
):
if hasattr(
original_expected_input_type, "__name__"
) and not original_expected_input_type_str.startswith("typing."):
type_name = original_expected_input_type.__name__
else:
type_name = original_expected_input_type_str
raise ValidationError(f"must be type '{type_name}'", expected_name, original_expected_input_type)
raise ValidationError(
f"must be type '{type_name}'",
expected_name,
original_expected_input_type,
)

# Validate parameter-specific requirements are met
try:
Expand All @@ -181,6 +218,6 @@ def validate(self, expected_input, all_request_inputs):
raise ValidationError(str(e), expected_name, expected_input_type)

# Return input back to parent function
if len(user_inputs) == 1:
return user_inputs[0]
return user_inputs
if expected_input_type_str.startswith("typing.List"):
return user_inputs
return user_inputs[0]