From 3a0266d0e053b0aec6ff7a3a12f7b2c29dad398b Mon Sep 17 00:00:00 2001 From: InversionSpaces Date: Tue, 1 Aug 2023 11:46:05 +0000 Subject: [PATCH] Fix types, add semantics tests --- .../rules/types/TypesInterpreter.scala | 53 +++--- .../aqua/semantics/ValuesAlgebraSpec.scala | 158 +++++++++++++++++- .../main/scala/aqua/types/CompareTypes.scala | 4 +- 3 files changed, 187 insertions(+), 28 deletions(-) diff --git a/semantics/src/main/scala/aqua/semantics/rules/types/TypesInterpreter.scala b/semantics/src/main/scala/aqua/semantics/rules/types/TypesInterpreter.scala index 1dde93db7..19b29248c 100644 --- a/semantics/src/main/scala/aqua/semantics/rules/types/TypesInterpreter.scala +++ b/semantics/src/main/scala/aqua/semantics/rules/types/TypesInterpreter.scala @@ -13,20 +13,7 @@ import aqua.raw.value.{ import aqua.semantics.rules.locations.LocationsAlgebra import aqua.semantics.rules.StackInterpreter import aqua.semantics.rules.errors.ReportErrors -import aqua.types.{ - AbilityType, - ArrayType, - ArrowType, - BoxType, - LiteralType, - NamedType, - OptionType, - ProductType, - ScalarType, - StreamType, - StructType, - Type -} +import aqua.types.* import cats.data.Validated.{Invalid, Valid} import cats.data.{Chain, NonEmptyList, NonEmptyMap, State} import cats.instances.list.* @@ -226,17 +213,35 @@ class TypesInterpreter[S[_], X](implicit left: Type, right: Type ): State[X, Boolean] = { - val isComparable = (left, right) match { - case (LiteralType(xs, _), LiteralType(ys, _)) => - xs.intersect(ys).nonEmpty - case _ => - left.acceptsValueOf(right) - } + // TODO: This needs more comprehensive logic + def isComparable(lt: Type, rt: Type): Boolean = + (lt, rt) match { + // All numbers are comparable + case (lst: ScalarType, rst: ScalarType) + if ScalarType.number(lst) && ScalarType.number(rst) => + true + // Hack: u64 `U` LiteralType.signed = TopType, + // but they shoudl be comparable + case (lst: ScalarType, LiteralType.signed) if ScalarType.number(lst) => + true + case (LiteralType.signed, rst: ScalarType) if ScalarType.number(rst) => + true + case (lbt: BoxType, rbt: BoxType) => + isComparable(lbt.element, rbt.element) + // Prohibit comparing abilities + case (_: AbilityType, _: AbilityType) => + false + // Prohibit comparing arrows + case (_: ArrowType, _: ArrowType) => + false + case (LiteralType(xs, _), LiteralType(ys, _)) => + xs.intersect(ys).nonEmpty + case _ => + lt.uniteTop(rt) != TopType + } - if (isComparable) State.pure(true) - else - report(token, s"Cannot compare '$left' with '$right''") - .as(false) + if (isComparable(left, right)) State.pure(true) + else report(token, s"Cannot compare '$left' with '$right''").as(false) } private def extractToken(token: Token[S]) = diff --git a/semantics/src/test/scala/aqua/semantics/ValuesAlgebraSpec.scala b/semantics/src/test/scala/aqua/semantics/ValuesAlgebraSpec.scala index 03fbcbfc8..381535db6 100644 --- a/semantics/src/test/scala/aqua/semantics/ValuesAlgebraSpec.scala +++ b/semantics/src/test/scala/aqua/semantics/ValuesAlgebraSpec.scala @@ -16,8 +16,10 @@ import aqua.semantics.rules.locations.LocationsAlgebra import aqua.semantics.rules.locations.DummyLocationsInterpreter import aqua.raw.value.{ApplyBinaryOpRaw, LiteralRaw} import aqua.raw.RawContext -import aqua.types.{LiteralType, ScalarType, TopType, Type} +import aqua.types.* import aqua.parser.lexer.{InfixToken, LiteralToken, Name, PrefixToken, ValueToken, VarToken} +import aqua.raw.value.ApplyUnaryOpRaw +import aqua.parser.lexer.ValueToken.string import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers @@ -26,8 +28,9 @@ import cats.Id import cats.data.State import cats.syntax.functor.* import cats.syntax.comonad.* +import cats.data.NonEmptyMap import monocle.syntax.all.* -import aqua.raw.value.ApplyUnaryOpRaw +import scala.collection.immutable.SortedMap class ValuesAlgebraSpec extends AnyFlatSpec with Matchers with Inside { @@ -75,6 +78,25 @@ class ValuesAlgebraSpec extends AnyFlatSpec with Matchers with Inside { ) ) + def valueOfType(t: Type)( + varName: String, + bool: String = "true", + unsigned: String = "42", + signed: String = "-42", + string: String = "string" + ): ValueToken[Id] = t match { + case t: LiteralType if t == LiteralType.bool => + literal(bool, t) + case t: LiteralType if t == LiteralType.unsigned => + literal(unsigned, t) + case t: LiteralType if t == LiteralType.signed => + literal(signed, t) + case t: LiteralType if t == LiteralType.string => + literal(f"\"$string\"", t) + case _ => + variable(varName) + } + "valueToRaw" should "handle +, -, /, *, % on number literals" in { val types = List( LiteralType.signed, @@ -249,6 +271,85 @@ class ValuesAlgebraSpec extends AnyFlatSpec with Matchers with Inside { } } + it should "handle ==, != on values" in { + val test = (lt: Type, rt: Type) => { + InfixToken.EqOp.values.foreach { op => + val left = valueOfType(lt)( + varName = "left", + bool = "true", + unsigned = "42", + signed = "-42", + string = "\"foo\"" + ) + val right = valueOfType(rt)( + varName = "right", + bool = "false", + unsigned = "37", + signed = "-37", + string = "\"bar\"" + ) + + val alg = algebra() + + val state = genState( + vars = ( + List("left" -> lt).filter(_ => + lt match { + case _: LiteralType => false + case _ => true + } + ) ++ List("right" -> rt).filter(_ => + rt match + case _: LiteralType => false + case _ => true + ) + ).toMap + ) + + val token = InfixToken[Id](left, right, InfixToken.Op.Eq(op)) + + val (st, res) = alg + .valueToRaw(token) + .run(state) + .value + + inside(res) { case Some(ApplyBinaryOpRaw(bop, _, _)) => + bop shouldBe (op match { + case InfixToken.EqOp.Eq => ApplyBinaryOpRaw.Op.Eq + case InfixToken.EqOp.Neq => ApplyBinaryOpRaw.Op.Neq + }) + } + } + } + + val numbers = ScalarType.integer.toList ++ List( + LiteralType.signed, + LiteralType.unsigned + ) + + allPairs(numbers).foreach { case (lt, rt) => + test(lt, rt) + } + + val numberStreams = ScalarType.integer.toList.map(StreamType.apply) + + allPairs(numberStreams).foreach { case (lt, rt) => + test(lt, rt) + } + + val structType = StructType( + "Struct", + NonEmptyMap( + "foo" -> ScalarType.i64, + SortedMap( + "bar" -> ScalarType.bool + ) + ) + ) + + test(structType, structType) + } + it should "handle ! on bool values" in { val types = List(LiteralType.bool, ScalarType.bool) @@ -325,6 +426,59 @@ class ValuesAlgebraSpec extends AnyFlatSpec with Matchers with Inside { } } + it should "check type of (in)equality operands" in { + val structType = StructType("Struct", NonEmptyMap.one("field", ScalarType.i8)) + + val types = + List( + LiteralType.bool, + ScalarType.i32, + structType, + StreamType(ScalarType.i8), + StreamType(structType), + ArrowType( + domain = ProductType(ScalarType.i64 :: Nil), + codomain = ProductType(ScalarType.bool :: Nil) + ) + ) + + allPairs(types).filterNot { case (lt, rt) => lt == rt }.foreach { case (lt, rt) => + InfixToken.EqOp.values.foreach { op => + val left = lt match { + case lt: LiteralType => + literal("true", lt) + case _ => + variable("left") + } + val right = rt match { + case rt: LiteralType => + literal("false", rt) + case _ => + variable("right") + } + + val alg = algebra() + + val state = genState( + vars = ( + List("left" -> lt).filter(_ => lt != LiteralType.bool) ++ + List("right" -> rt).filter(_ => rt != LiteralType.bool) + ).toMap + ) + + val token = InfixToken[Id](left, right, InfixToken.Op.Eq(op)) + + val (st, res) = alg + .valueToRaw(token) + .run(state) + .value + + res shouldBe None + st.errors.exists(_.isInstanceOf[RulesViolated[Id]]) shouldBe true + } + } + } + it should "check type of logical operand (unary)" in { val types = ScalarType.integer.toList :+ LiteralType.unsigned diff --git a/types/src/main/scala/aqua/types/CompareTypes.scala b/types/src/main/scala/aqua/types/CompareTypes.scala index 74a9e9218..cbd70672e 100644 --- a/types/src/main/scala/aqua/types/CompareTypes.scala +++ b/types/src/main/scala/aqua/types/CompareTypes.scala @@ -126,9 +126,9 @@ object CompareTypes { // Literals and scalars case (x: ScalarType, y: ScalarType) => scalarOrder.partialCompare(x, y) case (LiteralType(xs, _), y: ScalarType) if xs == Set(y) => 0.0 - case (LiteralType(xs, _), y: ScalarType) if xs(y) => -1.0 + case (LiteralType(xs, _), y: ScalarType) if xs.exists(y acceptsValueOf _) => -1.0 case (x: ScalarType, LiteralType(ys, _)) if ys == Set(x) => 0.0 - case (x: ScalarType, LiteralType(ys, _)) if ys(x) => 1.0 + case (x: ScalarType, LiteralType(ys, _)) if ys.exists(x acceptsValueOf _) => 1.0 case (LiteralType(xs, _), LiteralType(ys, _)) if xs == ys => 0.0 case (LiteralType(xs, _), LiteralType(ys, _)) if xs subsetOf ys => 1.0 case (LiteralType(xs, _), LiteralType(ys, _)) if ys subsetOf xs => -1.0