Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions pkg/provider/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@ import (
)

const (
StackitProviderName = "stackit"
StackitMachineLabel = "kubernetes.io/machine"
StackitMachineClassLabel = "kubernetes.io/machineclass"
StackitProviderName = "stackit"
)

// GetVolumeIDs extracts volume IDs from PersistentVolume specs
Expand Down
6 changes: 3 additions & 3 deletions pkg/provider/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ func (p *Provider) createServerRequest(req *driver.CreateMachineRequest, provide
}

// Add MCM-specific labels for server identification and orphan VM detection
labels[StackitMachineLabel] = req.Machine.Name
labels[StackitMachineClassLabel] = req.MachineClass.Name
labels[p.GetMachineLabelKey()] = req.Machine.Name
labels[p.GetMachineClassLabelKey()] = req.MachineClass.Name

// Create server request
createReq := &client.CreateServerRequest{
Expand Down Expand Up @@ -212,7 +212,7 @@ func (p *Provider) createServerRequest(req *driver.CreateMachineRequest, provide
func (p *Provider) getServerByName(ctx context.Context, projectID, region, serverName string) (*client.Server, error) {
// Check if the server got already created
labelSelector := map[string]string{
StackitMachineLabel: serverName,
p.GetMachineLabelKey(): serverName,
}
servers, err := p.client.ListServers(ctx, projectID, region, labelSelector)
if err != nil {
Expand Down
6 changes: 3 additions & 3 deletions pkg/provider/list.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func (p *Provider) ListMachines(ctx context.Context, req *driver.ListMachinesReq

// Call STACKIT API to list all servers
labelSelector := map[string]string{
StackitMachineClassLabel: req.MachineClass.Name,
p.GetMachineClassLabelKey(): req.MachineClass.Name,
}
servers, err := p.client.ListServers(ctx, projectID, providerSpec.Region, labelSelector)
if err != nil {
Expand All @@ -51,15 +51,15 @@ func (p *Provider) ListMachines(ctx context.Context, req *driver.ListMachinesReq
}

// Filter servers by MachineClass label
// We use the "kubernetes.io/machineclass" label to identify which servers belong to this MachineClass
// We use the label to identify which servers belong to this MachineClass
machineList := make(map[string]string)
for _, server := range servers {
// Generate ProviderID in format: stackit://<projectId>/<serverId>
providerID := fmt.Sprintf("stackit://%s/%s", projectID, server.ID)

// Get machine name from labels (fallback to server name if not found)
machineName := server.Name
if machineLabel, ok := server.Labels[StackitMachineLabel]; ok {
if machineLabel, ok := server.Labels[p.GetMachineLabelKey()]; ok {
machineName = machineLabel
}

Expand Down
43 changes: 42 additions & 1 deletion pkg/provider/list_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ var _ = Describe("ListMachines", func() {
ctx = context.Background()
mockClient = &mock.StackitClient{}
provider = &Provider{
client: mockClient,
client: mockClient,
customLabelDomain: "kubernetes.io",
}

// Create secret with projectId
Expand Down Expand Up @@ -142,4 +143,44 @@ var _ = Describe("ListMachines", func() {
Expect(statusErr.Code()).To(Equal(codes.Internal))
})
})

Context("with custom label domain", func() {
It("should use custom domain in labels", func() {
customProvider := &Provider{
client: mockClient,
customLabelDomain: "custom.domain",
}

Expect(customProvider.GetMachineLabelKey()).To(Equal("custom.domain/machine"))
Expect(customProvider.GetMachineClassLabelKey()).To(Equal("custom.domain/machineclass"))
})

It("should list machines filtered by custom MachineClass label", func() {
customProvider := &Provider{
client: mockClient,
customLabelDomain: "custom.domain",
}

mockClient.ListServersFunc = func(_ context.Context, _, _ string, selector map[string]string) ([]*client.Server, error) {
Expect(selector["custom.domain/machineclass"]).To(Equal("test-machine-class"))

return []*client.Server{
{
ID: "server-1",
Name: "machine-1",
Labels: map[string]string{
"custom.domain/machineclass": "test-machine-class",
"custom.domain/machine": "machine-1",
},
},
}, nil
}

resp, err := customProvider.ListMachines(ctx, req)

Expect(err).NotTo(HaveOccurred())
Expect(resp).NotTo(BeNil())
Expect(resp.MachineList).To(HaveLen(1))
})
})
})
26 changes: 23 additions & 3 deletions pkg/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package provider

import (
"fmt"
"os"
"sync"
"time"

Expand All @@ -28,14 +29,23 @@ type Provider struct {
// intervals need to be configurable to speed up tests
pollingInterval time.Duration // Interval between polling attempts
pollingTimeout time.Duration // Maximum time to wait during polling
// NOTE: only change this if you know what you are doing!
// changing this value without a migration plan could lead to orphaned cloud resources
customLabelDomain string
}

// NewProvider returns an empty provider object
func NewProvider(i spi.SessionProviderInterface) driver.Driver {
customLabelDomain := os.Getenv("CUSTOM_LABEL_DOMAIN")
if customLabelDomain == "" {
customLabelDomain = "kubernetes.io"
}

return &Provider{
SPI: i,
pollingInterval: 5 * time.Second,
pollingTimeout: 10 * time.Minute,
SPI: i,
pollingInterval: 5 * time.Second,
pollingTimeout: 10 * time.Minute,
customLabelDomain: customLabelDomain,
}
}

Expand Down Expand Up @@ -72,3 +82,13 @@ func (p *Provider) ensureClient(serviceAccountKey string) error {

return p.clientErr
}

// GetMachineLabelKey returns the fully-qualified machine label key using the configured label domain
func (p *Provider) GetMachineLabelKey() string {
return fmt.Sprintf("%s/machine", p.customLabelDomain)
}

// GetMachineClassLabelKey returns the fully-qualified machine class label key using the configured label domain
func (p *Provider) GetMachineClassLabelKey() string {
return fmt.Sprintf("%s/machineclass", p.customLabelDomain)
}