Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions internal/guard/guard_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package guard

import (
"context"
"errors"
"sync"
"testing"

Expand All @@ -17,6 +18,20 @@ type mockGuard struct {
id string
}

// mockClosableGuard is a guard that tracks whether Close was called
type mockClosableGuard struct {
mockGuard
closed bool
closeCount int
closeErr error
}

func (m *mockClosableGuard) Close(ctx context.Context) error {
m.closed = true
m.closeCount++
return m.closeErr
}

func (m *mockGuard) Name() string { return "mock-" + m.id }
func (m *mockGuard) LabelAgent(ctx context.Context, policy interface{}, backend BackendCaller, caps *difc.Capabilities) (*LabelAgentResult, error) {
return &LabelAgentResult{DIFCMode: difc.ModeStrict}, nil
Expand Down Expand Up @@ -450,6 +465,70 @@ func TestGuardRegistry_HasNonNoopGuard(t *testing.T) {
})
}

func TestGuardRegistry_Close(t *testing.T) {
t.Run("close calls Close on guards that implement it", func(t *testing.T) {
registry := NewRegistry()
g := &mockClosableGuard{mockGuard: mockGuard{id: "wasm"}}
registry.Register("server1", g)

registry.Close(context.Background())

assert.True(t, g.closed, "expected guard Close to be called")
})

t.Run("close skips guards that do not implement Close", func(t *testing.T) {
registry := NewRegistry()
registry.Register("server1", NewNoopGuard())

// Should not panic
registry.Close(context.Background())
})

t.Run("close on empty registry is safe", func(t *testing.T) {
registry := NewRegistry()
// Should not panic
registry.Close(context.Background())
})

t.Run("close calls Close on all closable guards", func(t *testing.T) {
registry := NewRegistry()
g1 := &mockClosableGuard{mockGuard: mockGuard{id: "wasm1"}}
g2 := &mockClosableGuard{mockGuard: mockGuard{id: "wasm2"}}
registry.Register("server1", g1)
registry.Register("server2", g2)

registry.Close(context.Background())

assert.True(t, g1.closed, "expected guard 1 Close to be called")
assert.True(t, g2.closed, "expected guard 2 Close to be called")
})

t.Run("close continues when one guard returns an error", func(t *testing.T) {
registry := NewRegistry()
g1 := &mockClosableGuard{mockGuard: mockGuard{id: "failing"}, closeErr: errors.New("close failed")}
g2 := &mockClosableGuard{mockGuard: mockGuard{id: "ok"}}
registry.Register("server1", g1)
registry.Register("server2", g2)

// Should not panic even when one guard returns an error
registry.Close(context.Background())

assert.True(t, g1.closed, "expected failing guard Close to be called")
assert.True(t, g2.closed, "expected ok guard Close to be called")
})

t.Run("double close is safe", func(t *testing.T) {
registry := NewRegistry()
g := &mockClosableGuard{mockGuard: mockGuard{id: "wasm"}}
registry.Register("server1", g)

registry.Close(context.Background())
registry.Close(context.Background())

assert.Equal(t, 2, g.closeCount, "Close should be called twice without panic")
})
}

