From edf2689196052bdc46cd2697fa4fb7f1abae330f Mon Sep 17 00:00:00 2001 From: jekky <11986158+jac3km4@users.noreply.github.com> Date: Mon, 27 May 2024 18:51:42 +0100 Subject: [PATCH] feat: syntax sugar for calling static methods with a receiver --- compiler/src/assembler.rs | 39 ++++------- compiler/src/error.rs | 12 ++-- compiler/src/scope.rs | 3 - compiler/src/typechecker.rs | 126 +++++++++++++++++++++++++++--------- compiler/src/unit.rs | 14 +--- compiler/tests/bytecode.rs | 46 +++++++++++++ compiler/tests/compile.rs | 47 ++++++++++++++ compiler/tests/utils.rs | 6 +- 8 files changed, 213 insertions(+), 80 deletions(-) diff --git a/compiler/src/assembler.rs b/compiler/src/assembler.rs index c72b74c..915764e 100644 --- a/compiler/src/assembler.rs +++ b/compiler/src/assembler.rs @@ -252,34 +252,19 @@ impl<'a> Assembler<'a> { self.assemble_intrinsic(op, args.into_vec(), &type_, scope, pool, span)?; } }, - Expr::MethodCall(expr, fun_idx, args, span) => { - let fun = pool.function(fun_idx)?; - match *expr { - Expr::Ident(Reference::Symbol(Symbol::Class(_, _) | Symbol::Struct(_, _)), span) => { - if fun.flags.is_static() { - self.assemble_call(fun_idx, args, scope, pool, true, span)?; - } else { - return Err( - Cause::InvalidNonStaticMethodCall(Ident::from_heap(pool.def_name(fun_idx)?)) - .with_span(span), - ); - } - } - _ if fun.flags.is_static() => { - return Err( - Cause::InvalidStaticMethodCall(Ident::from_heap(pool.def_name(fun_idx)?)).with_span(span), - ); - } - expr => { - let force_static_call = matches!(&expr, Expr::Super(_)); - let exit_label = self.new_label(); - self.emit(Instr::Context(exit_label)); - self.assemble(expr, scope, pool, None)?; - self.assemble_call(fun_idx, args, scope, pool, force_static_call, span)?; - self.emit_label(exit_label); - } + Expr::MethodCall(expr, fun_idx, args, span) => match *expr { + Expr::Ident(Reference::Symbol(Symbol::Class(_, _) | Symbol::Struct(_, _)), span) => { + self.assemble_call(fun_idx, args, scope, pool, true, span)?; } - } + expr => { + let force_static_call = matches!(&expr, Expr::Super(_)); + let exit_label = self.new_label(); + self.emit(Instr::Context(exit_label)); + self.assemble(expr, scope, pool, None)?; + self.assemble_call(fun_idx, args, scope, pool, force_static_call, span)?; + self.emit_label(exit_label); + } + }, Expr::Null(_) => { self.emit(Instr::Null); diff --git a/compiler/src/error.rs b/compiler/src/error.rs index 175d879..529df33 100644 --- a/compiler/src/error.rs +++ b/compiler/src/error.rs @@ -71,10 +71,10 @@ pub enum Cause { UnexpectedValueReturn, #[error("invalid use of {0}, unexpected {1}")] InvalidIntrinsicUse(Intrinsic, Ident), - #[error("method {0} is static")] - InvalidStaticMethodCall(Ident), - #[error("method {0} is not static")] - InvalidNonStaticMethodCall(Ident), + #[error("this method is static, it cannot be used on an instance of an object")] + InvalidStaticMethodCall, + #[error("this method is not static, it must be used on an instance of an object")] + InvalidNonStaticMethodCall, #[error("no 'this' available in a static context")] UnexpectedThis, #[error("{0} is not supported")] @@ -149,8 +149,8 @@ impl Cause { Self::InvalidAnnotationArgs => "INVALID_ANN_USE", Self::InvalidMemberAccess(_) => "INVALID_MEMBER_ACCESS", Self::VoidCannotBeUsed => "INVALID_VOID_USE", - Self::InvalidStaticMethodCall(_) => "INVALID_STATIC_USE", - Self::InvalidNonStaticMethodCall(_) => "INVALID_NONSTATIC_USE", + Self::InvalidStaticMethodCall => "INVALID_STATIC_USE", + Self::InvalidNonStaticMethodCall => "INVALID_NONSTATIC_USE", Self::UnexpectedThis => "UNEXPECTED_THIS", Self::SymbolRedefinition(_) => "SYM_REDEFINITION", Self::FieldRedefinition => "FIELD_REDEFINITION", diff --git a/compiler/src/scope.rs b/compiler/src/scope.rs index 0c4ebbf..76980b2 100644 --- a/compiler/src/scope.rs +++ b/compiler/src/scope.rs @@ -109,7 +109,6 @@ impl Scope { } pub fn resolve_enum_member( - &self, ident: Ident, enum_idx: PoolIndex, pool: &ConstantPool, @@ -124,7 +123,6 @@ impl Scope { } pub fn resolve_method( - &self, ident: Ident, class_idx: PoolIndex, pool: &ConstantPool, @@ -152,7 +150,6 @@ impl Scope { } pub fn resolve_direct_method( - &self, ident: Ident, class_idx: PoolIndex, pool: &ConstantPool, diff --git a/compiler/src/typechecker.rs b/compiler/src/typechecker.rs index 9ae3bc2..efc5052 100644 --- a/compiler/src/typechecker.rs +++ b/compiler/src/typechecker.rs @@ -1,4 +1,5 @@ use std::iter; +use std::ops::Not; use std::str::FromStr; use itertools::{izip, Itertools}; @@ -142,8 +143,15 @@ impl<'a> TypeChecker<'a> { self.check_intrinsic(intrinsic, args, expected.as_ref(), scope, *span)? } else { let candidates = scope.resolve_function(name.clone()).with_span(*span)?; - let match_ = - self.resolve_overload(name.clone(), candidates, args.iter(), expected.as_ref(), scope, *span)?; + let match_ = self.resolve_overload( + name.clone(), + candidates, + args.iter(), + expected.as_ref(), + None, + scope, + *span, + )?; Expr::Call( Callable::Function(match_.index), [].into(), @@ -152,28 +160,49 @@ impl<'a> TypeChecker<'a> { ) } } - Expr::MethodCall(context, name, args, span) => { - let checked_context = self.check(context, None, scope)?; - let type_ = type_of(&checked_context, scope, self.pool)?; - let class = match type_.unwrapped() { + Expr::MethodCall(receiver, name, args, span) => { + let checked_receiver = self.check(receiver, None, scope)?; + let receiver_type = type_of(&checked_receiver, scope, self.pool)?; + let class = match receiver_type.unwrapped() { TypeId::Struct(class) | TypeId::Class(class) => *class, type_ => return Err(Cause::InvalidMemberAccess(type_.pretty(self.pool)?).with_span(*span)), }; - let candidates = scope.resolve_method(name.clone(), class, self.pool).with_span(*span)?; - let match_ = self.resolve_overload(name.clone(), candidates, args.iter(), expected, scope, *span)?; + let receiver = matches!(checked_receiver, Expr::Ident(Reference::Symbol(_), _)) + .not() + .then_some(&receiver_type); + let candidates = Scope::resolve_method(name.clone(), class, self.pool).with_span(*span)?; - let converted_context = if let TypeId::WeakRef(inner) = type_ { - insert_conversion(checked_context, &TypeId::Ref(inner), Conversion::WeakRefToRef) + let match_ = + self.resolve_overload(name.clone(), candidates, args.iter(), expected, receiver, scope, *span)?; + + if match_.insert_receiver { + Expr::Call( + Callable::Function(match_.index), + [].into(), + iter::once(checked_receiver).chain(match_.args).collect(), + *span, + ) } else { - checked_context - }; - Expr::MethodCall(Box::new(converted_context), match_.index, match_.args, *span) + let is_static = self.pool.function(match_.index).is_ok_and(|f| f.flags.is_static()); + if receiver.is_some() && is_static { + return Err(Cause::InvalidStaticMethodCall.with_span(*span)); + } else if receiver.is_none() && !is_static { + return Err(Cause::InvalidNonStaticMethodCall.with_span(*span)); + } + + let converted_receiver = if let TypeId::WeakRef(inner) = receiver_type { + insert_conversion(checked_receiver, &TypeId::Ref(inner), Conversion::WeakRefToRef) + } else { + checked_receiver + }; + Expr::MethodCall(Box::new(converted_receiver), match_.index, match_.args, *span) + } } Expr::BinOp(lhs, rhs, op, span) => { let name = Ident::from_static(op.into()); let args = IntoIterator::into_iter([lhs.as_ref(), rhs.as_ref()]); let candidates = scope.resolve_function(name.clone()).with_span(*span)?; - let match_ = self.resolve_overload(name, candidates, args, expected, scope, *span)?; + let match_ = self.resolve_overload(name, candidates, args, expected, None, scope, *span)?; Expr::Call( Callable::Function(match_.index), [].into(), @@ -185,7 +214,7 @@ impl<'a> TypeChecker<'a> { let name = Ident::from_static(op.into()); let args = iter::once(expr.as_ref()); let candidates = scope.resolve_function(name.clone()).with_span(*span)?; - let match_ = self.resolve_overload(name, candidates, args, expected, scope, *span)?; + let match_ = self.resolve_overload(name, candidates, args, expected, None, scope, *span)?; Expr::Call( Callable::Function(match_.index), [].into(), @@ -207,9 +236,7 @@ impl<'a> TypeChecker<'a> { Member::StructField(field) } TypeId::Enum(enum_) => { - let member = scope - .resolve_enum_member(name.clone(), *enum_, self.pool) - .with_span(*span)?; + let member = Scope::resolve_enum_member(name.clone(), *enum_, self.pool).with_span(*span)?; Member::EnumMember(*enum_, member) } type_ => return Err(Cause::InvalidMemberAccess(type_.pretty(self.pool)?).with_span(*span)), @@ -558,12 +585,14 @@ impl<'a> TypeChecker<'a> { } } + #[allow(clippy::too_many_arguments)] fn resolve_overload<'b>( &mut self, name: Ident, overloads: FunctionCandidates, args: impl ExactSizeIterator> + Clone, expected: Option<&TypeId>, + receiver: Option<&TypeId>, scope: &mut Scope, span: Span, ) -> Result { @@ -572,7 +601,7 @@ impl<'a> TypeChecker<'a> { let mut overload_errors = vec![]; for overload in &overloads.functions { - match Self::validate_call(*overload, arg_count, scope, self.pool, span) { + match Self::validate_call(*overload, arg_count, receiver, scope, self.pool, span) { Ok(res) => eligible.push(res), Err(MatcherError::MatchError(err)) => overload_errors.push(err), Err(MatcherError::Other(err)) => return Err(err), @@ -580,7 +609,7 @@ impl<'a> TypeChecker<'a> { } let match_ = match eligible.into_iter().exactly_one() { - Ok((fun_index, types)) => { + Ok((fun_index, types, has_static_receiver)) => { let checked_args: Vec<_> = args .clone() .zip(&types) @@ -588,7 +617,7 @@ impl<'a> TypeChecker<'a> { .try_collect()?; match Self::validate_args(fun_index, &checked_args, &types, expected, scope, self.pool, span) { - Ok(conversions) => Ok(FunctionMatch::new(fun_index, checked_args, conversions)), + Ok(convs) => Ok(FunctionMatch::new(fun_index, checked_args, convs, has_static_receiver)), Err(MatcherError::MatchError(err)) => { overload_errors.push(err); Err(Cause::NoMatchingOverload(name, overload_errors.into_boxed_slice()).with_span(span)) @@ -600,9 +629,9 @@ impl<'a> TypeChecker<'a> { let checked_args: Vec<_> = args.clone().map(|expr| self.check(expr, None, scope)).try_collect()?; let mut matches = vec![]; - for (fun_index, types) in eligible { + for (fun_index, types, has_static_receiver) in eligible { match Self::validate_args(fun_index, &checked_args, &types, expected, scope, self.pool, span) { - Ok(convs) => matches.push((fun_index, convs)), + Ok(convs) => matches.push((fun_index, convs, has_static_receiver)), Err(MatcherError::MatchError(err)) => overload_errors.push(err), Err(MatcherError::Other(err)) => return Err(err), } @@ -610,8 +639,10 @@ impl<'a> TypeChecker<'a> { let mut it = matches.into_iter(); match it.next() { - None => Err(Cause::NoMatchingOverload(name, overload_errors.into_boxed_slice()).with_span(span)), - Some((fun_index, convs)) => Ok(FunctionMatch::new(fun_index, checked_args, convs)), + None => Err(Cause::NoMatchingOverload(name, overload_errors.into()).with_span(span)), + Some((fun_index, convs, has_static_receiver)) => { + Ok(FunctionMatch::new(fun_index, checked_args, convs, has_static_receiver)) + } } } }; @@ -622,7 +653,7 @@ impl<'a> TypeChecker<'a> { let dummy_args: Vec<_> = args.map(|expr| self.check(expr, None, scope)).try_collect()?; let convs = iter::repeat(ArgConversion::identity()).take(dummy_args.len()).collect(); - Ok(FunctionMatch::new(overloads.functions[0], dummy_args, convs)) + Ok(FunctionMatch::new(overloads.functions[0], dummy_args, convs, false)) } Err(err) => Err(err), } @@ -631,16 +662,37 @@ impl<'a> TypeChecker<'a> { fn validate_call( fun_index: PoolIndex, arg_count: usize, + receiver: Option<&TypeId>, scope: &Scope, pool: &ConstantPool, span: Span, - ) -> Result<(PoolIndex, Vec), MatcherError> { + ) -> Result<(PoolIndex, Vec, bool), MatcherError> { let fun = pool.function(fun_index)?; - let params = fun + let mut params = fun .parameters .iter() .map(|idx| pool.parameter(*idx).map_err(Error::PoolError)); - let min_params = params.clone().filter_ok(|param| !param.flags.is_optional()).count(); + + let has_static_receiver = receiver.is_some_and(|receiver| { + fun.flags.is_static() + && params + .by_ref() + .peekable() + .next_if(|p| { + p.as_ref().is_ok_and(|p| { + scope + .resolve_type_from_pool(p.type_, pool) + .is_ok_and(|t| &t == receiver) + }) + }) + .is_some() + }); + + let min_params = params + .clone() + .rev() + .skip_while(|p| p.as_ref().is_ok_and(|p| p.flags.is_optional())) + .count(); if arg_count < min_params || arg_count > fun.parameters.len() { let err = FunctionMatchError::ArgumentCountMismatch { @@ -654,7 +706,7 @@ impl<'a> TypeChecker<'a> { let types = params .map(|res| res.and_then(|param| scope.resolve_type_from_pool(param.type_, pool).with_span(span))) .try_collect()?; - Ok((fun_index, types)) + Ok((fun_index, types, has_static_receiver)) } fn validate_args( @@ -972,16 +1024,26 @@ impl ArgConversion { pub struct FunctionMatch { pub index: PoolIndex, pub args: Vec, + pub insert_receiver: bool, } impl FunctionMatch { - fn new(index: PoolIndex, args: Vec, conversions: Vec) -> Self { + fn new( + index: PoolIndex, + args: Vec, + conversions: Vec, + has_static_receiver: bool, + ) -> Self { let args = args .into_iter() .zip(conversions) .map(|(expr, conv)| insert_conversion(expr, &conv.target, conv.conversion)) .collect(); - Self { index, args } + Self { + index, + args, + insert_receiver: has_static_receiver, + } } } diff --git a/compiler/src/unit.rs b/compiler/src/unit.rs index 6c67fdc..3868687 100644 --- a/compiler/src/unit.rs +++ b/compiler/src/unit.rs @@ -838,10 +838,7 @@ impl<'a> CompilationUnit<'a> { Symbol::Struct(idx, _) | Symbol::Class(idx, _) => idx, _ => return Err(Cause::ClassNotFound(class_name.clone()).with_span(ann.span)), }; - let candidates = self - .scope - .resolve_direct_method(name.clone(), target_class_idx, self.pool) - .ok(); + let candidates = Scope::resolve_direct_method(name.clone(), target_class_idx, self.pool).ok(); let fun_idx = candidates .as_ref() .and_then(|cd| cd.by_id(&sig, self.pool)) @@ -889,10 +886,7 @@ impl<'a> CompilationUnit<'a> { Symbol::Struct(idx, _) | Symbol::Class(idx, _) => idx, _ => return Err(Cause::ClassNotFound(class_name.clone()).with_span(ann.span)), }; - let candidates = self - .scope - .resolve_direct_method(name.clone(), target_class_idx, self.pool) - .ok(); + let candidates = Scope::resolve_direct_method(name.clone(), target_class_idx, self.pool).ok(); let fun_idx = candidates .as_ref() .and_then(|cd| cd.by_id(&sig, self.pool)) @@ -946,9 +940,7 @@ impl<'a> CompilationUnit<'a> { _ => return Err(Cause::ClassNotFound(class_name.clone()).with_span(ann.span)), }; - if self - .scope - .resolve_direct_method(name.clone(), target_class_idx, self.pool) + if Scope::resolve_direct_method(name.clone(), target_class_idx, self.pool) .ok() .and_then(|cd| cd.by_id(&sig, self.pool)) .is_some() diff --git a/compiler/tests/bytecode.rs b/compiler/tests/bytecode.rs index 6222cde..b0b4542 100644 --- a/compiler/tests/bytecode.rs +++ b/compiler/tests/bytecode.rs @@ -904,3 +904,49 @@ fn compile_initializers() { ]; TestContext::compiled(vec![sources]).unwrap().run("Testing", check); } + +#[test] +fn compile_static_receivers() { + let sources = r#" + struct Dummy1 { + static func Test(self: Dummy1) {} + } + + class Dummy2 { + static func Test(self: ref) {} + } + + func Testing() { + let a = new Dummy1(); + a.Test(); + Dummy1.Test(a); + + let b = new Dummy2(); + b.Test(); + Dummy2.Test(b); + } + "#; + + let check = check_code![ + pat!(Assign), + mem!(Local(a)), + pat!(Construct(_, _)), + mem!(InvokeStatic(_0, _1, func1, _2)), + mem!(Local(a)), + pat!(ParamEnd), + mem!(InvokeStatic(_0, _1, func1, _2)), + mem!(Local(a)), + pat!(ParamEnd), + pat!(Assign), + mem!(Local(b)), + mem!(New(class)), + mem!(InvokeStatic(_0, _1, func2, _2)), + mem!(Local(b)), + pat!(ParamEnd), + mem!(InvokeStatic(_0, _1, func2, _2)), + mem!(Local(b)), + pat!(ParamEnd), + pat!(Nop) + ]; + TestContext::compiled(vec![sources]).unwrap().run("Testing", check); +} diff --git a/compiler/tests/compile.rs b/compiler/tests/compile.rs index fc552a0..8e8b9c8 100644 --- a/compiler/tests/compile.rs +++ b/compiler/tests/compile.rs @@ -298,3 +298,50 @@ fn compile_defaults() { ] ); } + +#[test] +fn fail_on_static_receiver_mismatch() { + let sources = r#" + class Dummy { + static func Test(self: wref) {} + } + + func Testing() { + let a = new Dummy(); + a.Test(); + } + "#; + + let (_, errs) = compiled(vec![sources]).unwrap(); + assert!( + matches!( + &errs[..], + &[Diagnostic::CompileError(Cause::InvalidStaticMethodCall, _)] + ), + "{:?}", + errs + ); +} + +#[test] +fn fail_on_static_call_of_instance_method() { + let sources = r#" + class Dummy { + func Test() {} + } + + func Testing() { + Dummy.Test(); + } + "#; + + let (_, errs) = compiled(vec![sources]).unwrap(); + assert!( + matches!( + &errs[..], + &[Diagnostic::CompileError(Cause::InvalidNonStaticMethodCall, _)] + ), + "{:?}", + errs + ); +} diff --git a/compiler/tests/utils.rs b/compiler/tests/utils.rs index 264be32..d88d1d5 100644 --- a/compiler/tests/utils.rs +++ b/compiler/tests/utils.rs @@ -81,7 +81,7 @@ macro_rules! pat { } } -/// macro for matching an instruction and memorizing it's arguments by names +/// macro for matching an instruction and memorizing its arguments by names #[macro_export] macro_rules! mem { ($id:ident($($args:ident),+)) => { @@ -98,6 +98,10 @@ macro_rules! mem { #[macro_export] macro_rules! match_index { (__, $ctx:ident) => {}; + (_0, $ctx:ident) => {}; + (_1, $ctx:ident) => {}; + (_2, $ctx:ident) => {}; + (_3, $ctx:ident) => {}; ($id:ident, $ctx:ident) => { $ctx.match_index($id.cast(), stringify!($id)) };