From 6f629d1a6a53b8157c58c4eeedf5b6cec03ea4c7 Mon Sep 17 00:00:00 2001 From: ChAoS_UnItY <43753315+ChAoSUnItY@users.noreply.github.com> Date: Thu, 14 Oct 2021 07:15:52 +0800 Subject: [PATCH] transformer: eliminate unreachable branches & redundant branch expressions in MatchExpr (#12174) --- vlib/v/ast/ast.v | 2 +- vlib/v/transformer/transformer.v | 247 +++++++++++++++++++++++++++++-- 2 files changed, 234 insertions(+), 15 deletions(-) diff --git a/vlib/v/ast/ast.v b/vlib/v/ast/ast.v index 59c59a166f..20f85b4a83 100644 --- a/vlib/v/ast/ast.v +++ b/vlib/v/ast/ast.v @@ -866,12 +866,12 @@ pub mut: pub struct MatchBranch { pub: ecmnts [][]Comment // inline comments for each left side expr - stmts []Stmt // right side pos token.Position is_else bool post_comments []Comment // comments below ´... }´ branch_pos token.Position // for checker errors about invalid branches pub mut: + stmts []Stmt // right side exprs []Expr // left side scope &Scope } diff --git a/vlib/v/transformer/transformer.v b/vlib/v/transformer/transformer.v index 3b79b27aee..63ae283e7c 100644 --- a/vlib/v/transformer/transformer.v +++ b/vlib/v/transformer/transformer.v @@ -56,18 +56,19 @@ pub fn (t Transformer) stmt(mut node ast.Stmt) { ast.DeferStmt {} ast.EnumDecl {} ast.ExprStmt { - if node.expr is ast.IfExpr { - mut untrans_expr := node.expr as ast.IfExpr - expr := t.if_expr(mut untrans_expr) - node = &ast.ExprStmt{ - ...node - expr: expr - } - } else { - expr := t.expr(node.expr) - node = &ast.ExprStmt{ - ...node - expr: expr + expr := node.expr + node = &ast.ExprStmt{ + ...node + expr: match mut expr { + ast.IfExpr { + t.if_expr(mut expr) + } + ast.MatchExpr { + t.match_expr(mut expr) + } + else { + t.expr(expr) + } } } } @@ -116,10 +117,77 @@ pub fn (t Transformer) expr(node ast.Expr) ast.Expr { index: t.expr(node.index) } } - ast.MatchExpr { + ast.IfExpr { for mut branch in node.branches { - for mut stmt in branch.stmts { + branch = ast.IfBranch{ + ...(*branch) + cond: t.expr(branch.cond) + } + for i, mut stmt in branch.stmts { t.stmt(mut stmt) + + if i == branch.stmts.len - 1 { + if stmt is ast.ExprStmt { + expr := (stmt as ast.ExprStmt).expr + + match expr { + ast.IfExpr { + if expr.branches.len == 1 { + branch.stmts.pop() + branch.stmts << expr.branches[0].stmts + break + } + } + ast.MatchExpr { + if expr.branches.len == 1 { + branch.stmts.pop() + branch.stmts << expr.branches[0].stmts + break + } + } + else {} + } + } + } + } + } + return node + } + ast.MatchExpr { + node = ast.MatchExpr{ + ...node + cond: t.expr(node.cond) + } + for mut branch in node.branches { + for mut expr in branch.exprs { + expr = t.expr(expr) + } + for i, mut stmt in branch.stmts { + t.stmt(mut stmt) + + if i == branch.stmts.len - 1 { + if stmt is ast.ExprStmt { + expr := (stmt as ast.ExprStmt).expr + + match expr { + ast.IfExpr { + if expr.branches.len == 1 { + branch.stmts.pop() + branch.stmts << expr.branches[0].stmts + break + } + } + ast.MatchExpr { + if expr.branches.len == 1 { + branch.stmts.pop() + branch.stmts << expr.branches[0].stmts + break + } + } + else {} + } + } + } } } return node @@ -132,6 +200,9 @@ pub fn (t Transformer) expr(node ast.Expr) ast.Expr { pub fn (t Transformer) if_expr(mut original ast.IfExpr) ast.Expr { mut stop_index, mut unreachable_branches := -1, []int{cap: original.branches.len} + if original.is_comptime { + return *original + } for i, mut branch in original.branches { for mut stmt in branch.stmts { t.stmt(mut stmt) @@ -171,6 +242,84 @@ pub fn (t Transformer) if_expr(mut original ast.IfExpr) ast.Expr { return *original } +pub fn (t Transformer) match_expr(mut original ast.MatchExpr) ast.Expr { + cond, mut terminate := t.expr(original.cond), false + original = ast.MatchExpr{ + ...(*original) + cond: cond + } + for mut branch in original.branches { + if branch.is_else { + continue + } + + for mut stmt in branch.stmts { + t.stmt(mut stmt) + } + + for mut expr in branch.exprs { + expr = t.expr(expr) + + match cond { + ast.BoolLiteral { + if expr is ast.BoolLiteral { + if cond.val == (expr as ast.BoolLiteral).val { + branch.exprs = [expr] + original = ast.MatchExpr{ + ...(*original) + branches: [branch] + } + terminate = true + } + } + } + ast.IntegerLiteral { + if expr is ast.IntegerLiteral { + if cond.val.int() == (expr as ast.IntegerLiteral).val.int() { + branch.exprs = [expr] + original = ast.MatchExpr{ + ...(*original) + branches: [branch] + } + terminate = true + } + } + } + ast.FloatLiteral { + if expr is ast.FloatLiteral { + if cond.val.f32() == (expr as ast.FloatLiteral).val.f32() { + branch.exprs = [expr] + original = ast.MatchExpr{ + ...(*original) + branches: [branch] + } + terminate = true + } + } + } + ast.StringLiteral { + if expr is ast.StringLiteral { + if cond.val == (expr as ast.StringLiteral).val { + branch.exprs = [expr] + original = ast.MatchExpr{ + ...(*original) + branches: [branch] + } + terminate = true + } + } + } + else {} + } + } + + if terminate { + break + } + } + return *original +} + pub fn (t Transformer) infix_expr(original ast.InfixExpr) ast.Expr { mut node := original node.left = t.expr(node.left) @@ -357,6 +506,76 @@ pub fn (t Transformer) infix_expr(original ast.InfixExpr) ast.Expr { } } } + ast.FloatLiteral { + match right_node { + ast.FloatLiteral { + left_val := left_node.val.f32() + right_val := right_node.val.f32() + match node.op { + .eq { + return ast.BoolLiteral{ + val: left_node.val == right_node.val + } + } + .ne { + return ast.BoolLiteral{ + val: left_node.val != right_node.val + } + } + .gt { + return ast.BoolLiteral{ + val: left_node.val > right_node.val + } + } + .ge { + return ast.BoolLiteral{ + val: left_node.val >= right_node.val + } + } + .lt { + return ast.BoolLiteral{ + val: left_node.val < right_node.val + } + } + .le { + return ast.BoolLiteral{ + val: left_node.val <= right_node.val + } + } + .plus { + return ast.FloatLiteral{ + val: (left_val + right_val).str() + pos: pos + } + } + .mul { + return ast.FloatLiteral{ + val: (left_val * right_val).str() + pos: pos + } + } + .minus { + return ast.FloatLiteral{ + val: (left_val - right_val).str() + pos: pos + } + } + .div { + return ast.FloatLiteral{ + val: (left_val / right_val).str() + pos: pos + } + } + else { + return node + } + } + } + else { + return node + } + } + } else { return node }