test(rl-self-play): Add evaluation functions, opponent models, and smoke tests for divergence mining in RL self-play tools

Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
This commit is contained in:
autocommit 2026-05-27 20:26:00 -07:00
parent f0ae52a746
commit dbeb3f4088
3 changed files with 24 additions and 1 deletions

View file

@ -51,6 +51,13 @@ def _build_argparser() -> argparse.ArgumentParser:
p.add_argument("--opponent-device", default="cpu")
p.add_argument("--opponent-deterministic", action="store_true",
help="Argmax opponent actions (default: stochastic sampling).")
p.add_argument("--learner-deterministic", action=argparse.BooleanOptionalAction,
default=True,
help=("Argmax the evaluated (slot-0) policy. Default True. "
"For a symmetric self-play sanity check (e.g. v4 vs "
"v4, expect ~50%) pass --no-learner-deterministic so "
"both sides sample from the masked softmax — matching "
"the stochastic training-eval regime."))
return p
@ -121,7 +128,9 @@ def main() -> int:
info_history: list[dict[str, object]] = []
while not done:
mask = env.action_masks()
action, _ = model.predict(obs, action_masks=mask, deterministic=True)
action, _ = model.predict(
obs, action_masks=mask, deterministic=args.learner_deterministic
)
obs, reward, terminated, truncated, info = env.step(int(action))
info_history.append(info)
done = terminated or truncated

View file

@ -293,6 +293,12 @@ class MagicCivEnv(gym.Env[np.ndarray, np.int64]):
"score": new_score,
"city_count": int(view.get("score", {}).get("city_count", 0)),
}
if self._opponent is not None:
# Diagnostic: how many wire events the frozen opponent's turn
# produced this step. Zero across a whole episode means the
# opponent never actually acted (e.g. stale binary not skipping
# the external slot) — the smoke asserts this is >0.
info["opp_events"] = len(opp_events)
if reason:
info["reason"] = reason
elif step_capped:

View file

@ -56,6 +56,7 @@ def main() -> int:
"max_turn_seen": 0,
"mask_violations": 0,
"opp_turns_implied": 0,
"opp_events_total": 0,
"terminal_reason": None,
}
@ -89,6 +90,7 @@ def main() -> int:
details["mask_violations"] += 1
obs, reward, terminated, truncated, info = env.step(action)
mask = info.get("action_mask", np.zeros_like(mask))
details["opp_events_total"] += int(info.get("opp_events", 0))
turn = int(info.get("turn", 0))
if turn > details["max_turn_seen"]:
details["max_turn_seen"] = turn
@ -110,6 +112,12 @@ def main() -> int:
reasons.append("turn counter never advanced — opponent/turn loop stuck")
if details["mask_violations"] > 0:
reasons.append(f"{details['mask_violations']} mask violations")
if details["opp_events_total"] < 1:
reasons.append(
"opponent produced zero wire events across the run — frozen "
"opponent never acted (likely a stale binary not skipping the "
"external slot, so the simulator AI drove it instead)"
)
passed = not reasons
print(json.dumps({"passed": passed, "reasons": reasons, "details": details}))