Skip to content

Commit

Permalink
[CINN] Add IntrinsicOps into ir_codes_collector (#60556) (#62245)
Browse files Browse the repository at this point in the history
This PR fixed a bug of running Resnet PaddleClas.

The bug is due to vectorize introduce an intrinsic GetAddr and we didn't collect the tensor of GetAddr in ir_node_collector, this would caused tensor alias won't create in cuda code.

TODO: we may modify IntrinsicOp in the near future
  • Loading branch information
zhhsplendid authored Mar 1, 2024
1 parent 521dc70 commit 773ea41
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 2 deletions.
9 changes: 8 additions & 1 deletion paddle/cinn/ir/ir_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,16 +110,23 @@ class Dim;
macro__(Product) \
macro__(Sum) \
macro__(PrimitiveNode) \
macro__(IntrinsicOp) \
macro__(_BufferRange_) \
macro__(ScheduleBlock) \
macro__(ScheduleBlockRealize) \
macro__(_Dim_) \

#define NODETY_CONTROL_OP_FOR_INTRINSIC(macro__) \
macro__(IntrinsicOp) \

#define NODETY_FORALL(__m) \
NODETY_PRIMITIVE_TYPE_FOR_EACH(__m) \
NODETY_OP_FOR_EACH(__m) \
NODETY_CONTROL_OP_FOR_INTRINSIC(__m) \
NODETY_CONTROL_OP_FOR_EACH(__m)

#define NODETY_FORALL_EXCEPT_INTRINSIC(__m) \
NODETY_PRIMITIVE_TYPE_FOR_EACH(__m) \
NODETY_OP_FOR_EACH(__m) \
NODETY_CONTROL_OP_FOR_EACH(__m)
// clang-format on

Expand Down
67 changes: 66 additions & 1 deletion paddle/cinn/ir/utils/ir_nodes_collector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
#include "paddle/cinn/ir/utils/ir_nodes_collector.h"
#include <glog/logging.h>

#include "paddle/cinn/ir/intrinsic_ops.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"

Expand Down Expand Up @@ -71,8 +73,71 @@ struct IrNodesCollector : public IRVisitorRequireReImpl<void> {
} \
}

NODETY_FORALL(__m)
NODETY_FORALL_EXCEPT_INTRINSIC(__m)
#undef __m

void Visit(const ir::IntrinsicOp* op) {
switch (op->getKind()) {
#define __(x) \
case ir::IntrinsicKind::k##x: \
Visit(llvm::dyn_cast<ir::intrinsics::x>(op)); \
break;

INTRINSIC_KIND_FOR_EACH(__)
#undef __
}
}

void Visit(const ir::intrinsics::GetAddr* x) {
if (x->data.defined()) {
Visit(&(x->data));
}
}

void Visit(const ir::intrinsics::BufferGetDataHandle* x) {
if (x->buffer.defined()) {
Visit(&(x->buffer));
}
}

void Visit(const ir::intrinsics::BufferGetDataConstHandle* x) {
if (x->buffer.defined()) {
Visit(&(x->buffer));
}
}

void Visit(const ir::intrinsics::PodValueToX* x) {
if (x->pod_value_ptr.defined()) {
Visit(&(x->pod_value_ptr));
}
}

void Visit(const ir::intrinsics::BufferCreate* x) {
if (x->buffer.defined()) {
Visit(&(x->buffer));
}
}

void Visit(const ir::intrinsics::ArgsConstruct* x) {
if (x->var.defined()) {
Expr convert = Expr(x->var);
Visit(&convert);
}
for (int i = 0; i < x->args.size(); ++i) {
if (x->args[i].defined()) {
Visit(&(x->args[i]));
}
}
}

void Visit(const ir::intrinsics::BuiltinIntrin* x) {
for (int i = 0; i < x->args.size(); ++i) {
if (x->args[i].defined()) {
Visit(&(x->args[i]));
}
}
}

std::set<void*> visited_;
};

Expand Down

0 comments on commit 773ea41

Please sign in to comment.