From 1d59d35129f2f07a9f256fb6458ec31b3f15250c Mon Sep 17 00:00:00 2001 From: Ruofan XU <47302112+SleepyRoy@users.noreply.github.com> Date: Wed, 29 Jul 2020 04:17:25 +0800 Subject: [PATCH] cgen: fix sum type match (#5978) --- vlib/v/gen/cgen.v | 45 +++++++++++++++++ vlib/v/gen/fn.v | 2 +- vlib/v/tests/sum_type_test.v | 93 +++++++++++++++++++++++++++++++++++- 3 files changed, 138 insertions(+), 2 deletions(-) diff --git a/vlib/v/gen/cgen.v b/vlib/v/gen/cgen.v index 63e0129ed6..e7dd915d75 100644 --- a/vlib/v/gen/cgen.v +++ b/vlib/v/gen/cgen.v @@ -99,6 +99,8 @@ mut: inside_const bool comp_for_method string // $for method in T { comptime_var_type_map map[string]table.Type + match_sumtype_exprs []ast.Expr + match_sumtype_syms []table.TypeSymbol } const ( @@ -1760,6 +1762,9 @@ fn (mut g Gen) expr(node ast.Expr) { g.write(node.val) } ast.Ident { + if g.should_write_asterisk_due_to_match_sumtype(node) { + g.write('*') + } g.ident(node) } ast.IfExpr { @@ -2269,6 +2274,16 @@ fn (mut g Gen) match_expr(node ast.MatchExpr) { // g.write('/* EM ret type=${g.typ(node.return_type)} expected_type=${g.typ(node.expected_type)} */') } type_sym := g.table.get_type_symbol(node.cond_type) + if node.is_sum_type { + g.match_sumtype_exprs << node.cond + g.match_sumtype_syms << type_sym + } + defer { + if node.is_sum_type { + g.match_sumtype_exprs.pop() + g.match_sumtype_syms.pop() + } + } mut tmp := '' if type_sym.kind != .void { tmp = g.new_tmp_var() @@ -2417,6 +2432,36 @@ fn (mut g Gen) ident(node ast.Ident) { g.write(g.get_ternary_name(name)) } +[unlikely] +fn (mut g Gen) should_write_asterisk_due_to_match_sumtype(expr ast.Expr) bool { + if expr is ast.Ident { + typ := if expr.info is ast.IdentVar { (expr.info as ast.IdentVar).typ } else { (expr.info as ast.IdentFn).typ } + return if typ.is_ptr() && g.match_sumtype_has_no_struct_and_contains(expr) { + true + } else { + false + } + } else { + return false + } +} + +[unlikely] +fn (mut g Gen) match_sumtype_has_no_struct_and_contains(node ast.Ident) bool { + for i, expr in g.match_sumtype_exprs { + if expr is ast.Ident && node.name == (expr as ast.Ident).name { + sumtype := g.match_sumtype_syms[i].info as table.SumType + for typ in sumtype.variants { + if g.table.get_type_symbol(typ).kind == .struct_ { + return false + } + } + return true + } + } + return false +} + fn (mut g Gen) concat_expr(node ast.ConcatExpr) { styp := g.typ(node.return_type) sym := g.table.get_type_symbol(node.return_type) diff --git a/vlib/v/gen/fn.v b/vlib/v/gen/fn.v index 1e81adbd02..ff7e45a03e 100644 --- a/vlib/v/gen/fn.v +++ b/vlib/v/gen/fn.v @@ -391,7 +391,7 @@ fn (mut g Gen) method_call(node ast.CallExpr) { // g.write('/*${g.typ(node.receiver_type)}*/') // g.write('/*expr_type=${g.typ(node.left_type)} rec type=${g.typ(node.receiver_type)}*/') // } - if !node.receiver_type.is_ptr() && node.left_type.is_ptr() && node.name == 'str' { + if !node.receiver_type.is_ptr() && node.left_type.is_ptr() && node.name == 'str' && !g.should_write_asterisk_due_to_match_sumtype(node.left) { g.write('ptr_str(') } else { g.write('${name}(') diff --git a/vlib/v/tests/sum_type_test.v b/vlib/v/tests/sum_type_test.v index 45a1baf523..7e91ae3ec6 100644 --- a/vlib/v/tests/sum_type_test.v +++ b/vlib/v/tests/sum_type_test.v @@ -177,4 +177,95 @@ fn test_int_cast_to_sumtype() { assert false } } -} \ No newline at end of file +} + +// TODO: change definition once types other than any_int and any_float (int, f64, etc) are supported in sumtype +type Number = any_int | any_float + +fn is_gt_simple(val string, dst Number) bool { + match dst { + any_int { + return val.int() > dst + } + any_float { + return dst < val.f64() + } + } +} + +fn is_gt_nested(val string, dst Number) bool { + dst2 := dst + match dst { + any_int { + match dst2 { + any_int { + return val.int() > dst + } + // this branch should never been hit + else { + return val.int() < dst + } + } + } + any_float { + match dst2 { + any_float { + return dst < val.f64() + } + // this branch should never been hit + else { + return dst > val.f64() + } + } + } + } +} + +fn concat(val string, dst Number) string { + match dst { + any_int { + mut res := val + '(int)' + res += dst.str() + return res + } + any_float { + mut res := val + '(float)' + res += dst.str() + return res + } + } +} + +fn get_sum(val string, dst Number) f64 { + match dst { + any_int { + mut res := val.int() + res += dst + return res + } + any_float { + mut res := val.f64() + res += dst + return res + } + } +} + +fn test_sum_type_match() { + assert is_gt_simple('3', 2) + assert !is_gt_simple('3', 5) + assert is_gt_simple('3', 1.2) + assert !is_gt_simple('3', 3.5) + assert is_gt_nested('3', 2) + assert !is_gt_nested('3', 5) + assert is_gt_nested('3', 1.2) + assert !is_gt_nested('3', 3.5) + assert concat('3', 2) == '3(int)2' + assert concat('3', 5) == '3(int)5' + assert concat('3', 1.2) == '3(float)1.2' + assert concat('3', 3.5) == '3(float)3.5' + assert get_sum('3', 2) == 5.0 + assert get_sum('3', 5) == 8.0 + assert get_sum('3', 1.2) == 4.2 + assert get_sum('3', 3.5) == 6.5 +}