From ca10db70d92ef0469c246b09f4674d851469f8ca Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Sat, 15 Jul 2017 21:49:03 +0800 Subject: [PATCH] Complete NetOp * OperatorBase should not store OpDesc, because not All op contains an OpDesc and not all ops create from OpDesc. * Networks do not contain OpDesc, and do not created by OpDesc * Do not register Network to OpRegistry. * The network is directly created by user in Python. Not from registry. * Correctly handle the `inputs` and `outputs` of a Network. * Add CompleteAddOp() methods * Remove `AddOp(OpDesc&)`. All op are added by pointer. * Rewrite unit test for truely tested what networks do. * Remove `DemoOp` and `DemoOpTest` because it is useless and break the CI --- paddle/framework/CMakeLists.txt | 2 +- paddle/framework/net.cc | 53 ++++++++++--------- paddle/framework/net.h | 48 ++++++++++-------- paddle/framework/net_op_test.cc | 87 +++++++++++++++++++++----------- paddle/framework/op_registry.h | 2 +- paddle/framework/operator.cc | 2 +- paddle/framework/operator.h | 7 +-- paddle/operators/CMakeLists.txt | 4 -- paddle/operators/demo_op.cc | 79 ----------------------------- paddle/operators/demo_op_test.cc | 60 ---------------------- 10 files changed, 120 insertions(+), 224 deletions(-) delete mode 100644 paddle/operators/demo_op.cc delete mode 100644 paddle/operators/demo_op_test.cc diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 4ba45a0ff9c4a..06b1075882b92 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -22,4 +22,4 @@ add_dependencies(framework_py_proto framework_py_proto_init) proto_library(net_proto SRCS net_proto.proto DEPS op_proto) cc_library(net SRCS net.cc DEPS operator net_proto op_registry) -cc_test(net_op_test SRCS net_op_test.cc DEPS net demo_op) +cc_test(net_op_test SRCS net_op_test.cc DEPS net) diff --git a/paddle/framework/net.cc b/paddle/framework/net.cc index 5cc028c50f75a..e8479839122e1 100644 --- a/paddle/framework/net.cc +++ b/paddle/framework/net.cc @@ -3,33 +3,40 @@ namespace paddle { namespace framework { -void PlainNet::AddOp(const OpDesc& desc) { - ops_.push_back(OpRegistry::CreateOp(desc)); -} - -void PlainNet::AddOp(const OperatorPtr& op) { ops_.push_back(op); } - -void PlainNet::InferShape(const ScopePtr& scope) const { +void PlainNet::CompleteAddOp() { + std::unordered_set input_set; + std::unordered_set output_set; + std::unordered_set temp_output; for (auto& op : ops_) { - op->InferShape(scope); + for (auto& ipt : op->inputs_) { + if (!Contains(output_set, ipt)) { // Not other op's output + input_set.insert(ipt); + } else { + temp_output.insert(ipt); + } + } + + for (auto& opt : op->outputs_) { + output_set.insert(opt); + } } -} + inputs_.reserve(input_set.size()); + std::copy(input_set.begin(), input_set.end(), std::back_inserter(inputs_)); -void PlainNet::Run(const ScopePtr& scope, const DeviceContext& ctx) const { - for (auto& op : ops_) { - op->Run(scope, ctx); + outputs_.reserve(output_set.size()); + std::vector tmp_index; + tmp_index.reserve(temp_output.size()); + int idx = 0; + for (auto& opt : output_set) { + if (Contains(temp_output, opt)) { + tmp_index.push_back(idx); + } + outputs_.push_back(opt); + ++idx; } + + attrs_["temporary_index"] = tmp_index; } -class PlainNetOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { - public: - PlainNetOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddComment("This is test op"); - } -}; } // namespace framework -} // namespace paddle - -REGISTER_OP(plainnet_operator, paddle::framework::PlainNet, - paddle::framework::PlainNetOpProtoAndCheckerMaker); \ No newline at end of file +} // namespace paddle \ No newline at end of file diff --git a/paddle/framework/net.h b/paddle/framework/net.h index 4eccd710b0d3a..19a1620e29b86 100644 --- a/paddle/framework/net.h +++ b/paddle/framework/net.h @@ -24,16 +24,12 @@ limitations under the License. */ namespace paddle { namespace framework { -using namespace paddle::platform; - /** * @brief Network is also a type of Operator * * It will manage the operators it has. * - * Network is the container and controller of a set of operators, user can build - * a real network from a NetDesc which is a protobuf message and use - * Network.Run() * to run all the operators in the network. + * Network is the container and controller of a set of operators. * A network object knows all Operators belonging to this network. Variables, * which are inputs and outputs of these operators, are created and managed by a @@ -44,14 +40,12 @@ using namespace paddle::platform; */ class Net : public OperatorBase { public: - /* - * @brief Add an Operator according to `def`. - */ - virtual void AddOp(const OpDesc& def) = 0; - virtual void AddOp(const OperatorPtr& op) = 0; + virtual void CompleteAddOp() = 0; }; +using NetPtr = std::shared_ptr; + /** * @brief a basic implementation of Net. * @@ -64,7 +58,11 @@ class PlainNet : public Net { * Infer all the operators' input and output variables' shapes, will be called * before every mini-batch */ - void InferShape(const ScopePtr& scope) const override; + void InferShape(const ScopePtr& scope) const override { + for (auto& op : ops_) { + op->InferShape(scope); + } + } /** * @brief Run the network. @@ -74,21 +72,31 @@ class PlainNet : public Net { * will be used. */ void Run(const ScopePtr& scope, - const platform::DeviceContext& dev_ctx) const override; - - /** - * @brief Add an Operator by OpDesc. - */ - void AddOp(const OpDesc& def) override; + const platform::DeviceContext& dev_ctx) const override { + for (auto& op : ops_) { + op->Run(scope, dev_ctx); + } + } /** * @brief Add an operator by ptr */ - void AddOp(const OperatorPtr& def) override; + void AddOp(const OperatorPtr& op) override { + PADDLE_ENFORCE(!add_op_done_, "Cannot AddOp when this network is sealed"); + ops_.push_back(op); + } + + void CompleteAddOp() override; - private: - // the operators owned by `Network`. std::vector ops_; + + private: + bool add_op_done_{false}; + + template + static bool Contains(T container, KeyType key) { + return container.find(key) != container.end(); + } }; } // namespace framework diff --git a/paddle/framework/net_op_test.cc b/paddle/framework/net_op_test.cc index 0c43ad84b5e40..10083b9e06016 100644 --- a/paddle/framework/net_op_test.cc +++ b/paddle/framework/net_op_test.cc @@ -3,36 +3,63 @@ #include #include -USE_OP_WITHOUT_KERNEL(test_operator); -USE_OP_WITHOUT_KERNEL(plainnet_operator); +namespace pd = paddle::framework; + +static int infer_shape_cnt = 0; +static int run_cnt = 0; + +class TestOp : public pd::OperatorBase { + public: + void InferShape(const paddle::framework::ScopePtr& scope) const override { + ++infer_shape_cnt; + } + void Run(const paddle::framework::ScopePtr& scope, + const paddle::platform::DeviceContext& dev_ctx) const override { + ++run_cnt; + } +}; + +template +void AssertSameVectorWithoutOrder(const std::vector& expected, + const std::vector& actual) { + ASSERT_EQ(expected.size(), actual.size()); + std::unordered_set expected_set; + for (auto& tmp : expected) { + expected_set.insert(tmp); + } + for (auto& act : actual) { + ASSERT_NE(expected_set.end(), expected_set.find(act)); + } +} TEST(OpKernel, all) { - using namespace paddle::framework; - using namespace paddle::platform; - - // net op - OpDesc net_op_desc; - net_op_desc.set_type("plainnet_operator"); - - // test op - OpDesc test_op_desc; - test_op_desc.set_type("test_operator"); - *test_op_desc.mutable_inputs()->Add() = "IN1"; - *test_op_desc.mutable_outputs()->Add() = "OUT1"; - auto attr = test_op_desc.mutable_attrs()->Add(); - attr->set_name("scale"); - attr->set_type(paddle::framework::AttrType::FLOAT); - attr->set_f(3.14); - - auto test_op = OpRegistry::CreateOp(test_op_desc); - - CPUDeviceContext cpu_device_context; - auto scope = std::make_shared(); - - OperatorPtr op = paddle::framework::OpRegistry::CreateOp(net_op_desc); - auto net_op = static_cast(op.get()); - - net_op->AddOp(test_op_desc); - net_op->AddOp(test_op); - net_op->Run(scope, cpu_device_context); + auto net = std::make_shared(); + ASSERT_NE(net, nullptr); + + auto op1 = std::make_shared(); + op1->inputs_ = {"x", "w1", "b1"}; + op1->outputs_ = {"y"}; + net->AddOp(op1); + + auto op2 = std::make_shared(); + op2->inputs_ = {"y", "w2", "b2"}; + op2->outputs_ = {"z"}; + net->AddOp(op2); + + net->CompleteAddOp(); + AssertSameVectorWithoutOrder({"x", "w1", "b1", "w2", "b2"}, net->inputs_); + AssertSameVectorWithoutOrder({"y", "z"}, net->outputs_); + auto tmp_idx_iter = net->attrs_.find("temporary_index"); + ASSERT_NE(net->attrs_.end(), tmp_idx_iter); + auto& tmp_idx = boost::get>(tmp_idx_iter->second); + ASSERT_EQ(1UL, tmp_idx.size()); + ASSERT_EQ("y", net->outputs_[tmp_idx[0]]); + + auto scope = std::make_shared(); + paddle::platform::CPUDeviceContext dev_ctx; + + net->InferShape(scope); + net->Run(scope, dev_ctx); + ASSERT_EQ(2, infer_shape_cnt); + ASSERT_EQ(2, run_cnt); } diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 52af03432ed55..1deac1de20c2e 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -201,7 +201,7 @@ class OpRegistry { static OperatorPtr CreateOp(const OpDesc& op_desc) { std::string op_type = op_desc.type(); OperatorPtr op(creators().at(op_type)()); - op->desc_ = op_desc; + op->type_ = op_desc.type(); op->inputs_.reserve((size_t)op_desc.inputs_size()); std::copy(op_desc.inputs().begin(), op_desc.inputs().end(), std::back_inserter(op->inputs_)); diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 8f7adff8b3982..e9812e0ca88d4 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -20,7 +20,7 @@ namespace framework { std::string OperatorBase::DebugString() const { std::stringstream ss; ss << "=================\n"; - ss << "type = " << desc_.type() << "\n"; + ss << "type = " << type_ << "\n"; ss << "inputs = ["; for (auto& ipt : inputs_) { ss << ipt << ", "; diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index cf79f379fae1e..f7ed6e9f3d942 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -62,11 +62,8 @@ class OperatorBase { virtual void Run(const ScopePtr& scope, const platform::DeviceContext& dev_ctx) const = 0; - protected: - std::string Type() const { return desc_.type(); } - public: - OpDesc desc_; + std::string type_; std::vector inputs_; std::vector outputs_; AttributeMap attrs_; @@ -142,7 +139,7 @@ class OperatorWithKernel : public OperatorBase { void Run(const ScopePtr& scope, const platform::DeviceContext& dev_ctx) const final { - auto& opKernel = AllOpKernels().at(Type()).at(OpKernelKey(dev_ctx)); + auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx)); opKernel->Compute(OpKernel::KernelContext(this, scope, dev_ctx)); } diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 8540f44a79c88..40bb326512c11 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -4,7 +4,3 @@ else() cc_library(add_op SRCS add_op.cc DEPS operator op_registry glog ddim) endif() cc_test(add_op_test SRCS add_op_test.cc DEPS add_op) - - -cc_library(demo_op SRCS demo_op.cc DEPS operator op_registry) -cc_test(demo_op_test SRCS demo_op_test.cc DEPS demo_op) diff --git a/paddle/operators/demo_op.cc b/paddle/operators/demo_op.cc deleted file mode 100644 index b67d8caf4d951..0000000000000 --- a/paddle/operators/demo_op.cc +++ /dev/null @@ -1,79 +0,0 @@ -#include -#include -#include - -namespace paddle { -namespace operators { - -class OperatorTest : public framework::OperatorBase { -public: - void Init() override { x = 1; } - void InferShape(const framework::ScopePtr& scope) const override {} - void Run(const framework::ScopePtr& scope, - const platform::DeviceContext& dev_ctx) const override { - float scale = GetAttr("scale"); - std::cout << "this is " << Type() << std::endl - << " scale=" << scale << std::endl; - std::cout << DebugString() << std::endl; - } - -public: - float x = 0; -}; - -class OperatorTestProtoAndCheckerMaker - : public framework::OpProtoAndCheckerMaker { -public: - OperatorTestProtoAndCheckerMaker(framework::OpProto* proto, - framework::OpAttrChecker* op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("input", "input of test op"); - AddOutput("output", "output of test op"); - AddAttr("scale", "scale of cosine op") - .SetDefault(1.0) - .LargerThan(0.0); - AddComment("This is test op"); - } -}; - -class OpKernelTestProtoAndCheckerMaker - : public framework::OpProtoAndCheckerMaker { -public: - OpKernelTestProtoAndCheckerMaker(framework::OpProto* proto, - framework::OpAttrChecker* op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("input", "input of test op"); - AddOutput("output", "output of test op"); - AddAttr("scale", "scale of cosine op") - .SetDefault(1.0) - .LargerThan(0.0); - AddComment("This is test op"); - } -}; - -class OpWithKernelTest : public framework::OperatorWithKernel { -protected: - void InferShape( - const std::vector& inputs, - const std::vector& outputs) const override {} -}; - -class CPUKernelTest : public framework::OpKernel { -public: - void Compute(const framework::OpKernel::KernelContext& context) const { - float scale = context.op_.GetAttr("scale"); - std::cout << "this is cpu kernel, scale=" << scale << std::endl; - std::cout << context.op_.DebugString() << std::endl; - } -}; - -} // namespace operators -} // namespace paddle - -REGISTER_OP(test_operator, - paddle::operators::OperatorTest, - paddle::operators::OperatorTestProtoAndCheckerMaker); -REGISTER_OP(op_with_kernel, - paddle::operators::OpWithKernelTest, - paddle::operators::OpKernelTestProtoAndCheckerMaker); -REGISTER_OP_CPU_KERNEL(op_with_kernel, paddle::operators::CPUKernelTest); diff --git a/paddle/operators/demo_op_test.cc b/paddle/operators/demo_op_test.cc deleted file mode 100644 index a1e8425a9ad70..0000000000000 --- a/paddle/operators/demo_op_test.cc +++ /dev/null @@ -1,60 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "gtest/gtest.h" -#include "paddle/framework/op_registry.h" -#include "paddle/framework/operator.h" - -USE_OP_WITHOUT_KERNEL(test_operator); -USE_OP(op_with_kernel); - -TEST(OperatorBase, all) { - paddle::framework::OpDesc op_desc; - op_desc.set_type("test_operator"); - *op_desc.mutable_inputs()->Add() = "IN1"; - *op_desc.mutable_outputs()->Add() = "OUT1"; - auto attr = op_desc.mutable_attrs()->Add(); - attr->set_name("scale"); - attr->set_type(paddle::framework::AttrType::FLOAT); - float scale = 3.14; - attr->set_f(scale); - - paddle::platform::CPUDeviceContext device_context; - auto scope = std::make_shared(); - - paddle::framework::OperatorPtr op = - paddle::framework::OpRegistry::CreateOp(op_desc); - ASSERT_EQ(op->GetAttr("scale"), scale); - scope->CreateVariable("OUT1"); - op->Run(scope, device_context); - std::cout << op->DebugString() << std::endl; -} - -TEST(OpKernel, all) { - paddle::framework::OpDesc op_desc; - op_desc.set_type("op_with_kernel"); - *op_desc.mutable_inputs()->Add() = "IN1"; - *op_desc.mutable_outputs()->Add() = "OUT1"; - auto attr = op_desc.mutable_attrs()->Add(); - attr->set_name("scale"); - attr->set_type(paddle::framework::AttrType::FLOAT); - attr->set_f(3.14); - - paddle::platform::CPUDeviceContext cpu_device_context; - auto scope = std::make_shared(); - - paddle::framework::OperatorPtr op = - paddle::framework::OpRegistry::CreateOp(op_desc); - op->Run(scope, cpu_device_context); -}