Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Shogi] Add and fix buggy test samples #358

Merged
merged 12 commits into from
Feb 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions pgx/shogi.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,9 @@ def _pseudo_legal_drops(
# double pawn
has_pawn = (state.piece_board == PAWN).reshape(9, 9).any(axis=1)
has_pawn = jnp.tile(has_pawn, reps=(9, 1)).transpose().flatten()
legal_drops = jnp.where(has_pawn, FALSE, legal_drops)
legal_drops = legal_drops.at[0].set(
jnp.where(has_pawn, FALSE, legal_drops[0])
)

return legal_drops

Expand Down Expand Up @@ -1144,20 +1146,18 @@ def _from_sfen(sfen):
else:
s_turn = jnp.int8(1)
s_hand = jnp.zeros(14, dtype=jnp.int8)
if hand == "-":
s_hand = jnp.reshape(s_hand, (2, 7))
else:
if hand != "-":
num_piece = 1
for char in hand:
if char.isdigit():
num_piece = int(char)
else:
s_hand = s_hand.at[hand_char_dir.index(char)].set(num_piece)
num_piece = 1
return State.from_board(
return State._from_board(
turn=s_turn,
piece_board=jnp.rot90(piece_board.reshape((9, 9)), k=1).flatten(),
hand=s_hand,
hand=jnp.reshape(s_hand, (2, 7)),
)


Expand All @@ -1179,4 +1179,4 @@ def _from_cshogi(board):
for i in range(2):
for j in range(7):
hand = hand.at[i, j].set(pieces_in_hand[i][hand_piece_dir[j]])
return State.from_board(turn=board.turn, piece_board=pb, hand=hand)
return State._from_board(turn=board.turn, piece_board=pb, hand=hand)
2 changes: 2 additions & 0 deletions tests/assets/shogi/legal_action_mask_015.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 2 additions & 0 deletions tests/assets/shogi/legal_action_mask_016.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 2 additions & 0 deletions tests/assets/shogi/legal_action_mask_017.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
33 changes: 32 additions & 1 deletion tests/test_shogi.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import jax.numpy as jnp

from pgx.shogi import *
from pgx.shogi import _init, _step, _step_move, _step_drop, _flip, _effects_all, _legal_actions, _rotate, _to_direction
from pgx.shogi import _init, _step, _step_move, _step_drop, _flip, _effects_all, _legal_actions, _rotate, _to_direction, _from_sfen, _pseudo_legal_drops


# check visualization results by image preview plugins
Expand Down Expand Up @@ -473,3 +473,34 @@ def test_legal_action_mask():
assert not s.legal_action_mask[2 * 81 + xy2i(2, 2)] # 金打の後は角の利きが止まっている
assert not s.legal_action_mask[2 * 81 + xy2i(4, 4)]
assert s.legal_action_mask[0 * 81 + xy2i(4, 3)] # 金の利きが増える


def test_buggy_samples():
# 歩以外の持ち駒に対しての二歩判定回避
sfen = "9/9/9/9/9/9/PPPPPPPPP/9/9 b NLP 1"
state = _from_sfen(sfen)
visualize(state, "tests/assets/shogi/legal_action_mask_015.svg")

# 歩は二歩になるので打てない
assert (~state.legal_action_mask[20 * 81:21 * 81]).all()
# 香車は2列目には打てるが、1列目と7列目(歩がいる)には打てない
assert (state.legal_action_mask[21 * 81 + 1:22 * 81:9]).all()
assert (~state.legal_action_mask[21 * 81:22 * 81:9]).all()
assert (~state.legal_action_mask[21 * 81 + 6:22 * 81:9]).all()
# 桂馬は1,2列目に打てないが3列目には打てる
assert (~state.legal_action_mask[22 * 81:23 * 81:9]).all()
assert (~state.legal_action_mask[22 * 81 + 1:23 * 81:9]).all()
assert (state.legal_action_mask[21 * 81 + 2:22 * 81:9]).all()

# 成駒のpromotion判定
sfen = "9/2+B1G1+P2/9/9/9/9/9/9/9 b - 1"
state = _from_sfen(sfen)
visualize(state, "tests/assets/shogi/legal_action_mask_016.svg")
# promotionは生成されてたらダメ
assert (~state.legal_action_mask[10 * 81:]).all()

# 角は成れないはず
sfen = "l+B6l/6k2/3pg2P1/p6p1/1pP1pB2p/2p3n2/P+r1GP3P/4KS1+s1/LNG5L b RGN2sn6p 1"
state = _from_sfen(sfen)
visualize(state, "tests/assets/shogi/legal_action_mask_017.svg")
assert ~state.legal_action_mask[13 * 81 + 72] # = 1125, promote + left (91角成)