diff --git a/pkg/cli/codemod_factory.go b/pkg/cli/codemod_factory.go new file mode 100644 index 0000000000..c9618fd5fc --- /dev/null +++ b/pkg/cli/codemod_factory.go @@ -0,0 +1,68 @@ +package cli + +import "github.com/github/gh-aw/pkg/logger" + +// PostTransformFunc is an optional hook called after the primary field removal. +// It receives the already-modified lines, the full frontmatter, and the removed +// field's value. It returns the (potentially further modified) lines. +type PostTransformFunc func(lines []string, frontmatter map[string]any, fieldValue any) []string + +// fieldRemovalCodemodConfig holds the configuration for a field-removal codemod. +type fieldRemovalCodemodConfig struct { + ID string + Name string + Description string + IntroducedIn string + ParentKey string // Top-level frontmatter key that contains the field + FieldKey string // Child field to remove from the parent block + LogMsg string // Debug log message emitted when the codemod is applied + Log *logger.Logger // Logger for the codemod + PostTransform PostTransformFunc // Optional hook for additional transforms after field removal +} + +// newFieldRemovalCodemod creates a Codemod that: +// 1. Checks that the parent key is present in the frontmatter and is a map. +// 2. Checks that the child field is present in that map. +// 3. Removes the field (and any nested content) from the YAML block. +// 4. Optionally invokes PostTransform for any additional line-level changes. +func newFieldRemovalCodemod(cfg fieldRemovalCodemodConfig) Codemod { + return Codemod{ + ID: cfg.ID, + Name: cfg.Name, + Description: cfg.Description, + IntroducedIn: cfg.IntroducedIn, + Apply: func(content string, frontmatter map[string]any) (string, bool, error) { + parentValue, hasParent := frontmatter[cfg.ParentKey] + if !hasParent { + return content, false, nil + } + + parentMap, ok := parentValue.(map[string]any) + if !ok { + return content, false, nil + } + + fieldValue, hasField := parentMap[cfg.FieldKey] + if !hasField { + return content, false, nil + } + + newContent, applied, err := applyFrontmatterLineTransform(content, func(lines []string) ([]string, bool) { + result, modified := removeFieldFromBlock(lines, cfg.FieldKey, cfg.ParentKey) + if !modified { + return lines, false + } + + if cfg.PostTransform != nil { + result = cfg.PostTransform(result, frontmatter, fieldValue) + } + + return result, true + }) + if applied { + cfg.Log.Print(cfg.LogMsg) + } + return newContent, applied, err + }, + } +} diff --git a/pkg/cli/codemod_factory_test.go b/pkg/cli/codemod_factory_test.go new file mode 100644 index 0000000000..00bff3c071 --- /dev/null +++ b/pkg/cli/codemod_factory_test.go @@ -0,0 +1,271 @@ +//go:build !integration + +package cli + +import ( + "testing" + + "github.com/github/gh-aw/pkg/logger" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var testFactoryLog = logger.New("cli:codemod_factory") + +// baseFieldRemovalConfig returns a minimal valid fieldRemovalCodemodConfig for testing. +func baseFieldRemovalConfig() fieldRemovalCodemodConfig { + return fieldRemovalCodemodConfig{ + ID: "test-removal", + Name: "Remove test field", + Description: "Removes the test field for testing purposes", + IntroducedIn: "1.0.0", + ParentKey: "parent", + FieldKey: "child", + LogMsg: "Applied test field removal", + Log: testFactoryLog, + } +} + +func TestNewFieldRemovalCodemod_Metadata(t *testing.T) { + cfg := baseFieldRemovalConfig() + codemod := newFieldRemovalCodemod(cfg) + + assert.Equal(t, cfg.ID, codemod.ID, "ID should match config") + assert.Equal(t, cfg.Name, codemod.Name, "Name should match config") + assert.Equal(t, cfg.Description, codemod.Description, "Description should match config") + assert.Equal(t, cfg.IntroducedIn, codemod.IntroducedIn, "IntroducedIn should match config") + require.NotNil(t, codemod.Apply, "Apply function should not be nil") +} + +func TestNewFieldRemovalCodemod_ParentKeyMissing(t *testing.T) { + codemod := newFieldRemovalCodemod(baseFieldRemovalConfig()) + + content := `--- +on: workflow_dispatch +other: value +--- + +# Test` + + frontmatter := map[string]any{ + "on": "workflow_dispatch", + "other": "value", + } + + result, applied, err := codemod.Apply(content, frontmatter) + + require.NoError(t, err, "Apply should not return an error when parent key is missing") + assert.False(t, applied, "Codemod should not report changes when parent key is missing") + assert.Equal(t, content, result, "Content should remain unchanged") +} + +func TestNewFieldRemovalCodemod_ParentKeyWrongType(t *testing.T) { + codemod := newFieldRemovalCodemod(baseFieldRemovalConfig()) + + content := `--- +on: workflow_dispatch +parent: simple_string +--- + +# Test` + + frontmatter := map[string]any{ + "on": "workflow_dispatch", + "parent": "simple_string", + } + + result, applied, err := codemod.Apply(content, frontmatter) + + require.NoError(t, err, "Apply should not return an error when parent is not a map") + assert.False(t, applied, "Codemod should not report changes when parent is not a map") + assert.Equal(t, content, result, "Content should remain unchanged") +} + +func TestNewFieldRemovalCodemod_FieldKeyMissing(t *testing.T) { + codemod := newFieldRemovalCodemod(baseFieldRemovalConfig()) + + content := `--- +on: workflow_dispatch +parent: + other: value +--- + +# Test` + + frontmatter := map[string]any{ + "on": "workflow_dispatch", + "parent": map[string]any{ + "other": "value", + }, + } + + result, applied, err := codemod.Apply(content, frontmatter) + + require.NoError(t, err, "Apply should not return an error when child field is missing") + assert.False(t, applied, "Codemod should not report changes when child field is missing") + assert.Equal(t, content, result, "Content should remain unchanged") +} + +func TestNewFieldRemovalCodemod_SuccessfulRemoval(t *testing.T) { + codemod := newFieldRemovalCodemod(baseFieldRemovalConfig()) + + content := `--- +on: workflow_dispatch +parent: + child: true + sibling: value +--- + +# Test` + + frontmatter := map[string]any{ + "on": "workflow_dispatch", + "parent": map[string]any{ + "child": true, + "sibling": "value", + }, + } + + result, applied, err := codemod.Apply(content, frontmatter) + + require.NoError(t, err, "Apply should not return an error on successful removal") + assert.True(t, applied, "Codemod should report that changes were applied") + assert.NotContains(t, result, "child:", "Result should not contain the removed field") + assert.Contains(t, result, "sibling: value", "Result should preserve sibling fields") +} + +func TestNewFieldRemovalCodemod_PostTransformInvoked(t *testing.T) { + var postTransformCalled bool + var capturedFieldValue any + + cfg := baseFieldRemovalConfig() + cfg.PostTransform = func(lines []string, frontmatter map[string]any, fieldValue any) []string { + postTransformCalled = true + capturedFieldValue = fieldValue + // Append a marker line so we can verify the hook ran + return append(lines, "# post-transform-marker") + } + + codemod := newFieldRemovalCodemod(cfg) + + content := `--- +on: workflow_dispatch +parent: + child: sentinel +--- + +# Test` + + frontmatter := map[string]any{ + "on": "workflow_dispatch", + "parent": map[string]any{ + "child": "sentinel", + }, + } + + result, applied, err := codemod.Apply(content, frontmatter) + + require.NoError(t, err, "Apply should not return an error") + assert.True(t, applied, "Codemod should report changes") + assert.True(t, postTransformCalled, "PostTransform hook should have been called") + assert.Equal(t, "sentinel", capturedFieldValue, "PostTransform should receive the removed field's value") + assert.Contains(t, result, "# post-transform-marker", "Result should contain the output of the PostTransform hook") +} + +func TestNewFieldRemovalCodemod_PostTransformNotCalledWhenFieldAbsent(t *testing.T) { + var postTransformCalled bool + + cfg := baseFieldRemovalConfig() + cfg.PostTransform = func(lines []string, frontmatter map[string]any, fieldValue any) []string { + postTransformCalled = true + return lines + } + + codemod := newFieldRemovalCodemod(cfg) + + content := `--- +on: workflow_dispatch +parent: + other: value +--- + +# Test` + + frontmatter := map[string]any{ + "on": "workflow_dispatch", + "parent": map[string]any{ + "other": "value", + }, + } + + _, applied, err := codemod.Apply(content, frontmatter) + + require.NoError(t, err, "Apply should not return an error") + assert.False(t, applied, "Codemod should not report changes") + assert.False(t, postTransformCalled, "PostTransform should not be called when field is absent") +} + +func TestNewFieldRemovalCodemod_TableDriven(t *testing.T) { + tests := []struct { + name string + content string + frontmatter map[string]any + wantApplied bool + wantContent string // expected substring in result when applied, or full match when not applied + }{ + { + name: "parent key absent", + content: "---\non: workflow_dispatch\n---\n\n# Test", + frontmatter: map[string]any{ + "on": "workflow_dispatch", + }, + wantApplied: false, + }, + { + name: "parent not a map", + content: "---\non: workflow_dispatch\nparent: scalar\n---\n\n# Test", + frontmatter: map[string]any{ + "on": "workflow_dispatch", + "parent": "scalar", + }, + wantApplied: false, + }, + { + name: "child field absent", + content: "---\non: workflow_dispatch\nparent:\n other: val\n---\n\n# Test", + frontmatter: map[string]any{ + "on": "workflow_dispatch", + "parent": map[string]any{"other": "val"}, + }, + wantApplied: false, + }, + { + name: "child field present", + content: "---\non: workflow_dispatch\nparent:\n child: yes\n other: val\n---\n\n# Test", + frontmatter: map[string]any{ + "on": "workflow_dispatch", + "parent": map[string]any{"child": true, "other": "val"}, + }, + wantApplied: true, + wantContent: "other: val", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + codemod := newFieldRemovalCodemod(baseFieldRemovalConfig()) + + result, applied, err := codemod.Apply(tt.content, tt.frontmatter) + + require.NoError(t, err, "Apply should not return an error") + assert.Equal(t, tt.wantApplied, applied, "Applied flag should match expectation") + + if tt.wantApplied { + assert.Contains(t, result, tt.wantContent, "Result should contain expected content") + assert.NotContains(t, result, "child:", "Result should not contain removed field") + } else { + assert.Equal(t, tt.content, result, "Content should be unchanged when not applied") + } + }) + } +} diff --git a/pkg/cli/codemod_grep_tool.go b/pkg/cli/codemod_grep_tool.go index 7f18dbc6b8..2a2d7db23a 100644 --- a/pkg/cli/codemod_grep_tool.go +++ b/pkg/cli/codemod_grep_tool.go @@ -6,36 +6,14 @@ var grepToolCodemodLog = logger.New("cli:codemod_grep_tool") // getGrepToolRemovalCodemod creates a codemod for removing the deprecated tools.grep field func getGrepToolRemovalCodemod() Codemod { - return Codemod{ + return newFieldRemovalCodemod(fieldRemovalCodemodConfig{ ID: "grep-tool-removal", Name: "Remove deprecated tools.grep field", Description: "Removes 'tools.grep' field as grep is now always enabled as part of default bash tools", IntroducedIn: "0.7.0", - Apply: func(content string, frontmatter map[string]any) (string, bool, error) { - // Check if tools.grep exists - toolsValue, hasTools := frontmatter["tools"] - if !hasTools { - return content, false, nil - } - - toolsMap, ok := toolsValue.(map[string]any) - if !ok { - return content, false, nil - } - - // Check if grep field exists in tools - _, hasGrep := toolsMap["grep"] - if !hasGrep { - return content, false, nil - } - - newContent, applied, err := applyFrontmatterLineTransform(content, func(lines []string) ([]string, bool) { - return removeFieldFromBlock(lines, "grep", "tools") - }) - if applied { - grepToolCodemodLog.Print("Applied grep tool removal") - } - return newContent, applied, err - }, - } + ParentKey: "tools", + FieldKey: "grep", + LogMsg: "Applied grep tool removal", + Log: grepToolCodemodLog, + }) } diff --git a/pkg/cli/codemod_mcp_scripts.go b/pkg/cli/codemod_mcp_scripts.go index a391c5ac8d..4bde6ce25c 100644 --- a/pkg/cli/codemod_mcp_scripts.go +++ b/pkg/cli/codemod_mcp_scripts.go @@ -8,36 +8,14 @@ var mcpScriptsModeCodemodLog = logger.New("cli:codemod_mcp_scripts") // getMCPScriptsModeCodemod creates a codemod for removing the deprecated mcp-scripts.mode field func getMCPScriptsModeCodemod() Codemod { - return Codemod{ + return newFieldRemovalCodemod(fieldRemovalCodemodConfig{ ID: "mcp-scripts-mode-removal", Name: "Remove deprecated mcp-scripts.mode field", Description: "Removes the deprecated 'mcp-scripts.mode' field (HTTP is now the only supported mode)", IntroducedIn: "0.2.0", - Apply: func(content string, frontmatter map[string]any) (string, bool, error) { - // Check if mcp-scripts.mode exists - mcpScriptsValue, hasMCPScripts := frontmatter["mcp-scripts"] - if !hasMCPScripts { - return content, false, nil - } - - mcpScriptsMap, ok := mcpScriptsValue.(map[string]any) - if !ok { - return content, false, nil - } - - // Check if mode field exists in mcp-scripts - _, hasMode := mcpScriptsMap["mode"] - if !hasMode { - return content, false, nil - } - - newContent, applied, err := applyFrontmatterLineTransform(content, func(lines []string) ([]string, bool) { - return removeFieldFromBlock(lines, "mode", "mcp-scripts") - }) - if applied { - mcpScriptsModeCodemodLog.Print("Applied mcp-scripts.mode removal") - } - return newContent, applied, err - }, - } + ParentKey: "mcp-scripts", + FieldKey: "mode", + LogMsg: "Applied mcp-scripts.mode removal", + Log: mcpScriptsModeCodemodLog, + }) } diff --git a/pkg/cli/codemod_network_firewall.go b/pkg/cli/codemod_network_firewall.go index f81d25b10d..42cd153d8a 100644 --- a/pkg/cli/codemod_network_firewall.go +++ b/pkg/cli/codemod_network_firewall.go @@ -10,88 +10,63 @@ var networkFirewallCodemodLog = logger.New("cli:codemod_network_firewall") // getNetworkFirewallCodemod creates a codemod for migrating network.firewall to sandbox.agent func getNetworkFirewallCodemod() Codemod { - return Codemod{ + return newFieldRemovalCodemod(fieldRemovalCodemodConfig{ ID: "network-firewall-migration", Name: "Migrate network.firewall to sandbox.agent", Description: "Removes deprecated 'network.firewall' field (firewall is now always enabled via sandbox.agent: awf default)", IntroducedIn: "0.1.0", - Apply: func(content string, frontmatter map[string]any) (string, bool, error) { - // Check if network.firewall exists - networkValue, hasNetwork := frontmatter["network"] - if !hasNetwork { - return content, false, nil - } - - networkMap, ok := networkValue.(map[string]any) - if !ok { - return content, false, nil - } - - // Check if firewall field exists in network - firewallValue, hasFirewall := networkMap["firewall"] - if !hasFirewall { - return content, false, nil - } - + ParentKey: "network", + FieldKey: "firewall", + LogMsg: "Applied network.firewall migration (firewall now always enabled via sandbox.agent: awf default)", + Log: networkFirewallCodemodLog, + PostTransform: func(lines []string, frontmatter map[string]any, fieldValue any) []string { // Note: We no longer set sandbox.agent: false since the firewall is mandatory // The firewall is always enabled via the default sandbox.agent: awf _, hasSandbox := frontmatter["sandbox"] - newContent, applied, err := applyFrontmatterLineTransform(content, func(lines []string) ([]string, bool) { - // Remove the firewall field from the network block - result, modified := removeFieldFromBlock(lines, "firewall", "network") - if !modified { - return lines, false + // Add sandbox.agent if not already present AND if firewall was explicitly true + // (no need to add sandbox.agent: awf if firewall was false, since awf is now the default) + if !hasSandbox && fieldValue == true { + // Only add sandbox.agent: awf if firewall was explicitly set to true + sandboxLines := []string{ + "sandbox:", + " agent: awf # Firewall enabled (migrated from network.firewall)", } - // Add sandbox.agent if not already present AND if firewall was explicitly true - // (no need to add sandbox.agent: awf if firewall was false, since awf is now the default) - if !hasSandbox && firewallValue == true { - // Only add sandbox.agent: awf if firewall was explicitly set to true - sandboxLines := []string{ - "sandbox:", - " agent: awf # Firewall enabled (migrated from network.firewall)", - } - - // Try to place it after network block - insertIndex := -1 - inNet := false - for i, line := range result { - trimmed := strings.TrimSpace(line) - if strings.HasPrefix(trimmed, "network:") { - inNet = true - } else if inNet && len(trimmed) > 0 { - // Check if this is a top-level key (no leading whitespace) - if isTopLevelKey(line) { - // Found next top-level key - insertIndex = i - break - } + // Try to place it after network block + insertIndex := -1 + inNet := false + for i, line := range lines { + trimmed := strings.TrimSpace(line) + if strings.HasPrefix(trimmed, "network:") { + inNet = true + } else if inNet && len(trimmed) > 0 { + // Check if this is a top-level key (no leading whitespace) + if isTopLevelKey(line) { + // Found next top-level key + insertIndex = i + break } } + } - if insertIndex >= 0 { - // Insert after network block - newLines := make([]string, 0, len(result)+len(sandboxLines)) - newLines = append(newLines, result[:insertIndex]...) - newLines = append(newLines, sandboxLines...) - newLines = append(newLines, result[insertIndex:]...) - result = newLines - } else { - // Append at the end - result = append(result, sandboxLines...) - } - - networkFirewallCodemodLog.Print("Added sandbox.agent: awf (firewall was explicitly enabled)") + if insertIndex >= 0 { + // Insert after network block + newLines := make([]string, 0, len(lines)+len(sandboxLines)) + newLines = append(newLines, lines[:insertIndex]...) + newLines = append(newLines, sandboxLines...) + newLines = append(newLines, lines[insertIndex:]...) + lines = newLines + } else { + // Append at the end + lines = append(lines, sandboxLines...) } - return result, true - }) - if applied { - networkFirewallCodemodLog.Printf("Applied network.firewall removal (firewall: %v removed, firewall now always enabled via sandbox.agent: awf default)", firewallValue) + networkFirewallCodemodLog.Print("Added sandbox.agent: awf (firewall was explicitly enabled)") } - return newContent, applied, err + + return lines }, - } + }) }