Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relax][Bugfix] FCallPacked not checked in CodegenVMTIR #17073

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 1 addition & 23 deletions src/relax/backend/vm/codegen_vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,21 +45,6 @@ using namespace relax;
using namespace tvm::runtime;
using namespace tvm::runtime::relax_vm;

namespace {
// Helper function to get the function name of the registered packed function implementation of
// relax operator.
FCallPacked GetPackedFuncName(const Call& call) {
static auto op_map = Op::GetAttrMap<FCallPacked>("FCallPacked");
if (call->op.as<OpNode>()) {
Op op = Downcast<Op>(call->op);
if (op_map.count(op)) {
return op_map[op];
}
}
return {};
}
} // namespace

/*!
* \brief A class to generate VM executable for Relax functions.
*/
Expand Down Expand Up @@ -156,14 +141,7 @@ class CodeGenVM : public ExprFunctor<Instruction::Arg(const Expr&)> {
// allocate dst register.
RegName dst_reg = HasVoidStructInfo(call) ? Instruction::kVoidRegister : NewRegister();
if (call->op.as<OpNode>()) {
// special case generate for the intrinsics whose attribute fields
// cannot be represented by args in the CallNode
FCallPacked name = GetPackedFuncName(call);
if (!name.empty()) {
// If the operator has a registered packed function implementation, emit call to that packed
// function.
EmitPackedFuncCall(call, name, dst_reg);
} else if (call_node->op == call_builtin_with_ctx_op_) {
if (call_node->op == call_builtin_with_ctx_op_) {
// TODO(relax-team) migrate most handling of op to
// directly map to call_builtin_with_ctx before codegen and simplify vm codegen.
EmitCallBuiltinWithCtx(call, dst_reg);
Expand Down
24 changes: 1 addition & 23 deletions src/relax/backend/vm/codegen_vm_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,21 +44,6 @@ namespace relax_vm {

using vm::VMFuncInfo;

namespace {
// Helper function to get the function name of the registered packed function implementation of
// relax operator.
FCallPacked GetPackedFuncName(const Call& call) {
static auto op_map = Op::GetAttrMap<FCallPacked>("FCallPacked");
if (call->op.as<OpNode>()) {
Op op = Downcast<Op>(call->op);
if (op_map.count(op)) {
return op_map[op];
}
}
return {};
}
} // namespace

/*!
* \brief A class to generate VMTIR for Relax functions.
*
Expand Down Expand Up @@ -247,14 +232,7 @@ class CodeGenVMTIR : public ExprFunctor<Optional<PrimExpr>(const Expr&)> {
}
int64_t dst_reg = HasVoidStructInfo(call) ? -1 : NewRegister();
if (call->op.as<OpNode>()) {
// special case generate for the intrinsics whose attribute fields
// cannot be represented by args in the CallNode
FCallPacked name = GetPackedFuncName(call);
if (name.size()) {
// If the operator has a registered packed function implementation, emit call to that packed
// function.
EmitCallPacked(name, VisitArray(call->args), dst_reg);
} else if (call_node->op == call_builtin_with_ctx_op_) {
if (call_node->op == call_builtin_with_ctx_op_) {
EmitCallBuiltinWithCtx(call, dst_reg);
} else if (call_node->op == alloc_storage_op_) {
EmitAllocStorage(call, dst_reg);
Expand Down
25 changes: 17 additions & 8 deletions src/relax/transform/legalize_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ class LegalizeMutator : public ExprMutator {
Expr VisitExpr_(const CallNode* call) final {
Call visited_call = Downcast<Call>(this->VisitExprPostOrder_(call));
static const auto& legalize_map = Op::GetAttrMap<FLegalize>("FLegalize");
static const auto& call_packed_map = Op::GetAttrMap<FCallPacked>("FCallPacked");
static const auto& requires_arg_shapes_map = Op::GetAttrMap<Bool>("RequiresArgumentShapes");
static const Op& call_pure_packed_op = Op::Get("relax.call_pure_packed");
static const Op& call_tir_op = Op::Get("relax.call_tir");
Expand All @@ -236,7 +237,7 @@ class LegalizeMutator : public ExprMutator {
}
auto op = GetRef<Op>(op_node);

bool can_legalize = [&]() -> bool {
bool shapes_are_known_if_required = [&]() -> bool {
bool requires_arg_shapes = requires_arg_shapes_map.get(op, Bool(true))->value;
if (!requires_arg_shapes) {
// This operator does not require its arguments to have a
Expand Down Expand Up @@ -299,23 +300,31 @@ class LegalizeMutator : public ExprMutator {
return true;
}();

if (!can_legalize) {
return visited_call;
}

FLegalize legalization_func;

if (auto opt_custom_legalize = cmap_.Get(op->name)) {
if (auto opt_custom_legalize = cmap_.Get(op->name);
opt_custom_legalize && shapes_are_known_if_required) {
// First choice, use a custom legalization function
legalization_func = opt_custom_legalize.value();
} else if (legalize_map.count(op)) {
} else if (legalize_map.count(op) && shapes_are_known_if_required) {
// Second choice, use a default legalization
legalization_func = legalize_map[op];
} else if (call_packed_map.count(op)) {
// Third choice, use an explicit FCallPacked replacement. This does not require the shape
String packed_func_name = call_packed_map[op];
legalization_func = [packed_func_name](const BlockBuilder& bb, const Call& call) -> Expr {
return Call(ExternFunc(packed_func_name), call->args, Attrs(), {GetStructInfo(call)});
};
} else {
// No legalization.
if (enable_warning_ && op != call_tir_op && op != call_dps_packed_op &&
op != call_pure_packed_op) {
LOG(WARNING) << "No legalization func for " << op->name << " is found.";
if (shapes_are_known_if_required) {
LOG(WARNING) << "No legalization func for " << op->name << " is found.";
} else {
LOG(WARNING) << "Cannot legalize " << visited_call
<< ", missing known shapes for arguments and return value";
}
}
return visited_call;
}
Expand Down
Loading
Loading