Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 50 additions & 1 deletion backend/controllers/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,19 @@ import (

"ccsync_backend/utils"

"github.com/gorilla/sessions"
"github.com/gorilla/websocket"
)

// getEnv returns the environment mode, defaulting to "development"
func getEnv() string {
env := os.Getenv("ENV")
if env == "" {
return "development"
}
return env
}

type JobStatus struct {
Job string `json:"job"`
Status string `json:"status"`
Expand All @@ -21,7 +31,7 @@ func checkWebSocketOrigin(r *http.Request) bool {
origin := r.Header.Get("Origin")

// In development mode, be more permissive
if os.Getenv("ENV") != "production" {
if getEnv() != "production" {
if origin == "" ||
strings.HasPrefix(origin, "http://localhost") ||
strings.HasPrefix(origin, "http://127.0.0.1") {
Expand Down Expand Up @@ -73,6 +83,45 @@ var upgrader = websocket.Upgrader{
var clients = make(map[*websocket.Conn]bool)
var broadcast = make(chan JobStatus)

// AuthenticatedWebSocketHandler creates a WebSocket handler that requires authentication
func AuthenticatedWebSocketHandler(store *sessions.CookieStore) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Validate session before upgrading to WebSocket
session, err := store.Get(r, "session-name")
if err != nil {
utils.Logger.Warnf("WebSocket auth failed: could not get session: %v", err)
http.Error(w, "Authentication required", http.StatusUnauthorized)
return
}

userInfo, ok := session.Values["user"].(map[string]interface{})
if !ok || userInfo == nil {
utils.Logger.Warnf("WebSocket auth failed: no user in session")
http.Error(w, "Authentication required", http.StatusUnauthorized)
return
}

// User is authenticated, proceed with WebSocket upgrade
ws, err := upgrader.Upgrade(w, r, nil)
if err != nil {
utils.Logger.Error("WebSocket Upgrade Error:", err)
return
}
defer ws.Close()

clients[ws] = true
for {
_, _, err := ws.ReadMessage()
if err != nil {
delete(clients, ws)
break
}
}
}
}

// WebSocketHandler is kept for backward compatibility but should not be used
// Use AuthenticatedWebSocketHandler instead
func WebSocketHandler(w http.ResponseWriter, r *http.Request) {
ws, err := upgrader.Upgrade(w, r, nil)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion backend/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ func main() {

mux.HandleFunc("/health", controllers.HealthCheckHandler)

mux.HandleFunc("/ws", controllers.WebSocketHandler)
mux.HandleFunc("/ws", controllers.AuthenticatedWebSocketHandler(store))

// API documentation endpoint
mux.HandleFunc("/api/docs/", httpSwagger.WrapHandler)
Expand Down
Loading