Skip to content

Commit

Permalink
Support implementing trait functions in impl blocks
Browse files Browse the repository at this point in the history
  • Loading branch information
cburgdorf committed Jun 14, 2022
1 parent dee378a commit 588483d
Show file tree
Hide file tree
Showing 48 changed files with 1,769 additions and 962 deletions.
27 changes: 23 additions & 4 deletions crates/analyzer/src/context.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use crate::namespace::items::{Class, ContractId, DiagnosticSink, EventId, FunctionId, Item};
use crate::namespace::types::{SelfDecl, Struct, Type};
use crate::namespace::items::{
Class, ContractId, DiagnosticSink, EventId, FunctionId, FunctionSigId, Item, TraitId,
};
use crate::namespace::types::{Generic, SelfDecl, Struct, Type};
use crate::AnalyzerDb;
use crate::{
builtins::{ContractTypeMethod, GlobalFunction, Intrinsic, ValueMethod},
Expand Down Expand Up @@ -361,7 +363,12 @@ impl Location {
pub fn assign_location(typ: &Type) -> Self {
match typ {
Type::Base(_) | Type::Contract(_) => Location::Value,
Type::Array(_) | Type::Tuple(_) | Type::String(_) | Type::Struct(_) => Location::Memory,
// For now assume that generics can only ever refer to structs
Type::Array(_)
| Type::Tuple(_)
| Type::String(_)
| Type::Struct(_)
| Type::Generic(_) => Location::Memory,
other => panic!("Type {other} can not be assigned, returned or passed"),
}
}
Expand Down Expand Up @@ -458,11 +465,21 @@ pub enum CallType {
class: Class,
function: FunctionId,
},
// some_struct_or_contract.foo()
ValueMethod {
is_self: bool,
class: Class,
method: FunctionId,
},
// some_trait.foo()
// The reason this can not use `ValueMethod` is mainly because the trait might not have a function implementation
// and even if it had it might not be the one that ends up getting executed. An `impl` block will decide that.
TraitValueMethod {
trait_id: TraitId,
method: FunctionSigId,
// Traits can not directly be used as types but can act as bounds for generics. This is the generic type
// that the method is called on.
generic_type: Generic,
},
External {
contract: ContractId,
function: FunctionId,
Expand All @@ -479,6 +496,7 @@ impl CallType {
| BuiltinValueMethod { .. }
| TypeConstructor(_)
| Intrinsic(_)
| TraitValueMethod { .. }
| BuiltinAssociatedFunction { .. } => None,
AssociatedFunction { function: id, .. }
| ValueMethod { method: id, .. }
Expand All @@ -497,6 +515,7 @@ impl CallType {
| CallType::ValueMethod { method: id, .. }
| CallType::External { function: id, .. }
| CallType::Pure(id) => id.name(db),
CallType::TraitValueMethod { method: id, .. } => id.name(db),
CallType::TypeConstructor(typ) => typ.name(),
}
}
Expand Down
22 changes: 18 additions & 4 deletions crates/analyzer/src/db.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use crate::context::{Analysis, Constant, FunctionBody};
use crate::errors::{ConstEvalError, TypeError};
use crate::namespace::items::{
self, ContractFieldId, ContractId, DepGraphWrapper, EventId, FunctionId, Impl, IngotId, Item,
ModuleConstantId, ModuleId, StructFieldId, StructId, TraitId, TypeAliasId,
self, ContractFieldId, ContractId, DepGraphWrapper, EventId, FunctionId, FunctionSigId, ImplId,
IngotId, Item, ModuleConstantId, ModuleId, StructFieldId, StructId, TraitId, TypeAliasId,
};
use crate::namespace::types;
use fe_common::db::{SourceDb, SourceDbStorage, Upcast, UpcastMut};
Expand All @@ -26,6 +26,8 @@ pub trait AnalyzerDb: SourceDb + Upcast<dyn SourceDb> + UpcastMut<dyn SourceDb>
#[salsa::interned]
fn intern_trait(&self, data: Rc<items::Trait>) -> TraitId;
#[salsa::interned]
fn intern_impl(&self, data: Rc<items::Impl>) -> ImplId;
#[salsa::interned]
fn intern_struct_field(&self, data: Rc<items::StructField>) -> StructFieldId;
#[salsa::interned]
fn intern_type_alias(&self, data: Rc<items::TypeAlias>) -> TypeAliasId;
Expand All @@ -34,6 +36,8 @@ pub trait AnalyzerDb: SourceDb + Upcast<dyn SourceDb> + UpcastMut<dyn SourceDb>
#[salsa::interned]
fn intern_contract_field(&self, data: Rc<items::ContractField>) -> ContractFieldId;
#[salsa::interned]
fn intern_function_sig(&self, data: Rc<items::FunctionSig>) -> FunctionSigId;
#[salsa::interned]
fn intern_function(&self, data: Rc<items::Function>) -> FunctionId;
#[salsa::interned]
fn intern_event(&self, data: Rc<items::Event>) -> EventId;
Expand Down Expand Up @@ -63,7 +67,7 @@ pub trait AnalyzerDb: SourceDb + Upcast<dyn SourceDb> + UpcastMut<dyn SourceDb>
#[salsa::invoke(queries::module::module_all_items)]
fn module_all_items(&self, module: ModuleId) -> Rc<[Item]>;
#[salsa::invoke(queries::module::module_all_impls)]
fn module_all_impls(&self, module: ModuleId) -> Rc<[Impl]>;
fn module_all_impls(&self, module: ModuleId) -> Rc<[ImplId]>;
#[salsa::invoke(queries::module::module_item_map)]
fn module_item_map(&self, module: ModuleId) -> Analysis<Rc<IndexMap<SmolStr, Item>>>;
#[salsa::invoke(queries::module::module_contracts)]
Expand Down Expand Up @@ -134,7 +138,7 @@ pub trait AnalyzerDb: SourceDb + Upcast<dyn SourceDb> + UpcastMut<dyn SourceDb>

// Function
#[salsa::invoke(queries::functions::function_signature)]
fn function_signature(&self, id: FunctionId) -> Analysis<Rc<types::FunctionSignature>>;
fn function_signature(&self, id: FunctionSigId) -> Analysis<Rc<types::FunctionSignature>>;
#[salsa::invoke(queries::functions::function_body)]
fn function_body(&self, id: FunctionId) -> Analysis<Rc<FunctionBody>>;
#[salsa::cycle(queries::functions::function_dependency_graph_cycle)]
Expand All @@ -161,6 +165,16 @@ pub trait AnalyzerDb: SourceDb + Upcast<dyn SourceDb> + UpcastMut<dyn SourceDb>
// Trait
#[salsa::invoke(queries::traits::trait_type)]
fn trait_type(&self, id: TraitId) -> Rc<types::Trait>;
#[salsa::invoke(queries::traits::trait_all_functions)]
fn trait_all_functions(&self, id: TraitId) -> Rc<[FunctionSigId]>;
#[salsa::invoke(queries::traits::trait_function_map)]
fn trait_function_map(&self, id: TraitId) -> Analysis<Rc<IndexMap<SmolStr, FunctionSigId>>>;

// Impl
#[salsa::invoke(queries::impls::impl_all_functions)]
fn impl_all_functions(&self, id: ImplId) -> Rc<[FunctionId]>;
#[salsa::invoke(queries::impls::impl_function_map)]
fn impl_function_map(&self, id: ImplId) -> Analysis<Rc<IndexMap<SmolStr, FunctionId>>>;

// Event
#[salsa::invoke(queries::events::event_type)]
Expand Down
1 change: 1 addition & 0 deletions crates/analyzer/src/db/queries.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
pub mod contracts;
pub mod events;
pub mod functions;
pub mod impls;
pub mod ingots;
pub mod module;
pub mod structs;
Expand Down
14 changes: 5 additions & 9 deletions crates/analyzer/src/db/queries/contracts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,9 @@ pub fn contract_all_functions(db: &dyn AnalyzerDb, contract: ContractId) -> Rc<[
body.iter()
.filter_map(|stmt| match stmt {
ast::ContractStmt::Event(_) => None,
ast::ContractStmt::Function(node) => {
Some(db.intern_function(Rc::new(items::Function {
ast: node.clone(),
module,
parent: Some(items::Class::Contract(contract)),
})))
}
ast::ContractStmt::Function(node) => Some(db.intern_function(Rc::new(
items::Function::new(db, node, Some(items::Class::Contract(contract)), module),
))),
})
.collect()
}
Expand All @@ -53,7 +49,7 @@ pub fn contract_function_map(
def_name,
&NamedThing::Item(Item::Event(event)),
Some(event.name_span(db)),
def.kind.name.span,
def.kind.sig.kind.name.span,
);
continue;
}
Expand All @@ -64,7 +60,7 @@ pub fn contract_function_map(
def_name,
&named_item,
named_item.name_span(db),
def.kind.name.span,
def.kind.sig.kind.name.span,
);
continue;
}
Expand Down
32 changes: 20 additions & 12 deletions crates/analyzer/src/db/queries/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::context::{AnalyzerContext, CallType, FunctionBody};
use crate::db::{Analysis, AnalyzerDb};
use crate::errors::TypeError;
use crate::namespace::items::{
Class, DepGraph, DepGraphWrapper, DepLocality, FunctionId, Item, TypeDef,
Class, DepGraph, DepGraphWrapper, DepLocality, FunctionId, FunctionSigId, Item, TypeDef,
};
use crate::namespace::scopes::{BlockScope, BlockScopeType, FunctionScope, ItemScope};
use crate::namespace::types::{self, Contract, CtxDecl, Generic, SelfDecl, Struct, Type};
Expand All @@ -19,10 +19,9 @@ use std::rc::Rc;
/// errors. Does not inspect the function body.
pub fn function_signature(
db: &dyn AnalyzerDb,
function: FunctionId,
function: FunctionSigId,
) -> Analysis<Rc<types::FunctionSignature>> {
let node = &function.data(db).ast;
let def = &node.kind;
let def = &function.data(db).ast;

let mut scope = ItemScope::new(db, function.module(db));
let fn_parent = function.class(db);
Expand All @@ -44,6 +43,7 @@ pub fn function_signature(
}

let params = def
.kind
.args
.iter()
.enumerate()
Expand Down Expand Up @@ -114,9 +114,9 @@ pub fn function_signature(
if label.kind != "_";
if let Some(dup_idx) = labels.get(&label.kind);
then {
let dup_arg: &Node<ast::FunctionArg> = &def.args[*dup_idx];
let dup_arg: &Node<ast::FunctionArg> = &def.kind.args[*dup_idx];
scope.fancy_error(
&format!("duplicate parameter labels in function `{}`", def.name.kind),
&format!("duplicate parameter labels in function `{}`", def.kind.name.kind),
vec![
Label::primary(dup_arg.span, "the label `{}` was first used here"),
Label::primary(label.span, "label `{}` used again here"),
Expand All @@ -138,9 +138,9 @@ pub fn function_signature(
);
None
} else if let Some(dup_idx) = names.get(&reg.name.kind) {
let dup_arg: &Node<ast::FunctionArg> = &def.args[*dup_idx];
let dup_arg: &Node<ast::FunctionArg> = &def.kind.args[*dup_idx];
scope.duplicate_name_error(
&format!("duplicate parameter names in function `{}`", def.name.kind),
&format!("duplicate parameter names in function `{}`", function.name(db)),
&reg.name.kind,
dup_arg.span,
arg.span,
Expand All @@ -159,10 +159,11 @@ pub fn function_signature(
.collect();

let return_type = def
.kind
.return_type
.as_ref()
.map(|type_node| {
let fn_name = &def.name.kind;
let fn_name = &function.name(db);
if fn_name == "__init__" || fn_name == "__call__" {
// `__init__` and `__call__` must not return any type other than `()`.
if type_node.kind != ast::TypeDesc::Unit {
Expand Down Expand Up @@ -207,7 +208,7 @@ pub fn function_signature(

pub fn resolve_function_param_type(
db: &dyn AnalyzerDb,
function: FunctionId,
function: FunctionSigId,
context: &mut dyn AnalyzerContext,
desc: &Node<ast::TypeDesc>,
) -> Result<Type, TypeError> {
Expand Down Expand Up @@ -254,11 +255,11 @@ pub fn function_body(db: &dyn AnalyzerDb, function: FunctionId) -> Analysis<Rc<F
"function body is missing a return or revert statement",
vec![
Label::primary(
def.name.span,
function.name_span(db),
"all paths of this function must `return` or `revert`",
),
Label::secondary(
def.return_type.as_ref().unwrap().span,
def.sig.kind.return_type.as_ref().unwrap().span,
format!("expected function to return `{}`", return_type),
),
],
Expand Down Expand Up @@ -362,6 +363,13 @@ pub fn function_dependency_graph(db: &dyn AnalyzerDb, function: FunctionId) -> D
directs.push((root, class.as_item(), DepLocality::Local));
directs.push((root, Item::Function(*method), DepLocality::Local));
}
CallType::TraitValueMethod { trait_id, .. } => {
directs.push((
root,
Item::Type(TypeDef::Trait(*trait_id)),
DepLocality::Local,
));
}
CallType::External { contract, function } => {
directs.push((root, Item::Function(*function), DepLocality::External));
// Probably redundant:
Expand Down
57 changes: 57 additions & 0 deletions crates/analyzer/src/db/queries/impls.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
use indexmap::map::Entry;
use indexmap::IndexMap;
use smol_str::SmolStr;

use crate::context::{Analysis, AnalyzerContext};
use crate::namespace::items::{Function, FunctionId, ImplId};
use crate::namespace::scopes::ItemScope;
use crate::namespace::types::TypeDowncast;
use crate::AnalyzerDb;
use std::rc::Rc;

pub fn impl_all_functions(db: &dyn AnalyzerDb, impl_: ImplId) -> Rc<[FunctionId]> {
let impl_data = impl_.data(db);
impl_data
.ast
.kind
.functions
.iter()
.map(|node| {
db.intern_function(Rc::new(Function::new(
db,
node,
// Not sure if setting the receiver as parent is the right thing to do. We currently do this
// so that the generated name of the YUL function will take in the receiver and avoids name collisions.
impl_.receiver(db).as_class(),
impl_data.module,
)))
})
.collect()
}

pub fn impl_function_map(
db: &dyn AnalyzerDb,
impl_: ImplId,
) -> Analysis<Rc<IndexMap<SmolStr, FunctionId>>> {
let scope = ItemScope::new(db, impl_.module(db));
let mut map = IndexMap::<SmolStr, FunctionId>::new();

for func in db.impl_all_functions(impl_).iter() {
let def_name = func.name(db);

match map.entry(def_name) {
Entry::Occupied(entry) => {
scope.duplicate_name_error(
"duplicate function names in `impl` block",
entry.key(),
entry.get().name_span(db),
func.name_span(db),
);
}
Entry::Vacant(entry) => {
entry.insert(*func);
}
}
}
Analysis::new(Rc::new(map), scope.diagnostics.take().into())
}
20 changes: 8 additions & 12 deletions crates/analyzer/src/db/queries/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ use crate::context::{Analysis, AnalyzerContext, Constant};
use crate::db::AnalyzerDb;
use crate::errors::{self, ConstEvalError, TypeError};
use crate::namespace::items::{
Contract, ContractId, Event, Function, Impl, Item, ModuleConstant, ModuleConstantId, ModuleId,
ModuleSource, Struct, StructId, Trait, TypeAlias, TypeDef,
Contract, ContractId, Event, Function, Impl, ImplId, Item, ModuleConstant, ModuleConstantId,
ModuleId, ModuleSource, Struct, StructId, Trait, TypeAlias, TypeDef,
};
use crate::namespace::scopes::ItemScope;
use crate::namespace::types::{self, Type};
Expand Down Expand Up @@ -86,13 +86,9 @@ pub fn module_all_items(db: &dyn AnalyzerDb, module: ModuleId) -> Rc<[Item]> {
module,
}),
))),
ast::ModuleStmt::Function(node) => {
Some(Item::Function(db.intern_function(Rc::new(Function {
ast: node.clone(),
module,
parent: None,
}))))
}
ast::ModuleStmt::Function(node) => Some(Item::Function(
db.intern_function(Rc::new(Function::new(db, node, None, module))),
)),
ast::ModuleStmt::Trait(node) => Some(Item::Type(TypeDef::Trait(db.intern_trait(
Rc::new(Trait {
ast: node.clone(),
Expand All @@ -110,7 +106,7 @@ pub fn module_all_items(db: &dyn AnalyzerDb, module: ModuleId) -> Rc<[Item]> {
.collect()
}

pub fn module_all_impls(db: &dyn AnalyzerDb, module: ModuleId) -> Rc<[Impl]> {
pub fn module_all_impls(db: &dyn AnalyzerDb, module: ModuleId) -> Rc<[ImplId]> {
let body = &module.ast(db).body;
body.iter()
.filter_map(|stmt| match stmt {
Expand All @@ -124,12 +120,12 @@ pub fn module_all_impls(db: &dyn AnalyzerDb, module: ModuleId) -> Rc<[Impl]> {
let receiver_type = type_desc(&mut scope, &impl_node.kind.receiver).unwrap();

if let Some(Item::Type(TypeDef::Trait(val))) = treit {
Some(Impl {
Some(db.intern_impl(Rc::new(Impl {
trait_id: val,
receiver: receiver_type,
ast: impl_node.clone(),
module,
})
})))
} else {
None
}
Expand Down
Loading

0 comments on commit 588483d

Please sign in to comment.