From ee4090d9ecc684832e04e54ed693fea0e8dadc3f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BD=AD=E4=BA=8E=E6=96=8C?= <1931127624@qq.com> Date: Sun, 15 Mar 2020 00:30:49 +0800 Subject: [PATCH] OpenGL range_for support (stage 2.1.1) (#594) * [skip ci] range for using GLSL threads range for and extptr (TODO: extptr not work) [skip ci] remove ext ptr fake support Conflicts: examples/opengl_backend.py taichi/codegen/codegen_opengl.cpp taichi/platform/opengl/opengl_kernel.cpp taichi/platform/opengl/opengl_kernel.h Conflicts: taichi/program/program.cpp * [skip ci] fix typo in _thread_id_ > (should be >=) Conflicts: taichi/codegen/codegen_opengl.cpp Conflicts: taichi/codegen/codegen_opengl.cpp * multi work groups & compile before launch to save time Conflicts: taichi/codegen/codegen_opengl.cpp taichi/platform/opengl/opengl_api.cpp taichi/platform/opengl/opengl_api.h taichi/platform/opengl/opengl_data_types.h fix merge Conflicts: taichi/codegen/codegen_opengl.cpp * [skip ci] non-GL build fix --- examples/opengl_range_for.py | 14 ++ taichi/codegen/codegen_opengl.cpp | 254 ++++++++++++++++++----- taichi/platform/opengl/opengl_api.cpp | 23 +- taichi/platform/opengl/opengl_api.h | 4 +- taichi/platform/opengl/opengl_kernel.cpp | 23 +- taichi/platform/opengl/opengl_kernel.h | 5 +- 6 files changed, 246 insertions(+), 77 deletions(-) create mode 100644 examples/opengl_range_for.py diff --git a/examples/opengl_range_for.py b/examples/opengl_range_for.py new file mode 100644 index 0000000000000..d6026f3ce59ff --- /dev/null +++ b/examples/opengl_range_for.py @@ -0,0 +1,14 @@ +import taichi as ti + +ti.init(arch=ti.opengl) + +x = ti.var(ti.i32, shape=(5, 5)) + +@ti.kernel +def func(): + for i in range(5): + for j in range(5): + x[i, j] = i + j + +func() +print(x[2, 3]) diff --git a/taichi/codegen/codegen_opengl.cpp b/taichi/codegen/codegen_opengl.cpp index 3b406dbd267eb..e94d2ea785df1 100644 --- a/taichi/codegen/codegen_opengl.cpp +++ b/taichi/codegen/codegen_opengl.cpp @@ -1,3 +1,4 @@ +//#define _GLSL_DEBUG 1 #include "codegen_opengl.h" #include #include @@ -39,6 +40,8 @@ class KernelGen : public IRVisitor std::string root_snode_type_name_; std::string glsl_kernel_prefix_; int glsl_kernel_count_{0}; + int num_threads_{1}; + int num_groups_{1}; void push_indent() { @@ -63,19 +66,17 @@ class KernelGen : public IRVisitor emit("#version 430 core"); emit("#extension GL_ARB_compute_shader: enable"); emit("{}", struct_compiled_->source_code); - emit("layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;"); - emit("#define NARGS {}", taichi_max_num_args); emit("layout(std430, binding = 0) buffer args_i32"); emit("{{"); - emit(" int _args_i32_[NARGS * 2];"); + emit(" int _args_i32_[];"); emit("}};"); emit("layout(std430, binding = 0) buffer args_f32"); emit("{{"); - emit(" float _args_f32_[NARGS * 2];"); + emit(" float _args_f32_[];"); emit("}};"); emit("layout(std430, binding = 0) buffer args_f64"); emit("{{"); - emit(" double _args_f64_[NARGS];"); + emit(" double _args_f64_[];"); emit("}};"); emit("layout(std430, binding = 1) buffer data_i32"); emit("{{"); @@ -89,6 +90,18 @@ class KernelGen : public IRVisitor emit("{{"); emit(" double _data_f64_[];"); emit("}};"); + emit("layout(std430, binding = 2) buffer extra_args_i32"); + emit("{{"); + emit(" int _extra_args_i32_[];"); + emit("}};"); + emit("layout(std430, binding = 2) buffer extra_args_f32"); + emit("{{"); + emit(" float _extra_args_f32_[];"); + emit("}};"); + emit("layout(std430, binding = 2) buffer extra_args_f64"); + emit("{{"); + emit(" double _extra_args_f64_[];"); + emit("}};"); emit("#define _arg_i32(x) _args_i32_[(x) << 1]"); // skip to 64bit stride emit("#define _arg_f32(x) _args_f32_[(x) << 1]"); emit("#define _arg_i64(x) _args_i64_[(x) << 0]"); @@ -97,6 +110,10 @@ class KernelGen : public IRVisitor emit("#define _mem_f32(x) _data_f32_[(x) >> 2]"); emit("#define _mem_i64(x) _data_i64_[(x) >> 3]"); emit("#define _mem_f64(x) _data_f64_[(x) >> 3]"); + emit("#define _extarg_i32(x) _extra_args_i32_[(x) >> 2]"); + emit("#define _extarg_f32(x) _extra_args_f32_[(x) >> 2]"); + emit("#define _extarg_i64(x) _extra_args_i64_[(x) >> 3]"); + emit("#define _extarg_f64(x) _extra_args_f64_[(x) >> 3]"); emit(""); } @@ -105,20 +122,36 @@ class KernelGen : public IRVisitor // TODO(archibate): () really necessary? How about just main()? emit("void main()"); emit("{{"); - emit(" {}();", glsl_kernel_name_); + if (glsl_kernel_name_.size()) + emit(" {}();", glsl_kernel_name_); emit("}}"); + emit(""); + int threads_per_group = 1792; + if (num_threads_ < 1792) + threads_per_group = num_threads_; + num_groups_ = (num_threads_ + 1791) / 1792; + emit("layout(local_size_x = {}, local_size_y = 1, local_size_z = 1) in;", threads_per_group); } void visit(Block *stmt) override { if (!is_top_level_) push_indent(); for (auto &s : stmt->statements) { - //TI_INFO("visiting sub stmt {}", typeid(*s).name()); s->accept(this); } if (!is_top_level_) pop_indent(); } + virtual void visit(Stmt *stmt) override + { + TI_WARN("[glsl] default visitor called for {}", typeid(*stmt).name()); + } + + void visit(ExternalPtrStmt *stmt) override + { + TI_ERROR("[glsl] external pointers not supported on OpenGL arch"); + } + void visit(LinearizeStmt *stmt) override { std::string val = "0"; @@ -126,7 +159,7 @@ class KernelGen : public IRVisitor val = fmt::format("({} * {} + {})", val, stmt->strides[i], stmt->inputs[i]->raw_name()); } - emit("const int {} = {};", stmt->raw_name(), val); + emit("int {} = {};", stmt->raw_name(), val); } void visit(OffsetAndExtractBitsStmt *stmt) override @@ -163,17 +196,11 @@ class KernelGen : public IRVisitor void visit(GetChStmt *stmt) override { - if (stmt->output_snode->is_place()) { - emit("{} /* place {} */ {} = {}_get{}({});", - stmt->output_snode->node_type_name, - opengl_data_type_name(stmt->output_snode->dt), - stmt->raw_name(), stmt->input_snode->node_type_name, - stmt->chid, stmt->input_ptr->raw_name()); - } else { - emit("{} {} = {}_get{}({});", stmt->output_snode->node_type_name, - stmt->raw_name(), stmt->input_snode->node_type_name, - stmt->chid, stmt->input_ptr->raw_name()); - } + emit("{} {} = {}_get{}({});", stmt->output_snode->node_type_name, + stmt->raw_name(), stmt->input_snode->node_type_name, + stmt->chid, stmt->input_ptr->raw_name()); + if (stmt->output_snode->is_place()) + emit("// place {}", opengl_data_type_name(stmt->output_snode->dt)); } void visit(GlobalStoreStmt *stmt) override @@ -192,14 +219,18 @@ class KernelGen : public IRVisitor void visit(UnaryOpStmt *stmt) override { - if (stmt->op_type != UnaryOpType::cast) { - emit("const {} {} = {}({});", opengl_data_type_name(stmt->element_type()), - stmt->raw_name(), unary_op_type_name(stmt->op_type), + if (stmt->op_type == UnaryOpType::logic_not) { + emit("{} {} = {}({} == 0);", opengl_data_type_name(stmt->element_type()), + stmt->raw_name(), opengl_data_type_name(stmt->element_type()), stmt->operand->raw_name()); + } else if (stmt->op_type != UnaryOpType::cast) { + emit("{} {} = {}({}({}));", opengl_data_type_name(stmt->element_type()), + stmt->raw_name(), opengl_data_type_name(stmt->element_type()), + unary_op_type_name(stmt->op_type), stmt->operand->raw_name()); } else { // cast if (stmt->cast_by_value) { - emit("const {} {} = {}({});", + emit("{} {} = {}({});", opengl_data_type_name(stmt->element_type()), stmt->raw_name(), opengl_data_type_name(stmt->cast_type), stmt->operand->raw_name()); } else { @@ -216,29 +247,29 @@ class KernelGen : public IRVisitor const auto bin_name = bin->raw_name(); if (bin->op_type == BinaryOpType::floordiv) { if (is_integral(bin->element_type())) { - emit("const {} {} = int(floor({} / {}));", dt_name, bin_name, lhs_name, + emit("{} {} = int(floor({} / {}));", dt_name, bin_name, lhs_name, rhs_name); } else { - emit("const {} {} = floor({} / {});", dt_name, bin_name, lhs_name, + emit("{} {} = floor({} / {});", dt_name, bin_name, lhs_name, rhs_name); } return; } const auto binop = binary_op_type_symbol(bin->op_type); if (is_opengl_binary_op_infix(bin->op_type)) { - emit("const {} {} = ({} {} {});", dt_name, bin_name, lhs_name, binop, - rhs_name); + emit("{} {} = {}({} {} {});", dt_name, bin_name, dt_name, + lhs_name, binop, rhs_name); } else { // This is a function call - emit("const {} {} = {}({}, {});", dt_name, bin_name, binop, lhs_name, - rhs_name); + emit("{} {} = {}({}, {});", dt_name, bin_name, binop, lhs_name, + rhs_name); } } void visit(TernaryOpStmt *tri) override { TI_ASSERT(tri->op_type == TernaryOpType::select); - emit("const {} {} = ({}) ? ({}) : ({});", + emit("{} {} = ({}) ? ({}) : ({});", opengl_data_type_name(tri->element_type()), tri->raw_name(), tri->op1->raw_name(), tri->op2->raw_name(), tri->op3->raw_name()); } @@ -254,7 +285,7 @@ class KernelGen : public IRVisitor if (stmt->same_source() && linear_index && stmt->width() == stmt->ptr[0].var->width()) { auto ptr = stmt->ptr[0].var; - emit("const {} {}({});", opengl_data_type_name(stmt->element_type()), + emit("{} {} = {};", opengl_data_type_name(stmt->element_type()), stmt->raw_name(), ptr->raw_name()); } else { TI_NOT_IMPLEMENTED; @@ -268,7 +299,7 @@ class KernelGen : public IRVisitor void visit(AllocaStmt *alloca) override { - emit("{} {}(0);", + emit("{} {};", // need = 0? opengl_data_type_name(alloca->element_type()), alloca->raw_name()); } @@ -276,7 +307,7 @@ class KernelGen : public IRVisitor void visit(ConstStmt *const_stmt) override { TI_ASSERT(const_stmt->width() == 1); - emit("const {} {} = {};", opengl_data_type_name(const_stmt->element_type()), + emit("{} {} = {};", opengl_data_type_name(const_stmt->element_type()), const_stmt->raw_name(), const_stmt->val[0].stringify()); } @@ -284,10 +315,10 @@ class KernelGen : public IRVisitor { const auto dt = opengl_data_type_name(stmt->element_type()); if (stmt->is_ptr) { - emit("const {} {} = _arg_{}({}); // is_ptr", dt, stmt->raw_name(), + emit("{} {} = _arg_{}({}); // is ext pointer", dt, stmt->raw_name(), data_type_short_name(stmt->element_type()), stmt->arg_id); } else { - emit("const {} {} = _arg_{}({});", dt, stmt->raw_name(), + emit("{} {} = _arg_{}({});", dt, stmt->raw_name(), data_type_short_name(stmt->element_type()), stmt->arg_id); } } @@ -314,6 +345,86 @@ class KernelGen : public IRVisitor emit("}}\n"); } + void generate_range_for_kernel(OffloadedStmt *stmt) { + TI_ASSERT(stmt->task_type == OffloadedStmt::TaskType::range_for); + const std::string glsl_kernel_name = make_kernel_name(); + emit("void {}()", glsl_kernel_name); + this->glsl_kernel_name_ = glsl_kernel_name; + emit("{{ // range for"); + + push_indent(); + if (stmt->const_begin && stmt->const_end) { + TI_ASSERT_INFO(stmt->end_value > stmt->begin_value, + "range for end value <= begin value"); + num_threads_ = stmt->end_value - stmt->begin_value; + emit("// range known at compile time"); + emit("const int _thread_id_ = int(gl_GlobalInvocationID.x);"); + emit("if (_thread_id_ >= {}) return;", num_threads_); + emit("const int _it_value_ = {} + _thread_id_ * {};", + stmt->begin_value, 1 /* stmt->step? */); + } else { + TI_ERROR("non-const range_for currently unsupported under OpenGL"); + /*range_for_attribs.begin = + (stmt->const_begin ? stmt->begin_value : stmt->begin_offset); + range_for_attribs.end = + (stmt->const_end ? stmt->end_value : stmt->end_offset);*/ + } + pop_indent(); + + stmt->body->accept(this); + emit("}}\n"); + } + + void visit(LoopIndexStmt *stmt) override + { + TI_ASSERT(!stmt->is_struct_for); + TI_ASSERT(stmt->index == 0); // TODO: multiple indices + emit("int {} = _it_value_;", stmt->raw_name()); + } + + void visit(RangeForStmt *for_stmt) override + { + TI_ASSERT(for_stmt->width() == 1); + auto *loop_var = for_stmt->loop_var; + if (loop_var->ret_type.data_type == DataType::i32) { + if (!for_stmt->reversed) { + emit("for (int {}_ = {}; {}_ < {}; {}_ = {}_ + {}) {{", + loop_var->raw_name(), for_stmt->begin->raw_name(), + loop_var->raw_name(), for_stmt->end->raw_name(), + loop_var->raw_name(), loop_var->raw_name(), 1); + // variable named `loop_var->raw_name()` is already allocated by alloca + emit(" {} = {}_;", loop_var->raw_name(), loop_var->raw_name()); + } else { + // reversed for loop + emit("for (int {}_ = {} - 1; {}_ >= {}; {}_ = {}_ - {}) {{", + loop_var->raw_name(), for_stmt->end->raw_name(), + loop_var->raw_name(), for_stmt->begin->raw_name(), + loop_var->raw_name(), loop_var->raw_name(), 1); + emit(" {} = {}_;", loop_var->raw_name(), loop_var->raw_name()); + } + } else { + TI_ASSERT(!for_stmt->reversed); + const auto type_name = opengl_data_type_name(loop_var->element_type()); + emit("for ({} {} = {}; {} < {}; {} = {} + 1) {{", type_name, + loop_var->raw_name(), for_stmt->begin->raw_name(), + loop_var->raw_name(), for_stmt->end->raw_name(), + loop_var->raw_name(), loop_var->raw_name()); + } + for_stmt->body->accept(this); + emit("}}"); + } + + void visit(WhileControlStmt *stmt) override + { + emit("if ({} == 0) break;", stmt->cond->raw_name()); + } + + void visit(WhileStmt *stmt) override + { + emit("while (true) {{"); + stmt->body->accept(this); + emit("}}"); + } void visit(OffloadedStmt *stmt) override { @@ -322,8 +433,8 @@ class KernelGen : public IRVisitor using Type = OffloadedStmt::TaskType; if (stmt->task_type == Type::serial) { generate_serial_kernel(stmt); - /*} else if (stmt->task_type == Type::range_for) { - generate_range_for_kernel(stmt);*/ + } else if (stmt->task_type == Type::range_for) { + generate_range_for_kernel(stmt); } else { // struct_for is automatically lowered to ranged_for for dense snodes // (#378). So we only need to support serial and range_for tasks. @@ -332,6 +443,22 @@ class KernelGen : public IRVisitor is_top_level_ = true; } + void visit(StructForStmt *) override + { + TI_ERROR("Struct for cannot be nested under OpenGL for now"); + } + + void visit(IfStmt *if_stmt) override { + emit("if ({} != 0) {{", if_stmt->cond->raw_name()); + if (if_stmt->true_statements) { + if_stmt->true_statements->accept(this); + } + if (if_stmt->false_statements) { + emit("}} else {{"); + if_stmt->false_statements->accept(this); + } + emit("}}"); + } public: const std::string &kernel_source_code() const @@ -339,11 +466,15 @@ class KernelGen : public IRVisitor return kernel_src_code_; } + int get_num_work_groups() const + { + return num_groups_; + } + SSBO *create_root_ssbo() { static SSBO *root_ssbo; if (!root_ssbo) { - TI_INFO("[glsl] creating root buffer of size {} B", struct_compiled_->root_size); root_ssbo = new SSBO(struct_compiled_->root_size); } return root_ssbo; @@ -351,11 +482,9 @@ class KernelGen : public IRVisitor void run(const SNode &root_snode) { - //TI_INFO("ntm:: {}", root_snode.node_type_name); root_snode_ = &root_snode; root_snode_type_name_ = root_snode.node_type_name; generate_header(); - //irpass::print(kernel->ir); kernel->ir->accept(this); generate_bottom(); } @@ -364,7 +493,7 @@ class KernelGen : public IRVisitor } // namespace void OpenglCodeGen::lower() -{ +{ // {{{ auto ir = kernel_->ir; const bool print_ir = prog_->config.print_ir; if (print_ir) { @@ -477,7 +606,11 @@ void OpenglCodeGen::lower() irpass::re_id(ir); irpass::print(ir); } -} + +#ifdef _GLSL_DEBUG + irpass::print(ir); +#endif +} // }}} FunctionType OpenglCodeGen::gen(void) { @@ -485,23 +618,38 @@ FunctionType OpenglCodeGen::gen(void) codegen.run(*prog_->snode_root); SSBO *root_sb = codegen.create_root_ssbo(); const std::string kernel_source_code = codegen.kernel_source_code(); - //TI_INFO("source of kernel [{}]:\n{}", kernel_name_, kernel_source_code); + int num_groups = codegen.get_num_work_groups(); + TI_INFO("source of kernel [{}]:\n{}", kernel_name_, kernel_source_code); + GLProgram *glsl = compile_glsl_program(kernel_source_code); - return [kernel_source_code, root_sb](Context &ctx) { - // TODO(archibate): find out where get_arg stored, and just new SSBO(ctx) + return [glsl, num_groups, root_sb](Context &ctx) { + // TODO(archibate): try implement just new_ssbo_from_buffer(ctx.args) and no free like _IOMYBUF SSBO *arg_sb = new SSBO(taichi_max_num_args * sizeof(uint64_t)); - arg_sb->load_arguments_from(ctx); - std::vector iov = {*arg_sb, *root_sb}; - /*TI_INFO("data[0] = {}", ((int*)root_sb->data)[0]); + SSBO *extarg_sb = new SSBO(Context::extra_args_size); + arg_sb->load_from((void *)ctx.args); + extarg_sb->load_from((void *)ctx.extra_args); + std::vector iov = {*arg_sb, *root_sb, *extarg_sb}; +#ifdef _GLSL_DEBUG + TI_INFO("data[0] = {}", ((int*)root_sb->data)[0]); TI_INFO("data[1] = {}", ((int*)root_sb->data)[1]); TI_INFO("args[0] = {}", ((uint64_t*)arg_sb->data)[0]); - TI_INFO("args[1] = {}", ((uint64_t*)arg_sb->data)[1]);*/ - launch_glsl_kernel(kernel_source_code, iov); - /*TI_INFO("data[0] = {}", ((int*)root_sb->data)[0]); + TI_INFO("args[1] = {}", ((uint64_t*)arg_sb->data)[1]); + TI_INFO("earg[0] = {}", ((int*)extarg_sb->data)[0]); + TI_INFO("earg[1] = {}", ((int*)extarg_sb->data)[1]); +#endif + launch_glsl_kernel(glsl, iov, num_groups); +#ifdef _GLSL_DEBUG + TI_INFO("data[0] = {}", ((int*)root_sb->data)[0]); TI_INFO("data[1] = {}", ((int*)root_sb->data)[1]); TI_INFO("args[0] = {}", ((uint64_t*)arg_sb->data)[0]); - TI_INFO("args[1] = {}", ((uint64_t*)arg_sb->data)[1]);*/ - arg_sb->save_returns_to(ctx); + TI_INFO("args[1] = {}", ((uint64_t*)arg_sb->data)[1]); + TI_INFO("earg[0] = {}", ((int*)extarg_sb->data)[0]); + TI_INFO("earg[1] = {}", ((int*)extarg_sb->data)[1]); +#endif + arg_sb->save_to((void *)ctx.args); + extarg_sb->save_to((void *)ctx.extra_args); + delete arg_sb; + delete extarg_sb; }; } diff --git a/taichi/platform/opengl/opengl_api.cpp b/taichi/platform/opengl/opengl_api.cpp index 66dbaeb9750f3..971ec17e855f4 100644 --- a/taichi/platform/opengl/opengl_api.cpp +++ b/taichi/platform/opengl/opengl_api.cpp @@ -81,7 +81,11 @@ struct GLProgram id_ = glCreateProgram(); } - GLProgram(GLShader &shader) + explicit GLProgram(GLuint id) + : id_(id) + {} + + explicit GLProgram(GLShader &shader) : GLProgram() { this->attach(shader); @@ -238,12 +242,17 @@ void initialize_opengl() } } -void launch_glsl_kernel(std::string source, std::vector iov) +GLProgram *compile_glsl_program(std::string source) { GLShader shader(source); - GLProgram program(shader); - program.link(); - program.use(); + GLProgram *program = new GLProgram(shader); + program->link(); + return program; +} + +void launch_glsl_kernel(GLProgram *program, std::vector iov, int num_groups) +{ + program->use(); std::vector ssbo(iov.size()); for (int i = 0; i < ssbo.size(); i++) { @@ -258,7 +267,7 @@ void launch_glsl_kernel(std::string source, std::vector iov) // `glDispatchCompute(X, Y, Z)` - the X*Y*Z == `Blocks` in CUDA // `layout(local_size_x = X) in;` - the X == `Threads` in CUDA // - glDispatchCompute(1, 1, 1); + glDispatchCompute(num_groups, 1, 1); glMemoryBarrier(GL_SHADER_STORAGE_BARRIER_BIT); // TODO(archibate): move to Program::synchroize() for (int i = 0; i < ssbo.size(); i++) { @@ -274,7 +283,7 @@ bool is_opengl_api_available() } #else -void launch_glsl_kernel(std::string source, std::vector iov) +void launch_glsl_kernel(GLProgram *program, std::vector iov, int num_groups) { TI_NOT_IMPLEMENTED } diff --git a/taichi/platform/opengl/opengl_api.h b/taichi/platform/opengl/opengl_api.h index 297ff56660ab3..1f67fd1b61a7d 100644 --- a/taichi/platform/opengl/opengl_api.h +++ b/taichi/platform/opengl/opengl_api.h @@ -11,9 +11,11 @@ TLANG_NAMESPACE_BEGIN namespace opengl { +struct GLProgram; void initialize_opengl(); bool is_opengl_api_available(); -void launch_glsl_kernel(std::string source, std::vector iov); +void launch_glsl_kernel(GLProgram *program, std::vector iov, int num_groups); +GLProgram *compile_glsl_program(std::string source); } // namespace opengl diff --git a/taichi/platform/opengl/opengl_kernel.cpp b/taichi/platform/opengl/opengl_kernel.cpp index a714b7d9d377f..5147b16fd8a38 100644 --- a/taichi/platform/opengl/opengl_kernel.cpp +++ b/taichi/platform/opengl/opengl_kernel.cpp @@ -6,23 +6,18 @@ TLANG_NAMESPACE_BEGIN namespace opengl { -SSBO::SSBO(size_t data_size_) : data_(data_size_), data_size(data_size_) { -} +SSBO::SSBO(size_t data_size_) + : data_(data_size_), data_size(data_size_) +{} -void SSBO::load_arguments_from(Context &ctx) { - uint64_t *data_i = (uint64_t *)data(); - for (int i = 0; i < taichi_max_num_args; i++) { - uint64_t value = ctx.get_arg(i); - data_i[i] = value; - } +void SSBO::load_from(const void *buffer) +{ + std::memcpy(data(), buffer, data_size); } -void SSBO::save_returns_to(Context &ctx) { - uint64_t *data_i = (uint64_t *)data(); - for (int i = 0; i < taichi_max_num_args; i++) { - uint64_t value = data_i[i]; - ctx.set_arg(i, value); - } +void SSBO::save_to(void *buffer) +{ + std::memcpy(buffer, data(), data_size); } } // namespace opengl diff --git a/taichi/platform/opengl/opengl_kernel.h b/taichi/platform/opengl/opengl_kernel.h index f7422b383129c..92b2acbcdda5d 100644 --- a/taichi/platform/opengl/opengl_kernel.h +++ b/taichi/platform/opengl/opengl_kernel.h @@ -20,8 +20,9 @@ struct SSBO SSBO(size_t data_size); - void load_arguments_from(Context &ctx); - void save_returns_to(Context &ctx); + void load_from(const void *buffer); + void save_to(void *buffer); + inline void *data() { return (void *)data_.data();