Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions invokeai/app/api/auth_dependencies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""FastAPI dependencies for authentication."""

from typing import Annotated

from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer

from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.auth.token_service import TokenData, verify_token

# HTTP Bearer token security scheme
security = HTTPBearer(auto_error=False)


async def get_current_user(
credentials: Annotated[HTTPAuthorizationCredentials | None, Depends(security)],
) -> TokenData:
"""Get current authenticated user from Bearer token.

Note: This function accesses ApiDependencies.invoker.services.users directly,
which is the established pattern in this codebase. The ApiDependencies.invoker
is initialized in the FastAPI lifespan context before any requests are handled.

Args:
credentials: The HTTP authorization credentials containing the Bearer token

Returns:
TokenData containing user information from the token

Raises:
HTTPException: If token is missing, invalid, or expired (401 Unauthorized)
"""
if credentials is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing authentication credentials",
headers={"WWW-Authenticate": "Bearer"},
)

token = credentials.credentials
token_data = verify_token(token)

if token_data is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired authentication token",
headers={"WWW-Authenticate": "Bearer"},
)

# Verify user still exists and is active
user_service = ApiDependencies.invoker.services.users
user = user_service.get(token_data.user_id)

if user is None or not user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User account is inactive or does not exist",
headers={"WWW-Authenticate": "Bearer"},
)

return token_data


async def require_admin(
current_user: Annotated[TokenData, Depends(get_current_user)],
) -> TokenData:
"""Require admin role for the current user.

Args:
current_user: The current authenticated user's token data

Returns:
The token data if user is an admin

Raises:
HTTPException: If user does not have admin privileges (403 Forbidden)
"""
if not current_user.is_admin:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin privileges required")
return current_user


# Type aliases for convenient use in route dependencies
CurrentUser = Annotated[TokenData, Depends(get_current_user)]
AdminUser = Annotated[TokenData, Depends(require_admin)]
3 changes: 3 additions & 0 deletions invokeai/app/api/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from invokeai.app.services.style_preset_images.style_preset_images_disk import StylePresetImageFileStorageDisk
from invokeai.app.services.style_preset_records.style_preset_records_sqlite import SqliteStylePresetRecordsStorage
from invokeai.app.services.urls.urls_default import LocalUrlService
from invokeai.app.services.users.users_default import UserService
from invokeai.app.services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
from invokeai.app.services.workflow_thumbnails.workflow_thumbnails_disk import WorkflowThumbnailFileStorageDisk
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
Expand Down Expand Up @@ -155,6 +156,7 @@ def initialize(
style_preset_image_files = StylePresetImageFileStorageDisk(style_presets_folder / "images")
workflow_thumbnails = WorkflowThumbnailFileStorageDisk(workflow_thumbnails_folder)
client_state_persistence = ClientStatePersistenceSqlite(db=db)
users = UserService(db=db)

services = InvocationServices(
board_image_records=board_image_records,
Expand Down Expand Up @@ -186,6 +188,7 @@ def initialize(
style_preset_image_files=style_preset_image_files,
workflow_thumbnails=workflow_thumbnails,
client_state_persistence=client_state_persistence,
users=users,
)

ApiDependencies.invoker = Invoker(services)
Expand Down
201 changes: 201 additions & 0 deletions invokeai/app/api/routers/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
"""Authentication endpoints."""

from datetime import timedelta
from typing import Annotated

from fastapi import APIRouter, Body, HTTPException, status
from pydantic import BaseModel, Field, field_validator

from invokeai.app.api.auth_dependencies import CurrentUser
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.auth.token_service import TokenData, create_access_token
from invokeai.app.services.users.users_common import UserCreateRequest, UserDTO, validate_email_with_special_domains

auth_router = APIRouter(prefix="/v1/auth", tags=["authentication"])

# Token expiration constants (in days)
TOKEN_EXPIRATION_NORMAL = 1 # 1 day for normal login
TOKEN_EXPIRATION_REMEMBER_ME = 7 # 7 days for "remember me" login


class LoginRequest(BaseModel):
"""Request body for user login."""

email: str = Field(description="User email address")
password: str = Field(description="User password")
remember_me: bool = Field(default=False, description="Whether to extend session duration")

@field_validator("email")
@classmethod
def validate_email(cls, v: str) -> str:
"""Validate email address, allowing special-use domains."""
return validate_email_with_special_domains(v)


class LoginResponse(BaseModel):
"""Response from successful login."""

token: str = Field(description="JWT access token")
user: UserDTO = Field(description="User information")
expires_in: int = Field(description="Token expiration time in seconds")


class SetupRequest(BaseModel):
"""Request body for initial admin setup."""

email: str = Field(description="Admin email address")
display_name: str | None = Field(default=None, description="Admin display name")
password: str = Field(description="Admin password")

@field_validator("email")
@classmethod
def validate_email(cls, v: str) -> str:
"""Validate email address, allowing special-use domains."""
return validate_email_with_special_domains(v)


class SetupResponse(BaseModel):
"""Response from successful admin setup."""

success: bool = Field(description="Whether setup was successful")
user: UserDTO = Field(description="Created admin user information")


class LogoutResponse(BaseModel):
"""Response from logout."""

success: bool = Field(description="Whether logout was successful")


@auth_router.post("/login", response_model=LoginResponse)
async def login(
request: Annotated[LoginRequest, Body(description="Login credentials")],
) -> LoginResponse:
"""Authenticate user and return access token.

