-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
use operator context and infer context #3024
Changes from all commits
299525d
11eabf8
4280a60
dda4881
fb1b3d1
081c7ca
5273c7e
0d693fe
bf3940b
362ba2f
a4bfb61
217186e
1c91df3
2460af4
fb1980e
9ff3595
9fafc46
e87d253
9a2640b
b6764c9
fab7737
e4445d6
eda5493
5f0ed40
bd8872c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,22 +31,9 @@ limitations under the License. */ | |
namespace paddle { | ||
namespace framework { | ||
|
||
template <typename T> | ||
struct EigenDeviceConverter; | ||
|
||
template <> | ||
struct EigenDeviceConverter<platform::CPUPlace> { | ||
using EigenDeviceType = Eigen::DefaultDevice; | ||
}; | ||
|
||
#ifndef PADDLE_ONLY_CPU | ||
template <> | ||
struct EigenDeviceConverter<platform::GPUPlace> { | ||
using EigenDeviceType = Eigen::GpuDevice; | ||
}; | ||
#endif | ||
|
||
class OperatorBase; | ||
class InferShapeContext; | ||
class ExecutionContext; | ||
/** | ||
* OperatorBase has the basic element that Net will call to do computation. | ||
* Only CreateOperator from OpRegistry will new Operator directly. User | ||
|
@@ -112,85 +99,151 @@ class OperatorBase { | |
std::shared_ptr<std::unordered_map<std::string, int>> in_out_idxs_; | ||
}; | ||
|
||
class KernelContext { | ||
class OperatorContext { | ||
public: | ||
KernelContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope, | ||
const platform::DeviceContext& device_context) | ||
: op_(*op), scope_(scope), device_context_(device_context) {} | ||
OperatorContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope) | ||
: op_(*op), scope_(scope) {} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Have to check OperatorBase* op not null first. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we donot need to check because context will only be construct inside a op, so op will nevel be null. And so be need not to use std::shared_ptr |
||
|
||
size_t InputSize() const { return op_.inputs_.size(); } | ||
|
||
const Variable* Input(int index) const { | ||
return scope_->GetVariable(op_.inputs_[index]); | ||
size_t OutputSize() const { return op_.outputs_.size(); } | ||
|
||
const Variable* InputVar(const size_t& index) const { | ||
return scope_->GetVariable(op_.inputs_.at(index)); | ||
} | ||
|
||
Variable* Output(int index) const { | ||
return scope_->GetVariable(op_.outputs_[index]); | ||
Variable* OutputVar(const size_t& index) const { | ||
return scope_->GetVariable(op_.outputs_.at(index)); | ||
} | ||
|
||
const Variable* Input(const std::string& name) const { | ||
const Variable* InputVar(const std::string& name) const { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Still need Isn't |
||
return scope_->GetVariable(op_.Input(name)); | ||
} | ||
|
||
const Variable* Output(const std::string& name) const { | ||
Variable* OutputVar(const std::string& name) const { | ||
return scope_->GetVariable(op_.Output(name)); | ||
} | ||
|
||
const std::vector<const Variable*> Inputs(const std::string& name) const { | ||
const std::vector<const Variable*> MultiInputVar( | ||
const std::string& name) const { | ||
auto names = op_.Inputs(name); | ||
std::vector<const Variable*> res; | ||
res.reserve(names.size()); | ||
std::transform( | ||
names.begin(), names.end(), res.begin(), | ||
names.begin(), names.end(), std::back_inserter(res), | ||
[this](const std::string& name) { return scope_->GetVariable(name); }); | ||
return res; | ||
} | ||
|
||
const std::vector<const Variable*> Outputs(const std::string& name) const { | ||
std::vector<const Variable*> MultiOutputVar(const std::string& name) const { | ||
auto names = op_.Outputs(name); | ||
std::vector<const Variable*> res; | ||
res.reserve(names.size()); | ||
std::transform( | ||
names.begin(), names.end(), res.begin(), | ||
names.begin(), names.end(), std::back_inserter(res), | ||
[this](const std::string& name) { return scope_->GetVariable(name); }); | ||
return res; | ||
} | ||
|
||
template <typename T> | ||
const T* Input(const size_t& index) const { | ||
return &(InputVar(index)->Get<T>()); | ||
} | ||
|
||
template <typename T> | ||
T* Output(const size_t& index) const { | ||
return OutputVar(index)->GetMutable<T>(); | ||
} | ||
|
||
template <typename T> | ||
const T* Input(const std::string& name) const { | ||
return &(InputVar(name)->Get<T>()); | ||
} | ||
|
||
template <typename T> | ||
T* Output(const std::string& name) const { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same as top |
||
return OutputVar(name)->GetMutable<T>(); | ||
} | ||
|
||
template <typename T> | ||
const std::vector<const T*> MultiInput(const std::string& name) const { | ||
auto names = op_.Inputs(name); | ||
std::vector<const T*> res; | ||
res.reserve(names.size()); | ||
std::transform(names.begin(), names.end(), std::back_inserter(res), | ||
[this](const std::string& name) { | ||
return &scope_->GetVariable(name)->Get<T>(); | ||
}); | ||
return res; | ||
} | ||
|
||
template <typename T> | ||
std::vector<const T*> MultiOutput(const std::string& name) const { | ||
auto names = op_.Outputs(name); | ||
std::vector<const T*> res; | ||
res.reserve(names.size()); | ||
std::transform(names.begin(), names.end(), std::back_inserter(res), | ||
[this](const std::string& name) { | ||
return scope_->GetVariable(name)->GetMutable<T>(); | ||
}); | ||
return res; | ||
} | ||
|
||
const OperatorBase& op_; | ||
const std::shared_ptr<Scope>& scope_; | ||
}; | ||
|
||
class InferShapeContext : public OperatorContext { | ||
public: | ||
InferShapeContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope) | ||
: OperatorContext(op, scope) {} | ||
}; | ||
|
||
template <typename T> | ||
struct EigenDeviceConverter; | ||
|
||
template <> | ||
struct EigenDeviceConverter<platform::CPUPlace> { | ||
using EigenDeviceType = Eigen::DefaultDevice; | ||
}; | ||
|
||
#ifndef PADDLE_ONLY_CPU | ||
template <> | ||
struct EigenDeviceConverter<platform::GPUPlace> { | ||
using EigenDeviceType = Eigen::GpuDevice; | ||
}; | ||
#endif | ||
|
||
class ExecutionContext : public OperatorContext { | ||
public: | ||
ExecutionContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope, | ||
const platform::DeviceContext& device_context) | ||
: OperatorContext(op, scope), device_context_(device_context) {} | ||
|
||
template <typename PlaceType, | ||
typename DeviceType = | ||
typename EigenDeviceConverter<PlaceType>::EigenDeviceType> | ||
DeviceType* GetEigenDevice() const; | ||
|
||
platform::Place GetPlace() const { return device_context_.GetPlace(); } | ||
|
||
const OperatorBase& op_; | ||
const std::shared_ptr<Scope>& scope_; | ||
const platform::DeviceContext& device_context_; | ||
}; | ||
|
||
class OpKernel { | ||
public: | ||
/** | ||
* KernelContext is the only parameter of Kernel Run function. | ||
* ExecutionContext is the only parameter of Kernel Run function. | ||
* Run will get input/output variables, state such as momentum and | ||
* device resource such as CUDA stream, cublas handle, etc. from | ||
* KernelContext. User should construct it before run the Operator. | ||
* ExecutionContext. User should construct it before run the Operator. | ||
*/ | ||
|
||
virtual void Compute(const KernelContext& context) const = 0; | ||
virtual void Compute(const ExecutionContext& context) const = 0; | ||
|
||
virtual ~OpKernel() {} | ||
}; | ||
|
||
template <typename T> | ||
struct VarToTensor {}; | ||
|
||
template <> | ||
struct VarToTensor<Tensor*> { | ||
Tensor* operator()(Variable* var) { return var->GetMutable<Tensor>(); } | ||
}; | ||
|
||
template <> | ||
struct VarToTensor<const Tensor*> { | ||
const Tensor* operator()(Variable* var) { return &var->Get<Tensor>(); } | ||
}; | ||
|
||
class OperatorWithKernel : public OperatorBase { | ||
public: | ||
struct OpKernelKey { | ||
|
@@ -216,10 +269,14 @@ class OperatorWithKernel : public OperatorBase { | |
using OpKernelMap = | ||
std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>; | ||
|
||
void InferShape(const std::shared_ptr<Scope>& scope) const { | ||
InferShape(InferShapeContext(this, scope)); | ||
} | ||
|
||
void Run(const std::shared_ptr<Scope>& scope, | ||
const platform::DeviceContext& dev_ctx) const final { | ||
auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx)); | ||
opKernel->Compute(KernelContext(this, scope, dev_ctx)); | ||
opKernel->Compute(ExecutionContext(this, scope, dev_ctx)); | ||
} | ||
|
||
static std::unordered_map<std::string /* op_type */, OpKernelMap>& | ||
|
@@ -228,34 +285,8 @@ class OperatorWithKernel : public OperatorBase { | |
return g_all_op_kernels; | ||
} | ||
|
||
void InferShape(const std::shared_ptr<Scope>& scope) const final { | ||
std::vector<const Tensor*> ins; | ||
VarNamesToTensors(scope, inputs_, &ins); | ||
std::vector<Tensor*> outs; | ||
VarNamesToTensors(scope, outputs_, &outs); | ||
InferShape(ins, outs); | ||
}; | ||
|
||
private: | ||
template <typename T> | ||
void VarNamesToTensors(const std::shared_ptr<Scope>& scope, | ||
const std::vector<std::string>& var_names, | ||
std::vector<T>* container) const { | ||
container->reserve(var_names.size()); | ||
VarToTensor<T> convert; | ||
for (auto& name : var_names) { | ||
auto var = scope->GetVariable(name); | ||
if (var != nullptr) { | ||
container->push_back(convert(var)); | ||
} else { | ||
container->push_back(nullptr); | ||
} | ||
} | ||
} | ||
|
||
protected: | ||
virtual void InferShape(const std::vector<const Tensor*>& inputs, | ||
const std::vector<Tensor*>& outputs) const = 0; | ||
virtual void InferShape(const InferShapeContext& ctx) const = 0; | ||
}; | ||
|
||
} // namespace framework | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OperatorContext => ExecutionContext?
我们会有两个概念分别叫做 OperatorContext 和 KernelContext 的吗?如果其实没有,那么就叫 Context 或者 ExecutionContext 是不是更清楚?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ExecutionContext is better
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We actually have two contexts, one for InferShape, other for Run.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See line 35 and 36