Skip to content
This repository has been archived by the owner on Sep 15, 2021. It is now read-only.

Commit

Permalink
feat: compatible mode for getting pytorch tensor info with Python int…
Browse files Browse the repository at this point in the history
…erpreter
  • Loading branch information
NOBLES5E committed Jul 1, 2021
1 parent a7e34ba commit 1534d23
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 21 deletions.
1 change: 1 addition & 0 deletions bagua-core-internal/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ num-derive = "0.3"

[dependencies.pyo3]
version = "0.13.2"
features = ["auto-initialize"]

[build-dependencies]
shadow-rs = "0.6"
Expand Down
126 changes: 107 additions & 19 deletions bagua-core-internal/src/datatypes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use itertools::Itertools;
use num_derive::FromPrimitive;
use num_traits::FromPrimitive;
use parking_lot::{Mutex, RwLock};
use pyo3::types::IntoPyDict;
use sized_object_pool::DynamicPoolItem;
use std::ffi::c_void;
use std::fmt::Debug;
Expand Down Expand Up @@ -68,6 +69,7 @@ impl BaguaTensorDtype {
#[derive(Debug)]
pub struct TorchTensorRaw {
pub torch_tensor_cdata: u64,
pub python_fallback: bool,
pub dtype: BaguaTensorDtype,
}

Expand Down Expand Up @@ -385,6 +387,47 @@ pub trait RawBaguaTensor: Debug {
}

impl TorchTensorRaw {
pub fn get_pytensor<'a>(cdata: u64, py: &'a pyo3::Python) -> pyo3::PyResult<&'a pyo3::PyAny> {
let torch = py.import("torch")?;
let tensor_class = torch.get("Tensor")?;
tensor_class.call((), Some([("cdata", cdata)].into_py_dict(py.to_owned())))
}

pub fn check_consistency_with_python(&self) -> pyo3::PyResult<bool> {
fn check_consistency(
py: pyo3::Python,
cdata: u64,
data_ptr: u64,
numel: usize,
device_id: usize,
) -> pyo3::PyResult<bool> {
let py_tensor = TorchTensorRaw::get_pytensor(cdata, &py)?;
let py_data_ptr: u64 = py_tensor.call_method0("data_ptr")?.extract()?;
if py_data_ptr != data_ptr {
return Ok(false);
}
let py_numel: usize = py_tensor.call_method0("numel")?.extract()?;
if py_numel != numel {
return Ok(false);
}
let py_device_id: usize = py_tensor.getattr("device")?.getattr("index")?.extract()?;
if py_device_id != device_id {
return Ok(false);
}
return Ok(true);
}

pyo3::Python::with_gil(|py| {
check_consistency(
py,
self.torch_tensor_cdata,
self.data_ptr(),
self.num_elements(),
self.device_id(),
)
})
}

