Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Experimental] Add func based auto_reset #1017

Merged
merged 1 commit into from
Aug 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pgx/experimental/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from pgx.experimental.utils import act_randomly
from pgx.experimental.wrappers import auto_reset
102 changes: 24 additions & 78 deletions pgx/experimental/wrappers.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,37 @@
from typing import Tuple

# I personally don't prefer env = Wrapper(env) style.
import jax
import jax.numpy as jnp

import pgx

TRUE = jnp.bool_(True)
FALSE = jnp.bool_(False)


class Wrapper(pgx.Env):
def __init__(self, env: pgx.Env):
self.env: pgx.Env = env

def _init(self, key: jax.random.KeyArray) -> pgx.State:
"""Implement game-specific init function here."""
return self.env._init(key)

def _step(self, state, action) -> pgx.State:
"""Implement game-specific step function here."""
return self.env._step(state, action)

def _observe(
self, state: pgx.State, player_id: jnp.ndarray
) -> jnp.ndarray:
"""Implement game-specific observe function here."""
return self.env._observe(state, player_id)

@property
def id(self) -> pgx.EnvId:
return self.env.id

@property
def version(self) -> str:
return self.env.version
def auto_reset(step_fn, init_fn):
"""Auto reset wrapper.

@property
def num_players(self) -> int:
return self.env.num_players
There are several concerns before staging this wrapper:

@property
def num_actions(self) -> int:
return self.env.num_actions
1. Final state (observation)
When auto-reset happens, the termianl (or truncated) state/observation is
replaced by initial state/observation, It's ok if it's termination.
However, when truncation happens, value of truncated state/observation
might be used by agents. So its must be stored somewhere.
For example,

@property
def observation_shape(self) -> Tuple[int, ...]:
return self.env.observation_shape
https://github.com/Farama-Foundation/Gymnasium/blob/main/gymnasium/wrappers/autoreset.py#L59

However, currently, truncation does *NOT* actually happens because
all of Pgx environments (games) are finite-horizon and
terminates within reasonable # of steps.
Note that chess, shogi, and Go have `max_termination_steps` as AlphaZero.
So, this implementation is enough (so far).

class AutoReset(Wrapper):
"""AutoReset wrapper resets the state to the initial state immediately just after termination or truncation.
Note that the state before reset is required when truncation occurs,
but Pgx does not have such an environment at present, so it is not a practical problem.
2. Performance:
Might harm the performance as it always generates new state.
Memory usage might be doubled. Need to check.
"""

def step(self, state: pgx.State, action: jnp.ndarray) -> pgx.State:
def wrapped_step_fn(state, action):
state = jax.lax.cond(
(state.terminated | state.truncated),
lambda: state.replace( # type: ignore
Expand All @@ -63,49 +41,17 @@ def step(self, state: pgx.State, action: jnp.ndarray) -> pgx.State:
),
lambda: state,
)
state = self.env.step(state, action)
state = step_fn(state, action)
state = jax.lax.cond(
(state.terminated | state.truncated),
# state is replaced by initial state,
# but preserve (terminated, truncated, reward)
lambda: self.env.init(state._rng_key).replace( # type: ignore
lambda: init_fn(state._rng_key).replace( # type: ignore
terminated=state.terminated,
truncated=state.truncated,
rewards=state.rewards,
),
lambda: state,
)
return state


class ToSingle(Wrapper):
"""Flatten rewards to (batch_size,) assuming only <player_id=0> plays."""

def init(self, key: jax.random.KeyArray) -> pgx.State:
state = self.env.init(key)
return state.replace(rewards=state.rewards[:, 0]) # type: ignore

def step(self, state: pgx.State, action: jnp.ndarray) -> pgx.State:
state = self.env.step(state, action)
return state.replace(rewards=state.rewards[:, 0]) # type: ignore


class SpecifyFirstPlayer(Wrapper):
def __init__(self, env: pgx.Env):
super().__init__(env)
assert (
self.num_players == 2
), "SpecifyFirstPlayer is only for two-player game."

def init_with_first_player(
self, key: jax.random.KeyArray, first_player_id: jnp.ndarray
) -> pgx.State:
"""Special init function for two-player perfect information game.
Args:
key: pseudo-random generator key in JAX
first_player_id: zero or one
Returns:
State: initial state of environment
"""
state = self.init(key=key)
return state.replace(current_player=jnp.int8(first_player_id)) # type: ignore
return wrapped_step_fn