Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions cmd/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,34 @@ var initCmd = &cobra.Command{
}
}

// Set platform-specific configurations
if initPlatform != "" {
switch initPlatform {
case "aws":
if err := configHandler.SetContextValue("aws.enabled", true); err != nil {
return fmt.Errorf("Error setting aws.enabled: %w", err)
}
if err := configHandler.SetContextValue("cluster.driver", "eks"); err != nil {
return fmt.Errorf("Error setting cluster.driver: %w", err)
}
case "azure":
if err := configHandler.SetContextValue("azure.enabled", true); err != nil {
return fmt.Errorf("Error setting azure.enabled: %w", err)
}
if err := configHandler.SetContextValue("cluster.driver", "aks"); err != nil {
return fmt.Errorf("Error setting cluster.driver: %w", err)
}
case "metal":
if err := configHandler.SetContextValue("cluster.driver", "talos"); err != nil {
return fmt.Errorf("Error setting cluster.driver: %w", err)
}
case "local":
if err := configHandler.SetContextValue("cluster.driver", "talos"); err != nil {
return fmt.Errorf("Error setting cluster.driver: %w", err)
}
}
}

// Set the vm driver only if it's configured and not overridden by --set flag
if vmDriverConfig != "" && configHandler.GetString("vm.driver") == "" {
if err := configHandler.SetContextValue("vm.driver", vmDriverConfig); err != nil {
Expand Down
99 changes: 99 additions & 0 deletions cmd/init_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -836,3 +836,102 @@ func TestInitCmd(t *testing.T) {
}
})
}

type platformTest struct {
name string
flag string
enabledKey string
enabledValue bool
driverKey string
driverExpected string
}

