diff --git a/crates/lowering/src/mappers/functions.rs b/crates/lowering/src/mappers/functions.rs index 048d9745c5..b3303be097 100644 --- a/crates/lowering/src/mappers/functions.rs +++ b/crates/lowering/src/mappers/functions.rs @@ -22,6 +22,7 @@ pub fn func_def(context: &mut ModuleContext, function: FunctionId) -> Node Node ast::Function { unsafe_: None, name: names::list_expr_generator_fn_name(array).into_node(), args, + generic_params: Vec::new().into_node(), return_type, body: [vec![var_decl], assignments, vec![return_stmt]].concat(), } diff --git a/crates/parser/src/ast.rs b/crates/parser/src/ast.rs index 639ac2a485..ebaaa5b053 100644 --- a/crates/parser/src/ast.rs +++ b/crates/parser/src/ast.rs @@ -130,6 +130,24 @@ impl Spanned for GenericArg { } } +#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash, Clone)] +pub enum GenericParameter { + Unbounded(Node), + Bounded { + name: Node, + bound: Node, + }, +} + +impl Spanned for GenericParameter { + fn span(&self) -> Span { + match self { + GenericParameter::Unbounded(node) => node.span, + GenericParameter::Bounded { name, bound } => name.span + bound.span, + } + } +} + /// struct or contract field, with optional 'pub' and 'const' qualifiers #[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash, Clone)] pub struct Field { @@ -160,6 +178,7 @@ pub struct Function { pub pub_: Option, pub unsafe_: Option, pub name: Node, + pub generic_params: Node>, pub args: Vec>, pub return_type: Option>, pub body: Vec>, diff --git a/crates/parser/src/grammar/functions.rs b/crates/parser/src/grammar/functions.rs index 99ef6085b2..df956557a4 100644 --- a/crates/parser/src/grammar/functions.rs +++ b/crates/parser/src/grammar/functions.rs @@ -2,11 +2,11 @@ use super::expressions::{parse_call_args, parse_expr}; use super::types::parse_type_desc; use crate::ast::{ - BinOperator, Expr, FuncStmt, Function, FunctionArg, RegularFunctionArg, VarDeclTarget, + BinOperator, Expr, FuncStmt, Function, FunctionArg, GenericParameter, RegularFunctionArg, + VarDeclTarget, }; -use crate::lexer::TokenKind; use crate::node::{Node, Span}; -use crate::{Label, ParseFailed, ParseResult, Parser}; +use crate::{Label, ParseFailed, ParseResult, Parser, TokenKind}; /// Parse a function definition. The optional `pub` qualifier must be parsed by /// the caller, and passed in. Next token must be `unsafe` or `fn`. @@ -28,7 +28,14 @@ pub fn parse_fn_def(par: &mut Parser, mut pub_qual: Option) -> ParseResult } let fn_tok = par.expect(TokenKind::Fn, "failed to parse function definition")?; let name = par.expect(TokenKind::Name, "failed to parse function definition")?; - let mut span = fn_tok.span + unsafe_qual + pub_qual + name.span; + + let generic_params = if par.peek() == Some(TokenKind::Lt) { + parse_generic_params(par)? + } else { + Node::new(vec![], name.span) + }; + + let mut span = fn_tok.span + unsafe_qual + pub_qual + name.span + generic_params.span; let args = match par.peek_or_err()? { TokenKind::ParenOpen => { @@ -87,6 +94,7 @@ pub fn parse_fn_def(par: &mut Parser, mut pub_qual: Option) -> ParseResult unsafe_: unsafe_qual, name: name.into(), args, + generic_params, return_type, body, }, @@ -94,6 +102,87 @@ pub fn parse_fn_def(par: &mut Parser, mut pub_qual: Option) -> ParseResult )) } +pub fn parse_generic_param(par: &mut Parser) -> ParseResult { + use TokenKind::*; + + let name = par.assert(Name); + match par.optional(Colon) { + Some(_) => { + let bound = par.assert(Name); + return Ok(GenericParameter::Bounded { + name: Node::new(name.text.into(), name.span), + bound: Node::new(bound.text.into(), bound.span), + }); + } + None => { + return Ok(GenericParameter::Unbounded(Node::new( + name.text.into(), + name.span, + ))) + } + } +} + +/// Parse an angle-bracket-wrapped list of generic arguments (eg. the part wrapped in angle brackets +/// of `fn foo(some_arg: u256) -> bool`). +/// # Panics +/// Panics if the first token isn't `<`. +pub fn parse_generic_params(par: &mut Parser) -> ParseResult>> { + use TokenKind::*; + let mut span = par.assert(Lt).span; + + let mut args = vec![]; + + let expect_end = |par: &mut Parser| { + // If there's no comma, the next token must be `>` + match par.peek_or_err()? { + Gt => Ok(par.next()?.span), + _ => { + let tok = par.next()?; + par.unexpected_token_error( + tok.span, + "Unexpected token while parsing generic arg list", + vec![], + ); + Err(ParseFailed) + } + } + }; + + + + loop { + match par.peek_or_err()? { + Gt => { + span += par.next()?.span; + break; + } + Name => { + let typ = parse_generic_param(par)?; + args.push(typ); + if par.peek() == Some(Comma) { + par.next()?; + } else { + span += expect_end(par)?; + break; + } + } + + // Invalid generic argument. + _ => { + let tok = par.next()?; + par.unexpected_token_error( + tok.span, + "failed to parse generic type argument list", + vec![], + ); + return Err(ParseFailed); + } + } + } + Ok(Node::new(args, span)) +} + fn parse_fn_param_list(par: &mut Parser) -> ParseResult>>> { let mut span = par.assert(TokenKind::ParenOpen).span; let mut params = vec![]; diff --git a/crates/parser/src/grammar/traits.rs b/crates/parser/src/grammar/traits.rs index 7c397279e6..147367aceb 100644 --- a/crates/parser/src/grammar/traits.rs +++ b/crates/parser/src/grammar/traits.rs @@ -6,10 +6,7 @@ use crate::{ParseFailed, ParseResult, Parser, TokenKind}; /// Parse a trait definition. /// # Panics /// Panics if the next token isn't `trait`. -pub fn parse_trait_def( - par: &mut Parser, - trait_pub_qual: Option, -) -> ParseResult> { +pub fn parse_trait_def(par: &mut Parser, trait_pub_qual: Option) -> ParseResult> { let trait_tok = par.assert(TokenKind::Trait); // trait Event: @@ -26,7 +23,6 @@ pub fn parse_trait_def( par.enter_block(header_span, "trait definition")?; loop { - match par.peek() { Some(TokenKind::Pass) => { parse_single_word_stmt(par)?; diff --git a/crates/parser/tests/cases/parse_ast.rs b/crates/parser/tests/cases/parse_ast.rs index 443342719a..c4fe39fa47 100644 --- a/crates/parser/tests/cases/parse_ast.rs +++ b/crates/parser/tests/cases/parse_ast.rs @@ -141,6 +141,8 @@ test_parse! { type_tuple, types::parse_type_desc, "(u8, u16, address, Map bool:\n false"} + +test_parse! { fn_def_generic, try_parse_module, "fn transfer(from sender: address, to recip: address, _ val: u64) -> bool:\n false"} test_parse! { fn_def_pub, try_parse_module, "pub fn foo21(x: bool, y: address,) -> bool:\n x"} test_parse! { fn_def_unsafe, try_parse_module, "unsafe fn foo21(x: bool, y: address,) -> bool:\n x"} test_parse! { fn_def_pub_unsafe, try_parse_module, "pub unsafe fn foo21(x: bool, y: address,) -> bool:\n x"} diff --git a/crates/parser/tests/cases/snapshots/cases__parse_ast__fn_def.snap b/crates/parser/tests/cases/snapshots/cases__parse_ast__fn_def.snap index ebed258f4b..24be74401c 100644 --- a/crates/parser/tests/cases/snapshots/cases__parse_ast__fn_def.snap +++ b/crates/parser/tests/cases/snapshots/cases__parse_ast__fn_def.snap @@ -1,6 +1,6 @@ --- source: crates/parser/tests/cases/parse_ast.rs -expression: "ast_string(stringify!(fn_def), try_parse_module,\n \"fn transfer(from sender: address, to recip: address, _ val: u64) -> bool:\\n false\")" +expression: "ast_string(stringify!(fn_def), try_parse_module,\n \"fn transfer(from sender: address, to recip: address, _ val: u64) -> bool:\\n false\")" --- Node( @@ -17,6 +17,13 @@ Node( end: 11, ), ), + generic_params: Node( + kind: [], + span: Span( + start: 3, + end: 11, + ), + ), args: [ Node( kind: Regular(RegularFunctionArg( diff --git a/crates/parser/tests/cases/snapshots/cases__parse_ast__fn_def_generic.snap b/crates/parser/tests/cases/snapshots/cases__parse_ast__fn_def_generic.snap new file mode 100644 index 0000000000..a28838a913 --- /dev/null +++ b/crates/parser/tests/cases/snapshots/cases__parse_ast__fn_def_generic.snap @@ -0,0 +1,184 @@ +--- +source: crates/parser/tests/cases/parse_ast.rs +expression: "ast_string(stringify!(fn_def_generic), try_parse_module,\n \"fn transfer(from sender: address, to recip: address, _ val: u64) -> bool:\\n false\")" + +--- +Node( + kind: Module( + body: [ + Function(Node( + kind: Function( + pub_: None, + unsafe_: None, + name: Node( + kind: "transfer", + span: Span( + start: 3, + end: 11, + ), + ), + generic_params: Node( + kind: [ + Unbounded(Node( + kind: "T", + span: Span( + start: 12, + end: 13, + ), + )), + Bounded( + name: Node( + kind: "R", + span: Span( + start: 15, + end: 16, + ), + ), + bound: Node( + kind: "Event", + span: Span( + start: 18, + end: 23, + ), + ), + ), + ], + span: Span( + start: 11, + end: 24, + ), + ), + args: [ + Node( + kind: Regular(RegularFunctionArg( + label: Some(Node( + kind: "from", + span: Span( + start: 25, + end: 29, + ), + )), + name: Node( + kind: "sender", + span: Span( + start: 30, + end: 36, + ), + ), + typ: Node( + kind: Base( + base: "address", + ), + span: Span( + start: 38, + end: 45, + ), + ), + )), + span: Span( + start: 30, + end: 45, + ), + ), + Node( + kind: Regular(RegularFunctionArg( + label: Some(Node( + kind: "to", + span: Span( + start: 47, + end: 49, + ), + )), + name: Node( + kind: "recip", + span: Span( + start: 50, + end: 55, + ), + ), + typ: Node( + kind: Base( + base: "address", + ), + span: Span( + start: 57, + end: 64, + ), + ), + )), + span: Span( + start: 50, + end: 64, + ), + ), + Node( + kind: Regular(RegularFunctionArg( + label: Some(Node( + kind: "_", + span: Span( + start: 66, + end: 67, + ), + )), + name: Node( + kind: "val", + span: Span( + start: 68, + end: 71, + ), + ), + typ: Node( + kind: Base( + base: "u64", + ), + span: Span( + start: 73, + end: 76, + ), + ), + )), + span: Span( + start: 68, + end: 76, + ), + ), + ], + return_type: Some(Node( + kind: Base( + base: "bool", + ), + span: Span( + start: 81, + end: 85, + ), + )), + body: [ + Node( + kind: Expr( + value: Node( + kind: Bool(false), + span: Span( + start: 88, + end: 93, + ), + ), + ), + span: Span( + start: 88, + end: 93, + ), + ), + ], + ), + span: Span( + start: 0, + end: 93, + ), + )), + ], + ), + span: Span( + start: 0, + end: 93, + ), +) diff --git a/crates/parser/tests/cases/snapshots/cases__parse_ast__fn_def_pub.snap b/crates/parser/tests/cases/snapshots/cases__parse_ast__fn_def_pub.snap index 1110f57e14..aba7965c16 100644 --- a/crates/parser/tests/cases/snapshots/cases__parse_ast__fn_def_pub.snap +++ b/crates/parser/tests/cases/snapshots/cases__parse_ast__fn_def_pub.snap @@ -1,6 +1,6 @@ --- source: crates/parser/tests/cases/parse_ast.rs -expression: "ast_string(stringify!(fn_def_pub), try_parse_module,\n \"pub fn foo21(x: bool, y: address,) -> bool:\\n x\")" +expression: "ast_string(stringify!(fn_def_pub), try_parse_module,\n \"pub fn foo21(x: bool, y: address,) -> bool:\\n x\")" --- Node( @@ -20,6 +20,13 @@ Node( end: 12, ), ), + generic_params: Node( + kind: [], + span: Span( + start: 7, + end: 12, + ), + ), args: [ Node( kind: Regular(RegularFunctionArg( diff --git a/crates/parser/tests/cases/snapshots/cases__parse_ast__fn_def_pub_unsafe.snap b/crates/parser/tests/cases/snapshots/cases__parse_ast__fn_def_pub_unsafe.snap index 4d3486bddf..23de356ec9 100644 --- a/crates/parser/tests/cases/snapshots/cases__parse_ast__fn_def_pub_unsafe.snap +++ b/crates/parser/tests/cases/snapshots/cases__parse_ast__fn_def_pub_unsafe.snap @@ -1,6 +1,6 @@ --- source: crates/parser/tests/cases/parse_ast.rs -expression: "ast_string(stringify!(fn_def_pub_unsafe), try_parse_module,\n \"pub unsafe fn foo21(x: bool, y: address,) -> bool:\\n x\")" +expression: "ast_string(stringify!(fn_def_pub_unsafe), try_parse_module,\n \"pub unsafe fn foo21(x: bool, y: address,) -> bool:\\n x\")" --- Node( @@ -23,6 +23,13 @@ Node( end: 19, ), ), + generic_params: Node( + kind: [], + span: Span( + start: 14, + end: 19, + ), + ), args: [ Node( kind: Regular(RegularFunctionArg( diff --git a/crates/parser/tests/cases/snapshots/cases__parse_ast__fn_def_unsafe.snap b/crates/parser/tests/cases/snapshots/cases__parse_ast__fn_def_unsafe.snap index bad0eae436..c63c32db1e 100644 --- a/crates/parser/tests/cases/snapshots/cases__parse_ast__fn_def_unsafe.snap +++ b/crates/parser/tests/cases/snapshots/cases__parse_ast__fn_def_unsafe.snap @@ -1,6 +1,6 @@ --- source: crates/parser/tests/cases/parse_ast.rs -expression: "ast_string(stringify!(fn_def_unsafe), try_parse_module,\n \"unsafe fn foo21(x: bool, y: address,) -> bool:\\n x\")" +expression: "ast_string(stringify!(fn_def_unsafe), try_parse_module,\n \"unsafe fn foo21(x: bool, y: address,) -> bool:\\n x\")" --- Node( @@ -20,6 +20,13 @@ Node( end: 15, ), ), + generic_params: Node( + kind: [], + span: Span( + start: 10, + end: 15, + ), + ), args: [ Node( kind: Regular(RegularFunctionArg(