diff --git a/src/main/scala/firrtl/passes/InferTypes.scala b/src/main/scala/firrtl/passes/InferTypes.scala index 78213f49e8..4702a87f7c 100644 --- a/src/main/scala/firrtl/passes/InferTypes.scala +++ b/src/main/scala/firrtl/passes/InferTypes.scala @@ -12,8 +12,11 @@ object InferTypes extends Pass with PreservesAll[Transform] { override def prerequisites = Dependency(ResolveKinds) +: firrtl.stage.Forms.WorkingIR + @deprecated("This should never have been public", "1.3.2") type TypeMap = collection.mutable.LinkedHashMap[String, Type] + private type TypeLookup = collection.mutable.HashMap[String, Type] + def run(c: Circuit): Circuit = { val namespace = Namespace() val mtypes = (c.modules map (m => m.name -> module_type(m))).toMap @@ -36,7 +39,7 @@ object InferTypes extends Pass with PreservesAll[Transform] { } } - def infer_types_e(types: TypeMap)(e: Expression): Expression = + def infer_types_e(types: TypeLookup)(e: Expression): Expression = e map infer_types_e(types) match { case e: WRef => e copy (tpe = types(e.name)) case e: WSubField => e copy (tpe = field_type(e.expr.tpe, e.name)) @@ -48,7 +51,7 @@ object InferTypes extends Pass with PreservesAll[Transform] { case e @ (_: UIntLiteral | _: SIntLiteral) => e } - def infer_types_s(types: TypeMap)(s: Statement): Statement = s match { + def infer_types_s(types: TypeLookup)(s: Statement): Statement = s match { case sx: WDefInstance => val t = mtypes(sx.module) types(sx.name) = t @@ -61,7 +64,7 @@ object InferTypes extends Pass with PreservesAll[Transform] { val sxx = (sx map infer_types_e(types)).asInstanceOf[DefNode] val t = remove_unknowns(sxx.value.tpe) types(sx.name) = t - sxx map infer_types_e(types) + sxx case sx: DefRegister => val t = remove_unknowns(sx.tpe) types(sx.name) = t @@ -73,14 +76,14 @@ object InferTypes extends Pass with PreservesAll[Transform] { case sx => sx map infer_types_s(types) map infer_types_e(types) } - def infer_types_p(types: TypeMap)(p: Port): Port = { + def infer_types_p(types: TypeLookup)(p: Port): Port = { val t = remove_unknowns(p.tpe) types(p.name) = t p copy (tpe = t) } def infer_types(m: DefModule): DefModule = { - val types = new TypeMap + val types = new TypeLookup m map infer_types_p(types) map infer_types_s(types) } @@ -92,12 +95,15 @@ object CInferTypes extends Pass with PreservesAll[Transform] { override def prerequisites = firrtl.stage.Forms.ChirrtlForm + @deprecated("This should never have been public", "1.3.2") type TypeMap = collection.mutable.LinkedHashMap[String, Type] + private type TypeLookup = collection.mutable.HashMap[String, Type] + def run(c: Circuit): Circuit = { val mtypes = (c.modules map (m => m.name -> module_type(m))).toMap - def infer_types_e(types: TypeMap)(e: Expression) : Expression = + def infer_types_e(types: TypeLookup)(e: Expression) : Expression = e map infer_types_e(types) match { case (e: Reference) => e copy (tpe = types.getOrElse(e.name, UnknownType)) case (e: SubField) => e copy (tpe = field_type(e.expr.tpe, e.name)) @@ -109,7 +115,7 @@ object CInferTypes extends Pass with PreservesAll[Transform] { case e @ (_: UIntLiteral | _: SIntLiteral) => e } - def infer_types_s(types: TypeMap)(s: Statement): Statement = s match { + def infer_types_s(types: TypeLookup)(s: Statement): Statement = s match { case sx: DefRegister => types(sx.name) = sx.tpe sx map infer_types_e(types) @@ -136,13 +142,13 @@ object CInferTypes extends Pass with PreservesAll[Transform] { case sx => sx map infer_types_s(types) map infer_types_e(types) } - def infer_types_p(types: TypeMap)(p: Port): Port = { + def infer_types_p(types: TypeLookup)(p: Port): Port = { types(p.name) = p.tpe p } def infer_types(m: DefModule): DefModule = { - val types = new TypeMap + val types = new TypeLookup m map infer_types_p(types) map infer_types_s(types) }