Skip to content

Commit

Permalink
Clean Debug Code on Previous PRs (PaddlePaddle#59839)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhhsplendid committed Dec 13, 2023
1 parent 142c71a commit 54e46e3
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 254 deletions.
158 changes: 0 additions & 158 deletions paddle/cinn/optim/transform_gpu_forloop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -229,163 +229,6 @@ class ReplaceIndexToBindExpr : public ir::IRMutator<> {
}
};

using TENSOR_LOOP = std::pair<ir::Expr, std::vector<ir::Expr>>;
class CollectTensorLoopVisitor : public ir::IRMutator<> {
public:
void operator()(ir::Expr *expr) { ir::IRMutator<>::Visit(expr, expr); }

private:
void Visit(const ir::Store *op, Expr *expr) override {
auto tensor = op->tensor.as_tensor_ref();
// if buffer defined and buffer is not Heap.
if (tensor->buffer.defined() &&
tensor->buffer->memory_type != ir::MemoryType::Heap) {
if (buffer_tensor_loop_map_.count(tensor->buffer->name)) {
buffer_tensor_loop_map_[tensor->buffer->name].push_back(
std::make_pair(*expr, loops_));
} else {
buffer_tensor_loop_map_[tensor->buffer->name] = {
std::make_pair(*expr, loops_)};
}
}

IRMutator::Visit(op, expr);
}

void Visit(const ir::Load *op, Expr *expr) override {
if (op->is_addr_scalar()) {
return;
}
auto tensor = op->tensor.as_tensor_ref();
// if buffer defined and buffer is not Heap.
if (tensor->buffer.defined() &&
tensor->buffer->memory_type != ir::MemoryType::Heap) {
if (buffer_tensor_loop_map_.count(tensor->buffer->name)) {
buffer_tensor_loop_map_[tensor->buffer->name].push_back(
std::make_pair(*expr, loops_));
} else {
buffer_tensor_loop_map_[tensor->buffer->name] = {
std::make_pair(*expr, loops_)};
}
}

IRMutator::Visit(op, expr);
}

void Visit(const ir::For *op, Expr *expr) override {
loops_.push_back(*expr);
IRMutator::Visit(op, expr);
loops_.pop_back();
}

void Visit(const ir::PolyFor *op, Expr *expr) override {
LOG(FATAL) << "Unkown PolyFor!";
}

public:
std::vector<ir::Expr> loops_;
std::unordered_map<std::string, std::vector<TENSOR_LOOP>>
buffer_tensor_loop_map_;
};

