Skip to content

Commit

Permalink
tyck: Add typechecking of literal types
Browse files Browse the repository at this point in the history
  • Loading branch information
CohenArthur committed Apr 28, 2024
1 parent 1a51196 commit 7299231
Show file tree
Hide file tree
Showing 9 changed files with 167 additions and 39 deletions.
14 changes: 14 additions & 0 deletions ast/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
//! Abstract Syntax Tree representation of jinko's source code
use std::fmt;

use error::{ErrKind, Error};
use location::SpanTuple;
use symbol::Symbol;
Expand Down Expand Up @@ -137,6 +139,18 @@ pub enum Value {
Str(String),
}

impl fmt::Display for Value {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Value::Integer(v) => write!(f, "{v}"),
Value::Float(v) => write!(f, "{v}"),
Value::Bool(v) => write!(f, "{v}"),
Value::Char(v) => write!(f, "'{v}'"),
Value::Str(v) => write!(f, "\"{v}\""),
}
}
}

// FIXME: How to keep location in there? How to do it ergonomically?
// As a "Smart pointer" type? E.g by having it implement `Deref<T = AstInner>`?
// Would that even work? If it does, it is ergonomic but boy is it not idiomatic
Expand Down
2 changes: 1 addition & 1 deletion typecheck/src/actual.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ impl<'ctx> TypeLinkResolver<'ctx> {
final_type,
} = self.find_end(fir, to_resolve);

let node = final_type.0;
let node = final_type.origin();
let tyref = self.new.types.new_type(final_type);

intermediate_nodes
Expand Down
38 changes: 25 additions & 13 deletions typecheck/src/checker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
use colored::Colorize;
use error::{ErrKind, Error};
use fir::{Fallible, Fir, Node, RefIdx, Traversal};
use flatten::AstInfo;
use flatten::FlattenData;
use location::SpanTuple;

Expand Down Expand Up @@ -83,7 +84,7 @@ impl<'ctx> Checker<'ctx> {
BuiltinType::Bool => Type::record(self.0.primitives.bool_type),
};

