diff --git a/src/oidcop/__init__.py b/src/oidcop/__init__.py index 728a997..38400fc 100644 --- a/src/oidcop/__init__.py +++ b/src/oidcop/__init__.py @@ -1,6 +1,6 @@ import secrets -__version__ = "2.4.0" +__version__ = "2.4.1" DEF_SIGN_ALG = { "id_token": "RS256", diff --git a/src/oidcop/oauth2/add_on/extra_args.py b/src/oidcop/oauth2/add_on/extra_args.py index 68a8a84..ddfd3d4 100644 --- a/src/oidcop/oauth2/add_on/extra_args.py +++ b/src/oidcop/oauth2/add_on/extra_args.py @@ -1,3 +1,10 @@ +from oidcmsg.oauth2 import AccessTokenResponse +from oidcmsg.oauth2 import AuthorizationResponse +from oidcmsg.oauth2 import TokenExchangeResponse +from oidcmsg.oauth2 import TokenIntrospectionResponse +from oidcmsg.oidc import OpenIDSchema + + def pre_construct(response_args, request, endpoint_context, **kwargs): """ Add extra arguments to the request. @@ -11,12 +18,25 @@ def pre_construct(response_args, request, endpoint_context, **kwargs): _extra = endpoint_context.add_on.get("extra_args") if _extra: - for arg, _param in _extra.items(): - _val = endpoint_context.get(_param) + if isinstance(response_args, AuthorizationResponse): + _args = _extra.get("authorization", {}) + elif isinstance(response_args, AccessTokenResponse): + _args = _extra.get('accesstoken', {}) + elif isinstance(response_args, TokenExchangeResponse): + _args = _extra.get('token_exchange', {}) + elif isinstance(response_args, TokenIntrospectionResponse): + _args = _extra.get('token_introspection', {}) + elif isinstance(response_args, OpenIDSchema): + _args = _extra.get('userinfo', {}) + else: + _args = {} + + for arg, _param in _args.items(): + _val = getattr(endpoint_context, _param) if _val: - request[arg] = _val + response_args[arg] = _val - return request + return response_args def add_support(endpoint, **kwargs): diff --git a/tests/test_61_add_on.py b/tests/test_61_add_on.py index e6c2b8e..0b4b7e6 100644 --- a/tests/test_61_add_on.py +++ b/tests/test_61_add_on.py @@ -1,4 +1,5 @@ import os +from urllib.parse import urlparse from cryptojwt.jwk.ec import ECKey from cryptojwt.jwk.ec import new_ec_key @@ -6,6 +7,7 @@ from cryptojwt.key_jar import init_key_jar from oidcmsg.oauth2 import AccessTokenRequest from oidcmsg.oauth2 import AuthorizationRequest +from oidcmsg.oauth2 import AuthorizationResponse from oidcmsg.time_util import utc_time_sans_frac from oidcop.authn_event import create_authn_event @@ -155,15 +157,10 @@ def test_process_request(self): _context = self.endpoint.server_get("endpoint_context") assert _context.add_on["extra_args"] == {'authorization': {'iss': 'issuer'}} - _pr_resp = self.endpoint.parse_request(AUTH_REQ.to_dict()) - _resp = self.endpoint.process_request(_pr_resp) - assert set(_resp.keys()) == { - "response_args", - "fragment_enc", - "return_uri", - "cookie", - "session_id", - } - - assert 'iss' in _resp["response_args"] - assert _resp["response_args"]["iss"] == _context.issuer + _pr_resp = self.endpoint.parse_request(AUTH_REQ) + _args = self.endpoint.process_request(_pr_resp) + _resp = self.endpoint.do_response(request=AUTH_REQ, **_args) + parse_res = urlparse(_resp["response"]) + _payload = AuthorizationResponse().from_urlencoded(parse_res.query) + assert 'iss' in _payload + assert _payload["iss"] == _context.issuer