Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wildcard patterns #1904

Merged
merged 3 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions core/src/eval/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1189,9 +1189,6 @@ pub fn subst<C: Cache>(
RichTerm::new(Term::App(t1, t2), pos)
}
Term::Match(data) => {
let default =
data.default.map(|d| subst(cache, d, initial_env, env));

let branches = data.branches
.into_iter()
.map(|(pat, branch)| {
Expand All @@ -1202,7 +1199,7 @@ pub fn subst<C: Cache>(
})
.collect();

RichTerm::new(Term::Match(MatchData { branches, default}), pos)
RichTerm::new(Term::Match(MatchData { branches }), pos)
}
Term::Op1(op, t) => {
let t = subst(cache, t, initial_env, env);
Expand Down
30 changes: 4 additions & 26 deletions core/src/parser/grammar.lalrpop
Original file line number Diff line number Diff line change
Expand Up @@ -308,27 +308,13 @@ Applicative: UniTerm = {
=> UniTerm::from(mk_term::op2(op, t1, t2)),
NOpPre<AsTerm<Atom>>,
"match" "{" <branches: (MatchCase ",")*> <last: MatchCase?> "}" => {
let mut default = None;

let branches = branches
.into_iter()
.map(|(case, _comma)| case)
.chain(last)
.filter_map(|case| match case {
MatchCase::Normal(pat, branch) => Some((pat, branch)),
MatchCase::Default(default_branch) => {
default = Some(default_branch);
None
}
})
.collect();

UniTerm::from(
Term::Match(MatchData {
branches,
default,
})
)
UniTerm::from(Term::Match(MatchData { branches }))
}
};

Expand Down Expand Up @@ -578,6 +564,7 @@ PatternDataF<F>: PatternData = {
ConstantPattern => PatternData::Constant(<>),
EnumPatternF<F> => PatternData::Enum(<>),
Ident => PatternData::Any(<>),
"_" => PatternData::Wildcard,
};

// A general pattern.
Expand Down Expand Up @@ -888,17 +875,8 @@ UOp: UnaryOp = {
"enum_get_tag" => UnaryOp::EnumGetTag(),
}

// It might seem silly that a match case can always be the catch-all case
// `_ => <exp>`. It would be better to separate between a normal match case and
// a rule for the catch-call. However, it's then surprisingly annoying to
// express the rule for "match" such that it's both non-ambiguous and allow an
// optional trailing comma ",".
//
// In the end, it was simpler to just allow the catch-all case to appear
// anywhere, and then to raise an error in the action code of the "match" rule.
MatchCase: MatchCase = {
<pat: Pattern> "=>" <t: Term> => MatchCase::Normal(pat, t),
"_" "=>" <Term> => MatchCase::Default(<>),
MatchCase: (Pattern, RichTerm) = {
<pat: Pattern> "=>" <t: Term> => (pat, t),
};

// Infix operators by precedence levels. Lowest levels take precedence over
Expand Down
7 changes: 0 additions & 7 deletions core/src/parser/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,6 @@ pub enum StringEndDelimiter {
Special,
}

/// Distinguish between a normal case `id => exp` and a default case `_ => exp`.
#[derive(Clone, Debug)]
pub enum MatchCase {
Normal(Pattern, RichTerm),
Default(RichTerm),
}

/// Left hand side of a record field declaration.
#[derive(Clone, Debug)]
pub enum FieldPathElem {
Expand Down
2 changes: 1 addition & 1 deletion core/src/pretty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,7 @@ where
{
fn pretty(self, allocator: &'a D) -> DocBuilder<'a, D, A> {
match self {
PatternData::Wildcard => allocator.text("_"),
PatternData::Any(id) => allocator.as_string(id),
PatternData::Record(rp) => rp.pretty(allocator),
PatternData::Enum(evp) => evp.pretty(allocator),
Expand Down Expand Up @@ -886,7 +887,6 @@ where
data.branches
.iter()
.map(|(pat, t)| (pat.pretty(allocator), t))
.chain(data.default.iter().map(|d| (allocator.text("_"), d)))
.map(|(lhs, t)| docs![
allocator,
lhs,
Expand Down
7 changes: 1 addition & 6 deletions core/src/term/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,6 @@ pub struct MatchData {
/// Branches of the match expression, where the first component is the pattern on the left hand
/// side of `=>` and the second component is the body of the branch.
pub branches: Vec<(Pattern, RichTerm)>,
pub default: Option<RichTerm>,
}

/// A type or a contract together with its corresponding label.
Expand Down Expand Up @@ -2032,12 +2031,9 @@ impl Traverse<RichTerm> for RichTerm {
.map(|(pat, t)| t.traverse(f, order).map(|t_ok| (pat, t_ok)))
.collect();

let default = data.default.map(|t| t.traverse(f, order)).transpose()?;

RichTerm::new(
Term::Match(MatchData {
branches: branches?,
default,
}),
pos,
)
Expand Down Expand Up @@ -2210,8 +2206,7 @@ impl Traverse<RichTerm> for RichTerm {
Term::Match(data) => data
.branches
.iter()
.find_map(|(_pat, t)| t.traverse_ref(f, state))
.or_else(|| data.default.as_ref().and_then(|t| t.traverse_ref(f, state))),
.find_map(|(_pat, t)| t.traverse_ref(f, state)),
Term::Array(ts, _) => ts.iter().find_map(|t| t.traverse_ref(f, state)),
Term::OpN(_, ts) => ts.iter().find_map(|t| t.traverse_ref(f, state)),
Term::Annotated(annot, t) => t
Expand Down
118 changes: 68 additions & 50 deletions core/src/term/pattern/compile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ impl CompilePart for Pattern {
impl CompilePart for PatternData {
fn compile_part(&self, value_id: LocIdent, bindings_id: LocIdent) -> RichTerm {
match self {
PatternData::Wildcard => Term::Var(bindings_id).into(),
PatternData::Any(id) => {
// %record_insert% "<id>" value_id bindings_id
insert_binding(*id, value_id, bindings_id)
Expand Down Expand Up @@ -670,31 +671,48 @@ impl Compile for MatchData {
if self.branches.iter().all(|(pat, _)| {
matches!(
pat.data,
PatternData::Enum(EnumPattern { pattern: None, .. })
PatternData::Enum(EnumPattern { pattern: None, .. }) | PatternData::Wildcard
)
}) {
let tags_only = self.branches.into_iter().map(|(pat, body)| {
let PatternData::Enum(EnumPattern {tag, ..}) = pat.data else {
panic!("match compilation: just tested that all cases are enum tags, but found a non enum tag pattern");
};

(tag, body)
}).collect();
// We take the first wildcard pattern as the default case. In theory, we could discard
// all the tags coming after the default branch; in practice we expect any sane match
// expression to have the default branch at the end, so this isn't a very useful
// optimization to care about.
yannham marked this conversation as resolved.
Show resolved Hide resolved
let default = self.branches.iter().find_map(|(pat, body)| {
if let PatternData::Wildcard = pat.data {
Some(body.clone())
} else {
None
}
});

let tags_only = self
.branches
.into_iter()
.filter_map(|(pat, body)| {
if let PatternData::Enum(EnumPattern { tag, .. }) = pat.data {
Some((tag, body))
} else {
None
}
})
.collect();

return TagsOnlyMatch {
branches: tags_only,
default: self.default,
default,
}
.compile(value, pos);
}

let default_branch = self.default.unwrap_or_else(|| {
let error_case = RichTerm::new(
Term::RuntimeError(EvalError::NonExhaustiveMatch {
value: value.clone(),
pos,
})
.into()
});
}),
pos,
);

let value_id = LocIdent::fresh();

// The fold block:
Expand All @@ -711,45 +729,45 @@ impl Compile for MatchData {
// else
// # this primop evaluates body with an environment extended with bindings_id
// %pattern_branch% body bindings_id
let fold_block =
self.branches
.into_iter()
.rev()
.fold(default_branch, |cont, (pat, body)| {
let init_bindings_id = LocIdent::fresh();
let bindings_id = LocIdent::fresh();

// inner if block:
//
// if bindings_id == null then
// cont
// else
// # this primop evaluates body with an environment extended with bindings_id
// %pattern_branch% bindings_id body
let inner = make::if_then_else(
make::op2(BinaryOp::Eq(), Term::Var(bindings_id), Term::Null),
cont,
mk_app!(
make::op1(UnaryOp::PatternBranch(), Term::Var(bindings_id),),
body
),
);
let fold_block = self
.branches
.into_iter()
.rev()
.fold(error_case, |cont, (pat, body)| {
let init_bindings_id = LocIdent::fresh();
let bindings_id = LocIdent::fresh();

// inner if block:
//
// if bindings_id == null then
// cont
// else
// # this primop evaluates body with an environment extended with bindings_id
// %pattern_branch% bindings_id body
let inner = make::if_then_else(
make::op2(BinaryOp::Eq(), Term::Var(bindings_id), Term::Null),
cont,
mk_app!(
make::op1(UnaryOp::PatternBranch(), Term::Var(bindings_id),),
body
),
);

// The two initial chained let-bindings:
//
// let init_bindings_id = {} in
// let bindings_id = <pattern.compile_part(value_id, init_bindings)> in
// <inner>
// The two initial chained let-bindings:
//
// let init_bindings_id = {} in
// let bindings_id = <pattern.compile_part(value_id, init_bindings)> in
// <inner>
make::let_in(
init_bindings_id,
Term::Record(RecordData::empty()),
make::let_in(
init_bindings_id,
Term::Record(RecordData::empty()),
make::let_in(
bindings_id,
pat.compile_part(value_id, init_bindings_id),
inner,
),
)
});
bindings_id,
pat.compile_part(value_id, init_bindings_id),
inner,
),
)
});

// let value_id = value in <fold_block>
make::let_in(value_id, value, fold_block)
Expand Down
5 changes: 4 additions & 1 deletion core/src/term/pattern/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ pub mod compile;

#[derive(Debug, PartialEq, Clone)]
pub enum PatternData {
/// A wildcard pattern, matching any value. As opposed to any, this pattern doesn't bind any
/// variable.
Wildcard,
/// A simple pattern consisting of an identifier. Match anything and bind the result to the
/// corresponding identfier.
Any(LocIdent),
Expand Down Expand Up @@ -215,7 +218,7 @@ pub trait ElaborateContract {
impl ElaborateContract for PatternData {
fn elaborate_contract(&self) -> Option<LabeledType> {
match self {
PatternData::Any(_) => None,
PatternData::Wildcard | PatternData::Any(_) => None,
PatternData::Record(pat) => pat.elaborate_contract(),
PatternData::Enum(pat) => pat.elaborate_contract(),
PatternData::Constant(pat) => pat.elaborate_contract(),
Expand Down
1 change: 1 addition & 0 deletions core/src/transform/desugar_destructuring.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ impl Desugar for Pattern {
impl Desugar for PatternData {
fn desugar(self, destr: RichTerm, body: RichTerm) -> Term {
match self {
PatternData::Wildcard => body.into(),
// If the pattern is an unconstrained identifier, we just bind it to the value.
PatternData::Any(id) => Term::Let(id, destr, body, LetAttrs::default()),
PatternData::Record(pat) => pat.desugar(destr, body),
Expand Down
8 changes: 2 additions & 6 deletions core/src/transform/free_vars.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,6 @@ impl CollectFreeVars for RichTerm {

free_vars.extend(fresh);
}

if let Some(default) = &mut data.default {
default.collect_free_vars(free_vars);
}
}
Term::Op1(_, t) | Term::Sealed(_, t, _) | Term::EnumVariant { arg: t, .. } => {
t.collect_free_vars(free_vars)
Expand Down Expand Up @@ -255,8 +251,8 @@ impl RemoveBindings for PatternData {
PatternData::Enum(enum_variant_pat) => {
enum_variant_pat.remove_bindings(working_set);
}
// A constant pattern doesn't bind any variable.
PatternData::Constant(_) => (),
// A wildcard pattern or a constant pattern doesn't bind any variable.
PatternData::Wildcard | PatternData::Constant(_) => (),
}
}
}
Expand Down
39 changes: 24 additions & 15 deletions core/src/typ.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1030,26 +1030,35 @@ impl Subcontract for EnumRows {
}
}

let default = if let Some(var) = tail_var {
mk_app!(
mk_term::op2(
BinaryOp::ApplyContract(),
get_var_contract(&vars, var.ident(), var.pos)?,
mk_term::var(label_arg)
let (default, default_pos) = if let Some(var) = tail_var {
(
mk_app!(
mk_term::op2(
BinaryOp::ApplyContract(),
get_var_contract(&vars, var.ident(), var.pos)?,
mk_term::var(label_arg)
),
mk_term::var(value_arg)
),
mk_term::var(value_arg)
var.pos,
)
} else {
mk_app!(internals::enum_fail(), mk_term::var(label_arg))
(
mk_app!(internals::enum_fail(), mk_term::var(label_arg)),
TermPos::None,
)
};

let match_expr = mk_app!(
Term::Match(MatchData {
branches,
default: Some(default)
}),
mk_term::var(value_arg)
);
branches.push((
Pattern {
data: PatternData::Wildcard,
alias: None,
pos: default_pos,
},
default,
));

let match_expr = mk_app!(Term::Match(MatchData { branches }), mk_term::var(value_arg));

let case = mk_fun!(label_arg, value_arg, match_expr);
Ok(mk_app!(internals::enumeration(), case))
Expand Down
Loading
Loading