Skip to content

Commit

Permalink
Support pattern contracts in match statement (#1823)
Browse files Browse the repository at this point in the history
* Apply pattern contracts in match expressions

Match expressions have been extended to accept the full range of
patterns (that was reserved to destructuring before). Patterns allow, in
particular, to annotate fields with metadata, such as contract
annotations or default values. Those were, until now, simply ignored by
the new match expressions.

This commit modifies the pattern compilation scheme to handle contracts
and default values properly.

* Add tests for contracts in match patterns

* Fix clippy warning

* Renaming a couple test files
  • Loading branch information
yannham authored Feb 19, 2024
1 parent 4bc05f5 commit 0b294a2
Show file tree
Hide file tree
Showing 8 changed files with 193 additions and 50 deletions.
2 changes: 2 additions & 0 deletions core/src/eval/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1169,6 +1169,8 @@ impl<R: ImportResolver, C: Cache> VirtualMachine<R, C> {
match_sharedterm!(match (t) {
Term::Record(data) => {
for (id, field) in data.fields {
debug_assert!(field.metadata.is_empty());

if let Some(value) = field.value {
match_sharedterm!(match (value.term) {
Term::Closure(idx) => {
Expand Down
15 changes: 11 additions & 4 deletions core/src/term/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -712,13 +712,20 @@ impl TypeAnnotation {
}

/// Convert all the contracts of this annotation, including the potential type annotation as
/// the first element, to a runtime representation.
/// the first element, to a runtime representation. Apply contract optimizations to the static
/// type annotation.
pub fn all_contracts(&self) -> Result<Vec<RuntimeContract>, UnboundTypeVariableError> {
self.typ
.iter()
.chain(self.contracts.iter())
.as_ref()
.cloned()
.map(RuntimeContract::try_from)
.map(RuntimeContract::from_static_type)
.into_iter()
.chain(
self.contracts
.iter()
.cloned()
.map(RuntimeContract::try_from),
)
.collect::<Result<Vec<_>, _>>()
}

Expand Down
188 changes: 142 additions & 46 deletions core/src/term/pattern/compile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,48 +165,68 @@ impl CompilePart for RecordPattern {
//
// if %typeof% value_id == 'Record
// let final_bindings_id =
// <fold (field, value) in fields
// - cont is the accumulator
// - initial accumulator is `%record_insert% "<REST_FIELD>" bindings_id value_id`
// >
// # if this field pattern has some extra annotations (contract, default value, etc.)
// # we let merging take care of it by assembling a one-field record and merging it
// # with the matched value
// #
// # If there is a default value, we must merge it BEFORE the %field_is_defined%,
// # because the default acts as if the original matched value always have this field
// # defined
// <if field.extra.value.is_some() >
// let value_id = value_id & { "<field>" = <field.extra> } in
// <end if>
//
// <fold (field, value) in fields
// - cont is the accumulator
// - initial accumulator is `%record_insert% "<REST_FIELD>" bindings_id value_id`
// >
// if %field_is_defined% field value_id then
// let local_bindings_id = cont in
// if %field_is_defined% field value_id then
// # However, if there are extra annotations without a default value, we must first
// # check that the field is defined before merging the extra annotations.
// <if !field.extra.metadata.is_empty() && field.extra.value.is_none() >
// let value_id = value_id & { "<field>" = <field.extra> } in
// <end if>
//
// if local_bindings_id == null then
// null
// else
// let local_value_id = %static_access(field)% (%static_access(REST_FIELD)% local_bindings_id)
// let local_bindings_id = <remove_from_rest(field, local_bindings_id)> in
// <field.compile_part(local_value_id, local_bindings_id)>
// else
// null
// <end fold>
// let local_bindings_id = cont in
//
// if local_bindings_id == null then
// null
// else
// let local_value_id = %static_access(field)% (%static_access(REST_FIELD)% local_bindings_id)
// let local_bindings_id = <remove_from_rest(field, local_bindings_id)> in
// <field.compile_part(local_value_id, local_bindings_id)>
// else
// null
// <end fold>
// in
//
// if final_bindings_id == null then
// null
// else
// <if self.tail is empty>
// # if tail is empty, check that the value doesn't contain extra fields
// if (%static_access% <REST_FIELD> final_bindings_id) != {} then
// null
// else
// %record_remove% "<REST>" final_bindings_id
// <else if self.tail is capture(rest)>
// # move the rest from REST_FIELD to rest, and remove REST_FIELD
// %record_remove% "<REST>"
// (%record_insert% <rest>
// final_bindings_id
// (%static_access% <REST_FIELD> final_bindings_id)
// )
// <else if self.tail is open>
// %record_remove% "<REST>" final_bindings_id
// <end if>
// if final_bindings_id == null then
// null
// else
// <if self.tail is empty>
// # if tail is empty, check that the value doesn't contain extra fields
// if (%static_access% <REST_FIELD> final_bindings_id) != {} then
// null
// else
// %record_remove% "<REST>" final_bindings_id
// <else if self.tail is capture(rest)>
// # move the rest from REST_FIELD to rest, and remove REST_FIELD
// %record_remove% "<REST>"
// (%record_insert% <rest>
// final_bindings_id
// (%static_access% <REST_FIELD> final_bindings_id)
// )
// <else if self.tail is open>
// %record_remove% "<REST>" final_bindings_id
// <end if>
// else
// null
fn compile_part(&self, value_id: LocIdent, bindings_id: LocIdent) -> RichTerm {
use crate::{
label::{MergeKind, MergeLabel},
term::IndexMap,
};

let rest_field = LocIdent::fresh();

// `%record_insert% "<REST>" bindings_id value_id`
Expand All @@ -226,14 +246,30 @@ impl CompilePart for RecordPattern {
// - initial accumulator is `%record_insert% "<REST>" bindings_id value_id`
// >
//
// # if this field pattern has some extra annotations (contract, default value, etc.)
// # we let merging take care of it by assembling a one-field record and merging it
// # with the matched value
// #
// # If there is a default value, we must merge it BEFORE the %field_is_defined%,
// # because the default acts like the original matched value always have this field
// # defined
// <if field.extra.value.is_some() >
// let value_id = value_id & { "<field>" = <field.extra> } in
// <end if>
//
// if %field_is_defined% field value_id then
// # However, if there are extra annotations without a default value, we must first
// # check that the field is defined before merging the extra annotations.
// <if !field.extra.metadata.is_empty() && field.extra.value.is_none() >
// let value_id = value_id & { "<field>" = <field.extra> } in
// <end if>
//
// let local_bindings_id = cont in
//
// if local_bindings_id == null then
// null
// else
// let local_value_id = %static_access(field)% (%static_access(REST_FIELD)% local_bindings_id)
// let local_value_id = %static_access(field)% (%static_access(REST_FIELD)% local_bindings_id) in
// let local_bindings_id = <remove_from_rest(field, local_bindings_id)> in
// <field.compile_part(local_value_id, local_bindings_id)>
let fold_block: RichTerm = self.patterns.iter().fold(init_bindings, |cont, field_pat| {
Expand All @@ -251,20 +287,19 @@ impl CompilePart for RecordPattern {
.compile_part(local_value_id, local_bindings_id),
);

// let value_id = %static_access(field)% (%static_access(REST_FIELD)% local_bindings_id)
// in <updated_bindings_let>
let inner_else_block = make::let_in(
local_value_id,
// %static_access(field)% (%static_access(REST_FIELD)% local_bindings_id)
let extracted_value = make::op1(
UnaryOp::StaticAccess(field),
make::op1(
UnaryOp::StaticAccess(field),
make::op1(
UnaryOp::StaticAccess(rest_field),
Term::Var(local_bindings_id),
),
UnaryOp::StaticAccess(rest_field),
Term::Var(local_bindings_id),
),
updated_bindings_let,
);

// let local_value_id = <extracted_value> in <updated_bindings_let>
let inner_else_block =
make::let_in(local_value_id, extracted_value, updated_bindings_let);

// The innermost if:
//
// if local_bindings_id == null then
Expand All @@ -280,14 +315,75 @@ impl CompilePart for RecordPattern {
// let local_bindings_id = cont in <value_let>
let binding_cont_let = make::let_in(local_bindings_id, cont, inner_if);

// let value_id = value_id & { "<field>" = <field.extra> } in <merge_cont>
let mk_merge = |id: LocIdent, field: &Field, merge_cont: RichTerm| {
let singleton = Term::Record(RecordData {
fields: IndexMap::from([(id, field.clone())]),
..Default::default()
});
// Right now, patterns are compiled on-the-fly during evaluation. We thus need to
// perform the gen_pending_contract transformation manually, or the contracts will
// just be ignored. One step suffices, as we create a singleton record that doesn't
// contain other non-transformed records (the default value, if any, has been
// transformed normally).
//
// unwrap(): typechecking ensures that there are no unbound variables at this point
let singleton =
crate::transform::gen_pending_contracts::transform_one(singleton.into())
.unwrap();

let span = field
.metadata
.annotation
.iter()
.map(|labeled_ty| labeled_ty.label.span)
.chain(field.value.as_ref().and_then(|v| v.pos.into_opt()))
// We fuse all the definite spans together.
// unwrap(): all span should come from the same file
// unwrap(): we hope that at least one position is defined
.reduce(|span1, span2| crate::position::RawSpan::fuse(span1, span2).unwrap())
.unwrap();

let merge_label = MergeLabel {
span,
kind: MergeKind::Standard,
};

make::let_in(
value_id,
make::op2(BinaryOp::Merge(merge_label), Term::Var(value_id), singleton),
merge_cont,
)
};

// <if !field.extra.metadata.is_empty() && field.extra.value.is_none() >
// <mk_merge ...>
// <end if>
let optional_merge =
if !field_pat.extra.metadata.is_empty() && field_pat.extra.value.is_none() {
mk_merge(field, &field_pat.extra, binding_cont_let)
} else {
binding_cont_let
};

// %field_is_defined% field value_id
let has_field = make::op2(
BinaryOp::FieldIsDefined(RecordOpKind::ConsiderAllFields),
Term::Str(field.label().into()),
Term::Var(value_id),
);

make::if_then_else(has_field, binding_cont_let, Term::Null)
// if <has_field> then <optional_merge> else null
let enclosing_if = make::if_then_else(has_field, optional_merge, Term::Null);

// <if field.extra.value.is_some() >
// <mk_merge ...>
// <end if>
if field_pat.extra.value.is_some() {
mk_merge(field, &field_pat.extra, enclosing_if)
} else {
enclosing_if
}
});

// %typeof% value_id == 'Record
Expand Down
8 changes: 8 additions & 0 deletions core/src/term/record.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,14 @@ impl FieldMetadata {
pub fn new() -> Self {
Default::default()
}

pub fn is_empty(&self) -> bool {
self.doc.is_none()
&& self.annotation.is_empty()
&& !self.opt
&& !self.not_exported
&& matches!(self.priority, MergePriority::Neutral)
}
}

/// A record field with its metadata.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# test.type = 'error'
#
# [test.metadata]
# error = 'EvalError::BlameError'
{foo.bar = 5} |> match {
{foo={bar | String}} => bar,
}
9 changes: 9 additions & 0 deletions core/tests/integration/inputs/pattern-matching/contracts.ncl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# test.type = 'pass'

let {check, ..} = import "../lib/assert.ncl" in

[
{foo = {}} |> match { {foo = {bar ? 5}} => true},
{foo = {}} |> match { {foo = {bar | String }} => false, {foo} => true},
]
|> check
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# test.type = 'error'
#
# [test.metadata]
# error = 'EvalError::BlameError'
{foo.bar | default = []} |> match {
{foo={bar ? [1]}} => bar,
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# test.type = 'error'
#
# [test.metadata]
# error = 'EvalError::MergeIncompatibleArgs'
{foo.bar | default = 5} |> match {
{foo={bar ? 6}} => bar,
}

0 comments on commit 0b294a2

Please sign in to comment.