Skip to content

Commit

Permalink
Init hlir framework (PaddlePaddle#161)
Browse files Browse the repository at this point in the history
* init hlir framework

* rename host_memory to memory in cinn_buffer_t
  • Loading branch information
Superjomn authored Aug 10, 2020
1 parent 2fe7162 commit 2710d85
Show file tree
Hide file tree
Showing 45 changed files with 711 additions and 207 deletions.
2 changes: 1 addition & 1 deletion cinn/backends/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ void CodeGenC::PrintCall_buffer_get_data_handle(const ir::Call *op) {
auto *buffer = op->read_args[0].As<ir::_Buffer_>();
os() << buffer->name;
os() << "->";
os() << "host_memory";
os() << "memory";
}

void CodeGenC::PrintCall_get_address(const ir::Call *op) {
Expand Down
44 changes: 22 additions & 22 deletions cinn/backends/codegen_c_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ void add1(void* _args, int32_t num_args)
const cinn_buffer_t* _B = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1]));
cinn_buffer_t* _C = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[2]));
cinn_buffer_malloc((void*)(0), _C);
const float* A = ((const float*)(_A->host_memory));
const float* B = ((const float*)(_B->host_memory));
float* C = ((float*)(_C->host_memory));
const float* A = ((const float*)(_A->memory));
const float* B = ((const float*)(_B->memory));
float* C = ((float*)(_C->memory));
for (int32_t i = 0; i < 100; i += 1) {
for (int32_t j = 0; j < 20; j += 1) {
C[((20 * i) + j)] = (A[((20 * i) + j)] + B[((20 * i) + j)]);
Expand Down Expand Up @@ -166,10 +166,10 @@ void add1(void* _args, int32_t num_args)
cinn_buffer_t* _D = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[3]));
cinn_buffer_malloc((void*)(0), _C);
cinn_buffer_malloc((void*)(0), _D);
const float* A = ((const float*)(_A->host_memory));
const float* B = ((const float*)(_B->host_memory));
float* C = ((float*)(_C->host_memory));
float* D = ((float*)(_D->host_memory));
const float* A = ((const float*)(_A->memory));
const float* B = ((const float*)(_B->memory));
float* C = ((float*)(_C->memory));
float* D = ((float*)(_D->memory));
for (int32_t i_outer = 0; i_outer < 25; i_outer += 1) {
for (int32_t i_inner = 0; i_inner < 4; i_inner += 1) {
for (int32_t j = 0; j < 20; j += 1) {
Expand Down Expand Up @@ -251,10 +251,10 @@ void matmul(void* _args, int32_t num_args)
const cinn_buffer_t* _B = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1]));
cinn_buffer_t* _C = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[2]));
cinn_buffer_malloc((void*)(0), _C);
const float* A = ((const float*)(_A->host_memory));
const float* B = ((const float*)(_B->host_memory));
float* C = ((float*)(_C->host_memory));
float* C_init = ((float*)(_C->host_memory));
const float* A = ((const float*)(_A->memory));
const float* B = ((const float*)(_B->memory));
float* C = ((float*)(_C->memory));
float* C_init = ((float*)(_C->memory));
for (int32_t i = 0; i < 100; i += 1) {
for (int32_t j = 0; j < 50; j += 1) {
C_init[((50 * i) + j)] = 0;
Expand All @@ -272,9 +272,9 @@ void main(void* _args, int32_t num_args)
const cinn_buffer_t* _B = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1]));
cinn_buffer_t* _C = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[2]));
cinn_buffer_malloc((void*)(0), _C);
const float* A = ((const float*)(_A->host_memory));
const float* B = ((const float*)(_B->host_memory));
float* C = ((float*)(_C->host_memory));
const float* A = ((const float*)(_A->memory));
const float* B = ((const float*)(_B->memory));
float* C = ((float*)(_C->memory));
{
cinn_pod_value_t _pod_val__8;
buffer_p_to_cinn_pod_value(_A, &_pod_val__8);
Expand Down Expand Up @@ -351,10 +351,10 @@ void matmul(void* _args, int32_t num_args)
const cinn_buffer_t* _B = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1]));
cinn_buffer_t* _C = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[2]));
cinn_buffer_malloc((void*)(0), _C);
const float* A = ((const float*)(_A->host_memory));
const float* B = ((const float*)(_B->host_memory));
float* C = ((float*)(_C->host_memory));
float* C_init = ((float*)(_C->host_memory));
const float* A = ((const float*)(_A->memory));
const float* B = ((const float*)(_B->memory));
float* C = ((float*)(_C->memory));
float* C_init = ((float*)(_C->memory));
for (int32_t i_outer = 0; i_outer < 4; i_outer += 1) {
for (int32_t j_outer = 0; j_outer < 16; j_outer += 1) {
for (int32_t i_inner = 0; i_inner < (1 + ((int32_t)(cinn_min(31, (99 + (-32 * i_outer)))))); i_inner += 1) {
Expand Down Expand Up @@ -431,10 +431,10 @@ void matmul_with_packing(void* _args, int32_t num_args)
cinn_buffer_t* _C = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[3]));
cinn_buffer_malloc((void*)(0), _PackedB);
cinn_buffer_malloc((void*)(0), _C);
const float* A = ((const float*)(_A->host_memory));
const float* B = ((const float*)(_B->host_memory));
float* C = ((float*)(_C->host_memory));
float* PackedB = ((float*)(_PackedB->host_memory));
const float* A = ((const float*)(_A->memory));
const float* B = ((const float*)(_B->memory));
float* C = ((float*)(_C->memory));
float* PackedB = ((float*)(_PackedB->memory));
for (int32_t i = 0; i < 15; i += 1) {
for (int32_t j = 0; j < 200; j += 1) {
for (int32_t k = 0; k < 32; k += 1) {
Expand Down
60 changes: 30 additions & 30 deletions cinn/backends/codegen_cuda_dev_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -417,9 +417,9 @@ TEST(CodeGenCUDA, jit_host_call_cuda_kernel) {
cinn_buffer_t* C_buf =
cinn_buffer_new(cinn_x86_device, cinn_float32_t(), std::vector<int>{{M.as_int32(), N.as_int32()}});

A_buf->host_memory = reinterpret_cast<uint8_t*>(Ad);
B_buf->host_memory = reinterpret_cast<uint8_t*>(Bd);
C_buf->host_memory = reinterpret_cast<uint8_t*>(Cd);
A_buf->memory = reinterpret_cast<uint8_t*>(Ad);
B_buf->memory = reinterpret_cast<uint8_t*>(Bd);
C_buf->memory = reinterpret_cast<uint8_t*>(Cd);

CUDA_CALL(cudaDeviceSynchronize());

Expand Down Expand Up @@ -714,9 +714,9 @@ TEST(elementwise_add, share_local_cache) {
cinn_buffer_t* C_buf =
cinn_buffer_new(cinn_x86_device, cinn_float32_t(), std::vector<int>{{M.as_int32(), N.as_int32()}});

A_buf->host_memory = reinterpret_cast<uint8_t*>(Ad);
B_buf->host_memory = reinterpret_cast<uint8_t*>(Bd);
C_buf->host_memory = reinterpret_cast<uint8_t*>(Cd);
A_buf->memory = reinterpret_cast<uint8_t*>(Ad);
B_buf->memory = reinterpret_cast<uint8_t*>(Bd);
C_buf->memory = reinterpret_cast<uint8_t*>(Cd);

CUDA_CALL(cudaDeviceSynchronize());

Expand Down Expand Up @@ -971,23 +971,23 @@ void fn0_kernel(const float* __restrict__ A, const float* __restrict__ B, float*

cinn_buffer_t* dev_bufs[3];
for (int i = 0; i < 3; i++) dev_bufs[i] = new cinn_buffer_t;
dev_bufs[0]->host_memory = reinterpret_cast<uint8_t*>(A_dev);
dev_bufs[1]->host_memory = reinterpret_cast<uint8_t*>(B_dev);
dev_bufs[2]->host_memory = reinterpret_cast<uint8_t*>(C_dev);
auto args = common::ArgsBuilder().Add(dev_bufs[0]).Add(dev_bufs[1]).Add(dev_bufs[2]).Build();
dev_bufs[0]->memory = reinterpret_cast<uint8_t*>(A_dev);
dev_bufs[1]->memory = reinterpret_cast<uint8_t*>(B_dev);
dev_bufs[2]->memory = reinterpret_cast<uint8_t*>(C_dev);
auto args = common::ArgsBuilder().Add(dev_bufs[0]).Add(dev_bufs[1]).Add(dev_bufs[2]).Build();

CUDA_CALL(cudaDeviceSynchronize());
tester("fn0", args.data(), args.size());
CUDA_CALL(cudaDeviceSynchronize());

CUDA_CALL(cudaMemcpy(reinterpret_cast<void*>(C_target_host->host_memory),
CUDA_CALL(cudaMemcpy(reinterpret_cast<void*>(C_target_host->memory),
C_dev,
C_target_host->num_elements() * sizeof(float),
cudaMemcpyDeviceToHost));

auto* C_target_mem = reinterpret_cast<float*>(C_target_host->host_memory);
auto* A_mem = reinterpret_cast<float*>(A_host->host_memory);
auto* B_mem = reinterpret_cast<float*>(B_host->host_memory);
auto* C_target_mem = reinterpret_cast<float*>(C_target_host->memory);
auto* A_mem = reinterpret_cast<float*>(A_host->memory);
auto* B_mem = reinterpret_cast<float*>(B_host->memory);
for (int i = 0; i < C_target_host->num_elements(); i++) {
ASSERT_NEAR(C_target_mem[i], A_mem[i] + B_mem[i], 1e-5);
}
Expand Down Expand Up @@ -1099,23 +1099,23 @@ void fn1_kernel(const float* __restrict__ A, const float* __restrict__ B, float*

cinn_buffer_t* dev_bufs[3];
for (int i = 0; i < 3; i++) dev_bufs[i] = new cinn_buffer_t;
dev_bufs[0]->host_memory = reinterpret_cast<uint8_t*>(A_dev);
dev_bufs[1]->host_memory = reinterpret_cast<uint8_t*>(B_dev);
dev_bufs[2]->host_memory = reinterpret_cast<uint8_t*>(C_dev);
auto args = common::ArgsBuilder().Add(dev_bufs[0]).Add(dev_bufs[1]).Add(dev_bufs[2]).Build();
dev_bufs[0]->memory = reinterpret_cast<uint8_t*>(A_dev);
dev_bufs[1]->memory = reinterpret_cast<uint8_t*>(B_dev);
dev_bufs[2]->memory = reinterpret_cast<uint8_t*>(C_dev);
auto args = common::ArgsBuilder().Add(dev_bufs[0]).Add(dev_bufs[1]).Add(dev_bufs[2]).Build();

CUDA_CALL(cudaDeviceSynchronize());
tester("fn1", args.data(), args.size());
CUDA_CALL(cudaDeviceSynchronize());

CUDA_CALL(cudaMemcpy(reinterpret_cast<void*>(C_target_host->host_memory),
CUDA_CALL(cudaMemcpy(reinterpret_cast<void*>(C_target_host->memory),
C_dev,
C_target_host->num_elements() * sizeof(float),
cudaMemcpyDeviceToHost));

auto* C_target_mem = reinterpret_cast<float*>(C_target_host->host_memory);
auto* A_mem = reinterpret_cast<float*>(A_host->host_memory);
auto* B_mem = reinterpret_cast<float*>(B_host->host_memory);
auto* C_target_mem = reinterpret_cast<float*>(C_target_host->memory);
auto* A_mem = reinterpret_cast<float*>(A_host->memory);
auto* B_mem = reinterpret_cast<float*>(B_host->memory);
for (int i = 0; i < M.as_int32() - 2; i++) {
for (int j = 0; j < N.as_int32(); j++) {
ASSERT_NEAR(C_target_mem[i * N.as_int32() + j],
Expand Down Expand Up @@ -1167,23 +1167,23 @@ void TestElementwiseAddPrecisionBasic(

cinn_buffer_t* dev_bufs[3];
for (int i = 0; i < 3; i++) dev_bufs[i] = new cinn_buffer_t;
dev_bufs[0]->host_memory = reinterpret_cast<uint8_t*>(A_dev);
dev_bufs[1]->host_memory = reinterpret_cast<uint8_t*>(B_dev);
dev_bufs[2]->host_memory = reinterpret_cast<uint8_t*>(C_dev);
auto args = common::ArgsBuilder().Add(dev_bufs[0]).Add(dev_bufs[1]).Add(dev_bufs[2]).Build();
dev_bufs[0]->memory = reinterpret_cast<uint8_t*>(A_dev);
dev_bufs[1]->memory = reinterpret_cast<uint8_t*>(B_dev);
dev_bufs[2]->memory = reinterpret_cast<uint8_t*>(C_dev);
auto args = common::ArgsBuilder().Add(dev_bufs[0]).Add(dev_bufs[1]).Add(dev_bufs[2]).Build();

CUDA_CALL(cudaDeviceSynchronize());
tester(fn_name, args.data(), args.size());
CUDA_CALL(cudaDeviceSynchronize());

CUDA_CALL(cudaMemcpy(reinterpret_cast<void*>(C_target_host->host_memory),
CUDA_CALL(cudaMemcpy(reinterpret_cast<void*>(C_target_host->memory),
C_dev,
C_target_host->num_elements() * sizeof(float),
cudaMemcpyDeviceToHost));

auto* C_target_mem = reinterpret_cast<float*>(C_target_host->host_memory);
auto* A_mem = reinterpret_cast<float*>(A_host->host_memory);
auto* B_mem = reinterpret_cast<float*>(B_host->host_memory);
auto* C_target_mem = reinterpret_cast<float*>(C_target_host->memory);
auto* A_mem = reinterpret_cast<float*>(A_host->memory);
auto* B_mem = reinterpret_cast<float*>(B_host->memory);
for (int i = 0; i < M.as_int32() - 2; i++) {
for (int j = 0; j < N.as_int32(); j++) {
ASSERT_NEAR(
Expand Down
6 changes: 3 additions & 3 deletions cinn/backends/llvm/codegen_x86_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ TEST(Vectorize, basic) {

fn_ptr(reinterpret_cast<void**>(args.data()), args.size());

auto* A_data = reinterpret_cast<float*>(A_buf->host_memory);
auto* B_data = reinterpret_cast<float*>(B_buf->host_memory);
auto* C_data = reinterpret_cast<float*>(C_buf->host_memory);
auto* A_data = reinterpret_cast<float*>(A_buf->memory);
auto* B_data = reinterpret_cast<float*>(B_buf->memory);
auto* C_data = reinterpret_cast<float*>(C_buf->memory);
for (int i = 0; i < C_buf->num_elements(); i++) {
ASSERT_NEAR(A_data[i] + B_data[i], C_data[i], 1e-5);
}
Expand Down
24 changes: 12 additions & 12 deletions cinn/backends/llvm/execution_engine_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,15 @@ auto CreateTestBuffer() {
cinn_buffer_malloc(nullptr, A);
cinn_buffer_malloc(nullptr, B);
cinn_buffer_malloc(nullptr, C);
float *Ad = reinterpret_cast<float *>(A->host_memory);
float *Bd = reinterpret_cast<float *>(B->host_memory);
float *Ad = reinterpret_cast<float *>(A->memory);
float *Bd = reinterpret_cast<float *>(B->memory);

for (int i = 0; i < A->num_elements(); i++) {
Ad[i] = static_cast<float>(rand()) / RAND_MAX; // NOLINT
Bd[i] = static_cast<float>(rand()) / RAND_MAX; // NOLINT
}

float *Cd = reinterpret_cast<float *>(C->host_memory);
float *Cd = reinterpret_cast<float *>(C->memory);
CHECK_EQ(C->num_elements(), A->num_elements());

return std::make_tuple(A, B, C);
Expand Down Expand Up @@ -119,9 +119,9 @@ TEST(llvm_test01, elementwise_add) {
cinn_pod_value_t args[3] = {a_arg, b_arg, c_arg};
elementwise_add(args, 3);

float *ad = reinterpret_cast<float *>(a->host_memory);
float *bd = reinterpret_cast<float *>(b->host_memory);
float *cd = reinterpret_cast<float *>(c->host_memory);
float *ad = reinterpret_cast<float *>(a->memory);
float *bd = reinterpret_cast<float *>(b->memory);
float *cd = reinterpret_cast<float *>(c->memory);

for (int i = 0; i < c->num_elements(); i++) {
EXPECT_EQ(ad[i] + bd[i], cd[i]);
Expand Down Expand Up @@ -180,11 +180,11 @@ TEST(llvm, module_call_lowered_func) {

elementwise_add(args, 3);

auto *ad = reinterpret_cast<float *>(ab->host_memory);
auto *bd = reinterpret_cast<float *>(bb->host_memory);
auto *ad = reinterpret_cast<float *>(ab->memory);
auto *bd = reinterpret_cast<float *>(bb->memory);
for (int i = 0; i < kM; i++) {
for (int j = 0; j < kN; j++) {
auto *data = reinterpret_cast<float *>(cb->host_memory);
auto *data = reinterpret_cast<float *>(cb->memory);
ASSERT_NEAR(data[i * kN + j], ad[i * kN + j] + bd[i * kN + j], 1e-5);
}
}
Expand Down Expand Up @@ -309,9 +309,9 @@ TEST(ExecutionEngine, call_extern) {

comp(args, 3);

auto *ad = reinterpret_cast<float *>(ab->host_memory);
auto *bd = reinterpret_cast<float *>(bb->host_memory);
auto *cd = reinterpret_cast<float *>(cb->host_memory);
auto *ad = reinterpret_cast<float *>(ab->memory);
auto *bd = reinterpret_cast<float *>(bb->memory);
auto *cd = reinterpret_cast<float *>(cb->memory);
for (int m = 0; m < kM; m++) {
for (int n = 0; n < kN; n++) {
ASSERT_NEAR(cd[m * kN + n], cinn_cpu_tanh_fp32(ad[m * kN + n] + bd[m * kN + n]), 1e-5);
Expand Down
4 changes: 2 additions & 2 deletions cinn/common/cuda_test_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,12 @@ void CudaModuleTester::Compile(const lang::Module& m, const std::string& rewrite
}

void* CudaModuleTester::CreateDeviceBuffer(const cinn_buffer_t* host_buffer) {
CHECK(host_buffer->host_memory);
CHECK(host_buffer->memory);
int num_bytes = host_buffer->num_elements() * sizeof(float);
CUdeviceptr data;
cuMemAlloc(&data, num_bytes);

CUDA_CALL(cudaMemcpy(reinterpret_cast<void*>(data), host_buffer->host_memory, num_bytes, cudaMemcpyHostToDevice));
CUDA_CALL(cudaMemcpy(reinterpret_cast<void*>(data), host_buffer->memory, num_bytes, cudaMemcpyHostToDevice));
return reinterpret_cast<void*>(data);
}

Expand Down
2 changes: 2 additions & 0 deletions cinn/common/macros.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,5 @@
void operator=(const TypeName&) = delete

#define CINN_NOT_IMPLEMENTED LOG(FATAL) << "Not Implemented";

#define CINN_RESULT_SHOULD_USE __attribute__((warn_unused_result))
18 changes: 18 additions & 0 deletions cinn/common/target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,5 +73,23 @@ std::ostream &operator<<(std::ostream &os, const Target &target) {
return os;
}

std::ostream &operator<<(std::ostream &os, Target::Arch arch) {
switch (arch) {
case Target::Arch::Unk:
os << "Unk";
break;
case Target::Arch::X86:
os << "X86";
break;
case Target::Arch::ARM:
os << "ARM";
break;
case Target::Arch::NVGPU:
os << "NVGPU";
break;
}
return os;
}

} // namespace common
} // namespace cinn
2 changes: 2 additions & 0 deletions cinn/common/target.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,5 +70,7 @@ static const Target& DefaultNVGPUTarget() {
return target;
}

std::ostream& operator<<(std::ostream& os, Target::Arch arch);

} // namespace common
} // namespace cinn
14 changes: 7 additions & 7 deletions cinn/common/test_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,26 +23,26 @@ cinn_buffer_t* BufferBuilder::Build() {

switch (init_type_) {
case InitType::kZero:
memset(buffer->host_memory, 0, buffer->memory_size);
memset(buffer->memory, 0, buffer->memory_size);
break;

case InitType::kRandom:
if (type_ == type_of<float>()) {
RandomFloat<float>(buffer->host_memory, buffer->num_elements());
RandomFloat<float>(buffer->memory, buffer->num_elements());
} else if (type_ == type_of<double>()) {
RandomFloat<double>(buffer->host_memory, buffer->num_elements());
RandomFloat<double>(buffer->memory, buffer->num_elements());
} else if (type_ == type_of<int32_t>()) {
RandomInt<int32_t>(buffer->host_memory, buffer->num_elements());
RandomInt<int32_t>(buffer->memory, buffer->num_elements());
} else if (type_ == type_of<int64_t>()) {
RandomInt<int64_t>(buffer->host_memory, buffer->num_elements());
RandomInt<int64_t>(buffer->memory, buffer->num_elements());
}
break;

case InitType::kSetValue:
if (type_ == type_of<int>()) {
SetVal<int>(buffer->host_memory, buffer->num_elements(), init_val_);
SetVal<int>(buffer->memory, buffer->num_elements(), init_val_);
} else if (type_ == type_of<float>()) {
SetVal<float>(buffer->host_memory, buffer->num_elements(), init_val_);
SetVal<float>(buffer->memory, buffer->num_elements(), init_val_);
} else {
CINN_NOT_IMPLEMENTED
}
Expand Down
1 change: 1 addition & 0 deletions cinn/hlir/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
add_subdirectory(framework)
add_subdirectory(instruction)
add_subdirectory(pe)

Expand Down
Loading

0 comments on commit 2710d85

Please sign in to comment.