Skip to content

Commit

Permalink
Simple type checking with simple constraints
Browse files Browse the repository at this point in the history
  • Loading branch information
ITesserakt committed Apr 16, 2024
1 parent 16fbdc3 commit 27ea1d7
Show file tree
Hide file tree
Showing 14 changed files with 371 additions and 280 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
2 changes: 0 additions & 2 deletions crates/kodept-core/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::any::type_name;

pub mod code_point;
pub mod code_source;
pub mod file_relative;
Expand Down
8 changes: 4 additions & 4 deletions crates/kodept-inference/src/algorithm_w.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ impl<'e> AlgorithmW<'e> {
let v = self.env.new_var();
let var_bind = lambda.bind.clone();
let mut new_context = self.context.clone();
new_context.push(var_bind, MonomorphicType::Var(v.clone()).into());
new_context.push(var_bind, Rc::new(MonomorphicType::Var(v.clone()).into()));
let (s1, t1) = AlgorithmW {
context: &mut new_context,
env: self.env,
Expand All @@ -81,7 +81,7 @@ impl<'e> AlgorithmW<'e> {
Language::Var(v) => v,
_ => unreachable!(),
})
.push(var_bind, poly_type);
.push(var_bind, Rc::new(poly_type));

let (s2, t2) = AlgorithmW {
context: &mut new_context,
Expand Down Expand Up @@ -130,14 +130,14 @@ impl Language {
context.apply(self)
}

pub fn infer_with_env<'l>(
pub fn infer_with_env(
self: Rc<Self>,
context: &mut Assumptions,
env: &mut Environment,
) -> Result<MonomorphicType, AlgorithmWError> {
let (s, t) = AlgorithmW { context, env }.apply(&self)?;
let poly_type = context.generalize(t.clone());
context.substitute_mut(&s).push(self, poly_type);
context.substitute_mut(&s).push(self, Rc::new(poly_type));
Ok(t)
}

Expand Down
52 changes: 45 additions & 7 deletions crates/kodept-inference/src/assumption.rs
Original file line number Diff line number Diff line change
@@ -1,37 +1,55 @@
use itertools::Itertools;
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::fmt::{Display, Formatter};
use std::ops::Sub;
use std::rc::Rc;

use crate::language;
use crate::language::Language;
use crate::r#type::{MonomorphicType, PolymorphicType};
use crate::substitution::Substitutions;
use crate::{language, Environment};

type RLanguage = Rc<Language>;
type RPolymorphicType = Rc<PolymorphicType>;

#[derive(Clone, Debug, Default)]
pub struct Assumptions {
value: HashMap<RLanguage, PolymorphicType>,
value: HashMap<RLanguage, RPolymorphicType>,
}

impl Assumptions {
pub fn substitute_mut(&mut self, substitutions: &Substitutions) -> &mut Self {
for t in self.value.values_mut() {
*t = t.substitute(substitutions);
*t = Rc::new(t.substitute(substitutions));
}
self
}

pub fn push(&mut self, expr: RLanguage, t: PolymorphicType) -> &mut Self {
self.value.insert(expr, t);
pub fn push(&mut self, expr: RLanguage, t: RPolymorphicType) -> &mut Self {
match self.value.entry(expr) {
Entry::Occupied(slot) if slot.get() == &t => {}
Entry::Occupied(mut slot) => {
let mut env = Environment::default();
let s0 = slot
.get()
.instantiate(&mut env)
.unify(&t.instantiate(&mut env))
.expect("Given assumption cannot be unified with the old one");
slot.insert(Rc::new(t.substitute(&s0)));
}
Entry::Vacant(slot) => {
slot.insert(t);
}
};
self
}

#[must_use]
pub fn generalize(&self, t: MonomorphicType) -> PolymorphicType {
let occupied_types = self.value.iter().flat_map(|it| it.1.free_types()).collect();
let vars = t.free_types().sub(&occupied_types);
vars.iter()
vars.into_iter()
.fold(PolymorphicType::Monomorphic(t), |acc, next| {
PolymorphicType::Binding {
bind: next.clone(),
Expand All @@ -42,7 +60,7 @@ impl Assumptions {

#[must_use]
pub fn get(&self, key: &Language) -> Option<&PolymorphicType> {
self.value.get(key)
self.value.get(key).map(|it| it.as_ref())
}

#[must_use]
Expand All @@ -56,4 +74,24 @@ impl Assumptions {
.retain(|it, _| !matches!(it.as_ref(), Language::Var(v) if v == var));
self
}

pub fn merge(mut self, other: Assumptions) -> Self {
for (key, value) in other.value {
self.push(key, value);
}
self
}
}

impl Display for Assumptions {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let out = self
.value
.iter()
.map(|(key, value)| format!("{key} => {value}"))
.intersperse(", ".to_string())
.collect::<String>();

write!(f, "[{out}]")
}
}
22 changes: 11 additions & 11 deletions crates/kodept-inference/src/type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@ use std::fmt::{Display, Formatter};
use std::iter::once;
use std::ops::{BitOr, Sub};

use derive_more::{Display as DeriveDisplay, From};
use derive_more::{Constructor, Display as DeriveDisplay, From};
use itertools::Itertools;
use nonempty_collections::NEVec;

use crate::{LOWER_ALPHABET, UPPER_ALPHABET};
use crate::Environment;
use crate::substitution::Substitutions;
use crate::{Environment, LOWER_ALPHABET, UPPER_ALPHABET};

fn expand_to_string(id: usize, alphabet: &'static str) -> String {
if id == 0 {
Expand Down Expand Up @@ -39,11 +38,11 @@ pub enum PrimitiveType {
#[derive(Debug, Clone, PartialEq, Hash, Eq, From)]
pub struct Var(pub(crate) usize);

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[derive(Debug, Clone, PartialEq, Eq, Hash, Constructor)]

pub struct Tuple(pub(crate) Vec<MonomorphicType>);

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[derive(Debug, Clone, PartialEq, Eq, Hash, Constructor)]

pub struct Union(pub(crate) Vec<MonomorphicType>);

Expand All @@ -63,6 +62,7 @@ pub enum MonomorphicType {
}

#[derive(Debug, Clone, PartialEq, From)]
#[from(forward)]
pub enum PolymorphicType {
Monomorphic(MonomorphicType),
#[from(ignore)]
Expand Down Expand Up @@ -194,17 +194,17 @@ impl MonomorphicType {

impl PolymorphicType {
fn collect(&self) -> (Vec<Var>, MonomorphicType) {
fn step(mut acc: Vec<Var>, current: &PolymorphicType) -> (Vec<Var>, MonomorphicType) {
let mut result = vec![];
let mut current = self;
loop {
match current {
PolymorphicType::Monomorphic(t) => (vec![], t.clone()),
PolymorphicType::Monomorphic(ty) => return (result, ty.clone()),
PolymorphicType::Binding { bind, binding_type } => {
acc.push(bind.clone());
step(acc, binding_type.as_ref())
result.push(bind.clone());
current = binding_type.as_ref();
}
}
}

step(vec![], self)
}

#[must_use]
Expand Down
21 changes: 14 additions & 7 deletions crates/kodept-interpret/src/convert_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use kodept_ast::graph::{GhostToken, SyntaxTree};
use kodept_ast::traits::Identifiable;
use kodept_ast::{
BlockLevel, BodiedFunctionDeclaration, Body, Expression, ExpressionBlock, Identifier,
InitializedVariable, Lambda, Literal, Operation, Term,
IfExpression, InitializedVariable, Lambda, Literal, Operation, Term,
};
use kodept_inference::algorithm_w::AlgorithmWError;
use kodept_inference::assumption::Assumptions;
Expand Down Expand Up @@ -44,11 +44,11 @@ impl ToModelFrom<TypeDerivableNode> for ConversionHelper<'_> {
self.convert(node)
} else if let Some(node) = node.as_lambda() {
self.convert(node)
} else if let Some(node) = node.as_application() {
} else if let Some(_node) = node.as_application() {
todo!()
} else if let Some(node) = node.as_if_expr() {
} else if let Some(_node) = node.as_if_expr() {
todo!()
} else if let Some(node) = node.as_reference() {
} else if let Some(_node) = node.as_reference() {
todo!()
} else if let Some(node) = node.as_literal() {
self.convert(node)
Expand Down Expand Up @@ -169,9 +169,9 @@ impl ToModelFrom<Operation> for ConversionHelper<'_> {
} else {
Ok(app(unit(), expr).into())
};
} else if let Some(node) = node.as_access() {
} else if let Some(node) = node.as_binary() {
} else if let Some(node) = node.as_unary() {
} else if let Some(_node) = node.as_access() {
} else if let Some(_node) = node.as_binary() {
} else if let Some(_node) = node.as_unary() {
}
unreachable!()
}
Expand Down Expand Up @@ -214,6 +214,7 @@ impl ToModelFrom<Expression> for ConversionHelper<'_> {
if let Some(node) = node.as_term() {
return self.convert(node);
} else if let Some(node) = node.as_if() {
return self.convert(node);
} else if let Some(node) = node.as_literal() {
return self.convert(node);
} else if let Some(node) = node.as_lambda() {
Expand All @@ -223,6 +224,12 @@ impl ToModelFrom<Expression> for ConversionHelper<'_> {
}
}

impl ToModelFrom<IfExpression> for ConversionHelper<'_> {
fn convert(self, node: &IfExpression) -> Result<Language, InferError> {
Ok(unit())
}
}

impl ToModelFrom<Literal> for ConversionHelper<'_> {
fn convert(self, node: &Literal) -> Result<Language, InferError> {
if let Some(node) = node.as_number() {
Expand Down
Loading

0 comments on commit 27ea1d7

Please sign in to comment.