Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 38 additions & 12 deletions internal/guard/wasm.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package guard
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"os"
Expand All @@ -18,12 +19,25 @@ import (

var logWasm = logger.New("guard:wasm")

// globalCompilationCache is a process-level compilation cache shared across all
// WasmGuard instances. wazero's cache is goroutine-safe and eliminates redundant
// JIT compilation when multiple guards load the same WASM binary.
var globalCompilationCache = wazero.NewCompilationCache()

// WasmGuardOptions configures optional settings for WASM guard creation
type WasmGuardOptions struct {
// Stdout is the writer for WASM stdout output. Defaults to os.Stdout if nil.
Stdout io.Writer
// Stderr is the writer for WASM stderr output. Defaults to os.Stderr if nil.
Stderr io.Writer
// CompilationCache overrides the process-level compilation cache.
// If nil and DisableCompilationCache is false, the shared globalCompilationCache is used.
CompilationCache wazero.CompilationCache
// DisableCompilationCache, when true, prevents any compilation cache from being used
// even if a global or per-instance cache is available. This can be useful to avoid
// unbounded memory growth when many distinct WASM binaries are loaded over the
// lifetime of a long-lived process.
DisableCompilationCache bool
}

// WasmGuard implements Guard interface by executing a WASM module in-process
Expand Down Expand Up @@ -79,9 +93,15 @@ func NewWasmGuardFromBytes(ctx context.Context, name string, wasmBytes []byte, b
func NewWasmGuardWithOptions(ctx context.Context, name string, wasmBytes []byte, backend BackendCaller, opts *WasmGuardOptions) (*WasmGuard, error) {
logWasm.Printf("Creating WASM guard from bytes: name=%s, size=%d", name, len(wasmBytes))

// Create WASM runtime with explicit compiler config and context-based cancellation
// WithCloseOnContextDone enables request-scoped timeouts to propagate into guard execution
// Select compilation cache: explicit opt-out, injected cache, or shared global.
runtimeConfig := wazero.NewRuntimeConfigCompiler().WithCloseOnContextDone(true)
if opts != nil && opts.DisableCompilationCache {
// Caller explicitly disabled caching
} else if opts != nil && opts.CompilationCache != nil {
runtimeConfig = runtimeConfig.WithCompilationCache(opts.CompilationCache)
} else {
runtimeConfig = runtimeConfig.WithCompilationCache(globalCompilationCache)
}
runtime := wazero.NewRuntimeWithConfig(ctx, runtimeConfig)

// Instantiate WASI
Expand Down Expand Up @@ -624,10 +644,6 @@ func (g *WasmGuard) callWasmGuardFunction(ctx context.Context, funcName string,
func (g *WasmGuard) LabelAgent(ctx context.Context, policy interface{}, backend BackendCaller, caps *difc.Capabilities) (*LabelAgentResult, error) {
logWasm.Printf("LabelAgent called: guard=%s", g.name)

if g.module.ExportedFunction("label_agent") == nil {
return nil, fmt.Errorf("WASM guard does not export label_agent")
}

// Normalisation and payload-build operate only on the caller-supplied `policy`
// argument and do not access any g.* fields, so they are safe to run outside
// the lock that callWasmGuardFunction acquires.
Expand Down Expand Up @@ -810,7 +826,18 @@ func (g *WasmGuard) callWasmFunction(ctx context.Context, funcName string, input
return nil, fmt.Errorf("input too large: %d bytes (max %d)", len(inputJSON), maxInputSize)
}

// Try with initial buffer size, retry with larger buffer if needed
// Adaptive output buffer strategy:
//
// WASM guards communicate buffer-too-small via a return code convention:
// -2 → buffer too small; first 4 bytes of the output buffer MAY contain the
// required size as a little-endian uint32. If present and > 0, we use
// that size for the next attempt; otherwise we double the buffer.
// < 0 → other error (returned as-is to the caller).
// >= 0 → success; value is the number of bytes written to the output buffer.
//
// We retry up to maxRetries times, growing from 4MB toward the 16MB ceiling.
// A WASM trap (e.g. "wasm error: unreachable" from a Rust panic) permanently
// marks the guard as failed because the module's internal state may be corrupt.
outputSize := initialOutputSize
const maxRetries = 3

Expand Down Expand Up @@ -1130,13 +1157,12 @@ func parseCollectionLabeledData(items []interface{}) (*difc.CollectionLabeledDat

// Close releases WASM runtime resources
func (g *WasmGuard) Close(ctx context.Context) error {
var moduleErr, runtimeErr error
if g.module != nil {
if err := g.module.Close(ctx); err != nil {
logWasm.Printf("Error closing module: %v", err)
}
moduleErr = g.module.Close(ctx)
}
if g.runtime != nil {
return g.runtime.Close(ctx)
runtimeErr = g.runtime.Close(ctx)
}
return nil
return errors.Join(moduleErr, runtimeErr)
}
80 changes: 80 additions & 0 deletions internal/guard/wasm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"os"
"testing"
"time"

Expand All @@ -16,6 +18,17 @@ import (
"github.com/tetratelabs/wazero"
)

func TestMain(m *testing.M) {
code := m.Run()
if err := globalCompilationCache.Close(context.Background()); err != nil {
fmt.Fprintf(os.Stderr, "failed to close global compilation cache: %v\n", err)
if code == 0 {
code = 1
}
}
os.Exit(code)
}
Comment on lines +21 to +30
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TestMain closes globalCompilationCache but ignores the returned error. If Close can fail (e.g., for future disk-backed caches or internal cleanup), the test process would still exit with a success code. Consider checking the error and failing the test run (e.g., printing to stderr and using a non-zero exit code) so cache lifecycle issues are surfaced.

Copilot uses AI. Check for mistakes.

type ctxKey string

const testCtxKey ctxKey = "test-key"
Expand Down Expand Up @@ -1081,3 +1094,70 @@ func TestWasmGuardFailedState(t *testing.T) {
assert.ErrorIs(t, err, originalErr)
})
}

func TestWasmGuardCompilationCache(t *testing.T) {
t.Run("global compilation cache is not nil", func(t *testing.T) {
assert.NotNil(t, globalCompilationCache)
})

t.Run("custom cache is used when provided via options", func(t *testing.T) {
ctx := context.Background()

// Use a disk-backed cache in a temp dir so we can observe that it was used.
cacheDir := t.TempDir()
customCache, err := wazero.NewCompilationCacheWithDir(cacheDir)
require.NoError(t, err)
defer customCache.Close(ctx)

opts := &WasmGuardOptions{
CompilationCache: customCache,
}

// Instantiation will fail (minimal WASM), but the code path
// that selects the cache runs before module compilation, which
// should populate the disk-backed cache.
_, err = NewWasmGuardWithOptions(ctx, "cache-test", minimalGuardWasm, &mockBackendCaller{}, opts)
require.Error(t, err)

entries, readErr := os.ReadDir(cacheDir)
require.NoError(t, readErr)
assert.NotEmpty(t, entries, "expected compilation artifacts in custom cache directory")
})

t.Run("global cache is used when options cache is nil", func(t *testing.T) {
ctx := context.Background()

// Swap in a disk-backed global cache pointing at a temp dir so we can
// observe that the global cache path is actually exercised.
cacheDir := t.TempDir()
tmpCache, err := wazero.NewCompilationCacheWithDir(cacheDir)
require.NoError(t, err)

origCache := globalCompilationCache
globalCompilationCache = tmpCache
defer func() {
globalCompilationCache = origCache
tmpCache.Close(ctx)
}()

// nil opts → global cache path
_, err = NewWasmGuardWithOptions(ctx, "cache-test", minimalGuardWasm, &mockBackendCaller{}, nil)
require.Error(t, err)

entries, readErr := os.ReadDir(cacheDir)
require.NoError(t, readErr)
assert.NotEmpty(t, entries, "expected compilation artifacts in global cache directory")
})

t.Run("cache is disabled when DisableCompilationCache is true", func(t *testing.T) {
ctx := context.Background()

opts := &WasmGuardOptions{
DisableCompilationCache: true,
}

// Should still work (fail on minimal WASM) but without caching
_, err := NewWasmGuardWithOptions(ctx, "no-cache-test", minimalGuardWasm, &mockBackendCaller{}, opts)
require.Error(t, err)
})
}
Loading