Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions backend/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@ class Settings(BaseSettings):
"If not set, defaults to http://localhost:8087 when BASE_URL is http://localhost:8086, otherwise defaults to BASE_URL.",
)

# Proxy
trusted_proxy_ips: str | None = Field(
default=None,
description="Comma-separated trusted proxy IPs/CIDRs for X-Forwarded-For (e.g. '10.0.0.0/8'). "
"If unset, all proxies are trusted (only safe when the app is not directly exposed).",
)

# Database
database_url: PostgresDsn = Field(
default="postgresql://mfbt:iammfbt@localhost:5432/mfbt_dev",
Expand Down
34 changes: 28 additions & 6 deletions backend/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,16 @@
import logging
from contextlib import asynccontextmanager

from fastapi import Depends, FastAPI
from fastapi import Depends, FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from slowapi import _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded
from starlette.middleware.sessions import SessionMiddleware
from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware

from app.auth.trial import require_active_trial, require_tokens_available
from app.config import settings
from app.rate_limit import limiter
from app.routers import (
activity,
agent_api,
Expand Down Expand Up @@ -144,16 +148,25 @@ async def lifespan(app: FastAPI):
lifespan=lifespan,
)

# Trust X-Forwarded-For from reverse proxy so request.client.host is the real client IP
# (required for rate limiting to work correctly behind nginx/Traefik/load balancers)
# Set TRUSTED_PROXY_IPS to comma-separated IPs/CIDRs in production (e.g. "10.0.0.0/8")
_trusted_hosts = settings.trusted_proxy_ips.split(",") if settings.trusted_proxy_ips else ["*"]
app.add_middleware(ProxyHeadersMiddleware, trusted_hosts=_trusted_hosts)

# Register rate limiter
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)

