Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion sage/cognition/thalamic_router/gameplay_capture.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,16 @@ def run(self) -> CaptureResult:
self.errors.append(f"step {step.index}: decide failed: {e!r}")
continue

# Build the record with provenance metadata
# Build the record with provenance metadata.
#
# `known_good_*` fields are the supervised-training labels. Since
# this trace is a WIN, the action actually taken at this tick is
# by definition a good next action. This lets downstream training
# use the record as a supervised triple:
# state → (baseline-proposed dispatch) [router BC]
# state → known_good_action [action prediction]
# state × skill_params → known_good_action [motor-skill BC]
# Plus outcome-weighted shaping: sample weight ∝ game_outcome.won.
metadata = {
"source": "gameplay",
"game": self.trace.game,
Expand All @@ -385,6 +394,10 @@ def run(self) -> CaptureResult:
"step_index": step.index,
"level": step.level,
"synthetic_kernel_state": True,
# Supervised labels from the winning trace
"known_good_action": step.action,
"known_good_data": step.data,
"known_good_level": step.level,
}
record = RouterRecord(
router_input=router_input,
Expand Down
37 changes: 37 additions & 0 deletions sage/cognition/thalamic_router/tests/test_gameplay_capture.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,10 @@ def test_capture_end_to_end_emits_records(tmp_path):
assert rec.metadata.get("game") == "testgame"
assert rec.metadata.get("synthetic_kernel_state") is True
assert "game_outcome" in rec.metadata
# Supervised-training labels
assert "known_good_action" in rec.metadata
assert rec.metadata["known_good_action"] in (1, 3, 6) # from _make_trace_3_steps
assert "known_good_level" in rec.metadata


def test_capture_writes_to_partition(tmp_path):
Expand Down Expand Up @@ -271,3 +275,36 @@ def test_capture_all_records_have_valid_router_output(tmp_path):
for rec in capture.records:
ok, reason = rec.router_output.validate()
assert ok, f"invalid output: {reason}"


def test_capture_known_good_labels_match_trace_actions(tmp_path):
"""Each record's known_good_action must equal the corresponding TraceStep.action."""
trace = _make_trace_3_steps()
writer = RouterDatasetWriter(base_dir=tmp_path / "dataset",
machine="cbp", compress=False)
capture = GameplayCapture(trace=trace, writer=writer, machine="cbp",
env_factory=_make_env_factory())
capture.run()
writer.close()
assert len(capture.records) == 3
assert capture.records[0].metadata["known_good_action"] == 1 # UP
assert capture.records[1].metadata["known_good_action"] == 3 # LEFT
assert capture.records[2].metadata["known_good_action"] == 6 # CLICK
# Click step carries click data
assert capture.records[2].metadata["known_good_data"] == {"x": 10, "y": 20}
# Non-click steps have None data
assert capture.records[0].metadata["known_good_data"] is None


def test_capture_known_good_levels_match_trace_levels(tmp_path):
"""known_good_level passes through the trace's per-step level."""
trace = _make_trace_3_steps()
writer = RouterDatasetWriter(base_dir=tmp_path / "dataset",
machine="cbp", compress=False)
capture = GameplayCapture(trace=trace, writer=writer, machine="cbp",
env_factory=_make_env_factory())
capture.run()
writer.close()
assert capture.records[0].metadata["known_good_level"] == 0
assert capture.records[1].metadata["known_good_level"] == 0
assert capture.records[2].metadata["known_good_level"] == 1