Skip to content

Commit

Permalink
all: support short lambda expressions like a.sorted(|x,y| x > y), in …
Browse files Browse the repository at this point in the history
…all callsites that accept a fn callback (vlang#19390)
  • Loading branch information
spytheman authored and Wertzui123 committed Oct 8, 2023
1 parent 6a61862 commit 1cceb2a
Show file tree
Hide file tree
Showing 17 changed files with 385 additions and 25 deletions.
37 changes: 37 additions & 0 deletions vlib/builtin/sorted_lambda_expr_test.v
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
fn test_sort_with_lambda_expr() {
a := [5, 2, 1, 9, 8]
dump(a)

sorted01 := a.sorted(a < b)
sorted02 := a.sorted(a > b)
dump(sorted01)
dump(sorted02)

sorted01_with_compare_fn := a.sorted_with_compare(fn (a &int, b &int) int {
return *a - *b
})
sorted02_with_compare_fn := a.sorted_with_compare(fn (a &int, b &int) int {
return *b - *a
})
dump(sorted01_with_compare_fn)
dump(sorted02_with_compare_fn)

///////////////////////////////////////////

sorted01_lambda_expr := a.sorted(|ix, iy| ix < iy)
sorted02_lambda_expr := a.sorted(|ii, jj| ii > jj)
dump(sorted01_lambda_expr)
dump(sorted02_lambda_expr)

sorted01_with_compare_lambda_expr := a.sorted_with_compare(|x, y| *x - *y)
sorted02_with_compare_lambda_expr := a.sorted_with_compare(|e1, e2| *e2 - *e1)
dump(sorted01_with_compare_lambda_expr)
dump(sorted02_with_compare_lambda_expr)

assert sorted01 == sorted01_with_compare_fn
assert sorted02 == sorted02_with_compare_fn
assert sorted01 == sorted01_lambda_expr
assert sorted02 == sorted02_lambda_expr
assert sorted01 == sorted01_with_compare_lambda_expr
assert sorted02 == sorted02_with_compare_lambda_expr
}
46 changes: 30 additions & 16 deletions vlib/v/ast/ast.v
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ pub type Expr = AnonFn
| InfixExpr
| IntegerLiteral
| IsRefType
| LambdaExpr
| Likely
| LockExpr
| MapInit
Expand Down Expand Up @@ -506,7 +507,7 @@ pub mut:
decl FnDecl
inherited_vars []Param
typ Type // the type of anonymous fn. Both .typ and .decl.name are auto generated
has_gen map[string]bool // has been generated
has_gen map[string]bool // a map of the names of all generic anon functions, generated from it
}

// function or method declaration
Expand Down Expand Up @@ -782,15 +783,15 @@ pub:
share ShareType
is_mut bool
is_autofree_tmp bool
is_arg bool // fn args should not be autofreed
is_auto_deref bool
is_inherited bool
has_inherited bool
pub mut:
expr Expr
typ Type
orig_type Type // original sumtype type; 0 if it's not a sumtype
smartcasts []Type // nested sum types require nested smart casting, for that a list of types is needed
is_arg bool // fn args should not be autofreed
is_auto_deref bool
expr Expr
typ Type
orig_type Type // original sumtype type; 0 if it's not a sumtype
smartcasts []Type // nested sum types require nested smart casting, for that a list of types is needed
// TODO: move this to a real docs site later
// 10 <- original type (orig_type)
// [11, 12, 13] <- cast order (smartcasts)
Expand Down Expand Up @@ -891,6 +892,7 @@ pub mut:
generic_fns []&FnDecl
global_labels []string // from `asm { .globl labelname }`
template_paths []string // all the .html/.md files that were processed with $tmpl
unique_prefix string // a hash of the `.path` field, used for making anon fn generation unique
}

[unsafe]
Expand Down Expand Up @@ -1253,14 +1255,6 @@ pub mut:
// ct_conds is filled by the checker, based on the current nesting of `$if cond1 {}` blocks
}

/*
// filter(), map(), sort()
pub struct Lambda {
pub:
name string
}
*/

// variable assign statement
[minify]
pub struct AssignStmt {
Expand Down Expand Up @@ -1790,6 +1784,20 @@ pub:
pos token.Pos
}

pub struct LambdaExpr {
pub:
pos token.Pos
params []Ident
pub mut:
pos_expr token.Pos
expr Expr
pos_end token.Pos
scope &Scope = unsafe { nil }
func &AnonFn = unsafe { nil }
is_checked bool
typ Type
}

