Skip to content

Commit

Permalink
factor out custom max_element
Browse files Browse the repository at this point in the history
  • Loading branch information
albestro committed Dec 2, 2024
1 parent 1d6bc2a commit 62c4814
Showing 1 changed file with 8 additions and 11 deletions.
19 changes: 8 additions & 11 deletions include/dlaf/auxiliary/norm/mc.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@

namespace dlaf::auxiliary::internal {

template <class T>
T max_element(std::vector<T>&& values) {
DLAF_ASSERT(!values.empty(), "");
return *std::max_element(values.begin(), values.end());
}

template <class T>
pika::execution::experimental::unique_any_sender<T> reduce_in_place(
pika::execution::experimental::unique_any_sender<dlaf::comm::CommunicatorPipelineExclusiveWrapper>
Expand Down Expand Up @@ -102,17 +108,12 @@ pika::execution::experimental::unique_any_sender<dlaf::BaseType<T>> Norm<

// then it is necessary to reduce max values from all ranks into a single max value for the matrix

auto max_element = [](std::vector<NormT>&& values) {
DLAF_ASSERT(!values.empty(), "");
return *std::max_element(values.begin(), values.end());
};

ex::unique_any_sender<NormT> local_max_value = ex::just(NormT{0});
if (!tiles_max.empty())
local_max_value =
ex::when_all_vector(std::move(tiles_max)) |
dlaf::internal::transform(dlaf::internal::Policy<Backend::MC>(thread_stacksize::nostack),
std::move(max_element));
max_element<NormT>);

return reduce_in_place(comm_grid.full_communicator_pipeline().exclusive(),
comm_grid.rankFullCommunicator(rank), MPI_MAX, std::move(local_max_value));
Expand Down Expand Up @@ -149,16 +150,12 @@ pika::execution::experimental::unique_any_sender<dlaf::BaseType<T>> Norm<

// then it is necessary to reduce max values from all ranks into a single max value for the matrix

auto max_element = [](std::vector<NormT>&& values) {
DLAF_ASSERT(!values.empty(), "");
return *std::max_element(values.begin(), values.end());
};
ex::unique_any_sender<NormT> local_max_value = ex::just(NormT{0});
if (!tiles_max.empty())
local_max_value =
ex::when_all_vector(std::move(tiles_max)) |
dlaf::internal::transform(dlaf::internal::Policy<Backend::MC>(thread_stacksize::nostack),
std::move(max_element));
max_element<NormT>);

return reduce_in_place(comm_grid.full_communicator_pipeline().exclusive(),
comm_grid.rankFullCommunicator(rank), MPI_MAX, std::move(local_max_value));
Expand Down

0 comments on commit 62c4814

Please sign in to comment.