Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CINN][New Hardware Update] extend SplitCudaAndHostModule #64345

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions paddle/cinn/backends/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@ gather_srcs(
extern_func_protos.cc
extern_func_jit_register.cc
modular.cc
compiler.cc)
compiler.cc
codegen_device_util.cc)

if(WITH_CUDA)
add_subdirectory(nvrtc)
list(APPEND srcs cuda_util.cc codegen_cuda_dev.cc codegen_cuda_util.cc)
list(APPEND srcs cuda_util.cc codegen_cuda_dev.cc)
endif()

if(WITH_OPENMP)
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/backends/codegen_cuda_generate_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

#include "paddle/cinn/backends/codegen_cuda_dev.h"
#include "paddle/cinn/backends/codegen_cuda_host.h"
#include "paddle/cinn/backends/codegen_cuda_util.h"
#include "paddle/cinn/backends/codegen_device_util.h"
#include "paddle/cinn/backends/extern_func_jit_register.h"
#include "paddle/cinn/backends/llvm/execution_engine.h"
#include "paddle/cinn/backends/llvm/simple_jit.h"
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/backends/codegen_cuda_host.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
#include <string>
#include <unordered_map>

#include "paddle/cinn/backends/codegen_cuda_util.h"
#include "paddle/cinn/backends/codegen_device_util.h"
#include "paddle/cinn/backends/extern_func_emitter_builtin.h"
#include "paddle/cinn/backends/extern_func_jit_register.h"
#include "paddle/cinn/backends/llvm/llvm_util.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/backends/codegen_cuda_util.h"
#include "paddle/cinn/backends/codegen_device_util.h"