# Configure CORS
# For development, allow frontend origin to support OAuth cookie flow
# Also allow other origins for MCP clients (but those won't get credentials)
_cors_origins = [settings.frontend_url]
if settings.is_development:
_cors_origins.extend(["http://localhost:8087", "http://127.0.0.1:8087"])
app.add_middleware(
CORSMiddleware,
allow_origins=[
settings.frontend_url, # Frontend origin for OAuth cookie flow
"http://localhost:8087", # Explicit localhost for development
"http://127.0.0.1:8087", # Alternative localhost
],
allow_origins=_cors_origins,
allow_credentials=True, # Required for cookies to work cross-origin
allow_methods=["*"], # Allow all methods
allow_headers=["*"], # Allow all headers
Expand Down Expand Up @@ -253,6 +266,15 @@ async def lifespan(app: FastAPI):
app.include_router(_rp.router, prefix=_prefix, dependencies=_deps)


@app.middleware("http")
async def add_security_headers(request: Request, call_next):
"""Add security headers to all responses."""
response = await call_next(request)
response.headers["X-Content-Type-Options"] = "nosniff"
response.headers["X-Frame-Options"] = "DENY"
return response


@app.get("/health", tags=["health"])
async def health_check():
"""
Expand Down
6 changes: 6 additions & 0 deletions backend/app/rate_limit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""Rate limiting configuration using slowapi."""

from slowapi import Limiter
from slowapi.util import get_remote_address

limiter = Limiter(key_func=get_remote_address)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In-memory rate limiting broken under multiple workers

The Limiter is constructed here without a shared backing store, so it defaults to an in-process memory counter. In a production deployment with multiple uvicorn workers (e.g. --workers 4), each process holds its own independent counter. A client can therefore make limit × worker_count attempts per window — e.g. 5 register calls × 4 workers = 20 effective attempts per minute — defeating the purpose of the limit entirely.

The app already has a Redis connection configured in Settings. SlowAPI's Limiter can be pointed at Redis so that all workers share a single rate-limit bucket per client. Without a shared backing store, the rate limiting added in this PR will not function correctly in any multi-worker production environment.

Prompt To Fix With AI
This is a comment left during a code review.
Path: backend/app/rate_limit.py
Line: 6

Comment:
**In-memory rate limiting broken under multiple workers**

The `Limiter` is constructed here without a shared backing store, so it defaults to an in-process memory counter. In a production deployment with multiple uvicorn workers (e.g. `--workers 4`), each process holds its own independent counter. A client can therefore make `limit × worker_count` attempts per window — e.g. 5 register calls × 4 workers = 20 effective attempts per minute — defeating the purpose of the limit entirely.

The app already has a Redis connection configured in `Settings`. SlowAPI's `Limiter` can be pointed at Redis so that all workers share a single rate-limit bucket per client. Without a shared backing store, the rate limiting added in this PR will not function correctly in any multi-worker production environment.

How can I resolve this? If you propose a fix, please make it concise.

13 changes: 11 additions & 2 deletions backend/app/routers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from app.database import get_async_db, get_db
from app.models.user import User
from app.plugin_registry import get_plugin_registry
from app.rate_limit import limiter
from app.schemas.api_key import ApiKeyCreate
from app.schemas.auth import (
OrgMembershipResponse,
Expand Down Expand Up @@ -143,7 +144,9 @@ def _get_known_provider_slugs() -> set[str]:


@router.post("/register", response_model=RegistrationResponse, status_code=status.HTTP_201_CREATED)
@limiter.limit("5/minute")
async def register(
request: Request,
user_data: UserCreate,
db: Annotated[Session, Depends(get_db)],
async_db: Annotated[AsyncSession, Depends(get_async_db)],
Expand Down Expand Up @@ -266,7 +269,9 @@ async def register(


@router.post("/login", response_model=TokenResponse)
@limiter.limit("10/minute")
def login(
request: Request,
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
db: Annotated[Session, Depends(get_db)],
) -> TokenResponse:
Expand Down Expand Up @@ -373,7 +378,9 @@ def get_current_user_info(


@router.get("/me/token", response_model=TokenResponse)
@limiter.limit("30/minute")
def get_session_token(
request: Request,
session_cookie: Annotated[str | None, Cookie(alias="mfbt_session")] = None,
db: Session = Depends(get_db),
) -> TokenResponse:
Expand Down Expand Up @@ -658,8 +665,10 @@ def verify_email(


@router.post("/resend-verification", response_model=ResendVerificationResponse)
@limiter.limit("3/minute")
async def resend_verification(
request: ResendVerificationRequest,
request: Request,
verification_request: ResendVerificationRequest,
db: Annotated[Session, Depends(get_db)],
async_db: Annotated[AsyncSession, Depends(get_async_db)],
) -> ResendVerificationResponse:
Expand All @@ -678,7 +687,7 @@ async def resend_verification(
Returns:
ResendVerificationResponse with generic message
"""
user = UserService.get_user_by_email(db, request.email)
user = UserService.get_user_by_email(db, verification_request.email)

# Always return the same message for security (don't reveal if email exists)
generic_message = "If an account with that email exists and is not yet verified, a verification email has been sent"
Expand Down
21 changes: 21 additions & 0 deletions backend/app/routers/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from sqlalchemy.orm import Session

from app.auth.utils import decode_access_token
from app.config import settings
from app.models.user import User
from app.permissions.context import OrgContext
from app.services.user_service import UserService
Expand All @@ -19,6 +20,15 @@
router = APIRouter(prefix="/ws", tags=["websocket"])


def get_allowed_ws_origins() -> set[str]:
"""Build the set of allowed WebSocket origins (same as CORS config)."""
origins = {settings.frontend_url}
if settings.is_development:
origins.update({"http://localhost:8087", "http://127.0.0.1:8087"})
# Normalize: strip trailing slashes
return {o.rstrip("/") for o in origins}


async def get_current_user_ws(
token: str,
db: Session,
Expand Down Expand Up @@ -86,6 +96,17 @@ async def websocket_jobs_endpoint(
}
}
"""
# Validate Origin header to prevent Cross-Site WebSocket Hijacking (CSWSH).
# Absent Origin is allowed: non-browser clients (MCP tools, CLI) don't send it.
# These clients are still authenticated via JWT token in the query parameter.
origin = websocket.headers.get("origin")
if origin is not None:
normalized_origin = origin.rstrip("/")
if normalized_origin not in get_allowed_ws_origins():
logger.warning(f"WebSocket rejected: disallowed origin={origin}")
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="Origin not allowed")
return

# Authenticate user with a short-lived DB session (not held during WebSocket lifetime)
from app.database import SessionLocal

Expand Down
1 change: 1 addition & 0 deletions backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ dependencies = [
"mistune>=3.0",
"tavily-python>=0.5.0",
"redis>=7.1.0",
"slowapi>=0.1.9",
]

[project.optional-dependencies]
Expand Down
Loading
Loading