diff --git a/sage/cognition/thalamic_router/gameplay_capture.py b/sage/cognition/thalamic_router/gameplay_capture.py index ae8402a15..33c03e904 100644 --- a/sage/cognition/thalamic_router/gameplay_capture.py +++ b/sage/cognition/thalamic_router/gameplay_capture.py @@ -375,7 +375,16 @@ def run(self) -> CaptureResult: self.errors.append(f"step {step.index}: decide failed: {e!r}") continue - # Build the record with provenance metadata + # Build the record with provenance metadata. + # + # `known_good_*` fields are the supervised-training labels. Since + # this trace is a WIN, the action actually taken at this tick is + # by definition a good next action. This lets downstream training + # use the record as a supervised triple: + # state → (baseline-proposed dispatch) [router BC] + # state → known_good_action [action prediction] + # state × skill_params → known_good_action [motor-skill BC] + # Plus outcome-weighted shaping: sample weight ∝ game_outcome.won. metadata = { "source": "gameplay", "game": self.trace.game, @@ -385,6 +394,10 @@ def run(self) -> CaptureResult: "step_index": step.index, "level": step.level, "synthetic_kernel_state": True, + # Supervised labels from the winning trace + "known_good_action": step.action, + "known_good_data": step.data, + "known_good_level": step.level, } record = RouterRecord( router_input=router_input, diff --git a/sage/cognition/thalamic_router/tests/test_gameplay_capture.py b/sage/cognition/thalamic_router/tests/test_gameplay_capture.py index 01220162c..ade50148e 100644 --- a/sage/cognition/thalamic_router/tests/test_gameplay_capture.py +++ b/sage/cognition/thalamic_router/tests/test_gameplay_capture.py @@ -186,6 +186,10 @@ def test_capture_end_to_end_emits_records(tmp_path): assert rec.metadata.get("game") == "testgame" assert rec.metadata.get("synthetic_kernel_state") is True assert "game_outcome" in rec.metadata + # Supervised-training labels + assert "known_good_action" in rec.metadata + assert rec.metadata["known_good_action"] in (1, 3, 6) # from _make_trace_3_steps + assert "known_good_level" in rec.metadata def test_capture_writes_to_partition(tmp_path): @@ -271,3 +275,36 @@ def test_capture_all_records_have_valid_router_output(tmp_path): for rec in capture.records: ok, reason = rec.router_output.validate() assert ok, f"invalid output: {reason}" + + +def test_capture_known_good_labels_match_trace_actions(tmp_path): + """Each record's known_good_action must equal the corresponding TraceStep.action.""" + trace = _make_trace_3_steps() + writer = RouterDatasetWriter(base_dir=tmp_path / "dataset", + machine="cbp", compress=False) + capture = GameplayCapture(trace=trace, writer=writer, machine="cbp", + env_factory=_make_env_factory()) + capture.run() + writer.close() + assert len(capture.records) == 3 + assert capture.records[0].metadata["known_good_action"] == 1 # UP + assert capture.records[1].metadata["known_good_action"] == 3 # LEFT + assert capture.records[2].metadata["known_good_action"] == 6 # CLICK + # Click step carries click data + assert capture.records[2].metadata["known_good_data"] == {"x": 10, "y": 20} + # Non-click steps have None data + assert capture.records[0].metadata["known_good_data"] is None + + +def test_capture_known_good_levels_match_trace_levels(tmp_path): + """known_good_level passes through the trace's per-step level.""" + trace = _make_trace_3_steps() + writer = RouterDatasetWriter(base_dir=tmp_path / "dataset", + machine="cbp", compress=False) + capture = GameplayCapture(trace=trace, writer=writer, machine="cbp", + env_factory=_make_env_factory()) + capture.run() + writer.close() + assert capture.records[0].metadata["known_good_level"] == 0 + assert capture.records[1].metadata["known_good_level"] == 0 + assert capture.records[2].metadata["known_good_level"] == 1