forked from PaddlePaddle/Paddle
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add NetBuilder API for building a frontend Program. (PaddlePaddle#446)
* Add the coarse builder for building a Program. * Replace CoarseBuilder with NetBuilder. Update codes as suggested by the reviewers. * Add CUDA Mem Set for SetRandData. * Remove the `symbolization` dir. * Call the method `Validate` immediately of a program after building.
- Loading branch information
Showing
8 changed files
with
532 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,3 +19,4 @@ gen_modules | |
docs/source/cpp | ||
docs/source/doxygen_output | ||
docs/source/tutorials | ||
.vscode* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
#include "cinn/frontend/base_builder.h" | ||
|
||
#include <string> | ||
#include <utility> | ||
|
||
#include "cinn/common/common.h" | ||
#include "cinn/common/context.h" | ||
|
||
namespace cinn { | ||
namespace frontend { | ||
|
||
Program BaseBuilder::Build() { | ||
Program program{std::move(instrs_), std::move(inputs_)}; | ||
program.Validate(); | ||
return program; | ||
} | ||
|
||
Placeholder BaseBuilder::CreateInput(const common::Type& type, | ||
const std::vector<int>& shape, | ||
const std::string& id_hint) { | ||
if (!id_hint.empty()) { | ||
CheckVarNameValid(id_hint); | ||
} | ||
std::string id = id_hint.empty() ? common::Context::Global().NewName("placeholder") : id_hint; | ||
|
||
inputs_.emplace_back(id); | ||
auto& var = inputs_.back(); | ||
var->type = type; | ||
var->shape = shape; | ||
return Placeholder(var); | ||
} | ||
|
||
} // namespace frontend | ||
} // namespace cinn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
#pragma once | ||
|
||
#include <string> | ||
#include <utility> | ||
#include <vector> | ||
|
||
#include "cinn/common/type.h" | ||
#include "cinn/frontend/syntax.h" | ||
|
||
namespace cinn { | ||
namespace frontend { | ||
|
||
class BaseBuilder { | ||
public: | ||
explicit BaseBuilder(const std::string& name) : name_(name) {} | ||
|
||
Program Build(); | ||
|
||
Placeholder CreateInput(const common::Type& type, const std::vector<int>& shape, const std::string& id_hint = ""); | ||
|
||
// name of this builder | ||
const std::string& name() { return name_; } | ||
|
||
virtual ~BaseBuilder() {} | ||
|
||
protected: | ||
void AppendInstruction(const Instruction& instr) { instrs_.push_back(instr); } | ||
|
||
std::string name_; | ||
std::vector<Instruction> instrs_; | ||
std::vector<Variable> inputs_; | ||
}; | ||
|
||
} // namespace frontend | ||
} // namespace cinn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
#include "cinn/frontend/net_builder.h" | ||
|
||
#include <string> | ||
#include <unordered_map> | ||
#include <utility> | ||
|
||
#include "cinn/frontend/syntax.h" | ||
|
||
namespace cinn { | ||
namespace frontend { | ||
|
||
Variable NetBuilder::add(const Variable& a, const Variable& b) { | ||
Instruction instr("elementwise_add", {a, b}); | ||
AppendInstruction(instr); | ||
return instr.GetOutput(0); | ||
} | ||
|
||
Variable NetBuilder::mul(const Variable& a, const Variable& b, int x_num_col_dims, int y_num_col_dims) { | ||
Instruction instr("mul", {a, b}); | ||
instr.SetAttr("x_num_col_dims", x_num_col_dims); | ||
instr.SetAttr("y_num_col_dims", y_num_col_dims); | ||
AppendInstruction(instr); | ||
return instr.GetOutput(0); | ||
} | ||
|
||
Variable NetBuilder::mulbias( | ||
const Variable& a, const Variable& b, const Variable& c, int x_num_col_dims, int y_num_col_dims) { | ||
Instruction instr("mulbias", {a, b, c}); | ||
instr.SetAttr("x_num_col_dims", x_num_col_dims); | ||
instr.SetAttr("y_num_col_dims", y_num_col_dims); | ||
AppendInstruction(instr); | ||
return instr.GetOutput(1); | ||
} | ||
|
||
Variable NetBuilder::elementwise_add(const Variable& a, const Variable& b, int axis) { | ||
Instruction instr("elementwise_add", {a, b}); | ||
instr.SetAttr("axis", axis); | ||
AppendInstruction(instr); | ||
return instr.GetOutput(0); | ||
} | ||
|
||
Variable NetBuilder::elementwise_mul(const Variable& a, const Variable& b, int axis) { | ||
Instruction instr("elementwise_mul", {a, b}); | ||
instr.SetAttr("axis", axis); | ||
AppendInstruction(instr); | ||
return instr.GetOutput(0); | ||
} | ||
|
||
Variable NetBuilder::relu(const Variable& a) { | ||
Instruction instr("relu", {a}); | ||
AppendInstruction(instr); | ||
return instr.GetOutput(0); | ||
} | ||
|
||
Variable NetBuilder::relu6(const Variable& a, float threshold) { | ||
Instruction instr("relu6", {a}); | ||
instr.SetAttr("threshold", threshold); | ||
AppendInstruction(instr); | ||
return instr.GetOutput(0); | ||
} | ||
|
||
Variable NetBuilder::conv2d(const Variable& a, | ||
const Variable& b, | ||
const std::vector<int>& strides, | ||
const std::vector<int>& paddings, | ||
const std::vector<int>& dilations, | ||
int groups, | ||
const std::string& data_format, | ||
const std::string& padding_algorithm) { | ||
Instruction instr("conv2d"); | ||
instr.SetInputs({a, b}); | ||
instr.SetAttr("strides", strides); | ||
instr.SetAttr("paddings", paddings); | ||
instr.SetAttr("dilations", dilations); | ||
instr.SetAttr("groups", groups); | ||
instr.SetAttr("data_format", data_format); | ||
instr.SetAttr("padding_algorithm", padding_algorithm); | ||
AppendInstruction(instr); | ||
return instr.GetOutput(0); | ||
} | ||
|
||
Variable NetBuilder::depthwise_conv2d(const Variable& a, | ||
const Variable& b, | ||
const std::vector<int>& strides, | ||
const std::vector<int>& paddings, | ||
const std::vector<int>& dilations, | ||
int groups, | ||
const std::string& data_format, | ||
const std::string& padding_algorithm) { | ||
Instruction instr("depthwise_conv2d"); | ||
instr.SetInputs({a, b}); | ||
instr.SetAttr("strides", strides); | ||
instr.SetAttr("paddings", paddings); | ||
instr.SetAttr("dilations", dilations); | ||
instr.SetAttr("groups", groups); | ||
instr.SetAttr("data_format", data_format); | ||
instr.SetAttr("padding_algorithm", padding_algorithm); | ||
AppendInstruction(instr); | ||
return instr.GetOutput(0); | ||
} | ||
|
||
Variable NetBuilder::pool2d(const Variable& a, | ||
const std::string& pooling_type, | ||
const std::vector<int>& ksize, | ||
const std::vector<int>& strides, | ||
const std::vector<int>& paddings, | ||
bool ceil_mode, | ||
bool exclusive, | ||
bool global_pooling, | ||
const std::string& data_format, | ||
bool adaptive, | ||
const std::string& padding_algorithm) { | ||
Instruction instr("pool2d"); | ||
instr.SetInputs({a}); | ||
instr.SetAttr("pooling_type", pooling_type); | ||
instr.SetAttr("ksize", ksize); | ||
instr.SetAttr("strides", strides); | ||
instr.SetAttr("paddings", paddings); | ||
instr.SetAttr("ceil_mode", ceil_mode); | ||
instr.SetAttr("exclusive", exclusive); | ||
instr.SetAttr("global_pooling", global_pooling); | ||
instr.SetAttr("data_format", data_format); | ||
instr.SetAttr("adaptive", adaptive); | ||
instr.SetAttr("padding_algorithm", padding_algorithm); | ||
AppendInstruction(instr); | ||
return instr.GetOutput(0); | ||
} | ||
|
||
Variable NetBuilder::batchnorm(const Variable& a, | ||
const Variable& scale, | ||
const Variable& bias, | ||
const Variable& mean, | ||
const Variable& variance, | ||
float epsilon, | ||
float momentum, | ||
const std::string& data_layout) { | ||
Instruction instr("batchnorm"); | ||
instr.SetInputs({a, scale, bias, mean, variance}); | ||
instr.SetAttr("epsilon", epsilon); | ||
instr.SetAttr("momentum", momentum); | ||
instr.SetAttr("data_layout", data_layout); | ||
AppendInstruction(instr); | ||
return instr.GetOutput(0); | ||
} | ||
|
||
Variable NetBuilder::scale(const Variable& a, float scale, float bias, bool bias_after_scale) { | ||
Instruction instr("scale", {a}); | ||
instr.SetAttr("scale", scale); | ||
instr.SetAttr("bias", bias); | ||
instr.SetAttr("bias_after_scale", bias_after_scale); | ||
AppendInstruction(instr); | ||
return instr.GetOutput(0); | ||
} | ||
|
||
Variable NetBuilder::softmax(const Variable& a, int axis, const std::string& data_format) { | ||
Instruction instr("softmax", {a}); | ||
instr.SetAttr("axis", axis); | ||
instr.SetAttr("data_format", data_format); | ||
AppendInstruction(instr); | ||
return instr.GetOutput(0); | ||
} | ||
|
||
Variable NetBuilder::sigmoid(const Variable& a) { | ||
Instruction instr("sigmoid", {a}); | ||
AppendInstruction(instr); | ||
return instr.GetOutput(0); | ||
} | ||
|
||
Variable NetBuilder::slice(const Variable& a, | ||
const std::vector<int>& axes, | ||
const std::vector<int>& starts, | ||
const std::vector<int>& ends, | ||
const std::vector<int>& infer_flags, | ||
const std::vector<int>& decrease_axis) { | ||
Instruction instr("slice", {a}); | ||
instr.SetAttr("axes", axes); | ||
instr.SetAttr("starts", starts); | ||
instr.SetAttr("ends", ends); | ||
instr.SetAttr("infer_flags", infer_flags); | ||
instr.SetAttr("decrease_axis", decrease_axis); | ||
AppendInstruction(instr); | ||
return instr.GetOutput(0); | ||
} | ||
|
||
Variable NetBuilder::dropout_infer(const Variable& a, float dropout_prob, const std::string& dropout_implementation) { | ||
Instruction instr("dropout_infer", {a}); | ||
instr.SetAttr("dropout_prob", dropout_prob); | ||
instr.SetAttr("dropout_implementation", dropout_implementation); | ||
AppendInstruction(instr); | ||
return instr.GetOutput(0); | ||
} | ||
|
||
} // namespace frontend | ||
} // namespace cinn |
Oops, something went wrong.