Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ dev = [
[tool.ruff]
line-length = 120
target-version = "py39"
exclude = ["src/kognic/auth/_version.py"]

[tool.ruff.lint]
select = ["E", "F", "B", "W", "I001", "PTH"]
4 changes: 4 additions & 0 deletions src/kognic/auth/credentials.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from dataclasses import dataclass
from datetime import datetime
from typing import Optional


@dataclass
Expand All @@ -9,3 +11,5 @@ class ApiCredentials:
user_id: int
issuer: str
name: str = "API Credentials"
created: Optional[datetime] = None
expires: Optional[datetime] = None
69 changes: 55 additions & 14 deletions src/kognic/auth/credentials_parser.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import json
import os
import re
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional, Union

Expand All @@ -17,6 +19,34 @@
]


def _parse_datetime(s: str) -> datetime:
"""Parse an ISO 8601 datetime string into a timezone-aware datetime.

Handles Z suffix and sub-microsecond precision (truncated to microseconds).
"""
s = s.replace("Z", "+00:00")
s = re.sub(r"(\.\d{6})\d+", r"\1", s)
return datetime.fromisoformat(s)


def _parse_optional_datetime(s: Optional[str]) -> Optional[datetime]:
"""Parse an optional datetime string, returning None if absent or unparseable."""
if s is None:
return None
try:
return _parse_datetime(s)
except Exception:
return None


def _check_expiry(creds: ApiCredentials) -> None:
"""Raise ValueError if the credentials have an expires field that is in the past."""
if creds.expires is None:
return
if datetime.now(timezone.utc) >= creds.expires:
raise ValueError(f"Credentials expired at {creds.expires.isoformat()}")


def parse_credentials(path: Union[str, os.PathLike, dict]) -> ApiCredentials:
if isinstance(path, dict):
credentials = path
Expand All @@ -41,6 +71,8 @@ def parse_credentials(path: Union[str, os.PathLike, dict]) -> ApiCredentials:
user_id=credentials["userId"],
issuer=credentials["issuer"],
name=credentials.get("name", "API Credentials"),
created=_parse_optional_datetime(credentials.get("created")),
expires=_parse_optional_datetime(credentials.get("expires")),
)


Expand Down Expand Up @@ -120,6 +152,25 @@ def resolve_any_credentials(auth: ANY_AUTH_TYPE) -> ApiCredentials:
return creds


def _resolve_credentials(
auth: ANY_AUTH_TYPE = None, client_id: Optional[str] = None, client_secret: Optional[str] = None
) -> Optional[ApiCredentials]:
"""
Resolve credentials from either an auth input (which can be a variety of types)
or from explicit client_id and client_secret parameters.
Falls back to environment variables if neither are provided.
Returns the full ApiCredentials object, or None if no credentials are found.
"""
if client_id is not None and client_secret is not None:
if auth is not None:
raise ValueError("Choose either auth or client_id+client_secret")
return _anonymous_credentials(client_id, client_secret)
elif auth is not None:
return resolve_any_credentials(auth)

return get_credentials_from_system()


