diff --git a/cmd/ww/cat/cat.go b/cmd/ww/cat/cat.go index 8bd9b4c..0b0a94e 100644 --- a/cmd/ww/cat/cat.go +++ b/cmd/ww/cat/cat.go @@ -22,7 +22,7 @@ var env util.IPFSEnv func Command() *cli.Command { return &cli.Command{ Name: "cat", - ArgsUsage: " ", + ArgsUsage: " [method]", Usage: "Connect to a peer and execute a procedure over a stream", Description: `Connect to a specified peer and execute a procedure over a custom protocol stream. The command will: @@ -33,7 +33,8 @@ The command will: Examples: ww cat QmPeer123 /echo - ww cat 12D3KooW... /myproc`, + ww cat 12D3KooW... /myproc echo + ww cat 12D3KooW... /myproc poll`, Flags: append([]cli.Flag{ &cli.StringFlag{ Name: "ipfs", @@ -58,12 +59,13 @@ func Main(c *cli.Context) error { ctx, cancel := context.WithTimeout(c.Context, c.Duration("timeout")) defer cancel() - if c.NArg() != 2 { - return cli.Exit("cat requires exactly two arguments: ", 1) + if c.NArg() < 3 { + return cli.Exit("cat requires 2-3 arguments: [method]", 1) } peerIDStr := c.Args().Get(0) procName := c.Args().Get(1) + method := c.Args().Get(2) // Parse peer ID peerID, err := peer.Decode(peerIDStr) @@ -73,6 +75,9 @@ func Main(c *cli.Context) error { // Construct protocol ID protocolID := protocol.ID("/ww/0.1.0/" + procName) + if method != "" && method != "poll" { + protocolID = protocol.ID("/ww/0.1.0/" + procName + "/" + method) + } // Create libp2p host in client mode h, err := util.NewClient() diff --git a/cmd/ww/run/run.go b/cmd/ww/run/run.go index 2852c1e..760529f 100644 --- a/cmd/ww/run/run.go +++ b/cmd/ww/run/run.go @@ -9,11 +9,13 @@ import ( "os" "os/exec" "path/filepath" + "strings" "github.com/ipfs/boxo/path" "github.com/libp2p/go-libp2p/core/event" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/protocol" "github.com/tetratelabs/wazero" "github.com/tetratelabs/wazero/experimental/sys" "github.com/urfave/cli/v2" @@ -134,20 +136,39 @@ func Main(c *cli.Context) error { "peer", env.Host.ID(), "endpoint", p.Endpoint.Name) - env.Host.SetStreamHandler(p.Endpoint.Protocol(), func(s network.Stream) { + // Set up stream handler that matches both exact protocol and with method suffix + baseProto := p.Endpoint.Protocol() + env.Host.SetStreamHandlerMatch(baseProto, func(protocol protocol.ID) bool { + // Match exact base protocol (/ww/0.1.0/) or with method suffix (/ww/0.1.0//) + return protocol == baseProto || strings.HasPrefix(string(protocol), string(baseProto)+"/") + }, func(s network.Stream) { defer s.CloseRead() + + // Extract method from protocol string + method := "poll" // default + protocolStr := string(s.Protocol()) + if strings.HasPrefix(protocolStr, string(baseProto)+"/") { + // Extract method from /ww/0.1.0// + parts := strings.Split(protocolStr, "/") + if len(parts) > 0 { + method = parts[len(parts)-1] + } + } + slog.InfoContext(ctx, "stream connected", "peer", s.Conn().RemotePeer(), "stream-id", s.ID(), - "endpoint", p.Endpoint.Name) - if err := p.Poll(ctx, s, nil); err != nil { + "endpoint", p.Endpoint.Name, + "method", method) + if err := p.ProcessMessage(ctx, s, method); err != nil { slog.ErrorContext(ctx, "failed to poll process", "id", p.ID(), "stream", s.ID(), + "method", method, "reason", err) } }) - defer env.Host.RemoveStreamHandler(p.Endpoint.Protocol()) + defer env.Host.RemoveStreamHandler(baseProto) for { select { diff --git a/examples/echo/main.go b/examples/echo/main.go index cbe8515..622d424 100644 --- a/examples/echo/main.go +++ b/examples/echo/main.go @@ -10,26 +10,19 @@ import ( // main is the entry point for synchronous mode. // It processes one complete message from stdin and exits. func main() { - // Echo: copy stdin to stdout using io.Copy - // io.Copy uses an internal 32KB buffer by default - if _, err := io.Copy(os.Stdout, os.Stdin); err != nil { - os.Stderr.WriteString("Error copying stdin to stdout: " + err.Error() + "\n") - os.Exit(1) - } - defer os.Stdout.Sync() - // implicitly returns 0 to indicate successful completion + echo() } -// poll is the async entry point for stream-based processing. +// echo is the async entry point for stream-based processing. // This function is called by the wetware runtime when a new stream // is established for this process. // -//export poll -func poll() { +//export echo +func echo() { // In async mode, we process each incoming stream // This is the same logic as main() but for individual streams if _, err := io.Copy(os.Stdout, os.Stdin); err != nil { - os.Stderr.WriteString("Error in poll: " + err.Error() + "\n") + os.Stderr.WriteString("Error in echo: " + err.Error() + "\n") os.Exit(1) } defer os.Stdout.Sync() diff --git a/examples/echo/main.wasm b/examples/echo/main.wasm index ba70093..e07bf35 100644 Binary files a/examples/echo/main.wasm and b/examples/echo/main.wasm differ diff --git a/system/proc.go b/system/proc.go index c8a4185..9bc4eaa 100644 --- a/system/proc.go +++ b/system/proc.go @@ -166,8 +166,8 @@ func (p Proc) ID() string { // ProcessMessage processes one complete message synchronously. // In sync mode: lets _start run automatically and process one message -// In async mode: calls the poll export function -func (p Proc) ProcessMessage(ctx context.Context, s network.Stream) error { +// In async mode: calls the specified export function +func (p Proc) ProcessMessage(ctx context.Context, s network.Stream, method string) error { if deadline, ok := ctx.Deadline(); ok { if err := s.SetReadDeadline(deadline); err != nil { return fmt.Errorf("set read deadline: %w", err) @@ -187,14 +187,23 @@ func (p Proc) ProcessMessage(ctx context.Context, s network.Stream) error { p.Endpoint.ReadWriteCloser = nil }() - // In async mode, call the poll export function + // In async mode, call the specified export function if p.Config.Async { - if poll := p.Module.ExportedFunction("poll"); poll == nil { - return fmt.Errorf("%s::poll: not found", p.ID()) - } else if err := poll.CallWithStack(ctx, nil); err != nil { + // Normalize method: if empty string, use "poll" + if method == "" { + method = "poll" + } + + exp := p.Module.ExportedFunction(method) + if exp == nil { + _ = s.Reset() + return fmt.Errorf("unknown method: %s", method) + } + + if err := exp.CallWithStack(ctx, nil); err != nil { var exitErr *sys.ExitError if errors.As(err, &exitErr) && exitErr.ExitCode() != 0 { - return fmt.Errorf("%s::poll: %w", p.ID(), err) + return fmt.Errorf("%s::%s: %w", p.ID(), method, err) } // If it's ExitError with code 0, treat as success } @@ -204,11 +213,6 @@ func (p Proc) ProcessMessage(ctx context.Context, s network.Stream) error { return nil } -// Poll is an alias for ProcessMessage for backward compatibility -func (p Proc) Poll(ctx context.Context, s network.Stream, stack []uint64) error { - return p.ProcessMessage(ctx, s) -} - type CloserSlice []api.Closer func (cs CloserSlice) Close(ctx context.Context) error { diff --git a/system/proc_test.go b/system/proc_test.go index e2b68a0..f1a8fca 100644 --- a/system/proc_test.go +++ b/system/proc_test.go @@ -218,16 +218,17 @@ func TestProc_Poll_WithGomock(t *testing.T) { defer proc.Close(ctx) defer runtime.Close(ctx) - // Test that poll function exists - pollFunc := proc.Module.ExportedFunction("poll") - assert.NotNil(t, pollFunc, "poll function should exist") + // Test that echo function exists + echoFunc := proc.Module.ExportedFunction("echo") + assert.NotNil(t, echoFunc, "echo function should exist") // The echo WASM module will try to read from stdin, so we need to expect Read calls - // The poll function reads up to 512 bytes from stdin + // The echo function reads up to 512 bytes from stdin mockStream.EXPECT().Read(gomock.Any()).Return(0, io.EOF).AnyTimes() + mockStream.EXPECT().Reset().Return(nil).AnyTimes() - // Actually call the Poll method - this should succeed since we have a valid WASM function - err = proc.ProcessMessage(ctx, mockStream) + // Actually call the ProcessMessage method - this should succeed since we have a valid WASM function + err = proc.ProcessMessage(ctx, mockStream, "echo") // The WASM function should execute successfully assert.NoError(t, err) }) @@ -272,13 +273,14 @@ func TestProc_Poll_WithGomock(t *testing.T) { mockStream.EXPECT().SetReadDeadline(deadline).Return(nil) // The echo WASM module will try to read from stdin mockStream.EXPECT().Read(gomock.Any()).Return(0, io.EOF).AnyTimes() + mockStream.EXPECT().Reset().Return(nil).AnyTimes() - // Test that poll function exists (for async mode) - pollFunc := proc.Module.ExportedFunction("poll") - assert.NotNil(t, pollFunc, "poll function should exist") + // Test that echo function exists (for async mode) + echoFunc := proc.Module.ExportedFunction("echo") + assert.NotNil(t, echoFunc, "echo function should exist") // Actually call the ProcessMessage method to trigger the mock expectations - err = proc.ProcessMessage(ctxWithDeadline, mockStream) + err = proc.ProcessMessage(ctxWithDeadline, mockStream, "echo") // The WASM function should execute successfully assert.NoError(t, err) }) @@ -303,7 +305,7 @@ func TestProc_Poll_WithGomock(t *testing.T) { deadlineErr := errors.New("deadline error") mockStream.EXPECT().SetReadDeadline(deadline).Return(deadlineErr) - err := proc.ProcessMessage(ctxWithDeadline, mockStream) + err := proc.ProcessMessage(ctxWithDeadline, mockStream, "echo") assert.Error(t, err) assert.Contains(t, err.Error(), "set read deadline") assert.Contains(t, err.Error(), "deadline error") @@ -547,9 +549,9 @@ func TestProc_Integration_WithRealWasm(t *testing.T) { assert.NotEmpty(t, proc.Endpoint.Name, "Endpoint should have a name") assert.NotEmpty(t, proc.ID(), "String should not be empty") - // Test that the poll function is exported - pollFunc := proc.Module.ExportedFunction("poll") - assert.NotNil(t, pollFunc, "poll function should be exported") + // Test that the echo function is exported + echoFunc := proc.Module.ExportedFunction("echo") + assert.NotNil(t, echoFunc, "echo function should be exported") } // TestEcho_Synchronous tests the echo example in synchronous mode @@ -689,9 +691,14 @@ func TestEcho_Asynchronous(t *testing.T) { Return(nil). AnyTimes() + mockStream.EXPECT(). + Reset(). + Return(nil). + AnyTimes() + // Process message with the mock stream // This should process one complete message (until EOF) - err = proc.ProcessMessage(ctx, mockStream) + err = proc.ProcessMessage(ctx, mockStream, "echo") require.NoError(t, err, "ProcessMessage should succeed") // Verify the output matches the input @@ -780,8 +787,13 @@ func TestEcho_RepeatedAsync(t *testing.T) { Return(nil). AnyTimes() + mockStream.EXPECT(). + Reset(). + Return(nil). + AnyTimes() + // Process message with the mock stream - err = proc.ProcessMessage(ctx, mockStream) + err = proc.ProcessMessage(ctx, mockStream, "echo") require.NoError(t, err, "ProcessMessage should succeed for message %d", i+1) // Verify the output matches the input