diff --git a/internal/config/validation.go b/internal/config/validation.go index 57c156ff..accb87b3 100644 --- a/internal/config/validation.go +++ b/internal/config/validation.go @@ -20,6 +20,21 @@ var varExprPattern = regexp.MustCompile(`\$\{([A-Za-z_][A-Za-z0-9_]*)\}`) var logValidation = logger.New("config:validation") +// logValidateServerStart logs the beginning of server config validation. +func logValidateServerStart(name, serverType string) { + logValidation.Printf("Validating server config: name=%s, type=%s", name, serverType) +} + +// logValidateServerPassed logs a successful server config validation. +func logValidateServerPassed(name string) { + logValidation.Printf("Server config validation passed: name=%s", name) +} + +// logValidateServerFailed logs a failed server config validation with the given reason. +func logValidateServerFailed(name, reason string) { + logValidation.Printf("Validation failed: %s, name=%s", reason, name) +} + // expandVariablesCore is the shared implementation for variable expansion. // It works with byte slices and handles the core expansion logic, tracking undefined variables. // This eliminates code duplication between expandVariables and ExpandRawJSONVariables. @@ -105,7 +120,7 @@ func validateMounts(mounts []string, jsonPath string) error { // validateServerConfigWithCustomSchemas validates a server configuration with custom schema support func validateServerConfigWithCustomSchemas(name string, server *StdinServerConfig, customSchemas map[string]interface{}) error { - logValidation.Printf("Validating server config: name=%s, type=%s", name, server.Type) + logValidateServerStart(name, server.Type) jsonPath := fmt.Sprintf("mcpServers.%s", name) // Validate type (empty defaults to stdio) @@ -134,7 +149,7 @@ func validateStandardServerConfig(name string, server *StdinServerConfig, jsonPa // For stdio servers, container is required if server.Type == "stdio" || server.Type == "local" { if server.Container == "" { - logValidation.Printf("Validation failed: stdio server missing container field, name=%s", name) + logValidateServerFailed(name, "stdio server missing container field") return rules.MissingRequired("container", "stdio", jsonPath, "Add a 'container' field (e.g., \"ghcr.io/owner/image:tag\")") } @@ -150,16 +165,16 @@ func validateStandardServerConfig(name string, server *StdinServerConfig, jsonPa // For HTTP servers, url is required and mounts are not allowed if server.Type == "http" { if server.URL == "" { - logValidation.Printf("Validation failed: HTTP server missing url field, name=%s", name) + logValidateServerFailed(name, "HTTP server missing url field") return rules.MissingRequired("url", "HTTP", jsonPath, "Add a 'url' field (e.g., \"https://example.com/mcp\")") } if len(server.Mounts) > 0 { - logValidation.Printf("Validation failed: HTTP server has mounts field, name=%s", name) + logValidateServerFailed(name, "HTTP server has mounts field") return rules.UnsupportedField("mounts", "mounts are only supported for stdio (containerized) servers", jsonPath, "Remove the 'mounts' field from HTTP server configuration; mounts only apply to stdio servers") } } - logValidation.Printf("Server config validation passed: name=%s", name) + logValidateServerPassed(name) return nil } @@ -403,7 +418,7 @@ func validateTOMLStdioContainerization(servers map[string]*ServerConfig) error { // Check if command is Docker if cfg.Command != "docker" { - logValidation.Printf("Validation failed: stdio server using non-Docker command, name=%s, command=%s", name, cfg.Command) + logValidateServerFailed(name, fmt.Sprintf("stdio server using non-Docker command, command=%s", cfg.Command)) return fmt.Errorf( "server '%s': stdio servers must use containerized execution (command must be 'docker', got '%s'). "+ "This is required by MCP Gateway Specification Section 3.2.1 (Containerization Requirement). "+ diff --git a/internal/logger/common.go b/internal/logger/common.go index 5f947b08..876d0c86 100644 --- a/internal/logger/common.go +++ b/internal/logger/common.go @@ -11,27 +11,24 @@ import ( // Close Pattern for Logger Types // -// All logger types in this package should implement their Close() method using this pattern: +// All logger types in this package implement their Close() method using the withLock +// helper to ensure consistent mutex handling: // // func (l *Logger) Close() error { -// l.mu.Lock() -// defer l.mu.Unlock() -// -// // Optional: Perform cleanup before closing (e.g., write footer) -// // if l.logFile != nil { -// // if err := writeCleanup(); err != nil { -// // return closeLogFile(l.logFile, &l.mu, "loggerName") -// // } -// // } -// -// return closeLogFile(l.logFile, &l.mu, "loggerName") +// return l.withLock(func() error { +// // Optional: Perform cleanup before closing (e.g., write footer) +// return closeLogFile(l.logFile, &l.mu, "loggerName") +// }) // } // +// The withLock helper (defined on each logger type) acquires the mutex, executes the +// callback, then releases the mutex — ensuring the lock is always released via defer. +// // Why this pattern? // -// 1. Mutex protection: Acquire lock at method entry to ensure thread-safe cleanup -// 2. Deferred unlock: Use defer to release lock even if errors occur -// 3. Optional cleanup: Logger-specific cleanup (like MarkdownLogger's footer) goes before closeLogFile +// 1. Consistent locking: withLock enforces acquire-on-enter / release-on-exit +// 2. Deferred unlock: Implemented inside withLock using defer, so it's never forgotten +// 3. Optional cleanup: Logger-specific cleanup (like MarkdownLogger's footer) goes inside the callback // 4. Shared helper: Always delegate to closeLogFile() for consistent sync and close behavior // 5. Error handling: Return errors from closeLogFile to indicate serious issues // @@ -40,38 +37,28 @@ import ( // Simple Close() with no cleanup (FileLogger, JSONLLogger): // // func (fl *FileLogger) Close() error { -// fl.mu.Lock() -// defer fl.mu.Unlock() -// return closeLogFile(fl.logFile, &fl.mu, "file") +// return fl.withLock(func() error { +// return closeLogFile(fl.logFile, &fl.mu, "file") +// }) // } // // Close() with custom cleanup (MarkdownLogger): // // func (ml *MarkdownLogger) Close() error { -// ml.mu.Lock() -// defer ml.mu.Unlock() -// -// if ml.logFile != nil { -// // Write closing details tag before closing -// footer := "\n\n" -// if _, err := ml.logFile.WriteString(footer); err != nil { -// // Even if footer write fails, try to close the file properly +// return ml.withLock(func() error { +// if ml.logFile != nil { +// footer := "\n\n" +// if _, err := ml.logFile.WriteString(footer); err != nil { +// return closeLogFile(ml.logFile, &ml.mu, "markdown") +// } // return closeLogFile(ml.logFile, &ml.mu, "markdown") // } -// -// // Footer written successfully, now close -// return closeLogFile(ml.logFile, &ml.mu, "markdown") -// } -// return nil +// return nil +// }) // } // -// This pattern is intentionally duplicated across logger types rather than abstracted: -// - It's a standard Go idiom for wrapper methods -// - The duplication is minimal (5-14 lines per type) -// - Each logger can customize cleanup as needed -// - The shared closeLogFile() helper eliminates complex logic duplication -// -// When adding a new logger type, follow this pattern to ensure consistent behavior. +// When adding a new logger type, add a withLock helper and follow this pattern to ensure +// consistent, safe Close() behavior. // Initialization Pattern for Logger Types // @@ -81,14 +68,12 @@ import ( // // Standard Initialization Pattern: // -// All logger types use the initLogger() generic helper function for initialization: +// All logger types use the initLogger() generic helper function for initialization. +// The setup and error-handler callbacks are defined as named package-level functions +// (e.g., setupFileLogger, handleFileLoggerError) to aid readability and testability: // // func Init*Logger(logDir, fileName string) error { -// logger, err := initLogger( -// logDir, fileName, fileFlags, -// setupFunc, // Configure logger after file is opened -// errorHandler, // Handle initialization failures -// ) +// logger, err := initLogger(logDir, fileName, fileFlags, setup*Logger, handle*LoggerError) // initGlobal*Logger(logger) // return err // } @@ -96,8 +81,8 @@ import ( // The initLogger() helper: // 1. Attempts to create the log directory (if needed) // 2. Opens the log file with specified flags (os.O_APPEND, os.O_TRUNC, etc.) -// 3. Calls setupFunc to configure the logger instance -// 4. On error, calls errorHandler to implement fallback behavior +// 3. Calls setup*Logger to configure the logger instance +// 4. On error, calls handle*LoggerError to implement fallback behavior // 5. Returns the initialized logger and any error // // Fallback Behavior Strategies: diff --git a/internal/logger/file_logger.go b/internal/logger/file_logger.go index f5505ded..9038e9a9 100644 --- a/internal/logger/file_logger.go +++ b/internal/logger/file_logger.go @@ -23,46 +23,52 @@ var ( globalLoggerMu sync.RWMutex ) +// setupFileLogger configures a FileLogger after the log file has been opened. +func setupFileLogger(file *os.File, logDir, fileName string) (*FileLogger, error) { + fl := &FileLogger{ + logDir: logDir, + fileName: fileName, + logFile: file, + logger: log.New(file, "", 0), + } + log.Printf("Logging to file: %s", filepath.Join(logDir, fileName)) + return fl, nil +} + +// handleFileLoggerError falls back to stdout when the log file cannot be opened. +func handleFileLoggerError(err error, logDir, fileName string) (*FileLogger, error) { + log.Printf("WARNING: Failed to initialize log file: %v", err) + log.Printf("WARNING: Falling back to stdout for logging") + fl := &FileLogger{ + logDir: logDir, + fileName: fileName, + useFallback: true, + logger: log.New(os.Stdout, "", 0), + } + return fl, nil +} + // InitFileLogger initializes the global file logger // If the log directory doesn't exist and can't be created, falls back to stdout func InitFileLogger(logDir, fileName string) error { - logger, err := initLogger( - logDir, fileName, os.O_APPEND, - // Setup function: configure the logger after file is opened - func(file *os.File, logDir, fileName string) (*FileLogger, error) { - fl := &FileLogger{ - logDir: logDir, - fileName: fileName, - logFile: file, - logger: log.New(file, "", 0), - } - log.Printf("Logging to file: %s", filepath.Join(logDir, fileName)) - return fl, nil - }, - // Error handler: fallback to stdout on error - func(err error, logDir, fileName string) (*FileLogger, error) { - log.Printf("WARNING: Failed to initialize log file: %v", err) - log.Printf("WARNING: Falling back to stdout for logging") - fl := &FileLogger{ - logDir: logDir, - fileName: fileName, - useFallback: true, - logger: log.New(os.Stdout, "", 0), // We'll add our own timestamp - } - return fl, nil - }, - ) - + logger, err := initLogger(logDir, fileName, os.O_APPEND, setupFileLogger, handleFileLoggerError) initGlobalFileLogger(logger) return err } -// Close closes the log file -func (fl *FileLogger) Close() error { +// withLock acquires fl.mu, executes fn, then releases fl.mu. +// Use this in methods that return an error to avoid repeating the lock/unlock preamble. +func (fl *FileLogger) withLock(fn func() error) error { fl.mu.Lock() defer fl.mu.Unlock() + return fn() +} - return closeLogFile(fl.logFile, &fl.mu, "file") +// Close closes the log file +func (fl *FileLogger) Close() error { + return fl.withLock(func() error { + return closeLogFile(fl.logFile, &fl.mu, "file") + }) } // LogLevel represents the severity of a log message diff --git a/internal/logger/jsonl_logger.go b/internal/logger/jsonl_logger.go index 8725ac5c..cbbb3631 100644 --- a/internal/logger/jsonl_logger.go +++ b/internal/logger/jsonl_logger.go @@ -37,25 +37,25 @@ type JSONLRPCMessage struct { Payload json.RawMessage `json:"payload"` // Full sanitized payload as raw JSON } +// setupJSONLLogger configures a JSONLLogger after the log file has been opened. +func setupJSONLLogger(file *os.File, logDir, fileName string) (*JSONLLogger, error) { + jl := &JSONLLogger{ + logDir: logDir, + fileName: fileName, + logFile: file, + encoder: json.NewEncoder(file), + } + return jl, nil +} + +// handleJSONLLoggerError returns the error immediately — JSONLLogger has no fallback mode. +func handleJSONLLoggerError(err error, _ string, _ string) (*JSONLLogger, error) { + return nil, err +} + // InitJSONLLogger initializes the global JSONL logger func InitJSONLLogger(logDir, fileName string) error { - logger, err := initLogger( - logDir, fileName, os.O_APPEND, - // Setup function: configure the logger after file is opened - func(file *os.File, logDir, fileName string) (*JSONLLogger, error) { - jl := &JSONLLogger{ - logDir: logDir, - fileName: fileName, - logFile: file, - encoder: json.NewEncoder(file), - } - return jl, nil - }, - // Error handler: return error immediately (no fallback) - func(err error, logDir, fileName string) (*JSONLLogger, error) { - return nil, err - }, - ) + logger, err := initLogger(logDir, fileName, os.O_APPEND, setupJSONLLogger, handleJSONLLoggerError) // Only initialize global logger if successful (no error) // Unlike FileLogger/MarkdownLogger which return fallback loggers, @@ -67,12 +67,19 @@ func InitJSONLLogger(logDir, fileName string) error { return err } -// Close closes the JSONL log file -func (jl *JSONLLogger) Close() error { +// withLock acquires jl.mu, executes fn, then releases jl.mu. +// Use this in methods that return an error to avoid repeating the lock/unlock preamble. +func (jl *JSONLLogger) withLock(fn func() error) error { jl.mu.Lock() defer jl.mu.Unlock() + return fn() +} - return closeLogFile(jl.logFile, &jl.mu, "JSONL") +// Close closes the JSONL log file +func (jl *JSONLLogger) Close() error { + return jl.withLock(func() error { + return closeLogFile(jl.logFile, &jl.mu, "JSONL") + }) } // LogMessage logs an RPC message to the JSONL file diff --git a/internal/logger/markdown_logger.go b/internal/logger/markdown_logger.go index 856c7213..aaa8e020 100644 --- a/internal/logger/markdown_logger.go +++ b/internal/logger/markdown_logger.go @@ -24,31 +24,30 @@ var ( globalMarkdownMu sync.RWMutex ) +// setupMarkdownLogger configures a MarkdownLogger after the log file has been opened. +func setupMarkdownLogger(file *os.File, logDir, fileName string) (*MarkdownLogger, error) { + ml := &MarkdownLogger{ + logDir: logDir, + fileName: fileName, + logFile: file, + initialized: false, // Will be initialized on first write + } + return ml, nil +} + +// handleMarkdownLoggerError sets fallback mode (no stdout redirect) when the file cannot be opened. +func handleMarkdownLoggerError(_ error, logDir, fileName string) (*MarkdownLogger, error) { + ml := &MarkdownLogger{ + logDir: logDir, + fileName: fileName, + useFallback: true, + } + return ml, nil +} + // InitMarkdownLogger initializes the global markdown logger func InitMarkdownLogger(logDir, fileName string) error { - logger, err := initLogger( - logDir, fileName, os.O_TRUNC, - // Setup function: configure the logger after file is opened - func(file *os.File, logDir, fileName string) (*MarkdownLogger, error) { - ml := &MarkdownLogger{ - logDir: logDir, - fileName: fileName, - logFile: file, - initialized: false, // Will be initialized on first write - } - return ml, nil - }, - // Error handler: set fallback mode (no stdout redirect) - func(err error, logDir, fileName string) (*MarkdownLogger, error) { - ml := &MarkdownLogger{ - logDir: logDir, - fileName: fileName, - useFallback: true, - } - return ml, nil - }, - ) - + logger, err := initLogger(logDir, fileName, os.O_TRUNC, setupMarkdownLogger, handleMarkdownLoggerError) initGlobalMarkdownLogger(logger) return err } @@ -69,23 +68,30 @@ func (ml *MarkdownLogger) initializeFile() error { return nil } -// Close closes the log file and writes the closing details tag -func (ml *MarkdownLogger) Close() error { +// withLock acquires ml.mu, executes fn, then releases ml.mu. +// Use this in methods that return an error to avoid repeating the lock/unlock preamble. +func (ml *MarkdownLogger) withLock(fn func() error) error { ml.mu.Lock() defer ml.mu.Unlock() + return fn() +} - if ml.logFile != nil { - // Write closing details tag before closing - footer := "\n\n" - if _, err := ml.logFile.WriteString(footer); err != nil { - // Even if footer write fails, try to close the file properly +// Close closes the log file and writes the closing details tag +func (ml *MarkdownLogger) Close() error { + return ml.withLock(func() error { + if ml.logFile != nil { + // Write closing details tag before closing + footer := "\n\n" + if _, err := ml.logFile.WriteString(footer); err != nil { + // Even if footer write fails, try to close the file properly + return closeLogFile(ml.logFile, &ml.mu, "markdown") + } + + // Footer written successfully, now close return closeLogFile(ml.logFile, &ml.mu, "markdown") } - - // Footer written successfully, now close - return closeLogFile(ml.logFile, &ml.mu, "markdown") - } - return nil + return nil + }) } // getEmojiForLevel returns the appropriate emoji for the log level diff --git a/internal/logger/tools_logger.go b/internal/logger/tools_logger.go index 8250c487..b1f7952f 100644 --- a/internal/logger/tools_logger.go +++ b/internal/logger/tools_logger.go @@ -39,62 +39,69 @@ var ( globalToolsMu sync.RWMutex ) -// InitToolsLogger initializes the global tools logger -// If the log directory doesn't exist and can't be created, falls back to no-op -func InitToolsLogger(logDir, fileName string) error { - logger, err := initLogger( - logDir, fileName, os.O_TRUNC, // Truncate existing file to start fresh - // Setup function: configure the logger after directory is ready - func(file *os.File, logDir, fileName string) (*ToolsLogger, error) { - // Close the file immediately - we'll write directly later - if file != nil { - file.Close() - } - - tl := &ToolsLogger{ - logDir: logDir, - fileName: fileName, - data: &ToolsData{ - Servers: make(map[string][]ToolInfo), - }, - } - log.Printf("Tools logging to file: %s", filepath.Join(logDir, fileName)) - return tl, nil +// setupToolsLogger configures a ToolsLogger after the log file has been opened. +// The file is closed immediately because ToolsLogger writes atomically on each update. +func setupToolsLogger(file *os.File, logDir, fileName string) (*ToolsLogger, error) { + // Close the file immediately - we'll write directly later + if file != nil { + file.Close() + } + + tl := &ToolsLogger{ + logDir: logDir, + fileName: fileName, + data: &ToolsData{ + Servers: make(map[string][]ToolInfo), }, - // Error handler: fallback to no-op on error - func(err error, logDir, fileName string) (*ToolsLogger, error) { - log.Printf("WARNING: Failed to initialize tools log file: %v", err) - log.Printf("WARNING: Tools logging disabled") - tl := &ToolsLogger{ - logDir: logDir, - fileName: fileName, - useFallback: true, - data: &ToolsData{ - Servers: make(map[string][]ToolInfo), - }, - } - return tl, nil + } + log.Printf("Tools logging to file: %s", filepath.Join(logDir, fileName)) + return tl, nil +} + +// handleToolsLoggerError falls back to a no-op logger when the file cannot be opened. +func handleToolsLoggerError(err error, logDir, fileName string) (*ToolsLogger, error) { + log.Printf("WARNING: Failed to initialize tools log file: %v", err) + log.Printf("WARNING: Tools logging disabled") + tl := &ToolsLogger{ + logDir: logDir, + fileName: fileName, + useFallback: true, + data: &ToolsData{ + Servers: make(map[string][]ToolInfo), }, - ) + } + return tl, nil +} +// InitToolsLogger initializes the global tools logger +// If the log directory doesn't exist and can't be created, falls back to no-op +func InitToolsLogger(logDir, fileName string) error { + logger, err := initLogger(logDir, fileName, os.O_TRUNC, setupToolsLogger, handleToolsLoggerError) initGlobalToolsLogger(logger) return err } -// LogTools logs the tools for a specific server -func (tl *ToolsLogger) LogTools(serverID string, tools []ToolInfo) error { +// withLock acquires tl.mu, executes fn, then releases tl.mu. +// Use this in methods that return an error to avoid repeating the lock/unlock preamble. +func (tl *ToolsLogger) withLock(fn func() error) error { tl.mu.Lock() defer tl.mu.Unlock() + return fn() +} - if tl.useFallback { - return nil // Silently skip if in fallback mode - } +// LogTools logs the tools for a specific server +func (tl *ToolsLogger) LogTools(serverID string, tools []ToolInfo) error { + return tl.withLock(func() error { + if tl.useFallback { + return nil // Silently skip if in fallback mode + } - // Update the data structure - tl.data.Servers[serverID] = tools + // Update the data structure + tl.data.Servers[serverID] = tools - // Write the updated data to file - return tl.writeToFile() + // Write the updated data to file + return tl.writeToFile() + }) } // writeToFile writes the current tools data to the JSON file diff --git a/internal/server/session.go b/internal/server/session.go index f66613b9..51bcb113 100644 --- a/internal/server/session.go +++ b/internal/server/session.go @@ -10,6 +10,7 @@ import ( "github.com/github/gh-aw-mcpg/internal/auth" "github.com/github/gh-aw-mcpg/internal/logger" + "github.com/github/gh-aw-mcpg/internal/syncutil" ) var logSession = logger.New("server:session") @@ -67,33 +68,30 @@ func (us *UnifiedServer) requireSession(ctx context.Context) error { sessionID := us.getSessionID(ctx) logSession.Printf("Checking session: sessionID=%s", auth.TruncateSessionID(sessionID)) - // Use double-checked locking to auto-create session if needed - us.sessionMu.RLock() - session := us.sessions[sessionID] - us.sessionMu.RUnlock() - - if session == nil { - // Need to create session - acquire write lock - us.sessionMu.Lock() - // Double-check after acquiring write lock to avoid race condition - if us.sessions[sessionID] == nil { - log.Printf("Auto-creating session for ID: %s", auth.TruncateSessionID(sessionID)) - us.sessions[sessionID] = NewSession(sessionID, "") - log.Printf("Session auto-created for ID: %s", auth.TruncateSessionID(sessionID)) + // Use syncutil.GetOrCreate to handle the double-checked locking pattern. + // The isNew flag is set inside the create callback (while the write lock is held) + // so that ensureSessionDirectory is called exactly once per new session. + isNew := false + if _, err := syncutil.GetOrCreate(&us.sessionMu, us.sessions, sessionID, func() (*Session, error) { + logSession.Printf("Auto-creating session for ID: %s", auth.TruncateSessionID(sessionID)) + s := NewSession(sessionID, "") + logSession.Printf("Session auto-created for ID: %s", auth.TruncateSessionID(sessionID)) + isNew = true + return s, nil + }); err != nil { + return err + } - // Ensure session directory exists in payload mount point - // This is done after releasing the lock to avoid holding it during I/O - us.sessionMu.Unlock() - if err := us.ensureSessionDirectory(sessionID); err != nil { - logger.LogWarn("client", "Failed to create session directory for session=%s: %v", auth.TruncateSessionID(sessionID), err) - // Don't fail - payloads will attempt to create the directory when needed - } - return nil + if isNew { + // Ensure session directory exists in payload mount point. + // Called after GetOrCreate releases the lock to avoid holding it during I/O. + if err := us.ensureSessionDirectory(sessionID); err != nil { + logger.LogWarn("client", "Failed to create session directory for session=%s: %v", auth.TruncateSessionID(sessionID), err) + // Don't fail - payloads will attempt to create the directory when needed } - us.sessionMu.Unlock() } - log.Printf("Session validated for ID: %s", auth.TruncateSessionID(sessionID)) + logSession.Printf("Session validated for ID: %s", auth.TruncateSessionID(sessionID)) return nil }