diff --git a/taichi/codegen/codegen.cpp b/taichi/codegen/codegen.cpp index bf051e5894f40..b8cccd81630ca 100644 --- a/taichi/codegen/codegen.cpp +++ b/taichi/codegen/codegen.cpp @@ -6,6 +6,8 @@ #if defined(TI_WITH_LLVM) #include "taichi/codegen/cpu/codegen_cpu.h" #include "taichi/codegen/wasm/codegen_wasm.h" +#include "taichi/runtime/llvm/llvm_offline_cache.h" +#include "taichi/runtime/program_impls/llvm/llvm_program.h" #endif #if defined(TI_WITH_CUDA) #include "taichi/codegen/cuda/codegen_cuda.h" @@ -53,6 +55,34 @@ std::unique_ptr KernelCodeGen::create(Arch arch, } #ifdef TI_WITH_LLVM +bool KernelCodeGen::maybe_read_compilation_from_cache( + const std::string &kernel_key, + std::vector &data) { + const auto &config = prog->config; + auto reader = + LlvmOfflineCacheFileReader::make(config.offline_cache_file_path); + if (!reader) { + return false; + } + + LlvmOfflineCache::KernelCacheData cache_data; + auto *tlctx = get_llvm_program(prog)->get_llvm_context(config.arch); + auto &llvm_ctx = *tlctx->get_this_thread_context(); + + if (!reader->get_kernel_cache(cache_data, kernel_key, llvm_ctx)) { + return false; + } + data.swap(cache_data.compiled_data_list); + kernel->set_from_offline_cache(); + return true; +} + +void KernelCodeGen::cache_module(const std::string &kernel_key, + const std::vector &data) { + get_llvm_program(prog)->cache_kernel(kernel_key, data, + infer_launch_args(kernel)); +} + ModuleToFunctionConverter::ModuleToFunctionConverter( TaichiLLVMContext *tlctx, LlvmRuntimeExecutor *executor) diff --git a/taichi/codegen/codegen.h b/taichi/codegen/codegen.h index bda8e0a203c5c..6dbeb21060637 100644 --- a/taichi/codegen/codegen.h +++ b/taichi/codegen/codegen.h @@ -28,12 +28,20 @@ class KernelCodeGen { Stmt *stmt = nullptr); virtual FunctionType codegen() = 0; + virtual bool supports_offline_cache() const { + return false; + } + #ifdef TI_WITH_LLVM virtual LLVMCompiledData modulegen( std::unique_ptr &&module = nullptr, OffloadedStmt *stmt = nullptr) { TI_NOT_IMPLEMENTED } + bool maybe_read_compilation_from_cache(const std::string &kernel_key, + std::vector &data); + void cache_module(const std::string &kernel_key, + const std::vector &data); #endif }; diff --git a/taichi/codegen/cpu/codegen_cpu.cpp b/taichi/codegen/cpu/codegen_cpu.cpp index 75428471246ab..245596cf57ce7 100644 --- a/taichi/codegen/cpu/codegen_cpu.cpp +++ b/taichi/codegen/cpu/codegen_cpu.cpp @@ -8,7 +8,9 @@ #include "taichi/ir/ir.h" #include "taichi/ir/statements.h" #include "taichi/util/statistics.h" - +#include "taichi/ir/transforms.h" +#include "taichi/ir/analysis.h" +#include "taichi/analysis/offline_cache_util.h" TLANG_NAMESPACE_BEGIN namespace { @@ -22,10 +24,6 @@ class CodeGenLLVMCPU : public CodeGenLLVM { TI_AUTO_PROF } - bool supports_offline_cache() const override { - return true; - } - void create_offload_range_for(OffloadedStmt *stmt) override { int step = 1; @@ -221,16 +219,6 @@ class CodeGenLLVMCPU : public CodeGenLLVM { TI_NOT_IMPLEMENTED } } - - FunctionType gen() override { - auto compiled_res = run_compilation(); - - CPUModuleToFunctionConverter converter{ - tlctx, get_llvm_program(prog)->get_runtime_executor()}; - std::vector data; - data.push_back(std::move(compiled_res)); - return converter.convert(kernel, std::move(data)); - } }; } // namespace @@ -286,11 +274,45 @@ FunctionType CPUModuleToFunctionConverter::convert( }; } +LLVMCompiledData CodeGenCPU::modulegen(std::unique_ptr &&module, + OffloadedStmt *stmt) { + CodeGenLLVMCPU gen(kernel, stmt); + return gen.run_compilation(); +} #endif // TI_WITH_LLVM FunctionType CodeGenCPU::codegen() { TI_AUTO_PROF; - return CodeGenLLVMCPU(kernel, ir).gen(); -} + auto *llvm_prog = get_llvm_program(prog); + auto *tlctx = llvm_prog->get_llvm_context(kernel->arch); + auto &config = prog->config; + std::string kernel_key = get_hashed_offline_cache_key(&config, kernel); + kernel->set_kernel_key_for_cache(kernel_key); + if (config.offline_cache && !config.async_mode && + this->supports_offline_cache() && !kernel->is_evaluator) { + std::vector res; + const bool ok = maybe_read_compilation_from_cache(kernel_key, res); + if (ok) { + CPUModuleToFunctionConverter converter( + tlctx, get_llvm_program(prog)->get_runtime_executor()); + return converter.convert(kernel, std::move(res)); + } + } + if (!kernel->lowered()) { + kernel->lower(/*to_executable=*/false); + } + + CodeGenLLVMCPU gen(kernel, ir); + auto compiled_res = gen.run_compilation(); + CPUModuleToFunctionConverter converter{gen.tlctx, + llvm_prog->get_runtime_executor()}; + std::vector data_list; + data_list.push_back(std::move(compiled_res)); + if (!kernel->is_evaluator) { + cache_module(kernel_key, data_list); + } + + return converter.convert(this->kernel, std::move(data_list)); +} TLANG_NAMESPACE_END diff --git a/taichi/codegen/cpu/codegen_cpu.h b/taichi/codegen/cpu/codegen_cpu.h index 7aa0369a29d23..6cf2b728c8c02 100644 --- a/taichi/codegen/cpu/codegen_cpu.h +++ b/taichi/codegen/cpu/codegen_cpu.h @@ -20,6 +20,12 @@ class CodeGenCPU : public KernelCodeGen { IRNode *ir); #endif // TI_WITH_LLVM + bool supports_offline_cache() const override { + return true; + } + LLVMCompiledData modulegen(std::unique_ptr &&module = nullptr, + OffloadedStmt *stmt = nullptr) override; + FunctionType codegen() override; }; diff --git a/taichi/codegen/cuda/codegen_cuda.cpp b/taichi/codegen/cuda/codegen_cuda.cpp index b35017ef795c3..140ff8e740c4e 100644 --- a/taichi/codegen/cuda/codegen_cuda.cpp +++ b/taichi/codegen/cuda/codegen_cuda.cpp @@ -15,7 +15,7 @@ #include "taichi/rhi/cuda/cuda_context.h" #include "taichi/runtime/program_impls/llvm/llvm_program.h" #include "taichi/util/action_recorder.h" - +#include "taichi/analysis/offline_cache_util.h" TLANG_NAMESPACE_BEGIN using namespace llvm; @@ -31,21 +31,6 @@ class CodeGenLLVMCUDA : public CodeGenLLVM { : CodeGenLLVM(kernel, ir) { } - bool supports_offline_cache() const override { - return true; - } - - FunctionType gen() override { - auto compiled_res = run_compilation(); - - auto *llvm_prog = get_llvm_program(kernel->program); - CUDAModuleToFunctionConverter converter{tlctx, - llvm_prog->get_runtime_executor()}; - std::vector data; - data.push_back(std::move(compiled_res)); - return converter.convert(this->kernel, std::move(data)); - } - llvm::Value *create_print(std::string tag, DataType dt, llvm::Value *value) override { @@ -737,7 +722,38 @@ static void set_arg_external_array(RuntimeContext *ctx, FunctionType CodeGenCUDA::codegen() { TI_AUTO_PROF - return CodeGenLLVMCUDA(kernel, ir).gen(); + // TODO: move the offline cache part to the base class + auto *llvm_prog = get_llvm_program(prog); + auto *tlctx = llvm_prog->get_llvm_context(kernel->arch); + auto &config = prog->config; + std::string kernel_key = get_hashed_offline_cache_key(&config, kernel); + kernel->set_kernel_key_for_cache(kernel_key); + if (config.offline_cache && !config.async_mode && + this->supports_offline_cache() && !kernel->is_evaluator) { + std::vector res; + const bool ok = maybe_read_compilation_from_cache(kernel_key, res); + if (ok) { + CUDAModuleToFunctionConverter converter( + tlctx, get_llvm_program(prog)->get_runtime_executor()); + return converter.convert(kernel, std::move(res)); + } + } + if (!kernel->lowered()) { + kernel->lower(/*to_executable=*/false); + } + + CodeGenLLVMCUDA gen(kernel, ir); + auto compiled_res = gen.run_compilation(); + + CUDAModuleToFunctionConverter converter{gen.tlctx, + llvm_prog->get_runtime_executor()}; + std::vector data_list; + data_list.push_back(std::move(compiled_res)); + if (!kernel->is_evaluator) { + cache_module(kernel_key, data_list); + } + + return converter.convert(this->kernel, std::move(data_list)); } FunctionType CUDAModuleToFunctionConverter::convert( diff --git a/taichi/codegen/cuda/codegen_cuda.h b/taichi/codegen/cuda/codegen_cuda.h index ab1a47f4baceb..4cac2d4f40227 100644 --- a/taichi/codegen/cuda/codegen_cuda.h +++ b/taichi/codegen/cuda/codegen_cuda.h @@ -19,6 +19,10 @@ class CodeGenCUDA : public KernelCodeGen { IRNode *ir); #endif // TI_WITH_LLVM + bool supports_offline_cache() const override { + return true; + } + FunctionType codegen() override; }; diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index 1b033660ef6a1..f8e7cbc5d13a3 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -2547,58 +2547,15 @@ void CodeGenLLVM::emit_to_module() { } LLVMCompiledData CodeGenLLVM::run_compilation() { - const auto &config = prog->config; - std::string kernel_key = - get_hashed_offline_cache_key(&kernel->program->config, kernel); - kernel->set_kernel_key_for_cache(kernel_key); - if (config.offline_cache && !config.async_mode && - this->supports_offline_cache() && !kernel->is_evaluator) { - LLVMCompiledData res; - const bool ok = maybe_read_compilation_from_cache(kernel_key, &res); - if (ok) { - return res; - } - } + // Final lowering + + auto config = kernel->program->config; + kernel->offload_to_executable(ir); - if (!kernel->lowered()) { - kernel->lower(); - } emit_to_module(); eliminate_unused_functions(); - // Updates LlvmProgramImpl->cache_data_ to save the compiled kernel - // information for successive uses in AOT or CGraph. - if (!kernel->is_evaluator) { - cache_module(kernel_key); - } - - LLVMCompiledData res; - res.tasks = std::move(this->offloaded_tasks); - res.module = std::move(this->module); - return res; -} - -bool CodeGenLLVM::maybe_read_compilation_from_cache( - const std::string &kernel_key, - LLVMCompiledData *data) { - const auto &config = prog->config; - auto reader = - LlvmOfflineCacheFileReader::make(config.offline_cache_file_path); - if (!reader) { - return false; - } - - LlvmOfflineCache::KernelCacheData cache_data; - auto *tlctx = get_llvm_program(prog)->get_llvm_context(config.arch); - auto &llvm_ctx = *tlctx->get_this_thread_context(); - - if (!reader->get_kernel_cache(cache_data, kernel_key, llvm_ctx)) { - return false; - } - data->tasks = std::move(cache_data.compiled_data_list[0].tasks); - data->module = std::move(cache_data.compiled_data_list[0].module); - kernel->set_from_offline_cache(); - return true; + return {std::move(this->offloaded_tasks), std::move(this->module)}; } llvm::Value *CodeGenLLVM::create_xlogue(std::unique_ptr &block) { @@ -2672,13 +2629,6 @@ void CodeGenLLVM::visit(FuncCallStmt *stmt) { } } -void CodeGenLLVM::cache_module(const std::string &kernel_key) { - std::vector data; - data.emplace_back(offloaded_tasks, llvm::CloneModule(*module)); - get_llvm_program(prog)->cache_kernel(kernel_key, data, - infer_launch_args(kernel)); -} - LLVMCompiledData LLVMCompiledData::clone() const { return {tasks, llvm::CloneModule(*module)}; } diff --git a/taichi/codegen/llvm/codegen_llvm.h b/taichi/codegen/llvm/codegen_llvm.h index 3b2328e7b594e..6c05a6f5e117f 100644 --- a/taichi/codegen/llvm/codegen_llvm.h +++ b/taichi/codegen/llvm/codegen_llvm.h @@ -139,15 +139,7 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { * * @return LLVMCompiledData */ - LLVMCompiledData run_compilation(); - - // TODO: This function relies largely on `run_compilation()`. Name it better. - virtual FunctionType gen(){TI_NOT_IMPLEMENTED}; - - virtual bool supports_offline_cache() const { - return false; - } - + virtual LLVMCompiledData run_compilation(); // For debugging only virtual llvm::Value *create_print(std::string tag, DataType dt, @@ -429,12 +421,6 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { llvm::Value *bitcast_to_u64(llvm::Value *val, DataType type); ~CodeGenLLVM() override = default; - - private: - bool maybe_read_compilation_from_cache(const std::string &kernel_key, - LLVMCompiledData *data); - - void cache_module(const std::string &kernel_key); }; } // namespace lang diff --git a/taichi/codegen/wasm/codegen_wasm.cpp b/taichi/codegen/wasm/codegen_wasm.cpp index 46c7f29df65e7..0b92c03cd571f 100644 --- a/taichi/codegen/wasm/codegen_wasm.cpp +++ b/taichi/codegen/wasm/codegen_wasm.cpp @@ -215,8 +215,7 @@ class CodeGenLLVMWASM : public CodeGenLLVM { TI_ASSERT(!llvm::verifyFunction(*func, &llvm::errs())); } - FunctionType gen() override { - TI_AUTO_PROF + LLVMCompiledData run_compilation() override { // lower kernel if (!kernel->lowered()) { kernel->lower(); @@ -236,19 +235,24 @@ class CodeGenLLVMWASM : public CodeGenLLVM { } return func_name == offloaded_task_name; }); - tlctx->add_module(std::move(module)); - auto kernel_symbol = tlctx->lookup_function_pointer(offloaded_task_name); - return [=](RuntimeContext &context) { - TI_TRACE("Launching Taichi Kernel Function"); - auto func = (int32(*)(void *))kernel_symbol; - func(&context); - }; + LLVMCompiledData res; + res.tasks.emplace_back(offloaded_task_name); + res.module = std::move(this->module); + return res; } }; FunctionType CodeGenWASM::codegen() { TI_AUTO_PROF - return CodeGenLLVMWASM(kernel, ir).gen(); + CodeGenLLVMWASM gen(kernel, ir); + auto res = gen.run_compilation(); + gen.tlctx->add_module(std::move(res.module)); + auto kernel_symbol = gen.tlctx->lookup_function_pointer(res.tasks[0].name); + return [=](RuntimeContext &context) { + TI_TRACE("Launching Taichi Kernel Function"); + auto func = (int32(*)(void *))kernel_symbol; + func(&context); + }; } LLVMCompiledData CodeGenWASM::modulegen(std::unique_ptr &&module, diff --git a/taichi/program/kernel.cpp b/taichi/program/kernel.cpp index a608315d40c5f..92e050af1a57f 100644 --- a/taichi/program/kernel.cpp +++ b/taichi/program/kernel.cpp @@ -454,4 +454,20 @@ bool Kernel::supports_lowering(Arch arch) { return arch_is_cpu(arch) || (arch == Arch::cuda) || (arch == Arch::metal); } +void Kernel::offload_to_executable(IRNode *stmt) { + CurrentCallableGuard _(program, this); + auto config = program->config; + bool verbose = config.print_ir; + if ((is_accessor && !config.print_accessor_ir) || + (is_evaluator && !config.print_evaluator_ir)) + verbose = false; + irpass::offload_to_executable( + stmt, config, this, verbose, + /*determine_ad_stack_size=*/autodiff_mode == AutodiffMode::kReverse, + /*lower_global_access=*/true, + /*make_block_local=*/config.make_thread_local, + /*make_block_local=*/ + is_extension_supported(config.arch, Extension::bls) && + config.make_block_local); +} TLANG_NAMESPACE_END diff --git a/taichi/program/kernel.h b/taichi/program/kernel.h index ffac88a969a57..3311e5e22ec43 100644 --- a/taichi/program/kernel.h +++ b/taichi/program/kernel.h @@ -148,6 +148,7 @@ class TI_DLL_EXPORT Kernel : public Callable { const std::string &get_cached_kernel_key() { return kernel_key_; } + void offload_to_executable(IRNode *stmt); private: void init(Program &program,