Skip to content

Commit

Permalink
[TicTacToe] Separate TicTacToe specific attributes (#1146)
Browse files Browse the repository at this point in the history
  • Loading branch information
sotetsuk authored Jan 8, 2024
1 parent c8b988b commit 08c79dc
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 27 deletions.
2 changes: 1 addition & 1 deletion pgx/_src/dwg/tictactoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _make_tictactoe_dwg(dwg, state: TictactoeState, config):
)
)

for i, mark in enumerate(state._board):
for i, mark in enumerate(state._x._board):
x = i % BOARD_WIDTH
y = i // BOARD_HEIGHT
if mark == 0: # 先手
Expand Down
32 changes: 18 additions & 14 deletions pgx/tic_tac_toe.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,15 @@
TRUE = jnp.bool_(True)


@dataclass
class GameState:
_turn: Array = jnp.int32(0)
# 0 1 2
# 3 4 5
# 6 7 8
_board: Array = -jnp.ones(9, jnp.int32) # -1 (empty), 0, 1


@dataclass
class State(core.State):
current_player: Array = jnp.int32(0)
Expand All @@ -32,12 +41,7 @@ class State(core.State):
truncated: Array = FALSE
legal_action_mask: Array = jnp.ones(9, dtype=jnp.bool_)
_step_count: Array = jnp.int32(0)
# --- Tic-tac-toe specific ---
_turn: Array = jnp.int32(0)
# 0 1 2
# 3 4 5
# 6 7 8
_board: Array = -jnp.ones(9, jnp.int32) # -1 (empty), 0, 1
_x: GameState = GameState()

@property
def env_id(self) -> core.EnvId:
Expand Down Expand Up @@ -79,19 +83,19 @@ def _init(rng: PRNGKey) -> State:


def _step(state: State, action: Array) -> State:
state = state.replace(_board=state._board.at[action].set(state._turn)) # type: ignore
won = _win_check(state._board, state._turn)
state = state.replace(_x=state._x.replace(_board=state._x._board.at[action].set(state._x._turn))) # type: ignore
won = _win_check(state._x._board, state._x._turn)
reward = jax.lax.cond(
won,
lambda: jnp.float32([-1, -1]).at[state.current_player].set(1),
lambda: jnp.zeros(2, jnp.float32),
)
return state.replace( # type: ignore
current_player=(state.current_player + 1) % 2,
legal_action_mask=state._board < 0,
legal_action_mask=state._x._board < 0,
rewards=reward,
terminated=won | jnp.all(state._board != -1),
_turn=(state._turn + 1) % 2,
terminated=won | jnp.all(state._x._board != -1),
_x=state._x.replace(_turn=(state._x._turn + 1) % 2), # type: ignore
)


Expand All @@ -103,13 +107,13 @@ def _win_check(board, turn) -> Array:
def _observe(state: State, player_id: Array) -> Array:
@jax.vmap
def plane(i):
return (state._board == i).reshape((3, 3))
return (state._x._board == i).reshape((3, 3))

# flip if player_id is opposite
x = jax.lax.cond(
state.current_player == player_id,
lambda: jnp.int32([state._turn, 1 - state._turn]),
lambda: jnp.int32([1 - state._turn, state._turn]),
lambda: jnp.int32([state._x._turn, 1 - state._x._turn]),
lambda: jnp.int32([1 - state._x._turn, state._x._turn]),
)

