Skip to content

Commit

Permalink
[refactor] Remove legacy num_bits and acc_offsets from AxisExtractor (#…
Browse files Browse the repository at this point in the history
…7227)

### Brief Summary

This is a follow-up PR of #7104 to fully remove legacy code for the
SNode padding behavior. Note that the experimental code of struct for on
bitmasked SNode for SPIR-V is removed here because it assumes the
existence of the padding behavior.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
strongoier and pre-commit-ci[bot] authored Jan 28, 2023
1 parent 01ed3fb commit 0dff195
Show file tree
Hide file tree
Showing 8 changed files with 12 additions and 295 deletions.
4 changes: 0 additions & 4 deletions taichi/analysis/offline_cache_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,6 @@ static void get_offline_cache_key_of_snode_impl(
serializer(extractor.num_elements_from_root);
serializer(extractor.shape);
serializer(extractor.acc_shape);
serializer(extractor.num_bits);
serializer(extractor.acc_offset);
serializer(extractor.active);
}
serializer(snode->index_offsets);
Expand All @@ -97,8 +95,6 @@ static void get_offline_cache_key_of_snode_impl(
serializer(snode->depth);
serializer(snode->name);
serializer(snode->num_cells_per_container);
serializer(snode->total_num_bits);
serializer(snode->total_bit_start);
serializer(snode->chunk_size);
serializer(snode->cell_size_bytes);
serializer(snode->offset_bytes_in_parent_cell);
Expand Down
33 changes: 4 additions & 29 deletions taichi/codegen/spirv/snode_struct_compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,14 @@ class StructCompiler {
sn_desc.cell_stride = cell_stride;

if (sn->type == SNodeType::bitmasked) {
size_t num_cells = sn_desc.cells_per_container_pot();
size_t num_cells = sn_desc.snode->num_cells_per_container;
size_t bitmask_num_words =
num_cells % 32 == 0 ? (num_cells / 32) : (num_cells / 32 + 1);
sn_desc.container_stride =
cell_stride * num_cells + bitmask_num_words * 4;
} else {
sn_desc.container_stride =
cell_stride * sn_desc.cells_per_container_pot();
cell_stride * sn_desc.snode->num_cells_per_container;
}
}

Expand All @@ -121,34 +121,13 @@ class StructCompiler {
sn_desc.total_num_cells_from_root *= e.num_elements_from_root;
}

// Sum the bits per axis
SNode *snode_head = sn;
do {
for (int i = 0; i < taichi_max_num_indices; i++) {
const AxisExtractor &extractor = snode_head->extractors[i];
if (extractor.active) {
sn_desc.axis_bits_sum[i] += extractor.num_bits;
}
}
} while ((snode_head = snode_head->parent));
// Find the start bit
sn_desc.axis_start_bit[0] = 0;
for (int i = 1; i < taichi_max_num_indices; i++) {
sn_desc.axis_start_bit[i] =
sn_desc.axis_bits_sum[i - 1] + sn_desc.axis_start_bit[i - 1];
}
TI_TRACE("Indices at SNode {}", sn->get_name());
for (int i = 0; i < taichi_max_num_indices; i++) {
TI_TRACE("Index {}: {}..{}", i, sn_desc.axis_start_bit[i],
sn_desc.axis_start_bit[i] + sn_desc.axis_bits_sum[i]);
}

TI_TRACE("SNodeDescriptor");
TI_TRACE("* snode={}", sn_desc.snode->id);
TI_TRACE("* type={} (is_place={})", sn_desc.snode->node_type_name,
is_place);
TI_TRACE("* cell_stride={}", sn_desc.cell_stride);
TI_TRACE("* cells_per_container_pot={}", sn_desc.cells_per_container_pot());
TI_TRACE("* num_cells_per_container={}",
sn_desc.snode->num_cells_per_container);
TI_TRACE("* container_stride={}", sn_desc.container_stride);
TI_TRACE("* total_num_cells_from_root={}",
sn_desc.total_num_cells_from_root);
Expand All @@ -164,10 +143,6 @@ class StructCompiler {

} // namespace

size_t SNodeDescriptor::cells_per_container_pot() const {
return snode->num_cells_per_container;
}

CompiledSNodeStructs compile_snode_structs(SNode &root) {
StructCompiler compiler;
return compiler.run(root);
Expand Down
8 changes: 1 addition & 7 deletions taichi/codegen/spirv/snode_struct_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,10 @@ struct SNodeDescriptor {
// Stride (bytes) of a single cell.
size_t cell_stride = 0;

// Number of cells per container, padded to Power of Two (pot).
size_t cells_per_container_pot() const;

// Bytes of a single container.
size_t container_stride = 0;

// Total number of CELLS of this SNode, NOT padded to PoT.
// Total number of CELLS of this SNode
// For example, for a layout of
// ti.root
// .dense(ti.ij, (3, 2)) // S1
Expand All @@ -33,9 +30,6 @@ struct SNodeDescriptor {
// starts at a fixed offset in its parent cell's memory.
size_t mem_offset_in_parent_cell = 0;

int axis_bits_sum[taichi_max_num_indices] = {0};
int axis_start_bit[taichi_max_num_indices] = {0};

SNode *get_child(int ch_i) const {
return snode->ch[ch_i].get();
}
Expand Down
174 changes: 1 addition & 173 deletions taichi/codegen/spirv/spirv_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,6 @@ class TaskCodegen : public IRVisitor {
} else if (task_ir_->task_type == OffloadedTaskType::range_for) {
// struct_for is automatically lowered to ranged_for for dense snodes
generate_range_for_kernel(task_ir_);
} else if (task_ir_->task_type == OffloadedTaskType::listgen) {
generate_listgen_kernel(task_ir_);
} else if (task_ir_->task_type == OffloadedTaskType::struct_for) {
generate_struct_for_kernel(task_ir_);
} else {
Expand Down Expand Up @@ -336,7 +334,7 @@ class TaskCodegen : public IRVisitor {
ir_->uint_immediate_number(ir_->u32_type(), 2));
bitmask_word_ptr = ir_->add(
bitmask_word_ptr,
make_pointer(desc.cell_stride * desc.cells_per_container_pot()));
make_pointer(desc.cell_stride * desc.snode->num_cells_per_container));
bitmask_word_ptr = ir_->add(parent_ptr, bitmask_word_ptr);
bitmask_word_ptr = ir_->make_value(
spv::OpShiftRightLogical, ir_->u32_type(), bitmask_word_ptr,
Expand Down Expand Up @@ -480,24 +478,6 @@ class TaskCodegen : public IRVisitor {
spirv::Value loop_var = ir_->query_value("ii");
// spirv::Value val = ir_->add(loop_var, ir_->const_i32_zero_);
ir_->register_value(stmt_name, loop_var);
} else if (type == OffloadedTaskType::struct_for) {
SNode *snode = stmt->loop->as<OffloadedStmt>()->snode;
spirv::Value val = ir_->query_value("ii");
// FIXME: packed layout (non POT)
int root_id = snode_to_root_[snode->id];
const auto &snode_descs = compiled_structs_[root_id].snode_descriptors;
const int *axis_start_bit = snode_descs.at(snode->id).axis_start_bit;
const int *axis_bits_sum = snode_descs.at(snode->id).axis_bits_sum;
val =
ir_->make_value(spv::OpShiftRightLogical, ir_->u32_type(), val,
ir_->uint_immediate_number(
ir_->u32_type(), axis_start_bit[stmt->index]));
val = ir_->make_value(
spv::OpBitwiseAnd, ir_->u32_type(), val,
ir_->uint_immediate_number(ir_->u32_type(),
(1 << axis_bits_sum[stmt->index]) - 1));
val = ir_->cast(ir_->i32_type(), val);
ir_->register_value(stmt_name, val);
} else {
TI_NOT_IMPLEMENTED;
}
Expand Down Expand Up @@ -1945,158 +1925,6 @@ class TaskCodegen : public IRVisitor {
task_attribs_.texture_binds = get_texture_binds();
}

void generate_listgen_kernel(OffloadedStmt *stmt) {
task_attribs_.name = task_name_;
task_attribs_.task_type = OffloadedTaskType::listgen;
task_attribs_.advisory_total_num_threads = 1;
task_attribs_.advisory_num_threads_per_group = 32;

auto snode = stmt->snode;

TI_TRACE("Listgen for {}", snode->get_name());

std::vector<SNode *> snode_path;
std::vector<int> snode_path_num_cells;
std::vector<std::array<int, taichi_max_num_indices>>
snode_path_index_start_bit;
int total_num_cells = 1;
int root_id = 0;
{
// Construct the SNode path to the chosen node
auto snode_head = snode;
std::array<int, taichi_max_num_indices> start_indices{0};
do {
snode_path.push_back(snode_head);
snode_path_num_cells.push_back(total_num_cells);
snode_path_index_start_bit.push_back(start_indices);
total_num_cells *= snode_head->num_cells_per_container;
root_id = snode_head->id;
for (int i = 0; i < taichi_max_num_indices; i++) {
start_indices[i] += snode_head->extractors[i].num_bits;
}
} while ((snode_head = snode_head->parent));
}

const auto &snode_descs = compiled_structs_[root_id].snode_descriptors;
const auto sn_desc = snode_descs.at(snode->id);

for (int i = snode_path.size() - 1; i >= 0; i--) {
const auto &desc = snode_descs.at(snode_path[i]->id);
TI_TRACE("- {} ({})", snode_path[i]->get_name(),
snode_path[i]->type_name());
TI_TRACE(" is_place: {}, num_axis: {}, num_cells: {}",
snode_path[i]->is_place(), snode_path[i]->num_active_indices,
desc.cells_per_container_pot());
}

ir_->start_function(kernel_function_);

if (snode->type == SNodeType::bitmasked) {
task_attribs_.advisory_total_num_threads = total_num_cells;
int num_cells = snode->num_cells_per_container;

TI_INFO("ListGen {} * {}", total_num_cells / num_cells, num_cells);

auto listgen_buffer =
get_buffer_value(BufferInfo(BufferType::ListGen), PrimitiveType::i32);
auto invoc_index = ir_->get_global_invocation_id(0);

auto container_ptr = make_pointer(0);
std::vector<spirv::Value> linear_indices(snode_path.size());
for (int i = snode_path.size() - 1; i >= 0; i--) {
// Offset the ptr to the cell on layer up
SNode *this_snode = snode_path[i];
const auto &this_snode_desc = snode_descs.at(this_snode->id);

auto snode_linear_index =
ir_->uint_immediate_number(ir_->u32_type(), 0);
if (this_snode->num_active_indices) {
for (int idx = 0; idx < taichi_max_num_indices; idx++) {
if (this_snode->extractors[idx].active) {
auto axis_local_index = ir_->make_value(
spv::OpShiftRightLogical, ir_->u32_type(), invoc_index,
ir_->uint_immediate_number(
ir_->u32_type(), sn_desc.axis_start_bit[idx] +
snode_path_index_start_bit[i][idx]));
axis_local_index = ir_->make_value(
spv::OpBitwiseAnd, ir_->u32_type(), axis_local_index,
ir_->uint_immediate_number(
ir_->u32_type(),
(1 << this_snode->extractors[idx].num_bits) - 1));
snode_linear_index = ir_->make_value(
spv::OpBitwiseOr, ir_->u32_type(),
ir_->make_value(spv::OpShiftLeftLogical, ir_->u32_type(),
snode_linear_index,
ir_->uint_immediate_number(
ir_->u32_type(),
this_snode->extractors[idx].num_bits)),
axis_local_index);
}
}
}

if (i > 0) {
const auto &next_snode_desc = snode_descs.at(snode_path[i - 1]->id);
if (this_snode->num_active_indices) {
container_ptr = ir_->add(
container_ptr,
ir_->mul(snode_linear_index,
ir_->uint_immediate_number(
ir_->u32_type(), this_snode_desc.cell_stride)));
} else {
container_ptr = ir_->add(
container_ptr,
make_pointer(next_snode_desc.mem_offset_in_parent_cell));
}
}

linear_indices[i] = snode_linear_index;
}

// Check current bitmask mask within the cell
auto index_is_active =
bitmasked_activation(ActivationOp::query, container_ptr, root_id,
snode, linear_indices[0]);

auto if_then_label = ir_->new_label();
auto if_merge_label = ir_->new_label();

ir_->make_inst(spv::OpSelectionMerge, if_merge_label,
spv::SelectionControlMaskNone);
ir_->make_inst(spv::OpBranchConditional, index_is_active, if_then_label,
if_merge_label);
// if (is_active)
{
ir_->start_label(if_then_label);

auto listgen_count_ptr = ir_->struct_array_access(
ir_->u32_type(), listgen_buffer, ir_->const_i32_zero_);
auto index_count = ir_->make_value(
spv::OpAtomicIAdd, ir_->u32_type(), listgen_count_ptr,
/*scope=*/ir_->const_i32_one_,
/*semantics=*/ir_->const_i32_zero_,
ir_->uint_immediate_number(ir_->u32_type(), 1));
auto listgen_index_ptr = ir_->struct_array_access(
ir_->u32_type(), listgen_buffer,
ir_->add(ir_->uint_immediate_number(ir_->u32_type(), 1),
index_count));
ir_->store_variable(listgen_index_ptr, invoc_index);
ir_->make_inst(spv::OpBranch, if_merge_label);
}
ir_->start_label(if_merge_label);
} else if (snode->type == SNodeType::dense) {
// Why??
} else {
TI_NOT_IMPLEMENTED;
}

ir_->make_inst(spv::OpReturn); // return;
ir_->make_inst(spv::OpFunctionEnd); // } Close kernel

task_attribs_.buffer_binds = get_buffer_binds();
task_attribs_.texture_binds = get_texture_binds();
}

void generate_struct_for_kernel(OffloadedStmt *stmt) {
task_attribs_.name = task_name_;
task_attribs_.task_type = OffloadedTaskType::struct_for;
Expand Down
33 changes: 3 additions & 30 deletions taichi/ir/snode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ SNode &SNode::create_node(std::vector<Axis> axes,
"best performance, we recommend that you set it to a power of two.",
sizes[i], char('i' + ind), tb);
}
new_node.extractors[ind].activate(
bit::log2int(bit::least_pot_bound(sizes[i])));
new_node.extractors[ind].active = true;
new_node.extractors[ind].num_elements_from_root *= sizes[i];
new_node.extractors[ind].shape = sizes[i];
}
Expand All @@ -87,30 +86,16 @@ SNode &SNode::create_node(std::vector<Axis> axes,
"supported yet. Struct fors might not work either.");
}
new_node.num_cells_per_container = acc_shape;
// infer extractors (only for POT)
int acc_offsets = 0;
for (int i = taichi_max_num_indices - 1; i >= 0; i--) {
new_node.extractors[i].acc_offset = acc_offsets;
acc_offsets += new_node.extractors[i].num_bits;
}
new_node.total_num_bits = acc_offsets;

constexpr int kMaxTotalNumBits = 64;
TI_ERROR_IF(
new_node.total_num_bits >= kMaxTotalNumBits,
"SNode={}: total_num_bits={} exceeded limit={}. This implies that "
"your requested shape is too large.",
new_node.id, new_node.total_num_bits, kMaxTotalNumBits);

if (new_node.type == SNodeType::dynamic) {
int active_extractor_counder = 0;
for (int i = 0; i < taichi_max_num_indices; i++) {
if (new_node.extractors[i].num_bits != 0) {
if (new_node.extractors[i].active) {
active_extractor_counder += 1;
SNode *p = new_node.parent;
while (p) {
TI_ASSERT_INFO(
p->extractors[i].num_bits == 0,
!p->extractors[i].active,
"Dynamic SNode must have a standalone dimensionality.");
p = p->parent;
}
Expand Down Expand Up @@ -213,8 +198,6 @@ SNode::SNode(int depth,
snode_rw_accessors_bank_(snode_rw_accessors_bank) {
id = counter++;
node_type_name = get_node_type_name();
total_num_bits = 0;
total_bit_start = 0;
num_active_indices = 0;
std::memset(physical_index_position, -1, sizeof(physical_index_position));
parent = nullptr;
Expand All @@ -241,16 +224,6 @@ std::string SNode::get_node_type_name_hinted() const {
return fmt::format("S{}{}{}", id, snode_type_name(type), suffix);
}

int SNode::get_num_bits(int physical_index) const {
int result = 0;
const SNode *snode = this;
while (snode) {
result += snode->extractors[physical_index].num_bits;
snode = snode->parent;
}
return result;
}

void SNode::print() {
for (int i = 0; i < depth; i++) {
fmt::print(" ");
Expand Down
Loading

0 comments on commit 0dff195

Please sign in to comment.