Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use operator context and infer context #3024

Merged
merged 25 commits into from
Aug 1, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
299525d
use operator context
jacquesqiao Jul 23, 2017
11eabf8
optimize code
jacquesqiao Jul 24, 2017
4280a60
update net infershape
jacquesqiao Jul 24, 2017
dda4881
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Jul 24, 2017
fb1b3d1
update InferShape
jacquesqiao Jul 24, 2017
081c7ca
disable override InferShape(scope) in OperatorBase
jacquesqiao Jul 24, 2017
5273c7e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Jul 25, 2017
0d693fe
change InferShapeImpl to InferShape
jacquesqiao Jul 25, 2017
bf3940b
add template to OperatorContext Input/Output
jacquesqiao Jul 26, 2017
362ba2f
merge Input InputVar, Output OutputVar
jacquesqiao Jul 26, 2017
a4bfb61
change Inputs to MultiInput
jacquesqiao Jul 26, 2017
217186e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Jul 27, 2017
1c91df3
fix conflict
jacquesqiao Jul 27, 2017
2460af4
fix MultiInput bugs and add unit test
jacquesqiao Jul 27, 2017
fb1980e
rename KernelContext to ExecutionContext
jacquesqiao Jul 27, 2017
9ff3595
clean code
jacquesqiao Jul 27, 2017
9fafc46
change InferShape to protected
jacquesqiao Jul 28, 2017
e87d253
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Jul 30, 2017
9a2640b
fix template bug
jacquesqiao Jul 30, 2017
b6764c9
refine code
jacquesqiao Jul 30, 2017
fab7737
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Jul 30, 2017
e4445d6
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Jul 31, 2017
eda5493
use InputVar instead of Input<Variable>
jacquesqiao Jul 31, 2017
5f0ed40
typo
jacquesqiao Jul 31, 2017
bd8872c
optimize code
jacquesqiao Aug 1, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions paddle/framework/net_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ static int run_cnt = 0;

