From 2710d85111a1aa9570a00612c346cda7a77d9969 Mon Sep 17 00:00:00 2001 From: Yan Chunwei Date: Mon, 10 Aug 2020 16:13:37 +0800 Subject: [PATCH] Init hlir framework (#161) * init hlir framework * rename host_memory to memory in cinn_buffer_t --- cinn/backends/codegen_c.cc | 2 +- cinn/backends/codegen_c_test.cc | 44 ++++++------- cinn/backends/codegen_cuda_dev_test.cc | 60 +++++++++--------- cinn/backends/llvm/codegen_x86_test.cc | 6 +- cinn/backends/llvm/execution_engine_test.cc | 24 ++++---- cinn/common/cuda_test_helper.cc | 4 +- cinn/common/macros.h | 2 + cinn/common/target.cc | 18 ++++++ cinn/common/target.h | 2 + cinn/common/test_helper.cc | 14 ++--- cinn/hlir/CMakeLists.txt | 1 + cinn/hlir/framework/CMakeLists.txt | 23 +++++++ cinn/hlir/framework/buffer.cc | 46 ++++++++++++++ cinn/hlir/framework/buffer.h | 68 +++++++++++++++++++++ cinn/hlir/framework/buffer_test.cc | 47 ++++++++++++++ cinn/hlir/framework/memory.cc | 53 ++++++++++++++++ cinn/hlir/framework/memory.h | 59 ++++++++++++++++++ cinn/hlir/framework/scope.cc | 15 +++++ cinn/hlir/framework/scope.h | 45 ++++++++++++++ cinn/hlir/framework/scope_test.cc | 21 +++++++ cinn/hlir/framework/tensor.cc | 7 +++ cinn/hlir/framework/tensor.h | 65 ++++++++++++++++++++ cinn/hlir/framework/tensor_test.cc | 22 +++++++ cinn/hlir/framework/variable.cc | 7 +++ cinn/hlir/framework/variable.h | 7 +++ cinn/hlir/instruction/compiler_test.cc | 60 +++++++++--------- cinn/hlir/instruction/x86/mkl_math_test.cc | 4 +- cinn/ir/lowered_func.h | 2 +- cinn/lang/tensor_test.cc | 4 +- cinn/optim/transform_polyfor_to_for_test.cc | 6 +- cinn/optim/vectorize_loops_test.cc | 14 ++--- cinn/poly/stage_test.cc | 10 +-- cinn/pybind/runtime.cc | 8 +-- cinn/python/hlir_api_wrapper.cc | 8 +-- cinn/runtime/cinn_runtime.cc | 43 +++++++------ cinn/runtime/cinn_runtime.h | 27 ++++---- cinn/runtime/cinn_runtime_test.cc | 2 +- cinn/runtime/cinn_x86_device_impl.cc | 16 ++--- cinn/runtime/cpu/cblas.cc | 6 +- cinn/runtime/cpu/host_intrinsics.cc | 4 +- cinn/runtime/cpu/mkl_math.cc | 10 +-- cinn/runtime/cpu/mkl_math_test.cc | 4 +- cinn/runtime/cuda/cuda_util.cc | 2 +- tests/test01_elementwise_add_case.cc | 18 +++--- tests/test02_matmul_case.cc | 8 +-- 45 files changed, 711 insertions(+), 207 deletions(-) create mode 100644 cinn/hlir/framework/CMakeLists.txt create mode 100644 cinn/hlir/framework/buffer.cc create mode 100644 cinn/hlir/framework/buffer.h create mode 100644 cinn/hlir/framework/buffer_test.cc create mode 100644 cinn/hlir/framework/memory.cc create mode 100644 cinn/hlir/framework/memory.h create mode 100644 cinn/hlir/framework/scope.cc create mode 100644 cinn/hlir/framework/scope.h create mode 100644 cinn/hlir/framework/scope_test.cc create mode 100644 cinn/hlir/framework/tensor.cc create mode 100644 cinn/hlir/framework/tensor.h create mode 100644 cinn/hlir/framework/tensor_test.cc create mode 100644 cinn/hlir/framework/variable.cc create mode 100644 cinn/hlir/framework/variable.h diff --git a/cinn/backends/codegen_c.cc b/cinn/backends/codegen_c.cc index 34c938d58161f..e03468dd1992e 100644 --- a/cinn/backends/codegen_c.cc +++ b/cinn/backends/codegen_c.cc @@ -328,7 +328,7 @@ void CodeGenC::PrintCall_buffer_get_data_handle(const ir::Call *op) { auto *buffer = op->read_args[0].As(); os() << buffer->name; os() << "->"; - os() << "host_memory"; + os() << "memory"; } void CodeGenC::PrintCall_get_address(const ir::Call *op) { diff --git a/cinn/backends/codegen_c_test.cc b/cinn/backends/codegen_c_test.cc index 83944302ee366..ac3e981e8ecb5 100644 --- a/cinn/backends/codegen_c_test.cc +++ b/cinn/backends/codegen_c_test.cc @@ -71,9 +71,9 @@ void add1(void* _args, int32_t num_args) const cinn_buffer_t* _B = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1])); cinn_buffer_t* _C = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[2])); cinn_buffer_malloc((void*)(0), _C); - const float* A = ((const float*)(_A->host_memory)); - const float* B = ((const float*)(_B->host_memory)); - float* C = ((float*)(_C->host_memory)); + const float* A = ((const float*)(_A->memory)); + const float* B = ((const float*)(_B->memory)); + float* C = ((float*)(_C->memory)); for (int32_t i = 0; i < 100; i += 1) { for (int32_t j = 0; j < 20; j += 1) { C[((20 * i) + j)] = (A[((20 * i) + j)] + B[((20 * i) + j)]); @@ -166,10 +166,10 @@ void add1(void* _args, int32_t num_args) cinn_buffer_t* _D = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[3])); cinn_buffer_malloc((void*)(0), _C); cinn_buffer_malloc((void*)(0), _D); - const float* A = ((const float*)(_A->host_memory)); - const float* B = ((const float*)(_B->host_memory)); - float* C = ((float*)(_C->host_memory)); - float* D = ((float*)(_D->host_memory)); + const float* A = ((const float*)(_A->memory)); + const float* B = ((const float*)(_B->memory)); + float* C = ((float*)(_C->memory)); + float* D = ((float*)(_D->memory)); for (int32_t i_outer = 0; i_outer < 25; i_outer += 1) { for (int32_t i_inner = 0; i_inner < 4; i_inner += 1) { for (int32_t j = 0; j < 20; j += 1) { @@ -251,10 +251,10 @@ void matmul(void* _args, int32_t num_args) const cinn_buffer_t* _B = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1])); cinn_buffer_t* _C = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[2])); cinn_buffer_malloc((void*)(0), _C); - const float* A = ((const float*)(_A->host_memory)); - const float* B = ((const float*)(_B->host_memory)); - float* C = ((float*)(_C->host_memory)); - float* C_init = ((float*)(_C->host_memory)); + const float* A = ((const float*)(_A->memory)); + const float* B = ((const float*)(_B->memory)); + float* C = ((float*)(_C->memory)); + float* C_init = ((float*)(_C->memory)); for (int32_t i = 0; i < 100; i += 1) { for (int32_t j = 0; j < 50; j += 1) { C_init[((50 * i) + j)] = 0; @@ -272,9 +272,9 @@ void main(void* _args, int32_t num_args) const cinn_buffer_t* _B = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1])); cinn_buffer_t* _C = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[2])); cinn_buffer_malloc((void*)(0), _C); - const float* A = ((const float*)(_A->host_memory)); - const float* B = ((const float*)(_B->host_memory)); - float* C = ((float*)(_C->host_memory)); + const float* A = ((const float*)(_A->memory)); + const float* B = ((const float*)(_B->memory)); + float* C = ((float*)(_C->memory)); { cinn_pod_value_t _pod_val__8; buffer_p_to_cinn_pod_value(_A, &_pod_val__8); @@ -351,10 +351,10 @@ void matmul(void* _args, int32_t num_args) const cinn_buffer_t* _B = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1])); cinn_buffer_t* _C = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[2])); cinn_buffer_malloc((void*)(0), _C); - const float* A = ((const float*)(_A->host_memory)); - const float* B = ((const float*)(_B->host_memory)); - float* C = ((float*)(_C->host_memory)); - float* C_init = ((float*)(_C->host_memory)); + const float* A = ((const float*)(_A->memory)); + const float* B = ((const float*)(_B->memory)); + float* C = ((float*)(_C->memory)); + float* C_init = ((float*)(_C->memory)); for (int32_t i_outer = 0; i_outer < 4; i_outer += 1) { for (int32_t j_outer = 0; j_outer < 16; j_outer += 1) { for (int32_t i_inner = 0; i_inner < (1 + ((int32_t)(cinn_min(31, (99 + (-32 * i_outer)))))); i_inner += 1) { @@ -431,10 +431,10 @@ void matmul_with_packing(void* _args, int32_t num_args) cinn_buffer_t* _C = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[3])); cinn_buffer_malloc((void*)(0), _PackedB); cinn_buffer_malloc((void*)(0), _C); - const float* A = ((const float*)(_A->host_memory)); - const float* B = ((const float*)(_B->host_memory)); - float* C = ((float*)(_C->host_memory)); - float* PackedB = ((float*)(_PackedB->host_memory)); + const float* A = ((const float*)(_A->memory)); + const float* B = ((const float*)(_B->memory)); + float* C = ((float*)(_C->memory)); + float* PackedB = ((float*)(_PackedB->memory)); for (int32_t i = 0; i < 15; i += 1) { for (int32_t j = 0; j < 200; j += 1) { for (int32_t k = 0; k < 32; k += 1) { diff --git a/cinn/backends/codegen_cuda_dev_test.cc b/cinn/backends/codegen_cuda_dev_test.cc index 4f35f60bbfb8a..a28119ed2b96a 100644 --- a/cinn/backends/codegen_cuda_dev_test.cc +++ b/cinn/backends/codegen_cuda_dev_test.cc @@ -417,9 +417,9 @@ TEST(CodeGenCUDA, jit_host_call_cuda_kernel) { cinn_buffer_t* C_buf = cinn_buffer_new(cinn_x86_device, cinn_float32_t(), std::vector{{M.as_int32(), N.as_int32()}}); - A_buf->host_memory = reinterpret_cast(Ad); - B_buf->host_memory = reinterpret_cast(Bd); - C_buf->host_memory = reinterpret_cast(Cd); + A_buf->memory = reinterpret_cast(Ad); + B_buf->memory = reinterpret_cast(Bd); + C_buf->memory = reinterpret_cast(Cd); CUDA_CALL(cudaDeviceSynchronize()); @@ -714,9 +714,9 @@ TEST(elementwise_add, share_local_cache) { cinn_buffer_t* C_buf = cinn_buffer_new(cinn_x86_device, cinn_float32_t(), std::vector{{M.as_int32(), N.as_int32()}}); - A_buf->host_memory = reinterpret_cast(Ad); - B_buf->host_memory = reinterpret_cast(Bd); - C_buf->host_memory = reinterpret_cast(Cd); + A_buf->memory = reinterpret_cast(Ad); + B_buf->memory = reinterpret_cast(Bd); + C_buf->memory = reinterpret_cast(Cd); CUDA_CALL(cudaDeviceSynchronize()); @@ -971,23 +971,23 @@ void fn0_kernel(const float* __restrict__ A, const float* __restrict__ B, float* cinn_buffer_t* dev_bufs[3]; for (int i = 0; i < 3; i++) dev_bufs[i] = new cinn_buffer_t; - dev_bufs[0]->host_memory = reinterpret_cast(A_dev); - dev_bufs[1]->host_memory = reinterpret_cast(B_dev); - dev_bufs[2]->host_memory = reinterpret_cast(C_dev); - auto args = common::ArgsBuilder().Add(dev_bufs[0]).Add(dev_bufs[1]).Add(dev_bufs[2]).Build(); + dev_bufs[0]->memory = reinterpret_cast(A_dev); + dev_bufs[1]->memory = reinterpret_cast(B_dev); + dev_bufs[2]->memory = reinterpret_cast(C_dev); + auto args = common::ArgsBuilder().Add(dev_bufs[0]).Add(dev_bufs[1]).Add(dev_bufs[2]).Build(); CUDA_CALL(cudaDeviceSynchronize()); tester("fn0", args.data(), args.size()); CUDA_CALL(cudaDeviceSynchronize()); - CUDA_CALL(cudaMemcpy(reinterpret_cast(C_target_host->host_memory), + CUDA_CALL(cudaMemcpy(reinterpret_cast(C_target_host->memory), C_dev, C_target_host->num_elements() * sizeof(float), cudaMemcpyDeviceToHost)); - auto* C_target_mem = reinterpret_cast(C_target_host->host_memory); - auto* A_mem = reinterpret_cast(A_host->host_memory); - auto* B_mem = reinterpret_cast(B_host->host_memory); + auto* C_target_mem = reinterpret_cast(C_target_host->memory); + auto* A_mem = reinterpret_cast(A_host->memory); + auto* B_mem = reinterpret_cast(B_host->memory); for (int i = 0; i < C_target_host->num_elements(); i++) { ASSERT_NEAR(C_target_mem[i], A_mem[i] + B_mem[i], 1e-5); } @@ -1099,23 +1099,23 @@ void fn1_kernel(const float* __restrict__ A, const float* __restrict__ B, float* cinn_buffer_t* dev_bufs[3]; for (int i = 0; i < 3; i++) dev_bufs[i] = new cinn_buffer_t; - dev_bufs[0]->host_memory = reinterpret_cast(A_dev); - dev_bufs[1]->host_memory = reinterpret_cast(B_dev); - dev_bufs[2]->host_memory = reinterpret_cast(C_dev); - auto args = common::ArgsBuilder().Add(dev_bufs[0]).Add(dev_bufs[1]).Add(dev_bufs[2]).Build(); + dev_bufs[0]->memory = reinterpret_cast(A_dev); + dev_bufs[1]->memory = reinterpret_cast(B_dev); + dev_bufs[2]->memory = reinterpret_cast(C_dev); + auto args = common::ArgsBuilder().Add(dev_bufs[0]).Add(dev_bufs[1]).Add(dev_bufs[2]).Build(); CUDA_CALL(cudaDeviceSynchronize()); tester("fn1", args.data(), args.size()); CUDA_CALL(cudaDeviceSynchronize()); - CUDA_CALL(cudaMemcpy(reinterpret_cast(C_target_host->host_memory), + CUDA_CALL(cudaMemcpy(reinterpret_cast(C_target_host->memory), C_dev, C_target_host->num_elements() * sizeof(float), cudaMemcpyDeviceToHost)); - auto* C_target_mem = reinterpret_cast(C_target_host->host_memory); - auto* A_mem = reinterpret_cast(A_host->host_memory); - auto* B_mem = reinterpret_cast(B_host->host_memory); + auto* C_target_mem = reinterpret_cast(C_target_host->memory); + auto* A_mem = reinterpret_cast(A_host->memory); + auto* B_mem = reinterpret_cast(B_host->memory); for (int i = 0; i < M.as_int32() - 2; i++) { for (int j = 0; j < N.as_int32(); j++) { ASSERT_NEAR(C_target_mem[i * N.as_int32() + j], @@ -1167,23 +1167,23 @@ void TestElementwiseAddPrecisionBasic( cinn_buffer_t* dev_bufs[3]; for (int i = 0; i < 3; i++) dev_bufs[i] = new cinn_buffer_t; - dev_bufs[0]->host_memory = reinterpret_cast(A_dev); - dev_bufs[1]->host_memory = reinterpret_cast(B_dev); - dev_bufs[2]->host_memory = reinterpret_cast(C_dev); - auto args = common::ArgsBuilder().Add(dev_bufs[0]).Add(dev_bufs[1]).Add(dev_bufs[2]).Build(); + dev_bufs[0]->memory = reinterpret_cast(A_dev); + dev_bufs[1]->memory = reinterpret_cast(B_dev); + dev_bufs[2]->memory = reinterpret_cast(C_dev); + auto args = common::ArgsBuilder().Add(dev_bufs[0]).Add(dev_bufs[1]).Add(dev_bufs[2]).Build(); CUDA_CALL(cudaDeviceSynchronize()); tester(fn_name, args.data(), args.size()); CUDA_CALL(cudaDeviceSynchronize()); - CUDA_CALL(cudaMemcpy(reinterpret_cast(C_target_host->host_memory), + CUDA_CALL(cudaMemcpy(reinterpret_cast(C_target_host->memory), C_dev, C_target_host->num_elements() * sizeof(float), cudaMemcpyDeviceToHost)); - auto* C_target_mem = reinterpret_cast(C_target_host->host_memory); - auto* A_mem = reinterpret_cast(A_host->host_memory); - auto* B_mem = reinterpret_cast(B_host->host_memory); + auto* C_target_mem = reinterpret_cast(C_target_host->memory); + auto* A_mem = reinterpret_cast(A_host->memory); + auto* B_mem = reinterpret_cast(B_host->memory); for (int i = 0; i < M.as_int32() - 2; i++) { for (int j = 0; j < N.as_int32(); j++) { ASSERT_NEAR( diff --git a/cinn/backends/llvm/codegen_x86_test.cc b/cinn/backends/llvm/codegen_x86_test.cc index e1d53282f2a43..1d049a6de89c5 100644 --- a/cinn/backends/llvm/codegen_x86_test.cc +++ b/cinn/backends/llvm/codegen_x86_test.cc @@ -45,9 +45,9 @@ TEST(Vectorize, basic) { fn_ptr(reinterpret_cast(args.data()), args.size()); - auto* A_data = reinterpret_cast(A_buf->host_memory); - auto* B_data = reinterpret_cast(B_buf->host_memory); - auto* C_data = reinterpret_cast(C_buf->host_memory); + auto* A_data = reinterpret_cast(A_buf->memory); + auto* B_data = reinterpret_cast(B_buf->memory); + auto* C_data = reinterpret_cast(C_buf->memory); for (int i = 0; i < C_buf->num_elements(); i++) { ASSERT_NEAR(A_data[i] + B_data[i], C_data[i], 1e-5); } diff --git a/cinn/backends/llvm/execution_engine_test.cc b/cinn/backends/llvm/execution_engine_test.cc index 0d51a5db249ee..ab2b446f0a257 100644 --- a/cinn/backends/llvm/execution_engine_test.cc +++ b/cinn/backends/llvm/execution_engine_test.cc @@ -62,15 +62,15 @@ auto CreateTestBuffer() { cinn_buffer_malloc(nullptr, A); cinn_buffer_malloc(nullptr, B); cinn_buffer_malloc(nullptr, C); - float *Ad = reinterpret_cast(A->host_memory); - float *Bd = reinterpret_cast(B->host_memory); + float *Ad = reinterpret_cast(A->memory); + float *Bd = reinterpret_cast(B->memory); for (int i = 0; i < A->num_elements(); i++) { Ad[i] = static_cast(rand()) / RAND_MAX; // NOLINT Bd[i] = static_cast(rand()) / RAND_MAX; // NOLINT } - float *Cd = reinterpret_cast(C->host_memory); + float *Cd = reinterpret_cast(C->memory); CHECK_EQ(C->num_elements(), A->num_elements()); return std::make_tuple(A, B, C); @@ -119,9 +119,9 @@ TEST(llvm_test01, elementwise_add) { cinn_pod_value_t args[3] = {a_arg, b_arg, c_arg}; elementwise_add(args, 3); - float *ad = reinterpret_cast(a->host_memory); - float *bd = reinterpret_cast(b->host_memory); - float *cd = reinterpret_cast(c->host_memory); + float *ad = reinterpret_cast(a->memory); + float *bd = reinterpret_cast(b->memory); + float *cd = reinterpret_cast(c->memory); for (int i = 0; i < c->num_elements(); i++) { EXPECT_EQ(ad[i] + bd[i], cd[i]); @@ -180,11 +180,11 @@ TEST(llvm, module_call_lowered_func) { elementwise_add(args, 3); - auto *ad = reinterpret_cast(ab->host_memory); - auto *bd = reinterpret_cast(bb->host_memory); + auto *ad = reinterpret_cast(ab->memory); + auto *bd = reinterpret_cast(bb->memory); for (int i = 0; i < kM; i++) { for (int j = 0; j < kN; j++) { - auto *data = reinterpret_cast(cb->host_memory); + auto *data = reinterpret_cast(cb->memory); ASSERT_NEAR(data[i * kN + j], ad[i * kN + j] + bd[i * kN + j], 1e-5); } } @@ -309,9 +309,9 @@ TEST(ExecutionEngine, call_extern) { comp(args, 3); - auto *ad = reinterpret_cast(ab->host_memory); - auto *bd = reinterpret_cast(bb->host_memory); - auto *cd = reinterpret_cast(cb->host_memory); + auto *ad = reinterpret_cast(ab->memory); + auto *bd = reinterpret_cast(bb->memory); + auto *cd = reinterpret_cast(cb->memory); for (int m = 0; m < kM; m++) { for (int n = 0; n < kN; n++) { ASSERT_NEAR(cd[m * kN + n], cinn_cpu_tanh_fp32(ad[m * kN + n] + bd[m * kN + n]), 1e-5); diff --git a/cinn/common/cuda_test_helper.cc b/cinn/common/cuda_test_helper.cc index eab29b31d5bd4..8a494f4f8481b 100644 --- a/cinn/common/cuda_test_helper.cc +++ b/cinn/common/cuda_test_helper.cc @@ -48,12 +48,12 @@ void CudaModuleTester::Compile(const lang::Module& m, const std::string& rewrite } void* CudaModuleTester::CreateDeviceBuffer(const cinn_buffer_t* host_buffer) { - CHECK(host_buffer->host_memory); + CHECK(host_buffer->memory); int num_bytes = host_buffer->num_elements() * sizeof(float); CUdeviceptr data; cuMemAlloc(&data, num_bytes); - CUDA_CALL(cudaMemcpy(reinterpret_cast(data), host_buffer->host_memory, num_bytes, cudaMemcpyHostToDevice)); + CUDA_CALL(cudaMemcpy(reinterpret_cast(data), host_buffer->memory, num_bytes, cudaMemcpyHostToDevice)); return reinterpret_cast(data); } diff --git a/cinn/common/macros.h b/cinn/common/macros.h index f995a27ff6bff..6d8ca74cb35a5 100644 --- a/cinn/common/macros.h +++ b/cinn/common/macros.h @@ -9,3 +9,5 @@ void operator=(const TypeName&) = delete #define CINN_NOT_IMPLEMENTED LOG(FATAL) << "Not Implemented"; + +#define CINN_RESULT_SHOULD_USE __attribute__((warn_unused_result)) diff --git a/cinn/common/target.cc b/cinn/common/target.cc index 92dea0be3efac..1d327f829beb2 100644 --- a/cinn/common/target.cc +++ b/cinn/common/target.cc @@ -73,5 +73,23 @@ std::ostream &operator<<(std::ostream &os, const Target &target) { return os; } +std::ostream &operator<<(std::ostream &os, Target::Arch arch) { + switch (arch) { + case Target::Arch::Unk: + os << "Unk"; + break; + case Target::Arch::X86: + os << "X86"; + break; + case Target::Arch::ARM: + os << "ARM"; + break; + case Target::Arch::NVGPU: + os << "NVGPU"; + break; + } + return os; +} + } // namespace common } // namespace cinn diff --git a/cinn/common/target.h b/cinn/common/target.h index 45a69b54cd393..3d743fc603131 100644 --- a/cinn/common/target.h +++ b/cinn/common/target.h @@ -70,5 +70,7 @@ static const Target& DefaultNVGPUTarget() { return target; } +std::ostream& operator<<(std::ostream& os, Target::Arch arch); + } // namespace common } // namespace cinn diff --git a/cinn/common/test_helper.cc b/cinn/common/test_helper.cc index e23d7a68bdf46..eb18d84d8e78b 100644 --- a/cinn/common/test_helper.cc +++ b/cinn/common/test_helper.cc @@ -23,26 +23,26 @@ cinn_buffer_t* BufferBuilder::Build() { switch (init_type_) { case InitType::kZero: - memset(buffer->host_memory, 0, buffer->memory_size); + memset(buffer->memory, 0, buffer->memory_size); break; case InitType::kRandom: if (type_ == type_of()) { - RandomFloat(buffer->host_memory, buffer->num_elements()); + RandomFloat(buffer->memory, buffer->num_elements()); } else if (type_ == type_of()) { - RandomFloat(buffer->host_memory, buffer->num_elements()); + RandomFloat(buffer->memory, buffer->num_elements()); } else if (type_ == type_of()) { - RandomInt(buffer->host_memory, buffer->num_elements()); + RandomInt(buffer->memory, buffer->num_elements()); } else if (type_ == type_of()) { - RandomInt(buffer->host_memory, buffer->num_elements()); + RandomInt(buffer->memory, buffer->num_elements()); } break; case InitType::kSetValue: if (type_ == type_of()) { - SetVal(buffer->host_memory, buffer->num_elements(), init_val_); + SetVal(buffer->memory, buffer->num_elements(), init_val_); } else if (type_ == type_of()) { - SetVal(buffer->host_memory, buffer->num_elements(), init_val_); + SetVal(buffer->memory, buffer->num_elements(), init_val_); } else { CINN_NOT_IMPLEMENTED } diff --git a/cinn/hlir/CMakeLists.txt b/cinn/hlir/CMakeLists.txt index 7390df740f4af..fcaecb301fcfd 100644 --- a/cinn/hlir/CMakeLists.txt +++ b/cinn/hlir/CMakeLists.txt @@ -1,3 +1,4 @@ +add_subdirectory(framework) add_subdirectory(instruction) add_subdirectory(pe) diff --git a/cinn/hlir/framework/CMakeLists.txt b/cinn/hlir/framework/CMakeLists.txt new file mode 100644 index 0000000000000..307e3c2207eb7 --- /dev/null +++ b/cinn/hlir/framework/CMakeLists.txt @@ -0,0 +1,23 @@ +set(srcs + tensor.cc + scope.cc + variable.cc + buffer.cc + memory.cc + ) + +if(WITH_CUDA) + nv_test(test_hlir_framework_buffer SRCS buffer_test.cc DEPS core) +else() + cc_test(test_hlir_framework_buffer SRCS buffer_test.cc DEPS core) +endif() + +cc_test(test_hlir_framework_tensor SRCS tensor_test.cc DEPS core) +cc_test(test_hlir_framework_scope SRCS scope_test.cc DEPS core) + + +foreach(cpp ${srcs}) + set(core_src + "${core_src};cinn/hlir/framework/${cpp}" + CACHE INTERNAL "") +endforeach() diff --git a/cinn/hlir/framework/buffer.cc b/cinn/hlir/framework/buffer.cc new file mode 100644 index 0000000000000..3964db54adfa9 --- /dev/null +++ b/cinn/hlir/framework/buffer.cc @@ -0,0 +1,46 @@ +#include "cinn/hlir/framework/buffer.h" + +namespace cinn { +namespace hlir { +namespace framework { + +void Buffer::Resize(uint32_t size) { + if (size_ > 0) { + Free(); + } + + if (size_ != size) { + data_.memory = reinterpret_cast(Malloc(size)); + size_ = size; + } +} + +void Buffer::SetTarget(const common::Target& target) { + target_ = target; + memory_mng_cache_ = MemoryManager::Global().RetrieveSafely(target_.arch); +} + +void Buffer::ResizeLazy(uint32_t size) { + if (size <= size_) return; + Resize(size); +} + +void Buffer::Resize(uint32_t size, const common::Target& target) { + if (target.arch != target_.arch) { + Free(); + SetTarget(target); + } + Resize(size); +} + +void Buffer::ResizeLazy(uint32_t size, const common::Target& target) { + if (target.arch != target_.arch) { + Free(); + SetTarget(target); + } + ResizeLazy(size); +} + +} // namespace framework +} // namespace hlir +} // namespace cinn diff --git a/cinn/hlir/framework/buffer.h b/cinn/hlir/framework/buffer.h new file mode 100644 index 0000000000000..739531f9ebb77 --- /dev/null +++ b/cinn/hlir/framework/buffer.h @@ -0,0 +1,68 @@ +#pragma once + +#include + +#include +#include + +#include "cinn/common/macros.h" +#include "cinn/common/target.h" +#include "cinn/hlir/framework/memory.h" +#include "cinn/runtime/cinn_runtime.h" + +namespace cinn { +namespace hlir { +namespace framework { + +/** + * Buffer helps to hold the memory, and offers a set of methods to help manage the memory. + */ +struct Buffer final { + Buffer() = default; + explicit Buffer(const common::Target& target) { SetTarget(target); } + + //! Resize the memory hold by this buffer *exactlly* to \p size. + void Resize(uint32_t size); + + //! Lazily resize the memory. + void ResizeLazy(uint32_t size); + + //! Resize the memory to \p size in target \p target. + void Resize(uint32_t size, const common::Target& target); + + //! Lazily resize the memory to \p size in target \p target. + void ResizeLazy(uint32_t size, const common::Target& target); + + void SetTarget(const common::Target& target); + + const cinn_buffer_t* data() const { return &data_; } + cinn_buffer_t* data() { return &data_; } + + //! Free all the memory owned by this buffer. + void Free() { + if (!data_.memory) return; + memory_mng_cache_->free(data_.memory); + } + + private: + inline void* Malloc(uint32_t size) CINN_RESULT_SHOULD_USE { + CHECK(memory_mng_cache_) << "Should set target first"; + return memory_mng_cache_->malloc(size); + } + + private: + cinn_buffer_t data_; + + //! The place where this buffer locates. + common::Target target_; + + //! Number of bytes of this buffer. + uint32_t size_{}; + + //! Hold the corresponding memory manager for speed. + MemoryInterface* memory_mng_cache_{}; +}; + +} // namespace framework +} // namespace hlir +} // namespace cinn diff --git a/cinn/hlir/framework/buffer_test.cc b/cinn/hlir/framework/buffer_test.cc new file mode 100644 index 0000000000000..380bcedddab3d --- /dev/null +++ b/cinn/hlir/framework/buffer_test.cc @@ -0,0 +1,47 @@ +#include "cinn/hlir/framework/buffer.h" +#ifdef CINN_WITH_CUDA +#include "cinn/backends/cuda_util.h" +#endif +#include + +#include + +namespace cinn { +namespace hlir { +namespace framework { + +TEST(Buffer, basic) { + Buffer buffer(common::DefaultHostTarget()); + buffer.Resize(10 * sizeof(float)); + auto* data = reinterpret_cast(buffer.data()); + for (int i = 0; i < 10; i++) data[i] = i; +} + +#ifdef CINN_WITH_CUDA +TEST(Buffer, nvgpu) { + const int num_elements = 10; + Buffer buffer(common::DefaultNVGPUTarget()); + buffer.Resize(num_elements * sizeof(float)); + auto* data = reinterpret_cast(buffer.data()); + std::vector host_data(num_elements); + std::vector host_target(num_elements, 0); + + for (int i = 0; i < num_elements; i++) { + host_data[i] = i; + } + LOG(INFO) << "Cuda copy data"; + CUDA_DRIVER_CALL(cuMemcpy(reinterpret_cast(data), + reinterpret_cast(host_data.data()), + num_elements * sizeof(float))); + CUDA_DRIVER_CALL(cuMemcpy(reinterpret_cast(host_target.data()), + reinterpret_cast(data), + num_elements * sizeof(float))); + for (int i = 0; i < num_elements; i++) { + ASSERT_EQ(host_target[i], i); + } +} +#endif + +} // namespace framework +} // namespace hlir +} // namespace cinn diff --git a/cinn/hlir/framework/memory.cc b/cinn/hlir/framework/memory.cc new file mode 100644 index 0000000000000..2af3cf09c5e62 --- /dev/null +++ b/cinn/hlir/framework/memory.cc @@ -0,0 +1,53 @@ +#include "cinn/hlir/framework/memory.h" + +#ifdef CINN_WITH_CUDA +#include +#include + +#include "cinn/backends/cuda_util.h" +#endif + +namespace cinn { +namespace hlir { +namespace framework { + +using common::Target; + +namespace { + +class X86MemoryMng : public MemoryInterface { + public: + void* malloc(size_t nbytes) override { return ::malloc(nbytes); } + void free(void* data) override { + if (!data) return; + ::free(data); + } +}; + +#ifdef CINN_WITH_CUDA +class CudaMemoryMng : public MemoryInterface { + public: + void* malloc(size_t nbytes) override { + void* data; + CUDA_CALL(cudaMalloc(&data, nbytes)); + return data; + } + + void free(void* data) override { CUDA_CALL(cudaFree(data)); } +}; + +#endif + +} // namespace + +MemoryManager::MemoryManager() { + Register(Target::Arch::Unk, new X86MemoryMng); + Register(Target::Arch::X86, new X86MemoryMng); +#ifdef CINN_WITH_CUDA + Register(Target::Arch::NVGPU, new CudaMemoryMng); +#endif +} + +} // namespace framework +} // namespace hlir +} // namespace cinn diff --git a/cinn/hlir/framework/memory.h b/cinn/hlir/framework/memory.h new file mode 100644 index 0000000000000..2ec9cdebe5163 --- /dev/null +++ b/cinn/hlir/framework/memory.h @@ -0,0 +1,59 @@ +#pragma once + +#include +#include +#include +#include "cinn/common/macros.h" +#include "cinn/common/target.h" + +namespace cinn { +namespace hlir { +namespace framework { + +class MemoryInterface { + public: + virtual void* malloc(size_t nbytes) = 0; + virtual void free(void* data) = 0; +}; + +/** + * MemoryManager holds a map of MemoryInterface for each articture. + */ +class MemoryManager final { + public: + using key_t = common::Target::Arch; + + static MemoryManager& Global() { + static auto* x = new MemoryManager; + return *x; + } + + MemoryInterface* Retrieve(key_t key) CINN_RESULT_SHOULD_USE { + auto it = memory_mngs_.find(key); + if (it != memory_mngs_.end()) return it->second.get(); + return nullptr; + } + + MemoryInterface* RetrieveSafely(key_t key) { + auto* res = Retrieve(key); + CHECK(res) << "no MemoryInterface for architecture [" << key << "]"; + return res; + } + + MemoryInterface* Register(key_t key, MemoryInterface* item) { + CHECK(!memory_mngs_.count(key)) << "Duplicate register [" << key << "]"; + memory_mngs_[key].reset(item); + return item; + } + + private: + MemoryManager(); + + std::unordered_map> memory_mngs_; + + CINN_DISALLOW_COPY_AND_ASSIGN(MemoryManager); +}; + +} // namespace framework +} // namespace hlir +} // namespace cinn diff --git a/cinn/hlir/framework/scope.cc b/cinn/hlir/framework/scope.cc new file mode 100644 index 0000000000000..c3cdd3684a28d --- /dev/null +++ b/cinn/hlir/framework/scope.cc @@ -0,0 +1,15 @@ +#include "cinn/hlir/framework/scope.h" + +namespace cinn { +namespace hlir { +namespace framework { + +Variable* Scope::FindVar(const std::string& name) { + auto it = dic.find(name); + if (it != dic.end()) return it->second.get(); + return nullptr; +} + +} // namespace framework +} // namespace hlir +} // namespace cinn diff --git a/cinn/hlir/framework/scope.h b/cinn/hlir/framework/scope.h new file mode 100644 index 0000000000000..21e4c965672ca --- /dev/null +++ b/cinn/hlir/framework/scope.h @@ -0,0 +1,45 @@ +#pragma once +#include +#include +#include +#include +#include + +#include "cinn/common/macros.h" +#include "cinn/hlir/framework/tensor.h" + +namespace cinn { +namespace hlir { +namespace framework { + +using Variable = std::variant; + +class Scope { + public: + Scope() = default; + + //! Get or create a variable. + template + Variable* Var(const std::string& name); + + //! Find a variable, get null if not exists. + Variable* FindVar(const std::string& name); + + private: + std::unordered_map> dic; + + CINN_DISALLOW_COPY_AND_ASSIGN(Scope); +}; + +template +Variable* Scope::Var(const std::string& name) { + Variable* x = FindVar(name); + if (x) return x; + auto* data = new Variable(T()); + dic[name].reset(data); + return data; +} + +} // namespace framework +} // namespace hlir +} // namespace cinn diff --git a/cinn/hlir/framework/scope_test.cc b/cinn/hlir/framework/scope_test.cc new file mode 100644 index 0000000000000..fa289bd6dd515 --- /dev/null +++ b/cinn/hlir/framework/scope_test.cc @@ -0,0 +1,21 @@ +#include "cinn/hlir/framework/scope.h" +#include + +namespace cinn { +namespace hlir { +namespace framework { + +TEST(Scope, basic) { + Scope scope; + auto* var = scope.Var("key"); + auto& tensor = std::get(*var); + tensor.Resize(Shape{{3, 1}}); + auto* data = tensor.mutable_data(common::DefaultHostTarget()); + data[0] = 0.f; + data[1] = 1.f; + data[2] = 2.f; +} + +} // namespace framework +} // namespace hlir +} // namespace cinn diff --git a/cinn/hlir/framework/tensor.cc b/cinn/hlir/framework/tensor.cc new file mode 100644 index 0000000000000..98a97977c3a7e --- /dev/null +++ b/cinn/hlir/framework/tensor.cc @@ -0,0 +1,7 @@ +#include "cinn/hlir/framework/tensor.h" + +namespace cinn { +namespace hlir { +namespace framework {} // namespace framework +} // namespace hlir +} // namespace cinn diff --git a/cinn/hlir/framework/tensor.h b/cinn/hlir/framework/tensor.h new file mode 100644 index 0000000000000..9272227176938 --- /dev/null +++ b/cinn/hlir/framework/tensor.h @@ -0,0 +1,65 @@ +#pragma once + +#include +#include +#include +#include + +#include "cinn/common/macros.h" +#include "cinn/hlir/framework/buffer.h" + +namespace cinn { +namespace hlir { +namespace framework { +using common::Target; + +struct Shape { + using dim_t = uint32_t; + + Shape() = default; + explicit Shape(const std::vector& data) : data_(data) {} + + void SetData(const std::vector& data) { data_ = data; } + + const std::vector& data() const CINN_RESULT_SHOULD_USE { return data_; } + std::vector& data() CINN_RESULT_SHOULD_USE { return data_; } + size_t size() const CINN_RESULT_SHOULD_USE { return data_.size(); } + uint32_t numel() const CINN_RESULT_SHOULD_USE { + return std::accumulate(data_.begin(), data_.end(), 1, [](dim_t a, dim_t b) { return a * b; }); + } + + private: + std::vector data_; +}; + +class Tensor final { + public: + Tensor() : buffer_(std::make_shared()) {} + + const Shape& shape() const { return shape_; } + + void Resize(const Shape& shape) { + shape_ = shape; + buffer_->data()->resize(reinterpret_cast(shape.data().data()), shape.size()); + } + + template + inline T* mutable_data(const Target& target) { + buffer_->ResizeLazy(shape_.numel() * sizeof(T), target); + return reinterpret_cast(buffer_->data()->memory); + } + + template + const T* data() const { + return buffer_->data()->memory; + } + + private: + // A shared ptr to make it easier to share buffer between tensors. + std::shared_ptr buffer_; + Shape shape_; +}; + +} // namespace framework +} // namespace hlir +} // namespace cinn diff --git a/cinn/hlir/framework/tensor_test.cc b/cinn/hlir/framework/tensor_test.cc new file mode 100644 index 0000000000000..ee60413c2a2e6 --- /dev/null +++ b/cinn/hlir/framework/tensor_test.cc @@ -0,0 +1,22 @@ +#include "cinn/hlir/framework/tensor.h" + +#include + +namespace cinn { +namespace hlir { +namespace framework { + +TEST(Tensor, basic) { + Tensor tensor; + tensor.Resize(Shape{{3, 2}}); + + auto* data = tensor.mutable_data(common::DefaultHostTarget()); + + for (int i = 0; i < tensor.shape().numel(); i++) { + data[i] = i; + } +} + +} // namespace framework +} // namespace hlir +} // namespace cinn diff --git a/cinn/hlir/framework/variable.cc b/cinn/hlir/framework/variable.cc new file mode 100644 index 0000000000000..bc8a6e6fd4245 --- /dev/null +++ b/cinn/hlir/framework/variable.cc @@ -0,0 +1,7 @@ +#include "cinn/hlir/framework/variable.h" + +namespace cinn { +namespace hlir { +namespace framework {} // namespace framework +} // namespace hlir +} // namespace cinn diff --git a/cinn/hlir/framework/variable.h b/cinn/hlir/framework/variable.h new file mode 100644 index 0000000000000..ee21c6f1f994e --- /dev/null +++ b/cinn/hlir/framework/variable.h @@ -0,0 +1,7 @@ +#pragma once + +namespace cinn { +namespace hlir { +namespace framework {} // namespace framework +} // namespace hlir +} // namespace cinn diff --git a/cinn/hlir/instruction/compiler_test.cc b/cinn/hlir/instruction/compiler_test.cc index 6a6e1f2f1362b..44e252c8fbf51 100644 --- a/cinn/hlir/instruction/compiler_test.cc +++ b/cinn/hlir/instruction/compiler_test.cc @@ -26,15 +26,15 @@ auto CreateTestBuffer(int kM, int kN) { cinn_buffer_malloc(nullptr, A); cinn_buffer_malloc(nullptr, B); cinn_buffer_malloc(nullptr, C); - float* Ad = reinterpret_cast(A->host_memory); - float* Bd = reinterpret_cast(B->host_memory); + float* Ad = reinterpret_cast(A->memory); + float* Bd = reinterpret_cast(B->memory); for (int i = 0; i < A->num_elements(); i++) { Ad[i] = i; Bd[i] = i; } - float* Cd = reinterpret_cast(C->host_memory); + float* Cd = reinterpret_cast(C->memory); CHECK_EQ(C->num_elements(), A->num_elements()); return std::make_tuple(A, B, C); @@ -89,15 +89,15 @@ TEST(Compiler, call_kernel_directly) { compiler.Eval(module.get(), args, 3, "elementwise_add0"); - const float* c_data = reinterpret_cast(c->host_memory); + const float* c_data = reinterpret_cast(c->memory); for (int i = 0; i < c->num_elements(); i++) { ASSERT_EQ(c_data[i], i * 2); } - delete a->host_memory; - delete b->host_memory; - delete c->host_memory; + delete a->memory; + delete b->memory; + delete c->memory; } TEST(Compiler, call_main) { @@ -107,24 +107,24 @@ TEST(Compiler, call_main) { auto [a, b, c] = CreateTestBuffer(100, 200); // NOLINT - cinn_print_debug_string("a.host_memory: %p", a->host_memory); - cinn_print_debug_string("b.host_memory: %p", b->host_memory); - cinn_print_debug_string("c.host_memory: %p", c->host_memory); + cinn_print_debug_string("a.memory: %p", a->memory); + cinn_print_debug_string("b.memory: %p", b->memory); + cinn_print_debug_string("c.memory: %p", c->memory); cinn_pod_value_t a_arg(a), b_arg(b), c_arg(c); cinn_pod_value_t args[] = {a_arg, b_arg, c_arg}; compiler.Eval(module.get(), args, 3, ""); - const float* c_data = reinterpret_cast(c->host_memory); + const float* c_data = reinterpret_cast(c->memory); for (int i = 0; i < c->num_elements(); i++) { ASSERT_EQ(c_data[i], i * 2); } - delete a->host_memory; - delete b->host_memory; - delete c->host_memory; + delete a->memory; + delete b->memory; + delete c->memory; } TEST(Compiler, call_main1) { @@ -134,16 +134,16 @@ TEST(Compiler, call_main1) { auto [a, b, c] = CreateTestBuffer(100, 200); // NOLINT - cinn_print_debug_string("a.host_memory: %p", a->host_memory); - cinn_print_debug_string("b.host_memory: %p", b->host_memory); - cinn_print_debug_string("c.host_memory: %p", c->host_memory); + cinn_print_debug_string("a.memory: %p", a->memory); + cinn_print_debug_string("b.memory: %p", b->memory); + cinn_print_debug_string("c.memory: %p", c->memory); cinn_pod_value_t a_arg(a), b_arg(b), c_arg(c); cinn_pod_value_t args[] = {a_arg, b_arg, c_arg}; compiler.Eval(module.get(), args, 3, ""); - const float* c_data = reinterpret_cast(c->host_memory); + const float* c_data = reinterpret_cast(c->memory); for (int i = 0; i < c->num_elements(); i++) { ASSERT_EQ(c_data[i], i * 2); @@ -225,10 +225,10 @@ TEST(Compiler, call_main_dense_model) { const int batch_size_runtime = 20; auto handcraft_compu = [&] { - auto* xb_data = reinterpret_cast(Xb->host_memory); - auto* Wb_data = reinterpret_cast(Wb->host_memory); - auto* Biasb_data = reinterpret_cast(Biasb->host_memory); - auto* Outb_data = reinterpret_cast(Outb_target->host_memory); + auto* xb_data = reinterpret_cast(Xb->memory); + auto* Wb_data = reinterpret_cast(Wb->memory); + auto* Biasb_data = reinterpret_cast(Biasb->memory); + auto* Outb_data = reinterpret_cast(Outb_target->memory); for (int b = 0; b < batch_size_runtime; b++) { for (int m = 0; m < M; m++) { @@ -259,13 +259,13 @@ TEST(Compiler, call_main_dense_model) { auto randomize_buffer = [](cinn_buffer_t* buffer) { cinn_buffer_malloc(nullptr, buffer); - auto* data = reinterpret_cast(buffer->host_memory); + auto* data = reinterpret_cast(buffer->memory); // for (int i = 0; i < buffer->num_elements(); i++) data[i] = static_cast(rand()) / RAND_MAX; for (int i = 0; i < buffer->num_elements(); i++) data[i] = 1; }; auto initialize_buffer = [](cinn_buffer_t* buffer) { cinn_buffer_malloc(nullptr, buffer); - auto* data = reinterpret_cast(buffer->host_memory); + auto* data = reinterpret_cast(buffer->memory); memset(data, 2, buffer->num_elements() * sizeof(float)); }; @@ -289,8 +289,8 @@ TEST(Compiler, call_main_dense_model) { { // check result handcraft_compu(); - auto* out_data = reinterpret_cast(Outb->host_memory); - auto* out_target_data = reinterpret_cast(Outb_target->host_memory); + auto* out_data = reinterpret_cast(Outb->memory); + auto* out_target_data = reinterpret_cast(Outb_target->memory); for (int b = 0; b < batch_size_runtime; b++) { for (int m = 0; m < M; m++) { @@ -380,8 +380,8 @@ void TestElementwise() { auto args = common::ArgsBuilder().Add(x_buf).Add(out_buf).Build(); compiler.Eval("tanh0", args.data(), 2); - auto* x_data = reinterpret_cast(x_buf->host_memory); - auto* out_data = reinterpret_cast(out_buf->host_memory); + auto* x_data = reinterpret_cast(x_buf->memory); + auto* out_data = reinterpret_cast(out_buf->memory); for (int i = 0; i < out_buf->num_elements(); i++) { ASSERT_NEAR(fp(x_data[i]), out_data[i], 1e-5); @@ -432,8 +432,8 @@ TEST(Compiler, dot_cgemm) { compiler.Eval(fn_name, args.data(), args.size()); - auto* out_data = reinterpret_cast(out_buf->host_memory); - auto* out_data1 = reinterpret_cast(out1_buf->host_memory); + auto* out_data = reinterpret_cast(out_buf->memory); + auto* out_data1 = reinterpret_cast(out1_buf->memory); for (int i = 0; i < out_buf->num_elements(); i++) { if (i < 4) { LOG(INFO) << "Dot result: " << out_data[i]; diff --git a/cinn/hlir/instruction/x86/mkl_math_test.cc b/cinn/hlir/instruction/x86/mkl_math_test.cc index 6595ed5707569..b464bd01fad20 100644 --- a/cinn/hlir/instruction/x86/mkl_math_test.cc +++ b/cinn/hlir/instruction/x86/mkl_math_test.cc @@ -61,8 +61,8 @@ void TestCallElementwise(const std::string &fn_name, float (*fn_runtime)(float), cinn_pod_value_t args[] = {a_arg, b_arg}; fn_(args, 2); - auto *ad = reinterpret_cast(A_buf->host_memory); - auto *bd = reinterpret_cast(B_buf->host_memory); + auto *ad = reinterpret_cast(A_buf->memory); + auto *bd = reinterpret_cast(B_buf->memory); for (int i = 0; i < A_buf->num_elements(); i++) { ASSERT_NEAR(bd[i], fn_runtime(ad[i]), 1e-5); } diff --git a/cinn/ir/lowered_func.h b/cinn/ir/lowered_func.h index af5ba51f38ad6..349123a82cf81 100644 --- a/cinn/ir/lowered_func.h +++ b/cinn/ir/lowered_func.h @@ -137,7 +137,7 @@ struct _LoweredFunc_ : ExprNode<_LoweredFunc_> { std::vector dealloc_output_buffer_exprs; // @} - //! something like: float* A_data = (float*)(A->host_memory); + //! something like: float* A_data = (float*)(A->memory); std::vector buffer_data_cast_exprs; std::vector argument_prepare_exprs; diff --git a/cinn/lang/tensor_test.cc b/cinn/lang/tensor_test.cc index 070757e2be4d5..9b08e4ec859e7 100644 --- a/cinn/lang/tensor_test.cc +++ b/cinn/lang/tensor_test.cc @@ -84,8 +84,8 @@ TEST(Tensor, Collapse) { fn(args.data(), args.size()); // check result - auto* A_data = reinterpret_cast(A_buf->host_memory); - auto* C_data = reinterpret_cast(C_buf->host_memory); + auto* A_data = reinterpret_cast(A_buf->memory); + auto* C_data = reinterpret_cast(C_buf->memory); for (int i = 0; i < A_buf->num_elements(); i++) { ASSERT_NEAR(C_data[i], A_data[i] * A_data[i] + 1.f, 1e-5); } diff --git a/cinn/optim/transform_polyfor_to_for_test.cc b/cinn/optim/transform_polyfor_to_for_test.cc index 55e5495e0343c..a021325f7f030 100644 --- a/cinn/optim/transform_polyfor_to_for_test.cc +++ b/cinn/optim/transform_polyfor_to_for_test.cc @@ -67,9 +67,9 @@ void matmul(void* _args, int32_t num_args) const cinn_buffer_t* _B = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1])); cinn_buffer_t* _C = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[2])); cinn_buffer_malloc((void*)(0), _C); - const float* A = ((const float*)(_A->host_memory)); - const float* B = ((const float*)(_B->host_memory)); - float* C = ((float*)(_C->host_memory)); + const float* A = ((const float*)(_A->memory)); + const float* B = ((const float*)(_B->memory)); + float* C = ((float*)(_C->memory)); for (int32_t i_outer = 0; i_outer < 64; i_outer += 1) { for (int32_t i_inner = 0; i_inner < 8; i_inner += 1) { for (int32_t j_outer = 0; j_outer < 63; j_outer += 1) { diff --git a/cinn/optim/vectorize_loops_test.cc b/cinn/optim/vectorize_loops_test.cc index d732014009f76..4dc15e05b13fa 100644 --- a/cinn/optim/vectorize_loops_test.cc +++ b/cinn/optim/vectorize_loops_test.cc @@ -61,9 +61,9 @@ void matmul(void* _args, int32_t num_args) const cinn_buffer_t* _B = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1])); cinn_buffer_t* _C = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[2])); cinn_buffer_malloc((void*)(0), _C); - const float* A = ((const float*)(_A->host_memory)); - const float* B = ((const float*)(_B->host_memory)); - float* C = ((float*)(_C->host_memory)); + const float* A = ((const float*)(_A->memory)); + const float* B = ((const float*)(_B->memory)); + float* C = ((float*)(_C->memory)); for (int32_t i = 0; i < 100; i += 1) { for (int32_t j = 0; j < 31; j += 1) { C[StackVec<16,int32_t>::Ramp(((500 * i) + (16 * j)), 1, 16)] = (StackedVec::Load(A,((500 * i) + (16 * j))) * StackedVec::Load(B,((500 * i) + (16 * j)))); @@ -129,10 +129,10 @@ TEST(Vectorize, TestMarkVectorize) { void matmul(const struct cinn_buffer_t *_A, const struct cinn_buffer_t *_B, struct cinn_buffer_t *_C) { cinn_buffer_malloc((void*)(0), _C); - const float* A = (const float*)(_A->host_memory); - const float* B = (const float*)(_B->host_memory); - float* C = (float*)(_C->host_memory); - float* D = (float*)(_C->host_memory); + const float* A = (const float*)(_A->memory); + const float* B = (const float*)(_B->memory); + float* C = (float*)(_C->memory); + float* D = (float*)(_C->memory); for (int32_t i = 0; i < 100; i += 1) { for (int32_t j_outer = 0; j_outer < 31; j_outer += 1) { C[StackVec<16,int32_t>::Ramp(((500 * i) + (16 * j_outer)), 1, 16)] = (StackedVec::Load(A,((500 * i) + (16 * j_outer))) * StackedVec::Load(B,((500 * i) + (16 * j_outer)))); diff --git a/cinn/poly/stage_test.cc b/cinn/poly/stage_test.cc index 9e7bb7888144a..237e93185db86 100644 --- a/cinn/poly/stage_test.cc +++ b/cinn/poly/stage_test.cc @@ -209,8 +209,8 @@ function fn (_A, _cache, _C) fn_handler(arg_pack.data(), arg_pack.size()); - auto* C_data = reinterpret_cast(C_buf->host_memory); - auto* A_data = reinterpret_cast(A_buf->host_memory); + auto* C_data = reinterpret_cast(C_buf->memory); + auto* A_data = reinterpret_cast(A_buf->memory); for (int k = 0; k < 10; k++) { for (int i = 0; i < 10; i++) { @@ -451,9 +451,9 @@ void TestElementwiseAddJitPrecession(std::function&& schedule fn_handler(arg_pack.data(), arg_pack.size()); - auto* A_data = reinterpret_cast(A_buf->host_memory); - auto* B_data = reinterpret_cast(B_buf->host_memory); - auto* C_data = reinterpret_cast(C_buf->host_memory); + auto* A_data = reinterpret_cast(A_buf->memory); + auto* B_data = reinterpret_cast(B_buf->memory); + auto* C_data = reinterpret_cast(C_buf->memory); for (int i = 0; i < A_buf->num_elements(); i++) { if (i < 4) LOG(INFO) << C_data[i]; ASSERT_NEAR(A_data[i] + B_data[i], C_data[i], 1e-5); diff --git a/cinn/pybind/runtime.cc b/cinn/pybind/runtime.cc index 5104b65b34b89..f46bf36590947 100644 --- a/cinn/pybind/runtime.cc +++ b/cinn/pybind/runtime.cc @@ -39,7 +39,7 @@ cinn_buffer_t *CreateBufferFromNumpy(py::array data, cinn_device_kind_t device, std::copy_n(data.shape(), data.ndim(), std::back_inserter(shape)); auto *buffer = cinn_buffer_t::new_(device, type, shape, align); cinn_buffer_malloc(nullptr, buffer); - std::memcpy(buffer->host_memory, data.data(), data.nbytes()); + std::memcpy(buffer->memory, data.data(), data.nbytes()); return buffer; } @@ -68,7 +68,7 @@ py::array BufferHostMemoryToNumpy(cinn_buffer_t &buffer) { // NOLINT cinn_buffer_copy_to_host(nullptr, &buffer); switch (buffer.device) { case cinn_x86_device: - std::memcpy(mutable_data, buffer.host_memory, buffer.memory_size); + std::memcpy(mutable_data, buffer.memory, buffer.memory_size); break; } @@ -153,11 +153,11 @@ void BindCinnRuntime(py::module *m) { py::class_ cinn_buffer(*m, "cinn_buffer_t"); cinn_buffer.def_readwrite("device", &cinn_buffer_t::device) .def_readwrite("device_interface", &cinn_buffer_t::device_interface) - .def_readwrite("host_memory", &cinn_buffer_t::host_memory) + .def_readwrite("memory", &cinn_buffer_t::memory) .def_readwrite("flag", &cinn_buffer_t::flag) .def_readwrite("type", &cinn_buffer_t::type) .def_readwrite("dimensions", &cinn_buffer_t::dimensions) - .def_readwrite("dims", &cinn_buffer_t::dims) + //.def_readwrite("dims", &cinn_buffer_t::dims) .def_readwrite("lazy", &cinn_buffer_t::lazy) .def_readwrite("memory_size", &cinn_buffer_t::memory_size) .def_readwrite("align", &cinn_buffer_t::align) diff --git a/cinn/python/hlir_api_wrapper.cc b/cinn/python/hlir_api_wrapper.cc index b406fef021b70..5fcc1404fcb63 100644 --- a/cinn/python/hlir_api_wrapper.cc +++ b/cinn/python/hlir_api_wrapper.cc @@ -81,9 +81,9 @@ std::shared_ptr py_buffer::from_numpy(pybind11::array array) { auto buffer = std::make_shared(shape, type, "x86", 32); cinn_buffer_malloc(nullptr, buffer->data_); - CHECK(buffer->data_->host_memory); + CHECK(buffer->data_->memory); - std::memcpy(static_cast(buffer->data_->host_memory), array.mutable_data(), buffer->data_->memory_size); + std::memcpy(static_cast(buffer->data_->memory), array.mutable_data(), buffer->data_->memory_size); return buffer; } @@ -103,8 +103,8 @@ pybind11::array py_buffer::numpy() { pybind11::array::ShapeContainer shape(data_->dims, data_->dims + data_->dimensions); pybind11::array py_data(t, std::move(shape)); - CHECK(data_->host_memory); - std::memcpy(py_data.mutable_data(), data_->host_memory, data_->memory_size); + CHECK(data_->memory); + std::memcpy(py_data.mutable_data(), data_->memory, data_->memory_size); return py_data; } diff --git a/cinn/runtime/cinn_runtime.cc b/cinn/runtime/cinn_runtime.cc index 7d33cba7a6ea1..e011f76c1b8f8 100644 --- a/cinn/runtime/cinn_runtime.cc +++ b/cinn/runtime/cinn_runtime.cc @@ -26,7 +26,7 @@ void* cinn_buffer_slice(struct cinn_buffer_t* buf, uint32_t offset) { CINN_CHECK(buf); uint64_t offset_byte = offset * buf->type.bytes(); CINN_CHECK_LT(offset_byte, buf->memory_size); - return buf->host_memory + offset_byte; + return buf->memory + offset_byte; } int cinn_device_sync(void* context, struct cinn_buffer_t* buf) { @@ -65,12 +65,12 @@ int cinn_buffer_copy(void* context, struct cinn_buffer_t* src, struct cinn_buffe void* cinn_buffer_get_data_handle(struct cinn_buffer_t* buf) { CINN_CHECKP(buf, "%s", "buffer is null"); - return buf->host_memory; + return buf->memory; } void* cinn_buffer_get_data_const_handle(const struct cinn_buffer_t* buf) { CINN_CHECKP(buf, "%s", "buffer is null"); - return buf->host_memory; + return buf->memory; } cinn_type_t cinn_unk_t() { return cinn_type_t(cinn_type_unk, 0); } @@ -87,20 +87,20 @@ struct cinn_buffer_t* cinn_buffer_t::new_(cinn_device_kind_t device, cinn_type_t type, const std::vector& shape, int align) { - int32_t dimensions = shape.size(); - cinn_dimension_t* dims = (cinn_dimension_t*)malloc(sizeof(cinn_dimension_t) * dimensions); // NOLINT - memcpy(dims, shape.data(), shape.size() * sizeof(int)); - - struct cinn_buffer_t* x = (struct cinn_buffer_t*)malloc(sizeof(struct cinn_buffer_t)); - x->type = type; - x->device = device; - x->host_memory = nullptr; - x->memory_size = 0; - x->lazy = true; + int32_t dimensions = shape.size(); + CINN_CHECK(shape.size() < CINN_BUFFER_MAX_DIMS); + + struct cinn_buffer_t* buf = (struct cinn_buffer_t*)malloc(sizeof(struct cinn_buffer_t)); + memcpy(&(buf->dims[0]), shape.data(), shape.size() * sizeof(int)); + buf->type = type; + buf->device = device; + buf->memory = nullptr; + buf->memory_size = 0; + buf->lazy = true; // NOTE set device_interface for each buffer. - switch (x->device) { + switch (buf->device) { case cinn_x86_device: - x->device_interface = cinn_x86_device_interface(); + buf->device_interface = cinn_x86_device_interface(); break; case cinn_unk_device: fprintf(stderr, "Device type of buffer should be set, found Unk"); @@ -111,10 +111,9 @@ struct cinn_buffer_t* cinn_buffer_t::new_(cinn_device_kind_t device, abort(); } - x->dims = dims; - x->dimensions = dimensions; - x->align = align; - return x; + buf->dimensions = dimensions; + buf->align = align; + return buf; } cinn_buffer_t* cinn_buffer_new(cinn_device_kind_t device, cinn_type_t type, const std::vector& shape, int align) { @@ -213,10 +212,10 @@ void debug_pod_value(cinn_pod_value_t v, int i) { switch (v.type_code()) { case cinn_pod_value_t::type_code(): { cinn_buffer_t* node = v; - if (node->host_memory) { - cinn_print_debug_string("arg[%d].host_memory: %p\n", i, node->host_memory); + if (node->memory) { + cinn_print_debug_string("arg[%d].memory: %p\n", i, node->memory); } else { - cinn_print_debug_string("arg[%d].host_memory: %p\n", i, NULL); + cinn_print_debug_string("arg[%d].memory: %p\n", i, NULL); } } break; case cinn_pod_value_t::type_code(): { diff --git a/cinn/runtime/cinn_runtime.h b/cinn/runtime/cinn_runtime.h index 9064cdc94cd26..5b3a2120484df 100644 --- a/cinn/runtime/cinn_runtime.h +++ b/cinn/runtime/cinn_runtime.h @@ -1,3 +1,7 @@ +/** + * This file contains some core runtime concepts, the basic definition is used in C so that it can be deployed in some + * light-weight devices. + */ #ifndef CINN_RUNTIME_CINN_RUNTIME_H_ #define CINN_RUNTIME_CINN_RUNTIME_H_ #ifdef __cplusplus @@ -138,6 +142,7 @@ extern void* cinn_buffer_get_data_handle(struct cinn_buffer_t* buf); extern void* cinn_buffer_get_data_const_handle(const struct cinn_buffer_t* buf); //! The raw representation of a buffer,used in the generated code/lib. +#define CINN_BUFFER_MAX_DIMS 8 typedef struct cinn_buffer_t { //! Tell which kind of device this buffer locates. cinn_device_kind_t device; @@ -146,7 +151,7 @@ typedef struct cinn_buffer_t { const struct cinn_device_interface_t* device_interface; //! A pointer to the memory in host. - uint8_t* host_memory; + uint8_t* memory; //! Extra flags. uint64_t flag; @@ -156,7 +161,7 @@ typedef struct cinn_buffer_t { //! Number of dimensions. int32_t dimensions; - cinn_dimension_t* dims; + cinn_dimension_t dims[CINN_BUFFER_MAX_DIMS]; //! Allocate and deallocate lazily, default true. char lazy; @@ -170,11 +175,10 @@ typedef struct cinn_buffer_t { cinn_buffer_t() : device(cinn_unk_device), device_interface(NULL), - host_memory(NULL), + memory(NULL), flag(0UL), type(cinn_type_t()), dimensions(0), - dims(NULL), memory_size(0), align(0), lazy(true) {} @@ -185,20 +189,13 @@ typedef struct cinn_buffer_t { int align = 0); static void delete_(struct cinn_buffer_t* x) { delete x; } - ~cinn_buffer_t() { - delete host_memory; - delete dims; - } + ~cinn_buffer_t() {} // NOTE the buffer should be resized first. static void alloc(struct cinn_buffer_t*); //! Set the shape of the buffer. NOTE this just record the shape, not allocate the memory. CINN_ALWAYS_INLINE void resize(const cinn_dimension_t* dims, int dimensions) { - if (this->dimensions != dimensions) { - if (this->dims) free(this->dims); - this->dims = (cinn_dimension_t*)malloc(dimensions * sizeof(cinn_dimension_t)); // NOLINT - } this->dimensions = dimensions; memcpy(this->dims, dims, dimensions * sizeof(cinn_dimension_t)); } @@ -224,7 +221,7 @@ typedef struct cinn_buffer_t { } CINN_ALWAYS_INLINE uint8_t* begin() const { return 0; } - CINN_ALWAYS_INLINE uint8_t* end() const { return host_memory + num_elements() * type.bytes(); } + CINN_ALWAYS_INLINE uint8_t* end() const { return memory + num_elements() * type.bytes(); } CINN_ALWAYS_INLINE bool get_flag(cinn_buffer_kind_t flag) const { return (this->flag & flag) != 0; } CINN_ALWAYS_INLINE void set_flag(cinn_buffer_kind_t flag, bool value) { @@ -258,10 +255,10 @@ struct cinn_device_interface_impl_t { extern struct cinn_device_interface_t* cinn_x86_device_interface(); inline float cinn_buffer_load_float32(struct cinn_buffer_t* buf, uint32_t index) { - return ((float*)buf->host_memory)[index]; // NOLINT + return ((float*)buf->memory)[index]; // NOLINT } inline double cinn_buffer_load_float64(struct cinn_buffer_t* buf, uint32_t index) { - return ((double*)buf->host_memory)[index]; // NOLINT + return ((double*)buf->memory)[index]; // NOLINT } #endif // __cplusplus diff --git a/cinn/runtime/cinn_runtime_test.cc b/cinn/runtime/cinn_runtime_test.cc index 2934fc499b3ce..97a661de91d64 100644 --- a/cinn/runtime/cinn_runtime_test.cc +++ b/cinn/runtime/cinn_runtime_test.cc @@ -8,7 +8,7 @@ TEST(buffer, basic) { ASSERT_TRUE(buffer->device_interface); ASSERT_EQ(buffer->device_interface, cinn_x86_device_interface()); buffer->device_interface->impl->malloc(NULL, buffer); - auto* data = reinterpret_cast(buffer->host_memory); + auto* data = reinterpret_cast(buffer->memory); data[0] = 0.f; data[1] = 1.f; EXPECT_EQ(data[0], 0.f); diff --git a/cinn/runtime/cinn_x86_device_impl.cc b/cinn/runtime/cinn_x86_device_impl.cc index c30116aed78c7..50e06fbd71e9d 100644 --- a/cinn/runtime/cinn_x86_device_impl.cc +++ b/cinn/runtime/cinn_x86_device_impl.cc @@ -8,28 +8,28 @@ int cinn_x86_malloc(void* context, cinn_buffer_t* buf) { uint64_t memory_size = buf->num_elements() * buf->type.bytes(); CINN_CHECK(memory_size > 0); if (buf->memory_size < memory_size) { - if (buf->host_memory) { - free(buf->host_memory); + if (buf->memory) { + free(buf->memory); } int bytes = buf->type.bytes() * buf->num_elements(); if (buf->align == 0) { - buf->host_memory = (unsigned char*)malloc(bytes); + buf->memory = (unsigned char*)malloc(bytes); } else { - buf->host_memory = (unsigned char*)aligned_alloc(buf->align, bytes); + buf->memory = (unsigned char*)aligned_alloc(buf->align, bytes); } buf->memory_size = memory_size; CINN_LOG("buf.memory size is %ld\n", buf->memory_size); } - ASSERT_NOT_NULL(buf->host_memory); + ASSERT_NOT_NULL(buf->memory); return 0; } int cinn_x86_free(void* context, cinn_buffer_t* buf) { // ASSERT_NOT_NULL(context); ASSERT_NOT_NULL(buf); - if (buf->host_memory) { - free(buf->host_memory); - buf->host_memory = NULL; + if (buf->memory) { + free(buf->memory); + buf->memory = NULL; } return 0; } diff --git a/cinn/runtime/cpu/cblas.cc b/cinn/runtime/cpu/cblas.cc index 5077a88215816..19a35ed655038 100644 --- a/cinn/runtime/cpu/cblas.cc +++ b/cinn/runtime/cpu/cblas.cc @@ -32,12 +32,12 @@ void cinn_cpu_mkl_gemm_fp32(float alpha, N, K, alpha, - reinterpret_cast(A->host_memory), + reinterpret_cast(A->memory), lda, - reinterpret_cast(B->host_memory), + reinterpret_cast(B->memory), ldb, beta, - reinterpret_cast(C->host_memory), + reinterpret_cast(C->memory), ldc); } diff --git a/cinn/runtime/cpu/host_intrinsics.cc b/cinn/runtime/cpu/host_intrinsics.cc index cd1d5c529e609..7f0781ff8e06a 100644 --- a/cinn/runtime/cpu/host_intrinsics.cc +++ b/cinn/runtime/cpu/host_intrinsics.cc @@ -19,8 +19,8 @@ float __cinn_host_tanh_fp32(float x) { return std::tanh(x); } void __cinn_host_tanh_v(const cinn_buffer_t* x, cinn_buffer_t* out) { CINN_CHECK_EQ(x->num_elements(), out->num_elements()); int xn = x->num_elements(); - auto* x_data = (float*)(x->host_memory); - auto* out_data = (float*)(out->host_memory); + auto* x_data = (float*)(x->memory); + auto* out_data = (float*)(out->memory); for (int i = 0; i < x->num_elements(); i++) { out_data[i] = __cinn_host_tanh_fp32(x_data[i]); } diff --git a/cinn/runtime/cpu/mkl_math.cc b/cinn/runtime/cpu/mkl_math.cc index e45d8c6b9d080..d6a7802d75b53 100644 --- a/cinn/runtime/cpu/mkl_math.cc +++ b/cinn/runtime/cpu/mkl_math.cc @@ -12,25 +12,25 @@ void cinn_mkl_tanh_v_fp32(cinn_buffer_t *x, cinn_buffer_t *out) { CHECK_EQ(x->num_elements(), out->num_elements()); - vsTanh(x->num_elements(), reinterpret_cast(x->host_memory), reinterpret_cast(out->host_memory)); + vsTanh(x->num_elements(), reinterpret_cast(x->memory), reinterpret_cast(out->memory)); } void cinn_mkl_tanh_v_fp64(cinn_buffer_t *x, cinn_buffer_t *out) { CHECK_EQ(x->num_elements(), out->num_elements()); - vdTanh(x->num_elements(), reinterpret_cast(x->host_memory), reinterpret_cast(out->host_memory)); + vdTanh(x->num_elements(), reinterpret_cast(x->memory), reinterpret_cast(out->memory)); } void cinn_mkl_exp_v_fp32(cinn_buffer_t *x, cinn_buffer_t *out) { CHECK_EQ(x->num_elements(), out->num_elements()); - vdExp(x->num_elements(), reinterpret_cast(x->host_memory), reinterpret_cast(out->host_memory)); + vdExp(x->num_elements(), reinterpret_cast(x->memory), reinterpret_cast(out->memory)); } /* void cinn_mkl_cos_v_fp32(cinn_buffer_t *x, cinn_buffer_t *out) { CHECK_EQ(x->num_elements(), out->num_elements()); - vsCosh(x->num_elements(), reinterpret_cast(x->host_memory), reinterpret_cast(out->host_memory)); + vsCosh(x->num_elements(), reinterpret_cast(x->memory), reinterpret_cast(out->memory)); } void cinn_mkl_cos_v_fp64(cinn_buffer_t *x, cinn_buffer_t *out) { CHECK_EQ(x->num_elements(), out->num_elements()); - vdCosh(x->num_elements(), reinterpret_cast(x->host_memory), reinterpret_cast(out->host_memory)); + vdCosh(x->num_elements(), reinterpret_cast(x->memory), reinterpret_cast(out->memory)); } */ diff --git a/cinn/runtime/cpu/mkl_math_test.cc b/cinn/runtime/cpu/mkl_math_test.cc index d6dbe0830371b..5ce06e1f9a273 100644 --- a/cinn/runtime/cpu/mkl_math_test.cc +++ b/cinn/runtime/cpu/mkl_math_test.cc @@ -63,8 +63,8 @@ void TestCallElementwise(const std::string &fn_name, float (*fn_runtime)(float), cinn_pod_value_t args[] = {a_arg, b_arg}; fn_(args, 2); - auto *ad = reinterpret_cast(A_buf->host_memory); - auto *bd = reinterpret_cast(B_buf->host_memory); + auto *ad = reinterpret_cast(A_buf->memory); + auto *bd = reinterpret_cast(B_buf->memory); for (int i = 0; i < A_buf->num_elements(); i++) { ASSERT_NEAR(bd[i], fn_runtime(ad[i]), 1e-5); } diff --git a/cinn/runtime/cuda/cuda_util.cc b/cinn/runtime/cuda/cuda_util.cc index f0d0fe90645b5..48a58748ef908 100644 --- a/cinn/runtime/cuda/cuda_util.cc +++ b/cinn/runtime/cuda/cuda_util.cc @@ -25,7 +25,7 @@ void cinn_call_cuda_kernel(void *kernel_fn, CHECK_LT(num_args, 20); for (int i = 0; i < num_args; i++) { if (args[i].type_code() == cinn_pod_value_t::type_code()) { - arr[i] = &((cinn_buffer_t *)args[i])->host_memory; + arr[i] = &((cinn_buffer_t *)args[i])->memory; } else { arr[i] = args[i].data_addr(); } diff --git a/tests/test01_elementwise_add_case.cc b/tests/test01_elementwise_add_case.cc index e616f1aed7022..6f8d0092b42ba 100644 --- a/tests/test01_elementwise_add_case.cc +++ b/tests/test01_elementwise_add_case.cc @@ -14,9 +14,9 @@ TEST(test01, basic) { auto* B = cinn::common::BufferBuilder(Float(32), {100, 32}).set_align(32).set_random().Build(); auto* C = cinn::common::BufferBuilder(Float(32), {100, 32}).set_align(32).set_zero().Build(); - float* Ad = reinterpret_cast(A->host_memory); - float* Bd = reinterpret_cast(B->host_memory); - float* Cd = reinterpret_cast(C->host_memory); + float* Ad = reinterpret_cast(A->memory); + float* Bd = reinterpret_cast(B->memory); + float* Cd = reinterpret_cast(C->memory); ASSERT_EQ(C->num_elements(), A->num_elements()); auto check = [&] { @@ -44,9 +44,9 @@ TEST(test01, compute_at) { auto* B = cinn::common::BufferBuilder(Float(32), {M, N}).set_align(32).set_random().Build(); auto* C = cinn::common::BufferBuilder(Float(32), {M, N}).set_align(32).set_zero().Build(); - float* Ad = reinterpret_cast(A->host_memory); - float* Bd = reinterpret_cast(B->host_memory); - float* Cd = reinterpret_cast(C->host_memory); + float* Ad = reinterpret_cast(A->memory); + float* Bd = reinterpret_cast(B->memory); + float* Cd = reinterpret_cast(C->memory); ASSERT_EQ(C->num_elements(), A->num_elements()); auto check_add = [&] { @@ -91,9 +91,9 @@ TEST(test01, compute_at_level1) { auto* B = cinn::common::BufferBuilder(Float(32), {M, N}).set_align(32).set_random().Build(); auto* C = cinn::common::BufferBuilder(Float(32), {M, N}).set_align(32).set_zero().Build(); - float* Ad = reinterpret_cast(A->host_memory); - float* Bd = reinterpret_cast(B->host_memory); - float* Cd = reinterpret_cast(C->host_memory); + float* Ad = reinterpret_cast(A->memory); + float* Bd = reinterpret_cast(B->memory); + float* Cd = reinterpret_cast(C->memory); ASSERT_EQ(C->num_elements(), A->num_elements()); auto check_add = [&] { diff --git a/tests/test02_matmul_case.cc b/tests/test02_matmul_case.cc index 42c0cd3534532..d1ba4b879417f 100644 --- a/tests/test02_matmul_case.cc +++ b/tests/test02_matmul_case.cc @@ -33,10 +33,10 @@ TEST(test02, basic) { cinn_buffer_malloc(nullptr, C); cinn_buffer_malloc(nullptr, packedB); - float* Ad = reinterpret_cast(A->host_memory); - float* Bd = reinterpret_cast(B->host_memory); - float* Cd_target = reinterpret_cast(C_target->host_memory); - float* Cd = reinterpret_cast(C->host_memory); + float* Ad = reinterpret_cast(A->memory); + float* Bd = reinterpret_cast(B->memory); + float* Cd_target = reinterpret_cast(C_target->memory); + float* Cd = reinterpret_cast(C->memory); for (int i = 0; i < M; i++) { for (int k = 0; k < K; k++) {