Skip to content

Commit

Permalink
legal_actions をObservationだけから生成する (#1076)
Browse files Browse the repository at this point in the history
* add declaration

* add tests

* Update mjxproto

* Apply formatter

* extract IsRoundOver

* add dummy action

* Apply formatter

* add python test

* use pybind

* add coment

* Apply formatter

* fix

* fix test

* add discards and tsumogiri

* Apply formatter

* add comments

* add switch

* add discard after chi/pon

* Apply formatter

* add discard action after riichi

* add HadDrawLeft and RequireKanDraw

* tmp CanRon

* Apply formatter

* extract TargetTile

* Apply formatter

* implement missed_tiles

* add valid

* set win state info

* Apply formatter

* add comments

* Apply formatter

* add is_ippatsu

* Apply formatter

* add IsRobbingKan

* Apply formatter

* add IsFirstTurnWithoutOpen

* add chi/pon/kan/no

* Apply formatter

* add CanRiichi (wrong cnt = 5)

* Apply formatter

* add CanTsumo (wrong cnt = 4)

* Apply formatter

* add Kan

* Apply formatter

* add nine tiles (wrong cnt = 4)

* Apply formatter

* Update mjxproto

* fix test

* add assertion

* enhance test

* enhance tests

* fix furiten

* Apply formatter

* fix bug in tsumo (wrong cnt 1)

* fix bug in can riichi (wrong cnt = 0)

* Apply formatter

* add parallel tests

* Update mjxproto

* Apply formatter

* fix bug in tyankan. # failure = 0/468 (0 %)

* Apply formatter

* add internal::Observation::GenerateLegalActions

* Apply formatter

* move test

* Apply formatter

* revert internal/state.h, state.cpp

* fix test

* Apply formatter

Co-authored-by: GitHub Actions <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
sotetsuk and github-actions[bot] authored Mar 1, 2022
1 parent 3c927a0 commit 13eb4a6
Show file tree
Hide file tree
Showing 11 changed files with 698 additions and 1,187 deletions.
448 changes: 439 additions & 9 deletions include/mjx/internal/observation.cpp

Large diffs are not rendered by default.

36 changes: 36 additions & 0 deletions include/mjx/internal/observation.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,48 @@ class Observation {

[[nodiscard]] std::vector<float> ToFeature(std::string version) const;

[[nodiscard]] static std::vector<mjxproto::Action> GenerateLegalActions(
const mjxproto::Observation& observation);

private:
// TODO: remove friends and use proto()
friend class State;
mjxproto::Observation proto_ = mjxproto::Observation{};

[[nodiscard]] std::vector<float> small_v0() const;

// 次のブロックは主にObservationだけからでもlegal
// actionを生成できるようにするための機能
// internal::Stateの内部状態を使わずに同じ計算をしている想定
// ただし、eventsをなめるので遅かったりする
[[nodiscard]] static bool HasDrawLeft(
const mjxproto::PublicObservation& public_observation);
[[nodiscard]] static bool HasNextDrawLeft(
const mjxproto::PublicObservation& public_observation);
[[nodiscard]] static bool RequireKanDraw(
const mjxproto::PublicObservation& public_observation);
[[nodiscard]] static bool CanRon(AbsolutePos who,
const mjxproto::Observation& observation);
[[nodiscard]] static bool CanTsumo(AbsolutePos who,
const mjxproto::Observation& observation);
[[nodiscard]] static std::optional<Tile> TargetTile(
const mjxproto::PublicObservation& public_observation);
[[nodiscard]] static AbsolutePos dealer(
const mjxproto::PublicObservation& public_observation);
[[nodiscard]] static Wind prevalent_wind(
const mjxproto::PublicObservation& public_observation);
[[nodiscard]] static bool IsIppatsu(
AbsolutePos who, const mjxproto::PublicObservation& public_observation);
[[nodiscard]] static bool IsRobbingKan(
const mjxproto::PublicObservation& public_observation);
[[nodiscard]] static bool IsFirstTurnWithoutOpen(
const mjxproto::PublicObservation& public_observation);
[[nodiscard]] static bool IsFourKanNoWinner(
const mjxproto::PublicObservation& public_observation);
[[nodiscard]] static bool CanRiichi(AbsolutePos who,
const mjxproto::Observation& observation);
[[nodiscard]] static bool IsRoundOver(
const mjxproto::PublicObservation& public_observation);
};
} // namespace mjx::internal

