Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 145 additions & 0 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@
# Collect ALL middleware for debug printing - include internal _registered_api_adapter
all_middlewares = router_middlewares + self.middlewares + [_registered_api_adapter]
print("\nMiddleware Stack:")
print("=================")

Check failure on line 527 in aws_lambda_powertools/event_handler/api_gateway.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Define a constant instead of duplicating this literal "=================" 4 times.

See more on https://sonarcloud.io/project/issues?id=aws-powertools_powertools-lambda-python&issues=AZ27EUEMuMdUgzJGgySj&open=AZ27EUEMuMdUgzJGgySj&pullRequest=8170
print("\n".join(getattr(item, "__name__", "Unknown") for item in all_middlewares))
print("=================")

Expand Down Expand Up @@ -613,6 +613,63 @@

self._middleware_stack_built = True

async def call_async(
self,
router_middlewares: list[Callable],
app: ApiGatewayResolver,
route_arguments: dict[str, str],
) -> dict | tuple | Response:
from aws_lambda_powertools.event_handler.middlewares.async_utils import (
AsyncMiddlewareFrame,
_registered_api_adapter_async,
)

all_middlewares: list[Callable[..., Any]] = []

route_validation_enabled = (
self.enable_validation if self.enable_validation is not None else app._enable_validation
)

if route_validation_enabled and not hasattr(app, "_request_validation_middleware"):
from aws_lambda_powertools.event_handler.middlewares.openapi_validation import (
OpenAPIRequestValidationMiddleware,
OpenAPIResponseValidationMiddleware,
)

app._request_validation_middleware = OpenAPIRequestValidationMiddleware()
app._response_validation_middleware = OpenAPIResponseValidationMiddleware(
validation_serializer=app._serializer,
has_response_validation_error=app._has_response_validation_error,
)

if route_validation_enabled and hasattr(app, "_request_validation_middleware"):
all_middlewares.append(app._request_validation_middleware)

all_middlewares.extend(router_middlewares + self.middlewares)

if route_validation_enabled and hasattr(app, "_response_validation_middleware"):
all_middlewares.append(app._response_validation_middleware)

all_middlewares.append(_registered_api_adapter_async)

logger.debug(f"Building async middleware stack: {all_middlewares}")

if app._debug:
print(f"\nProcessing Route (async):::{self.func.__name__} ({app.context['_path']})")
print("\nAsync Middleware Stack:")
print("=================")
print("\n".join(getattr(item, "__name__", "Unknown") for item in all_middlewares))
print("=================")

app.append_context(_route_args=route_arguments)

# Build async chain from inside-out (not cached — avoids state conflicts with sync cache)
next_handler: Callable = self.func
for handler in reversed(all_middlewares):
next_handler = AsyncMiddlewareFrame(current_middleware=handler, next_middleware=next_handler)

return await next_handler(app)

@property
def dependant(self) -> Dependant:
if self._dependant is None:
Expand Down Expand Up @@ -2509,6 +2566,94 @@

return response

async def _resolve_async(self) -> ResponseBuilder:
method = self.current_event.http_method.upper()
path = self._remove_prefix(self.current_event.path)

registered_routes = self._static_routes + self._dynamic_routes

for route in registered_routes:
if method != route.method:
continue
match_results: Match | None = route.rule.match(path)
if match_results:
logger.debug("Found a registered route. Calling async function")
self.append_context(_route=route, _path=path)

route_keys = self._convert_matches_into_route_keys(match_results)
return await self._call_route_async(route, route_keys)

return await self._handle_not_found_async(method=method, path=path)

async def _call_route_async(self, route: Route, route_arguments: dict[str, str]) -> ResponseBuilder:
try:
self._reset_processed_stack()

response = await route.call_async(
router_middlewares=self._router_middlewares,
app=self,
route_arguments=route_arguments,
)

return self._response_builder_class(
response=self._to_response(response), # type: ignore[arg-type]
serializer=self._serializer,
route=route,
)
except Exception as exc:
response_builder = self._call_exception_handler(exc, route)
if response_builder:
return response_builder

logger.exception(exc)
if self._debug:
return self._response_builder_class(
response=Response(
status_code=500,
content_type=content_types.TEXT_PLAIN,
body="".join(traceback.format_exc()),
),
serializer=self._serializer,
route=route,
)

raise

async def _handle_not_found_async(self, method: str, path: str) -> ResponseBuilder:
logger.debug(f"No match found for path {path} and method {method}")

