diff --git a/pkg/provider/core.go b/pkg/provider/core.go index 1cfacc0..90d438a 100644 --- a/pkg/provider/core.go +++ b/pkg/provider/core.go @@ -10,9 +10,7 @@ import ( ) const ( - StackitProviderName = "stackit" - StackitMachineLabel = "kubernetes.io/machine" - StackitMachineClassLabel = "kubernetes.io/machineclass" + StackitProviderName = "stackit" ) // GetVolumeIDs extracts volume IDs from PersistentVolume specs diff --git a/pkg/provider/create.go b/pkg/provider/create.go index 0c13e5a..e2c947a 100644 --- a/pkg/provider/create.go +++ b/pkg/provider/create.go @@ -114,8 +114,8 @@ func (p *Provider) createServerRequest(req *driver.CreateMachineRequest, provide } // Add MCM-specific labels for server identification and orphan VM detection - labels[StackitMachineLabel] = req.Machine.Name - labels[StackitMachineClassLabel] = req.MachineClass.Name + labels[p.GetMachineLabelKey()] = req.Machine.Name + labels[p.GetMachineClassLabelKey()] = req.MachineClass.Name // Create server request createReq := &client.CreateServerRequest{ @@ -212,7 +212,7 @@ func (p *Provider) createServerRequest(req *driver.CreateMachineRequest, provide func (p *Provider) getServerByName(ctx context.Context, projectID, region, serverName string) (*client.Server, error) { // Check if the server got already created labelSelector := map[string]string{ - StackitMachineLabel: serverName, + p.GetMachineLabelKey(): serverName, } servers, err := p.client.ListServers(ctx, projectID, region, labelSelector) if err != nil { diff --git a/pkg/provider/list.go b/pkg/provider/list.go index e96b84a..4b9f73e 100644 --- a/pkg/provider/list.go +++ b/pkg/provider/list.go @@ -42,7 +42,7 @@ func (p *Provider) ListMachines(ctx context.Context, req *driver.ListMachinesReq // Call STACKIT API to list all servers labelSelector := map[string]string{ - StackitMachineClassLabel: req.MachineClass.Name, + p.GetMachineClassLabelKey(): req.MachineClass.Name, } servers, err := p.client.ListServers(ctx, projectID, providerSpec.Region, labelSelector) if err != nil { @@ -51,7 +51,7 @@ func (p *Provider) ListMachines(ctx context.Context, req *driver.ListMachinesReq } // Filter servers by MachineClass label - // We use the "kubernetes.io/machineclass" label to identify which servers belong to this MachineClass + // We use the label to identify which servers belong to this MachineClass machineList := make(map[string]string) for _, server := range servers { // Generate ProviderID in format: stackit:/// @@ -59,7 +59,7 @@ func (p *Provider) ListMachines(ctx context.Context, req *driver.ListMachinesReq // Get machine name from labels (fallback to server name if not found) machineName := server.Name - if machineLabel, ok := server.Labels[StackitMachineLabel]; ok { + if machineLabel, ok := server.Labels[p.GetMachineLabelKey()]; ok { machineName = machineLabel } diff --git a/pkg/provider/list_test.go b/pkg/provider/list_test.go index 162d4c8..b77dba3 100644 --- a/pkg/provider/list_test.go +++ b/pkg/provider/list_test.go @@ -32,7 +32,8 @@ var _ = Describe("ListMachines", func() { ctx = context.Background() mockClient = &mock.StackitClient{} provider = &Provider{ - client: mockClient, + client: mockClient, + customLabelDomain: "kubernetes.io", } // Create secret with projectId @@ -142,4 +143,44 @@ var _ = Describe("ListMachines", func() { Expect(statusErr.Code()).To(Equal(codes.Internal)) }) }) + + Context("with custom label domain", func() { + It("should use custom domain in labels", func() { + customProvider := &Provider{ + client: mockClient, + customLabelDomain: "custom.domain", + } + + Expect(customProvider.GetMachineLabelKey()).To(Equal("custom.domain/machine")) + Expect(customProvider.GetMachineClassLabelKey()).To(Equal("custom.domain/machineclass")) + }) + + It("should list machines filtered by custom MachineClass label", func() { + customProvider := &Provider{ + client: mockClient, + customLabelDomain: "custom.domain", + } + + mockClient.ListServersFunc = func(_ context.Context, _, _ string, selector map[string]string) ([]*client.Server, error) { + Expect(selector["custom.domain/machineclass"]).To(Equal("test-machine-class")) + + return []*client.Server{ + { + ID: "server-1", + Name: "machine-1", + Labels: map[string]string{ + "custom.domain/machineclass": "test-machine-class", + "custom.domain/machine": "machine-1", + }, + }, + }, nil + } + + resp, err := customProvider.ListMachines(ctx, req) + + Expect(err).NotTo(HaveOccurred()) + Expect(resp).NotTo(BeNil()) + Expect(resp.MachineList).To(HaveLen(1)) + }) + }) }) diff --git a/pkg/provider/provider.go b/pkg/provider/provider.go index 684523c..c60ada5 100644 --- a/pkg/provider/provider.go +++ b/pkg/provider/provider.go @@ -2,6 +2,7 @@ package provider import ( "fmt" + "os" "sync" "time" @@ -28,14 +29,23 @@ type Provider struct { // intervals need to be configurable to speed up tests pollingInterval time.Duration // Interval between polling attempts pollingTimeout time.Duration // Maximum time to wait during polling + // NOTE: only change this if you know what you are doing! + // changing this value without a migration plan could lead to orphaned cloud resources + customLabelDomain string } // NewProvider returns an empty provider object func NewProvider(i spi.SessionProviderInterface) driver.Driver { + customLabelDomain := os.Getenv("CUSTOM_LABEL_DOMAIN") + if customLabelDomain == "" { + customLabelDomain = "kubernetes.io" + } + return &Provider{ - SPI: i, - pollingInterval: 5 * time.Second, - pollingTimeout: 10 * time.Minute, + SPI: i, + pollingInterval: 5 * time.Second, + pollingTimeout: 10 * time.Minute, + customLabelDomain: customLabelDomain, } } @@ -72,3 +82,13 @@ func (p *Provider) ensureClient(serviceAccountKey string) error { return p.clientErr } + +// GetMachineLabelKey returns the fully-qualified machine label key using the configured label domain +func (p *Provider) GetMachineLabelKey() string { + return fmt.Sprintf("%s/machine", p.customLabelDomain) +} + +// GetMachineClassLabelKey returns the fully-qualified machine class label key using the configured label domain +func (p *Provider) GetMachineClassLabelKey() string { + return fmt.Sprintf("%s/machineclass", p.customLabelDomain) +}