diff --git a/pkg/workflow/error_aggregation_test.go b/pkg/workflow/error_aggregation_test.go index a2b994afa71..1ac656ec202 100644 --- a/pkg/workflow/error_aggregation_test.go +++ b/pkg/workflow/error_aggregation_test.go @@ -4,12 +4,21 @@ package workflow import ( "errors" + "fmt" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +type testFormattedErrorChainType struct { + message string +} + +func (e *testFormattedErrorChainType) Error() string { + return e.message +} + func TestNewErrorCollector(t *testing.T) { tests := []struct { name string @@ -271,3 +280,21 @@ func TestErrorCollectorFormattedError(t *testing.T) { }) } } + +func TestErrorCollectorFormattedError_PreservesErrorChain(t *testing.T) { + collector := NewErrorCollector(false) + sentinelErr := errors.New("sentinel") + typedErr := &testFormattedErrorChainType{message: "typed"} + + require.NoError(t, collector.Add(fmt.Errorf("wrapped sentinel: %w", sentinelErr)), "Should collect wrapped sentinel error") + require.NoError(t, collector.Add(fmt.Errorf("wrapped typed: %w", typedErr)), "Should collect wrapped typed error") + + result := collector.FormattedError("validation") + require.Error(t, result, "Should return formatted error") + + require.ErrorIs(t, result, sentinelErr, "FormattedError should preserve errors.Is chain") + + var extractedTypedErr *testFormattedErrorChainType + require.ErrorAs(t, result, &extractedTypedErr, "FormattedError should preserve errors.As chain") + assert.Equal(t, typedErr, extractedTypedErr, "errors.As should extract the wrapped typed error") +} diff --git a/pkg/workflow/workflow_errors.go b/pkg/workflow/workflow_errors.go index 20841f42cd6..a3031e3a717 100644 --- a/pkg/workflow/workflow_errors.go +++ b/pkg/workflow/workflow_errors.go @@ -250,15 +250,8 @@ func (c *ErrorCollector) FormattedError(category string) error { return c.errors[0] } - // Build formatted error with count header - var sb strings.Builder - fmt.Fprintf(&sb, "Found %d %s errors:", len(c.errors), category) - for _, err := range c.errors { - sb.WriteString("\n • ") - sb.WriteString(err.Error()) - } - - return fmt.Errorf("%s", sb.String()) + header := fmt.Sprintf("Found %d %s errors:", len(c.errors), category) + return fmt.Errorf("%s\n%w", header, errors.Join(c.errors...)) } var sharedWorkflowLog = logger.New("workflow:shared_workflow_error")