Skip to content

Commit

Permalink
add Instruction (PaddlePaddle#162)
Browse files Browse the repository at this point in the history
  • Loading branch information
Superjomn authored Aug 10, 2020
1 parent 2710d85 commit fc6992f
Show file tree
Hide file tree
Showing 6 changed files with 184 additions and 1 deletion.
1 change: 1 addition & 0 deletions cinn/common/test_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ struct ArgsBuilder {
template <typename T>
ArgsBuilder& Add(T x) {
data_.emplace_back(x);
LOG(INFO) << "ArgsBuilder add a " << data_.back().type_code();
return *this;
}

Expand Down
2 changes: 2 additions & 0 deletions cinn/hlir/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ set(srcs
variable.cc
buffer.cc
memory.cc
instruction.cc
)

if(WITH_CUDA)
Expand All @@ -14,6 +15,7 @@ endif()

cc_test(test_hlir_framework_tensor SRCS tensor_test.cc DEPS core)
cc_test(test_hlir_framework_scope SRCS scope_test.cc DEPS core)
cc_test(test_hlir_framework_instruction SRCS instruction_test.cc DEPS core)


foreach(cpp ${srcs})
Expand Down
29 changes: 29 additions & 0 deletions cinn/hlir/framework/instruction.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#include "cinn/hlir/framework/instruction.h"

namespace cinn {
namespace hlir {
namespace framework {

std::vector<cinn_pod_value_t>& Instruction::PreparePodArgs() {
if (!args_cached_.empty()) return args_cached_;

common::ArgsBuilder builder;
std::vector<std::string> all_args(in_args_.begin(), in_args_.end());
all_args.insert(std::end(all_args), out_args_.begin(), out_args_.end());

for (auto& arg : all_args) {
auto* var = scope_->FindVar(arg);
CHECK(var) << "Argument [" << arg << "] not found in the scope";

// TODO(Superjomn) Support other types.
auto& tensor = std::get<Tensor>(*var);
builder.Add(tensor.buffer());
}

args_cached_ = builder.Build();
return args_cached_;
}

} // namespace framework
} // namespace hlir
} // namespace cinn
68 changes: 68 additions & 0 deletions cinn/hlir/framework/instruction.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#pragma once

#include <string>
#include <vector>

#include "cinn/common/test_helper.h"
#include "cinn/hlir/framework/scope.h"

namespace cinn {
namespace hlir {
namespace framework {

/**
* Instruction is the basic executable element in runtime, it holds a pointer to the JIT-compiled LoweredFunc, and
* collect the cinn_buffer of the inputs and outputs from the scope, prepare the arguments and finally pass them into
* the LoweredFunc and execute it.
*/
class Instruction {
public:
using infershape_t = std::function<void(Scope*, const std::vector<std::string>&)>;

/**
* Constructor.
* @param target The \p target the instruction runs on.
* @param scope The scope containing all the runtime variables(Tensors and PODs).
* @param in_args The names of the inputs.
* @param out_args The names of the outputs.
* @param infershape The handler of this Instruction to perform shape inference.
*/
Instruction(const Target& target,
Scope* scope,
const std::vector<std::string>& in_args,
const std::vector<std::string>& out_args)
: target_(target), scope_(scope), in_args_(in_args), out_args_(out_args) {}

/**
* Set compiled function address.
* @param fn The JIT compiled function address.
*/
void SetLoweredFunc(lower_func_ptr_t fn) { fn_ = fn; }

/**
* Run the Instruction.
*/
void Run() {
CHECK(fn_) << "The LoweredFunc address should be set first by calling SetLoweredFunc method";
auto& pod_args = PreparePodArgs();
fn_(pod_args.data(), pod_args.size());
}

protected:
std::vector<cinn_pod_value_t>& PreparePodArgs();

private:
Scope* scope_{};
std::vector<std::string> in_args_;
std::vector<std::string> out_args_;

std::vector<cinn_pod_value_t> args_cached_;

Target target_;

lower_func_ptr_t fn_{};
};

} // namespace framework
} // namespace hlir
} // namespace cinn
81 changes: 81 additions & 0 deletions cinn/hlir/framework/instruction_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
#include "cinn/hlir/framework/instruction.h"

#include <gtest/gtest.h>

#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "cinn/backends/llvm/simple_jit.h"

namespace cinn {
namespace hlir {
namespace framework {

std::unique_ptr<backends::SimpleJIT> GetLoweredFunc(int M, int N) {
Expr m(M);
Expr n(N);

Placeholder<float> x("x", {m, n});
Placeholder<float> y("y", {m, n});

auto z = Compute(
{m, n}, [=](Expr i, Expr j) { return x(i, j) + y(i, j); }, "z");

auto fn = Lower("fn", {x, y, z});

lang::Module::Builder builder("some_module", common::DefaultHostTarget());
builder.AddFunction(fn);

auto jit = backends::SimpleJIT::Create();
jit->Link(builder.Build());
return std::move(jit);
}

TEST(Instruction, basic) {
const int M = 10;
const int N = 20;

Scope scope;

auto get_tensor = [&](const std::string& name) {
auto* var = scope.Var<Tensor>(name);
auto& tensor = std::get<Tensor>(*var);
return tensor;
};

for (auto& name : std::vector<std::string>({"x", "y", "z"})) {
auto tensor = get_tensor(name);
tensor.Resize(Shape{{M, N}});
auto* data = tensor.mutable_data<float>(common::DefaultHostTarget());
for (int i = 0; i < M * N; i++) {
data[i] = (rand() * 1.f) / RAND_MAX; // NOLINT
}
}

// create Instruction
Instruction instr(common::DefaultHostTarget(), &scope, {"x", "y"}, {"z"});
auto jit = GetLoweredFunc(M, N);
auto fn_addr = jit->Lookup("fn");
CHECK(fn_addr);

instr.SetLoweredFunc(reinterpret_cast<lower_func_ptr_t>(fn_addr));
instr.Run();

// check result
{
auto xd = get_tensor("x").data<float>();
auto yd = get_tensor("y").data<float>();
auto zd = get_tensor("z").data<float>();

for (int i = 0; i < M * N; i++) {
LOG_FIRST_N(INFO, 3) << "data: " << xd[i] << " + " << yd[i] << " = " << zd[i];
ASSERT_NEAR(xd[i] + yd[i], zd[i], 1e-5);
}
}
}

} // namespace framework
} // namespace hlir
} // namespace cinn
4 changes: 3 additions & 1 deletion cinn/hlir/framework/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,11 @@ class Tensor final {

template <typename T>
const T* data() const {
return buffer_->data()->memory;
return reinterpret_cast<T*>(buffer_->data()->memory);
}

cinn_buffer_t* buffer() { return buffer_->data(); }

private:
// A shared ptr to make it easier to share buffer between tensors.
std::shared_ptr<Buffer> buffer_;
Expand Down

0 comments on commit fc6992f

Please sign in to comment.