Skip to content

Commit

Permalink
Refactor binary ops
Browse files Browse the repository at this point in the history
  • Loading branch information
InversionSpaces committed Jul 24, 2023
1 parent cb539f1 commit f9a034c
Show file tree
Hide file tree
Showing 7 changed files with 128 additions and 91 deletions.
2 changes: 1 addition & 1 deletion parser/src/main/scala/aqua/parser/lexer/ValueToken.scala
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ object ValueToken {
(minus.?.with1 ~ Numbers.nonNegativeIntString).lift.map(fu =>
fu.extract match {
case (Some(_), n) LiteralToken(fu.as(s"-$n"), LiteralType.signed)
case (None, n) LiteralToken(fu.as(n), LiteralType.number)
case (None, n) LiteralToken(fu.as(n), LiteralType.unsigned)
}
)

Expand Down
98 changes: 55 additions & 43 deletions semantics/src/main/scala/aqua/semantics/rules/ValuesAlgebra.scala
Original file line number Diff line number Diff line change
Expand Up @@ -129,18 +129,17 @@ class ValuesAlgebra[S[_], Alg[_]: Monad](implicit
typeFromFieldsWithData = rawFields
.map(rf =>
resolvedType match {
case struct@StructType(_, _) =>
case struct @ StructType(_, _) =>
(
StructType(typeName.value, rf.map(_.`type`)),
Some(MakeStructRaw(rf, struct))
)
case scope@AbilityType(_, _) =>
case scope @ AbilityType(_, _) =>
(
AbilityType(typeName.value, rf.map(_.`type`)),
Some(AbilityRaw(rf, scope))
)
}

)
.getOrElse(BottomType -> None)
(typeFromFields, data) = typeFromFieldsWithData
Expand Down Expand Up @@ -175,50 +174,59 @@ class ValuesAlgebra[S[_], Alg[_]: Monad](implicit
callArrowToRaw(ca).map(_.widen[ValueRaw])

case it @ InfixToken(l, r, _) =>
(valueToRaw(l), valueToRaw(r)).mapN((ll, rr) => ll -> rr).flatMap {
(valueToRaw(l), valueToRaw(r)).flatMapN {
case (Some(leftRaw), Some(rightRaw)) =>
// TODO handle literal types
val hasFloats = ScalarType.float
.exists(ft => leftRaw.`type`.acceptsValueOf(ft) || rightRaw.`type`.acceptsValueOf(ft))
val lType = leftRaw.`type`
val rType = rightRaw.`type`
lazy val uType = lType `∪` rType
val hasFloat = List(lType, rType).exists(
_ acceptsValueOf LiteralType.float
)

// See https://github.com/fluencelabs/aqua-lib/blob/main/math.aqua
val (id, fn) = it.op match {
case Op.Add => ("math", "add")
case Op.Sub => ("math", "sub")
case Op.Mul if hasFloat => ("math", "fmul")
case Op.Mul => ("math", "mul")
case Op.Div => ("math", "div")
case Op.Rem => ("math", "rem")
case Op.Pow => ("math", "pow")
case Op.Gt => ("cmp", "gt")
case Op.Gte => ("cmp", "gte")
case Op.Lt => ("cmp", "lt")
case Op.Lte => ("cmp", "lte")
}

// https://github.com/fluencelabs/aqua-lib/blob/main/math.aqua
// Expected types of left and right operands, result type if known, service ID and function name
val (leftType, rightType, res, id, fn) = it.op match {
case Op.Add =>
(ScalarType.i64, ScalarType.i64, None, "math", "add")
case Op.Sub => (ScalarType.i64, ScalarType.i64, None, "math", "sub")
case Op.Mul if hasFloats =>
// TODO may it be i32?
(ScalarType.f64, ScalarType.f64, Some(ScalarType.i64), "math", "fmul")
// Expected type sets of left and right operands, result type
val (leftExp, rightExp, resType) = it.op match {
case Op.Add | Op.Sub | Op.Div | Op.Rem =>
(ScalarType.integer, ScalarType.integer, uType)
case Op.Pow =>
(ScalarType.integer, ScalarType.unsigned, uType)
case Op.Mul if hasFloat =>
(ScalarType.float, ScalarType.float, ScalarType.i64)
case Op.Mul =>
(ScalarType.i64, ScalarType.i64, None, "math", "mul")
case Op.Div => (ScalarType.i64, ScalarType.i64, None, "math", "div")
case Op.Rem => (ScalarType.i64, ScalarType.i64, None, "math", "rem")
case Op.Pow => (ScalarType.i64, ScalarType.u32, None, "math", "pow")
case Op.Gt => (ScalarType.i64, ScalarType.i64, Some(ScalarType.bool), "cmp", "gt")
case Op.Gte => (ScalarType.i64, ScalarType.i64, Some(ScalarType.bool), "cmp", "gte")
case Op.Lt => (ScalarType.i64, ScalarType.i64, Some(ScalarType.bool), "cmp", "lt")
case Op.Lte => (ScalarType.i64, ScalarType.i64, Some(ScalarType.bool), "cmp", "lte")
(ScalarType.integer, ScalarType.integer, uType)
case Op.Gt | Op.Lt | Op.Gte | Op.Lte =>
(ScalarType.integer, ScalarType.integer, ScalarType.bool)
}

for {
ltm <- T.ensureTypeMatches(l, leftType, leftRaw.`type`)
rtm <- T.ensureTypeMatches(r, rightType, rightRaw.`type`)
} yield Option.when(ltm && rtm)(
leftChecked <- T.ensureTypeOneOf(l, leftExp, lType)
rightChecked <- T.ensureTypeOneOf(r, rightExp, rType)
} yield Option.when(
leftChecked.isDefined && rightChecked.isDefined
)(
CallArrowRaw(
Some(id),
fn,
leftRaw :: rightRaw :: Nil,
ArrowType(
ProductType(leftType :: rightType :: Nil),
ProductType(
res.getOrElse(
// If result type is not known/enforced, then assume it's the widest type of operands
// E.g. 1:i8 + 1:i8 -> i8, not i64
leftRaw.`type` `∪` rightRaw.`type`
) :: Nil
)
ability = Some(id),
name = fn,
arguments = leftRaw :: rightRaw :: Nil,
baseType = ArrowType(
ProductType(lType :: rType :: Nil),
ProductType(resType :: Nil)
),
Some(LiteralRaw.quote(id))
serviceId = Some(LiteralRaw.quote(id))
)
)

Expand All @@ -228,9 +236,14 @@ class ValuesAlgebra[S[_], Alg[_]: Monad](implicit
}

// Generate CallArrowRaw for arrow in ability
def callAbType(ab: String, abType: AbilityType, ca: CallArrowToken[S]): Alg[Option[CallArrowRaw]] =
def callAbType(
ab: String,
abType: AbilityType,
ca: CallArrowToken[S]
): Alg[Option[CallArrowRaw]] =
abType.arrows.get(ca.funcName.value) match {
case Some(arrowType) => Option(CallArrowRaw(None, s"$ab.${ca.funcName.value}", Nil, arrowType, None)).pure[Alg]
case Some(arrowType) =>
Option(CallArrowRaw(None, s"$ab.${ca.funcName.value}", Nil, arrowType, None)).pure[Alg]
case None => None.pure[Alg]
}

Expand Down Expand Up @@ -279,7 +292,6 @@ class ValuesAlgebra[S[_], Alg[_]: Monad](implicit
case _ => none
}
}

)
result <- raw.flatTraverse(r =>
val arr = r.baseType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ trait TypesAlgebra[S[_], Alg[_]] {

def ensureTypeMatches(token: Token[S], expected: Type, givenType: Type): Alg[Boolean]

def ensureTypeOneOf[T <: Type](
token: Token[S],
expected: Set[T],
givenType: Type
): Alg[Option[Type]]

def expectNoExport(token: Token[S]): Alg[Unit]

def checkArgumentsNumber(token: Token[S], expected: Int, givenNum: Int): Alg[Boolean]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import cats.syntax.flatMap.*
import cats.syntax.functor.*
import cats.syntax.traverse.*
import cats.{~>, Applicative}
import cats.syntax.option.*
import monocle.Lens
import monocle.macros.GenLens

Expand Down Expand Up @@ -296,6 +297,21 @@ class TypesInterpreter[S[_], X](implicit
}
}

override def ensureTypeOneOf[T <: Type](
token: Token[S],
expected: Set[T],
givenType: Type
): State[X, Option[Type]] = expected
.find(_ acceptsValueOf givenType)
.fold(
reportError(
token,
"Types mismatch." ::
s"expected one of: ${expected.mkString(", ")}" ::
s"given: $givenType" :: Nil
).as(none)
)(_.some.pure)

override def expectNoExport(token: Token[S]): State[X, Unit] =
report(
token,
Expand Down
84 changes: 44 additions & 40 deletions types/src/main/scala/aqua/types/CompareTypes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -106,46 +106,50 @@ object CompareTypes {
* -1 if left is a subtype of the right
*/
def apply(l: Type, r: Type): Double =
if (l == r) 0.0
else
(l, r) match {
case (TopType, _) | (_, BottomType) => 1.0
case (BottomType, _) | (_, TopType) => -1.0

// Literals and scalars
case (x: ScalarType, y: ScalarType) => scalarOrder.partialCompare(x, y)
case (LiteralType(xs, _), y: ScalarType) if xs == Set(y) => 0.0
case (LiteralType(xs, _), y: ScalarType) if xs(y) => -1.0
case (x: ScalarType, LiteralType(ys, _)) if ys == Set(x) => 0.0
case (x: ScalarType, LiteralType(ys, _)) if ys(x) => 1.0

// Collections
case (x: ArrayType, y: ArrayType) => apply(x.element, y.element)
case (x: ArrayType, y: StreamType) => apply(x.element, y.element)
case (x: ArrayType, y: OptionType) => apply(x.element, y.element)
case (x: OptionType, y: OptionType) => apply(x.element, y.element)
case (x: OptionType, y: StreamType) => apply(x.element, y.element)
case (x: OptionType, y: ArrayType) => apply(x.element, y.element)
case (x: StreamType, y: StreamType) => apply(x.element, y.element)
case (lnt: AbilityType, rnt: AbilityType) => compareNamed(lnt.fields, rnt.fields)
case (lnt: StructType, rnt: StructType) => compareNamed(lnt.fields, rnt.fields)

// Products
case (l: ProductType, r: ProductType) => compareProducts(l, r)

// Arrows
case (ArrowType(ldom, lcodom), ArrowType(rdom, rcodom)) =>
val cmpDom = apply(ldom, rdom)
val cmpCodom = apply(lcodom, rcodom)

if (cmpDom == 0 && cmpCodom == 0) 0
else if (cmpDom <= 0 && cmpCodom >= 0) 1.0
else if (cmpDom >= 0 && cmpCodom <= 0) -1.0
else NaN

case _ =>
Double.NaN
}
(l, r) match {
case _ if l == r => 0.0

case (TopType, _) | (_, BottomType) => 1.0
case (BottomType, _) | (_, TopType) =>
-1.0

// Collections
case (x: ArrayType, y: ArrayType) => apply(x.element, y.element)
case (x: ArrayType, y: StreamType) => apply(x.element, y.element)
case (x: ArrayType, y: OptionType) => apply(x.element, y.element)
case (x: OptionType, y: OptionType) => apply(x.element, y.element)
case (x: OptionType, y: StreamType) => apply(x.element, y.element)
case (x: OptionType, y: ArrayType) => apply(x.element, y.element)
case (x: StreamType, y: StreamType) => apply(x.element, y.element)
case (lnt: AbilityType, rnt: AbilityType) => compareNamed(lnt.fields, rnt.fields)
case (lnt: StructType, rnt: StructType) => compareNamed(lnt.fields, rnt.fields)

// Literals and scalars
case (x: ScalarType, y: ScalarType) => scalarOrder.partialCompare(x, y)
case (LiteralType(xs, _), y: ScalarType) if xs == Set(y) => 0.0
case (LiteralType(xs, _), y: ScalarType) if xs(y) => -1.0
case (x: ScalarType, LiteralType(ys, _)) if ys == Set(x) => 0.0
case (x: ScalarType, LiteralType(ys, _)) if ys(x) => 1.0
case (LiteralType(xs, _), LiteralType(ys, _)) if xs == ys => 0.0
case (LiteralType(xs, _), LiteralType(ys, _)) if xs subsetOf ys => 1.0
case (LiteralType(xs, _), LiteralType(ys, _)) if ys subsetOf xs => -1.0

// Products
case (l: ProductType, r: ProductType) => compareProducts(l, r)

// Arrows
case (ArrowType(ldom, lcodom), ArrowType(rdom, rcodom)) =>
val cmpDom = apply(ldom, rdom)
val cmpCodom = apply(lcodom, rcodom)

if (cmpDom == 0 && cmpCodom == 0) 0
else if (cmpDom <= 0 && cmpCodom >= 0) 1.0
else if (cmpDom >= 0 && cmpCodom <= 0) -1.0
else NaN

case _ =>
Double.NaN
}

implicit val partialOrder: PartialOrder[Type] =
PartialOrder.from(CompareTypes.apply)
Expand Down
8 changes: 5 additions & 3 deletions types/src/main/scala/aqua/types/Type.scala
Original file line number Diff line number Diff line change
Expand Up @@ -172,19 +172,21 @@ object ScalarType {
val string = ScalarType("string")

val float = Set(f32, f64)
val signed = float ++ Set(i8, i16, i32, i64)
val signed = Set(i8, i16, i32, i64)
val unsigned = Set(u8, u16, u32, u64)
val number = signed ++ unsigned
val integer = signed ++ unsigned
val number = float ++ integer
val all = number ++ Set(bool, string)
}

case class LiteralType private (oneOf: Set[ScalarType], name: String) extends DataType {
override def toString: String = s"$name:lt"
override def toString: String = s"$name literal"
}

object LiteralType {
val float = LiteralType(ScalarType.float, "float")
val signed = LiteralType(ScalarType.signed, "signed")
val unsigned = LiteralType(ScalarType.unsigned, "unsigned")
val number = LiteralType(ScalarType.number, "number")
val bool = LiteralType(Set(ScalarType.bool), "bool")
val string = LiteralType(Set(ScalarType.string), "string")
Expand Down
5 changes: 1 addition & 4 deletions types/src/main/scala/aqua/types/UniteTypes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ case class UniteTypes(scalarsCombine: ScalarsCombine.T) extends Monoid[Type]:

override def combine(a: Type, b: Type): Type =
(a, b) match {
case _ if CompareTypes(a, b) == 0.0 => a

case (ap: ProductType, bp: ProductType) =>
combineProducts(ap, bp)

Expand Down Expand Up @@ -75,8 +73,7 @@ case class UniteTypes(scalarsCombine: ScalarsCombine.T) extends Monoid[Type]:
case 1.0 => a
case -1.0 => b
case 0.0 => a
case _ =>
TopType
case _ => TopType
}

}
Expand Down

0 comments on commit f9a034c

Please sign in to comment.