Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OperatorWithKernel #2815

Merged
merged 1 commit into from
Jul 12, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion paddle/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ if(Boost_FOUND)
add_subdirectory(memory)
add_subdirectory(platform)
add_subdirectory(framework)
add_subdirectory(operators)
add_subdirectory(pybind)
endif()

Expand Down
33 changes: 16 additions & 17 deletions paddle/framework/op_registry_test.cc
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
#include "paddle/framework/op_registry.h"
#include <gtest/gtest.h>
#include "paddle/framework/operator.h"
#include "paddle/operators/demo_op.h"

using namespace paddle::framework;

namespace paddle {
namespace framework {
class CosineOp : public OperatorWithKernel {
class CosineOp : public OperatorBase {
public:
void Run(const OpRunContext* context) const override {
printf("%s\n", DebugString().c_str());
}
void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const override {}
void InferShape(const std::shared_ptr<Scope>& scope) const override {}
};

class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
Expand All @@ -30,12 +28,13 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {

REGISTER_OP(CosineOp, CosineOpProtoAndCheckerMaker, cos_sim)

class MyTestOp : public OperatorWithKernel {
class MyTestOp : public OperatorBase {
public:
void InferShape(const std::shared_ptr<Scope>& scope) const override {}
void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const override {}

public:
void Run(const OpRunContext* ctx) const override {
printf("%s\n", DebugString().c_str());
printf("test_attr = %d\n", ctx->op_->GetAttr<int>("test_attr"));
}
};

class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
Expand Down Expand Up @@ -73,8 +72,8 @@ TEST(OpRegistry, CreateOp) {
paddle::framework::OperatorBase* op =
paddle::framework::OpRegistry::CreateOp(op_desc);
auto scope = std::make_shared<Scope>();
auto dev_ctx = DeviceContext();
op->Run(scope, &dev_ctx);
paddle::platform::CPUDeviceContext dev_ctx;
op->Run(scope, dev_ctx);
float scale_get = op->GetAttr<float>("scale");
ASSERT_EQ(scale_get, scale);
}
Expand Down Expand Up @@ -116,8 +115,8 @@ TEST(OpRegistry, DefaultValue) {
paddle::framework::OperatorBase* op =
paddle::framework::OpRegistry::CreateOp(op_desc);
auto scope = std::make_shared<Scope>();
auto dev_ctx = DeviceContext();
op->Run(scope, &dev_ctx);
paddle::platform::CPUDeviceContext dev_ctx;
op->Run(scope, dev_ctx);
ASSERT_EQ(op->GetAttr<float>("scale"), 1.0);
}

Expand Down Expand Up @@ -169,9 +168,9 @@ TEST(OpRegistry, CustomChecker) {
attr->set_i(4);
paddle::framework::OperatorBase* op =
paddle::framework::OpRegistry::CreateOp(op_desc);
auto dev_ctx = DeviceContext();
paddle::platform::CPUDeviceContext dev_ctx;
auto scope = std::make_shared<Scope>();
op->Run(scope, &dev_ctx);
op->Run(scope, dev_ctx);
int test_attr = op->GetAttr<int>("test_attr");
ASSERT_EQ(test_attr, 4);
}
Expand Down
8 changes: 0 additions & 8 deletions paddle/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,5 @@ std::string OperatorBase::DebugString() const {
return ss.str();
}

const Variable* OpRunContext::Input(int index) const {
return scope_->GetVariable(op_->inputs_[index]);
}

Variable* OpRunContext::Output(int index) const {
return scope_->GetVariable(op_->outputs_[index]);
}

} // namespace framework
} // namespace paddle
117 changes: 80 additions & 37 deletions paddle/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,44 +14,22 @@ limitations under the License. */

#pragma once

#include <paddle/framework/attr_checker.h>
#include <paddle/framework/op_desc.pb.h>
#include <paddle/framework/scope.h>
#include <paddle/platform/device_context.h>
#include <paddle/platform/place.h>
#include <paddle/utils/Error.h>
#include <boost/variant.hpp>
#include <string>
#include <unordered_map>
#include <vector>

#include "paddle/framework/attr_checker.h"
#include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/scope.h"
#include "paddle/utils/Error.h"

namespace paddle {
namespace framework {

class OperatorBase;

class DeviceContext {};

/**
* OpRunContext is the only parameter of Operator's Run function.
* Run will get input/output variables, state such as momentum and
* device resource such as CUDA stream, cublas handle, etc. from
* OpRunContext. User should construct it before run the Operator.
*/
class OpRunContext {
public:
OpRunContext(const OperatorBase* op, const std::shared_ptr<Scope> scope,
const DeviceContext* device_context)
: op_(op), scope_(scope), device_context_(device_context) {}

const Variable* Input(int index) const;
Variable* Output(int index) const;

public:
const OperatorBase* op_;
const std::shared_ptr<Scope> scope_;
const DeviceContext* device_context_;
};

/**
* OperatorBase has the basic element that Net will call to do computation.
* Only CreateOperator from OpRegistry will new Operator directly. User
Expand All @@ -77,7 +55,10 @@ class OperatorBase {

/// Net will call this function to Run an op.
virtual void Run(const std::shared_ptr<Scope>& scope,
const DeviceContext* dev_ctx) const = 0;
const platform::DeviceContext& dev_ctx) const = 0;

protected:
std::string Type() const { return desc_.type(); }

public:
OpDesc desc_;
Expand All @@ -86,22 +67,84 @@ class OperatorBase {
AttributeMap attrs_;
};

class OpKernel {
public:
/**
* KernelContext is the only parameter of Kernel Run function.
* Run will get input/output variables, state such as momentum and
* device resource such as CUDA stream, cublas handle, etc. from
* KernelContext. User should construct it before run the Operator.
*/
class KernelContext {
public:
KernelContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& device_context)
: op_(*op), scope_(scope), device_context_(device_context) {}

const Variable* Input(int index) const {
return scope_->GetVariable(op_.inputs_[index]);
}

Variable* Output(int index) const {
return scope_->GetVariable(op_.outputs_[index]);
}

const OperatorBase& op_;
const std::shared_ptr<Scope>& scope_;
const platform::DeviceContext& device_context_;
};

virtual void Compute(const KernelContext& context) const = 0;

virtual ~OpKernel() {}
};

class OperatorWithKernel : public OperatorBase {
public:
virtual ~OperatorWithKernel() {}
struct OpKernelKey {
platform::Place place_;

virtual void InferShape(const std::shared_ptr<Scope>& scope) const {}
OpKernelKey() = default;
OpKernelKey(const platform::DeviceContext& dev_ctx) {
place_ = dev_ctx.GetPlace();
}

bool operator==(const OpKernelKey& o) const { return place_ == o.place_; }
};

struct OpKernelHash {
std::hash<bool> hash_;
size_t operator()(const OpKernelKey& key) const {
return hash_(platform::is_gpu_place(key.place_));
}
};

using OpKernelMap =
std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>;

void Run(const std::shared_ptr<Scope>& scope,
const DeviceContext* dev_ctx) const {
OpRunContext op_ctx(this, scope, dev_ctx);
Run(&op_ctx);
const platform::DeviceContext& dev_ctx) const final {
auto& opKernel = AllOpKernels().at(Type()).at(OpKernelKey(dev_ctx));
opKernel->Compute(OpKernel::KernelContext(this, scope, dev_ctx));
}

/// when implement an Op, your should implement this function.
/// this function should be moved to OpKernel later
virtual void Run(const OpRunContext* context) const = 0;
static std::unordered_map<std::string /* op_type */, OpKernelMap>&
AllOpKernels() {
static std::unordered_map<std::string, OpKernelMap> g_all_op_kernels;
return g_all_op_kernels;
};
};

} // namespace framework
} // namespace paddle

#define REGISTER_OP_KERNEL(type, PlaceType, KernelType) \
struct __op_kernel_register__##type##__ { \
__op_kernel_register__##type##__() { \
::paddle::framework::OperatorWithKernel::OpKernelKey key; \
key.place_ = PlaceType(); \
::paddle::framework::OperatorWithKernel::AllOpKernels()[#type][key] \
.reset(new KernelType()); \
} \
}; \
static __op_kernel_register__##type##__ __reg_kernel_##type##__
39 changes: 16 additions & 23 deletions paddle/framework/operator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,15 @@ limitations under the License. */
namespace paddle {
namespace framework {

class OperatorTest : public OperatorWithKernel {
class OperatorTest : public OperatorBase {
public:
void Run(const OpRunContext* ctx) const override {
float scale = ctx->op_->GetAttr<float>("scale");
PADDLE_ENFORCE(ctx->Input(0) == nullptr, "Input(0) should not initialized");
PADDLE_ENFORCE(ctx->Output(0) == nullptr,
"Output(1) should not initialized");
auto output1 = ctx->scope_->CreateVariable("output1");
PADDLE_ENFORCE(output1 != nullptr, "should create output1 from scope");
printf("get attr %s = %f\n", "scale", scale);
printf("%s\n", DebugString().c_str());
void InferShape(const std::shared_ptr<Scope>& scope) const override {}
void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const override {
float scale = GetAttr<float>("scale");
ASSERT_NEAR(scale, 3.14, 1e-5);
ASSERT_EQ(scope->GetVariable(inputs_[0]), nullptr);
ASSERT_NE(scope->GetVariable(outputs_[0]), nullptr);
}
};

Expand All @@ -49,31 +47,26 @@ class OperatorTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {

REGISTER_OP(OperatorTest, OperatorTestProtoAndCheckerMaker, test_operator)

TEST(OperatorBase, DebugString) {
TEST(OperatorBase, all) {
OpDesc op_desc;
op_desc.set_type("test_operator");
std::vector<std::string> inputs = {"IN1", "IN2"};
for (auto& input : inputs) {
op_desc.add_inputs(input);
}
std::vector<std::string> outputs = {"OUT1", "OUT2"};
for (auto& output : outputs) {
op_desc.add_outputs(output);
}
*op_desc.mutable_inputs()->Add() = "IN1";
*op_desc.mutable_outputs()->Add() = "OUT1";
auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale");
attr->set_type(paddle::framework::AttrType::FLOAT);
float scale = 3.14;
attr->set_f(scale);

DeviceContext device_context;
platform::CPUDeviceContext device_context;
auto scope = std::make_shared<Scope>();

OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc);
ASSERT_EQ(op->inputs_, inputs);
ASSERT_EQ(op->outputs_, outputs);
ASSERT_EQ(op->GetAttr<float>("scale"), scale);
op->Run(scope, &device_context);
scope->CreateVariable("OUT1");
op->Run(scope, device_context);
std::cout << op->DebugString() << std::endl;
delete op;
}

} // namespace framework
Expand Down
5 changes: 0 additions & 5 deletions paddle/operators/.clang-format

This file was deleted.

Empty file removed paddle/operators/CMakeLists.txt
Empty file.
59 changes: 0 additions & 59 deletions paddle/operators/demo_op.h

This file was deleted.

Loading