Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions internal/launcher/health_monitor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
package launcher

import (
"log"
"time"

"github.com/github/gh-aw-mcpg/internal/logger"
)

const (
// DefaultHealthCheckInterval is the recommended periodic health check interval (spec §8).
DefaultHealthCheckInterval = 30 * time.Second

// maxConsecutiveRestartFailures caps how many consecutive restart failures
// are allowed before the monitor stops retrying a particular server.
maxConsecutiveRestartFailures = 3
)

var logHealth = logger.New("launcher:health")

// HealthMonitor periodically checks backend server health and automatically
// restarts servers that are in an error state (MCP Gateway Specification §8).
type HealthMonitor struct {
launcher *Launcher
interval time.Duration
stopCh chan struct{}
doneCh chan struct{}

// Track consecutive restart failures per server to avoid infinite retry loops.
consecutiveFailures map[string]int
}

// NewHealthMonitor creates a health monitor for the given launcher.
func NewHealthMonitor(l *Launcher, interval time.Duration) *HealthMonitor {
if interval <= 0 {
interval = DefaultHealthCheckInterval
}
return &HealthMonitor{
launcher: l,
interval: interval,
stopCh: make(chan struct{}),
doneCh: make(chan struct{}),
consecutiveFailures: make(map[string]int),
}
}

// Start begins periodic health checks in a background goroutine.
func (hm *HealthMonitor) Start() {
log.Printf("[HEALTH] Starting health monitor (interval=%s)", hm.interval)
logger.LogInfo("startup", "Health monitor started (interval=%s)", hm.interval)
go hm.run()
}

// Stop signals the health monitor to stop and waits for it to finish.
func (hm *HealthMonitor) Stop() {
close(hm.stopCh)
<-hm.doneCh
Comment on lines +47 to +57
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Start()/Stop() aren’t safe to call more than once: Start() can spawn multiple goroutines that will both close(hm.doneCh) (panic), and Stop() will close(hm.stopCh) (panic) if invoked twice. Consider guarding lifecycle with a sync.Once (or an atomic state) so start/stop are idempotent and safe under repeated shutdown paths.

Copilot uses AI. Check for mistakes.
logHealth.Print("Health monitor stopped")
logger.LogInfo("shutdown", "Health monitor stopped")
Comment on lines +56 to +59
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stop() can block shutdown for up to the backend startupTimeout (default 60s) per restart attempt, because the goroutine may be inside checkAll()/GetOrLaunch() and won’t observe stopCh until it returns. Consider making Stop() time-bounded (e.g., select on doneCh with a timeout) and/or making restart attempts cancellable so shutdown isn’t held hostage by a hung backend launch.

Suggested change
close(hm.stopCh)
<-hm.doneCh
logHealth.Print("Health monitor stopped")
logger.LogInfo("shutdown", "Health monitor stopped")
// If already stopped, return immediately to avoid closing an already-closed channel.
select {
case <-hm.doneCh:
logHealth.Print("Health monitor already stopped")
logger.LogInfo("shutdown", "Health monitor already stopped")
return
default:
}
// Signal the run loop to stop.
close(hm.stopCh)
// Bound the time we wait for the run loop to exit, so shutdown isn't blocked indefinitely
// by a stuck health check or backend restart attempt.
timeout := hm.interval
if timeout <= 0 {
timeout = DefaultHealthCheckInterval
}
select {
case <-hm.doneCh:
logHealth.Print("Health monitor stopped")
logger.LogInfo("shutdown", "Health monitor stopped")
case <-time.After(timeout):
logHealth.Printf("Health monitor stop timed out after %s", timeout)
logger.LogInfo("shutdown", "Health monitor stop timed out after %s", timeout)
}

Copilot uses AI. Check for mistakes.
}

func (hm *HealthMonitor) run() {
defer close(hm.doneCh)

ticker := time.NewTicker(hm.interval)
defer ticker.Stop()

for {
select {
case <-hm.stopCh:
return
case <-hm.launcher.ctx.Done():
return
case <-ticker.C:
hm.checkAll()
}
}
}

// checkAll iterates over every configured backend and attempts to restart
// any server that is in an error state.
func (hm *HealthMonitor) checkAll() {
for _, serverID := range hm.launcher.ServerIDs() {
state := hm.launcher.GetServerState(serverID)

switch state.Status {
case "error":
hm.handleErrorState(serverID, state)
case "running":
// Reset consecutive failure counter on healthy server.
if hm.consecutiveFailures[serverID] > 0 {
hm.consecutiveFailures[serverID] = 0
}
}
Comment on lines +82 to +94
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

checkAll() only consults GetServerState(), which is currently derived from cached serverStartTimes/serverErrors (i.e., it won’t detect a backend that crashes after a successful launch). If the intent is true periodic health monitoring, consider adding an active probe (e.g., a lightweight RPC like tools/list/initialize on the cached connection) and recording an error when that probe fails so auto-restart can trigger on real liveness failures.

Copilot uses AI. Check for mistakes.
}
}

