diff --git a/backend/app/config/settings.py b/backend/app/config/settings.py index 124621a4a..5332a8cbd 100644 --- a/backend/app/config/settings.py +++ b/backend/app/config/settings.py @@ -1,8 +1,31 @@ +import os +from dotenv import load_dotenv + +# Load environment variables from .env file +load_dotenv() + +# Security Settings +SECRET_KEY = os.getenv("SECRET_KEY", "dev-secret-key-change-in-production") +API_KEY = os.getenv("API_KEY", "dev-api-key-change-in-production") +ALGORITHM = os.getenv("ALGORITHM", "HS256") +ACCESS_TOKEN_EXPIRE_MINUTES = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "30")) + +# CORS Settings +ALLOWED_ORIGINS = os.getenv( + "ALLOWED_ORIGINS", "http://localhost:1420,tauri://localhost" +).split(",") + +# Rate Limiting +RATE_LIMIT_PER_MINUTE = int(os.getenv("RATE_LIMIT_PER_MINUTE", "60")) +RATE_LIMIT_PER_HOUR = int(os.getenv("RATE_LIMIT_PER_HOUR", "1000")) + # Model Exports Path MODEL_EXPORTS_PATH = "app/models/ONNX_Exports" # Microservice URLs -SYNC_MICROSERVICE_URL = "http://localhost:8001/api/v1" +SYNC_MICROSERVICE_URL = os.getenv( + "SYNC_MICROSERVICE_URL", "http://localhost:8001/api/v1" +) CONFIDENCE_PERCENT = 0.6 # Object Detection Models: @@ -20,6 +43,6 @@ TEST_INPUT_PATH = "tests/inputs" TEST_OUTPUT_PATH = "tests/outputs" -DATABASE_PATH = "app/database/PictoPy.db" +DATABASE_PATH = os.getenv("DATABASE_PATH", "app/database/PictoPy.db") THUMBNAIL_IMAGES_PATH = "./images/thumbnails" IMAGES_PATH = "./images" diff --git a/backend/app/middleware/__init__.py b/backend/app/middleware/__init__.py new file mode 100644 index 000000000..34c3d5e45 --- /dev/null +++ b/backend/app/middleware/__init__.py @@ -0,0 +1,19 @@ +""" +Middleware package for PictoPy API. +""" + +from .auth import ( + create_access_token, + verify_token, + verify_api_key, + get_current_user, + get_current_user_optional, +) + +__all__ = [ + "create_access_token", + "verify_token", + "verify_api_key", + "get_current_user", + "get_current_user_optional", +] diff --git a/backend/app/middleware/auth.py b/backend/app/middleware/auth.py new file mode 100644 index 000000000..4c0d9e3d7 --- /dev/null +++ b/backend/app/middleware/auth.py @@ -0,0 +1,155 @@ +""" +Authentication middleware for PictoPy API. +Supports both JWT tokens and API key authentication. +""" + +from datetime import datetime, timedelta +from typing import Optional + +from fastapi import Depends, HTTPException, status, Header +from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials +from jose import JWTError, jwt + +from app.config.settings import SECRET_KEY, ALGORITHM, ACCESS_TOKEN_EXPIRE_MINUTES, API_KEY + +# Security schemes +security = HTTPBearer(auto_error=False) + + +def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str: + """ + Create a JWT access token. + + Args: + data: Data to encode in the token + expires_delta: Token expiration time + + Returns: + Encoded JWT token + """ + to_encode = data.copy() + if expires_delta: + expire = datetime.utcnow() + expires_delta + else: + expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) + + to_encode.update({"exp": expire}) + encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) + return encoded_jwt + + +def verify_token(token: str) -> dict: + """ + Verify and decode a JWT token. + + Args: + token: JWT token to verify + + Returns: + Decoded token payload + + Raises: + HTTPException: If token is invalid or expired + """ + try: + payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + return payload + except JWTError: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid authentication credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + + +async def verify_api_key(x_api_key: Optional[str] = Header(None)) -> bool: + """ + Verify API key from header for Tauri application. + + Args: + x_api_key: API key from X-API-Key header + + Returns: + True if API key is valid + + Raises: + HTTPException: If API key is invalid or missing + """ + if not x_api_key: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="API key is missing", + headers={"WWW-Authenticate": "ApiKey"}, + ) + + if x_api_key != API_KEY: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Invalid API key", + ) + + return True + + +async def get_current_user_optional( + credentials: Optional[HTTPAuthorizationCredentials] = Depends(security), + x_api_key: Optional[str] = Header(None), +) -> Optional[dict]: + """ + Get current user from JWT token or API key (optional authentication). + Used for endpoints that work with or without authentication. + + Args: + credentials: HTTP Bearer credentials + x_api_key: API key from header + + Returns: + User data if authenticated, None otherwise + """ + # Check API key first (for Tauri app) + if x_api_key and x_api_key == API_KEY: + return {"authenticated_via": "api_key", "client": "tauri"} + + # Check JWT token + if credentials and credentials.credentials: + try: + payload = verify_token(credentials.credentials) + return payload + except HTTPException: + return None + + return None + + +async def get_current_user( + credentials: Optional[HTTPAuthorizationCredentials] = Depends(security), + x_api_key: Optional[str] = Header(None), +) -> dict: + """ + Get current user from JWT token or API key (required authentication). + Used for protected endpoints that require authentication. + + Args: + credentials: HTTP Bearer credentials + x_api_key: API key from header + + Returns: + User data + + Raises: + HTTPException: If authentication fails + """ + # Check API key first (for Tauri app) + if x_api_key and x_api_key == API_KEY: + return {"authenticated_via": "api_key", "client": "tauri"} + + # Check JWT token + if credentials and credentials.credentials: + payload = verify_token(credentials.credentials) + return payload + + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Not authenticated", + headers={"WWW-Authenticate": "Bearer"}, + ) diff --git a/backend/app/routes/auth.py b/backend/app/routes/auth.py new file mode 100644 index 000000000..06b9d0589 --- /dev/null +++ b/backend/app/routes/auth.py @@ -0,0 +1,104 @@ +""" +Authentication routes for PictoPy API. +""" + +from datetime import timedelta +from typing import Optional + +from fastapi import APIRouter, HTTPException, status, Header +from pydantic import BaseModel + +from app.middleware.auth import create_access_token, verify_api_key +from app.config.settings import ACCESS_TOKEN_EXPIRE_MINUTES, API_KEY + +router = APIRouter() + + +class TokenRequest(BaseModel): + """Request model for token generation.""" + + client_id: str + api_key: str + + +class TokenResponse(BaseModel): + """Response model for token generation.""" + + access_token: str + token_type: str + expires_in: int + + +class AuthStatusResponse(BaseModel): + """Response model for auth status check.""" + + authenticated: bool + auth_method: Optional[str] = None + message: str + + +@router.post( + "/token", + response_model=TokenResponse, + summary="Generate JWT Token", + description="Generate a JWT access token using API key authentication. Used for testing or future web interface.", +) +async def generate_token(request: TokenRequest): + """ + Generate a JWT access token. + + Args: + request: Token request containing client_id and api_key + + Returns: + Access token and expiration info + + Raises: + HTTPException: If API key is invalid + """ + # Verify API key + if request.api_key != API_KEY: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Invalid API key", + ) + + # Create access token + access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) + access_token = create_access_token( + data={"sub": request.client_id, "client": "web"}, expires_delta=access_token_expires + ) + + return TokenResponse( + access_token=access_token, + token_type="bearer", + expires_in=ACCESS_TOKEN_EXPIRE_MINUTES * 60, # in seconds + ) + + +@router.get( + "/status", + response_model=AuthStatusResponse, + summary="Check Authentication Status", + description="Check if the provided API key is valid.", +) +async def check_auth_status(x_api_key: Optional[str] = Header(None)): + """ + Check authentication status. + + Args: + x_api_key: API key from X-API-Key header + + Returns: + Authentication status + """ + if x_api_key and x_api_key == API_KEY: + return AuthStatusResponse( + authenticated=True, + auth_method="api_key", + message="Successfully authenticated via API key", + ) + + return AuthStatusResponse( + authenticated=False, auth_method=None, message="Not authenticated" + ) diff --git a/backend/main.py b/backend/main.py index 1abfd8fdc..00066a832 100644 --- a/backend/main.py +++ b/backend/main.py @@ -7,10 +7,14 @@ import json from uvicorn import Config, Server -from fastapi import FastAPI +from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware +from fastapi.middleware.trustedhost import TrustedHostMiddleware from contextlib import asynccontextmanager from concurrent.futures import ProcessPoolExecutor +from slowapi import Limiter, _rate_limit_exceeded_handler +from slowapi.util import get_remote_address +from slowapi.errors import RateLimitExceeded from app.database.faces import db_create_faces_table from app.database.images import db_create_images_table from app.database.face_clusters import db_create_clusters_table @@ -26,7 +30,13 @@ from app.routes.images import router as images_router from app.routes.face_clusters import router as face_clusters_router from app.routes.user_preferences import router as user_preferences_router +from app.routes.auth import router as auth_router from fastapi.openapi.utils import get_openapi +from app.config.settings import ALLOWED_ORIGINS, RATE_LIMIT_PER_MINUTE + + +# Initialize rate limiter +limiter = Limiter(key_func=get_remote_address, default_limits=[f"{RATE_LIMIT_PER_MINUTE}/minute"]) @asynccontextmanager @@ -42,8 +52,9 @@ async def lifespan(app: FastAPI): db_create_album_images_table() db_create_metadata_table() microservice_util_start_sync_service() - # Create ProcessPoolExecutor and attach it to app.state - app.state.executor = ProcessPoolExecutor(max_workers=1) + # Create ProcessPoolExecutor with optimal worker count + max_workers = max(1, multiprocessing.cpu_count() - 1) + app.state.executor = ProcessPoolExecutor(max_workers=max_workers) try: yield @@ -65,6 +76,10 @@ async def lifespan(app: FastAPI): ], ) +# Attach rate limiter to app state +app.state.limiter = limiter +app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) + def generate_openapi_json(): try: @@ -92,22 +107,54 @@ def generate_openapi_json(): print(f"Failed to generate openapi.json: {e}") -# Add CORS middleware +# Add security middleware - Trusted Host +app.add_middleware( + TrustedHostMiddleware, + allowed_hosts=["localhost", "127.0.0.1", "0.0.0.0"], +) + +# Add CORS middleware with restricted origins app.add_middleware( CORSMiddleware, - allow_origins=["*"], # Allows all origins + allow_origins=ALLOWED_ORIGINS, # Only allow specific origins from config allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], + allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], + allow_headers=[ + "Content-Type", + "Authorization", + "X-API-Key", + "Accept", + ], + expose_headers=["X-RateLimit-Limit", "X-RateLimit-Remaining", "X-RateLimit-Reset"], ) -# Basic health check endpoint +# Add security headers middleware +@app.middleware("http") +async def add_security_headers(request: Request, call_next): + """Add security headers to all responses.""" + response = await call_next(request) + response.headers["X-Content-Type-Options"] = "nosniff" + response.headers["X-Frame-Options"] = "DENY" + response.headers["X-XSS-Protection"] = "1; mode=block" + response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains" + return response + + +# Basic health check endpoint (no rate limit) @app.get("/health", tags=["Health"]) +@limiter.exempt async def root(): - return {"message": "PictoPy Server is up and running!"} + """Health check endpoint to verify server status.""" + return { + "status": "healthy", + "message": "PictoPy Server is up and running!", + "version": app.version, + } +# Include routers +app.include_router(auth_router, prefix="/auth", tags=["Authentication"]) app.include_router(folders_router, prefix="/folders", tags=["Folders"]) app.include_router(albums_router, prefix="/albums", tags=["Albums"]) app.include_router(images_router, prefix="/images", tags=["Images"]) diff --git a/backend/requirements.txt b/backend/requirements.txt index 00bec243a..922ad1b9d 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -8,6 +8,8 @@ email_validator==2.1.1 exceptiongroup==1.2.1 fastapi==0.111.0 fastapi-cli==0.0.3 +python-jose[cryptography]==3.3.0 +slowapi==0.1.9 flatbuffers==24.3.25 h11==0.14.0 h2==4.1.0 diff --git a/backend/tests/test_auth.py b/backend/tests/test_auth.py new file mode 100644 index 000000000..e87250a69 --- /dev/null +++ b/backend/tests/test_auth.py @@ -0,0 +1,122 @@ +""" +Tests for authentication system. +""" + +import pytest +from fastapi.testclient import TestClient +from app.main import app +from app.middleware.auth import create_access_token, verify_token +from app.config.settings import API_KEY + +client = TestClient(app) + + +class TestAuthentication: + """Test authentication endpoints and middleware.""" + + def test_health_endpoint_no_auth(self): + """Test that health endpoint works without authentication.""" + response = client.get("/health") + assert response.status_code == 200 + assert response.json()["status"] == "healthy" + + def test_auth_status_without_api_key(self): + """Test auth status endpoint without API key.""" + response = client.get("/auth/status") + assert response.status_code == 200 + data = response.json() + assert data["authenticated"] is False + assert data["auth_method"] is None + + def test_auth_status_with_valid_api_key(self): + """Test auth status endpoint with valid API key.""" + response = client.get("/auth/status", headers={"X-API-Key": API_KEY}) + assert response.status_code == 200 + data = response.json() + assert data["authenticated"] is True + assert data["auth_method"] == "api_key" + + def test_auth_status_with_invalid_api_key(self): + """Test auth status endpoint with invalid API key.""" + response = client.get("/auth/status", headers={"X-API-Key": "invalid-key"}) + assert response.status_code == 200 + data = response.json() + assert data["authenticated"] is False + + def test_generate_token_with_valid_api_key(self): + """Test JWT token generation with valid API key.""" + response = client.post( + "/auth/token", + json={"client_id": "test-client", "api_key": API_KEY}, + ) + assert response.status_code == 200 + data = response.json() + assert "access_token" in data + assert data["token_type"] == "bearer" + assert data["expires_in"] > 0 + + def test_generate_token_with_invalid_api_key(self): + """Test JWT token generation with invalid API key.""" + response = client.post( + "/auth/token", + json={"client_id": "test-client", "api_key": "invalid-key"}, + ) + assert response.status_code == 403 + assert response.json()["detail"] == "Invalid API key" + + def test_create_and_verify_token(self): + """Test JWT token creation and verification.""" + # Create token + token = create_access_token(data={"sub": "test-user", "client": "test"}) + assert token is not None + + # Verify token + payload = verify_token(token) + assert payload["sub"] == "test-user" + assert payload["client"] == "test" + assert "exp" in payload + + def test_verify_invalid_token(self): + """Test verification of invalid JWT token.""" + from fastapi import HTTPException + + with pytest.raises(HTTPException) as exc_info: + verify_token("invalid-token") + assert exc_info.value.status_code == 401 + + def test_cors_headers_present(self): + """Test that CORS headers are properly set.""" + response = client.options("/health") + # Should not fail with CORS error + assert response.status_code in [200, 405] # 405 if OPTIONS not explicitly handled + + def test_security_headers_present(self): + """Test that security headers are present in responses.""" + response = client.get("/health") + headers = response.headers + + # Check security headers + assert "x-content-type-options" in headers + assert headers["x-content-type-options"] == "nosniff" + assert "x-frame-options" in headers + assert headers["x-frame-options"] == "DENY" + assert "x-xss-protection" in headers + assert "strict-transport-security" in headers + + +class TestRateLimiting: + """Test rate limiting functionality.""" + + def test_rate_limit_headers_present(self): + """Test that rate limit headers are present in responses.""" + response = client.get("/auth/status") + headers = response.headers + + # slowapi adds rate limit headers + # Note: These might not always be present in test environment + # This is more of a smoke test + assert response.status_code == 200 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/sync-microservice/main.py b/sync-microservice/main.py index db62c605d..6b45ee6f0 100644 --- a/sync-microservice/main.py +++ b/sync-microservice/main.py @@ -3,6 +3,12 @@ from app.core.lifespan import lifespan from app.routes import health, watcher, folders from fastapi.middleware.cors import CORSMiddleware +import os + +# Load allowed origins from environment or use defaults +ALLOWED_ORIGINS = os.getenv( + "ALLOWED_ORIGINS", "http://localhost:8000,http://localhost:1420,tauri://localhost" +).split(",") # Create FastAPI app with lifespan management app = FastAPI( @@ -11,12 +17,14 @@ version="1.0.0", lifespan=lifespan, ) + +# Add CORS middleware with restricted origins app.add_middleware( CORSMiddleware, - allow_origins=["*"], # Allows all origins + allow_origins=ALLOWED_ORIGINS, # Only allow specific origins allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], + allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], + allow_headers=["Content-Type", "Authorization", "X-API-Key", "Accept"], ) # Include route modules app.include_router(health.router, prefix="/api/v1")