forked from PaddlePaddle/Paddle
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
184 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
#include "cinn/hlir/framework/instruction.h" | ||
|
||
namespace cinn { | ||
namespace hlir { | ||
namespace framework { | ||
|
||
std::vector<cinn_pod_value_t>& Instruction::PreparePodArgs() { | ||
if (!args_cached_.empty()) return args_cached_; | ||
|
||
common::ArgsBuilder builder; | ||
std::vector<std::string> all_args(in_args_.begin(), in_args_.end()); | ||
all_args.insert(std::end(all_args), out_args_.begin(), out_args_.end()); | ||
|
||
for (auto& arg : all_args) { | ||
auto* var = scope_->FindVar(arg); | ||
CHECK(var) << "Argument [" << arg << "] not found in the scope"; | ||
|
||
// TODO(Superjomn) Support other types. | ||
auto& tensor = std::get<Tensor>(*var); | ||
builder.Add(tensor.buffer()); | ||
} | ||
|
||
args_cached_ = builder.Build(); | ||
return args_cached_; | ||
} | ||
|
||
} // namespace framework | ||
} // namespace hlir | ||
} // namespace cinn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
#pragma once | ||
|
||
#include <string> | ||
#include <vector> | ||
|
||
#include "cinn/common/test_helper.h" | ||
#include "cinn/hlir/framework/scope.h" | ||
|
||
namespace cinn { | ||
namespace hlir { | ||
namespace framework { | ||
|
||
/** | ||
* Instruction is the basic executable element in runtime, it holds a pointer to the JIT-compiled LoweredFunc, and | ||
* collect the cinn_buffer of the inputs and outputs from the scope, prepare the arguments and finally pass them into | ||
* the LoweredFunc and execute it. | ||
*/ | ||
class Instruction { | ||
public: | ||
using infershape_t = std::function<void(Scope*, const std::vector<std::string>&)>; | ||
|
||
/** | ||
* Constructor. | ||
* @param target The \p target the instruction runs on. | ||
* @param scope The scope containing all the runtime variables(Tensors and PODs). | ||
* @param in_args The names of the inputs. | ||
* @param out_args The names of the outputs. | ||
* @param infershape The handler of this Instruction to perform shape inference. | ||
*/ | ||
Instruction(const Target& target, | ||
Scope* scope, | ||
const std::vector<std::string>& in_args, | ||
const std::vector<std::string>& out_args) | ||
: target_(target), scope_(scope), in_args_(in_args), out_args_(out_args) {} | ||
|
||
/** | ||
* Set compiled function address. | ||
* @param fn The JIT compiled function address. | ||
*/ | ||
void SetLoweredFunc(lower_func_ptr_t fn) { fn_ = fn; } | ||
|
||
/** | ||
* Run the Instruction. | ||
*/ | ||
void Run() { | ||
CHECK(fn_) << "The LoweredFunc address should be set first by calling SetLoweredFunc method"; | ||
auto& pod_args = PreparePodArgs(); | ||
fn_(pod_args.data(), pod_args.size()); | ||
} | ||
|
||
protected: | ||
std::vector<cinn_pod_value_t>& PreparePodArgs(); | ||
|
||
private: | ||
Scope* scope_{}; | ||
std::vector<std::string> in_args_; | ||
std::vector<std::string> out_args_; | ||
|
||
std::vector<cinn_pod_value_t> args_cached_; | ||
|
||
Target target_; | ||
|
||
lower_func_ptr_t fn_{}; | ||
}; | ||
|
||
} // namespace framework | ||
} // namespace hlir | ||
} // namespace cinn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
#include "cinn/hlir/framework/instruction.h" | ||
|
||
#include <gtest/gtest.h> | ||
|
||
#include <memory> | ||
#include <string> | ||
#include <utility> | ||
#include <vector> | ||
|
||
#include "cinn/backends/llvm/simple_jit.h" | ||
|
||
namespace cinn { | ||
namespace hlir { | ||
namespace framework { | ||
|
||
std::unique_ptr<backends::SimpleJIT> GetLoweredFunc(int M, int N) { | ||
Expr m(M); | ||
Expr n(N); | ||
|
||
Placeholder<float> x("x", {m, n}); | ||
Placeholder<float> y("y", {m, n}); | ||
|
||
auto z = Compute( | ||
{m, n}, [=](Expr i, Expr j) { return x(i, j) + y(i, j); }, "z"); | ||
|
||
auto fn = Lower("fn", {x, y, z}); | ||
|
||
lang::Module::Builder builder("some_module", common::DefaultHostTarget()); | ||
builder.AddFunction(fn); | ||
|
||
auto jit = backends::SimpleJIT::Create(); | ||
jit->Link(builder.Build()); | ||
return std::move(jit); | ||
} | ||
|
||
TEST(Instruction, basic) { | ||
const int M = 10; | ||
const int N = 20; | ||
|
||
Scope scope; | ||
|
||
auto get_tensor = [&](const std::string& name) { | ||
auto* var = scope.Var<Tensor>(name); | ||
auto& tensor = std::get<Tensor>(*var); | ||
return tensor; | ||
}; | ||
|
||
for (auto& name : std::vector<std::string>({"x", "y", "z"})) { | ||
auto tensor = get_tensor(name); | ||
tensor.Resize(Shape{{M, N}}); | ||
auto* data = tensor.mutable_data<float>(common::DefaultHostTarget()); | ||
for (int i = 0; i < M * N; i++) { | ||
data[i] = (rand() * 1.f) / RAND_MAX; // NOLINT | ||
} | ||
} | ||
|
||
// create Instruction | ||
Instruction instr(common::DefaultHostTarget(), &scope, {"x", "y"}, {"z"}); | ||
auto jit = GetLoweredFunc(M, N); | ||
auto fn_addr = jit->Lookup("fn"); | ||
CHECK(fn_addr); | ||
|
||
instr.SetLoweredFunc(reinterpret_cast<lower_func_ptr_t>(fn_addr)); | ||
instr.Run(); | ||
|
||
// check result | ||
{ | ||
auto xd = get_tensor("x").data<float>(); | ||
auto yd = get_tensor("y").data<float>(); | ||
auto zd = get_tensor("z").data<float>(); | ||
|
||
for (int i = 0; i < M * N; i++) { | ||
LOG_FIRST_N(INFO, 3) << "data: " << xd[i] << " + " << yd[i] << " = " << zd[i]; | ||
ASSERT_NEAR(xd[i] + yd[i], zd[i], 1e-5); | ||
} | ||
} | ||
} | ||
|
||
} // namespace framework | ||
} // namespace hlir | ||
} // namespace cinn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters