diff --git a/benchmark/src/main/scala/firrtl/benchmark/hot/ResolveKindsBenchmark.scala b/benchmark/src/main/scala/firrtl/benchmark/hot/ResolveKindsBenchmark.scala new file mode 100644 index 0000000000..61e9c1a64b --- /dev/null +++ b/benchmark/src/main/scala/firrtl/benchmark/hot/ResolveKindsBenchmark.scala @@ -0,0 +1,23 @@ + +package firrtl +package benchmark +package hot + +import passes.ResolveKinds +import stage.TransformManager + +import firrtl.benchmark.util._ + +object ResolveKindsBenchmark extends App { + val inputFile = args(0) + val warmup = args(1).toInt + val runs = args(2).toInt + + val input = filenameToCircuit(inputFile) + val state = CircuitState(input, ChirrtlForm) + val prereqs = ResolveKinds.prerequisites + val manager = new TransformManager(prereqs) + val preState = manager.execute(state) + + hot.util.benchmark(warmup, runs)(ResolveKinds.run(preState.circuit)) +} diff --git a/benchmark/src/main/scala/firrtl/benchmark/hot/util/package.scala b/benchmark/src/main/scala/firrtl/benchmark/hot/util/package.scala new file mode 100644 index 0000000000..c05b83597c --- /dev/null +++ b/benchmark/src/main/scala/firrtl/benchmark/hot/util/package.scala @@ -0,0 +1,28 @@ + +package firrtl.benchmark.hot + +import firrtl.Utils.time +import firrtl.benchmark.util._ + +package object util { + def benchmark(nWarmup: Int, nRun: Int)(f: => Unit): Unit = { + // Warmup + for (i <- 0 until nWarmup) { + val (t, res) = time(f) + println(f"Warmup run $i took $t%.1f ms") + } + + // Benchmark + val times: Array[Double] = Array.fill(nRun)(0.0) + for (i <- 0 until nRun) { + val (t, res) = time(f) + times(i) = t + println(f"Benchmark run $i took $t%.1f ms") + } + + println(f"Mean: ${mean(times)}%.1f ms") + println(f"Median: ${median(times)}%.1f ms") + println(f"Stddev: ${stdDev(times)}%.1f ms") + } + +} diff --git a/benchmark/src/main/scala/firrtl/benchmark/util/package.scala b/benchmark/src/main/scala/firrtl/benchmark/util/package.scala new file mode 100644 index 0000000000..2923d8b5fe --- /dev/null +++ b/benchmark/src/main/scala/firrtl/benchmark/util/package.scala @@ -0,0 +1,34 @@ + +package firrtl +package benchmark + +import firrtl.ir.Circuit +import scala.util.control.NonFatal + +package object util { + def filenameToCircuit(filename: String): Circuit = try { + proto.FromProto.fromFile(filename) + } catch { + case NonFatal(_) => Parser.parseFile(filename, Parser.IgnoreInfo) + } + + def mean(xs: Iterable[Double]): Double = xs.sum / xs.size + + def median(xs: Iterable[Double]): Double = { + val size = xs.size + val sorted = xs.toSeq.sorted + if (size % 2 == 1) sorted(size / 2) + else { + val a = sorted(size / 2) + val b = sorted((size / 2) - 1) + (a + b) / 2 + } + } + + def variance(xs: Iterable[Double]): Double = { + val avg = mean(xs) + xs.map(a => math.pow(a - avg, 2)).sum / xs.size + } + + def stdDev(xs: Iterable[Double]): Double = math.sqrt(variance(xs)) +} diff --git a/build.sbt b/build.sbt index 89eda7a5b5..37a353ecbd 100644 --- a/build.sbt +++ b/build.sbt @@ -192,3 +192,11 @@ lazy val firrtl = (project in file(".")) .settings(publishSettings) .settings(docSettings) .settings(mimaSettings) + +lazy val benchmark = (project in file("benchmark")) + .dependsOn(firrtl) + .settings( + assemblyJarName in assembly := "firrtl-benchmark.jar", + test in assembly := {}, + assemblyOutputPath in assembly := file("./utils/bin/firrtl-benchmark.jar") + ) diff --git a/src/main/scala/firrtl/passes/ResolveKinds.scala b/src/main/scala/firrtl/passes/ResolveKinds.scala index 73c8646ac8..a3c23a8cb4 100644 --- a/src/main/scala/firrtl/passes/ResolveKinds.scala +++ b/src/main/scala/firrtl/passes/ResolveKinds.scala @@ -5,18 +5,53 @@ package firrtl.passes import firrtl._ import firrtl.ir._ import firrtl.Mappers._ +import firrtl.traversals.Foreachers._ import firrtl.options.PreservesAll object ResolveKinds extends Pass with PreservesAll[Transform] { override def prerequisites = firrtl.stage.Forms.WorkingIR + private def find_port(kinds: collection.mutable.HashMap[String, Kind])(p: Port): Unit = { + kinds(p.name) = PortKind + } + + def resolve_expr(kinds: collection.mutable.HashMap[String, Kind])(e: Expression): Expression = e match { + case ex: WRef => ex copy (kind = kinds(ex.name)) + case _ => e map resolve_expr(kinds) + } + + def resolve_stmt(kinds: collection.mutable.HashMap[String, Kind])(s: Statement): Statement = { + s match { + case sx: DefWire => kinds(sx.name) = WireKind + case sx: DefNode => kinds(sx.name) = NodeKind + case sx: DefRegister => kinds(sx.name) = RegKind + case sx: WDefInstance => kinds(sx.name) = InstanceKind + case sx: DefMemory => kinds(sx.name) = MemKind + case _ => + } + s.map(resolve_stmt(kinds)) + .map(resolve_expr(kinds)) + } + + def resolve_kinds(m: DefModule): DefModule = { + val kinds = new collection.mutable.HashMap[String, Kind] + m.foreach(find_port(kinds)) + m.map(resolve_stmt(kinds)) + } + + def run(c: Circuit): Circuit = + c copy (modules = c.modules map resolve_kinds) + + @deprecated("This internal type alias will change in 1.4", "1.3.1") type KindMap = collection.mutable.LinkedHashMap[String, Kind] + @deprecated("This internal method's signature will change in 1.4", "1.3.1") def find_port(kinds: KindMap)(p: Port): Port = { kinds(p.name) = PortKind ; p } + @deprecated("This internal method's signature will change in 1.4", "1.3.1") def find_stmt(kinds: KindMap)(s: Statement):Statement = { s match { case sx: DefWire => kinds(sx.name) = WireKind @@ -29,21 +64,14 @@ object ResolveKinds extends Pass with PreservesAll[Transform] { s map find_stmt(kinds) } + @deprecated("This internal method's signature will change in 1.4", "1.3.1") def resolve_expr(kinds: KindMap)(e: Expression): Expression = e match { case ex: WRef => ex copy (kind = kinds(ex.name)) case _ => e map resolve_expr(kinds) } + @deprecated("This internal method's signature will change in 1.4", "1.3.1") def resolve_stmt(kinds: KindMap)(s: Statement): Statement = s map resolve_stmt(kinds) map resolve_expr(kinds) - - def resolve_kinds(m: DefModule): DefModule = { - val kinds = new KindMap - (m map find_port(kinds) - map find_stmt(kinds) - map resolve_stmt(kinds)) - } - - def run(c: Circuit): Circuit = - c copy (modules = c.modules map resolve_kinds) } +