From ca99a1d355a9f8009095f97a25cf00e989439698 Mon Sep 17 00:00:00 2001 From: yuyi Date: Wed, 7 Sep 2022 01:33:39 +0800 Subject: [PATCH] cgen: fix sumtype with fntype using fn directly (fix #15674) (#15679) --- vlib/v/gen/c/cgen.v | 40 ++++++++++++++++++------- vlib/v/tests/sumtype_with_fntype_test.v | 32 ++++++++++++++++++++ 2 files changed, 61 insertions(+), 11 deletions(-) create mode 100644 vlib/v/tests/sumtype_with_fntype_test.v diff --git a/vlib/v/gen/c/cgen.v b/vlib/v/gen/c/cgen.v index 5aaaa4da81..0035f9d204 100644 --- a/vlib/v/gen/c/cgen.v +++ b/vlib/v/gen/c/cgen.v @@ -2098,16 +2098,22 @@ fn (mut g Gen) write_sumtype_casting_fn(fun SumtypeCastingFn) { got_sym, exp_sym := g.table.sym(got), g.table.sym(exp) mut got_cname, exp_cname := got_sym.cname, exp_sym.cname mut type_idx := g.type_sidx(got) + mut sb := strings.new_builder(128) + mut is_anon_fn := false if got_sym.info is ast.FnType { - if got_sym.info.is_anon { + if got_sym.info.is_anon || g.table.known_fn(got_sym.name) { got_name := 'fn ${g.table.fn_type_source_signature(got_sym.info.func)}' got_cname = 'anon_fn_${g.table.fn_type_signature(got_sym.info.func)}' type_idx = g.table.type_idxs[got_name].str() + sb.writeln('static inline $exp_cname ${fun.fn_name}($got_cname x) {') + sb.writeln('\t$got_cname ptr = x;') + is_anon_fn = true } } - mut sb := strings.new_builder(128) - sb.writeln('static inline $exp_cname ${fun.fn_name}($got_cname* x) {') - sb.writeln('\t$got_cname* ptr = memdup(x, sizeof($got_cname));') + if !is_anon_fn { + sb.writeln('static inline $exp_cname ${fun.fn_name}($got_cname* x) {') + sb.writeln('\t$got_cname* ptr = memdup(x, sizeof($got_cname));') + } for embed_hierarchy in g.table.get_embeds(got_sym) { // last embed in the hierarchy mut embed_cname := '' @@ -2149,14 +2155,14 @@ fn (mut g Gen) write_sumtype_casting_fn(fun SumtypeCastingFn) { g.auto_fn_definitions << sb.str() } -fn (mut g Gen) call_cfn_for_casting_expr(fname string, expr ast.Expr, exp_is_ptr bool, exp_styp string, got_is_ptr bool, got_styp string) { +fn (mut g Gen) call_cfn_for_casting_expr(fname string, expr ast.Expr, exp_is_ptr bool, exp_styp string, got_is_ptr bool, got_is_fn bool, got_styp string) { mut rparen_n := 1 if exp_is_ptr { g.write('HEAP($exp_styp, ') rparen_n++ } g.write('${fname}(') - if !got_is_ptr { + if !got_is_ptr && !got_is_fn { if !expr.is_lvalue() || (expr is ast.Ident && (expr as ast.Ident).obj.is_simple_define_const()) { // Note: the `_to_sumtype_` family of functions do call memdup internally, making @@ -2207,7 +2213,7 @@ fn (mut g Gen) expr_with_cast(expr ast.Expr, got_type_raw ast.Type, expected_typ fname = g.generic_fn_name(exp_sym.info.concrete_types, fname) } g.call_cfn_for_casting_expr(fname, expr, expected_is_ptr, exp_styp, true, - got_styp) + false, got_styp) g.inside_cast_in_heap-- } else { got_styp := g.cc_type(got_type, true) @@ -2227,16 +2233,18 @@ fn (mut g Gen) expr_with_cast(expr ast.Expr, got_type_raw ast.Type, expected_typ fname = g.generic_fn_name(exp_sym.info.concrete_types, fname) } g.call_cfn_for_casting_expr(fname, expr, expected_is_ptr, exp_styp, got_is_ptr, - got_styp) + false, got_styp) } return } // cast to sum type exp_styp := g.typ(expected_type) mut got_styp := g.typ(got_type) + mut got_is_fn := false if got_sym.info is ast.FnType { if got_sym.info.is_anon { got_styp = 'anon_fn_${g.table.fn_type_signature(got_sym.info.func)}' + got_is_fn = true } } if expected_type != ast.void_type { @@ -2276,7 +2284,7 @@ fn (mut g Gen) expr_with_cast(expr ast.Expr, got_type_raw ast.Type, expected_typ } fname := g.get_sumtype_casting_fn(unwrapped_got_type, unwrapped_expected_type) g.call_cfn_for_casting_expr(fname, expr, expected_is_ptr, unwrapped_exp_sym.cname, - got_is_ptr, got_styp) + got_is_ptr, got_is_fn, got_styp) } return } @@ -5074,7 +5082,13 @@ fn (mut g Gen) write_types(symbols []&ast.TypeSymbol) { g.type_definitions.writeln('\tunion {') for variant in sym.info.variants { variant_sym := g.table.sym(variant) - g.type_definitions.writeln('\t\t${g.typ(variant.ref())} _$variant_sym.cname;') + mut var := variant.ref() + if variant_sym.info is ast.FnType { + if variant_sym.info.is_anon { + var = variant + } + } + g.type_definitions.writeln('\t\t${g.typ(var)} _$variant_sym.cname;') } g.type_definitions.writeln('\t};') g.type_definitions.writeln('\tint _typ;') @@ -5608,7 +5622,11 @@ fn (mut g Gen) as_cast(node ast.AsCast) { mut expr_type_sym := g.table.sym(g.unwrap_generic(node.expr_type)) if mut expr_type_sym.info is ast.SumType { dot := if node.expr_type.is_ptr() { '->' } else { '.' } - g.write('/* as */ *($styp*)__as_cast(') + if sym.info is ast.FnType { + g.write('/* as */ ($styp)__as_cast(') + } else { + g.write('/* as */ *($styp*)__as_cast(') + } g.write('(') g.expr(node.expr) g.write(')') diff --git a/vlib/v/tests/sumtype_with_fntype_test.v b/vlib/v/tests/sumtype_with_fntype_test.v new file mode 100644 index 0000000000..306eb88ac4 --- /dev/null +++ b/vlib/v/tests/sumtype_with_fntype_test.v @@ -0,0 +1,32 @@ +type Expr = fn () int | fn (int) int + +fn id(n int) int { + return n +} + +fn func() Expr { + return id +} + +fn test_sumtype_with_fntype() { + f := func() + + f1 := f as fn (int) int + println(123) + assert f1(123) == 123 + + if f is fn (int) int { + ret := f(123) + println(ret) + assert ret == 123 + } + + match f { + fn (int) int { + ret := f(321) + println(ret) + assert ret == 321 + } + else {} + } +}