Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use stackless threads where appropriate #1037

Merged
merged 9 commits into from
Nov 23, 2023
42 changes: 27 additions & 15 deletions include/dlaf/eigensolver/band_to_tridiag/mc.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,13 +229,15 @@ class BandBlock {
template <Device D, class Sender>
auto copy_diag(SizeType j, Sender source) noexcept {
using dlaf::internal::transform;
using pika::execution::thread_priority;
using pika::execution::thread_stacksize;
namespace ex = pika::execution::experimental;

constexpr auto B = dlaf::matrix::internal::CopyBackend_v<D, Device::CPU>;

if constexpr (D == Device::CPU) {
return transform(
dlaf::internal::Policy<B>(pika::execution::thread_priority::high),
dlaf::internal::Policy<B>(thread_priority::high, thread_stacksize::nostack),
[j, this](const matrix::Tile<const T, D>& source) {
constexpr auto General = blas::Uplo::General;
constexpr auto Lower = blas::Uplo::Lower;
Expand Down Expand Up @@ -269,7 +271,7 @@ class BandBlock {
else if constexpr (D == Device::GPU) {
DLAF_ASSERT_HEAVY(is_accessible_from_GPU(), "BandBlock memory should be accessible from GPU");
return transform(
dlaf::internal::Policy<B>(pika::execution::thread_priority::high),
dlaf::internal::Policy<B>(thread_priority::high),
[j, this](const matrix::Tile<const T, D>& source, whip::stream_t stream) {
constexpr auto General = blas::Uplo::General;
constexpr auto Lower = blas::Uplo::Lower;
Expand Down Expand Up @@ -306,14 +308,15 @@ class BandBlock {
template <Device D, class Sender>
auto copy_off_diag(const SizeType j, Sender source) noexcept {
using dlaf::internal::transform;

using pika::execution::thread_priority;
using pika::execution::thread_stacksize;
namespace ex = pika::execution::experimental;

constexpr auto B = dlaf::matrix::internal::CopyBackend_v<D, Device::CPU>;

if constexpr (D == Device::CPU) {
return transform(
dlaf::internal::Policy<B>(pika::execution::thread_priority::high),
dlaf::internal::Policy<B>(thread_priority::high, thread_stacksize::nostack),
[j, this](const matrix::Tile<const T, D>& source) {
constexpr auto General = blas::Uplo::General;
constexpr auto Upper = blas::Uplo::Upper;
Expand Down Expand Up @@ -346,7 +349,7 @@ class BandBlock {
else if constexpr (D == Device::GPU) {
DLAF_ASSERT_HEAVY(is_accessible_from_GPU(), "BandBlock memory should be accessible from GPU");
return transform(
dlaf::internal::Policy<B>(pika::execution::thread_priority::high),
dlaf::internal::Policy<B>(thread_priority::high),
[j, this](const matrix::Tile<const T, D>& source, whip::stream_t stream) {
constexpr auto General = blas::Uplo::General;
constexpr auto Upper = blas::Uplo::Upper;
Expand Down Expand Up @@ -682,14 +685,18 @@ TridiagResult<T, Device::CPU> BandToTridiag<Backend::MC, D, T>::call_L(
using common::internal::vector;
using util::ceilDiv;

using pika::execution::thread_priority;
using pika::execution::thread_stacksize;
using pika::resource::get_num_threads;
using SemaphorePtr = std::shared_ptr<pika::counting_semaphore<>>;
using TileVector = std::vector<matrix::Tile<T, Device::CPU>>;
using TileVectorPtr = std::shared_ptr<TileVector>;

namespace ex = pika::execution::experimental;

const auto policy_hp = dlaf::internal::Policy<Backend::MC>(pika::execution::thread_priority::high);
const auto policy_hp = dlaf::internal::Policy<Backend::MC>(thread_priority::high);
const auto policy_hp_nostack =
dlaf::internal::Policy<Backend::MC>(thread_priority::high, thread_stacksize::nostack);

// note: A is square and has square blocksize
const SizeType size = mat_a.size().cols();
Expand Down Expand Up @@ -825,12 +832,13 @@ TridiagResult<T, Device::CPU> BandToTridiag<Backend::MC, D, T>::call_L(
sem = std::move(sem_next);
}

auto copy_tridiag = [policy_hp, a_ws, size, nb, &mat_trid, copy_tridiag_task](SizeType i, auto&& dep) {
auto copy_tridiag = [policy_hp_nostack, a_ws, size, nb, &mat_trid, copy_tridiag_task](SizeType i,
auto&& dep) {
const auto tile_index = (i - 1) / nb;
const auto start = tile_index * nb;
ex::when_all(ex::just(start, std::min(nb, size - start), std::min(nb, size - 1 - start)),
mat_trid.readwrite(GlobalTileIndex{tile_index, 0}), std::forward<decltype(dep)>(dep)) |
dlaf::internal::transformDetach(policy_hp, copy_tridiag_task);
dlaf::internal::transformDetach(policy_hp_nostack, copy_tridiag_task);
};

auto dep = ex::just(std::move(sem)) |
Expand Down Expand Up @@ -1011,6 +1019,8 @@ TridiagResult<T, Device::CPU> BandToTridiag<Backend::MC, D, T>::call_L(
using matrix::internal::CopyBackend_v;
using util::ceilDiv;

using pika::execution::thread_priority;
using pika::execution::thread_stacksize;
using pika::resource::get_num_threads;
using SemaphorePtr = std::shared_ptr<pika::counting_semaphore<>>;
using Tile = matrix::Tile<T, Device::CPU>;
Expand Down Expand Up @@ -1045,7 +1055,9 @@ TridiagResult<T, Device::CPU> BandToTridiag<Backend::MC, D, T>::call_L(
const auto prev_rank = (rank == 0 ? ranks - 1 : rank - 1);
const auto next_rank = (rank + 1 == ranks ? 0 : rank + 1);

auto policy_hp = dlaf::internal::Policy<Backend::MC>(pika::execution::thread_priority::high);
auto policy_hp = dlaf::internal::Policy<Backend::MC>(thread_priority::high);
auto policy_hp_nostack =
dlaf::internal::Policy<Backend::MC>(thread_priority::high, thread_stacksize::nostack);

const SizeType nb_band = get1DBlockSize(nb);
const SizeType tiles_per_block = nb_band / nb;
Expand Down Expand Up @@ -1194,10 +1206,10 @@ TridiagResult<T, Device::CPU> BandToTridiag<Backend::MC, D, T>::call_L(
nr_release = ceilDiv(size - k * nb, b) + 1;
}

prev_dep =
ex::when_all(ex::just(nr_release, sems[k_block_local]), std::move(prev_dep),
std::move(dep)) |
dlaf::internal::transform(policy_hp, [](SizeType nr, auto&& sem) { sem->release(nr); });
prev_dep = ex::when_all(ex::just(nr_release, sems[k_block_local]), std::move(prev_dep),
std::move(dep)) |
dlaf::internal::transform(policy_hp_nostack,
[](SizeType nr, auto&& sem) { sem->release(nr); });
}
else {
if (rank == rank_diag) {
Expand Down Expand Up @@ -1469,15 +1481,15 @@ TridiagResult<T, Device::CPU> BandToTridiag<Backend::MC, D, T>::call_L(

// Rank 0 (owner of the first band matrix block) copies the last parts of the tridiag matrix.
if (rank == 0) {
auto copy_tridiag = [policy_hp, size, nb, &mat_trid, &copy_tridiag_task](
auto copy_tridiag = [policy_hp_nostack, size, nb, &mat_trid, &copy_tridiag_task](
std::shared_ptr<BandBlock<T, true>> a_block, SizeType i, auto&& dep) {
const auto tile_index = (i - 1) / nb;
const auto start = tile_index * nb;
ex::when_all(ex::just(std::move(a_block), start, std::min(nb, size - start),
std::min(nb, size - 1 - start)),
mat_trid.readwrite(GlobalTileIndex{tile_index, 0}),
std::forward<decltype(dep)>(dep)) |
dlaf::internal::transformDetach(policy_hp, copy_tridiag_task);
dlaf::internal::transformDetach(policy_hp_nostack, copy_tridiag_task);
};

auto dep = ex::just(std::move(sems[0])) |
Expand Down
30 changes: 21 additions & 9 deletions include/dlaf/eigensolver/bt_band_to_tridiag/impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -510,13 +510,15 @@ struct HHManager<Backend::MC, Device::CPU, T> {
auto computeVW(const SizeType nb_apply, const LocalTileIndex ij, const TileAccessHelper& helper,
SenderHH&& tile_hh, matrix::Panel<Coord::Col, T, D>& mat_v,
matrix::Panel<Coord::Col, T, D>& mat_t, matrix::Panel<Coord::Col, T, D>& mat_w) {
using pika::execution::thread_stacksize;
namespace ex = pika::execution::experimental;

return dlaf::internal::whenAllLift(b, std::forward<SenderHH>(tile_hh), nb_apply,
splitTile(mat_v.readwrite(ij), helper.specHH()),
splitTile(mat_t.readwrite(ij), helper.specT()),
splitTile(mat_w.readwrite(ij), helper.specHH())) |
dlaf::internal::transform(dlaf::internal::Policy<Backend::MC>(), bt_tridiag::computeVW<T>) |
dlaf::internal::transform(dlaf::internal::Policy<Backend::MC>(thread_stacksize::nostack),
bt_tridiag::computeVW<T>) |
ex::split_tuple();
}

Expand All @@ -538,6 +540,7 @@ struct HHManager<Backend::GPU, Device::GPU, T> {
auto computeVW(const SizeType hhr_nb, const LocalTileIndex ij, const TileAccessHelper& helper,
SenderHH&& tile_hh, matrix::Panel<Coord::Col, T, D>& mat_v,
matrix::Panel<Coord::Col, T, D>& mat_t, matrix::Panel<Coord::Col, T, D>& mat_w) {
using pika::execution::thread_stacksize;
namespace ex = pika::execution::experimental;

auto& mat_v_h = w_panels_h.nextResource();
Expand All @@ -550,7 +553,8 @@ struct HHManager<Backend::GPU, Device::GPU, T> {
dlaf::internal::whenAllLift(b, std::forward<SenderHH>(tile_hh), hhr_nb,
splitTile(mat_v_h.readwrite(ij), helper.specHH()),
splitTile(mat_t_h.readwrite(ij_t), t_spec)) |
dlaf::internal::transform(dlaf::internal::Policy<Backend::MC>(), computeVT<T>) |
dlaf::internal::transform(dlaf::internal::Policy<Backend::MC>(thread_stacksize::nostack),
computeVT<T>) |
ex::split_tuple();

auto copyVTandComputeW =
Expand Down Expand Up @@ -585,7 +589,7 @@ struct HHManager<Backend::GPU, Device::GPU, T> {
splitTile(mat_t.readwrite(ij_t), t_spec),
splitTile(mat_w.readwrite(ij), helper.specHH())) |
dlaf::internal::transform<dlaf::internal::TransformDispatchType::Blas>(
dlaf::internal::Policy<Backend::GPU>(), copyVTandComputeW) |
dlaf::internal::Policy<Backend::GPU>(thread_stacksize::nostack), copyVTandComputeW) |
ex::split_tuple();
}

Expand All @@ -601,6 +605,7 @@ template <Backend B, Device D, class T>
void BackTransformationT2B<B, D, T>::call(const SizeType band_size, Matrix<T, D>& mat_e,
Matrix<const T, Device::CPU>& mat_hh) {
using pika::execution::thread_priority;
using pika::execution::thread_stacksize;
namespace ex = pika::execution::experimental;

using common::iterate_range2d;
Expand Down Expand Up @@ -702,15 +707,17 @@ void BackTransformationT2B<B, D, T>::call(const SizeType band_size, Matrix<T, D>
ex::when_all(ex::just(group_size), tile_v, tile_w,
mat_w2.readwrite(LocalTileIndex(0, j_e)), mat_e_rt.readwrite(idx_e)) |
dlaf::internal::transform<dlaf::internal::TransformDispatchType::Blas>(
dlaf::internal::Policy<B>(thread_priority::normal), ApplyHHToSingleTileRow<B, T>{}));
dlaf::internal::Policy<B>(thread_priority::normal, thread_stacksize::nostack),
ApplyHHToSingleTileRow<B, T>{}));
}
else {
ex::start_detached(
ex::when_all(ex::just(group_size), tile_v, tile_w,
mat_w2.readwrite(LocalTileIndex(0, j_e)), mat_e_rt.readwrite(idx_e),
mat_e_rt.readwrite(helper.bottomIndexE(j_e))) |
dlaf::internal::transform<dlaf::internal::TransformDispatchType::Blas>(
dlaf::internal::Policy<B>(thread_priority::normal), ApplyHHToDoubleTileRow<B, T>{}));
dlaf::internal::Policy<B>(thread_priority::normal, thread_stacksize::nostack),
ApplyHHToDoubleTileRow<B, T>{}));
}
}

Expand All @@ -726,6 +733,7 @@ template <Backend B, Device D, class T>
void BackTransformationT2B<B, D, T>::call(comm::CommunicatorGrid grid, const SizeType band_size,
Matrix<T, D>& mat_e, Matrix<const T, Device::CPU>& mat_hh) {
using pika::execution::thread_priority;
using pika::execution::thread_stacksize;
namespace ex = pika::execution::experimental;

using common::iterate_range2d;
Expand Down Expand Up @@ -940,7 +948,8 @@ void BackTransformationT2B<B, D, T>::call(comm::CommunicatorGrid grid, const Siz
splitTile(mat_w2.readwrite(idx_w2), helper.specW2(nb)),
mat_e_rt.readwrite(idx_e_top)) |
dlaf::internal::transform<dlaf::internal::TransformDispatchType::Blas>(
dlaf::internal::Policy<B>(thread_priority::normal),
dlaf::internal::Policy<B>(thread_priority::normal,
thread_stacksize::nostack),
ApplyHHToSingleTileRow<B, T>{}));
}
// TWO ROWs
Expand All @@ -954,7 +963,8 @@ void BackTransformationT2B<B, D, T>::call(comm::CommunicatorGrid grid, const Siz
splitTile(mat_w2.readwrite(idx_w2), helper.specW2(nb)),
mat_e_rt.readwrite(idx_e_top), mat_e_rt.readwrite(idx_e_bottom)) |
dlaf::internal::transform<dlaf::internal::TransformDispatchType::Blas>(
dlaf::internal::Policy<B>(thread_priority::normal), ApplyHHToDoubleTileRow<B, T>{}));
dlaf::internal::Policy<B>(thread_priority::normal, thread_stacksize::nostack),
ApplyHHToDoubleTileRow<B, T>{}));
}
// TWO ROWs TWO RANKs UPDATE (MAIN + PARTNER)
else {
Expand All @@ -980,7 +990,8 @@ void BackTransformationT2B<B, D, T>::call(comm::CommunicatorGrid grid, const Siz
dlaf::internal::whenAllLift(blas::Op::ConjTrans, blas::Op::NoTrans, T(1),
std::move(subtile_v), std::move(subtile_e_ro), T(0),
splitTile(mat_w2tmp.readwrite(idx_w2), helper.specW2(nb))) |
dlaf::tile::gemm(dlaf::internal::Policy<B>(thread_priority::normal)));
dlaf::tile::gemm(dlaf::internal::Policy<B>(thread_priority::normal,
thread_stacksize::nostack)));

// Compute final W2 by adding the contribution from the partner rank
ex::start_detached( //
Expand All @@ -995,7 +1006,8 @@ void BackTransformationT2B<B, D, T>::call(comm::CommunicatorGrid grid, const Siz
std::move(subtile_w),
splitTile(mat_w2.read(idx_w2), helper.specW2(nb)), T(1),
std::move(subtile_e)) |
dlaf::tile::gemm(dlaf::internal::Policy<B>(thread_priority::normal)));
dlaf::tile::gemm(dlaf::internal::Policy<B>(thread_priority::normal,
thread_stacksize::nostack)));
}
}
}
Expand Down
13 changes: 9 additions & 4 deletions include/dlaf/eigensolver/bt_reduction_to_band/impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,13 @@ struct Helpers<Backend::GPU> {

template <Backend backend, typename SrcSender, typename DstSender>
void copyAndSetHHUpperTiles(SizeType j_diag, SrcSender&& src, DstSender&& dst) {
using pika::execution::thread_priority;
using pika::execution::thread_stacksize;
namespace ex = pika::execution::experimental;
using ElementType = dlaf::internal::SenderElementType<DstSender>;

ex::start_detached(dlaf::internal::transform(
dlaf::internal::Policy<backend>(pika::execution::thread_priority::high),
dlaf::internal::Policy<backend>(thread_priority::high, thread_stacksize::nostack),
Helpers<backend>::template copyAndSetHHUpperTiles<ElementType>,
dlaf::internal::whenAllLift(j_diag, std::forward<SrcSender>(src), std::forward<DstSender>(dst))));
}
Expand All @@ -83,36 +85,39 @@ template <Backend backend, class TSender, class SourcePanelSender, class PanelTi
void trmmPanel(pika::execution::thread_priority priority, TSender&& t, SourcePanelSender&& v,
PanelTileSender&& w) {
using ElementType = dlaf::internal::SenderElementType<PanelTileSender>;
using pika::execution::thread_stacksize;

pika::execution::experimental::start_detached(
dlaf::internal::whenAllLift(blas::Side::Right, blas::Uplo::Upper, blas::Op::ConjTrans,
blas::Diag::NonUnit, ElementType(1.0), std::forward<TSender>(t),
std::forward<SourcePanelSender>(v), std::forward<PanelTileSender>(w)) |
tile::trmm3(dlaf::internal::Policy<backend>(priority)));
tile::trmm3(dlaf::internal::Policy<backend>(priority, thread_stacksize::nostack)));
}

template <Backend backend, class PanelTileSender, class MatrixTileSender, class ColPanelSender>
void gemmUpdateW2(pika::execution::thread_priority priority, PanelTileSender&& w, MatrixTileSender&& c,
ColPanelSender&& w2) {
using ElementType = dlaf::internal::SenderElementType<PanelTileSender>;
using pika::execution::thread_stacksize;

pika::execution::experimental::start_detached(
dlaf::internal::whenAllLift(blas::Op::ConjTrans, blas::Op::NoTrans, ElementType(1.0),
std::forward<PanelTileSender>(w), std::forward<MatrixTileSender>(c),
ElementType(1.0), std::forward<ColPanelSender>(w2)) |
tile::gemm(dlaf::internal::Policy<backend>(priority)));
tile::gemm(dlaf::internal::Policy<backend>(priority, thread_stacksize::nostack)));
}

template <Backend backend, class PanelTileSender, class ColPanelSender, class MatrixTileSender>
void gemmTrailingMatrix(pika::execution::thread_priority priority, PanelTileSender&& v,
ColPanelSender&& w2, MatrixTileSender&& c) {
using ElementType = dlaf::internal::SenderElementType<PanelTileSender>;
using pika::execution::thread_stacksize;

pika::execution::experimental::start_detached(
dlaf::internal::whenAllLift(blas::Op::NoTrans, blas::Op::NoTrans, ElementType(-1.0),
std::forward<PanelTileSender>(v), std::forward<ColPanelSender>(w2),
ElementType(1.0), std::forward<MatrixTileSender>(c)) |
tile::gemm(dlaf::internal::Policy<backend>(priority)));
tile::gemm(dlaf::internal::Policy<backend>(priority, thread_stacksize::nostack)));
}
}

Expand Down
Loading
Loading