Skip to content

Commit

Permalink
oneline -> twolines
Browse files Browse the repository at this point in the history
  • Loading branch information
PGZXB committed Apr 21, 2023
1 parent 4233048 commit 25fb9ce
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 64 deletions.
33 changes: 20 additions & 13 deletions cpp_examples/autograd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,19 +195,26 @@ void autograd() {
ctx_ext.set_arg_external_array_with_shape(2, taichi::uint64(ext_c.data()), n,
{n});

program.launch_kernel(
program.compile_kernel(config, program.get_device_caps(), *kernel_init),
ctx_init);
program.launch_kernel(program.compile_kernel(
config, program.get_device_caps(), *kernel_forward),
ctx_forward);
program.launch_kernel(
program.compile_kernel(config, program.get_device_caps(),
*kernel_backward),
ctx_backward);
program.launch_kernel(
program.compile_kernel(config, program.get_device_caps(), *kernel_ext),
ctx_ext);
{
const auto &compiled_kernel_data =
program.compile_kernel(config, program.get_device_caps(), *kernel_init);
program.launch_kernel(compiled_kernel_data, ctx_init);
}
{
const auto &compiled_kernel_data = program.compile_kernel(
config, program.get_device_caps(), *kernel_forward);
program.launch_kernel(compiled_kernel_data, ctx_forward);
}
{
const auto &compiled_kernel_data = program.compile_kernel(
config, program.get_device_caps(), *kernel_backward);
program.launch_kernel(compiled_kernel_data, ctx_backward);
}
{
const auto &compiled_kernel_data =
program.compile_kernel(config, program.get_device_caps(), *kernel_ext);
program.launch_kernel(compiled_kernel_data, ctx_ext);
}
for (int i = 0; i < n; i++)
std::cout << ext_a[i] << " ";
std::cout << std::endl;
Expand Down
32 changes: 19 additions & 13 deletions cpp_examples/run_snode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,17 +134,23 @@ void run_snode() {
ctx_ext.set_arg_external_array_with_shape(0, taichi::uint64(ext_arr.data()),
n, {n});

program.launch_kernel(
program.compile_kernel(config, program.get_device_caps(), *kernel_init),
ctx_init);
program.launch_kernel(
program.compile_kernel(config, program.get_device_caps(), *kernel_ret),
ctx_ret);
std::cout << program.fetch_result<int>(0) << std::endl;
program.launch_kernel(
program.compile_kernel(config, program.get_device_caps(), *kernel_ext),
ctx_ext);
for (int i = 0; i < n; i++)
std::cout << ext_arr[i] << " ";
std::cout << std::endl;
{
const auto &compiled_kernel_data =
program.compile_kernel(config, program.get_device_caps(), *kernel_init);
program.launch_kernel(compiled_kernel_data, ctx_init);
}
{
const auto &compiled_kernel_data =
program.compile_kernel(config, program.get_device_caps(), *kernel_ret);
program.launch_kernel(compiled_kernel_data, ctx_ret);
std::cout << program.fetch_result<int>(0) << std::endl;
}
{
const auto &compiled_kernel_data =
program.compile_kernel(config, program.get_device_caps(), *kernel_ext);
program.launch_kernel(compiled_kernel_data, ctx_ext);
for (int i = 0; i < n; i++)
std::cout << ext_arr[i] << " ";
std::cout << std::endl;
}
}
4 changes: 2 additions & 2 deletions taichi/aot/graph_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ void CompiledGraph::jit_run(
// worry that the kernels dispatched by this cgraph will be compiled
// repeatedly.
auto *prog = dispatch.ti_kernel->program;
const auto &ckd = prog->compile_kernel(
const auto &compiled_kernel_data = prog->compile_kernel(
compile_config, prog->get_device_caps(), *dispatch.ti_kernel);
prog->launch_kernel(ckd, launch_ctx);
prog->launch_kernel(compiled_kernel_data, launch_ctx);
}
}

Expand Down
42 changes: 18 additions & 24 deletions taichi/program/snode_rw_accessors_bank.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,20 +41,18 @@ void SNodeRwAccessorsBank::Accessors::write_float(const std::vector<int> &I,
set_kernel_args(I, snode_->num_active_indices, &launch_ctx);
launch_ctx.set_arg_float(snode_->num_active_indices, val);
prog_->synchronize();
prog_->launch_kernel(
prog_->compile_kernel(prog_->compile_config(), prog_->get_device_caps(),
*writer_),
launch_ctx);
const auto &compiled_kernel_data = prog_->compile_kernel(
prog_->compile_config(), prog_->get_device_caps(), *writer_);
prog_->launch_kernel(compiled_kernel_data, launch_ctx);
}

float64 SNodeRwAccessorsBank::Accessors::read_float(const std::vector<int> &I) {
prog_->synchronize();
auto launch_ctx = reader_->make_launch_context();
set_kernel_args(I, snode_->num_active_indices, &launch_ctx);
prog_->launch_kernel(
prog_->compile_kernel(prog_->compile_config(), prog_->get_device_caps(),
*reader_),
launch_ctx);
const auto &compiled_kernel_data = prog_->compile_kernel(
prog_->compile_config(), prog_->get_device_caps(), *reader_);
prog_->launch_kernel(compiled_kernel_data, launch_ctx);
prog_->synchronize();
if (arch_uses_llvm(prog_->compile_config().arch)) {
return launch_ctx.get_struct_ret_float({0});
Expand All @@ -70,10 +68,9 @@ void SNodeRwAccessorsBank::Accessors::write_int(const std::vector<int> &I,
set_kernel_args(I, snode_->num_active_indices, &launch_ctx);
launch_ctx.set_arg_int(snode_->num_active_indices, val);
prog_->synchronize();
prog_->launch_kernel(
prog_->compile_kernel(prog_->compile_config(), prog_->get_device_caps(),
*writer_),
launch_ctx);
const auto &compiled_kernel_data = prog_->compile_kernel(
prog_->compile_config(), prog_->get_device_caps(), *writer_);
prog_->launch_kernel(compiled_kernel_data, launch_ctx);
}

// for int32 and int64
Expand All @@ -83,20 +80,18 @@ void SNodeRwAccessorsBank::Accessors::write_uint(const std::vector<int> &I,
set_kernel_args(I, snode_->num_active_indices, &launch_ctx);
launch_ctx.set_arg_uint(snode_->num_active_indices, val);
prog_->synchronize();
prog_->launch_kernel(
prog_->compile_kernel(prog_->compile_config(), prog_->get_device_caps(),
*writer_),
launch_ctx);
const auto &compiled_kernel_data = prog_->compile_kernel(
prog_->compile_config(), prog_->get_device_caps(), *writer_);
prog_->launch_kernel(compiled_kernel_data, launch_ctx);
}

int64 SNodeRwAccessorsBank::Accessors::read_int(const std::vector<int> &I) {
prog_->synchronize();
auto launch_ctx = reader_->make_launch_context();
set_kernel_args(I, snode_->num_active_indices, &launch_ctx);
prog_->launch_kernel(
prog_->compile_kernel(prog_->compile_config(), prog_->get_device_caps(),
*reader_),
launch_ctx);
const auto &compiled_kernel_data = prog_->compile_kernel(
prog_->compile_config(), prog_->get_device_caps(), *reader_);
prog_->launch_kernel(compiled_kernel_data, launch_ctx);
prog_->synchronize();
if (arch_uses_llvm(prog_->compile_config().arch)) {
return launch_ctx.get_struct_ret_int({0});
Expand All @@ -109,10 +104,9 @@ uint64 SNodeRwAccessorsBank::Accessors::read_uint(const std::vector<int> &I) {
prog_->synchronize();
auto launch_ctx = reader_->make_launch_context();
set_kernel_args(I, snode_->num_active_indices, &launch_ctx);
prog_->launch_kernel(
prog_->compile_kernel(prog_->compile_config(), prog_->get_device_caps(),
*reader_),
launch_ctx);
const auto &compiled_kernel_data = prog_->compile_kernel(
prog_->compile_config(), prog_->get_device_caps(), *reader_);
prog_->launch_kernel(compiled_kernel_data, launch_ctx);
prog_->synchronize();
if (arch_uses_llvm(prog_->compile_config().arch)) {
return launch_ctx.get_struct_ret_uint({0});
Expand Down
24 changes: 12 additions & 12 deletions tests/cpp/ir/ir_builder_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,9 @@ TEST(IRBuilder, ExternalPtr) {
launch_ctx.set_arg_external_array_with_shape(
/*arg_id=*/0, (uint64)array.get(), size, {size});
auto *prog = test_prog.prog();
prog->launch_kernel(prog->compile_kernel(prog->compile_config(),
prog->get_device_caps(), *ker),
launch_ctx);
const auto &compiled_kernel_data = prog->compile_kernel(
prog->compile_config(), prog->get_device_caps(), *ker);
prog->launch_kernel(compiled_kernel_data, launch_ctx);
EXPECT_EQ(array[0], 2);
EXPECT_EQ(array[1], 1);
EXPECT_EQ(array[2], 42);
Expand All @@ -145,9 +145,9 @@ TEST(IRBuilder, Ndarray) {
auto ker1 = setup_kernel1(test_prog.prog());
auto launch_ctx1 = ker1->make_launch_context();
launch_ctx1.set_arg_ndarray(/*arg_id=*/0, array);
prog->launch_kernel(prog->compile_kernel(prog->compile_config(),
prog->get_device_caps(), *ker1),
launch_ctx1);
const auto &compiled_kernel_data = prog->compile_kernel(
prog->compile_config(), prog->get_device_caps(), *ker1);
prog->launch_kernel(compiled_kernel_data, launch_ctx1);
EXPECT_EQ(array.read_int({0}), 2);
EXPECT_EQ(array.read_int({1}), 1);
EXPECT_EQ(array.read_int({2}), 42);
Expand All @@ -156,9 +156,9 @@ TEST(IRBuilder, Ndarray) {
auto launch_ctx2 = ker2->make_launch_context();
launch_ctx2.set_arg_ndarray(/*arg_id=*/0, array);
launch_ctx2.set_arg_int(/*arg_id=*/1, 3);
prog->launch_kernel(prog->compile_kernel(prog->compile_config(),
prog->get_device_caps(), *ker2),
launch_ctx2);
const auto &compiled_kernel_data2 = prog->compile_kernel(
prog->compile_config(), prog->get_device_caps(), *ker2);
prog->launch_kernel(compiled_kernel_data2, launch_ctx2);
EXPECT_EQ(array.read_int({0}), 2);
EXPECT_EQ(array.read_int({1}), 3);
EXPECT_EQ(array.read_int({2}), 42);
Expand Down Expand Up @@ -187,9 +187,9 @@ TEST(IRBuilder, AtomicOp) {
launch_ctx.set_arg_external_array_with_shape(
/*arg_id=*/0, (uint64)array.get(), size, {size});
auto *prog = test_prog.prog();
prog->launch_kernel(prog->compile_kernel(prog->compile_config(),
prog->get_device_caps(), *ker),
launch_ctx);
const auto &compiled_kernel_data = prog->compile_kernel(
prog->compile_config(), prog->get_device_caps(), *ker);
prog->launch_kernel(compiled_kernel_data, launch_ctx);
EXPECT_EQ(array[0], 3);
}
} // namespace taichi::lang

0 comments on commit 25fb9ce

Please sign in to comment.