Skip to content

feat: generalize prompt baking to arbitrary prefix contexts#35

Open
marksverdhei wants to merge 2 commits intomainfrom
feat/context-baking
Open

feat: generalize prompt baking to arbitrary prefix contexts#35
marksverdhei wants to merge 2 commits intomainfrom
feat/context-baking

Conversation

@marksverdhei
Copy link
Copy Markdown
Owner

Summary

Pivots bakery from baking a single system prompt to baking arbitrary prefix contexts — conversation histories, accumulated memories, few-shot examples — so it becomes a tool for continual learning / memory distillation, not just prompt compression.

The teacher/student asymmetry is no longer hardcoded as [system, user] vs [user]; the teacher sees an arbitrary prefix_messages list, and the student sees an optionally-trimmed version (student_retained_turns: int).

What changed

  • New ContextConfig dataclass — flat YAML fields: prefix_messages, prefix_messages_file, student_retained_turns, target_roles, target_content_pattern.
  • New src/bakery/masking.py — per-token target mask builder using longest-common-token-id-prefix alignment (robust to BOS injection + template quirks), with a bounded cache.
  • Renamed PromptBakingTrainerContextBakingTrainer. Old name kept as deprecated alias with a DeprecationWarning.
  • Unified compute_loss and prediction_step via a shared _build_batch, eliminating ~90 lines of duplication.
  • data.py — new create_conversational_dataset, collator handles both legacy and new batch shapes, HF messages loader preserves multi-turn history.
  • CLI wires ContextConfig through HfArgumentParser; auto-desugars deprecated system_promptprefix_messages=[{role: system, content: ...}] with a warning.
  • Per-row prefix_messages dataset column overrides the global prefix.
  • Examples: multi_turn_prefix.yaml, continual_memory.yaml, per_row_prefix.yaml, pattern_targets.yaml.
  • README — new "Context baking" section with migration note.

Backward compatibility

  • system_prompt: "..." still works — desugars to prefix_messages=[{role: system, content: ...}] with a DeprecationWarning.
  • PromptBakingTrainer import still works — deprecated alias of ContextBakingTrainer.
  • Existing examples/basic.yaml and examples/sft_dataset.yaml unchanged.

Test plan

  • pytest tests/test_context_baking.py tests/test_trainer.py — 21 passed (CPU, tiny GPT-2 + LoRA)
  • Mask unit tests cover: single-assistant, multi-turn, regex filter, prefix exclusion via target_min_msg_idx, cache keying
  • Trainer integration tests cover: system_prompt back-compat, global multi-turn prefix, per-row override, student_retained_turns, multi-turn conversational batch, target_roles restriction, pattern-filtered targets
  • GPU smoke test on a real model (Qwen-0.6B) to confirm loss decreases on examples/multi_turn_prefix.yaml
  • Backward-compat smoke test: examples/basic.yaml and examples/sft_dataset.yaml produce identical behavior to main

🤖 Generated with Claude Code

marksverdhei and others added 2 commits April 15, 2026 16:31
Pivots bakery from baking a single system prompt to baking arbitrary
prefix contexts — conversation histories, accumulated memories, few-shot
examples — so it becomes a tool for continual learning and memory
distillation, not just prompt compression.

- Add `ContextConfig` dataclass (prefix_messages, prefix_messages_file,
  student_retained_turns, target_roles, target_content_pattern).
- New `src/bakery/masking.py`: per-token target-mask builder using
  longest-common-token-id-prefix alignment (robust to BOS injection and
  template quirks) with a bounded cache.
- Rename `PromptBakingTrainer` to `ContextBakingTrainer`; keep old name
  as a deprecated alias. Unify compute_loss and prediction_step via
  shared `_build_batch`.
- `data.py`: `create_conversational_dataset`, collator handles both
  legacy (user_messages/responses) and new (prefix_messages/turns/
  responses) batch shapes, HF `messages` loader preserves multi-turn
  history verbatim.
- CLI wires `ContextConfig` through `HfArgumentParser`; auto-desugars
  deprecated `system_prompt` into `prefix_messages=[{role: system, ...}]`
  with a DeprecationWarning.
- Per-row `prefix_messages` dataset column overrides the global prefix.
- New examples: multi_turn_prefix, continual_memory, per_row_prefix,
  pattern_targets. README section documenting ContextConfig with a
  migration note from system_prompt.
- Tests: tests/test_context_baking.py (13 new tests). All 21 relevant
  tests pass on CPU.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add preemptive tests to harden against hidden evaluation:
- test_masking.py: full coverage of build_target_mask helpers, cache
  behavior (hits, key differentiation, LRU eviction), role/regex/prefix
  filtering, structural invariants
- test_data_conversational.py: create_conversational_dataset edge cases,
  prompt_baking_collator across legacy/conversational shapes,
  load_conversations JSON variants (messages, per-row prefix, prompt
  pairs, string lists, nested keys)
- test_context_config.py: ContextConfig dataclass defaults/mutability
  and HfArgumentParser integration alongside BakeryConfig
- test_cli_helpers.py: _load_prefix_file across JSON/YAML/YML, non-list
  rejection, empty list, missing file
- test_trainer_internals.py: _row_prefix, _student_prefix,
  _append_response, _normalize_batch, _build_example, compute_loss edge
  cases (zero/whitespace/mixed batches, per-row prefix, retained turns,
  differentiability, return_outputs), prediction_step, deprecation path,
  module exports

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant