diff --git a/bagua-core-internal/src/comm_ops/decentralized_full_precision_synchronous.rs b/bagua-core-internal/src/comm_ops/decentralized_full_precision_synchronous.rs index 25f7ac7..7f10cad 100644 --- a/bagua-core-internal/src/comm_ops/decentralized_full_precision_synchronous.rs +++ b/bagua-core-internal/src/comm_ops/decentralized_full_precision_synchronous.rs @@ -87,11 +87,13 @@ impl CommOpTrait for DecentralizedFullPrecisionSynchronous { "you cannot use decentralized algorithm with average_all off when there are odd number of ranks, current n_ranks {}", c.nranks ); - let comm_step = step / comm_interval; + let comm_step = (step / comm_interval) as i64; + let rank = c.rank as i64; + let nranks = c.nranks as i64; let peer_rank = if c.rank < c.nranks / 2 { - ((comm_step + c.rank) % ((c.nranks + 1) / 2)) + (c.nranks / 2) + ((comm_step + rank) % ((nranks + 1) / 2)) + (nranks / 2) } else { - (c.rank - (c.nranks / 2) - comm_step).rem_euclid(c.nranks / 2) + (rank - (nranks / 2) - comm_step).rem_euclid(nranks / 2) } as i32; tracing::debug!("rank {} peer_rank {}", c.rank, peer_rank); {