Skip to content

Commit

Permalink
Fix types, add semantics tests
Browse files Browse the repository at this point in the history
  • Loading branch information
InversionSpaces committed Aug 1, 2023
1 parent 69d2892 commit 3a0266d
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,7 @@ import aqua.raw.value.{
import aqua.semantics.rules.locations.LocationsAlgebra
import aqua.semantics.rules.StackInterpreter
import aqua.semantics.rules.errors.ReportErrors
import aqua.types.{
AbilityType,
ArrayType,
ArrowType,
BoxType,
LiteralType,
NamedType,
OptionType,
ProductType,
ScalarType,
StreamType,
StructType,
Type
}
import aqua.types.*
import cats.data.Validated.{Invalid, Valid}
import cats.data.{Chain, NonEmptyList, NonEmptyMap, State}
import cats.instances.list.*
Expand Down Expand Up @@ -226,17 +213,35 @@ class TypesInterpreter[S[_], X](implicit
left: Type,
right: Type
): State[X, Boolean] = {
val isComparable = (left, right) match {
case (LiteralType(xs, _), LiteralType(ys, _)) =>
xs.intersect(ys).nonEmpty
case _ =>
left.acceptsValueOf(right)
}
// TODO: This needs more comprehensive logic
def isComparable(lt: Type, rt: Type): Boolean =
(lt, rt) match {
// All numbers are comparable
case (lst: ScalarType, rst: ScalarType)
if ScalarType.number(lst) && ScalarType.number(rst) =>
true
// Hack: u64 `U` LiteralType.signed = TopType,
// but they shoudl be comparable
case (lst: ScalarType, LiteralType.signed) if ScalarType.number(lst) =>
true
case (LiteralType.signed, rst: ScalarType) if ScalarType.number(rst) =>
true
case (lbt: BoxType, rbt: BoxType) =>
isComparable(lbt.element, rbt.element)
// Prohibit comparing abilities
case (_: AbilityType, _: AbilityType) =>
false
// Prohibit comparing arrows
case (_: ArrowType, _: ArrowType) =>
false
case (LiteralType(xs, _), LiteralType(ys, _)) =>
xs.intersect(ys).nonEmpty
case _ =>
lt.uniteTop(rt) != TopType
}

if (isComparable) State.pure(true)
else
report(token, s"Cannot compare '$left' with '$right''")
.as(false)
if (isComparable(left, right)) State.pure(true)
else report(token, s"Cannot compare '$left' with '$right''").as(false)
}

private def extractToken(token: Token[S]) =
Expand Down
158 changes: 156 additions & 2 deletions semantics/src/test/scala/aqua/semantics/ValuesAlgebraSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ import aqua.semantics.rules.locations.LocationsAlgebra
import aqua.semantics.rules.locations.DummyLocationsInterpreter
import aqua.raw.value.{ApplyBinaryOpRaw, LiteralRaw}
import aqua.raw.RawContext
import aqua.types.{LiteralType, ScalarType, TopType, Type}
import aqua.types.*
import aqua.parser.lexer.{InfixToken, LiteralToken, Name, PrefixToken, ValueToken, VarToken}
import aqua.raw.value.ApplyUnaryOpRaw
import aqua.parser.lexer.ValueToken.string

import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
Expand All @@ -26,8 +28,9 @@ import cats.Id
import cats.data.State
import cats.syntax.functor.*
import cats.syntax.comonad.*
import cats.data.NonEmptyMap
import monocle.syntax.all.*
import aqua.raw.value.ApplyUnaryOpRaw
import scala.collection.immutable.SortedMap

class ValuesAlgebraSpec extends AnyFlatSpec with Matchers with Inside {

Expand Down Expand Up @@ -75,6 +78,25 @@ class ValuesAlgebraSpec extends AnyFlatSpec with Matchers with Inside {
)
)

def valueOfType(t: Type)(
varName: String,
bool: String = "true",
unsigned: String = "42",
signed: String = "-42",
string: String = "string"
): ValueToken[Id] = t match {
case t: LiteralType if t == LiteralType.bool =>
literal(bool, t)
case t: LiteralType if t == LiteralType.unsigned =>
literal(unsigned, t)
case t: LiteralType if t == LiteralType.signed =>
literal(signed, t)
case t: LiteralType if t == LiteralType.string =>
literal(f"\"$string\"", t)
case _ =>
variable(varName)
}

"valueToRaw" should "handle +, -, /, *, % on number literals" in {
val types = List(
LiteralType.signed,
Expand Down Expand Up @@ -249,6 +271,85 @@ class ValuesAlgebraSpec extends AnyFlatSpec with Matchers with Inside {
}
}

it should "handle ==, != on values" in {
val test = (lt: Type, rt: Type) => {
InfixToken.EqOp.values.foreach { op =>
val left = valueOfType(lt)(
varName = "left",
bool = "true",
unsigned = "42",
signed = "-42",
string = "\"foo\""
)
val right = valueOfType(rt)(
varName = "right",
bool = "false",
unsigned = "37",
signed = "-37",
string = "\"bar\""
)

val alg = algebra()

val state = genState(
vars = (
List("left" -> lt).filter(_ =>
lt match {
case _: LiteralType => false
case _ => true
}
) ++ List("right" -> rt).filter(_ =>
rt match
case _: LiteralType => false
case _ => true
)
).toMap
)

val token = InfixToken[Id](left, right, InfixToken.Op.Eq(op))

val (st, res) = alg
.valueToRaw(token)
.run(state)
.value

inside(res) { case Some(ApplyBinaryOpRaw(bop, _, _)) =>
bop shouldBe (op match {
case InfixToken.EqOp.Eq => ApplyBinaryOpRaw.Op.Eq
case InfixToken.EqOp.Neq => ApplyBinaryOpRaw.Op.Neq
})
}
}
}

val numbers = ScalarType.integer.toList ++ List(
LiteralType.signed,
LiteralType.unsigned
)

allPairs(numbers).foreach { case (lt, rt) =>
test(lt, rt)
}

val numberStreams = ScalarType.integer.toList.map(StreamType.apply)

allPairs(numberStreams).foreach { case (lt, rt) =>
test(lt, rt)
}

val structType = StructType(
"Struct",
NonEmptyMap(
"foo" -> ScalarType.i64,
SortedMap(
"bar" -> ScalarType.bool
)
)
)

test(structType, structType)
}

it should "handle ! on bool values" in {
val types = List(LiteralType.bool, ScalarType.bool)

Expand Down Expand Up @@ -325,6 +426,59 @@ class ValuesAlgebraSpec extends AnyFlatSpec with Matchers with Inside {
}
}

it should "check type of (in)equality operands" in {
val structType = StructType("Struct", NonEmptyMap.one("field", ScalarType.i8))

val types =
List(
LiteralType.bool,
ScalarType.i32,
structType,
StreamType(ScalarType.i8),
StreamType(structType),
ArrowType(
domain = ProductType(ScalarType.i64 :: Nil),
codomain = ProductType(ScalarType.bool :: Nil)
)
)

allPairs(types).filterNot { case (lt, rt) => lt == rt }.foreach { case (lt, rt) =>
InfixToken.EqOp.values.foreach { op =>
val left = lt match {
case lt: LiteralType =>
literal("true", lt)
case _ =>
variable("left")
}
val right = rt match {
case rt: LiteralType =>
literal("false", rt)
case _ =>
variable("right")
}

val alg = algebra()

val state = genState(
vars = (
List("left" -> lt).filter(_ => lt != LiteralType.bool) ++
List("right" -> rt).filter(_ => rt != LiteralType.bool)
).toMap
)

val token = InfixToken[Id](left, right, InfixToken.Op.Eq(op))

val (st, res) = alg
.valueToRaw(token)
.run(state)
.value

res shouldBe None
st.errors.exists(_.isInstanceOf[RulesViolated[Id]]) shouldBe true
}
}
}

it should "check type of logical operand (unary)" in {
val types = ScalarType.integer.toList :+ LiteralType.unsigned

Expand Down
4 changes: 2 additions & 2 deletions types/src/main/scala/aqua/types/CompareTypes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,9 @@ object CompareTypes {
// Literals and scalars
case (x: ScalarType, y: ScalarType) => scalarOrder.partialCompare(x, y)
case (LiteralType(xs, _), y: ScalarType) if xs == Set(y) => 0.0
case (LiteralType(xs, _), y: ScalarType) if xs(y) => -1.0
case (LiteralType(xs, _), y: ScalarType) if xs.exists(y acceptsValueOf _) => -1.0
case (x: ScalarType, LiteralType(ys, _)) if ys == Set(x) => 0.0
case (x: ScalarType, LiteralType(ys, _)) if ys(x) => 1.0
case (x: ScalarType, LiteralType(ys, _)) if ys.exists(x acceptsValueOf _) => 1.0
case (LiteralType(xs, _), LiteralType(ys, _)) if xs == ys => 0.0
case (LiteralType(xs, _), LiteralType(ys, _)) if xs subsetOf ys => 1.0
case (LiteralType(xs, _), LiteralType(ys, _)) if ys subsetOf xs => -1.0
Expand Down

0 comments on commit 3a0266d

Please sign in to comment.