diff --git a/include/matxscript/ir/printer/doc.h b/include/matxscript/ir/printer/doc.h index 7f6c8e7d..076796db 100644 --- a/include/matxscript/ir/printer/doc.h +++ b/include/matxscript/ir/printer/doc.h @@ -310,6 +310,19 @@ class LiteralDoc : public ExprDoc { return LiteralDoc::Str(dtype, p); } + /*! + * \brief Create a LiteralDoc to represent string. + * \param v The string value. + * \param p The object path + */ + static LiteralDoc HLOType(const Type& v, const Optional& p) { + if (auto const* pt = v.as()) { + return LiteralDoc::DataType(pt->dtype, p); + } + StringRef dtype = v->GetPythonTypeName().encode(); + return LiteralDoc::Str(dtype, p); + } + MATXSCRIPT_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(LiteralDoc, ExprDoc, LiteralDocNode); }; diff --git a/src/ir/_base/cow_array_ref.cc b/src/ir/_base/cow_array_ref.cc index 1a65b946..61853417 100644 --- a/src/ir/_base/cow_array_ref.cc +++ b/src/ir/_base/cow_array_ref.cc @@ -29,6 +29,8 @@ #include #include #include +#include +#include #include #include @@ -197,5 +199,19 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ']'; }); +using namespace ::matxscript::ir::printer; +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch>( // + "", + [](Array array, ObjectPath p, IRDocsifier d) -> Doc { + int n = array.size(); + Array results; + results.reserve(n); + for (int i = 0; i < n; ++i) { + results.push_back(d->AsDoc(array[i], p->ArrayIndex(i))); + } + return ListDoc(results); + }); + } // namespace ir } // namespace matxscript diff --git a/src/ir/_base/cow_map_ref.cc b/src/ir/_base/cow_map_ref.cc index ad6a7d6a..2fa0e684 100644 --- a/src/ir/_base/cow_map_ref.cc +++ b/src/ir/_base/cow_map_ref.cc @@ -29,6 +29,8 @@ #include #include #include +#include +#include #include #include @@ -364,5 +366,41 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) MATX_DLL constexpr uint64_t DenseMapNode::kNextProbeLocation[]; +using namespace ::matxscript::ir::printer; +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch>( // + "", + [](Map dict, ObjectPath p, IRDocsifier d) -> Doc { + using POO = std::pair; + std::vector items{dict.begin(), dict.end()}; + bool is_str_map = true; + for (const auto& kv : items) { + if (!kv.first.as()) { + is_str_map = false; + break; + } + } + if (is_str_map) { + std::sort(items.begin(), items.end(), [](const POO& lhs, const POO& rhs) { + return runtime::Downcast(lhs.first) < + runtime::Downcast(rhs.first); + }); + } else { + std::sort(items.begin(), items.end(), [](const POO& lhs, const POO& rhs) { + return lhs.first.get() < rhs.first.get(); + }); + } + int n = dict.size(); + Array ks; + Array vs; + ks.reserve(n); + vs.reserve(n); + for (int i = 0; i < n; ++i) { + ks.push_back(d->AsDoc(items[i].first, p->MissingMapEntry())); + vs.push_back(d->AsDoc(items[i].second, p->MapValue(items[i].first))); + } + return DictDoc(ks, vs); + }); + } // namespace ir } // namespace matxscript diff --git a/src/ir/_base/string_ref.cc b/src/ir/_base/string_ref.cc index 49d72a47..a251fba0 100644 --- a/src/ir/_base/string_ref.cc +++ b/src/ir/_base/string_ref.cc @@ -23,6 +23,8 @@ #include #include +#include +#include #include #include #include @@ -262,5 +264,12 @@ typename StringRef::const_reverse_iterator StringRef::rend() const { return const_reverse_iterator(begin()); } +using namespace ::matxscript::ir::printer; +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](StringRef s, ObjectPath p, IRDocsifier d) -> Doc { + // TODO: optimize MultipleLines + return LiteralDoc::Str(s, p); + }); + } // namespace ir } // namespace matxscript diff --git a/src/ir/adt.cc b/src/ir/adt.cc index fbe6a2d7..77915482 100644 --- a/src/ir/adt.cc +++ b/src/ir/adt.cc @@ -30,6 +30,8 @@ */ #include +#include +#include #include #include @@ -37,6 +39,7 @@ namespace matxscript { namespace ir { using namespace runtime; +using namespace ::matxscript::ir::printer; Constructor::Constructor(Type ret_type, StringRef name_hint, @@ -66,6 +69,11 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << node->belong_to << ")"; }); +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // + .set_dispatch("", [](Constructor node, ObjectPath p, IRDocsifier d) -> Doc { + return IdDoc(node->name_hint); + }); + // ClassType ClassType::ClassType(uint64_t py_type_id, GlobalTypeVar header, diff --git a/src/ir/hlo_expr.cc b/src/ir/hlo_expr.cc index e7ba526b..8a15b187 100644 --- a/src/ir/hlo_expr.cc +++ b/src/ir/hlo_expr.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -33,6 +34,7 @@ namespace matxscript { namespace ir { using namespace ::matxscript::runtime; +using namespace ::matxscript::ir::printer; StringImm::StringImm(StringRef value, Span span) { ObjectPtr node = runtime::make_object(); @@ -52,6 +54,12 @@ MATXSCRIPT_REGISTER_GLOBAL("ir.StringImm").set_body_typed([](StringRef s, Span s return StringImm(std::move(s), std::move(span)); }); +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](StringImm s, ObjectPath p, IRDocsifier d) -> Doc { + // TODO: fix bytes print + return LiteralDoc::Str(s->value, p->Attr("value")); + }); + UnicodeImm::UnicodeImm(StringRef value, Span span) { ObjectPtr node = runtime::make_object(); node->value = std::move(value); @@ -70,6 +78,12 @@ MATXSCRIPT_REGISTER_GLOBAL("ir.UnicodeImm").set_body_typed([](StringRef s, Span return UnicodeImm(std::move(s), std::move(span)); }); +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](UnicodeImm s, ObjectPath p, IRDocsifier d) -> Doc { + // TODO: fix unicode + return LiteralDoc::Str(s->value, p->Attr("value")); + }); + #define MATXSCRIPT_DEFINE_CMPOP_CONSTRUCTOR(Name) \ Name::Name(BaseExpr a, BaseExpr b, Span span) { \ using T = Name::ContainerType; \ @@ -205,6 +219,14 @@ static Type InferAddOpType(const Type& lhs_raw, const Type& rhs_raw) { return ObjectType(); } +#define MATXSCRIPT_SCRIPT_PRINTER_DEF_HLO_BINARY(NodeType, OpKind) \ + MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) \ + .set_dispatch("", [](ir::NodeType node, ObjectPath p, IRDocsifier d) -> Doc { \ + ExprDoc a = d->AsDoc(node->a, p->Attr("a")); \ + ExprDoc b = d->AsDoc(node->b, p->Attr("b")); \ + return OperationDoc(OperationDocNode::Kind::OpKind, {a, b}); \ + }); + // HLOAdd HLOAdd::HLOAdd(BaseExpr a, BaseExpr b, Span span) { MXCHECK(a.defined()) << "ValueError: a is undefined\n"; @@ -236,6 +258,8 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ')'; }); +MATXSCRIPT_SCRIPT_PRINTER_DEF_HLO_BINARY(HLOAdd, kAdd); + // HLOSub static Type InferSubOpType(const Type& lhs_raw, const Type& rhs_raw) { const auto& lhs_type = RemoveReference(lhs_raw); @@ -329,6 +353,8 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ')'; }); +MATXSCRIPT_SCRIPT_PRINTER_DEF_HLO_BINARY(HLOSub, kSub); + // HLOMul static Type InferMulOpType(const Type& lhs_raw, const Type& rhs_raw) { const auto& lhs_type = RemoveReference(lhs_raw); @@ -451,6 +477,8 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ')'; }); +MATXSCRIPT_SCRIPT_PRINTER_DEF_HLO_BINARY(HLOMul, kMult); + // HLOFloorDiv static Type InferFloorDivOpType(const Type& lhs_raw, const Type& rhs_raw) { const auto& lhs_type = RemoveReference(lhs_raw); @@ -520,6 +548,8 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "floordiv(" << op->a << ", " << op->b << ")"; }); +MATXSCRIPT_SCRIPT_PRINTER_DEF_HLO_BINARY(HLOFloorDiv, kFloorDiv); + // HLOFloorMod static Type InferFloorModOpType(const Type& lhs_raw, const Type& rhs_raw) { const auto& lhs_type = RemoveReference(lhs_raw); @@ -589,6 +619,8 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "floormod(" << op->a << ", " << op->b << ")"; }); +MATXSCRIPT_SCRIPT_PRINTER_DEF_HLO_BINARY(HLOFloorMod, kMod); + // HLOEqual HLOEqual::HLOEqual(BaseExpr a, BaseExpr b, Span span) { MXCHECK(a.defined()) << "ValueError: a is undefined\n"; @@ -617,6 +649,8 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ')'; }); +MATXSCRIPT_SCRIPT_PRINTER_DEF_HLO_BINARY(HLOEqual, kEq); + // HLONotEqual HLONotEqual::HLONotEqual(BaseExpr a, BaseExpr b, Span span) { MXCHECK(a.defined()) << "ValueError: a is undefined\n"; @@ -645,6 +679,8 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ')'; }); +MATXSCRIPT_SCRIPT_PRINTER_DEF_HLO_BINARY(HLONotEqual, kNotEq); + // HLOLessThan HLOLessThan::HLOLessThan(BaseExpr a, BaseExpr b, Span span) { MXCHECK(a.defined()) << "ValueError: a is undefined\n"; @@ -673,6 +709,8 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ')'; }); +MATXSCRIPT_SCRIPT_PRINTER_DEF_HLO_BINARY(HLOLessThan, kLt); + // HLOLessEqual HLOLessEqual::HLOLessEqual(BaseExpr a, BaseExpr b, Span span) { MXCHECK(a.defined()) << "ValueError: a is undefined\n"; @@ -701,6 +739,8 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ')'; }); +MATXSCRIPT_SCRIPT_PRINTER_DEF_HLO_BINARY(HLOLessEqual, kLtE); + // HLOGreaterThan HLOGreaterThan::HLOGreaterThan(BaseExpr a, BaseExpr b, Span span) { MXCHECK(a.defined()) << "ValueError: a is undefined\n"; @@ -730,6 +770,8 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ')'; }); +MATXSCRIPT_SCRIPT_PRINTER_DEF_HLO_BINARY(HLOGreaterThan, kGt); + // HLOGreaterEqual HLOGreaterEqual::HLOGreaterEqual(BaseExpr a, BaseExpr b, Span span) { MXCHECK(a.defined()) << "ValueError: a is undefined\n"; @@ -759,6 +801,8 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ')'; }); +MATXSCRIPT_SCRIPT_PRINTER_DEF_HLO_BINARY(HLOGreaterEqual, kGtE); + // HLOAnd HLOAnd::HLOAnd(BaseExpr a, BaseExpr b, Span span) { MXCHECK(a.defined()) << "ValueError: a is undefined"; @@ -799,6 +843,8 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ')'; }); +MATXSCRIPT_SCRIPT_PRINTER_DEF_HLO_BINARY(HLOAnd, kAnd); + // HLOOr HLOOr::HLOOr(BaseExpr a, BaseExpr b, Span span) { MXCHECK(a.defined()) << "ValueError: a is undefined"; @@ -832,6 +878,10 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ')'; }); +MATXSCRIPT_SCRIPT_PRINTER_DEF_HLO_BINARY(HLOOr, kOr); + +#undef MATXSCRIPT_SCRIPT_PRINTER_DEF_HLO_BINARY + // HLONot HLONot::HLONot(BaseExpr a, Span span) { MXCHECK(a.defined()) << "ValueError: a is undefined"; @@ -857,6 +907,12 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ")"; }); +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](ir::HLONot node, ObjectPath p, IRDocsifier d) -> Doc { + ExprDoc a = d->AsDoc(node->a, p->Attr("a")); + return OperationDoc(OperationDocNode::Kind::kNot, {a}); + }); + // Call Call::Call(Type ret_type, HLOExpr op, Array args, Span span, Array type_args) { ObjectPtr n = make_object(); @@ -886,6 +942,27 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << ") -> " << node->checked_type_; }); +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](ir::Call call, ObjectPath call_p, IRDocsifier d) -> Doc { + ExprDoc prefix{nullptr}; + if (const auto* op = call->op.as()) { + // TODO: fix prim op name + StringRef name = op->name; + prefix = Dialect(d, name); + } else if (const auto* gv = call->op.as()) { + prefix = LiteralDoc::Str(gv->name_hint, call_p->Attr("op")); + } else { + MXLOG(FATAL) << "call: " << call; + } + Array args; + int n_args = call->args.size(); + args.reserve(n_args + 1); + for (int i = 0; i < n_args; ++i) { + args.push_back(d->AsDoc(call->args[i], call_p->Attr("args")->ArrayIndex(i))); + } + return prefix->Call(args); + }); + // HLOIterator HLOIterator::HLOIterator(BaseExpr container, IntImm method, Span span) { ObjectPtr n = make_object(); @@ -909,6 +986,8 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "HLOIterator(" << node->container << "." << node->method << ")"; }); +// TODO: remove unused HLOIterator + // InitializerList InitializerList::InitializerList(Array fields, Span span) { ObjectPtr n = make_object(); @@ -948,6 +1027,12 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "InitializerList(" << node->fields << ")"; }); +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( + "", [](ir::InitializerList li, ObjectPath li_p, IRDocsifier d) -> Doc { + return d->AsDoc(li->fields, li_p->Attr("fields")); + }); + // InitializerDict InitializerDict::InitializerDict(Map fields, Span span) { ObjectPtr n = make_object(); @@ -1004,6 +1089,12 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "InitializerDict(" << node->fields << ")"; }); +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( + "", [](ir::InitializerDict di, ObjectPath di_p, IRDocsifier d) -> Doc { + return d->AsDoc(di->fields, di_p->Attr("fields")); + }); + // EnumAttr EnumAttr::EnumAttr(StringRef enum_str, Span span) { ObjectPtr n = make_object(); @@ -1025,6 +1116,12 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "EnumAttr(" << node->enum_str << ")"; }); +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](ir::EnumAttr en, ObjectPath en_p, IRDocsifier d) -> Doc { + // TODO: fixme + return d->AsDoc(en->enum_str, en_p->Attr("enum_str")); + }); + // ClassGetItem ClassGetItem::ClassGetItem(HLOExpr self, StringImm attr, Span span) { ObjectPtr n = make_object(); @@ -1053,6 +1150,13 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "ClassGetItemNode(" << node->self << "." << node->attr << ")"; }); +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( + "", [](ir::ClassGetItem cls_getitem, ObjectPath cls_getitem_p, IRDocsifier d) -> Doc { + auto self = d->AsDoc(cls_getitem->self, cls_getitem_p->Attr("self")); + return self->Attr(cls_getitem->attr->value); + }); + // HLOCast HLOCast::HLOCast(Type t, BaseExpr value, Span span) { MXCHECK(value.defined()); @@ -1077,6 +1181,13 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ')'; }); +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](ir::HLOCast e, ObjectPath p, IRDocsifier d) -> Doc { + ExprDoc dtype = LiteralDoc::HLOType(e->checked_type_, p->Attr("checked_type_")); + ExprDoc value = d->AsDoc(e->value, p->Attr("value")); + return Dialect(d, "HLOCast")->Call({dtype, value}); + }); + // HLOMove HLOMove::HLOMove(BaseExpr value, Span span) { MXCHECK(value.defined()); @@ -1101,6 +1212,12 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ")"; }); +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](ir::HLOMove e, ObjectPath p, IRDocsifier d) -> Doc { + ExprDoc value = d->AsDoc(e->value, p->Attr("value")); + return Dialect(d, "move")->Call({value}); + }); + // HLOEnumerate HLOEnumerate::HLOEnumerate(BaseExpr value, BaseExpr start, Span span) { MXCHECK(value.defined()); @@ -1134,6 +1251,12 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ")"; }); +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](ir::HLOEnumerate e, ObjectPath p, IRDocsifier d) -> Doc { + ExprDoc value = d->AsDoc(e->value, p->Attr("value")); + return IdDoc("enumerate")->Call({value}); + }); + // HLOZip HLOZip::HLOZip(Array values, Span span) { MXCHECK(values.defined()); @@ -1167,5 +1290,17 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ")"; }); +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](ir::HLOZip e, ObjectPath p, IRDocsifier d) -> Doc { + int n = e->values.size(); + Array results; + results.reserve(n); + p = p->Attr("values"); + for (int i = 0; i < n; ++i) { + results.push_back(d->AsDoc(e->values[i], p->ArrayIndex(i))); + } + return IdDoc("zip")->Call(results); + }); + } // namespace ir } // namespace matxscript diff --git a/src/ir/hlo_var.cc b/src/ir/hlo_var.cc index c6cdfcdc..e20e53f4 100644 --- a/src/ir/hlo_var.cc +++ b/src/ir/hlo_var.cc @@ -25,6 +25,8 @@ #include #include #include +#include +#include #include #include @@ -32,6 +34,7 @@ namespace matxscript { namespace ir { using runtime::make_object; +using namespace matxscript::ir::printer; MATXSCRIPT_REGISTER_NODE_TYPE(IdNode); @@ -68,6 +71,11 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ")"; }); +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // + .set_dispatch("", [](HLOVar var, ObjectPath p, IRDocsifier d) -> Doc { + return IdDoc(var->name_hint()); + }); + GlobalVar::GlobalVar(StringRef name_hint, Span span) { ObjectPtr n = make_object(); n->name_hint = std::move(name_hint); @@ -87,5 +95,10 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "GlobalVar(" << node->name_hint << ")"; }); +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // + .set_dispatch("", [](GlobalVar var, ObjectPath p, IRDocsifier d) -> Doc { + return IdDoc(var->name_hint); + }); + } // namespace ir } // namespace matxscript diff --git a/src/ir/none_expr.cc b/src/ir/none_expr.cc index 8b46a701..d53a8582 100644 --- a/src/ir/none_expr.cc +++ b/src/ir/none_expr.cc @@ -24,6 +24,8 @@ #include #include +#include +#include #include #include @@ -31,6 +33,7 @@ namespace matxscript { namespace ir { using namespace runtime; +using namespace ::matxscript::ir::printer; NoneExpr::NoneExpr(Span span) { ObjectPtr n = make_object(); @@ -51,5 +54,10 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "NoneExpr"; }); +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // + .set_dispatch("", [](NoneExpr t, ObjectPath p, IRDocsifier d) -> Doc { + return LiteralDoc::None(p); + }); + } // namespace ir } // namespace matxscript diff --git a/src/ir/prim_expr.cc b/src/ir/prim_expr.cc index a319f903..3ec3bed4 100644 --- a/src/ir/prim_expr.cc +++ b/src/ir/prim_expr.cc @@ -615,6 +615,9 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) MATXSCRIPT_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(PrimOr, PrimOrNode, logic_or, "Or", kOr); +#undef MATXSCRIPT_SCRIPT_PRINTER_DEF_BINARY +#undef MATXSCRIPT_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR + // PrimNot PrimNot::PrimNot(PrimExpr a, Span span) { MXCHECK(a.defined()) << "ValueError: a is undefined"; @@ -641,6 +644,12 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->Print(op->a); }); +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](ir::PrimNot node, ObjectPath p, IRDocsifier d) -> Doc { + ExprDoc a = d->AsDoc(node->a, p->Attr("a")); + return OperationDoc(OperationDocNode::Kind::kNot, {a}); + }); + // PrimSelect PrimSelect::PrimSelect(PrimExpr condition, PrimExpr true_value, PrimExpr false_value, Span span) { MXCHECK(condition.defined()) << "ValueError: condition is undefined"; diff --git a/src/ir/range_expr.cc b/src/ir/range_expr.cc index 71d91df3..b48a6a9b 100644 --- a/src/ir/range_expr.cc +++ b/src/ir/range_expr.cc @@ -21,6 +21,8 @@ #include #include +#include +#include #include #include #include @@ -29,6 +31,7 @@ namespace matxscript { namespace ir { using namespace runtime; +using namespace ::matxscript::ir::printer; RangeExpr::RangeExpr(PrimExpr start, PrimExpr stop, PrimExpr step, Span span) { auto n = make_object(); @@ -55,5 +58,14 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << ')'; }); +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // + .set_dispatch("", [](RangeExpr r, ObjectPath p, IRDocsifier d) -> Doc { + // TODO: optimize range args + ExprDoc start = d->AsDoc(r->start, p->Attr("start")); + ExprDoc stop = d->AsDoc(r->stop, p->Attr("stop")); + ExprDoc step = d->AsDoc(r->step, p->Attr("step")); + return IdDoc("range")->Call({start, stop, step}); + }); + } // namespace ir } // namespace matxscript diff --git a/src/ir/tuple_expr.cc b/src/ir/tuple_expr.cc index b5f7d1fa..8d2ba447 100644 --- a/src/ir/tuple_expr.cc +++ b/src/ir/tuple_expr.cc @@ -24,6 +24,8 @@ #include #include +#include +#include #include #include @@ -31,6 +33,7 @@ namespace matxscript { namespace ir { using namespace runtime; +using namespace ::matxscript::ir::printer; TupleExpr::TupleExpr(Array fields, Span span) { ObjectPtr n = make_object(); @@ -51,5 +54,17 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "TupleExpr(" << node->fields << ")"; }); +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // + .set_dispatch("", [](TupleExpr t, ObjectPath p, IRDocsifier d) -> Doc { + int n = t->fields.size(); + Array results; + results.reserve(n); + p = p->Attr("fields"); + for (int i = 0; i < n; ++i) { + results.push_back(d->AsDoc(t->fields[i], p->ArrayIndex(i))); + } + return TupleDoc(std::move(results)); + }); + } // namespace ir } // namespace matxscript