Skip to content

Commit

Permalink
[KUNLUN] update xccl lib & use native Reduce in dygraph (PaddlePaddle…
Browse files Browse the repository at this point in the history
…#49941)

* update xccl lib & use native Reduce in dygraph

* minor
  • Loading branch information
XiaociZhang authored and pangengzheng committed Feb 2, 2023
1 parent ead7011 commit f583870
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 35 deletions.
2 changes: 1 addition & 1 deletion cmake/external/xpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ else()
endif()

set(XPU_XCCL_BASE_URL
"https://klx-sdk-release-public.su.bcebos.com/xccl/release/1.0.6")
"https://klx-sdk-release-public.su.bcebos.com/xccl/release/1.0.7")

if(WITH_AARCH64)
set(XPU_XRE_DIR_NAME "xre-kylin_aarch64")
Expand Down
44 changes: 10 additions & 34 deletions paddle/fluid/distributed/collective/process_group_bkcl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -352,41 +352,17 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Reduce(
const phi::DenseTensor& input,
BKCLContext_t comm,
const XPUStream& stream) {
phi::DenseTensor output_t;
paddle::framework::TensorCopy(*output, platform::XPUPlace(), &output_t);
const auto& place = input.place();
auto* calc_ctx = static_cast<phi::XPUContext*>(
platform::DeviceContextPool::Instance().Get(place));
switch (input.dtype()) {
case phi::DataType::FLOAT32:
calc_ctx->template Alloc<float>(&output_t);
break;
case phi::DataType::FLOAT16:
calc_ctx->template Alloc<float16>(&output_t);
break;
case phi::DataType::INT32:
calc_ctx->template Alloc<int>(&output_t);
break;
default:
VLOG(0) << "Error: type " << input.dtype() << " not supported for "
<< GetBackendName();
break;
}
int ret =
bkcl_all_reduce(comm,
input.data(),
output_t.data(),
input.numel(),
platform::ToBKCLDataType(
framework::TransToProtoVarType(input.type())),
ToBKCLRedType(opts.reduce_op),
stream);
if (rank_ == opts.root_rank) {
*output = output_t;
}
return ret;
return bkcl_reduce(comm,
input.data(),
output->data(),
input.numel(),
platform::ToBKCLDataType(
framework::TransToProtoVarType(input.type())),
ToBKCLRedType(opts.reduce_op),
opts.root_rank,
stream);
},
CommType::ALLREDUCE,
CommType::REDUCE,
sync_op,
use_calc_stream);
}
Expand Down

0 comments on commit f583870

Please sign in to comment.