diff --git a/tooling/rl_self_play/smoke_model_opponent.py b/tooling/rl_self_play/smoke_model_opponent.py new file mode 100644 index 00000000..6542e5cc --- /dev/null +++ b/tooling/rl_self_play/smoke_model_opponent.py @@ -0,0 +1,120 @@ +"""Smoke test for the model-backed (self-play) opponent path. + +Unlike `smoke.py` (stdlib-only, MCTS opponent) this necessarily imports +`sb3_contrib` + `torch` because it loads a frozen MaskablePPO snapshot +into the opponent slot. It exercises the full integration: a real +`MagicCivEnv` configured with a `ModelOpponent`, driven by a *random* +masked policy on the learner slot for a bounded number of steps. + +Verifies, against a live harness: + * the multi-slot wire is honoured (both slots externally driven, no + harness errors), + * the frozen opponent takes real turns (the turn counter advances), + * opponent-driven events bubble into the env (no silent captures), + * the episode can run without wedging and reports a terminal/truncation + reason. + +Usage: + python3 -m tooling.rl_self_play.smoke_model_opponent \ + --opponent-model tooling/rl_self_play/models/duel-v4-encfix-s7/best_model.zip \ + --steps 400 + +Prints a one-line JSON verdict; exit 0 on `passed: true`. +""" +from __future__ import annotations + +import argparse +import json +import sys +from pathlib import Path +from typing import Any + +import numpy as np + +THIS_DIR = Path(__file__).resolve().parent +PROJECT_ROOT = THIS_DIR.parents[1] +if __package__ is None: + sys.path.insert(0, str(PROJECT_ROOT)) + + +def main() -> int: + p = argparse.ArgumentParser(description="Model-opponent self-play smoke") + p.add_argument("--opponent-model", required=True, type=Path) + p.add_argument("--steps", type=int, default=400) + p.add_argument("--seed", type=int, default=42) + p.add_argument("--max-turns", type=int, default=200) + p.add_argument("--opponent-device", default="cpu") + args = p.parse_args() + + from tooling.rl_self_play.harness_client import HarnessConfig + from tooling.rl_self_play.magic_civ_env import MagicCivEnv + from tooling.rl_self_play.opponent import ModelOpponent + + reasons: list[str] = [] + details: dict[str, Any] = { + "steps": 0, + "max_turn_seen": 0, + "mask_violations": 0, + "opp_turns_implied": 0, + "terminal_reason": None, + } + + if not args.opponent_model.is_file(): + print(json.dumps({"passed": False, + "reasons": [f"opponent model not found: {args.opponent_model}"], + "details": details})) + return 1 + + rng = np.random.default_rng(args.seed) + opponent = ModelOpponent( + model_path=str(args.opponent_model), + slots=(1,), + device=args.opponent_device, + deterministic=False, + ) + cfg = HarnessConfig(seed=args.seed, players=2, player_slot=0, map_size="duel") + env = MagicCivEnv(harness_config=cfg, max_turns=args.max_turns, opponent=opponent) + + try: + obs, info = env.reset() + mask = info["action_mask"] + prev_turn = 0 + for step_idx in range(args.steps): + legal = np.flatnonzero(mask) + if legal.size == 0: + reasons.append(f"empty action mask at step {step_idx}") + break + action = int(rng.choice(legal)) + if not mask[action]: + details["mask_violations"] += 1 + obs, reward, terminated, truncated, info = env.step(action) + mask = info.get("action_mask", np.zeros_like(mask)) + turn = int(info.get("turn", 0)) + if turn > details["max_turn_seen"]: + details["max_turn_seen"] = turn + # Each turn advance past prev implies the opponent took a turn + # (in all-external 2-slot play the processor steps per end_turn). + if turn > prev_turn: + details["opp_turns_implied"] += 1 + prev_turn = turn + details["steps"] = step_idx + 1 + if terminated or truncated: + details["terminal_reason"] = info.get("reason") + break + except Exception as e: # noqa: BLE001 — smoke wants the failure surfaced + reasons.append(f"exception: {type(e).__name__}: {e}") + finally: + env.close() + + if details["max_turn_seen"] < 1: + reasons.append("turn counter never advanced — opponent/turn loop stuck") + if details["mask_violations"] > 0: + reasons.append(f"{details['mask_violations']} mask violations") + + passed = not reasons + print(json.dumps({"passed": passed, "reasons": reasons, "details": details})) + return 0 if passed else 1 + + +if __name__ == "__main__": + sys.exit(main())