From 659bd1a4289af5ea0559e4f8cb45da61228f55f9 Mon Sep 17 00:00:00 2001 From: yuyi Date: Fri, 30 Sep 2022 07:40:30 +0800 Subject: [PATCH] ast, checker, cgen: fix generic fn/method with comptime for/if (fix #15905) (#15910) --- vlib/v/ast/ast.v | 7 ++- vlib/v/checker/assign.v | 3 + vlib/v/checker/checker.v | 2 + vlib/v/checker/fn.v | 14 +++++ vlib/v/checker/for.v | 23 +++++++ vlib/v/gen/c/cgen.v | 1 + vlib/v/gen/c/fn.v | 48 ++++++++++++++- vlib/v/gen/c/for.v | 21 ++++++- .../tests/generic_fn_with_comptime_for_test.v | 61 +++++++++++++++++++ 9 files changed, 173 insertions(+), 7 deletions(-) create mode 100644 vlib/v/tests/generic_fn_with_comptime_for_test.v diff --git a/vlib/v/ast/ast.v b/vlib/v/ast/ast.v index 64ebe153bb..b6166ca83a 100644 --- a/vlib/v/ast/ast.v +++ b/vlib/v/ast/ast.v @@ -654,9 +654,10 @@ pub mut: // 10 <- original type (orig_type) // [11, 12, 13] <- cast order (smartcasts) // 12 <- the current casted type (typ) - pos token.Pos - is_used bool // whether the local variable was used in other expressions - is_changed bool // to detect mutable vars that are never changed + pos token.Pos + is_used bool // whether the local variable was used in other expressions + is_changed bool // to detect mutable vars that are never changed + is_comptime_field bool // comptime field var `a := t.$(field.name)` // // (for setting the position after the or block for autofree) is_or bool // `x := foo() or { ... }` diff --git a/vlib/v/checker/assign.v b/vlib/v/checker/assign.v index 56e9993478..3245c69d22 100644 --- a/vlib/v/checker/assign.v +++ b/vlib/v/checker/assign.v @@ -269,6 +269,9 @@ pub fn (mut c Checker) assign_stmt(mut node ast.AssignStmt) { } } } + if right is ast.ComptimeSelector { + left.obj.is_comptime_field = true + } } ast.GlobalField { left.obj.typ = left_type diff --git a/vlib/v/checker/checker.v b/vlib/v/checker/checker.v index 36cdf0f89c..c7f3155018 100644 --- a/vlib/v/checker/checker.v +++ b/vlib/v/checker/checker.v @@ -93,6 +93,7 @@ pub mut: inside_ct_attr bool // true inside `[if expr]` inside_x_is_type bool // true inside the Type expression of `if x is Type {` inside_comptime_for_field bool + inside_for_in_any_cond bool skip_flags bool // should `#flag` and `#include` be skipped fn_level int // 0 for the top level, 1 for `fn abc() {}`, 2 for a nested fn, etc smartcast_mut_pos token.Pos // match mut foo, if mut foo is Foo @@ -112,6 +113,7 @@ mut: loop_label string // set when inside a labelled for loop vweb_gen_types []ast.Type // vweb route checks timers &util.Timers = util.get_timers() + for_in_any_val_type ast.Type comptime_fields_default_type ast.Type comptime_fields_type map[string]ast.Type fn_scope &ast.Scope = unsafe { nil } diff --git a/vlib/v/checker/fn.v b/vlib/v/checker/fn.v index abe358b390..839c03a7a7 100644 --- a/vlib/v/checker/fn.v +++ b/vlib/v/checker/fn.v @@ -1435,6 +1435,11 @@ pub fn (mut c Checker) method_call(mut node ast.CallExpr) ast.Type { } for i, mut arg in node.args { + param_idx := if method.is_variadic && i >= method.params.len - 1 { + method.params.len - 1 + } else { + i + 1 + } if i > 0 || exp_arg_typ == ast.Type(0) { exp_arg_typ = if method.is_variadic && i >= method.params.len - 1 { method.params.last().typ @@ -1449,6 +1454,15 @@ pub fn (mut c Checker) method_call(mut node ast.CallExpr) ast.Type { mut got_arg_typ := c.check_expr_opt_call(arg.expr, c.expr(arg.expr)) node.args[i].typ = got_arg_typ + if c.inside_comptime_for_field && method.params[param_idx].typ.has_flag(.generic) { + c.table.register_fn_concrete_types(method.fkey(), [ + c.comptime_fields_default_type, + ]) + } else if c.inside_for_in_any_cond && method.params[param_idx].typ.has_flag(.generic) { + c.table.register_fn_concrete_types(method.fkey(), [ + c.for_in_any_val_type, + ]) + } if no_type_promotion { if got_arg_typ != exp_arg_typ { c.error('cannot use `${c.table.sym(got_arg_typ).name}` as argument for `$method.name` (`$exp_arg_sym.name` expected)', diff --git a/vlib/v/checker/for.v b/vlib/v/checker/for.v index 1292a758ce..c22496c321 100644 --- a/vlib/v/checker/for.v +++ b/vlib/v/checker/for.v @@ -113,6 +113,27 @@ fn (mut c Checker) for_in_stmt(mut node ast.ForInStmt) { node.scope.update_var_type(node.val_var, val_type) } else if sym.kind == .string && node.val_is_mut { c.error('string type is immutable, it cannot be changed', node.pos) + } else if sym.kind == .any { + node.cond_type = typ + node.kind = sym.kind + + unwrapped_typ := c.unwrap_generic(typ) + unwrapped_sym := c.table.sym(unwrapped_typ) + + if node.key_var.len > 0 { + key_type := match unwrapped_sym.kind { + .map { unwrapped_sym.map_info().key_type } + else { ast.int_type } + } + node.key_type = key_type + node.scope.update_var_type(node.key_var, key_type) + } + + value_type := c.table.value_type(unwrapped_typ) + node.scope.update_var_type(node.val_var, value_type) + + c.inside_for_in_any_cond = true + c.for_in_any_val_type = value_type } else { if sym.kind == .map && !(node.key_var.len > 0 && node.val_var.len > 0) { c.error( @@ -176,6 +197,8 @@ fn (mut c Checker) for_in_stmt(mut node ast.ForInStmt) { c.check_loop_label(node.label, node.pos) c.stmts(node.stmts) c.loop_label = prev_loop_label + c.inside_for_in_any_cond = false + c.for_in_any_val_type = 0 c.in_for_count-- } diff --git a/vlib/v/gen/c/cgen.v b/vlib/v/gen/c/cgen.v index d40d7f8236..8a6ed58c8d 100644 --- a/vlib/v/gen/c/cgen.v +++ b/vlib/v/gen/c/cgen.v @@ -130,6 +130,7 @@ mut: inside_const bool inside_const_opt_or_res bool inside_lambda bool + inside_for_in_any_cond bool loop_depth int ternary_names map[string]string ternary_level_names map[string][]string diff --git a/vlib/v/gen/c/fn.v b/vlib/v/gen/c/fn.v index d6cc4f08d3..0ee299e542 100644 --- a/vlib/v/gen/c/fn.v +++ b/vlib/v/gen/c/fn.v @@ -715,6 +715,8 @@ fn (mut g Gen) method_call(node ast.CallExpr) { } left_type := g.unwrap_generic(node.left_type) mut unwrapped_rec_type := node.receiver_type + mut has_comptime_field := false + mut for_in_any_var_type := ast.void_type if g.cur_fn != unsafe { nil } && g.cur_fn.generic_names.len > 0 { // in generic fn unwrapped_rec_type = g.unwrap_generic(node.receiver_type) } else { // in non-generic fn @@ -733,6 +735,31 @@ fn (mut g Gen) method_call(node ast.CallExpr) { else {} } } + if g.inside_comptime_for_field { + mut node_ := unsafe { node } + for i, mut call_arg in node_.args { + if mut call_arg.expr is ast.Ident { + if mut call_arg.expr.obj is ast.Var { + node_.args[i].typ = call_arg.expr.obj.typ + if call_arg.expr.obj.is_comptime_field { + has_comptime_field = true + } + } + } else if mut call_arg.expr is ast.ComptimeSelector { + has_comptime_field = true + } + } + } + if g.inside_for_in_any_cond { + for call_arg in node.args { + if call_arg.expr is ast.Ident { + if call_arg.expr.obj is ast.Var { + for_in_any_var_type = call_arg.expr.obj.typ + } + } + } + } + mut typ_sym := g.table.sym(unwrapped_rec_type) // alias type that undefined this method (not include `str`) need to use parent type if typ_sym.kind == .alias && node.name != 'str' && !typ_sym.has_method(node.name) { @@ -1020,8 +1047,16 @@ fn (mut g Gen) method_call(node ast.CallExpr) { } } } - concrete_types := node.concrete_types.map(g.unwrap_generic(it)) - name = g.generic_fn_name(concrete_types, name) + + if g.comptime_for_field_type != 0 && g.inside_comptime_for_field && has_comptime_field { + name = g.generic_fn_name([g.comptime_for_field_type], name) + } else if g.inside_for_in_any_cond && for_in_any_var_type != ast.void_type { + name = g.generic_fn_name([for_in_any_var_type], name) + } else { + concrete_types := node.concrete_types.map(g.unwrap_generic(it)) + name = g.generic_fn_name(concrete_types, name) + } + // TODO2 // g.generate_tmp_autofree_arg_vars(node, name) if !node.receiver_type.is_ptr() && left_type.is_ptr() && node.name == 'str' { @@ -1141,6 +1176,7 @@ fn (mut g Gen) fn_call(node ast.CallExpr) { // will be `0` for `foo()` mut is_interface_call := false mut is_selector_call := false + mut has_comptime_field := false if node.left_type != 0 { left_sym := g.table.sym(node.left_type) if left_sym.kind == .interface_ { @@ -1171,7 +1207,12 @@ fn (mut g Gen) fn_call(node ast.CallExpr) { if mut call_arg.expr is ast.Ident { if mut call_arg.expr.obj is ast.Var { node_.args[i].typ = call_arg.expr.obj.typ + if call_arg.expr.obj.is_comptime_field { + has_comptime_field = true + } } + } else if mut call_arg.expr is ast.ComptimeSelector { + has_comptime_field = true } } } @@ -1262,7 +1303,8 @@ fn (mut g Gen) fn_call(node ast.CallExpr) { if !is_selector_call { if func := g.table.find_fn(node.name) { if func.generic_names.len > 0 { - if g.comptime_for_field_type != 0 && g.inside_comptime_for_field { + if g.comptime_for_field_type != 0 && g.inside_comptime_for_field + && has_comptime_field { name = g.generic_fn_name([g.comptime_for_field_type], name) } else { concrete_types := node.concrete_types.map(g.unwrap_generic(it)) diff --git a/vlib/v/gen/c/for.v b/vlib/v/gen/c/for.v index e895ac0137..6ff23b5dc1 100644 --- a/vlib/v/gen/c/for.v +++ b/vlib/v/gen/c/for.v @@ -128,7 +128,25 @@ fn (mut g Gen) for_stmt(node ast.ForStmt) { g.loop_depth-- } -fn (mut g Gen) for_in_stmt(node ast.ForInStmt) { +fn (mut g Gen) for_in_stmt(node_ ast.ForInStmt) { + mut node := unsafe { node_ } + if node.kind == .any { + g.inside_for_in_any_cond = true + unwrapped_typ := g.unwrap_generic(node.cond_type) + unwrapped_sym := g.table.sym(unwrapped_typ) + node.kind = unwrapped_sym.kind + node.cond_type = unwrapped_typ + if node.key_var.len > 0 { + key_type := match unwrapped_sym.kind { + .map { unwrapped_sym.map_info().key_type } + else { ast.int_type } + } + node.key_type = key_type + node.scope.update_var_type(node.key_var, key_type) + } + node.val_type = g.table.value_type(unwrapped_typ) + node.scope.update_var_type(node.val_var, node.val_type) + } g.loop_depth++ if node.label.len > 0 { g.writeln('\t$node.label: {}') @@ -384,5 +402,6 @@ fn (mut g Gen) for_in_stmt(node ast.ForInStmt) { if node.label.len > 0 { g.writeln('\t${node.label}__break: {}') } + g.inside_for_in_any_cond = false g.loop_depth-- } diff --git a/vlib/v/tests/generic_fn_with_comptime_for_test.v b/vlib/v/tests/generic_fn_with_comptime_for_test.v new file mode 100644 index 0000000000..a73fe88c9b --- /dev/null +++ b/vlib/v/tests/generic_fn_with_comptime_for_test.v @@ -0,0 +1,61 @@ +fn fn_a(data T, depth int, nl bool) { + for _ in 0 .. depth { + print('\t') + } + $if T.typ is int { + print('int: $data') + } $else $if T.typ is string { + print('string: $data') + } $else $if T is $Array { + println('array: [') + for i, elem in data { + fn_a(elem, depth + 1, false) + if i < data.len - 1 { + print(', ') + } + println('') + } + print(']') + } $else $if T is $Map { + println('map: {') + for key, value in data { + print('\t(key) ') + fn_a(key, depth, false) + print(' -> (value) ') + fn_a(value, depth, true) + } + print('}') + } $else $if T is $Struct { + println('struct $T.name: {') + $for field in T.fields { + print('\t($field.name) ') + fn_a(data.$(field.name), depth, true) + // uncommenting either of these lines will cause a C error in my branch as the type is + // set manually to the $for field type, it needs to be fixed to infer the type correctly. + fn_a(['1', '2', '3', '4'], depth + 1, true) + fn_a(0.111, depth + 1, true) + fn_a('hello', depth + 1, true) + } + print('}') + } + if nl { + println('') + } +} + +struct StructA { + field_a int + field_b string +} + +fn test_generic_fn_with_comptime_for() { + fn_a(111, 0, true) + fn_a('hello', 0, true) + fn_a(['a', 'b', 'c', 'd'], 0, true) + fn_a({ + 'one': 1 + 'two': 2 + }, 0, true) + fn_a(StructA{ field_a: 111, field_b: 'vlang' }, 0, true) + assert true +}