-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add sgd op #2950
Add sgd op #2950
Changes from 7 commits
54335ec
c843b25
04db57c
3ffcc8d
74cf950
4ab4560
df1e4a9
5b9e807
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/operators/sgd_op.h" | ||
#include "paddle/framework/op_registry.h" | ||
#include "paddle/framework/tensor.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
class SGDOp : public framework::OperatorWithKernel { | ||
protected: | ||
void InferShape( | ||
const std::vector<const framework::Tensor *> &inputs, | ||
const std::vector<framework::Tensor *> &outputs) const override { | ||
PADDLE_ENFORCE(inputs.size() == 2, "Input size of SGDOp must be two"); | ||
PADDLE_ENFORCE(outputs.size() == 1, "Output size of SGDOp must be one"); | ||
PADDLE_ENFORCE(inputs[0] != nullptr, "inputs[0] mast be set"); | ||
PADDLE_ENFORCE(inputs[1] != nullptr, "inputs[1] mast be set"); | ||
PADDLE_ENFORCE(outputs[0] != nullptr, "outputs[0] mast be set"); | ||
PADDLE_ENFORCE(inputs[0]->dims() == inputs[1]->dims(), | ||
"Two input of SGD Op's dimension must be same."); | ||
outputs[0]->set_dims(inputs[0]->dims()); | ||
} | ||
}; | ||
|
||
class SGDOpMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
SGDOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) | ||
: framework::OpProtoAndCheckerMaker(proto, op_checker) { | ||
AddInput("param", "input parameter"); | ||
AddInput("grad", "input gradient"); | ||
AddOutput("param_out", "output parameter"); | ||
AddAttr<float>("learning_rate", "learning rate of sgd"); | ||
AddComment(R"DOC( | ||
|
||
Simplest sgd algorithm. | ||
|
||
param_out = param - learning_rate * grad; | ||
|
||
)DOC"); | ||
} | ||
}; | ||
} // namespace operators | ||
} // namespace paddle | ||
|
||
REGISTER_OP(sgd, paddle::operators::SGDOp, paddle::operators::SGDOpMaker); | ||
typedef paddle::operators::SGDOpKernel<::paddle::platform::CPUPlace, float> | ||
SGDOpKernel_CPU_float; | ||
REGISTER_OP_CPU_KERNEL(sgd, SGDOpKernel_CPU_float); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
#include "paddle/operators/sgd_op.h" | ||
#include "paddle/framework/op_registry.h" | ||
|
||
typedef paddle::operators::SGDOpKernel<::paddle::platform::GPUPlace, float> SGDOpKernel_GPU_float; | ||
REGISTER_OP_GPU_KERNEL(sgd, SGDOpKernel_GPU_float); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#pragma once | ||
#include "glog/logging.h" | ||
#include "paddle/framework/operator.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
template <typename Place, typename T> | ||
class SGDOpKernel : public framework::OpKernel { | ||
public: | ||
void Compute(const framework::KernelContext& ctx) const override { | ||
auto param = ctx.Input("param")->Get<framework::Tensor>(); | ||
auto grad = ctx.Input("grad")->Get<framework::Tensor>(); | ||
auto* param_out = ctx.Output(0)->GetMutable<framework::Tensor>(); | ||
float lr = ctx.op_.GetAttr<float>("learning_rate"); | ||
|
||
param_out->mutable_data<T>(ctx.GetPlace()); | ||
|
||
param_out->flat<T>().device(*(ctx.GetEigenDevice<Place>())) = | ||
param.flat<T>() - lr * grad.flat<T>(); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include <gtest/gtest.h> | ||
#define private public | ||
#include <paddle/framework/op_registry.h> | ||
USE_OP(sgd); | ||
TEST(SGDOp, GetOpProto) { | ||
auto& protos = paddle::framework::OpRegistry::protos(); | ||
auto it = protos.find("sgd"); | ||
ASSERT_NE(it, protos.end()); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
cc_library(paddle_pybind SHARED SRCS pybind.cc DEPS pybind python | ||
add_op fc_op) | ||
add_op fc_op sgd_op) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
add_python_test(test_framework test_protobuf.py test_scope.py | ||
test_default_scope_funcs.py test_op_creation_methods.py | ||
test_tensor.py test_fc_op.py) | ||
test_tensor.py test_fc_op.py test_add_two_op.py test_sgd_op.py) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import paddle.v2.framework.core as core | ||
import unittest | ||
import numpy | ||
import paddle.v2.framework.create_op_creation_methods as creation | ||
|
||
|
||
class OpTestMeta(type): | ||
def __new__(cls, name, bases, attrs): | ||
obj = super(OpTestMeta, cls).__new__(cls, name, bases, attrs) | ||
|
||
def test_all(self): | ||
func = getattr(creation.op_creations, self.type, None) | ||
self.assertIsNotNone(func) | ||
|
||
scope = core.Scope(None) | ||
kwargs = dict() | ||
|
||
for in_name in func.all_input_args: | ||
if hasattr(self, in_name): | ||
kwargs[in_name] = in_name | ||
var = scope.create_var(in_name).get_tensor() | ||
arr = getattr(self, in_name) | ||
var.set_dims(arr.shape) | ||
var.set(arr) | ||
else: | ||
kwargs[in_name] = "@EMPTY@" | ||
|
||
for out_name in func.all_output_args: | ||
if hasattr(self, out_name): | ||
kwargs[out_name] = out_name | ||
scope.create_var(out_name).get_tensor() | ||
|
||
for attr_name in func.all_attr_args: | ||
if hasattr(self, attr_name): | ||
kwargs[attr_name] = getattr(self, attr_name) | ||
|
||
op = func(**kwargs) | ||
|
||
op.infer_shape(scope) | ||
|
||
ctx = core.DeviceContext.cpu_context() | ||
op.run(scope, ctx) | ||
|
||
for out_name in func.all_output_args: | ||
actual = numpy.array(scope.get_var(out_name).get_tensor()) | ||
expect = getattr(self, out_name) | ||
numpy.testing.assert_almost_equal(actual, expect) | ||
|
||
obj.test_all = test_all | ||
return obj |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
import unittest | ||
from op_test_util import OpTestMeta | ||
import numpy | ||
|
||
|
||
class TestAddOp(unittest.TestCase): | ||
__metaclass__ = OpTestMeta | ||
|
||
def setUp(self): | ||
self.type = "add_two" | ||
self.X = numpy.random.random((342, 345)).astype("float32") | ||
self.Y = numpy.random.random((342, 345)).astype("float32") | ||
self.Out = self.X + self.Y | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import unittest | ||
import numpy | ||
from op_test_util import OpTestMeta | ||
|
||
|
||
class TestSGD(unittest.TestCase): | ||
__metaclass__ = OpTestMeta | ||
|
||
def setUp(self): | ||
self.type = "sgd" | ||
self.param = numpy.random.random((342, 345)).astype("float32") | ||
self.grad = numpy.random.random((342, 345)).astype("float32") | ||
self.learning_rate = 0.1 | ||
self.param_out = self.param - self.learning_rate * self.grad | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. seems it is does not finished check? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. By using current op unit test framework, the developer only needs to |
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no black magic, please.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed