Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 29 additions & 10 deletions flask_parameter_validation/parameter_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import inspect
import re
from inspect import signature
from flask import request
from flask import request, Response
from werkzeug.datastructures import ImmutableMultiDict
from werkzeug.exceptions import BadRequest
from .exceptions import (InvalidParameterTypeError, MissingInputError,
Expand Down Expand Up @@ -40,8 +40,13 @@ def __call__(self, f):
}
fn_list[fsig] = fdocs

@functools.wraps(f)
async def nested_func(**kwargs):
def nested_func_helper(**kwargs):
"""
Validates the inputs of a Flask route or returns an error. Returns
are wrapped in a dictionary with a flag to let nested_func() know
if it should unpack the resulting dictionary of inputs as kwargs,
or just return the error message.
"""
# Step 1 - Get expected input details as dict
expected_inputs = signature(f).parameters

Expand All @@ -54,7 +59,7 @@ async def nested_func(**kwargs):
try:
json_input = request.json
except BadRequest:
return {"error": "Could not parse JSON."}, 400
return {"error": ({"error": "Could not parse JSON."}, 400), "validated": False}

# Step 3 - Extract list of parameters expected to be lists (otherwise all values are converted to lists)
expected_list_params = []
Expand All @@ -79,18 +84,32 @@ async def nested_func(**kwargs):
try:
new_input = self.validate(expected, request_inputs)
except (MissingInputError, ValidationError) as e:
return {"error": str(e)}, 400
return {"error": ({"error": str(e)}, 400), "validated": False}
else:
try:
new_input = self.validate(expected, request_inputs)
except Exception as e:
return self.custom_error_handler(e)
return {"error": self.custom_error_handler(e), "validated": False}
validated_inputs[expected.name] = new_input

if asyncio.iscoroutinefunction(f):
return await f(**validated_inputs)
else:
return f(**validated_inputs)
return {"inputs": validated_inputs, "validated": True}

if asyncio.iscoroutinefunction(f):
# If the view function is async, return and await a coroutine
@functools.wraps(f)
async def nested_func(**kwargs):
validated_inputs = nested_func_helper(**kwargs)
if validated_inputs["validated"]:
return await f(**validated_inputs["inputs"])
return validated_inputs["error"]
else:
# If the view function is not async, return a function
@functools.wraps(f)
def nested_func(**kwargs):
validated_inputs = nested_func_helper(**kwargs)
if validated_inputs["validated"]:
return f(**validated_inputs["inputs"])
return validated_inputs["error"]

nested_func.__name__ = f.__name__
return nested_func
Expand Down
22 changes: 22 additions & 0 deletions flask_parameter_validation/test/test_file_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,28 @@ def test_required_file(client):
assert "error" in r.json


def test_required_file_decorator(client):
url = "/file/decorator/required"
# Test that we receive a success response if a file is provided
r = client.post(url, data={"v": (resources / "test.json").open("rb")})
assert "success" in r.json
assert r.json["success"]
# Test that we receive an error if a file is not provided
r = client.post(url)
assert "error" in r.json


def test_required_file_async_decorator(client):
url = "/file/async_decorator/required"
# Test that we receive a success response if a file is provided
r = client.post(url, data={"v": (resources / "test.json").open("rb")})
assert "success" in r.json
assert r.json["success"]
# Test that we receive an error if a file is not provided
r = client.post(url)
assert "error" in r.json


def test_optional_file(client):
url = "/file/optional"
# Test that we receive a success response if a file is provided
Expand Down
22 changes: 22 additions & 0 deletions flask_parameter_validation/test/test_form_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,28 @@ def test_required_str(client):
assert "error" in r.json


def test_required_str_decorator(client):
url = "/form/str/decorator/required"
# Test that present input yields input value
r = client.post(url, data={"v": "v"})
assert "v" in r.json
assert r.json["v"] == "v"
# Test that missing input yields error
r = client.post(url)
assert "error" in r.json


def test_required_str_async_decorator(client):
url = "/form/str/async_decorator/required"
# Test that present input yields input value
r = client.post(url, data={"v": "v"})
assert "v" in r.json
assert r.json["v"] == "v"
# Test that missing input yields error
r = client.post(url)
assert "error" in r.json


def test_optional_str(client):
url = "/form/str/optional"
# Test that missing input yields None
Expand Down
48 changes: 48 additions & 0 deletions flask_parameter_validation/test/test_json_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,6 +952,54 @@ def test_dict_default(client):
assert opt == r.json["opt"]


def test_dict_default_decorator(client):
url = "/json/dict/decorator/default"
# Test that present dict yields input values
n_opt = {"e": "f"}
opt = {"g": "h"}
r = client.post(url, json={"n_opt": n_opt, "opt": opt})
assert "n_opt" in r.json
assert "opt" in r.json
assert type(r.json["n_opt"]) is dict
assert type(r.json["opt"]) is dict
assert n_opt == r.json["n_opt"]
assert opt == r.json["opt"]
# Test that missing dict yields default values
n_opt = {"a": "b"}
opt = {"c": "d"}
r = client.post(url)
assert "n_opt" in r.json
assert "opt" in r.json
assert type(r.json["n_opt"]) is dict
assert type(r.json["opt"]) is dict
assert n_opt == r.json["n_opt"]
assert opt == r.json["opt"]


def test_dict_default_async_decorator(client):
url = "/json/dict/async_decorator/default"
# Test that present dict yields input values
n_opt = {"e": "f"}
opt = {"g": "h"}
r = client.post(url, json={"n_opt": n_opt, "opt": opt})
assert "n_opt" in r.json
assert "opt" in r.json
assert type(r.json["n_opt"]) is dict
assert type(r.json["opt"]) is dict
assert n_opt == r.json["n_opt"]
assert opt == r.json["opt"]
# Test that missing dict yields default values
n_opt = {"a": "b"}
opt = {"c": "d"}
r = client.post(url)
assert "n_opt" in r.json
assert "opt" in r.json
assert type(r.json["n_opt"]) is dict
assert type(r.json["opt"]) is dict
assert n_opt == r.json["n_opt"]
assert opt == r.json["opt"]


def test_dict_func(client):
url = "/json/dict/func"
# Test that dict passing func yields input value
Expand Down
Loading