feat(ai): shared obs encoder contract — schema as single source of truth (Python side)
Replace the hand-duplicated observation encoder with a schema-driven contract: obs_schema.json declares the layout (version, obs_dim, per-field ops from a fixed vocabulary: scalar/reduce/clamp_div, +onehot/frac/histogram/per_entity for v2), and both Python and Rust interpret it instead of hardcoding the math. Kills the bit-exact-drift risk that made growing 32->96 dims dangerous. This commit lands the Python half + the v1 schema (reproduces the historical 32-dim encoder EXACTLY): obs_contract.py interprets the schema; encoders.py delegates to it (OBS_DIM + field math now come from the schema, not module code). Verified locally: encoders.encode_observation matches all 56 parity fixtures with ZERO drift. Design: .project/designs/obs-contract.md. Next: Rust interpreter (encoder.rs reads the embedded schema), verify-obs-contract gate + version assertions, then bump to v2 (richer 96-dim) as a schema data change. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
bbf56c7ab7
commit
b67764ec67
4 changed files with 251 additions and 81 deletions
82
.project/designs/obs-contract.md
Normal file
82
.project/designs/obs-contract.md
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
# Observation Encoder Contract — single source of truth (Python ⇄ Rust)
|
||||
|
||||
> Owner ask (2026-06-30): one shared contract — schema + versioning + verification
|
||||
> scripting — that Python (`tooling/rl_self_play/encoders.py`) and Rust
|
||||
> (`mc-player-api/src/learned/encoder.rs`) both derive from, instead of two
|
||||
> hand-maintained encoders kept in sync by luck + a parity test. Supersedes the
|
||||
> hand-rewrite approach in `.project/designs/richer-encoder-spec.md` (that spec's
|
||||
> 96-dim field list becomes the v2 schema *data*, not dual hand-code).
|
||||
|
||||
## Problem
|
||||
|
||||
Today the observation encoder is implemented **twice** — `encoders.py::encode_observation`
|
||||
and `encoder.rs::encode_observation` — and kept bit-exact only by `learned_parity.rs`
|
||||
against captured fixtures. Every field change is a dual edit that silently breaks
|
||||
parity if the two drift. Growing 32→96 dims this way is 96 chances to diverge.
|
||||
|
||||
## Design: schema-driven encoder
|
||||
|
||||
**1. The schema is the source of truth.** A versioned JSON, shipped as a data pack:
|
||||
`public/games/age-of-dwarves/data/ai/obs_schema.json`. It declares `version`,
|
||||
`obs_dim`, the uniform `normalize` ("asinh"), `clan_order`, the action layout
|
||||
(`max_units`, `per_unit_actions`, `max_cities`, `city_queue_items`), and a `fields`
|
||||
list. Each field is `{i, op, ...args}`. Unlisted indices are 0 (then normalized).
|
||||
|
||||
**2. Op vocabulary (fixed, interpretable in both languages).**
|
||||
|
||||
| op | args | meaning |
|
||||
|---|---|---|
|
||||
| `scalar` | `path` | float at a dot-path from the view root (`resources.gold`, `turn`) |
|
||||
| `reduce` | `list`, `agg`∈{sum,avg,count}, opt `select` (subfield), opt `where` (filter), opt `contains` ({field,needle}) | aggregate a list-valued path; `avg` of empty = 0 |
|
||||
| `clamp_div` | `path`, `divisor`, `max` | `min(max, path/divisor)` (turn-progress signal) |
|
||||
| `onehot` (v2) | `path` (int), `size` | 1.0 at `i+value` for value in `0..size`; all-0 if value <0 |
|
||||
| `frac` (v2) | `num`, `den` | `num/den`, 0 if `den==0` (army-health, food-progress, siege) |
|
||||
| `histogram` (v2) | `list`, `field`, `vocab[]` | per-bucket counts over a fixed vocab (+ "other") |
|
||||
| `per_entity` (v2) | `list`, `k`, `subfields[]` | k entity blocks (per-city/unit), each a sub-field list |
|
||||
|
||||
`where` predicates: `{field, eq}` (eq may be the sentinel `"$me"` = `view.player`),
|
||||
`{field, truthy:true}`, `{field, contains}`. A tiny, total, side-effect-free language —
|
||||
NOT arbitrary code — so both interpreters are ~120 lines and provably equivalent.
|
||||
|
||||
**3. Both sides are thin interpreters.**
|
||||
- Python: `encoders.py` loads the schema once and `encode_observation(view)` walks
|
||||
`fields` applying ops → `np.float32[obs_dim]` → asinh. `OBS_DIM` etc. come FROM the
|
||||
schema (no module constants to drift).
|
||||
- Rust: `encoder.rs` embeds the schema via `include_str!` (compile-time, no runtime
|
||||
file dependency for the engine) and the same interpreter → `[f32; obs_dim]`. The
|
||||
action layout (`CITY_QUEUE_ITEMS`, `MAX_UNITS`, …) also reads from the schema.
|
||||
|
||||
**4. Versioning.** `schema.version` is asserted equal on both sides at load. It is
|
||||
written into the ONNX export sidecar and the `LearnedController::ident()` version
|
||||
string, and checked before a model runs — a model trained on v1 obs cannot be served
|
||||
against a v2 schema. Bumping the schema (e.g. 32→96) is a deliberate version bump that
|
||||
retires old artifacts cleanly (the contract already broke once for normalization).
|
||||
|
||||
**5. Verification scripting (the shared gate).** `scripts/verify-obs-contract.sh`:
|
||||
1. `python -m tooling.rl_self_play.verify_obs_contract` — schema well-formed; Python
|
||||
`obs_dim` == `len(fields'-max-index+1)` consistent; encodes the fixture views and
|
||||
re-emits `learned_mp_v1_encoder_parity.json` from the schema.
|
||||
2. `cargo test -p mc-player-api learned_parity` — Rust interpreter on the same fixtures
|
||||
== the Python-emitted obs, float-exact (≤1e-4), mask bit-exact, **and** asserts the
|
||||
Rust schema `version`/`obs_dim` match the JSON.
|
||||
Exit 0 only if schema + Python + Rust all agree. Runs in CI (Python step) + on the
|
||||
fleet (cargo step, no local Rust toolchain — see `mc-no-local-rust-toolchain`).
|
||||
|
||||
## Migration order (parity stays green at every step)
|
||||
|
||||
1. Write `obs_schema.json` v1 — encodes the **current 32-dim** layout EXACTLY.
|
||||
2. Refactor `encoders.py` to interpret it; assert it reproduces the existing fixtures
|
||||
byte-for-byte (behavior-preserving — proves the interpreter).
|
||||
3. Refactor `encoder.rs` to interpret the embedded schema; `learned_parity` stays green
|
||||
on the **unchanged** fixtures (proves cross-language equivalence of the interpreter).
|
||||
4. Add `verify-obs-contract.sh` + the version assertions. **← contract proven here.**
|
||||
5. THEN bump to v2: extend `fields` with the richer channels (tech/territory/army/
|
||||
terrain/civics + clan one-hot) + new ops; regenerate fixtures once; both interpreters
|
||||
pick it up with **zero field-math edits**. 32→96 becomes a data change.
|
||||
|
||||
## Why interpreter, not codegen
|
||||
|
||||
Codegen (schema → generated .py/.rs) gives one source too, but adds a generator +
|
||||
build step + generated-code review surface. The op vocabulary is small and the obs is
|
||||
tiny (≤128 floats, encoded once per action), so a runtime interpreter's cost is
|
||||
negligible and there is no build-step/regen drift. Interpreter wins for this size.
|
||||
39
public/games/age-of-dwarves/data/ai/obs_schema.json
Normal file
39
public/games/age-of-dwarves/data/ai/obs_schema.json
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
{
|
||||
"$comment": "SINGLE SOURCE OF TRUTH for the RL observation + action contract. Python (tooling/rl_self_play/encoders.py) and Rust (mc-player-api/src/learned/encoder.rs) both interpret this; neither hardcodes the layout. See .project/designs/obs-contract.md. Paths are JSON WIRE keys of PlayerView (note serde rename: UnitView.type_id -> \"type\"). v1 reproduces the historical 32-dim encoder byte-for-byte.",
|
||||
"version": 1,
|
||||
"obs_dim": 32,
|
||||
"normalize": "asinh",
|
||||
"clan_order": ["ironhold", "goldvein", "blackhammer", "deepforge", "tinkersmith", "runesmith"],
|
||||
"action": {
|
||||
"max_units": 16,
|
||||
"per_unit_actions": 16,
|
||||
"max_cities": 4,
|
||||
"city_queue_items": [
|
||||
"worker", "warrior", "library", "barracks", "forge",
|
||||
"walls", "longhouse", "monument", "dwarf_warrior", "dwarf_founder",
|
||||
"spearmen", "archer", "temple", "high_guild_hall", "chronicle_tower",
|
||||
"mead_hall"
|
||||
]
|
||||
},
|
||||
"fields": [
|
||||
{"i": 0, "op": "scalar", "path": "resources.gold"},
|
||||
{"i": 1, "op": "scalar", "path": "resources.gold_per_turn"},
|
||||
{"i": 2, "op": "scalar", "path": "resources.science_per_turn"},
|
||||
{"i": 3, "op": "scalar", "path": "score.score_estimate"},
|
||||
{"i": 4, "op": "scalar", "path": "score.city_count"},
|
||||
{"i": 5, "op": "scalar", "path": "score.unit_count"},
|
||||
{"i": 6, "op": "scalar", "path": "resources.happiness_pool"},
|
||||
{"i": 7, "op": "scalar", "path": "resources.culture_per_turn"},
|
||||
{"i": 8, "op": "reduce", "agg": "sum", "list": "cities", "select": "yields.food"},
|
||||
{"i": 9, "op": "reduce", "agg": "sum", "list": "cities", "select": "yields.production"},
|
||||
{"i": 10, "op": "reduce", "agg": "avg", "list": "cities", "select": "population"},
|
||||
{"i": 11, "op": "reduce", "agg": "count", "list": "units", "where": {"field": "owner", "eq": "$me"}, "contains": {"field": "type", "needle": "warrior"}},
|
||||
{"i": 12, "op": "reduce", "agg": "count", "list": "units", "where": {"field": "owner", "eq": "$me"}, "contains": {"field": "type", "needle": "founder"}},
|
||||
{"i": 16, "op": "reduce", "agg": "count", "list": "diplomacy"},
|
||||
{"i": 17, "op": "reduce", "agg": "count", "list": "diplomacy", "where": {"field": "relation", "eq": "war"}},
|
||||
{"i": 18, "op": "reduce", "agg": "count", "list": "diplomacy", "where": {"field": "relation", "eq": "peace"}},
|
||||
{"i": 19, "op": "reduce", "agg": "count", "list": "diplomacy", "where": {"field": "open_borders", "truthy": true}},
|
||||
{"i": 24, "op": "scalar", "path": "turn"},
|
||||
{"i": 25, "op": "clamp_div", "path": "turn", "divisor": 1000.0, "max": 1.0}
|
||||
]
|
||||
}
|
||||
|
|
@ -25,22 +25,18 @@ from typing import Any
|
|||
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
from . import obs_contract
|
||||
except ImportError: # flat import when the dir is on sys.path (harness/tests)
|
||||
import obs_contract
|
||||
|
||||
# ── Observation shape ────────────────────────────────────────────────
|
||||
# The fixed-length observation vector has three blocks:
|
||||
# [0:8] self resources + score (gold, gold_per_turn, sci_per_turn,
|
||||
# score_estimate, city_count, unit_count,
|
||||
# happiness_pool, culture_per_turn)
|
||||
# [8:16] self per-turn yields summed across cities (food, production,
|
||||
# science, gold, culture)
|
||||
# + (avg city pop, total mil units,
|
||||
# total founder units)
|
||||
# [16:24] opponent intel snapshot (opponent count seen, # at war,
|
||||
# # at peace, # open_borders, ...)
|
||||
# padded to 8 floats
|
||||
# [24:32] turn counters (turn number, fraction of game elapsed,
|
||||
# # cities lost, # cities captured,
|
||||
# ... pad to 8)
|
||||
OBS_DIM = 32
|
||||
# The observation layout is the SHARED CONTRACT in
|
||||
# `public/games/age-of-dwarves/data/ai/obs_schema.json`, interpreted by
|
||||
# `obs_contract` (Python) and `mc-player-api/src/learned/encoder.rs` (Rust).
|
||||
# OBS_DIM and the per-field math live in the schema, not here — see
|
||||
# `.project/designs/obs-contract.md`.
|
||||
OBS_DIM = obs_contract.obs_dim()
|
||||
|
||||
# ── Action index layout ──────────────────────────────────────────────
|
||||
# We bucket legal actions deterministically:
|
||||
|
|
@ -106,73 +102,14 @@ def _hex_direction(from_pos: tuple[int, int], to_pos: tuple[int, int]) -> int |
|
|||
|
||||
|
||||
def encode_observation(view: dict[str, Any]) -> np.ndarray:
|
||||
"""Project a PlayerView dict into a fixed-shape float32 vector."""
|
||||
obs = np.zeros(OBS_DIM, dtype=np.float32)
|
||||
res = view.get("resources", {})
|
||||
score = view.get("score", {})
|
||||
obs[0] = float(res.get("gold", 0.0))
|
||||
obs[1] = float(res.get("gold_per_turn", 0.0))
|
||||
obs[2] = float(res.get("science_per_turn", 0.0))
|
||||
obs[3] = float(score.get("score_estimate", 0.0))
|
||||
obs[4] = float(score.get("city_count", 0.0))
|
||||
obs[5] = float(score.get("unit_count", 0.0))
|
||||
obs[6] = float(res.get("happiness_pool", 0.0))
|
||||
obs[7] = float(res.get("culture_per_turn", 0.0))
|
||||
"""Project a PlayerView dict into the shared-contract observation vector.
|
||||
|
||||
cities = view.get("cities", [])
|
||||
if cities:
|
||||
food = sum(float(c.get("yields", {}).get("food", 0)) for c in cities)
|
||||
prod = sum(float(c.get("yields", {}).get("production", 0)) for c in cities)
|
||||
obs[8] = food
|
||||
obs[9] = prod
|
||||
obs[10] = sum(float(c.get("population", 0)) for c in cities) / len(cities)
|
||||
|
||||
units = view.get("units", [])
|
||||
me = int(view.get("player", 0))
|
||||
my_units = [u for u in units if int(u.get("owner", -1)) == me]
|
||||
obs[11] = float(sum(1 for u in my_units if "warrior" in str(u.get("type", ""))))
|
||||
obs[12] = float(sum(1 for u in my_units if "founder" in str(u.get("type", ""))))
|
||||
|
||||
diplo = view.get("diplomacy", [])
|
||||
obs[16] = float(len(diplo))
|
||||
obs[17] = float(sum(1 for d in diplo if d.get("relation") == "war"))
|
||||
obs[18] = float(sum(1 for d in diplo if d.get("relation") == "peace"))
|
||||
obs[19] = float(sum(1 for d in diplo if d.get("open_borders")))
|
||||
|
||||
obs[24] = float(view.get("turn", 0))
|
||||
# Bound turn at 1000 (Stage 6.1.5 max_turns) for a rough [0,1] progress signal.
|
||||
obs[25] = min(1.0, float(view.get("turn", 0)) / 1000.0)
|
||||
return _normalize_obs(obs)
|
||||
|
||||
|
||||
# ── Observation normalization (mp-v1) ────────────────────────────────
|
||||
# duel-v4 fed RAW magnitudes (score 0-1000, raw gold/yields/counts) into a
|
||||
# bare MLP with no VecNormalize. An MLP whose hidden activations were fit on
|
||||
# duel-early magnitudes (score~35, 4 units) saw mid-game magnitudes
|
||||
# (score~240, 13 units) as out-of-distribution — the root cause of the
|
||||
# state-magnitude generalization failure (diagnosed via the p1-29k probe).
|
||||
# mp-v1 applies a fixed, signed, large-dynamic-range compression so
|
||||
# competence transfers across magnitude ranges including unseen ones.
|
||||
#
|
||||
# Transform: asinh (inverse hyperbolic sine), applied UNIFORMLY to all 32
|
||||
# dims. Properties that make it the right choice here:
|
||||
# * defined on ALL reals (unlike log1p, which NaNs on negatives) — dims
|
||||
# that can plausibly go negative (gold_per_turn deficit, happiness_pool)
|
||||
# are safe with NO per-dim sign audit, and stay safe as the multiplayer
|
||||
# distribution populates dims that were zero in bench-v1;
|
||||
# * asinh(0) == 0 — unused / zero dims and the already-[0,1] progress dim
|
||||
# (obs[25]) pass through ~unchanged (near-linear for |x| < 1);
|
||||
# * log-compresses large |x| — score 240 -> ~6.2, 35 -> ~4.3, so the
|
||||
# mid-game/duel-early gap collapses from 205 to ~1.9.
|
||||
#
|
||||
# Parity contract (p1-29f): the Rust encoder
|
||||
# (`mc-player-api/src/learned/encoder.rs::normalize_obs`) MUST apply the
|
||||
# byte-equivalent transform. Both sides compute asinh in f64 then cast to
|
||||
# f32 — f64 asinh agrees cross-implementation to ~1 ULP, well inside the
|
||||
# parity test's 1e-4 obs tolerance. Do NOT hand-roll ln(x + sqrt(x^2 + 1)):
|
||||
# f32 intermediates diverge across implementations.
|
||||
def _normalize_obs(obs: np.ndarray) -> np.ndarray:
|
||||
return np.arcsinh(obs.astype(np.float64)).astype(np.float32)
|
||||
Delegates to the schema interpreter (`obs_contract`) — the single source of
|
||||
truth shared with the Rust encoder. The asinh normalization, field layout,
|
||||
and OBS_DIM all live in `obs_schema.json`. See
|
||||
`.project/designs/obs-contract.md`.
|
||||
"""
|
||||
return obs_contract.encode_observation(view)
|
||||
|
||||
|
||||
def _unit_action_offset(unit_slot: int, sub: int) -> int:
|
||||
|
|
|
|||
112
tooling/rl_self_play/obs_contract.py
Normal file
112
tooling/rl_self_play/obs_contract.py
Normal file
|
|
@ -0,0 +1,112 @@
|
|||
"""Schema-driven observation encoder — the Python half of the shared contract.
|
||||
|
||||
The observation layout is NOT hardcoded here; it is read from the single source
|
||||
of truth `public/games/age-of-dwarves/data/ai/obs_schema.json`. The Rust half
|
||||
(`mc-player-api/src/learned/encoder.rs`) interprets the same schema. A field's
|
||||
math lives in exactly one place — the op vocabulary below — on each side, and
|
||||
the verification gate asserts the two produce byte-equal output.
|
||||
|
||||
See `.project/designs/obs-contract.md`. Paths are PlayerView JSON WIRE keys
|
||||
(note serde rename: `UnitView.type_id` -> `"type"`).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
# Resolve the schema relative to the repo root (this file is at
|
||||
# <repo>/tooling/rl_self_play/obs_contract.py).
|
||||
_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
_SCHEMA_REL = "public/games/age-of-dwarves/data/ai/obs_schema.json"
|
||||
|
||||
|
||||
@lru_cache(maxsize=4)
|
||||
def load_schema(path: str | None = None) -> dict[str, Any]:
|
||||
"""Load + lightly validate the obs schema. Cached by path."""
|
||||
p = path or os.environ.get("MC_OBS_SCHEMA") or os.path.join(_REPO_ROOT, _SCHEMA_REL)
|
||||
with open(p, encoding="utf-8") as fh:
|
||||
schema = json.load(fh)
|
||||
obs_dim = int(schema["obs_dim"])
|
||||
max_i = max((int(f["i"]) for f in schema["fields"]), default=-1)
|
||||
if max_i >= obs_dim:
|
||||
raise ValueError(f"obs schema: field index {max_i} >= obs_dim {obs_dim}")
|
||||
if schema.get("normalize") not in ("asinh", None):
|
||||
raise ValueError(f"obs schema: unsupported normalize {schema.get('normalize')!r}")
|
||||
return schema
|
||||
|
||||
|
||||
def _resolve(obj: Any, path: str) -> Any:
|
||||
"""Walk a dot-path through nested dicts. Returns None on any miss."""
|
||||
cur = obj
|
||||
for part in path.split("."):
|
||||
if isinstance(cur, dict):
|
||||
cur = cur.get(part)
|
||||
else:
|
||||
return None
|
||||
return cur
|
||||
|
||||
|
||||
def _matches(item: dict[str, Any], where: dict[str, Any] | None, me: Any) -> bool:
|
||||
if not where:
|
||||
return True
|
||||
f = item.get(where["field"])
|
||||
if "eq" in where:
|
||||
target = where["eq"]
|
||||
if target == "$me":
|
||||
target = me
|
||||
return f == target
|
||||
if where.get("truthy"):
|
||||
return bool(f)
|
||||
return True
|
||||
|
||||
|
||||
def _apply_field(view: dict[str, Any], fld: dict[str, Any], me: Any) -> float:
|
||||
op = fld["op"]
|
||||
if op == "scalar":
|
||||
v = _resolve(view, fld["path"])
|
||||
return float(v) if v is not None else 0.0
|
||||
if op == "clamp_div":
|
||||
v = _resolve(view, fld["path"])
|
||||
v = float(v) if v is not None else 0.0
|
||||
return float(min(float(fld["max"]), v / float(fld["divisor"])))
|
||||
if op == "reduce":
|
||||
items = view.get(fld["list"], []) or []
|
||||
where = fld.get("where")
|
||||
filt = [it for it in items if isinstance(it, dict) and _matches(it, where, me)]
|
||||
contains = fld.get("contains")
|
||||
if contains is not None:
|
||||
needle, field = contains["needle"], contains["field"]
|
||||
return float(sum(1 for it in filt if needle in str(it.get(field, ""))))
|
||||
agg = fld["agg"]
|
||||
if agg == "count":
|
||||
return float(len(filt))
|
||||
sel = fld.get("select")
|
||||
vals = [float(_resolve(it, sel) or 0) if sel else 1.0 for it in filt]
|
||||
if agg == "sum":
|
||||
return float(sum(vals))
|
||||
if agg == "avg":
|
||||
return float(sum(vals) / len(vals)) if vals else 0.0
|
||||
raise ValueError(f"obs schema: unknown agg {agg!r}")
|
||||
raise ValueError(f"obs schema: unknown op {op!r}")
|
||||
|
||||
|
||||
def obs_dim(path: str | None = None) -> int:
|
||||
return int(load_schema(path)["obs_dim"])
|
||||
|
||||
|
||||
def encode_observation(view: dict[str, Any], path: str | None = None) -> np.ndarray:
|
||||
"""Project a PlayerView dict into the schema's fixed-shape float32 obs."""
|
||||
schema = load_schema(path)
|
||||
dim = int(schema["obs_dim"])
|
||||
obs = np.zeros(dim, dtype=np.float32)
|
||||
me = int(view.get("player", 0))
|
||||
for fld in schema["fields"]:
|
||||
obs[int(fld["i"])] = _apply_field(view, fld, me)
|
||||
if schema.get("normalize") == "asinh":
|
||||
# f64 asinh then cast — byte-equivalent to the Rust side (≤1 ULP).
|
||||
obs = np.arcsinh(obs.astype(np.float64)).astype(np.float32)
|
||||
return obs
|
||||
Loading…
Add table
Reference in a new issue