From c9b9473bdd1b0ef10f306501f5492347aa3b7f5f Mon Sep 17 00:00:00 2001 From: Xiangru Lian Date: Tue, 29 Jun 2021 22:00:42 -0700 Subject: [PATCH] feat: support creating BaguaTensor by passing torch tensor directly (#19) BREAKING CHANGE: `BaguaBucketPy` and `BaguaTensorPy` now require name. `BaguaTensorPy` is created by passing pytorch tensor directly now. --- Cargo.lock | 21 +- bagua-core-c/src/lib.rs | 6 +- bagua-core-internal/Cargo.toml | 2 +- bagua-core-internal/build.rs | 14 +- .../centralized_full_precision_synchronous.rs | 4 +- .../centralized_low_precision_synchronous.rs | 19 +- ...ecentralized_full_precision_synchronous.rs | 8 +- .../src/comm_ops/python_ffi_op.rs | 4 +- bagua-core-internal/src/communicators/mod.rs | 149 +- bagua-core-internal/src/datatypes/mod.rs | 630 +++--- bagua-core-internal/src/events.rs | 16 +- bagua-core-internal/src/lib.rs | 41 +- bagua-core-internal/src/telemetry/mod.rs | 8 +- bagua-core-internal/src/torch_ffi.rs | 1792 +++++++++++++++++ bagua-core-py/src/lib.rs | 99 +- 15 files changed, 2342 insertions(+), 471 deletions(-) create mode 100644 bagua-core-internal/src/torch_ffi.rs diff --git a/Cargo.lock b/Cargo.lock index 86ced9e..cff2347 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,5 +1,7 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. +version = 3 + [[package]] name = "addr2line" version = "0.15.2" @@ -101,10 +103,10 @@ dependencies = [ "cmd_lib", "cpp", "cpp_build", + "derivative", "flume", "hashbrown 0.11.2", "itertools", - "lazy_id", "ndarray", "num-derive", "num-traits", @@ -355,6 +357,17 @@ dependencies = [ "syn", ] +[[package]] +name = "derivative" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcc3dd5e9e9c0b295d6e1e4d811fb6f157d5ffd784b8d202fc62eac8035a770b" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "dynamic-pool" version = "0.2.2" @@ -623,12 +636,6 @@ dependencies = [ "wasm-bindgen", ] -[[package]] -name = "lazy_id" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "908c9a91c2be2e9f5eae45496061566fc0cd6e7929ff8a96032dc10d4827b961" - [[package]] name = "lazy_static" version = "1.4.0" diff --git a/bagua-core-c/src/lib.rs b/bagua-core-c/src/lib.rs index 1060b0f..9315564 100644 --- a/bagua-core-c/src/lib.rs +++ b/bagua-core-c/src/lib.rs @@ -1,4 +1,6 @@ use bagua_core_internal::communicators::BaguaSingleCommunicator; +use std::ffi::CStr; +use std::os::raw::c_char; pub struct BaguaSingleCommunicatorC { inner: BaguaSingleCommunicator, @@ -9,7 +11,7 @@ pub extern "C" fn bagua_single_communicator_c_create( nranks: usize, device_id: usize, stream_ptr: u64, - nccl_unique_id_str: &str, + nccl_unique_id_str: *const c_char, ) -> *mut BaguaSingleCommunicatorC { let obj = BaguaSingleCommunicatorC { inner: bagua_core_internal::communicators::BaguaSingleCommunicator::new( @@ -17,7 +19,7 @@ pub extern "C" fn bagua_single_communicator_c_create( nranks, device_id, stream_ptr, - nccl_unique_id_str, + unsafe { CStr::from_ptr(nccl_unique_id_str).to_str().unwrap() }, ), }; diff --git a/bagua-core-internal/Cargo.toml b/bagua-core-internal/Cargo.toml index 726789f..8a7a8c6 100644 --- a/bagua-core-internal/Cargo.toml +++ b/bagua-core-internal/Cargo.toml @@ -15,8 +15,8 @@ itertools = "0.10" shadow-rs = "0.6" parking_lot = { version = "0.11", features = ["deadlock_detection"] } hashbrown = "0.11" -lazy_id = "0.1" flume = "0.10" +derivative = "2.2.0" oneshot = "0.1" cpp = "0.5" sized-object-pool = "0.2" diff --git a/bagua-core-internal/build.rs b/bagua-core-internal/build.rs index ef26695..997fc52 100644 --- a/bagua-core-internal/build.rs +++ b/bagua-core-internal/build.rs @@ -14,16 +14,18 @@ fn main() { let mut cuda_cc = cc::Build::new(); cuda_cc .cuda(true) - .opt_level(3) .include("cpp/include") .include("third_party/cub-1.8.0") .include("../python/bagua_core/.data/include") .flag("-std=c++14") .flag("-cudart=shared"); - for sm in supported_sms { - cuda_cc - .flag("-gencode") - .flag(format!("arch=compute_{},code=sm_{}", sm, sm).as_str()); + + if std::env::var("PROFILE").unwrap() == "release" { + for sm in supported_sms { + cuda_cc + .flag("-gencode") + .flag(format!("arch=compute_{},code=sm_{}", sm, sm).as_str()); + } } cuda_cc .file("kernels/bagua_kernels.cu") @@ -82,5 +84,7 @@ fn main() { println!("cargo:rerun-if-changed=src/"); println!("cargo:rerun-if-changed=kernels/"); println!("cargo:rerun-if-changed=build.rs"); + + // bindgen --allowlist-type '.*TensorImpl.*' --enable-cxx-namespaces --ignore-functions --ignore-methods --size_t-is-usize --default-enum-style=rust --opaque-type 'std.*' --opaque-type 'c10::optional.*' wrapper.h -- -x c++ -std=c++14 > src/torch_ffi.rs shadow_rs::new().unwrap(); } diff --git a/bagua-core-internal/src/comm_ops/centralized_full_precision_synchronous.rs b/bagua-core-internal/src/comm_ops/centralized_full_precision_synchronous.rs index ed04fd0..9393d9b 100644 --- a/bagua-core-internal/src/comm_ops/centralized_full_precision_synchronous.rs +++ b/bagua-core-internal/src/comm_ops/centralized_full_precision_synchronous.rs @@ -1,6 +1,6 @@ use crate::comm_ops::CommOpTrait; use crate::communicators::BaguaCommunicator; -use crate::datatypes::{BaguaBucket, BaguaReductionOp, BaguaTensorRaw}; +use crate::datatypes::{BaguaBucket, BaguaReductionOp, BaguaTensorRaw, RawBaguaTensor}; use crate::resource_pool::CUDA_DEVICE_MEMORY_POOL; use crate::BaguaCommOpChannels; use std::sync::Arc; @@ -38,7 +38,7 @@ impl CommOpTrait for CentralizedFullPrecisionSynchronous { dtype: t.raw.dtype.clone(), num_elem: t.raw.num_elem, device_id: t.raw.device_id, - pool_allocation: Some(temp_buf), + pool_allocations: vec![Arc::new(temp_buf)], }; if self.scattergather { tracing::debug!("start alltoall"); diff --git a/bagua-core-internal/src/comm_ops/centralized_low_precision_synchronous.rs b/bagua-core-internal/src/comm_ops/centralized_low_precision_synchronous.rs index 2d49fa0..24691ce 100644 --- a/bagua-core-internal/src/comm_ops/centralized_low_precision_synchronous.rs +++ b/bagua-core-internal/src/comm_ops/centralized_low_precision_synchronous.rs @@ -1,6 +1,6 @@ use crate::comm_ops::CommOpTrait; use crate::communicators::BaguaCommunicator; -use crate::datatypes::{BaguaBucket, BaguaTensorRaw, TensorCompressionMethod}; +use crate::datatypes::{BaguaBucket, BaguaTensorRaw, RawBaguaTensor, TensorCompressionMethod}; use crate::resource_pool::CUDA_DEVICE_MEMORY_POOL; use crate::BaguaCommOpChannels; use std::sync::Arc; @@ -35,19 +35,20 @@ impl CommOpTrait for CentralizedLowPrecisionSynchronous { .expect("cannot compress tensor"); let temp_buf = CUDA_DEVICE_MEMORY_POOL[t.raw.device_id] .try_pull( - compressed_tensor.num_elem_allocated * compressed_tensor.dtype.bytes(), + compressed_tensor.num_elements_allocated() + * compressed_tensor.dtype().bytes(), ) .expect("cannot allocate cuda memory"); let mut temp_tensor = BaguaTensorRaw { ptr: temp_buf.ptr, - num_elem_allocated: compressed_tensor.num_elem_allocated, - dtype: compressed_tensor.dtype.clone(), - num_elem: compressed_tensor.num_elem, - device_id: compressed_tensor.device_id, - pool_allocation: Some(temp_buf), + num_elem_allocated: compressed_tensor.num_elements_allocated(), + dtype: compressed_tensor.dtype().clone(), + num_elem: compressed_tensor.num_elements(), + device_id: compressed_tensor.device_id(), + pool_allocations: vec![Arc::new(temp_buf)], }; tracing::debug!("start alltoall"); - c.alltoall(&compressed_tensor, &mut temp_tensor); + c.alltoall(compressed_tensor.as_ref(), &mut temp_tensor); tracing::debug!("start decompress"); t.raw.decompress_from( &self.compression_method, @@ -72,7 +73,7 @@ impl CommOpTrait for CentralizedLowPrecisionSynchronous { ) .expect("cannot compress tensor"); tracing::debug!("start allgather"); - c.allgather(&compressed_tensor, &mut temp_tensor); + c.allgather(compressed_tensor.as_ref(), &mut temp_tensor); tracing::debug!("start decompress"); t.raw.decompress_from( &self.compression_method, diff --git a/bagua-core-internal/src/comm_ops/decentralized_full_precision_synchronous.rs b/bagua-core-internal/src/comm_ops/decentralized_full_precision_synchronous.rs index 2a2d9ba..05f9995 100644 --- a/bagua-core-internal/src/comm_ops/decentralized_full_precision_synchronous.rs +++ b/bagua-core-internal/src/comm_ops/decentralized_full_precision_synchronous.rs @@ -1,6 +1,7 @@ use crate::comm_ops::CommOpTrait; use crate::communicators::{BaguaCommunicator, BaguaHierarchicalCommunicator, NCCLGroupGuard}; -use crate::datatypes::{BaguaBucket, BaguaReductionOp, BaguaTensorRaw}; +use crate::datatypes::{BaguaBucket, BaguaReductionOp, BaguaTensorRaw, RawBaguaTensor}; +use crate::events::BaguaEventChannel; use crate::resource_pool::CUDA_DEVICE_MEMORY_POOL; use crate::{BaguaCommOpChannels, BaguaScheduledCommOp}; use parking_lot::Mutex; @@ -53,7 +54,7 @@ impl CommOpTrait for DecentralizedFullPrecisionSynchronous { dtype: t.raw.dtype, num_elem: t.raw.num_elem, device_id: t.raw.device_id, - pool_allocation: Some(peer_tensor_buffer), + pool_allocations: vec![Arc::new(peer_tensor_buffer)], }; let peer_mode = &self.peer_selection_mode; @@ -107,12 +108,13 @@ impl CommOpTrait for DecentralizedFullPrecisionSynchronous { if step % comm_interval == 0 { // TODO: move this to .then() python API instead of hard code this in op let post_backward_comm_op = BaguaScheduledCommOp { + name: format!("post backward comm op for bucket {}", bucket.name), bucket: bucket.clone(), ops: vec![Arc::new(DecentralizedFullPrecisionSynchronousPostStep { communicator: self.communicator.clone(), result_weight: peer_tensor, })], - event_channel: Default::default(), + event_channel: BaguaEventChannel::new("decentralized_post_backward"), }; comm_op_channels diff --git a/bagua-core-internal/src/comm_ops/python_ffi_op.rs b/bagua-core-internal/src/comm_ops/python_ffi_op.rs index ad42395..3bc3093 100644 --- a/bagua-core-internal/src/comm_ops/python_ffi_op.rs +++ b/bagua-core-internal/src/comm_ops/python_ffi_op.rs @@ -1,7 +1,5 @@ use crate::comm_ops::CommOpTrait; -use crate::communicators::BaguaCommunicator; -use crate::datatypes::{BaguaBucket, BaguaTensorRaw}; -use crate::resource_pool::CUDA_DEVICE_MEMORY_POOL; +use crate::datatypes::BaguaBucket; use crate::BaguaCommOpChannels; use pyo3::Python; use std::sync::Arc; diff --git a/bagua-core-internal/src/communicators/mod.rs b/bagua-core-internal/src/communicators/mod.rs index d6d2480..58ad0e5 100644 --- a/bagua-core-internal/src/communicators/mod.rs +++ b/bagua-core-internal/src/communicators/mod.rs @@ -1,5 +1,6 @@ -use crate::datatypes::{BaguaCommunicationTensor, BaguaReductionOp, BaguaTensor, BaguaTensorRaw}; -use crate::resource_pool::CUDA_DEVICE_MEMORY_POOL; +use crate::datatypes::{ + BaguaCommunicationTensor, BaguaReductionOp, BaguaTensor, BaguaTensorRaw, RawBaguaTensor, +}; use crate::BaguaCoreError; use itertools::Itertools; use std::sync::Arc; @@ -68,31 +69,33 @@ impl BaguaSingleCommunicator { } pub fn allreduce(&self, tensor: &mut BaguaTensor, op: BaguaReductionOp) { - self.inner.allreduce(&mut tensor.inner.write().raw, op); + self.inner.allreduce(tensor.inner.write().raw.as_mut(), op); } pub fn broadcast(&self, tensor: &mut BaguaTensor, root_rank: i32) { self.inner - .broadcast(&mut tensor.inner.write().raw, root_rank); + .broadcast(tensor.inner.write().raw.as_mut(), root_rank); } pub fn reduce(&self, tensor: &mut BaguaTensor, root_rank: i32, op: BaguaReductionOp) { self.inner - .reduce(&mut tensor.inner.write().raw, root_rank, op); + .reduce(tensor.inner.write().raw.as_mut(), root_rank, op); } pub fn send(&self, tensor: &mut BaguaTensor, peer_rank: i32) { - self.inner.send(&mut tensor.inner.write().raw, peer_rank); + self.inner + .send(tensor.inner.write().raw.as_mut(), peer_rank); } pub fn recv(&self, tensor: &mut BaguaTensor, peer_rank: i32) { - self.inner.recv(&mut tensor.inner.write().raw, peer_rank); + self.inner + .recv(tensor.inner.write().raw.as_mut(), peer_rank); } pub fn alltoall(&self, send_tensor: &mut BaguaTensor, recv_tensor: &mut BaguaTensor) { self.inner.alltoall( - &mut send_tensor.inner.write().raw, - &mut recv_tensor.inner.write().raw, + send_tensor.inner.write().raw.as_mut(), + recv_tensor.inner.write().raw.as_mut(), ); } @@ -106,19 +109,19 @@ impl BaguaSingleCommunicator { recv_displs: &BaguaTensor, ) { self.inner.alltoall_v( - &mut send_tensor.inner.write().raw, - &send_counts.inner.read().raw, - &send_displs.inner.read().raw, - &mut recv_tensor.inner.write().raw, - &recv_counts.inner.read().raw, - &recv_displs.inner.read().raw, + send_tensor.inner.write().raw.as_mut(), + send_counts.inner.read().raw.as_ref(), + send_displs.inner.read().raw.as_ref(), + recv_tensor.inner.write().raw.as_mut(), + recv_counts.inner.read().raw.as_ref(), + recv_displs.inner.read().raw.as_ref(), ); } pub fn allgather(&self, send_tensor: &mut BaguaTensor, recv_tensor: &mut BaguaTensor) { self.inner.allgather( - &mut send_tensor.inner.write().raw, - &mut recv_tensor.inner.write().raw, + send_tensor.inner.write().raw.as_mut(), + recv_tensor.inner.write().raw.as_mut(), ); } @@ -333,11 +336,11 @@ impl Drop for NCCLGroupGuard { } impl BaguaCommunicatorInner { - pub fn broadcast(&self, tensor: &mut BaguaTensorRaw, root_rank: i32) { + pub fn broadcast(&self, tensor: &mut dyn RawBaguaTensor, root_rank: i32) { let communicator_ptr = self.comm_ptr; - let tensor_ptr = tensor.ptr; - let total_num_elem = tensor.num_elem_allocated; - let nccl_tensor_type = tensor.dtype.to_nccl_datatype(); + let tensor_ptr = tensor.data_ptr(); + let total_num_elem = tensor.num_elements_allocated(); + let nccl_tensor_type = tensor.dtype().to_nccl_datatype(); unsafe { cpp::cpp!([tensor_ptr as "void *", root_rank as "int", total_num_elem as "size_t", communicator_ptr as "Al::NCCLCommunicator *", nccl_tensor_type as "ncclDataType_t"] @@ -358,11 +361,11 @@ impl BaguaCommunicatorInner { } } - pub fn reduce(&self, tensor: &mut BaguaTensorRaw, root_rank: i32, op: BaguaReductionOp) { + pub fn reduce(&self, tensor: &mut dyn RawBaguaTensor, root_rank: i32, op: BaguaReductionOp) { let communicator_ptr = self.comm_ptr; - let tensor_ptr = tensor.ptr; - let total_num_elem = tensor.num_elem_allocated; - let nccl_tensor_type = tensor.dtype.to_nccl_datatype(); + let tensor_ptr = tensor.data_ptr(); + let total_num_elem = tensor.num_elements_allocated(); + let nccl_tensor_type = tensor.dtype().to_nccl_datatype(); unsafe { cpp::cpp!([tensor_ptr as "void *", root_rank as "int", total_num_elem as "size_t", communicator_ptr as "Al::NCCLCommunicator *", nccl_tensor_type as "ncclDataType_t", op as "uint8_t"] @@ -383,19 +386,19 @@ impl BaguaCommunicatorInner { } } - pub fn alltoall(&self, send_tensor: &BaguaTensorRaw, recv_tensor: &mut BaguaTensorRaw) { + pub fn alltoall(&self, send_tensor: &dyn RawBaguaTensor, recv_tensor: &mut dyn RawBaguaTensor) { let communicator_ptr = self.comm_ptr; - let tensor_ptr = send_tensor.ptr; + // TODO: also check recv buf? assert_eq!( - send_tensor.num_elem_allocated % self.nranks, + send_tensor.num_elements_allocated() % self.nranks, 0, "tensors must be aligned before using allscatter" ); - let send_chunk_size = send_tensor.num_elem_allocated / self.nranks; - let nccl_tensor_type = send_tensor.dtype.to_nccl_datatype(); + let send_chunk_size = send_tensor.num_elements_allocated() / self.nranks; + let nccl_tensor_type = send_tensor.dtype().to_nccl_datatype(); - let send_buf_ptr = send_tensor.ptr; - let recv_buf_ptr = recv_tensor.ptr; + let send_buf_ptr = send_tensor.data_ptr(); + let recv_buf_ptr = recv_tensor.data_ptr(); unsafe { cpp::cpp!([recv_buf_ptr as "void *", send_buf_ptr as "void *", send_chunk_size as "size_t", communicator_ptr as "Al::NCCLCommunicator *", nccl_tensor_type as "ncclDataType_t"] @@ -418,43 +421,43 @@ impl BaguaCommunicatorInner { pub fn alltoall_v( &self, - send_tensor: &BaguaTensorRaw, - send_counts_tensor: &BaguaTensorRaw, - send_displs_tensor: &BaguaTensorRaw, - recv_tensor: &mut BaguaTensorRaw, - recv_counts_tensor: &BaguaTensorRaw, - recv_displs_tensor: &BaguaTensorRaw, + send_tensor: &dyn RawBaguaTensor, + send_counts_tensor: &dyn RawBaguaTensor, + send_displs_tensor: &dyn RawBaguaTensor, + recv_tensor: &mut dyn RawBaguaTensor, + recv_counts_tensor: &dyn RawBaguaTensor, + recv_displs_tensor: &dyn RawBaguaTensor, ) { let communicator_ptr = self.comm_ptr; let nranks = self.nranks; - let nccl_tensor_type = send_tensor.dtype.to_nccl_datatype(); + let nccl_tensor_type = send_tensor.dtype().to_nccl_datatype(); assert_eq!( - send_counts_tensor.dtype.to_nccl_datatype(), + send_counts_tensor.dtype().to_nccl_datatype(), 5, "send_counts_tensor.dtype must be BaguaTensorDtype::U64" ); assert_eq!( - send_displs_tensor.dtype.to_nccl_datatype(), + send_displs_tensor.dtype().to_nccl_datatype(), 5, "send_displs_tensor.dtype must be BaguaTensorDtype::U64" ); assert_eq!( - recv_counts_tensor.dtype.to_nccl_datatype(), + recv_counts_tensor.dtype().to_nccl_datatype(), 5, "recv_counts_tensor.dtype must be BaguaTensorDtype::U64" ); assert_eq!( - recv_displs_tensor.dtype.to_nccl_datatype(), + recv_displs_tensor.dtype().to_nccl_datatype(), 5, "recv_displs_tensor.dtype must be BaguaTensorDtype::U64" ); - let send_buf_ptr = send_tensor.ptr; - let send_counts_ptr = send_counts_tensor.ptr; - let send_displs_ptr = send_displs_tensor.ptr; - let recv_buf_ptr = recv_tensor.ptr; - let recv_counts_ptr = recv_counts_tensor.ptr; - let recv_displs_ptr = recv_displs_tensor.ptr; + let send_buf_ptr = send_tensor.data_ptr(); + let send_counts_ptr = send_counts_tensor.data_ptr(); + let send_displs_ptr = send_displs_tensor.data_ptr(); + let recv_buf_ptr = recv_tensor.data_ptr(); + let recv_counts_ptr = recv_counts_tensor.data_ptr(); + let recv_displs_ptr = recv_displs_tensor.data_ptr(); unsafe { cpp::cpp!([ @@ -482,11 +485,11 @@ impl BaguaCommunicatorInner { } } - pub fn send(&self, send_tensor: &BaguaTensorRaw, peer_rank: i32) { + pub fn send(&self, send_tensor: &dyn RawBaguaTensor, peer_rank: i32) { let communicator_ptr = self.comm_ptr; - let tensor_ptr = send_tensor.ptr; - let total_num_elem = send_tensor.num_elem_allocated; - let nccl_tensor_type = send_tensor.dtype.to_nccl_datatype(); + let tensor_ptr = send_tensor.data_ptr(); + let total_num_elem = send_tensor.num_elements_allocated(); + let nccl_tensor_type = send_tensor.dtype().to_nccl_datatype(); unsafe { cpp::cpp!([tensor_ptr as "void *", total_num_elem as "size_t", communicator_ptr as "Al::NCCLCommunicator *", nccl_tensor_type as "ncclDataType_t", peer_rank as "int"] @@ -507,11 +510,11 @@ impl BaguaCommunicatorInner { } } - pub fn recv(&self, recv_tensor: &mut BaguaTensorRaw, peer_rank: i32) { + pub fn recv(&self, recv_tensor: &mut dyn RawBaguaTensor, peer_rank: i32) { let communicator_ptr = self.comm_ptr; - let tensor_ptr = recv_tensor.ptr; - let total_num_elem = recv_tensor.num_elem_allocated; - let nccl_tensor_type = recv_tensor.dtype.to_nccl_datatype(); + let tensor_ptr = recv_tensor.data_ptr(); + let total_num_elem = recv_tensor.num_elements_allocated(); + let nccl_tensor_type = recv_tensor.dtype().to_nccl_datatype(); unsafe { cpp::cpp!([tensor_ptr as "void *", total_num_elem as "size_t", communicator_ptr as "Al::NCCLCommunicator *", nccl_tensor_type as "ncclDataType_t", peer_rank as "int"] @@ -532,25 +535,29 @@ impl BaguaCommunicatorInner { } } - pub fn allgather(&self, send_tensor: &BaguaTensorRaw, recv_tensor: &mut BaguaTensorRaw) { + pub fn allgather( + &self, + send_tensor: &dyn RawBaguaTensor, + recv_tensor: &mut dyn RawBaguaTensor, + ) { let communicator_ptr = self.comm_ptr; - let send_tensor_ptr = send_tensor.ptr; + let send_tensor_ptr = send_tensor.data_ptr(); assert_eq!( - send_tensor.num_elem_allocated, - recv_tensor.num_elem_allocated + send_tensor.num_elements_allocated(), + recv_tensor.num_elements_allocated() ); - assert_eq!(send_tensor.dtype, recv_tensor.dtype); + assert_eq!(send_tensor.dtype(), recv_tensor.dtype()); assert_eq!( - send_tensor.num_elem_allocated % self.nranks, + send_tensor.num_elements_allocated() % self.nranks, 0, "tensors must be aligned before using allgather" ); - let send_chunk_size = send_tensor.num_elem_allocated / self.nranks; - let nccl_tensor_type = send_tensor.dtype.to_nccl_datatype(); + let send_chunk_size = send_tensor.num_elements_allocated() / self.nranks; + let nccl_tensor_type = send_tensor.dtype().to_nccl_datatype(); let send_buf_ptr = send_tensor_ptr - + self.rank as u64 * send_chunk_size as u64 * send_tensor.dtype.bytes() as u64; - let recv_buf_ptr = recv_tensor.ptr; + + self.rank as u64 * send_chunk_size as u64 * send_tensor.dtype().bytes() as u64; + let recv_buf_ptr = recv_tensor.data_ptr(); unsafe { cpp::cpp!([recv_buf_ptr as "void *", send_buf_ptr as "void *", send_chunk_size as "size_t", communicator_ptr as "Al::NCCLCommunicator *", nccl_tensor_type as "ncclDataType_t"] @@ -582,11 +589,11 @@ impl BaguaCommunicatorInner { } } - pub fn allreduce(&self, tensor: &mut BaguaTensorRaw, op: BaguaReductionOp) { + pub fn allreduce(&self, tensor: &mut dyn RawBaguaTensor, op: BaguaReductionOp) { let communicator_ptr = self.comm_ptr; - let tensor_ptr = tensor.ptr; - let total_num_elem = tensor.num_elem_allocated; - let nccl_tensor_type = tensor.dtype.to_nccl_datatype(); + let tensor_ptr = tensor.data_ptr(); + let total_num_elem = tensor.num_elements_allocated(); + let nccl_tensor_type = tensor.dtype().to_nccl_datatype(); unsafe { cpp::cpp!([tensor_ptr as "void *", total_num_elem as "size_t", communicator_ptr as "Al::NCCLCommunicator *", nccl_tensor_type as "ncclDataType_t", op as "uint8_t"] diff --git a/bagua-core-internal/src/datatypes/mod.rs b/bagua-core-internal/src/datatypes/mod.rs index b1fb912..e4d2b26 100644 --- a/bagua-core-internal/src/datatypes/mod.rs +++ b/bagua-core-internal/src/datatypes/mod.rs @@ -8,12 +8,15 @@ use crate::comm_ops::CommOpTrait; use crate::communicators::{BaguaCommunicator, BaguaSingleCommunicator}; use crate::resource_pool::{CudaMemory, CUDA_DEVICE_MEMORY_POOL}; use crate::telemetry::TELEMETRY; +use crate::torch_ffi::root::c10::{DeviceType, StorageImpl, TensorImpl}; use crate::{kernels, BaguaCoreError}; use itertools::Itertools; use num_derive::FromPrimitive; use num_traits::FromPrimitive; use parking_lot::{Mutex, RwLock}; use sized_object_pool::DynamicPoolItem; +use std::ffi::c_void; +use std::fmt::Debug; use std::sync::Arc; // must be consistent with Aluminum ReductionOperator: https://github.com/BaguaSys/Aluminum/blob/master/include/aluminum/base.hpp @@ -62,6 +65,12 @@ impl BaguaTensorDtype { } } +#[derive(Debug)] +pub struct TorchTensorRaw { + pub torch_tensor_cdata: u64, + pub dtype: BaguaTensorDtype, +} + #[derive(Debug)] pub struct BaguaTensorRaw { pub ptr: u64, @@ -69,99 +78,21 @@ pub struct BaguaTensorRaw { pub dtype: BaguaTensorDtype, pub num_elem: usize, pub device_id: usize, - pub pool_allocation: Option>, + pub pool_allocations: Vec>>, } -#[derive(Clone, Debug)] -pub struct MinMaxUInt8CompressionParameters {} - -impl MinMaxUInt8CompressionParameters { - pub fn get_compressed_buffer_size( - n_chunks: usize, - chunk_size: usize, - from_dtype: &BaguaTensorDtype, - ) -> usize { - #[inline] - fn align_size(size: usize, align: usize) -> usize { - ((size) + (align) - 1) / (align) * (align) - } - - match from_dtype { - BaguaTensorDtype::F32 => { - let align_bytes = 32; - // Note chunk_size is already aligned outside - let compressed_align_bytes = align_size(chunk_size * n_chunks, align_bytes); - let min_max_align_bytes = align_size(4 * 2, align_bytes) * n_chunks; // assume min max are both f32 - return compressed_align_bytes + min_max_align_bytes; - } - BaguaTensorDtype::F16 => { - let align_bytes = 32; - // Note chunk_size is already aligned outside - let compressed_align_bytes = align_size(chunk_size * n_chunks, align_bytes); - let min_max_align_bytes = align_size(2 * 2, align_bytes) * n_chunks; // assume min max are both f16 - return compressed_align_bytes + min_max_align_bytes; - } - BaguaTensorDtype::U8 => { - unimplemented!() - } - BaguaTensorDtype::I64 => { - unimplemented!() - } - BaguaTensorDtype::U64 => { - unimplemented!() - } - } - } - - pub fn get_temp_buffer_size( - input_ptr: u64, - input_size: usize, - output_ptr: u64, - stream_ptr: u64, - from_dtype: &BaguaTensorDtype, - ) -> usize { - match from_dtype { - BaguaTensorDtype::F32 => unsafe { - kernels::array_min_max_size_f32_host( - input_ptr as _, - input_size as _, - output_ptr as _, - stream_ptr as _, - ) - }, +pub trait RawBaguaTensor: Debug { + fn data_ptr(&self) -> u64; + fn num_elements(&self) -> usize; + fn num_elements_allocated(&self) -> usize; + fn device_id(&self) -> usize; + fn dtype(&self) -> BaguaTensorDtype; - BaguaTensorDtype::F16 => unsafe { - kernels::array_min_max_size_f16_host( - input_ptr as _, - input_size as _, - output_ptr as _, - stream_ptr as _, - ) - }, - BaguaTensorDtype::U8 => { - unimplemented!() - } - BaguaTensorDtype::I64 => { - unimplemented!() - } - BaguaTensorDtype::U64 => { - unimplemented!() - } - } - } -} - -#[derive(Clone, Debug)] -pub enum TensorCompressionMethod { - MinMaxUInt8(MinMaxUInt8CompressionParameters), -} - -impl BaguaTensorRaw { - pub fn divide_inplace(&mut self, stream_ptr: u64, divide_factor: f32) { - let tensor_ptr = self.ptr; - let total_num_elem = self.num_elem; + fn divide_inplace(&mut self, stream_ptr: u64, divide_factor: f32) { + let tensor_ptr = self.data_ptr(); + let total_num_elem = self.num_elements(); unsafe { - match self.dtype { + match self.dtype() { BaguaTensorDtype::F32 => { kernels::divide_inplace_f32_host( tensor_ptr as _, @@ -191,13 +122,13 @@ impl BaguaTensorRaw { } } - pub fn clone_from(&mut self, other: &Self, stream_ptr: u64) { - assert_eq!(self.dtype, other.dtype); - assert_eq!(self.num_elem, other.num_elem); + fn clone_from(&mut self, other: &dyn RawBaguaTensor, stream_ptr: u64) { + assert_eq!(self.dtype(), other.dtype()); + assert_eq!(self.num_elements(), other.num_elements()); unsafe { - let src = other.ptr; - let dst = self.ptr; - let count = self.num_elem * self.dtype.bytes(); + let src = other.data_ptr(); + let dst = self.data_ptr(); + let count = self.num_elements() * self.dtype().bytes(); cpp::cpp!([stream_ptr as "cudaStream_t", dst as "void *", src as "const void *", count as "size_t"] { CUDACHECK(cudaMemcpyAsync( dst, src , count, cudaMemcpyDeviceToDevice, stream_ptr)); @@ -205,17 +136,17 @@ impl BaguaTensorRaw { } } - pub fn average_inplace(&mut self, other: &Self, stream_ptr: u64) { - assert_eq!(self.dtype, other.dtype); - assert_eq!(self.num_elem, other.num_elem); - let tensor_ptr = self.ptr; - let total_num_elem = self.num_elem; + fn average_inplace(&mut self, other: &dyn RawBaguaTensor, stream_ptr: u64) { + assert_eq!(self.dtype(), other.dtype()); + assert_eq!(self.num_elements(), other.num_elements()); + let tensor_ptr = self.data_ptr(); + let total_num_elem = self.num_elements(); unsafe { - match self.dtype { + match self.dtype() { BaguaTensorDtype::F32 => { kernels::average_inplace_f32_host( tensor_ptr as _, - other.ptr as _, + other.data_ptr() as _, total_num_elem as i32, stream_ptr as _, ); @@ -223,7 +154,7 @@ impl BaguaTensorRaw { BaguaTensorDtype::F16 => { kernels::average_inplace_f16_host( tensor_ptr as _, - other.ptr as _, + other.data_ptr() as _, total_num_elem as i32, stream_ptr as _, ); @@ -241,45 +172,45 @@ impl BaguaTensorRaw { } } - pub fn compress( + fn compress( &self, compression: &TensorCompressionMethod, n_chunks: usize, stream_ptr: u64, target_chunk: i32, - ) -> Result { + ) -> Result, BaguaCoreError> { match compression { TensorCompressionMethod::MinMaxUInt8(_parameters) => { assert_eq!( - self.num_elem_allocated % n_chunks, + self.num_elements_allocated() % n_chunks, 0, "compression tensor size % n_chunks must be 0" ); - let chunk_size = self.num_elem_allocated / n_chunks; + let chunk_size = self.num_elements_allocated() / n_chunks; let output_buffer_size = MinMaxUInt8CompressionParameters::get_compressed_buffer_size( n_chunks, chunk_size, - &self.dtype, + &self.dtype(), ); let output_buffer = - CUDA_DEVICE_MEMORY_POOL[self.device_id].try_pull(output_buffer_size)?; + CUDA_DEVICE_MEMORY_POOL[self.device_id()].try_pull(output_buffer_size)?; let temp_buffer_size = MinMaxUInt8CompressionParameters::get_temp_buffer_size( - self.ptr, - self.num_elem, + self.data_ptr(), + self.num_elements(), output_buffer.ptr, stream_ptr, - &self.dtype, + &self.dtype(), ); let temp_buffer = - CUDA_DEVICE_MEMORY_POOL[self.device_id].try_pull(temp_buffer_size)?; + CUDA_DEVICE_MEMORY_POOL[self.device_id()].try_pull(temp_buffer_size)?; - match self.dtype { + match self.dtype() { BaguaTensorDtype::F32 => unsafe { kernels::compress_f32_to_uint8_host( - self.ptr as _, - self.num_elem as _, + self.data_ptr() as _, + self.num_elements() as _, chunk_size as _, n_chunks as _, output_buffer.ptr as _, @@ -292,8 +223,8 @@ impl BaguaTensorRaw { }, BaguaTensorDtype::F16 => unsafe { kernels::compress_f16_to_uint8_host( - self.ptr as _, - self.num_elem as _, + self.data_ptr() as _, + self.num_elements() as _, chunk_size as _, n_chunks as _, output_buffer.ptr as _, @@ -314,50 +245,52 @@ impl BaguaTensorRaw { unimplemented!() } } - return Ok(BaguaTensorRaw { + return Ok(Box::new(BaguaTensorRaw { ptr: output_buffer.ptr, num_elem_allocated: output_buffer_size, dtype: BaguaTensorDtype::U8, num_elem: output_buffer_size, - device_id: self.device_id, - pool_allocation: Some(output_buffer), - }); + device_id: self.device_id(), + pool_allocations: vec![Arc::new(output_buffer)], + })); } } } - pub fn decompress_from( + fn decompress_from( &mut self, compression: &TensorCompressionMethod, n_chunks: usize, - compressed_buffer: &BaguaTensorRaw, + compressed_buffer: &dyn RawBaguaTensor, stream_ptr: u64, ) { assert_eq!( - self.num_elem_allocated % n_chunks, + self.num_elements_allocated() % n_chunks, 0, "compression tensor size % n_chunks must be 0" ); - let chunk_size = self.num_elem_allocated / n_chunks; + let chunk_size = self.num_elements_allocated() / n_chunks; match compression { - TensorCompressionMethod::MinMaxUInt8(_parameters) => match self.dtype { + TensorCompressionMethod::MinMaxUInt8(_parameters) => match self.dtype() { BaguaTensorDtype::F32 => unsafe { kernels::decompress_uint8_to_f32_host( - compressed_buffer.ptr as _, - compressed_buffer.num_elem_allocated * compressed_buffer.dtype.bytes(), + compressed_buffer.data_ptr() as _, + compressed_buffer.num_elements_allocated() + * compressed_buffer.dtype().bytes(), chunk_size as _, n_chunks as _, - self.ptr as _, + self.num_elements_allocated() as _, stream_ptr as _, ); }, BaguaTensorDtype::F16 => unsafe { kernels::decompress_uint8_to_f16_host( - compressed_buffer.ptr as _, - compressed_buffer.num_elem_allocated * compressed_buffer.dtype.bytes(), + compressed_buffer.data_ptr() as _, + compressed_buffer.num_elements_allocated() + * compressed_buffer.dtype().bytes(), chunk_size as _, n_chunks as _, - self.ptr as _, + self.data_ptr() as _, stream_ptr as _, ); }, @@ -374,17 +307,17 @@ impl BaguaTensorRaw { } } - pub fn reduce_mean_inplace(&mut self, n_chunks: usize, target_chunk: usize, stream_ptr: u64) { + fn reduce_mean_inplace(&mut self, n_chunks: usize, target_chunk: usize, stream_ptr: u64) { assert_eq!( - self.num_elem_allocated % n_chunks, + self.num_elements_allocated() % n_chunks, 0, "reduce_mean_inplace requires tensor aligned" ); - let chunk_size = self.num_elem_allocated / n_chunks; - match self.dtype { + let chunk_size = self.num_elements_allocated() / n_chunks; + match self.dtype() { BaguaTensorDtype::F32 => unsafe { kernels::reduce_mean_f32_inplace_host( - self.ptr as _, + self.data_ptr() as _, chunk_size as _, n_chunks as _, target_chunk as _, @@ -393,7 +326,7 @@ impl BaguaTensorRaw { }, BaguaTensorDtype::F16 => unsafe { kernels::reduce_mean_f16_inplace_host( - self.ptr as _, + self.data_ptr() as _, chunk_size as _, n_chunks as _, target_chunk as _, @@ -412,17 +345,17 @@ impl BaguaTensorRaw { } } - pub fn reduce_sum_inplace(&mut self, n_chunks: usize, target_chunk: usize, stream_ptr: u64) { + fn reduce_sum_inplace(&mut self, n_chunks: usize, target_chunk: usize, stream_ptr: u64) { assert_eq!( - self.num_elem_allocated % n_chunks, + self.num_elements_allocated() % n_chunks, 0, "reduce_sum_inplace requires tensor aligned" ); - let chunk_size = self.num_elem_allocated / n_chunks; - match self.dtype { + let chunk_size = self.num_elements_allocated() / n_chunks; + match self.dtype() { BaguaTensorDtype::F32 => unsafe { kernels::reduce_sum_f32_inplace_host( - self.ptr as _, + self.data_ptr() as _, chunk_size as _, n_chunks as _, target_chunk as _, @@ -431,7 +364,7 @@ impl BaguaTensorRaw { }, BaguaTensorDtype::F16 => unsafe { kernels::reduce_sum_f16_inplace_host( - self.ptr as _, + self.data_ptr() as _, chunk_size as _, n_chunks as _, target_chunk as _, @@ -451,20 +384,188 @@ impl BaguaTensorRaw { } } +impl TorchTensorRaw { + fn extract_torch_c_data(&self) -> &TensorImpl { + unsafe { (self.torch_tensor_cdata as *const TensorImpl).as_ref() } + .expect("torch c data pointer is null") + } + + fn extract_storage(&self) -> &StorageImpl { + unsafe { + (self.extract_torch_c_data().storage_.storage_impl_.target_ as *const StorageImpl) + .as_ref() + .expect("torch c data has not storage") + } + } +} + +impl RawBaguaTensor for TorchTensorRaw { + fn data_ptr(&self) -> u64 { + let cdata = self.extract_torch_c_data(); + let storage = self.extract_storage(); + storage.data_ptr_.ptr_.data_ as u64 + + cdata.storage_offset_ as u64 * self.dtype.bytes() as u64 + } + + fn num_elements(&self) -> usize { + self.extract_torch_c_data().numel_ as _ + } + + fn num_elements_allocated(&self) -> usize { + self.num_elements() + } + + fn device_id(&self) -> usize { + let storage_data_ptr = &self.extract_storage().data_ptr_; + assert_eq!( + storage_data_ptr.device_.type_, + DeviceType::CUDA, + "currently only cuda tensors are supported in Bagua" + ); + return storage_data_ptr.device_.index_ as _; + } + + fn dtype(&self) -> BaguaTensorDtype { + self.dtype + } +} + +impl RawBaguaTensor for BaguaTensorRaw { + fn data_ptr(&self) -> u64 { + self.ptr + } + + fn num_elements(&self) -> usize { + self.num_elem + } + + fn num_elements_allocated(&self) -> usize { + self.num_elem_allocated + } + + fn device_id(&self) -> usize { + self.device_id + } + + fn dtype(&self) -> BaguaTensorDtype { + self.dtype + } +} + +#[derive(Clone, Debug)] +pub struct MinMaxUInt8CompressionParameters {} + +impl MinMaxUInt8CompressionParameters { + pub fn get_compressed_buffer_size( + n_chunks: usize, + chunk_size: usize, + from_dtype: &BaguaTensorDtype, + ) -> usize { + #[inline] + fn align_size(size: usize, align: usize) -> usize { + ((size) + (align) - 1) / (align) * (align) + } + + match from_dtype { + BaguaTensorDtype::F32 => { + let align_bytes = 32; + // Note chunk_size is already aligned outside + let compressed_align_bytes = align_size(chunk_size * n_chunks, align_bytes); + let min_max_align_bytes = align_size(4 * 2, align_bytes) * n_chunks; // assume min max are both f32 + return compressed_align_bytes + min_max_align_bytes; + } + BaguaTensorDtype::F16 => { + let align_bytes = 32; + // Note chunk_size is already aligned outside + let compressed_align_bytes = align_size(chunk_size * n_chunks, align_bytes); + let min_max_align_bytes = align_size(2 * 2, align_bytes) * n_chunks; // assume min max are both f16 + return compressed_align_bytes + min_max_align_bytes; + } + BaguaTensorDtype::U8 => { + unimplemented!() + } + BaguaTensorDtype::I64 => { + unimplemented!() + } + BaguaTensorDtype::U64 => { + unimplemented!() + } + } + } + + pub fn get_temp_buffer_size( + input_ptr: u64, + input_size: usize, + output_ptr: u64, + stream_ptr: u64, + from_dtype: &BaguaTensorDtype, + ) -> usize { + match from_dtype { + BaguaTensorDtype::F32 => unsafe { + kernels::array_min_max_size_f32_host( + input_ptr as _, + input_size as _, + output_ptr as _, + stream_ptr as _, + ) + }, + + BaguaTensorDtype::F16 => unsafe { + kernels::array_min_max_size_f16_host( + input_ptr as _, + input_size as _, + output_ptr as _, + stream_ptr as _, + ) + }, + BaguaTensorDtype::U8 => { + unimplemented!() + } + BaguaTensorDtype::I64 => { + unimplemented!() + } + BaguaTensorDtype::U64 => { + unimplemented!() + } + } + } +} + +#[derive(Clone, Debug)] +pub enum TensorCompressionMethod { + MinMaxUInt8(MinMaxUInt8CompressionParameters), +} + +impl BaguaTensorRaw {} + #[derive(Debug)] pub struct BaguaTensorInner { - pub raw: BaguaTensorRaw, + pub name: String, + pub raw: Box, pub ready_for_comm: bool, pub ready_cuda_event_ptr: u64, } #[derive(Clone, Debug)] pub struct BaguaTensor { - pub id: u64, pub inner: Arc>, } impl BaguaTensor { + pub fn new_from_torch(name: String, torch_cdata_ptr: u64, dtype: BaguaTensorDtype) -> Self { + Self { + inner: Arc::new(RwLock::new(BaguaTensorInner { + name, + raw: Box::new(TorchTensorRaw { + torch_tensor_cdata: torch_cdata_ptr, + dtype, + }), + ready_for_comm: false, + ready_cuda_event_ptr: 0, + })), + } + } + pub fn mark_comm_ready(&self, cuda_event_ptr: u64) { if cuda_event_ptr == 0 { tracing::info!("mark comm ready with an event 0, ignoring event"); @@ -472,7 +573,7 @@ impl BaguaTensor { match TELEMETRY.as_ref() { None => {} Some(ref x) => { - x.lock().new_tensor_ready(self.id); + x.lock().new_tensor_ready(self.inner.read().name.as_str()); } } let mut guard = self.inner.write(); @@ -484,53 +585,19 @@ impl BaguaTensor { self.inner.write().ready_for_comm = false; } - pub fn ready_for_comm(&self) -> bool { - self.inner.read().ready_for_comm + pub fn name(&self) -> String { + self.inner.read().name.clone() } -} -impl BaguaTensor { - pub fn new( - ptr: u64, - num_elem: usize, - num_elem_allocated: usize, - dtype: &str, - device_id: usize, - ) -> Self { - let dtype = match dtype { - "f32" => BaguaTensorDtype::F32, - "f16" => BaguaTensorDtype::F16, - "u8" => BaguaTensorDtype::U8, - "i64" => BaguaTensorDtype::I64, - "u64" => BaguaTensorDtype::U64, - _ => { - unimplemented!() - } - }; - let id = lazy_id::Id::lazy().get(); - tracing::debug!("generate tensor id {}", id); - Self { - id, - inner: Arc::new(RwLock::new(BaguaTensorInner { - raw: BaguaTensorRaw { - ptr, - num_elem, - num_elem_allocated, - dtype, - device_id, - pool_allocation: None, - }, - ready_for_comm: false, - ready_cuda_event_ptr: 0, - })), - } + pub fn ready_for_comm(&self) -> bool { + self.inner.read().ready_for_comm } pub fn compress(&self, method: &str, n_chunks: usize, target_chunk: i32) -> Self { match method { "min_max_uint8" => Self { - id: 0, inner: Arc::new(RwLock::new(BaguaTensorInner { + name: "compressed_tensor".to_string(), raw: self .inner .read() @@ -543,7 +610,7 @@ impl BaguaTensor { 0, target_chunk, ) - .expect("unable to compress"), + .expect("cannot compress tensor"), ready_for_comm: false, ready_cuda_event_ptr: 0, })), @@ -592,7 +659,7 @@ impl BaguaTensor { self.inner.write().raw.decompress_from( &TensorCompressionMethod::MinMaxUInt8(MinMaxUInt8CompressionParameters {}), n_chunks, - &compressed_buffer.inner.read().raw, + compressed_buffer.inner.read().raw.as_ref(), 0, ); } @@ -602,28 +669,24 @@ impl BaguaTensor { } } - pub fn ptr(&self) -> u64 { - self.inner.read().raw.ptr + pub fn data_ptr(&self) -> u64 { + self.inner.read().raw.data_ptr() } - pub fn id(&self) -> u64 { - self.id + pub fn device_id(&self) -> usize { + self.inner.read().raw.device_id() } - pub fn num_elem(&self) -> usize { - self.inner.read().raw.num_elem + pub fn num_elements(&self) -> usize { + self.inner.read().raw.num_elements() } - pub fn num_elem_allocated(&self) -> usize { - self.inner.read().raw.num_elem_allocated + pub fn num_elements_allocated(&self) -> usize { + self.inner.read().raw.num_elements_allocated() } pub fn dtype(&self) -> String { - format!("{:?}", self.inner.read().raw.dtype) - } - - pub fn reset_ptr(&mut self, ptr: u64) { - self.inner.write().raw.ptr = ptr; + format!("{:?}", self.inner.read().raw.dtype()) } } @@ -631,9 +694,7 @@ impl BaguaTensor { pub struct BaguaBucketInner { pub tensors: Vec, pub dtype: BaguaTensorDtype, - pub inplace: bool, pub comm_ops: Vec>, - pub align_bytes: usize, } pub struct BaguaCommunicationTensor<'b> { @@ -644,7 +705,47 @@ pub struct BaguaCommunicationTensor<'b> { } impl BaguaBucketInner { + pub fn contiguous(&self) -> bool { + let bytes_per_element = self.dtype.bytes() as u64; + let t = &(*self.tensors.first().unwrap()).inner.read(); + let mut current_ptr = + t.raw.data_ptr() + t.raw.num_elements_allocated() as u64 * bytes_per_element; + for tensor in self.tensors.iter().dropping(1) { + let inner_tensor = &tensor.inner.read(); + tracing::debug!( + "current_ptr {} next tensor data_ptr {}", + current_ptr, + inner_tensor.raw.data_ptr() + ); + if current_ptr != inner_tensor.raw.data_ptr() { + return false; + } else { + current_ptr += inner_tensor.raw.num_elements_allocated() as u64 * bytes_per_element; + } + } + return true; + } + + pub fn total_num_elements(&self) -> usize { + self.tensors + .iter() + .map(|tensor| tensor.num_elements()) + .sum::() + } + + pub fn total_num_elements_allocated(&self) -> usize { + self.tensors + .iter() + .map(|tensor| tensor.num_elements_allocated()) + .sum::() + } + + pub fn total_allocated_bytes(&self) -> usize { + self.total_num_elements_allocated() * self.dtype.bytes() + } + /// NOTE: this does not wait for memcpy finished + // TODO: simplify args pub fn get_communication_tensor( &self, stream_ptr: u64, @@ -663,28 +764,18 @@ impl BaguaBucketInner { } tensor.inner.write().ready_cuda_event_ptr = 0; } - match self.inplace && !force_copy { + match self.contiguous() && !force_copy { true => { tracing::debug!("bucket is inplace, creating communication tensor without copy"); - let tensor = self.tensors.first().unwrap().inner.read(); - let total_num_elem: usize = self - .tensors - .iter() - .map(|x| x.inner.read().raw.num_elem) - .sum(); - let total_num_elem_allocated: usize = self - .tensors - .iter() - .map(|x| x.inner.read().raw.num_elem_allocated) - .sum(); + let total_num_elem_allocated: usize = self.total_num_elements_allocated(); BaguaCommunicationTensor { raw: BaguaTensorRaw { - ptr: tensor.raw.ptr, - num_elem: total_num_elem, - dtype: tensor.raw.dtype.clone(), + ptr: self.tensors[0].data_ptr(), + num_elem: total_num_elem_allocated, + dtype: self.dtype, num_elem_allocated: total_num_elem_allocated, - device_id: tensor.raw.device_id, - pool_allocation: None, + device_id: self.tensors[0].device_id(), + pool_allocations: vec![], }, need_copy_back: false, bucket: &self, @@ -697,47 +788,34 @@ impl BaguaBucketInner { let total_num_elem: usize = self .tensors .iter() - .map(|x| x.inner.read().raw.num_elem) + .map(|x| x.inner.read().raw.num_elements()) .sum(); - let total_bytes = { - let mut result = total_num_elem * first_tensor.raw.dtype.bytes(); - if self.align_bytes > 0 { - result = - (result + (self.align_bytes - 1)) / self.align_bytes * self.align_bytes; - } - result - }; - assert_eq!( - total_bytes % first_tensor.raw.dtype.bytes(), - 0, - "cannot align tensor" - ); - let buffer_tensor = CUDA_DEVICE_MEMORY_POOL[first_tensor.raw.device_id] - .try_pull(total_bytes) + let buffer_tensor = CUDA_DEVICE_MEMORY_POOL[first_tensor.raw.device_id()] + .try_pull(self.total_allocated_bytes()) .expect("unable to allocate GPU buffer memory to do bucketing"); let mut dst = buffer_tensor.ptr; for tensor in &self.tensors { let tensor = tensor.inner.read(); - let src = tensor.raw.ptr; - let count = tensor.raw.num_elem * tensor.raw.dtype.bytes(); + let src = tensor.raw.data_ptr(); + let count = tensor.raw.num_elements() * tensor.raw.dtype().bytes(); unsafe { cpp::cpp!([stream_ptr as "cudaStream_t", dst as "void *", src as "const void *", count as "size_t"] { CUDACHECK(cudaMemcpyAsync( dst, src , count, cudaMemcpyDeviceToDevice, stream_ptr)); }); } - dst += tensor.raw.num_elem as u64 * tensor.raw.dtype.bytes() as u64; + dst += tensor.raw.num_elements() as u64 * tensor.raw.dtype().bytes() as u64; } BaguaCommunicationTensor { raw: BaguaTensorRaw { ptr: buffer_tensor.ptr, num_elem: total_num_elem, - dtype: first_tensor.raw.dtype.clone(), - num_elem_allocated: total_bytes / first_tensor.raw.dtype.bytes(), - device_id: first_tensor.raw.device_id, - pool_allocation: Some(buffer_tensor), + dtype: first_tensor.raw.dtype().clone(), + num_elem_allocated: self.total_num_elements_allocated(), + device_id: first_tensor.raw.device_id(), + pool_allocations: vec![Arc::new(buffer_tensor)], }, need_copy_back: if force_not_copy_back { false } else { true }, bucket: &self, @@ -755,8 +833,8 @@ impl<'b> Drop for BaguaCommunicationTensor<'b> { let mut src = self.raw.ptr; for tensor in &self.bucket.tensors { let tensor = tensor.inner.read(); - let dst = tensor.raw.ptr; - let count = tensor.raw.num_elem * tensor.raw.dtype.bytes(); + let dst = tensor.raw.data_ptr(); + let count = tensor.raw.num_elements() * tensor.raw.dtype().bytes(); unsafe { cpp::cpp!([stream_ptr as "cudaStream_t", dst as "void *", src as "const void *", count as "size_t"] { @@ -779,33 +857,27 @@ impl<'b> Drop for BaguaCommunicationTensor<'b> { #[derive(Debug, Clone)] pub struct BaguaBucket { - pub id: u64, pub name: String, pub inner: Arc>, } impl BaguaBucket { - pub fn new( - tensors: &[&BaguaTensor], - name: &str, - inplace: bool, - align_bytes: usize, - ) -> Result { + pub fn new(tensors: &[&BaguaTensor], name: &str) -> Result { if tensors.is_empty() { return Err(BaguaCoreError::BucketError("bucket is empty".into())); } let first_tensor = (*tensors.first().unwrap()).inner.read(); - let dtype: &BaguaTensorDtype = &first_tensor.raw.dtype; - let device_id = first_tensor.raw.device_id; + let dtype: &BaguaTensorDtype = &first_tensor.raw.dtype(); + let device_id = first_tensor.raw.device_id(); for tensor in tensors.iter() { let tensor = tensor.inner.read(); - if dtype != &tensor.raw.dtype { + if dtype != &tensor.raw.dtype() { return Err(BaguaCoreError::BucketError( "tensors in the same bucket should be of the same dtype".into(), )); } - if device_id != tensor.raw.device_id { + if device_id != tensor.raw.device_id() { return Err(BaguaCoreError::BucketError( "tensors in the same bucket should be of the same device".into(), )); @@ -814,59 +886,19 @@ impl BaguaBucket { for tensor in tensors.iter() { let tensor = tensor.inner.read(); - if tensor.raw.num_elem_allocated < tensor.raw.num_elem { + if tensor.raw.num_elements_allocated() < tensor.raw.num_elements() { return Err(BaguaCoreError::TensorError( "num_elem_allocated should always be greater than num_elem in a tensor".into(), )); } } - // inplace memory contiguous check - if inplace { - let t = &(*tensors.first().unwrap()).inner.read(); - let bytes_per_element = dtype.bytes() as u64; - let mut current_ptr = t.raw.ptr + t.raw.num_elem_allocated as u64 * bytes_per_element; - for tensor in tensors.iter().dropping(1) { - let inner_tensor = &tensor.inner.read(); - if current_ptr != inner_tensor.raw.ptr { - return Err(BaguaCoreError::BucketError( - "tensors in a bucket not contiguous while marked as inplace".into(), - )); - } else { - current_ptr += inner_tensor.raw.num_elem_allocated as u64 * bytes_per_element; - } - } - - let mut total_bytes = 0; - for (i, tensor) in tensors.iter().enumerate() { - let inner_tensor = &tensor.inner.read(); - if (inner_tensor.raw.num_elem != inner_tensor.raw.num_elem_allocated) - && (i != (tensors.len() - 1)) - { - return Err(BaguaCoreError::BucketError( - "non-last tensors in a bucket should have num_elem == num_elem_allocated in inplace mode".into() - )); - } - total_bytes += inner_tensor.raw.num_elem_allocated * inner_tensor.raw.dtype.bytes(); - } - - if total_bytes % align_bytes != 0 { - return Err(BaguaCoreError::BucketError( - "inplace bucket tensors are not properly aligned".into(), - )); - } - } - - let id = lazy_id::Id::lazy().get(); Ok(Self { - id, name: name.to_owned(), inner: Arc::new(Mutex::new(BaguaBucketInner { - inplace, tensors: tensors.iter().map(|x| (**x).clone()).collect(), comm_ops: vec![], - dtype: tensors.first().unwrap().inner.read().raw.dtype.clone(), - align_bytes, + dtype: tensors.first().unwrap().inner.read().raw.dtype().clone(), })), }) } diff --git a/bagua-core-internal/src/events.rs b/bagua-core-internal/src/events.rs index fc1cd5f..266cfeb 100644 --- a/bagua-core-internal/src/events.rs +++ b/bagua-core-internal/src/events.rs @@ -3,10 +3,18 @@ use std::sync::Arc; #[derive(Clone, Debug)] pub struct BaguaEventChannel { + pub name: String, inner: Arc<(Mutex, parking_lot::Condvar)>, } impl BaguaEventChannel { + pub fn new(name: &str) -> Self { + Self { + name: name.to_string(), + inner: Arc::new((Mutex::new(false), parking_lot::Condvar::new())), + } + } + pub fn finish(&self) { let &(ref lock, ref cvar) = &*self.inner; let mut finished = lock.lock(); @@ -22,11 +30,3 @@ impl BaguaEventChannel { } } } - -impl Default for BaguaEventChannel { - fn default() -> Self { - Self { - inner: Arc::new((Mutex::new(false), parking_lot::Condvar::new())), - } - } -} diff --git a/bagua-core-internal/src/lib.rs b/bagua-core-internal/src/lib.rs index ac63035..289782a 100644 --- a/bagua-core-internal/src/lib.rs +++ b/bagua-core-internal/src/lib.rs @@ -11,6 +11,7 @@ pub mod events; pub mod kernels; pub mod resource_pool; pub mod telemetry; +mod torch_ffi; use crate::comm_ops::CommOpTrait; use crate::telemetry::{SCHEDULED_THREAD_POOL, TELEMETRY}; @@ -57,6 +58,7 @@ pub enum BaguaCoreError { #[derive(Debug)] pub struct BaguaScheduledCommOp { + pub name: String, pub bucket: Arc, pub ops: Vec>, pub event_channel: BaguaEventChannel, @@ -120,7 +122,7 @@ pub fn show_version() { pub struct BaguaCommBackend { ordered_buckets: VecDeque>, /// - bucket_mapping: HashMap>, + bucket_mapping: HashMap>, channels: Arc, managed_ptrs: HashSet, comm_worker: std::thread::JoinHandle<()>, @@ -129,10 +131,11 @@ pub struct BaguaCommBackend { impl BaguaCommBackend { pub fn schedule_comm(&self, bucket: Arc) -> Result<(), BaguaCoreError> { - let event_channel = BaguaEventChannel::default(); + let event_channel = BaguaEventChannel::new("comm_op"); self.channels .schedule_channel_sender .send(BaguaScheduledCommOp { + name: format!("comm op for bucket {}", bucket.name), ops: { let guard = bucket.inner.lock(); guard.comm_ops.clone() @@ -196,19 +199,21 @@ impl BaguaCommBackend { .recv() .expect("cannot receive new comm op"); tracing::debug!( - "worker received scheduled communication operation {:?}", - comm_op + "worker received scheduled communication operation {}", + comm_op.name ); - monitor_op_start_channel_sender.send(comm_op.bucket.clone()); + if let Err(e) = monitor_op_start_channel_sender.send(comm_op.bucket.clone()) { + tracing::error!("{:?}", e); + } for op in &comm_op.ops { op.execute_background_communication( comm_op.bucket.clone(), &channels_clone, ); } - tracing::debug!("comm op executed: {:?}", comm_op); + tracing::debug!("comm op executed: {}", comm_op.name); comm_op.event_channel.finish(); - tracing::debug!("comm op marked finished: {:?}", comm_op); + tracing::debug!("comm op marked finished: {}", comm_op.name); monitor_op_finish_channel_sender.send(()); } }), @@ -239,17 +244,19 @@ impl BaguaCommBackend { let bucket = Arc::new((*bucket).clone()); self.ordered_buckets.push_back(bucket.clone()); for tensor in &bucket.inner.lock().tensors { - if self.bucket_mapping.contains_key(&tensor.id) - || self.managed_ptrs.contains(&tensor.inner.read().raw.ptr) + if self.bucket_mapping.contains_key(&tensor.name()) + || self + .managed_ptrs + .contains(&tensor.inner.read().raw.data_ptr()) { return Err(BaguaCoreError::TensorError(format!( - "duplicated tensor detected, id {}, ptr {}", - &tensor.id, - &tensor.inner.read().raw.ptr + "duplicated tensor detected, name {}, ptr {}", + &tensor.name(), + &tensor.inner.read().raw.data_ptr() ))); } - self.bucket_mapping.insert(tensor.id, bucket.clone()); - self.managed_ptrs.insert(tensor.inner.read().raw.ptr); + self.bucket_mapping.insert(tensor.name(), bucket.clone()); + self.managed_ptrs.insert(tensor.inner.read().raw.data_ptr()); } } Ok(()) @@ -263,7 +270,7 @@ impl BaguaCommBackend { tensor.mark_comm_ready(ready_cuda_event_ptr); while self.should_schedule()? { let bucket = self.ordered_buckets.pop_front().unwrap(); - tracing::debug!("bucket {:?} ready for communication", bucket); + tracing::debug!("bucket {} ready for communication", bucket.name); bucket.reset_comm_ready(); let bucket_clone = bucket.clone(); self.ordered_buckets.push_back(bucket); @@ -280,9 +287,9 @@ impl BaguaCommBackend { let ev = self.channels.not_waited_events_receiver.try_recv(); match ev { Ok(x) => { - tracing::debug!("waiting for comm ops event {:?}", x); + tracing::debug!("waiting for comm ops event `{}`", x.name); x.wait(); - tracing::debug!("comm ops event {:?} finished", x); + tracing::debug!("comm ops event `{}` finished", x.name); num_ev += 1; } Err(_) => return Ok(num_ev), diff --git a/bagua-core-internal/src/telemetry/mod.rs b/bagua-core-internal/src/telemetry/mod.rs index 6fd8c15..32e8a6c 100644 --- a/bagua-core-internal/src/telemetry/mod.rs +++ b/bagua-core-internal/src/telemetry/mod.rs @@ -27,7 +27,7 @@ pub struct BaguaCommCoreTelemetry { #[derive(Debug, Serialize, Deserialize)] pub struct TelemetryPayload { - tensor_ready_order: Vec, + tensor_ready_order: Vec, communication_time_ms: u64, } @@ -55,8 +55,10 @@ impl BaguaCommCoreTelemetry { } } - pub fn new_tensor_ready(&mut self, tensor_id: u64) { - self.current_payload.tensor_ready_order.push(tensor_id); + pub fn new_tensor_ready(&mut self, tensor_name: &str) { + self.current_payload + .tensor_ready_order + .push(tensor_name.to_string()); } pub fn push_payload_and_clear(&mut self) -> Result<(), BaguaCoreError> { diff --git a/bagua-core-internal/src/torch_ffi.rs b/bagua-core-internal/src/torch_ffi.rs new file mode 100644 index 0000000..2042437 --- /dev/null +++ b/bagua-core-internal/src/torch_ffi.rs @@ -0,0 +1,1792 @@ +/* automatically generated by rust-bindgen 0.58.1 */ + +#[allow(non_snake_case, non_camel_case_types, non_upper_case_globals)] +pub mod root { + #[repr(C)] + #[derive(Copy, Clone, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] + pub struct __BindgenBitfieldUnit { + storage: Storage, + } + impl __BindgenBitfieldUnit { + #[inline] + pub const fn new(storage: Storage) -> Self { + Self { storage } + } + } + impl __BindgenBitfieldUnit + where + Storage: AsRef<[u8]> + AsMut<[u8]>, + { + #[inline] + pub fn get_bit(&self, index: usize) -> bool { + debug_assert!(index / 8 < self.storage.as_ref().len()); + let byte_index = index / 8; + let byte = self.storage.as_ref()[byte_index]; + let bit_index = if cfg!(target_endian = "big") { + 7 - (index % 8) + } else { + index % 8 + }; + let mask = 1 << bit_index; + byte & mask == mask + } + #[inline] + pub fn set_bit(&mut self, index: usize, val: bool) { + debug_assert!(index / 8 < self.storage.as_ref().len()); + let byte_index = index / 8; + let byte = &mut self.storage.as_mut()[byte_index]; + let bit_index = if cfg!(target_endian = "big") { + 7 - (index % 8) + } else { + index % 8 + }; + let mask = 1 << bit_index; + if val { + *byte |= mask; + } else { + *byte &= !mask; + } + } + #[inline] + pub fn get(&self, bit_offset: usize, bit_width: u8) -> u64 { + debug_assert!(bit_width <= 64); + debug_assert!(bit_offset / 8 < self.storage.as_ref().len()); + debug_assert!((bit_offset + (bit_width as usize)) / 8 <= self.storage.as_ref().len()); + let mut val = 0; + for i in 0..(bit_width as usize) { + if self.get_bit(i + bit_offset) { + let index = if cfg!(target_endian = "big") { + bit_width as usize - 1 - i + } else { + i + }; + val |= 1 << index; + } + } + val + } + #[inline] + pub fn set(&mut self, bit_offset: usize, bit_width: u8, val: u64) { + debug_assert!(bit_width <= 64); + debug_assert!(bit_offset / 8 < self.storage.as_ref().len()); + debug_assert!((bit_offset + (bit_width as usize)) / 8 <= self.storage.as_ref().len()); + for i in 0..(bit_width as usize) { + let mask = 1 << i; + let val_bit_is_set = val & mask == mask; + let index = if cfg!(target_endian = "big") { + bit_width as usize - 1 - i + } else { + i + }; + self.set_bit(index + bit_offset, val_bit_is_set); + } + } + } + #[allow(unused_imports)] + use self::super::root; + pub type __int8_t = ::std::os::raw::c_schar; + pub type __uint8_t = ::std::os::raw::c_uchar; + pub type __uint16_t = ::std::os::raw::c_ushort; + pub type __uint32_t = ::std::os::raw::c_uint; + pub type __int64_t = ::std::os::raw::c_long; + pub type __uint64_t = ::std::os::raw::c_ulong; + pub mod std { + #[allow(unused_imports)] + use self::super::super::root; + pub type string = [u64; 4usize]; + #[repr(C)] + #[repr(align(1))] + #[derive(Debug, Copy, Clone)] + pub struct input_iterator_tag { + pub _bindgen_opaque_blob: u8, + } + #[test] + fn bindgen_test_layout_input_iterator_tag() { + assert_eq!( + ::std::mem::size_of::(), + 1usize, + concat!("Size of: ", stringify!(input_iterator_tag)) + ); + assert_eq!( + ::std::mem::align_of::(), + 1usize, + concat!("Alignment of ", stringify!(input_iterator_tag)) + ); + } + #[repr(C)] + #[derive(Debug, Copy, Clone)] + pub struct reverse_iterator { + pub _address: u8, + } + pub type reverse_iterator___traits_type = u8; + pub type reverse_iterator_iterator_type = u8; + pub type reverse_iterator_difference_type = u8; + pub type reverse_iterator_pointer = u8; + pub type reverse_iterator_reference = u8; + #[repr(C)] + #[derive(Debug, Copy, Clone)] + pub struct allocator { + pub _address: u8, + } + pub type allocator_value_type = u8; + pub type allocator_size_type = u64; + pub type allocator_difference_type = u64; + pub type allocator_pointer = u8; + pub type allocator_const_pointer = u8; + pub type allocator_reference = u8; + pub type allocator_const_reference = u8; + #[repr(C)] + #[derive(Debug, Copy, Clone)] + pub struct allocator_rebind { + pub _address: u8, + } + pub type allocator_rebind_other = u8; + pub type allocator_propagate_on_container_move_assignment = u8; + pub type allocator_is_always_equal = u8; + #[repr(C)] + #[derive(Debug, Copy, Clone)] + pub struct default_delete { + pub _address: u8, + } + #[repr(C)] + #[derive(Debug, Copy, Clone)] + pub struct unique_ptr { + pub _address: u8, + } + pub type unique_ptr__DeleterConstraint = u8; + pub type unique_ptr_pointer = u8; + pub type unique_ptr_element_type = u8; + pub type unique_ptr_deleter_type = u8; + pub type unique_ptr___safe_conversion_up = u8; + #[repr(C)] + #[derive(Debug, Copy, Clone)] + pub struct vector { + pub _address: u8, + } + pub type vector__Base = u8; + pub type vector__Tp_alloc_type = u8; + pub type vector__Alloc_traits = u8; + pub type vector_value_type = u8; + pub type vector_pointer = u8; + pub type vector_const_pointer = u8; + pub type vector_reference = u8; + pub type vector_const_reference = u8; + pub type vector_iterator = u8; + pub type vector_const_iterator = u8; + pub type vector_const_reverse_iterator = u8; + pub type vector_reverse_iterator = u8; + pub type vector_size_type = u64; + pub type vector_difference_type = u64; + pub type vector_allocator_type = u8; + #[repr(C)] + #[derive(Debug, Copy, Clone)] + pub struct vector__Temporary_value { + pub _address: u8, + } + #[repr(C)] + #[derive(Debug, Copy, Clone)] + pub struct atomic { + pub _address: u8, + } + pub type atomic_value_type = u8; + pub mod chrono { + #[allow(unused_imports)] + use self::super::super::super::root; + } + } + pub mod __gnu_cxx { + #[allow(unused_imports)] + use self::super::super::root; + } + pub mod c10 { + #[allow(unused_imports)] + use self::super::super::root; + pub mod detail { + #[allow(unused_imports)] + use self::super::super::super::root; + #[repr(C)] + #[derive(Debug)] + pub struct UniqueVoidPtr { + pub data_: *mut ::std::os::raw::c_void, + pub ctx_: [u64; 2usize], + } + #[test] + fn bindgen_test_layout_UniqueVoidPtr() { + assert_eq!( + ::std::mem::size_of::(), + 24usize, + concat!("Size of: ", stringify!(UniqueVoidPtr)) + ); + assert_eq!( + ::std::mem::align_of::(), + 8usize, + concat!("Alignment of ", stringify!(UniqueVoidPtr)) + ); + assert_eq!( + unsafe { &(*(::std::ptr::null::())).data_ as *const _ as usize }, + 0usize, + concat!( + "Offset of field: ", + stringify!(UniqueVoidPtr), + "::", + stringify!(data_) + ) + ); + assert_eq!( + unsafe { &(*(::std::ptr::null::())).ctx_ as *const _ as usize }, + 8usize, + concat!( + "Offset of field: ", + stringify!(UniqueVoidPtr), + "::", + stringify!(ctx_) + ) + ); + } + #[repr(C)] + #[derive(Debug, Copy, Clone)] + pub struct intrusive_target_default_null_type { + pub _address: u8, + } + } + pub mod guts { + #[allow(unused_imports)] + use self::super::super::super::root; + pub mod typelist { + #[allow(unused_imports)] + use self::super::super::super::super::root; + } + } + #[repr(i8)] + #[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] + pub enum DeviceType { + CPU = 0, + CUDA = 1, + MKLDNN = 2, + OPENGL = 3, + OPENCL = 4, + IDEEP = 5, + HIP = 6, + FPGA = 7, + MSNPU = 8, + XLA = 9, + Vulkan = 10, + Metal = 11, + XPU = 12, + MLC = 13, + Meta = 14, + HPU = 15, + COMPILE_TIME_MAX_DEVICE_TYPES = 16, + } + pub type DeviceIndex = i8; + #[repr(C)] + #[derive(Debug, Copy, Clone)] + pub struct Device { + pub type_: root::c10::DeviceType, + pub index_: root::c10::DeviceIndex, + } + pub use self::super::super::root::c10::DeviceType as Device_Type; + #[test] + fn bindgen_test_layout_Device() { + assert_eq!( + ::std::mem::size_of::(), + 2usize, + concat!("Size of: ", stringify!(Device)) + ); + assert_eq!( + ::std::mem::align_of::(), + 1usize, + concat!("Alignment of ", stringify!(Device)) + ); + assert_eq!( + unsafe { &(*(::std::ptr::null::())).type_ as *const _ as usize }, + 0usize, + concat!( + "Offset of field: ", + stringify!(Device), + "::", + stringify!(type_) + ) + ); + assert_eq!( + unsafe { &(*(::std::ptr::null::())).index_ as *const _ as usize }, + 1usize, + concat!( + "Offset of field: ", + stringify!(Device), + "::", + stringify!(index_) + ) + ); + } + extern "C" { + #[link_name = "\u{1}_ZN3c106DeviceC1ERKNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE"] + pub fn Device_Device( + this: *mut root::c10::Device, + device_string: *const root::std::string, + ); + } + impl Device { + #[inline] + pub unsafe fn new(device_string: *const root::std::string) -> Self { + let mut __bindgen_tmp = ::std::mem::MaybeUninit::uninit(); + Device_Device(__bindgen_tmp.as_mut_ptr(), device_string); + __bindgen_tmp.assume_init() + } + } + pub type DeleterFnPtr = + ::std::option::Option; + #[repr(C)] + #[derive(Debug)] + pub struct DataPtr { + pub ptr_: root::c10::detail::UniqueVoidPtr, + pub device_: root::c10::Device, + } + #[test] + fn bindgen_test_layout_DataPtr() { + assert_eq!( + ::std::mem::size_of::(), + 32usize, + concat!("Size of: ", stringify!(DataPtr)) + ); + assert_eq!( + ::std::mem::align_of::(), + 8usize, + concat!("Alignment of ", stringify!(DataPtr)) + ); + assert_eq!( + unsafe { &(*(::std::ptr::null::())).ptr_ as *const _ as usize }, + 0usize, + concat!( + "Offset of field: ", + stringify!(DataPtr), + "::", + stringify!(ptr_) + ) + ); + assert_eq!( + unsafe { &(*(::std::ptr::null::())).device_ as *const _ as usize }, + 24usize, + concat!( + "Offset of field: ", + stringify!(DataPtr), + "::", + stringify!(device_) + ) + ); + } + #[repr(C)] + pub struct Allocator__bindgen_vtable(::std::os::raw::c_void); + #[repr(C)] + #[derive(Debug)] + pub struct Allocator { + pub vtable_: *const Allocator__bindgen_vtable, + } + #[test] + fn bindgen_test_layout_Allocator() { + assert_eq!( + ::std::mem::size_of::(), + 8usize, + concat!("Size of: ", stringify!(Allocator)) + ); + assert_eq!( + ::std::mem::align_of::(), + 8usize, + concat!("Alignment of ", stringify!(Allocator)) + ); + } + #[repr(C)] + #[derive(Debug, Copy, Clone)] + pub struct ArrayRef { + pub Data: *const T, + pub Length: root::c10::ArrayRef_size_type, + pub _phantom_0: ::std::marker::PhantomData<::std::cell::UnsafeCell>, + } + pub type ArrayRef_iterator = *const T; + pub type ArrayRef_const_iterator = *const T; + pub type ArrayRef_size_type = usize; + pub type ArrayRef_value_type = T; + pub type ArrayRef_reverse_iterator = u8; + pub type IntArrayRef = root::c10::ArrayRef; + #[repr(C)] + #[derive(Debug, Copy, Clone)] + pub struct optional { + pub _address: u8, + } + pub type optional_OptionalBase = u8; + pub type optional_value_type = u8; + pub mod impl_ { + #[allow(unused_imports)] + use self::super::super::super::root; + #[repr(C)] + pub struct SizesAndStrides { + pub size_: usize, + pub __bindgen_anon_1: root::c10::impl_::SizesAndStrides__bindgen_ty_1, + } + pub type SizesAndStrides_sizes_iterator = *mut i64; + pub type SizesAndStrides_sizes_const_iterator = *const i64; + pub type SizesAndStrides_strides_iterator = *mut i64; + pub type SizesAndStrides_strides_const_iterator = *const i64; + #[repr(C)] + #[derive(Copy, Clone)] + pub union SizesAndStrides__bindgen_ty_1 { + pub outOfLineStorage_: *mut i64, + pub inlineStorage_: [i64; 10usize], + } + #[test] + fn bindgen_test_layout_SizesAndStrides__bindgen_ty_1() { + assert_eq!( + ::std::mem::size_of::(), + 80usize, + concat!("Size of: ", stringify!(SizesAndStrides__bindgen_ty_1)) + ); + assert_eq!( + ::std::mem::align_of::(), + 8usize, + concat!("Alignment of ", stringify!(SizesAndStrides__bindgen_ty_1)) + ); + assert_eq!( + unsafe { + &(*(::std::ptr::null::())).outOfLineStorage_ + as *const _ as usize + }, + 0usize, + concat!( + "Offset of field: ", + stringify!(SizesAndStrides__bindgen_ty_1), + "::", + stringify!(outOfLineStorage_) + ) + ); + assert_eq!( + unsafe { + &(*(::std::ptr::null::())).inlineStorage_ + as *const _ as usize + }, + 0usize, + concat!( + "Offset of field: ", + stringify!(SizesAndStrides__bindgen_ty_1), + "::", + stringify!(inlineStorage_) + ) + ); + } + #[test] + fn bindgen_test_layout_SizesAndStrides() { + assert_eq!( + ::std::mem::size_of::(), + 88usize, + concat!("Size of: ", stringify!(SizesAndStrides)) + ); + assert_eq!( + ::std::mem::align_of::(), + 8usize, + concat!("Alignment of ", stringify!(SizesAndStrides)) + ); + assert_eq!( + unsafe { + &(*(::std::ptr::null::())).size_ as *const _ as usize + }, + 0usize, + concat!( + "Offset of field: ", + stringify!(SizesAndStrides), + "::", + stringify!(size_) + ) + ); + } + } + pub mod raw { + #[allow(unused_imports)] + use self::super::super::super::root; + } + #[repr(C)] + pub struct intrusive_ptr_target__bindgen_vtable(::std::os::raw::c_void); + #[repr(C)] + #[derive(Debug)] + pub struct intrusive_ptr_target { + pub vtable_: *const intrusive_ptr_target__bindgen_vtable, + pub refcount_: u64, + pub weakcount_: u64, + } + #[test] + fn bindgen_test_layout_intrusive_ptr_target() { + assert_eq!( + ::std::mem::size_of::(), + 24usize, + concat!("Size of: ", stringify!(intrusive_ptr_target)) + ); + assert_eq!( + ::std::mem::align_of::(), + 8usize, + concat!("Alignment of ", stringify!(intrusive_ptr_target)) + ); + assert_eq!( + unsafe { + &(*(::std::ptr::null::())).refcount_ as *const _ as usize + }, + 8usize, + concat!( + "Offset of field: ", + stringify!(intrusive_ptr_target), + "::", + stringify!(refcount_) + ) + ); + assert_eq!( + unsafe { + &(*(::std::ptr::null::())).weakcount_ as *const _ as usize + }, + 16usize, + concat!( + "Offset of field: ", + stringify!(intrusive_ptr_target), + "::", + stringify!(weakcount_) + ) + ); + } + #[repr(C)] + #[derive(Debug)] + pub struct intrusive_ptr { + pub target_: *mut TTarget, + pub _phantom_0: ::std::marker::PhantomData<::std::cell::UnsafeCell>, + } + pub type intrusive_ptr_element_type = TTarget; + #[repr(C)] + #[derive(Debug)] + pub struct StorageImpl { + pub _base: root::c10::intrusive_ptr_target, + pub data_ptr_: root::c10::DataPtr, + pub size_bytes_: usize, + pub resizable_: bool, + pub received_cuda_: bool, + pub allocator_: *mut root::c10::Allocator, + } + #[repr(C)] + #[derive(Debug, Copy, Clone)] + pub struct StorageImpl_use_byte_size_t { + pub _address: u8, + } + #[test] + fn bindgen_test_layout_StorageImpl_use_byte_size_t() { + assert_eq!( + ::std::mem::size_of::(), + 1usize, + concat!("Size of: ", stringify!(StorageImpl_use_byte_size_t)) + ); + assert_eq!( + ::std::mem::align_of::(), + 1usize, + concat!("Alignment of ", stringify!(StorageImpl_use_byte_size_t)) + ); + } + #[test] + fn bindgen_test_layout_StorageImpl() { + assert_eq!( + ::std::mem::size_of::(), + 80usize, + concat!("Size of: ", stringify!(StorageImpl)) + ); + assert_eq!( + ::std::mem::align_of::(), + 8usize, + concat!("Alignment of ", stringify!(StorageImpl)) + ); + assert_eq!( + unsafe { &(*(::std::ptr::null::())).data_ptr_ as *const _ as usize }, + 24usize, + concat!( + "Offset of field: ", + stringify!(StorageImpl), + "::", + stringify!(data_ptr_) + ) + ); + assert_eq!( + unsafe { &(*(::std::ptr::null::())).size_bytes_ as *const _ as usize }, + 56usize, + concat!( + "Offset of field: ", + stringify!(StorageImpl), + "::", + stringify!(size_bytes_) + ) + ); + assert_eq!( + unsafe { &(*(::std::ptr::null::())).resizable_ as *const _ as usize }, + 64usize, + concat!( + "Offset of field: ", + stringify!(StorageImpl), + "::", + stringify!(resizable_) + ) + ); + assert_eq!( + unsafe { + &(*(::std::ptr::null::())).received_cuda_ as *const _ as usize + }, + 65usize, + concat!( + "Offset of field: ", + stringify!(StorageImpl), + "::", + stringify!(received_cuda_) + ) + ); + assert_eq!( + unsafe { &(*(::std::ptr::null::())).allocator_ as *const _ as usize }, + 72usize, + concat!( + "Offset of field: ", + stringify!(StorageImpl), + "::", + stringify!(allocator_) + ) + ); + } + impl root::c10::DispatchKey { + pub const CatchAll: root::c10::DispatchKey = DispatchKey::Undefined; + } + impl root::c10::DispatchKey { + pub const EndOfBackendKeys: root::c10::DispatchKey = DispatchKey::PrivateUse3; + } + impl root::c10::DispatchKey { + pub const EndOfAliasKeys: root::c10::DispatchKey = + DispatchKey::CompositeExplicitAutograd; + } + impl root::c10::DispatchKey { + pub const CPUTensorId: root::c10::DispatchKey = DispatchKey::CPU; + } + impl root::c10::DispatchKey { + pub const CUDATensorId: root::c10::DispatchKey = DispatchKey::CUDA; + } + impl root::c10::DispatchKey { + pub const DefaultBackend: root::c10::DispatchKey = + DispatchKey::CompositeExplicitAutograd; + } + impl root::c10::DispatchKey { + pub const PrivateUse1_PreAutograd: root::c10::DispatchKey = + DispatchKey::AutogradPrivateUse1; + } + impl root::c10::DispatchKey { + pub const PrivateUse2_PreAutograd: root::c10::DispatchKey = + DispatchKey::AutogradPrivateUse2; + } + impl root::c10::DispatchKey { + pub const PrivateUse3_PreAutograd: root::c10::DispatchKey = + DispatchKey::AutogradPrivateUse3; + } + impl root::c10::DispatchKey { + pub const Autocast: root::c10::DispatchKey = DispatchKey::AutocastCUDA; + } + #[repr(u8)] + #[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] + pub enum DispatchKey { + Undefined = 0, + CPU = 1, + CUDA = 2, + HIP = 3, + FPGA = 4, + MSNPU = 5, + XLA = 6, + MLC = 7, + Vulkan = 8, + Metal = 9, + XPU = 10, + HPU = 11, + Meta = 12, + QuantizedCPU = 13, + QuantizedCUDA = 14, + QuantizedXPU = 15, + CustomRNGKeyId = 16, + MkldnnCPU = 17, + SparseCPU = 18, + SparseCUDA = 19, + SparseHIP = 20, + SparseXPU = 21, + SparseCsrCPU = 22, + SparseCsrCUDA = 23, + NestedTensor = 24, + PrivateUse1 = 25, + PrivateUse2 = 26, + PrivateUse3 = 27, + BackendSelect = 28, + FuncTorchPython = 29, + Named = 30, + FuncTorchDynamicLayerBackMode = 31, + ADInplaceOrView = 32, + AutogradOther = 33, + AutogradCPU = 34, + AutogradCUDA = 35, + AutogradXLA = 36, + AutogradXPU = 37, + AutogradMLC = 38, + AutogradHPU = 39, + AutogradNestedTensor = 40, + AutogradPrivateUse1 = 41, + AutogradPrivateUse2 = 42, + AutogradPrivateUse3 = 43, + Tracer = 44, + AutocastCUDA = 45, + FuncTorchBatched = 46, + FuncTorchVmapMode = 47, + Batched = 48, + VmapMode = 49, + FuncTorchGradWrapper = 50, + FuncTorchDynamicLayerFrontMode = 51, + TESTING_ONLY_GenericWrapper = 52, + TESTING_ONLY_GenericMode = 53, + NumDispatchKeys = 54, + Autograd = 55, + CompositeImplicitAutograd = 56, + CompositeExplicitAutograd = 57, + } + pub mod llvm { + #[allow(unused_imports)] + use self::super::super::super::root; + } + #[repr(C)] + #[derive(Debug, Copy, Clone)] + pub struct DispatchKeySet { + pub repr_: u64, + } + #[repr(u32)] + #[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] + pub enum DispatchKeySet_Full { + FULL = 0, + } + #[repr(u32)] + #[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] + pub enum DispatchKeySet_FullAfter { + FULL_AFTER = 0, + } + #[repr(u32)] + #[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] + pub enum DispatchKeySet_Raw { + RAW = 0, + } + #[repr(C)] + #[derive(Debug, Copy, Clone)] + pub struct DispatchKeySet_iterator { + pub data_ptr_: *const u64, + pub i_: u8, + } + pub type DispatchKeySet_iterator_self_type = root::c10::DispatchKeySet_iterator; + pub type DispatchKeySet_iterator_iterator_category = root::std::input_iterator_tag; + pub use self::super::super::root::c10::DispatchKey as DispatchKeySet_iterator_value_type; + pub type DispatchKeySet_iterator_difference_type = isize; + #[test] + fn bindgen_test_layout_DispatchKeySet_iterator() { + assert_eq!( + ::std::mem::size_of::(), + 16usize, + concat!("Size of: ", stringify!(DispatchKeySet_iterator)) + ); + assert_eq!( + ::std::mem::align_of::(), + 8usize, + concat!("Alignment of ", stringify!(DispatchKeySet_iterator)) + ); + assert_eq!( + unsafe { + &(*(::std::ptr::null::())).data_ptr_ as *const _ + as usize + }, + 0usize, + concat!( + "Offset of field: ", + stringify!(DispatchKeySet_iterator), + "::", + stringify!(data_ptr_) + ) + ); + assert_eq!( + unsafe { + &(*(::std::ptr::null::())).i_ as *const _ as usize + }, + 8usize, + concat!( + "Offset of field: ", + stringify!(DispatchKeySet_iterator), + "::", + stringify!(i_) + ) + ); + } + #[test] + fn bindgen_test_layout_DispatchKeySet() { + assert_eq!( + ::std::mem::size_of::(), + 8usize, + concat!("Size of: ", stringify!(DispatchKeySet)) + ); + assert_eq!( + ::std::mem::align_of::(), + 8usize, + concat!("Alignment of ", stringify!(DispatchKeySet)) + ); + assert_eq!( + unsafe { &(*(::std::ptr::null::())).repr_ as *const _ as usize }, + 0usize, + concat!( + "Offset of field: ", + stringify!(DispatchKeySet), + "::", + stringify!(repr_) + ) + ); + } + #[repr(i8)] + #[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] + pub enum Layout { + Strided = 0, + Sparse = 1, + SparseCsr = 2, + Mkldnn = 3, + NumOptions = 4, + } + #[repr(i8)] + #[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] + pub enum MemoryFormat { + Contiguous = 0, + Preserve = 1, + ChannelsLast = 2, + ChannelsLast3d = 3, + } + pub mod util { + #[allow(unused_imports)] + use self::super::super::super::root; + } + #[repr(C)] + #[derive(Debug)] + pub struct Storage { + pub storage_impl_: root::c10::intrusive_ptr, + } + #[repr(C)] + #[derive(Debug, Copy, Clone)] + pub struct Storage_use_byte_size_t { + pub _address: u8, + } + #[test] + fn bindgen_test_layout_Storage_use_byte_size_t() { + assert_eq!( + ::std::mem::size_of::(), + 1usize, + concat!("Size of: ", stringify!(Storage_use_byte_size_t)) + ); + assert_eq!( + ::std::mem::align_of::(), + 1usize, + concat!("Alignment of ", stringify!(Storage_use_byte_size_t)) + ); + } + #[test] + fn bindgen_test_layout_Storage() { + assert_eq!( + ::std::mem::size_of::(), + 8usize, + concat!("Size of: ", stringify!(Storage)) + ); + assert_eq!( + ::std::mem::align_of::(), + 8usize, + concat!("Alignment of ", stringify!(Storage)) + ); + assert_eq!( + unsafe { &(*(::std::ptr::null::())).storage_impl_ as *const _ as usize }, + 0usize, + concat!( + "Offset of field: ", + stringify!(Storage), + "::", + stringify!(storage_impl_) + ) + ); + } + #[repr(C)] + pub struct AutogradMetaInterface__bindgen_vtable(::std::os::raw::c_void); + #[repr(C)] + #[derive(Debug)] + pub struct AutogradMetaInterface { + pub vtable_: *const AutogradMetaInterface__bindgen_vtable, + } + #[test] + fn bindgen_test_layout_AutogradMetaInterface() { + assert_eq!( + ::std::mem::size_of::(), + 8usize, + concat!("Size of: ", stringify!(AutogradMetaInterface)) + ); + assert_eq!( + ::std::mem::align_of::(), + 8usize, + concat!("Alignment of ", stringify!(AutogradMetaInterface)) + ); + } + extern "C" { + #[link_name = "\u{1}_ZN3c1021AutogradMetaInterfaceD1Ev"] + pub fn AutogradMetaInterface_AutogradMetaInterface_destructor( + this: *mut root::c10::AutogradMetaInterface, + ); + } + #[repr(C)] + pub struct NamedTensorMetaInterface__bindgen_vtable(::std::os::raw::c_void); + #[repr(C)] + #[derive(Debug)] + pub struct NamedTensorMetaInterface { + pub vtable_: *const NamedTensorMetaInterface__bindgen_vtable, + } + #[test] + fn bindgen_test_layout_NamedTensorMetaInterface() { + assert_eq!( + ::std::mem::size_of::(), + 8usize, + concat!("Size of: ", stringify!(NamedTensorMetaInterface)) + ); + assert_eq!( + ::std::mem::align_of::(), + 8usize, + concat!("Alignment of ", stringify!(NamedTensorMetaInterface)) + ); + } + #[repr(C)] + #[derive(Debug)] + pub struct VariableVersion { + pub version_counter_: + root::c10::intrusive_ptr, + } + #[repr(C)] + #[derive(Debug)] + pub struct VariableVersion_VersionCounter { + pub _base: root::c10::intrusive_ptr_target, + pub version_: u32, + } + #[test] + fn bindgen_test_layout_VariableVersion_VersionCounter() { + assert_eq!( + ::std::mem::size_of::(), + 32usize, + concat!("Size of: ", stringify!(VariableVersion_VersionCounter)) + ); + assert_eq!( + ::std::mem::align_of::(), + 8usize, + concat!("Alignment of ", stringify!(VariableVersion_VersionCounter)) + ); + assert_eq!( + unsafe { + &(*(::std::ptr::null::())).version_ as *const _ + as usize + }, + 24usize, + concat!( + "Offset of field: ", + stringify!(VariableVersion_VersionCounter), + "::", + stringify!(version_) + ) + ); + } + #[repr(u32)] + #[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] + pub enum VariableVersion_Disabled { + DISABLED = 0, + } + #[test] + fn bindgen_test_layout_VariableVersion() { + assert_eq!( + ::std::mem::size_of::(), + 8usize, + concat!("Size of: ", stringify!(VariableVersion)) + ); + assert_eq!( + ::std::mem::align_of::(), + 8usize, + concat!("Alignment of ", stringify!(VariableVersion)) + ); + assert_eq!( + unsafe { + &(*(::std::ptr::null::())).version_counter_ as *const _ + as usize + }, + 0usize, + concat!( + "Offset of field: ", + stringify!(VariableVersion), + "::", + stringify!(version_counter_) + ) + ); + } + #[repr(C)] + pub struct TensorImpl { + pub _base: root::c10::intrusive_ptr_target, + pub storage_: root::c10::Storage, + pub autograd_meta_: u64, + pub named_tensor_meta_: u64, + pub version_counter_: root::c10::VariableVersion, + pub pyobj_: *mut root::PyObject, + pub sizes_and_strides_: root::c10::impl_::SizesAndStrides, + pub storage_offset_: i64, + pub numel_: i64, + pub data_type_: root::caffe2::TypeMeta, + pub device_opt_: [u8; 3usize], + pub _bitfield_align_1: [u8; 0], + pub _bitfield_1: root::__BindgenBitfieldUnit<[u8; 1usize]>, + pub storage_access_should_throw_: bool, + pub _bitfield_align_2: [u8; 0], + pub _bitfield_2: root::__BindgenBitfieldUnit<[u8; 1usize]>, + pub key_set_: root::c10::DispatchKeySet, + } + #[repr(u32)] + #[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] + pub enum TensorImpl_ImplType { + VIEW = 0, + } + #[repr(u8)] + #[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] + pub enum TensorImpl_HasContiguityPolicy { + Default = 0, + ContiguityNotSupported = 1, + CustomBehavior = 2, + } + extern "C" { + #[link_name = "\u{1}_ZN3c1010TensorImpl42err_msg_tensor_metadata_change_not_allowedE"] + pub static TensorImpl_err_msg_tensor_metadata_change_not_allowed: + *const ::std::os::raw::c_char; + } + #[test] + fn bindgen_test_layout_TensorImpl() { + assert_eq!( + ::std::mem::size_of::(), + 184usize, + concat!("Size of: ", stringify!(TensorImpl)) + ); + assert_eq!( + ::std::mem::align_of::(), + 8usize, + concat!("Alignment of ", stringify!(TensorImpl)) + ); + assert_eq!( + unsafe { &(*(::std::ptr::null::())).storage_ as *const _ as usize }, + 24usize, + concat!( + "Offset of field: ", + stringify!(TensorImpl), + "::", + stringify!(storage_) + ) + ); + assert_eq!( + unsafe { + &(*(::std::ptr::null::())).autograd_meta_ as *const _ as usize + }, + 32usize, + concat!( + "Offset of field: ", + stringify!(TensorImpl), + "::", + stringify!(autograd_meta_) + ) + ); + assert_eq!( + unsafe { + &(*(::std::ptr::null::())).named_tensor_meta_ as *const _ as usize + }, + 40usize, + concat!( + "Offset of field: ", + stringify!(TensorImpl), + "::", + stringify!(named_tensor_meta_) + ) + ); + assert_eq!( + unsafe { + &(*(::std::ptr::null::())).version_counter_ as *const _ as usize + }, + 48usize, + concat!( + "Offset of field: ", + stringify!(TensorImpl), + "::", + stringify!(version_counter_) + ) + ); + assert_eq!( + unsafe { &(*(::std::ptr::null::())).pyobj_ as *const _ as usize }, + 56usize, + concat!( + "Offset of field: ", + stringify!(TensorImpl), + "::", + stringify!(pyobj_) + ) + ); + assert_eq!( + unsafe { + &(*(::std::ptr::null::())).sizes_and_strides_ as *const _ as usize + }, + 64usize, + concat!( + "Offset of field: ", + stringify!(TensorImpl), + "::", + stringify!(sizes_and_strides_) + ) + ); + assert_eq!( + unsafe { + &(*(::std::ptr::null::())).storage_offset_ as *const _ as usize + }, + 152usize, + concat!( + "Offset of field: ", + stringify!(TensorImpl), + "::", + stringify!(storage_offset_) + ) + ); + assert_eq!( + unsafe { &(*(::std::ptr::null::())).numel_ as *const _ as usize }, + 160usize, + concat!( + "Offset of field: ", + stringify!(TensorImpl), + "::", + stringify!(numel_) + ) + ); + assert_eq!( + unsafe { &(*(::std::ptr::null::())).data_type_ as *const _ as usize }, + 168usize, + concat!( + "Offset of field: ", + stringify!(TensorImpl), + "::", + stringify!(data_type_) + ) + ); + assert_eq!( + unsafe { &(*(::std::ptr::null::())).device_opt_ as *const _ as usize }, + 170usize, + concat!( + "Offset of field: ", + stringify!(TensorImpl), + "::", + stringify!(device_opt_) + ) + ); + assert_eq!( + unsafe { + &(*(::std::ptr::null::())).storage_access_should_throw_ as *const _ + as usize + }, + 174usize, + concat!( + "Offset of field: ", + stringify!(TensorImpl), + "::", + stringify!(storage_access_should_throw_) + ) + ); + assert_eq!( + unsafe { &(*(::std::ptr::null::())).key_set_ as *const _ as usize }, + 176usize, + concat!( + "Offset of field: ", + stringify!(TensorImpl), + "::", + stringify!(key_set_) + ) + ); + } + extern "C" { + #[link_name = "\u{1}_ZN3c1010TensorImplC1EONS_7StorageENS_14DispatchKeySetEN6caffe28TypeMetaE"] + pub fn TensorImpl_TensorImpl( + this: *mut root::c10::TensorImpl, + storage: *mut root::c10::Storage, + arg1: root::c10::DispatchKeySet, + data_type: root::caffe2::TypeMeta, + ); + } + extern "C" { + #[link_name = "\u{1}_ZN3c1010TensorImplC1ENS0_8ImplTypeEONS_7StorageENS_14DispatchKeySetEN6caffe28TypeMetaE"] + pub fn TensorImpl_TensorImpl1( + this: *mut root::c10::TensorImpl, + arg1: root::c10::TensorImpl_ImplType, + storage: *mut root::c10::Storage, + arg2: root::c10::DispatchKeySet, + data_type: root::caffe2::TypeMeta, + ); + } + extern "C" { + #[link_name = "\u{1}_ZN3c1010TensorImplC1ENS_14DispatchKeySetEN6caffe28TypeMetaENS_8optionalINS_6DeviceEEE"] + pub fn TensorImpl_TensorImpl2( + this: *mut root::c10::TensorImpl, + arg1: root::c10::DispatchKeySet, + data_type: root::caffe2::TypeMeta, + device_opt: [u8; 3usize], + ); + } + impl TensorImpl { + #[inline] + pub fn is_contiguous_(&self) -> bool { + unsafe { ::std::mem::transmute(self._bitfield_1.get(0usize, 1u8) as u8) } + } + #[inline] + pub fn set_is_contiguous_(&mut self, val: bool) { + unsafe { + let val: u8 = ::std::mem::transmute(val); + self._bitfield_1.set(0usize, 1u8, val as u64) + } + } + #[inline] + pub fn has_contiguity_(&self) -> u8 { + unsafe { ::std::mem::transmute(self._bitfield_1.get(1usize, 2u8) as u8) } + } + #[inline] + pub fn set_has_contiguity_(&mut self, val: u8) { + unsafe { + let val: u8 = ::std::mem::transmute(val); + self._bitfield_1.set(1usize, 2u8, val as u64) + } + } + #[inline] + pub fn new_bitfield_1( + is_contiguous_: bool, + has_contiguity_: u8, + ) -> root::__BindgenBitfieldUnit<[u8; 1usize]> { + let mut __bindgen_bitfield_unit: root::__BindgenBitfieldUnit<[u8; 1usize]> = + Default::default(); + __bindgen_bitfield_unit.set(0usize, 1u8, { + let is_contiguous_: u8 = unsafe { ::std::mem::transmute(is_contiguous_) }; + is_contiguous_ as u64 + }); + __bindgen_bitfield_unit.set(1usize, 2u8, { + let has_contiguity_: u8 = unsafe { ::std::mem::transmute(has_contiguity_) }; + has_contiguity_ as u64 + }); + __bindgen_bitfield_unit + } + #[inline] + pub fn is_channels_last_(&self) -> bool { + unsafe { ::std::mem::transmute(self._bitfield_2.get(0usize, 1u8) as u8) } + } + #[inline] + pub fn set_is_channels_last_(&mut self, val: bool) { + unsafe { + let val: u8 = ::std::mem::transmute(val); + self._bitfield_2.set(0usize, 1u8, val as u64) + } + } + #[inline] + pub fn is_channels_last_contiguous_(&self) -> bool { + unsafe { ::std::mem::transmute(self._bitfield_2.get(1usize, 1u8) as u8) } + } + #[inline] + pub fn set_is_channels_last_contiguous_(&mut self, val: bool) { + unsafe { + let val: u8 = ::std::mem::transmute(val); + self._bitfield_2.set(1usize, 1u8, val as u64) + } + } + #[inline] + pub fn is_channels_last_3d_(&self) -> bool { + unsafe { ::std::mem::transmute(self._bitfield_2.get(2usize, 1u8) as u8) } + } + #[inline] + pub fn set_is_channels_last_3d_(&mut self, val: bool) { + unsafe { + let val: u8 = ::std::mem::transmute(val); + self._bitfield_2.set(2usize, 1u8, val as u64) + } + } + #[inline] + pub fn is_channels_last_3d_contiguous_(&self) -> bool { + unsafe { ::std::mem::transmute(self._bitfield_2.get(3usize, 1u8) as u8) } + } + #[inline] + pub fn set_is_channels_last_3d_contiguous_(&mut self, val: bool) { + unsafe { + let val: u8 = ::std::mem::transmute(val); + self._bitfield_2.set(3usize, 1u8, val as u64) + } + } + #[inline] + pub fn is_non_overlapping_and_dense_(&self) -> bool { + unsafe { ::std::mem::transmute(self._bitfield_2.get(4usize, 1u8) as u8) } + } + #[inline] + pub fn set_is_non_overlapping_and_dense_(&mut self, val: bool) { + unsafe { + let val: u8 = ::std::mem::transmute(val); + self._bitfield_2.set(4usize, 1u8, val as u64) + } + } + #[inline] + pub fn is_wrapped_number_(&self) -> bool { + unsafe { ::std::mem::transmute(self._bitfield_2.get(5usize, 1u8) as u8) } + } + #[inline] + pub fn set_is_wrapped_number_(&mut self, val: bool) { + unsafe { + let val: u8 = ::std::mem::transmute(val); + self._bitfield_2.set(5usize, 1u8, val as u64) + } + } + #[inline] + pub fn allow_tensor_metadata_change_(&self) -> bool { + unsafe { ::std::mem::transmute(self._bitfield_2.get(6usize, 1u8) as u8) } + } + #[inline] + pub fn set_allow_tensor_metadata_change_(&mut self, val: bool) { + unsafe { + let val: u8 = ::std::mem::transmute(val); + self._bitfield_2.set(6usize, 1u8, val as u64) + } + } + #[inline] + pub fn reserved_(&self) -> bool { + unsafe { ::std::mem::transmute(self._bitfield_2.get(7usize, 1u8) as u8) } + } + #[inline] + pub fn set_reserved_(&mut self, val: bool) { + unsafe { + let val: u8 = ::std::mem::transmute(val); + self._bitfield_2.set(7usize, 1u8, val as u64) + } + } + #[inline] + pub fn new_bitfield_2( + is_channels_last_: bool, + is_channels_last_contiguous_: bool, + is_channels_last_3d_: bool, + is_channels_last_3d_contiguous_: bool, + is_non_overlapping_and_dense_: bool, + is_wrapped_number_: bool, + allow_tensor_metadata_change_: bool, + reserved_: bool, + ) -> root::__BindgenBitfieldUnit<[u8; 1usize]> { + let mut __bindgen_bitfield_unit: root::__BindgenBitfieldUnit<[u8; 1usize]> = + Default::default(); + __bindgen_bitfield_unit.set(0usize, 1u8, { + let is_channels_last_: u8 = unsafe { ::std::mem::transmute(is_channels_last_) }; + is_channels_last_ as u64 + }); + __bindgen_bitfield_unit.set(1usize, 1u8, { + let is_channels_last_contiguous_: u8 = + unsafe { ::std::mem::transmute(is_channels_last_contiguous_) }; + is_channels_last_contiguous_ as u64 + }); + __bindgen_bitfield_unit.set(2usize, 1u8, { + let is_channels_last_3d_: u8 = + unsafe { ::std::mem::transmute(is_channels_last_3d_) }; + is_channels_last_3d_ as u64 + }); + __bindgen_bitfield_unit.set(3usize, 1u8, { + let is_channels_last_3d_contiguous_: u8 = + unsafe { ::std::mem::transmute(is_channels_last_3d_contiguous_) }; + is_channels_last_3d_contiguous_ as u64 + }); + __bindgen_bitfield_unit.set(4usize, 1u8, { + let is_non_overlapping_and_dense_: u8 = + unsafe { ::std::mem::transmute(is_non_overlapping_and_dense_) }; + is_non_overlapping_and_dense_ as u64 + }); + __bindgen_bitfield_unit.set(5usize, 1u8, { + let is_wrapped_number_: u8 = + unsafe { ::std::mem::transmute(is_wrapped_number_) }; + is_wrapped_number_ as u64 + }); + __bindgen_bitfield_unit.set(6usize, 1u8, { + let allow_tensor_metadata_change_: u8 = + unsafe { ::std::mem::transmute(allow_tensor_metadata_change_) }; + allow_tensor_metadata_change_ as u64 + }); + __bindgen_bitfield_unit.set(7usize, 1u8, { + let reserved_: u8 = unsafe { ::std::mem::transmute(reserved_) }; + reserved_ as u64 + }); + __bindgen_bitfield_unit + } + #[inline] + pub unsafe fn new( + storage: *mut root::c10::Storage, + arg1: root::c10::DispatchKeySet, + data_type: root::caffe2::TypeMeta, + ) -> Self { + let mut __bindgen_tmp = ::std::mem::MaybeUninit::uninit(); + TensorImpl_TensorImpl(__bindgen_tmp.as_mut_ptr(), storage, arg1, data_type); + __bindgen_tmp.assume_init() + } + #[inline] + pub unsafe fn new1( + arg1: root::c10::TensorImpl_ImplType, + storage: *mut root::c10::Storage, + arg2: root::c10::DispatchKeySet, + data_type: root::caffe2::TypeMeta, + ) -> Self { + let mut __bindgen_tmp = ::std::mem::MaybeUninit::uninit(); + TensorImpl_TensorImpl1(__bindgen_tmp.as_mut_ptr(), arg1, storage, arg2, data_type); + __bindgen_tmp.assume_init() + } + #[inline] + pub unsafe fn new2( + arg1: root::c10::DispatchKeySet, + data_type: root::caffe2::TypeMeta, + device_opt: [u8; 3usize], + ) -> Self { + let mut __bindgen_tmp = ::std::mem::MaybeUninit::uninit(); + TensorImpl_TensorImpl2(__bindgen_tmp.as_mut_ptr(), arg1, data_type, device_opt); + __bindgen_tmp.assume_init() + } + } + #[repr(C)] + pub struct UndefinedTensorImpl { + pub _base: root::c10::TensorImpl, + } + extern "C" { + #[link_name = "\u{1}_ZN3c1019UndefinedTensorImpl10_singletonE"] + pub static mut UndefinedTensorImpl__singleton: root::c10::UndefinedTensorImpl; + } + #[test] + fn bindgen_test_layout_UndefinedTensorImpl() { + assert_eq!( + ::std::mem::size_of::(), + 184usize, + concat!("Size of: ", stringify!(UndefinedTensorImpl)) + ); + assert_eq!( + ::std::mem::align_of::(), + 8usize, + concat!("Alignment of ", stringify!(UndefinedTensorImpl)) + ); + } + pub mod ivalue { + #[allow(unused_imports)] + use self::super::super::super::root; + } + } + pub mod caffe2 { + #[allow(unused_imports)] + use self::super::super::root; + pub mod detail { + #[allow(unused_imports)] + use self::super::super::super::root; + pub type TypeMetaData_New = + ::std::option::Option *mut ::std::os::raw::c_void>; + pub type TypeMetaData_PlacementNew = ::std::option::Option< + unsafe extern "C" fn(arg1: *mut ::std::os::raw::c_void, arg2: usize), + >; + pub type TypeMetaData_Copy = ::std::option::Option< + unsafe extern "C" fn( + arg1: *const ::std::os::raw::c_void, + arg2: *mut ::std::os::raw::c_void, + arg3: usize, + ), + >; + pub type TypeMetaData_PlacementDelete = ::std::option::Option< + unsafe extern "C" fn(arg1: *mut ::std::os::raw::c_void, arg2: usize), + >; + pub type TypeMetaData_Delete = + ::std::option::Option; + } + #[repr(C)] + #[derive(Debug, Copy, Clone)] + pub struct TypeMeta { + pub index_: u16, + } + pub type TypeMeta_New = root::caffe2::detail::TypeMetaData_New; + pub type TypeMeta_PlacementNew = root::caffe2::detail::TypeMetaData_PlacementNew; + pub type TypeMeta_Copy = root::caffe2::detail::TypeMetaData_Copy; + pub type TypeMeta_PlacementDelete = root::caffe2::detail::TypeMetaData_PlacementDelete; + pub type TypeMeta_Delete = root::caffe2::detail::TypeMetaData_Delete; + extern "C" { + #[link_name = "\u{1}_ZN6caffe28TypeMeta13nextTypeIndexE"] + pub static mut TypeMeta_nextTypeIndex: u16; + } + #[test] + fn bindgen_test_layout_TypeMeta() { + assert_eq!( + ::std::mem::size_of::(), + 2usize, + concat!("Size of: ", stringify!(TypeMeta)) + ); + assert_eq!( + ::std::mem::align_of::(), + 2usize, + concat!("Alignment of ", stringify!(TypeMeta)) + ); + assert_eq!( + unsafe { &(*(::std::ptr::null::())).index_ as *const _ as usize }, + 0usize, + concat!( + "Offset of field: ", + stringify!(TypeMeta), + "::", + stringify!(index_) + ) + ); + } + extern "C" { + #[link_name = "\u{1}_ZN6caffe28TypeMetaC1Ev"] + pub fn TypeMeta_TypeMeta(this: *mut root::caffe2::TypeMeta); + } + impl TypeMeta { + #[inline] + pub unsafe fn new() -> Self { + let mut __bindgen_tmp = ::std::mem::MaybeUninit::uninit(); + TypeMeta_TypeMeta(__bindgen_tmp.as_mut_ptr()); + __bindgen_tmp.assume_init() + } + } + } + pub mod at { + #[allow(unused_imports)] + use self::super::super::root; + pub mod detail { + #[allow(unused_imports)] + use self::super::super::super::root; + } + pub mod impl_ { + #[allow(unused_imports)] + use self::super::super::super::root; + } + #[repr(C)] + #[derive(Debug, Copy, Clone)] + pub struct QTensorImpl { + _unused: [u8; 0], + } + } + pub mod torch { + #[allow(unused_imports)] + use self::super::super::root; + } + pub mod ska { + #[allow(unused_imports)] + use self::super::super::root; + } + pub mod google { + #[allow(unused_imports)] + use self::super::super::root; + pub mod base { + #[allow(unused_imports)] + use self::super::super::super::root; + } + } + #[repr(C)] + #[derive(Debug, Copy, Clone)] + pub struct _object { + _unused: [u8; 0], + } + pub type PyObject = root::_object; + pub mod ska_ordered { + #[allow(unused_imports)] + use self::super::super::root; + } + #[test] + fn __bindgen_test_layout_ArrayRef_open0_int64_t_close0_instantiation() { + assert_eq!( + ::std::mem::size_of::>(), + 16usize, + concat!( + "Size of template specialization: ", + stringify!(root::c10::ArrayRef) + ) + ); + assert_eq!( + ::std::mem::align_of::>(), + 8usize, + concat!( + "Alignment of template specialization: ", + stringify!(root::c10::ArrayRef) + ) + ); + } + #[test] + fn __bindgen_test_layout_intrusive_ptr_open0_StorageImpl_intrusive_target_default_null_type_open1_StorageImpl_close1_close0_instantiation( + ) { + assert_eq!( + ::std::mem::size_of::>(), + 8usize, + concat!( + "Size of template specialization: ", + stringify!(root::c10::intrusive_ptr) + ) + ); + assert_eq!( + ::std::mem::align_of::>(), + 8usize, + concat!( + "Alignment of template specialization: ", + stringify!(root::c10::intrusive_ptr) + ) + ); + } + #[test] + fn __bindgen_test_layout_intrusive_target_default_null_type_open0_StorageImpl_close0_instantiation( + ) { + assert_eq!( + ::std::mem::size_of::(), + 1usize, + concat!( + "Size of template specialization: ", + stringify!(root::c10::detail::intrusive_target_default_null_type) + ) + ); + assert_eq!( + ::std::mem::align_of::(), + 1usize, + concat!( + "Alignment of template specialization: ", + stringify!(root::c10::detail::intrusive_target_default_null_type) + ) + ); + } + #[test] + fn __bindgen_test_layout_intrusive_ptr_open0_VariableVersion_VersionCounter_intrusive_target_default_null_type_open1_VariableVersion_VersionCounter_close1_close0_instantiation( + ) { + assert_eq!( + ::std::mem::size_of::< + root::c10::intrusive_ptr, + >(), + 8usize, + concat!( + "Size of template specialization: ", + stringify!(root::c10::intrusive_ptr) + ) + ); + assert_eq!( + ::std::mem::align_of::< + root::c10::intrusive_ptr, + >(), + 8usize, + concat!( + "Alignment of template specialization: ", + stringify!(root::c10::intrusive_ptr) + ) + ); + } + #[test] + fn __bindgen_test_layout_intrusive_target_default_null_type_open0_VariableVersion_VersionCounter_close0_instantiation( + ) { + assert_eq!( + ::std::mem::size_of::(), + 1usize, + concat!( + "Size of template specialization: ", + stringify!(root::c10::detail::intrusive_target_default_null_type) + ) + ); + assert_eq!( + ::std::mem::align_of::(), + 1usize, + concat!( + "Alignment of template specialization: ", + stringify!(root::c10::detail::intrusive_target_default_null_type) + ) + ); + } + #[test] + fn __bindgen_test_layout_intrusive_ptr_open0_TensorImpl_intrusive_target_default_null_type_open1_TensorImpl_close1_close0_instantiation( + ) { + assert_eq!( + ::std::mem::size_of::>(), + 8usize, + concat!( + "Size of template specialization: ", + stringify!(root::c10::intrusive_ptr) + ) + ); + assert_eq!( + ::std::mem::align_of::>(), + 8usize, + concat!( + "Alignment of template specialization: ", + stringify!(root::c10::intrusive_ptr) + ) + ); + } + #[test] + fn __bindgen_test_layout_intrusive_target_default_null_type_open0_TensorImpl_close0_instantiation( + ) { + assert_eq!( + ::std::mem::size_of::(), + 1usize, + concat!( + "Size of template specialization: ", + stringify!(root::c10::detail::intrusive_target_default_null_type) + ) + ); + assert_eq!( + ::std::mem::align_of::(), + 1usize, + concat!( + "Alignment of template specialization: ", + stringify!(root::c10::detail::intrusive_target_default_null_type) + ) + ); + } + #[test] + fn __bindgen_test_layout_ArrayRef_open0_int64_t_close0_instantiation_1() { + assert_eq!( + ::std::mem::size_of::>(), + 16usize, + concat!( + "Size of template specialization: ", + stringify!(root::c10::ArrayRef) + ) + ); + assert_eq!( + ::std::mem::align_of::>(), + 8usize, + concat!( + "Alignment of template specialization: ", + stringify!(root::c10::ArrayRef) + ) + ); + } + #[test] + fn __bindgen_test_layout_ArrayRef_open0_int_close0_instantiation() { + assert_eq!( + ::std::mem::size_of::>(), + 16usize, + concat!( + "Size of template specialization: ", + stringify!(root::c10::ArrayRef<::std::os::raw::c_int>) + ) + ); + assert_eq!( + ::std::mem::align_of::>(), + 8usize, + concat!( + "Alignment of template specialization: ", + stringify!(root::c10::ArrayRef<::std::os::raw::c_int>) + ) + ); + } + #[test] + fn __bindgen_test_layout_ArrayRef_open0_size_t_close0_instantiation() { + assert_eq!( + ::std::mem::size_of::>(), + 16usize, + concat!( + "Size of template specialization: ", + stringify!(root::c10::ArrayRef) + ) + ); + assert_eq!( + ::std::mem::align_of::>(), + 8usize, + concat!( + "Alignment of template specialization: ", + stringify!(root::c10::ArrayRef) + ) + ); + } +} diff --git a/bagua-core-py/src/lib.rs b/bagua-core-py/src/lib.rs index d8fa09f..bd06a19 100644 --- a/bagua-core-py/src/lib.rs +++ b/bagua-core-py/src/lib.rs @@ -11,7 +11,6 @@ use numpy::{IntoPyArray, PyArray1}; use pyo3::exceptions::PyRuntimeError; use pyo3::prelude::*; use pyo3::PyNativeType; -use std::sync::Arc; #[pyclass(dict)] pub struct BaguaSingleCommunicatorPy { @@ -122,17 +121,49 @@ pub struct BaguaTensorPy { #[pymethods] impl BaguaTensorPy { + // #[new] + // pub fn new( + // ptr: u64, + // num_elem: usize, + // num_elem_allocated: usize, + // dtype: &str, + // device_id: usize, + // ) -> Self { + // Self { + // inner: BaguaTensor::new(ptr, num_elem, num_elem_allocated, dtype, device_id), + // } + // } + #[new] - pub fn new( - ptr: u64, - num_elem: usize, - num_elem_allocated: usize, - dtype: &str, - device_id: usize, - ) -> Self { - Self { - inner: BaguaTensor::new(ptr, num_elem, num_elem_allocated, dtype, device_id), - } + pub fn new(torch_tensor: &PyAny, name: String) -> PyResult { + // TODO: sanity check + let dtype = torch_tensor + .getattr("dtype") + .expect("must pass valid torch tensor") + .repr()? + .to_string(); + let bagua_dtype = match dtype.as_str() { + "torch.float32" => BaguaTensorDtype::F32, + "torch.float16" => BaguaTensorDtype::F16, + "torch.int64" => BaguaTensorDtype::I64, + "torch.uint8" => BaguaTensorDtype::U8, + _ => { + return Err(PyRuntimeError::new_err(format!( + "unsupported tensor dtype {}", + dtype + ))) + } + }; + Ok(Self { + inner: BaguaTensor::new_from_torch( + name, + torch_tensor + .getattr("_cdata") + .expect("must pass valid torch tensor") + .extract()?, + bagua_dtype, + ), + }) } pub fn compress(&self, method: &str, n_chunks: usize, target_chunk: i32) -> Self { @@ -143,11 +174,11 @@ impl BaguaTensorPy { pub fn to_numpy_f32<'py>(self_: PyRef<'py, Self>) -> pyo3::Py> { let inner = self_.inner.inner.read(); - assert_eq!(inner.raw.dtype, BaguaTensorDtype::F32); - let mut array = ndarray::Array1::from_elem((inner.raw.num_elem,), 0f32); + assert_eq!(inner.raw.dtype(), BaguaTensorDtype::F32); + let mut array = ndarray::Array1::from_elem((inner.raw.num_elements(),), 0f32); let array_ptr = array.as_mut_ptr(); - let device_ptr = inner.raw.ptr; - let num_bytes = inner.raw.num_elem as i32 * inner.raw.dtype.bytes() as i32; + let device_ptr = inner.raw.data_ptr(); + let num_bytes = inner.raw.num_elements() as i32 * inner.raw.dtype().bytes() as i32; unsafe { bagua_core_internal::cuda_utils::cuda_memcpy_device_to_host_sync( array_ptr as _, @@ -160,11 +191,11 @@ impl BaguaTensorPy { pub fn to_numpy_u8<'py>(self_: PyRef<'py, Self>) -> pyo3::Py> { let inner = self_.inner.inner.read(); - assert_eq!(inner.raw.dtype, BaguaTensorDtype::U8); - let mut array = ndarray::Array1::from_elem((inner.raw.num_elem,), 0u8); + assert_eq!(inner.raw.dtype(), BaguaTensorDtype::U8); + let mut array = ndarray::Array1::from_elem((inner.raw.num_elements(),), 0u8); let array_ptr = array.as_mut_ptr(); - let device_ptr = inner.raw.ptr; - let num_bytes = inner.raw.num_elem as i32 * inner.raw.dtype.bytes() as i32; + let device_ptr = inner.raw.data_ptr(); + let num_bytes = inner.raw.num_elements() as i32 * inner.raw.dtype().bytes() as i32; unsafe { bagua_core_internal::cuda_utils::cuda_memcpy_device_to_host_sync( array_ptr as _, @@ -180,29 +211,21 @@ impl BaguaTensorPy { .decompress_from(method, n_chunks, &compressed_buffer.inner); } - pub fn ptr(&self) -> u64 { - self.inner.ptr() - } - - pub fn id(&self) -> u64 { - self.inner.id() + pub fn data_ptr(&self) -> u64 { + self.inner.data_ptr() } - pub fn num_elem(&self) -> usize { - self.inner.num_elem() + pub fn num_elements(&self) -> usize { + self.inner.num_elements() } - pub fn num_elem_allocated(&self) -> usize { - self.inner.num_elem_allocated() + pub fn num_elements_allocated(&self) -> usize { + self.inner.num_elements_allocated() } pub fn dtype(&self) -> String { self.inner.dtype() } - - pub fn reset_ptr(&mut self, ptr: u64) { - self.inner.reset_ptr(ptr) - } } #[pyclass(dict)] @@ -271,19 +294,13 @@ pub struct BaguaBucketPy { #[pymethods] impl BaguaBucketPy { #[new] - #[args(align_bytes = "0")] - pub fn new( - name: &str, - tensors: Vec>, - inplace: bool, - align_bytes: usize, - ) -> PyResult { + pub fn new(name: &str, tensors: Vec>) -> PyResult { let mut tensors_inner = Vec::with_capacity(tensors.len()); for t in tensors.iter() { tensors_inner.push(&t.inner) } Ok(Self { - inner: BaguaBucket::new(tensors_inner.as_slice(), name, inplace, align_bytes) + inner: BaguaBucket::new(tensors_inner.as_slice(), name) .map_err(|e| PyRuntimeError::new_err(format!("{:?}", e)))?, }) }