diff --git a/.project/designs/obs-contract.md b/.project/designs/obs-contract.md new file mode 100644 index 00000000..389d679f --- /dev/null +++ b/.project/designs/obs-contract.md @@ -0,0 +1,82 @@ +# Observation Encoder Contract — single source of truth (Python ⇄ Rust) + +> Owner ask (2026-06-30): one shared contract — schema + versioning + verification +> scripting — that Python (`tooling/rl_self_play/encoders.py`) and Rust +> (`mc-player-api/src/learned/encoder.rs`) both derive from, instead of two +> hand-maintained encoders kept in sync by luck + a parity test. Supersedes the +> hand-rewrite approach in `.project/designs/richer-encoder-spec.md` (that spec's +> 96-dim field list becomes the v2 schema *data*, not dual hand-code). + +## Problem + +Today the observation encoder is implemented **twice** — `encoders.py::encode_observation` +and `encoder.rs::encode_observation` — and kept bit-exact only by `learned_parity.rs` +against captured fixtures. Every field change is a dual edit that silently breaks +parity if the two drift. Growing 32→96 dims this way is 96 chances to diverge. + +## Design: schema-driven encoder + +**1. The schema is the source of truth.** A versioned JSON, shipped as a data pack: +`public/games/age-of-dwarves/data/ai/obs_schema.json`. It declares `version`, +`obs_dim`, the uniform `normalize` ("asinh"), `clan_order`, the action layout +(`max_units`, `per_unit_actions`, `max_cities`, `city_queue_items`), and a `fields` +list. Each field is `{i, op, ...args}`. Unlisted indices are 0 (then normalized). + +**2. Op vocabulary (fixed, interpretable in both languages).** + +| op | args | meaning | +|---|---|---| +| `scalar` | `path` | float at a dot-path from the view root (`resources.gold`, `turn`) | +| `reduce` | `list`, `agg`∈{sum,avg,count}, opt `select` (subfield), opt `where` (filter), opt `contains` ({field,needle}) | aggregate a list-valued path; `avg` of empty = 0 | +| `clamp_div` | `path`, `divisor`, `max` | `min(max, path/divisor)` (turn-progress signal) | +| `onehot` (v2) | `path` (int), `size` | 1.0 at `i+value` for value in `0..size`; all-0 if value <0 | +| `frac` (v2) | `num`, `den` | `num/den`, 0 if `den==0` (army-health, food-progress, siege) | +| `histogram` (v2) | `list`, `field`, `vocab[]` | per-bucket counts over a fixed vocab (+ "other") | +| `per_entity` (v2) | `list`, `k`, `subfields[]` | k entity blocks (per-city/unit), each a sub-field list | + +`where` predicates: `{field, eq}` (eq may be the sentinel `"$me"` = `view.player`), +`{field, truthy:true}`, `{field, contains}`. A tiny, total, side-effect-free language — +NOT arbitrary code — so both interpreters are ~120 lines and provably equivalent. + +**3. Both sides are thin interpreters.** +- Python: `encoders.py` loads the schema once and `encode_observation(view)` walks + `fields` applying ops → `np.float32[obs_dim]` → asinh. `OBS_DIM` etc. come FROM the + schema (no module constants to drift). +- Rust: `encoder.rs` embeds the schema via `include_str!` (compile-time, no runtime + file dependency for the engine) and the same interpreter → `[f32; obs_dim]`. The + action layout (`CITY_QUEUE_ITEMS`, `MAX_UNITS`, …) also reads from the schema. + +**4. Versioning.** `schema.version` is asserted equal on both sides at load. It is +written into the ONNX export sidecar and the `LearnedController::ident()` version +string, and checked before a model runs — a model trained on v1 obs cannot be served +against a v2 schema. Bumping the schema (e.g. 32→96) is a deliberate version bump that +retires old artifacts cleanly (the contract already broke once for normalization). + +**5. Verification scripting (the shared gate).** `scripts/verify-obs-contract.sh`: +1. `python -m tooling.rl_self_play.verify_obs_contract` — schema well-formed; Python + `obs_dim` == `len(fields'-max-index+1)` consistent; encodes the fixture views and + re-emits `learned_mp_v1_encoder_parity.json` from the schema. +2. `cargo test -p mc-player-api learned_parity` — Rust interpreter on the same fixtures + == the Python-emitted obs, float-exact (≤1e-4), mask bit-exact, **and** asserts the + Rust schema `version`/`obs_dim` match the JSON. +Exit 0 only if schema + Python + Rust all agree. Runs in CI (Python step) + on the +fleet (cargo step, no local Rust toolchain — see `mc-no-local-rust-toolchain`). + +## Migration order (parity stays green at every step) + +1. Write `obs_schema.json` v1 — encodes the **current 32-dim** layout EXACTLY. +2. Refactor `encoders.py` to interpret it; assert it reproduces the existing fixtures + byte-for-byte (behavior-preserving — proves the interpreter). +3. Refactor `encoder.rs` to interpret the embedded schema; `learned_parity` stays green + on the **unchanged** fixtures (proves cross-language equivalence of the interpreter). +4. Add `verify-obs-contract.sh` + the version assertions. **← contract proven here.** +5. THEN bump to v2: extend `fields` with the richer channels (tech/territory/army/ + terrain/civics + clan one-hot) + new ops; regenerate fixtures once; both interpreters + pick it up with **zero field-math edits**. 32→96 becomes a data change. + +## Why interpreter, not codegen + +Codegen (schema → generated .py/.rs) gives one source too, but adds a generator + +build step + generated-code review surface. The op vocabulary is small and the obs is +tiny (≤128 floats, encoded once per action), so a runtime interpreter's cost is +negligible and there is no build-step/regen drift. Interpreter wins for this size. diff --git a/public/games/age-of-dwarves/data/ai/obs_schema.json b/public/games/age-of-dwarves/data/ai/obs_schema.json new file mode 100644 index 00000000..fb7fb894 --- /dev/null +++ b/public/games/age-of-dwarves/data/ai/obs_schema.json @@ -0,0 +1,39 @@ +{ + "$comment": "SINGLE SOURCE OF TRUTH for the RL observation + action contract. Python (tooling/rl_self_play/encoders.py) and Rust (mc-player-api/src/learned/encoder.rs) both interpret this; neither hardcodes the layout. See .project/designs/obs-contract.md. Paths are JSON WIRE keys of PlayerView (note serde rename: UnitView.type_id -> \"type\"). v1 reproduces the historical 32-dim encoder byte-for-byte.", + "version": 1, + "obs_dim": 32, + "normalize": "asinh", + "clan_order": ["ironhold", "goldvein", "blackhammer", "deepforge", "tinkersmith", "runesmith"], + "action": { + "max_units": 16, + "per_unit_actions": 16, + "max_cities": 4, + "city_queue_items": [ + "worker", "warrior", "library", "barracks", "forge", + "walls", "longhouse", "monument", "dwarf_warrior", "dwarf_founder", + "spearmen", "archer", "temple", "high_guild_hall", "chronicle_tower", + "mead_hall" + ] + }, + "fields": [ + {"i": 0, "op": "scalar", "path": "resources.gold"}, + {"i": 1, "op": "scalar", "path": "resources.gold_per_turn"}, + {"i": 2, "op": "scalar", "path": "resources.science_per_turn"}, + {"i": 3, "op": "scalar", "path": "score.score_estimate"}, + {"i": 4, "op": "scalar", "path": "score.city_count"}, + {"i": 5, "op": "scalar", "path": "score.unit_count"}, + {"i": 6, "op": "scalar", "path": "resources.happiness_pool"}, + {"i": 7, "op": "scalar", "path": "resources.culture_per_turn"}, + {"i": 8, "op": "reduce", "agg": "sum", "list": "cities", "select": "yields.food"}, + {"i": 9, "op": "reduce", "agg": "sum", "list": "cities", "select": "yields.production"}, + {"i": 10, "op": "reduce", "agg": "avg", "list": "cities", "select": "population"}, + {"i": 11, "op": "reduce", "agg": "count", "list": "units", "where": {"field": "owner", "eq": "$me"}, "contains": {"field": "type", "needle": "warrior"}}, + {"i": 12, "op": "reduce", "agg": "count", "list": "units", "where": {"field": "owner", "eq": "$me"}, "contains": {"field": "type", "needle": "founder"}}, + {"i": 16, "op": "reduce", "agg": "count", "list": "diplomacy"}, + {"i": 17, "op": "reduce", "agg": "count", "list": "diplomacy", "where": {"field": "relation", "eq": "war"}}, + {"i": 18, "op": "reduce", "agg": "count", "list": "diplomacy", "where": {"field": "relation", "eq": "peace"}}, + {"i": 19, "op": "reduce", "agg": "count", "list": "diplomacy", "where": {"field": "open_borders", "truthy": true}}, + {"i": 24, "op": "scalar", "path": "turn"}, + {"i": 25, "op": "clamp_div", "path": "turn", "divisor": 1000.0, "max": 1.0} + ] +} diff --git a/tooling/rl_self_play/encoders.py b/tooling/rl_self_play/encoders.py index e787501c..d2d60087 100644 --- a/tooling/rl_self_play/encoders.py +++ b/tooling/rl_self_play/encoders.py @@ -25,22 +25,18 @@ from typing import Any import numpy as np +try: + from . import obs_contract +except ImportError: # flat import when the dir is on sys.path (harness/tests) + import obs_contract + # ── Observation shape ──────────────────────────────────────────────── -# The fixed-length observation vector has three blocks: -# [0:8] self resources + score (gold, gold_per_turn, sci_per_turn, -# score_estimate, city_count, unit_count, -# happiness_pool, culture_per_turn) -# [8:16] self per-turn yields summed across cities (food, production, -# science, gold, culture) -# + (avg city pop, total mil units, -# total founder units) -# [16:24] opponent intel snapshot (opponent count seen, # at war, -# # at peace, # open_borders, ...) -# padded to 8 floats -# [24:32] turn counters (turn number, fraction of game elapsed, -# # cities lost, # cities captured, -# ... pad to 8) -OBS_DIM = 32 +# The observation layout is the SHARED CONTRACT in +# `public/games/age-of-dwarves/data/ai/obs_schema.json`, interpreted by +# `obs_contract` (Python) and `mc-player-api/src/learned/encoder.rs` (Rust). +# OBS_DIM and the per-field math live in the schema, not here — see +# `.project/designs/obs-contract.md`. +OBS_DIM = obs_contract.obs_dim() # ── Action index layout ────────────────────────────────────────────── # We bucket legal actions deterministically: @@ -106,73 +102,14 @@ def _hex_direction(from_pos: tuple[int, int], to_pos: tuple[int, int]) -> int | def encode_observation(view: dict[str, Any]) -> np.ndarray: - """Project a PlayerView dict into a fixed-shape float32 vector.""" - obs = np.zeros(OBS_DIM, dtype=np.float32) - res = view.get("resources", {}) - score = view.get("score", {}) - obs[0] = float(res.get("gold", 0.0)) - obs[1] = float(res.get("gold_per_turn", 0.0)) - obs[2] = float(res.get("science_per_turn", 0.0)) - obs[3] = float(score.get("score_estimate", 0.0)) - obs[4] = float(score.get("city_count", 0.0)) - obs[5] = float(score.get("unit_count", 0.0)) - obs[6] = float(res.get("happiness_pool", 0.0)) - obs[7] = float(res.get("culture_per_turn", 0.0)) + """Project a PlayerView dict into the shared-contract observation vector. - cities = view.get("cities", []) - if cities: - food = sum(float(c.get("yields", {}).get("food", 0)) for c in cities) - prod = sum(float(c.get("yields", {}).get("production", 0)) for c in cities) - obs[8] = food - obs[9] = prod - obs[10] = sum(float(c.get("population", 0)) for c in cities) / len(cities) - - units = view.get("units", []) - me = int(view.get("player", 0)) - my_units = [u for u in units if int(u.get("owner", -1)) == me] - obs[11] = float(sum(1 for u in my_units if "warrior" in str(u.get("type", "")))) - obs[12] = float(sum(1 for u in my_units if "founder" in str(u.get("type", "")))) - - diplo = view.get("diplomacy", []) - obs[16] = float(len(diplo)) - obs[17] = float(sum(1 for d in diplo if d.get("relation") == "war")) - obs[18] = float(sum(1 for d in diplo if d.get("relation") == "peace")) - obs[19] = float(sum(1 for d in diplo if d.get("open_borders"))) - - obs[24] = float(view.get("turn", 0)) - # Bound turn at 1000 (Stage 6.1.5 max_turns) for a rough [0,1] progress signal. - obs[25] = min(1.0, float(view.get("turn", 0)) / 1000.0) - return _normalize_obs(obs) - - -# ── Observation normalization (mp-v1) ──────────────────────────────── -# duel-v4 fed RAW magnitudes (score 0-1000, raw gold/yields/counts) into a -# bare MLP with no VecNormalize. An MLP whose hidden activations were fit on -# duel-early magnitudes (score~35, 4 units) saw mid-game magnitudes -# (score~240, 13 units) as out-of-distribution — the root cause of the -# state-magnitude generalization failure (diagnosed via the p1-29k probe). -# mp-v1 applies a fixed, signed, large-dynamic-range compression so -# competence transfers across magnitude ranges including unseen ones. -# -# Transform: asinh (inverse hyperbolic sine), applied UNIFORMLY to all 32 -# dims. Properties that make it the right choice here: -# * defined on ALL reals (unlike log1p, which NaNs on negatives) — dims -# that can plausibly go negative (gold_per_turn deficit, happiness_pool) -# are safe with NO per-dim sign audit, and stay safe as the multiplayer -# distribution populates dims that were zero in bench-v1; -# * asinh(0) == 0 — unused / zero dims and the already-[0,1] progress dim -# (obs[25]) pass through ~unchanged (near-linear for |x| < 1); -# * log-compresses large |x| — score 240 -> ~6.2, 35 -> ~4.3, so the -# mid-game/duel-early gap collapses from 205 to ~1.9. -# -# Parity contract (p1-29f): the Rust encoder -# (`mc-player-api/src/learned/encoder.rs::normalize_obs`) MUST apply the -# byte-equivalent transform. Both sides compute asinh in f64 then cast to -# f32 — f64 asinh agrees cross-implementation to ~1 ULP, well inside the -# parity test's 1e-4 obs tolerance. Do NOT hand-roll ln(x + sqrt(x^2 + 1)): -# f32 intermediates diverge across implementations. -def _normalize_obs(obs: np.ndarray) -> np.ndarray: - return np.arcsinh(obs.astype(np.float64)).astype(np.float32) + Delegates to the schema interpreter (`obs_contract`) — the single source of + truth shared with the Rust encoder. The asinh normalization, field layout, + and OBS_DIM all live in `obs_schema.json`. See + `.project/designs/obs-contract.md`. + """ + return obs_contract.encode_observation(view) def _unit_action_offset(unit_slot: int, sub: int) -> int: diff --git a/tooling/rl_self_play/obs_contract.py b/tooling/rl_self_play/obs_contract.py new file mode 100644 index 00000000..948c4726 --- /dev/null +++ b/tooling/rl_self_play/obs_contract.py @@ -0,0 +1,112 @@ +"""Schema-driven observation encoder — the Python half of the shared contract. + +The observation layout is NOT hardcoded here; it is read from the single source +of truth `public/games/age-of-dwarves/data/ai/obs_schema.json`. The Rust half +(`mc-player-api/src/learned/encoder.rs`) interprets the same schema. A field's +math lives in exactly one place — the op vocabulary below — on each side, and +the verification gate asserts the two produce byte-equal output. + +See `.project/designs/obs-contract.md`. Paths are PlayerView JSON WIRE keys +(note serde rename: `UnitView.type_id` -> `"type"`). +""" +from __future__ import annotations + +import json +import os +from functools import lru_cache +from typing import Any + +import numpy as np + +# Resolve the schema relative to the repo root (this file is at +# /tooling/rl_self_play/obs_contract.py). +_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +_SCHEMA_REL = "public/games/age-of-dwarves/data/ai/obs_schema.json" + + +@lru_cache(maxsize=4) +def load_schema(path: str | None = None) -> dict[str, Any]: + """Load + lightly validate the obs schema. Cached by path.""" + p = path or os.environ.get("MC_OBS_SCHEMA") or os.path.join(_REPO_ROOT, _SCHEMA_REL) + with open(p, encoding="utf-8") as fh: + schema = json.load(fh) + obs_dim = int(schema["obs_dim"]) + max_i = max((int(f["i"]) for f in schema["fields"]), default=-1) + if max_i >= obs_dim: + raise ValueError(f"obs schema: field index {max_i} >= obs_dim {obs_dim}") + if schema.get("normalize") not in ("asinh", None): + raise ValueError(f"obs schema: unsupported normalize {schema.get('normalize')!r}") + return schema + + +def _resolve(obj: Any, path: str) -> Any: + """Walk a dot-path through nested dicts. Returns None on any miss.""" + cur = obj + for part in path.split("."): + if isinstance(cur, dict): + cur = cur.get(part) + else: + return None + return cur + + +def _matches(item: dict[str, Any], where: dict[str, Any] | None, me: Any) -> bool: + if not where: + return True + f = item.get(where["field"]) + if "eq" in where: + target = where["eq"] + if target == "$me": + target = me + return f == target + if where.get("truthy"): + return bool(f) + return True + + +def _apply_field(view: dict[str, Any], fld: dict[str, Any], me: Any) -> float: + op = fld["op"] + if op == "scalar": + v = _resolve(view, fld["path"]) + return float(v) if v is not None else 0.0 + if op == "clamp_div": + v = _resolve(view, fld["path"]) + v = float(v) if v is not None else 0.0 + return float(min(float(fld["max"]), v / float(fld["divisor"]))) + if op == "reduce": + items = view.get(fld["list"], []) or [] + where = fld.get("where") + filt = [it for it in items if isinstance(it, dict) and _matches(it, where, me)] + contains = fld.get("contains") + if contains is not None: + needle, field = contains["needle"], contains["field"] + return float(sum(1 for it in filt if needle in str(it.get(field, "")))) + agg = fld["agg"] + if agg == "count": + return float(len(filt)) + sel = fld.get("select") + vals = [float(_resolve(it, sel) or 0) if sel else 1.0 for it in filt] + if agg == "sum": + return float(sum(vals)) + if agg == "avg": + return float(sum(vals) / len(vals)) if vals else 0.0 + raise ValueError(f"obs schema: unknown agg {agg!r}") + raise ValueError(f"obs schema: unknown op {op!r}") + + +def obs_dim(path: str | None = None) -> int: + return int(load_schema(path)["obs_dim"]) + + +def encode_observation(view: dict[str, Any], path: str | None = None) -> np.ndarray: + """Project a PlayerView dict into the schema's fixed-shape float32 obs.""" + schema = load_schema(path) + dim = int(schema["obs_dim"]) + obs = np.zeros(dim, dtype=np.float32) + me = int(view.get("player", 0)) + for fld in schema["fields"]: + obs[int(fld["i"])] = _apply_field(view, fld, me) + if schema.get("normalize") == "asinh": + # f64 asinh then cast — byte-equivalent to the Rust side (≤1 ULP). + obs = np.arcsinh(obs.astype(np.float64)).astype(np.float32) + return obs