diff --git a/pkg/provider/aws/cluster.go b/pkg/provider/aws/cluster.go index 016fab09..a77e8414 100644 --- a/pkg/provider/aws/cluster.go +++ b/pkg/provider/aws/cluster.go @@ -135,11 +135,35 @@ func (p *Provider) CreateCluster() error { } _ = p.updateProgressingCondition(*p.DeepCopy(), &cache.AWS, "v1alpha1.Creating", "Internet Gateway created") - if err := p.createRouteTable(&cache.AWS); err != nil { - _ = p.updateDegradedCondition(*p.DeepCopy(), &cache.AWS, "v1alpha1.Creating", "Error creating route table") - return fmt.Errorf("error creating route table: %w", err) + // Phase 1b: Create cluster networking (public subnet, NAT GW, route tables) + // Note: We skip createRouteTable() here — in single-node mode it creates an + // IGW-routed table for the (only) subnet. In cluster mode the private subnet + // gets a NAT-routed table (createPrivateRouteTable) and the public subnet + // gets an IGW-routed table (createPublicRouteTable). Calling createRouteTable + // would create an orphaned IGW table associated with the private subnet. + if err := p.createPublicSubnet(&cache.AWS); err != nil { + _ = p.updateDegradedCondition(*p.DeepCopy(), &cache.AWS, "v1alpha1.Creating", "Error creating public subnet") + return fmt.Errorf("error creating public subnet: %w", err) } - _ = p.updateProgressingCondition(*p.DeepCopy(), &cache.AWS, "v1alpha1.Creating", "Route Table created") + _ = p.updateProgressingCondition(*p.DeepCopy(), &cache.AWS, "v1alpha1.Creating", "Public subnet created") + + if err := p.createPublicRouteTable(&cache.AWS); err != nil { + _ = p.updateDegradedCondition(*p.DeepCopy(), &cache.AWS, "v1alpha1.Creating", "Error creating public route table") + return fmt.Errorf("error creating public route table: %w", err) + } + _ = p.updateProgressingCondition(*p.DeepCopy(), &cache.AWS, "v1alpha1.Creating", "Public route table created") + + if err := p.createNATGateway(&cache.AWS); err != nil { + _ = p.updateDegradedCondition(*p.DeepCopy(), &cache.AWS, "v1alpha1.Creating", "Error creating NAT Gateway") + return fmt.Errorf("error creating NAT Gateway: %w", err) + } + _ = p.updateProgressingCondition(*p.DeepCopy(), &cache.AWS, "v1alpha1.Creating", "NAT Gateway created") + + if err := p.createPrivateRouteTable(&cache.AWS); err != nil { + _ = p.updateDegradedCondition(*p.DeepCopy(), &cache.AWS, "v1alpha1.Creating", "Error creating private route table") + return fmt.Errorf("error creating private route table: %w", err) + } + _ = p.updateProgressingCondition(*p.DeepCopy(), &cache.AWS, "v1alpha1.Creating", "Private route table created") // Phase 2: Create separate CP and Worker security groups if err := p.createControlPlaneSecurityGroup(cache); err != nil { @@ -677,7 +701,7 @@ func (p *Provider) createInstances( }, NetworkInterfaces: []types.InstanceNetworkInterfaceSpecification{ { - AssociatePublicIpAddress: aws.Bool(true), + AssociatePublicIpAddress: aws.Bool(false), DeleteOnTermination: aws.Bool(true), DeviceIndex: aws.Int32(0), Groups: []string{sgID}, diff --git a/pkg/provider/aws/cluster_test.go b/pkg/provider/aws/cluster_test.go new file mode 100644 index 00000000..d976f1c2 --- /dev/null +++ b/pkg/provider/aws/cluster_test.go @@ -0,0 +1,386 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aws + +import ( + "context" + "sync" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" + "github.com/aws/aws-sdk-go-v2/service/ec2/types" + + "github.com/NVIDIA/holodeck/api/holodeck/v1alpha1" +) + +// TestCreateInstancesSetsNoPublicIP verifies that createInstances sets +// AssociatePublicIpAddress=false in the RunInstancesInput for cluster mode. +func TestCreateInstancesSetsNoPublicIP(t *testing.T) { + var mu sync.Mutex + var captured []*ec2.RunInstancesInput + + mock := NewMockEC2Client() + + // Capture RunInstances calls + mock.RunInstancesFunc = func(ctx context.Context, params *ec2.RunInstancesInput, optFns ...func(*ec2.Options)) (*ec2.RunInstancesOutput, error) { + mu.Lock() + captured = append(captured, params) + mu.Unlock() + return &ec2.RunInstancesOutput{ + Instances: []types.Instance{ + { + InstanceId: aws.String("i-test-12345"), + PrivateIpAddress: aws.String("10.0.0.10"), + NetworkInterfaces: []types.InstanceNetworkInterface{ + {NetworkInterfaceId: aws.String("eni-test-12345")}, + }, + }, + }, + }, nil + } + + // Mock DescribeInstances for the waiter — return instance in "running" state + mock.DescribeInstsFunc = func(ctx context.Context, params *ec2.DescribeInstancesInput, optFns ...func(*ec2.Options)) (*ec2.DescribeInstancesOutput, error) { + return &ec2.DescribeInstancesOutput{ + Reservations: []types.Reservation{ + { + Instances: []types.Instance{ + { + InstanceId: aws.String("i-test-12345"), + State: &types.InstanceState{Name: types.InstanceStateNameRunning}, + PublicDnsName: aws.String(""), + PublicIpAddress: nil, + PrivateIpAddress: aws.String("10.0.0.10"), + NetworkInterfaces: []types.InstanceNetworkInterface{ + {NetworkInterfaceId: aws.String("eni-test-12345")}, + }, + }, + }, + }, + }, + }, nil + } + + // Mock DescribeImages for resolveImageForNode and describeImageRootDevice + mock.DescribeImagesFunc = func(ctx context.Context, params *ec2.DescribeImagesInput, optFns ...func(*ec2.Options)) (*ec2.DescribeImagesOutput, error) { + return &ec2.DescribeImagesOutput{ + Images: []types.Image{ + { + ImageId: aws.String("ami-test-123"), + Architecture: types.ArchitectureValuesX8664, + RootDeviceName: aws.String("/dev/sda1"), + }, + }, + }, nil + } + + provider := newTestProvider(mock) + cache := &ClusterCache{ + AWS: AWS{ + Subnetid: "subnet-private", + CPSecurityGroupid: "sg-cp", + WorkerSecurityGroupid: "sg-worker", + }, + } + + instances, err := provider.createInstances( + cache, + 1, + NodeRoleControlPlane, + "t3.medium", + nil, + "", + &v1alpha1.Image{ImageId: aws.String("ami-test-123")}, + ) + if err != nil { + t.Fatalf("createInstances failed: %v", err) + } + + if len(instances) != 1 { + t.Fatalf("expected 1 instance, got %d", len(instances)) + } + + // Verify no public IP + mu.Lock() + defer mu.Unlock() + if len(captured) != 1 { + t.Fatalf("expected 1 RunInstances call, got %d", len(captured)) + } + + nis := captured[0].NetworkInterfaces + if len(nis) == 0 { + t.Fatal("expected NetworkInterfaces in RunInstancesInput") + } + if nis[0].AssociatePublicIpAddress == nil || *nis[0].AssociatePublicIpAddress != false { + t.Error("AssociatePublicIpAddress should be false for cluster instances") + } + + // Verify instance uses private subnet + if aws.ToString(nis[0].SubnetId) != "subnet-private" { + t.Errorf("SubnetId = %q, want %q", aws.ToString(nis[0].SubnetId), "subnet-private") + } +} + +// TestCreateInstancesUsesRoleSecurityGroup verifies that createInstances selects +// the correct security group based on the node role. +func TestCreateInstancesUsesRoleSecurityGroup(t *testing.T) { + tests := []struct { + name string + role NodeRole + wantSG string + }{ + {"control-plane uses CP SG", NodeRoleControlPlane, "sg-cp"}, + {"worker uses Worker SG", NodeRoleWorker, "sg-worker"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var mu sync.Mutex + var captured *ec2.RunInstancesInput + + mock := NewMockEC2Client() + mock.RunInstancesFunc = func(ctx context.Context, params *ec2.RunInstancesInput, optFns ...func(*ec2.Options)) (*ec2.RunInstancesOutput, error) { + mu.Lock() + captured = params + mu.Unlock() + return &ec2.RunInstancesOutput{ + Instances: []types.Instance{ + { + InstanceId: aws.String("i-test"), + PrivateIpAddress: aws.String("10.0.0.10"), + NetworkInterfaces: []types.InstanceNetworkInterface{ + {NetworkInterfaceId: aws.String("eni-test")}, + }, + }, + }, + }, nil + } + mock.DescribeInstsFunc = func(ctx context.Context, params *ec2.DescribeInstancesInput, optFns ...func(*ec2.Options)) (*ec2.DescribeInstancesOutput, error) { + return &ec2.DescribeInstancesOutput{ + Reservations: []types.Reservation{{ + Instances: []types.Instance{{ + InstanceId: aws.String("i-test"), + State: &types.InstanceState{Name: types.InstanceStateNameRunning}, + PrivateIpAddress: aws.String("10.0.0.10"), + NetworkInterfaces: []types.InstanceNetworkInterface{ + {NetworkInterfaceId: aws.String("eni-test")}, + }, + }}, + }}, + }, nil + } + mock.DescribeImagesFunc = func(ctx context.Context, params *ec2.DescribeImagesInput, optFns ...func(*ec2.Options)) (*ec2.DescribeImagesOutput, error) { + return &ec2.DescribeImagesOutput{ + Images: []types.Image{{ + ImageId: aws.String("ami-test"), + Architecture: types.ArchitectureValuesX8664, + RootDeviceName: aws.String("/dev/sda1"), + }}, + }, nil + } + + provider := newTestProvider(mock) + cache := &ClusterCache{ + AWS: AWS{ + Subnetid: "subnet-private", + CPSecurityGroupid: "sg-cp", + WorkerSecurityGroupid: "sg-worker", + }, + } + + _, err := provider.createInstances(cache, 1, tt.role, "t3.medium", nil, "", &v1alpha1.Image{ImageId: aws.String("ami-test")}) + if err != nil { + t.Fatalf("createInstances failed: %v", err) + } + + mu.Lock() + defer mu.Unlock() + if captured == nil { + t.Fatal("RunInstances was not called") + } + gotSG := captured.NetworkInterfaces[0].Groups[0] + if gotSG != tt.wantSG { + t.Errorf("SecurityGroup = %q, want %q", gotSG, tt.wantSG) + } + }) + } +} + +// TestPrivateRouteTableRoutesToNATGW verifies that createPrivateRouteTable +// routes 0.0.0.0/0 to the NAT Gateway (not the Internet Gateway). +func TestPrivateRouteTableRoutesToNATGW(t *testing.T) { + var capturedRoute *ec2.CreateRouteInput + + mock := NewMockEC2Client() + mock.CreateRouteFunc = func(ctx context.Context, params *ec2.CreateRouteInput, optFns ...func(*ec2.Options)) (*ec2.CreateRouteOutput, error) { + capturedRoute = params + return &ec2.CreateRouteOutput{}, nil + } + + provider := newTestProvider(mock) + cache := &AWS{ + Vpcid: "vpc-test", + Subnetid: "subnet-private", + NatGatewayid: "nat-test-123", + InternetGwid: "igw-test-456", + } + + if err := provider.createPrivateRouteTable(cache); err != nil { + t.Fatalf("createPrivateRouteTable failed: %v", err) + } + + if capturedRoute == nil { + t.Fatal("CreateRoute was not called") + } + + // Must route to NAT GW, not IGW + if capturedRoute.NatGatewayId == nil || *capturedRoute.NatGatewayId != "nat-test-123" { + t.Errorf("Route should target NAT GW nat-test-123, got NatGatewayId=%v GatewayId=%v", + aws.ToString(capturedRoute.NatGatewayId), aws.ToString(capturedRoute.GatewayId)) + } + if capturedRoute.GatewayId != nil { + t.Error("Private route table should NOT route to IGW (GatewayId should be nil)") + } +} + +// TestPublicRouteTableRoutesToIGW verifies that createPublicRouteTable +// routes 0.0.0.0/0 to the Internet Gateway and associates with the public subnet. +func TestPublicRouteTableRoutesToIGW(t *testing.T) { + var capturedRoute *ec2.CreateRouteInput + var capturedAssoc *ec2.AssociateRouteTableInput + + mock := NewMockEC2Client() + mock.CreateRouteFunc = func(ctx context.Context, params *ec2.CreateRouteInput, optFns ...func(*ec2.Options)) (*ec2.CreateRouteOutput, error) { + capturedRoute = params + return &ec2.CreateRouteOutput{}, nil + } + mock.AssociateRTFunc = func(ctx context.Context, params *ec2.AssociateRouteTableInput, optFns ...func(*ec2.Options)) (*ec2.AssociateRouteTableOutput, error) { + capturedAssoc = params + return &ec2.AssociateRouteTableOutput{}, nil + } + + provider := newTestProvider(mock) + cache := &AWS{ + Vpcid: "vpc-test", + PublicSubnetid: "subnet-public", + InternetGwid: "igw-test-456", + } + + if err := provider.createPublicRouteTable(cache); err != nil { + t.Fatalf("createPublicRouteTable failed: %v", err) + } + + // Verify route targets IGW + if capturedRoute == nil { + t.Fatal("CreateRoute was not called") + } + if capturedRoute.GatewayId == nil || *capturedRoute.GatewayId != "igw-test-456" { + t.Errorf("Route should target IGW igw-test-456, got GatewayId=%v", aws.ToString(capturedRoute.GatewayId)) + } + + // Verify association with public subnet (not private) + if capturedAssoc == nil { + t.Fatal("AssociateRouteTable was not called") + } + if aws.ToString(capturedAssoc.SubnetId) != "subnet-public" { + t.Errorf("Route table associated with %q, want public subnet %q", + aws.ToString(capturedAssoc.SubnetId), "subnet-public") + } +} + +// TestPublicSubnetCreatedInCorrectCIDR verifies that createPublicSubnet +// creates a subnet in the 10.0.1.0/24 CIDR and stores it in PublicSubnetid. +func TestPublicSubnetCreatedInCorrectCIDR(t *testing.T) { + var capturedSubnet *ec2.CreateSubnetInput + + mock := NewMockEC2Client() + mock.CreateSubnetFunc = func(ctx context.Context, params *ec2.CreateSubnetInput, optFns ...func(*ec2.Options)) (*ec2.CreateSubnetOutput, error) { + capturedSubnet = params + return &ec2.CreateSubnetOutput{ + Subnet: &types.Subnet{SubnetId: aws.String("subnet-public-123")}, + }, nil + } + + provider := newTestProvider(mock) + cache := &AWS{ + Vpcid: "vpc-test", + } + + if err := provider.createPublicSubnet(cache); err != nil { + t.Fatalf("createPublicSubnet failed: %v", err) + } + + if capturedSubnet == nil { + t.Fatal("CreateSubnet was not called") + } + if aws.ToString(capturedSubnet.CidrBlock) != "10.0.1.0/24" { + t.Errorf("Public subnet CIDR = %q, want %q", aws.ToString(capturedSubnet.CidrBlock), "10.0.1.0/24") + } + + // Verify stored in PublicSubnetid (not Subnetid) + if cache.PublicSubnetid != "subnet-public-123" { + t.Errorf("cache.PublicSubnetid = %q, want %q", cache.PublicSubnetid, "subnet-public-123") + } +} + +// TestNATGatewayCreatedInPublicSubnet verifies that createNATGateway +// places the NAT gateway in the public subnet. +func TestNATGatewayCreatedInPublicSubnet(t *testing.T) { + var capturedNAT *ec2.CreateNatGatewayInput + + mock := NewMockEC2Client() + mock.CreateNatGatewayFunc = func(ctx context.Context, params *ec2.CreateNatGatewayInput, optFns ...func(*ec2.Options)) (*ec2.CreateNatGatewayOutput, error) { + capturedNAT = params + return &ec2.CreateNatGatewayOutput{ + NatGateway: &types.NatGateway{ + NatGatewayId: aws.String("nat-test-123"), + State: types.NatGatewayStateAvailable, + }, + }, nil + } + // Mock DescribeNatGateways for the wait loop + mock.DescribeNatGatewaysFunc = func(ctx context.Context, params *ec2.DescribeNatGatewaysInput, optFns ...func(*ec2.Options)) (*ec2.DescribeNatGatewaysOutput, error) { + return &ec2.DescribeNatGatewaysOutput{ + NatGateways: []types.NatGateway{ + { + NatGatewayId: aws.String("nat-test-123"), + State: types.NatGatewayStateAvailable, + }, + }, + }, nil + } + + provider := newTestProvider(mock) + cache := &AWS{ + Vpcid: "vpc-test", + PublicSubnetid: "subnet-public", + Subnetid: "subnet-private", + } + + if err := provider.createNATGateway(cache); err != nil { + t.Fatalf("createNATGateway failed: %v", err) + } + + if capturedNAT == nil { + t.Fatal("CreateNatGateway was not called") + } + if aws.ToString(capturedNAT.SubnetId) != "subnet-public" { + t.Errorf("NAT GW placed in %q, want public subnet %q", + aws.ToString(capturedNAT.SubnetId), "subnet-public") + } +} diff --git a/pkg/provider/aws/nlb.go b/pkg/provider/aws/nlb.go index 2507d3f7..4d880e48 100644 --- a/pkg/provider/aws/nlb.go +++ b/pkg/provider/aws/nlb.go @@ -54,8 +54,8 @@ func (p *Provider) createNLB(cache *ClusterCache) error { } lbName := nlbBaseName + nlbSuffix - // Determine subnet IDs (use the same subnet for NLB) - subnetIDs := []string{cache.Subnetid} + // Use the public subnet for the internet-facing NLB + subnetIDs := []string{cache.PublicSubnetid} // Create load balancer createLBInput := &elasticloadbalancingv2.CreateLoadBalancerInput{ diff --git a/pkg/provisioner/cluster.go b/pkg/provisioner/cluster.go index 0dc17e86..29a481e6 100644 --- a/pkg/provisioner/cluster.go +++ b/pkg/provisioner/cluster.go @@ -95,6 +95,16 @@ func (cp *ClusterProvisioner) transportOptsForNode(node NodeInfo) []Option { return nil } +// hostForNode returns the SSH host address for a node. Nodes with a Transport +// (e.g., SSM for private-subnet instances) use PrivateIP since the transport +// handles connectivity. Nodes without a transport use PublicIP for direct SSH. +func hostForNode(node NodeInfo) string { + if node.Transport != nil { + return node.PrivateIP + } + return node.PublicIP +} + // ProvisionCluster provisions a multinode Kubernetes cluster // It follows the order: init first CP → join additional CPs → join workers func (cp *ClusterProvisioner) ProvisionCluster(nodes []NodeInfo) error { @@ -185,7 +195,7 @@ func (cp *ClusterProvisioner) provisionBaseOnAllNodes(nodes []NodeInfo) error { g.Go(func() error { cp.log.Info("Provisioning base dependencies on %s (%s)", node.Name, node.PublicIP) - provisioner, err := New(cp.log, cp.KeyPath, cp.getUsernameForNode(node), node.PublicIP, cp.transportOptsForNode(node)...) + provisioner, err := New(cp.log, cp.KeyPath, cp.getUsernameForNode(node), hostForNode(node), cp.transportOptsForNode(node)...) if err != nil { return fmt.Errorf("failed to connect to %s: %w", node.Name, err) } @@ -227,7 +237,7 @@ func (cp *ClusterProvisioner) provisionBaseOnAllNodes(nodes []NodeInfo) error { func (cp *ClusterProvisioner) installK8sPrereqs(node NodeInfo) error { cp.log.Info("Installing K8s binaries on %s (%s)", node.Name, node.PublicIP) - provisioner, err := New(cp.log, cp.KeyPath, cp.getUsernameForNode(node), node.PublicIP, cp.transportOptsForNode(node)...) + provisioner, err := New(cp.log, cp.KeyPath, cp.getUsernameForNode(node), hostForNode(node), cp.transportOptsForNode(node)...) if err != nil { return fmt.Errorf("failed to connect to %s: %w", node.Name, err) } @@ -291,7 +301,7 @@ func (cp *ClusterProvisioner) installK8sPrereqs(node NodeInfo) error { // initFirstControlPlane initializes the first control-plane node with kubeadm init func (cp *ClusterProvisioner) initFirstControlPlane(node NodeInfo) error { - provisioner, err := New(cp.log, cp.KeyPath, cp.getUsernameForNode(node), node.PublicIP, cp.transportOptsForNode(node)...) + provisioner, err := New(cp.log, cp.KeyPath, cp.getUsernameForNode(node), hostForNode(node), cp.transportOptsForNode(node)...) if err != nil { return fmt.Errorf("failed to connect to %s: %w", node.Name, err) } @@ -383,7 +393,7 @@ printf '%s\n%s\n%s' "$TOKEN" "$HASH" "$CERTKEY" // joinControlPlane joins an additional control-plane node to the cluster func (cp *ClusterProvisioner) joinControlPlane(node NodeInfo) error { - provisioner, err := New(cp.log, cp.KeyPath, cp.getUsernameForNode(node), node.PublicIP, cp.transportOptsForNode(node)...) + provisioner, err := New(cp.log, cp.KeyPath, cp.getUsernameForNode(node), hostForNode(node), cp.transportOptsForNode(node)...) if err != nil { return fmt.Errorf("failed to connect to %s: %w", node.Name, err) } @@ -418,7 +428,7 @@ func (cp *ClusterProvisioner) joinControlPlane(node NodeInfo) error { // joinWorker joins a worker node to the cluster func (cp *ClusterProvisioner) joinWorker(node NodeInfo) error { - provisioner, err := New(cp.log, cp.KeyPath, cp.getUsernameForNode(node), node.PublicIP, cp.transportOptsForNode(node)...) + provisioner, err := New(cp.log, cp.KeyPath, cp.getUsernameForNode(node), hostForNode(node), cp.transportOptsForNode(node)...) if err != nil { return fmt.Errorf("failed to connect to %s: %w", node.Name, err) } @@ -460,7 +470,7 @@ func (cp *ClusterProvisioner) isHAEnabled() bool { // configureNodes applies labels, taints, and roles to all cluster nodes // This is run from the first control-plane node after all nodes have joined func (cp *ClusterProvisioner) configureNodes(firstCP NodeInfo, nodes []NodeInfo) error { - provisioner, err := New(cp.log, cp.KeyPath, cp.getUsernameForNode(firstCP), firstCP.PublicIP, cp.transportOptsForNode(firstCP)...) + provisioner, err := New(cp.log, cp.KeyPath, cp.getUsernameForNode(firstCP), hostForNode(firstCP), cp.transportOptsForNode(firstCP)...) if err != nil { return fmt.Errorf("failed to connect to %s: %w", firstCP.Name, err) } @@ -643,15 +653,16 @@ type NodeHealth struct { InternalIP string } -// GetClusterHealth checks the health of a multinode cluster by querying the first control-plane -func (cp *ClusterProvisioner) GetClusterHealth(firstCPPublicIP string) (*ClusterHealth, error) { +// GetClusterHealth checks the health of a multinode cluster by querying the first control-plane. +// firstCPHost is the SSH-reachable address — PublicIP for direct SSH, PrivateIP for SSM transport. +func (cp *ClusterProvisioner) GetClusterHealth(firstCPHost string) (*ClusterHealth, error) { // Resolve SSH username: check per-node username from cluster status first, // then fall back to global username. This handles OS-based provisioning // where the provider resolves the SSH username per node. username := cp.UserName if cp.Environment != nil && cp.Environment.Status.Cluster != nil { for _, node := range cp.Environment.Status.Cluster.Nodes { - if node.PublicIP == firstCPPublicIP && node.SSHUsername != "" { + if (node.PublicIP == firstCPHost || node.PrivateIP == firstCPHost) && node.SSHUsername != "" { username = node.SSHUsername break } @@ -661,7 +672,7 @@ func (cp *ClusterProvisioner) GetClusterHealth(firstCPPublicIP string) (*Cluster var transportOpts []Option if cp.Environment != nil && cp.Environment.Status.Cluster != nil { for _, node := range cp.Environment.Status.Cluster.Nodes { - if node.PublicIP == firstCPPublicIP { + if node.PublicIP == firstCPHost || node.PrivateIP == firstCPHost { nodeInfo := NodeInfo{ PublicIP: node.PublicIP, PrivateIP: node.PrivateIP, @@ -672,7 +683,7 @@ func (cp *ClusterProvisioner) GetClusterHealth(firstCPPublicIP string) (*Cluster } } } - provisioner, err := New(cp.log, cp.KeyPath, username, firstCPPublicIP, transportOpts...) + provisioner, err := New(cp.log, cp.KeyPath, username, firstCPHost, transportOpts...) if err != nil { return &ClusterHealth{ Healthy: false, @@ -782,20 +793,25 @@ func GetClusterHealthFromEnv(log *logger.FunLogger, env *v1alpha1.Environment) ( return nil, fmt.Errorf("not a multinode cluster") } - // Find first control-plane node - var firstCPIP string + // Find first control-plane node — prefer PublicIP for direct SSH, + // fall back to PrivateIP for SSM transport (private subnet nodes). + var firstCPHost string for _, node := range env.Status.Cluster.Nodes { if node.Role == "control-plane" { - firstCPIP = node.PublicIP + if node.PublicIP != "" { + firstCPHost = node.PublicIP + } else { + firstCPHost = node.PrivateIP + } break } } - if firstCPIP == "" { + if firstCPHost == "" { return nil, fmt.Errorf("no control-plane node found") } cp := NewClusterProvisioner(log, env.Spec.PrivateKey, env.Spec.Username, env) - return cp.GetClusterHealth(firstCPIP) + return cp.GetClusterHealth(firstCPHost) } // Note: addScriptHeader is defined in provisioner.go diff --git a/tests/aws_cluster_test.go b/tests/aws_cluster_test.go index 40211ca2..b3b790c2 100644 --- a/tests/aws_cluster_test.go +++ b/tests/aws_cluster_test.go @@ -182,18 +182,23 @@ var _ = DescribeTable("AWS Cluster E2E", Expect(cp.ProvisionCluster(nodes)).To(Succeed(), "Failed to provision cluster") By("Verifying cluster health") - // Get first control-plane IP for health check - var firstCPIP string + // Get first control-plane host for health check — prefer PublicIP, + // fall back to PrivateIP for SSM transport (private subnet nodes). + var firstCPHost string for _, node := range env.Status.Cluster.Nodes { if node.Role == "control-plane" { - firstCPIP = node.PublicIP + if node.PublicIP != "" { + firstCPHost = node.PublicIP + } else { + firstCPHost = node.PrivateIP + } break } } - Expect(firstCPIP).NotTo(BeEmpty(), "First control plane IP should not be empty") + Expect(firstCPHost).NotTo(BeEmpty(), "First control plane host should not be empty") // Check cluster health - health, err := cp.GetClusterHealth(firstCPIP) + health, err := cp.GetClusterHealth(firstCPHost) Expect(err).NotTo(HaveOccurred(), "Failed to get cluster health") Expect(health).NotTo(BeNil(), "Cluster health should not be nil") Expect(health.APIServerStatus).To(Equal("Running"), "API server should be running")