diff --git a/vlib/v/checker/checker.v b/vlib/v/checker/checker.v index 5688ab4278a2ea..b038bef8f9a189 100644 --- a/vlib/v/checker/checker.v +++ b/vlib/v/checker/checker.v @@ -3197,8 +3197,9 @@ fn (mut c Checker) cast_expr(mut node ast.CastExpr) ast.Type { if mut node.expr is ast.ComptimeSelector { node.expr_type = c.comptime.get_comptime_selector_type(node.expr, node.expr_type) + } else if node.expr is ast.Ident && c.comptime.is_comptime_variant_var(node.expr) { + node.expr_type = c.comptime.type_map['${c.comptime.comptime_for_variant_var}.typ'] } - mut from_type := c.unwrap_generic(node.expr_type) from_sym := c.table.sym(from_type) final_from_sym := c.table.final_sym(from_type) diff --git a/vlib/v/comptime/comptimeinfo.v b/vlib/v/comptime/comptimeinfo.v index 58d72383fbadc2..c623365cbd0a3e 100644 --- a/vlib/v/comptime/comptimeinfo.v +++ b/vlib/v/comptime/comptimeinfo.v @@ -34,6 +34,12 @@ pub fn (mut ct ComptimeInfo) is_comptime_var(node ast.Expr) bool { return ct.get_ct_type_var(node) != .no_comptime } +// is_comptime_variant_var checks if the node is related to a comptime variant variable +@[inline] +pub fn (mut ct ComptimeInfo) is_comptime_variant_var(node ast.Ident) bool { + return node.name == ct.comptime_for_variant_var +} + // get_ct_type_var gets the comptime type of the variable (.generic_param, .key_var, etc) @[inline] pub fn (mut ct ComptimeInfo) get_ct_type_var(node ast.Expr) ast.ComptimeVarKind { diff --git a/vlib/v/gen/c/cgen.v b/vlib/v/gen/c/cgen.v index 09c2655c29a1af..53b4682ce00243 100644 --- a/vlib/v/gen/c/cgen.v +++ b/vlib/v/gen/c/cgen.v @@ -704,6 +704,7 @@ fn cgen_process_one_file_cb(mut p pool.PoolProcessor, idx int, wid int) &Gen { inner_loop: &ast.empty_stmt field_data_type: global_g.table.find_type('FieldData') enum_data_type: global_g.table.find_type('EnumData') + variant_data_type: global_g.table.find_type('VariantData') array_sort_fn: global_g.array_sort_fn waiter_fns: global_g.waiter_fns threaded_fns: global_g.threaded_fns @@ -2512,13 +2513,15 @@ fn (mut g Gen) write_sumtype_casting_fn(fun SumtypeCastingFn) { fn (mut g Gen) call_cfn_for_casting_expr(fname string, expr ast.Expr, exp_is_ptr bool, exp_styp string, got_is_ptr bool, got_is_fn bool, got_styp string) { mut rparen_n := 1 + is_comptime_variant := expr is ast.Ident && g.comptime.is_comptime_variant_var(expr) if exp_is_ptr { g.write('HEAP(${exp_styp}, ') rparen_n++ } g.write('${fname}(') if !got_is_ptr && !got_is_fn { - if !expr.is_lvalue() || (expr is ast.Ident && expr.obj.is_simple_define_const()) { + if is_comptime_variant || !expr.is_lvalue() + || (expr is ast.Ident && expr.obj.is_simple_define_const()) { // Note: the `_to_sumtype_` family of functions do call memdup internally, making // another duplicate with the HEAP macro is redundant, so use ADDR instead: promotion_macro_name := if fname.contains('_to_sumtype_') { 'ADDR' } else { 'HEAP' } @@ -2530,6 +2533,8 @@ fn (mut g Gen) call_cfn_for_casting_expr(fname string, expr ast.Expr, exp_is_ptr } if got_styp == 'none' && !g.cur_fn.return_type.has_flag(.option) { g.write('(none){EMPTY_STRUCT_INITIALIZATION}') + } else if is_comptime_variant { + g.write(g.type_default(g.comptime.type_map['${g.comptime.comptime_for_variant_var}.typ'])) } else { g.expr(expr) } @@ -5109,6 +5114,9 @@ fn (mut g Gen) cast_expr(node ast.CastExpr) { if sym.kind in [.sum_type, .interface] { if node.typ.has_flag(.option) && node.expr is ast.None { g.gen_option_error(node.typ, node.expr) + } else if node.expr is ast.Ident && g.comptime.is_comptime_variant_var(node.expr) { + g.expr_with_cast(node.expr, g.comptime.type_map['${g.comptime.comptime_for_variant_var}.typ'], + node_typ) } else if node.typ.has_flag(.option) { g.expr_with_opt(node.expr, expr_type, node.typ) } else { diff --git a/vlib/v/tests/sumtype_init_by_name_test.v b/vlib/v/tests/sumtype_init_by_name_test.v new file mode 100644 index 00000000000000..69163379b289c8 --- /dev/null +++ b/vlib/v/tests/sumtype_init_by_name_test.v @@ -0,0 +1,31 @@ +type Sum = int | string + +fn get[T](val T, type_name string) T { + $if T is $sumtype { + $for v in val.variants { + if type_name == typeof(v.typ).name { + return T(v) + } + } + } + return T{} +} + +fn get2[T](val T, type_name string) T { + $if T is $sumtype { + $for v in val.variants { + if type_name == typeof(v.typ).name { + return Sum(v) + } + } + } + return T{} +} + +fn test_main() { + assert dump(get(Sum{}, 'int')) == Sum(0) + assert dump(get(Sum{}, 'string')) == Sum('') + + assert dump(get2(Sum{}, 'int')) == Sum(0) + assert dump(get2(Sum{}, 'string')) == Sum('') +}