Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions internal/aws/ec2_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,14 @@ type EC2Client interface {
DescribeSecurityGroups(ctx context.Context,
params *ec2.DescribeSecurityGroupsInput,
optFns ...func(*ec2.Options)) (*ec2.DescribeSecurityGroupsOutput, error)
RevokeSecurityGroupIngress(ctx context.Context,
params *ec2.RevokeSecurityGroupIngressInput,
optFns ...func(*ec2.Options)) (*ec2.RevokeSecurityGroupIngressOutput,
error)
RevokeSecurityGroupEgress(ctx context.Context,
params *ec2.RevokeSecurityGroupEgressInput,
optFns ...func(*ec2.Options)) (*ec2.RevokeSecurityGroupEgressOutput,
error)

// Instance operations
RunInstances(ctx context.Context, params *ec2.RunInstancesInput,
Expand Down
28 changes: 8 additions & 20 deletions pkg/provider/aws/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,12 @@ func (p *Provider) CreateCluster() error {
}
_ = p.updateProgressingCondition(*p.DeepCopy(), &cache.AWS, "v1alpha1.Creating", "Internet Gateway created")

// Phase 1b: Create cluster networking (public subnet, NAT GW, route tables)
// Note: We skip createRouteTable() here — in single-node mode it creates an
// IGW-routed table for the (only) subnet. In cluster mode the private subnet
// gets a NAT-routed table (createPrivateRouteTable) and the public subnet
// gets an IGW-routed table (createPublicRouteTable). Calling createRouteTable
// would create an orphaned IGW table associated with the private subnet.
// Phase 1b: Create public subnet and route table for cluster instances.
// Cluster instances are placed in the public subnet with AssociatePublicIpAddress=true,
// so they have direct internet access via the IGW. NAT Gateway and private route table
// are NOT needed — skipping them avoids consuming scarce EIP quota (AWS limit: 5 per
// region), which caused CI failures when multiple jobs ran concurrently.
// The private subnet (created above) is retained for future SSM endpoint use.
if err := p.createPublicSubnet(&cache.AWS); err != nil {
_ = p.updateDegradedCondition(*p.DeepCopy(), &cache.AWS, "v1alpha1.Creating", "Error creating public subnet")
return fmt.Errorf("error creating public subnet: %w", err)
Expand All @@ -153,18 +153,6 @@ func (p *Provider) CreateCluster() error {
}
_ = p.updateProgressingCondition(*p.DeepCopy(), &cache.AWS, "v1alpha1.Creating", "Public route table created")

if err := p.createNATGateway(&cache.AWS); err != nil {
_ = p.updateDegradedCondition(*p.DeepCopy(), &cache.AWS, "v1alpha1.Creating", "Error creating NAT Gateway")
return fmt.Errorf("error creating NAT Gateway: %w", err)
}
_ = p.updateProgressingCondition(*p.DeepCopy(), &cache.AWS, "v1alpha1.Creating", "NAT Gateway created")

if err := p.createPrivateRouteTable(&cache.AWS); err != nil {
_ = p.updateDegradedCondition(*p.DeepCopy(), &cache.AWS, "v1alpha1.Creating", "Error creating private route table")
return fmt.Errorf("error creating private route table: %w", err)
}
_ = p.updateProgressingCondition(*p.DeepCopy(), &cache.AWS, "v1alpha1.Creating", "Private route table created")

// Phase 2: Create separate CP and Worker security groups
if err := p.createControlPlaneSecurityGroup(cache); err != nil {
_ = p.updateDegradedCondition(*p.DeepCopy(), &cache.AWS, "v1alpha1.Creating", "Error creating control-plane security group")
Expand Down Expand Up @@ -701,11 +689,11 @@ func (p *Provider) createInstances(
},
NetworkInterfaces: []types.InstanceNetworkInterfaceSpecification{
{
AssociatePublicIpAddress: aws.Bool(false),
AssociatePublicIpAddress: aws.Bool(true),
DeleteOnTermination: aws.Bool(true),
DeviceIndex: aws.Int32(0),
Groups: []string{sgID},
SubnetId: aws.String(cache.Subnetid),
SubnetId: aws.String(cache.PublicSubnetid),
},
},
KeyName: aws.String(p.Spec.KeyName),
Expand Down
106 changes: 98 additions & 8 deletions pkg/provider/aws/cluster_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package aws

import (
"context"
"strings"
"sync"
"testing"

Expand All @@ -28,9 +29,9 @@ import (
"github.com/NVIDIA/holodeck/api/holodeck/v1alpha1"
)

// TestCreateInstancesSetsNoPublicIP verifies that createInstances sets
// AssociatePublicIpAddress=false in the RunInstancesInput for cluster mode.
func TestCreateInstancesSetsNoPublicIP(t *testing.T) {
// TestCreateInstancesSetsPublicIP verifies that createInstances sets
// AssociatePublicIpAddress=true in the RunInstancesInput for cluster mode.
func TestCreateInstancesSetsPublicIP(t *testing.T) {
var mu sync.Mutex
var captured []*ec2.RunInstancesInput

Expand Down Expand Up @@ -93,6 +94,7 @@ func TestCreateInstancesSetsNoPublicIP(t *testing.T) {
cache := &ClusterCache{
AWS: AWS{
Subnetid: "subnet-private",
PublicSubnetid: "subnet-public",
CPSecurityGroupid: "sg-cp",
WorkerSecurityGroupid: "sg-worker",
},
Expand Down Expand Up @@ -126,13 +128,13 @@ func TestCreateInstancesSetsNoPublicIP(t *testing.T) {
if len(nis) == 0 {
t.Fatal("expected NetworkInterfaces in RunInstancesInput")
}
if nis[0].AssociatePublicIpAddress == nil || *nis[0].AssociatePublicIpAddress != false {
t.Error("AssociatePublicIpAddress should be false for cluster instances")
if nis[0].AssociatePublicIpAddress == nil || *nis[0].AssociatePublicIpAddress != true {
t.Error("AssociatePublicIpAddress should be true for cluster instances")
}

// Verify instance uses private subnet
if aws.ToString(nis[0].SubnetId) != "subnet-private" {
t.Errorf("SubnetId = %q, want %q", aws.ToString(nis[0].SubnetId), "subnet-private")
// Verify instance uses public subnet
if aws.ToString(nis[0].SubnetId) != "subnet-public" {
t.Errorf("SubnetId = %q, want %q", aws.ToString(nis[0].SubnetId), "subnet-public")
}
}

Expand Down Expand Up @@ -384,3 +386,91 @@ func TestNATGatewayCreatedInPublicSubnet(t *testing.T) {
aws.ToString(capturedNAT.SubnetId), "subnet-public")
}
}

// TestNATGatewayWaitsForAvailable verifies that createNATGateway polls
// DescribeNatGateways until the NAT GW transitions from pending to available.
func TestNATGatewayWaitsForAvailable(t *testing.T) {
describeCalls := 0

mock := NewMockEC2Client()
mock.CreateNatGatewayFunc = func(ctx context.Context, params *ec2.CreateNatGatewayInput, optFns ...func(*ec2.Options)) (*ec2.CreateNatGatewayOutput, error) {
return &ec2.CreateNatGatewayOutput{
NatGateway: &types.NatGateway{
NatGatewayId: aws.String("nat-pending-123"),
State: types.NatGatewayStatePending,
},
}, nil
}
mock.DescribeNatGatewaysFunc = func(ctx context.Context, params *ec2.DescribeNatGatewaysInput, optFns ...func(*ec2.Options)) (*ec2.DescribeNatGatewaysOutput, error) {
describeCalls++
state := types.NatGatewayStatePending
if describeCalls >= 2 {
state = types.NatGatewayStateAvailable
}
return &ec2.DescribeNatGatewaysOutput{
NatGateways: []types.NatGateway{
{
NatGatewayId: aws.String("nat-pending-123"),
State: state,
},
},
}, nil
}

provider := newTestProvider(mock)
cache := &AWS{
Vpcid: "vpc-test",
PublicSubnetid: "subnet-public",
Subnetid: "subnet-private",
}

if err := provider.createNATGateway(cache); err != nil {
t.Fatalf("createNATGateway failed: %v", err)
}

if describeCalls < 2 {
t.Errorf("Expected at least 2 DescribeNatGateways calls for polling, got %d", describeCalls)
}
if cache.NatGatewayid != "nat-pending-123" {
t.Errorf("cache.NatGatewayid = %q, want %q", cache.NatGatewayid, "nat-pending-123")
}
}

// TestNATGatewayFailedState verifies that createNATGateway returns an error
// if the NAT GW transitions to the failed state.
func TestNATGatewayFailedState(t *testing.T) {
mock := NewMockEC2Client()
mock.CreateNatGatewayFunc = func(ctx context.Context, params *ec2.CreateNatGatewayInput, optFns ...func(*ec2.Options)) (*ec2.CreateNatGatewayOutput, error) {
return &ec2.CreateNatGatewayOutput{
NatGateway: &types.NatGateway{
NatGatewayId: aws.String("nat-fail-123"),
State: types.NatGatewayStatePending,
},
}, nil
}
mock.DescribeNatGatewaysFunc = func(ctx context.Context, params *ec2.DescribeNatGatewaysInput, optFns ...func(*ec2.Options)) (*ec2.DescribeNatGatewaysOutput, error) {
return &ec2.DescribeNatGatewaysOutput{
NatGateways: []types.NatGateway{
{
NatGatewayId: aws.String("nat-fail-123"),
State: types.NatGatewayStateFailed,
},
},
}, nil
}

provider := newTestProvider(mock)
cache := &AWS{
Vpcid: "vpc-test",
PublicSubnetid: "subnet-public",
Subnetid: "subnet-private",
}

err := provider.createNATGateway(cache)
if err == nil {
t.Fatal("Expected error when NAT Gateway reaches failed state")
}
if !strings.Contains(err.Error(), "failed state") {
t.Errorf("Error should mention failed state, got: %v", err)
}
}
31 changes: 29 additions & 2 deletions pkg/provider/aws/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -654,8 +654,35 @@ func (p *Provider) createNATGateway(cache *AWS) error {
}
cache.NatGatewayid = *natOutput.NatGateway.NatGatewayId

cancelLoading(nil)
return nil
// Wait for NAT Gateway to reach "available" state before returning.
// CreateNatGateway returns immediately with state "pending"; using the
// NAT GW ID in a CreateRoute call before it is available causes
// InvalidNatGatewayID.NotFound errors.
p.log.Info("Waiting for NAT Gateway %s to become available", cache.NatGatewayid)
for i := 0; i < 60; i++ { // 60 × 5s = 5 minutes max
time.Sleep(5 * time.Second)
dCtx, dCancel := context.WithTimeout(context.Background(), 30*time.Second)
out, err := p.ec2.DescribeNatGateways(dCtx, &ec2.DescribeNatGatewaysInput{
NatGatewayIds: []string{cache.NatGatewayid},
})
dCancel()
if err != nil {
p.log.Warning("Error checking NAT Gateway state: %v", err)
continue
}
if len(out.NatGateways) > 0 && out.NatGateways[0].State == types.NatGatewayStateAvailable {
p.log.Info("NAT Gateway %s is available", cache.NatGatewayid)
cancelLoading(nil)
return nil
}
if len(out.NatGateways) > 0 && out.NatGateways[0].State == types.NatGatewayStateFailed {
cancelLoading(logger.ErrLoadingFailed)
return fmt.Errorf("NAT Gateway %s reached failed state", cache.NatGatewayid)
}
}

cancelLoading(logger.ErrLoadingFailed)
return fmt.Errorf("NAT Gateway %s did not become available within timeout", cache.NatGatewayid)
}

// releaseEIP releases an Elastic IP by allocation ID.
Expand Down
23 changes: 22 additions & 1 deletion pkg/provider/aws/create_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,18 @@ func (m *mockEC2Client) DescribeSecurityGroups(ctx context.Context,
return &ec2.DescribeSecurityGroupsOutput{}, nil
}

func (m *mockEC2Client) RevokeSecurityGroupIngress(ctx context.Context,
params *ec2.RevokeSecurityGroupIngressInput,
optFns ...func(*ec2.Options)) (*ec2.RevokeSecurityGroupIngressOutput, error) {
return &ec2.RevokeSecurityGroupIngressOutput{}, nil
}

func (m *mockEC2Client) RevokeSecurityGroupEgress(ctx context.Context,
params *ec2.RevokeSecurityGroupEgressInput,
optFns ...func(*ec2.Options)) (*ec2.RevokeSecurityGroupEgressOutput, error) {
return &ec2.RevokeSecurityGroupEgressOutput{}, nil
}

func (m *mockEC2Client) RunInstances(ctx context.Context, params *ec2.RunInstancesInput,
optFns ...func(*ec2.Options)) (*ec2.RunInstancesOutput, error) {
m.runInstancesCalls = append(m.runInstancesCalls, *params)
Expand Down Expand Up @@ -310,7 +322,16 @@ func (m *mockEC2Client) DeleteNatGateway(ctx context.Context, params *ec2.Delete

func (m *mockEC2Client) DescribeNatGateways(ctx context.Context, params *ec2.DescribeNatGatewaysInput,
optFns ...func(*ec2.Options)) (*ec2.DescribeNatGatewaysOutput, error) {
return &ec2.DescribeNatGatewaysOutput{}, nil
// Return the NAT GW as available so the wait loop in createNATGateway exits immediately.
natGWID := "nat-123"
if len(params.NatGatewayIds) > 0 {
natGWID = params.NatGatewayIds[0]
}
return &ec2.DescribeNatGatewaysOutput{
NatGateways: []types.NatGateway{
{NatGatewayId: aws.String(natGWID), State: types.NatGatewayStateAvailable},
},
}, nil
}

func (m *mockEC2Client) ModifySubnetAttribute(ctx context.Context, params *ec2.ModifySubnetAttributeInput,
Expand Down
Loading
Loading