diff --git a/src/oidcop/oauth2/token.py b/src/oidcop/oauth2/token.py index c06199fc..0a39dd21 100755 --- a/src/oidcop/oauth2/token.py +++ b/src/oidcop/oauth2/token.py @@ -118,11 +118,16 @@ def process_request(self, req: Union[Message, dict], **kwargs): return self.error_cls(error="invalid_request", error_description="Missing code") _session_info = _mngr.get_session_info_by_token(_access_code, grant=True) - if _session_info["client_id"] != req["client_id"]: - logger.debug("{} owner of token".format(_session_info["client_id"])) + client_id = _session_info["client_id"] + if client_id != req["client_id"]: + logger.debug("{} owner of token".format(client_id)) logger.warning("Client using token it was not given") return self.error_cls(error="invalid_grant", error_description="Wrong client") + if "grant_types_supported" in _context.cdb[client_id]: + grant_types_supported = _context.cdb[client_id].get("grant_types_supported") + else: + grant_types_supported = _context.provider_info["grant_types_supported"] grant = _session_info["grant"] _based_on = grant.get_token(_access_code) @@ -162,7 +167,11 @@ def process_request(self, req: Union[Message, dict], **kwargs): if token.expires_at: _response["expires_in"] = token.expires_at - utc_time_sans_frac() - if issue_refresh and "refresh_token" in _supports_minting: + if ( + issue_refresh + and "refresh_token" in _supports_minting + and "refresh_token" in grant_types_supported + ): try: refresh_token = self._mint_token( token_class="refresh_token", diff --git a/src/oidcop/oidc/token.py b/src/oidcop/oidc/token.py index 61c9e878..bf4f3e2b 100755 --- a/src/oidcop/oidc/token.py +++ b/src/oidcop/oidc/token.py @@ -45,11 +45,16 @@ def process_request(self, req: Union[Message, dict], **kwargs): _session_info = _mngr.get_session_info_by_token(_access_code, grant=True) logger.debug(f"Session info: {_session_info}") - if _session_info["client_id"] != req["client_id"]: - logger.debug("{} owner of token".format(_session_info["client_id"])) + client_id = _session_info["client_id"] + if client_id != req["client_id"]: + logger.debug("{} owner of token".format(client_id)) logger.warning("{} using token it was not given".format(req["client_id"])) return self.error_cls(error="invalid_grant", error_description="Wrong client") + if "grant_types_supported" in _context.cdb[client_id]: + grant_types_supported = _context.cdb[client_id].get("grant_types_supported") + else: + grant_types_supported = _context.provider_info["grant_types_supported"] grant = _session_info["grant"] token_type = "Bearer" @@ -110,7 +115,11 @@ def process_request(self, req: Union[Message, dict], **kwargs): if token.expires_at: _response["expires_in"] = token.expires_at - utc_time_sans_frac() - if issue_refresh and "refresh_token" in _supports_minting: + if ( + issue_refresh + and "refresh_token" in _supports_minting + and "refresh_token" in grant_types_supported + ): try: refresh_token = self._mint_token( token_class="refresh_token", diff --git a/tests/test_24_oauth2_token_endpoint.py b/tests/test_24_oauth2_token_endpoint.py index 0f2bf2fc..bab5ad46 100644 --- a/tests/test_24_oauth2_token_endpoint.py +++ b/tests/test_24_oauth2_token_endpoint.py @@ -376,21 +376,7 @@ def test_refresh_grant_disallowed_per_client(self): _req = self.token_endpoint.parse_request(_token_request) _resp = self.token_endpoint.process_request(request=_req, issue_refresh=True) - _request = REFRESH_TOKEN_REQ.copy() - _request["refresh_token"] = _resp["response_args"]["refresh_token"] - - _token_value = _resp["response_args"]["refresh_token"] - _session_info = self.session_manager.get_session_info_by_token(_token_value) - _token = self.session_manager.find_token(_session_info["session_id"], _token_value) - _token.usage_rules["supports_minting"] = ["access_token", "refresh_token"] - - _req = self.token_endpoint.parse_request(_request.to_json()) - - assert isinstance(_req, TokenErrorResponse) - assert _req.to_dict() == { - "error": "invalid_request", - "error_description": "Unsupported grant_type: refresh_token", - } + assert "refresh_token" not in _resp def test_do_2nd_refresh_access_token(self): areq = AUTH_REQ.copy()