func TestInitCmd_PlatformFlag(t *testing.T) {
platforms := []platformTest{
{
name: "aws",
flag: "aws",
enabledKey: "aws.enabled",
enabledValue: true,
driverKey: "cluster.driver",
driverExpected: "eks",
},
{
name: "azure",
flag: "azure",
enabledKey: "azure.enabled",
enabledValue: true,
driverKey: "cluster.driver",
driverExpected: "aks",
},
{
name: "metal",
flag: "metal",
enabledKey: "",
enabledValue: false,
driverKey: "cluster.driver",
driverExpected: "talos",
},
{
name: "local",
flag: "local",
enabledKey: "",
enabledValue: false,
driverKey: "cluster.driver",
driverExpected: "talos",
},
}

for _, tc := range platforms {
t.Run(tc.name, func(t *testing.T) {
// Use a real map-backed mock config handler
store := make(map[string]interface{})
mockConfigHandler := config.NewMockConfigHandler()
mockConfigHandler.SetContextValueFunc = func(key string, value any) error {
store[key] = value
return nil
}
mockConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string {
if v, ok := store[key]; ok {
if s, ok := v.(string); ok {
return s
}
}
if len(defaultValue) > 0 {
return defaultValue[0]
}
return ""
}
mockConfigHandler.GetBoolFunc = func(key string, defaultValue ...bool) bool {
if v, ok := store[key]; ok {
if b, ok := v.(bool); ok {
return b
}
}
if len(defaultValue) > 0 {
return defaultValue[0]
}
return false
}

mocks := setupInitMocks(t, &SetupOptions{ConfigHandler: mockConfigHandler})
rootCmd.ResetFlags()
initCmd.ResetFlags()
initCmd.Flags().StringVar(&initPlatform, "platform", "", "Specify the platform to use [local|metal]")

rootCmd.SetArgs([]string{"init", "--platform", tc.flag})
err := Execute(mocks.Controller)
if err != nil {
t.Fatalf("Expected success, got error: %v", err)
}
if tc.enabledKey != "" {
if !mockConfigHandler.GetBool(tc.enabledKey) {
t.Errorf("Expected %s to be true", tc.enabledKey)
}
}
if got := mockConfigHandler.GetString(tc.driverKey); got != tc.driverExpected {
t.Errorf("Expected %s to be %q, got %q", tc.driverKey, tc.driverExpected, got)
}
})
}
}
73 changes: 65 additions & 8 deletions pkg/blueprint/blueprint_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,18 @@ type BlueprintHandler interface {
//go:embed templates/default.jsonnet
var defaultJsonnetTemplate string

//go:embed templates/local.jsonnet
var localJsonnetTemplate string

//go:embed templates/metal.jsonnet
var metalJsonnetTemplate string

//go:embed templates/aws.jsonnet
var awsJsonnetTemplate string

//go:embed templates/azure.jsonnet
var azureJsonnetTemplate string

type BaseBlueprintHandler struct {
BlueprintHandler
injector di.Injector
Expand Down Expand Up @@ -134,13 +146,34 @@ func (b *BaseBlueprintHandler) LoadConfig(path ...string) error {
basePath = path[0]
}

jsonnetData, jsonnetErr := b.loadFileData(basePath + ".jsonnet")
yamlData, yamlErr := b.loadFileData(basePath + ".yaml")
if jsonnetErr != nil {
return jsonnetErr
// Get platform from context
platform := ""
if b.configHandler.GetConfig().Cluster != nil && b.configHandler.GetConfig().Cluster.Platform != nil {
platform = *b.configHandler.GetConfig().Cluster.Platform
}

// Try to load platform-specific template first
platformData, err := b.loadPlatformTemplate(platform)
if err != nil {
return fmt.Errorf("error loading platform template: %w", err)
}
if yamlErr != nil && !os.IsNotExist(yamlErr) {
return yamlErr

var yamlData []byte
// If no platform template, fall back to default
if len(platformData) == 0 {
jsonnetData, jsonnetErr := b.loadFileData(basePath + ".jsonnet")
var yamlErr error
yamlData, yamlErr = b.loadFileData(basePath + ".yaml")
if jsonnetErr != nil {
return jsonnetErr
}
if yamlErr != nil && !os.IsNotExist(yamlErr) {
return yamlErr
}

if len(jsonnetData) > 0 {
platformData = jsonnetData
}
}

config := b.configHandler.GetConfig()
Expand Down Expand Up @@ -168,8 +201,8 @@ func (b *BaseBlueprintHandler) LoadConfig(path ...string) error {
vm := b.shims.NewJsonnetVM()
vm.ExtCode("context", string(contextJSON))

if len(jsonnetData) > 0 {
evaluatedJsonnet, err = vm.EvaluateAnonymousSnippet("blueprint.jsonnet", string(jsonnetData))
if len(platformData) > 0 {
evaluatedJsonnet, err = vm.EvaluateAnonymousSnippet("blueprint.jsonnet", string(platformData))
if err != nil {
return fmt.Errorf("error generating blueprint from jsonnet: %w", err)
}
Expand Down Expand Up @@ -221,6 +254,10 @@ func (b *BaseBlueprintHandler) WriteConfig(path ...string) error {
return fmt.Errorf("error creating directory: %w", err)
}

if _, err := b.shims.Stat(finalPath); err == nil {
return nil
}

fullBlueprint := b.blueprint.DeepCopy()

for i := range fullBlueprint.TerraformComponents {
Expand Down Expand Up @@ -643,6 +680,26 @@ func (b *BaseBlueprintHandler) loadFileData(path string) ([]byte, error) {
return nil, nil
}

// loadPlatformTemplate loads a platform-specific template if one exists
func (b *BaseBlueprintHandler) loadPlatformTemplate(platform string) ([]byte, error) {
if platform == "" {
return nil, nil
}

switch platform {
case "local":
return []byte(localJsonnetTemplate), nil
case "metal":
return []byte(metalJsonnetTemplate), nil
case "aws":
return []byte(awsJsonnetTemplate), nil
case "azure":
return []byte(azureJsonnetTemplate), nil
default:
return nil, nil
}
}

// yamlMarshalWithDefinedPaths marshals data to YAML format while ensuring all parent paths are defined.
// It handles various Go types including structs, maps, slices, and primitive types, preserving YAML
// tags and properly representing nil values.
Expand Down
77 changes: 76 additions & 1 deletion pkg/blueprint/blueprint_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1115,12 +1115,15 @@ func TestBlueprintHandler_WriteConfig(t *testing.T) {
t.Run("Success", func(t *testing.T) {
// Given a blueprint handler with metadata
handler, mocks := setup(t)
// Patch Stat to simulate file does not exist
mocks.Shims.Stat = func(name string) (os.FileInfo, error) {
return nil, os.ErrNotExist
}
expectedMetadata := blueprintv1alpha1.Metadata{
Name: "test-blueprint",
Description: "A test blueprint",
Authors: []string{"John Doe"},
}

handler.SetMetadata(expectedMetadata)

// And a mock file system that captures written data
Expand Down Expand Up @@ -1164,6 +1167,10 @@ func TestBlueprintHandler_WriteConfig(t *testing.T) {
t.Run("WriteNoPath", func(t *testing.T) {
// Given a blueprint handler with metadata
handler, mocks := setup(t)
// Patch Stat to simulate file does not exist
mocks.Shims.Stat = func(name string) (os.FileInfo, error) {
return nil, os.ErrNotExist
}
expectedMetadata := blueprintv1alpha1.Metadata{
Name: "test-blueprint",
Description: "A test blueprint",
Expand Down Expand Up @@ -1260,6 +1267,10 @@ func TestBlueprintHandler_WriteConfig(t *testing.T) {
t.Run("ErrorMarshallingYaml", func(t *testing.T) {
// Given a blueprint handler
handler, mocks := setup(t)
// Patch Stat to simulate file does not exist
mocks.Shims.Stat = func(name string) (os.FileInfo, error) {
return nil, os.ErrNotExist
}

// And a mock yaml marshaller that returns an error
mocks.Shims.YamlMarshalNonNull = func(in any) ([]byte, error) {
Expand All @@ -1281,6 +1292,10 @@ func TestBlueprintHandler_WriteConfig(t *testing.T) {
t.Run("ErrorWritingFile", func(t *testing.T) {
// Given a blueprint handler
handler, mocks := setup(t)
// Patch Stat to simulate file does not exist
mocks.Shims.Stat = func(name string) (os.FileInfo, error) {
return nil, os.ErrNotExist
}

// And a mock file system that fails to write files
mocks.Shims.WriteFile = func(name string, data []byte, perm fs.FileMode) error {
Expand All @@ -1302,6 +1317,10 @@ func TestBlueprintHandler_WriteConfig(t *testing.T) {
t.Run("CleanupEmptyPostBuild", func(t *testing.T) {
// Given a blueprint handler with kustomizations containing empty PostBuild
handler, mocks := setup(t)
// Patch Stat to simulate file does not exist
mocks.Shims.Stat = func(name string) (os.FileInfo, error) {
return nil, os.ErrNotExist
}
emptyPostBuildKustomizations := []blueprintv1alpha1.Kustomization{
{
Name: "kustomization-empty-postbuild",
Expand Down Expand Up @@ -1370,6 +1389,10 @@ func TestBlueprintHandler_WriteConfig(t *testing.T) {
t.Run("ClearTerraformComponentsVariablesAndValues", func(t *testing.T) {
// Given a blueprint handler with terraform components containing variables and values
handler, mocks := setup(t)
// Patch Stat to simulate file does not exist
mocks.Shims.Stat = func(name string) (os.FileInfo, error) {
return nil, os.ErrNotExist
}
terraformComponents := []blueprintv1alpha1.TerraformComponent{
{
Source: "source1",
Expand Down Expand Up @@ -4089,3 +4112,55 @@ func TestBaseBlueprintHandler_WaitForKustomizations(t *testing.T) {
}
})
}

func TestBaseBlueprintHandler_loadPlatformTemplate(t *testing.T) {
t.Run("ValidPlatforms", func(t *testing.T) {
// Given a BaseBlueprintHandler
handler := &BaseBlueprintHandler{}

// When loading templates for valid platforms
platforms := []string{"local", "metal", "aws", "azure"}
for _, platform := range platforms {
// Then the template should be loaded successfully
template, err := handler.loadPlatformTemplate(platform)
if err != nil {
t.Errorf("Expected no error for platform %s, got: %v", platform, err)
}
if len(template) == 0 {
t.Errorf("Expected non-empty template for platform %s", platform)
}
}
})

t.Run("InvalidPlatform", func(t *testing.T) {
// Given a BaseBlueprintHandler
handler := &BaseBlueprintHandler{}

// When loading template for invalid platform
template, err := handler.loadPlatformTemplate("invalid-platform")

// Then no error should occur but template should be empty
if err != nil {
t.Errorf("Expected no error for invalid platform, got: %v", err)
}
if len(template) != 0 {
t.Errorf("Expected empty template for invalid platform, got length: %d", len(template))
}
})

t.Run("EmptyPlatform", func(t *testing.T) {
// Given a BaseBlueprintHandler
handler := &BaseBlueprintHandler{}

// When loading template with empty platform
template, err := handler.loadPlatformTemplate("")

// Then no error should occur and template should be empty
if err != nil {
t.Errorf("Expected no error for empty platform, got: %v", err)
}
if len(template) != 0 {
t.Errorf("Expected empty template for empty platform, got length: %d", len(template))
}
})
}
Loading
Loading