diff --git a/src/col/vct/col/ast/Node.scala b/src/col/vct/col/ast/Node.scala index 73fcefc29b..5d7d6fb3e9 100644 --- a/src/col/vct/col/ast/Node.scala +++ b/src/col/vct/col/ast/Node.scala @@ -207,7 +207,7 @@ final case class Branch[G](branches: Seq[(Expr[G], Statement[G])])(implicit val final case class IndetBranch[G](branches: Seq[Statement[G]])(implicit val o: Origin) extends CompositeStatement[G] with IndetBranchImpl[G] final case class Switch[G](expr: Expr[G], body: Statement[G])(implicit val o: Origin) extends CompositeStatement[G] with SwitchImpl[G] final case class Loop[G](init: Statement[G], cond: Expr[G], update: Statement[G], contract: LoopContract[G], body: Statement[G])(implicit val o: Origin) extends CompositeStatement[G] with LoopImpl[G] -final case class RangedFor[G](binder: Variable[G], from: Expr[G], to: Expr[G], contract: LoopContract[G], body: Statement[G])(implicit val o: Origin) extends CompositeStatement[G] +final case class RangedFor[G](iter: IterVariable[G], contract: LoopContract[G], body: Statement[G])(implicit val o: Origin) extends CompositeStatement[G] with Declarator[G] with RangedForImpl[G] final case class TryCatchFinally[G](body: Statement[G], after: Statement[G], catches: Seq[CatchClause[G]])(implicit val o: Origin) extends CompositeStatement[G] with TryCatchFinallyImpl[G] final case class Synchronized[G](obj: Expr[G], body: Statement[G])(val blame: Blame[LockRegionFailure])(implicit val o: Origin) extends CompositeStatement[G] with SynchronizedImpl[G] final case class ParInvariant[G](decl: ParInvariantDecl[G], inv: Expr[G], content: Statement[G])(val blame: Blame[ParInvariantNotEstablished])(implicit val o: Origin) extends CompositeStatement[G] with ParInvariantImpl[G] diff --git a/src/col/vct/col/ast/statement/composite/RangedForImpl.scala b/src/col/vct/col/ast/statement/composite/RangedForImpl.scala new file mode 100644 index 0000000000..f7de193f0e --- /dev/null +++ b/src/col/vct/col/ast/statement/composite/RangedForImpl.scala @@ -0,0 +1,13 @@ +package vct.col.ast.statement.composite + +import vct.col.ast.{Declaration, RangedFor} +import vct.col.print._ + +trait RangedForImpl[G] { this: RangedFor[G] => + override def declarations: Seq[Declaration[G]] = Seq(iter.variable) + + override def layout(implicit ctx: Ctx): Doc = + Doc.stack(Seq( + Group(Text("for") <> "(" <> Doc.arg(iter) <> ")"), + )) <+> body.layoutAsBlock +} diff --git a/src/col/vct/col/typerules/CoercingRewriter.scala b/src/col/vct/col/typerules/CoercingRewriter.scala index abfce24489..7d42cb21b7 100644 --- a/src/col/vct/col/typerules/CoercingRewriter.scala +++ b/src/col/vct/col/typerules/CoercingRewriter.scala @@ -1568,6 +1568,7 @@ abstract class CoercingRewriter[Pre <: Generation]() extends AbstractRewriter[Pr case bar @ ParBarrier(block, invs, requires, ensures, content) => ParBarrier(block, invs, res(requires), res(ensures), content)(bar.blame) case p @ ParInvariant(decl, inv, content) => ParInvariant(decl, res(inv), content)(p.blame) case ParStatement(impl) => ParStatement(impl) + case RangedFor(iter, contract, body) => RangedFor(iter, contract, body) case Recv(ref) => Recv(ref) case r @ Refute(assn) => Refute(res(assn))(r.blame) case Return(result) => Return(result) // TODO coerce return, make AmbiguousReturn? diff --git a/src/colhelper/ColDefs.scala b/src/colhelper/ColDefs.scala index b1cd9c6172..85ce4d714a 100644 --- a/src/colhelper/ColDefs.scala +++ b/src/colhelper/ColDefs.scala @@ -52,7 +52,7 @@ object ColDefs { "ModelDeclaration" -> Seq("Program"), "EnumConstant" -> Seq("Program"), "Variable" -> Seq( - "ParBlock", "VecBlock", "CatchClause", "Scope", "SignalsClause", // Explicit declarations + "ParBlock", "VecBlock", "CatchClause", "Scope", "SignalsClause", "RangedForLoop", // Explicit declarations "AxiomaticDataType", "JavaClass", "JavaInterface", // Type arguments "Predicate", "InstancePredicate", // Arguments "ModelProcess", "ModelAction", "ADTFunction", diff --git a/src/main/vct/main/stages/Transformation.scala b/src/main/vct/main/stages/Transformation.scala index 0561ff4ea1..acb83de09c 100644 --- a/src/main/vct/main/stages/Transformation.scala +++ b/src/main/vct/main/stages/Transformation.scala @@ -182,6 +182,7 @@ case class SilverTransformation // Normalize AST Disambiguate, // Resolve overloaded operators (+, subscript, etc.) DisambiguateLocation, // Resolve location type + EncodeRangedFor, EncodeString, // Encode spec string as seq EncodeChar, diff --git a/src/parsers/vct/parsers/transform/PVLToCol.scala b/src/parsers/vct/parsers/transform/PVLToCol.scala index 6a4a0fa754..5d19f1ba49 100644 --- a/src/parsers/vct/parsers/transform/PVLToCol.scala +++ b/src/parsers/vct/parsers/transform/PVLToCol.scala @@ -345,11 +345,9 @@ case class PVLToCol[G](override val originProvider: OriginProvider, override val convert(body) )) ) - case PvlRangedFor(contract, _, _, Iter0(t, name, _, from, _, to), _, body) => + case PvlRangedFor(contract, _, _, iter, _, body) => withContract(contract, contract => - Scope(Nil, RangedFor( - new Variable(convert(t))(origin(name)), - convert(from), convert(to), + Scope(Nil, RangedFor(convert(iter), contract.consumeLoopContract(stat), convert(body)))) case PvlBlock(inner) => convert(inner) diff --git a/src/rewrite/vct/rewrite/EncodeRangedFor.scala b/src/rewrite/vct/rewrite/EncodeRangedFor.scala index ff1725cd5c..53a89d5b0e 100644 --- a/src/rewrite/vct/rewrite/EncodeRangedFor.scala +++ b/src/rewrite/vct/rewrite/EncodeRangedFor.scala @@ -1,7 +1,11 @@ package vct.col.rewrite -import vct.col.ast.{Statement, RangedFor} +import hre.util.ScopedStack +import vct.col.ast.{Block, Assign, Expr, IterVariable, Local, LocalDecl, Loop, LoopInvariant, IterationContract, Eval, LoopContract, PostAssignExpression, Range, RangedFor, SeqMember, Statement, Variable} +import vct.col.origin.AssignLocalOk import vct.col.util.AstBuildHelpers.const +import vct.col.util.AstBuildHelpers._ +import vct.col.ast.RewriteHelpers._ case object EncodeRangedFor extends RewriterBuilder { override def key: String = "encodeRangedFor" @@ -9,8 +13,30 @@ case object EncodeRangedFor extends RewriterBuilder { } case class EncodeRangedFor[Pre <: Generation]() extends Rewriter[Pre] { + val bounds = ScopedStack[(RangedFor[Pre], Expr[Post])]() + override def dispatch(stat: Statement[Pre]): Statement[Post] = stat match { -// case RangedFor(v, from, to, contract, body) => ??? + case rf @ RangedFor(IterVariable(iVar, from, to), contract, body) => + implicit val o = iVar.o + val i = Local[Post](anySucc[Variable[Post]](iVar))(iVar.o) + Loop( + Block(Seq( + LocalDecl(variables.collect(dispatch(iVar))._1.head), + Assign(i, dispatch(from))(AssignLocalOk) + ))(iVar.o), + SeqMember(i, Range(dispatch(from), dispatch(to))), + Eval(PostAssignExpression(i, i + const(1))(AssignLocalOk)), + bounds.having((rf, SeqMember(i, Range(dispatch(from), dispatch(to) + const(1)))))(dispatch(contract)), + dispatch(body) + )(rf.o) case stat => rewriteDefault(stat) } + + override def dispatch(contract: LoopContract[Pre]): LoopContract[Post] = (bounds.topOption, contract) match { + case (Some((rf, iBounds)), l: LoopInvariant[Pre]) => + l.rewrite(invariant = (iBounds && dispatch(l.invariant))(rf.o)) + case (Some((rf, iBounds)), ic: IterationContract[Pre]) => + ic.rewrite(context_everywhere = (iBounds && dispatch(ic.context_everywhere))(rf.o)) + case (None, c) => rewriteDefault(c) + } }