Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract InferShape to many cc files #5174

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions paddle/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc)
cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute)
cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker)
cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog)
cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog shape_inference)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute ddim op_info operator glog)
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog)

cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
Expand Down
132 changes: 130 additions & 2 deletions paddle/framework/op_desc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,51 @@ limitations under the License. */
#include <functional>
#include <mutex>
#include <unordered_map>
#include "glog/logging.h"
#include "paddle/framework/block_desc.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/program_desc.h"

#include "glog/logging.h"
#include "paddle/framework/shape_inference.h"

namespace paddle {
namespace framework {

class OpDescBind;
class BlockDescBind;
class CompileTimeInferShapeContext : public InferShapeContext {
public:
CompileTimeInferShapeContext(const OpDescBind &op,
const BlockDescBind &block);

bool HasInput(const std::string &name) const override;

bool HasOutput(const std::string &name) const override;

bool HasInputs(const std::string &name) const override;

bool HasOutputs(const std::string &name) const override;

DDim GetInputDim(const std::string &name) const override;

void SetOutputDim(const std::string &name, const DDim &dim) override;

AttrReader Attrs() const override;

const std::vector<std::string> &Inputs(
const std::string &name) const override;

const std::vector<std::string> &Outputs(
const std::string &name) const override;

private:
DDim GetDim(const std::string &name) const override;

void SetDim(const std::string &name, const DDim &dim) override;

const OpDescBind &op_;
const BlockDescBind &block_;
};

OpDescBind::OpDescBind(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const AttributeMap &attrs) {
Expand Down Expand Up @@ -288,5 +324,97 @@ void OpDescBind::InferVarType(BlockDescBind *block) const {
}
}

CompileTimeInferShapeContext::CompileTimeInferShapeContext(
const OpDescBind &op, const BlockDescBind &block)
: op_(op), block_(block) {}

bool CompileTimeInferShapeContext::HasInput(const std::string &name) const {
const std::vector<std::string> &input_names = op_.Input(name);
auto length = input_names.size();
if (length == 0) {
return false;
}
PADDLE_ENFORCE_EQ(length, 1UL,
"Input(%s) should have only one value, "
"but it have %d now",
name, length);
return block_.HasVarRecursive(input_names[0]);
}

bool CompileTimeInferShapeContext::HasOutput(const std::string &name) const {
const std::vector<std::string> &output_names = op_.Output(name);
auto length = output_names.size();
if (length == 0) {
return false;
}
PADDLE_ENFORCE_EQ(length, 1UL,
"Output(%s) should have only one value, "
"but it have %d now",
name, length);
return block_.HasVarRecursive(output_names[0]);
}

bool CompileTimeInferShapeContext::HasInputs(const std::string &name) const {
const std::vector<std::string> &input_names = op_.Input(name);
if (input_names.empty()) {
return false;
}
for (auto &input : input_names) {
if (!block_.HasVarRecursive(input)) return false;
}
return true;
}

bool CompileTimeInferShapeContext::HasOutputs(const std::string &name) const {
const std::vector<std::string> &output_names = op_.Output(name);
if (output_names.empty()) {
return false;
}
for (auto &output : output_names) {
if (!block_.HasVarRecursive(output)) return false;
}
return true;
}

DDim CompileTimeInferShapeContext::GetInputDim(const std::string &name) const {
std::vector<DDim> ddims = GetInputsDim(name);
auto length = ddims.size();
PADDLE_ENFORCE_EQ(length, 1UL,
"Input(%s) should have 1 value, "
"but it has %d now",
name, length);
return ddims[0];
}

void CompileTimeInferShapeContext::SetOutputDim(const std::string &name,
const DDim &dim) {
SetOutputsDim(name, {dim});
}

AttrReader CompileTimeInferShapeContext::Attrs() const {
return AttrReader(op_.GetAttrMap());
}

const std::vector<std::string> &CompileTimeInferShapeContext::Inputs(
const std::string &name) const {
return op_.Input(name);
}

const std::vector<std::string> &CompileTimeInferShapeContext::Outputs(
const std::string &name) const {
return op_.Output(name);
}

DDim CompileTimeInferShapeContext::GetDim(const std::string &name) const {
auto var = block_.FindVarRecursive(name);
PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name);
return framework::make_ddim(var->Shape());
}

void CompileTimeInferShapeContext::SetDim(const std::string &name,
const DDim &dim) {
block_.FindVarRecursive(name)->SetShape(framework::vectorize(dim));
}

} // namespace framework
} // namespace paddle
1 change: 1 addition & 0 deletions paddle/framework/op_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ limitations under the License. */
#include "paddle/framework/op_desc.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/shape_inference.h"