Args:
request: Login credentials (email and password)

Returns:
LoginResponse containing JWT token and user information

Raises:
HTTPException: 401 if credentials are invalid or user is inactive
"""
user_service = ApiDependencies.invoker.services.users
user = user_service.authenticate(request.email, request.password)

if user is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect email or password",
headers={"WWW-Authenticate": "Bearer"},
)

if not user.is_active:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User account is disabled")

# Create token with appropriate expiration
expires_delta = timedelta(days=TOKEN_EXPIRATION_REMEMBER_ME if request.remember_me else TOKEN_EXPIRATION_NORMAL)
token_data = TokenData(
user_id=user.user_id,
email=user.email,
is_admin=user.is_admin,
)
token = create_access_token(token_data, expires_delta)

return LoginResponse(
token=token,
user=user,
expires_in=int(expires_delta.total_seconds()),
)


@auth_router.post("/logout", response_model=LogoutResponse)
async def logout(
current_user: CurrentUser,
) -> LogoutResponse:
"""Logout current user.

Currently a no-op since we use stateless JWT tokens. For token invalidation in
future implementations, consider:
- Token blacklist: Store invalidated tokens in Redis/database with expiration
- Token versioning: Add version field to user record, increment on logout
- Short-lived tokens: Use refresh token pattern with token rotation
- Session storage: Track active sessions server-side for revocation

Args:
current_user: The authenticated user (validates token)

Returns:
LogoutResponse indicating success
"""
# TODO: Implement token invalidation when server-side session management is added
# For now, this is a no-op since we use stateless JWT tokens
return LogoutResponse(success=True)


@auth_router.get("/me", response_model=UserDTO)
async def get_current_user_info(
current_user: CurrentUser,
) -> UserDTO:
"""Get current authenticated user's information.

Args:
current_user: The authenticated user's token data

Returns:
UserDTO containing user information

Raises:
HTTPException: 404 if user is not found (should not happen normally)
"""
user_service = ApiDependencies.invoker.services.users
user = user_service.get(current_user.user_id)

if user is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")

return user


@auth_router.post("/setup", response_model=SetupResponse)
async def setup_admin(
request: Annotated[SetupRequest, Body(description="Admin account details")],
) -> SetupResponse:
"""Set up initial administrator account.

This endpoint can only be called once, when no admin user exists. It creates
the first admin user for the system.

Args:
request: Admin account details (email, display_name, password)

Returns:
SetupResponse containing the created admin user

Raises:
HTTPException: 400 if admin already exists or password is weak
"""
user_service = ApiDependencies.invoker.services.users

# Check if any admin exists
if user_service.has_admin():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Administrator account already configured",
)

# Create admin user - this will validate password strength
try:
user_data = UserCreateRequest(
email=request.email,
display_name=request.display_name,
password=request.password,
is_admin=True,
)
user = user_service.create_admin(user_data)
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e

return SetupResponse(success=True, user=user)
3 changes: 3 additions & 0 deletions invokeai/app/api_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
from invokeai.app.api.routers import (
app_info,
auth,
board_images,
boards,
client_state,
Expand Down Expand Up @@ -121,6 +122,8 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):


# Include all routers
# Authentication router should be first so it's registered before protected routes
app.include_router(auth.auth_router, prefix="/api")
app.include_router(utilities.utilities_router, prefix="/api")
app.include_router(model_manager.model_manager_router, prefix="/api")
app.include_router(download_queue.download_queue_router, prefix="/api")
Expand Down
3 changes: 3 additions & 0 deletions invokeai/app/services/invocation_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from invokeai.app.services.session_processor.session_processor_base import SessionProcessorBase
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
from invokeai.app.services.urls.urls_base import UrlServiceBase
from invokeai.app.services.users.users_base import UserServiceBase
from invokeai.app.services.workflow_records.workflow_records_base import WorkflowRecordsStorageBase
from invokeai.app.services.workflow_thumbnails.workflow_thumbnails_base import WorkflowThumbnailServiceBase
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
Expand Down Expand Up @@ -75,6 +76,7 @@ def __init__(
style_preset_image_files: "StylePresetImageFileStorageBase",
workflow_thumbnails: "WorkflowThumbnailServiceBase",
client_state_persistence: "ClientStatePersistenceABC",
users: "UserServiceBase",
):
self.board_images = board_images
self.board_image_records = board_image_records
Expand Down Expand Up @@ -105,3 +107,4 @@ def __init__(
self.style_preset_image_files = style_preset_image_files
self.workflow_thumbnails = workflow_thumbnails
self.client_state_persistence = client_state_persistence
self.users = users
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,15 @@ def _update_style_presets_table(self, cursor: sqlite3.Cursor) -> None:
cursor.execute("CREATE INDEX IF NOT EXISTS idx_style_presets_is_public ON style_presets(is_public);")

def _create_system_user(self, cursor: sqlite3.Cursor) -> None:
"""Create system user for backward compatibility."""
"""Create system user for backward compatibility.

The system user is NOT an admin - it's just used to own existing data
from before multi-user support was added. Real admin users should be
created through the /auth/setup endpoint.
"""
cursor.execute("""
INSERT OR IGNORE INTO users (user_id, email, display_name, password_hash, is_admin, is_active)
VALUES ('system', 'system@system.invokeai', 'System', '', TRUE, TRUE);
VALUES ('system', 'system@system.invokeai', 'System', '', FALSE, TRUE);
""")


Expand Down
Loading
Loading