Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion pkg/app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ import (
"github.com/docker/docker-agent/pkg/tui/messages"
)

// RAGInitializer is implemented by runtimes that support background RAG initialization.
// Local runtimes use this to start indexing early; remote runtimes typically do not.
type RAGInitializer interface {
StartBackgroundRAGInit(ctx context.Context, sendEvent func(runtime.Event))
}

type App struct {
runtime runtime.Runtime
session *session.Session
Expand Down Expand Up @@ -119,7 +125,7 @@ func New(ctx context.Context, rt runtime.Runtime, sess *session.Session, opts ..
// If the runtime supports background RAG initialization, start it
// and forward events to the TUI. Remote runtimes typically handle RAG server-side
// and won't implement this optional interface.
if ragRuntime, ok := rt.(runtime.RAGInitializer); ok {
if ragRuntime, ok := rt.(RAGInitializer); ok {
go ragRuntime.StartBackgroundRAGInit(ctx, func(event runtime.Event) {
select {
case app.events <- event:
Expand Down
14 changes: 5 additions & 9 deletions pkg/rag/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,13 @@ type ManagersBuildConfig struct {
}

// NewManagers constructs all RAG managers defined in the config.
func NewManagers(
ctx context.Context,
cfg *latest.Config,
buildCfg ManagersBuildConfig,
) (map[string]*Manager, error) {
managers := make(map[string]*Manager)

func NewManagers(ctx context.Context, cfg *latest.Config, buildCfg ManagersBuildConfig) ([]*Manager, error) {
if len(cfg.RAG) == 0 {
return managers, nil
return nil, nil
}

var managers []*Manager

for ragName, ragCfg := range cfg.RAG {
// Validate that we have at least one strategy
if len(ragCfg.Strategies) == 0 {
Expand Down Expand Up @@ -69,7 +65,7 @@ func NewManagers(
return nil, fmt.Errorf("failed to create RAG manager %q: %w", ragName, err)
}

managers[ragName] = manager
managers = append(managers, manager)

strategyNames := make([]string, len(strategyConfigs))
for i, sc := range strategyConfigs {
Expand Down
11 changes: 4 additions & 7 deletions pkg/runtime/event.go
Original file line number Diff line number Diff line change
Expand Up @@ -558,12 +558,12 @@ type RAGIndexingStartedEvent struct {
StrategyName string `json:"strategy_name"`
}

func RAGIndexingStarted(ragName, strategyName, agentName string) Event {
func RAGIndexingStarted(ragName, strategyName string) Event {
return &RAGIndexingStartedEvent{
Type: "rag_indexing_started",
RAGName: ragName,
StrategyName: strategyName,
AgentContext: newAgentContext(agentName),
AgentContext: newAgentContext(""),
}
}

Expand Down Expand Up @@ -596,12 +596,12 @@ type RAGIndexingCompletedEvent struct {
StrategyName string `json:"strategy_name"`
}

func RAGIndexingCompleted(ragName, strategyName, agentName string) Event {
func RAGIndexingCompleted(ragName, strategyName string) Event {
return &RAGIndexingCompletedEvent{
Type: "rag_indexing_completed",
RAGName: ragName,
StrategyName: strategyName,
AgentContext: newAgentContext(agentName),
AgentContext: newAgentContext(""),
}
}

Expand Down Expand Up @@ -635,7 +635,6 @@ type MessageAddedEvent struct {
Message *session.Message `json:"-"`
}

func (e *MessageAddedEvent) GetAgentName() string { return e.AgentName }
func (e *MessageAddedEvent) GetSessionID() string { return e.SessionID }

func MessageAdded(sessionID string, msg *session.Message, agentName string) Event {
Expand All @@ -657,8 +656,6 @@ type SubSessionCompletedEvent struct {
SubSession any `json:"sub_session"` // *session.Session
}

func (e *SubSessionCompletedEvent) GetAgentName() string { return e.AgentName }

func SubSessionCompleted(parentSessionID string, subSession any, agentName string) Event {
return &SubSessionCompletedEvent{
Type: "sub_session_completed",
Expand Down
4 changes: 3 additions & 1 deletion pkg/runtime/loop.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c
events <- TeamInfo(r.agentDetailsFromTeam(), a.Name())

// Initialize RAG and forward events
r.InitializeRAG(ctx, events)
r.StartBackgroundRAGInit(ctx, func(event Event) {
events <- event
})

r.emitAgentWarnings(a, chanSend(events))
r.configureToolsetHandlers(a, events)
Expand Down
105 changes: 105 additions & 0 deletions pkg/runtime/rag.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package runtime

import (
"context"
"fmt"
"log/slog"

"github.com/docker/docker-agent/pkg/rag"
"github.com/docker/docker-agent/pkg/rag/types"
)

// StartBackgroundRAGInit initializes RAG in background and forwards events
// Should be called early (e.g., by App) to start indexing before RunStream
func (r *LocalRuntime) StartBackgroundRAGInit(ctx context.Context, sendEvent func(Event)) {
if r.ragInitialized.Swap(true) {
return
}

ragManagers := r.team.RAGManagers()
if len(ragManagers) == 0 {
return
}

// Set up event forwarding BEFORE starting initialization
r.forwardRAGEvents(ctx, ragManagers, sendEvent)
initializeRAG(ctx, ragManagers)
startRAGFileWatchers(ctx, ragManagers)
}

// forwardRAGEvents forwards RAG manager events to the given callback
// Consolidates duplicated event forwarding logic
func (r *LocalRuntime) forwardRAGEvents(ctx context.Context, ragManagers []*rag.Manager, sendEvent func(Event)) {
for _, mgr := range ragManagers {
go func() {
ragName := mgr.Name()
slog.Debug("Starting RAG event forwarder goroutine", "rag", ragName)
for {
select {
case <-ctx.Done():
slog.Debug("RAG event forwarder stopped", "rag", ragName)
return
case ragEvent, ok := <-mgr.Events():
if !ok {
slog.Debug("RAG events channel closed", "rag", ragName)
return
}

agentName := r.CurrentAgentName()
slog.Debug("Forwarding RAG event", "type", ragEvent.Type, "rag", ragName, "agent", agentName)

switch ragEvent.Type {
case types.EventTypeIndexingStarted:
sendEvent(RAGIndexingStarted(ragName, ragEvent.StrategyName))
case types.EventTypeIndexingProgress:
if ragEvent.Progress != nil {
sendEvent(RAGIndexingProgress(ragName, ragEvent.StrategyName, ragEvent.Progress.Current, ragEvent.Progress.Total, agentName))
}
case types.EventTypeIndexingComplete:
sendEvent(RAGIndexingCompleted(ragName, ragEvent.StrategyName))
case types.EventTypeUsage:
// Convert RAG usage to TokenUsageEvent so TUI displays it
sendEvent(NewTokenUsageEvent("", agentName, &Usage{
InputTokens: ragEvent.TotalTokens,
ContextLength: ragEvent.TotalTokens,
Cost: ragEvent.Cost,
}))
case types.EventTypeError:
if ragEvent.Error != nil {
sendEvent(Error(fmt.Sprintf("RAG %s error: %v", ragName, ragEvent.Error)))
}
default:
// Log unhandled events for debugging
slog.Debug("Unhandled RAG event type", "type", ragEvent.Type, "rag", ragName)
}
}
}
}()
}
}

// InitializeRAG initializes all RAG managers in the background
func initializeRAG(ctx context.Context, ragManagers []*rag.Manager) {
for _, mgr := range ragManagers {
go func() {
slog.Debug("Starting RAG manager initialization goroutine", "rag", mgr.Name())
if err := mgr.Initialize(ctx); err != nil {
slog.Error("Failed to initialize RAG manager", "rag", mgr.Name(), "error", err)
} else {
slog.Info("RAG manager initialized successfully", "rag", mgr.Name())
}
}()
}
}

// StartRAGFileWatchers starts file watchers for all RAG managers
func startRAGFileWatchers(ctx context.Context, ragManagers []*rag.Manager) {
for _, mgr := range ragManagers {
go func() {
slog.Debug("Starting RAG file watcher goroutine", "rag", mgr.Name())
if err := mgr.StartFileWatcher(ctx); err != nil {
slog.Error("Failed to start RAG file watcher", "rag", mgr.Name(), "error", err)
}
}()
}
}
109 changes: 0 additions & 109 deletions pkg/runtime/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ import (
"github.com/docker/docker-agent/pkg/config/types"
"github.com/docker/docker-agent/pkg/hooks"
"github.com/docker/docker-agent/pkg/modelsdev"
"github.com/docker/docker-agent/pkg/rag"
ragtypes "github.com/docker/docker-agent/pkg/rag/types"
"github.com/docker/docker-agent/pkg/session"
"github.com/docker/docker-agent/pkg/sessiontitle"
"github.com/docker/docker-agent/pkg/team"
Expand Down Expand Up @@ -163,12 +161,6 @@ type ModelStore interface {
GetDatabase(ctx context.Context) (*modelsdev.Database, error)
}

// RAGInitializer is implemented by runtimes that support background RAG initialization.
// Local runtimes use this to start indexing early; remote runtimes typically do not.
type RAGInitializer interface {
StartBackgroundRAGInit(ctx context.Context, sendEvent func(Event))
}

// ToolsChangeSubscriber is implemented by runtimes that can notify when
// toolsets report a change in their tool list (e.g. after an MCP
// ToolListChanged notification). The provided callback is invoked
Expand Down Expand Up @@ -340,107 +332,6 @@ func NewLocalRuntime(agents *team.Team, opts ...Opt) (*LocalRuntime, error) {
return r, nil
}

// StartBackgroundRAGInit initializes RAG in background and forwards events
// Should be called early (e.g., by App) to start indexing before RunStream
func (r *LocalRuntime) StartBackgroundRAGInit(ctx context.Context, sendEvent func(Event)) {
if r.ragInitialized.Swap(true) {
return
}

ragManagers := r.team.RAGManagers()
if len(ragManagers) == 0 {
return
}

slog.Debug("Starting background RAG initialization with event forwarding", "manager_count", len(ragManagers))

// Set up event forwarding BEFORE starting initialization
// This ensures all events are captured
r.forwardRAGEvents(ctx, ragManagers, sendEvent)

// Now start initialization (events will be forwarded)
r.team.InitializeRAG(ctx)
r.team.StartRAGFileWatchers(ctx)
}

// forwardRAGEvents forwards RAG manager events to the given callback
// Consolidates duplicated event forwarding logic
func (r *LocalRuntime) forwardRAGEvents(ctx context.Context, ragManagers map[string]*rag.Manager, sendEvent func(Event)) {
for _, mgr := range ragManagers {
go func(mgr *rag.Manager) {
ragName := mgr.Name()
slog.Debug("Starting RAG event forwarder goroutine", "rag", ragName)
for {
select {
case <-ctx.Done():
slog.Debug("RAG event forwarder stopped", "rag", ragName)
return
case ragEvent, ok := <-mgr.Events():
if !ok {
slog.Debug("RAG events channel closed", "rag", ragName)
return
}

agentName := r.CurrentAgentName()
slog.Debug("Forwarding RAG event", "type", ragEvent.Type, "rag", ragName, "agent", agentName)

switch ragEvent.Type {
case ragtypes.EventTypeIndexingStarted:
sendEvent(RAGIndexingStarted(ragName, ragEvent.StrategyName, agentName))
case ragtypes.EventTypeIndexingProgress:
if ragEvent.Progress != nil {
sendEvent(RAGIndexingProgress(ragName, ragEvent.StrategyName, ragEvent.Progress.Current, ragEvent.Progress.Total, agentName))
}
case ragtypes.EventTypeIndexingComplete:
sendEvent(RAGIndexingCompleted(ragName, ragEvent.StrategyName, agentName))
case ragtypes.EventTypeUsage:
// Convert RAG usage to TokenUsageEvent so TUI displays it
sendEvent(NewTokenUsageEvent("", agentName, &Usage{
InputTokens: ragEvent.TotalTokens,
ContextLength: ragEvent.TotalTokens,
Cost: ragEvent.Cost,
}))
case ragtypes.EventTypeError:
if ragEvent.Error != nil {
sendEvent(Error(fmt.Sprintf("RAG %s error: %v", ragName, ragEvent.Error)))
}
default:
// Log unhandled events for debugging
slog.Debug("Unhandled RAG event type", "type", ragEvent.Type, "rag", ragName)
}
}
}
}(mgr)
}
}

// InitializeRAG is called within RunStream as a fallback when background init wasn't used
// (e.g., for exec command or API mode where there's no App)
func (r *LocalRuntime) InitializeRAG(ctx context.Context, events chan Event) {
// If already initialized via StartBackgroundRAGInit, skip entirely
// Event forwarding was already set up there
if r.ragInitialized.Swap(true) {
slog.Debug("RAG already initialized, event forwarding already active", "manager_count", len(r.team.RAGManagers()))
return
}

ragManagers := r.team.RAGManagers()
if len(ragManagers) == 0 {
return
}

slog.Debug("Setting up RAG initialization (fallback path for non-TUI)", "manager_count", len(ragManagers))

// Set up event forwarding BEFORE starting initialization
r.forwardRAGEvents(ctx, ragManagers, func(event Event) {
events <- event
})

// Start initialization and file watchers
r.team.InitializeRAG(ctx)
r.team.StartRAGFileWatchers(ctx)
}

func (r *LocalRuntime) CurrentAgentName() string {
r.currentAgentMu.RLock()
defer r.currentAgentMu.RUnlock()
Expand Down
6 changes: 1 addition & 5 deletions pkg/runtime/runtime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -550,12 +550,8 @@ func TestStartBackgroundRAGInit_StopsForwardingAfterContextCancel(t *testing.T)
_ = mgr.Close()
}()

tm := team.New(team.WithRAGManagers(map[string]*rag.Manager{
"default": mgr,
}))

rt := &LocalRuntime{
team: tm,
team: team.New(team.WithRAGManagers([]*rag.Manager{mgr})),
currentAgent: "root",
}

Expand Down
Loading
Loading