namespace paddle {
namespace framework {
Expand Down
132 changes: 132 additions & 0 deletions paddle/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/framework/operator.h"
#include <algorithm>
#include <atomic>
#include "paddle/framework/shape_inference.h"

namespace paddle {
namespace framework {
Expand Down Expand Up @@ -273,5 +274,136 @@ bool OpSupportGPU(const std::string& op_type) {
return false;
}

class RuntimeInferShapeContext : public InferShapeContext {
public:
RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope)
: op_(op), scope_(scope) {}

bool HasInput(const std::string& name) const override {
auto& ins = Inputs(name);
size_t length = ins.size();
if (length == 0) {
return false;
}
PADDLE_ENFORCE_EQ(length, 1UL, "Input %s should have more than one inputs",
name);
auto ipt = ins[0];
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
return var != nullptr;
}

bool HasOutput(const std::string& name) const override {
auto& outs = Outputs(name);
size_t length = outs.size();
if (length == 0) {
return false;
}
PADDLE_ENFORCE_EQ(length, 1UL, "Output %s should have more than one inputs",
name);
auto ipt = outs[0];
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
return var != nullptr;
}

bool HasInputs(const std::string& name) const override {
auto inputs = op_.Inputs(name);
if (inputs.empty()) {
return false;
}
for (auto& input : inputs) {
if (scope_.FindVar(input) == nullptr) {
return false;
}
}
return true;
}

bool HasOutputs(const std::string& name) const override {
auto outputs = op_.Outputs(name);
if (outputs.empty()) {
return false;
}
for (auto& output : outputs) {
if (scope_.FindVar(output) == nullptr) {
return false;
}
}
return true;
}

DDim GetInputDim(const std::string& name) const override {
return GetDim(op_.Input(name));
}

void SetOutputDim(const std::string& name, const DDim& dim) override {
SetDim(op_.Output(name), dim);
}

AttrReader Attrs() const override { return AttrReader(op_.Attrs()); }

const std::vector<std::string>& Inputs(
const std::string& name) const override {
return op_.Inputs(name);
}

const std::vector<std::string>& Outputs(
const std::string& name) const override {
return op_.Outputs(name);
}

private:
DDim GetDim(const std::string& name) const override {
Variable* var = scope_.FindVar(name);
if (var->IsType<LoDTensor>()) {
return var->Get<LoDTensor>().dims();
} else if (var->IsType<SelectedRows>()) {
return var->Get<SelectedRows>().GetCompleteDims();
} else {
PADDLE_THROW("Variable type must be LoDTensor/SelectedRows.");
}
}

void SetDim(const std::string& name, const DDim& dim) override {
Variable* var = scope_.FindVar(name);
if (var->IsType<LoDTensor>()) {
var->GetMutable<LoDTensor>()->Resize(dim);
} else if (var->IsType<SelectedRows>()) {
var->GetMutable<SelectedRows>()->set_height(dim[0]);
} else {
PADDLE_THROW("Variable type must be LoDTensor/SelectedRows.");
}
}

const OperatorBase& op_;
const Scope& scope_;
};

void OperatorWithKernel::Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const {
VLOG(3) << "Running operator " << this->Type();
RuntimeInferShapeContext infer_shape_ctx(*this, scope);
this->InferShape(&infer_shape_ctx);

ExecutionContext ctx(*this, scope, dev_ctx);

// check if op[type] has kernel registered.
auto& all_op_kernels = AllOpKernels();
auto kernels_iter = all_op_kernels.find(type_);
if (kernels_iter == all_op_kernels.end()) {
PADDLE_THROW("op[%s] has no kernel", type_);
}

// check if op[type] have kernel for kernel_key
OpKernelMap& kernels = kernels_iter->second;
auto kernel_key = OpKernelKey(IndicateDataType(ctx), dev_ctx);
auto kernel_iter = kernels.find(kernel_key);

if (kernel_iter == kernels.end()) {
PADDLE_THROW("op[%s] has no kernel with kernel_key[%s]", type_, kernel_key);
}

kernel_iter->second->Compute(ctx);
}

} // namespace framework
} // namespace paddle
Loading