#include "paddle/cinn/backends/cuda_util.h"
#include "paddle/cinn/common/cas.h"
Expand All @@ -22,7 +22,7 @@ PD_DECLARE_bool(cinn_bucket_compile);
namespace cinn {
namespace backends {

std::tuple<ir::Module, ir::Module> SplitCudaAndHostModule(ir::Module module) {
std::tuple<ir::Module, ir::Module> SplitDeviceAndHostModule(ir::Module module) {
if (FLAGS_cinn_bucket_compile) {
detail::CollectBucketStrategyHostFunctionVisitor visitor(module->name);
Expr expr(module);
Expand Down Expand Up @@ -91,7 +91,16 @@ void detail::CollectBucketStrategyHostFunctionVisitor::ProcessLoweredFunc(
ir::Var kernel_ptr(GenDeviceKernelName(func_node->name, predicate),
type_of<std::string>());

Expr shared_mem_bytes = CalculateSharedMemory(func);
std::optional<Expr> shared_mem_bytes;
cinn::common::DefaultDeviceTarget().arch.Match(
[&](std::variant<common::UnknownArch, common::X86Arch, common::ARMArch>) {
CINN_NOT_IMPLEMENTED;
},
[&](common::NVGPUArch) {
#ifdef CINN_WITH_CUDA
shared_mem_bytes = CalculateSharedMemory(func);
#endif
});

VLOG(6) << "Add a call node for func_node->name " << func_node->name << "\n"
<< "grid_dim: (" << func_node->cuda_axis_info.grid_dim(0) << ", "
Expand All @@ -100,10 +109,18 @@ void detail::CollectBucketStrategyHostFunctionVisitor::ProcessLoweredFunc(
<< "block_dim: (" << func_node->cuda_axis_info.block_dim(0) << ", "
<< func_node->cuda_axis_info.block_dim(1) << ", "
<< func_node->cuda_axis_info.block_dim(2) << "), "
<< "shared_mem: " << shared_mem_bytes;
<< "shared_mem: " << shared_mem_bytes.value();
std::optional<const char *> call_kernel;
cinn::common::DefaultDeviceTarget().arch.Match(
[&](std::variant<common::UnknownArch, common::X86Arch, common::ARMArch>) {
CINN_NOT_IMPLEMENTED;
},
[&](common::NVGPUArch) {
call_kernel = runtime::intrinsic::call_cuda_kernel;
});
ir::Expr call_extern_api =
ir::Call::Make(Void(),
runtime::intrinsic::call_cuda_kernel,
call_kernel.value(),
{kernel_ptr,
kernel_args_,
kernel_args_num_,
Expand All @@ -113,7 +130,7 @@ void detail::CollectBucketStrategyHostFunctionVisitor::ProcessLoweredFunc(
func_node->cuda_axis_info.block_dim(0), // block_x
func_node->cuda_axis_info.block_dim(1), // block_y
func_node->cuda_axis_info.block_dim(2), // block_z
shared_mem_bytes, // shared_mem
shared_mem_bytes.value(), // shared_mem
kernel_stream_},
{},
ir::CallType::Extern,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@
#include <string>
#include <tuple>
#include <vector>

#ifdef CINN_WITH_CUDA
#include "paddle/cinn/backends/codegen_cuda_dev.h"
#endif
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/runtime/flags.h"

namespace cinn {
namespace backends {
Expand All @@ -43,7 +45,7 @@ namespace backends {
* - replace the original kernel function with a Call node and add it to the
* first module, add a device kernel function to the second module.
*/
std::tuple<ir::Module, ir::Module> SplitCudaAndHostModule(ir::Module module);
std::tuple<ir::Module, ir::Module> SplitDeviceAndHostModule(ir::Module module);

namespace detail {

Expand All @@ -52,7 +54,7 @@ struct CollectHostFunctionVisitor : public ir::IRMutator<> {
: host_module_builder(module_name + "_host",
cinn::common::DefaultHostTarget()),
device_module_builder(module_name + "_gpu_device",
cinn::common::DefaultNVGPUTarget()) {}
cinn::common::DefaultDeviceTarget()) {}

std::tuple<ir::Module, ir::Module> operator()(Expr* expr) {
ir::IRMutator<>::Visit(expr, expr);
Expand Down Expand Up @@ -109,9 +111,18 @@ struct CollectHostFunctionVisitor : public ir::IRMutator<> {
// shared_mem_bytes Can be calculated after codegen_cuda_dev buffer creation
// however, this make CodeGenCUDA_Dev before spliting the host and device
// module Maybe we could reorder the process.
CodeGenCUDA_Dev codegen_dev(cinn::common::DefaultNVGPUTarget());
codegen_dev.Compile(ir::LoweredFunc(func));
Expr shared_mem_bytes = codegen_dev.GetDynSharedMemOffset();
std::optional<Expr> shared_mem_bytes;
cinn::common::DefaultDeviceTarget().arch.Match(
[&](std::variant<common::UnknownArch,
common::X86Arch,
common::ARMArch>) { CINN_NOT_IMPLEMENTED; },
[&](common::NVGPUArch) {
#ifdef CINN_WITH_CUDA
CodeGenCUDA_Dev codegen_dev(cinn::common::DefaultNVGPUTarget());
codegen_dev.Compile(ir::LoweredFunc(func));
shared_mem_bytes = codegen_dev.GetDynSharedMemOffset();
#endif
});

VLOG(6) << "Add a call node for func->name " << func->name << "\n"
<< "grid_dim: (" << func->cuda_axis_info.grid_dim(0) << ", "
Expand All @@ -120,10 +131,20 @@ struct CollectHostFunctionVisitor : public ir::IRMutator<> {
<< "block_dim: (" << func->cuda_axis_info.block_dim(0) << ", "
<< func->cuda_axis_info.block_dim(1) << ", "
<< func->cuda_axis_info.block_dim(2) << "), "
<< "shared_mem: " << shared_mem_bytes;
<< "shared_mem: " << shared_mem_bytes.value();

std::optional<const char*> call_kernel;
cinn::common::DefaultDeviceTarget().arch.Match(
[&](std::variant<common::UnknownArch,
common::X86Arch,
common::ARMArch>) { CINN_NOT_IMPLEMENTED; },
[&](common::NVGPUArch) {
call_kernel = runtime::intrinsic::call_cuda_kernel;
});

auto call_extern_api =
ir::Call::Make(Void(),
runtime::intrinsic::call_cuda_kernel,
call_kernel.value(),
{kernel_ptr,
kernel_args,
kernel_args_num,
Expand All @@ -133,7 +154,7 @@ struct CollectHostFunctionVisitor : public ir::IRMutator<> {
func->cuda_axis_info.block_dim(0), // block_x
func->cuda_axis_info.block_dim(1), // block_y
func->cuda_axis_info.block_dim(2), // block_z
shared_mem_bytes,
shared_mem_bytes.value(),
kernel_stream},
{},
ir::CallType::Extern,
Expand Down
7 changes: 4 additions & 3 deletions paddle/cinn/backends/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#ifdef CINN_WITH_CUDA
#include "paddle/cinn/backends/codegen_cuda_dev.h"
#include "paddle/cinn/backends/codegen_cuda_host.h"
#include "paddle/cinn/backends/codegen_cuda_util.h"
#include "paddle/cinn/backends/codegen_device_util.h"
#include "paddle/cinn/backends/nvrtc/nvrtc_util.h"
#include "paddle/cinn/runtime/cuda/cuda_module.h"
#include "paddle/cinn/runtime/cuda/cuda_util.h"
Expand Down Expand Up @@ -246,7 +246,7 @@ std::string Compiler::GetSourceCode(const ir::Module& module) {
[&](common::NVGPUArch) -> std::string {
#ifdef CINN_WITH_CUDA
auto _host_module_device_module_ =
SplitCudaAndHostModule(module); // NOLINT
SplitDeviceAndHostModule(module); // NOLINT
auto& host_module = std::get<0>(_host_module_device_module_);
auto& device_module = std::get<1>(_host_module_device_module_);
CodeGenCUDA_Dev codegen(target_);
Expand All @@ -270,7 +270,8 @@ void Compiler::BuildDefault(const Module& module) {
void Compiler::CompileCudaModule(const Module& module,
const std::string& code) {
#ifdef CINN_WITH_CUDA
auto _host_module_device_module_ = SplitCudaAndHostModule(module); // NOLINT
auto _host_module_device_module_ =
SplitDeviceAndHostModule(module); // NOLINT
auto& host_module = std::get<0>(_host_module_device_module_);
auto& device_module = std::get<1>(_host_module_device_module_);
VLOG(3) << "[CUDA] host module:\n" << host_module;
Expand Down
3 changes: 3 additions & 0 deletions paddle/cinn/common/arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <functional>
#include <ostream>
#include <variant>
#include "paddle/common/overloaded.h"

namespace cinn {
namespace common {
Expand Down Expand Up @@ -45,6 +46,8 @@ struct Arch final : public ArchBase {
return static_cast<const ArchBase&>(*this);
}

DEFINE_MATCH_METHOD();

bool operator==(const auto& other) const {
return this->index() == other.index();
}
Expand Down
4 changes: 2 additions & 2 deletions paddle/cinn/common/cuda_test_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

#include "paddle/cinn/backends/codegen_cuda_dev.h"
#include "paddle/cinn/backends/codegen_cuda_host.h"
#include "paddle/cinn/backends/codegen_cuda_util.h"
#include "paddle/cinn/backends/codegen_device_util.h"
#include "paddle/cinn/backends/nvrtc/nvrtc_util.h"
#include "paddle/cinn/runtime/cuda/cuda_module.h"
#include "paddle/cinn/runtime/cuda/cuda_util.h"
Expand All @@ -28,7 +28,7 @@ namespace common {
void CudaModuleTester::Compile(const ir::Module& m,
const std::string& rewrite_cuda_code) {
auto _host_module_device_module_ =
backends::SplitCudaAndHostModule(m); // NOLINT
backends::SplitDeviceAndHostModule(m); // NOLINT
auto& host_module = std::get<0>(_host_module_device_module_);
auto& device_module = std::get<1>(_host_module_device_module_);
CHECK(!host_module.functions().empty());
Expand Down
6 changes: 6 additions & 0 deletions paddle/cinn/common/target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,12 @@ const Target &DefaultNVGPUTarget() {
return target;
}

const Target &DefaultDeviceTarget() {
#ifdef CINN_WITH_CUDA
return DefaultNVGPUTarget();
#endif
}

int GetMaxThreads() {
// cudaDeviceGetAttribute ( int* value, cudaDeviceAttr attr, int device )
int max_threads = 1;
Expand Down
2 changes: 2 additions & 0 deletions paddle/cinn/common/target.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ const Target& DefaultHostTarget();

const Target& DefaultNVGPUTarget();

const Target& DefaultDeviceTarget();

const Target& DefaultTarget();

int GetMaxThreads();
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/frontend/paddle/model_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

#include "paddle/cinn/backends/codegen_cuda_dev.h"
#include "paddle/cinn/backends/codegen_cuda_host.h"
#include "paddle/cinn/backends/codegen_cuda_util.h"
#include "paddle/cinn/backends/codegen_device_util.h"
#include "paddle/cinn/backends/cuda_util.h"
#include "paddle/cinn/common/common.h"
#include "paddle/cinn/frontend/paddle/compatible_pb.h"
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/framework/graph_compiler_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ struct CompilationContext {
void* stream = nullptr;

// Set attached source code, if code is not empty, these codes will replace
// the device_module code after SplitCudaAndHostModule.
// the device_module code after SplitDeviceAndHostModule.
void ApplySourceCode(const std::string& code);
// Apply results of auto-tune to compile.
// Compilation will start from CompilationStage::CODEGEN_AND_JIT when tuning
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/framework/op_lowering_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

#include "paddle/cinn/backends/codegen_c_x86.h"
#include "paddle/cinn/backends/codegen_cuda_dev.h"
#include "paddle/cinn/backends/codegen_cuda_util.h"
#include "paddle/cinn/backends/codegen_device_util.h"
#include "paddle/cinn/backends/cuda_util.h"
#include "paddle/cinn/backends/llvm/execution_engine.h"
#include "paddle/cinn/backends/nvrtc/nvrtc_util.h"
Expand Down
4 changes: 2 additions & 2 deletions paddle/cinn/hlir/framework/parallel_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

#include "paddle/cinn/backends/codegen_cuda_dev.h"
#include "paddle/cinn/backends/codegen_cuda_host.h"
#include "paddle/cinn/backends/codegen_cuda_util.h"
#include "paddle/cinn/backends/codegen_device_util.h"
#include "paddle/cinn/backends/compiler.h"
#include "paddle/cinn/backends/llvm/codegen_x86.h"
#include "paddle/cinn/backends/llvm/runtime_symbol_registry.h"
Expand Down Expand Up @@ -238,7 +238,7 @@ void ParallelCompiler::Task::CodegenAndJit() {
auto ir_module = builder.Build();
if (context->target == cinn::common::DefaultNVGPUTarget()) {
#ifdef CINN_WITH_CUDA
auto splited_module = backends::SplitCudaAndHostModule(ir_module);
auto splited_module = backends::SplitDeviceAndHostModule(ir_module);
auto hmodule = std::get<0>(splited_module);
auto dmodule = std::get<1>(splited_module);

Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/framework/pir/op_lowering_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

#include "paddle/cinn/adt/map_expr_ctx.h"
#include "paddle/cinn/ast_gen_ius/tensor_group.h"
#include "paddle/cinn/backends/codegen_cuda_util.h"
#include "paddle/cinn/backends/codegen_device_util.h"
#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h"
#include "paddle/cinn/hlir/framework/compile_error.h"
#include "paddle/cinn/hlir/framework/pir/op_lowering_util.h"
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/op/custom_call.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/backends/codegen_cuda_util.h"
#include "paddle/cinn/backends/codegen_device_util.h"
#include "paddle/cinn/common/cas.h"
#include "paddle/cinn/hlir/framework/node.h"
#include "paddle/cinn/hlir/framework/op.h"
Expand Down
4 changes: 2 additions & 2 deletions paddle/cinn/hlir/op/reduction_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

#include "paddle/cinn/backends/codegen_cuda_dev.h"
#include "paddle/cinn/backends/codegen_cuda_host.h"
#include "paddle/cinn/backends/codegen_cuda_util.h"
#include "paddle/cinn/backends/codegen_device_util.h"
#include "paddle/cinn/backends/cuda_util.h"
#include "paddle/cinn/backends/llvm/execution_engine.h"
#include "paddle/cinn/backends/llvm/runtime_symbol_registry.h"
Expand Down Expand Up @@ -116,7 +116,7 @@ std::pair<ir::Module, std::string> GenReduceCode(
// now.
auto module = builder.Build();
auto host_module_device_module =
backends::SplitCudaAndHostModule(module); // NOLINT
backends::SplitDeviceAndHostModule(module); // NOLINT
auto& host_module = std::get<0>(host_module_device_module);
auto& device_module = std::get<1>(host_module_device_module);

Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/op/transform_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

#include "paddle/cinn/backends/codegen_cuda_dev.h"
#include "paddle/cinn/backends/codegen_cuda_host.h"
#include "paddle/cinn/backends/codegen_cuda_util.h"
#include "paddle/cinn/backends/codegen_device_util.h"
#include "paddle/cinn/backends/cuda_util.h"
#include "paddle/cinn/backends/llvm/execution_engine.h"
#include "paddle/cinn/backends/llvm/runtime_symbol_registry.h"
Expand Down
8 changes: 4 additions & 4 deletions paddle/cinn/hlir/pe/pe_transform_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#include <gtest/gtest.h>

#include "paddle/cinn/backends/codegen_cuda_dev.h"
#include "paddle/cinn/backends/codegen_cuda_util.h"
#include "paddle/cinn/backends/codegen_device_util.h"
#include "paddle/cinn/backends/cuda_util.h"
#include "paddle/cinn/backends/llvm/execution_engine.h"
#include "paddle/cinn/backends/nvrtc/nvrtc_util.h"
Expand Down Expand Up @@ -132,7 +132,7 @@ TEST(ScatterAssign, ScatterAssign) {
builder.AddFunction(func);

auto module = builder.Build();
auto host_module_device_module = backends::SplitCudaAndHostModule(module);
auto host_module_device_module = backends::SplitDeviceAndHostModule(module);
auto &host_module = std::get<0>(host_module_device_module);
auto &device_module = std::get<1>(host_module_device_module);

Expand Down Expand Up @@ -176,7 +176,7 @@ TEST(SliceAssign, SliceAssign) {
builder.AddFunction(func);

auto module = builder.Build();
auto host_module_device_module = backends::SplitCudaAndHostModule(module);
auto host_module_device_module = backends::SplitDeviceAndHostModule(module);
auto &host_module = std::get<0>(host_module_device_module);
auto &device_module = std::get<1>(host_module_device_module);

Expand Down Expand Up @@ -217,7 +217,7 @@ TEST(Concat, ConcatCase0) {
builder.AddFunction(func);

auto module = builder.Build();
auto host_module_device_module = backends::SplitCudaAndHostModule(module);
auto host_module_device_module = backends::SplitDeviceAndHostModule(module);
auto &host_module = std::get<0>(host_module_device_module);
auto &device_module = std::get<1>(host_module_device_module);

Expand Down