if valid_union_type.is_superset_of(&expected_ty) {
if expected_ty.can_widen_to(&valid_union_type) {
Ok(vec![expected_ty; arity])
} else {
Err(unexpected_arithmetic_type(
Expand Down Expand Up @@ -121,13 +122,21 @@ impl<'fir, 'ast> Fmt<'fir, 'ast> {
unreachable!()
}

let ty_fmt = |node: &Node<FlattenData<'_>>| match &node.data.ast {
AstInfo::Node(ast::Ast {
node: ast::Node::Constant(value),
..
}) => value.to_string(),
info => info.symbol().unwrap().access().to_string(),
};

ty.set()
.0 // FIXME: ugly
.iter()
.map(|idx| self.0.nodes[&idx.expect_resolved()].data.ast.symbol())
.fold(None, |acc, sym| match acc {
None => Some(format!("`{}`", sym.unwrap().access().purple())),
Some(acc) => Some(format!("{} | `{}`", acc, sym.unwrap().access().purple(),)),
.map(|idx| &self.0.nodes[&idx.expect_resolved()])
.fold(None, |acc, node| match acc {
None => Some(format!("`{}`", ty_fmt(node).purple())),
Some(acc) => Some(format!("{} | `{}`", acc, ty_fmt(node).purple(),)),
})
.unwrap()
}
Expand All @@ -144,6 +153,9 @@ fn type_mismatch(
) -> Error {
let fmt = Fmt(fir);

dbg!(&expected.0);
dbg!(&got.0);

Error::new(ErrKind::TypeChecker)
.with_msg(format!(
"type mismatch found: expected {}, got {}",
Expand Down Expand Up @@ -200,14 +212,14 @@ impl<'ctx> Traversal<FlattenData<'_>, Error> for Checker<'ctx> {
return Ok(());
}

let ret_ty = return_ty
.as_ref()
.map(|b| self.get_type(b))
.unwrap_or(self.unit());
let block_ty = block
.as_ref()
.map(|b| self.get_type(b))
.unwrap_or(self.unit());
let type_or_unit = |ty: &Option<RefIdx>| {
ty.as_ref()
.map(|ty| self.get_type(ty))
.unwrap_or(self.unit())
};

let ret_ty = type_or_unit(return_ty);
let block_ty = type_or_unit(block);

if !block_ty.can_widen_to(ret_ty) {
let err = type_mismatch(
Expand Down
4 changes: 4 additions & 0 deletions typecheck/src/collectors.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
//! A collection of collectors (heh) necessary for the well-being of the type system.
pub mod constants;
pub mod primitives;
57 changes: 57 additions & 0 deletions typecheck/src/collectors/constants.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
//! Collect all primitive union type constants in the program in order to build our primitive union types properly. There are three primitive union types: `char`, `int` and `string`, so this module collects all character, integer and string constants.
use std::collections::HashSet;
use std::convert::Infallible;

use fir::{Fallible, Fir, Node, OriginIdx, RefIdx, Traversal};
use flatten::{AstInfo, FlattenData};

#[derive(Default)]
pub struct ConstantCollector {
pub(crate) integers: HashSet<RefIdx>,
pub(crate) characters: HashSet<RefIdx>,
pub(crate) strings: HashSet<RefIdx>,
}

impl ConstantCollector {
pub fn new() -> ConstantCollector {
ConstantCollector::default()
}

fn add_integer(&mut self, idx: OriginIdx) {
self.integers.insert(RefIdx::Resolved(idx));
}

fn add_character(&mut self, idx: OriginIdx) {
self.characters.insert(RefIdx::Resolved(idx));
}

fn add_string(&mut self, idx: OriginIdx) {
self.strings.insert(RefIdx::Resolved(idx));
}
}

impl Traversal<FlattenData<'_>, Infallible> for ConstantCollector {
fn traverse_constant(
&mut self,
_: &Fir<FlattenData<'_>>,
node: &Node<FlattenData<'_>>,
_: &RefIdx,
) -> Fallible<Infallible> {
match node.data.ast {
AstInfo::Node(ast::Ast {
node: ast::Node::Constant(value),
..
}) => match value {
ast::Value::Integer(_) => self.add_integer(node.origin),
ast::Value::Char(_) => self.add_character(node.origin),
ast::Value::Str(_) => self.add_string(node.origin),
// do nothing - the other constants are not part of primitive union types
_ => {}
},
_ => unreachable!("Fir constant with non-node AST info. this is an interpreter error"),
};

Ok(())
}
}
File renamed without changes.
56 changes: 48 additions & 8 deletions typecheck/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
mod actual;
mod checker;
mod primitives;
mod collectors;
mod typemap;
mod typer;

use std::collections::{HashMap, HashSet};

use error::{ErrKind, Error};
use fir::{Fir, Incomplete, Mapper, OriginIdx, Pass, RefIdx, Traversal};
use fir::{Fir, Incomplete, Kind, Mapper, OriginIdx, Pass, RefIdx, Traversal};
use flatten::FlattenData;

use actual::Actual;
use checker::Checker;
use typer::Typer;

use primitives::PrimitiveTypes;
use collectors::{
constants::ConstantCollector,
primitives::{self, PrimitiveTypes},
};

#[derive(Clone, Debug, Eq, PartialEq)]
// FIXME: Should that be a hashset RefIdx or OriginIdx?
Expand All @@ -32,12 +35,25 @@ impl TypeSet {

/// This is the base structure that our typechecker - a type "interpreter" - will play with.
/// In `jinko`, the type of a variable is a set of elements of kind `type`. So this structure can
/// be thought of as a simple set of actual, monomorphized types.
// TODO: for now, let's not think about optimizations - let's box and clone and blurt bytes everywhere
/// be thought of as a simple set of actual, monomorphized types. There is one complication in that
/// the language recognizes a couple of magic types: `int`, `string` and `char` should be treated as
/// sets of all possible literals of that type. So we can imagine that `char` should actually be defined
/// as such:
///
/// ```rust,ignore
/// type char = '0' | '1' | '2' ... 'a' | 'b' | 'c' ... | <last_unicode_char_ever>;
/// ```
///
/// This is of course not a realistic definition to put in our standard library (and it gets worse for `string`)
/// so these types have to be handled separately.
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct Type(OriginIdx, TypeSet);

impl Type {
pub fn origin(&self) -> OriginIdx {
self.0
}

pub fn builtin(set: HashSet<RefIdx>) -> Type {
Type(OriginIdx(u64::MAX), TypeSet(set))
}
Expand All @@ -59,11 +75,11 @@ impl Type {
}

pub fn is_superset_of(&self, other: &Type) -> bool {
return self.set().contains(other.set());
self.set().contains(other.set())
}

pub fn can_widen_to(&self, superset: &Type) -> bool {
return superset.set().contains(self.set());
superset.set().contains(self.set())
}
}

Expand All @@ -89,9 +105,33 @@ pub trait TypeCheck<T>: Sized {
}

impl<'ast> TypeCheck<Fir<FlattenData<'ast>>> for Fir<FlattenData<'ast>> {
fn type_check(self) -> Result<Fir<FlattenData<'ast>>, Error> {
fn type_check(mut self) -> Result<Fir<FlattenData<'ast>>, Error> {
let primitives = primitives::find(&self)?;

let mut const_collector = ConstantCollector::new();
const_collector.traverse(&self)?;

// We can now build our primitive union types. Because the first TypeCtx deals
// with [`TypeVariable`]s, it's not possible to directly create a TypeSet - so
// we can do that later on during typechecking, right before the actual
// typechecking. An alternative is to modify the [`Fir`] directly by creating
// new nodes for these primitive unions, which is probably a little cleaner and
// less spaghetti.
let mk_constant_types = |set: HashSet<RefIdx>| set.into_iter().collect();

self[primitives.int_type].kind = Kind::UnionType {
generics: vec![],
variants: mk_constant_types(const_collector.integers),
};
self[primitives.char_type].kind = Kind::UnionType {
generics: vec![],
variants: mk_constant_types(const_collector.characters),
};
self[primitives.string_type].kind = Kind::UnionType {
generics: vec![],
variants: mk_constant_types(const_collector.strings),
};

TypeCtx {
primitives,
types: HashMap::new(),
Expand Down
2 changes: 1 addition & 1 deletion typecheck/src/typemap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ impl TypeMap {

/// Insert a new type into the typemap
pub fn new_type(&mut self, ty: Type) -> TypeRef {
let ty_ref = TypeRef(ty.0);
let ty_ref = TypeRef(ty.origin());
// FIXME: Is it actually okay to ignore if the type existed or not?
self.types.insert(ty_ref, ty);

Expand Down
33 changes: 17 additions & 16 deletions typecheck/src/typer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,17 @@ impl<'ast, 'ctx> Mapper<FlattenData<'ast>, FlattenData<'ast>, Error> for Typer<'
origin: OriginIdx,
_constant: RefIdx,
) -> Result<Node<FlattenData<'ast>>, Error> {
let ast = data.ast.node();

let ty = match &ast.node {
// This does not take into account that primitives are multi types and will need to be fixed
AstNode::Constant(Value::Bool(_)) => self.0.primitives.bool_type,
AstNode::Constant(Value::Char(_)) => self.0.primitives.char_type,
AstNode::Constant(Value::Integer(_)) => self.0.primitives.int_type,
AstNode::Constant(Value::Float(_)) => self.0.primitives.float_type,
AstNode::Constant(Value::Str(_)) => self.0.primitives.string_type,
_ => unreachable!(),
};
// let ast = data.ast.node();

// let ty = match &ast.node {
// This does not take into account that primitives are multi types and will need to be fixed
// AstNode::Constant(Value::Bool(_)) => self.0.primitives.bool_type,
// AstNode::Constant(Value::Char(_)) => self.0.primitives.char_type,
// AstNode::Constant(Value::Integer(_)) => self.0.primitives.int_type,
// AstNode::Constant(Value::Float(_)) => self.0.primitives.float_type,
// AstNode::Constant(Value::Str(_)) => self.0.primitives.string_type,
// _ => unreachable!(),
// };

// For constants, how will we look up the basic primitive type nodes before assigning them
// here? Just a traversal and we do that based on name? Or will they need to be builtin at this point?
Expand All @@ -86,13 +86,14 @@ impl<'ast, 'ctx> Mapper<FlattenData<'ast>, FlattenData<'ast>, Error> for Typer<'
// the proper primitive multitype. We need to implement this.

// FIXME: How do we get a TypeReference here? Or should we actually do that operation in the checker?
let new_node = Node {

self.assign_type(origin, TypeVariable::Record(origin));

Ok(Node {
data,
origin,
kind: Kind::Constant(RefIdx::Resolved(ty)),
};

self.ty(new_node, RefIdx::Resolved(ty))
kind: Kind::Constant(RefIdx::Resolved(origin)),
})
}

fn map_call(
Expand Down

0 comments on commit 7299231

Please sign in to comment.