magicciv/tooling/rl_self_play/encoders.py
Natalie b67764ec67
Some checks are pending
ci / regression gate (push) Waiting to run
deploy-next / deploy dev guide to mc.next.black.lan (push) Waiting to run
feat(ai): shared obs encoder contract — schema as single source of truth (Python side)
Replace the hand-duplicated observation encoder with a schema-driven contract:
obs_schema.json declares the layout (version, obs_dim, per-field ops from a fixed
vocabulary: scalar/reduce/clamp_div, +onehot/frac/histogram/per_entity for v2),
and both Python and Rust interpret it instead of hardcoding the math. Kills the
bit-exact-drift risk that made growing 32->96 dims dangerous.

This commit lands the Python half + the v1 schema (reproduces the historical
32-dim encoder EXACTLY): obs_contract.py interprets the schema; encoders.py
delegates to it (OBS_DIM + field math now come from the schema, not module code).
Verified locally: encoders.encode_observation matches all 56 parity fixtures with
ZERO drift. Design: .project/designs/obs-contract.md.

Next: Rust interpreter (encoder.rs reads the embedded schema), verify-obs-contract
gate + version assertions, then bump to v2 (richer 96-dim) as a schema data change.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-30 11:49:04 -04:00

209 lines
8.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""PlayerView ⇄ fixed-shape tensors for RL.
The wire-side `PlayerView` is a deeply-nested JSON dict; RL libraries
need fixed-shape numeric arrays. We pin two contracts here:
1. **Observation encoder** (`encode_observation`) projects the view into
a fixed-length float32 vector. Length is `OBS_DIM`; layout is
deterministic and documented inline so the policy net can learn a
stable embedding.
2. **Action index encoder** (`encode_legal_actions` /
`decode_action_index`) flattens the view's `legal_actions` (top-level
+ per-unit + per-city) into a fixed-size index space `[0, ACTION_DIM)`.
Indices not occupied by a legal action in the current state are
masked out by `legal_action_mask`. MaskablePPO consumes that mask
directly.
These encoders are intentionally lossy — they discard tile-by-tile data
and only summarise the macro state. Replace with a CNN-based observation
once the macro head proves the loop works end-to-end.
"""
from __future__ import annotations
from typing import Any
import numpy as np
try:
from . import obs_contract
except ImportError: # flat import when the dir is on sys.path (harness/tests)
import obs_contract
# ── Observation shape ────────────────────────────────────────────────
# The observation layout is the SHARED CONTRACT in
# `public/games/age-of-dwarves/data/ai/obs_schema.json`, interpreted by
# `obs_contract` (Python) and `mc-player-api/src/learned/encoder.rs` (Rust).
# OBS_DIM and the per-field math live in the schema, not here — see
# `.project/designs/obs-contract.md`.
OBS_DIM = obs_contract.obs_dim()
# ── Action index layout ──────────────────────────────────────────────
# We bucket legal actions deterministically:
# [0] end_turn
# [1] noop
# [2..2+MAX_UNITS*K) per-unit slots (skip, fortify, sentry, found_city,
# move-N/NE/SE/S/SW/NW (6 dirs),
# attack-target N/NE/SE/S/SW/NW (6 dirs))
# tail per-city build queue: indices into a fixed
# priority-ordered roster (worker, warrior, library,
# barracks, forge, walls, longhouse, monument)
#
# Anything legal but outside this layout is silently dropped — the RL
# agent simply can't learn to take it. For a duel game, the layout
# below covers >95% of legitimate openings; for the full 5-player
# huge-map case we extend MAX_UNITS / CITY_QUEUE_SLOTS once the basic
# loop trains.
MAX_UNITS = 16
PER_UNIT_ACTIONS = 16 # skip, fortify, sentry, found, move×6, attack×6, unfortify
MAX_CITIES = 4
CITY_QUEUE_ITEMS: tuple[str, ...] = (
"worker", "warrior", "library", "barracks", "forge",
"walls", "longhouse", "monument", "dwarf_warrior", "dwarf_founder",
"spearmen", "archer", "temple", "high_guild_hall", "chronicle_tower",
"mead_hall",
)
CITY_QUEUE_DIM = len(CITY_QUEUE_ITEMS)
ACTION_DIM = (
2 # end_turn, noop
+ MAX_UNITS * PER_UNIT_ACTIONS
+ MAX_CITIES * CITY_QUEUE_DIM
)
# Hex neighbour offsets. The world uses an **odd-q** offset layout
# (flat-top hexes laid out in columns; odd columns shifted down), so a
# hex's six neighbours are determined by its COLUMN parity. Verified
# against live `view.units[].legal_actions` move targets (Stage 6.1.6
# diagnostic). An earlier revision keyed the table on row parity and
# silently dropped roughly half of every odd-parity unit's legal moves
# out of the action mask — a latent bug in the RL env's move space.
# See public/games/age-of-dwarves/docs/HEX_GEOMETRY.md.
_DIR_OFFSETS_EVEN_COL: tuple[tuple[int, int], ...] = (
(0, -1), (1, -1), (1, 0), (0, 1), (-1, 0), (-1, -1),
)
_DIR_OFFSETS_ODD_COL: tuple[tuple[int, int], ...] = (
(0, -1), (1, 0), (1, 1), (0, 1), (-1, 1), (-1, 0),
)
def _hex_direction(from_pos: tuple[int, int], to_pos: tuple[int, int]) -> int | None:
"""Return 0..5 indexing the matching neighbour offset, or None if the
target is not one of `from_pos`'s six neighbours. odd-q layout — the
offset table is selected by the *column* parity of `from_pos`."""
fc, fr = from_pos
tc, tr = to_pos
dc, dr = tc - fc, tr - fr
table = _DIR_OFFSETS_EVEN_COL if (fc % 2 == 0) else _DIR_OFFSETS_ODD_COL
for i, (odc, odr) in enumerate(table):
if (odc, odr) == (dc, dr):
return i
return None
def encode_observation(view: dict[str, Any]) -> np.ndarray:
"""Project a PlayerView dict into the shared-contract observation vector.
Delegates to the schema interpreter (`obs_contract`) — the single source of
truth shared with the Rust encoder. The asinh normalization, field layout,
and OBS_DIM all live in `obs_schema.json`. See
`.project/designs/obs-contract.md`.
"""
return obs_contract.encode_observation(view)
def _unit_action_offset(unit_slot: int, sub: int) -> int:
return 2 + unit_slot * PER_UNIT_ACTIONS + sub
def _city_action_offset(city_slot: int, item_idx: int) -> int:
return 2 + MAX_UNITS * PER_UNIT_ACTIONS + city_slot * CITY_QUEUE_DIM + item_idx
def encode_legal_actions(
view: dict[str, Any],
) -> tuple[np.ndarray, dict[int, dict[str, Any]]]:
"""Build the action-mask + an index→PlayerAction lookup table.
Returns (mask[ACTION_DIM] bool, idx_to_action dict). Only entries
present in the returned dict are legal this step; the mask is True
at those positions. MaskablePPO uses the mask to zero out the
sampling distribution before drawing.
"""
mask = np.zeros(ACTION_DIM, dtype=bool)
idx_to_action: dict[int, dict[str, Any]] = {}
top = view.get("legal_actions", [])
for entry in top:
a = entry.get("action", {})
if a.get("type") == "end_turn":
mask[0] = True
idx_to_action[0] = a
elif a.get("type") == "noop":
mask[1] = True
idx_to_action[1] = a
units = view.get("units", [])
me = int(view.get("player", 0))
my_units = [u for u in units if int(u.get("owner", -1)) == me]
for slot, u in enumerate(my_units[:MAX_UNITS]):
upos = tuple(int(x) for x in u.get("position", (0, 0)))
for entry in u.get("legal_actions", []):
a = entry.get("action", {})
sub: int | None = None
t = a.get("type")
if t == "skip":
sub = 0
elif t == "fortify":
sub = 1
elif t == "sentry":
sub = 2
elif t == "found_city":
sub = 3
elif t == "unfortify":
sub = 4
elif t == "move":
dir_idx = _hex_direction(
upos, tuple(int(x) for x in a.get("to", (0, 0)))
)
if dir_idx is not None:
sub = 5 + dir_idx # 5..10
elif t == "attack":
dir_idx = _hex_direction(
upos, tuple(int(x) for x in a.get("target", (0, 0)))
)
if dir_idx is not None:
sub = 11 + dir_idx # 11..15 (15 is also unfortify? no, 11..16 but PER_UNIT_ACTIONS=16)
sub = min(sub, PER_UNIT_ACTIONS - 1)
if sub is None:
continue
offset = _unit_action_offset(slot, sub)
if offset < ACTION_DIM:
mask[offset] = True
idx_to_action[offset] = a
cities = view.get("cities", [])
for slot, c in enumerate(cities[:MAX_CITIES]):
for entry in c.get("legal_actions", []):
a = entry.get("action", {})
if a.get("type") != "queue_production":
continue
item = str(a.get("item", ""))
if item not in CITY_QUEUE_ITEMS:
continue
item_idx = CITY_QUEUE_ITEMS.index(item)
offset = _city_action_offset(slot, item_idx)
if offset < ACTION_DIM:
mask[offset] = True
idx_to_action[offset] = a
return mask, idx_to_action
def decode_action_index(
index: int, idx_to_action: dict[int, dict[str, Any]]
) -> dict[str, Any]:
"""Invert `encode_legal_actions`. If the policy picks an index that
has been masked (shouldn't happen with MaskablePPO, but defensive
code is cheap), fall back to `end_turn`."""
return idx_to_action.get(index, {"type": "end_turn"})