Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 13 additions & 48 deletions cmd/root/record.go
Original file line number Diff line number Diff line change
@@ -1,71 +1,36 @@
package root

import (
"fmt"
"log/slog"
"strings"
"time"

"github.com/docker/docker-agent/pkg/config"
"github.com/docker/docker-agent/pkg/fake"
"github.com/docker/docker-agent/pkg/recording"
)

// setupFakeProxy starts a fake proxy if fakeResponses is non-empty.
// streamDelayMs controls simulated streaming: 0 = disabled, >0 = delay in milliseconds between chunks.
// It returns a cleanup function that must be called when done (typically via defer).
// It configures the runtime config's ModelsGateway to point to the proxy.
func setupFakeProxy(fakeResponses string, streamDelayMs int, runConfig *config.RuntimeConfig) (cleanup func() error, err error) {
if fakeResponses == "" {
return func() error { return nil }, nil
}

// Normalize path by stripping .yaml suffix (go-vcr adds it automatically)
fakeResponses = strings.TrimSuffix(fakeResponses, ".yaml")

var opts []fake.ProxyOption
if streamDelayMs > 0 {
opts = append(opts,
fake.WithSimulateStream(true),
fake.WithStreamChunkDelay(time.Duration(streamDelayMs)*time.Millisecond),
)
}

proxyURL, cleanupFn, err := fake.StartProxy(fakeResponses, opts...)
proxyURL, cleanupFn, err := recording.SetupFakeProxy(fakeResponses, streamDelayMs)
if err != nil {
return nil, fmt.Errorf("failed to start fake proxy: %w", err)
return nil, err
}

runConfig.ModelsGateway = proxyURL
slog.Info("Fake mode enabled", "cassette", fakeResponses, "proxy", proxyURL)
if proxyURL != "" {
runConfig.ModelsGateway = proxyURL
}

return cleanupFn, nil
}

