diff --git a/taichi/ir/scratch_pad.cpp b/taichi/ir/scratch_pad.cpp index dd64bcd32653b..6d6a086d6325d 100644 --- a/taichi/ir/scratch_pad.cpp +++ b/taichi/ir/scratch_pad.cpp @@ -14,7 +14,7 @@ std::string ScratchPad::global_to_linearized_local( TI_ASSERT(step_size % pad_size[i] == 0); step_size /= pad_size[i]; ret += fmt::format(" + ({} - {}_base - {}) * {}", indices[i]->raw_name(), - loop_vars[i]->raw_name(), bounds[0][i], step_size); + loop_vars[i]->raw_name(), bounds[i].low, step_size); } return ret; } diff --git a/taichi/ir/scratch_pad.h b/taichi/ir/scratch_pad.h index eb1538aae2db7..aac3d641b4e39 100644 --- a/taichi/ir/scratch_pad.h +++ b/taichi/ir/scratch_pad.h @@ -29,11 +29,26 @@ inline AccessFlag operator|=(AccessFlag &a, AccessFlag &b) { class ScratchPad { public: - SNode *snode; + // The lowest and highest index in each dimension. + struct BoundRange { + int low{0}; + int high{0}; + + int range() const { + return high - low; + } + + TI_IO_DEF(low, high); + }; + + SNode *snode{nullptr}; using AccessFlag = taichi::lang::AccessFlag; - std::vector bounds[2]; + std::vector bounds; + // pad_size[i] := bounds[i].high - bounds[i].low + // TODO: This can be replaced by a function call to bounds[i].range() std::vector pad_size; + // block_size[i] := (1 << snode.extractor[i].num_bits) std::vector block_size; bool finalized; int dim; @@ -48,17 +63,16 @@ class ScratchPad { ScratchPad(SNode *snode) : snode(snode) { TI_ASSERT(snode != nullptr); dim = snode->num_active_indices; - bounds[0].resize(dim); - bounds[1].resize(dim); + bounds.resize(dim); pad_size.resize(dim); finalized = false; total_flags = AccessFlag(0); - std::fill(bounds[0].begin(), bounds[0].end(), - std::numeric_limits::max()); - std::fill(bounds[1].begin(), bounds[1].end(), - std::numeric_limits::min()); + BoundRange init_bound; + init_bound.low = std::numeric_limits::max(); + init_bound.high = std::numeric_limits::min(); + std::fill(bounds.begin(), bounds.end(), init_bound); empty = false; } @@ -67,9 +81,9 @@ class ScratchPad { empty = true; TI_ASSERT((int)indices.size() == dim); for (int i = 0; i < dim; i++) { - bounds[0][i] = std::min(bounds[0][i], indices[i]); - bounds[1][i] = std::max(bounds[1][i], indices[i] + 1); - pad_size[i] = bounds[1][i] - bounds[0][i]; + bounds[i].low = std::min(bounds[i].low, indices[i]); + bounds[i].high = std::max(bounds[i].high, indices[i] + 1); + pad_size[i] = bounds[i].range(); } accesses.emplace_back(indices, flags); } @@ -86,8 +100,8 @@ class ScratchPad { block_size[i] = 1 << snode->parent->extractors[snode->physical_index_position[i]] .num_bits; - TI_ASSERT(bounds[0][i] != std::numeric_limits::max()); - TI_ASSERT(bounds[1][i] != std::numeric_limits::min()); + TI_ASSERT(bounds[i].low != std::numeric_limits::max()); + TI_ASSERT(bounds[i].high != std::numeric_limits::min()); } finalized = true; @@ -132,8 +146,8 @@ class ScratchPad { int ret = 0; TI_ASSERT(finalized); for (int i = 0; i < dim; i++) { - ret *= (bounds[1][i] - bounds[0][i]); - ret += indices[i] - bounds[0][i]; + ret *= (bounds[i].high - bounds[i].low); + ret += indices[i] - bounds[i].low; } return ret; } @@ -144,7 +158,7 @@ class ScratchPad { div *= pad_size[i]; } return fmt::format("({} / {} % {} + {})", var, div, pad_size[d], - bounds[0][d]); + bounds[d].low); } /* @@ -209,9 +223,8 @@ class ScratchPads { if (pads.find(snode) != pads.end()) { auto &pad = pads[snode]; int offset = 0; - // for (int i = pad.dim - 1; i >= 0; i--) { for (int i = 0; i < pad.dim; i++) { - offset = offset + (indices[i] - pad.bounds[0][i]); + offset = offset + (indices[i] - pad.bounds[i].low); if (i > 0) offset = offset * pad.pad_size[i - 1]; } @@ -224,8 +237,7 @@ class ScratchPads { void print() { for (auto &it : pads) { TI_P(it.first->node_type_name); - TI_P(it.second.bounds[0]); - TI_P(it.second.bounds[1]); + TI_P(it.second.bounds); } } diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index ef9e4be086ea9..14ba6f21e4b9d 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -1066,7 +1066,8 @@ class LoopIndexStmt : public Stmt { }; /** - * All loop indices of the |loop| fused together. + * thread index within a CUDA block + * TODO: Remove this. Have a better way for retrieving thread index. */ class LoopLinearIndexStmt : public Stmt { public: @@ -1080,10 +1081,6 @@ class LoopLinearIndexStmt : public Stmt { return false; } - // Return the number of bits of the loop, or -1 if unknown. - // TODO: implement - // int max_num_bits() const; - TI_STMT_DEF_FIELDS(ret_type, loop); TI_DEFINE_ACCEPT_AND_CLONE }; diff --git a/taichi/transforms/make_block_local.cpp b/taichi/transforms/make_block_local.cpp index 9fe9fd39a4530..f28899eb505ff 100644 --- a/taichi/transforms/make_block_local.cpp +++ b/taichi/transforms/make_block_local.cpp @@ -19,7 +19,7 @@ void make_block_local_offload(OffloadedStmt *offload, auto pads = irpass::initialize_scratch_pad(offload); - std::size_t bls_offset = 0; + std::size_t bls_offset_in_bytes = 0; for (auto &pad : pads->pads) { auto snode = pad.first; @@ -35,10 +35,10 @@ void make_block_local_offload(OffloadedStmt *offload, "BLS with both read and accumulation is not supported.") // dim = Dimensionality of the BLS buffer and the block - auto dim = (int)pad.second.pad_size.size(); + const auto dim = (int)pad.second.pad_size.size(); TI_ASSERT(dim == snode->num_active_indices); - auto bls_num_elements = pad.second.pad_size_linear(); + const auto bls_num_elements = pad.second.pad_size_linear(); std::vector block_strides(dim); std::vector bls_strides(dim); @@ -55,7 +55,8 @@ void make_block_local_offload(OffloadedStmt *offload, // TODO: improve IR builder to make this part easier to read // Ensure BLS alignment - bls_offset += (dtype_size - bls_offset % dtype_size) % dtype_size; + bls_offset_in_bytes += + (dtype_size - bls_offset_in_bytes % dtype_size) % dtype_size; // This lambda is used for both BLS prologue and epilogue creation auto create_xlogue = @@ -67,7 +68,8 @@ void make_block_local_offload(OffloadedStmt *offload, block = std::make_unique(); block->parent_stmt = offload; } - Stmt *block_linear_index = + // Equivalent to CUDA threadIdx + Stmt *thread_idx_stmt = block->push_back(offload); /* @@ -75,14 +77,25 @@ void make_block_local_offload(OffloadedStmt *offload, each thread may have to fetch more than one element to BLS. Therefore on CUDA we need something like - auto bls_element_id = block_linear_index; - while (bls_element_id < bls_size) { + auto bls_element_id = thread_idx_stmt; + while (bls_element_id < bls_num_elements) { i, j, k = bls_to_global(bls_element_id) bls[bls_element_id] = x[i, j, k] // or x[i, j, k] = bls[bls_element_id] bls_element_id += block_dim; } + func bls_to_global(bls_element_id): + partial = bls_element_id + global_indices = [] // "i, j, k" + for i in reversed(range(0, dim)): + pad_size = pad.pad_size[i] // a.k.a. bounds[i].range() + bls_coord = partial % pad_size + partial = partial / pad_size + global_index_at_i = BlockCorner[i] + bls_coord + global_index_at_i += pad.bounds[i].low + global_indices[i] = global_index_at_i + Since we know block_dim and bls_size at compile time and there's usually not too many iterations, we directly unroll this while loop for performance when constructing prologues/epilogues. @@ -90,14 +103,14 @@ void make_block_local_offload(OffloadedStmt *offload, // Unroll the while-loop int loop_offset = 0; - int block_dim = offload->block_dim; + const int block_dim = offload->block_dim; while (loop_offset < bls_num_elements) { Block *element_block = nullptr; auto loop_offset_stmt = block->push_back(TypedConstant(loop_offset)); auto bls_element_id_this_iteration = block->push_back( - BinaryOpType::add, loop_offset_stmt, block_linear_index); + BinaryOpType::add, loop_offset_stmt, thread_idx_stmt); auto bls_element_offset_bytes = block->push_back( BinaryOpType::mul, bls_element_id_this_iteration, @@ -105,7 +118,8 @@ void make_block_local_offload(OffloadedStmt *offload, bls_element_offset_bytes = block->push_back( BinaryOpType::add, bls_element_offset_bytes, - block->push_back(TypedConstant((int32)bls_offset))); + block->push_back( + TypedConstant((int32)bls_offset_in_bytes))); if (loop_offset + block_dim > bls_num_elements) { // Need to create an IfStmt to safeguard since bls size may not be @@ -130,25 +144,26 @@ void make_block_local_offload(OffloadedStmt *offload, // via a series of % and /. auto bls_element_id_partial = bls_element_id_this_iteration; for (int i = dim - 1; i >= 0; i--) { - auto size = element_block->push_back( + auto pad_size_stmt = element_block->push_back( TypedConstant(pad.second.pad_size[i])); auto bls_coord = element_block->push_back( - BinaryOpType::mod, bls_element_id_partial, size); + BinaryOpType::mod, bls_element_id_partial, pad_size_stmt); bls_element_id_partial = element_block->push_back( - BinaryOpType::div, bls_element_id_partial, size); + BinaryOpType::div, bls_element_id_partial, pad_size_stmt); - auto global_index = element_block->push_back( - BinaryOpType::add, - element_block->push_back( - TypedConstant(pad.second.bounds[0][i])), - bls_coord); + auto global_index_this_dim = + element_block->push_back( + BinaryOpType::add, + element_block->push_back( + TypedConstant(pad.second.bounds[i].low)), + bls_coord); - global_index = element_block->push_back( - BinaryOpType::add, global_index, + global_index_this_dim = element_block->push_back( + BinaryOpType::add, global_index_this_dim, element_block->push_back(offload, i)); - global_indices[i] = global_index; + global_indices[i] = global_index_this_dim; } operation(element_block, global_indices, bls_element_offset_bytes); @@ -209,7 +224,7 @@ void make_block_local_offload(OffloadedStmt *offload, for (int i = 0; i < dim; i++) { // BLS index = sum_i inc_i // where inc_i = - // bls_stride_i * (gbl_idx_i - loop_base_i - bls_lower_bound_i) + // bls_stride_i * (gbl_idx_i - block_corner_i - bls_lower_bound_i) // Note that when index offsets are used, the offset contributions are // already included in bls_lower_bound_i. auto block_corner = bls.push_back(offload, i); @@ -218,13 +233,14 @@ void make_block_local_offload(OffloadedStmt *offload, BinaryOpType::sub, global_indices[i], block_corner); inc = bls.push_back( BinaryOpType::sub, inc, - bls.push_back(TypedConstant(pad.second.bounds[0][i]))); + bls.push_back( + TypedConstant(pad.second.bounds[i].low))); if (debug) { // This part insert an assertion to make sure BLS access is within // the bound. auto bls_axis_size = - pad.second.bounds[1][i] - pad.second.bounds[0][i]; + pad.second.bounds[i].high - pad.second.bounds[i].low; std::string msg = fmt::format( "(kernel={}, body) Access out of bound: BLS buffer axis {} " "(size {}) with " @@ -267,7 +283,8 @@ void make_block_local_offload(OffloadedStmt *offload, // add array offset bls_element_offset = bls.push_back( BinaryOpType::add, bls_element_offset, - bls.push_back(TypedConstant((int32)bls_offset))); + bls.push_back( + TypedConstant((int32)bls_offset_in_bytes))); bls.push_back( bls_element_offset, @@ -297,10 +314,10 @@ void make_block_local_offload(OffloadedStmt *offload, } // allocate storage for the BLS variable - bls_offset += dtype_size * bls_num_elements; - } + bls_offset_in_bytes += dtype_size * bls_num_elements; + } // for (auto &pad : pads->pads) - offload->bls_size = std::max(std::size_t(1), bls_offset); + offload->bls_size = std::max(std::size_t(1), bls_offset_in_bytes); } } // namespace diff --git a/tests/cpp/analysis/bls_analyzer_test.cpp b/tests/cpp/analysis/bls_analyzer_test.cpp index fc3c60d0b494a..6919c438c66f3 100644 --- a/tests/cpp/analysis/bls_analyzer_test.cpp +++ b/tests/cpp/analysis/bls_analyzer_test.cpp @@ -81,13 +81,11 @@ TEST_F(BLSAnalyzerTest, Basic) { BLSAnalyzer bls(for_stmt_.get(), &pads_); pads_.finalize(); const auto &pad = pads_.get(child_snode_); - EXPECT_EQ(pad.bounds[0].size(), 2); - constexpr int kLow = 0; - constexpr int kHigh = 1; - EXPECT_EQ(pad.bounds[kLow][0], 0); - EXPECT_EQ(pad.bounds[kHigh][0], 1 + kBlockSize); - EXPECT_EQ(pad.bounds[kLow][1], -3); - EXPECT_EQ(pad.bounds[kHigh][1], kBlockSize); + EXPECT_EQ(pad.bounds.size(), 2); + EXPECT_EQ(pad.bounds[0].low, 0); + EXPECT_EQ(pad.bounds[0].high, 1 + kBlockSize); + EXPECT_EQ(pad.bounds[1].low, -3); + EXPECT_EQ(pad.bounds[1].high, kBlockSize); } } // namespace