Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VCLLVM] Add partial contract support for LLVM #1034

Merged
merged 8 commits into from
Jun 28, 2023
4 changes: 3 additions & 1 deletion build.sc
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ object parsers extends VercorsModule {
Seq(
antlr.c.generate(),
antlr.java.generate(),
antlr.pvl.generate()
antlr.pvl.generate(),
antlr.llvm.generate(),
)
}
def deps = Agg(
Expand All @@ -60,6 +61,7 @@ object rewrite extends VercorsModule {
def deps = Agg(
ivy"org.sosy-lab:java-smt:3.14.3",
ivy"com.lihaoyi::upickle:2.0.0",
ivy"org.antlr:antlr4-runtime:4.8",
)
def moduleDeps = Seq(hre, col)
}
Expand Down
9 changes: 9 additions & 0 deletions project/antlr.sc
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,13 @@ object pvl extends GenModule {
"SpecParser.g4", "SpecLexer.g4",
"LangPVLParser.g4", "LangPVLLexer.g4",
)
}

object llvm extends GenModule {
def lexer = "LangLLVMSpecLexer.g4"
def parser = "LLVMSpecParser.g4"
def deps = Seq(
"SpecParser.g4", "SpecLexer.g4",
"LangLLVMSpecParser.g4", "LangLLVMSpecLexer.g4"
)
}
19 changes: 17 additions & 2 deletions src/col/vct/col/ast/Node.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1013,6 +1013,7 @@ final case class JavaNamedType[G](names: Seq[(String, Option[Seq[Type[G]]])])(im
var ref: Option[JavaTypeNameTarget[G]] = None
}
final case class JavaTClass[G](ref: Ref[G, JavaClassOrInterface[G]], typeArgs: Seq[Type[G]])(implicit val o: Origin = DiagnosticOrigin) extends JavaType[G] with JavaTClassImpl[G]

final case class JavaWildcard[G]()(implicit val o: Origin = DiagnosticOrigin) extends JavaType[G] with JavaWildcardImpl[G]

sealed trait JavaExpr[G] extends Expr[G] with JavaExprImpl[G]
Expand Down Expand Up @@ -1080,9 +1081,11 @@ final case class BipInternal[G]()(implicit val o: Origin = DiagnosticOrigin) ext
final case class BipPortSynchronization[G](ports: Seq[Ref[G, BipPort[G]]], wires: Seq[BipGlueDataWire[G]])(val blame: Blame[BipSynchronizationFailure])(implicit val o: Origin) extends GlobalDeclaration[G] with BipPortSynchronizationImpl[G]
final case class BipTransitionSynchronization[G](transitions: Seq[Ref[G, BipTransition[G]]], wires: Seq[BipGlueDataWire[G]])(val blame: Blame[BipSynchronizationFailure])(implicit val o: Origin) extends GlobalDeclaration[G] with BipTransitionSynchronizationImpl[G]

final class LlvmFunctionContract[G](val value:String, val references:Seq[(String, Ref[G, Declaration[G]])])
final class LlvmFunctionContract[G](val value:String, val variableRefs:Seq[(String, Ref[G, Variable[G]])], val invokableRefs:Seq[(String, Ref[G, LlvmFunctionDefinition[G]])])
(val blame: Blame[NontrivialUnsatisfiable])
(implicit val o: Origin) extends NodeFamily[G] with LLVMFunctionContractImpl[G]
(implicit val o: Origin) extends NodeFamily[G] with LLVMFunctionContractImpl[G] {
var data: Option[ApplicableContract[G]] = None
}

