From 7659b37c5c4b364182e7fc71e50fc8c71032b0dd Mon Sep 17 00:00:00 2001 From: Sotetsu KOYAMADA Date: Fri, 10 Feb 2023 16:29:18 +0900 Subject: [PATCH 1/6] . --- pgx/shogi.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pgx/shogi.py b/pgx/shogi.py index c363f892b..1aca17e21 100644 --- a/pgx/shogi.py +++ b/pgx/shogi.py @@ -401,7 +401,7 @@ def between_king(p, f): ) # 王手がかかってないなら王手放置は考えなくてよい - is_not_checked = ~(opp_effect_boards & king_mask).any() # scalar + is_not_checked = flipped_effecting_mask.sum() == 0 # scalar leave_check_mask |= is_not_checked # filter by leave check mask @@ -562,9 +562,7 @@ def between_king(p, f): axis=0 ) # (81,) - opp_effect_boards = jnp.flip(flipped_effect_boards) # (81,) - king_mask = state.piece_board == KING - is_not_checked = ~(opp_effect_boards & king_mask).any() # scalar + is_not_checked = flipped_effecting_mask.sum() == 0 # scalar legal_drops &= is_not_checked | aigoma_area_boards From be5a91f9ac7c82bceed74c07ed5f886c214e51d1 Mon Sep 17 00:00:00 2001 From: Sotetsu KOYAMADA Date: Fri, 10 Feb 2023 16:38:36 +0900 Subject: [PATCH 2/6] . --- pgx/shogi.py | 64 ++++++++++++++++++++++++---------------------------- 1 file changed, 29 insertions(+), 35 deletions(-) diff --git a/pgx/shogi.py b/pgx/shogi.py index 1aca17e21..44afd26bc 100644 --- a/pgx/shogi.py +++ b/pgx/shogi.py @@ -256,24 +256,29 @@ def _step_drop(state: State, action: Action) -> State: def _legal_actions(state: State): effect_boards = _apply_effects(state) + legal_moves = _pseudo_legal_moves(state, effect_boards) + legal_drops = _pseudo_legal_drops(state, effect_boards) + + # flipped_state = _flip(state) flipped_effect_boards = _apply_effects(flipped_state) - # generate legal moves from effects - legal_moves = _pseudo_legal_moves(state, effect_boards) + num_checks, check_defense_board = _check_defense(state, flipped_state, flipped_effect_boards) + + # legal_moves = _filter_suicide_moves( state, legal_moves, flipped_state, flipped_effect_boards ) legal_moves = _filter_ignoring_check_moves( - state, legal_moves, flipped_state, flipped_effect_boards + state, legal_moves, flipped_state, flipped_effect_boards, num_checks, check_defense_board ) legal_promotion = _legal_promotion(state, legal_moves) - # generate legal drops from effects - legal_drops = _pseudo_legal_drops(state, effect_boards) + + # legal_drops = _filter_pawn_drop_mate( state, legal_drops, effect_boards, flipped_effect_boards ) legal_drops = _filter_ignoring_check_drops( - state, legal_drops, flipped_state, flipped_effect_boards + legal_drops, num_checks, check_defense_board ) return legal_moves, legal_promotion, legal_drops @@ -349,6 +354,8 @@ def _filter_ignoring_check_moves( legal_moves: jnp.ndarray, flipped_state, flipped_effect_boards, + num_checks, + check_defense_board ) -> jnp.ndarray: """Filter moves which ignores check @@ -376,32 +383,16 @@ def _filter_ignoring_check_moves( leave_check_mask |= capturing_mask # 駒を動かして合駒をする - @jax.vmap - def between_king(p, f): - return IS_ON_THE_WAY[p, f, flipped_king_pos, :] - - from_ = jnp.arange(81) - large_piece = _to_large_piece_ix(flipped_state.piece_board) - flipped_between_king_mask = between_king(large_piece, from_) # (81, 81) - # 王手してない駒からのマスクは外す - flipped_aigoma_area_boards = jnp.where( - flipped_effecting_mask.reshape(81, 1), - flipped_between_king_mask, - jnp.zeros_like(flipped_between_king_mask), - ) - aigoma_area_boards = jnp.flip(flipped_aigoma_area_boards).any( - axis=0 - ) # (81,) - leave_check_mask |= aigoma_area_boards # filter target + leave_check_mask |= check_defense_board # filter target # 両王手の場合、王が避ける以外ない - is_double_checked = flipped_effecting_mask.sum() > 1 + is_double_checked = num_checks > 1 leave_check_mask = jax.lax.cond( is_double_checked, lambda: king_escape_mask, lambda: leave_check_mask ) # 王手がかかってないなら王手放置は考えなくてよい - is_not_checked = flipped_effecting_mask.sum() == 0 # scalar + is_not_checked = num_checks == 0 # scalar leave_check_mask |= is_not_checked # filter by leave check mask @@ -531,13 +522,7 @@ def _filter_pawn_drop_mate( return legal_drops -def _filter_ignoring_check_drops( - state: State, - legal_drops: jnp.ndarray, - flipped_state, - flipped_effect_boards, -): - # 合駒(王手放置) +def _check_defense(state, flipped_state, flipped_effect_boards): flipped_king_pos = ( 80 - jnp.nonzero(state.piece_board == KING, size=1)[0].item() ) @@ -562,12 +547,21 @@ def between_king(p, f): axis=0 ) # (81,) - is_not_checked = flipped_effecting_mask.sum() == 0 # scalar + num_checks = flipped_effecting_mask.sum() # scalar + return num_checks, aigoma_area_boards + - legal_drops &= is_not_checked | aigoma_area_boards +def _filter_ignoring_check_drops( + legal_drops: jnp.ndarray, + num_checks, + check_defense_board, +): + # 合駒(王手放置) + is_not_checked = num_checks == 0 + legal_drops &= is_not_checked | check_defense_board # 両王手の場合、合駒は無駄 - is_double_checked = flipped_effecting_mask.sum() > 1 + is_double_checked = num_checks > 1 legal_drops &= ~is_double_checked return legal_drops From c6262e6b251eae2a5b7487af57f8720468c5132f Mon Sep 17 00:00:00 2001 From: Sotetsu KOYAMADA Date: Fri, 10 Feb 2023 16:47:56 +0900 Subject: [PATCH 3/6] fix bug --- pgx/shogi.py | 2 +- tests/assets/shogi/legal_moves_006.svg | 2 +- tests/test_shogi.py | 13 +++++++------ 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/pgx/shogi.py b/pgx/shogi.py index 44afd26bc..4ce6cfe43 100644 --- a/pgx/shogi.py +++ b/pgx/shogi.py @@ -379,7 +379,7 @@ def _filter_ignoring_check_moves( flipped_effecting_mask = flipped_effect_boards[ :, flipped_king_pos ] # (81,) 王に利いている駒の位置 - capturing_mask = jnp.tile(flipped_effecting_mask, reps=(81, 1)) + capturing_mask = jnp.tile(jnp.flip(flipped_effecting_mask), reps=(81, 1)) leave_check_mask |= capturing_mask # 駒を動かして合駒をする diff --git a/tests/assets/shogi/legal_moves_006.svg b/tests/assets/shogi/legal_moves_006.svg index aff0ce789..7c674104b 100644 --- a/tests/assets/shogi/legal_moves_006.svg +++ b/tests/assets/shogi/legal_moves_006.svg @@ -1,2 +1,2 @@ -123456789 \ No newline at end of file +123456789 \ No newline at end of file diff --git a/tests/test_shogi.py b/tests/test_shogi.py index b8cacec3c..6152a5fd9 100644 --- a/tests/test_shogi.py +++ b/tests/test_shogi.py @@ -163,18 +163,19 @@ def test_legal_moves(): assert not legal_moves[xy2i(5, 9), xy2i(5, 8)] # 自殺手はNG assert not legal_moves[xy2i(2, 7), xy2i(2, 6)] # 王を放置するのはNG - # + # Checking piece should be captured s = init() s = s.replace( piece_board=s.piece_board - .at[xy2i(5, 5)].set(OPP_LANCE) - .at[xy2i(5, 7)].set(EMPTY) - .at[xy2i(7, 7)].set(EMPTY) + .at[:].set(EMPTY) + .at[xy2i(1, 9)].set(KING) + .at[xy2i(1, 1)].set(OPP_LANCE) + .at[xy2i(6, 1)].set(ROOK) ) visualize(s, "tests/assets/shogi/legal_moves_006.svg") legal_moves, _, _ = _legal_actions(s) - assert not legal_moves[xy2i(8, 8), xy2i(4, 4)] # 角が香を取る以外の動きは王手放置でNG - assert legal_moves[xy2i(8, 8), xy2i(5, 5)] # 角が王手をかけている香を取るのはOK + assert not legal_moves[xy2i(6, 1), xy2i(2, 1)] # 飛車が香を取る以外の動きは王手放置でNG + assert legal_moves[xy2i(6, 1), xy2i(1, 1)] # 飛車が王手をかけている香を取るのはOK # 合駒 s = init() From 183a24b2e4d328389609227ee565528277fd3402 Mon Sep 17 00:00:00 2001 From: Sotetsu KOYAMADA Date: Fri, 10 Feb 2023 16:54:21 +0900 Subject: [PATCH 4/6] . --- pgx/shogi.py | 38 ++++++++++++++++++-------------------- 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/pgx/shogi.py b/pgx/shogi.py index 4ce6cfe43..f7b5f53d5 100644 --- a/pgx/shogi.py +++ b/pgx/shogi.py @@ -259,27 +259,30 @@ def _legal_actions(state: State): legal_moves = _pseudo_legal_moves(state, effect_boards) legal_drops = _pseudo_legal_drops(state, effect_boards) - # + # Prepare necessary materials flipped_state = _flip(state) flipped_effect_boards = _apply_effects(flipped_state) - num_checks, check_defense_board = _check_defense(state, flipped_state, flipped_effect_boards) + checking_point_board, check_defense_board = _check_info(state, flipped_state, flipped_effect_boards) - # + # Filter illegal moves legal_moves = _filter_suicide_moves( state, legal_moves, flipped_state, flipped_effect_boards ) legal_moves = _filter_ignoring_check_moves( - state, legal_moves, flipped_state, flipped_effect_boards, num_checks, check_defense_board + state, legal_moves, flipped_effect_boards, checking_point_board, check_defense_board ) - legal_promotion = _legal_promotion(state, legal_moves) - # + # Filter illegal drops legal_drops = _filter_pawn_drop_mate( state, legal_drops, effect_boards, flipped_effect_boards ) legal_drops = _filter_ignoring_check_drops( - legal_drops, num_checks, check_defense_board + legal_drops, checking_point_board, check_defense_board ) + + # Generate legal promotion + legal_promotion = _legal_promotion(state, legal_moves) + return legal_moves, legal_promotion, legal_drops @@ -352,9 +355,8 @@ def pinned_piece_mask(p, f): def _filter_ignoring_check_moves( state: State, legal_moves: jnp.ndarray, - flipped_state, flipped_effect_boards, - num_checks, + checking_point_board, check_defense_board ) -> jnp.ndarray: """Filter moves which ignores check @@ -373,19 +375,14 @@ def _filter_ignoring_check_moves( leave_check_mask |= king_escape_mask # Capture the checking piece - flipped_king_pos = ( - 80 - jnp.nonzero(state.piece_board == KING, size=1)[0].item() - ) - flipped_effecting_mask = flipped_effect_boards[ - :, flipped_king_pos - ] # (81,) 王に利いている駒の位置 - capturing_mask = jnp.tile(jnp.flip(flipped_effecting_mask), reps=(81, 1)) + capturing_mask = jnp.tile(checking_point_board, reps=(81, 1)) leave_check_mask |= capturing_mask # 駒を動かして合駒をする leave_check_mask |= check_defense_board # filter target # 両王手の場合、王が避ける以外ない + num_checks = checking_point_board.sum() is_double_checked = num_checks > 1 leave_check_mask = jax.lax.cond( is_double_checked, lambda: king_escape_mask, lambda: leave_check_mask @@ -522,7 +519,7 @@ def _filter_pawn_drop_mate( return legal_drops -def _check_defense(state, flipped_state, flipped_effect_boards): +def _check_info(state, flipped_state, flipped_effect_boards): flipped_king_pos = ( 80 - jnp.nonzero(state.piece_board == KING, size=1)[0].item() ) @@ -547,15 +544,16 @@ def between_king(p, f): axis=0 ) # (81,) - num_checks = flipped_effecting_mask.sum() # scalar - return num_checks, aigoma_area_boards + return jnp.flip(flipped_effecting_mask), aigoma_area_boards def _filter_ignoring_check_drops( legal_drops: jnp.ndarray, - num_checks, + checking_piece_board, check_defense_board, ): + num_checks = checking_piece_board.sum() + # 合駒(王手放置) is_not_checked = num_checks == 0 legal_drops &= is_not_checked | check_defense_board From 16f936818ce4c01026ba2f7bf929d9348159c34c Mon Sep 17 00:00:00 2001 From: Sotetsu KOYAMADA Date: Fri, 10 Feb 2023 16:57:42 +0900 Subject: [PATCH 5/6] tidy --- pgx/shogi.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pgx/shogi.py b/pgx/shogi.py index f7b5f53d5..f25b8af48 100644 --- a/pgx/shogi.py +++ b/pgx/shogi.py @@ -269,7 +269,7 @@ def _legal_actions(state: State): state, legal_moves, flipped_state, flipped_effect_boards ) legal_moves = _filter_ignoring_check_moves( - state, legal_moves, flipped_effect_boards, checking_point_board, check_defense_board + state, legal_moves, checking_point_board, check_defense_board ) # Filter illegal drops @@ -355,7 +355,6 @@ def pinned_piece_mask(p, f): def _filter_ignoring_check_moves( state: State, legal_moves: jnp.ndarray, - flipped_effect_boards, checking_point_board, check_defense_board ) -> jnp.ndarray: @@ -368,8 +367,7 @@ def _filter_ignoring_check_moves( """ leave_check_mask = jnp.zeros_like(legal_moves, dtype=jnp.bool_) - # King escapes - opp_effect_boards = jnp.flip(flipped_effect_boards) # (81,) + # King escapes (i.e., Only King can move) king_mask = state.piece_board == KING king_escape_mask = jnp.tile(king_mask, reps=(81, 1)).transpose() leave_check_mask |= king_escape_mask From 81c0542159ef5dc552bb541a7d129261b4d287d8 Mon Sep 17 00:00:00 2001 From: Sotetsu KOYAMADA Date: Fri, 10 Feb 2023 16:58:22 +0900 Subject: [PATCH 6/6] . --- pgx/shogi.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pgx/shogi.py b/pgx/shogi.py index f25b8af48..8aa5daa3b 100644 --- a/pgx/shogi.py +++ b/pgx/shogi.py @@ -262,7 +262,9 @@ def _legal_actions(state: State): # Prepare necessary materials flipped_state = _flip(state) flipped_effect_boards = _apply_effects(flipped_state) - checking_point_board, check_defense_board = _check_info(state, flipped_state, flipped_effect_boards) + checking_point_board, check_defense_board = _check_info( + state, flipped_state, flipped_effect_boards + ) # Filter illegal moves legal_moves = _filter_suicide_moves( @@ -356,7 +358,7 @@ def _filter_ignoring_check_moves( state: State, legal_moves: jnp.ndarray, checking_point_board, - check_defense_board + check_defense_board, ) -> jnp.ndarray: """Filter moves which ignores check