Skip to content

Commit

Permalink
Enum tag destructuring (#1813)
Browse files Browse the repository at this point in the history
* Support enum tags in patterns

Until now, only enum variants were supported within patterns (applied
enums). This commit adds support for bare enum tags as well.

* Improve enum variants handling in stdlib

Make `typeof` returns `'Enum` for enum variants as well, and update
`std.contract.Equal` to properly handle enum tags and enum variants.

* Add tests for enum tag patterns

* Update manual's sample to make tests pass
  • Loading branch information
yannham authored Feb 13, 2024
1 parent b01f9b8 commit 0768f05
Show file tree
Hide file tree
Showing 13 changed files with 114 additions and 48 deletions.
2 changes: 1 addition & 1 deletion core/src/eval/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ impl<R: ImportResolver, C: Cache> VirtualMachine<R, C> {
Term::Num(_) => "Number",
Term::Bool(_) => "Bool",
Term::Str(_) => "String",
Term::Enum(_) => "Enum",
Term::Enum(_) | Term::EnumVariant { .. } => "Enum",
Term::Fun(..) | Term::Match { .. } => "Function",
Term::Array(..) => "Array",
Term::Record(..) | Term::RecRecord(..) => "Record",
Expand Down
21 changes: 13 additions & 8 deletions core/src/parser/grammar.lalrpop
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ Pattern: Pattern = {
#[inline]
PatternData: PatternData = {
RecordPattern => PatternData::Record(<>),
EnumVariantPattern => PatternData::EnumVariant(<>),
EnumPattern => PatternData::Enum(<>),
Ident => PatternData::Any(<>),
};

Expand All @@ -560,13 +560,18 @@ RecordPattern: RecordPattern = {
},
};

EnumVariantPattern: EnumVariantPattern =
<start: @L> <tag: EnumTag> ".." "(" <pattern: Pattern> ")" <end: @R> =>
EnumVariantPattern {
tag,
pattern: Box::new(pattern),
span: mk_span(src_id, start, end),
};
EnumPattern: EnumPattern = {
<start: @L> <tag: EnumTag> <end: @R> => EnumPattern {
tag,
pattern: None,
span: mk_span(src_id, start, end),
},
<start: @L> <tag: EnumTag> ".." "(" <pattern: Pattern> ")" <end: @R> => EnumPattern {
tag,
pattern: Some(Box::new(pattern)),
span: mk_span(src_id, start, end),
},
};

// A binding `ident = <pattern>` inside a record pattern.
FieldPattern: FieldPattern = {
Expand Down
16 changes: 8 additions & 8 deletions core/src/pretty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@ use std::fmt;

use crate::identifier::LocIdent;
use crate::parser::lexer::KEYWORDS;
use crate::term::pattern::{
EnumVariantPattern, Pattern, PatternData, RecordPattern, RecordPatternTail,
};
use crate::term::pattern::{EnumPattern, Pattern, PatternData, RecordPattern, RecordPatternTail};
use crate::term::record::RecordData;
use crate::term::{
record::{Field, FieldMetadata},
Expand Down Expand Up @@ -574,12 +572,12 @@ where
match self {
PatternData::Any(id) => allocator.as_string(id),
PatternData::Record(rp) => rp.pretty(allocator),
PatternData::EnumVariant(evp) => evp.pretty(allocator),
PatternData::Enum(evp) => evp.pretty(allocator),
}
}
}

impl<'a, D, A> Pretty<'a, D, A> for &EnumVariantPattern
impl<'a, D, A> Pretty<'a, D, A> for &EnumPattern
where
D: NickelAllocatorExt<'a, A>,
D::Doc: Clone,
Expand All @@ -590,9 +588,11 @@ where
allocator,
"'",
ident_quoted(&self.tag),
"..(",
&*self.pattern,
")"
if let Some(ref arg_pat) = self.pattern {
docs![allocator, "..(", &**arg_pat, ")"]
} else {
allocator.nil()
}
]
}
}
Expand Down
25 changes: 17 additions & 8 deletions core/src/term/pattern.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ pub enum PatternData {
Any(LocIdent),
/// A record pattern as in `{ a = { b, c } }`
Record(RecordPattern),
/// An enum pattern as in `let 'Foo x = y in ...`
EnumVariant(EnumVariantPattern),
/// An enum pattern as in `'Foo x` or `'Foo`
Enum(EnumPattern),
}

/// A generic pattern, that can appear in a match expression (not yet implemented) or in a
Expand All @@ -39,11 +39,11 @@ pub struct Pattern {
pub span: RawSpan,
}

/// An enum variant pattern.
/// An enum pattern, including both an enum tag and an enum variant.
#[derive(Debug, PartialEq, Clone)]
pub struct EnumVariantPattern {
pub struct EnumPattern {
pub tag: LocIdent,
pub pattern: Box<Pattern>,
pub pattern: Option<Box<Pattern>>,
pub span: RawSpan,
}

Expand Down Expand Up @@ -197,7 +197,7 @@ impl ElaborateContract for PatternData {
match self {
PatternData::Any(_) => None,
PatternData::Record(pat) => pat.elaborate_contract(),
PatternData::EnumVariant(pat) => pat.elaborate_contract(),
PatternData::Enum(pat) => pat.elaborate_contract(),
}
}
}
Expand All @@ -208,12 +208,21 @@ impl ElaborateContract for Pattern {
}
}

