From 25fb9ce9a3c51f29907f541f896cadc4bee37efb Mon Sep 17 00:00:00 2001 From: PGZXB Date: Fri, 21 Apr 2023 10:59:51 +0800 Subject: [PATCH] oneline -> twolines --- cpp_examples/autograd.cpp | 33 ++++++++++------- cpp_examples/run_snode.cpp | 32 ++++++++++------- taichi/aot/graph_data.cpp | 4 +-- taichi/program/snode_rw_accessors_bank.cpp | 42 ++++++++++------------ tests/cpp/ir/ir_builder_test.cpp | 24 ++++++------- 5 files changed, 71 insertions(+), 64 deletions(-) diff --git a/cpp_examples/autograd.cpp b/cpp_examples/autograd.cpp index 947b957e862fb..b03a684fc8d45 100644 --- a/cpp_examples/autograd.cpp +++ b/cpp_examples/autograd.cpp @@ -195,19 +195,26 @@ void autograd() { ctx_ext.set_arg_external_array_with_shape(2, taichi::uint64(ext_c.data()), n, {n}); - program.launch_kernel( - program.compile_kernel(config, program.get_device_caps(), *kernel_init), - ctx_init); - program.launch_kernel(program.compile_kernel( - config, program.get_device_caps(), *kernel_forward), - ctx_forward); - program.launch_kernel( - program.compile_kernel(config, program.get_device_caps(), - *kernel_backward), - ctx_backward); - program.launch_kernel( - program.compile_kernel(config, program.get_device_caps(), *kernel_ext), - ctx_ext); + { + const auto &compiled_kernel_data = + program.compile_kernel(config, program.get_device_caps(), *kernel_init); + program.launch_kernel(compiled_kernel_data, ctx_init); + } + { + const auto &compiled_kernel_data = program.compile_kernel( + config, program.get_device_caps(), *kernel_forward); + program.launch_kernel(compiled_kernel_data, ctx_forward); + } + { + const auto &compiled_kernel_data = program.compile_kernel( + config, program.get_device_caps(), *kernel_backward); + program.launch_kernel(compiled_kernel_data, ctx_backward); + } + { + const auto &compiled_kernel_data = + program.compile_kernel(config, program.get_device_caps(), *kernel_ext); + program.launch_kernel(compiled_kernel_data, ctx_ext); + } for (int i = 0; i < n; i++) std::cout << ext_a[i] << " "; std::cout << std::endl; diff --git a/cpp_examples/run_snode.cpp b/cpp_examples/run_snode.cpp index fbc97d2600ea1..12a249ac335c5 100644 --- a/cpp_examples/run_snode.cpp +++ b/cpp_examples/run_snode.cpp @@ -134,17 +134,23 @@ void run_snode() { ctx_ext.set_arg_external_array_with_shape(0, taichi::uint64(ext_arr.data()), n, {n}); - program.launch_kernel( - program.compile_kernel(config, program.get_device_caps(), *kernel_init), - ctx_init); - program.launch_kernel( - program.compile_kernel(config, program.get_device_caps(), *kernel_ret), - ctx_ret); - std::cout << program.fetch_result(0) << std::endl; - program.launch_kernel( - program.compile_kernel(config, program.get_device_caps(), *kernel_ext), - ctx_ext); - for (int i = 0; i < n; i++) - std::cout << ext_arr[i] << " "; - std::cout << std::endl; + { + const auto &compiled_kernel_data = + program.compile_kernel(config, program.get_device_caps(), *kernel_init); + program.launch_kernel(compiled_kernel_data, ctx_init); + } + { + const auto &compiled_kernel_data = + program.compile_kernel(config, program.get_device_caps(), *kernel_ret); + program.launch_kernel(compiled_kernel_data, ctx_ret); + std::cout << program.fetch_result(0) << std::endl; + } + { + const auto &compiled_kernel_data = + program.compile_kernel(config, program.get_device_caps(), *kernel_ext); + program.launch_kernel(compiled_kernel_data, ctx_ext); + for (int i = 0; i < n; i++) + std::cout << ext_arr[i] << " "; + std::cout << std::endl; + } } diff --git a/taichi/aot/graph_data.cpp b/taichi/aot/graph_data.cpp index 2332c24c8b668..2423f864e3cc9 100644 --- a/taichi/aot/graph_data.cpp +++ b/taichi/aot/graph_data.cpp @@ -31,9 +31,9 @@ void CompiledGraph::jit_run( // worry that the kernels dispatched by this cgraph will be compiled // repeatedly. auto *prog = dispatch.ti_kernel->program; - const auto &ckd = prog->compile_kernel( + const auto &compiled_kernel_data = prog->compile_kernel( compile_config, prog->get_device_caps(), *dispatch.ti_kernel); - prog->launch_kernel(ckd, launch_ctx); + prog->launch_kernel(compiled_kernel_data, launch_ctx); } } diff --git a/taichi/program/snode_rw_accessors_bank.cpp b/taichi/program/snode_rw_accessors_bank.cpp index 932f7046d989f..15c5fa3296f67 100644 --- a/taichi/program/snode_rw_accessors_bank.cpp +++ b/taichi/program/snode_rw_accessors_bank.cpp @@ -41,20 +41,18 @@ void SNodeRwAccessorsBank::Accessors::write_float(const std::vector &I, set_kernel_args(I, snode_->num_active_indices, &launch_ctx); launch_ctx.set_arg_float(snode_->num_active_indices, val); prog_->synchronize(); - prog_->launch_kernel( - prog_->compile_kernel(prog_->compile_config(), prog_->get_device_caps(), - *writer_), - launch_ctx); + const auto &compiled_kernel_data = prog_->compile_kernel( + prog_->compile_config(), prog_->get_device_caps(), *writer_); + prog_->launch_kernel(compiled_kernel_data, launch_ctx); } float64 SNodeRwAccessorsBank::Accessors::read_float(const std::vector &I) { prog_->synchronize(); auto launch_ctx = reader_->make_launch_context(); set_kernel_args(I, snode_->num_active_indices, &launch_ctx); - prog_->launch_kernel( - prog_->compile_kernel(prog_->compile_config(), prog_->get_device_caps(), - *reader_), - launch_ctx); + const auto &compiled_kernel_data = prog_->compile_kernel( + prog_->compile_config(), prog_->get_device_caps(), *reader_); + prog_->launch_kernel(compiled_kernel_data, launch_ctx); prog_->synchronize(); if (arch_uses_llvm(prog_->compile_config().arch)) { return launch_ctx.get_struct_ret_float({0}); @@ -70,10 +68,9 @@ void SNodeRwAccessorsBank::Accessors::write_int(const std::vector &I, set_kernel_args(I, snode_->num_active_indices, &launch_ctx); launch_ctx.set_arg_int(snode_->num_active_indices, val); prog_->synchronize(); - prog_->launch_kernel( - prog_->compile_kernel(prog_->compile_config(), prog_->get_device_caps(), - *writer_), - launch_ctx); + const auto &compiled_kernel_data = prog_->compile_kernel( + prog_->compile_config(), prog_->get_device_caps(), *writer_); + prog_->launch_kernel(compiled_kernel_data, launch_ctx); } // for int32 and int64 @@ -83,20 +80,18 @@ void SNodeRwAccessorsBank::Accessors::write_uint(const std::vector &I, set_kernel_args(I, snode_->num_active_indices, &launch_ctx); launch_ctx.set_arg_uint(snode_->num_active_indices, val); prog_->synchronize(); - prog_->launch_kernel( - prog_->compile_kernel(prog_->compile_config(), prog_->get_device_caps(), - *writer_), - launch_ctx); + const auto &compiled_kernel_data = prog_->compile_kernel( + prog_->compile_config(), prog_->get_device_caps(), *writer_); + prog_->launch_kernel(compiled_kernel_data, launch_ctx); } int64 SNodeRwAccessorsBank::Accessors::read_int(const std::vector &I) { prog_->synchronize(); auto launch_ctx = reader_->make_launch_context(); set_kernel_args(I, snode_->num_active_indices, &launch_ctx); - prog_->launch_kernel( - prog_->compile_kernel(prog_->compile_config(), prog_->get_device_caps(), - *reader_), - launch_ctx); + const auto &compiled_kernel_data = prog_->compile_kernel( + prog_->compile_config(), prog_->get_device_caps(), *reader_); + prog_->launch_kernel(compiled_kernel_data, launch_ctx); prog_->synchronize(); if (arch_uses_llvm(prog_->compile_config().arch)) { return launch_ctx.get_struct_ret_int({0}); @@ -109,10 +104,9 @@ uint64 SNodeRwAccessorsBank::Accessors::read_uint(const std::vector &I) { prog_->synchronize(); auto launch_ctx = reader_->make_launch_context(); set_kernel_args(I, snode_->num_active_indices, &launch_ctx); - prog_->launch_kernel( - prog_->compile_kernel(prog_->compile_config(), prog_->get_device_caps(), - *reader_), - launch_ctx); + const auto &compiled_kernel_data = prog_->compile_kernel( + prog_->compile_config(), prog_->get_device_caps(), *reader_); + prog_->launch_kernel(compiled_kernel_data, launch_ctx); prog_->synchronize(); if (arch_uses_llvm(prog_->compile_config().arch)) { return launch_ctx.get_struct_ret_uint({0}); diff --git a/tests/cpp/ir/ir_builder_test.cpp b/tests/cpp/ir/ir_builder_test.cpp index 28f3419ec64af..ac2c6624dc928 100644 --- a/tests/cpp/ir/ir_builder_test.cpp +++ b/tests/cpp/ir/ir_builder_test.cpp @@ -117,9 +117,9 @@ TEST(IRBuilder, ExternalPtr) { launch_ctx.set_arg_external_array_with_shape( /*arg_id=*/0, (uint64)array.get(), size, {size}); auto *prog = test_prog.prog(); - prog->launch_kernel(prog->compile_kernel(prog->compile_config(), - prog->get_device_caps(), *ker), - launch_ctx); + const auto &compiled_kernel_data = prog->compile_kernel( + prog->compile_config(), prog->get_device_caps(), *ker); + prog->launch_kernel(compiled_kernel_data, launch_ctx); EXPECT_EQ(array[0], 2); EXPECT_EQ(array[1], 1); EXPECT_EQ(array[2], 42); @@ -145,9 +145,9 @@ TEST(IRBuilder, Ndarray) { auto ker1 = setup_kernel1(test_prog.prog()); auto launch_ctx1 = ker1->make_launch_context(); launch_ctx1.set_arg_ndarray(/*arg_id=*/0, array); - prog->launch_kernel(prog->compile_kernel(prog->compile_config(), - prog->get_device_caps(), *ker1), - launch_ctx1); + const auto &compiled_kernel_data = prog->compile_kernel( + prog->compile_config(), prog->get_device_caps(), *ker1); + prog->launch_kernel(compiled_kernel_data, launch_ctx1); EXPECT_EQ(array.read_int({0}), 2); EXPECT_EQ(array.read_int({1}), 1); EXPECT_EQ(array.read_int({2}), 42); @@ -156,9 +156,9 @@ TEST(IRBuilder, Ndarray) { auto launch_ctx2 = ker2->make_launch_context(); launch_ctx2.set_arg_ndarray(/*arg_id=*/0, array); launch_ctx2.set_arg_int(/*arg_id=*/1, 3); - prog->launch_kernel(prog->compile_kernel(prog->compile_config(), - prog->get_device_caps(), *ker2), - launch_ctx2); + const auto &compiled_kernel_data2 = prog->compile_kernel( + prog->compile_config(), prog->get_device_caps(), *ker2); + prog->launch_kernel(compiled_kernel_data2, launch_ctx2); EXPECT_EQ(array.read_int({0}), 2); EXPECT_EQ(array.read_int({1}), 3); EXPECT_EQ(array.read_int({2}), 42); @@ -187,9 +187,9 @@ TEST(IRBuilder, AtomicOp) { launch_ctx.set_arg_external_array_with_shape( /*arg_id=*/0, (uint64)array.get(), size, {size}); auto *prog = test_prog.prog(); - prog->launch_kernel(prog->compile_kernel(prog->compile_config(), - prog->get_device_caps(), *ker), - launch_ctx); + const auto &compiled_kernel_data = prog->compile_kernel( + prog->compile_config(), prog->get_device_caps(), *ker); + prog->launch_kernel(compiled_kernel_data, launch_ctx); EXPECT_EQ(array[0], 3); } } // namespace taichi::lang