Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[opt] Better encapsulate BLS bounds #2341

Merged
merged 4 commits into from
May 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion taichi/ir/scratch_pad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ std::string ScratchPad::global_to_linearized_local(
TI_ASSERT(step_size % pad_size[i] == 0);
step_size /= pad_size[i];
ret += fmt::format(" + ({} - {}_base - {}) * {}", indices[i]->raw_name(),
loop_vars[i]->raw_name(), bounds[0][i], step_size);
loop_vars[i]->raw_name(), bounds[i].low, step_size);
}
return ret;
}
Expand Down
52 changes: 32 additions & 20 deletions taichi/ir/scratch_pad.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,26 @@ inline AccessFlag operator|=(AccessFlag &a, AccessFlag &b) {

class ScratchPad {
public:
SNode *snode;
// The lowest and highest index in each dimension.
struct BoundRange {
int low{0};
int high{0};

int range() const {
return high - low;
}

TI_IO_DEF(low, high);
};

SNode *snode{nullptr};
using AccessFlag = taichi::lang::AccessFlag;

std::vector<int> bounds[2];
std::vector<BoundRange> bounds;
// pad_size[i] := bounds[i].high - bounds[i].low
// TODO: This can be replaced by a function call to bounds[i].range()
std::vector<int> pad_size;
// block_size[i] := (1 << snode.extractor[i].num_bits)
std::vector<int> block_size;
bool finalized;
int dim;
Expand All @@ -48,17 +63,16 @@ class ScratchPad {
ScratchPad(SNode *snode) : snode(snode) {
TI_ASSERT(snode != nullptr);
dim = snode->num_active_indices;
bounds[0].resize(dim);
bounds[1].resize(dim);
bounds.resize(dim);
pad_size.resize(dim);

finalized = false;

total_flags = AccessFlag(0);
std::fill(bounds[0].begin(), bounds[0].end(),
std::numeric_limits<int>::max());
std::fill(bounds[1].begin(), bounds[1].end(),
std::numeric_limits<int>::min());
BoundRange init_bound;
init_bound.low = std::numeric_limits<int>::max();
init_bound.high = std::numeric_limits<int>::min();
std::fill(bounds.begin(), bounds.end(), init_bound);
empty = false;
}

Expand All @@ -67,9 +81,9 @@ class ScratchPad {
empty = true;
TI_ASSERT((int)indices.size() == dim);
for (int i = 0; i < dim; i++) {
bounds[0][i] = std::min(bounds[0][i], indices[i]);
bounds[1][i] = std::max(bounds[1][i], indices[i] + 1);
pad_size[i] = bounds[1][i] - bounds[0][i];
bounds[i].low = std::min(bounds[i].low, indices[i]);
bounds[i].high = std::max(bounds[i].high, indices[i] + 1);
pad_size[i] = bounds[i].range();
}
accesses.emplace_back(indices, flags);
}
Expand All @@ -86,8 +100,8 @@ class ScratchPad {
block_size[i] =
1 << snode->parent->extractors[snode->physical_index_position[i]]
.num_bits;
TI_ASSERT(bounds[0][i] != std::numeric_limits<int>::max());
TI_ASSERT(bounds[1][i] != std::numeric_limits<int>::min());
TI_ASSERT(bounds[i].low != std::numeric_limits<int>::max());
TI_ASSERT(bounds[i].high != std::numeric_limits<int>::min());
}

finalized = true;
Expand Down Expand Up @@ -132,8 +146,8 @@ class ScratchPad {
int ret = 0;
TI_ASSERT(finalized);
for (int i = 0; i < dim; i++) {
ret *= (bounds[1][i] - bounds[0][i]);
ret += indices[i] - bounds[0][i];
ret *= (bounds[i].high - bounds[i].low);
ret += indices[i] - bounds[i].low;
}
return ret;
}
Expand All @@ -144,7 +158,7 @@ class ScratchPad {
div *= pad_size[i];
}
return fmt::format("({} / {} % {} + {})", var, div, pad_size[d],
bounds[0][d]);
bounds[d].low);
}

/*
Expand Down Expand Up @@ -209,9 +223,8 @@ class ScratchPads {
if (pads.find(snode) != pads.end()) {
auto &pad = pads[snode];
int offset = 0;
// for (int i = pad.dim - 1; i >= 0; i--) {
for (int i = 0; i < pad.dim; i++) {
offset = offset + (indices[i] - pad.bounds[0][i]);
offset = offset + (indices[i] - pad.bounds[i].low);
if (i > 0)
offset = offset * pad.pad_size[i - 1];
}
Expand All @@ -224,8 +237,7 @@ class ScratchPads {
void print() {
for (auto &it : pads) {
TI_P(it.first->node_type_name);
TI_P(it.second.bounds[0]);
TI_P(it.second.bounds[1]);
Comment on lines -227 to -228
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: If this function is merely used for debugging, why not printing bounds as well?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed :)

TI_P(it.second.bounds);
}
}

Expand Down
7 changes: 2 additions & 5 deletions taichi/ir/statements.h
Original file line number Diff line number Diff line change
Expand Up @@ -1066,7 +1066,8 @@ class LoopIndexStmt : public Stmt {
};

/**
* All loop indices of the |loop| fused together.
* thread index within a CUDA block
* TODO: Remove this. Have a better way for retrieving thread index.
*/
class LoopLinearIndexStmt : public Stmt {
public:
Expand All @@ -1080,10 +1081,6 @@ class LoopLinearIndexStmt : public Stmt {
return false;
}

// Return the number of bits of the loop, or -1 if unknown.
// TODO: implement
// int max_num_bits() const;

TI_STMT_DEF_FIELDS(ret_type, loop);
TI_DEFINE_ACCEPT_AND_CLONE
};
Expand Down
73 changes: 45 additions & 28 deletions taichi/transforms/make_block_local.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ void make_block_local_offload(OffloadedStmt *offload,

auto pads = irpass::initialize_scratch_pad(offload);

std::size_t bls_offset = 0;
std::size_t bls_offset_in_bytes = 0;

for (auto &pad : pads->pads) {
auto snode = pad.first;
Expand All @@ -35,10 +35,10 @@ void make_block_local_offload(OffloadedStmt *offload,
"BLS with both read and accumulation is not supported.")

// dim = Dimensionality of the BLS buffer and the block
auto dim = (int)pad.second.pad_size.size();
const auto dim = (int)pad.second.pad_size.size();
TI_ASSERT(dim == snode->num_active_indices);

auto bls_num_elements = pad.second.pad_size_linear();
const auto bls_num_elements = pad.second.pad_size_linear();

std::vector<int> block_strides(dim);
std::vector<int> bls_strides(dim);
Expand All @@ -55,7 +55,8 @@ void make_block_local_offload(OffloadedStmt *offload,
// TODO: improve IR builder to make this part easier to read

// Ensure BLS alignment
bls_offset += (dtype_size - bls_offset % dtype_size) % dtype_size;
bls_offset_in_bytes +=
(dtype_size - bls_offset_in_bytes % dtype_size) % dtype_size;

// This lambda is used for both BLS prologue and epilogue creation
auto create_xlogue =
Expand All @@ -67,45 +68,58 @@ void make_block_local_offload(OffloadedStmt *offload,
block = std::make_unique<Block>();
block->parent_stmt = offload;
}
Stmt *block_linear_index =
// Equivalent to CUDA threadIdx
Stmt *thread_idx_stmt =
block->push_back<LoopLinearIndexStmt>(offload);

/*
Note that since there are fewer elements in the block than in BLS,
each thread may have to fetch more than one element to BLS.
Therefore on CUDA we need something like

auto bls_element_id = block_linear_index;
while (bls_element_id < bls_size) {
auto bls_element_id = thread_idx_stmt;
while (bls_element_id < bls_num_elements) {
i, j, k = bls_to_global(bls_element_id)
bls[bls_element_id] = x[i, j, k]
// or x[i, j, k] = bls[bls_element_id]
bls_element_id += block_dim;
}

func bls_to_global(bls_element_id):
partial = bls_element_id
global_indices = [] // "i, j, k"
for i in reversed(range(0, dim)):
pad_size = pad.pad_size[i] // a.k.a. bounds[i].range()
bls_coord = partial % pad_size
partial = partial / pad_size
global_index_at_i = BlockCorner[i] + bls_coord
global_index_at_i += pad.bounds[i].low
global_indices[i] = global_index_at_i

Since we know block_dim and bls_size at compile time and there's
usually not too many iterations, we directly unroll this while loop
for performance when constructing prologues/epilogues.
*/

// Unroll the while-loop
int loop_offset = 0;
int block_dim = offload->block_dim;
const int block_dim = offload->block_dim;
while (loop_offset < bls_num_elements) {
Block *element_block = nullptr;
auto loop_offset_stmt =
block->push_back<ConstStmt>(TypedConstant(loop_offset));

auto bls_element_id_this_iteration = block->push_back<BinaryOpStmt>(
BinaryOpType::add, loop_offset_stmt, block_linear_index);
BinaryOpType::add, loop_offset_stmt, thread_idx_stmt);

auto bls_element_offset_bytes = block->push_back<BinaryOpStmt>(
BinaryOpType::mul, bls_element_id_this_iteration,
block->push_back<ConstStmt>(TypedConstant(dtype_size)));

bls_element_offset_bytes = block->push_back<BinaryOpStmt>(
BinaryOpType::add, bls_element_offset_bytes,
block->push_back<ConstStmt>(TypedConstant((int32)bls_offset)));
block->push_back<ConstStmt>(
TypedConstant((int32)bls_offset_in_bytes)));

if (loop_offset + block_dim > bls_num_elements) {
// Need to create an IfStmt to safeguard since bls size may not be
Expand All @@ -130,25 +144,26 @@ void make_block_local_offload(OffloadedStmt *offload,
// via a series of % and /.
auto bls_element_id_partial = bls_element_id_this_iteration;
for (int i = dim - 1; i >= 0; i--) {
auto size = element_block->push_back<ConstStmt>(
auto pad_size_stmt = element_block->push_back<ConstStmt>(
TypedConstant(pad.second.pad_size[i]));

auto bls_coord = element_block->push_back<BinaryOpStmt>(
BinaryOpType::mod, bls_element_id_partial, size);
BinaryOpType::mod, bls_element_id_partial, pad_size_stmt);
bls_element_id_partial = element_block->push_back<BinaryOpStmt>(
BinaryOpType::div, bls_element_id_partial, size);
BinaryOpType::div, bls_element_id_partial, pad_size_stmt);

auto global_index = element_block->push_back<BinaryOpStmt>(
BinaryOpType::add,
element_block->push_back<ConstStmt>(
TypedConstant(pad.second.bounds[0][i])),
bls_coord);
auto global_index_this_dim =
element_block->push_back<BinaryOpStmt>(
BinaryOpType::add,
element_block->push_back<ConstStmt>(
TypedConstant(pad.second.bounds[i].low)),
bls_coord);

global_index = element_block->push_back<BinaryOpStmt>(
BinaryOpType::add, global_index,
global_index_this_dim = element_block->push_back<BinaryOpStmt>(
BinaryOpType::add, global_index_this_dim,
element_block->push_back<BlockCornerIndexStmt>(offload, i));

global_indices[i] = global_index;
global_indices[i] = global_index_this_dim;
}

operation(element_block, global_indices, bls_element_offset_bytes);
Expand Down Expand Up @@ -209,7 +224,7 @@ void make_block_local_offload(OffloadedStmt *offload,
for (int i = 0; i < dim; i++) {
// BLS index = sum_i inc_i
// where inc_i =
// bls_stride_i * (gbl_idx_i - loop_base_i - bls_lower_bound_i)
// bls_stride_i * (gbl_idx_i - block_corner_i - bls_lower_bound_i)
// Note that when index offsets are used, the offset contributions are
// already included in bls_lower_bound_i.
auto block_corner = bls.push_back<BlockCornerIndexStmt>(offload, i);
Expand All @@ -218,13 +233,14 @@ void make_block_local_offload(OffloadedStmt *offload,
BinaryOpType::sub, global_indices[i], block_corner);
inc = bls.push_back<BinaryOpStmt>(
BinaryOpType::sub, inc,
bls.push_back<ConstStmt>(TypedConstant(pad.second.bounds[0][i])));
bls.push_back<ConstStmt>(
TypedConstant(pad.second.bounds[i].low)));

if (debug) {
// This part insert an assertion to make sure BLS access is within
// the bound.
auto bls_axis_size =
pad.second.bounds[1][i] - pad.second.bounds[0][i];
pad.second.bounds[i].high - pad.second.bounds[i].low;
std::string msg = fmt::format(
"(kernel={}, body) Access out of bound: BLS buffer axis {} "
"(size {}) with "
Expand Down Expand Up @@ -267,7 +283,8 @@ void make_block_local_offload(OffloadedStmt *offload,
// add array offset
bls_element_offset = bls.push_back<BinaryOpStmt>(
BinaryOpType::add, bls_element_offset,
bls.push_back<ConstStmt>(TypedConstant((int32)bls_offset)));
bls.push_back<ConstStmt>(
TypedConstant((int32)bls_offset_in_bytes)));

bls.push_back<BlockLocalPtrStmt>(
bls_element_offset,
Expand Down Expand Up @@ -297,10 +314,10 @@ void make_block_local_offload(OffloadedStmt *offload,
}

// allocate storage for the BLS variable
bls_offset += dtype_size * bls_num_elements;
}
bls_offset_in_bytes += dtype_size * bls_num_elements;
} // for (auto &pad : pads->pads)

offload->bls_size = std::max(std::size_t(1), bls_offset);
offload->bls_size = std::max(std::size_t(1), bls_offset_in_bytes);
}

} // namespace
Expand Down
12 changes: 5 additions & 7 deletions tests/cpp/analysis/bls_analyzer_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,11 @@ TEST_F(BLSAnalyzerTest, Basic) {
BLSAnalyzer bls(for_stmt_.get(), &pads_);
pads_.finalize();
const auto &pad = pads_.get(child_snode_);
EXPECT_EQ(pad.bounds[0].size(), 2);
constexpr int kLow = 0;
constexpr int kHigh = 1;
EXPECT_EQ(pad.bounds[kLow][0], 0);
EXPECT_EQ(pad.bounds[kHigh][0], 1 + kBlockSize);
EXPECT_EQ(pad.bounds[kLow][1], -3);
EXPECT_EQ(pad.bounds[kHigh][1], kBlockSize);
EXPECT_EQ(pad.bounds.size(), 2);
EXPECT_EQ(pad.bounds[0].low, 0);
EXPECT_EQ(pad.bounds[0].high, 1 + kBlockSize);
EXPECT_EQ(pad.bounds[1].low, -3);
EXPECT_EQ(pad.bounds[1].high, kBlockSize);
}

} // namespace
Expand Down