class TestOp : public OperatorBase {
public:
void InferShape(
const std::shared_ptr<framework::Scope>& scope) const override {
void InferShape(const std::shared_ptr<Scope>& scope) const override {
++infer_shape_cnt;
}
void Run(const std::shared_ptr<framework::Scope>& scope,
Expand Down
4 changes: 2 additions & 2 deletions paddle/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@ namespace paddle {
namespace framework {

template <>
Eigen::DefaultDevice* KernelContext::GetEigenDevice<
Eigen::DefaultDevice* ExecutionContext::GetEigenDevice<
platform::CPUPlace, Eigen::DefaultDevice>() const {
return device_context_.get_eigen_device<Eigen::DefaultDevice>();
}

#ifndef PADDLE_ONLY_CPU
template <>
Eigen::GpuDevice*
KernelContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
return device_context_.get_eigen_device<Eigen::GpuDevice>();
}
#endif
Expand Down
181 changes: 106 additions & 75 deletions paddle/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,9 @@ limitations under the License. */
namespace paddle {
namespace framework {

template <typename T>
struct EigenDeviceConverter;

template <>
struct EigenDeviceConverter<platform::CPUPlace> {
using EigenDeviceType = Eigen::DefaultDevice;
};

#ifndef PADDLE_ONLY_CPU
template <>
struct EigenDeviceConverter<platform::GPUPlace> {
using EigenDeviceType = Eigen::GpuDevice;
};
#endif

class OperatorBase;
class InferShapeContext;
class ExecutionContext;
/**
* OperatorBase has the basic element that Net will call to do computation.
* Only CreateOperator from OpRegistry will new Operator directly. User
Expand Down Expand Up @@ -112,85 +99,151 @@ class OperatorBase {
std::shared_ptr<std::unordered_map<std::string, int>> in_out_idxs_;
};

class KernelContext {
class OperatorContext {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OperatorContext => ExecutionContext?

我们会有两个概念分别叫做 OperatorContext 和 KernelContext 的吗?如果其实没有,那么就叫 Context 或者 ExecutionContext 是不是更清楚?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ExecutionContext is better

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We actually have two contexts, one for InferShape, other for Run.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See line 35 and 36

public:
KernelContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& device_context)
: op_(*op), scope_(scope), device_context_(device_context) {}
OperatorContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope)
: op_(*op), scope_(scope) {}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have to check OperatorBase* op not null first.
And we use const OperatorBase& op_ as a member, why not const std::shared_ptr<OperatorBase> op_

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we donot need to check because context will only be construct inside a op, so op will nevel be null. And so be need not to use std::shared_ptr


size_t InputSize() const { return op_.inputs_.size(); }

const Variable* Input(int index) const {
return scope_->GetVariable(op_.inputs_[index]);
size_t OutputSize() const { return op_.outputs_.size(); }

const Variable* InputVar(const size_t& index) const {
return scope_->GetVariable(op_.inputs_.at(index));
}

Variable* Output(int index) const {
return scope_->GetVariable(op_.outputs_[index]);
Variable* OutputVar(const size_t& index) const {
return scope_->GetVariable(op_.outputs_.at(index));
}

const Variable* Input(const std::string& name) const {
const Variable* InputVar(const std::string& name) const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still need InputVar although we have template <typename T> Input(name) ?

Isn't InputVar is Input<Variable> ?

return scope_->GetVariable(op_.Input(name));
}

const Variable* Output(const std::string& name) const {
Variable* OutputVar(const std::string& name) const {
return scope_->GetVariable(op_.Output(name));
}

const std::vector<const Variable*> Inputs(const std::string& name) const {
const std::vector<const Variable*> MultiInputVar(
const std::string& name) const {
auto names = op_.Inputs(name);
std::vector<const Variable*> res;
res.reserve(names.size());
std::transform(
names.begin(), names.end(), res.begin(),
names.begin(), names.end(), std::back_inserter(res),
[this](const std::string& name) { return scope_->GetVariable(name); });
return res;
}

const std::vector<const Variable*> Outputs(const std::string& name) const {
std::vector<const Variable*> MultiOutputVar(const std::string& name) const {
auto names = op_.Outputs(name);
std::vector<const Variable*> res;
res.reserve(names.size());
std::transform(
names.begin(), names.end(), res.begin(),
names.begin(), names.end(), std::back_inserter(res),
[this](const std::string& name) { return scope_->GetVariable(name); });
return res;
}

template <typename T>
const T* Input(const size_t& index) const {
return &(InputVar(index)->Get<T>());
}

template <typename T>
T* Output(const size_t& index) const {
return OutputVar(index)->GetMutable<T>();
}

template <typename T>
const T* Input(const std::string& name) const {
return &(InputVar(name)->Get<T>());
}

template <typename T>
T* Output(const std::string& name) const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as top

return OutputVar(name)->GetMutable<T>();
}

template <typename T>
const std::vector<const T*> MultiInput(const std::string& name) const {
auto names = op_.Inputs(name);
std::vector<const T*> res;
res.reserve(names.size());
std::transform(names.begin(), names.end(), std::back_inserter(res),
[this](const std::string& name) {
return &scope_->GetVariable(name)->Get<T>();
});
return res;
}

template <typename T>
std::vector<const T*> MultiOutput(const std::string& name) const {
auto names = op_.Outputs(name);
std::vector<const T*> res;
res.reserve(names.size());
std::transform(names.begin(), names.end(), std::back_inserter(res),
[this](const std::string& name) {
return scope_->GetVariable(name)->GetMutable<T>();
});
return res;
}

const OperatorBase& op_;
const std::shared_ptr<Scope>& scope_;
};

class InferShapeContext : public OperatorContext {
public:
InferShapeContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope)
: OperatorContext(op, scope) {}
};

template <typename T>
struct EigenDeviceConverter;

template <>
struct EigenDeviceConverter<platform::CPUPlace> {
using EigenDeviceType = Eigen::DefaultDevice;
};

#ifndef PADDLE_ONLY_CPU
template <>
struct EigenDeviceConverter<platform::GPUPlace> {
using EigenDeviceType = Eigen::GpuDevice;
};
#endif

class ExecutionContext : public OperatorContext {
public:
ExecutionContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& device_context)
: OperatorContext(op, scope), device_context_(device_context) {}

template <typename PlaceType,
typename DeviceType =
typename EigenDeviceConverter<PlaceType>::EigenDeviceType>
DeviceType* GetEigenDevice() const;

platform::Place GetPlace() const { return device_context_.GetPlace(); }

const OperatorBase& op_;
const std::shared_ptr<Scope>& scope_;
const platform::DeviceContext& device_context_;
};

class OpKernel {
public:
/**
* KernelContext is the only parameter of Kernel Run function.
* ExecutionContext is the only parameter of Kernel Run function.
* Run will get input/output variables, state such as momentum and
* device resource such as CUDA stream, cublas handle, etc. from
* KernelContext. User should construct it before run the Operator.
* ExecutionContext. User should construct it before run the Operator.
*/

virtual void Compute(const KernelContext& context) const = 0;
virtual void Compute(const ExecutionContext& context) const = 0;

virtual ~OpKernel() {}
};

template <typename T>
struct VarToTensor {};

template <>
struct VarToTensor<Tensor*> {
Tensor* operator()(Variable* var) { return var->GetMutable<Tensor>(); }
};

template <>
struct VarToTensor<const Tensor*> {
const Tensor* operator()(Variable* var) { return &var->Get<Tensor>(); }
};

class OperatorWithKernel : public OperatorBase {
public:
struct OpKernelKey {
Expand All @@ -216,10 +269,14 @@ class OperatorWithKernel : public OperatorBase {
using OpKernelMap =
std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>;

void InferShape(const std::shared_ptr<Scope>& scope) const {
InferShape(InferShapeContext(this, scope));
}

void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const final {
auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx));
opKernel->Compute(KernelContext(this, scope, dev_ctx));
opKernel->Compute(ExecutionContext(this, scope, dev_ctx));
}

static std::unordered_map<std::string /* op_type */, OpKernelMap>&
Expand All @@ -228,34 +285,8 @@ class OperatorWithKernel : public OperatorBase {
return g_all_op_kernels;
}

void InferShape(const std::shared_ptr<Scope>& scope) const final {
std::vector<const Tensor*> ins;
VarNamesToTensors(scope, inputs_, &ins);
std::vector<Tensor*> outs;
VarNamesToTensors(scope, outputs_, &outs);
InferShape(ins, outs);
};

private:
template <typename T>
void VarNamesToTensors(const std::shared_ptr<Scope>& scope,
const std::vector<std::string>& var_names,
std::vector<T>* container) const {
container->reserve(var_names.size());
VarToTensor<T> convert;
for (auto& name : var_names) {
auto var = scope->GetVariable(name);
if (var != nullptr) {
container->push_back(convert(var));
} else {
container->push_back(nullptr);
}
}
}

protected:
virtual void InferShape(const std::vector<const Tensor*>& inputs,
const std::vector<Tensor*>& outputs) const = 0;
virtual void InferShape(const InferShapeContext& ctx) const = 0;
};

} // namespace framework
Expand Down
38 changes: 32 additions & 6 deletions paddle/framework/operator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ static int op_run_num = 0;
class OpWithoutKernelTest : public OperatorBase {
public:
void Init() override { x = 1; }
void InferShape(const std::shared_ptr<Scope>& scope) const override {}
void InferShape(
const std::shared_ptr<framework::Scope>& scope) const override {}
void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const override {
op_run_num++;
Expand Down Expand Up @@ -73,6 +74,7 @@ TEST(OperatorBase, all) {
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
scope->CreateVariable("OUT1");
ASSERT_EQ(paddle::framework::op_run_num, 0);
op->InferShape(scope);
op->Run(scope, device_context);
ASSERT_EQ(paddle::framework::op_run_num, 1);
}
Expand All @@ -97,14 +99,13 @@ static int cpu_kernel_run_num = 0;

class OpWithKernelTest : public OperatorWithKernel {
protected:
void InferShape(const std::vector<const Tensor*>& inputs,
const std::vector<Tensor*>& outputs) const override {}
void InferShape(const framework::InferShapeContext& ctx) const override {}
};

template <typename T1, typename T2>
class CPUKernelTest : public OpKernel {
public:
void Compute(const KernelContext& ctx) const {
void Compute(const ExecutionContext& ctx) const {
std::cout << "this is cpu kernel" << std::endl;
std::cout << ctx.op_.DebugString() << std::endl;
cpu_kernel_run_num++;
Expand All @@ -117,7 +118,8 @@ class CPUKernelTest : public OpKernel {
class OperatorMultiInputsTest : public OperatorBase {
public:
void Init() override { x = 1; }
void InferShape(const std::shared_ptr<Scope>& scope) const override {}
void InferShape(
const std::shared_ptr<framework::Scope>& scope) const override {}
void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const override {
ASSERT_EQ(scope->GetVariable(inputs_[0]), nullptr);
Expand Down Expand Up @@ -149,13 +151,31 @@ class OpKernelTestMultiInputsProtoAndCheckerMaker

class CPUKernalMultiInputsTest : public OpKernel {
public:
void Compute(const KernelContext& ctx) const {
void Compute(const ExecutionContext& ctx) const {
auto xs = ctx.op_.Inputs("xs");
ASSERT_EQ(xs.size(), 3UL);
ASSERT_EQ(xs[0], "x0");
ASSERT_EQ(xs[1], "x1");
ASSERT_EQ(xs[2], "x2");

auto inVar0 = ctx.MultiInputVar("xs");
ASSERT_EQ(inVar0.size(), 3);

auto intVar1 = ctx.InputVar("k");
ASSERT_NE(intVar1, nullptr);

auto outVar0 = ctx.MultiOutputVar("ys");
ASSERT_EQ(outVar0.size(), 2);

auto inTensor0 = ctx.MultiInput<Tensor>("xs");
ASSERT_EQ(inTensor0.size(), 3);

auto intTensor1 = ctx.Input<Tensor>("k");
ASSERT_NE(intTensor1, nullptr);

auto outTensor0 = ctx.MultiOutput<Tensor>("ys");
ASSERT_EQ(outTensor0.size(), 2);

auto k = ctx.op_.Input("k");
ASSERT_EQ(k, "k0");

Expand Down Expand Up @@ -233,6 +253,12 @@ TEST(OpKernel, multi_inputs) {

paddle::platform::CPUDeviceContext cpu_device_context;
auto scope = std::make_shared<Scope>();
scope->CreateVariable("x0")->GetMutable<Tensor>();
scope->CreateVariable("x1")->GetMutable<Tensor>();
scope->CreateVariable("x2")->GetMutable<Tensor>();
scope->CreateVariable("k0")->GetMutable<Tensor>();
scope->CreateVariable("y0")->GetMutable<Tensor>();
scope->CreateVariable("y1")->GetMutable<Tensor>();

auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
op->Run(scope, cpu_device_context);
Expand Down
Loading