diff --git a/include/hspp.h b/include/hspp.h index a7b59a3..c8bfc6a 100644 --- a/include/hspp.h +++ b/include/hspp.h @@ -4129,6 +4129,19 @@ constexpr auto apply = toGFunc<1> | [](auto p) } // namespace parser +// For io +constexpr auto mapM_ = toGFunc<2> | [](auto func, auto lst) +{ + return data::io([=] + { + for (auto e : lst) + { + (e >>= func).run(); + } + return _o_; + }); +}; + } // namespace hspp #endif // HSPP_PARSER_H diff --git a/test/hspp/stm.cpp b/test/hspp/stm.cpp index e9c9390..97db2d1 100644 --- a/test/hspp/stm.cpp +++ b/test/hspp/stm.cpp @@ -419,7 +419,6 @@ struct TVar IORef>> waitQueue; }; -template struct RSE { ID id; @@ -428,6 +427,12 @@ struct RSE IORef>> waitQueue; }; +bool operator<(RSE const& lhs, RSE const& rhs) +{ + auto result = (lhs.id rhs.id); + return result == Ordering::kLT; +} + template struct WSE { @@ -443,15 +448,19 @@ constexpr auto toWSE = toGFunc<5> | [](Lock lock, IORef writeStamp, aut return WSE{lock, writeStamp, content, waitQueue, newValue}; }; -// optimize me later class ReadSet { - using T = std::map; + using T = std::set; public: std::shared_ptr data = std::make_shared(); }; -using WriteSet = ReadSet; +class WriteSet +{ + using T = std::map; +public: + std::shared_ptr data = std::make_shared(); +}; using TId = std::thread::id; using Stamp = Integer; @@ -474,7 +483,8 @@ IORef globalClock{initIORef(1)}; constexpr auto readIORef = toGFunc<1> | [](auto const& ioRef) { - return io([&ioRef]{ + return io([&ioRef] + { return ioRef.data->load(); }); }; @@ -583,7 +593,14 @@ constexpr auto putWS = toFunc<> | [](WriteSet ws, ID id, std::any ptr) }); }; -constexpr auto putRS = putWS; +constexpr auto putRS = toFunc<> | [](ReadSet rs, RSE entry) +{ + return io([=] + { + rs.data->insert(entry); + return _o_; + }); +}; constexpr auto lookUpWS = toFunc<> | [](WriteSet ws, ID id) { @@ -730,9 +747,15 @@ class MonadBase namespace concurrent { +constexpr auto myTId = io([] +{ + return 2 * std::hash{}(std::this_thread::get_id()); +}); + constexpr auto newTState = io([] { - return TState{std::this_thread::get_id(), {}, {}, {}}; + auto const readStamp = readIORef(globalClock).run(); + return TState{std::this_thread::get_id(), readStamp, {}, {}}; }); template @@ -747,10 +770,13 @@ auto atomicallyImpl(STM const& stmac) return std::visit(overload( [=](Valid const& v_) -> A { - (void)v_; - // auto [nts, a] = v_; - // auto transid = nts.transId; - // auto writeSet = nts.writeSet; + auto [nts, a] = v_; + auto ti = myTId.run(); + assert(ti == (nts.transId)); + (void)ti; + (void)a; + // auto ws = nts.writeSet; + // tup = getLocks | nts.transId | wslist; return A{}; }, [=](Retry const&) diff --git a/test/hspp/test.cpp b/test/hspp/test.cpp index badb9a2..287218a 100644 --- a/test/hspp/test.cpp +++ b/test/hspp/test.cpp @@ -1474,3 +1474,20 @@ TEST(do_, comprehension3) auto const expected = std::vector>{ { 3, 4, 5 }, { 6, 8, 10 }, { 5, 12, 13 }, { 9, 12, 15 }, { 8, 15, 17 } }; EXPECT_EQ(result, expected); } + +TEST(MapM_, IO) +{ + testing::internal::CaptureStdout(); + + const auto lst = std::vector{ioData("3"s), ioData("4"s)}; + const auto func = putStrLn; + const auto mapM_result = mapM_ | func | lst; + + std::string output0 = testing::internal::GetCapturedStdout(); + EXPECT_EQ(output0, ""); + + testing::internal::CaptureStdout(); + EXPECT_EQ(mapM_result.run(), _o_); + std::string output1 = testing::internal::GetCapturedStdout(); + EXPECT_EQ(output1, "3\n4\n"); +}