Skip to content
This repository has been archived by the owner on Sep 15, 2021. It is now read-only.

Commit

Permalink
feat: make full precision decentralized op stateless (#36)
Browse files Browse the repository at this point in the history
BREAKING CHANGE: `BaguaBucketPy::append_decentralized_synchronous_op` now only supports full precision decentralized communication.
  • Loading branch information
wangraying authored Jul 21, 2021
1 parent 30cdb67 commit 98319c9
Show file tree
Hide file tree
Showing 5 changed files with 209 additions and 306 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::comm_ops::CommOpTrait;
use crate::communicators::{BaguaCommunicator, BaguaHierarchicalCommunicator, NCCLGroupGuard};
use crate::datatypes::{BaguaBucket, BaguaReductionOp, BaguaTensorRaw, RawBaguaTensor};
use crate::datatypes::{BaguaBucket, BaguaTensor, BaguaReductionOp, BaguaTensorRaw, RawBaguaTensor};
use crate::events::BaguaEventChannel;
use crate::resource_pool::CUDA_DEVICE_MEMORY_POOL;
use crate::{BaguaCommOpChannels, BaguaScheduledCommOp};
Expand All @@ -19,7 +19,7 @@ pub struct DecentralizedFullPrecisionSynchronous {
pub communicator: BaguaCommunicator,
pub peer_selection_mode: PeerSelectionMode,
pub step: Mutex<usize>,
pub communication_interval: usize,
pub peer_weight: BaguaTensor,
}

impl CommOpTrait for DecentralizedFullPrecisionSynchronous {
Expand All @@ -45,22 +45,11 @@ impl CommOpTrait for DecentralizedFullPrecisionSynchronous {
},
};

let t = &communication_tensor;
let peer_tensor_buffer = CUDA_DEVICE_MEMORY_POOL[t.raw.device_id]
.try_pull(t.raw.num_elem_allocated * t.raw.dtype.bytes())
.expect("cannot allocate gpu memory");
let mut peer_tensor = BaguaTensorRaw {
ptr: peer_tensor_buffer.ptr,
num_elem_allocated: t.raw.num_elem_allocated,
dtype: t.raw.dtype,
num_elem: t.raw.num_elem,
device_id: t.raw.device_id,
pool_allocations: vec![Arc::new(peer_tensor_buffer)],
};

let peer_mode = &self.peer_selection_mode;
let comm_interval = &self.communication_interval;
let step = { *self.step.lock() };

let mut peer_guard = self.peer_weight.inner.write();
let mut peer_tensor = peer_guard.raw.as_mut();
let step = { *self.step.lock() } as i64;

self.communicator.execute_communication(
&mut communication_tensor,
Expand All @@ -71,95 +60,43 @@ impl CommOpTrait for DecentralizedFullPrecisionSynchronous {
match peer_mode {
PeerSelectionMode::All => {
{
if step % comm_interval == 0 {
peer_tensor.clone_from(&t.raw, c.stream_ptr);
let _guard = NCCLGroupGuard::new();
c.allreduce_inplace(&mut peer_tensor, BaguaReductionOp::SUM);
peer_tensor.divide_inplace(stream_ptr, c.nranks as f32);
}
peer_tensor.clone_from(&t.raw, c.stream_ptr);
let _guard = NCCLGroupGuard::new();
c.allreduce_inplace(peer_tensor, BaguaReductionOp::SUM);
peer_tensor.divide_inplace(stream_ptr, c.nranks as f32);
}
}
PeerSelectionMode::ShiftOne => {
if step % comm_interval == 0 {
assert_eq!(
c.nranks % 2,
0,
"you cannot use decentralized algorithm with average_all off when there are odd number of ranks, current n_ranks {}",
c.nranks
);
let comm_step = (step / comm_interval) as i64;
let rank = c.rank as i64;
let nranks = c.nranks as i64;
let peer_rank = if c.rank < c.nranks / 2 {
((comm_step + rank) % ((nranks + 1) / 2)) + (nranks / 2)
} else {
(rank - (nranks / 2) - comm_step).rem_euclid(nranks / 2)
} as i32;
tracing::debug!("rank {} peer_rank {}", c.rank, peer_rank);
{
let _guard = NCCLGroupGuard::new();
c.send(&t.raw, peer_rank);
c.recv(&mut peer_tensor, peer_rank);
}
peer_tensor.average_inplace(&t.raw, c.stream_ptr);
assert_eq!(
c.nranks % 2,
0,
"you cannot use decentralized algorithm with average_all off when there are odd number of ranks, current n_ranks {}",
c.nranks
);
let rank = c.rank as i64;
let nranks = c.nranks as i64;
let peer_rank = if c.rank < c.nranks / 2 {
((step + rank) % ((nranks + 1) / 2)) + (nranks / 2)
} else {
(rank - (nranks / 2) - step).rem_euclid(nranks / 2)
} as i32;
tracing::debug!("rank {} peer_rank {}", c.rank, peer_rank);
{
let _guard = NCCLGroupGuard::new();
c.send(&t.raw, peer_rank);
c.recv(peer_tensor, peer_rank);
}
peer_tensor.average_inplace(&t.raw, c.stream_ptr);
},
PeerSelectionMode::Ring => {
unimplemented!()
}
},
}
},
);

if step % comm_interval == 0 {
// TODO: move this to .then() python API instead of hard code this in op
let post_backward_comm_op = BaguaScheduledCommOp {
name: format!("post backward comm op for bucket {}", bucket.name),
bucket: bucket.clone(),
ops: vec![Arc::new(DecentralizedFullPrecisionSynchronousPostStep {
communicator: self.communicator.clone(),
result_weight: peer_tensor,
})],
event_channel: BaguaEventChannel::new("decentralized_post_backward"),
};

comm_op_channels
.not_waited_post_backward_events_sender
.send(post_backward_comm_op.event_channel.clone())
.expect("cannot send post backward event");
comm_op_channels
.post_backward_channel_sender
.send(post_backward_comm_op)
.expect("cannot send post backward op");
}

*self.step.lock() += 1;
}
}

#[derive(Debug)]
pub struct DecentralizedFullPrecisionSynchronousPostStep {
pub communicator: BaguaCommunicator,
pub result_weight: BaguaTensorRaw,
}

impl CommOpTrait for DecentralizedFullPrecisionSynchronousPostStep {
fn execute_background_communication(
&self,
bucket: Arc<BaguaBucket>,
_comm_op_channels: &BaguaCommOpChannels,
) {
let bucket = bucket.inner.lock();
let stream_ptr = self.communicator.stream_ptr();
let mut communication_tensor = bucket.get_communication_tensor(stream_ptr, false, false);
self.communicator.execute_communication(
&mut communication_tensor,
false,
false,
true,
&mut |c, t| {
t.raw.clone_from(&self.result_weight, c.stream_ptr);
},
);
}
}

Loading

0 comments on commit 98319c9

Please sign in to comment.