Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[refactor] Separate SNode read/write kernels into a dedicated class #2205

Merged
merged 3 commits into from
Mar 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 0 additions & 69 deletions taichi/ir/snode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,6 @@

TLANG_NAMESPACE_BEGIN

namespace {

void set_kernel_args(const std::vector<int> &I,
int num_active_indices,
Kernel::LaunchContextBuilder *launch_ctx) {
for (int i = 0; i < num_active_indices; i++) {
launch_ctx->set_arg_int(i, I[i]);
}
}

} // namespace

std::atomic<int> SNode::counter{0};

SNode &SNode::insert_children(SNodeType t) {
Expand Down Expand Up @@ -197,63 +185,6 @@ SNode *SNode::get_least_sparse_ancestor() const {
return result;
}

// for float and double
void SNode::write_float(const std::vector<int> &I, float64 val) {
if (writer_kernel == nullptr) {
writer_kernel = &get_current_program().get_snode_writer(this);
}
auto launch_ctx = writer_kernel->make_launch_context();
set_kernel_args(I, num_active_indices, &launch_ctx);
for (int i = 0; i < num_active_indices; i++) {
launch_ctx.set_arg_int(i, I[i]);
}
launch_ctx.set_arg_float(num_active_indices, val);
get_current_program().synchronize();
(*writer_kernel)(launch_ctx);
}

float64 SNode::read_float(const std::vector<int> &I) {
if (reader_kernel == nullptr) {
reader_kernel = &get_current_program().get_snode_reader(this);
}
get_current_program().synchronize();
auto launch_ctx = reader_kernel->make_launch_context();
set_kernel_args(I, num_active_indices, &launch_ctx);
(*reader_kernel)(launch_ctx);
get_current_program().synchronize();
auto ret = reader_kernel->get_ret_float(0);
return ret;
}

// for int32 and int64
void SNode::write_int(const std::vector<int> &I, int64 val) {
if (writer_kernel == nullptr) {
writer_kernel = &get_current_program().get_snode_writer(this);
}
auto launch_ctx = writer_kernel->make_launch_context();
set_kernel_args(I, num_active_indices, &launch_ctx);
launch_ctx.set_arg_int(num_active_indices, val);
get_current_program().synchronize();
(*writer_kernel)(launch_ctx);
}

int64 SNode::read_int(const std::vector<int> &I) {
if (reader_kernel == nullptr) {
reader_kernel = &get_current_program().get_snode_reader(this);
}
get_current_program().synchronize();
auto launch_ctx = reader_kernel->make_launch_context();
set_kernel_args(I, num_active_indices, &launch_ctx);
(*reader_kernel)(launch_ctx);
get_current_program().synchronize();
auto ret = reader_kernel->get_ret_int(0);
return ret;
}

uint64 SNode::read_uint(const std::vector<int> &I) {
return (uint64)read_int(I);
}

int SNode::shape_along_axis(int i) const {
const auto &extractor = extractors[physical_index_position[i]];
return extractor.num_elements * (1 << extractor.trailing_bits);
Expand Down
9 changes: 0 additions & 9 deletions taichi/ir/snode.h
Original file line number Diff line number Diff line change
Expand Up @@ -203,15 +203,6 @@ class SNode {
return *this;
}

// for float and double
void write_float(const std::vector<int> &I, float64);
float64 read_float(const std::vector<int> &I);

// for int32 and int64
void write_int(const std::vector<int> &I, int64);
int64 read_int(const std::vector<int> &I);
uint64 read_uint(const std::vector<int> &I);

int child_id(SNode *c) {
for (int i = 0; i < (int)ch.size(); i++) {
if (ch[i].get() == c) {
Expand Down
2 changes: 1 addition & 1 deletion taichi/program/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ inline uint64 *allocate_result_buffer_default(Program *prog) {
Program *current_program = nullptr;
std::atomic<int> Program::num_instances;

Program::Program(Arch desired_arch) {
Program::Program(Arch desired_arch) : snode_rw_accessors_bank_(this) {
TI_TRACE("Program initializing...");

// For performance considerations and correctness of CustomFloatType
Expand Down
6 changes: 6 additions & 0 deletions taichi/program/program.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "taichi/backends/cc/cc_program.h"
#include "taichi/program/kernel.h"
#include "taichi/program/kernel_profiler.h"
#include "taichi/program/snode_rw_accessors_bank.h"
#include "taichi/program/context.h"
#include "taichi/runtime/runtime.h"
#include "taichi/backends/metal/struct_metal.h"
Expand Down Expand Up @@ -275,13 +276,18 @@ class Program {

~Program();

inline SNodeRwAccessorsBank &get_snode_rw_accessors_bank() {
return snode_rw_accessors_bank_;
}

private:
// Metal related data structures
std::optional<metal::CompiledStructs> metal_compiled_structs_;
std::unique_ptr<metal::KernelManager> metal_kernel_mgr_;
// OpenGL related data structures
std::optional<opengl::StructCompiledResult> opengl_struct_compiled_;
std::unique_ptr<opengl::GLSLLauncher> opengl_kernel_launcher_;
SNodeRwAccessorsBank snode_rw_accessors_bank_;

public:
#ifdef TI_WITH_CC
Expand Down
86 changes: 86 additions & 0 deletions taichi/program/snode_rw_accessors_bank.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#include "taichi/program/snode_rw_accessors_bank.h"

#include "taichi/program/program.h"

namespace taichi {
namespace lang {

namespace {
void set_kernel_args(const std::vector<int> &I,
int num_active_indices,
Kernel::LaunchContextBuilder *launch_ctx) {
for (int i = 0; i < num_active_indices; i++) {
launch_ctx->set_arg_int(i, I[i]);
}
}
} // namespace

SNodeRwAccessorsBank::Accessors SNodeRwAccessorsBank::get(SNode *snode) {
auto &kernels = snode_to_kernels_[snode];
if (kernels.reader == nullptr) {
kernels.reader = &(program_->get_snode_reader(snode));
}
if (kernels.writer == nullptr) {
kernels.writer = &(program_->get_snode_writer(snode));
}
return Accessors(snode, kernels, program_);
}

SNodeRwAccessorsBank::Accessors::Accessors(const SNode *snode,
const RwKernels &kernels,
Program *prog)
: snode_(snode),
prog_(prog),
reader_(kernels.reader),
writer_(kernels.writer) {
TI_ASSERT(reader_ != nullptr);
TI_ASSERT(writer_ != nullptr);
}
void SNodeRwAccessorsBank::Accessors::write_float(const std::vector<int> &I,
float64 val) {
auto launch_ctx = writer_->make_launch_context();
set_kernel_args(I, snode_->num_active_indices, &launch_ctx);
for (int i = 0; i < snode_->num_active_indices; i++) {
launch_ctx.set_arg_int(i, I[i]);
}
launch_ctx.set_arg_float(snode_->num_active_indices, val);
prog_->synchronize();
(*writer_)(launch_ctx);
}

float64 SNodeRwAccessorsBank::Accessors::read_float(const std::vector<int> &I) {
prog_->synchronize();
auto launch_ctx = reader_->make_launch_context();
set_kernel_args(I, snode_->num_active_indices, &launch_ctx);
(*reader_)(launch_ctx);
prog_->synchronize();
auto ret = reader_->get_ret_float(0);
return ret;
}

// for int32 and int64
void SNodeRwAccessorsBank::Accessors::write_int(const std::vector<int> &I,
int64 val) {
auto launch_ctx = writer_->make_launch_context();
set_kernel_args(I, snode_->num_active_indices, &launch_ctx);
launch_ctx.set_arg_int(snode_->num_active_indices, val);
prog_->synchronize();
(*writer_)(launch_ctx);
}

int64 SNodeRwAccessorsBank::Accessors::read_int(const std::vector<int> &I) {
prog_->synchronize();
auto launch_ctx = reader_->make_launch_context();
set_kernel_args(I, snode_->num_active_indices, &launch_ctx);
(*reader_)(launch_ctx);
prog_->synchronize();
auto ret = reader_->get_ret_int(0);
return ret;
}

uint64 SNodeRwAccessorsBank::Accessors::read_uint(const std::vector<int> &I) {
return (uint64)read_int(I);
}

} // namespace lang
} // namespace taichi
60 changes: 60 additions & 0 deletions taichi/program/snode_rw_accessors_bank.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#pragma once

#include <unordered_map>

#include "taichi/program/kernel.h"
#include "taichi/ir/snode.h"

namespace taichi {
namespace lang {

class Program;

/** A mapping from an SNode to its read/write access kernels.
*
* The main purpose of this class is to decouple the accessor kernels from the
* SNode class itself. Ideally, SNode should be nothing more than a group of
* plain data.
*/
class SNodeRwAccessorsBank {
private:
struct RwKernels {
Kernel *reader{nullptr};
Kernel *writer{nullptr};
};

public:
class Accessors {
public:
explicit Accessors(const SNode *snode,
const RwKernels &kernels,
Program *prog);

// for float and double
void write_float(const std::vector<int> &I, float64 val);
float64 read_float(const std::vector<int> &I);

// for int32 and int64
void write_int(const std::vector<int> &I, int64 val);
int64 read_int(const std::vector<int> &I);
uint64 read_uint(const std::vector<int> &I);

private:
const SNode *snode_;
Program *prog_;
Kernel *reader_;
Kernel *writer_;
};

explicit SNodeRwAccessorsBank(Program *program) : program_(program) {
}

Accessors get(SNode *snode);

private:
Program *const program_;
std::unordered_map<const SNode *, RwKernels> snode_to_kernels_;
};

} // namespace lang
} // namespace taichi
30 changes: 25 additions & 5 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "taichi/ir/statements.h"
#include "taichi/program/extension.h"
#include "taichi/program/async_engine.h"
#include "taichi/program/snode_rw_accessors_bank.h"
#include "taichi/common/interface.h"
#include "taichi/python/export.h"
#include "taichi/gui/gui.h"
Expand Down Expand Up @@ -59,6 +60,10 @@ void compile_runtimes();
std::string libdevice_path();
std::string get_runtime_dir();

SNodeRwAccessorsBank::Accessors get_snode_rw_accessors(SNode *snode) {
return get_current_program().get_snode_rw_accessors_bank().get(snode);
}

TLANG_NAMESPACE_END

TI_NAMESPACE_BEGIN
Expand Down Expand Up @@ -268,15 +273,30 @@ void export_lang(py::module &m) {
[](SNode *snode, int i) -> SNode * { return snode->ch[i].get(); },
py::return_value_policy::reference)
.def("lazy_grad", &SNode::lazy_grad)
.def("read_int", &SNode::read_int)
.def("read_uint", &SNode::read_uint)
.def("read_float", &SNode::read_float)
.def("read_int",
[](SNode *snode, const std::vector<int> &I) -> int64 {
return get_snode_rw_accessors(snode).read_int(I);
})
.def("read_uint",
[](SNode *snode, const std::vector<int> &I) -> uint64 {
return get_snode_rw_accessors(snode).read_uint(I);
})
.def("read_float",
[](SNode *snode, const std::vector<int> &I) -> float64 {
return get_snode_rw_accessors(snode).read_float(I);
})
.def("has_grad", &SNode::has_grad)
.def("is_primal", &SNode::is_primal)
.def("is_place", &SNode::is_place)
.def("get_expr", &SNode::get_expr, py::return_value_policy::reference)
.def("write_int", &SNode::write_int)
.def("write_float", &SNode::write_float)
.def("write_int",
[](SNode *snode, const std::vector<int> &I, int64 val) {
get_snode_rw_accessors(snode).write_int(I, val);
})
.def("write_float",
[](SNode *snode, const std::vector<int> &I, float64 val) {
get_snode_rw_accessors(snode).write_float(I, val);
})
.def("get_shape_along_axis", &SNode::shape_along_axis)
.def("get_physical_index_position",
[](SNode *snode) {
Expand Down