Skip to content

Commit

Permalink
Encode that malloc may return NULL in C
Browse files Browse the repository at this point in the history
The new field 'fallible' on the NewPointerArray AST node stores whether you are guaranteed to get a new pointer array (false) or merely probably get one (true).

Updates tests, fixes #1233.
  • Loading branch information
wandernauta committed Sep 16, 2024
1 parent 8fab87a commit c5c9cf6
Show file tree
Hide file tree
Showing 8 changed files with 52 additions and 21 deletions.
8 changes: 8 additions & 0 deletions examples/concepts/c/malloc_free.c
Original file line number Diff line number Diff line change
Expand Up @@ -13,31 +13,39 @@ struct e{

int main(){
int* xs = (int*) malloc(sizeof(int)*3);
if (xs == NULL) return 1;

xs[0] = 3;
xs[1] = 2;
xs[2] = 1;
free(xs);

int** xxs = (int * *) malloc(sizeof(int *)*3);
if (xxs == NULL) return 1;

int temp[3] = {1,2,3};
xxs[0] = temp;
assert(xxs[0][0] == 1);
free(xxs);

struct d* ys = (struct d*) malloc(3*sizeof(struct d));
if (ys == NULL) return 1;

ys[0].x = 3;
ys[1].x = 2;
ys[2].x = 1;
free(ys);

struct e* a = (struct e*) malloc(1*sizeof(struct e));
if (a == NULL) return 1;

a->s.x = 1;
struct d* b = &(a->s);
free(a);

float * z = (float *) malloc(sizeof(float));
if (z == NULL) return 1;

z[0] = 3.0;
*z = 2.0;
free(z);
Expand Down
8 changes: 5 additions & 3 deletions src/col/vct/col/ast/Node.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1820,9 +1820,11 @@ final case class NewArray[G](
initialize: Boolean,
)(val blame: Blame[ArraySizeError])(implicit val o: Origin)
extends Expr[G] with NewArrayImpl[G]
final case class NewPointerArray[G](element: Type[G], size: Expr[G])(
val blame: Blame[ArraySizeError]
)(implicit val o: Origin)
final case class NewPointerArray[G](
element: Type[G],
size: Expr[G],
fallible: Boolean,
)(val blame: Blame[ArraySizeError])(implicit val o: Origin)
extends Expr[G] with NewPointerArrayImpl[G]
final case class FreePointer[G](pointer: Expr[G])(
val blame: Blame[PointerFreeError]
Expand Down
4 changes: 2 additions & 2 deletions src/col/vct/col/typerules/CoercingRewriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1552,8 +1552,8 @@ abstract class CoercingRewriter[Pre <: Generation]()
Neq(coerce(left, sharedType), coerce(right, sharedType))
case na @ NewArray(element, dims, moreDims, initialize) =>
NewArray(element, dims.map(int), moreDims, initialize)(na.blame)
case na @ NewPointerArray(element, size) =>
NewPointerArray(element, size)(na.blame)
case na @ NewPointerArray(element, size, fallible) =>
NewPointerArray(element, size, fallible)(na.blame)
case NewObject(cls) => NewObject(cls)
case NoPerm() => NoPerm()
case Not(arg) => Not(bool(arg))
Expand Down
22 changes: 14 additions & 8 deletions src/rewrite/vct/rewrite/EncodeArrayValues.scala
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] {
: mutable.Map[(Type[Pre], Int, Int, Boolean), Procedure[Post]] = mutable
.Map()

val pointerArrayCreationMethods: mutable.Map[Type[Pre], Procedure[Post]] =
val pointerArrayCreationMethods: mutable.Map[(Type[Pre], Boolean), Procedure[Post]] =
mutable.Map()

val freeMethods: mutable.Map[Type[
Expand Down Expand Up @@ -504,10 +504,10 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] {
case _ => false
}

def makePointerCreationMethodFor(elementType: Type[Pre]) = {
def makePointerCreationMethodFor(elementType: Type[Pre], fallible: Boolean) = {
implicit val o: Origin = arrayCreationOrigin
// ar != null
// ar.length == dim0
// fallible? then 'ar != null ==> ...'; otherwise 'ar != null ** ...'
// ar.length == size
// forall ar[i] :: Perm(ar[i], write)
// (if type ar[i] is pointer or struct):
// forall i,j :: i!=j ==> ar[i] != ar[j]
Expand All @@ -529,7 +529,7 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] {
Seq(access(i), access(j)),
)

var ensures = (result !== Null()) &*
var ensures =
(PointerBlockLength(result)(FramedPtrBlockLength) === sizeArg.get) &*
(PointerBlockOffset(result)(FramedPtrBlockOffset) === zero)
// Pointer location needs pointer add, not pointer subscript
Expand Down Expand Up @@ -558,14 +558,20 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] {
else
ensures &* foldStar(permFields.map(_._1))

ensures =
if (!fallible)
(result !== Null()) &* ensures
else
Star(Implies(result !== Null(), ensures), tt)

procedure(
blame = AbstractApplicable,
contractBlame = TrueSatisfiable,
returnType = TPointer(dispatch(elementType)),
args = Seq(sizeArg),
requires = UnitAccountedPredicate(requires),
ensures = UnitAccountedPredicate(ensures),
)(o.where(name = "make_pointer_array_" + elementType.toString))
)(o.where(name = "make_pointer_array_" + elementType.toString + (if (fallible) "_fallible" else "")))
}))
}

