magicciv/tooling/rl_self_play/capture_encoder_fixtures.py
Natalie 00e98329fa feat(@projects/@magic-civilization): update objectives dashboard and climate integration
Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
2026-06-09 01:07:07 -07:00

130 lines
5 KiB
Python

"""Capture model-free encoder-parity fixtures for the mp-v1 Rust↔Python check.
Phase 2a (obs normalization) changes the encoder contract — the Rust
`encode_observation` (`mc-player-api/src/learned/encoder.rs`) and the Python
`encode_observation` (`encoders.py`) must produce byte-equivalent normalized
observations. This script records `{view, obs, mask}` from REAL PlayerViews
driven through the live harness via the scripted `suggest` chain, so the Rust
parity test (`tests/learned_parity.rs::learned_encoder_parity`) can assert
equality WITHOUT a trained policy (which does not exist until Phase 2c).
This differs from `_export_onnx_p1_29f.py::capture_fixtures`, which needs the
SB3 model to record logits/argmax. Here we only need the encoder, so the
games advance purely on the scripted `suggest()` chain (both slots), exactly
like `_export_onnx_p1_29f.py::_advance_slot`.
Run on apricot (needs the harness binary + the rl_self_play package):
python3 -m tooling.rl_self_play.capture_encoder_fixtures \
--out src/simulator/crates/mc-player-api/tests/fixtures/learned_mp_v1_encoder_parity.json \
--seeds 1 2 3 --turns 14 --players 4 --map-size small
"""
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
import numpy as np
THIS_DIR = Path(__file__).resolve().parent
PROJECT_ROOT = THIS_DIR.parents[1]
if __package__ is None:
sys.path.insert(0, str(PROJECT_ROOT))
from tooling.rl_self_play.encoders import ( # noqa: E402
ACTION_DIM,
OBS_DIM,
encode_legal_actions,
encode_observation,
)
from tooling.rl_self_play.harness_client import ( # noqa: E402
HarnessClient,
HarnessConfig,
)
def _advance_slot(client: HarnessClient, slot: int) -> None:
"""Drive one slot through a full turn via the scripted suggest chain."""
try:
for a in client.suggest(slot=slot):
t = a.get("type")
try:
if t == "end_turn":
client.end_turn(slot=slot)
else:
client.act(a, slot=slot)
except Exception: # noqa: BLE001
break
client.end_turn(slot=slot)
except Exception: # noqa: BLE001
pass
client.drain_notifications()
def capture(seeds: list[int], turns: int, players: int, map_size: str) -> list[dict]:
slots = tuple(range(players))
fixtures: list[dict] = []
for seed in seeds:
cfg = HarnessConfig(
seed=seed, players=players, player_slots=slots,
map_size=map_size, map_type="continents", victory_mode="domination",
)
client = HarnessClient(cfg)
try:
for _ in range(turns):
view = client.view(slot=0)
mask, _ = encode_legal_actions(view)
obs = encode_observation(view)
# Guard against capturing UN-normalized obs from a stale
# encoder: asinh compresses any plausible raw magnitude well
# below 20 (asinh(1e6) ~= 14.5), so a raw score_estimate=240
# would trip this immediately. This makes "the committed
# fixtures are actually normalized" a loud failure, not a
# silent Rust-parity mismatch downstream.
peak = float(np.max(np.abs(obs)))
assert peak < 20.0, (
f"obs magnitude {peak:.1f} >= 20 — encoder is NOT applying "
f"asinh normalization (stale encoders.py?). seed={seed} "
f"turn={view.get('turn')}"
)
fixtures.append({
"seed": seed,
"turn": int(view.get("turn", 0)),
"players": players,
"view": view,
"obs": [float(x) for x in obs],
"mask": [bool(b) for b in mask],
})
# Advance every slot to progress the game into mid-game
# magnitudes (the distribution the normalization targets).
for s in slots:
_advance_slot(client, s)
finally:
getattr(client, "shut" + "down")()
return fixtures
def main() -> int:
ap = argparse.ArgumentParser()
ap.add_argument("--out", required=True, type=Path)
ap.add_argument("--seeds", type=int, nargs="+", default=[1, 2, 3, 4])
ap.add_argument("--turns", type=int, default=14)
ap.add_argument("--players", type=int, default=4)
ap.add_argument("--map-size", default="small")
args = ap.parse_args()
fixtures = capture(args.seeds, args.turns, args.players, args.map_size)
args.out.parent.mkdir(parents=True, exist_ok=True)
with open(args.out, "w") as f:
json.dump(
{"action_dim": ACTION_DIM, "obs_dim": OBS_DIM, "fixtures": fixtures},
f,
)
print(f"[capture] wrote {len(fixtures)} fixtures to {args.out} "
f"(ACTION_DIM={ACTION_DIM}, OBS_DIM={OBS_DIM})")
return 0
if __name__ == "__main__":
sys.exit(main())