Skip to content

Commit

Permalink
Merge pull request #1 from reyoung/net-op
Browse files Browse the repository at this point in the history
Refine NetOp
  • Loading branch information
jacquesqiao committed Jul 15, 2017
2 parents 58bfcec + ca10db7 commit 796b763
Show file tree
Hide file tree
Showing 10 changed files with 120 additions and 224 deletions.
2 changes: 1 addition & 1 deletion paddle/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ add_dependencies(framework_py_proto framework_py_proto_init)

proto_library(net_proto SRCS net_proto.proto DEPS op_proto)
cc_library(net SRCS net.cc DEPS operator net_proto op_registry)
cc_test(net_op_test SRCS net_op_test.cc DEPS net demo_op)
cc_test(net_op_test SRCS net_op_test.cc DEPS net)
53 changes: 30 additions & 23 deletions paddle/framework/net.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,40 @@
namespace paddle {
namespace framework {

void PlainNet::AddOp(const OpDesc& desc) {
ops_.push_back(OpRegistry::CreateOp(desc));
}

void PlainNet::AddOp(const OperatorPtr& op) { ops_.push_back(op); }

void PlainNet::InferShape(const ScopePtr& scope) const {
void PlainNet::CompleteAddOp() {
std::unordered_set<std::string> input_set;
std::unordered_set<std::string> output_set;
std::unordered_set<std::string> temp_output;
for (auto& op : ops_) {
op->InferShape(scope);
for (auto& ipt : op->inputs_) {
if (!Contains(output_set, ipt)) { // Not other op's output
input_set.insert(ipt);
} else {
temp_output.insert(ipt);
}
}

for (auto& opt : op->outputs_) {
output_set.insert(opt);
}
}
}
inputs_.reserve(input_set.size());
std::copy(input_set.begin(), input_set.end(), std::back_inserter(inputs_));

void PlainNet::Run(const ScopePtr& scope, const DeviceContext& ctx) const {
for (auto& op : ops_) {
op->Run(scope, ctx);
outputs_.reserve(output_set.size());
std::vector<int> tmp_index;
tmp_index.reserve(temp_output.size());
int idx = 0;
for (auto& opt : output_set) {
if (Contains(temp_output, opt)) {
tmp_index.push_back(idx);
}
outputs_.push_back(opt);
++idx;
}

attrs_["temporary_index"] = tmp_index;
}

class PlainNetOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
PlainNetOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddComment("This is test op");
}
};
} // namespace framework
} // namespace paddle

REGISTER_OP(plainnet_operator, paddle::framework::PlainNet,
paddle::framework::PlainNetOpProtoAndCheckerMaker);
} // namespace paddle
48 changes: 28 additions & 20 deletions paddle/framework/net.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,12 @@ limitations under the License. */

namespace paddle {
namespace framework {
using namespace paddle::platform;

/**
* @brief Network is also a type of Operator
*
* It will manage the operators it has.
*
* Network is the container and controller of a set of operators, user can build
* a real network from a NetDesc which is a protobuf message and use
* Network.Run() * to run all the operators in the network.
* Network is the container and controller of a set of operators.
* A network object knows all Operators belonging to this network. Variables,
* which are inputs and outputs of these operators, are created and managed by a
Expand All @@ -44,14 +40,12 @@ using namespace paddle::platform;
*/
class Net : public OperatorBase {
public:
/*
* @brief Add an Operator according to `def`.
*/
virtual void AddOp(const OpDesc& def) = 0;

virtual void AddOp(const OperatorPtr& op) = 0;
virtual void CompleteAddOp() = 0;
};

using NetPtr = std::shared_ptr<Net>;

/**
* @brief a basic implementation of Net.
*
Expand All @@ -64,7 +58,11 @@ class PlainNet : public Net {
* Infer all the operators' input and output variables' shapes, will be called
* before every mini-batch
*/
void InferShape(const ScopePtr& scope) const override;
void InferShape(const ScopePtr& scope) const override {
for (auto& op : ops_) {
op->InferShape(scope);
}
}

/**
* @brief Run the network.
Expand All @@ -74,21 +72,31 @@ class PlainNet : public Net {
* will be used.
*/
void Run(const ScopePtr& scope,
const platform::DeviceContext& dev_ctx) const override;

/**
* @brief Add an Operator by OpDesc.
*/
void AddOp(const OpDesc& def) override;
const platform::DeviceContext& dev_ctx) const override {
for (auto& op : ops_) {
op->Run(scope, dev_ctx);
}
}

/**
* @brief Add an operator by ptr
*/
void AddOp(const OperatorPtr& def) override;
void AddOp(const OperatorPtr& op) override {
PADDLE_ENFORCE(!add_op_done_, "Cannot AddOp when this network is sealed");
ops_.push_back(op);
}

void CompleteAddOp() override;

private:
// the operators owned by `Network`.
std::vector<OperatorPtr> ops_;

private:
bool add_op_done_{false};

template <typename T, typename KeyType>
static bool Contains(T container, KeyType key) {
return container.find(key) != container.end();
}
};

} // namespace framework
Expand Down
87 changes: 57 additions & 30 deletions paddle/framework/net_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,36 +3,63 @@
#include <paddle/framework/op_registry.h>
#include <paddle/framework/operator.h>

USE_OP_WITHOUT_KERNEL(test_operator);
USE_OP_WITHOUT_KERNEL(plainnet_operator);
namespace pd = paddle::framework;

