Skip to content

Commit

Permalink
[Mahjong] Imprement tsumo (#1041)
Browse files Browse the repository at this point in the history
  • Loading branch information
OkanoShinri authored Sep 20, 2023
1 parent 2eeaaec commit 2041a6d
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 2 deletions.
32 changes: 30 additions & 2 deletions pgx/_mahjong/_mahjong2.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,8 +563,36 @@ def _riichi(state: State):


def _tsumo(state: State):
# TODO
return _pass(state)
c_p = state.current_player

score = Yaku.score(
state.hand[c_p],
state.melds[c_p],
state.n_meld[c_p],
state.target,
state.riichi[c_p],
is_ron=FALSE,
)
s1 = score + (-score) % 100
s2 = (score * 2) + (-(score * 2)) % 100

oya = (state.oya + state.round) % 4
reward = jax.lax.cond(
oya == c_p,
lambda: jnp.full(4, -s2, dtype=jnp.int32).at[c_p].set(s2 * 3),
lambda: jnp.full(4, -s1, dtype=jnp.int32)
.at[oya]
.set(-s2)
.at[c_p]
.set(s1 * 2 + s2),
)

# 供託
reward -= 1000 * state.riichi
reward = reward.at[c_p].set(reward[c_p] + 1000 * jnp.sum(state.riichi))
return state.replace( # type:ignore
terminated=TRUE, rewards=jnp.float32(reward)
)


def _ron(state: State):
Expand Down
25 changes: 25 additions & 0 deletions tests/test_mahjong.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,9 +288,34 @@ def test_ron():

state = step(state, Action.RON)
assert state.terminated
assert (
state.rewards
== jnp.array([0.0, 2000.0, -2000.0, 0.0], dtype=jnp.float32)
).all()
visualize(state, "tests/assets/mahjong/ron.svg")


def test_tsumo():
rng = jax.random.PRNGKey(25)
state = init(key=rng)

for i in range(91):
rng, subkey = jax.random.split(rng)
a = act_randomly(subkey, state)
state = step(state, a)

assert not state.terminated
assert state.legal_action_mask[Action.TSUMO]

state = step(state, Action.TSUMO)
assert state.terminated
assert (
state.rewards
== jnp.array([1200.0, -400.0, -400.0, -400.0], dtype=jnp.float32)
).all()
visualize(state, "tests/assets/mahjong/tsumo.svg")


def test_transparent():
rng = jax.random.PRNGKey(31)
state = init(key=rng)
Expand Down

0 comments on commit 2041a6d

Please sign in to comment.