From 0535805657b2ef913e9a48db31f4808a644372b0 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Sat, 21 Feb 2026 23:10:51 +0000 Subject: [PATCH] feat: Prefix shell output with truncated input command Updates the Shell V2 output handling in `tavern/internal/http/shell/handler.go` to prefix the output and error messages with the input command that caused them. - Prefixes standard output with `[+] \n`. - Prefixes error output with `[!] \n`. - Truncates input strings to 64 characters to avoid cluttering the terminal. - Updates integration tests in `tavern/internal/http/shell/integration_test.go` to verify the new prefix format and truncation logic. - Adds `TestTruncation` to explicitly test the input truncation behavior. Co-authored-by: KCarretto <16250309+KCarretto@users.noreply.github.com> --- tavern/internal/http/shell/handler.go | 24 +++- .../internal/http/shell/integration_test.go | 133 +++++++++++------- 2 files changed, 99 insertions(+), 58 deletions(-) diff --git a/tavern/internal/http/shell/handler.go b/tavern/internal/http/shell/handler.go index a446abcdf..7924e1e4b 100644 --- a/tavern/internal/http/shell/handler.go +++ b/tavern/internal/http/shell/handler.go @@ -488,10 +488,14 @@ func (h *Handler) writeMessagesFromShell(ctx context.Context, session *ShellSess } if task.Output != "" { - taskOutputCh <- NewWebsocketTaskOutputMessage(task) + msg := NewWebsocketTaskOutputMessage(task) + msg.Output = fmt.Sprintf("[+] %s\n%s", truncateInput(task.Input), msg.Output) + taskOutputCh <- msg } if task.Error != "" { - taskErrCh <- NewWebsocketTaskErrorMessage(task) + msg := NewWebsocketTaskErrorMessage(task) + msg.Error = fmt.Sprintf("[!] %s\n%s", truncateInput(task.Input), msg.Error) + taskErrCh <- msg } sentTasks[task.ID] = struct{}{} } @@ -556,9 +560,13 @@ func (h *Handler) writeMessagesFromShell(ctx context.Context, session *ShellSess if mote.StreamId == streamID { // Local stream output outputMsg := NewWebsocketTaskOutputMessage(task) - outputMsg.Output = string(bytesPayload.Data) // Use real-time chunk + if _, sent := sentTasks[task.ID]; !sent { + outputMsg.Output = fmt.Sprintf("[+] %s\n%s", truncateInput(task.Input), string(bytesPayload.Data)) + sentTasks[task.ID] = struct{}{} + } else { + outputMsg.Output = string(bytesPayload.Data) // Use real-time chunk + } taskOutputCh <- outputMsg - sentTasks[task.ID] = struct{}{} } else { // Other stream output if _, sent := sentTasks[task.ID]; !sent { @@ -645,3 +653,11 @@ func (h *Handler) writeMessagesFromShell(ctx context.Context, session *ShellSess } } + +func truncateInput(input string) string { + const maxLength = 64 + if len(input) > maxLength { + return input[:maxLength] + "..." + } + return input +} diff --git a/tavern/internal/http/shell/integration_test.go b/tavern/internal/http/shell/integration_test.go index da4e5db24..b00908ec5 100644 --- a/tavern/internal/http/shell/integration_test.go +++ b/tavern/internal/http/shell/integration_test.go @@ -166,56 +166,6 @@ func TestInteractiveShell(t *testing.T) { require.NoError(t, err) // Subscribe to IN topic (User -> Agent) - // CreatePortal sets up subscriptions for IN topic? - // `CreatePortal` calls `m.openSubscription(ctx, topicIn, subName)` and starts `m.receiveLoop`. - // Wait, `CreatePortal` subscribes to IN topic (messages FROM user TO agent). - // But it consumes them internally (via `receiveLoop`). - // `receiveLoop` likely processes messages or forwards them? - // `mux.go` doesn't show `receiveLoop` impl. - // But typically `CreatePortal` is used by the C2 server to bridge PubSub <-> gRPC stream. - // Here in test, we want to simulate the Agent receiving the message. - // If `CreatePortal` already subscribed, we can't subscribe again easily with `mempubsub` (maybe?). - // However, `mux` exposes `Subscribe` method? Yes. - // But `CreatePortal` locks the sub manager. - // If we want to intercept messages sent by User, we should subscribe to `TopicIn`. - // If `CreatePortal` already subscribed, we might have competition. - // But `CreatePortal` is designed for the C2 service. - // In this test, we ARE the C2 service / Agent. - // `CreatePortal` logic: starts a loop. - // We can't access that loop's output channel easily unless we use `Mux` internals or if `CreatePortal` returned a channel (it returns cleanup). - - // Alternative: Don't use `CreatePortal`. Use `OpenPortal` pattern but ensure topics exist manually? - // But `ensureTopic` is private in `Mux`. - // AND we need the DB record. - - // Actually, `OpenPortal` (User side) subscribes to `TopicOut`. - // `CreatePortal` (Agent side) subscribes to `TopicIn`. - // We want to write to `TopicOut` (simulate Agent output) and read from `TopicIn` (simulate Agent input). - - // If `CreatePortal` consumes `TopicIn`, we can't read it. - // BUT, `CreatePortal` implementation (read earlier) does: - // `go func() { m.receiveLoop(ctxLoop, topicIn, sub) }()` - // We don't know what `receiveLoop` does. - // If `receiveLoop` just drops messages or handles history, we might miss them. - // Wait, `CreatePortal` is used by `c2/server.go`. It likely forwards to the gRPC stream. - // But here we don't have the gRPC stream connected to `CreatePortal`. - // `CreatePortal` signature in `mux_create.go`: `func (m *Mux) CreatePortal(...) (int, func(), error)`. - // It doesn't take a channel or stream. - // So where do messages go? - // Maybe `receiveLoop` puts them into `history` or `subs`? - // `SubscriberRegistry` in `Mux` struct. - - // If `CreatePortal` is running, it consumes messages. - // We probably shouldn't use `CreatePortal` if we want to intercept messages manually in the test using `Mux.Subscribe`. - // UNLESS `Mux.Subscribe` adds a *local* subscriber to the *internal* registry? - // `Mux` seems to have `SubscriberRegistry`. - // `Subscribe` probably adds a channel to this registry. - // `receiveLoop` probably reads from PubSub and broadcasts to registry channels. - // If so, we are fine. `CreatePortal` sets up the PubSub subscription. `Subscribe` hooks into the stream. - - // Let's verify `Subscribe`. I didn't read it. - // Assuming `Subscribe` works as intended (pubsub style). - agentInCh, agentSubCleanup := env.Mux.Subscribe(env.Mux.TopicIn(p.ID)) defer agentSubCleanup() @@ -282,8 +232,6 @@ func TestInteractiveShell(t *testing.T) { }, } // We publish to TopicOut. User subscribes to TopicOut. - // Ensure TopicOut exists? `CreatePortal` calls `ensureTopic(TopicOut)`. - // So it should be fine. err = env.Mux.Publish(ctx, env.Mux.TopicOut(p.ID), outMote) require.NoError(t, err) @@ -311,7 +259,7 @@ func TestInteractiveShell(t *testing.T) { } } } - require.Equal(t, outputData, outMsg.Output) + require.Equal(t, fmt.Sprintf("[+] %s\n%s", inputCmd, outputData), outMsg.Output) require.Equal(t, task.ID, outMsg.ShellTaskID) } @@ -386,7 +334,7 @@ func TestNonInteractiveShell(t *testing.T) { } } } - require.Equal(t, "root", outMsg.Output) + require.Equal(t, fmt.Sprintf("[+] %s\nroot", inputCmd), outMsg.Output) require.Equal(t, task.ID, outMsg.ShellTaskID) } @@ -494,3 +442,80 @@ func TestOtherStreamOutput(t *testing.T) { require.Contains(t, otherMsg.Output, "Other User") require.Contains(t, otherMsg.Output, "rebooting...") } + +func TestTruncation(t *testing.T) { + env := SetupTestEnv(t) + defer env.Close() + + // 1. Connect via WebSocket + url := fmt.Sprintf("%s?shell_id=%d", env.WSURL, env.Shell.ID) + ws, _, err := websocket.DefaultDialer.Dial(url, nil) + require.NoError(t, err) + defer ws.Close() + + // 2. Send Long Input + longInput := strings.Repeat("A", 100) + inputMsg := shell.WebsocketTaskInputMessage{ + Kind: shell.WebsocketMessageKindInput, + Input: longInput, + } + err = ws.WriteJSON(inputMsg) + require.NoError(t, err) + + // 3. Expect "Task Queued" Control Message + var msg shell.WebsocketControlFlowMessage + for { + _, data, err := ws.ReadMessage() + require.NoError(t, err) + + var genericMsg struct { + Kind string `json:"kind"` + } + json.Unmarshal(data, &genericMsg) + if genericMsg.Kind == shell.WebsocketMessageKindControlFlow { + json.Unmarshal(data, &msg) + if msg.Signal == shell.WebsocketControlFlowSignalTaskQueued { + break + } + } + } + require.Contains(t, msg.Message, "Task Queued for testbeacon") + + // 4. Simulate C2 updating Task Output + time.Sleep(100 * time.Millisecond) + tasks, err := env.EntClient.ShellTask.Query().All(context.Background()) + require.NoError(t, err) + require.NotEmpty(t, tasks) + task := tasks[0] + + _, err = task.Update().SetOutput("out").Save(context.Background()) + require.NoError(t, err) + + // 5. Expect Output on WebSocket + var outMsg shell.WebsocketTaskOutputMessage + found := false + timeout := time.After(5 * time.Second) + for !found { + select { + case <-timeout: + t.Fatal("timeout waiting for output") + default: + _, data, err := ws.ReadMessage() + if err != nil { + continue + } + var genericMsg struct { + Kind string `json:"kind"` + } + json.Unmarshal(data, &genericMsg) + if genericMsg.Kind == shell.WebsocketMessageKindOutput { + json.Unmarshal(data, &outMsg) + found = true + } + } + } + require.Equal(t, task.ID, outMsg.ShellTaskID) + + expectedPrefix := "[+] " + longInput[:64] + "...\n" + require.Contains(t, outMsg.Output, expectedPrefix) +}