Skip to content

Commit

Permalink
make variable->Grad() a weak_ptr (#11453)
Browse files Browse the repository at this point in the history
* fix #11416

* make sgd check tape has been backwarded_

* add error message
  • Loading branch information
Yang Yang(Tony) authored Jun 14, 2018
1 parent a59c3b7 commit f790b96
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 8 deletions.
3 changes: 2 additions & 1 deletion paddle/contrib/tape/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ class SGD {
}

void operator()(VariableHandle input) {
PADDLE_ENFORCE(get_global_tape().HasBeenBackwarded(),
"optimization must happen after the backward");
Tape temp_tape;
temp_tape.AddOp("sgd",
{{"Param", {input}},
Expand All @@ -120,7 +122,6 @@ class SGD {
{{"ParamOut", {input}}},
{});
temp_tape.Forward();
input->ResetGrad();
}

private:
Expand Down
2 changes: 2 additions & 0 deletions paddle/contrib/tape/tape.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ class Tape {
void Forward();
void Backward(VariableHandle target);

bool HasBeenBackwarded() { return has_been_backwarded_; }

private:
bool has_been_backwarded_ = false;
size_t current_position_ = 0;
Expand Down
14 changes: 7 additions & 7 deletions paddle/contrib/tape/variable.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,15 @@ class Variable {
void InitializeVariable();

VariableHandle Grad() {
if (grad_ == nullptr) {
grad_.reset(new Variable(desc_.Name(), true));
if (grad_.expired()) {
VariableHandle new_grad(new Variable(desc_.Name(), true));
grad_ = new_grad;
return new_grad;
} else {
return VariableHandle(grad_);
}

return grad_;
}

void ResetGrad() { grad_ = nullptr; }

// Stochastic Gradient Descent with Momentum
// VariableHandle Momentum ();

Expand All @@ -79,7 +79,7 @@ class Variable {
framework::VarDesc desc_;
framework::Variable var_;

VariableHandle grad_;
std::weak_ptr<Variable> grad_;
};
}
}

0 comments on commit f790b96

Please sign in to comment.