From b5ffc104e5e825b56538ab0564eeafa69bd68d7c Mon Sep 17 00:00:00 2001 From: David Gageot Date: Sat, 21 Mar 2026 10:02:37 +0100 Subject: [PATCH 1/3] Extract RAG code Signed-off-by: David Gageot --- pkg/app/app.go | 8 ++- pkg/rag/builder.go | 14 ++--- pkg/runtime/loop.go | 4 +- pkg/runtime/rag.go | 105 +++++++++++++++++++++++++++++++++ pkg/runtime/runtime.go | 109 ----------------------------------- pkg/runtime/runtime_test.go | 6 +- pkg/team/team.go | 36 ++---------- pkg/teamloader/teamloader.go | 10 +++- 8 files changed, 132 insertions(+), 160 deletions(-) create mode 100644 pkg/runtime/rag.go diff --git a/pkg/app/app.go b/pkg/app/app.go index 869132bc6..3e01b7540 100644 --- a/pkg/app/app.go +++ b/pkg/app/app.go @@ -30,6 +30,12 @@ import ( "github.com/docker/docker-agent/pkg/tui/messages" ) +// RAGInitializer is implemented by runtimes that support background RAG initialization. +// Local runtimes use this to start indexing early; remote runtimes typically do not. +type RAGInitializer interface { + StartBackgroundRAGInit(ctx context.Context, sendEvent func(runtime.Event)) +} + type App struct { runtime runtime.Runtime session *session.Session @@ -119,7 +125,7 @@ func New(ctx context.Context, rt runtime.Runtime, sess *session.Session, opts .. // If the runtime supports background RAG initialization, start it // and forward events to the TUI. Remote runtimes typically handle RAG server-side // and won't implement this optional interface. - if ragRuntime, ok := rt.(runtime.RAGInitializer); ok { + if ragRuntime, ok := rt.(RAGInitializer); ok { go ragRuntime.StartBackgroundRAGInit(ctx, func(event runtime.Event) { select { case app.events <- event: diff --git a/pkg/rag/builder.go b/pkg/rag/builder.go index 71f04409a..dc265670d 100644 --- a/pkg/rag/builder.go +++ b/pkg/rag/builder.go @@ -25,17 +25,13 @@ type ManagersBuildConfig struct { } // NewManagers constructs all RAG managers defined in the config. -func NewManagers( - ctx context.Context, - cfg *latest.Config, - buildCfg ManagersBuildConfig, -) (map[string]*Manager, error) { - managers := make(map[string]*Manager) - +func NewManagers(ctx context.Context, cfg *latest.Config, buildCfg ManagersBuildConfig) ([]*Manager, error) { if len(cfg.RAG) == 0 { - return managers, nil + return nil, nil } + var managers []*Manager + for ragName, ragCfg := range cfg.RAG { // Validate that we have at least one strategy if len(ragCfg.Strategies) == 0 { @@ -69,7 +65,7 @@ func NewManagers( return nil, fmt.Errorf("failed to create RAG manager %q: %w", ragName, err) } - managers[ragName] = manager + managers = append(managers, manager) strategyNames := make([]string, len(strategyConfigs)) for i, sc := range strategyConfigs { diff --git a/pkg/runtime/loop.go b/pkg/runtime/loop.go index ae161a815..69caa7eb6 100644 --- a/pkg/runtime/loop.go +++ b/pkg/runtime/loop.go @@ -97,7 +97,9 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c events <- TeamInfo(r.agentDetailsFromTeam(), a.Name()) // Initialize RAG and forward events - r.InitializeRAG(ctx, events) + r.StartBackgroundRAGInit(ctx, func(event Event) { + events <- event + }) r.emitAgentWarnings(a, chanSend(events)) r.configureToolsetHandlers(a, events) diff --git a/pkg/runtime/rag.go b/pkg/runtime/rag.go new file mode 100644 index 000000000..8d7c0a6b9 --- /dev/null +++ b/pkg/runtime/rag.go @@ -0,0 +1,105 @@ +package runtime + +import ( + "context" + "fmt" + "log/slog" + + "github.com/docker/docker-agent/pkg/rag" + "github.com/docker/docker-agent/pkg/rag/types" +) + +// StartBackgroundRAGInit initializes RAG in background and forwards events +// Should be called early (e.g., by App) to start indexing before RunStream +func (r *LocalRuntime) StartBackgroundRAGInit(ctx context.Context, sendEvent func(Event)) { + if r.ragInitialized.Swap(true) { + return + } + + ragManagers := r.team.RAGManagers() + if len(ragManagers) == 0 { + return + } + + // Set up event forwarding BEFORE starting initialization + r.forwardRAGEvents(ctx, ragManagers, sendEvent) + initializeRAG(ctx, ragManagers) + startRAGFileWatchers(ctx, ragManagers) +} + +// forwardRAGEvents forwards RAG manager events to the given callback +// Consolidates duplicated event forwarding logic +func (r *LocalRuntime) forwardRAGEvents(ctx context.Context, ragManagers []*rag.Manager, sendEvent func(Event)) { + for _, mgr := range ragManagers { + go func() { + ragName := mgr.Name() + slog.Debug("Starting RAG event forwarder goroutine", "rag", ragName) + for { + select { + case <-ctx.Done(): + slog.Debug("RAG event forwarder stopped", "rag", ragName) + return + case ragEvent, ok := <-mgr.Events(): + if !ok { + slog.Debug("RAG events channel closed", "rag", ragName) + return + } + + agentName := r.CurrentAgentName() + slog.Debug("Forwarding RAG event", "type", ragEvent.Type, "rag", ragName, "agent", agentName) + + switch ragEvent.Type { + case types.EventTypeIndexingStarted: + sendEvent(RAGIndexingStarted(ragName, ragEvent.StrategyName, agentName)) + case types.EventTypeIndexingProgress: + if ragEvent.Progress != nil { + sendEvent(RAGIndexingProgress(ragName, ragEvent.StrategyName, ragEvent.Progress.Current, ragEvent.Progress.Total, agentName)) + } + case types.EventTypeIndexingComplete: + sendEvent(RAGIndexingCompleted(ragName, ragEvent.StrategyName, agentName)) + case types.EventTypeUsage: + // Convert RAG usage to TokenUsageEvent so TUI displays it + sendEvent(NewTokenUsageEvent("", agentName, &Usage{ + InputTokens: ragEvent.TotalTokens, + ContextLength: ragEvent.TotalTokens, + Cost: ragEvent.Cost, + })) + case types.EventTypeError: + if ragEvent.Error != nil { + sendEvent(Error(fmt.Sprintf("RAG %s error: %v", ragName, ragEvent.Error))) + } + default: + // Log unhandled events for debugging + slog.Debug("Unhandled RAG event type", "type", ragEvent.Type, "rag", ragName) + } + } + } + }() + } +} + +// InitializeRAG initializes all RAG managers in the background +func initializeRAG(ctx context.Context, ragManagers []*rag.Manager) { + for _, mgr := range ragManagers { + go func() { + slog.Debug("Starting RAG manager initialization goroutine", "rag", mgr.Name()) + if err := mgr.Initialize(ctx); err != nil { + slog.Error("Failed to initialize RAG manager", "rag", mgr.Name(), "error", err) + } else { + slog.Info("RAG manager initialized successfully", "rag", mgr.Name()) + } + }() + } +} + +// StartRAGFileWatchers starts file watchers for all RAG managers +func startRAGFileWatchers(ctx context.Context, ragManagers []*rag.Manager) { + for _, mgr := range ragManagers { + go func() { + slog.Debug("Starting RAG file watcher goroutine", "rag", mgr.Name()) + if err := mgr.StartFileWatcher(ctx); err != nil { + slog.Error("Failed to start RAG file watcher", "rag", mgr.Name(), "error", err) + } + }() + } +} diff --git a/pkg/runtime/runtime.go b/pkg/runtime/runtime.go index a01928dae..a0e95b422 100644 --- a/pkg/runtime/runtime.go +++ b/pkg/runtime/runtime.go @@ -18,8 +18,6 @@ import ( "github.com/docker/docker-agent/pkg/config/types" "github.com/docker/docker-agent/pkg/hooks" "github.com/docker/docker-agent/pkg/modelsdev" - "github.com/docker/docker-agent/pkg/rag" - ragtypes "github.com/docker/docker-agent/pkg/rag/types" "github.com/docker/docker-agent/pkg/session" "github.com/docker/docker-agent/pkg/sessiontitle" "github.com/docker/docker-agent/pkg/team" @@ -163,12 +161,6 @@ type ModelStore interface { GetDatabase(ctx context.Context) (*modelsdev.Database, error) } -// RAGInitializer is implemented by runtimes that support background RAG initialization. -// Local runtimes use this to start indexing early; remote runtimes typically do not. -type RAGInitializer interface { - StartBackgroundRAGInit(ctx context.Context, sendEvent func(Event)) -} - // ToolsChangeSubscriber is implemented by runtimes that can notify when // toolsets report a change in their tool list (e.g. after an MCP // ToolListChanged notification). The provided callback is invoked @@ -340,107 +332,6 @@ func NewLocalRuntime(agents *team.Team, opts ...Opt) (*LocalRuntime, error) { return r, nil } -// StartBackgroundRAGInit initializes RAG in background and forwards events -// Should be called early (e.g., by App) to start indexing before RunStream -func (r *LocalRuntime) StartBackgroundRAGInit(ctx context.Context, sendEvent func(Event)) { - if r.ragInitialized.Swap(true) { - return - } - - ragManagers := r.team.RAGManagers() - if len(ragManagers) == 0 { - return - } - - slog.Debug("Starting background RAG initialization with event forwarding", "manager_count", len(ragManagers)) - - // Set up event forwarding BEFORE starting initialization - // This ensures all events are captured - r.forwardRAGEvents(ctx, ragManagers, sendEvent) - - // Now start initialization (events will be forwarded) - r.team.InitializeRAG(ctx) - r.team.StartRAGFileWatchers(ctx) -} - -// forwardRAGEvents forwards RAG manager events to the given callback -// Consolidates duplicated event forwarding logic -func (r *LocalRuntime) forwardRAGEvents(ctx context.Context, ragManagers map[string]*rag.Manager, sendEvent func(Event)) { - for _, mgr := range ragManagers { - go func(mgr *rag.Manager) { - ragName := mgr.Name() - slog.Debug("Starting RAG event forwarder goroutine", "rag", ragName) - for { - select { - case <-ctx.Done(): - slog.Debug("RAG event forwarder stopped", "rag", ragName) - return - case ragEvent, ok := <-mgr.Events(): - if !ok { - slog.Debug("RAG events channel closed", "rag", ragName) - return - } - - agentName := r.CurrentAgentName() - slog.Debug("Forwarding RAG event", "type", ragEvent.Type, "rag", ragName, "agent", agentName) - - switch ragEvent.Type { - case ragtypes.EventTypeIndexingStarted: - sendEvent(RAGIndexingStarted(ragName, ragEvent.StrategyName, agentName)) - case ragtypes.EventTypeIndexingProgress: - if ragEvent.Progress != nil { - sendEvent(RAGIndexingProgress(ragName, ragEvent.StrategyName, ragEvent.Progress.Current, ragEvent.Progress.Total, agentName)) - } - case ragtypes.EventTypeIndexingComplete: - sendEvent(RAGIndexingCompleted(ragName, ragEvent.StrategyName, agentName)) - case ragtypes.EventTypeUsage: - // Convert RAG usage to TokenUsageEvent so TUI displays it - sendEvent(NewTokenUsageEvent("", agentName, &Usage{ - InputTokens: ragEvent.TotalTokens, - ContextLength: ragEvent.TotalTokens, - Cost: ragEvent.Cost, - })) - case ragtypes.EventTypeError: - if ragEvent.Error != nil { - sendEvent(Error(fmt.Sprintf("RAG %s error: %v", ragName, ragEvent.Error))) - } - default: - // Log unhandled events for debugging - slog.Debug("Unhandled RAG event type", "type", ragEvent.Type, "rag", ragName) - } - } - } - }(mgr) - } -} - -// InitializeRAG is called within RunStream as a fallback when background init wasn't used -// (e.g., for exec command or API mode where there's no App) -func (r *LocalRuntime) InitializeRAG(ctx context.Context, events chan Event) { - // If already initialized via StartBackgroundRAGInit, skip entirely - // Event forwarding was already set up there - if r.ragInitialized.Swap(true) { - slog.Debug("RAG already initialized, event forwarding already active", "manager_count", len(r.team.RAGManagers())) - return - } - - ragManagers := r.team.RAGManagers() - if len(ragManagers) == 0 { - return - } - - slog.Debug("Setting up RAG initialization (fallback path for non-TUI)", "manager_count", len(ragManagers)) - - // Set up event forwarding BEFORE starting initialization - r.forwardRAGEvents(ctx, ragManagers, func(event Event) { - events <- event - }) - - // Start initialization and file watchers - r.team.InitializeRAG(ctx) - r.team.StartRAGFileWatchers(ctx) -} - func (r *LocalRuntime) CurrentAgentName() string { r.currentAgentMu.RLock() defer r.currentAgentMu.RUnlock() diff --git a/pkg/runtime/runtime_test.go b/pkg/runtime/runtime_test.go index 5587b2d8e..cd9d17a66 100644 --- a/pkg/runtime/runtime_test.go +++ b/pkg/runtime/runtime_test.go @@ -550,12 +550,8 @@ func TestStartBackgroundRAGInit_StopsForwardingAfterContextCancel(t *testing.T) _ = mgr.Close() }() - tm := team.New(team.WithRAGManagers(map[string]*rag.Manager{ - "default": mgr, - })) - rt := &LocalRuntime{ - team: tm, + team: team.New(team.WithRAGManagers([]*rag.Manager{mgr})), currentAgent: "root", } diff --git a/pkg/team/team.go b/pkg/team/team.go index 7de968c97..7e4e34a5e 100644 --- a/pkg/team/team.go +++ b/pkg/team/team.go @@ -15,7 +15,7 @@ import ( type Team struct { agents []*agent.Agent - ragManagers map[string]*rag.Manager + ragManagers []*rag.Manager permissions *permissions.Checker } @@ -27,7 +27,7 @@ func WithAgents(agents ...*agent.Agent) Opt { } } -func WithRAGManagers(managers map[string]*rag.Manager) Opt { +func WithRAGManagers(managers []*rag.Manager) Opt { return func(t *Team) { t.ragManagers = managers } @@ -40,9 +40,7 @@ func WithPermissions(checker *permissions.Checker) Opt { } func New(opts ...Opt) *Team { - t := &Team{ - ragManagers: make(map[string]*rag.Manager), - } + t := &Team{} for _, opt := range opts { opt(t) } @@ -139,36 +137,10 @@ func (t *Team) StopToolSets(ctx context.Context) error { } // RAGManagers returns the RAG managers for this team -func (t *Team) RAGManagers() map[string]*rag.Manager { +func (t *Team) RAGManagers() []*rag.Manager { return t.ragManagers } -// InitializeRAG initializes all RAG managers in the background -func (t *Team) InitializeRAG(ctx context.Context) { - for _, mgr := range t.ragManagers { - go func(m *rag.Manager) { - slog.Debug("Starting RAG manager initialization goroutine", "rag", m.Name()) - if err := m.Initialize(ctx); err != nil { - slog.Error("Failed to initialize RAG manager", "rag", m.Name(), "error", err) - } else { - slog.Info("RAG manager initialized successfully", "rag", m.Name()) - } - }(mgr) - } -} - -// StartRAGFileWatchers starts file watchers for all RAG managers -func (t *Team) StartRAGFileWatchers(ctx context.Context) { - for _, mgr := range t.ragManagers { - go func(m *rag.Manager) { - slog.Debug("Starting RAG file watcher goroutine", "rag", m.Name()) - if err := m.StartFileWatcher(ctx); err != nil { - slog.Error("Failed to start RAG file watcher", "rag", m.Name(), "error", err) - } - }(mgr) - } -} - // Permissions returns the permission checker for this team. // Returns nil if no permissions are configured. func (t *Team) Permissions() *permissions.Checker { diff --git a/pkg/teamloader/teamloader.go b/pkg/teamloader/teamloader.go index 8774d980e..55071373f 100644 --- a/pkg/teamloader/teamloader.go +++ b/pkg/teamloader/teamloader.go @@ -604,7 +604,7 @@ func contextWithExternalDepth(ctx context.Context, depth int) context.Context { } // createRAGToolsForAgent creates RAG tools for an agent, one for each referenced RAG source -func createRAGToolsForAgent(agentConfig *latest.AgentConfig, allManagers map[string]*rag.Manager) []tools.ToolSet { +func createRAGToolsForAgent(agentConfig *latest.AgentConfig, ragManagers []*rag.Manager) []tools.ToolSet { if len(agentConfig.RAG) == 0 { return nil } @@ -612,12 +612,16 @@ func createRAGToolsForAgent(agentConfig *latest.AgentConfig, allManagers map[str var ragTools []tools.ToolSet for _, ragName := range agentConfig.RAG { - mgr, exists := allManagers[ragName] - if !exists { + idx := slices.IndexFunc(ragManagers, func(m *rag.Manager) bool { + return m.Name() == ragName + }) + if idx == -1 { slog.Error("RAG source not found", "rag_source", ragName) continue } + mgr := ragManagers[idx] + // Use custom tool name if configured, otherwise use the RAG source name toolName := cmp.Or(mgr.ToolName(), ragName) From 5a768536564c18d59c13d837c9ab6c8076e6d094 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Sat, 21 Mar 2026 10:36:23 +0100 Subject: [PATCH 2/3] Remove dead code Signed-off-by: David Gageot --- pkg/runtime/event.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/pkg/runtime/event.go b/pkg/runtime/event.go index 4254e2b61..72c82bde7 100644 --- a/pkg/runtime/event.go +++ b/pkg/runtime/event.go @@ -635,7 +635,6 @@ type MessageAddedEvent struct { Message *session.Message `json:"-"` } -func (e *MessageAddedEvent) GetAgentName() string { return e.AgentName } func (e *MessageAddedEvent) GetSessionID() string { return e.SessionID } func MessageAdded(sessionID string, msg *session.Message, agentName string) Event { @@ -657,8 +656,6 @@ type SubSessionCompletedEvent struct { SubSession any `json:"sub_session"` // *session.Session } -func (e *SubSessionCompletedEvent) GetAgentName() string { return e.AgentName } - func SubSessionCompleted(parentSessionID string, subSession any, agentName string) Event { return &SubSessionCompletedEvent{ Type: "sub_session_completed", From 4265af5d3ad20fb58a7d773b2010c9a6563a86aa Mon Sep 17 00:00:00 2001 From: David Gageot Date: Sat, 21 Mar 2026 10:39:08 +0100 Subject: [PATCH 3/3] Those events are not linked to an agent Signed-off-by: David Gageot --- pkg/runtime/event.go | 8 ++++---- pkg/runtime/rag.go | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pkg/runtime/event.go b/pkg/runtime/event.go index 72c82bde7..d4b141d1e 100644 --- a/pkg/runtime/event.go +++ b/pkg/runtime/event.go @@ -558,12 +558,12 @@ type RAGIndexingStartedEvent struct { StrategyName string `json:"strategy_name"` } -func RAGIndexingStarted(ragName, strategyName, agentName string) Event { +func RAGIndexingStarted(ragName, strategyName string) Event { return &RAGIndexingStartedEvent{ Type: "rag_indexing_started", RAGName: ragName, StrategyName: strategyName, - AgentContext: newAgentContext(agentName), + AgentContext: newAgentContext(""), } } @@ -596,12 +596,12 @@ type RAGIndexingCompletedEvent struct { StrategyName string `json:"strategy_name"` } -func RAGIndexingCompleted(ragName, strategyName, agentName string) Event { +func RAGIndexingCompleted(ragName, strategyName string) Event { return &RAGIndexingCompletedEvent{ Type: "rag_indexing_completed", RAGName: ragName, StrategyName: strategyName, - AgentContext: newAgentContext(agentName), + AgentContext: newAgentContext(""), } } diff --git a/pkg/runtime/rag.go b/pkg/runtime/rag.go index 8d7c0a6b9..d1e87bd04 100644 --- a/pkg/runtime/rag.go +++ b/pkg/runtime/rag.go @@ -50,13 +50,13 @@ func (r *LocalRuntime) forwardRAGEvents(ctx context.Context, ragManagers []*rag. switch ragEvent.Type { case types.EventTypeIndexingStarted: - sendEvent(RAGIndexingStarted(ragName, ragEvent.StrategyName, agentName)) + sendEvent(RAGIndexingStarted(ragName, ragEvent.StrategyName)) case types.EventTypeIndexingProgress: if ragEvent.Progress != nil { sendEvent(RAGIndexingProgress(ragName, ragEvent.StrategyName, ragEvent.Progress.Current, ragEvent.Progress.Total, agentName)) } case types.EventTypeIndexingComplete: - sendEvent(RAGIndexingCompleted(ragName, ragEvent.StrategyName, agentName)) + sendEvent(RAGIndexingCompleted(ragName, ragEvent.StrategyName)) case types.EventTypeUsage: // Convert RAG usage to TokenUsageEvent so TUI displays it sendEvent(NewTokenUsageEvent("", agentName, &Usage{