From e5a5d595836c0abc2d07dad4f47c9083361886e7 Mon Sep 17 00:00:00 2001 From: Caleb Gross Date: Mon, 13 Apr 2026 18:11:18 -0400 Subject: [PATCH 1/8] =?UTF-8?q?feat:=20continuous=20learning=20Phase=20B?= =?UTF-8?q?=20=E2=80=94=20curriculum=20generation=20(#391)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds curriculum generation to the dreaming agent (Phase 4.75). When enabled, the dreaming cycle re-encodes needs_improvement memories via the teacher model (API provider), producing corrected outputs that become training pairs for the local spoke model. New infrastructure: - CLCurriculumConfig with enable flag, cooldown, and batch limits - Migration 017: corrected_output columns on experience_buffer + curriculum_runs tracking table - 5 new ContinuousLearningStore methods (ListNeedsImprovement, UpdateExperienceCorrectedOutput, curriculum run CRUD) - Export BuildCompressionPrompt for cross-package prompt reuse - 5 store tests covering correction lifecycle, dedup, and limits The pipeline: reclassify experience buffer → fetch worst entries → rebuild identical encoding prompt → call teacher model → validate response (JSON + required fields + EPR > 0.7) → store correction. Gated by config flag, minimum entry threshold, and cooldown timer. Co-Authored-By: Claude Opus 4.6 (1M context) --- cmd/mnemonic/serve.go | 1 + internal/agent/dreaming/agent.go | 12 + internal/agent/dreaming/curriculum.go | 194 +++++++++++++++ internal/agent/encoding/agent.go | 6 + internal/config/config.go | 15 +- internal/store/sqlite/continuous_learning.go | 115 +++++++++ .../store/sqlite/continuous_learning_test.go | 220 ++++++++++++++++++ internal/store/sqlite/schema.go | 32 ++- internal/store/store.go | 33 ++- internal/store/storetest/mock.go | 11 + migrations/008_curriculum_generation.sql | 21 ++ 11 files changed, 654 insertions(+), 6 deletions(-) create mode 100644 internal/agent/dreaming/curriculum.go create mode 100644 internal/store/sqlite/continuous_learning_test.go create mode 100644 migrations/008_curriculum_generation.sql diff --git a/cmd/mnemonic/serve.go b/cmd/mnemonic/serve.go index 997172c1..4b3e21da 100644 --- a/cmd/mnemonic/serve.go +++ b/cmd/mnemonic/serve.go @@ -517,6 +517,7 @@ func serveCommand(configPath string) { DeadMemoryWindow: cfg.Dreaming.DeadMemoryWindow, InsightsBudget: cfg.Dreaming.InsightsBudget, DefaultConfidence: cfg.Dreaming.DefaultConfidence, + Curriculum: cfg.ContinuousLearning.Curriculum, }, log) if err := dreamer.Start(rootCtx, bus); err != nil { diff --git a/internal/agent/dreaming/agent.go b/internal/agent/dreaming/agent.go index da49bcd3..2d61dce9 100644 --- a/internal/agent/dreaming/agent.go +++ b/internal/agent/dreaming/agent.go @@ -11,6 +11,7 @@ import ( "time" "github.com/appsprout-dev/mnemonic/internal/agent/agentutil" + "github.com/appsprout-dev/mnemonic/internal/config" "github.com/appsprout-dev/mnemonic/internal/events" "github.com/appsprout-dev/mnemonic/internal/llm" "github.com/appsprout-dev/mnemonic/internal/store" @@ -26,6 +27,7 @@ type DreamingConfig struct { DeadMemoryWindow time.Duration InsightsBudget int DefaultConfidence float32 + Curriculum config.CLCurriculumConfig } type DreamingAgent struct { @@ -183,6 +185,16 @@ func (da *DreamingAgent) runCycle(ctx context.Context) (*DreamReport, error) { da.log.Info("reclassified experience buffer entries", "count", reclassified) } + // Phase 4.75: Curriculum generation — re-encode bad memories via teacher model + if currReport, err := da.curriculumGeneration(ctx, da.config.Curriculum); err != nil && ctx.Err() == nil { + da.log.Error("curriculum generation phase failed", "error", err) + } else if currReport != nil { + da.log.Info("curriculum generation completed", + "attempted", currReport.CorrectionsAttempted, + "passed", currReport.CorrectionsPassed, + "failed", currReport.CorrectionsFailed) + } + // Phase 5: Link replayed memories to matching patterns if err := da.linkToPatterns(ctx, replayed, report); err != nil && ctx.Err() == nil { da.log.Error("pattern linking phase failed", "error", err) diff --git a/internal/agent/dreaming/curriculum.go b/internal/agent/dreaming/curriculum.go new file mode 100644 index 00000000..93373bcf --- /dev/null +++ b/internal/agent/dreaming/curriculum.go @@ -0,0 +1,194 @@ +package dreaming + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "time" + + "github.com/appsprout-dev/mnemonic/internal/agent/agentutil" + "github.com/appsprout-dev/mnemonic/internal/agent/encoding" + "github.com/appsprout-dev/mnemonic/internal/config" + "github.com/appsprout-dev/mnemonic/internal/llm" + "github.com/appsprout-dev/mnemonic/internal/store" + "github.com/google/uuid" +) + +// CurriculumReport tracks the results of a curriculum generation cycle. +type CurriculumReport struct { + CorrectionsAttempted int + CorrectionsPassed int + CorrectionsFailed int +} + +// curriculumGeneration runs Phase 4.75: re-encode bad memories via teacher model (Gemini API). +// Produces corrected outputs that become training pairs for the local spoke model. +func (da *DreamingAgent) curriculumGeneration(ctx context.Context, cfg config.CLCurriculumConfig) (*CurriculumReport, error) { + if !cfg.Enabled { + return nil, nil + } + + // Check minimum entries + stats, err := da.store.GetExperienceBufferStats(ctx) + if err != nil { + return nil, fmt.Errorf("getting experience stats: %w", err) + } + if stats.NeedsImprovement < cfg.MinNeedsImprovement { + da.log.Debug("curriculum generation skipped: insufficient needs_improvement entries", + "have", stats.NeedsImprovement, "need", cfg.MinNeedsImprovement) + return nil, nil + } + + // Check cooldown + lastRun, err := da.store.GetLastCurriculumRunTime(ctx) + if err != nil { + return nil, fmt.Errorf("getting last curriculum run time: %w", err) + } + if !lastRun.IsZero() && time.Since(lastRun) < time.Duration(cfg.CooldownHours)*time.Hour { + da.log.Debug("curriculum generation skipped: cooldown active", + "last_run", lastRun, "cooldown_hours", cfg.CooldownHours) + return nil, nil + } + + // Get entries to correct + entries, err := da.store.ListNeedsImprovement(ctx, cfg.MaxCorrectionsPerCycle) + if err != nil { + return nil, fmt.Errorf("listing needs_improvement entries: %w", err) + } + if len(entries) == 0 { + return nil, nil + } + + // Start a curriculum run + run := store.CurriculumRun{ + ID: uuid.New().String(), + StartedAt: time.Now(), + Status: "running", + } + if err := da.store.WriteCurriculumRun(ctx, run); err != nil { + return nil, fmt.Errorf("writing curriculum run: %w", err) + } + + report := &CurriculumReport{} + + for _, entry := range entries { + if ctx.Err() != nil { + break + } + + report.CorrectionsAttempted++ + + if err := da.correctEntry(ctx, entry); err != nil { + da.log.Warn("curriculum correction failed", + "entry_id", entry.ID, "memory_id", entry.MemoryID, "error", err) + report.CorrectionsFailed++ + continue + } + report.CorrectionsPassed++ + } + + // Complete the run + now := time.Now() + run.CompletedAt = &now + run.CorrectionsAttempted = report.CorrectionsAttempted + run.CorrectionsPassed = report.CorrectionsPassed + run.CorrectionsFailed = report.CorrectionsFailed + run.Status = "completed" + if err := da.store.UpdateCurriculumRun(ctx, run); err != nil { + da.log.Warn("failed to update curriculum run", "error", err) + } + + return report, nil +} + +// correctEntry re-encodes a single bad memory using the teacher model (API provider). +func (da *DreamingAgent) correctEntry(ctx context.Context, entry store.ExperienceEntry) error { + // Load the original raw memory + raw, err := da.store.GetRaw(ctx, entry.RawID) + if err != nil { + return fmt.Errorf("loading raw memory %s: %w", entry.RawID, err) + } + + // Build the same prompt the local model saw + truncatedContent := agentutil.Truncate(raw.Content, 4000) + prompt := encoding.BuildCompressionPrompt(truncatedContent, raw.Source, raw.Type, "", "", nil) + + // Call the teacher model + req := llm.CompletionRequest{ + Messages: []llm.Message{ + {Role: "system", Content: "You are a memory encoder. You receive events and output structured JSON. Never explain, never apologize, never chat. Just fill in the JSON fields based on the event data."}, + {Role: "user", Content: prompt}, + }, + MaxTokens: 1024, + Temperature: 0.1, + } + + resp, err := da.llmProvider.Complete(ctx, req) + if err != nil { + return fmt.Errorf("teacher model completion failed: %w", err) + } + + // Parse and validate the response + jsonStr := agentutil.ExtractJSON(resp.Content) + if jsonStr == "" { + return fmt.Errorf("teacher model returned no valid JSON") + } + + // Basic structure validation — must be valid JSON with required fields + var parsed map[string]any + if err := json.Unmarshal([]byte(jsonStr), &parsed); err != nil { + return fmt.Errorf("teacher model response not valid JSON: %w", err) + } + + // Check required fields exist + for _, field := range []string{"summary", "content", "concepts"} { + if _, ok := parsed[field]; !ok { + return fmt.Errorf("teacher model response missing required field: %s", field) + } + } + + // Compute EPR on the corrected output + epr := computeSimpleEPR(raw.Content, jsonStr) + if epr < 0.7 { + return fmt.Errorf("teacher model output EPR too low (%.2f), skipping", epr) + } + + // Store the corrected output + if err := da.store.UpdateExperienceCorrectedOutput(ctx, entry.ID, jsonStr, epr, 0.0, "api"); err != nil { + return fmt.Errorf("storing corrected output: %w", err) + } + + da.log.Info("curriculum correction stored", + "entry_id", entry.ID, "original_epr", entry.EncodingEPR, "corrected_epr", epr) + return nil +} + +// computeSimpleEPR calculates a basic Entity Preservation Rate by checking +// how many significant tokens from the input appear in the output. +func computeSimpleEPR(rawContent, outputJSON string) float64 { + rawLower := strings.ToLower(rawContent) + outLower := strings.ToLower(outputJSON) + + // Extract tokens of 4+ characters (likely meaningful entities/terms) + words := strings.Fields(rawLower) + var significant int + var preserved int + for _, w := range words { + // Skip short common words + clean := strings.Trim(w, ".,;:!?\"'()[]{}") + if len(clean) < 4 { + continue + } + significant++ + if strings.Contains(outLower, clean) { + preserved++ + } + } + + if significant == 0 { + return 1.0 + } + return float64(preserved) / float64(significant) +} + diff --git a/internal/agent/encoding/agent.go b/internal/agent/encoding/agent.go index f993256e..1d01f735 100644 --- a/internal/agent/encoding/agent.go +++ b/internal/agent/encoding/agent.go @@ -1272,6 +1272,12 @@ func (ea *EncodingAgent) compressAndExtractConcepts(ctx context.Context, raw sto // The previous verbose prompt (field-by-field descriptions, concept vocabulary, // coaching instructions) actively hurt faithfulness by confusing the model with // noise. See training/docs/experiment_registry.md EXP-29 for full data. +// BuildCompressionPrompt constructs the encoding prompt for a raw memory. +// Exported for use by curriculum generation (dreaming agent Phase B). +func BuildCompressionPrompt(content, source, memType, episodeCtx, coachingInstructions string, conceptVocabulary []string) string { + return buildCompressionPrompt(content, source, memType, episodeCtx, coachingInstructions, conceptVocabulary) +} + func buildCompressionPrompt(content, source, memType, episodeCtx, coachingInstructions string, conceptVocabulary []string) string { var b strings.Builder diff --git a/internal/config/config.go b/internal/config/config.go index 5da43ff2..de9df32e 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -481,9 +481,18 @@ type CoachingConfig struct { // ContinuousLearningConfig holds settings for the continuous learning pipeline. type ContinuousLearningConfig struct { - Enabled bool `yaml:"enabled"` - Training CLTrainingConfig `yaml:"training"` - Trigger CLTriggerConfig `yaml:"trigger"` + Enabled bool `yaml:"enabled"` + Training CLTrainingConfig `yaml:"training"` + Curriculum CLCurriculumConfig `yaml:"curriculum"` + Trigger CLTriggerConfig `yaml:"trigger"` +} + +// CLCurriculumConfig holds settings for Phase B curriculum generation. +type CLCurriculumConfig struct { + Enabled bool `yaml:"enabled"` // enable curriculum generation in dreaming (default: false) + MaxCorrectionsPerCycle int `yaml:"max_corrections_per_cycle"` // max entries to re-encode per dream cycle (default: 20) + MinNeedsImprovement int `yaml:"min_needs_improvement"` // min needs_improvement entries before running (default: 10) + CooldownHours int `yaml:"cooldown_hours"` // hours between curriculum runs (default: 24) } // CLTrainingConfig holds training-specific settings for continuous learning. diff --git a/internal/store/sqlite/continuous_learning.go b/internal/store/sqlite/continuous_learning.go index b8b31698..1abe3e44 100644 --- a/internal/store/sqlite/continuous_learning.go +++ b/internal/store/sqlite/continuous_learning.go @@ -219,6 +219,121 @@ func (s *SQLiteStore) GetEncodingQualityWindow(ctx context.Context, windowSize i return w, nil } +// --- Phase B: Curriculum generation --- + +func (s *SQLiteStore) UpdateExperienceCorrectedOutput(ctx context.Context, entryID string, output string, epr float64, fr float64, source string) error { + now := time.Now() + _, err := s.db.ExecContext(ctx, + `UPDATE experience_buffer + SET corrected_output = ?, corrected_epr = ?, corrected_fr = ?, + correction_source = ?, corrected_at = ?, updated_at = ? + WHERE id = ?`, + output, epr, fr, source, now, now, entryID, + ) + if err != nil { + return fmt.Errorf("updating corrected output for entry %s: %w", entryID, err) + } + return nil +} + +func (s *SQLiteStore) ListNeedsImprovement(ctx context.Context, limit int) ([]store.ExperienceEntry, error) { + // Return needs_improvement entries that haven't been corrected yet + rows, err := s.db.QueryContext(ctx, + `SELECT id, raw_id, memory_id, encoding_epr, encoding_fr, encoding_flags, + recall_score, recall_count, category, used_in_training, created_at, updated_at + FROM experience_buffer + WHERE category = 'needs_improvement' AND corrected_output IS NULL + ORDER BY encoding_epr ASC + LIMIT ?`, + limit, + ) + if err != nil { + return nil, fmt.Errorf("listing needs_improvement entries: %w", err) + } + defer func() { _ = rows.Close() }() + + var entries []store.ExperienceEntry + for rows.Next() { + var e store.ExperienceEntry + var flagsJSON string + var usedInt int + if err := rows.Scan(&e.ID, &e.RawID, &e.MemoryID, &e.EncodingEPR, &e.EncodingFR, &flagsJSON, + &e.RecallScore, &e.RecallCount, &e.Category, &usedInt, &e.CreatedAt, &e.UpdatedAt); err != nil { + return nil, fmt.Errorf("scanning experience entry: %w", err) + } + _ = json.Unmarshal([]byte(flagsJSON), &e.EncodingFlags) + e.UsedInTraining = usedInt != 0 + entries = append(entries, e) + } + return entries, rows.Err() +} + +func (s *SQLiteStore) WriteCurriculumRun(ctx context.Context, run store.CurriculumRun) error { + if run.ID == "" { + run.ID = uuid.New().String() + } + _, err := s.db.ExecContext(ctx, + `INSERT INTO curriculum_runs (id, started_at, completed_at, corrections_attempted, corrections_passed, + corrections_failed, entries_reclassified, training_batch_path, status, created_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + run.ID, run.StartedAt, run.CompletedAt, + run.CorrectionsAttempted, run.CorrectionsPassed, run.CorrectionsFailed, + run.EntriesReclassified, run.TrainingBatchPath, run.Status, time.Now(), + ) + if err != nil { + return fmt.Errorf("writing curriculum run %s: %w", run.ID, err) + } + return nil +} + +func (s *SQLiteStore) UpdateCurriculumRun(ctx context.Context, run store.CurriculumRun) error { + _, err := s.db.ExecContext(ctx, + `UPDATE curriculum_runs + SET completed_at = ?, corrections_attempted = ?, corrections_passed = ?, + corrections_failed = ?, entries_reclassified = ?, training_batch_path = ?, status = ? + WHERE id = ?`, + run.CompletedAt, run.CorrectionsAttempted, run.CorrectionsPassed, + run.CorrectionsFailed, run.EntriesReclassified, run.TrainingBatchPath, run.Status, run.ID, + ) + if err != nil { + return fmt.Errorf("updating curriculum run %s: %w", run.ID, err) + } + return nil +} + +func (s *SQLiteStore) GetLastCurriculumRunTime(ctx context.Context) (time.Time, error) { + var raw *string + err := s.db.QueryRowContext(ctx, + `SELECT MAX(started_at) FROM curriculum_runs WHERE status = 'completed'`, + ).Scan(&raw) + if err != nil { + return time.Time{}, fmt.Errorf("getting last curriculum run time: %w", err) + } + if raw == nil || *raw == "" { + return time.Time{}, nil + } + // Try multiple time formats — SQLite + Go's time.Time.String() output + formats := []string{ + time.RFC3339Nano, + time.RFC3339, + "2006-01-02 15:04:05-07:00", + "2006-01-02T15:04:05Z", + "2006-01-02 15:04:05 -0700 MST", + } + var t time.Time + var parseErr error + for _, f := range formats { + t, parseErr = time.Parse(f, *raw) + if parseErr == nil { + break + } + } + if parseErr != nil { + return time.Time{}, fmt.Errorf("parsing curriculum run time %q: %w", *raw, parseErr) + } + return t, nil +} + func (s *SQLiteStore) ListRecentEncodingQuality(ctx context.Context, limit int) ([]store.EncodingQualityEntry, error) { rows, err := s.db.QueryContext(ctx, `SELECT m.id, COALESCE(m.summary, ''), COALESCE(m.source, ''), diff --git a/internal/store/sqlite/continuous_learning_test.go b/internal/store/sqlite/continuous_learning_test.go new file mode 100644 index 00000000..b05576c7 --- /dev/null +++ b/internal/store/sqlite/continuous_learning_test.go @@ -0,0 +1,220 @@ +//go:build sqlite_fts5 + +package sqlite + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/appsprout-dev/mnemonic/internal/store" +) + +// writeMemoryForExperience creates the prerequisite memory + raw_memory rows +// that the experience_buffer FK requires. +func writeMemoryForExperience(t *testing.T, s *SQLiteStore, id string) { + t.Helper() + rawID := "raw-" + id + writeRawForMemory(t, s, rawID) + mem := store.Memory{ + ID: "mem-" + id, + RawID: rawID, + Summary: "test memory " + id, + CreatedAt: time.Now(), + } + if err := s.WriteMemory(context.Background(), mem); err != nil { + t.Fatalf("writing prerequisite memory %s: %v", id, err) + } +} + +func TestListNeedsImprovement(t *testing.T) { + s := createTestStore(t) + defer func() { _ = s.Close() }() + ctx := context.Background() + + ids := []string{"a", "b", "c", "d"} + cats := []string{"gold", "needs_improvement", "needs_improvement", "ambiguous"} + + for i, id := range ids { + writeMemoryForExperience(t, s, id) + entry := store.ExperienceEntry{ + ID: "entry-" + id, + RawID: "raw-" + id, + MemoryID: "mem-" + id, + EncodingEPR: float64(i) * 0.2, + Category: cats[i], + } + if err := s.WriteExperienceEntry(ctx, entry); err != nil { + t.Fatalf("writing entry %s: %v", id, err) + } + } + + entries, err := s.ListNeedsImprovement(ctx, 10) + if err != nil { + t.Fatalf("ListNeedsImprovement: %v", err) + } + if len(entries) != 2 { + t.Fatalf("expected 2 needs_improvement entries, got %d", len(entries)) + } + if entries[0].EncodingEPR > entries[1].EncodingEPR { + t.Errorf("expected ascending EPR order, got %.2f then %.2f", entries[0].EncodingEPR, entries[1].EncodingEPR) + } +} + +func TestListNeedsImprovement_ExcludesCorrected(t *testing.T) { + s := createTestStore(t) + defer func() { _ = s.Close() }() + ctx := context.Background() + + for _, id := range []string{"a", "b"} { + writeMemoryForExperience(t, s, id) + entry := store.ExperienceEntry{ + ID: "entry-" + id, + RawID: "raw-" + id, + MemoryID: "mem-" + id, + Category: "needs_improvement", + } + if err := s.WriteExperienceEntry(ctx, entry); err != nil { + t.Fatalf("writing entry %s: %v", id, err) + } + } + + if err := s.UpdateExperienceCorrectedOutput(ctx, "entry-a", `{"summary":"corrected"}`, 0.95, 1.0, "gemini"); err != nil { + t.Fatalf("UpdateExperienceCorrectedOutput: %v", err) + } + + entries, err := s.ListNeedsImprovement(ctx, 10) + if err != nil { + t.Fatalf("ListNeedsImprovement: %v", err) + } + if len(entries) != 1 { + t.Fatalf("expected 1 uncorrected entry, got %d", len(entries)) + } + if entries[0].ID != "entry-b" { + t.Errorf("expected entry-b, got %s", entries[0].ID) + } +} + +func TestUpdateExperienceCorrectedOutput(t *testing.T) { + s := createTestStore(t) + defer func() { _ = s.Close() }() + ctx := context.Background() + + writeMemoryForExperience(t, s, "1") + entry := store.ExperienceEntry{ + ID: "entry-1", + RawID: "raw-1", + MemoryID: "mem-1", + Category: "needs_improvement", + } + if err := s.WriteExperienceEntry(ctx, entry); err != nil { + t.Fatalf("writing entry: %v", err) + } + + if err := s.UpdateExperienceCorrectedOutput(ctx, "entry-1", `{"summary":"better"}`, 0.92, 1.0, "gemini"); err != nil { + t.Fatalf("UpdateExperienceCorrectedOutput: %v", err) + } + + var output string + var epr float64 + var source string + err := s.db.QueryRow(`SELECT corrected_output, corrected_epr, correction_source FROM experience_buffer WHERE id = ?`, "entry-1"). + Scan(&output, &epr, &source) + if err != nil { + t.Fatalf("querying corrected output: %v", err) + } + if output != `{"summary":"better"}` { + t.Errorf("expected corrected output, got %s", output) + } + if epr != 0.92 { + t.Errorf("expected corrected EPR 0.92, got %.2f", epr) + } + if source != "gemini" { + t.Errorf("expected source 'gemini', got %s", source) + } +} + +func TestCurriculumRunLifecycle(t *testing.T) { + s := createTestStore(t) + defer func() { _ = s.Close() }() + ctx := context.Background() + + // No runs yet — should return zero time + lastRun, err := s.GetLastCurriculumRunTime(ctx) + if err != nil { + t.Fatalf("GetLastCurriculumRunTime: %v", err) + } + if !lastRun.IsZero() { + t.Errorf("expected zero time for no runs, got %v", lastRun) + } + + // Write a completed run + now := time.Now().Truncate(time.Second) + run := store.CurriculumRun{ + ID: "run-1", + StartedAt: now, + CompletedAt: &now, + CorrectionsAttempted: 10, + CorrectionsPassed: 7, + CorrectionsFailed: 3, + Status: "completed", + } + if err := s.WriteCurriculumRun(ctx, run); err != nil { + t.Fatalf("WriteCurriculumRun: %v", err) + } + + lastRun, err = s.GetLastCurriculumRunTime(ctx) + if err != nil { + t.Fatalf("GetLastCurriculumRunTime after write: %v", err) + } + if lastRun.Before(now.Add(-time.Second)) || lastRun.After(now.Add(time.Second)) { + t.Errorf("expected last run near %v, got %v", now, lastRun) + } + + // Update the run + later := now.Add(time.Minute) + run.CompletedAt = &later + run.CorrectionsPassed = 8 + if err := s.UpdateCurriculumRun(ctx, run); err != nil { + t.Fatalf("UpdateCurriculumRun: %v", err) + } + + // Verify update took + var passed int + err = s.db.QueryRow(`SELECT corrections_passed FROM curriculum_runs WHERE id = ?`, "run-1").Scan(&passed) + if err != nil { + t.Fatalf("querying updated run: %v", err) + } + if passed != 8 { + t.Errorf("expected corrections_passed=8, got %d", passed) + } +} + +func TestListNeedsImprovement_RespectsLimit(t *testing.T) { + s := createTestStore(t) + defer func() { _ = s.Close() }() + ctx := context.Background() + + for i := 0; i < 5; i++ { + id := fmt.Sprintf("%d", i) + writeMemoryForExperience(t, s, id) + entry := store.ExperienceEntry{ + ID: "entry-" + id, + RawID: "raw-" + id, + MemoryID: "mem-" + id, + Category: "needs_improvement", + } + if err := s.WriteExperienceEntry(ctx, entry); err != nil { + t.Fatalf("writing entry %d: %v", i, err) + } + } + + entries, err := s.ListNeedsImprovement(ctx, 3) + if err != nil { + t.Fatalf("ListNeedsImprovement: %v", err) + } + if len(entries) != 3 { + t.Fatalf("expected limit of 3, got %d", len(entries)) + } +} diff --git a/internal/store/sqlite/schema.go b/internal/store/sqlite/schema.go index c64ec324..c0ce00e3 100644 --- a/internal/store/sqlite/schema.go +++ b/internal/store/sqlite/schema.go @@ -10,7 +10,7 @@ import ( // migration is added. It is written to PRAGMA user_version after InitSchema // completes, and read by the pre-migration backup logic to skip backups when // the schema is already current. -const SchemaVersion = 16 +const SchemaVersion = 17 const schema = ` -- Raw observations before encoding @@ -612,6 +612,36 @@ INSERT OR IGNORE INTO forum_categories (id, name, slug, description, icon, color _, _ = db.Exec(`CREATE INDEX IF NOT EXISTS idx_experience_buffer_category ON experience_buffer(category)`) _, _ = db.Exec(`CREATE INDEX IF NOT EXISTS idx_experience_buffer_memory ON experience_buffer(memory_id)`) + // Migration 017: Curriculum generation — corrected output columns + run tracking (#391 Phase B) + for _, col := range []struct{ column, def string }{ + {"corrected_output", "TEXT DEFAULT NULL"}, + {"corrected_epr", "REAL DEFAULT NULL"}, + {"corrected_fr", "REAL DEFAULT NULL"}, + {"correction_source", "TEXT DEFAULT NULL"}, + {"corrected_at", "DATETIME DEFAULT NULL"}, + } { + _, err = db.Exec(fmt.Sprintf(`ALTER TABLE experience_buffer ADD COLUMN %s %s`, col.column, col.def)) + if err != nil && !isAlterTableDuplicateColumn(err) { + return fmt.Errorf("failed to add experience_buffer.%s column: %w", col.column, err) + } + } + + if _, err := db.Exec(`CREATE TABLE IF NOT EXISTS curriculum_runs ( + id TEXT PRIMARY KEY, + started_at DATETIME NOT NULL, + completed_at DATETIME, + corrections_attempted INTEGER DEFAULT 0, + corrections_passed INTEGER DEFAULT 0, + corrections_failed INTEGER DEFAULT 0, + entries_reclassified INTEGER DEFAULT 0, + training_batch_path TEXT, + status TEXT DEFAULT 'pending', + created_at DATETIME DEFAULT CURRENT_TIMESTAMP + )`); err != nil { + return fmt.Errorf("failed to create curriculum_runs table: %w", err) + } + _, _ = db.Exec(`CREATE INDEX IF NOT EXISTS idx_curriculum_runs_status ON curriculum_runs(status)`) + // Record the schema version so pre-migration backups can skip when current. if _, err := db.Exec(fmt.Sprintf("PRAGMA user_version = %d", SchemaVersion)); err != nil { return fmt.Errorf("failed to set user_version: %w", err) diff --git a/internal/store/store.go b/internal/store/store.go index b9350364..9f810f91 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -594,8 +594,30 @@ type ExperienceEntry struct { RecallCount int `json:"recall_count"` Category string `json:"category"` // gold, needs_improvement, ambiguous UsedInTraining bool `json:"used_in_training"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` + + // Phase B: Curriculum generation — corrected output from teacher model + CorrectedOutput string `json:"corrected_output,omitempty"` + CorrectedEPR float64 `json:"corrected_epr,omitempty"` + CorrectedFR float64 `json:"corrected_fr,omitempty"` + CorrectionSource string `json:"correction_source,omitempty"` // "gemini", "api" + CorrectedAt *time.Time `json:"corrected_at,omitempty"` + + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// CurriculumRun tracks a single curriculum generation cycle. +type CurriculumRun struct { + ID string `json:"id"` + StartedAt time.Time `json:"started_at"` + CompletedAt *time.Time `json:"completed_at,omitempty"` + CorrectionsAttempted int `json:"corrections_attempted"` + CorrectionsPassed int `json:"corrections_passed"` + CorrectionsFailed int `json:"corrections_failed"` + EntriesReclassified int `json:"entries_reclassified"` + TrainingBatchPath string `json:"training_batch_path,omitempty"` + Status string `json:"status"` // pending, completed, failed + CreatedAt time.Time `json:"created_at"` } // ExperienceStats summarizes the experience buffer contents. @@ -641,6 +663,13 @@ type ContinuousLearningStore interface { WriteRecallFeedbackEntry(ctx context.Context, entry RecallFeedbackEntry) error GetRecallHistory(ctx context.Context, memoryID string) ([]RecallFeedbackEntry, error) + // Curriculum generation (Phase B) + UpdateExperienceCorrectedOutput(ctx context.Context, entryID string, output string, epr float64, fr float64, source string) error + ListNeedsImprovement(ctx context.Context, limit int) ([]ExperienceEntry, error) + WriteCurriculumRun(ctx context.Context, run CurriculumRun) error + UpdateCurriculumRun(ctx context.Context, run CurriculumRun) error + GetLastCurriculumRunTime(ctx context.Context) (time.Time, error) + // Quality drift detection GetEncodingQualityWindow(ctx context.Context, windowSize int) (EncodingQualityWindow, error) diff --git a/internal/store/storetest/mock.go b/internal/store/storetest/mock.go index 0240735e..6c1ee31e 100644 --- a/internal/store/storetest/mock.go +++ b/internal/store/storetest/mock.go @@ -392,6 +392,17 @@ func (MockStore) GetEncodingQualityWindow(context.Context, int) (store.EncodingQ func (MockStore) ListRecentEncodingQuality(context.Context, int) ([]store.EncodingQualityEntry, error) { return nil, nil } +func (MockStore) UpdateExperienceCorrectedOutput(context.Context, string, string, float64, float64, string) error { + return nil +} +func (MockStore) ListNeedsImprovement(context.Context, int) ([]store.ExperienceEntry, error) { + return nil, nil +} +func (MockStore) WriteCurriculumRun(context.Context, store.CurriculumRun) error { return nil } +func (MockStore) UpdateCurriculumRun(context.Context, store.CurriculumRun) error { return nil } +func (MockStore) GetLastCurriculumRunTime(context.Context) (time.Time, error) { + return time.Time{}, nil +} // --- Lifecycle --- diff --git a/migrations/008_curriculum_generation.sql b/migrations/008_curriculum_generation.sql new file mode 100644 index 00000000..f7a8e0c5 --- /dev/null +++ b/migrations/008_curriculum_generation.sql @@ -0,0 +1,21 @@ +-- Phase B: Curriculum generation columns on experience_buffer +ALTER TABLE experience_buffer ADD COLUMN corrected_output TEXT DEFAULT NULL; +ALTER TABLE experience_buffer ADD COLUMN corrected_epr REAL DEFAULT NULL; +ALTER TABLE experience_buffer ADD COLUMN corrected_fr REAL DEFAULT NULL; +ALTER TABLE experience_buffer ADD COLUMN correction_source TEXT DEFAULT NULL; +ALTER TABLE experience_buffer ADD COLUMN corrected_at DATETIME DEFAULT NULL; + +-- Curriculum run tracking +CREATE TABLE IF NOT EXISTS curriculum_runs ( + id TEXT PRIMARY KEY, + started_at DATETIME NOT NULL, + completed_at DATETIME, + corrections_attempted INTEGER DEFAULT 0, + corrections_passed INTEGER DEFAULT 0, + corrections_failed INTEGER DEFAULT 0, + entries_reclassified INTEGER DEFAULT 0, + training_batch_path TEXT, + status TEXT DEFAULT 'pending', + created_at DATETIME DEFAULT CURRENT_TIMESTAMP +); +CREATE INDEX IF NOT EXISTS idx_curriculum_runs_status ON curriculum_runs(status); From 99d3c28a788105d335183323ae53d6a46331457c Mon Sep 17 00:00:00 2001 From: Caleb Gross Date: Mon, 13 Apr 2026 18:13:11 -0400 Subject: [PATCH 2/8] feat: training data assembly for continuous learning (#391) AssembleTrainingBatch exports gold and corrected encoding pairs as JSONL for spoke fine-tuning. Splits 70/30: 70% from experience buffer (gold + corrective pairs), 30% reserved for replay mixing by the Python training script. Each example includes the full encoding prompt and target output for direct tokenization. Writes batch_{id}.jsonl + batch_{id}_manifest.json with provenance. Called by Phase C (automated training trigger) or via MCP tool. Co-Authored-By: Claude Opus 4.6 (1M context) --- internal/agent/dreaming/training_data.go | 196 +++++++++++++++++++++++ 1 file changed, 196 insertions(+) create mode 100644 internal/agent/dreaming/training_data.go diff --git a/internal/agent/dreaming/training_data.go b/internal/agent/dreaming/training_data.go new file mode 100644 index 00000000..1064131d --- /dev/null +++ b/internal/agent/dreaming/training_data.go @@ -0,0 +1,196 @@ +package dreaming + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "time" + + "github.com/appsprout-dev/mnemonic/internal/agent/agentutil" + "github.com/appsprout-dev/mnemonic/internal/agent/encoding" + "github.com/appsprout-dev/mnemonic/internal/store" + "github.com/google/uuid" +) + +// TrainingExample is a single training pair written to JSONL. +// The Python training script tokenizes and mixes with replay data. +type TrainingExample struct { + Type string `json:"type"` // "gold" or "corrective" + Prompt string `json:"prompt"` // system + user prompt (identical to what the model saw) + Output string `json:"output"` // the target completion (gold encoding or corrected encoding) + MemoryID string `json:"memory_id"` // provenance + EPR float64 `json:"epr"` // EPR score of the output +} + +// TrainingBatchManifest describes a training batch for reproducibility. +type TrainingBatchManifest struct { + ID string `json:"id"` + CreatedAt time.Time `json:"created_at"` + GoldCount int `json:"gold_count"` + CorrectedCount int `json:"corrected_count"` + TotalExamples int `json:"total_examples"` + DataPath string `json:"data_path"` +} + +// AssembleTrainingBatch writes gold and corrected encoding pairs to a JSONL file. +// Returns the manifest and output path. The Python training script handles +// replay mixing (30% from base dataset) and tokenization. +// Called by Phase C (automated training trigger) or via MCP tool. +func (da *DreamingAgent) AssembleTrainingBatch(ctx context.Context, outputDir string, maxExamples int) (*TrainingBatchManifest, error) { + if maxExamples <= 0 { + maxExamples = 200 + } + + // 70/30 split: 70% from experience buffer (gold + corrected), 30% reserved for replay + bufferBudget := maxExamples * 7 / 10 + goldBudget := bufferBudget / 2 + correctedBudget := bufferBudget - goldBudget + + // Fetch gold entries + goldEntries, err := da.store.ListExperienceByCategory(ctx, "gold", goldBudget) + if err != nil { + return nil, fmt.Errorf("listing gold entries: %w", err) + } + + // Fetch corrected entries (needs_improvement with corrected_output set) + correctedEntries, err := da.store.ListExperienceByCategory(ctx, "needs_improvement", correctedBudget*3) + if err != nil { + return nil, fmt.Errorf("listing corrected entries: %w", err) + } + // Filter to only those with corrections + var corrected []store.ExperienceEntry + for _, e := range correctedEntries { + if e.CorrectedOutput != "" { + corrected = append(corrected, e) + if len(corrected) >= correctedBudget { + break + } + } + } + + if len(goldEntries) == 0 && len(corrected) == 0 { + return nil, fmt.Errorf("no training examples available (0 gold, 0 corrected)") + } + + // Create output directory + if err := os.MkdirAll(outputDir, 0o755); err != nil { + return nil, fmt.Errorf("creating output dir: %w", err) + } + + batchID := uuid.New().String()[:8] + dataPath := filepath.Join(outputDir, fmt.Sprintf("batch_%s.jsonl", batchID)) + + f, err := os.Create(dataPath) + if err != nil { + return nil, fmt.Errorf("creating batch file: %w", err) + } + defer func() { _ = f.Close() }() + + enc := json.NewEncoder(f) + var totalWritten int + + // Write gold examples + for _, entry := range goldEntries { + example, err := da.buildTrainingExample(ctx, entry, "gold") + if err != nil { + da.log.Debug("skipping gold entry", "entry_id", entry.ID, "error", err) + continue + } + if err := enc.Encode(example); err != nil { + return nil, fmt.Errorf("writing gold example: %w", err) + } + totalWritten++ + } + + // Write corrective examples (using the teacher model's output) + for _, entry := range corrected { + example := TrainingExample{ + Type: "corrective", + MemoryID: entry.MemoryID, + EPR: entry.CorrectedEPR, + Output: entry.CorrectedOutput, + } + // Build the prompt from raw memory + raw, err := da.store.GetRaw(ctx, entry.RawID) + if err != nil { + da.log.Debug("skipping corrected entry", "entry_id", entry.ID, "error", err) + continue + } + truncated := agentutil.Truncate(raw.Content, 4000) + example.Prompt = encoding.BuildCompressionPrompt(truncated, raw.Source, raw.Type, "", "", nil) + + if err := enc.Encode(example); err != nil { + return nil, fmt.Errorf("writing corrective example: %w", err) + } + totalWritten++ + } + + manifest := &TrainingBatchManifest{ + ID: batchID, + CreatedAt: time.Now(), + GoldCount: len(goldEntries), + CorrectedCount: len(corrected), + TotalExamples: totalWritten, + DataPath: dataPath, + } + + // Write manifest + manifestPath := filepath.Join(outputDir, fmt.Sprintf("batch_%s_manifest.json", batchID)) + mf, err := os.Create(manifestPath) + if err != nil { + return manifest, fmt.Errorf("creating manifest file: %w", err) + } + defer func() { _ = mf.Close() }() + + manifestEnc := json.NewEncoder(mf) + manifestEnc.SetIndent("", " ") + if err := manifestEnc.Encode(manifest); err != nil { + return manifest, fmt.Errorf("writing manifest: %w", err) + } + + da.log.Info("training batch assembled", + "batch_id", batchID, "gold", manifest.GoldCount, + "corrected", manifest.CorrectedCount, "total", manifest.TotalExamples, + "path", dataPath) + + return manifest, nil +} + +// buildTrainingExample creates a training example from a gold experience entry. +// Loads the raw memory and the encoded memory to reconstruct the prompt+output pair. +func (da *DreamingAgent) buildTrainingExample(ctx context.Context, entry store.ExperienceEntry, exType string) (*TrainingExample, error) { + raw, err := da.store.GetRaw(ctx, entry.RawID) + if err != nil { + return nil, fmt.Errorf("loading raw memory %s: %w", entry.RawID, err) + } + + // Get the encoded memory (the model's output that was rated as gold) + mem, err := da.store.GetMemory(ctx, entry.MemoryID) + if err != nil { + return nil, fmt.Errorf("loading memory %s: %w", entry.MemoryID, err) + } + + truncated := agentutil.Truncate(raw.Content, 4000) + prompt := encoding.BuildCompressionPrompt(truncated, raw.Source, raw.Type, "", "", nil) + + // Reconstruct the encoding output as JSON + output, err := json.Marshal(map[string]any{ + "summary": mem.Summary, + "content": mem.Content, + "concepts": mem.Concepts, + "salience": mem.Salience, + }) + if err != nil { + return nil, fmt.Errorf("marshaling memory output: %w", err) + } + + return &TrainingExample{ + Type: exType, + Prompt: prompt, + Output: string(output), + MemoryID: entry.MemoryID, + EPR: entry.EncodingEPR, + }, nil +} From a85e0bed19fc69300218b234cde4ab5cf07b53c7 Mon Sep 17 00:00:00 2001 From: Caleb Gross Date: Mon, 13 Apr 2026 23:01:04 -0400 Subject: [PATCH 3/8] refactor: reduce session startup context overhead by 30% Merge scientific-method.md, experiment-logging.md, and peer-review-standard.md into a single research-standards.md. Remove MCP tools table from CLAUDE.md (redundant with MCP server schemas). Deduplicate conventions and platform sections that were repeated across rules files and CLAUDE.md. Tighten git-safety.md and code-quality.md to remove rules already enforced by Claude Code system prompt and git hooks. Saves ~3,350 tokens per session startup (~13.4KB). Co-Authored-By: Claude Opus 4.6 (1M context) --- .claude/rules/code-quality.md | 31 ++----- .claude/rules/experiment-logging.md | 69 --------------- .claude/rules/git-safety.md | 46 ++-------- .claude/rules/peer-review-standard.md | 29 ------- .claude/rules/research-standards.md | 69 +++++++++++++++ .claude/rules/scientific-method.md | 116 -------------------------- CLAUDE.md | 77 +---------------- 7 files changed, 88 insertions(+), 349 deletions(-) delete mode 100644 .claude/rules/experiment-logging.md delete mode 100644 .claude/rules/peer-review-standard.md create mode 100644 .claude/rules/research-standards.md delete mode 100644 .claude/rules/scientific-method.md diff --git a/.claude/rules/code-quality.md b/.claude/rules/code-quality.md index ec10c364..6546b31d 100644 --- a/.claude/rules/code-quality.md +++ b/.claude/rules/code-quality.md @@ -3,30 +3,13 @@ ## Finish the Work First - The goal is working, tested code — not a PR. PRs are the delivery mechanism, not the deliverable. -- Do NOT rush to commit, push, or open a PR until the actual work is done and verified. -- A task is not done when the code compiles. It's done when it runs correctly and is tested. -- Stay in the problem. If something is "working," verify it properly before moving to git/PR mechanics. -- Only create a PR when explicitly asked, or when the work is genuinely complete and tested. -- Never treat "open a PR" as a natural next step — wait for the user to decide when the work is ready. +- Do NOT rush to commit/push/PR until work is done and verified. A task is done when it runs correctly and is tested, not when it compiles. +- Only create a PR when explicitly asked, or when work is genuinely complete and tested. ## Scope -- Only change what was asked for — don't touch surrounding code -- If you spot something worth fixing but it wasn't requested, call it out instead of silently doing it -- No drive-by refactors, no "while I'm here" improvements -- One logical change per task — don't bundle unrelated fixes - -## Change Safety - -- Read before edit — always understand a file before modifying it -- Build and test after changes, don't assume it works -- No new dependencies without discussing it first -- Don't delete code you don't fully understand - -## Review Mindset - -- Don't add comments, docstrings, or type annotations to untouched code -- Don't rename things that aren't part of the task -- Don't "improve" error messages or formatting in adjacent code -- Keep PRs reviewable — small, focused diffs -- If a change is getting large, pause and check in with the user +- Only change what was asked for. No drive-by refactors, no "while I'm here" improvements. +- If you spot something worth fixing, call it out — don't silently do it. +- One logical change per task. Keep PRs reviewable — small, focused diffs. +- No new dependencies without discussing it first. +- If a change is getting large, pause and check in with the user. diff --git a/.claude/rules/experiment-logging.md b/.claude/rules/experiment-logging.md deleted file mode 100644 index f5cc1047..00000000 --- a/.claude/rules/experiment-logging.md +++ /dev/null @@ -1,69 +0,0 @@ -# Experiment Logging - -This is a serious research project. Mnemonic-LM training docs are scientific documents, not scratch pads. - -## Document Structure - -`training/docs/experiments.md` follows a fixed structure. Do not deviate: - -1. **Overview** — research question and project goals -2. **Experimental Protocol** — training setup, hardware, data mix, evaluation metrics -3. **Baselines** — Gemini quality floor, stub-LLM IR benchmarks, and what they represent -4. **HP Sweep Results** — grouped by variable (LR, batch size, beta2, warmup, etc.) -5. **Pretraining Runs** — full training results with loss curves and checkpoints -6. **Planned Experiments** — with hypothesis and motivation -7. **Summary** — results table, key findings, open questions - -## Quality Bar for Experiment Entries - -Every experiment entry MUST include ALL of the following: - -- **Header line:** Experiment name, date, config, hardware — on one line -- **Control and variable:** Explicitly state what is being compared and what single variable changed -- **Results table:** Loss and PPL at minimum. Throughput (tokens/sec) for batch experiments. -- **Analysis paragraph:** Not bullet points. A proper paragraph explaining: - - What happened and by how much (quantitative) - - Why it happened (mechanistic interpretation if possible) - - What it implies for the next experiment or the full run -- **For sweep runs:** Compare against the same-phase baseline, not a different phase's results - -## When Logging Sweep Phases - -Group runs under a shared subsection (e.g., "Phase 1: LR + Weight Decay") with: -- A table of all runs in the phase -- The best run highlighted -- A combined analysis paragraph drawing conclusions across the group - -## Before Running Any Experiment - -- State the hypothesis and what variable is being tested -- Pre-register in `training/docs/experiment_registry.md` -- Note the config, HP, and how results will be compared - -## After Every Run - -- Record results immediately — no "I'll write it up later" -- Update `training/sweep_results.tsv` with raw numbers -- Update `training/docs/experiments.md` with analysis -- Update the summary table — every experiment gets a row -- If the result changes any prior conclusions, update those entries - -## Benchmark Logging - -Benchmark results (IR quality, end-to-end Gemini) require: -- **Exact command used** to produce the numbers -- **Software state** — git commit hash, binary version, config -- **Environment** — hardware, daemon config, LLM provider and model -- **All metrics** — not just the headline number -- **Comparison context** — what baseline this represents and what it needs to beat - -## Do Not - -- Write informal one-liner "key findings" — write analysis paragraphs -- Log experiments chronologically instead of by category -- Skip the control/variable line — every experiment is a comparison -- Use vague language ("clearly better", "significantly worse") without numbers -- Cherry-pick metrics — report all of them, including unfavorable ones -- Leave an experiment as "RUNNING" after it finishes -- Run a benchmark, get a number, and not document methodology -- Let the registry config drift from what actually ran — if batch size, accum, or any HP changes mid-experiment, update the config line immediately diff --git a/.claude/rules/git-safety.md b/.claude/rules/git-safety.md index 0d450a33..4edbe02e 100644 --- a/.claude/rules/git-safety.md +++ b/.claude/rules/git-safety.md @@ -2,48 +2,20 @@ ## Branch Workflow -- Remote: `origin` (https://github.com/appsprout-dev/mnemonic.git) -- Primary branch: `main` -- **All new work starts on a feature branch** — never commit directly to `main` -- Branch naming: `feat/`, `fix/` +- Remote: `origin` (https://github.com/appsprout-dev/mnemonic.git), primary branch: `main` +- All new work on feature branches (`feat/`, `fix/`) — never commit directly to `main` - Before branching: `git stash` (if dirty), `git pull origin main`, then `git checkout -b ` -- **Before committing:** Run `git branch --show-current` to verify you're on the intended branch. Bash tool does not persist shell state — a prior `git checkout` may not have taken effect. -- **All changes go through a PR** — push the branch, open a PR with `gh pr create`, get it reviewed -- **Closing issues:** When a PR resolves a GitHub issue, comment on the issue with a reference to the PR before or after closing it. Never close issues silently. -- No blind commits to main, no YOLO pushes +- **Before committing:** `git branch --show-current` to verify — Bash tool doesn't persist shell state +- All changes go through PRs (`gh pr create`). When a PR resolves an issue, comment on the issue with a PR reference. -## Forbidden Operations +## Forbidden (enforced by hooks) -Enforced by `.claude/hooks/protect-git.sh` and `.claude/hooks/no-secrets.sh`: - -- `git push --force` / `git push -f` -- destroys remote history -- `git reset --hard` -- destroys local changes -- `git clean -f` -- permanently deletes untracked files -- `git checkout .` / `git restore .` -- discards all unstaged changes -- Staging `.env`, `credentials`, `*.db`, `settings.local.json` +`.claude/hooks/protect-git.sh` and `.claude/hooks/no-secrets.sh` block: force push, `reset --hard`, `clean -f`, `checkout .`/`restore .`, staging `.env`/`credentials`/`*.db`/`settings.local.json`. ## Commit Messages (Conventional Commits) -Use [Conventional Commits](https://www.conventionalcommits.org/) format — release-please uses these to auto-generate changelogs and version bumps: - -- `feat: add memory source tracking` — new feature (bumps minor) -- `fix: prevent nil pointer in retrieval` — bug fix (bumps patch) -- `docs: update README with Gemini setup` — documentation only -- `refactor: simplify consolidation loop` — code change, no behavior change -- `test: add encoding agent coverage` — tests only -- `chore: update dependencies` — maintenance -- `ci: fix release workflow runner` — CI/CD changes - -Rules: - -- Short, direct subject line describing the change -- Body for context when non-obvious -- No issue-closing keywords in commit messages unless explicitly asked -- Use Co-Authored-By for Claude contributions -- Append `!` after the type for breaking changes: `feat!: redesign store interface` +Format: `type: description` — release-please uses these for changelogs/version bumps. -## Secrets +Types: `feat` (minor), `fix` (patch), `docs`, `refactor`, `test`, `chore`, `ci`. Append `!` for breaking changes. -- `settings.local.json` contains machine-specific permissions -- NEVER commit -- `*.db` files contain user data -- gitignored -- Never include API tokens in commit messages or code +Rules: short subject, body when non-obvious, no issue-closing keywords unless asked, Co-Authored-By for Claude, `settings.local.json` and `*.db` never committed. diff --git a/.claude/rules/peer-review-standard.md b/.claude/rules/peer-review-standard.md deleted file mode 100644 index 72072e00..00000000 --- a/.claude/rules/peer-review-standard.md +++ /dev/null @@ -1,29 +0,0 @@ -# Peer Review Standard - -This project is under review by Aaron Gokaslan and Andrej Karpathy. All work must meet the standard of a published research project, not a hobby repo. - -## What This Means - -### Experiments -- Every result must be reproducible from the registry entry alone — exact commands, configs, hardware, data paths -- Report ALL metrics, not just the favorable ones -- Look at actual model outputs, not just aggregate numbers — open random examples and read them -- Statistical claims require sufficient sample sizes and confidence intervals -- Negative results get the same documentation quality as positive results - -### Code -- Scripts must be self-documenting — clear argument parsing, docstrings, usage examples -- No dead code, no commented-out experiments, no "TODO: clean up later" -- Training pipelines must run end-to-end from a clean checkout -- Evaluation scripts must produce deterministic results given the same inputs - -### Documentation -- The experiment registry tells the complete story of every experiment -- Design documents explain the reasoning, not just the implementation -- Every architectural decision has a recorded rationale - -### Claims -- "The model doesn't hallucinate" requires evidence, not assertion -- "X is better than Y" requires controlled comparison on matched conditions -- Fabrication rate of 10% is not "low" — it means 1 in 10 memories is corrupted -- 25 test inputs is a pilot, not a proof — acknowledge sample size limitations diff --git a/.claude/rules/research-standards.md b/.claude/rules/research-standards.md new file mode 100644 index 00000000..045d067f --- /dev/null +++ b/.claude/rules/research-standards.md @@ -0,0 +1,69 @@ +# Research Standards + +This project is under review by Aaron Gokaslan and Andrej Karpathy. All work must meet the standard of a published research project. Follow the scientific method — every experiment is a test of a hypothesis, not a fishing expedition. + +## Core Principles + +- **Let the data decide.** Judge results by numbers, not desire. No reinterpreting negative results as "needs more training" without evidence. +- **No motivated reasoning.** Report the number you got, not the number you wanted. Negative results get the same documentation quality. +- **Actively disprove.** After a positive result: could this be an LR artifact? Param count mismatch? Training duration effect? Random seed noise? Not confirmed until alternatives are ruled out. +- **Reproducibility.** Every result must be reproducible from the registry entry alone — exact commands, configs, hardware, data paths. + +## Pre-Registration (BEFORE any training or sweep) + +Create an entry in `training/docs/experiment_registry.md`: + +```markdown +### EXP-{number}: {name} +- **Date:** {YYYY-MM-DD} +- **Status:** REGISTERED | RUNNING | COMPLETED | FAILED +- **Hypothesis:** {What you expect and why} +- **Variable:** {The ONE thing changed vs control} +- **Control:** {Comparison target with its result} +- **Prediction:** {Quantitative — e.g., "expect LR 1e-3 to beat 6e-4 by 5-10%"} +- **Config:** {model, HP, hardware, data} +- **Result:** {filled after run} +- **Verdict:** CONFIRMED | REFUTED | INCONCLUSIVE +- **Analysis:** {What happened, why, what it means} +``` + +## After Every Run + +1. Record result in registry (Status -> COMPLETED) +2. Compare to prediction — was your mental model right? +3. Positive result: list alternative explanations, which are ruled out +4. Negative result: what does this tell us about config/architecture? +5. Update `training/sweep_results.tsv` with raw numbers +6. Update `training/docs/experiments.md` with analysis paragraph (not bullet points) +7. If result changes prior conclusions, update those entries too + +## Experiment Document Structure + +`training/docs/experiments.md` follows: Overview, Experimental Protocol, Baselines, HP Sweep Results (by variable), Pretraining Runs, Planned Experiments, Summary. + +Every entry needs: header line (name/date/config/hardware), control and variable, results table (loss + PPL minimum), analysis paragraph (quantitative, mechanistic, implications). Sweep phases get a combined table + cross-group analysis. + +## Benchmark Logging + +Benchmarks require: exact command, software state (commit hash, version, config), environment (hardware, provider, model), ALL metrics, comparison context (baseline and target). + +## Evaluation Protocol + +Standard budgets (RX 7800 XT): short test 1K-2K optimizer steps, full sweep 4K+ micro-steps, full pretrain ~400K micro-steps. + +Metrics — Training: loss, PPL, tokens/sec, VRAM peak (report ALL). Quality: nDCG@5 (primary), Precision@5, Recall@5, MRR, JSON compliance, latency. + +## Claims Bar + +- "Doesn't hallucinate" requires evidence. "X > Y" requires controlled comparison on matched conditions. +- Fabrication rate of 10% is not "low." 25 test inputs is a pilot, not proof. +- Look at actual outputs, not just aggregates. +- Code: no dead code, scripts self-documenting, pipelines run from clean checkout, evals deterministic. + +## Red Flags + +- Running without a hypothesis -> pre-register first +- 3+ experiments all confirmed -> testing hard enough? +- Comparing across different LRs/steps/batch sizes -> unfair +- Explaining away negatives -> data is probably right +- Registry config drifts from actual run -> update immediately diff --git a/.claude/rules/scientific-method.md b/.claude/rules/scientific-method.md deleted file mode 100644 index 7685f6fb..00000000 --- a/.claude/rules/scientific-method.md +++ /dev/null @@ -1,116 +0,0 @@ -# Scientific Method & Mertonian Norms - -This project follows the scientific method. Every experiment is a test of a hypothesis, -not a fishing expedition. These rules enforce rigor based on Merton's norms of science: -communalism, universalism, disinterestedness, and organized skepticism. - -## The Four Norms (Applied) - -### 1. Communalism — Share Everything - -- All findings belong to the project, not to a session. Document so that anyone - (including future-you with no memory) can understand and reproduce. -- Every experiment must be registered in `training/docs/experiment_registry.md` BEFORE training starts. -- Results, methodology, and failed attempts are all public record in the project docs. - -### 2. Universalism — Let the Data Decide - -- Judge results by the numbers, not by how much you want them to work. -- A hypothesis is supported or refuted by the data. Do not reinterpret negative results - as "needs more training" or "probably works at scale" without evidence. -- Same evaluation protocol for every experiment. No special treatment for favored configs. - -### 3. Disinterestedness — No Motivated Reasoning - -- Pre-register the hypothesis AND the expected outcome BEFORE running. -- If the result contradicts expectation, that's information. Don't rationalize it away. -- Report the number you got, not the number you wanted. -- Negative results get the same documentation quality as positive results. - -### 4. Organized Skepticism — Actively Try to Disprove - -- After a positive result, ask: "What else could explain this?" - - LR artifact? (Run the baseline at the same LR) - - Param count mismatch? (Check overhead percentage) - - Training duration effect? (Compare at matched steps) - - Random seed variance? (Note if the delta is small enough to be noise) -- A result is not "confirmed" until the obvious alternative explanations are ruled out. - -## Pre-Registration Protocol - -BEFORE launching any training or sweep run, create an entry in `training/docs/experiment_registry.md`: - -```markdown -### EXP-{number}: {name} -- **Date:** {YYYY-MM-DD} -- **Status:** REGISTERED | RUNNING | COMPLETED | FAILED -- **Hypothesis:** {What you expect to happen and why} -- **Variable:** {The ONE thing that changed vs control} -- **Control:** {What you're comparing against, with its result} -- **Prediction:** {Quantitative — e.g., "expect LR 1e-3 to beat 6e-4 by 5-10% lower loss"} -- **Config:** {model config, HP, hardware, data} -- **Result:** {filled in after run completes} -- **Verdict:** CONFIRMED | REFUTED | INCONCLUSIVE -- **Analysis:** {What happened, why, and what it means} -``` - -The prediction forces you to think about effect size before running. If you can't -predict a direction, that's fine — say "exploratory, no directional prediction" — but -be honest that you're exploring, not testing. - -## Post-Experiment Checklist - -After EVERY completed run: - -1. Record the result in the registry entry (Status -> COMPLETED) -2. Compare result to prediction — was your mental model right? -3. If result is positive: list alternative explanations and whether they're ruled out -4. If result is negative: what does this tell us about the config/architecture? -5. Update `training/sweep_results.tsv` with the raw numbers -6. Update the findings document with analysis -7. If the result changes any prior conclusions, update those entries too - -## Red Flags — Stop and Think - -- You're about to run something without a hypothesis -> pre-register first -- You ran 3+ experiments that all confirmed your expectations -> are you testing hard enough? -- You're comparing results from different LRs/steps/batch sizes -> not a fair test -- You're explaining away a negative result -> the data is probably right -- You got benchmark numbers and were about to slap them on an issue without methodology -> stop - -## Evaluation Protocol - -### Sweep Runs (HP Search) - -Standard budgets for the RX 7800 XT with 100M v3: -- **Short directional test:** 1000-2000 optimizer steps (~4000-8000 micro-steps at accum 4) -- **Full sweep run:** 4000+ micro-steps per config -- **Full pretraining:** ~400K micro-steps (1 epoch through 6.5B tokens) - -### Metrics - -- **Loss** (cross-entropy): Primary metric for pretraining sweeps. Lower is better. -- **PPL** (perplexity): exp(loss). More interpretable for comparison with prior felixlm work. -- **Tokens/sec**: Throughput. Report alongside loss — a 2% loss win at 3x cost may not be worth it. -- **VRAM peak**: Report for batch size experiments. -- Always report ALL metrics, not just the favorable ones. - -### Benchmark Metrics (Mnemonic Quality) - -- **nDCG@5**: Primary retrieval quality metric (IR benchmark) -- **Precision@5, Recall@5, MRR**: Supporting retrieval metrics -- **JSON compliance rate**: Encoding output validity -- **Latency**: End-to-end response time -- Report against established baselines (Gemini, stub LLM) with exact commands used. - -## What Counts as "Confirmed" - -A finding is confirmed when: -1. The effect is observed at the predicted scale AND direction -2. The most obvious alternative explanation (usually LR or param mismatch) is ruled out -3. Ideally, the effect holds across multiple runs or conditions - -A finding is "promising" (not confirmed) when: -1. Observed under one condition only -2. Delta is small (could be noise) -3. No alternative explanation has been explicitly tested diff --git a/CLAUDE.md b/CLAUDE.md index 2c245a89..4dad1a54 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -84,11 +84,8 @@ scripts/ Utility scripts ## Conventions -- **Event bus architecture:** Agents communicate via events, never direct calls. To add behavior, subscribe to events in the bus. -- **Store interface:** All data access goes through `store.Store` interface. The SQLite implementation is in `internal/store/sqlite/`. -- **Error handling:** Wrap errors with context: `fmt.Errorf("encoding memory %s: %w", id, err)` -- **Platform-specific code:** Use Go build tags (`//go:build darwin`, `//go:build !darwin`). See `internal/watcher/filesystem/` for examples. -- **Config:** All tunables live in `config.yaml`. Add new fields to `internal/config/config.go` struct. +See `.claude/rules/go-conventions.md` for Go style, lint, architecture, and build details. + - **Spoke routing:** When a spoke provider is configured (`LLM.Spoke` in config), specific agent tasks route to the spoke model via `CompositeProvider` (completions → spoke, embeddings → main provider). Configure task routing in `config.yaml`'s `LLM.Spoke.Tasks` list. Health-checked at startup in `cmd/mnemonic/serve.go`. ## Adding Things @@ -98,14 +95,6 @@ scripts/ Utility scripts - **New API route:** Add handler in `internal/api/routes/`, register in `internal/api/server.go`. Existing routes include `/api/v1/activity` (watcher concept tracker for MCP sync). - **New MCP tool:** Add to `internal/mcp/server.go` tool registration. -## Platform Support - -| Platform | Status | -|----------|--------| -| macOS ARM | Full support | -| Linux x86_64 | Full support (primary dev platform) — systemd service, RX 7800 XT + ROCm for training/inference | -| Windows x86_64 | Supported — `serve`, `install`, `start`, `stop`, `uninstall` work via Windows Services | - ## Training (Felix-LM / Mnemonic-LM) Felix-LM is a hub-and-spoke architecture for language models. The "central post" is a frozen pretrained base model. "Spokes" are lightweight low-rank adapters (~25M params, <1% overhead) injected at each decoder layer. The spokes are the only trainable parameters — the base model is frozen. @@ -132,64 +121,4 @@ All experiments must be pre-registered in `training/docs/experiment_registry.md` See [GitHub Issues](https://github.com/appsprout-dev/mnemonic/issues) for tracked bugs. ---- - -## MCP Tools Available - -You have 24 tools via the `mnemonic` MCP server: - -| Tool | When to Use | -|------|-------------| -| `remember` | Store decisions, errors, insights, learnings (returns raw ID + salience) | -| `recall` | Semantic search with spread activation (`explain`, `include_associations`, `format`, `type`, `types`, `include_patterns`, `include_abstractions`, `synthesize` params) | -| `batch_recall` | Run multiple recall queries in parallel — ideal for session start | -| `get_context` | Proactive suggestions based on recent daemon activity — call at natural breakpoints | -| `forget` | Archive irrelevant memories | -| `amend` | Update a memory's content in place (preserves associations, history, salience) | -| `check_memory` | Inspect a memory's encoding status, concepts, and associations | -| `status` | System health, encoding pipeline status, source distribution | -| `recall_project` | Get project-specific context and patterns | -| `recall_timeline` | See what happened in a time range | -| `recall_session` | Retrieve all memories from a specific MCP session | -| `list_sessions` | List recent sessions with time range and memory count | -| `session_summary` | Summarize current/recent session | -| `get_patterns` | View discovered recurring patterns (returns IDs for dismissal, supports `min_strength`) | -| `get_insights` | View metacognition observations and abstractions (returns IDs for dismissal) | -| `feedback` | Report recall quality (drives ranking, can auto-suppress noisy memories) | -| `audit_encodings` | Review recent encoding quality and suggest improvements | -| `coach_local_llm` | Write coaching guidance to improve local LLM prompts | -| `ingest_project` | Bulk-ingest a project directory into memory | -| `exclude_path` | Add a watcher exclusion pattern at runtime | -| `list_exclusions` | List all runtime watcher exclusion patterns | -| `dismiss_pattern` | Archive a stale or irrelevant pattern to stop it surfacing in recall | -| `dismiss_abstraction` | Archive a stale or irrelevant principle/axiom to stop it surfacing in recall | -| `create_handoff` | Store structured session handoff notes (high salience, surfaced by recall_project) | - -### At Session Start - -- Use `recall_project` to load context for the current project -- Use `recall` with relevant keywords to find prior decisions - -### During Work - -- `remember` decisions with `type: "decision"` — e.g., "chose SQLite over Postgres for simplicity" -- `remember` errors with `type: "error"` — e.g., "nil pointer in auth middleware, fixed with guard clause" -- `remember` insights with `type: "insight"` — e.g., "spread activation works best with 3 hops max" -- `remember` learnings with `type: "learning"` — e.g., "Go's sql.NullString needed for nullable columns" - -### After Recalls - -- Use `feedback` to rate recall quality — this helps the system improve -- `helpful` = memories were relevant and useful -- `partial` = some relevant, some not -- `irrelevant` = memories didn't help - -### Memory Types - -When using `remember`, set the `type` field: - -- `decision` — architectural choices, tradeoffs, "we chose X because Y" -- `error` — bugs found, error patterns, debugging insights -- `insight` — realizations about code, architecture, or process -- `learning` — new knowledge, API behaviors, framework quirks -- `general` — everything else (default) +MCP tool usage protocol: see `.claude/rules/mnemonic-usage.md`. Tool schemas come from the MCP server — no need to duplicate here. From 5bb0094952ed98e54702aa7b61d2549a4925bf9b Mon Sep 17 00:00:00 2001 From: Caleb Gross Date: Mon, 13 Apr 2026 23:05:58 -0400 Subject: [PATCH 4/8] =?UTF-8?q?feat:=20continuous=20learning=20Phase=20C?= =?UTF-8?q?=20=E2=80=94=20training=20trigger=20&=20orchestration=20(#391)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds the automated spoke training pipeline (Phase C) to the continuous learning system. When enough experience data accumulates in the buffer, the daemon can assemble training batches, run spoke fine-tuning via Python subprocess, evaluate against quality gates, and deploy new spokes. New infrastructure: - TrainingRun type and 5 new ContinuousLearningStore methods (WriteTrainingRun, UpdateTrainingRun, GetLastTrainingRunTime, CountUntrainedExperience, MarkExperienceUsedInTraining) - Migration 018: training_runs table for audit trail - Training orchestrator in dreaming agent (Phase 4.85 in dream cycle) with subprocess execution, quality gate (EPR >= 0.90, FR <= 0.05, SC >= 0.95), and atomic deployment with rollback - train_model MCP tool (#25) for manual training trigger - 26 new tests across curriculum, training data, and trigger logic The pipeline: check untrained count >= threshold → assemble JSONL batch → run train_spokes.py subprocess → evaluate via eval_encoding.py → deploy via deploy_model.sh if quality passes → record result. Gated by config flags, training window, and minimum data threshold. Also includes minor fixes from other agents: episoding debug logging, embedded LLM grammar improvements. Co-Authored-By: Claude Opus 4.6 (1M context) --- cmd/mnemonic/serve.go | 32 +- internal/agent/dreaming/agent.go | 11 + internal/agent/dreaming/curriculum.go | 1 - internal/agent/dreaming/curriculum_test.go | 422 ++++++++++++++++++ internal/agent/dreaming/training_data.go | 30 +- internal/agent/dreaming/training_data_test.go | 406 +++++++++++++++++ internal/agent/dreaming/training_trigger.go | 412 +++++++++++++++++ .../agent/dreaming/training_trigger_test.go | 266 +++++++++++ internal/agent/episoding/agent.go | 6 + internal/config/config.go | 2 +- internal/llm/embedded.go | 7 +- internal/llm/embedded_test.go | 18 + internal/llm/grammar.go | 26 ++ internal/mcp/server.go | 28 ++ internal/mcp/server_test.go | 5 +- internal/mcp/session.go | 23 +- internal/mcp/tools.go | 12 + internal/store/sqlite/continuous_learning.go | 100 +++++ internal/store/sqlite/schema.go | 26 +- internal/store/store.go | 67 ++- internal/store/storetest/mock.go | 13 +- migrations/009_training_runs.sql | 25 ++ scripts/backfill_verification.go | 14 +- 23 files changed, 1891 insertions(+), 61 deletions(-) create mode 100644 internal/agent/dreaming/curriculum_test.go create mode 100644 internal/agent/dreaming/training_data_test.go create mode 100644 internal/agent/dreaming/training_trigger.go create mode 100644 internal/agent/dreaming/training_trigger_test.go create mode 100644 migrations/009_training_runs.sql diff --git a/cmd/mnemonic/serve.go b/cmd/mnemonic/serve.go index 4b3e21da..0a37e5d9 100644 --- a/cmd/mnemonic/serve.go +++ b/cmd/mnemonic/serve.go @@ -518,6 +518,7 @@ func serveCommand(configPath string) { InsightsBudget: cfg.Dreaming.InsightsBudget, DefaultConfidence: cfg.Dreaming.DefaultConfidence, Curriculum: cfg.ContinuousLearning.Curriculum, + ContinuousLearning: cfg.ContinuousLearning, }, log) if err := dreamer.Start(rootCtx, bus); err != nil { @@ -688,7 +689,7 @@ func serveCommand(configPath string) { // Create MCP session manager for HTTP transport mcpResolver := config.NewProjectResolver(cfg.Projects) - mcpSessions := mcp.NewSessionManager(mcp.SessionManagerConfig{ + smCfg := mcp.SessionManagerConfig{ Store: memStore, Retriever: retriever, Bus: bus, @@ -709,7 +710,34 @@ func serveCommand(configPath string) { FeedbackStrengthDelta: cfg.MemoryDefaults.FeedbackStrengthDelta, FeedbackSalienceBoost: cfg.MemoryDefaults.FeedbackSalienceBoost, }, - }) + } + + // Wire up manual training trigger if dreaming agent is available + if dreamer != nil && cfg.ContinuousLearning.Trigger.Manual { + clCfg := cfg.ContinuousLearning + smCfg.TrainingTriggerFn = func(ctx context.Context) (map[string]any, error) { + result, err := dreamer.RunTrainingCycle(ctx, clCfg) + if err != nil { + return nil, err + } + if result == nil { + return nil, nil + } + return map[string]any{ + "status": result.Status, + "batch_id": result.BatchID, + "total_examples": result.TotalExamples, + "quality_passed": result.QualityPassed, + "checkpoint": result.CheckpointPath, + "model": result.ModelPath, + "eval_epr": result.EvalEPR, + "eval_sc": result.EvalSC, + "error": result.ErrorMessage, + }, nil + } + } + + mcpSessions := mcp.NewSessionManager(smCfg) apiDeps.MCPSessions = mcpSessions defer mcpSessions.Stop(rootCtx) diff --git a/internal/agent/dreaming/agent.go b/internal/agent/dreaming/agent.go index 2d61dce9..d16f35ff 100644 --- a/internal/agent/dreaming/agent.go +++ b/internal/agent/dreaming/agent.go @@ -28,6 +28,7 @@ type DreamingConfig struct { InsightsBudget int DefaultConfidence float32 Curriculum config.CLCurriculumConfig + ContinuousLearning config.ContinuousLearningConfig } type DreamingAgent struct { @@ -195,6 +196,16 @@ func (da *DreamingAgent) runCycle(ctx context.Context) (*DreamReport, error) { "failed", currReport.CorrectionsFailed) } + // Phase 4.85: Training trigger — check if enough data for spoke fine-tuning + if trainResult, err := da.trainingCheck(ctx, da.config.ContinuousLearning); err != nil && ctx.Err() == nil { + da.log.Error("training trigger phase failed", "error", err) + } else if trainResult != nil { + da.log.Info("training cycle result", + "status", trainResult.Status, + "examples", trainResult.TotalExamples, + "quality_passed", trainResult.QualityPassed) + } + // Phase 5: Link replayed memories to matching patterns if err := da.linkToPatterns(ctx, replayed, report); err != nil && ctx.Err() == nil { da.log.Error("pattern linking phase failed", "error", err) diff --git a/internal/agent/dreaming/curriculum.go b/internal/agent/dreaming/curriculum.go index 93373bcf..25a261e7 100644 --- a/internal/agent/dreaming/curriculum.go +++ b/internal/agent/dreaming/curriculum.go @@ -191,4 +191,3 @@ func computeSimpleEPR(rawContent, outputJSON string) float64 { } return float64(preserved) / float64(significant) } - diff --git a/internal/agent/dreaming/curriculum_test.go b/internal/agent/dreaming/curriculum_test.go new file mode 100644 index 00000000..caf4d064 --- /dev/null +++ b/internal/agent/dreaming/curriculum_test.go @@ -0,0 +1,422 @@ +package dreaming + +import ( + "context" + "fmt" + "io" + "log/slog" + "testing" + "time" + + "github.com/appsprout-dev/mnemonic/internal/config" + "github.com/appsprout-dev/mnemonic/internal/llm" + "github.com/appsprout-dev/mnemonic/internal/store" + "github.com/appsprout-dev/mnemonic/internal/store/storetest" +) + +// curriculumMockStore provides controlled responses for curriculum tests. +type curriculumMockStore struct { + storetest.MockStore + stats store.ExperienceStats + lastRunTime time.Time + needsImpEntries []store.ExperienceEntry + rawMemories map[string]store.RawMemory + correctedEntries map[string]correctedUpdate // entry_id -> update + curriculumRunsW []store.CurriculumRun + curriculumRunsU []store.CurriculumRun +} + +type correctedUpdate struct { + output string + epr float64 + source string +} + +func (m *curriculumMockStore) GetExperienceBufferStats(_ context.Context) (store.ExperienceStats, error) { + return m.stats, nil +} + +func (m *curriculumMockStore) GetLastCurriculumRunTime(_ context.Context) (time.Time, error) { + return m.lastRunTime, nil +} + +func (m *curriculumMockStore) ListNeedsImprovement(_ context.Context, limit int) ([]store.ExperienceEntry, error) { + if limit < len(m.needsImpEntries) { + return m.needsImpEntries[:limit], nil + } + return m.needsImpEntries, nil +} + +func (m *curriculumMockStore) GetRaw(_ context.Context, id string) (store.RawMemory, error) { + raw, ok := m.rawMemories[id] + if !ok { + return store.RawMemory{}, store.ErrNotFound + } + return raw, nil +} + +func (m *curriculumMockStore) UpdateExperienceCorrectedOutput(_ context.Context, entryID string, output string, epr float64, _ float64, source string) error { + if m.correctedEntries == nil { + m.correctedEntries = make(map[string]correctedUpdate) + } + m.correctedEntries[entryID] = correctedUpdate{output: output, epr: epr, source: source} + return nil +} + +func (m *curriculumMockStore) WriteCurriculumRun(_ context.Context, run store.CurriculumRun) error { + m.curriculumRunsW = append(m.curriculumRunsW, run) + return nil +} + +func (m *curriculumMockStore) UpdateCurriculumRun(_ context.Context, run store.CurriculumRun) error { + m.curriculumRunsU = append(m.curriculumRunsU, run) + return nil +} + +// curriculumMockLLM returns configurable teacher model responses. +type curriculumMockLLM struct { + completeFn func(ctx context.Context, req llm.CompletionRequest) (llm.CompletionResponse, error) +} + +func (p *curriculumMockLLM) Complete(ctx context.Context, req llm.CompletionRequest) (llm.CompletionResponse, error) { + if p.completeFn != nil { + return p.completeFn(ctx, req) + } + return llm.CompletionResponse{}, nil +} + +func (p *curriculumMockLLM) Embed(_ context.Context, _ string) ([]float32, error) { + return nil, nil +} + +func (p *curriculumMockLLM) BatchEmbed(_ context.Context, _ []string) ([][]float32, error) { + return nil, nil +} + +func (p *curriculumMockLLM) Health(_ context.Context) error { return nil } + +func (p *curriculumMockLLM) ModelInfo(_ context.Context) (llm.ModelMetadata, error) { + return llm.ModelMetadata{Name: "mock-teacher"}, nil +} + +func enabledCurriculumCfg() config.CLCurriculumConfig { + return config.CLCurriculumConfig{ + Enabled: true, + MaxCorrectionsPerCycle: 20, + MinNeedsImprovement: 2, + CooldownHours: 0, + } +} + +func TestCurriculumGeneration_Disabled(t *testing.T) { + ms := &curriculumMockStore{} + agent := NewDreamingAgent(ms, nil, DreamingConfig{Interval: time.Hour}, slog.New(slog.NewTextHandler(io.Discard, nil))) + + cfg := config.CLCurriculumConfig{Enabled: false} + report, err := agent.curriculumGeneration(context.Background(), cfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if report != nil { + t.Fatal("expected nil report when disabled") + } +} + +func TestCurriculumGeneration_InsufficientEntries(t *testing.T) { + ms := &curriculumMockStore{ + stats: store.ExperienceStats{NeedsImprovement: 3}, + } + agent := NewDreamingAgent(ms, nil, DreamingConfig{Interval: time.Hour}, slog.New(slog.NewTextHandler(io.Discard, nil))) + + cfg := enabledCurriculumCfg() + cfg.MinNeedsImprovement = 10 // requires 10, only 3 exist + + report, err := agent.curriculumGeneration(context.Background(), cfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if report != nil { + t.Fatal("expected nil report when insufficient entries") + } +} + +func TestCurriculumGeneration_CooldownActive(t *testing.T) { + ms := &curriculumMockStore{ + stats: store.ExperienceStats{NeedsImprovement: 20}, + lastRunTime: time.Now().Add(-1 * time.Hour), // ran 1 hour ago + } + agent := NewDreamingAgent(ms, nil, DreamingConfig{Interval: time.Hour}, slog.New(slog.NewTextHandler(io.Discard, nil))) + + cfg := enabledCurriculumCfg() + cfg.CooldownHours = 24 // needs 24h between runs + + report, err := agent.curriculumGeneration(context.Background(), cfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if report != nil { + t.Fatal("expected nil report when cooldown active") + } +} + +func TestCurriculumGeneration_SuccessfulCorrection(t *testing.T) { + ms := &curriculumMockStore{ + stats: store.ExperienceStats{NeedsImprovement: 10}, + needsImpEntries: []store.ExperienceEntry{ + {ID: "e1", RawID: "raw-1", MemoryID: "mem-1", EncodingEPR: 0.4, Category: "needs_improvement"}, + }, + rawMemories: map[string]store.RawMemory{ + "raw-1": {ID: "raw-1", Content: "Fixed authentication middleware null pointer when session expires on production server", Source: "terminal", Type: "command_executed"}, + }, + } + + teacherResponse := `{"summary":"auth middleware null pointer fix","content":"Fixed null pointer in authentication middleware triggered by expired sessions on production","concepts":["authentication","middleware","null-pointer","production"],"salience":0.8}` + llmProv := &curriculumMockLLM{ + completeFn: func(_ context.Context, _ llm.CompletionRequest) (llm.CompletionResponse, error) { + return llm.CompletionResponse{Content: teacherResponse}, nil + }, + } + + agent := NewDreamingAgent(ms, llmProv, DreamingConfig{Interval: time.Hour}, slog.New(slog.NewTextHandler(io.Discard, nil))) + + cfg := enabledCurriculumCfg() + report, err := agent.curriculumGeneration(context.Background(), cfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if report == nil { + t.Fatal("expected non-nil report") + } + if report.CorrectionsAttempted != 1 { + t.Errorf("expected 1 attempted, got %d", report.CorrectionsAttempted) + } + if report.CorrectionsPassed != 1 { + t.Errorf("expected 1 passed, got %d", report.CorrectionsPassed) + } + + // Verify corrected output was stored + update, ok := ms.correctedEntries["e1"] + if !ok { + t.Fatal("expected corrected entry to be stored") + } + if update.output != teacherResponse { + t.Errorf("stored output mismatch") + } + if update.source != "api" { + t.Errorf("expected source 'api', got %q", update.source) + } + + // Verify curriculum run was written and updated + if len(ms.curriculumRunsW) != 1 { + t.Fatalf("expected 1 curriculum run written, got %d", len(ms.curriculumRunsW)) + } + if len(ms.curriculumRunsU) != 1 { + t.Fatalf("expected 1 curriculum run updated, got %d", len(ms.curriculumRunsU)) + } + if ms.curriculumRunsU[0].Status != "completed" { + t.Errorf("expected status 'completed', got %q", ms.curriculumRunsU[0].Status) + } +} + +func TestCurriculumGeneration_TeacherReturnsInvalidJSON(t *testing.T) { + ms := &curriculumMockStore{ + stats: store.ExperienceStats{NeedsImprovement: 10}, + needsImpEntries: []store.ExperienceEntry{ + {ID: "e1", RawID: "raw-1", MemoryID: "mem-1", EncodingEPR: 0.3, Category: "needs_improvement"}, + }, + rawMemories: map[string]store.RawMemory{ + "raw-1": {ID: "raw-1", Content: "Some content that needs correction", Source: "terminal", Type: "command_executed"}, + }, + } + + llmProv := &curriculumMockLLM{ + completeFn: func(_ context.Context, _ llm.CompletionRequest) (llm.CompletionResponse, error) { + return llm.CompletionResponse{Content: "I cannot process this request."}, nil + }, + } + + agent := NewDreamingAgent(ms, llmProv, DreamingConfig{Interval: time.Hour}, slog.New(slog.NewTextHandler(io.Discard, nil))) + + report, err := agent.curriculumGeneration(context.Background(), enabledCurriculumCfg()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if report.CorrectionsFailed != 1 { + t.Errorf("expected 1 failed, got %d", report.CorrectionsFailed) + } + if report.CorrectionsPassed != 0 { + t.Errorf("expected 0 passed, got %d", report.CorrectionsPassed) + } + + // Should not have stored anything + if len(ms.correctedEntries) != 0 { + t.Error("expected no corrected entries stored") + } +} + +func TestCurriculumGeneration_TeacherMissingFields(t *testing.T) { + ms := &curriculumMockStore{ + stats: store.ExperienceStats{NeedsImprovement: 10}, + needsImpEntries: []store.ExperienceEntry{ + {ID: "e1", RawID: "raw-1", MemoryID: "mem-1", EncodingEPR: 0.3, Category: "needs_improvement"}, + }, + rawMemories: map[string]store.RawMemory{ + "raw-1": {ID: "raw-1", Content: "Some content that needs correction with enough words to pass threshold", Source: "terminal", Type: "command_executed"}, + }, + } + + // Valid JSON but missing required "concepts" field + llmProv := &curriculumMockLLM{ + completeFn: func(_ context.Context, _ llm.CompletionRequest) (llm.CompletionResponse, error) { + return llm.CompletionResponse{Content: `{"summary":"ok","content":"stuff"}`}, nil + }, + } + + agent := NewDreamingAgent(ms, llmProv, DreamingConfig{Interval: time.Hour}, slog.New(slog.NewTextHandler(io.Discard, nil))) + + report, err := agent.curriculumGeneration(context.Background(), enabledCurriculumCfg()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if report.CorrectionsFailed != 1 { + t.Errorf("expected 1 failed (missing field), got %d", report.CorrectionsFailed) + } +} + +func TestCurriculumGeneration_LowEPRRejected(t *testing.T) { + ms := &curriculumMockStore{ + stats: store.ExperienceStats{NeedsImprovement: 10}, + needsImpEntries: []store.ExperienceEntry{ + {ID: "e1", RawID: "raw-1", MemoryID: "mem-1", EncodingEPR: 0.3, Category: "needs_improvement"}, + }, + rawMemories: map[string]store.RawMemory{ + "raw-1": {ID: "raw-1", Content: "Deployed kubernetes cluster with terraform infrastructure monitoring dashboards grafana prometheus", Source: "terminal", Type: "command_executed"}, + }, + } + + // Teacher returns valid JSON but the output doesn't preserve entities from input + llmProv := &curriculumMockLLM{ + completeFn: func(_ context.Context, _ llm.CompletionRequest) (llm.CompletionResponse, error) { + return llm.CompletionResponse{Content: `{"summary":"did something","content":"generic stuff","concepts":["general"]}`}, nil + }, + } + + agent := NewDreamingAgent(ms, llmProv, DreamingConfig{Interval: time.Hour}, slog.New(slog.NewTextHandler(io.Discard, nil))) + + report, err := agent.curriculumGeneration(context.Background(), enabledCurriculumCfg()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if report.CorrectionsFailed != 1 { + t.Errorf("expected 1 failed (low EPR), got %d", report.CorrectionsFailed) + } +} + +func TestCurriculumGeneration_LLMError(t *testing.T) { + ms := &curriculumMockStore{ + stats: store.ExperienceStats{NeedsImprovement: 10}, + needsImpEntries: []store.ExperienceEntry{ + {ID: "e1", RawID: "raw-1", MemoryID: "mem-1", EncodingEPR: 0.3, Category: "needs_improvement"}, + }, + rawMemories: map[string]store.RawMemory{ + "raw-1": {ID: "raw-1", Content: "Some content", Source: "terminal", Type: "command_executed"}, + }, + } + + llmProv := &curriculumMockLLM{ + completeFn: func(_ context.Context, _ llm.CompletionRequest) (llm.CompletionResponse, error) { + return llm.CompletionResponse{}, fmt.Errorf("API rate limit exceeded") + }, + } + + agent := NewDreamingAgent(ms, llmProv, DreamingConfig{Interval: time.Hour}, slog.New(slog.NewTextHandler(io.Discard, nil))) + + report, err := agent.curriculumGeneration(context.Background(), enabledCurriculumCfg()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if report.CorrectionsFailed != 1 { + t.Errorf("expected 1 failed (LLM error), got %d", report.CorrectionsFailed) + } +} + +func TestCurriculumGeneration_MultipleEntries(t *testing.T) { + ms := &curriculumMockStore{ + stats: store.ExperienceStats{NeedsImprovement: 10}, + needsImpEntries: []store.ExperienceEntry{ + {ID: "e1", RawID: "raw-1", MemoryID: "mem-1", EncodingEPR: 0.3, Category: "needs_improvement"}, + {ID: "e2", RawID: "raw-2", MemoryID: "mem-2", EncodingEPR: 0.4, Category: "needs_improvement"}, + {ID: "e3", RawID: "raw-3", MemoryID: "mem-3", EncodingEPR: 0.35, Category: "needs_improvement"}, + }, + rawMemories: map[string]store.RawMemory{ + "raw-1": {ID: "raw-1", Content: "Fixed authentication middleware null pointer when session expires on production server", Source: "terminal", Type: "command_executed"}, + "raw-2": {ID: "raw-2", Content: "Deployed kubernetes cluster with terraform and configured monitoring dashboards for production", Source: "filesystem", Type: "file_modified"}, + "raw-3": {ID: "raw-3", Content: "Some short thing", Source: "mcp", Type: "general"}, + }, + } + + callCount := 0 + llmProv := &curriculumMockLLM{ + completeFn: func(_ context.Context, _ llm.CompletionRequest) (llm.CompletionResponse, error) { + callCount++ + switch callCount { + case 1: + // Must preserve enough 4+ char tokens from raw-1 for EPR >= 0.7 + return llm.CompletionResponse{Content: `{"summary":"auth fix","content":"Fixed authentication middleware null pointer when session expires on production server","concepts":["authentication","middleware","production"]}`}, nil + case 2: + return llm.CompletionResponse{Content: `{"summary":"k8s deploy","content":"Deployed kubernetes cluster with terraform and configured monitoring dashboards for production","concepts":["kubernetes","terraform","monitoring","production"]}`}, nil + default: + return llm.CompletionResponse{}, fmt.Errorf("API error") + } + }, + } + + agent := NewDreamingAgent(ms, llmProv, DreamingConfig{Interval: time.Hour}, slog.New(slog.NewTextHandler(io.Discard, nil))) + + report, err := agent.curriculumGeneration(context.Background(), enabledCurriculumCfg()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if report.CorrectionsAttempted != 3 { + t.Errorf("expected 3 attempted, got %d", report.CorrectionsAttempted) + } + if report.CorrectionsPassed != 2 { + t.Errorf("expected 2 passed, got %d", report.CorrectionsPassed) + } + if report.CorrectionsFailed != 1 { + t.Errorf("expected 1 failed, got %d", report.CorrectionsFailed) + } +} + +func TestCurriculumGeneration_ContextCancelled(t *testing.T) { + ms := &curriculumMockStore{ + stats: store.ExperienceStats{NeedsImprovement: 10}, + needsImpEntries: []store.ExperienceEntry{ + {ID: "e1", RawID: "raw-1", MemoryID: "mem-1", EncodingEPR: 0.3, Category: "needs_improvement"}, + {ID: "e2", RawID: "raw-2", MemoryID: "mem-2", EncodingEPR: 0.4, Category: "needs_improvement"}, + }, + rawMemories: map[string]store.RawMemory{ + "raw-1": {ID: "raw-1", Content: "First entry content with enough words for meaningful processing", Source: "terminal", Type: "command_executed"}, + "raw-2": {ID: "raw-2", Content: "Second entry content that should not be processed", Source: "terminal", Type: "command_executed"}, + }, + } + + ctx, cancel := context.WithCancel(context.Background()) + llmProv := &curriculumMockLLM{ + completeFn: func(_ context.Context, _ llm.CompletionRequest) (llm.CompletionResponse, error) { + cancel() // cancel after first call + return llm.CompletionResponse{Content: `{"summary":"test","content":"test content","concepts":["test"]}`}, nil + }, + } + + agent := NewDreamingAgent(ms, llmProv, DreamingConfig{Interval: time.Hour}, slog.New(slog.NewTextHandler(io.Discard, nil))) + + report, err := agent.curriculumGeneration(ctx, enabledCurriculumCfg()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + // Should have processed at most 1 (context cancelled before second) + if report.CorrectionsAttempted > 1 { + t.Errorf("expected at most 1 attempted after cancel, got %d", report.CorrectionsAttempted) + } +} diff --git a/internal/agent/dreaming/training_data.go b/internal/agent/dreaming/training_data.go index 1064131d..fade5627 100644 --- a/internal/agent/dreaming/training_data.go +++ b/internal/agent/dreaming/training_data.go @@ -17,21 +17,21 @@ import ( // TrainingExample is a single training pair written to JSONL. // The Python training script tokenizes and mixes with replay data. type TrainingExample struct { - Type string `json:"type"` // "gold" or "corrective" - Prompt string `json:"prompt"` // system + user prompt (identical to what the model saw) - Output string `json:"output"` // the target completion (gold encoding or corrected encoding) - MemoryID string `json:"memory_id"` // provenance - EPR float64 `json:"epr"` // EPR score of the output + Type string `json:"type"` // "gold" or "corrective" + Prompt string `json:"prompt"` // system + user prompt (identical to what the model saw) + Output string `json:"output"` // the target completion (gold encoding or corrected encoding) + MemoryID string `json:"memory_id"` // provenance + EPR float64 `json:"epr"` // EPR score of the output } // TrainingBatchManifest describes a training batch for reproducibility. type TrainingBatchManifest struct { - ID string `json:"id"` - CreatedAt time.Time `json:"created_at"` - GoldCount int `json:"gold_count"` - CorrectedCount int `json:"corrected_count"` - TotalExamples int `json:"total_examples"` - DataPath string `json:"data_path"` + ID string `json:"id"` + CreatedAt time.Time `json:"created_at"` + GoldCount int `json:"gold_count"` + CorrectedCount int `json:"corrected_count"` + TotalExamples int `json:"total_examples"` + DataPath string `json:"data_path"` } // AssembleTrainingBatch writes gold and corrected encoding pairs to a JSONL file. @@ -177,10 +177,10 @@ func (da *DreamingAgent) buildTrainingExample(ctx context.Context, entry store.E // Reconstruct the encoding output as JSON output, err := json.Marshal(map[string]any{ - "summary": mem.Summary, - "content": mem.Content, - "concepts": mem.Concepts, - "salience": mem.Salience, + "summary": mem.Summary, + "content": mem.Content, + "concepts": mem.Concepts, + "salience": mem.Salience, }) if err != nil { return nil, fmt.Errorf("marshaling memory output: %w", err) diff --git a/internal/agent/dreaming/training_data_test.go b/internal/agent/dreaming/training_data_test.go new file mode 100644 index 00000000..a7cdce23 --- /dev/null +++ b/internal/agent/dreaming/training_data_test.go @@ -0,0 +1,406 @@ +package dreaming + +import ( + "context" + "encoding/json" + "io" + "log/slog" + "os" + "path/filepath" + "testing" + "time" + + "github.com/appsprout-dev/mnemonic/internal/store" + "github.com/appsprout-dev/mnemonic/internal/store/storetest" +) + +// trainingDataMockStore provides controlled responses for training data assembly tests. +type trainingDataMockStore struct { + storetest.MockStore + goldEntries []store.ExperienceEntry + needsImpEntries []store.ExperienceEntry + rawMemories map[string]store.RawMemory + memories map[string]store.Memory +} + +func (m *trainingDataMockStore) ListExperienceByCategory(_ context.Context, category string, limit int) ([]store.ExperienceEntry, error) { + switch category { + case "gold": + if limit < len(m.goldEntries) { + return m.goldEntries[:limit], nil + } + return m.goldEntries, nil + case "needs_improvement": + if limit < len(m.needsImpEntries) { + return m.needsImpEntries[:limit], nil + } + return m.needsImpEntries, nil + } + return nil, nil +} + +func (m *trainingDataMockStore) GetRaw(_ context.Context, id string) (store.RawMemory, error) { + raw, ok := m.rawMemories[id] + if !ok { + return store.RawMemory{}, store.ErrNotFound + } + return raw, nil +} + +func (m *trainingDataMockStore) GetMemory(_ context.Context, id string) (store.Memory, error) { + mem, ok := m.memories[id] + if !ok { + return store.Memory{}, store.ErrNotFound + } + return mem, nil +} + +func newTestAgent(s store.Store) *DreamingAgent { + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + cfg := DreamingConfig{ + Interval: 3 * time.Hour, + BatchSize: 20, + } + return NewDreamingAgent(s, nil, cfg, logger) +} + +func TestAssembleTrainingBatch_GoldOnly(t *testing.T) { + ms := &trainingDataMockStore{ + goldEntries: []store.ExperienceEntry{ + {ID: "e1", RawID: "raw-1", MemoryID: "mem-1", EncodingEPR: 0.95, Category: "gold"}, + {ID: "e2", RawID: "raw-2", MemoryID: "mem-2", EncodingEPR: 0.92, Category: "gold"}, + }, + rawMemories: map[string]store.RawMemory{ + "raw-1": {ID: "raw-1", Content: "Fixed null pointer in auth middleware when session is expired", Source: "terminal", Type: "command_executed"}, + "raw-2": {ID: "raw-2", Content: "Added retry logic for flaky HTTP connections to upstream API", Source: "filesystem", Type: "file_modified"}, + }, + memories: map[string]store.Memory{ + "mem-1": {ID: "mem-1", Summary: "auth middleware null pointer fix", Content: "Fixed NPE in auth middleware for expired sessions", Concepts: []string{"auth", "null-pointer"}, Salience: 0.8}, + "mem-2": {ID: "mem-2", Summary: "HTTP retry logic", Content: "Added retry logic for upstream API connections", Concepts: []string{"http", "retry"}, Salience: 0.7}, + }, + } + + agent := newTestAgent(ms) + dir := t.TempDir() + + manifest, err := agent.AssembleTrainingBatch(context.Background(), dir, 100) + if err != nil { + t.Fatalf("AssembleTrainingBatch: %v", err) + } + if manifest.GoldCount != 2 { + t.Errorf("expected 2 gold, got %d", manifest.GoldCount) + } + if manifest.CorrectedCount != 0 { + t.Errorf("expected 0 corrected, got %d", manifest.CorrectedCount) + } + if manifest.TotalExamples != 2 { + t.Errorf("expected 2 total, got %d", manifest.TotalExamples) + } + + // Verify JSONL file exists and is parseable + data, err := os.ReadFile(manifest.DataPath) + if err != nil { + t.Fatalf("reading data file: %v", err) + } + lines := splitJSONLLines(t, data) + if len(lines) != 2 { + t.Fatalf("expected 2 JSONL lines, got %d", len(lines)) + } + + var ex TrainingExample + if err := json.Unmarshal(lines[0], &ex); err != nil { + t.Fatalf("parsing first JSONL line: %v", err) + } + if ex.Type != "gold" { + t.Errorf("expected type 'gold', got %q", ex.Type) + } + if ex.MemoryID != "mem-1" { + t.Errorf("expected memory_id 'mem-1', got %q", ex.MemoryID) + } + if ex.Prompt == "" { + t.Error("expected non-empty prompt") + } + if ex.Output == "" { + t.Error("expected non-empty output") + } + + // Output should be valid JSON with expected fields + var outputFields map[string]any + if err := json.Unmarshal([]byte(ex.Output), &outputFields); err != nil { + t.Fatalf("gold output is not valid JSON: %v", err) + } + if _, ok := outputFields["summary"]; !ok { + t.Error("gold output missing 'summary' field") + } + if _, ok := outputFields["content"]; !ok { + t.Error("gold output missing 'content' field") + } +} + +func TestAssembleTrainingBatch_CorrectedOnly(t *testing.T) { + ms := &trainingDataMockStore{ + needsImpEntries: []store.ExperienceEntry{ + { + ID: "e1", RawID: "raw-1", MemoryID: "mem-1", EncodingEPR: 0.45, Category: "needs_improvement", + CorrectedOutput: `{"summary":"corrected summary","content":"corrected content","concepts":["auth"]}`, + CorrectedEPR: 0.92, + }, + }, + rawMemories: map[string]store.RawMemory{ + "raw-1": {ID: "raw-1", Content: "Debugging auth middleware null pointer error in production", Source: "terminal", Type: "command_executed"}, + }, + } + + agent := newTestAgent(ms) + dir := t.TempDir() + + manifest, err := agent.AssembleTrainingBatch(context.Background(), dir, 100) + if err != nil { + t.Fatalf("AssembleTrainingBatch: %v", err) + } + if manifest.GoldCount != 0 { + t.Errorf("expected 0 gold, got %d", manifest.GoldCount) + } + if manifest.CorrectedCount != 1 { + t.Errorf("expected 1 corrected, got %d", manifest.CorrectedCount) + } + + data, err := os.ReadFile(manifest.DataPath) + if err != nil { + t.Fatalf("reading data file: %v", err) + } + lines := splitJSONLLines(t, data) + if len(lines) != 1 { + t.Fatalf("expected 1 JSONL line, got %d", len(lines)) + } + + var ex TrainingExample + if err := json.Unmarshal(lines[0], &ex); err != nil { + t.Fatalf("parsing JSONL line: %v", err) + } + if ex.Type != "corrective" { + t.Errorf("expected type 'corrective', got %q", ex.Type) + } + if ex.EPR != 0.92 { + t.Errorf("expected EPR 0.92, got %.2f", ex.EPR) + } + if ex.Output != `{"summary":"corrected summary","content":"corrected content","concepts":["auth"]}` { + t.Errorf("corrective output mismatch: %s", ex.Output) + } +} + +func TestAssembleTrainingBatch_MixedGoldAndCorrected(t *testing.T) { + ms := &trainingDataMockStore{ + goldEntries: []store.ExperienceEntry{ + {ID: "e1", RawID: "raw-1", MemoryID: "mem-1", EncodingEPR: 0.95, Category: "gold"}, + }, + needsImpEntries: []store.ExperienceEntry{ + { + ID: "e2", RawID: "raw-2", MemoryID: "mem-2", EncodingEPR: 0.4, Category: "needs_improvement", + CorrectedOutput: `{"summary":"fixed","content":"better","concepts":["go"]}`, + CorrectedEPR: 0.88, + }, + // This one has no correction — should be filtered out + {ID: "e3", RawID: "raw-3", MemoryID: "mem-3", EncodingEPR: 0.3, Category: "needs_improvement"}, + }, + rawMemories: map[string]store.RawMemory{ + "raw-1": {ID: "raw-1", Content: "Deployed new spoke model v2 with improved EPR", Source: "mcp", Type: "decision"}, + "raw-2": {ID: "raw-2", Content: "Refactored encoding pipeline to use batch processing", Source: "terminal", Type: "command_executed"}, + }, + memories: map[string]store.Memory{ + "mem-1": {ID: "mem-1", Summary: "spoke v2 deployment", Content: "Deployed spoke model v2", Concepts: []string{"model", "deployment"}, Salience: 0.9}, + }, + } + + agent := newTestAgent(ms) + dir := t.TempDir() + + manifest, err := agent.AssembleTrainingBatch(context.Background(), dir, 100) + if err != nil { + t.Fatalf("AssembleTrainingBatch: %v", err) + } + if manifest.GoldCount != 1 { + t.Errorf("expected 1 gold, got %d", manifest.GoldCount) + } + if manifest.CorrectedCount != 1 { + t.Errorf("expected 1 corrected, got %d", manifest.CorrectedCount) + } + if manifest.TotalExamples != 2 { + t.Errorf("expected 2 total, got %d", manifest.TotalExamples) + } +} + +func TestAssembleTrainingBatch_EmptyBuffer(t *testing.T) { + ms := &trainingDataMockStore{} + agent := newTestAgent(ms) + dir := t.TempDir() + + _, err := agent.AssembleTrainingBatch(context.Background(), dir, 100) + if err == nil { + t.Fatal("expected error for empty buffer") + } +} + +func TestAssembleTrainingBatch_ManifestWritten(t *testing.T) { + ms := &trainingDataMockStore{ + goldEntries: []store.ExperienceEntry{ + {ID: "e1", RawID: "raw-1", MemoryID: "mem-1", EncodingEPR: 0.95, Category: "gold"}, + }, + rawMemories: map[string]store.RawMemory{ + "raw-1": {ID: "raw-1", Content: "Test content for manifest verification", Source: "mcp", Type: "general"}, + }, + memories: map[string]store.Memory{ + "mem-1": {ID: "mem-1", Summary: "test", Content: "test content", Concepts: []string{"test"}, Salience: 0.5}, + }, + } + + agent := newTestAgent(ms) + dir := t.TempDir() + + manifest, err := agent.AssembleTrainingBatch(context.Background(), dir, 100) + if err != nil { + t.Fatalf("AssembleTrainingBatch: %v", err) + } + + // Verify manifest JSON was written alongside data file + manifestPath := filepath.Join(dir, "batch_"+manifest.ID+"_manifest.json") + data, err := os.ReadFile(manifestPath) + if err != nil { + t.Fatalf("reading manifest file: %v", err) + } + + var diskManifest TrainingBatchManifest + if err := json.Unmarshal(data, &diskManifest); err != nil { + t.Fatalf("parsing manifest: %v", err) + } + if diskManifest.ID != manifest.ID { + t.Errorf("manifest ID mismatch: %s vs %s", diskManifest.ID, manifest.ID) + } + if diskManifest.TotalExamples != manifest.TotalExamples { + t.Errorf("manifest total mismatch: %d vs %d", diskManifest.TotalExamples, manifest.TotalExamples) + } +} + +func TestAssembleTrainingBatch_DefaultMaxExamples(t *testing.T) { + ms := &trainingDataMockStore{ + goldEntries: []store.ExperienceEntry{ + {ID: "e1", RawID: "raw-1", MemoryID: "mem-1", EncodingEPR: 0.95, Category: "gold"}, + }, + rawMemories: map[string]store.RawMemory{ + "raw-1": {ID: "raw-1", Content: "Test default max examples path", Source: "mcp", Type: "general"}, + }, + memories: map[string]store.Memory{ + "mem-1": {ID: "mem-1", Summary: "test", Content: "test content", Concepts: []string{"test"}, Salience: 0.5}, + }, + } + + agent := newTestAgent(ms) + dir := t.TempDir() + + // Pass 0 — should use default of 200 + manifest, err := agent.AssembleTrainingBatch(context.Background(), dir, 0) + if err != nil { + t.Fatalf("AssembleTrainingBatch with 0: %v", err) + } + if manifest.TotalExamples != 1 { + t.Errorf("expected 1 total, got %d", manifest.TotalExamples) + } +} + +func TestAssembleTrainingBatch_SkipsMissingRaw(t *testing.T) { + ms := &trainingDataMockStore{ + goldEntries: []store.ExperienceEntry{ + {ID: "e1", RawID: "raw-missing", MemoryID: "mem-1", EncodingEPR: 0.95, Category: "gold"}, + {ID: "e2", RawID: "raw-2", MemoryID: "mem-2", EncodingEPR: 0.93, Category: "gold"}, + }, + rawMemories: map[string]store.RawMemory{ + // raw-missing is intentionally absent + "raw-2": {ID: "raw-2", Content: "This one exists and should be written", Source: "mcp", Type: "general"}, + }, + memories: map[string]store.Memory{ + "mem-2": {ID: "mem-2", Summary: "valid", Content: "valid content", Concepts: []string{"valid"}, Salience: 0.8}, + }, + } + + agent := newTestAgent(ms) + dir := t.TempDir() + + manifest, err := agent.AssembleTrainingBatch(context.Background(), dir, 100) + if err != nil { + t.Fatalf("AssembleTrainingBatch: %v", err) + } + // First gold entry is skipped (missing raw), second succeeds + if manifest.TotalExamples != 1 { + t.Errorf("expected 1 total (1 skipped), got %d", manifest.TotalExamples) + } +} + +func TestComputeSimpleEPR(t *testing.T) { + tests := []struct { + name string + raw string + output string + minEPR float64 + maxEPR float64 + }{ + { + name: "high preservation", + raw: "Fixed authentication middleware null pointer when session expires", + output: `{"summary":"Fixed authentication middleware null pointer for expired sessions","concepts":["authentication","middleware"]}`, + minEPR: 0.5, + maxEPR: 1.0, + }, + { + name: "low preservation", + raw: "Deployed kubernetes cluster with terraform and configured monitoring dashboards", + output: `{"summary":"did something","concepts":["general"]}`, + minEPR: 0.0, + maxEPR: 0.5, + }, + { + name: "empty raw returns 1.0", + raw: "hi", + output: `{"summary":"hi"}`, + minEPR: 1.0, + maxEPR: 1.0, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + epr := computeSimpleEPR(tc.raw, tc.output) + if epr < tc.minEPR || epr > tc.maxEPR { + t.Errorf("EPR %.3f outside expected range [%.2f, %.2f]", epr, tc.minEPR, tc.maxEPR) + } + }) + } +} + +// splitJSONLLines splits JSONL bytes into individual JSON lines, skipping empty lines. +func splitJSONLLines(t *testing.T, data []byte) []json.RawMessage { + t.Helper() + var lines []json.RawMessage + for _, line := range splitBytes(data, '\n') { + if len(line) == 0 { + continue + } + lines = append(lines, json.RawMessage(line)) + } + return lines +} + +// splitBytes is a simple byte splitter. +func splitBytes(data []byte, sep byte) [][]byte { + var result [][]byte + start := 0 + for i, b := range data { + if b == sep { + result = append(result, data[start:i]) + start = i + 1 + } + } + if start < len(data) { + result = append(result, data[start:]) + } + return result +} diff --git a/internal/agent/dreaming/training_trigger.go b/internal/agent/dreaming/training_trigger.go new file mode 100644 index 00000000..6bd521b5 --- /dev/null +++ b/internal/agent/dreaming/training_trigger.go @@ -0,0 +1,412 @@ +package dreaming + +import ( + "context" + "fmt" + "encoding/json" + "os" + "os/exec" + "path/filepath" + "time" + + "github.com/appsprout-dev/mnemonic/internal/config" + "github.com/appsprout-dev/mnemonic/internal/store" + "github.com/google/uuid" +) + +// TrainingResult reports the outcome of a training cycle. +type TrainingResult struct { + BatchID string + TotalExamples int + Status string // completed, failed + CheckpointPath string + ModelPath string + EvalEPR float64 + EvalSC float64 + QualityPassed bool + ErrorMessage string +} + +// trainingCheck runs Phase 4.85: check if we should trigger spoke training. +// Only runs during dreaming if auto-trigger is enabled. Also callable via MCP. +func (da *DreamingAgent) trainingCheck(ctx context.Context, clCfg config.ContinuousLearningConfig) (*TrainingResult, error) { + if !clCfg.Enabled { + return nil, nil + } + if !clCfg.Trigger.Auto { + da.log.Debug("training auto-trigger disabled, skipping") + return nil, nil + } + + // Check training window + if clCfg.Trigger.TrainingWindow != "" && !inTrainingWindow(clCfg.Trigger.TrainingWindow) { + da.log.Debug("outside training window, skipping", "window", clCfg.Trigger.TrainingWindow) + return nil, nil + } + + return da.RunTrainingCycle(ctx, clCfg) +} + +// RunTrainingCycle executes the full training pipeline: +// 1. Check if enough untrained experience exists +// 2. Assemble training batch (JSONL) +// 3. Run spoke training (Python subprocess) +// 4. Run quality gate evaluation +// 5. Deploy new spokes if quality passes +// +// This is the manual entry point called by MCP tools or dreaming auto-trigger. +func (da *DreamingAgent) RunTrainingCycle(ctx context.Context, clCfg config.ContinuousLearningConfig) (*TrainingResult, error) { + tCfg := clCfg.Training + + // Step 1: Check if enough untrained data exists + untrained, err := da.store.CountUntrainedExperience(ctx) + if err != nil { + return nil, fmt.Errorf("counting untrained experience: %w", err) + } + + minExamples := tCfg.MinNewExamples + if minExamples <= 0 { + minExamples = 50 + } + if untrained < minExamples { + da.log.Info("training skipped: insufficient untrained data", + "untrained", untrained, "min_required", minExamples) + return nil, nil + } + + // Step 2: Assemble training batch + outputDir := filepath.Join(os.TempDir(), "mnemonic-training") + maxExamples := tCfg.MaxExamplesPerRun + if maxExamples <= 0 { + maxExamples = 200 + } + + manifest, err := da.AssembleTrainingBatch(ctx, outputDir, maxExamples) + if err != nil { + return nil, fmt.Errorf("assembling training batch: %w", err) + } + + // Create a training run record + runID := uuid.New().String()[:8] + run := store.TrainingRun{ + ID: runID, + BatchID: manifest.ID, + BatchPath: manifest.DataPath, + GoldCount: manifest.GoldCount, + CorrectedCount: manifest.CorrectedCount, + TotalExamples: manifest.TotalExamples, + Status: "training", + StartedAt: time.Now(), + } + if err := da.store.WriteTrainingRun(ctx, run); err != nil { + return nil, fmt.Errorf("writing training run: %w", err) + } + + da.log.Info("training cycle started", + "run_id", runID, "batch_id", manifest.ID, + "examples", manifest.TotalExamples) + + result := &TrainingResult{ + BatchID: manifest.ID, + TotalExamples: manifest.TotalExamples, + } + + // Step 3: Run spoke training + checkpointPath, err := da.runSpokeTraining(ctx, manifest.DataPath, tCfg) + if err != nil { + result.Status = "failed" + result.ErrorMessage = fmt.Sprintf("training failed: %v", err) + da.failTrainingRun(ctx, &run, result.ErrorMessage) + return result, nil + } + run.CheckpointPath = checkpointPath + run.Status = "evaluating" + _ = da.store.UpdateTrainingRun(ctx, run) + + // Step 4: Run quality gate + evalResult, err := da.runQualityGate(ctx, checkpointPath) + if err != nil { + result.Status = "failed" + result.ErrorMessage = fmt.Sprintf("evaluation failed: %v", err) + da.failTrainingRun(ctx, &run, result.ErrorMessage) + return result, nil + } + + run.EvalEPR = evalResult.EPR + run.EvalFR = evalResult.FR + run.EvalSC = evalResult.SC + run.QualityPassed = evalResult.Passed + result.EvalEPR = evalResult.EPR + result.EvalSC = evalResult.SC + + if !evalResult.Passed { + result.Status = "failed" + result.QualityPassed = false + result.ErrorMessage = fmt.Sprintf("quality gate failed: EPR=%.2f FR=%.2f SC=%.2f", evalResult.EPR, evalResult.FR, evalResult.SC) + da.failTrainingRun(ctx, &run, result.ErrorMessage) + da.log.Warn("training quality gate failed — discarding checkpoint", + "run_id", runID, "epr", evalResult.EPR, "fr", evalResult.FR, "sc", evalResult.SC) + return result, nil + } + + // Step 5: Deploy new spokes + run.Status = "deploying" + _ = da.store.UpdateTrainingRun(ctx, run) + + modelPath, err := da.deploySpokeModel(ctx, checkpointPath) + if err != nil { + result.Status = "failed" + result.ErrorMessage = fmt.Sprintf("deployment failed: %v", err) + da.failTrainingRun(ctx, &run, result.ErrorMessage) + return result, nil + } + + // Success + now := time.Now() + run.ModelPath = modelPath + run.Status = "completed" + run.CompletedAt = &now + _ = da.store.UpdateTrainingRun(ctx, run) + + result.Status = "completed" + result.QualityPassed = true + result.CheckpointPath = checkpointPath + result.ModelPath = modelPath + + da.log.Info("training cycle completed", + "run_id", runID, "epr", evalResult.EPR, "sc", evalResult.SC, + "model", modelPath) + + return result, nil +} + +// failTrainingRun records a failed training run in the store. +func (da *DreamingAgent) failTrainingRun(ctx context.Context, run *store.TrainingRun, errMsg string) { + now := time.Now() + run.Status = "failed" + run.ErrorMessage = errMsg + run.CompletedAt = &now + _ = da.store.UpdateTrainingRun(ctx, *run) +} + +// qualityGateResult holds the evaluation metrics from the quality gate. +type qualityGateResult struct { + EPR float64 + FR float64 + SC float64 + Passed bool +} + +// runSpokeTraining executes the Python training script as a subprocess. +// Returns the path to the output checkpoint. +func (da *DreamingAgent) runSpokeTraining(ctx context.Context, batchPath string, tCfg config.CLTrainingConfig) (string, error) { + // The training script lives relative to the daemon binary's project root. + // Use the MNEMONIC_PROJECT_DIR env var or default to /home//Projects/mem. + projectDir := os.Getenv("MNEMONIC_PROJECT_DIR") + if projectDir == "" { + homeDir, _ := os.UserHomeDir() + projectDir = filepath.Join(homeDir, "Projects", "mem") + } + + scriptPath := filepath.Join(projectDir, "training", "scripts", "train_spokes.py") + if _, err := os.Stat(scriptPath); err != nil { + return "", fmt.Errorf("training script not found at %s: %w", scriptPath, err) + } + + checkpointDir := filepath.Join(projectDir, "checkpoints", "continuous_learning") + if err := os.MkdirAll(checkpointDir, 0o755); err != nil { + return "", fmt.Errorf("creating checkpoint dir: %w", err) + } + + // Construct training command. The venv must be activated by the caller + // or the script must be runnable with the system Python. + venvPython := filepath.Join(os.Getenv("HOME"), "Projects", "felixlm", ".venv", "bin", "python") + if _, err := os.Stat(venvPython); err != nil { + venvPython = "python3" // fallback + } + + args := []string{ + scriptPath, + "--model-type", "gemma", + "--data", batchPath, + "--output-dir", checkpointDir, + "--steps", "500", + "--batch-size", "1", + "--grad-accum", "8", + "--lr", "1e-4", + } + + da.log.Info("running spoke training", + "script", scriptPath, "data", batchPath, + "output_dir", checkpointDir) + + cmd := exec.CommandContext(ctx, venvPython, args...) + cmd.Dir = projectDir + cmd.Env = append(os.Environ(), "PYTHONUNBUFFERED=1") + + output, err := cmd.CombinedOutput() + if err != nil { + return "", fmt.Errorf("training script failed: %w\nOutput: %s", err, string(output)) + } + + // Find the checkpoint — the script writes to output_dir/last.pt + checkpointPath := filepath.Join(checkpointDir, "last.pt") + if _, err := os.Stat(checkpointPath); err != nil { + return "", fmt.Errorf("checkpoint not found after training at %s", checkpointPath) + } + + da.log.Info("spoke training completed", "checkpoint", checkpointPath) + return checkpointPath, nil +} + +// runQualityGate evaluates the trained checkpoint against probe inputs. +// Returns metrics and whether the model passes the quality threshold. +func (da *DreamingAgent) runQualityGate(ctx context.Context, checkpointPath string) (*qualityGateResult, error) { + projectDir := os.Getenv("MNEMONIC_PROJECT_DIR") + if projectDir == "" { + homeDir, _ := os.UserHomeDir() + projectDir = filepath.Join(homeDir, "Projects", "mem") + } + + evalScript := filepath.Join(projectDir, "training", "scripts", "eval_encoding.py") + if _, err := os.Stat(evalScript); err != nil { + return nil, fmt.Errorf("eval script not found at %s: %w", evalScript, err) + } + + venvPython := filepath.Join(os.Getenv("HOME"), "Projects", "felixlm", ".venv", "bin", "python") + if _, err := os.Stat(venvPython); err != nil { + venvPython = "python3" + } + + args := []string{ + evalScript, + "--checkpoint", checkpointPath, + "--mode", "generate", + "--json-output", + } + + cmd := exec.CommandContext(ctx, venvPython, args...) + cmd.Dir = projectDir + + output, err := cmd.CombinedOutput() + if err != nil { + return nil, fmt.Errorf("eval script failed: %w\nOutput: %s", err, string(output)) + } + + // Parse the JSON output from the eval script. + // The script outputs a JSON line with EPR, FR, SC metrics. + result, err := parseEvalOutput(string(output)) + if err != nil { + return nil, fmt.Errorf("parsing eval output: %w", err) + } + + // Apply quality thresholds from the design doc: + // EPR >= 0.90, FR <= 0.05, SC >= 0.95 + result.Passed = result.EPR >= 0.90 && result.FR <= 0.05 && result.SC >= 0.95 + + da.log.Info("quality gate evaluation", + "epr", result.EPR, "fr", result.FR, "sc", result.SC, + "passed", result.Passed) + + return result, nil +} + +// deploySpokeModel exports the checkpoint to GGUF and deploys it. +func (da *DreamingAgent) deploySpokeModel(ctx context.Context, checkpointPath string) (string, error) { + projectDir := os.Getenv("MNEMONIC_PROJECT_DIR") + if projectDir == "" { + homeDir, _ := os.UserHomeDir() + projectDir = filepath.Join(homeDir, "Projects", "mem") + } + + deployScript := filepath.Join(projectDir, "training", "scripts", "deploy_model.sh") + if _, err := os.Stat(deployScript); err != nil { + return "", fmt.Errorf("deploy script not found at %s: %w", deployScript, err) + } + + // Version the model with timestamp + modelName := fmt.Sprintf("gemma-spokes-cl-%s", time.Now().Format("20060102-150405")) + + cmd := exec.CommandContext(ctx, "bash", deployScript, checkpointPath, "--name", modelName) + cmd.Dir = projectDir + + output, err := cmd.CombinedOutput() + if err != nil { + return "", fmt.Errorf("deploy script failed: %w\nOutput: %s", err, string(output)) + } + + modelPath := filepath.Join(projectDir, "models", modelName+".gguf") + da.log.Info("spoke model deployed", "path", modelPath, "name", modelName) + + return modelPath, nil +} + +// parseEvalOutput extracts metrics from the evaluation script's JSON output. +func parseEvalOutput(output string) (*qualityGateResult, error) { + // The eval script outputs various lines. We look for the JSON summary. + // For now, use a simple heuristic: find the last line that starts with '{'. + lines := splitLines(output) + for i := len(lines) - 1; i >= 0; i-- { + line := lines[i] + if len(line) > 0 && line[0] == '{' { + var metrics struct { + EPR float64 `json:"epr"` + FR float64 `json:"fr"` + SC float64 `json:"sc"` + } + if err := json.Unmarshal([]byte(line), &metrics); err != nil { + continue + } + return &qualityGateResult{ + EPR: metrics.EPR, + FR: metrics.FR, + SC: metrics.SC, + }, nil + } + } + return nil, fmt.Errorf("no JSON metrics found in eval output") +} + +// inTrainingWindow checks if the current time is within the configured window. +// Window format: "HH:MM-HH:MM" (24-hour, e.g. "02:00-06:00"). +func inTrainingWindow(window string) bool { + if window == "" { + return true + } + var startH, startM, endH, endM int + n, _ := fmt.Sscanf(window, "%d:%d-%d:%d", &startH, &startM, &endH, &endM) + if n != 4 { + return true // malformed window, allow + } + + now := time.Now() + currentMin := now.Hour()*60 + now.Minute() + startMin := startH*60 + startM + endMin := endH*60 + endM + + if startMin <= endMin { + return currentMin >= startMin && currentMin < endMin + } + // Wraps midnight (e.g. "22:00-06:00") + return currentMin >= startMin || currentMin < endMin +} + +// splitLines splits a string into lines, trimming trailing whitespace. +func splitLines(s string) []string { + var lines []string + start := 0 + for i := 0; i < len(s); i++ { + if s[i] == '\n' { + line := s[start:i] + if len(line) > 0 && line[len(line)-1] == '\r' { + line = line[:len(line)-1] + } + lines = append(lines, line) + start = i + 1 + } + } + if start < len(s) { + lines = append(lines, s[start:]) + } + return lines +} diff --git a/internal/agent/dreaming/training_trigger_test.go b/internal/agent/dreaming/training_trigger_test.go new file mode 100644 index 00000000..34901426 --- /dev/null +++ b/internal/agent/dreaming/training_trigger_test.go @@ -0,0 +1,266 @@ +package dreaming + +import ( + "context" + "io" + "log/slog" + "testing" + "time" + + "github.com/appsprout-dev/mnemonic/internal/config" + "github.com/appsprout-dev/mnemonic/internal/store" + "github.com/appsprout-dev/mnemonic/internal/store/storetest" +) + +// triggerMockStore provides controlled responses for training trigger tests. +type triggerMockStore struct { + storetest.MockStore + untrainedCount int + goldEntries []store.ExperienceEntry + needsImpEntries []store.ExperienceEntry + rawMemories map[string]store.RawMemory + memories map[string]store.Memory + trainingRunsW []store.TrainingRun + trainingRunsU []store.TrainingRun +} + +func (m *triggerMockStore) CountUntrainedExperience(_ context.Context) (int, error) { + return m.untrainedCount, nil +} + +func (m *triggerMockStore) ListExperienceByCategory(_ context.Context, category string, limit int) ([]store.ExperienceEntry, error) { + switch category { + case "gold": + if limit < len(m.goldEntries) { + return m.goldEntries[:limit], nil + } + return m.goldEntries, nil + case "needs_improvement": + if limit < len(m.needsImpEntries) { + return m.needsImpEntries[:limit], nil + } + return m.needsImpEntries, nil + } + return nil, nil +} + +func (m *triggerMockStore) GetRaw(_ context.Context, id string) (store.RawMemory, error) { + raw, ok := m.rawMemories[id] + if !ok { + return store.RawMemory{}, store.ErrNotFound + } + return raw, nil +} + +func (m *triggerMockStore) GetMemory(_ context.Context, id string) (store.Memory, error) { + mem, ok := m.memories[id] + if !ok { + return store.Memory{}, store.ErrNotFound + } + return mem, nil +} + +func (m *triggerMockStore) WriteTrainingRun(_ context.Context, run store.TrainingRun) error { + m.trainingRunsW = append(m.trainingRunsW, run) + return nil +} + +func (m *triggerMockStore) UpdateTrainingRun(_ context.Context, run store.TrainingRun) error { + m.trainingRunsU = append(m.trainingRunsU, run) + return nil +} + +func baseCLConfig() config.ContinuousLearningConfig { + return config.ContinuousLearningConfig{ + Enabled: true, + Training: config.CLTrainingConfig{ + MinNewExamples: 5, // low threshold for tests + MaxExamplesPerRun: 50, + ReplayRatio: 0.3, + RollbackVersions: 3, + }, + Curriculum: config.CLCurriculumConfig{ + Enabled: true, + }, + Trigger: config.CLTriggerConfig{ + Auto: true, + Manual: true, + }, + } +} + +func TestTrainingCheck_Disabled(t *testing.T) { + ms := &triggerMockStore{untrainedCount: 100} + agent := NewDreamingAgent(ms, nil, DreamingConfig{Interval: time.Hour}, slog.New(slog.NewTextHandler(io.Discard, nil))) + + clCfg := baseCLConfig() + clCfg.Enabled = false + + result, err := agent.trainingCheck(context.Background(), clCfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result != nil { + t.Fatal("expected nil result when disabled") + } +} + +func TestTrainingCheck_AutoTriggerDisabled(t *testing.T) { + ms := &triggerMockStore{untrainedCount: 100} + agent := NewDreamingAgent(ms, nil, DreamingConfig{Interval: time.Hour}, slog.New(slog.NewTextHandler(io.Discard, nil))) + + clCfg := baseCLConfig() + clCfg.Trigger.Auto = false + + result, err := agent.trainingCheck(context.Background(), clCfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result != nil { + t.Fatal("expected nil result when auto-trigger disabled") + } +} + +func TestRunTrainingCycle_InsufficientData(t *testing.T) { + ms := &triggerMockStore{untrainedCount: 3} + agent := NewDreamingAgent(ms, nil, DreamingConfig{Interval: time.Hour}, slog.New(slog.NewTextHandler(io.Discard, nil))) + + clCfg := baseCLConfig() + clCfg.Training.MinNewExamples = 50 + + result, err := agent.RunTrainingCycle(context.Background(), clCfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result != nil { + t.Fatal("expected nil result for insufficient data") + } +} + +func TestRunTrainingCycle_AssemblesAndRecords(t *testing.T) { + ms := &triggerMockStore{ + untrainedCount: 10, + goldEntries: []store.ExperienceEntry{ + {ID: "e1", RawID: "raw-1", MemoryID: "mem-1", EncodingEPR: 0.95, Category: "gold"}, + }, + rawMemories: map[string]store.RawMemory{ + "raw-1": {ID: "raw-1", Content: "Test content for training trigger", Source: "mcp", Type: "general"}, + }, + memories: map[string]store.Memory{ + "mem-1": {ID: "mem-1", Summary: "test", Content: "test content", Concepts: []string{"test"}, Salience: 0.5}, + }, + } + + agent := NewDreamingAgent(ms, nil, DreamingConfig{Interval: time.Hour}, slog.New(slog.NewTextHandler(io.Discard, nil))) + + clCfg := baseCLConfig() + + // RunTrainingCycle will assemble data, write a training run, then fail on + // the subprocess call (no Python env in tests). That's expected — we're testing + // the trigger logic and record-keeping, not the actual training. + result, err := agent.RunTrainingCycle(context.Background(), clCfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Should have assembled data and started a training run + if len(ms.trainingRunsW) != 1 { + t.Fatalf("expected 1 training run written, got %d", len(ms.trainingRunsW)) + } + run := ms.trainingRunsW[0] + if run.Status != "training" { + t.Errorf("expected initial status 'training', got %q", run.Status) + } + if run.TotalExamples != 1 { + t.Errorf("expected 1 total example, got %d", run.TotalExamples) + } + + // Training script will fail (not available in test env) — result should reflect that + if result == nil { + t.Fatal("expected non-nil result") + } + if result.Status != "failed" { + t.Errorf("expected status 'failed' (no training env), got %q", result.Status) + } + if result.ErrorMessage == "" { + t.Error("expected error message") + } + + // Should have updated the training run to failed + if len(ms.trainingRunsU) < 1 { + t.Fatal("expected at least 1 training run update") + } + lastUpdate := ms.trainingRunsU[len(ms.trainingRunsU)-1] + if lastUpdate.Status != "failed" { + t.Errorf("expected updated status 'failed', got %q", lastUpdate.Status) + } +} + +func TestInTrainingWindow(t *testing.T) { + tests := []struct { + name string + window string + want bool + }{ + {"empty window always allows", "", true}, + {"malformed window allows", "bad", true}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := inTrainingWindow(tc.window) + if got != tc.want { + t.Errorf("inTrainingWindow(%q) = %v, want %v", tc.window, got, tc.want) + } + }) + } +} + +func TestParseEvalOutput(t *testing.T) { + t.Run("valid JSON metrics", func(t *testing.T) { + output := "Loading model...\nRunning evaluation...\n{\"epr\": 0.92, \"fr\": 0.03, \"sc\": 0.96}\nDone." + result, err := parseEvalOutput(output) + if err != nil { + t.Fatalf("parseEvalOutput: %v", err) + } + if result.EPR != 0.92 { + t.Errorf("expected EPR 0.92, got %.2f", result.EPR) + } + if result.FR != 0.03 { + t.Errorf("expected FR 0.03, got %.2f", result.FR) + } + if result.SC != 0.96 { + t.Errorf("expected SC 0.96, got %.2f", result.SC) + } + }) + + t.Run("no JSON in output", func(t *testing.T) { + _, err := parseEvalOutput("No metrics here\nJust text output") + if err == nil { + t.Fatal("expected error for missing JSON") + } + }) + + t.Run("quality gate pass", func(t *testing.T) { + output := `{"epr": 0.95, "fr": 0.02, "sc": 0.98}` + result, err := parseEvalOutput(output) + if err != nil { + t.Fatalf("parseEvalOutput: %v", err) + } + result.Passed = result.EPR >= 0.90 && result.FR <= 0.05 && result.SC >= 0.95 + if !result.Passed { + t.Error("expected quality gate to pass") + } + }) + + t.Run("quality gate fail low EPR", func(t *testing.T) { + output := `{"epr": 0.85, "fr": 0.02, "sc": 0.98}` + result, err := parseEvalOutput(output) + if err != nil { + t.Fatalf("parseEvalOutput: %v", err) + } + result.Passed = result.EPR >= 0.90 && result.FR <= 0.05 && result.SC >= 0.95 + if result.Passed { + t.Error("expected quality gate to fail for low EPR") + } + }) +} diff --git a/internal/agent/episoding/agent.go b/internal/agent/episoding/agent.go index f1203f3f..a24d2a38 100644 --- a/internal/agent/episoding/agent.go +++ b/internal/agent/episoding/agent.go @@ -528,6 +528,12 @@ func parseEpisodeSynthesis(response string) episodeSynthesis { var result episodeSynthesis jsonStr := agentutil.ExtractJSON(response) if err := json.Unmarshal([]byte(jsonStr), &result); err != nil { + slog.Warn("episode synthesis parse failed", + "error", err, + "response_len", len(response), + "extracted_json_len", len(jsonStr), + "response_preview", agentutil.Truncate(response, 200), + ) return episodeSynthesis{ Title: "Untitled session", Summary: "Episode synthesis failed — LLM returned unparseable response.", diff --git a/internal/config/config.go b/internal/config/config.go index de9df32e..5b69d971 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -489,7 +489,7 @@ type ContinuousLearningConfig struct { // CLCurriculumConfig holds settings for Phase B curriculum generation. type CLCurriculumConfig struct { - Enabled bool `yaml:"enabled"` // enable curriculum generation in dreaming (default: false) + Enabled bool `yaml:"enabled"` // enable curriculum generation in dreaming (default: false) MaxCorrectionsPerCycle int `yaml:"max_corrections_per_cycle"` // max entries to re-encode per dream cycle (default: 20) MinNeedsImprovement int `yaml:"min_needs_improvement"` // min needs_improvement entries before running (default: 10) CooldownHours int `yaml:"cooldown_hours"` // hours between curriculum runs (default: 24) diff --git a/internal/llm/embedded.go b/internal/llm/embedded.go index dc2b29f8..b369b9a0 100644 --- a/internal/llm/embedded.go +++ b/internal/llm/embedded.go @@ -292,9 +292,12 @@ func (p *EmbeddedProvider) Complete(ctx context.Context, req CompletionRequest) grammar = GBNFJSONObject } if req.ResponseFormat != nil && req.ResponseFormat.Type == "json_schema" && req.ResponseFormat.JSONSchema != nil { - if req.ResponseFormat.JSONSchema.Name == "encoding_response" { + switch req.ResponseFormat.JSONSchema.Name { + case "encoding_response": grammar = GBNFEncodingResponse - } else { + case "episode_synthesis": + grammar = GBNFEpisodeSynthesis + default: grammar = GBNFJSONObject } } diff --git a/internal/llm/embedded_test.go b/internal/llm/embedded_test.go index 1b4c52e5..88bf6ca2 100644 --- a/internal/llm/embedded_test.go +++ b/internal/llm/embedded_test.go @@ -259,6 +259,24 @@ func TestEmbeddedProviderGrammarRouting(t *testing.T) { t.Errorf("expected encoding-specific GBNF grammar for encoding_response schema, got generic") } + // json_schema with episode_synthesis name — should use episode-specific grammar + _, err = p.Complete(ctx, CompletionRequest{ + Messages: []Message{{Role: "user", Content: "hello"}}, + ResponseFormat: &ResponseFormat{ + Type: "json_schema", + JSONSchema: &JSONSchema{ + Name: "episode_synthesis", + Strict: true, + }, + }, + }) + if err != nil { + t.Fatalf("Complete failed: %v", err) + } + if capturedGrammar != GBNFEpisodeSynthesis { + t.Errorf("expected episode-specific GBNF grammar for episode_synthesis schema, got generic") + } + // json_schema with other name — should fall back to generic JSON grammar _, err = p.Complete(ctx, CompletionRequest{ Messages: []Message{{Role: "user", Content: "hello"}}, diff --git a/internal/llm/grammar.go b/internal/llm/grammar.go index 4bc061a9..81514e97 100644 --- a/internal/llm/grammar.go +++ b/internal/llm/grammar.go @@ -57,6 +57,32 @@ number ::= "-"? ("0" | [1-9] [0-9]*) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws ::= ([ \t\n] ws)? ` +// GBNFEpisodeSynthesis constrains output to the episode synthesis schema. +// Fixed key order and typed values prevent the embedded model from producing +// type mismatches (e.g. salience as string) that break json.Unmarshal. +const GBNFEpisodeSynthesis = `root ::= "{" ws title-kv "," ws summary-kv "," ws narrative-kv "," ws emotional-tone-kv "," ws outcome-kv "," ws concepts-kv "," ws salience-kv ws "}" + +title-kv ::= "\"title\"" ws ":" ws string +summary-kv ::= "\"summary\"" ws ":" ws string +narrative-kv ::= "\"narrative\"" ws ":" ws string +emotional-tone-kv ::= "\"emotional_tone\"" ws ":" ws string +outcome-kv ::= "\"outcome\"" ws ":" ws string +concepts-kv ::= "\"concepts\"" ws ":" ws string-array +salience-kv ::= "\"salience\"" ws ":" ws number + +string-array ::= "[" ws "]" | "[" ws string ("," ws string)* ws "]" + +string ::= + "\"" ( + [^\\"\x00-\x1f] | + "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) + )* "\"" + +number ::= "-"? ("0" | [1-9] [0-9]*) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? + +ws ::= ([ \t\n] ws)? +` + // GBNFEncodingResponse constrains output to the mnemonic encoding response schema. // Fixed key order eliminates ambiguity for small models and enforces all required fields. const GBNFEncodingResponse = `root ::= "{" ws gist-kv "," ws summary-kv "," ws content-kv "," ws narrative-kv "," ws concepts-kv "," ws structured-concepts-kv "," ws significance-kv "," ws emotional-tone-kv "," ws outcome-kv "," ws salience-kv ws "}" diff --git a/internal/mcp/server.go b/internal/mcp/server.go index e99c120f..92debb7b 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -121,6 +121,9 @@ type MCPServer struct { // Daemon activity sync (for context_boost in MCP processes) daemonURL string // base URL of daemon API (e.g. "http://127.0.0.1:9999") + + // Optional training trigger (set when dreaming agent is available) + trainingTriggerFn func(ctx context.Context) (map[string]any, error) } // NewMCPServer creates a new MCP server with the given dependencies. @@ -160,6 +163,12 @@ func NewMCPServer(s store.Store, r *retrieval.RetrievalAgent, bus events.Bus, lo } } +// SetTrainingTrigger sets the function used by the train_model MCP tool. +// Called from serve.go when the dreaming agent is available. +func (srv *MCPServer) SetTrainingTrigger(fn func(ctx context.Context) (map[string]any, error)) { + srv.trainingTriggerFn = fn +} + // detectProject determines the project name from the current working directory. func detectProject() string { wd, err := os.Getwd() @@ -349,6 +358,8 @@ func (srv *MCPServer) handleToolCall(ctx context.Context, req *JSONRPCRequest) * result, toolErr = srv.handleDismissAbstraction(ctx, params.Arguments) case "create_handoff": result, toolErr = srv.handleCreateHandoff(ctx, params.Arguments) + case "train_model": + result, toolErr = srv.handleTrainModel(ctx, params.Arguments) default: return errorResponse(req.ID, -32602, fmt.Sprintf("Unknown tool: %s", params.Name)) } @@ -2883,3 +2894,20 @@ func (srv *MCPServer) handleCreateHandoff(ctx context.Context, args map[string]a srv.log.Info("session handoff created", "id", raw.ID, "project", srv.project) return toolResult(fmt.Sprintf("Handoff stored (id: %s, salience: 0.95)\nWill be surfaced by recall_project in the next session.", raw.ID)), nil } + +// handleTrainModel triggers a spoke training cycle manually. +func (srv *MCPServer) handleTrainModel(ctx context.Context, _ map[string]any) (any, error) { + if srv.trainingTriggerFn == nil { + return nil, fmt.Errorf("training not available — daemon must be running with dreaming agent enabled") + } + + result, err := srv.trainingTriggerFn(ctx) + if err != nil { + return nil, fmt.Errorf("training cycle failed: %w", err) + } + if result == nil { + return toolResult("Training skipped — insufficient untrained data in experience buffer. Check `status` for details."), nil + } + + return result, nil +} diff --git a/internal/mcp/server_test.go b/internal/mcp/server_test.go index 1bd850e4..938f0f98 100644 --- a/internal/mcp/server_test.go +++ b/internal/mcp/server_test.go @@ -140,8 +140,8 @@ func TestHandleToolsList(t *testing.T) { t.Fatalf("tools is not an array, got %T", toolsInterface) } - if len(toolsArray) != 24 { - t.Fatalf("expected 24 tools, got %d", len(toolsArray)) + if len(toolsArray) != 25 { + t.Fatalf("expected 25 tools, got %d", len(toolsArray)) } // Verify tool names @@ -170,6 +170,7 @@ func TestHandleToolsList(t *testing.T) { "dismiss_pattern": false, "dismiss_abstraction": false, "create_handoff": false, + "train_model": false, } for _, toolInterface := range toolsArray { diff --git a/internal/mcp/session.go b/internal/mcp/session.go index 5ffc338b..eb4b6715 100644 --- a/internal/mcp/session.go +++ b/internal/mcp/session.go @@ -30,8 +30,9 @@ type SessionManager struct { excludePatterns []string maxContentBytes int resolver ProjectResolver - daemonURL string - memDefaults MemoryDefaults + daemonURL string + memDefaults MemoryDefaults + trainingTriggerFn func(ctx context.Context) (map[string]any, error) idleTimeout time.Duration // how long before an idle session is expired stopCh chan struct{} // signals the reaper goroutine to stop @@ -53,9 +54,10 @@ type SessionManagerConfig struct { ExcludePatterns []string MaxContentBytes int Resolver *config.ProjectResolver - DaemonURL string - MemDefaults MemoryDefaults - IdleTimeout time.Duration // default: 30 minutes + DaemonURL string + MemDefaults MemoryDefaults + TrainingTriggerFn func(ctx context.Context) (map[string]any, error) + IdleTimeout time.Duration // default: 30 minutes } // NewSessionManager creates a session manager for HTTP MCP transport. @@ -76,9 +78,10 @@ func NewSessionManager(cfg SessionManagerConfig) *SessionManager { excludePatterns: cfg.ExcludePatterns, maxContentBytes: cfg.MaxContentBytes, resolver: cfg.Resolver, - daemonURL: cfg.DaemonURL, - memDefaults: cfg.MemDefaults, - idleTimeout: timeout, + daemonURL: cfg.DaemonURL, + memDefaults: cfg.MemDefaults, + trainingTriggerFn: cfg.TrainingTriggerFn, + idleTimeout: timeout, stopCh: make(chan struct{}), } @@ -111,6 +114,10 @@ func (sm *SessionManager) GetOrCreate(clientSessionID string) (*MCPServer, strin sm.memDefaults, ) + if sm.trainingTriggerFn != nil { + srv.SetTrainingTrigger(sm.trainingTriggerFn) + } + // Use the MCPServer's generated session ID as the key key := srv.SessionID() sm.sessions[key] = &httpSession{ diff --git a/internal/mcp/tools.go b/internal/mcp/tools.go index f9ce2c40..2ac42c1d 100644 --- a/internal/mcp/tools.go +++ b/internal/mcp/tools.go @@ -701,5 +701,17 @@ func allToolDefs() []ToolDefinition { dismissPatternToolDef(), dismissAbstractionToolDef(), createHandoffToolDef(), + trainModelToolDef(), + } +} + +func trainModelToolDef() ToolDefinition { + return ToolDefinition{ + Name: "train_model", + Description: "Trigger a spoke fine-tuning cycle using accumulated experience data. Assembles gold and corrected encoding pairs into a training batch, runs spoke training, evaluates quality against probes, and deploys if the quality gate passes. Requires sufficient untrained data in the experience buffer (default: 50 entries).", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{}, + }, } } diff --git a/internal/store/sqlite/continuous_learning.go b/internal/store/sqlite/continuous_learning.go index 1abe3e44..1eaca1e1 100644 --- a/internal/store/sqlite/continuous_learning.go +++ b/internal/store/sqlite/continuous_learning.go @@ -334,6 +334,106 @@ func (s *SQLiteStore) GetLastCurriculumRunTime(ctx context.Context) (time.Time, return t, nil } +// --- Phase C: Training runs --- + +func (s *SQLiteStore) WriteTrainingRun(ctx context.Context, run store.TrainingRun) error { + _, err := s.db.ExecContext(ctx, + `INSERT INTO training_runs (id, batch_id, batch_path, gold_count, corrected_count, + total_examples, status, checkpoint_path, model_path, eval_epr, eval_fr, eval_sc, + quality_passed, error_message, started_at, completed_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + run.ID, run.BatchID, run.BatchPath, run.GoldCount, run.CorrectedCount, + run.TotalExamples, run.Status, run.CheckpointPath, run.ModelPath, + run.EvalEPR, run.EvalFR, run.EvalSC, + run.QualityPassed, run.ErrorMessage, run.StartedAt, run.CompletedAt, + ) + if err != nil { + return fmt.Errorf("writing training run %s: %w", run.ID, err) + } + return nil +} + +func (s *SQLiteStore) UpdateTrainingRun(ctx context.Context, run store.TrainingRun) error { + _, err := s.db.ExecContext(ctx, + `UPDATE training_runs + SET status = ?, checkpoint_path = ?, model_path = ?, + eval_epr = ?, eval_fr = ?, eval_sc = ?, + quality_passed = ?, error_message = ?, completed_at = ? + WHERE id = ?`, + run.Status, run.CheckpointPath, run.ModelPath, + run.EvalEPR, run.EvalFR, run.EvalSC, + run.QualityPassed, run.ErrorMessage, run.CompletedAt, run.ID, + ) + if err != nil { + return fmt.Errorf("updating training run %s: %w", run.ID, err) + } + return nil +} + +func (s *SQLiteStore) GetLastTrainingRunTime(ctx context.Context) (time.Time, error) { + var raw *string + err := s.db.QueryRowContext(ctx, + `SELECT MAX(started_at) FROM training_runs WHERE status = 'completed'`, + ).Scan(&raw) + if err != nil { + return time.Time{}, fmt.Errorf("getting last training run time: %w", err) + } + if raw == nil || *raw == "" { + return time.Time{}, nil + } + formats := []string{ + time.RFC3339Nano, time.RFC3339, + "2006-01-02 15:04:05-07:00", "2006-01-02T15:04:05Z", + "2006-01-02 15:04:05 -0700 MST", + } + var t time.Time + var parseErr error + for _, f := range formats { + t, parseErr = time.Parse(f, *raw) + if parseErr == nil { + break + } + } + if parseErr != nil { + return time.Time{}, fmt.Errorf("parsing training run time %q: %w", *raw, parseErr) + } + return t, nil +} + +func (s *SQLiteStore) CountUntrainedExperience(ctx context.Context) (int, error) { + var count int + err := s.db.QueryRowContext(ctx, + `SELECT COUNT(*) FROM experience_buffer + WHERE used_in_training = 0 + AND (category = 'gold' OR (category = 'needs_improvement' AND corrected_output IS NOT NULL))`, + ).Scan(&count) + if err != nil { + return 0, fmt.Errorf("counting untrained experience: %w", err) + } + return count, nil +} + +func (s *SQLiteStore) MarkExperienceUsedInTraining(ctx context.Context, batchID string, entryIDs []string) error { + if len(entryIDs) == 0 { + return nil + } + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("beginning transaction: %w", err) + } + defer func() { _ = tx.Rollback() }() + + for _, id := range entryIDs { + if _, err := tx.ExecContext(ctx, + `UPDATE experience_buffer SET used_in_training = 1, updated_at = CURRENT_TIMESTAMP WHERE id = ?`, + id, + ); err != nil { + return fmt.Errorf("marking entry %s as used: %w", id, err) + } + } + return tx.Commit() +} + func (s *SQLiteStore) ListRecentEncodingQuality(ctx context.Context, limit int) ([]store.EncodingQualityEntry, error) { rows, err := s.db.QueryContext(ctx, `SELECT m.id, COALESCE(m.summary, ''), COALESCE(m.source, ''), diff --git a/internal/store/sqlite/schema.go b/internal/store/sqlite/schema.go index c0ce00e3..e44bc8b3 100644 --- a/internal/store/sqlite/schema.go +++ b/internal/store/sqlite/schema.go @@ -10,7 +10,7 @@ import ( // migration is added. It is written to PRAGMA user_version after InitSchema // completes, and read by the pre-migration backup logic to skip backups when // the schema is already current. -const SchemaVersion = 17 +const SchemaVersion = 18 const schema = ` -- Raw observations before encoding @@ -642,6 +642,30 @@ INSERT OR IGNORE INTO forum_categories (id, name, slug, description, icon, color } _, _ = db.Exec(`CREATE INDEX IF NOT EXISTS idx_curriculum_runs_status ON curriculum_runs(status)`) + // Migration 018: Training runs table (Phase C — automated spoke training) + if _, err := db.Exec(`CREATE TABLE IF NOT EXISTS training_runs ( + id TEXT PRIMARY KEY, + batch_id TEXT NOT NULL, + batch_path TEXT NOT NULL, + gold_count INTEGER DEFAULT 0, + corrected_count INTEGER DEFAULT 0, + total_examples INTEGER DEFAULT 0, + status TEXT DEFAULT 'pending', + checkpoint_path TEXT, + model_path TEXT, + eval_epr REAL DEFAULT 0, + eval_fr REAL DEFAULT 0, + eval_sc REAL DEFAULT 0, + quality_passed BOOLEAN DEFAULT FALSE, + error_message TEXT, + started_at DATETIME NOT NULL, + completed_at DATETIME + )`); err != nil { + return fmt.Errorf("failed to create training_runs table: %w", err) + } + _, _ = db.Exec(`CREATE INDEX IF NOT EXISTS idx_training_runs_status ON training_runs(status)`) + _, _ = db.Exec(`CREATE INDEX IF NOT EXISTS idx_training_runs_started_at ON training_runs(started_at)`) + // Record the schema version so pre-migration backups can skip when current. if _, err := db.Exec(fmt.Sprintf("PRAGMA user_version = %d", SchemaVersion)); err != nil { return fmt.Errorf("failed to set user_version: %w", err) diff --git a/internal/store/store.go b/internal/store/store.go index 9f810f91..2fbb6f59 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -584,16 +584,16 @@ type AnalyticsStore interface { // ExperienceEntry represents a training candidate in the experience buffer. type ExperienceEntry struct { - ID string `json:"id"` - RawID string `json:"raw_id"` - MemoryID string `json:"memory_id"` - EncodingEPR float64 `json:"encoding_epr"` - EncodingFR float64 `json:"encoding_fr"` - EncodingFlags []string `json:"encoding_flags"` - RecallScore float64 `json:"recall_score"` - RecallCount int `json:"recall_count"` - Category string `json:"category"` // gold, needs_improvement, ambiguous - UsedInTraining bool `json:"used_in_training"` + ID string `json:"id"` + RawID string `json:"raw_id"` + MemoryID string `json:"memory_id"` + EncodingEPR float64 `json:"encoding_epr"` + EncodingFR float64 `json:"encoding_fr"` + EncodingFlags []string `json:"encoding_flags"` + RecallScore float64 `json:"recall_score"` + RecallCount int `json:"recall_count"` + Category string `json:"category"` // gold, needs_improvement, ambiguous + UsedInTraining bool `json:"used_in_training"` // Phase B: Curriculum generation — corrected output from teacher model CorrectedOutput string `json:"corrected_output,omitempty"` @@ -608,16 +608,16 @@ type ExperienceEntry struct { // CurriculumRun tracks a single curriculum generation cycle. type CurriculumRun struct { - ID string `json:"id"` - StartedAt time.Time `json:"started_at"` - CompletedAt *time.Time `json:"completed_at,omitempty"` - CorrectionsAttempted int `json:"corrections_attempted"` - CorrectionsPassed int `json:"corrections_passed"` - CorrectionsFailed int `json:"corrections_failed"` - EntriesReclassified int `json:"entries_reclassified"` - TrainingBatchPath string `json:"training_batch_path,omitempty"` - Status string `json:"status"` // pending, completed, failed - CreatedAt time.Time `json:"created_at"` + ID string `json:"id"` + StartedAt time.Time `json:"started_at"` + CompletedAt *time.Time `json:"completed_at,omitempty"` + CorrectionsAttempted int `json:"corrections_attempted"` + CorrectionsPassed int `json:"corrections_passed"` + CorrectionsFailed int `json:"corrections_failed"` + EntriesReclassified int `json:"entries_reclassified"` + TrainingBatchPath string `json:"training_batch_path,omitempty"` + Status string `json:"status"` // pending, completed, failed + CreatedAt time.Time `json:"created_at"` } // ExperienceStats summarizes the experience buffer contents. @@ -638,6 +638,26 @@ type RecallFeedbackEntry struct { CreatedAt time.Time `json:"created_at"` } +// TrainingRun tracks a single spoke fine-tuning cycle. +type TrainingRun struct { + ID string `json:"id"` + BatchID string `json:"batch_id"` // links to TrainingBatchManifest + BatchPath string `json:"batch_path"` // JSONL file path + GoldCount int `json:"gold_count"` + CorrectedCount int `json:"corrected_count"` + TotalExamples int `json:"total_examples"` + Status string `json:"status"` // pending, training, evaluating, deploying, completed, failed + CheckpointPath string `json:"checkpoint_path,omitempty"` + ModelPath string `json:"model_path,omitempty"` // deployed GGUF path + EvalEPR float64 `json:"eval_epr,omitempty"` + EvalFR float64 `json:"eval_fr,omitempty"` + EvalSC float64 `json:"eval_sc,omitempty"` // schema compliance + QualityPassed bool `json:"quality_passed"` + ErrorMessage string `json:"error_message,omitempty"` + StartedAt time.Time `json:"started_at"` + CompletedAt *time.Time `json:"completed_at,omitempty"` +} + // EncodingQualityWindow holds rolling quality metrics for drift detection. type EncodingQualityWindow struct { WindowSize int `json:"window_size"` @@ -670,6 +690,13 @@ type ContinuousLearningStore interface { UpdateCurriculumRun(ctx context.Context, run CurriculumRun) error GetLastCurriculumRunTime(ctx context.Context) (time.Time, error) + // Training runs (Phase C) + WriteTrainingRun(ctx context.Context, run TrainingRun) error + UpdateTrainingRun(ctx context.Context, run TrainingRun) error + GetLastTrainingRunTime(ctx context.Context) (time.Time, error) + CountUntrainedExperience(ctx context.Context) (int, error) + MarkExperienceUsedInTraining(ctx context.Context, batchID string, entryIDs []string) error + // Quality drift detection GetEncodingQualityWindow(ctx context.Context, windowSize int) (EncodingQualityWindow, error) diff --git a/internal/store/storetest/mock.go b/internal/store/storetest/mock.go index 6c1ee31e..c96375d1 100644 --- a/internal/store/storetest/mock.go +++ b/internal/store/storetest/mock.go @@ -398,11 +398,20 @@ func (MockStore) UpdateExperienceCorrectedOutput(context.Context, string, string func (MockStore) ListNeedsImprovement(context.Context, int) ([]store.ExperienceEntry, error) { return nil, nil } -func (MockStore) WriteCurriculumRun(context.Context, store.CurriculumRun) error { return nil } -func (MockStore) UpdateCurriculumRun(context.Context, store.CurriculumRun) error { return nil } +func (MockStore) WriteCurriculumRun(context.Context, store.CurriculumRun) error { return nil } +func (MockStore) UpdateCurriculumRun(context.Context, store.CurriculumRun) error { return nil } func (MockStore) GetLastCurriculumRunTime(context.Context) (time.Time, error) { return time.Time{}, nil } +func (MockStore) WriteTrainingRun(context.Context, store.TrainingRun) error { return nil } +func (MockStore) UpdateTrainingRun(context.Context, store.TrainingRun) error { return nil } +func (MockStore) GetLastTrainingRunTime(context.Context) (time.Time, error) { + return time.Time{}, nil +} +func (MockStore) CountUntrainedExperience(context.Context) (int, error) { return 0, nil } +func (MockStore) MarkExperienceUsedInTraining(context.Context, string, []string) error { + return nil +} // --- Lifecycle --- diff --git a/migrations/009_training_runs.sql b/migrations/009_training_runs.sql new file mode 100644 index 00000000..0623e7f0 --- /dev/null +++ b/migrations/009_training_runs.sql @@ -0,0 +1,25 @@ +-- Migration 009: Training runs table (Phase C — automated spoke training) +-- Tracks each spoke fine-tuning cycle for auditing and rollback. +-- Linked to experience_buffer entries via batch_id. + +CREATE TABLE IF NOT EXISTS training_runs ( + id TEXT PRIMARY KEY, + batch_id TEXT NOT NULL, -- links to training batch manifest + batch_path TEXT NOT NULL, -- JSONL file path + gold_count INTEGER DEFAULT 0, + corrected_count INTEGER DEFAULT 0, + total_examples INTEGER DEFAULT 0, + status TEXT DEFAULT 'pending', -- pending, training, evaluating, deploying, completed, failed + checkpoint_path TEXT, -- PyTorch checkpoint after training + model_path TEXT, -- deployed GGUF path + eval_epr REAL DEFAULT 0, -- post-training evaluation: entity preservation rate + eval_fr REAL DEFAULT 0, -- post-training evaluation: fabrication rate + eval_sc REAL DEFAULT 0, -- post-training evaluation: schema compliance + quality_passed BOOLEAN DEFAULT FALSE, + error_message TEXT, + started_at DATETIME NOT NULL, + completed_at DATETIME +); + +CREATE INDEX IF NOT EXISTS idx_training_runs_status ON training_runs(status); +CREATE INDEX IF NOT EXISTS idx_training_runs_started_at ON training_runs(started_at); diff --git a/scripts/backfill_verification.go b/scripts/backfill_verification.go index 7627ee3f..6c4c4a09 100644 --- a/scripts/backfill_verification.go +++ b/scripts/backfill_verification.go @@ -27,13 +27,13 @@ type verificationResult struct { } var ( - numberRE = regexp.MustCompile(`-?\d{1,3}(?:,\d{3})+(?:\.\d+)?|-?\d+\.\d+[eE][+-]?\d+|-?\d+\.\d+%|-?\d+%|-?\d+\.\d+|\d+/\d+|\d+`) - pathRE = regexp.MustCompile(`[a-zA-Z_~/][\w/~.-]+\.(?:go|py|js|ts|html|css|yaml|yml|json|jsonl|toml|md|sh|sql|gguf|db|txt|log|patch|cuh|cpp|c|h)\b|/(?:home|usr|etc|var|tmp|opt|api|static)[\w./~-]+`) - versionRE = regexp.MustCompile(`v\d+\.\d+(?:\.\d+)?`) - multiWordRE = regexp.MustCompile(`\b([A-Z][a-z]+(?:\s+[A-Z][a-z]+)+)\b`) - singleProperRE = regexp.MustCompile(`(?:[a-z,;]\s)([A-Z][a-z]{2,})\b`) - mentionRE = regexp.MustCompile(`@(\w+)`) - camelCaseRE = regexp.MustCompile(`\b([A-Z][a-z]+[A-Z]\w+)\b`) + numberRE = regexp.MustCompile(`-?\d{1,3}(?:,\d{3})+(?:\.\d+)?|-?\d+\.\d+[eE][+-]?\d+|-?\d+\.\d+%|-?\d+%|-?\d+\.\d+|\d+/\d+|\d+`) + pathRE = regexp.MustCompile(`[a-zA-Z_~/][\w/~.-]+\.(?:go|py|js|ts|html|css|yaml|yml|json|jsonl|toml|md|sh|sql|gguf|db|txt|log|patch|cuh|cpp|c|h)\b|/(?:home|usr|etc|var|tmp|opt|api|static)[\w./~-]+`) + versionRE = regexp.MustCompile(`v\d+\.\d+(?:\.\d+)?`) + multiWordRE = regexp.MustCompile(`\b([A-Z][a-z]+(?:\s+[A-Z][a-z]+)+)\b`) + singleProperRE = regexp.MustCompile(`(?:[a-z,;]\s)([A-Z][a-z]{2,})\b`) + mentionRE = regexp.MustCompile(`@(\w+)`) + camelCaseRE = regexp.MustCompile(`\b([A-Z][a-z]+[A-Z]\w+)\b`) ) var commonWords = map[string]bool{ From 524c17b0b32d8176354bb134c212a19aa5c899ba Mon Sep 17 00:00:00 2001 From: Caleb Gross Date: Mon, 13 Apr 2026 23:53:12 -0400 Subject: [PATCH 5/8] refactor: adapt training data format for Gemma prep pipeline (#391) Changes TrainingExample from prompt/output to raw_input/encoded/task_type format matching prepare_gemma_finetune_data.py input. Adds a tokenization step (prepareTrainingData) before training, and updates runSpokeTraining args to match current train_spokes.py CLI for Gemma. Co-Authored-By: Claude Opus 4.6 (1M context) --- internal/agent/dreaming/training_data.go | 60 ++++++++-------- internal/agent/dreaming/training_data_test.go | 40 ++++++----- internal/agent/dreaming/training_trigger.go | 69 +++++++++++++++++-- .../agent/dreaming/training_trigger_test.go | 10 +-- 4 files changed, 125 insertions(+), 54 deletions(-) diff --git a/internal/agent/dreaming/training_data.go b/internal/agent/dreaming/training_data.go index fade5627..08822a6e 100644 --- a/internal/agent/dreaming/training_data.go +++ b/internal/agent/dreaming/training_data.go @@ -9,17 +9,21 @@ import ( "time" "github.com/appsprout-dev/mnemonic/internal/agent/agentutil" - "github.com/appsprout-dev/mnemonic/internal/agent/encoding" "github.com/appsprout-dev/mnemonic/internal/store" "github.com/google/uuid" ) // TrainingExample is a single training pair written to JSONL. -// The Python training script tokenizes and mixes with replay data. +// Format matches prepare_gemma_finetune_data.py input: +// +// {"raw_input": "...", "encoded": {...}, "task_type": "encoding", "memory_id": "...", "epr": 0.95} +// +// The prep script applies the chat template, tokenizes, and produces input_ids/completion_start +// for the training script. type TrainingExample struct { - Type string `json:"type"` // "gold" or "corrective" - Prompt string `json:"prompt"` // system + user prompt (identical to what the model saw) - Output string `json:"output"` // the target completion (gold encoding or corrected encoding) + RawInput string `json:"raw_input"` // raw memory content (prep script applies chat template) + Encoded any `json:"encoded"` // structured encoding output (JSON object) + TaskType string `json:"task_type"` // "encoding" (for prep script compatibility) MemoryID string `json:"memory_id"` // provenance EPR float64 `json:"epr"` // EPR score of the output } @@ -93,7 +97,7 @@ func (da *DreamingAgent) AssembleTrainingBatch(ctx context.Context, outputDir st // Write gold examples for _, entry := range goldEntries { - example, err := da.buildTrainingExample(ctx, entry, "gold") + example, err := da.buildTrainingExample(ctx, entry) if err != nil { da.log.Debug("skipping gold entry", "entry_id", entry.ID, "error", err) continue @@ -106,20 +110,26 @@ func (da *DreamingAgent) AssembleTrainingBatch(ctx context.Context, outputDir st // Write corrective examples (using the teacher model's output) for _, entry := range corrected { - example := TrainingExample{ - Type: "corrective", - MemoryID: entry.MemoryID, - EPR: entry.CorrectedEPR, - Output: entry.CorrectedOutput, - } - // Build the prompt from raw memory raw, err := da.store.GetRaw(ctx, entry.RawID) if err != nil { da.log.Debug("skipping corrected entry", "entry_id", entry.ID, "error", err) continue } - truncated := agentutil.Truncate(raw.Content, 4000) - example.Prompt = encoding.BuildCompressionPrompt(truncated, raw.Source, raw.Type, "", "", nil) + + // Parse corrected output back to structured JSON + var encoded any + if err := json.Unmarshal([]byte(entry.CorrectedOutput), &encoded); err != nil { + da.log.Debug("skipping corrected entry with invalid JSON", "entry_id", entry.ID, "error", err) + continue + } + + example := TrainingExample{ + RawInput: agentutil.Truncate(raw.Content, 4000), + Encoded: encoded, + TaskType: "encoding", + MemoryID: entry.MemoryID, + EPR: entry.CorrectedEPR, + } if err := enc.Encode(example); err != nil { return nil, fmt.Errorf("writing corrective example: %w", err) @@ -159,8 +169,8 @@ func (da *DreamingAgent) AssembleTrainingBatch(ctx context.Context, outputDir st } // buildTrainingExample creates a training example from a gold experience entry. -// Loads the raw memory and the encoded memory to reconstruct the prompt+output pair. -func (da *DreamingAgent) buildTrainingExample(ctx context.Context, entry store.ExperienceEntry, exType string) (*TrainingExample, error) { +// Loads the raw memory and the encoded memory to reconstruct the raw_input+encoded pair. +func (da *DreamingAgent) buildTrainingExample(ctx context.Context, entry store.ExperienceEntry) (*TrainingExample, error) { raw, err := da.store.GetRaw(ctx, entry.RawID) if err != nil { return nil, fmt.Errorf("loading raw memory %s: %w", entry.RawID, err) @@ -172,24 +182,18 @@ func (da *DreamingAgent) buildTrainingExample(ctx context.Context, entry store.E return nil, fmt.Errorf("loading memory %s: %w", entry.MemoryID, err) } - truncated := agentutil.Truncate(raw.Content, 4000) - prompt := encoding.BuildCompressionPrompt(truncated, raw.Source, raw.Type, "", "", nil) - - // Reconstruct the encoding output as JSON - output, err := json.Marshal(map[string]any{ + // Reconstruct the encoding output as a structured object + encoded := map[string]any{ "summary": mem.Summary, "content": mem.Content, "concepts": mem.Concepts, "salience": mem.Salience, - }) - if err != nil { - return nil, fmt.Errorf("marshaling memory output: %w", err) } return &TrainingExample{ - Type: exType, - Prompt: prompt, - Output: string(output), + RawInput: agentutil.Truncate(raw.Content, 4000), + Encoded: encoded, + TaskType: "encoding", MemoryID: entry.MemoryID, EPR: entry.EncodingEPR, }, nil diff --git a/internal/agent/dreaming/training_data_test.go b/internal/agent/dreaming/training_data_test.go index a7cdce23..7af31059 100644 --- a/internal/agent/dreaming/training_data_test.go +++ b/internal/agent/dreaming/training_data_test.go @@ -111,29 +111,29 @@ func TestAssembleTrainingBatch_GoldOnly(t *testing.T) { if err := json.Unmarshal(lines[0], &ex); err != nil { t.Fatalf("parsing first JSONL line: %v", err) } - if ex.Type != "gold" { - t.Errorf("expected type 'gold', got %q", ex.Type) + if ex.TaskType != "encoding" { + t.Errorf("expected task_type 'encoding', got %q", ex.TaskType) } if ex.MemoryID != "mem-1" { t.Errorf("expected memory_id 'mem-1', got %q", ex.MemoryID) } - if ex.Prompt == "" { - t.Error("expected non-empty prompt") + if ex.RawInput == "" { + t.Error("expected non-empty raw_input") } - if ex.Output == "" { - t.Error("expected non-empty output") + if ex.Encoded == nil { + t.Error("expected non-nil encoded") } - // Output should be valid JSON with expected fields - var outputFields map[string]any - if err := json.Unmarshal([]byte(ex.Output), &outputFields); err != nil { - t.Fatalf("gold output is not valid JSON: %v", err) + // Encoded should be a map with expected fields + encodedMap, ok := ex.Encoded.(map[string]any) + if !ok { + t.Fatalf("encoded is not a map, got %T", ex.Encoded) } - if _, ok := outputFields["summary"]; !ok { - t.Error("gold output missing 'summary' field") + if _, ok := encodedMap["summary"]; !ok { + t.Error("encoded missing 'summary' field") } - if _, ok := outputFields["content"]; !ok { - t.Error("gold output missing 'content' field") + if _, ok := encodedMap["content"]; !ok { + t.Error("encoded missing 'content' field") } } @@ -178,14 +178,18 @@ func TestAssembleTrainingBatch_CorrectedOnly(t *testing.T) { if err := json.Unmarshal(lines[0], &ex); err != nil { t.Fatalf("parsing JSONL line: %v", err) } - if ex.Type != "corrective" { - t.Errorf("expected type 'corrective', got %q", ex.Type) + if ex.TaskType != "encoding" { + t.Errorf("expected task_type 'encoding', got %q", ex.TaskType) } if ex.EPR != 0.92 { t.Errorf("expected EPR 0.92, got %.2f", ex.EPR) } - if ex.Output != `{"summary":"corrected summary","content":"corrected content","concepts":["auth"]}` { - t.Errorf("corrective output mismatch: %s", ex.Output) + encodedMap, ok := ex.Encoded.(map[string]any) + if !ok { + t.Fatalf("encoded is not a map, got %T", ex.Encoded) + } + if encodedMap["summary"] != "corrected summary" { + t.Errorf("expected summary 'corrected summary', got %v", encodedMap["summary"]) } } diff --git a/internal/agent/dreaming/training_trigger.go b/internal/agent/dreaming/training_trigger.go index 6bd521b5..fb3bcc15 100644 --- a/internal/agent/dreaming/training_trigger.go +++ b/internal/agent/dreaming/training_trigger.go @@ -111,8 +111,17 @@ func (da *DreamingAgent) RunTrainingCycle(ctx context.Context, clCfg config.Cont TotalExamples: manifest.TotalExamples, } - // Step 3: Run spoke training - checkpointPath, err := da.runSpokeTraining(ctx, manifest.DataPath, tCfg) + // Step 3: Tokenize the batch (raw_input+encoded → input_ids+completion_start) + tokenizedPath, err := da.prepareTrainingData(ctx, manifest.DataPath, outputDir) + if err != nil { + result.Status = "failed" + result.ErrorMessage = fmt.Sprintf("data preparation failed: %v", err) + da.failTrainingRun(ctx, &run, result.ErrorMessage) + return result, nil + } + + // Step 4: Run spoke training + checkpointPath, err := da.runSpokeTraining(ctx, tokenizedPath, tCfg) if err != nil { result.Status = "failed" result.ErrorMessage = fmt.Sprintf("training failed: %v", err) @@ -197,6 +206,55 @@ type qualityGateResult struct { Passed bool } +// prepareTrainingData runs the Gemma data prep script to tokenize raw_input+encoded +// pairs into input_ids+completion_start JSONL that the training script expects. +func (da *DreamingAgent) prepareTrainingData(ctx context.Context, batchPath string, outputDir string) (string, error) { + projectDir := os.Getenv("MNEMONIC_PROJECT_DIR") + if projectDir == "" { + homeDir, _ := os.UserHomeDir() + projectDir = filepath.Join(homeDir, "Projects", "mem") + } + + prepScript := filepath.Join(projectDir, "training", "scripts", "prepare_gemma_finetune_data.py") + if _, err := os.Stat(prepScript); err != nil { + return "", fmt.Errorf("prep script not found at %s: %w", prepScript, err) + } + + venvPython := filepath.Join(os.Getenv("HOME"), "Projects", "felixlm", ".venv", "bin", "python") + if _, err := os.Stat(venvPython); err != nil { + venvPython = "python3" + } + + tokenizedDir := filepath.Join(outputDir, "tokenized") + + args := []string{ + prepScript, + "--input", batchPath, + "--output-dir", tokenizedDir, + "--max-seq-len", "2048", + "--eval-ratio", "0", + } + + da.log.Info("preparing training data", "script", prepScript, "input", batchPath, "output_dir", tokenizedDir) + + cmd := exec.CommandContext(ctx, venvPython, args...) + cmd.Dir = projectDir + + output, err := cmd.CombinedOutput() + if err != nil { + return "", fmt.Errorf("prep script failed: %w\nOutput: %s", err, string(output)) + } + + // The prep script writes train.jsonl in the output dir + tokenizedPath := filepath.Join(tokenizedDir, "train.jsonl") + if _, err := os.Stat(tokenizedPath); err != nil { + return "", fmt.Errorf("tokenized data not found at %s after prep", tokenizedPath) + } + + da.log.Info("training data prepared", "path", tokenizedPath) + return tokenizedPath, nil +} + // runSpokeTraining executes the Python training script as a subprocess. // Returns the path to the output checkpoint. func (da *DreamingAgent) runSpokeTraining(ctx context.Context, batchPath string, tCfg config.CLTrainingConfig) (string, error) { @@ -228,12 +286,15 @@ func (da *DreamingAgent) runSpokeTraining(ctx context.Context, batchPath string, args := []string{ scriptPath, "--model-type", "gemma", - "--data", batchPath, - "--output-dir", checkpointDir, + "--base-model", "google/gemma-4-E2B-it", + "--train-data", batchPath, + "--checkpoint-dir", checkpointDir, + "--seq-len", "2048", "--steps", "500", "--batch-size", "1", "--grad-accum", "8", "--lr", "1e-4", + "--no-wandb", } da.log.Info("running spoke training", diff --git a/internal/agent/dreaming/training_trigger_test.go b/internal/agent/dreaming/training_trigger_test.go index 34901426..c1cd88a5 100644 --- a/internal/agent/dreaming/training_trigger_test.go +++ b/internal/agent/dreaming/training_trigger_test.go @@ -155,10 +155,12 @@ func TestRunTrainingCycle_AssemblesAndRecords(t *testing.T) { clCfg := baseCLConfig() - // RunTrainingCycle will assemble data, write a training run, then fail on - // the subprocess call (no Python env in tests). That's expected — we're testing - // the trigger logic and record-keeping, not the actual training. - result, err := agent.RunTrainingCycle(context.Background(), clCfg) + // Use a short timeout — we only test trigger logic and record-keeping. + // The subprocess will be killed quickly rather than loading a full model. + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + result, err := agent.RunTrainingCycle(ctx, clCfg) if err != nil { t.Fatalf("unexpected error: %v", err) } From ebaaebeda20c5cc5c9ab53d3b377f44ed8eeb2cf Mon Sep 17 00:00:00 2001 From: Caleb Gross Date: Mon, 13 Apr 2026 23:54:55 -0400 Subject: [PATCH 6/8] docs: systemd training orchestration design (#391) Specifies the refactor from inline subprocess training to systemd- orchestrated training. Daemon writes a request file; systemd path unit triggers a separate service that stops the daemon, trains, and restarts. Eliminates VRAM contention that was crashing the system. Co-Authored-By: Claude Opus 4.6 (1M context) --- ...3-systemd-training-orchestration-design.md | 107 ++++++++++++++++++ 1 file changed, 107 insertions(+) create mode 100644 docs/superpowers/specs/2026-04-13-systemd-training-orchestration-design.md diff --git a/docs/superpowers/specs/2026-04-13-systemd-training-orchestration-design.md b/docs/superpowers/specs/2026-04-13-systemd-training-orchestration-design.md new file mode 100644 index 00000000..b03a467d --- /dev/null +++ b/docs/superpowers/specs/2026-04-13-systemd-training-orchestration-design.md @@ -0,0 +1,107 @@ +# Systemd Training Orchestration + +## 2026-04-13 | Issue #391 Phase C refinement + +### Problem + +The daemon runs an embedded llama.cpp model on the GPU (~3GB VRAM). When `RunTrainingCycle` spawns `train_spokes.py` as a subprocess, both compete for VRAM. On the RX 7800 XT with ~12GB usable, loading Gemma 4 E2B (~10GB for training) alongside the running model causes OOM and system crash. This has happened multiple times. + +The spec (SPEC_continuous_learning.md, section 4.1-4.2) already designed the correct solution: hybrid orchestration via systemd. The daemon writes a request file and exits; systemd handles the rest. This doc specifies the implementation. + +### Design + +**Daemon side (Go):** + +`RunTrainingCycle` splits into two responsibilities: + +1. **Data assembly** (stays in daemon): count untrained experience, assemble JSONL batch, write manifest. This is pure Go, no GPU, fast. + +2. **Training request** (new): write a `pending.json` request file to `~/.local/share/mnemonic/training_requests/`. The daemon does NOT run any Python subprocesses. `RunTrainingCycle` returns `{status: "training_requested", request_id, batch_id}`. + +The request file contains everything the training pipeline needs: +```json +{ + "request_id": "tr-20260413-abc123", + "timestamp": "2026-04-13T03:00:00Z", + "trigger": "manual|auto", + "batch_path": "/tmp/mnemonic-training/batch_abc123.jsonl", + "total_examples": 87, + "gold_count": 52, + "corrected_count": 35, + "run_id": "abc12345" +} +``` + +After writing the request, the daemon does NOT shut itself down. The systemd path unit triggers `mnemonic-train.service`, which stops the daemon before training. This keeps shutdown authority with systemd. + +**Systemd side (new units):** + +`mnemonic-train.path` watches for `pending.json`. When it appears, it triggers `mnemonic-train.service`. + +`mnemonic-train.service` runs `scripts/continuous_train.sh`, which: +1. Stops the daemon (`systemctl --user stop mnemonic`) to free VRAM +2. Reads the request file +3. Runs `prepare_gemma_finetune_data.py` (tokenize) +4. Runs `train_spokes.py` (train spokes) +5. Runs `eval_encoding.py` (quality gate) +6. If quality passes: runs `deploy_model.sh` +7. Writes `result.json` with outcome +8. Archives `pending.json` +9. Restarts the daemon (`systemctl --user start mnemonic`) -- in EXIT trap, always runs + +**Daemon startup (result pickup):** + +On startup, the daemon checks for `result.json` in the training requests directory. If found: +- Reads the result +- Updates the corresponding `training_runs` record in the DB +- Logs the outcome +- Deletes `result.json` + +This closes the feedback loop — the training run record goes from `status: "training_requested"` to `completed` or `failed`. + +**MCP tool change:** + +`train_model` returns immediately with `{status: "training_requested", request_id}`. The caller (Claude Code agent) can check status later via existing `status` tool or the training_runs table. + +### What moves out of Go + +These functions in `training_trigger.go` become dead code and are removed: +- `prepareTrainingData` (moves to shell script) +- `runSpokeTraining` (moves to shell script) +- `runQualityGate` (moves to shell script) +- `deploySpokeModel` (moves to shell script) +- `parseEvalOutput` (moves to shell script / not needed) + +### What stays in Go + +- `trainingCheck` (auto-trigger gating: enabled, window check) +- `RunTrainingCycle` (refactored: assemble data + write request) +- `failTrainingRun` (for assembly failures) +- `inTrainingWindow`, `splitLines` (utilities) +- `AssembleTrainingBatch` (data assembly, no GPU) + +### Files changed + +| File | Change | +|------|--------| +| `internal/agent/dreaming/training_trigger.go` | Refactor: remove subprocess funcs, add request file writing | +| `internal/agent/dreaming/training_trigger_test.go` | Update tests for async flow | +| `cmd/mnemonic/serve.go` | Add training result pickup on startup | +| `internal/mcp/server.go` | Update handleTrainModel response format | +| `scripts/continuous_train.sh` | New: training orchestrator | +| `scripts/systemd/mnemonic-train.path` | New: path watcher | +| `scripts/systemd/mnemonic-train.service` | New: training service | + +### Quality gate thresholds (from spec) + +Run by the shell script, not Go. Pass criteria: +- EPR >= 0.90 +- FR <= 0.05 (monitoring, lenient) +- SC >= 0.95 + +### Testing + +- Unit tests: request file writing, result pickup, manifest validation +- Integration: mock the file write, verify training_runs record lifecycle +- Manual: install systemd units, trigger training, verify full flow +- The shell script itself is tested manually (requires GPU) From 2a87d4aabfaf9edfb77b45dbda31b7cdfb646e4c Mon Sep 17 00:00:00 2001 From: Caleb Gross Date: Tue, 14 Apr 2026 00:06:48 -0400 Subject: [PATCH 7/8] refactor: systemd-orchestrated training to prevent GPU crashes (#391) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The daemon was running Python training subprocesses while holding VRAM for the embedded llama.cpp model, causing OOM crashes. Refactored to match the spec's hybrid orchestration design: - RunTrainingCycle now assembles data and writes pending.json, no subprocess calls. Returns "training_requested" status. - New continuous_train.sh stops the daemon (freeing VRAM), runs tokenization/training/eval/deploy, always restarts the daemon. - New systemd units: mnemonic-train.path watches for pending.json, mnemonic-train.service runs the training script with 30min timeout. - Daemon picks up result.json on startup to close the feedback loop. - MCP train_model tool now returns async (request_id, status). - Duplicate request prevention (skips if pending.json already exists). Tested: unit tests (9/9), full suite (0 failures), manual systemd path trigger with fake request — daemon stopped, script ran, failed correctly on missing batch, daemon restarted, result picked up. Co-Authored-By: Claude Opus 4.6 (1M context) --- cmd/mnemonic/serve.go | 16 +- internal/agent/dreaming/agent.go | 6 +- internal/agent/dreaming/training_trigger.go | 485 ++++++------------ .../agent/dreaming/training_trigger_test.go | 244 ++++++--- internal/mcp/session.go | 80 +-- internal/mcp/tools.go | 2 +- scripts/continuous_train.sh | 235 +++++++++ scripts/systemd/mnemonic-train.path | 9 + scripts/systemd/mnemonic-train.service | 14 + 9 files changed, 632 insertions(+), 459 deletions(-) create mode 100755 scripts/continuous_train.sh create mode 100644 scripts/systemd/mnemonic-train.path create mode 100644 scripts/systemd/mnemonic-train.service diff --git a/cmd/mnemonic/serve.go b/cmd/mnemonic/serve.go index 0a37e5d9..eeb71012 100644 --- a/cmd/mnemonic/serve.go +++ b/cmd/mnemonic/serve.go @@ -134,6 +134,13 @@ func serveCommand(configPath string) { } intCancel() + // Pick up training results from a previous systemd training run + pickupCtx, pickupCancel := context.WithTimeout(context.Background(), 5*time.Second) + if err := dreaming.PickUpTrainingResult(pickupCtx, memStore, log); err != nil { + log.Warn("failed to pick up training result", "error", err) + } + pickupCancel() + // Check available disk space dbDir := filepath.Dir(cfg.Store.DBPath) if availBytes, diskErr := diskAvailable(dbDir); diskErr == nil { @@ -716,7 +723,7 @@ func serveCommand(configPath string) { if dreamer != nil && cfg.ContinuousLearning.Trigger.Manual { clCfg := cfg.ContinuousLearning smCfg.TrainingTriggerFn = func(ctx context.Context) (map[string]any, error) { - result, err := dreamer.RunTrainingCycle(ctx, clCfg) + result, err := dreamer.RunTrainingCycle(ctx, clCfg, "manual") if err != nil { return nil, err } @@ -725,13 +732,10 @@ func serveCommand(configPath string) { } return map[string]any{ "status": result.Status, + "request_id": result.RequestID, "batch_id": result.BatchID, "total_examples": result.TotalExamples, - "quality_passed": result.QualityPassed, - "checkpoint": result.CheckpointPath, - "model": result.ModelPath, - "eval_epr": result.EvalEPR, - "eval_sc": result.EvalSC, + "request_path": result.RequestPath, "error": result.ErrorMessage, }, nil } diff --git a/internal/agent/dreaming/agent.go b/internal/agent/dreaming/agent.go index d16f35ff..9153ee7a 100644 --- a/internal/agent/dreaming/agent.go +++ b/internal/agent/dreaming/agent.go @@ -200,10 +200,10 @@ func (da *DreamingAgent) runCycle(ctx context.Context) (*DreamReport, error) { if trainResult, err := da.trainingCheck(ctx, da.config.ContinuousLearning); err != nil && ctx.Err() == nil { da.log.Error("training trigger phase failed", "error", err) } else if trainResult != nil { - da.log.Info("training cycle result", + da.log.Info("training request result", "status", trainResult.Status, - "examples", trainResult.TotalExamples, - "quality_passed", trainResult.QualityPassed) + "request_id", trainResult.RequestID, + "examples", trainResult.TotalExamples) } // Phase 5: Link replayed memories to matching patterns diff --git a/internal/agent/dreaming/training_trigger.go b/internal/agent/dreaming/training_trigger.go index fb3bcc15..aedab5af 100644 --- a/internal/agent/dreaming/training_trigger.go +++ b/internal/agent/dreaming/training_trigger.go @@ -2,10 +2,9 @@ package dreaming import ( "context" - "fmt" "encoding/json" + "fmt" "os" - "os/exec" "path/filepath" "time" @@ -14,17 +13,54 @@ import ( "github.com/google/uuid" ) -// TrainingResult reports the outcome of a training cycle. +// TrainingResult reports the outcome of a training request. +// With systemd orchestration, the daemon only assembles data and writes a request file. +// Actual training happens in a separate systemd service after the daemon stops. type TrainingResult struct { - BatchID string - TotalExamples int - Status string // completed, failed - CheckpointPath string - ModelPath string - EvalEPR float64 - EvalSC float64 - QualityPassed bool - ErrorMessage string + RequestID string `json:"request_id"` + BatchID string `json:"batch_id"` + TotalExamples int `json:"total_examples"` + Status string `json:"status"` // "training_requested" or "failed" + RequestPath string `json:"request_path"` // path to pending.json + ErrorMessage string `json:"error_message,omitempty"` +} + +// TrainingRequest is the JSON written to pending.json for the systemd training service. +type TrainingRequest struct { + RequestID string `json:"request_id"` + RunID string `json:"run_id"` + Timestamp string `json:"timestamp"` + Trigger string `json:"trigger"` // "manual" or "auto" + BatchPath string `json:"batch_path"` + TotalExamples int `json:"total_examples"` + GoldCount int `json:"gold_count"` + CorrectedCount int `json:"corrected_count"` +} + +// TrainingResultFile is the JSON written by continuous_train.sh after training completes. +// The daemon reads this on startup to update the training_runs record. +type TrainingResultFile struct { + RequestID string `json:"request_id"` + RunID string `json:"run_id"` + Status string `json:"status"` // "completed" or "failed" + CheckpointPath string `json:"checkpoint_path,omitempty"` + ModelPath string `json:"model_path,omitempty"` + EvalEPR float64 `json:"eval_epr,omitempty"` + EvalFR float64 `json:"eval_fr,omitempty"` + EvalSC float64 `json:"eval_sc,omitempty"` + QualityPassed bool `json:"quality_passed"` + ErrorMessage string `json:"error_message,omitempty"` + CompletedAt string `json:"completed_at"` +} + +// TrainingRequestsDir returns the path to the training requests directory. +// Uses ~/.mnemonic/training_requests/ by default. +func TrainingRequestsDir() string { + if dir := os.Getenv("MNEMONIC_TRAINING_REQUESTS_DIR"); dir != "" { + return dir + } + homeDir, _ := os.UserHomeDir() + return filepath.Join(homeDir, ".mnemonic", "training_requests") } // trainingCheck runs Phase 4.85: check if we should trigger spoke training. @@ -44,18 +80,16 @@ func (da *DreamingAgent) trainingCheck(ctx context.Context, clCfg config.Continu return nil, nil } - return da.RunTrainingCycle(ctx, clCfg) + return da.RunTrainingCycle(ctx, clCfg, "auto") } -// RunTrainingCycle executes the full training pipeline: -// 1. Check if enough untrained experience exists -// 2. Assemble training batch (JSONL) -// 3. Run spoke training (Python subprocess) -// 4. Run quality gate evaluation -// 5. Deploy new spokes if quality passes +// RunTrainingCycle assembles training data and writes a request file for the +// systemd training service. The daemon does NOT run training subprocesses. // -// This is the manual entry point called by MCP tools or dreaming auto-trigger. -func (da *DreamingAgent) RunTrainingCycle(ctx context.Context, clCfg config.ContinuousLearningConfig) (*TrainingResult, error) { +// Flow: check untrained count -> assemble JSONL batch -> write pending.json +// The systemd path unit detects pending.json, stops the daemon, runs training, +// and restarts the daemon. Results are picked up on next startup. +func (da *DreamingAgent) RunTrainingCycle(ctx context.Context, clCfg config.ContinuousLearningConfig, trigger string) (*TrainingResult, error) { tCfg := clCfg.Training // Step 1: Check if enough untrained data exists @@ -74,6 +108,14 @@ func (da *DreamingAgent) RunTrainingCycle(ctx context.Context, clCfg config.Cont return nil, nil } + // Check for an existing pending request — don't stack requests + requestDir := TrainingRequestsDir() + pendingPath := filepath.Join(requestDir, "pending.json") + if _, err := os.Stat(pendingPath); err == nil { + da.log.Info("training skipped: pending request already exists", "path", pendingPath) + return nil, nil + } + // Step 2: Assemble training batch outputDir := filepath.Join(os.TempDir(), "mnemonic-training") maxExamples := tCfg.MaxExamplesPerRun @@ -95,98 +137,68 @@ func (da *DreamingAgent) RunTrainingCycle(ctx context.Context, clCfg config.Cont GoldCount: manifest.GoldCount, CorrectedCount: manifest.CorrectedCount, TotalExamples: manifest.TotalExamples, - Status: "training", + Status: "requested", StartedAt: time.Now(), } if err := da.store.WriteTrainingRun(ctx, run); err != nil { return nil, fmt.Errorf("writing training run: %w", err) } - da.log.Info("training cycle started", - "run_id", runID, "batch_id", manifest.ID, - "examples", manifest.TotalExamples) - - result := &TrainingResult{ - BatchID: manifest.ID, - TotalExamples: manifest.TotalExamples, + // Step 3: Write the training request file + request := TrainingRequest{ + RequestID: fmt.Sprintf("tr-%s-%s", time.Now().Format("20060102"), runID), + RunID: runID, + Timestamp: time.Now().UTC().Format(time.RFC3339), + Trigger: trigger, + BatchPath: manifest.DataPath, + TotalExamples: manifest.TotalExamples, + GoldCount: manifest.GoldCount, + CorrectedCount: manifest.CorrectedCount, } - // Step 3: Tokenize the batch (raw_input+encoded → input_ids+completion_start) - tokenizedPath, err := da.prepareTrainingData(ctx, manifest.DataPath, outputDir) + requestPath, err := da.writeTrainingRequest(request) if err != nil { - result.Status = "failed" - result.ErrorMessage = fmt.Sprintf("data preparation failed: %v", err) - da.failTrainingRun(ctx, &run, result.ErrorMessage) - return result, nil + da.failTrainingRun(ctx, &run, fmt.Sprintf("writing request file: %v", err)) + return &TrainingResult{ + RequestID: request.RequestID, + BatchID: manifest.ID, + Status: "failed", + ErrorMessage: fmt.Sprintf("writing request file: %v", err), + }, nil } - // Step 4: Run spoke training - checkpointPath, err := da.runSpokeTraining(ctx, tokenizedPath, tCfg) - if err != nil { - result.Status = "failed" - result.ErrorMessage = fmt.Sprintf("training failed: %v", err) - da.failTrainingRun(ctx, &run, result.ErrorMessage) - return result, nil - } - run.CheckpointPath = checkpointPath - run.Status = "evaluating" - _ = da.store.UpdateTrainingRun(ctx, run) + da.log.Info("training request written — systemd will handle training", + "run_id", runID, "request_id", request.RequestID, + "examples", manifest.TotalExamples, "path", requestPath) - // Step 4: Run quality gate - evalResult, err := da.runQualityGate(ctx, checkpointPath) - if err != nil { - result.Status = "failed" - result.ErrorMessage = fmt.Sprintf("evaluation failed: %v", err) - da.failTrainingRun(ctx, &run, result.ErrorMessage) - return result, nil - } + return &TrainingResult{ + RequestID: request.RequestID, + BatchID: manifest.ID, + TotalExamples: manifest.TotalExamples, + Status: "training_requested", + RequestPath: requestPath, + }, nil +} - run.EvalEPR = evalResult.EPR - run.EvalFR = evalResult.FR - run.EvalSC = evalResult.SC - run.QualityPassed = evalResult.Passed - result.EvalEPR = evalResult.EPR - result.EvalSC = evalResult.SC - - if !evalResult.Passed { - result.Status = "failed" - result.QualityPassed = false - result.ErrorMessage = fmt.Sprintf("quality gate failed: EPR=%.2f FR=%.2f SC=%.2f", evalResult.EPR, evalResult.FR, evalResult.SC) - da.failTrainingRun(ctx, &run, result.ErrorMessage) - da.log.Warn("training quality gate failed — discarding checkpoint", - "run_id", runID, "epr", evalResult.EPR, "fr", evalResult.FR, "sc", evalResult.SC) - return result, nil +// writeTrainingRequest writes the pending.json file that triggers the systemd training service. +func (da *DreamingAgent) writeTrainingRequest(request TrainingRequest) (string, error) { + requestDir := TrainingRequestsDir() + if err := os.MkdirAll(requestDir, 0o755); err != nil { + return "", fmt.Errorf("creating request dir %s: %w", requestDir, err) } - // Step 5: Deploy new spokes - run.Status = "deploying" - _ = da.store.UpdateTrainingRun(ctx, run) + pendingPath := filepath.Join(requestDir, "pending.json") - modelPath, err := da.deploySpokeModel(ctx, checkpointPath) + data, err := json.MarshalIndent(request, "", " ") if err != nil { - result.Status = "failed" - result.ErrorMessage = fmt.Sprintf("deployment failed: %v", err) - da.failTrainingRun(ctx, &run, result.ErrorMessage) - return result, nil + return "", fmt.Errorf("marshaling request: %w", err) } - // Success - now := time.Now() - run.ModelPath = modelPath - run.Status = "completed" - run.CompletedAt = &now - _ = da.store.UpdateTrainingRun(ctx, run) - - result.Status = "completed" - result.QualityPassed = true - result.CheckpointPath = checkpointPath - result.ModelPath = modelPath - - da.log.Info("training cycle completed", - "run_id", runID, "epr", evalResult.EPR, "sc", evalResult.SC, - "model", modelPath) + if err := os.WriteFile(pendingPath, data, 0o644); err != nil { + return "", fmt.Errorf("writing %s: %w", pendingPath, err) + } - return result, nil + return pendingPath, nil } // failTrainingRun records a failed training run in the store. @@ -198,234 +210,61 @@ func (da *DreamingAgent) failTrainingRun(ctx context.Context, run *store.Trainin _ = da.store.UpdateTrainingRun(ctx, *run) } -// qualityGateResult holds the evaluation metrics from the quality gate. -type qualityGateResult struct { - EPR float64 - FR float64 - SC float64 - Passed bool -} - -// prepareTrainingData runs the Gemma data prep script to tokenize raw_input+encoded -// pairs into input_ids+completion_start JSONL that the training script expects. -func (da *DreamingAgent) prepareTrainingData(ctx context.Context, batchPath string, outputDir string) (string, error) { - projectDir := os.Getenv("MNEMONIC_PROJECT_DIR") - if projectDir == "" { - homeDir, _ := os.UserHomeDir() - projectDir = filepath.Join(homeDir, "Projects", "mem") - } - - prepScript := filepath.Join(projectDir, "training", "scripts", "prepare_gemma_finetune_data.py") - if _, err := os.Stat(prepScript); err != nil { - return "", fmt.Errorf("prep script not found at %s: %w", prepScript, err) - } - - venvPython := filepath.Join(os.Getenv("HOME"), "Projects", "felixlm", ".venv", "bin", "python") - if _, err := os.Stat(venvPython); err != nil { - venvPython = "python3" - } - - tokenizedDir := filepath.Join(outputDir, "tokenized") - - args := []string{ - prepScript, - "--input", batchPath, - "--output-dir", tokenizedDir, - "--max-seq-len", "2048", - "--eval-ratio", "0", - } - - da.log.Info("preparing training data", "script", prepScript, "input", batchPath, "output_dir", tokenizedDir) - - cmd := exec.CommandContext(ctx, venvPython, args...) - cmd.Dir = projectDir - - output, err := cmd.CombinedOutput() - if err != nil { - return "", fmt.Errorf("prep script failed: %w\nOutput: %s", err, string(output)) - } - - // The prep script writes train.jsonl in the output dir - tokenizedPath := filepath.Join(tokenizedDir, "train.jsonl") - if _, err := os.Stat(tokenizedPath); err != nil { - return "", fmt.Errorf("tokenized data not found at %s after prep", tokenizedPath) - } - - da.log.Info("training data prepared", "path", tokenizedPath) - return tokenizedPath, nil -} - -// runSpokeTraining executes the Python training script as a subprocess. -// Returns the path to the output checkpoint. -func (da *DreamingAgent) runSpokeTraining(ctx context.Context, batchPath string, tCfg config.CLTrainingConfig) (string, error) { - // The training script lives relative to the daemon binary's project root. - // Use the MNEMONIC_PROJECT_DIR env var or default to /home//Projects/mem. - projectDir := os.Getenv("MNEMONIC_PROJECT_DIR") - if projectDir == "" { - homeDir, _ := os.UserHomeDir() - projectDir = filepath.Join(homeDir, "Projects", "mem") - } - - scriptPath := filepath.Join(projectDir, "training", "scripts", "train_spokes.py") - if _, err := os.Stat(scriptPath); err != nil { - return "", fmt.Errorf("training script not found at %s: %w", scriptPath, err) - } - - checkpointDir := filepath.Join(projectDir, "checkpoints", "continuous_learning") - if err := os.MkdirAll(checkpointDir, 0o755); err != nil { - return "", fmt.Errorf("creating checkpoint dir: %w", err) - } - - // Construct training command. The venv must be activated by the caller - // or the script must be runnable with the system Python. - venvPython := filepath.Join(os.Getenv("HOME"), "Projects", "felixlm", ".venv", "bin", "python") - if _, err := os.Stat(venvPython); err != nil { - venvPython = "python3" // fallback - } +// PickUpTrainingResult checks for a result.json from a previous training run +// and updates the corresponding training_runs record. Called on daemon startup. +func PickUpTrainingResult(ctx context.Context, s store.Store, log interface{ Info(string, ...any) }) error { + requestDir := TrainingRequestsDir() + resultPath := filepath.Join(requestDir, "result.json") - args := []string{ - scriptPath, - "--model-type", "gemma", - "--base-model", "google/gemma-4-E2B-it", - "--train-data", batchPath, - "--checkpoint-dir", checkpointDir, - "--seq-len", "2048", - "--steps", "500", - "--batch-size", "1", - "--grad-accum", "8", - "--lr", "1e-4", - "--no-wandb", - } - - da.log.Info("running spoke training", - "script", scriptPath, "data", batchPath, - "output_dir", checkpointDir) - - cmd := exec.CommandContext(ctx, venvPython, args...) - cmd.Dir = projectDir - cmd.Env = append(os.Environ(), "PYTHONUNBUFFERED=1") - - output, err := cmd.CombinedOutput() + data, err := os.ReadFile(resultPath) if err != nil { - return "", fmt.Errorf("training script failed: %w\nOutput: %s", err, string(output)) - } - - // Find the checkpoint — the script writes to output_dir/last.pt - checkpointPath := filepath.Join(checkpointDir, "last.pt") - if _, err := os.Stat(checkpointPath); err != nil { - return "", fmt.Errorf("checkpoint not found after training at %s", checkpointPath) - } - - da.log.Info("spoke training completed", "checkpoint", checkpointPath) - return checkpointPath, nil -} - -// runQualityGate evaluates the trained checkpoint against probe inputs. -// Returns metrics and whether the model passes the quality threshold. -func (da *DreamingAgent) runQualityGate(ctx context.Context, checkpointPath string) (*qualityGateResult, error) { - projectDir := os.Getenv("MNEMONIC_PROJECT_DIR") - if projectDir == "" { - homeDir, _ := os.UserHomeDir() - projectDir = filepath.Join(homeDir, "Projects", "mem") - } - - evalScript := filepath.Join(projectDir, "training", "scripts", "eval_encoding.py") - if _, err := os.Stat(evalScript); err != nil { - return nil, fmt.Errorf("eval script not found at %s: %w", evalScript, err) - } - - venvPython := filepath.Join(os.Getenv("HOME"), "Projects", "felixlm", ".venv", "bin", "python") - if _, err := os.Stat(venvPython); err != nil { - venvPython = "python3" - } - - args := []string{ - evalScript, - "--checkpoint", checkpointPath, - "--mode", "generate", - "--json-output", - } - - cmd := exec.CommandContext(ctx, venvPython, args...) - cmd.Dir = projectDir - - output, err := cmd.CombinedOutput() - if err != nil { - return nil, fmt.Errorf("eval script failed: %w\nOutput: %s", err, string(output)) + if os.IsNotExist(err) { + return nil // no result to pick up + } + return fmt.Errorf("reading result file: %w", err) } - // Parse the JSON output from the eval script. - // The script outputs a JSON line with EPR, FR, SC metrics. - result, err := parseEvalOutput(string(output)) - if err != nil { - return nil, fmt.Errorf("parsing eval output: %w", err) + var result TrainingResultFile + if err := json.Unmarshal(data, &result); err != nil { + return fmt.Errorf("parsing result file: %w", err) } - // Apply quality thresholds from the design doc: - // EPR >= 0.90, FR <= 0.05, SC >= 0.95 - result.Passed = result.EPR >= 0.90 && result.FR <= 0.05 && result.SC >= 0.95 - - da.log.Info("quality gate evaluation", - "epr", result.EPR, "fr", result.FR, "sc", result.SC, - "passed", result.Passed) - - return result, nil -} - -// deploySpokeModel exports the checkpoint to GGUF and deploys it. -func (da *DreamingAgent) deploySpokeModel(ctx context.Context, checkpointPath string) (string, error) { - projectDir := os.Getenv("MNEMONIC_PROJECT_DIR") - if projectDir == "" { - homeDir, _ := os.UserHomeDir() - projectDir = filepath.Join(homeDir, "Projects", "mem") + // Update the training run record + completedAt, _ := time.Parse(time.RFC3339, result.CompletedAt) + now := completedAt + if now.IsZero() { + now = time.Now() } - deployScript := filepath.Join(projectDir, "training", "scripts", "deploy_model.sh") - if _, err := os.Stat(deployScript); err != nil { - return "", fmt.Errorf("deploy script not found at %s: %w", deployScript, err) - } - - // Version the model with timestamp - modelName := fmt.Sprintf("gemma-spokes-cl-%s", time.Now().Format("20060102-150405")) - - cmd := exec.CommandContext(ctx, "bash", deployScript, checkpointPath, "--name", modelName) - cmd.Dir = projectDir - - output, err := cmd.CombinedOutput() - if err != nil { - return "", fmt.Errorf("deploy script failed: %w\nOutput: %s", err, string(output)) - } - - modelPath := filepath.Join(projectDir, "models", modelName+".gguf") - da.log.Info("spoke model deployed", "path", modelPath, "name", modelName) - - return modelPath, nil -} - -// parseEvalOutput extracts metrics from the evaluation script's JSON output. -func parseEvalOutput(output string) (*qualityGateResult, error) { - // The eval script outputs various lines. We look for the JSON summary. - // For now, use a simple heuristic: find the last line that starts with '{'. - lines := splitLines(output) - for i := len(lines) - 1; i >= 0; i-- { - line := lines[i] - if len(line) > 0 && line[0] == '{' { - var metrics struct { - EPR float64 `json:"epr"` - FR float64 `json:"fr"` - SC float64 `json:"sc"` - } - if err := json.Unmarshal([]byte(line), &metrics); err != nil { - continue - } - return &qualityGateResult{ - EPR: metrics.EPR, - FR: metrics.FR, - SC: metrics.SC, - }, nil - } - } - return nil, fmt.Errorf("no JSON metrics found in eval output") + run := store.TrainingRun{ + ID: result.RunID, + Status: result.Status, + CheckpointPath: result.CheckpointPath, + ModelPath: result.ModelPath, + EvalEPR: result.EvalEPR, + EvalFR: result.EvalFR, + EvalSC: result.EvalSC, + QualityPassed: result.QualityPassed, + ErrorMessage: result.ErrorMessage, + CompletedAt: &now, + } + if err := s.UpdateTrainingRun(ctx, run); err != nil { + return fmt.Errorf("updating training run %s: %w", result.RunID, err) + } + + log.Info("picked up training result from previous run", + "run_id", result.RunID, "status", result.Status, + "quality_passed", result.QualityPassed, + "epr", result.EvalEPR, "sc", result.EvalSC) + + // Archive the result file + archivePath := filepath.Join(requestDir, fmt.Sprintf("result_%s_%s.json", result.RunID, time.Now().Format("20060102_150405"))) + if err := os.Rename(resultPath, archivePath); err != nil { + // Not fatal — log and continue + log.Info("could not archive result file", "error", err) + } + + return nil } // inTrainingWindow checks if the current time is within the configured window. @@ -451,23 +290,3 @@ func inTrainingWindow(window string) bool { // Wraps midnight (e.g. "22:00-06:00") return currentMin >= startMin || currentMin < endMin } - -// splitLines splits a string into lines, trimming trailing whitespace. -func splitLines(s string) []string { - var lines []string - start := 0 - for i := 0; i < len(s); i++ { - if s[i] == '\n' { - line := s[start:i] - if len(line) > 0 && line[len(line)-1] == '\r' { - line = line[:len(line)-1] - } - lines = append(lines, line) - start = i + 1 - } - } - if start < len(s) { - lines = append(lines, s[start:]) - } - return lines -} diff --git a/internal/agent/dreaming/training_trigger_test.go b/internal/agent/dreaming/training_trigger_test.go index c1cd88a5..a271c0b2 100644 --- a/internal/agent/dreaming/training_trigger_test.go +++ b/internal/agent/dreaming/training_trigger_test.go @@ -2,8 +2,11 @@ package dreaming import ( "context" + "encoding/json" "io" "log/slog" + "os" + "path/filepath" "testing" "time" @@ -128,7 +131,7 @@ func TestRunTrainingCycle_InsufficientData(t *testing.T) { clCfg := baseCLConfig() clCfg.Training.MinNewExamples = 50 - result, err := agent.RunTrainingCycle(context.Background(), clCfg) + result, err := agent.RunTrainingCycle(context.Background(), clCfg, "manual") if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -137,7 +140,11 @@ func TestRunTrainingCycle_InsufficientData(t *testing.T) { } } -func TestRunTrainingCycle_AssemblesAndRecords(t *testing.T) { +func TestRunTrainingCycle_WritesRequestFile(t *testing.T) { + // Use a temp dir for training requests so we don't pollute the real one + tmpDir := t.TempDir() + t.Setenv("MNEMONIC_TRAINING_REQUESTS_DIR", tmpDir) + ms := &triggerMockStore{ untrainedCount: 10, goldEntries: []store.ExperienceEntry{ @@ -155,46 +162,181 @@ func TestRunTrainingCycle_AssemblesAndRecords(t *testing.T) { clCfg := baseCLConfig() - // Use a short timeout — we only test trigger logic and record-keeping. - // The subprocess will be killed quickly rather than loading a full model. - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - result, err := agent.RunTrainingCycle(ctx, clCfg) + result, err := agent.RunTrainingCycle(context.Background(), clCfg, "manual") if err != nil { t.Fatalf("unexpected error: %v", err) } + if result == nil { + t.Fatal("expected non-nil result") + } - // Should have assembled data and started a training run + // Should return training_requested status + if result.Status != "training_requested" { + t.Errorf("expected status 'training_requested', got %q", result.Status) + } + if result.RequestID == "" { + t.Error("expected non-empty request_id") + } + if result.TotalExamples != 1 { + t.Errorf("expected 1 total example, got %d", result.TotalExamples) + } + + // Should have written a training run record if len(ms.trainingRunsW) != 1 { t.Fatalf("expected 1 training run written, got %d", len(ms.trainingRunsW)) } run := ms.trainingRunsW[0] - if run.Status != "training" { - t.Errorf("expected initial status 'training', got %q", run.Status) + if run.Status != "requested" { + t.Errorf("expected initial status 'requested', got %q", run.Status) } - if run.TotalExamples != 1 { - t.Errorf("expected 1 total example, got %d", run.TotalExamples) + + // Should have written a pending.json file + pendingPath := filepath.Join(tmpDir, "pending.json") + data, err := os.ReadFile(pendingPath) + if err != nil { + t.Fatalf("reading pending.json: %v", err) } - // Training script will fail (not available in test env) — result should reflect that - if result == nil { - t.Fatal("expected non-nil result") + var request TrainingRequest + if err := json.Unmarshal(data, &request); err != nil { + t.Fatalf("parsing pending.json: %v", err) } - if result.Status != "failed" { - t.Errorf("expected status 'failed' (no training env), got %q", result.Status) + if request.Trigger != "manual" { + t.Errorf("expected trigger 'manual', got %q", request.Trigger) } - if result.ErrorMessage == "" { - t.Error("expected error message") + if request.TotalExamples != 1 { + t.Errorf("expected 1 total example in request, got %d", request.TotalExamples) } + if request.RunID == "" { + t.Error("expected non-empty run_id in request") + } +} - // Should have updated the training run to failed - if len(ms.trainingRunsU) < 1 { - t.Fatal("expected at least 1 training run update") +func TestRunTrainingCycle_SkipsWhenPendingExists(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("MNEMONIC_TRAINING_REQUESTS_DIR", tmpDir) + + // Pre-create a pending.json + if err := os.WriteFile(filepath.Join(tmpDir, "pending.json"), []byte(`{}`), 0o644); err != nil { + t.Fatal(err) } - lastUpdate := ms.trainingRunsU[len(ms.trainingRunsU)-1] - if lastUpdate.Status != "failed" { - t.Errorf("expected updated status 'failed', got %q", lastUpdate.Status) + + ms := &triggerMockStore{untrainedCount: 100} + agent := NewDreamingAgent(ms, nil, DreamingConfig{Interval: time.Hour}, slog.New(slog.NewTextHandler(io.Discard, nil))) + + result, err := agent.RunTrainingCycle(context.Background(), baseCLConfig(), "auto") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result != nil { + t.Fatal("expected nil result when pending request exists") + } +} + +func TestPickUpTrainingResult_NoFile(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("MNEMONIC_TRAINING_REQUESTS_DIR", tmpDir) + + ms := &triggerMockStore{} + log := slog.New(slog.NewTextHandler(io.Discard, nil)) + + err := PickUpTrainingResult(context.Background(), ms, log) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + // No updates should have happened + if len(ms.trainingRunsU) != 0 { + t.Errorf("expected no training run updates, got %d", len(ms.trainingRunsU)) + } +} + +func TestPickUpTrainingResult_CompletedRun(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("MNEMONIC_TRAINING_REQUESTS_DIR", tmpDir) + + // Write a result file + result := TrainingResultFile{ + RequestID: "tr-20260413-abc", + RunID: "abc12345", + Status: "completed", + CheckpointPath: "/tmp/checkpoint", + ModelPath: "/tmp/model.gguf", + EvalEPR: 0.95, + EvalFR: 0.02, + EvalSC: 0.98, + QualityPassed: true, + CompletedAt: time.Now().UTC().Format(time.RFC3339), + } + data, _ := json.Marshal(result) + if err := os.WriteFile(filepath.Join(tmpDir, "result.json"), data, 0o644); err != nil { + t.Fatal(err) + } + + ms := &triggerMockStore{} + log := slog.New(slog.NewTextHandler(io.Discard, nil)) + + err := PickUpTrainingResult(context.Background(), ms, log) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Should have updated the training run + if len(ms.trainingRunsU) != 1 { + t.Fatalf("expected 1 training run update, got %d", len(ms.trainingRunsU)) + } + update := ms.trainingRunsU[0] + if update.ID != "abc12345" { + t.Errorf("expected run ID 'abc12345', got %q", update.ID) + } + if update.Status != "completed" { + t.Errorf("expected status 'completed', got %q", update.Status) + } + if !update.QualityPassed { + t.Error("expected quality_passed to be true") + } + if update.EvalEPR != 0.95 { + t.Errorf("expected EPR 0.95, got %.2f", update.EvalEPR) + } + + // Result file should be archived (renamed) + if _, err := os.Stat(filepath.Join(tmpDir, "result.json")); !os.IsNotExist(err) { + t.Error("expected result.json to be archived (renamed)") + } +} + +func TestPickUpTrainingResult_FailedRun(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("MNEMONIC_TRAINING_REQUESTS_DIR", tmpDir) + + result := TrainingResultFile{ + RequestID: "tr-20260413-def", + RunID: "def12345", + Status: "failed", + ErrorMessage: "quality gate failed: EPR=0.82", + CompletedAt: time.Now().UTC().Format(time.RFC3339), + } + data, _ := json.Marshal(result) + if err := os.WriteFile(filepath.Join(tmpDir, "result.json"), data, 0o644); err != nil { + t.Fatal(err) + } + + ms := &triggerMockStore{} + log := slog.New(slog.NewTextHandler(io.Discard, nil)) + + err := PickUpTrainingResult(context.Background(), ms, log) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(ms.trainingRunsU) != 1 { + t.Fatalf("expected 1 training run update, got %d", len(ms.trainingRunsU)) + } + update := ms.trainingRunsU[0] + if update.Status != "failed" { + t.Errorf("expected status 'failed', got %q", update.Status) + } + if update.ErrorMessage != "quality gate failed: EPR=0.82" { + t.Errorf("unexpected error message: %q", update.ErrorMessage) } } @@ -216,53 +358,3 @@ func TestInTrainingWindow(t *testing.T) { }) } } - -func TestParseEvalOutput(t *testing.T) { - t.Run("valid JSON metrics", func(t *testing.T) { - output := "Loading model...\nRunning evaluation...\n{\"epr\": 0.92, \"fr\": 0.03, \"sc\": 0.96}\nDone." - result, err := parseEvalOutput(output) - if err != nil { - t.Fatalf("parseEvalOutput: %v", err) - } - if result.EPR != 0.92 { - t.Errorf("expected EPR 0.92, got %.2f", result.EPR) - } - if result.FR != 0.03 { - t.Errorf("expected FR 0.03, got %.2f", result.FR) - } - if result.SC != 0.96 { - t.Errorf("expected SC 0.96, got %.2f", result.SC) - } - }) - - t.Run("no JSON in output", func(t *testing.T) { - _, err := parseEvalOutput("No metrics here\nJust text output") - if err == nil { - t.Fatal("expected error for missing JSON") - } - }) - - t.Run("quality gate pass", func(t *testing.T) { - output := `{"epr": 0.95, "fr": 0.02, "sc": 0.98}` - result, err := parseEvalOutput(output) - if err != nil { - t.Fatalf("parseEvalOutput: %v", err) - } - result.Passed = result.EPR >= 0.90 && result.FR <= 0.05 && result.SC >= 0.95 - if !result.Passed { - t.Error("expected quality gate to pass") - } - }) - - t.Run("quality gate fail low EPR", func(t *testing.T) { - output := `{"epr": 0.85, "fr": 0.02, "sc": 0.98}` - result, err := parseEvalOutput(output) - if err != nil { - t.Fatalf("parseEvalOutput: %v", err) - } - result.Passed = result.EPR >= 0.90 && result.FR <= 0.05 && result.SC >= 0.95 - if result.Passed { - t.Error("expected quality gate to fail for low EPR") - } - }) -} diff --git a/internal/mcp/session.go b/internal/mcp/session.go index eb4b6715..ebad3325 100644 --- a/internal/mcp/session.go +++ b/internal/mcp/session.go @@ -21,18 +21,18 @@ type SessionManager struct { sessions map[string]*httpSession // Shared dependencies (from daemon) - store store.Store - retriever *retrieval.RetrievalAgent - bus events.Bus - log *slog.Logger - version string - coachingFile string - excludePatterns []string - maxContentBytes int - resolver ProjectResolver - daemonURL string - memDefaults MemoryDefaults - trainingTriggerFn func(ctx context.Context) (map[string]any, error) + store store.Store + retriever *retrieval.RetrievalAgent + bus events.Bus + log *slog.Logger + version string + coachingFile string + excludePatterns []string + maxContentBytes int + resolver ProjectResolver + daemonURL string + memDefaults MemoryDefaults + trainingTriggerFn func(ctx context.Context) (map[string]any, error) idleTimeout time.Duration // how long before an idle session is expired stopCh chan struct{} // signals the reaper goroutine to stop @@ -45,19 +45,19 @@ type httpSession struct { // SessionManagerConfig holds configuration for the session manager. type SessionManagerConfig struct { - Store store.Store - Retriever *retrieval.RetrievalAgent - Bus events.Bus - Log *slog.Logger - Version string - CoachingFile string - ExcludePatterns []string - MaxContentBytes int - Resolver *config.ProjectResolver - DaemonURL string - MemDefaults MemoryDefaults - TrainingTriggerFn func(ctx context.Context) (map[string]any, error) - IdleTimeout time.Duration // default: 30 minutes + Store store.Store + Retriever *retrieval.RetrievalAgent + Bus events.Bus + Log *slog.Logger + Version string + CoachingFile string + ExcludePatterns []string + MaxContentBytes int + Resolver *config.ProjectResolver + DaemonURL string + MemDefaults MemoryDefaults + TrainingTriggerFn func(ctx context.Context) (map[string]any, error) + IdleTimeout time.Duration // default: 30 minutes } // NewSessionManager creates a session manager for HTTP MCP transport. @@ -68,21 +68,21 @@ func NewSessionManager(cfg SessionManagerConfig) *SessionManager { } sm := &SessionManager{ - sessions: make(map[string]*httpSession), - store: cfg.Store, - retriever: cfg.Retriever, - bus: cfg.Bus, - log: cfg.Log, - version: cfg.Version, - coachingFile: cfg.CoachingFile, - excludePatterns: cfg.ExcludePatterns, - maxContentBytes: cfg.MaxContentBytes, - resolver: cfg.Resolver, - daemonURL: cfg.DaemonURL, - memDefaults: cfg.MemDefaults, - trainingTriggerFn: cfg.TrainingTriggerFn, - idleTimeout: timeout, - stopCh: make(chan struct{}), + sessions: make(map[string]*httpSession), + store: cfg.Store, + retriever: cfg.Retriever, + bus: cfg.Bus, + log: cfg.Log, + version: cfg.Version, + coachingFile: cfg.CoachingFile, + excludePatterns: cfg.ExcludePatterns, + maxContentBytes: cfg.MaxContentBytes, + resolver: cfg.Resolver, + daemonURL: cfg.DaemonURL, + memDefaults: cfg.MemDefaults, + trainingTriggerFn: cfg.TrainingTriggerFn, + idleTimeout: timeout, + stopCh: make(chan struct{}), } // Start background reaper for idle sessions diff --git a/internal/mcp/tools.go b/internal/mcp/tools.go index 2ac42c1d..1768de27 100644 --- a/internal/mcp/tools.go +++ b/internal/mcp/tools.go @@ -708,7 +708,7 @@ func allToolDefs() []ToolDefinition { func trainModelToolDef() ToolDefinition { return ToolDefinition{ Name: "train_model", - Description: "Trigger a spoke fine-tuning cycle using accumulated experience data. Assembles gold and corrected encoding pairs into a training batch, runs spoke training, evaluates quality against probes, and deploys if the quality gate passes. Requires sufficient untrained data in the experience buffer (default: 50 entries).", + Description: "Request a spoke fine-tuning cycle using accumulated experience data. Assembles gold and corrected encoding pairs into a training batch and writes a training request for the systemd training service. Training runs asynchronously after the daemon stops to free VRAM. Results are picked up on the next daemon startup. Requires sufficient untrained data in the experience buffer (default: 50 entries).", InputSchema: map[string]any{ "type": "object", "properties": map[string]any{}, diff --git a/scripts/continuous_train.sh b/scripts/continuous_train.sh new file mode 100755 index 00000000..bfda96eb --- /dev/null +++ b/scripts/continuous_train.sh @@ -0,0 +1,235 @@ +#!/bin/bash +# continuous_train.sh — Orchestrates spoke training outside the daemon process. +# +# Called by mnemonic-train.service (triggered by mnemonic-train.path when +# pending.json appears). Stops the daemon to free VRAM, runs training, +# writes a result file, and always restarts the daemon. +# +# Usage: continuous_train.sh +# Reads: ~/.mnemonic/training_requests/pending.json +# Writes: ~/.mnemonic/training_requests/result.json + +set -uo pipefail + +REQUEST_DIR="${MNEMONIC_TRAINING_REQUESTS_DIR:-$HOME/.mnemonic/training_requests}" +REQUEST="$REQUEST_DIR/pending.json" +RESULT="$REQUEST_DIR/result.json" +LOG="$REQUEST_DIR/train_$(date +%Y%m%d_%H%M%S).log" + +PROJECT_DIR="${MNEMONIC_PROJECT_DIR:-$HOME/Projects/mem}" +VENV_PYTHON="${MNEMONIC_VENV_PYTHON:-$HOME/Projects/felixlm/.venv/bin/python}" + +# Fall back to system python if venv not found +if [ ! -f "$VENV_PYTHON" ]; then + VENV_PYTHON="python3" +fi + +# CRITICAL: Always restart the daemon, even on training failure. +# The daemon must come back up regardless of what happens here. +cleanup() { + local exit_code=$? + + # Archive the request file (move out of watched path) + if [ -f "$REQUEST" ]; then + mv "$REQUEST" "$REQUEST_DIR/completed_$(date +%Y%m%d_%H%M%S).json" + fi + + # Restart the daemon — this MUST happen + echo "[continuous_train] Restarting mnemonic daemon..." + systemctl --user start mnemonic + + # Keep only the last 10 log files + ls -t "$REQUEST_DIR"/train_*.log 2>/dev/null | tail -n +11 | xargs -r rm + + echo "[continuous_train] Done (exit code: $exit_code)" +} +trap cleanup EXIT + +echo "[continuous_train] Starting training cycle at $(date)" +echo "[continuous_train] Request: $REQUEST" + +# Validate request file exists +if [ ! -f "$REQUEST" ]; then + echo "[continuous_train] ERROR: No pending request at $REQUEST" + exit 1 +fi + +# Parse request +REQUEST_ID=$(jq -r '.request_id' "$REQUEST") +RUN_ID=$(jq -r '.run_id' "$REQUEST") +BATCH_PATH=$(jq -r '.batch_path' "$REQUEST") +TOTAL_EXAMPLES=$(jq -r '.total_examples' "$REQUEST") + +echo "[continuous_train] Request ID: $REQUEST_ID" +echo "[continuous_train] Run ID: $RUN_ID" +echo "[continuous_train] Batch: $BATCH_PATH ($TOTAL_EXAMPLES examples)" + +# Validate batch file exists +if [ ! -f "$BATCH_PATH" ]; then + echo "[continuous_train] ERROR: Batch file not found at $BATCH_PATH" + jq -n \ + --arg request_id "$REQUEST_ID" \ + --arg run_id "$RUN_ID" \ + --arg error "batch file not found at $BATCH_PATH" \ + --arg completed_at "$(date -u +%Y-%m-%dT%H:%M:%SZ)" \ + '{request_id: $request_id, run_id: $run_id, status: "failed", error_message: $error, quality_passed: false, completed_at: $completed_at}' \ + > "$RESULT" + exit 1 +fi + +# Helper: write a failure result and exit +write_failure() { + local msg="$1" + echo "[continuous_train] FAILED: $msg" + jq -n \ + --arg request_id "$REQUEST_ID" \ + --arg run_id "$RUN_ID" \ + --arg error "$msg" \ + --arg completed_at "$(date -u +%Y-%m-%dT%H:%M:%SZ)" \ + '{request_id: $request_id, run_id: $run_id, status: "failed", error_message: $error, quality_passed: false, completed_at: $completed_at}' \ + > "$RESULT" + exit 1 +} + +# Stop the daemon to free VRAM +echo "[continuous_train] Stopping mnemonic daemon to free VRAM..." +systemctl --user stop mnemonic || true +sleep 2 # Give GPU time to release memory + +# Check VRAM is actually free +if command -v rocm-smi &>/dev/null; then + VRAM_USED=$(rocm-smi --showmeminfo vram 2>/dev/null | grep "Used" | awk '{print $NF}' | head -1) + echo "[continuous_train] VRAM used after daemon stop: ${VRAM_USED:-unknown}" +fi + +# Step 1: Tokenize the batch data +echo "[continuous_train] Step 1: Preparing training data..." +PREP_SCRIPT="$PROJECT_DIR/training/scripts/prepare_gemma_finetune_data.py" +TOKENIZED_DIR="$(dirname "$BATCH_PATH")/tokenized" + +if [ ! -f "$PREP_SCRIPT" ]; then + write_failure "prep script not found at $PREP_SCRIPT" +fi + +"$VENV_PYTHON" "$PREP_SCRIPT" \ + --input "$BATCH_PATH" \ + --output-dir "$TOKENIZED_DIR" \ + --max-seq-len 2048 \ + --eval-ratio 0 \ + 2>&1 | tee -a "$LOG" + +TOKENIZED_PATH="$TOKENIZED_DIR/train.jsonl" +if [ ! -f "$TOKENIZED_PATH" ]; then + write_failure "tokenized data not found at $TOKENIZED_PATH after prep" +fi + +echo "[continuous_train] Data prepared: $TOKENIZED_PATH" + +# Step 2: Run spoke training +echo "[continuous_train] Step 2: Training spokes..." +TRAIN_SCRIPT="$PROJECT_DIR/training/scripts/train_spokes.py" +CHECKPOINT_DIR="$PROJECT_DIR/checkpoints/continuous_learning" +mkdir -p "$CHECKPOINT_DIR" + +if [ ! -f "$TRAIN_SCRIPT" ]; then + write_failure "training script not found at $TRAIN_SCRIPT" +fi + +"$VENV_PYTHON" "$TRAIN_SCRIPT" \ + --model-type gemma \ + --base-model google/gemma-4-E2B-it \ + --train-data "$TOKENIZED_PATH" \ + --checkpoint-dir "$CHECKPOINT_DIR" \ + --seq-len 2048 \ + --steps 500 \ + --batch-size 1 \ + --grad-accum 8 \ + --lr 1e-4 \ + --no-wandb \ + 2>&1 | tee -a "$LOG" + +CHECKPOINT_PATH="$CHECKPOINT_DIR/last.pt" +if [ ! -f "$CHECKPOINT_PATH" ]; then + write_failure "checkpoint not found after training at $CHECKPOINT_PATH" +fi + +echo "[continuous_train] Training complete: $CHECKPOINT_PATH" + +# Step 3: Quality gate evaluation +echo "[continuous_train] Step 3: Running quality gate..." +EVAL_SCRIPT="$PROJECT_DIR/training/scripts/eval_encoding.py" + +if [ ! -f "$EVAL_SCRIPT" ]; then + write_failure "eval script not found at $EVAL_SCRIPT" +fi + +EVAL_OUTPUT=$("$VENV_PYTHON" "$EVAL_SCRIPT" \ + --checkpoint "$CHECKPOINT_PATH" \ + --mode generate \ + --json-output \ + 2>&1 | tee -a "$LOG") + +# Extract the JSON metrics line (last line starting with '{') +EVAL_JSON=$(echo "$EVAL_OUTPUT" | grep '^{' | tail -1) +if [ -z "$EVAL_JSON" ]; then + write_failure "no JSON metrics in eval output" +fi + +EPR=$(echo "$EVAL_JSON" | jq -r '.epr') +FR=$(echo "$EVAL_JSON" | jq -r '.fr') +SC=$(echo "$EVAL_JSON" | jq -r '.sc') + +echo "[continuous_train] Quality gate results: EPR=$EPR FR=$FR SC=$SC" + +# Check thresholds: EPR >= 0.90, FR <= 0.05, SC >= 0.95 +PASSED=$(echo "$EPR $FR $SC" | awk '{ + if ($1 >= 0.90 && $2 <= 0.05 && $3 >= 0.95) print "true" + else print "false" +}') + +if [ "$PASSED" = "false" ]; then + echo "[continuous_train] Quality gate FAILED" + jq -n \ + --arg request_id "$REQUEST_ID" \ + --arg run_id "$RUN_ID" \ + --arg checkpoint "$CHECKPOINT_PATH" \ + --argjson epr "$EPR" \ + --argjson fr "$FR" \ + --argjson sc "$SC" \ + --arg error "quality gate failed: EPR=$EPR FR=$FR SC=$SC" \ + --arg completed_at "$(date -u +%Y-%m-%dT%H:%M:%SZ)" \ + '{request_id: $request_id, run_id: $run_id, status: "failed", checkpoint_path: $checkpoint, eval_epr: $epr, eval_fr: $fr, eval_sc: $sc, quality_passed: false, error_message: $error, completed_at: $completed_at}' \ + > "$RESULT" + exit 0 # Not an error — training ran, quality was insufficient +fi + +echo "[continuous_train] Quality gate PASSED" + +# Step 4: Deploy new spokes +echo "[continuous_train] Step 4: Deploying model..." +DEPLOY_SCRIPT="$PROJECT_DIR/training/scripts/deploy_model.sh" +MODEL_NAME="gemma-spokes-cl-$(date +%Y%m%d-%H%M%S)" + +if [ -f "$DEPLOY_SCRIPT" ]; then + bash "$DEPLOY_SCRIPT" "$CHECKPOINT_PATH" --name "$MODEL_NAME" 2>&1 | tee -a "$LOG" + MODEL_PATH="$PROJECT_DIR/models/${MODEL_NAME}.gguf" +else + echo "[continuous_train] WARNING: deploy script not found, skipping deployment" + MODEL_PATH="" +fi + +# Write success result +jq -n \ + --arg request_id "$REQUEST_ID" \ + --arg run_id "$RUN_ID" \ + --arg checkpoint "$CHECKPOINT_PATH" \ + --arg model "${MODEL_PATH:-}" \ + --argjson epr "$EPR" \ + --argjson fr "$FR" \ + --argjson sc "$SC" \ + --arg completed_at "$(date -u +%Y-%m-%dT%H:%M:%SZ)" \ + '{request_id: $request_id, run_id: $run_id, status: "completed", checkpoint_path: $checkpoint, model_path: $model, eval_epr: $epr, eval_fr: $fr, eval_sc: $sc, quality_passed: true, completed_at: $completed_at}' \ + > "$RESULT" + +echo "[continuous_train] Training cycle completed successfully" +echo "[continuous_train] Model: $MODEL_PATH" diff --git a/scripts/systemd/mnemonic-train.path b/scripts/systemd/mnemonic-train.path new file mode 100644 index 00000000..b7910e81 --- /dev/null +++ b/scripts/systemd/mnemonic-train.path @@ -0,0 +1,9 @@ +[Unit] +Description=Watch for Mnemonic training requests + +[Path] +PathExists=%h/.mnemonic/training_requests/pending.json +Unit=mnemonic-train.service + +[Install] +WantedBy=default.target diff --git a/scripts/systemd/mnemonic-train.service b/scripts/systemd/mnemonic-train.service new file mode 100644 index 00000000..8e3ce990 --- /dev/null +++ b/scripts/systemd/mnemonic-train.service @@ -0,0 +1,14 @@ +[Unit] +Description=Mnemonic continuous learning training cycle +# Don't require mnemonic.service — the script stops it before training +# and restarts it after, regardless of outcome. + +[Service] +Type=oneshot +ExecStart=%h/Projects/mem/scripts/continuous_train.sh +# 30 minutes should be enough for 500 steps on RX 7800 XT. +# If training takes longer, the timeout kills the process and the +# EXIT trap in the script still restarts the daemon. +TimeoutStartSec=1800 +Environment=PYTHONUNBUFFERED=1 +WorkingDirectory=%h/Projects/mem From ac2bf1bcb6bb7b213450759943241029be361741 Mon Sep 17 00:00:00 2001 From: Caleb Gross Date: Tue, 14 Apr 2026 00:11:12 -0400 Subject: [PATCH 8/8] feat: enable continuous learning by default (#391) Pipeline is built, tested, and safe (systemd-orchestrated). No reason to keep it gated behind a flag. Auto-trigger runs during 02:00-06:00 training window, curriculum generation runs during dreaming hours. Co-Authored-By: Claude Opus 4.6 (1M context) --- internal/config/config.go | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/internal/config/config.go b/internal/config/config.go index 5b69d971..630d0f2f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -881,6 +881,26 @@ func Default() *Config { MCP: MCPConfig{ Enabled: true, }, + ContinuousLearning: ContinuousLearningConfig{ + Enabled: true, + Training: CLTrainingConfig{ + MinNewExamples: 50, + MaxExamplesPerRun: 200, + ReplayRatio: 0.30, + RollbackVersions: 3, + }, + Curriculum: CLCurriculumConfig{ + Enabled: true, + MaxCorrectionsPerCycle: 20, + MinNeedsImprovement: 10, + CooldownHours: 24, + }, + Trigger: CLTriggerConfig{ + Auto: true, + Manual: true, + TrainingWindow: "02:00-06:00", + }, + }, AgentSDK: AgentSDKConfig{ Enabled: false, EvolutionDir: "",