Skip to content
This repository has been archived by the owner on Aug 20, 2024. It is now read-only.

Commit

Permalink
Add basic transform benchmarking infrastructure and speed up Resolve …
Browse files Browse the repository at this point in the history
…Kinds (bp #1475) (#1622)

* Add benchmark for ResolveKinds with hot JIT

* Use HashMap instead of LinkedHashMap in ResolveKinds

* Modify to deprecate old methods for backport

(cherry picked from commit 0f78e2d)

Co-authored-by: Albert Magyar <albert.magyar@gmail.com>

* Eliminate unnecessary traversals in ResolveKinds

* Modify for backporting

* Make find_port return Unit and use Foreachers in ResolveKinds

* Modify for backporting

Co-authored-by: Jack Koenig <koenig@sifive.com>
Co-authored-by: Albert Magyar <albert.magyar@gmail.com>
  • Loading branch information
3 people authored May 14, 2020
1 parent dac5d51 commit 83c36db
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 10 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@

package firrtl
package benchmark
package hot

import passes.ResolveKinds
import stage.TransformManager

import firrtl.benchmark.util._

object ResolveKindsBenchmark extends App {
val inputFile = args(0)
val warmup = args(1).toInt
val runs = args(2).toInt

val input = filenameToCircuit(inputFile)
val state = CircuitState(input, ChirrtlForm)
val prereqs = ResolveKinds.prerequisites
val manager = new TransformManager(prereqs)
val preState = manager.execute(state)

hot.util.benchmark(warmup, runs)(ResolveKinds.run(preState.circuit))
}
28 changes: 28 additions & 0 deletions benchmark/src/main/scala/firrtl/benchmark/hot/util/package.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@

package firrtl.benchmark.hot

import firrtl.Utils.time
import firrtl.benchmark.util._

package object util {
def benchmark(nWarmup: Int, nRun: Int)(f: => Unit): Unit = {
// Warmup
for (i <- 0 until nWarmup) {
val (t, res) = time(f)
println(f"Warmup run $i took $t%.1f ms")
}

// Benchmark
val times: Array[Double] = Array.fill(nRun)(0.0)
for (i <- 0 until nRun) {
val (t, res) = time(f)
times(i) = t
println(f"Benchmark run $i took $t%.1f ms")
}

println(f"Mean: ${mean(times)}%.1f ms")
println(f"Median: ${median(times)}%.1f ms")
println(f"Stddev: ${stdDev(times)}%.1f ms")
}

}
34 changes: 34 additions & 0 deletions benchmark/src/main/scala/firrtl/benchmark/util/package.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@

package firrtl
package benchmark

import firrtl.ir.Circuit
import scala.util.control.NonFatal

package object util {
def filenameToCircuit(filename: String): Circuit = try {
proto.FromProto.fromFile(filename)
} catch {
case NonFatal(_) => Parser.parseFile(filename, Parser.IgnoreInfo)
}

def mean(xs: Iterable[Double]): Double = xs.sum / xs.size

def median(xs: Iterable[Double]): Double = {
val size = xs.size
val sorted = xs.toSeq.sorted
if (size % 2 == 1) sorted(size / 2)
else {
val a = sorted(size / 2)
val b = sorted((size / 2) - 1)
(a + b) / 2
}
}

def variance(xs: Iterable[Double]): Double = {
val avg = mean(xs)
xs.map(a => math.pow(a - avg, 2)).sum / xs.size
}

def stdDev(xs: Iterable[Double]): Double = math.sqrt(variance(xs))
}
8 changes: 8 additions & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,11 @@ lazy val firrtl = (project in file("."))
.settings(publishSettings)
.settings(docSettings)
.settings(mimaSettings)

lazy val benchmark = (project in file("benchmark"))
.dependsOn(firrtl)
.settings(
assemblyJarName in assembly := "firrtl-benchmark.jar",
test in assembly := {},
assemblyOutputPath in assembly := file("./utils/bin/firrtl-benchmark.jar")
)
48 changes: 38 additions & 10 deletions src/main/scala/firrtl/passes/ResolveKinds.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,53 @@ package firrtl.passes
import firrtl._
import firrtl.ir._
import firrtl.Mappers._
import firrtl.traversals.Foreachers._
import firrtl.options.PreservesAll

object ResolveKinds extends Pass with PreservesAll[Transform] {

override def prerequisites = firrtl.stage.Forms.WorkingIR

private def find_port(kinds: collection.mutable.HashMap[String, Kind])(p: Port): Unit = {
kinds(p.name) = PortKind
}

def resolve_expr(kinds: collection.mutable.HashMap[String, Kind])(e: Expression): Expression = e match {
case ex: WRef => ex copy (kind = kinds(ex.name))
case _ => e map resolve_expr(kinds)
}

def resolve_stmt(kinds: collection.mutable.HashMap[String, Kind])(s: Statement): Statement = {
s match {
case sx: DefWire => kinds(sx.name) = WireKind
case sx: DefNode => kinds(sx.name) = NodeKind
case sx: DefRegister => kinds(sx.name) = RegKind
case sx: WDefInstance => kinds(sx.name) = InstanceKind
case sx: DefMemory => kinds(sx.name) = MemKind
case _ =>
}
s.map(resolve_stmt(kinds))
.map(resolve_expr(kinds))
}

def resolve_kinds(m: DefModule): DefModule = {
val kinds = new collection.mutable.HashMap[String, Kind]
m.foreach(find_port(kinds))
m.map(resolve_stmt(kinds))
}

def run(c: Circuit): Circuit =
c copy (modules = c.modules map resolve_kinds)

@deprecated("This internal type alias will change in 1.4", "1.3.1")
type KindMap = collection.mutable.LinkedHashMap[String, Kind]

@deprecated("This internal method's signature will change in 1.4", "1.3.1")
def find_port(kinds: KindMap)(p: Port): Port = {
kinds(p.name) = PortKind ; p
}

@deprecated("This internal method's signature will change in 1.4", "1.3.1")
def find_stmt(kinds: KindMap)(s: Statement):Statement = {
s match {
case sx: DefWire => kinds(sx.name) = WireKind
Expand All @@ -29,21 +64,14 @@ object ResolveKinds extends Pass with PreservesAll[Transform] {
s map find_stmt(kinds)
}

@deprecated("This internal method's signature will change in 1.4", "1.3.1")
def resolve_expr(kinds: KindMap)(e: Expression): Expression = e match {
case ex: WRef => ex copy (kind = kinds(ex.name))
case _ => e map resolve_expr(kinds)
}

@deprecated("This internal method's signature will change in 1.4", "1.3.1")
def resolve_stmt(kinds: KindMap)(s: Statement): Statement =
s map resolve_stmt(kinds) map resolve_expr(kinds)

def resolve_kinds(m: DefModule): DefModule = {
val kinds = new KindMap
(m map find_port(kinds)
map find_stmt(kinds)
map resolve_stmt(kinds))
}

def run(c: Circuit): Circuit =
c copy (modules = c.modules map resolve_kinds)
}

0 comments on commit 83c36db

Please sign in to comment.