Skip to content

Commit

Permalink
add python bind for pe (PaddlePaddle#169)
Browse files Browse the repository at this point in the history
* add python bind for pe

* refine pe test
  • Loading branch information
Superjomn authored Aug 12, 2020
1 parent 75dbf7d commit 8401b57
Show file tree
Hide file tree
Showing 23 changed files with 209 additions and 14 deletions.
1 change: 0 additions & 1 deletion cinn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,4 @@ add_subdirectory(backends)
add_subdirectory(lang)
add_subdirectory(optim)
add_subdirectory(hlir)
#add_subdirectory(python)
add_subdirectory(pybind)
2 changes: 1 addition & 1 deletion cinn/backends/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ void Compiler::CompileCudaModule(const Module& module) {

void Compiler::CompileX86Module(const Module& module) { engine_->Link(module); }

lower_func_ptr_t Compiler::GetFn(std::string_view fn_name) {
lower_func_ptr_t Compiler::Lookup(std::string_view fn_name) {
CHECK(engine_);
return reinterpret_cast<lower_func_ptr_t>(engine_->Lookup(fn_name));
}
Expand Down
3 changes: 2 additions & 1 deletion cinn/backends/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "cinn/backends/llvm/codegen_llvm.h"
#include "cinn/backends/llvm/execution_engine.h"
#include "cinn/backends/llvm/simple_jit.h"
#include "cinn/ir/packed_func.h"
#ifdef CINN_WITH_CUDA
#include "cinn/runtime/cuda/cuda_module.h"
#endif
Expand All @@ -29,7 +30,7 @@ class Compiler final {
* Retrieve a function by \p fn_name.
* @return function address or null if not exists.
*/
lower_func_ptr_t GetFn(std::string_view fn_name);
lower_func_ptr_t Lookup(std::string_view fn_name);

private:
void CompileCudaModule(const lang::Module& module);
Expand Down
4 changes: 2 additions & 2 deletions cinn/backends/compiler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ TEST(Compiler, x86) {
auto compiler = Compiler::Create(common::DefaultHostTarget());
compiler->Build(builder.Build());

auto* fnp = compiler->GetFn("fn");
auto* fnp = compiler->Lookup("fn");
ASSERT_TRUE(fnp);

auto* Ab = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}).set_random().Build();
Expand Down Expand Up @@ -67,7 +67,7 @@ TEST(Compiler, x86) {
auto compiler = Compiler::Create(common::DefaultNVGPUTarget());
compiler->Build(builder.Build());

auto* fnp = compiler->GetFn("fn");
auto* fnp = compiler->Lookup("fn");
ASSERT_TRUE(fnp);

auto* Ab = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}).set_random().Build();
Expand Down
3 changes: 3 additions & 0 deletions cinn/backends/cuda_util.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#pragma once
#ifdef CINN_WITH_CUDA

#include <cuda.h>
#include <cuda_runtime.h>
Expand Down Expand Up @@ -45,3 +46,5 @@ std::string cuda_block_axis_name(int level);

} // namespace backends
} // namespace cinn

#endif // CINN_WITH_CUDA
4 changes: 3 additions & 1 deletion cinn/backends/nvrtc_util.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#pragma once

#ifdef CINN_WITH_CUDA
#if defined(__linux__)
#include <sys/stat.h>
#endif
Expand Down Expand Up @@ -47,3 +47,5 @@ class NVRTC_Compiler {

} // namespace backends
} // namespace cinn

#endif // CINN_WITH_CUDA
2 changes: 1 addition & 1 deletion cinn/common/cinn_value.h
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#pragma once
#include <glog/logging.h>

#include <any>
#include <vector>

#include <any>
#include "cinn/common/common.h"
#include "cinn/common/macros.h"
#include "cinn/common/object.h"
Expand Down
2 changes: 2 additions & 0 deletions cinn/common/context.h
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
#pragma once
#include <gflags/gflags.h>
#include <isl/cpp.h>

#include <any>
#include <set>
#include <string>
#include <vector>

#include "cinn/common/debug_manager.h"
#include "cinn/common/info_registry.h"

Expand Down
2 changes: 2 additions & 0 deletions cinn/hlir/framework/memory.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
#pragma once

#include <glog/logging.h>

#include <memory>
#include <unordered_map>

#include "cinn/common/macros.h"
#include "cinn/common/target.h"

Expand Down
2 changes: 2 additions & 0 deletions cinn/hlir/framework/print_graph_pass_test.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include <gtest/gtest.h>

#include <any>
#include <string>

#include "cinn/hlir/framework/graph.h"
#include "cinn/hlir/framework/node.h"
#include "cinn/hlir/framework/op.h"
Expand Down
1 change: 1 addition & 0 deletions cinn/hlir/framework/scope_test.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "cinn/hlir/framework/scope.h"

#include <gtest/gtest.h>

namespace cinn {
Expand Down
2 changes: 1 addition & 1 deletion cinn/poly/poly_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
#include <limits>
#include <map>
#include <set>

#include <stack>
#include <unordered_set>

#include "cinn/poly/isl_utils.h"

namespace cinn {
Expand Down
18 changes: 13 additions & 5 deletions cinn/pybind/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
cc_library(core_api SHARED
SRCS runtime.cc common.cc lang.cc ir.cc poly.cc backends.cc bind.cc optim.cc
DEPS core cinn_runtime pybind)

if (WITH_CUDA)
message(STATUS "Compile core_api with CUDA support")
nv_library(core_api SHARED
SRCS runtime.cc common.cc lang.cc ir.cc poly.cc backends.cc bind.cc optim.cc pe.cc
DEPS core cinn_runtime pybind)
message("cuda_nvrtc: ${CUDA_NVRTC}")
target_link_libraries(core_api ${CUDA_NVRTC_LIB} ${CUDA_LIBRARIES} cuda)
else()
message(STATUS "Compile core_api without CUDA support")
cc_library(core_api SHARED
SRCS runtime.cc common.cc lang.cc ir.cc poly.cc backends.cc bind.cc optim.cc pe.cc
DEPS core cinn_runtime pybind)
endif()

execute_process(COMMAND python3 -m pybind11 --includes OUTPUT_VARIABLE pybind_includes)
string(REGEX REPLACE "\n$" "" pybind_includes "${pybind_includes}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${pybind_includes}")


#execute_process(COMMAND python3-config --extension-suffix OUTPUT_VARIABLE python3_module_suffix)
SET_TARGET_PROPERTIES(core_api PROPERTIES PREFIX "")
22 changes: 22 additions & 0 deletions cinn/pybind/backends.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include <functional>

#include "cinn/backends/compiler.h"
#include "cinn/backends/llvm/execution_engine.h"
#include "cinn/pybind/bind.h"

Expand All @@ -11,9 +12,13 @@ namespace py = pybind11;
struct cinn_pod_value_t;

namespace cinn::pybind {

using backends::Compiler;
using backends::ExecutionEngine;
using backends::ExecutionOptions;

namespace {

void BindExecutionEngine(py::module *);

void BindExecutionEngine(py::module *m) {
Expand All @@ -35,7 +40,24 @@ void BindExecutionEngine(py::module *m) {
.def(py::init(&ExecutionEngine::Create), py::arg("options") = ExecutionOptions())
.def("lookup", lookup)
.def("link", &ExecutionEngine::Link);

{
auto lookup = [](Compiler &self, std::string_view name) {
auto *function_ptr = reinterpret_cast<void (*)(void **, int32_t)>(self.Lookup(name));
auto function_wrapper = [function_ptr](std::vector<cinn_pod_value_t> &args) {
function_ptr(reinterpret_cast<void **>(args.data()), args.size());
};
return std::function(function_wrapper);
};

py::class_<Compiler> compiler(*m, "Compiler");
compiler
.def_static("create", &Compiler::Create) //
.def("build", &Compiler::Build) //
.def("lookup", lookup);
}
}

} // namespace

void BindBackends(py::module *m) { BindExecutionEngine(m); }
Expand Down
2 changes: 2 additions & 0 deletions cinn/pybind/bind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ PYBIND11_MODULE(core_api, m) {
py::module poly = m.def_submodule("poly", "namespace cinn::poly, polyhedral");
py::module backends = m.def_submodule("backends", "namespace cinn::backends, execution backends");
py::module optim = m.def_submodule("optim", "namespace cinn::optim, CINN IR optimization");
py::module pe = m.def_submodule("pe", "namespace cinn::hlir::pe, CINN Primitive Emitters");

BindRuntime(&runtime);
BindCommon(&common);
Expand All @@ -22,5 +23,6 @@ PYBIND11_MODULE(core_api, m) {
BindPoly(&poly);
BindBackends(&backends);
BindOptim(&optim);
BindPE(&pe);
}
} // namespace cinn::pybind
1 change: 1 addition & 0 deletions cinn/pybind/bind.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@ void BindIr(pybind11::module *m);
void BindBackends(pybind11::module *m);
void BindPoly(pybind11::module *m);
void BindOptim(pybind11::module *m);
void BindPE(pybind11::module *m);

} // namespace cinn::pybind
2 changes: 1 addition & 1 deletion cinn/pybind/bind_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ struct Visitor : Ts... {
};

template <class... Ts>
Visitor(Ts...)->Visitor<Ts...>;
Visitor(Ts...) -> Visitor<Ts...>;

using ExprOp = std::variant<ir::IntImm,
ir::UIntImm,
Expand Down
57 changes: 57 additions & 0 deletions cinn/pybind/pe.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#include "cinn/common/target.h"
#include "cinn/hlir/pe/elementwise.h"
#include "cinn/pybind/bind.h"
#include "cinn/pybind/bind_utils.h"
#include "cinn/utils/string.h"

namespace py = pybind11;

namespace cinn {
namespace pybind {

using common::Type;
using lang::Placeholder;
using py::arg;
using utils::GetStreamCnt;
using utils::StringFormat;

void BindPE(py::module* m) {
#define BIND_UNARY(name__, fn__) m->def(#name__, &hlir::pe::fn__, py::arg("x"), py::arg("out") = "T_" #name__ "_out")
BIND_UNARY(exp, Exp);
BIND_UNARY(erf, Erf);
BIND_UNARY(sqrt, Sqrt);
BIND_UNARY(log, Log);
BIND_UNARY(log2, Log2);
BIND_UNARY(log10, Log10);
BIND_UNARY(floor, Floor);
BIND_UNARY(ceil, Ceil);
BIND_UNARY(round, Round);
BIND_UNARY(trunc, Trunc);
BIND_UNARY(cos, Cos);
BIND_UNARY(cosh, Cosh);
BIND_UNARY(tan, Tan);
BIND_UNARY(sin, Sin);
BIND_UNARY(sinh, Sinh);
BIND_UNARY(acos, Acos);
BIND_UNARY(acosh, Acosh);
BIND_UNARY(asin, Asin);
BIND_UNARY(asinh, Asinh);
BIND_UNARY(atan, Atan);
BIND_UNARY(atanh, Atanh);
BIND_UNARY(isnan, Isnan);
BIND_UNARY(tanh, Tanh);
BIND_UNARY(isfinite, Isfinite);
BIND_UNARY(isinf, Isinf);

BIND_UNARY(negative, Negative);
BIND_UNARY(identity, Identity);
BIND_UNARY(logical_not, LogicalNot);
BIND_UNARY(bitwise_not, BitwiseNot);
BIND_UNARY(sigmoid, Sigmoid);
BIND_UNARY(sign, Sign);
BIND_UNARY(abs, Abs);
BIND_UNARY(rsqrt, Rsqrt);
}

} // namespace pybind
} // namespace cinn
3 changes: 3 additions & 0 deletions cinn/runtime/cinn_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,9 @@ struct cinn_pod_value_t {
template <typename T>
static int type_code();

void set_type_code(int x) { type_code_ = x; }
void set_value(union cinn_value_t x) { value_ = x; }

protected:
// @}
#endif // __cplusplus
Expand Down
4 changes: 4 additions & 0 deletions cinn/runtime/cuda/cuda_module.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#pragma once
#ifdef CINN_WITH_CUDA

#include <cuda.h>
#include <cuda_runtime.h>
#include <glog/logging.h>
Expand Down Expand Up @@ -87,3 +89,5 @@ class CUDAModule {
} // namespace cuda
} // namespace runtime
} // namespace cinn

#endif // CINN_WITH_CUDA
1 change: 1 addition & 0 deletions python/cinn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from .core_api.common import CINNValue
from .core_api.backends import ExecutionOptions
from .core_api.backends import ExecutionEngine
from .core_api.backends import Compiler
1 change: 1 addition & 0 deletions python/cinn/pe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .core_api.pe import *
Loading

0 comments on commit 8401b57

Please sign in to comment.