Skip to content

Commit

Permalink
Implement pattern guards (#1910)
Browse files Browse the repository at this point in the history
* Implement pattern guards

This commit implement pattern guards, which are side conditions that can
be added to the branch of a match expression, following the pattern, to
further constrain the matching. This condition is introduced by the
(already existing) `if` keyword. For example:

`match { {list} if list != [] => body }` will match a record with a
unique field `list` that isn't empty.

The compilation is rather straightforward, as a pattern is already
compiled to a tree of if-then-else while also building the bindings
introduced by pattern variables. After all the conditions coming from
the pattern have been tested, we just additionally check for the guard
(injecting the bindings in the condition since the guard can - and most
often does - use variables bound by the pattern).

* Update core/src/term/mod.rs

Co-authored-by: jneem <joeneeman@gmail.com>

* Exclude guarded patterns from tag-only optimization

---------

Co-authored-by: jneem <joeneeman@gmail.com>
  • Loading branch information
yannham and jneem authored May 13, 2024
1 parent aca39ef commit 7225b72
Show file tree
Hide file tree
Showing 15 changed files with 298 additions and 92 deletions.
15 changes: 8 additions & 7 deletions core/src/eval/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ use crate::{
make as mk_term,
pattern::compile::Compile,
record::{Field, RecordData},
BinaryOp, BindingType, LetAttrs, MatchData, RecordOpKind, RichTerm, RuntimeContract,
StrChunk, Term, UnaryOp,
BinaryOp, BindingType, LetAttrs, MatchBranch, MatchData, RecordOpKind, RichTerm,
RuntimeContract, StrChunk, Term, UnaryOp,
},
};