def resolve_credentials(
auth: ANY_AUTH_TYPE = None, client_id: Optional[str] = None, client_secret: Optional[str] = None
) -> tuple[Optional[str], Optional[str]]:
Expand All @@ -132,20 +183,10 @@ def resolve_credentials(
:param client_secret:
:return:
"""
has_credentials_tuple = client_id is not None and client_secret is not None

if has_credentials_tuple:
if auth is not None:
raise ValueError("Choose either auth or client_id+client_secret")
return client_id, client_secret
elif auth is not None:
creds = resolve_any_credentials(auth)
return creds.client_id, creds.client_secret

creds = get_credentials_from_system()
if creds:
return creds.client_id, creds.client_secret
return None, None
creds = _resolve_credentials(auth, client_id, client_secret)
if creds is None:
return None, None
return creds.client_id, creds.client_secret


if __name__ == "__main__":
Expand Down
9 changes: 7 additions & 2 deletions src/kognic/auth/httpx/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from kognic.auth import DEFAULT_HOST, DEFAULT_TOKEN_ENDPOINT_RELPATH
from kognic.auth.base.auth_client import AuthClient
from kognic.auth.credentials_parser import resolve_credentials
from kognic.auth.credentials_parser import _check_expiry, _resolve_credentials

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -43,7 +43,12 @@ def __init__(
self.host = host
self.token_url = f"{host}{token_endpoint}"

client_id, client_secret = resolve_credentials(auth)
creds = _resolve_credentials(auth)
if creds:
_check_expiry(creds)

client_id = creds.client_id if creds else None
client_secret = creds.client_secret if creds else None

self._oauth_client = _AsyncFixedClient(
client_id=client_id,
Expand Down
4 changes: 4 additions & 0 deletions src/kognic/auth/internal/credentials_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ def save_credentials(creds: ApiCredentials, profile: str = DEFAULT_PROFILE) -> N
"issuer": creds.issuer,
"name": creds.name,
}
if creds.created is not None:
data["created"] = creds.created.isoformat()
if creds.expires is not None:
data["expires"] = creds.expires.isoformat()
kr.set_password(SERVICE_NAME, profile, json.dumps(data))
log.debug("Saved credentials to keyring for profile=%s", profile)

Expand Down
8 changes: 6 additions & 2 deletions src/kognic/auth/requests/auth_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from kognic.auth import DEFAULT_HOST, DEFAULT_TOKEN_ENDPOINT_RELPATH
from kognic.auth.base.auth_client import AuthClient
from kognic.auth.credentials_parser import ANY_AUTH_TYPE, resolve_credentials
from kognic.auth.credentials_parser import ANY_AUTH_TYPE, _check_expiry, _resolve_credentials

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -76,7 +76,11 @@ def __init__(
self.host = host
self.token_url = f"{host}{token_endpoint}"

client_id, client_secret = resolve_credentials(auth, client_id, client_secret)
creds = _resolve_credentials(auth, client_id, client_secret)
if creds:
_check_expiry(creds)
client_id = creds.client_id if creds else None
client_secret = creds.client_secret if creds else None
self._client_id = client_id
self._on_token_updated = on_token_updated

Expand Down
11 changes: 7 additions & 4 deletions src/kognic/auth/requests/base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from kognic.auth import DEFAULT_HOST, DEFAULT_TOKEN_ENDPOINT_RELPATH
from kognic.auth._sunset import SunsetHandler, default_sunset_handler, handle_sunset
from kognic.auth._user_agent import get_user_agent
from kognic.auth.credentials_parser import ANY_AUTH_TYPE, resolve_credentials
from kognic.auth.credentials_parser import ANY_AUTH_TYPE, _resolve_credentials
from kognic.auth.env_config import DEFAULT_ENV_CONFIG_FILE_PATH, load_kognic_env_config
from kognic.auth.internal.token_cache import TokenCache
from kognic.auth.requests.auth_session import RequestsAuthSession
Expand Down Expand Up @@ -167,9 +167,10 @@ def make_token_provider(
Returns:
Configured RequestsAuthSession
"""
client_id, client_secret = resolve_credentials(auth)
credentials = _resolve_credentials(auth)
client_id = credentials.client_id if credentials else None
return RequestsAuthSession(
auth=(client_id, client_secret),
auth=credentials,
host=auth_host,
token_endpoint=auth_token_endpoint,
initial_token=token_cache.load(auth_host, client_id) if (token_cache and client_id) else None,
Expand All @@ -192,7 +193,9 @@ def _get_shared_provider(
Providers are keyed by (client_id, auth_host, auth_token_endpoint, cache_type) and held
weakly, so they are GC'd once no BaseApiClient instances reference them.
"""
client_id, client_secret = resolve_credentials(auth)
credentials = _resolve_credentials(auth)
client_id = credentials.client_id if credentials else None
client_secret = credentials.client_secret if credentials else None

if not client_id or not client_secret:
return RequestsAuthSession(auth=auth, host=auth_host, token_endpoint=auth_token_endpoint)
Expand Down
23 changes: 14 additions & 9 deletions tests/test_base_client_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
from unittest.mock import MagicMock, patch

from kognic.auth._sunset import DATETIME_FMT, handle_sunset
from kognic.auth.credentials import ApiCredentials


def _creds(client_id: str, client_secret: str) -> ApiCredentials:
return ApiCredentials(client_id=client_id, client_secret=client_secret, email="", user_id=0, issuer="")


class TestSunsetHeaderHandling(unittest.TestCase):
Expand Down Expand Up @@ -213,7 +218,7 @@ def _make_clients(self, mock_session, n=2, **kwargs):

@patch("kognic.auth.requests.base_client.requests.Session")
@patch("kognic.auth.requests.base_client.RequestsAuthSession")
@patch("kognic.auth.requests.base_client.resolve_credentials", return_value=("id1", "secret1"))
@patch("kognic.auth.requests.base_client._resolve_credentials", return_value=_creds("id1", "secret1"))
def test_same_credentials_share_provider(self, _resolve, mock_ras, mock_session):
self._make_clients(mock_session, n=2, auth=("id1", "secret1"))

Expand All @@ -223,8 +228,8 @@ def test_same_credentials_share_provider(self, _resolve, mock_ras, mock_session)
@patch("kognic.auth.requests.base_client.requests.Session")
@patch("kognic.auth.requests.base_client.RequestsAuthSession")
@patch(
"kognic.auth.requests.base_client.resolve_credentials",
side_effect=lambda auth, *a, **kw: auth,
"kognic.auth.requests.base_client._resolve_credentials",
side_effect=lambda auth, *a, **kw: _creds(*auth),
)
def test_different_credentials_get_different_providers(self, _resolve, mock_ras, mock_session):
from kognic.auth.requests.base_client import BaseApiClient
Expand All @@ -238,7 +243,7 @@ def test_different_credentials_get_different_providers(self, _resolve, mock_ras,

@patch("kognic.auth.requests.base_client.requests.Session")
@patch("kognic.auth.requests.base_client.RequestsAuthSession")
@patch("kognic.auth.requests.base_client.resolve_credentials", return_value=("id1", "secret1"))
@patch("kognic.auth.requests.base_client._resolve_credentials", return_value=_creds("id1", "secret1"))
def test_different_auth_host_gets_different_provider(self, _resolve, mock_ras, mock_session):
from kognic.auth.requests.base_client import BaseApiClient

Expand All @@ -251,7 +256,7 @@ def test_different_auth_host_gets_different_provider(self, _resolve, mock_ras, m

@patch("kognic.auth.requests.base_client.requests.Session")
@patch("kognic.auth.requests.base_client.RequestsAuthSession")
@patch("kognic.auth.requests.base_client.resolve_credentials", return_value=("id1", "secret1"))
@patch("kognic.auth.requests.base_client._resolve_credentials", return_value=_creds("id1", "secret1"))
def test_cache_type_is_part_of_pool_key(self, _resolve, mock_ras, mock_session):
from kognic.auth.internal.token_cache import FileTokenCache
from kognic.auth.requests.base_client import BaseApiClient
Expand All @@ -265,7 +270,7 @@ def test_cache_type_is_part_of_pool_key(self, _resolve, mock_ras, mock_session):

@patch("kognic.auth.requests.base_client.requests.Session")
@patch("kognic.auth.requests.base_client.RequestsAuthSession")
@patch("kognic.auth.requests.base_client.resolve_credentials", return_value=("id1", "secret1"))
@patch("kognic.auth.requests.base_client._resolve_credentials", return_value=_creds("id1", "secret1"))
def test_explicit_token_provider_bypasses_pool(self, mock_resolve, mock_ras, mock_session):
from kognic.auth.requests.base_client import BaseApiClient, _provider_pool

Expand All @@ -279,7 +284,7 @@ def test_explicit_token_provider_bypasses_pool(self, mock_resolve, mock_ras, moc

@patch("kognic.auth.requests.base_client.requests.Session")
@patch("kognic.auth.requests.base_client.RequestsAuthSession")
@patch("kognic.auth.requests.base_client.resolve_credentials", return_value=("id1", "secret1"))
@patch("kognic.auth.requests.base_client._resolve_credentials", return_value=_creds("id1", "secret1"))
def test_pool_entry_alive_while_client_referenced(self, _resolve, mock_ras, mock_session):
from kognic.auth.requests.base_client import (
DEFAULT_HOST,
Expand Down Expand Up @@ -308,8 +313,8 @@ def test_provider_gc_when_all_clients_deleted(self):

# Use side_effect so each call returns a fresh object with no external strong references
with patch(
"kognic.auth.requests.base_client.resolve_credentials",
return_value=("id-gc", "secret-gc"),
"kognic.auth.requests.base_client._resolve_credentials",
return_value=_creds("id-gc", "secret-gc"),
):
with patch(
"kognic.auth.requests.base_client.RequestsAuthSession",
Expand Down
51 changes: 51 additions & 0 deletions tests/test_credentials_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import json
import os
import unittest
from datetime import datetime, timezone
from pathlib import Path
from unittest.mock import patch

from kognic.auth.credentials_parser import (
ApiCredentials,
_check_expiry,
get_credentials_from_env,
parse_credentials,
resolve_credentials,
Expand Down Expand Up @@ -56,6 +58,27 @@ def test_parse_missing_key_raises(self):
parse_credentials(incomplete)
self.assertIn("email", str(ctx.exception))

def test_parse_stores_created_and_expires(self):
data = {
**VALID_CREDENTIALS_DICT,
"created": "2026-01-01T00:00:00.000000Z",
"expires": "2099-01-01T00:00:00.000000Z",
}
creds = parse_credentials(data)
self.assertEqual(creds.created, datetime(2026, 1, 1, tzinfo=timezone.utc))
self.assertEqual(creds.expires, datetime(2099, 1, 1, tzinfo=timezone.utc))

def test_parse_does_not_check_expiry(self):
"""parse_credentials should not raise even if expires is in the past."""
data = {**VALID_CREDENTIALS_DICT, "expires": "2000-01-01T00:00:00.000000Z"}
creds = parse_credentials(data)
self.assertEqual(creds.expires, datetime(2000, 1, 1, tzinfo=timezone.utc))

def test_parse_malformed_expires_returns_none(self):
data = {**VALID_CREDENTIALS_DICT, "expires": "not-a-date"}
creds = parse_credentials(data)
self.assertIsNone(creds.expires)


class TestGetCredentialsFromEnv(unittest.TestCase):
@patch.dict(os.environ, {}, clear=True)
Expand Down Expand Up @@ -210,5 +233,33 @@ def test_auth_keyring_uri_not_found_raises(self, _):
self.assertIn("missing-profile", str(ctx.exception))


def _make_creds(**kwargs) -> ApiCredentials:
return ApiCredentials(
client_id="id",
client_secret="secret",
email="a@b.com",
user_id=1,
issuer="issuer",
**kwargs,
)


class TestCheckExpiry(unittest.TestCase):
def test_no_expires_field(self):
_check_expiry(_make_creds()) # should not raise

def test_future_expires(self):
_check_expiry(_make_creds(expires=datetime(2099, 1, 1, tzinfo=timezone.utc))) # should not raise

def test_expired_raises(self):
with self.assertRaises(ValueError) as ctx:
_check_expiry(_make_creds(expires=datetime(2000, 1, 1, tzinfo=timezone.utc)))
self.assertIn("expired", str(ctx.exception))

def test_expired_with_time(self):
with self.assertRaises(ValueError):
_check_expiry(_make_creds(expires=datetime(2000, 6, 15, 12, 34, 56, 123456, tzinfo=timezone.utc)))


if __name__ == "__main__":
unittest.main()