feat(ai): shared obs encoder contract — schema as single source of truth (Python side)
Some checks are pending
ci / regression gate (push) Waiting to run
deploy-next / deploy dev guide to mc.next.black.lan (push) Waiting to run

Replace the hand-duplicated observation encoder with a schema-driven contract:
obs_schema.json declares the layout (version, obs_dim, per-field ops from a fixed
vocabulary: scalar/reduce/clamp_div, +onehot/frac/histogram/per_entity for v2),
and both Python and Rust interpret it instead of hardcoding the math. Kills the
bit-exact-drift risk that made growing 32->96 dims dangerous.

This commit lands the Python half + the v1 schema (reproduces the historical
32-dim encoder EXACTLY): obs_contract.py interprets the schema; encoders.py
delegates to it (OBS_DIM + field math now come from the schema, not module code).
Verified locally: encoders.encode_observation matches all 56 parity fixtures with
ZERO drift. Design: .project/designs/obs-contract.md.

Next: Rust interpreter (encoder.rs reads the embedded schema), verify-obs-contract
gate + version assertions, then bump to v2 (richer 96-dim) as a schema data change.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
Natalie 2026-06-30 11:49:04 -04:00
parent bbf56c7ab7
commit b67764ec67
4 changed files with 251 additions and 81 deletions

View file