void UpdateBufferAxisPassOld(ir::Expr *expr) {
CollectTensorLoopVisitor collect_tensor_loop_visitor;
collect_tensor_loop_visitor(expr);

auto buffer_tensor_loop = collect_tensor_loop_visitor.buffer_tensor_loop_map_;

for (auto &tmp : buffer_tensor_loop) {
auto tensor_loop_v = tmp.second;

auto &front = tensor_loop_v.front();
int count = tensor_loop_v.size() > 1 ? front.second.size() : 0;
for (int idx = 1; idx < tensor_loop_v.size(); ++idx) {
auto &other = tensor_loop_v[idx];
for (int idy = 0;
idy < std::min(front.second.size(), other.second.size());
++idy) {
if (front.second[idy] != other.second[idy]) {
count = std::min(count, idy);
break;
}
}
}

auto get_thread_bind_var = [](const std::vector<ir::Expr> &loops) {
// threadidx loop_var,extent.
using ThreadLoopVarExtentMap =
std::unordered_map<std::string, std::pair<std::string, int>>;
ThreadLoopVarExtentMap thread_loop_var_exent_map;
for (auto loop : loops) {
auto loop_ir = loop.As<ir::For>();
CHECK(loop_ir);
if (loop_ir->is_gpu_thread_binded()) {
std::string axis = "";
if (loop_ir->bind_info().offset == 0) {
axis = "threadIdx.x";
} else if (loop_ir->bind_info().offset == 1) {
axis = "threadIdx.y";
} else {
axis = "threadIdx.z";
}
// insert gpu thread loop var.
if (thread_loop_var_exent_map.count(axis)) {
auto &loop_var_extent = thread_loop_var_exent_map[axis];
if (loop_var_extent.second >= loop_ir->extent.as_int32()) {
thread_loop_var_exent_map[axis] = std::make_pair(
loop_ir->loop_var->name, loop_ir->extent.as_int32());
}
} else {
thread_loop_var_exent_map[axis] = std::make_pair(
loop_ir->loop_var->name, loop_ir->extent.as_int32());
}
}
}

std::unordered_set<std::string> loop_var_map;
for (auto &tmp : thread_loop_var_exent_map) {
loop_var_map.insert(tmp.second.first);
}

return loop_var_map;
};

auto load = front.first.As<ir::Load>();
auto store = front.first.As<ir::Store>();
auto tensor =
load ? load->tensor.as_tensor_ref() : store->tensor.as_tensor_ref();
// find store and load keep loop for shared
std::vector<std::unordered_set<std::string>> keep_loop_vars;
if (tensor->buffer->memory_type == ir::MemoryType::GPUShared) {
for (auto &tensor_loop : tensor_loop_v) {
keep_loop_vars.push_back(get_thread_bind_var(tensor_loop.second));
}
CHECK_EQ(keep_loop_vars.size(), tensor_loop_v.size());
}

auto &loops = front.second;
for (int idx = 0; idx < count; ++idx) {
auto loop_expr = loops[idx];
auto loop_ir = loop_expr.As<ir::For>();
auto loop_var = loop_ir->loop_var;

for (int idy = 0; idy < tensor_loop_v.size(); ++idy) {
auto expr = tensor_loop_v[idy].first;
auto load = expr.As<ir::Load>();
auto store = expr.As<ir::Store>();
if (keep_loop_vars.size() == 0 ||
!keep_loop_vars[idy].count(loop_var->name)) {
auto &indices = load ? load->indices : store->indices;
for (auto &indice : indices) {
optim::ReplaceVarWithExpr(&indice, loop_var, ir::Expr(0));
indice = cinn::common::AutoSimplify(indice);
}
}
}
}
}
}

