From c27d85a9c91080c0d501f5e070026f06cfacceaa Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Sat, 6 Nov 2021 07:14:32 +0000 Subject: [PATCH] Emit the boilerplate for Type printer/parser dialect dispatching from ODS Add a new `useDefaultTypePrinterParser` boolean settings on the dialect (default to false for now) that emits the boilerplate to dispatch type parsing/printing to the auto-generated method. We will likely turn this on by default in the future. Differential Revision: https://reviews.llvm.org/D113332 --- .../mlir/Dialect/EmitC/IR/EmitCBase.td | 1 + mlir/include/mlir/IR/OpBase.td | 4 +++ mlir/include/mlir/TableGen/Dialect.h | 4 +++ mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 19 ----------- mlir/lib/TableGen/Dialect.cpp | 4 +++ mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp | 32 +++++++++++++++++++ mlir/tools/mlir-tblgen/DialectGen.cpp | 2 +- 7 files changed, 46 insertions(+), 20 deletions(-) diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitCBase.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitCBase.td index 40ef2c40efb1fc..01dc7796741976 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitCBase.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitCBase.td @@ -29,6 +29,7 @@ def EmitC_Dialect : Dialect { }]; let hasConstantMaterializer = 1; + let useDefaultTypePrinterParser = 1; } #endif // MLIR_DIALECT_EMITC_IR_EMITCBASE diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 70a5f2942a2df4..360832f9478b86 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -302,6 +302,10 @@ class Dialect { // it'll dispatch the parsing to every individual attributes directly. bit useDefaultAttributePrinterParser = 0; + // If this dialect should use default generated type parser boilerplate: + // it'll dispatch the parsing to every individual types directly. + bit useDefaultTypePrinterParser = 0; + // If this dialect overrides the hook for canonicalization patterns. bit hasCanonicalizer = 0; diff --git a/mlir/include/mlir/TableGen/Dialect.h b/mlir/include/mlir/TableGen/Dialect.h index 7eb70030785bd1..d56b2288e2ef5e 100644 --- a/mlir/include/mlir/TableGen/Dialect.h +++ b/mlir/include/mlir/TableGen/Dialect.h @@ -78,6 +78,10 @@ class Dialect { /// attribute printing/parsing. bool useDefaultAttributePrinterParser() const; + /// Returns true if this dialect should generate the default dispatch for + /// type printing/parsing. + bool useDefaultTypePrinterParser() const; + // Returns whether two dialects are equal by checking the equality of the // underlying record. bool operator==(const Dialect &other) const; diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index 484903c9a70a9d..f8aebd7c6f25c7 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -227,25 +227,6 @@ Type emitc::OpaqueType::parse(DialectAsmParser &parser) { return get(parser.getContext(), value); } -Type EmitCDialect::parseType(DialectAsmParser &parser) const { - llvm::SMLoc typeLoc = parser.getCurrentLocation(); - StringRef mnemonic; - if (parser.parseKeyword(&mnemonic)) - return Type(); - Type genType; - OptionalParseResult parseResult = - generatedTypeParser(parser, mnemonic, genType); - if (parseResult.hasValue()) - return genType; - parser.emitError(typeLoc, "unknown type in EmitC dialect"); - return Type(); -} - -void EmitCDialect::printType(Type type, DialectAsmPrinter &os) const { - if (failed(generatedTypePrinter(type, os))) - llvm_unreachable("unexpected 'EmitC' type kind"); -} - void emitc::OpaqueType::print(DialectAsmPrinter &printer) const { printer << "opaque<\""; llvm::printEscapedString(getValue(), printer.getStream()); diff --git a/mlir/lib/TableGen/Dialect.cpp b/mlir/lib/TableGen/Dialect.cpp index bfaf7163f45642..6970a7f8276a84 100644 --- a/mlir/lib/TableGen/Dialect.cpp +++ b/mlir/lib/TableGen/Dialect.cpp @@ -94,6 +94,10 @@ bool Dialect::useDefaultAttributePrinterParser() const { return def->getValueAsBit("useDefaultAttributePrinterParser"); } +bool Dialect::useDefaultTypePrinterParser() const { + return def->getValueAsBit("useDefaultTypePrinterParser"); +} + Dialect::EmitPrefix Dialect::getEmitAccessorPrefix() const { int prefix = def->getValueAsInt("emitAccessorPrefix"); if (prefix < 0 || prefix > static_cast(EmitPrefix::Both)) diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp index f9e86a0eedcf6f..d6ea2f2877ce36 100644 --- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp @@ -528,6 +528,32 @@ void {0}::printAttribute(::mlir::Attribute attr, } )"; +/// The code block for default type parser/printer dispatch boilerplate. +/// {0}: the dialect fully qualified class name. +static const char *const dialectDefaultTypePrinterParserDispatch = R"( +/// Parse a type registered to this dialect. +::mlir::Type {0}::parseType(::mlir::DialectAsmParser &parser) const {{ + llvm::SMLoc typeLoc = parser.getCurrentLocation(); + StringRef mnemonic; + if (parser.parseKeyword(&mnemonic)) + return Type(); + Type genType; + OptionalParseResult parseResult = + generatedTypeParser(parser, mnemonic, genType); + if (parseResult.hasValue()) + return genType; + parser.emitError(typeLoc) << "unknown type `" + << mnemonic << "` in dialect `" << getNamespace() << "`"; + return {{}; +} +/// Print a type registered to this dialect. +void {0}::printType(::mlir::Type type, + ::mlir::DialectAsmPrinter &printer) const {{ + if (succeeded(generatedTypePrinter(type, printer))) + return; +} +)"; + /// The code block used to start the auto-generated printer function. /// /// {0}: The name of the base value type, e.g. Attribute or Type. @@ -1020,6 +1046,12 @@ bool DefGenerator::emitDefs(StringRef selectedDialect) { os << llvm::formatv(dialectDefaultAttrPrinterParserDispatch, defs.front().getDialect().getCppClassName()); + // Emit the default parser/printer for Types if the dialect asked for it. + if (valueType == "Type" && + defs.front().getDialect().useDefaultTypePrinterParser()) + os << llvm::formatv(dialectDefaultTypePrinterParserDispatch, + defs.front().getDialect().getCppClassName()); + return false; } diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp index 7767b257312e1e..7e3a96e5980e2e 100644 --- a/mlir/tools/mlir-tblgen/DialectGen.cpp +++ b/mlir/tools/mlir-tblgen/DialectGen.cpp @@ -210,7 +210,7 @@ static void emitDialectDecl(Dialect &dialect, // add the hooks for parsing/printing. if (!dialectAttrs.empty() || dialect.useDefaultAttributePrinterParser()) os << attrParserDecl; - if (!dialectTypes.empty()) + if (!dialectTypes.empty() || dialect.useDefaultTypePrinterParser()) os << typeParserDecl; // Add the decls for the various features of the dialect.