diff --git a/go.mod b/go.mod index cfb230614..5c2478ad5 100644 --- a/go.mod +++ b/go.mod @@ -14,6 +14,7 @@ require ( github.com/fluxcd/source-controller/api v1.6.2 github.com/getsops/sops/v3 v3.10.2 github.com/goccy/go-yaml v1.18.0 + github.com/google/cel-go v0.26.0 github.com/google/go-containerregistry v0.20.6 github.com/google/go-jsonnet v0.21.0 github.com/hashicorp/hcl/v2 v2.24.0 @@ -57,6 +58,7 @@ require ( github.com/ProtonMail/gopenpgp/v2 v2.8.3 // indirect github.com/adrg/xdg v0.5.3 // indirect github.com/agext/levenshtein v1.2.3 // indirect + github.com/antlr4-go/antlr/v4 v4.13.1 // indirect github.com/apparentlymart/go-textseg/v15 v15.0.0 // indirect github.com/aws/aws-sdk-go-v2 v1.36.3 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10 // indirect @@ -172,6 +174,7 @@ require ( github.com/siderolabs/protoenc v0.2.2 // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/spiffe/go-spiffe/v2 v2.5.0 // indirect + github.com/stoewer/go-strcase v1.3.0 // indirect github.com/tetratelabs/wabin v0.0.0-20230304001439-f6f874872834 // indirect github.com/tetratelabs/wazero v1.9.0 // indirect github.com/urfave/cli v1.22.16 // indirect @@ -194,6 +197,7 @@ require ( go.uber.org/zap v1.27.0 // indirect go.yaml.in/yaml/v2 v2.4.2 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect + golang.org/x/exp v0.0.0-20250128182459-e0ece0dbea4c // indirect golang.org/x/mod v0.27.0 // indirect golang.org/x/net v0.43.0 // indirect golang.org/x/oauth2 v0.30.0 // indirect diff --git a/go.sum b/go.sum index 4e7d83480..01259eb4e 100644 --- a/go.sum +++ b/go.sum @@ -439,6 +439,7 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= diff --git a/pkg/blueprint/feature_evaluator.go b/pkg/blueprint/feature_evaluator.go new file mode 100644 index 000000000..9ab5bd41c --- /dev/null +++ b/pkg/blueprint/feature_evaluator.go @@ -0,0 +1,156 @@ +package blueprint + +import ( + "fmt" + "reflect" + + "github.com/google/cel-go/cel" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" +) + +// The FeatureEvaluator is a CEL-based expression evaluator for blueprint feature conditions. +// It provides GitHub Actions-style conditional expression evaluation capabilities with support +// for nested object access, logical operators, and type-safe variable declarations. +// The FeatureEvaluator enables dynamic feature activation based on user configuration values. + +// ============================================================================= +// Types +// ============================================================================= + +// FeatureEvaluator provides CEL expression evaluation capabilities for feature conditions. +type FeatureEvaluator struct { + env *cel.Env +} + +// ============================================================================= +// Constructor +// ============================================================================= + +// NewFeatureEvaluator creates a new CEL-based feature evaluator configured for evaluating feature conditions. +// The evaluator is pre-configured with standard libraries and custom functions needed +// for blueprint feature evaluation. +func NewFeatureEvaluator() (*FeatureEvaluator, error) { + env, err := cel.NewEnv( + cel.HomogeneousAggregateLiterals(), + cel.EagerlyValidateDeclarations(true), + ) + if err != nil { + return nil, fmt.Errorf("failed to create CEL environment: %w", err) + } + + return &FeatureEvaluator{ + env: env, + }, nil +} + +// ============================================================================= +// Public Methods +// ============================================================================= + +// CompileExpression compiles a CEL expression string with variable declarations derived from the config structure. +// The expression should follow GitHub Actions-style syntax with support for: +// - Equality/inequality: ==, != +// - Logical operators: &&, || +// - Parentheses for grouping: (expression) +// - Nested object access: provider, observability.enabled, vm.driver +// Returns a compiled program that can be evaluated against configuration data. +func (e *FeatureEvaluator) CompileExpression(expression string, config map[string]any) (cel.Program, error) { + if expression == "" { + return nil, fmt.Errorf("expression cannot be empty") + } + + var envOptions []cel.EnvOption + envOptions = append(envOptions, cel.HomogeneousAggregateLiterals()) + envOptions = append(envOptions, cel.EagerlyValidateDeclarations(true)) + + for key, value := range config { + envOptions = append(envOptions, cel.Variable(key, e.getCELType(value))) + } + + env, err := cel.NewEnv(envOptions...) + if err != nil { + return nil, fmt.Errorf("failed to create CEL environment with config: %w", err) + } + + ast, issues := env.Compile(expression) + if issues.Err() != nil { + return nil, fmt.Errorf("failed to compile expression '%s': %w", expression, issues.Err()) + } + + program, err := env.Program(ast) + if err != nil { + return nil, fmt.Errorf("failed to create program for expression '%s': %w", expression, err) + } + + return program, nil +} + +// EvaluateProgram executes a compiled CEL program against the provided configuration data. +// The configuration data should be a map containing the user's configuration values +// that the expression will be evaluated against. +// Returns true if the expression evaluates to true, false otherwise. +func (e *FeatureEvaluator) EvaluateProgram(program cel.Program, config map[string]any) (bool, error) { + if config == nil { + config = make(map[string]any) + } + + result, _, err := program.Eval(config) + if err != nil { + return false, fmt.Errorf("failed to evaluate expression: %w", err) + } + + return e.convertToBool(result) +} + +// EvaluateExpression is a convenience method that compiles and evaluates an expression in one call. +// This is useful for one-time evaluations where the compiled program won't be reused. +func (e *FeatureEvaluator) EvaluateExpression(expression string, config map[string]any) (bool, error) { + program, err := e.CompileExpression(expression, config) + if err != nil { + return false, err + } + + return e.EvaluateProgram(program, config) +} + +// ============================================================================= +// Private Methods +// ============================================================================= + +// convertToBool converts a CEL result value to a boolean. +// CEL expressions should evaluate to boolean values for feature conditions. +func (e *FeatureEvaluator) convertToBool(result ref.Val) (bool, error) { + if result.Type() == types.BoolType { + return result.Value().(bool), nil + } + + return false, fmt.Errorf("expression must evaluate to boolean, got %s", result.Type()) +} + +// getCELType determines the appropriate CEL type for a Go value. +// This is used to create variable declarations for the CEL environment. +func (e *FeatureEvaluator) getCELType(value any) *cel.Type { + if value == nil { + return cel.DynType + } + + switch reflect.TypeOf(value).Kind() { + case reflect.String: + return cel.StringType + case reflect.Bool: + return cel.BoolType + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return cel.IntType + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return cel.UintType + case reflect.Float32, reflect.Float64: + return cel.DoubleType + case reflect.Map: + return cel.MapType(cel.StringType, cel.DynType) + case reflect.Slice, reflect.Array: + return cel.ListType(cel.DynType) + default: + return cel.DynType + } +} diff --git a/pkg/blueprint/feature_evaluator_test.go b/pkg/blueprint/feature_evaluator_test.go new file mode 100644 index 000000000..dbe236937 --- /dev/null +++ b/pkg/blueprint/feature_evaluator_test.go @@ -0,0 +1,378 @@ +package blueprint + +import ( + "testing" +) + +// ============================================================================= +// Test Constructor +// ============================================================================= + +func TestNewFeatureEvaluator(t *testing.T) { + t.Run("CreatesNewFeatureEvaluatorSuccessfully", func(t *testing.T) { + evaluator, err := NewFeatureEvaluator() + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if evaluator == nil { + t.Fatal("Expected evaluator, got nil") + } + if evaluator.env == nil { + t.Fatal("Expected CEL env to be initialized") + } + }) +} + +// ============================================================================= +// Test Public Methods +// ============================================================================= + +func TestCompileExpression(t *testing.T) { + evaluator, err := NewFeatureEvaluator() + if err != nil { + t.Fatalf("Failed to create evaluator: %v", err) + } + + tests := []struct { + name string + expression string + shouldError bool + }{ + { + name: "EmptyExpressionFails", + expression: "", + shouldError: true, + }, + { + name: "SimpleEqualityExpression", + expression: "provider == 'aws'", + shouldError: false, + }, + { + name: "SimpleInequalityExpression", + expression: "provider != 'local'", + shouldError: false, + }, + { + name: "LogicalAndExpression", + expression: "provider == 'local' && observability.enabled == true", + shouldError: false, + }, + { + name: "LogicalOrExpression", + expression: "provider == 'aws' || provider == 'azure'", + shouldError: false, + }, + { + name: "ParenthesesGrouping", + expression: "provider == 'local' && (vm.driver != 'docker-desktop' || loadbalancer.enabled == true)", + shouldError: false, + }, + { + name: "NestedObjectAccess", + expression: "observability.enabled == true && observability.backend == 'quickwit'", + shouldError: false, + }, + { + name: "BooleanComparison", + expression: "dns.enabled == true", + shouldError: false, + }, + { + name: "InvalidSyntaxFails", + expression: "provider ===", + shouldError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := map[string]any{ + "provider": "aws", + "observability": map[string]any{ + "enabled": true, + "backend": "quickwit", + }, + "vm": map[string]any{ + "driver": "virtualbox", + }, + "loadbalancer": map[string]any{ + "enabled": true, + }, + "dns": map[string]any{ + "enabled": true, + }, + } + + program, err := evaluator.CompileExpression(tt.expression, config) + + if tt.shouldError { + if err == nil { + t.Errorf("Expected error for expression '%s', got none", tt.expression) + } + } else { + if err != nil { + t.Errorf("Expected no error for expression '%s', got %v", tt.expression, err) + } + if program == nil { + t.Errorf("Expected program for expression '%s', got nil", tt.expression) + } + } + }) + } +} + +func TestEvaluateProgram(t *testing.T) { + evaluator, err := NewFeatureEvaluator() + if err != nil { + t.Fatalf("Failed to create evaluator: %v", err) + } + + tests := []struct { + name string + expression string + config map[string]any + expected bool + shouldErr bool + }{ + { + name: "SimpleStringEqualityTrue", + expression: "provider == 'aws'", + config: map[string]any{"provider": "aws"}, + expected: true, + }, + { + name: "SimpleStringEqualityFalse", + expression: "provider == 'aws'", + config: map[string]any{"provider": "local"}, + expected: false, + }, + { + name: "StringInequalityTrue", + expression: "provider != 'local'", + config: map[string]any{"provider": "aws"}, + expected: true, + }, + { + name: "StringInequalityFalse", + expression: "provider != 'local'", + config: map[string]any{"provider": "local"}, + expected: false, + }, + { + name: "BooleanEqualityTrue", + expression: "observability.enabled == true", + config: map[string]any{ + "observability": map[string]any{ + "enabled": true, + }, + }, + expected: true, + }, + { + name: "BooleanEqualityFalse", + expression: "observability.enabled == true", + config: map[string]any{ + "observability": map[string]any{ + "enabled": false, + }, + }, + expected: false, + }, + { + name: "LogicalAndBothTrue", + expression: "provider == 'local' && observability.enabled == true", + config: map[string]any{ + "provider": "local", + "observability": map[string]any{ + "enabled": true, + }, + }, + expected: true, + }, + { + name: "LogicalAndFirstFalse", + expression: "provider == 'local' && observability.enabled == true", + config: map[string]any{ + "provider": "aws", + "observability": map[string]any{ + "enabled": true, + }, + }, + expected: false, + }, + { + name: "LogicalOrFirstTrue", + expression: "provider == 'aws' || provider == 'azure'", + config: map[string]any{"provider": "aws"}, + expected: true, + }, + { + name: "LogicalOrSecondTrue", + expression: "provider == 'aws' || provider == 'azure'", + config: map[string]any{"provider": "azure"}, + expected: true, + }, + { + name: "LogicalOrBothFalse", + expression: "provider == 'aws' || provider == 'azure'", + config: map[string]any{"provider": "local"}, + expected: false, + }, + { + name: "ParenthesesGroupingComplexExpressionTrue", + expression: "provider == 'local' && (vm.driver != 'docker-desktop' || loadbalancer.enabled == true)", + config: map[string]any{ + "provider": "local", + "vm": map[string]any{ + "driver": "virtualbox", + }, + "loadbalancer": map[string]any{ + "enabled": false, + }, + }, + expected: true, + }, + { + name: "NestedObjectAccessMultipleLevels", + expression: "observability.enabled == true && observability.backend == 'quickwit'", + config: map[string]any{ + "observability": map[string]any{ + "enabled": true, + "backend": "quickwit", + }, + }, + expected: true, + }, + { + name: "MissingFieldEvaluatesToNullFalseComparison", + expression: "missing.field == 'value'", + config: map[string]any{ + "missing": map[string]any{ + "field": nil, + }, + }, + expected: false, + }, + { + name: "NilConfigHandledGracefully", + expression: "provider == 'aws'", + config: map[string]any{ + "provider": nil, + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + program, err := evaluator.CompileExpression(tt.expression, tt.config) + if err != nil { + t.Fatalf("Failed to compile expression '%s': %v", tt.expression, err) + } + + result, err := evaluator.EvaluateProgram(program, tt.config) + if tt.shouldErr { + if err == nil { + t.Errorf("Expected error for expression '%s', got none", tt.expression) + } + return + } + + if err != nil { + t.Errorf("Expected no error for expression '%s', got %v", tt.expression, err) + return + } + + if result != tt.expected { + t.Errorf("Expected %v for expression '%s' with config %v, got %v", + tt.expected, tt.expression, tt.config, result) + } + }) + } +} + +func TestEvaluateExpression(t *testing.T) { + evaluator, err := NewFeatureEvaluator() + if err != nil { + t.Fatalf("Failed to create evaluator: %v", err) + } + + t.Run("ConvenienceMethodWorksCorrectly", func(t *testing.T) { + config := map[string]any{ + "provider": "aws", + "observability": map[string]any{ + "enabled": true, + "backend": "quickwit", + }, + } + + result, err := evaluator.EvaluateExpression("provider == 'aws' && observability.enabled == true", config) + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + if !result { + t.Errorf("Expected true, got false") + } + }) + + t.Run("InvalidExpressionReturnsError", func(t *testing.T) { + _, err := evaluator.EvaluateExpression("invalid === syntax", map[string]any{}) + if err == nil { + t.Error("Expected error for invalid expression, got none") + } + }) +} + +func TestConvertToBool(t *testing.T) { + evaluator, err := NewFeatureEvaluator() + if err != nil { + t.Fatalf("Failed to create evaluator: %v", err) + } + + tests := []struct { + name string + expression string + config map[string]any + shouldErr bool + }{ + { + name: "BooleanResultConvertsSuccessfully", + expression: "provider == 'aws'", + config: map[string]any{"provider": "aws"}, + shouldErr: false, + }, + { + name: "StringResultShouldError", + expression: "provider", + config: map[string]any{"provider": "aws"}, + shouldErr: true, + }, + { + name: "NumberResultShouldError", + expression: "count", + config: map[string]any{"count": 5}, + shouldErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + program, err := evaluator.CompileExpression(tt.expression, tt.config) + if err != nil { + t.Fatalf("Failed to compile expression '%s': %v", tt.expression, err) + } + + _, err = evaluator.EvaluateProgram(program, tt.config) + if tt.shouldErr { + if err == nil { + t.Errorf("Expected error for non-boolean result, got none") + } + } else { + if err != nil { + t.Errorf("Expected no error for boolean result, got %v", err) + } + } + }) + } +}