feat(rl-self-play): ✨ Introduce turn/step cap tracking in evaluation metrics for improved RL self-play observability
Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
This commit is contained in:
parent
47ca06f270
commit
3241bdacd1
1 changed files with 33 additions and 18 deletions
|
|
@ -7,7 +7,17 @@ Usage:
|
|||
|
||||
Prints a one-line JSON verdict:
|
||||
{"episodes": 50, "wins": 28, "losses": 18, "draws": 4,
|
||||
"turn_caps": 0, "step_caps": 0,
|
||||
"win_rate": 0.56, "mean_turns": 142.3}
|
||||
|
||||
Verdicts:
|
||||
win Terminated with `reason=won` (opponent eliminated).
|
||||
loss Terminated with `reason=eliminated` or `reason=harness_error`.
|
||||
draw Terminated with no clear winner.
|
||||
turn_cap Episode hit --max-turns with no terminal verdict. Not a win:
|
||||
PlayerView doesn't expose opponent score, so standings at the
|
||||
cap are not comparable.
|
||||
step_cap Episode hit the per-episode step cap. Degenerate non-result.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
|
|
@ -34,32 +44,29 @@ def _build_argparser() -> argparse.ArgumentParser:
|
|||
return p
|
||||
|
||||
|
||||
def _classify_episode(info_history: list[dict[str, object]], total_reward: float) -> str:
|
||||
"""Decide win/loss/draw from the last step's info + accumulated reward.
|
||||
def _classify_episode(info_history: list[dict[str, object]]) -> str:
|
||||
"""Decide verdict from the last step's info.
|
||||
|
||||
Win: terminated with positive terminal reward (i.e. opponent eliminated
|
||||
or score-fallback in our favour — currently only "opponent
|
||||
eliminated" because the env doesn't yet read win events).
|
||||
Loss: terminated with `reason=eliminated` (we ran out of cities).
|
||||
Draw: truncated at max_turns OR rolled positive score-shaping but no
|
||||
terminal signal — neither side decisively won.
|
||||
Only terminal events count as win/loss. Truncations (turn_cap, step_cap)
|
||||
are reported as their own categories; we don't promote positive
|
||||
score-shaping to a "win" because the simulator's PlayerView doesn't
|
||||
expose opponent score, so we have no honest way to compare standings
|
||||
when the clock runs out.
|
||||
"""
|
||||
if not info_history:
|
||||
return "draw"
|
||||
last = info_history[-1]
|
||||
reason = last.get("reason")
|
||||
if reason == "won":
|
||||
return "win"
|
||||
if reason == "eliminated":
|
||||
return "loss"
|
||||
if reason == "harness_error":
|
||||
return "loss"
|
||||
if reason == "step_cap":
|
||||
# Policy stuck in a no-progress loop and the env truncated the
|
||||
# whole episode — degenerate non-result, surfaced as its own
|
||||
# category so it's visible in the eval JSON.
|
||||
return "step_cap"
|
||||
# No explicit win yet from the env; use score sign as tiebreaker.
|
||||
if total_reward > 0.5:
|
||||
return "win"
|
||||
if reason == "turn_cap":
|
||||
return "turn_cap"
|
||||
return "draw"
|
||||
|
||||
|
||||
|
|
@ -71,7 +78,7 @@ def main() -> int:
|
|||
|
||||
model = MaskablePPO.load(str(args.model_path))
|
||||
|
||||
wins = losses = draws = step_caps = 0
|
||||
wins = losses = draws = step_caps = turn_caps = 0
|
||||
turns_per_episode: list[int] = []
|
||||
for episode in range(args.episodes):
|
||||
cfg = HarnessConfig(
|
||||
|
|
@ -84,22 +91,22 @@ def main() -> int:
|
|||
try:
|
||||
obs, info = env.reset()
|
||||
done = False
|
||||
total_reward = 0.0
|
||||
info_history: list[dict[str, object]] = []
|
||||
while not done:
|
||||
mask = env.action_masks()
|
||||
action, _ = model.predict(obs, action_masks=mask, deterministic=True)
|
||||
obs, reward, terminated, truncated, info = env.step(int(action))
|
||||
total_reward += reward
|
||||
info_history.append(info)
|
||||
done = terminated or truncated
|
||||
verdict = _classify_episode(info_history, total_reward)
|
||||
verdict = _classify_episode(info_history)
|
||||
if verdict == "win":
|
||||
wins += 1
|
||||
elif verdict == "loss":
|
||||
losses += 1
|
||||
elif verdict == "step_cap":
|
||||
step_caps += 1
|
||||
elif verdict == "turn_cap":
|
||||
turn_caps += 1
|
||||
else:
|
||||
draws += 1
|
||||
turns_per_episode.append(int(info.get("turn", 0)))
|
||||
|
|
@ -113,6 +120,7 @@ def main() -> int:
|
|||
"wins": wins,
|
||||
"losses": losses,
|
||||
"draws": draws,
|
||||
"turn_caps": turn_caps,
|
||||
"step_caps": step_caps,
|
||||
"win_rate": wins / total,
|
||||
"mean_turns": round(mean_turns, 1),
|
||||
|
|
@ -125,6 +133,13 @@ def main() -> int:
|
|||
f"loop. Check encoder/reward shaping.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
if turn_caps:
|
||||
print(
|
||||
f"NOTE: {turn_caps}/{args.episodes} eval episodes hit the "
|
||||
f"--max-turns={args.max_turns} cap with no terminal verdict. "
|
||||
f"Raise --max-turns or expose opponent score for tie-breaking.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue