diff --git a/cinn/CMakeLists.txt b/cinn/CMakeLists.txt index a1608f14cea24..70a9047e2e1ed 100644 --- a/cinn/CMakeLists.txt +++ b/cinn/CMakeLists.txt @@ -9,5 +9,4 @@ add_subdirectory(backends) add_subdirectory(lang) add_subdirectory(optim) add_subdirectory(hlir) -#add_subdirectory(python) add_subdirectory(pybind) diff --git a/cinn/backends/compiler.cc b/cinn/backends/compiler.cc index bbdf3b7bea666..57cb4697070df 100644 --- a/cinn/backends/compiler.cc +++ b/cinn/backends/compiler.cc @@ -68,7 +68,7 @@ void Compiler::CompileCudaModule(const Module& module) { void Compiler::CompileX86Module(const Module& module) { engine_->Link(module); } -lower_func_ptr_t Compiler::GetFn(std::string_view fn_name) { +lower_func_ptr_t Compiler::Lookup(std::string_view fn_name) { CHECK(engine_); return reinterpret_cast(engine_->Lookup(fn_name)); } diff --git a/cinn/backends/compiler.h b/cinn/backends/compiler.h index b603aa446f91d..e503e710f91d8 100644 --- a/cinn/backends/compiler.h +++ b/cinn/backends/compiler.h @@ -7,6 +7,7 @@ #include "cinn/backends/llvm/codegen_llvm.h" #include "cinn/backends/llvm/execution_engine.h" #include "cinn/backends/llvm/simple_jit.h" +#include "cinn/ir/packed_func.h" #ifdef CINN_WITH_CUDA #include "cinn/runtime/cuda/cuda_module.h" #endif @@ -29,7 +30,7 @@ class Compiler final { * Retrieve a function by \p fn_name. * @return function address or null if not exists. */ - lower_func_ptr_t GetFn(std::string_view fn_name); + lower_func_ptr_t Lookup(std::string_view fn_name); private: void CompileCudaModule(const lang::Module& module); diff --git a/cinn/backends/compiler_test.cc b/cinn/backends/compiler_test.cc index 7ca137791bd02..bc04acda2ec8c 100644 --- a/cinn/backends/compiler_test.cc +++ b/cinn/backends/compiler_test.cc @@ -34,7 +34,7 @@ TEST(Compiler, x86) { auto compiler = Compiler::Create(common::DefaultHostTarget()); compiler->Build(builder.Build()); - auto* fnp = compiler->GetFn("fn"); + auto* fnp = compiler->Lookup("fn"); ASSERT_TRUE(fnp); auto* Ab = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}).set_random().Build(); @@ -67,7 +67,7 @@ TEST(Compiler, x86) { auto compiler = Compiler::Create(common::DefaultNVGPUTarget()); compiler->Build(builder.Build()); - auto* fnp = compiler->GetFn("fn"); + auto* fnp = compiler->Lookup("fn"); ASSERT_TRUE(fnp); auto* Ab = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}).set_random().Build(); diff --git a/cinn/backends/cuda_util.h b/cinn/backends/cuda_util.h index 516593f853d7d..3fda6c4908ea8 100644 --- a/cinn/backends/cuda_util.h +++ b/cinn/backends/cuda_util.h @@ -1,4 +1,5 @@ #pragma once +#ifdef CINN_WITH_CUDA #include #include @@ -45,3 +46,5 @@ std::string cuda_block_axis_name(int level); } // namespace backends } // namespace cinn + +#endif // CINN_WITH_CUDA diff --git a/cinn/backends/nvrtc_util.h b/cinn/backends/nvrtc_util.h index 1b3b20943bfe0..e89e3c4085285 100644 --- a/cinn/backends/nvrtc_util.h +++ b/cinn/backends/nvrtc_util.h @@ -1,5 +1,5 @@ #pragma once - +#ifdef CINN_WITH_CUDA #if defined(__linux__) #include #endif @@ -47,3 +47,5 @@ class NVRTC_Compiler { } // namespace backends } // namespace cinn + +#endif // CINN_WITH_CUDA diff --git a/cinn/common/cinn_value.h b/cinn/common/cinn_value.h index b496f65a5b2bc..6984cd6f7178b 100644 --- a/cinn/common/cinn_value.h +++ b/cinn/common/cinn_value.h @@ -1,9 +1,9 @@ #pragma once #include +#include #include -#include #include "cinn/common/common.h" #include "cinn/common/macros.h" #include "cinn/common/object.h" diff --git a/cinn/common/context.h b/cinn/common/context.h index 0d0435a5a6d41..a4b39c9cf5255 100644 --- a/cinn/common/context.h +++ b/cinn/common/context.h @@ -1,10 +1,12 @@ #pragma once #include #include + #include #include #include #include + #include "cinn/common/debug_manager.h" #include "cinn/common/info_registry.h" diff --git a/cinn/hlir/framework/memory.h b/cinn/hlir/framework/memory.h index 2ec9cdebe5163..d7078b1a14aa1 100644 --- a/cinn/hlir/framework/memory.h +++ b/cinn/hlir/framework/memory.h @@ -1,8 +1,10 @@ #pragma once #include + #include #include + #include "cinn/common/macros.h" #include "cinn/common/target.h" diff --git a/cinn/hlir/framework/print_graph_pass_test.cc b/cinn/hlir/framework/print_graph_pass_test.cc index 6a950c938ae7f..d7e87d1a5721f 100644 --- a/cinn/hlir/framework/print_graph_pass_test.cc +++ b/cinn/hlir/framework/print_graph_pass_test.cc @@ -1,6 +1,8 @@ #include + #include #include + #include "cinn/hlir/framework/graph.h" #include "cinn/hlir/framework/node.h" #include "cinn/hlir/framework/op.h" diff --git a/cinn/hlir/framework/scope_test.cc b/cinn/hlir/framework/scope_test.cc index fa289bd6dd515..37fdc38a15f71 100644 --- a/cinn/hlir/framework/scope_test.cc +++ b/cinn/hlir/framework/scope_test.cc @@ -1,4 +1,5 @@ #include "cinn/hlir/framework/scope.h" + #include namespace cinn { diff --git a/cinn/poly/poly_scheduler.cc b/cinn/poly/poly_scheduler.cc index f252ad3a33636..fe56aea26d1c9 100644 --- a/cinn/poly/poly_scheduler.cc +++ b/cinn/poly/poly_scheduler.cc @@ -6,9 +6,9 @@ #include #include #include - #include #include + #include "cinn/poly/isl_utils.h" namespace cinn { diff --git a/cinn/pybind/CMakeLists.txt b/cinn/pybind/CMakeLists.txt index 0d8b5d3e87485..f500590524c9a 100644 --- a/cinn/pybind/CMakeLists.txt +++ b/cinn/pybind/CMakeLists.txt @@ -1,12 +1,20 @@ -cc_library(core_api SHARED - SRCS runtime.cc common.cc lang.cc ir.cc poly.cc backends.cc bind.cc optim.cc - DEPS core cinn_runtime pybind) - +if (WITH_CUDA) + message(STATUS "Compile core_api with CUDA support") + nv_library(core_api SHARED + SRCS runtime.cc common.cc lang.cc ir.cc poly.cc backends.cc bind.cc optim.cc pe.cc + DEPS core cinn_runtime pybind) + message("cuda_nvrtc: ${CUDA_NVRTC}") + target_link_libraries(core_api ${CUDA_NVRTC_LIB} ${CUDA_LIBRARIES} cuda) +else() + message(STATUS "Compile core_api without CUDA support") + cc_library(core_api SHARED + SRCS runtime.cc common.cc lang.cc ir.cc poly.cc backends.cc bind.cc optim.cc pe.cc + DEPS core cinn_runtime pybind) +endif() execute_process(COMMAND python3 -m pybind11 --includes OUTPUT_VARIABLE pybind_includes) string(REGEX REPLACE "\n$" "" pybind_includes "${pybind_includes}") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${pybind_includes}") -#execute_process(COMMAND python3-config --extension-suffix OUTPUT_VARIABLE python3_module_suffix) SET_TARGET_PROPERTIES(core_api PROPERTIES PREFIX "") diff --git a/cinn/pybind/backends.cc b/cinn/pybind/backends.cc index 3fcb06f81e62d..724814d54ac0b 100644 --- a/cinn/pybind/backends.cc +++ b/cinn/pybind/backends.cc @@ -3,6 +3,7 @@ #include +#include "cinn/backends/compiler.h" #include "cinn/backends/llvm/execution_engine.h" #include "cinn/pybind/bind.h" @@ -11,9 +12,13 @@ namespace py = pybind11; struct cinn_pod_value_t; namespace cinn::pybind { + +using backends::Compiler; using backends::ExecutionEngine; using backends::ExecutionOptions; + namespace { + void BindExecutionEngine(py::module *); void BindExecutionEngine(py::module *m) { @@ -35,7 +40,24 @@ void BindExecutionEngine(py::module *m) { .def(py::init(&ExecutionEngine::Create), py::arg("options") = ExecutionOptions()) .def("lookup", lookup) .def("link", &ExecutionEngine::Link); + + { + auto lookup = [](Compiler &self, std::string_view name) { + auto *function_ptr = reinterpret_cast(self.Lookup(name)); + auto function_wrapper = [function_ptr](std::vector &args) { + function_ptr(reinterpret_cast(args.data()), args.size()); + }; + return std::function(function_wrapper); + }; + + py::class_ compiler(*m, "Compiler"); + compiler + .def_static("create", &Compiler::Create) // + .def("build", &Compiler::Build) // + .def("lookup", lookup); + } } + } // namespace void BindBackends(py::module *m) { BindExecutionEngine(m); } diff --git a/cinn/pybind/bind.cc b/cinn/pybind/bind.cc index ceff6fd2d0228..bd572998c8dd0 100644 --- a/cinn/pybind/bind.cc +++ b/cinn/pybind/bind.cc @@ -14,6 +14,7 @@ PYBIND11_MODULE(core_api, m) { py::module poly = m.def_submodule("poly", "namespace cinn::poly, polyhedral"); py::module backends = m.def_submodule("backends", "namespace cinn::backends, execution backends"); py::module optim = m.def_submodule("optim", "namespace cinn::optim, CINN IR optimization"); + py::module pe = m.def_submodule("pe", "namespace cinn::hlir::pe, CINN Primitive Emitters"); BindRuntime(&runtime); BindCommon(&common); @@ -22,5 +23,6 @@ PYBIND11_MODULE(core_api, m) { BindPoly(&poly); BindBackends(&backends); BindOptim(&optim); + BindPE(&pe); } } // namespace cinn::pybind diff --git a/cinn/pybind/bind.h b/cinn/pybind/bind.h index e55c8e801907b..597f5b3d72959 100644 --- a/cinn/pybind/bind.h +++ b/cinn/pybind/bind.h @@ -12,5 +12,6 @@ void BindIr(pybind11::module *m); void BindBackends(pybind11::module *m); void BindPoly(pybind11::module *m); void BindOptim(pybind11::module *m); +void BindPE(pybind11::module *m); } // namespace cinn::pybind diff --git a/cinn/pybind/bind_utils.h b/cinn/pybind/bind_utils.h index 2ab2c51b3dd42..6f9ca605db058 100644 --- a/cinn/pybind/bind_utils.h +++ b/cinn/pybind/bind_utils.h @@ -28,7 +28,7 @@ struct Visitor : Ts... { }; template -Visitor(Ts...)->Visitor; +Visitor(Ts...) -> Visitor; using ExprOp = std::variantdef(#name__, &hlir::pe::fn__, py::arg("x"), py::arg("out") = "T_" #name__ "_out") + BIND_UNARY(exp, Exp); + BIND_UNARY(erf, Erf); + BIND_UNARY(sqrt, Sqrt); + BIND_UNARY(log, Log); + BIND_UNARY(log2, Log2); + BIND_UNARY(log10, Log10); + BIND_UNARY(floor, Floor); + BIND_UNARY(ceil, Ceil); + BIND_UNARY(round, Round); + BIND_UNARY(trunc, Trunc); + BIND_UNARY(cos, Cos); + BIND_UNARY(cosh, Cosh); + BIND_UNARY(tan, Tan); + BIND_UNARY(sin, Sin); + BIND_UNARY(sinh, Sinh); + BIND_UNARY(acos, Acos); + BIND_UNARY(acosh, Acosh); + BIND_UNARY(asin, Asin); + BIND_UNARY(asinh, Asinh); + BIND_UNARY(atan, Atan); + BIND_UNARY(atanh, Atanh); + BIND_UNARY(isnan, Isnan); + BIND_UNARY(tanh, Tanh); + BIND_UNARY(isfinite, Isfinite); + BIND_UNARY(isinf, Isinf); + + BIND_UNARY(negative, Negative); + BIND_UNARY(identity, Identity); + BIND_UNARY(logical_not, LogicalNot); + BIND_UNARY(bitwise_not, BitwiseNot); + BIND_UNARY(sigmoid, Sigmoid); + BIND_UNARY(sign, Sign); + BIND_UNARY(abs, Abs); + BIND_UNARY(rsqrt, Rsqrt); +} + +} // namespace pybind +} // namespace cinn diff --git a/cinn/runtime/cinn_runtime.h b/cinn/runtime/cinn_runtime.h index 5b3a2120484df..a90d968075253 100644 --- a/cinn/runtime/cinn_runtime.h +++ b/cinn/runtime/cinn_runtime.h @@ -360,6 +360,9 @@ struct cinn_pod_value_t { template static int type_code(); + void set_type_code(int x) { type_code_ = x; } + void set_value(union cinn_value_t x) { value_ = x; } + protected: // @} #endif // __cplusplus diff --git a/cinn/runtime/cuda/cuda_module.h b/cinn/runtime/cuda/cuda_module.h index 2e342ac62a91c..ef2753cad6f87 100644 --- a/cinn/runtime/cuda/cuda_module.h +++ b/cinn/runtime/cuda/cuda_module.h @@ -1,4 +1,6 @@ #pragma once +#ifdef CINN_WITH_CUDA + #include #include #include @@ -87,3 +89,5 @@ class CUDAModule { } // namespace cuda } // namespace runtime } // namespace cinn + +#endif // CINN_WITH_CUDA diff --git a/python/cinn/__init__.py b/python/cinn/__init__.py index 6fe534db56993..7898462042186 100644 --- a/python/cinn/__init__.py +++ b/python/cinn/__init__.py @@ -3,3 +3,4 @@ from .core_api.common import CINNValue from .core_api.backends import ExecutionOptions from .core_api.backends import ExecutionEngine +from .core_api.backends import Compiler diff --git a/python/cinn/pe.py b/python/cinn/pe.py new file mode 100644 index 0000000000000..2df124ab65ce4 --- /dev/null +++ b/python/cinn/pe.py @@ -0,0 +1 @@ +from .core_api.pe import * diff --git a/python/tests/test_pe.py b/python/tests/test_pe.py new file mode 100644 index 0000000000000..aeda3385a7723 --- /dev/null +++ b/python/tests/test_pe.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python3 +import unittest +import cinn +import numpy as np +from cinn import runtime +from cinn import ir +from cinn import lang +from cinn import Target +from cinn import pe +from cinn.common import * + + +class TestPE(unittest.TestCase): + def setUp(self): + self.m = 32 + self.n = 32 + + self.target = Target() + self.target.arch = Target.Arch.X86 + self.target.bits = Target.Bit.k32 + self.target.os = Target.OS.Linux + + self.unary_data = [] + + def test_unary(self): + for (fn_name, pe_fn, np_fn) in [ + ("tanh", pe.tanh, np.tanh), + ("cos", pe.cos, np.cos), + ("exp", pe.exp, np.exp), + ]: + self.compiler = cinn.Compiler.create(self.target) + self.union_tester(fn_name, pe_fn, np_fn) + + def union_tester(self, fn_name, cinn_fn, np_fn): + m, n = [ir.Expr(_) for _ in ( + self.m, + self.n, + )] + x = lang.Placeholder("float32", "x", [m, n]) + y = cinn_fn(x.to_tensor()) + + func_name = "test_" + fn_name + + func = lang.lower(func_name, [x.to_tensor(), y]) + + builder = lang.Module.Builder("elementwise_module", self.target) + builder.add_function(func) + + module = builder.build() + self.compiler.build(module) + + fn = self.compiler.lookup(func_name) + + x_data, x_buf, out_buf, *args = self.create_data() + fn(args) + + self.assertTrue( + np.allclose( + out_buf.numpy(), + self.create_target_data(x_data, np_fn), + atol=1e-4)) + + def create_target_data(self, x_data, np_target_fn): + return np_target_fn(x_data) + + def create_data(self): + if not self.unary_data: + x_data = np.around( + np.random.randn(self.m, self.n).astype("float32"), 2) + x = runtime.cinn_buffer_t(x_data, runtime.cinn_x86_device) + out = runtime.cinn_buffer_t( + np.zeros([self.m, self.n]).astype("float32"), + runtime.cinn_x86_device) + self.unary_data = [ + x_data, x, out, + runtime.cinn_pod_value_t(x), + runtime.cinn_pod_value_t(out) + ] + + return self.unary_data + + +if __name__ == "__main__": + unittest.main()