diff --git a/go.sum b/go.sum index 0b14b3d1..1ea4d591 100644 --- a/go.sum +++ b/go.sum @@ -234,4 +234,4 @@ sigs.k8s.io/randfill v1.0.0/go.mod h1:XeLlZ/jmk4i1HRopwe7/aU3H5n1zNUcX6TM94b3QxO sigs.k8s.io/structured-merge-diff/v6 v6.3.0 h1:jTijUJbW353oVOd9oTlifJqOGEkUw2jB/fXCbTiQEco= sigs.k8s.io/structured-merge-diff/v6 v6.3.0/go.mod h1:M3W8sfWvn2HhQDIbGWj3S099YozAsymCo/wrT5ohRUE= sigs.k8s.io/yaml v1.6.0 h1:G8fkbMSAFqgEFgh4b1wmtzDnioxFCUgTZhlbj5P9QYs= -sigs.k8s.io/yaml v1.6.0/go.mod h1:796bPqUfzR/0jLAl6XjHl3Ck7MiyVv8dbTdyT3/pMf4= +sigs.k8s.io/yaml v1.6.0/go.mod h1:796bPqUfzR/0jLAl6XjHl3Ck7MiyVv8dbTdyT3/pMf4= \ No newline at end of file diff --git a/pkg/provider/apis/provider_spec.go b/pkg/provider/apis/provider_spec.go index f8b91ba8..cb1822d4 100644 --- a/pkg/provider/apis/provider_spec.go +++ b/pkg/provider/apis/provider_spec.go @@ -28,6 +28,10 @@ type ProviderSpec struct { // Optional field. If not specified, the server may use default networking or require manual configuration. Networking *NetworkingSpec `json:"networking,omitempty"` + // AllowedAddresses are the IP address ranges (CIDRs) allowed to originate traffic from the server's network interface. + // Optional field. If specified, these ranges are configured as AllowedAddresses on the network interface of the server to bypass anti-spoofing rules. + AllowedAddresses []string `json:"allowedAddresses,omitempty"` + // SecurityGroups are the names of security groups to attach to the server // Optional field. If not specified, the project's default security group will be used. SecurityGroups []string `json:"securityGroups,omitempty"` diff --git a/pkg/provider/apis/validation/validation.go b/pkg/provider/apis/validation/validation.go index c27bac58..bc7a7f55 100644 --- a/pkg/provider/apis/validation/validation.go +++ b/pkg/provider/apis/validation/validation.go @@ -8,6 +8,7 @@ package validation import ( "encoding/json" "fmt" + "net" "regexp" api "github.com/stackitcloud/machine-controller-manager-provider-stackit/pkg/provider/apis" @@ -165,6 +166,15 @@ func ValidateProviderSpecNSecret(spec *api.ProviderSpec, secrets *corev1.Secret) } } + // Validate AllowedAddresses + if len(spec.AllowedAddresses) > 0 { + for _, cidr := range spec.AllowedAddresses { + if _, _, err := net.ParseCIDR(cidr); err != nil { + errors = append(errors, fmt.Errorf("providerSpec.allowedAddresses has an invalid CIDR: %s", cidr)) + } + } + } + // Validate AffinityGroup if spec.AffinityGroup != "" { if !isValidUUID(spec.AffinityGroup) { diff --git a/pkg/provider/core.go b/pkg/provider/core.go index f36e799d..0dcad840 100644 --- a/pkg/provider/core.go +++ b/pkg/provider/core.go @@ -10,10 +10,13 @@ import ( "encoding/base64" "errors" "fmt" + "slices" + "strings" "github.com/gardener/machine-controller-manager/pkg/util/provider/driver" "github.com/gardener/machine-controller-manager/pkg/util/provider/machinecodes/codes" "github.com/gardener/machine-controller-manager/pkg/util/provider/machinecodes/status" + api "github.com/stackitcloud/machine-controller-manager-provider-stackit/pkg/provider/apis" "github.com/stackitcloud/machine-controller-manager-provider-stackit/pkg/provider/apis/validation" "k8s.io/klog/v2" ) @@ -182,27 +185,106 @@ func (p *Provider) CreateMachine(ctx context.Context, req *driver.CreateMachineR createReq.Metadata = providerSpec.Metadata } - // Call STACKIT API to create server - server, err := p.client.CreateServer(ctx, projectID, providerSpec.Region, createReq) + // check if server already exists + server, err := p.getServerByName(ctx, projectID, providerSpec.Region, req.Machine.Name) if err != nil { - klog.Errorf("Failed to create server for machine %q: %v", req.Machine.Name, err) - return nil, status.Error(codes.Internal, fmt.Sprintf("failed to create server: %v", err)) + klog.Errorf("Failed to fetch server for machine %q: %v", req.Machine.Name, err) + return nil, status.Error(codes.Unavailable, fmt.Sprintf("failed to fetch server: %v", err)) } - // Generate ProviderID in format: stackit:/// - providerID := fmt.Sprintf("%s://%s/%s", StackitProviderName, projectID, server.ID) + if server == nil { + // Call STACKIT API to create server + server, err = p.client.CreateServer(ctx, projectID, providerSpec.Region, createReq) + if err != nil { + klog.Errorf("Failed to create server for machine %q: %v", req.Machine.Name, err) + return nil, status.Error(codes.Unavailable, fmt.Sprintf("failed to create server: %v", err)) + } + } - // NodeName is the machine name (will register with this name in Kubernetes) - nodeName := req.Machine.Name + if err := p.patchNetworkInterface(ctx, projectID, server.ID, providerSpec); err != nil { + klog.Errorf("Failed to patch network interface for server %q: %v", req.Machine.Name, err) + return nil, status.Error(codes.Unavailable, fmt.Sprintf("failed to patch network interface for server: %v", err)) + } + // Generate ProviderID in format: stackit:/// + providerID := fmt.Sprintf("%s://%s/%s", StackitProviderName, projectID, server.ID) klog.V(2).Infof("Successfully created server %q with ID %q for machine %q", server.Name, server.ID, req.Machine.Name) return &driver.CreateMachineResponse{ ProviderID: providerID, - NodeName: nodeName, + NodeName: req.Machine.Name, }, nil } +func (p *Provider) getServerByName(ctx context.Context, projectID, region, serverName string) (*Server, error) { + // Check if the server got already created + labelSelector := map[string]string{ + StackitMachineLabel: serverName, + } + servers, err := p.client.ListServers(ctx, projectID, region, labelSelector) + if err != nil { + return nil, fmt.Errorf("SDK ListServers with labelSelector: %v failed: %w", labelSelector, err) + } + + if len(servers) > 1 { + return nil, fmt.Errorf("%v servers found for server name %v", len(servers), serverName) + } + + if len(servers) == 1 { + return servers[0], nil + } + + // no servers found len == 0 + return nil, nil +} + +func (p *Provider) patchNetworkInterface(ctx context.Context, projectID, serverID string, providerSpec *api.ProviderSpec) error { + if len(providerSpec.AllowedAddresses) == 0 { + return nil + } + + nics, err := p.client.GetNICsForServer(ctx, projectID, providerSpec.Region, serverID) + if err != nil { + return fmt.Errorf("failed to get NICs for server %q: %w", serverID, err) + } + + if len(nics) == 0 { + return fmt.Errorf("failed to find NIC for server %q", serverID) + } + + for _, nic := range nics { + // if networking is not set, server is inside the default network + // just patch the interface since the server should only have one + if providerSpec.Networking != nil { + // only process interfaces that are either in the configured network (NetworkID) or are defined in NICIDs + if providerSpec.Networking.NetworkID != nic.NetworkID && !slices.Contains(providerSpec.Networking.NICIDs, nic.ID) { + continue + } + } + + updateNic := false + // check if every cidr in providerspec.allowedAddresses is inside the nic allowedAddresses + for _, allowedAddress := range providerSpec.AllowedAddresses { + if !slices.Contains(nic.AllowedAddresses, allowedAddress) { + nic.AllowedAddresses = append(nic.AllowedAddresses, allowedAddress) + updateNic = true + } + } + + if !updateNic { + continue + } + + if _, err := p.client.UpdateNIC(ctx, projectID, providerSpec.Region, nic.NetworkID, nic.ID, nic.AllowedAddresses); err != nil { + return fmt.Errorf("failed to update allowed addresses for NIC %s: %w", nic.ID, err) + } + + klog.V(2).Infof("Updated allowed addresses for NIC %s to %v", nic.ID, nic.AllowedAddresses) + } + + return nil +} + // DeleteMachine handles a machine deletion request by deleting the STACKIT server // // This method deletes the server identified by the ProviderID from STACKIT infrastructure. @@ -216,33 +298,52 @@ func (p *Provider) DeleteMachine(ctx context.Context, req *driver.DeleteMachineR klog.V(2).Infof("Machine deletion request has been received for %q", req.Machine.Name) defer klog.V(2).Infof("Machine deletion request has been processed for %q", req.Machine.Name) - // Validate ProviderID exists - if req.Machine.Spec.ProviderID == "" { - return nil, status.Error(codes.InvalidArgument, "ProviderID is required") - } - // Extract credentials from Secret serviceAccountKey := string(req.Secret.Data["serviceaccount.json"]) - // Initialize client on first use (lazy initialization) if err := p.ensureClient(serviceAccountKey); err != nil { return nil, status.Error(codes.Internal, fmt.Sprintf("failed to initialize STACKIT client: %v", err)) } - // Parse ProviderID to extract projectID and serverID - projectID, serverID, err := parseProviderID(req.Machine.Spec.ProviderID) + var projectID, serverID string + var err error + if req.Machine.Spec.ProviderID != "" { + if !strings.HasPrefix(req.Machine.Spec.ProviderID, StackitProviderName) { + return nil, status.Error(codes.InvalidArgument, "providerID is not empty and does not start with stackit://") + } + + // Parse ProviderID to extract projectID and serverID + projectID, serverID, err = parseProviderID(req.Machine.Spec.ProviderID) + if err != nil { + klog.V(2).Infof("invalid ProviderID format: %v", err) + } + } + if projectID == "" { projectID = string(req.Secret.Data["project-id"]) } - if err != nil { - return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("invalid ProviderID format: %v", err)) - } providerSpec, err := decodeProviderSpec(req.MachineClass) if err != nil { return nil, status.Error(codes.Internal, err.Error()) } + if serverID == "" { + server, err := p.getServerByName(ctx, projectID, providerSpec.Region, req.Machine.Name) + if err != nil { + return nil, status.Error(codes.Internal, fmt.Sprintf("failed to find server by name: %v", err)) + } + + if server != nil { + serverID = server.ID + } + } + + if serverID == "" { + klog.V(2).Infof("Server is already deleted for machine %q", req.Machine.Name) + return &driver.DeleteMachineResponse{}, nil + } + // Call STACKIT API to delete server err = p.client.DeleteServer(ctx, projectID, providerSpec.Region, serverID) if err != nil { @@ -364,7 +465,9 @@ func (p *Provider) ListMachines(ctx context.Context, req *driver.ListMachinesReq } // Call STACKIT API to list all servers - labelSelector := fmt.Sprintf("%s=%s", StackitMachineClassLabel, req.MachineClass.Name) + labelSelector := map[string]string{ + StackitMachineClassLabel: req.MachineClass.Name, + } servers, err := p.client.ListServers(ctx, projectID, providerSpec.Region, labelSelector) if err != nil { klog.Errorf("Failed to list servers for MachineClass %q: %v", req.MachineClass.Name, err) diff --git a/pkg/provider/core_create_machine_basic_test.go b/pkg/provider/core_create_machine_basic_test.go index 30ef4a8b..1fc72358 100644 --- a/pkg/provider/core_create_machine_basic_test.go +++ b/pkg/provider/core_create_machine_basic_test.go @@ -184,7 +184,7 @@ var _ = Describe("CreateMachine", func() { Expect(err).To(HaveOccurred()) statusErr, ok := status.FromError(err) Expect(ok).To(BeTrue()) - Expect(statusErr.Code()).To(Equal(codes.Internal)) + Expect(statusErr.Code()).To(Equal(codes.Unavailable)) }) }) }) diff --git a/pkg/provider/core_create_machine_networking_test.go b/pkg/provider/core_create_machine_networking_test.go index 94122098..b8649d93 100644 --- a/pkg/provider/core_create_machine_networking_test.go +++ b/pkg/provider/core_create_machine_networking_test.go @@ -147,6 +147,211 @@ var _ = Describe("CreateMachine - Networking", func() { Expect(capturedReq.Networking.NICIDs[0]).To(Equal("880e8400-e29b-41d4-a716-446655440001")) Expect(capturedReq.Networking.NICIDs[1]).To(Equal("990e8400-e29b-41d4-a716-446655440002")) }) + + It("should update the allowedAddresses in the NIC when networkID is set", func() { + providerSpec := &api.ProviderSpec{ + MachineType: "c1.2", + Region: "eu01", + ImageID: "12345678-1234-1234-1234-123456789abc", + Networking: &api.NetworkingSpec{ + NetworkID: "770e8400-e29b-41d4-a716-446655440000", + }, + AllowedAddresses: []string{ + "10.0.0.1/8", + }, + } + providerSpecRaw, _ := encodeProviderSpec(providerSpec) + + machineClass = &v1alpha1.MachineClass{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-machine-class", + }, + Provider: "stackit", + ProviderSpec: runtime.RawExtension{ + Raw: providerSpecRaw, + }, + } + + req = &driver.CreateMachineRequest{ + Machine: machine, + MachineClass: machineClass, + Secret: secret, + } + + var capturedReq *CreateServerRequest + mockClient.createServerFunc = func(_ context.Context, _, _ string, req *CreateServerRequest) (*Server, error) { + capturedReq = req + return &Server{ + ID: "test-server-id", + Name: req.Name, + Status: "CREATING", + }, nil + } + + mockClient.getNICsFunc = func(_ context.Context, _, _, _ string) ([]*NIC, error) { + return []*NIC{{ + ID: "990e8400-e29b-41d4-a716-446655440002", + NetworkID: "770e8400-e29b-41d4-a716-446655440000", + AllowedAddresses: []string{}, + }}, nil + } + + var called = false + mockClient.updateNICFunc = func(_ context.Context, _, _, _, _ string, addresses []string) (*NIC, error) { + called = true + Expect(addresses).To(HaveLen(1)) + Expect(addresses[0]).To(Equal("10.0.0.1/8")) + return &NIC{ + ID: "990e8400-e29b-41d4-a716-446655440002", + NetworkID: "770e8400-e29b-41d4-a716-446655440000", + AllowedAddresses: addresses, + }, nil + } + + _, err := provider.CreateMachine(ctx, req) + + Expect(err).NotTo(HaveOccurred()) + Expect(called).To(BeTrue()) + Expect(capturedReq.Networking).NotTo(BeNil()) + Expect(capturedReq.Networking.NetworkID).ToNot(BeEmpty()) + }) + + It("should update the allowedAddresses in the NIC when NICIDs are set", func() { + providerSpec := &api.ProviderSpec{ + MachineType: "c1.2", + Region: "eu01", + ImageID: "12345678-1234-1234-1234-123456789abc", + Networking: &api.NetworkingSpec{ + NICIDs: []string{ + "880e8400-e29b-41d4-a716-446655440001", + "990e8400-e29b-41d4-a716-446655440002", + }, + }, + AllowedAddresses: []string{ + "10.0.0.1/8", + }, + } + providerSpecRaw, _ := encodeProviderSpec(providerSpec) + + machineClass = &v1alpha1.MachineClass{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-machine-class", + }, + Provider: "stackit", + ProviderSpec: runtime.RawExtension{ + Raw: providerSpecRaw, + }, + } + + req = &driver.CreateMachineRequest{ + Machine: machine, + MachineClass: machineClass, + Secret: secret, + } + + var capturedReq *CreateServerRequest + mockClient.createServerFunc = func(_ context.Context, _, _ string, req *CreateServerRequest) (*Server, error) { + capturedReq = req + return &Server{ + ID: "test-server-id", + Name: req.Name, + Status: "CREATING", + }, nil + } + + mockClient.getNICsFunc = func(_ context.Context, _, _, _ string) ([]*NIC, error) { + return []*NIC{{ + ID: "880e8400-e29b-41d4-a716-446655440001", + NetworkID: "770e8400-e29b-41d4-a716-446655440000", + AllowedAddresses: []string{}, + }}, nil + } + + var called = false + mockClient.updateNICFunc = func(_ context.Context, _, _, _, _ string, addresses []string) (*NIC, error) { + called = true + Expect(addresses).To(HaveLen(1)) + Expect(addresses[0]).To(Equal("10.0.0.1/8")) + return &NIC{ + ID: "880e8400-e29b-41d4-a716-446655440001", + NetworkID: "770e8400-e29b-41d4-a716-446655440000", + AllowedAddresses: addresses, + }, nil + } + + _, err := provider.CreateMachine(ctx, req) + + Expect(err).NotTo(HaveOccurred()) + Expect(called).To(BeTrue()) + Expect(capturedReq.Networking).NotTo(BeNil()) + Expect(capturedReq.Networking.NICIDs).To(HaveLen(2)) + }) + + It("should not update the allowedAddresses in the NIC if already present", func() { + providerSpec := &api.ProviderSpec{ + MachineType: "c1.2", + Region: "eu01", + ImageID: "12345678-1234-1234-1234-123456789abc", + Networking: &api.NetworkingSpec{ + NetworkID: "770e8400-e29b-41d4-a716-446655440000", + }, + AllowedAddresses: []string{ + "10.0.0.1/8", + }, + } + providerSpecRaw, _ := encodeProviderSpec(providerSpec) + + machineClass = &v1alpha1.MachineClass{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-machine-class", + }, + Provider: "stackit", + ProviderSpec: runtime.RawExtension{ + Raw: providerSpecRaw, + }, + } + + req = &driver.CreateMachineRequest{ + Machine: machine, + MachineClass: machineClass, + Secret: secret, + } + + var capturedReq *CreateServerRequest + mockClient.createServerFunc = func(_ context.Context, _, _ string, req *CreateServerRequest) (*Server, error) { + capturedReq = req + return &Server{ + ID: "test-server-id", + Name: req.Name, + Status: "CREATING", + }, nil + } + + mockClient.getNICsFunc = func(_ context.Context, _, _, _ string) ([]*NIC, error) { + return []*NIC{{ + ID: "990e8400-e29b-41d4-a716-446655440002", + NetworkID: "770e8400-e29b-41d4-a716-446655440000", + AllowedAddresses: []string{"10.0.0.1/8"}, + }}, nil + } + + var called = false + mockClient.updateNICFunc = func(_ context.Context, _, _, _, _ string, addresses []string) (*NIC, error) { + called = true + return &NIC{ + ID: "990e8400-e29b-41d4-a716-446655440002", + NetworkID: "770e8400-e29b-41d4-a716-446655440000", + AllowedAddresses: addresses, + }, nil + } + + _, err := provider.CreateMachine(ctx, req) + + Expect(err).NotTo(HaveOccurred()) + Expect(called).To(BeFalse()) + Expect(capturedReq.Networking).NotTo(BeNil()) + Expect(capturedReq.Networking.NetworkID).ToNot(BeEmpty()) + }) }) Context("with networking fallback to Secret", func() { diff --git a/pkg/provider/core_delete_machine_test.go b/pkg/provider/core_delete_machine_test.go index da9684d4..a748d062 100644 --- a/pkg/provider/core_delete_machine_test.go +++ b/pkg/provider/core_delete_machine_test.go @@ -115,15 +115,12 @@ var _ = Describe("DeleteMachine", func() { }) Context("with missing or invalid ProviderID", func() { - It("should return InvalidArgument when ProviderID is missing", func() { + It("should still delete the machine when ProviderID is missing", func() { machine.Spec.ProviderID = "" _, err := provider.DeleteMachine(ctx, req) - Expect(err).To(HaveOccurred()) - statusErr, ok := status.FromError(err) - Expect(ok).To(BeTrue()) - Expect(statusErr.Code()).To(Equal(codes.InvalidArgument)) + Expect(err).ToNot(HaveOccurred()) }) It("should return InvalidArgument when ProviderID has invalid format", func() { diff --git a/pkg/provider/core_get_machine_status_test.go b/pkg/provider/core_get_machine_status_test.go index c70a46a0..4470d700 100644 --- a/pkg/provider/core_get_machine_status_test.go +++ b/pkg/provider/core_get_machine_status_test.go @@ -90,7 +90,7 @@ var _ = Describe("GetMachineStatus", func() { return &Server{ ID: serverID, Name: "test-machine", - Status: "RUNNING", + Status: "ACTIVE", }, nil } @@ -112,7 +112,7 @@ var _ = Describe("GetMachineStatus", func() { return &Server{ ID: serverID, Name: "test-machine", - Status: "RUNNING", + Status: "ACTIVE", }, nil } diff --git a/pkg/provider/core_list_machines_test.go b/pkg/provider/core_list_machines_test.go index 817ca691..0e201372 100644 --- a/pkg/provider/core_list_machines_test.go +++ b/pkg/provider/core_list_machines_test.go @@ -72,8 +72,8 @@ var _ = Describe("ListMachines", func() { Context("with valid inputs", func() { It("should list machines filtered by MachineClass label", func() { - mockClient.listServersFunc = func(_ context.Context, _, _, selector string) ([]*Server, error) { - Expect(selector).To(ContainSubstring("mcm.gardener.cloud/machineclass=test-machine-class")) + mockClient.listServersFunc = func(_ context.Context, _, _ string, selector map[string]string) ([]*Server, error) { + Expect(selector["mcm.gardener.cloud/machineclass"]).To(Equal("test-machine-class")) return []*Server{ { @@ -106,7 +106,7 @@ var _ = Describe("ListMachines", func() { }) It("should return empty list when no servers match", func() { - mockClient.listServersFunc = func(_ context.Context, _, _, _ string) ([]*Server, error) { + mockClient.listServersFunc = func(_ context.Context, _, _ string, _ map[string]string) ([]*Server, error) { return []*Server{}, nil } @@ -118,7 +118,7 @@ var _ = Describe("ListMachines", func() { }) It("should return empty list when no servers exist", func() { - mockClient.listServersFunc = func(_ context.Context, _, _, _ string) ([]*Server, error) { + mockClient.listServersFunc = func(_ context.Context, _, _ string, _ map[string]string) ([]*Server, error) { return []*Server{}, nil } @@ -132,7 +132,7 @@ var _ = Describe("ListMachines", func() { Context("when STACKIT API fails", func() { It("should return Internal error on API failure", func() { - mockClient.listServersFunc = func(_ context.Context, _, _, _ string) ([]*Server, error) { + mockClient.listServersFunc = func(_ context.Context, _, _ string, _ map[string]string) ([]*Server, error) { return nil, fmt.Errorf("API connection failed") } diff --git a/pkg/provider/core_mocks_test.go b/pkg/provider/core_mocks_test.go index 0de56d0e..ce4388c8 100644 --- a/pkg/provider/core_mocks_test.go +++ b/pkg/provider/core_mocks_test.go @@ -16,7 +16,9 @@ type mockStackitClient struct { createServerFunc func(ctx context.Context, projectID, region string, req *CreateServerRequest) (*Server, error) getServerFunc func(ctx context.Context, projectID, region, serverID string) (*Server, error) deleteServerFunc func(ctx context.Context, projectID, region, serverID string) error - listServersFunc func(ctx context.Context, projectID, region, labelSelector string) ([]*Server, error) + listServersFunc func(ctx context.Context, projectID, region string, labelSelector map[string]string) ([]*Server, error) + getNICsFunc func(ctx context.Context, projectID, region, serverID string) ([]*NIC, error) + updateNICFunc func(ctx context.Context, projectID, region, networkID, nicID string, allowedAddresses []string) (*NIC, error) } func (m *mockStackitClient) CreateServer(ctx context.Context, projectID, region string, req *CreateServerRequest) (*Server, error) { @@ -37,7 +39,7 @@ func (m *mockStackitClient) GetServer(ctx context.Context, projectID, region, se return &Server{ ID: serverID, Name: "test-machine", - Status: "RUNNING", + Status: "ACTIVE", }, nil } @@ -48,13 +50,29 @@ func (m *mockStackitClient) DeleteServer(ctx context.Context, projectID, region, return nil } -func (m *mockStackitClient) ListServers(ctx context.Context, projectID, region, labelSelector string) ([]*Server, error) { +func (m *mockStackitClient) ListServers(ctx context.Context, projectID, region string, labelSelector map[string]string) ([]*Server, error) { if m.listServersFunc != nil { return m.listServersFunc(ctx, projectID, region, labelSelector) } return []*Server{}, nil } +func (m *mockStackitClient) GetNICsForServer(ctx context.Context, projectID, region, serverID string) ([]*NIC, error) { + if m.getNICsFunc != nil { + return m.getNICsFunc(ctx, projectID, region, serverID) + } + return []*NIC{}, nil +} + +func (m *mockStackitClient) UpdateNIC(ctx context.Context, projectID, region, networkID, nicID string, allowedAddresses []string) (*NIC, error) { + if m.updateNICFunc != nil { + return m.updateNICFunc(ctx, projectID, region, networkID, nicID, allowedAddresses) + } + return &NIC{}, nil +} + +// UpdateNIC updates a network interface + // encodeProviderSpec is a helper function to encode ProviderSpec for tests func encodeProviderSpec(spec *api.ProviderSpec) ([]byte, error) { return encodeProviderSpecForResponse(spec) diff --git a/pkg/provider/sdk_client.go b/pkg/provider/sdk_client.go index 72079648..fb489083 100644 --- a/pkg/provider/sdk_client.go +++ b/pkg/provider/sdk_client.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "os" + "strings" "github.com/stackitcloud/stackit-sdk-go/core/config" "github.com/stackitcloud/stackit-sdk-go/core/oapierror" @@ -87,23 +88,8 @@ func createIAASClient(serviceAccountKey string) (*iaas.APIClient, error) { // CreateServer creates a new server via STACKIT SDK // -//nolint:gocyclo,funlen//TODO:refactor +//nolint:gocyclo // TODO: refactor func (c *SdkStackitClient) CreateServer(ctx context.Context, projectID, region string, req *CreateServerRequest) (*Server, error) { - // Check if the server got already created - labelSelector := fmt.Sprintf("mcm.gardener.cloud/machine=%s", req.Name) - servers, err := c.ListServers(ctx, projectID, region, labelSelector) - if err != nil { - return nil, fmt.Errorf("SDK ListServers with labelSelector: %v failed: %w", labelSelector, err) - } - - if len(servers) > 1 { - return nil, fmt.Errorf("%v servers found for server name %v", len(servers), req.Name) - } - - if len(servers) == 1 { - return servers[0], nil - } - // Convert our request to SDK payload payload := &iaas.CreateServerPayload{ Name: ptr(req.Name), @@ -273,11 +259,19 @@ func (c *SdkStackitClient) DeleteServer(ctx context.Context, projectID, region, } // ListServers lists all servers in a project via STACKIT SDK -func (c *SdkStackitClient) ListServers(ctx context.Context, projectID, region, labelSelector string) ([]*Server, error) { +func (c *SdkStackitClient) ListServers(ctx context.Context, projectID, region string, labelSelector map[string]string) ([]*Server, error) { serverRequest := c.iaasClient.ListServers(ctx, projectID, region) - if labelSelector != "" { - serverRequest = serverRequest.LabelSelector(labelSelector) + if labelSelector != nil { + sb := strings.Builder{} + for k, v := range labelSelector { + _, err := fmt.Fprintf(&sb, "%s=%s,", k, v) + if err != nil { + return nil, fmt.Errorf("failed to format label selector: %w", err) + } + } + + serverRequest = serverRequest.LabelSelector(sb.String()) } sdkResponse, err := serverRequest.Execute() @@ -304,8 +298,68 @@ func (c *SdkStackitClient) ListServers(ctx context.Context, projectID, region, l return servers, nil } +func (c *SdkStackitClient) GetNICsForServer(ctx context.Context, projectID, region, serverID string) ([]*NIC, error) { + res, err := c.iaasClient.ListServerNICs(ctx, projectID, region, serverID).Execute() + if err != nil { + return nil, fmt.Errorf("SDK ListServerNICs failed: %w", err) + } + + if res.Items == nil { + return []*NIC{}, nil + } + + nics := make([]*NIC, 0) + for _, nic := range *res.Items { + nics = append(nics, convertSDKNICtoNIC(&nic)) + } + + return nics, nil +} + +func (c *SdkStackitClient) UpdateNIC(ctx context.Context, projectID, region, networkID, nicID string, allowedAddresses []string) (*NIC, error) { + addresses := make([]iaas.AllowedAddressesInner, len(allowedAddresses)) + + for i, addr := range allowedAddresses { + addresses[i] = iaas.AllowedAddressesInner{ + String: ptr(addr), + } + } + + payload := iaas.UpdateNicPayload{ + AllowedAddresses: &addresses, + } + + sdkNic, err := c.iaasClient.UpdateNic(ctx, projectID, region, networkID, nicID).UpdateNicPayload(payload).Execute() + if err != nil { + return nil, fmt.Errorf("SDK UpdateNic failed: %w", err) + } + + if sdkNic == nil { + return nil, nil + } + + return convertSDKNICtoNIC(sdkNic), nil +} + // Helper functions +func convertSDKNICtoNIC(nic *iaas.NIC) *NIC { + addresses := make([]string, 0) + if nic.AllowedAddresses != nil { + for _, addr := range *nic.AllowedAddresses { + if addr.String != nil { + addresses = append(addresses, *addr.String) + } + } + } + + return &NIC{ + ID: getStringValue(nic.Id), + NetworkID: getStringValue(nic.NetworkId), + AllowedAddresses: addresses, + } +} + // getStringValue safely dereferences a string pointer, returning empty string if nil func getStringValue(s *string) string { if s == nil { diff --git a/pkg/provider/stackit_client.go b/pkg/provider/stackit_client.go index 442ebb3c..0fe44be9 100644 --- a/pkg/provider/stackit_client.go +++ b/pkg/provider/stackit_client.go @@ -18,6 +18,7 @@ import ( // // Note: region parameter is required by STACKIT SDK v1.0.0+ // It must be extracted from the Secret (e.g., "eu01-1", "eu01-2") +// nolint:dupl // the duplicates are mock functions type StackitClient interface { // CreateServer creates a new server in STACKIT CreateServer(ctx context.Context, projectID, region string, req *CreateServerRequest) (*Server, error) @@ -26,7 +27,11 @@ type StackitClient interface { // DeleteServer deletes a server by ID from STACKIT DeleteServer(ctx context.Context, projectID, region, serverID string) error // ListServers lists all servers in a project - ListServers(ctx context.Context, projectID, region, labelSelector string) ([]*Server, error) + ListServers(ctx context.Context, projectID, region string, labelSelector map[string]string) ([]*Server, error) + // GetNICsForServer retrieves a network interfaces for a given server + GetNICsForServer(ctx context.Context, projectID, region, serverID string) ([]*NIC, error) + // UpdateNIC updates a network interface + UpdateNIC(ctx context.Context, projectID, region, networkID, nicID string, allowedAddresses []string) (*NIC, error) } // CreateServerRequest represents the request to create a server @@ -86,3 +91,10 @@ type Server struct { Status string `json:"status"` Labels map[string]string `json:"labels,omitempty"` } + +// NIC represents a STACKIT network interface +type NIC struct { + ID string + NetworkID string + AllowedAddresses []string +}