diff --git a/pkg/provider/aws/cluster.go b/pkg/provider/aws/cluster.go index 5cd8e2f4a..2073ed6e4 100644 --- a/pkg/provider/aws/cluster.go +++ b/pkg/provider/aws/cluster.go @@ -398,10 +398,21 @@ func (p *Provider) createInstances( image *v1alpha1.Image, ) ([]InstanceInfo, error) { // Resolve AMI for this node pool - // Determine architecture from image spec + // Determine architecture: prefer explicit spec, then infer from instance type var arch string if image != nil && image.Architecture != "" { arch = image.Architecture + } else { + // Infer architecture from instance type (e.g., arm64 for g5g/m7g/c7g) + inferred, err := p.inferArchFromInstanceType(instanceType) + if err != nil { + return nil, fmt.Errorf( + "failed to infer architecture from instance type %s: %w; set spec.image.architecture explicitly to override", + instanceType, + err, + ) + } + arch = inferred } resolved, err := p.resolveImageForNode(os, image, arch) if err != nil { diff --git a/pkg/provider/aws/image.go b/pkg/provider/aws/image.go index 4e3388b94..02c18b577 100644 --- a/pkg/provider/aws/image.go +++ b/pkg/provider/aws/image.go @@ -70,7 +70,16 @@ func (p *Provider) setAMI() error { func (p *Provider) resolveOSToAMI() error { arch := p.Spec.Image.Architecture if arch == "" { - arch = "x86_64" // Default architecture + // Infer architecture from instance type (e.g., arm64 for g5g/m7g/c7g) + inferred, err := p.inferArchFromInstanceType(p.Spec.Type) + if err != nil { + return fmt.Errorf( + "failed to infer architecture from instance type %s: %w; set spec.image.architecture explicitly to override", + p.Spec.Type, + err, + ) + } + arch = inferred } //nolint:staticcheck // Instance is embedded but explicit access is clearer @@ -244,18 +253,27 @@ func (p *Provider) findLegacyAMI(arch string) (string, error) { // setLegacyAMI implements the original Ubuntu 22.04 default behavior for // backward compatibility when OS is not specified. This mutates provider state. func (p *Provider) setLegacyAMI() error { - imageID, err := p.findLegacyAMI("") + // Determine architecture before AMI lookup + arch := p.Spec.Image.Architecture + if arch == "" { + // Infer architecture from instance type (e.g., arm64 for g5g/m7g/c7g) + inferred, err := p.inferArchFromInstanceType(p.Spec.Type) + if err != nil { + return fmt.Errorf( + "failed to infer architecture from instance type %s: %w; set spec.image.architecture explicitly to override", + p.Spec.Type, + err, + ) + } + arch = inferred + } + + imageID, err := p.findLegacyAMI(arch) if err != nil { return err } p.Spec.Image.ImageId = &imageID - - // Store the resolved architecture (normalized to EC2 form) for cross-validation in DryRun - if p.Spec.Image.Architecture == "" { - p.Spec.Image.Architecture = "x86_64" // Legacy default - } else { - p.Spec.Image.Architecture = normalizeArchToEC2(p.Spec.Image.Architecture) - } + p.Spec.Image.Architecture = normalizeArchToEC2(arch) // Set default username for Ubuntu if not provided //nolint:staticcheck // Auth is embedded but explicit access is clearer @@ -353,6 +371,32 @@ func normalizeArchToEC2(arch string) string { } } +// inferArchFromInstanceType queries EC2 for the supported architectures of +// the given instance type. If the instance type only supports arm64, returns +// "arm64"; otherwise returns "x86_64" for backward compatibility. +// This enables automatic ARM64 AMI selection when users specify an arm64-only +// instance type (e.g., g5g, m7g, c7g) without explicitly setting Architecture. +func (p *Provider) inferArchFromInstanceType(instanceType string) (string, error) { + archs, err := p.getInstanceTypeArch(instanceType) + if err != nil { + return "", err + } + hasX86 := false + hasArm := false + for _, a := range archs { + switch { + case strings.HasPrefix(a, "x86_64"): + hasX86 = true + case strings.HasPrefix(a, "arm64"): + hasArm = true + } + } + if hasArm && !hasX86 { + return "arm64", nil + } + return "x86_64", nil +} + // describeImageArch queries EC2 DescribeImages for a specific AMI ID and // returns its architecture string (e.g., "x86_64" or "arm64"). func (p *Provider) describeImageArch(imageID string) (string, error) { diff --git a/pkg/provider/aws/image_test.go b/pkg/provider/aws/image_test.go index dd1ec083c..f5c339bd5 100644 --- a/pkg/provider/aws/image_test.go +++ b/pkg/provider/aws/image_test.go @@ -1004,3 +1004,236 @@ func TestDryRun_ArchitectureMatch(t *testing.T) { err := p.DryRun() require.NoError(t, err) } + +func TestInferArchFromInstanceType(t *testing.T) { + tests := []struct { + name string + instanceType string + setupMock func(*MockEC2Client) + wantArch string + wantErr bool + }{ + { + name: "arm64-only instance type infers arm64", + instanceType: "g5g.xlarge", + setupMock: func(ec2Mock *MockEC2Client) { + ec2Mock.DescribeInstTypesFunc = func(ctx context.Context, + params *ec2.DescribeInstanceTypesInput, + optFns ...func(*ec2.Options)) (*ec2.DescribeInstanceTypesOutput, error) { + return &ec2.DescribeInstanceTypesOutput{ + InstanceTypes: []types.InstanceTypeInfo{ + { + InstanceType: "g5g.xlarge", + ProcessorInfo: &types.ProcessorInfo{ + SupportedArchitectures: []types.ArchitectureType{ + types.ArchitectureTypeArm64, + }, + }, + }, + }, + }, nil + } + }, + wantArch: "arm64", + wantErr: false, + }, + { + name: "x86_64-only instance type infers x86_64", + instanceType: "g4dn.xlarge", + setupMock: func(ec2Mock *MockEC2Client) { + ec2Mock.DescribeInstTypesFunc = func(ctx context.Context, + params *ec2.DescribeInstanceTypesInput, + optFns ...func(*ec2.Options)) (*ec2.DescribeInstanceTypesOutput, error) { + return &ec2.DescribeInstanceTypesOutput{ + InstanceTypes: []types.InstanceTypeInfo{ + { + InstanceType: "g4dn.xlarge", + ProcessorInfo: &types.ProcessorInfo{ + SupportedArchitectures: []types.ArchitectureType{ + types.ArchitectureTypeX8664, + }, + }, + }, + }, + }, nil + } + }, + wantArch: "x86_64", + wantErr: false, + }, + { + name: "dual-arch instance type defaults to x86_64", + instanceType: "synthetic.dualarch", + setupMock: func(ec2Mock *MockEC2Client) { + ec2Mock.DescribeInstTypesFunc = func(ctx context.Context, + params *ec2.DescribeInstanceTypesInput, + optFns ...func(*ec2.Options)) (*ec2.DescribeInstanceTypesOutput, error) { + return &ec2.DescribeInstanceTypesOutput{ + InstanceTypes: []types.InstanceTypeInfo{ + { + InstanceType: "synthetic.dualarch", + ProcessorInfo: &types.ProcessorInfo{ + SupportedArchitectures: []types.ArchitectureType{ + types.ArchitectureTypeX8664, + types.ArchitectureTypeArm64, + }, + }, + }, + }, + }, nil + } + }, + wantArch: "x86_64", + wantErr: false, + }, + { + name: "arm64_mac variant infers arm64", + instanceType: "mac2-m2.metal", + setupMock: func(ec2Mock *MockEC2Client) { + ec2Mock.DescribeInstTypesFunc = func(ctx context.Context, + params *ec2.DescribeInstanceTypesInput, + optFns ...func(*ec2.Options)) (*ec2.DescribeInstanceTypesOutput, error) { + return &ec2.DescribeInstanceTypesOutput{ + InstanceTypes: []types.InstanceTypeInfo{ + { + InstanceType: "mac2-m2.metal", + ProcessorInfo: &types.ProcessorInfo{ + SupportedArchitectures: []types.ArchitectureType{ + types.ArchitectureTypeArm64Mac, + }, + }, + }, + }, + }, nil + } + }, + wantArch: "arm64", + wantErr: false, + }, + { + name: "x86_64_mac variant infers x86_64", + instanceType: "mac1.metal", + setupMock: func(ec2Mock *MockEC2Client) { + ec2Mock.DescribeInstTypesFunc = func(ctx context.Context, + params *ec2.DescribeInstanceTypesInput, + optFns ...func(*ec2.Options)) (*ec2.DescribeInstanceTypesOutput, error) { + return &ec2.DescribeInstanceTypesOutput{ + InstanceTypes: []types.InstanceTypeInfo{ + { + InstanceType: "mac1.metal", + ProcessorInfo: &types.ProcessorInfo{ + SupportedArchitectures: []types.ArchitectureType{ + types.ArchitectureTypeX8664Mac, + }, + }, + }, + }, + }, nil + } + }, + wantArch: "x86_64", + wantErr: false, + }, + { + name: "API error returns error", + instanceType: "unknown.type", + setupMock: func(ec2Mock *MockEC2Client) { + ec2Mock.DescribeInstTypesFunc = func(ctx context.Context, + params *ec2.DescribeInstanceTypesInput, + optFns ...func(*ec2.Options)) (*ec2.DescribeInstanceTypesOutput, error) { + return nil, fmt.Errorf("instance type not found") + } + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ec2Mock := NewMockEC2Client() + if tt.setupMock != nil { + tt.setupMock(ec2Mock) + } + + p := &Provider{ec2: ec2Mock} + arch, err := p.inferArchFromInstanceType(tt.instanceType) + + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.wantArch, arch) + }) + } +} + +func TestResolveOSToAMI_InfersArchFromInstanceType(t *testing.T) { + // When Architecture is empty and instance type is arm64-only, + // resolveOSToAMI should infer arm64 and resolve an arm64 AMI. + ec2Mock := NewMockEC2Client() + ssmMock := &mockSSMClient{} + + // Mock: g5g.xlarge is arm64-only + ec2Mock.DescribeInstTypesFunc = func(ctx context.Context, + params *ec2.DescribeInstanceTypesInput, + optFns ...func(*ec2.Options)) (*ec2.DescribeInstanceTypesOutput, error) { + return &ec2.DescribeInstanceTypesOutput{ + InstanceTypes: []types.InstanceTypeInfo{ + { + InstanceType: "g5g.xlarge", + ProcessorInfo: &types.ProcessorInfo{ + SupportedArchitectures: []types.ArchitectureType{ + types.ArchitectureTypeArm64, + }, + }, + }, + }, + }, nil + } + + // Mock: SSM returns arm64 AMI when arm64 is in path + ssmMock.GetParameterFunc = func(ctx context.Context, params *ssm.GetParameterInput, + optFns ...func(*ssm.Options)) (*ssm.GetParameterOutput, error) { + if params.Name != nil && strings.Contains(*params.Name, "arm64") { + return &ssm.GetParameterOutput{ + Parameter: &ssmtypes.Parameter{ + Value: aws.String("ami-arm64-inferred"), + }, + }, nil + } + return nil, fmt.Errorf("expected arm64 in SSM path, got: %s", *params.Name) + } + + resolver := ami.NewResolver(ec2Mock, ssmMock, "us-east-1") + + env := v1alpha1.Environment{ + ObjectMeta: metav1.ObjectMeta{Name: "test-arm64-inference"}, + Spec: v1alpha1.EnvironmentSpec{ + Provider: v1alpha1.ProviderAWS, + Instance: v1alpha1.Instance{ + Type: "g5g.xlarge", // arm64-only instance type + Region: "us-east-1", + OS: "ubuntu-22.04", + }, + // Architecture is intentionally NOT set + }, + } + + p := &Provider{ + Environment: &env, + ec2: ec2Mock, + amiResolver: resolver, + } + + err := p.resolveOSToAMI() + require.NoError(t, err) + + // Architecture should have been inferred as arm64 + assert.Equal(t, "arm64", p.Spec.Image.Architecture, + "Should infer arm64 from g5g.xlarge instance type") + // AMI should be the arm64 one + require.NotNil(t, p.Spec.Image.ImageId) + assert.Equal(t, "ami-arm64-inferred", *p.Spec.Image.ImageId, + "Should resolve arm64 AMI when architecture inferred from instance type") +}