Skip to content

Commit

Permalink
[Shogi] Fix capturing checking piece bug (#335)
Browse files Browse the repository at this point in the history
  • Loading branch information
sotetsuk authored Feb 10, 2023
1 parent 7306719 commit 36ba18f
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 56 deletions.
88 changes: 39 additions & 49 deletions pgx/shogi.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,25 +256,35 @@ def _step_drop(state: State, action: Action) -> State:

def _legal_actions(state: State):
effect_boards = _apply_effects(state)
legal_moves = _pseudo_legal_moves(state, effect_boards)
legal_drops = _pseudo_legal_drops(state, effect_boards)

# Prepare necessary materials
flipped_state = _flip(state)
flipped_effect_boards = _apply_effects(flipped_state)
# generate legal moves from effects
legal_moves = _pseudo_legal_moves(state, effect_boards)
checking_point_board, check_defense_board = _check_info(
state, flipped_state, flipped_effect_boards
)

# Filter illegal moves
legal_moves = _filter_suicide_moves(
state, legal_moves, flipped_state, flipped_effect_boards
)
legal_moves = _filter_ignoring_check_moves(
state, legal_moves, flipped_state, flipped_effect_boards
state, legal_moves, checking_point_board, check_defense_board
)
legal_promotion = _legal_promotion(state, legal_moves)
# generate legal drops from effects
legal_drops = _pseudo_legal_drops(state, effect_boards)

# Filter illegal drops
legal_drops = _filter_pawn_drop_mate(
state, legal_drops, effect_boards, flipped_effect_boards
)
legal_drops = _filter_ignoring_check_drops(
state, legal_drops, flipped_state, flipped_effect_boards
legal_drops, checking_point_board, check_defense_board
)

# Generate legal promotion
legal_promotion = _legal_promotion(state, legal_moves)

return legal_moves, legal_promotion, legal_drops


Expand Down Expand Up @@ -347,8 +357,8 @@ def pinned_piece_mask(p, f):
def _filter_ignoring_check_moves(
state: State,
legal_moves: jnp.ndarray,
flipped_state,
flipped_effect_boards,
checking_point_board,
check_defense_board,
) -> jnp.ndarray:
"""Filter moves which ignores check
Expand All @@ -359,49 +369,27 @@ def _filter_ignoring_check_moves(
"""
leave_check_mask = jnp.zeros_like(legal_moves, dtype=jnp.bool_)

# King escapes
opp_effect_boards = jnp.flip(flipped_effect_boards) # (81,)
# King escapes (i.e., Only King can move)
king_mask = state.piece_board == KING
king_escape_mask = jnp.tile(king_mask, reps=(81, 1)).transpose()
leave_check_mask |= king_escape_mask

# Capture the checking piece
flipped_king_pos = (
80 - jnp.nonzero(state.piece_board == KING, size=1)[0].item()
)
flipped_effecting_mask = flipped_effect_boards[
:, flipped_king_pos
] # (81,) 王に利いている駒の位置
capturing_mask = jnp.tile(flipped_effecting_mask, reps=(81, 1))
capturing_mask = jnp.tile(checking_point_board, reps=(81, 1))
leave_check_mask |= capturing_mask

# 駒を動かして合駒をする
@jax.vmap
def between_king(p, f):
return IS_ON_THE_WAY[p, f, flipped_king_pos, :]

from_ = jnp.arange(81)
large_piece = _to_large_piece_ix(flipped_state.piece_board)
flipped_between_king_mask = between_king(large_piece, from_) # (81, 81)
# 王手してない駒からのマスクは外す
flipped_aigoma_area_boards = jnp.where(
flipped_effecting_mask.reshape(81, 1),
flipped_between_king_mask,
jnp.zeros_like(flipped_between_king_mask),
)
aigoma_area_boards = jnp.flip(flipped_aigoma_area_boards).any(
axis=0
) # (81,)
leave_check_mask |= aigoma_area_boards # filter target
leave_check_mask |= check_defense_board # filter target

# 両王手の場合、王が避ける以外ない
is_double_checked = flipped_effecting_mask.sum() > 1
num_checks = checking_point_board.sum()
is_double_checked = num_checks > 1
leave_check_mask = jax.lax.cond(
is_double_checked, lambda: king_escape_mask, lambda: leave_check_mask
)

# 王手がかかってないなら王手放置は考えなくてよい
is_not_checked = ~(opp_effect_boards & king_mask).any() # scalar
is_not_checked = num_checks == 0 # scalar
leave_check_mask |= is_not_checked

# filter by leave check mask
Expand Down Expand Up @@ -531,13 +519,7 @@ def _filter_pawn_drop_mate(
return legal_drops


def _filter_ignoring_check_drops(
state: State,
legal_drops: jnp.ndarray,
flipped_state,
flipped_effect_boards,
):
# 合駒(王手放置)
def _check_info(state, flipped_state, flipped_effect_boards):
flipped_king_pos = (
80 - jnp.nonzero(state.piece_board == KING, size=1)[0].item()
)
Expand All @@ -562,14 +544,22 @@ def between_king(p, f):
axis=0
) # (81,)

opp_effect_boards = jnp.flip(flipped_effect_boards) # (81,)
king_mask = state.piece_board == KING
is_not_checked = ~(opp_effect_boards & king_mask).any() # scalar
return jnp.flip(flipped_effecting_mask), aigoma_area_boards

legal_drops &= is_not_checked | aigoma_area_boards

def _filter_ignoring_check_drops(
legal_drops: jnp.ndarray,
checking_piece_board,
check_defense_board,
):
num_checks = checking_piece_board.sum()

# 合駒(王手放置)
is_not_checked = num_checks == 0
legal_drops &= is_not_checked | check_defense_board

# 両王手の場合、合駒は無駄
is_double_checked = flipped_effecting_mask.sum() > 1
is_double_checked = num_checks > 1
legal_drops &= ~is_double_checked

return legal_drops
Expand Down
2 changes: 1 addition & 1 deletion tests/assets/shogi/legal_moves_006.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
13 changes: 7 additions & 6 deletions tests/test_shogi.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,18 +163,19 @@ def test_legal_moves():
assert not legal_moves[xy2i(5, 9), xy2i(5, 8)] # 自殺手はNG
assert not legal_moves[xy2i(2, 7), xy2i(2, 6)] # 王を放置するのはNG

#
# Checking piece should be captured
s = init()
s = s.replace(
piece_board=s.piece_board
.at[xy2i(5, 5)].set(OPP_LANCE)
.at[xy2i(5, 7)].set(EMPTY)
.at[xy2i(7, 7)].set(EMPTY)
.at[:].set(EMPTY)
.at[xy2i(1, 9)].set(KING)
.at[xy2i(1, 1)].set(OPP_LANCE)
.at[xy2i(6, 1)].set(ROOK)
)
visualize(s, "tests/assets/shogi/legal_moves_006.svg")
legal_moves, _, _ = _legal_actions(s)
assert not legal_moves[xy2i(8, 8), xy2i(4, 4)] # 角が香を取る以外の動きは王手放置でNG
assert legal_moves[xy2i(8, 8), xy2i(5, 5)] # 角が王手をかけている香を取るのはOK
assert not legal_moves[xy2i(6, 1), xy2i(2, 1)] # 飛車が香を取る以外の動きは王手放置でNG
assert legal_moves[xy2i(6, 1), xy2i(1, 1)] # 飛車が王手をかけている香を取るのはOK

# 合駒
s = init()
Expand Down

0 comments on commit 36ba18f

Please sign in to comment.