Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions internal/agent/dreaming/training_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,20 @@ func (da *DreamingAgent) AssembleTrainingBatch(ctx context.Context, outputDir st
totalWritten++
}

// Mark all assembled entries as used so they don't re-trigger training.
// This is critical: without it, CountUntrainedExperience never drops,
// and every dreaming cycle re-triggers training on the same data.
var usedEntryIDs []string
for _, entry := range goldEntries {
usedEntryIDs = append(usedEntryIDs, entry.ID)
}
for _, entry := range corrected {
usedEntryIDs = append(usedEntryIDs, entry.ID)
}
if err := da.store.MarkExperienceUsedInTraining(ctx, batchID, usedEntryIDs); err != nil {
da.log.Warn("failed to mark experience as used in training", "error", err, "count", len(usedEntryIDs))
}

manifest := &TrainingBatchManifest{
ID: batchID,
CreatedAt: time.Now(),
Expand Down
54 changes: 54 additions & 0 deletions internal/agent/dreaming/training_trigger.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,19 @@ func TrainingRequestsDir() string {
return filepath.Join(homeDir, ".mnemonic", "training_requests")
}

// TrainingDisabledPath returns the path to the e-stop sentinel file.
// When this file exists, all auto and manual training is blocked.
func TrainingDisabledPath() string {
homeDir, _ := os.UserHomeDir()
return filepath.Join(homeDir, ".mnemonic", "training.disabled")
}

// isTrainingDisabled checks for the e-stop sentinel file.
func isTrainingDisabled() bool {
_, err := os.Stat(TrainingDisabledPath())
return err == nil
}

// trainingCheck runs Phase 4.85: check if we should trigger spoke training.
// Only runs during dreaming if auto-trigger is enabled. Also callable via MCP.
func (da *DreamingAgent) trainingCheck(ctx context.Context, clCfg config.ContinuousLearningConfig) (*TrainingResult, error) {
Expand All @@ -74,12 +87,53 @@ func (da *DreamingAgent) trainingCheck(ctx context.Context, clCfg config.Continu
return nil, nil
}

// E-stop: check for sentinel file
if isTrainingDisabled() {
da.log.Warn("training disabled by e-stop file", "path", TrainingDisabledPath())
return nil, nil
}

// Check training window
if clCfg.Trigger.TrainingWindow != "" && !inTrainingWindow(clCfg.Trigger.TrainingWindow) {
da.log.Debug("outside training window, skipping", "window", clCfg.Trigger.TrainingWindow)
return nil, nil
}

// Circuit breaker: stop after too many consecutive failures
maxFailures := clCfg.Trigger.MaxConsecutiveFailures
if maxFailures <= 0 {
maxFailures = 3
}
consecutiveFailures, err := da.store.CountConsecutiveFailedTrainingRuns(ctx)
if err != nil {
da.log.Warn("failed to check consecutive training failures", "error", err)
} else if consecutiveFailures >= maxFailures {
da.log.Warn("training circuit breaker open: too many consecutive failures",
"consecutive_failures", consecutiveFailures, "max", maxFailures)
return nil, nil
}

// Cooldown: don't re-trigger too soon after a failed run
cooldownHours := clCfg.Trigger.FailureCooldownHours
if cooldownHours <= 0 {
cooldownHours = 24
}
if consecutiveFailures > 0 {
lastEnd, endErr := da.store.GetLastTrainingRunEndTime(ctx)
if endErr != nil {
da.log.Warn("failed to check last training run time", "error", endErr)
} else if !lastEnd.IsZero() {
cooldown := time.Duration(cooldownHours) * time.Hour
if time.Since(lastEnd) < cooldown {
da.log.Info("training skipped: cooling down after failure",
"last_run_ended", lastEnd.Format(time.RFC3339),
"cooldown_hours", cooldownHours,
"consecutive_failures", consecutiveFailures)
return nil, nil
}
}
}

return da.RunTrainingCycle(ctx, clCfg, "auto")
}

Expand Down
160 changes: 153 additions & 7 deletions internal/agent/dreaming/training_trigger_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,16 @@ import (
// triggerMockStore provides controlled responses for training trigger tests.
type triggerMockStore struct {
storetest.MockStore
untrainedCount int
goldEntries []store.ExperienceEntry
needsImpEntries []store.ExperienceEntry
rawMemories map[string]store.RawMemory
memories map[string]store.Memory
trainingRunsW []store.TrainingRun
trainingRunsU []store.TrainingRun
untrainedCount int
goldEntries []store.ExperienceEntry
needsImpEntries []store.ExperienceEntry
rawMemories map[string]store.RawMemory
memories map[string]store.Memory
trainingRunsW []store.TrainingRun
trainingRunsU []store.TrainingRun
consecutiveFailures int
lastTrainingRunEndTime time.Time
markedUsedEntryIDs []string
}

func (m *triggerMockStore) CountUntrainedExperience(_ context.Context) (int, error) {
Expand Down Expand Up @@ -73,6 +76,19 @@ func (m *triggerMockStore) UpdateTrainingRun(_ context.Context, run store.Traini
return nil
}

func (m *triggerMockStore) CountConsecutiveFailedTrainingRuns(_ context.Context) (int, error) {
return m.consecutiveFailures, nil
}

func (m *triggerMockStore) GetLastTrainingRunEndTime(_ context.Context) (time.Time, error) {
return m.lastTrainingRunEndTime, nil
}

func (m *triggerMockStore) MarkExperienceUsedInTraining(_ context.Context, _ string, entryIDs []string) error {
m.markedUsedEntryIDs = append(m.markedUsedEntryIDs, entryIDs...)
return nil
}

func baseCLConfig() config.ContinuousLearningConfig {
return config.ContinuousLearningConfig{
Enabled: true,
Expand Down Expand Up @@ -340,6 +356,136 @@ func TestPickUpTrainingResult_FailedRun(t *testing.T) {
}
}

func TestTrainingCheck_CircuitBreakerBlocks(t *testing.T) {
ms := &triggerMockStore{
untrainedCount: 100,
consecutiveFailures: 3,
}
agent := NewDreamingAgent(ms, nil, DreamingConfig{Interval: time.Hour}, slog.New(slog.NewTextHandler(io.Discard, nil)))

clCfg := baseCLConfig()
clCfg.Trigger.MaxConsecutiveFailures = 3

result, err := agent.trainingCheck(context.Background(), clCfg)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result != nil {
t.Fatal("expected nil result when circuit breaker is open")
}
}

func TestTrainingCheck_CooldownBlocks(t *testing.T) {
ms := &triggerMockStore{
untrainedCount: 100,
consecutiveFailures: 1,
lastTrainingRunEndTime: time.Now().Add(-30 * time.Minute), // 30 min ago
}
agent := NewDreamingAgent(ms, nil, DreamingConfig{Interval: time.Hour}, slog.New(slog.NewTextHandler(io.Discard, nil)))

clCfg := baseCLConfig()
clCfg.Trigger.FailureCooldownHours = 24

result, err := agent.trainingCheck(context.Background(), clCfg)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result != nil {
t.Fatal("expected nil result during cooldown period")
}
}

func TestTrainingCheck_AllowsAfterCooldown(t *testing.T) {
tmpDir := t.TempDir()
t.Setenv("MNEMONIC_TRAINING_REQUESTS_DIR", tmpDir)

ms := &triggerMockStore{
untrainedCount: 100,
consecutiveFailures: 1,
lastTrainingRunEndTime: time.Now().Add(-25 * time.Hour), // 25h ago
goldEntries: []store.ExperienceEntry{
{ID: "e1", RawID: "raw-1", MemoryID: "mem-1", EncodingEPR: 0.95, Category: "gold"},
},
rawMemories: map[string]store.RawMemory{
"raw-1": {ID: "raw-1", Content: "Test", Source: "mcp", Type: "general"},
},
memories: map[string]store.Memory{
"mem-1": {ID: "mem-1", Summary: "test", Content: "test", Concepts: []string{"test"}, Salience: 0.5},
},
}
agent := NewDreamingAgent(ms, nil, DreamingConfig{Interval: time.Hour}, slog.New(slog.NewTextHandler(io.Discard, nil)))

clCfg := baseCLConfig()
clCfg.Trigger.FailureCooldownHours = 24

result, err := agent.trainingCheck(context.Background(), clCfg)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result == nil {
t.Fatal("expected training to proceed after cooldown expires")
}
if result.Status != "training_requested" {
t.Errorf("expected status 'training_requested', got %q", result.Status)
}
}

func TestTrainingCheck_EStopBlocks(t *testing.T) {
tmpDir := t.TempDir()
estopPath := filepath.Join(tmpDir, ".mnemonic", "training.disabled")
if err := os.MkdirAll(filepath.Dir(estopPath), 0o755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(estopPath, []byte("stopped"), 0o644); err != nil {
t.Fatal(err)
}
// Override HOME so isTrainingDisabled() finds the sentinel
t.Setenv("HOME", tmpDir)

ms := &triggerMockStore{untrainedCount: 100}
agent := NewDreamingAgent(ms, nil, DreamingConfig{Interval: time.Hour}, slog.New(slog.NewTextHandler(io.Discard, nil)))

result, err := agent.trainingCheck(context.Background(), baseCLConfig())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result != nil {
t.Fatal("expected nil result when e-stop file exists")
}
}

func TestAssembleTrainingBatch_MarksExperienceAsUsed(t *testing.T) {
tmpDir := t.TempDir()

ms := &triggerMockStore{
untrainedCount: 10,
goldEntries: []store.ExperienceEntry{
{ID: "e1", RawID: "raw-1", MemoryID: "mem-1", EncodingEPR: 0.95, Category: "gold"},
{ID: "e2", RawID: "raw-2", MemoryID: "mem-2", EncodingEPR: 0.90, Category: "gold"},
},
rawMemories: map[string]store.RawMemory{
"raw-1": {ID: "raw-1", Content: "First event", Source: "mcp", Type: "general"},
"raw-2": {ID: "raw-2", Content: "Second event", Source: "mcp", Type: "general"},
},
memories: map[string]store.Memory{
"mem-1": {ID: "mem-1", Summary: "first", Content: "first content", Concepts: []string{"test"}, Salience: 0.5},
"mem-2": {ID: "mem-2", Summary: "second", Content: "second content", Concepts: []string{"test"}, Salience: 0.5},
},
}

agent := NewDreamingAgent(ms, nil, DreamingConfig{Interval: time.Hour}, slog.New(slog.NewTextHandler(io.Discard, nil)))

_, err := agent.AssembleTrainingBatch(context.Background(), tmpDir, 50)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

// Both gold entries should be marked as used
if len(ms.markedUsedEntryIDs) != 2 {
t.Fatalf("expected 2 entries marked as used, got %d", len(ms.markedUsedEntryIDs))
}
}

func TestInTrainingWindow(t *testing.T) {
tests := []struct {
name string
Expand Down
Loading