fn extract_torch_c_data(&self) -> &TensorImpl {
unsafe { (self.torch_tensor_cdata as *const TensorImpl).as_ref() }
.expect("torch c data pointer is null")
Expand All @@ -401,28 +444,59 @@ impl TorchTensorRaw {

impl RawBaguaTensor for TorchTensorRaw {
fn data_ptr(&self) -> u64 {
let cdata = self.extract_torch_c_data();
let storage = self.extract_storage();
storage.data_ptr_.ptr_.data_ as u64
+ cdata.storage_offset_ as u64 * self.dtype.bytes() as u64
if self.python_fallback {
pyo3::Python::with_gil(|py| {
let py_tensor = TorchTensorRaw::get_pytensor(self.torch_tensor_cdata, &py).unwrap();
py_tensor
.call_method0("data_ptr")
.unwrap()
.extract()
.unwrap()
})
} else {
let cdata = self.extract_torch_c_data();
let storage = self.extract_storage();
storage.data_ptr_.ptr_.data_ as u64
+ cdata.storage_offset_ as u64 * self.dtype.bytes() as u64
}
}

fn num_elements(&self) -> usize {
self.extract_torch_c_data().numel_ as _
if self.python_fallback {
pyo3::Python::with_gil(|py| {
let py_tensor = TorchTensorRaw::get_pytensor(self.torch_tensor_cdata, &py).unwrap();
py_tensor.call_method0("numel").unwrap().extract().unwrap()
})
} else {
self.extract_torch_c_data().numel_ as _
}
}

fn num_elements_allocated(&self) -> usize {
self.num_elements()
}

fn device_id(&self) -> usize {
let storage_data_ptr = &self.extract_storage().data_ptr_;
assert_eq!(
storage_data_ptr.device_.type_,
DeviceType::CUDA,
"currently only cuda tensors are supported in Bagua"
);
return storage_data_ptr.device_.index_ as _;
if self.python_fallback {
pyo3::Python::with_gil(|py| {
let py_tensor = TorchTensorRaw::get_pytensor(self.torch_tensor_cdata, &py).unwrap();
py_tensor
.getattr("device")
.unwrap()
.getattr("index")
.unwrap()
.extract()
.unwrap()
})
} else {
let storage_data_ptr = &self.extract_storage().data_ptr_;
assert_eq!(
storage_data_ptr.device_.type_,
DeviceType::CUDA,
"currently only cuda tensors are supported in Bagua"
);
return storage_data_ptr.device_.index_ as _;
}
}

fn dtype(&self) -> BaguaTensorDtype {
Expand Down Expand Up @@ -552,18 +626,32 @@ pub struct BaguaTensor {
}

impl BaguaTensor {
pub fn new_from_torch(name: String, torch_cdata_ptr: u64, dtype: BaguaTensorDtype) -> Self {
Self {
pub fn new_from_torch(
name: String,
torch_cdata_ptr: u64,
dtype: BaguaTensorDtype,
) -> pyo3::PyResult<Self> {
let mut torch_tensor = TorchTensorRaw {
torch_tensor_cdata: torch_cdata_ptr,
python_fallback: false,
dtype,
};
let consistency = torch_tensor.check_consistency_with_python()?;
if !consistency {
tracing::warn!(
r#"PyTorch tensor memory layout inconsistent with latest PyTorch. Bagua will fallback to Python interface. This will degrade system performance. We suggest upgrading to latest PyTorch."#
)
}
torch_tensor.python_fallback = !consistency;

Ok(Self {
inner: Arc::new(RwLock::new(BaguaTensorInner {
name,
raw: Box::new(TorchTensorRaw {
torch_tensor_cdata: torch_cdata_ptr,
dtype,
}),
raw: Box::new(torch_tensor),
ready_for_comm: false,
ready_cuda_event_ptr: 0,
})),
}
})
}

pub fn mark_comm_ready(&self, cuda_event_ptr: u64) {
Expand Down
6 changes: 4 additions & 2 deletions bagua-core-py/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use bagua_core_internal::datatypes::{
BaguaBucket, BaguaReductionOp, BaguaTensor, BaguaTensorDtype,
};
use bagua_core_internal::BaguaCommBackend;
use num_derive::FromPrimitive;
use num_traits::FromPrimitive;
use numpy::{IntoPyArray, PyArray1};
use pyo3::exceptions::PyRuntimeError;
Expand Down Expand Up @@ -162,7 +161,7 @@ impl BaguaTensorPy {
.expect("must pass valid torch tensor")
.extract()?,
bagua_dtype,
),
)?,
})
}

Expand Down Expand Up @@ -384,6 +383,9 @@ impl BaguaBucketPy {

#[pymodule]
fn bagua_core(_py: Python, m: &PyModule) -> PyResult<()> {
if std::env::var("LOG_LEVEL").is_err() {
std::env::set_var("LOG_LEVEL", "WARN");
}
tracing_subscriber::fmt()
.with_env_filter(tracing_subscriber::EnvFilter::from_env("LOG_LEVEL"))
.init();
Expand Down

0 comments on commit 1534d23

Please sign in to comment.