From f3e3c02436febc44f7923327a08ebe4caaeb0bc1 Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Wed, 15 Apr 2026 10:47:33 -0400 Subject: [PATCH 01/13] feat(test-harness): add parallel execution, better error reporting, and lint fixes Add --concurrency flag (default 4) to run model builds/tests in parallel using errgroup with a configurable semaphore. All three commands (run, build, schema-compare) support parallel execution with: - Pre-allocated result slices to preserve manifest ordering - Per-model progress lines under a mutex - Quiet mode that captures build output instead of streaming to avoid interleaved output in parallel mode - Per-model isolated work directories so models sharing a repo don't clobber each other's setup commands Improve error reporting: - Final error message now includes per-model failure details (error message or failing test names) instead of just listing model names - Setup commands run with 'set -euo pipefail' so missing tools (e.g. yq) are caught immediately instead of silently producing empty files - Validate cog.yaml is non-empty after setup to catch silent failures Fix all 30 golangci-lint issues across the test-harness: - copyloopvar: remove unnecessary Go <1.22 loop variable copies - errcheck: handle intentionally ignored errors - gocritic: use 0o octal literals, rewrite if-else chains as switch - gosec: annotate intentional HTTP calls to known APIs - misspell: fix British/American spelling - modernize: use strings.SplitSeq, maps.Copy, built-in min --- tools/test-harness/cmd/build.go | 59 ++++++++++-- tools/test-harness/cmd/root.go | 46 ++++++++++ tools/test-harness/cmd/run.go | 67 +++++++++++--- tools/test-harness/cmd/schema_compare.go | 71 ++++++++++++--- .../test-harness/internal/patcher/patcher.go | 7 +- .../internal/patcher/patcher_test.go | 6 +- tools/test-harness/internal/report/report.go | 17 ++-- .../internal/resolver/resolver.go | 27 +++--- tools/test-harness/internal/runner/runner.go | 89 ++++++++++++++----- .../internal/validator/validator.go | 7 -- 10 files changed, 309 insertions(+), 87 deletions(-) diff --git a/tools/test-harness/cmd/build.go b/tools/test-harness/cmd/build.go index cceee3bcfa..cba430a145 100644 --- a/tools/test-harness/cmd/build.go +++ b/tools/test-harness/cmd/build.go @@ -3,7 +3,9 @@ package cmd import ( "context" "fmt" - "strings" + "sync" + + "golang.org/x/sync/errgroup" "github.com/spf13/cobra" @@ -32,7 +34,13 @@ func runBuild(ctx context.Context) error { fmt.Println("No models to build") return nil } - fmt.Printf("Building %d model(s)\n\n", len(models)) + + parallel := concurrency > 1 && len(models) > 1 + if parallel { + fmt.Printf("Building %d model(s) with concurrency %d\n\n", len(models), concurrency) + } else { + fmt.Printf("Building %d model(s)\n\n", len(models)) + } // Create runner r, err := runner.New(runner.Options{ @@ -41,18 +49,51 @@ func runBuild(ctx context.Context) error { SDKWheel: resolved.SDKWheel, CleanImages: cleanImages, KeepOutputs: keepOutputs, + Quiet: parallel, }) if err != nil { return fmt.Errorf("creating runner: %w", err) } - defer r.Cleanup() + defer func() { _ = r.Cleanup() }() // Build models - var results []report.ModelResult - for _, model := range models { - fmt.Printf("Building %s...\n", model.Name) - result := r.BuildModel(ctx, model) - results = append(results, *result) + results := make([]report.ModelResult, len(models)) + + if parallel { + g, ctx := errgroup.WithContext(ctx) + g.SetLimit(concurrency) + + var mu sync.Mutex + for i, model := range models { + g.Go(func() error { + mu.Lock() + fmt.Printf(" [%d/%d] Building %s...\n", i+1, len(models), model.Name) + mu.Unlock() + + result := r.BuildModel(ctx, model) + results[i] = *result + + mu.Lock() + switch { + case result.Passed: + fmt.Printf(" [%d/%d] + %s (%.1fs)\n", i+1, len(models), model.Name, result.BuildDuration) + case result.Skipped: + fmt.Printf(" [%d/%d] - %s (skipped: %s)\n", i+1, len(models), model.Name, result.SkipReason) + default: + fmt.Printf(" [%d/%d] x %s FAILED\n", i+1, len(models), model.Name) + } + mu.Unlock() + + return nil + }) + } + _ = g.Wait() + } else { + for i, model := range models { + fmt.Printf("Building %s...\n", model.Name) + result := r.BuildModel(ctx, model) + results[i] = *result + } } // Output results @@ -66,7 +107,7 @@ func runBuild(ctx context.Context) error { } } if len(failedNames) > 0 { - return fmt.Errorf("%d build(s) failed: %s", len(failedNames), strings.Join(failedNames, ", ")) + return formatFailureSummary("build", results) } return nil diff --git a/tools/test-harness/cmd/root.go b/tools/test-harness/cmd/root.go index 686dfa7319..32e91b0d99 100644 --- a/tools/test-harness/cmd/root.go +++ b/tools/test-harness/cmd/root.go @@ -2,10 +2,12 @@ package cmd import ( "fmt" + "strings" "github.com/spf13/cobra" "github.com/replicate/cog/tools/test-harness/internal/manifest" + "github.com/replicate/cog/tools/test-harness/internal/report" "github.com/replicate/cog/tools/test-harness/internal/resolver" ) @@ -21,6 +23,7 @@ var ( sdkWheel string cleanImages bool keepOutputs bool + concurrency int ) // NewRootCommand creates the root command @@ -46,6 +49,7 @@ It reads the same manifest.yaml format as the Python version.`, rootCmd.PersistentFlags().StringVar(&sdkWheel, "sdk-wheel", "", "Path to pre-built SDK wheel") rootCmd.PersistentFlags().BoolVar(&cleanImages, "clean-images", false, "Remove Docker images after run (default: keep them)") rootCmd.PersistentFlags().BoolVar(&keepOutputs, "keep-outputs", false, "Preserve prediction outputs (images, files) in the work directory") + rootCmd.PersistentFlags().IntVar(&concurrency, "concurrency", 4, "Maximum number of models to build/test in parallel") // Subcommands rootCmd.AddCommand(newRunCommand()) @@ -77,3 +81,45 @@ func resolveSetup() (*manifest.Manifest, []manifest.Model, *resolver.Result, err models := mf.FilterModels(modelFilter, noGPU, gpuOnly) return mf, models, resolved, nil } + +// formatFailureSummary builds an error message with per-model failure details. +func formatFailureSummary(action string, results []report.ModelResult) error { + var b strings.Builder + var failCount int + for _, r := range results { + if r.Passed || r.Skipped { + continue + } + failCount++ + fmt.Fprintf(&b, "\n x %s", r.Name) + if r.Error != "" { + // Show first line of the error + firstLine := r.Error + if idx := strings.Index(firstLine, "\n"); idx != -1 { + firstLine = firstLine[:idx] + } + fmt.Fprintf(&b, ": %s", firstLine) + } else { + // Summarize failed tests + for _, t := range r.TestResults { + if !t.Passed { + msg := t.Message + if idx := strings.Index(msg, "\n"); idx != -1 { + msg = msg[:idx] + } + fmt.Fprintf(&b, "\n test %q: %s", t.Description, msg) + } + } + for _, t := range r.TrainResults { + if !t.Passed { + msg := t.Message + if idx := strings.Index(msg, "\n"); idx != -1 { + msg = msg[:idx] + } + fmt.Fprintf(&b, "\n train %q: %s", t.Description, msg) + } + } + } + } + return fmt.Errorf("%d %s(s) failed:%s", failCount, action, b.String()) +} diff --git a/tools/test-harness/cmd/run.go b/tools/test-harness/cmd/run.go index 16585318e1..bea4769609 100644 --- a/tools/test-harness/cmd/run.go +++ b/tools/test-harness/cmd/run.go @@ -4,7 +4,9 @@ import ( "context" "fmt" "os" - "strings" + "sync" + + "golang.org/x/sync/errgroup" "github.com/spf13/cobra" @@ -51,7 +53,13 @@ func runRun(ctx context.Context, outputFormat, outputFile string) error { fmt.Println("No models to run") return nil } - fmt.Printf("Running %d model(s)\n\n", len(models)) + + parallel := concurrency > 1 && len(models) > 1 + if parallel { + fmt.Printf("Running %d model(s) with concurrency %d\n\n", len(models), concurrency) + } else { + fmt.Printf("Running %d model(s)\n\n", len(models)) + } // Create runner r, err := runner.New(runner.Options{ @@ -60,18 +68,52 @@ func runRun(ctx context.Context, outputFormat, outputFile string) error { SDKWheel: resolved.SDKWheel, CleanImages: cleanImages, KeepOutputs: keepOutputs, + Quiet: parallel, }) if err != nil { return fmt.Errorf("creating runner: %w", err) } - defer r.Cleanup() + defer func() { _ = r.Cleanup() }() // Run tests - var results []report.ModelResult - for _, model := range models { - fmt.Printf("Running %s...\n", model.Name) - result := r.RunModel(ctx, model) - results = append(results, *result) + results := make([]report.ModelResult, len(models)) + + if parallel { + g, ctx := errgroup.WithContext(ctx) + g.SetLimit(concurrency) + + var mu sync.Mutex + for i, model := range models { + g.Go(func() error { + mu.Lock() + fmt.Printf(" [%d/%d] Running %s...\n", i+1, len(models), model.Name) + mu.Unlock() + + result := r.RunModel(ctx, model) + results[i] = *result + + mu.Lock() + switch { + case result.Skipped: + fmt.Printf(" [%d/%d] - %s (skipped: %s)\n", i+1, len(models), model.Name, result.SkipReason) + case result.Passed: + testCount := len(result.TestResults) + len(result.TrainResults) + fmt.Printf(" [%d/%d] + %s (%.1fs build, %d tests passed)\n", i+1, len(models), model.Name, result.BuildDuration, testCount) + default: + fmt.Printf(" [%d/%d] x %s FAILED\n", i+1, len(models), model.Name) + } + mu.Unlock() + + return nil + }) + } + _ = g.Wait() + } else { + for i, model := range models { + fmt.Printf("Running %s...\n", model.Name) + result := r.RunModel(ctx, model) + results[i] = *result + } } // Output results @@ -111,14 +153,15 @@ func runRun(ctx context.Context, outputFormat, outputFile string) error { } // Check for failures - var failedNames []string + var hasFailures bool for _, r := range results { if !r.Passed && !r.Skipped { - failedNames = append(failedNames, r.Name) + hasFailures = true + break } } - if len(failedNames) > 0 { - return fmt.Errorf("%d model(s) failed: %s", len(failedNames), strings.Join(failedNames, ", ")) + if hasFailures { + return formatFailureSummary("model", results) } return nil diff --git a/tools/test-harness/cmd/schema_compare.go b/tools/test-harness/cmd/schema_compare.go index aff44eb44a..8a9394008a 100644 --- a/tools/test-harness/cmd/schema_compare.go +++ b/tools/test-harness/cmd/schema_compare.go @@ -5,6 +5,9 @@ import ( "fmt" "os" "strings" + "sync" + + "golang.org/x/sync/errgroup" "github.com/spf13/cobra" @@ -45,7 +48,13 @@ func runSchemaCompare(ctx context.Context, outputFormat, outputFile string) erro fmt.Println("No models to compare") return nil } - fmt.Printf("Comparing schemas for %d model(s)\n\n", len(models)) + + parallel := concurrency > 1 && len(models) > 1 + if parallel { + fmt.Printf("Comparing schemas for %d model(s) with concurrency %d\n\n", len(models), concurrency) + } else { + fmt.Printf("Comparing schemas for %d model(s)\n\n", len(models)) + } // Create runner r, err := runner.New(runner.Options{ @@ -54,18 +63,48 @@ func runSchemaCompare(ctx context.Context, outputFormat, outputFile string) erro SDKWheel: resolved.SDKWheel, CleanImages: cleanImages, KeepOutputs: keepOutputs, + Quiet: parallel, }) if err != nil { return fmt.Errorf("creating runner: %w", err) } - defer r.Cleanup() + defer func() { _ = r.Cleanup() }() // Compare schemas - var results []report.SchemaCompareResult - for _, model := range models { - fmt.Printf("Comparing %s...\n", model.Name) - result := r.CompareSchema(ctx, model) - results = append(results, *result) + results := make([]report.SchemaCompareResult, len(models)) + + if parallel { + g, ctx := errgroup.WithContext(ctx) + g.SetLimit(concurrency) + + var mu sync.Mutex + for i, model := range models { + g.Go(func() error { + mu.Lock() + fmt.Printf(" [%d/%d] Comparing %s...\n", i+1, len(models), model.Name) + mu.Unlock() + + result := r.CompareSchema(ctx, model) + results[i] = *result + + mu.Lock() + if result.Passed { + fmt.Printf(" [%d/%d] + %s schemas match\n", i+1, len(models), model.Name) + } else { + fmt.Printf(" [%d/%d] x %s FAILED\n", i+1, len(models), model.Name) + } + mu.Unlock() + + return nil + }) + } + _ = g.Wait() + } else { + for i, model := range models { + fmt.Printf("Comparing %s...\n", model.Name) + result := r.CompareSchema(ctx, model) + results[i] = *result + } } // Output results @@ -105,14 +144,24 @@ func runSchemaCompare(ctx context.Context, outputFormat, outputFile string) erro } // Check for failures - var failedNames []string + var failedDetails []string for _, r := range results { if !r.Passed { - failedNames = append(failedNames, r.Name) + detail := r.Name + if r.Error != "" { + firstLine := r.Error + if idx := strings.Index(firstLine, "\n"); idx != -1 { + firstLine = firstLine[:idx] + } + detail += ": " + firstLine + } else if r.Diff != "" { + detail += ": schemas differ" + } + failedDetails = append(failedDetails, " x "+detail) } } - if len(failedNames) > 0 { - return fmt.Errorf("%d schema comparison(s) failed: %s", len(failedNames), strings.Join(failedNames, ", ")) + if len(failedDetails) > 0 { + return fmt.Errorf("%d schema comparison(s) failed:\n%s", len(failedDetails), strings.Join(failedDetails, "\n")) } return nil diff --git a/tools/test-harness/internal/patcher/patcher.go b/tools/test-harness/internal/patcher/patcher.go index 847a3f605b..e289b45648 100644 --- a/tools/test-harness/internal/patcher/patcher.go +++ b/tools/test-harness/internal/patcher/patcher.go @@ -2,6 +2,7 @@ package patcher import ( "fmt" + "maps" "os" "gopkg.in/yaml.v3" @@ -44,7 +45,7 @@ func Patch(cogYAMLPath string, sdkVersion string, overrides map[string]any) erro return fmt.Errorf("marshaling cog.yaml: %w", err) } - if err := os.WriteFile(cogYAMLPath, output, 0644); err != nil { + if err := os.WriteFile(cogYAMLPath, output, 0o644); err != nil { return fmt.Errorf("writing cog.yaml: %w", err) } @@ -54,9 +55,7 @@ func Patch(cogYAMLPath string, sdkVersion string, overrides map[string]any) erro // deepMerge recursively merges override into base func deepMerge(base, override map[string]any) map[string]any { result := make(map[string]any) - for k, v := range base { - result[k] = v - } + maps.Copy(result, base) for k, v := range override { if baseVal, ok := result[k]; ok { diff --git a/tools/test-harness/internal/patcher/patcher_test.go b/tools/test-harness/internal/patcher/patcher_test.go index 19cc109755..99a11811e7 100644 --- a/tools/test-harness/internal/patcher/patcher_test.go +++ b/tools/test-harness/internal/patcher/patcher_test.go @@ -18,7 +18,7 @@ func TestPatch(t *testing.T) { python_version: "3.10" predict: predict.py ` - require.NoError(t, os.WriteFile(cogYAML, []byte(content), 0644)) + require.NoError(t, os.WriteFile(cogYAML, []byte(content), 0o644)) require.NoError(t, Patch(cogYAML, "0.16.12", nil)) data, err := os.ReadFile(cogYAML) @@ -32,7 +32,7 @@ predict: predict.py python_version: "3.10" predict: predict.py ` - require.NoError(t, os.WriteFile(cogYAML, []byte(content), 0644)) + require.NoError(t, os.WriteFile(cogYAML, []byte(content), 0o644)) overrides := map[string]any{ "build": map[string]any{ @@ -53,7 +53,7 @@ predict: predict.py python_version: "3.10" predict: predict.py ` - require.NoError(t, os.WriteFile(cogYAML, []byte(content), 0644)) + require.NoError(t, os.WriteFile(cogYAML, []byte(content), 0o644)) overrides := map[string]any{ "predict": "new_predict.py", diff --git a/tools/test-harness/internal/report/report.go b/tools/test-harness/internal/report/report.go index 6021784e88..dfb2a074d1 100644 --- a/tools/test-harness/internal/report/report.go +++ b/tools/test-harness/internal/report/report.go @@ -76,7 +76,7 @@ func ConsoleReport(results []ModelResult, sdkVersion, cogVersion string) { } writeStatus("x", r.Name, firstLine, r.GPU) // Print full error details indented below - for _, line := range strings.Split(r.Error, "\n") { + for line := range strings.SplitSeq(r.Error, "\n") { if line != "" { fmt.Printf(" %s\n", line) } @@ -123,7 +123,7 @@ func ConsoleReport(results []ModelResult, sdkVersion, cogVersion string) { // Print individual failures with full output for _, t := range failures { fmt.Printf(" x %s:\n", t.Description) - for _, line := range strings.Split(t.Message, "\n") { + for line := range strings.SplitSeq(t.Message, "\n") { fmt.Printf(" %s\n", line) } } @@ -192,11 +192,12 @@ func JSONReport(results []ModelResult, sdkVersion, cogVersion string) map[string failed := 0 skipped := 0 for _, r := range results { - if r.Skipped { + switch { + case r.Skipped: skipped++ - } else if r.Passed { + case r.Passed: passed++ - } else { + default: failed++ } } @@ -242,7 +243,7 @@ func SchemaCompareConsoleReport(results []SchemaCompareResult, cogVersion string firstLine = firstLine[:idx] } writeStatus("x", r.Name, firstLine, false) - for _, line := range strings.Split(r.Error, "\n") { + for line := range strings.SplitSeq(r.Error, "\n") { if line != "" { fmt.Printf(" %s\n", line) } @@ -259,7 +260,7 @@ func SchemaCompareConsoleReport(results []SchemaCompareResult, cogVersion string writeStatus("x", r.Name, "schemas differ", false) failed++ if r.Diff != "" { - for _, line := range strings.Split(r.Diff, "\n") { + for line := range strings.SplitSeq(r.Diff, "\n") { fmt.Printf(" %s\n", line) } } @@ -339,7 +340,7 @@ func round(val float64, precision int) float64 { // SaveResults saves results to a JSON file in the results directory func SaveResults(results []ModelResult, sdkVersion, cogVersion string) (string, error) { resultsDir := "results" - if err := os.MkdirAll(resultsDir, 0755); err != nil { + if err := os.MkdirAll(resultsDir, 0o755); err != nil { return "", fmt.Errorf("creating results dir: %w", err) } diff --git a/tools/test-harness/internal/resolver/resolver.go b/tools/test-harness/internal/resolver/resolver.go index 8051356332..84194efad7 100644 --- a/tools/test-harness/internal/resolver/resolver.go +++ b/tools/test-harness/internal/resolver/resolver.go @@ -63,21 +63,22 @@ func Resolve(cogBinary, cogVersion, cogRef, sdkVersion, sdkWheel string, manifes result.CogVersion = version // Determine SDK wheel: explicit --sdk-wheel wins over ref-built - if sdkWheel != "" { + switch { + case sdkWheel != "": result.SDKWheel = sdkWheel result.SDKVersion = fmt.Sprintf("wheel:%s", filepath.Base(sdkWheel)) if sdkVersion != "" { result.SDKPatchVersion = sdkVersion result.SDKVersion = sdkVersion } - } else if wheel != "" { + case wheel != "": result.SDKWheel = wheel result.SDKVersion = version if sdkVersion != "" { result.SDKPatchVersion = sdkVersion result.SDKVersion = sdkVersion } - } else { + default: // Resolve SDK version from PyPI or explicit flag sdkVer, err := resolveSDKVersion(sdkVersion, manifestDefaults) if err != nil { @@ -168,7 +169,7 @@ func resolveLatestCogVersion() (string, error) { setGitHubAuth(req) client := &http.Client{Timeout: 30 * time.Second} - resp, err := client.Do(req) + resp, err := client.Do(req) //nolint:gosec // URL is constructed from a known constant if err != nil { return "", fmt.Errorf("fetching GitHub releases: %w", err) } @@ -208,7 +209,7 @@ func resolveLatestPyPIVersion() (string, error) { req.Header.Set("Accept", "application/json") client := &http.Client{Timeout: 30 * time.Second} - resp, err := client.Do(req) + resp, err := client.Do(req) //nolint:gosec // URL is a known PyPI API constant if err != nil { return "", fmt.Errorf("fetching PyPI: %w", err) } @@ -266,7 +267,7 @@ func downloadCogBinary(tag string) (dest string, err error) { return "", fmt.Errorf("cannot determine home directory: %w", err) } baseDir := filepath.Join(home, ".cache", "cog-harness", "bin") - if err := os.MkdirAll(baseDir, 0755); err != nil { + if err := os.MkdirAll(baseDir, 0o755); err != nil { return "", fmt.Errorf("creating bin cache dir: %w", err) } tmpDir, err := os.MkdirTemp(baseDir, "cog-bin-*") @@ -275,7 +276,7 @@ func downloadCogBinary(tag string) (dest string, err error) { } defer func() { if err != nil { - os.RemoveAll(tmpDir) + _ = os.RemoveAll(tmpDir) } }() @@ -289,7 +290,7 @@ func downloadCogBinary(tag string) (dest string, err error) { setGitHubAuth(req) client := &http.Client{Timeout: 120 * time.Second} - resp, err := client.Do(req) + resp, err := client.Do(req) //nolint:gosec // URL is constructed from GitHub releases if err != nil { return "", fmt.Errorf("downloading cog binary: %w", err) } @@ -299,13 +300,13 @@ func downloadCogBinary(tag string) (dest string, err error) { return "", fmt.Errorf("download returned %d", resp.StatusCode) } - f, err := os.OpenFile(dest, os.O_CREATE|os.O_WRONLY, 0755) + f, err := os.OpenFile(dest, os.O_CREATE|os.O_WRONLY, 0o755) if err != nil { return "", fmt.Errorf("creating binary file: %w", err) } if _, copyErr := io.Copy(f, resp.Body); copyErr != nil { - f.Close() + _ = f.Close() return "", fmt.Errorf("writing binary: %w", copyErr) } @@ -336,7 +337,7 @@ func verifyDownloadedBinary(tag, assetName, dest string) error { setGitHubAuth(req) client := &http.Client{Timeout: 30 * time.Second} - resp, err := client.Do(req) + resp, err := client.Do(req) //nolint:gosec // URL is constructed from GitHub releases if err != nil { return fmt.Errorf("downloading checksum file: %w", err) } @@ -369,7 +370,7 @@ func verifyDownloadedBinary(tag, assetName, dest string) error { } func parseChecksum(content, assetName string) (string, error) { - for _, line := range strings.Split(content, "\n") { + for line := range strings.SplitSeq(content, "\n") { line = strings.TrimSpace(line) if line == "" { continue @@ -470,7 +471,7 @@ func buildCogFromRef(ref string) (string, string, string, error) { // Build SDK wheel wheelDir := filepath.Join(tmpDir, "dist") - if err := os.MkdirAll(wheelDir, 0755); err != nil { + if err := os.MkdirAll(wheelDir, 0o755); err != nil { return "", "", "", fmt.Errorf("creating wheel dir: %w", err) } diff --git a/tools/test-harness/internal/runner/runner.go b/tools/test-harness/internal/runner/runner.go index d36c2bf9e3..ab2d9ae578 100644 --- a/tools/test-harness/internal/runner/runner.go +++ b/tools/test-harness/internal/runner/runner.go @@ -14,6 +14,7 @@ import ( "runtime" "sort" "strings" + "sync" "time" "golang.org/x/sync/errgroup" @@ -34,14 +35,18 @@ type Options struct { FixturesDir string CleanImages bool KeepOutputs bool + Quiet bool // Suppress real-time build output (for parallel execution) } -// Runner orchestrates the test lifecycle +// Runner orchestrates the test lifecycle. +// It is safe to call RunModel, BuildModel, and CompareSchema concurrently +// from multiple goroutines. type Runner struct { opts Options fixturesDir string workDir string clonedRepos map[string]string + mu sync.Mutex // protects clonedRepos } // New creates a new Runner @@ -65,7 +70,7 @@ func New(opts Options) (*Runner, error) { return nil, fmt.Errorf("cannot determine home directory for work dir (set $HOME): %w", err) } baseDir := filepath.Join(home, ".cache", "cog-harness") - if err := os.MkdirAll(baseDir, 0755); err != nil { + if err := os.MkdirAll(baseDir, 0o755); err != nil { return nil, fmt.Errorf("creating harness cache dir: %w", err) } workDir, err := os.MkdirTemp(baseDir, "run-*") @@ -258,8 +263,8 @@ func (r *Runner) CompareSchema(ctx context.Context, model manifest.Model) *repor // Always clean up schema comparison images when done defer func() { - exec.Command("docker", "rmi", "-f", staticTag).Run() - exec.Command("docker", "rmi", "-f", runtimeTag).Run() + _ = exec.Command("docker", "rmi", "-f", staticTag).Run() + _ = exec.Command("docker", "rmi", "-f", runtimeTag).Run() }() staticDir := filepath.Join(r.workDir, fmt.Sprintf("schema-static-%s", model.Name)) @@ -307,7 +312,7 @@ func (r *Runner) CompareSchema(ctx context.Context, model manifest.Model) *repor if err := g.Wait(); err != nil { result.Passed = false - result.Error = fmt.Sprintf("context cancelled: %v", err) + result.Error = fmt.Sprintf("context canceled: %v", err) return result } @@ -377,12 +382,20 @@ func (r *Runner) prepareModel(ctx context.Context, model manifest.Model) (string } modelDir = dest } else { - // Clone repo + // Clone repo (shared cache, thread-safe) repoDir, err := r.cloneRepo(ctx, model.Repo) if err != nil { return "", err } - modelDir = filepath.Join(repoDir, model.Path) + + // Each model gets its own copy so that setup commands (e.g. + // select.sh) don't clobber each other when running in parallel. + srcDir := filepath.Join(repoDir, model.Path) + dest := filepath.Join(r.workDir, fmt.Sprintf("model-%s", model.Name)) + if err := copyDir(srcDir, dest); err != nil { + return "", fmt.Errorf("copying repo for model %s: %w", model.Name, err) + } + modelDir = dest } // Run setup commands (e.g. script/select.sh to generate cog.yaml) @@ -390,9 +403,14 @@ func (r *Runner) prepareModel(ctx context.Context, model manifest.Model) (string return "", fmt.Errorf("running setup commands: %w", err) } - if _, err := os.Stat(filepath.Join(modelDir, "cog.yaml")); err != nil { + cogYAMLPath := filepath.Join(modelDir, "cog.yaml") + info, err := os.Stat(cogYAMLPath) + if err != nil { return "", fmt.Errorf("no cog.yaml in %s (did setup commands run correctly?)", modelDir) } + if info.Size() == 0 { + return "", fmt.Errorf("cog.yaml in %s is empty (setup commands may have failed silently — check that tools like yq are installed)", modelDir) + } // Patch cog.yaml sdkVersion := model.SDKVersion @@ -412,23 +430,41 @@ func (r *Runner) prepareModel(ctx context.Context, model manifest.Model) (string // from templates (e.g. "script/select.sh dev" in replicate/cog-flux). func (r *Runner) runSetupCommands(ctx context.Context, modelDir string, model manifest.Model) error { for _, cmdStr := range model.Setup { - fmt.Printf(" Running setup: %s\n", cmdStr) - cmd := exec.CommandContext(ctx, "sh", "-c", cmdStr) + if !r.opts.Quiet { + fmt.Printf(" Running setup: %s\n", cmdStr) + } + // Wrap with set -euo pipefail so any failing command in the + // script (e.g. a missing yq binary) is caught immediately + // rather than silently producing an empty/invalid cog.yaml. + cmd := exec.CommandContext(ctx, "sh", "-c", "set -euo pipefail; "+cmdStr) cmd.Dir = modelDir - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr cmd.Env = os.Environ() for k, v := range model.Env { cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", k, os.ExpandEnv(v))) } - if err := cmd.Run(); err != nil { - return fmt.Errorf("setup command %q failed: %w", cmdStr, err) + if r.opts.Quiet { + var outputBuf bytes.Buffer + cmd.Stdout = &outputBuf + cmd.Stderr = &outputBuf + if err := cmd.Run(); err != nil { + return fmt.Errorf("setup command %q failed: %w\n%s", cmdStr, err, outputBuf.String()) + } + } else { + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + return fmt.Errorf("setup command %q failed: %w", cmdStr, err) + } } } return nil } +// cloneRepo clones a repo once and caches the result. Thread-safe. func (r *Runner) cloneRepo(ctx context.Context, repo string) (string, error) { + r.mu.Lock() + defer r.mu.Unlock() + if dir, ok := r.clonedRepos[repo]; ok { return dir, nil } @@ -436,7 +472,7 @@ func (r *Runner) cloneRepo(ctx context.Context, repo string) (string, error) { dest := filepath.Join(r.workDir, strings.ReplaceAll(repo, "/", "--")) // Remove if exists - os.RemoveAll(dest) + _ = os.RemoveAll(dest) url := fmt.Sprintf("https://github.com/%s.git", repo) cmd := exec.CommandContext(ctx, "git", "clone", "--depth=1", url, dest) @@ -490,11 +526,24 @@ func (r *Runner) buildModelWithEnv(ctx context.Context, modelDir string, model m // Stream build output to stderr in real-time so the user can see progress, // while also capturing it for error reporting if the build fails. + // When running in parallel mode (opts.Quiet), only capture output and + // include it in error messages to avoid interleaved output. var outputBuf bytes.Buffer - cmd.Stdout = io.MultiWriter(os.Stderr, &outputBuf) - cmd.Stderr = io.MultiWriter(os.Stderr, &outputBuf) + if r.opts.Quiet { + cmd.Stdout = &outputBuf + cmd.Stderr = &outputBuf + } else { + cmd.Stdout = io.MultiWriter(os.Stderr, &outputBuf) + cmd.Stderr = io.MultiWriter(os.Stderr, &outputBuf) + } if err := cmd.Run(); err != nil { - return fmt.Errorf("%w\n%s", err, outputBuf.String()) + // Include the last portion of build output for context. + output := outputBuf.String() + const maxTail = 2000 + if len(output) > maxTail { + output = "...\n" + output[len(output)-maxTail:] + } + return fmt.Errorf("%w\n%s", err, output) } return nil } @@ -634,7 +683,7 @@ func (r *Runner) resolveInput(value any) string { func extractOutput(stdout, stderr, modelDir string) string { // For file outputs (e.g. images), cog writes the file to CWD and prints // "Written output to: " on stderr. Check stderr for this pattern. - for _, line := range strings.Split(stderr, "\n") { + for line := range strings.SplitSeq(stderr, "\n") { if strings.Contains(line, "Written output to:") { parts := strings.SplitN(line, "Written output to:", 2) if len(parts) == 2 { @@ -671,7 +720,7 @@ func copyDir(src, dst string) error { dstPath := filepath.Join(dst, rel) if d.IsDir() { - return os.MkdirAll(dstPath, 0755) + return os.MkdirAll(dstPath, 0o755) } data, err := os.ReadFile(path) diff --git a/tools/test-harness/internal/validator/validator.go b/tools/test-harness/internal/validator/validator.go index f13b2fc6db..4482c5064a 100644 --- a/tools/test-harness/internal/validator/validator.go +++ b/tools/test-harness/internal/validator/validator.go @@ -249,10 +249,3 @@ func getKeys(m map[string]any) []string { } return keys } - -func min(a, b int) int { - if a < b { - return a - } - return b -} From dd1145df86a6195c48dec44db05eb0073d392e3a Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Wed, 15 Apr 2026 10:49:27 -0400 Subject: [PATCH 02/13] fix: use bash instead of sh for setup commands (dash lacks pipefail) --- tools/test-harness/internal/runner/runner.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tools/test-harness/internal/runner/runner.go b/tools/test-harness/internal/runner/runner.go index ab2d9ae578..e65516b8f5 100644 --- a/tools/test-harness/internal/runner/runner.go +++ b/tools/test-harness/internal/runner/runner.go @@ -433,10 +433,11 @@ func (r *Runner) runSetupCommands(ctx context.Context, modelDir string, model ma if !r.opts.Quiet { fmt.Printf(" Running setup: %s\n", cmdStr) } - // Wrap with set -euo pipefail so any failing command in the + // Use bash with strict mode so any failing command in the // script (e.g. a missing yq binary) is caught immediately // rather than silently producing an empty/invalid cog.yaml. - cmd := exec.CommandContext(ctx, "sh", "-c", "set -euo pipefail; "+cmdStr) + // We use bash (not sh) because dash does not support pipefail. + cmd := exec.CommandContext(ctx, "bash", "-euo", "pipefail", "-c", cmdStr) cmd.Dir = modelDir cmd.Env = os.Environ() for k, v := range model.Env { From e9fac6b6fb40cad85926f8f5d74ce71b0d39ba3f Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Wed, 15 Apr 2026 10:56:47 -0400 Subject: [PATCH 03/13] feat(test-harness): add requires_tools manifest field and Replicate model manifest Add requires_tools field to manifest model schema so models can declare CLI tools needed for setup commands. The harness checks tools exist on PATH before running setup and errors with a clear message listing the missing tools. Improve setup error reporting: both quiet and non-quiet modes now capture stderr and include it in the error message so users see the actual failure instead of a downstream error from an empty cog.yaml. Add manifest-replicate.yaml with 15 Replicate production models. --- .../internal/manifest/manifest.go | 1 + tools/test-harness/internal/runner/runner.go | 38 +- tools/test-harness/manifest-replicate.yaml | 392 ++++++++++++++++++ tools/test-harness/manifest.yaml | 7 + 4 files changed, 429 insertions(+), 9 deletions(-) create mode 100644 tools/test-harness/manifest-replicate.yaml diff --git a/tools/test-harness/internal/manifest/manifest.go b/tools/test-harness/internal/manifest/manifest.go index c325204fd4..d894996c20 100644 --- a/tools/test-harness/internal/manifest/manifest.go +++ b/tools/test-harness/internal/manifest/manifest.go @@ -29,6 +29,7 @@ type Model struct { GPU bool `yaml:"gpu"` Timeout int `yaml:"timeout"` RequiresEnv []string `yaml:"requires_env"` + RequiresTools []string `yaml:"requires_tools"` Env map[string]string `yaml:"env"` SDKVersion string `yaml:"sdk_version"` CogYAMLOverrides map[string]any `yaml:"cog_yaml_overrides"` diff --git a/tools/test-harness/internal/runner/runner.go b/tools/test-harness/internal/runner/runner.go index e65516b8f5..d1e29f8143 100644 --- a/tools/test-harness/internal/runner/runner.go +++ b/tools/test-harness/internal/runner/runner.go @@ -424,11 +424,32 @@ func (r *Runner) prepareModel(ctx context.Context, model manifest.Model) (string return modelDir, nil } +// checkRequiredTools verifies that all tools listed in requires_tools are +// available on PATH. Returns a descriptive error listing missing tools and +// install hints when possible. +func checkRequiredTools(tools []string) error { + var missing []string + for _, tool := range tools { + if _, err := exec.LookPath(tool); err != nil { + missing = append(missing, tool) + } + } + if len(missing) == 0 { + return nil + } + return fmt.Errorf("required tool(s) not found on PATH: %s", strings.Join(missing, ", ")) +} + // runSetupCommands executes the model's setup commands in the model directory. // Setup commands run after clone/copy but before cog.yaml validation and patching. // This is used for models that need preparation steps like generating cog.yaml // from templates (e.g. "script/select.sh dev" in replicate/cog-flux). func (r *Runner) runSetupCommands(ctx context.Context, modelDir string, model manifest.Model) error { + // Check required tools before running any setup commands + if err := checkRequiredTools(model.RequiresTools); err != nil { + return err + } + for _, cmdStr := range model.Setup { if !r.opts.Quiet { fmt.Printf(" Running setup: %s\n", cmdStr) @@ -443,19 +464,18 @@ func (r *Runner) runSetupCommands(ctx context.Context, modelDir string, model ma for k, v := range model.Env { cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", k, os.ExpandEnv(v))) } + // Always capture output for error reporting. In non-quiet mode, + // also stream to stdout/stderr so the user sees progress. + var outputBuf bytes.Buffer if r.opts.Quiet { - var outputBuf bytes.Buffer cmd.Stdout = &outputBuf cmd.Stderr = &outputBuf - if err := cmd.Run(); err != nil { - return fmt.Errorf("setup command %q failed: %w\n%s", cmdStr, err, outputBuf.String()) - } } else { - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - if err := cmd.Run(); err != nil { - return fmt.Errorf("setup command %q failed: %w", cmdStr, err) - } + cmd.Stdout = io.MultiWriter(os.Stdout, &outputBuf) + cmd.Stderr = io.MultiWriter(os.Stderr, &outputBuf) + } + if err := cmd.Run(); err != nil { + return fmt.Errorf("setup command %q failed: %w\n%s", cmdStr, err, outputBuf.String()) } } return nil diff --git a/tools/test-harness/manifest-replicate.yaml b/tools/test-harness/manifest-replicate.yaml new file mode 100644 index 0000000000..9d50016bd4 --- /dev/null +++ b/tools/test-harness/manifest-replicate.yaml @@ -0,0 +1,392 @@ +# Replicate Model Test Manifest +# ============================== +# Models from models.tsv — the top Replicate models and their cogified repos. +# +# Input values prefixed with "@" are resolved as fixture file paths relative +# to the fixtures/ directory (e.g. "@test_image.png" -> fixtures/test_image.png). +# +# Many of these models require GPU hardware with significant VRAM (40GB+). +# Proxy models (llama, alibaba) require API keys and run on CPU. +# +# Validation types: +# exact - output string must equal `value` exactly +# contains - output string must contain `value` as a substring +# regex - output string must match `pattern` +# file_exists - output is a file path; optionally check `mime` type +# json_match - parse output as JSON, assert `match` is a subset +# json_keys - parse output as JSON dict, assert it has entries +# not_empty - output is non-empty (loose smoke test) + +defaults: + sdk_version: "latest" + cog_version: "latest" + +models: + + # ── FLUX models (cog-flux) ────────────────────────────────────────── + # All FLUX variants live in replicate/cog-flux and are selected via + # script/select.sh which merges cog.yaml.template with + # model-cog-configs/.yaml to produce cog.yaml. + + - name: flux-schnell + repo: replicate/cog-flux + path: "." + gpu: true + timeout: 900 + requires_tools: ["yq"] + setup: + - "script/select.sh schnell" + tests: + - description: "generate image from prompt" + inputs: + prompt: "a golden retriever sitting in a field of sunflowers" + num_inference_steps: 4 + output_format: "png" + expect: + type: file_exists + mime: "image/png" + + - name: flux-dev-lora + repo: replicate/cog-flux + path: "." + gpu: true + timeout: 900 + requires_tools: ["yq"] + setup: + - "script/select.sh dev-lora" + tests: + - description: "generate image with prompt (no LoRA)" + inputs: + prompt: "a cat wearing a tiny top hat, studio lighting" + num_inference_steps: 28 + guidance: 3.0 + output_format: "png" + expect: + type: file_exists + mime: "image/png" + + - name: flux-fill-dev + repo: replicate/cog-flux + path: "." + gpu: true + timeout: 900 + requires_tools: ["yq"] + setup: + - "script/select.sh fill-dev" + tests: + - description: "inpaint region of image" + inputs: + prompt: "a bright red door" + image: "@test_image.png" + num_inference_steps: 28 + guidance: 30 + output_format: "png" + expect: + type: file_exists + mime: "image/png" + + - name: flux-depth-dev + repo: replicate/cog-flux + path: "." + gpu: true + timeout: 900 + requires_tools: ["yq"] + setup: + - "script/select.sh depth-dev" + tests: + - description: "depth-conditioned generation" + inputs: + prompt: "a futuristic cityscape at sunset" + control_image: "@test_image.png" + num_inference_steps: 28 + guidance: 10 + output_format: "png" + expect: + type: file_exists + mime: "image/png" + + - name: flux-schnell-lora + repo: replicate/cog-flux + path: "." + gpu: true + timeout: 900 + requires_tools: ["yq"] + setup: + - "script/select.sh schnell-lora" + tests: + - description: "generate image with schnell (no LoRA)" + inputs: + prompt: "a minimalist logo of a mountain, vector art" + num_inference_steps: 4 + output_format: "png" + expect: + type: file_exists + mime: "image/png" + + - name: flux-canny-dev + repo: replicate/cog-flux + path: "." + gpu: true + timeout: 900 + requires_tools: ["yq"] + setup: + - "script/select.sh canny-dev" + tests: + - description: "canny edge conditioned generation" + inputs: + prompt: "a pencil sketch of a building" + control_image: "@test_image.png" + num_inference_steps: 28 + guidance: 30 + output_format: "png" + expect: + type: file_exists + mime: "image/png" + + - name: flux-redux-dev + repo: replicate/cog-flux + path: "." + gpu: true + timeout: 900 + requires_tools: ["yq"] + setup: + - "script/select.sh redux-dev" + tests: + - description: "redux image variation" + inputs: + redux_image: "@test_image.png" + num_inference_steps: 28 + guidance: 3 + output_format: "png" + expect: + type: file_exists + mime: "image/png" + + # ── FLUX Kontext LoRA (diy-flux-lora) ─────────────────────────────── + # Training/inference repo for FLUX LoRA fine-tuning. Uses + # script/select-model to generate cog.yaml from template. + + - name: flux-kontext-dev-lora + repo: replicate/diy-flux-lora + path: "." + gpu: true + timeout: 900 + requires_tools: ["envsubst"] + setup: + - "script/select-model kontext-dev" + train_tests: + - description: "train a LoRA" + inputs: + input_images: "@test_image.png" + expect: + type: not_empty + tests: + - description: "generate with kontext (no LoRA)" + inputs: + prompt: "a watercolor painting of a forest" + expect: + type: file_exists + + # ── Qwen Edit (cog-qwen-edit-2509-multi-angle) ───────────────────── + # Direct cog.yaml, no setup script needed. + + - name: qwen-edit-multiangle + repo: replicate/cog-qwen-edit-2509-multi-angle + path: "." + gpu: true + timeout: 600 + tests: + - description: "rotate camera angle on image" + inputs: + image: "@test_image.png" + rotate_degrees: 15 + expect: + type: file_exists + + # ── Real-ESRGAN (cog-official-nightmareai-real-esrgan) ────────────── + # Image upscaling. Direct cog.yaml, no setup needed. + + - name: real-esrgan + repo: replicate/cog-official-nightmareai-real-esrgan + path: "." + gpu: true + timeout: 300 + tests: + - description: "upscale image 4x" + inputs: + image: "@test_image.png" + scale: 4 + face_enhance: false + expect: + type: file_exists + + # ── Stable Diffusion 3 (cog-stable-diffusion-3) ──────────────────── + # Direct cog.yaml, no setup needed. + + - name: stable-diffusion-3 + repo: replicate/cog-stable-diffusion-3 + path: "." + gpu: true + timeout: 600 + tests: + - description: "text-to-image generation" + inputs: + prompt: "a photorealistic image of an astronaut riding a horse on mars" + aspect_ratio: "1:1" + num_outputs: 1 + guidance_scale: 4.5 + output_format: "png" + expect: + type: file_exists + mime: "image/png" + + # ── Wan 2.1 1.3B (cog-wan-2.1) ───────────────────────────────────── + # Text-to-video model. Uses predict-1.3b.py directly via cog.yaml. + + - name: wan-2.1-1.3b + repo: replicate/cog-wan-2.1 + path: "." + gpu: true + timeout: 900 + tests: + - description: "text-to-video generation" + inputs: + prompt: "a timelapse of clouds moving across a blue sky" + aspect_ratio: "16:9" + frame_num: 17 + expect: + type: file_exists + + # ── Resemble AI Chatterbox (cog-resemble-chatterbox) ──────────────── + # TTS model. Direct cog.yaml, no setup needed. + + - name: chatterbox + repo: replicate/cog-resemble-chatterbox + path: "." + gpu: true + timeout: 300 + tests: + - description: "text-to-speech synthesis" + inputs: + prompt: "Hello, this is a test of the Chatterbox text to speech model." + exaggeration: 0.5 + cfg_weight: 0.5 + temperature: 0.8 + expect: + type: file_exists + + # ── Resemble AI Chatterbox Turbo (cog-resemble-chatterbox-turbo) ──── + # Faster TTS variant. Direct cog.yaml, no setup needed. + + - name: chatterbox-turbo + repo: replicate/cog-resemble-chatterbox-turbo + path: "." + gpu: true + timeout: 300 + tests: + - description: "fast text-to-speech synthesis" + inputs: + text: "Hello, this is a test of the Chatterbox Turbo text to speech model." + temperature: 0.8 + seed: 42 + expect: + type: file_exists + + # ── Qwen3 TTS (cog-qwen-tts) ─────────────────────────────────────── + # TTS model with multiple modes. Uses script/select-model. + + - name: qwen3-tts + repo: replicate/cog-qwen-tts + path: "." + gpu: true + timeout: 600 + requires_tools: ["envsubst"] + setup: + - "script/select-model qwen3-tts" + tests: + - description: "text-to-speech with preset voice" + inputs: + text: "The quick brown fox jumps over the lazy dog." + mode: "custom_voice" + language: "auto" + speaker: "Serena" + expect: + type: file_exists + + # ── Proxy models (CPU, require API keys) ──────────────────────────── + # These are thin wrappers around external APIs (Groq, Dashscope, etc.). + # They require API keys written to .proxy-api-key and similar files. + + # - name: meta-llama-3-70b-instruct + # repo: replicate/cog-llama-proxy + # path: "." + # gpu: false + # timeout: 300 + # setup: + # # NOTE: select-model currently only supports 405b in the script. + # # The 70b model uses MetaLlama370bInstructPredictor (backed by Groq). + # # You may need to manually generate cog.yaml or extend select-model. + # - "PREDICTOR=MetaLlama370bInstructPredictor envsubst < cog.yaml.tpl > cog.yaml" + # requires_env: + # - GROQ_API_KEY + # tests: + # - description: "basic chat completion" + # inputs: + # prompt: "What is the capital of France? Answer in one word." + # max_tokens: 10 + # temperature: 0.1 + # expect: + # type: contains + # value: "Paris" + + # - name: qwen-image-2-pro + # repo: replicate/cog-alibaba-proxy + # path: "." + # gpu: false + # timeout: 300 + # setup: + # - "script/select-model qwen-image-2-pro" + # requires_env: + # - DASHSCOPE_API_KEY + # tests: + # - description: "text-to-image generation" + # inputs: + # prompt: "a serene japanese garden with a koi pond, watercolor style" + # aspect_ratio: "1:1" + # expect: + # type: file_exists + + # - name: qwen-image-2 + # repo: replicate/cog-alibaba-proxy + # path: "." + # gpu: false + # timeout: 300 + # setup: + # - "script/select-model qwen-image-2" + # requires_env: + # - DASHSCOPE_API_KEY + # tests: + # - description: "text-to-image generation" + # inputs: + # prompt: "a cozy cabin in the snowy mountains at night" + # aspect_ratio: "16:9" + # expect: + # type: file_exists + + # ── Models without identified repos ───────────────────────────────── + # These models could not be matched to a repo in the replicate org. + # They are listed here as placeholders for future discovery. + + # - name: nsfw-image-detection + # # falcons-ai/nsfw_image_detection — no repo found in replicate org + # repo: unknown + # path: "." + # gpu: true + # tests: [] + + # - name: meta-llama-3-8b-instruct + # # meta/meta-llama-3-8b-instruct — repo not confirmed + # repo: unknown + # path: "." + # gpu: true + # tests: [] diff --git a/tools/test-harness/manifest.yaml b/tools/test-harness/manifest.yaml index bdd8870023..b72cf62749 100644 --- a/tools/test-harness/manifest.yaml +++ b/tools/test-harness/manifest.yaml @@ -21,6 +21,13 @@ # Example: # setup: # - "script/select.sh dev" +# +# Required tools: +# Optional `requires_tools` list of CLI tools that must be on PATH for +# setup commands to work. The harness checks before running and prints +# install instructions for known tools (yq, envsubst). +# Example: +# requires_tools: ["yq"] defaults: sdk_version: "latest" # "latest" = newest stable from PyPI; or pin e.g. "0.16.12" From 036a83f95d86733d2975b045f2e8e018e5e7d843 Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Wed, 15 Apr 2026 12:09:07 -0400 Subject: [PATCH 04/13] feat(test-harness): stream logs with model name prefix in parallel mode Replace the quiet/buffered approach with a prefixWriter that prepends each output line with the model name, like docker-compose: [flux-schnell ] Step 3/12: RUN pip install ... [chatterbox ] Building image cog-harness-chatterbox:test [flux-schnell ] Successfully built abc123 Output streams in real-time even in parallel mode. Sequential mode (--concurrency 1) streams without prefixes as before. --- tools/test-harness/cmd/build.go | 2 +- tools/test-harness/cmd/run.go | 2 +- tools/test-harness/cmd/schema_compare.go | 2 +- tools/test-harness/internal/runner/runner.go | 109 ++++++++++++++----- 4 files changed, 84 insertions(+), 31 deletions(-) diff --git a/tools/test-harness/cmd/build.go b/tools/test-harness/cmd/build.go index cba430a145..c4e8b84e31 100644 --- a/tools/test-harness/cmd/build.go +++ b/tools/test-harness/cmd/build.go @@ -49,7 +49,7 @@ func runBuild(ctx context.Context) error { SDKWheel: resolved.SDKWheel, CleanImages: cleanImages, KeepOutputs: keepOutputs, - Quiet: parallel, + Parallel: parallel, }) if err != nil { return fmt.Errorf("creating runner: %w", err) diff --git a/tools/test-harness/cmd/run.go b/tools/test-harness/cmd/run.go index bea4769609..7b18be6c20 100644 --- a/tools/test-harness/cmd/run.go +++ b/tools/test-harness/cmd/run.go @@ -68,7 +68,7 @@ func runRun(ctx context.Context, outputFormat, outputFile string) error { SDKWheel: resolved.SDKWheel, CleanImages: cleanImages, KeepOutputs: keepOutputs, - Quiet: parallel, + Parallel: parallel, }) if err != nil { return fmt.Errorf("creating runner: %w", err) diff --git a/tools/test-harness/cmd/schema_compare.go b/tools/test-harness/cmd/schema_compare.go index 8a9394008a..14b68e595a 100644 --- a/tools/test-harness/cmd/schema_compare.go +++ b/tools/test-harness/cmd/schema_compare.go @@ -63,7 +63,7 @@ func runSchemaCompare(ctx context.Context, outputFormat, outputFile string) erro SDKWheel: resolved.SDKWheel, CleanImages: cleanImages, KeepOutputs: keepOutputs, - Quiet: parallel, + Parallel: parallel, }) if err != nil { return fmt.Errorf("creating runner: %w", err) diff --git a/tools/test-harness/internal/runner/runner.go b/tools/test-harness/internal/runner/runner.go index d1e29f8143..915fd31c80 100644 --- a/tools/test-harness/internal/runner/runner.go +++ b/tools/test-harness/internal/runner/runner.go @@ -27,6 +27,69 @@ import ( const openapiSchemaLabel = "run.cog.openapi_schema" +// prefixWriter wraps an io.Writer and prepends a prefix to each line. +// Partial lines (no trailing newline) are buffered until a newline arrives. +type prefixWriter struct { + prefix string + dest io.Writer + mu sync.Mutex + buf []byte +} + +func newPrefixWriter(dest io.Writer, modelName string) *prefixWriter { + return &prefixWriter{ + prefix: fmt.Sprintf("[%-20s] ", modelName), + dest: dest, + } +} + +func (pw *prefixWriter) Write(p []byte) (int, error) { + pw.mu.Lock() + defer pw.mu.Unlock() + + total := len(p) + pw.buf = append(pw.buf, p...) + + for { + idx := bytes.IndexByte(pw.buf, '\n') + if idx < 0 { + break + } + line := pw.buf[:idx] + pw.buf = pw.buf[idx+1:] + if _, err := fmt.Fprintf(pw.dest, "%s%s\n", pw.prefix, line); err != nil { + return total, err + } + } + return total, nil +} + +// Flush writes any remaining buffered content (partial line without trailing newline). +func (pw *prefixWriter) Flush() { + pw.mu.Lock() + defer pw.mu.Unlock() + + if len(pw.buf) > 0 { + _, _ = fmt.Fprintf(pw.dest, "%s%s\n", pw.prefix, pw.buf) + pw.buf = nil + } +} + +// modelOutput returns stdout/stderr writers for a model. +// In parallel mode, output is prefixed with the model name and +// also captured in a buffer for error reporting. +// In sequential mode, output streams directly to the terminal +// and is also captured. +func (r *Runner) modelOutput(modelName string) (stdout, stderr io.Writer, capture *bytes.Buffer, flush func()) { + var buf bytes.Buffer + if r.opts.Parallel { + pw := newPrefixWriter(os.Stderr, modelName) + w := io.MultiWriter(pw, &buf) + return w, w, &buf, pw.Flush + } + return io.MultiWriter(os.Stdout, &buf), io.MultiWriter(os.Stderr, &buf), &buf, func() {} +} + // Options configures a Runner. type Options struct { CogBinary string @@ -35,7 +98,7 @@ type Options struct { FixturesDir string CleanImages bool KeepOutputs bool - Quiet bool // Suppress real-time build output (for parallel execution) + Parallel bool // Prefix output lines with model name (for parallel execution) } // Runner orchestrates the test lifecycle. @@ -450,10 +513,10 @@ func (r *Runner) runSetupCommands(ctx context.Context, modelDir string, model ma return err } + stdout, stderr, capture, flush := r.modelOutput(model.Name) + for _, cmdStr := range model.Setup { - if !r.opts.Quiet { - fmt.Printf(" Running setup: %s\n", cmdStr) - } + _, _ = fmt.Fprintf(stderr, " Running setup: %s\n", cmdStr) // Use bash with strict mode so any failing command in the // script (e.g. a missing yq binary) is caught immediately // rather than silently producing an empty/invalid cog.yaml. @@ -464,20 +527,14 @@ func (r *Runner) runSetupCommands(ctx context.Context, modelDir string, model ma for k, v := range model.Env { cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", k, os.ExpandEnv(v))) } - // Always capture output for error reporting. In non-quiet mode, - // also stream to stdout/stderr so the user sees progress. - var outputBuf bytes.Buffer - if r.opts.Quiet { - cmd.Stdout = &outputBuf - cmd.Stderr = &outputBuf - } else { - cmd.Stdout = io.MultiWriter(os.Stdout, &outputBuf) - cmd.Stderr = io.MultiWriter(os.Stderr, &outputBuf) - } + cmd.Stdout = stdout + cmd.Stderr = stderr if err := cmd.Run(); err != nil { - return fmt.Errorf("setup command %q failed: %w\n%s", cmdStr, err, outputBuf.String()) + flush() + return fmt.Errorf("setup command %q failed: %w\n%s", cmdStr, err, capture.String()) } } + flush() return nil } @@ -545,21 +602,17 @@ func (r *Runner) buildModelWithEnv(ctx context.Context, modelDir string, model m cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", k, v)) } - // Stream build output to stderr in real-time so the user can see progress, + // Stream build output in real-time so the user can see progress, // while also capturing it for error reporting if the build fails. - // When running in parallel mode (opts.Quiet), only capture output and - // include it in error messages to avoid interleaved output. - var outputBuf bytes.Buffer - if r.opts.Quiet { - cmd.Stdout = &outputBuf - cmd.Stderr = &outputBuf - } else { - cmd.Stdout = io.MultiWriter(os.Stderr, &outputBuf) - cmd.Stderr = io.MultiWriter(os.Stderr, &outputBuf) - } - if err := cmd.Run(); err != nil { + // In parallel mode, each line is prefixed with the model name. + _, stderr, capture, flush := r.modelOutput(model.Name) + cmd.Stdout = stderr + cmd.Stderr = stderr + err := cmd.Run() + flush() + if err != nil { // Include the last portion of build output for context. - output := outputBuf.String() + output := capture.String() const maxTail = 2000 if len(output) > maxTail { output = "...\n" + output[len(output)-maxTail:] From a9d7ddabb57c75528f5f07e3830fba0f31d04275 Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Wed, 15 Apr 2026 12:25:45 -0400 Subject: [PATCH 05/13] feat(test-harness): add timing output between build and test steps Print === banners with timing info at each phase boundary so users can see where time is being spent: === Preparing flux-schnell... === Building flux-schnell (timeout 900s)... === Build complete (62.3s) === Predict test 1/1: generate image from prompt (timeout 900s)... === Predict test 1/1 PASSED (48.7s) In parallel mode these are prefixed with the model name like all other output. --- tools/test-harness/internal/runner/runner.go | 42 ++++++++++++++++++-- 1 file changed, 38 insertions(+), 4 deletions(-) diff --git a/tools/test-harness/internal/runner/runner.go b/tools/test-harness/internal/runner/runner.go index 915fd31c80..f34f3821c6 100644 --- a/tools/test-harness/internal/runner/runner.go +++ b/tools/test-harness/internal/runner/runner.go @@ -212,6 +212,8 @@ func (r *Runner) Cleanup() error { // RunModel runs all tests for a single model func (r *Runner) RunModel(ctx context.Context, model manifest.Model) *report.ModelResult { + _, logw, _, flush := r.modelOutput(model.Name) + result := &report.ModelResult{ Name: model.Name, Passed: true, @@ -229,46 +231,71 @@ func (r *Runner) RunModel(ctx context.Context, model manifest.Model) *report.Mod } // Prepare model + _, _ = fmt.Fprintf(logw, "=== Preparing %s...\n", model.Name) modelDir, err := r.prepareModel(ctx, model) if err != nil { result.Passed = false result.Error = fmt.Sprintf("Preparation failed: %v", err) + flush() return result } // Build + _, _ = fmt.Fprintf(logw, "=== Building %s (timeout %ds)...\n", model.Name, model.Timeout) buildStart := time.Now() if err := r.buildModel(ctx, modelDir, model); err != nil { result.Passed = false result.BuildDuration = time.Since(buildStart).Seconds() result.Error = fmt.Sprintf("Build failed: %v", err) + _, _ = fmt.Fprintf(logw, "=== Build FAILED after %.1fs\n", result.BuildDuration) + flush() return result } result.BuildDuration = time.Since(buildStart).Seconds() + _, _ = fmt.Fprintf(logw, "=== Build complete (%.1fs)\n", result.BuildDuration) // Run train tests - for _, tc := range model.TrainTests { + for i, tc := range model.TrainTests { + desc := tc.Description + if desc == "" { + desc = "train" + } + _, _ = fmt.Fprintf(logw, "=== Train test %d/%d: %s (timeout %ds)...\n", i+1, len(model.TrainTests), desc, model.Timeout) tr := r.runTrainTest(ctx, modelDir, model, tc) result.TrainResults = append(result.TrainResults, tr) - if !tr.Passed { + if tr.Passed { + _, _ = fmt.Fprintf(logw, "=== Train test %d/%d PASSED (%.1fs)\n", i+1, len(model.TrainTests), tr.DurationSec) + } else { + _, _ = fmt.Fprintf(logw, "=== Train test %d/%d FAILED (%.1fs)\n", i+1, len(model.TrainTests), tr.DurationSec) result.Passed = false } } // Run predict tests - for _, tc := range model.Tests { + for i, tc := range model.Tests { + desc := tc.Description + if desc == "" { + desc = "predict" + } + _, _ = fmt.Fprintf(logw, "=== Predict test %d/%d: %s (timeout %ds)...\n", i+1, len(model.Tests), desc, model.Timeout) tr := r.runPredictTest(ctx, modelDir, model, tc) result.TestResults = append(result.TestResults, tr) - if !tr.Passed { + if tr.Passed { + _, _ = fmt.Fprintf(logw, "=== Predict test %d/%d PASSED (%.1fs)\n", i+1, len(model.Tests), tr.DurationSec) + } else { + _, _ = fmt.Fprintf(logw, "=== Predict test %d/%d FAILED (%.1fs)\n", i+1, len(model.Tests), tr.DurationSec) result.Passed = false } } + flush() return result } // BuildModel builds a model image only func (r *Runner) BuildModel(ctx context.Context, model manifest.Model) *report.ModelResult { + _, logw, _, flush := r.modelOutput(model.Name) + result := &report.ModelResult{ Name: model.Name, Passed: true, @@ -286,23 +313,30 @@ func (r *Runner) BuildModel(ctx context.Context, model manifest.Model) *report.M } // Prepare model + _, _ = fmt.Fprintf(logw, "=== Preparing %s...\n", model.Name) modelDir, err := r.prepareModel(ctx, model) if err != nil { result.Passed = false result.Error = fmt.Sprintf("Preparation failed: %v", err) + flush() return result } // Build + _, _ = fmt.Fprintf(logw, "=== Building %s (timeout %ds)...\n", model.Name, model.Timeout) buildStart := time.Now() if err := r.buildModel(ctx, modelDir, model); err != nil { result.Passed = false result.BuildDuration = time.Since(buildStart).Seconds() result.Error = fmt.Sprintf("Build failed: %v", err) + _, _ = fmt.Fprintf(logw, "=== Build FAILED after %.1fs\n", result.BuildDuration) + flush() return result } result.BuildDuration = time.Since(buildStart).Seconds() + _, _ = fmt.Fprintf(logw, "=== Build complete (%.1fs)\n", result.BuildDuration) + flush() return result } From 80b3c8fcae7943726e2e6fa7f467292176228079 Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Wed, 15 Apr 2026 14:38:22 -0400 Subject: [PATCH 06/13] feat: add --skip-schema-validation flag to cog build Expose the existing SkipSchemaValidation option as a CLI flag. This is useful when building models that fail schema validation due to import-time side effects (e.g. accessing GPU/weights files) or when using a remote Docker context where the legacy schema validation container cannot access local resources. Also add skip_schema_validation field to the test harness manifest so models can opt out of schema validation per-model. --- pkg/cli/build.go | 33 +++++++++++-------- .../internal/manifest/manifest.go | 27 +++++++-------- tools/test-harness/internal/runner/runner.go | 6 +++- 3 files changed, 39 insertions(+), 27 deletions(-) diff --git a/pkg/cli/build.go b/pkg/cli/build.go index 490ab585f3..577371749e 100644 --- a/pkg/cli/build.go +++ b/pkg/cli/build.go @@ -26,6 +26,7 @@ var buildDockerfileFile string var buildUseCogBaseImage bool var buildStrip bool var buildPrecompile bool +var buildSkipSchemaValidation bool var configFilename string const useCogBaseImageFlagKey = "use-cog-base-image" @@ -66,6 +67,7 @@ with 'cog push'.`, addStripFlag(cmd) addPrecompileFlag(cmd) addConfigFlag(cmd) + addSkipSchemaValidationFlag(cmd) cmd.Flags().StringVarP(&buildTag, "tag", "t", "", "A name for the built image in the form 'repository:tag'") return cmd } @@ -194,22 +196,27 @@ func DetermineUseCogBaseImage(cmd *cobra.Command) *bool { return useCogBaseImage } +func addSkipSchemaValidationFlag(cmd *cobra.Command) { + cmd.Flags().BoolVar(&buildSkipSchemaValidation, "skip-schema-validation", false, "Skip OpenAPI schema generation and validation") +} + // buildOptionsFromFlags creates BuildOptions from the current CLI flag values. // The imageName and annotations parameters vary by command and must be provided. func buildOptionsFromFlags(cmd *cobra.Command, imageName string, annotations map[string]string) model.BuildOptions { return model.BuildOptions{ - ImageName: imageName, - Secrets: buildSecrets, - NoCache: buildNoCache, - SeparateWeights: buildSeparateWeights, - UseCudaBaseImage: buildUseCudaBaseImage, - ProgressOutput: buildProgressOutput, - SchemaFile: buildSchemaFile, - DockerfileFile: buildDockerfileFile, - UseCogBaseImage: DetermineUseCogBaseImage(cmd), - Strip: buildStrip, - Precompile: buildPrecompile, - Annotations: annotations, - OCIIndex: model.OCIIndexEnabled(), + ImageName: imageName, + Secrets: buildSecrets, + NoCache: buildNoCache, + SeparateWeights: buildSeparateWeights, + UseCudaBaseImage: buildUseCudaBaseImage, + ProgressOutput: buildProgressOutput, + SchemaFile: buildSchemaFile, + DockerfileFile: buildDockerfileFile, + UseCogBaseImage: DetermineUseCogBaseImage(cmd), + Strip: buildStrip, + Precompile: buildPrecompile, + Annotations: annotations, + OCIIndex: model.OCIIndexEnabled(), + SkipSchemaValidation: buildSkipSchemaValidation, } } diff --git a/tools/test-harness/internal/manifest/manifest.go b/tools/test-harness/internal/manifest/manifest.go index d894996c20..8be4a3ef8e 100644 --- a/tools/test-harness/internal/manifest/manifest.go +++ b/tools/test-harness/internal/manifest/manifest.go @@ -23,19 +23,20 @@ type Defaults struct { // Model represents a single model definition type Model struct { - Name string `yaml:"name"` - Repo string `yaml:"repo"` - Path string `yaml:"path"` - GPU bool `yaml:"gpu"` - Timeout int `yaml:"timeout"` - RequiresEnv []string `yaml:"requires_env"` - RequiresTools []string `yaml:"requires_tools"` - Env map[string]string `yaml:"env"` - SDKVersion string `yaml:"sdk_version"` - CogYAMLOverrides map[string]any `yaml:"cog_yaml_overrides"` - Setup []string `yaml:"setup"` - Tests []TestCase `yaml:"tests"` - TrainTests []TestCase `yaml:"train_tests"` + Name string `yaml:"name"` + Repo string `yaml:"repo"` + Path string `yaml:"path"` + GPU bool `yaml:"gpu"` + Timeout int `yaml:"timeout"` + RequiresEnv []string `yaml:"requires_env"` + RequiresTools []string `yaml:"requires_tools"` + Env map[string]string `yaml:"env"` + SDKVersion string `yaml:"sdk_version"` + CogYAMLOverrides map[string]any `yaml:"cog_yaml_overrides"` + Setup []string `yaml:"setup"` + SkipSchemaValidation bool `yaml:"skip_schema_validation"` + Tests []TestCase `yaml:"tests"` + TrainTests []TestCase `yaml:"train_tests"` } // TestCase represents a single test case diff --git a/tools/test-harness/internal/runner/runner.go b/tools/test-harness/internal/runner/runner.go index f34f3821c6..24c3c728f0 100644 --- a/tools/test-harness/internal/runner/runner.go +++ b/tools/test-harness/internal/runner/runner.go @@ -611,7 +611,11 @@ func (r *Runner) buildModelWithEnv(ctx context.Context, modelDir string, model m ctx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second) defer cancel() - cmd := exec.CommandContext(ctx, r.opts.CogBinary, "build", "-t", imageTag) + buildArgs := []string{"build", "-t", imageTag} + if model.SkipSchemaValidation { + buildArgs = append(buildArgs, "--skip-schema-validation") + } + cmd := exec.CommandContext(ctx, r.opts.CogBinary, buildArgs...) cmd.Dir = modelDir cmd.Env = os.Environ() if r.opts.SDKWheel != "" { From b71c93380f014d4ffbbccb8058886604b24b5f27 Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Wed, 15 Apr 2026 15:53:46 -0400 Subject: [PATCH 07/13] chore: rm manifest-replicate Signed-off-by: Mark Phelps --- go.mod | 2 +- tools/test-harness/manifest-replicate.yaml | 392 --------------------- 2 files changed, 1 insertion(+), 393 deletions(-) delete mode 100644 tools/test-harness/manifest-replicate.yaml diff --git a/go.mod b/go.mod index 84a44f1b45..9f4e907903 100644 --- a/go.mod +++ b/go.mod @@ -39,6 +39,7 @@ require ( golang.org/x/sys v0.42.0 golang.org/x/term v0.41.0 google.golang.org/grpc v1.79.3 + gopkg.in/yaml.v3 v3.0.1 ) require ( @@ -148,7 +149,6 @@ require ( google.golang.org/genproto/googleapis/api v0.0.0-20260128011058-8636f8732409 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20260128011058-8636f8732409 // indirect google.golang.org/protobuf v1.36.11 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect gotest.tools/gotestsum v1.12.2 // indirect ) diff --git a/tools/test-harness/manifest-replicate.yaml b/tools/test-harness/manifest-replicate.yaml deleted file mode 100644 index 9d50016bd4..0000000000 --- a/tools/test-harness/manifest-replicate.yaml +++ /dev/null @@ -1,392 +0,0 @@ -# Replicate Model Test Manifest -# ============================== -# Models from models.tsv — the top Replicate models and their cogified repos. -# -# Input values prefixed with "@" are resolved as fixture file paths relative -# to the fixtures/ directory (e.g. "@test_image.png" -> fixtures/test_image.png). -# -# Many of these models require GPU hardware with significant VRAM (40GB+). -# Proxy models (llama, alibaba) require API keys and run on CPU. -# -# Validation types: -# exact - output string must equal `value` exactly -# contains - output string must contain `value` as a substring -# regex - output string must match `pattern` -# file_exists - output is a file path; optionally check `mime` type -# json_match - parse output as JSON, assert `match` is a subset -# json_keys - parse output as JSON dict, assert it has entries -# not_empty - output is non-empty (loose smoke test) - -defaults: - sdk_version: "latest" - cog_version: "latest" - -models: - - # ── FLUX models (cog-flux) ────────────────────────────────────────── - # All FLUX variants live in replicate/cog-flux and are selected via - # script/select.sh which merges cog.yaml.template with - # model-cog-configs/.yaml to produce cog.yaml. - - - name: flux-schnell - repo: replicate/cog-flux - path: "." - gpu: true - timeout: 900 - requires_tools: ["yq"] - setup: - - "script/select.sh schnell" - tests: - - description: "generate image from prompt" - inputs: - prompt: "a golden retriever sitting in a field of sunflowers" - num_inference_steps: 4 - output_format: "png" - expect: - type: file_exists - mime: "image/png" - - - name: flux-dev-lora - repo: replicate/cog-flux - path: "." - gpu: true - timeout: 900 - requires_tools: ["yq"] - setup: - - "script/select.sh dev-lora" - tests: - - description: "generate image with prompt (no LoRA)" - inputs: - prompt: "a cat wearing a tiny top hat, studio lighting" - num_inference_steps: 28 - guidance: 3.0 - output_format: "png" - expect: - type: file_exists - mime: "image/png" - - - name: flux-fill-dev - repo: replicate/cog-flux - path: "." - gpu: true - timeout: 900 - requires_tools: ["yq"] - setup: - - "script/select.sh fill-dev" - tests: - - description: "inpaint region of image" - inputs: - prompt: "a bright red door" - image: "@test_image.png" - num_inference_steps: 28 - guidance: 30 - output_format: "png" - expect: - type: file_exists - mime: "image/png" - - - name: flux-depth-dev - repo: replicate/cog-flux - path: "." - gpu: true - timeout: 900 - requires_tools: ["yq"] - setup: - - "script/select.sh depth-dev" - tests: - - description: "depth-conditioned generation" - inputs: - prompt: "a futuristic cityscape at sunset" - control_image: "@test_image.png" - num_inference_steps: 28 - guidance: 10 - output_format: "png" - expect: - type: file_exists - mime: "image/png" - - - name: flux-schnell-lora - repo: replicate/cog-flux - path: "." - gpu: true - timeout: 900 - requires_tools: ["yq"] - setup: - - "script/select.sh schnell-lora" - tests: - - description: "generate image with schnell (no LoRA)" - inputs: - prompt: "a minimalist logo of a mountain, vector art" - num_inference_steps: 4 - output_format: "png" - expect: - type: file_exists - mime: "image/png" - - - name: flux-canny-dev - repo: replicate/cog-flux - path: "." - gpu: true - timeout: 900 - requires_tools: ["yq"] - setup: - - "script/select.sh canny-dev" - tests: - - description: "canny edge conditioned generation" - inputs: - prompt: "a pencil sketch of a building" - control_image: "@test_image.png" - num_inference_steps: 28 - guidance: 30 - output_format: "png" - expect: - type: file_exists - mime: "image/png" - - - name: flux-redux-dev - repo: replicate/cog-flux - path: "." - gpu: true - timeout: 900 - requires_tools: ["yq"] - setup: - - "script/select.sh redux-dev" - tests: - - description: "redux image variation" - inputs: - redux_image: "@test_image.png" - num_inference_steps: 28 - guidance: 3 - output_format: "png" - expect: - type: file_exists - mime: "image/png" - - # ── FLUX Kontext LoRA (diy-flux-lora) ─────────────────────────────── - # Training/inference repo for FLUX LoRA fine-tuning. Uses - # script/select-model to generate cog.yaml from template. - - - name: flux-kontext-dev-lora - repo: replicate/diy-flux-lora - path: "." - gpu: true - timeout: 900 - requires_tools: ["envsubst"] - setup: - - "script/select-model kontext-dev" - train_tests: - - description: "train a LoRA" - inputs: - input_images: "@test_image.png" - expect: - type: not_empty - tests: - - description: "generate with kontext (no LoRA)" - inputs: - prompt: "a watercolor painting of a forest" - expect: - type: file_exists - - # ── Qwen Edit (cog-qwen-edit-2509-multi-angle) ───────────────────── - # Direct cog.yaml, no setup script needed. - - - name: qwen-edit-multiangle - repo: replicate/cog-qwen-edit-2509-multi-angle - path: "." - gpu: true - timeout: 600 - tests: - - description: "rotate camera angle on image" - inputs: - image: "@test_image.png" - rotate_degrees: 15 - expect: - type: file_exists - - # ── Real-ESRGAN (cog-official-nightmareai-real-esrgan) ────────────── - # Image upscaling. Direct cog.yaml, no setup needed. - - - name: real-esrgan - repo: replicate/cog-official-nightmareai-real-esrgan - path: "." - gpu: true - timeout: 300 - tests: - - description: "upscale image 4x" - inputs: - image: "@test_image.png" - scale: 4 - face_enhance: false - expect: - type: file_exists - - # ── Stable Diffusion 3 (cog-stable-diffusion-3) ──────────────────── - # Direct cog.yaml, no setup needed. - - - name: stable-diffusion-3 - repo: replicate/cog-stable-diffusion-3 - path: "." - gpu: true - timeout: 600 - tests: - - description: "text-to-image generation" - inputs: - prompt: "a photorealistic image of an astronaut riding a horse on mars" - aspect_ratio: "1:1" - num_outputs: 1 - guidance_scale: 4.5 - output_format: "png" - expect: - type: file_exists - mime: "image/png" - - # ── Wan 2.1 1.3B (cog-wan-2.1) ───────────────────────────────────── - # Text-to-video model. Uses predict-1.3b.py directly via cog.yaml. - - - name: wan-2.1-1.3b - repo: replicate/cog-wan-2.1 - path: "." - gpu: true - timeout: 900 - tests: - - description: "text-to-video generation" - inputs: - prompt: "a timelapse of clouds moving across a blue sky" - aspect_ratio: "16:9" - frame_num: 17 - expect: - type: file_exists - - # ── Resemble AI Chatterbox (cog-resemble-chatterbox) ──────────────── - # TTS model. Direct cog.yaml, no setup needed. - - - name: chatterbox - repo: replicate/cog-resemble-chatterbox - path: "." - gpu: true - timeout: 300 - tests: - - description: "text-to-speech synthesis" - inputs: - prompt: "Hello, this is a test of the Chatterbox text to speech model." - exaggeration: 0.5 - cfg_weight: 0.5 - temperature: 0.8 - expect: - type: file_exists - - # ── Resemble AI Chatterbox Turbo (cog-resemble-chatterbox-turbo) ──── - # Faster TTS variant. Direct cog.yaml, no setup needed. - - - name: chatterbox-turbo - repo: replicate/cog-resemble-chatterbox-turbo - path: "." - gpu: true - timeout: 300 - tests: - - description: "fast text-to-speech synthesis" - inputs: - text: "Hello, this is a test of the Chatterbox Turbo text to speech model." - temperature: 0.8 - seed: 42 - expect: - type: file_exists - - # ── Qwen3 TTS (cog-qwen-tts) ─────────────────────────────────────── - # TTS model with multiple modes. Uses script/select-model. - - - name: qwen3-tts - repo: replicate/cog-qwen-tts - path: "." - gpu: true - timeout: 600 - requires_tools: ["envsubst"] - setup: - - "script/select-model qwen3-tts" - tests: - - description: "text-to-speech with preset voice" - inputs: - text: "The quick brown fox jumps over the lazy dog." - mode: "custom_voice" - language: "auto" - speaker: "Serena" - expect: - type: file_exists - - # ── Proxy models (CPU, require API keys) ──────────────────────────── - # These are thin wrappers around external APIs (Groq, Dashscope, etc.). - # They require API keys written to .proxy-api-key and similar files. - - # - name: meta-llama-3-70b-instruct - # repo: replicate/cog-llama-proxy - # path: "." - # gpu: false - # timeout: 300 - # setup: - # # NOTE: select-model currently only supports 405b in the script. - # # The 70b model uses MetaLlama370bInstructPredictor (backed by Groq). - # # You may need to manually generate cog.yaml or extend select-model. - # - "PREDICTOR=MetaLlama370bInstructPredictor envsubst < cog.yaml.tpl > cog.yaml" - # requires_env: - # - GROQ_API_KEY - # tests: - # - description: "basic chat completion" - # inputs: - # prompt: "What is the capital of France? Answer in one word." - # max_tokens: 10 - # temperature: 0.1 - # expect: - # type: contains - # value: "Paris" - - # - name: qwen-image-2-pro - # repo: replicate/cog-alibaba-proxy - # path: "." - # gpu: false - # timeout: 300 - # setup: - # - "script/select-model qwen-image-2-pro" - # requires_env: - # - DASHSCOPE_API_KEY - # tests: - # - description: "text-to-image generation" - # inputs: - # prompt: "a serene japanese garden with a koi pond, watercolor style" - # aspect_ratio: "1:1" - # expect: - # type: file_exists - - # - name: qwen-image-2 - # repo: replicate/cog-alibaba-proxy - # path: "." - # gpu: false - # timeout: 300 - # setup: - # - "script/select-model qwen-image-2" - # requires_env: - # - DASHSCOPE_API_KEY - # tests: - # - description: "text-to-image generation" - # inputs: - # prompt: "a cozy cabin in the snowy mountains at night" - # aspect_ratio: "16:9" - # expect: - # type: file_exists - - # ── Models without identified repos ───────────────────────────────── - # These models could not be matched to a repo in the replicate org. - # They are listed here as placeholders for future discovery. - - # - name: nsfw-image-detection - # # falcons-ai/nsfw_image_detection — no repo found in replicate org - # repo: unknown - # path: "." - # gpu: true - # tests: [] - - # - name: meta-llama-3-8b-instruct - # # meta/meta-llama-3-8b-instruct — repo not confirmed - # repo: unknown - # path: "." - # gpu: true - # tests: [] From 3fdde107b5e527d27f8dc4b48054f4ea93ca5064 Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Wed, 15 Apr 2026 15:23:16 -0400 Subject: [PATCH 08/13] fix(test-harness): pass --setup-timeout to cog predict/train cog predict has a default --setup-timeout of 300s which kills the container if model setup (weight downloads) takes longer. Pass the model's timeout value as the setup timeout so large models have enough time to download weights during first run. --- tools/test-harness/internal/runner/runner.go | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/tools/test-harness/internal/runner/runner.go b/tools/test-harness/internal/runner/runner.go index 24c3c728f0..b3d1aa1b12 100644 --- a/tools/test-harness/internal/runner/runner.go +++ b/tools/test-harness/internal/runner/runner.go @@ -687,8 +687,15 @@ func (r *Runner) runCogTest(ctx context.Context, modelDir string, model manifest start := time.Now() - // Build command - args := []string{command} + // Set timeout + timeout := model.Timeout + if timeout == 0 { + timeout = 300 + } + + // Build command — pass setup-timeout matching the model timeout so + // cog predict doesn't kill the container during model weight downloads. + args := []string{command, "--setup-timeout", fmt.Sprintf("%d", timeout)} keys := make([]string, 0, len(tc.Inputs)) for k := range tc.Inputs { keys = append(keys, k) @@ -700,12 +707,6 @@ func (r *Runner) runCogTest(ctx context.Context, modelDir string, model manifest args = append(args, "-i", fmt.Sprintf("%s=%s", key, resolved)) } - // Set timeout - timeout := model.Timeout - if timeout == 0 { - timeout = 300 - } - ctx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second) defer cancel() From e41286780487984ae3f726be29685826d9be5449 Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Wed, 15 Apr 2026 17:30:08 -0400 Subject: [PATCH 09/13] fix(test-harness): add pass/fail feedback in sequential mode The sequential code paths (--concurrency 1) were missing per-model status output after each build/run/compare, unlike the parallel paths. Now both modes print consistent pass/fail/skip feedback. --- tools/test-harness/cmd/build.go | 8 ++++++++ tools/test-harness/cmd/run.go | 9 +++++++++ tools/test-harness/cmd/schema_compare.go | 5 +++++ 3 files changed, 22 insertions(+) diff --git a/tools/test-harness/cmd/build.go b/tools/test-harness/cmd/build.go index c4e8b84e31..8ba148362a 100644 --- a/tools/test-harness/cmd/build.go +++ b/tools/test-harness/cmd/build.go @@ -93,6 +93,14 @@ func runBuild(ctx context.Context) error { fmt.Printf("Building %s...\n", model.Name) result := r.BuildModel(ctx, model) results[i] = *result + switch { + case result.Passed: + fmt.Printf(" + %s built successfully (%.1fs)\n", model.Name, result.BuildDuration) + case result.Skipped: + fmt.Printf(" - %s (skipped: %s)\n", model.Name, result.SkipReason) + default: + fmt.Printf(" x %s FAILED\n", model.Name) + } } } diff --git a/tools/test-harness/cmd/run.go b/tools/test-harness/cmd/run.go index 7b18be6c20..351ca455e2 100644 --- a/tools/test-harness/cmd/run.go +++ b/tools/test-harness/cmd/run.go @@ -113,6 +113,15 @@ func runRun(ctx context.Context, outputFormat, outputFile string) error { fmt.Printf("Running %s...\n", model.Name) result := r.RunModel(ctx, model) results[i] = *result + switch { + case result.Skipped: + fmt.Printf(" - %s (skipped: %s)\n", model.Name, result.SkipReason) + case result.Passed: + testCount := len(result.TestResults) + len(result.TrainResults) + fmt.Printf(" + %s (%.1fs build, %d tests passed)\n", model.Name, result.BuildDuration, testCount) + default: + fmt.Printf(" x %s FAILED\n", model.Name) + } } } diff --git a/tools/test-harness/cmd/schema_compare.go b/tools/test-harness/cmd/schema_compare.go index 14b68e595a..7bfa9f6591 100644 --- a/tools/test-harness/cmd/schema_compare.go +++ b/tools/test-harness/cmd/schema_compare.go @@ -104,6 +104,11 @@ func runSchemaCompare(ctx context.Context, outputFormat, outputFile string) erro fmt.Printf("Comparing %s...\n", model.Name) result := r.CompareSchema(ctx, model) results[i] = *result + if result.Passed { + fmt.Printf(" + %s schemas match\n", model.Name) + } else { + fmt.Printf(" x %s FAILED\n", model.Name) + } } } From cf7442c03cb1cd13e35cda1f90c72f4f9bde3e98 Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Thu, 16 Apr 2026 10:35:52 -0400 Subject: [PATCH 10/13] chore: regen cli docs Signed-off-by: Mark Phelps --- docs/cli.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/cli.md b/docs/cli.md index 04446b2e29..0006b69056 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -62,6 +62,7 @@ cog build [flags] --progress string Set type of build progress output, 'auto' (default), 'tty', 'plain', or 'quiet' (default "auto") --secret stringArray Secrets to pass to the build environment in the form 'id=foo,src=/path/to/file' --separate-weights Separate model weights from code in image layers + --skip-schema-validation Skip OpenAPI schema generation and validation -t, --tag string A name for the built image in the form 'repository:tag' --use-cog-base-image Use pre-built Cog base image for faster cold boots (default true) --use-cuda-base-image string Use Nvidia CUDA base image, 'true' (default) or 'false' (use python base image). False results in a smaller image but may cause problems for non-torch projects (default "auto") From f3ea12d429791d002e86bba417affd24ed41954d Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Thu, 16 Apr 2026 11:03:08 -0400 Subject: [PATCH 11/13] fix(test-harness): address review findings in parallel execution and runner - Replace errgroup.WithContext with WaitGroup.Go + semaphore channel; goroutines never returned errors so errgroup semantics were misleading and g.Wait() was silently discarded - Extract generic runModels[T] helper to deduplicate the parallel/sequential loop that was copy-pasted across build.go, run.go, and schema_compare.go - Use singleflight.Group in cloneRepo instead of holding a mutex during the entire git clone operation, which serialized all parallel clones - Capture cloneRepo stderr into a buffer instead of writing directly to os.Stderr, preventing interleaved output in parallel mode - Add --concurrency flag validation (reject values < 1 to prevent panic) - Split modelOutput into modelLoggers + modelOutput so RunModel/BuildModel don't allocate capture buffers they never use - Cap prefixWriter line buffer at 64 KiB to prevent unbounded growth from long lines without newlines (e.g. progress bars) - Truncate runSetupCommands error output to last 2000 bytes, matching the pattern already used in buildModelWithEnv --- tools/test-harness/cmd/build.go | 69 ++++++--------- tools/test-harness/cmd/root.go | 68 +++++++++++++++ tools/test-harness/cmd/run.go | 73 ++++++---------- tools/test-harness/cmd/schema_compare.go | 61 ++++++------- tools/test-harness/internal/runner/runner.go | 91 +++++++++++++------- 5 files changed, 209 insertions(+), 153 deletions(-) diff --git a/tools/test-harness/cmd/build.go b/tools/test-harness/cmd/build.go index 8ba148362a..615b2464d9 100644 --- a/tools/test-harness/cmd/build.go +++ b/tools/test-harness/cmd/build.go @@ -3,12 +3,10 @@ package cmd import ( "context" "fmt" - "sync" - - "golang.org/x/sync/errgroup" "github.com/spf13/cobra" + "github.com/replicate/cog/tools/test-harness/internal/manifest" "github.com/replicate/cog/tools/test-harness/internal/report" "github.com/replicate/cog/tools/test-harness/internal/runner" ) @@ -25,6 +23,10 @@ func newBuildCommand() *cobra.Command { } func runBuild(ctx context.Context) error { + if err := validateConcurrency(); err != nil { + return err + } + _, models, resolved, err := resolveSetup() if err != nil { return err @@ -59,64 +61,47 @@ func runBuild(ctx context.Context) error { // Build models results := make([]report.ModelResult, len(models)) - if parallel { - g, ctx := errgroup.WithContext(ctx) - g.SetLimit(concurrency) - - var mu sync.Mutex - for i, model := range models { - g.Go(func() error { - mu.Lock() - fmt.Printf(" [%d/%d] Building %s...\n", i+1, len(models), model.Name) - mu.Unlock() - - result := r.BuildModel(ctx, model) - results[i] = *result - - mu.Lock() + runModels(ctx, models, results, parallel, + func(ctx context.Context, model manifest.Model) *report.ModelResult { + return r.BuildModel(ctx, model) + }, + func(index, total int, model manifest.Model) string { + if parallel { + return fmt.Sprintf(" [%d/%d] Building %s...\n", index, total, model.Name) + } + return fmt.Sprintf("Building %s...\n", model.Name) + }, + func(index, total int, model manifest.Model, result *report.ModelResult) string { + if parallel { switch { case result.Passed: - fmt.Printf(" [%d/%d] + %s (%.1fs)\n", i+1, len(models), model.Name, result.BuildDuration) + return fmt.Sprintf(" [%d/%d] + %s (%.1fs)\n", index, total, model.Name, result.BuildDuration) case result.Skipped: - fmt.Printf(" [%d/%d] - %s (skipped: %s)\n", i+1, len(models), model.Name, result.SkipReason) + return fmt.Sprintf(" [%d/%d] - %s (skipped: %s)\n", index, total, model.Name, result.SkipReason) default: - fmt.Printf(" [%d/%d] x %s FAILED\n", i+1, len(models), model.Name) + return fmt.Sprintf(" [%d/%d] x %s FAILED\n", index, total, model.Name) } - mu.Unlock() - - return nil - }) - } - _ = g.Wait() - } else { - for i, model := range models { - fmt.Printf("Building %s...\n", model.Name) - result := r.BuildModel(ctx, model) - results[i] = *result + } switch { case result.Passed: - fmt.Printf(" + %s built successfully (%.1fs)\n", model.Name, result.BuildDuration) + return fmt.Sprintf(" + %s built successfully (%.1fs)\n", model.Name, result.BuildDuration) case result.Skipped: - fmt.Printf(" - %s (skipped: %s)\n", model.Name, result.SkipReason) + return fmt.Sprintf(" - %s (skipped: %s)\n", model.Name, result.SkipReason) default: - fmt.Printf(" x %s FAILED\n", model.Name) + return fmt.Sprintf(" x %s FAILED\n", model.Name) } - } - } + }, + ) // Output results report.ConsoleReport(results, resolved.SDKVersion, resolved.CogVersion) // Check for failures - var failedNames []string for _, r := range results { if !r.Passed && !r.Skipped { - failedNames = append(failedNames, r.Name) + return formatFailureSummary("build", results) } } - if len(failedNames) > 0 { - return formatFailureSummary("build", results) - } return nil } diff --git a/tools/test-harness/cmd/root.go b/tools/test-harness/cmd/root.go index 32e91b0d99..7f701befd3 100644 --- a/tools/test-harness/cmd/root.go +++ b/tools/test-harness/cmd/root.go @@ -1,8 +1,10 @@ package cmd import ( + "context" "fmt" "strings" + "sync" "github.com/spf13/cobra" @@ -82,7 +84,73 @@ func resolveSetup() (*manifest.Manifest, []manifest.Model, *resolver.Result, err return mf, models, resolved, nil } +// validateConcurrency checks that the concurrency flag is a valid value. +// errgroup.SetLimit panics on 0, and negative values mean unlimited. +func validateConcurrency() error { + if concurrency < 1 { + return fmt.Errorf("--concurrency must be at least 1, got %d", concurrency) + } + return nil +} + +// modelAction is a function that processes a single model and returns a result. +type modelAction[T any] func(ctx context.Context, model manifest.Model) *T + +// statusPrinter formats a per-model status line after processing completes. +type statusPrinter[T any] func(index, total int, model manifest.Model, result *T) string + +// runModels executes an action for each model, either sequentially or in parallel +// depending on the concurrency setting. It handles the common pattern of: +// - printing a "starting" line +// - running the action +// - printing a "done/failed" status line +// +// The results slice is pre-allocated by the caller. This function fills it in. +func runModels[T any]( + ctx context.Context, + models []manifest.Model, + results []T, + parallel bool, + action modelAction[T], + startLine func(index, total int, model manifest.Model) string, + statusLine statusPrinter[T], +) { + if parallel { + sem := make(chan struct{}, concurrency) + var wg sync.WaitGroup + var mu sync.Mutex + + for i, model := range models { + wg.Go(func() { + sem <- struct{}{} // acquire + defer func() { <-sem }() // release + + mu.Lock() + fmt.Print(startLine(i+1, len(models), model)) + mu.Unlock() + + result := action(ctx, model) + results[i] = *result + + mu.Lock() + fmt.Print(statusLine(i+1, len(models), model, result)) + mu.Unlock() + }) + } + wg.Wait() + } else { + for i, model := range models { + fmt.Print(startLine(i+1, len(models), model)) + result := action(ctx, model) + results[i] = *result + fmt.Print(statusLine(i+1, len(models), model, result)) + } + } +} + // formatFailureSummary builds an error message with per-model failure details. +// +//nolint:gosec // G705: writes to strings.Builder, not an HTTP response — no XSS risk func formatFailureSummary(action string, results []report.ModelResult) error { var b strings.Builder var failCount int diff --git a/tools/test-harness/cmd/run.go b/tools/test-harness/cmd/run.go index 351ca455e2..c7167d9779 100644 --- a/tools/test-harness/cmd/run.go +++ b/tools/test-harness/cmd/run.go @@ -4,12 +4,10 @@ import ( "context" "fmt" "os" - "sync" - - "golang.org/x/sync/errgroup" "github.com/spf13/cobra" + "github.com/replicate/cog/tools/test-harness/internal/manifest" "github.com/replicate/cog/tools/test-harness/internal/report" "github.com/replicate/cog/tools/test-harness/internal/runner" ) @@ -38,6 +36,10 @@ func runRun(ctx context.Context, outputFormat, outputFile string) error { return fmt.Errorf("invalid output format %q: must be 'console' or 'json'", outputFormat) } + if err := validateConcurrency(); err != nil { + return err + } + _, models, resolved, err := resolveSetup() if err != nil { return err @@ -78,52 +80,38 @@ func runRun(ctx context.Context, outputFormat, outputFile string) error { // Run tests results := make([]report.ModelResult, len(models)) - if parallel { - g, ctx := errgroup.WithContext(ctx) - g.SetLimit(concurrency) - - var mu sync.Mutex - for i, model := range models { - g.Go(func() error { - mu.Lock() - fmt.Printf(" [%d/%d] Running %s...\n", i+1, len(models), model.Name) - mu.Unlock() - - result := r.RunModel(ctx, model) - results[i] = *result - - mu.Lock() + runModels(ctx, models, results, parallel, + func(ctx context.Context, model manifest.Model) *report.ModelResult { + return r.RunModel(ctx, model) + }, + func(index, total int, model manifest.Model) string { + if parallel { + return fmt.Sprintf(" [%d/%d] Running %s...\n", index, total, model.Name) + } + return fmt.Sprintf("Running %s...\n", model.Name) + }, + func(index, total int, model manifest.Model, result *report.ModelResult) string { + testCount := len(result.TestResults) + len(result.TrainResults) + if parallel { switch { case result.Skipped: - fmt.Printf(" [%d/%d] - %s (skipped: %s)\n", i+1, len(models), model.Name, result.SkipReason) + return fmt.Sprintf(" [%d/%d] - %s (skipped: %s)\n", index, total, model.Name, result.SkipReason) case result.Passed: - testCount := len(result.TestResults) + len(result.TrainResults) - fmt.Printf(" [%d/%d] + %s (%.1fs build, %d tests passed)\n", i+1, len(models), model.Name, result.BuildDuration, testCount) + return fmt.Sprintf(" [%d/%d] + %s (%.1fs build, %d tests passed)\n", index, total, model.Name, result.BuildDuration, testCount) default: - fmt.Printf(" [%d/%d] x %s FAILED\n", i+1, len(models), model.Name) + return fmt.Sprintf(" [%d/%d] x %s FAILED\n", index, total, model.Name) } - mu.Unlock() - - return nil - }) - } - _ = g.Wait() - } else { - for i, model := range models { - fmt.Printf("Running %s...\n", model.Name) - result := r.RunModel(ctx, model) - results[i] = *result + } switch { case result.Skipped: - fmt.Printf(" - %s (skipped: %s)\n", model.Name, result.SkipReason) + return fmt.Sprintf(" - %s (skipped: %s)\n", model.Name, result.SkipReason) case result.Passed: - testCount := len(result.TestResults) + len(result.TrainResults) - fmt.Printf(" + %s (%.1fs build, %d tests passed)\n", model.Name, result.BuildDuration, testCount) + return fmt.Sprintf(" + %s (%.1fs build, %d tests passed)\n", model.Name, result.BuildDuration, testCount) default: - fmt.Printf(" x %s FAILED\n", model.Name) + return fmt.Sprintf(" x %s FAILED\n", model.Name) } - } - } + }, + ) // Output results if outputFormat == "json" { @@ -162,16 +150,11 @@ func runRun(ctx context.Context, outputFormat, outputFile string) error { } // Check for failures - var hasFailures bool for _, r := range results { if !r.Passed && !r.Skipped { - hasFailures = true - break + return formatFailureSummary("model", results) } } - if hasFailures { - return formatFailureSummary("model", results) - } return nil } diff --git a/tools/test-harness/cmd/schema_compare.go b/tools/test-harness/cmd/schema_compare.go index 7bfa9f6591..724cf30418 100644 --- a/tools/test-harness/cmd/schema_compare.go +++ b/tools/test-harness/cmd/schema_compare.go @@ -5,12 +5,10 @@ import ( "fmt" "os" "strings" - "sync" - - "golang.org/x/sync/errgroup" "github.com/spf13/cobra" + "github.com/replicate/cog/tools/test-harness/internal/manifest" "github.com/replicate/cog/tools/test-harness/internal/report" "github.com/replicate/cog/tools/test-harness/internal/runner" ) @@ -39,6 +37,10 @@ func runSchemaCompare(ctx context.Context, outputFormat, outputFile string) erro return fmt.Errorf("invalid output format %q: must be 'console' or 'json'", outputFormat) } + if err := validateConcurrency(); err != nil { + return err + } + _, models, resolved, err := resolveSetup() if err != nil { return err @@ -73,44 +75,29 @@ func runSchemaCompare(ctx context.Context, outputFormat, outputFile string) erro // Compare schemas results := make([]report.SchemaCompareResult, len(models)) - if parallel { - g, ctx := errgroup.WithContext(ctx) - g.SetLimit(concurrency) - - var mu sync.Mutex - for i, model := range models { - g.Go(func() error { - mu.Lock() - fmt.Printf(" [%d/%d] Comparing %s...\n", i+1, len(models), model.Name) - mu.Unlock() - - result := r.CompareSchema(ctx, model) - results[i] = *result - - mu.Lock() + runModels(ctx, models, results, parallel, + func(ctx context.Context, model manifest.Model) *report.SchemaCompareResult { + return r.CompareSchema(ctx, model) + }, + func(index, total int, model manifest.Model) string { + if parallel { + return fmt.Sprintf(" [%d/%d] Comparing %s...\n", index, total, model.Name) + } + return fmt.Sprintf("Comparing %s...\n", model.Name) + }, + func(index, total int, model manifest.Model, result *report.SchemaCompareResult) string { + if parallel { if result.Passed { - fmt.Printf(" [%d/%d] + %s schemas match\n", i+1, len(models), model.Name) - } else { - fmt.Printf(" [%d/%d] x %s FAILED\n", i+1, len(models), model.Name) + return fmt.Sprintf(" [%d/%d] + %s schemas match\n", index, total, model.Name) } - mu.Unlock() - - return nil - }) - } - _ = g.Wait() - } else { - for i, model := range models { - fmt.Printf("Comparing %s...\n", model.Name) - result := r.CompareSchema(ctx, model) - results[i] = *result + return fmt.Sprintf(" [%d/%d] x %s FAILED\n", index, total, model.Name) + } if result.Passed { - fmt.Printf(" + %s schemas match\n", model.Name) - } else { - fmt.Printf(" x %s FAILED\n", model.Name) + return fmt.Sprintf(" + %s schemas match\n", model.Name) } - } - } + return fmt.Sprintf(" x %s FAILED\n", model.Name) + }, + ) // Output results if outputFormat == "json" { diff --git a/tools/test-harness/internal/runner/runner.go b/tools/test-harness/internal/runner/runner.go index b3d1aa1b12..9d86981650 100644 --- a/tools/test-harness/internal/runner/runner.go +++ b/tools/test-harness/internal/runner/runner.go @@ -18,6 +18,7 @@ import ( "time" "golang.org/x/sync/errgroup" + "golang.org/x/sync/singleflight" "github.com/replicate/cog/tools/test-harness/internal/manifest" "github.com/replicate/cog/tools/test-harness/internal/patcher" @@ -27,8 +28,14 @@ import ( const openapiSchemaLabel = "run.cog.openapi_schema" +// maxPrefixBuf is the maximum size of the prefixWriter line buffer before +// it is force-flushed. This prevents unbounded memory growth when a +// subprocess writes very long lines without newlines (e.g. progress bars). +const maxPrefixBuf = 64 * 1024 // 64 KiB + // prefixWriter wraps an io.Writer and prepends a prefix to each line. -// Partial lines (no trailing newline) are buffered until a newline arrives. +// Partial lines (no trailing newline) are buffered until a newline arrives +// or the buffer exceeds maxPrefixBuf. type prefixWriter struct { prefix string dest io.Writer @@ -61,6 +68,15 @@ func (pw *prefixWriter) Write(p []byte) (int, error) { return total, err } } + + // Force-flush if the buffer has grown too large (e.g. no newlines in output). + if len(pw.buf) > maxPrefixBuf { + if _, err := fmt.Fprintf(pw.dest, "%s%s\n", pw.prefix, pw.buf); err != nil { + return total, err + } + pw.buf = pw.buf[:0] + } + return total, nil } @@ -75,11 +91,22 @@ func (pw *prefixWriter) Flush() { } } -// modelOutput returns stdout/stderr writers for a model. -// In parallel mode, output is prefixed with the model name and -// also captured in a buffer for error reporting. -// In sequential mode, output streams directly to the terminal -// and is also captured. +// modelLoggers returns stdout/stderr writers for status logging only. +// In parallel mode, output is prefixed with the model name. +// In sequential mode, output streams directly to the terminal. +// Use modelOutput when captured output is needed for error reporting. +func (r *Runner) modelLoggers(modelName string) (logw io.Writer, flush func()) { + if r.opts.Parallel { + pw := newPrefixWriter(os.Stderr, modelName) + return pw, pw.Flush + } + return os.Stderr, func() {} +} + +// modelOutput returns stdout/stderr writers for a model that also capture +// output into a buffer for error reporting. +// In parallel mode, output is prefixed with the model name. +// In sequential mode, output streams directly to the terminal. func (r *Runner) modelOutput(modelName string) (stdout, stderr io.Writer, capture *bytes.Buffer, flush func()) { var buf bytes.Buffer if r.opts.Parallel { @@ -108,8 +135,7 @@ type Runner struct { opts Options fixturesDir string workDir string - clonedRepos map[string]string - mu sync.Mutex // protects clonedRepos + cloneGroup singleflight.Group // deduplicates concurrent clones of the same repo } // New creates a new Runner @@ -145,7 +171,6 @@ func New(opts Options) (*Runner, error) { opts: opts, fixturesDir: fixturesDir, workDir: workDir, - clonedRepos: make(map[string]string), }, nil } @@ -212,7 +237,7 @@ func (r *Runner) Cleanup() error { // RunModel runs all tests for a single model func (r *Runner) RunModel(ctx context.Context, model manifest.Model) *report.ModelResult { - _, logw, _, flush := r.modelOutput(model.Name) + logw, flush := r.modelLoggers(model.Name) result := &report.ModelResult{ Name: model.Name, @@ -294,7 +319,7 @@ func (r *Runner) RunModel(ctx context.Context, model manifest.Model) *report.Mod // BuildModel builds a model image only func (r *Runner) BuildModel(ctx context.Context, model manifest.Model) *report.ModelResult { - _, logw, _, flush := r.modelOutput(model.Name) + logw, flush := r.modelLoggers(model.Name) result := &report.ModelResult{ Name: model.Name, @@ -565,7 +590,14 @@ func (r *Runner) runSetupCommands(ctx context.Context, modelDir string, model ma cmd.Stderr = stderr if err := cmd.Run(); err != nil { flush() - return fmt.Errorf("setup command %q failed: %w\n%s", cmdStr, err, capture.String()) + // Truncate captured output to avoid unwieldy error messages, + // matching the pattern used in buildModelWithEnv. + output := capture.String() + const maxTail = 2000 + if len(output) > maxTail { + output = "...\n" + output[len(output)-maxTail:] + } + return fmt.Errorf("setup command %q failed: %w\n%s", cmdStr, err, output) } } flush() @@ -573,28 +605,29 @@ func (r *Runner) runSetupCommands(ctx context.Context, modelDir string, model ma } // cloneRepo clones a repo once and caches the result. Thread-safe. +// Uses singleflight to deduplicate concurrent clones of the same repo +// without holding a mutex during the (potentially slow) git clone. func (r *Runner) cloneRepo(ctx context.Context, repo string) (string, error) { - r.mu.Lock() - defer r.mu.Unlock() + result, err, _ := r.cloneGroup.Do(repo, func() (any, error) { + dest := filepath.Join(r.workDir, strings.ReplaceAll(repo, "/", "--")) - if dir, ok := r.clonedRepos[repo]; ok { - return dir, nil - } + // Remove if exists + _ = os.RemoveAll(dest) - dest := filepath.Join(r.workDir, strings.ReplaceAll(repo, "/", "--")) - - // Remove if exists - _ = os.RemoveAll(dest) + url := fmt.Sprintf("https://github.com/%s.git", repo) + cmd := exec.CommandContext(ctx, "git", "clone", "--depth=1", url, dest) + var stderr bytes.Buffer + cmd.Stderr = &stderr + if err := cmd.Run(); err != nil { + return "", fmt.Errorf("cloning %s: %w\n%s", repo, err, stderr.String()) + } - url := fmt.Sprintf("https://github.com/%s.git", repo) - cmd := exec.CommandContext(ctx, "git", "clone", "--depth=1", url, dest) - cmd.Stderr = os.Stderr - if err := cmd.Run(); err != nil { - return "", fmt.Errorf("cloning %s: %w", repo, err) + return dest, nil + }) + if err != nil { + return "", err } - - r.clonedRepos[repo] = dest - return dest, nil + return result.(string), nil } func (r *Runner) buildModel(ctx context.Context, modelDir string, model manifest.Model) error { From d4f07ddce54a77d04e9ef28c019d4d1433c956bc Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Thu, 16 Apr 2026 11:03:28 -0400 Subject: [PATCH 12/13] chore: update llm docs Signed-off-by: Mark Phelps --- docs/llms.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/llms.txt b/docs/llms.txt index 73975c6beb..0c152ccaee 100644 --- a/docs/llms.txt +++ b/docs/llms.txt @@ -258,6 +258,7 @@ cog build [flags] --progress string Set type of build progress output, 'auto' (default), 'tty', 'plain', or 'quiet' (default "auto") --secret stringArray Secrets to pass to the build environment in the form 'id=foo,src=/path/to/file' --separate-weights Separate model weights from code in image layers + --skip-schema-validation Skip OpenAPI schema generation and validation -t, --tag string A name for the built image in the form 'repository:tag' --use-cog-base-image Use pre-built Cog base image for faster cold boots (default true) --use-cuda-base-image string Use Nvidia CUDA base image, 'true' (default) or 'false' (use python base image). False results in a smaller image but may cause problems for non-torch projects (default "auto") From d537b25bbd43c099326c71ba588725c2157affa8 Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Thu, 16 Apr 2026 13:46:33 -0400 Subject: [PATCH 13/13] feat(test-harness): unified-diff schema comparison and quiet builds Improve schema-compare output so differences are immediately readable: - Replace terse one-line-per-path diff format with unified-diff style output showing actual JSON values with --- static / +++ runtime headers and @@ path @@ hunks (like git diff) - Sort diff keys alphabetically for deterministic output - Suppress docker build log streaming during schema comparison builds (quietBuildModelWithEnv) so build output no longer intermingles with the diff report; build logs are still captured for error messages - Add spacing around diff blocks in console report for readability - Add table-driven tests for the new jsonDiff format --- tools/test-harness/internal/report/report.go | 4 +- tools/test-harness/internal/runner/runner.go | 148 ++++++++++++++---- .../internal/runner/runner_test.go | 63 ++++++++ 3 files changed, 183 insertions(+), 32 deletions(-) diff --git a/tools/test-harness/internal/report/report.go b/tools/test-harness/internal/report/report.go index dfb2a074d1..8822a35d71 100644 --- a/tools/test-harness/internal/report/report.go +++ b/tools/test-harness/internal/report/report.go @@ -260,9 +260,11 @@ func SchemaCompareConsoleReport(results []SchemaCompareResult, cogVersion string writeStatus("x", r.Name, "schemas differ", false) failed++ if r.Diff != "" { + fmt.Println() for line := range strings.SplitSeq(r.Diff, "\n") { - fmt.Printf(" %s\n", line) + fmt.Printf(" %s\n", line) } + fmt.Println() } } } diff --git a/tools/test-harness/internal/runner/runner.go b/tools/test-harness/internal/runner/runner.go index 9d86981650..94eb603221 100644 --- a/tools/test-harness/internal/runner/runner.go +++ b/tools/test-harness/internal/runner/runner.go @@ -413,7 +413,7 @@ func (r *Runner) CompareSchema(ctx context.Context, model manifest.Model) *repor runtimeStart := time.Now() g.Go(func() error { - staticErr = r.buildModelWithEnv(ctx, staticDir, model, staticTag, map[string]string{"COG_STATIC_SCHEMA": "1"}) + staticErr = r.quietBuildModelWithEnv(ctx, staticDir, model, staticTag, map[string]string{"COG_STATIC_SCHEMA": "1"}) if staticErr != nil { return nil // Don't fail the group, we'll check errors after } @@ -423,7 +423,7 @@ func (r *Runner) CompareSchema(ctx context.Context, model manifest.Model) *repor }) g.Go(func() error { - runtimeErr = r.buildModelWithEnv(ctx, runtimeDir, model, runtimeTag, map[string]string{}) + runtimeErr = r.quietBuildModelWithEnv(ctx, runtimeDir, model, runtimeTag, map[string]string{}) if runtimeErr != nil { return nil } @@ -635,7 +635,22 @@ func (r *Runner) buildModel(ctx context.Context, modelDir string, model manifest return r.buildModelWithEnv(ctx, modelDir, model, imageTag, nil) } +// buildModelWithEnv builds a model image, streaming output to the terminal +// (or prefixed in parallel mode) for real-time progress visibility. func (r *Runner) buildModelWithEnv(ctx context.Context, modelDir string, model manifest.Model, imageTag string, extraEnv map[string]string) error { + _, stderr, capture, flush := r.modelOutput(model.Name) + return r.doBuild(ctx, modelDir, model, imageTag, extraEnv, stderr, capture, flush) +} + +// quietBuildModelWithEnv builds a model image without streaming output to +// the terminal. Output is captured for inclusion in error messages on failure. +// Used by CompareSchema where build logs would interleave with diff results. +func (r *Runner) quietBuildModelWithEnv(ctx context.Context, modelDir string, model manifest.Model, imageTag string, extraEnv map[string]string) error { + var capture bytes.Buffer + return r.doBuild(ctx, modelDir, model, imageTag, extraEnv, &capture, &capture, func() {}) +} + +func (r *Runner) doBuild(ctx context.Context, modelDir string, model manifest.Model, imageTag string, extraEnv map[string]string, output io.Writer, capture *bytes.Buffer, flush func()) error { // Set timeout timeout := model.Timeout if timeout == 0 { @@ -673,22 +688,18 @@ func (r *Runner) buildModelWithEnv(ctx context.Context, modelDir string, model m cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", k, v)) } - // Stream build output in real-time so the user can see progress, - // while also capturing it for error reporting if the build fails. - // In parallel mode, each line is prefixed with the model name. - _, stderr, capture, flush := r.modelOutput(model.Name) - cmd.Stdout = stderr - cmd.Stderr = stderr + cmd.Stdout = output + cmd.Stderr = output err := cmd.Run() flush() if err != nil { // Include the last portion of build output for context. - output := capture.String() + out := capture.String() const maxTail = 2000 - if len(output) > maxTail { - output = "...\n" + output[len(output)-maxTail:] + if len(out) > maxTail { + out = "...\n" + out[len(out)-maxTail:] } - return fmt.Errorf("%w\n%s", err, output) + return fmt.Errorf("%w\n%s", err, out) } return nil } @@ -883,21 +894,86 @@ func copyDir(src, dst string) error { }) } +// jsonDiff produces a unified-diff-style comparison between two JSON objects. +// The output resembles git diff, showing the static schema as "a" (old) and +// the runtime schema as "b" (new), with -, + and ~ prefixes for removed, +// added, and changed values respectively. func jsonDiff(a, b map[string]any) string { - var lines []string - diffRecursive(a, b, "$", &lines) - if len(lines) == 0 { + var hunks []diffHunk + collectDiffs(a, b, "$", &hunks) + if len(hunks) == 0 { return "" } - return strings.Join(lines, "\n") + + var buf strings.Builder + buf.WriteString("--- static schema\n") + buf.WriteString("+++ runtime schema\n") + for _, h := range hunks { + buf.WriteString("\n") + buf.WriteString(h.String()) + } + return buf.String() +} + +// diffHunk represents a single difference between two JSON values. +type diffHunk struct { + Path string + Kind string // "missing_in_static", "missing_in_runtime", "changed", "type_mismatch", "array_length" + StaticV any // value in static (nil if missing) + RuntimeV any // value in runtime (nil if missing) } -func diffRecursive(a, b any, path string, lines *[]string) { +func (h diffHunk) String() string { + var buf strings.Builder + fmt.Fprintf(&buf, "@@ %s @@\n", h.Path) + switch h.Kind { + case "missing_in_runtime": + // Present in static, absent in runtime + for _, line := range prettyLines(h.StaticV) { + fmt.Fprintf(&buf, "-%s\n", line) + } + case "missing_in_static": + // Absent in static, present in runtime + for _, line := range prettyLines(h.RuntimeV) { + fmt.Fprintf(&buf, "+%s\n", line) + } + case "changed", "type_mismatch": + for _, line := range prettyLines(h.StaticV) { + fmt.Fprintf(&buf, "-%s\n", line) + } + for _, line := range prettyLines(h.RuntimeV) { + fmt.Fprintf(&buf, "+%s\n", line) + } + case "array_length": + for _, line := range prettyLines(h.StaticV) { + fmt.Fprintf(&buf, "-%s\n", line) + } + for _, line := range prettyLines(h.RuntimeV) { + fmt.Fprintf(&buf, "+%s\n", line) + } + } + return buf.String() +} + +// prettyLines returns a compact JSON representation split into lines. +func prettyLines(v any) []string { + data, err := json.MarshalIndent(v, "", " ") + if err != nil { + return []string{fmt.Sprintf("%v", v)} + } + return strings.Split(string(data), "\n") +} + +func collectDiffs(a, b any, path string, hunks *[]diffHunk) { if a == nil && b == nil { return } - if a == nil || b == nil { - *lines = append(*lines, fmt.Sprintf(" %s: one side is nil", path)) + if a == nil { + *hunks = append(*hunks, diffHunk{Path: path, Kind: "missing_in_static", RuntimeV: b}) + return + } + if b == nil { + *hunks = append(*hunks, diffHunk{Path: path, Kind: "missing_in_runtime", StaticV: a}) return } @@ -905,9 +981,10 @@ func diffRecursive(a, b any, path string, lines *[]string) { case map[string]any: bv, ok := b.(map[string]any) if !ok { - *lines = append(*lines, fmt.Sprintf(" %s: type mismatch (object vs %T)", path, b)) + *hunks = append(*hunks, diffHunk{Path: path, Kind: "type_mismatch", StaticV: a, RuntimeV: b}) return } + // Collect all keys, sorted for deterministic output allKeys := make(map[string]bool) for k := range av { allKeys[k] = true @@ -915,33 +992,42 @@ func diffRecursive(a, b any, path string, lines *[]string) { for k := range bv { allKeys[k] = true } + sortedKeys := make([]string, 0, len(allKeys)) for k := range allKeys { + sortedKeys = append(sortedKeys, k) + } + sort.Strings(sortedKeys) + + for _, k := range sortedKeys { childPath := fmt.Sprintf("%s.%s", path, k) - if _, ok := av[k]; !ok { - *lines = append(*lines, fmt.Sprintf(" %s: missing in static", childPath)) - } else if _, ok := bv[k]; !ok { - *lines = append(*lines, fmt.Sprintf(" %s: missing in runtime", childPath)) - } else { - diffRecursive(av[k], bv[k], childPath, lines) + aVal, aOK := av[k] + bVal, bOK := bv[k] + switch { + case !aOK: + *hunks = append(*hunks, diffHunk{Path: childPath, Kind: "missing_in_static", RuntimeV: bVal}) + case !bOK: + *hunks = append(*hunks, diffHunk{Path: childPath, Kind: "missing_in_runtime", StaticV: aVal}) + default: + collectDiffs(aVal, bVal, childPath, hunks) } } case []any: bv, ok := b.([]any) if !ok { - *lines = append(*lines, fmt.Sprintf(" %s: type mismatch (array vs %T)", path, b)) + *hunks = append(*hunks, diffHunk{Path: path, Kind: "type_mismatch", StaticV: a, RuntimeV: b}) return } if len(av) != len(bv) { - *lines = append(*lines, fmt.Sprintf(" %s: array length mismatch (%d vs %d)", path, len(av), len(bv))) + *hunks = append(*hunks, diffHunk{Path: path, Kind: "array_length", StaticV: a, RuntimeV: b}) return } for i := range av { childPath := fmt.Sprintf("%s[%d]", path, i) - diffRecursive(av[i], bv[i], childPath, lines) + collectDiffs(av[i], bv[i], childPath, hunks) } default: - if a != b { - *lines = append(*lines, fmt.Sprintf(" %s: value mismatch (%v vs %v)", path, a, b)) + if fmt.Sprint(a) != fmt.Sprint(b) { + *hunks = append(*hunks, diffHunk{Path: path, Kind: "changed", StaticV: a, RuntimeV: b}) } } } diff --git a/tools/test-harness/internal/runner/runner_test.go b/tools/test-harness/internal/runner/runner_test.go index cdcc9f72c6..5ee16cfbce 100644 --- a/tools/test-harness/internal/runner/runner_test.go +++ b/tools/test-harness/internal/runner/runner_test.go @@ -2,6 +2,7 @@ package runner import ( "path/filepath" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -38,6 +39,68 @@ func TestSafeSubpathRejectsAbsoluteOutsidePath(t *testing.T) { assert.Contains(t, err.Error(), "must be relative") } +func TestJsonDiffUnifiedFormat(t *testing.T) { + tests := []struct { + name string + static map[string]any + runtime map[string]any + contains []string // substrings the diff must contain + empty bool // expect no diff + }{ + { + name: "identical schemas produce no diff", + static: map[string]any{"a": 1, "b": "hello"}, + runtime: map[string]any{"a": 1, "b": "hello"}, + empty: true, + }, + { + name: "missing key in static shows + lines", + static: map[string]any{"a": 1}, + runtime: map[string]any{"a": 1, "b": "new"}, + contains: []string{"--- static schema", "+++ runtime schema", "@@ $.b @@", "+\"new\""}, + }, + { + name: "missing key in runtime shows - lines", + static: map[string]any{"a": 1, "b": "old"}, + runtime: map[string]any{"a": 1}, + contains: []string{"@@ $.b @@", "-\"old\""}, + }, + { + name: "changed value shows - and + lines", + static: map[string]any{"a": "foo"}, + runtime: map[string]any{"a": "bar"}, + contains: []string{"@@ $.a @@", "-\"foo\"", "+\"bar\""}, + }, + { + name: "nested object diff shows full path", + static: map[string]any{"outer": map[string]any{"inner": 1}}, + runtime: map[string]any{"outer": map[string]any{"inner": 2}}, + contains: []string{"@@ $.outer.inner @@", "-1", "+2"}, + }, + { + name: "array length mismatch shows both arrays", + static: map[string]any{"arr": []any{"a", "b"}}, + runtime: map[string]any{"arr": []any{"a", "b", "c"}}, + contains: []string{"@@ $.arr @@", "-[", "+[", "+ \"c\""}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + diff := jsonDiff(tt.static, tt.runtime) + if tt.empty { + assert.Empty(t, diff) + return + } + require.NotEmpty(t, diff) + for _, substr := range tt.contains { + assert.True(t, strings.Contains(diff, substr), + "diff should contain %q but got:\n%s", substr, diff) + } + }) + } +} + func TestExtractOutput(t *testing.T) { tests := []struct { name string