-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
use operator context and infer context #3024
use operator context and infer context #3024
Conversation
paddle/framework/operator.cc
Outdated
@@ -79,6 +79,10 @@ std::vector<std::string> OperatorBase::Outputs(const std::string& name) const { | |||
outputs_.begin() + output_format.at(offset + 1)}; | |||
} | |||
|
|||
void OperatorBase::InferShape(const std::shared_ptr<Scope>& scope) const { | |||
InferShapeImpl(InferShapeContext(this, scope)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does scope
here have to be of type shared_ptr
? It seems simpler if we can use const Scope&
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It cannot be done, because inside InferShape
/Run
in some operator, e.g., RNN
, the developer will create a new local Scope
which uses std::shared_ptr<Scope>
as an argument.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里传递指针的引用确实比较confusing,能加一下注释吗?
指针的引用表示指针本身也会被改变,那改变之后,这个指针之前指向的对象怎么办呢?会有内存泄漏吗?
paddle/framework/net.h
Outdated
@@ -57,9 +57,9 @@ class PlainNet : public Net { | |||
* Infer all the operators' input and output variables' shapes, will be called | |||
* before every mini-batch | |||
*/ | |||
void InferShape(const std::shared_ptr<Scope>& scope) const override { | |||
void InferShapeImpl(const InferShapeContext& ctx) const override { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In my mind, Impl
is usually a suffix for a class name, which implements an interface. Do we really need to name a function Impl
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's right, maybe just called InferShape
is cool.
paddle/framework/operator.h
Outdated
return scope_->GetVariable(op_.outputs_[index]); | ||
} | ||
|
||
const Variable* Input(const std::string& name) const { | ||
const Variable* InputVar(const std::string& name) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可不可以 std::vector<const Variable*>
整个作为 Variable::Get()
的类型:
typedef std::vector<Tensor*> TensorArray;
TensorArray tensors = var.Get<TensorArray>();
@@ -110,29 +98,32 @@ class OperatorBase { | |||
std::shared_ptr<std::unordered_map<std::string, int>> in_out_idxs_; | |||
}; | |||
|
|||
class KernelContext { | |||
class OperatorContext { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OperatorContext => ExecutionContext?
我们会有两个概念分别叫做 OperatorContext 和 KernelContext 的吗?如果其实没有,那么就叫 Context 或者 ExecutionContext 是不是更清楚?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ExecutionContext is better
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We actually have two contexts, one for InferShape, other for Run.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See line 35 and 36
const platform::DeviceContext& device_context) | ||
: op_(*op), scope_(scope), device_context_(device_context) {} | ||
OperatorContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope) | ||
: op_(*op), scope_(scope) {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have to check OperatorBase* op not null first.
And we use const OperatorBase& op_
as a member, why not const std::shared_ptr<OperatorBase> op_
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we donot need to check because context will only be construct inside a op, so op will nevel be null. And so be need not to use std::shared_ptr
… optimize-context
paddle/framework/operator.h
Outdated
const Variable* Input(int index) const { | ||
int OutputSize() const { return static_cast<int>(op_.outputs_.size()); } | ||
|
||
const Variable* InputVar(int index) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In OperatorBase
, Input
returns string
, In Context
, Input
returns Tensor
, and here is another InputVar
Can we uniform all the Input()
s, and use a single API like template <typename T> input(std::string)
, and implements three types: string
, Tensor
, Variable
.
This is much simplier to understand. @jacquesqiao
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great suggestion, thanx!
paddle/framework/operator.h
Outdated
@@ -84,7 +71,8 @@ class OperatorBase { | |||
|
|||
/// InferShape infer the size of Variables used by this Operator with | |||
/// information inside scope | |||
virtual void InferShape(const std::shared_ptr<Scope>& scope) const = 0; | |||
virtual void InferShape(const std::shared_ptr<Scope>& scope) const final; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here we use the reference of std::shared_ptr<Scope>
as a parameter, and all InferShape
share the same std::shared_ptr<Scope>
.
Does it means that there will be only one std::shared_ptr<Scope>
? If so, why not use std::unique_ptr<Scope>
instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please complete unittest
paddle/framework/operator.cc
Outdated
auto names = op_.Inputs(name); | ||
std::vector<const Variable*> res; | ||
std::transform( | ||
names.begin(), names.end(), res.begin(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
std::back_inserter(res)
? because res is empty, transform cannot copy data to res container.
paddle/framework/operator.cc
Outdated
std::vector<const Variable*> OperatorContext::MultiOutput( | ||
const std::string& name) const { | ||
auto names = op_.Outputs(name); | ||
std::vector<const Variable*> res; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above
paddle/framework/operator.h
Outdated
OperatorContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope) | ||
: op_(*op), scope_(scope) {} | ||
|
||
int InputSize() const { return static_cast<int>(op_.inputs_.size()); } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
??? why not size_t?
… optimize-context
paddle/operators/add_op.cc
Outdated
void InferShape(const framework::InferShapeContext &ctx) const override { | ||
PADDLE_ENFORCE(ctx.InputSize() == 2, "Input size of AddOp must be two"); | ||
PADDLE_ENFORCE(ctx.OutputSize() == 1, "Output size of AddOp must be one"); | ||
PADDLE_ENFORCE(ctx.Input<framework::Variable>(0) != nullptr && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The lines of code are not shrunk at all.
Maybe we could add a default template argument as framework::Variable?
template <typename T = framework::Variable>
T Input(size_t idx);
I am not sure if that is a good design or not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
currently, the return type is a pointer, look like this
template <typename T = framework::Variable>
T* Input(size_t idx);
but this does not support std:: string
, that means we will add another 4 interfaces like:
std::string InputName();
std::vector<std::string> MultiInputNames();
std::string OutputName();
std::vector<std::string> MultiOutputNames();
Is there some way to embed std::string
as a template type as follows ?
auto variable_name = Input<std::string>("x")
paddle/framework/operator.cc
Outdated
@@ -99,5 +99,48 @@ std::string OperatorBase::DebugString() const { | |||
return ss.str(); | |||
} | |||
|
|||
template <> | |||
const Variable* OperatorContext::Input<Variable>(int index) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use size_t
instead of int
.
size_t is the standard type for stl container.
paddle/framework/operator.cc
Outdated
} | ||
|
||
template <> | ||
Variable* OperatorContext::Output<Variable>(int index) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
int --> size_t
@@ -110,29 +98,32 @@ class OperatorBase { | |||
std::shared_ptr<std::unordered_map<std::string, int>> in_out_idxs_; | |||
}; | |||
|
|||
class KernelContext { | |||
class OperatorContext { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We actually have two contexts, one for InferShape, other for Run.
paddle/framework/operator.h
Outdated
void InferShape(const std::shared_ptr<Scope>& scope) const { | ||
InferShape(InferShapeContext(this, scope)); | ||
} | ||
virtual void InferShape(const InferShapeContext& ctx) const = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make this interface protected
paddle/framework/operator.h
Outdated
int OutputSize() const { return static_cast<int>(op_.outputs_.size()); } | ||
|
||
template <typename T> | ||
const T* Input(int index) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why return const T*
not const T&
… optimize-context
} | ||
|
||
const Variable* Input(const std::string& name) const { | ||
const Variable* InputVar(const std::string& name) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still need InputVar
although we have template <typename T> Input(name)
?
Isn't InputVar
is Input<Variable>
?
paddle/framework/operator.h
Outdated
[this](const std::string& name) { return scope_->GetVariable(name); }); | ||
return res; | ||
} | ||
|
||
template <typename T> | ||
const T* Input(size_t index) const { | ||
return &(InputVar(index)->Get<T>()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add a template specification here to support Variable
?
template <>
Variable* Input<Variable>(const std::string &name);
} | ||
|
||
template <typename T> | ||
T* Output(const std::string& name) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as top
PADDLE_ENFORCE(outputs.size() == 1, "Sigmoid Op only have one output"); | ||
outputs[0]->Resize(inputs[0]->dims()); | ||
void InferShape(const InferShapeContext &ctx) const override { | ||
PADDLE_ENFORCE(ctx.InputSize() == 1, "Sigmoid Op only have one input"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
enforce the number of real inputs is the duty of OperatorBase
, here 1
is just the number of inputs in definition of sigmoid.
The simgoid
's op definition has only one input X
SigmoidOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "sigmoid input");
AddOutput("Y", "sigmoid output");
AddComment("Sigmoid function");
}
so we need OperatorBase
's infershape to automatically enforce number of inputs and outputs, the real operator's implementation doesn't need to enforce itself.
@jacquesqiao @reyoung @QiJune
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, the input size condition should be guaranteed by the Op creator.
PADDLE_ENFORCE(inputs.size() == 1, "Only one input is need for softmax"); | ||
PADDLE_ENFORCE(inputs[0]->dims().size() == 2, | ||
void InferShape(const InferShapeContext &ctx) const override { | ||
PADDLE_ENFORCE(ctx.InputSize() == 1, "Only one input is need for softmax"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as top, when an operator InferShape, it means to check the weather the real inputs and outputs match the definition in its op_proto.
here the 1
is softmax_op.op_proto.inputs().size()
, and OperatorBase
can do this enforce automatically for SoftmaxOp
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
now we use
for shape infer. but there are some problems:
std::vector<Tensor*> x = ctx.Inputs("X");