def not_found_handler():
_headers: dict[str, Any] = {}

if self._cors and method == "OPTIONS":
logger.debug("Pre-flight request detected. Returning CORS with empty response")
_headers["Access-Control-Allow-Methods"] = CORSConfig.build_allow_methods(self._cors_methods)
return Response(status_code=204, content_type=None, headers=_headers, body="")

custom_not_found_handler = self.exception_handler_manager.lookup_exception_handler(NotFoundError)
if custom_not_found_handler:
return custom_not_found_handler(NotFoundError())

return Response(
status_code=HTTPStatus.NOT_FOUND.value,
content_type=content_types.APPLICATION_JSON,
headers=_headers,
body={"statusCode": HTTPStatus.NOT_FOUND.value, "message": "Not found"},
)

route = Route(
rule=self._compile_regex(r".*"),
method=method,
path=path,
func=not_found_handler,
cors=self._cors_enabled,
compress=False,
)

self.append_context(_route=route, _path=path)

return await self._call_route_async(route=route, route_arguments={})

def __call__(self, event, context) -> Any:
return self.resolve(event, context)

Expand Down
6 changes: 3 additions & 3 deletions aws_lambda_powertools/event_handler/http_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def _get_base_path(self) -> str:
"""Return the base path for HTTP resolver (no stage prefix)."""
return ""

async def _resolve_async(self) -> dict:
async def _resolve_async(self) -> dict: # type: ignore[override]
"""Async version of resolve that supports async handlers."""
method = self.current_event.http_method.upper()
path = self._remove_prefix(self.current_event.path)
Expand All @@ -258,7 +258,7 @@ async def _resolve_async(self) -> dict:
# Handle not found
return await self._handle_not_found_async()

async def _call_route_async(self, route: Route, route_arguments: dict[str, str]) -> dict:
async def _call_route_async(self, route: Route, route_arguments: dict[str, str]) -> dict: # type: ignore[override]
"""Call route handler, supporting both sync and async handlers."""
from aws_lambda_powertools.event_handler.api_gateway import ResponseBuilder

Expand Down Expand Up @@ -323,7 +323,7 @@ async def final_handler(app):

return await next_handler(self)

async def _handle_not_found_async(self) -> dict:
async def _handle_not_found_async(self, method: str = "", path: str = "") -> dict: # type: ignore[override]
"""Handle 404 responses, using custom not_found handler if registered."""
from http import HTTPStatus

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import asyncio

from aws_lambda_powertools.event_handler.api_gateway import (
APIGatewayHttpResolver,
BaseRouter,
)
from tests.functional.utils import load_event

API_RESTV2_EVENT = load_event("apiGatewayProxyV2Event_GET.json")


def _setup_app(app, event):
BaseRouter.current_event = app._to_proxy_event(event)
BaseRouter.lambda_context = {}


class TestResolveAsyncValidation:
def test_validation_middleware_created_and_used(self):
# GIVEN a resolver with validation enabled and an async handler
app = APIGatewayHttpResolver(enable_validation=True)

@app.get("/my/path")
async def get_lambda() -> dict:
await asyncio.sleep(0)
return {"message": "validated"}

# WHEN calling _resolve_async
_setup_app(app, API_RESTV2_EVENT)
result = asyncio.run(app._resolve_async())

# THEN the validation middlewares are created and the response is valid
response = result.build(app.current_event, app._cors)
assert response["statusCode"] == 200
assert hasattr(app, "_request_validation_middleware")
assert hasattr(app, "_response_validation_middleware")

def test_validation_middleware_lazy_created_for_per_route_validation(self):
# GIVEN a resolver WITHOUT global validation, but a route WITH enable_validation=True
app = APIGatewayHttpResolver()
assert not hasattr(app, "_request_validation_middleware")

@app.get("/my/path", enable_validation=True)
async def get_lambda() -> dict:
await asyncio.sleep(0)
return {"message": "lazy validated"}

# WHEN calling _resolve_async (triggers lazy creation in Route.call_async)
_setup_app(app, API_RESTV2_EVENT)
result = asyncio.run(app._resolve_async())

# THEN validation middlewares are lazily created on the app
response = result.build(app.current_event, app._cors)
assert response["statusCode"] == 200
assert hasattr(app, "_request_validation_middleware")
assert hasattr(app, "_response_validation_middleware")
Loading
Loading