diff --git a/internal/guard/guard_test.go b/internal/guard/guard_test.go index 93b85d9f..47a8426a 100644 --- a/internal/guard/guard_test.go +++ b/internal/guard/guard_test.go @@ -2,6 +2,7 @@ package guard import ( "context" + "errors" "sync" "testing" @@ -17,6 +18,20 @@ type mockGuard struct { id string } +// mockClosableGuard is a guard that tracks whether Close was called +type mockClosableGuard struct { + mockGuard + closed bool + closeCount int + closeErr error +} + +func (m *mockClosableGuard) Close(ctx context.Context) error { + m.closed = true + m.closeCount++ + return m.closeErr +} + func (m *mockGuard) Name() string { return "mock-" + m.id } func (m *mockGuard) LabelAgent(ctx context.Context, policy interface{}, backend BackendCaller, caps *difc.Capabilities) (*LabelAgentResult, error) { return &LabelAgentResult{DIFCMode: difc.ModeStrict}, nil @@ -450,6 +465,70 @@ func TestGuardRegistry_HasNonNoopGuard(t *testing.T) { }) } +func TestGuardRegistry_Close(t *testing.T) { + t.Run("close calls Close on guards that implement it", func(t *testing.T) { + registry := NewRegistry() + g := &mockClosableGuard{mockGuard: mockGuard{id: "wasm"}} + registry.Register("server1", g) + + registry.Close(context.Background()) + + assert.True(t, g.closed, "expected guard Close to be called") + }) + + t.Run("close skips guards that do not implement Close", func(t *testing.T) { + registry := NewRegistry() + registry.Register("server1", NewNoopGuard()) + + // Should not panic + registry.Close(context.Background()) + }) + + t.Run("close on empty registry is safe", func(t *testing.T) { + registry := NewRegistry() + // Should not panic + registry.Close(context.Background()) + }) + + t.Run("close calls Close on all closable guards", func(t *testing.T) { + registry := NewRegistry() + g1 := &mockClosableGuard{mockGuard: mockGuard{id: "wasm1"}} + g2 := &mockClosableGuard{mockGuard: mockGuard{id: "wasm2"}} + registry.Register("server1", g1) + registry.Register("server2", g2) + + registry.Close(context.Background()) + + assert.True(t, g1.closed, "expected guard 1 Close to be called") + assert.True(t, g2.closed, "expected guard 2 Close to be called") + }) + + t.Run("close continues when one guard returns an error", func(t *testing.T) { + registry := NewRegistry() + g1 := &mockClosableGuard{mockGuard: mockGuard{id: "failing"}, closeErr: errors.New("close failed")} + g2 := &mockClosableGuard{mockGuard: mockGuard{id: "ok"}} + registry.Register("server1", g1) + registry.Register("server2", g2) + + // Should not panic even when one guard returns an error + registry.Close(context.Background()) + + assert.True(t, g1.closed, "expected failing guard Close to be called") + assert.True(t, g2.closed, "expected ok guard Close to be called") + }) + + t.Run("double close is safe", func(t *testing.T) { + registry := NewRegistry() + g := &mockClosableGuard{mockGuard: mockGuard{id: "wasm"}} + registry.Register("server1", g) + + registry.Close(context.Background()) + registry.Close(context.Background()) + + assert.Equal(t, 2, g.closeCount, "Close should be called twice without panic") + }) +} + func TestCreateGuard(t *testing.T) { tests := []struct { name string diff --git a/internal/guard/registry.go b/internal/guard/registry.go index 08143455..44bfceb4 100644 --- a/internal/guard/registry.go +++ b/internal/guard/registry.go @@ -1,6 +1,7 @@ package guard import ( + "context" "fmt" "sync" @@ -30,7 +31,7 @@ func (r *Registry) Register(serverID string, guard Guard) { defer r.mu.Unlock() r.guards[serverID] = guard - log.Printf("[Guard] Registered guard '%s' for server '%s'", guard.Name(), serverID) + logger.LogInfo("guard", "Registered guard '%s' for server '%s'", guard.Name(), serverID) } // Get retrieves the guard for a server, or returns a noop guard if not found @@ -46,7 +47,6 @@ func (r *Registry) Get(serverID string) Guard { // Return noop guard as default debugLog.Printf("No guard registered for serverID=%s, returning noop guard", serverID) - log.Printf("[Guard] No guard registered for server '%s', using noop guard", serverID) return NewNoopGuard() } @@ -76,7 +76,7 @@ func (r *Registry) Remove(serverID string) { r.mu.Lock() defer r.mu.Unlock() delete(r.guards, serverID) - log.Printf("[Guard] Removed guard for server '%s'", serverID) + logger.LogInfo("guard", "Removed guard for server '%s'", serverID) } // List returns all registered server IDs @@ -103,6 +103,33 @@ func (r *Registry) GetGuardInfo() map[string]string { return info } +// Close closes all registered guards that implement Close(context.Context) error. +// It should be called during server shutdown to release WASM runtime resources. +func (r *Registry) Close(ctx context.Context) { + type closableGuard struct { + id string + c interface{ Close(context.Context) error } + } + + r.mu.RLock() + closers := make([]closableGuard, 0, len(r.guards)) + for id, g := range r.guards { + if c, ok := g.(interface{ Close(context.Context) error }); ok { + closers = append(closers, closableGuard{id: id, c: c}) + } + } + r.mu.RUnlock() + + for _, guard := range closers { + if err := guard.c.Close(ctx); err != nil { + logger.LogWarn("guard", "Failed to close guard for server %s: %v", guard.id, err) + } + } + if len(closers) > 0 { + logger.LogInfo("guard", "Closed %d guard(s)", len(closers)) + } +} + // GuardFactory is a function that creates a guard instance type GuardFactory func() (Guard, error) @@ -116,7 +143,7 @@ func RegisterGuardType(name string, factory GuardFactory) { registeredGuardsMu.Lock() defer registeredGuardsMu.Unlock() registeredGuards[name] = factory - log.Printf("[Guard] Registered guard type: %s", name) + logger.LogInfo("guard", "Registered guard type: %s", name) } // CreateGuard creates a guard instance by name using registered factories diff --git a/internal/guard/wasm.go b/internal/guard/wasm.go index ea144fa9..2a300d59 100644 --- a/internal/guard/wasm.go +++ b/internal/guard/wasm.go @@ -15,6 +15,7 @@ import ( "github.com/tetratelabs/wazero" "github.com/tetratelabs/wazero/api" "github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1" + "github.com/tetratelabs/wazero/sys" ) var logWasm = logger.New("guard:wasm") @@ -830,10 +831,22 @@ func parsePathLabeledResponse(responseJSON []byte, originalData interface{}) (di return pld.ToCollectionLabeledData(), nil } -// isWasmTrap reports whether err is a WASM execution trap such as the -// "wasm error: unreachable" produced when a Rust-compiled guard panics. +// isWasmTrap reports whether err represents a WASM execution trap that should +// permanently poison the guard. Normal process exits (exit code 0, e.g. TinyGo +// init) are NOT considered traps. A non-zero exit code is treated as a trap. +// As a fallback for wazero execution faults (e.g. Rust panic → unreachable), +// the function also matches on wazero's "wasm error:" message prefix. func isWasmTrap(err error) bool { - return err != nil && strings.Contains(err.Error(), "wasm error:") + if err == nil { + return false + } + // A normal WASI process exit (exit code 0) is not a trap — don't poison the guard. + var exitErr *sys.ExitError + if errors.As(err, &exitErr) { + return exitErr.ExitCode() != 0 + } + // Fallback for wazero execution traps (e.g. Rust panic → unreachable). + return strings.Contains(err.Error(), "wasm error:") } // callWasmFunction calls an exported function in the WASM module. diff --git a/internal/guard/wasm_test.go b/internal/guard/wasm_test.go index 66314b2d..776501fe 100644 --- a/internal/guard/wasm_test.go +++ b/internal/guard/wasm_test.go @@ -16,6 +16,7 @@ import ( "github.com/github/gh-aw-mcpg/internal/difc" "github.com/tetratelabs/wazero" + "github.com/tetratelabs/wazero/sys" ) func TestMain(m *testing.M) { @@ -1152,6 +1153,26 @@ func TestIsWasmTrap(t *testing.T) { err := errors.New("wasm error: out of bounds memory access") assert.True(t, isWasmTrap(err)) }) + + t.Run("sys.ExitError with exit code 0 is not a trap", func(t *testing.T) { + err := sys.NewExitError(0) + assert.False(t, isWasmTrap(err)) + }) + + t.Run("sys.ExitError with non-zero exit code is a trap", func(t *testing.T) { + err := sys.NewExitError(1) + assert.True(t, isWasmTrap(err)) + }) + + t.Run("wrapped sys.ExitError with exit code 0 is not a trap", func(t *testing.T) { + err := fmt.Errorf("wrapper: %w", sys.NewExitError(0)) + assert.False(t, isWasmTrap(err)) + }) + + t.Run("wrapped sys.ExitError with non-zero exit code is a trap", func(t *testing.T) { + err := fmt.Errorf("wrapper: %w", sys.NewExitError(2)) + assert.True(t, isWasmTrap(err)) + }) } func TestWasmGuardFailedState(t *testing.T) { diff --git a/internal/server/unified.go b/internal/server/unified.go index 7c0f5ce8..ce433ecb 100644 --- a/internal/server/unified.go +++ b/internal/server/unified.go @@ -718,7 +718,7 @@ func (us *UnifiedServer) GetToolHandler(backendID string, toolName string) func( // Close cleans up resources func (us *UnifiedServer) Close() error { - us.launcher.Close() + us.InitiateShutdown() return nil } @@ -753,6 +753,11 @@ func (us *UnifiedServer) InitiateShutdown() int { logger.LogInfo("shutdown", "Terminating %d backend servers", serversTerminated) us.launcher.Close() + // Release WASM runtime resources held by guards + if us.guardRegistry != nil { + us.guardRegistry.Close(context.Background()) + } + logger.LogInfo("shutdown", "Backend servers terminated successfully") }) return serversTerminated