diff --git a/paddle/contrib/tape/function.h b/paddle/contrib/tape/function.h index 0584f4ec8aaae..8c9694d9a21b5 100644 --- a/paddle/contrib/tape/function.h +++ b/paddle/contrib/tape/function.h @@ -112,6 +112,8 @@ class SGD { } void operator()(VariableHandle input) { + PADDLE_ENFORCE(get_global_tape().HasBeenBackwarded(), + "optimization must happen after the backward"); Tape temp_tape; temp_tape.AddOp("sgd", {{"Param", {input}}, @@ -120,7 +122,6 @@ class SGD { {{"ParamOut", {input}}}, {}); temp_tape.Forward(); - input->ResetGrad(); } private: diff --git a/paddle/contrib/tape/tape.h b/paddle/contrib/tape/tape.h index 9938ce9a7f46a..ed79de17a7fca 100644 --- a/paddle/contrib/tape/tape.h +++ b/paddle/contrib/tape/tape.h @@ -47,6 +47,8 @@ class Tape { void Forward(); void Backward(VariableHandle target); + bool HasBeenBackwarded() { return has_been_backwarded_; } + private: bool has_been_backwarded_ = false; size_t current_position_ = 0; diff --git a/paddle/contrib/tape/variable.h b/paddle/contrib/tape/variable.h index 7e63aa38a7a63..35c328e69c9eb 100644 --- a/paddle/contrib/tape/variable.h +++ b/paddle/contrib/tape/variable.h @@ -45,15 +45,15 @@ class Variable { void InitializeVariable(); VariableHandle Grad() { - if (grad_ == nullptr) { - grad_.reset(new Variable(desc_.Name(), true)); + if (grad_.expired()) { + VariableHandle new_grad(new Variable(desc_.Name(), true)); + grad_ = new_grad; + return new_grad; + } else { + return VariableHandle(grad_); } - - return grad_; } - void ResetGrad() { grad_ = nullptr; } - // Stochastic Gradient Descent with Momentum // VariableHandle Momentum (); @@ -79,7 +79,7 @@ class Variable { framework::VarDesc desc_; framework::Variable var_; - VariableHandle grad_; + std::weak_ptr grad_; }; } }