Skip to content

Commit

Permalink
move configure check
Browse files Browse the repository at this point in the history
  • Loading branch information
FeixLiu committed Dec 2, 2021
1 parent 5753e4b commit ddc3f3e
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 26 deletions.
22 changes: 0 additions & 22 deletions paddle/fluid/distributed/fleet_executor/amplifier_interceptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,28 +27,6 @@ AmplifierInterceptor::AmplifierInterceptor(int64_t interceptor_id,
run_at_offset_ = node->run_at_offset();
reply_up_per_steps_ = node->reply_up_per_steps();
send_down_per_steps_ = node->send_down_per_steps();

PADDLE_ENFORCE_GE(
run_per_steps_, 1,
platform::errors::InvalidArgument(
"run_per_steps must >= 1, but now is %ld", run_per_steps_));
PADDLE_ENFORCE_GE(
run_at_offset_, 0,
platform::errors::InvalidArgument(
"run_at_offset must >= 0, but now is %ld", run_at_offset_));
PADDLE_ENFORCE_LT(run_at_offset_, run_per_steps_,
platform::errors::InvalidArgument(
"run_at_offset must < run_per_steps, must now "
"run_at_offset=%ld run_per_steps=%ld",
run_at_offset_, run_per_steps_));
PADDLE_ENFORCE_GE(
reply_up_per_steps_, 1,
platform::errors::InvalidArgument(
"reply_up_per_steps must >= 1, but now is %ld", reply_up_per_steps_));
PADDLE_ENFORCE_GE(send_down_per_steps_, 1,
platform::errors::InvalidArgument(
"send_down_per_steps must >= 1, but now is %ld",
send_down_per_steps_));
}

void AmplifierInterceptor::RunOps() {
Expand Down
7 changes: 7 additions & 0 deletions paddle/fluid/distributed/fleet_executor/carrier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,13 @@ void Carrier::CreateInterceptors() {
int64_t interceptor_id = item.first;
TaskNode* task_node = item.second;

PADDLE_ENFORCE_LT(
task_node->run_at_offset(), task_node->run_per_steps(),
platform::errors::InvalidArgument(
"Interceptor's run_at_offset must < run_per_steps, must now "
"run_at_offset=%ld run_per_steps=%ld",
task_node->run_at_offset(), task_node->run_per_steps()));

std::unique_ptr<Interceptor> interceptor;
if (task_node->type().empty()) {
// TODO(wangxi): delete this in future
Expand Down
29 changes: 29 additions & 0 deletions paddle/fluid/distributed/fleet_executor/task_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,5 +92,34 @@ std::string TaskNode::DebugString() const {
os << "\n";
return os.str();
}

void TaskNode::SetRunPerSteps(int64_t value) {
PADDLE_ENFORCE_GE(value, 1,
platform::errors::InvalidArgument(
"run_per_steps must >= 1, but received %ld", value));
run_per_steps_ = value;
}

void TaskNode::SetRunAtOffset(int64_t value) {
PADDLE_ENFORCE_GE(value, 0,
platform::errors::InvalidArgument(
"run_at_offset must >= 0, but received %ld", value));
run_at_offset_ = value;
}

void TaskNode::SetReplyUpPerSteps(int64_t value) {
PADDLE_ENFORCE_GE(
value, 1, platform::errors::InvalidArgument(
"reply_up_per_steps must >= 1, but received %ld", value));
reply_up_per_steps_ = value;
}

void TaskNode::SetSendDownPerSteps(int64_t value) {
PADDLE_ENFORCE_GE(
value, 1, platform::errors::InvalidArgument(
"send_down_per_steps must >= 1, but received %ld", value));
send_down_per_steps_ = value;
}

} // namespace distributed
} // namespace paddle
8 changes: 4 additions & 4 deletions paddle/fluid/distributed/fleet_executor/task_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ class TaskNode final {
const paddle::framework::ProgramDesc& program() const { return program_; }
const std::vector<OperatorBase*>& ops() const { return ops_; }

void SetRunPerSteps(int64_t value) { run_per_steps_ = value; }
void SetRunAtOffset(int64_t value) { run_at_offset_ = value; }
void SetReplyUpPerSteps(int64_t value) { reply_up_per_steps_ = value; }
void SetSendDownPerSteps(int64_t value) { send_down_per_steps_ = value; }
void SetRunPerSteps(int64_t value);
void SetRunAtOffset(int64_t value);
void SetReplyUpPerSteps(int64_t value);
void SetSendDownPerSteps(int64_t value);
void SetType(const std::string& type) { type_ = type; }

bool AddUpstreamTask(int64_t task_id);
Expand Down

0 comments on commit ddc3f3e

Please sign in to comment.