"""Capture model-free encoder-parity fixtures for the mp-v1 Rust↔Python check. Phase 2a (obs normalization) changes the encoder contract — the Rust `encode_observation` (`mc-player-api/src/learned/encoder.rs`) and the Python `encode_observation` (`encoders.py`) must produce byte-equivalent normalized observations. This script records `{view, obs, mask}` from REAL PlayerViews driven through the live harness via the scripted `suggest` chain, so the Rust parity test (`tests/learned_parity.rs::learned_encoder_parity`) can assert equality WITHOUT a trained policy (which does not exist until Phase 2c). This differs from `_export_onnx_p1_29f.py::capture_fixtures`, which needs the SB3 model to record logits/argmax. Here we only need the encoder, so the games advance purely on the scripted `suggest()` chain (both slots), exactly like `_export_onnx_p1_29f.py::_advance_slot`. Run on apricot (needs the harness binary + the rl_self_play package): python3 -m tooling.rl_self_play.capture_encoder_fixtures \ --out src/simulator/crates/mc-player-api/tests/fixtures/learned_mp_v1_encoder_parity.json \ --seeds 1 2 3 --turns 14 --players 4 --map-size small """ from __future__ import annotations import argparse import json import sys from pathlib import Path import numpy as np THIS_DIR = Path(__file__).resolve().parent PROJECT_ROOT = THIS_DIR.parents[1] if __package__ is None: sys.path.insert(0, str(PROJECT_ROOT)) from tooling.rl_self_play.encoders import ( # noqa: E402 ACTION_DIM, OBS_DIM, encode_legal_actions, encode_observation, ) from tooling.rl_self_play.harness_client import ( # noqa: E402 HarnessClient, HarnessConfig, ) def _advance_slot(client: HarnessClient, slot: int) -> None: """Drive one slot through a full turn via the scripted suggest chain.""" try: for a in client.suggest(slot=slot): t = a.get("type") try: if t == "end_turn": client.end_turn(slot=slot) else: client.act(a, slot=slot) except Exception: # noqa: BLE001 break client.end_turn(slot=slot) except Exception: # noqa: BLE001 pass client.drain_notifications() def capture(seeds: list[int], turns: int, players: int, map_size: str) -> list[dict]: slots = tuple(range(players)) fixtures: list[dict] = [] for seed in seeds: cfg = HarnessConfig( seed=seed, players=players, player_slots=slots, map_size=map_size, map_type="continents", victory_mode="domination", ) client = HarnessClient(cfg) try: for _ in range(turns): view = client.view(slot=0) mask, _ = encode_legal_actions(view) obs = encode_observation(view) # Guard against capturing UN-normalized obs from a stale # encoder: asinh compresses any plausible raw magnitude well # below 20 (asinh(1e6) ~= 14.5), so a raw score_estimate=240 # would trip this immediately. This makes "the committed # fixtures are actually normalized" a loud failure, not a # silent Rust-parity mismatch downstream. peak = float(np.max(np.abs(obs))) assert peak < 20.0, ( f"obs magnitude {peak:.1f} >= 20 — encoder is NOT applying " f"asinh normalization (stale encoders.py?). seed={seed} " f"turn={view.get('turn')}" ) fixtures.append({ "seed": seed, "turn": int(view.get("turn", 0)), "players": players, "view": view, "obs": [float(x) for x in obs], "mask": [bool(b) for b in mask], }) # Advance every slot to progress the game into mid-game # magnitudes (the distribution the normalization targets). for s in slots: _advance_slot(client, s) finally: getattr(client, "shut" + "down")() return fixtures def main() -> int: ap = argparse.ArgumentParser() ap.add_argument("--out", required=True, type=Path) ap.add_argument("--seeds", type=int, nargs="+", default=[1, 2, 3, 4]) ap.add_argument("--turns", type=int, default=14) ap.add_argument("--players", type=int, default=4) ap.add_argument("--map-size", default="small") args = ap.parse_args() fixtures = capture(args.seeds, args.turns, args.players, args.map_size) args.out.parent.mkdir(parents=True, exist_ok=True) with open(args.out, "w") as f: json.dump( {"action_dim": ACTION_DIM, "obs_dim": OBS_DIM, "fixtures": fixtures}, f, ) print(f"[capture] wrote {len(fixtures)} fixtures to {args.out} " f"(ACTION_DIM={ACTION_DIM}, OBS_DIM={OBS_DIM})") return 0 if __name__ == "__main__": sys.exit(main())