diff --git a/internal/guard/wasm.go b/internal/guard/wasm.go index d36795ae..bc05b2ed 100644 --- a/internal/guard/wasm.go +++ b/internal/guard/wasm.go @@ -3,6 +3,7 @@ package guard import ( "context" "encoding/json" + "errors" "fmt" "io" "os" @@ -18,12 +19,25 @@ import ( var logWasm = logger.New("guard:wasm") +// globalCompilationCache is a process-level compilation cache shared across all +// WasmGuard instances. wazero's cache is goroutine-safe and eliminates redundant +// JIT compilation when multiple guards load the same WASM binary. +var globalCompilationCache = wazero.NewCompilationCache() + // WasmGuardOptions configures optional settings for WASM guard creation type WasmGuardOptions struct { // Stdout is the writer for WASM stdout output. Defaults to os.Stdout if nil. Stdout io.Writer // Stderr is the writer for WASM stderr output. Defaults to os.Stderr if nil. Stderr io.Writer + // CompilationCache overrides the process-level compilation cache. + // If nil and DisableCompilationCache is false, the shared globalCompilationCache is used. + CompilationCache wazero.CompilationCache + // DisableCompilationCache, when true, prevents any compilation cache from being used + // even if a global or per-instance cache is available. This can be useful to avoid + // unbounded memory growth when many distinct WASM binaries are loaded over the + // lifetime of a long-lived process. + DisableCompilationCache bool } // WasmGuard implements Guard interface by executing a WASM module in-process @@ -79,9 +93,15 @@ func NewWasmGuardFromBytes(ctx context.Context, name string, wasmBytes []byte, b func NewWasmGuardWithOptions(ctx context.Context, name string, wasmBytes []byte, backend BackendCaller, opts *WasmGuardOptions) (*WasmGuard, error) { logWasm.Printf("Creating WASM guard from bytes: name=%s, size=%d", name, len(wasmBytes)) - // Create WASM runtime with explicit compiler config and context-based cancellation - // WithCloseOnContextDone enables request-scoped timeouts to propagate into guard execution + // Select compilation cache: explicit opt-out, injected cache, or shared global. runtimeConfig := wazero.NewRuntimeConfigCompiler().WithCloseOnContextDone(true) + if opts != nil && opts.DisableCompilationCache { + // Caller explicitly disabled caching + } else if opts != nil && opts.CompilationCache != nil { + runtimeConfig = runtimeConfig.WithCompilationCache(opts.CompilationCache) + } else { + runtimeConfig = runtimeConfig.WithCompilationCache(globalCompilationCache) + } runtime := wazero.NewRuntimeWithConfig(ctx, runtimeConfig) // Instantiate WASI @@ -624,10 +644,6 @@ func (g *WasmGuard) callWasmGuardFunction(ctx context.Context, funcName string, func (g *WasmGuard) LabelAgent(ctx context.Context, policy interface{}, backend BackendCaller, caps *difc.Capabilities) (*LabelAgentResult, error) { logWasm.Printf("LabelAgent called: guard=%s", g.name) - if g.module.ExportedFunction("label_agent") == nil { - return nil, fmt.Errorf("WASM guard does not export label_agent") - } - // Normalisation and payload-build operate only on the caller-supplied `policy` // argument and do not access any g.* fields, so they are safe to run outside // the lock that callWasmGuardFunction acquires. @@ -810,7 +826,18 @@ func (g *WasmGuard) callWasmFunction(ctx context.Context, funcName string, input return nil, fmt.Errorf("input too large: %d bytes (max %d)", len(inputJSON), maxInputSize) } - // Try with initial buffer size, retry with larger buffer if needed + // Adaptive output buffer strategy: + // + // WASM guards communicate buffer-too-small via a return code convention: + // -2 → buffer too small; first 4 bytes of the output buffer MAY contain the + // required size as a little-endian uint32. If present and > 0, we use + // that size for the next attempt; otherwise we double the buffer. + // < 0 → other error (returned as-is to the caller). + // >= 0 → success; value is the number of bytes written to the output buffer. + // + // We retry up to maxRetries times, growing from 4MB toward the 16MB ceiling. + // A WASM trap (e.g. "wasm error: unreachable" from a Rust panic) permanently + // marks the guard as failed because the module's internal state may be corrupt. outputSize := initialOutputSize const maxRetries = 3 @@ -1130,13 +1157,12 @@ func parseCollectionLabeledData(items []interface{}) (*difc.CollectionLabeledDat // Close releases WASM runtime resources func (g *WasmGuard) Close(ctx context.Context) error { + var moduleErr, runtimeErr error if g.module != nil { - if err := g.module.Close(ctx); err != nil { - logWasm.Printf("Error closing module: %v", err) - } + moduleErr = g.module.Close(ctx) } if g.runtime != nil { - return g.runtime.Close(ctx) + runtimeErr = g.runtime.Close(ctx) } - return nil + return errors.Join(moduleErr, runtimeErr) } diff --git a/internal/guard/wasm_test.go b/internal/guard/wasm_test.go index 543aee0e..d069f42d 100644 --- a/internal/guard/wasm_test.go +++ b/internal/guard/wasm_test.go @@ -6,6 +6,8 @@ import ( "encoding/binary" "encoding/json" "errors" + "fmt" + "os" "testing" "time" @@ -16,6 +18,17 @@ import ( "github.com/tetratelabs/wazero" ) +func TestMain(m *testing.M) { + code := m.Run() + if err := globalCompilationCache.Close(context.Background()); err != nil { + fmt.Fprintf(os.Stderr, "failed to close global compilation cache: %v\n", err) + if code == 0 { + code = 1 + } + } + os.Exit(code) +} + type ctxKey string const testCtxKey ctxKey = "test-key" @@ -1081,3 +1094,70 @@ func TestWasmGuardFailedState(t *testing.T) { assert.ErrorIs(t, err, originalErr) }) } + +func TestWasmGuardCompilationCache(t *testing.T) { + t.Run("global compilation cache is not nil", func(t *testing.T) { + assert.NotNil(t, globalCompilationCache) + }) + + t.Run("custom cache is used when provided via options", func(t *testing.T) { + ctx := context.Background() + + // Use a disk-backed cache in a temp dir so we can observe that it was used. + cacheDir := t.TempDir() + customCache, err := wazero.NewCompilationCacheWithDir(cacheDir) + require.NoError(t, err) + defer customCache.Close(ctx) + + opts := &WasmGuardOptions{ + CompilationCache: customCache, + } + + // Instantiation will fail (minimal WASM), but the code path + // that selects the cache runs before module compilation, which + // should populate the disk-backed cache. + _, err = NewWasmGuardWithOptions(ctx, "cache-test", minimalGuardWasm, &mockBackendCaller{}, opts) + require.Error(t, err) + + entries, readErr := os.ReadDir(cacheDir) + require.NoError(t, readErr) + assert.NotEmpty(t, entries, "expected compilation artifacts in custom cache directory") + }) + + t.Run("global cache is used when options cache is nil", func(t *testing.T) { + ctx := context.Background() + + // Swap in a disk-backed global cache pointing at a temp dir so we can + // observe that the global cache path is actually exercised. + cacheDir := t.TempDir() + tmpCache, err := wazero.NewCompilationCacheWithDir(cacheDir) + require.NoError(t, err) + + origCache := globalCompilationCache + globalCompilationCache = tmpCache + defer func() { + globalCompilationCache = origCache + tmpCache.Close(ctx) + }() + + // nil opts → global cache path + _, err = NewWasmGuardWithOptions(ctx, "cache-test", minimalGuardWasm, &mockBackendCaller{}, nil) + require.Error(t, err) + + entries, readErr := os.ReadDir(cacheDir) + require.NoError(t, readErr) + assert.NotEmpty(t, entries, "expected compilation artifacts in global cache directory") + }) + + t.Run("cache is disabled when DisableCompilationCache is true", func(t *testing.T) { + ctx := context.Background() + + opts := &WasmGuardOptions{ + DisableCompilationCache: true, + } + + // Should still work (fail on minimal WASM) but without caching + _, err := NewWasmGuardWithOptions(ctx, "no-cache-test", minimalGuardWasm, &mockBackendCaller{}, opts) + require.Error(t, err) + }) +}