diff --git a/mlir/include/mlir/IR/AttrTypeBase.td b/mlir/include/mlir/IR/AttrTypeBase.td index 91c9283de8bd41..d176b36068f7a5 100644 --- a/mlir/include/mlir/IR/AttrTypeBase.td +++ b/mlir/include/mlir/IR/AttrTypeBase.td @@ -256,7 +256,7 @@ class AttrDef traits = [], AttrOrTypeDef<"Attr", name, traits, baseCppClass> { // The name of the C++ Attribute class. string cppClassName = name # "Attr"; - let storageType = dialect.cppNamespace # "::" # name # "Attr"; + let storageType = dialect.cppNamespace # "::" # cppClassName; // The underlying C++ value type let returnType = dialect.cppNamespace # "::" # cppClassName; @@ -275,12 +275,10 @@ class AttrDef traits = [], // // For example, `$_self.getValue().getSExtValue()` for `IntegerAttr val` will // expand to `getAttrOfType("val").getValue().getSExtValue()`. - let convertFromStorage = "::llvm::cast<" # dialect.cppNamespace # - "::" # cppClassName # ">($_self)"; + let convertFromStorage = "::llvm::cast<" # cppType # ">($_self)"; // The predicate for when this def is used as a constraint. - let predicate = CPred<"::llvm::isa<" # dialect.cppNamespace # - "::" # cppClassName # ">($_self)">; + let predicate = CPred<"::llvm::isa<" # cppType # ">($_self)">; } // Define a new type, named `name`, belonging to `dialect` that inherits from @@ -289,6 +287,9 @@ class TypeDef traits = [], string baseCppClass = "::mlir::Type"> : DialectType, /*descr*/"", name # "Type">, AttrOrTypeDef<"Type", name, traits, baseCppClass> { + // The name of the C++ Type class. + string cppClassName = name # "Type"; + // Make it possible to use such type as parameters for other types. string cppType = dialect.cppNamespace # "::" # cppClassName; @@ -297,12 +298,11 @@ class TypeDef traits = [], // A constant builder provided when the type has no parameters. let builderCall = !if(!empty(parameters), - "$_builder.getType<" # dialect.cppNamespace # - "::" # cppClassName # ">()", + "$_builder.getType<" # cppType # ">()", ""); + // The predicate for when this def is used as a constraint. - let predicate = CPred<"::llvm::isa<" # dialect.cppNamespace # - "::" # cppClassName # ">($_self)">; + let predicate = CPred<"::llvm::isa<" # cppType # ">($_self)">; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/CommonAttrConstraints.td b/mlir/include/mlir/IR/CommonAttrConstraints.td index d99bde1f87ef00..6774a7c568315d 100644 --- a/mlir/include/mlir/IR/CommonAttrConstraints.td +++ b/mlir/include/mlir/IR/CommonAttrConstraints.td @@ -169,14 +169,14 @@ def AnyAttr : Attr, "any attribute"> { // Any attribute from the given list class AnyAttrOf allowedAttrs, string summary = "", - string cppClassName = "::mlir::Attribute", + string cppType = "::mlir::Attribute", string fromStorage = "$_self"> : Attr< // Satisfy any of the allowed attribute's condition Or, !if(!eq(summary, ""), !interleave(!foreach(t, allowedAttrs, t.summary), " or "), summary)> { - let returnType = cppClassName; + let returnType = cppType; let convertFromStorage = fromStorage; } @@ -369,7 +369,7 @@ def TypeAttr : TypeAttrBase<"::mlir::Type", "any type attribute"> { } class TypeAttrOf - : TypeAttrBase { let constBuilderCall = "::mlir::TypeAttr::get($0)"; } diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td index 2493f212a356a4..70c3e485679e57 100644 --- a/mlir/include/mlir/IR/CommonTypeConstraints.td +++ b/mlir/include/mlir/IR/CommonTypeConstraints.td @@ -98,23 +98,23 @@ def HasValueSemanticsPred : CPred<"$_self.hasTrait<::mlir::ValueSemantics>()">; // A type, carries type constraints. class Type : - TypeConstraint { + string cppType = "::mlir::Type"> : + TypeConstraint { string description = ""; string builderCall = ""; } // Allows providing an alternative name and summary to an existing type def. class TypeAlias : - Type { + Type { let description = t.description; let builderCall = t.builderCall; } // A type of a specific dialect. class DialectType : - Type { + string cppType = "::mlir::Type"> : + Type { Dialect dialect = d; } @@ -122,7 +122,7 @@ class DialectType : TypeConstraint { + type.cppType> { Type baseType = type; int minSize = 0; } @@ -140,7 +140,7 @@ class VariadicOfVariadic // An optional type constraint. It expands to either zero or one of the base // type. This class is used for supporting optional operands/results. class Optional : TypeConstraint { + type.cppType> { Type baseType = type; } @@ -172,33 +172,33 @@ def NoneType : Type($_self)">, "none type", // Any type from the given list class AnyTypeOf allowedTypeList, string summary = "", - string cppClassName = "::mlir::Type"> : Type< + string cppType = "::mlir::Type"> : Type< // Satisfy any of the allowed types' conditions. Or, !if(!eq(summary, ""), !interleave(!foreach(t, allowedTypeList, t.summary), " or "), summary), - cppClassName> { + cppType> { list allowedTypes = allowedTypeList; } // A type that satisfies the constraints of all given types. class AllOfType allowedTypeList, string summary = "", - string cppClassName = "::mlir::Type"> : Type< + string cppType = "::mlir::Type"> : Type< // Satisfy all of the allowed types' conditions. And, !if(!eq(summary, ""), !interleave(!foreach(t, allowedTypeList, t.summary), " and "), summary), - cppClassName> { + cppType> { list allowedTypes = allowedTypeList; } // A type that satisfies additional predicates. class ConfinedType predicates, string summary = "", - string cppClassName = type.cppClassName> : Type< + string cppType = type.cppType> : Type< And, - summary, cppClassName>; + summary, cppType>; // Integer types. @@ -375,23 +375,23 @@ def FunctionType : Type($_self)">, // A container type is a type that has another type embedded within it. class ContainerType : + string descr, string cppType = "::mlir::Type"> : // First, check the container predicate. Then, substitute the extracted // element into the element type checker. Type(elementTypeCall), etype.predicate>]>, - descr # " of " # etype.summary # " values", cppClassName>; + descr # " of " # etype.summary # " values", cppType>; class ShapedContainerType allowedTypes, Pred containerPred, string descr, - string cppClassName = "::mlir::Type"> : + string cppType = "::mlir::Type"> : Type.predicate>, "; }(::llvm::cast<::mlir::ShapedType>($_self).getElementType())">]>, - descr # " of " # AnyTypeOf.summary # " values", cppClassName>; + descr # " of " # AnyTypeOf.summary # " values", cppType>; // Whether a shaped type is ranked. def HasRankPred : CPred<"::llvm::cast<::mlir::ShapedType>($_self).hasRank()">; diff --git a/mlir/include/mlir/IR/Constraints.td b/mlir/include/mlir/IR/Constraints.td index a026d58ccffb8e..39bc55db63da1a 100644 --- a/mlir/include/mlir/IR/Constraints.td +++ b/mlir/include/mlir/IR/Constraints.td @@ -149,10 +149,10 @@ class Constraint { // Subclass for constraints on a type. class TypeConstraint : + string cppTypeParam = "::mlir::Type"> : Constraint { // The name of the C++ Type class if known, or Type if not. - string cppClassName = cppClassNameParam; + string cppType = cppTypeParam; } // Subclass for constraints on an attribute. diff --git a/mlir/include/mlir/TableGen/Type.h b/mlir/include/mlir/TableGen/Type.h index 06cf4f5730d565..b974ac281041bc 100644 --- a/mlir/include/mlir/TableGen/Type.h +++ b/mlir/include/mlir/TableGen/Type.h @@ -56,8 +56,8 @@ class TypeConstraint : public Constraint { // returns std::nullopt otherwise. std::optional getBuilderCall() const; - // Return the C++ class name for this type (which may just be ::mlir::Type). - std::string getCPPClassName() const; + // Return the C++ type for this type (which may just be ::mlir::Type). + StringRef getCppType() const; }; // Wrapper class with helper methods for accessing Types defined in TableGen. diff --git a/mlir/lib/TableGen/Type.cpp b/mlir/lib/TableGen/Type.cpp index e9f2394dd540af..cda752297988bb 100644 --- a/mlir/lib/TableGen/Type.cpp +++ b/mlir/lib/TableGen/Type.cpp @@ -59,20 +59,9 @@ std::optional TypeConstraint::getBuilderCall() const { .Default([](auto *) { return std::nullopt; }); } -// Return the C++ class name for this type (which may just be ::mlir::Type). -std::string TypeConstraint::getCPPClassName() const { - StringRef className = def->getValueAsString("cppClassName"); - - // If the class name is already namespace resolved, use it. - if (className.contains("::")) - return className.str(); - - // Otherwise, check to see if there is a namespace from a dialect to prepend. - if (const llvm::RecordVal *value = def->getValue("dialect")) { - Dialect dialect(cast(value->getValue())->getDef()); - return (dialect.getCppNamespace() + "::" + className).str(); - } - return className.str(); +// Return the C++ type for this type (which may just be ::mlir::Type). +StringRef TypeConstraint::getCppType() const { + return def->getValueAsString("cppType"); } Type::Type(const llvm::Record *record) : TypeConstraint(record) {} diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp index 1f0df033d43398..01c78e280080ee 100644 --- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp +++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp @@ -879,8 +879,7 @@ void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords, -> const ods::TypeConstraint & { return odsContext.insertTypeConstraint( cst.constraint.getUniqueDefName(), - processDoc(cst.constraint.getSummary()), - cst.constraint.getCPPClassName()); + processDoc(cst.constraint.getSummary()), cst.constraint.getCppType()); }; auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange { return {loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)}; @@ -944,7 +943,7 @@ void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords, tblgen::TypeConstraint constraint(def); decls.push_back(createODSNativePDLLConstraintDecl( constraint, convertLocToRange(def->getLoc().front()), typeTy, - constraint.getCPPClassName())); + constraint.getCppType())); } /// OpInterfaces. ast::Type opTy = ast::OperationType::get(ctx); diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index a2ceefb34db453..66dbb16760ebb0 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -2085,21 +2085,8 @@ static void generateValueRangeStartAndEnd( } static std::string generateTypeForGetter(const NamedTypeConstraint &value) { - std::string str = "::mlir::Value"; - /// If the CPPClassName is not a fully qualified type. Uses of types - /// across Dialect fail because they are not in the correct namespace. So we - /// dont generate TypedValue unless the type is fully qualified. - /// getCPPClassName doesn't return the fully qualified path for - /// `mlir::pdl::OperationType` see - /// https://github.com/llvm/llvm-project/issues/57279. - /// Adaptor will have values that are not from the type of their operation and - /// this is expected, so we dont generate TypedValue for Adaptor - if (value.constraint.getCPPClassName() != "::mlir::Type" && - StringRef(value.constraint.getCPPClassName()).starts_with("::")) - str = llvm::formatv("::mlir::TypedValue<{0}>", - value.constraint.getCPPClassName()) - .str(); - return str; + return llvm::formatv("::mlir::TypedValue<{0}>", value.constraint.getCppType()) + .str(); } // Generates the named operand getter methods for the given Operator `op` and @@ -3944,7 +3931,7 @@ void OpEmitter::genTraits() { // For single result ops with a known specific type, generate a OneTypedResult // trait. if (numResults == 1 && numVariadicResults == 0) { - auto cppName = op.getResults().begin()->constraint.getCPPClassName(); + auto cppName = op.getResults().begin()->constraint.getCppType(); opClass.addTrait("::mlir::OpTrait::OneTypedResult<" + cppName + ">::Impl"); } diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp index 27ad79a5c1efed..9a95f495b77658 100644 --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -1657,7 +1657,7 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body, TypeSwitch(dir->getArg()) .Case([&](auto operand) { body << formatv(parserCode, - operand->getVar()->constraint.getCPPClassName(), + operand->getVar()->constraint.getCppType(), listName); }) .Default([&](auto operand) { @@ -2603,7 +2603,7 @@ void OperationFormat::genElementPrinter(FormatElement *element, } if (var && !var->isVariadicOfVariadic() && !var->isVariadic() && !var->isOptional()) { - std::string cppClass = var->constraint.getCPPClassName(); + StringRef cppType = var->constraint.getCppType(); if (dir->shouldBeQualified()) { body << " _odsPrinter << " << op.getGetterName(var->name) << "().getType();\n"; @@ -2612,7 +2612,7 @@ void OperationFormat::genElementPrinter(FormatElement *element, body << " {\n" << " auto type = " << op.getGetterName(var->name) << "().getType();\n" - << " if (auto validType = ::llvm::dyn_cast<" << cppClass + << " if (auto validType = ::llvm::dyn_cast<" << cppType << ">(type))\n" << " _odsPrinter.printStrippedAttrOrType(validType);\n" << " else\n"