Skip to content

Commit

Permalink
[Refactor] Move unify and variants into their own trait (#1469)
Browse files Browse the repository at this point in the history
* Structure `unify` and variants into Unify trait

* Fix comment documenting Unify/unify
  • Loading branch information
yannham authored Jul 21, 2023
1 parent aff367d commit 77ec9a0
Show file tree
Hide file tree
Showing 3 changed files with 400 additions and 366 deletions.
5 changes: 3 additions & 2 deletions core/src/typecheck/destructuring.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::{
identifier::Ident,
mk_uty_row,
term::{IndexMap, LabeledType},
typecheck::{unify, UnifRecordRow},
typecheck::{UnifRecordRow, Unify},
types::{RecordRowF, RecordRowsF, TypeF},
};

Expand Down Expand Up @@ -114,7 +114,8 @@ fn build_pattern_type(
let types = annot_ty.types.clone();
let pos = types.pos;
let annot_ty = UnifType::from_type(types, &ctxt.term_env);
unify(state, ctxt, ty.clone(), annot_ty)
ty.clone()
.unify(annot_ty, state, ctxt)
.map_err(|e| e.into_typecheck_err(state, pos))?;
}

Expand Down
95 changes: 54 additions & 41 deletions core/src/typecheck/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1675,16 +1675,20 @@ fn check<L: Linearizer>(
Term::ParseError(_) => Ok(()),
Term::RuntimeError(_) => panic!("unexpected RuntimeError term during typechecking"),
// null is inferred to be of type Dyn
Term::Null => unify(state, &ctxt, ty, mk_uniftype::dynamic())
Term::Null => ty
.unify(mk_uniftype::dynamic(), state, &ctxt)
.map_err(|err| err.into_typecheck_err(state, rt.pos)),
Term::Bool(_) => unify(state, &ctxt, ty, mk_uniftype::bool())
Term::Bool(_) => ty
.unify(mk_uniftype::bool(), state, &ctxt)
.map_err(|err| err.into_typecheck_err(state, rt.pos)),
Term::Num(_) => unify(state, &ctxt, ty, mk_uniftype::num())
Term::Num(_) => ty
.unify(mk_uniftype::num(), state, &ctxt)
.map_err(|err| err.into_typecheck_err(state, rt.pos)),
Term::Str(_) => unify(state, &ctxt, ty, mk_uniftype::str())
Term::Str(_) => ty
.unify(mk_uniftype::str(), state, &ctxt)
.map_err(|err| err.into_typecheck_err(state, rt.pos)),
Term::StrChunks(chunks) => {
unify(state, &ctxt, ty, mk_uniftype::str())
ty.unify(mk_uniftype::str(), state, &ctxt)
.map_err(|err| err.into_typecheck_err(state, rt.pos))?;

chunks
Expand Down Expand Up @@ -1713,7 +1717,8 @@ fn check<L: Linearizer>(

linearizer.retype_ident(lin, x, src.clone());

unify(state, &ctxt, ty, arr).map_err(|err| err.into_typecheck_err(state, rt.pos))?;
ty.unify(arr, state, &ctxt)
.map_err(|err| err.into_typecheck_err(state, rt.pos))?;

ctxt.type_env.insert(*x, src);
check(state, ctxt, lin, linearizer, t, trg)
Expand All @@ -1730,13 +1735,14 @@ fn check<L: Linearizer>(
}

destructuring::inject_pattern_variables(state, &mut ctxt.type_env, pat, src_rows_ty);
unify(state, &ctxt, ty, arr).map_err(|err| err.into_typecheck_err(state, rt.pos))?;
ty.unify(arr, state, &ctxt)
.map_err(|err| err.into_typecheck_err(state, rt.pos))?;
check(state, ctxt, lin, linearizer, t, trg)
}
Term::Array(terms, _) => {
let ty_elts = state.table.fresh_type_uvar(ctxt.var_level);

unify(state, &ctxt, ty, mk_uniftype::array(ty_elts.clone()))
ty.unify(mk_uniftype::array(ty_elts.clone()), state, &ctxt)
.map_err(|err| err.into_typecheck_err(state, rt.pos))?;

terms
Expand All @@ -1754,7 +1760,7 @@ fn check<L: Linearizer>(
}
Term::Lbl(_) => {
// TODO implement lbl type
unify(state, &ctxt, ty, mk_uniftype::dynamic())
ty.unify(mk_uniftype::dynamic(), state, &ctxt)
.map_err(|err| err.into_typecheck_err(state, rt.pos))
}
Term::Let(x, re, rt, attrs) => {
Expand Down Expand Up @@ -1793,7 +1799,9 @@ fn check<L: Linearizer>(
// The inferred type of the expr being bound
let ty_let = binding_type(state, re.as_ref(), &ctxt, true);

unify(state, &ctxt, ty_let.clone(), pattern_type)
ty_let
.clone()
.unify(pattern_type, state, &ctxt)
.map_err(|e| e.into_typecheck_err(state, re.pos))?;

check(
Expand Down Expand Up @@ -1827,7 +1835,7 @@ fn check<L: Linearizer>(
let tgt = state.table.fresh_type_uvar(ctxt.var_level);
let arr = mk_uty_arrow!(src.clone(), tgt.clone());

unify(state, &ctxt, arr, function_type)
arr.unify(function_type, state, &ctxt)
.map_err(|err| err.into_typecheck_err(state, e.pos))?;

check(state, ctxt.clone(), lin, linearizer, t, src)?;
Expand All @@ -1845,11 +1853,10 @@ fn check<L: Linearizer>(
let return_type = state.table.fresh_type_uvar(ctxt.var_level);

// We unify the expected type of the match expression with `arg_type -> return_type`
unify(
ty.unify(
mk_uty_arrow!(arg_type.clone(), return_type.clone()),
state,
&ctxt,
ty,
mk_uty_arrow!(arg_type.clone(), return_type.clone()),
)
.map_err(|err| err.into_typecheck_err(state, rt.pos))?;

Expand Down Expand Up @@ -1877,7 +1884,8 @@ fn check<L: Linearizer>(
)?,
};

unify(state, &ctxt, arg_type, mk_uty_enum!(; erows))
arg_type
.unify(mk_uty_enum!(; erows), state, &ctxt)
.map_err(|err| err.into_typecheck_err(state, rt.pos))
}
Term::Var(x) => {
Expand All @@ -1888,12 +1896,12 @@ fn check<L: Linearizer>(
.ok_or(TypecheckError::UnboundIdentifier(*x, *pos))?;

let instantiated = instantiate_foralls(state, &mut ctxt, x_ty, ForallInst::Ptr);
unify(state, &ctxt, ty, instantiated)
ty.unify(instantiated, state, &ctxt)
.map_err(|err| err.into_typecheck_err(state, rt.pos))
}
Term::Enum(id) => {
let row = state.table.fresh_erows_uvar(ctxt.var_level);
unify(state, &ctxt, ty, mk_uty_enum!(*id; row))
ty.unify(mk_uty_enum!(*id; row), state, &ctxt)
.map_err(|err| err.into_typecheck_err(state, rt.pos))
}
// If some fields are defined dynamically, the only potential type that works is `{_ : a}`
Expand All @@ -1902,7 +1910,7 @@ fn check<L: Linearizer>(
// element type.
Term::RecRecord(record, dynamic, ..) if !dynamic.is_empty() => {
let ty_dict = state.table.fresh_type_uvar(ctxt.var_level);
unify(state, &ctxt, ty, mk_uniftype::dict(ty_dict.clone()))
ty.unify(mk_uniftype::dict(ty_dict.clone()), state, &ctxt)
.map_err(|err| err.into_typecheck_err(state, rt.pos))?;

//TODO: should we insert in the environment the checked type, or the actual type?
Expand Down Expand Up @@ -1981,24 +1989,24 @@ fn check<L: Linearizer>(
|acc, (id, row_ty)| mk_uty_row!((*id, row_ty.clone()); acc),
);

unify(state, &ctxt, ty, mk_uty_record!(; rows))
ty.unify(mk_uty_record!(; rows), state, &ctxt)
.map_err(|err| err.into_typecheck_err(state, rt.pos))?;

for (id, field) in record.fields.iter() {
if let Term::RecRecord(..) = t.as_ref() {
let affected_type = ctxt.type_env.get(id).cloned().unwrap();
unify(
state,
&ctxt,
field_types.get(id).cloned().unwrap(),
affected_type,
)
.map_err(|err| {
err.into_typecheck_err(
state,
field.value.as_ref().map(|v| v.pos).unwrap_or_default(),
)
})?;

field_types
.get(id)
.cloned()
.unwrap()
.unify(affected_type, state, &ctxt)
.map_err(|err| {
err.into_typecheck_err(
state,
field.value.as_ref().map(|v| v.pos).unwrap_or_default(),
)
})?;
}

check_field(
Expand Down Expand Up @@ -2052,7 +2060,8 @@ fn check<L: Linearizer>(
Term::OpN(op, args) => {
let (tys_op, ty_ret) = get_nop_type(state, ctxt.var_level, op)?;

unify(state, &ctxt, ty, ty_ret).map_err(|err| err.into_typecheck_err(state, rt.pos))?;
ty.unify(ty_ret, state, &ctxt)
.map_err(|err| err.into_typecheck_err(state, rt.pos))?;

tys_op.into_iter().zip(args.iter()).try_for_each(
|(ty_t, t)| -> Result<_, TypecheckError> {
Expand All @@ -2064,10 +2073,12 @@ fn check<L: Linearizer>(
Ok(())
}
Term::Annotated(annot, rt) => check_annotated(state, ctxt, lin, linearizer, annot, rt, ty),
Term::SealingKey(_) => unify(state, &ctxt, ty, mk_uniftype::sym())
Term::SealingKey(_) => ty
.unify(mk_uniftype::sym(), state, &ctxt)
.map_err(|err| err.into_typecheck_err(state, rt.pos)),
Term::Sealed(_, t, _) => check(state, ctxt, lin, linearizer, t, ty),
Term::Import(_) => unify(state, &ctxt, ty, mk_uniftype::dynamic())
Term::Import(_) => ty
.unify(mk_uniftype::dynamic(), state, &ctxt)
.map_err(|err| err.into_typecheck_err(state, rt.pos)),
// We use the apparent type of the import for checking. This function doesn't recursively
// typecheck imports: this is the responsibility of the caller.
Expand All @@ -2080,7 +2091,8 @@ fn check<L: Linearizer>(
apparent_type(t.as_ref(), Some(&ctxt.type_env), Some(state.resolver)),
&ctxt.term_env,
);
unify(state, &ctxt, ty, ty_import).map_err(|err| err.into_typecheck_err(state, rt.pos))
ty.unify(ty_import, state, &ctxt)
.map_err(|err| err.into_typecheck_err(state, rt.pos))
}
Term::Types(typ) => {
if let Some(flat) = typ.find_flat() {
Expand All @@ -2103,7 +2115,7 @@ pub fn subsumption(
inferred: UnifType,
checked: UnifType,
) -> Result<(), UnifError> {
unify(state, ctxt, checked, inferred)
checked.unify(inferred, state, ctxt)
}

fn check_field<L: Linearizer>(
Expand Down Expand Up @@ -2173,7 +2185,8 @@ fn check_with_annot<L: Linearizer>(
let instantiated =
instantiate_foralls(state, &mut ctxt, uty2.clone(), ForallInst::Constant);

unify(state, &ctxt, uty2, ty).map_err(|err| err.into_typecheck_err(state, pos))?;
uty2.unify(ty, state, &ctxt)
.map_err(|err| err.into_typecheck_err(state, pos))?;
check(state, ctxt, lin, linearizer, value, instantiated)
}
// A annotation without a type but with a contract switches the typechecker back to walk
Expand All @@ -2190,11 +2203,10 @@ fn check_with_annot<L: Linearizer>(
let ctr = contracts.get(0).unwrap();
let LabeledType { types: ty2, .. } = ctr;

unify(
ty.unify(
UnifType::from_type(ty2.clone(), &ctxt.term_env),
state,
&ctxt,
ty,
UnifType::from_type(ty2.clone(), &ctxt.term_env),
)
.map_err(|err| err.into_typecheck_err(state, pos))?;

Expand Down Expand Up @@ -2223,7 +2235,8 @@ fn check_with_annot<L: Linearizer>(
.first()
.map(|labeled_ty| UnifType::from_type(labeled_ty.types.clone(), &ctxt.term_env))
.unwrap_or_else(mk_uniftype::dynamic);
unify(state, &ctxt, ty, inferred).map_err(|err| err.into_typecheck_err(state, pos))
ty.unify(inferred, state, &ctxt)
.map_err(|err| err.into_typecheck_err(state, pos))
}
}
}
Expand Down
Loading

0 comments on commit 77ec9a0

Please sign in to comment.