Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ cog build [flags]
--progress string Set type of build progress output, 'auto' (default), 'tty', 'plain', or 'quiet' (default "auto")
--secret stringArray Secrets to pass to the build environment in the form 'id=foo,src=/path/to/file'
--separate-weights Separate model weights from code in image layers
--skip-schema-validation Skip OpenAPI schema generation and validation
-t, --tag string A name for the built image in the form 'repository:tag'
--use-cog-base-image Use pre-built Cog base image for faster cold boots (default true)
--use-cuda-base-image string Use Nvidia CUDA base image, 'true' (default) or 'false' (use python base image). False results in a smaller image but may cause problems for non-torch projects (default "auto")
Expand Down
1 change: 1 addition & 0 deletions docs/llms.txt

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ require (
golang.org/x/sys v0.42.0
golang.org/x/term v0.41.0
google.golang.org/grpc v1.79.3
gopkg.in/yaml.v3 v3.0.1
)

require (
Expand Down Expand Up @@ -148,7 +149,6 @@ require (
google.golang.org/genproto/googleapis/api v0.0.0-20260128011058-8636f8732409 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20260128011058-8636f8732409 // indirect
google.golang.org/protobuf v1.36.11 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
gotest.tools/gotestsum v1.12.2 // indirect
)

Expand Down
33 changes: 20 additions & 13 deletions pkg/cli/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ var buildDockerfileFile string
var buildUseCogBaseImage bool
var buildStrip bool
var buildPrecompile bool
var buildSkipSchemaValidation bool
var configFilename string

const useCogBaseImageFlagKey = "use-cog-base-image"
Expand Down Expand Up @@ -66,6 +67,7 @@ with 'cog push'.`,
addStripFlag(cmd)
addPrecompileFlag(cmd)
addConfigFlag(cmd)
addSkipSchemaValidationFlag(cmd)
cmd.Flags().StringVarP(&buildTag, "tag", "t", "", "A name for the built image in the form 'repository:tag'")
return cmd
}
Expand Down Expand Up @@ -194,22 +196,27 @@ func DetermineUseCogBaseImage(cmd *cobra.Command) *bool {
return useCogBaseImage
}

func addSkipSchemaValidationFlag(cmd *cobra.Command) {
cmd.Flags().BoolVar(&buildSkipSchemaValidation, "skip-schema-validation", false, "Skip OpenAPI schema generation and validation")
}

// buildOptionsFromFlags creates BuildOptions from the current CLI flag values.
// The imageName and annotations parameters vary by command and must be provided.
func buildOptionsFromFlags(cmd *cobra.Command, imageName string, annotations map[string]string) model.BuildOptions {
return model.BuildOptions{
ImageName: imageName,
Secrets: buildSecrets,
NoCache: buildNoCache,
SeparateWeights: buildSeparateWeights,
UseCudaBaseImage: buildUseCudaBaseImage,
ProgressOutput: buildProgressOutput,
SchemaFile: buildSchemaFile,
DockerfileFile: buildDockerfileFile,
UseCogBaseImage: DetermineUseCogBaseImage(cmd),
Strip: buildStrip,
Precompile: buildPrecompile,
Annotations: annotations,
OCIIndex: model.OCIIndexEnabled(),
ImageName: imageName,
Secrets: buildSecrets,
NoCache: buildNoCache,
SeparateWeights: buildSeparateWeights,
UseCudaBaseImage: buildUseCudaBaseImage,
ProgressOutput: buildProgressOutput,
SchemaFile: buildSchemaFile,
DockerfileFile: buildDockerfileFile,
UseCogBaseImage: DetermineUseCogBaseImage(cmd),
Strip: buildStrip,
Precompile: buildPrecompile,
Annotations: annotations,
OCIIndex: model.OCIIndexEnabled(),
SkipSchemaValidation: buildSkipSchemaValidation,
}
}
62 changes: 48 additions & 14 deletions tools/test-harness/cmd/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ package cmd
import (
"context"
"fmt"
"strings"

"github.com/spf13/cobra"

"github.com/replicate/cog/tools/test-harness/internal/manifest"
"github.com/replicate/cog/tools/test-harness/internal/report"
"github.com/replicate/cog/tools/test-harness/internal/runner"
)
Expand All @@ -23,6 +23,10 @@ func newBuildCommand() *cobra.Command {
}

func runBuild(ctx context.Context) error {
if err := validateConcurrency(); err != nil {
return err
}

_, models, resolved, err := resolveSetup()
if err != nil {
return err
Expand All @@ -32,7 +36,13 @@ func runBuild(ctx context.Context) error {
fmt.Println("No models to build")
return nil
}
fmt.Printf("Building %d model(s)\n\n", len(models))

parallel := concurrency > 1 && len(models) > 1
if parallel {
fmt.Printf("Building %d model(s) with concurrency %d\n\n", len(models), concurrency)
} else {
fmt.Printf("Building %d model(s)\n\n", len(models))
}

// Create runner
r, err := runner.New(runner.Options{
Expand All @@ -41,33 +51,57 @@ func runBuild(ctx context.Context) error {
SDKWheel: resolved.SDKWheel,
CleanImages: cleanImages,
KeepOutputs: keepOutputs,
Parallel: parallel,
})
if err != nil {
return fmt.Errorf("creating runner: %w", err)
}
defer r.Cleanup()
defer func() { _ = r.Cleanup() }()

// Build models
var results []report.ModelResult
for _, model := range models {
fmt.Printf("Building %s...\n", model.Name)
result := r.BuildModel(ctx, model)
results = append(results, *result)
}
results := make([]report.ModelResult, len(models))

runModels(ctx, models, results, parallel,
func(ctx context.Context, model manifest.Model) *report.ModelResult {
return r.BuildModel(ctx, model)
},
func(index, total int, model manifest.Model) string {
if parallel {
return fmt.Sprintf(" [%d/%d] Building %s...\n", index, total, model.Name)
}
return fmt.Sprintf("Building %s...\n", model.Name)
},
func(index, total int, model manifest.Model, result *report.ModelResult) string {
if parallel {
switch {
case result.Passed:
return fmt.Sprintf(" [%d/%d] + %s (%.1fs)\n", index, total, model.Name, result.BuildDuration)
case result.Skipped:
return fmt.Sprintf(" [%d/%d] - %s (skipped: %s)\n", index, total, model.Name, result.SkipReason)
default:
return fmt.Sprintf(" [%d/%d] x %s FAILED\n", index, total, model.Name)
}
}
switch {
case result.Passed:
return fmt.Sprintf(" + %s built successfully (%.1fs)\n", model.Name, result.BuildDuration)
case result.Skipped:
return fmt.Sprintf(" - %s (skipped: %s)\n", model.Name, result.SkipReason)
default:
return fmt.Sprintf(" x %s FAILED\n", model.Name)
}
},
)

// Output results
report.ConsoleReport(results, resolved.SDKVersion, resolved.CogVersion)

// Check for failures
var failedNames []string
for _, r := range results {
if !r.Passed && !r.Skipped {
failedNames = append(failedNames, r.Name)
return formatFailureSummary("build", results)
}
}
if len(failedNames) > 0 {
return fmt.Errorf("%d build(s) failed: %s", len(failedNames), strings.Join(failedNames, ", "))
}

return nil
}
114 changes: 114 additions & 0 deletions tools/test-harness/cmd/root.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
package cmd

import (
"context"
"fmt"
"strings"
"sync"

"github.com/spf13/cobra"

"github.com/replicate/cog/tools/test-harness/internal/manifest"
"github.com/replicate/cog/tools/test-harness/internal/report"
"github.com/replicate/cog/tools/test-harness/internal/resolver"
)

Expand All @@ -21,6 +25,7 @@ var (
sdkWheel string
cleanImages bool
keepOutputs bool
concurrency int
)

// NewRootCommand creates the root command
Expand All @@ -46,6 +51,7 @@ It reads the same manifest.yaml format as the Python version.`,
rootCmd.PersistentFlags().StringVar(&sdkWheel, "sdk-wheel", "", "Path to pre-built SDK wheel")
rootCmd.PersistentFlags().BoolVar(&cleanImages, "clean-images", false, "Remove Docker images after run (default: keep them)")
rootCmd.PersistentFlags().BoolVar(&keepOutputs, "keep-outputs", false, "Preserve prediction outputs (images, files) in the work directory")
rootCmd.PersistentFlags().IntVar(&concurrency, "concurrency", 4, "Maximum number of models to build/test in parallel")

// Subcommands
rootCmd.AddCommand(newRunCommand())
Expand Down Expand Up @@ -77,3 +83,111 @@ func resolveSetup() (*manifest.Manifest, []manifest.Model, *resolver.Result, err
models := mf.FilterModels(modelFilter, noGPU, gpuOnly)
return mf, models, resolved, nil
}

// validateConcurrency checks that the concurrency flag is a valid value.
// errgroup.SetLimit panics on 0, and negative values mean unlimited.
func validateConcurrency() error {
if concurrency < 1 {
return fmt.Errorf("--concurrency must be at least 1, got %d", concurrency)
}
return nil
}

// modelAction is a function that processes a single model and returns a result.
type modelAction[T any] func(ctx context.Context, model manifest.Model) *T

// statusPrinter formats a per-model status line after processing completes.
type statusPrinter[T any] func(index, total int, model manifest.Model, result *T) string

// runModels executes an action for each model, either sequentially or in parallel
// depending on the concurrency setting. It handles the common pattern of:
// - printing a "starting" line
// - running the action
// - printing a "done/failed" status line
//
// The results slice is pre-allocated by the caller. This function fills it in.
func runModels[T any](
ctx context.Context,
models []manifest.Model,
results []T,
parallel bool,
action modelAction[T],
startLine func(index, total int, model manifest.Model) string,
statusLine statusPrinter[T],
) {
if parallel {
sem := make(chan struct{}, concurrency)
var wg sync.WaitGroup
var mu sync.Mutex

for i, model := range models {
wg.Go(func() {
sem <- struct{}{} // acquire
defer func() { <-sem }() // release

mu.Lock()
fmt.Print(startLine(i+1, len(models), model))
mu.Unlock()

result := action(ctx, model)
results[i] = *result

mu.Lock()
fmt.Print(statusLine(i+1, len(models), model, result))
mu.Unlock()
})
}
wg.Wait()
} else {
for i, model := range models {
fmt.Print(startLine(i+1, len(models), model))
result := action(ctx, model)
results[i] = *result
fmt.Print(statusLine(i+1, len(models), model, result))
}
}
}

// formatFailureSummary builds an error message with per-model failure details.
//
//nolint:gosec // G705: writes to strings.Builder, not an HTTP response — no XSS risk
func formatFailureSummary(action string, results []report.ModelResult) error {
var b strings.Builder
var failCount int
for _, r := range results {
if r.Passed || r.Skipped {
continue
}
failCount++
fmt.Fprintf(&b, "\n x %s", r.Name)
if r.Error != "" {
// Show first line of the error
firstLine := r.Error
if idx := strings.Index(firstLine, "\n"); idx != -1 {
firstLine = firstLine[:idx]
}
fmt.Fprintf(&b, ": %s", firstLine)
} else {
// Summarize failed tests
for _, t := range r.TestResults {
if !t.Passed {
msg := t.Message
if idx := strings.Index(msg, "\n"); idx != -1 {
msg = msg[:idx]
}
fmt.Fprintf(&b, "\n test %q: %s", t.Description, msg)
}
}
for _, t := range r.TrainResults {
if !t.Passed {
msg := t.Message
if idx := strings.Index(msg, "\n"); idx != -1 {
msg = msg[:idx]
}
fmt.Fprintf(&b, "\n train %q: %s", t.Description, msg)
}
}
}
}
return fmt.Errorf("%d %s(s) failed:%s", failCount, action, b.String())
}
Loading
Loading