diff --git a/model/inline/src/main/scala/aqua/model/inline/TagInliner.scala b/model/inline/src/main/scala/aqua/model/inline/TagInliner.scala index 5e7cf45cf..974f3bddd 100644 --- a/model/inline/src/main/scala/aqua/model/inline/TagInliner.scala +++ b/model/inline/src/main/scala/aqua/model/inline/TagInliner.scala @@ -4,9 +4,12 @@ import aqua.model.inline.state.{Arrows, Exports, Mangler} import aqua.model.* import aqua.model.inline.RawValueInliner.collectionToModel import aqua.model.inline.raw.CallArrowRawInliner +import aqua.raw.value.ApplyBinaryOpRaw.Op as BinOp import aqua.raw.ops.* import aqua.raw.value.* import aqua.types.{BoxType, CanonStreamType, StreamType} +import aqua.model.inline.Inline.parDesugarPrefixOpt + import cats.syntax.traverse.* import cats.syntax.applicative.* import cats.syntax.flatMap.* @@ -18,7 +21,6 @@ import cats.data.{Chain, State, StateT} import cats.syntax.show.* import cats.syntax.bifunctor.* import scribe.{log, Logging} -import aqua.model.inline.Inline.parDesugarPrefixOpt /** * [[TagInliner]] prepares a [[RawTag]] for futher processing by converting [[ValueRaw]]s into [[ValueModel]]s. @@ -202,12 +204,36 @@ object TagInliner extends Logging { prefix = parDesugarPrefix(viaF.prependedAll(pif)) ) - case IfTag(leftRaw, rightRaw, shouldMatch) => - ( - valueToModel(leftRaw) >>= canonicalizeIfStream.tupled, - valueToModel(rightRaw) >>= canonicalizeIfStream.tupled - ).mapN { case ((leftModel, leftPrefix), (rightModel, rightPrefix)) => - val prefix = parDesugarPrefixOpt(leftPrefix, rightPrefix) + case IfTag(valueRaw) => + (valueRaw match { + case ApplyBinaryOpRaw(op @ (BinOp.Eq | BinOp.Neq), left, right) => + ( + valueToModel(left), + valueToModel(right) + ).mapN { case ((lmodel, lprefix), (rmodel, rprefix)) => + val prefix = parDesugarPrefixOpt(lprefix, rprefix) + val matchModel = MatchMismatchModel( + left = lmodel, + right = rmodel, + shouldMatch = op match { + case BinOp.Eq => true + case BinOp.Neq => false + } + ) + + (prefix, matchModel) + } + case _ => + valueToModel(valueRaw).map { case (valueModel, prefix) => + val matchModel = MatchMismatchModel( + left = valueModel, + right = LiteralModel.bool(true), + shouldMatch = true + ) + + (prefix, matchModel) + } + }).map { case (prefix, matchModel) => val toModel = (children: Chain[OpModel.Tree]) => XorModel.wrap( children.uncons.map { case (ifBody, elseBody) => @@ -227,11 +253,7 @@ object TagInliner extends Logging { ) else elseBodyFiltered - MatchMismatchModel( - leftModel, - rightModel, - shouldMatch - ).wrap(ifBody) +: elseBodyAugmented + matchModel.wrap(ifBody) +: elseBodyAugmented }.getOrElse(children) ) diff --git a/model/inline/src/main/scala/aqua/model/inline/raw/ApplyBinaryOpRawInliner.scala b/model/inline/src/main/scala/aqua/model/inline/raw/ApplyBinaryOpRawInliner.scala index 4104d474b..3006123d7 100644 --- a/model/inline/src/main/scala/aqua/model/inline/raw/ApplyBinaryOpRawInliner.scala +++ b/model/inline/src/main/scala/aqua/model/inline/raw/ApplyBinaryOpRawInliner.scala @@ -92,6 +92,7 @@ object ApplyBinaryOpRawInliner extends RawInliner[ApplyBinaryOpRaw] { val predo = (resName: String) => SeqModel.wrap( linline.predo ++ rinline.predo :+ XorModel.wrap( + // TODO: Canonicalize values if they are streams MatchMismatchModel(lmodel, rmodel, shouldMatch).wrap( FlattenModel( LiteralModel.bool(true), diff --git a/model/raw/src/main/scala/aqua/raw/ops/RawTag.scala b/model/raw/src/main/scala/aqua/raw/ops/RawTag.scala index 20238e47d..58722c04d 100644 --- a/model/raw/src/main/scala/aqua/raw/ops/RawTag.scala +++ b/model/raw/src/main/scala/aqua/raw/ops/RawTag.scala @@ -80,7 +80,7 @@ case object ParTag extends ParGroupTag { case object Par extends GroupTag } -case class IfTag(left: ValueRaw, right: ValueRaw, equal: Boolean) extends GroupTag +case class IfTag(value: ValueRaw) extends GroupTag object IfTag { diff --git a/parser/src/main/scala/aqua/parser/expr/func/IfExpr.scala b/parser/src/main/scala/aqua/parser/expr/func/IfExpr.scala index c33a1600e..7352b20c9 100644 --- a/parser/src/main/scala/aqua/parser/expr/func/IfExpr.scala +++ b/parser/src/main/scala/aqua/parser/expr/func/IfExpr.scala @@ -3,7 +3,7 @@ package aqua.parser.expr.func import aqua.parser.Expr import aqua.parser.expr.func.{ForExpr, IfExpr} import aqua.parser.lexer.Token.* -import aqua.parser.lexer.{EqOp, LiteralToken, ValueToken} +import aqua.parser.lexer.{LiteralToken, ValueToken} import aqua.parser.lift.LiftParser import aqua.types.LiteralType import cats.parse.Parser as P @@ -11,11 +11,10 @@ import cats.{~>, Comonad} import aqua.parser.lift.Span import aqua.parser.lift.Span.{P0ToSpan, PToSpan} -case class IfExpr[F[_]](left: ValueToken[F], eqOp: EqOp[F], right: ValueToken[F]) - extends Expr[F](IfExpr, eqOp) { +case class IfExpr[F[_]](value: ValueToken[F]) extends Expr[F](IfExpr, value) { override def mapK[K[_]: Comonad](fk: F ~> K): IfExpr[K] = - copy(left.mapK(fk), eqOp.mapK(fk), right.mapK(fk)) + copy(value.mapK(fk)) } object IfExpr extends Expr.AndIndented { @@ -24,10 +23,5 @@ object IfExpr extends Expr.AndIndented { override def validChildren: List[Expr.Lexem] = ForExpr.validChildren override val p: P[IfExpr[Span.S]] = - (`if` *> ` ` *> ValueToken.`value` ~ (` ` *> EqOp.p ~ (` ` *> ValueToken.`value`)).?).map { - case (left, Some((e, right))) => - IfExpr(left, e, right) - case (left, None) => - IfExpr(left, EqOp(left.as(true)), LiteralToken(left.as("true"), LiteralType.bool)) - } + (`if` *> ` ` *> ValueToken.`value`).map(IfExpr(_)) } diff --git a/parser/src/main/scala/aqua/parser/lexer/EqOp.scala b/parser/src/main/scala/aqua/parser/lexer/EqOp.scala deleted file mode 100644 index fa9b5d0e6..000000000 --- a/parser/src/main/scala/aqua/parser/lexer/EqOp.scala +++ /dev/null @@ -1,27 +0,0 @@ -package aqua.parser.lexer - -import aqua.parser.lift.LiftParser -import cats.Comonad -import cats.syntax.functor.* -import Token.* -import cats.parse.Parser as P -import LiftParser.* -import cats.syntax.comonad.* -import cats.~> -import aqua.parser.lift.Span -import aqua.parser.lift.Span.{P0ToSpan, PToSpan} - -case class EqOp[F[_]: Comonad](eq: F[Boolean]) extends Token[F] { - override def as[T](v: T): F[T] = eq.as(v) - - override def mapK[K[_]: Comonad](fk: F ~> K): EqOp[K] = - copy(fk(eq)) - - def value: Boolean = eq.extract -} - -object EqOp { - - val p: P[EqOp[Span.S]] = - (`eqs`.as(true).lift | `neq`.as(false).lift).map(EqOp(_)) -} diff --git a/semantics/src/main/scala/aqua/semantics/expr/func/IfSem.scala b/semantics/src/main/scala/aqua/semantics/expr/func/IfSem.scala index f7087a109..3a7338761 100644 --- a/semantics/src/main/scala/aqua/semantics/expr/func/IfSem.scala +++ b/semantics/src/main/scala/aqua/semantics/expr/func/IfSem.scala @@ -17,6 +17,8 @@ import cats.syntax.applicative.* import cats.syntax.flatMap.* import cats.syntax.functor.* import cats.syntax.apply.* +import cats.syntax.traverse.* +import aqua.types.ScalarType class IfSem[S[_]](val expr: IfExpr[S]) extends AnyVal { @@ -29,30 +31,26 @@ class IfSem[S[_]](val expr: IfExpr[S]) extends AnyVal { ): Prog[Alg, Raw] = Prog .around( - (V.valueToRaw(expr.left), V.valueToRaw(expr.right)).flatMapN { - case (Some(lt), Some(rt)) => - T.ensureValuesComparable( - token = expr.token, - left = lt.`type`, - right = rt.`type` - ).map(Option.when(_)(lt -> rt)) - case _ => None.pure - }, - (values: Option[(ValueRaw, ValueRaw)], ops: Raw) => - values + V.valueToRaw(expr.value) + .flatMap( + _.flatTraverse(raw => + T.ensureTypeMatches( + token = expr.value, + expected = ScalarType.bool, + givenType = raw.`type` + ).map(Option.when(_)(raw)) + ) + ), + (value: Option[ValueRaw], ops: Raw) => + value .fold( Raw.error("`if` expression errored in matching types") - ) { case (lt, rt) => + )(raw => ops match { - case FuncOp(op) => - IfTag( - left = lt, - right = rt, - equal = expr.eqOp.value - ).wrap(op).toFuncOp + case FuncOp(op) => IfTag(raw).wrap(op).toFuncOp case _ => Raw.error("Wrong body of the `if` expression") } - } + ) .pure ) .abilitiesScope[S](expr.token)