diff --git a/pgx/_src/dwg/tictactoe.py b/pgx/_src/dwg/tictactoe.py index b3cc46fef..b8cc66748 100644 --- a/pgx/_src/dwg/tictactoe.py +++ b/pgx/_src/dwg/tictactoe.py @@ -77,7 +77,7 @@ def _make_tictactoe_dwg(dwg, state: TictactoeState, config): ) ) - for i, mark in enumerate(state._board): + for i, mark in enumerate(state._x._board): x = i % BOARD_WIDTH y = i // BOARD_HEIGHT if mark == 0: # 先手 diff --git a/pgx/tic_tac_toe.py b/pgx/tic_tac_toe.py index 08897dd4d..135cc331a 100644 --- a/pgx/tic_tac_toe.py +++ b/pgx/tic_tac_toe.py @@ -23,6 +23,15 @@ TRUE = jnp.bool_(True) +@dataclass +class GameState: + _turn: Array = jnp.int32(0) + # 0 1 2 + # 3 4 5 + # 6 7 8 + _board: Array = -jnp.ones(9, jnp.int32) # -1 (empty), 0, 1 + + @dataclass class State(core.State): current_player: Array = jnp.int32(0) @@ -32,12 +41,7 @@ class State(core.State): truncated: Array = FALSE legal_action_mask: Array = jnp.ones(9, dtype=jnp.bool_) _step_count: Array = jnp.int32(0) - # --- Tic-tac-toe specific --- - _turn: Array = jnp.int32(0) - # 0 1 2 - # 3 4 5 - # 6 7 8 - _board: Array = -jnp.ones(9, jnp.int32) # -1 (empty), 0, 1 + _x: GameState = GameState() @property def env_id(self) -> core.EnvId: @@ -79,8 +83,8 @@ def _init(rng: PRNGKey) -> State: def _step(state: State, action: Array) -> State: - state = state.replace(_board=state._board.at[action].set(state._turn)) # type: ignore - won = _win_check(state._board, state._turn) + state = state.replace(_x=state._x.replace(_board=state._x._board.at[action].set(state._x._turn))) # type: ignore + won = _win_check(state._x._board, state._x._turn) reward = jax.lax.cond( won, lambda: jnp.float32([-1, -1]).at[state.current_player].set(1), @@ -88,10 +92,10 @@ def _step(state: State, action: Array) -> State: ) return state.replace( # type: ignore current_player=(state.current_player + 1) % 2, - legal_action_mask=state._board < 0, + legal_action_mask=state._x._board < 0, rewards=reward, - terminated=won | jnp.all(state._board != -1), - _turn=(state._turn + 1) % 2, + terminated=won | jnp.all(state._x._board != -1), + _x=state._x.replace(_turn=(state._x._turn + 1) % 2), # type: ignore ) @@ -103,13 +107,13 @@ def _win_check(board, turn) -> Array: def _observe(state: State, player_id: Array) -> Array: @jax.vmap def plane(i): - return (state._board == i).reshape((3, 3)) + return (state._x._board == i).reshape((3, 3)) # flip if player_id is opposite x = jax.lax.cond( state.current_player == player_id, - lambda: jnp.int32([state._turn, 1 - state._turn]), - lambda: jnp.int32([1 - state._turn, state._turn]), + lambda: jnp.int32([state._x._turn, 1 - state._x._turn]), + lambda: jnp.int32([1 - state._x._turn, state._x._turn]), ) return jnp.stack(plane(x), -1) diff --git a/tests/test_tic_tac_toe.py b/tests/test_tic_tac_toe.py index d60cad8ba..c9fd63287 100644 --- a/tests/test_tic_tac_toe.py +++ b/tests/test_tic_tac_toe.py @@ -19,13 +19,13 @@ def test_step(): key = jax.random.PRNGKey(1) state = init(key=key) assert state.current_player == 1 - assert state._turn == 0 + assert state._x._turn == 0 assert jnp.all( state.legal_action_mask == jnp.array([1, 1, 1, 1, 1, 1, 1, 1, 1], jnp.bool_) ) # fmt: ignore assert jnp.all( - state._board == jnp.int32([-1, -1, -1, -1, -1, -1, -1, -1, -1]) + state._x._board == jnp.int32([-1, -1, -1, -1, -1, -1, -1, -1, -1]) ) assert not state.terminated # -1 -1 -1 @@ -35,13 +35,13 @@ def test_step(): action = jnp.int32(4) state = step(state, action) assert state.current_player == 0 - assert state._turn == 1 + assert state._x._turn == 1 assert jnp.all( state.legal_action_mask == jnp.array([1, 1, 1, 1, 0, 1, 1, 1, 1], jnp.bool_) ) # fmt: ignore assert jnp.all( - state._board == jnp.int32([-1, -1, -1, -1, 0, -1, -1, -1, -1]) + state._x._board == jnp.int32([-1, -1, -1, -1, 0, -1, -1, -1, -1]) ) assert jnp.all(state.rewards == 0) # fmt: ignore assert not state.terminated @@ -52,12 +52,12 @@ def test_step(): action = jnp.int32(0) state = step(state, action) assert state.current_player == 1 - assert state._turn == 0 + assert state._x._turn == 0 assert jnp.all( state.legal_action_mask == jnp.array([0, 1, 1, 1, 0, 1, 1, 1, 1], jnp.bool_) ) # fmt: ignore - assert jnp.all(state._board == jnp.int32([1, -1, -1, -1, 0, -1, -1, -1, -1])) + assert jnp.all(state._x._board == jnp.int32([1, -1, -1, -1, 0, -1, -1, -1, -1])) assert jnp.all(state.rewards == 0) # fmt: ignore assert not state.terminated # 1 -1 -1 @@ -67,12 +67,12 @@ def test_step(): action = jnp.int32(1) state = step(state, action) assert state.current_player == 0 - assert state._turn == 1 + assert state._x._turn == 1 assert jnp.all( state.legal_action_mask == jnp.array([0, 0, 1, 1, 0, 1, 1, 1, 1], jnp.bool_) ) # fmt: ignore - assert jnp.all(state._board == jnp.int32([1, 0, -1, -1, 0, -1, -1, -1, -1])) + assert jnp.all(state._x._board == jnp.int32([1, 0, -1, -1, 0, -1, -1, -1, -1])) assert jnp.all(state.rewards == 0) # fmt: ignore assert not state.terminated # 1 0 -1 @@ -82,12 +82,12 @@ def test_step(): action = jnp.int32(8) state = step(state, action) assert state.current_player == 1 - assert state._turn == 0 + assert state._x._turn == 0 assert jnp.all( state.legal_action_mask == jnp.array([0, 0, 1, 1, 0, 1, 1, 1, 0], jnp.bool_) ) # fmt: ignore - assert jnp.all(state._board == jnp.int32([1, 0, -1, -1, 0, -1, -1, -1, 1])) + assert jnp.all(state._x._board == jnp.int32([1, 0, -1, -1, 0, -1, -1, -1, 1])) assert jnp.all(state.rewards == 0) # fmt: ignore assert not state.terminated # 1 0 -1 @@ -97,12 +97,12 @@ def test_step(): action = jnp.int32(7) state = step(state, action) assert state.current_player == 0 - assert state._turn == 1 + assert state._x._turn == 1 assert jnp.all( state.legal_action_mask == jnp.array([1, 1, 1, 1, 1, 1, 1, 1, 1], jnp.bool_) ) # fmt: ignore - assert jnp.all(state._board == jnp.int32([1, 0, -1, -1, 0, -1, -1, 0, 1])) + assert jnp.all(state._x._board == jnp.int32([1, 0, -1, -1, 0, -1, -1, 0, 1])) assert jnp.all(state.rewards == jnp.int32([-1, 1])) # fmt: ignore assert state.terminated # 1 0 -1