diff --git a/src/main/scala/firrtl/transforms/Dedup.scala b/src/main/scala/firrtl/transforms/Dedup.scala index c3149e5554..199c492189 100644 --- a/src/main/scala/firrtl/transforms/Dedup.scala +++ b/src/main/scala/firrtl/transforms/Dedup.scala @@ -5,12 +5,15 @@ package transforms import firrtl.ir._ import firrtl.Mappers._ +import firrtl.traversals.Foreachers._ import firrtl.analyses.InstanceGraph import firrtl.annotations._ import firrtl.passes.{InferTypes, MemPortUtils} -import firrtl.Utils.throwInternalError +import firrtl.Utils.{kind, splitRef, throwInternalError} import firrtl.options.{HasShellOptions, PreservesAll, ShellOption} +import scala.annotation.tailrec + // Datastructures import scala.collection.mutable @@ -276,6 +279,48 @@ object DedupModules { changeInternals({n => n}, retype, {i => i}, renameOfModule)(module) } + @tailrec + private def hasBundleType(tpe: Type): Boolean = tpe match { + case _: BundleType => true + case _: GroundType => false + case VectorType(t, _) => hasBundleType(t) + } + + // Find modules that should not have their ports agnostified to avoid bug in + // https://github.com/freechipsproject/firrtl/issues/1703 + // Marks modules that have a port of BundleType that are connected via an aggregate connect or + // partial connect in an instantiating parent + // Order of modules does not matter + private def modsToNotAgnostifyPorts(modules: Seq[DefModule]): Set[String] = { + val dontDedup = mutable.HashSet.empty[String] + def onModule(mod: DefModule): Unit = { + val instToModule = mutable.HashMap.empty[String, String] + def markAggregatePorts(expr: Expression): Unit = { + if (kind(expr) == InstanceKind && hasBundleType(expr.tpe)) { + val (WRef(inst, _, _, _), _) = splitRef(expr) + dontDedup += instToModule(inst) + } + } + def onStmt(stmt: Statement): Unit = { + stmt.foreach(onStmt) + stmt match { + case inst: WDefInstance => + instToModule(inst.name) = inst.module + case Connect(_, lhs, rhs) => + markAggregatePorts(lhs) + markAggregatePorts(rhs) + case PartialConnect(_, lhs, rhs) => + markAggregatePorts(lhs) + markAggregatePorts(rhs) + case _ => + } + } + mod.foreach(onStmt) + } + modules.foreach(onModule) + dontDedup.toSet + } + //scalastyle:off /** Returns * 1) map of tag to all matching module names, @@ -340,6 +385,8 @@ object DedupModules { val agnosticRename = RenameMap() + val dontAgnostifyPorts = modsToNotAgnostifyPorts(moduleLinearization) + moduleLinearization.foreach { originalModule => // Replace instance references to new deduped modules val dontcare = RenameMap() @@ -360,9 +407,15 @@ object DedupModules { // Build tag val builder = new mutable.ArrayBuffer[Any]() - agnosticModule.ports.foreach { builder ++= _.serialize } builder += agnosticAnnos + // It may seem weird to use non-agnostified ports with an agnostified body because + // technically it would be invalid FIRRTL, but it is logically sound for the purpose of + // calculating deduplication tags + val ports = + if (dontAgnostifyPorts(originalModule.name)) originalModule.ports else agnosticModule.ports + ports.foreach { builder ++= _.serialize } + agnosticModule match { case Module(i, n, ps, b) => builder ++= fastSerializedHash(b).toString()//.serialize case ExtModule(i, n, ps, dn, p) => diff --git a/src/test/scala/firrtlTests/transforms/DedupTests.scala b/src/test/scala/firrtlTests/transforms/DedupTests.scala index 4cc19c9dde..4709051a7f 100644 --- a/src/test/scala/firrtlTests/transforms/DedupTests.scala +++ b/src/test/scala/firrtlTests/transforms/DedupTests.scala @@ -253,6 +253,61 @@ class DedupModuleTests extends HighTransformSpec { val diff_params = mkfir(("BB", "BB"), ("0", "1")) execute(diff_params, diff_params, Seq.empty) } + + "Modules with aggregate ports that are bulk connected" should "NOT dedup if their port names differ" in { + val input = + """ + |circuit FooAndBarModule : + | module FooModule : + | output io : {flip foo : UInt<1>, fuzz : UInt<1>} + | io.fuzz <= io.foo + | module BarModule : + | output io : {flip bar : UInt<1>, buzz : UInt<1>} + | io.buzz <= io.bar + | module FooAndBarModule : + | output io : {foo : {flip foo : UInt<1>, fuzz : UInt<1>}, bar : {flip bar : UInt<1>, buzz : UInt<1>}} + | inst foo of FooModule + | inst bar of BarModule + | io.foo <- foo.io + | io.bar <- bar.io + |""".stripMargin + val check = input + execute(input, check, Seq.empty) + } + + "Modules with aggregate ports that are bulk connected" should "dedup if their port names are the same" in { + val input = + """ + |circuit FooAndBarModule : + | module FooModule : + | output io : {flip foo : UInt<1>, fuzz : UInt<1>} + | io.fuzz <= io.foo + | module BarModule : + | output io : {flip foo : UInt<1>, fuzz : UInt<1>} + | io.fuzz <= io.foo + | module FooAndBarModule : + | output io : {foo : {flip foo : UInt<1>, fuzz : UInt<1>}, bar : {flip bar : UInt<1>, buzz : UInt<1>}} + | inst foo of FooModule + | inst bar of BarModule + | io.foo <- foo.io + | io.bar <- bar.io + |""".stripMargin + val check = + """ + |circuit FooAndBarModule : + | module FooModule : + | output io : {flip foo : UInt<1>, fuzz : UInt<1>} + | io.fuzz <= io.foo + | module FooAndBarModule : + | output io : {foo : {flip foo : UInt<1>, fuzz : UInt<1>}, bar : {flip bar : UInt<1>, buzz : UInt<1>}} + | inst foo of FooModule + | inst bar of FooModule + | io.foo <- foo.io + | io.bar <- bar.io + |""".stripMargin + execute(input, check, Seq.empty) + } + "The module A and B" should "be deduped with the first module in order" in { val input = """circuit Top :