func (hm *HealthMonitor) handleErrorState(serverID string, state ServerState) {
failures := hm.consecutiveFailures[serverID]
if failures >= maxConsecutiveRestartFailures {
// Already logged when the threshold was reached; stay silent.
return
}

logger.LogWarn("backend", "Health check: server %q in error state (%s), attempting restart (%d/%d)",
serverID, state.LastError, failures+1, maxConsecutiveRestartFailures)

// Clear error state and cached connection so GetOrLaunch can retry.
hm.launcher.clearServerForRestart(serverID)

_, err := GetOrLaunch(hm.launcher, serverID)
if err != nil {
hm.consecutiveFailures[serverID] = failures + 1
logger.LogError("backend", "Health check: restart failed for server %q: %v (attempt %d/%d)",
serverID, err, failures+1, maxConsecutiveRestartFailures)
if hm.consecutiveFailures[serverID] >= maxConsecutiveRestartFailures {
logger.LogError("backend",
"Health check: server %q reached max restart attempts (%d), will not retry until gateway restart",
serverID, maxConsecutiveRestartFailures)
}
return
}

hm.consecutiveFailures[serverID] = 0
log.Printf("[HEALTH] Successfully restarted server %q", serverID)
logger.LogInfo("backend", "Health check: successfully restarted server %q", serverID)
}
137 changes: 137 additions & 0 deletions internal/launcher/health_monitor_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
package launcher

import (
"context"
"testing"
"time"

"github.com/github/gh-aw-mcpg/internal/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func newTestLauncher(servers map[string]*config.ServerConfig) *Launcher {
ctx := context.Background()
cfg := &config.Config{Servers: servers}
return New(ctx, cfg)
}

func TestHealthMonitor_StartStop(t *testing.T) {
l := newTestLauncher(map[string]*config.ServerConfig{})
hm := NewHealthMonitor(l, 50*time.Millisecond)

hm.Start()
// Let at least one tick fire
time.Sleep(100 * time.Millisecond)
hm.Stop()

// Verify doneCh is closed (Stop returned)
select {
case <-hm.doneCh:
// expected
default:
t.Fatal("doneCh should be closed after Stop")
}
}

func TestHealthMonitor_DefaultInterval(t *testing.T) {
l := newTestLauncher(map[string]*config.ServerConfig{})
hm := NewHealthMonitor(l, 0)

assert.Equal(t, DefaultHealthCheckInterval, hm.interval)
}

func TestHealthMonitor_RunningServerResetsFailureCounter(t *testing.T) {
servers := map[string]*config.ServerConfig{
"test-server": {Type: "http", URL: "http://localhost:9999"},
}
l := newTestLauncher(servers)

// Simulate a running server
l.recordStart("test-server")

hm := NewHealthMonitor(l, 50*time.Millisecond)
hm.consecutiveFailures["test-server"] = 2

hm.checkAll()

assert.Equal(t, 0, hm.consecutiveFailures["test-server"])
}

func TestHealthMonitor_ErrorStateIncrementsFailureCounter(t *testing.T) {
// Use a server config that will fail to launch (no Docker available in test)
servers := map[string]*config.ServerConfig{
"bad-server": {Type: "stdio", Command: "nonexistent-binary-xyz"},
}
l := newTestLauncher(servers)

// Simulate the server being in error state
l.recordError("bad-server", "process crashed")

hm := NewHealthMonitor(l, time.Hour) // large interval; we call checkAll manually

hm.checkAll()

// Server should have failed restart and incremented counter
assert.Equal(t, 1, hm.consecutiveFailures["bad-server"])
}

func TestHealthMonitor_StopsRetryingAtMaxFailures(t *testing.T) {
servers := map[string]*config.ServerConfig{
"bad-server": {Type: "stdio", Command: "nonexistent-binary-xyz"},
}
l := newTestLauncher(servers)

hm := NewHealthMonitor(l, time.Hour)
hm.consecutiveFailures["bad-server"] = maxConsecutiveRestartFailures

// Simulate error state
l.recordError("bad-server", "still broken")

hm.checkAll()

// Should not have incremented further
assert.Equal(t, maxConsecutiveRestartFailures, hm.consecutiveFailures["bad-server"])

// Error should still be present (no restart attempted)
state := l.GetServerState("bad-server")
assert.Equal(t, "error", state.Status)
}

