From 04627f6696e44b4a8c213001a3953fa92a96a46f Mon Sep 17 00:00:00 2001 From: xiamu Date: Mon, 19 Jan 2026 12:06:26 +0800 Subject: [PATCH 1/2] feat: deploy add novita provider 1. endpoint support novita provider. 2. initializer add novita watch replica setup. 3. specs.yaml add novita product sepc. 4. novita provider support interface.DeploymentProvider base api(eg: deploy/get/delete/update). --- app/handler/endpoint_handler.go | 3 +- app/router/router.go | 2 +- cmd/initializers.go | 70 +- config/config.example.yaml | 10 +- config/config.yaml | 8 + config/specs.yaml | 32 + pkg/config/config.go | 32 +- pkg/deploy/novita/README.md | 571 +++++++++++++ pkg/deploy/novita/client.go | 161 ++++ pkg/deploy/novita/mapper.go | 457 +++++++++++ pkg/deploy/novita/provider.go | 580 ++++++++++++++ pkg/deploy/novita/provider_test.go | 1196 ++++++++++++++++++++++++++++ pkg/deploy/novita/specs_config.go | 159 ++++ pkg/deploy/novita/types.go | 232 ++++++ pkg/provider/factory.go | 2 + 15 files changed, 3498 insertions(+), 17 deletions(-) create mode 100644 pkg/deploy/novita/README.md create mode 100644 pkg/deploy/novita/client.go create mode 100644 pkg/deploy/novita/mapper.go create mode 100644 pkg/deploy/novita/provider.go create mode 100644 pkg/deploy/novita/provider_test.go create mode 100644 pkg/deploy/novita/specs_config.go create mode 100644 pkg/deploy/novita/types.go diff --git a/app/handler/endpoint_handler.go b/app/handler/endpoint_handler.go index eb767dc..02f441a 100644 --- a/app/handler/endpoint_handler.go +++ b/app/handler/endpoint_handler.go @@ -20,7 +20,8 @@ import ( "github.com/gin-gonic/gin" ) -// EndpointHandler handles endpoint lifecycle APIs (metadata + K8s deployment) +// EndpointHandler handles endpoint lifecycle APIs (metadata + deployment) +// Supports multiple deployment providers: K8s, Novita, etc. type EndpointHandler struct { deploymentProvider interfaces.DeploymentProvider endpointService *endpointsvc.Service diff --git a/app/router/router.go b/app/router/router.go index 9d3d475..223060a 100644 --- a/app/router/router.go +++ b/app/router/router.go @@ -82,7 +82,7 @@ func (r *Router) Setup(engine *gin.Engine) { v2.POST("/job-stream/:worker_id/:task_id", r.workerHandler.SubmitResult) } - // API v1 - K8s application management interface (if enabled) + // API v1 - Endpoint management interface (K8s or Novita, if enabled) if r.endpointHandler != nil { api := engine.Group("/api/v1") { diff --git a/cmd/initializers.go b/cmd/initializers.go index e9661ad..fbba5a7 100644 --- a/cmd/initializers.go +++ b/cmd/initializers.go @@ -13,6 +13,7 @@ import ( "waverless/pkg/autoscaler" "waverless/pkg/config" "waverless/pkg/deploy/k8s" + "waverless/pkg/deploy/novita" "waverless/pkg/interfaces" "waverless/pkg/logger" "waverless/pkg/monitoring" @@ -120,6 +121,14 @@ func (app *Application) initServices() error { } } + // Get Novita deployment provider for status sync + var novitaDeployProvider *novita.NovitaDeploymentProvider + if app.config.Novita.Enabled { + if novitaProv, ok := app.deploymentProvider.(*novita.NovitaDeploymentProvider); ok { + novitaDeployProvider = novitaProv + } + } + // Initialize worker service (MySQL-based) app.workerService = service.NewWorkerService( app.mysqlRepo.Worker, @@ -198,6 +207,12 @@ func (app *Application) initServices() error { // Non-critical feature, continue startup } + // Setup Novita status watcher for endpoint status sync (when Novita is enabled) + if err := app.setupNovitaStatusWatcher(novitaDeployProvider); err != nil { + logger.WarnCtx(app.ctx, "Failed to setup Novita status watcher: %v (non-critical, continuing)", err) + // Non-critical feature, continue startup + } + return nil } @@ -307,6 +322,49 @@ func (app *Application) setupSpotInterruptionWatcher(k8sProvider *k8s.K8sDeploym return nil } +// setupNovitaStatusWatcher sets up Novita status watcher for endpoint status sync +func (app *Application) setupNovitaStatusWatcher(novitaProvider *novita.NovitaDeploymentProvider) error { + if novitaProvider == nil { + logger.InfoCtx(app.ctx, "Novita provider not available, skipping status watcher setup") + return nil + } + + logger.InfoCtx(app.ctx, "Setting up Novita status watcher for endpoint status sync...") + + // Register replica watch callback to sync status to database + err := novitaProvider.WatchReplicas(app.ctx, func(event interfaces.ReplicaEvent) { + endpoint := event.Name + + // Calculate status based on replica state + status := "Pending" + if event.AvailableReplicas == event.DesiredReplicas && event.DesiredReplicas > 0 { + status = "Running" + } else if event.DesiredReplicas == 0 { + status = "Stopped" + } + + // Update endpoint runtime state in database + if app.mysqlRepo != nil && app.mysqlRepo.Endpoint != nil { + runtimeState := map[string]interface{}{ + "replicas": event.DesiredReplicas, + "readyReplicas": event.ReadyReplicas, + "availableReplicas": event.AvailableReplicas, + } + + if err := app.mysqlRepo.Endpoint.UpdateRuntimeState(app.ctx, endpoint, status, runtimeState); err != nil { + logger.ErrorCtx(app.ctx, "Failed to update Novita endpoint runtime state: %v", err) + } + } + }) + + if err != nil { + logger.WarnCtx(app.ctx, "Failed to register Novita status watcher: %v", err) + return err + } + + return nil +} + // setupDeploymentWatcher sets up Deployment change listener (optimizes rolling updates) // This watcher only sets Pod Deletion Cost to guide K8s on which pods to delete first // It does NOT mark workers as DRAINING - that's handled by setupPodWatcher when pods are actually terminated @@ -429,12 +487,18 @@ func (app *Application) initHandlers() error { app.statisticsHandler = handler.NewStatisticsHandler(app.statisticsService, app.workerService) app.monitoringHandler = handler.NewMonitoringHandler(app.monitoringService) - // Initialize K8s Handler (Endpoint Handler) - if app.config.K8s.Enabled { + // Initialize Endpoint Handler (for K8s or Novita) + if app.config.K8s.Enabled || app.config.Novita.Enabled { if app.deploymentProvider == nil { - logger.ErrorCtx(app.ctx, "K8s is enabled but deployment provider is nil") + logger.ErrorCtx(app.ctx, "Deployment provider is enabled but provider is nil") } else { app.endpointHandler = handler.NewEndpointHandler(app.deploymentProvider, app.endpointService, app.workerService) + if app.config.K8s.Enabled { + logger.InfoCtx(app.ctx, "Endpoint handler initialized for K8s") + } + if app.config.Novita.Enabled { + logger.InfoCtx(app.ctx, "Endpoint handler initialized for Novita") + } } } diff --git a/config/config.example.yaml b/config/config.example.yaml index 6bc58bf..9bc864b 100644 --- a/config/config.example.yaml +++ b/config/config.example.yaml @@ -84,6 +84,14 @@ notification: feishu_webhook_url: "" # Example: "https://open.feishu.cn/open-apis/bot/v2/hook/xxxxxxxx" providers: - deployment: "k8s" # k8s, docker + deployment: "k8s" # k8s, docker, novita queue: "redis" # redis, mysql metadata: "mysql" # redis, mysql + +# Novita Serverless Configuration +novita: + enabled: false # Enable Novita serverless provider + api_key: "" # Your Novita API key (Bearer token) + base_url: "https://api.novita.ai" # Novita API base URL + config_dir: "./config" # Configuration directory (contains specs.yaml and templates/) + poll_interval: 10 # Poll interval for status updates (seconds, default: 10) diff --git a/config/config.yaml b/config/config.yaml index eb2d35d..d69f526 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -64,3 +64,11 @@ providers: metadata: mysql # Metadata storage: mysql (persistent), redis (ephemeral) # MySQL stores: endpoints, tasks, autoscaler configs, scaling events # Redis stores: worker heartbeats, task queues, distributed locks, cache + +# Novita Serverless Configuration +novita: + enabled: false # Enable Novita serverless provider + api_key: "" # Your Novita API key (Bearer token) + base_url: "https://api.novita.ai" # Novita API base URL + config_dir: "./config" # Configuration directory (contains specs.yaml and templates/) + poll_interval: 10 # Poll interval for status updates (seconds, default: 10) diff --git a/config/specs.yaml b/config/specs.yaml index fb3d4a3..b5879e4 100644 --- a/config/specs.yaml +++ b/config/specs.yaml @@ -92,3 +92,35 @@ specs: operator: "Equal" value: "gpu" effect: "NoSchedule" + + # Novita 5090 Single GPU + - name: "novita-5090-single" + displayName: "Novita 5090 1x GPU" + category: "gpu" + resourceType: "serverless" + resources: + gpu: "1" + gpuType: "NVIDIA GeForce RTX 5090" + cpu: "12" + memory: "50" + ephemeralStorage: "100" + platforms: + novita: + productId: "SL-serverless-3" # Replace with actual Novita product ID + region: "us-dallas-nas-2" + + # Novita 4090 Single GPU + - name: "novita-4090-single" + displayName: "Novita 4090 1x GPU" + category: "gpu" + resourceType: "serverless" + resources: + gpu: "1" + gpuType: "NVIDIA GeForce RTX 4090" + cpu: "8" + memory: "32" + ephemeralStorage: "50" + platforms: + novita: + productId: "1" # Replace with actual Novita product ID + region: "us-dallas-nas-2" \ No newline at end of file diff --git a/pkg/config/config.go b/pkg/config/config.go index d1bd900..634d940 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -18,9 +18,10 @@ type Config struct { Logger LoggerConfig `yaml:"logger"` K8s K8sConfig `yaml:"k8s"` AutoScaler AutoScalerConfig `yaml:"autoscaler"` - Docker DockerConfig `yaml:"docker"` // Docker registry authentication - Notification NotificationConfig `yaml:"notification"` // Notification configuration - Providers *ProvidersConfig `yaml:"providers,omitempty"` // Providers configuration (optional) + Docker DockerConfig `yaml:"docker"` // Docker registry authentication + Notification NotificationConfig `yaml:"notification"` // Notification configuration + Providers *ProvidersConfig `yaml:"providers,omitempty"` // Providers configuration (optional) + Novita NovitaConfig `yaml:"novita"` // Novita serverless configuration } // ServerConfig server configuration @@ -80,18 +81,18 @@ type LoggerConfig struct { // LoggerFileConfig logger file configuration type LoggerFileConfig struct { Path string `yaml:"path"` - MaxSize int `yaml:"max_size"` // MB per file (default: 100) - MaxBackups int `yaml:"max_backups"` // max backup files (default: 3) - MaxAge int `yaml:"max_age"` // days to keep (default: 7) - Compress bool `yaml:"compress"` // compress rotated files (default: false) + MaxSize int `yaml:"max_size"` // MB per file (default: 100) + MaxBackups int `yaml:"max_backups"` // max backup files (default: 3) + MaxAge int `yaml:"max_age"` // days to keep (default: 7) + Compress bool `yaml:"compress"` // compress rotated files (default: false) } // K8sConfig K8s configuration type K8sConfig struct { - Enabled bool `yaml:"enabled"` // whether to enable K8s features - Namespace string `yaml:"namespace"` // K8s namespace - Platform string `yaml:"platform"` // Platform type: generic, aliyun-ack, aws-eks - ConfigDir string `yaml:"config_dir"` // Configuration directory (specs.yaml and templates) + Enabled bool `yaml:"enabled"` // whether to enable K8s features + Namespace string `yaml:"namespace"` // K8s namespace + Platform string `yaml:"platform"` // Platform type: generic, aliyun-ack, aws-eks + ConfigDir string `yaml:"config_dir"` // Configuration directory (specs.yaml and templates) } // ProvidersConfig providers configuration @@ -129,6 +130,15 @@ type NotificationConfig struct { FeishuWebhookURL string `yaml:"feishu_webhook_url"` // Feishu (Lark) webhook URL } +// NovitaConfig Novita serverless configuration +type NovitaConfig struct { + Enabled bool `yaml:"enabled"` // Whether to enable Novita provider + APIKey string `yaml:"api_key"` // Novita API key (Bearer token) + BaseURL string `yaml:"base_url"` // API base URL, default: https://api.novita.ai + ConfigDir string `yaml:"config_dir"` // Configuration directory (specs.yaml and templates) + PollInterval int `yaml:"poll_interval"` // Poll interval for status updates (seconds, default: 10) +} + // Init initializes configuration func Init() error { configPath := os.Getenv("CONFIG_PATH") diff --git a/pkg/deploy/novita/README.md b/pkg/deploy/novita/README.md new file mode 100644 index 0000000..069aacc --- /dev/null +++ b/pkg/deploy/novita/README.md @@ -0,0 +1,571 @@ +# Novita Serverless Provider + +This package implements the `DeploymentProvider` interface for [Novita AI Serverless](https://novita.ai) platform. + +## Features + +### ✅ Implemented Core Features + +- **Deploy**: Create serverless endpoints with auto-scaling workers +- **GetApp**: Retrieve endpoint details and status +- **ListApps**: List all deployed endpoints +- **DeleteApp**: Delete endpoints and clean up resources +- **ScaleApp**: Scale worker count (min/max replicas) +- **GetAppStatus**: Get real-time endpoint status +- **UpdateDeployment**: Update endpoint configuration (image, replicas, env vars) +- **ListSpecs**: List available GPU specifications +- **GetSpec**: Get specific spec details +- **PreviewDeploymentYAML**: Preview Novita configuration as JSON +- **WatchReplicas**: Monitor endpoint status changes via polling (configurable interval) + +### ⚠️ Limitations & Differences + +The following features are **not supported** by Novita's API and will return friendly error messages: + +- **GetAppLogs**: Logs must be accessed via [Novita Dashboard](https://console.novita.ai) +- **GetPods / DescribePod / GetPodYAML**: Novita manages workers internally; use `GetApp` for worker status +- **ListPVCs**: Storage is managed by Novita; persistent storage not yet supported +- **Volume Mounts**: PVC volume mounts are not supported (Novita uses network storage) +- **ShmSize**: Shared memory size configuration not applicable to Novita +- **EnablePtrace**: Ptrace capability not applicable to Novita + +**Behavioral Differences:** + +- **Replica Watching**: Uses polling instead of real-time watch (configurable interval) +- **Worker Lifecycle**: Workers are managed by Novita's auto-scaling system +- **Health Checks**: Fixed to `/health` endpoint on port 8000 +- **Networking**: Endpoints are accessible via Novita's managed load balancer +- **Region**: Specified per-spec (not per-deployment) in `specs.yaml` + +## Configuration + +### 1. Enable Novita Provider + +Edit `config/config.yaml`: + +```yaml +providers: + deployment: "novita" # Switch from k8s to novita + +novita: + enabled: true + api_key: "your-novita-api-key-here" # Your Novita API key (Bearer token) + base_url: "https://api.novita.ai" # Novita API base URL + config_dir: "./config" # Configuration directory (contains specs.yaml) + poll_interval: 10 # Status polling interval in seconds (default: 10) +``` + +### 2. Configure Specs + +Add Novita-compatible specs to `config/specs.yaml`: + +```yaml +specs: + # Novita 5090 Single GPU + - name: "novita-5090-single" + displayName: "Novita 5090 1x GPU" + category: "gpu" + resourceType: "serverless" + resources: + gpu: "1" + gpuType: "NVIDIA GeForce RTX 5090" + cpu: "12" + memory: "50" + ephemeralStorage: "100" # Rootfs size in GB + platforms: + novita: + productId: "SL-serverless-3" # Get from Novita Console + region: "us-dallas-nas-2" # Region/cluster ID + cudaVersion: "any" # Optional: CUDA version + + # Novita 4090 Single GPU + - name: "novita-4090-single" + displayName: "Novita 4090 1x GPU" + category: "gpu" + resourceType: "serverless" + resources: + gpu: "1" + gpuType: "NVIDIA GeForce RTX 4090" + cpu: "8" + memory: "32" + ephemeralStorage: "50" + platforms: + novita: + productId: "1" # Replace with actual Novita product ID + region: "us-dallas-nas-2" + cudaVersion: "any" +``` + +**Important Notes:** +- `ephemeralStorage`: Specifies rootfs disk size in GB (e.g., "100" = 100GB) +- `productId`: Get from [Novita Console](https://console.novita.ai) +- `region`: Novita cluster/region ID (e.g., "us-dallas-nas-2") +- `cudaVersion`: Optional CUDA version specification + +## Usage + +### Deploy Endpoint + +**Via API:** + +```bash +curl -X POST http://localhost:8090/api/v1/endpoints \ + -H "Content-Type: application/json" \ + -d '{ + "endpoint": "base-test", + "specName": "novita-5090-single", + "image": "ubuntu:22.04", + "replicas": 3, + "taskTimeout": 1200, + "env": { + "MODEL_NAME": "llama-3-70b" + } + }' +``` + +**Via Go SDK:** + +```go +import ( + "context" + "waverless/pkg/interfaces" +) + +req := &interfaces.DeployRequest{ + Endpoint: "my-inference-endpoint", + SpecName: "novita-5090-single", + Image: "your-docker-image:latest", + Replicas: 2, + TaskTimeout: 1200, // Task timeout in seconds + Env: map[string]string{ + "MODEL_NAME": "llama-3-70b", + }, +} + +resp, err := provider.Deploy(ctx, req) +``` + +**Request Fields:** +- `endpoint` (required): Unique endpoint name +- `specName` (required): Spec name from `specs.yaml` +- `image` (required): Docker image URL +- `replicas` (optional): Number of workers (default: 1) +- `taskTimeout` (optional): Task timeout in seconds (default: 3600) +- `env` (optional): Environment variables +- `volumeMounts` (optional): PVC volume mounts (not fully supported by Novita yet) +- `shmSize` (optional): Shared memory size (e.g., "1Gi", not applicable to Novita) +- `enablePtrace` (optional): Enable ptrace capability (not applicable to Novita) + +**Complete API Request Example (with all fields):** + +```json +{ + "endpoint": "base-test", + "specName": "novita-5090-single", + "image": "ubuntu:22.04", + "replicas": 3, + "taskTimeout": 1200, + "maxPendingTasks": 10, + "env": { + "MODEL_NAME": "llama-3-70b", + "MODEL_VERSION": "v1.0" + }, + "minReplicas": 1, + "maxReplicas": 10, + "scaleUpThreshold": 2, + "scaleDownIdleTime": 300, + "scaleUpCooldown": 30, + "scaleDownCooldown": 60, + "priority": 50, + "enableDynamicPrio": true, + "highLoadThreshold": 10, + "priorityBoost": 20 +} +``` + +**Autoscaling Fields (optional):** +- `minReplicas`: Minimum replica count (default: 0, scale-to-zero) +- `maxReplicas`: Maximum replica count (default: 10) +- `scaleUpThreshold`: Queue threshold for scale up (default: 1) +- `scaleDownIdleTime`: Idle time before scale down in seconds (default: 300) +- `scaleUpCooldown`: Scale up cooldown in seconds (default: 30) +- `scaleDownCooldown`: Scale down cooldown in seconds (default: 60) +- `priority`: Priority for resource allocation (0-100, default: 50) +- `enableDynamicPrio`: Enable dynamic priority (default: true) +- `highLoadThreshold`: High load threshold for priority boost (default: 10) +- `priorityBoost`: Priority boost amount when high load (default: 20) + +### Region Configuration + +Region is specified in the spec configuration (`config/specs.yaml`): + +```yaml +platforms: + novita: + productId: "SL-serverless-3" + region: "us-dallas-nas-2" # Novita cluster/region ID +``` + +The region value is used as the `clusterID` when creating Novita endpoints. Common regions include: +- `us-dallas-nas-2` (US Dallas) +- `us-west-1` (US West) +- `us-east-1` (US East) + +For available regions, check the [Novita Console](https://console.novita.ai). + +### Preview Deployment Configuration + +Preview the Novita API configuration before deploying: + +**Via API:** + +```bash +curl -X POST http://localhost:8090/api/v1/endpoints/preview \ + -H "Content-Type: application/json" \ + -d '{ + "endpoint": "base-test", + "specName": "novita-5090-single", + "image": "ubuntu:22.04", + "replicas": 3, + "taskTimeout": 1200 + }' +``` + +This returns the complete Novita API request JSON that will be sent to create the endpoint. + +### Scale Endpoint + +```bash +curl -X POST http://localhost:8090/api/v1/endpoints/base-test/scale \ + -H "Content-Type: application/json" \ + -d '{"replicas": 5}' +``` + +**Via Go SDK:** + +```go +// Scale to 5 workers +err := provider.ScaleApp(ctx, "base-test", 5) +``` + +### Get Endpoint Status + +**Via API:** + +```bash +# Get detailed endpoint info +curl http://localhost:8090/api/v1/endpoints/base-test + +# Get status only +curl http://localhost:8090/api/v1/endpoints/base-test/status +``` + +**Via Go SDK:** + +```go +// Get detailed info +app, err := provider.GetApp(ctx, "base-test") +fmt.Printf("Status: %s, Replicas: %d/%d\n", + app.Status, app.ReadyReplicas, app.Replicas) + +// Get status only +status, err := provider.GetAppStatus(ctx, "base-test") +fmt.Printf("Ready Workers: %d/%d\n", + status.ReadyReplicas, status.TotalReplicas) +``` + +### List All Endpoints + +**Via API:** + +```bash +curl http://localhost:8090/api/v1/endpoints +``` + +**Via Go SDK:** + +```go +apps, err := provider.ListApps(ctx) +for _, app := range apps { + fmt.Printf("Endpoint: %s, Status: %s, Replicas: %d/%d\n", + app.Name, app.Status, app.ReadyReplicas, app.Replicas) +} +``` + +### Delete Endpoint + +**Via API:** + +```bash +curl -X DELETE http://localhost:8090/api/v1/endpoints/base-test +``` + +**Via Go SDK:** + +```go +err := provider.DeleteApp(ctx, "base-test") +``` + +### Update Deployment + +**Via API:** + +```bash +curl -X PATCH http://localhost:8090/api/v1/endpoints/base-test/deployment \ + -H "Content-Type: application/json" \ + -d '{ + "image": "ubuntu:22.04", + "replicas": 5, + "env": { + "MODEL_VERSION": "v2" + } + }' +``` + +**Via Go SDK:** + +```go +replicas := 3 +req := &interfaces.UpdateDeploymentRequest{ + Endpoint: "my-inference-endpoint", + Image: "new-image:v2", + Replicas: &replicas, + Env: &map[string]string{ + "MODEL_VERSION": "v2", + }, +} + +resp, err := provider.UpdateDeployment(ctx, req) +``` + +**Update Fields (all optional):** +- `image`: New Docker image +- `replicas`: New worker count (pointer to distinguish from zero) +- `env`: New environment variables (replaces all existing env vars) +- `taskTimeout`: New task timeout in seconds + +### Watch Status Changes + +Monitor endpoint status changes in real-time: + +```go +// Register callback to monitor endpoint status changes +err := provider.WatchReplicas(ctx, func(event interfaces.ReplicaEvent) { + fmt.Printf("[%s] Status changed: desired=%d, ready=%d, available=%d, status=%v\n", + event.Name, + event.DesiredReplicas, + event.ReadyReplicas, + event.AvailableReplicas, + event.Conditions) +}) +``` + +**Polling Mechanism:** + +Unlike K8s which provides real-time watch API, Novita provider uses a **polling mechanism**: + +1. Polls all endpoints at configured interval (default: 10s) +2. Compares current state with cached previous state +3. Triggers callback only when state changes (replicas, status, etc.) +4. Automatically handles endpoint lifecycle (creation, deletion) + +**Configure Poll Interval:** + +```yaml +novita: + poll_interval: 10 # Poll every 10 seconds (default: 10) +``` + +**Status Change Events:** + +The callback receives `ReplicaEvent` with: +- `Name`: Endpoint name +- `DesiredReplicas`: Target worker count (workerConfig.maxNum) +- `ReadyReplicas`: Number of healthy workers +- `AvailableReplicas`: Number of running workers +- `Conditions`: Status conditions (Available, Progressing, Failed, etc.) + +## API Mapping + +### Waverless → Novita Field Mapping + +| Waverless Field | Novita API Field | Notes | +|-----------------|------------------|-------| +| `endpoint` | `endpoint.name` | Unique endpoint identifier | +| `specName` | `products[].id` + `clusterID` | Product ID and region from `specs.yaml` | +| `replicas` | `workerConfig.minNum/maxNum` | Initially set to same value for fixed count | +| `image` | `image.image` | Docker image URL | +| `env` | `envs[]` | Array of `{key, value}` objects | +| `taskTimeout` | `workerConfig.freeTimeout` | Worker idle timeout in seconds (default: 300) | +| - | `workerConfig.requestTimeout` | Request timeout (default: 3600) | +| `spec.resources.gpu` | `workerConfig.gpuNum` | Number of GPUs per worker | +| `spec.resources.ephemeralStorage` | `endpoint.rootfsSize` | Rootfs disk size in GB | +| `spec.platforms.novita.region` | `endpoint.clusterID` | Novita cluster/region ID | +| `spec.platforms.novita.productId` | `products[0].id` | Novita product/GPU type ID | +| `spec.platforms.novita.cudaVersion` | `workerConfig.cudaVersion` | CUDA version (optional) | + +### Novita-Specific Default Values + +| Field | Default Value | Description | +|-------|--------------|-------------| +| `ports` | `[{port: 8000}]` | Default HTTP port for health check and requests | +| `policy` | `{type: "queue", value: 60}` | Auto-scaling policy: queue wait time | +| `healthy.path` | `/health` | Health check endpoint path | +| `workerConfig.maxConcurrent` | `1` | Max concurrent requests per worker | +| `workerConfig.freeTimeout` | `300` | Worker idle timeout (5 minutes) | +| `workerConfig.requestTimeout` | `3600` | Request timeout (1 hour) | + +## Architecture + +``` +┌─────────────────────────────────────────────────────────┐ +│ NovitaDeploymentProvider │ +├─────────────────────────────────────────────────────────┤ +│ - Deploy / GetApp / ListApps / DeleteApp │ +│ - ScaleApp / GetAppStatus / UpdateDeployment │ +│ - Endpoint ID caching (name -> ID mapping) │ +└──────────────────┬──────────────────────────────────────┘ + │ + ┌─────────┴─────────┐ + │ │ + ┌────▼─────┐ ┌─────▼──────┐ + │ Client │ │ Mapper │ + ├──────────┤ ├────────────┤ + │ HTTP API │ │ Data │ + │ calls to │ │ conversion │ + │ Novita │ │ logic │ + └──────────┘ └────────────┘ + │ + ▼ + Novita AI API +``` + +## Testing + +### Manual Testing + +1. Set up configuration: + ```bash + cp config/config.example.yaml config/config.yaml + # Edit config.yaml with your Novita API key + ``` + +2. Add test specs: + ```bash + cat config/specs-novita-example.yaml >> config/specs.yaml + ``` + +3. Deploy test endpoint: + ```bash + curl -X POST http://localhost:8090/api/v1/endpoints \ + -H "Content-Type: application/json" \ + -d '{ + "endpoint": "test-endpoint", + "specName": "novita-5090-single", + "image": "ubuntu:22.04", + "replicas": 1, + "taskTimeout": 1200 + }' + ``` + +## Troubleshooting + +### Error: "novita provider is not enabled" + +**Cause**: Provider not enabled in configuration + +**Solution**: Enable Novita provider in `config.yaml`: + +```yaml +providers: + deployment: "novita" + +novita: + enabled: true + api_key: "your-api-key" +``` + +### Error: "novita API key is required" + +**Cause**: Missing API key in configuration + +**Solution**: Set `novita.api_key` in `config.yaml`: + +```yaml +novita: + api_key: "your-novita-api-key-here" +``` + +Get API key from [Novita Console](https://console.novita.ai). + +### Error: "no Novita product ID found for spec" + +**Cause**: Missing platform configuration for the spec + +**Solution**: Add `platforms.novita` section to your spec in `specs.yaml`: + +```yaml +platforms: + novita: + productId: "SL-serverless-3" # Get from Novita Console + region: "us-dallas-nas-2" +``` + +### Error: "endpoint not found in Novita" + +**Cause**: Endpoint was deleted outside Waverless or cache is stale + +**Solution**: The provider automatically refreshes the cache on next API call. If the endpoint truly doesn't exist, create a new one. + +### Error: "failed to parse rootfs size" + +**Cause**: Invalid `ephemeralStorage` format in spec resources + +**Solution**: Use numeric value (GB) without unit: + +```yaml +resources: + ephemeralStorage: "100" # 100GB (correct) + # NOT: "100Gi" or "100GB" +``` + +### Error: "failed to create Novita endpoint" (401 Unauthorized) + +**Cause**: Invalid API key or expired token + +**Solution**: +1. Verify API key in [Novita Console](https://console.novita.ai) +2. Update `config.yaml` with correct API key +3. Restart Waverless service + +### Error: "failed to create Novita endpoint" (404 Not Found) + +**Cause**: Invalid product ID or region + +**Solution**: +1. Check available products in [Novita Console](https://console.novita.ai) +2. Update `specs.yaml` with correct `productId` and `region` +3. Verify the product is available in the specified region + +## Development + +### Adding New Features + +1. Check if Novita API supports the feature +2. Add types to `types.go` if needed +3. Implement in `client.go` (HTTP calls) +4. Add mapping logic in `mapper.go` +5. Expose via `provider.go` + +### Testing Locally + +```bash +# Run with Novita provider +go run cmd/main.go --config config/config.yaml +``` + +## References + +- [Novita AI Serverless API Documentation](https://novita.ai/docs/api-reference/serverless-create-endpoint) +- [Novita Console](https://console.novita.ai) +- [Waverless Architecture](../../../docs/ARCHITECTURE.md) diff --git a/pkg/deploy/novita/client.go b/pkg/deploy/novita/client.go new file mode 100644 index 0000000..443d1ba --- /dev/null +++ b/pkg/deploy/novita/client.go @@ -0,0 +1,161 @@ +package novita + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "time" + + "waverless/pkg/config" + "waverless/pkg/logger" +) + +// Client is the Novita API client +type Client struct { + apiKey string + baseURL string + httpClient *http.Client +} + +// NewClient creates a new Novita API client +func NewClient(cfg *config.NovitaConfig) *Client { + baseURL := cfg.BaseURL + if baseURL == "" { + baseURL = "https://api.novita.ai" + } + + return &Client{ + apiKey: cfg.APIKey, + baseURL: baseURL, + httpClient: &http.Client{ + Timeout: 30 * time.Second, + }, + } +} + +// CreateEndpoint creates a new endpoint +func (c *Client) CreateEndpoint(ctx context.Context, req *CreateEndpointRequest) (*CreateEndpointResponse, error) { + url := c.baseURL + "/gpu-instance/openapi/v1/endpoint/create" + + respData, err := c.doRequest(ctx, "POST", url, req) + if err != nil { + return nil, err + } + + var resp CreateEndpointResponse + if err := json.Unmarshal(respData, &resp); err != nil { + return nil, fmt.Errorf("failed to parse create endpoint response: %w", err) + } + + return &resp, nil +} + +// GetEndpoint gets endpoint details +func (c *Client) GetEndpoint(ctx context.Context, endpointID string) (*GetEndpointResponse, error) { + url := fmt.Sprintf("%s/gpu-instance/openapi/v1/endpoint?id=%s", c.baseURL, endpointID) + + respData, err := c.doRequest(ctx, "GET", url, nil) + if err != nil { + return nil, err + } + + var resp GetEndpointResponse + if err := json.Unmarshal(respData, &resp); err != nil { + return nil, fmt.Errorf("failed to parse get endpoint response: %w", err) + } + + return &resp, nil +} + +// ListEndpoints lists all endpoints +func (c *Client) ListEndpoints(ctx context.Context) (*ListEndpointsResponse, error) { + url := c.baseURL + "/gpu-instance/openapi/v1/endpoints" + + respData, err := c.doRequest(ctx, "GET", url, nil) + if err != nil { + return nil, err + } + + var resp ListEndpointsResponse + if err := json.Unmarshal(respData, &resp); err != nil { + return nil, fmt.Errorf("failed to parse list endpoints response: %w", err) + } + + return &resp, nil +} + +// UpdateEndpoint updates an existing endpoint +func (c *Client) UpdateEndpoint(ctx context.Context, req *UpdateEndpointRequest) error { + url := c.baseURL + "/gpu-instance/openapi/v1/endpoint/update" + + _, err := c.doRequest(ctx, "POST", url, req) + return err +} + +// DeleteEndpoint deletes an endpoint +func (c *Client) DeleteEndpoint(ctx context.Context, endpointID string) error { + url := c.baseURL + "/gpu-instance/openapi/v1/endpoint/delete" + + req := &DeleteEndpointRequest{ + ID: endpointID, + } + + _, err := c.doRequest(ctx, "POST", url, req) + return err +} + +// doRequest performs an HTTP request with proper authentication +func (c *Client) doRequest(ctx context.Context, method, url string, body interface{}) ([]byte, error) { + var reqBody io.Reader + if body != nil { + jsonData, err := json.Marshal(body) + if err != nil { + return nil, fmt.Errorf("failed to marshal request body: %w", err) + } + reqBody = bytes.NewReader(jsonData) + + // Log request for debugging + logger.Debugf("Novita API Request: %s %s, Body: %s", method, url, string(jsonData)) + } else { + logger.Debugf("Novita API Request: %s %s", method, url) + } + + req, err := http.NewRequestWithContext(ctx, method, url, reqBody) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP request: %w", err) + } + + // Set headers + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+c.apiKey) + + // Execute request + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to execute HTTP request: %w", err) + } + defer resp.Body.Close() + + // Read response body + respData, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + // Log response for debugging + logger.Debugf("Novita API Response: Status %d, Body: %s", resp.StatusCode, string(respData)) + + // Check for HTTP errors + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + var errResp ErrorResponse + if err := json.Unmarshal(respData, &errResp); err == nil && errResp.Message != "" { + return nil, fmt.Errorf("novita API error (status %d): %s", resp.StatusCode, errResp.Message) + } + return nil, fmt.Errorf("novita API error (status %d): %s", resp.StatusCode, string(respData)) + } + + return respData, nil +} diff --git a/pkg/deploy/novita/mapper.go b/pkg/deploy/novita/mapper.go new file mode 100644 index 0000000..8eaea3a --- /dev/null +++ b/pkg/deploy/novita/mapper.go @@ -0,0 +1,457 @@ +package novita + +import ( + "fmt" + "strconv" + "strings" + + "waverless/pkg/interfaces" + "waverless/pkg/logger" +) + +const ( + // Spec names + SpecNameNovitaH100Single = "novita-h100-single" + SpecNameNovitaA100Single = "novita-a100-single" + SpecNameNovitaA10Single = "novita-a10-single" + SpecNameNovitaH200Single = "novita-h200-single" + SpecNameH100Single = "h100-single" + SpecNameA1004x = "a100-4x" + + // Product IDs (placeholder values, replace with actual Novita product IDs) + ProductIDH100Single = "novita-h100-80gb-product-id" + ProductIDA100Single = "novita-a100-40gb-product-id" + ProductIDA10Single = "novita-a10-24gb-product-id" + ProductIDH200Single = "novita-h200-141gb-product-id" + ProductIDA1004x = "novita-a100-4x-product-id" + + // GPU Types + GPUTypeH100 = "NVIDIA-H100" + GPUTypeA100 = "NVIDIA-A100" + GPUTypeA10 = "NVIDIA-A10" + GPUTypeH200 = "NVIDIA-H200" + + // Resource category and type + CategoryGPU = "gpu" + ResourceTypeServerless = "serverless" + + // Standard GPU counts + GPUCount1 = "1" + + // CPU and Memory configurations + CPUH100 = "16" + MemoryH100 = "80Gi" + CPUA100 = "12" + MemoryA100 = "64Gi" + CPUA10 = "8" + MemoryA10 = "32Gi" + CPUH200 = "16" + MemoryH200 = "141Gi" + + // Display names + DisplayNameH100Single = "Novita H100 1x GPU" + DisplayNameA100Single = "Novita A100 1x GPU" + DisplayNameA10Single = "Novita A10 1x GPU" + DisplayNameH200Single = "Novita H200 1x GPU" + + // Default values + DefaultPort = 8000 + DefaultHealthPath = "/health" + DefaultRootfsSize = 100 + DefaultFreeTimeout = 300 // 5 minutes + DefaultMaxConcurrent = 1 + DefaultRegion = "us-dallas-nas-2" // Default region + DefaultQueueWaitTime = 60 // 60 seconds queue wait time + DefaultLocalStorageGB = 30 // 30GB local storage + DefaultRequestTimeout = 3600 // 1 hour request timeout + + // Label keys + LabelKeyRegion = "region" + LabelKeyProvider = "provider" + LabelKeyEndpointID = "endpoint-id" + + // Label values + LabelValueNovita = "novita" + + // Policy types + PolicyTypeQueue = "queue" + PolicyTypeConcurrency = "concurrency" + + // Storage types + // StorageTypeLocal = "local" + StorageTypeNetwork = "network" + + // App/Endpoint types + TypeServerlessEndpoint = "ServerlessEndpoint" + + // Status strings + StatusRunning = "Running" + StatusStopped = "Stopped" + StatusFailed = "Failed" + StatusPending = "Pending" + StatusCreating = "Creating" + StatusUpdating = "Updating" + StatusTerminating = "Terminating" + StatusUnknown = "Unknown" + + // Novita status strings (lowercase) + NovitaStatusServing = "serving" // Endpoint is serving (available) + NovitaStatusRunning = "running" // Worker is running + NovitaStatusStopped = "stopped" + NovitaStatusFailed = "failed" + NovitaStatusPending = "pending" + NovitaStatusCreating = "creating" + NovitaStatusUpdating = "updating" + NovitaStatusDeleting = "deleting" + + // Environment variable keys + EnvKeyNovitaProvider = "NOVITA_PROVIDER" + EnvKeyProviderType = "PROVIDER_TYPE" + EnvKeyNovitaRegion = "NOVITA_REGION" + + // Environment variable values + EnvValueTrue = "true" + EnvValueNovita = "novita" + + // Messages + MessageNoStatusInfo = "No status information available" + MessageDeploySuccess = "Endpoint deployed successfully" + MessageUpdateSuccess = "Endpoint updated successfully" + MessageDeleteSuccess = "Successfully deleted endpoint" + MessageNotSupported = "not supported by Novita provider" + MessageLogsNotSupported = "GetAppLogs is not supported by Novita provider - please use Novita dashboard for logs" + MessagePodsNotSupported = "GetPods is not supported by Novita provider - Novita manages workers internally" + MessageWatchNotSupported = "WatchReplicas is implemented using polling mechanism" +) + +// mapDeployRequestToNovita converts Waverless DeployRequest to Novita CreateEndpointRequest +func mapDeployRequestToNovita(req *interfaces.DeployRequest, spec *interfaces.SpecInfo) (*CreateEndpointRequest, error) { + novitaConfig, ok := spec.Platforms[PlatformNovita].(PlatformConfig) + if !ok { + return nil, fmt.Errorf("novita config not found for spec %s", spec.Name) + } + gpuNum, err := strconv.ParseInt(spec.Resources.GPU, 10, 64) + if err != nil { + return nil, fmt.Errorf("failed to parse GPU number: %w", err) + } + rootfsSize, err := strconv.ParseInt(spec.Resources.EphemeralStorage, 10, 64) + if err != nil { + return nil, fmt.Errorf("failed to parse rootfs size from '%s': %w", spec.Resources.EphemeralStorage, err) + } + // Worker configuration + workerConfig := WorkerConfig{ + MinNum: req.Replicas, + MaxNum: req.Replicas, // Initially set to same as min + FreeTimeout: DefaultFreeTimeout, + MaxConcurrent: DefaultMaxConcurrent, + GPUNum: int(gpuNum), + RequestTimeout: DefaultRequestTimeout, + } + + // Set optional fields only if they have values + if novitaConfig.CudaVersion != "" { + workerConfig.CudaVersion = novitaConfig.CudaVersion + } + + // Port configuration - default to 8000 + ports := []PortConfig{ + {Port: DefaultPort}, + } + + // Auto-scaling policy - single object (not array) + policy := PolicyConfig{ + Type: PolicyTypeQueue, + Value: DefaultQueueWaitTime, + } + + // Image configuration - single object (not array) + imageConfig := ImageConfig{ + Image: req.Image, + } + + // Product configuration + products := []ProductConfig{ + {ID: novitaConfig.ProductID}, + } + + // Environment variables + var envs []EnvVar + for k, v := range req.Env { + envs = append(envs, EnvVar{ + Key: k, + Value: v, + }) + } + + // Add region as environment variable for worker to access + envs = append(envs, EnvVar{ + Key: EnvKeyNovitaRegion, + Value: novitaConfig.Region, + }) + + // TODO: Volume mounts + // var volumeMounts []VolumeMount + // for _, vm := range req.VolumeMounts { + // volumeMounts = append(volumeMounts, VolumeMount{ + // Type: StorageTypeLocal, + // Size: DefaultLocalStorageGB, + // MountPath: vm.MountPath, + // }) + // } + + // Health check + healthCheck := &HealthCheck{ + Path: DefaultHealthPath, + } + + createReq := &CreateEndpointRequest{ + Endpoint: EndpointCreateConfig{ + Name: req.Endpoint, + AppName: req.Endpoint, // Use endpoint name as app name + ClusterID: novitaConfig.Region, // Set region as cluster ID + WorkerConfig: workerConfig, // Single object, not array + Ports: ports, + Policy: policy, // Single object, not array + Image: imageConfig, // Single object, not array + Products: products, + RootfsSize: int(rootfsSize), + // VolumeMounts: volumeMounts, + Envs: envs, + Healthy: healthCheck, + }, + } + + // Log the mapped request for debugging + logger.Debugf("Mapped CreateEndpoint request: Name=%s, ClusterID=%s, RootfsSize=%d, ProductID=%s, Image=%s", + createReq.Endpoint.Name, + createReq.Endpoint.ClusterID, + createReq.Endpoint.RootfsSize, + novitaConfig.ProductID, + imageConfig.Image, + ) + + return createReq, nil +} + +// mapNovitaResponseToAppInfo converts Novita GetEndpointResponse to Waverless AppInfo +func mapNovitaResponseToAppInfo(resp *GetEndpointResponse) *interfaces.AppInfo { + if resp == nil { + return nil + } + + endpoint := resp.Endpoint + + // Extract worker configuration + var replicas, readyReplicas, availableReplicas int32 + replicas = int32(endpoint.WorkerConfig.MaxNum) + + // Count running and healthy workers + runningWorkers := 0 + healthyWorkers := 0 + for _, worker := range endpoint.Workers { + if worker.State.State == NovitaStatusRunning { + runningWorkers++ + } + if worker.Healthy { + healthyWorkers++ + } + } + + // Use healthy workers as ready replicas, running workers as available replicas + readyReplicas = int32(healthyWorkers) + availableReplicas = int32(runningWorkers) + + // Extract image + image := endpoint.Image.Image + + // Build labels + labels := make(map[string]string) + labels[LabelKeyProvider] = LabelValueNovita + labels[LabelKeyEndpointID] = endpoint.ID + + // Extract region from environment variables + for _, env := range endpoint.Envs { + if env.Key == EnvKeyNovitaRegion { + labels[LabelKeyRegion] = env.Value + break + } + } + + return &interfaces.AppInfo{ + Name: endpoint.Name, + Type: TypeServerlessEndpoint, + Status: mapNovitaStatusToWaverless(endpoint.State.State), + Replicas: replicas, + ReadyReplicas: readyReplicas, + AvailableReplicas: availableReplicas, + Image: image, + Labels: labels, + CreatedAt: "", // Not provided in response, could use log timestamp if needed + } +} + +// mapNovitaListItemToAppInfo converts Novita EndpointListItem to Waverless AppInfo +func mapNovitaListItemToAppInfo(item *EndpointListItem) *interfaces.AppInfo { + if item == nil { + return nil + } + + labels := make(map[string]string) + labels[LabelKeyProvider] = LabelValueNovita + labels[LabelKeyEndpointID] = item.ID + + return &interfaces.AppInfo{ + Name: item.Name, + Type: TypeServerlessEndpoint, + Status: mapNovitaStatusToWaverless(item.State.State), + Image: "", // Not available in list view + Labels: labels, + CreatedAt: item.CreatedAt, + } +} + +// mapNovitaStatusToAppStatus converts Novita endpoint data to Waverless AppStatus +func mapNovitaStatusToAppStatus(endpointName string, data *EndpointConfig) *interfaces.AppStatus { + if data == nil { + return &interfaces.AppStatus{ + Endpoint: endpointName, + Status: StatusUnknown, + Message: MessageNoStatusInfo, + } + } + + // Count workers by state + runningWorkers := 0 + healthyWorkers := 0 + pendingWorkers := 0 + for _, worker := range data.Workers { + switch worker.State.State { + case NovitaStatusRunning: + runningWorkers++ + if worker.Healthy { + healthyWorkers++ + } + case NovitaStatusPending, NovitaStatusCreating: + pendingWorkers++ + } + } + + totalReplicas := runningWorkers + pendingWorkers + + return &interfaces.AppStatus{ + Endpoint: endpointName, + Status: mapNovitaStatusToWaverless(data.State.State), + ReadyReplicas: int32(healthyWorkers), + AvailableReplicas: int32(runningWorkers), + TotalReplicas: int32(totalReplicas), + Message: data.State.Message, + } +} + +// mapNovitaStatusToWaverless converts Novita state to Waverless status +func mapNovitaStatusToWaverless(state string) string { + switch strings.ToLower(state) { + case NovitaStatusServing, NovitaStatusRunning: + return StatusRunning + case NovitaStatusStopped: + return StatusStopped + case NovitaStatusFailed: + return StatusFailed + case NovitaStatusPending: + return StatusPending + case NovitaStatusCreating: + return StatusCreating + case NovitaStatusUpdating: + return StatusUpdating + case NovitaStatusDeleting: + return StatusTerminating + default: + return StatusUnknown + } +} + +// mapUpdateRequestToNovita converts Waverless UpdateDeploymentRequest to Novita UpdateEndpointRequest +func mapUpdateRequestToNovita(endpointID string, req *interfaces.UpdateDeploymentRequest, currentConfig *GetEndpointResponse) *UpdateEndpointRequest { + if currentConfig == nil { + logger.Warnf("No current config available for endpoint %s, update may fail", endpointID) + return nil + } + + data := currentConfig.Endpoint + + // Build worker config - use string types for freeTimeout and maxConcurrent + workerConfig := WorkerConfigResponse{ + MinNum: data.WorkerConfig.MinNum, + MaxNum: data.WorkerConfig.MaxNum, + FreeTimeout: data.WorkerConfig.FreeTimeout, // Already string + MaxConcurrent: data.WorkerConfig.MaxConcurrent, // Already string + GPUNum: data.WorkerConfig.GPUNum, + RequestTimeout: data.WorkerConfig.RequestTimeout, + } + + // Build policy config - use string type for value + policy := PolicyResponse{ + Type: data.Policy.Type, + Value: data.Policy.Value, // Already string + } + + // Build image config + imageConfig := ImageConfig{ + Image: data.Image.Image, + AuthID: data.Image.AuthID, + Command: data.Image.Command, + } + + // Build ports config + portsConfig := []PortConfig{} + for _, port := range data.Ports { + portsConfig = append(portsConfig, PortConfig{ + Port: port.Port, + }) + } + + // Build health check config with full details + + // Apply updates from request + // Update image if specified + if req.Image != "" { + imageConfig.Image = req.Image + } + + // Update replicas if specified + if req.Replicas != nil { + workerConfig.MinNum = *req.Replicas + workerConfig.MaxNum = *req.Replicas + } + + // Update environment variables if specified + envs := data.Envs + if req.Env != nil { + envs = []EnvVar{} + for k, v := range *req.Env { + envs = append(envs, EnvVar{ + Key: k, + Value: v, + }) + } + } + healthCheck := &HealthCheck{ + Path: data.Healthy.Path, + } + // Build flattened update request + return &UpdateEndpointRequest{ + ID: endpointID, + Name: data.Name, + AppName: data.AppName, + WorkerConfig: workerConfig, + Policy: policy, + Image: imageConfig, + RootfsSize: data.RootfsSize, + VolumeMounts: data.VolumeMounts, + Envs: envs, + Ports: portsConfig, + Workers: nil, // Set to nil as per API example + Products: data.Products, + Healthy: healthCheck, + } +} diff --git a/pkg/deploy/novita/provider.go b/pkg/deploy/novita/provider.go new file mode 100644 index 0000000..3f33006 --- /dev/null +++ b/pkg/deploy/novita/provider.go @@ -0,0 +1,580 @@ +package novita + +import ( + "context" + "encoding/json" + "fmt" + "sync" + "sync/atomic" + "time" + + "waverless/pkg/config" + "waverless/pkg/interfaces" + "waverless/pkg/logger" +) + +// clientInterface defines the interface for Novita API client (for testing) +type clientInterface interface { + CreateEndpoint(ctx context.Context, req *CreateEndpointRequest) (*CreateEndpointResponse, error) + GetEndpoint(ctx context.Context, endpointID string) (*GetEndpointResponse, error) + ListEndpoints(ctx context.Context) (*ListEndpointsResponse, error) + UpdateEndpoint(ctx context.Context, req *UpdateEndpointRequest) error + DeleteEndpoint(ctx context.Context, endpointID string) error +} + +// replicaCallbackEntry represents a registered replica callback +type replicaCallbackEntry struct { + id uint64 + callback interfaces.ReplicaCallback +} + +// endpointState stores the last known state of an endpoint +type endpointState struct { + DesiredReplicas int + ReadyReplicas int + AvailableReplicas int + Status string +} + +// NovitaDeploymentProvider implements interfaces.DeploymentProvider for Novita Serverless +type NovitaDeploymentProvider struct { + client clientInterface + config *config.NovitaConfig + specsConfig *SpecsConfig + endpointCache sync.Map // Cache endpoint ID mappings: name -> endpointID + + // WatchReplicas support + replicaCallbacks map[uint64]*replicaCallbackEntry + replicaCallbacksLock sync.RWMutex + nextCallbackID uint64 + endpointStates sync.Map // endpoint name -> *endpointState + watcherRunning atomic.Bool + watcherStopCh chan struct{} + pollInterval time.Duration // Configurable poll interval +} + +// NewNovitaDeploymentProvider creates a new Novita deployment provider +func NewNovitaDeploymentProvider(cfg *config.Config) (interfaces.DeploymentProvider, error) { + if !cfg.Novita.Enabled { + return nil, fmt.Errorf("novita provider is not enabled in config") + } + + if cfg.Novita.APIKey == "" { + return nil, fmt.Errorf("novita API key is required") + } + + // Initialize specs configuration + specsConfig, err := NewSpecsConfig(cfg.Novita.ConfigDir) + if err != nil { + return nil, fmt.Errorf("failed to initialize specs config: %w", err) + } + + client := NewClient(&cfg.Novita) + + // Set default poll interval to 10 seconds + pollInterval := 10 * time.Second + if cfg.Novita.PollInterval > 0 { + pollInterval = time.Duration(cfg.Novita.PollInterval) * time.Second + } + + return &NovitaDeploymentProvider{ + client: client, + config: &cfg.Novita, + specsConfig: specsConfig, + replicaCallbacks: make(map[uint64]*replicaCallbackEntry), + watcherStopCh: make(chan struct{}), + pollInterval: pollInterval, + }, nil +} + +// Deploy deploys an application to Novita serverless +func (p *NovitaDeploymentProvider) Deploy(ctx context.Context, req *interfaces.DeployRequest) (*interfaces.DeployResponse, error) { + logger.Infof("Deploying endpoint %s to Novita", req.Endpoint) + + // Get spec from configuration + specInfo, err := p.specsConfig.GetSpec(req.SpecName) + if err != nil { + return nil, fmt.Errorf("failed to get spec for %s: %w", req.SpecName, err) + } + + // Map Waverless request to Novita request (mapper will extract platform config from spec) + novitaReq, err := mapDeployRequestToNovita(req, specInfo) + if err != nil { + return nil, fmt.Errorf("failed to map deploy request to Novita: %w", err) + } + + // Create endpoint + resp, err := p.client.CreateEndpoint(ctx, novitaReq) + if err != nil { + return nil, fmt.Errorf("failed to create Novita endpoint: %w", err) + } + + // Cache endpoint ID mapping + p.endpointCache.Store(req.Endpoint, resp.ID) + + logger.Infof("Successfully deployed endpoint %s to Novita (ID: %s)", req.Endpoint, resp.ID) + + return &interfaces.DeployResponse{ + Endpoint: req.Endpoint, + Message: fmt.Sprintf("%s (ID: %s)", MessageDeploySuccess, resp.ID), + CreatedAt: "", // Novita doesn't return creation time in response + }, nil +} + +// GetApp retrieves application details +func (p *NovitaDeploymentProvider) GetApp(ctx context.Context, endpoint string) (*interfaces.AppInfo, error) { + logger.Debugf("Getting app info for endpoint %s", endpoint) + + // Get endpoint ID from cache or find it + endpointID, err := p.getEndpointID(ctx, endpoint) + if err != nil { + return nil, err + } + + // Get endpoint details from Novita + resp, err := p.client.GetEndpoint(ctx, endpointID) + if err != nil { + return nil, fmt.Errorf("failed to get endpoint from Novita: %w", err) + } + + return mapNovitaResponseToAppInfo(resp), nil +} + +// ListApps lists all applications +func (p *NovitaDeploymentProvider) ListApps(ctx context.Context) ([]*interfaces.AppInfo, error) { + logger.Debugf("Listing all Novita endpoints") + + resp, err := p.client.ListEndpoints(ctx) + if err != nil { + return nil, fmt.Errorf("failed to list endpoints from Novita: %w", err) + } + + apps := make([]*interfaces.AppInfo, 0, len(resp.Endpoints)) + for _, item := range resp.Endpoints { + // Cache endpoint ID mapping + p.endpointCache.Store(item.Name, item.ID) + + apps = append(apps, mapNovitaListItemToAppInfo(&item)) + } + + return apps, nil +} + +// DeleteApp deletes application +func (p *NovitaDeploymentProvider) DeleteApp(ctx context.Context, endpoint string) error { + logger.Infof("Deleting endpoint %s from Novita", endpoint) + + // Get endpoint ID + endpointID, err := p.getEndpointID(ctx, endpoint) + if err != nil { + return err + } + + // Delete endpoint + if err := p.client.DeleteEndpoint(ctx, endpointID); err != nil { + return fmt.Errorf("failed to delete endpoint from Novita: %w", err) + } + + // Remove from cache + p.endpointCache.Delete(endpoint) + + logger.Infof("%s %s (ID: %s)", MessageDeleteSuccess, endpoint, endpointID) + return nil +} + +// ScaleApp scales application replicas +func (p *NovitaDeploymentProvider) ScaleApp(ctx context.Context, endpoint string, replicas int) error { + logger.Infof("Scaling endpoint %s to %d replicas", endpoint, replicas) + + // Get endpoint ID + endpointID, err := p.getEndpointID(ctx, endpoint) + if err != nil { + return err + } + + // Get current configuration + currentConfig, err := p.client.GetEndpoint(ctx, endpointID) + if err != nil { + return fmt.Errorf("failed to get current endpoint config: %w", err) + } + + // Create update request with modified replicas + replicasPtr := &replicas + scaleReq := &interfaces.UpdateDeploymentRequest{ + Endpoint: endpoint, + Replicas: replicasPtr, + } + + updateReq := mapUpdateRequestToNovita(endpointID, scaleReq, currentConfig) + if updateReq == nil { + return fmt.Errorf("failed to create scale request") + } + + if err := p.client.UpdateEndpoint(ctx, updateReq); err != nil { + return fmt.Errorf("failed to scale endpoint: %w", err) + } + + logger.Infof("Successfully scaled endpoint %s to %d replicas", endpoint, replicas) + return nil +} + +// GetAppStatus retrieves application status +func (p *NovitaDeploymentProvider) GetAppStatus(ctx context.Context, endpoint string) (*interfaces.AppStatus, error) { + logger.Debugf("Getting status for endpoint %s", endpoint) + + // Get endpoint ID + endpointID, err := p.getEndpointID(ctx, endpoint) + if err != nil { + return nil, err + } + + // Get endpoint details + resp, err := p.client.GetEndpoint(ctx, endpointID) + if err != nil { + return nil, fmt.Errorf("failed to get endpoint status: %w", err) + } + + return mapNovitaStatusToAppStatus(endpoint, &resp.Endpoint), nil +} + +// GetAppLogs retrieves application logs (not supported by Novita) +func (p *NovitaDeploymentProvider) GetAppLogs(ctx context.Context, endpoint string, lines int, podName ...string) (string, error) { + return "", fmt.Errorf(MessageLogsNotSupported) +} + +// UpdateDeployment updates deployment +func (p *NovitaDeploymentProvider) UpdateDeployment(ctx context.Context, req *interfaces.UpdateDeploymentRequest) (*interfaces.DeployResponse, error) { + logger.Infof("Updating deployment for endpoint %s", req.Endpoint) + + // Get endpoint ID + endpointID, err := p.getEndpointID(ctx, req.Endpoint) + if err != nil { + return nil, err + } + + // Get current configuration + currentConfig, err := p.client.GetEndpoint(ctx, endpointID) + if err != nil { + return nil, fmt.Errorf("failed to get current endpoint config: %w", err) + } + + // Map update request + updateReq := mapUpdateRequestToNovita(endpointID, req, currentConfig) + if updateReq == nil { + return nil, fmt.Errorf("failed to map update request") + } + + // Update endpoint + if err := p.client.UpdateEndpoint(ctx, updateReq); err != nil { + return nil, fmt.Errorf("failed to update endpoint: %w", err) + } + + logger.Infof("Successfully updated endpoint %s", req.Endpoint) + + return &interfaces.DeployResponse{ + Endpoint: req.Endpoint, + Message: MessageUpdateSuccess, + CreatedAt: "", + }, nil +} + +// ListSpecs lists available specifications +func (p *NovitaDeploymentProvider) ListSpecs(ctx context.Context) ([]*interfaces.SpecInfo, error) { + return p.specsConfig.ListSpecs(), nil +} + +// GetSpec retrieves specification details +func (p *NovitaDeploymentProvider) GetSpec(ctx context.Context, specName string) (*interfaces.SpecInfo, error) { + return p.specsConfig.GetSpec(specName) +} + +// PreviewDeploymentYAML previews deployment configuration (returns JSON for Novita) +func (p *NovitaDeploymentProvider) PreviewDeploymentYAML(ctx context.Context, req *interfaces.DeployRequest) (string, error) { + // Get spec from configuration + specInfo, err := p.specsConfig.GetSpec(req.SpecName) + if err != nil { + return "", fmt.Errorf("failed to get spec for %s: %w", req.SpecName, err) + } + + region := req.Labels[LabelKeyRegion] + if region == "" { + region = DefaultRegion + } + + // Map to Novita request (mapper will extract platform config from spec) + novitaReq, err := mapDeployRequestToNovita(req, specInfo) + if err != nil { + return "", fmt.Errorf("failed to map deploy request to Novita: %w", err) + } + + // Convert to JSON + jsonData, err := json.MarshalIndent(novitaReq, "", " ") + if err != nil { + return "", fmt.Errorf("failed to marshal Novita config: %w", err) + } + + return string(jsonData), nil +} + +// WatchReplicas watches replica count changes using polling mechanism +func (p *NovitaDeploymentProvider) WatchReplicas(ctx context.Context, callback interfaces.ReplicaCallback) error { + if callback == nil { + return fmt.Errorf("replica callback is nil") + } + + // Register callback + p.replicaCallbacksLock.Lock() + callbackID := atomic.AddUint64(&p.nextCallbackID, 1) + p.replicaCallbacks[callbackID] = &replicaCallbackEntry{ + id: callbackID, + callback: callback, + } + p.replicaCallbacksLock.Unlock() + + logger.Infof("Registered replica watch callback (ID: %d) for Novita endpoints", callbackID) + + // Start watcher if not already running + if p.watcherRunning.CompareAndSwap(false, true) { + logger.Infof("Starting Novita replica watcher (poll interval: %v)", p.pollInterval) + go p.runReplicaWatcher(ctx) + } + + // Unregister callback when context is done + go func() { + <-ctx.Done() + p.replicaCallbacksLock.Lock() + delete(p.replicaCallbacks, callbackID) + p.replicaCallbacksLock.Unlock() + logger.Infof("Unregistered replica watch callback (ID: %d)", callbackID) + }() + + return nil +} + +// runReplicaWatcher runs the polling loop to monitor replica changes +func (p *NovitaDeploymentProvider) runReplicaWatcher(ctx context.Context) { + ticker := time.NewTicker(p.pollInterval) + defer ticker.Stop() + + logger.Infof("Novita replica watcher started") + + for { + select { + case <-ctx.Done(): + logger.Infof("Novita replica watcher stopped (context done)") + p.watcherRunning.Store(false) + return + case <-p.watcherStopCh: + logger.Infof("Novita replica watcher stopped (stop signal)") + p.watcherRunning.Store(false) + return + case <-ticker.C: + p.pollEndpointStates(ctx) + } + } +} + +// pollEndpointStates polls all endpoints and detects state changes +func (p *NovitaDeploymentProvider) pollEndpointStates(ctx context.Context) { + // List all endpoints + resp, err := p.client.ListEndpoints(ctx) + if err != nil { + logger.Errorf("Failed to list Novita endpoints for polling: %v", err) + return + } + + // Process each endpoint + for _, item := range resp.Endpoints { + endpointName := item.Name + + // Get status from list item (includes full worker details) + status := p.getEndpointStateFromListItem(&item) + + // Compare with cached state + previousStateInterface, exists := p.endpointStates.Load(endpointName) + + var hasChanged bool + if !exists { + // New endpoint + hasChanged = true + } else { + previousState := previousStateInterface.(*endpointState) + hasChanged = p.hasStateChanged(previousState, status) + } + + // Update cache + p.endpointStates.Store(endpointName, status) + + // Trigger callbacks if state changed + if hasChanged { + logger.Debugf("Detected state change for endpoint %s: desired=%d, ready=%d, available=%d, status=%s", + endpointName, status.DesiredReplicas, status.ReadyReplicas, status.AvailableReplicas, status.Status) + + p.triggerReplicaCallbacks(interfaces.ReplicaEvent{ + Name: endpointName, + DesiredReplicas: status.DesiredReplicas, + ReadyReplicas: status.ReadyReplicas, + AvailableReplicas: status.AvailableReplicas, + Conditions: p.buildConditions(status), + }) + } + } +} + +// getEndpointStateFromListItem extracts state from list item +// Note: Novita's ListEndpoints API returns full endpoint details including workers +func (p *NovitaDeploymentProvider) getEndpointStateFromListItem(item *EndpointListItem) *endpointState { + if item == nil { + return &endpointState{} + } + + status := mapNovitaStatusToWaverless(item.State.State) + + // Get desired replicas from worker config + desiredReplicas := item.WorkerConfig.MaxNum + + // Count workers by state + runningWorkers := 0 + // healthyWorkers := 0 + + for _, worker := range item.Workers { + if worker.State.State == NovitaStatusRunning { + runningWorkers++ + } + // if worker.Healthy { + // healthyWorkers++ + // } + } + + // Use running workers as ready replicas and available replicas + readyReplicas := runningWorkers + availableReplicas := runningWorkers + + return &endpointState{ + DesiredReplicas: desiredReplicas, + ReadyReplicas: readyReplicas, + AvailableReplicas: availableReplicas, + Status: status, + } +} + +// hasStateChanged checks if the endpoint state has changed +func (p *NovitaDeploymentProvider) hasStateChanged(previous, current *endpointState) bool { + return previous.DesiredReplicas != current.DesiredReplicas || + previous.ReadyReplicas != current.ReadyReplicas || + previous.AvailableReplicas != current.AvailableReplicas || + previous.Status != current.Status +} + +// buildConditions builds condition list from endpoint state +func (p *NovitaDeploymentProvider) buildConditions(state *endpointState) []interfaces.ReplicaCondition { + conditions := []interfaces.ReplicaCondition{} + + if state.Status == StatusRunning && state.ReadyReplicas > 0 { + conditions = append(conditions, interfaces.ReplicaCondition{ + Type: "Available", + Status: "True", + Reason: "MinimumReplicasAvailable", + Message: "Endpoint has minimum availability", + }) + } else if state.Status == StatusPending || state.Status == StatusCreating { + conditions = append(conditions, interfaces.ReplicaCondition{ + Type: "Progressing", + Status: "True", + Reason: "NewEndpointAvailable", + Message: "Endpoint is being created", + }) + } else if state.Status == StatusFailed { + conditions = append(conditions, interfaces.ReplicaCondition{ + Type: "Available", + Status: "False", + Reason: "EndpointFailed", + Message: "Endpoint has failed", + }) + } + + return conditions +} + +// triggerReplicaCallbacks triggers all registered callbacks with the event +func (p *NovitaDeploymentProvider) triggerReplicaCallbacks(event interfaces.ReplicaEvent) { + p.replicaCallbacksLock.RLock() + defer p.replicaCallbacksLock.RUnlock() + + for _, entry := range p.replicaCallbacks { + // Call callback in a goroutine to avoid blocking + go func(cb interfaces.ReplicaCallback, e interfaces.ReplicaEvent) { + defer func() { + if r := recover(); r != nil { + logger.Errorf("Panic in replica callback: %v", r) + } + }() + cb(e) + }(entry.callback, event) + } +} + +// StopReplicaWatcher stops the replica watcher +func (p *NovitaDeploymentProvider) StopReplicaWatcher() { + if p.watcherRunning.Load() { + close(p.watcherStopCh) + } +} + +// GetPods retrieves all Pod information (not supported by Novita) +func (p *NovitaDeploymentProvider) GetPods(ctx context.Context, endpoint string) ([]*interfaces.PodInfo, error) { + return nil, fmt.Errorf(MessagePodsNotSupported) +} + +// DescribePod retrieves detailed Pod information (not supported by Novita) +func (p *NovitaDeploymentProvider) DescribePod(ctx context.Context, endpoint string, podName string) (*interfaces.PodDetail, error) { + return nil, fmt.Errorf("DescribePod %s", MessageNotSupported) +} + +// GetPodYAML retrieves Pod YAML (not supported by Novita) +func (p *NovitaDeploymentProvider) GetPodYAML(ctx context.Context, endpoint string, podName string) (string, error) { + return "", fmt.Errorf("GetPodYAML %s", MessageNotSupported) +} + +// ListPVCs lists all PersistentVolumeClaims (not supported by Novita) +func (p *NovitaDeploymentProvider) ListPVCs(ctx context.Context) ([]*interfaces.PVCInfo, error) { + return nil, fmt.Errorf("ListPVCs %s", MessageNotSupported) +} + +// GetDefaultEnv retrieves default environment variables +func (p *NovitaDeploymentProvider) GetDefaultEnv(ctx context.Context) (map[string]string, error) { + // Return default environment variables for Novita + defaultEnv := map[string]string{ + EnvKeyNovitaProvider: EnvValueTrue, + EnvKeyProviderType: EnvValueNovita, + } + + return defaultEnv, nil +} + +// getEndpointID retrieves the Novita endpoint ID for a given endpoint name +// It first checks the cache, then queries the API if not found +func (p *NovitaDeploymentProvider) getEndpointID(ctx context.Context, endpoint string) (string, error) { + // Check cache first + if id, ok := p.endpointCache.Load(endpoint); ok { + return id.(string), nil + } + + // Not in cache, query API + logger.Debugf("Endpoint ID not in cache, querying Novita API for %s", endpoint) + + resp, err := p.client.ListEndpoints(ctx) + if err != nil { + return "", fmt.Errorf("failed to list endpoints: %w", err) + } + + // Find matching endpoint and update cache + for _, item := range resp.Endpoints { + p.endpointCache.Store(item.Name, item.ID) + if item.Name == endpoint { + return item.ID, nil + } + } + + return "", fmt.Errorf("endpoint %s not found in Novita", endpoint) +} diff --git a/pkg/deploy/novita/provider_test.go b/pkg/deploy/novita/provider_test.go new file mode 100644 index 0000000..85c1929 --- /dev/null +++ b/pkg/deploy/novita/provider_test.go @@ -0,0 +1,1196 @@ +package novita + +import ( + "context" + "fmt" + "os" + "path/filepath" + "sync" + "testing" + + "waverless/pkg/config" + "waverless/pkg/interfaces" +) + +// mockClient is a mock implementation of Novita API client for testing +type mockClient struct { + endpoints map[string]*GetEndpointResponse + endpointsMutex sync.RWMutex + createError error + getError error + deleteError error + updateError error + listError error +} + +func newMockClient() *mockClient { + return &mockClient{ + endpoints: make(map[string]*GetEndpointResponse), + } +} + +func (m *mockClient) CreateEndpoint(ctx context.Context, req *CreateEndpointRequest) (*CreateEndpointResponse, error) { + if m.createError != nil { + return nil, m.createError + } + + m.endpointsMutex.Lock() + defer m.endpointsMutex.Unlock() + + endpointID := fmt.Sprintf("ep-%s", req.Endpoint.Name) + + // Create workers based on replica count + workers := make([]WorkerInfo, req.Endpoint.WorkerConfig.MinNum) + for i := 0; i < req.Endpoint.WorkerConfig.MinNum; i++ { + workers[i] = WorkerInfo{ + ID: fmt.Sprintf("worker-%d", i), + State: StateInfo{ + State: NovitaStatusRunning, + Error: "", + Message: "Worker is running", + }, + Healthy: true, + } + } + + // Convert EndpointConfig to GetEndpointResponseData + m.endpoints[endpointID] = &GetEndpointResponse{ + Endpoint: EndpointConfig{ + ID: endpointID, + Name: req.Endpoint.Name, + AppName: req.Endpoint.AppName, + State: StateInfo{ + State: NovitaStatusServing, + Error: "", + Message: "Endpoint is serving", + }, + URL: fmt.Sprintf("https://%s.novita.ai", endpointID), + WorkerConfig: WorkerConfigResponse{ + MinNum: req.Endpoint.WorkerConfig.MinNum, + MaxNum: req.Endpoint.WorkerConfig.MaxNum, + FreeTimeout: fmt.Sprintf("%d", req.Endpoint.WorkerConfig.FreeTimeout), + MaxConcurrent: fmt.Sprintf("%d", req.Endpoint.WorkerConfig.MaxConcurrent), + GPUNum: req.Endpoint.WorkerConfig.GPUNum, + CudaVersion: "11.8", + }, + Policy: PolicyDetails{ + Type: req.Endpoint.Policy.Type, + Value: fmt.Sprintf("%d", req.Endpoint.Policy.Value), + }, + Image: ImageDetails{ + Image: req.Endpoint.Image.Image, + AuthID: req.Endpoint.Image.AuthID, + Command: req.Endpoint.Image.Command, + }, + RootfsSize: req.Endpoint.RootfsSize, + VolumeMounts: req.Endpoint.VolumeMounts, + Envs: req.Endpoint.Envs, + Ports: []PortDetails{{Port: 8000}}, + Workers: workers, + Products: req.Endpoint.Products, + Healthy: nil, + ClusterID: req.Endpoint.ClusterID, + Log: fmt.Sprintf("/logs/%s", endpointID), + }, + } + + return &CreateEndpointResponse{ID: endpointID}, nil +} + +func (m *mockClient) GetEndpoint(ctx context.Context, endpointID string) (*GetEndpointResponse, error) { + if m.getError != nil { + return nil, m.getError + } + + m.endpointsMutex.RLock() + defer m.endpointsMutex.RUnlock() + + ep, ok := m.endpoints[endpointID] + if !ok { + return nil, fmt.Errorf("endpoint %s not found", endpointID) + } + + return ep, nil +} + +func (m *mockClient) ListEndpoints(ctx context.Context) (*ListEndpointsResponse, error) { + if m.listError != nil { + return nil, m.listError + } + + m.endpointsMutex.RLock() + defer m.endpointsMutex.RUnlock() + + items := make([]EndpointListItem, 0, len(m.endpoints)) + for id, ep := range m.endpoints { + items = append(items, EndpointListItem{ + ID: id, + Name: ep.Endpoint.Name, + AppName: ep.Endpoint.AppName, + State: ep.Endpoint.State, + }) + } + + return &ListEndpointsResponse{ + Endpoints: items, + Total: len(items), + }, nil +} + +func (m *mockClient) UpdateEndpoint(ctx context.Context, req *UpdateEndpointRequest) error { + if m.updateError != nil { + return m.updateError + } + + m.endpointsMutex.Lock() + defer m.endpointsMutex.Unlock() + + ep, ok := m.endpoints[req.ID] + if !ok { + return fmt.Errorf("endpoint %s not found", req.ID) + } + + // Update fields from flattened UpdateEndpointRequest + if req.Name != "" { + ep.Endpoint.Name = req.Name + } + if req.AppName != "" { + ep.Endpoint.AppName = req.AppName + } + + // Update worker config + wc := req.WorkerConfig + if wc.MinNum >= 0 { + ep.Endpoint.WorkerConfig.MinNum = wc.MinNum + ep.Endpoint.WorkerConfig.MaxNum = wc.MaxNum + ep.Endpoint.WorkerConfig.FreeTimeout = wc.FreeTimeout + ep.Endpoint.WorkerConfig.MaxConcurrent = wc.MaxConcurrent + } + if wc.GPUNum > 0 { + ep.Endpoint.WorkerConfig.GPUNum = wc.GPUNum + } + + // Update workers count to match new replicas + newWorkerCount := wc.MinNum + if newWorkerCount >= 0 { + if newWorkerCount == 0 { + // Remove all workers + ep.Endpoint.Workers = []WorkerInfo{} + } else if newWorkerCount > len(ep.Endpoint.Workers) { + // Add more workers + for i := len(ep.Endpoint.Workers); i < newWorkerCount; i++ { + ep.Endpoint.Workers = append(ep.Endpoint.Workers, WorkerInfo{ + ID: fmt.Sprintf("worker-%d", i), + State: StateInfo{ + State: NovitaStatusRunning, + Message: "Worker is running", + }, + Healthy: true, + }) + } + } else if newWorkerCount < len(ep.Endpoint.Workers) { + // Remove workers + ep.Endpoint.Workers = ep.Endpoint.Workers[:newWorkerCount] + } + } + + // Update image if provided + if req.Image.Image != "" { + ep.Endpoint.Image.Image = req.Image.Image + ep.Endpoint.Image.AuthID = req.Image.AuthID + ep.Endpoint.Image.Command = req.Image.Command + } + + // Update env vars if provided + if req.Envs != nil { + ep.Endpoint.Envs = req.Envs + } + + return nil +} + +func (m *mockClient) DeleteEndpoint(ctx context.Context, endpointID string) error { + if m.deleteError != nil { + return m.deleteError + } + + m.endpointsMutex.Lock() + defer m.endpointsMutex.Unlock() + + if _, ok := m.endpoints[endpointID]; !ok { + return fmt.Errorf("endpoint %s not found", endpointID) + } + + delete(m.endpoints, endpointID) + return nil +} + +// createTestSpecsFile creates a temporary specs.yaml file for testing +func createTestSpecsFile(t *testing.T) string { + tmpDir := t.TempDir() + specsContent := `- name: novita-h100-single + displayName: "Novita H100 1x GPU" + category: gpu + resourceType: serverless + resources: + gpu: "1" + gpuType: "NVIDIA-H100" + cpu: "16" + memory: "80Gi" + ephemeralStorage: "100" + platforms: + novita: + productId: "novita-h100-80gb-product-id" + region: "eu-west-1" + - name: novita-a100-single + displayName: "Novita A100 1x GPU" + category: gpu + resourceType: serverless + resources: + gpu: "1" + gpuType: "NVIDIA-A100" + cpu: "12" + memory: "40Gi" + ephemeralStorage: "100" + platforms: + novita: + productId: "novita-a100-40gb-product-id" + region: "eu-west-1" + - name: novita-a10-single + displayName: "Novita A10 1x GPU" + category: gpu + resourceType: serverless + resources: + gpu: "1" + gpuType: "NVIDIA-A10" + cpu: "8" + memory: "32Gi" + ephemeralStorage: "100" + platforms: + novita: + productId: "novita-a10-24gb-product-id" + region: "eu-west-1" + - name: novita-h200-single + displayName: "Novita H200 1x GPU" + category: gpu + resourceType: serverless + resources: + gpu: "1" + gpuType: "NVIDIA-H200" + cpu: "16" + memory: "141Gi" + ephemeralStorage: "100" + platforms: + novita: + productId: "novita-h200-141gb-product-id" + region: "eu-west-1" +` + specsFile := filepath.Join(tmpDir, "specs.yaml") + if err := os.WriteFile(specsFile, []byte(specsContent), 0644); err != nil { + t.Fatalf("Failed to create test specs file: %v", err) + } + return tmpDir +} + +// createTestProvider creates a test provider with mock client +func createTestProvider(mockCli *mockClient) *NovitaDeploymentProvider { + tmpDir, _ := os.MkdirTemp("", "novita-test-*") + specsContent := `specs: + - name: novita-h100-single + displayName: "Novita H100 1x GPU" + category: gpu + resourceType: serverless + resources: + gpu: "1" + gpuType: "NVIDIA-H100" + cpu: "16" + memory: "80Gi" + ephemeralStorage: "100" + platforms: + novita: + productId: "novita-h100-80gb-product-id" + region: "us-east-1" + - name: novita-a100-single + displayName: "Novita A100 1x GPU" + category: gpu + resourceType: serverless + resources: + gpu: "1" + gpuType: "NVIDIA-A100" + cpu: "12" + memory: "40Gi" + ephemeralStorage: "100" + platforms: + novita: + productId: "novita-a100-40gb-product-id" + region: "us-east-1" + - name: novita-a10-single + displayName: "Novita A10 1x GPU" + category: gpu + resourceType: serverless + resources: + gpu: "1" + gpuType: "NVIDIA-A10" + cpu: "8" + memory: "32Gi" + ephemeralStorage: "100" + platforms: + novita: + productId: "novita-a10-24gb-product-id" + region: "us-east-1" + - name: novita-h200-single + displayName: "Novita H200 1x GPU" + category: gpu + resourceType: serverless + resources: + gpu: "1" + gpuType: "NVIDIA-H200" + cpu: "16" + memory: "141Gi" + ephemeralStorage: "100" + platforms: + novita: + productId: "novita-h200-141gb-product-id" + region: "us-east-1" +` + specsFile := filepath.Join(tmpDir, "specs.yaml") + if err := os.WriteFile(specsFile, []byte(specsContent), 0644); err != nil { + panic(fmt.Sprintf("Failed to write specs file: %v", err)) + } + + specsConfig, err := NewSpecsConfig(tmpDir) + if err != nil { + panic(fmt.Sprintf("Failed to create specs config: %v", err)) + } + + return &NovitaDeploymentProvider{ + client: clientInterface(mockCli), + specsConfig: specsConfig, + config: &config.NovitaConfig{}, + } +} + +// TestNewNovitaDeploymentProvider tests provider creation +func TestNewNovitaDeploymentProvider(t *testing.T) { + // Create temporary specs file for valid test case + tmpDir := "config/specs-novita-example.yaml" + + tests := []struct { + name string + config *config.Config + wantErr bool + }{ + { + name: "valid config", + config: &config.Config{ + Novita: config.NovitaConfig{ + Enabled: true, + APIKey: "test-api-key", + BaseURL: "https://api.novita.ai", + ConfigDir: tmpDir, + }, + }, + wantErr: false, + }, + { + name: "novita not enabled", + config: &config.Config{ + Novita: config.NovitaConfig{ + Enabled: false, + APIKey: "test-api-key", + ConfigDir: tmpDir, + }, + }, + wantErr: true, + }, + { + name: "missing API key", + config: &config.Config{ + Novita: config.NovitaConfig{ + Enabled: true, + APIKey: "", + ConfigDir: tmpDir, + }, + }, + wantErr: true, + }, + { + name: "missing specs file", + config: &config.Config{ + Novita: config.NovitaConfig{ + Enabled: true, + APIKey: "test-api-key", + ConfigDir: "/nonexistent", + }, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewNovitaDeploymentProvider(tt.config) + if (err != nil) != tt.wantErr { + t.Errorf("NewNovitaDeploymentProvider() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +// TestMapDeployRequestToNovita tests request mapping +func TestMapDeployRequestToNovita(t *testing.T) { + req := &interfaces.DeployRequest{ + Endpoint: "test-endpoint", + SpecName: "test-spec", + Image: "test-image:latest", + Replicas: 2, + Env: map[string]string{ + "MODEL_NAME": "test-model", + }, + Labels: map[string]string{ + "region": "us-east-1", + }, + } + + // Create test spec info with proper PlatformConfig + platformCfg := PlatformConfig{ + ProductID: "test-product-id", + Region: "us-east-1", + } + + specInfo := &interfaces.SpecInfo{ + Name: "test-spec", + DisplayName: "Test Spec", + Category: CategoryGPU, + Resources: interfaces.ResourceRequirements{ + GPU: "1", + GPUType: GPUTypeH100, + CPU: "16", + Memory: "80Gi", + EphemeralStorage: "100", + }, + Platforms: map[string]interface{}{ + PlatformNovita: platformCfg, + }, + } + + novitaReq, err := mapDeployRequestToNovita(req, specInfo) + if err != nil { + t.Fatalf("mapDeployRequestToNovita failed: %v", err) + } + + // Verify basic fields + if novitaReq.Endpoint.Name != req.Endpoint { + t.Errorf("Expected name %s, got %s", req.Endpoint, novitaReq.Endpoint.Name) + } + + if novitaReq.Endpoint.Image.Image != req.Image { + t.Errorf("Expected image %s, got %s", req.Image, novitaReq.Endpoint.Image.Image) + } + + // Verify worker config + workerCfg := novitaReq.Endpoint.WorkerConfig + if workerCfg.MinNum != req.Replicas { + t.Errorf("Expected minNum %d, got %d", req.Replicas, workerCfg.MinNum) + } + + // Verify product + expectedProductID := "test-product-id" + if len(novitaReq.Endpoint.Products) == 0 || novitaReq.Endpoint.Products[0].ID != expectedProductID { + t.Errorf("Expected product ID %s, got %v", expectedProductID, novitaReq.Endpoint.Products) + } + + // Verify environment variables + foundModel := false + foundRegion := false + for _, env := range novitaReq.Endpoint.Envs { + if env.Key == "MODEL_NAME" && env.Value == "test-model" { + foundModel = true + } + if env.Key == EnvKeyNovitaRegion && env.Value == "us-east-1" { + foundRegion = true + } + } + if !foundModel { + t.Error("Expected MODEL_NAME env var to be set") + } + if !foundRegion { + t.Error("Expected NOVITA_REGION env var to be set to us-east-1 from labels") + } +} + +// TestMapNovitaStatusToWaverless tests status mapping +func TestMapNovitaStatusToWaverless(t *testing.T) { + tests := []struct { + novitaStatus string + expected string + }{ + {"running", "Running"}, + {"stopped", "Stopped"}, + {"failed", "Failed"}, + {"pending", "Pending"}, + {"creating", "Creating"}, + {"updating", "Updating"}, + {"deleting", "Terminating"}, + {"unknown-state", "Unknown"}, + } + + for _, tt := range tests { + t.Run(tt.novitaStatus, func(t *testing.T) { + result := mapNovitaStatusToWaverless(tt.novitaStatus) + if result != tt.expected { + t.Errorf("Expected %s, got %s", tt.expected, result) + } + }) + } +} + +// TestGetDefaultEnv tests default environment variables +func TestGetDefaultEnv(t *testing.T) { + // Create temporary specs file + tmpDir := createTestSpecsFile(t) + + cfg := &config.Config{ + Novita: config.NovitaConfig{ + Enabled: true, + APIKey: "test-key", + ConfigDir: tmpDir, + }, + } + + provider, err := NewNovitaDeploymentProvider(cfg) + if err != nil { + t.Fatalf("Failed to create provider: %v", err) + } + + ctx := context.Background() + env, err := provider.GetDefaultEnv(ctx) + if err != nil { + t.Fatalf("GetDefaultEnv failed: %v", err) + } + + // Verify required fields + if env[EnvKeyNovitaProvider] != EnvValueTrue { + t.Errorf("Expected %s to be '%s'", EnvKeyNovitaProvider, EnvValueTrue) + } + + if env[EnvKeyProviderType] != EnvValueNovita { + t.Errorf("Expected %s to be '%s'", EnvKeyProviderType, EnvValueNovita) + } + + if env[EnvKeyNovitaRegion] != "eu-west-1" { + t.Errorf("Expected %s to be 'eu-west-1', got %s", EnvKeyNovitaRegion, env[EnvKeyNovitaRegion]) + } +} + +// TestConvertSpecNameToProductID tests spec to product ID conversion +// This test is no longer needed as we now extract platform config directly in mapDeployRequestToNovita +/* +func TestExtractNovitaConfig(t *testing.T) { + // Test removed - platform config is now extracted in mapDeployRequestToNovita +} +*/ + +// TestDeploy tests endpoint deployment +func TestDeploy(t *testing.T) { + mockClient := newMockClient() + provider := createTestProvider(mockClient) + ctx := context.Background() + + req := &interfaces.DeployRequest{ + Endpoint: "test-endpoint", + SpecName: SpecNameNovitaH100Single, + Image: "test-image:v1", + Replicas: 2, + Env: map[string]string{ + "MODEL_NAME": "llama-3", + "LOG_LEVEL": "info", + }, + Labels: map[string]string{ + LabelKeyRegion: "us-east-1", + }, + } + + // Test successful deployment + resp, err := provider.Deploy(ctx, req) + if err != nil { + t.Fatalf("Deploy failed: %v", err) + } + + if resp.Endpoint != req.Endpoint { + t.Errorf("Expected endpoint %s, got %s", req.Endpoint, resp.Endpoint) + } + + // Verify endpoint was created in mock + endpointID := fmt.Sprintf("ep-%s", req.Endpoint) + endpoint, err := mockClient.GetEndpoint(ctx, endpointID) + if err != nil { + t.Fatalf("Failed to get created endpoint: %v", err) + } + + if endpoint.Endpoint.Name != req.Endpoint { + t.Errorf("Expected name %s, got %s", req.Endpoint, endpoint.Endpoint.Name) + } + + if endpoint.Endpoint.Image.Image != req.Image { + t.Errorf("Expected image %s, got %s", req.Image, endpoint.Endpoint.Image.Image) + } + + // Test deployment with error + mockClient.createError = fmt.Errorf("API error") + _, err = provider.Deploy(ctx, req) + if err == nil { + t.Error("Expected error when API fails, got nil") + } +} + +// TestGetApp tests getting endpoint information +func TestGetApp(t *testing.T) { + mockClient := newMockClient() + provider := createTestProvider(mockClient) + ctx := context.Background() + + // Deploy an endpoint first + req := &interfaces.DeployRequest{ + Endpoint: "test-endpoint", + SpecName: SpecNameNovitaA100Single, + Image: "test-image:latest", + Replicas: 1, + } + + _, err := provider.Deploy(ctx, req) + if err != nil { + t.Fatalf("Failed to deploy: %v", err) + } + + // Test GetApp + appInfo, err := provider.GetApp(ctx, req.Endpoint) + if err != nil { + t.Fatalf("GetApp failed: %v", err) + } + + if appInfo.Name != req.Endpoint { + t.Errorf("Expected name %s, got %s", req.Endpoint, appInfo.Name) + } + + if appInfo.Type != TypeServerlessEndpoint { + t.Errorf("Expected type %s, got %s", TypeServerlessEndpoint, appInfo.Type) + } + + if appInfo.Status != StatusRunning { + t.Errorf("Expected status %s, got %s", StatusRunning, appInfo.Status) + } + + // Test getting non-existent endpoint + _, err = provider.GetApp(ctx, "non-existent") + if err == nil { + t.Error("Expected error when getting non-existent endpoint, got nil") + } +} + +// TestListApps tests listing all endpoints +func TestListApps(t *testing.T) { + mockClient := newMockClient() + provider := createTestProvider(mockClient) + ctx := context.Background() + + // Deploy multiple endpoints + endpoints := []string{"endpoint-1", "endpoint-2", "endpoint-3"} + for _, ep := range endpoints { + req := &interfaces.DeployRequest{ + Endpoint: ep, + SpecName: SpecNameNovitaH100Single, + Image: "test-image:latest", + Replicas: 1, + } + _, err := provider.Deploy(ctx, req) + if err != nil { + t.Fatalf("Failed to deploy %s: %v", ep, err) + } + } + + // Test ListApps + apps, err := provider.ListApps(ctx) + if err != nil { + t.Fatalf("ListApps failed: %v", err) + } + + if len(apps) != len(endpoints) { + t.Errorf("Expected %d apps, got %d", len(endpoints), len(apps)) + } + + // Verify all endpoints are in the list + foundEndpoints := make(map[string]bool) + for _, app := range apps { + foundEndpoints[app.Name] = true + } + + for _, ep := range endpoints { + if !foundEndpoints[ep] { + t.Errorf("Endpoint %s not found in list", ep) + } + } +} + +// TestDeleteApp tests endpoint deletion +func TestDeleteApp(t *testing.T) { + mockClient := newMockClient() + provider := createTestProvider(mockClient) + ctx := context.Background() + + // Deploy an endpoint + req := &interfaces.DeployRequest{ + Endpoint: "test-endpoint", + SpecName: SpecNameNovitaH100Single, + Image: "test-image:latest", + Replicas: 1, + } + + _, err := provider.Deploy(ctx, req) + if err != nil { + t.Fatalf("Failed to deploy: %v", err) + } + + // Verify endpoint exists + _, err = provider.GetApp(ctx, req.Endpoint) + if err != nil { + t.Fatalf("Endpoint should exist: %v", err) + } + + // Test DeleteApp + err = provider.DeleteApp(ctx, req.Endpoint) + if err != nil { + t.Fatalf("DeleteApp failed: %v", err) + } + + // Verify endpoint is deleted + _, err = provider.GetApp(ctx, req.Endpoint) + if err == nil { + t.Error("Expected error when getting deleted endpoint, got nil") + } + + // Test deleting non-existent endpoint + err = provider.DeleteApp(ctx, "non-existent") + if err == nil { + t.Error("Expected error when deleting non-existent endpoint, got nil") + } +} + +// TestScaleApp tests endpoint scaling +func TestScaleApp(t *testing.T) { + mockClient := newMockClient() + provider := createTestProvider(mockClient) + ctx := context.Background() + + // Deploy an endpoint + req := &interfaces.DeployRequest{ + Endpoint: "test-endpoint", + SpecName: SpecNameNovitaA100Single, + Image: "test-image:latest", + Replicas: 2, + } + + _, err := provider.Deploy(ctx, req) + if err != nil { + t.Fatalf("Failed to deploy: %v", err) + } + + // Test scale up + newReplicas := 5 + err = provider.ScaleApp(ctx, req.Endpoint, newReplicas) + if err != nil { + t.Fatalf("ScaleApp failed: %v", err) + } + + // Verify scaling + endpointID := fmt.Sprintf("ep-%s", req.Endpoint) + endpoint, err := mockClient.GetEndpoint(ctx, endpointID) + if err != nil { + t.Fatalf("Failed to get endpoint: %v", err) + } + + if endpoint.Endpoint.WorkerConfig.MinNum != newReplicas { + t.Errorf("Expected minNum %d, got %d", newReplicas, endpoint.Endpoint.WorkerConfig.MinNum) + } + + if endpoint.Endpoint.WorkerConfig.MaxNum != newReplicas { + t.Errorf("Expected maxNum %d, got %d", newReplicas, endpoint.Endpoint.WorkerConfig.MaxNum) + } + + // Test scale down + newReplicas = 1 + err = provider.ScaleApp(ctx, req.Endpoint, newReplicas) + if err != nil { + t.Fatalf("ScaleApp (scale down) failed: %v", err) + } + + endpoint, err = mockClient.GetEndpoint(ctx, endpointID) + if err != nil { + t.Fatalf("Failed to get endpoint: %v", err) + } + + if endpoint.Endpoint.WorkerConfig.MinNum != newReplicas { + t.Errorf("Expected minNum %d, got %d", newReplicas, endpoint.Endpoint.WorkerConfig.MinNum) + } + + // Test scaling non-existent endpoint + err = provider.ScaleApp(ctx, "non-existent", 3) + if err == nil { + t.Error("Expected error when scaling non-existent endpoint, got nil") + } +} + +// TestGetAppStatus tests getting endpoint status +func TestGetAppStatus(t *testing.T) { + mockClient := newMockClient() + provider := createTestProvider(mockClient) + ctx := context.Background() + + // Deploy an endpoint + req := &interfaces.DeployRequest{ + Endpoint: "test-endpoint", + SpecName: SpecNameNovitaH100Single, + Image: "test-image:latest", + Replicas: 3, + } + + _, err := provider.Deploy(ctx, req) + if err != nil { + t.Fatalf("Failed to deploy: %v", err) + } + + // Test GetAppStatus + status, err := provider.GetAppStatus(ctx, req.Endpoint) + if err != nil { + t.Fatalf("GetAppStatus failed: %v", err) + } + + if status.Endpoint != req.Endpoint { + t.Errorf("Expected endpoint %s, got %s", req.Endpoint, status.Endpoint) + } + + if status.Status != StatusRunning { + t.Errorf("Expected status %s, got %s", StatusRunning, status.Status) + } + + if status.ReadyReplicas != int32(req.Replicas) { + t.Errorf("Expected ready replicas %d, got %d", req.Replicas, status.ReadyReplicas) + } + + // Test getting status of non-existent endpoint + _, err = provider.GetAppStatus(ctx, "non-existent") + if err == nil { + t.Error("Expected error when getting status of non-existent endpoint, got nil") + } +} + +// TestUpdateDeployment tests updating endpoint deployment +func TestUpdateDeployment(t *testing.T) { + mockClient := newMockClient() + provider := createTestProvider(mockClient) + ctx := context.Background() + + // Deploy an endpoint + deployReq := &interfaces.DeployRequest{ + Endpoint: "test-endpoint", + SpecName: SpecNameNovitaH100Single, + Image: "test-image:v1", + Replicas: 2, + Env: map[string]string{ + "VERSION": "v1", + }, + } + + _, err := provider.Deploy(ctx, deployReq) + if err != nil { + t.Fatalf("Failed to deploy: %v", err) + } + + // Test update image + newImage := "test-image:v2" + updateReq := &interfaces.UpdateDeploymentRequest{ + Endpoint: deployReq.Endpoint, + Image: newImage, + } + + resp, err := provider.UpdateDeployment(ctx, updateReq) + if err != nil { + t.Fatalf("UpdateDeployment failed: %v", err) + } + + if resp.Endpoint != deployReq.Endpoint { + t.Errorf("Expected endpoint %s, got %s", deployReq.Endpoint, resp.Endpoint) + } + + // Verify update + endpointID := fmt.Sprintf("ep-%s", deployReq.Endpoint) + endpoint, err := mockClient.GetEndpoint(ctx, endpointID) + if err != nil { + t.Fatalf("Failed to get endpoint: %v", err) + } + + if endpoint.Endpoint.Image.Image != newImage { + t.Errorf("Expected image %s, got %s", newImage, endpoint.Endpoint.Image.Image) + } + + // Test update replicas + newReplicas := 5 + updateReq = &interfaces.UpdateDeploymentRequest{ + Endpoint: deployReq.Endpoint, + Replicas: &newReplicas, + } + + _, err = provider.UpdateDeployment(ctx, updateReq) + if err != nil { + t.Fatalf("UpdateDeployment (replicas) failed: %v", err) + } + + endpoint, err = mockClient.GetEndpoint(ctx, endpointID) + if err != nil { + t.Fatalf("Failed to get endpoint: %v", err) + } + + if endpoint.Endpoint.WorkerConfig.MinNum != newReplicas { + t.Errorf("Expected minNum %d, got %d", newReplicas, endpoint.Endpoint.WorkerConfig.MinNum) + } + + // Test updating non-existent endpoint + updateReq = &interfaces.UpdateDeploymentRequest{ + Endpoint: "non-existent", + Image: "new-image:latest", + } + + _, err = provider.UpdateDeployment(ctx, updateReq) + if err == nil { + t.Error("Expected error when updating non-existent endpoint, got nil") + } +} + +// TestListSpecs tests listing available specifications +func TestListSpecs(t *testing.T) { + mockClient := newMockClient() + provider := createTestProvider(mockClient) + ctx := context.Background() + + specs, err := provider.ListSpecs(ctx) + if err != nil { + t.Fatalf("ListSpecs failed: %v", err) + } + + if len(specs) == 0 { + t.Error("Expected at least one spec, got none") + } + + // Verify expected specs are present + expectedSpecs := map[string]bool{ + SpecNameNovitaH100Single: false, + SpecNameNovitaA100Single: false, + SpecNameNovitaA10Single: false, + SpecNameNovitaH200Single: false, + } + + for _, spec := range specs { + if _, ok := expectedSpecs[spec.Name]; ok { + expectedSpecs[spec.Name] = true + } + + // Verify spec has required fields + if spec.DisplayName == "" { + t.Errorf("Spec %s has empty DisplayName", spec.Name) + } + if spec.Category != CategoryGPU { + t.Errorf("Spec %s expected category %s, got %s", spec.Name, CategoryGPU, spec.Category) + } + if spec.ResourceType != ResourceTypeServerless { + t.Errorf("Spec %s expected resource type %s, got %s", spec.Name, ResourceTypeServerless, spec.ResourceType) + } + } + + // Check all expected specs were found + for name, found := range expectedSpecs { + if !found { + t.Errorf("Expected spec %s not found", name) + } + } +} + +// TestGetSpec tests getting a specific specification +func TestGetSpec(t *testing.T) { + mockClient := newMockClient() + provider := createTestProvider(mockClient) + ctx := context.Background() + + // Test getting existing spec + spec, err := provider.GetSpec(ctx, SpecNameNovitaH100Single) + if err != nil { + t.Fatalf("GetSpec failed: %v", err) + } + + if spec.Name != SpecNameNovitaH100Single { + t.Errorf("Expected spec name %s, got %s", SpecNameNovitaH100Single, spec.Name) + } + + if spec.Resources.GPUType != GPUTypeH100 { + t.Errorf("Expected GPU type %s, got %s", GPUTypeH100, spec.Resources.GPUType) + } + + // Test getting non-existent spec + _, err = provider.GetSpec(ctx, "non-existent-spec") + if err == nil { + t.Error("Expected error when getting non-existent spec, got nil") + } +} + +// TestWatchReplicas tests the replica status watching functionality +func TestWatchReplicas(t *testing.T) { + mockClient := newMockClient() + provider := createTestProvider(mockClient) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Create a test endpoint + req := &interfaces.DeployRequest{ + Endpoint: "test-watch-endpoint", + SpecName: SpecNameNovitaH100Single, + Image: "test-image:latest", + Replicas: 2, + } + + _, err := provider.Deploy(ctx, req) + if err != nil { + t.Fatalf("Deploy failed: %v", err) + } + + // Channel to receive events + eventsChan := make(chan interfaces.ReplicaEvent, 10) + + // Register callback + err = provider.WatchReplicas(ctx, func(event interfaces.ReplicaEvent) { + eventsChan <- event + }) + if err != nil { + t.Fatalf("WatchReplicas failed: %v", err) + } + + // Wait a bit for the watcher to poll + // Note: In a real test, we'd want to mock time or use smaller intervals + t.Log("WatchReplicas registered successfully, watcher is running") + + // Cancel context to stop watcher + cancel() + + // Verify callback was unregistered + provider.replicaCallbacksLock.RLock() + callbackCount := len(provider.replicaCallbacks) + provider.replicaCallbacksLock.RUnlock() + + if callbackCount != 0 { + t.Errorf("Expected 0 callbacks after context cancel, got %d", callbackCount) + } +} + +// TestWatchReplicasMultipleCallbacks tests multiple callbacks registration +func TestWatchReplicasMultipleCallbacks(t *testing.T) { + mockClient := newMockClient() + provider := createTestProvider(mockClient) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Register multiple callbacks + err1 := provider.WatchReplicas(ctx, func(event interfaces.ReplicaEvent) { + // Callback 1 + }) + if err1 != nil { + t.Fatalf("First WatchReplicas failed: %v", err1) + } + + err2 := provider.WatchReplicas(ctx, func(event interfaces.ReplicaEvent) { + // Callback 2 + }) + if err2 != nil { + t.Fatalf("Second WatchReplicas failed: %v", err2) + } + + // Verify multiple callbacks are registered + provider.replicaCallbacksLock.RLock() + callbackCount := len(provider.replicaCallbacks) + provider.replicaCallbacksLock.RUnlock() + + if callbackCount != 2 { + t.Errorf("Expected 2 callbacks, got %d", callbackCount) + } + + // Trigger callbacks manually for testing + testEvent := interfaces.ReplicaEvent{ + Name: "test-endpoint", + DesiredReplicas: 2, + ReadyReplicas: 2, + AvailableReplicas: 2, + } + provider.triggerReplicaCallbacks(testEvent) + + // Give some time for callbacks to execute + // In production, callbacks run in goroutines + t.Log("Multiple callbacks registered and triggered successfully") +} + +// TestWatchReplicasNilCallback tests error handling for nil callback +func TestWatchReplicasNilCallback(t *testing.T) { + mockClient := newMockClient() + provider := createTestProvider(mockClient) + ctx := context.Background() + + err := provider.WatchReplicas(ctx, nil) + if err == nil { + t.Error("Expected error when registering nil callback, got nil") + } + if err != nil && err.Error() != "replica callback is nil" { + t.Errorf("Expected 'replica callback is nil' error, got: %v", err) + } +} + +func TestRealScaleDown(t *testing.T) { + provider, err := NewNovitaDeploymentProvider(&config.Config{ + Novita: config.NovitaConfig{ + Enabled: true, + APIKey: "your api key here", + BaseURL: "https://api.novita.ai", + ConfigDir: "../../../config", + }, + }) + if err != nil { + t.Fatalf("Failed to create provider: %v", err) + } + ctx := context.Background() + + // Deploy an endpoint + req := &interfaces.DeployRequest{ + Endpoint: "base-test", + SpecName: "novita-5090-single", + Image: "ubuntu:22.04", + Replicas: 3, + TaskTimeout: 1200, + } + _, err = provider.Deploy(ctx, req) + if err != nil { + t.Fatalf("Failed to deploy: %v", err) + } + app, err := provider.GetApp(ctx, req.Endpoint) + if err != nil { + t.Fatalf("Failed to get app: %v", err) + } + if app.Replicas != 3 { + t.Errorf("Expected replicas %d, got %d", 3, app.Replicas) + } + + // Scale down to 0 replicas + err = provider.ScaleApp(ctx, req.Endpoint, 0) + if err != nil { + t.Fatalf("Failed to scale down: %v", err) + } + err = provider.DeleteApp(ctx, req.Endpoint) + if err != nil { + t.Fatalf("Failed to delete: %v", err) + } + +} diff --git a/pkg/deploy/novita/specs_config.go b/pkg/deploy/novita/specs_config.go new file mode 100644 index 0000000..9534bde --- /dev/null +++ b/pkg/deploy/novita/specs_config.go @@ -0,0 +1,159 @@ +package novita + +import ( + "fmt" + "os" + "path/filepath" + + "gopkg.in/yaml.v3" + + "waverless/pkg/interfaces" + "waverless/pkg/logger" +) + +const ( + PlatformNovita = "novita" +) + +// ResourceSpec 资源规格定义 +type ResourceSpec struct { + Name string `yaml:"name" json:"name"` + DisplayName string `yaml:"displayName" json:"displayName"` + Category string `yaml:"category" json:"category"` // cpu, gpu + Resources SpecResources `yaml:"resources" json:"resources"` + Platforms map[string]PlatformConfig `yaml:"platforms" json:"platforms"` +} + +// SpecResources 规格资源 +type SpecResources struct { + CPU string `yaml:"cpu,omitempty" json:"cpu,omitempty"` + Memory string `yaml:"memory" json:"memory"` + GPU string `yaml:"gpu,omitempty" json:"gpu,omitempty"` + GpuType string `yaml:"gpuType,omitempty" json:"gpuType,omitempty"` + EphemeralStorage string `yaml:"ephemeralStorage" json:"ephemeralStorage"` + ShmSize string `yaml:"shmSize,omitempty" json:"shmSize,omitempty"` // Shared memory size +} + +// PlatformConfig 平台特定配置 +type PlatformConfig struct { + ProductID string `yaml:"productId" json:"productId"` + Region string `yaml:"region" json:"region"` + CudaVersion string `yaml:"cudaVersion" json:"cudaVersion"` +} + +// SpecsConfig manages Novita specifications from specs.yaml +type SpecsConfig struct { + specs map[string]*ResourceSpec + configDir string +} + +// NewSpecsConfig creates a new specs configuration manager +func NewSpecsConfig(configDir string) (*SpecsConfig, error) { + if configDir == "" { + configDir = "config" + } + + sc := &SpecsConfig{ + specs: make(map[string]*ResourceSpec), + configDir: configDir, + } + + // Load specs from config file + if err := sc.loadSpecs(); err != nil { + return nil, fmt.Errorf("failed to load Novita specs: %w", err) + } + + // Ensure at least one spec is loaded + if len(sc.specs) == 0 { + return nil, fmt.Errorf("no Novita-compatible specs found in %s", filepath.Join(configDir, "specs.yaml")) + } + + return sc, nil +} + +// SpecsFileConfig represents the structure of specs.yaml file +type SpecsFileConfig struct { + Specs []*ResourceSpec `yaml:"specs"` +} + +// loadSpecs loads specs from specs.yaml +func (sc *SpecsConfig) loadSpecs() error { + specsFile := filepath.Join(sc.configDir, "specs.yaml") + + data, err := os.ReadFile(specsFile) + if err != nil { + return fmt.Errorf("failed to read specs file: %w", err) + } + + // Parse YAML file using SpecsFileConfig structure + var config SpecsFileConfig + if err := yaml.Unmarshal(data, &config); err != nil { + return fmt.Errorf("failed to parse specs file: %w", err) + } + + // Clear and load specs - filter for Novita-compatible specs + sc.specs = make(map[string]*ResourceSpec) + loadedCount := 0 + + for _, s := range config.Specs { + sc.specs[s.Name] = s + loadedCount++ + logger.Debugf("Loaded Novita spec: %s %+v", s.Name, s.Platforms) + } + + if loadedCount == 0 { + logger.Warnf("No Novita-compatible specs found in %s", specsFile) + } else { + logger.Infof("Loaded %d Novita-compatible specs from %s", loadedCount, specsFile) + } + return nil +} + +// GetSpec returns a specific spec info by name +func (sc *SpecsConfig) GetSpec(specName string) (*interfaces.SpecInfo, error) { + fmt.Printf("%v ----------- %s \n", sc.specs, specName) + for k, sv := range sc.specs { + fmt.Printf("%s %v \n ", k, sv) + } + resourceSpec, ok := sc.specs[specName] + if !ok { + return nil, fmt.Errorf("spec %s not found", specName) + } + + return sc.convertToSpecInfo(resourceSpec), nil +} + +// ListSpecs returns all available spec infos +func (sc *SpecsConfig) ListSpecs() []*interfaces.SpecInfo { + // Return a copy to prevent external modification + specs := make([]*interfaces.SpecInfo, 0, len(sc.specs)) + for _, spec := range sc.specs { + specs = append(specs, sc.convertToSpecInfo(spec)) + } + return specs +} + +// convertToSpecInfo converts ResourceSpec to interfaces.SpecInfo +func (sc *SpecsConfig) convertToSpecInfo(spec *ResourceSpec) *interfaces.SpecInfo { + // Convert Platforms map to map[string]interface{} + platforms := make(map[string]interface{}) + for platformName, platformConfig := range spec.Platforms { + // Keep the full PlatformConfig struct instead of converting to map + platforms[platformName] = platformConfig + } + + return &interfaces.SpecInfo{ + Name: spec.Name, + DisplayName: spec.DisplayName, + Category: spec.Category, + Resources: interfaces.ResourceRequirements{ + GPU: spec.Resources.GPU, + GPUType: spec.Resources.GpuType, + CPU: spec.Resources.CPU, + Memory: spec.Resources.Memory, + EphemeralStorage: spec.Resources.EphemeralStorage, + ShmSize: spec.Resources.ShmSize, + }, + Platforms: platforms, + } +} diff --git a/pkg/deploy/novita/types.go b/pkg/deploy/novita/types.go new file mode 100644 index 0000000..db729dd --- /dev/null +++ b/pkg/deploy/novita/types.go @@ -0,0 +1,232 @@ +package novita + +// Novita API types based on https://novita.ai/docs/api-reference/serverless-create-endpoint + +// ======================================== +// Request Types (for Create/Update API) +// ======================================== + +// CreateEndpointRequest represents the request to create a Novita endpoint +type CreateEndpointRequest struct { + Endpoint EndpointCreateConfig `json:"endpoint"` +} + +// EndpointCreateConfig represents the endpoint configuration for create/update requests +type EndpointCreateConfig struct { + Name string `json:"name,omitempty"` // Endpoint name + AppName string `json:"appName,omitempty"` // Application name (appears in URL) + WorkerConfig WorkerConfig `json:"workerConfig"` // Worker configuration + Ports []PortConfig `json:"ports"` // HTTP ports + Policy PolicyConfig `json:"policy"` // Auto-scaling policy + Image ImageConfig `json:"image"` // Image information + Products []ProductConfig `json:"products"` // Product information + RootfsSize int `json:"rootfsSize"` // System disk size (GB) + VolumeMounts []VolumeMount `json:"volumeMounts,omitempty"` // Storage information + ClusterID string `json:"clusterID,omitempty"` // Cluster ID + Envs []EnvVar `json:"envs,omitempty"` // Environment variables + Healthy *HealthCheck `json:"healthy,omitempty"` // Health check endpoint +} + +// ======================================== +// Response Types (from Get/List API) +// ======================================== + +// EndpointConfig represents the full endpoint data from Get/List Endpoint API +type EndpointConfig struct { + ID string `json:"id"` // Endpoint ID + Name string `json:"name"` // Endpoint name + AppName string `json:"appName"` // Application name + State StateInfo `json:"state"` // Endpoint state + URL string `json:"url"` // Endpoint URL + WorkerConfig WorkerConfigResponse `json:"workerConfig"` // Worker configuration (from API response) + Policy PolicyDetails `json:"policy"` // Auto-scaling policy + Image ImageDetails `json:"image"` // Image information + RootfsSize int `json:"rootfsSize"` // System disk size (GB) + VolumeMounts []VolumeMount `json:"volumeMounts"` // Storage information + Envs []EnvVar `json:"envs"` // Environment variables + Ports []PortDetails `json:"ports"` // Port information + Workers []WorkerInfo `json:"workers"` // Worker information + Products []ProductConfig `json:"products"` // Product information + Healthy *HealthCheckDetails `json:"healthy"` // Health check configuration + ClusterID string `json:"clusterID"` // Cluster ID + Log string `json:"log"` // Log path +} + +// ======================================== +// Common Types +// ======================================== + +// StateInfo represents the state information from Get Endpoint response +type StateInfo struct { + State string `json:"state"` // State: "serving", "stopped", "failed", etc. + Error string `json:"error"` // Error code if any + Message string `json:"message"` // State message +} + +// WorkerConfig represents worker configuration (for Create API - int types) +type WorkerConfig struct { + MinNum int `json:"minNum"` // Minimum number of workers + MaxNum int `json:"maxNum"` // Maximum number of workers + FreeTimeout int `json:"freeTimeout"` // Idle timeout (seconds) + MaxConcurrent int `json:"maxConcurrent"` // Maximum concurrency + GPUNum int `json:"gpuNum"` // Number of GPUs per worker + RequestTimeout int `json:"requestTimeout,omitempty"` // Request timeout (seconds) + CudaVersion string `json:"cudaVersion,omitempty"` // CUDA version +} + +// WorkerConfigResponse represents worker configuration from API responses (Get/List - string types) +type WorkerConfigResponse struct { + MinNum int `json:"minNum"` // Minimum number of workers + MaxNum int `json:"maxNum"` // Maximum number of workers + FreeTimeout string `json:"freeTimeout"` // Idle timeout (string: "300") + MaxConcurrent string `json:"maxConcurrent"` // Maximum concurrency (string: "1") + GPUNum int `json:"gpuNum"` // Number of GPUs per worker + RequestTimeout int `json:"requestTimeout"` // Request timeout (seconds) + CudaVersion string `json:"cudaVersion"` // CUDA version +} + +// PolicyDetails represents policy configuration in Get Endpoint response +type PolicyDetails struct { + Type string `json:"type"` // Policy type + Value string `json:"value"` // Policy value (string in response) +} + +// ImageDetails represents image configuration in Get Endpoint response +type ImageDetails struct { + Image string `json:"image"` // Image URL + AuthID string `json:"authId"` // Private image credential ID + Command string `json:"command"` // Container startup command +} + +// PortDetails represents port configuration in Get Endpoint response +type PortDetails struct { + Port int `json:"port"` // Port number (int in response) +} + +// WorkerInfo represents individual worker information +type WorkerInfo struct { + ID string `json:"id"` // Worker ID + State StateInfo `json:"state"` // Worker state + Log string `json:"log"` // Log path + Metrics string `json:"metrics"` // Metrics path + Healthy bool `json:"healthy"` // Health status +} + +// HealthCheckDetails represents health check configuration in Get Endpoint response +type HealthCheckDetails struct { + Path string `json:"path"` // Health check path + InitialDelay int `json:"initialDelay"` // Initial delay in seconds + Period int `json:"period"` // Check period in seconds + Timeout int `json:"timeout"` // Timeout in seconds + SuccessThreshold int `json:"successThreshold"` // Success threshold + FailureThreshold int `json:"failureThreshold"` // Failure threshold +} + +// PolicyConfig represents auto-scaling policy (for create/update requests - int value) +type PolicyConfig struct { + Type string `json:"type"` // Policy type: "queue" or "concurrency" + Value int `json:"value"` // Policy value +} + +// ImageConfig represents image configuration (for create/update requests) +type ImageConfig struct { + Image string `json:"image"` // Image URL + AuthID string `json:"authId,omitempty"` // Private image credential ID + Command string `json:"command,omitempty"` // Container startup command +} + +// PortConfig represents port configuration +type PortConfig struct { + Port int `json:"port"` // HTTP port +} + +// ProductConfig represents product configuration +type ProductConfig struct { + ID string `json:"id"` // Product ID +} + +// VolumeMount represents storage configuration +type VolumeMount struct { + Type string `json:"type"` // Storage type: "local" or "network" + Size int `json:"size,omitempty"` // Local storage size (GB) + ID string `json:"id,omitempty"` // Network storage ID + MountPath string `json:"mountPath,omitempty"` // Mount path +} + +// EnvVar represents environment variable +type EnvVar struct { + Key string `json:"key"` // Environment variable name + Value string `json:"value"` // Environment variable value +} + +// HealthCheck represents health check configuration (simple version for create) +type HealthCheck struct { + Path string `json:"path"` // Health check path +} + +// CreateEndpointResponse represents the response from creating an endpoint +type CreateEndpointResponse struct { + ID string `json:"id"` // Created endpoint ID +} + +// UpdateEndpointRequest represents the request to update an endpoint +// Note: Update API uses a flattened structure with string types for worker config +type UpdateEndpointRequest struct { + ID string `json:"id"` // Endpoint ID + Name string `json:"name"` // Endpoint name + AppName string `json:"appName"` // Application name + WorkerConfig WorkerConfigResponse `json:"workerConfig"` // Worker configuration (string types) + Policy PolicyResponse `json:"policy"` // Auto-scaling policy (string value) + Image ImageConfig `json:"image"` // Image configuration + RootfsSize int `json:"rootfsSize"` // System disk size (GB) + VolumeMounts []VolumeMount `json:"volumeMounts"` // Storage information + Envs []EnvVar `json:"envs,omitempty"` // Environment variables + Ports []PortConfig `json:"ports"` // HTTP ports + Workers []WorkerInfo `json:"workers"` // Workers list (can be null) + Products []ProductConfig `json:"products"` // Product information + Healthy *HealthCheck `json:"healthy,omitempty"` // Health check configuration +} + +// PolicyResponse represents policy configuration from API responses (string value) +type PolicyResponse struct { + Type string `json:"type"` // Policy type: "queue" or "concurrency" + Value string `json:"value"` // Policy value (string: "60") +} + +// GetEndpointResponse represents the response from getting an endpoint +// The actual API response has only one field: "endpoint" +type GetEndpointResponse struct { + Endpoint EndpointConfig `json:"endpoint"` // Endpoint data +} + +// ListEndpointsResponse represents the response from listing endpoints +type ListEndpointsResponse struct { + Endpoints []EndpointListItem `json:"endpoints"` // List of endpoints + Total int `json:"total"` // Total count +} + +// EndpointListItem represents a single endpoint in the list +// Note: Novita's ListEndpoints API returns full endpoint details with string types +type EndpointListItem struct { + ID string `json:"id"` // Endpoint ID + Name string `json:"name"` // Endpoint name + AppName string `json:"appName"` // Application name + State StateInfo `json:"state"` // Current state + WorkerConfig WorkerConfigResponse `json:"workerConfig"` // Worker configuration (string types) + Workers []WorkerInfo `json:"workers"` // Workers list (full details) + Policy PolicyDetails `json:"policy"` // Policy configuration (string value) + Image ImageDetails `json:"image"` // Image configuration + CreatedAt string `json:"createdAt"` // Creation time + UpdatedAt string `json:"updatedAt"` // Last update time +} + +// DeleteEndpointRequest represents the request to delete an endpoint +type DeleteEndpointRequest struct { + ID string `json:"id"` // Endpoint ID +} + +// ErrorResponse represents an error response from Novita API +type ErrorResponse struct { + Code int `json:"code"` // Error code + Message string `json:"message"` // Error message +} diff --git a/pkg/provider/factory.go b/pkg/provider/factory.go index 8159477..5f1bafa 100644 --- a/pkg/provider/factory.go +++ b/pkg/provider/factory.go @@ -7,6 +7,7 @@ import ( "waverless/pkg/config" "waverless/pkg/deploy/docker" "waverless/pkg/deploy/k8s" + "waverless/pkg/deploy/novita" "waverless/pkg/interfaces" ) @@ -37,6 +38,7 @@ func init() { RegisterDeploymentProvider("k8s", k8s.NewK8sDeploymentProvider) RegisterDeploymentProvider("kubernetes", k8s.NewK8sDeploymentProvider) RegisterDeploymentProvider("docker", docker.NewDockerDeploymentProvider) + RegisterDeploymentProvider("novita", novita.NewNovitaDeploymentProvider) } func (f *ProviderFactory) CreateDeploymentProvider(providerType string) (interfaces.DeploymentProvider, error) { From 088de2f7f41321b20a72bee55449ad011b930a80 Mon Sep 17 00:00:00 2001 From: xiamu Date: Fri, 23 Jan 2026 10:40:25 +0800 Subject: [PATCH 2/2] feat: novita provider support registry credential 1. novita client add registry interface. 2. novita provider deploy func add registry credential check. --- pkg/deploy/novita/client.go | 48 +++++++++++++++++++++++++++++++- pkg/deploy/novita/provider.go | 52 +++++++++++++++++++++++++++++++++++ pkg/deploy/novita/types.go | 34 +++++++++++++++++++++++ 3 files changed, 133 insertions(+), 1 deletion(-) diff --git a/pkg/deploy/novita/client.go b/pkg/deploy/novita/client.go index 443d1ba..454c642 100644 --- a/pkg/deploy/novita/client.go +++ b/pkg/deploy/novita/client.go @@ -107,6 +107,52 @@ func (c *Client) DeleteEndpoint(ctx context.Context, endpointID string) error { return err } +// CreateRegistryAuth creates a new container registry authentication +func (c *Client) CreateRegistryAuth(ctx context.Context, req *CreateRegistryAuthRequest) (*CreateRegistryAuthResponse, error) { + url := c.baseURL + "/gpu-instance/openapi/v1/repository/auth/save" + + respData, err := c.doRequest(ctx, "POST", url, req) + if err != nil { + return nil, err + } + + var resp CreateRegistryAuthResponse + if err := json.Unmarshal(respData, &resp); err != nil { + return nil, fmt.Errorf("failed to parse create registry auth response: %w", err) + } + + return &resp, nil +} + +// ListRegistryAuths lists all container registry authentications +func (c *Client) ListRegistryAuths(ctx context.Context) (*ListRegistryAuthsResponse, error) { + url := c.baseURL + "/gpu-instance/openapi/v1/repository/auths" + + respData, err := c.doRequest(ctx, "GET", url, nil) + if err != nil { + return nil, err + } + + var resp ListRegistryAuthsResponse + if err := json.Unmarshal(respData, &resp); err != nil { + return nil, fmt.Errorf("failed to parse list registry auths response: %w", err) + } + + return &resp, nil +} + +// DeleteRegistryAuth deletes a container registry authentication +func (c *Client) DeleteRegistryAuth(ctx context.Context, authID string) error { + url := c.baseURL + "/gpu-instance/openapi/v1/repository/auth/delete" + + req := &DeleteRegistryAuthRequest{ + ID: authID, + } + + _, err := c.doRequest(ctx, "POST", url, req) + return err +} + // doRequest performs an HTTP request with proper authentication func (c *Client) doRequest(ctx context.Context, method, url string, body interface{}) ([]byte, error) { var reqBody io.Reader @@ -146,7 +192,7 @@ func (c *Client) doRequest(ctx context.Context, method, url string, body interfa } // Log response for debugging - logger.Debugf("Novita API Response: Status %d, Body: %s", resp.StatusCode, string(respData)) + // logger.Debugf("Novita API Response: Status %d, Body: %s", resp.StatusCode, string(respData)) // Check for HTTP errors if resp.StatusCode < 200 || resp.StatusCode >= 300 { diff --git a/pkg/deploy/novita/provider.go b/pkg/deploy/novita/provider.go index 3f33006..3f24a7b 100644 --- a/pkg/deploy/novita/provider.go +++ b/pkg/deploy/novita/provider.go @@ -20,6 +20,10 @@ type clientInterface interface { ListEndpoints(ctx context.Context) (*ListEndpointsResponse, error) UpdateEndpoint(ctx context.Context, req *UpdateEndpointRequest) error DeleteEndpoint(ctx context.Context, endpointID string) error + // Registry Auth methods + CreateRegistryAuth(ctx context.Context, req *CreateRegistryAuthRequest) (*CreateRegistryAuthResponse, error) + ListRegistryAuths(ctx context.Context) (*ListRegistryAuthsResponse, error) + DeleteRegistryAuth(ctx context.Context, authID string) error } // replicaCallbackEntry represents a registered replica callback @@ -103,6 +107,17 @@ func (p *NovitaDeploymentProvider) Deploy(ctx context.Context, req *interfaces.D return nil, fmt.Errorf("failed to map deploy request to Novita: %w", err) } + // Handle registry credential - sync auth to Novita if provided + if req.RegistryCredential != nil { + authID, err := p.ensureRegistryAuth(ctx, req.RegistryCredential) + if err != nil { + return nil, fmt.Errorf("failed to ensure registry auth: %w", err) + } + // Set the auth ID in the request + novitaReq.Endpoint.Image.AuthID = authID + logger.Infof("Using registry auth ID: %s for image: %s", authID, req.Image) + } + // Create endpoint resp, err := p.client.CreateEndpoint(ctx, novitaReq) if err != nil { @@ -552,6 +567,43 @@ func (p *NovitaDeploymentProvider) GetDefaultEnv(ctx context.Context) (map[strin return defaultEnv, nil } +// ensureRegistryAuth ensures a registry auth exists in Novita +// Returns the auth ID (existing or newly created) +func (p *NovitaDeploymentProvider) ensureRegistryAuth(ctx context.Context, cred *interfaces.RegistryCredential) (string, error) { + if cred == nil { + return "", fmt.Errorf("registry credential is nil") + } + + // List existing registry auths + listResp, err := p.client.ListRegistryAuths(ctx) + if err != nil { + return "", fmt.Errorf("failed to list registry auths: %w", err) + } + // Check if auth already exists by matching registry name + for _, auth := range listResp.Data { + if auth.Name == cred.Registry { + logger.Infof("Found existing registry auth for %s (ID: %s)", cred.Registry, auth.ID) + return auth.ID, nil + } + } + + // Auth doesn't exist, create new one + logger.Infof("Creating new registry auth for %s", cred.Registry) + createReq := &CreateRegistryAuthRequest{ + Name: cred.Registry, + Username: cred.Username, + Password: cred.Password, + } + + createResp, err := p.client.CreateRegistryAuth(ctx, createReq) + if err != nil { + return "", fmt.Errorf("failed to create registry auth: %w", err) + } + + logger.Infof("Created registry auth for %s (ID: %s)", cred.Registry, createResp.ID) + return createResp.ID, nil +} + // getEndpointID retrieves the Novita endpoint ID for a given endpoint name // It first checks the cache, then queries the API if not found func (p *NovitaDeploymentProvider) getEndpointID(ctx context.Context, endpoint string) (string, error) { diff --git a/pkg/deploy/novita/types.go b/pkg/deploy/novita/types.go index db729dd..61d8469 100644 --- a/pkg/deploy/novita/types.go +++ b/pkg/deploy/novita/types.go @@ -230,3 +230,37 @@ type ErrorResponse struct { Code int `json:"code"` // Error code Message string `json:"message"` // Error message } + +// ======================================== +// Container Registry Auth Types +// ======================================== + +// CreateRegistryAuthRequest represents the request to create a container registry auth +type CreateRegistryAuthRequest struct { + Name string `json:"name"` // Auth name (registry URL) + Username string `json:"username"` // Username + Password string `json:"password"` // Password +} + +// CreateRegistryAuthResponse represents the response from creating a registry auth +type CreateRegistryAuthResponse struct { + ID string `json:"id"` // Created auth ID +} + +// ListRegistryAuthsResponse represents the response from listing registry auths +type ListRegistryAuthsResponse struct { + Data []RegistryAuthItem `json:"data"` // List of registry auths +} + +// RegistryAuthItem represents a single registry auth item +type RegistryAuthItem struct { + ID string `json:"id"` // Auth ID + Name string `json:"name"` // Auth name (registry URL) + Username string `json:"username"` // Username + Password string `json:"password"` // Password (may be masked) +} + +// DeleteRegistryAuthRequest represents the request to delete a registry auth +type DeleteRegistryAuthRequest struct { + ID string `json:"id"` // Auth ID to delete +}