Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions tavern/internal/http/shell/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -488,10 +488,14 @@ func (h *Handler) writeMessagesFromShell(ctx context.Context, session *ShellSess
}

if task.Output != "" {
taskOutputCh <- NewWebsocketTaskOutputMessage(task)
msg := NewWebsocketTaskOutputMessage(task)
msg.Output = fmt.Sprintf("[+] %s\n%s", truncateInput(task.Input), msg.Output)
taskOutputCh <- msg
}
if task.Error != "" {
taskErrCh <- NewWebsocketTaskErrorMessage(task)
msg := NewWebsocketTaskErrorMessage(task)
msg.Error = fmt.Sprintf("[!] %s\n%s", truncateInput(task.Input), msg.Error)
taskErrCh <- msg
}
sentTasks[task.ID] = struct{}{}
}
Expand Down Expand Up @@ -556,9 +560,13 @@ func (h *Handler) writeMessagesFromShell(ctx context.Context, session *ShellSess
if mote.StreamId == streamID {
// Local stream output
outputMsg := NewWebsocketTaskOutputMessage(task)
outputMsg.Output = string(bytesPayload.Data) // Use real-time chunk
if _, sent := sentTasks[task.ID]; !sent {
outputMsg.Output = fmt.Sprintf("[+] %s\n%s", truncateInput(task.Input), string(bytesPayload.Data))
sentTasks[task.ID] = struct{}{}
} else {
outputMsg.Output = string(bytesPayload.Data) // Use real-time chunk
}
taskOutputCh <- outputMsg
sentTasks[task.ID] = struct{}{}
} else {
// Other stream output
if _, sent := sentTasks[task.ID]; !sent {
Expand Down Expand Up @@ -645,3 +653,11 @@ func (h *Handler) writeMessagesFromShell(ctx context.Context, session *ShellSess

}
}

func truncateInput(input string) string {
const maxLength = 64
if len(input) > maxLength {
return input[:maxLength] + "..."
}
return input
}
133 changes: 79 additions & 54 deletions tavern/internal/http/shell/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,56 +166,6 @@ func TestInteractiveShell(t *testing.T) {
require.NoError(t, err)

// Subscribe to IN topic (User -> Agent)
// CreatePortal sets up subscriptions for IN topic?
// `CreatePortal` calls `m.openSubscription(ctx, topicIn, subName)` and starts `m.receiveLoop`.
// Wait, `CreatePortal` subscribes to IN topic (messages FROM user TO agent).
// But it consumes them internally (via `receiveLoop`).
// `receiveLoop` likely processes messages or forwards them?
// `mux.go` doesn't show `receiveLoop` impl.
// But typically `CreatePortal` is used by the C2 server to bridge PubSub <-> gRPC stream.
// Here in test, we want to simulate the Agent receiving the message.
// If `CreatePortal` already subscribed, we can't subscribe again easily with `mempubsub` (maybe?).
// However, `mux` exposes `Subscribe` method? Yes.
// But `CreatePortal` locks the sub manager.
// If we want to intercept messages sent by User, we should subscribe to `TopicIn`.
// If `CreatePortal` already subscribed, we might have competition.
// But `CreatePortal` is designed for the C2 service.
// In this test, we ARE the C2 service / Agent.
// `CreatePortal` logic: starts a loop.
// We can't access that loop's output channel easily unless we use `Mux` internals or if `CreatePortal` returned a channel (it returns cleanup).

// Alternative: Don't use `CreatePortal`. Use `OpenPortal` pattern but ensure topics exist manually?
// But `ensureTopic` is private in `Mux`.
// AND we need the DB record.

// Actually, `OpenPortal` (User side) subscribes to `TopicOut`.
// `CreatePortal` (Agent side) subscribes to `TopicIn`.
// We want to write to `TopicOut` (simulate Agent output) and read from `TopicIn` (simulate Agent input).

// If `CreatePortal` consumes `TopicIn`, we can't read it.
// BUT, `CreatePortal` implementation (read earlier) does:
// `go func() { m.receiveLoop(ctxLoop, topicIn, sub) }()`
// We don't know what `receiveLoop` does.
// If `receiveLoop` just drops messages or handles history, we might miss them.
// Wait, `CreatePortal` is used by `c2/server.go`. It likely forwards to the gRPC stream.
// But here we don't have the gRPC stream connected to `CreatePortal`.
// `CreatePortal` signature in `mux_create.go`: `func (m *Mux) CreatePortal(...) (int, func(), error)`.
// It doesn't take a channel or stream.
// So where do messages go?
// Maybe `receiveLoop` puts them into `history` or `subs`?
// `SubscriberRegistry` in `Mux` struct.

// If `CreatePortal` is running, it consumes messages.
// We probably shouldn't use `CreatePortal` if we want to intercept messages manually in the test using `Mux.Subscribe`.
// UNLESS `Mux.Subscribe` adds a *local* subscriber to the *internal* registry?
// `Mux` seems to have `SubscriberRegistry`.
// `Subscribe` probably adds a channel to this registry.
// `receiveLoop` probably reads from PubSub and broadcasts to registry channels.
// If so, we are fine. `CreatePortal` sets up the PubSub subscription. `Subscribe` hooks into the stream.

// Let's verify `Subscribe`. I didn't read it.
// Assuming `Subscribe` works as intended (pubsub style).

agentInCh, agentSubCleanup := env.Mux.Subscribe(env.Mux.TopicIn(p.ID))
defer agentSubCleanup()

Expand Down Expand Up @@ -282,8 +232,6 @@ func TestInteractiveShell(t *testing.T) {
},
}
// We publish to TopicOut. User subscribes to TopicOut.
// Ensure TopicOut exists? `CreatePortal` calls `ensureTopic(TopicOut)`.
// So it should be fine.
err = env.Mux.Publish(ctx, env.Mux.TopicOut(p.ID), outMote)
require.NoError(t, err)

Expand Down Expand Up @@ -311,7 +259,7 @@ func TestInteractiveShell(t *testing.T) {
}
}
}
require.Equal(t, outputData, outMsg.Output)
require.Equal(t, fmt.Sprintf("[+] %s\n%s", inputCmd, outputData), outMsg.Output)
require.Equal(t, task.ID, outMsg.ShellTaskID)
}

