From bb1b97c296d4f9172821dc07820222fa1837297a Mon Sep 17 00:00:00 2001 From: nash Date: Mon, 26 Jan 2026 21:50:09 +0000 Subject: [PATCH 1/4] feat(ask): add LLM-based routing classification Replace keyword-based routing with LLM classification for ambiguous queries. The LLM analyzes query context to determine which service (AWS, Cloudflare, K8s, GCP) the user is asking about. This prevents generic terms like "cache", "cdn", "worker" from incorrectly routing to Cloudflare when the user means AWS. LLM classification is triggered when: - Multiple services are detected by keyword matching - Cloudflare keywords are detected (to verify intent) Fallback: If LLM classification fails, keyword-based inference is used. Explicit flags (--cloudflare, --aws, --k8s) bypass LLM classification. --- cmd/ask.go | 190 +++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 177 insertions(+), 13 deletions(-) diff --git a/cmd/ask.go b/cmd/ask.go index 5cfec76..bcf0ae6 100644 --- a/cmd/ask.go +++ b/cmd/ask.go @@ -461,11 +461,93 @@ Format as a professional compliance table suitable for government security docum var inferredGCP bool var inferredCloudflare bool routingQuestion := questionForRouting(question) + + // First, do quick keyword check for explicit terms includeAWS, inferredCode, includeGitHub, inferredTerraform, inferredK8s, inferredGCP, inferredCloudflare = inferContext(routingQuestion) _ = inferredCode if debug { - fmt.Printf("Inferred context: AWS=%v, GitHub=%v, Terraform=%v, K8s=%v, GCP=%v, Cloudflare=%v\n", includeAWS, includeGitHub, inferredTerraform, inferredK8s, inferredGCP, inferredCloudflare) + fmt.Printf("Keyword inference: AWS=%v, GitHub=%v, Terraform=%v, K8s=%v, GCP=%v, Cloudflare=%v\n", + includeAWS, includeGitHub, inferredTerraform, inferredK8s, inferredGCP, inferredCloudflare) + } + + // For ambiguous queries (multiple services detected or Cloudflare detected), + // use LLM to make the final routing decision + needsLLMClassification := false + + // Count how many services were inferred + inferredCount := 0 + if includeAWS { + inferredCount++ + } + if inferredK8s { + inferredCount++ + } + if inferredGCP { + inferredCount++ + } + if inferredCloudflare { + inferredCount++ + } + + // Use LLM classification if: + // 1. Multiple services inferred (ambiguous) + // 2. Cloudflare was inferred (verify it's actually Cloudflare-related) + if inferredCount > 1 || inferredCloudflare { + needsLLMClassification = true + } + + if needsLLMClassification { + if debug { + fmt.Println("[routing] Ambiguous query detected, using LLM for classification...") + } + + llmService, err := classifyQueryWithLLM(context.Background(), routingQuestion, debug) + if err != nil { + // FALLBACK: LLM classification failed, use keyword-based inference + if debug { + fmt.Printf("[routing] LLM classification failed (%v), falling back to keyword inference\n", err) + } + // Keep the keyword-inferred values as-is (no changes needed) + } else { + // LLM succeeded - override keyword-based inference with LLM decision + switch llmService { + case "cloudflare": + inferredCloudflare = true + inferredK8s = false + inferredGCP = false + includeAWS = false + case "k8s": + inferredK8s = true + inferredCloudflare = false + inferredGCP = false + case "gcp": + inferredGCP = true + inferredCloudflare = false + inferredK8s = false + case "aws": + includeAWS = true + inferredCloudflare = false + inferredK8s = false + inferredGCP = false + case "terraform": + inferredTerraform = true + inferredCloudflare = false + case "github": + includeGitHub = true + inferredCloudflare = false + default: + // "general" - default to AWS + includeAWS = true + inferredCloudflare = false + inferredK8s = false + } + + if debug { + fmt.Printf("LLM override: AWS=%v, K8s=%v, GCP=%v, Cloudflare=%v\n", + includeAWS, inferredK8s, inferredGCP, inferredCloudflare) + } + } } // Handle inferred Terraform context @@ -981,18 +1063,11 @@ func inferContext(question string) (aws bool, code bool, github bool, terraform } cloudflareKeywords := []string{ - // Platform and tools - "cloudflare", "cf", "wrangler", "cloudflared", - // DNS - "dns record", "zone", "nameserver", "cname", "a record", "aaaa record", "mx record", "txt record", - // Workers and edge - "worker", "workers", "kv", "d1", "r2", "pages", "durable objects", - // Security - "waf", "firewall rule", "rate limit", "ddos", "bot management", "page rules", - // Zero Trust - "tunnel", "access", "zero trust", "warp", "cloudflare access", - // Analytics and performance - "cdn", "cache", "analytics", "web analytics", + // Only match if Cloudflare is explicitly mentioned + "cloudflare", + // Cloudflare-specific CLI tools (unique to Cloudflare) + "wrangler", + "cloudflared", } questionLower := strings.ToLower(question) @@ -1985,3 +2060,92 @@ func determineRoutingDecision(question string) (agent string, reason string) { // Default to CLI for general queries return "cli", "General infrastructure query or analysis" } + +// getRoutingClassificationPrompt returns a prompt for LLM to classify which service a query is about +func getRoutingClassificationPrompt(question string) string { + return fmt.Sprintf(`Classify which cloud service or platform this user query is about. + +User Query: "%s" + +Available services: +- cloudflare: Cloudflare CDN, DNS, Workers, KV, D1, R2, Pages, WAF, Tunnels, Zero Trust, Analytics +- aws: Amazon Web Services (EC2, Lambda, S3, RDS, VPC, Route53, CloudFront, IAM, ECS, etc.) +- k8s: Kubernetes clusters, pods, deployments, services, helm, kubectl +- gcp: Google Cloud Platform (Cloud Run, GKE, Cloud SQL, BigQuery, etc.) +- github: GitHub repositories, PRs, issues, actions, workflows +- terraform: Infrastructure as code, Terraform plans, state, modules +- general: General questions not specific to any cloud platform + +IMPORTANT RULES: +1. Only classify as "cloudflare" if the query EXPLICITLY mentions Cloudflare, wrangler, cloudflared, or Cloudflare-specific products +2. Generic terms like "cdn", "cache", "dns", "worker", "waf", "rate limit", "tunnel" should default to AWS unless Cloudflare is explicitly mentioned +3. If the query mentions AWS services (EC2, Lambda, S3, CloudFront, Route53, etc.), classify as "aws" +4. If uncertain, classify as "aws" (the default cloud provider) + +Respond with ONLY a JSON object: +{ + "service": "cloudflare|aws|k8s|gcp|github|terraform|general", + "confidence": "high|medium|low", + "reason": "brief explanation of why this classification" +}`, question) +} + +// classifyQueryWithLLM uses the AI client to determine which service a query is about +func classifyQueryWithLLM(ctx context.Context, question string, debug bool) (string, error) { + // Get provider config + provider := viper.GetString("ai.default_provider") + if provider == "" { + provider = "openai" + } + + var apiKey string + switch provider { + case "openai": + apiKey = os.Getenv("OPENAI_API_KEY") + if apiKey == "" { + apiKey = viper.GetString("ai.providers.openai.api_key") + } + case "anthropic": + apiKey = os.Getenv("ANTHROPIC_API_KEY") + if apiKey == "" { + apiKey = viper.GetString("ai.providers.anthropic.api_key") + } + case "gemini", "gemini-api": + apiKey = os.Getenv("GEMINI_API_KEY") + } + + // Create minimal AI client for classification + aiClient := ai.NewClient(provider, apiKey, debug, "") + + prompt := getRoutingClassificationPrompt(question) + response, err := aiClient.AskPrompt(ctx, prompt) + if err != nil { + if debug { + fmt.Printf("[routing] LLM classification failed: %v, falling back to keyword matching\n", err) + } + return "", err + } + + // Parse the JSON response + var classification struct { + Service string `json:"service"` + Confidence string `json:"confidence"` + Reason string `json:"reason"` + } + + // Clean response and parse JSON + cleaned := aiClient.CleanJSONResponse(response) + if err := json.Unmarshal([]byte(cleaned), &classification); err != nil { + if debug { + fmt.Printf("[routing] Failed to parse classification response: %v\n", err) + } + return "", err + } + + if debug { + fmt.Printf("[routing] LLM classification: service=%s, confidence=%s, reason=%s\n", + classification.Service, classification.Confidence, classification.Reason) + } + + return classification.Service, nil +} From df0080071968c5f6f05ca628f093d0861196208f Mon Sep 17 00:00:00 2001 From: nash Date: Mon, 26 Jan 2026 22:01:59 +0000 Subject: [PATCH 2/4] test(ask): add routing tests for Cloudflare keyword filtering Add tests to verify: - Cloudflare only triggers on explicit mentions (cloudflare, wrangler, cloudflared) - Generic terms (cache, cdn, worker, waf, rate limit) do not trigger Cloudflare - AWS routing works for ec2, lambda, dns - K8s routing works for pods, kubernetes, kubectl --- cmd/ask_routing_test.go | 180 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 180 insertions(+) create mode 100644 cmd/ask_routing_test.go diff --git a/cmd/ask_routing_test.go b/cmd/ask_routing_test.go new file mode 100644 index 0000000..6795bda --- /dev/null +++ b/cmd/ask_routing_test.go @@ -0,0 +1,180 @@ +package cmd + +import ( + "testing" +) + +func TestInferContext_CloudflareExplicit(t *testing.T) { + tests := []struct { + name string + query string + expectCloudflare bool + expectAWS bool + expectK8s bool + }{ + { + name: "explicit cloudflare mention", + query: "list my cloudflare zones", + expectCloudflare: true, + expectAWS: false, + expectK8s: false, + }, + { + name: "wrangler tool mention", + query: "wrangler deploy my worker", + expectCloudflare: true, + expectAWS: false, + expectK8s: false, + }, + { + name: "cloudflared tool mention", + query: "cloudflared tunnel list", + expectCloudflare: true, + expectAWS: false, + expectK8s: false, + }, + { + name: "generic cache should not trigger cloudflare", + query: "show cache hit rate", + expectCloudflare: false, + expectAWS: true, + expectK8s: false, + }, + { + name: "generic cdn should not trigger cloudflare", + query: "list cdn distributions", + expectCloudflare: false, + expectAWS: true, + expectK8s: false, + }, + { + name: "generic worker should not trigger cloudflare", + query: "show worker processes", + expectCloudflare: false, + expectAWS: false, + expectK8s: false, + }, + { + name: "generic waf should not trigger cloudflare", + query: "list waf rules", + expectCloudflare: false, + expectAWS: true, + expectK8s: false, + }, + { + name: "generic rate limit should not trigger cloudflare", + query: "show rate limits", + expectCloudflare: false, + expectAWS: true, // "rate" triggers AWS keyword match + expectK8s: false, + }, + { + name: "generic dns should not trigger cloudflare", + query: "list dns records", + expectCloudflare: false, + expectAWS: true, + expectK8s: false, + }, + { + name: "ec2 should trigger aws", + query: "list ec2 instances", + expectCloudflare: false, + expectAWS: true, + expectK8s: false, + }, + { + name: "lambda should trigger aws", + query: "show lambda functions", + expectCloudflare: false, + expectAWS: true, + expectK8s: false, + }, + { + name: "pods should trigger k8s", + query: "list pods", + expectCloudflare: false, + expectAWS: false, + expectK8s: true, + }, + { + name: "kubernetes should trigger k8s", + query: "show kubernetes deployments", + expectCloudflare: false, + expectAWS: false, + expectK8s: true, + }, + { + name: "kubectl should trigger k8s", + query: "kubectl get nodes", + expectCloudflare: false, + expectAWS: false, + expectK8s: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + aws, _, _, _, k8s, _, cf := inferContext(tt.query) + + if cf != tt.expectCloudflare { + t.Errorf("inferContext(%q) cloudflare = %v, want %v", tt.query, cf, tt.expectCloudflare) + } + if aws != tt.expectAWS { + t.Errorf("inferContext(%q) aws = %v, want %v", tt.query, aws, tt.expectAWS) + } + if k8s != tt.expectK8s { + t.Errorf("inferContext(%q) k8s = %v, want %v", tt.query, k8s, tt.expectK8s) + } + }) + } +} + +func TestInferContext_NoFalsePositives(t *testing.T) { + // These queries should NOT trigger Cloudflare routing + noCloudflareQueries := []string{ + "what is the cache hit rate", + "show cdn distribution", + "list workers", + "show rate limits", + "check waf status", + "create tunnel to database", + "show analytics dashboard", + "configure access control", + "deploy to pages", + "list dns records for route53", + "show cloudfront distributions", + } + + for _, query := range noCloudflareQueries { + t.Run(query, func(t *testing.T) { + _, _, _, _, _, _, cf := inferContext(query) + if cf { + t.Errorf("inferContext(%q) incorrectly triggered Cloudflare routing", query) + } + }) + } +} + +func TestGetRoutingClassificationPrompt(t *testing.T) { + prompt := getRoutingClassificationPrompt("list my cloudflare zones") + + // Check that the prompt contains expected elements + if prompt == "" { + t.Error("getRoutingClassificationPrompt returned empty string") + } + + expectedPhrases := []string{ + "cloudflare", + "aws", + "k8s", + "gcp", + "JSON object", + "service", + } + + for _, phrase := range expectedPhrases { + if !contains(prompt, phrase) { + t.Errorf("getRoutingClassificationPrompt missing expected phrase: %s", phrase) + } + } +} From f7da235d1b2723eb6be4ae33ba96ef5140c15a99 Mon Sep 17 00:00:00 2001 From: nash Date: Mon, 26 Jan 2026 22:10:02 +0000 Subject: [PATCH 3/4] refactor(routing): extract routing logic into internal package Move query routing classification logic from cmd/ask.go to internal/routing package for better testability and code organization. Changes: - Create internal/routing/routing.go with InferContext, ClassifyWithLLM, NeedsLLMClassification, ApplyLLMClassification, GetClassificationPrompt - Create internal/routing/routing_test.go with comprehensive tests - Update cmd/ask.go to use the routing package - Remove duplicate inferContext, classifyQueryWithLLM, and contains functions - Delete cmd/ask_routing_test.go (tests moved to internal/routing) --- cmd/ask.go | 318 ++--------------------------- cmd/ask_routing_test.go | 180 ----------------- internal/routing/routing.go | 337 +++++++++++++++++++++++++++++++ internal/routing/routing_test.go | 337 +++++++++++++++++++++++++++++++ 4 files changed, 694 insertions(+), 478 deletions(-) delete mode 100644 cmd/ask_routing_test.go create mode 100644 internal/routing/routing.go create mode 100644 internal/routing/routing_test.go diff --git a/cmd/ask.go b/cmd/ask.go index bcf0ae6..fb7463d 100644 --- a/cmd/ask.go +++ b/cmd/ask.go @@ -23,6 +23,7 @@ import ( "github.com/bgdnvk/clanker/internal/k8s" "github.com/bgdnvk/clanker/internal/k8s/plan" "github.com/bgdnvk/clanker/internal/maker" + "github.com/bgdnvk/clanker/internal/routing" tfclient "github.com/bgdnvk/clanker/internal/terraform" "github.com/spf13/cobra" "github.com/spf13/viper" @@ -301,11 +302,11 @@ Examples: makerProvider = "aws" makerProviderReason = "explicit" default: - _, _, _, _, _, inferredGCP, inferredCloudflare := inferContext(questionForRouting(question)) - if inferredCloudflare { + svcCtx := routing.InferContext(questionForRouting(question)) + if svcCtx.Cloudflare { makerProvider = "cloudflare" makerProviderReason = "inferred" - } else if inferredGCP { + } else if svcCtx.GCP { makerProvider = "gcp" makerProviderReason = "inferred" } @@ -455,54 +456,26 @@ Format as a professional compliance table suitable for government security docum } if !includeAWS && !includeGitHub && !includeTerraform && !includeGCP && !includeCloudflare { - var inferredTerraform bool - var inferredCode bool - var inferredK8s bool - var inferredGCP bool - var inferredCloudflare bool routingQuestion := questionForRouting(question) // First, do quick keyword check for explicit terms - includeAWS, inferredCode, includeGitHub, inferredTerraform, inferredK8s, inferredGCP, inferredCloudflare = inferContext(routingQuestion) - _ = inferredCode + svcCtx := routing.InferContext(routingQuestion) + includeAWS = svcCtx.AWS + includeGitHub = svcCtx.GitHub if debug { fmt.Printf("Keyword inference: AWS=%v, GitHub=%v, Terraform=%v, K8s=%v, GCP=%v, Cloudflare=%v\n", - includeAWS, includeGitHub, inferredTerraform, inferredK8s, inferredGCP, inferredCloudflare) + svcCtx.AWS, svcCtx.GitHub, svcCtx.Terraform, svcCtx.K8s, svcCtx.GCP, svcCtx.Cloudflare) } // For ambiguous queries (multiple services detected or Cloudflare detected), // use LLM to make the final routing decision - needsLLMClassification := false - - // Count how many services were inferred - inferredCount := 0 - if includeAWS { - inferredCount++ - } - if inferredK8s { - inferredCount++ - } - if inferredGCP { - inferredCount++ - } - if inferredCloudflare { - inferredCount++ - } - - // Use LLM classification if: - // 1. Multiple services inferred (ambiguous) - // 2. Cloudflare was inferred (verify it's actually Cloudflare-related) - if inferredCount > 1 || inferredCloudflare { - needsLLMClassification = true - } - - if needsLLMClassification { + if routing.NeedsLLMClassification(svcCtx) { if debug { fmt.Println("[routing] Ambiguous query detected, using LLM for classification...") } - llmService, err := classifyQueryWithLLM(context.Background(), routingQuestion, debug) + llmService, err := routing.ClassifyWithLLM(context.Background(), routingQuestion, debug) if err != nil { // FALLBACK: LLM classification failed, use keyword-based inference if debug { @@ -511,61 +484,35 @@ Format as a professional compliance table suitable for government security docum // Keep the keyword-inferred values as-is (no changes needed) } else { // LLM succeeded - override keyword-based inference with LLM decision - switch llmService { - case "cloudflare": - inferredCloudflare = true - inferredK8s = false - inferredGCP = false - includeAWS = false - case "k8s": - inferredK8s = true - inferredCloudflare = false - inferredGCP = false - case "gcp": - inferredGCP = true - inferredCloudflare = false - inferredK8s = false - case "aws": - includeAWS = true - inferredCloudflare = false - inferredK8s = false - inferredGCP = false - case "terraform": - inferredTerraform = true - inferredCloudflare = false - case "github": - includeGitHub = true - inferredCloudflare = false - default: - // "general" - default to AWS - includeAWS = true - inferredCloudflare = false - inferredK8s = false - } + routing.ApplyLLMClassification(&svcCtx, llmService) if debug { fmt.Printf("LLM override: AWS=%v, K8s=%v, GCP=%v, Cloudflare=%v\n", - includeAWS, inferredK8s, inferredGCP, inferredCloudflare) + svcCtx.AWS, svcCtx.K8s, svcCtx.GCP, svcCtx.Cloudflare) } } } // Handle inferred Terraform context - if inferredTerraform { + if svcCtx.Terraform { includeTerraform = true } - if inferredGCP { + if svcCtx.GCP { includeGCP = true } + // Update includeAWS and includeGitHub from service context + includeAWS = svcCtx.AWS + includeGitHub = svcCtx.GitHub + // Handle Cloudflare queries by delegating to Cloudflare agent - if inferredCloudflare { + if svcCtx.Cloudflare { return handleCloudflareQuery(context.Background(), routingQuestion, debug) } // Handle K8s queries by delegating to K8s agent - if inferredK8s { + if svcCtx.K8s { return handleK8sQuery(context.Background(), routingQuestion, debug, viper.GetString("kubernetes.kubeconfig")) } } @@ -990,138 +937,6 @@ func resolveGeminiModel(provider, flagValue string) string { return model } -// inferContext tries to determine if the question is about AWS, GitHub, Terraform, Kubernetes, GCP, or Cloudflare. -// Code scanning is disabled, so this never infers code context. -func inferContext(question string) (aws bool, code bool, github bool, terraform bool, k8s bool, gcp bool, cf bool) { - awsKeywords := []string{ - // Core services - "ec2", "lambda", "rds", "s3", "ecs", "cloudwatch", "logs", "batch", "sqs", "sns", "dynamodb", "elasticache", "elb", "alb", "nlb", "route53", "cloudfront", "api-gateway", "cognito", "iam", "vpc", "subnet", "security-group", "nacl", "nat", "igw", "vpn", "direct-connect", - // General terms - "instance", "bucket", "database", "aws", "resources", "infrastructure", "running", "account", "error", "log", "job", "queue", "compute", "storage", "network", "cdn", "load-balancer", "auto-scaling", "scaling", "health", "metric", "alarm", "notification", "backup", "snapshot", "ami", "volume", "ebs", "efs", "fsx", - // Compute and GPU terms - "gpu", "cuda", "ml", "machine-learning", "training", "inference", "p2", "p3", "p4", "g3", "g4", "g5", "spot", "reserved", "dedicated", - // Status and operations - "status", "state", "healthy", "unhealthy", "available", "pending", "stopping", "stopped", "terminated", "creating", "deleting", "modifying", "active", "inactive", "enabled", "disabled", - // Cost and billing - "cost", "billing", "price", "usage", "spend", "budget", - // Monitoring and debugging - "monitor", "trace", "debug", "performance", "latency", "throughput", "error-rate", "failure", "timeout", "retry", - // Infrastructure discovery - "services", "active", "deployed", "discovery", "overview", "summary", "list-all", "what's-running", "what-services", "infrastructure-overview", - } - - githubKeywords := []string{ - // GitHub platform - "github", "git", "repository", "repo", "fork", "clone", "branch", "tag", "release", "issue", "discussion", - // CI/CD and Actions - "action", "workflow", "ci", "cd", "build", "deploy", "deployment", "pipeline", "job", "step", "runner", "artifact", - // Collaboration - "pr", "pull", "request", "merge", "commit", "push", "pull-request", "review", "approve", "comment", "assignee", "reviewer", - // Project management - "milestone", "project", "board", "epic", "story", "task", "bug", "feature", "enhancement", "label", "status", - // Security and compliance - "security", "vulnerability", "dependabot", "secret", "token", "permission", "access", "audit", - } - - terraformKeywords := []string{ - // Terraform core - "terraform", "tf", "hcl", "plan", "apply", "destroy", "init", "workspace", "state", "backend", "provider", "resource", "data", "module", "variable", "output", "local", - // Operations - "infrastructure-as-code", "iac", "provisioning", "deployment", "environment", "stack", "configuration", "template", - // State management - "tfstate", "state-file", "remote-state", "lock", "unlock", "drift", "refresh", "import", "taint", "untaint", - // Workspaces and environments - "dev", "stage", "staging", "prod", "production", "qa", "environment", "workspace", - } - - k8sKeywords := []string{ - // Core K8s terms - "kubernetes", "k8s", "kubectl", "kube", - // Workloads - "pod", "pods", "deployment", "deployments", "replicaset", "statefulset", - "daemonset", "job", "cronjob", - // Networking - "service", "services", "ingress", "loadbalancer", "nodeport", "clusterip", - "networkpolicy", "endpoint", - // Storage - "pv", "pvc", "persistentvolume", "storageclass", "configmap", "secret", - // Cluster - "node", "nodes", "namespace", "cluster", "kubeconfig", "context", - // Tools - "helm", "chart", "release", "tiller", - // Providers - "eks", "kubeadm", "kops", "k3s", "minikube", - // Operations - "rollout", "scale", "drain", "cordon", "taint", - } - - gcpKeywords := []string{ - "gcp", "google cloud", "cloud run", "cloudrun", "cloud sql", "cloudsql", "gke", "gcs", "cloud storage", - "pubsub", "pub/sub", "cloud functions", "cloud function", "compute engine", "gce", "iam service account", - "workload identity", "artifact registry", "secret manager", "bigquery", "spanner", "bigtable", - "cloud build", "cloud deploy", "cloud dns", "cloud armor", "cloud load balancing", "api gateway", - } - - cloudflareKeywords := []string{ - // Only match if Cloudflare is explicitly mentioned - "cloudflare", - // Cloudflare-specific CLI tools (unique to Cloudflare) - "wrangler", - "cloudflared", - } - - questionLower := strings.ToLower(question) - - for _, keyword := range awsKeywords { - if contains(questionLower, keyword) { - aws = true - break - } - } - - for _, keyword := range githubKeywords { - if contains(questionLower, keyword) { - github = true - break - } - } - - for _, keyword := range terraformKeywords { - if contains(questionLower, keyword) { - terraform = true - break - } - } - - for _, keyword := range k8sKeywords { - if contains(questionLower, keyword) { - k8s = true - break - } - } - - for _, keyword := range gcpKeywords { - if contains(questionLower, keyword) { - gcp = true - break - } - } - - for _, keyword := range cloudflareKeywords { - if contains(questionLower, keyword) { - cf = true - break - } - } - - // If no specific context detected, include AWS + GitHub by default. - if !aws && !github && !terraform && !k8s && !gcp && !cf { - aws, github = true, true - } - - return aws, code, github, terraform, k8s, gcp, cf -} - func questionForRouting(question string) string { trimmed := strings.TrimSpace(question) if trimmed == "" { @@ -1803,10 +1618,6 @@ func extractDeployName(question string) string { return "" } -func contains(s, substr string) bool { - return strings.Contains(strings.ToLower(s), strings.ToLower(substr)) -} - // formatK8sCommand formats a command for display (like AWS maker formatAWSArgsForLog) func formatK8sCommand(cmdName string, args []string) string { const maxArgLen = 160 @@ -2060,92 +1871,3 @@ func determineRoutingDecision(question string) (agent string, reason string) { // Default to CLI for general queries return "cli", "General infrastructure query or analysis" } - -// getRoutingClassificationPrompt returns a prompt for LLM to classify which service a query is about -func getRoutingClassificationPrompt(question string) string { - return fmt.Sprintf(`Classify which cloud service or platform this user query is about. - -User Query: "%s" - -Available services: -- cloudflare: Cloudflare CDN, DNS, Workers, KV, D1, R2, Pages, WAF, Tunnels, Zero Trust, Analytics -- aws: Amazon Web Services (EC2, Lambda, S3, RDS, VPC, Route53, CloudFront, IAM, ECS, etc.) -- k8s: Kubernetes clusters, pods, deployments, services, helm, kubectl -- gcp: Google Cloud Platform (Cloud Run, GKE, Cloud SQL, BigQuery, etc.) -- github: GitHub repositories, PRs, issues, actions, workflows -- terraform: Infrastructure as code, Terraform plans, state, modules -- general: General questions not specific to any cloud platform - -IMPORTANT RULES: -1. Only classify as "cloudflare" if the query EXPLICITLY mentions Cloudflare, wrangler, cloudflared, or Cloudflare-specific products -2. Generic terms like "cdn", "cache", "dns", "worker", "waf", "rate limit", "tunnel" should default to AWS unless Cloudflare is explicitly mentioned -3. If the query mentions AWS services (EC2, Lambda, S3, CloudFront, Route53, etc.), classify as "aws" -4. If uncertain, classify as "aws" (the default cloud provider) - -Respond with ONLY a JSON object: -{ - "service": "cloudflare|aws|k8s|gcp|github|terraform|general", - "confidence": "high|medium|low", - "reason": "brief explanation of why this classification" -}`, question) -} - -// classifyQueryWithLLM uses the AI client to determine which service a query is about -func classifyQueryWithLLM(ctx context.Context, question string, debug bool) (string, error) { - // Get provider config - provider := viper.GetString("ai.default_provider") - if provider == "" { - provider = "openai" - } - - var apiKey string - switch provider { - case "openai": - apiKey = os.Getenv("OPENAI_API_KEY") - if apiKey == "" { - apiKey = viper.GetString("ai.providers.openai.api_key") - } - case "anthropic": - apiKey = os.Getenv("ANTHROPIC_API_KEY") - if apiKey == "" { - apiKey = viper.GetString("ai.providers.anthropic.api_key") - } - case "gemini", "gemini-api": - apiKey = os.Getenv("GEMINI_API_KEY") - } - - // Create minimal AI client for classification - aiClient := ai.NewClient(provider, apiKey, debug, "") - - prompt := getRoutingClassificationPrompt(question) - response, err := aiClient.AskPrompt(ctx, prompt) - if err != nil { - if debug { - fmt.Printf("[routing] LLM classification failed: %v, falling back to keyword matching\n", err) - } - return "", err - } - - // Parse the JSON response - var classification struct { - Service string `json:"service"` - Confidence string `json:"confidence"` - Reason string `json:"reason"` - } - - // Clean response and parse JSON - cleaned := aiClient.CleanJSONResponse(response) - if err := json.Unmarshal([]byte(cleaned), &classification); err != nil { - if debug { - fmt.Printf("[routing] Failed to parse classification response: %v\n", err) - } - return "", err - } - - if debug { - fmt.Printf("[routing] LLM classification: service=%s, confidence=%s, reason=%s\n", - classification.Service, classification.Confidence, classification.Reason) - } - - return classification.Service, nil -} diff --git a/cmd/ask_routing_test.go b/cmd/ask_routing_test.go deleted file mode 100644 index 6795bda..0000000 --- a/cmd/ask_routing_test.go +++ /dev/null @@ -1,180 +0,0 @@ -package cmd - -import ( - "testing" -) - -func TestInferContext_CloudflareExplicit(t *testing.T) { - tests := []struct { - name string - query string - expectCloudflare bool - expectAWS bool - expectK8s bool - }{ - { - name: "explicit cloudflare mention", - query: "list my cloudflare zones", - expectCloudflare: true, - expectAWS: false, - expectK8s: false, - }, - { - name: "wrangler tool mention", - query: "wrangler deploy my worker", - expectCloudflare: true, - expectAWS: false, - expectK8s: false, - }, - { - name: "cloudflared tool mention", - query: "cloudflared tunnel list", - expectCloudflare: true, - expectAWS: false, - expectK8s: false, - }, - { - name: "generic cache should not trigger cloudflare", - query: "show cache hit rate", - expectCloudflare: false, - expectAWS: true, - expectK8s: false, - }, - { - name: "generic cdn should not trigger cloudflare", - query: "list cdn distributions", - expectCloudflare: false, - expectAWS: true, - expectK8s: false, - }, - { - name: "generic worker should not trigger cloudflare", - query: "show worker processes", - expectCloudflare: false, - expectAWS: false, - expectK8s: false, - }, - { - name: "generic waf should not trigger cloudflare", - query: "list waf rules", - expectCloudflare: false, - expectAWS: true, - expectK8s: false, - }, - { - name: "generic rate limit should not trigger cloudflare", - query: "show rate limits", - expectCloudflare: false, - expectAWS: true, // "rate" triggers AWS keyword match - expectK8s: false, - }, - { - name: "generic dns should not trigger cloudflare", - query: "list dns records", - expectCloudflare: false, - expectAWS: true, - expectK8s: false, - }, - { - name: "ec2 should trigger aws", - query: "list ec2 instances", - expectCloudflare: false, - expectAWS: true, - expectK8s: false, - }, - { - name: "lambda should trigger aws", - query: "show lambda functions", - expectCloudflare: false, - expectAWS: true, - expectK8s: false, - }, - { - name: "pods should trigger k8s", - query: "list pods", - expectCloudflare: false, - expectAWS: false, - expectK8s: true, - }, - { - name: "kubernetes should trigger k8s", - query: "show kubernetes deployments", - expectCloudflare: false, - expectAWS: false, - expectK8s: true, - }, - { - name: "kubectl should trigger k8s", - query: "kubectl get nodes", - expectCloudflare: false, - expectAWS: false, - expectK8s: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - aws, _, _, _, k8s, _, cf := inferContext(tt.query) - - if cf != tt.expectCloudflare { - t.Errorf("inferContext(%q) cloudflare = %v, want %v", tt.query, cf, tt.expectCloudflare) - } - if aws != tt.expectAWS { - t.Errorf("inferContext(%q) aws = %v, want %v", tt.query, aws, tt.expectAWS) - } - if k8s != tt.expectK8s { - t.Errorf("inferContext(%q) k8s = %v, want %v", tt.query, k8s, tt.expectK8s) - } - }) - } -} - -func TestInferContext_NoFalsePositives(t *testing.T) { - // These queries should NOT trigger Cloudflare routing - noCloudflareQueries := []string{ - "what is the cache hit rate", - "show cdn distribution", - "list workers", - "show rate limits", - "check waf status", - "create tunnel to database", - "show analytics dashboard", - "configure access control", - "deploy to pages", - "list dns records for route53", - "show cloudfront distributions", - } - - for _, query := range noCloudflareQueries { - t.Run(query, func(t *testing.T) { - _, _, _, _, _, _, cf := inferContext(query) - if cf { - t.Errorf("inferContext(%q) incorrectly triggered Cloudflare routing", query) - } - }) - } -} - -func TestGetRoutingClassificationPrompt(t *testing.T) { - prompt := getRoutingClassificationPrompt("list my cloudflare zones") - - // Check that the prompt contains expected elements - if prompt == "" { - t.Error("getRoutingClassificationPrompt returned empty string") - } - - expectedPhrases := []string{ - "cloudflare", - "aws", - "k8s", - "gcp", - "JSON object", - "service", - } - - for _, phrase := range expectedPhrases { - if !contains(prompt, phrase) { - t.Errorf("getRoutingClassificationPrompt missing expected phrase: %s", phrase) - } - } -} diff --git a/internal/routing/routing.go b/internal/routing/routing.go new file mode 100644 index 0000000..32c7e42 --- /dev/null +++ b/internal/routing/routing.go @@ -0,0 +1,337 @@ +// Package routing provides query routing and classification for cloud services. +package routing + +import ( + "context" + "encoding/json" + "fmt" + "os" + "strings" + + "github.com/bgdnvk/clanker/internal/ai" + "github.com/spf13/viper" +) + +// ServiceContext represents which services were detected in a query +type ServiceContext struct { + AWS bool + GitHub bool + Terraform bool + K8s bool + GCP bool + Cloudflare bool + Code bool +} + +// Classification represents the result of LLM-based query classification +type Classification struct { + Service string `json:"service"` + Confidence string `json:"confidence"` + Reason string `json:"reason"` +} + +// InferContext analyzes a question and determines which cloud service contexts are relevant. +// Returns a ServiceContext with boolean flags for each detected service. +func InferContext(question string) ServiceContext { + ctx := ServiceContext{} + + awsKeywords := []string{ + // Core services + "ec2", "lambda", "rds", "s3", "ecs", "cloudwatch", "logs", "batch", + "sqs", "sns", "dynamodb", "elasticache", "elb", "alb", "nlb", "route53", + "cloudfront", "api-gateway", "cognito", "iam", "vpc", "subnet", + "security-group", "nacl", "nat", "igw", "vpn", "direct-connect", + // General terms that strongly indicate AWS context + "instance", "bucket", "database", "aws", "resources", "infrastructure", + "running", "account", "error", "log", "job", "queue", "compute", + "storage", "network", "cdn", "load-balancer", "auto-scaling", "scaling", + "health", "metric", "alarm", "notification", "backup", "snapshot", + "ami", "volume", "ebs", "efs", "fsx", + // ML/GPU + "gpu", "cuda", "ml", "machine-learning", "training", "inference", + "p2", "p3", "p4", "g3", "g4", "g5", "spot", "reserved", "dedicated", + // Status keywords + "status", "state", "healthy", "unhealthy", "available", "pending", + "stopping", "stopped", "terminated", "creating", "deleting", "modifying", + "active", "inactive", "enabled", "disabled", + // Cost keywords + "cost", "billing", "price", "usage", "spend", "budget", + // Monitoring keywords + "monitor", "trace", "debug", "performance", "latency", "throughput", + "error-rate", "failure", "timeout", "retry", + // Discovery keywords + "services", "active", "deployed", "discovery", "overview", "summary", + "list-all", "what's-running", "what-services", "infrastructure-overview", + } + + githubKeywords := []string{ + // Platform + "github", "git", "repository", "repo", "fork", "clone", "branch", "tag", "release", + "issue", "discussion", + // CI/CD + "action", "workflow", "ci", "cd", "build", "deploy", "deployment", + "pipeline", "job", "step", "runner", "artifact", + // Collaboration + "pr", "pull", "request", "merge", "commit", "push", "pull-request", + "review", "approve", "comment", "assignee", "reviewer", + // Project management + "milestone", "project", "board", "epic", "story", "task", "bug", + "feature", "enhancement", "label", "status", + // Security + "security", "vulnerability", "dependabot", "secret", "token", + "permission", "access", "audit", + } + + terraformKeywords := []string{ + // Core + "terraform", "tf ", "hcl", "plan", "apply", "destroy", "init", + "workspace", "state", "backend", "provider", "resource", "data", + "module", "variable", "output", "local", + // Operations + "infrastructure-as-code", "iac", "provisioning", "deployment", + "environment", "stack", "configuration", "template", + // State management + "tfstate", "state-file", "remote-state", "lock", "unlock", + "drift", "refresh", "import", "taint", "untaint", + // Environments + "dev", "stage", "staging", "prod", "production", "qa", "environment", "workspace", + } + + k8sKeywords := []string{ + // Core K8s terms + "kubernetes", "k8s", "kubectl", "kube", + // Workloads + "pod", "pods", "deployment", "deployments", "replicaset", "statefulset", + "daemonset", "job", "cronjob", + // Networking + "service", "services", "ingress", "loadbalancer", "nodeport", "clusterip", + "networkpolicy", "endpoint", + // Storage + "pv", "pvc", "persistentvolume", "storageclass", "configmap", "secret", + // Cluster + "node", "nodes", "namespace", "cluster", "kubeconfig", "context", + // Tools + "helm", "chart", "release", "tiller", + // Providers + "eks", "kubeadm", "kops", "k3s", "minikube", + // Operations + "rollout", "scale", "drain", "cordon", "taint", + } + + gcpKeywords := []string{ + "gcp", "google cloud", "cloud run", "cloudrun", "cloud sql", "cloudsql", "gke", "gcs", "cloud storage", + "pubsub", "pub/sub", "cloud functions", "cloud function", "compute engine", "gce", "iam service account", + "workload identity", "artifact registry", "secret manager", "bigquery", "spanner", "bigtable", + "cloud build", "cloud deploy", "cloud dns", "cloud armor", "cloud load balancing", "api gateway", + } + + cloudflareKeywords := []string{ + // Only match if Cloudflare is explicitly mentioned + "cloudflare", + // Cloudflare-specific CLI tools (unique to Cloudflare) + "wrangler", + "cloudflared", + } + + questionLower := strings.ToLower(question) + + for _, keyword := range awsKeywords { + if contains(questionLower, keyword) { + ctx.AWS = true + break + } + } + + for _, keyword := range githubKeywords { + if contains(questionLower, keyword) { + ctx.GitHub = true + break + } + } + + for _, keyword := range terraformKeywords { + if contains(questionLower, keyword) { + ctx.Terraform = true + break + } + } + + for _, keyword := range k8sKeywords { + if contains(questionLower, keyword) { + ctx.K8s = true + break + } + } + + for _, keyword := range gcpKeywords { + if contains(questionLower, keyword) { + ctx.GCP = true + break + } + } + + for _, keyword := range cloudflareKeywords { + if contains(questionLower, keyword) { + ctx.Cloudflare = true + break + } + } + + // Default to AWS and GitHub context if nothing is detected + if !ctx.AWS && !ctx.GitHub && !ctx.Terraform && !ctx.K8s && !ctx.GCP && !ctx.Cloudflare { + ctx.AWS = true + ctx.GitHub = true + } + + return ctx +} + +// GetClassificationPrompt returns a prompt for LLM to classify which service a query is about +func GetClassificationPrompt(question string) string { + return fmt.Sprintf(`Classify which cloud service or platform this user query is about. + +User Query: "%s" + +Available services: +- cloudflare: Cloudflare CDN, DNS, Workers, KV, D1, R2, Pages, WAF, Tunnels, Zero Trust, Analytics +- aws: Amazon Web Services (EC2, Lambda, S3, RDS, VPC, Route53, CloudFront, IAM, ECS, etc.) +- k8s: Kubernetes clusters, pods, deployments, services, helm, kubectl +- gcp: Google Cloud Platform (Cloud Run, GKE, Cloud SQL, BigQuery, etc.) +- github: GitHub repositories, PRs, issues, actions, workflows +- terraform: Infrastructure as code, Terraform plans, state, modules +- general: General questions not specific to any cloud platform + +IMPORTANT RULES: +1. Only classify as "cloudflare" if the query EXPLICITLY mentions Cloudflare, wrangler, cloudflared, or Cloudflare-specific products +2. Generic terms like "cdn", "cache", "dns", "worker", "waf", "rate limit", "tunnel" should default to AWS unless Cloudflare is explicitly mentioned +3. If the query mentions AWS services (EC2, Lambda, S3, CloudFront, Route53, etc.), classify as "aws" +4. If uncertain, classify as "aws" (the default cloud provider) + +Respond with ONLY a JSON object: +{ + "service": "cloudflare|aws|k8s|gcp|github|terraform|general", + "confidence": "high|medium|low", + "reason": "brief explanation of why this classification" +}`, question) +} + +// ClassifyWithLLM uses the AI client to determine which service a query is about. +// Returns the service name and any error encountered. +func ClassifyWithLLM(ctx context.Context, question string, debug bool) (string, error) { + // Get provider config + provider := viper.GetString("ai.default_provider") + if provider == "" { + provider = "openai" + } + + var apiKey string + switch provider { + case "openai": + apiKey = os.Getenv("OPENAI_API_KEY") + if apiKey == "" { + apiKey = viper.GetString("ai.providers.openai.api_key") + } + case "anthropic": + apiKey = os.Getenv("ANTHROPIC_API_KEY") + if apiKey == "" { + apiKey = viper.GetString("ai.providers.anthropic.api_key") + } + case "gemini", "gemini-api": + apiKey = os.Getenv("GEMINI_API_KEY") + } + + // Create minimal AI client for classification + aiClient := ai.NewClient(provider, apiKey, debug, "") + + prompt := GetClassificationPrompt(question) + response, err := aiClient.AskPrompt(ctx, prompt) + if err != nil { + if debug { + fmt.Printf("[routing] LLM classification failed: %v, falling back to keyword matching\n", err) + } + return "", err + } + + // Parse the JSON response + var classification Classification + + // Clean response and parse JSON + cleaned := aiClient.CleanJSONResponse(response) + if err := json.Unmarshal([]byte(cleaned), &classification); err != nil { + if debug { + fmt.Printf("[routing] Failed to parse classification response: %v\n", err) + } + return "", err + } + + if debug { + fmt.Printf("[routing] LLM classification: service=%s, confidence=%s, reason=%s\n", + classification.Service, classification.Confidence, classification.Reason) + } + + return classification.Service, nil +} + +// NeedsLLMClassification determines if a query needs LLM classification +// based on ambiguity (multiple services detected) or Cloudflare being inferred. +func NeedsLLMClassification(ctx ServiceContext) bool { + // Count how many services were inferred + count := 0 + if ctx.AWS { + count++ + } + if ctx.K8s { + count++ + } + if ctx.GCP { + count++ + } + if ctx.Cloudflare { + count++ + } + + // Use LLM classification if: + // 1. Multiple services inferred (ambiguous) + // 2. Cloudflare was inferred (verify it's actually Cloudflare-related) + return count > 1 || ctx.Cloudflare +} + +// ApplyLLMClassification updates the ServiceContext based on LLM classification result +func ApplyLLMClassification(ctx *ServiceContext, llmService string) { + switch llmService { + case "cloudflare": + ctx.Cloudflare = true + ctx.K8s = false + ctx.GCP = false + ctx.AWS = false + case "k8s": + ctx.K8s = true + ctx.Cloudflare = false + ctx.GCP = false + case "gcp": + ctx.GCP = true + ctx.Cloudflare = false + ctx.K8s = false + case "aws": + ctx.AWS = true + ctx.Cloudflare = false + ctx.K8s = false + ctx.GCP = false + case "terraform": + ctx.Terraform = true + ctx.Cloudflare = false + case "github": + ctx.GitHub = true + ctx.Cloudflare = false + default: + // "general" - default to AWS + ctx.AWS = true + ctx.Cloudflare = false + ctx.K8s = false + } +} + +// contains checks if s contains substr (case-insensitive) +func contains(s, substr string) bool { + return strings.Contains(strings.ToLower(s), strings.ToLower(substr)) +} diff --git a/internal/routing/routing_test.go b/internal/routing/routing_test.go new file mode 100644 index 0000000..8bc0d36 --- /dev/null +++ b/internal/routing/routing_test.go @@ -0,0 +1,337 @@ +package routing + +import ( + "testing" +) + +func TestInferContext_CloudflareExplicit(t *testing.T) { + tests := []struct { + name string + query string + expectCloudflare bool + expectAWS bool + expectK8s bool + }{ + { + name: "explicit cloudflare mention", + query: "list my cloudflare zones", + expectCloudflare: true, + expectAWS: false, + expectK8s: false, + }, + { + name: "wrangler tool mention", + query: "wrangler deploy my worker", + expectCloudflare: true, + expectAWS: false, + expectK8s: false, + }, + { + name: "cloudflared tool mention", + query: "cloudflared tunnel list", + expectCloudflare: true, + expectAWS: false, + expectK8s: false, + }, + { + name: "generic cache should not trigger cloudflare", + query: "show cache hit rate", + expectCloudflare: false, + expectAWS: true, + expectK8s: false, + }, + { + name: "generic cdn should not trigger cloudflare", + query: "list cdn distributions", + expectCloudflare: false, + expectAWS: true, + expectK8s: false, + }, + { + name: "generic worker should not trigger cloudflare", + query: "show worker processes", + expectCloudflare: false, + expectAWS: false, + expectK8s: false, + }, + { + name: "generic waf should not trigger cloudflare", + query: "list waf rules", + expectCloudflare: false, + expectAWS: true, + expectK8s: false, + }, + { + name: "generic rate limit should not trigger cloudflare", + query: "show rate limits", + expectCloudflare: false, + expectAWS: true, // "rate" triggers AWS keyword match + expectK8s: false, + }, + { + name: "generic dns should not trigger cloudflare", + query: "list dns records", + expectCloudflare: false, + expectAWS: true, + expectK8s: false, + }, + { + name: "ec2 should trigger aws", + query: "list ec2 instances", + expectCloudflare: false, + expectAWS: true, + expectK8s: false, + }, + { + name: "lambda should trigger aws", + query: "show lambda functions", + expectCloudflare: false, + expectAWS: true, + expectK8s: false, + }, + { + name: "pods should trigger k8s", + query: "list pods", + expectCloudflare: false, + expectAWS: false, + expectK8s: true, + }, + { + name: "kubernetes should trigger k8s", + query: "show kubernetes deployments", + expectCloudflare: false, + expectAWS: false, + expectK8s: true, + }, + { + name: "kubectl should trigger k8s", + query: "kubectl get nodes", + expectCloudflare: false, + expectAWS: false, + expectK8s: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := InferContext(tt.query) + + if ctx.Cloudflare != tt.expectCloudflare { + t.Errorf("InferContext(%q) cloudflare = %v, want %v", tt.query, ctx.Cloudflare, tt.expectCloudflare) + } + if ctx.AWS != tt.expectAWS { + t.Errorf("InferContext(%q) aws = %v, want %v", tt.query, ctx.AWS, tt.expectAWS) + } + if ctx.K8s != tt.expectK8s { + t.Errorf("InferContext(%q) k8s = %v, want %v", tt.query, ctx.K8s, tt.expectK8s) + } + }) + } +} + +func TestInferContext_NoCloudflarefalsePositives(t *testing.T) { + // These queries should NOT trigger Cloudflare routing + noCloudflareQueries := []string{ + "what is the cache hit rate", + "show cdn distribution", + "list workers", + "show rate limits", + "check waf status", + "create tunnel to database", + "show analytics dashboard", + "configure access control", + "deploy to pages", + "list dns records for route53", + "show cloudfront distributions", + } + + for _, query := range noCloudflareQueries { + t.Run(query, func(t *testing.T) { + ctx := InferContext(query) + if ctx.Cloudflare { + t.Errorf("InferContext(%q) incorrectly triggered Cloudflare routing", query) + } + }) + } +} + +func TestInferContext_DefaultBehavior(t *testing.T) { + // Unknown queries should default to AWS + GitHub + ctx := InferContext("random question about nothing") + + if !ctx.AWS { + t.Error("Unknown query should default to AWS=true") + } + if !ctx.GitHub { + t.Error("Unknown query should default to GitHub=true") + } + if ctx.Cloudflare { + t.Error("Unknown query should not trigger Cloudflare") + } + if ctx.K8s { + t.Error("Unknown query should not trigger K8s") + } +} + +func TestGetClassificationPrompt(t *testing.T) { + prompt := GetClassificationPrompt("list my cloudflare zones") + + if prompt == "" { + t.Error("GetClassificationPrompt returned empty string") + } + + expectedPhrases := []string{ + "cloudflare", + "aws", + "k8s", + "gcp", + "JSON object", + "service", + } + + for _, phrase := range expectedPhrases { + if !contains(prompt, phrase) { + t.Errorf("GetClassificationPrompt missing expected phrase: %s", phrase) + } + } +} + +func TestNeedsLLMClassification(t *testing.T) { + tests := []struct { + name string + ctx ServiceContext + expect bool + }{ + { + name: "cloudflare detected needs verification", + ctx: ServiceContext{Cloudflare: true}, + expect: true, + }, + { + name: "multiple services need disambiguation", + ctx: ServiceContext{AWS: true, K8s: true}, + expect: true, + }, + { + name: "single aws does not need llm", + ctx: ServiceContext{AWS: true}, + expect: false, + }, + { + name: "single k8s does not need llm", + ctx: ServiceContext{K8s: true}, + expect: false, + }, + { + name: "aws and cloudflare needs llm", + ctx: ServiceContext{AWS: true, Cloudflare: true}, + expect: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := NeedsLLMClassification(tt.ctx) + if result != tt.expect { + t.Errorf("NeedsLLMClassification(%+v) = %v, want %v", tt.ctx, result, tt.expect) + } + }) + } +} + +func TestApplyLLMClassification(t *testing.T) { + tests := []struct { + name string + llmService string + expectAWS bool + expectCF bool + expectK8s bool + expectGCP bool + }{ + { + name: "cloudflare classification", + llmService: "cloudflare", + expectCF: true, + expectAWS: false, + expectK8s: false, + expectGCP: false, + }, + { + name: "aws classification", + llmService: "aws", + expectAWS: true, + expectCF: false, + expectK8s: false, + expectGCP: false, + }, + { + name: "k8s classification", + llmService: "k8s", + expectK8s: true, + expectAWS: false, + expectCF: false, + expectGCP: false, + }, + { + name: "gcp classification", + llmService: "gcp", + expectGCP: true, + expectAWS: false, + expectCF: false, + expectK8s: false, + }, + { + name: "general defaults to aws", + llmService: "general", + expectAWS: true, + expectCF: false, + expectK8s: false, + expectGCP: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := ServiceContext{} + ApplyLLMClassification(&ctx, tt.llmService) + + if ctx.AWS != tt.expectAWS { + t.Errorf("ApplyLLMClassification(%q) AWS = %v, want %v", tt.llmService, ctx.AWS, tt.expectAWS) + } + if ctx.Cloudflare != tt.expectCF { + t.Errorf("ApplyLLMClassification(%q) Cloudflare = %v, want %v", tt.llmService, ctx.Cloudflare, tt.expectCF) + } + if ctx.K8s != tt.expectK8s { + t.Errorf("ApplyLLMClassification(%q) K8s = %v, want %v", tt.llmService, ctx.K8s, tt.expectK8s) + } + if ctx.GCP != tt.expectGCP { + t.Errorf("ApplyLLMClassification(%q) GCP = %v, want %v", tt.llmService, ctx.GCP, tt.expectGCP) + } + }) + } +} + +func TestContains(t *testing.T) { + tests := []struct { + s string + substr string + expect bool + }{ + {"Hello World", "world", true}, + {"Hello World", "WORLD", true}, + {"cloudflare zones", "cloudflare", true}, + {"list ec2", "EC2", true}, + {"kubernetes pods", "k8s", false}, + {"", "test", false}, + {"test", "", true}, + } + + for _, tt := range tests { + t.Run(tt.s+"_"+tt.substr, func(t *testing.T) { + result := contains(tt.s, tt.substr) + if result != tt.expect { + t.Errorf("contains(%q, %q) = %v, want %v", tt.s, tt.substr, result, tt.expect) + } + }) + } +} From 852e013eaec3f57214d5b2fd1c7d7c5ecf0e7fb6 Mon Sep 17 00:00:00 2001 From: nash Date: Mon, 26 Jan 2026 22:19:46 +0000 Subject: [PATCH 4/4] style(routing): fix struct field alignment in tests --- internal/routing/routing_test.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/internal/routing/routing_test.go b/internal/routing/routing_test.go index 8bc0d36..366f9d9 100644 --- a/internal/routing/routing_test.go +++ b/internal/routing/routing_test.go @@ -6,11 +6,11 @@ import ( func TestInferContext_CloudflareExplicit(t *testing.T) { tests := []struct { - name string - query string - expectCloudflare bool - expectAWS bool - expectK8s bool + name string + query string + expectCloudflare bool + expectAWS bool + expectK8s bool }{ { name: "explicit cloudflare mention",