Skip to content

Commit

Permalink
generalize methods
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinbarabash committed Oct 12, 2023
1 parent da69b61 commit 87f6a55
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 25 deletions.
104 changes: 85 additions & 19 deletions crates/escalier_hm/src/infer_class.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
use generational_arena::Index;
use generational_arena::{Arena, Index};
use itertools::Itertools;

use escalier_ast::{self as syntax, *};

use crate::ast_utils::{find_returns, find_throws};
use crate::checker::Checker;
use crate::context::*;
use crate::infer::generalize_func;
use crate::infer_pattern::pattern_to_tpat;
use crate::key_value_store::KeyValueStore;
use crate::type_error::TypeError;
use crate::types::{self, *};
use crate::visitor::{walk_index, Visitor};

impl Checker {
pub fn infer_class(
Expand All @@ -18,6 +21,8 @@ impl Checker {
) -> Result<Index, TypeError> {
let mut cls_ctx = ctx.clone();

// TODO: mutate the instance_scheme since only the methods need
// further type checking.
// TODO: unify _static_type with the static type of the class
let (instance_scheme, _static_type) = self.infer_class_interface(class, &mut cls_ctx)?;

Expand Down Expand Up @@ -189,16 +194,59 @@ impl Checker {
}
ClassMember::Getter(_) => todo!(),
ClassMember::Setter(_) => todo!(),
ClassMember::Field(_) => {
// If there's an initializer, infer its type and then
// unify with the type annotation of the field.
ClassMember::Field(Field {
span: _,
name,
is_public: _, // TODO
is_static,
type_ann,
init: _, // TODO: unify in `infer_class`
// If there's an initializer, infer its type and then
// unify with the type annotation of the field.
}) => {
let mut sig_ctx = cls_ctx.clone();

let type_ann_t = match type_ann {
Some(type_ann) => self.infer_type_ann(type_ann, &mut sig_ctx)?,
None => self.new_type_var(None),

Check warning on line 211 in crates/escalier_hm/src/infer_class.rs

View check run for this annotation

Codecov / codecov/patch

crates/escalier_hm/src/infer_class.rs#L211

Added line #L211 was not covered by tests
};

let field = TObjElem::Prop(TProp {
name: TPropKey::StringKey(name.name.to_owned()),
t: type_ann_t,
optional: false, // TODO
readonly: false, // TODO
});

match is_static {
true => static_elems.push(field),

Check warning on line 222 in crates/escalier_hm/src/infer_class.rs

View check run for this annotation

Codecov / codecov/patch

crates/escalier_hm/src/infer_class.rs#L222

Added line #L222 was not covered by tests
false => instance_elems.push(field),
};
}
}
}

// We generalize methods after all of them have been inferred so
// that mutually recursive method calls can be handled correctly.
for elem in instance_elems.iter_mut() {
if let TObjElem::Method(method) = elem {
let func = generalize_func(self, &method.function);
method.function = func;
}
}

let instance_type = self.new_object_type(&instance_elems);
let static_type = self.new_object_type(&static_elems);

let self_scheme = Scheme {
type_params: None,
t: instance_type,
is_type_param: false,
};

replace_self_type_refs(&mut self.arena, &instance_type, &self_scheme);
replace_self_type_refs(&mut self.arena, &static_type, &self_scheme);

self.unify(&cls_ctx, instance_scheme.t, instance_type)?;

Ok(static_type)
Expand Down Expand Up @@ -388,21 +436,6 @@ impl Checker {
}
}

// TODO: iterate over instance_elems and generalize all methods
for elem in instance_elems.iter_mut() {
// if let TObjElem::Method(method) = elem {
// let func = generalize_func(self, &method.function);
// method.function = func;
// }

// let pruned_index = self.prune(binding.index);
// if let TypeKind::Function(func) = &self.arena[pruned_index].kind.clone() {
// let func = generalize_func(self, func);
// let gen_func_index = self.arena.insert(Type::from(TypeKind::Function(func)));
// self.bind(ctx, binding.index, gen_func_index)?;
// }
}

let instance_scheme = Scheme {
t: self.new_object_type(&instance_elems),
// TODO: add type params
Expand Down Expand Up @@ -444,3 +477,36 @@ impl Checker {
})
}
}

pub struct ReplaceVisitor<'a> {
pub arena: &'a mut Arena<Type>,
pub scheme: &'a Scheme,
}

impl<'a> KeyValueStore<Index, Type> for ReplaceVisitor<'a> {
fn get_type(&mut self, idx: &Index) -> Type {
self.arena[*idx].clone()
}
fn put_type(&mut self, t: Type) -> Index {
self.arena.insert(t)

Check warning on line 491 in crates/escalier_hm/src/infer_class.rs

View check run for this annotation

Codecov / codecov/patch

crates/escalier_hm/src/infer_class.rs#L490-L491

Added lines #L490 - L491 were not covered by tests
}
}

impl<'a> Visitor for ReplaceVisitor<'a> {
fn visit_index(&mut self, index: &Index) {
match &mut self.arena[*index].kind {
TypeKind::TypeRef(tref) => {
if tref.name == "Self" {
tref.scheme = Some(self.scheme.clone());
}
}
_ => walk_index(self, index),
}
}
}

pub fn replace_self_type_refs(arena: &mut Arena<Type>, t: &Index, scheme: &Scheme) {
let mut replace_visitor = ReplaceVisitor { arena, scheme };

replace_visitor.visit_index(t)
}
28 changes: 22 additions & 6 deletions crates/escalier_hm/tests/integration_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5550,13 +5550,21 @@ fn infer_class_with_generic_method() -> Result<(), TypeError> {
// TODO: Allow comments in class bodies
let src = r#"
let Foo = class {
fn constructor(mut self) {}
x: number
fn constructor(mut self) {
self.x = 0
}
fn fst(self, a, b) {
return a
}
fn inc(mut self) {
self.x += 1
return self
}
}
let foo = new Foo()
// let x = foo.fst(5, "hello")
let mut foo = new Foo()
let bar = foo.inc()
let x = foo.inc().fst(5, "hello")
"#;
let mut script = parse_script(src).unwrap();

Expand All @@ -5567,11 +5575,19 @@ fn infer_class_with_generic_method() -> Result<(), TypeError> {
let t = checker.expand_type(&my_ctx, binding.index)?;
assert_eq!(
checker.print_type(&t),
r#"{fst<T, U>(self, a: T, b: U) -> T}"#
r#"{x: number, fst<B, A>(self, a: A, b: B) -> A, inc(mut self) -> Self}"#
);

let binding = my_ctx.values.get("bar").unwrap();
assert_eq!(checker.print_type(&binding.index), r#"Self"#);
let t = checker.expand_type(&my_ctx, binding.index)?;
assert_eq!(
checker.print_type(&t),
r#"{x: number, fst<B, A>(self, a: A, b: B) -> A, inc(mut self) -> Self}"#
);

// let binding = my_ctx.values.get("x").unwrap();
// assert_eq!(checker.print_type(&binding.index), r#"5"#);
let binding = my_ctx.values.get("x").unwrap();
assert_eq!(checker.print_type(&binding.index), r#"5"#);

assert_no_errors(&checker)
}
Expand Down

0 comments on commit 87f6a55

Please sign in to comment.