final class LlvmFunctionDefinition[G](val returnType: Type[G],
val args: Seq[Variable[G]],
Expand All @@ -1106,6 +1109,18 @@ final case class LlvmLoopInvariant[G](value:String, references:Seq[(String, Ref[
(val blame: Blame[LoopInvariantFailure])
(implicit val o: Origin) extends LlvmLoopContract[G] with LLVMLoopInvariantImpl[G]

sealed trait LlvmExpr[G] extends Expr[G] with LLVMExprImpl[G]

final case class LlvmLocal[G](name: String)(val blame: Blame[DerefInsufficientPermission])(implicit val o: Origin) extends LlvmExpr[G] with LLVMLocalImpl[G] {
var ref: Option[Ref[G, Variable[G]]] = None
}
final case class LlvmAmbiguousFunctionInvocation[G](name: String,
args: Seq[Expr[G]],
givenMap: Seq[(Ref[G, Variable[G]], Expr[G])],
yields: Seq[(Expr[G], Ref[G, Variable[G]])])
(val blame: Blame[InvocationFailure])(implicit val o: Origin) extends LlvmExpr[G] with LLVMAmbiguousFunctionInvocationImpl[G] {
var ref: Option[Ref[G, LlvmFunctionDefinition[G]]] = None
}
sealed trait PVLType[G] extends Type[G] with PVLTypeImpl[G]
final case class PVLNamedType[G](name: String, typeArgs: Seq[Type[G]])(implicit val o: Origin = DiagnosticOrigin) extends PVLType[G] with PVLNamedTypeImpl[G] {
var ref: Option[PVLTypeNameTarget[G]] = None
Expand Down
1 change: 1 addition & 0 deletions src/col/vct/col/ast/expr/context/AmbiguousResultImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ trait AmbiguousResultImpl[G] extends NodeFamilyImpl[G] { this: AmbiguousResult[G
case RefFunction(decl) => decl.returnType
case RefProcedure(decl) => decl.returnType
case RefJavaMethod(decl) => decl.returnType
case RefLlvmFunctionDefinition(decl) => decl.returnType
case RefInstanceFunction(decl) => decl.returnType
case RefInstanceMethod(decl) => decl.returnType
case RefInstanceOperatorMethod(decl) => decl.returnType
Expand Down
18 changes: 18 additions & 0 deletions src/col/vct/col/ast/lang/LLVMAmbiguousFunctionInvocationImpl.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package vct.col.ast.lang

import vct.col.ast.{LlvmAmbiguousFunctionInvocation, Type}
import vct.col.print.{Ctx, Doc, DocUtil, Group, Precedence, Text}

trait LLVMAmbiguousFunctionInvocationImpl[G] { this: LlvmAmbiguousFunctionInvocation[G] =>
override lazy val t: Type[G] = ref match {
case Some(ref) => ref.decl.returnType
}

override def precedence: Int = Precedence.POSTFIX

override def layout(implicit ctx: Ctx): Doc =
Group(
Group(
Text(name) <> "(") <> Doc.args(args) <> ")" <> DocUtil.givenYields(givenMap, yields)
)
}
7 changes: 7 additions & 0 deletions src/col/vct/col/ast/lang/LLVMExprImpl.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package vct.col.ast.lang


import vct.col.ast.LlvmExpr
trait LLVMExprImpl[G] { this: LlvmExpr[G] =>

}
12 changes: 12 additions & 0 deletions src/col/vct/col/ast/lang/LLVMLocalImpl.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package vct.col.ast.lang

import vct.col.ast.{LlvmLocal, Type}
import vct.col.print.{Ctx, Doc, Text}

trait LLVMLocalImpl[G] { this: LlvmLocal[G] =>
override lazy val t: Type[G] = ref match {
case Some(ref) => ref.decl.t
}

override def layout(implicit ctx: Ctx): Doc = Text(name)
}
12 changes: 7 additions & 5 deletions src/col/vct/col/origin/Origin.scala
Original file line number Diff line number Diff line change
Expand Up @@ -323,13 +323,11 @@ case class SourceNameOrigin(name: String, inner: Origin) extends Origin {
}

case object RedirectOrigin {
case class StringReadable(data: String) extends Readable {
case class StringReadable(data: String, fileName:String="<unknown filename>") extends Readable {
override def isRereadable: Boolean = true

override protected def getReader: Reader =
new StringReader(data)

override def fileName: String = "<unknown filename>"
}
}

Expand Down Expand Up @@ -379,6 +377,8 @@ case class LLVMOrigin(deserializeOrigin: Deserialize.Origin) extends Origin {
case string => Some(JsonParser(string).asJsObject().fields)
}

def fileName: String = deserializeOrigin.fileName

override def preferredName: String = parsedOrigin match {
case Some(o) => o.get("preferredName") match {
case Some(JsString(jsString)) => jsString
Expand All @@ -397,10 +397,12 @@ case class LLVMOrigin(deserializeOrigin: Deserialize.Origin) extends Origin {

override def context: String = {
val atLine = f" At $shortPosition:\n"
if(contextFragment == inlineContext) {
if (contextFragment == inlineContext) {
atLine + Origin.HR + contextFragment
} else {
} else if (contextFragment.contains(inlineContext)) {
atLine + Origin.HR + markedInlineContext
} else {
deserializeOrigin.context
}
}

Expand Down
35 changes: 32 additions & 3 deletions src/col/vct/col/resolve/Resolve.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package vct.col.resolve

import com.typesafe.scalalogging.LazyLogging
import hre.data.BitString
import hre.util.FuncTools
import vct.col.ast._
import vct.col.ast.util.Declarator
Expand All @@ -10,11 +11,13 @@ import vct.col.resolve.ResolveReferences.scanScope
import vct.col.ref.Ref
import vct.col.resolve.ctx._
import vct.col.resolve.lang.{C, Java, PVL, Spec}
import vct.col.resolve.Resolve.{MalformedBipAnnotation, SpecExprParser, getLit, isBip}
import vct.col.resolve.Resolve.{MalformedBipAnnotation, SpecContractParser, SpecExprParser, getLit, isBip}
import vct.col.resolve.lang.JavaAnnotationData.{BipComponent, BipData, BipGuard, BipInvariant, BipPort, BipPure, BipStatePredicate, BipTransition}
import vct.col.rewrite.InitialGeneration
import vct.result.VerificationError.UserError

import scala.collection.immutable.{AbstractSeq, LinearSeq}

case object Resolve {
case class MalformedBipAnnotation(n: Node[_], err: String) extends UserError {
override def code: String = "badBipAnnotation"
Expand All @@ -27,6 +30,10 @@ case object Resolve {
def parse[G](input: String, o: Origin): Expr[G]
}

trait SpecContractParser {
def parse[G](input: LlvmFunctionContract[G], o:Origin): ApplicableContract[G]
}

def extractLiteral(e: Expr[_]): Option[String] = e match {
case JavaStringValue(guardName, _) =>
Some(guardName)
Expand Down Expand Up @@ -170,8 +177,8 @@ case object ResolveTypes {
}

case object ResolveReferences extends LazyLogging {
def resolve[G](program: Program[G], jp: SpecExprParser): Seq[CheckError] = {
resolve(program, ReferenceResolutionContext[G](jp))
def resolve[G](program: Program[G], jp: SpecExprParser, lsp: SpecContractParser): Seq[CheckError] = {
resolve(program, ReferenceResolutionContext[G](jp, lsp))
}

def resolve[G](node: Node[G], ctx: ReferenceResolutionContext[G], inGPUKernel: Boolean=false): Seq[CheckError] = {
Expand Down Expand Up @@ -304,6 +311,8 @@ case object ResolveReferences extends LazyLogging {
.declare(info.params.getOrElse(Nil))
.copy(currentResult=info.params.map(_ => RefCGlobalDeclaration(func, idx)))
}
case func: LlvmFunctionDefinition[G] => ctx
.copy(currentResult = Some(RefLlvmFunctionDefinition(func)))
case par: ParStatement[G] => ctx
.declare(scanBlocks(par.impl).map(_.decl))
case Scope(locals, body) => ctx
Expand Down Expand Up @@ -572,6 +581,26 @@ case object ResolveReferences extends LazyLogging {
case portName @ JavaBipGlueName(JavaTClass(Ref(cls: JavaClass[G]), Nil), name) =>
portName.data = Some((cls, getLit(name)))

case contract: LlvmFunctionContract[G] =>
val applicableContract = ctx.llvmSpecParser.parse(contract, contract.o)
contract.data = Some(applicableContract)
resolve(applicableContract, ctx)
case local: LlvmLocal[G] =>
local.ref = ctx.currentResult.get match {
case RefLlvmFunctionDefinition(decl) =>
decl.contract.variableRefs.find(ref => ref._1 == local.name) match {
case Some(ref) => Some(ref._2)
case None => throw NoSuchNameError("local", local.name, local)
}
}
case inv: LlvmAmbiguousFunctionInvocation[G] =>
inv.ref = ctx.currentResult.get match {
case RefLlvmFunctionDefinition(decl) =>
decl.contract.invokableRefs.find(ref => ref._1 == inv.name) match {
case Some(ref) => Some(ref._2)
case None => throw NoSuchNameError("function", inv.name, inv)
}
}
case _ =>
}
}
3 changes: 2 additions & 1 deletion src/col/vct/col/resolve/ctx/ReferenceResolutionContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package vct.col.resolve.ctx
import vct.col.ast._
import vct.col.check.CheckContext
import vct.col.origin.DiagnosticOrigin
import vct.col.resolve.Resolve.SpecExprParser
import vct.col.resolve.Resolve.{SpecContractParser, SpecExprParser}
import vct.col.util.SuccessionMap

import scala.collection.immutable.{HashMap, ListMap}
Expand All @@ -12,6 +12,7 @@ import scala.collection.mutable
case class ReferenceResolutionContext[G]
(
javaParser: SpecExprParser,
llvmSpecParser: SpecContractParser,
stack: Seq[Seq[Referrable[G]]] = Nil,
topLevelJavaDeref: Option[JavaDeref[G]] = None,
externallyLoadedElements: mutable.ArrayBuffer[GlobalDeclaration[G]] = mutable.ArrayBuffer[GlobalDeclaration[G]](),
Expand Down
5 changes: 3 additions & 2 deletions src/col/vct/col/resolve/ctx/Referrable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ case object Referrable {
case decl: VeyMontSeqProg[G] => RefSeqProg(decl)
case decl: VeyMontThread[G] => RefVeyMontThread(decl)
case decl: JavaBipGlueContainer[G] => RefJavaBipGlueContainer()
case decl: LlvmFunctionDefinition[G] => RefLlvmFunctionDefinition()
case decl: LlvmFunctionDefinition[G] => RefLlvmFunctionDefinition(decl)
case decl: ProverType[G] => RefProverType(decl)
case decl: ProverFunction[G] => RefProverFunction(decl)
})
Expand Down Expand Up @@ -166,6 +166,7 @@ sealed trait SpecDerefTarget[G] extends CDerefTarget[G] with JavaDerefTarget[G]
sealed trait JavaInvocationTarget[G] extends Referrable[G]
sealed trait CInvocationTarget[G] extends Referrable[G]
sealed trait PVLInvocationTarget[G] extends Referrable[G]
sealed trait LlvmInvocationTarget[G] extends Referrable[G]
sealed trait SpecInvocationTarget[G]
extends JavaInvocationTarget[G]
with CNameTarget[G]
Expand Down Expand Up @@ -227,7 +228,7 @@ case class RefPVLConstructor[G](decl: PVLConstructor[G]) extends Referrable[G] w
case class RefJavaBipStatePredicate[G](state: String, decl: JavaAnnotation[G]) extends Referrable[G] with JavaBipStatePredicateTarget[G]
case class RefJavaBipGuard[G](decl: JavaMethod[G]) extends Referrable[G] with JavaNameTarget[G]
case class RefJavaBipGlueContainer[G]() extends Referrable[G] // Bip glue jobs are not actually referrable
case class RefLlvmFunctionDefinition[G]() extends Referrable[G]
case class RefLlvmFunctionDefinition[G](decl: LlvmFunctionDefinition[G]) extends Referrable[G] with LlvmInvocationTarget[G] with ResultTarget[G]
case class RefSeqProg[G](decl: VeyMontSeqProg[G]) extends Referrable[G]
case class RefVeyMontThread[G](decl: VeyMontThread[G]) extends Referrable[G] with PVLNameTarget[G]
case class RefProverType[G](decl: ProverType[G]) extends Referrable[G] with SpecTypeNameTarget[G]
Expand Down
3 changes: 3 additions & 0 deletions src/col/vct/col/typerules/CoercingRewriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1189,6 +1189,8 @@ abstract class CoercingRewriter[Pre <: Generation]() extends AbstractRewriter[Pr
ProcedureInvocation(ref, coerceArgs(args, ref.decl, typeArgs), outArgs, typeArgs, coerceGiven(givenMap), coerceYields(yields, inv))(inv.blame)
case inv @ LlvmFunctionInvocation(ref, args, givenMap, yields) =>
LlvmFunctionInvocation(ref, args, givenMap, yields)(inv.blame)
case inv @ LlvmAmbiguousFunctionInvocation(name, args, givenMap, yields) =>
LlvmAmbiguousFunctionInvocation(name, args, givenMap, yields)(inv.blame)
case ProcessApply(process, args) =>
ProcessApply(process, coerceArgs(args, process.decl))
case ProcessChoice(left, right) =>
Expand Down Expand Up @@ -1518,6 +1520,7 @@ abstract class CoercingRewriter[Pre <: Generation]() extends AbstractRewriter[Pr
case VeyMontCondition(c) => VeyMontCondition(c)
case localIncoming: BipLocalIncomingData[Pre] => localIncoming
case glue: JavaBipGlue[Pre] => glue
case LlvmLocal(name) => e
}
}

Expand Down
14 changes: 7 additions & 7 deletions src/colhelper/ColHelperDeserialize.scala
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ case class ColHelperDeserialize(info: ColDescription, proto: ColProto) extends C

def makeNodeDeserialize(defn: ClassDef): List[Stat] = List(q"""
def ${Term.Name("deserialize" + defn.baseName)}(node: ${serType(defn.baseName)}): ${defn.typ}[G] =
${defn.make(defn.params.map(deserializeParam(defn)), q"LLVMOrigin(Deserialize.Origin(node.origin))", q"LLVMOrigin(Deserialize.Origin(node.origin))")}
${defn.make(defn.params.map(deserializeParam(defn)), q"LLVMOrigin(Deserialize.Origin(node.origin, fileName))", q"LLVMOrigin(Deserialize.Origin(node.origin, fileName))")}
""")

def makeDeserialize(): List[Stat] = q"""
Expand All @@ -76,18 +76,18 @@ case class ColHelperDeserialize(info: ColDescription, proto: ColProto) extends C
import vct.col.origin.LLVMOrigin

object Deserialize {
case class Origin(stringOrigin:String="{}") extends vct.col.origin.Origin {
case class Origin(stringOrigin:String="{}", fileName:String="<unknown>") extends vct.col.origin.Origin {
override def preferredName: String = "unknown"
override def context: String = "At: [deserialized node]"
override def inlineContext: String = "[Deserialized node]"
override def shortPosition: String = "serialized"
}

def deserialize[G](program: ser.Program): Program[G] =
Deserialize[G](mutable.Map()).deserializeProgram(program)
def deserializeProgram[G](program: ser.Program, fileName:String="<unknown>"): Program[G] =
Deserialize[G](mutable.Map(), fileName).deserializeProgram(program)

def deserialize[G](verification: ser.Verification): Verification[G] =
Deserialize[G](mutable.Map()).deserializeVerification(verification)
def deserializeVerification[G](verification: ser.Verification, fileName:String="<unknown>"): Verification[G] =
Deserialize[G](mutable.Map(), fileName).deserializeVerification(verification)

trait DeserializeFunc[S, N[_] <: Node[_]] {
def deserialize[G](s: Deserialize[G], n: S): N[G]
Expand All @@ -97,7 +97,7 @@ case class ColHelperDeserialize(info: ColDescription, proto: ColProto) extends C
..${info.families.flatMap(makeFamilyDispatchLut(_, decl = false)).toList}
}

case class Deserialize[G](decls: mutable.Map[Long, Declaration[G]]) {
case class Deserialize[G](decls: mutable.Map[Long, Declaration[G]], fileName:String) {
def ref[T <: Declaration[G]](id: Long)(implicit tag: ClassTag[T]): Ref[G, T] =
new LazyRef[G, T](decls(id))

Expand Down
Loading