static int infer_shape_cnt = 0;
static int run_cnt = 0;

class TestOp : public pd::OperatorBase {
public:
void InferShape(const paddle::framework::ScopePtr& scope) const override {
++infer_shape_cnt;
}
void Run(const paddle::framework::ScopePtr& scope,
const paddle::platform::DeviceContext& dev_ctx) const override {
++run_cnt;
}
};

template <typename T>
void AssertSameVectorWithoutOrder(const std::vector<T>& expected,
const std::vector<T>& actual) {
ASSERT_EQ(expected.size(), actual.size());
std::unordered_set<T> expected_set;
for (auto& tmp : expected) {
expected_set.insert(tmp);
}
for (auto& act : actual) {
ASSERT_NE(expected_set.end(), expected_set.find(act));
}
}

TEST(OpKernel, all) {
using namespace paddle::framework;
using namespace paddle::platform;

// net op
OpDesc net_op_desc;
net_op_desc.set_type("plainnet_operator");

// test op
OpDesc test_op_desc;
test_op_desc.set_type("test_operator");
*test_op_desc.mutable_inputs()->Add() = "IN1";
*test_op_desc.mutable_outputs()->Add() = "OUT1";
auto attr = test_op_desc.mutable_attrs()->Add();
attr->set_name("scale");
attr->set_type(paddle::framework::AttrType::FLOAT);
attr->set_f(3.14);

auto test_op = OpRegistry::CreateOp(test_op_desc);

CPUDeviceContext cpu_device_context;
auto scope = std::make_shared<Scope>();

OperatorPtr op = paddle::framework::OpRegistry::CreateOp(net_op_desc);
auto net_op = static_cast<PlainNet*>(op.get());

net_op->AddOp(test_op_desc);
net_op->AddOp(test_op);
net_op->Run(scope, cpu_device_context);
auto net = std::make_shared<paddle::framework::PlainNet>();
ASSERT_NE(net, nullptr);

auto op1 = std::make_shared<TestOp>();
op1->inputs_ = {"x", "w1", "b1"};
op1->outputs_ = {"y"};
net->AddOp(op1);

auto op2 = std::make_shared<TestOp>();
op2->inputs_ = {"y", "w2", "b2"};
op2->outputs_ = {"z"};
net->AddOp(op2);

net->CompleteAddOp();
AssertSameVectorWithoutOrder({"x", "w1", "b1", "w2", "b2"}, net->inputs_);
AssertSameVectorWithoutOrder({"y", "z"}, net->outputs_);
auto tmp_idx_iter = net->attrs_.find("temporary_index");
ASSERT_NE(net->attrs_.end(), tmp_idx_iter);
auto& tmp_idx = boost::get<std::vector<int>>(tmp_idx_iter->second);
ASSERT_EQ(1UL, tmp_idx.size());
ASSERT_EQ("y", net->outputs_[tmp_idx[0]]);

auto scope = std::make_shared<pd::Scope>();
paddle::platform::CPUDeviceContext dev_ctx;

net->InferShape(scope);
net->Run(scope, dev_ctx);
ASSERT_EQ(2, infer_shape_cnt);
ASSERT_EQ(2, run_cnt);
}
2 changes: 1 addition & 1 deletion paddle/framework/op_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ class OpRegistry {
static OperatorPtr CreateOp(const OpDesc& op_desc) {
std::string op_type = op_desc.type();
OperatorPtr op(creators().at(op_type)());
op->desc_ = op_desc;
op->type_ = op_desc.type();
op->inputs_.reserve((size_t)op_desc.inputs_size());
std::copy(op_desc.inputs().begin(), op_desc.inputs().end(),
std::back_inserter(op->inputs_));
Expand Down
2 changes: 1 addition & 1 deletion paddle/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ namespace framework {
std::string OperatorBase::DebugString() const {
std::stringstream ss;
ss << "=================\n";
ss << "type = " << desc_.type() << "\n";
ss << "type = " << type_ << "\n";
ss << "inputs = [";
for (auto& ipt : inputs_) {
ss << ipt << ", ";
Expand Down
7 changes: 2 additions & 5 deletions paddle/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,8 @@ class OperatorBase {
virtual void Run(const ScopePtr& scope,
const platform::DeviceContext& dev_ctx) const = 0;

protected:
std::string Type() const { return desc_.type(); }

public:
OpDesc desc_;
std::string type_;
std::vector<std::string> inputs_;
std::vector<std::string> outputs_;
AttributeMap attrs_;
Expand Down Expand Up @@ -142,7 +139,7 @@ class OperatorWithKernel : public OperatorBase {

void Run(const ScopePtr& scope,
const platform::DeviceContext& dev_ctx) const final {
auto& opKernel = AllOpKernels().at(Type()).at(OpKernelKey(dev_ctx));
auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx));
opKernel->Compute(OpKernel::KernelContext(this, scope, dev_ctx));
}

Expand Down
4 changes: 0 additions & 4 deletions paddle/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,3 @@ else()
cc_library(add_op SRCS add_op.cc DEPS operator op_registry glog ddim)
endif()
cc_test(add_op_test SRCS add_op_test.cc DEPS add_op)


cc_library(demo_op SRCS demo_op.cc DEPS operator op_registry)
cc_test(demo_op_test SRCS demo_op_test.cc DEPS demo_op)
79 changes: 0 additions & 79 deletions paddle/operators/demo_op.cc

This file was deleted.

Loading

0 comments on commit 796b763

Please sign in to comment.