-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
change net to operator #2846
change net to operator #2846
Changes from 1 commit
1cd208a
843ca37
b4241da
5e34298
c12cf1d
f96d7c0
bf5197d
a57f20e
fc85e6a
fea4359
bb8ab09
6879eb2
994ca58
01f7963
58bfcec
ca10db7
796b763
467bdba
e235d08
c62986f
1429d96
b605ce6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,48 +1,45 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#pragma once | ||
|
||
#include <paddle/framework/op_desc.pb.h> | ||
#include <paddle/framework/operator.h> | ||
#include "paddle/framework/net_proto.pb.h" | ||
#include "paddle/framework/op_proto.pb.h" | ||
#include "paddle/framework/op_registry.h" | ||
#include "paddle/framework/scope.h" | ||
#include "paddle/platform/device_context.h" | ||
|
||
namespace paddle { | ||
namespace framework { | ||
using namespace paddle::platform; | ||
|
||
// operator's index stored in a network. | ||
typedef int OpIndex; | ||
/** | ||
* NOTE following codes are some definitions of unimplemented concepts. | ||
* We write some basic implementation to make Net compilable. These APIs will | ||
* keep updating if the concepts related are implemented. | ||
*/ | ||
|
||
struct OpDesc; | ||
struct OpAttrs {}; | ||
|
||
class Operator { | ||
// tmperary put here for test | ||
class PlainNetOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { | ||
public: | ||
Operator(const OpDesc &def) {} | ||
void InferShape() {} | ||
void Run(DeviceContext *ctx) {} | ||
PlainNetOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) | ||
: OpProtoAndCheckerMaker(proto, op_checker) { | ||
AddType("plainnet_operator"); | ||
AddComment("This is test op"); | ||
} | ||
}; | ||
|
||
/** | ||
* @brief Network that manage the operators it has. | ||
* @brief Network is also a type of Operator | ||
* | ||
* It will manage the operators it has. | ||
* | ||
* Network is the container and controller of a set of operators, user can build | ||
* a real network from a NetDesc which is a protobuf message and use | ||
|
@@ -55,43 +52,22 @@ class Operator { | |
* This is the base class of network, all the networks should implement the apis | ||
* it defines. | ||
*/ | ||
class Net { | ||
class Net : public OperatorBase { | ||
public: | ||
/** | ||
* @brief Infer shapes of all inputs and outputs of operators. | ||
*/ | ||
virtual void InferShape(Scope *scope) = 0; | ||
/** | ||
* @brief Run the network. | ||
* | ||
* Run all the operators and return success(true) or not, with all the | ||
* variables are located in `scope`. `context` describes the detail execution | ||
* environment for ops. `begin` and `end` specify the scope of `ops_` to run, | ||
* If no positive indexes are provided, all operators in `ops_` will run. | ||
*/ | ||
virtual void Run(std::shared_ptr<Scope> scope, DeviceContext *ctx) = 0; | ||
|
||
/** | ||
* @brief Add an Operator according to `def`. | ||
*/ | ||
virtual OpIndex AddOp(const OpProto &def) = 0; | ||
virtual void AddOp(const OpDesc& def) = 0; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it is better to make
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
|
||
/** | ||
* @brief Add optimizer operators acctording to `attrs`. | ||
*/ | ||
virtual void AddOptimizerOps(const OpAttrs &attrs) = 0; | ||
virtual void AddOptimizerOps() = 0; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. removed |
||
|
||
/** | ||
* @brief Add backward operators. | ||
*/ | ||
virtual void AddBackwardOps() = 0; | ||
|
||
/** | ||
* @brief Create a network. | ||
*/ | ||
static std::unique_ptr<Net> Create(const NetDesc &def = NetDesc()); | ||
|
||
virtual ~Net() {} | ||
}; | ||
|
||
/** | ||
|
@@ -102,19 +78,11 @@ class Net { | |
*/ | ||
class PlainNet : public Net { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. since we have reached an agreement in treating net as a composed Op, seems we do not need a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think net will have a common interface |
||
public: | ||
/** | ||
* @brief Initialize a PlainNet. | ||
* | ||
* Initialize from a network describe by `def`. NetDesc is the definition of | ||
* a network. | ||
*/ | ||
PlainNet(const NetDesc &def); | ||
|
||
/** | ||
* Infer all the operators' input and output varialbes' shapes, will be called | ||
* before every mini-batch | ||
*/ | ||
virtual void InferShape(Scope *scope) override; | ||
void InferShape(const std::shared_ptr<Scope>& scope) const override; | ||
|
||
/** | ||
* @brief Run the network. | ||
|
@@ -123,48 +91,27 @@ class PlainNet : public Net { | |
* scope will be used instead. If no OpContext is provicded, default context | ||
* will be used. | ||
*/ | ||
virtual void Run(std::shared_ptr<Scope> scope, DeviceContext *ctx) override; | ||
void Run(const std::shared_ptr<Scope>& scope, | ||
const platform::DeviceContext& dev_ctx) const override; | ||
|
||
/** | ||
* @brief Add an operator to this network. | ||
*/ | ||
virtual OpIndex AddOp(const OpProto &def) override; | ||
void AddOp(const OpDesc& def) override; | ||
|
||
/** | ||
* @brief Add all optimizer operators related into the network. | ||
*/ | ||
virtual void AddOptimizerOps(const OpAttrs &attrs) override; | ||
void AddOptimizerOps() override {} | ||
|
||
/** | ||
* @brief Add all backward operators related into the network. | ||
*/ | ||
virtual void AddBackwardOps() override; | ||
|
||
virtual ~PlainNet() override {} | ||
|
||
protected: | ||
/** | ||
* @brief Build the network. | ||
* | ||
* Create operators accordding to `def`, will be called by the constructor. | ||
*/ | ||
void BuildNet(const NetDesc &def); | ||
|
||
/** | ||
* @brief Add an operator into this network. | ||
* | ||
* Add a operator which is identified as `type` and has attributes described | ||
* in `attrs`, the `inputs` are the keys of readonly input variables, | ||
* `outputs` are keys of mutable output variables. An `OpIndex` will be | ||
* returned to indicate the offset of the new operator in `ops_`. | ||
*/ | ||
OpIndex AddOp(const std::string &type, const std::vector<std::string> &inputs, | ||
const std::vector<std::string> &outputs, | ||
const OpAttrs &attrs = OpAttrs()); | ||
void AddBackwardOps() override {} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we decided to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok |
||
|
||
private: | ||
// the operators owned by `Network`. | ||
std::vector<Operator> ops_; | ||
std::vector<std::unique_ptr<OperatorBase>> ops_; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. std::shared_ptr is better. The There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
}; | ||
|
||
} // namespace framework | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
#include <gtest/gtest.h> | ||
#include <paddle/framework/net.h> | ||
#include <paddle/framework/op_registry.h> | ||
#include <paddle/framework/operator.h> | ||
|
||
namespace paddle { | ||
namespace framework { | ||
class OperatorTest : public OperatorBase { | ||
public: | ||
void Init() override { x = 1; } | ||
void InferShape(const std::shared_ptr<Scope>& scope) const override {} | ||
void Run(const std::shared_ptr<Scope>& scope, | ||
const platform::DeviceContext& dev_ctx) const override { | ||
float scale = GetAttr<float>("scale"); | ||
ASSERT_NEAR(scale, 3.14, 1e-5); | ||
ASSERT_EQ(scope->GetVariable(inputs_[0]), nullptr); | ||
ASSERT_EQ(x, 1); | ||
std::cout << "this is test_operator" << std::endl; | ||
} | ||
|
||
public: | ||
float x = 0; | ||
}; | ||
|
||
class OperatorTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker { | ||
public: | ||
OperatorTestProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) | ||
: OpProtoAndCheckerMaker(proto, op_checker) { | ||
AddInput("input", "input of test op"); | ||
AddOutput("output", "output of test op"); | ||
AddAttr<float>("scale", "scale of cosine op") | ||
.SetDefault(1.0) | ||
.LargerThan(0.0); | ||
AddType("test_operator"); | ||
AddComment("This is test op"); | ||
} | ||
}; | ||
|
||
REGISTER_OP(test_operator, OperatorTest, OperatorTestProtoAndCheckerMaker); | ||
REGISTER_OP(plainnet_operator, PlainNet, PlainNetOpProtoAndCheckerMaker); | ||
|
||
} // namespace framework | ||
} // namespace paddle | ||
|
||
TEST(OpKernel, all) { | ||
using namespace paddle::framework; | ||
using namespace paddle::platform; | ||
|
||
// net op | ||
OpDesc net_op_desc; | ||
net_op_desc.set_type("plainnet_operator"); | ||
|
||
// test op | ||
OpDesc test_op_desc; | ||
test_op_desc.set_type("test_operator"); | ||
*test_op_desc.mutable_inputs()->Add() = "IN1"; | ||
*test_op_desc.mutable_outputs()->Add() = "OUT1"; | ||
auto attr = test_op_desc.mutable_attrs()->Add(); | ||
attr->set_name("scale"); | ||
attr->set_type(paddle::framework::AttrType::FLOAT); | ||
attr->set_f(3.14); | ||
|
||
CPUDeviceContext cpu_device_context; | ||
auto scope = std::make_shared<Scope>(); | ||
|
||
OperatorBase* op = paddle::framework::OpRegistry::CreateOp(net_op_desc); | ||
auto net_op = static_cast<PlainNet*>(op); | ||
|
||
net_op->AddOp(test_op_desc); | ||
op->Run(scope, cpu_device_context); | ||
|
||
delete op; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
emplace_back is better.
emplace_back vs push_back
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, I will take a look~~