diff --git a/include/circt/Dialect/Moore/MooreOps.td b/include/circt/Dialect/Moore/MooreOps.td index 5bd357ef0383..5a16a896117d 100644 --- a/include/circt/Dialect/Moore/MooreOps.td +++ b/include/circt/Dialect/Moore/MooreOps.td @@ -186,6 +186,7 @@ def ProcedureOp : MooreOp<"procedure", [ //===----------------------------------------------------------------------===// def VariableOp : MooreOp<"variable", [ + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, OptionalTypesMatchWith<"initial value and variable types match", @@ -412,6 +413,7 @@ def ConstantOp : MooreOp<"constant", [Pure]> { let results = (outs IntType:$result); let hasCustomAssemblyFormat = 1; let hasVerifier = 1; + let hasFolder = 1; let builders = [ OpBuilder<(ins "IntType":$type, "const APInt &":$value)>, OpBuilder<(ins "IntType":$type, "int64_t":$value)>, @@ -965,7 +967,7 @@ def ExtractRefOp : MooreOp<"extract_ref"> { }]; } -def StructCreateOp : MooreOp<"struct_create", [SameOperandsAndResultType]> { +def StructCreateOp : MooreOp<"struct_create", [Pure]> { let summary = "Struct Create operation"; let description = [{ A structure represents a collection of data types @@ -990,15 +992,18 @@ def StructCreateOp : MooreOp<"struct_create", [SameOperandsAndResultType]> { ``` See IEEE 1800-2017 § 7.2. "Structures". }]; - let arguments = (ins UnpackedType:$input); - let results = (outs UnpackedType:$result); - let hasCustomAssemblyFormat = 1; + let arguments = (ins Variadic:$input); + let results = (outs RefType:$result); let assemblyFormat = [{ - $input attr-dict `:` type($input) + $input attr-dict `:` type($input) `->` type($result) }]; + let hasVerifier = 1; + let hasFolder = 1; } -def StructExtractOp : MooreOp<"struct_extract"> { +def StructExtractOp : MooreOp<"struct_extract", [ + DeclareOpInterfaceMethods +]> { let summary = "Struct Extract operation"; let description = [{ Structures can be converted to bits preserving the bit pattern. @@ -1026,16 +1031,29 @@ def StructExtractOp : MooreOp<"struct_extract"> { ``` See IEEE 1800-2017 § 7.2.1 "Assigning to structures". }]; - let arguments = (ins StrAttr:$memberName, UnpackedType:$input); + let arguments = (ins StrAttr:$fieldName, Arg:$input); let results = (outs UnpackedType:$result); let assemblyFormat = [{ - $input `,` $memberName attr-dict `:` - type($input) `->` - type($result) + $input `,` $fieldName attr-dict `:` type($input) `->` type($result) + }]; + let hasVerifier = 1; + let hasFolder = 1; +} + +def StructExtractRefOp : MooreOp<"struct_extract_ref", [ + DeclareOpInterfaceMethods +]> { + let summary = "Struct Extract operation"; + let arguments = (ins StrAttr:$fieldName, RefType:$input); + let results = (outs RefType:$result); + let assemblyFormat = [{ + $input `,` $fieldName attr-dict `:` + type($input) `->` type($result) }]; + let hasVerifier = 1; } -def StructInjectOp : MooreOp<"struct_inject"> { +def StructInjectOp : MooreOp<"struct_inject", [Pure]> { let summary = "Struct Field operation"; let description = [{ A structure can be assigned as a whole and passed to @@ -1059,16 +1077,16 @@ def StructInjectOp : MooreOp<"struct_inject"> { ``` See IEEE 1800-2017 § 7.2. "Assigning to structures". }]; - let arguments = (ins UnpackedType:$LHS, StrAttr:$memberName, - UnpackedType:$RHS); - let results = (outs UnpackedType:$result); - let assemblyFormat = [{ - $LHS `,` $memberName `,` $RHS - attr-dict `:` type($LHS) type($RHS) `->` type($result) - }]; + let arguments = (ins RefType:$input, StrAttr:$fieldName, + UnpackedType:$newValue); + let results = (outs RefType:$result); + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; + let hasFolder = 1; + let hasCanonicalizeMethod = true; } -def UnionCreateOp : MooreOp<"union_create"> { +def UnionCreateOp : MooreOp<"union_create", [Pure]> { let summary = "Union Create operation"; let description = [{ A union is a data type that represents a single piece @@ -1093,12 +1111,12 @@ def UnionCreateOp : MooreOp<"union_create"> { ``` See IEEE 1800-2017 § 7.3 "Unions" }]; - let arguments = (ins StrAttr:$unionName, UnpackedType:$input); + let arguments = (ins UnpackedType:$input, StrAttr:$fieldName); let results = (outs UnpackedType:$result); let assemblyFormat = [{ - $unionName `,` $input attr-dict `:` - type($input) `->` type($result) + $input attr-dict `:` type($input) `->` type($result) }]; + let hasVerifier = 1; } def UnionExtractOp : MooreOp<"union_extract"> { @@ -1124,17 +1142,29 @@ def UnionExtractOp : MooreOp<"union_extract"> { ``` See IEEE 1800-2017 § 7.3.1 "Packed unions" }]; - let arguments = (ins StrAttr:$memberName, UnpackedType:$input); + let arguments = (ins StrAttr:$fieldName, UnpackedType:$input); let results = (outs UnpackedType:$result); let assemblyFormat = [{ - $input `,` $memberName attr-dict `:` - type($input) `->` type($result) + $input `,` $fieldName attr-dict `:` + type($input) `->` type($result) }]; + let hasVerifier = 1; +} + +def UnionExtractRefOp : MooreOp<"union_extract_ref"> { + let summary = "Union Extract operation"; + let arguments = (ins StrAttr:$fieldName, RefType:$input); + let results = (outs RefType:$result); + let assemblyFormat = [{ + $input `,` $fieldName attr-dict `:` + type($input) `->` type($result) + }]; + let hasVerifier = 1; } def ConditionalOp : MooreOp<"conditional",[ RecursiveMemoryEffects, - NoRegionArguments, + NoRegionArguments, SingleBlockImplicitTerminator<"moore::YieldOp"> ]> { let summary = "Conditional operation"; @@ -1174,8 +1204,8 @@ def ConditionalOp : MooreOp<"conditional",[ } def YieldOp : MooreOp<"yield", [ - Pure, - Terminator, + Pure, + Terminator, HasParent<"ConditionalOp"> ]> { let summary = "conditional yield and termination operation"; diff --git a/include/circt/Dialect/Moore/MooreTypes.h b/include/circt/Dialect/Moore/MooreTypes.h index 2a2a93376598..48b90901e49a 100644 --- a/include/circt/Dialect/Moore/MooreTypes.h +++ b/include/circt/Dialect/Moore/MooreTypes.h @@ -18,6 +18,7 @@ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Location.h" #include "mlir/IR/Types.h" +#include "mlir/Interfaces/MemorySlotInterfaces.h" #include namespace circt { diff --git a/include/circt/Dialect/Moore/MooreTypes.td b/include/circt/Dialect/Moore/MooreTypes.td index 7ef3c380eb61..700308b35a0b 100644 --- a/include/circt/Dialect/Moore/MooreTypes.td +++ b/include/circt/Dialect/Moore/MooreTypes.td @@ -16,6 +16,7 @@ include "circt/Dialect/Moore/MooreDialect.td" include "mlir/IR/AttrTypeBase.td" include "mlir/IR/EnumAttr.td" +include "mlir/Interfaces/MemorySlotInterfaces.td" class MooreTypeDef traits = [], string baseCppClass = "::mlir::Type"> @@ -284,9 +285,16 @@ class StructLikeType< let assemblyFormat = [{ `<` custom($members) `>` }]; + let extraClassDeclaration = [{ + std::optional> getSubelementIndexMap(); + Type getTypeAtIndex(Attribute index); + std::optional getFieldIndex(StringAttr nameField); + }]; } -def StructType : StructLikeType<"Struct", [], "moore::PackedType"> { +def StructType : StructLikeType<"Struct", [ + DeclareTypeInterfaceMethods], "moore::PackedType"> { let mnemonic = "struct"; let summary = "a packed struct type"; let description = [{ @@ -296,7 +304,9 @@ def StructType : StructLikeType<"Struct", [], "moore::PackedType"> { } def UnpackedStructType : StructLikeType< - "UnpackedStruct", [], "moore::UnpackedType" + "UnpackedStruct", [ + DeclareTypeInterfaceMethods], "moore::UnpackedType" > { let mnemonic = "ustruct"; let summary = "an unpacked struct type"; @@ -305,7 +315,9 @@ def UnpackedStructType : StructLikeType< }]; } -def UnionType : StructLikeType<"Union", [], "moore::PackedType"> { +def UnionType : StructLikeType<"Union", [ + DeclareTypeInterfaceMethods], "moore::PackedType"> { let mnemonic = "union"; let summary = "a packed union type"; let description = [{ @@ -316,7 +328,9 @@ def UnionType : StructLikeType<"Union", [], "moore::PackedType"> { def UnpackedUnionType : StructLikeType< - "UnpackedUnion", [], "moore::UnpackedType" + "UnpackedUnion", [ + DeclareTypeInterfaceMethods], "moore::UnpackedType" > { let mnemonic = "uunion"; let summary = "an unpacked union type"; @@ -329,7 +343,9 @@ def UnpackedUnionType : StructLikeType< // Reference type wrapper //===----------------------------------------------------------------------===// -def RefType : MooreTypeDef<"Ref", [], "moore::UnpackedType">{ +def RefType : MooreTypeDef<"Ref", [ + DeclareTypeInterfaceMethods], "moore::UnpackedType">{ let mnemonic = "ref"; let description = [{ A wrapper is used to wrap any SystemVerilog type. It's aimed to work for @@ -352,6 +368,9 @@ def RefType : MooreTypeDef<"Ref", [], "moore::UnpackedType">{ std::optional getBitSize() { return getNestedType().getBitSize(); }; + std::optional> getSubelementIndexMap(); + Type getTypeAtIndex(Attribute index); + std::optional getFieldIndex(StringAttr nameField); }]; } diff --git a/lib/Conversion/ImportVerilog/Expressions.cpp b/lib/Conversion/ImportVerilog/Expressions.cpp index d79f7c509977..01ebbb3e81db 100644 --- a/lib/Conversion/ImportVerilog/Expressions.cpp +++ b/lib/Conversion/ImportVerilog/Expressions.cpp @@ -99,6 +99,16 @@ struct RvalueExprVisitor { return {}; } + if (auto refOp = lhs.getDefiningOp()) { + auto input = refOp.getInput(); + if (isa(input.getDefiningOp()->getParentOp())) { + refOp.getInputMutable(); + refOp->erase(); + builder.create(loc, input.getType(), input, + refOp.getFieldNameAttr(), rhs); + return rhs; + } + } if (expr.isNonBlocking()) builder.create(loc, lhs, rhs); else @@ -467,7 +477,7 @@ struct RvalueExprVisitor { Value visit(const slang::ast::MemberAccessExpression &expr) { auto type = context.convertType(*expr.type); auto valueType = expr.value().type; - auto value = context.convertRvalueExpression(expr.value()); + auto value = context.convertLvalueExpression(expr.value()); if (!type || !value) return {}; if (valueType->isStruct()) { @@ -478,7 +488,9 @@ struct RvalueExprVisitor { return builder.create( loc, type, builder.getStringAttr(expr.member.name), value); } - llvm_unreachable("unsupported symbol kind"); + mlir::emitError(loc, "expression of type ") + << value.getType() << " cannot be accessed"; + return {}; } // Handle set membership operator. @@ -659,6 +671,27 @@ struct LvalueExprVisitor { lowBit); } + Value visit(const slang::ast::MemberAccessExpression &expr) { + auto type = context.convertType(*expr.type); + auto valueType = expr.value().type; + auto value = context.convertLvalueExpression(expr.value()); + if (!type || !value) + return {}; + if (valueType->isStruct()) { + return builder.create( + loc, moore::RefType::get(cast(type)), + builder.getStringAttr(expr.member.name), value); + } + if (valueType->isPackedUnion() || valueType->isUnpackedUnion()) { + return builder.create( + loc, moore::RefType::get(cast(type)), + builder.getStringAttr(expr.member.name), value); + } + mlir::emitError(loc, "expression of type ") + << value.getType() << " cannot be accessed"; + return {}; + } + // Handle range bits selections. Value visit(const slang::ast::RangeSelectExpression &expr) { auto type = context.convertType(*expr.type); diff --git a/lib/Conversion/ImportVerilog/Structure.cpp b/lib/Conversion/ImportVerilog/Structure.cpp index d890bd592e14..d8c9dc43b139 100644 --- a/lib/Conversion/ImportVerilog/Structure.cpp +++ b/lib/Conversion/ImportVerilog/Structure.cpp @@ -407,6 +407,17 @@ struct ModuleVisitor : public BaseVisitor { if (!lhs || !rhs) return failure(); + if (auto refOp = lhs.getDefiningOp()) { + auto input = refOp.getInput(); + if (isa(input.getDefiningOp()->getParentOp())) { + refOp.getInputMutable(); + refOp->erase(); + builder.create(loc, input.getType(), input, + refOp.getFieldNameAttr(), rhs); + return success(); + } + } + builder.create(loc, lhs, rhs); return success(); } diff --git a/lib/Dialect/Moore/MooreOps.cpp b/lib/Dialect/Moore/MooreOps.cpp index 761599c132af..fb9001a82dde 100644 --- a/lib/Dialect/Moore/MooreOps.cpp +++ b/lib/Dialect/Moore/MooreOps.cpp @@ -16,6 +16,7 @@ #include "circt/Support/CustomDirectiveImpl.h" #include "mlir/IR/Builders.h" #include "llvm/ADT/SmallString.h" +#include "llvm/ADT/TypeSwitch.h" using namespace circt; using namespace circt::moore; @@ -305,6 +306,50 @@ LogicalResult VariableOp::canonicalize(VariableOp op, return failure(); } +SmallVector VariableOp::getDestructurableSlots() { + if (isa(getOperation()->getParentOp())) + return {}; + + auto refType = getType(); + auto destructurable = llvm::dyn_cast(refType); + if (!destructurable) + return {}; + + auto destructuredType = destructurable.getSubelementIndexMap(); + if (!destructuredType) + return {}; + + return {DestructurableMemorySlot{{getResult(), refType}, *destructuredType}}; +} + +DenseMap VariableOp::destructure( + const DestructurableMemorySlot &slot, + const SmallPtrSetImpl &usedIndices, OpBuilder &builder, + SmallVectorImpl &newAllocators) { + assert(slot.ptr == getResult()); + builder.setInsertionPointAfter(*this); + + auto destructurableType = cast(getType()); + DenseMap slotMap; + for (Attribute index : usedIndices) { + auto elemType = cast(destructurableType.getTypeAtIndex(index)); + assert(elemType && "used index must exist"); + auto varOp = builder.create(getLoc(), elemType, + cast(index), Value()); + newAllocators.push_back(varOp); + slotMap.try_emplace(index, {varOp.getResult(), elemType}); + } + + return slotMap; +} + +std::optional +VariableOp::handleDestructuringComplete(const DestructurableMemorySlot &slot, + OpBuilder &builder) { + assert(slot.ptr == getResult()); + this->erase(); + return std::nullopt; +} //===----------------------------------------------------------------------===// // NetOp //===----------------------------------------------------------------------===// @@ -401,6 +446,11 @@ void ConstantOp::build(OpBuilder &builder, OperationState &result, IntType type, APInt(type.getWidth(), (uint64_t)value, /*isSigned=*/true)); } +OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { + assert(adaptor.getOperands().empty() && "constant has no operands"); + return getValueAttr(); +} + //===----------------------------------------------------------------------===// // NamedConstantOp //===----------------------------------------------------------------------===// @@ -449,6 +499,397 @@ LogicalResult ConcatRefOp::inferReturnTypes( return success(); } +//===----------------------------------------------------------------------===// +// StructCreateOp +//===----------------------------------------------------------------------===// + +LogicalResult StructCreateOp::verify() { + /// checks if the types of the inputs are exactly equal to the types of the + /// result struct fields + return TypeSwitch(getType().getNestedType()) + .Case([this](auto &type) { + auto members = type.getMembers(); + auto inputs = getInput(); + if (inputs.size() != members.size()) + return failure(); + for (size_t i = 0; i < members.size(); i++) { + auto memberType = cast(members[i].type); + auto inputType = inputs[i].getType(); + if (inputType != memberType) { + emitOpError("input types must match struct field types and orders"); + return failure(); + } + } + return success(); + }) + .Default([this](auto &) { + emitOpError("Result type must be StructType or UnpackedStructType"); + return failure(); + }); +} + +OpFoldResult StructCreateOp::fold(FoldAdaptor adaptor) { + auto inputs = adaptor.getInput(); + + if (llvm::any_of(inputs, [](Attribute attr) { return !attr; })) + return {}; + + auto members = TypeSwitch>( + cast(getType()).getNestedType()) + .Case( + [](auto &type) { return type.getMembers(); }) + .Default([](auto) { return std::nullopt; }); + SmallVector namedInputs; + for (auto [input, member] : llvm::zip(inputs, members)) + namedInputs.push_back(NamedAttribute(member.name, input)); + + return DictionaryAttr::get(getContext(), namedInputs); +} + +//===----------------------------------------------------------------------===// +// StructExtractOp +//===----------------------------------------------------------------------===// + +LogicalResult StructExtractOp::verify() { + /// checks if the type of the result match field type in this struct + return TypeSwitch(getInput().getType().getNestedType()) + .Case([this](auto &type) { + auto members = type.getMembers(); + auto filedName = getFieldName(); + auto resultType = getType(); + for (const auto &member : members) { + if (member.name == filedName && member.type == resultType) { + return success(); + } + } + emitOpError("result type must match struct field type"); + return failure(); + }) + .Default([this](auto &) { + emitOpError("input type must be StructType or UnpackedStructType"); + return failure(); + }); +} + +bool StructExtractOp::canRewire(const DestructurableMemorySlot &slot, + SmallPtrSetImpl &usedIndices, + SmallVectorImpl &mustBeSafelyUsed, + const DataLayout &dataLayout) { + if (slot.ptr == getInput()) { + usedIndices.insert(getFieldNameAttr()); + return true; + } + return false; +} + +DeletionKind StructExtractOp::rewire(const DestructurableMemorySlot &slot, + DenseMap &subslots, + OpBuilder &builder, + const DataLayout &dataLayout) { + auto index = getFieldNameAttr(); + const auto &memorySlot = subslots.at(index); + auto readOp = builder.create( + getLoc(), cast(memorySlot.elemType).getNestedType(), + memorySlot.ptr); + replaceAllUsesWith(readOp.getResult()); + getInputMutable().drop(); + erase(); + return DeletionKind::Keep; +} + +OpFoldResult StructExtractOp::fold(FoldAdaptor adaptor) { + if (auto constOperand = adaptor.getInput()) { + auto operandAttr = llvm::cast(constOperand); + for (const auto &ele : operandAttr) + if (ele.getName() == getFieldNameAttr()) + return ele.getValue(); + } + + if (auto structInject = getInput().getDefiningOp()) + return structInject.getFieldNameAttr() == getFieldNameAttr() + ? structInject.getNewValue() + : Value(); + if (auto structCreate = getInput().getDefiningOp()) { + auto ind = TypeSwitch>( + getInput().getType().getNestedType()) + .Case([this](auto &type) { + return type.getFieldIndex(getFieldNameAttr()); + }) + .Default([](auto &) { return std::nullopt; }); + return ind.has_value() ? structCreate->getOperand(ind.value()) : Value(); + } + return {}; +} + +//===----------------------------------------------------------------------===// +// StructExtractRefOp +//===----------------------------------------------------------------------===// + +LogicalResult StructExtractRefOp::verify() { + /// checks if the type of the result match field type in this struct + return TypeSwitch(getInput().getType().getNestedType()) + .Case([this](auto &type) { + auto members = type.getMembers(); + auto filedName = getFieldName(); + auto resultType = getType().getNestedType(); + for (const auto &member : members) { + if (member.name == filedName && member.type == resultType) { + return success(); + } + } + emitOpError("result type must match struct field type"); + return failure(); + }) + .Default([this](auto &) { + emitOpError("input type must be refrence of StructType or " + "UnpackedStructType"); + return failure(); + }); +} + +bool StructExtractRefOp::canRewire( + const DestructurableMemorySlot &slot, + SmallPtrSetImpl &usedIndices, + SmallVectorImpl &mustBeSafelyUsed, + const DataLayout &dataLayout) { + if (slot.ptr != getInput()) + return false; + auto index = getFieldNameAttr(); + if (!index || !slot.subelementTypes.contains(index)) + return false; + usedIndices.insert(index); + return true; +} + +DeletionKind +StructExtractRefOp::rewire(const DestructurableMemorySlot &slot, + DenseMap &subslots, + OpBuilder &builder, const DataLayout &dataLayout) { + auto index = getFieldNameAttr(); + const MemorySlot &memorySlot = subslots.at(index); + replaceAllUsesWith(memorySlot.ptr); + getInputMutable().drop(); + erase(); + return DeletionKind::Keep; +} + +//===----------------------------------------------------------------------===// +// StructInjectOp +//===----------------------------------------------------------------------===// + +LogicalResult StructInjectOp::verify() { + /// checks if the type of the new value match field type in this struct + return TypeSwitch(getInput().getType().getNestedType()) + .Case([this](auto &type) { + auto members = type.getMembers(); + auto filedName = getFieldName(); + auto newValueType = getNewValue().getType(); + for (const auto &member : members) { + if (member.name == filedName && member.type == newValueType) { + return success(); + } + } + emitOpError("new value type must match struct field type"); + return failure(); + }) + .Default([this](auto &) { + emitOpError("input type must be StructType or UnpackedStructType"); + return failure(); + }); +} + +void StructInjectOp::print(OpAsmPrinter &p) { + p << " "; + p.printOperand(getInput()); + p << ", " << getFieldNameAttr() << ", "; + p.printOperand(getNewValue()); + p << " : " << getInput().getType(); +} + +ParseResult StructInjectOp::parse(OpAsmParser &parser, OperationState &result) { + llvm::SMLoc inputOperandsLoc = parser.getCurrentLocation(); + OpAsmParser::UnresolvedOperand operand, val; + StringAttr fieldName; + Type declType; + + if (parser.parseOperand(operand) || parser.parseComma() || + parser.parseAttribute(fieldName) || parser.parseComma() || + parser.parseOperand(val) || parser.parseColonType(declType)) + return failure(); + + return TypeSwitch(cast(declType).getNestedType()) + .Case([&parser, &result, &declType, + &fieldName, &operand, &val, + &inputOperandsLoc](auto &type) { + auto members = type.getMembers(); + Type fieldType; + for (const auto &member : members) + if (member.name == fieldName) + fieldType = member.type; + if (!fieldType) { + parser.emitError(parser.getNameLoc(), + "field name '" + fieldName.getValue() + + "' not found in struct type"); + return failure(); + } + + auto fieldNameAttr = + StringAttr::get(parser.getContext(), Twine(fieldName)); + result.addAttribute("fieldName", fieldNameAttr); + result.addTypes(declType); + if (parser.resolveOperands({operand, val}, {declType, fieldType}, + inputOperandsLoc, result.operands)) + return failure(); + + return success(); + }) + .Default([&parser, &inputOperandsLoc](auto &) { + return parser.emitError(inputOperandsLoc, + "invalid kind of type specified"); + }); +} + +OpFoldResult StructInjectOp::fold(FoldAdaptor adaptor) { + auto input = adaptor.getInput(); + auto newValue = adaptor.getNewValue(); + if (!input || !newValue) + return {}; + SmallVector array; + llvm::copy(cast(input), std::back_inserter(array)); + for (auto &ele : array) { + if (ele.getName() == getFieldName()) + ele.setValue(newValue); + } + return DictionaryAttr::get(getContext(), array); +} + +LogicalResult StructInjectOp::canonicalize(StructInjectOp op, + PatternRewriter &rewriter) { + // Canonicalize multiple injects into a create op and eliminate overwrites. + SmallPtrSet injects; + DenseMap fields; + + // Chase a chain of injects. Bail out if cycles are present. + StructInjectOp inject = op; + Value input; + do { + if (!injects.insert(inject).second) + return failure(); + + fields.try_emplace(inject.getFieldNameAttr(), inject.getNewValue()); + input = inject.getInput(); + inject = input.getDefiningOp(); + } while (inject); + assert(input && "missing input to inject chain"); + + auto members = TypeSwitch>( + cast(op.getType()).getNestedType()) + .Case( + [](auto &type) { return type.getMembers(); }) + .Default([](auto) { return std::nullopt; }); + + // If the inject chain sets all fields, canonicalize to create. + if (fields.size() == members.size()) { + SmallVector createFields; + for (const auto &member : members) { + auto it = fields.find(member.name); + assert(it != fields.end() && "missing field"); + createFields.push_back(it->second); + } + op.getInputMutable(); + rewriter.replaceOpWithNewOp(op, op.getType(), createFields); + return success(); + } + + // Nothing to canonicalize, only the original inject in the chain. + if (injects.size() == fields.size()) + return failure(); + + // Eliminate overwrites. The hash map contains the last write to each field. + for (const auto &member : members) { + auto it = fields.find(member.name); + if (it == fields.end()) + continue; + input = rewriter.create(op.getLoc(), op.getType(), input, + member.name, it->second); + } + + rewriter.replaceOp(op, input); + return success(); +} + +//===----------------------------------------------------------------------===// +// UnionCreateOp +//===----------------------------------------------------------------------===// + +LogicalResult UnionCreateOp::verify() { + /// checks if the types of the input is exactly equal to the union field + /// type + return TypeSwitch(getType()) + .Case([this](auto &type) { + auto members = type.getMembers(); + auto resultType = getType(); + auto fieldName = getFieldName(); + for (const auto &member : members) + if (member.name == fieldName && member.type == resultType) + return success(); + emitOpError("input type must match the union field type"); + return failure(); + }) + .Default([this](auto &) { + emitOpError("input type must be UnionType or UnpackedUnionType"); + return failure(); + }); +} + +//===----------------------------------------------------------------------===// +// UnionExtractOp +//===----------------------------------------------------------------------===// + +LogicalResult UnionExtractOp::verify() { + /// checks if the types of the input is exactly equal to the one of the + /// types of the result union fields + return TypeSwitch(getInput().getType()) + .Case([this](auto &type) { + auto members = type.getMembers(); + auto fieldName = getFieldName(); + auto resultType = getType(); + for (const auto &member : members) + if (member.name == fieldName && member.type == resultType) + return success(); + emitOpError("result type must match the union field type"); + return failure(); + }) + .Default([this](auto &) { + emitOpError("input type must be UnionType or UnpackedUnionType"); + return failure(); + }); +} + +//===----------------------------------------------------------------------===// +// UnionExtractOp +//===----------------------------------------------------------------------===// + +LogicalResult UnionExtractRefOp::verify() { + /// checks if the types of the result is exactly equal to the type of the + /// refe union field + return TypeSwitch(getInput().getType().getNestedType()) + .Case([this](auto &type) { + auto members = type.getMembers(); + auto fieldName = getFieldName(); + auto resultType = getType().getNestedType(); + for (const auto &member : members) + if (member.name == fieldName && member.type == resultType) + return success(); + emitOpError("result type must match the union field type"); + return failure(); + }) + .Default([this](auto &) { + emitOpError("input type must be UnionType or UnpackedUnionType"); + return failure(); + }); +} + //===----------------------------------------------------------------------===// // YieldOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Moore/MooreTypes.cpp b/lib/Dialect/Moore/MooreTypes.cpp index f96afb5c3c6a..980ab78dc9ee 100644 --- a/lib/Dialect/Moore/MooreTypes.cpp +++ b/lib/Dialect/Moore/MooreTypes.cpp @@ -198,6 +198,115 @@ LogicalResult UnionType::verify(function_ref emitError, return verifyAllMembersPacked(emitError, members); } +//===----------------------------------------------------------------------===// +// Interfaces for destructurable +//===----------------------------------------------------------------------===// + +static std::optional> +getAllSubelementIndexMap(ArrayRef members) { + DenseMap destructured; + for (const auto &member : members) + destructured.insert({member.name, RefType::get(member.type)}); + return destructured; +} + +static Type getTypeAtAllIndex(ArrayRef members, + Attribute index) { + auto indexAttr = cast(index); + if (!indexAttr) + return {}; + for (const auto &member : members) { + if (member.name == indexAttr) { + return RefType::get(member.type); + } + } + return Type(); +} + +static std::optional +getFieldAllIndex(ArrayRef members, StringAttr nameField) { + for (uint32_t fieldIndex = 0; fieldIndex < members.size(); fieldIndex++) + if (members[fieldIndex].name == nameField) + return fieldIndex; + return std::nullopt; +} + +std::optional> StructType::getSubelementIndexMap() { + return getAllSubelementIndexMap(getMembers()); +} + +Type StructType::getTypeAtIndex(Attribute index) { + return getTypeAtAllIndex(getMembers(), index); +} + +std::optional StructType::getFieldIndex(StringAttr nameField) { + return getFieldAllIndex(getMembers(), nameField); +} + +std::optional> +UnpackedStructType::getSubelementIndexMap() { + return getAllSubelementIndexMap(getMembers()); +} + +Type UnpackedStructType::getTypeAtIndex(Attribute index) { + return getTypeAtAllIndex(getMembers(), index); +} + +std::optional +UnpackedStructType::getFieldIndex(StringAttr nameField) { + return getFieldAllIndex(getMembers(), nameField); +} + +std::optional> UnionType::getSubelementIndexMap() { + return getAllSubelementIndexMap(getMembers()); +} + +Type UnionType::getTypeAtIndex(Attribute index) { + return getTypeAtAllIndex(getMembers(), index); +} + +std::optional UnionType::getFieldIndex(StringAttr nameField) { + return getFieldAllIndex(getMembers(), nameField); +} + +std::optional> +UnpackedUnionType::getSubelementIndexMap() { + return getAllSubelementIndexMap(getMembers()); +} + +Type UnpackedUnionType::getTypeAtIndex(Attribute index) { + return getTypeAtAllIndex(getMembers(), index); +} + +std::optional UnpackedUnionType::getFieldIndex(StringAttr nameField) { + return getFieldAllIndex(getMembers(), nameField); +} + +std::optional> RefType::getSubelementIndexMap() { + return TypeSwitch>>( + getNestedType()) + .Case([](auto &type) { + return getAllSubelementIndexMap(type.getMembers()); + }) + .Default([](auto) { return std::nullopt; }); +} + +Type RefType::getTypeAtIndex(Attribute index) { + return TypeSwitch(getNestedType()) + .Case([&index](auto &type) { + return getTypeAtAllIndex(type.getMembers(), index); + }) + .Default([](auto) { return Type(); }); +} + +std::optional RefType::getFieldIndex(StringAttr nameField) { + return TypeSwitch>(getNestedType()) + .Case([&nameField](auto &type) { + return getFieldAllIndex(type.getMembers(), nameField); + }) + .Default([](auto) { return std::nullopt; }); +} + //===----------------------------------------------------------------------===// // Generated logic //===----------------------------------------------------------------------===// diff --git a/test/Conversion/ImportVerilog/basic.sv b/test/Conversion/ImportVerilog/basic.sv index 44a478ffbc64..7d0d8eeb9ae6 100644 --- a/test/Conversion/ImportVerilog/basic.sv +++ b/test/Conversion/ImportVerilog/basic.sv @@ -351,7 +351,7 @@ module Statements; // CHECK: scf.yield // CHECK: } for (y = x; x; x = z) x = y; - + // CHECK: [[TMP1:%.+]] = moore.read %i : i32 // CHECK: scf.while (%arg0 = [[TMP1]]) : (!moore.i32) -> !moore.i32 { // CHECK: [[TMP2:%.+]] = moore.bool_cast %arg0 : i32 -> i1 @@ -410,7 +410,7 @@ module Statements; // CHECK: moore.blocking_assign %y, [[TMP1]] : i1 // CHECK: moore.blocking_assign %x, [[TMP1]] : i1 x = (y = z); - + // CHECK: [[TMP1:%.+]] = moore.read %y : i1 // CHECK: moore.nonblocking_assign %x, [[TMP1]] : i1 x <= y; @@ -1001,7 +1001,14 @@ module Expressions; // CHECK: [[TMP2:%.+]] = moore.add [[A_ADD]], [[TMP1]] // CHECK: moore.blocking_assign %a, [[TMP2]] a += (a *= a--); - + + // CHECK: [[TMP1:%.+]] = moore.read %a : i32 + // CHECK: [[TMP2:%.+]] = moore.struct_inject %struct0, "a", [[TMP1]] : !moore.ref> + struct0.a = a; + + // CHECK: [[TMP3:%.+]] = moore.struct_extract %struct0, "b" : > -> i32 + // CHECK: moore.blocking_assign %b, [[TMP3]] : i32 + b = struct0.b; end endmodule diff --git a/test/Dialect/Moore/canonicalizers.mlir b/test/Dialect/Moore/canonicalizers.mlir index b8ebd6d6c8b1..a59d9872a22a 100644 --- a/test/Dialect/Moore/canonicalizers.mlir +++ b/test/Dialect/Moore/canonicalizers.mlir @@ -37,3 +37,66 @@ moore.module @MultiAssign() { moore.assign %a, %1 : i32 moore.output } + +// CHECK-LABEL: moore.module @structAssign +moore.module @structAssign(out a : !moore.ref>) { + %x = moore.variable : + %y = moore.variable : + %z = moore.variable : + // CHECK: %0 = moore.constant 4 : i32 + // CHECK: %1 = moore.read %x : i32 + // CHECK: %2 = moore.constant 1 : i32 + // CHECK: %3 = moore.add %1, %2 : i32 + // CHECK: %4 = moore.struct_create %3, %0 : !moore.i32, !moore.i32 -> > + %ii = moore.variable : > + %0 = moore.constant 4 : i32 + %1 = moore.conversion %0 : !moore.i32 -> !moore.i32 + %2 = moore.struct_inject %ii, "b", %1 : !moore.ref> + %3 = moore.read %x : i32 + %4 = moore.constant 1 : i32 + %5 = moore.add %3, %4 : i32 + %6 = moore.struct_inject %2, "a", %5 : !moore.ref> + %7 = moore.struct_extract %6, "a" : > -> i32 + // CHECK: moore.assign %y, %3 : i32 + moore.assign %y, %7 : i32 + %8 = moore.struct_extract %6, "a" : > -> i32 + // CHECK: moore.assign %z, %3 : i32 + moore.assign %z, %8 : i32 + // CHECK: moore.output %4 : !moore.ref> + moore.output %6 : !moore.ref> +} + +// CHECK-LABEL: moore.module @structInjectFold +moore.module @structInjectFold(out a : !moore.ref>) { + %x = moore.variable : + %y = moore.variable : + %z = moore.variable : + %ii = moore.variable : > + // CHECK: %0 = moore.read %x : i32 + // CHECK: %1 = moore.constant 1 : i32 + // CHECK: %2 = moore.add %0, %1 : i32 + // CHECK: %3 = moore.struct_inject %ii, "a", %2 : !moore.ref> + %0 = moore.constant 4 : i32 + %1 = moore.conversion %0 : !moore.i32 -> !moore.i32 + %2 = moore.struct_inject %ii, "a", %1 : !moore.ref> + %3 = moore.read %x : i32 + %4 = moore.constant 1 : i32 + %5 = moore.add %3, %4 : i32 + %6 = moore.struct_inject %2, "a", %5 : !moore.ref> + %7 = moore.struct_extract %6, "a" : > -> i32 + // CHECK: moore.assign %y, %2 : i32 + moore.assign %y, %7 : i32 + %8 = moore.struct_extract %6, "a" : > -> i32 + // CHECK: moore.assign %z, %2 : i32 + moore.assign %z, %8 : i32 + // CHECK: moore.output %3 : !moore.ref> + moore.output %6 : !moore.ref> +} + +// CHECK-LABEL: moore.module @structCreateFold +moore.module @structCreateFold(in %a : !moore.i1, out b : !moore.i1) { + %0 = moore.struct_create %a : !moore.i1 -> > + %1 = moore.struct_extract %0, "a" : > -> i1 + // CHECK: moore.output %a : !moore.i1 + moore.output %1 : !moore.i1 + } diff --git a/test/Dialect/Moore/sroa.mlir b/test/Dialect/Moore/sroa.mlir new file mode 100644 index 000000000000..f11a8867c2dd --- /dev/null +++ b/test/Dialect/Moore/sroa.mlir @@ -0,0 +1,60 @@ +// RUN: circt-opt --sroa %s | FileCheck %s + +// CHECK-LABEL: moore.module @LocalVar() { +moore.module @LocalVar() { +// CHECK: %x = moore.variable : +// CHECK: %y = moore.variable : +// CHECK: %z = moore.variable : +%x = moore.variable : +%y = moore.variable : +%z = moore.variable : +moore.procedure always_comb { + // CHECK: %a = moore.variable : + // CHECK: %b = moore.variable : + // CHECK: %0 = moore.constant 1 : i32 + // CHECK: %1 = moore.conversion %0 : !moore.i32 -> !moore.i32 + // CHECK: moore.blocking_assign %a, %1 : i32 + // CHECK: %2 = moore.constant 4 : i32 + // CHECK: %3 = moore.conversion %2 : !moore.i32 -> !moore.i32 + // CHECK: moore.blocking_assign %b, %3 : i32 + // CHECK: %4 = moore.read %x : i32 + // CHECK: %5 = moore.constant 1 : i32 + // CHECK: %6 = moore.add %4, %5 : i32 + // CHECK: moore.blocking_assign %a, %6 : i32 + // CHECK: %7 = moore.read %a : i32 + // CHECK: moore.blocking_assign %y, %7 : i32 + // CHECK: %8 = moore.read %a : i32 + // CHECK: %9 = moore.constant 1 : i32 + // CHECK: %10 = moore.add %8, %9 : i32 + // CHECK: moore.blocking_assign %a, %10 : i32 + // CHECK: %11 = moore.read %a : i32 + // CHECK: moore.blocking_assign %z, %11 : i32 + %ii = moore.variable : > + %0 = moore.struct_extract_ref %ii, "a" : > -> + %1 = moore.constant 1 : i32 + %2 = moore.conversion %1 : !moore.i32 -> !moore.i32 + moore.blocking_assign %0, %2 : i32 + %3 = moore.struct_extract_ref %ii, "b" : > -> + %4 = moore.constant 4 : i32 + %5 = moore.conversion %4 : !moore.i32 -> !moore.i32 + moore.blocking_assign %3, %5 : i32 + %6 = moore.struct_extract_ref %ii, "a" : > -> + %7 = moore.read %x : i32 + %8 = moore.constant 1 : i32 + %9 = moore.add %7, %8 : i32 + moore.blocking_assign %6, %9 : i32 + %10 = moore.struct_extract %ii, "a" : > -> i32 + moore.blocking_assign %y, %10 : i32 + %11 = moore.struct_extract_ref %ii, "a" : > -> + %12 = moore.struct_extract %ii, "a" : > -> i32 + %13 = moore.constant 1 : i32 + %14 = moore.add %12, %13 : i32 + moore.blocking_assign %11, %14 : i32 + %15 = moore.struct_extract %ii, "a" : > -> i32 + moore.blocking_assign %z, %15 : i32 +} +// CHECK: moore.output +moore.output +} + + diff --git a/tools/circt-opt/circt-opt.cpp b/tools/circt-opt/circt-opt.cpp index 3d9b87e59467..9c4be5a329d1 100644 --- a/tools/circt-opt/circt-opt.cpp +++ b/tools/circt-opt/circt-opt.cpp @@ -74,6 +74,7 @@ int main(int argc, char **argv) { // Register test passes circt::test::registerAnalysisTestPasses(); + mlir::registerSROA(); mlir::registerMem2RegPass(); return mlir::failed(mlir::MlirOptMain( diff --git a/tools/circt-verilog/circt-verilog.cpp b/tools/circt-verilog/circt-verilog.cpp index 0b3684a159b8..36c84f67f9d1 100644 --- a/tools/circt-verilog/circt-verilog.cpp +++ b/tools/circt-verilog/circt-verilog.cpp @@ -223,6 +223,8 @@ static LogicalResult populateMooreTransforms(mlir::PassManager &pm) { auto &modulePM = pm.nest(); modulePM.addPass(moore::createLowerConcatRefPass()); modulePM.addPass(moore::createSimplifyProceduresPass()); + + pm.addPass(mlir::createSROA()); pm.addPass(mlir::createMem2Reg()); // TODO: like dedup pass.