impl ElaborateContract for EnumVariantPattern {
impl ElaborateContract for EnumPattern {
fn elaborate_contract(&self) -> Option<LabeledType> {
let pos = TermPos::Original(self.span);

// TODO[adts]: it would be better to simply build a type like `[| 'tag arg |]` or `[| 'tag
// |]` and to rely on its derived contract. However, for the time being, the contract
// derived from enum variants isn't implemented yet.
let contract = if self.pattern.is_some() {
mk_app!(internals::enum_variant(), Term::Enum(self.tag))
} else {
mk_app!(internals::stdlib_contract_equal(), Term::Enum(self.tag))
};

let typ = Type {
typ: TypeF::Flat(mk_app!(internals::enum_variant(), Term::Enum(self.tag))),
typ: TypeF::Flat(contract),
pos,
};

Expand Down
32 changes: 23 additions & 9 deletions core/src/transform/desugar_destructuring.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@
//! ```
use crate::identifier::LocIdent;
use crate::match_sharedterm;
use crate::mk_app;
use crate::term::pattern::*;
use crate::term::{
make::{op1, op2},
BinaryOp::DynRemove,
LetAttrs, RecordOpKind, RichTerm, Term, TypeAnnotation, UnaryOp,
make as mk_term, BinaryOp::DynRemove, LetAttrs, RecordOpKind, RichTerm, Term, TypeAnnotation,
UnaryOp,
};

/// Entry point of the destructuring desugaring transformation.
Expand Down Expand Up @@ -139,7 +139,7 @@ impl Desugar for PatternData {
// If the pattern is an unconstrained identifier, we just bind it to the value.
PatternData::Any(id) => Term::Let(id, destr, body, LetAttrs::default()),
PatternData::Record(pat) => pat.desugar(destr, body),
PatternData::EnumVariant(pat) => pat.desugar(destr, body),
PatternData::Enum(pat) => pat.desugar(destr, body),
}
}
}
Expand All @@ -149,7 +149,7 @@ impl Desugar for FieldPattern {
// destructured. We extract the field from `destr` and desugar the rest of the pattern against
// `destr.matched_id`.
fn desugar(self, destr: RichTerm, body: RichTerm) -> Term {
let extracted = op1(UnaryOp::StaticAccess(self.matched_id), destr.clone());
let extracted = mk_term::op1(UnaryOp::StaticAccess(self.matched_id), destr.clone());
self.pattern.desugar(extracted, body)
}
}
Expand Down Expand Up @@ -178,10 +178,24 @@ impl Desugar for RecordPattern {
}
}

