Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 44 additions & 2 deletions backend/controllers/app_handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package controllers
import (
"ccsync_backend/utils"
"context"
"crypto/rand"
"encoding/base64"
"encoding/json"
"net/http"
"os"
Expand All @@ -11,6 +13,15 @@ import (
"golang.org/x/oauth2"
)

// generateOAuthState creates a cryptographically secure random state string
func generateOAuthState() (string, error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", err
}
return base64.URLEncoding.EncodeToString(b), nil
}

type App struct {
Config *oauth2.Config
SessionStore *sessions.CookieStore
Expand All @@ -26,9 +37,27 @@ type App struct {
// @Accept json
// @Produce json
// @Success 307 {string} string "Redirect to OAuth provider"
// @Failure 500 {string} string "Internal server error"
// @Router /auth/oauth [get]
func (a *App) OAuthHandler(w http.ResponseWriter, r *http.Request) {
url := a.Config.AuthCodeURL("state", oauth2.AccessTypeOffline)
// Generate a cryptographically secure random state to prevent CSRF attacks
state, err := generateOAuthState()
if err != nil {
utils.Logger.Errorf("Failed to generate OAuth state: %v", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}

// Store state in session for validation in callback
session, _ := a.SessionStore.Get(r, "session-name")
session.Values["oauth_state"] = state
if err := session.Save(r, w); err != nil {
utils.Logger.Errorf("Failed to save OAuth state to session: %v", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}

url := a.Config.AuthCodeURL(state, oauth2.AccessTypeOffline)
http.Redirect(w, r, url, http.StatusTemporaryRedirect)
}

Expand All @@ -39,11 +68,25 @@ func (a *App) OAuthHandler(w http.ResponseWriter, r *http.Request) {
// @Accept json
// @Produce json
// @Param code query string true "OAuth authorization code"
// @Param state query string true "OAuth state parameter for CSRF protection"
// @Success 303 {string} string "Redirect to frontend home page"
// @Failure 400 {string} string "Bad request"
// @Failure 403 {string} string "Invalid OAuth state"
// @Failure 500 {string} string "Internal server error"
// @Router /auth/callback [get]
func (a *App) OAuthCallbackHandler(w http.ResponseWriter, r *http.Request) {
// Validate OAuth state parameter to prevent CSRF attacks
state := r.URL.Query().Get("state")
session, _ := a.SessionStore.Get(r, "session-name")
expectedState, ok := session.Values["oauth_state"].(string)
if !ok || state == "" || state != expectedState {
utils.Logger.Warnf("OAuth state mismatch: expected=%s, got=%s", expectedState, state)
http.Error(w, "Invalid OAuth state", http.StatusForbidden)
return
}
// Clear the state from session after validation (one-time use)
delete(session.Values, "oauth_state")

code := r.URL.Query().Get("code")

t, err := a.Config.Exchange(context.Background(), code)
Expand Down Expand Up @@ -77,7 +120,6 @@ func (a *App) OAuthCallbackHandler(w http.ResponseWriter, r *http.Request) {

userInfo["uuid"] = uuidStr
userInfo["encryption_secret"] = encryptionSecret
session, _ := a.SessionStore.Get(r, "session-name")
session.Values["user"] = userInfo
if err := session.Save(r, w); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
Expand Down
Loading