From 20d842004d89efe44e33a6b9e6551146124b0034 Mon Sep 17 00:00:00 2001 From: autocommit Date: Wed, 27 May 2026 20:15:33 -0700 Subject: [PATCH] =?UTF-8?q?feat(rl-self-play):=20=E2=9C=A8=20Add=20methods?= =?UTF-8?q?=20to=20load=20and=20integrate=20learned=20opponent=20policies?= =?UTF-8?q?=20into=20MagicCivEnv=20for=20reinforcement=20learning=20workfl?= =?UTF-8?q?ows?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Lilith Autocommit --- tooling/rl_self_play/magic_civ_env.py | 68 +++++++++++++++++++-------- 1 file changed, 49 insertions(+), 19 deletions(-) diff --git a/tooling/rl_self_play/magic_civ_env.py b/tooling/rl_self_play/magic_civ_env.py index cf882bdc..e509da13 100644 --- a/tooling/rl_self_play/magic_civ_env.py +++ b/tooling/rl_self_play/magic_civ_env.py @@ -16,6 +16,7 @@ its win rate against this baseline; the policy is considered to have from __future__ import annotations import sys +from dataclasses import replace from typing import Any import gymnasium as gym @@ -30,6 +31,7 @@ from .encoders import ( encode_observation, ) from .harness_client import HarnessClient, HarnessConfig, HarnessError +from .opponent import ModelOpponent # Reward shape (Stage 6.1.5 redesign, 2026-05-18). The prior shape used # TURN_ADVANCE_BONUS = 1e-2 which on a 200-turn cap accumulates +2.0 — @@ -106,8 +108,18 @@ DEFAULT_MAX_TURNS = 1000 class MagicCivEnv(gym.Env[np.ndarray, np.int64]): - """Single-player Gym wrapper: our policy controls slot 0, the - harness's built-in AI controls slot 1..N-1.""" + """Single-learner Gym wrapper: our policy controls slot 0. + + The opponent on slot 1..N-1 is one of: + * the harness's built-in MCTS (default — `opponent=None`), driven + internally by the simulator's `apply_end_turn` AI loop; or + * a frozen `ModelOpponent` (self-play curriculum), driven in-process + over the multi-slot wire. When a model opponent is supplied, both + slot 0 and the opponent slots are *externally* controlled, so the + harness's internal AI loop skips them (see dispatch.rs Stage 4) + and this env advances the opponent's turn after the learner's + `end_turn`. + """ metadata = {"render_modes": []} @@ -116,11 +128,20 @@ class MagicCivEnv(gym.Env[np.ndarray, np.int64]): harness_config: HarnessConfig | None = None, max_turns: int = DEFAULT_MAX_TURNS, max_steps_per_episode: int = DEFAULT_MAX_STEPS_PER_EPISODE, + opponent: ModelOpponent | None = None, ) -> None: super().__init__() self._config = harness_config or HarnessConfig() self._max_turns = max_turns self._max_steps_per_episode = max_steps_per_episode + self._opponent = opponent + self._my_slot = self._config.player_slot + # When a model opponent drives the other slot(s), every wire call + # MUST name its slot (multi-slot adapter contract). With the + # default MCTS opponent we keep the legacy single-slot wire shape + # (slot omitted) so nothing about the shipping path changes. + self._multi_slot = opponent is not None + self._slot_kw = self._my_slot if self._multi_slot else None self.observation_space = spaces.Box( low=-1e6, high=1e6, shape=(OBS_DIM,), dtype=np.float32 ) @@ -151,22 +172,22 @@ class MagicCivEnv(gym.Env[np.ndarray, np.int64]): if self._client is not None: self._client.shutdown() cfg = self._config + # When self-playing against a model opponent, both the learner's + # slot and the opponent's slot(s) are externally driven — declare + # them so the simulator's AI loop skips them. + if self._multi_slot and self._opponent is not None: + cfg = replace(cfg, player_slots=(self._my_slot, *self._opponent.slots)) + # `replace` preserves every other field (player_slots, victory_mode, + # player_controllers, …) — the old field-by-field rebuild silently + # dropped them, which would have un-declared the external slots. if seed is not None: - cfg = HarnessConfig( - seed=seed, - players=cfg.players, - player_slot=cfg.player_slot, - map_size=cfg.map_size, - map_type=cfg.map_type, - omniscient=cfg.omniscient, - timeout_sec=cfg.timeout_sec, - ) + cfg = replace(cfg, seed=seed) self._client = HarnessClient(cfg) self._terminated = False self._step_count = 0 self._capital_by_player = {} self._city_founded_rewards_issued = 0 - view = self._client.view() + view = self._client.view(slot=self._slot_kw) # Seed capitals from any cities present at game start. In duel # maps each player begins with a founder, so the capital map is # populated on the first CityFounded event per player (handled @@ -197,11 +218,19 @@ class MagicCivEnv(gym.Env[np.ndarray, np.int64]): prev_turn = int(self._last_view.get("turn", 0)) reward = -_step_penalty(prev_turn) + opp_events: list[dict[str, Any]] = [] try: if player_action.get("type") == "end_turn": - response = self._client.end_turn() + response = self._client.end_turn(slot=self._slot_kw) + # With a frozen model opponent, the simulator's AI loop + # skips the opponent slot (it is externally declared) — so + # we drive its full turn here. With the default MCTS + # opponent this is a no-op: the AI loop already ran inside + # the end_turn dispatch and its events are in `response`. + if self._opponent is not None: + opp_events = self._opponent.play_turn(self._client) else: - response = self._client.act(player_action) + response = self._client.act(player_action, slot=self._slot_kw) except HarnessError: # Treat any harness failure as a loss — bad action, dead # subprocess, etc. Terminate the episode. @@ -214,12 +243,13 @@ class MagicCivEnv(gym.Env[np.ndarray, np.int64]): {"action_mask": np.zeros(ACTION_DIM, dtype=bool), "reason": "harness_error"}, ) - view = self._client.view() - # Collect synchronous events from the act response + any async - # notifications buffered while waiting for view's response. - # Terminal events (game_over / player_eliminated) may have fired - # during the opponent's turn between our act and our view. + view = self._client.view(slot=self._slot_kw) + # Collect synchronous events from the act response + the opponent's + # turn + any async notifications buffered while waiting for view's + # response. Terminal events (game_over / player_eliminated) may + # have fired during the opponent's turn between our act and view. recent_events: list[dict[str, Any]] = list(response.get("events", [])) + recent_events.extend(opp_events) recent_events.extend(self._client.drain_notifications()) new_turn = int(view.get("turn", 0)) me = int(view.get("player", 0))