magicciv/tooling/rl_self_play/train.py
2026-05-27 20:15:33 -07:00

252 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Train a MaskablePPO policy against the harness's built-in AI.
Usage:
cd tooling/rl-self-play
pip install -r requirements.txt
python -m tooling.rl-self-play.train --total-steps 1_000_000
Run via TensorBoard for live curves:
tensorboard --logdir tooling/rl-self-play/runs/
The training loop:
1. K parallel `MagicCivEnv` instances are spawned (each owns a Godot
harness subprocess; rule of thumb: K = min(physical cores // 2, 8)).
2. MaskablePPO collects on-policy rollouts across all K envs, learns
for `total_timesteps`.
3. Every `eval_freq` steps we run a held-out eval against the same
baseline and record win-rate. When win-rate crosses
`--target-win-rate` (default 0.55) we save the model as
`models/winner.zip` and exit.
This script is intentionally minimal — no curriculum, no
self-play-against-frozen-snapshots, no league. Those are reasonable
extensions once the basic policy starts winning at all (which itself
will take hours on apricot).
"""
from __future__ import annotations
import argparse
import os
import sys
from pathlib import Path
import numpy as np
THIS_DIR = Path(__file__).resolve().parent
PROJECT_ROOT = THIS_DIR.parents[1]
# Resolve module path so the script works whether invoked as a module
# (`python -m tooling.rl-self-play.train`) or as a plain script
# (`python train.py`). Both paths matter — the former is the canonical
# way; the latter helps quick iteration without re-installing.
if __package__ is None:
sys.path.insert(0, str(PROJECT_ROOT))
from tooling.rl_self_play.harness_client import HarnessConfig # type: ignore[import-not-found]
def _build_argparser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(description="Train MaskablePPO on Magic Civilization")
p.add_argument("--total-steps", type=int, default=1_000_000,
help="Total environment steps (default: 1M).")
p.add_argument("--num-envs", type=int, default=4,
help="Parallel envs; each spawns its own harness (default: 4).")
p.add_argument("--max-turns", type=int, default=1000,
help="Per-episode turn limit before truncation (default: 1000, Stage 6.1.5).")
p.add_argument("--map-size", default="duel",
help="MapGenerator size key (default: duel).")
p.add_argument("--players", type=int, default=2,
help="Total player slots in each game (default: 2).")
p.add_argument("--eval-freq", type=int, default=20_000,
help="Run eval every N steps (default: 20k).")
p.add_argument("--eval-episodes", type=int, default=20,
help="Episodes per eval (default: 20).")
p.add_argument("--target-win-rate", type=float, default=0.55,
help="Stop training once eval win-rate exceeds this (default: 0.55).")
p.add_argument("--run-name", default="duel-v1",
help="Subdirectory under runs/ + models/ (default: duel-v1).")
p.add_argument("--seed", type=int, default=42,
help="Base RNG seed; per-env seeds offset from this (default: 42).")
p.add_argument("--init-from", default=None,
help=("Path to a MaskablePPO checkpoint (.zip) to warm-start "
"from — e.g. a behavioural-cloning checkpoint produced "
"by bc_pretrain.py (Stage 6.1.6). When set, PPO refines "
"the loaded policy instead of training from scratch."))
p.add_argument("--device", default="auto",
help=("Torch device for the policy net: 'auto' (default — "
"picks cuda if available, else cpu), 'cuda', "
"'cuda:1' (second GPU), 'mps' (Apple Silicon), or "
"'cpu'. On apricot, prefer 'cuda:1' so cuda:0 stays "
"free for model-boss / MCTS rollouts."))
p.add_argument("--opponent-model", default=None,
help=("Path to a frozen MaskablePPO snapshot (.zip) to use "
"as the opponent instead of the harness MCTS — the "
"AlphaZero-style self-play curriculum rung. When set, "
"slot 0 (learner) and the opponent slot(s) are both "
"externally driven; the frozen policy plays the "
"opponent in-process. Default None = MCTS opponent."))
p.add_argument("--opponent-slots", default="1",
help="Comma-separated opponent slot indices the frozen "
"model drives (default: '1').")
p.add_argument("--opponent-device", default="cpu",
help=("Torch device for the frozen opponent net (default "
"'cpu' — keeps the GPU for the learner; the opponent "
"MlpPolicy is tiny and runs many copies, one per env)."))
p.add_argument("--opponent-deterministic", action="store_true",
help=("Sample the opponent's actions with argmax instead of "
"from the masked softmax. Default off: a stochastic "
"opponent varies its play across seeds so the learner "
"cannot overfit a single trajectory."))
return p
def _make_env_factory(args: argparse.Namespace, env_idx: int):
"""Return a thunk that constructs one MagicCivEnv. sb3 expects these
as factories so each subprocess builds its own env after fork."""
from tooling.rl_self_play.magic_civ_env import MagicCivEnv # type: ignore[import-not-found]
from tooling.rl_self_play.opponent import ModelOpponent # type: ignore[import-not-found]
opp_slots: tuple[int, ...] = tuple(
int(s) for s in str(args.opponent_slots).split(",") if s.strip()
)
def _make() -> MagicCivEnv:
cfg = HarnessConfig(
seed=args.seed + env_idx,
players=args.players,
player_slot=0,
map_size=args.map_size,
map_type="continents",
)
opponent = None
if args.opponent_model:
opponent = ModelOpponent(
model_path=args.opponent_model,
slots=opp_slots,
device=args.opponent_device,
deterministic=args.opponent_deterministic,
)
return MagicCivEnv(
harness_config=cfg, max_turns=args.max_turns, opponent=opponent
)
return _make
def main() -> int:
args = _build_argparser().parse_args()
# Lazy imports — sb3 + torch are heavy and only needed once we
# commit to running. Lets `--help` stay fast.
from sb3_contrib import MaskablePPO # type: ignore[import-not-found]
from sb3_contrib.common.maskable.callbacks import ( # type: ignore[import-not-found]
MaskableEvalCallback,
)
from sb3_contrib.common.maskable.utils import ( # type: ignore[import-not-found]
get_action_masks,
)
from stable_baselines3.common.vec_env import ( # type: ignore[import-not-found]
DummyVecEnv,
SubprocVecEnv,
)
run_dir = THIS_DIR / "runs" / args.run_name
model_dir = THIS_DIR / "models" / args.run_name
run_dir.mkdir(parents=True, exist_ok=True)
model_dir.mkdir(parents=True, exist_ok=True)
factories = [_make_env_factory(args, i) for i in range(args.num_envs)]
# SubprocVecEnv runs each env in its own process — necessary because
# each env owns a Godot subprocess (we don't want one harness's
# JSON-Lines pipe to block sibling envs). DummyVecEnv is the
# single-process fallback for debugging.
env_cls = SubprocVecEnv if args.num_envs > 1 else DummyVecEnv
train_env = env_cls(factories)
eval_env = DummyVecEnv([_make_env_factory(args, 1000)])
eval_callback = MaskableEvalCallback(
eval_env,
best_model_save_path=str(model_dir),
log_path=str(run_dir / "eval"),
eval_freq=max(args.eval_freq // args.num_envs, 1),
n_eval_episodes=args.eval_episodes,
# Stochastic eval: a barely-trained net's argmax over the
# 322-dim action head has ~zero chance of being end_turn (idx 0),
# so deterministic eval episodes never advance past turn 0 and
# all 10 hit step_cap with reward 0. Sampling from the masked
# softmax keeps end_turn reachable until the policy has
# consolidated enough mass on a real strategy.
deterministic=False,
render=False,
)
# Resolve `--device` for logging clarity — sb3 accepts 'auto' but we
# want to print exactly which device the rollouts will land on so a
# multi-GPU box (apricot has 2× RTX 3090) can be confirmed at a glance.
import torch # type: ignore[import-not-found]
if args.device == "auto":
if torch.cuda.is_available():
resolved_device = "cuda"
elif getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available():
resolved_device = "mps"
else:
resolved_device = "cpu"
else:
resolved_device = args.device
print(
f"policy device: {resolved_device} "
f"(cuda_available={torch.cuda.is_available()}, "
f"cuda_devices={torch.cuda.device_count() if torch.cuda.is_available() else 0})"
)
if args.init_from:
init_path = Path(args.init_from)
if not init_path.is_file():
print(f"--init-from checkpoint not found: {init_path}", file=sys.stderr)
return 1
# Warm-start: load the BC checkpoint's policy weights, re-bind it
# to the live training envs + this run's tensorboard log. The
# checkpoint was saved with identical hyperparameters (see
# bc_pretrain.py), so PPO continues with the intended schedule.
print(f"warm-starting from BC checkpoint: {init_path}")
model = MaskablePPO.load(
str(init_path),
env=train_env,
device=resolved_device,
tensorboard_log=str(run_dir),
)
else:
model = MaskablePPO(
"MlpPolicy",
train_env,
verbose=1,
tensorboard_log=str(run_dir),
seed=args.seed,
device=resolved_device,
n_steps=512,
batch_size=128,
learning_rate=3e-4,
gamma=0.995,
gae_lambda=0.95,
ent_coef=0.01,
)
try:
model.learn(
total_timesteps=args.total_steps,
callback=eval_callback,
progress_bar=True,
reset_num_timesteps=True,
)
finally:
train_env.close()
eval_env.close()
model.save(str(model_dir / "final.zip"))
print(f"training complete; model saved to {model_dir / 'final.zip'}")
return 0
if __name__ == "__main__":
sys.exit(main())