From 69e4d06559f8050f231abe021d004e49fa9d7e4f Mon Sep 17 00:00:00 2001 From: Jack Koenig Date: Wed, 6 May 2020 11:26:35 -0700 Subject: [PATCH] Remove unnecessary traversal in ResolveKinds --- .../scala/firrtl/passes/ResolveKinds.scala | 31 ++++++++++++------- 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/src/main/scala/firrtl/passes/ResolveKinds.scala b/src/main/scala/firrtl/passes/ResolveKinds.scala index 98336e0e21..077267dfc2 100644 --- a/src/main/scala/firrtl/passes/ResolveKinds.scala +++ b/src/main/scala/firrtl/passes/ResolveKinds.scala @@ -5,6 +5,7 @@ package firrtl.passes import firrtl._ import firrtl.ir._ import firrtl.Mappers._ +import firrtl.traversals.Foreachers._ import firrtl.options.PreservesAll import scala.collection.mutable @@ -17,25 +18,32 @@ object ResolveKinds extends Pass with PreservesAll[Transform] { type KindMap = collection.mutable.LinkedHashMap[String, Kind] @deprecated("This API should never have been public", "1.3.1") - def find_port(kinds: KindMap)(p: Port): Port = findPort(kinds)(p) + def find_port(kinds: KindMap)(p: Port): Port = { + recordPort(kinds)(p) + p + } @deprecated("This API should never have been public", "1.3.1") - def find_stmt(kinds: KindMap)(s: Statement): Statement = findStmt(kinds)(s) + def find_stmt(kinds: KindMap)(s: Statement): Statement = { + recordKind(kinds, s) + s.map(find_stmt(kinds)) + } @deprecated("This API should never have been public", "1.3.1") def resolve_expr(kinds: KindMap)(e: Expression): Expression = onExpr(kinds)(e) @deprecated("This API should never have been public", "1.3.1") - def resolve_stmt(kinds: KindMap)(s: Statement): Statement = onStmt(kinds)(s) + def resolve_stmt(kinds: KindMap)(s: Statement): Statement = + s.map(resolve_stmt(kinds)).map(onExpr(kinds)) private type NewKindMap = mutable.Map[String, Kind] - private def findPort(kinds: NewKindMap)(p: Port): Port = { + private def recordPort(kinds: NewKindMap)(p: Port): Unit = { kinds(p.name) = PortKind - p } - private def findStmt(kinds: NewKindMap)(s: Statement): Statement = { + // Note: this is *not* recursive + private def recordKind(kinds: NewKindMap, s: Statement): Unit = s match { case sx: DefWire => kinds(sx.name) = WireKind case sx: DefNode => kinds(sx.name) = NodeKind @@ -44,22 +52,21 @@ object ResolveKinds extends Pass with PreservesAll[Transform] { case sx: DefMemory => kinds(sx.name) = MemKind case _ => } - s.map(findStmt(kinds)) - } private def onExpr(kinds: NewKindMap)(e: Expression): Expression = e match { case ex: WRef => ex.copy(kind = kinds(ex.name)) case _ => e.map(onExpr(kinds)) } - private def onStmt(kinds: NewKindMap)(s: Statement): Statement = + private def onStmt(kinds: NewKindMap)(s: Statement): Statement = { + recordKind(kinds, s) s.map(onStmt(kinds)).map(onExpr(kinds)) + } def resolve_kinds(m: DefModule): DefModule = { val kinds = new mutable.HashMap[String, Kind] - m.map(findPort(kinds)) - .map(findStmt(kinds)) - .map(onStmt(kinds)) + m.foreach(recordPort(kinds)) + m.map(onStmt(kinds)) } def run(c: Circuit): Circuit =