Skip to content
This repository has been archived by the owner on Sep 15, 2021. It is now read-only.

Commit

Permalink
feat: support reduce and allgather op with Reduction op enum
Browse files Browse the repository at this point in the history
  • Loading branch information
liuhatry authored Jun 30, 2021
1 parent 1903645 commit dee0b7a
Show file tree
Hide file tree
Showing 8 changed files with 74 additions and 28 deletions.
15 changes: 15 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions bagua-core-internal/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ serde = { version = "1", features = ["derive"] }
scheduled-thread-pool = "0.2"
serde_json = "1.0"
ureq = "2.1"
num-traits = "0.2"
num-derive = "0.3"

[dependencies.pyo3]
version = "0.13.2"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::comm_ops::CommOpTrait;
use crate::communicators::BaguaCommunicator;
use crate::datatypes::{BaguaBucket, BaguaTensorRaw};
use crate::datatypes::{BaguaBucket, BaguaReductionOp, BaguaTensorRaw};
use crate::resource_pool::CUDA_DEVICE_MEMORY_POOL;
use crate::BaguaCommOpChannels;
use std::sync::Arc;
Expand Down Expand Up @@ -54,7 +54,7 @@ impl CommOpTrait for CentralizedFullPrecisionSynchronous {
tracing::debug!("internode communication done")
} else {
tracing::debug!("start allreduce");
c.allreduce(&mut t.raw);
c.allreduce(&mut t.raw, BaguaReductionOp::SUM);
tracing::debug!("internode communication done");
if self.average {
t.raw.divide_inplace(stream_ptr, c.nranks as f32);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::comm_ops::CommOpTrait;
use crate::communicators::{BaguaCommunicator, BaguaHierarchicalCommunicator, NCCLGroupGuard};
use crate::datatypes::{BaguaBucket, BaguaTensorRaw};
use crate::datatypes::{BaguaBucket, BaguaReductionOp, BaguaTensorRaw};
use crate::resource_pool::CUDA_DEVICE_MEMORY_POOL;
use crate::{BaguaCommOpChannels, BaguaScheduledCommOp};
use parking_lot::Mutex;
Expand Down Expand Up @@ -72,7 +72,7 @@ impl CommOpTrait for DecentralizedFullPrecisionSynchronous {
if step % comm_interval == 0 {
peer_tensor.clone_from(&t.raw, c.stream_ptr);
let _guard = NCCLGroupGuard::new();
c.allreduce(&mut peer_tensor);
c.allreduce(&mut peer_tensor, BaguaReductionOp::SUM);
peer_tensor.divide_inplace(stream_ptr, c.nranks as f32);
}
}
Expand Down
39 changes: 20 additions & 19 deletions bagua-core-internal/src/communicators/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::datatypes::{BaguaCommunicationTensor, BaguaTensor, BaguaTensorRaw};
use crate::datatypes::{BaguaCommunicationTensor, BaguaReductionOp, BaguaTensor, BaguaTensorRaw};
use crate::resource_pool::CUDA_DEVICE_MEMORY_POOL;
use crate::BaguaCoreError;
use itertools::Itertools;
Expand Down Expand Up @@ -67,17 +67,18 @@ impl BaguaSingleCommunicator {
self.inner.device_id
}

pub fn allreduce(&self, tensor: &mut BaguaTensor) {
self.inner.allreduce(&mut tensor.inner.write().raw);
pub fn allreduce(&self, tensor: &mut BaguaTensor, op: BaguaReductionOp) {
self.inner.allreduce(&mut tensor.inner.write().raw, op);
}

pub fn broadcast(&self, tensor: &mut BaguaTensor, root_rank: i32) {
self.inner
.broadcast(&mut tensor.inner.write().raw, root_rank);
}

pub fn reduce(&self, tensor: &mut BaguaTensor, root_rank: i32) {
self.inner.reduce(&mut tensor.inner.write().raw, root_rank);
pub fn reduce(&self, tensor: &mut BaguaTensor, root_rank: i32, op: BaguaReductionOp) {
self.inner
.reduce(&mut tensor.inner.write().raw, root_rank, op);
}

pub fn send(&self, tensor: &mut BaguaTensor, peer_rank: i32) {
Expand Down Expand Up @@ -172,7 +173,7 @@ impl BaguaHierarchicalCommunicatorLeader {
let stream_ptr = intranode_communicator.stream_ptr;
assert_eq!(communication_tensor.stream_ptr, stream_ptr);
tracing::debug!("reduce start");
intranode_communicator.reduce(&mut communication_tensor.raw, 0);
intranode_communicator.reduce(&mut communication_tensor.raw, 0, BaguaReductionOp::SUM);
tracing::debug!("reduce done");
if average {
communication_tensor
Expand All @@ -199,7 +200,7 @@ pub struct BaguaHierarchicalCommunicatorWorker {
impl BaguaHierarchicalCommunicatorWorker {
pub fn hierarchical_worker_pre(&self, communication_tensor: &mut BaguaCommunicationTensor) {
let intranode_communicator = self.intranode.inner.clone();
intranode_communicator.reduce(&mut communication_tensor.raw, 0);
intranode_communicator.reduce(&mut communication_tensor.raw, 0, BaguaReductionOp::SUM);
}

pub fn hierarchical_worker_post(&self, communication_tensor: &mut BaguaCommunicationTensor) {
Expand Down Expand Up @@ -357,23 +358,23 @@ impl BaguaCommunicatorInner {
}
}

pub fn reduce(&self, tensor: &mut BaguaTensorRaw, root_rank: i32) {
pub fn reduce(&self, tensor: &mut BaguaTensorRaw, root_rank: i32, op: BaguaReductionOp) {
let communicator_ptr = self.comm_ptr;
let tensor_ptr = tensor.ptr;
let total_num_elem = tensor.num_elem_allocated;
let nccl_tensor_type = tensor.dtype.to_nccl_datatype();

unsafe {
cpp::cpp!([tensor_ptr as "void *", root_rank as "int", total_num_elem as "size_t", communicator_ptr as "Al::NCCLCommunicator *", nccl_tensor_type as "ncclDataType_t"]
cpp::cpp!([tensor_ptr as "void *", root_rank as "int", total_num_elem as "size_t", communicator_ptr as "Al::NCCLCommunicator *", nccl_tensor_type as "ncclDataType_t", op as "uint8_t"]
{
if (nccl_tensor_type == ncclDataType_t::ncclFloat32) {
Al::Reduce<Al::NCCLBackend>(static_cast<float*>(tensor_ptr), total_num_elem, Al::ReductionOperator::sum, root_rank, *communicator_ptr);
Al::Reduce<Al::NCCLBackend>(static_cast<float*>(tensor_ptr), total_num_elem, static_cast<Al::ReductionOperator>(op), root_rank, *communicator_ptr);
} else if (nccl_tensor_type == ncclDataType_t::ncclFloat16) {
Al::Reduce<Al::NCCLBackend>(static_cast<__half*>(tensor_ptr), total_num_elem, Al::ReductionOperator::sum, root_rank, *communicator_ptr);
Al::Reduce<Al::NCCLBackend>(static_cast<__half*>(tensor_ptr), total_num_elem, static_cast<Al::ReductionOperator>(op), root_rank, *communicator_ptr);
} else if (nccl_tensor_type == ncclDataType_t::ncclUint8) {
Al::Reduce<Al::NCCLBackend>(static_cast<unsigned char*>(tensor_ptr), total_num_elem, Al::ReductionOperator::sum, root_rank, *communicator_ptr);
Al::Reduce<Al::NCCLBackend>(static_cast<unsigned char*>(tensor_ptr), total_num_elem, static_cast<Al::ReductionOperator>(op), root_rank, *communicator_ptr);
} else if (nccl_tensor_type == ncclDataType_t::ncclInt64) {
Al::Reduce<Al::NCCLBackend>(static_cast<long long int*>(tensor_ptr), total_num_elem, Al::ReductionOperator::sum, root_rank, *communicator_ptr);
Al::Reduce<Al::NCCLBackend>(static_cast<long long int*>(tensor_ptr), total_num_elem, static_cast<Al::ReductionOperator>(op), root_rank, *communicator_ptr);
} else {
fputs("unsupport tensor data type.\n", stderr);
abort();
Expand Down Expand Up @@ -581,23 +582,23 @@ impl BaguaCommunicatorInner {
}
}

pub fn allreduce(&self, tensor: &mut BaguaTensorRaw) {
pub fn allreduce(&self, tensor: &mut BaguaTensorRaw, op: BaguaReductionOp) {
let communicator_ptr = self.comm_ptr;
let tensor_ptr = tensor.ptr;
let total_num_elem = tensor.num_elem_allocated;
let nccl_tensor_type = tensor.dtype.to_nccl_datatype();

unsafe {
cpp::cpp!([tensor_ptr as "void *", total_num_elem as "size_t", communicator_ptr as "Al::NCCLCommunicator *", nccl_tensor_type as "ncclDataType_t"]
cpp::cpp!([tensor_ptr as "void *", total_num_elem as "size_t", communicator_ptr as "Al::NCCLCommunicator *", nccl_tensor_type as "ncclDataType_t", op as "uint8_t"]
{
if (nccl_tensor_type == ncclDataType_t::ncclFloat32) {
Al::Allreduce<Al::NCCLBackend>(static_cast<float*>(tensor_ptr), total_num_elem, Al::ReductionOperator::sum, *communicator_ptr);
Al::Allreduce<Al::NCCLBackend>(static_cast<float*>(tensor_ptr), total_num_elem, static_cast<Al::ReductionOperator>(op), *communicator_ptr);
} else if (nccl_tensor_type == ncclDataType_t::ncclFloat16) {
Al::Allreduce<Al::NCCLBackend>(static_cast<__half*>(tensor_ptr), total_num_elem, Al::ReductionOperator::sum, *communicator_ptr);
Al::Allreduce<Al::NCCLBackend>(static_cast<__half*>(tensor_ptr), total_num_elem, static_cast<Al::ReductionOperator>(op), *communicator_ptr);
} else if (nccl_tensor_type == ncclDataType_t::ncclUint8) {
Al::Allreduce<Al::NCCLBackend>(static_cast<unsigned char*>(tensor_ptr), total_num_elem, Al::ReductionOperator::sum, *communicator_ptr);
Al::Allreduce<Al::NCCLBackend>(static_cast<unsigned char*>(tensor_ptr), total_num_elem, static_cast<Al::ReductionOperator>(op), *communicator_ptr);
} else if (nccl_tensor_type == ncclDataType_t::ncclInt64) {
Al::Allreduce<Al::NCCLBackend>(static_cast<long long int*>(tensor_ptr), total_num_elem, Al::ReductionOperator::sum, *communicator_ptr);
Al::Allreduce<Al::NCCLBackend>(static_cast<long long int*>(tensor_ptr), total_num_elem, static_cast<Al::ReductionOperator>(op), *communicator_ptr);
} else {
fputs("unsupport tensor data type.\n", stderr);
abort();
Expand Down
17 changes: 17 additions & 0 deletions bagua-core-internal/src/datatypes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,27 @@ use crate::resource_pool::{CudaMemory, CUDA_DEVICE_MEMORY_POOL};
use crate::telemetry::TELEMETRY;
use crate::{kernels, BaguaCoreError};
use itertools::Itertools;
use num_derive::FromPrimitive;
use num_traits::FromPrimitive;
use parking_lot::{Mutex, RwLock};
use sized_object_pool::DynamicPoolItem;
use std::sync::Arc;

// must be consistent with Aluminum ReductionOperator: https://github.com/BaguaSys/Aluminum/blob/master/include/aluminum/base.hpp
#[derive(Clone, Copy, Debug, PartialEq, FromPrimitive)]
pub enum BaguaReductionOp {
SUM,
PROD,
MIN,
MAX,
LOR,
LAND,
LXOR,
BOR,
BAND,
BXOR,
}

#[derive(Clone, Copy, Debug, PartialEq)]
pub enum BaguaTensorDtype {
F32,
Expand Down
2 changes: 2 additions & 0 deletions bagua-core-py/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ color-eyre = "0.5"
numpy = "0.13.2"
parking_lot = { version = "0.11", features = ["deadlock_detection"] }
openssl-sys = { version = "*", features = ["vendored"] }
num-traits = "0.2"
num-derive = "0.3"

[dependencies.pyo3]
version = "0.13"
Expand Down
19 changes: 14 additions & 5 deletions bagua-core-py/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
#![allow(clippy::needless_return)]

use bagua_core_internal::communicators::BaguaSingleCommunicator;
use bagua_core_internal::datatypes::{BaguaBucket, BaguaTensor, BaguaTensorDtype};
use bagua_core_internal::datatypes::{
BaguaBucket, BaguaReductionOp, BaguaTensor, BaguaTensorDtype,
};
use bagua_core_internal::BaguaCommBackend;
use num_derive::FromPrimitive;
use num_traits::FromPrimitive;
use numpy::{IntoPyArray, PyArray1};
use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::*;
Expand Down Expand Up @@ -47,16 +51,21 @@ impl BaguaSingleCommunicatorPy {
self.inner.device_id()
}

pub fn allreduce(&self, tensor: &mut BaguaTensorPy) {
self.inner.allreduce(&mut tensor.inner)
pub fn allreduce(&self, tensor: &mut BaguaTensorPy, op: u8) {
self.inner
.allreduce(&mut tensor.inner, BaguaReductionOp::from_u8(op).unwrap())
}

pub fn broadcast(&self, tensor: &mut BaguaTensorPy, root_rank: i32) {
self.inner.broadcast(&mut tensor.inner, root_rank)
}

pub fn reduce(&self, tensor: &mut BaguaTensorPy, root_rank: i32) {
self.inner.reduce(&mut tensor.inner, root_rank)
pub fn reduce(&self, tensor: &mut BaguaTensorPy, root_rank: i32, op: u8) {
self.inner.reduce(
&mut tensor.inner,
root_rank,
BaguaReductionOp::from_u8(op).unwrap(),
)
}

pub fn send(&self, tensor: &mut BaguaTensorPy, peer_rank: i32) {
Expand Down

0 comments on commit dee0b7a

Please sign in to comment.