Skip to content
This repository has been archived by the owner on Sep 15, 2021. It is now read-only.

feat: add (scatter, gather, scatter_reduce) and all inplace version communication primitives #37

Merged
merged 2 commits into from
Jul 15, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ impl CommOpTrait for CentralizedFullPrecisionSynchronous {
temp_tensor.reduce_sum_inplace(c.nranks, c.rank, c.stream_ptr);
}
tracing::debug!("start allgather");
c.allgather(&temp_tensor, &mut t.raw);
c.allgather2(&temp_tensor, &mut t.raw);
tracing::debug!("internode communication done")
} else {
tracing::debug!("start allreduce");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ impl CommOpTrait for CentralizedLowPrecisionSynchronous {
)
.expect("cannot compress tensor");
tracing::debug!("start allgather");
c.allgather(compressed_tensor.as_ref(), &mut temp_tensor);
c.allgather2(compressed_tensor.as_ref(), &mut temp_tensor);
NOBLES5E marked this conversation as resolved.
Show resolved Hide resolved
tracing::debug!("start decompress");
t.raw.decompress_from(
&self.compression_method,
Expand Down
251 changes: 241 additions & 10 deletions bagua-core-internal/src/communicators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,42 @@ impl BaguaSingleCommunicator {
);
}

pub fn allgather2(&self, send_tensor: &mut BaguaTensor, recv_tensor: &mut BaguaTensor) {
self.inner.allgather2(
send_tensor.inner.write().raw.as_mut(),
recv_tensor.inner.write().raw.as_mut(),
);
}

pub fn gather(&self, send_tensor: &mut BaguaTensor, recv_tensor: &mut BaguaTensor, dst: i32) {
self.inner.gather(
send_tensor.inner.write().raw.as_mut(),
recv_tensor.inner.write().raw.as_mut(),
dst,
);
}

pub fn scatter(&self, send_tensor: &mut BaguaTensor, recv_tensor: &mut BaguaTensor, src: i32) {
self.inner.scatter(
send_tensor.inner.write().raw.as_mut(),
recv_tensor.inner.write().raw.as_mut(),
src,
);
}

pub fn reduce_scatter(
&self,
send_tensor: &mut BaguaTensor,
recv_tensor: &mut BaguaTensor,
op: BaguaReductionOp,
) {
self.inner.reduce_scatter(
send_tensor.inner.write().raw.as_mut(),
recv_tensor.inner.write().raw.as_mut(),
op,
);
}

pub fn barrier(&self) {
self.inner.barrier();
}
Expand Down Expand Up @@ -388,12 +424,16 @@ impl BaguaCommunicatorInner {

pub fn alltoall(&self, send_tensor: &dyn RawBaguaTensor, recv_tensor: &mut dyn RawBaguaTensor) {
let communicator_ptr = self.comm_ptr;
// TODO: also check recv buf?
assert_eq!(send_tensor.dtype(), recv_tensor.dtype());
assert_eq!(
send_tensor.num_elements_allocated() % self.nranks,
0,
"tensors must be aligned before using allscatter"
);
assert_eq!(
send_tensor.num_elements_allocated(),
recv_tensor.num_elements_allocated(),
);
let send_chunk_size = send_tensor.num_elements_allocated() / self.nranks;
let nccl_tensor_type = send_tensor.dtype().to_nccl_datatype();

Expand Down Expand Up @@ -515,10 +555,58 @@ impl BaguaCommunicatorInner {
}
}

pub fn allgather_impl(
&self,
send_ptr: u64,
recv_ptr: u64,
count: usize,
communicator_ptr: u64,
nccl_tensor_type: i32,
) {
unsafe {
cpp::cpp!([send_ptr as "void *", recv_ptr as "void *", count as "size_t", communicator_ptr as "Al::NCCLCommunicator *", nccl_tensor_type as "ncclDataType_t"]
{
if (nccl_tensor_type == ncclDataType_t::ncclFloat32) {
Al::Allgather<Al::NCCLBackend>(static_cast<float*>(send_ptr), static_cast<float*>(recv_ptr), count, *communicator_ptr);
} else if (nccl_tensor_type == ncclDataType_t::ncclFloat16) {
Al::Allgather<Al::NCCLBackend>(static_cast<__half*>(send_ptr), static_cast<__half*>(recv_ptr), count, *communicator_ptr);
} else if (nccl_tensor_type == ncclDataType_t::ncclUint8) {
Al::Allgather<Al::NCCLBackend>(static_cast<unsigned char*>(send_ptr), static_cast<unsigned char*>(recv_ptr), count, *communicator_ptr);
} else if (nccl_tensor_type == ncclDataType_t::ncclInt64) {
Al::Allgather<Al::NCCLBackend>(static_cast<long long int*>(send_ptr), static_cast<long long int*>(recv_ptr), count, *communicator_ptr);
} else {
fputs("unsupport tensor data type.\n", stderr);
abort();
}
});
}
}

pub fn allgather(
&self,
send_tensor: &dyn RawBaguaTensor,
recv_tensor: &mut dyn RawBaguaTensor,
) {
let communicator_ptr = self.comm_ptr;
let send_ptr = send_tensor.data_ptr();
let recv_ptr = recv_tensor.data_ptr();
let count = send_tensor.num_elements_allocated();
let nccl_tensor_type = send_tensor.dtype().to_nccl_datatype();
assert_eq!(count * self.nranks, recv_tensor.num_elements_allocated());
assert_eq!(send_tensor.dtype(), recv_tensor.dtype());
self.allgather_impl(
send_ptr,
recv_ptr,
count,
communicator_ptr,
nccl_tensor_type,
);
}

pub fn allgather2(
&self,
send_tensor: &dyn RawBaguaTensor,
recv_tensor: &mut dyn RawBaguaTensor,
) {
let communicator_ptr = self.comm_ptr;
let send_tensor_ptr = send_tensor.data_ptr();
Expand All @@ -532,24 +620,42 @@ impl BaguaCommunicatorInner {
0,
"tensors must be aligned before using allgather"
);
let send_chunk_size = send_tensor.num_elements_allocated() / self.nranks;
let count = send_tensor.num_elements_allocated() / self.nranks;
let nccl_tensor_type = send_tensor.dtype().to_nccl_datatype();

let send_buf_ptr = send_tensor_ptr
+ self.rank as u64 * send_chunk_size as u64 * send_tensor.dtype().bytes() as u64;
let recv_buf_ptr = recv_tensor.data_ptr();
let send_ptr =
send_tensor_ptr + self.rank as u64 * count as u64 * send_tensor.dtype().bytes() as u64;
let recv_ptr = recv_tensor.data_ptr();

self.allgather_impl(
send_ptr,
recv_ptr,
count,
communicator_ptr,
nccl_tensor_type,
);
}

pub fn gather_impl(
&self,
send_ptr: u64,
recv_ptr: u64,
count: usize,
dst: i32,
communicator_ptr: u64,
nccl_tensor_type: i32,
) {
unsafe {
cpp::cpp!([recv_buf_ptr as "void *", send_buf_ptr as "void *", send_chunk_size as "size_t", communicator_ptr as "Al::NCCLCommunicator *", nccl_tensor_type as "ncclDataType_t"]
cpp::cpp!([send_ptr as "void *", recv_ptr as "void *", count as "size_t", dst as "int", communicator_ptr as "Al::NCCLCommunicator *", nccl_tensor_type as "ncclDataType_t"]
{
if (nccl_tensor_type == ncclDataType_t::ncclFloat32) {
Al::Allgather<Al::NCCLBackend>(static_cast<float*>(send_buf_ptr), static_cast<float*>(recv_buf_ptr), send_chunk_size, *communicator_ptr);
Al::Gather<Al::NCCLBackend>(static_cast<float*>(send_ptr), static_cast<float*>(recv_ptr), count, dst, *communicator_ptr);
} else if (nccl_tensor_type == ncclDataType_t::ncclFloat16) {
Al::Allgather<Al::NCCLBackend>(static_cast<__half*>(send_buf_ptr), static_cast<__half*>(recv_buf_ptr), send_chunk_size, *communicator_ptr);
Al::Gather<Al::NCCLBackend>(static_cast<__half*>(send_ptr), static_cast<__half*>(recv_ptr), count, dst, *communicator_ptr);
} else if (nccl_tensor_type == ncclDataType_t::ncclUint8) {
Al::Allgather<Al::NCCLBackend>(static_cast<unsigned char*>(send_buf_ptr), static_cast<unsigned char*>(recv_buf_ptr), send_chunk_size, *communicator_ptr);
Al::Gather<Al::NCCLBackend>(static_cast<unsigned char*>(send_ptr), static_cast<unsigned char*>(recv_ptr), count, dst, *communicator_ptr);
} else if (nccl_tensor_type == ncclDataType_t::ncclInt64) {
Al::Allgather<Al::NCCLBackend>(static_cast<long long int*>(send_buf_ptr), static_cast<long long int*>(recv_buf_ptr), send_chunk_size, *communicator_ptr);
Al::Gather<Al::NCCLBackend>(static_cast<long long int*>(send_ptr), static_cast<long long int*>(recv_ptr), count, dst, *communicator_ptr);
} else {
fputs("unsupport tensor data type.\n", stderr);
abort();
Expand All @@ -558,6 +664,131 @@ impl BaguaCommunicatorInner {
}
}

pub fn gather(
&self,
send_tensor: &dyn RawBaguaTensor,
recv_tensor: &mut dyn RawBaguaTensor,
dst: i32,
) {
let communicator_ptr = self.comm_ptr;
let send_ptr = send_tensor.data_ptr();
let recv_ptr = recv_tensor.data_ptr();
let count = send_tensor.num_elements_allocated();
let nccl_tensor_type = send_tensor.dtype().to_nccl_datatype();
assert_eq!(count * self.nranks, recv_tensor.num_elements_allocated());
assert_eq!(send_tensor.dtype(), recv_tensor.dtype());
self.gather_impl(
send_ptr,
recv_ptr,
count,
dst,
communicator_ptr,
nccl_tensor_type,
);
}

pub fn scatter_impl(
&self,
send_ptr: u64,
recv_ptr: u64,
count: usize,
src: i32,
communicator_ptr: u64,
nccl_tensor_type: i32,
) {
unsafe {
cpp::cpp!([send_ptr as "void *", recv_ptr as "void *", count as "size_t", src as "int", communicator_ptr as "Al::NCCLCommunicator *", nccl_tensor_type as "ncclDataType_t"]
{
if (nccl_tensor_type == ncclDataType_t::ncclFloat32) {
Al::Gather<Al::NCCLBackend>(static_cast<float*>(send_ptr), static_cast<float*>(recv_ptr), count, src, *communicator_ptr);
} else if (nccl_tensor_type == ncclDataType_t::ncclFloat16) {
Al::Gather<Al::NCCLBackend>(static_cast<__half*>(send_ptr), static_cast<__half*>(recv_ptr), count, src, *communicator_ptr);
} else if (nccl_tensor_type == ncclDataType_t::ncclUint8) {
Al::Gather<Al::NCCLBackend>(static_cast<unsigned char*>(send_ptr), static_cast<unsigned char*>(recv_ptr), count, src, *communicator_ptr);
} else if (nccl_tensor_type == ncclDataType_t::ncclInt64) {
Al::Gather<Al::NCCLBackend>(static_cast<long long int*>(send_ptr), static_cast<long long int*>(recv_ptr), count, src, *communicator_ptr);
} else {
fputs("unsupport tensor data type.\n", stderr);
abort();
}
});
}
}

pub fn scatter(
&self,
send_tensor: &dyn RawBaguaTensor,
recv_tensor: &mut dyn RawBaguaTensor,
src: i32,
) {
let communicator_ptr = self.comm_ptr;
let send_ptr = send_tensor.data_ptr();
let recv_ptr = recv_tensor.data_ptr();
let count = recv_tensor.num_elements_allocated();
let nccl_tensor_type = send_tensor.dtype().to_nccl_datatype();
assert_eq!(count * self.nranks, send_tensor.num_elements_allocated());
assert_eq!(send_tensor.dtype(), recv_tensor.dtype());
self.scatter_impl(
send_ptr,
recv_ptr,
count,
src,
communicator_ptr,
nccl_tensor_type,
);
}

pub fn reduce_scatter_impl(
&self,
send_ptr: u64,
recv_ptr: u64,
count: usize,
communicator_ptr: u64,
nccl_tensor_type: i32,
op: BaguaReductionOp,
) {
unsafe {
cpp::cpp!([send_ptr as "void *", recv_ptr as "void *", count as "size_t", op as "uint8_t", communicator_ptr as "Al::NCCLCommunicator *", nccl_tensor_type as "ncclDataType_t"]
{
if (nccl_tensor_type == ncclDataType_t::ncclFloat32) {
Al::Reduce_scatter<Al::NCCLBackend>(static_cast<const float*>(send_ptr), static_cast<float*>(recv_ptr), count, static_cast<Al::ReductionOperator>(op), *communicator_ptr);
} else if (nccl_tensor_type == ncclDataType_t::ncclFloat16) {
Al::Reduce_scatter<Al::NCCLBackend>(static_cast<const __half*>(send_ptr), static_cast<__half*>(recv_ptr), count, static_cast<Al::ReductionOperator>(op), *communicator_ptr);
} else if (nccl_tensor_type == ncclDataType_t::ncclUint8) {
Al::Reduce_scatter<Al::NCCLBackend>(static_cast<const unsigned char*>(send_ptr), static_cast<unsigned char*>(recv_ptr), count, static_cast<Al::ReductionOperator>(op), *communicator_ptr);
} else if (nccl_tensor_type == ncclDataType_t::ncclInt64) {
Al::Reduce_scatter<Al::NCCLBackend>(static_cast<const long long int*>(send_ptr), static_cast<long long int*>(recv_ptr), count, static_cast<Al::ReductionOperator>(op), *communicator_ptr);
} else {
fputs("unsupport tensor data type.\n", stderr);
abort();
}
});
}
}

pub fn reduce_scatter(
&self,
send_tensor: &dyn RawBaguaTensor,
recv_tensor: &mut dyn RawBaguaTensor,
op: BaguaReductionOp,
) {
let communicator_ptr = self.comm_ptr;
let send_ptr = send_tensor.data_ptr();
let recv_ptr = recv_tensor.data_ptr();
let count = recv_tensor.num_elements_allocated();
let nccl_tensor_type = send_tensor.dtype().to_nccl_datatype();
assert_eq!(count * self.nranks, send_tensor.num_elements_allocated());
assert_eq!(send_tensor.dtype(), recv_tensor.dtype());
self.reduce_scatter_impl(
send_ptr,
recv_ptr,
count,
communicator_ptr,
nccl_tensor_type,
op,
);
}

pub fn barrier(&self) {
let communicator_ptr = self.comm_ptr;

Expand Down
33 changes: 33 additions & 0 deletions bagua-core-py/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,39 @@ impl BaguaSingleCommunicatorPy {
.allgather(&mut send_tensor.inner, &mut recv_tensor.inner)
}

pub fn gather(
&self,
send_tensor: &mut BaguaTensorPy,
recv_tensor: &mut BaguaTensorPy,
dst: i32,
) {
self.inner
.gather(&mut send_tensor.inner, &mut recv_tensor.inner, dst)
}

pub fn scatter(
&self,
send_tensor: &mut BaguaTensorPy,
recv_tensor: &mut BaguaTensorPy,
src: i32,
) {
self.inner
.scatter(&mut send_tensor.inner, &mut recv_tensor.inner, src)
}

pub fn reduce_scatter(
&self,
send_tensor: &mut BaguaTensorPy,
recv_tensor: &mut BaguaTensorPy,
op: u8,
) {
self.inner.reduce_scatter(
&mut send_tensor.inner,
&mut recv_tensor.inner,
BaguaReductionOp::from_u8(op).unwrap(),
)
}

pub fn barrier(&self) {
self.inner.barrier()
}
Expand Down