diff --git a/tooling/rl_self_play/evaluate.py b/tooling/rl_self_play/evaluate.py index 20ade18a..60d050eb 100644 --- a/tooling/rl_self_play/evaluate.py +++ b/tooling/rl_self_play/evaluate.py @@ -7,7 +7,17 @@ Usage: Prints a one-line JSON verdict: {"episodes": 50, "wins": 28, "losses": 18, "draws": 4, + "turn_caps": 0, "step_caps": 0, "win_rate": 0.56, "mean_turns": 142.3} + +Verdicts: + win Terminated with `reason=won` (opponent eliminated). + loss Terminated with `reason=eliminated` or `reason=harness_error`. + draw Terminated with no clear winner. + turn_cap Episode hit --max-turns with no terminal verdict. Not a win: + PlayerView doesn't expose opponent score, so standings at the + cap are not comparable. + step_cap Episode hit the per-episode step cap. Degenerate non-result. """ from __future__ import annotations @@ -34,32 +44,29 @@ def _build_argparser() -> argparse.ArgumentParser: return p -def _classify_episode(info_history: list[dict[str, object]], total_reward: float) -> str: - """Decide win/loss/draw from the last step's info + accumulated reward. +def _classify_episode(info_history: list[dict[str, object]]) -> str: + """Decide verdict from the last step's info. - Win: terminated with positive terminal reward (i.e. opponent eliminated - or score-fallback in our favour — currently only "opponent - eliminated" because the env doesn't yet read win events). - Loss: terminated with `reason=eliminated` (we ran out of cities). - Draw: truncated at max_turns OR rolled positive score-shaping but no - terminal signal — neither side decisively won. + Only terminal events count as win/loss. Truncations (turn_cap, step_cap) + are reported as their own categories; we don't promote positive + score-shaping to a "win" because the simulator's PlayerView doesn't + expose opponent score, so we have no honest way to compare standings + when the clock runs out. """ if not info_history: return "draw" last = info_history[-1] reason = last.get("reason") + if reason == "won": + return "win" if reason == "eliminated": return "loss" if reason == "harness_error": return "loss" if reason == "step_cap": - # Policy stuck in a no-progress loop and the env truncated the - # whole episode — degenerate non-result, surfaced as its own - # category so it's visible in the eval JSON. return "step_cap" - # No explicit win yet from the env; use score sign as tiebreaker. - if total_reward > 0.5: - return "win" + if reason == "turn_cap": + return "turn_cap" return "draw" @@ -71,7 +78,7 @@ def main() -> int: model = MaskablePPO.load(str(args.model_path)) - wins = losses = draws = step_caps = 0 + wins = losses = draws = step_caps = turn_caps = 0 turns_per_episode: list[int] = [] for episode in range(args.episodes): cfg = HarnessConfig( @@ -84,22 +91,22 @@ def main() -> int: try: obs, info = env.reset() done = False - total_reward = 0.0 info_history: list[dict[str, object]] = [] while not done: mask = env.action_masks() action, _ = model.predict(obs, action_masks=mask, deterministic=True) obs, reward, terminated, truncated, info = env.step(int(action)) - total_reward += reward info_history.append(info) done = terminated or truncated - verdict = _classify_episode(info_history, total_reward) + verdict = _classify_episode(info_history) if verdict == "win": wins += 1 elif verdict == "loss": losses += 1 elif verdict == "step_cap": step_caps += 1 + elif verdict == "turn_cap": + turn_caps += 1 else: draws += 1 turns_per_episode.append(int(info.get("turn", 0))) @@ -113,6 +120,7 @@ def main() -> int: "wins": wins, "losses": losses, "draws": draws, + "turn_caps": turn_caps, "step_caps": step_caps, "win_rate": wins / total, "mean_turns": round(mean_turns, 1), @@ -125,6 +133,13 @@ def main() -> int: f"loop. Check encoder/reward shaping.", file=sys.stderr, ) + if turn_caps: + print( + f"NOTE: {turn_caps}/{args.episodes} eval episodes hit the " + f"--max-turns={args.max_turns} cap with no terminal verdict. " + f"Raise --max-turns or expose opponent score for tie-breaking.", + file=sys.stderr, + ) return 0