Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 31 additions & 10 deletions backend/api/api_gateway/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@
# Load environment variables
load_dotenv()

MAX_REQUEST_BODY_SIZE = 1 * 1024 * 1024 # 1MB
EXCLUDED_HEADERS = {
"host", "connection", "keep-alive", "proxy-authenticate",
"proxy-authorization", "te", "trailers", "transfer-encoding", "upgrade",
"content-length",
}

# Create FastAPI app
app = FastAPI(
title="TaskHub API Gateway",
Expand Down Expand Up @@ -89,24 +96,38 @@ async def forward_request(
Returns:
JSONResponse: Response from service
"""
# Get request body
body = await request.body()

# Get request headers
headers = dict(request.headers)
# Filter headers
temp_headers = {}
for name, value in request.headers.items():
if name.lower() not in EXCLUDED_HEADERS:
temp_headers[name] = value

# Add user ID to headers if available
if hasattr(request.state, "user_id"):
headers["X-User-ID"] = request.state.user_id
temp_headers["X-User-ID"] = str(request.state.user_id)

# Prepare arguments for circuit_breaker.call_service
request_body = await request.body()

if len(request_body) > MAX_REQUEST_BODY_SIZE:
return JSONResponse(
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
content={"detail": f"Request body exceeds maximum allowed size of {MAX_REQUEST_BODY_SIZE} bytes."}
)

service_kwargs = {
"headers": temp_headers,
"params": dict(request.query_params)
}

if request.method.upper() not in ("GET", "HEAD", "DELETE"):
service_kwargs["content"] = request_body

# Forward request to service using circuit breaker
response = await circuit_breaker.call_service( # type: ignore
service_name=service_name,
url=target_url,
method=request.method,
headers=headers,
content=body,
params=dict(request.query_params),
**service_kwargs
)

# Return response
Expand Down
84 changes: 38 additions & 46 deletions backend/api/api_gateway/middleware/auth_middleware.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
import os
from typing import Awaitable, Callable, Optional

import httpx
from dotenv import load_dotenv
from fastapi import HTTPException, Request, status
from fastapi.responses import JSONResponse
from jose import ExpiredSignatureError, JWTError, jwt

# Load environment variables
load_dotenv()

# Auth service URL
AUTH_SERVICE_URL = os.getenv("AUTH_SERVICE_URL", "http://localhost:8001")
SUPABASE_JWT_SECRET = os.getenv("SUPABASE_JWT_SECRET")
SUPABASE_AUDIENCE = os.getenv("SUPABASE_AUDIENCE", "authenticated")
# Optional: Add SUPABASE_ISSUER if you want to validate the 'iss' claim, e.g.:
# SUPABASE_ISSUER = os.getenv("SUPABASE_ISSUER")


async def auth_middleware(
request: Request, call_next: Callable[[Request], Awaitable[JSONResponse]]
) -> JSONResponse:
if request.method == "OPTIONS":
return await call_next(request)
"""
Middleware for authentication.

Expand Down Expand Up @@ -102,56 +106,44 @@ def _get_token_from_request(request: Request) -> Optional[str]:


async def _validate_token(token: str) -> str:
"""
Validate token with auth service.

Args:
token (str): JWT token

Returns:
str: User ID
if not SUPABASE_JWT_SECRET:
print('ERROR: SUPABASE_JWT_SECRET is not configured in the environment.')
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='Authentication system configuration error.',
)

Raises:
HTTPException: If token is invalid
"""
try:
# Make request to auth service
async with httpx.AsyncClient() as client:
response = await client.get(
f"{AUTH_SERVICE_URL}/auth/validate",
headers={"Authorization": f"Bearer {token}"},
payload = jwt.decode(
token,
SUPABASE_JWT_SECRET,
algorithms=['HS256'],
audience=SUPABASE_AUDIENCE
# If validating issuer, add: issuer=SUPABASE_ISSUER
)

user_id = payload.get('sub')
if not user_id:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail='Invalid token: User ID (sub) not found in token.',
)

return user_id

# Check response
if response.status_code != 200:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token"
)

# Parse response
data = response.json()

# Extract user ID from token
# In a real application, you would decode the token and extract the user ID
# For simplicity, we'll assume the auth service returns the user ID
user_id = data.get("user_id")

if not user_id:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token, user_id not in response",
)

return user_id
except httpx.RequestError as e:
except ExpiredSignatureError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail='Token has expired.'
)
except JWTError as e:
print(f'JWTError during token validation: {str(e)}') # Server log
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=f"Auth service unavailable: {str(e)}",
status_code=status.HTTP_401_UNAUTHORIZED,
detail='Invalid token.',
)
except Exception as e:
# It's good practice to log the error here
# logger.error(f"Unexpected error during token validation with auth service: {str(e)}")
print(f'Unexpected error during token validation: {str(e)}') # Server log
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An unexpected error occurred while validating the token.",
detail='An unexpected error occurred during token validation.',
)
4 changes: 2 additions & 2 deletions backend/api/api_gateway/middleware/circuit_breaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ def __init__(
self,
failure_threshold: int = 5,
recovery_timeout: int = 30,
timeout: float = 5.0,
timeout: float = 10.0,
):
"""
Initialize CircuitBreaker.

Args:
failure_threshold (int, optional): Number of failures before opening circuit. Defaults to 5.
recovery_timeout (int, optional): Seconds to wait before trying again. Defaults to 30.
timeout (float, optional): Request timeout in seconds. Defaults to 5.0.
timeout (float, optional): Request timeout in seconds. Defaults to 10.0.
"""
self.failure_threshold = failure_threshold
self.recovery_timeout = recovery_timeout
Expand Down
3 changes: 3 additions & 0 deletions backend/api/api_gateway/utils/service_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,9 @@ def __init__(self):
"methods": ["POST"],
},
{"path": "/health", "methods": ["GET"]},
{"path": "/analytics/card/{card_id}", "methods": ["GET"]},
{"path": "/calendar/events", "methods": ["GET", "POST"]},
{"path": "/ai/inference/{model}", "methods": ["POST"]},
],
},
}
Expand Down
30 changes: 2 additions & 28 deletions backend/api/auth_service/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from api.auth_service.app.schemas.user import (
TokenDTO,
TokenValidationResponseDTO,
# TokenValidationResponseDTO, # No longer needed
UserProfileDTO,
UserRegisterDTO,
)
Expand Down Expand Up @@ -67,33 +67,6 @@ async def login(form_data: OAuth2PasswordRequestForm = Depends()):
return auth_service.login(form_data.username, form_data.password)


@app.get(
"/auth/validate", response_model=TokenValidationResponseDTO, tags=["Authentication"]
)
async def validate(token: str = Security(oauth2_scheme)):
"""
Validate a token. Also returns user_id along with new tokens.

Args:
token (str): JWT token
"""
return auth_service.validate_token(token)


@app.post("/auth/refresh", response_model=TokenDTO, tags=["Authentication"])
async def refresh(refresh_token: str) -> Any:
"""
Refresh a token.

Args:
refresh_token (str): Refresh token

Returns:
TokenDTO: Authentication tokens
"""
return auth_service.refresh_token(refresh_token)


@app.post("/auth/logout", tags=["Authentication"])
async def logout(token: str = Security(oauth2_scheme)):
"""
Expand Down Expand Up @@ -131,3 +104,4 @@ async def health_check() -> Any:
Dict[str, str]: Health status
"""
return {"status": "healthy"}

Loading
Loading