Skip to content
This repository has been archived by the owner on Aug 20, 2024. It is now read-only.

Commit

Permalink
Remove unnecessary traversal in ResolveKinds
Browse files Browse the repository at this point in the history
  • Loading branch information
jackkoenig committed May 6, 2020
1 parent a36f34c commit 69e4d06
Showing 1 changed file with 19 additions and 12 deletions.
31 changes: 19 additions & 12 deletions src/main/scala/firrtl/passes/ResolveKinds.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package firrtl.passes
import firrtl._
import firrtl.ir._
import firrtl.Mappers._
import firrtl.traversals.Foreachers._
import firrtl.options.PreservesAll

import scala.collection.mutable
Expand All @@ -17,25 +18,32 @@ object ResolveKinds extends Pass with PreservesAll[Transform] {
type KindMap = collection.mutable.LinkedHashMap[String, Kind]

@deprecated("This API should never have been public", "1.3.1")
def find_port(kinds: KindMap)(p: Port): Port = findPort(kinds)(p)
def find_port(kinds: KindMap)(p: Port): Port = {
recordPort(kinds)(p)
p
}

@deprecated("This API should never have been public", "1.3.1")
def find_stmt(kinds: KindMap)(s: Statement): Statement = findStmt(kinds)(s)
def find_stmt(kinds: KindMap)(s: Statement): Statement = {
recordKind(kinds, s)
s.map(find_stmt(kinds))
}

@deprecated("This API should never have been public", "1.3.1")
def resolve_expr(kinds: KindMap)(e: Expression): Expression = onExpr(kinds)(e)

@deprecated("This API should never have been public", "1.3.1")
def resolve_stmt(kinds: KindMap)(s: Statement): Statement = onStmt(kinds)(s)
def resolve_stmt(kinds: KindMap)(s: Statement): Statement =
s.map(resolve_stmt(kinds)).map(onExpr(kinds))

private type NewKindMap = mutable.Map[String, Kind]

private def findPort(kinds: NewKindMap)(p: Port): Port = {
private def recordPort(kinds: NewKindMap)(p: Port): Unit = {
kinds(p.name) = PortKind
p
}

private def findStmt(kinds: NewKindMap)(s: Statement): Statement = {
// Note: this is *not* recursive
private def recordKind(kinds: NewKindMap, s: Statement): Unit =
s match {
case sx: DefWire => kinds(sx.name) = WireKind
case sx: DefNode => kinds(sx.name) = NodeKind
Expand All @@ -44,22 +52,21 @@ object ResolveKinds extends Pass with PreservesAll[Transform] {
case sx: DefMemory => kinds(sx.name) = MemKind
case _ =>
}
s.map(findStmt(kinds))
}

private def onExpr(kinds: NewKindMap)(e: Expression): Expression = e match {
case ex: WRef => ex.copy(kind = kinds(ex.name))
case _ => e.map(onExpr(kinds))
}

private def onStmt(kinds: NewKindMap)(s: Statement): Statement =
private def onStmt(kinds: NewKindMap)(s: Statement): Statement = {
recordKind(kinds, s)
s.map(onStmt(kinds)).map(onExpr(kinds))
}

def resolve_kinds(m: DefModule): DefModule = {
val kinds = new mutable.HashMap[String, Kind]
m.map(findPort(kinds))
.map(findStmt(kinds))
.map(onStmt(kinds))
m.foreach(recordPort(kinds))
m.map(onStmt(kinds))
}

def run(c: Circuit): Circuit =
Expand Down

0 comments on commit 69e4d06

Please sign in to comment.