From b49cc843f45bfae6e375b610d231bffd66a92af6 Mon Sep 17 00:00:00 2001 From: Nikos Sklikas Date: Mon, 14 Feb 2022 15:26:40 +0200 Subject: [PATCH 1/5] Refactor client authn --- src/oidcop/client_authn.py | 359 +++++++++++--------- src/oidcop/endpoint.py | 9 +- src/oidcop/exception.py | 4 + src/oidcop/oauth2/add_on/dpop.py | 4 +- src/oidcop/oauth2/authorization.py | 1 + src/oidcop/oauth2/introspection.py | 8 + src/oidcop/oidc/authorization.py | 1 + src/oidcop/oidc/session.py | 17 +- src/oidcop/oidc/userinfo.py | 4 +- src/oidcop/server.py | 20 +- tests/test_02_client_authn.py | 119 ++++--- tests/test_20_endpoint.py | 3 +- tests/test_23_oidc_registration_endpoint.py | 2 + tests/test_24_oauth2_token_endpoint.py | 4 +- tests/test_26_oidc_userinfo_endpoint.py | 3 +- tests/test_30_oidc_end_session.py | 3 +- tests/test_31_oauth2_introspection.py | 6 +- tests/test_32_oidc_read_registration.py | 2 + tests/test_35_oidc_token_endpoint.py | 4 +- 19 files changed, 308 insertions(+), 265 deletions(-) diff --git a/src/oidcop/client_authn.py b/src/oidcop/client_authn.py index a34e6329..2ffdba8f 100755 --- a/src/oidcop/client_authn.py +++ b/src/oidcop/client_authn.py @@ -1,5 +1,6 @@ import base64 import logging +from collections import OrderedDict from typing import Callable from typing import Optional from typing import Union @@ -20,55 +21,65 @@ from oidcop import sanitize from oidcop.endpoint_context import EndpointContext from oidcop.exception import BearerTokenAuthenticationError +from oidcop.exception import ClientAuthenticationError from oidcop.exception import InvalidClient -from oidcop.exception import MultipleUsage -from oidcop.exception import NotForMe +from oidcop.exception import InvalidToken from oidcop.exception import ToOld +from oidcop.exception import UnAuthorizedClient from oidcop.exception import UnknownClient -from oidcop.util import importer logger = logging.getLogger(__name__) __author__ = "roland hedberg" -class AuthnFailure(Exception): - pass - - -class NoMatchingKey(Exception): - pass - - -class UnknownOrNoAuthnMethod(Exception): - pass - - -class WrongAuthnMethod(Exception): - pass - - class ClientAuthnMethod(object): - def __init__(self, server_get): + tag = None + + @classmethod + def _verify( + cls, endpoint_context, request=None, authorization_token=None, endpoint=None, **kwargs + ): """ - :param server_get: A method that can be used to get general server information. + Verify authentication information in a request + :param kwargs: + :return: """ - self.server_get = server_get + raise NotImplementedError() - def verify(self, **kwargs): + @classmethod + def verify( + cls, + endpoint_context, + request=None, + authorization_token=None, + endpoint=None, + get_client_id_from_token=None, + **kwargs, + ): """ Verify authentication information in a request :param kwargs: :return: """ - raise NotImplementedError() + res = cls._verify( + endpoint_context, + request=request, + authorization_token=authorization_token, + endpoint=endpoint, + get_client_id_from_token=get_client_id_from_token, + **kwargs, + ) + res["method"] = cls.tag + return res - def is_usable(self, request=None, authorization_info=None): + @classmethod + def is_usable(cls, endpoint_context, request=None, authorization_token=None): """ Verify that this authentication method is applicable. :param request: The request - :param authorization_info: Other authorization information + :param authorization_token: The authorization token :return: True/False """ raise NotImplementedError() @@ -76,7 +87,7 @@ def is_usable(self, request=None, authorization_info=None): def basic_authn(authorization_token): if not authorization_token.startswith("Basic "): - raise AuthnFailure("Wrong type of authorization token") + raise ClientAuthenticationError("Wrong type of authorization token") _tok = as_bytes(authorization_token[6:]) # Will raise ValueError type exception if not base64 encoded @@ -88,6 +99,25 @@ def basic_authn(authorization_token): raise ValueError("Illegal token") +class NoneAuthn(ClientAuthnMethod): + """ + Used for public clients, that don't require any form of authentication other + than their client_id + """ + + tag = "none" + + @classmethod + def is_usable(cls, endpoint_context, request=None, authorization_token=None): + return request and "client_id" in request + + @classmethod + def _verify( + cls, endpoint_context, request=None, authorization_token=None, endpoint=None, **kwargs + ): + return {"client_id": request["client_id"]} + + class ClientSecretBasic(ClientAuthnMethod): """ Clients that have received a client_secret value from the Authorization @@ -97,21 +127,22 @@ class ClientSecretBasic(ClientAuthnMethod): tag = "client_secret_basic" - def is_usable(self, request=None, authorization_token=None): + @classmethod + def is_usable(cls, endpoint_context, request=None, authorization_token=None): if authorization_token is not None and authorization_token.startswith("Basic "): return True return False - def verify(self, authorization_token, **kwargs): + @classmethod + def _verify( + cls, endpoint_context, request=None, authorization_token=None, endpoint=None, **kwargs + ): client_info = basic_authn(authorization_token) - if ( - self.server_get("endpoint_context").cdb[client_info["id"]]["client_secret"] - == client_info["secret"] - ): + if endpoint_context.cdb[client_info["id"]]["client_secret"] == client_info["secret"]: return {"client_id": client_info["id"]} else: - raise AuthnFailure() + raise ClientAuthenticationError() class ClientSecretPost(ClientSecretBasic): @@ -124,21 +155,22 @@ class ClientSecretPost(ClientSecretBasic): tag = "client_secret_post" - def is_usable(self, request=None, authorization_token=None): + @classmethod + def is_usable(cls, endpoint_context, request=None, authorization_token=None): if request is None: return False if "client_id" in request and "client_secret" in request: return True return False - def verify(self, request, **kwargs): - if ( - self.server_get("endpoint_context").cdb[request["client_id"]]["client_secret"] - == request["client_secret"] - ): + @classmethod + def _verify( + cls, endpoint_context, request=None, authorization_token=None, endpoint=None, **kwargs + ): + if endpoint_context.cdb[request["client_id"]]["client_secret"] == request["client_secret"]: return {"client_id": request["client_id"]} else: - raise AuthnFailure("secrets doesn't match") + raise ClientAuthenticationError("secrets doesn't match") class BearerHeader(ClientSecretBasic): @@ -146,13 +178,30 @@ class BearerHeader(ClientSecretBasic): tag = "bearer_header" - def is_usable(self, request=None, authorization_token=None): + @classmethod + def is_usable(cls, endpoint_context, request=None, authorization_token=None): if authorization_token is not None and authorization_token.startswith("Bearer "): return True return False - def verify(self, authorization_token, **kwargs): - return {"token": authorization_token.split(" ", 1)[1]} + @classmethod + def _verify( + cls, + endpoint_context, + request=None, + authorization_token=None, + endpoint=None, + get_client_id_from_token=None, + **kwargs, + ): + token = authorization_token.split(" ", 1)[1] + try: + client_id = get_client_id_from_token(endpoint_context, token, request) + except ToOld: + raise BearerTokenAuthenticationError("Expired token") + except KeyError: + raise BearerTokenAuthenticationError("Unknown token") + return {"token": token, "client_id": client_id} class BearerBody(ClientSecretPost): @@ -162,15 +211,19 @@ class BearerBody(ClientSecretPost): tag = "bearer_body" - def is_usable(self, request=None, authorization_token=None): + @classmethod + def is_usable(cls, endpoint_context, request=None, authorization_token=None): if request is not None and "access_token" in request: return True return False - def verify(self, request, **kwargs): + @classmethod + def _verify( + cls, endpoint_context, request=None, authorization_token=None, endpoint=None, **kwargs + ): _token = request.get("access_token") if _token is None: - raise AuthnFailure("No access token") + raise ClientAuthenticationError("No access token") res = {"token": _token} _client_id = request.get("client_id") @@ -180,28 +233,31 @@ def verify(self, request, **kwargs): class JWSAuthnMethod(ClientAuthnMethod): - def is_usable(self, request=None, authorization_token=None): + @classmethod + def is_usable(cls, endpoint_context, request=None, authorization_token=None): if request is None: return False if "client_assertion" in request: return True return False - def verify(self, request, key_type, **kwargs): - _context = self.server_get("endpoint_context") - _jwt = JWT(_context.keyjar, msg_cls=JsonWebToken) + @classmethod + def _verify(cls, endpoint_context, request=None, endpoint=None, key_type=None, **kwargs): + _jwt = JWT(endpoint_context.keyjar, msg_cls=JsonWebToken) try: ca_jwt = _jwt.unpack(request["client_assertion"]) except (Invalid, MissingKey, BadSignature) as err: logger.info("%s" % sanitize(err)) - raise AuthnFailure("Could not verify client_assertion.") + raise ClientAuthenticationError("Could not verify client_assertion.") _sign_alg = ca_jwt.jws_header.get("alg") if _sign_alg and _sign_alg.startswith("HS"): if key_type == "private_key": raise AttributeError("Wrong key type") - keys = _context.keyjar.get("sig", "oct", ca_jwt["iss"], ca_jwt.jws_header.get("kid")) - _secret = _context.cdb[ca_jwt["iss"]].get("client_secret") + keys = endpoint_context.keyjar.get( + "sig", "oct", ca_jwt["iss"], ca_jwt.jws_header.get("kid") + ) + _secret = endpoint_context.cdb[ca_jwt["iss"]].get("client_secret") if _secret and keys[0].key != as_bytes(_secret): raise AttributeError("Oct key used for signing not client_secret") else: @@ -211,26 +267,25 @@ def verify(self, request, key_type, **kwargs): authtoken = sanitize(ca_jwt.to_dict()) logger.debug("authntoken: {}".format(authtoken)) - _endpoint = kwargs.get("endpoint") - if _endpoint is None or not _endpoint: - if _context.issuer in ca_jwt["aud"]: + if endpoint is None or not endpoint: + if endpoint_context.issuer in ca_jwt["aud"]: pass else: - raise NotForMe("Not for me!") + raise InvalidToken("Not for me!") else: - if set(ca_jwt["aud"]).intersection(_endpoint.allowed_target_uris()): + if set(ca_jwt["aud"]).intersection(endpoint.allowed_target_uris()): pass else: - raise NotForMe("Not for me!") + raise InvalidToken("Not for me!") # If there is a jti use it to make sure one-time usage is true _jti = ca_jwt.get("jti") if _jti: _key = "{}:{}".format(ca_jwt["iss"], _jti) - if _key in _context.jti_db: - raise MultipleUsage("Have seen this token once before") + if _key in endpoint_context.jti_db: + raise InvalidToken("Have seen this token once before") else: - _context.jti_db[_key] = utc_time_sans_frac() + endpoint_context.jti_db[_key] = utc_time_sans_frac() request[verified_claim_name("client_assertion")] = ca_jwt client_id = kwargs.get("client_id") or ca_jwt["iss"] @@ -248,10 +303,12 @@ class ClientSecretJWT(JWSAuthnMethod): tag = "client_secret_jwt" - def verify(self, request=None, **kwargs): - res = JWSAuthnMethod.verify(self, request, key_type="client_secret", **kwargs) + @classmethod + def _verify(cls, endpoint_context, request=None, **kwargs): + res = JWSAuthnMethod.verify( + endpoint_context, request=request, key_type="client_secret", **kwargs + ) # Verify that a HS alg was used - res["method"] = self.tag return res @@ -262,37 +319,40 @@ class PrivateKeyJWT(JWSAuthnMethod): tag = "private_key_jwt" - def verify(self, request=None, **kwargs): - res = JWSAuthnMethod.verify(self, request, key_type="private_key", **kwargs) + @classmethod + def _verify(cls, endpoint_context, request=None, **kwargs): + res = JWSAuthnMethod.verify( + endpoint_context, request=request, key_type="private_key", **kwargs + ) # Verify that an RS or ES alg was used ? - res["method"] = self.tag return res class RequestParam(ClientAuthnMethod): tag = "request_param" - def is_usable(self, request=None, authorization_info=None): + @classmethod + def is_usable(cls, endpoint_context, request=None, authorization_token=None): if request and "request" in request: return True - def verify(self, request=None, **kwargs): - _context = self.server_get("endpoint_context") - _jwt = JWT(_context.keyjar, msg_cls=JsonWebToken) + @classmethod + def _verify(cls, endpoint_context, request=None, **kwargs): + _jwt = JWT(endpoint_context.keyjar, msg_cls=JsonWebToken) try: _jwt = _jwt.unpack(request["request"]) except (Invalid, MissingKey, BadSignature) as err: logger.info("%s" % sanitize(err)) - raise AuthnFailure("Could not verify client_assertion.") + raise ClientAuthenticationError("Could not verify client_assertion.") # If there is a jti use it to make sure one-time usage is true _jti = _jwt.get("jti") if _jti: _key = "{}:{}".format(_jwt["iss"], _jti) - if _key in _context.jti_db: - raise MultipleUsage("Have seen this token once before") + if _key in endpoint_context.jti_db: + raise InvalidToken("Have seen this token once before") else: - _context.jti_db[_key] = utc_time_sans_frac() + endpoint_context.jti_db[_key] = utc_time_sans_frac() request[verified_claim_name("client_assertion")] = _jwt client_id = kwargs.get("client_id") or _jwt["iss"] @@ -300,16 +360,17 @@ def verify(self, request=None, **kwargs): return {"client_id": client_id, "jwt": _jwt} -CLIENT_AUTHN_METHOD = { - "client_secret_basic": ClientSecretBasic, - "client_secret_post": ClientSecretPost, - "bearer_header": BearerHeader, - "bearer_body": BearerBody, - "client_secret_jwt": ClientSecretJWT, - "private_key_jwt": PrivateKeyJWT, - "request_param": RequestParam, - "none": None, -} +# We use OrderedDict in order to ensure that the `none` method is used last +CLIENT_AUTHN_METHOD = OrderedDict( + client_secret_basic=ClientSecretBasic, + client_secret_post=ClientSecretPost, + bearer_header=BearerHeader, + bearer_body=BearerBody, + client_secret_jwt=ClientSecretJWT, + private_key_jwt=PrivateKeyJWT, + request_param=RequestParam, + none=NoneAuthn, +) TYPE_METHOD = [(JWT_BEARER, JWSAuthnMethod)] @@ -348,30 +409,28 @@ def verify_client( authorization_token = None auth_info = {} - _methods = getattr(endpoint, "client_authn_method", []) + allowed_methods = getattr(endpoint, "client_authn_method") + if not allowed_methods: + allowed_methods = list(CLIENT_AUTHN_METHOD.keys()) - for _method in _methods: - if _method is None: + for _method in (CLIENT_AUTHN_METHOD[meth] for meth in allowed_methods): + if not _method.is_usable( + endpoint_context, request=request, authorization_token=authorization_token + ): continue - if _method.is_usable(request, authorization_token): - try: - auth_info = _method.verify( - request=request, - authorization_token=authorization_token, - endpoint=endpoint, - ) - except Exception as err: - logger.warning("Verifying auth using {} failed: {}".format(_method.tag, err)) - else: - if "method" not in auth_info: - auth_info["method"] = _method.tag - break - - if not auth_info: - if None in _methods: - auth_info = {"method": "none", "client_id": request.get("client_id")} - else: - return auth_info + try: + auth_info = _method.verify( + endpoint_context, + request=request, + authorization_token=authorization_token, + endpoint=endpoint, + get_client_id_from_token=get_client_id_from_token, + ) + break + except (BearerTokenAuthenticationError, ClientAuthenticationError): + raise + except Exception as err: + logger.info("Verifying auth using {} failed: {}".format(_method.tag, err)) if also_known_as: client_id = also_known_as[auth_info.get("client_id")] @@ -379,58 +438,34 @@ def verify_client( else: client_id = auth_info.get("client_id") - _token = auth_info.get("token") - - if client_id: - if client_id not in endpoint_context.cdb: - raise UnknownClient("Unknown Client ID") - - _cinfo = endpoint_context.cdb[client_id] - if isinstance(_cinfo, str): - if _cinfo not in endpoint_context.cdb: - raise UnknownClient("Unknown Client ID") - - if not valid_client_info(_cinfo): - logger.warning("Client registration has timed out or " "client secret is expired.") - raise InvalidClient("Not valid client") - - # store what authn method was used - if auth_info.get("method"): - _request_type = request.__class__.__name__ - _used_authn_method = endpoint_context.cdb[client_id].get("auth_method") - if _used_authn_method: - endpoint_context.cdb[client_id]["auth_method"][_request_type] = auth_info["method"] - else: - endpoint_context.cdb[client_id]["auth_method"] = { - _request_type: auth_info["method"] - } - elif not client_id and get_client_id_from_token: - if not _token: - logger.warning("No token") - raise BearerTokenAuthenticationError("No token") - - try: - # get_client_id_from_token is a callback... Do not abuse for code readability. - auth_info["client_id"] = get_client_id_from_token(endpoint_context, _token, request) - except ToOld: - raise BearerTokenAuthenticationError("Expired token") - except KeyError: - raise BearerTokenAuthenticationError("Unknown token") - - return auth_info - - -def client_auth_setup(auth_set, server_get): - res = [] - - for item in auth_set: - if item is None or item.lower() == "none": - res.append(None) + if client_id not in endpoint_context.cdb: + raise UnknownClient("Unknown Client ID") + + _cinfo = endpoint_context.cdb[client_id] + + if not valid_client_info(_cinfo): + logger.warning("Client registration has timed out or " "client secret is expired.") + raise InvalidClient("Not valid client") + + # Validate that the used method is allowed for this client/endpoint + client_allowed_methods = _cinfo.get(f"{endpoint.endpoint_name}_client_authn_method") + if client_allowed_methods is not None and _method.tag not in client_allowed_methods: + logger.info( + f"Allowed methods for client: {client_id} at endpoint: {endpoint.name} are: " + f"`{', '.join(client_allowed_methods)}`" + ) + raise UnAuthorizedClient( + f"Authentication method: {_method.tag} not allowed for client: {client_id} in " + f"endpoint: {endpoint.name}" + ) + + # store what authn method was used + if auth_info.get("method"): + _request_type = request.__class__.__name__ + _used_authn_method = _cinfo.get("auth_method") + if _used_authn_method: + endpoint_context.cdb[client_id]["auth_method"][_request_type] = auth_info["method"] else: - _cls = CLIENT_AUTHN_METHOD.get(item) - if _cls: - res.append(_cls(server_get)) - else: - res.append(importer(item)(server_get)) + endpoint_context.cdb[client_id]["auth_method"] = {_request_type: auth_info["method"]} - return res + return auth_info diff --git a/src/oidcop/endpoint.py b/src/oidcop/endpoint.py index 369940b1..b128e7db 100755 --- a/src/oidcop/endpoint.py +++ b/src/oidcop/endpoint.py @@ -12,7 +12,6 @@ from oidcmsg.oidc import RegistrationRequest from oidcop import sanitize -from oidcop.client_authn import client_auth_setup from oidcop.client_authn import verify_client from oidcop.construct import construct_endpoint_info from oidcop.endpoint_context import EndpointContext @@ -115,13 +114,11 @@ def __init__(self, server_get: Callable, **kwargs): _methods = kwargs.get("client_authn_method") self.client_authn_method = [] if _methods: - self.client_authn_method = client_auth_setup(_methods, server_get) + self.client_authn_method = _methods elif _methods is not None: # [] or '' or something not None but regarded as nothing. - self.client_authn_method = [None] # Ignore default value + self.client_authn_method = ["none"] # Ignore default value elif self.default_capabilities: - _methods = self.default_capabilities.get("client_authn_method") - if _methods: - self.client_authn_method = client_auth_setup(_methods, server_get) + self.client_authn_method = self.default_capabilities.get("client_authn_method") self.endpoint_info = construct_endpoint_info(self.default_capabilities, **kwargs) # This is for matching against aud in JWTs diff --git a/src/oidcop/exception.py b/src/oidcop/exception.py index eb500cfb..dcb5d69c 100755 --- a/src/oidcop/exception.py +++ b/src/oidcop/exception.py @@ -70,6 +70,10 @@ class InvalidClient(ClientAuthenticationError): pass +class InvalidToken(ClientAuthenticationError): + pass + + class UnAuthorizedClient(ClientAuthenticationError): pass diff --git a/src/oidcop/oauth2/add_on/dpop.py b/src/oidcop/oauth2/add_on/dpop.py index 472c5ea6..afcab59b 100644 --- a/src/oidcop/oauth2/add_on/dpop.py +++ b/src/oidcop/oauth2/add_on/dpop.py @@ -9,9 +9,9 @@ from oidcmsg.message import SINGLE_REQUIRED_STRING from oidcmsg.message import Message -from oidcop.client_authn import AuthnFailure from oidcop.client_authn import ClientAuthnMethod from oidcop.client_authn import basic_authn +from oidcop.exception import ClientAuthenticationError class DPoPProof(Message): @@ -167,4 +167,4 @@ def verify(self, authorization_info, **kwargs): if _context.cdb[client_info["id"]]["client_secret"] == client_info["secret"]: return {"client_id": client_info["id"]} else: - raise AuthnFailure() + raise ClientAuthenticationError() diff --git a/src/oidcop/oauth2/authorization.py b/src/oidcop/oauth2/authorization.py index 86b7ce82..f01c569d 100755 --- a/src/oidcop/oauth2/authorization.py +++ b/src/oidcop/oauth2/authorization.py @@ -286,6 +286,7 @@ class Authorization(Endpoint): name = "authorization" default_capabilities = { "claims_parameter_supported": True, + "client_authn_method": ["request_param", "none"], "request_parameter_supported": True, "request_uri_parameter_supported": True, "response_types_supported": ["code", "token", "code token"], diff --git a/src/oidcop/oauth2/introspection.py b/src/oidcop/oauth2/introspection.py index c298c12d..7b39d116 100644 --- a/src/oidcop/oauth2/introspection.py +++ b/src/oidcop/oauth2/introspection.py @@ -20,6 +20,14 @@ class Introspection(Endpoint): response_format = "json" endpoint_name = "introspection_endpoint" name = "introspection" + default_capabilities = { + "client_authn_method": [ + "client_secret_basic", + "client_secret_post", + "client_secret_jwt", + "private_key_jwt", + ] + } def __init__(self, server_get, **kwargs): Endpoint.__init__(self, server_get, **kwargs) diff --git a/src/oidcop/oidc/authorization.py b/src/oidcop/oidc/authorization.py index fb4d73e8..800adb90 100755 --- a/src/oidcop/oidc/authorization.py +++ b/src/oidcop/oidc/authorization.py @@ -77,6 +77,7 @@ class Authorization(authorization.Authorization): name = "authorization" default_capabilities = { "claims_parameter_supported": True, + "client_authn_method": ["request_param", "none"], "request_parameter_supported": True, "request_uri_parameter_supported": True, "response_types_supported": [ diff --git a/src/oidcop/oidc/session.py b/src/oidcop/oidc/session.py index 79a7f9dc..aa011e57 100644 --- a/src/oidcop/oidc/session.py +++ b/src/oidcop/oidc/session.py @@ -25,7 +25,6 @@ from oidcmsg.oidc.session import EndSessionRequest from oidcop import rndstr -from oidcop.client_authn import UnknownOrNoAuthnMethod from oidcop.endpoint import Endpoint from oidcop.endpoint_context import add_path from oidcop.oauth2.authorization import verify_uri @@ -361,18 +360,14 @@ def parse_request(self, request, http_info=None, **kwargs): request = {} # Verify that the client is allowed to do this - try: - auth_info = self.client_authentication(request, http_info, **kwargs) - except UnknownOrNoAuthnMethod: + auth_info = self.client_authentication(request, http_info, **kwargs) + if not auth_info: pass + elif isinstance(auth_info, ResponseMessage): + return auth_info else: - if not auth_info: - pass - elif isinstance(auth_info, ResponseMessage): - return auth_info - else: - request["client_id"] = auth_info["client_id"] - request["access_token"] = auth_info["token"] + request["client_id"] = auth_info["client_id"] + request["access_token"] = auth_info["token"] if isinstance(request, dict): _context = self.server_get("endpoint_context") diff --git a/src/oidcop/oidc/userinfo.py b/src/oidcop/oidc/userinfo.py index 6e6ce978..c71d7a7c 100755 --- a/src/oidcop/oidc/userinfo.py +++ b/src/oidcop/oidc/userinfo.py @@ -13,7 +13,7 @@ from oidcmsg.oauth2 import ResponseMessage from oidcop.endpoint import Endpoint -from oidcop.token.exception import UnknownToken +from oidcop.exception import ClientAuthenticationError from oidcop.util import OAUTH2_NOCACHE_HEADERS logger = logging.getLogger(__name__) @@ -172,7 +172,7 @@ def parse_request(self, request, http_info=None, **kwargs): # Verify that the client is allowed to do this try: auth_info = self.client_authentication(request, http_info, **kwargs) - except (ValueError, UnknownToken) as e: + except ClientAuthenticationError as e: return self.error_cls(error="invalid_token", error_description=e.args[0]) if isinstance(auth_info, ResponseMessage): diff --git a/src/oidcop/server.py b/src/oidcop/server.py index 1972aff6..de2a2876 100644 --- a/src/oidcop/server.py +++ b/src/oidcop/server.py @@ -6,7 +6,6 @@ from oidcmsg.impexp import ImpExp from oidcop import authz -from oidcop.client_authn import client_auth_setup from oidcop.configure import ASConfiguration from oidcop.configure import OPConfiguration from oidcop.endpoint import Endpoint @@ -94,23 +93,8 @@ def __init__( # Must be done after userinfo self.do_login_hint_lookup() - for endpoint_name, endpoint_conf in self.endpoint.items(): - _endpoint = self.endpoint[endpoint_name] - _methods = _endpoint.kwargs.get("client_authn_method") - - self.client_authn_method = [] - if _methods: - _endpoint.client_authn_method = client_auth_setup(_methods, self.server_get) - elif _methods is not None: # [] or '' or something not None but regarded as nothing. - _endpoint.client_authn_method = [None] # Ignore default value - elif _endpoint.default_capabilities: - _methods = _endpoint.default_capabilities.get("client_authn_method") - if _methods: - _endpoint.client_authn_method = client_auth_setup( - auth_set=_methods, server_get=self.server_get - ) - - _endpoint.server_get = self.server_get + for endpoint_name, _ in self.endpoint.items(): + self.endpoint[endpoint_name].server_get = self.server_get _token_endp = self.endpoint.get("token") if _token_endp: diff --git a/tests/test_02_client_authn.py b/tests/test_02_client_authn.py index 243454fe..7852b659 100755 --- a/tests/test_02_client_authn.py +++ b/tests/test_02_client_authn.py @@ -9,7 +9,6 @@ from cryptojwt.utils import as_unicode from oidcop import JWT_BEARER -from oidcop.client_authn import AuthnFailure from oidcop.client_authn import BearerBody from oidcop.client_authn import BearerHeader from oidcop.client_authn import ClientSecretBasic @@ -19,8 +18,8 @@ from oidcop.client_authn import PrivateKeyJWT from oidcop.client_authn import basic_authn from oidcop.client_authn import verify_client -from oidcop.exception import MultipleUsage -from oidcop.exception import NotForMe +from oidcop.exception import ClientAuthenticationError +from oidcop.exception import InvalidToken from oidcop.oidc.authorization import Authorization from oidcop.oidc.registration import Registration from oidcop.oidc.token import Token @@ -85,10 +84,11 @@ def get_client_id_from_token(endpoint_context, token, request=None): class TestClientSecretBasic: @pytest.fixture(autouse=True) - def create_method(self): + def setup(self): server = Server(conf=CONF, keyjar=KEYJAR) server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} - self.method = ClientSecretBasic(server.server_get) + self.endpoint_context = server.endpoint_context + self.method = ClientSecretBasic def test_client_secret_basic(self): _token = "{}:{}".format(client_id, client_secret) @@ -96,13 +96,13 @@ def test_client_secret_basic(self): authz_token = "Basic {}".format(token) - assert self.method.is_usable(authorization_token=authz_token) - authn_info = self.method.verify(authorization_token=authz_token) + assert self.method.is_usable(self.endpoint_context, authorization_token=authz_token) + authn_info = self.method.verify(self.endpoint_context, authorization_token=authz_token) assert authn_info["client_id"] == client_id def test_wrong_type(self): - assert self.method.is_usable(authorization_token="Foppa toffel") is False + assert self.method.is_usable(self.endpoint_context, authorization_token="Foppa toffel") is False def test_csb_wrong_secret(self): _token = "{}:{}".format(client_id, "pillow") @@ -110,10 +110,10 @@ def test_csb_wrong_secret(self): authz_token = "Basic {}".format(token) - assert self.method.is_usable(authorization_token=authz_token) + assert self.method.is_usable(self.endpoint_context, authorization_token=authz_token) - with pytest.raises(AuthnFailure): - self.method.verify(authorization_token=authz_token) + with pytest.raises(ClientAuthenticationError): + self.method.verify(self.endpoint_context, authorization_token=authz_token) class TestClientSecretPost: @@ -121,21 +121,22 @@ class TestClientSecretPost: def create_method(self): server = Server(conf=CONF, keyjar=KEYJAR) server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} - self.method = ClientSecretPost(server.server_get) + self.endpoint_context = server.endpoint_context + self.method = ClientSecretPost def test_client_secret_post(self): request = {"client_id": client_id, "client_secret": client_secret} - assert self.method.is_usable(request=request) - authn_info = self.method.verify(request) + assert self.method.is_usable(self.endpoint_context, request=request) + authn_info = self.method.verify(self.endpoint_context, request) assert authn_info["client_id"] == client_id def test_client_secret_post_wrong_secret(self): request = {"client_id": client_id, "client_secret": "pillow"} - assert self.method.is_usable(request=request) - with pytest.raises(AuthnFailure): - self.method.verify(request) + assert self.method.is_usable(self.endpoint_context, request=request) + with pytest.raises(ClientAuthenticationError): + self.method.verify(self.endpoint_context, request) class TestClientSecretJWT: @@ -143,7 +144,8 @@ class TestClientSecretJWT: def create_method(self): server = Server(conf=CONF, keyjar=KEYJAR) server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} - self.method = ClientSecretJWT(server.server_get) + self.endpoint_context = server.endpoint_context + self.method = ClientSecretJWT def test_client_secret_jwt(self): client_keyjar = KeyJar() @@ -157,8 +159,8 @@ def test_client_secret_jwt(self): request = {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER} - assert self.method.is_usable(request=request) - authn_info = self.method.verify(request) + assert self.method.is_usable(self.endpoint_context, request=request) + authn_info = self.method.verify(self.endpoint_context, request=request) assert authn_info["client_id"] == client_id assert "jwt" in authn_info @@ -169,7 +171,9 @@ class TestPrivateKeyJWT: def create_method(self): server = Server(conf=CONF, keyjar=KEYJAR) server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} - self.method = PrivateKeyJWT(server.server_get) + self.server = server + self.endpoint_context = server.endpoint_context + self.method = PrivateKeyJWT def test_private_key_jwt(self): # Own dynamic keys @@ -178,7 +182,7 @@ def test_private_key_jwt(self): client_keyjar.import_jwks(KEYJAR.export_jwks(private=True), CONF["issuer"]) _jwks = client_keyjar.export_jwks() - self.method.server_get("endpoint_context").keyjar.import_jwks(_jwks, client_id) + self.endpoint_context.keyjar.import_jwks(_jwks, client_id) _jwt = JWT(client_keyjar, iss=client_id, sign_alg="RS256") _jwt.with_jti = True @@ -186,8 +190,8 @@ def test_private_key_jwt(self): request = {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER} - assert self.method.is_usable(request=request) - authn_info = self.method.verify(request=request) + assert self.method.is_usable(self.endpoint_context, request=request) + authn_info = self.method.verify(self.endpoint_context, request=request) assert authn_info["client_id"] == client_id assert "jwt" in authn_info @@ -199,27 +203,27 @@ def test_private_key_jwt_reusage_other_endpoint(self): client_keyjar.import_jwks(KEYJAR.export_jwks(private=True), CONF["issuer"]) _jwks = client_keyjar.export_jwks() - self.method.server_get("endpoint_context").keyjar.import_jwks(_jwks, client_id) + self.endpoint_context.keyjar.import_jwks(_jwks, client_id) _jwt = JWT(client_keyjar, iss=client_id, sign_alg="RS256") _jwt.with_jti = True - _assertion = _jwt.pack({"aud": [self.method.server_get("endpoint", "token").full_path]}) + _assertion = _jwt.pack({"aud": [self.server.server_get("endpoint", "token").full_path]}) request = {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER} # This should be OK - assert self.method.is_usable(request=request) - self.method.verify(request=request, endpoint=self.method.server_get("endpoint", "token")) + assert self.method.is_usable(self.endpoint_context, request=request) + self.method.verify(self.endpoint_context, request=request, endpoint=self.server.server_get("endpoint", "token")) # This should NOT be OK - with pytest.raises(NotForMe): + with pytest.raises(InvalidToken): self.method.verify( - request, endpoint=self.method.server_get("endpoint", "authorization") + self.endpoint_context, request=request, endpoint=self.server.server_get("endpoint", "authorization") ) # This should NOT be OK because this is the second time the token appears - with pytest.raises(MultipleUsage): - self.method.verify(request, endpoint=self.method.server_get("endpoint", "token")) + with pytest.raises(InvalidToken): + self.method.verify(self.endpoint_context, request=request, endpoint=self.server.server_get("endpoint", "token")) def test_private_key_jwt_auth_endpoint(self): # Own dynamic keys @@ -228,19 +232,19 @@ def test_private_key_jwt_auth_endpoint(self): client_keyjar.import_jwks(KEYJAR.export_jwks(private=True), CONF["issuer"]) _jwks = client_keyjar.export_jwks() - self.method.server_get("endpoint_context").keyjar.import_jwks(_jwks, client_id) + self.endpoint_context.keyjar.import_jwks(_jwks, client_id) _jwt = JWT(client_keyjar, iss=client_id, sign_alg="RS256") _jwt.with_jti = True _assertion = _jwt.pack( - {"aud": [self.method.server_get("endpoint", "authorization").full_path]} + {"aud": [self.server.server_get("endpoint", "authorization").full_path]} ) request = {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER} - assert self.method.is_usable(request=request) + assert self.method.is_usable(self.endpoint_context, request=request) authn_info = self.method.verify( - request=request, endpoint=self.method.server_get("endpoint", "authorization"), + self.endpoint_context, request=request, endpoint=self.server.server_get("endpoint", "authorization"), ) assert authn_info["client_id"] == client_id @@ -252,15 +256,18 @@ class TestBearerHeader: def create_method(self): server = Server(conf=CONF, keyjar=KEYJAR) server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} - self.method = BearerHeader(server.server_get) + self.server = server + self.endpoint_context = server.endpoint_context + self.method = BearerHeader def test_bearerheader(self): authorization_info = "Bearer 1234567890" - assert self.method.verify(authorization_token=authorization_info) == {"token": "1234567890"} + get_client_id_from_token = lambda *_: "client_id" + assert self.method.verify(self.endpoint_context, authorization_token=authorization_info, get_client_id_from_token=get_client_id_from_token) == {"token": "1234567890", "method": "bearer_header", "client_id": "client_id"} def test_bearerheader_wrong_type(self): authorization_info = "Thrower 1234567890" - assert self.method.is_usable(authorization_token=authorization_info) is False + assert self.method.is_usable(self.endpoint_context, authorization_token=authorization_info) is False class TestBearerBody: @@ -268,16 +275,18 @@ class TestBearerBody: def create_method(self): server = Server(conf=CONF, keyjar=KEYJAR) server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} - self.method = BearerBody(server.server_get) + self.server = server + self.endpoint_context = server.endpoint_context + self.method = BearerBody def test_bearer_body(self): request = {"access_token": "1234567890"} - assert self.method.verify(request) == {"token": "1234567890"} + assert self.method.verify(self.endpoint_context, request) == {"token": "1234567890", "method": "bearer_body"} def test_bearer_body_no_token(self): request = {} - with pytest.raises(AuthnFailure): - self.method.verify(request=request) + with pytest.raises(ClientAuthenticationError): + self.method.verify(self.endpoint_context, request=request) class TestJWSAuthnMethod: @@ -285,7 +294,9 @@ class TestJWSAuthnMethod: def create_method(self): server = Server(conf=CONF, keyjar=KEYJAR) server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} - self.method = JWSAuthnMethod(server.server_get) + self.server = server + self.endpoint_context = server.endpoint_context + self.method = JWSAuthnMethod def test_jws_authn_method_wrong_key(self): client_keyjar = KeyJar() @@ -299,7 +310,7 @@ def test_jws_authn_method_wrong_key(self): request = {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER} with pytest.raises(NoSuitableSigningKeys): - self.method.verify(request=request, key_type="private_key") + self.method.verify(self.endpoint_context, request=request, key_type="private_key") def test_jws_authn_method_aud_iss(self): client_keyjar = KeyJar() @@ -314,7 +325,7 @@ def test_jws_authn_method_aud_iss(self): request = {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER} - assert self.method.verify(request=request, key_type="client_secret") + assert self.method.verify(self.endpoint_context, request=request, key_type="client_secret") def test_jws_authn_method_aud_token_endpoint(self): client_keyjar = KeyJar() @@ -331,8 +342,9 @@ def test_jws_authn_method_aud_token_endpoint(self): request = {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER} assert self.method.verify( + self.endpoint_context, request=request, - endpoint=self.method.server_get("endpoint", "token"), + endpoint=self.server.server_get("endpoint", "token"), key_type="client_secret", ) @@ -351,8 +363,8 @@ def test_jws_authn_method_aud_not_me(self): request = {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER} - with pytest.raises(NotForMe): - self.method.verify(request=request, key_type="client_secret") + with pytest.raises(InvalidToken): + self.method.verify(self.endpoint_context, request=request, key_type="client_secret") def test_jws_authn_method_aud_userinfo_endpoint(self): client_keyjar = KeyJar() @@ -368,8 +380,9 @@ def test_jws_authn_method_aud_userinfo_endpoint(self): request = {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER} assert self.method.verify( + self.endpoint_context, request=request, - endpoint=self.method.server_get("endpoint", "userinfo"), + endpoint=self.server.server_get("endpoint", "userinfo"), key_type="client_secret", ) @@ -386,7 +399,7 @@ def test_basic_auth_wrong_label(): _token = "{}:{}".format(client_id, client_secret) token = as_unicode(base64.b64encode(as_bytes(_token))) - with pytest.raises(AuthnFailure): + with pytest.raises(ClientAuthenticationError): basic_authn("Expanded {}".format(token)) @@ -579,10 +592,10 @@ def test_verify_client_authorization_none(self): def test_verify_client_registration_none(self): # This is when no special auth method is configured - request = {"redirect_uris": ["https://example.com/cb"]} + request = {"redirect_uris": ["https://example.com/cb"], "client_id": "client_id"} res = verify_client( self.endpoint_context, request, endpoint=self.server.server_get("endpoint", "registration"), ) - assert res == {} + assert res == {"client_id": "client_id", "method": "none"} diff --git a/tests/test_20_endpoint.py b/tests/test_20_endpoint.py index 241d44c4..f5a4d99b 100755 --- a/tests/test_20_endpoint.py +++ b/tests/test_20_endpoint.py @@ -17,7 +17,7 @@ {"type": "EC", "crv": "P-256", "use": ["sig"]}, ] -REQ = Message(foo="bar", hej="hopp") +REQ = Message(foo="bar", hej="hopp", client_id="client_id") EXAMPLE_MSG = { "name": "Jane Doe", @@ -65,6 +65,7 @@ def create_endpoint(self): } server = Server(OPConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) + server.endpoint_context.cdb["client_id"] = {} self.endpoint_context = server.endpoint_context self.endpoint = server.server_get("endpoint", "") diff --git a/tests/test_23_oidc_registration_endpoint.py b/tests/test_23_oidc_registration_endpoint.py index 27e32c8f..f2013d74 100755 --- a/tests/test_23_oidc_registration_endpoint.py +++ b/tests/test_23_oidc_registration_endpoint.py @@ -45,6 +45,7 @@ ] MSG = { + "client_id": "client_id", "application_type": "web", "redirect_uris": [ "https://client.example.org/callback", @@ -152,6 +153,7 @@ def create_endpoint(self): "template_dir": "template", } server = Server(OPConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) + server.endpoint_context.cdb["client_id"] = {} self.endpoint = server.server_get("endpoint", "registration") def test_parse(self): diff --git a/tests/test_24_oauth2_token_endpoint.py b/tests/test_24_oauth2_token_endpoint.py index bdcd83e1..07599f82 100644 --- a/tests/test_24_oauth2_token_endpoint.py +++ b/tests/test_24_oauth2_token_endpoint.py @@ -15,7 +15,7 @@ from oidcop.authz import AuthzHandling from oidcop.client_authn import verify_client from oidcop.configure import ASConfiguration -from oidcop.exception import UnAuthorizedClient +from oidcop.exception import InvalidToken from oidcop.oauth2.authorization import Authorization from oidcop.oauth2.token import Token from oidcop.server import Server @@ -320,7 +320,7 @@ def test_process_request_using_private_key_jwt(self): _resp = self.token_endpoint.process_request(request=_req) # 2nd time used - with pytest.raises(UnAuthorizedClient): + with pytest.raises(InvalidToken): self.token_endpoint.parse_request(_token_request) def test_do_refresh_access_token(self): diff --git a/tests/test_26_oidc_userinfo_endpoint.py b/tests/test_26_oidc_userinfo_endpoint.py index 3ed37cc4..a7a42a03 100755 --- a/tests/test_26_oidc_userinfo_endpoint.py +++ b/tests/test_26_oidc_userinfo_endpoint.py @@ -565,7 +565,8 @@ def test_userinfo_claims_post(self): code = self._mint_code(grant, session_id) access_token = self._mint_token("access_token", grant, session_id, code) - _req = self.endpoint.parse_request({"access_token": access_token.value}) + http_info = {"headers": {"authorization": "Bearer {}".format(access_token.value)}} + _req = self.endpoint.parse_request({}, http_info=http_info) args = self.endpoint.process_request(_req) assert args res = self.endpoint.do_response(request=_req, **args) diff --git a/tests/test_30_oidc_end_session.py b/tests/test_30_oidc_end_session.py index 8dc0e433..54de8b70 100644 --- a/tests/test_30_oidc_end_session.py +++ b/tests/test_30_oidc_end_session.py @@ -288,8 +288,7 @@ def test_end_session_endpoint_with_cookie(self): _session_info = self.session_manager.get_session_info_by_token(_code) cookie = self._create_cookie(_session_info["session_id"]) http_info = {"cookie": [cookie]} - _req_args = self.session_endpoint.parse_request({"state": "1234567"}, http_info=http_info) - resp = self.session_endpoint.process_request(_req_args, http_info=http_info) + resp = self.session_endpoint.process_request({"state": "foo"}, http_info=http_info) # returns a signed JWT to be put in a verification web page shown to # the user diff --git a/tests/test_31_oauth2_introspection.py b/tests/test_31_oauth2_introspection.py index fcfb23cf..a13d7b67 100644 --- a/tests/test_31_oauth2_introspection.py +++ b/tests/test_31_oauth2_introspection.py @@ -18,7 +18,7 @@ from oidcop.authn_event import create_authn_event from oidcop.authz import AuthzHandling from oidcop.client_authn import verify_client -from oidcop.exception import UnAuthorizedClient +from oidcop.exception import UnknownClient from oidcop.oauth2.authorization import Authorization from oidcop.oauth2.introspection import Introspection from oidcop.oidc.token import Token @@ -240,7 +240,7 @@ def _get_access_token(self, areq): def test_parse_no_authn(self): access_token = self._get_access_token(AUTH_REQ) - with pytest.raises(UnAuthorizedClient): + with pytest.raises(UnknownClient): self.introspection_endpoint.parse_request({"token": access_token.value}) def test_parse_with_client_auth_in_req(self): @@ -271,7 +271,7 @@ def test_parse_with_wrong_client_authn(self): _basic_authz = "Basic {}".format(_basic_token) http_info = {"headers": {"authorization": _basic_authz}} - with pytest.raises(UnAuthorizedClient): + with pytest.raises(UnknownClient): self.introspection_endpoint.parse_request( {"token": access_token.value}, http_info=http_info ) diff --git a/tests/test_32_oidc_read_registration.py b/tests/test_32_oidc_read_registration.py index c021c14c..c199187c 100644 --- a/tests/test_32_oidc_read_registration.py +++ b/tests/test_32_oidc_read_registration.py @@ -48,6 +48,7 @@ } msg = { + "client_id": "client_1", "application_type": "web", "redirect_uris": [ "https://client.example.org/callback", @@ -117,6 +118,7 @@ def create_endpoint(self): server = Server(OPConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) self.registration_endpoint = server.server_get("endpoint", "registration") self.registration_api_endpoint = server.server_get("endpoint", "registration_read") + server.endpoint_context.cdb["client_1"] = {} def test_do_response(self): _req = self.registration_endpoint.parse_request(CLI_REQ.to_json()) diff --git a/tests/test_35_oidc_token_endpoint.py b/tests/test_35_oidc_token_endpoint.py index 73f19e9a..33801d4f 100755 --- a/tests/test_35_oidc_token_endpoint.py +++ b/tests/test_35_oidc_token_endpoint.py @@ -18,7 +18,7 @@ from oidcop.client_authn import verify_client from oidcop.configure import OPConfiguration from oidcop.cookie_handler import CookieHandler -from oidcop.exception import UnAuthorizedClient +from oidcop.exception import InvalidToken from oidcop.oidc import userinfo from oidcop.oidc.authorization import Authorization from oidcop.oidc.provider_config import ProviderConfiguration @@ -334,7 +334,7 @@ def test_process_request_using_private_key_jwt(self): _resp = self.token_endpoint.process_request(request=_req) # 2nd time used - with pytest.raises(UnAuthorizedClient): + with pytest.raises(InvalidToken): self.token_endpoint.parse_request(_token_request) def test_do_refresh_access_token(self): From 44eb00f6f54ff7c7c653c2d2e5a5989da73eb65f Mon Sep 17 00:00:00 2001 From: Nikos Sklikas Date: Tue, 15 Feb 2022 13:30:42 +0200 Subject: [PATCH 2/5] Document per client authn method --- docs/source/contents/conf.rst | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/docs/source/contents/conf.rst b/docs/source/contents/conf.rst index 30ed142f..5b338576 100644 --- a/docs/source/contents/conf.rst +++ b/docs/source/contents/conf.rst @@ -831,6 +831,24 @@ The usage rules for each token type. E.g.:: } } + +------------------- +client_authn_method +------------------- + +A list with the client authentication methods that are allowed for this client. + +This can be overriden per endpoint by adding the prefix `{endpoint_name}_`. +E.g to define `client_authn_method` for a client only for the introspection +endpoint we need to add to the client metadata:: + + { + "introspection_endpoint_client_authn_method": ["client_secret_basic", "client_secret_post"] + } + +NOTE: The client authentication methods defined per client MUST be a subset of the +endpoint's authentication methods, else they are ignored. + -------------- pkce_essential -------------- From bc4bc512dc463a00b3b84dcd7aa82bd7596dac0a Mon Sep 17 00:00:00 2001 From: Nikos Sklikas Date: Tue, 15 Feb 2022 13:30:49 +0200 Subject: [PATCH 3/5] Add tests --- src/oidcop/client_authn.py | 10 ++++--- tests/test_02_client_authn.py | 40 +++++++++++++++++++++++++++ tests/test_31_oauth2_introspection.py | 6 ++-- 3 files changed, 49 insertions(+), 7 deletions(-) diff --git a/src/oidcop/client_authn.py b/src/oidcop/client_authn.py index 2ffdba8f..bd89b096 100755 --- a/src/oidcop/client_authn.py +++ b/src/oidcop/client_authn.py @@ -432,11 +432,13 @@ def verify_client( except Exception as err: logger.info("Verifying auth using {} failed: {}".format(_method.tag, err)) + client_id = auth_info.get("client_id") + if client_id is None: + raise ClientAuthenticationError("Failed to verify client") + if also_known_as: - client_id = also_known_as[auth_info.get("client_id")] + client_id = also_known_as[client_id] auth_info["client_id"] = client_id - else: - client_id = auth_info.get("client_id") if client_id not in endpoint_context.cdb: raise UnknownClient("Unknown Client ID") @@ -448,7 +450,7 @@ def verify_client( raise InvalidClient("Not valid client") # Validate that the used method is allowed for this client/endpoint - client_allowed_methods = _cinfo.get(f"{endpoint.endpoint_name}_client_authn_method") + client_allowed_methods = _cinfo.get(f"{endpoint.endpoint_name}_client_authn_method", _cinfo.get("client_authn_method")) if client_allowed_methods is not None and _method.tag not in client_allowed_methods: logger.info( f"Allowed methods for client: {client_id} at endpoint: {endpoint.name} are: " diff --git a/tests/test_02_client_authn.py b/tests/test_02_client_authn.py index 7852b659..4daad060 100755 --- a/tests/test_02_client_authn.py +++ b/tests/test_02_client_authn.py @@ -426,6 +426,46 @@ def create_method(self): self.server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} self.endpoint_context = self.server.server_get("endpoint_context") + def test_verify_per_client_per_endpoint(self): + self.server.endpoint_context.cdb[client_id]["client_authn_method"] = ["none"] + + request = {"client_id": client_id} + res = verify_client( + self.endpoint_context, request, endpoint=self.server.server_get("endpoint", "token"), + ) + assert res == {"method": "none", "client_id": client_id} + + def test_verify_per_client_per_endpoint(self): + self.server.endpoint_context.cdb[client_id]["registration_endpoint_client_authn_method"] = ["none"] + self.server.endpoint_context.cdb[client_id]["token_endpoint_client_authn_method"] = ["client_secret_post"] + + request = {"client_id": client_id} + res = verify_client( + self.endpoint_context, request, endpoint=self.server.server_get("endpoint", "registration"), + ) + assert res == {"method": "none", "client_id": client_id} + + with pytest.raises(ClientAuthenticationError) as e: + verify_client( + self.endpoint_context, request, endpoint=self.server.server_get("endpoint", "token"), + ) + assert e.value.args[0] == "Failed to verify client" + + request = {"client_id": client_id, "client_secret": client_secret} + res = verify_client( + self.endpoint_context, request, endpoint=self.server.server_get("endpoint", "token"), + ) + assert set(res.keys()) == {"method", "client_id"} + assert res["method"] == "client_secret_post" + + def test_verify_client_client_secret_post(self): + request = {"client_id": client_id, "client_secret": client_secret} + res = verify_client( + self.endpoint_context, request, endpoint=self.server.server_get("endpoint", "token"), + ) + assert set(res.keys()) == {"method", "client_id"} + assert res["method"] == "client_secret_post" + def test_verify_client_jws_authn_method(self): client_keyjar = KeyJar() client_keyjar.import_jwks(KEYJAR.export_jwks(private=True), CONF["issuer"]) diff --git a/tests/test_31_oauth2_introspection.py b/tests/test_31_oauth2_introspection.py index a13d7b67..6c89ff37 100644 --- a/tests/test_31_oauth2_introspection.py +++ b/tests/test_31_oauth2_introspection.py @@ -18,7 +18,7 @@ from oidcop.authn_event import create_authn_event from oidcop.authz import AuthzHandling from oidcop.client_authn import verify_client -from oidcop.exception import UnknownClient +from oidcop.exception import ClientAuthenticationError from oidcop.oauth2.authorization import Authorization from oidcop.oauth2.introspection import Introspection from oidcop.oidc.token import Token @@ -240,7 +240,7 @@ def _get_access_token(self, areq): def test_parse_no_authn(self): access_token = self._get_access_token(AUTH_REQ) - with pytest.raises(UnknownClient): + with pytest.raises(ClientAuthenticationError): self.introspection_endpoint.parse_request({"token": access_token.value}) def test_parse_with_client_auth_in_req(self): @@ -271,7 +271,7 @@ def test_parse_with_wrong_client_authn(self): _basic_authz = "Basic {}".format(_basic_token) http_info = {"headers": {"authorization": _basic_authz}} - with pytest.raises(UnknownClient): + with pytest.raises(ClientAuthenticationError): self.introspection_endpoint.parse_request( {"token": access_token.value}, http_info=http_info ) From a3e46a99c343c601eb986e34cd1266ad6bc37eb4 Mon Sep 17 00:00:00 2001 From: Nikos Sklikas Date: Mon, 21 Feb 2022 15:52:03 +0200 Subject: [PATCH 4/5] Remove class method decorator --- src/oidcop/client_authn.py | 122 +++++++++++++++++------------ src/oidcop/endpoint_context.py | 2 + src/oidcop/oauth2/authorization.py | 2 +- src/oidcop/oidc/authorization.py | 2 +- src/oidcop/server.py | 5 ++ tests/test_02_client_authn.py | 117 +++++++++++++++++---------- 6 files changed, 154 insertions(+), 96 deletions(-) diff --git a/src/oidcop/client_authn.py b/src/oidcop/client_authn.py index bd89b096..3b0b19ec 100755 --- a/src/oidcop/client_authn.py +++ b/src/oidcop/client_authn.py @@ -27,6 +27,7 @@ from oidcop.exception import ToOld from oidcop.exception import UnAuthorizedClient from oidcop.exception import UnknownClient +from oidcop.util import importer logger = logging.getLogger(__name__) @@ -36,9 +37,14 @@ class ClientAuthnMethod(object): tag = None - @classmethod + def __init__(self, server_get): + """ + :param server_get: A method that can be used to get general server information. + """ + self.server_get = server_get + def _verify( - cls, endpoint_context, request=None, authorization_token=None, endpoint=None, **kwargs + self, endpoint_context, request=None, authorization_token=None, endpoint=None, **kwargs ): """ Verify authentication information in a request @@ -47,10 +53,8 @@ def _verify( """ raise NotImplementedError() - @classmethod def verify( - cls, - endpoint_context, + self, request=None, authorization_token=None, endpoint=None, @@ -62,19 +66,18 @@ def verify( :param kwargs: :return: """ - res = cls._verify( - endpoint_context, + res = self._verify( + self.server_get("endpoint_context"), request=request, authorization_token=authorization_token, endpoint=endpoint, get_client_id_from_token=get_client_id_from_token, **kwargs, ) - res["method"] = cls.tag + res["method"] = self.tag return res - @classmethod - def is_usable(cls, endpoint_context, request=None, authorization_token=None): + def is_usable(self, request=None, authorization_token=None): """ Verify that this authentication method is applicable. @@ -100,24 +103,39 @@ def basic_authn(authorization_token): class NoneAuthn(ClientAuthnMethod): + """ + Used for testing purposes + """ + + tag = "none" + + def is_usable(self, request=None, authorization_token=None): + return request is not None + + def _verify( + self, endpoint_context, request=None, authorization_token=None, endpoint=None, **kwargs + ): + return {"client_id": request.get("client_id")} + + +class PublicAuthn(ClientAuthnMethod): """ Used for public clients, that don't require any form of authentication other than their client_id """ - tag = "none" + tag = "public" - @classmethod - def is_usable(cls, endpoint_context, request=None, authorization_token=None): + def is_usable(self, request=None, authorization_token=None): return request and "client_id" in request - @classmethod def _verify( - cls, endpoint_context, request=None, authorization_token=None, endpoint=None, **kwargs + self, endpoint_context, request=None, authorization_token=None, endpoint=None, **kwargs ): return {"client_id": request["client_id"]} + class ClientSecretBasic(ClientAuthnMethod): """ Clients that have received a client_secret value from the Authorization @@ -127,15 +145,13 @@ class ClientSecretBasic(ClientAuthnMethod): tag = "client_secret_basic" - @classmethod - def is_usable(cls, endpoint_context, request=None, authorization_token=None): + def is_usable(self, request=None, authorization_token=None): if authorization_token is not None and authorization_token.startswith("Basic "): return True return False - @classmethod def _verify( - cls, endpoint_context, request=None, authorization_token=None, endpoint=None, **kwargs + self, endpoint_context, request=None, authorization_token=None, endpoint=None, **kwargs ): client_info = basic_authn(authorization_token) @@ -155,17 +171,15 @@ class ClientSecretPost(ClientSecretBasic): tag = "client_secret_post" - @classmethod - def is_usable(cls, endpoint_context, request=None, authorization_token=None): + def is_usable(self, request=None, authorization_token=None): if request is None: return False if "client_id" in request and "client_secret" in request: return True return False - @classmethod def _verify( - cls, endpoint_context, request=None, authorization_token=None, endpoint=None, **kwargs + self, endpoint_context, request=None, authorization_token=None, endpoint=None, **kwargs ): if endpoint_context.cdb[request["client_id"]]["client_secret"] == request["client_secret"]: return {"client_id": request["client_id"]} @@ -178,15 +192,13 @@ class BearerHeader(ClientSecretBasic): tag = "bearer_header" - @classmethod - def is_usable(cls, endpoint_context, request=None, authorization_token=None): + def is_usable(self, request=None, authorization_token=None): if authorization_token is not None and authorization_token.startswith("Bearer "): return True return False - @classmethod def _verify( - cls, + self, endpoint_context, request=None, authorization_token=None, @@ -211,15 +223,13 @@ class BearerBody(ClientSecretPost): tag = "bearer_body" - @classmethod - def is_usable(cls, endpoint_context, request=None, authorization_token=None): + def is_usable(self, request=None, authorization_token=None): if request is not None and "access_token" in request: return True return False - @classmethod def _verify( - cls, endpoint_context, request=None, authorization_token=None, endpoint=None, **kwargs + self, endpoint_context, request=None, authorization_token=None, endpoint=None, **kwargs ): _token = request.get("access_token") if _token is None: @@ -233,16 +243,14 @@ def _verify( class JWSAuthnMethod(ClientAuthnMethod): - @classmethod - def is_usable(cls, endpoint_context, request=None, authorization_token=None): + def is_usable(self, request=None, authorization_token=None): if request is None: return False if "client_assertion" in request: return True return False - @classmethod - def _verify(cls, endpoint_context, request=None, endpoint=None, key_type=None, **kwargs): + def _verify(self, endpoint_context, request=None, endpoint=None, key_type=None, **kwargs): _jwt = JWT(endpoint_context.keyjar, msg_cls=JsonWebToken) try: ca_jwt = _jwt.unpack(request["client_assertion"]) @@ -303,9 +311,8 @@ class ClientSecretJWT(JWSAuthnMethod): tag = "client_secret_jwt" - @classmethod - def _verify(cls, endpoint_context, request=None, **kwargs): - res = JWSAuthnMethod.verify( + def _verify(self, endpoint_context, request=None, **kwargs): + res = super()._verify( endpoint_context, request=request, key_type="client_secret", **kwargs ) # Verify that a HS alg was used @@ -319,9 +326,8 @@ class PrivateKeyJWT(JWSAuthnMethod): tag = "private_key_jwt" - @classmethod - def _verify(cls, endpoint_context, request=None, **kwargs): - res = JWSAuthnMethod.verify( + def _verify(self, endpoint_context, request=None, **kwargs): + res = super()._verify( endpoint_context, request=request, key_type="private_key", **kwargs ) # Verify that an RS or ES alg was used ? @@ -331,13 +337,11 @@ def _verify(cls, endpoint_context, request=None, **kwargs): class RequestParam(ClientAuthnMethod): tag = "request_param" - @classmethod - def is_usable(cls, endpoint_context, request=None, authorization_token=None): + def is_usable(self, request=None, authorization_token=None): if request and "request" in request: return True - @classmethod - def _verify(cls, endpoint_context, request=None, **kwargs): + def _verify(self, endpoint_context, request=None, **kwargs): _jwt = JWT(endpoint_context.keyjar, msg_cls=JsonWebToken) try: _jwt = _jwt.unpack(request["request"]) @@ -360,8 +364,7 @@ def _verify(cls, endpoint_context, request=None, **kwargs): return {"client_id": client_id, "jwt": _jwt} -# We use OrderedDict in order to ensure that the `none` method is used last -CLIENT_AUTHN_METHOD = OrderedDict( +CLIENT_AUTHN_METHOD = dict( client_secret_basic=ClientSecretBasic, client_secret_post=ClientSecretPost, bearer_header=BearerHeader, @@ -369,6 +372,7 @@ def _verify(cls, endpoint_context, request=None, **kwargs): client_secret_jwt=ClientSecretJWT, private_key_jwt=PrivateKeyJWT, request_param=RequestParam, + public=PublicAuthn, none=NoneAuthn, ) @@ -409,18 +413,18 @@ def verify_client( authorization_token = None auth_info = {} + methods = endpoint_context.client_authn_method allowed_methods = getattr(endpoint, "client_authn_method") if not allowed_methods: - allowed_methods = list(CLIENT_AUTHN_METHOD.keys()) + allowed_methods = list(methods.keys()) - for _method in (CLIENT_AUTHN_METHOD[meth] for meth in allowed_methods): + for _method in (methods[meth] for meth in allowed_methods): if not _method.is_usable( - endpoint_context, request=request, authorization_token=authorization_token + request=request, authorization_token=authorization_token ): continue try: auth_info = _method.verify( - endpoint_context, request=request, authorization_token=authorization_token, endpoint=endpoint, @@ -432,6 +436,9 @@ def verify_client( except Exception as err: logger.info("Verifying auth using {} failed: {}".format(_method.tag, err)) + if auth_info.get("method") == "none": + return auth_info + client_id = auth_info.get("client_id") if client_id is None: raise ClientAuthenticationError("Failed to verify client") @@ -471,3 +478,16 @@ def verify_client( endpoint_context.cdb[client_id]["auth_method"] = {_request_type: auth_info["method"]} return auth_info + + +def client_auth_setup(server_get, auth_set=None): + if auth_set is None: + auth_set = {} + auth_set = {**CLIENT_AUTHN_METHOD, **auth_set} + res = {} + + for name, cls in auth_set.items(): + if isinstance(cls, str): + cls = importer(cls) + res[name] = cls(server_get) + return res diff --git a/src/oidcop/endpoint_context.py b/src/oidcop/endpoint_context.py index fd135395..d09eca79 100755 --- a/src/oidcop/endpoint_context.py +++ b/src/oidcop/endpoint_context.py @@ -116,6 +116,7 @@ class EndpointContext(OidcContext): "symkey": "", "token_args_methods": [], # "userinfo": UserInfo, + "client_authn_method": {}, } def __init__( @@ -169,6 +170,7 @@ def __init__( self.template_handler = None self.token_args_methods = [] self.userinfo = None + self.client_authn_method = {} for param in [ "issuer", diff --git a/src/oidcop/oauth2/authorization.py b/src/oidcop/oauth2/authorization.py index f01c569d..f735c29b 100755 --- a/src/oidcop/oauth2/authorization.py +++ b/src/oidcop/oauth2/authorization.py @@ -286,7 +286,7 @@ class Authorization(Endpoint): name = "authorization" default_capabilities = { "claims_parameter_supported": True, - "client_authn_method": ["request_param", "none"], + "client_authn_method": ["request_param", "public"], "request_parameter_supported": True, "request_uri_parameter_supported": True, "response_types_supported": ["code", "token", "code token"], diff --git a/src/oidcop/oidc/authorization.py b/src/oidcop/oidc/authorization.py index 800adb90..81f8b53a 100755 --- a/src/oidcop/oidc/authorization.py +++ b/src/oidcop/oidc/authorization.py @@ -77,7 +77,7 @@ class Authorization(authorization.Authorization): name = "authorization" default_capabilities = { "claims_parameter_supported": True, - "client_authn_method": ["request_param", "none"], + "client_authn_method": ["request_param", "public"], "request_parameter_supported": True, "request_uri_parameter_supported": True, "response_types_supported": [ diff --git a/src/oidcop/server.py b/src/oidcop/server.py index de2a2876..76ed9981 100644 --- a/src/oidcop/server.py +++ b/src/oidcop/server.py @@ -8,6 +8,7 @@ from oidcop import authz from oidcop.configure import ASConfiguration from oidcop.configure import OPConfiguration +from oidcop.client_authn import client_auth_setup from oidcop.endpoint import Endpoint from oidcop.endpoint_context import EndpointContext from oidcop.endpoint_context import init_service @@ -93,6 +94,7 @@ def __init__( # Must be done after userinfo self.do_login_hint_lookup() + self.do_client_authn_methods() for endpoint_name, _ in self.endpoint.items(): self.endpoint[endpoint_name].server_get = self.server_get @@ -166,3 +168,6 @@ def do_login_hint_lookup(self): self.endpoint_context.login_hint_lookup = init_service(_conf) self.endpoint_context.login_hint_lookup.userinfo = _userinfo + + def do_client_authn_methods(self): + self.endpoint_context.client_authn_method = client_auth_setup(self.server_get, self.conf.get("client_authn_method")) diff --git a/tests/test_02_client_authn.py b/tests/test_02_client_authn.py index 4daad060..878f59c6 100755 --- a/tests/test_02_client_authn.py +++ b/tests/test_02_client_authn.py @@ -1,4 +1,5 @@ import base64 +from unittest.mock import MagicMock import pytest from cryptojwt.jws.exception import NoSuitableSigningKeys @@ -9,6 +10,7 @@ from cryptojwt.utils import as_unicode from oidcop import JWT_BEARER +from oidcop.client_authn import ClientAuthnMethod from oidcop.client_authn import BearerBody from oidcop.client_authn import BearerHeader from oidcop.client_authn import ClientSecretBasic @@ -88,7 +90,7 @@ def setup(self): server = Server(conf=CONF, keyjar=KEYJAR) server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} self.endpoint_context = server.endpoint_context - self.method = ClientSecretBasic + self.method = ClientSecretBasic(server.server_get) def test_client_secret_basic(self): _token = "{}:{}".format(client_id, client_secret) @@ -96,13 +98,13 @@ def test_client_secret_basic(self): authz_token = "Basic {}".format(token) - assert self.method.is_usable(self.endpoint_context, authorization_token=authz_token) - authn_info = self.method.verify(self.endpoint_context, authorization_token=authz_token) + assert self.method.is_usable(authorization_token=authz_token) + authn_info = self.method.verify(authorization_token=authz_token) assert authn_info["client_id"] == client_id def test_wrong_type(self): - assert self.method.is_usable(self.endpoint_context, authorization_token="Foppa toffel") is False + assert self.method.is_usable(authorization_token="Foppa toffel") is False def test_csb_wrong_secret(self): _token = "{}:{}".format(client_id, "pillow") @@ -110,10 +112,10 @@ def test_csb_wrong_secret(self): authz_token = "Basic {}".format(token) - assert self.method.is_usable(self.endpoint_context, authorization_token=authz_token) + assert self.method.is_usable(authorization_token=authz_token) with pytest.raises(ClientAuthenticationError): - self.method.verify(self.endpoint_context, authorization_token=authz_token) + self.method.verify(authorization_token=authz_token) class TestClientSecretPost: @@ -122,21 +124,21 @@ def create_method(self): server = Server(conf=CONF, keyjar=KEYJAR) server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} self.endpoint_context = server.endpoint_context - self.method = ClientSecretPost + self.method = ClientSecretPost(server.server_get) def test_client_secret_post(self): request = {"client_id": client_id, "client_secret": client_secret} - assert self.method.is_usable(self.endpoint_context, request=request) - authn_info = self.method.verify(self.endpoint_context, request) + assert self.method.is_usable(request=request) + authn_info = self.method.verify(request) assert authn_info["client_id"] == client_id def test_client_secret_post_wrong_secret(self): request = {"client_id": client_id, "client_secret": "pillow"} - assert self.method.is_usable(self.endpoint_context, request=request) + assert self.method.is_usable(request=request) with pytest.raises(ClientAuthenticationError): - self.method.verify(self.endpoint_context, request) + self.method.verify(request) class TestClientSecretJWT: @@ -145,7 +147,7 @@ def create_method(self): server = Server(conf=CONF, keyjar=KEYJAR) server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} self.endpoint_context = server.endpoint_context - self.method = ClientSecretJWT + self.method = ClientSecretJWT(server.server_get) def test_client_secret_jwt(self): client_keyjar = KeyJar() @@ -159,8 +161,8 @@ def test_client_secret_jwt(self): request = {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER} - assert self.method.is_usable(self.endpoint_context, request=request) - authn_info = self.method.verify(self.endpoint_context, request=request) + assert self.method.is_usable(request=request) + authn_info = self.method.verify(request=request) assert authn_info["client_id"] == client_id assert "jwt" in authn_info @@ -173,7 +175,7 @@ def create_method(self): server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} self.server = server self.endpoint_context = server.endpoint_context - self.method = PrivateKeyJWT + self.method = PrivateKeyJWT(server.server_get) def test_private_key_jwt(self): # Own dynamic keys @@ -190,8 +192,8 @@ def test_private_key_jwt(self): request = {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER} - assert self.method.is_usable(self.endpoint_context, request=request) - authn_info = self.method.verify(self.endpoint_context, request=request) + assert self.method.is_usable(request=request) + authn_info = self.method.verify(request=request) assert authn_info["client_id"] == client_id assert "jwt" in authn_info @@ -212,18 +214,18 @@ def test_private_key_jwt_reusage_other_endpoint(self): request = {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER} # This should be OK - assert self.method.is_usable(self.endpoint_context, request=request) - self.method.verify(self.endpoint_context, request=request, endpoint=self.server.server_get("endpoint", "token")) + assert self.method.is_usable(request=request) + self.method.verify(request=request, endpoint=self.server.server_get("endpoint", "token")) # This should NOT be OK with pytest.raises(InvalidToken): self.method.verify( - self.endpoint_context, request=request, endpoint=self.server.server_get("endpoint", "authorization") + request=request, endpoint=self.server.server_get("endpoint", "authorization") ) # This should NOT be OK because this is the second time the token appears with pytest.raises(InvalidToken): - self.method.verify(self.endpoint_context, request=request, endpoint=self.server.server_get("endpoint", "token")) + self.method.verify(request=request, endpoint=self.server.server_get("endpoint", "token")) def test_private_key_jwt_auth_endpoint(self): # Own dynamic keys @@ -242,9 +244,9 @@ def test_private_key_jwt_auth_endpoint(self): request = {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER} - assert self.method.is_usable(self.endpoint_context, request=request) + assert self.method.is_usable(request=request) authn_info = self.method.verify( - self.endpoint_context, request=request, endpoint=self.server.server_get("endpoint", "authorization"), + request=request, endpoint=self.server.server_get("endpoint", "authorization"), ) assert authn_info["client_id"] == client_id @@ -258,16 +260,16 @@ def create_method(self): server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} self.server = server self.endpoint_context = server.endpoint_context - self.method = BearerHeader + self.method = BearerHeader(server.server_get) def test_bearerheader(self): authorization_info = "Bearer 1234567890" get_client_id_from_token = lambda *_: "client_id" - assert self.method.verify(self.endpoint_context, authorization_token=authorization_info, get_client_id_from_token=get_client_id_from_token) == {"token": "1234567890", "method": "bearer_header", "client_id": "client_id"} + assert self.method.verify(authorization_token=authorization_info, get_client_id_from_token=get_client_id_from_token) == {"token": "1234567890", "method": "bearer_header", "client_id": "client_id"} def test_bearerheader_wrong_type(self): authorization_info = "Thrower 1234567890" - assert self.method.is_usable(self.endpoint_context, authorization_token=authorization_info) is False + assert self.method.is_usable(authorization_token=authorization_info) is False class TestBearerBody: @@ -277,16 +279,16 @@ def create_method(self): server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} self.server = server self.endpoint_context = server.endpoint_context - self.method = BearerBody + self.method = BearerBody(server.server_get) def test_bearer_body(self): request = {"access_token": "1234567890"} - assert self.method.verify(self.endpoint_context, request) == {"token": "1234567890", "method": "bearer_body"} + assert self.method.verify(request) == {"token": "1234567890", "method": "bearer_body"} def test_bearer_body_no_token(self): request = {} with pytest.raises(ClientAuthenticationError): - self.method.verify(self.endpoint_context, request=request) + self.method.verify(request=request) class TestJWSAuthnMethod: @@ -296,7 +298,7 @@ def create_method(self): server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} self.server = server self.endpoint_context = server.endpoint_context - self.method = JWSAuthnMethod + self.method = JWSAuthnMethod(server.server_get) def test_jws_authn_method_wrong_key(self): client_keyjar = KeyJar() @@ -310,7 +312,7 @@ def test_jws_authn_method_wrong_key(self): request = {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER} with pytest.raises(NoSuitableSigningKeys): - self.method.verify(self.endpoint_context, request=request, key_type="private_key") + self.method.verify(request=request, key_type="private_key") def test_jws_authn_method_aud_iss(self): client_keyjar = KeyJar() @@ -325,7 +327,7 @@ def test_jws_authn_method_aud_iss(self): request = {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER} - assert self.method.verify(self.endpoint_context, request=request, key_type="client_secret") + assert self.method.verify(request=request, key_type="client_secret") def test_jws_authn_method_aud_token_endpoint(self): client_keyjar = KeyJar() @@ -342,7 +344,6 @@ def test_jws_authn_method_aud_token_endpoint(self): request = {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER} assert self.method.verify( - self.endpoint_context, request=request, endpoint=self.server.server_get("endpoint", "token"), key_type="client_secret", @@ -364,7 +365,7 @@ def test_jws_authn_method_aud_not_me(self): request = {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER} with pytest.raises(InvalidToken): - self.method.verify(self.endpoint_context, request=request, key_type="client_secret") + self.method.verify(request=request, key_type="client_secret") def test_jws_authn_method_aud_userinfo_endpoint(self): client_keyjar = KeyJar() @@ -380,7 +381,6 @@ def test_jws_authn_method_aud_userinfo_endpoint(self): request = {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER} assert self.method.verify( - self.endpoint_context, request=request, endpoint=self.server.server_get("endpoint", "userinfo"), key_type="client_secret", @@ -426,24 +426,24 @@ def create_method(self): self.server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} self.endpoint_context = self.server.server_get("endpoint_context") - def test_verify_per_client_per_endpoint(self): - self.server.endpoint_context.cdb[client_id]["client_authn_method"] = ["none"] + def test_verify_per_client(self): + self.server.endpoint_context.cdb[client_id]["client_authn_method"] = ["public"] request = {"client_id": client_id} res = verify_client( - self.endpoint_context, request, endpoint=self.server.server_get("endpoint", "token"), + self.endpoint_context, request, endpoint=self.server.server_get("endpoint", "registration"), ) - assert res == {"method": "none", "client_id": client_id} + assert res == {"method": "public", "client_id": client_id} def test_verify_per_client_per_endpoint(self): - self.server.endpoint_context.cdb[client_id]["registration_endpoint_client_authn_method"] = ["none"] + self.server.endpoint_context.cdb[client_id]["registration_endpoint_client_authn_method"] = ["public"] self.server.endpoint_context.cdb[client_id]["token_endpoint_client_authn_method"] = ["client_secret_post"] request = {"client_id": client_id} res = verify_client( self.endpoint_context, request, endpoint=self.server.server_get("endpoint", "registration"), ) - assert res == {"method": "none", "client_id": client_id} + assert res == {"method": "public", "client_id": client_id} with pytest.raises(ClientAuthenticationError) as e: verify_client( @@ -630,7 +630,7 @@ def test_verify_client_authorization_none(self): assert res["method"] == "none" assert res["client_id"] == "client_id" - def test_verify_client_registration_none(self): + def test_verify_client_registration_public(self): # This is when no special auth method is configured request = {"redirect_uris": ["https://example.com/cb"], "client_id": "client_id"} res = verify_client( @@ -638,4 +638,35 @@ def test_verify_client_registration_none(self): request, endpoint=self.server.server_get("endpoint", "registration"), ) - assert res == {"client_id": "client_id", "method": "none"} + assert res == {"client_id": "client_id", "method": "public"} + + def test_verify_client_registration_none(self): + # This is when no special auth method is configured + request = {"redirect_uris": ["https://example.com/cb"]} + res = verify_client( + self.endpoint_context, + request, + endpoint=self.server.server_get("endpoint", "registration"), + ) + assert res == {"client_id": None, "method": "none"} + + +def test_client_auth_setup(): + + class Mock: + is_usable = MagicMock(return_value=True) + verify = MagicMock(return_value={"method": "custom", "client_id": client_id}) + + mock = Mock() + conf = dict(CONF) + conf["client_authn_method"] = {"custom": MagicMock(return_value=mock)} + conf["endpoint"]["registration"]["kwargs"]["client_authn_method"] = ["custom"] + server = Server(conf=conf, keyjar=KEYJAR) + server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} + + request = {"redirect_uris": ["https://example.com/cb"]} + res = verify_client(server.endpoint_context, request, endpoint=server.server_get("endpoint", "registration")) + + assert res == {"client_id": "client_id", "method": "custom"} + mock.is_usable.assert_called_once() + mock.verify.assert_called_once() From 03fc9188dcf44059fdfd0bfeddca11a06587d0fe Mon Sep 17 00:00:00 2001 From: Nikos Sklikas Date: Mon, 21 Feb 2022 15:52:32 +0200 Subject: [PATCH 5/5] Add documentation --- docs/source/contents/conf.rst | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/docs/source/contents/conf.rst b/docs/source/contents/conf.rst index 5b338576..00b76063 100644 --- a/docs/source/contents/conf.rst +++ b/docs/source/contents/conf.rst @@ -119,6 +119,30 @@ code_challenge_method The allowed code_challenge methods. The supported code challenge methods are: ``plain, S256, S384, S512`` +------------------- +client_authn_method +------------------- + +A dictionary with the allowed client authentication methods. The keys are the methods' +names and and the values must be either a class or a path to a python class that will +be imported and used to validate the client. The class should inherit from +`oidcop.client_authn.ClientAuthnMethod` and it must implement the methods +`is_usable` and `_verify`. You can then define which of these methods are allowed per +endpoint by defining a list with the names of the methods allowed in the endpoint's +capabilities. This can be overriden per client by defining `client_authn_method` +in the client's metadata. + +Defaults to: + - none: `oidcop.client_authn.NoneAuthn`, no client authentication. Never use this in production + - public: `oidcop.client_authn.PublicAuthn`, used for public clients, requires only a valid`client_id` in the request + - client_secret_basic: `oidcop.client_authn.ClientSecretBasic`, see https://datatracker.ietf.org/doc/html/rfc6749#section-2.3.1 + - client_secret_post: `oidcop.client_authn.ClientSecretPost`, see https://datatracker.ietf.org/doc/html/rfc6749#section-2.3.1 + - bearer_header: `oidcop.client_authn.BearerHeader`, see https://datatracker.ietf.org/doc/html/rfc6750#section-2.1 + - bearer_body: `oidcop.client_authn.BearerBody`, see https://datatracker.ietf.org/doc/html/rfc6750#section-2.2 + - client_secret_jwt: `oidcop.client_authn.ClientSecretJWT`, see https://openid.net/specs/openid-connect-core-1_0.html#ClientAuthentication + - private_key_jwt: `oidcop.client_authn.PrivateKeyJWT`, see https://openid.net/specs/openid-connect-core-1_0.html#ClientAuthentication + - request_param: `oidcop.client_authn.RequestParam`, see https://openid.net/specs/openid-connect-core-1_0.html#JWTRequests + -------------- authentication --------------