Expand Down
12 changes: 12 additions & 0 deletions include/mjx/observation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <utility>

#include "mjx/internal/observation.h"
#include "mjx/internal/state.h"
#include "mjx/internal/types.h"

namespace mjx {
Expand Down Expand Up @@ -122,4 +123,15 @@ std::vector<int> Observation::tens() const noexcept {
int Observation::round() const noexcept {
return proto_.public_observation().init_score().round();
}

std::string Observation::AddLegalActions(const std::string& obs_json) {
auto obs = Observation(obs_json);
mjxproto::Observation obs_proto = obs.proto();
auto legal_actions =
mjx::internal::Observation::GenerateLegalActions(obs_proto);
for (auto a : legal_actions) {
obs_proto.mutable_legal_actions()->Add(std::move(a));
}
return Observation(obs_proto).ToJson();
}
} // namespace mjx
2 changes: 2 additions & 0 deletions include/mjx/observation.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ class Observation {
std::string ToJson() const noexcept;
std::vector<float> ToFeature(const std::string& version) const noexcept;

static std::string AddLegalActions(const std::string& obs_json);

// accessors
const mjxproto::Observation& proto() const noexcept;
Hand curr_hand() const noexcept;
Expand Down
5 changes: 5 additions & 0 deletions mjx/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ def save_svg(self, filename: str, view_idx: Optional[int] = None) -> None:
observation = self.to_proto()
save_svg(observation, filename, view_idx)

@staticmethod
def add_legal_actions(obs_json: str) -> str:
assert len(Observation(obs_json).legal_actions()) == 0, "Legal actions are alredy set."
return _mjx.Observation.add_legal_actions(obs_json)

@classmethod
def from_proto(cls, proto: mjxproto.Observation) -> Observation:
return Observation(json_format.MessageToJson(proto))
Expand Down
3 changes: 2 additions & 1 deletion mjx/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ PYBIND11_MODULE(_mjx, m) {
.def("honba", &mjx::Observation::honba)
.def("tens", &mjx::Observation::tens)
.def("round", &mjx::Observation::round)
.def("curr_hand", &mjx::Observation::curr_hand);
.def("curr_hand", &mjx::Observation::curr_hand)
.def_static("add_legal_actions", &mjx::Observation::AddLegalActions);

py::class_<mjx::State>(m, "State")
.def(py::init<std::string>())
Expand Down
1,172 changes: 66 additions & 1,106 deletions mjxproto/mjx_pb2.py

Large diffs are not rendered by default.

46 changes: 46 additions & 0 deletions tests_cpp/internal_observation_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <fstream>

#include "gtest/gtest.h"
#include "utils.cpp"

using namespace mjx::internal;

Expand Down Expand Up @@ -109,4 +110,49 @@ TEST(internal_observation, current_hand) {
state = State(next_state_info);
}
}
}

bool legal_actions_equals(const std::vector<mjxproto::Action> &legal_actions1,
const std::vector<mjxproto::Action> &legal_actions2) {
if (legal_actions1.size() != legal_actions2.size()) return false;
for (int i = 0; i < legal_actions1.size(); ++i) {
bool ok = google::protobuf::util::MessageDifferencer::Equals(
legal_actions1.at(i), legal_actions2.at(i));
if (!ok) return false;
}
return true;
}

TEST(internal_state, LegalActions) {
// Test with resources
const bool all_ok = ParallelTest([](const std::string &json) {
bool all_ok = true;
const auto state = State(json);
auto past_decisions = State::GeneratePastDecisions(state.proto());
for (auto [obs_proto, a] : past_decisions) {
auto obs_original = Observation(obs_proto);
auto legal_actions_original = obs_original.legal_actions();
mjxproto::Observation obs_cleared = obs_proto;
obs_cleared.clear_legal_actions();
EXPECT_NE(legal_actions_original.size(), 0);
EXPECT_EQ(obs_cleared.legal_actions_size(), 0);
auto legal_actions_restored =
Observation::GenerateLegalActions(obs_cleared);
bool ok =
legal_actions_equals(legal_actions_original, legal_actions_restored);
all_ok = all_ok && ok;
if (!ok) {
std::cerr << "Original: " << legal_actions_original.size() << std::endl;
std::cerr << obs_original.ToJson() << std::endl;
std::cerr << "Restored: " << legal_actions_restored.size() << std::endl;
auto o = Observation(obs_cleared);
o.add_legal_actions(legal_actions_restored);
std::cerr << o.ToJson() << std::endl;
}
}
return all_ok;
});
EXPECT_TRUE(all_ok);

// Test with simulators
}
73 changes: 2 additions & 71 deletions tests_cpp/internal_state_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <thread>

#include "gtest/gtest.h"
#include "utils.cpp"

using namespace mjx::internal;

Expand Down Expand Up @@ -172,76 +173,6 @@ mjxproto::Action FindPossibleAction(mjxproto::ActionType action_type,
Assert(false);
}

template <typename F>
bool ParallelTest(F &&f) {
static std::mutex mtx_;
int total_cnt = 0;
int failure_cnt = 0;

auto Check = [&total_cnt, &failure_cnt, &f](int begin, int end,
const auto &jsons) {
// {
// std::lock_guard<std::mutex> lock(mtx_);
// std::cerr << std::this_thread::get_id() << " " << begin << " " << end
// << std::endl;
// }
int curr = begin;
while (curr < end) {
const auto &[json, filename] = jsons[curr];
bool ok = f(json);
{
std::lock_guard<std::mutex> lock(mtx_);
total_cnt++;
if (!ok) {
failure_cnt++;
std::cerr << filename << std::endl;
}
if (total_cnt % 1000 == 0)
std::cerr << "# failure = " << failure_cnt << "/" << total_cnt << " ("
<< 100.0 * failure_cnt / total_cnt << " %)" << std::endl;
}
curr++;
}
};

const auto thread_count = std::thread::hardware_concurrency();
std::vector<std::thread> threads;
std::vector<std::pair<std::string, std::string>> jsons;
std::string json_path = std::string(TEST_RESOURCES_DIR) + "/json";

auto Run = [&]() {
const int json_size = jsons.size();
const int size_per = json_size / thread_count;
for (int i = 0; i < thread_count; ++i) {
const int start_ix = i * size_per;
const int end_ix =
(i == thread_count - 1) ? json_size : (i + 1) * size_per;
threads.emplace_back(Check, start_ix, end_ix, jsons);
}
for (auto &t : threads) t.join();
threads.clear();
jsons.clear();
};

if (!json_path.empty())
for (const auto &filename :
std::filesystem::directory_iterator(json_path)) {
std::ifstream ifs(filename.path().string(), std::ios::in);
while (!ifs.eof()) {
std::string json;
std::getline(ifs, json);
if (json.empty()) continue;
jsons.emplace_back(std::move(json), filename.path().string());
}
if (jsons.size() > 1000) Run();
}
Run();

std::cerr << "# failure = " << failure_cnt << "/" << total_cnt << " ("
<< 100.0 * failure_cnt / total_cnt << " %)" << std::endl;
return failure_cnt == 0;
}

TEST(internal_state, ToJson) {
// From https://tenhou.net/0/?log=2011020417gm-00a9-0000-b67fcaa3&tw=1
// w/o terminal state
Expand Down Expand Up @@ -1160,4 +1091,4 @@ TEST(internal_state, GeneratePastDecisions) {
return action.type() == mjxproto::ACTION_TYPE_RON;
}),
3);
}
}
82 changes: 82 additions & 0 deletions tests_cpp/utils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#include <google/protobuf/util/message_differencer.h>
#include <mjx/internal/state.h>
#include <mjx/internal/utils.h>