func TestClearServerForRestart(t *testing.T) {
l := newTestLauncher(map[string]*config.ServerConfig{
"srv": {Type: "http", URL: "http://localhost:9999"},
})

// Record start then error
l.serverStartTimes["srv"] = time.Now()
l.serverErrors["srv"] = "something failed"

state := l.GetServerState("srv")
require.Equal(t, "error", state.Status)

l.clearServerForRestart("srv")

state = l.GetServerState("srv")
assert.Equal(t, "stopped", state.Status)
assert.Empty(t, state.LastError)
}

func TestHealthMonitor_RespectsContextCancellation(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cfg := &config.Config{Servers: map[string]*config.ServerConfig{}}
l := New(ctx, cfg)

hm := NewHealthMonitor(l, 50*time.Millisecond)
hm.Start()

// Cancel context — monitor should exit
cancel()

select {
case <-hm.doneCh:
// expected — monitor stopped
case <-time.After(2 * time.Second):
t.Fatal("health monitor did not stop after context cancellation")
}
}
17 changes: 17 additions & 0 deletions internal/launcher/launcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,23 @@ func (l *Launcher) recordError(serverID string, errMsg string) {
logLauncher.Printf("Recorded server error: serverID=%s, error=%s", serverID, errMsg)
}

// clearServerForRestart removes the error record and any cached connection for
// serverID so that a subsequent GetOrLaunch call will attempt a fresh launch.
// Called by HealthMonitor before retrying a failed server.
func (l *Launcher) clearServerForRestart(serverID string) {
l.mu.Lock()
defer l.mu.Unlock()

delete(l.serverErrors, serverID)
delete(l.serverStartTimes, serverID)

if conn, ok := l.connections[serverID]; ok {
conn.Close()
delete(l.connections, serverID)
Comment on lines +336 to +337
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clearServerForRestart calls conn.Close() while holding l.mu. Since Close() can block (e.g., waiting on SDK session shutdown), this can stall unrelated launcher operations that need the same mutex. Prefer removing the connection from the map under the lock, then closing it after unlocking.

Suggested change
conn.Close()
delete(l.connections, serverID)
delete(l.connections, serverID)
go conn.Close()

Copilot uses AI. Check for mistakes.
}
logLauncher.Printf("Cleared server state for restart: serverID=%s", serverID)
}

// GetServerState returns the observed runtime state for a single server.
func (l *Launcher) GetServerState(serverID string) ServerState {
l.mu.RLock()
Expand Down
12 changes: 12 additions & 0 deletions internal/server/unified.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ type UnifiedServer struct {

// Testing support - when true, skips os.Exit() call
testMode bool

// Health monitoring
healthMonitor *launcher.HealthMonitor
}

// NewUnified creates a new unified MCP server
Expand Down Expand Up @@ -192,6 +195,10 @@ func NewUnified(ctx context.Context, cfg *config.Config) (*UnifiedServer, error)
return nil, fmt.Errorf("failed to register tools: %w", err)
}

// Start periodic health monitoring and auto-restart (spec §8)
us.healthMonitor = launcher.NewHealthMonitor(l, launcher.DefaultHealthCheckInterval)
us.healthMonitor.Start()

Comment on lines +199 to +201
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NewUnified() now starts the health monitor, but (*UnifiedServer).Close() does not stop it (many tests use defer us.Close() without calling InitiateShutdown()). This can leak a ticker goroutine and may cause the monitor to keep calling into a closed launcher. Consider stopping the monitor in Close() (or making Close() delegate to InitiateShutdown()), and/or cancelling the shared context used by the monitor/launcher.

Suggested change
us.healthMonitor = launcher.NewHealthMonitor(l, launcher.DefaultHealthCheckInterval)
us.healthMonitor.Start()
hm := launcher.NewHealthMonitor(l, launcher.DefaultHealthCheckInterval)
us.healthMonitor = hm
hm.Start()
// Ensure the health monitor stops when the context is cancelled to avoid goroutine leaks.
go func() {
<-ctx.Done()
hm.Stop()
}()

Copilot uses AI. Check for mistakes.
logUnified.Printf("Unified server created successfully with %d tools", len(us.tools))
return us, nil
}
Expand Down Expand Up @@ -661,6 +668,11 @@ func (us *UnifiedServer) InitiateShutdown() int {
log.Println("Initiating gateway shutdown...")
logger.LogInfo("shutdown", "Gateway shutdown initiated")

// Stop health monitor before closing connections
if us.healthMonitor != nil {
us.healthMonitor.Stop()
}

// Count servers before closing
serversTerminated = len(us.launcher.ServerIDs())

Expand Down
Loading