return jnp.stack(plane(x), -1)
24 changes: 12 additions & 12 deletions tests/test_tic_tac_toe.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ def test_step():
key = jax.random.PRNGKey(1)
state = init(key=key)
assert state.current_player == 1
assert state._turn == 0
assert state._x._turn == 0
assert jnp.all(
state.legal_action_mask
== jnp.array([1, 1, 1, 1, 1, 1, 1, 1, 1], jnp.bool_)
) # fmt: ignore
assert jnp.all(
state._board == jnp.int32([-1, -1, -1, -1, -1, -1, -1, -1, -1])
state._x._board == jnp.int32([-1, -1, -1, -1, -1, -1, -1, -1, -1])
)
assert not state.terminated
# -1 -1 -1
Expand All @@ -35,13 +35,13 @@ def test_step():
action = jnp.int32(4)
state = step(state, action)
assert state.current_player == 0
assert state._turn == 1
assert state._x._turn == 1
assert jnp.all(
state.legal_action_mask
== jnp.array([1, 1, 1, 1, 0, 1, 1, 1, 1], jnp.bool_)
) # fmt: ignore
assert jnp.all(
state._board == jnp.int32([-1, -1, -1, -1, 0, -1, -1, -1, -1])
state._x._board == jnp.int32([-1, -1, -1, -1, 0, -1, -1, -1, -1])
)
assert jnp.all(state.rewards == 0) # fmt: ignore
assert not state.terminated
Expand All @@ -52,12 +52,12 @@ def test_step():
action = jnp.int32(0)
state = step(state, action)
assert state.current_player == 1
assert state._turn == 0
assert state._x._turn == 0
assert jnp.all(
state.legal_action_mask
== jnp.array([0, 1, 1, 1, 0, 1, 1, 1, 1], jnp.bool_)
) # fmt: ignore
assert jnp.all(state._board == jnp.int32([1, -1, -1, -1, 0, -1, -1, -1, -1]))
assert jnp.all(state._x._board == jnp.int32([1, -1, -1, -1, 0, -1, -1, -1, -1]))
assert jnp.all(state.rewards == 0) # fmt: ignore
assert not state.terminated
# 1 -1 -1
Expand All @@ -67,12 +67,12 @@ def test_step():
action = jnp.int32(1)
state = step(state, action)
assert state.current_player == 0
assert state._turn == 1
assert state._x._turn == 1
assert jnp.all(
state.legal_action_mask
== jnp.array([0, 0, 1, 1, 0, 1, 1, 1, 1], jnp.bool_)
) # fmt: ignore
assert jnp.all(state._board == jnp.int32([1, 0, -1, -1, 0, -1, -1, -1, -1]))
assert jnp.all(state._x._board == jnp.int32([1, 0, -1, -1, 0, -1, -1, -1, -1]))
assert jnp.all(state.rewards == 0) # fmt: ignore
assert not state.terminated
# 1 0 -1
Expand All @@ -82,12 +82,12 @@ def test_step():
action = jnp.int32(8)
state = step(state, action)
assert state.current_player == 1
assert state._turn == 0
assert state._x._turn == 0
assert jnp.all(
state.legal_action_mask
== jnp.array([0, 0, 1, 1, 0, 1, 1, 1, 0], jnp.bool_)
) # fmt: ignore
assert jnp.all(state._board == jnp.int32([1, 0, -1, -1, 0, -1, -1, -1, 1]))
assert jnp.all(state._x._board == jnp.int32([1, 0, -1, -1, 0, -1, -1, -1, 1]))
assert jnp.all(state.rewards == 0) # fmt: ignore
assert not state.terminated
# 1 0 -1
Expand All @@ -97,12 +97,12 @@ def test_step():
action = jnp.int32(7)
state = step(state, action)
assert state.current_player == 0
assert state._turn == 1
assert state._x._turn == 1
assert jnp.all(
state.legal_action_mask
== jnp.array([1, 1, 1, 1, 1, 1, 1, 1, 1], jnp.bool_)
) # fmt: ignore
assert jnp.all(state._board == jnp.int32([1, 0, -1, -1, 0, -1, -1, 0, 1]))
assert jnp.all(state._x._board == jnp.int32([1, 0, -1, -1, 0, -1, -1, 0, 1]))
assert jnp.all(state.rewards == jnp.int32([-1, 1])) # fmt: ignore
assert state.terminated
# 1 0 -1
Expand Down

0 comments on commit 08c79dc

Please sign in to comment.