diff --git a/python/ai-server/src/ai_server/server_resources/server_client.py b/python/ai-server/src/ai_server/server_resources/server_client.py index 6c9be57..fdb0684 100644 --- a/python/ai-server/src/ai_server/server_resources/server_client.py +++ b/python/ai-server/src/ai_server/server_resources/server_client.py @@ -1,4 +1,4 @@ -from typing import Any, List, Dict, Union, Optional, Set, Generator, Tuple +from typing import Any, List, Dict, Union, Optional, Set, Generator, Tuple, Callable import requests import json import pandas as pd @@ -7,6 +7,8 @@ from urllib.parse import urlparse, unquote from pathlib import Path import os +import time +from datetime import datetime, timedelta logger: logging.Logger = logging.getLogger(__name__) @@ -14,8 +16,9 @@ class ServerClient: """ServerClient to make calls to a ai server instance - Example: + Examples: + Method 1: User Access Key Authentication ```python >>> import ai_server @@ -23,7 +26,43 @@ class ServerClient: >>> loginKeys = {"secretKey":"","accessKey":""} # create connection object by passing in the secret key, access key and base url for the api - >>> server_connection = ai_server.ServerClient(access_key=loginKeys['accessKey'], secret_key=loginKeys['secretKey'], base='') + >>> server_connection = ai_server.ServerClient( + ... access_key=loginKeys['accessKey'], + ... secret_key=loginKeys['secretKey'], + ... base='' + ... ) + ``` + + Method 2: Bearer Token Authentication + ```python + >>> import ai_server + + # create connection with bearer token from IdP + >>> server_connection = ai_server.ServerClient( + ... bearer_token='', + ... bearer_token_provider='', + ... base='' + ... ) + ``` + + Method 3: OAuth Device Code Flow + ```python + >>> import ai_server + + # create connection using OAuth device code flow + # tokens are cached by default - you'll only need to login once + >>> server_connection = ai_server.ServerClient( + ... base='', + ... use_device_code_flow=True, + ... device_code_client_id='', + ... device_code_auth_url='https://oauth.example.com/device/code', + ... device_code_token_url='https://oauth.example.com/token', + ... device_code_scope='openid profile email', # optional + ... device_code_callback=lambda user_code, uri: print(f"Go to {uri} and enter {user_code}"), # optional + ... cache_token=True # default - caches token to avoid repeated logins + ... ) + # The user will be prompted to visit a URL and enter a code to authorize (only on first run) + # Subsequent runs will use the cached token automatically ``` """ @@ -37,6 +76,14 @@ def __init__( secret_key: Optional[str] = None, bearer_token: Optional[str] = None, bearer_token_provider: Optional[str] = None, + use_device_code_flow: bool = False, + device_code_client_id: Optional[str] = None, + device_code_auth_url: Optional[str] = None, + device_code_token_url: Optional[str] = None, + device_code_scope: Optional[str] = None, + device_code_callback: Optional[Callable[[str, str], None]] = None, + cache_token: bool = True, + token_cache_path: Optional[str] = None, ) -> None: """ Args: @@ -50,6 +97,22 @@ def __init__( A token provided from a successful login of a user to an existing IdP bearer_token_provider (`Optional[str]`): The existing IdP login type as recognized by the AuthProvider enum list + use_device_code_flow (`bool`): + Whether to use OAuth 2.0 device authorization grant flow for authentication + device_code_client_id (`Optional[str]`): + The OAuth client ID for device code flow + device_code_auth_url (`Optional[str]`): + The authorization endpoint URL for device code flow (e.g., https://oauth.example.com/device/code) + device_code_token_url (`Optional[str]`): + The token endpoint URL for device code flow (e.g., https://oauth.example.com/token) + device_code_scope (`Optional[str]`): + The OAuth scope to request during device code flow + device_code_callback (`Optional[Callable[[str, str], None]]`): + Optional callback function that receives (user_code, verification_uri) to display to the user + cache_token (`bool`): + Whether to cache authentication tokens locally to avoid repeated logins (default: True) + token_cache_path (`Optional[str]`): + Path to store cached tokens. Defaults to ~/.ai_server/tokens.json if not specified """ # set the base url as an instance attribute self.main_url: str = base @@ -69,18 +132,57 @@ def __init__( # set the secret key as an instance attribute self.bearer_token_provider: str = bearer_token_provider + # Device code flow attributes + self.use_device_code_flow: bool = use_device_code_flow + self.device_code_client_id: str = device_code_client_id + self.device_code_auth_url: str = device_code_auth_url + self.device_code_token_url: str = device_code_token_url + self.device_code_scope: str = device_code_scope + self.device_code_callback: Optional[Callable[[str, str], None]] = ( + device_code_callback + ) + + # Token caching attributes + self.cache_token: bool = cache_token + if token_cache_path: + self.token_cache_path: str = token_cache_path + else: + # Default to ~/.ai_server/tokens.json + home_dir = Path.home() + cache_dir = home_dir / ".ai_server" + self.token_cache_path: str = str(cache_dir / "tokens.json") + + # Token expiration tracking + self.token_expires_at: Optional[datetime] = None + self.refresh_token: Optional[str] = None + useUserAccessKey = (self.access_key is not None and self.access_key != "") and ( self.secret_key is not None and self.secret_key != "" ) - useBearerToken = ( - self.bearer_token is not None and self.bearer_token != "" - ) and ( - self.bearer_token_provider is not None and self.bearer_token_provider != "" + useBearerToken = self.bearer_token is not None and self.bearer_token != "" + + useDeviceCodeFlow = ( + self.use_device_code_flow + and ( + self.device_code_client_id is not None + and self.device_code_client_id != "" + ) + and ( + self.device_code_auth_url is not None + and self.device_code_auth_url != "" + ) + and ( + self.device_code_token_url is not None + and self.device_code_token_url != "" + ) ) - if not useUserAccessKey and not useBearerToken: + if not useUserAccessKey and not useBearerToken and not useDeviceCodeFlow: raise Exception( - "Must provide either access_key and secret_key for user access login or provide bearer_token and bearer_token_provider for login using your IdP access key" + "Must provide either:\n" + " 1) access_key and secret_key for user access login, or\n" + " 2) bearer_token (with optional bearer_token_provider) for IdP login, or\n" + " 3) use_device_code_flow=True with device_code_client_id, device_code_auth_url, and device_code_token_url for OAuth device code flow" ) # TODO provide definitons for all of these attributes @@ -88,8 +190,12 @@ def __init__( self.auth_headers: Dict = {} if useUserAccessKey: self.loginUserAccessKey() - else: + elif useBearerToken: self.loginBearerToken() + else: + # For device code flow, try loading cached token first + if not self._load_token_cache(): + self.loginDeviceCodeFlow() # This will hold CSRF and any other required headers (merge into all calls) self.required_headers = {} @@ -134,8 +240,9 @@ def loginBearerToken(self): """ headers = { "Authorization": f"Bearer {self.bearer_token}", - "Bearer-Provider": self.bearer_token_provider, } + if self.bearer_token_provider is not None and self.bearer_token_provider != "": + headers["Bearer-Provider"] = self.bearer_token_provider self.auth_headers: Dict = headers.copy() # make sure user is authenticated @@ -147,6 +254,294 @@ def loginBearerToken(self): # display the cookies logger.info(self.cookies) + def loginDeviceCodeFlow(self): + """ + Authenticate using OAuth 2.0 Device Authorization Grant flow. + + This flow is designed for devices with limited input capabilities or no browser. + It involves: + 1. Requesting a device code from the authorization server + 2. Displaying a user code and verification URI to the user + 3. Polling the token endpoint until authorization is complete + 4. Storing the access token for subsequent API calls + """ + # Step 1: Request device code + device_code_payload = { + "client_id": self.device_code_client_id, + } + + if self.device_code_scope: + device_code_payload["scope"] = self.device_code_scope + + try: + device_response = requests.post( + self.device_code_auth_url, + data=device_code_payload, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + device_response.raise_for_status() + device_data = device_response.json() + except Exception as e: + raise AuthenticationError(f"Failed to request device code: {str(e)}") + + # Extract device code response data + device_code = device_data.get("device_code") + user_code = device_data.get("user_code") + verification_uri = device_data.get("verification_uri") + verification_uri_complete = device_data.get("verification_uri_complete") + expires_in = device_data.get("expires_in", 900) # Default 15 minutes + interval = device_data.get("interval", 5) # Default 5 seconds polling interval + + if not device_code or not user_code or not verification_uri: + raise AuthenticationError( + "Invalid device code response from authorization server" + ) + + # Step 2: Display user code and verification URI + display_uri = verification_uri_complete or verification_uri + + if self.device_code_callback: + # Use custom callback if provided + self.device_code_callback(user_code, display_uri) + else: + # Default: print to console + logger.info("=" * 60) + logger.info("DEVICE CODE AUTHENTICATION") + logger.info("=" * 60) + logger.info(f"Please visit: {display_uri}") + logger.info(f"And enter code: {user_code}") + logger.info("=" * 60) + print(f"\nPlease visit: {display_uri}") + print(f"And enter code: {user_code}\n") + + # Step 3: Poll the token endpoint + token_payload = { + "client_id": self.device_code_client_id, + "device_code": device_code, + "grant_type": "urn:ietf:params:oauth:grant-type:device_code", + } + + start_time = time.time() + access_token = None + + while time.time() - start_time < expires_in: + time.sleep(interval) + + try: + token_response = requests.post( + self.device_code_token_url, + data=token_payload, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + + token_data = token_response.json() + + # Check for errors + error = token_data.get("error") + + if error == "authorization_pending": + # User hasn't authorized yet, continue polling + logger.debug("Authorization pending, continuing to poll...") + continue + elif error == "slow_down": + # Increase polling interval + interval += 5 + logger.debug(f"Slowing down polling interval to {interval} seconds") + continue + elif error == "expired_token": + raise AuthenticationError( + "Device code has expired. Please try again." + ) + elif error == "access_denied": + raise AuthenticationError("User denied the authorization request.") + elif error: + raise AuthenticationError( + f"OAuth error: {error} - {token_data.get('error_description', '')}" + ) + + # Success! We have an access token + if token_response.status_code == 200: + access_token = token_data.get("access_token") + if access_token: + logger.info( + "Successfully obtained access token via device code flow" + ) + break + + except requests.exceptions.RequestException as e: + logger.warning(f"Error during token polling: {str(e)}") + continue + + if not access_token: + raise AuthenticationError( + "Failed to obtain access token: timeout or authorization not completed" + ) + + # Store the access token in bearer token format + self.bearer_token = access_token + + # Set authorization headers + headers = { + "Authorization": f"Bearer {access_token}", + } + + # If there's a provider specified, include it + if self.bearer_token_provider: + headers["Bearer-Provider"] = self.bearer_token_provider + + self.auth_headers: Dict = headers.copy() + + # Verify the token works + response, is_logged_in = self.is_session_login(self.auth_headers) + if not is_logged_in: + raise AuthenticationError("Access token obtained but authentication failed") + + self.cookies = response.cookies + logger.info("Device code flow authentication completed successfully") + + # Cache the token if enabled + if self.cache_token: + expires_in = token_data.get("expires_in") + if expires_in: + self.token_expires_at = datetime.now() + timedelta(seconds=expires_in) + + refresh_token = token_data.get("refresh_token") + if refresh_token: + self.refresh_token = refresh_token + + self._save_token_cache() + + def _get_cache_key(self) -> str: + """Generate a unique cache key based on connection parameters""" + # Use base URL + client ID as the cache key + import hashlib + + key_parts = [self.main_url] + if self.device_code_client_id: + key_parts.append(self.device_code_client_id) + key_string = "|".join(key_parts) + return hashlib.sha256(key_string.encode()).hexdigest()[:16] + + def _save_token_cache(self) -> None: + """Save the current token to cache file""" + if not self.cache_token or not self.bearer_token: + return + + try: + # Ensure cache directory exists + cache_path = Path(self.token_cache_path) + cache_path.parent.mkdir(parents=True, exist_ok=True) + + # Load existing cache or create new + if cache_path.exists(): + with open(cache_path, "r") as f: + cache_data = json.load(f) + else: + cache_data = {} + + # Store token with metadata + cache_key = self._get_cache_key() + cache_data[cache_key] = { + "access_token": self.bearer_token, + "bearer_token_provider": self.bearer_token_provider, + "expires_at": ( + self.token_expires_at.isoformat() if self.token_expires_at else None + ), + "refresh_token": self.refresh_token, + "base_url": self.main_url, + } + + # Write cache file + with open(cache_path, "w") as f: + json.dump(cache_data, f, indent=2) + + logger.info(f"Token cached successfully at {self.token_cache_path}") + except Exception as e: + logger.warning(f"Failed to cache token: {str(e)}") + + def _load_token_cache(self) -> bool: + """ + Load and validate token from cache. + + Returns: + bool: True if valid cached token was loaded, False otherwise + """ + if not self.cache_token: + return False + + try: + cache_path = Path(self.token_cache_path) + if not cache_path.exists(): + logger.debug("No token cache file found") + return False + + with open(cache_path, "r") as f: + cache_data = json.load(f) + + cache_key = self._get_cache_key() + if cache_key not in cache_data: + logger.debug("No cached token found for this connection") + return False + + token_data = cache_data[cache_key] + + # Check if token is expired + if token_data.get("expires_at"): + expires_at = datetime.fromisoformat(token_data["expires_at"]) + if datetime.now() >= expires_at: + logger.info("Cached token has expired") + return False + + # Load token data + self.bearer_token = token_data.get("access_token") + self.bearer_token_provider = token_data.get("bearer_token_provider") + self.refresh_token = token_data.get("refresh_token") + if token_data.get("expires_at"): + self.token_expires_at = datetime.fromisoformat(token_data["expires_at"]) + + # Validate the token by attempting to use it + headers = { + "Authorization": f"Bearer {self.bearer_token}", + } + if self.bearer_token_provider: + headers["Bearer-Provider"] = self.bearer_token_provider + + self.auth_headers = headers.copy() + + response, is_logged_in = self.is_session_login(self.auth_headers) + if is_logged_in: + self.cookies = response.cookies + logger.info("Successfully loaded and validated cached token") + return True + else: + logger.info("Cached token is invalid") + return False + + except Exception as e: + logger.warning(f"Failed to load token cache: {str(e)}") + return False + + def clear_token_cache(self) -> None: + """Clear the cached token for this connection""" + try: + cache_path = Path(self.token_cache_path) + if not cache_path.exists(): + return + + with open(cache_path, "r") as f: + cache_data = json.load(f) + + cache_key = self._get_cache_key() + if cache_key in cache_data: + del cache_data[cache_key] + + with open(cache_path, "w") as f: + json.dump(cache_data, f, indent=2) + + logger.info("Token cache cleared") + except Exception as e: + logger.warning(f"Failed to clear token cache: {str(e)}") + def reconnect(self): self.logout() # Call __init__ to reconnect to the server on session timeout @@ -156,6 +551,14 @@ def reconnect(self): self.secret_key, self.bearer_token, self.bearer_token_provider, + self.use_device_code_flow, + self.device_code_client_id, + self.device_code_auth_url, + self.device_code_token_url, + self.device_code_scope, + self.device_code_callback, + self.cache_token, + self.token_cache_path, ) def is_session_login(