diff --git a/internal/aws/ec2_client.go b/internal/aws/ec2_client.go index 9339c92a..5d56ec54 100644 --- a/internal/aws/ec2_client.go +++ b/internal/aws/ec2_client.go @@ -96,6 +96,14 @@ type EC2Client interface { DescribeSecurityGroups(ctx context.Context, params *ec2.DescribeSecurityGroupsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeSecurityGroupsOutput, error) + RevokeSecurityGroupIngress(ctx context.Context, + params *ec2.RevokeSecurityGroupIngressInput, + optFns ...func(*ec2.Options)) (*ec2.RevokeSecurityGroupIngressOutput, + error) + RevokeSecurityGroupEgress(ctx context.Context, + params *ec2.RevokeSecurityGroupEgressInput, + optFns ...func(*ec2.Options)) (*ec2.RevokeSecurityGroupEgressOutput, + error) // Instance operations RunInstances(ctx context.Context, params *ec2.RunInstancesInput, diff --git a/pkg/provider/aws/cluster.go b/pkg/provider/aws/cluster.go index a77e8414..bb22efb8 100644 --- a/pkg/provider/aws/cluster.go +++ b/pkg/provider/aws/cluster.go @@ -135,12 +135,12 @@ func (p *Provider) CreateCluster() error { } _ = p.updateProgressingCondition(*p.DeepCopy(), &cache.AWS, "v1alpha1.Creating", "Internet Gateway created") - // Phase 1b: Create cluster networking (public subnet, NAT GW, route tables) - // Note: We skip createRouteTable() here — in single-node mode it creates an - // IGW-routed table for the (only) subnet. In cluster mode the private subnet - // gets a NAT-routed table (createPrivateRouteTable) and the public subnet - // gets an IGW-routed table (createPublicRouteTable). Calling createRouteTable - // would create an orphaned IGW table associated with the private subnet. + // Phase 1b: Create public subnet and route table for cluster instances. + // Cluster instances are placed in the public subnet with AssociatePublicIpAddress=true, + // so they have direct internet access via the IGW. NAT Gateway and private route table + // are NOT needed — skipping them avoids consuming scarce EIP quota (AWS limit: 5 per + // region), which caused CI failures when multiple jobs ran concurrently. + // The private subnet (created above) is retained for future SSM endpoint use. if err := p.createPublicSubnet(&cache.AWS); err != nil { _ = p.updateDegradedCondition(*p.DeepCopy(), &cache.AWS, "v1alpha1.Creating", "Error creating public subnet") return fmt.Errorf("error creating public subnet: %w", err) @@ -153,18 +153,6 @@ func (p *Provider) CreateCluster() error { } _ = p.updateProgressingCondition(*p.DeepCopy(), &cache.AWS, "v1alpha1.Creating", "Public route table created") - if err := p.createNATGateway(&cache.AWS); err != nil { - _ = p.updateDegradedCondition(*p.DeepCopy(), &cache.AWS, "v1alpha1.Creating", "Error creating NAT Gateway") - return fmt.Errorf("error creating NAT Gateway: %w", err) - } - _ = p.updateProgressingCondition(*p.DeepCopy(), &cache.AWS, "v1alpha1.Creating", "NAT Gateway created") - - if err := p.createPrivateRouteTable(&cache.AWS); err != nil { - _ = p.updateDegradedCondition(*p.DeepCopy(), &cache.AWS, "v1alpha1.Creating", "Error creating private route table") - return fmt.Errorf("error creating private route table: %w", err) - } - _ = p.updateProgressingCondition(*p.DeepCopy(), &cache.AWS, "v1alpha1.Creating", "Private route table created") - // Phase 2: Create separate CP and Worker security groups if err := p.createControlPlaneSecurityGroup(cache); err != nil { _ = p.updateDegradedCondition(*p.DeepCopy(), &cache.AWS, "v1alpha1.Creating", "Error creating control-plane security group") @@ -701,11 +689,11 @@ func (p *Provider) createInstances( }, NetworkInterfaces: []types.InstanceNetworkInterfaceSpecification{ { - AssociatePublicIpAddress: aws.Bool(false), + AssociatePublicIpAddress: aws.Bool(true), DeleteOnTermination: aws.Bool(true), DeviceIndex: aws.Int32(0), Groups: []string{sgID}, - SubnetId: aws.String(cache.Subnetid), + SubnetId: aws.String(cache.PublicSubnetid), }, }, KeyName: aws.String(p.Spec.KeyName), diff --git a/pkg/provider/aws/cluster_test.go b/pkg/provider/aws/cluster_test.go index d976f1c2..74b71de0 100644 --- a/pkg/provider/aws/cluster_test.go +++ b/pkg/provider/aws/cluster_test.go @@ -18,6 +18,7 @@ package aws import ( "context" + "strings" "sync" "testing" @@ -28,9 +29,9 @@ import ( "github.com/NVIDIA/holodeck/api/holodeck/v1alpha1" ) -// TestCreateInstancesSetsNoPublicIP verifies that createInstances sets -// AssociatePublicIpAddress=false in the RunInstancesInput for cluster mode. -func TestCreateInstancesSetsNoPublicIP(t *testing.T) { +// TestCreateInstancesSetsPublicIP verifies that createInstances sets +// AssociatePublicIpAddress=true in the RunInstancesInput for cluster mode. +func TestCreateInstancesSetsPublicIP(t *testing.T) { var mu sync.Mutex var captured []*ec2.RunInstancesInput @@ -93,6 +94,7 @@ func TestCreateInstancesSetsNoPublicIP(t *testing.T) { cache := &ClusterCache{ AWS: AWS{ Subnetid: "subnet-private", + PublicSubnetid: "subnet-public", CPSecurityGroupid: "sg-cp", WorkerSecurityGroupid: "sg-worker", }, @@ -126,13 +128,13 @@ func TestCreateInstancesSetsNoPublicIP(t *testing.T) { if len(nis) == 0 { t.Fatal("expected NetworkInterfaces in RunInstancesInput") } - if nis[0].AssociatePublicIpAddress == nil || *nis[0].AssociatePublicIpAddress != false { - t.Error("AssociatePublicIpAddress should be false for cluster instances") + if nis[0].AssociatePublicIpAddress == nil || *nis[0].AssociatePublicIpAddress != true { + t.Error("AssociatePublicIpAddress should be true for cluster instances") } - // Verify instance uses private subnet - if aws.ToString(nis[0].SubnetId) != "subnet-private" { - t.Errorf("SubnetId = %q, want %q", aws.ToString(nis[0].SubnetId), "subnet-private") + // Verify instance uses public subnet + if aws.ToString(nis[0].SubnetId) != "subnet-public" { + t.Errorf("SubnetId = %q, want %q", aws.ToString(nis[0].SubnetId), "subnet-public") } } @@ -384,3 +386,91 @@ func TestNATGatewayCreatedInPublicSubnet(t *testing.T) { aws.ToString(capturedNAT.SubnetId), "subnet-public") } } + +// TestNATGatewayWaitsForAvailable verifies that createNATGateway polls +// DescribeNatGateways until the NAT GW transitions from pending to available. +func TestNATGatewayWaitsForAvailable(t *testing.T) { + describeCalls := 0 + + mock := NewMockEC2Client() + mock.CreateNatGatewayFunc = func(ctx context.Context, params *ec2.CreateNatGatewayInput, optFns ...func(*ec2.Options)) (*ec2.CreateNatGatewayOutput, error) { + return &ec2.CreateNatGatewayOutput{ + NatGateway: &types.NatGateway{ + NatGatewayId: aws.String("nat-pending-123"), + State: types.NatGatewayStatePending, + }, + }, nil + } + mock.DescribeNatGatewaysFunc = func(ctx context.Context, params *ec2.DescribeNatGatewaysInput, optFns ...func(*ec2.Options)) (*ec2.DescribeNatGatewaysOutput, error) { + describeCalls++ + state := types.NatGatewayStatePending + if describeCalls >= 2 { + state = types.NatGatewayStateAvailable + } + return &ec2.DescribeNatGatewaysOutput{ + NatGateways: []types.NatGateway{ + { + NatGatewayId: aws.String("nat-pending-123"), + State: state, + }, + }, + }, nil + } + + provider := newTestProvider(mock) + cache := &AWS{ + Vpcid: "vpc-test", + PublicSubnetid: "subnet-public", + Subnetid: "subnet-private", + } + + if err := provider.createNATGateway(cache); err != nil { + t.Fatalf("createNATGateway failed: %v", err) + } + + if describeCalls < 2 { + t.Errorf("Expected at least 2 DescribeNatGateways calls for polling, got %d", describeCalls) + } + if cache.NatGatewayid != "nat-pending-123" { + t.Errorf("cache.NatGatewayid = %q, want %q", cache.NatGatewayid, "nat-pending-123") + } +} + +// TestNATGatewayFailedState verifies that createNATGateway returns an error +// if the NAT GW transitions to the failed state. +func TestNATGatewayFailedState(t *testing.T) { + mock := NewMockEC2Client() + mock.CreateNatGatewayFunc = func(ctx context.Context, params *ec2.CreateNatGatewayInput, optFns ...func(*ec2.Options)) (*ec2.CreateNatGatewayOutput, error) { + return &ec2.CreateNatGatewayOutput{ + NatGateway: &types.NatGateway{ + NatGatewayId: aws.String("nat-fail-123"), + State: types.NatGatewayStatePending, + }, + }, nil + } + mock.DescribeNatGatewaysFunc = func(ctx context.Context, params *ec2.DescribeNatGatewaysInput, optFns ...func(*ec2.Options)) (*ec2.DescribeNatGatewaysOutput, error) { + return &ec2.DescribeNatGatewaysOutput{ + NatGateways: []types.NatGateway{ + { + NatGatewayId: aws.String("nat-fail-123"), + State: types.NatGatewayStateFailed, + }, + }, + }, nil + } + + provider := newTestProvider(mock) + cache := &AWS{ + Vpcid: "vpc-test", + PublicSubnetid: "subnet-public", + Subnetid: "subnet-private", + } + + err := provider.createNATGateway(cache) + if err == nil { + t.Fatal("Expected error when NAT Gateway reaches failed state") + } + if !strings.Contains(err.Error(), "failed state") { + t.Errorf("Error should mention failed state, got: %v", err) + } +} diff --git a/pkg/provider/aws/create.go b/pkg/provider/aws/create.go index bafc30f3..1f44c0d8 100644 --- a/pkg/provider/aws/create.go +++ b/pkg/provider/aws/create.go @@ -654,8 +654,35 @@ func (p *Provider) createNATGateway(cache *AWS) error { } cache.NatGatewayid = *natOutput.NatGateway.NatGatewayId - cancelLoading(nil) - return nil + // Wait for NAT Gateway to reach "available" state before returning. + // CreateNatGateway returns immediately with state "pending"; using the + // NAT GW ID in a CreateRoute call before it is available causes + // InvalidNatGatewayID.NotFound errors. + p.log.Info("Waiting for NAT Gateway %s to become available", cache.NatGatewayid) + for i := 0; i < 60; i++ { // 60 × 5s = 5 minutes max + time.Sleep(5 * time.Second) + dCtx, dCancel := context.WithTimeout(context.Background(), 30*time.Second) + out, err := p.ec2.DescribeNatGateways(dCtx, &ec2.DescribeNatGatewaysInput{ + NatGatewayIds: []string{cache.NatGatewayid}, + }) + dCancel() + if err != nil { + p.log.Warning("Error checking NAT Gateway state: %v", err) + continue + } + if len(out.NatGateways) > 0 && out.NatGateways[0].State == types.NatGatewayStateAvailable { + p.log.Info("NAT Gateway %s is available", cache.NatGatewayid) + cancelLoading(nil) + return nil + } + if len(out.NatGateways) > 0 && out.NatGateways[0].State == types.NatGatewayStateFailed { + cancelLoading(logger.ErrLoadingFailed) + return fmt.Errorf("NAT Gateway %s reached failed state", cache.NatGatewayid) + } + } + + cancelLoading(logger.ErrLoadingFailed) + return fmt.Errorf("NAT Gateway %s did not become available within timeout", cache.NatGatewayid) } // releaseEIP releases an Elastic IP by allocation ID. diff --git a/pkg/provider/aws/create_test.go b/pkg/provider/aws/create_test.go index 42d3c887..592e08a7 100644 --- a/pkg/provider/aws/create_test.go +++ b/pkg/provider/aws/create_test.go @@ -185,6 +185,18 @@ func (m *mockEC2Client) DescribeSecurityGroups(ctx context.Context, return &ec2.DescribeSecurityGroupsOutput{}, nil } +func (m *mockEC2Client) RevokeSecurityGroupIngress(ctx context.Context, + params *ec2.RevokeSecurityGroupIngressInput, + optFns ...func(*ec2.Options)) (*ec2.RevokeSecurityGroupIngressOutput, error) { + return &ec2.RevokeSecurityGroupIngressOutput{}, nil +} + +func (m *mockEC2Client) RevokeSecurityGroupEgress(ctx context.Context, + params *ec2.RevokeSecurityGroupEgressInput, + optFns ...func(*ec2.Options)) (*ec2.RevokeSecurityGroupEgressOutput, error) { + return &ec2.RevokeSecurityGroupEgressOutput{}, nil +} + func (m *mockEC2Client) RunInstances(ctx context.Context, params *ec2.RunInstancesInput, optFns ...func(*ec2.Options)) (*ec2.RunInstancesOutput, error) { m.runInstancesCalls = append(m.runInstancesCalls, *params) @@ -310,7 +322,16 @@ func (m *mockEC2Client) DeleteNatGateway(ctx context.Context, params *ec2.Delete func (m *mockEC2Client) DescribeNatGateways(ctx context.Context, params *ec2.DescribeNatGatewaysInput, optFns ...func(*ec2.Options)) (*ec2.DescribeNatGatewaysOutput, error) { - return &ec2.DescribeNatGatewaysOutput{}, nil + // Return the NAT GW as available so the wait loop in createNATGateway exits immediately. + natGWID := "nat-123" + if len(params.NatGatewayIds) > 0 { + natGWID = params.NatGatewayIds[0] + } + return &ec2.DescribeNatGatewaysOutput{ + NatGateways: []types.NatGateway{ + {NatGatewayId: aws.String(natGWID), State: types.NatGatewayStateAvailable}, + }, + }, nil } func (m *mockEC2Client) ModifySubnetAttribute(ctx context.Context, params *ec2.ModifySubnetAttributeInput, diff --git a/pkg/provider/aws/delete.go b/pkg/provider/aws/delete.go index b47aede3..cf069083 100644 --- a/pkg/provider/aws/delete.go +++ b/pkg/provider/aws/delete.go @@ -32,8 +32,8 @@ import ( ) const ( - maxRetries = 5 - retryDelay = 5 * time.Second + maxRetries = 10 + retryDelay = 10 * time.Second maxRetryDelay = 30 * time.Second deletionTimeout = 15 * time.Minute verificationDelay = 2 * time.Second @@ -115,6 +115,14 @@ func (p *Provider) delete(cache *AWS) error { return fmt.Errorf("failed to delete EC2 instances: %w", err) } + // Phase 1.5: Wait for ENIs to detach after instance termination. + // AWS ENIs can linger for 2-5 minutes after termination, blocking SG deletion. + if cache.Vpcid != "" { + if err := p.waitForENIsDrained(cache.Vpcid); err != nil { + p.log.Warning("ENI drain wait failed (continuing): %v", err) + } + } + // Phase 2: Delete Security Groups if err := p.deleteSecurityGroups(cache); err != nil { return fmt.Errorf("failed to delete security groups: %w", err) @@ -253,6 +261,19 @@ func (p *Provider) deleteSecurityGroups(cache *AWS) error { p.log.Error(fmt.Errorf("failed to update progressing condition: %w", err)) } + // Break cross-SG ingress references before deletion. + // Worker SG has ingress rules referencing CP SG, and CP SG has ingress + // rules referencing Worker SG. Both must be cleared to avoid + // DependencyViolation errors on DeleteSecurityGroup. + if cache.WorkerSecurityGroupid != "" && cache.CPSecurityGroupid != "" { + if err := p.revokeSecurityGroupRules(cache.CPSecurityGroupid); err != nil { + p.log.Warning("Error revoking CP SG %s rules (continuing): %v", cache.CPSecurityGroupid, err) + } + if err := p.revokeSecurityGroupRules(cache.WorkerSecurityGroupid); err != nil { + p.log.Warning("Error revoking Worker SG %s rules (continuing): %v", cache.WorkerSecurityGroupid, err) + } + } + // Delete Worker SG first — it references CP SG, so CP SG can't be deleted // while Worker SG still exists. if err := p.deleteSecurityGroup(cache.WorkerSecurityGroupid, "worker"); err != nil { @@ -345,13 +366,15 @@ func (p *Provider) deleteVPCResources(cache *AWS) error { return err } - // Step 3: Delete public route table - if err := p.deletePublicRouteTable(cache); err != nil { + // Step 3: Delete public subnet (before route table — deleting the subnet + // implicitly removes the route table association, avoiding the need for + // ec2:DisassociateRouteTable which CI IAM may lack) + if err := p.deletePublicSubnet(cache); err != nil { return err } - // Step 4: Delete public subnet - if err := p.deletePublicSubnet(cache); err != nil { + // Step 4: Delete public route table (association removed by step 3) + if err := p.deletePublicRouteTable(cache); err != nil { return err } @@ -360,7 +383,7 @@ func (p *Provider) deleteVPCResources(cache *AWS) error { return err } - // Step 6: Delete private Route Table + // Step 6: Delete private Route Table (association removed by step 5) if err := p.deleteRouteTable(cache); err != nil { return err } @@ -691,6 +714,108 @@ func (p *Provider) deleteVPC(cache *AWS) error { // Helper functions +// waitForENIsDrained polls DescribeNetworkInterfaces until all non-available +// ENIs in the VPC are detached or deleted. This prevents DependencyViolation +// errors when deleting security groups, since AWS ENIs can linger for 2-5 +// minutes after instance termination. +func (p *Provider) waitForENIsDrained(vpcID string) error { + if vpcID == "" { + return nil + } + + const ( + eniPollInterval = 10 * time.Second + eniPollTimeout = 5 * time.Minute + ) + + deadline := time.Now().Add(eniPollTimeout) + + for { + ctx, cancel := context.WithTimeout(context.Background(), apiCallTimeout) + result, err := p.ec2.DescribeNetworkInterfaces(ctx, &ec2.DescribeNetworkInterfacesInput{ + Filters: []types.Filter{ + {Name: aws.String("vpc-id"), Values: []string{vpcID}}, + }, + }) + cancel() + + if err != nil { + p.log.Warning("Error checking ENIs in VPC %s: %v", vpcID, err) + } else { + // Count non-available ENIs (in-use ENIs block SG deletion) + var blocking int + for _, eni := range result.NetworkInterfaces { + if eni.Status != types.NetworkInterfaceStatusAvailable { + blocking++ + } + } + if blocking == 0 { + p.log.Info("All ENIs in VPC %s are drained", vpcID) + return nil + } + p.log.Info("Waiting for %d in-use ENI(s) in VPC %s to detach...", blocking, vpcID) + } + + if time.Now().After(deadline) { + return fmt.Errorf("timed out waiting for ENIs to drain in VPC %s", vpcID) + } + + time.Sleep(eniPollInterval) + } +} + +// revokeSecurityGroupRules removes all ingress and egress rules from a security +// group before deletion. This prevents DependencyViolation errors caused by +// cross-SG references (e.g., worker SG referencing control-plane SG). +func (p *Provider) revokeSecurityGroupRules(sgID string) error { + if sgID == "" { + return nil + } + + ctx, cancel := context.WithTimeout(context.Background(), apiCallTimeout) + defer cancel() + + result, err := p.ec2.DescribeSecurityGroups(ctx, &ec2.DescribeSecurityGroupsInput{ + GroupIds: []string{sgID}, + }) + if err != nil { + if strings.Contains(err.Error(), "InvalidGroup.NotFound") { + return nil + } + return fmt.Errorf("error describing security group %s: %w", sgID, err) + } + + if len(result.SecurityGroups) == 0 { + return nil + } + + sg := result.SecurityGroups[0] + + // Revoke all ingress rules + if len(sg.IpPermissions) > 0 { + p.log.Info("Revoking %d ingress rule(s) from security group %s", len(sg.IpPermissions), sgID) + rCtx, rCancel := context.WithTimeout(context.Background(), apiCallTimeout) + defer rCancel() + _, err := p.ec2.RevokeSecurityGroupIngress(rCtx, &ec2.RevokeSecurityGroupIngressInput{ + GroupId: &sgID, + IpPermissions: sg.IpPermissions, + }) + if err != nil { + if strings.Contains(err.Error(), "InvalidGroup.NotFound") { + return nil + } + return fmt.Errorf("error revoking ingress rules for %s: %w", sgID, err) + } + } + + // Note: egress rule revocation is intentionally skipped. + // The CI IAM user (cnt-ci) lacks ec2:RevokeSecurityGroupEgress permission, + // and the default egress rule (0.0.0.0/0) does not create cross-SG + // dependencies that would block DeleteSecurityGroup. + + return nil +} + func (p *Provider) retryWithBackoff(operation func() error) error { delay := retryDelay for i := 0; i < maxRetries; i++ { diff --git a/pkg/provider/aws/delete_test.go b/pkg/provider/aws/delete_test.go index 6d7a0899..0b185e22 100644 --- a/pkg/provider/aws/delete_test.go +++ b/pkg/provider/aws/delete_test.go @@ -51,8 +51,8 @@ func TestDeleteConstants(t *testing.T) { constant time.Duration expected time.Duration }{ - {"maxRetries", time.Duration(maxRetries), 5}, - {"retryDelay", retryDelay, 5 * time.Second}, + {"maxRetries", time.Duration(maxRetries), 10}, + {"retryDelay", retryDelay, 10 * time.Second}, {"maxRetryDelay", maxRetryDelay, 30 * time.Second}, {"verificationDelay", verificationDelay, 2 * time.Second}, {"apiCallTimeout", apiCallTimeout, 30 * time.Second}, @@ -519,3 +519,223 @@ func TestDeleteSecurityGroups_SharedSameAsCP(t *testing.T) { t.Errorf("second delete should be CP SG (same as shared), got %s", deletedSGs[1]) } } + +func TestRevokeSecurityGroupRules_RevokesIngressOnly(t *testing.T) { + var ingressCalls []ec2.RevokeSecurityGroupIngressInput + var egressCalls []ec2.RevokeSecurityGroupEgressInput + + mock := NewMockEC2Client() + mock.DescribeSGsFunc = func(ctx context.Context, params *ec2.DescribeSecurityGroupsInput, + optFns ...func(*ec2.Options)) (*ec2.DescribeSecurityGroupsOutput, error) { + return &ec2.DescribeSecurityGroupsOutput{ + SecurityGroups: []types.SecurityGroup{ + { + GroupId: aws.String("sg-worker"), + IpPermissions: []types.IpPermission{ + { + IpProtocol: aws.String("-1"), + UserIdGroupPairs: []types.UserIdGroupPair{ + {GroupId: aws.String("sg-cp")}, + }, + }, + }, + IpPermissionsEgress: []types.IpPermission{ + { + IpProtocol: aws.String("-1"), + IpRanges: []types.IpRange{ + {CidrIp: aws.String("0.0.0.0/0")}, + }, + }, + }, + }, + }, + }, nil + } + mock.RevokeSGIngressFunc = func(ctx context.Context, params *ec2.RevokeSecurityGroupIngressInput, + optFns ...func(*ec2.Options)) (*ec2.RevokeSecurityGroupIngressOutput, error) { + ingressCalls = append(ingressCalls, *params) + return &ec2.RevokeSecurityGroupIngressOutput{}, nil + } + mock.RevokeSGEgressFunc = func(ctx context.Context, params *ec2.RevokeSecurityGroupEgressInput, + optFns ...func(*ec2.Options)) (*ec2.RevokeSecurityGroupEgressOutput, error) { + egressCalls = append(egressCalls, *params) + return &ec2.RevokeSecurityGroupEgressOutput{}, nil + } + + provider := &Provider{ec2: mock, log: mockLogger()} + + err := provider.revokeSecurityGroupRules("sg-worker") + if err != nil { + t.Fatalf("revokeSecurityGroupRules failed: %v", err) + } + + if len(ingressCalls) != 1 { + t.Fatalf("Expected 1 RevokeSecurityGroupIngress call, got %d", len(ingressCalls)) + } + if *ingressCalls[0].GroupId != "sg-worker" { + t.Errorf("Expected GroupId 'sg-worker', got %q", *ingressCalls[0].GroupId) + } + + // Egress revocation is intentionally skipped — CI IAM user lacks + // ec2:RevokeSecurityGroupEgress, and the default egress rule does not + // create cross-SG dependencies that block DeleteSecurityGroup. + if len(egressCalls) != 0 { + t.Errorf("Expected 0 RevokeSecurityGroupEgress calls (egress skipped), got %d", len(egressCalls)) + } +} + +func TestRevokeSecurityGroupRules_SkipsEmptyRules(t *testing.T) { + var ingressCalls []ec2.RevokeSecurityGroupIngressInput + var egressCalls []ec2.RevokeSecurityGroupEgressInput + + mock := NewMockEC2Client() + mock.DescribeSGsFunc = func(ctx context.Context, params *ec2.DescribeSecurityGroupsInput, + optFns ...func(*ec2.Options)) (*ec2.DescribeSecurityGroupsOutput, error) { + return &ec2.DescribeSecurityGroupsOutput{ + SecurityGroups: []types.SecurityGroup{ + { + GroupId: aws.String("sg-empty"), + IpPermissions: nil, + IpPermissionsEgress: nil, + }, + }, + }, nil + } + mock.RevokeSGIngressFunc = func(ctx context.Context, params *ec2.RevokeSecurityGroupIngressInput, + optFns ...func(*ec2.Options)) (*ec2.RevokeSecurityGroupIngressOutput, error) { + ingressCalls = append(ingressCalls, *params) + return &ec2.RevokeSecurityGroupIngressOutput{}, nil + } + mock.RevokeSGEgressFunc = func(ctx context.Context, params *ec2.RevokeSecurityGroupEgressInput, + optFns ...func(*ec2.Options)) (*ec2.RevokeSecurityGroupEgressOutput, error) { + egressCalls = append(egressCalls, *params) + return &ec2.RevokeSecurityGroupEgressOutput{}, nil + } + + provider := &Provider{ec2: mock, log: mockLogger()} + + err := provider.revokeSecurityGroupRules("sg-empty") + if err != nil { + t.Fatalf("revokeSecurityGroupRules failed: %v", err) + } + + if len(ingressCalls) != 0 { + t.Errorf("Expected no RevokeSecurityGroupIngress calls for empty rules, got %d", len(ingressCalls)) + } + if len(egressCalls) != 0 { + t.Errorf("Expected no RevokeSecurityGroupEgress calls for empty rules, got %d", len(egressCalls)) + } +} + +func TestRevokeSecurityGroupRules_SkipsEmptyID(t *testing.T) { + mock := NewMockEC2Client() + provider := &Provider{ec2: mock, log: mockLogger()} + + err := provider.revokeSecurityGroupRules("") + if err != nil { + t.Fatalf("revokeSecurityGroupRules should skip empty SG ID, got: %v", err) + } +} + +func TestRevokeSecurityGroupRules_DescribeError(t *testing.T) { + mock := NewMockEC2Client() + mock.DescribeSGsFunc = func(ctx context.Context, params *ec2.DescribeSecurityGroupsInput, + optFns ...func(*ec2.Options)) (*ec2.DescribeSecurityGroupsOutput, error) { + return nil, fmt.Errorf("InvalidGroup.NotFound") + } + + provider := &Provider{ec2: mock, log: mockLogger()} + + // NotFound is not an error — SG is already gone + err := provider.revokeSecurityGroupRules("sg-gone") + if err != nil { + t.Fatalf("revokeSecurityGroupRules should handle NotFound gracefully, got: %v", err) + } +} + +func TestWaitForENIsDrained_NoENIs(t *testing.T) { + mock := NewMockEC2Client() + mock.DescribeNIsFunc = func(ctx context.Context, params *ec2.DescribeNetworkInterfacesInput, + optFns ...func(*ec2.Options)) (*ec2.DescribeNetworkInterfacesOutput, error) { + return &ec2.DescribeNetworkInterfacesOutput{ + NetworkInterfaces: []types.NetworkInterface{}, + }, nil + } + + provider := &Provider{ec2: mock, log: mockLogger()} + + err := provider.waitForENIsDrained("vpc-123") + if err != nil { + t.Fatalf("waitForENIsDrained should succeed with no ENIs, got: %v", err) + } +} + +func TestWaitForENIsDrained_SkipsEmptyVPC(t *testing.T) { + mock := NewMockEC2Client() + provider := &Provider{ec2: mock, log: mockLogger()} + + err := provider.waitForENIsDrained("") + if err != nil { + t.Fatalf("waitForENIsDrained should skip empty VPC ID, got: %v", err) + } +} + +func TestWaitForENIsDrained_ENIsDrainOnSecondPoll(t *testing.T) { + if testing.Short() { + t.Skip("skipping: poll loop sleeps 10s between calls") + } + callCount := 0 + mock := NewMockEC2Client() + mock.DescribeNIsFunc = func(ctx context.Context, params *ec2.DescribeNetworkInterfacesInput, + optFns ...func(*ec2.Options)) (*ec2.DescribeNetworkInterfacesOutput, error) { + callCount++ + if callCount == 1 { + // First call: ENI still in-use + return &ec2.DescribeNetworkInterfacesOutput{ + NetworkInterfaces: []types.NetworkInterface{ + { + NetworkInterfaceId: aws.String("eni-123"), + Status: types.NetworkInterfaceStatusInUse, + }, + }, + }, nil + } + // Second call: all drained + return &ec2.DescribeNetworkInterfacesOutput{ + NetworkInterfaces: []types.NetworkInterface{}, + }, nil + } + + provider := &Provider{ec2: mock, log: mockLogger()} + + err := provider.waitForENIsDrained("vpc-123") + if err != nil { + t.Fatalf("waitForENIsDrained should succeed after ENIs drain, got: %v", err) + } + if callCount < 2 { + t.Errorf("Expected at least 2 DescribeNetworkInterfaces calls, got %d", callCount) + } +} + +func TestWaitForENIsDrained_AvailableENIsIgnored(t *testing.T) { + mock := NewMockEC2Client() + mock.DescribeNIsFunc = func(ctx context.Context, params *ec2.DescribeNetworkInterfacesInput, + optFns ...func(*ec2.Options)) (*ec2.DescribeNetworkInterfacesOutput, error) { + // ENI exists but is in "available" state (detached) — not blocking + return &ec2.DescribeNetworkInterfacesOutput{ + NetworkInterfaces: []types.NetworkInterface{ + { + NetworkInterfaceId: aws.String("eni-avail"), + Status: types.NetworkInterfaceStatusAvailable, + }, + }, + }, nil + } + + provider := &Provider{ec2: mock, log: mockLogger()} + + err := provider.waitForENIsDrained("vpc-123") + if err != nil { + t.Fatalf("waitForENIsDrained should ignore 'available' ENIs, got: %v", err) + } +} diff --git a/pkg/provider/aws/mock_ec2_test.go b/pkg/provider/aws/mock_ec2_test.go index 250e8278..a68855fc 100644 --- a/pkg/provider/aws/mock_ec2_test.go +++ b/pkg/provider/aws/mock_ec2_test.go @@ -56,10 +56,12 @@ type MockEC2Client struct { ReplaceRTAssocFunc func(ctx context.Context, params *ec2.ReplaceRouteTableAssociationInput, optFns ...func(*ec2.Options)) (*ec2.ReplaceRouteTableAssociationOutput, error) // Security Group - CreateSGFunc func(ctx context.Context, params *ec2.CreateSecurityGroupInput, optFns ...func(*ec2.Options)) (*ec2.CreateSecurityGroupOutput, error) - AuthorizeSGFunc func(ctx context.Context, params *ec2.AuthorizeSecurityGroupIngressInput, optFns ...func(*ec2.Options)) (*ec2.AuthorizeSecurityGroupIngressOutput, error) - DeleteSGFunc func(ctx context.Context, params *ec2.DeleteSecurityGroupInput, optFns ...func(*ec2.Options)) (*ec2.DeleteSecurityGroupOutput, error) - DescribeSGsFunc func(ctx context.Context, params *ec2.DescribeSecurityGroupsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeSecurityGroupsOutput, error) + CreateSGFunc func(ctx context.Context, params *ec2.CreateSecurityGroupInput, optFns ...func(*ec2.Options)) (*ec2.CreateSecurityGroupOutput, error) + AuthorizeSGFunc func(ctx context.Context, params *ec2.AuthorizeSecurityGroupIngressInput, optFns ...func(*ec2.Options)) (*ec2.AuthorizeSecurityGroupIngressOutput, error) + DeleteSGFunc func(ctx context.Context, params *ec2.DeleteSecurityGroupInput, optFns ...func(*ec2.Options)) (*ec2.DeleteSecurityGroupOutput, error) + DescribeSGsFunc func(ctx context.Context, params *ec2.DescribeSecurityGroupsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeSecurityGroupsOutput, error) + RevokeSGIngressFunc func(ctx context.Context, params *ec2.RevokeSecurityGroupIngressInput, optFns ...func(*ec2.Options)) (*ec2.RevokeSecurityGroupIngressOutput, error) + RevokeSGEgressFunc func(ctx context.Context, params *ec2.RevokeSecurityGroupEgressInput, optFns ...func(*ec2.Options)) (*ec2.RevokeSecurityGroupEgressOutput, error) // Instance RunInstancesFunc func(ctx context.Context, params *ec2.RunInstancesInput, optFns ...func(*ec2.Options)) (*ec2.RunInstancesOutput, error) @@ -262,6 +264,24 @@ func (m *MockEC2Client) DescribeSecurityGroups(ctx context.Context, params *ec2. return &ec2.DescribeSecurityGroupsOutput{}, nil } +func (m *MockEC2Client) RevokeSecurityGroupIngress(ctx context.Context, + params *ec2.RevokeSecurityGroupIngressInput, + optFns ...func(*ec2.Options)) (*ec2.RevokeSecurityGroupIngressOutput, error) { + if m.RevokeSGIngressFunc != nil { + return m.RevokeSGIngressFunc(ctx, params, optFns...) + } + return &ec2.RevokeSecurityGroupIngressOutput{}, nil +} + +func (m *MockEC2Client) RevokeSecurityGroupEgress(ctx context.Context, + params *ec2.RevokeSecurityGroupEgressInput, + optFns ...func(*ec2.Options)) (*ec2.RevokeSecurityGroupEgressOutput, error) { + if m.RevokeSGEgressFunc != nil { + return m.RevokeSGEgressFunc(ctx, params, optFns...) + } + return &ec2.RevokeSecurityGroupEgressOutput{}, nil +} + // Instance operations func (m *MockEC2Client) RunInstances(ctx context.Context, params *ec2.RunInstancesInput, optFns ...func(*ec2.Options)) (*ec2.RunInstancesOutput, error) { if m.RunInstancesFunc != nil { diff --git a/pkg/testutil/mocks/aws.go b/pkg/testutil/mocks/aws.go index 6994506d..846ddd19 100644 --- a/pkg/testutil/mocks/aws.go +++ b/pkg/testutil/mocks/aws.go @@ -83,6 +83,13 @@ type EC2Client interface { params *ec2.DescribeSecurityGroupsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeSecurityGroupsOutput, error) + RevokeSecurityGroupIngress(ctx context.Context, + params *ec2.RevokeSecurityGroupIngressInput, + optFns ...func(*ec2.Options)) (*ec2.RevokeSecurityGroupIngressOutput, error) + RevokeSecurityGroupEgress(ctx context.Context, + params *ec2.RevokeSecurityGroupEgressInput, + optFns ...func(*ec2.Options)) (*ec2.RevokeSecurityGroupEgressOutput, error) + RunInstances(ctx context.Context, params *ec2.RunInstancesInput, optFns ...func(*ec2.Options)) (*ec2.RunInstancesOutput, error) TerminateInstances(ctx context.Context, @@ -476,3 +483,13 @@ func (m *MockEC2Client) DescribeNatGateways(ctx context.Context, params *ec2.Des func (m *MockEC2Client) ModifySubnetAttribute(ctx context.Context, params *ec2.ModifySubnetAttributeInput, optFns ...func(*ec2.Options)) (*ec2.ModifySubnetAttributeOutput, error) { return &ec2.ModifySubnetAttributeOutput{}, nil } + +// Security Group Revoke operations + +func (m *MockEC2Client) RevokeSecurityGroupIngress(ctx context.Context, params *ec2.RevokeSecurityGroupIngressInput, optFns ...func(*ec2.Options)) (*ec2.RevokeSecurityGroupIngressOutput, error) { + return &ec2.RevokeSecurityGroupIngressOutput{}, nil +} + +func (m *MockEC2Client) RevokeSecurityGroupEgress(ctx context.Context, params *ec2.RevokeSecurityGroupEgressInput, optFns ...func(*ec2.Options)) (*ec2.RevokeSecurityGroupEgressOutput, error) { + return &ec2.RevokeSecurityGroupEgressOutput{}, nil +}