Skip to content
This repository has been archived by the owner on Sep 15, 2021. It is now read-only.

fix: fix scatter and reduce_scatter implementation #40

Merged
merged 1 commit into from
Jul 16, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions bagua-core-internal/src/communicators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -820,13 +820,13 @@ impl BaguaCommunicatorInner {
cpp::cpp!([send_ptr as "void *", recv_ptr as "void *", count as "size_t", src as "int", communicator_ptr as "Al::NCCLCommunicator *", nccl_tensor_type as "ncclDataType_t"]
{
if (nccl_tensor_type == ncclDataType_t::ncclFloat32) {
Al::Gather<Al::NCCLBackend>(static_cast<float*>(send_ptr), static_cast<float*>(recv_ptr), count, src, *communicator_ptr);
Al::Scatter<Al::NCCLBackend>(static_cast<float*>(send_ptr), static_cast<float*>(recv_ptr), count, src, *communicator_ptr);
} else if (nccl_tensor_type == ncclDataType_t::ncclFloat16) {
Al::Gather<Al::NCCLBackend>(static_cast<__half*>(send_ptr), static_cast<__half*>(recv_ptr), count, src, *communicator_ptr);
Al::Scatter<Al::NCCLBackend>(static_cast<__half*>(send_ptr), static_cast<__half*>(recv_ptr), count, src, *communicator_ptr);
} else if (nccl_tensor_type == ncclDataType_t::ncclUint8) {
Al::Gather<Al::NCCLBackend>(static_cast<unsigned char*>(send_ptr), static_cast<unsigned char*>(recv_ptr), count, src, *communicator_ptr);
Al::Scatter<Al::NCCLBackend>(static_cast<unsigned char*>(send_ptr), static_cast<unsigned char*>(recv_ptr), count, src, *communicator_ptr);
} else if (nccl_tensor_type == ncclDataType_t::ncclInt64) {
Al::Gather<Al::NCCLBackend>(static_cast<long long int*>(send_ptr), static_cast<long long int*>(recv_ptr), count, src, *communicator_ptr);
Al::Scatter<Al::NCCLBackend>(static_cast<long long int*>(send_ptr), static_cast<long long int*>(recv_ptr), count, src, *communicator_ptr);
} else {
fputs("unsupport tensor data type.\n", stderr);
abort();
Expand Down Expand Up @@ -906,7 +906,7 @@ impl BaguaCommunicatorInner {
);
let communicator_ptr = self.comm_ptr;
let tensor_ptr = tensor.data_ptr();
let count = tensor.num_elements_allocated();
liuhatry marked this conversation as resolved.
Show resolved Hide resolved
let count = tensor.num_elements_allocated() / self.nranks;
let nccl_tensor_type = tensor.dtype().to_nccl_datatype();
unsafe {
cpp::cpp!([tensor_ptr as "void *", count as "size_t", op as "uint8_t", communicator_ptr as "Al::NCCLCommunicator *", nccl_tensor_type as "ncclDataType_t"]
Expand Down