magicciv/tooling/rl_self_play/evaluate.py
2026-05-27 20:26:00 -07:00

183 lines
6.9 KiB
Python

"""Run a held-out evaluation of a trained MaskablePPO model.
Usage:
python -m tooling.rl_self_play.evaluate \
--model-path tooling/rl_self_play/models/duel-v1/final.zip \
--episodes 50
Prints a one-line JSON verdict:
{"episodes": 50, "wins": 28, "losses": 18, "draws": 4,
"turn_caps": 0, "step_caps": 0,
"win_rate": 0.56, "mean_turns": 142.3}
Verdicts:
win Terminated with `reason=won` (opponent eliminated).
loss Terminated with `reason=eliminated` or `reason=harness_error`.
draw Terminated with no clear winner.
turn_cap Episode hit --max-turns with no terminal verdict. Not a win:
PlayerView doesn't expose opponent score, so standings at the
cap are not comparable.
step_cap Episode hit the per-episode step cap. Degenerate non-result.
"""
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
THIS_DIR = Path(__file__).resolve().parent
PROJECT_ROOT = THIS_DIR.parents[1]
if __package__ is None:
sys.path.insert(0, str(PROJECT_ROOT))
def _build_argparser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(description="Evaluate a trained policy against built-in AI")
p.add_argument("--model-path", required=True, type=Path)
p.add_argument("--episodes", type=int, default=50)
p.add_argument("--max-turns", type=int, default=1000)
p.add_argument("--seed-offset", type=int, default=10_000,
help="Eval episode seeds = offset + episode_idx; avoids overlap with train seeds")
p.add_argument("--players", type=int, default=2)
p.add_argument("--map-size", default="duel")
p.add_argument("--opponent-model", default=None, type=Path,
help=("Frozen MaskablePPO snapshot (.zip) to use as the "
"opponent instead of the harness MCTS. Set to the "
"graduated snapshot to measure win-rate against a "
"learned policy (self-play curriculum gate)."))
p.add_argument("--opponent-slots", default="1",
help="Comma-separated opponent slot indices (default '1').")
p.add_argument("--opponent-device", default="cpu")
p.add_argument("--opponent-deterministic", action="store_true",
help="Argmax opponent actions (default: stochastic sampling).")
p.add_argument("--learner-deterministic", action=argparse.BooleanOptionalAction,
default=True,
help=("Argmax the evaluated (slot-0) policy. Default True. "
"For a symmetric self-play sanity check (e.g. v4 vs "
"v4, expect ~50%) pass --no-learner-deterministic so "
"both sides sample from the masked softmax — matching "
"the stochastic training-eval regime."))
return p
def _classify_episode(info_history: list[dict[str, object]]) -> str:
"""Decide verdict from the last step's info.
Only terminal events count as win/loss. Truncations (turn_cap, step_cap)
are reported as their own categories; we don't promote positive
score-shaping to a "win" because the simulator's PlayerView doesn't
expose opponent score, so we have no honest way to compare standings
when the clock runs out.
"""
if not info_history:
return "draw"
last = info_history[-1]
reason = last.get("reason")
if reason == "won":
return "win"
if reason == "eliminated":
return "loss"
if reason == "harness_error":
return "loss"
if reason == "step_cap":
return "step_cap"
if reason == "turn_cap":
return "turn_cap"
return "draw"
def main() -> int:
args = _build_argparser().parse_args()
from sb3_contrib import MaskablePPO # type: ignore[import-not-found]
from tooling.rl_self_play.harness_client import HarnessConfig # type: ignore[import-not-found]
from tooling.rl_self_play.magic_civ_env import MagicCivEnv # type: ignore[import-not-found]
from tooling.rl_self_play.opponent import ModelOpponent # type: ignore[import-not-found]
model = MaskablePPO.load(str(args.model_path))
opp_slots: tuple[int, ...] = tuple(
int(s) for s in str(args.opponent_slots).split(",") if s.strip()
)
def _make_opponent() -> ModelOpponent | None:
if not args.opponent_model:
return None
return ModelOpponent(
model_path=str(args.opponent_model),
slots=opp_slots,
device=args.opponent_device,
deterministic=args.opponent_deterministic,
)
wins = losses = draws = step_caps = turn_caps = 0
turns_per_episode: list[int] = []
for episode in range(args.episodes):
cfg = HarnessConfig(
seed=args.seed_offset + episode,
players=args.players,
player_slot=0,
map_size=args.map_size,
)
env = MagicCivEnv(
harness_config=cfg, max_turns=args.max_turns, opponent=_make_opponent()
)
try:
obs, info = env.reset()
done = False
info_history: list[dict[str, object]] = []
while not done:
mask = env.action_masks()
action, _ = model.predict(
obs, action_masks=mask, deterministic=args.learner_deterministic
)
obs, reward, terminated, truncated, info = env.step(int(action))
info_history.append(info)
done = terminated or truncated
verdict = _classify_episode(info_history)
if verdict == "win":
wins += 1
elif verdict == "loss":
losses += 1
elif verdict == "step_cap":
step_caps += 1
elif verdict == "turn_cap":
turn_caps += 1
else:
draws += 1
turns_per_episode.append(int(info.get("turn", 0)))
finally:
env.close()
total = max(args.episodes, 1)
mean_turns = sum(turns_per_episode) / max(len(turns_per_episode), 1)
verdict = {
"episodes": args.episodes,
"wins": wins,
"losses": losses,
"draws": draws,
"turn_caps": turn_caps,
"step_caps": step_caps,
"win_rate": wins / total,
"mean_turns": round(mean_turns, 1),
}
print(json.dumps(verdict))
if step_caps:
print(
f"WARNING: {step_caps}/{args.episodes} eval episodes hit the "
f"per-episode step cap — policy got stuck in a no-progress "
f"loop. Check encoder/reward shaping.",
file=sys.stderr,
)
if turn_caps:
print(
f"NOTE: {turn_caps}/{args.episodes} eval episodes hit the "
f"--max-turns={args.max_turns} cap with no terminal verdict. "
f"Raise --max-turns or expose opponent score for tie-breaking.",
file=sys.stderr,
)
return 0
if __name__ == "__main__":
sys.exit(main())