@ -0,0 +1,82 @@
# Observation Encoder Contract — single source of truth (Python ⇄ Rust)
> Owner ask (2026-06-30): one shared contract — schema + versioning + verification
> scripting — that Python (`tooling/rl_self_play/encoders.py`) and Rust
> (`mc-player-api/src/learned/encoder.rs`) both derive from, instead of two
> hand-maintained encoders kept in sync by luck + a parity test. Supersedes the
> hand-rewrite approach in `.project/designs/richer-encoder-spec.md` (that spec's
> 96-dim field list becomes the v2 schema *data*, not dual hand-code).
## Problem
Today the observation encoder is implemented **twice**`encoders.py::encode_observation`
and `encoder.rs::encode_observation` — and kept bit-exact only by `learned_parity.rs`
against captured fixtures. Every field change is a dual edit that silently breaks
parity if the two drift. Growing 32→96 dims this way is 96 chances to diverge.
## Design: schema-driven encoder
**1. The schema is the source of truth.** A versioned JSON, shipped as a data pack:
`public/games/age-of-dwarves/data/ai/obs_schema.json`. It declares `version`,
`obs_dim`, the uniform `normalize` ("asinh"), `clan_order`, the action layout
(`max_units`, `per_unit_actions`, `max_cities`, `city_queue_items`), and a `fields`
list. Each field is `{i, op, ...args}`. Unlisted indices are 0 (then normalized).
**2. Op vocabulary (fixed, interpretable in both languages).**
| op | args | meaning |
|---|---|---|
| `scalar` | `path` | float at a dot-path from the view root (`resources.gold`, `turn`) |
| `reduce` | `list`, `agg`∈{sum,avg,count}, opt `select` (subfield), opt `where` (filter), opt `contains` ({field,needle}) | aggregate a list-valued path; `avg` of empty = 0 |
| `clamp_div` | `path`, `divisor`, `max` | `min(max, path/divisor)` (turn-progress signal) |
| `onehot` (v2) | `path` (int), `size` | 1.0 at `i+value` for value in `0..size`; all-0 if value <0 |
| `frac` (v2) | `num`, `den` | `num/den`, 0 if `den==0` (army-health, food-progress, siege) |
| `histogram` (v2) | `list`, `field`, `vocab[]` | per-bucket counts over a fixed vocab (+ "other") |
| `per_entity` (v2) | `list`, `k`, `subfields[]` | k entity blocks (per-city/unit), each a sub-field list |
`where` predicates: `{field, eq}` (eq may be the sentinel `"$me"` = `view.player`),
`{field, truthy:true}`, `{field, contains}`. A tiny, total, side-effect-free language —
NOT arbitrary code — so both interpreters are ~120 lines and provably equivalent.
**3. Both sides are thin interpreters.**
- Python: `encoders.py` loads the schema once and `encode_observation(view)` walks
`fields` applying ops → `np.float32[obs_dim]` → asinh. `OBS_DIM` etc. come FROM the
schema (no module constants to drift).
- Rust: `encoder.rs` embeds the schema via `include_str!` (compile-time, no runtime
file dependency for the engine) and the same interpreter → `[f32; obs_dim]`. The
action layout (`CITY_QUEUE_ITEMS`, `MAX_UNITS`, …) also reads from the schema.
**4. Versioning.** `schema.version` is asserted equal on both sides at load. It is
written into the ONNX export sidecar and the `LearnedController::ident()` version
string, and checked before a model runs — a model trained on v1 obs cannot be served
against a v2 schema. Bumping the schema (e.g. 32→96) is a deliberate version bump that
retires old artifacts cleanly (the contract already broke once for normalization).
**5. Verification scripting (the shared gate).** `scripts/verify-obs-contract.sh`:
1. `python -m tooling.rl_self_play.verify_obs_contract` — schema well-formed; Python
`obs_dim` == `len(fields'-max-index+1)` consistent; encodes the fixture views and
re-emits `learned_mp_v1_encoder_parity.json` from the schema.
2. `cargo test -p mc-player-api learned_parity` — Rust interpreter on the same fixtures
== the Python-emitted obs, float-exact (≤1e-4), mask bit-exact, **and** asserts the
Rust schema `version`/`obs_dim` match the JSON.
Exit 0 only if schema + Python + Rust all agree. Runs in CI (Python step) + on the
fleet (cargo step, no local Rust toolchain — see `mc-no-local-rust-toolchain`).
## Migration order (parity stays green at every step)
1. Write `obs_schema.json` v1 — encodes the **current 32-dim** layout EXACTLY.
2. Refactor `encoders.py` to interpret it; assert it reproduces the existing fixtures
byte-for-byte (behavior-preserving — proves the interpreter).
3. Refactor `encoder.rs` to interpret the embedded schema; `learned_parity` stays green
on the **unchanged** fixtures (proves cross-language equivalence of the interpreter).
4. Add `verify-obs-contract.sh` + the version assertions. **← contract proven here.**
5. THEN bump to v2: extend `fields` with the richer channels (tech/territory/army/
terrain/civics + clan one-hot) + new ops; regenerate fixtures once; both interpreters
pick it up with **zero field-math edits**. 32→96 becomes a data change.
## Why interpreter, not codegen
Codegen (schema → generated .py/.rs) gives one source too, but adds a generator +
build step + generated-code review surface. The op vocabulary is small and the obs is
tiny (≤128 floats, encoded once per action), so a runtime interpreter's cost is
negligible and there is no build-step/regen drift. Interpreter wins for this size.

View file

