diff --git a/bagua-core-internal/src/communicators/mod.rs b/bagua-core-internal/src/communicators/mod.rs index 2b0f3ae..229f0ec 100644 --- a/bagua-core-internal/src/communicators/mod.rs +++ b/bagua-core-internal/src/communicators/mod.rs @@ -820,13 +820,13 @@ impl BaguaCommunicatorInner { cpp::cpp!([send_ptr as "void *", recv_ptr as "void *", count as "size_t", src as "int", communicator_ptr as "Al::NCCLCommunicator *", nccl_tensor_type as "ncclDataType_t"] { if (nccl_tensor_type == ncclDataType_t::ncclFloat32) { - Al::Gather(static_cast(send_ptr), static_cast(recv_ptr), count, src, *communicator_ptr); + Al::Scatter(static_cast(send_ptr), static_cast(recv_ptr), count, src, *communicator_ptr); } else if (nccl_tensor_type == ncclDataType_t::ncclFloat16) { - Al::Gather(static_cast<__half*>(send_ptr), static_cast<__half*>(recv_ptr), count, src, *communicator_ptr); + Al::Scatter(static_cast<__half*>(send_ptr), static_cast<__half*>(recv_ptr), count, src, *communicator_ptr); } else if (nccl_tensor_type == ncclDataType_t::ncclUint8) { - Al::Gather(static_cast(send_ptr), static_cast(recv_ptr), count, src, *communicator_ptr); + Al::Scatter(static_cast(send_ptr), static_cast(recv_ptr), count, src, *communicator_ptr); } else if (nccl_tensor_type == ncclDataType_t::ncclInt64) { - Al::Gather(static_cast(send_ptr), static_cast(recv_ptr), count, src, *communicator_ptr); + Al::Scatter(static_cast(send_ptr), static_cast(recv_ptr), count, src, *communicator_ptr); } else { fputs("unsupport tensor data type.\n", stderr); abort(); @@ -906,7 +906,7 @@ impl BaguaCommunicatorInner { ); let communicator_ptr = self.comm_ptr; let tensor_ptr = tensor.data_ptr(); - let count = tensor.num_elements_allocated(); + let count = tensor.num_elements_allocated() / self.nranks; let nccl_tensor_type = tensor.dtype().to_nccl_datatype(); unsafe { cpp::cpp!([tensor_ptr as "void *", count as "size_t", op as "uint8_t", communicator_ptr as "Al::NCCLCommunicator *", nccl_tensor_type as "ncclDataType_t"]