Skip to content

Commit

Permalink
First draft implementation of lightweight stratified permissions
Browse files Browse the repository at this point in the history
  • Loading branch information
bobismijnnaam committed Sep 19, 2024
1 parent 89999e9 commit 0a006a8
Showing 1 changed file with 42 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -127,31 +127,25 @@ case class EncodePermissionStratification[Pre <: Generation](
// to the endpoint context argument.
val specializing = ScopedStack[Expr[Post]]()

type WrapperPredicateKey = (TClass[Pre], Type[Pre], InstanceField[Pre])
type WrapperPredicateKey = (Type[Pre], InstanceField[Pre])
val wrapperPredicates = mut
.LinkedHashMap[WrapperPredicateKey, Predicate[Post]]()

def wrapperPredicate(
endpoint: Endpoint[Pre],
objT: Type[Pre],
field: InstanceField[Pre],
)(implicit o: Origin): Ref[Post, Predicate[Post]] = {
val k = (endpoint.t, objT, field)
// TODO (RR): It does not really wrap anymore, rename
def wrapperPredicate(objT: Type[Pre], field: InstanceField[Pre])(
implicit o: Origin
): Ref[Post, Predicate[Post]] = {
val k = (objT, field)
wrapperPredicates.getOrElseUpdate(
k, {
logger.debug(s"Declaring wrapper predicate for $k")
val endpointArg =
new Variable(dispatch(endpoint.t))(o.where(name = "endpoint"))
new Variable[Post](TAnyValue())(o.where(name = "endpoint"))
val objectArg = new Variable(dispatch(objT))(o.where(name = "obj"))
val body = Perm[Post](
FieldLocation(objectArg.get, succ(field)),
WritePerm(),
)
new Predicate(Seq(endpointArg, objectArg), Some(body))(
o.where(indirect =
Name.names(Name("wrap"), field.o.getPreferredNameOrElse())
)
).declare()
new Predicate(Seq(endpointArg, objectArg), None)(o.where(indirect =
Name
.names(Name("ep"), Name("owner"), field.o.getPreferredNameOrElse())
)).declare()
},
).ref
}
Expand All @@ -162,31 +156,25 @@ case class EncodePermissionStratification[Pre <: Generation](
obj: Expr[Pre],
field: InstanceField[Pre],
)(implicit o: Origin): Ref[Post, Function[Post]] = {
val k = (endpoint.t, obj.t, field)
val pred = wrapperPredicate(endpoint, obj.t, field)
val k = (obj.t, field)
val pred = wrapperPredicate(obj.t, field)
readFunctions.getOrElseUpdate(
k, {
logger.debug(s"Declaring read function for $k")
val endpointArg =
new Variable(dispatch(endpoint.t))(o.where(name = "endpoint"))
new Variable[Post](TAnyValue())(o.where(name = "endpoint"))
val objArg = new Variable(dispatch(obj.t))(o.where(name = "obj"))
function(
requires =
Value(PredicateLocation(
PredicateApply(pred, Seq(endpointArg.get, objArg.get))
)).accounted,
(Value[Post](FieldLocation(objArg.get, succ(field))) &*
Value(PredicateLocation(
PredicateApply(pred, Seq(endpointArg.get, objArg.get))
))).accounted,
args = Seq(endpointArg, objArg),
returnType = dispatch(field.t),
body = Some(
Unfolding(
ValuePredicateApply(
PredicateApply(pred, Seq(endpointArg.get, objArg.get))
),
Deref[Post](objArg.get, succ(field))(PanicBlame(
"Permission is guaranteed by the predicate"
)),
)(PanicBlame("Predicate is guaranteed to be in the precondition"))
),
body = Some(Deref[Post](objArg.get, succ(field))(PanicBlame(
"Permission is guaranteed by the predicate"
))),
blame = PanicBlame("Contract is guaranteed to hold"),
contractBlame = PanicBlame("Contract is guaranteed to be satisfiable"),
)(o.where(indirect =
Expand Down Expand Up @@ -327,22 +315,22 @@ case class EncodePermissionStratification[Pre <: Generation](
perm: Expr[Pre],
endpointExpr: Expr[Post] = null,
)(implicit o: Origin): Expr[Post] = {
// TODO: This branch + parameter use is ugly and unclear
// TODO (RR): This branch + parameter use is ugly and unclear. Also, document what it does, as I don't really follow it
val expr =
if (endpointExpr == null)
EndpointName[Post](succ(endpoint))
else
endpointExpr

if (perm == ReadPerm[Pre]()) {
Value(PredicateLocation(PredicateApply(
wrapperPredicate(endpoint, loc.obj.t, loc.field.decl),
(Value(dispatch(loc)) &* Value(PredicateLocation(PredicateApply(
wrapperPredicate(loc.obj.t, loc.field.decl),
Seq(expr, dispatch(loc.obj)),
)))
))))
} else {
Perm(
Perm(dispatch(loc), dispatch(perm)) &* Perm(
PredicateLocation(PredicateApply(
wrapperPredicate(endpoint, loc.obj.t, loc.field.decl),
wrapperPredicate(loc.obj.t, loc.field.decl),
Seq(expr, dispatch(loc.obj)),
)),
dispatch(perm),
Expand Down Expand Up @@ -408,7 +396,7 @@ case class EncodePermissionStratification[Pre <: Generation](
val predicatesForClass = cls.fields.map { field =>
ScaledPredicateApply(
PredicateApply[Post](
wrapperPredicate(endpoint, baseT, field),
wrapperPredicate(baseT, field),
Seq(EndpointName(succ(endpoint)), base),
),
WritePerm(),
Expand Down Expand Up @@ -493,20 +481,17 @@ case class EncodePermissionStratification[Pre <: Generation](
implicit val o = statement.o
val apply = {
val newEndpoint: Ref[Post, Endpoint[Post]] = succ(endpoint)
val ref = wrapperPredicate(endpoint, obj.t, field)
ScaledPredicateApply[Post](
PredicateApply(
ref,
Seq(
EndpointName(newEndpoint),
currentEndpoint.having(endpoint) {
specializing.having(EndpointName[Post](succ(endpoint))) {
dispatch(obj)
}
},
),
val ref = wrapperPredicate(obj.t, field)
PredicateApply(
ref,
Seq(
EndpointName(newEndpoint),
currentEndpoint.having(endpoint) {
specializing.having(EndpointName[Post](succ(endpoint))) {
dispatch(obj)
}
},
),
WritePerm(),
)
}
val intermediate =
Expand All @@ -519,15 +504,17 @@ case class EncodePermissionStratification[Pre <: Generation](
Seq(intermediate),
Block(Seq(
assignLocal(intermediate.get, dispatch(assign.value)),
Unfold(apply)(ForwardUnfoldFailedToDeref(deref)),
// Unfold(apply)(ForwardUnfoldFailedToDeref(deref)),
Assert(Perm(PredicateLocation(apply), WritePerm()))(PanicBlame(
"TODO: Forward blame"
)),
assign.rewrite(
target =
Deref[Post](dispatch(obj), succ(field))(PanicBlame(
"Unfold succeeded, so assignment is safe"
)),
value = intermediate.get,
),
Fold(apply)(PanicBlame("Unfold succeeded, so fold is safe")),
)),
)
}
Expand Down

0 comments on commit 0a006a8

Please sign in to comment.