Skip to content
This repository has been archived by the owner on Aug 20, 2024. It is now read-only.

Add basic transform benchmarking infrastructure and speed up Resolve Kinds #1475

Merged
merged 4 commits into from
May 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@

package firrtl
package benchmark
package hot

import passes.ResolveKinds
import stage.TransformManager

import firrtl.benchmark.util._

object ResolveKindsBenchmark extends App {
val inputFile = args(0)
val warmup = args(1).toInt
val runs = args(2).toInt

val input = filenameToCircuit(inputFile)
val state = CircuitState(input, ChirrtlForm)
val prereqs = ResolveKinds.prerequisites
val manager = new TransformManager(prereqs)
val preState = manager.execute(state)

hot.util.benchmark(warmup, runs)(ResolveKinds.run(preState.circuit))
}
28 changes: 28 additions & 0 deletions benchmark/src/main/scala/firrtl/benchmark/hot/util/package.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@

package firrtl.benchmark.hot

import firrtl.Utils.time
import firrtl.benchmark.util._

package object util {
def benchmark(nWarmup: Int, nRun: Int)(f: => Unit): Unit = {
// Warmup
for (i <- 0 until nWarmup) {
val (t, res) = time(f)
println(f"Warmup run $i took $t%.1f ms")
}

// Benchmark
val times: Array[Double] = Array.fill(nRun)(0.0)
for (i <- 0 until nRun) {
val (t, res) = time(f)
times(i) = t
println(f"Benchmark run $i took $t%.1f ms")
}

println(f"Mean: ${mean(times)}%.1f ms")
println(f"Median: ${median(times)}%.1f ms")
println(f"Stddev: ${stdDev(times)}%.1f ms")
}

}
34 changes: 34 additions & 0 deletions benchmark/src/main/scala/firrtl/benchmark/util/package.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@

package firrtl
package benchmark

import firrtl.ir.Circuit
import scala.util.control.NonFatal

package object util {
def filenameToCircuit(filename: String): Circuit = try {
proto.FromProto.fromFile(filename)
} catch {
case NonFatal(_) => Parser.parseFile(filename, Parser.IgnoreInfo)
}

def mean(xs: Iterable[Double]): Double = xs.sum / xs.size

def median(xs: Iterable[Double]): Double = {
val size = xs.size
val sorted = xs.toSeq.sorted
if (size % 2 == 1) sorted(size / 2)
else {
val a = sorted(size / 2)
val b = sorted((size / 2) - 1)
(a + b) / 2
}
}

def variance(xs: Iterable[Double]): Double = {
val avg = mean(xs)
xs.map(a => math.pow(a - avg, 2)).sum / xs.size
}

def stdDev(xs: Iterable[Double]): Double = math.sqrt(variance(xs))
}
8 changes: 8 additions & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,11 @@ lazy val firrtl = (project in file("."))
.settings(testAssemblySettings)
.settings(publishSettings)
.settings(docSettings)

lazy val benchmark = (project in file("benchmark"))
.dependsOn(firrtl)
.settings(
assemblyJarName in assembly := "firrtl-benchmark.jar",
test in assembly := {},
assemblyOutputPath in assembly := file("./utils/bin/firrtl-benchmark.jar")
)
31 changes: 15 additions & 16 deletions src/main/scala/firrtl/passes/ResolveKinds.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,25 @@ package firrtl.passes
import firrtl._
import firrtl.ir._
import firrtl.Mappers._
import firrtl.traversals.Foreachers._
import firrtl.options.PreservesAll

object ResolveKinds extends Pass with PreservesAll[Transform] {

override def prerequisites = firrtl.stage.Forms.WorkingIR

type KindMap = collection.mutable.LinkedHashMap[String, Kind]
type KindMap = collection.mutable.HashMap[String, Kind]

def find_port(kinds: KindMap)(p: Port): Port = {
kinds(p.name) = PortKind ; p
private def find_port(kinds: KindMap)(p: Port): Unit = {
kinds(p.name) = PortKind
}

def find_stmt(kinds: KindMap)(s: Statement):Statement = {
def resolve_expr(kinds: KindMap)(e: Expression): Expression = e match {
case ex: WRef => ex copy (kind = kinds(ex.name))
case _ => e map resolve_expr(kinds)
}

def resolve_stmt(kinds: KindMap)(s: Statement): Statement = {
s match {
case sx: DefWire => kinds(sx.name) = WireKind
case sx: DefNode => kinds(sx.name) = NodeKind
Expand All @@ -26,24 +32,17 @@ object ResolveKinds extends Pass with PreservesAll[Transform] {
case sx: DefMemory => kinds(sx.name) = MemKind
case _ =>
}
s map find_stmt(kinds)
}

def resolve_expr(kinds: KindMap)(e: Expression): Expression = e match {
case ex: WRef => ex copy (kind = kinds(ex.name))
case _ => e map resolve_expr(kinds)
s.map(resolve_stmt(kinds))
.map(resolve_expr(kinds))
}

def resolve_stmt(kinds: KindMap)(s: Statement): Statement =
s map resolve_stmt(kinds) map resolve_expr(kinds)

def resolve_kinds(m: DefModule): DefModule = {
val kinds = new KindMap
(m map find_port(kinds)
map find_stmt(kinds)
map resolve_stmt(kinds))
m.foreach(find_port(kinds))
m.map(resolve_stmt(kinds))
}

def run(c: Circuit): Circuit =
c copy (modules = c.modules map resolve_kinds)
}