Skip to content

Commit

Permalink
feat: syntax sugar for calling static methods with a receiver
Browse files Browse the repository at this point in the history
  • Loading branch information
jac3km4 committed May 27, 2024
1 parent 9a886d1 commit edf2689
Show file tree
Hide file tree
Showing 8 changed files with 213 additions and 80 deletions.
39 changes: 12 additions & 27 deletions compiler/src/assembler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -252,34 +252,19 @@ impl<'a> Assembler<'a> {
self.assemble_intrinsic(op, args.into_vec(), &type_, scope, pool, span)?;
}
},
Expr::MethodCall(expr, fun_idx, args, span) => {
let fun = pool.function(fun_idx)?;
match *expr {
Expr::Ident(Reference::Symbol(Symbol::Class(_, _) | Symbol::Struct(_, _)), span) => {
if fun.flags.is_static() {
self.assemble_call(fun_idx, args, scope, pool, true, span)?;
} else {
return Err(
Cause::InvalidNonStaticMethodCall(Ident::from_heap(pool.def_name(fun_idx)?))
.with_span(span),
);
}
}
_ if fun.flags.is_static() => {
return Err(
Cause::InvalidStaticMethodCall(Ident::from_heap(pool.def_name(fun_idx)?)).with_span(span),
);
}
expr => {
let force_static_call = matches!(&expr, Expr::Super(_));
let exit_label = self.new_label();
self.emit(Instr::Context(exit_label));
self.assemble(expr, scope, pool, None)?;
self.assemble_call(fun_idx, args, scope, pool, force_static_call, span)?;
self.emit_label(exit_label);
}
Expr::MethodCall(expr, fun_idx, args, span) => match *expr {
Expr::Ident(Reference::Symbol(Symbol::Class(_, _) | Symbol::Struct(_, _)), span) => {
self.assemble_call(fun_idx, args, scope, pool, true, span)?;
}
}
expr => {
let force_static_call = matches!(&expr, Expr::Super(_));
let exit_label = self.new_label();
self.emit(Instr::Context(exit_label));
self.assemble(expr, scope, pool, None)?;
self.assemble_call(fun_idx, args, scope, pool, force_static_call, span)?;
self.emit_label(exit_label);
}
},

Expr::Null(_) => {
self.emit(Instr::Null);
Expand Down
12 changes: 6 additions & 6 deletions compiler/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ pub enum Cause {
UnexpectedValueReturn,
#[error("invalid use of {0}, unexpected {1}")]
InvalidIntrinsicUse(Intrinsic, Ident),
#[error("method {0} is static")]
InvalidStaticMethodCall(Ident),
#[error("method {0} is not static")]
InvalidNonStaticMethodCall(Ident),
#[error("this method is static, it cannot be used on an instance of an object")]
InvalidStaticMethodCall,
#[error("this method is not static, it must be used on an instance of an object")]
InvalidNonStaticMethodCall,
#[error("no 'this' available in a static context")]
UnexpectedThis,
#[error("{0} is not supported")]
Expand Down Expand Up @@ -149,8 +149,8 @@ impl Cause {
Self::InvalidAnnotationArgs => "INVALID_ANN_USE",
Self::InvalidMemberAccess(_) => "INVALID_MEMBER_ACCESS",
Self::VoidCannotBeUsed => "INVALID_VOID_USE",
Self::InvalidStaticMethodCall(_) => "INVALID_STATIC_USE",
Self::InvalidNonStaticMethodCall(_) => "INVALID_NONSTATIC_USE",
Self::InvalidStaticMethodCall => "INVALID_STATIC_USE",
Self::InvalidNonStaticMethodCall => "INVALID_NONSTATIC_USE",
Self::UnexpectedThis => "UNEXPECTED_THIS",
Self::SymbolRedefinition(_) => "SYM_REDEFINITION",
Self::FieldRedefinition => "FIELD_REDEFINITION",
Expand Down
3 changes: 0 additions & 3 deletions compiler/src/scope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ impl Scope {
}

pub fn resolve_enum_member(
&self,
ident: Ident,
enum_idx: PoolIndex<Enum>,
pool: &ConstantPool,
Expand All @@ -124,7 +123,6 @@ impl Scope {
}

pub fn resolve_method(
&self,
ident: Ident,
class_idx: PoolIndex<Class>,
pool: &ConstantPool,
Expand Down Expand Up @@ -152,7 +150,6 @@ impl Scope {
}

pub fn resolve_direct_method(
&self,
ident: Ident,
class_idx: PoolIndex<Class>,
pool: &ConstantPool,
Expand Down
126 changes: 94 additions & 32 deletions compiler/src/typechecker.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::iter;
use std::ops::Not;
use std::str::FromStr;

use itertools::{izip, Itertools};
Expand Down Expand Up @@ -142,8 +143,15 @@ impl<'a> TypeChecker<'a> {
self.check_intrinsic(intrinsic, args, expected.as_ref(), scope, *span)?
} else {
let candidates = scope.resolve_function(name.clone()).with_span(*span)?;
let match_ =
self.resolve_overload(name.clone(), candidates, args.iter(), expected.as_ref(), scope, *span)?;
let match_ = self.resolve_overload(
name.clone(),
candidates,
args.iter(),
expected.as_ref(),
None,
scope,
*span,
)?;
Expr::Call(
Callable::Function(match_.index),
[].into(),
Expand All @@ -152,28 +160,49 @@ impl<'a> TypeChecker<'a> {
)
}
}
Expr::MethodCall(context, name, args, span) => {
let checked_context = self.check(context, None, scope)?;
let type_ = type_of(&checked_context, scope, self.pool)?;
let class = match type_.unwrapped() {
Expr::MethodCall(receiver, name, args, span) => {
let checked_receiver = self.check(receiver, None, scope)?;
let receiver_type = type_of(&checked_receiver, scope, self.pool)?;
let class = match receiver_type.unwrapped() {
TypeId::Struct(class) | TypeId::Class(class) => *class,
type_ => return Err(Cause::InvalidMemberAccess(type_.pretty(self.pool)?).with_span(*span)),
};
let candidates = scope.resolve_method(name.clone(), class, self.pool).with_span(*span)?;
let match_ = self.resolve_overload(name.clone(), candidates, args.iter(), expected, scope, *span)?;
let receiver = matches!(checked_receiver, Expr::Ident(Reference::Symbol(_), _))
.not()
.then_some(&receiver_type);
let candidates = Scope::resolve_method(name.clone(), class, self.pool).with_span(*span)?;

let converted_context = if let TypeId::WeakRef(inner) = type_ {
insert_conversion(checked_context, &TypeId::Ref(inner), Conversion::WeakRefToRef)
let match_ =
self.resolve_overload(name.clone(), candidates, args.iter(), expected, receiver, scope, *span)?;

if match_.insert_receiver {
Expr::Call(
Callable::Function(match_.index),
[].into(),
iter::once(checked_receiver).chain(match_.args).collect(),
*span,
)
} else {
checked_context
};
Expr::MethodCall(Box::new(converted_context), match_.index, match_.args, *span)
let is_static = self.pool.function(match_.index).is_ok_and(|f| f.flags.is_static());
if receiver.is_some() && is_static {
return Err(Cause::InvalidStaticMethodCall.with_span(*span));
} else if receiver.is_none() && !is_static {
return Err(Cause::InvalidNonStaticMethodCall.with_span(*span));
}

let converted_receiver = if let TypeId::WeakRef(inner) = receiver_type {
insert_conversion(checked_receiver, &TypeId::Ref(inner), Conversion::WeakRefToRef)
} else {
checked_receiver
};
Expr::MethodCall(Box::new(converted_receiver), match_.index, match_.args, *span)
}
}
Expr::BinOp(lhs, rhs, op, span) => {
let name = Ident::from_static(op.into());
let args = IntoIterator::into_iter([lhs.as_ref(), rhs.as_ref()]);
let candidates = scope.resolve_function(name.clone()).with_span(*span)?;
let match_ = self.resolve_overload(name, candidates, args, expected, scope, *span)?;
let match_ = self.resolve_overload(name, candidates, args, expected, None, scope, *span)?;
Expr::Call(
Callable::Function(match_.index),
[].into(),
Expand All @@ -185,7 +214,7 @@ impl<'a> TypeChecker<'a> {
let name = Ident::from_static(op.into());
let args = iter::once(expr.as_ref());
let candidates = scope.resolve_function(name.clone()).with_span(*span)?;
let match_ = self.resolve_overload(name, candidates, args, expected, scope, *span)?;
let match_ = self.resolve_overload(name, candidates, args, expected, None, scope, *span)?;
Expr::Call(
Callable::Function(match_.index),
[].into(),
Expand All @@ -207,9 +236,7 @@ impl<'a> TypeChecker<'a> {
Member::StructField(field)
}
TypeId::Enum(enum_) => {
let member = scope
.resolve_enum_member(name.clone(), *enum_, self.pool)
.with_span(*span)?;
let member = Scope::resolve_enum_member(name.clone(), *enum_, self.pool).with_span(*span)?;
Member::EnumMember(*enum_, member)
}
type_ => return Err(Cause::InvalidMemberAccess(type_.pretty(self.pool)?).with_span(*span)),
Expand Down Expand Up @@ -558,12 +585,14 @@ impl<'a> TypeChecker<'a> {
}
}

#[allow(clippy::too_many_arguments)]
fn resolve_overload<'b>(
&mut self,
name: Ident,
overloads: FunctionCandidates,
args: impl ExactSizeIterator<Item = &'b Expr<SourceAst>> + Clone,
expected: Option<&TypeId>,
receiver: Option<&TypeId>,
scope: &mut Scope,
span: Span,
) -> Result<FunctionMatch, Error> {
Expand All @@ -572,23 +601,23 @@ impl<'a> TypeChecker<'a> {
let mut overload_errors = vec![];

for overload in &overloads.functions {
match Self::validate_call(*overload, arg_count, scope, self.pool, span) {
match Self::validate_call(*overload, arg_count, receiver, scope, self.pool, span) {
Ok(res) => eligible.push(res),
Err(MatcherError::MatchError(err)) => overload_errors.push(err),
Err(MatcherError::Other(err)) => return Err(err),
}
}

let match_ = match eligible.into_iter().exactly_one() {
Ok((fun_index, types)) => {
Ok((fun_index, types, has_static_receiver)) => {
let checked_args: Vec<_> = args
.clone()
.zip(&types)
.map(|(arg, typ)| self.check(arg, Some(typ), scope))
.try_collect()?;

match Self::validate_args(fun_index, &checked_args, &types, expected, scope, self.pool, span) {
Ok(conversions) => Ok(FunctionMatch::new(fun_index, checked_args, conversions)),
Ok(convs) => Ok(FunctionMatch::new(fun_index, checked_args, convs, has_static_receiver)),
Err(MatcherError::MatchError(err)) => {
overload_errors.push(err);
Err(Cause::NoMatchingOverload(name, overload_errors.into_boxed_slice()).with_span(span))
Expand All @@ -600,18 +629,20 @@ impl<'a> TypeChecker<'a> {
let checked_args: Vec<_> = args.clone().map(|expr| self.check(expr, None, scope)).try_collect()?;

let mut matches = vec![];
for (fun_index, types) in eligible {
for (fun_index, types, has_static_receiver) in eligible {
match Self::validate_args(fun_index, &checked_args, &types, expected, scope, self.pool, span) {
Ok(convs) => matches.push((fun_index, convs)),
Ok(convs) => matches.push((fun_index, convs, has_static_receiver)),
Err(MatcherError::MatchError(err)) => overload_errors.push(err),
Err(MatcherError::Other(err)) => return Err(err),
}
}

let mut it = matches.into_iter();
match it.next() {
None => Err(Cause::NoMatchingOverload(name, overload_errors.into_boxed_slice()).with_span(span)),
Some((fun_index, convs)) => Ok(FunctionMatch::new(fun_index, checked_args, convs)),
None => Err(Cause::NoMatchingOverload(name, overload_errors.into()).with_span(span)),
Some((fun_index, convs, has_static_receiver)) => {
Ok(FunctionMatch::new(fun_index, checked_args, convs, has_static_receiver))
}
}
}
};
Expand All @@ -622,7 +653,7 @@ impl<'a> TypeChecker<'a> {

let dummy_args: Vec<_> = args.map(|expr| self.check(expr, None, scope)).try_collect()?;
let convs = iter::repeat(ArgConversion::identity()).take(dummy_args.len()).collect();
Ok(FunctionMatch::new(overloads.functions[0], dummy_args, convs))
Ok(FunctionMatch::new(overloads.functions[0], dummy_args, convs, false))
}
Err(err) => Err(err),
}
Expand All @@ -631,16 +662,37 @@ impl<'a> TypeChecker<'a> {
fn validate_call(
fun_index: PoolIndex<Function>,
arg_count: usize,
receiver: Option<&TypeId>,
scope: &Scope,
pool: &ConstantPool,
span: Span,
) -> Result<(PoolIndex<Function>, Vec<TypeId>), MatcherError> {
) -> Result<(PoolIndex<Function>, Vec<TypeId>, bool), MatcherError> {
let fun = pool.function(fun_index)?;
let params = fun
let mut params = fun
.parameters
.iter()
.map(|idx| pool.parameter(*idx).map_err(Error::PoolError));
let min_params = params.clone().filter_ok(|param| !param.flags.is_optional()).count();

let has_static_receiver = receiver.is_some_and(|receiver| {
fun.flags.is_static()
&& params
.by_ref()
.peekable()
.next_if(|p| {
p.as_ref().is_ok_and(|p| {
scope
.resolve_type_from_pool(p.type_, pool)
.is_ok_and(|t| &t == receiver)
})
})
.is_some()
});

let min_params = params
.clone()
.rev()
.skip_while(|p| p.as_ref().is_ok_and(|p| p.flags.is_optional()))
.count();

if arg_count < min_params || arg_count > fun.parameters.len() {
let err = FunctionMatchError::ArgumentCountMismatch {
Expand All @@ -654,7 +706,7 @@ impl<'a> TypeChecker<'a> {
let types = params
.map(|res| res.and_then(|param| scope.resolve_type_from_pool(param.type_, pool).with_span(span)))
.try_collect()?;
Ok((fun_index, types))
Ok((fun_index, types, has_static_receiver))
}

fn validate_args(
Expand Down Expand Up @@ -972,16 +1024,26 @@ impl ArgConversion {
pub struct FunctionMatch {
pub index: PoolIndex<Function>,
pub args: Vec<TypedExpr>,
pub insert_receiver: bool,
}

impl FunctionMatch {
fn new(index: PoolIndex<Function>, args: Vec<TypedExpr>, conversions: Vec<ArgConversion>) -> Self {
fn new(
index: PoolIndex<Function>,
args: Vec<TypedExpr>,
conversions: Vec<ArgConversion>,
has_static_receiver: bool,
) -> Self {
let args = args
.into_iter()
.zip(conversions)
.map(|(expr, conv)| insert_conversion(expr, &conv.target, conv.conversion))
.collect();
Self { index, args }
Self {
index,
args,
insert_receiver: has_static_receiver,
}
}
}

Expand Down
14 changes: 3 additions & 11 deletions compiler/src/unit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -838,10 +838,7 @@ impl<'a> CompilationUnit<'a> {
Symbol::Struct(idx, _) | Symbol::Class(idx, _) => idx,
_ => return Err(Cause::ClassNotFound(class_name.clone()).with_span(ann.span)),
};
let candidates = self
.scope
.resolve_direct_method(name.clone(), target_class_idx, self.pool)
.ok();
let candidates = Scope::resolve_direct_method(name.clone(), target_class_idx, self.pool).ok();
let fun_idx = candidates
.as_ref()
.and_then(|cd| cd.by_id(&sig, self.pool))
Expand Down Expand Up @@ -889,10 +886,7 @@ impl<'a> CompilationUnit<'a> {
Symbol::Struct(idx, _) | Symbol::Class(idx, _) => idx,
_ => return Err(Cause::ClassNotFound(class_name.clone()).with_span(ann.span)),
};
let candidates = self
.scope
.resolve_direct_method(name.clone(), target_class_idx, self.pool)
.ok();
let candidates = Scope::resolve_direct_method(name.clone(), target_class_idx, self.pool).ok();
let fun_idx = candidates
.as_ref()
.and_then(|cd| cd.by_id(&sig, self.pool))
Expand Down Expand Up @@ -946,9 +940,7 @@ impl<'a> CompilationUnit<'a> {
_ => return Err(Cause::ClassNotFound(class_name.clone()).with_span(ann.span)),
};

if self
.scope
.resolve_direct_method(name.clone(), target_class_idx, self.pool)
if Scope::resolve_direct_method(name.clone(), target_class_idx, self.pool)
.ok()
.and_then(|cd| cd.by_id(&sig, self.pool))
.is_some()
Expand Down
Loading

0 comments on commit edf2689

Please sign in to comment.