Expand Down Expand Up @@ -1191,11 +1191,12 @@ pub fn subst<C: Cache>(
Term::Match(data) => {
let branches = data.branches
.into_iter()
.map(|(pat, branch)| {
(
pat,
subst(cache, branch, initial_env, env),
)
.map(|MatchBranch { pattern, guard, body} | {
MatchBranch {
pattern,
guard: guard.map(|cond| subst(cache, cond, initial_env, env)),
body: subst(cache, body, initial_env, env),
}
})
.collect();

Expand Down
12 changes: 7 additions & 5 deletions core/src/parser/grammar.lalrpop
Original file line number Diff line number Diff line change
Expand Up @@ -307,10 +307,10 @@ Applicative: UniTerm = {
<op: BOpPre> <t1: AsTerm<Atom>> <t2: AsTerm<Atom>>
=> UniTerm::from(mk_term::op2(op, t1, t2)),
NOpPre<AsTerm<Atom>>,
"match" "{" <branches: (MatchCase ",")*> <last: MatchCase?> "}" => {
"match" "{" <branches: (MatchBranch ",")*> <last: MatchBranch?> "}" => {
let branches = branches
.into_iter()
.map(|(case, _comma)| case)
.map(|(branch, _comma)| branch)
.chain(last)
.collect();

Expand Down Expand Up @@ -876,9 +876,11 @@ UOp: UnaryOp = {
"enum_get_tag" => UnaryOp::EnumGetTag(),
}

MatchCase: (Pattern, RichTerm) = {
<pat: Pattern> "=>" <t: Term> => (pat, t),
};
PatternGuard: RichTerm = "if" <Term> => <>;

MatchBranch: MatchBranch =
<pattern: Pattern> <guard: PatternGuard?> "=>" <body: Term> =>
MatchBranch { pattern, guard, body};

// Infix operators by precedence levels. Lowest levels take precedence over
// highest ones.
Expand Down
44 changes: 30 additions & 14 deletions core/src/pretty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -884,20 +884,12 @@ where
allocator,
allocator.line(),
allocator.intersperse(
data.branches
.iter()
.map(|(pat, t)| (pat.pretty(allocator), t))
.map(|(lhs, t)| docs![
allocator,
lhs,
allocator.space(),
"=>",
allocator.line(),
t,
","
]
.nest(2)),
allocator.line()
data.branches.iter().map(|branch| docs![
allocator,
branch.pretty(allocator),
","
]),
allocator.line(),
),
]
.nest(2)
Expand Down Expand Up @@ -1185,6 +1177,30 @@ where
}
}

impl<'a, D, A> Pretty<'a, D, A> for &MatchBranch
where
D: NickelAllocatorExt<'a, A>,
D::Doc: Clone,
A: Clone + 'a,
{
fn pretty(self, allocator: &'a D) -> DocBuilder<'a, D, A> {
let guard = if let Some(guard) = &self.guard {
docs![allocator, allocator.line(), "if", allocator.space(), guard]
} else {
allocator.nil()
};

docs![
allocator,
&self.pattern,
guard,
allocator.space(),
"=>",
docs![allocator, allocator.line(), self.body.pretty(allocator),].nest(2),
]
}
}

/// Generate an implementation of `fmt::Display` for types that implement `Pretty`.
#[macro_export]
macro_rules! impl_display_from_pretty {
Expand Down
50 changes: 43 additions & 7 deletions core/src/term/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -604,12 +604,24 @@ impl fmt::Display for MergePriority {
}
}

/// A branch of a match expression.
#[derive(Debug, PartialEq, Clone)]
pub struct MatchBranch {
/// The pattern on the left hand side of `=>`.
pub pattern: Pattern,
/// A potential guard, which is an additional side-condition defined as `if cond`. The value
/// stored in this field is the boolean condition itself.
pub guard: Option<RichTerm>,
/// The body of the branch, on the right hand side of `=>`.
pub body: RichTerm,
}

/// Content of a match expression.
#[derive(Debug, PartialEq, Clone)]
pub struct MatchData {
/// Branches of the match expression, where the first component is the pattern on the left hand
/// side of `=>` and the second component is the body of the branch.
pub branches: Vec<(Pattern, RichTerm)>,
pub branches: Vec<MatchBranch>,
}

/// A type or a contract together with its corresponding label.
Expand Down Expand Up @@ -2024,11 +2036,26 @@ impl Traverse<RichTerm> for RichTerm {
Term::Match(data) => {
// The annotation on `map_res` use Result's corresponding trait to convert from
// Iterator<Result> to a Result<Iterator>
let branches: Result<Vec<(Pattern, RichTerm)>, E> = data
let branches: Result<Vec<MatchBranch>, E> = data
.branches
.into_iter()
// For the conversion to work, note that we need a Result<(Ident,RichTerm), E>
.map(|(pat, t)| t.traverse(f, order).map(|t_ok| (pat, t_ok)))
.map(
|MatchBranch {
pattern,
guard,
body,
}| {
let guard = guard.map(|cond| cond.traverse(f, order)).transpose()?;
let body = body.traverse(f, order)?;

Ok(MatchBranch {
pattern,
guard,
body,
})
},
)
.collect();

RichTerm::new(
Expand Down Expand Up @@ -2203,10 +2230,19 @@ impl Traverse<RichTerm> for RichTerm {
.or_else(|| field.traverse_ref(f, state))
})
}),
Term::Match(data) => data
.branches
.iter()
.find_map(|(_pat, t)| t.traverse_ref(f, state)),
Term::Match(data) => data.branches.iter().find_map(
|MatchBranch {
pattern: _,
guard,
body,
}| {
if let Some(cond) = guard.as_ref() {
cond.traverse_ref(f, state)?;
}

body.traverse_ref(f, state)
},
),
Term::Array(ts, _) => ts.iter().find_map(|t| t.traverse_ref(f, state)),
Term::OpN(_, ts) => ts.iter().find_map(|t| t.traverse_ref(f, state)),
Term::Annotated(annot, t) => t
Expand Down
86 changes: 61 additions & 25 deletions core/src/term/pattern/compile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ use super::*;
use crate::{
mk_app,
term::{
make, record::FieldMetadata, BinaryOp, MatchData, RecordExtKind, RecordOpKind, RichTerm,
Term, UnaryOp,
make, record::FieldMetadata, BinaryOp, MatchBranch, MatchData, RecordExtKind, RecordOpKind,
RichTerm, Term, UnaryOp,
},
};

Expand Down Expand Up @@ -668,23 +668,32 @@ impl Compile for MatchData {
// # this primop evaluates body with an environment extended with bindings_id
// %pattern_branch% body bindings_id
fn compile(mut self, value: RichTerm, pos: TermPos) -> RichTerm {
if self.branches.iter().all(|(pat, _)| {
if self.branches.iter().all(|branch| {
// While we could get something working even with a guard, it's a bit more work and
// there's no current incentive to do so (a guard on a tags-only match is arguably less
// common, as such patterns don't bind any variable). For the time being, we just
// exclude guards from the tags-only optimization.
matches!(
pat.data,
branch.pattern.data,
PatternData::Enum(EnumPattern { pattern: None, .. }) | PatternData::Wildcard
)
) && branch.guard.is_none()
}) {
let wildcard_pat = self
.branches
.iter()
.enumerate()
.find_map(|(idx, (pat, body))| {
if let PatternData::Wildcard = pat.data {
let wildcard_pat = self.branches.iter().enumerate().find_map(
|(
idx,
MatchBranch {
pattern,
guard,
body,
},
)| {
if matches!((&pattern.data, guard), (PatternData::Wildcard, None)) {
Some((idx, body.clone()))
} else {
None
}
});
},
);

// If we find a wildcard pattern, we record its index in order to discard all the
// patterns coming after the wildcard, because they are unreachable.
Expand All @@ -698,13 +707,19 @@ impl Compile for MatchData {
let tags_only = self
.branches
.into_iter()
.filter_map(|(pat, body)| {
if let PatternData::Enum(EnumPattern { tag, .. }) = pat.data {
Some((tag, body))
} else {
None
}
})
.filter_map(
|MatchBranch {
pattern,
guard: _,
body,
}| {
if let PatternData::Enum(EnumPattern { tag, .. }) = pattern.data {
Some((tag, body))
} else {
None
}
},
)
.collect();

return TagsOnlyMatch {
Expand All @@ -726,14 +741,14 @@ impl Compile for MatchData {

// The fold block:
//
// <for (pattern, body) in branches.rev()
// <for branch in branches.rev()
// - cont is the accumulator
// - initial accumulator is the default branch (or error if not default branch)
// >
// let init_bindings_id = {} in
// let bindings_id = <pattern.compile_part(value_id, init_bindings)> in
//
// if bindings_id == null then
// if bindings_id == null || !<guard> then
// cont
// else
// # this primop evaluates body with an environment extended with bindings_id
Expand All @@ -742,10 +757,31 @@ impl Compile for MatchData {
.branches
.into_iter()
.rev()
.fold(error_case, |cont, (pat, body)| {
.fold(error_case, |cont, branch| {
let init_bindings_id = LocIdent::fresh();
let bindings_id = LocIdent::fresh();

// inner if condition:
// bindings_id == null || !<guard>
let inner_if_cond = make::op2(BinaryOp::Eq(), Term::Var(bindings_id), Term::Null);
let inner_if_cond = if let Some(guard) = branch.guard {
// the guard must be evaluated in the same environment as the body of the
// branch, as it might use bindings introduced by the pattern. Since `||` is
// lazy in Nickel, we know that `bindings_id` is not null if the guard
// condition is ever evaluated.
let guard_cond = mk_app!(
make::op1(UnaryOp::PatternBranch(), Term::Var(bindings_id)),
guard
);

mk_app!(
make::op1(UnaryOp::BoolOr(), inner_if_cond),
make::op1(UnaryOp::BoolNot(), guard_cond)
)
} else {
inner_if_cond
};

// inner if block:
//
// if bindings_id == null then
Expand All @@ -754,11 +790,11 @@ impl Compile for MatchData {
// # this primop evaluates body with an environment extended with bindings_id
// %pattern_branch% bindings_id body
let inner = make::if_then_else(
make::op2(BinaryOp::Eq(), Term::Var(bindings_id), Term::Null),
inner_if_cond,
cont,
mk_app!(
make::op1(UnaryOp::PatternBranch(), Term::Var(bindings_id),),
body
branch.body
),
);

Expand All @@ -772,7 +808,7 @@ impl Compile for MatchData {
Term::Record(RecordData::empty()),
make::let_in(
bindings_id,
pat.compile_part(value_id, init_bindings_id),
branch.pattern.compile_part(value_id, init_bindings_id),
inner,
),
)
Expand Down
12 changes: 8 additions & 4 deletions core/src/transform/desugar_destructuring.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
use crate::{
identifier::LocIdent,
match_sharedterm,
term::{pattern::*, MatchData, RichTerm, Term},
term::{pattern::*, MatchBranch, MatchData, RichTerm, Term},
};

/// Entry point of the destructuring desugaring transformation.
Expand Down Expand Up @@ -48,16 +48,20 @@ pub fn desugar_fun(mut pat: Pattern, body: RichTerm) -> Term {
/// Desugar a destructuring let-binding.
///
/// A let-binding `let <pat> = bound in body` is desugared to `<bound> |> match { <pat> => body }`.
pub fn desugar_let(pat: Pattern, bound: RichTerm, body: RichTerm) -> Term {
pub fn desugar_let(pattern: Pattern, bound: RichTerm, body: RichTerm) -> Term {
// the position of the match expression is used during error reporting, so we try to provide a
// sensible one.
let match_expr_pos = pat.pos.fuse(bound.pos);
let match_expr_pos = pattern.pos.fuse(bound.pos);

// `(match { <pat> => <body> }) <bound>`
Term::App(
RichTerm::new(
Term::Match(MatchData {
branches: vec![(pat, body)],
branches: vec![MatchBranch {
pattern,
guard: None,
body,
}],
}),
match_expr_pos,
),
Expand Down
Loading

0 comments on commit 7225b72

Please sign in to comment.