diff --git a/src/rewrite/vct/rewrite/EncodeRangedFor.scala b/src/rewrite/vct/rewrite/EncodeRangedFor.scala index 71d1cdd8d9..30a5b69570 100644 --- a/src/rewrite/vct/rewrite/EncodeRangedFor.scala +++ b/src/rewrite/vct/rewrite/EncodeRangedFor.scala @@ -1,8 +1,8 @@ package vct.col.rewrite import hre.util.ScopedStack -import vct.col.ast.{Select, Block, Assign, Expr, IterVariable, Local, LocalDecl, Loop, LoopInvariant, IterationContract, Eval, LoopContract, PostAssignExpression, Range, RangedFor, SeqMember, Statement, Variable} -import vct.col.origin.AssignLocalOk +import vct.col.ast.{Assign, Block, Eval, Expr, IterVariable, IterationContract, Local, LocalDecl, Loop, LoopContract, LoopInvariant, PostAssignExpression, Range, RangedFor, Select, SeqMember, Statement, TInt, Variable} +import vct.col.origin.{AssignLocalOk, Origin, PreferredNameOrigin} import vct.col.util.AstBuildHelpers.const import vct.col.util.AstBuildHelpers._ import vct.col.ast.RewriteHelpers._ @@ -10,27 +10,42 @@ import vct.col.ast.RewriteHelpers._ case object EncodeRangedFor extends RewriterBuilder { override def key: String = "encodeRangedFor" override def desc: String = "Encodes ranged for as a regular for loop" + + case class ForeachBoundOrigin(inner: Origin, name: String) extends PreferredNameOrigin } case class EncodeRangedFor[Pre <: Generation]() extends Rewriter[Pre] { + import EncodeRangedFor._ + val bounds = ScopedStack[(RangedFor[Pre], Expr[Post])]() override def dispatch(stat: Statement[Pre]): Statement[Post] = stat match { - case rf @ RangedFor(iv @ IterVariable(iVar, from, to), contract, body) => + case rf @ RangedFor(iv @ IterVariable(iVar, fromExpr, toExpr), contract, body) => implicit val o = iVar.o val i: Local[Post] = Local(succ[Variable[Post]](iVar))(iVar.o) + + val fromVar = new Variable[Post](TInt()(fromExpr.o))(ForeachBoundOrigin(fromExpr.o, "from")) + val from = Local(fromVar.ref[Variable[Post]])(ForeachBoundOrigin(fromExpr.o, "from")) + + val toVar = new Variable[Post](TInt()(toExpr.o))(ForeachBoundOrigin(toExpr.o, "to")) + val to = Local(toVar.ref[Variable[Post]])(ForeachBoundOrigin(toExpr.o, "to")) + Loop( Block(Seq( + LocalDecl(fromVar)(fromExpr.o), + LocalDecl(toVar)(toExpr.o), LocalDecl(variables.collect(dispatch(iVar))._1.head), - Assign(i, dispatch(from))(AssignLocalOk) + Assign(from, dispatch(fromExpr))(AssignLocalOk), + Assign(to, dispatch(toExpr))(AssignLocalOk), + Assign(i, from)(AssignLocalOk) ))(iVar.o), - { implicit val o = iv.o; SeqMember(i, Range(dispatch(from), dispatch(to))) }, + { implicit val o = iv.o; SeqMember(i, Range(from, to)) }, Eval(PostAssignExpression(i, i + const(1))(AssignLocalOk)), bounds.having((rf, { implicit val o = iv.o; - Select(dispatch(from) < dispatch(to), - SeqMember(i, Range(dispatch(from), dispatch(to) + const(1))), - i === dispatch(from)) + Select(from < to, + SeqMember(i, Range(from, to + const(1))), + i === from) }))(dispatch(contract)), dispatch(body) )(rf.o) diff --git a/test/main/vct/test/integration/examples/TechnicalPvlSpec.scala b/test/main/vct/test/integration/examples/TechnicalPvlSpec.scala index ffde8c3757..97ea4a146d 100644 --- a/test/main/vct/test/integration/examples/TechnicalPvlSpec.scala +++ b/test/main/vct/test/integration/examples/TechnicalPvlSpec.scala @@ -3,7 +3,7 @@ package vct.test.integration.examples import vct.test.integration.helper.VercorsSpec class TechnicalPvlSpec extends VercorsSpec { - vercors should verify using silicon in "ranged for loop" pvl + vercors should verify using silicon in "ranged for loop based on locals" pvl """ void m() { int max = -1; @@ -12,7 +12,39 @@ class TechnicalPvlSpec extends VercorsSpec { assert 0 <= i && i < 10; max = i; } - assert max == 9; + assert max == 10 - 1; + } + """ + + vercors should verify using silicon in "ranged for loop based on heap location" pvl + """ + class C { int f; } + + requires Perm(c.f, 1) ** c.f > 0; + void heapTest1(C c) { + int max = -1; + loop_invariant i - 1 == max; + for (int i = 0 .. c.f) { + max = i; + } + assert max == c.f - 1; + } + """ + + vercors should verify using silicon in "ranged for loop based on heap location used inside" pvl + """ + class C { int f; } + + requires Perm(c.f, 1) ** c.f > 0; + void heapTest2(C c) { + int max = -1; + loop_invariant i - 1 == max; + loop_invariant Perm(c.f, 1\2) ** c.f == \old(c.f); + for (int i = 0 .. c.f) { + assert 0 <= i && i < c.f; + max = i; + } + assert max == c.f - 1; } """ }