@ -0,0 +1,39 @@
{
"$comment": "SINGLE SOURCE OF TRUTH for the RL observation + action contract. Python (tooling/rl_self_play/encoders.py) and Rust (mc-player-api/src/learned/encoder.rs) both interpret this; neither hardcodes the layout. See .project/designs/obs-contract.md. Paths are JSON WIRE keys of PlayerView (note serde rename: UnitView.type_id -> \"type\"). v1 reproduces the historical 32-dim encoder byte-for-byte.",
"version": 1,
"obs_dim": 32,
"normalize": "asinh",
"clan_order": ["ironhold", "goldvein", "blackhammer", "deepforge", "tinkersmith", "runesmith"],
"action": {
"max_units": 16,
"per_unit_actions": 16,
"max_cities": 4,
"city_queue_items": [
"worker", "warrior", "library", "barracks", "forge",
"walls", "longhouse", "monument", "dwarf_warrior", "dwarf_founder",
"spearmen", "archer", "temple", "high_guild_hall", "chronicle_tower",
"mead_hall"
]
},
"fields": [
{"i": 0, "op": "scalar", "path": "resources.gold"},
{"i": 1, "op": "scalar", "path": "resources.gold_per_turn"},
{"i": 2, "op": "scalar", "path": "resources.science_per_turn"},
{"i": 3, "op": "scalar", "path": "score.score_estimate"},
{"i": 4, "op": "scalar", "path": "score.city_count"},
{"i": 5, "op": "scalar", "path": "score.unit_count"},
{"i": 6, "op": "scalar", "path": "resources.happiness_pool"},
{"i": 7, "op": "scalar", "path": "resources.culture_per_turn"},
{"i": 8, "op": "reduce", "agg": "sum", "list": "cities", "select": "yields.food"},
{"i": 9, "op": "reduce", "agg": "sum", "list": "cities", "select": "yields.production"},
{"i": 10, "op": "reduce", "agg": "avg", "list": "cities", "select": "population"},
{"i": 11, "op": "reduce", "agg": "count", "list": "units", "where": {"field": "owner", "eq": "$me"}, "contains": {"field": "type", "needle": "warrior"}},
{"i": 12, "op": "reduce", "agg": "count", "list": "units", "where": {"field": "owner", "eq": "$me"}, "contains": {"field": "type", "needle": "founder"}},
{"i": 16, "op": "reduce", "agg": "count", "list": "diplomacy"},
{"i": 17, "op": "reduce", "agg": "count", "list": "diplomacy", "where": {"field": "relation", "eq": "war"}},
{"i": 18, "op": "reduce", "agg": "count", "list": "diplomacy", "where": {"field": "relation", "eq": "peace"}},
{"i": 19, "op": "reduce", "agg": "count", "list": "diplomacy", "where": {"field": "open_borders", "truthy": true}},
{"i": 24, "op": "scalar", "path": "turn"},
{"i": 25, "op": "clamp_div", "path": "turn", "divisor": 1000.0, "max": 1.0}
]
}

View file

