From 88d18f3303ce77ce3ff99333802cc0ba796cdece Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20D=C3=A4schle?= Date: Thu, 7 Jan 2021 21:35:32 +0100 Subject: [PATCH] checker: smartcast in for loops (#7942) --- CHANGELOG.md | 1 + vlib/v/ast/ast.v | 1 + vlib/v/checker/checker.v | 239 +++++++++++++++--------------- vlib/v/parser/for.v | 3 + vlib/v/tests/for_smartcast_test.v | 53 +++++++ 5 files changed, 174 insertions(+), 123 deletions(-) create mode 100644 vlib/v/tests/for_smartcast_test.v diff --git a/CHANGELOG.md b/CHANGELOG.md index b1dc1efd06..e193adcc45 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ - Overloading of `>`, `<`, `!=`, and `==` operators. - New struct updating syntax: `User{ ...u, name: 'new' }` to replace `{ u | name: 'new' }`. - `byte.str()` has been fixed and works like with all other numbers. `byte.ascii_str()` has been added. +- Smart cast in for-loops: `for mut x is string {}` ## V 0.2.1 *30 Dec 2020* diff --git a/vlib/v/ast/ast.v b/vlib/v/ast/ast.v index 57b25d25c3..e92618e894 100644 --- a/vlib/v/ast/ast.v +++ b/vlib/v/ast/ast.v @@ -120,6 +120,7 @@ pub mut: pub fn (e &SelectorExpr) root_ident() Ident { mut root := e.expr for root is SelectorExpr { + // TODO: remove this line selector_expr := root as SelectorExpr root = selector_expr.expr } diff --git a/vlib/v/checker/checker.v b/vlib/v/checker/checker.v index 5f91476c2d..fc91ee7626 100644 --- a/vlib/v/checker/checker.v +++ b/vlib/v/checker/checker.v @@ -2886,19 +2886,7 @@ fn (mut c Checker) stmt(node ast.Stmt) { c.in_for_count-- } ast.ForStmt { - c.in_for_count++ - prev_loop_label := c.loop_label - c.expected_type = table.bool_type - typ := c.expr(node.cond) - if !node.is_inf && typ.idx() != table.bool_type_idx && !c.pref.translated { - c.error('non-bool used as for condition', node.pos) - } - // TODO: update loop var type - // how does this work currenly? - c.check_loop_label(node.label, node.pos) - c.stmts(node.stmts) - c.loop_label = prev_loop_label - c.in_for_count-- + c.for_stmt(mut node) } ast.GlobalDecl { for field in node.fields { @@ -3884,60 +3872,7 @@ fn (mut c Checker) match_exprs(mut node ast.MatchExpr, cond_type_sym table.TypeS } else { expr_type = expr_types[0].typ } - match mut node.cond { - ast.SelectorExpr { - mut is_mut := false - mut sum_type_casts := []table.Type{} - expr_sym := c.table.get_type_symbol(node.cond.expr_type) - if field := c.table.struct_find_field(expr_sym, node.cond.field_name) { - if field.is_mut { - root_ident := node.cond.root_ident() - if v := branch.scope.find_var(root_ident.name) { - is_mut = v.is_mut - } - } - } - if field := branch.scope.find_struct_field(node.cond.expr_type, - node.cond.field_name) { - sum_type_casts << field.sum_type_casts - } - // smartcast either if the value is immutable or if the mut argument is explicitly given - if !is_mut || node.cond.is_mut { - sum_type_casts << expr_type - branch.scope.register_struct_field(ast.ScopeStructField{ - struct_type: node.cond.expr_type - name: node.cond.field_name - typ: node.cond_type - sum_type_casts: sum_type_casts - pos: node.cond.pos - }) - } - } - ast.Ident { - mut is_mut := false - mut sum_type_casts := []table.Type{} - mut is_already_casted := false - if node.cond.obj is ast.Var { - v := node.cond.obj as ast.Var - is_mut = v.is_mut - sum_type_casts << v.sum_type_casts - is_already_casted = v.pos.pos == node.cond.pos.pos - } - // smartcast either if the value is immutable or if the mut argument is explicitly given - if (!is_mut || node.cond.is_mut) && !is_already_casted { - sum_type_casts << expr_type - branch.scope.register(ast.Var{ - name: node.cond.name - typ: node.cond_type - pos: node.cond.pos - is_used: true - is_mut: node.cond.is_mut - sum_type_casts: sum_type_casts - }) - } - } - else {} - } + c.smartcast_sumtype(node.cond, node.cond_type, expr_type, mut branch.scope) } } } @@ -4017,6 +3952,63 @@ fn (mut c Checker) match_exprs(mut node ast.MatchExpr, cond_type_sym table.TypeS c.error(err_details, node.pos) } +// smartcast takes the expression with the current type which should be smartcasted to the target type in the given scope +fn (c Checker) smartcast_sumtype(expr ast.Expr, cur_type table.Type, to_type table.Type, mut scope ast.Scope) { + match mut expr { + ast.SelectorExpr { + mut is_mut := false + mut sum_type_casts := []table.Type{} + expr_sym := c.table.get_type_symbol(expr.expr_type) + if field := c.table.struct_find_field(expr_sym, expr.field_name) { + if field.is_mut { + root_ident := expr.root_ident() + if v := scope.find_var(root_ident.name) { + is_mut = v.is_mut + } + } + } + if field := scope.find_struct_field(expr.expr_type, expr.field_name) { + sum_type_casts << field.sum_type_casts + } + // smartcast either if the value is immutable or if the mut argument is explicitly given + if !is_mut || expr.is_mut { + sum_type_casts << to_type + scope.register_struct_field(ast.ScopeStructField{ + struct_type: expr.expr_type + name: expr.field_name + typ: cur_type + sum_type_casts: sum_type_casts + pos: expr.pos + }) + } + } + ast.Ident { + mut is_mut := false + mut sum_type_casts := []table.Type{} + mut is_already_casted := false + if expr.obj is ast.Var { + v := expr.obj as ast.Var + is_mut = v.is_mut + sum_type_casts << v.sum_type_casts + is_already_casted = v.pos.pos == expr.pos.pos + } + // smartcast either if the value is immutable or if the mut argument is explicitly given + if (!is_mut || expr.is_mut) && !is_already_casted { + sum_type_casts << to_type + scope.register(ast.Var{ + name: expr.name + typ: cur_type + pos: expr.pos + is_used: true + is_mut: expr.is_mut + sum_type_casts: sum_type_casts + }) + } + } + else {} + } +} + pub fn (mut c Checker) select_expr(mut node ast.SelectExpr) table.Type { node.is_expr = c.expected_type != table.void_type node.expected_type = c.expected_type @@ -4114,6 +4106,45 @@ pub fn (mut c Checker) unsafe_expr(mut node ast.UnsafeExpr) table.Type { return t } +fn (mut c Checker) for_stmt(mut node ast.ForStmt) { + c.in_for_count++ + prev_loop_label := c.loop_label + c.expected_type = table.bool_type + typ := c.expr(node.cond) + if !node.is_inf && typ.idx() != table.bool_type_idx && !c.pref.translated { + c.error('non-bool used as for condition', node.pos) + } + if node.cond is ast.InfixExpr { + infix := node.cond + if infix.op == .key_is { + if (infix.left is ast.Ident || + infix.left is ast.SelectorExpr) && + infix.right is ast.Type { + right_expr := infix.right as ast.Type + is_variable := if mut infix.left is ast.Ident { + infix.left.kind == .variable + } else { + true + } + left_type := c.expr(infix.left) + left_sym := c.table.get_type_symbol(left_type) + if is_variable { + if left_sym.kind == .sum_type { + c.smartcast_sumtype(infix.left, infix.left_type, right_expr.typ, mut + node.scope) + } + } + } + } + } + // TODO: update loop var type + // how does this work currenly? + c.check_loop_label(node.label, node.pos) + c.stmts(node.stmts) + c.loop_label = prev_loop_label + c.in_for_count-- +} + pub fn (mut c Checker) if_expr(mut node ast.IfExpr) table.Type { if_kind := if node.is_comptime { '\$if' } else { 'if' } expr_required := c.expected_type != table.void_type @@ -4167,68 +4198,30 @@ pub fn (mut c Checker) if_expr(mut node ast.IfExpr) table.Type { } else { true } - // Register shadow variable or `as` variable with actual type if is_variable { - // TODO: merge this code with match_expr because it has the same logic implemented if left_sym.kind in [.interface_, .sum_type] { - mut is_mut := false - if mut infix.left is ast.Ident { + if infix.left is ast.Ident && left_sym.kind == .interface_ { + // TODO: rewrite interface smartcast + left := infix.left as ast.Ident + mut is_mut := false mut sum_type_casts := []table.Type{} - if v := branch.scope.find_var(infix.left.name) { + if v := branch.scope.find_var(left.name) { is_mut = v.is_mut sum_type_casts << v.sum_type_casts } - if left_sym.kind == .sum_type { - // smartcast either if the value is immutable or if the mut argument is explicitly given - if !is_mut || infix.left.is_mut { - sum_type_casts << right_expr.typ - branch.scope.register(ast.Var{ - name: infix.left.name - typ: infix.left_type - sum_type_casts: sum_type_casts - pos: infix.left.pos - is_used: true - is_mut: is_mut - }) - } - } else if left_sym.kind == .interface_ { - branch.scope.register(ast.Var{ - name: infix.left.name - typ: right_expr.typ.to_ptr() - sum_type_casts: sum_type_casts - pos: infix.left.pos - is_used: true - is_mut: is_mut - }) - // TODO: remove that later @danieldaeschle - node.branches[i].smartcast = true - } - } else if mut infix.left is ast.SelectorExpr { - mut sum_type_casts := []table.Type{} - expr_sym := c.table.get_type_symbol(infix.left.expr_type) - if field := c.table.struct_find_field(expr_sym, infix.left.field_name) { - if field.is_mut { - root_ident := infix.left.root_ident() - if root_ident.obj is ast.Var { - is_mut = root_ident.obj.is_mut - } - } - } - if field := branch.scope.find_struct_field(infix.left.expr_type, - infix.left.field_name) { - sum_type_casts << field.sum_type_casts - } - // smartcast either if the value is immutable or if the mut argument is explicitly given - if (!is_mut || infix.left.is_mut) && left_sym.kind == .sum_type { - sum_type_casts << right_expr.typ - branch.scope.register_struct_field(ast.ScopeStructField{ - struct_type: infix.left.expr_type - name: infix.left.field_name - typ: infix.left_type - sum_type_casts: sum_type_casts - pos: infix.left.pos - }) - } + branch.scope.register(ast.Var{ + name: left.name + typ: right_expr.typ.to_ptr() + sum_type_casts: sum_type_casts + pos: left.pos + is_used: true + is_mut: is_mut + }) + // TODO: needs to be removed + node.branches[i].smartcast = true + } else { + c.smartcast_sumtype(infix.left, infix.left_type, right_expr.typ, mut + branch.scope) } } } diff --git a/vlib/v/parser/for.v b/vlib/v/parser/for.v index adfca6be5f..33ba8cb76e 100644 --- a/vlib/v/parser/for.v +++ b/vlib/v/parser/for.v @@ -177,6 +177,8 @@ fn (mut p Parser) for_stmt() ast.Stmt { // `for cond {` cond := p.expr(0) p.inside_for = false + // extra scope for the body + p.open_scope() stmts := p.parse_block_no_scope(false) pos.last_line = p.prev_tok.line_nr - 1 for_stmt := ast.ForStmt{ @@ -186,5 +188,6 @@ fn (mut p Parser) for_stmt() ast.Stmt { scope: p.scope } p.close_scope() + p.close_scope() return for_stmt } diff --git a/vlib/v/tests/for_smartcast_test.v b/vlib/v/tests/for_smartcast_test.v new file mode 100644 index 0000000000..1e2d66c0bd --- /dev/null +++ b/vlib/v/tests/for_smartcast_test.v @@ -0,0 +1,53 @@ +type Node = Expr | string +type Expr = IfExpr | IntegerLiteral + +struct IntegerLiteral {} +struct IfExpr { + pos int +} + +struct NodeWrapper { + node Node +} + +fn test_nested_sumtype_selector() { + c := NodeWrapper{Node(Expr(IfExpr{pos: 1}))} + for c.node is Expr { + assert typeof(c.node).name == 'Expr' + break + } +} + +struct Milk { +mut: + name string +} + +struct Eggs { +mut: + name string +} + +type Food = Milk | Eggs + +struct FoodWrapper { +mut: + food Food +} + +fn test_match_mut() { + mut f := Food(Eggs{'test'}) + for mut f is Eggs { + f.name = 'eggs' + assert f.name == 'eggs' + break + } +} + +fn test_conditional_break() { + mut f := Food(Eggs{'test'}) + for mut f is Eggs { + f = Milk{'test'} + } + assert true +}