// setupRecordingProxy starts a recording proxy if recordPath is non-empty.
// It handles auto-generating a filename when recordPath is "true" (from NoOptDefVal),
// and normalizes the path by stripping any .yaml suffix.
// Returns the cassette path (with .yaml extension) and a cleanup function.
// The cleanup function must be called when done (typically via defer).
// It configures the runtime config's ModelsGateway to point to the proxy.
func setupRecordingProxy(recordPath string, runConfig *config.RuntimeConfig) (cassettePath string, cleanup func() error, err error) {
if recordPath == "" {
return "", func() error { return nil }, nil
}

// Handle auto-generated filename (from NoOptDefVal)
if recordPath == "true" {
recordPath = fmt.Sprintf("cagent-recording-%d", time.Now().Unix())
} else {
recordPath = strings.TrimSuffix(recordPath, ".yaml")
}

proxyURL, cleanupFn, err := fake.StartRecordingProxy(recordPath)
cassettePath, proxyURL, cleanupFn, err := recording.SetupRecordingProxy(recordPath)
if err != nil {
return "", nil, fmt.Errorf("failed to start recording proxy: %w", err)
return "", nil, err
}

runConfig.ModelsGateway = proxyURL
cassettePath = recordPath + ".yaml"

slog.Info("Recording mode enabled", "cassette", cassettePath, "proxy", proxyURL)
if proxyURL != "" {
runConfig.ModelsGateway = proxyURL
}

return cassettePath, cleanupFn, nil
}
56 changes: 12 additions & 44 deletions cmd/root/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ import (
"log/slog"
"os"
"path/filepath"
goruntime "runtime"
"runtime/pprof"
"sync"
"time"

Expand All @@ -22,6 +20,7 @@ import (
"github.com/docker/docker-agent/pkg/cli"
"github.com/docker/docker-agent/pkg/config"
"github.com/docker/docker-agent/pkg/paths"
"github.com/docker/docker-agent/pkg/profiling"
"github.com/docker/docker-agent/pkg/runtime"
"github.com/docker/docker-agent/pkg/session"
"github.com/docker/docker-agent/pkg/sessiontitle"
Expand Down Expand Up @@ -145,37 +144,16 @@ func (f *runExecFlags) runRunCommand(cmd *cobra.Command, args []string) error {
func (f *runExecFlags) runOrExec(ctx context.Context, out *cli.Printer, args []string, useTUI bool) error {
slog.Debug("Starting agent", "agent", f.agentName)

// Start CPU profiling if requested
if f.cpuProfile != "" {
pf, err := os.Create(f.cpuProfile)
if err != nil {
return fmt.Errorf("failed to create CPU profile: %w", err)
}
if err := pprof.StartCPUProfile(pf); err != nil {
pf.Close()
return fmt.Errorf("failed to start CPU profile: %w", err)
}
defer pprof.StopCPUProfile()
defer pf.Close()
slog.Info("CPU profiling enabled", "file", f.cpuProfile)
}

// Write memory profile at exit if requested
if f.memProfile != "" {
defer func() {
mf, err := os.Create(f.memProfile)
if err != nil {
slog.Error("Failed to create memory profile", "error", err)
return
}
defer mf.Close()
goruntime.GC() // Get up-to-date statistics
if err := pprof.WriteHeapProfile(mf); err != nil {
slog.Error("Failed to write memory profile", "error", err)
}
slog.Info("Memory profile written", "file", f.memProfile)
}()
// Start profiling if requested
stopProfiling, err := profiling.Start(f.cpuProfile, f.memProfile)
if err != nil {
return err
}
defer func() {
if err := stopProfiling(); err != nil {
slog.Error("Profiling cleanup failed", "error", err)
}
}()

var agentFileName string
if len(args) > 0 {
Expand Down Expand Up @@ -271,10 +249,6 @@ func (f *runExecFlags) runOrExec(ctx context.Context, out *cli.Printer, args []s
}
defer initialTeamCleanup()

if useTUI {
applyTheme()
}

if f.dryRun {
out.Println("Dry run mode enabled. Agent initialized but will not execute.")
return nil
Expand All @@ -284,19 +258,13 @@ func (f *runExecFlags) runOrExec(ctx context.Context, out *cli.Printer, args []s
return f.handleExecMode(ctx, out, rt, sess, args)
}

applyTheme()
opts, err := f.buildAppOpts(args)
if err != nil {
return err
}

var sessStore session.Store
switch typedRt := rt.(type) {
case *runtime.LocalRuntime:
sessStore = typedRt.SessionStore()
case *runtime.PersistentRuntime:
sessStore = typedRt.SessionStore()
}

sessStore := rt.SessionStore()
return runTUI(ctx, rt, sess, f.createSessionSpawner(agentSource, sessStore), initialTeamCleanup, opts...)
}

Expand Down
65 changes: 65 additions & 0 deletions pkg/profiling/profiling.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// Package profiling provides helpers for CPU and memory profiling.
package profiling

import (
"errors"
"fmt"
"os"
"runtime"
"runtime/pprof"
)

// Stop is a function returned by Start that stops profiling and flushes
// any buffered data. It must be called (typically via defer) when the
// profiled section of code completes.
type Stop func() error

// Start begins CPU and/or memory profiling based on the provided file
// paths. Pass an empty string to skip the corresponding profile.
// The returned Stop function must be called to finalise the profiles.
func Start(cpuProfile, memProfile string) (Stop, error) {
var closers []func() error

if cpuProfile != "" {
f, err := os.Create(cpuProfile)
if err != nil {
return noop, fmt.Errorf("failed to create CPU profile: %w", err)
}
if err := pprof.StartCPUProfile(f); err != nil {
f.Close()
return noop, fmt.Errorf("failed to start CPU profile: %w", err)
}
closers = append(closers, func() error {
pprof.StopCPUProfile()
return f.Close()
})
}

if memProfile != "" {
closers = append(closers, func() error {
f, err := os.Create(memProfile)
if err != nil {
return fmt.Errorf("failed to create memory profile: %w", err)
}
defer f.Close()
runtime.GC()
if err := pprof.WriteHeapProfile(f); err != nil {
return fmt.Errorf("failed to write memory profile: %w", err)
}
return nil
})
}

return func() error {
// Run in reverse order so CPU profile is stopped before mem profile is written.
var errs []error
for i := len(closers) - 1; i >= 0; i-- {
if err := closers[i](); err != nil {
errs = append(errs, err)
}
}
return errors.Join(errs...)
}, nil
}

func noop() error { return nil }
71 changes: 71 additions & 0 deletions pkg/recording/recording.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// Package recording provides helpers for recording and replaying AI API interactions.
package recording

import (
"fmt"
"log/slog"
"strings"
"time"

"github.com/docker/docker-agent/pkg/fake"
)

// SetupFakeProxy starts a fake proxy if fakeResponses is non-empty.
// streamDelayMs controls simulated streaming: 0 = disabled, >0 = delay in milliseconds between chunks.
// It returns the proxy URL and a cleanup function that must be called when done (typically via defer).
func SetupFakeProxy(fakeResponses string, streamDelayMs int) (proxyURL string, cleanup func() error, err error) {
if fakeResponses == "" {
return "", noop, nil
}

// Normalize path by stripping .yaml suffix (go-vcr adds it automatically)
fakeResponses = strings.TrimSuffix(fakeResponses, ".yaml")

var opts []fake.ProxyOption
if streamDelayMs > 0 {
opts = append(opts,
fake.WithSimulateStream(true),
fake.WithStreamChunkDelay(time.Duration(streamDelayMs)*time.Millisecond),
)
}

proxyURL, cleanupFn, err := fake.StartProxy(fakeResponses, opts...)
if err != nil {
return "", nil, fmt.Errorf("failed to start fake proxy: %w", err)
}

slog.Info("Fake mode enabled", "cassette", fakeResponses, "proxy", proxyURL)

return proxyURL, cleanupFn, nil
}

// SetupRecordingProxy starts a recording proxy if recordPath is non-empty.
// It handles auto-generating a filename when recordPath is "true" (from NoOptDefVal),
// and normalizes the path by stripping any .yaml suffix.
// Returns the cassette path (with .yaml extension), the proxy URL, and a cleanup function.
// The cleanup function must be called when done (typically via defer).
func SetupRecordingProxy(recordPath string) (cassettePath, proxyURL string, cleanup func() error, err error) {
if recordPath == "" {
return "", "", noop, nil
}

// Handle auto-generated filename (from NoOptDefVal)
if recordPath == "true" {
recordPath = fmt.Sprintf("cagent-recording-%d", time.Now().Unix())
} else {
recordPath = strings.TrimSuffix(recordPath, ".yaml")
}

proxyURL, cleanupFn, err := fake.StartRecordingProxy(recordPath)
if err != nil {
return "", "", nil, fmt.Errorf("failed to start recording proxy: %w", err)
}

cassettePath = recordPath + ".yaml"

slog.Info("Recording mode enabled", "cassette", cassettePath, "proxy", proxyURL)

return cassettePath, proxyURL, cleanupFn, nil
}

func noop() error { return nil }
Loading