diff --git a/cmd/src/gateway.go b/cmd/src/gateway.go deleted file mode 100644 index 0540d694e3..0000000000 --- a/cmd/src/gateway.go +++ /dev/null @@ -1,39 +0,0 @@ -package main - -import ( - "flag" - "fmt" -) - -var gatewayCommands commander - -func init() { - usage := `'src gateway' interacts with Cody Gateway (directly or through a Sourcegraph instance). - -Usage: - - src gateway command [command options] - -The commands are: - - benchmark runs benchmarks against Cody Gateway - benchmark-stream runs benchmarks against Cody Gateway code completion streaming endpoints - -Use "src gateway [command] -h" for more information about a command. - -` - - flagSet := flag.NewFlagSet("gateway", flag.ExitOnError) - handler := func(args []string) error { - gatewayCommands.run(flagSet, "src gateway", usage, args) - return nil - } - - // Register the command. - commands = append(commands, &command{ - flagSet: flagSet, - aliases: []string{}, // No aliases for gateway command - handler: handler, - usageFunc: func() { fmt.Println(usage) }, - }) -} diff --git a/cmd/src/gateway_benchmark.go b/cmd/src/gateway_benchmark.go deleted file mode 100644 index cc354bc625..0000000000 --- a/cmd/src/gateway_benchmark.go +++ /dev/null @@ -1,523 +0,0 @@ -package main - -import ( - "encoding/csv" - "flag" - "fmt" - "io" - "net/http" - "os" - "sort" - "strings" - "time" - - "github.com/gorilla/websocket" - - "github.com/sourcegraph/src-cli/internal/cmderrors" -) - -type Stats struct { - Avg time.Duration - P5 time.Duration - P75 time.Duration - P80 time.Duration - P95 time.Duration - Median time.Duration - Total time.Duration -} - -type requestResult struct { - duration time.Duration - traceID string // X-Trace header value -} - -func init() { - usage := ` -'src gateway benchmark' runs performance benchmarks against Cody Gateway and Sourcegraph test endpoints. - -Usage: - - src gateway benchmark [flags] - -Examples: - - $ src gateway benchmark --sgp - $ src gateway benchmark --requests 50 --sgp - $ src gateway benchmark --gateway http://localhost:9992 --sourcegraph http://localhost:3082 --sgp - $ src gateway benchmark --requests 50 --csv results.csv --request-csv requests.csv --sgp - $ src gateway benchmark --gateway https://cody-gateway.sourcegraph.com --sourcegraph https://sourcegraph.com --sgp --use-special-header -` - - flagSet := flag.NewFlagSet("benchmark", flag.ExitOnError) - - var ( - requestCount = flagSet.Int("requests", 1000, "Number of requests to make per endpoint") - csvOutput = flagSet.String("csv", "", "Export results to CSV file (provide filename)") - requestLevelCsvOutput = flagSet.String("request-csv", "", "Export request results to CSV file (provide filename)") - gatewayEndpoint = flagSet.String("gateway", "", "Cody Gateway endpoint") - sgEndpoint = flagSet.String("sourcegraph", "", "Sourcegraph endpoint") - sgpToken = flagSet.String("sgp", "", "Sourcegraph personal access token for the called instance") - useSpecialHeader = flagSet.Bool("use-special-header", false, "Use special header to test the gateway") - ) - - handler := func(args []string) error { - if err := flagSet.Parse(args); err != nil { - return err - } - - if len(flagSet.Args()) != 0 { - return cmderrors.Usage("additional arguments not allowed") - } - - if *useSpecialHeader { - fmt.Println("Using special header 'cody-core-gc-test'") - } - - var ( - httpClient = &http.Client{} - endpoints = map[string]any{} // Values: URL `string`s or `*webSocketClient`s - ) - if *gatewayEndpoint != "" { - fmt.Println("Benchmarking Cody Gateway instance:", *gatewayEndpoint) - headers := http.Header{ - "X-Sourcegraph-Should-Trace": []string{"true"}, - } - endpoints["ws(s): gateway"] = &webSocketClient{ - conn: nil, - URL: strings.Replace(fmt.Sprint(*gatewayEndpoint, "/v2/websocket"), "http", "ws", 1), - reqHeaders: headers, - } - endpoints["http(s): gateway"] = fmt.Sprint(*gatewayEndpoint, "/v2/http") - } else { - fmt.Println("warning: not benchmarking Cody Gateway (-gateway endpoint not provided)") - } - if *sgEndpoint != "" { - if *sgpToken == "" { - return cmderrors.Usage("must specify --sgp ") - } - fmt.Println("Benchmarking Sourcegraph instance:", *sgEndpoint) - headers := http.Header{ - "Authorization": []string{"token " + *sgpToken}, - "X-Sourcegraph-Should-Trace": []string{"true"}, - } - if *useSpecialHeader { - headers.Set("cody-core-gc-test", "M2R{+6VI?1,M3n&") - } - - endpoints["ws(s): sourcegraph"] = &webSocketClient{ - conn: nil, - URL: strings.Replace(fmt.Sprint(*sgEndpoint, "/.api/gateway/websocket"), "http", "ws", 1), - reqHeaders: headers, - } - endpoints["http(s): sourcegraph"] = fmt.Sprint(*sgEndpoint, "/.api/gateway/http") - endpoints["http(s): http-then-ws"] = fmt.Sprint(*sgEndpoint, "/.api/gateway/http-then-websocket") - } else { - fmt.Println("warning: not benchmarking Sourcegraph instance (-sourcegraph endpoint not provided)") - } - - fmt.Printf("Starting benchmark with %d requests per endpoint...\n", *requestCount) - - var eResults []endpointResult - rResults := map[string][]requestResult{} - for name, clientOrURL := range endpoints { - durations := make([]time.Duration, 0, *requestCount) - rResults[name] = make([]requestResult, 0, *requestCount) - fmt.Printf("\nTesting %s...", name) - - for i := 0; i < *requestCount; i++ { - if ws, ok := clientOrURL.(*webSocketClient); ok { - result := benchmarkEndpointWebSocket(ws) - if result.duration > 0 { - durations = append(durations, result.duration) - rResults[name] = append(rResults[name], result) - } - } else if url, ok := clientOrURL.(string); ok { - result := benchmarkEndpointHTTP(httpClient, url, *sgpToken, *useSpecialHeader) - if result.duration > 0 { - durations = append(durations, result.duration) - rResults[name] = append(rResults[name], result) - } - } - } - fmt.Println() - - stats := calculateStats(durations) - - eResults = append(eResults, endpointResult{ - name: name, - avg: stats.Avg, - median: stats.Median, - p5: stats.P5, - p75: stats.P75, - p80: stats.P80, - p95: stats.P95, - total: stats.Total, - successful: len(durations), - }) - } - - printResults(eResults, requestCount) - - if *csvOutput != "" { - if err := writeResultsToCSV(*csvOutput, eResults, requestCount); err != nil { - return fmt.Errorf("failed to export CSV: %v", err) - } - fmt.Printf("\nResults exported to %s\n", *csvOutput) - } - if *requestLevelCsvOutput != "" { - if err := writeRequestResultsToCSV(*requestLevelCsvOutput, rResults); err != nil { - return fmt.Errorf("failed to export request-level CSV: %v", err) - } - fmt.Printf("\nRequest-level results exported to %s\n", *requestLevelCsvOutput) - } - - return nil - } - - gatewayCommands = append(gatewayCommands, &command{ - flagSet: flagSet, - aliases: []string{}, - handler: handler, - usageFunc: func() { - _, err := fmt.Fprintf(flag.CommandLine.Output(), "Usage of 'src gateway %s':\n", flagSet.Name()) - if err != nil { - return - } - flagSet.PrintDefaults() - fmt.Println(usage) - }, - }) -} - -type webSocketClient struct { - conn *websocket.Conn - URL string - reqHeaders http.Header - respHeaders http.Header -} - -func (c *webSocketClient) reconnect() error { - if c.conn != nil { - c.conn.Close() // don't leak connections - } - fmt.Println("Connecting to WebSocket..", c.URL) - var err error - var resp *http.Response - //nolint:bodyclose // closed as part of webSocketClient - c.conn, resp, err = websocket.DefaultDialer.Dial(c.URL, c.reqHeaders) - if err != nil { - c.conn = nil // retry again later - return fmt.Errorf("WebSocket dial(%s): %v", c.URL, err) - } - c.respHeaders = resp.Header - fmt.Println("Connected!") - return nil -} - -type endpointResult struct { - name string - avg time.Duration - median time.Duration - p5 time.Duration - p75 time.Duration - p80 time.Duration - p95 time.Duration - total time.Duration - successful int -} - -func benchmarkEndpointHTTP(client *http.Client, url, accessToken string, useSpecialHeader bool) requestResult { - start := time.Now() - req, err := http.NewRequest("POST", url, strings.NewReader("ping")) - if err != nil { - fmt.Printf("Error creating request: %v\n", err) - return requestResult{} - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "token "+accessToken) - req.Header.Set("X-Sourcegraph-Should-Trace", "true") - if useSpecialHeader { - req.Header.Set("cody-core-gc-test", "M2R{+6VI?1,M3n&") - } - resp, err := client.Do(req) - if err != nil { - fmt.Printf("Error calling %s: %v\n", url, err) - return requestResult{} - } - defer func() { - err := resp.Body.Close() - if err != nil { - fmt.Printf("Error closing response body: %v\n", err) - } - }() - if resp.StatusCode != http.StatusOK { - fmt.Printf("non-200 response: %v\n", resp.Status) - return requestResult{} - } - body, err := io.ReadAll(resp.Body) - if err != nil { - fmt.Printf("Error reading response body: %v\n", err) - return requestResult{} - } - if string(body) != "pong" { - fmt.Printf("Expected 'pong' response, got: %q\n", string(body)) - return requestResult{} - } - - return requestResult{ - duration: time.Since(start), - traceID: resp.Header.Get("X-Trace"), - } -} - -func benchmarkEndpointWebSocket(client *webSocketClient) requestResult { - // Perform initial websocket connection, if needed. - if client.conn == nil { - if err := client.reconnect(); err != nil { - fmt.Printf("Error reconnecting: %v\n", err) - return requestResult{} - } - } - - // Perform the benchmarked request using the websocket. - start := time.Now() - err := client.conn.WriteMessage(websocket.TextMessage, []byte("ping")) - if err != nil { - fmt.Printf("WebSocket write error: %v\n", err) - if err := client.reconnect(); err != nil { - fmt.Printf("Error reconnecting: %v\n", err) - } - return requestResult{} - } - _, message, err := client.conn.ReadMessage() - - if err != nil { - fmt.Printf("WebSocket read error: %v\n", err) - if err := client.reconnect(); err != nil { - fmt.Printf("Error reconnecting: %v\n", err) - } - return requestResult{} - } - if string(message) != "pong" { - fmt.Printf("Expected 'pong' response, got: %q\n", string(message)) - if err := client.reconnect(); err != nil { - fmt.Printf("Error reconnecting: %v\n", err) - } - return requestResult{} - } - return requestResult{ - duration: time.Since(start), - traceID: client.respHeaders.Get("Content-Type"), - } -} - -func calculateStats(durations []time.Duration) Stats { - if len(durations) == 0 { - return Stats{0, 0, 0, 0, 0, 0, 0} - } - - // Sort durations in ascending order - sort.Slice(durations, func(i, j int) bool { - return durations[i] < durations[j] - }) - - var sum time.Duration - for _, d := range durations { - sum += d - } - avg := sum / time.Duration(len(durations)) - - return Stats{ - Avg: avg, - P5: durations[int(float64(len(durations))*0.05)], - P75: durations[int(float64(len(durations))*0.75)], - P80: durations[int(float64(len(durations))*0.80)], - P95: durations[int(float64(len(durations))*0.95)], - Median: durations[(len(durations) / 2)], - Total: sum, - } -} - -func formatDuration(d time.Duration, best bool, worst bool) string { - value := fmt.Sprintf("%.2fms", float64(d.Microseconds())/1000) - if best { - return ansiColors["green"] + value + ansiColors["nc"] - } - if worst { - return ansiColors["red"] + value + ansiColors["nc"] - } - return ansiColors["yellow"] + value + ansiColors["nc"] -} - -func formatSuccessRate(successful, total int, best bool, worst bool) string { - value := fmt.Sprintf("%d/%d", successful, total) - if best { - return ansiColors["green"] + value + ansiColors["nc"] - } - if worst { - return ansiColors["red"] + value + ansiColors["nc"] - } - return ansiColors["yellow"] + value + ansiColors["nc"] -} - -func printResults(results []endpointResult, requestCount *int) { - // Print header - headerFmt := ansiColors["blue"] + "%-25s | %-10s | %-10s | %-10s | %-10s | %-10s | %-10s | %-10s | %-10s" + ansiColors["nc"] + "\n" - fmt.Printf("\n"+headerFmt, - "Endpoint ", "Average", "Median", "P5", "P75", "P80", "P95", "Total", "Success") - fmt.Println(ansiColors["blue"] + strings.Repeat("-", 121) + ansiColors["nc"]) - - // Find best/worst values for each metric - var bestAvg, worstAvg time.Duration - var bestMedian, worstMedian time.Duration - var bestP5, worstP5 time.Duration - var bestP75, worstP75 time.Duration - var bestP80, worstP80 time.Duration - var bestP95, worstP95 time.Duration - var bestTotal, worstTotal time.Duration - var bestSuccess, worstSuccess int - - for i, r := range results { - if i == 0 || r.avg < bestAvg { - bestAvg = r.avg - } - if i == 0 || r.avg > worstAvg { - worstAvg = r.avg - } - if i == 0 || r.median < bestMedian { - bestMedian = r.median - } - if i == 0 || r.median > worstMedian { - worstMedian = r.median - } - if i == 0 || r.p5 < bestP5 { - bestP5 = r.p5 - } - if i == 0 || r.p5 > worstP5 { - worstP5 = r.p5 - } - if i == 0 || r.p75 < bestP75 { - bestP75 = r.p75 - } - if i == 0 || r.p75 > worstP75 { - worstP75 = r.p75 - } - if i == 0 || r.p80 < bestP80 { - bestP80 = r.p80 - } - if i == 0 || r.p80 > worstP80 { - worstP80 = r.p80 - } - if i == 0 || r.p95 < bestP95 { - bestP95 = r.p95 - } - if i == 0 || r.p95 > worstP95 { - worstP95 = r.p95 - } - if i == 0 || r.total < bestTotal { - bestTotal = r.total - } - if i == 0 || r.total > worstTotal { - worstTotal = r.total - } - if i == 0 || r.successful > bestSuccess { - bestSuccess = r.successful - } - if i == 0 || r.successful < worstSuccess { - worstSuccess = r.successful - } - } - - // Print each row - for _, r := range results { - fmt.Printf("%-25s | %-19s | %-19s | %-19s | %-19s | %-19s | %-19s | %-19s | %s\n", - r.name, - formatDuration(r.avg, r.avg == bestAvg, r.avg == worstAvg), - formatDuration(r.median, r.median == bestMedian, r.median == worstMedian), - formatDuration(r.p5, r.p5 == bestP5, r.p5 == worstP5), - formatDuration(r.p75, r.p75 == bestP75, r.p75 == worstP75), - formatDuration(r.p80, r.p80 == bestP80, r.p80 == worstP80), - formatDuration(r.p95, r.p95 == bestP95, r.p95 == worstP95), - formatDuration(r.total, r.total == bestTotal, r.total == worstTotal), - formatSuccessRate(r.successful, *requestCount, r.successful == bestSuccess, r.successful == worstSuccess)) - } -} - -func writeResultsToCSV(filename string, results []endpointResult, requestCount *int) error { - file, err := os.Create(filename) - if err != nil { - return fmt.Errorf("failed to create CSV file: %v", err) - } - defer func() { - err := file.Close() - if err != nil { - return - } - }() - - writer := csv.NewWriter(file) - defer writer.Flush() - - // Write header - header := []string{"Endpoint", "Average (ms)", "Median (ms)", "P5 (ms)", "P75 (ms)", "P80 (ms)", "P95 (ms)", "Total (ms)", "Success Rate"} - if err := writer.Write(header); err != nil { - return fmt.Errorf("failed to write CSV header: %v", err) - } - - // Write data rows - for _, r := range results { - row := []string{ - r.name, - fmt.Sprintf("%.2f", float64(r.avg.Microseconds())/1000), - fmt.Sprintf("%.2f", float64(r.median.Microseconds())/1000), - fmt.Sprintf("%.2f", float64(r.p5.Microseconds())/1000), - fmt.Sprintf("%.2f", float64(r.p75.Microseconds())/1000), - fmt.Sprintf("%.2f", float64(r.p80.Microseconds())/1000), - fmt.Sprintf("%.2f", float64(r.p95.Microseconds())/1000), - fmt.Sprintf("%.2f", float64(r.total.Microseconds())/1000), - fmt.Sprintf("%d/%d", r.successful, *requestCount), - } - if err := writer.Write(row); err != nil { - return fmt.Errorf("failed to write CSV row: %v", err) - } - } - - return nil -} - -func writeRequestResultsToCSV(filename string, results map[string][]requestResult) error { - file, err := os.Create(filename) - if err != nil { - return fmt.Errorf("failed to create CSV file: %v", err) - } - defer func() { - err := file.Close() - if err != nil { - return - } - }() - - writer := csv.NewWriter(file) - defer writer.Flush() - - // Write header - header := []string{"Endpoint", "Duration (ms)", "Trace ID"} - if err := writer.Write(header); err != nil { - return fmt.Errorf("failed to write CSV header: %v", err) - } - - for endpoint, requestResults := range results { - for _, result := range requestResults { - row := []string{ - endpoint, - fmt.Sprintf("%.2f", float64(result.duration.Microseconds())/1000), - result.traceID, - } - if err := writer.Write(row); err != nil { - return fmt.Errorf("failed to write CSV row: %v", err) - } - } - } - - return nil -} diff --git a/cmd/src/gateway_benchmark_stream.go b/cmd/src/gateway_benchmark_stream.go deleted file mode 100644 index 66154a5c8b..0000000000 --- a/cmd/src/gateway_benchmark_stream.go +++ /dev/null @@ -1,283 +0,0 @@ -package main - -import ( - "flag" - "fmt" - "io" - "net/http" - "strings" - "time" - - "github.com/sourcegraph/src-cli/internal/cmderrors" -) - -type httpEndpoint struct { - url string - authHeader string - body string -} - -func init() { - usage := ` -'src gateway benchmark-stream' runs performance benchmarks against Cody Gateway and Sourcegraph -code completion streaming endpoints. - -Usage: - - src gateway benchmark-stream [flags] - -Examples: - - $ src gateway benchmark-stream --requests 50 --csv results.csv --sgd --sgp - $ src gateway benchmark-stream --gateway http://localhost:9992 --sourcegraph http://localhost:3082 --sgd --sgp - $ src gateway benchmark-stream --requests 250 --gateway http://localhost:9992 --sourcegraph http://localhost:3082 --sgd --sgp --max-tokens 50 --provider fireworks --stream -` - - flagSet := flag.NewFlagSet("benchmark-stream", flag.ExitOnError) - - var ( - requestCount = flagSet.Int("requests", 1000, "Number of requests to make per endpoint") - csvOutput = flagSet.String("csv", "", "Export results to CSV file (provide filename)") - requestLevelCsvOutput = flagSet.String("request-csv", "", "Export request results to CSV file (provide filename)") - gatewayEndpoint = flagSet.String("gateway", "", "Cody Gateway endpoint") - sgEndpoint = flagSet.String("sourcegraph", "", "Sourcegraph endpoint") - sgdToken = flagSet.String("sgd", "", "Sourcegraph Dotcom user key for Cody Gateway") - sgpToken = flagSet.String("sgp", "", "Sourcegraph personal access token for the called instance") - maxTokens = flagSet.Int("max-tokens", 256, "Maximum number of tokens to generate") - provider = flagSet.String("provider", "anthropic", "Provider to use for completion. Supported values: 'anthropic', 'fireworks'") - stream = flagSet.Bool("stream", false, "Whether to stream completions. Default: false") - ) - - handler := func(args []string) error { - // Parse the flags. - if err := flagSet.Parse(args); err != nil { - return err - } - if len(flagSet.Args()) != 0 { - return cmderrors.Usage("additional arguments not allowed") - } - if *gatewayEndpoint != "" && *sgdToken == "" { - return cmderrors.Usage("must specify --sgp ") - } - if *sgEndpoint != "" && *sgpToken == "" { - return cmderrors.Usage("must specify --sgp ") - } - - var httpClient = &http.Client{} - var cgResult, sgResult endpointResult - var cgRequestResults, sgRequestResults []requestResult - - // Do the benchmarking. - fmt.Printf("Starting benchmark with %d requests per endpoint...\n", *requestCount) - if *gatewayEndpoint != "" { - fmt.Println("Benchmarking Cody Gateway instance:", *gatewayEndpoint) - endpoint := buildGatewayHttpEndpoint(*gatewayEndpoint, *sgdToken, *maxTokens, *provider, *stream) - cgResult, cgRequestResults = benchmarkCodeCompletions("gateway", httpClient, endpoint, *requestCount) - fmt.Println() - } else { - fmt.Println("warning: not benchmarking Cody Gateway (-gateway endpoint not provided)") - } - if *sgEndpoint != "" { - fmt.Println("Benchmarking Sourcegraph instance:", *sgEndpoint) - endpoint := buildSourcegraphHttpEndpoint(*sgEndpoint, *sgpToken, *maxTokens, *provider, *stream) - sgResult, sgRequestResults = benchmarkCodeCompletions("sourcegraph", httpClient, endpoint, *requestCount) - fmt.Println() - } else { - fmt.Println("warning: not benchmarking Sourcegraph instance (-sourcegraph endpoint not provided)") - } - - // Output the results. - endpointResults := []endpointResult{cgResult, sgResult} - printResults(endpointResults, requestCount) - if *csvOutput != "" { - if err := writeResultsToCSV(*csvOutput, endpointResults, requestCount); err != nil { - return fmt.Errorf("failed to export CSV: %v", err) - } - fmt.Printf("\nAggregate results exported to %s\n", *csvOutput) - } - if *requestLevelCsvOutput != "" { - if err := writeRequestResultsToCSV(*requestLevelCsvOutput, map[string][]requestResult{"gateway": cgRequestResults, "sourcegraph": sgRequestResults}); err != nil { - return fmt.Errorf("failed to export CSV: %v", err) - } - fmt.Printf("\nRequest-level results exported to %s\n", *requestLevelCsvOutput) - } - - return nil - } - gatewayCommands = append(gatewayCommands, &command{ - flagSet: flagSet, - aliases: []string{}, - handler: handler, - usageFunc: func() { - _, err := fmt.Fprintf(flag.CommandLine.Output(), "Usage of 'src gateway %s':\n", flagSet.Name()) - if err != nil { - return - } - flagSet.PrintDefaults() - fmt.Println(usage) - }, - }) -} - -func buildGatewayHttpEndpoint(gatewayEndpoint string, sgdToken string, maxTokens int, provider string, stream bool) httpEndpoint { - s := "true" - if !stream { - s = "false" - } - if provider == "anthropic" { - return httpEndpoint{ - url: fmt.Sprint(gatewayEndpoint, "/v1/completions/anthropic-messages"), - authHeader: fmt.Sprintf("Bearer %s", sgdToken), - body: fmt.Sprintf(`{ - "model": "claude-3-haiku-20240307", - "messages": [ - {"role": "user", "content": "def bubble_sort(arr):"}, - {"role": "assistant", "content": "Here is a bubble sort:"} - ], - "max_tokens": %d, - "temperature": 0.0, - "stream": %s -}`, maxTokens, s), - } - } else if provider == "fireworks" { - return httpEndpoint{ - url: fmt.Sprint(gatewayEndpoint, "/v1/completions/fireworks"), - authHeader: fmt.Sprintf("Bearer %s", sgdToken), - body: fmt.Sprintf(`{ - "model": "starcoder", - "prompt": "#hello.ts<|fim▁begin|>const sayHello = () => <|fim▁hole|><|fim▁end|>", - "max_tokens": %d, - "stop": [ - "\n\n", - "\n\r\n", - "<|fim▁begin|>", - "<|fim▁hole|>", - "<|fim▁end|>, <|eos_token|>" - ], - "temperature": 0.2, - "topK": 0, - "topP": 0, - "stream": %s -}`, maxTokens, s), - } - } - - return httpEndpoint{} -} - -func buildSourcegraphHttpEndpoint(sgEndpoint string, sgpToken string, maxTokens int, provider string, stream bool) httpEndpoint { - s := "true" - if !stream { - s = "false" - } - if provider == "anthropic" { - return httpEndpoint{ - url: fmt.Sprint(sgEndpoint, "/.api/completions/stream"), - authHeader: fmt.Sprintf("token %s", sgpToken), - body: fmt.Sprintf(`{ - "model": "anthropic::2023-06-01::claude-3-haiku", - "messages": [ - {"speaker": "human", "text": "def bubble_sort(arr):"}, - {"speaker": "assistant", "text": "Here is a bubble sort:"} - ], - "maxTokensToSample": %d, - "temperature": 0.0, - "stream": %s -}`, maxTokens, s), - } - } else if provider == "fireworks" { - return httpEndpoint{ - url: fmt.Sprint(sgEndpoint, "/.api/completions/code"), - authHeader: fmt.Sprintf("token %s", sgpToken), - body: fmt.Sprintf(`{ - "model": "fireworks::v1::starcoder", - "messages": [ - {"speaker": "human", "text": "#hello.ts<|fim▁begin|>const sayHello = () => <|fim▁hole|><|fim▁end|>"} - ], - "maxTokensToSample": %d, - "stopSequences": [ - "\n\n", - "\n\r\n", - "<|fim▁begin|>", - "<|fim▁hole|>", - "<|fim▁end|>, <|eos_token|>" - ], - "temperature": 0.2, - "topK": 0, - "topP": 0, - "stream": %s -}`, maxTokens, s), - } - } - - return httpEndpoint{} -} - -func benchmarkCodeCompletions(benchmarkName string, client *http.Client, endpoint httpEndpoint, requestCount int) (endpointResult, []requestResult) { - results := make([]requestResult, 0, requestCount) - durations := make([]time.Duration, 0, requestCount) - - for i := 0; i < requestCount; i++ { - result := benchmarkCodeCompletion(client, endpoint) - if result.duration > 0 { - results = append(results, result) - durations = append(durations, result.duration) - } - } - stats := calculateStats(durations) - - return toEndpointResult(benchmarkName, stats, len(durations)), results -} - -func benchmarkCodeCompletion(client *http.Client, endpoint httpEndpoint) requestResult { - start := time.Now() - req, err := http.NewRequest("POST", endpoint.url, strings.NewReader(endpoint.body)) - if err != nil { - fmt.Printf("Error creating request: %v\n", err) - return requestResult{0, ""} - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", endpoint.authHeader) - req.Header.Set("X-Sourcegraph-Should-Trace", "true") - req.Header.Set("X-Sourcegraph-Feature", "code_completions") - resp, err := client.Do(req) - if err != nil { - fmt.Printf("Error calling %s: %v\n", endpoint.url, err) - return requestResult{0, ""} - } - defer func() { - err := resp.Body.Close() - if err != nil { - fmt.Printf("Error closing response body: %v\n", err) - } - }() - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - fmt.Printf("non-200 response: %v - %s\n", resp.Status, body) - return requestResult{0, ""} - } - _, err = io.ReadAll(resp.Body) - if err != nil { - fmt.Printf("Error reading response body: %v\n", err) - return requestResult{0, ""} - } - - return requestResult{ - duration: time.Since(start), - traceID: resp.Header.Get("X-Trace"), - } -} - -func toEndpointResult(name string, stats Stats, requestCount int) endpointResult { - return endpointResult{ - name: name, - avg: stats.Avg, - median: stats.Median, - p5: stats.P5, - p75: stats.P75, - p80: stats.P80, - p95: stats.P95, - successful: requestCount, - total: stats.Total, - } -}