Expand Down Expand Up @@ -596,9 +602,9 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] {
Nil,
Nil,
)(ArrayCreationFailed(newArr))
case newPointerArr @ NewPointerArray(element, size) =>
case newPointerArr @ NewPointerArray(element, size, fallible) =>
val method = pointerArrayCreationMethods
.getOrElseUpdate(element, makePointerCreationMethodFor(element))
.getOrElseUpdate((element, fallible), makePointerCreationMethodFor(element, fallible))
ProcedureInvocation[Post](
method.ref,
Seq(dispatch(size)),
Expand Down
2 changes: 1 addition & 1 deletion src/rewrite/vct/rewrite/TrivialAddrOf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ case class TrivialAddrOf[Pre <: Generation]() extends Rewriter[Pre] {
val newPointer = Eval(
PreAssignExpression(
newTarget,
NewPointerArray(newValue.t, const[Post](1))(PanicBlame("Size is > 0")),
NewPointerArray(newValue.t, const[Post](1), fallible=false)(PanicBlame("Size is > 0")),
)(blame)
)
(newPointer, newTarget, newValue)
Expand Down
6 changes: 3 additions & 3 deletions src/rewrite/vct/rewrite/lang/LangCPPToCol.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2737,11 +2737,11 @@ case class LangCPPToCol[Pre <: Generation](rw: LangSpecificToCol[Pre])
(sizeOption, init.init) match {
case (None, None) => throw WrongCPPType(decl)
case (Some(size), None) =>
val newArr = NewPointerArray[Post](t, rw.dispatch(size))(cta.blame)
val newArr = NewPointerArray[Post](t, rw.dispatch(size), fallible=false)(cta.blame)
Block(Seq(LocalDecl(v), assignLocal(v.get, newArr)))
case (None, Some(CPPLiteralArray(exprs))) =>
val newArr =
NewPointerArray[Post](t, c_const[Post](exprs.size))(cta.blame)
NewPointerArray[Post](t, c_const[Post](exprs.size), fallible=false)(cta.blame)
Block(
Seq(LocalDecl(v), assignLocal(v.get, newArr)) ++
assignliteralArray(v, exprs, o)
Expand All @@ -2752,7 +2752,7 @@ case class LangCPPToCol[Pre <: Generation](rw: LangSpecificToCol[Pre])
if (realSize < exprs.size)
logger.warn(s"Excess elements in array initializer: '${decl}'")
val newArr =
NewPointerArray[Post](t, c_const[Post](realSize))(cta.blame)
NewPointerArray[Post](t, c_const[Post](realSize), fallible=false)(cta.blame)
Block(
Seq(LocalDecl(v), assignLocal(v.get, newArr)) ++
assignliteralArray(v, exprs.take(realSize.intValue), o)
Expand Down
10 changes: 6 additions & 4 deletions src/rewrite/vct/rewrite/lang/LangCToCol.scala
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre])
(t1, rw.dispatch(r))
case _ => throw UnsupportedMalloc(c)
}
NewPointerArray(rw.dispatch(t1), size)(ArrayMallocFailed(inv))(c.o)
NewPointerArray(rw.dispatch(t1), size, fallible=true)(ArrayMallocFailed(inv))(c.o)
case CCast(CInvocation(CLocal("__vercors_malloc"), _, _, _), _) =>
throw UnsupportedMalloc(c)
case CCast(n @ Null(), t) if t.asPointer.isDefined => rw.dispatch(n)
Expand Down Expand Up @@ -650,6 +650,7 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre])
NewPointerArray[Post](
getInnerType(cNameSuccessor(d).t),
Local(v.ref),
fallible=false,
)(PanicBlame("Shared memory sizes cannot be negative.")),
)
declarations ++= Seq(cNameSuccessor(d))
Expand All @@ -664,6 +665,7 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre])
NewPointerArray[Post](
getInnerType(cNameSuccessor(d).t),
CIntegerValue(size),
fallible=false
)(blame.get),
)
declarations ++= Seq(cNameSuccessor(d))
Expand Down Expand Up @@ -1127,11 +1129,11 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre])
(sizeOption, init.init) match {
case (None, None) => throw WrongCType(decl)
case (Some(size), None) =>
val newArr = NewPointerArray[Post](t, rw.dispatch(size))(cta.blame)
val newArr = NewPointerArray[Post](t, rw.dispatch(size), fallible=false)(cta.blame)
Block(Seq(LocalDecl(v), assignLocal(v.get, newArr)))
case (None, Some(CLiteralArray(exprs))) =>
val newArr =
NewPointerArray[Post](t, c_const[Post](exprs.size))(cta.blame)
NewPointerArray[Post](t, c_const[Post](exprs.size), fallible=false)(cta.blame)
Block(
Seq(LocalDecl(v), assignLocal(v.get, newArr)) ++
assignliteralArray(v, exprs, o)
Expand All @@ -1142,7 +1144,7 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre])
if (realSize < exprs.size)
logger.warn(s"Excess elements in array initializer: '${decl}'")
val newArr =
NewPointerArray[Post](t, c_const[Post](realSize))(cta.blame)
NewPointerArray[Post](t, c_const[Post](realSize), fallible=false)(cta.blame)
Block(
Seq(LocalDecl(v), assignLocal(v.get, newArr)) ++
assignliteralArray(v, exprs.take(realSize.intValue), o)
Expand Down
13 changes: 13 additions & 0 deletions test/main/vct/test/integration/examples/CSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,25 @@ class CSpec extends VercorsSpec {
int main(){
struct e* a = (struct e*) malloc(1*sizeof(struct e));
if (a == NULL) return 1;
a->s.x = 1;
struct d* b = &(a->s);
free(a);
b->x = 2;
}
"""

vercors should fail withCode "ptrNull" using silicon in "use malloc result without null check" c
"""
#include <stdlib.h>
int main(){
int* xs = (int*) malloc(1*sizeof(int));
*xs = 12;
free(xs);
}
"""

vercors should fail withCode "ptrNull" using silicon in "free null pointer" c
"""
#include <stdlib.h>
Expand Down Expand Up @@ -386,6 +398,7 @@ class CSpec extends VercorsSpec {
struct nested *np = NULL;
np = (struct nested*) NULL;
np = (struct nested*) malloc(sizeof(struct nested));
if (np == NULL) return;
np->inner = NULL;
np->inner = (struct nested*) NULL;
}
Expand Down

0 comments on commit c5c9cf6

Please sign in to comment.