diff --git a/Cargo.lock b/Cargo.lock index 52e93358..4136ff96 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -176,6 +176,16 @@ dependencies = [ "fir", ] +[[package]] +name = "dedup" +version = "0.1.0" +dependencies = [ + "ast", + "error", + "fir", + "flatten", +] + [[package]] name = "dirs" version = "1.0.5" @@ -366,6 +376,7 @@ dependencies = [ "builtins", "colored", "debug-fir", + "dedup", "downcast-rs", "error", "fir", @@ -516,6 +527,7 @@ dependencies = [ name = "name_resolve" version = "0.1.0" dependencies = [ + "ast", "colored", "debug-fir", "distance", diff --git a/Cargo.toml b/Cargo.toml index 5549db76..6fe38062 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,7 @@ members = [ "ast", "fir", "flatten", + "dedup", "name_resolve", "include_code", "loop_desugar", @@ -35,6 +36,7 @@ ast = { path = "ast" } ast-sanitizer = { path = "ast-sanitizer" } fir = { path = "fir" } debug-fir = { path = "debug-fir" } +dedup = { path = "dedup" } flatten = { path = "flatten" } name_resolve = { path = "name_resolve" } symbol = { path = "symbol" } diff --git a/ast/src/lib.rs b/ast/src/lib.rs index a6db5f6d..b6a47c86 100644 --- a/ast/src/lib.rs +++ b/ast/src/lib.rs @@ -1,5 +1,7 @@ //! Abstract Syntax Tree representation of jinko's source code +use std::fmt; + use error::{ErrKind, Error}; use location::SpanTuple; use symbol::Symbol; @@ -10,6 +12,7 @@ pub enum TypeKind { // probably simpler, right? // but then how do we handle a simple type like "int" without generics or anything -> symbol(int) and empty vec for generics so we're good Simple(Symbol), // FIXME: Can this be a symbol? Should it be something else? I think Symbol is fine, because the struct enclosing TypeKind has generics and location + Literal(Box), Multi(Vec), FunctionLike(Vec, Option>), } @@ -136,6 +139,18 @@ pub enum Value { Str(String), } +impl fmt::Display for Value { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Value::Integer(v) => write!(f, "{v}"), + Value::Float(v) => write!(f, "{v}"), + Value::Bool(v) => write!(f, "{v}"), + Value::Char(v) => write!(f, "'{v}'"), + Value::Str(v) => write!(f, "\"{v}\""), + } + } +} + // FIXME: How to keep location in there? How to do it ergonomically? // As a "Smart pointer" type? E.g by having it implement `Deref`? // Would that even work? If it does, it is ergonomic but boy is it not idiomatic diff --git a/dedup/Cargo.toml b/dedup/Cargo.toml new file mode 100644 index 00000000..c19e8471 --- /dev/null +++ b/dedup/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "dedup" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +ast = { path = "../ast" } +fir = { path = "../fir" } +flatten = { path = "../flatten" } +error = { path = "../error" } diff --git a/dedup/src/lib.rs b/dedup/src/lib.rs new file mode 100644 index 00000000..dc33fcea --- /dev/null +++ b/dedup/src/lib.rs @@ -0,0 +1,100 @@ +//! The Deduplication module aims at removing duplicate nodes from an [`Fir`]. This can be +//! done for "optimization" purposes, but not only - it is essential to type checking in order to +//! enable literal types and simplify union types. + +use std::collections::HashMap; + +use error::Error; +use fir::{Fir, Kind, Mapper, Node, OriginIdx, RefIdx}; +use flatten::{AstInfo, FlattenData}; + +pub trait DeduplicateConstants: Sized { + fn deduplicate_constants(self) -> Result; +} + +impl<'ast> DeduplicateConstants for Fir> { + fn deduplicate_constants(self) -> Result { + let mut ctx = ConstantDeduplicator(HashMap::new()); + + // we can't emit an error in the constant deduplicator + Ok(ctx.map(self).unwrap()) + } +} + +#[derive(Debug, Clone, Hash, PartialEq, Eq)] +enum HashableValue<'ast> { + Bool(bool), + Char(char), + Integer(i64), + String(&'ast str), +} + +impl<'ast> From<&'ast ast::Value> for HashableValue<'ast> { + fn from(value: &'ast ast::Value) -> Self { + use ast::Value::*; + + match value { + Bool(b) => HashableValue::Bool(*b), + Char(c) => HashableValue::Char(*c), + Integer(i) => HashableValue::Integer(*i), + Str(s) => HashableValue::String(s), + Float(_) => { + unreachable!("cannot hash floating point constants. this is an interpreter error") + } + } + } +} + +struct ConstantDeduplicator<'ast>(pub(crate) HashMap, OriginIdx>); + +impl<'ast> Mapper, FlattenData<'ast>, Error> for ConstantDeduplicator<'ast> { + fn map_constant( + &mut self, + data: FlattenData<'ast>, + origin: OriginIdx, + constant: RefIdx, + ) -> Result>, Error> { + match data.ast { + // if we're dealing with a floating point constant, then we can't deduplicate - equality rules + // are not great around floats and we can't use them as literal types anyway (and for that reason) + AstInfo::Node(ast::Ast { + node: ast::Node::Constant(ast::Value::Float(_)), + .. + }) => Ok(Node { + origin, + data, + kind: Kind::Constant(constant), + }), + AstInfo::Node(ast::Ast { + node: ast::Node::Constant(value), + .. + }) => { + let value = HashableValue::from(value); + + match self.0.get(&value) { + // if the value was already present in the map, then we can just transform + // our constant into a reference to that constant + Some(constant) => Ok(Node { + origin, + data, + kind: Kind::NodeRef(RefIdx::Resolved(*constant)), + }), + // otherwise, we insert the value into the map and return this node as + // the "original constant", that future duplicate constants will refer to + None => { + self.0.insert(value, origin); + + Ok(Node { + origin, + data, + kind: Kind::Constant(constant), + }) + } + } + } + _ => unreachable!( + "constant without an AST constant as its node info. this is an interpreter error." + ), + } + } +} diff --git a/error/src/lib.rs b/error/src/lib.rs index e037d674..fc731aa2 100644 --- a/error/src/lib.rs +++ b/error/src/lib.rs @@ -287,7 +287,7 @@ impl Error { } } -use std::convert::From; +use std::convert::{From, Infallible}; use std::io; /// I/O errors keep their messages @@ -330,9 +330,17 @@ impl std::convert::From for Error { } } -impl From> for Error { - fn from(errs: Vec) -> Self { - Error::new(ErrKind::Multiple(errs)) +impl> From> for Error { + fn from(errs: Vec) -> Self { + Error::new(ErrKind::Multiple( + errs.into_iter().map(Into::into).collect(), + )) + } +} + +impl From for Error { + fn from(_: Infallible) -> Self { + unreachable!() } } diff --git a/fir/src/checks.rs b/fir/src/checks.rs index 8d532961..fa74625d 100644 --- a/fir/src/checks.rs +++ b/fir/src/checks.rs @@ -62,12 +62,12 @@ impl Fir { // FIXME: This is missing a bunch of valid "checks". For example, checking that a call's argument can // point to an if-else expression. Basically, to anything that's an expression actually. // Should we split the fir::Kind into fir::Kind::Stmt and fir::Kind::Expr? Or does that not make sense? - Kind::Constant(r) => check!(r => Kind::RecordType { .. }, node), + Kind::Constant(r) => check!(r => Kind::RecordType { .. } | Kind::Constant(_), node), Kind::NodeRef(_to) => { // `to` can link to basically anything, so there is nothing to do } // FIXME: Is that okay? - Kind::TypeReference(to) => check!(to => Kind::RecordType { .. } | Kind::UnionType { .. } | Kind::Generic { .. } | Kind::TypeReference(_), node), + Kind::TypeReference(to) => check!(to => Kind::RecordType { .. } | Kind::UnionType { .. } | Kind::Generic { .. } | Kind::TypeReference(_) | Kind::Constant(_), node), Kind::Generic { default: Some(default), } => check!(default => Kind::TypeReference(_), node), @@ -103,7 +103,7 @@ impl Fir { variants, } => { check!(@generics => Kind::Generic { .. }, node); - check!(@variants => Kind::TypeReference(_), node); + check!(@variants => Kind::TypeReference(_) | Kind::Constant(_), node); } Kind::Binding { to: _, ty } => { // `to` can point to anything, correct? diff --git a/fir/src/lib.rs b/fir/src/lib.rs index fdd38003..34dcf67c 100644 --- a/fir/src/lib.rs +++ b/fir/src/lib.rs @@ -97,6 +97,7 @@ use std::fmt::{self, Debug}; use std::hash::Hash; +use std::ops::IndexMut; use std::{collections::BTreeMap, ops::Index}; mod checks; @@ -130,6 +131,18 @@ impl RefIdx { } } +impl From<&OriginIdx> for OriginIdx { + fn from(value: &OriginIdx) -> Self { + *value + } +} + +impl From<&RefIdx> for OriginIdx { + fn from(value: &RefIdx) -> Self { + value.expect_resolved() + } +} + /// Each [`Node`] in the [`Fir`] is its own [`OriginIdx`], which is an origin point. This is a bit wasteful /// since most nodes aren't definitions and instead *refer* to definitions, but it makes it easy to refer to /// call points or to emit errors. @@ -246,19 +259,17 @@ pub enum Kind { NodeRef(RefIdx), } -impl Index<&RefIdx> for Fir { +impl> Index for Fir { type Output = Node; - fn index(&self, index: &RefIdx) -> &Node { - &self.nodes[&index.expect_resolved()] + fn index(&self, index: K) -> &Node { + &self.nodes[&index.into()] } } -impl Index<&OriginIdx> for Fir { - type Output = Node; - - fn index(&self, index: &OriginIdx) -> &Node { - &self.nodes[index] +impl> IndexMut for Fir { + fn index_mut(&mut self, index: K) -> &mut Node { + self.nodes.get_mut(&index.into()).unwrap() } } diff --git a/flatten/src/lib.rs b/flatten/src/lib.rs index 5a5a8d0b..7c85be2d 100644 --- a/flatten/src/lib.rs +++ b/flatten/src/lib.rs @@ -329,7 +329,16 @@ impl<'ast> Ctx<'ast> { match &ty.kind { // FIXME: Do we need to create a type reference here as well? `handle_multi_ty` returns the actual union type TypeKind::Multi(variants) => ctx.handle_multi_type(ty, generics, variants), + TypeKind::Literal(ast) => { + let (ctx, idx) = ctx.visit(ast); + let data = FlattenData { + ast: AstInfo::Node(ast), + scope: ctx.scope, + }; + + ctx.append(data, Kind::TypeReference(idx)) + } TypeKind::Simple(_) => { let data = FlattenData { ast: AstInfo::Type(ty), diff --git a/interpreter/jinko.rs b/interpreter/jinko.rs index 92d07e75..85754bf1 100644 --- a/interpreter/jinko.rs +++ b/interpreter/jinko.rs @@ -7,6 +7,7 @@ mod repl; use colored::Colorize; use builtins::AppendAstBuiltins; +use dedup::DeduplicateConstants; use fire::instance::Instance; use fire::Interpret; use flatten::{FlattenAst, FlattenData}; @@ -153,6 +154,12 @@ fn experimental_pipeline(input: &str, file: &Path) -> InteractResult { .show_data(data_fmt) .display(&fir); + let fir = x_try!(fir.deduplicate_constants()); + FirDebug::default() + .header("deduped_constants") + .show_data(data_fmt) + .display(&fir); + let fir = x_try!(fir.type_check()); FirDebug::default() .header("typechecked") diff --git a/name_resolve/Cargo.toml b/name_resolve/Cargo.toml index 92fcebe8..7d952aad 100644 --- a/name_resolve/Cargo.toml +++ b/name_resolve/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +ast = { path = "../ast" } fir = { path = "../fir" } debug-fir = { path = "../debug-fir" } symbol = { path = "../symbol" } diff --git a/name_resolve/src/declarator.rs b/name_resolve/src/declarator.rs index de54a880..d954a0d3 100644 --- a/name_resolve/src/declarator.rs +++ b/name_resolve/src/declarator.rs @@ -1,5 +1,6 @@ +use ast::{Ast, Node as AstNode}; use fir::{Fallible, Fir, Node, OriginIdx, RefIdx, Traversal}; -use flatten::FlattenData; +use flatten::{AstInfo, FlattenData}; use crate::{NameResolutionError, NameResolveCtx, UniqueError}; @@ -73,12 +74,19 @@ impl<'ast, 'ctx, 'enclosing> Traversal, NameResolutionError> node: &Node>, reference: &RefIdx, ) -> Fallible { - // if we already see resolved type references, then it means we are dealing - // with a type alias - if let RefIdx::Resolved(_) = reference { - self.define(DefinitionKind::Type, node) - } else { - Ok(()) + match node.data.ast { + // if we're dealing with a type constant, then we have nothing to do during + // name resolution (at least in the declaration pass) + AstInfo::Node(Ast { + node: AstNode::Constant(_), + .. + }) => Ok(()), + _ => match reference { + // if we already see resolved type references, then it means we are dealing + // with a type alias + RefIdx::Resolved(_) => self.define(DefinitionKind::Type, node), + _ => Ok(()), + }, } } diff --git a/name_resolve/src/lib.rs b/name_resolve/src/lib.rs index 309bcf74..e32c336f 100644 --- a/name_resolve/src/lib.rs +++ b/name_resolve/src/lib.rs @@ -674,4 +674,17 @@ mod tests { assert!(fir.is_err()) } + + #[test] + fn nameres_literal_types() { + let ast = ast! { + type Bar; + type Foo(inner: Bar | "bar"); + func f(arg: Bar | "bar" | "Bar") {} + }; + + let fir = ast.flatten().name_resolve(); + + assert!(fir.is_ok()) + } } diff --git a/typecheck/src/actual.rs b/typecheck/src/actual.rs index c5d50fa5..9c533ba2 100644 --- a/typecheck/src/actual.rs +++ b/typecheck/src/actual.rs @@ -52,7 +52,7 @@ impl<'ctx> TypeLinkResolver<'ctx> { final_type, } = self.find_end(fir, to_resolve); - let node = final_type.0; + let node = final_type.origin(); let tyref = self.new.types.new_type(final_type); intermediate_nodes diff --git a/typecheck/src/checker.rs b/typecheck/src/checker.rs index 4965e736..8c0d1752 100644 --- a/typecheck/src/checker.rs +++ b/typecheck/src/checker.rs @@ -5,6 +5,7 @@ use colored::Colorize; use error::{ErrKind, Error}; use fir::{Fallible, Fir, Node, RefIdx, Traversal}; +use flatten::AstInfo; use flatten::FlattenData; use location::SpanTuple; @@ -83,7 +84,7 @@ impl<'ctx> Checker<'ctx> { BuiltinType::Bool => Type::record(self.0.primitives.bool_type), }; - if valid_union_type.is_superset_of(&expected_ty) { + if expected_ty.can_widen_to(&valid_union_type) { Ok(vec![expected_ty; arity]) } else { Err(unexpected_arithmetic_type( @@ -121,13 +122,21 @@ impl<'fir, 'ast> Fmt<'fir, 'ast> { unreachable!() } + let ty_fmt = |node: &Node>| match &node.data.ast { + AstInfo::Node(ast::Ast { + node: ast::Node::Constant(value), + .. + }) => value.to_string(), + info => info.symbol().unwrap().access().to_string(), + }; + ty.set() .0 // FIXME: ugly .iter() - .map(|idx| self.0.nodes[&idx.expect_resolved()].data.ast.symbol()) - .fold(None, |acc, sym| match acc { - None => Some(format!("`{}`", sym.unwrap().access().purple())), - Some(acc) => Some(format!("{} | `{}`", acc, sym.unwrap().access().purple(),)), + .map(|idx| &self.0.nodes[&idx.expect_resolved()]) + .fold(None, |acc, node| match acc { + None => Some(format!("`{}`", ty_fmt(node).purple())), + Some(acc) => Some(format!("{} | `{}`", acc, ty_fmt(node).purple(),)), }) .unwrap() } @@ -144,6 +153,9 @@ fn type_mismatch( ) -> Error { let fmt = Fmt(fir); + dbg!(&expected.0); + dbg!(&got.0); + Error::new(ErrKind::TypeChecker) .with_msg(format!( "type mismatch found: expected {}, got {}", @@ -200,14 +212,14 @@ impl<'ctx> Traversal, Error> for Checker<'ctx> { return Ok(()); } - let ret_ty = return_ty - .as_ref() - .map(|b| self.get_type(b)) - .unwrap_or(self.unit()); - let block_ty = block - .as_ref() - .map(|b| self.get_type(b)) - .unwrap_or(self.unit()); + let type_or_unit = |ty: &Option| { + ty.as_ref() + .map(|ty| self.get_type(ty)) + .unwrap_or(self.unit()) + }; + + let ret_ty = type_or_unit(return_ty); + let block_ty = type_or_unit(block); if !block_ty.can_widen_to(ret_ty) { let err = type_mismatch( diff --git a/typecheck/src/collectors.rs b/typecheck/src/collectors.rs new file mode 100644 index 00000000..672c206b --- /dev/null +++ b/typecheck/src/collectors.rs @@ -0,0 +1,4 @@ +//! A collection of collectors (heh) necessary for the well-being of the type system. + +pub mod constants; +pub mod primitives; diff --git a/typecheck/src/collectors/constants.rs b/typecheck/src/collectors/constants.rs new file mode 100644 index 00000000..e5bfa39f --- /dev/null +++ b/typecheck/src/collectors/constants.rs @@ -0,0 +1,64 @@ +//! Collect all primitive union type constants in the program in order to build our primitive union types properly. There are three primitive union types: `char`, `int` and `string`, so this module collects all character, integer and string constants. + +use std::collections::HashSet; +use std::convert::Infallible; + +use fir::{Fallible, Fir, Node, OriginIdx, RefIdx, Traversal}; +use flatten::{AstInfo, FlattenData}; + +#[derive(Default)] +pub struct ConstantCollector { + pub(crate) integers: HashSet, + pub(crate) characters: HashSet, + pub(crate) strings: HashSet, + // Hopefully there's only two of those + pub(crate) bools: HashSet, +} + +impl ConstantCollector { + pub fn new() -> ConstantCollector { + ConstantCollector::default() + } + + fn add_integer(&mut self, idx: OriginIdx) { + self.integers.insert(RefIdx::Resolved(idx)); + } + + fn add_character(&mut self, idx: OriginIdx) { + self.characters.insert(RefIdx::Resolved(idx)); + } + + fn add_string(&mut self, idx: OriginIdx) { + self.strings.insert(RefIdx::Resolved(idx)); + } + + fn add_bool(&mut self, idx: OriginIdx) { + self.bools.insert(RefIdx::Resolved(idx)); + } +} + +impl Traversal, Infallible> for ConstantCollector { + fn traverse_constant( + &mut self, + _: &Fir>, + node: &Node>, + _: &RefIdx, + ) -> Fallible { + match node.data.ast { + AstInfo::Node(ast::Ast { + node: ast::Node::Constant(value), + .. + }) => match value { + ast::Value::Integer(_) => self.add_integer(node.origin), + ast::Value::Char(_) => self.add_character(node.origin), + ast::Value::Str(_) => self.add_string(node.origin), + ast::Value::Bool(_) => self.add_bool(node.origin), + // do nothing - the other constants are not part of primitive union types + _ => {} + }, + _ => unreachable!("Fir constant with non-node AST info. this is an interpreter error"), + }; + + Ok(()) + } +} diff --git a/typecheck/src/primitives.rs b/typecheck/src/collectors/primitives.rs similarity index 100% rename from typecheck/src/primitives.rs rename to typecheck/src/collectors/primitives.rs diff --git a/typecheck/src/lib.rs b/typecheck/src/lib.rs index e1e90d56..5dde117c 100644 --- a/typecheck/src/lib.rs +++ b/typecheck/src/lib.rs @@ -1,20 +1,23 @@ mod actual; mod checker; -mod primitives; +mod collectors; mod typemap; mod typer; use std::collections::{HashMap, HashSet}; use error::{ErrKind, Error}; -use fir::{Fir, Incomplete, Mapper, OriginIdx, Pass, RefIdx, Traversal}; +use fir::{Fir, Incomplete, Kind, Mapper, OriginIdx, Pass, RefIdx, Traversal}; use flatten::FlattenData; use actual::Actual; use checker::Checker; use typer::Typer; -use primitives::PrimitiveTypes; +use collectors::{ + constants::ConstantCollector, + primitives::{self, PrimitiveTypes}, +}; #[derive(Clone, Debug, Eq, PartialEq)] // FIXME: Should that be a hashset RefIdx or OriginIdx? @@ -32,12 +35,25 @@ impl TypeSet { /// This is the base structure that our typechecker - a type "interpreter" - will play with. /// In `jinko`, the type of a variable is a set of elements of kind `type`. So this structure can -/// be thought of as a simple set of actual, monomorphized types. -// TODO: for now, let's not think about optimizations - let's box and clone and blurt bytes everywhere +/// be thought of as a simple set of actual, monomorphized types. There is one complication in that +/// the language recognizes a couple of magic types: `int`, `string` and `char` should be treated as +/// sets of all possible literals of that type. So we can imagine that `char` should actually be defined +/// as such: +/// +/// ```rust,ignore +/// type char = '0' | '1' | '2' ... 'a' | 'b' | 'c' ... | ; +/// ``` +/// +/// This is of course not a realistic definition to put in our standard library (and it gets worse for `string`) +/// so these types have to be handled separately. #[derive(Clone, Debug, Eq, PartialEq)] pub struct Type(OriginIdx, TypeSet); impl Type { + pub fn origin(&self) -> OriginIdx { + self.0 + } + pub fn builtin(set: HashSet) -> Type { Type(OriginIdx(u64::MAX), TypeSet(set)) } @@ -59,11 +75,11 @@ impl Type { } pub fn is_superset_of(&self, other: &Type) -> bool { - return self.set().contains(other.set()); + self.set().contains(other.set()) } pub fn can_widen_to(&self, superset: &Type) -> bool { - return superset.set().contains(self.set()); + superset.set().contains(self.set()) } } @@ -89,9 +105,37 @@ pub trait TypeCheck: Sized { } impl<'ast> TypeCheck>> for Fir> { - fn type_check(self) -> Result>, Error> { + fn type_check(mut self) -> Result>, Error> { let primitives = primitives::find(&self)?; + let mut const_collector = ConstantCollector::new(); + const_collector.traverse(&self)?; + + // We can now build our primitive union types. Because the first TypeCtx deals + // with [`TypeVariable`]s, it's not possible to directly create a TypeSet - so + // we can do that later on during typechecking, right before the actual + // typechecking. An alternative is to modify the [`Fir`] directly by creating + // new nodes for these primitive unions, which is probably a little cleaner and + // less spaghetti. + let mk_constant_types = |set: HashSet| set.into_iter().collect(); + + self[primitives.int_type].kind = Kind::UnionType { + generics: vec![], + variants: mk_constant_types(const_collector.integers), + }; + self[primitives.char_type].kind = Kind::UnionType { + generics: vec![], + variants: mk_constant_types(const_collector.characters), + }; + self[primitives.string_type].kind = Kind::UnionType { + generics: vec![], + variants: mk_constant_types(const_collector.strings), + }; + self[primitives.bool_type].kind = Kind::UnionType { + generics: vec![], + variants: mk_constant_types(const_collector.bools), + }; + TypeCtx { primitives, types: HashMap::new(), diff --git a/typecheck/src/typemap.rs b/typecheck/src/typemap.rs index 5cfc25a1..98381b92 100644 --- a/typecheck/src/typemap.rs +++ b/typecheck/src/typemap.rs @@ -48,7 +48,7 @@ impl TypeMap { /// Insert a new type into the typemap pub fn new_type(&mut self, ty: Type) -> TypeRef { - let ty_ref = TypeRef(ty.0); + let ty_ref = TypeRef(ty.origin()); // FIXME: Is it actually okay to ignore if the type existed or not? self.types.insert(ty_ref, ty); diff --git a/typecheck/src/typer.rs b/typecheck/src/typer.rs index 61894a97..73a11a63 100644 --- a/typecheck/src/typer.rs +++ b/typecheck/src/typer.rs @@ -1,7 +1,6 @@ // TODO: Write module documentation // TODO: Does `Typer` take care of monomorphization as well? -use ast::{Node as AstNode, Value}; use error::Error; use fir::{Kind, Mapper, Node, OriginIdx, RefIdx}; use flatten::FlattenData; @@ -66,17 +65,17 @@ impl<'ast, 'ctx> Mapper, FlattenData<'ast>, Error> for Typer<' origin: OriginIdx, _constant: RefIdx, ) -> Result>, Error> { - let ast = data.ast.node(); - - let ty = match &ast.node { - // This does not take into account that primitives are multi types and will need to be fixed - AstNode::Constant(Value::Bool(_)) => self.0.primitives.bool_type, - AstNode::Constant(Value::Char(_)) => self.0.primitives.char_type, - AstNode::Constant(Value::Integer(_)) => self.0.primitives.int_type, - AstNode::Constant(Value::Float(_)) => self.0.primitives.float_type, - AstNode::Constant(Value::Str(_)) => self.0.primitives.string_type, - _ => unreachable!(), - }; + // let ast = data.ast.node(); + + // let ty = match &ast.node { + // This does not take into account that primitives are multi types and will need to be fixed + // AstNode::Constant(Value::Bool(_)) => self.0.primitives.bool_type, + // AstNode::Constant(Value::Char(_)) => self.0.primitives.char_type, + // AstNode::Constant(Value::Integer(_)) => self.0.primitives.int_type, + // AstNode::Constant(Value::Float(_)) => self.0.primitives.float_type, + // AstNode::Constant(Value::Str(_)) => self.0.primitives.string_type, + // _ => unreachable!(), + // }; // For constants, how will we look up the basic primitive type nodes before assigning them // here? Just a traversal and we do that based on name? Or will they need to be builtin at this point? @@ -87,13 +86,14 @@ impl<'ast, 'ctx> Mapper, FlattenData<'ast>, Error> for Typer<' // the proper primitive multitype. We need to implement this. // FIXME: How do we get a TypeReference here? Or should we actually do that operation in the checker? - let new_node = Node { + + self.assign_type(origin, TypeVariable::Record(origin)); + + Ok(Node { data, origin, - kind: Kind::Constant(RefIdx::Resolved(ty)), - }; - - self.ty(new_node, RefIdx::Resolved(ty)) + kind: Kind::Constant(RefIdx::Resolved(origin)), + }) } fn map_call( diff --git a/xparser/src/constructs.rs b/xparser/src/constructs.rs index a8f80197..ebf2a2d9 100644 --- a/xparser/src/constructs.rs +++ b/xparser/src/constructs.rs @@ -31,6 +31,7 @@ use location::Location; use location::SpanTuple; use nom::sequence::tuple; use nom::Err::Error as NomError; +use nom::Parser; use nom::Slice; use nom_locate::position; use symbol::Symbol; @@ -471,25 +472,53 @@ fn type_id(input: ParseInput) -> ParseResult { }, )) } else { - let (input, (id, (start_loc, end_loc))) = spaced_identifier(input)?; - let kind = TypeKind::Simple(Symbol::from(id)); - let (input, generics) = maybe_generic_application(input)?; + fn literal(input: ParseInput) -> ParseResult { + let (input, (location, kind)) = string_constant + .or(float_constant) + .or(int_constant) + .map(|literal| { + ( + literal.location.clone(), + TypeKind::Literal(Box::new(literal)), + ) + }) + .parse(input)?; + + Ok(( + input, + Type { + kind, + generics: vec![], + location, + }, + )) + } - // FIXME: Refactor - let end_loc = if !generics.is_empty() { - position(input)?.1.into() - } else { - end_loc - }; + fn type_symbol(input: ParseInput) -> ParseResult { + let (input, (kind, (start_loc, end_loc))) = spaced_identifier + .map(|(id, loc)| (TypeKind::Simple(Symbol::from(id)), loc)) + .parse(input)?; + + let (input, generics) = maybe_generic_application(input)?; + + // FIXME: Refactor + let end_loc = if !generics.is_empty() { + position(input)?.1.into() + } else { + end_loc + }; + + Ok(( + input, + Type { + kind, + generics, + location: pos_to_loc(input, start_loc, end_loc), + }, + )) + } - Ok(( - input, - Type { - kind, - generics, - location: pos_to_loc(input, start_loc, end_loc), - }, - )) + type_symbol.or(literal).parse(input) } } @@ -2295,4 +2324,25 @@ mod tests { TypeKind::FunctionLike(..) )); } + + #[test] + fn parse_module_type() { + let input = span!("type foo = source[\"foo\"]"); + + assert!(expr(input).is_ok()); + } + + #[test] + fn parse_literal_type_in_function_arg() { + let input = span!("func only_accepts_foo_15(a: \"foo\", b: 15) {}"); + + assert!(expr(input).is_ok()); + } + + #[test] + fn parse_literal_type_in_generic_application() { + let input = span!("func foo(a: Foo[\"foo\"], b: Bar[15, 15.2]) {}"); + + assert!(expr(input).is_ok()); + } }