|
6 | 6 | import urllib |
7 | 7 | from typing import Any, Dict |
8 | 8 |
|
9 | | -from asgi_correlation_id import CorrelationIdMiddleware, correlation_id # noqa |
10 | 9 | from asgiref.typing import ( |
11 | 10 | ASGI3Application, |
12 | 11 | ASGIReceiveCallable, |
|
15 | 14 | HTTPScope, |
16 | 15 | ) |
17 | 16 |
|
18 | | -from ..logging import JsonLogFormatter |
| 17 | +from ..logging import JsonLogFormatter, get_or_generate_request_id, request_id_context |
| 18 | + |
| 19 | + |
| 20 | +class RequestIdMiddleware: |
| 21 | + def __init__( |
| 22 | + self, |
| 23 | + app: ASGI3Application, |
| 24 | + ) -> None: |
| 25 | + self.app = app |
| 26 | + |
| 27 | + async def __call__( |
| 28 | + self, scope: HTTPScope, receive: ASGIReceiveCallable, send: ASGISendCallable |
| 29 | + ) -> None: |
| 30 | + if scope["type"] != "http": |
| 31 | + return await self.app(scope, receive, send) |
| 32 | + |
| 33 | + headers = {} |
| 34 | + for name, value in scope["headers"]: |
| 35 | + header_key = name.decode("latin1").lower() |
| 36 | + header_val = value.decode("latin1") |
| 37 | + headers[header_key] = header_val |
| 38 | + |
| 39 | + request_id_context.set(get_or_generate_request_id(headers)) |
| 40 | + await self.app(scope, receive, send) |
19 | 41 |
|
20 | 42 |
|
21 | 43 | class MozlogRequestSummaryLogger: |
@@ -75,7 +97,7 @@ def _format(self, scope: HTTPScope, info) -> Dict[str, Any]: |
75 | 97 | "code": info["response"]["status"], |
76 | 98 | "lang": info["request_headers"].get("accept-language"), |
77 | 99 | "t": int(request_duration_ms), |
78 | | - "rid": correlation_id.get(), |
| 100 | + "rid": request_id_context.get(), |
79 | 101 | } |
80 | 102 |
|
81 | 103 | if getattr(scope["app"].state, "DOCKERFLOW_SUMMARY_LOG_QUERYSTRING", False): |
|
0 commit comments