diff --git a/invokeai/app/api/auth_dependencies.py b/invokeai/app/api/auth_dependencies.py index f5537890b63..a7b01931929 100644 --- a/invokeai/app/api/auth_dependencies.py +++ b/invokeai/app/api/auth_dependencies.py @@ -7,6 +7,9 @@ from invokeai.app.api.dependencies import ApiDependencies from invokeai.app.services.auth.token_service import TokenData, verify_token +from invokeai.backend.util.logging import logging + +logger = logging.getLogger(__name__) # HTTP Bearer token security scheme security = HTTPBearer(auto_error=False) @@ -61,6 +64,45 @@ async def get_current_user( return token_data +async def get_current_user_or_default( + credentials: Annotated[HTTPAuthorizationCredentials | None, Depends(security)], +) -> TokenData: + """Get current authenticated user from Bearer token, or return a default system user if not authenticated. + + This dependency is useful for endpoints that should work in both authenticated and non-authenticated contexts. + In single-user mode or when authentication is not provided, it returns a TokenData for the 'system' user. + + Args: + credentials: The HTTP authorization credentials containing the Bearer token + + Returns: + TokenData containing user information from the token, or system user if no credentials + """ + if credentials is None: + # Return system user for unauthenticated requests (single-user mode or backwards compatibility) + logger.debug("No authentication credentials provided, using system user") + return TokenData(user_id="system", email="system@system.invokeai", is_admin=False) + + token = credentials.credentials + token_data = verify_token(token) + + if token_data is None: + # Invalid token - still fall back to system user for backwards compatibility + logger.warning("Invalid or expired token provided, falling back to system user") + return TokenData(user_id="system", email="system@system.invokeai", is_admin=False) + + # Verify user still exists and is active + user_service = ApiDependencies.invoker.services.users + user = user_service.get(token_data.user_id) + + if user is None or not user.is_active: + # User doesn't exist or is inactive - fall back to system user + logger.warning(f"User {token_data.user_id} does not exist or is inactive, falling back to system user") + return TokenData(user_id="system", email="system@system.invokeai", is_admin=False) + + return token_data + + async def require_admin( current_user: Annotated[TokenData, Depends(get_current_user)], ) -> TokenData: @@ -82,4 +124,5 @@ async def require_admin( # Type aliases for convenient use in route dependencies CurrentUser = Annotated[TokenData, Depends(get_current_user)] +CurrentUserOrDefault = Annotated[TokenData, Depends(get_current_user_or_default)] AdminUser = Annotated[TokenData, Depends(require_admin)] diff --git a/invokeai/app/api/routers/client_state.py b/invokeai/app/api/routers/client_state.py index 188225760c7..2e34ea9fe6b 100644 --- a/invokeai/app/api/routers/client_state.py +++ b/invokeai/app/api/routers/client_state.py @@ -1,6 +1,7 @@ from fastapi import Body, HTTPException, Path, Query from fastapi.routing import APIRouter +from invokeai.app.api.auth_dependencies import CurrentUserOrDefault from invokeai.app.api.dependencies import ApiDependencies from invokeai.backend.util.logging import logging @@ -13,15 +14,16 @@ response_model=str | None, ) async def get_client_state_by_key( - queue_id: str = Path(description="The queue id to perform this operation on"), + current_user: CurrentUserOrDefault, + queue_id: str = Path(description="The queue id (ignored, kept for backwards compatibility)"), key: str = Query(..., description="Key to get"), ) -> str | None: - """Gets the client state""" + """Gets the client state for the current user (or system user if not authenticated)""" try: - return ApiDependencies.invoker.services.client_state_persistence.get_by_key(queue_id, key) + return ApiDependencies.invoker.services.client_state_persistence.get_by_key(current_user.user_id, key) except Exception as e: logging.error(f"Error getting client state: {e}") - raise HTTPException(status_code=500, detail="Error setting client state") + raise HTTPException(status_code=500, detail="Error getting client state") @client_state_router.post( @@ -30,13 +32,14 @@ async def get_client_state_by_key( response_model=str, ) async def set_client_state( - queue_id: str = Path(description="The queue id to perform this operation on"), + current_user: CurrentUserOrDefault, + queue_id: str = Path(description="The queue id (ignored, kept for backwards compatibility)"), key: str = Query(..., description="Key to set"), value: str = Body(..., description="Stringified value to set"), ) -> str: - """Sets the client state""" + """Sets the client state for the current user (or system user if not authenticated)""" try: - return ApiDependencies.invoker.services.client_state_persistence.set_by_key(queue_id, key, value) + return ApiDependencies.invoker.services.client_state_persistence.set_by_key(current_user.user_id, key, value) except Exception as e: logging.error(f"Error setting client state: {e}") raise HTTPException(status_code=500, detail="Error setting client state") @@ -48,11 +51,12 @@ async def set_client_state( responses={204: {"description": "Client state deleted"}}, ) async def delete_client_state( - queue_id: str = Path(description="The queue id to perform this operation on"), + current_user: CurrentUserOrDefault, + queue_id: str = Path(description="The queue id (ignored, kept for backwards compatibility)"), ) -> None: - """Deletes the client state""" + """Deletes the client state for the current user (or system user if not authenticated)""" try: - ApiDependencies.invoker.services.client_state_persistence.delete(queue_id) + ApiDependencies.invoker.services.client_state_persistence.delete(current_user.user_id) except Exception as e: logging.error(f"Error deleting client state: {e}") raise HTTPException(status_code=500, detail="Error deleting client state") diff --git a/invokeai/app/services/client_state_persistence/client_state_persistence_base.py b/invokeai/app/services/client_state_persistence/client_state_persistence_base.py index 193561ef898..99ad71bc8b7 100644 --- a/invokeai/app/services/client_state_persistence/client_state_persistence_base.py +++ b/invokeai/app/services/client_state_persistence/client_state_persistence_base.py @@ -4,15 +4,16 @@ class ClientStatePersistenceABC(ABC): """ Base class for client persistence implementations. - This class defines the interface for persisting client data. + This class defines the interface for persisting client data per user. """ @abstractmethod - def set_by_key(self, queue_id: str, key: str, value: str) -> str: + def set_by_key(self, user_id: str, key: str, value: str) -> str: """ Set a key-value pair for the client. Args: + user_id (str): The user ID to set state for. key (str): The key to set. value (str): The value to set for the key. @@ -22,11 +23,12 @@ def set_by_key(self, queue_id: str, key: str, value: str) -> str: pass @abstractmethod - def get_by_key(self, queue_id: str, key: str) -> str | None: + def get_by_key(self, user_id: str, key: str) -> str | None: """ Get the value for a specific key of the client. Args: + user_id (str): The user ID to get state for. key (str): The key to retrieve the value for. Returns: @@ -35,8 +37,11 @@ def get_by_key(self, queue_id: str, key: str) -> str | None: pass @abstractmethod - def delete(self, queue_id: str) -> None: + def delete(self, user_id: str) -> None: """ - Delete all client state. + Delete all client state for a user. + + Args: + user_id (str): The user ID to delete state for. """ pass diff --git a/invokeai/app/services/client_state_persistence/client_state_persistence_sqlite.py b/invokeai/app/services/client_state_persistence/client_state_persistence_sqlite.py index 36f22d96760..643db306857 100644 --- a/invokeai/app/services/client_state_persistence/client_state_persistence_sqlite.py +++ b/invokeai/app/services/client_state_persistence/client_state_persistence_sqlite.py @@ -1,5 +1,3 @@ -import json - from invokeai.app.services.client_state_persistence.client_state_persistence_base import ClientStatePersistenceABC from invokeai.app.services.invoker import Invoker from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase @@ -7,59 +5,51 @@ class ClientStatePersistenceSqlite(ClientStatePersistenceABC): """ - Base class for client persistence implementations. - This class defines the interface for persisting client data. + SQLite implementation for client state persistence. + This class stores client state data per user to prevent data leakage between users. """ def __init__(self, db: SqliteDatabase) -> None: super().__init__() self._db = db - self._default_row_id = 1 def start(self, invoker: Invoker) -> None: self._invoker = invoker - def _get(self) -> dict[str, str] | None: + def set_by_key(self, user_id: str, key: str, value: str) -> str: with self._db.transaction() as cursor: cursor.execute( - f""" - SELECT data FROM client_state - WHERE id = {self._default_row_id} """ + INSERT INTO client_state (user_id, key, value) + VALUES (?, ?, ?) + ON CONFLICT(user_id, key) DO UPDATE + SET value = excluded.value; + """, + (user_id, key, value), ) - row = cursor.fetchone() - if row is None: - return None - return json.loads(row[0]) - def set_by_key(self, queue_id: str, key: str, value: str) -> str: - state = self._get() or {} - state.update({key: value}) + return value + def get_by_key(self, user_id: str, key: str) -> str | None: with self._db.transaction() as cursor: cursor.execute( - f""" - INSERT INTO client_state (id, data) - VALUES ({self._default_row_id}, ?) - ON CONFLICT(id) DO UPDATE - SET data = excluded.data; + """ + SELECT value FROM client_state + WHERE user_id = ? AND key = ? """, - (json.dumps(state),), + (user_id, key), ) + row = cursor.fetchone() + if row is None: + return None + return row[0] - return value - - def get_by_key(self, queue_id: str, key: str) -> str | None: - state = self._get() - if state is None: - return None - return state.get(key, None) - - def delete(self, queue_id: str) -> None: + def delete(self, user_id: str) -> None: with self._db.transaction() as cursor: cursor.execute( - f""" - DELETE FROM client_state - WHERE id = {self._default_row_id} """ + DELETE FROM client_state + WHERE user_id = ? + """, + (user_id,), ) diff --git a/invokeai/app/services/shared/sqlite/sqlite_util.py b/invokeai/app/services/shared/sqlite/sqlite_util.py index 54a0450084a..ecf769a9cf4 100644 --- a/invokeai/app/services/shared/sqlite/sqlite_util.py +++ b/invokeai/app/services/shared/sqlite/sqlite_util.py @@ -28,6 +28,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_23 import build_migration_23 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_24 import build_migration_24 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_25 import build_migration_25 +from invokeai.app.services.shared.sqlite_migrator.migrations.migration_26 import build_migration_26 from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator @@ -73,6 +74,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto migrator.register_migration(build_migration_23(app_config=config, logger=logger)) migrator.register_migration(build_migration_24(app_config=config, logger=logger)) migrator.register_migration(build_migration_25()) + migrator.register_migration(build_migration_26()) migrator.run_migrations() return db diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_26.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_26.py new file mode 100644 index 00000000000..8f37404a81b --- /dev/null +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_26.py @@ -0,0 +1,120 @@ +"""Migration 26: Add user_id to client_state table for multi-user support. + +This migration updates the client_state table to support per-user state isolation: +- Drops the single-row constraint (CHECK(id = 1)) +- Adds user_id column +- Creates unique constraint on (user_id, key) pairs +- Migrates existing data to 'system' user +""" + +import json +import sqlite3 + +from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration + + +class Migration26Callback: + """Migration to add per-user client state support.""" + + def __call__(self, cursor: sqlite3.Cursor) -> None: + self._update_client_state_table(cursor) + + def _update_client_state_table(self, cursor: sqlite3.Cursor) -> None: + """Restructure client_state table to support per-user storage.""" + # Check if client_state table exists + cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='client_state';") + if cursor.fetchone() is None: + # Table doesn't exist, create it with the new schema + cursor.execute( + """ + CREATE TABLE client_state ( + user_id TEXT NOT NULL, + key TEXT NOT NULL, + value TEXT NOT NULL, + updated_at DATETIME NOT NULL DEFAULT (CURRENT_TIMESTAMP), + PRIMARY KEY (user_id, key), + FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE + ); + """ + ) + cursor.execute("CREATE INDEX IF NOT EXISTS idx_client_state_user_id ON client_state(user_id);") + cursor.execute( + """ + CREATE TRIGGER tg_client_state_updated_at + AFTER UPDATE ON client_state + FOR EACH ROW + BEGIN + UPDATE client_state + SET updated_at = CURRENT_TIMESTAMP + WHERE user_id = OLD.user_id AND key = OLD.key; + END; + """ + ) + return + + # Table exists with old schema - migrate it + # Get existing data + cursor.execute("SELECT data FROM client_state WHERE id = 1;") + row = cursor.fetchone() + existing_data = {} + if row is not None: + try: + existing_data = json.loads(row[0]) + except (json.JSONDecodeError, TypeError): + # If data is corrupt, just start fresh + pass + + # Drop the old table + cursor.execute("DROP TABLE IF EXISTS client_state;") + + # Create new table with per-user schema + cursor.execute( + """ + CREATE TABLE client_state ( + user_id TEXT NOT NULL, + key TEXT NOT NULL, + value TEXT NOT NULL, + updated_at DATETIME NOT NULL DEFAULT (CURRENT_TIMESTAMP), + PRIMARY KEY (user_id, key), + FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE + ); + """ + ) + + cursor.execute("CREATE INDEX IF NOT EXISTS idx_client_state_user_id ON client_state(user_id);") + + cursor.execute( + """ + CREATE TRIGGER tg_client_state_updated_at + AFTER UPDATE ON client_state + FOR EACH ROW + BEGIN + UPDATE client_state + SET updated_at = CURRENT_TIMESTAMP + WHERE user_id = OLD.user_id AND key = OLD.key; + END; + """ + ) + + # Migrate existing data to 'system' user + # The 'system' user is created by migration 25, so it's guaranteed to exist at this point + for key, value in existing_data.items(): + cursor.execute( + """ + INSERT INTO client_state (user_id, key, value) + VALUES ('system', ?, ?); + """, + (key, value), + ) + + +def build_migration_26() -> Migration: + """Builds the migration object for migrating from version 25 to version 26. + + This migration adds per-user client state support to prevent data leakage between users. + """ + return Migration( + from_version=25, + to_version=26, + callback=Migration26Callback(), + ) diff --git a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/driver.ts b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/driver.ts index 9e67770b436..fdb25b37d2c 100644 --- a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/driver.ts +++ b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/driver.ts @@ -68,10 +68,26 @@ const getIdbKey = (key: string) => { return `${IDB_STORAGE_PREFIX}${key}`; }; +// Helper to get auth headers for client_state requests +const getAuthHeaders = (): Record => { + const headers: Record = {}; + // Safe access to localStorage (not available in Node.js test environment) + if (typeof window !== 'undefined' && window.localStorage) { + const token = localStorage.getItem('auth_token'); + if (token) { + headers['Authorization'] = `Bearer ${token}`; + } + } + return headers; +}; + const getItem = async (key: string) => { try { const url = getUrl('get_by_key', key); - const res = await fetch(url, { method: 'GET' }); + const res = await fetch(url, { + method: 'GET', + headers: getAuthHeaders(), + }); if (!res.ok) { throw new Error(`Response status: ${res.status}`); } @@ -130,7 +146,11 @@ const setItem = async (key: string, value: string) => { } log.trace({ key, last: lastPersistedState.get(key), next: value }, `Persisting state for ${key}`); const url = getUrl('set_by_key', key); - const res = await fetch(url, { method: 'POST', body: value }); + const res = await fetch(url, { + method: 'POST', + body: value, + headers: getAuthHeaders(), + }); if (!res.ok) { throw new Error(`Response status: ${res.status}`); } @@ -158,7 +178,10 @@ export const clearStorage = async () => { try { persistRefCount++; const url = getUrl('delete'); - const res = await fetch(url, { method: 'POST' }); + const res = await fetch(url, { + method: 'POST', + headers: getAuthHeaders(), + }); if (!res.ok) { throw new Error(`Response status: ${res.status}`); } diff --git a/invokeai/frontend/web/src/features/auth/components/LoginPage.tsx b/invokeai/frontend/web/src/features/auth/components/LoginPage.tsx index 19ccf0949aa..e62b1289d06 100644 --- a/invokeai/frontend/web/src/features/auth/components/LoginPage.tsx +++ b/invokeai/frontend/web/src/features/auth/components/LoginPage.tsx @@ -52,13 +52,14 @@ export const LoginPage = memo(() => { is_active: result.user.is_active || true, }; dispatch(setCredentials({ token: result.token, user })); - // Navigate to main app after successful login - navigate('/app', { replace: true }); + // Force a page reload to ensure all user-specific state is loaded from server + // This is important for multiuser isolation to prevent state leakage + window.location.href = '/app'; } catch { // Error is handled by RTK Query and displayed via error state } }, - [email, password, rememberMe, login, dispatch, navigate] + [email, password, rememberMe, login, dispatch] ); const handleEmailChange = useCallback((e: ChangeEvent) => { diff --git a/invokeai/frontend/web/src/features/auth/store/authSlice.ts b/invokeai/frontend/web/src/features/auth/store/authSlice.ts index bcf932ca32d..6ac65ef03ce 100644 --- a/invokeai/frontend/web/src/features/auth/store/authSlice.ts +++ b/invokeai/frontend/web/src/features/auth/store/authSlice.ts @@ -21,9 +21,17 @@ const zAuthState = z.object({ type User = z.infer; type AuthState = z.infer; +// Helper to safely access localStorage (not available in test environment) +const getStoredAuthToken = (): string | null => { + if (typeof window !== 'undefined' && window.localStorage) { + return localStorage.getItem('auth_token'); + } + return null; +}; + const initialState: AuthState = { - isAuthenticated: !!localStorage.getItem('auth_token'), - token: localStorage.getItem('auth_token'), + isAuthenticated: !!getStoredAuthToken(), + token: getStoredAuthToken(), user: null, isLoading: false, }; @@ -38,13 +46,17 @@ const authSlice = createSlice({ state.token = action.payload.token; state.user = action.payload.user; state.isAuthenticated = true; - localStorage.setItem('auth_token', action.payload.token); + if (typeof window !== 'undefined' && window.localStorage) { + localStorage.setItem('auth_token', action.payload.token); + } }, logout: (state) => { state.token = null; state.user = null; state.isAuthenticated = false; - localStorage.removeItem('auth_token'); + if (typeof window !== 'undefined' && window.localStorage) { + localStorage.removeItem('auth_token'); + } }, setLoading: (state, action: PayloadAction) => { state.isLoading = action.payload; diff --git a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts index 0190aba602b..b4b328704e0 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts @@ -6,6 +6,7 @@ import { deepClone } from 'common/util/deepClone'; import { roundDownToMultiple, roundToMultiple } from 'common/util/roundDownToMultiple'; import { isPlainObject } from 'es-toolkit'; import { clamp } from 'es-toolkit/compat'; +import { logout } from 'features/auth/store/authSlice'; import type { AspectRatioID, InfillMethod, ParamsState, RgbaColor } from 'features/controlLayers/store/types'; import { ASPECT_RATIO_MAP, @@ -401,6 +402,12 @@ const slice = createSlice({ }, paramsReset: (state) => resetState(state), }, + extraReducers(builder) { + // Reset params state on logout to prevent user data leakage when switching users + builder.addCase(logout, () => { + return getInitialParamsState(); + }); + }, }); const applyClipSkip = (state: { clipSkip: number }, model: ParameterModel | null, clipSkip: number) => { diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 6562358551e..5ce98f98379 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -2053,7 +2053,7 @@ export type paths = { }; /** * Get Client State By Key - * @description Gets the client state + * @description Gets the client state for the current user (or system user if not authenticated) */ get: operations["get_client_state_by_key"]; put?: never; @@ -2075,7 +2075,7 @@ export type paths = { put?: never; /** * Set Client State - * @description Sets the client state + * @description Sets the client state for the current user (or system user if not authenticated) */ post: operations["set_client_state"]; delete?: never; @@ -2095,7 +2095,7 @@ export type paths = { put?: never; /** * Delete Client State - * @description Deletes the client state + * @description Deletes the client state for the current user (or system user if not authenticated) */ post: operations["delete_client_state"]; delete?: never; @@ -30860,7 +30860,7 @@ export interface operations { }; header?: never; path: { - /** @description The queue id to perform this operation on */ + /** @description The queue id (ignored, kept for backwards compatibility) */ queue_id: string; }; cookie?: never; @@ -30895,7 +30895,7 @@ export interface operations { }; header?: never; path: { - /** @description The queue id to perform this operation on */ + /** @description The queue id (ignored, kept for backwards compatibility) */ queue_id: string; }; cookie?: never; @@ -30931,7 +30931,7 @@ export interface operations { query?: never; header?: never; path: { - /** @description The queue id to perform this operation on */ + /** @description The queue id (ignored, kept for backwards compatibility) */ queue_id: string; }; cookie?: never; diff --git a/tests/app/routers/test_client_state_multiuser.py b/tests/app/routers/test_client_state_multiuser.py new file mode 100644 index 00000000000..2b67e8c0165 --- /dev/null +++ b/tests/app/routers/test_client_state_multiuser.py @@ -0,0 +1,296 @@ +"""Tests for multiuser client state functionality.""" + +from typing import Any + +import pytest +from fastapi import status +from fastapi.testclient import TestClient + +from invokeai.app.api.dependencies import ApiDependencies +from invokeai.app.api_app import app +from invokeai.app.services.invoker import Invoker +from invokeai.app.services.users.users_common import UserCreateRequest + + +@pytest.fixture +def client(): + """Create a test client.""" + return TestClient(app) + + +class MockApiDependencies(ApiDependencies): + """Mock API dependencies for testing.""" + + invoker: Invoker + + def __init__(self, invoker: Invoker) -> None: + self.invoker = invoker + + +def setup_test_user( + mock_invoker: Invoker, email: str, display_name: str, password: str = "TestPass123", is_admin: bool = False +) -> str: + """Helper to create a test user and return user_id.""" + user_service = mock_invoker.services.users + user_data = UserCreateRequest( + email=email, + display_name=display_name, + password=password, + is_admin=is_admin, + ) + user = user_service.create(user_data) + return user.user_id + + +def get_user_token(client: TestClient, email: str, password: str = "TestPass123") -> str: + """Helper to login and get a user token.""" + response = client.post( + "/api/v1/auth/login", + json={ + "email": email, + "password": password, + "remember_me": False, + }, + ) + assert response.status_code == 200 + return response.json()["token"] + + +@pytest.fixture +def admin_token(monkeypatch: Any, mock_invoker: Invoker, client: TestClient): + """Get an admin token for testing.""" + # Mock ApiDependencies for auth and client_state routers + monkeypatch.setattr("invokeai.app.api.routers.auth.ApiDependencies", MockApiDependencies(mock_invoker)) + monkeypatch.setattr("invokeai.app.api.auth_dependencies.ApiDependencies", MockApiDependencies(mock_invoker)) + monkeypatch.setattr("invokeai.app.api.routers.client_state.ApiDependencies", MockApiDependencies(mock_invoker)) + + # Create admin user + setup_test_user(mock_invoker, "admin@test.com", "Admin User", is_admin=True) + + return get_user_token(client, "admin@test.com") + + +@pytest.fixture +def user1_token(monkeypatch: Any, mock_invoker: Invoker, client: TestClient, admin_token: str): + """Get a token for test user 1.""" + # Create a regular user + setup_test_user(mock_invoker, "user1@test.com", "User One", is_admin=False) + + return get_user_token(client, "user1@test.com") + + +@pytest.fixture +def user2_token(monkeypatch: Any, mock_invoker: Invoker, client: TestClient, admin_token: str): + """Get a token for test user 2.""" + # Create another regular user + setup_test_user(mock_invoker, "user2@test.com", "User Two", is_admin=False) + + return get_user_token(client, "user2@test.com") + + +def test_get_client_state_without_auth_uses_system_user(client: TestClient, monkeypatch, mock_invoker: Invoker): + """Test that getting client state without authentication uses the system user.""" + # Mock ApiDependencies + monkeypatch.setattr("invokeai.app.api.auth_dependencies.ApiDependencies", MockApiDependencies(mock_invoker)) + monkeypatch.setattr("invokeai.app.api.routers.client_state.ApiDependencies", MockApiDependencies(mock_invoker)) + + # Set a value for the system user directly + mock_invoker.services.client_state_persistence.set_by_key("system", "test_key", "system_value") + + # Get without authentication - should return system user's value + response = client.get("/api/v1/client_state/default/get_by_key?key=test_key") + assert response.status_code == status.HTTP_200_OK + assert response.json() == "system_value" + + +def test_set_client_state_without_auth_uses_system_user(client: TestClient, monkeypatch, mock_invoker: Invoker): + """Test that setting client state without authentication uses the system user.""" + # Mock ApiDependencies + monkeypatch.setattr("invokeai.app.api.auth_dependencies.ApiDependencies", MockApiDependencies(mock_invoker)) + monkeypatch.setattr("invokeai.app.api.routers.client_state.ApiDependencies", MockApiDependencies(mock_invoker)) + + # Set without authentication - should set for system user + response = client.post( + "/api/v1/client_state/default/set_by_key?key=test_key", + json="unauthenticated_value", + ) + assert response.status_code == status.HTTP_200_OK + + # Verify it was set for system user + value = mock_invoker.services.client_state_persistence.get_by_key("system", "test_key") + assert value == "unauthenticated_value" + + +def test_delete_client_state_without_auth_uses_system_user(client: TestClient, monkeypatch, mock_invoker: Invoker): + """Test that deleting client state without authentication uses the system user.""" + # Mock ApiDependencies + monkeypatch.setattr("invokeai.app.api.auth_dependencies.ApiDependencies", MockApiDependencies(mock_invoker)) + monkeypatch.setattr("invokeai.app.api.routers.client_state.ApiDependencies", MockApiDependencies(mock_invoker)) + + # Set a value for system user + mock_invoker.services.client_state_persistence.set_by_key("system", "test_key", "system_value") + + # Delete without authentication - should delete system user's data + response = client.post("/api/v1/client_state/default/delete") + assert response.status_code == status.HTTP_200_OK + + # Verify it was deleted for system user + value = mock_invoker.services.client_state_persistence.get_by_key("system", "test_key") + assert value is None + + +def test_set_and_get_client_state(client: TestClient, admin_token: str): + """Test that authenticated users can set and get their client state.""" + # Set a value + set_response = client.post( + "/api/v1/client_state/default/set_by_key?key=test_key", + json="test_value", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + assert set_response.status_code == status.HTTP_200_OK + assert set_response.json() == "test_value" + + # Get the value back + get_response = client.get( + "/api/v1/client_state/default/get_by_key?key=test_key", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + assert get_response.status_code == status.HTTP_200_OK + assert get_response.json() == "test_value" + + +def test_client_state_isolation_between_users(client: TestClient, user1_token: str, user2_token: str): + """Test that client state is isolated between different users.""" + # User 1 sets a value + user1_set_response = client.post( + "/api/v1/client_state/default/set_by_key?key=shared_key", + json="user1_value", + headers={"Authorization": f"Bearer {user1_token}"}, + ) + assert user1_set_response.status_code == status.HTTP_200_OK + + # User 2 sets a different value for the same key + user2_set_response = client.post( + "/api/v1/client_state/default/set_by_key?key=shared_key", + json="user2_value", + headers={"Authorization": f"Bearer {user2_token}"}, + ) + assert user2_set_response.status_code == status.HTTP_200_OK + + # User 1 should still see their own value + user1_get_response = client.get( + "/api/v1/client_state/default/get_by_key?key=shared_key", + headers={"Authorization": f"Bearer {user1_token}"}, + ) + assert user1_get_response.status_code == status.HTTP_200_OK + assert user1_get_response.json() == "user1_value" + + # User 2 should see their own value + user2_get_response = client.get( + "/api/v1/client_state/default/get_by_key?key=shared_key", + headers={"Authorization": f"Bearer {user2_token}"}, + ) + assert user2_get_response.status_code == status.HTTP_200_OK + assert user2_get_response.json() == "user2_value" + + +def test_get_nonexistent_key_returns_null(client: TestClient, admin_token: str): + """Test that getting a nonexistent key returns null.""" + response = client.get( + "/api/v1/client_state/default/get_by_key?key=nonexistent_key", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + assert response.status_code == status.HTTP_200_OK + assert response.json() is None + + +def test_delete_client_state(client: TestClient, admin_token: str): + """Test that users can delete their own client state.""" + # Set some values + client.post( + "/api/v1/client_state/default/set_by_key?key=key1", + json="value1", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + client.post( + "/api/v1/client_state/default/set_by_key?key=key2", + json="value2", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + + # Verify values exist + get_response = client.get( + "/api/v1/client_state/default/get_by_key?key=key1", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + assert get_response.json() == "value1" + + # Delete all client state + delete_response = client.post( + "/api/v1/client_state/default/delete", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + assert delete_response.status_code == status.HTTP_200_OK + + # Verify values are gone + get_response = client.get( + "/api/v1/client_state/default/get_by_key?key=key1", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + assert get_response.json() is None + + get_response = client.get( + "/api/v1/client_state/default/get_by_key?key=key2", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + assert get_response.json() is None + + +def test_update_existing_key(client: TestClient, admin_token: str): + """Test that updating an existing key works correctly.""" + # Set initial value + client.post( + "/api/v1/client_state/default/set_by_key?key=update_key", + json="initial_value", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + + # Update the value + update_response = client.post( + "/api/v1/client_state/default/set_by_key?key=update_key", + json="updated_value", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + assert update_response.status_code == status.HTTP_200_OK + + # Verify the updated value + get_response = client.get( + "/api/v1/client_state/default/get_by_key?key=update_key", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + assert get_response.status_code == status.HTTP_200_OK + assert get_response.json() == "updated_value" + + +def test_complex_json_values(client: TestClient, admin_token: str): + """Test that complex JSON values can be stored and retrieved.""" + import json + + complex_dict = {"params": {"model": "test-model", "steps": 50}, "prompt": "a beautiful landscape"} + complex_value = json.dumps(complex_dict) + + # Set complex value + set_response = client.post( + "/api/v1/client_state/default/set_by_key?key=complex_key", + json=complex_value, + headers={"Authorization": f"Bearer {admin_token}"}, + ) + assert set_response.status_code == status.HTTP_200_OK + + # Get it back + get_response = client.get( + "/api/v1/client_state/default/get_by_key?key=complex_key", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + assert get_response.status_code == status.HTTP_200_OK + assert get_response.json() == complex_value diff --git a/tests/conftest.py b/tests/conftest.py index 84e66b0501d..980a99611ab 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,6 +14,7 @@ from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage from invokeai.app.services.boards.boards_default import BoardService from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService +from invokeai.app.services.client_state_persistence.client_state_persistence_sqlite import ClientStatePersistenceSqlite from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.image_records.image_records_sqlite import SqliteImageRecordStorage from invokeai.app.services.images.images_default import ImageService @@ -64,7 +65,7 @@ def mock_services() -> InvocationServices: workflow_thumbnails=None, # type: ignore model_relationship_records=None, # type: ignore model_relationships=None, # type: ignore - client_state_persistence=None, # type: ignore + client_state_persistence=ClientStatePersistenceSqlite(db=db), users=UserService(db), )