Skip to content
This repository has been archived by the owner on Sep 15, 2021. It is now read-only.

Commit

Permalink
feat: support creating BaguaTensor by passing torch tensor directly (#19
Browse files Browse the repository at this point in the history
)

BREAKING CHANGE: `BaguaBucketPy` and `BaguaTensorPy` now require name. `BaguaTensorPy` is created by passing pytorch tensor directly now.
  • Loading branch information
NOBLES5E authored Jun 30, 2021
1 parent dee0b7a commit c9b9473
Show file tree
Hide file tree
Showing 15 changed files with 2,342 additions and 471 deletions.
21 changes: 14 additions & 7 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 4 additions & 2 deletions bagua-core-c/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use bagua_core_internal::communicators::BaguaSingleCommunicator;
use std::ffi::CStr;
use std::os::raw::c_char;

pub struct BaguaSingleCommunicatorC {
inner: BaguaSingleCommunicator,
Expand All @@ -9,15 +11,15 @@ pub extern "C" fn bagua_single_communicator_c_create(
nranks: usize,
device_id: usize,
stream_ptr: u64,
nccl_unique_id_str: &str,
nccl_unique_id_str: *const c_char,
) -> *mut BaguaSingleCommunicatorC {
let obj = BaguaSingleCommunicatorC {
inner: bagua_core_internal::communicators::BaguaSingleCommunicator::new(
rank,
nranks,
device_id,
stream_ptr,
nccl_unique_id_str,
unsafe { CStr::from_ptr(nccl_unique_id_str).to_str().unwrap() },
),
};

Expand Down
2 changes: 1 addition & 1 deletion bagua-core-internal/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ itertools = "0.10"
shadow-rs = "0.6"
parking_lot = { version = "0.11", features = ["deadlock_detection"] }
hashbrown = "0.11"
lazy_id = "0.1"
flume = "0.10"
derivative = "2.2.0"
oneshot = "0.1"
cpp = "0.5"
sized-object-pool = "0.2"
Expand Down
14 changes: 9 additions & 5 deletions bagua-core-internal/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,18 @@ fn main() {
let mut cuda_cc = cc::Build::new();
cuda_cc
.cuda(true)
.opt_level(3)
.include("cpp/include")
.include("third_party/cub-1.8.0")
.include("../python/bagua_core/.data/include")
.flag("-std=c++14")
.flag("-cudart=shared");
for sm in supported_sms {
cuda_cc
.flag("-gencode")
.flag(format!("arch=compute_{},code=sm_{}", sm, sm).as_str());

if std::env::var("PROFILE").unwrap() == "release" {
for sm in supported_sms {
cuda_cc
.flag("-gencode")
.flag(format!("arch=compute_{},code=sm_{}", sm, sm).as_str());
}
}
cuda_cc
.file("kernels/bagua_kernels.cu")
Expand Down Expand Up @@ -82,5 +84,7 @@ fn main() {
println!("cargo:rerun-if-changed=src/");
println!("cargo:rerun-if-changed=kernels/");
println!("cargo:rerun-if-changed=build.rs");

// bindgen --allowlist-type '.*TensorImpl.*' --enable-cxx-namespaces --ignore-functions --ignore-methods --size_t-is-usize --default-enum-style=rust --opaque-type 'std.*' --opaque-type 'c10::optional.*' wrapper.h -- -x c++ -std=c++14 > src/torch_ffi.rs
shadow_rs::new().unwrap();
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::comm_ops::CommOpTrait;
use crate::communicators::BaguaCommunicator;
use crate::datatypes::{BaguaBucket, BaguaReductionOp, BaguaTensorRaw};
use crate::datatypes::{BaguaBucket, BaguaReductionOp, BaguaTensorRaw, RawBaguaTensor};
use crate::resource_pool::CUDA_DEVICE_MEMORY_POOL;
use crate::BaguaCommOpChannels;
use std::sync::Arc;
Expand Down Expand Up @@ -38,7 +38,7 @@ impl CommOpTrait for CentralizedFullPrecisionSynchronous {
dtype: t.raw.dtype.clone(),
num_elem: t.raw.num_elem,
device_id: t.raw.device_id,
pool_allocation: Some(temp_buf),
pool_allocations: vec![Arc::new(temp_buf)],
};
if self.scattergather {
tracing::debug!("start alltoall");
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::comm_ops::CommOpTrait;
use crate::communicators::BaguaCommunicator;
use crate::datatypes::{BaguaBucket, BaguaTensorRaw, TensorCompressionMethod};
use crate::datatypes::{BaguaBucket, BaguaTensorRaw, RawBaguaTensor, TensorCompressionMethod};
use crate::resource_pool::CUDA_DEVICE_MEMORY_POOL;
use crate::BaguaCommOpChannels;
use std::sync::Arc;
Expand Down Expand Up @@ -35,19 +35,20 @@ impl CommOpTrait for CentralizedLowPrecisionSynchronous {
.expect("cannot compress tensor");
let temp_buf = CUDA_DEVICE_MEMORY_POOL[t.raw.device_id]
.try_pull(
compressed_tensor.num_elem_allocated * compressed_tensor.dtype.bytes(),
compressed_tensor.num_elements_allocated()
* compressed_tensor.dtype().bytes(),
)
.expect("cannot allocate cuda memory");
let mut temp_tensor = BaguaTensorRaw {
ptr: temp_buf.ptr,
num_elem_allocated: compressed_tensor.num_elem_allocated,
dtype: compressed_tensor.dtype.clone(),
num_elem: compressed_tensor.num_elem,
device_id: compressed_tensor.device_id,
pool_allocation: Some(temp_buf),
num_elem_allocated: compressed_tensor.num_elements_allocated(),
dtype: compressed_tensor.dtype().clone(),
num_elem: compressed_tensor.num_elements(),
device_id: compressed_tensor.device_id(),
pool_allocations: vec![Arc::new(temp_buf)],
};
tracing::debug!("start alltoall");
c.alltoall(&compressed_tensor, &mut temp_tensor);
c.alltoall(compressed_tensor.as_ref(), &mut temp_tensor);
tracing::debug!("start decompress");
t.raw.decompress_from(
&self.compression_method,
Expand All @@ -72,7 +73,7 @@ impl CommOpTrait for CentralizedLowPrecisionSynchronous {
)
.expect("cannot compress tensor");
tracing::debug!("start allgather");
c.allgather(&compressed_tensor, &mut temp_tensor);
c.allgather(compressed_tensor.as_ref(), &mut temp_tensor);
tracing::debug!("start decompress");
t.raw.decompress_from(
&self.compression_method,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::comm_ops::CommOpTrait;
use crate::communicators::{BaguaCommunicator, BaguaHierarchicalCommunicator, NCCLGroupGuard};
use crate::datatypes::{BaguaBucket, BaguaReductionOp, BaguaTensorRaw};
use crate::datatypes::{BaguaBucket, BaguaReductionOp, BaguaTensorRaw, RawBaguaTensor};
use crate::events::BaguaEventChannel;
use crate::resource_pool::CUDA_DEVICE_MEMORY_POOL;
use crate::{BaguaCommOpChannels, BaguaScheduledCommOp};
use parking_lot::Mutex;
Expand Down Expand Up @@ -53,7 +54,7 @@ impl CommOpTrait for DecentralizedFullPrecisionSynchronous {
dtype: t.raw.dtype,
num_elem: t.raw.num_elem,
device_id: t.raw.device_id,
pool_allocation: Some(peer_tensor_buffer),
pool_allocations: vec![Arc::new(peer_tensor_buffer)],
};

let peer_mode = &self.peer_selection_mode;
Expand Down Expand Up @@ -107,12 +108,13 @@ impl CommOpTrait for DecentralizedFullPrecisionSynchronous {
if step % comm_interval == 0 {
// TODO: move this to .then() python API instead of hard code this in op
let post_backward_comm_op = BaguaScheduledCommOp {
name: format!("post backward comm op for bucket {}", bucket.name),
bucket: bucket.clone(),
ops: vec![Arc::new(DecentralizedFullPrecisionSynchronousPostStep {
communicator: self.communicator.clone(),
result_weight: peer_tensor,
})],
event_channel: Default::default(),
event_channel: BaguaEventChannel::new("decentralized_post_backward"),
};

comm_op_channels
Expand Down
4 changes: 1 addition & 3 deletions bagua-core-internal/src/comm_ops/python_ffi_op.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
use crate::comm_ops::CommOpTrait;
use crate::communicators::BaguaCommunicator;
use crate::datatypes::{BaguaBucket, BaguaTensorRaw};
use crate::resource_pool::CUDA_DEVICE_MEMORY_POOL;
use crate::datatypes::BaguaBucket;
use crate::BaguaCommOpChannels;
use pyo3::Python;
use std::sync::Arc;
Expand Down
Loading

0 comments on commit c9b9473

Please sign in to comment.