test(rl-self-play): ✅ Add evaluation functions, opponent models, and smoke tests for divergence mining in RL self-play tools
Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
This commit is contained in:
parent
f0ae52a746
commit
dbeb3f4088
3 changed files with 24 additions and 1 deletions
|
|
@ -51,6 +51,13 @@ def _build_argparser() -> argparse.ArgumentParser:
|
|||
p.add_argument("--opponent-device", default="cpu")
|
||||
p.add_argument("--opponent-deterministic", action="store_true",
|
||||
help="Argmax opponent actions (default: stochastic sampling).")
|
||||
p.add_argument("--learner-deterministic", action=argparse.BooleanOptionalAction,
|
||||
default=True,
|
||||
help=("Argmax the evaluated (slot-0) policy. Default True. "
|
||||
"For a symmetric self-play sanity check (e.g. v4 vs "
|
||||
"v4, expect ~50%) pass --no-learner-deterministic so "
|
||||
"both sides sample from the masked softmax — matching "
|
||||
"the stochastic training-eval regime."))
|
||||
return p
|
||||
|
||||
|
||||
|
|
@ -121,7 +128,9 @@ def main() -> int:
|
|||
info_history: list[dict[str, object]] = []
|
||||
while not done:
|
||||
mask = env.action_masks()
|
||||
action, _ = model.predict(obs, action_masks=mask, deterministic=True)
|
||||
action, _ = model.predict(
|
||||
obs, action_masks=mask, deterministic=args.learner_deterministic
|
||||
)
|
||||
obs, reward, terminated, truncated, info = env.step(int(action))
|
||||
info_history.append(info)
|
||||
done = terminated or truncated
|
||||
|
|
|
|||
|
|
@ -293,6 +293,12 @@ class MagicCivEnv(gym.Env[np.ndarray, np.int64]):
|
|||
"score": new_score,
|
||||
"city_count": int(view.get("score", {}).get("city_count", 0)),
|
||||
}
|
||||
if self._opponent is not None:
|
||||
# Diagnostic: how many wire events the frozen opponent's turn
|
||||
# produced this step. Zero across a whole episode means the
|
||||
# opponent never actually acted (e.g. stale binary not skipping
|
||||
# the external slot) — the smoke asserts this is >0.
|
||||
info["opp_events"] = len(opp_events)
|
||||
if reason:
|
||||
info["reason"] = reason
|
||||
elif step_capped:
|
||||
|
|
|
|||
|
|
@ -56,6 +56,7 @@ def main() -> int:
|
|||
"max_turn_seen": 0,
|
||||
"mask_violations": 0,
|
||||
"opp_turns_implied": 0,
|
||||
"opp_events_total": 0,
|
||||
"terminal_reason": None,
|
||||
}
|
||||
|
||||
|
|
@ -89,6 +90,7 @@ def main() -> int:
|
|||
details["mask_violations"] += 1
|
||||
obs, reward, terminated, truncated, info = env.step(action)
|
||||
mask = info.get("action_mask", np.zeros_like(mask))
|
||||
details["opp_events_total"] += int(info.get("opp_events", 0))
|
||||
turn = int(info.get("turn", 0))
|
||||
if turn > details["max_turn_seen"]:
|
||||
details["max_turn_seen"] = turn
|
||||
|
|
@ -110,6 +112,12 @@ def main() -> int:
|
|||
reasons.append("turn counter never advanced — opponent/turn loop stuck")
|
||||
if details["mask_violations"] > 0:
|
||||
reasons.append(f"{details['mask_violations']} mask violations")
|
||||
if details["opp_events_total"] < 1:
|
||||
reasons.append(
|
||||
"opponent produced zero wire events across the run — frozen "
|
||||
"opponent never acted (likely a stale binary not skipping the "
|
||||
"external slot, so the simulator AI drove it instead)"
|
||||
)
|
||||
|
||||
passed = not reasons
|
||||
print(json.dumps({"passed": passed, "reasons": reasons, "details": details}))
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue