diff --git a/core/src/eval/mod.rs b/core/src/eval/mod.rs index b8f7bd6ba5..fec09cf7d0 100644 --- a/core/src/eval/mod.rs +++ b/core/src/eval/mod.rs @@ -90,8 +90,8 @@ use crate::{ make as mk_term, pattern::compile::Compile, record::{Field, RecordData}, - BinaryOp, BindingType, LetAttrs, MatchData, RecordOpKind, RichTerm, RuntimeContract, - StrChunk, Term, UnaryOp, + BinaryOp, BindingType, LetAttrs, MatchBranch, MatchData, RecordOpKind, RichTerm, + RuntimeContract, StrChunk, Term, UnaryOp, }, }; @@ -1191,11 +1191,12 @@ pub fn subst( Term::Match(data) => { let branches = data.branches .into_iter() - .map(|(pat, branch)| { - ( - pat, - subst(cache, branch, initial_env, env), - ) + .map(|MatchBranch { pattern, guard, body} | { + MatchBranch { + pattern, + guard: guard.map(|cond| subst(cache, cond, initial_env, env)), + body: subst(cache, body, initial_env, env), + } }) .collect(); diff --git a/core/src/parser/grammar.lalrpop b/core/src/parser/grammar.lalrpop index 9368a8931f..cf638910df 100644 --- a/core/src/parser/grammar.lalrpop +++ b/core/src/parser/grammar.lalrpop @@ -307,10 +307,10 @@ Applicative: UniTerm = { > > => UniTerm::from(mk_term::op2(op, t1, t2)), NOpPre>, - "match" "{" "}" => { + "match" "{" "}" => { let branches = branches .into_iter() - .map(|(case, _comma)| case) + .map(|(branch, _comma)| branch) .chain(last) .collect(); @@ -876,9 +876,11 @@ UOp: UnaryOp = { "enum_get_tag" => UnaryOp::EnumGetTag(), } -MatchCase: (Pattern, RichTerm) = { - "=>" => (pat, t), -}; +PatternGuard: RichTerm = "if" => <>; + +MatchBranch: MatchBranch = + "=>" => + MatchBranch { pattern, guard, body}; // Infix operators by precedence levels. Lowest levels take precedence over // highest ones. diff --git a/core/src/pretty.rs b/core/src/pretty.rs index 3bd179007e..ecdcc56ee5 100644 --- a/core/src/pretty.rs +++ b/core/src/pretty.rs @@ -884,20 +884,12 @@ where allocator, allocator.line(), allocator.intersperse( - data.branches - .iter() - .map(|(pat, t)| (pat.pretty(allocator), t)) - .map(|(lhs, t)| docs![ - allocator, - lhs, - allocator.space(), - "=>", - allocator.line(), - t, - "," - ] - .nest(2)), - allocator.line() + data.branches.iter().map(|branch| docs![ + allocator, + branch.pretty(allocator), + "," + ]), + allocator.line(), ), ] .nest(2) @@ -1185,6 +1177,30 @@ where } } +impl<'a, D, A> Pretty<'a, D, A> for &MatchBranch +where + D: NickelAllocatorExt<'a, A>, + D::Doc: Clone, + A: Clone + 'a, +{ + fn pretty(self, allocator: &'a D) -> DocBuilder<'a, D, A> { + let guard = if let Some(guard) = &self.guard { + docs![allocator, allocator.line(), "if", allocator.space(), guard] + } else { + allocator.nil() + }; + + docs![ + allocator, + &self.pattern, + guard, + allocator.space(), + "=>", + docs![allocator, allocator.line(), self.body.pretty(allocator),].nest(2), + ] + } +} + /// Generate an implementation of `fmt::Display` for types that implement `Pretty`. #[macro_export] macro_rules! impl_display_from_pretty { diff --git a/core/src/term/mod.rs b/core/src/term/mod.rs index d63496d4d1..2bce81ba21 100644 --- a/core/src/term/mod.rs +++ b/core/src/term/mod.rs @@ -604,12 +604,24 @@ impl fmt::Display for MergePriority { } } +/// A branch of a match expression. +#[derive(Debug, PartialEq, Clone)] +pub struct MatchBranch { + /// The pattern on the left hand side of `=>`. + pub pattern: Pattern, + /// A potential guard, which is an additional side-condition defined as `if cond`. The value + /// stored in this field is the boolean condition itself. + pub guard: Option, + /// The body of the branch, on the right hand side of `=>`. + pub body: RichTerm, +} + /// Content of a match expression. #[derive(Debug, PartialEq, Clone)] pub struct MatchData { /// Branches of the match expression, where the first component is the pattern on the left hand /// side of `=>` and the second component is the body of the branch. - pub branches: Vec<(Pattern, RichTerm)>, + pub branches: Vec, } /// A type or a contract together with its corresponding label. @@ -2024,11 +2036,26 @@ impl Traverse for RichTerm { Term::Match(data) => { // The annotation on `map_res` use Result's corresponding trait to convert from // Iterator to a Result - let branches: Result, E> = data + let branches: Result, E> = data .branches .into_iter() // For the conversion to work, note that we need a Result<(Ident,RichTerm), E> - .map(|(pat, t)| t.traverse(f, order).map(|t_ok| (pat, t_ok))) + .map( + |MatchBranch { + pattern, + guard, + body, + }| { + let guard = guard.map(|cond| cond.traverse(f, order)).transpose()?; + let body = body.traverse(f, order)?; + + Ok(MatchBranch { + pattern, + guard, + body, + }) + }, + ) .collect(); RichTerm::new( @@ -2203,10 +2230,19 @@ impl Traverse for RichTerm { .or_else(|| field.traverse_ref(f, state)) }) }), - Term::Match(data) => data - .branches - .iter() - .find_map(|(_pat, t)| t.traverse_ref(f, state)), + Term::Match(data) => data.branches.iter().find_map( + |MatchBranch { + pattern: _, + guard, + body, + }| { + if let Some(cond) = guard.as_ref() { + cond.traverse_ref(f, state)?; + } + + body.traverse_ref(f, state) + }, + ), Term::Array(ts, _) => ts.iter().find_map(|t| t.traverse_ref(f, state)), Term::OpN(_, ts) => ts.iter().find_map(|t| t.traverse_ref(f, state)), Term::Annotated(annot, t) => t diff --git a/core/src/term/pattern/compile.rs b/core/src/term/pattern/compile.rs index 16d1fe4df2..b1bb675720 100644 --- a/core/src/term/pattern/compile.rs +++ b/core/src/term/pattern/compile.rs @@ -26,8 +26,8 @@ use super::*; use crate::{ mk_app, term::{ - make, record::FieldMetadata, BinaryOp, MatchData, RecordExtKind, RecordOpKind, RichTerm, - Term, UnaryOp, + make, record::FieldMetadata, BinaryOp, MatchBranch, MatchData, RecordExtKind, RecordOpKind, + RichTerm, Term, UnaryOp, }, }; @@ -668,23 +668,32 @@ impl Compile for MatchData { // # this primop evaluates body with an environment extended with bindings_id // %pattern_branch% body bindings_id fn compile(mut self, value: RichTerm, pos: TermPos) -> RichTerm { - if self.branches.iter().all(|(pat, _)| { + if self.branches.iter().all(|branch| { + // While we could get something working even with a guard, it's a bit more work and + // there's no current incentive to do so (a guard on a tags-only match is arguably less + // common, as such patterns don't bind any variable). For the time being, we just + // exclude guards from the tags-only optimization. matches!( - pat.data, + branch.pattern.data, PatternData::Enum(EnumPattern { pattern: None, .. }) | PatternData::Wildcard - ) + ) && branch.guard.is_none() }) { - let wildcard_pat = self - .branches - .iter() - .enumerate() - .find_map(|(idx, (pat, body))| { - if let PatternData::Wildcard = pat.data { + let wildcard_pat = self.branches.iter().enumerate().find_map( + |( + idx, + MatchBranch { + pattern, + guard, + body, + }, + )| { + if matches!((&pattern.data, guard), (PatternData::Wildcard, None)) { Some((idx, body.clone())) } else { None } - }); + }, + ); // If we find a wildcard pattern, we record its index in order to discard all the // patterns coming after the wildcard, because they are unreachable. @@ -698,13 +707,19 @@ impl Compile for MatchData { let tags_only = self .branches .into_iter() - .filter_map(|(pat, body)| { - if let PatternData::Enum(EnumPattern { tag, .. }) = pat.data { - Some((tag, body)) - } else { - None - } - }) + .filter_map( + |MatchBranch { + pattern, + guard: _, + body, + }| { + if let PatternData::Enum(EnumPattern { tag, .. }) = pattern.data { + Some((tag, body)) + } else { + None + } + }, + ) .collect(); return TagsOnlyMatch { @@ -726,14 +741,14 @@ impl Compile for MatchData { // The fold block: // - // // let init_bindings_id = {} in // let bindings_id = in // - // if bindings_id == null then + // if bindings_id == null || ! then // cont // else // # this primop evaluates body with an environment extended with bindings_id @@ -742,10 +757,31 @@ impl Compile for MatchData { .branches .into_iter() .rev() - .fold(error_case, |cont, (pat, body)| { + .fold(error_case, |cont, branch| { let init_bindings_id = LocIdent::fresh(); let bindings_id = LocIdent::fresh(); + // inner if condition: + // bindings_id == null || ! + let inner_if_cond = make::op2(BinaryOp::Eq(), Term::Var(bindings_id), Term::Null); + let inner_if_cond = if let Some(guard) = branch.guard { + // the guard must be evaluated in the same environment as the body of the + // branch, as it might use bindings introduced by the pattern. Since `||` is + // lazy in Nickel, we know that `bindings_id` is not null if the guard + // condition is ever evaluated. + let guard_cond = mk_app!( + make::op1(UnaryOp::PatternBranch(), Term::Var(bindings_id)), + guard + ); + + mk_app!( + make::op1(UnaryOp::BoolOr(), inner_if_cond), + make::op1(UnaryOp::BoolNot(), guard_cond) + ) + } else { + inner_if_cond + }; + // inner if block: // // if bindings_id == null then @@ -754,11 +790,11 @@ impl Compile for MatchData { // # this primop evaluates body with an environment extended with bindings_id // %pattern_branch% bindings_id body let inner = make::if_then_else( - make::op2(BinaryOp::Eq(), Term::Var(bindings_id), Term::Null), + inner_if_cond, cont, mk_app!( make::op1(UnaryOp::PatternBranch(), Term::Var(bindings_id),), - body + branch.body ), ); @@ -772,7 +808,7 @@ impl Compile for MatchData { Term::Record(RecordData::empty()), make::let_in( bindings_id, - pat.compile_part(value_id, init_bindings_id), + branch.pattern.compile_part(value_id, init_bindings_id), inner, ), ) diff --git a/core/src/transform/desugar_destructuring.rs b/core/src/transform/desugar_destructuring.rs index 77059e6ef8..69639b2152 100644 --- a/core/src/transform/desugar_destructuring.rs +++ b/core/src/transform/desugar_destructuring.rs @@ -10,7 +10,7 @@ use crate::{ identifier::LocIdent, match_sharedterm, - term::{pattern::*, MatchData, RichTerm, Term}, + term::{pattern::*, MatchBranch, MatchData, RichTerm, Term}, }; /// Entry point of the destructuring desugaring transformation. @@ -48,16 +48,20 @@ pub fn desugar_fun(mut pat: Pattern, body: RichTerm) -> Term { /// Desugar a destructuring let-binding. /// /// A let-binding `let = bound in body` is desugared to ` |> match { => body }`. -pub fn desugar_let(pat: Pattern, bound: RichTerm, body: RichTerm) -> Term { +pub fn desugar_let(pattern: Pattern, bound: RichTerm, body: RichTerm) -> Term { // the position of the match expression is used during error reporting, so we try to provide a // sensible one. - let match_expr_pos = pat.pos.fuse(bound.pos); + let match_expr_pos = pattern.pos.fuse(bound.pos); // `(match { => }) ` Term::App( RichTerm::new( Term::Match(MatchData { - branches: vec![(pat, body)], + branches: vec![MatchBranch { + pattern, + guard: None, + body, + }], }), match_expr_pos, ), diff --git a/core/src/transform/free_vars.rs b/core/src/transform/free_vars.rs index e8f86c7d92..414a4f4dfe 100644 --- a/core/src/transform/free_vars.rs +++ b/core/src/transform/free_vars.rs @@ -8,7 +8,7 @@ use crate::{ term::pattern::*, term::{ record::{Field, FieldDeps, RecordDeps}, - IndexMap, RichTerm, SharedTerm, StrChunk, Term, + IndexMap, MatchBranch, RichTerm, SharedTerm, StrChunk, Term, }, typ::{RecordRowF, RecordRows, RecordRowsF, Type, TypeF}, }; @@ -87,11 +87,20 @@ impl CollectFreeVars for RichTerm { t2.collect_free_vars(free_vars); } Term::Match(data) => { - for (pat, branch) in data.branches.iter_mut() { + for MatchBranch { + pattern, + guard, + body, + } in data.branches.iter_mut() + { let mut fresh = HashSet::new(); - branch.collect_free_vars(&mut fresh); - pat.remove_bindings(&mut fresh); + if let Some(guard) = guard { + guard.collect_free_vars(&mut fresh); + } + body.collect_free_vars(&mut fresh); + + pattern.remove_bindings(&mut fresh); free_vars.extend(fresh); } diff --git a/core/src/typ.rs b/core/src/typ.rs index 10f9caade6..ae0f10693f 100644 --- a/core/src/typ.rs +++ b/core/src/typ.rs @@ -51,7 +51,7 @@ use crate::{ stdlib::internals, term::{ array::Array, make as mk_term, record::RecordData, string::NickelString, IndexMap, - MatchData, RichTerm, Term, Traverse, TraverseControl, TraverseOrder, + MatchBranch, MatchData, RichTerm, Term, Traverse, TraverseControl, TraverseOrder, }, }; @@ -1022,7 +1022,11 @@ impl Subcontract for EnumRows { pos: row.id.pos, }; - branches.push((pattern, body)); + branches.push(MatchBranch { + pattern, + guard: None, + body, + }); } EnumRowsIteratorItem::TailVar(var) => { tail_var = Some(var); @@ -1049,14 +1053,15 @@ impl Subcontract for EnumRows { ) }; - branches.push(( - Pattern { + branches.push(MatchBranch { + pattern: Pattern { data: PatternData::Wildcard, alias: None, pos: default_pos, }, - default, - )); + guard: None, + body: default, + }); let match_expr = mk_app!(Term::Match(MatchData { branches }), mk_term::var(value_arg)); diff --git a/core/src/typecheck/mod.rs b/core/src/typecheck/mod.rs index 3fa71b905c..1f9c97dc56 100644 --- a/core/src/typecheck/mod.rs +++ b/core/src/typecheck/mod.rs @@ -60,8 +60,8 @@ use crate::{ identifier::{Ident, LocIdent}, stdlib as nickel_stdlib, term::{ - pattern::Pattern, record::Field, LabeledType, RichTerm, StrChunk, Term, Traverse, - TraverseOrder, TypeAnnotation, + record::Field, LabeledType, MatchBranch, RichTerm, StrChunk, Term, Traverse, TraverseOrder, + TypeAnnotation, }, typ::*, {mk_uty_arrow, mk_uty_enum, mk_uty_record, mk_uty_record_row}, @@ -1534,11 +1534,11 @@ fn walk( walk(state, ctxt, visitor, t) } Term::Match(data) => { - data.branches.iter().try_for_each(|(pat, branch)| { + data.branches.iter().try_for_each(|MatchBranch { pattern, guard, body }| { let mut local_ctxt = ctxt.clone(); - let PatternTypeData { bindings: pat_bindings, .. } = pat.data.pattern_types(state, &ctxt, pattern::TypecheckMode::Walk)?; + let PatternTypeData { bindings: pat_bindings, .. } = pattern.data.pattern_types(state, &ctxt, pattern::TypecheckMode::Walk)?; - if let Some(alias) = &pat.alias { + if let Some(alias) = &pattern.alias { visitor.visit_ident(alias, mk_uniftype::dynamic()); local_ctxt.type_env.insert(alias.ident(), mk_uniftype::dynamic()); } @@ -1548,7 +1548,11 @@ fn walk( local_ctxt.type_env.insert(id.ident(), typ); } - walk(state, local_ctxt, visitor, branch) + if let Some(guard) = guard { + walk(state, local_ctxt.clone(), visitor, guard)?; + } + + walk(state, local_ctxt, visitor, body) })?; Ok(()) @@ -1995,18 +1999,21 @@ fn check( // introduced to open enum rows and close the corresponding rows at the end of the // procedure). - // We zip the pattern types with each case + // We zip the pattern types with each branch let with_pat_types = data .branches .iter() - .map(|(pat, branch)| -> Result<_, TypecheckError> { + .map(|branch| -> Result<_, TypecheckError> { Ok(( - pat, - pat.pattern_types(state, &ctxt, pattern::TypecheckMode::Enforce)?, branch, + branch.pattern.pattern_types( + state, + &ctxt, + pattern::TypecheckMode::Enforce, + )?, )) }) - .collect::, &RichTerm)>, _>>()?; + .collect::)>, _>>()?; // A match expression is a special kind of function. Thus it's typed as `a -> b`, where // `a` is a type determined by the patterns and `b` is the type of each match arm. @@ -2014,9 +2021,17 @@ fn check( let return_type = state.table.fresh_type_uvar(ctxt.var_level); // Express the constraint that all the arms of the match expression should have a - // compatible type. - for (pat, pat_types, arm) in with_pat_types.iter() { - if let Some(alias) = &pat.alias { + // compatible type and that each guard must be a boolean. + for ( + MatchBranch { + pattern, + guard, + body, + }, + pat_types, + ) in with_pat_types.iter() + { + if let Some(alias) = &pattern.alias { visitor.visit_ident(alias, return_type.clone()); ctxt.type_env.insert(alias.ident(), return_type.clone()); } @@ -2026,12 +2041,14 @@ fn check( ctxt.type_env.insert(id.ident(), typ.clone()); } - check(state, ctxt.clone(), visitor, arm, return_type.clone())?; + if let Some(guard) = guard { + check(state, ctxt.clone(), visitor, guard, mk_uniftype::bool())?; + } + + check(state, ctxt.clone(), visitor, body, return_type.clone())?; } - let pat_types = with_pat_types - .into_iter() - .map(|(_, pat_types, _)| pat_types); + let pat_types = with_pat_types.into_iter().map(|(_, pat_types)| pat_types); // Unify all the pattern types with the argument's type, and build the list of all open // tail vars @@ -2073,6 +2090,7 @@ fn check( // occurrences, we can finally close the tails that need to be. pattern::close_enums(enum_open_tails, &wildcard_occurrences, state); + // And finally fail if there was an error. pat_unif_result.map_err(|err| err.into_typecheck_err(state, rt.pos))?; // We unify the expected type of the match expression with `arg_type -> return_type`. @@ -2092,7 +2110,9 @@ fn check( // as desired. // // As a safety net, the tail closing code panics (in debug mode) if it finds a rigid - // type variable at the end of the tail of a pattern type. + // type variable at the end of the tail of a pattern type, which would happen if we + // somehow generalized an enum row type variable before properly closing the tails + // before. ty.unify( mk_uty_arrow!(arg_type.clone(), return_type.clone()), state, diff --git a/core/tests/integration/inputs/pattern-matching/guards.ncl b/core/tests/integration/inputs/pattern-matching/guards.ncl new file mode 100644 index 0000000000..f37b1bcc8b --- /dev/null +++ b/core/tests/integration/inputs/pattern-matching/guards.ncl @@ -0,0 +1,32 @@ +# test.type = 'pass' +let {check, ..} = import "../lib/assert.ncl" in + +[ + 'Ok |> match { + 'Ok if false => false, + 'Ok if true => true, + _ => false, + }, + + 'Some "true" |> match { + 'Some x if std.is_number x => false, + 'Some x if std.is_bool x => false, + 'Some x if std.is_string x => x == "true", + _ => false, + }, + + { + hello = ["hello"], + world=["world"] + } + |> match { + {hello, world} if std.array.length hello == 0 => false, + {hello, universe} if true => false, + {hello, world} + if (world |> (@) hello + |> std.string.join ", ") + == "hello, world" => true, + _ => false + } +] +|> check diff --git a/core/tests/integration/inputs/pattern-matching/non_bool_guard.ncl b/core/tests/integration/inputs/pattern-matching/non_bool_guard.ncl new file mode 100644 index 0000000000..3088c176cd --- /dev/null +++ b/core/tests/integration/inputs/pattern-matching/non_bool_guard.ncl @@ -0,0 +1,7 @@ +# test.type = 'error' +# +# [test.metadata] +# error = 'EvalError::UnaryPrimopTypeError' +{foo = 'Foo 5, bar = 5} |> match { + {foo = 'Foo x, bar} if x => x, +} diff --git a/core/tests/integration/inputs/typecheck/pattern_non_bool_guard.ncl b/core/tests/integration/inputs/typecheck/pattern_non_bool_guard.ncl new file mode 100644 index 0000000000..5696fccef2 --- /dev/null +++ b/core/tests/integration/inputs/typecheck/pattern_non_bool_guard.ncl @@ -0,0 +1,16 @@ +# test.type = 'error' +# eval = 'typecheck' +# +# [test.metadata] +# error = 'TypecheckError::TypeMismatch' +# +# [test.metadata.expectation] +# expected = 'Bool' +# inferred = 'Number' +( + {foo = 1, bar = 2} + |> match { + {foo, bar} if 1+1 => foo + bar, + _ => 0, + } +) : _ diff --git a/core/tests/integration/inputs/typecheck/pattern_unbound_identifier_guard.ncl b/core/tests/integration/inputs/typecheck/pattern_unbound_identifier_guard.ncl new file mode 100644 index 0000000000..ae23ccbe0b --- /dev/null +++ b/core/tests/integration/inputs/typecheck/pattern_unbound_identifier_guard.ncl @@ -0,0 +1,15 @@ +# test.type = 'error' +# eval = 'typecheck' +# +# [test.metadata] +# error = 'TypecheckError::UnboundIdentifier' +# +# [test.metadata.expectation] +# identifier = 'baz' +( + {foo = 1, bar = 2} + |> match { + {foo, bar} if foo+bar+baz == 0 => foo + bar, + _ => 0, + } +) : _ diff --git a/core/tests/integration/main.rs b/core/tests/integration/main.rs index 1f399eda53..ac45508396 100644 --- a/core/tests/integration/main.rs +++ b/core/tests/integration/main.rs @@ -149,6 +149,8 @@ enum ErrorExpectation { EvalEqError, #[serde(rename = "EvalError::Other")] EvalOther, + #[serde(rename = "EvalError::UnaryPrimopTypeError")] + EvalUnaryPrimopTypeError, #[serde(rename = "EvalError::NAryPrimopTypeError")] EvalNAryPrimopTypeError, #[serde(rename = "EvalError::BlameError")] @@ -228,6 +230,10 @@ impl PartialEq for ErrorExpectation { | (EvalTypeError, Error::EvalError(EvalError::TypeError(..))) | (EvalEqError, Error::EvalError(EvalError::EqError { .. })) | (EvalNAryPrimopTypeError, Error::EvalError(EvalError::NAryPrimopTypeError { .. })) + | ( + EvalUnaryPrimopTypeError, + Error::EvalError(EvalError::UnaryPrimopTypeError { .. }), + ) | (EvalInfiniteRecursion, Error::EvalError(EvalError::InfiniteRecursion(..))) | ( EvalMergeIncompatibleArgs, @@ -385,6 +391,7 @@ impl std::fmt::Display for ErrorExpectation { EvalOther => "EvalError::Other".to_owned(), EvalMergeIncompatibleArgs => "EvalError::MergeIncompatibleArgs".to_owned(), EvalNAryPrimopTypeError => "EvalError::NAryPrimopTypeError".to_owned(), + EvalUnaryPrimopTypeError => "EvalError::UnaryPrimopTypeError".to_owned(), EvalInfiniteRecursion => "EvalError::InfiniteRecursion".to_owned(), EvalIllegalPolymorphicTailAccess => { "EvalError::IllegalPolymorphicTailAccess".to_owned() diff --git a/lsp/nls/src/position.rs b/lsp/nls/src/position.rs index 87fc9021cc..a25796957d 100644 --- a/lsp/nls/src/position.rs +++ b/lsp/nls/src/position.rs @@ -135,7 +135,7 @@ impl PositionLookup { let ids = data .branches .iter() - .flat_map(|(pat, _branch)| pat.bindings().into_iter()) + .flat_map(|branch| branch.pattern.bindings().into_iter()) .map(|(_path, id, _)| id); idents.extend(ids); }