pub struct Likely {
pub:
pos token.Pos
Expand Down Expand Up @@ -1977,7 +1985,7 @@ pub fn (expr Expr) pos() token.Pos {
IsRefType, Likely, LockExpr, MapInit, MatchExpr, None, OffsetOf, OrExpr, ParExpr,
PostfixExpr, PrefixExpr, RangeExpr, SelectExpr, SelectorExpr, SizeOf, SqlExpr,
StringInterLiteral, StringLiteral, StructInit, TypeNode, TypeOf, UnsafeExpr, ComptimeType,
Nil {
LambdaExpr, Nil {
return expr.pos
}
IndexExpr {
Expand Down Expand Up @@ -2169,6 +2177,12 @@ pub fn (node Node) children() []Node {
TypeOf, ArrayDecompose {
children << node.expr
}
LambdaExpr {
for p in node.params {
children << Node(Expr(p))
}
children << node.expr
}
LockExpr, OrExpr {
return node.stmts.map(Node(it))
}
Expand Down
9 changes: 9 additions & 0 deletions vlib/v/ast/str.v
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ pub fn (f &FnDecl) get_name() string {
}
}

// get_anon_fn_name returns the unique anonymous function name, based on the prefix, the func signature and its position in the source code
pub fn (table &Table) get_anon_fn_name(prefix string, func &Fn, pos int) string {
return 'anon_fn_${prefix}_${table.fn_type_signature(func)}_${pos}'
}

// get_name returns the real name for the function calling
pub fn (f &CallExpr) get_name() string {
if f.name != '' && f.name.all_after_last('.')[0].is_capital() && f.name.contains('__static__') {
Expand Down Expand Up @@ -609,6 +614,10 @@ pub fn (x Expr) str() string {
}
return 'typeof(${x.expr.str()})'
}
LambdaExpr {
ilist := x.params.map(it.name).join(', ')
return '|${ilist}| ${x.expr.str()}'
}
Likely {
return '_likely_(${x.expr.str()})'
}
Expand Down
3 changes: 3 additions & 0 deletions vlib/v/checker/checker.v
Original file line number Diff line number Diff line change
Expand Up @@ -2783,6 +2783,9 @@ pub fn (mut c Checker) expr(mut node ast.Expr) ast.Type {
ast.IntegerLiteral {
return c.int_lit(mut node)
}
ast.LambdaExpr {
return c.lambda_expr(mut node, c.expected_type)
}
ast.LockExpr {
return c.lock_expr(mut node)
}
Expand Down
8 changes: 7 additions & 1 deletion vlib/v/checker/fn.v
Original file line number Diff line number Diff line change
Expand Up @@ -2597,6 +2597,10 @@ fn (mut c Checker) array_builtin_method_call(mut node ast.CallExpr, left_type as
if method_name in ['filter', 'map', 'any', 'all'] {
// position of `it` doesn't matter
scope_register_it(mut node.scope, node.pos, elem_typ)
} else if method_name == 'sorted_with_compare' && node.args.len == 1 {
if mut node.args[0].expr is ast.LambdaExpr {
c.support_lambda_expr_in_sort(elem_typ.ref(), ast.int_type, mut node.args[0].expr)
}
} else if method_name == 'sort' || method_name == 'sorted' {
if method_name == 'sort' {
if node.left is ast.CallExpr {
Expand All @@ -2611,7 +2615,9 @@ fn (mut c Checker) array_builtin_method_call(mut node ast.CallExpr, left_type as
if node.args.len > 1 {
c.error('expected 0 or 1 argument, but got ${node.args.len}', node.pos)
} else if node.args.len == 1 {
if node.args[0].expr is ast.InfixExpr {
if mut node.args[0].expr is ast.LambdaExpr {
c.support_lambda_expr_in_sort(elem_typ.ref(), ast.bool_type, mut node.args[0].expr)
} else if node.args[0].expr is ast.InfixExpr {
if node.args[0].expr.op !in [.gt, .lt] {
c.error('`.${method_name}()` can only use `<` or `>` comparison',
node.pos)
Expand Down
123 changes: 123 additions & 0 deletions vlib/v/checker/lambda_expr.v
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
module checker

import v.ast

pub fn (mut c Checker) lambda_expr(mut node ast.LambdaExpr, exp_typ ast.Type) ast.Type {
// defer { eprintln('> line: ${@LINE} | exp_typ: $exp_typ | node: ${voidptr(node)} | node.typ: ${node.typ}') }
if node.is_checked {
return node.typ
}
if !c.inside_fn_arg {
c.error('lambda expressions are allowed only inside function or method callsites',
node.pos)
return ast.void_type
}
if exp_typ == 0 {
c.error('lambda expressions are allowed only in places expecting function callbacks',
node.pos)
return ast.void_type
}
exp_sym := c.table.sym(exp_typ)
if exp_sym.kind != .function {
c.error('a lambda expression was used, but `${exp_sym.kind}` was expected', node.pos)
return ast.void_type
}
if exp_sym.info is ast.FnType {
if node.params.len != exp_sym.info.func.params.len {
c.error('lambda expression has ${node.params.len} params, but the expected fn callback needs ${exp_sym.info.func.params.len} params',
node.pos)
return ast.void_type
}
mut params := []ast.Param{}
for idx, mut x in node.params {
eparam := exp_sym.info.func.params[idx]
eparam_type := eparam.typ
eparam_auto_deref := eparam.typ.is_ptr()
if mut v := node.scope.find(x.name) {
if mut v is ast.Var {
v.is_arg = true
v.typ = eparam_type
v.expr = ast.empty_expr
v.is_auto_deref = eparam_auto_deref
}
}
c.ident(mut x)
x.obj.typ = eparam_type

params << ast.Param{
pos: x.pos
name: x.name
typ: eparam_type
type_pos: x.pos
is_auto_rec: eparam_auto_deref
}
}
/////
is_variadic := false
return_type := exp_sym.info.func.return_type
return_type_pos := node.pos
mut stmts := []ast.Stmt{}
mut return_stmt := ast.Return{
pos: node.pos
exprs: [node.expr]
}
stmts << return_stmt

mut func := ast.Fn{
params: params
is_variadic: is_variadic
return_type: return_type
is_method: false
}
name := c.table.get_anon_fn_name(c.file.unique_prefix, func, node.pos.pos)
func.name = name
idx := c.table.find_or_register_fn_type(func, true, false)
typ := ast.new_type(idx)
node.func = &ast.AnonFn{
decl: ast.FnDecl{
name: name
short_name: ''
mod: c.file.mod.name
stmts: stmts
return_type: return_type
return_type_pos: return_type_pos
params: params
is_variadic: is_variadic
is_method: false
is_anon: true
no_body: false
pos: node.pos.extend(node.pos_end)
file: c.file.path
scope: node.scope.parent
}
typ: typ
}
c.anon_fn(mut node.func)
}
node.is_checked = true
node.typ = exp_typ

return exp_typ
}

pub fn (mut c Checker) support_lambda_expr_in_sort(param_type ast.Type, return_type ast.Type, mut expr ast.LambdaExpr) {
is_auto_rec := param_type.is_ptr()
mut expected_fn := ast.Fn{
params: [
ast.Param{
name: 'zza'
typ: param_type
is_auto_rec: is_auto_rec
},
ast.Param{
name: 'zzb'
typ: param_type
is_auto_rec: is_auto_rec
},
]
return_type: return_type
}
expected_fn_type := ast.new_type(c.table.find_or_register_fn_type(expected_fn, true,
false))
c.lambda_expr(mut expr, expected_fn_type)
}
2 changes: 1 addition & 1 deletion vlib/v/eval/expr.v
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ pub fn (mut e Eval) expr(expr ast.Expr, expecting ast.Type) Object {
ast.ConcatExpr, ast.DumpExpr, ast.EmptyExpr, ast.EnumVal, ast.GoExpr, ast.SpawnExpr,
ast.IfGuardExpr, ast.IsRefType, ast.Likely, ast.LockExpr, ast.MapInit, ast.MatchExpr,
ast.Nil, ast.NodeError, ast.None, ast.OffsetOf, ast.OrExpr, ast.RangeExpr, ast.SelectExpr,
ast.SqlExpr, ast.TypeNode, ast.TypeOf {
ast.SqlExpr, ast.TypeNode, ast.TypeOf, ast.LambdaExpr {
e.error('unhandled expression ${typeof(expr).name}')
}
}
Expand Down
11 changes: 11 additions & 0 deletions vlib/v/fmt/fmt.v
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,17 @@ pub fn (mut f Fmt) expr(node_ ast.Expr) {
ast.IntegerLiteral {
f.write(node.val)
}
ast.LambdaExpr {
f.write('|')
for i, x in node.params {
f.expr(x)
if i < node.params.len - 1 {
f.write(', ')
}
}
f.write('| ')
f.expr(node.expr)
}
ast.Likely {
f.likely(node)
}
Expand Down
Loading

0 comments on commit 1cceb2a

Please sign in to comment.