From f9a034caeec1601cb806533d7c6f6cb522c5b6b9 Mon Sep 17 00:00:00 2001 From: InversionSpaces Date: Thu, 13 Jul 2023 21:52:17 +0000 Subject: [PATCH] Refactor binary ops --- .../scala/aqua/parser/lexer/ValueToken.scala | 2 +- .../aqua/semantics/rules/ValuesAlgebra.scala | 98 +++++++++++-------- .../semantics/rules/types/TypesAlgebra.scala | 6 ++ .../rules/types/TypesInterpreter.scala | 16 +++ .../main/scala/aqua/types/CompareTypes.scala | 84 ++++++++-------- types/src/main/scala/aqua/types/Type.scala | 8 +- .../main/scala/aqua/types/UniteTypes.scala | 5 +- 7 files changed, 128 insertions(+), 91 deletions(-) diff --git a/parser/src/main/scala/aqua/parser/lexer/ValueToken.scala b/parser/src/main/scala/aqua/parser/lexer/ValueToken.scala index 6ef8e8e5c..25733e68f 100644 --- a/parser/src/main/scala/aqua/parser/lexer/ValueToken.scala +++ b/parser/src/main/scala/aqua/parser/lexer/ValueToken.scala @@ -351,7 +351,7 @@ object ValueToken { (minus.?.with1 ~ Numbers.nonNegativeIntString).lift.map(fu => fu.extract match { case (Some(_), n) ⇒ LiteralToken(fu.as(s"-$n"), LiteralType.signed) - case (None, n) ⇒ LiteralToken(fu.as(n), LiteralType.number) + case (None, n) ⇒ LiteralToken(fu.as(n), LiteralType.unsigned) } ) diff --git a/semantics/src/main/scala/aqua/semantics/rules/ValuesAlgebra.scala b/semantics/src/main/scala/aqua/semantics/rules/ValuesAlgebra.scala index a699c8730..6d8620e9b 100644 --- a/semantics/src/main/scala/aqua/semantics/rules/ValuesAlgebra.scala +++ b/semantics/src/main/scala/aqua/semantics/rules/ValuesAlgebra.scala @@ -129,18 +129,17 @@ class ValuesAlgebra[S[_], Alg[_]: Monad](implicit typeFromFieldsWithData = rawFields .map(rf => resolvedType match { - case struct@StructType(_, _) => + case struct @ StructType(_, _) => ( StructType(typeName.value, rf.map(_.`type`)), Some(MakeStructRaw(rf, struct)) ) - case scope@AbilityType(_, _) => + case scope @ AbilityType(_, _) => ( AbilityType(typeName.value, rf.map(_.`type`)), Some(AbilityRaw(rf, scope)) ) } - ) .getOrElse(BottomType -> None) (typeFromFields, data) = typeFromFieldsWithData @@ -175,50 +174,59 @@ class ValuesAlgebra[S[_], Alg[_]: Monad](implicit callArrowToRaw(ca).map(_.widen[ValueRaw]) case it @ InfixToken(l, r, _) => - (valueToRaw(l), valueToRaw(r)).mapN((ll, rr) => ll -> rr).flatMap { + (valueToRaw(l), valueToRaw(r)).flatMapN { case (Some(leftRaw), Some(rightRaw)) => - // TODO handle literal types - val hasFloats = ScalarType.float - .exists(ft => leftRaw.`type`.acceptsValueOf(ft) || rightRaw.`type`.acceptsValueOf(ft)) + val lType = leftRaw.`type` + val rType = rightRaw.`type` + lazy val uType = lType `∪` rType + val hasFloat = List(lType, rType).exists( + _ acceptsValueOf LiteralType.float + ) + + // See https://github.com/fluencelabs/aqua-lib/blob/main/math.aqua + val (id, fn) = it.op match { + case Op.Add => ("math", "add") + case Op.Sub => ("math", "sub") + case Op.Mul if hasFloat => ("math", "fmul") + case Op.Mul => ("math", "mul") + case Op.Div => ("math", "div") + case Op.Rem => ("math", "rem") + case Op.Pow => ("math", "pow") + case Op.Gt => ("cmp", "gt") + case Op.Gte => ("cmp", "gte") + case Op.Lt => ("cmp", "lt") + case Op.Lte => ("cmp", "lte") + } - // https://github.com/fluencelabs/aqua-lib/blob/main/math.aqua - // Expected types of left and right operands, result type if known, service ID and function name - val (leftType, rightType, res, id, fn) = it.op match { - case Op.Add => - (ScalarType.i64, ScalarType.i64, None, "math", "add") - case Op.Sub => (ScalarType.i64, ScalarType.i64, None, "math", "sub") - case Op.Mul if hasFloats => - // TODO may it be i32? - (ScalarType.f64, ScalarType.f64, Some(ScalarType.i64), "math", "fmul") + // Expected type sets of left and right operands, result type + val (leftExp, rightExp, resType) = it.op match { + case Op.Add | Op.Sub | Op.Div | Op.Rem => + (ScalarType.integer, ScalarType.integer, uType) + case Op.Pow => + (ScalarType.integer, ScalarType.unsigned, uType) + case Op.Mul if hasFloat => + (ScalarType.float, ScalarType.float, ScalarType.i64) case Op.Mul => - (ScalarType.i64, ScalarType.i64, None, "math", "mul") - case Op.Div => (ScalarType.i64, ScalarType.i64, None, "math", "div") - case Op.Rem => (ScalarType.i64, ScalarType.i64, None, "math", "rem") - case Op.Pow => (ScalarType.i64, ScalarType.u32, None, "math", "pow") - case Op.Gt => (ScalarType.i64, ScalarType.i64, Some(ScalarType.bool), "cmp", "gt") - case Op.Gte => (ScalarType.i64, ScalarType.i64, Some(ScalarType.bool), "cmp", "gte") - case Op.Lt => (ScalarType.i64, ScalarType.i64, Some(ScalarType.bool), "cmp", "lt") - case Op.Lte => (ScalarType.i64, ScalarType.i64, Some(ScalarType.bool), "cmp", "lte") + (ScalarType.integer, ScalarType.integer, uType) + case Op.Gt | Op.Lt | Op.Gte | Op.Lte => + (ScalarType.integer, ScalarType.integer, ScalarType.bool) } + for { - ltm <- T.ensureTypeMatches(l, leftType, leftRaw.`type`) - rtm <- T.ensureTypeMatches(r, rightType, rightRaw.`type`) - } yield Option.when(ltm && rtm)( + leftChecked <- T.ensureTypeOneOf(l, leftExp, lType) + rightChecked <- T.ensureTypeOneOf(r, rightExp, rType) + } yield Option.when( + leftChecked.isDefined && rightChecked.isDefined + )( CallArrowRaw( - Some(id), - fn, - leftRaw :: rightRaw :: Nil, - ArrowType( - ProductType(leftType :: rightType :: Nil), - ProductType( - res.getOrElse( - // If result type is not known/enforced, then assume it's the widest type of operands - // E.g. 1:i8 + 1:i8 -> i8, not i64 - leftRaw.`type` `∪` rightRaw.`type` - ) :: Nil - ) + ability = Some(id), + name = fn, + arguments = leftRaw :: rightRaw :: Nil, + baseType = ArrowType( + ProductType(lType :: rType :: Nil), + ProductType(resType :: Nil) ), - Some(LiteralRaw.quote(id)) + serviceId = Some(LiteralRaw.quote(id)) ) ) @@ -228,9 +236,14 @@ class ValuesAlgebra[S[_], Alg[_]: Monad](implicit } // Generate CallArrowRaw for arrow in ability - def callAbType(ab: String, abType: AbilityType, ca: CallArrowToken[S]): Alg[Option[CallArrowRaw]] = + def callAbType( + ab: String, + abType: AbilityType, + ca: CallArrowToken[S] + ): Alg[Option[CallArrowRaw]] = abType.arrows.get(ca.funcName.value) match { - case Some(arrowType) => Option(CallArrowRaw(None, s"$ab.${ca.funcName.value}", Nil, arrowType, None)).pure[Alg] + case Some(arrowType) => + Option(CallArrowRaw(None, s"$ab.${ca.funcName.value}", Nil, arrowType, None)).pure[Alg] case None => None.pure[Alg] } @@ -279,7 +292,6 @@ class ValuesAlgebra[S[_], Alg[_]: Monad](implicit case _ => none } } - ) result <- raw.flatTraverse(r => val arr = r.baseType diff --git a/semantics/src/main/scala/aqua/semantics/rules/types/TypesAlgebra.scala b/semantics/src/main/scala/aqua/semantics/rules/types/TypesAlgebra.scala index 59b867637..b278ab0ec 100644 --- a/semantics/src/main/scala/aqua/semantics/rules/types/TypesAlgebra.scala +++ b/semantics/src/main/scala/aqua/semantics/rules/types/TypesAlgebra.scala @@ -36,6 +36,12 @@ trait TypesAlgebra[S[_], Alg[_]] { def ensureTypeMatches(token: Token[S], expected: Type, givenType: Type): Alg[Boolean] + def ensureTypeOneOf[T <: Type]( + token: Token[S], + expected: Set[T], + givenType: Type + ): Alg[Option[Type]] + def expectNoExport(token: Token[S]): Alg[Unit] def checkArgumentsNumber(token: Token[S], expected: Int, givenNum: Int): Alg[Boolean] diff --git a/semantics/src/main/scala/aqua/semantics/rules/types/TypesInterpreter.scala b/semantics/src/main/scala/aqua/semantics/rules/types/TypesInterpreter.scala index ea671d63a..95273aa37 100644 --- a/semantics/src/main/scala/aqua/semantics/rules/types/TypesInterpreter.scala +++ b/semantics/src/main/scala/aqua/semantics/rules/types/TypesInterpreter.scala @@ -36,6 +36,7 @@ import cats.syntax.flatMap.* import cats.syntax.functor.* import cats.syntax.traverse.* import cats.{~>, Applicative} +import cats.syntax.option.* import monocle.Lens import monocle.macros.GenLens @@ -296,6 +297,21 @@ class TypesInterpreter[S[_], X](implicit } } + override def ensureTypeOneOf[T <: Type]( + token: Token[S], + expected: Set[T], + givenType: Type + ): State[X, Option[Type]] = expected + .find(_ acceptsValueOf givenType) + .fold( + reportError( + token, + "Types mismatch." :: + s"expected one of: ${expected.mkString(", ")}" :: + s"given: $givenType" :: Nil + ).as(none) + )(_.some.pure) + override def expectNoExport(token: Token[S]): State[X, Unit] = report( token, diff --git a/types/src/main/scala/aqua/types/CompareTypes.scala b/types/src/main/scala/aqua/types/CompareTypes.scala index c32fd2c2c..6ad257696 100644 --- a/types/src/main/scala/aqua/types/CompareTypes.scala +++ b/types/src/main/scala/aqua/types/CompareTypes.scala @@ -106,46 +106,50 @@ object CompareTypes { * -1 if left is a subtype of the right */ def apply(l: Type, r: Type): Double = - if (l == r) 0.0 - else - (l, r) match { - case (TopType, _) | (_, BottomType) => 1.0 - case (BottomType, _) | (_, TopType) => -1.0 - - // Literals and scalars - case (x: ScalarType, y: ScalarType) => scalarOrder.partialCompare(x, y) - case (LiteralType(xs, _), y: ScalarType) if xs == Set(y) => 0.0 - case (LiteralType(xs, _), y: ScalarType) if xs(y) => -1.0 - case (x: ScalarType, LiteralType(ys, _)) if ys == Set(x) => 0.0 - case (x: ScalarType, LiteralType(ys, _)) if ys(x) => 1.0 - - // Collections - case (x: ArrayType, y: ArrayType) => apply(x.element, y.element) - case (x: ArrayType, y: StreamType) => apply(x.element, y.element) - case (x: ArrayType, y: OptionType) => apply(x.element, y.element) - case (x: OptionType, y: OptionType) => apply(x.element, y.element) - case (x: OptionType, y: StreamType) => apply(x.element, y.element) - case (x: OptionType, y: ArrayType) => apply(x.element, y.element) - case (x: StreamType, y: StreamType) => apply(x.element, y.element) - case (lnt: AbilityType, rnt: AbilityType) => compareNamed(lnt.fields, rnt.fields) - case (lnt: StructType, rnt: StructType) => compareNamed(lnt.fields, rnt.fields) - - // Products - case (l: ProductType, r: ProductType) => compareProducts(l, r) - - // Arrows - case (ArrowType(ldom, lcodom), ArrowType(rdom, rcodom)) => - val cmpDom = apply(ldom, rdom) - val cmpCodom = apply(lcodom, rcodom) - - if (cmpDom == 0 && cmpCodom == 0) 0 - else if (cmpDom <= 0 && cmpCodom >= 0) 1.0 - else if (cmpDom >= 0 && cmpCodom <= 0) -1.0 - else NaN - - case _ => - Double.NaN - } + (l, r) match { + case _ if l == r => 0.0 + + case (TopType, _) | (_, BottomType) => 1.0 + case (BottomType, _) | (_, TopType) => + -1.0 + + // Collections + case (x: ArrayType, y: ArrayType) => apply(x.element, y.element) + case (x: ArrayType, y: StreamType) => apply(x.element, y.element) + case (x: ArrayType, y: OptionType) => apply(x.element, y.element) + case (x: OptionType, y: OptionType) => apply(x.element, y.element) + case (x: OptionType, y: StreamType) => apply(x.element, y.element) + case (x: OptionType, y: ArrayType) => apply(x.element, y.element) + case (x: StreamType, y: StreamType) => apply(x.element, y.element) + case (lnt: AbilityType, rnt: AbilityType) => compareNamed(lnt.fields, rnt.fields) + case (lnt: StructType, rnt: StructType) => compareNamed(lnt.fields, rnt.fields) + + // Literals and scalars + case (x: ScalarType, y: ScalarType) => scalarOrder.partialCompare(x, y) + case (LiteralType(xs, _), y: ScalarType) if xs == Set(y) => 0.0 + case (LiteralType(xs, _), y: ScalarType) if xs(y) => -1.0 + case (x: ScalarType, LiteralType(ys, _)) if ys == Set(x) => 0.0 + case (x: ScalarType, LiteralType(ys, _)) if ys(x) => 1.0 + case (LiteralType(xs, _), LiteralType(ys, _)) if xs == ys => 0.0 + case (LiteralType(xs, _), LiteralType(ys, _)) if xs subsetOf ys => 1.0 + case (LiteralType(xs, _), LiteralType(ys, _)) if ys subsetOf xs => -1.0 + + // Products + case (l: ProductType, r: ProductType) => compareProducts(l, r) + + // Arrows + case (ArrowType(ldom, lcodom), ArrowType(rdom, rcodom)) => + val cmpDom = apply(ldom, rdom) + val cmpCodom = apply(lcodom, rcodom) + + if (cmpDom == 0 && cmpCodom == 0) 0 + else if (cmpDom <= 0 && cmpCodom >= 0) 1.0 + else if (cmpDom >= 0 && cmpCodom <= 0) -1.0 + else NaN + + case _ => + Double.NaN + } implicit val partialOrder: PartialOrder[Type] = PartialOrder.from(CompareTypes.apply) diff --git a/types/src/main/scala/aqua/types/Type.scala b/types/src/main/scala/aqua/types/Type.scala index de4b6b0d1..c9d99153c 100644 --- a/types/src/main/scala/aqua/types/Type.scala +++ b/types/src/main/scala/aqua/types/Type.scala @@ -172,19 +172,21 @@ object ScalarType { val string = ScalarType("string") val float = Set(f32, f64) - val signed = float ++ Set(i8, i16, i32, i64) + val signed = Set(i8, i16, i32, i64) val unsigned = Set(u8, u16, u32, u64) - val number = signed ++ unsigned + val integer = signed ++ unsigned + val number = float ++ integer val all = number ++ Set(bool, string) } case class LiteralType private (oneOf: Set[ScalarType], name: String) extends DataType { - override def toString: String = s"$name:lt" + override def toString: String = s"$name literal" } object LiteralType { val float = LiteralType(ScalarType.float, "float") val signed = LiteralType(ScalarType.signed, "signed") + val unsigned = LiteralType(ScalarType.unsigned, "unsigned") val number = LiteralType(ScalarType.number, "number") val bool = LiteralType(Set(ScalarType.bool), "bool") val string = LiteralType(Set(ScalarType.string), "string") diff --git a/types/src/main/scala/aqua/types/UniteTypes.scala b/types/src/main/scala/aqua/types/UniteTypes.scala index 32c2297f6..edbc74f68 100644 --- a/types/src/main/scala/aqua/types/UniteTypes.scala +++ b/types/src/main/scala/aqua/types/UniteTypes.scala @@ -30,8 +30,6 @@ case class UniteTypes(scalarsCombine: ScalarsCombine.T) extends Monoid[Type]: override def combine(a: Type, b: Type): Type = (a, b) match { - case _ if CompareTypes(a, b) == 0.0 => a - case (ap: ProductType, bp: ProductType) => combineProducts(ap, bp) @@ -75,8 +73,7 @@ case class UniteTypes(scalarsCombine: ScalarsCombine.T) extends Monoid[Type]: case 1.0 => a case -1.0 => b case 0.0 => a - case _ => - TopType + case _ => TopType } }