Skip to content

Commit

Permalink
v: allow sumtype init by variant comptime var T(v) / SumType(v) (v…
Browse files Browse the repository at this point in the history
  • Loading branch information
felipensp authored Oct 27, 2024
1 parent 14b1a14 commit 731d07d
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 2 deletions.
3 changes: 2 additions & 1 deletion vlib/v/checker/checker.v
Original file line number Diff line number Diff line change
Expand Up @@ -3197,8 +3197,9 @@ fn (mut c Checker) cast_expr(mut node ast.CastExpr) ast.Type {

if mut node.expr is ast.ComptimeSelector {
node.expr_type = c.comptime.get_comptime_selector_type(node.expr, node.expr_type)
} else if node.expr is ast.Ident && c.comptime.is_comptime_variant_var(node.expr) {
node.expr_type = c.comptime.type_map['${c.comptime.comptime_for_variant_var}.typ']
}

mut from_type := c.unwrap_generic(node.expr_type)
from_sym := c.table.sym(from_type)
final_from_sym := c.table.final_sym(from_type)
Expand Down
6 changes: 6 additions & 0 deletions vlib/v/comptime/comptimeinfo.v
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ pub fn (mut ct ComptimeInfo) is_comptime_var(node ast.Expr) bool {
return ct.get_ct_type_var(node) != .no_comptime
}

// is_comptime_variant_var checks if the node is related to a comptime variant variable
@[inline]
pub fn (mut ct ComptimeInfo) is_comptime_variant_var(node ast.Ident) bool {
return node.name == ct.comptime_for_variant_var
}

// get_ct_type_var gets the comptime type of the variable (.generic_param, .key_var, etc)
@[inline]
pub fn (mut ct ComptimeInfo) get_ct_type_var(node ast.Expr) ast.ComptimeVarKind {
Expand Down
10 changes: 9 additions & 1 deletion vlib/v/gen/c/cgen.v
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,7 @@ fn cgen_process_one_file_cb(mut p pool.PoolProcessor, idx int, wid int) &Gen {
inner_loop: &ast.empty_stmt
field_data_type: global_g.table.find_type('FieldData')
enum_data_type: global_g.table.find_type('EnumData')
variant_data_type: global_g.table.find_type('VariantData')
array_sort_fn: global_g.array_sort_fn
waiter_fns: global_g.waiter_fns
threaded_fns: global_g.threaded_fns
Expand Down Expand Up @@ -2512,13 +2513,15 @@ fn (mut g Gen) write_sumtype_casting_fn(fun SumtypeCastingFn) {
fn (mut g Gen) call_cfn_for_casting_expr(fname string, expr ast.Expr, exp_is_ptr bool, exp_styp string,
got_is_ptr bool, got_is_fn bool, got_styp string) {
mut rparen_n := 1
is_comptime_variant := expr is ast.Ident && g.comptime.is_comptime_variant_var(expr)
if exp_is_ptr {
g.write('HEAP(${exp_styp}, ')
rparen_n++
}
g.write('${fname}(')
if !got_is_ptr && !got_is_fn {
if !expr.is_lvalue() || (expr is ast.Ident && expr.obj.is_simple_define_const()) {
if is_comptime_variant || !expr.is_lvalue()
|| (expr is ast.Ident && expr.obj.is_simple_define_const()) {
// Note: the `_to_sumtype_` family of functions do call memdup internally, making
// another duplicate with the HEAP macro is redundant, so use ADDR instead:
promotion_macro_name := if fname.contains('_to_sumtype_') { 'ADDR' } else { 'HEAP' }
Expand All @@ -2530,6 +2533,8 @@ fn (mut g Gen) call_cfn_for_casting_expr(fname string, expr ast.Expr, exp_is_ptr
}
if got_styp == 'none' && !g.cur_fn.return_type.has_flag(.option) {
g.write('(none){EMPTY_STRUCT_INITIALIZATION}')
} else if is_comptime_variant {
g.write(g.type_default(g.comptime.type_map['${g.comptime.comptime_for_variant_var}.typ']))
} else {
g.expr(expr)
}
Expand Down Expand Up @@ -5109,6 +5114,9 @@ fn (mut g Gen) cast_expr(node ast.CastExpr) {
if sym.kind in [.sum_type, .interface] {
if node.typ.has_flag(.option) && node.expr is ast.None {
g.gen_option_error(node.typ, node.expr)
} else if node.expr is ast.Ident && g.comptime.is_comptime_variant_var(node.expr) {
g.expr_with_cast(node.expr, g.comptime.type_map['${g.comptime.comptime_for_variant_var}.typ'],
node_typ)
} else if node.typ.has_flag(.option) {
g.expr_with_opt(node.expr, expr_type, node.typ)
} else {
Expand Down
31 changes: 31 additions & 0 deletions vlib/v/tests/sumtype_init_by_name_test.v
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
type Sum = int | string

fn get[T](val T, type_name string) T {
$if T is $sumtype {
$for v in val.variants {
if type_name == typeof(v.typ).name {
return T(v)
}
}
}
return T{}
}

fn get2[T](val T, type_name string) T {
$if T is $sumtype {
$for v in val.variants {
if type_name == typeof(v.typ).name {
return Sum(v)
}
}
}
return T{}
}

fn test_main() {
assert dump(get(Sum{}, 'int')) == Sum(0)
assert dump(get(Sum{}, 'string')) == Sum('')

assert dump(get2(Sum{}, 'int')) == Sum(0)
assert dump(get2(Sum{}, 'string')) == Sum('')
}

0 comments on commit 731d07d

Please sign in to comment.