#include <filesystem>
#include <fstream>
#include <queue>
#include <thread>

#include "gtest/gtest.h"

using namespace mjx::internal;

template <typename F>
bool ParallelTest(F &&f) {
static std::mutex mtx_;
int total_cnt = 0;
int failure_cnt = 0;

auto Check = [&total_cnt, &failure_cnt, &f](int begin, int end,
const auto &jsons) {
// {
// std::lock_guard<std::mutex> lock(mtx_);
// std::cerr << std::this_thread::get_id() << " " << begin << " " << end
// << std::endl;
// }
int curr = begin;
while (curr < end) {
const auto &[json, filename] = jsons[curr];
bool ok = f(json);
{
std::lock_guard<std::mutex> lock(mtx_);
total_cnt++;
if (!ok) {
failure_cnt++;
std::cerr << filename << std::endl;
}
if (total_cnt % 1000 == 0)
std::cerr << "# failure = " << failure_cnt << "/" << total_cnt << " ("
<< 100.0 * failure_cnt / total_cnt << " %)" << std::endl;
}
curr++;
}
};

const auto thread_count = std::thread::hardware_concurrency();
std::vector<std::thread> threads;
std::vector<std::pair<std::string, std::string>> jsons;
std::string json_path = std::string(TEST_RESOURCES_DIR) + "/json";

auto Run = [&]() {
const int json_size = jsons.size();
const int size_per = json_size / thread_count;
for (int i = 0; i < thread_count; ++i) {
const int start_ix = i * size_per;
const int end_ix =
(i == thread_count - 1) ? json_size : (i + 1) * size_per;
threads.emplace_back(Check, start_ix, end_ix, jsons);
}
for (auto &t : threads) t.join();
threads.clear();
jsons.clear();
};

if (!json_path.empty())
for (const auto &filename :
std::filesystem::directory_iterator(json_path)) {
std::ifstream ifs(filename.path().string(), std::ios::in);
while (!ifs.eof()) {
std::string json;
std::getline(ifs, json);
if (json.empty()) continue;
jsons.emplace_back(std::move(json), filename.path().string());
}
if (jsons.size() > 1000) Run();
}
Run();

std::cerr << "# failure = " << failure_cnt << "/" << total_cnt << " ("
<< 100.0 * failure_cnt / total_cnt << " %)" << std::endl;
return failure_cnt == 0;
}
6 changes: 6 additions & 0 deletions tests_py/test_observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,9 @@ def test_from_proto():
obs = mjx.Observation(json_str)
proto_obs = obs.to_proto()
mjx.Observation.from_proto(proto_obs).to_json() == json_str


def test_add_legal_actions():
json_wo_legal_actions = '{"publicObservation":{"playerIds":["player_2","player_1","player_0","player_3"],"initScore":{"tens":[25000,25000,25000,25000]},"doraIndicators":[101],"events":[{"type":"EVENT_TYPE_DRAW"}]},"privateObservation":{"initHand":{"closedTiles":[24,3,87,124,37,42,58,134,92,82,122,18,117]},"drawHistory":[79],"currHand":{"closedTiles":[3,18,24,37,42,58,79,82,87,92,117,122,124,134]}}}'
json_restored = mjx.Observation.add_legal_actions(json_wo_legal_actions)
assert json_str == json_restored

0 comments on commit 13eb4a6

Please sign in to comment.