diff --git a/core/src/error/mod.rs b/core/src/error/mod.rs index 70afaef9ed..0e77b8ecad 100644 --- a/core/src/error/mod.rs +++ b/core/src/error/mod.rs @@ -391,6 +391,17 @@ pub enum TypecheckError { /// The position of the expression that was being typechecked as `type_var`. pos: TermPos, }, + /// Invalid `or` pattern. + /// + /// This error is raised when the patterns composing an `or`-pattern don't have the precise + /// same set of free variables. For example, `'Foo x or 'Bar y`. + OrPatternVarsMismatch { + /// A variable which isn't present in all the other patterns (there might be more of them, + /// this is just a sample). + var: LocIdent, + /// The position of the whole or-pattern. + pos: TermPos, + }, } #[derive(Debug, PartialEq, Eq, Clone, Default)] @@ -2516,6 +2527,22 @@ impl IntoDiagnostics for TypecheckError { ), ])] } + TypecheckError::OrPatternVarsMismatch { var, pos } => { + let mut labels = vec![primary_alt(var.pos.into_opt(), var.into_label(), files) + .with_message("this variable must occur in all branches")]; + + if let Some(span) = pos.into_opt() { + labels.push(secondary(&span).with_message("in this `or` pattern")); + } + + vec![Diagnostic::error() + .with_message("`or` pattern variable mismatch".to_string()) + .with_labels(labels) + .with_notes(vec![ + "All branches of an `or` pattern must bind exactly the same set of variables" + .into(), + ])] + } } } } diff --git a/core/src/parser/grammar.lalrpop b/core/src/parser/grammar.lalrpop index b94ef7a68d..38247a91e2 100644 --- a/core/src/parser/grammar.lalrpop +++ b/core/src/parser/grammar.lalrpop @@ -520,18 +520,26 @@ FieldPathElem: FieldPathElem = { // A pattern. // -// The PatternF, PatternDataF and EnumPatternF rules are parametrized by a -// (string) flag (using LALRPOP's undocumented conditional macros). The idea is -// that those rules have two flavours: the most general one, which allow -// patterns to be unrestricted, and a version for function arguments. +// The PatternF, and in general several pattern rules ending with a capital `F`, +// are parametrized by other pattern rules. In general, depending on where those +// patterns and subpatterns occur, they need various restrictions to ensure that +// parsing is never ambiguous (at least with respect to the LALR(1)/LR(1) +// capabilities supported by LALRPOP). // -// The issue is the following: before the introduction of enum variants, -// functions have been allowed to match on several arguments using a sequence of -// patterns. For example, `fun {x} {y} z => x + y + z`. With variants, we've -// added the following pattern form: `'SomeTag argument`. Now, something like -// `fun 'SomeTag 'SomeArg => ...` is ambiguous: are we matching on a single -// argument that we expect to be `('SomeTag 'SomeArg)`, or on two separate -// arguments that are bare enum tags, as in `fun ('SomeTag) ('SomeArg)`? +// The various flavours of pattern rules and their respective motivation are +// detailed below. +// +// # Parentheses +// +// ## Enum variants +// +// Before the introduction of enum variants, functions have been allowed to +// match on several arguments using a sequence of patterns. For example, `fun +// {x} {y} z => x + y + z`. With variants, we've added the following pattern +// form: `'SomeTag argument`. Now, something like `fun 'SomeTag 'SomeArg => ...` +// is ambiguous: are we matching on a single argument that we expect to be +// `('SomeTag 'SomeArg)`, or on two separate arguments that are bare enum tags, +// as in `fun ('SomeTag) ('SomeArg)`? // // To avoid ambiguity, we force the top-level argument patterns of a function to // use a parenthesized version for enum variants. Thus `fun 'Foo 'Bar => ...` is @@ -545,12 +553,62 @@ FieldPathElem: FieldPathElem = { // readability. In practice, this means that the argument pattern of an enum // variant pattern has the same restriction as a function argument pattern. // -// The flavour parameter `F` can either be `"function"`, which is disabling the -// non-parenthesized enum variant rule, or any other string for the general -// flavour. In practice we use "". +// ## `or` patterns +// +// The same ambiguity (and solution) extends to `or` patterns, which are also +// ambiguous when used as function arguments. As we want `or` to remain a valid +// identifier, `fun x or y => ...` could be a function of 3 arguments `x`, `or` +// and `y`, or a function of one argument matching the pattern `(x or y)` (this +// isn't a valid or-pattern because the bound variables are different in each or +// branch, but that's beside the point - checking variables mismatches isn't the +// job of parsing rules). +// +// We thus reuse the exact same idea in order to force or-patterns used at the +// top-level of a function argument to be parenthesized. +// +// ## Or-patterns ambiguities +// +// Parsing `or`-patterns without ambiguity, while still allowing `or` to remain +// a valid identifier including within patterns (as in `'Foo or`) requires +// slightly more complicated constraints. +// +// The first issue is parentheses: how to parse `x or y or z`? Here we take a +// simple stance: patterns (enum variant patterns and `or`-patterns) within an +// `or`-branch must be parenthesized. So, this is parsed as a flat `or`-pattern +// `[x, y, z]`. +// +// The second, harder issue is that enum variant patterns might have a trailing +// `or` identifier. That is, when seeing `'Foo or`, the parser doesn't know if +// it should shift in the hope of seeing another pattern after the `or`, and +// parsing the overall result as an `or`-pattern `['Foo, +// or` form, where the argument is an `Any` pattern with identifier `or`, and +// everything else +// 2. Within an `or`-pattern branch, in the generic case, we restrict patterns +// so that they can't include the first form (`' or`) nor enum tag +// patterns (like `'Foo`). This way, whether in a `('Foo or)` enum variant +// pattern, or in a `'Foo or 'Bar` `or`-pattern, the `'Foo or` part is +// invariably parsed using the special rule `EnumVariantOrPattern`. +// 3. Then, in the `or`-pattern branch rule, we assemble `or` branches which are +// either `' or`, which is re-interpreted on the fly not as a enum +// variant pattern, but as an enum tag pattern followed by `or`, or the +// restricted generic form of pattern followed by an actual `or`. +// +// There are other minor details, but with enough variations of pattern rules, +// we can ensure there's only one way to parse each and every combination with +// only one look-ahead, thus satisfying the LR(1). #[inline] -PatternF: Pattern = { - "@")?> > => { +PatternF: Pattern = { + + > "@")?> + > + => { Pattern { alias, data, @@ -560,21 +618,50 @@ PatternF: Pattern = { }; #[inline] -PatternDataF: PatternData = { +PatternDataF: PatternData = { RecordPattern => PatternData::Record(<>), ArrayPattern => PatternData::Array(<>), ConstantPattern => PatternData::Constant(<>), - EnumPatternF => PatternData::Enum(<>), - Ident => PatternData::Any(<>), + EnumRule => PatternData::Enum(<>), + OrRule => PatternData::Or(<>), + IdentRule => PatternData::Any(<>), "_" => PatternData::Wildcard, }; -// A general pattern. +// A general pattern, unrestricted. #[inline] -Pattern: Pattern = PatternF<"">; +Pattern: Pattern = PatternF; -// A pattern restricted to function arguments. -PatternFun: Pattern = PatternF<"function">; +// A pattern restricted to function arguments, which requires `or`-patterns and +// enum variant patterns to be parenthesized at the top-level. +#[inline] +PatternFun: Pattern = PatternF; + +// A pattern that can be used within a branch of an `or`-pattern. To avoid a +// shift-reduce conflicts (because we want to allow `or` to remain a valid +// identifier, even inside patterns), this pattern has the following +// restrictions: +// +// 1. Enum tag patterns are forbidden (such as `'Foo or 'Bar`). +// 2. Enum variant patterns shouldn't have the "or" identifier as an argument. +// 3. Or-pattern must be parenthesized when nested in another or-pattern. +// 4. Aliases are forbidden at the top-level. Otherwise, we run into troubles +// with alias chains. Furthermore, the branches of an `or` pattern must have +// the same bound variables, so it usually makes more sense to alias the whole +// `or`-pattern instead of one specific branch. +// +// See the `PatternF` rule for an explanation of why we need those restrictions. +#[inline] +PatternOrBranch: Pattern = + + > + => { + Pattern { + alias: None, + data, + pos: mk_pos(src_id, left, right), + } + }; ConstantPattern: ConstantPattern = { => ConstantPattern { @@ -646,26 +733,146 @@ ArrayPattern: ArrayPattern = { }, }; -EnumPatternF: EnumPattern = { - => EnumPattern { - tag, - pattern: None, - pos: mk_pos(src_id, start, end), - }, - // See documentation of PatternF to see why we use the "function" variant - // here. - > if F != "function" => EnumPattern { +// A pattern for an enum tag (without argument). +EnumTagPattern: EnumPattern = => EnumPattern { + tag, + pattern: None, + pos: mk_pos(src_id, start, end), +}; + +// A rule which only matches an enum variant pattern of the form `' or`. +// Used to disambiguate between an enum variant pattern and an `or`-pattern. +EnumVariantOrPattern: EnumPattern = + + + > + => { + let pos_or = or_arg.pos; + + EnumPattern { + tag, + pattern: Some(Box::new(Pattern { + data: PatternData::Any(or_arg), + alias: None, + pos: pos_or, + })), + pos: mk_pos(src_id, start, end), + } + }; + +// An enum variant pattern, excluding the `EnumVariantPatternOr` case: that is, +// this rule doesn't match the case `' or`. +EnumVariantNoOrPattern: EnumPattern = + + + >> + => EnumPattern { tag, pattern: Some(Box::new(pattern)), pos: mk_pos(src_id, start, end), + }; + +// A pattern for an enum variant (with an argument). To avoid ambiguity, we need +// to decompose it into two disjoint rules, one that only match the `' or` +// input and everything else. +// +// The idea is that the former case can also serve as the prefix of an +// `or`-pattern, as in `'Foo or 'Bar`; but as long as we parse this common +// prefix using the same rule and only disambiguate later, there is no +// shift/reduce conflict. +EnumVariantPattern: EnumPattern = { + EnumVariantOrPattern, + EnumVariantNoOrPattern, +}; + +// A twisted version of EnumPattern made specifically for the branch of an +// `or`-pattern. As we parse `EnumVariantOrPattern` and treat it specifically in +// an `or` branch (`OrPatternBranch`), we need to remove it from the enum +// pattern rule. +EnumPatternOrBranch: EnumPattern = { + EnumVariantNoOrPattern, + // Only a top-level un-parenthesized enum variant pattern can be ambiguous. + // If it's parenthesized, we allow the general version including the "or" + // identifier + "(" ")", +}; + + +// An unparenthesized enum pattern (including both enum tags and enum +// variants). +EnumPatternUnparens: EnumPattern = { + EnumTagPattern, + EnumVariantPattern, +}; + +// A parenthesized enum pattern, including both tags and variants (note that an +// enum tag alone is never parenthesized: parentheses only applies to enum +// variant patterns). +EnumPatternParens: EnumPattern = { + EnumTagPattern, + "(" ")", +} + +// The unrestricted rule for enum patterns. Allows both enum tags and enum +// variants, and both parenthesized and un-parenthesized enum variants. +EnumPattern: EnumPattern = { + EnumTagPattern, + EnumVariantPattern, + "(" ")" +}; + +// An individual element of an or-pattern, plus a trailing "or". This rule is a +// bit artificial, and is essentially here to dispel the shift/reduce conflict +// around `'Foo or`/`'Foo or 'Bar` explained in the description of `PatternF`. +OrPatternBranch: Pattern = { + // To avoid various shift-reduce conflicts, the patterns used within an + // `or`-branch have several restrictions. See the `PatternOrBranch` rule. + "or", + // A variant pattern of the form `' or`. The trick is to instead + // consider it as the enum tag pattern `'` followed by the `or` + // contextual keyword after-the-fact. + => { + let pos = pat.pos; + + Pattern { + pos, + alias: None, + data: PatternData::Enum(EnumPattern { + tag: pat.tag, + pattern: None, + pos, + }), + } }, - "(" > ")" => EnumPattern { - tag, - pattern: Some(Box::new(pattern)), - pos: mk_pos(src_id, start, end), +}; + +// Unparenthesized `or`-pattern. +OrPatternUnparens: OrPattern = { + + + > + => { + let patterns = + patterns.into_iter().chain(std::iter::once(last)).collect(); + + OrPattern { + patterns, + pos: mk_pos(src_id, start, end), + } }, }; +// Parenthesized `or`-pattern. +OrPatternParens: OrPattern = { + "(" ")", +}; + +// Unrestricted `or`-pattern, which can be parenthesized or not. +OrPattern: OrPattern = { + OrPatternUnparens, + OrPatternParens, +} + // A binding `ident = ` inside a record pattern. FieldPattern: FieldPattern = { ?> @@ -725,12 +932,24 @@ MetadataKeyword: LocIdent = { // // Thus, for fields, ExtendedIdent is use in place of Ident. ExtendedIdent: LocIdent = { - >, - , + WithPos, + Ident, }; -Ident: LocIdent = => - LocIdent::new_with_pos(i, mk_pos(src_id, l, r)); +// The "or" keyword, parsed as an indent. +IdentOr: LocIdent = "or" => LocIdent::new("or"); + +// The set of pure identifiers, which are never keywords in any context. +RestrictedIdent: LocIdent = "identifier" => LocIdent::new(<>); + +// Identifiers allowed everywhere, which include pure identifiers and the "or" +// contextual keyword. With a bit of effort around pattern, we can make it a +// valid identifier unambiguously. +#[inline] +Ident: LocIdent = { + WithPos, + WithPos, +}; Bool: bool = { "true" => true, @@ -1255,6 +1474,7 @@ extern { "null" => Token::Normal(NormalToken::Null), "true" => Token::Normal(NormalToken::True), "false" => Token::Normal(NormalToken::False), + "or" => Token::Normal(NormalToken::Or), "?" => Token::Normal(NormalToken::QuestionMark), "," => Token::Normal(NormalToken::Comma), diff --git a/core/src/parser/lexer.rs b/core/src/parser/lexer.rs index a4602269c1..87548601bd 100644 --- a/core/src/parser/lexer.rs +++ b/core/src/parser/lexer.rs @@ -121,6 +121,10 @@ pub enum NormalToken<'input> { True, #[token("false")] False, + /// Or isn't a reserved keyword. It is a contextual keyword (a keyword that can be used as an + /// identifier because it's not ambiguous but) within patterns. + #[token("or")] + Or, #[token("?")] QuestionMark, diff --git a/core/src/pretty.rs b/core/src/pretty.rs index 5287eb99f3..0b2f9c75ef 100644 --- a/core/src/pretty.rs +++ b/core/src/pretty.rs @@ -400,19 +400,19 @@ where typ.pretty(self).parens_if(needs_parens_in_type_pos(typ)) } - /// Pretty printing of a restricted patterns that requires enum variant patterns to be - /// parenthesized (typically function pattern arguments). The only difference with a general - /// pattern is that for a function, a top-level enum variant pattern with an enum tag as an - /// argument such as `'Foo 'Bar` must be parenthesized, because `fun 'Foo 'Bar => ...` is - /// parsed as a function of two arguments, which are bare enum tags `'Foo` and `'Bar`. We must - /// print `fun ('Foo 'Bar) => ..` instead. + /// Pretty printing of a restricted patterns that requires enum variant patterns and or + /// patterns to be parenthesized (typically function pattern arguments). The only difference + /// with a general pattern is that for a function, a top-level enum variant pattern with an + /// enum tag as an argument such as `'Foo 'Bar` must be parenthesized, because `fun 'Foo 'Bar + /// => ...` is parsed as a function of two arguments, which are bare enum tags `'Foo` and + /// `'Bar`. We must print `fun ('Foo 'Bar) => ..` instead. fn pat_with_parens(&'a self, pattern: &Pattern) -> DocBuilder<'a, Self, A> { pattern.pretty(self).parens_if(matches!( pattern.data, PatternData::Enum(EnumPattern { pattern: Some(_), .. - }) + }) | PatternData::Or(_) )) } } @@ -591,6 +591,7 @@ where PatternData::Array(ap) => ap.pretty(allocator), PatternData::Enum(evp) => evp.pretty(allocator), PatternData::Constant(cp) => cp.pretty(allocator), + PatternData::Or(op) => op.pretty(allocator), } } } @@ -732,6 +733,26 @@ where } } +impl<'a, D, A> Pretty<'a, D, A> for &OrPattern +where + D: NickelAllocatorExt<'a, A>, + D::Doc: Clone, + A: Clone + 'a, +{ + fn pretty(self, allocator: &'a D) -> DocBuilder<'a, D, A> { + docs![ + allocator, + allocator.intersperse( + self.patterns + .iter() + .map(|pat| allocator.pat_with_parens(pat)), + docs![allocator, allocator.line(), "or", allocator.space()], + ), + ] + .group() + } +} + impl<'a, D, A> Pretty<'a, D, A> for &RichTerm where D: NickelAllocatorExt<'a, A>, diff --git a/core/src/term/pattern/compile.rs b/core/src/term/pattern/compile.rs index db987e0449..d8d15bed28 100644 --- a/core/src/term/pattern/compile.rs +++ b/core/src/term/pattern/compile.rs @@ -261,6 +261,7 @@ impl CompilePart for PatternData { PatternData::Array(pat) => pat.compile_part(value_id, bindings_id), PatternData::Enum(pat) => pat.compile_part(value_id, bindings_id), PatternData::Constant(pat) => pat.compile_part(value_id, bindings_id), + PatternData::Or(pat) => pat.compile_part(value_id, bindings_id), } } } @@ -304,6 +305,45 @@ impl CompilePart for ConstantPatternData { } } +impl CompilePart for OrPattern { + // Compilation of or patterns. + // + // + // + // let prev_bindings = cont in + // + // # if one of the previous patterns already matched, we just stop here and return the + // # resulting updated bindings. Otherwise, we try the current one + // if prev_bindings != null then + // prev_bindings + // else + // + // + fn compile_part(&self, value_id: LocIdent, bindings_id: LocIdent) -> RichTerm { + self.patterns + .iter() + .fold(Term::Null.into(), |cont, pattern| { + let prev_bindings = LocIdent::fresh(); + + let is_prev_not_null = make::op1( + UnaryOp::BoolNot(), + make::op2(BinaryOp::Eq(), Term::Var(prev_bindings), Term::Null), + ); + + let if_block = make::if_then_else( + is_prev_not_null, + Term::Var(prev_bindings), + pattern.compile_part(value_id, bindings_id), + ); + + make::let_in(prev_bindings, cont, if_block) + }) + } +} + impl CompilePart for RecordPattern { // Compilation of the top-level record pattern wrapper. // diff --git a/core/src/term/pattern/mod.rs b/core/src/term/pattern/mod.rs index 4d364344b0..52f5e9a05f 100644 --- a/core/src/term/pattern/mod.rs +++ b/core/src/term/pattern/mod.rs @@ -31,6 +31,8 @@ pub enum PatternData { Enum(EnumPattern), /// A constant pattern as in `42` or `true`. Constant(ConstantPattern), + /// A sequence alternative patterns as in `'Foo _ or 'Bar _ or 'Baz _`. + Or(OrPattern), } /// A generic pattern, that can appear in a match expression (not yet implemented) or in a @@ -139,6 +141,12 @@ pub enum ConstantPatternData { Null, } +#[derive(Debug, PartialEq, Clone)] +pub struct OrPattern { + pub patterns: Vec, + pub pos: TermPos, +} + /// The tail of a data structure pattern (record or array) which might capture the rest of said /// data structure. #[derive(Debug, PartialEq, Clone)] diff --git a/core/src/transform/free_vars.rs b/core/src/transform/free_vars.rs index f79d0732d3..283bf1c25e 100644 --- a/core/src/transform/free_vars.rs +++ b/core/src/transform/free_vars.rs @@ -254,15 +254,10 @@ impl RemoveBindings for PatternData { PatternData::Any(id) => { working_set.remove(&id.ident()); } - PatternData::Record(record_pat) => { - record_pat.remove_bindings(working_set); - } - PatternData::Array(array_pat) => { - array_pat.remove_bindings(working_set); - } - PatternData::Enum(enum_variant_pat) => { - enum_variant_pat.remove_bindings(working_set); - } + PatternData::Record(record_pat) => record_pat.remove_bindings(working_set), + PatternData::Array(array_pat) => array_pat.remove_bindings(working_set), + PatternData::Enum(enum_variant_pat) => enum_variant_pat.remove_bindings(working_set), + PatternData::Or(or_pat) => or_pat.remove_bindings(working_set), // A wildcard pattern or a constant pattern doesn't bind any variable. PatternData::Wildcard | PatternData::Constant(_) => (), } @@ -316,3 +311,18 @@ impl RemoveBindings for EnumPattern { } } } + +impl RemoveBindings for OrPattern { + fn remove_bindings(&self, working_set: &mut HashSet) { + // Theoretically, we could just remove the bindings of the first pattern, as all + // branch in an or patterns should bind exactly the same variables. However, at the + // time of writing, this condition isn't enforced at parsing type (it's enforced + // during typechecking). It doesn't cost much to be conservative and to remove all + // the bindings (removing something non-existent from a hashet is a no-op), so that + // we don't miss free variables in the case of ill-formed or-patterns, although we + // should ideally rule those out before reaching the free var transformation. + for pat in &self.patterns { + pat.remove_bindings(working_set); + } + } +} diff --git a/core/src/typecheck/pattern.rs b/core/src/typecheck/pattern.rs index 0d5d03ea2d..1b5b91c8c9 100644 --- a/core/src/typecheck/pattern.rs +++ b/core/src/typecheck/pattern.rs @@ -92,7 +92,6 @@ pub struct PatternTypeData { /// tails in [Self::enum_open_tails] should be left open. pub wildcard_occurrences: HashSet, } - /// Close all the enum row types left open when typechecking a match expression. Special case of /// `close_enums` for a single destructuring pattern (thus, where wildcard occurrences are not /// relevant). @@ -390,6 +389,7 @@ impl PatternTypes for PatternData { PatternData::Constant(constant_pat) => { constant_pat.pattern_types_inj(pt_state, path, state, ctxt, mode) } + PatternData::Or(or_pat) => or_pat.pattern_types_inj(pt_state, path, state, ctxt, mode), } } } @@ -522,3 +522,134 @@ impl PatternTypes for EnumPattern { }) } } + +impl PatternTypes for OrPattern { + type PatType = UnifType; + + fn pattern_types_inj( + &self, + pt_state: &mut PatTypeState, + path: PatternPath, + state: &mut State, + ctxt: &Context, + mode: TypecheckMode, + ) -> Result { + // When checking a sequence of or patterns, we must combine their open tails and wildcard + // pattern positions - in fact, when typechecking a whole match expression, this is exactly + // what the typechecker is doing: it merges all those data. And a match expression is, + // similarly to an or pattern, a disjunction of patterns. + // + // However, the treatment of bindings is different. If any of the branch in an or-pattern + // matches, the same code path (the match branch) will be run, and thus they must agree on + // pattern variables. Which means: + // + // 1. All pattern branches must have the same set of variables + // 2. Each variable must have a compatible type across all or-pattern branches + // + // To do so, we call to `pattern_types_inj` with a fresh vector of bindings, so that we can + // post-process them afterward (enforcing 1. and 2. above) before actually adding them to + // the original overall bindings. + // + // `bindings` stores, for each or pattern branch, the inferred type of the whole branch, + // the generated bindings and the position (the latter for error reporting). + let bindings: Result, _> = self + .patterns + .iter() + .map(|pat| -> Result<_, TypecheckError> { + let mut fresh_bindings = Vec::new(); + + let mut local_state = PatTypeState { + bindings: &mut fresh_bindings, + enum_open_tails: pt_state.enum_open_tails, + wildcard_pat_paths: pt_state.wildcard_pat_paths, + }; + + let typ = + pat.pattern_types_inj(&mut local_state, path.clone(), state, ctxt, mode)?; + + // We sort the bindings to check later that they are the same in all branches + fresh_bindings.sort_by_key(|(id, _typ)| *id); + + Ok((typ, fresh_bindings, pat.pos)) + }) + .collect(); + + let mut it = bindings?.into_iter(); + + // We need a reference set of variables (and their types for unification). We just pick the + // first bindings of the list. + let Some((model_typ, model, _pos)) = it.next() else { + // We should never generate empty `or` sequences (it's not possible to write them in + // the source language, at least). However, it doesn't cost much to support them: such + // a pattern never matches anything. Thus, we return the bottom type encoded as `forall + // a. a`. + let free_var = Ident::from("a"); + + return Ok(UnifType::concrete(TypeF::Forall { + var: free_var.into(), + var_kind: VarKind::Type, + body: Box::new(UnifType::concrete(TypeF::Var(free_var))), + })); + }; + + for (typ, pat_bindings, pos) in it { + if model.len() != pat_bindings.len() { + // We need to arbitrary choose a variable to report. We take the last one of the + // longest list, which is guaranteed to not be present in all branches + let witness = if model.len() > pat_bindings.len() { + // unwrap(): model.len() > pat_bindings.len() >= 0 + model.last().unwrap().0 + } else { + // unwrap(): model.len() <= pat_bindings.len() and (by the outer-if) + // pat_bindings.len() != mode.len(), so: + // 0 <= model.len() < pat_bindings.len() + pat_bindings.last().unwrap().0 + }; + + return Err(TypecheckError::OrPatternVarsMismatch { + var: witness, + pos: self.pos, + }); + } + + // We unify the type of the first or-branch with the current or-branch, to make sure + // all the subpatterns are matching values of the same type + if let TypecheckMode::Enforce = mode { + model_typ + .clone() + .unify(typ, state, ctxt) + .map_err(|e| e.into_typecheck_err(state, pos))?; + } + + // Finally, we unify the type of the bindings + for (idx, (id, typ)) in pat_bindings.into_iter().enumerate() { + let (model_id, model_ty) = &model[idx]; + + if *model_id != id { + // Once again, we must arbitrarily pick a variable to report. We take the + // smaller one, which is guaranteed to be missing (indeed, the greater one + // could still appear later in the other list, but the smaller is necessarily + // missing in the list with the greater one) + return Err(TypecheckError::OrPatternVarsMismatch { + var: std::cmp::min(*model_id, id), + pos: self.pos, + }); + } + + if let TypecheckMode::Enforce = mode { + model_ty + .clone() + .unify(typ, state, ctxt) + .map_err(|e| e.into_typecheck_err(state, id.pos))?; + } + } + } + + // Once we have checked that all the bound variables are the same and we have unified their + // types, we can add them to the overall bindings (since they are unified, it doesn't + // matter which type we use - so we just reuse the model, which is still around) + pt_state.bindings.extend(model); + + Ok(model_typ) + } +} diff --git a/core/tests/integration/inputs/pattern-matching/or_pattern_vars_mismatch.ncl b/core/tests/integration/inputs/pattern-matching/or_pattern_vars_mismatch.ncl new file mode 100644 index 0000000000..18dca0132e --- /dev/null +++ b/core/tests/integration/inputs/pattern-matching/or_pattern_vars_mismatch.ncl @@ -0,0 +1,10 @@ +# test.type = 'error' +# +# [test.metadata] +# error = 'TypecheckError::OrPatternVarsMismatch' +# +# [test.metadata.expectation] +# var = 'y' +{data = 'Foo 5} |> match { + {data = 'Foo x} or {field = y @ 'Bar x} => true, +} diff --git a/core/tests/integration/inputs/pattern-matching/or_patterns.ncl b/core/tests/integration/inputs/pattern-matching/or_patterns.ncl new file mode 100644 index 0000000000..af8545c9f0 --- /dev/null +++ b/core/tests/integration/inputs/pattern-matching/or_patterns.ncl @@ -0,0 +1,45 @@ +# test.type = 'pass' +let {check, ..} = import "../lib/assert.ncl" in + +[ + "a" |> match { + "e" or "f" or "g" => false, + "a" or "b" or "c" => true, + _ => false, + }, + + 'Foo (1+1) |> match { + ('Bar _) or ('Baz _) => false, + ('Qux x) or ('Foo x) => x == 2, + _ => false, + }, + + [1, {field = 'Foo 5}, 2] |> match { + [_, {field = 'Bar _} or {field = 'Baz _}, _] => false, + [_, {field = 'Bar _} or {field = 'Foo _}, _] => true, + _ => false, + }, + + {some = "data"} |> match { + x if std.is_number x || std.is_string x => false, + {..} or [..] => true, + _ => false, + }, + + {field = 'Marked} |> match { + {field = x} or {data = x} if x == 'Unmarked => false, + {data = x} or {field = x} if x == 'Marked => true, + _ => false, + }, + + 'Foo 1 |> match { + ('Foo or) or ('Baz or) => or == 1, + _ => false, + }, + + 'Baz |> match { + 'Foo or 'Bar or 'Baz or 'Qux => true, + _ => false, + }, +] +|> check diff --git a/core/tests/integration/inputs/typecheck/or_patterns.ncl b/core/tests/integration/inputs/typecheck/or_patterns.ncl new file mode 100644 index 0000000000..973c518290 --- /dev/null +++ b/core/tests/integration/inputs/typecheck/or_patterns.ncl @@ -0,0 +1,31 @@ +# test.type = 'pass' +let typecheck = [ + match { + ('Foo x) + or ('Bar x) + or ('Baz x) => null, + } : forall a. [| 'Foo a, 'Bar a, 'Baz a |] -> Dyn, + + # open enum rows when using wildcard in or-patterns + + match { + ('Some {foo = 'Bar 5, nested = 'One ('Two null)}) + or ('Some {foo = 'Baz "str", nested = 'One ('Three null)}) + or ('Some {foo = _, nested = 'One _}) => true, + _ => false, + } : forall r1 r2 r3. + [| 'Some { + foo: [| 'Bar Number, 'Baz String; r1 |], + nested: [| 'One [| 'Two Dyn, 'Three Dyn; r2 |] |] }; + r3 + |] -> Bool, + + match { + {foo, bar = x, baz = [y, ..rest]} + or {foo, bar = x @ rest, baz = [y]} + or {foo = y @ foo, bar = x, baz = [..rest]} => + null, + } : forall a. {foo: a, bar: Array a, baz: Array a} -> Dyn, +] in + +true diff --git a/core/tests/integration/inputs/typecheck/pattern_or_closed_enum.ncl b/core/tests/integration/inputs/typecheck/pattern_or_closed_enum.ncl new file mode 100644 index 0000000000..9e31e32858 --- /dev/null +++ b/core/tests/integration/inputs/typecheck/pattern_or_closed_enum.ncl @@ -0,0 +1,13 @@ +# test.type = 'error' +# +# [test.metadata] +# error = 'TypecheckError::ArrowTypeMismatch' +# +# [test.metadata.expectation.cause] +# error = 'TypecheckError::RecordRowMismatch' +match { + {field = 'Foo x} + or {field = 'Bar x} + or {field = 'Baz x} => + null, +}: forall a r. {field: [| 'Foo a, 'Bar a, 'Baz a; r |]} -> Dyn diff --git a/core/tests/integration/inputs/typecheck/pattern_or_type_mismatch.ncl b/core/tests/integration/inputs/typecheck/pattern_or_type_mismatch.ncl new file mode 100644 index 0000000000..6259e7b2df --- /dev/null +++ b/core/tests/integration/inputs/typecheck/pattern_or_type_mismatch.ncl @@ -0,0 +1,11 @@ +# test.type = 'error' +# +# [test.metadata] +# error = 'TypecheckError::TypeMismatch' +# +# [test.metadata.expectation] +# expected = 'Number' +# inferred = 'Bool' +match { + 1 or false => null, +}: _ diff --git a/core/tests/integration/main.rs b/core/tests/integration/main.rs index ac45508396..692fa7ddcb 100644 --- a/core/tests/integration/main.rs +++ b/core/tests/integration/main.rs @@ -204,6 +204,8 @@ enum ErrorExpectation { TypecheckFlatTypeInTermPosition, #[serde(rename = "TypecheckError::VarLevelMismatch")] TypecheckVarLevelMismatch { type_var: String }, + #[serde(rename = "TypecheckError::OrPatternVarsMismatch")] + TypecheckOrPatternVarsMismatch { var: String }, #[serde(rename = "ParseError")] AnyParseError, #[serde(rename = "ParseError::DuplicateIdentInRecordPattern")] @@ -355,6 +357,10 @@ impl PartialEq for ErrorExpectation { type_var: constant, .. }), ) => ident == constant.label(), + ( + TypecheckOrPatternVarsMismatch { var }, + Error::TypecheckError(TypecheckError::OrPatternVarsMismatch { var: id, .. }), + ) => var == id.label(), // The clone is not ideal, but currently we can't compare `TypecheckError` directly // with an ErrorExpectation. Ideally, we would implement `eq` for all error subtypes, // and have the eq with `Error` just dispatch to those sub-eq functions. @@ -436,12 +442,15 @@ impl std::fmt::Display for ErrorExpectation { TypecheckExtraDynTail => "TypecheckError::ExtraDynTail".to_owned(), TypecheckMissingDynTail => "TypecheckError::MissingDynTail".to_owned(), TypecheckArrowTypeMismatch { cause } => { - format!("TypecheckError::ArrowTypeMismatch{cause})") + format!("TypecheckError::ArrowTypeMismatch({cause})") } TypecheckFlatTypeInTermPosition => "TypecheckError::FlatTypeInTermPosition".to_owned(), TypecheckVarLevelMismatch { type_var } => { format!("TypecheckError::VarLevelMismatch({type_var})") } + TypecheckOrPatternVarsMismatch { var } => { + format!("TypecheckError::OrPatternVarsMismatch({var})") + } SerializeNumberOutOfRange => "ExportError::NumberOutOfRange".to_owned(), }; write!(f, "{}", name) diff --git a/lsp/nls/src/pattern.rs b/lsp/nls/src/pattern.rs index ed10ab2e6d..1bdf88fd90 100644 --- a/lsp/nls/src/pattern.rs +++ b/lsp/nls/src/pattern.rs @@ -90,6 +90,7 @@ impl InjectBindings for PatternData { PatternData::Enum(evariant_pat) => { evariant_pat.inject_bindings(bindings, path, parent_deco) } + PatternData::Or(or_pat) => or_pat.inject_bindings(bindings, path, parent_deco), // Wildcard and constant patterns don't bind any variable PatternData::Wildcard | PatternData::Constant(_) => (), } @@ -165,3 +166,16 @@ impl InjectBindings for EnumPattern { } } } + +impl InjectBindings for OrPattern { + fn inject_bindings( + &self, + bindings: &mut Vec<(Vec, LocIdent, Field)>, + path: Vec, + parent_extra: Option<&Field>, + ) { + for subpat in self.patterns.iter() { + subpat.inject_bindings(bindings, path.clone(), parent_extra); + } + } +}