class ReplaceLoopVarToGpu : public ir::IRMutator<> {
public:
void operator()(Expr *expr) { ir::IRMutator<>::Visit(expr, expr); }
Expand Down Expand Up @@ -586,7 +429,6 @@ void OptimizeExprGPU(Expr *expr) {

// resize buffer axis
UpdateBufferAxisPass(expr);
// UpdateBufferAxisPassOld(expr);

// replace var name with block/thread
ReplaceLoopVarToGpu replace_loop_var_to_gpu;
Expand Down
100 changes: 4 additions & 96 deletions paddle/cinn/optim/update_buffer_axis_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,80 +61,6 @@ void FormalizeSingleIndex(const ir::Tensor& tensor,
}
}

/**
* This is a template pass to update the buffer access when using
* single axis of a mult-dim tensor. For example, if the tensor t
* t.shape = [2, 3, 4] and the buffer access is t[12 * k]
* it is same as t[k, 0, 0]. It is easy for human to understand
* they are the same but not easy for compiler.
*
* This class check the buffer access are the same and update those
* same buffer access with the same index expr.
*
* Note! this is a temporary solution. Our symbolic simplify is not
* powerful to simplify the 12 * k / 4 % 3 and so on. So we only handle
* the simplest case. We can modify our class when we can simplify the
* 12 * k / 4 % 3 well.
*/
class AnalyzeSingleAxisOfMultDimTensor : public ir::IRMutator<> {
public:
void operator()(ir::Expr* expr) { ir::IRMutator<>::Visit(expr, expr); }

void Visit(const ir::Store* op, Expr* expr) override {
ir::Store* store = expr->As<ir::Store>();
ir::Tensor tensor = store->tensor.as_tensor_ref();
AnalyzeSingleAxisAccess(store->indices, tensor);
ir::IRMutator<>::Visit(op, expr);
}

// Analyze the buffer access inside load
void Visit(const ir::Load* op, Expr* expr) override {
ir::Load* load = expr->As<ir::Load>();
ir::Tensor tensor = load->tensor.as_tensor_ref();
AnalyzeSingleAxisAccess(load->indices, tensor);
ir::IRMutator<>::Visit(op, expr);
}

void AnalyzeSingleAxisAccess(const std::vector<Expr>& indices,
const ir::Tensor& tensor) {
if (!tensor->buffer.defined() ||
tensor->buffer->memory_type == ir::MemoryType::Heap ||
tensor->buffer->memory_type == ir::MemoryType::GPUShared) {
return;
}
CHECK(indices.size() > 0) << "Buffer access indices is empty";
const std::string& buffer_name = tensor->buffer->name;
const std::vector<ir::Expr>& shape = tensor->shape;

ir::Expr index_expr;
if (indices.size() == 1 && shape.size() > 1) {
index_expr = indices[0];
} else if (indices.size() == shape.size()) {
ir::Expr mul = Expr(1);
index_expr = indices.back();
for (int i = static_cast<int>(indices.size()) - 2; i >= 0; --i) {
mul = ir::Mul::Make(shape[i + 1], mul);
ir::Expr cur = ir::Mul::Make(indices[i], mul);
index_expr = ir::Add::Make(cur, index_expr);
}
}
index_expr = common::AutoSimplify(index_expr);

if (!buffer_name_to_same_single_axis.count(buffer_name)) {
buffer_name_to_same_single_axis[buffer_name] = index_expr;
return;
} else {
const ir::Expr& stored_index_expr =
buffer_name_to_same_single_axis[buffer_name];
if (!ExprMathEqual(index_expr, stored_index_expr)) {
buffer_name_to_same_single_axis.erase(buffer_name);
}
}
}

std::unordered_map<std::string, ir::Expr> buffer_name_to_same_single_axis;
};

class AnalyzeBufferAxis : public ir::IRMutator<> {
public:
void operator()(ir::Expr* expr) { ir::IRMutator<>::Visit(expr, expr); }
Expand Down Expand Up @@ -260,11 +186,9 @@ class ReplaceSameAxisToZero : public ir::IRMutator<> {
public:
ReplaceSameAxisToZero(
const std::unordered_map<std::string, std::map<int, ir::Expr>>&
buffer_name_access_same_index_expr,
const std::unordered_map<std::string, ir::Expr>&
buffer_name_to_same_single_axis)
: buffer_name_access_same_index_expr_(buffer_name_access_same_index_expr),
buffer_name_to_same_single_axis_(buffer_name_to_same_single_axis) {}
buffer_name_access_same_index_expr)
: buffer_name_access_same_index_expr_(
buffer_name_access_same_index_expr) {}

void operator()(ir::Expr* expr) { ir::IRMutator<>::Visit(expr, expr); }

Expand Down Expand Up @@ -303,29 +227,15 @@ class ReplaceSameAxisToZero : public ir::IRMutator<> {
}
return;
}
if (buffer_name_to_same_single_axis_.count(buffer_name)) {
indices->clear();
indices->push_back(ir::Expr(0));
return;
}
}

const std::unordered_map<std::string, std::map<int, ir::Expr>>&
buffer_name_access_same_index_expr_;
const std::unordered_map<std::string, ir::Expr>&
buffer_name_to_same_single_axis_;
};

void UpdateBufferAxisPass(ir::Expr* expr) {
VLOG(6) << "Before UpdateBufferAxisPass, Expr = \n" << *expr;

// AnalyzeSingleAxisOfMultDimTensor singler_axis_analyzer;
// singler_axis_analyzer(expr);
// for (auto p : singler_axis_analyzer.buffer_name_to_same_single_axis) {
// VLOG(6) << "Single axis Buffer name: " << p.first;
// VLOG(6) << "Single Expr: " << p.second;
// }
std::unordered_map<std::string, ir::Expr> dump;
AnalyzeBufferAxis buffer_axis_analyzer;
buffer_axis_analyzer(expr);
for (auto p : buffer_axis_analyzer.buffer_name_access_same_index_expr) {
Expand All @@ -336,9 +246,7 @@ void UpdateBufferAxisPass(ir::Expr* expr) {
}

ReplaceSameAxisToZero replacer(
buffer_axis_analyzer.buffer_name_access_same_index_expr,
// singler_axis_analyzer.buffer_name_to_same_single_axis);
dump);
buffer_axis_analyzer.buffer_name_access_same_index_expr);
replacer(expr);
VLOG(6) << "After UpdateBufferAxisPass, Expr = \n" << *expr;
}
Expand Down

0 comments on commit 54e46e3

Please sign in to comment.