impl Desugar for EnumVariantPattern {
impl Desugar for EnumPattern {
fn desugar(self, destr: RichTerm, body: RichTerm) -> Term {
let extracted = op1(UnaryOp::EnumUnwrapVariant(), destr.clone());
self.pattern.desugar(extracted, body)
if let Some(arg_pat) = self.pattern {
let extracted = mk_term::op1(UnaryOp::EnumUnwrapVariant(), destr.clone());
arg_pat.desugar(extracted, body)
}
// If the pattern doesn't bind any argument, it's transparent, and we just proceed with the
// body. However, because of lazyness, the associated contract will never be checked,
// because body doesn't depend on `destr`.
//
// For patterns that bind variables, it's reasonable to keep them lazy: that is, in `let
// 'Foo x = destr in body`, `destr` is checked to be an enum only when `x` is evaluated.
// However, for patterns that don't bind anything, the pattern is useless. Arguably, users
// would expect that writing `let 'Foo = x in body` would check that `x` is indeed equal to
// `'Foo`. We thus introduce an artificial dependency: it's exactly the role of `seq`
else {
mk_app!(mk_term::op1(UnaryOp::Seq(), destr), body).into()
}
}
}

Expand All @@ -197,7 +211,7 @@ fn bind_rest(pat: &RecordPattern, destr: RichTerm, body: RichTerm) -> RichTerm {
Term::Let(
capture_var,
pat.patterns.iter().fold(destr, |acc, field_pat| {
op2(
mk_term::op2(
DynRemove(RecordOpKind::default()),
Term::Str(field_pat.matched_id.ident().into()),
acc,
Expand Down
8 changes: 5 additions & 3 deletions core/src/transform/free_vars.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ impl RemoveBindings for PatternData {
PatternData::Record(record_pat) => {
record_pat.remove_bindings(working_set);
}
PatternData::EnumVariant(enum_variant_pat) => {
PatternData::Enum(enum_variant_pat) => {
enum_variant_pat.remove_bindings(working_set);
}
}
Expand Down Expand Up @@ -278,8 +278,10 @@ impl RemoveBindings for RecordPattern {
}
}

impl RemoveBindings for EnumVariantPattern {
impl RemoveBindings for EnumPattern {
fn remove_bindings(&self, working_set: &mut HashSet<Ident>) {
self.pattern.remove_bindings(working_set);
if let Some(ref arg_pat) = self.pattern {
arg_pat.remove_bindings(working_set);
}
}
}
11 changes: 7 additions & 4 deletions core/src/typecheck/pattern.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ impl PatternTypes for PatternData {
PatternData::Record(record_pat) => Ok(UnifType::concrete(TypeF::Record(
record_pat.pattern_types_inj(bindings, state, ctxt, mode)?,
))),
PatternData::EnumVariant(enum_pat) => {
PatternData::Enum(enum_pat) => {
let row = enum_pat.pattern_types_inj(bindings, state, ctxt, mode)?;

// This represents the single-row, closed type `[| row |]`
Expand Down Expand Up @@ -227,7 +227,7 @@ impl PatternTypes for FieldPattern {
}
}

impl PatternTypes for EnumVariantPattern {
impl PatternTypes for EnumPattern {
type PatType = UnifEnumRow;

fn pattern_types_inj(
Expand All @@ -239,11 +239,14 @@ impl PatternTypes for EnumVariantPattern {
) -> Result<Self::PatType, TypecheckError> {
let typ_arg = self
.pattern
.pattern_types_inj(bindings, state, ctxt, mode)?;
.as_ref()
.map(|pat| pat.pattern_types_inj(bindings, state, ctxt, mode))
.transpose()?
.map(Box::new);

Ok(UnifEnumRow {
id: self.tag,
typ: Some(Box::new(typ_arg)),
typ: typ_arg,
})
}
}
17 changes: 16 additions & 1 deletion core/stdlib/std.ncl
Original file line number Diff line number Diff line change
Expand Up @@ -855,7 +855,22 @@
let constant_type = %typeof% constant in
let check_typeof_eq = fun ctr_label value =>
let value_type = %typeof% value in
if value_type == constant_type then
if value_type == constant_type && value_type == 'Enum then
if %enum_is_variant% value != %enum_is_variant% constant then
let enum_kind = fun x =>
if %enum_is_variant% x then
"enum variant"
else
"enum tag"
in

ctr_label
|> label.with_message "expected an %{enum_kind constant}, got an %{enum_kind value}"
|> label.append_note "`std.contract.Equal some_enum` requires that the checked value is equal to the enum `some_enum`, but they are different variants."
|> blame
else
value
else if value_type == constant_type then
value
else
ctr_label
Expand Down
4 changes: 4 additions & 0 deletions core/tests/integration/inputs/destructuring/adt_enum_tag.ncl
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# test.type = 'pass'
let x = 'Foo in
let 'Foo = x in
true
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# test.type = 'error'
#
# [test.metadata]
# error = 'EvalError::BlameError'
let 'Foo = 'Bar in
true
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# test.type = 'error'
#
# [test.metadata]
# error = 'EvalError::BlameError'
let 'Foo = 'Foo..(5) in
true
6 changes: 3 additions & 3 deletions doc/manual/merging.md
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,7 @@ argument), we do get a contract violation error:
{ foo | FooContract }
& { foo.required_field1 = "here" }
in
intermediate
& { foo.required_field2 = "here" }
|> std.deep_seq intermediate
Expand All @@ -621,9 +621,9 @@ error: missing definition for `required_field2`
8 │ & { foo.required_field1 = "here" }
│ ------------------------ in this record
┌─ <stdlib/std.ncl>:2997:18
┌─ <stdlib/std.ncl>:3012:18
2997 │ = fun x y => %deep_seq% x y,
3012 │ = fun x y => %deep_seq% x y,
│ ------------ accessed here
```

Expand Down
8 changes: 5 additions & 3 deletions lsp/nls/src/pattern.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ impl InjectBindings for PatternData {
PatternData::Record(record_pat) => {
record_pat.inject_bindings(bindings, path, parent_deco)
}
PatternData::EnumVariant(evariant_pat) => {
PatternData::Enum(evariant_pat) => {
evariant_pat.inject_bindings(bindings, path, parent_deco)
}
}
Expand Down Expand Up @@ -120,7 +120,7 @@ impl InjectBindings for FieldPattern {
}
}

impl InjectBindings for EnumVariantPattern {
impl InjectBindings for EnumPattern {
fn inject_bindings(
&self,
bindings: &mut Vec<(Vec<LocIdent>, LocIdent, Field)>,
Expand All @@ -129,6 +129,8 @@ impl InjectBindings for EnumVariantPattern {
) {
//TODO: I'm not sure we should just transparently forward to the variant's argument. Maybe
//we need a more complex notion of path here, that knows when we enter an enum variant?
self.pattern.inject_bindings(bindings, path, None);
if let Some(ref arg_pat) = self.pattern {
arg_pat.inject_bindings(bindings, path, None);
}
}
}

0 comments on commit 0768f05

Please sign in to comment.