Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
yannham committed May 13, 2024
1 parent f3b3632 commit ed58cdb
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 2 deletions.
15 changes: 13 additions & 2 deletions core/src/parser/grammar.lalrpop
Original file line number Diff line number Diff line change
Expand Up @@ -729,8 +729,19 @@ ExtendedIdent: LocIdent = {
<Ident>,
};

Ident: LocIdent = <l:@L> <i: "identifier"> <r:@R> =>
LocIdent::new_with_pos(i, mk_pos(src_id, l, r));
// Tokens that must be lexed separately because they are keywords in some
// specific contexts, but aren't ambiguous otherwise, and can in particular be
// used as identifiers without any issue.
ContextualKeyword: Ident = {
"or" => Ident::from("or"),
};

Ident: LocIdent = {
<left :@L> <id: "identifier"> <right :@R> =>
LocIdent::new_with_pos(id, mk_pos(src_id, left, right)),
<left: @L> <id: ContextualKeyword> <right: @R> =>
LocIdent::new_with_pos(id, mk_pos(src_id, left, right)),
};

Bool: bool = {
"true" => true,
Expand Down
4 changes: 4 additions & 0 deletions core/src/parser/lexer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ pub enum NormalToken<'input> {
True,
#[token("false")]
False,
/// Or isn't a reserved keyword. It is a contextual keyword (a keyword that can be used as an
/// identifier because it's not ambiguous but) within patterns.
#[token("or")]
Or,

#[token("?")]
QuestionMark,
Expand Down
37 changes: 37 additions & 0 deletions core/src/term/pattern/compile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ impl CompilePart for PatternData {
PatternData::Array(pat) => pat.compile_part(value_id, bindings_id),
PatternData::Enum(pat) => pat.compile_part(value_id, bindings_id),
PatternData::Constant(pat) => pat.compile_part(value_id, bindings_id),
PatternData::OrPattern(pat) => pat.compile_part(value_id, bindings_id),
}
}
}
Expand Down Expand Up @@ -304,6 +305,42 @@ impl CompilePart for ConstantPatternData {
}
}

impl CompilePart for OrPattern {
// Compilation of or patterns.
//
// <fold pattern in patterns
// - cont is the accumulator
// - initial accumulator is `null`
// >
//
// let prev_bindings = cont in
//
// # if one of the previous patterns already matched, we just stop here and return the
// # resulting updated bindings. Otherwise, we try the current one
// if prev_bindings != null then
// prev_bindings
// else
// <pattern.compile(value_id, bindings_id)>
// <end fold>
fn compile_part(&self, value_id: LocIdent, bindings_id: LocIdent) -> RichTerm {
self.patterns
.iter()
.fold(Term::Null.into(), |cont, pattern| {
let prev_bindings = LocIdent::fresh();
let is_prev_not_null = make::op1(
UnaryOp::BoolNot,
make::op2(BinaryOp::Eq(), Term::Var(prev_bindings), Term::Null),
);

make::if_then_else(
is_prev_not_null,
Term::Var(prev_bindings),
pattern.compile_part(value_id, bindings_id),
)
})
}
}

impl CompilePart for RecordPattern {
// Compilation of the top-level record pattern wrapper.
//
Expand Down
5 changes: 5 additions & 0 deletions core/src/term/pattern/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ pub enum PatternData {
Enum(EnumPattern),
/// A constant pattern as in `42` or `true`.
Constant(ConstantPattern),
/// A sequence alternative patterns as in `'Foo _ or 'Bar _ or 'Baz _`.
OrPattern(OrPattern),
}

/// A generic pattern, that can appear in a match expression (not yet implemented) or in a
Expand Down Expand Up @@ -139,6 +141,9 @@ pub enum ConstantPatternData {
Null,
}

#[derive(Debug, PartialEq, Clone)]
pub struct OrPattern { pub patterns: Vec<Pattern> }

/// The tail of a data structure pattern (record or array) which might capture the rest of said
/// data structure.
#[derive(Debug, PartialEq, Clone)]
Expand Down
81 changes: 81 additions & 0 deletions core/src/typecheck/pattern.rs
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,9 @@ impl PatternTypes for PatternData {
PatternData::Constant(constant_pat) => {
constant_pat.pattern_types_inj(pt_state, path, state, ctxt, mode)
}
PatternData::OrPattern(or_pat) => {
or_pat.pattern_types_inj(pt_state, path, state, ctxt, mode)
}
}
}
}
Expand Down Expand Up @@ -522,3 +525,81 @@ impl PatternTypes for EnumPattern {
})
}
}

impl PatternTypes for OrPattern {
type PatType = UnifType;

fn pattern_types_inj(
&self,
pt_state: &mut PatTypeState,
path: PatternPath,
state: &mut State,
ctxt: &Context,
mode: TypecheckMode,
) -> Result<Self::PatType, TypecheckError> {
// When checking a sequence of disjuncts, we must combine their open tails and wildcard
// pattern positions - in fact, when typechecking a whole match expression, this is exactly
// what the typechecker is doing: it merges them. However, for bindings, we must ensure
// that:
//
// 1. They are the same
// 2. They have the same type
//
// To do so, we call to pattern_types_inj with a fresh vector of bindings, so that we can
// pre-process them afterward and before finally adding them to the original overall
// bindings.

let bindings: Result<Vec<_>> = self
.patterns
.iter()
.map(|pat| {
let mut fresh_bindings = Vec::new();

let local_state = PatTypeState {
bindings: &mut fresh_bindings,
enum_open_tails: pt_state.enum_open_tails,
wildcard_pat_paths: pt_state.wildcard_pat_paths,
};

let typ = pat.pattern_types_inj(&mut local_state, path.clone(), state, ctxt, mode)?;

fresh_bindings.sort();
Ok((typ, fresh_bindings))
})
.collect();

let mut it = bindings?.into_iterator();

let Some(model) = it.first() else {
// We should never generate empty `or` sequences (it's not possible in the source
// language, at least). However, it doesn't cost much to support them: such a pattern
// never matchs anything. Thus, we return the bottom type encoded as `forall a. a`.
let free_var = Ident::from("a");
return Ok(UnifType::concrete(TypeF::Forall {var: free_var.into(), var_kind: VarKind::Type, body: UnifType::concrete(TypeF::Var(free_var)) }));
};

for (typ, pat_bindings) in bindings {
if model.len() != bindings.len() {
return Err(todo!());
}

for idx in 0..model.len() {
let (model_id, model_ty) = model[idx];
let (id, ty) = pat_bindings[idx];

if model_id != id {
return Err(todo!());
}

if let TypecheckMode::Enforce = mode {
model_ty
.clone()
.unify(ty, state, ctxt)
.map_err(|e| e.into_typecheck_err(state, self.pos))?;
}
}
}


}
}

0 comments on commit ed58cdb

Please sign in to comment.