Skip to content

Commit

Permalink
sharding mode support avg_weight (PaddlePaddle#95)
Browse files Browse the repository at this point in the history
* sharding mode, erase other device param

* sharding mode support avg_weight
  • Loading branch information
gumplh authored Nov 6, 2023
1 parent 1d44c42 commit 0bd927a
Showing 1 changed file with 37 additions and 18 deletions.
55 changes: 37 additions & 18 deletions paddle/fluid/framework/boxps_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,34 @@ int BoxPSWorker::IsParameter(const std::string& name, bool full_match) {
return -1;
}
}

static bool FindVarInMap(const VariableNameMap& op_var_map,
const std::multiset<std::string>& var_set) {
for (auto& o : op_var_map) {
for (auto& name : o.second) {
if (var_set.find(name) != var_set.end()) {
return true;
}
}
}
return false;
}

static bool IsAvgOp(OpDesc* op_desc) {
if (op_desc->Type() != "elementwise_add" &&
op_desc->Type() != "elementwise_mul") {
return false;
}
for (auto& o : op_desc->Outputs()) {
for (auto& name : o.second) {
if (name.find("avg_weight") != std::string::npos ||
name.find("@avg") != std::string::npos) {
return true;
}
}
}
return false;
}
void BoxPSWorker::BuildShardingDepends(const ProgramDesc& program) {
nccl_rank_id_ = place_.GetDeviceId();
#if defined(PADDLE_WITH_CUDA)
Expand Down Expand Up @@ -615,33 +643,24 @@ void BoxPSWorker::BuildShardingDepends(const ProgramDesc& program) {
unpersist_vars_.insert(name);
}
} else {
// adam ubmq1_h2_param.b_0_moment1_0
// adam ubmq1_h2_param.b_0_moment1_0, avg_weight @avg @w_backup
remove_vars_.insert(name);
}
}

std::multiset<std::string> all_remove_inputs;
for (auto& op_desc : all_desc) {
bool find = false;
for (auto& o : op_desc->Inputs()) {
for (auto& name : o.second) {
if (remove_vars_.find(name) == remove_vars_.end()) {
continue;
}
find = true;
break;
}
if (find) {
break;
}
}
if (find) {
if (FindVarInMap(op_desc->Inputs(), remove_vars_)) {
for (auto& o : op_desc->Inputs()) {
for (auto& name : o.second) {
all_remove_inputs.insert(name);
}
}
remove_ops_.insert(op_desc);
} else if (IsAvgOp(op_desc) &&
(FindVarInMap(op_desc->Outputs(), remove_vars_) ||
FindVarInMap(op_desc->Inputs(), unpersist_vars_))) {
remove_ops_.insert(op_desc);
}
}

Expand Down Expand Up @@ -1026,7 +1045,7 @@ void BoxPSWorker::CreateThreadScopeForSharding(const ProgramDesc& program) {
ptr->GetMutable<phi::DenseTensor>()->Resize(dims).set_type(var_dtype);
++unpersist_num;
++persistable_num;
total_persistable_len += ptr->GetMutable<phi::DenseTensor>()->numel();
total_persistable_len += ptr->GetMutable<phi::DenseTensor>()->numel();
continue;
}
Variable* root_var = root_scope_->FindVar(name);
Expand Down Expand Up @@ -1079,8 +1098,8 @@ void BoxPSWorker::CreateThreadScopeForSharding(const ProgramDesc& program) {
share_persistable_len += len;
} else {
TensorCopy(*static_cast<const Tensor*>(root_tensor),
place_,
static_cast<Tensor*>(gpu_tensor));
place_,
static_cast<Tensor*>(gpu_tensor));
++copy_persist_num;
// device 0 need sync datanorm and learning rate to root scope
if (device_id_ == 0) {
Expand Down

0 comments on commit 0bd927a

Please sign in to comment.