Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions invokeai/app/api/auth_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.auth.token_service import TokenData, verify_token
from invokeai.backend.util.logging import logging

logger = logging.getLogger(__name__)

# HTTP Bearer token security scheme
security = HTTPBearer(auto_error=False)
Expand Down Expand Up @@ -61,6 +64,45 @@ async def get_current_user(
return token_data


async def get_current_user_or_default(
credentials: Annotated[HTTPAuthorizationCredentials | None, Depends(security)],
) -> TokenData:
"""Get current authenticated user from Bearer token, or return a default system user if not authenticated.

This dependency is useful for endpoints that should work in both authenticated and non-authenticated contexts.
In single-user mode or when authentication is not provided, it returns a TokenData for the 'system' user.

Args:
credentials: The HTTP authorization credentials containing the Bearer token

Returns:
TokenData containing user information from the token, or system user if no credentials
"""
if credentials is None:
# Return system user for unauthenticated requests (single-user mode or backwards compatibility)
logger.debug("No authentication credentials provided, using system user")
return TokenData(user_id="system", email="system@system.invokeai", is_admin=False)

token = credentials.credentials
token_data = verify_token(token)

if token_data is None:
# Invalid token - still fall back to system user for backwards compatibility
logger.warning("Invalid or expired token provided, falling back to system user")
return TokenData(user_id="system", email="system@system.invokeai", is_admin=False)

# Verify user still exists and is active
user_service = ApiDependencies.invoker.services.users
user = user_service.get(token_data.user_id)

if user is None or not user.is_active:
# User doesn't exist or is inactive - fall back to system user
logger.warning(f"User {token_data.user_id} does not exist or is inactive, falling back to system user")
return TokenData(user_id="system", email="system@system.invokeai", is_admin=False)

return token_data


async def require_admin(
current_user: Annotated[TokenData, Depends(get_current_user)],
) -> TokenData:
Expand All @@ -82,4 +124,5 @@ async def require_admin(

# Type aliases for convenient use in route dependencies
CurrentUser = Annotated[TokenData, Depends(get_current_user)]
CurrentUserOrDefault = Annotated[TokenData, Depends(get_current_user_or_default)]
AdminUser = Annotated[TokenData, Depends(require_admin)]
24 changes: 14 additions & 10 deletions invokeai/app/api/routers/client_state.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from fastapi import Body, HTTPException, Path, Query
from fastapi.routing import APIRouter

from invokeai.app.api.auth_dependencies import CurrentUserOrDefault
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.backend.util.logging import logging

Expand All @@ -13,15 +14,16 @@
response_model=str | None,
)
async def get_client_state_by_key(
queue_id: str = Path(description="The queue id to perform this operation on"),
current_user: CurrentUserOrDefault,
queue_id: str = Path(description="The queue id (ignored, kept for backwards compatibility)"),
key: str = Query(..., description="Key to get"),
) -> str | None:
"""Gets the client state"""
"""Gets the client state for the current user (or system user if not authenticated)"""
try:
return ApiDependencies.invoker.services.client_state_persistence.get_by_key(queue_id, key)
return ApiDependencies.invoker.services.client_state_persistence.get_by_key(current_user.user_id, key)
except Exception as e:
logging.error(f"Error getting client state: {e}")
raise HTTPException(status_code=500, detail="Error setting client state")
raise HTTPException(status_code=500, detail="Error getting client state")


@client_state_router.post(
Expand All @@ -30,13 +32,14 @@ async def get_client_state_by_key(
response_model=str,
)
async def set_client_state(
queue_id: str = Path(description="The queue id to perform this operation on"),
current_user: CurrentUserOrDefault,
queue_id: str = Path(description="The queue id (ignored, kept for backwards compatibility)"),
key: str = Query(..., description="Key to set"),
value: str = Body(..., description="Stringified value to set"),
) -> str:
"""Sets the client state"""
"""Sets the client state for the current user (or system user if not authenticated)"""
try:
return ApiDependencies.invoker.services.client_state_persistence.set_by_key(queue_id, key, value)
return ApiDependencies.invoker.services.client_state_persistence.set_by_key(current_user.user_id, key, value)
except Exception as e:
logging.error(f"Error setting client state: {e}")
raise HTTPException(status_code=500, detail="Error setting client state")
Expand All @@ -48,11 +51,12 @@ async def set_client_state(
responses={204: {"description": "Client state deleted"}},
)
async def delete_client_state(
queue_id: str = Path(description="The queue id to perform this operation on"),
current_user: CurrentUserOrDefault,
queue_id: str = Path(description="The queue id (ignored, kept for backwards compatibility)"),
) -> None:
"""Deletes the client state"""
"""Deletes the client state for the current user (or system user if not authenticated)"""
try:
ApiDependencies.invoker.services.client_state_persistence.delete(queue_id)
ApiDependencies.invoker.services.client_state_persistence.delete(current_user.user_id)
except Exception as e:
logging.error(f"Error deleting client state: {e}")
raise HTTPException(status_code=500, detail="Error deleting client state")
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@
class ClientStatePersistenceABC(ABC):
"""
Base class for client persistence implementations.
This class defines the interface for persisting client data.
This class defines the interface for persisting client data per user.
"""

@abstractmethod
def set_by_key(self, queue_id: str, key: str, value: str) -> str:
def set_by_key(self, user_id: str, key: str, value: str) -> str:
"""
Set a key-value pair for the client.

Args:
user_id (str): The user ID to set state for.
key (str): The key to set.
value (str): The value to set for the key.

Expand All @@ -22,11 +23,12 @@ def set_by_key(self, queue_id: str, key: str, value: str) -> str:
pass

@abstractmethod
def get_by_key(self, queue_id: str, key: str) -> str | None:
def get_by_key(self, user_id: str, key: str) -> str | None:
"""
Get the value for a specific key of the client.

Args:
user_id (str): The user ID to get state for.
key (str): The key to retrieve the value for.

Returns:
Expand All @@ -35,8 +37,11 @@ def get_by_key(self, queue_id: str, key: str) -> str | None:
pass

@abstractmethod
def delete(self, queue_id: str) -> None:
def delete(self, user_id: str) -> None:
"""
Delete all client state.
Delete all client state for a user.

Args:
user_id (str): The user ID to delete state for.
"""
pass
Original file line number Diff line number Diff line change
@@ -1,65 +1,55 @@
import json

from invokeai.app.services.client_state_persistence.client_state_persistence_base import ClientStatePersistenceABC
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase


class ClientStatePersistenceSqlite(ClientStatePersistenceABC):
"""
Base class for client persistence implementations.
This class defines the interface for persisting client data.
SQLite implementation for client state persistence.
This class stores client state data per user to prevent data leakage between users.
"""

def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._db = db
self._default_row_id = 1

def start(self, invoker: Invoker) -> None:
self._invoker = invoker

def _get(self) -> dict[str, str] | None:
def set_by_key(self, user_id: str, key: str, value: str) -> str:
with self._db.transaction() as cursor:
cursor.execute(
f"""
SELECT data FROM client_state
WHERE id = {self._default_row_id}
"""
INSERT INTO client_state (user_id, key, value)
VALUES (?, ?, ?)
ON CONFLICT(user_id, key) DO UPDATE
SET value = excluded.value;
""",
(user_id, key, value),
)
row = cursor.fetchone()
if row is None:
return None
return json.loads(row[0])

def set_by_key(self, queue_id: str, key: str, value: str) -> str:
state = self._get() or {}
state.update({key: value})
return value

def get_by_key(self, user_id: str, key: str) -> str | None:
with self._db.transaction() as cursor:
cursor.execute(
f"""
INSERT INTO client_state (id, data)
VALUES ({self._default_row_id}, ?)
ON CONFLICT(id) DO UPDATE
SET data = excluded.data;
"""
SELECT value FROM client_state
WHERE user_id = ? AND key = ?
""",
(json.dumps(state),),
(user_id, key),
)
row = cursor.fetchone()
if row is None:
return None
return row[0]

return value

def get_by_key(self, queue_id: str, key: str) -> str | None:
state = self._get()
if state is None:
return None
return state.get(key, None)

def delete(self, queue_id: str) -> None:
def delete(self, user_id: str) -> None:
with self._db.transaction() as cursor:
cursor.execute(
f"""
DELETE FROM client_state
WHERE id = {self._default_row_id}
"""
DELETE FROM client_state
WHERE user_id = ?
""",
(user_id,),
)
2 changes: 2 additions & 0 deletions invokeai/app/services/shared/sqlite/sqlite_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_23 import build_migration_23
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_24 import build_migration_24
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_25 import build_migration_25
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_26 import build_migration_26
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator


Expand Down Expand Up @@ -73,6 +74,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
migrator.register_migration(build_migration_23(app_config=config, logger=logger))
migrator.register_migration(build_migration_24(app_config=config, logger=logger))
migrator.register_migration(build_migration_25())
migrator.register_migration(build_migration_26())
migrator.run_migrations()

return db
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""Migration 26: Add user_id to client_state table for multi-user support.

This migration updates the client_state table to support per-user state isolation:
- Drops the single-row constraint (CHECK(id = 1))
- Adds user_id column
- Creates unique constraint on (user_id, key) pairs
- Migrates existing data to 'system' user
"""

import json
import sqlite3

from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration


class Migration26Callback:
"""Migration to add per-user client state support."""

def __call__(self, cursor: sqlite3.Cursor) -> None:
self._update_client_state_table(cursor)

def _update_client_state_table(self, cursor: sqlite3.Cursor) -> None:
"""Restructure client_state table to support per-user storage."""
# Check if client_state table exists
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='client_state';")
if cursor.fetchone() is None:
# Table doesn't exist, create it with the new schema
cursor.execute(
"""
CREATE TABLE client_state (
user_id TEXT NOT NULL,
key TEXT NOT NULL,
value TEXT NOT NULL,
updated_at DATETIME NOT NULL DEFAULT (CURRENT_TIMESTAMP),
PRIMARY KEY (user_id, key),
FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE
);
"""
)
cursor.execute("CREATE INDEX IF NOT EXISTS idx_client_state_user_id ON client_state(user_id);")
cursor.execute(
"""
CREATE TRIGGER tg_client_state_updated_at
AFTER UPDATE ON client_state
FOR EACH ROW
BEGIN
UPDATE client_state
SET updated_at = CURRENT_TIMESTAMP
WHERE user_id = OLD.user_id AND key = OLD.key;
END;
"""
)
return

# Table exists with old schema - migrate it
# Get existing data
cursor.execute("SELECT data FROM client_state WHERE id = 1;")
row = cursor.fetchone()
existing_data = {}
if row is not None:
try:
existing_data = json.loads(row[0])
except (json.JSONDecodeError, TypeError):
# If data is corrupt, just start fresh
pass

# Drop the old table
cursor.execute("DROP TABLE IF EXISTS client_state;")

# Create new table with per-user schema
cursor.execute(
"""
CREATE TABLE client_state (
user_id TEXT NOT NULL,
key TEXT NOT NULL,
value TEXT NOT NULL,
updated_at DATETIME NOT NULL DEFAULT (CURRENT_TIMESTAMP),
PRIMARY KEY (user_id, key),
FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE
);
"""
)

cursor.execute("CREATE INDEX IF NOT EXISTS idx_client_state_user_id ON client_state(user_id);")

cursor.execute(
"""
CREATE TRIGGER tg_client_state_updated_at
AFTER UPDATE ON client_state
FOR EACH ROW
BEGIN
UPDATE client_state
SET updated_at = CURRENT_TIMESTAMP
WHERE user_id = OLD.user_id AND key = OLD.key;
END;
"""
)

# Migrate existing data to 'system' user
# The 'system' user is created by migration 25, so it's guaranteed to exist at this point
for key, value in existing_data.items():
cursor.execute(
"""
INSERT INTO client_state (user_id, key, value)
VALUES ('system', ?, ?);
""",
(key, value),
)


def build_migration_26() -> Migration:
"""Builds the migration object for migrating from version 25 to version 26.

This migration adds per-user client state support to prevent data leakage between users.
"""
return Migration(
from_version=25,
to_version=26,
callback=Migration26Callback(),
)
Loading