diff --git a/internal/launcher/health_monitor.go b/internal/launcher/health_monitor.go new file mode 100644 index 00000000..c4864e35 --- /dev/null +++ b/internal/launcher/health_monitor.go @@ -0,0 +1,127 @@ +package launcher + +import ( + "log" + "time" + + "github.com/github/gh-aw-mcpg/internal/logger" +) + +const ( + // DefaultHealthCheckInterval is the recommended periodic health check interval (spec §8). + DefaultHealthCheckInterval = 30 * time.Second + + // maxConsecutiveRestartFailures caps how many consecutive restart failures + // are allowed before the monitor stops retrying a particular server. + maxConsecutiveRestartFailures = 3 +) + +var logHealth = logger.New("launcher:health") + +// HealthMonitor periodically checks backend server health and automatically +// restarts servers that are in an error state (MCP Gateway Specification §8). +type HealthMonitor struct { + launcher *Launcher + interval time.Duration + stopCh chan struct{} + doneCh chan struct{} + + // Track consecutive restart failures per server to avoid infinite retry loops. + consecutiveFailures map[string]int +} + +// NewHealthMonitor creates a health monitor for the given launcher. +func NewHealthMonitor(l *Launcher, interval time.Duration) *HealthMonitor { + if interval <= 0 { + interval = DefaultHealthCheckInterval + } + return &HealthMonitor{ + launcher: l, + interval: interval, + stopCh: make(chan struct{}), + doneCh: make(chan struct{}), + consecutiveFailures: make(map[string]int), + } +} + +// Start begins periodic health checks in a background goroutine. +func (hm *HealthMonitor) Start() { + log.Printf("[HEALTH] Starting health monitor (interval=%s)", hm.interval) + logger.LogInfo("startup", "Health monitor started (interval=%s)", hm.interval) + go hm.run() +} + +// Stop signals the health monitor to stop and waits for it to finish. +func (hm *HealthMonitor) Stop() { + close(hm.stopCh) + <-hm.doneCh + logHealth.Print("Health monitor stopped") + logger.LogInfo("shutdown", "Health monitor stopped") +} + +func (hm *HealthMonitor) run() { + defer close(hm.doneCh) + + ticker := time.NewTicker(hm.interval) + defer ticker.Stop() + + for { + select { + case <-hm.stopCh: + return + case <-hm.launcher.ctx.Done(): + return + case <-ticker.C: + hm.checkAll() + } + } +} + +// checkAll iterates over every configured backend and attempts to restart +// any server that is in an error state. +func (hm *HealthMonitor) checkAll() { + for _, serverID := range hm.launcher.ServerIDs() { + state := hm.launcher.GetServerState(serverID) + + switch state.Status { + case "error": + hm.handleErrorState(serverID, state) + case "running": + // Reset consecutive failure counter on healthy server. + if hm.consecutiveFailures[serverID] > 0 { + hm.consecutiveFailures[serverID] = 0 + } + } + } +} + +func (hm *HealthMonitor) handleErrorState(serverID string, state ServerState) { + failures := hm.consecutiveFailures[serverID] + if failures >= maxConsecutiveRestartFailures { + // Already logged when the threshold was reached; stay silent. + return + } + + logger.LogWarn("backend", "Health check: server %q in error state (%s), attempting restart (%d/%d)", + serverID, state.LastError, failures+1, maxConsecutiveRestartFailures) + + // Clear error state and cached connection so GetOrLaunch can retry. + hm.launcher.clearServerForRestart(serverID) + + _, err := GetOrLaunch(hm.launcher, serverID) + if err != nil { + hm.consecutiveFailures[serverID] = failures + 1 + logger.LogError("backend", "Health check: restart failed for server %q: %v (attempt %d/%d)", + serverID, err, failures+1, maxConsecutiveRestartFailures) + if hm.consecutiveFailures[serverID] >= maxConsecutiveRestartFailures { + logger.LogError("backend", + "Health check: server %q reached max restart attempts (%d), will not retry until gateway restart", + serverID, maxConsecutiveRestartFailures) + } + return + } + + hm.consecutiveFailures[serverID] = 0 + log.Printf("[HEALTH] Successfully restarted server %q", serverID) + logger.LogInfo("backend", "Health check: successfully restarted server %q", serverID) +} diff --git a/internal/launcher/health_monitor_test.go b/internal/launcher/health_monitor_test.go new file mode 100644 index 00000000..5e16610a --- /dev/null +++ b/internal/launcher/health_monitor_test.go @@ -0,0 +1,137 @@ +package launcher + +import ( + "context" + "testing" + "time" + + "github.com/github/gh-aw-mcpg/internal/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func newTestLauncher(servers map[string]*config.ServerConfig) *Launcher { + ctx := context.Background() + cfg := &config.Config{Servers: servers} + return New(ctx, cfg) +} + +func TestHealthMonitor_StartStop(t *testing.T) { + l := newTestLauncher(map[string]*config.ServerConfig{}) + hm := NewHealthMonitor(l, 50*time.Millisecond) + + hm.Start() + // Let at least one tick fire + time.Sleep(100 * time.Millisecond) + hm.Stop() + + // Verify doneCh is closed (Stop returned) + select { + case <-hm.doneCh: + // expected + default: + t.Fatal("doneCh should be closed after Stop") + } +} + +func TestHealthMonitor_DefaultInterval(t *testing.T) { + l := newTestLauncher(map[string]*config.ServerConfig{}) + hm := NewHealthMonitor(l, 0) + + assert.Equal(t, DefaultHealthCheckInterval, hm.interval) +} + +func TestHealthMonitor_RunningServerResetsFailureCounter(t *testing.T) { + servers := map[string]*config.ServerConfig{ + "test-server": {Type: "http", URL: "http://localhost:9999"}, + } + l := newTestLauncher(servers) + + // Simulate a running server + l.recordStart("test-server") + + hm := NewHealthMonitor(l, 50*time.Millisecond) + hm.consecutiveFailures["test-server"] = 2 + + hm.checkAll() + + assert.Equal(t, 0, hm.consecutiveFailures["test-server"]) +} + +func TestHealthMonitor_ErrorStateIncrementsFailureCounter(t *testing.T) { + // Use a server config that will fail to launch (no Docker available in test) + servers := map[string]*config.ServerConfig{ + "bad-server": {Type: "stdio", Command: "nonexistent-binary-xyz"}, + } + l := newTestLauncher(servers) + + // Simulate the server being in error state + l.recordError("bad-server", "process crashed") + + hm := NewHealthMonitor(l, time.Hour) // large interval; we call checkAll manually + + hm.checkAll() + + // Server should have failed restart and incremented counter + assert.Equal(t, 1, hm.consecutiveFailures["bad-server"]) +} + +func TestHealthMonitor_StopsRetryingAtMaxFailures(t *testing.T) { + servers := map[string]*config.ServerConfig{ + "bad-server": {Type: "stdio", Command: "nonexistent-binary-xyz"}, + } + l := newTestLauncher(servers) + + hm := NewHealthMonitor(l, time.Hour) + hm.consecutiveFailures["bad-server"] = maxConsecutiveRestartFailures + + // Simulate error state + l.recordError("bad-server", "still broken") + + hm.checkAll() + + // Should not have incremented further + assert.Equal(t, maxConsecutiveRestartFailures, hm.consecutiveFailures["bad-server"]) + + // Error should still be present (no restart attempted) + state := l.GetServerState("bad-server") + assert.Equal(t, "error", state.Status) +} + +func TestClearServerForRestart(t *testing.T) { + l := newTestLauncher(map[string]*config.ServerConfig{ + "srv": {Type: "http", URL: "http://localhost:9999"}, + }) + + // Record start then error + l.serverStartTimes["srv"] = time.Now() + l.serverErrors["srv"] = "something failed" + + state := l.GetServerState("srv") + require.Equal(t, "error", state.Status) + + l.clearServerForRestart("srv") + + state = l.GetServerState("srv") + assert.Equal(t, "stopped", state.Status) + assert.Empty(t, state.LastError) +} + +func TestHealthMonitor_RespectsContextCancellation(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cfg := &config.Config{Servers: map[string]*config.ServerConfig{}} + l := New(ctx, cfg) + + hm := NewHealthMonitor(l, 50*time.Millisecond) + hm.Start() + + // Cancel context — monitor should exit + cancel() + + select { + case <-hm.doneCh: + // expected — monitor stopped + case <-time.After(2 * time.Second): + t.Fatal("health monitor did not stop after context cancellation") + } +} diff --git a/internal/launcher/launcher.go b/internal/launcher/launcher.go index 70b9a3cd..6336459a 100644 --- a/internal/launcher/launcher.go +++ b/internal/launcher/launcher.go @@ -322,6 +322,23 @@ func (l *Launcher) recordError(serverID string, errMsg string) { logLauncher.Printf("Recorded server error: serverID=%s, error=%s", serverID, errMsg) } +// clearServerForRestart removes the error record and any cached connection for +// serverID so that a subsequent GetOrLaunch call will attempt a fresh launch. +// Called by HealthMonitor before retrying a failed server. +func (l *Launcher) clearServerForRestart(serverID string) { + l.mu.Lock() + defer l.mu.Unlock() + + delete(l.serverErrors, serverID) + delete(l.serverStartTimes, serverID) + + if conn, ok := l.connections[serverID]; ok { + conn.Close() + delete(l.connections, serverID) + } + logLauncher.Printf("Cleared server state for restart: serverID=%s", serverID) +} + // GetServerState returns the observed runtime state for a single server. func (l *Launcher) GetServerState(serverID string) ServerState { l.mu.RLock() diff --git a/internal/server/unified.go b/internal/server/unified.go index 297232c8..a515197b 100644 --- a/internal/server/unified.go +++ b/internal/server/unified.go @@ -102,6 +102,9 @@ type UnifiedServer struct { // Testing support - when true, skips os.Exit() call testMode bool + + // Health monitoring + healthMonitor *launcher.HealthMonitor } // NewUnified creates a new unified MCP server @@ -192,6 +195,10 @@ func NewUnified(ctx context.Context, cfg *config.Config) (*UnifiedServer, error) return nil, fmt.Errorf("failed to register tools: %w", err) } + // Start periodic health monitoring and auto-restart (spec §8) + us.healthMonitor = launcher.NewHealthMonitor(l, launcher.DefaultHealthCheckInterval) + us.healthMonitor.Start() + logUnified.Printf("Unified server created successfully with %d tools", len(us.tools)) return us, nil } @@ -661,6 +668,11 @@ func (us *UnifiedServer) InitiateShutdown() int { log.Println("Initiating gateway shutdown...") logger.LogInfo("shutdown", "Gateway shutdown initiated") + // Stop health monitor before closing connections + if us.healthMonitor != nil { + us.healthMonitor.Stop() + } + // Count servers before closing serversTerminated = len(us.launcher.ServerIDs())