From e2aed0ebf6f19611865235103b94fe37a50ca185 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 6 Apr 2026 11:53:55 +0000 Subject: [PATCH 1/2] Initial plan From b4f844fdcd06596590fe9e509a0cff63a7f47743 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 6 Apr 2026 12:08:37 +0000 Subject: [PATCH 2/2] test(agentdrain): improve miner_test.go with testify assertions and new coverage Agent-Logs-Url: https://github.com/github/gh-aw/sessions/4f410dc1-fcf0-49ef-961b-738282a6dcd6 Co-authored-by: pelikhan <4175913+pelikhan@users.noreply.github.com> --- pkg/agentdrain/miner_test.go | 336 ++++++++++++++++++++++++----------- 1 file changed, 233 insertions(+), 103 deletions(-) diff --git a/pkg/agentdrain/miner_test.go b/pkg/agentdrain/miner_test.go index 1bd454f6046..2b3fe4432a1 100644 --- a/pkg/agentdrain/miner_test.go +++ b/pkg/agentdrain/miner_test.go @@ -4,102 +4,145 @@ package agentdrain import ( "fmt" - "strings" "sync" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestNewMiner(t *testing.T) { cfg := DefaultConfig() m, err := NewMiner(cfg) - if err != nil { - t.Fatalf("NewMiner: unexpected error: %v", err) - } - if m == nil { - t.Fatal("NewMiner: expected non-nil miner") + require.NoError(t, err, "NewMiner should not return an error") + require.NotNil(t, m, "NewMiner should return a non-nil miner") + assert.Equal(t, 0, m.ClusterCount(), "new miner should have zero clusters") +} + +func TestTrain(t *testing.T) { + tests := []struct { + name string + simThreshold float64 + lines []string + wantClusters int + wantWildcard bool + wantClusterIDNZ bool // last result ClusterID should be non-zero + }{ + { + name: "single line creates one cluster", + simThreshold: DefaultConfig().SimThreshold, + lines: []string{"stage=plan action=start"}, + wantClusters: 1, + wantWildcard: false, + wantClusterIDNZ: true, + }, + { + name: "two identical lines stay in one cluster without wildcard", + simThreshold: DefaultConfig().SimThreshold, + lines: []string{"stage=plan action=start", "stage=plan action=start"}, + wantClusters: 1, + wantWildcard: false, + wantClusterIDNZ: true, + }, + { + name: "two distinct lines create separate clusters", + simThreshold: DefaultConfig().SimThreshold, + lines: []string{"stage=plan action=start", "stage=finish status=ok"}, + wantClusters: 2, + wantWildcard: false, + }, + { + name: "similar lines merge and produce wildcard", + simThreshold: 0.4, + lines: []string{"stage=tool_call tool=search", "stage=tool_call tool=read_file"}, + wantClusters: 1, + wantWildcard: true, + wantClusterIDNZ: true, + }, } - if m.ClusterCount() != 0 { - t.Errorf("NewMiner: expected 0 clusters, got %d", m.ClusterCount()) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := DefaultConfig() + cfg.SimThreshold = tt.simThreshold + m, err := NewMiner(cfg) + require.NoError(t, err, "NewMiner should succeed") + + var result *MatchResult + for _, line := range tt.lines { + result, err = m.Train(line) + require.NoError(t, err, "Train should not return an error for line %q", line) + } + + assert.Equal(t, tt.wantClusters, m.ClusterCount(), "cluster count mismatch") + if tt.wantClusterIDNZ { + assert.NotZero(t, result.ClusterID, "last result ClusterID should be non-zero") + } + if tt.wantWildcard { + assert.Contains(t, result.Template, "<*>", "merged template should contain wildcard") + } + }) } } -func TestTrain_ClusterCreation(t *testing.T) { +func TestTrainEvent(t *testing.T) { m, err := NewMiner(DefaultConfig()) - if err != nil { - t.Fatalf("NewMiner: %v", err) - } - result, err := m.Train("stage=plan action=start") - if err != nil { - t.Fatalf("Train: unexpected error: %v", err) - } - if result.ClusterID == 0 { - t.Error("Train: expected non-zero ClusterID") - } - if m.ClusterCount() != 1 { - t.Errorf("Train: expected 1 cluster, got %d", m.ClusterCount()) + require.NoError(t, err, "NewMiner should succeed") + + evt := AgentEvent{ + Stage: "plan", + Fields: map[string]string{"action": "start"}, } + result, err := m.TrainEvent(evt) + require.NoError(t, err, "TrainEvent should not return an error") + require.NotNil(t, result, "TrainEvent should return a non-nil result") + assert.Equal(t, "plan", result.Stage, "TrainEvent should propagate the stage to the result") + assert.Equal(t, 1, m.ClusterCount(), "TrainEvent should create one cluster") } -func TestTrain_ClusterMerge(t *testing.T) { - cfg := DefaultConfig() - cfg.SimThreshold = 0.4 - m, err := NewMiner(cfg) - if err != nil { - t.Fatalf("NewMiner: %v", err) - } +func TestClusters(t *testing.T) { + m, err := NewMiner(DefaultConfig()) + require.NoError(t, err, "NewMiner should succeed") - // These two lines differ only in the tool name value. - _, err = m.Train("stage=tool_call tool=search") - if err != nil { - t.Fatalf("Train 1: %v", err) - } - result, err := m.Train("stage=tool_call tool=read_file") - if err != nil { - t.Fatalf("Train 2: %v", err) - } + assert.Empty(t, m.Clusters(), "Clusters should be empty for a new miner") - // Should merge into one cluster. - if m.ClusterCount() != 1 { - t.Errorf("expected 1 cluster after merge, got %d", m.ClusterCount()) - } - if !strings.Contains(result.Template, "<*>") { - t.Errorf("expected wildcard in merged template, got: %q", result.Template) - } + _, err = m.Train("stage=plan action=start") + require.NoError(t, err, "Train should not return an error") + + clusters := m.Clusters() + assert.Len(t, clusters, 1, "Clusters should return one cluster after training one line") + assert.NotZero(t, clusters[0].ID, "cluster ID should be non-zero") } func TestMasking(t *testing.T) { masker, err := NewMasker(DefaultConfig().MaskRules) - if err != nil { - t.Fatalf("NewMasker: %v", err) - } + require.NoError(t, err, "NewMasker should not return an error") tests := []struct { - input string - check func(string) bool - name string + name string + input string + wantContain string }{ { - name: "UUID replaced", - input: "id=550e8400-e29b-41d4-a716-446655440000 msg=ok", - check: func(s string) bool { return strings.Contains(s, "") }, + name: "UUID replaced", + input: "id=550e8400-e29b-41d4-a716-446655440000 msg=ok", + wantContain: "", }, { - name: "URL replaced", - input: "fetching https://example.com/api/v1", - check: func(s string) bool { return strings.Contains(s, "") }, + name: "URL replaced", + input: "fetching https://example.com/api/v1", + wantContain: "", }, { - name: "Number value replaced", - input: "latency_ms=250", - check: func(s string) bool { return strings.Contains(s, "=") }, + name: "Number value replaced", + input: "latency_ms=250", + wantContain: "=", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { out := masker.Mask(tt.input) - if !tt.check(out) { - t.Errorf("Mask(%q) = %q, check failed", tt.input, out) - } + assert.Contains(t, out, tt.wantContain, "Mask(%q) should contain %q", tt.input, tt.wantContain) }) } } @@ -117,26 +160,29 @@ func TestFlattenEvent(t *testing.T) { exclude := []string{"session_id"} result := FlattenEvent(evt, exclude) - // session_id must be excluded. - if strings.Contains(result, "session_id") { - t.Errorf("FlattenEvent: excluded field present: %q", result) - } - // Keys should be sorted: latency_ms < query < tool. - idx := func(s string) int { return strings.Index(result, s) } - if idx("latency_ms=") > idx("query=") || idx("query=") > idx("tool=") { - t.Errorf("FlattenEvent: keys not sorted: %q", result) - } - // Stage should appear first. - if !strings.HasPrefix(result, "stage=tool_call") { - t.Errorf("FlattenEvent: stage not first: %q", result) + assert.NotContains(t, result, "session_id", "excluded field should not appear in flattened output") + assert.True(t, len(result) > 0 && + indexIn(result, "latency_ms=") < indexIn(result, "query=") && + indexIn(result, "query=") < indexIn(result, "tool="), + "keys should be sorted alphabetically in flattened output: %q", result) + assert.True(t, len(result) >= len("stage=tool_call") && + result[:len("stage=tool_call")] == "stage=tool_call", + "stage should appear first in flattened output: %q", result) +} + +// indexIn returns the byte offset of substr in s, or -1 if not found. +func indexIn(s, substr string) int { + for i := range len(s) - len(substr) + 1 { + if s[i:i+len(substr)] == substr { + return i + } } + return -1 } func TestConcurrency(t *testing.T) { m, err := NewMiner(DefaultConfig()) - if err != nil { - t.Fatalf("NewMiner: %v", err) - } + require.NoError(t, err, "NewMiner should succeed") var wg sync.WaitGroup const goroutines = 10 @@ -148,26 +194,21 @@ func TestConcurrency(t *testing.T) { defer wg.Done() for i := range linesEach { line := fmt.Sprintf("stage=work goroutine=%d iter=%d", id, i) - if _, err := m.Train(line); err != nil { - t.Errorf("Train: %v", err) - } + _, trainErr := m.Train(line) + assert.NoError(t, trainErr, "Train should not error during concurrent access") } }(g) } wg.Wait() - if m.ClusterCount() == 0 { - t.Error("expected clusters after concurrent training") - } + assert.Positive(t, m.ClusterCount(), "there should be clusters after concurrent training") } func TestStageRouting(t *testing.T) { cfg := DefaultConfig() stages := []string{"plan", "tool_call", "finish"} coord, err := NewCoordinator(cfg, stages) - if err != nil { - t.Fatalf("NewCoordinator: %v", err) - } + require.NoError(t, err, "NewCoordinator should succeed") events := []AgentEvent{ {Stage: "plan", Fields: map[string]string{"action": "start"}}, @@ -175,15 +216,113 @@ func TestStageRouting(t *testing.T) { {Stage: "finish", Fields: map[string]string{"status": "ok"}}, } for _, evt := range events { - if _, err := coord.TrainEvent(evt); err != nil { - t.Fatalf("TrainEvent(%q): %v", evt.Stage, err) - } + _, err := coord.TrainEvent(evt) + require.NoError(t, err, "TrainEvent should succeed for known stage %q", evt.Stage) } - // Unknown stage should error. _, err = coord.TrainEvent(AgentEvent{Stage: "unknown", Fields: map[string]string{}}) - if err == nil { - t.Error("expected error for unknown stage, got nil") + assert.Error(t, err, "TrainEvent should return an error for an unknown stage") +} + +func TestCoordinatorAnalyzeEvent(t *testing.T) { + cfg := DefaultConfig() + stages := []string{"plan", "tool_call"} + coord, err := NewCoordinator(cfg, stages) + require.NoError(t, err, "NewCoordinator should succeed") + + evt := AgentEvent{Stage: "plan", Fields: map[string]string{"action": "start"}} + + // First occurrence should be a new template. + result, report, err := coord.AnalyzeEvent(evt) + require.NoError(t, err, "AnalyzeEvent should not error on first event") + require.NotNil(t, result, "AnalyzeEvent should return a non-nil result") + require.NotNil(t, report, "AnalyzeEvent should return a non-nil report") + assert.True(t, report.IsNewTemplate, "first event should be flagged as a new template") + + // Second identical occurrence should not be new. + _, report2, err := coord.AnalyzeEvent(evt) + require.NoError(t, err, "AnalyzeEvent should not error on second event") + require.NotNil(t, report2, "second AnalyzeEvent should return a non-nil report") + assert.False(t, report2.IsNewTemplate, "second identical event should not be flagged as a new template") + + // Unknown stage should error. + _, _, err = coord.AnalyzeEvent(AgentEvent{Stage: "unknown"}) + assert.Error(t, err, "AnalyzeEvent should return an error for an unknown stage") +} + +func TestStageSequence(t *testing.T) { + tests := []struct { + name string + events []AgentEvent + expected string + }{ + { + name: "empty slice", + events: []AgentEvent{}, + expected: "", + }, + { + name: "single event", + events: []AgentEvent{ + {Stage: "plan"}, + }, + expected: "plan", + }, + { + name: "typical pipeline", + events: []AgentEvent{ + {Stage: "plan"}, + {Stage: "tool_call"}, + {Stage: "tool_result"}, + {Stage: "finish"}, + }, + expected: "plan tool_call tool_result finish", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := StageSequence(tt.events) + assert.Equal(t, tt.expected, got, "StageSequence result mismatch") + }) + } +} + +func TestPersistenceRoundTrip(t *testing.T) { + cfg := DefaultConfig() + stages := []string{"plan", "tool_call", "finish"} + coord, err := NewCoordinator(cfg, stages) + require.NoError(t, err, "NewCoordinator should succeed") + + // Train each stage miner with some events. + trainingEvents := []AgentEvent{ + {Stage: "plan", Fields: map[string]string{"action": "start"}}, + {Stage: "plan", Fields: map[string]string{"action": "stop"}}, + {Stage: "tool_call", Fields: map[string]string{"tool": "search", "query": "foo"}}, + {Stage: "finish", Fields: map[string]string{"status": "ok"}}, + } + for _, evt := range trainingEvents { + _, err := coord.TrainEvent(evt) + require.NoError(t, err, "TrainEvent should succeed for stage %q", evt.Stage) + } + + // Save snapshots. + snapshots, err := coord.SaveSnapshots() + require.NoError(t, err, "SaveSnapshots should succeed") + assert.Len(t, snapshots, len(stages), "SaveSnapshots should return one entry per stage") + + // Create a new coordinator and restore state. + coord2, err := NewCoordinator(cfg, stages) + require.NoError(t, err, "NewCoordinator for restore should succeed") + err = coord2.LoadSnapshots(snapshots) + require.NoError(t, err, "LoadSnapshots should succeed") + + // Cluster counts must match the original coordinator. + original := coord.AllClusters() + restored := coord2.AllClusters() + for _, stage := range stages { + assert.Len(t, restored[stage], len(original[stage]), + "restored cluster count for stage %q should match original", stage) } } @@ -229,9 +368,7 @@ func TestComputeSimilarity(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := computeSimilarity(tt.a, tt.b, param) - if got != tt.expected { - t.Errorf("computeSimilarity = %v, want %v", got, tt.expected) - } + assert.InDelta(t, tt.expected, got, 1e-9, "computeSimilarity(%v, %v) mismatch", tt.a, tt.b) }) } } @@ -272,14 +409,7 @@ func TestMergeTemplate(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := mergeTemplate(tt.existing, tt.incoming, param) - if len(got) != len(tt.expected) { - t.Fatalf("mergeTemplate len = %d, want %d", len(got), len(tt.expected)) - } - for i, tok := range got { - if tok != tt.expected[i] { - t.Errorf("mergeTemplate[%d] = %q, want %q", i, tok, tt.expected[i]) - } - } + assert.Equal(t, tt.expected, got, "mergeTemplate(%v, %v) mismatch", tt.existing, tt.incoming) }) } }