diff --git a/cmd/ask.go b/cmd/ask.go index 5cfec76..fb7463d 100644 --- a/cmd/ask.go +++ b/cmd/ask.go @@ -23,6 +23,7 @@ import ( "github.com/bgdnvk/clanker/internal/k8s" "github.com/bgdnvk/clanker/internal/k8s/plan" "github.com/bgdnvk/clanker/internal/maker" + "github.com/bgdnvk/clanker/internal/routing" tfclient "github.com/bgdnvk/clanker/internal/terraform" "github.com/spf13/cobra" "github.com/spf13/viper" @@ -301,11 +302,11 @@ Examples: makerProvider = "aws" makerProviderReason = "explicit" default: - _, _, _, _, _, inferredGCP, inferredCloudflare := inferContext(questionForRouting(question)) - if inferredCloudflare { + svcCtx := routing.InferContext(questionForRouting(question)) + if svcCtx.Cloudflare { makerProvider = "cloudflare" makerProviderReason = "inferred" - } else if inferredGCP { + } else if svcCtx.GCP { makerProvider = "gcp" makerProviderReason = "inferred" } @@ -455,35 +456,63 @@ Format as a professional compliance table suitable for government security docum } if !includeAWS && !includeGitHub && !includeTerraform && !includeGCP && !includeCloudflare { - var inferredTerraform bool - var inferredCode bool - var inferredK8s bool - var inferredGCP bool - var inferredCloudflare bool routingQuestion := questionForRouting(question) - includeAWS, inferredCode, includeGitHub, inferredTerraform, inferredK8s, inferredGCP, inferredCloudflare = inferContext(routingQuestion) - _ = inferredCode + + // First, do quick keyword check for explicit terms + svcCtx := routing.InferContext(routingQuestion) + includeAWS = svcCtx.AWS + includeGitHub = svcCtx.GitHub if debug { - fmt.Printf("Inferred context: AWS=%v, GitHub=%v, Terraform=%v, K8s=%v, GCP=%v, Cloudflare=%v\n", includeAWS, includeGitHub, inferredTerraform, inferredK8s, inferredGCP, inferredCloudflare) + fmt.Printf("Keyword inference: AWS=%v, GitHub=%v, Terraform=%v, K8s=%v, GCP=%v, Cloudflare=%v\n", + svcCtx.AWS, svcCtx.GitHub, svcCtx.Terraform, svcCtx.K8s, svcCtx.GCP, svcCtx.Cloudflare) + } + + // For ambiguous queries (multiple services detected or Cloudflare detected), + // use LLM to make the final routing decision + if routing.NeedsLLMClassification(svcCtx) { + if debug { + fmt.Println("[routing] Ambiguous query detected, using LLM for classification...") + } + + llmService, err := routing.ClassifyWithLLM(context.Background(), routingQuestion, debug) + if err != nil { + // FALLBACK: LLM classification failed, use keyword-based inference + if debug { + fmt.Printf("[routing] LLM classification failed (%v), falling back to keyword inference\n", err) + } + // Keep the keyword-inferred values as-is (no changes needed) + } else { + // LLM succeeded - override keyword-based inference with LLM decision + routing.ApplyLLMClassification(&svcCtx, llmService) + + if debug { + fmt.Printf("LLM override: AWS=%v, K8s=%v, GCP=%v, Cloudflare=%v\n", + svcCtx.AWS, svcCtx.K8s, svcCtx.GCP, svcCtx.Cloudflare) + } + } } // Handle inferred Terraform context - if inferredTerraform { + if svcCtx.Terraform { includeTerraform = true } - if inferredGCP { + if svcCtx.GCP { includeGCP = true } + // Update includeAWS and includeGitHub from service context + includeAWS = svcCtx.AWS + includeGitHub = svcCtx.GitHub + // Handle Cloudflare queries by delegating to Cloudflare agent - if inferredCloudflare { + if svcCtx.Cloudflare { return handleCloudflareQuery(context.Background(), routingQuestion, debug) } // Handle K8s queries by delegating to K8s agent - if inferredK8s { + if svcCtx.K8s { return handleK8sQuery(context.Background(), routingQuestion, debug, viper.GetString("kubernetes.kubeconfig")) } } @@ -908,145 +937,6 @@ func resolveGeminiModel(provider, flagValue string) string { return model } -// inferContext tries to determine if the question is about AWS, GitHub, Terraform, Kubernetes, GCP, or Cloudflare. -// Code scanning is disabled, so this never infers code context. -func inferContext(question string) (aws bool, code bool, github bool, terraform bool, k8s bool, gcp bool, cf bool) { - awsKeywords := []string{ - // Core services - "ec2", "lambda", "rds", "s3", "ecs", "cloudwatch", "logs", "batch", "sqs", "sns", "dynamodb", "elasticache", "elb", "alb", "nlb", "route53", "cloudfront", "api-gateway", "cognito", "iam", "vpc", "subnet", "security-group", "nacl", "nat", "igw", "vpn", "direct-connect", - // General terms - "instance", "bucket", "database", "aws", "resources", "infrastructure", "running", "account", "error", "log", "job", "queue", "compute", "storage", "network", "cdn", "load-balancer", "auto-scaling", "scaling", "health", "metric", "alarm", "notification", "backup", "snapshot", "ami", "volume", "ebs", "efs", "fsx", - // Compute and GPU terms - "gpu", "cuda", "ml", "machine-learning", "training", "inference", "p2", "p3", "p4", "g3", "g4", "g5", "spot", "reserved", "dedicated", - // Status and operations - "status", "state", "healthy", "unhealthy", "available", "pending", "stopping", "stopped", "terminated", "creating", "deleting", "modifying", "active", "inactive", "enabled", "disabled", - // Cost and billing - "cost", "billing", "price", "usage", "spend", "budget", - // Monitoring and debugging - "monitor", "trace", "debug", "performance", "latency", "throughput", "error-rate", "failure", "timeout", "retry", - // Infrastructure discovery - "services", "active", "deployed", "discovery", "overview", "summary", "list-all", "what's-running", "what-services", "infrastructure-overview", - } - - githubKeywords := []string{ - // GitHub platform - "github", "git", "repository", "repo", "fork", "clone", "branch", "tag", "release", "issue", "discussion", - // CI/CD and Actions - "action", "workflow", "ci", "cd", "build", "deploy", "deployment", "pipeline", "job", "step", "runner", "artifact", - // Collaboration - "pr", "pull", "request", "merge", "commit", "push", "pull-request", "review", "approve", "comment", "assignee", "reviewer", - // Project management - "milestone", "project", "board", "epic", "story", "task", "bug", "feature", "enhancement", "label", "status", - // Security and compliance - "security", "vulnerability", "dependabot", "secret", "token", "permission", "access", "audit", - } - - terraformKeywords := []string{ - // Terraform core - "terraform", "tf", "hcl", "plan", "apply", "destroy", "init", "workspace", "state", "backend", "provider", "resource", "data", "module", "variable", "output", "local", - // Operations - "infrastructure-as-code", "iac", "provisioning", "deployment", "environment", "stack", "configuration", "template", - // State management - "tfstate", "state-file", "remote-state", "lock", "unlock", "drift", "refresh", "import", "taint", "untaint", - // Workspaces and environments - "dev", "stage", "staging", "prod", "production", "qa", "environment", "workspace", - } - - k8sKeywords := []string{ - // Core K8s terms - "kubernetes", "k8s", "kubectl", "kube", - // Workloads - "pod", "pods", "deployment", "deployments", "replicaset", "statefulset", - "daemonset", "job", "cronjob", - // Networking - "service", "services", "ingress", "loadbalancer", "nodeport", "clusterip", - "networkpolicy", "endpoint", - // Storage - "pv", "pvc", "persistentvolume", "storageclass", "configmap", "secret", - // Cluster - "node", "nodes", "namespace", "cluster", "kubeconfig", "context", - // Tools - "helm", "chart", "release", "tiller", - // Providers - "eks", "kubeadm", "kops", "k3s", "minikube", - // Operations - "rollout", "scale", "drain", "cordon", "taint", - } - - gcpKeywords := []string{ - "gcp", "google cloud", "cloud run", "cloudrun", "cloud sql", "cloudsql", "gke", "gcs", "cloud storage", - "pubsub", "pub/sub", "cloud functions", "cloud function", "compute engine", "gce", "iam service account", - "workload identity", "artifact registry", "secret manager", "bigquery", "spanner", "bigtable", - "cloud build", "cloud deploy", "cloud dns", "cloud armor", "cloud load balancing", "api gateway", - } - - cloudflareKeywords := []string{ - // Platform and tools - "cloudflare", "cf", "wrangler", "cloudflared", - // DNS - "dns record", "zone", "nameserver", "cname", "a record", "aaaa record", "mx record", "txt record", - // Workers and edge - "worker", "workers", "kv", "d1", "r2", "pages", "durable objects", - // Security - "waf", "firewall rule", "rate limit", "ddos", "bot management", "page rules", - // Zero Trust - "tunnel", "access", "zero trust", "warp", "cloudflare access", - // Analytics and performance - "cdn", "cache", "analytics", "web analytics", - } - - questionLower := strings.ToLower(question) - - for _, keyword := range awsKeywords { - if contains(questionLower, keyword) { - aws = true - break - } - } - - for _, keyword := range githubKeywords { - if contains(questionLower, keyword) { - github = true - break - } - } - - for _, keyword := range terraformKeywords { - if contains(questionLower, keyword) { - terraform = true - break - } - } - - for _, keyword := range k8sKeywords { - if contains(questionLower, keyword) { - k8s = true - break - } - } - - for _, keyword := range gcpKeywords { - if contains(questionLower, keyword) { - gcp = true - break - } - } - - for _, keyword := range cloudflareKeywords { - if contains(questionLower, keyword) { - cf = true - break - } - } - - // If no specific context detected, include AWS + GitHub by default. - if !aws && !github && !terraform && !k8s && !gcp && !cf { - aws, github = true, true - } - - return aws, code, github, terraform, k8s, gcp, cf -} - func questionForRouting(question string) string { trimmed := strings.TrimSpace(question) if trimmed == "" { @@ -1728,10 +1618,6 @@ func extractDeployName(question string) string { return "" } -func contains(s, substr string) bool { - return strings.Contains(strings.ToLower(s), strings.ToLower(substr)) -} - // formatK8sCommand formats a command for display (like AWS maker formatAWSArgsForLog) func formatK8sCommand(cmdName string, args []string) string { const maxArgLen = 160 diff --git a/internal/routing/routing.go b/internal/routing/routing.go new file mode 100644 index 0000000..32c7e42 --- /dev/null +++ b/internal/routing/routing.go @@ -0,0 +1,337 @@ +// Package routing provides query routing and classification for cloud services. +package routing + +import ( + "context" + "encoding/json" + "fmt" + "os" + "strings" + + "github.com/bgdnvk/clanker/internal/ai" + "github.com/spf13/viper" +) + +// ServiceContext represents which services were detected in a query +type ServiceContext struct { + AWS bool + GitHub bool + Terraform bool + K8s bool + GCP bool + Cloudflare bool + Code bool +} + +// Classification represents the result of LLM-based query classification +type Classification struct { + Service string `json:"service"` + Confidence string `json:"confidence"` + Reason string `json:"reason"` +} + +// InferContext analyzes a question and determines which cloud service contexts are relevant. +// Returns a ServiceContext with boolean flags for each detected service. +func InferContext(question string) ServiceContext { + ctx := ServiceContext{} + + awsKeywords := []string{ + // Core services + "ec2", "lambda", "rds", "s3", "ecs", "cloudwatch", "logs", "batch", + "sqs", "sns", "dynamodb", "elasticache", "elb", "alb", "nlb", "route53", + "cloudfront", "api-gateway", "cognito", "iam", "vpc", "subnet", + "security-group", "nacl", "nat", "igw", "vpn", "direct-connect", + // General terms that strongly indicate AWS context + "instance", "bucket", "database", "aws", "resources", "infrastructure", + "running", "account", "error", "log", "job", "queue", "compute", + "storage", "network", "cdn", "load-balancer", "auto-scaling", "scaling", + "health", "metric", "alarm", "notification", "backup", "snapshot", + "ami", "volume", "ebs", "efs", "fsx", + // ML/GPU + "gpu", "cuda", "ml", "machine-learning", "training", "inference", + "p2", "p3", "p4", "g3", "g4", "g5", "spot", "reserved", "dedicated", + // Status keywords + "status", "state", "healthy", "unhealthy", "available", "pending", + "stopping", "stopped", "terminated", "creating", "deleting", "modifying", + "active", "inactive", "enabled", "disabled", + // Cost keywords + "cost", "billing", "price", "usage", "spend", "budget", + // Monitoring keywords + "monitor", "trace", "debug", "performance", "latency", "throughput", + "error-rate", "failure", "timeout", "retry", + // Discovery keywords + "services", "active", "deployed", "discovery", "overview", "summary", + "list-all", "what's-running", "what-services", "infrastructure-overview", + } + + githubKeywords := []string{ + // Platform + "github", "git", "repository", "repo", "fork", "clone", "branch", "tag", "release", + "issue", "discussion", + // CI/CD + "action", "workflow", "ci", "cd", "build", "deploy", "deployment", + "pipeline", "job", "step", "runner", "artifact", + // Collaboration + "pr", "pull", "request", "merge", "commit", "push", "pull-request", + "review", "approve", "comment", "assignee", "reviewer", + // Project management + "milestone", "project", "board", "epic", "story", "task", "bug", + "feature", "enhancement", "label", "status", + // Security + "security", "vulnerability", "dependabot", "secret", "token", + "permission", "access", "audit", + } + + terraformKeywords := []string{ + // Core + "terraform", "tf ", "hcl", "plan", "apply", "destroy", "init", + "workspace", "state", "backend", "provider", "resource", "data", + "module", "variable", "output", "local", + // Operations + "infrastructure-as-code", "iac", "provisioning", "deployment", + "environment", "stack", "configuration", "template", + // State management + "tfstate", "state-file", "remote-state", "lock", "unlock", + "drift", "refresh", "import", "taint", "untaint", + // Environments + "dev", "stage", "staging", "prod", "production", "qa", "environment", "workspace", + } + + k8sKeywords := []string{ + // Core K8s terms + "kubernetes", "k8s", "kubectl", "kube", + // Workloads + "pod", "pods", "deployment", "deployments", "replicaset", "statefulset", + "daemonset", "job", "cronjob", + // Networking + "service", "services", "ingress", "loadbalancer", "nodeport", "clusterip", + "networkpolicy", "endpoint", + // Storage + "pv", "pvc", "persistentvolume", "storageclass", "configmap", "secret", + // Cluster + "node", "nodes", "namespace", "cluster", "kubeconfig", "context", + // Tools + "helm", "chart", "release", "tiller", + // Providers + "eks", "kubeadm", "kops", "k3s", "minikube", + // Operations + "rollout", "scale", "drain", "cordon", "taint", + } + + gcpKeywords := []string{ + "gcp", "google cloud", "cloud run", "cloudrun", "cloud sql", "cloudsql", "gke", "gcs", "cloud storage", + "pubsub", "pub/sub", "cloud functions", "cloud function", "compute engine", "gce", "iam service account", + "workload identity", "artifact registry", "secret manager", "bigquery", "spanner", "bigtable", + "cloud build", "cloud deploy", "cloud dns", "cloud armor", "cloud load balancing", "api gateway", + } + + cloudflareKeywords := []string{ + // Only match if Cloudflare is explicitly mentioned + "cloudflare", + // Cloudflare-specific CLI tools (unique to Cloudflare) + "wrangler", + "cloudflared", + } + + questionLower := strings.ToLower(question) + + for _, keyword := range awsKeywords { + if contains(questionLower, keyword) { + ctx.AWS = true + break + } + } + + for _, keyword := range githubKeywords { + if contains(questionLower, keyword) { + ctx.GitHub = true + break + } + } + + for _, keyword := range terraformKeywords { + if contains(questionLower, keyword) { + ctx.Terraform = true + break + } + } + + for _, keyword := range k8sKeywords { + if contains(questionLower, keyword) { + ctx.K8s = true + break + } + } + + for _, keyword := range gcpKeywords { + if contains(questionLower, keyword) { + ctx.GCP = true + break + } + } + + for _, keyword := range cloudflareKeywords { + if contains(questionLower, keyword) { + ctx.Cloudflare = true + break + } + } + + // Default to AWS and GitHub context if nothing is detected + if !ctx.AWS && !ctx.GitHub && !ctx.Terraform && !ctx.K8s && !ctx.GCP && !ctx.Cloudflare { + ctx.AWS = true + ctx.GitHub = true + } + + return ctx +} + +// GetClassificationPrompt returns a prompt for LLM to classify which service a query is about +func GetClassificationPrompt(question string) string { + return fmt.Sprintf(`Classify which cloud service or platform this user query is about. + +User Query: "%s" + +Available services: +- cloudflare: Cloudflare CDN, DNS, Workers, KV, D1, R2, Pages, WAF, Tunnels, Zero Trust, Analytics +- aws: Amazon Web Services (EC2, Lambda, S3, RDS, VPC, Route53, CloudFront, IAM, ECS, etc.) +- k8s: Kubernetes clusters, pods, deployments, services, helm, kubectl +- gcp: Google Cloud Platform (Cloud Run, GKE, Cloud SQL, BigQuery, etc.) +- github: GitHub repositories, PRs, issues, actions, workflows +- terraform: Infrastructure as code, Terraform plans, state, modules +- general: General questions not specific to any cloud platform + +IMPORTANT RULES: +1. Only classify as "cloudflare" if the query EXPLICITLY mentions Cloudflare, wrangler, cloudflared, or Cloudflare-specific products +2. Generic terms like "cdn", "cache", "dns", "worker", "waf", "rate limit", "tunnel" should default to AWS unless Cloudflare is explicitly mentioned +3. If the query mentions AWS services (EC2, Lambda, S3, CloudFront, Route53, etc.), classify as "aws" +4. If uncertain, classify as "aws" (the default cloud provider) + +Respond with ONLY a JSON object: +{ + "service": "cloudflare|aws|k8s|gcp|github|terraform|general", + "confidence": "high|medium|low", + "reason": "brief explanation of why this classification" +}`, question) +} + +// ClassifyWithLLM uses the AI client to determine which service a query is about. +// Returns the service name and any error encountered. +func ClassifyWithLLM(ctx context.Context, question string, debug bool) (string, error) { + // Get provider config + provider := viper.GetString("ai.default_provider") + if provider == "" { + provider = "openai" + } + + var apiKey string + switch provider { + case "openai": + apiKey = os.Getenv("OPENAI_API_KEY") + if apiKey == "" { + apiKey = viper.GetString("ai.providers.openai.api_key") + } + case "anthropic": + apiKey = os.Getenv("ANTHROPIC_API_KEY") + if apiKey == "" { + apiKey = viper.GetString("ai.providers.anthropic.api_key") + } + case "gemini", "gemini-api": + apiKey = os.Getenv("GEMINI_API_KEY") + } + + // Create minimal AI client for classification + aiClient := ai.NewClient(provider, apiKey, debug, "") + + prompt := GetClassificationPrompt(question) + response, err := aiClient.AskPrompt(ctx, prompt) + if err != nil { + if debug { + fmt.Printf("[routing] LLM classification failed: %v, falling back to keyword matching\n", err) + } + return "", err + } + + // Parse the JSON response + var classification Classification + + // Clean response and parse JSON + cleaned := aiClient.CleanJSONResponse(response) + if err := json.Unmarshal([]byte(cleaned), &classification); err != nil { + if debug { + fmt.Printf("[routing] Failed to parse classification response: %v\n", err) + } + return "", err + } + + if debug { + fmt.Printf("[routing] LLM classification: service=%s, confidence=%s, reason=%s\n", + classification.Service, classification.Confidence, classification.Reason) + } + + return classification.Service, nil +} + +// NeedsLLMClassification determines if a query needs LLM classification +// based on ambiguity (multiple services detected) or Cloudflare being inferred. +func NeedsLLMClassification(ctx ServiceContext) bool { + // Count how many services were inferred + count := 0 + if ctx.AWS { + count++ + } + if ctx.K8s { + count++ + } + if ctx.GCP { + count++ + } + if ctx.Cloudflare { + count++ + } + + // Use LLM classification if: + // 1. Multiple services inferred (ambiguous) + // 2. Cloudflare was inferred (verify it's actually Cloudflare-related) + return count > 1 || ctx.Cloudflare +} + +// ApplyLLMClassification updates the ServiceContext based on LLM classification result +func ApplyLLMClassification(ctx *ServiceContext, llmService string) { + switch llmService { + case "cloudflare": + ctx.Cloudflare = true + ctx.K8s = false + ctx.GCP = false + ctx.AWS = false + case "k8s": + ctx.K8s = true + ctx.Cloudflare = false + ctx.GCP = false + case "gcp": + ctx.GCP = true + ctx.Cloudflare = false + ctx.K8s = false + case "aws": + ctx.AWS = true + ctx.Cloudflare = false + ctx.K8s = false + ctx.GCP = false + case "terraform": + ctx.Terraform = true + ctx.Cloudflare = false + case "github": + ctx.GitHub = true + ctx.Cloudflare = false + default: + // "general" - default to AWS + ctx.AWS = true + ctx.Cloudflare = false + ctx.K8s = false + } +} + +// contains checks if s contains substr (case-insensitive) +func contains(s, substr string) bool { + return strings.Contains(strings.ToLower(s), strings.ToLower(substr)) +} diff --git a/internal/routing/routing_test.go b/internal/routing/routing_test.go new file mode 100644 index 0000000..366f9d9 --- /dev/null +++ b/internal/routing/routing_test.go @@ -0,0 +1,337 @@ +package routing + +import ( + "testing" +) + +func TestInferContext_CloudflareExplicit(t *testing.T) { + tests := []struct { + name string + query string + expectCloudflare bool + expectAWS bool + expectK8s bool + }{ + { + name: "explicit cloudflare mention", + query: "list my cloudflare zones", + expectCloudflare: true, + expectAWS: false, + expectK8s: false, + }, + { + name: "wrangler tool mention", + query: "wrangler deploy my worker", + expectCloudflare: true, + expectAWS: false, + expectK8s: false, + }, + { + name: "cloudflared tool mention", + query: "cloudflared tunnel list", + expectCloudflare: true, + expectAWS: false, + expectK8s: false, + }, + { + name: "generic cache should not trigger cloudflare", + query: "show cache hit rate", + expectCloudflare: false, + expectAWS: true, + expectK8s: false, + }, + { + name: "generic cdn should not trigger cloudflare", + query: "list cdn distributions", + expectCloudflare: false, + expectAWS: true, + expectK8s: false, + }, + { + name: "generic worker should not trigger cloudflare", + query: "show worker processes", + expectCloudflare: false, + expectAWS: false, + expectK8s: false, + }, + { + name: "generic waf should not trigger cloudflare", + query: "list waf rules", + expectCloudflare: false, + expectAWS: true, + expectK8s: false, + }, + { + name: "generic rate limit should not trigger cloudflare", + query: "show rate limits", + expectCloudflare: false, + expectAWS: true, // "rate" triggers AWS keyword match + expectK8s: false, + }, + { + name: "generic dns should not trigger cloudflare", + query: "list dns records", + expectCloudflare: false, + expectAWS: true, + expectK8s: false, + }, + { + name: "ec2 should trigger aws", + query: "list ec2 instances", + expectCloudflare: false, + expectAWS: true, + expectK8s: false, + }, + { + name: "lambda should trigger aws", + query: "show lambda functions", + expectCloudflare: false, + expectAWS: true, + expectK8s: false, + }, + { + name: "pods should trigger k8s", + query: "list pods", + expectCloudflare: false, + expectAWS: false, + expectK8s: true, + }, + { + name: "kubernetes should trigger k8s", + query: "show kubernetes deployments", + expectCloudflare: false, + expectAWS: false, + expectK8s: true, + }, + { + name: "kubectl should trigger k8s", + query: "kubectl get nodes", + expectCloudflare: false, + expectAWS: false, + expectK8s: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := InferContext(tt.query) + + if ctx.Cloudflare != tt.expectCloudflare { + t.Errorf("InferContext(%q) cloudflare = %v, want %v", tt.query, ctx.Cloudflare, tt.expectCloudflare) + } + if ctx.AWS != tt.expectAWS { + t.Errorf("InferContext(%q) aws = %v, want %v", tt.query, ctx.AWS, tt.expectAWS) + } + if ctx.K8s != tt.expectK8s { + t.Errorf("InferContext(%q) k8s = %v, want %v", tt.query, ctx.K8s, tt.expectK8s) + } + }) + } +} + +func TestInferContext_NoCloudflarefalsePositives(t *testing.T) { + // These queries should NOT trigger Cloudflare routing + noCloudflareQueries := []string{ + "what is the cache hit rate", + "show cdn distribution", + "list workers", + "show rate limits", + "check waf status", + "create tunnel to database", + "show analytics dashboard", + "configure access control", + "deploy to pages", + "list dns records for route53", + "show cloudfront distributions", + } + + for _, query := range noCloudflareQueries { + t.Run(query, func(t *testing.T) { + ctx := InferContext(query) + if ctx.Cloudflare { + t.Errorf("InferContext(%q) incorrectly triggered Cloudflare routing", query) + } + }) + } +} + +func TestInferContext_DefaultBehavior(t *testing.T) { + // Unknown queries should default to AWS + GitHub + ctx := InferContext("random question about nothing") + + if !ctx.AWS { + t.Error("Unknown query should default to AWS=true") + } + if !ctx.GitHub { + t.Error("Unknown query should default to GitHub=true") + } + if ctx.Cloudflare { + t.Error("Unknown query should not trigger Cloudflare") + } + if ctx.K8s { + t.Error("Unknown query should not trigger K8s") + } +} + +func TestGetClassificationPrompt(t *testing.T) { + prompt := GetClassificationPrompt("list my cloudflare zones") + + if prompt == "" { + t.Error("GetClassificationPrompt returned empty string") + } + + expectedPhrases := []string{ + "cloudflare", + "aws", + "k8s", + "gcp", + "JSON object", + "service", + } + + for _, phrase := range expectedPhrases { + if !contains(prompt, phrase) { + t.Errorf("GetClassificationPrompt missing expected phrase: %s", phrase) + } + } +} + +func TestNeedsLLMClassification(t *testing.T) { + tests := []struct { + name string + ctx ServiceContext + expect bool + }{ + { + name: "cloudflare detected needs verification", + ctx: ServiceContext{Cloudflare: true}, + expect: true, + }, + { + name: "multiple services need disambiguation", + ctx: ServiceContext{AWS: true, K8s: true}, + expect: true, + }, + { + name: "single aws does not need llm", + ctx: ServiceContext{AWS: true}, + expect: false, + }, + { + name: "single k8s does not need llm", + ctx: ServiceContext{K8s: true}, + expect: false, + }, + { + name: "aws and cloudflare needs llm", + ctx: ServiceContext{AWS: true, Cloudflare: true}, + expect: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := NeedsLLMClassification(tt.ctx) + if result != tt.expect { + t.Errorf("NeedsLLMClassification(%+v) = %v, want %v", tt.ctx, result, tt.expect) + } + }) + } +} + +func TestApplyLLMClassification(t *testing.T) { + tests := []struct { + name string + llmService string + expectAWS bool + expectCF bool + expectK8s bool + expectGCP bool + }{ + { + name: "cloudflare classification", + llmService: "cloudflare", + expectCF: true, + expectAWS: false, + expectK8s: false, + expectGCP: false, + }, + { + name: "aws classification", + llmService: "aws", + expectAWS: true, + expectCF: false, + expectK8s: false, + expectGCP: false, + }, + { + name: "k8s classification", + llmService: "k8s", + expectK8s: true, + expectAWS: false, + expectCF: false, + expectGCP: false, + }, + { + name: "gcp classification", + llmService: "gcp", + expectGCP: true, + expectAWS: false, + expectCF: false, + expectK8s: false, + }, + { + name: "general defaults to aws", + llmService: "general", + expectAWS: true, + expectCF: false, + expectK8s: false, + expectGCP: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := ServiceContext{} + ApplyLLMClassification(&ctx, tt.llmService) + + if ctx.AWS != tt.expectAWS { + t.Errorf("ApplyLLMClassification(%q) AWS = %v, want %v", tt.llmService, ctx.AWS, tt.expectAWS) + } + if ctx.Cloudflare != tt.expectCF { + t.Errorf("ApplyLLMClassification(%q) Cloudflare = %v, want %v", tt.llmService, ctx.Cloudflare, tt.expectCF) + } + if ctx.K8s != tt.expectK8s { + t.Errorf("ApplyLLMClassification(%q) K8s = %v, want %v", tt.llmService, ctx.K8s, tt.expectK8s) + } + if ctx.GCP != tt.expectGCP { + t.Errorf("ApplyLLMClassification(%q) GCP = %v, want %v", tt.llmService, ctx.GCP, tt.expectGCP) + } + }) + } +} + +func TestContains(t *testing.T) { + tests := []struct { + s string + substr string + expect bool + }{ + {"Hello World", "world", true}, + {"Hello World", "WORLD", true}, + {"cloudflare zones", "cloudflare", true}, + {"list ec2", "EC2", true}, + {"kubernetes pods", "k8s", false}, + {"", "test", false}, + {"test", "", true}, + } + + for _, tt := range tests { + t.Run(tt.s+"_"+tt.substr, func(t *testing.T) { + result := contains(tt.s, tt.substr) + if result != tt.expect { + t.Errorf("contains(%q, %q) = %v, want %v", tt.s, tt.substr, result, tt.expect) + } + }) + } +}