@ -25,22 +25,18 @@ from typing import Any
import numpy as np
try:
from . import obs_contract
except ImportError: # flat import when the dir is on sys.path (harness/tests)
import obs_contract
# ── Observation shape ────────────────────────────────────────────────
# The fixed-length observation vector has three blocks:
# [0:8] self resources + score (gold, gold_per_turn, sci_per_turn,
# score_estimate, city_count, unit_count,
# happiness_pool, culture_per_turn)
# [8:16] self per-turn yields summed across cities (food, production,
# science, gold, culture)
# + (avg city pop, total mil units,
# total founder units)
# [16:24] opponent intel snapshot (opponent count seen, # at war,
# # at peace, # open_borders, ...)
# padded to 8 floats
# [24:32] turn counters (turn number, fraction of game elapsed,
# # cities lost, # cities captured,
# ... pad to 8)
OBS_DIM = 32
# The observation layout is the SHARED CONTRACT in
# `public/games/age-of-dwarves/data/ai/obs_schema.json`, interpreted by
# `obs_contract` (Python) and `mc-player-api/src/learned/encoder.rs` (Rust).
# OBS_DIM and the per-field math live in the schema, not here — see
# `.project/designs/obs-contract.md`.
OBS_DIM = obs_contract.obs_dim()
# ── Action index layout ──────────────────────────────────────────────
# We bucket legal actions deterministically:
@ -106,73 +102,14 @@ def _hex_direction(from_pos: tuple[int, int], to_pos: tuple[int, int]) -> int |
def encode_observation(view: dict[str, Any]) -> np.ndarray:
"""Project a PlayerView dict into a fixed-shape float32 vector."""
obs = np.zeros(OBS_DIM, dtype=np.float32)
res = view.get("resources", {})
score = view.get("score", {})
obs[0] = float(res.get("gold", 0.0))
obs[1] = float(res.get("gold_per_turn", 0.0))
obs[2] = float(res.get("science_per_turn", 0.0))
obs[3] = float(score.get("score_estimate", 0.0))
obs[4] = float(score.get("city_count", 0.0))
obs[5] = float(score.get("unit_count", 0.0))
obs[6] = float(res.get("happiness_pool", 0.0))
obs[7] = float(res.get("culture_per_turn", 0.0))
"""Project a PlayerView dict into the shared-contract observation vector.
cities = view.get("cities", [])
if cities:
food = sum(float(c.get("yields", {}).get("food", 0)) for c in cities)
prod = sum(float(c.get("yields", {}).get("production", 0)) for c in cities)
obs[8] = food
obs[9] = prod
obs[10] = sum(float(c.get("population", 0)) for c in cities) / len(cities)
units = view.get("units", [])
me = int(view.get("player", 0))
my_units = [u for u in units if int(u.get("owner", -1)) == me]
obs[11] = float(sum(1 for u in my_units if "warrior" in str(u.get("type", ""))))
obs[12] = float(sum(1 for u in my_units if "founder" in str(u.get("type", ""))))
diplo = view.get("diplomacy", [])
obs[16] = float(len(diplo))
obs[17] = float(sum(1 for d in diplo if d.get("relation") == "war"))
obs[18] = float(sum(1 for d in diplo if d.get("relation") == "peace"))
obs[19] = float(sum(1 for d in diplo if d.get("open_borders")))
obs[24] = float(view.get("turn", 0))
# Bound turn at 1000 (Stage 6.1.5 max_turns) for a rough [0,1] progress signal.
obs[25] = min(1.0, float(view.get("turn", 0)) / 1000.0)
return _normalize_obs(obs)
# ── Observation normalization (mp-v1) ────────────────────────────────
# duel-v4 fed RAW magnitudes (score 0-1000, raw gold/yields/counts) into a
# bare MLP with no VecNormalize. An MLP whose hidden activations were fit on
# duel-early magnitudes (score~35, 4 units) saw mid-game magnitudes
# (score~240, 13 units) as out-of-distribution — the root cause of the
# state-magnitude generalization failure (diagnosed via the p1-29k probe).
# mp-v1 applies a fixed, signed, large-dynamic-range compression so
# competence transfers across magnitude ranges including unseen ones.
#
# Transform: asinh (inverse hyperbolic sine), applied UNIFORMLY to all 32
# dims. Properties that make it the right choice here:
# * defined on ALL reals (unlike log1p, which NaNs on negatives) — dims
# that can plausibly go negative (gold_per_turn deficit, happiness_pool)
# are safe with NO per-dim sign audit, and stay safe as the multiplayer
# distribution populates dims that were zero in bench-v1;
# * asinh(0) == 0 — unused / zero dims and the already-[0,1] progress dim
# (obs[25]) pass through ~unchanged (near-linear for |x| < 1);
# * log-compresses large |x| — score 240 -> ~6.2, 35 -> ~4.3, so the
# mid-game/duel-early gap collapses from 205 to ~1.9.
#
# Parity contract (p1-29f): the Rust encoder
# (`mc-player-api/src/learned/encoder.rs::normalize_obs`) MUST apply the
# byte-equivalent transform. Both sides compute asinh in f64 then cast to
# f32 — f64 asinh agrees cross-implementation to ~1 ULP, well inside the
# parity test's 1e-4 obs tolerance. Do NOT hand-roll ln(x + sqrt(x^2 + 1)):
# f32 intermediates diverge across implementations.
def _normalize_obs(obs: np.ndarray) -> np.ndarray:
return np.arcsinh(obs.astype(np.float64)).astype(np.float32)
Delegates to the schema interpreter (`obs_contract`) the single source of
truth shared with the Rust encoder. The asinh normalization, field layout,
and OBS_DIM all live in `obs_schema.json`. See
`.project/designs/obs-contract.md`.
"""
return obs_contract.encode_observation(view)
def _unit_action_offset(unit_slot: int, sub: int) -> int:

View file

