diff --git a/backend/controllers/websocket.go b/backend/controllers/websocket.go index 6f318b50..5d245f5d 100644 --- a/backend/controllers/websocket.go +++ b/backend/controllers/websocket.go @@ -2,6 +2,9 @@ package controllers import ( "net/http" + "net/url" + "os" + "strings" "ccsync_backend/utils" @@ -13,10 +16,58 @@ type JobStatus struct { Status string `json:"status"` } -var upgrader = websocket.Upgrader{ - CheckOrigin: func(r *http.Request) bool { +// checkWebSocketOrigin validates the Origin header against allowed origins +func checkWebSocketOrigin(r *http.Request) bool { + origin := r.Header.Get("Origin") + + // In development mode, be more permissive + if os.Getenv("ENV") != "production" { + if origin == "" || + strings.HasPrefix(origin, "http://localhost") || + strings.HasPrefix(origin, "http://127.0.0.1") { + return true + } + } + + // In production, require an origin header + if origin == "" { + utils.Logger.Warn("WebSocket connection rejected: missing Origin header in production") + return false + } + + // Check against configured allowed origin (exact match) + allowedOrigin := os.Getenv("ALLOWED_ORIGIN") + if allowedOrigin != "" && origin == allowedOrigin { return true - }, + } + + // Fallback: parse origin and compare hostname exactly with request host + parsedOrigin, err := url.Parse(origin) + if err != nil { + utils.Logger.Warnf("WebSocket connection rejected: invalid origin URL: %s", origin) + return false + } + + // Extract hostname from request Host header (may include port) + requestHost := r.Host + if idx := strings.LastIndex(requestHost, ":"); idx != -1 { + // Be careful with IPv6 addresses like [::1]:8080 + if !strings.HasPrefix(requestHost, "[") || idx > strings.Index(requestHost, "]") { + requestHost = requestHost[:idx] + } + } + + // Exact hostname comparison + if parsedOrigin.Hostname() == requestHost { + return true + } + + utils.Logger.Warnf("WebSocket connection rejected from origin: %s", origin) + return false +} + +var upgrader = websocket.Upgrader{ + CheckOrigin: checkWebSocketOrigin, } var clients = make(map[*websocket.Conn]bool) diff --git a/backend/main.go b/backend/main.go index b98b2436..00057bd1 100644 --- a/backend/main.go +++ b/backend/main.go @@ -80,6 +80,16 @@ func main() { utils.Logger.Fatal("SESSION_KEY environment variable is not set or empty") } store := sessions.NewCookieStore(sessionKey) + + // Configure secure cookie options + store.Options = &sessions.Options{ + Path: "/", + MaxAge: 86400 * 7, // 7 days + HttpOnly: true, // Prevent JavaScript access + Secure: os.Getenv("ENV") == "production", // HTTPS only in production + SameSite: http.SameSiteLaxMode, // CSRF protection (Lax allows OAuth redirects) + } + gob.Register(map[string]interface{}{}) app := controllers.App{Config: conf, SessionStore: store} diff --git a/production/example.backend.env b/production/example.backend.env index e01b533f..78f1e186 100644 --- a/production/example.backend.env +++ b/production/example.backend.env @@ -1,7 +1,20 @@ -REDIRECT_URL_DEV="http://localhost:8000/auth/callback" -SESSION_KEY="Random key" -CLIENT_SEC="Via Google Oauth" -CLIENT_ID="Via Google Oauth" -FRONTEND_ORIGIN_DEV="http://localhost" -CONTAINER_ORIGIN="http://production-syncserver-1:8080/" +# Environment: set to "production" for secure cookies and strict origin checking +ENV="production" + +# OAuth configuration +REDIRECT_URL_DEV="https://your-domain.com/auth/callback" +CLIENT_ID="your-google-oauth-client-id" +CLIENT_SEC="your-google-oauth-client-secret" + +# Session configuration (generate a random 32+ character key) +SESSION_KEY="generate-a-random-secret-key-here" + +# CORS and WebSocket origin (your frontend URL, no trailing slash) +FRONTEND_ORIGIN_DEV="https://your-domain.com" +ALLOWED_ORIGIN="https://your-domain.com" + +# Sync server container URL (internal Docker network) +CONTAINER_ORIGIN="http://syncserver:8080/" + +# Port (usually 8000) PORT=8000