Skip to content

Commit

Permalink
[CINN] Update Buffer Axis (#59209)
Browse files Browse the repository at this point in the history
Make Resize Buffer Safer
  • Loading branch information
zhhsplendid authored Dec 8, 2023
1 parent caa0d78 commit 180b64d
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 106 deletions.
12 changes: 10 additions & 2 deletions paddle/cinn/optim/replace_mod_to_max.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,29 @@

#include <unordered_map>

#include "paddle/cinn/common/cas.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"

namespace cinn {
namespace optim {

/**
* Replace Mod to possible max value.
* a % b -> min(b - 1, a)
* either b - 1 or a is the possible max value of the mod expression.
*/
class ReplaceModToMaxMutator : public ir::IRMutator<> {
public:
void operator()(ir::Expr* expr) { ir::IRMutator<>::Visit(expr, expr); }

void Visit(const ir::Mod* op, ir::Expr* expr) override {
ir::Mod* node = expr->As<ir::Mod>();
Expr base = node->operand(1);
*expr = ir::Sub::Make(base, Expr(1));
Expr base = ir::Sub::Make(node->operand(1), Expr(1));
Expr min_expr = ir::Min::Make(node->operand(0), base);
*expr = cinn::common::AutoSimplify(min_expr);
ir::IRMutator<>::Visit(expr, expr);
}
};

Expand Down
106 changes: 2 additions & 104 deletions paddle/cinn/optim/transform_gpu_forloop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -521,106 +521,6 @@ class LocalAxisVisitor : public ir::IRMutator<> {
"threadIdx.z"};
};

class ResizeBufferSizeVisitor : public ir::IRMutator<> {
public:
void operator()(ir::Expr *expr) { ir::IRMutator<>::Visit(expr, expr); }

private:
void Visit(const ir::Store *op, Expr *expr) override {
auto store = expr->As<ir::Store>();
auto store_tensor = store->tensor.as_tensor_ref();

if (!store_tensor->buffer.defined()) {
return;
}
if (store_tensor->buffer->memory_type == ir::MemoryType::Heap) {
ir::IRMutator<>::Visit(op, expr);
return;
}

auto &indices = store->indices;
auto &shape = store_tensor->shape;
auto &buffer = store_tensor->buffer->shape;

shape.clear();
buffer.clear();
for (int idx = 0; idx < indices.size(); ++idx) {
shape.push_back(ir::Expr(BufferSize(indices[idx])));
buffer.push_back(shape.back());
}
ir::IRMutator<>::Visit(op, expr);
}

void Visit(const ir::Load *op, Expr *expr) override {
auto load = expr->As<ir::Load>();
if (!load->tensor.as_tensor_ref()->buffer.defined()) {
return;
}

if (load->tensor.as_tensor_ref()->buffer->memory_type ==
ir::MemoryType::Heap) {
ir::IRMutator<>::Visit(op, expr);
return;
}

load->tensor.as_tensor_ref()->shape =
load->tensor.as_tensor_ref()->buffer->shape;

// For the moment, align the load tensor indices with the tensor shape using
// the trick method. A better way would be to modify the FlattenLoop
// Schedule.
int cnt = load->indices.size() - load->tensor.as_tensor_ref()->shape.size();
for (int i = 0; i < cnt; i++) {
load->indices.erase(load->indices.begin());
}
ir::IRMutator<>::Visit(op, expr);
}

void Visit(const ir::For *op, Expr *expr) override {
CHECK(expr->As<ir::For>());
auto for_ir = expr->As<ir::For>();
auto var_name = for_ir->loop_var->name;
auto extent_i = for_ir->extent;

if (extent_i.is_constant()) loop_2_extent_[var_name] = extent_i.as_int32();
ir::IRMutator<>::Visit(op, expr);
}

int BufferSize(ir::Expr indice) {
auto copy = ir::ir_utils::IRCopy(indice);
auto vars = ir::ir_utils::CollectIRNodesInOrder(
copy, [](const ir::Expr *expr) { return expr->As<ir::_Var_>(); });

int max_range = 1;
// using recursion funcitons index range.
std::function<void(int, ir::Expr)> compute_range = [&](const int deep,
ir::Expr index) {
auto var = vars[deep].as_var_ref();
CHECK(loop_2_extent_.count(var->name)) << var->name;
auto extent = loop_2_extent_.find(var->name)->second;

for (int idx = 0; idx < extent; ++idx) {
auto tmp = ir::ir_utils::IRCopy(index);
ReplaceVarWithExpr(&tmp, var, Expr(idx));

if (deep == vars.size() - 1) {
auto simplify = cinn::common::AutoSimplify(tmp);
auto range = cinn::common::AutoSimplify(simplify);
CHECK(range.is_constant());
max_range = std::max(max_range, range.as_int32() + 1);
} else {
compute_range(deep + 1, tmp);
}
}
};

if (vars.size()) compute_range(0, copy);
return max_range;
}

std::unordered_map<std::string, int> loop_2_extent_;
};

class ReplaceVarToZero : public ir::IRMutator<> {
public:
void operator()(ir::Expr *expr) { ir::IRMutator<>::Visit(expr, expr); }
Expand Down Expand Up @@ -685,8 +585,8 @@ void OptimizeExprGPU(Expr *expr) {
replace_index_to_bind_expr(expr);

// resize buffer axis
// UpdateBufferAxisPass(expr);
UpdateBufferAxisPassOld(expr);
UpdateBufferAxisPass(expr);
// UpdateBufferAxisPassOld(expr);

// replace var name with block/thread
ReplaceLoopVarToGpu replace_loop_var_to_gpu;
Expand All @@ -701,8 +601,6 @@ void OptimizeExprGPU(Expr *expr) {
local_axis_visitor(expr);

ResizeBufferToMaxVarRange(expr);
// ResizeBufferSizeVisitor resize_buffer_size_visitor;
// resize_buffer_size_visitor(expr);

ReplaceVarToZero replace_var_to_zero;
replace_var_to_zero(expr);
Expand Down
10 changes: 10 additions & 0 deletions paddle/cinn/optim/update_buffer_axis_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,11 @@ class AnalyzeBufferAxis : public ir::IRMutator<> {
void Visit(const ir::Store* op, Expr* expr) override {
ir::Store* store = expr->As<ir::Store>();
ir::Tensor tensor = store->tensor.as_tensor_ref();
if (!tensor->buffer.defined() ||
tensor->buffer->memory_type == ir::MemoryType::Heap) {
ir::IRMutator<>::Visit(op, expr);
return;
}
FormalizeSingleIndex(tensor, &(store->indices));
AnalyzeTensorAxis(store->indices, tensor);
ir::IRMutator<>::Visit(op, expr);
Expand All @@ -162,6 +167,11 @@ class AnalyzeBufferAxis : public ir::IRMutator<> {
void Visit(const ir::Load* op, Expr* expr) override {
ir::Load* load = expr->As<ir::Load>();
ir::Tensor tensor = load->tensor.as_tensor_ref();
if (!tensor->buffer.defined() ||
tensor->buffer->memory_type == ir::MemoryType::Heap) {
ir::IRMutator<>::Visit(op, expr);
return;
}
FormalizeSingleIndex(tensor, &(load->indices));
AnalyzeTensorAxis(load->indices, tensor);
ir::IRMutator<>::Visit(op, expr);
Expand Down

0 comments on commit 180b64d

Please sign in to comment.