Expand Down Expand Up @@ -386,7 +334,7 @@ func TestNonInteractiveShell(t *testing.T) {
}
}
}
require.Equal(t, "root", outMsg.Output)
require.Equal(t, fmt.Sprintf("[+] %s\nroot", inputCmd), outMsg.Output)
require.Equal(t, task.ID, outMsg.ShellTaskID)
}

Expand Down Expand Up @@ -494,3 +442,80 @@ func TestOtherStreamOutput(t *testing.T) {
require.Contains(t, otherMsg.Output, "Other User")
require.Contains(t, otherMsg.Output, "rebooting...")
}

func TestTruncation(t *testing.T) {
env := SetupTestEnv(t)
defer env.Close()

// 1. Connect via WebSocket
url := fmt.Sprintf("%s?shell_id=%d", env.WSURL, env.Shell.ID)
ws, _, err := websocket.DefaultDialer.Dial(url, nil)
require.NoError(t, err)
defer ws.Close()

// 2. Send Long Input
longInput := strings.Repeat("A", 100)
inputMsg := shell.WebsocketTaskInputMessage{
Kind: shell.WebsocketMessageKindInput,
Input: longInput,
}
err = ws.WriteJSON(inputMsg)
require.NoError(t, err)

// 3. Expect "Task Queued" Control Message
var msg shell.WebsocketControlFlowMessage
for {
_, data, err := ws.ReadMessage()
require.NoError(t, err)

var genericMsg struct {
Kind string `json:"kind"`
}
json.Unmarshal(data, &genericMsg)
if genericMsg.Kind == shell.WebsocketMessageKindControlFlow {
json.Unmarshal(data, &msg)
if msg.Signal == shell.WebsocketControlFlowSignalTaskQueued {
break
}
}
}
require.Contains(t, msg.Message, "Task Queued for testbeacon")

// 4. Simulate C2 updating Task Output
time.Sleep(100 * time.Millisecond)
tasks, err := env.EntClient.ShellTask.Query().All(context.Background())
require.NoError(t, err)
require.NotEmpty(t, tasks)
task := tasks[0]

_, err = task.Update().SetOutput("out").Save(context.Background())
require.NoError(t, err)

// 5. Expect Output on WebSocket
var outMsg shell.WebsocketTaskOutputMessage
found := false
timeout := time.After(5 * time.Second)
for !found {
select {
case <-timeout:
t.Fatal("timeout waiting for output")
default:
_, data, err := ws.ReadMessage()
if err != nil {
continue
}
var genericMsg struct {
Kind string `json:"kind"`
}
json.Unmarshal(data, &genericMsg)
if genericMsg.Kind == shell.WebsocketMessageKindOutput {
json.Unmarshal(data, &outMsg)
found = true
}
}
}
require.Equal(t, task.ID, outMsg.ShellTaskID)

expectedPrefix := "[+] " + longInput[:64] + "...\n"
require.Contains(t, outMsg.Output, expectedPrefix)
}
Loading