feat(rl-self-play): ✨ Add methods to load and integrate learned opponent policies into MagicCivEnv for reinforcement learning workflows
Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
This commit is contained in:
parent
e2e578cdab
commit
20d842004d
1 changed files with 49 additions and 19 deletions
|
|
@ -16,6 +16,7 @@ its win rate against this baseline; the policy is considered to have
|
|||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from dataclasses import replace
|
||||
from typing import Any
|
||||
|
||||
import gymnasium as gym
|
||||
|
|
@ -30,6 +31,7 @@ from .encoders import (
|
|||
encode_observation,
|
||||
)
|
||||
from .harness_client import HarnessClient, HarnessConfig, HarnessError
|
||||
from .opponent import ModelOpponent
|
||||
|
||||
# Reward shape (Stage 6.1.5 redesign, 2026-05-18). The prior shape used
|
||||
# TURN_ADVANCE_BONUS = 1e-2 which on a 200-turn cap accumulates +2.0 —
|
||||
|
|
@ -106,8 +108,18 @@ DEFAULT_MAX_TURNS = 1000
|
|||
|
||||
|
||||
class MagicCivEnv(gym.Env[np.ndarray, np.int64]):
|
||||
"""Single-player Gym wrapper: our policy controls slot 0, the
|
||||
harness's built-in AI controls slot 1..N-1."""
|
||||
"""Single-learner Gym wrapper: our policy controls slot 0.
|
||||
|
||||
The opponent on slot 1..N-1 is one of:
|
||||
* the harness's built-in MCTS (default — `opponent=None`), driven
|
||||
internally by the simulator's `apply_end_turn` AI loop; or
|
||||
* a frozen `ModelOpponent` (self-play curriculum), driven in-process
|
||||
over the multi-slot wire. When a model opponent is supplied, both
|
||||
slot 0 and the opponent slots are *externally* controlled, so the
|
||||
harness's internal AI loop skips them (see dispatch.rs Stage 4)
|
||||
and this env advances the opponent's turn after the learner's
|
||||
`end_turn`.
|
||||
"""
|
||||
|
||||
metadata = {"render_modes": []}
|
||||
|
||||
|
|
@ -116,11 +128,20 @@ class MagicCivEnv(gym.Env[np.ndarray, np.int64]):
|
|||
harness_config: HarnessConfig | None = None,
|
||||
max_turns: int = DEFAULT_MAX_TURNS,
|
||||
max_steps_per_episode: int = DEFAULT_MAX_STEPS_PER_EPISODE,
|
||||
opponent: ModelOpponent | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._config = harness_config or HarnessConfig()
|
||||
self._max_turns = max_turns
|
||||
self._max_steps_per_episode = max_steps_per_episode
|
||||
self._opponent = opponent
|
||||
self._my_slot = self._config.player_slot
|
||||
# When a model opponent drives the other slot(s), every wire call
|
||||
# MUST name its slot (multi-slot adapter contract). With the
|
||||
# default MCTS opponent we keep the legacy single-slot wire shape
|
||||
# (slot omitted) so nothing about the shipping path changes.
|
||||
self._multi_slot = opponent is not None
|
||||
self._slot_kw = self._my_slot if self._multi_slot else None
|
||||
self.observation_space = spaces.Box(
|
||||
low=-1e6, high=1e6, shape=(OBS_DIM,), dtype=np.float32
|
||||
)
|
||||
|
|
@ -151,22 +172,22 @@ class MagicCivEnv(gym.Env[np.ndarray, np.int64]):
|
|||
if self._client is not None:
|
||||
self._client.shutdown()
|
||||
cfg = self._config
|
||||
# When self-playing against a model opponent, both the learner's
|
||||
# slot and the opponent's slot(s) are externally driven — declare
|
||||
# them so the simulator's AI loop skips them.
|
||||
if self._multi_slot and self._opponent is not None:
|
||||
cfg = replace(cfg, player_slots=(self._my_slot, *self._opponent.slots))
|
||||
# `replace` preserves every other field (player_slots, victory_mode,
|
||||
# player_controllers, …) — the old field-by-field rebuild silently
|
||||
# dropped them, which would have un-declared the external slots.
|
||||
if seed is not None:
|
||||
cfg = HarnessConfig(
|
||||
seed=seed,
|
||||
players=cfg.players,
|
||||
player_slot=cfg.player_slot,
|
||||
map_size=cfg.map_size,
|
||||
map_type=cfg.map_type,
|
||||
omniscient=cfg.omniscient,
|
||||
timeout_sec=cfg.timeout_sec,
|
||||
)
|
||||
cfg = replace(cfg, seed=seed)
|
||||
self._client = HarnessClient(cfg)
|
||||
self._terminated = False
|
||||
self._step_count = 0
|
||||
self._capital_by_player = {}
|
||||
self._city_founded_rewards_issued = 0
|
||||
view = self._client.view()
|
||||
view = self._client.view(slot=self._slot_kw)
|
||||
# Seed capitals from any cities present at game start. In duel
|
||||
# maps each player begins with a founder, so the capital map is
|
||||
# populated on the first CityFounded event per player (handled
|
||||
|
|
@ -197,11 +218,19 @@ class MagicCivEnv(gym.Env[np.ndarray, np.int64]):
|
|||
|
||||
prev_turn = int(self._last_view.get("turn", 0))
|
||||
reward = -_step_penalty(prev_turn)
|
||||
opp_events: list[dict[str, Any]] = []
|
||||
try:
|
||||
if player_action.get("type") == "end_turn":
|
||||
response = self._client.end_turn()
|
||||
response = self._client.end_turn(slot=self._slot_kw)
|
||||
# With a frozen model opponent, the simulator's AI loop
|
||||
# skips the opponent slot (it is externally declared) — so
|
||||
# we drive its full turn here. With the default MCTS
|
||||
# opponent this is a no-op: the AI loop already ran inside
|
||||
# the end_turn dispatch and its events are in `response`.
|
||||
if self._opponent is not None:
|
||||
opp_events = self._opponent.play_turn(self._client)
|
||||
else:
|
||||
response = self._client.act(player_action)
|
||||
response = self._client.act(player_action, slot=self._slot_kw)
|
||||
except HarnessError:
|
||||
# Treat any harness failure as a loss — bad action, dead
|
||||
# subprocess, etc. Terminate the episode.
|
||||
|
|
@ -214,12 +243,13 @@ class MagicCivEnv(gym.Env[np.ndarray, np.int64]):
|
|||
{"action_mask": np.zeros(ACTION_DIM, dtype=bool), "reason": "harness_error"},
|
||||
)
|
||||
|
||||
view = self._client.view()
|
||||
# Collect synchronous events from the act response + any async
|
||||
# notifications buffered while waiting for view's response.
|
||||
# Terminal events (game_over / player_eliminated) may have fired
|
||||
# during the opponent's turn between our act and our view.
|
||||
view = self._client.view(slot=self._slot_kw)
|
||||
# Collect synchronous events from the act response + the opponent's
|
||||
# turn + any async notifications buffered while waiting for view's
|
||||
# response. Terminal events (game_over / player_eliminated) may
|
||||
# have fired during the opponent's turn between our act and view.
|
||||
recent_events: list[dict[str, Any]] = list(response.get("events", []))
|
||||
recent_events.extend(opp_events)
|
||||
recent_events.extend(self._client.drain_notifications())
|
||||
new_turn = int(view.get("turn", 0))
|
||||
me = int(view.get("player", 0))
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue