diff --git a/go.mod b/go.mod index 84a44f1b45..9f4e907903 100644 --- a/go.mod +++ b/go.mod @@ -39,6 +39,7 @@ require ( golang.org/x/sys v0.42.0 golang.org/x/term v0.41.0 google.golang.org/grpc v1.79.3 + gopkg.in/yaml.v3 v3.0.1 ) require ( @@ -148,7 +149,6 @@ require ( google.golang.org/genproto/googleapis/api v0.0.0-20260128011058-8636f8732409 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20260128011058-8636f8732409 // indirect google.golang.org/protobuf v1.36.11 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect gotest.tools/gotestsum v1.12.2 // indirect ) diff --git a/pkg/cli/build.go b/pkg/cli/build.go index 490ab585f3..8e52b5924a 100644 --- a/pkg/cli/build.go +++ b/pkg/cli/build.go @@ -15,18 +15,21 @@ import ( "github.com/replicate/cog/pkg/util/console" ) -var buildTag string -var buildSeparateWeights bool -var buildSecrets []string -var buildNoCache bool -var buildProgressOutput string -var buildSchemaFile string -var buildUseCudaBaseImage string -var buildDockerfileFile string -var buildUseCogBaseImage bool -var buildStrip bool -var buildPrecompile bool -var configFilename string +var ( + buildTag string + buildSeparateWeights bool + buildSecrets []string + buildNoCache bool + buildProgressOutput string + buildSchemaFile string + buildUseCudaBaseImage string + buildDockerfileFile string + buildUseCogBaseImage bool + buildStrip bool + buildPrecompile bool + buildSkipSchemaValidation bool + configFilename string +) const useCogBaseImageFlagKey = "use-cog-base-image" @@ -66,6 +69,7 @@ with 'cog push'.`, addStripFlag(cmd) addPrecompileFlag(cmd) addConfigFlag(cmd) + addSkipSchemaValidationFlag(cmd) cmd.Flags().StringVarP(&buildTag, "tag", "t", "", "A name for the built image in the form 'repository:tag'") return cmd } @@ -194,22 +198,28 @@ func DetermineUseCogBaseImage(cmd *cobra.Command) *bool { return useCogBaseImage } +func addSkipSchemaValidationFlag(cmd *cobra.Command) { + cmd.Flags().BoolVar(&buildSkipSchemaValidation, "skip-schema-validation", false, "Skip OpenAPI schema generation and validation") + _ = cmd.Flags().MarkHidden("skip-schema-validation") +} + // buildOptionsFromFlags creates BuildOptions from the current CLI flag values. // The imageName and annotations parameters vary by command and must be provided. func buildOptionsFromFlags(cmd *cobra.Command, imageName string, annotations map[string]string) model.BuildOptions { return model.BuildOptions{ - ImageName: imageName, - Secrets: buildSecrets, - NoCache: buildNoCache, - SeparateWeights: buildSeparateWeights, - UseCudaBaseImage: buildUseCudaBaseImage, - ProgressOutput: buildProgressOutput, - SchemaFile: buildSchemaFile, - DockerfileFile: buildDockerfileFile, - UseCogBaseImage: DetermineUseCogBaseImage(cmd), - Strip: buildStrip, - Precompile: buildPrecompile, - Annotations: annotations, - OCIIndex: model.OCIIndexEnabled(), + ImageName: imageName, + Secrets: buildSecrets, + NoCache: buildNoCache, + SeparateWeights: buildSeparateWeights, + UseCudaBaseImage: buildUseCudaBaseImage, + ProgressOutput: buildProgressOutput, + SchemaFile: buildSchemaFile, + DockerfileFile: buildDockerfileFile, + UseCogBaseImage: DetermineUseCogBaseImage(cmd), + Strip: buildStrip, + Precompile: buildPrecompile, + Annotations: annotations, + OCIIndex: model.OCIIndexEnabled(), + SkipSchemaValidation: buildSkipSchemaValidation, } } diff --git a/tools/test-harness/cmd/build.go b/tools/test-harness/cmd/build.go index cceee3bcfa..615b2464d9 100644 --- a/tools/test-harness/cmd/build.go +++ b/tools/test-harness/cmd/build.go @@ -3,10 +3,10 @@ package cmd import ( "context" "fmt" - "strings" "github.com/spf13/cobra" + "github.com/replicate/cog/tools/test-harness/internal/manifest" "github.com/replicate/cog/tools/test-harness/internal/report" "github.com/replicate/cog/tools/test-harness/internal/runner" ) @@ -23,6 +23,10 @@ func newBuildCommand() *cobra.Command { } func runBuild(ctx context.Context) error { + if err := validateConcurrency(); err != nil { + return err + } + _, models, resolved, err := resolveSetup() if err != nil { return err @@ -32,7 +36,13 @@ func runBuild(ctx context.Context) error { fmt.Println("No models to build") return nil } - fmt.Printf("Building %d model(s)\n\n", len(models)) + + parallel := concurrency > 1 && len(models) > 1 + if parallel { + fmt.Printf("Building %d model(s) with concurrency %d\n\n", len(models), concurrency) + } else { + fmt.Printf("Building %d model(s)\n\n", len(models)) + } // Create runner r, err := runner.New(runner.Options{ @@ -41,33 +51,57 @@ func runBuild(ctx context.Context) error { SDKWheel: resolved.SDKWheel, CleanImages: cleanImages, KeepOutputs: keepOutputs, + Parallel: parallel, }) if err != nil { return fmt.Errorf("creating runner: %w", err) } - defer r.Cleanup() + defer func() { _ = r.Cleanup() }() // Build models - var results []report.ModelResult - for _, model := range models { - fmt.Printf("Building %s...\n", model.Name) - result := r.BuildModel(ctx, model) - results = append(results, *result) - } + results := make([]report.ModelResult, len(models)) + + runModels(ctx, models, results, parallel, + func(ctx context.Context, model manifest.Model) *report.ModelResult { + return r.BuildModel(ctx, model) + }, + func(index, total int, model manifest.Model) string { + if parallel { + return fmt.Sprintf(" [%d/%d] Building %s...\n", index, total, model.Name) + } + return fmt.Sprintf("Building %s...\n", model.Name) + }, + func(index, total int, model manifest.Model, result *report.ModelResult) string { + if parallel { + switch { + case result.Passed: + return fmt.Sprintf(" [%d/%d] + %s (%.1fs)\n", index, total, model.Name, result.BuildDuration) + case result.Skipped: + return fmt.Sprintf(" [%d/%d] - %s (skipped: %s)\n", index, total, model.Name, result.SkipReason) + default: + return fmt.Sprintf(" [%d/%d] x %s FAILED\n", index, total, model.Name) + } + } + switch { + case result.Passed: + return fmt.Sprintf(" + %s built successfully (%.1fs)\n", model.Name, result.BuildDuration) + case result.Skipped: + return fmt.Sprintf(" - %s (skipped: %s)\n", model.Name, result.SkipReason) + default: + return fmt.Sprintf(" x %s FAILED\n", model.Name) + } + }, + ) // Output results report.ConsoleReport(results, resolved.SDKVersion, resolved.CogVersion) // Check for failures - var failedNames []string for _, r := range results { if !r.Passed && !r.Skipped { - failedNames = append(failedNames, r.Name) + return formatFailureSummary("build", results) } } - if len(failedNames) > 0 { - return fmt.Errorf("%d build(s) failed: %s", len(failedNames), strings.Join(failedNames, ", ")) - } return nil } diff --git a/tools/test-harness/cmd/root.go b/tools/test-harness/cmd/root.go index 686dfa7319..7f701befd3 100644 --- a/tools/test-harness/cmd/root.go +++ b/tools/test-harness/cmd/root.go @@ -1,11 +1,15 @@ package cmd import ( + "context" "fmt" + "strings" + "sync" "github.com/spf13/cobra" "github.com/replicate/cog/tools/test-harness/internal/manifest" + "github.com/replicate/cog/tools/test-harness/internal/report" "github.com/replicate/cog/tools/test-harness/internal/resolver" ) @@ -21,6 +25,7 @@ var ( sdkWheel string cleanImages bool keepOutputs bool + concurrency int ) // NewRootCommand creates the root command @@ -46,6 +51,7 @@ It reads the same manifest.yaml format as the Python version.`, rootCmd.PersistentFlags().StringVar(&sdkWheel, "sdk-wheel", "", "Path to pre-built SDK wheel") rootCmd.PersistentFlags().BoolVar(&cleanImages, "clean-images", false, "Remove Docker images after run (default: keep them)") rootCmd.PersistentFlags().BoolVar(&keepOutputs, "keep-outputs", false, "Preserve prediction outputs (images, files) in the work directory") + rootCmd.PersistentFlags().IntVar(&concurrency, "concurrency", 4, "Maximum number of models to build/test in parallel") // Subcommands rootCmd.AddCommand(newRunCommand()) @@ -77,3 +83,111 @@ func resolveSetup() (*manifest.Manifest, []manifest.Model, *resolver.Result, err models := mf.FilterModels(modelFilter, noGPU, gpuOnly) return mf, models, resolved, nil } + +// validateConcurrency checks that the concurrency flag is a valid value. +// errgroup.SetLimit panics on 0, and negative values mean unlimited. +func validateConcurrency() error { + if concurrency < 1 { + return fmt.Errorf("--concurrency must be at least 1, got %d", concurrency) + } + return nil +} + +// modelAction is a function that processes a single model and returns a result. +type modelAction[T any] func(ctx context.Context, model manifest.Model) *T + +// statusPrinter formats a per-model status line after processing completes. +type statusPrinter[T any] func(index, total int, model manifest.Model, result *T) string + +// runModels executes an action for each model, either sequentially or in parallel +// depending on the concurrency setting. It handles the common pattern of: +// - printing a "starting" line +// - running the action +// - printing a "done/failed" status line +// +// The results slice is pre-allocated by the caller. This function fills it in. +func runModels[T any]( + ctx context.Context, + models []manifest.Model, + results []T, + parallel bool, + action modelAction[T], + startLine func(index, total int, model manifest.Model) string, + statusLine statusPrinter[T], +) { + if parallel { + sem := make(chan struct{}, concurrency) + var wg sync.WaitGroup + var mu sync.Mutex + + for i, model := range models { + wg.Go(func() { + sem <- struct{}{} // acquire + defer func() { <-sem }() // release + + mu.Lock() + fmt.Print(startLine(i+1, len(models), model)) + mu.Unlock() + + result := action(ctx, model) + results[i] = *result + + mu.Lock() + fmt.Print(statusLine(i+1, len(models), model, result)) + mu.Unlock() + }) + } + wg.Wait() + } else { + for i, model := range models { + fmt.Print(startLine(i+1, len(models), model)) + result := action(ctx, model) + results[i] = *result + fmt.Print(statusLine(i+1, len(models), model, result)) + } + } +} + +// formatFailureSummary builds an error message with per-model failure details. +// +//nolint:gosec // G705: writes to strings.Builder, not an HTTP response — no XSS risk +func formatFailureSummary(action string, results []report.ModelResult) error { + var b strings.Builder + var failCount int + for _, r := range results { + if r.Passed || r.Skipped { + continue + } + failCount++ + fmt.Fprintf(&b, "\n x %s", r.Name) + if r.Error != "" { + // Show first line of the error + firstLine := r.Error + if idx := strings.Index(firstLine, "\n"); idx != -1 { + firstLine = firstLine[:idx] + } + fmt.Fprintf(&b, ": %s", firstLine) + } else { + // Summarize failed tests + for _, t := range r.TestResults { + if !t.Passed { + msg := t.Message + if idx := strings.Index(msg, "\n"); idx != -1 { + msg = msg[:idx] + } + fmt.Fprintf(&b, "\n test %q: %s", t.Description, msg) + } + } + for _, t := range r.TrainResults { + if !t.Passed { + msg := t.Message + if idx := strings.Index(msg, "\n"); idx != -1 { + msg = msg[:idx] + } + fmt.Fprintf(&b, "\n train %q: %s", t.Description, msg) + } + } + } + } + return fmt.Errorf("%d %s(s) failed:%s", failCount, action, b.String()) +} diff --git a/tools/test-harness/cmd/run.go b/tools/test-harness/cmd/run.go index 16585318e1..c7167d9779 100644 --- a/tools/test-harness/cmd/run.go +++ b/tools/test-harness/cmd/run.go @@ -4,10 +4,10 @@ import ( "context" "fmt" "os" - "strings" "github.com/spf13/cobra" + "github.com/replicate/cog/tools/test-harness/internal/manifest" "github.com/replicate/cog/tools/test-harness/internal/report" "github.com/replicate/cog/tools/test-harness/internal/runner" ) @@ -36,6 +36,10 @@ func runRun(ctx context.Context, outputFormat, outputFile string) error { return fmt.Errorf("invalid output format %q: must be 'console' or 'json'", outputFormat) } + if err := validateConcurrency(); err != nil { + return err + } + _, models, resolved, err := resolveSetup() if err != nil { return err @@ -51,7 +55,13 @@ func runRun(ctx context.Context, outputFormat, outputFile string) error { fmt.Println("No models to run") return nil } - fmt.Printf("Running %d model(s)\n\n", len(models)) + + parallel := concurrency > 1 && len(models) > 1 + if parallel { + fmt.Printf("Running %d model(s) with concurrency %d\n\n", len(models), concurrency) + } else { + fmt.Printf("Running %d model(s)\n\n", len(models)) + } // Create runner r, err := runner.New(runner.Options{ @@ -60,19 +70,48 @@ func runRun(ctx context.Context, outputFormat, outputFile string) error { SDKWheel: resolved.SDKWheel, CleanImages: cleanImages, KeepOutputs: keepOutputs, + Parallel: parallel, }) if err != nil { return fmt.Errorf("creating runner: %w", err) } - defer r.Cleanup() + defer func() { _ = r.Cleanup() }() // Run tests - var results []report.ModelResult - for _, model := range models { - fmt.Printf("Running %s...\n", model.Name) - result := r.RunModel(ctx, model) - results = append(results, *result) - } + results := make([]report.ModelResult, len(models)) + + runModels(ctx, models, results, parallel, + func(ctx context.Context, model manifest.Model) *report.ModelResult { + return r.RunModel(ctx, model) + }, + func(index, total int, model manifest.Model) string { + if parallel { + return fmt.Sprintf(" [%d/%d] Running %s...\n", index, total, model.Name) + } + return fmt.Sprintf("Running %s...\n", model.Name) + }, + func(index, total int, model manifest.Model, result *report.ModelResult) string { + testCount := len(result.TestResults) + len(result.TrainResults) + if parallel { + switch { + case result.Skipped: + return fmt.Sprintf(" [%d/%d] - %s (skipped: %s)\n", index, total, model.Name, result.SkipReason) + case result.Passed: + return fmt.Sprintf(" [%d/%d] + %s (%.1fs build, %d tests passed)\n", index, total, model.Name, result.BuildDuration, testCount) + default: + return fmt.Sprintf(" [%d/%d] x %s FAILED\n", index, total, model.Name) + } + } + switch { + case result.Skipped: + return fmt.Sprintf(" - %s (skipped: %s)\n", model.Name, result.SkipReason) + case result.Passed: + return fmt.Sprintf(" + %s (%.1fs build, %d tests passed)\n", model.Name, result.BuildDuration, testCount) + default: + return fmt.Sprintf(" x %s FAILED\n", model.Name) + } + }, + ) // Output results if outputFormat == "json" { @@ -111,15 +150,11 @@ func runRun(ctx context.Context, outputFormat, outputFile string) error { } // Check for failures - var failedNames []string for _, r := range results { if !r.Passed && !r.Skipped { - failedNames = append(failedNames, r.Name) + return formatFailureSummary("model", results) } } - if len(failedNames) > 0 { - return fmt.Errorf("%d model(s) failed: %s", len(failedNames), strings.Join(failedNames, ", ")) - } return nil } diff --git a/tools/test-harness/cmd/schema_compare.go b/tools/test-harness/cmd/schema_compare.go index aff44eb44a..724cf30418 100644 --- a/tools/test-harness/cmd/schema_compare.go +++ b/tools/test-harness/cmd/schema_compare.go @@ -8,6 +8,7 @@ import ( "github.com/spf13/cobra" + "github.com/replicate/cog/tools/test-harness/internal/manifest" "github.com/replicate/cog/tools/test-harness/internal/report" "github.com/replicate/cog/tools/test-harness/internal/runner" ) @@ -36,6 +37,10 @@ func runSchemaCompare(ctx context.Context, outputFormat, outputFile string) erro return fmt.Errorf("invalid output format %q: must be 'console' or 'json'", outputFormat) } + if err := validateConcurrency(); err != nil { + return err + } + _, models, resolved, err := resolveSetup() if err != nil { return err @@ -45,7 +50,13 @@ func runSchemaCompare(ctx context.Context, outputFormat, outputFile string) erro fmt.Println("No models to compare") return nil } - fmt.Printf("Comparing schemas for %d model(s)\n\n", len(models)) + + parallel := concurrency > 1 && len(models) > 1 + if parallel { + fmt.Printf("Comparing schemas for %d model(s) with concurrency %d\n\n", len(models), concurrency) + } else { + fmt.Printf("Comparing schemas for %d model(s)\n\n", len(models)) + } // Create runner r, err := runner.New(runner.Options{ @@ -54,19 +65,39 @@ func runSchemaCompare(ctx context.Context, outputFormat, outputFile string) erro SDKWheel: resolved.SDKWheel, CleanImages: cleanImages, KeepOutputs: keepOutputs, + Parallel: parallel, }) if err != nil { return fmt.Errorf("creating runner: %w", err) } - defer r.Cleanup() + defer func() { _ = r.Cleanup() }() // Compare schemas - var results []report.SchemaCompareResult - for _, model := range models { - fmt.Printf("Comparing %s...\n", model.Name) - result := r.CompareSchema(ctx, model) - results = append(results, *result) - } + results := make([]report.SchemaCompareResult, len(models)) + + runModels(ctx, models, results, parallel, + func(ctx context.Context, model manifest.Model) *report.SchemaCompareResult { + return r.CompareSchema(ctx, model) + }, + func(index, total int, model manifest.Model) string { + if parallel { + return fmt.Sprintf(" [%d/%d] Comparing %s...\n", index, total, model.Name) + } + return fmt.Sprintf("Comparing %s...\n", model.Name) + }, + func(index, total int, model manifest.Model, result *report.SchemaCompareResult) string { + if parallel { + if result.Passed { + return fmt.Sprintf(" [%d/%d] + %s schemas match\n", index, total, model.Name) + } + return fmt.Sprintf(" [%d/%d] x %s FAILED\n", index, total, model.Name) + } + if result.Passed { + return fmt.Sprintf(" + %s schemas match\n", model.Name) + } + return fmt.Sprintf(" x %s FAILED\n", model.Name) + }, + ) // Output results if outputFormat == "json" { @@ -105,14 +136,24 @@ func runSchemaCompare(ctx context.Context, outputFormat, outputFile string) erro } // Check for failures - var failedNames []string + var failedDetails []string for _, r := range results { if !r.Passed { - failedNames = append(failedNames, r.Name) + detail := r.Name + if r.Error != "" { + firstLine := r.Error + if idx := strings.Index(firstLine, "\n"); idx != -1 { + firstLine = firstLine[:idx] + } + detail += ": " + firstLine + } else if r.Diff != "" { + detail += ": schemas differ" + } + failedDetails = append(failedDetails, " x "+detail) } } - if len(failedNames) > 0 { - return fmt.Errorf("%d schema comparison(s) failed: %s", len(failedNames), strings.Join(failedNames, ", ")) + if len(failedDetails) > 0 { + return fmt.Errorf("%d schema comparison(s) failed:\n%s", len(failedDetails), strings.Join(failedDetails, "\n")) } return nil diff --git a/tools/test-harness/internal/manifest/manifest.go b/tools/test-harness/internal/manifest/manifest.go index c325204fd4..8be4a3ef8e 100644 --- a/tools/test-harness/internal/manifest/manifest.go +++ b/tools/test-harness/internal/manifest/manifest.go @@ -23,18 +23,20 @@ type Defaults struct { // Model represents a single model definition type Model struct { - Name string `yaml:"name"` - Repo string `yaml:"repo"` - Path string `yaml:"path"` - GPU bool `yaml:"gpu"` - Timeout int `yaml:"timeout"` - RequiresEnv []string `yaml:"requires_env"` - Env map[string]string `yaml:"env"` - SDKVersion string `yaml:"sdk_version"` - CogYAMLOverrides map[string]any `yaml:"cog_yaml_overrides"` - Setup []string `yaml:"setup"` - Tests []TestCase `yaml:"tests"` - TrainTests []TestCase `yaml:"train_tests"` + Name string `yaml:"name"` + Repo string `yaml:"repo"` + Path string `yaml:"path"` + GPU bool `yaml:"gpu"` + Timeout int `yaml:"timeout"` + RequiresEnv []string `yaml:"requires_env"` + RequiresTools []string `yaml:"requires_tools"` + Env map[string]string `yaml:"env"` + SDKVersion string `yaml:"sdk_version"` + CogYAMLOverrides map[string]any `yaml:"cog_yaml_overrides"` + Setup []string `yaml:"setup"` + SkipSchemaValidation bool `yaml:"skip_schema_validation"` + Tests []TestCase `yaml:"tests"` + TrainTests []TestCase `yaml:"train_tests"` } // TestCase represents a single test case diff --git a/tools/test-harness/internal/patcher/patcher.go b/tools/test-harness/internal/patcher/patcher.go index 847a3f605b..e289b45648 100644 --- a/tools/test-harness/internal/patcher/patcher.go +++ b/tools/test-harness/internal/patcher/patcher.go @@ -2,6 +2,7 @@ package patcher import ( "fmt" + "maps" "os" "gopkg.in/yaml.v3" @@ -44,7 +45,7 @@ func Patch(cogYAMLPath string, sdkVersion string, overrides map[string]any) erro return fmt.Errorf("marshaling cog.yaml: %w", err) } - if err := os.WriteFile(cogYAMLPath, output, 0644); err != nil { + if err := os.WriteFile(cogYAMLPath, output, 0o644); err != nil { return fmt.Errorf("writing cog.yaml: %w", err) } @@ -54,9 +55,7 @@ func Patch(cogYAMLPath string, sdkVersion string, overrides map[string]any) erro // deepMerge recursively merges override into base func deepMerge(base, override map[string]any) map[string]any { result := make(map[string]any) - for k, v := range base { - result[k] = v - } + maps.Copy(result, base) for k, v := range override { if baseVal, ok := result[k]; ok { diff --git a/tools/test-harness/internal/patcher/patcher_test.go b/tools/test-harness/internal/patcher/patcher_test.go index 19cc109755..99a11811e7 100644 --- a/tools/test-harness/internal/patcher/patcher_test.go +++ b/tools/test-harness/internal/patcher/patcher_test.go @@ -18,7 +18,7 @@ func TestPatch(t *testing.T) { python_version: "3.10" predict: predict.py ` - require.NoError(t, os.WriteFile(cogYAML, []byte(content), 0644)) + require.NoError(t, os.WriteFile(cogYAML, []byte(content), 0o644)) require.NoError(t, Patch(cogYAML, "0.16.12", nil)) data, err := os.ReadFile(cogYAML) @@ -32,7 +32,7 @@ predict: predict.py python_version: "3.10" predict: predict.py ` - require.NoError(t, os.WriteFile(cogYAML, []byte(content), 0644)) + require.NoError(t, os.WriteFile(cogYAML, []byte(content), 0o644)) overrides := map[string]any{ "build": map[string]any{ @@ -53,7 +53,7 @@ predict: predict.py python_version: "3.10" predict: predict.py ` - require.NoError(t, os.WriteFile(cogYAML, []byte(content), 0644)) + require.NoError(t, os.WriteFile(cogYAML, []byte(content), 0o644)) overrides := map[string]any{ "predict": "new_predict.py", diff --git a/tools/test-harness/internal/report/report.go b/tools/test-harness/internal/report/report.go index 6021784e88..b36debde12 100644 --- a/tools/test-harness/internal/report/report.go +++ b/tools/test-harness/internal/report/report.go @@ -36,7 +36,8 @@ type SchemaCompareResult struct { Name string `json:"name"` Passed bool `json:"passed"` Error string `json:"error,omitempty"` - Diff string `json:"diff,omitempty"` + Diff string `json:"diff,omitempty"` // Real differences (failures) + ExpectedDiff string `json:"expected_diff,omitempty"` // Known limitations (informational) StaticBuild float64 `json:"static_build_s"` RuntimeBuild float64 `json:"runtime_build_s"` } @@ -76,7 +77,7 @@ func ConsoleReport(results []ModelResult, sdkVersion, cogVersion string) { } writeStatus("x", r.Name, firstLine, r.GPU) // Print full error details indented below - for _, line := range strings.Split(r.Error, "\n") { + for line := range strings.SplitSeq(r.Error, "\n") { if line != "" { fmt.Printf(" %s\n", line) } @@ -123,7 +124,7 @@ func ConsoleReport(results []ModelResult, sdkVersion, cogVersion string) { // Print individual failures with full output for _, t := range failures { fmt.Printf(" x %s:\n", t.Description) - for _, line := range strings.Split(t.Message, "\n") { + for line := range strings.SplitSeq(t.Message, "\n") { fmt.Printf(" %s\n", line) } } @@ -192,11 +193,12 @@ func JSONReport(results []ModelResult, sdkVersion, cogVersion string) map[string failed := 0 skipped := 0 for _, r := range results { - if r.Skipped { + switch { + case r.Skipped: skipped++ - } else if r.Passed { + case r.Passed: passed++ - } else { + default: failed++ } } @@ -242,7 +244,7 @@ func SchemaCompareConsoleReport(results []SchemaCompareResult, cogVersion string firstLine = firstLine[:idx] } writeStatus("x", r.Name, firstLine, false) - for _, line := range strings.Split(r.Error, "\n") { + for line := range strings.SplitSeq(r.Error, "\n") { if line != "" { fmt.Printf(" %s\n", line) } @@ -253,15 +255,39 @@ func SchemaCompareConsoleReport(results []SchemaCompareResult, cogVersion string if r.Passed { timing := fmt.Sprintf("(static %.1fs, runtime %.1fs)", r.StaticBuild, r.RuntimeBuild) - writeStatus("+", r.Name, fmt.Sprintf("schemas match %s", timing), false) + status := "schemas match" + if r.ExpectedDiff != "" { + status = "schemas match (with expected differences)" + } + writeStatus("+", r.Name, fmt.Sprintf("%s %s", status, timing), false) passed++ + + // Show expected differences as informational notes + if r.ExpectedDiff != "" { + fmt.Println() + fmt.Printf(" Known limitations (not failures):\n") + for line := range strings.SplitSeq(r.ExpectedDiff, "\n") { + fmt.Printf(" %s\n", line) + } + fmt.Println() + } } else { writeStatus("x", r.Name, "schemas differ", false) failed++ if r.Diff != "" { - for _, line := range strings.Split(r.Diff, "\n") { - fmt.Printf(" %s\n", line) + fmt.Println() + for line := range strings.SplitSeq(r.Diff, "\n") { + fmt.Printf(" %s\n", line) + } + fmt.Println() + } + // Also show expected diffs for context + if r.ExpectedDiff != "" { + fmt.Printf(" Additionally, known limitations (not counted as failures):\n") + for line := range strings.SplitSeq(r.ExpectedDiff, "\n") { + fmt.Printf(" %s\n", line) } + fmt.Println() } } } @@ -339,7 +365,7 @@ func round(val float64, precision int) float64 { // SaveResults saves results to a JSON file in the results directory func SaveResults(results []ModelResult, sdkVersion, cogVersion string) (string, error) { resultsDir := "results" - if err := os.MkdirAll(resultsDir, 0755); err != nil { + if err := os.MkdirAll(resultsDir, 0o755); err != nil { return "", fmt.Errorf("creating results dir: %w", err) } diff --git a/tools/test-harness/internal/resolver/resolver.go b/tools/test-harness/internal/resolver/resolver.go index 8051356332..84194efad7 100644 --- a/tools/test-harness/internal/resolver/resolver.go +++ b/tools/test-harness/internal/resolver/resolver.go @@ -63,21 +63,22 @@ func Resolve(cogBinary, cogVersion, cogRef, sdkVersion, sdkWheel string, manifes result.CogVersion = version // Determine SDK wheel: explicit --sdk-wheel wins over ref-built - if sdkWheel != "" { + switch { + case sdkWheel != "": result.SDKWheel = sdkWheel result.SDKVersion = fmt.Sprintf("wheel:%s", filepath.Base(sdkWheel)) if sdkVersion != "" { result.SDKPatchVersion = sdkVersion result.SDKVersion = sdkVersion } - } else if wheel != "" { + case wheel != "": result.SDKWheel = wheel result.SDKVersion = version if sdkVersion != "" { result.SDKPatchVersion = sdkVersion result.SDKVersion = sdkVersion } - } else { + default: // Resolve SDK version from PyPI or explicit flag sdkVer, err := resolveSDKVersion(sdkVersion, manifestDefaults) if err != nil { @@ -168,7 +169,7 @@ func resolveLatestCogVersion() (string, error) { setGitHubAuth(req) client := &http.Client{Timeout: 30 * time.Second} - resp, err := client.Do(req) + resp, err := client.Do(req) //nolint:gosec // URL is constructed from a known constant if err != nil { return "", fmt.Errorf("fetching GitHub releases: %w", err) } @@ -208,7 +209,7 @@ func resolveLatestPyPIVersion() (string, error) { req.Header.Set("Accept", "application/json") client := &http.Client{Timeout: 30 * time.Second} - resp, err := client.Do(req) + resp, err := client.Do(req) //nolint:gosec // URL is a known PyPI API constant if err != nil { return "", fmt.Errorf("fetching PyPI: %w", err) } @@ -266,7 +267,7 @@ func downloadCogBinary(tag string) (dest string, err error) { return "", fmt.Errorf("cannot determine home directory: %w", err) } baseDir := filepath.Join(home, ".cache", "cog-harness", "bin") - if err := os.MkdirAll(baseDir, 0755); err != nil { + if err := os.MkdirAll(baseDir, 0o755); err != nil { return "", fmt.Errorf("creating bin cache dir: %w", err) } tmpDir, err := os.MkdirTemp(baseDir, "cog-bin-*") @@ -275,7 +276,7 @@ func downloadCogBinary(tag string) (dest string, err error) { } defer func() { if err != nil { - os.RemoveAll(tmpDir) + _ = os.RemoveAll(tmpDir) } }() @@ -289,7 +290,7 @@ func downloadCogBinary(tag string) (dest string, err error) { setGitHubAuth(req) client := &http.Client{Timeout: 120 * time.Second} - resp, err := client.Do(req) + resp, err := client.Do(req) //nolint:gosec // URL is constructed from GitHub releases if err != nil { return "", fmt.Errorf("downloading cog binary: %w", err) } @@ -299,13 +300,13 @@ func downloadCogBinary(tag string) (dest string, err error) { return "", fmt.Errorf("download returned %d", resp.StatusCode) } - f, err := os.OpenFile(dest, os.O_CREATE|os.O_WRONLY, 0755) + f, err := os.OpenFile(dest, os.O_CREATE|os.O_WRONLY, 0o755) if err != nil { return "", fmt.Errorf("creating binary file: %w", err) } if _, copyErr := io.Copy(f, resp.Body); copyErr != nil { - f.Close() + _ = f.Close() return "", fmt.Errorf("writing binary: %w", copyErr) } @@ -336,7 +337,7 @@ func verifyDownloadedBinary(tag, assetName, dest string) error { setGitHubAuth(req) client := &http.Client{Timeout: 30 * time.Second} - resp, err := client.Do(req) + resp, err := client.Do(req) //nolint:gosec // URL is constructed from GitHub releases if err != nil { return fmt.Errorf("downloading checksum file: %w", err) } @@ -369,7 +370,7 @@ func verifyDownloadedBinary(tag, assetName, dest string) error { } func parseChecksum(content, assetName string) (string, error) { - for _, line := range strings.Split(content, "\n") { + for line := range strings.SplitSeq(content, "\n") { line = strings.TrimSpace(line) if line == "" { continue @@ -470,7 +471,7 @@ func buildCogFromRef(ref string) (string, string, string, error) { // Build SDK wheel wheelDir := filepath.Join(tmpDir, "dist") - if err := os.MkdirAll(wheelDir, 0755); err != nil { + if err := os.MkdirAll(wheelDir, 0o755); err != nil { return "", "", "", fmt.Errorf("creating wheel dir: %w", err) } diff --git a/tools/test-harness/internal/runner/runner.go b/tools/test-harness/internal/runner/runner.go index d36c2bf9e3..406974f195 100644 --- a/tools/test-harness/internal/runner/runner.go +++ b/tools/test-harness/internal/runner/runner.go @@ -14,9 +14,11 @@ import ( "runtime" "sort" "strings" + "sync" "time" "golang.org/x/sync/errgroup" + "golang.org/x/sync/singleflight" "github.com/replicate/cog/tools/test-harness/internal/manifest" "github.com/replicate/cog/tools/test-harness/internal/patcher" @@ -26,6 +28,95 @@ import ( const openapiSchemaLabel = "run.cog.openapi_schema" +// maxPrefixBuf is the maximum size of the prefixWriter line buffer before +// it is force-flushed. This prevents unbounded memory growth when a +// subprocess writes very long lines without newlines (e.g. progress bars). +const maxPrefixBuf = 64 * 1024 // 64 KiB + +// prefixWriter wraps an io.Writer and prepends a prefix to each line. +// Partial lines (no trailing newline) are buffered until a newline arrives +// or the buffer exceeds maxPrefixBuf. +type prefixWriter struct { + prefix string + dest io.Writer + mu sync.Mutex + buf []byte +} + +func newPrefixWriter(dest io.Writer, modelName string) *prefixWriter { + return &prefixWriter{ + prefix: fmt.Sprintf("[%-20s] ", modelName), + dest: dest, + } +} + +func (pw *prefixWriter) Write(p []byte) (int, error) { + pw.mu.Lock() + defer pw.mu.Unlock() + + total := len(p) + pw.buf = append(pw.buf, p...) + + for { + idx := bytes.IndexByte(pw.buf, '\n') + if idx < 0 { + break + } + line := pw.buf[:idx] + pw.buf = pw.buf[idx+1:] + if _, err := fmt.Fprintf(pw.dest, "%s%s\n", pw.prefix, line); err != nil { + return total, err + } + } + + // Force-flush if the buffer has grown too large (e.g. no newlines in output). + if len(pw.buf) > maxPrefixBuf { + if _, err := fmt.Fprintf(pw.dest, "%s%s\n", pw.prefix, pw.buf); err != nil { + return total, err + } + pw.buf = pw.buf[:0] + } + + return total, nil +} + +// Flush writes any remaining buffered content (partial line without trailing newline). +func (pw *prefixWriter) Flush() { + pw.mu.Lock() + defer pw.mu.Unlock() + + if len(pw.buf) > 0 { + _, _ = fmt.Fprintf(pw.dest, "%s%s\n", pw.prefix, pw.buf) + pw.buf = nil + } +} + +// modelLoggers returns stdout/stderr writers for status logging only. +// In parallel mode, output is prefixed with the model name. +// In sequential mode, output streams directly to the terminal. +// Use modelOutput when captured output is needed for error reporting. +func (r *Runner) modelLoggers(modelName string) (logw io.Writer, flush func()) { + if r.opts.Parallel { + pw := newPrefixWriter(os.Stderr, modelName) + return pw, pw.Flush + } + return os.Stderr, func() {} +} + +// modelOutput returns stdout/stderr writers for a model that also capture +// output into a buffer for error reporting. +// In parallel mode, output is prefixed with the model name. +// In sequential mode, output streams directly to the terminal. +func (r *Runner) modelOutput(modelName string) (stdout, stderr io.Writer, capture *bytes.Buffer, flush func()) { + var buf bytes.Buffer + if r.opts.Parallel { + pw := newPrefixWriter(os.Stderr, modelName) + w := io.MultiWriter(pw, &buf) + return w, w, &buf, pw.Flush + } + return io.MultiWriter(os.Stdout, &buf), io.MultiWriter(os.Stderr, &buf), &buf, func() {} +} + // Options configures a Runner. type Options struct { CogBinary string @@ -34,14 +125,17 @@ type Options struct { FixturesDir string CleanImages bool KeepOutputs bool + Parallel bool // Prefix output lines with model name (for parallel execution) } -// Runner orchestrates the test lifecycle +// Runner orchestrates the test lifecycle. +// It is safe to call RunModel, BuildModel, and CompareSchema concurrently +// from multiple goroutines. type Runner struct { opts Options fixturesDir string workDir string - clonedRepos map[string]string + cloneGroup singleflight.Group // deduplicates concurrent clones of the same repo } // New creates a new Runner @@ -65,7 +159,7 @@ func New(opts Options) (*Runner, error) { return nil, fmt.Errorf("cannot determine home directory for work dir (set $HOME): %w", err) } baseDir := filepath.Join(home, ".cache", "cog-harness") - if err := os.MkdirAll(baseDir, 0755); err != nil { + if err := os.MkdirAll(baseDir, 0o755); err != nil { return nil, fmt.Errorf("creating harness cache dir: %w", err) } workDir, err := os.MkdirTemp(baseDir, "run-*") @@ -77,7 +171,6 @@ func New(opts Options) (*Runner, error) { opts: opts, fixturesDir: fixturesDir, workDir: workDir, - clonedRepos: make(map[string]string), }, nil } @@ -144,6 +237,8 @@ func (r *Runner) Cleanup() error { // RunModel runs all tests for a single model func (r *Runner) RunModel(ctx context.Context, model manifest.Model) *report.ModelResult { + logw, flush := r.modelLoggers(model.Name) + result := &report.ModelResult{ Name: model.Name, Passed: true, @@ -161,46 +256,71 @@ func (r *Runner) RunModel(ctx context.Context, model manifest.Model) *report.Mod } // Prepare model + _, _ = fmt.Fprintf(logw, "=== Preparing %s...\n", model.Name) modelDir, err := r.prepareModel(ctx, model) if err != nil { result.Passed = false result.Error = fmt.Sprintf("Preparation failed: %v", err) + flush() return result } // Build + _, _ = fmt.Fprintf(logw, "=== Building %s (timeout %ds)...\n", model.Name, model.Timeout) buildStart := time.Now() if err := r.buildModel(ctx, modelDir, model); err != nil { result.Passed = false result.BuildDuration = time.Since(buildStart).Seconds() result.Error = fmt.Sprintf("Build failed: %v", err) + _, _ = fmt.Fprintf(logw, "=== Build FAILED after %.1fs\n", result.BuildDuration) + flush() return result } result.BuildDuration = time.Since(buildStart).Seconds() + _, _ = fmt.Fprintf(logw, "=== Build complete (%.1fs)\n", result.BuildDuration) // Run train tests - for _, tc := range model.TrainTests { + for i, tc := range model.TrainTests { + desc := tc.Description + if desc == "" { + desc = "train" + } + _, _ = fmt.Fprintf(logw, "=== Train test %d/%d: %s (timeout %ds)...\n", i+1, len(model.TrainTests), desc, model.Timeout) tr := r.runTrainTest(ctx, modelDir, model, tc) result.TrainResults = append(result.TrainResults, tr) - if !tr.Passed { + if tr.Passed { + _, _ = fmt.Fprintf(logw, "=== Train test %d/%d PASSED (%.1fs)\n", i+1, len(model.TrainTests), tr.DurationSec) + } else { + _, _ = fmt.Fprintf(logw, "=== Train test %d/%d FAILED (%.1fs)\n", i+1, len(model.TrainTests), tr.DurationSec) result.Passed = false } } // Run predict tests - for _, tc := range model.Tests { + for i, tc := range model.Tests { + desc := tc.Description + if desc == "" { + desc = "predict" + } + _, _ = fmt.Fprintf(logw, "=== Predict test %d/%d: %s (timeout %ds)...\n", i+1, len(model.Tests), desc, model.Timeout) tr := r.runPredictTest(ctx, modelDir, model, tc) result.TestResults = append(result.TestResults, tr) - if !tr.Passed { + if tr.Passed { + _, _ = fmt.Fprintf(logw, "=== Predict test %d/%d PASSED (%.1fs)\n", i+1, len(model.Tests), tr.DurationSec) + } else { + _, _ = fmt.Fprintf(logw, "=== Predict test %d/%d FAILED (%.1fs)\n", i+1, len(model.Tests), tr.DurationSec) result.Passed = false } } + flush() return result } // BuildModel builds a model image only func (r *Runner) BuildModel(ctx context.Context, model manifest.Model) *report.ModelResult { + logw, flush := r.modelLoggers(model.Name) + result := &report.ModelResult{ Name: model.Name, Passed: true, @@ -218,23 +338,30 @@ func (r *Runner) BuildModel(ctx context.Context, model manifest.Model) *report.M } // Prepare model + _, _ = fmt.Fprintf(logw, "=== Preparing %s...\n", model.Name) modelDir, err := r.prepareModel(ctx, model) if err != nil { result.Passed = false result.Error = fmt.Sprintf("Preparation failed: %v", err) + flush() return result } // Build + _, _ = fmt.Fprintf(logw, "=== Building %s (timeout %ds)...\n", model.Name, model.Timeout) buildStart := time.Now() if err := r.buildModel(ctx, modelDir, model); err != nil { result.Passed = false result.BuildDuration = time.Since(buildStart).Seconds() result.Error = fmt.Sprintf("Build failed: %v", err) + _, _ = fmt.Fprintf(logw, "=== Build FAILED after %.1fs\n", result.BuildDuration) + flush() return result } result.BuildDuration = time.Since(buildStart).Seconds() + _, _ = fmt.Fprintf(logw, "=== Build complete (%.1fs)\n", result.BuildDuration) + flush() return result } @@ -258,8 +385,8 @@ func (r *Runner) CompareSchema(ctx context.Context, model manifest.Model) *repor // Always clean up schema comparison images when done defer func() { - exec.Command("docker", "rmi", "-f", staticTag).Run() - exec.Command("docker", "rmi", "-f", runtimeTag).Run() + _ = exec.Command("docker", "rmi", "-f", staticTag).Run() + _ = exec.Command("docker", "rmi", "-f", runtimeTag).Run() }() staticDir := filepath.Join(r.workDir, fmt.Sprintf("schema-static-%s", model.Name)) @@ -286,7 +413,7 @@ func (r *Runner) CompareSchema(ctx context.Context, model manifest.Model) *repor runtimeStart := time.Now() g.Go(func() error { - staticErr = r.buildModelWithEnv(ctx, staticDir, model, staticTag, map[string]string{"COG_STATIC_SCHEMA": "1"}) + staticErr = r.quietBuildModelWithEnv(ctx, staticDir, model, staticTag, map[string]string{"COG_STATIC_SCHEMA": "1"}) if staticErr != nil { return nil // Don't fail the group, we'll check errors after } @@ -296,7 +423,7 @@ func (r *Runner) CompareSchema(ctx context.Context, model manifest.Model) *repor }) g.Go(func() error { - runtimeErr = r.buildModelWithEnv(ctx, runtimeDir, model, runtimeTag, map[string]string{}) + runtimeErr = r.quietBuildModelWithEnv(ctx, runtimeDir, model, runtimeTag, map[string]string{}) if runtimeErr != nil { return nil } @@ -307,7 +434,7 @@ func (r *Runner) CompareSchema(ctx context.Context, model manifest.Model) *repor if err := g.Wait(); err != nil { result.Passed = false - result.Error = fmt.Sprintf("context cancelled: %v", err) + result.Error = fmt.Sprintf("context canceled: %v", err) return result } @@ -349,11 +476,15 @@ func (r *Runner) CompareSchema(ctx context.Context, model manifest.Model) *repor return result } - // Compare - diff := jsonDiff(staticJSON, runtimeJSON) - if diff != "" { + // Compare and classify differences + cmp := jsonCompare(staticJSON, runtimeJSON) + + if len(cmp.Real) > 0 { result.Passed = false - result.Diff = diff + result.Diff = formatDiffHunks(cmp.Real) + } + if len(cmp.Expected) > 0 { + result.ExpectedDiff = formatDiffHunks(cmp.Expected) } return result @@ -377,12 +508,20 @@ func (r *Runner) prepareModel(ctx context.Context, model manifest.Model) (string } modelDir = dest } else { - // Clone repo + // Clone repo (shared cache, thread-safe) repoDir, err := r.cloneRepo(ctx, model.Repo) if err != nil { return "", err } - modelDir = filepath.Join(repoDir, model.Path) + + // Each model gets its own copy so that setup commands (e.g. + // select.sh) don't clobber each other when running in parallel. + srcDir := filepath.Join(repoDir, model.Path) + dest := filepath.Join(r.workDir, fmt.Sprintf("model-%s", model.Name)) + if err := copyDir(srcDir, dest); err != nil { + return "", fmt.Errorf("copying repo for model %s: %w", model.Name, err) + } + modelDir = dest } // Run setup commands (e.g. script/select.sh to generate cog.yaml) @@ -390,9 +529,14 @@ func (r *Runner) prepareModel(ctx context.Context, model manifest.Model) (string return "", fmt.Errorf("running setup commands: %w", err) } - if _, err := os.Stat(filepath.Join(modelDir, "cog.yaml")); err != nil { + cogYAMLPath := filepath.Join(modelDir, "cog.yaml") + info, err := os.Stat(cogYAMLPath) + if err != nil { return "", fmt.Errorf("no cog.yaml in %s (did setup commands run correctly?)", modelDir) } + if info.Size() == 0 { + return "", fmt.Errorf("cog.yaml in %s is empty (setup commands may have failed silently — check that tools like yq are installed)", modelDir) + } // Patch cog.yaml sdkVersion := model.SDKVersion @@ -406,47 +550,88 @@ func (r *Runner) prepareModel(ctx context.Context, model manifest.Model) (string return modelDir, nil } +// checkRequiredTools verifies that all tools listed in requires_tools are +// available on PATH. Returns a descriptive error listing missing tools and +// install hints when possible. +func checkRequiredTools(tools []string) error { + var missing []string + for _, tool := range tools { + if _, err := exec.LookPath(tool); err != nil { + missing = append(missing, tool) + } + } + if len(missing) == 0 { + return nil + } + return fmt.Errorf("required tool(s) not found on PATH: %s", strings.Join(missing, ", ")) +} + // runSetupCommands executes the model's setup commands in the model directory. // Setup commands run after clone/copy but before cog.yaml validation and patching. // This is used for models that need preparation steps like generating cog.yaml // from templates (e.g. "script/select.sh dev" in replicate/cog-flux). func (r *Runner) runSetupCommands(ctx context.Context, modelDir string, model manifest.Model) error { + // Check required tools before running any setup commands + if err := checkRequiredTools(model.RequiresTools); err != nil { + return err + } + + stdout, stderr, capture, flush := r.modelOutput(model.Name) + for _, cmdStr := range model.Setup { - fmt.Printf(" Running setup: %s\n", cmdStr) - cmd := exec.CommandContext(ctx, "sh", "-c", cmdStr) + _, _ = fmt.Fprintf(stderr, " Running setup: %s\n", cmdStr) + // Use bash with strict mode so any failing command in the + // script (e.g. a missing yq binary) is caught immediately + // rather than silently producing an empty/invalid cog.yaml. + // We use bash (not sh) because dash does not support pipefail. + cmd := exec.CommandContext(ctx, "bash", "-euo", "pipefail", "-c", cmdStr) cmd.Dir = modelDir - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr cmd.Env = os.Environ() for k, v := range model.Env { cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", k, os.ExpandEnv(v))) } + cmd.Stdout = stdout + cmd.Stderr = stderr if err := cmd.Run(); err != nil { - return fmt.Errorf("setup command %q failed: %w", cmdStr, err) + flush() + // Truncate captured output to avoid unwieldy error messages, + // matching the pattern used in buildModelWithEnv. + output := capture.String() + const maxTail = 2000 + if len(output) > maxTail { + output = "...\n" + output[len(output)-maxTail:] + } + return fmt.Errorf("setup command %q failed: %w\n%s", cmdStr, err, output) } } + flush() return nil } +// cloneRepo clones a repo once and caches the result. Thread-safe. +// Uses singleflight to deduplicate concurrent clones of the same repo +// without holding a mutex during the (potentially slow) git clone. func (r *Runner) cloneRepo(ctx context.Context, repo string) (string, error) { - if dir, ok := r.clonedRepos[repo]; ok { - return dir, nil - } + result, err, _ := r.cloneGroup.Do(repo, func() (any, error) { + dest := filepath.Join(r.workDir, strings.ReplaceAll(repo, "/", "--")) - dest := filepath.Join(r.workDir, strings.ReplaceAll(repo, "/", "--")) + // Remove if exists + _ = os.RemoveAll(dest) - // Remove if exists - os.RemoveAll(dest) + url := fmt.Sprintf("https://github.com/%s.git", repo) + cmd := exec.CommandContext(ctx, "git", "clone", "--depth=1", url, dest) + var stderr bytes.Buffer + cmd.Stderr = &stderr + if err := cmd.Run(); err != nil { + return "", fmt.Errorf("cloning %s: %w\n%s", repo, err, stderr.String()) + } - url := fmt.Sprintf("https://github.com/%s.git", repo) - cmd := exec.CommandContext(ctx, "git", "clone", "--depth=1", url, dest) - cmd.Stderr = os.Stderr - if err := cmd.Run(); err != nil { - return "", fmt.Errorf("cloning %s: %w", repo, err) + return dest, nil + }) + if err != nil { + return "", err } - - r.clonedRepos[repo] = dest - return dest, nil + return result.(string), nil } func (r *Runner) buildModel(ctx context.Context, modelDir string, model manifest.Model) error { @@ -454,7 +639,22 @@ func (r *Runner) buildModel(ctx context.Context, modelDir string, model manifest return r.buildModelWithEnv(ctx, modelDir, model, imageTag, nil) } +// buildModelWithEnv builds a model image, streaming output to the terminal +// (or prefixed in parallel mode) for real-time progress visibility. func (r *Runner) buildModelWithEnv(ctx context.Context, modelDir string, model manifest.Model, imageTag string, extraEnv map[string]string) error { + _, stderr, capture, flush := r.modelOutput(model.Name) + return r.doBuild(ctx, modelDir, model, imageTag, extraEnv, stderr, capture, flush) +} + +// quietBuildModelWithEnv builds a model image without streaming output to +// the terminal. Output is captured for inclusion in error messages on failure. +// Used by CompareSchema where build logs would interleave with diff results. +func (r *Runner) quietBuildModelWithEnv(ctx context.Context, modelDir string, model manifest.Model, imageTag string, extraEnv map[string]string) error { + var capture bytes.Buffer + return r.doBuild(ctx, modelDir, model, imageTag, extraEnv, &capture, &capture, func() {}) +} + +func (r *Runner) doBuild(ctx context.Context, modelDir string, model manifest.Model, imageTag string, extraEnv map[string]string, output io.Writer, capture *bytes.Buffer, flush func()) error { // Set timeout timeout := model.Timeout if timeout == 0 { @@ -463,7 +663,11 @@ func (r *Runner) buildModelWithEnv(ctx context.Context, modelDir string, model m ctx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second) defer cancel() - cmd := exec.CommandContext(ctx, r.opts.CogBinary, "build", "-t", imageTag) + buildArgs := []string{"build", "-t", imageTag} + if model.SkipSchemaValidation { + buildArgs = append(buildArgs, "--skip-schema-validation") + } + cmd := exec.CommandContext(ctx, r.opts.CogBinary, buildArgs...) cmd.Dir = modelDir cmd.Env = os.Environ() if r.opts.SDKWheel != "" { @@ -488,13 +692,18 @@ func (r *Runner) buildModelWithEnv(ctx context.Context, modelDir string, model m cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", k, v)) } - // Stream build output to stderr in real-time so the user can see progress, - // while also capturing it for error reporting if the build fails. - var outputBuf bytes.Buffer - cmd.Stdout = io.MultiWriter(os.Stderr, &outputBuf) - cmd.Stderr = io.MultiWriter(os.Stderr, &outputBuf) - if err := cmd.Run(); err != nil { - return fmt.Errorf("%w\n%s", err, outputBuf.String()) + cmd.Stdout = output + cmd.Stderr = output + err := cmd.Run() + flush() + if err != nil { + // Include the last portion of build output for context. + out := capture.String() + const maxTail = 2000 + if len(out) > maxTail { + out = "...\n" + out[len(out)-maxTail:] + } + return fmt.Errorf("%w\n%s", err, out) } return nil } @@ -526,8 +735,15 @@ func (r *Runner) runCogTest(ctx context.Context, modelDir string, model manifest start := time.Now() - // Build command - args := []string{command} + // Set timeout + timeout := model.Timeout + if timeout == 0 { + timeout = 300 + } + + // Build command — pass setup-timeout matching the model timeout so + // cog predict doesn't kill the container during model weight downloads. + args := []string{command, "--setup-timeout", fmt.Sprintf("%d", timeout)} keys := make([]string, 0, len(tc.Inputs)) for k := range tc.Inputs { keys = append(keys, k) @@ -539,12 +755,6 @@ func (r *Runner) runCogTest(ctx context.Context, modelDir string, model manifest args = append(args, "-i", fmt.Sprintf("%s=%s", key, resolved)) } - // Set timeout - timeout := model.Timeout - if timeout == 0 { - timeout = 300 - } - ctx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second) defer cancel() @@ -634,7 +844,7 @@ func (r *Runner) resolveInput(value any) string { func extractOutput(stdout, stderr, modelDir string) string { // For file outputs (e.g. images), cog writes the file to CWD and prints // "Written output to: " on stderr. Check stderr for this pattern. - for _, line := range strings.Split(stderr, "\n") { + for line := range strings.SplitSeq(stderr, "\n") { if strings.Contains(line, "Written output to:") { parts := strings.SplitN(line, "Written output to:", 2) if len(parts) == 2 { @@ -671,7 +881,7 @@ func copyDir(src, dst string) error { dstPath := filepath.Join(dst, rel) if d.IsDir() { - return os.MkdirAll(dstPath, 0755) + return os.MkdirAll(dstPath, 0o755) } data, err := os.ReadFile(path) @@ -688,21 +898,137 @@ func copyDir(src, dst string) error { }) } -func jsonDiff(a, b map[string]any) string { - var lines []string - diffRecursive(a, b, "$", &lines) - if len(lines) == 0 { +// diffResult contains classified diff hunks from a schema comparison. +type diffResult struct { + Real []diffHunk // Genuine mismatches that indicate a bug + Expected []diffHunk // Known limitations (dynamic descriptions, training schemas) +} + +// jsonCompare compares two JSON schemas and classifies differences as +// "real" (genuine mismatches) or "expected" (known static-gen limitations). +// +// Expected differences: +// - Training schemas/paths present in static but absent in runtime +// (runtime only generates predict schema) +// - Descriptions present in runtime but absent in static +// (static can't resolve dynamically-constructed descriptions) +func jsonCompare(a, b map[string]any) diffResult { + var hunks []diffHunk + collectDiffs(a, b, "$", &hunks) + + var result diffResult + for _, h := range hunks { + if isExpectedDiff(h) { + result.Expected = append(result.Expected, h) + } else { + result.Real = append(result.Real, h) + } + } + return result +} + +// isExpectedDiff returns true if a diff hunk represents a known limitation +// of static schema generation rather than a real bug. +func isExpectedDiff(h diffHunk) bool { + // Static generates training schemas; runtime only generates predict. + // Training-related schemas/paths are expected to be missing in runtime. + if h.Kind == "missing_in_runtime" { + if strings.Contains(h.Path, "Training") || + strings.Contains(h.Path, "trainings") { + return true + } + } + + // Static can't resolve dynamically-constructed descriptions (e.g. + // f-strings with conditional logic in class methods). + if h.Kind == "missing_in_static" && strings.HasSuffix(h.Path, ".description") { + return true + } + + return false +} + +// formatDiffHunks formats a slice of hunks as a unified-diff-style string. +func formatDiffHunks(hunks []diffHunk) string { + if len(hunks) == 0 { return "" } - return strings.Join(lines, "\n") + var buf strings.Builder + buf.WriteString("--- static schema\n") + buf.WriteString("+++ runtime schema\n") + for _, h := range hunks { + buf.WriteString("\n") + buf.WriteString(h.String()) + } + return buf.String() +} + +// jsonDiff produces a unified-diff-style comparison between two JSON objects. +// Only includes "real" differences (not expected/known limitations). +func jsonDiff(a, b map[string]any) string { + result := jsonCompare(a, b) + return formatDiffHunks(result.Real) +} + +// diffHunk represents a single difference between two JSON values. +type diffHunk struct { + Path string + Kind string // "missing_in_static", "missing_in_runtime", "changed", "type_mismatch", "array_length" + StaticV any // value in static (nil if missing) + RuntimeV any // value in runtime (nil if missing) +} + +func (h diffHunk) String() string { + var buf strings.Builder + fmt.Fprintf(&buf, "@@ %s @@\n", h.Path) + switch h.Kind { + case "missing_in_runtime": + // Present in static, absent in runtime + for _, line := range prettyLines(h.StaticV) { + fmt.Fprintf(&buf, "-%s\n", line) + } + case "missing_in_static": + // Absent in static, present in runtime + for _, line := range prettyLines(h.RuntimeV) { + fmt.Fprintf(&buf, "+%s\n", line) + } + case "changed", "type_mismatch": + for _, line := range prettyLines(h.StaticV) { + fmt.Fprintf(&buf, "-%s\n", line) + } + for _, line := range prettyLines(h.RuntimeV) { + fmt.Fprintf(&buf, "+%s\n", line) + } + case "array_length": + for _, line := range prettyLines(h.StaticV) { + fmt.Fprintf(&buf, "-%s\n", line) + } + for _, line := range prettyLines(h.RuntimeV) { + fmt.Fprintf(&buf, "+%s\n", line) + } + } + return buf.String() +} + +// prettyLines returns a compact JSON representation split into lines. +func prettyLines(v any) []string { + data, err := json.MarshalIndent(v, "", " ") + if err != nil { + return []string{fmt.Sprintf("%v", v)} + } + return strings.Split(string(data), "\n") } -func diffRecursive(a, b any, path string, lines *[]string) { +func collectDiffs(a, b any, path string, hunks *[]diffHunk) { if a == nil && b == nil { return } - if a == nil || b == nil { - *lines = append(*lines, fmt.Sprintf(" %s: one side is nil", path)) + if a == nil { + *hunks = append(*hunks, diffHunk{Path: path, Kind: "missing_in_static", RuntimeV: b}) + return + } + if b == nil { + *hunks = append(*hunks, diffHunk{Path: path, Kind: "missing_in_runtime", StaticV: a}) return } @@ -710,9 +1036,10 @@ func diffRecursive(a, b any, path string, lines *[]string) { case map[string]any: bv, ok := b.(map[string]any) if !ok { - *lines = append(*lines, fmt.Sprintf(" %s: type mismatch (object vs %T)", path, b)) + *hunks = append(*hunks, diffHunk{Path: path, Kind: "type_mismatch", StaticV: a, RuntimeV: b}) return } + // Collect all keys, sorted for deterministic output allKeys := make(map[string]bool) for k := range av { allKeys[k] = true @@ -720,33 +1047,42 @@ func diffRecursive(a, b any, path string, lines *[]string) { for k := range bv { allKeys[k] = true } + sortedKeys := make([]string, 0, len(allKeys)) for k := range allKeys { + sortedKeys = append(sortedKeys, k) + } + sort.Strings(sortedKeys) + + for _, k := range sortedKeys { childPath := fmt.Sprintf("%s.%s", path, k) - if _, ok := av[k]; !ok { - *lines = append(*lines, fmt.Sprintf(" %s: missing in static", childPath)) - } else if _, ok := bv[k]; !ok { - *lines = append(*lines, fmt.Sprintf(" %s: missing in runtime", childPath)) - } else { - diffRecursive(av[k], bv[k], childPath, lines) + aVal, aOK := av[k] + bVal, bOK := bv[k] + switch { + case !aOK: + *hunks = append(*hunks, diffHunk{Path: childPath, Kind: "missing_in_static", RuntimeV: bVal}) + case !bOK: + *hunks = append(*hunks, diffHunk{Path: childPath, Kind: "missing_in_runtime", StaticV: aVal}) + default: + collectDiffs(aVal, bVal, childPath, hunks) } } case []any: bv, ok := b.([]any) if !ok { - *lines = append(*lines, fmt.Sprintf(" %s: type mismatch (array vs %T)", path, b)) + *hunks = append(*hunks, diffHunk{Path: path, Kind: "type_mismatch", StaticV: a, RuntimeV: b}) return } if len(av) != len(bv) { - *lines = append(*lines, fmt.Sprintf(" %s: array length mismatch (%d vs %d)", path, len(av), len(bv))) + *hunks = append(*hunks, diffHunk{Path: path, Kind: "array_length", StaticV: a, RuntimeV: b}) return } for i := range av { childPath := fmt.Sprintf("%s[%d]", path, i) - diffRecursive(av[i], bv[i], childPath, lines) + collectDiffs(av[i], bv[i], childPath, hunks) } default: - if a != b { - *lines = append(*lines, fmt.Sprintf(" %s: value mismatch (%v vs %v)", path, a, b)) + if fmt.Sprint(a) != fmt.Sprint(b) { + *hunks = append(*hunks, diffHunk{Path: path, Kind: "changed", StaticV: a, RuntimeV: b}) } } } diff --git a/tools/test-harness/internal/runner/runner_test.go b/tools/test-harness/internal/runner/runner_test.go index cdcc9f72c6..8491eb3e9f 100644 --- a/tools/test-harness/internal/runner/runner_test.go +++ b/tools/test-harness/internal/runner/runner_test.go @@ -2,6 +2,7 @@ package runner import ( "path/filepath" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -38,6 +39,157 @@ func TestSafeSubpathRejectsAbsoluteOutsidePath(t *testing.T) { assert.Contains(t, err.Error(), "must be relative") } +func TestJsonDiffUnifiedFormat(t *testing.T) { + tests := []struct { + name string + static map[string]any + runtime map[string]any + contains []string // substrings the diff must contain + empty bool // expect no diff + }{ + { + name: "identical schemas produce no diff", + static: map[string]any{"a": 1, "b": "hello"}, + runtime: map[string]any{"a": 1, "b": "hello"}, + empty: true, + }, + { + name: "missing key in static shows + lines", + static: map[string]any{"a": 1}, + runtime: map[string]any{"a": 1, "b": "new"}, + contains: []string{"--- static schema", "+++ runtime schema", "@@ $.b @@", "+\"new\""}, + }, + { + name: "missing key in runtime shows - lines", + static: map[string]any{"a": 1, "b": "old"}, + runtime: map[string]any{"a": 1}, + contains: []string{"@@ $.b @@", "-\"old\""}, + }, + { + name: "changed value shows - and + lines", + static: map[string]any{"a": "foo"}, + runtime: map[string]any{"a": "bar"}, + contains: []string{"@@ $.a @@", "-\"foo\"", "+\"bar\""}, + }, + { + name: "nested object diff shows full path", + static: map[string]any{"outer": map[string]any{"inner": 1}}, + runtime: map[string]any{"outer": map[string]any{"inner": 2}}, + contains: []string{"@@ $.outer.inner @@", "-1", "+2"}, + }, + { + name: "array length mismatch shows both arrays", + static: map[string]any{"arr": []any{"a", "b"}}, + runtime: map[string]any{"arr": []any{"a", "b", "c"}}, + contains: []string{"@@ $.arr @@", "-[", "+[", "+ \"c\""}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + diff := jsonDiff(tt.static, tt.runtime) + if tt.empty { + assert.Empty(t, diff) + return + } + require.NotEmpty(t, diff) + for _, substr := range tt.contains { + assert.True(t, strings.Contains(diff, substr), + "diff should contain %q but got:\n%s", substr, diff) + } + }) + } +} + +func TestJsonCompareClassifiesExpectedDiffs(t *testing.T) { + t.Run("training schemas missing in runtime are expected", func(t *testing.T) { + static := map[string]any{ + "components": map[string]any{ + "schemas": map[string]any{ + "Input": map[string]any{"type": "object"}, + "TrainingInput": map[string]any{"type": "object"}, + }, + }, + "paths": map[string]any{ + "/predictions": map[string]any{"post": true}, + "/trainings": map[string]any{"post": true}, + "/trainings/{training_id}/cancel": map[string]any{"post": true}, + }, + } + runtime := map[string]any{ + "components": map[string]any{ + "schemas": map[string]any{ + "Input": map[string]any{"type": "object"}, + }, + }, + "paths": map[string]any{ + "/predictions": map[string]any{"post": true}, + }, + } + + result := jsonCompare(static, runtime) + assert.Empty(t, result.Real, "training schema diffs should not be real failures") + assert.Len(t, result.Expected, 3, "should have 3 expected diffs (TrainingInput, /trainings, /trainings/.../cancel)") + }) + + t.Run("missing description in static is expected", func(t *testing.T) { + static := map[string]any{ + "properties": map[string]any{ + "steps": map[string]any{ + "type": "integer", + "default": float64(4), + }, + }, + } + runtime := map[string]any{ + "properties": map[string]any{ + "steps": map[string]any{ + "type": "integer", + "default": float64(4), + "description": "Number of denoising steps.", + }, + }, + } + + result := jsonCompare(static, runtime) + assert.Empty(t, result.Real, "missing description in static should not be a real failure") + assert.Len(t, result.Expected, 1) + }) + + t.Run("real diffs are not classified as expected", func(t *testing.T) { + static := map[string]any{ + "type": "integer", + "default": float64(4), + } + runtime := map[string]any{ + "type": "string", + "default": float64(4), + } + + result := jsonCompare(static, runtime) + assert.Len(t, result.Real, 1, "type mismatch should be a real failure") + assert.Empty(t, result.Expected) + }) + + t.Run("only expected diffs means jsonDiff returns empty", func(t *testing.T) { + static := map[string]any{ + "components": map[string]any{ + "schemas": map[string]any{ + "TrainingInput": map[string]any{"type": "object"}, + }, + }, + } + runtime := map[string]any{ + "components": map[string]any{ + "schemas": map[string]any{}, + }, + } + + diff := jsonDiff(static, runtime) + assert.Empty(t, diff, "jsonDiff should return empty when only expected diffs exist") + }) +} + func TestExtractOutput(t *testing.T) { tests := []struct { name string diff --git a/tools/test-harness/internal/validator/validator.go b/tools/test-harness/internal/validator/validator.go index f13b2fc6db..4482c5064a 100644 --- a/tools/test-harness/internal/validator/validator.go +++ b/tools/test-harness/internal/validator/validator.go @@ -249,10 +249,3 @@ func getKeys(m map[string]any) []string { } return keys } - -func min(a, b int) int { - if a < b { - return a - } - return b -} diff --git a/tools/test-harness/manifest.yaml b/tools/test-harness/manifest.yaml index bdd8870023..b72cf62749 100644 --- a/tools/test-harness/manifest.yaml +++ b/tools/test-harness/manifest.yaml @@ -21,6 +21,13 @@ # Example: # setup: # - "script/select.sh dev" +# +# Required tools: +# Optional `requires_tools` list of CLI tools that must be on PATH for +# setup commands to work. The harness checks before running and prints +# install instructions for known tools (yq, envsubst). +# Example: +# requires_tools: ["yq"] defaults: sdk_version: "latest" # "latest" = newest stable from PyPI; or pin e.g. "0.16.12"