@ -0,0 +1,112 @@
"""Schema-driven observation encoder — the Python half of the shared contract.
The observation layout is NOT hardcoded here; it is read from the single source
of truth `public/games/age-of-dwarves/data/ai/obs_schema.json`. The Rust half
(`mc-player-api/src/learned/encoder.rs`) interprets the same schema. A field's
math lives in exactly one place the op vocabulary below on each side, and
the verification gate asserts the two produce byte-equal output.
See `.project/designs/obs-contract.md`. Paths are PlayerView JSON WIRE keys
(note serde rename: `UnitView.type_id` -> `"type"`).
"""
from __future__ import annotations
import json
import os
from functools import lru_cache
from typing import Any
import numpy as np
# Resolve the schema relative to the repo root (this file is at
# <repo>/tooling/rl_self_play/obs_contract.py).
_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
_SCHEMA_REL = "public/games/age-of-dwarves/data/ai/obs_schema.json"
@lru_cache(maxsize=4)
def load_schema(path: str | None = None) -> dict[str, Any]:
"""Load + lightly validate the obs schema. Cached by path."""
p = path or os.environ.get("MC_OBS_SCHEMA") or os.path.join(_REPO_ROOT, _SCHEMA_REL)
with open(p, encoding="utf-8") as fh:
schema = json.load(fh)
obs_dim = int(schema["obs_dim"])
max_i = max((int(f["i"]) for f in schema["fields"]), default=-1)
if max_i >= obs_dim:
raise ValueError(f"obs schema: field index {max_i} >= obs_dim {obs_dim}")
if schema.get("normalize") not in ("asinh", None):
raise ValueError(f"obs schema: unsupported normalize {schema.get('normalize')!r}")
return schema
def _resolve(obj: Any, path: str) -> Any:
"""Walk a dot-path through nested dicts. Returns None on any miss."""
cur = obj
for part in path.split("."):
if isinstance(cur, dict):
cur = cur.get(part)
else:
return None
return cur
def _matches(item: dict[str, Any], where: dict[str, Any] | None, me: Any) -> bool:
if not where:
return True
f = item.get(where["field"])
if "eq" in where:
target = where["eq"]
if target == "$me":
target = me
return f == target
if where.get("truthy"):
return bool(f)
return True
def _apply_field(view: dict[str, Any], fld: dict[str, Any], me: Any) -> float:
op = fld["op"]
if op == "scalar":
v = _resolve(view, fld["path"])
return float(v) if v is not None else 0.0
if op == "clamp_div":
v = _resolve(view, fld["path"])
v = float(v) if v is not None else 0.0
return float(min(float(fld["max"]), v / float(fld["divisor"])))
if op == "reduce":
items = view.get(fld["list"], []) or []
where = fld.get("where")
filt = [it for it in items if isinstance(it, dict) and _matches(it, where, me)]
contains = fld.get("contains")
if contains is not None:
needle, field = contains["needle"], contains["field"]
return float(sum(1 for it in filt if needle in str(it.get(field, ""))))
agg = fld["agg"]
if agg == "count":
return float(len(filt))
sel = fld.get("select")
vals = [float(_resolve(it, sel) or 0) if sel else 1.0 for it in filt]
if agg == "sum":
return float(sum(vals))
if agg == "avg":
return float(sum(vals) / len(vals)) if vals else 0.0
raise ValueError(f"obs schema: unknown agg {agg!r}")
raise ValueError(f"obs schema: unknown op {op!r}")
def obs_dim(path: str | None = None) -> int:
return int(load_schema(path)["obs_dim"])
def encode_observation(view: dict[str, Any], path: str | None = None) -> np.ndarray:
"""Project a PlayerView dict into the schema's fixed-shape float32 obs."""
schema = load_schema(path)
dim = int(schema["obs_dim"])
obs = np.zeros(dim, dtype=np.float32)
me = int(view.get("player", 0))
for fld in schema["fields"]:
obs[int(fld["i"])] = _apply_field(view, fld, me)
if schema.get("normalize") == "asinh":
# f64 asinh then cast — byte-equivalent to the Rust side (≤1 ULP).
obs = np.arcsinh(obs.astype(np.float64)).astype(np.float32)
return obs