Skip to content

Commit

Permalink
[spirv] Specialize element shape for spirv codegen. (#5068)
Browse files Browse the repository at this point in the history
* Specialize element shape for spirv codegen.

* Fix index for size_var_names

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Slight changes for better code style.

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
turbo0628 and pre-commit-ci[bot] authored Jun 1, 2022
1 parent c985372 commit 1514d4b
Showing 1 changed file with 31 additions and 2 deletions.
33 changes: 31 additions & 2 deletions taichi/codegen/spirv/spirv_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -564,8 +564,30 @@ class TaskCodegen : public IRVisitor {
{
const int num_indices = stmt->indices.size();
std::vector<std::string> size_var_names;
const auto &element_shape = stmt->element_shape;
enum ExternalArrayLayout { layout_AOS = 0, layout_SOA = 1 };
const auto layout = stmt->element_dim <= 0 ? layout_AOS : layout_SOA;
const auto extra_args_member_index = ctx_attribs_->args().size();

// Determine the element shape position inside the indices vector
// TODO: change the outer layout in order to remove the element layout
// guess work
int element_shape_begin = -1;
int element_shape_end = -1;
if (element_shape.size() > 0) {
if (layout == layout_SOA) {
element_shape_begin = 0;
element_shape_end = element_shape.size();
} else {
element_shape_begin = num_indices - element_shape.size();
element_shape_end = num_indices;
}
}
for (int i = 0; i < num_indices; i++) {
// Skip expressions for element shapes.
if (i >= element_shape_begin && i < element_shape_end) {
continue;
}
std::string var_name = fmt::format("{}_size{}_", stmt->raw_name(), i);
const auto extra_arg_index = (arg_id * taichi_max_num_indices) + i;
spirv::Value var_ptr = ir_->make_value(
Expand All @@ -578,8 +600,16 @@ class TaskCodegen : public IRVisitor {
ir_->register_value(var_name, var);
size_var_names.push_back(std::move(var_name));
}
int size_var_names_idx = 0;
for (int i = 0; i < num_indices; i++) {
spirv::Value size_var = ir_->query_value(size_var_names[i]);
spirv::Value size_var;
// Use immediate numbers to flatten index for element shapes.
if (i >= element_shape_begin && i < element_shape_end) {
size_var = ir_->uint_immediate_number(
ir_->i32_type(), element_shape[i - element_shape_begin]);
} else {
size_var = ir_->query_value(size_var_names[size_var_names_idx++]);
}
spirv::Value indices = ir_->query_value(stmt->indices[i]->raw_name());
linear_offset = ir_->mul(linear_offset, size_var);
linear_offset = ir_->add(linear_offset, indices);
Expand All @@ -592,7 +622,6 @@ class TaskCodegen : public IRVisitor {
ir_->decorate(spv::OpDecorate, linear_offset,
spv::DecorationNoSignedWrap);
}

if (device_->get_cap(DeviceCapability::spirv_has_physical_storage_buffer)) {
spirv::Value addr_ptr = ir_->make_value(
spv::OpAccessChain,
Expand Down

0 comments on commit 1514d4b

Please sign in to comment.