func TestCreateGuard(t *testing.T) {
tests := []struct {
name string
Expand Down
35 changes: 31 additions & 4 deletions internal/guard/registry.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package guard

import (
"context"
"fmt"
"sync"

Expand Down Expand Up @@ -30,7 +31,7 @@ func (r *Registry) Register(serverID string, guard Guard) {
defer r.mu.Unlock()

r.guards[serverID] = guard
log.Printf("[Guard] Registered guard '%s' for server '%s'", guard.Name(), serverID)
logger.LogInfo("guard", "Registered guard '%s' for server '%s'", guard.Name(), serverID)
}

// Get retrieves the guard for a server, or returns a noop guard if not found
Expand All @@ -46,7 +47,6 @@ func (r *Registry) Get(serverID string) Guard {

// Return noop guard as default
debugLog.Printf("No guard registered for serverID=%s, returning noop guard", serverID)
log.Printf("[Guard] No guard registered for server '%s', using noop guard", serverID)
return NewNoopGuard()
}

Expand Down Expand Up @@ -76,7 +76,7 @@ func (r *Registry) Remove(serverID string) {
r.mu.Lock()
defer r.mu.Unlock()
delete(r.guards, serverID)
log.Printf("[Guard] Removed guard for server '%s'", serverID)
logger.LogInfo("guard", "Removed guard for server '%s'", serverID)
}

// List returns all registered server IDs
Expand All @@ -103,6 +103,33 @@ func (r *Registry) GetGuardInfo() map[string]string {
return info
}

// Close closes all registered guards that implement Close(context.Context) error.
// It should be called during server shutdown to release WASM runtime resources.
func (r *Registry) Close(ctx context.Context) {
type closableGuard struct {
id string
c interface{ Close(context.Context) error }
}

r.mu.RLock()
closers := make([]closableGuard, 0, len(r.guards))
for id, g := range r.guards {
if c, ok := g.(interface{ Close(context.Context) error }); ok {
closers = append(closers, closableGuard{id: id, c: c})
}
}
r.mu.RUnlock()

for _, guard := range closers {
if err := guard.c.Close(ctx); err != nil {
logger.LogWarn("guard", "Failed to close guard for server %s: %v", guard.id, err)
}
}
if len(closers) > 0 {
logger.LogInfo("guard", "Closed %d guard(s)", len(closers))
}
}

// GuardFactory is a function that creates a guard instance
type GuardFactory func() (Guard, error)

Expand All @@ -116,7 +143,7 @@ func RegisterGuardType(name string, factory GuardFactory) {
registeredGuardsMu.Lock()
defer registeredGuardsMu.Unlock()
registeredGuards[name] = factory
log.Printf("[Guard] Registered guard type: %s", name)
logger.LogInfo("guard", "Registered guard type: %s", name)
}

// CreateGuard creates a guard instance by name using registered factories
Expand Down
19 changes: 16 additions & 3 deletions internal/guard/wasm.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
"github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1"
"github.com/tetratelabs/wazero/sys"
)

var logWasm = logger.New("guard:wasm")
Expand Down Expand Up @@ -830,10 +831,22 @@ func parsePathLabeledResponse(responseJSON []byte, originalData interface{}) (di
return pld.ToCollectionLabeledData(), nil
}

// isWasmTrap reports whether err is a WASM execution trap such as the
// "wasm error: unreachable" produced when a Rust-compiled guard panics.
// isWasmTrap reports whether err represents a WASM execution trap that should
// permanently poison the guard. Normal process exits (exit code 0, e.g. TinyGo
// init) are NOT considered traps. A non-zero exit code is treated as a trap.
// As a fallback for wazero execution faults (e.g. Rust panic → unreachable),
// the function also matches on wazero's "wasm error:" message prefix.
func isWasmTrap(err error) bool {
return err != nil && strings.Contains(err.Error(), "wasm error:")
if err == nil {
return false
}
// A normal WASI process exit (exit code 0) is not a trap — don't poison the guard.
var exitErr *sys.ExitError
if errors.As(err, &exitErr) {
return exitErr.ExitCode() != 0
}
// Fallback for wazero execution traps (e.g. Rust panic → unreachable).
return strings.Contains(err.Error(), "wasm error:")
}

// callWasmFunction calls an exported function in the WASM module.
Expand Down
21 changes: 21 additions & 0 deletions internal/guard/wasm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (

"github.com/github/gh-aw-mcpg/internal/difc"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/sys"
)

func TestMain(m *testing.M) {
Expand Down Expand Up @@ -1152,6 +1153,26 @@ func TestIsWasmTrap(t *testing.T) {
err := errors.New("wasm error: out of bounds memory access")
assert.True(t, isWasmTrap(err))
})

t.Run("sys.ExitError with exit code 0 is not a trap", func(t *testing.T) {
err := sys.NewExitError(0)
assert.False(t, isWasmTrap(err))
})

t.Run("sys.ExitError with non-zero exit code is a trap", func(t *testing.T) {
err := sys.NewExitError(1)
assert.True(t, isWasmTrap(err))
})

t.Run("wrapped sys.ExitError with exit code 0 is not a trap", func(t *testing.T) {
err := fmt.Errorf("wrapper: %w", sys.NewExitError(0))
assert.False(t, isWasmTrap(err))
})

t.Run("wrapped sys.ExitError with non-zero exit code is a trap", func(t *testing.T) {
err := fmt.Errorf("wrapper: %w", sys.NewExitError(2))
assert.True(t, isWasmTrap(err))
})
}

func TestWasmGuardFailedState(t *testing.T) {
Expand Down
7 changes: 6 additions & 1 deletion internal/server/unified.go
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,7 @@ func (us *UnifiedServer) GetToolHandler(backendID string, toolName string) func(

// Close cleans up resources
func (us *UnifiedServer) Close() error {
us.launcher.Close()
us.InitiateShutdown()
return nil
}

Expand Down Expand Up @@ -753,6 +753,11 @@ func (us *UnifiedServer) InitiateShutdown() int {
logger.LogInfo("shutdown", "Terminating %d backend servers", serversTerminated)
us.launcher.Close()

// Release WASM runtime resources held by guards
if us.guardRegistry != nil {
us.guardRegistry.Close(context.Background())
}
Comment on lines +756 to +759
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

InitiateShutdown closes guards immediately while the HTTP server may still be draining in-flight requests. If any in-flight request is currently executing a guard call, closing the underlying WASM module/runtime concurrently can race (WasmGuard serializes calls with a mutex, but Close does not take that mutex). Consider either deferring guardRegistry.Close until after HTTP shutdown/drain completes, or ensuring guard Close implementations synchronize with in-flight calls (e.g. WasmGuard.Close acquiring g.mu).

Copilot uses AI. Check for mistakes.

logger.LogInfo("shutdown", "Backend servers terminated successfully")
})
return serversTerminated
Expand Down
Loading