Skip to content

Commit

Permalink
fixup! [dist-mat] use row-gatherer
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcelKoch committed Oct 8, 2024
1 parent 341e781 commit b2025a8
Showing 1 changed file with 20 additions and 15 deletions.
35 changes: 20 additions & 15 deletions core/distributed/matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -394,33 +394,38 @@ void Matrix<ValueType, LocalIndexType, GlobalIndexType>::col_scale(
auto comm = this->get_communicator();
size_type n_local_cols = local_mtx_->get_size()[1];
size_type n_non_local_cols = non_local_mtx_->get_size()[1];

std::unique_ptr<global_vector_type> scaling_factors_single_stride;
auto stride = scaling_factors->get_stride();
if (stride != 1) {
auto scaling_stride = scaling_factors->get_stride();
if (scaling_stride != 1) {
scaling_factors_single_stride = global_vector_type::create(exec, comm);
scaling_factors_single_stride->copy_from(scaling_factors.get());
}
const auto scale_values =
stride == 1 ? scaling_factors->get_const_local_values()
: scaling_factors_single_stride->get_const_local_values();
const global_vector_type* scaling_factors_ptr =
scaling_stride == 1 ? scaling_factors.get()
: scaling_factors_single_stride.get();
const auto scale_diag = gko::matrix::Diagonal<ValueType>::create_const(
exec, n_local_cols,
make_const_array_view(exec, n_local_cols, scale_values));

auto req = this->communicate(
stride == 1 ? scaling_factors->get_local_vector()
: scaling_factors_single_stride->get_local_vector());
make_const_array_view(exec, n_local_cols,
scaling_factors_ptr->get_const_local_values()));

auto recv_dim = dim<2>{
static_cast<size_type>(
row_gatherer_->get_collective_communicator()->get_recv_size()),
scaling_factors->get_size()[1]};
auto recv_exec =
mpi::requires_host_buffer(exec, comm) ? exec->get_master() : exec;
recv_buffer_.init(recv_exec, recv_dim);

auto req =
row_gatherer_->apply_async(scaling_factors_ptr, recv_buffer_.get());
scale_diag->rapply(local_mtx_, local_mtx_);
req.wait();
if (n_non_local_cols > 0) {
auto use_host_buffer = mpi::requires_host_buffer(exec, comm);
if (use_host_buffer) {
recv_buffer_->copy_from(host_recv_buffer_.get());
}
const auto non_local_scale_diag =
gko::matrix::Diagonal<ValueType>::create_const(
exec, n_non_local_cols,
make_const_array_view(exec, n_non_local_cols,
make_const_array_view(recv_exec, n_non_local_cols,
recv_buffer_->get_const_values()));
non_local_scale_diag->rapply(non_local_mtx_, non_local_mtx_);
}
Expand Down

0 comments on commit b2025a8

Please sign in to comment.