Skip to content

Commit

Permalink
Contract init_fn, public_fns queries. Reject calls to init, non-pub f…
Browse files Browse the repository at this point in the history
…ns on external contracts
  • Loading branch information
sbillig committed Aug 10, 2021
1 parent 61bbb2f commit 31b1e72
Show file tree
Hide file tree
Showing 19 changed files with 329 additions and 113 deletions.
91 changes: 44 additions & 47 deletions crates/abi/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::elements::{
ModuleAbis,
};
use crate::AbiError;
use fe_analyzer::namespace::items::{ContractId, ModuleId};
use fe_analyzer::namespace::items::{ContractId, FunctionId, ModuleId};
use fe_analyzer::namespace::types;
use fe_analyzer::AnalyzerDb;

Expand Down Expand Up @@ -50,55 +50,52 @@ fn contract_def(db: &dyn AnalyzerDb, contract: ContractId) -> Contract {
})
.collect();

let functions = contract
.functions(db)
let mut functions = contract
.public_functions(db)
.iter()
.filter_map(|(name, func)| {
if func.is_public(db) {
let sig = func.signature(db);
let (name, func_type) = match name.as_str() {
"__init__" => ("".to_owned(), FuncType::Constructor),
_ => (name.to_owned(), FuncType::Function),
};
.map(|(name, func)| function_def(db, name, *func, FuncType::Function))
.collect::<Vec<_>>();

let inputs = sig
.params
.iter()
.map(|param| {
let typ = param.typ.clone().expect("function parameter type error");
if let Some(init_fn) = contract.init_function(db) {
functions.push(function_def(db, "", init_fn, FuncType::Constructor));
}

FuncInput {
name: param.name.to_owned(),
typ: typ.abi_json_name(),
components: components(db, &typ),
}
})
.collect();
Contract { events, functions }
}

let return_type = sig.return_type.clone().expect("function return type error");
let outputs = if return_type.is_unit() {
vec![]
} else {
vec![FuncOutput {
name: "".to_string(),
typ: return_type.abi_json_name(),
components: components(db, &return_type),
}]
};
fn function_def(db: &dyn AnalyzerDb, name: &str, fn_id: FunctionId, typ: FuncType) -> Function {
let sig = fn_id.signature(db);
let inputs = sig
.params
.iter()
.map(|param| {
let typ = param.typ.clone().expect("function parameter type error");

Some(Function {
name,
typ: func_type,
inputs,
outputs,
})
} else {
None
FuncInput {
name: param.name.to_owned(),
typ: typ.abi_json_name(),
components: components(db, &typ),
}
})
.collect();

Contract { events, functions }
let return_type = sig.return_type.clone().expect("function return type error");
let outputs = if return_type.is_unit() {
vec![]
} else {
vec![FuncOutput {
name: "".to_string(),
typ: return_type.abi_json_name(),
components: components(db, &return_type),
}]
};

Function {
name: name.to_string(),
typ,
inputs,
outputs,
}
}

fn components(db: &dyn AnalyzerDb, typ: &types::FixedSize) -> Vec<Component> {
Expand Down Expand Up @@ -157,13 +154,13 @@ mod tests {
assert_eq!(abi.events[0].name, "Food");
// function count
assert_eq!(abi.functions.len(), 2);
// __init__
assert_eq!(abi.functions[0].name, "");
assert_eq!(abi.functions[0].inputs[0].typ, "address",);
// bar
assert_eq!(abi.functions[1].name, "bar",);
assert_eq!(abi.functions[1].inputs[0].typ, "uint256",);
assert_eq!(abi.functions[1].outputs[0].typ, "uint256[10]",);
assert_eq!(abi.functions[0].name, "bar",);
assert_eq!(abi.functions[0].inputs[0].typ, "uint256",);
assert_eq!(abi.functions[0].outputs[0].typ, "uint256[10]",);
// __init__ always comes after normal functions
assert_eq!(abi.functions[1].name, "");
assert_eq!(abi.functions[1].inputs[0].typ, "address",);
} else {
panic!("contract \"Foo\" not found in module")
}
Expand Down
12 changes: 12 additions & 0 deletions crates/analyzer/src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ pub trait AnalyzerDb {
#[salsa::interned]
fn intern_event(&self, data: Rc<items::Event>) -> EventId;

// Module
#[salsa::invoke(queries::module::module_all_type_defs)]
fn module_all_type_defs(&self, module: ModuleId) -> Rc<Vec<TypeDefId>>;
#[salsa::invoke(queries::module::module_type_def_map)]
Expand All @@ -61,14 +62,21 @@ pub trait AnalyzerDb {
#[salsa::invoke(queries::module::module_structs)]
fn module_structs(&self, module: ModuleId) -> Rc<Vec<StructId>>;

// Contract
#[salsa::invoke(queries::contracts::contract_all_functions)]
fn contract_all_functions(&self, id: ContractId) -> Rc<Vec<FunctionId>>;
#[salsa::invoke(queries::contracts::contract_function_map)]
fn contract_function_map(&self, id: ContractId) -> Analysis<Rc<IndexMap<String, FunctionId>>>;
#[salsa::invoke(queries::contracts::contract_public_function_map)]
fn contract_public_function_map(&self, id: ContractId) -> Rc<IndexMap<String, FunctionId>>;
#[salsa::invoke(queries::contracts::contract_init_function)]
fn contract_init_function(&self, id: ContractId) -> Analysis<Option<FunctionId>>;

#[salsa::invoke(queries::contracts::contract_all_events)]
fn contract_all_events(&self, id: ContractId) -> Rc<Vec<EventId>>;
#[salsa::invoke(queries::contracts::contract_event_map)]
fn contract_event_map(&self, id: ContractId) -> Analysis<Rc<IndexMap<String, EventId>>>;

#[salsa::invoke(queries::contracts::contract_all_fields)]
fn contract_all_fields(&self, id: ContractId) -> Rc<Vec<ContractFieldId>>;
#[salsa::invoke(queries::contracts::contract_field_map)]
Expand All @@ -80,11 +88,13 @@ pub trait AnalyzerDb {
field: ContractFieldId,
) -> Analysis<Result<types::Type, TypeError>>;

// Function
#[salsa::invoke(queries::functions::function_signature)]
fn function_signature(&self, id: FunctionId) -> Analysis<Rc<types::FunctionSignature>>;
#[salsa::invoke(queries::functions::function_body)]
fn function_body(&self, id: FunctionId) -> Analysis<Rc<FunctionBody>>;

// Struct
#[salsa::invoke(queries::structs::struct_type)]
fn struct_type(&self, id: StructId) -> Rc<types::Struct>;
#[salsa::invoke(queries::structs::struct_all_fields)]
Expand All @@ -97,9 +107,11 @@ pub trait AnalyzerDb {
field: StructFieldId,
) -> Analysis<Result<types::FixedSize, TypeError>>;

// Event
#[salsa::invoke(queries::events::event_type)]
fn event_type(&self, event: EventId) -> Analysis<Rc<types::Event>>;

// Type alias
#[salsa::invoke(queries::types::type_alias_type)]
#[salsa::cycle(queries::types::type_alias_type_cycle)]
fn type_alias_type(&self, id: TypeAliasId) -> Analysis<Result<types::Type, TypeError>>;
Expand Down
72 changes: 70 additions & 2 deletions crates/analyzer/src/db/queries/contracts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::namespace::items::{self, ContractFieldId, ContractId, EventId, Functi
use crate::namespace::scopes::ItemScope;
use crate::namespace::types;
use crate::traversal::types::type_desc;
use fe_common::diagnostics::Label;
use fe_parser::ast;
use indexmap::map::{Entry, IndexMap};

Expand Down Expand Up @@ -36,9 +37,11 @@ pub fn contract_function_map(

for func in contract.all_functions(db).iter() {
let def = &func.data(db).ast;
let name = def.name().to_string();
if def.name() == "__init__" {
continue;
}

match map.entry(name) {
match map.entry(def.name().to_string()) {
Entry::Occupied(entry) => {
scope.duplicate_name_error(
&format!(
Expand All @@ -61,6 +64,71 @@ pub fn contract_function_map(
}
}

pub fn contract_public_function_map(
db: &dyn AnalyzerDb,
contract: ContractId,
) -> Rc<IndexMap<String, FunctionId>> {
Rc::new(
contract
.functions(db)
.iter()
.filter_map(|(name, func)| func.is_public(db).then(|| (name.clone(), *func)))
.collect(),
)
}

pub fn contract_init_function(
db: &dyn AnalyzerDb,
contract: ContractId,
) -> Analysis<Option<FunctionId>> {
let all_fns = contract.all_functions(db);
let mut init_fns = all_fns.iter().filter_map(|func| {
let def = &func.data(db).ast;
(def.name() == "__init__").then(|| (func, def.span))
});

let mut diagnostics = vec![];

let first_def = init_fns.next();
if let Some((_, dupe_span)) = init_fns.next() {
let mut labels = vec![
Label::primary(first_def.unwrap().1, "`__init__` first defined here"),
Label::secondary(dupe_span, "`init` redefined here"),
];
while let Some((_, dupe_span)) = init_fns.next() {
labels.push(Label::secondary(dupe_span, "`init` redefined here"));
}
diagnostics.push(errors::fancy_error(
&format!(
"`fn __init__()` is defined multiple times in `contract {}`",
contract.name(db),
),
labels,
vec![],
));
}

if let Some((id, span)) = first_def {
// `__init__` must be `pub`.
// Return type is checked in `queries::functions::function_signature`.
if !id.data(db).ast.kind.is_pub {
diagnostics.push(errors::fancy_error(
"`__init__` function is not public",
vec![Label::primary(span, "`__init__` function must be public")],
vec![
"Hint: Add the `pub` modifier.".to_string(),
"Example: `pub fn __init__():`".to_string(),
],
));
}
}

Analysis {
value: first_def.map(|(id, _span)| *id),
diagnostics: Rc::new(diagnostics),
}
}

pub fn contract_all_events(db: &dyn AnalyzerDb, contract: ContractId) -> Rc<Vec<EventId>> {
let body = &contract.data(db).ast.kind.body;
Rc::new(
Expand Down
15 changes: 0 additions & 15 deletions crates/analyzer/src/db/queries/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,21 +65,6 @@ pub fn function_signature(
})
.collect();

// `__init__` must be `pub`.
if def.name.kind == "__init__" && !def.is_pub {
scope.fancy_error(
"`__init__` function is not public",
vec![Label::primary(
node.span,
"`__init__` function must be public",
)],
vec![
"Hint: Add the `pub` modifier.".to_string(),
"Example: `pub fn __init__():`".to_string(),
],
);
}

let return_type = def
.return_type
.as_ref()
Expand Down
27 changes: 23 additions & 4 deletions crates/analyzer/src/namespace/items.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,28 +218,46 @@ impl ContractId {
Some((field.typ(db), index))
}

pub fn function(&self, db: &dyn AnalyzerDb, name: &str) -> Option<FunctionId> {
self.functions(db).get(name).copied()
pub fn init_function(&self, db: &dyn AnalyzerDb) -> Option<FunctionId> {
db.contract_init_function(*self).value
}

/// User functions, public and not. Excludes `__init__`.
pub fn functions(&self, db: &dyn AnalyzerDb) -> Rc<IndexMap<String, FunctionId>> {
db.contract_function_map(*self).value
}

/// All functions, including duplicates
/// Lookup a function by name. Searches all user functions, private or not. Excludes init function.
pub fn function(&self, db: &dyn AnalyzerDb, name: &str) -> Option<FunctionId> {
self.functions(db).get(name).copied()
}

/// Excludes `__init__`.
pub fn public_functions(&self, db: &dyn AnalyzerDb) -> Rc<IndexMap<String, FunctionId>> {
db.contract_public_function_map(*self)
}

/// Lookup a function by name. Matches on public and private functions, excludes init function.
pub fn public_function(&self, db: &dyn AnalyzerDb, name: &str) -> Option<FunctionId> {
self.public_functions(db).get(name).copied()
}

/// A `Vec` of every function defined in the contract, including duplicates and the init function.
pub fn all_functions(&self, db: &dyn AnalyzerDb) -> Rc<Vec<FunctionId>> {
db.contract_all_functions(*self)
}

/// Lookup an event by name.
pub fn event(&self, db: &dyn AnalyzerDb, name: &str) -> Option<EventId> {
self.events(db).get(name).copied()
}

/// A map of events defined within the contract.
pub fn events(&self, db: &dyn AnalyzerDb) -> Rc<IndexMap<String, EventId>> {
db.contract_event_map(*self).value
}

/// All events, including duplicates
/// A `Vec` of all events defined within the contract, including those with duplicate names.
pub fn all_events(&self, db: &dyn AnalyzerDb) -> Rc<Vec<EventId>> {
db.contract_all_events(*self)
}
Expand All @@ -258,6 +276,7 @@ impl ContractId {
.for_each(|event| event.sink_diagnostics(db, sink));

// functions
db.contract_init_function(*self).sink_diagnostics(sink);
db.contract_function_map(*self).sink_diagnostics(sink);
db.contract_all_functions(*self)
.iter()
Expand Down
Loading

0 comments on commit 31b1e72

Please sign in to comment.