Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge pull request #11608 from panyx0718/doc #11621

Merged
merged 1 commit into from
Jun 21, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion paddle/fluid/framework/details/multi_devices_graph_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
BuildStrategy::GradientScaleStrategy::kCustomized) {
CreateScaleLossGradOp(&result);
}
// This assumes the backward generating code will ensure IsScaleLossOp
// is true only for the op that scale the final scalar loss.
// It also assumes backward op will always follow the forward op in
// the block.
is_forwarding = false;
} else {
int op_dev_id = GetOpDeviceID(*op);
Expand Down Expand Up @@ -244,6 +248,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
InsertAllReduceOp(&result, g_name);
}
break;
default:
LOG(FATAL) << "Unknown reduce strategy ";
break;
}
}
} catch (boost::bad_get e) {
Expand All @@ -262,7 +269,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
}
/*
Dependency graph has been constructed. However, there are still data
harzaeds need to be handled.
hazards need to be handled.
*/
PolishGraphToSupportDataHazards(&result);

Expand Down Expand Up @@ -447,6 +454,8 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result,
return var;
}

// Find the first occurence of `prev_op_name` and make current `op` depend
// on it.
void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op,
const std::string &prev_op_name) const {
for (auto &prev_op : result->ops_) {
Expand Down Expand Up @@ -490,6 +499,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result,
}
}

// Create RPC related op handles that connects its in ops and out ops.
void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
const OpDesc &op) const {
int op_dev_id = -1;
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
auto cur_ready_vars = ready_vars.PopAll(1, &timeout);

if (timeout) {
std::lock_guard<std::mutex> l(exception_mu_);
if (exception_) {
auto exp = *exception_;
exception_.reset();
Expand Down Expand Up @@ -199,6 +200,7 @@ void ThreadedSSAGraphExecutor::RunOp(
ready_var_q->Extend(op->Outputs());
VLOG(10) << op << " " << op->Name() << "Signal posted";
} catch (platform::EnforceNotMet ex) {
std::lock_guard<std::mutex> l(exception_mu_);
exception_.reset(new platform::EnforceNotMet(ex));
} catch (...) {
LOG(FATAL) << "Unknown exception catched";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
std::vector<Scope *> local_scopes_;
std::vector<platform::Place> places_;
platform::DeviceContextPool fetch_ctxs_;
std::mutex exception_mu_;
std::unique_ptr<platform::EnforceNotMet> exception_;
std::atomic<int> running_ops_;

Expand Down