diff --git a/.github/workflows/periodic.yaml b/.github/workflows/periodic.yaml index b00231b3..42f064a0 100644 --- a/.github/workflows/periodic.yaml +++ b/.github/workflows/periodic.yaml @@ -39,7 +39,7 @@ jobs: - name: Clean up VPCs if: steps.identify-resources.outputs.AWS_VPC_IDS != '' - uses: NVIDIA/holodeck@v0.3.3 + uses: NVIDIA/holodeck@v0.3.4 with: action: cleanup vpc_ids: ${{ steps.identify-resources.outputs.AWS_VPC_IDS }} diff --git a/CHANGELOG.md b/CHANGELOG.md index 71de0a66..e877cef1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,15 @@ All notable changes to this project will be documented in this file. +## [v0.3.4] - 2026-04-01 + +### Bug Fixes + +- **fix: handle InvalidInternetGatewayID.NotFound in IGW detach** — When an Internet Gateway is already deleted, the detach step now recognizes `InvalidInternetGatewayID.NotFound` alongside `Gateway.NotAttached` and skips retries, fixing cleanup hangs where the IGW was deleted out-of-band. +- **fix: handle NotFound errors in NLB/listener/target-group deletion** — All NLB cleanup paths now check for `LoadBalancerNotFound`, `ListenerNotFound`, and `TargetGroupNotFound` before retrying, treating already-deleted resources as success. +- **fix: add SSH keepalive and handshake timeout** — SSH connections now send keepalive probes every 30 seconds to prevent session drops during long operations (e.g., `kubeadm init`). A 15-second handshake timeout prevents `connectOrDie` from blocking indefinitely against hosts that accept TCP but never complete the SSH handshake. +- **fix: suppress NotFound warnings in cleanup deleteInternetGateways** — The periodic cleanup job no longer logs misleading "Failed to detach/delete internet gateway" warnings when an IGW is already gone. + ## [v0.3.3] - 2026-04-01 ### Bug Fixes diff --git a/cmd/cli/main.go b/cmd/cli/main.go index 681c47e5..b3905bef 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -40,7 +40,7 @@ const ( // ProgramName is the canonical name of this program ProgramName = "holodeck" // ProgramVersion is the current version of the program - ProgramVersion = "0.3.3" + ProgramVersion = "0.3.4" ) type config struct { diff --git a/cmd/cli/main_test.go b/cmd/cli/main_test.go index 20fc89f1..f9ff1171 100644 --- a/cmd/cli/main_test.go +++ b/cmd/cli/main_test.go @@ -26,8 +26,8 @@ func TestNewApp(t *testing.T) { log := logger.NewLogger() app := NewApp(log) - if app.Version != "0.3.3" { - t.Errorf("expected app version %q, got %q", "0.3.3", app.Version) + if app.Version != "0.3.4" { + t.Errorf("expected app version %q, got %q", "0.3.4", app.Version) } if app.Name != "holodeck" { t.Errorf("expected app name %q, got %q", "holodeck", app.Name) diff --git a/pkg/cleanup/cleanup.go b/pkg/cleanup/cleanup.go index c3b8a0b5..d9d46333 100644 --- a/pkg/cleanup/cleanup.go +++ b/pkg/cleanup/cleanup.go @@ -480,7 +480,9 @@ func (c *Cleaner) deleteSecurityGroups(ctx context.Context, vpcID string) error _, err = c.ec2.DeleteSecurityGroup(ctx, deleteInput) if err != nil { - c.log.Warning("Failed to delete security group %s: %v", safeString(sg.GroupId), err) + if !strings.Contains(err.Error(), "InvalidGroup.NotFound") { + c.log.Warning("Failed to delete security group %s: %v", safeString(sg.GroupId), err) + } } } @@ -512,7 +514,9 @@ func (c *Cleaner) deleteSubnets(ctx context.Context, vpcID string) error { _, err = c.ec2.DeleteSubnet(ctx, deleteInput) if err != nil { - c.log.Warning("Failed to delete subnet %s: %v", safeString(subnet.SubnetId), err) + if !strings.Contains(err.Error(), "InvalidSubnetID.NotFound") { + c.log.Warning("Failed to delete subnet %s: %v", safeString(subnet.SubnetId), err) + } } } @@ -583,7 +587,9 @@ func (c *Cleaner) deleteRouteTables(ctx context.Context, vpcID string) error { _, err = c.ec2.DeleteRouteTable(ctx, deleteInput) if err != nil { - c.log.Warning("Failed to delete route table %s: %v", safeString(rt.RouteTableId), err) + if !strings.Contains(err.Error(), "InvalidRouteTableID.NotFound") { + c.log.Warning("Failed to delete route table %s: %v", safeString(rt.RouteTableId), err) + } } } @@ -617,7 +623,11 @@ func (c *Cleaner) deleteInternetGateways(ctx context.Context, vpcID string) erro _, err = c.ec2.DetachInternetGateway(ctx, detachInput) if err != nil { - c.log.Warning("Failed to detach internet gateway %s: %v", safeString(igw.InternetGatewayId), err) + errMsg := err.Error() + if !strings.Contains(errMsg, "Gateway.NotAttached") && + !strings.Contains(errMsg, "InvalidInternetGatewayID.NotFound") { + c.log.Warning("Failed to detach internet gateway %s: %v", safeString(igw.InternetGatewayId), err) + } } // Delete internet gateway @@ -627,7 +637,9 @@ func (c *Cleaner) deleteInternetGateways(ctx context.Context, vpcID string) erro _, err = c.ec2.DeleteInternetGateway(ctx, deleteInput) if err != nil { - c.log.Warning("Failed to delete internet gateway %s: %v", safeString(igw.InternetGatewayId), err) + if !strings.Contains(err.Error(), "InvalidInternetGatewayID.NotFound") { + c.log.Warning("Failed to delete internet gateway %s: %v", safeString(igw.InternetGatewayId), err) + } } } diff --git a/pkg/cleanup/cleanup_ginkgo_test.go b/pkg/cleanup/cleanup_ginkgo_test.go index 5bd53a89..b1e83520 100644 --- a/pkg/cleanup/cleanup_ginkgo_test.go +++ b/pkg/cleanup/cleanup_ginkgo_test.go @@ -783,6 +783,266 @@ var _ = Describe("Cleanup Package", func() { }) }) + Describe("deleteInternetGateways NotFound handling", func() { + BeforeEach(func() { + mockEC.DescribeInstancesFunc = func(ctx context.Context, + params *ec2.DescribeInstancesInput, + optFns ...func(*ec2.Options)) (*ec2.DescribeInstancesOutput, error) { + return &ec2.DescribeInstancesOutput{}, nil + } + mockEC.DescribeSecurityGroupsFunc = func(ctx context.Context, + params *ec2.DescribeSecurityGroupsInput, + optFns ...func(*ec2.Options)) (*ec2.DescribeSecurityGroupsOutput, error) { + return &ec2.DescribeSecurityGroupsOutput{}, nil + } + mockEC.DescribeSubnetsFunc = func(ctx context.Context, + params *ec2.DescribeSubnetsInput, + optFns ...func(*ec2.Options)) (*ec2.DescribeSubnetsOutput, error) { + return &ec2.DescribeSubnetsOutput{}, nil + } + mockEC.DescribeRouteTablesFunc = func(ctx context.Context, + params *ec2.DescribeRouteTablesInput, + optFns ...func(*ec2.Options)) (*ec2.DescribeRouteTablesOutput, error) { + return &ec2.DescribeRouteTablesOutput{}, nil + } + mockEC.DeleteVpcFunc = func(ctx context.Context, + params *ec2.DeleteVpcInput, + optFns ...func(*ec2.Options)) (*ec2.DeleteVpcOutput, error) { + return &ec2.DeleteVpcOutput{}, nil + } + }) + + It("should complete successfully when IGW detach/delete return NotFound", func() { + detachCalls := 0 + deleteCalls := 0 + + mockEC.DescribeInternetGatewaysFunc = func(ctx context.Context, + params *ec2.DescribeInternetGatewaysInput, + optFns ...func(*ec2.Options)) (*ec2.DescribeInternetGatewaysOutput, error) { + return &ec2.DescribeInternetGatewaysOutput{ + InternetGateways: []types.InternetGateway{ + {InternetGatewayId: aws.String("igw-gone")}, + }, + }, nil + } + mockEC.DetachInternetGatewayFunc = func(ctx context.Context, + params *ec2.DetachInternetGatewayInput, + optFns ...func(*ec2.Options)) (*ec2.DetachInternetGatewayOutput, error) { + detachCalls++ + return nil, fmt.Errorf("InvalidInternetGatewayID.NotFound: igw-gone does not exist") + } + mockEC.DeleteInternetGatewayFunc = func(ctx context.Context, + params *ec2.DeleteInternetGatewayInput, + optFns ...func(*ec2.Options)) (*ec2.DeleteInternetGatewayOutput, error) { + deleteCalls++ + return nil, fmt.Errorf("InvalidInternetGatewayID.NotFound: igw-gone does not exist") + } + + cleaner, err := New(log, "us-west-2", WithEC2Client(mockEC)) + Expect(err).NotTo(HaveOccurred()) + + err = cleaner.DeleteVPCResources(context.Background(), "vpc-12345") + Expect(err).NotTo(HaveOccurred()) + // NotFound errors are silently ignored — detach and delete still called + Expect(detachCalls).To(Equal(1)) + Expect(deleteCalls).To(Equal(1)) + }) + }) + + Describe("deleteSecurityGroups NotFound handling", func() { + BeforeEach(func() { + mockEC.DescribeInstancesFunc = func(ctx context.Context, + params *ec2.DescribeInstancesInput, + optFns ...func(*ec2.Options)) (*ec2.DescribeInstancesOutput, error) { + return &ec2.DescribeInstancesOutput{}, nil + } + mockEC.DescribeSubnetsFunc = func(ctx context.Context, + params *ec2.DescribeSubnetsInput, + optFns ...func(*ec2.Options)) (*ec2.DescribeSubnetsOutput, error) { + return &ec2.DescribeSubnetsOutput{}, nil + } + mockEC.DescribeRouteTablesFunc = func(ctx context.Context, + params *ec2.DescribeRouteTablesInput, + optFns ...func(*ec2.Options)) (*ec2.DescribeRouteTablesOutput, error) { + return &ec2.DescribeRouteTablesOutput{}, nil + } + mockEC.DescribeInternetGatewaysFunc = func(ctx context.Context, + params *ec2.DescribeInternetGatewaysInput, + optFns ...func(*ec2.Options)) (*ec2.DescribeInternetGatewaysOutput, error) { + return &ec2.DescribeInternetGatewaysOutput{}, nil + } + mockEC.DeleteVpcFunc = func(ctx context.Context, + params *ec2.DeleteVpcInput, + optFns ...func(*ec2.Options)) (*ec2.DeleteVpcOutput, error) { + return &ec2.DeleteVpcOutput{}, nil + } + }) + + It("should complete successfully when SG delete returns InvalidGroup.NotFound", func() { + deleteCalls := 0 + + mockEC.DescribeSecurityGroupsFunc = func(ctx context.Context, + params *ec2.DescribeSecurityGroupsInput, + optFns ...func(*ec2.Options)) (*ec2.DescribeSecurityGroupsOutput, error) { + return &ec2.DescribeSecurityGroupsOutput{ + SecurityGroups: []types.SecurityGroup{ + {GroupId: aws.String("sg-default"), GroupName: aws.String("default")}, + {GroupId: aws.String("sg-gone"), GroupName: aws.String("holodeck-sg")}, + }, + }, nil + } + mockEC.DescribeNetworkInterfacesFunc = func(ctx context.Context, + params *ec2.DescribeNetworkInterfacesInput, + optFns ...func(*ec2.Options)) (*ec2.DescribeNetworkInterfacesOutput, error) { + return &ec2.DescribeNetworkInterfacesOutput{}, nil + } + mockEC.DeleteSecurityGroupFunc = func(ctx context.Context, + params *ec2.DeleteSecurityGroupInput, + optFns ...func(*ec2.Options)) (*ec2.DeleteSecurityGroupOutput, error) { + deleteCalls++ + return nil, fmt.Errorf("InvalidGroup.NotFound: The security group '%s' does not exist", *params.GroupId) + } + + cleaner, err := New(log, "us-west-2", WithEC2Client(mockEC)) + Expect(err).NotTo(HaveOccurred()) + + err = cleaner.DeleteVPCResources(context.Background(), "vpc-12345") + Expect(err).NotTo(HaveOccurred()) + Expect(deleteCalls).To(Equal(1)) + }) + }) + + Describe("deleteSubnets NotFound handling", func() { + BeforeEach(func() { + mockEC.DescribeInstancesFunc = func(ctx context.Context, + params *ec2.DescribeInstancesInput, + optFns ...func(*ec2.Options)) (*ec2.DescribeInstancesOutput, error) { + return &ec2.DescribeInstancesOutput{}, nil + } + mockEC.DescribeSecurityGroupsFunc = func(ctx context.Context, + params *ec2.DescribeSecurityGroupsInput, + optFns ...func(*ec2.Options)) (*ec2.DescribeSecurityGroupsOutput, error) { + return &ec2.DescribeSecurityGroupsOutput{}, nil + } + mockEC.DescribeRouteTablesFunc = func(ctx context.Context, + params *ec2.DescribeRouteTablesInput, + optFns ...func(*ec2.Options)) (*ec2.DescribeRouteTablesOutput, error) { + return &ec2.DescribeRouteTablesOutput{}, nil + } + mockEC.DescribeInternetGatewaysFunc = func(ctx context.Context, + params *ec2.DescribeInternetGatewaysInput, + optFns ...func(*ec2.Options)) (*ec2.DescribeInternetGatewaysOutput, error) { + return &ec2.DescribeInternetGatewaysOutput{}, nil + } + mockEC.DeleteVpcFunc = func(ctx context.Context, + params *ec2.DeleteVpcInput, + optFns ...func(*ec2.Options)) (*ec2.DeleteVpcOutput, error) { + return &ec2.DeleteVpcOutput{}, nil + } + }) + + It("should complete successfully when subnet delete returns InvalidSubnetID.NotFound", func() { + deleteCalls := 0 + + mockEC.DescribeSubnetsFunc = func(ctx context.Context, + params *ec2.DescribeSubnetsInput, + optFns ...func(*ec2.Options)) (*ec2.DescribeSubnetsOutput, error) { + return &ec2.DescribeSubnetsOutput{ + Subnets: []types.Subnet{ + {SubnetId: aws.String("subnet-gone")}, + }, + }, nil + } + mockEC.DeleteSubnetFunc = func(ctx context.Context, + params *ec2.DeleteSubnetInput, + optFns ...func(*ec2.Options)) (*ec2.DeleteSubnetOutput, error) { + deleteCalls++ + return nil, fmt.Errorf("InvalidSubnetID.NotFound: The subnet ID '%s' does not exist", *params.SubnetId) + } + + cleaner, err := New(log, "us-west-2", WithEC2Client(mockEC)) + Expect(err).NotTo(HaveOccurred()) + + err = cleaner.DeleteVPCResources(context.Background(), "vpc-12345") + Expect(err).NotTo(HaveOccurred()) + Expect(deleteCalls).To(Equal(1)) + }) + }) + + Describe("deleteRouteTables NotFound handling", func() { + BeforeEach(func() { + mockEC.DescribeInstancesFunc = func(ctx context.Context, + params *ec2.DescribeInstancesInput, + optFns ...func(*ec2.Options)) (*ec2.DescribeInstancesOutput, error) { + return &ec2.DescribeInstancesOutput{}, nil + } + mockEC.DescribeSecurityGroupsFunc = func(ctx context.Context, + params *ec2.DescribeSecurityGroupsInput, + optFns ...func(*ec2.Options)) (*ec2.DescribeSecurityGroupsOutput, error) { + return &ec2.DescribeSecurityGroupsOutput{}, nil + } + mockEC.DescribeSubnetsFunc = func(ctx context.Context, + params *ec2.DescribeSubnetsInput, + optFns ...func(*ec2.Options)) (*ec2.DescribeSubnetsOutput, error) { + return &ec2.DescribeSubnetsOutput{}, nil + } + mockEC.DescribeInternetGatewaysFunc = func(ctx context.Context, + params *ec2.DescribeInternetGatewaysInput, + optFns ...func(*ec2.Options)) (*ec2.DescribeInternetGatewaysOutput, error) { + return &ec2.DescribeInternetGatewaysOutput{}, nil + } + mockEC.DeleteVpcFunc = func(ctx context.Context, + params *ec2.DeleteVpcInput, + optFns ...func(*ec2.Options)) (*ec2.DeleteVpcOutput, error) { + return &ec2.DeleteVpcOutput{}, nil + } + }) + + It("should complete successfully when route table delete returns InvalidRouteTableID.NotFound", func() { + deleteCalls := 0 + mainRT := true + + mockEC.DescribeRouteTablesFunc = func(ctx context.Context, + params *ec2.DescribeRouteTablesInput, + optFns ...func(*ec2.Options)) (*ec2.DescribeRouteTablesOutput, error) { + return &ec2.DescribeRouteTablesOutput{ + RouteTables: []types.RouteTable{ + { + RouteTableId: aws.String("rtb-main"), + Associations: []types.RouteTableAssociation{ + {RouteTableAssociationId: aws.String("rtbassoc-main"), Main: &mainRT}, + }, + }, + { + RouteTableId: aws.String("rtb-gone"), + Associations: []types.RouteTableAssociation{ + {RouteTableAssociationId: aws.String("rtbassoc-gone")}, + }, + }, + }, + }, nil + } + mockEC.ReplaceRouteTableAssociationFunc = func(ctx context.Context, + params *ec2.ReplaceRouteTableAssociationInput, + optFns ...func(*ec2.Options)) (*ec2.ReplaceRouteTableAssociationOutput, error) { + return &ec2.ReplaceRouteTableAssociationOutput{}, nil + } + mockEC.DeleteRouteTableFunc = func(ctx context.Context, + params *ec2.DeleteRouteTableInput, + optFns ...func(*ec2.Options)) (*ec2.DeleteRouteTableOutput, error) { + deleteCalls++ + return nil, fmt.Errorf("InvalidRouteTableID.NotFound: The routeTable ID '%s' does not exist", *params.RouteTableId) + } + + cleaner, err := New(log, "us-west-2", WithEC2Client(mockEC)) + Expect(err).NotTo(HaveOccurred()) + + err = cleaner.DeleteVPCResources(context.Background(), "vpc-12345") + Expect(err).NotTo(HaveOccurred()) + Expect(deleteCalls).To(Equal(1)) + }) + }) + Describe("deleteRouteTables", func() { BeforeEach(func() { mockEC.DescribeInstancesFunc = func(ctx context.Context, diff --git a/pkg/provider/aws/delete.go b/pkg/provider/aws/delete.go index d5158b7c..55b725d9 100644 --- a/pkg/provider/aws/delete.go +++ b/pkg/provider/aws/delete.go @@ -70,6 +70,10 @@ func (p *Provider) deleteNLBForCluster(cache *ClusterCache) error { } describeOutput, err := p.elbv2.DescribeLoadBalancers(ctx, describeInput) if err != nil { + if isNLBNotFoundError(err.Error()) { + p.log.Info("No load balancers found for %s, nothing to delete", lbName) + return nil + } return fmt.Errorf("error describing load balancers: %w", err) } @@ -625,8 +629,8 @@ func (p *Provider) deleteInternetGateway(cache *AWS) error { VpcId: &cache.Vpcid, }) if err != nil { - if strings.Contains(err.Error(), "Gateway.NotAttached") { - p.log.Info("Internet Gateway %s already detached", cache.InternetGwid) + if isAlreadyDetachedError(err.Error()) { + p.log.Info("Internet Gateway %s already detached or does not exist", cache.InternetGwid) return nil } return err @@ -712,6 +716,13 @@ func (p *Provider) deleteVPC(cache *AWS) error { return nil } +// isAlreadyDetachedError returns true if the error indicates the IGW +// is already detached or doesn't exist (both mean "nothing to detach"). +func isAlreadyDetachedError(errMsg string) bool { + return strings.Contains(errMsg, "Gateway.NotAttached") || + strings.Contains(errMsg, "InvalidInternetGatewayID.NotFound") +} + // Helper functions // waitForENIsDrained polls DescribeNetworkInterfaces until all non-available diff --git a/pkg/provider/aws/delete_igw_test.go b/pkg/provider/aws/delete_igw_test.go new file mode 100644 index 00000000..96f2db23 --- /dev/null +++ b/pkg/provider/aws/delete_igw_test.go @@ -0,0 +1,114 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aws + +import ( + "context" + "fmt" + "testing" + + "github.com/aws/aws-sdk-go-v2/service/ec2" +) + +func TestDeleteInternetGateway_DetachNotFound(t *testing.T) { + // When DetachInternetGateway returns InvalidInternetGatewayID.NotFound, + // deleteInternetGateway should treat it as success (IGW already gone) + // and proceed to the delete step without error. + detachCalls := 0 + deleteCalls := 0 + + mock := &MockEC2Client{ + DetachIGWFunc: func(ctx context.Context, params *ec2.DetachInternetGatewayInput, optFns ...func(*ec2.Options)) (*ec2.DetachInternetGatewayOutput, error) { + detachCalls++ + return nil, fmt.Errorf("InvalidInternetGatewayID.NotFound: The internetGateway ID 'igw-gone' does not exist") + }, + DeleteIGWFunc: func(ctx context.Context, params *ec2.DeleteInternetGatewayInput, optFns ...func(*ec2.Options)) (*ec2.DeleteInternetGatewayOutput, error) { + deleteCalls++ + return nil, fmt.Errorf("InvalidInternetGatewayID.NotFound: The internetGateway ID 'igw-gone' does not exist") + }, + } + + provider := &Provider{ec2: mock, log: mockLogger(), sleep: noopSleep} + cache := &AWS{InternetGwid: "igw-gone", Vpcid: "vpc-123"} + + err := provider.deleteInternetGateway(cache) + if err != nil { + t.Fatalf("expected no error when IGW is already gone, got: %v", err) + } + + // Detach should be called exactly once (NotFound stops retries) + if detachCalls != 1 { + t.Errorf("expected 1 detach call (NotFound stops retries), got %d", detachCalls) + } + + // Delete should also be called exactly once + if deleteCalls != 1 { + t.Errorf("expected 1 delete call (NotFound stops retries), got %d", deleteCalls) + } +} + +func TestDeleteInternetGateway_DetachNotAttached(t *testing.T) { + // Original behavior: Gateway.NotAttached during detach is still treated as success. + detachCalls := 0 + + mock := &MockEC2Client{ + DetachIGWFunc: func(ctx context.Context, params *ec2.DetachInternetGatewayInput, optFns ...func(*ec2.Options)) (*ec2.DetachInternetGatewayOutput, error) { + detachCalls++ + return nil, fmt.Errorf("Gateway.NotAttached: The gateway igw-123 is not attached") + }, + DeleteIGWFunc: func(ctx context.Context, params *ec2.DeleteInternetGatewayInput, optFns ...func(*ec2.Options)) (*ec2.DeleteInternetGatewayOutput, error) { + return &ec2.DeleteInternetGatewayOutput{}, nil + }, + } + + provider := &Provider{ec2: mock, log: mockLogger(), sleep: noopSleep} + cache := &AWS{InternetGwid: "igw-123", Vpcid: "vpc-456"} + + err := provider.deleteInternetGateway(cache) + if err != nil { + t.Fatalf("expected no error for Gateway.NotAttached, got: %v", err) + } + + if detachCalls != 1 { + t.Errorf("expected 1 detach call, got %d", detachCalls) + } +} + +func TestDeleteInternetGateway_DetachRealErrorRetries(t *testing.T) { + // A non-NotFound error during detach should be retried (and eventually fail). + detachCalls := 0 + + mock := &MockEC2Client{ + DetachIGWFunc: func(ctx context.Context, params *ec2.DetachInternetGatewayInput, optFns ...func(*ec2.Options)) (*ec2.DetachInternetGatewayOutput, error) { + detachCalls++ + return nil, fmt.Errorf("DependencyViolation: gateway has active connections") + }, + } + + provider := &Provider{ec2: mock, log: mockLogger(), sleep: noopSleep} + cache := &AWS{InternetGwid: "igw-busy", Vpcid: "vpc-789"} + + err := provider.deleteInternetGateway(cache) + if err == nil { + t.Fatal("expected error for DependencyViolation, got nil") + } + + // Should have retried maxRetries times + if detachCalls != maxRetries { + t.Errorf("expected %d detach calls (full retry), got %d", maxRetries, detachCalls) + } +} diff --git a/pkg/provider/aws/nlb.go b/pkg/provider/aws/nlb.go index 6bd4e85d..47bb729a 100644 --- a/pkg/provider/aws/nlb.go +++ b/pkg/provider/aws/nlb.go @@ -19,6 +19,7 @@ package aws import ( "context" "fmt" + "strings" "time" "github.com/NVIDIA/holodeck/internal/logger" @@ -28,6 +29,21 @@ import ( elbv2types "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2/types" ) +// isNLBNotFoundError returns true if the error indicates the load balancer doesn't exist. +func isNLBNotFoundError(errMsg string) bool { + return strings.Contains(errMsg, "LoadBalancerNotFound") +} + +// isTargetGroupNotFoundError returns true if the error indicates the target group doesn't exist. +func isTargetGroupNotFoundError(errMsg string) bool { + return strings.Contains(errMsg, "TargetGroupNotFound") +} + +// isListenerNotFoundError returns true if the error indicates the listener doesn't exist. +func isListenerNotFoundError(errMsg string) bool { + return strings.Contains(errMsg, "ListenerNotFound") +} + const ( // Port for Kubernetes API server k8sAPIPort = 6443 @@ -271,6 +287,11 @@ func (p *Provider) deleteNLB(cache *ClusterCache) error { _, err := p.elbv2.DeleteLoadBalancer(ctx, deleteLBInput) if err != nil { + if isNLBNotFoundError(err.Error()) { + p.log.Info("Load balancer %s already deleted", cache.LoadBalancerArn) + cancelLoading(nil) + return nil + } cancelLoading(logger.ErrLoadingFailed) return fmt.Errorf("error deleting load balancer: %w", err) } @@ -296,6 +317,9 @@ func (p *Provider) deleteListener(cache *ClusterCache) error { describeOutput, err := p.elbv2.DescribeListeners(ctx, describeInput) if err != nil { + if isListenerNotFoundError(err.Error()) || isNLBNotFoundError(err.Error()) { + return nil + } return fmt.Errorf("error describing listeners: %w", err) } @@ -310,6 +334,10 @@ func (p *Provider) deleteListener(cache *ClusterCache) error { cancelDel() if err != nil { + if isListenerNotFoundError(err.Error()) { + p.log.Info("Listener %s already deleted", aws.ToString(listener.ListenerArn)) + continue + } return fmt.Errorf("error deleting listener %s: %w", aws.ToString(listener.ListenerArn), err) } } @@ -332,6 +360,9 @@ func (p *Provider) deleteTargetGroup(cache *ClusterCache) error { defer cancelTargets() targetsOutput, err := p.elbv2.DescribeTargetHealth(ctxTargets, describeTargetsInput) + if err != nil && isTargetGroupNotFoundError(err.Error()) { + return nil + } if err == nil && len(targetsOutput.TargetHealthDescriptions) > 0 { targets := make([]elbv2types.TargetDescription, 0, len(targetsOutput.TargetHealthDescriptions)) for _, th := range targetsOutput.TargetHealthDescriptions { @@ -362,6 +393,10 @@ func (p *Provider) deleteTargetGroup(cache *ClusterCache) error { _, err = p.elbv2.DeleteTargetGroup(ctx, deleteTGInput) if err != nil { + if isTargetGroupNotFoundError(err.Error()) { + p.log.Info("Target group %s already deleted", cache.TargetGroupArn) + return nil + } return fmt.Errorf("error deleting target group: %w", err) } diff --git a/pkg/provider/aws/nlb_delete_test.go b/pkg/provider/aws/nlb_delete_test.go new file mode 100644 index 00000000..b3d3f944 --- /dev/null +++ b/pkg/provider/aws/nlb_delete_test.go @@ -0,0 +1,266 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aws + +import ( + "context" + "fmt" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" + elbv2types "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2/types" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/NVIDIA/holodeck/api/holodeck/v1alpha1" +) + +// MockELBv2Client implements internalaws.ELBv2Client for testing. +type MockELBv2Client struct { + CreateLBFunc func(ctx context.Context, params *elasticloadbalancingv2.CreateLoadBalancerInput, optFns ...func(*elasticloadbalancingv2.Options)) (*elasticloadbalancingv2.CreateLoadBalancerOutput, error) + DescribeLBsFunc func(ctx context.Context, params *elasticloadbalancingv2.DescribeLoadBalancersInput, optFns ...func(*elasticloadbalancingv2.Options)) (*elasticloadbalancingv2.DescribeLoadBalancersOutput, error) + DeleteLBFunc func(ctx context.Context, params *elasticloadbalancingv2.DeleteLoadBalancerInput, optFns ...func(*elasticloadbalancingv2.Options)) (*elasticloadbalancingv2.DeleteLoadBalancerOutput, error) + CreateTGFunc func(ctx context.Context, params *elasticloadbalancingv2.CreateTargetGroupInput, optFns ...func(*elasticloadbalancingv2.Options)) (*elasticloadbalancingv2.CreateTargetGroupOutput, error) + DescribeTGsFunc func(ctx context.Context, params *elasticloadbalancingv2.DescribeTargetGroupsInput, optFns ...func(*elasticloadbalancingv2.Options)) (*elasticloadbalancingv2.DescribeTargetGroupsOutput, error) + DescribeTHFunc func(ctx context.Context, params *elasticloadbalancingv2.DescribeTargetHealthInput, optFns ...func(*elasticloadbalancingv2.Options)) (*elasticloadbalancingv2.DescribeTargetHealthOutput, error) + DeleteTGFunc func(ctx context.Context, params *elasticloadbalancingv2.DeleteTargetGroupInput, optFns ...func(*elasticloadbalancingv2.Options)) (*elasticloadbalancingv2.DeleteTargetGroupOutput, error) + RegisterFunc func(ctx context.Context, params *elasticloadbalancingv2.RegisterTargetsInput, optFns ...func(*elasticloadbalancingv2.Options)) (*elasticloadbalancingv2.RegisterTargetsOutput, error) + DeregisterFunc func(ctx context.Context, params *elasticloadbalancingv2.DeregisterTargetsInput, optFns ...func(*elasticloadbalancingv2.Options)) (*elasticloadbalancingv2.DeregisterTargetsOutput, error) + CreateListenerFunc func(ctx context.Context, params *elasticloadbalancingv2.CreateListenerInput, optFns ...func(*elasticloadbalancingv2.Options)) (*elasticloadbalancingv2.CreateListenerOutput, error) + DescribeListenerFunc func(ctx context.Context, params *elasticloadbalancingv2.DescribeListenersInput, optFns ...func(*elasticloadbalancingv2.Options)) (*elasticloadbalancingv2.DescribeListenersOutput, error) + DeleteListenerFunc func(ctx context.Context, params *elasticloadbalancingv2.DeleteListenerInput, optFns ...func(*elasticloadbalancingv2.Options)) (*elasticloadbalancingv2.DeleteListenerOutput, error) + AddTagsFunc func(ctx context.Context, params *elasticloadbalancingv2.AddTagsInput, optFns ...func(*elasticloadbalancingv2.Options)) (*elasticloadbalancingv2.AddTagsOutput, error) +} + +func (m *MockELBv2Client) CreateLoadBalancer(ctx context.Context, params *elasticloadbalancingv2.CreateLoadBalancerInput, optFns ...func(*elasticloadbalancingv2.Options)) (*elasticloadbalancingv2.CreateLoadBalancerOutput, error) { + if m.CreateLBFunc != nil { + return m.CreateLBFunc(ctx, params, optFns...) + } + return &elasticloadbalancingv2.CreateLoadBalancerOutput{}, nil +} + +func (m *MockELBv2Client) DescribeLoadBalancers(ctx context.Context, params *elasticloadbalancingv2.DescribeLoadBalancersInput, optFns ...func(*elasticloadbalancingv2.Options)) (*elasticloadbalancingv2.DescribeLoadBalancersOutput, error) { + if m.DescribeLBsFunc != nil { + return m.DescribeLBsFunc(ctx, params, optFns...) + } + return &elasticloadbalancingv2.DescribeLoadBalancersOutput{}, nil +} + +func (m *MockELBv2Client) DeleteLoadBalancer(ctx context.Context, params *elasticloadbalancingv2.DeleteLoadBalancerInput, optFns ...func(*elasticloadbalancingv2.Options)) (*elasticloadbalancingv2.DeleteLoadBalancerOutput, error) { + if m.DeleteLBFunc != nil { + return m.DeleteLBFunc(ctx, params, optFns...) + } + return &elasticloadbalancingv2.DeleteLoadBalancerOutput{}, nil +} + +func (m *MockELBv2Client) CreateTargetGroup(ctx context.Context, params *elasticloadbalancingv2.CreateTargetGroupInput, optFns ...func(*elasticloadbalancingv2.Options)) (*elasticloadbalancingv2.CreateTargetGroupOutput, error) { + if m.CreateTGFunc != nil { + return m.CreateTGFunc(ctx, params, optFns...) + } + return &elasticloadbalancingv2.CreateTargetGroupOutput{}, nil +} + +func (m *MockELBv2Client) DescribeTargetGroups(ctx context.Context, params *elasticloadbalancingv2.DescribeTargetGroupsInput, optFns ...func(*elasticloadbalancingv2.Options)) (*elasticloadbalancingv2.DescribeTargetGroupsOutput, error) { + if m.DescribeTGsFunc != nil { + return m.DescribeTGsFunc(ctx, params, optFns...) + } + return &elasticloadbalancingv2.DescribeTargetGroupsOutput{}, nil +} + +func (m *MockELBv2Client) DescribeTargetHealth(ctx context.Context, params *elasticloadbalancingv2.DescribeTargetHealthInput, optFns ...func(*elasticloadbalancingv2.Options)) (*elasticloadbalancingv2.DescribeTargetHealthOutput, error) { + if m.DescribeTHFunc != nil { + return m.DescribeTHFunc(ctx, params, optFns...) + } + return &elasticloadbalancingv2.DescribeTargetHealthOutput{}, nil +} + +func (m *MockELBv2Client) DeleteTargetGroup(ctx context.Context, params *elasticloadbalancingv2.DeleteTargetGroupInput, optFns ...func(*elasticloadbalancingv2.Options)) (*elasticloadbalancingv2.DeleteTargetGroupOutput, error) { + if m.DeleteTGFunc != nil { + return m.DeleteTGFunc(ctx, params, optFns...) + } + return &elasticloadbalancingv2.DeleteTargetGroupOutput{}, nil +} + +func (m *MockELBv2Client) RegisterTargets(ctx context.Context, params *elasticloadbalancingv2.RegisterTargetsInput, optFns ...func(*elasticloadbalancingv2.Options)) (*elasticloadbalancingv2.RegisterTargetsOutput, error) { + if m.RegisterFunc != nil { + return m.RegisterFunc(ctx, params, optFns...) + } + return &elasticloadbalancingv2.RegisterTargetsOutput{}, nil +} + +func (m *MockELBv2Client) DeregisterTargets(ctx context.Context, params *elasticloadbalancingv2.DeregisterTargetsInput, optFns ...func(*elasticloadbalancingv2.Options)) (*elasticloadbalancingv2.DeregisterTargetsOutput, error) { + if m.DeregisterFunc != nil { + return m.DeregisterFunc(ctx, params, optFns...) + } + return &elasticloadbalancingv2.DeregisterTargetsOutput{}, nil +} + +func (m *MockELBv2Client) CreateListener(ctx context.Context, params *elasticloadbalancingv2.CreateListenerInput, optFns ...func(*elasticloadbalancingv2.Options)) (*elasticloadbalancingv2.CreateListenerOutput, error) { + if m.CreateListenerFunc != nil { + return m.CreateListenerFunc(ctx, params, optFns...) + } + return &elasticloadbalancingv2.CreateListenerOutput{}, nil +} + +func (m *MockELBv2Client) DescribeListeners(ctx context.Context, params *elasticloadbalancingv2.DescribeListenersInput, optFns ...func(*elasticloadbalancingv2.Options)) (*elasticloadbalancingv2.DescribeListenersOutput, error) { + if m.DescribeListenerFunc != nil { + return m.DescribeListenerFunc(ctx, params, optFns...) + } + return &elasticloadbalancingv2.DescribeListenersOutput{}, nil +} + +func (m *MockELBv2Client) DeleteListener(ctx context.Context, params *elasticloadbalancingv2.DeleteListenerInput, optFns ...func(*elasticloadbalancingv2.Options)) (*elasticloadbalancingv2.DeleteListenerOutput, error) { + if m.DeleteListenerFunc != nil { + return m.DeleteListenerFunc(ctx, params, optFns...) + } + return &elasticloadbalancingv2.DeleteListenerOutput{}, nil +} + +func (m *MockELBv2Client) AddTags(ctx context.Context, params *elasticloadbalancingv2.AddTagsInput, optFns ...func(*elasticloadbalancingv2.Options)) (*elasticloadbalancingv2.AddTagsOutput, error) { + if m.AddTagsFunc != nil { + return m.AddTagsFunc(ctx, params, optFns...) + } + return &elasticloadbalancingv2.AddTagsOutput{}, nil +} + +func TestDeleteNLB_LoadBalancerNotFound(t *testing.T) { + mock := &MockELBv2Client{ + DeleteLBFunc: func(ctx context.Context, params *elasticloadbalancingv2.DeleteLoadBalancerInput, optFns ...func(*elasticloadbalancingv2.Options)) (*elasticloadbalancingv2.DeleteLoadBalancerOutput, error) { + return nil, fmt.Errorf("LoadBalancerNotFound: One or more load balancers not found") + }, + } + + provider := &Provider{elbv2: mock, log: mockLogger(), sleep: noopSleep} + cache := &ClusterCache{LoadBalancerArn: "arn:aws:elasticloadbalancing:us-east-1:123:loadbalancer/net/gone/abc"} + + err := provider.deleteNLB(cache) + if err != nil { + t.Fatalf("expected no error when NLB is already deleted, got: %v", err) + } +} + +func TestDeleteNLB_RealError(t *testing.T) { + mock := &MockELBv2Client{ + DeleteLBFunc: func(ctx context.Context, params *elasticloadbalancingv2.DeleteLoadBalancerInput, optFns ...func(*elasticloadbalancingv2.Options)) (*elasticloadbalancingv2.DeleteLoadBalancerOutput, error) { + return nil, fmt.Errorf("InternalError: something went wrong") + }, + } + + provider := &Provider{elbv2: mock, log: mockLogger(), sleep: noopSleep} + cache := &ClusterCache{LoadBalancerArn: "arn:aws:elasticloadbalancing:us-east-1:123:loadbalancer/net/test/abc"} + + err := provider.deleteNLB(cache) + if err == nil { + t.Fatal("expected error for InternalError, got nil") + } +} + +func TestDeleteListener_ListenerNotFound(t *testing.T) { + mock := &MockELBv2Client{ + DescribeListenerFunc: func(ctx context.Context, params *elasticloadbalancingv2.DescribeListenersInput, optFns ...func(*elasticloadbalancingv2.Options)) (*elasticloadbalancingv2.DescribeListenersOutput, error) { + return &elasticloadbalancingv2.DescribeListenersOutput{ + Listeners: []elbv2types.Listener{ + {ListenerArn: aws.String("arn:listener/gone")}, + }, + }, nil + }, + DeleteListenerFunc: func(ctx context.Context, params *elasticloadbalancingv2.DeleteListenerInput, optFns ...func(*elasticloadbalancingv2.Options)) (*elasticloadbalancingv2.DeleteListenerOutput, error) { + return nil, fmt.Errorf("ListenerNotFound: One or more listeners not found") + }, + } + + provider := &Provider{elbv2: mock, log: mockLogger(), sleep: noopSleep} + cache := &ClusterCache{LoadBalancerArn: "arn:lb/test"} + + err := provider.deleteListener(cache) + if err != nil { + t.Fatalf("expected no error when listener is already deleted, got: %v", err) + } +} + +func TestDeleteListener_DescribeNotFound(t *testing.T) { + mock := &MockELBv2Client{ + DescribeListenerFunc: func(ctx context.Context, params *elasticloadbalancingv2.DescribeListenersInput, optFns ...func(*elasticloadbalancingv2.Options)) (*elasticloadbalancingv2.DescribeListenersOutput, error) { + return nil, fmt.Errorf("LoadBalancerNotFound: LB already gone") + }, + } + + provider := &Provider{elbv2: mock, log: mockLogger(), sleep: noopSleep} + cache := &ClusterCache{LoadBalancerArn: "arn:lb/gone"} + + err := provider.deleteListener(cache) + if err != nil { + t.Fatalf("expected no error when LB is already deleted during describe, got: %v", err) + } +} + +func TestDeleteTargetGroup_NotFound(t *testing.T) { + mock := &MockELBv2Client{ + DescribeTHFunc: func(ctx context.Context, params *elasticloadbalancingv2.DescribeTargetHealthInput, optFns ...func(*elasticloadbalancingv2.Options)) (*elasticloadbalancingv2.DescribeTargetHealthOutput, error) { + return nil, fmt.Errorf("TargetGroupNotFound: target group gone") + }, + DeleteTGFunc: func(ctx context.Context, params *elasticloadbalancingv2.DeleteTargetGroupInput, optFns ...func(*elasticloadbalancingv2.Options)) (*elasticloadbalancingv2.DeleteTargetGroupOutput, error) { + return nil, fmt.Errorf("TargetGroupNotFound: target group gone") + }, + } + + provider := &Provider{elbv2: mock, log: mockLogger(), sleep: noopSleep} + cache := &ClusterCache{TargetGroupArn: "arn:tg/gone"} + + err := provider.deleteTargetGroup(cache) + if err != nil { + t.Fatalf("expected no error when target group is already deleted, got: %v", err) + } +} + +func TestDeleteTargetGroup_RealError(t *testing.T) { + mock := &MockELBv2Client{ + DescribeTHFunc: func(ctx context.Context, params *elasticloadbalancingv2.DescribeTargetHealthInput, optFns ...func(*elasticloadbalancingv2.Options)) (*elasticloadbalancingv2.DescribeTargetHealthOutput, error) { + return &elasticloadbalancingv2.DescribeTargetHealthOutput{}, nil + }, + DeleteTGFunc: func(ctx context.Context, params *elasticloadbalancingv2.DeleteTargetGroupInput, optFns ...func(*elasticloadbalancingv2.Options)) (*elasticloadbalancingv2.DeleteTargetGroupOutput, error) { + return nil, fmt.Errorf("InternalError: something went wrong") + }, + } + + provider := &Provider{elbv2: mock, log: mockLogger(), sleep: noopSleep} + cache := &ClusterCache{TargetGroupArn: "arn:tg/test"} + + err := provider.deleteTargetGroup(cache) + if err == nil { + t.Fatal("expected error for InternalError, got nil") + } +} + +func TestDeleteNLBForCluster_DescribeNotFound(t *testing.T) { + mock := &MockELBv2Client{ + DescribeLBsFunc: func(ctx context.Context, params *elasticloadbalancingv2.DescribeLoadBalancersInput, optFns ...func(*elasticloadbalancingv2.Options)) (*elasticloadbalancingv2.DescribeLoadBalancersOutput, error) { + return nil, fmt.Errorf("LoadBalancerNotFound: One or more load balancers not found") + }, + } + + env := &v1alpha1.Environment{ + ObjectMeta: metav1.ObjectMeta{Name: "test-env"}, + } + provider := &Provider{elbv2: mock, log: mockLogger(), sleep: noopSleep, Environment: env} + cache := &ClusterCache{LoadBalancerDNS: "gone-nlb.elb.amazonaws.com"} + + err := provider.deleteNLBForCluster(cache) + if err != nil { + t.Fatalf("expected no error when NLB is already deleted, got: %v", err) + } +} diff --git a/pkg/provisioner/provisioner.go b/pkg/provisioner/provisioner.go index 83663c9f..c9363b17 100644 --- a/pkg/provisioner/provisioner.go +++ b/pkg/provisioner/provisioner.go @@ -55,6 +55,13 @@ const ( sshMaxRetries = 20 // sshRetryDelay is the delay between SSH connection retry attempts. sshRetryDelay = 1 * time.Second + // sshKeepaliveInterval is how often we send keepalive requests to prevent + // network middleboxes from dropping idle SSH connections during long-running + // commands like kubeadm init (~10-20 minutes). + sshKeepaliveInterval = 30 * time.Second + // sshHandshakeTimeout is the maximum time for the SSH handshake to complete. + // Without this, connections to unresponsive hosts block indefinitely. + sshHandshakeTimeout = 15 * time.Second ) type Provisioner struct { @@ -466,6 +473,23 @@ func addScriptHeader(tpl *bytes.Buffer) error { return nil } +// startKeepalive sends periodic SSH keepalive requests to prevent network +// middleboxes (NATs, firewalls) from dropping idle connections during +// long-running remote commands. The goroutine self-terminates when the +// client connection is closed. +func startKeepalive(client *ssh.Client) { + go func() { + ticker := time.NewTicker(sshKeepaliveInterval) + defer ticker.Stop() + for range ticker.C { + _, _, err := client.SendRequest("keepalive@holodeck", true, nil) + if err != nil { + return + } + } + }() +} + // createSshClient creates a ssh client, and retries if it fails to connect. // When transport is non-nil, it uses transport.Dial() to get a net.Conn and // creates the SSH client via ssh.NewClientConn. When transport is nil, it @@ -498,6 +522,7 @@ func connectOrDie(keyPath, userName, hostUrl string, transport Transport) (*ssh. ssh.PublicKeys(signer), }, HostKeyCallback: sshutil.TOFUHostKeyCallback(), + Timeout: sshHandshakeTimeout, } addr := hostUrl + ":22" @@ -507,12 +532,19 @@ func connectOrDie(keyPath, userName, hostUrl string, transport Transport) (*ssh. var conn net.Conn conn, err = transport.Dial() if err == nil { + // Set a deadline for the SSH handshake. ssh.NewClientConn + // does not use ClientConfig.Timeout (that only applies to + // ssh.Dial), so we set it on the underlying connection. + _ = conn.SetDeadline(time.Now().Add(sshHandshakeTimeout)) var sshConn ssh.Conn var chans <-chan ssh.NewChannel var reqs <-chan *ssh.Request sshConn, chans, reqs, err = ssh.NewClientConn(conn, addr, sshConfig) if err == nil { + // Clear the deadline for normal operation + _ = conn.SetDeadline(time.Time{}) client = ssh.NewClient(sshConn, chans, reqs) + startKeepalive(client) return client, nil } _ = conn.Close() @@ -521,6 +553,7 @@ func connectOrDie(keyPath, userName, hostUrl string, transport Transport) (*ssh. // Fall back to direct SSH dial client, err = ssh.Dial("tcp", addr, sshConfig) if err == nil { + startKeepalive(client) return client, nil } } diff --git a/pkg/provisioner/ssh_config_test.go b/pkg/provisioner/ssh_config_test.go new file mode 100644 index 00000000..92626b81 --- /dev/null +++ b/pkg/provisioner/ssh_config_test.go @@ -0,0 +1,164 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package provisioner + +import ( + "crypto/ed25519" + "crypto/rand" + "encoding/pem" + "net" + "os" + "path/filepath" + "sync/atomic" + "testing" + "time" + + "golang.org/x/crypto/ssh" + + "github.com/NVIDIA/holodeck/internal/logger" +) + +// countingTransport counts Dial() calls while connecting to a black hole. +type countingTransport struct { + addr string + calls atomic.Int32 +} + +func (t *countingTransport) Dial() (net.Conn, error) { + t.calls.Add(1) + return net.DialTimeout("tcp", t.addr, 5*time.Second) +} + +func (t *countingTransport) Target() string { return t.addr } +func (t *countingTransport) Close() error { return nil } + +// TestNew_HandshakeTimeout verifies that New() (via connectOrDie) configures +// an SSH handshake timeout. Without the timeout, ssh.NewClientConn blocks +// forever against a host that accepts TCP but never responds with the SSH +// banner. With the timeout, each attempt fails in ~15s. +// +// We verify this by connecting to a black hole server and checking that +// multiple retry attempts complete within a bounded time (proving the +// handshake timeout fires instead of blocking indefinitely). +func TestNew_HandshakeTimeout(t *testing.T) { + if testing.Short() { + t.Skip("skipping: test waits for SSH handshake timeouts") + } + + // Start a TCP listener that never performs the SSH handshake + addr := startBlackHoleServer(t) + keyPath := writeTestSSHKey(t) + log := logger.NewLogger() + + transport := &countingTransport{addr: addr} + + // Run New() in a goroutine — it will retry up to sshMaxRetries times, + // each timing out in sshHandshakeTimeout (~15s). We don't want to wait + // for all 20 retries (~5 min), so we observe progress from the outside. + errCh := make(chan error, 1) + go func() { + _, err := New(log, keyPath, "testuser", "black-hole-host", WithTransport(transport)) + errCh <- err + }() + + // Wait for at least 2 Dial() calls — proving that: + // 1. The first SSH handshake attempt timed out (didn't block forever) + // 2. The retry logic moved on to attempt #2 + deadline := time.After(45 * time.Second) + ticker := time.NewTicker(500 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-deadline: + calls := transport.calls.Load() + t.Fatalf("timed out waiting for retry progress; only %d Dial() calls observed — "+ + "handshake timeout may not be configured", calls) + case <-ticker.C: + if transport.calls.Load() >= 2 { + // Success: at least 2 attempts means the first one timed out + // and the retry loop continued. The timeout is working. + return + } + case err := <-errCh: + // New() returned — all retries exhausted + if err == nil { + t.Fatal("expected connection error, got nil") + } + calls := transport.calls.Load() + if calls < 2 { + t.Fatalf("New() returned after only %d Dial() calls — "+ + "expected multiple retries with handshake timeouts", calls) + } + return + } + } +} + +// startBlackHoleServer starts a TCP listener that accepts connections and +// holds them open indefinitely without performing an SSH handshake. +func startBlackHoleServer(t *testing.T) string { + t.Helper() + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to start listener: %v", err) + } + + done := make(chan struct{}) + go func() { + for { + conn, err := listener.Accept() + if err != nil { + return + } + go func(c net.Conn) { + <-done + _ = c.Close() + }(conn) + } + }() + + t.Cleanup(func() { + close(done) + _ = listener.Close() + }) + + return listener.Addr().String() +} + +// writeTestSSHKey creates a temporary ed25519 SSH private key in PEM format. +func writeTestSSHKey(t *testing.T) string { + t.Helper() + _, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatalf("failed to generate key: %v", err) + } + + pemBlock, err := ssh.MarshalPrivateKey(priv, "") + if err != nil { + t.Fatalf("failed to marshal private key: %v", err) + } + + pemBytes := pem.EncodeToMemory(pemBlock) + + dir := t.TempDir() + keyPath := filepath.Join(dir, "test_key") + if err := os.WriteFile(keyPath, pemBytes, 0600); err != nil { + t.Fatalf("failed to write key file: %v", err) + } + return keyPath +}