From 30029eaf5d1e4c329ecb376b7dd77e55724d7fa7 Mon Sep 17 00:00:00 2001 From: yuyi Date: Mon, 13 Sep 2021 14:49:28 +0800 Subject: [PATCH] checker, cgen: fix generic operator overload (fix #11472) (#11479) --- vlib/v/checker/checker.v | 18 ++++++---- vlib/v/gen/c/infix_expr.v | 36 ++++++++++++------- vlib/v/tests/generic_operator_overload_test.v | 35 ++++++++++++++++++ 3 files changed, 71 insertions(+), 18 deletions(-) create mode 100644 vlib/v/tests/generic_operator_overload_test.v diff --git a/vlib/v/checker/checker.v b/vlib/v/checker/checker.v index 672776e9f4..a3c5606217 100644 --- a/vlib/v/checker/checker.v +++ b/vlib/v/checker/checker.v @@ -1435,13 +1435,19 @@ pub fn (mut c Checker) infix_expr(mut node ast.InfixExpr) ast.Type { return_type = left_type } } else { - left_name := c.table.type_to_str(left_type) - right_name := c.table.type_to_str(right_type) - if left_name == right_name { - c.error('undefined operation `$left_name` $node.op.str() `$right_name`', - left_right_pos) + if left_sym.kind == .struct_ + && (left_sym.info as ast.Struct).generic_types.len > 0 { + return_type = left_type } else { - c.error('mismatched types `$left_name` and `$right_name`', left_right_pos) + left_name := c.table.type_to_str(left_type) + right_name := c.table.type_to_str(right_type) + if left_name == right_name { + c.error('undefined operation `$left_name` $node.op.str() `$right_name`', + left_right_pos) + } else { + c.error('mismatched types `$left_name` and `$right_name`', + left_right_pos) + } } } } else if right_sym.kind in [.array, .array_fixed, .map, .struct_] { diff --git a/vlib/v/gen/c/infix_expr.v b/vlib/v/gen/c/infix_expr.v index 2ed706b412..71c4d5a7ea 100644 --- a/vlib/v/gen/c/infix_expr.v +++ b/vlib/v/gen/c/infix_expr.v @@ -468,19 +468,31 @@ fn (mut g Gen) gen_interface_is_op(node ast.InfixExpr) { fn (mut g Gen) infix_expr_arithmetic_op(node ast.InfixExpr) { left := g.unwrap(node.left_type) right := g.unwrap(node.right_type) - method := g.table.type_find_method(left.sym, node.op.str()) or { - g.gen_plain_infix_expr(node) - return + if left.sym.kind == .struct_ && (left.sym.info as ast.Struct).generic_types.len > 0 { + concrete_types := (left.sym.info as ast.Struct).concrete_types + mut method_name := left.sym.cname + '_' + util.replace_op(node.op.str()) + method_name = g.generic_fn_name(concrete_types, method_name, true) + g.write(method_name) + g.write('(') + g.expr(node.left) + g.write(', ') + g.expr(node.right) + g.write(')') + } else { + method := g.table.type_find_method(left.sym, node.op.str()) or { + g.gen_plain_infix_expr(node) + return + } + left_styp := g.typ(left.typ.set_nr_muls(0)) + g.write(left_styp) + g.write('_') + g.write(util.replace_op(node.op.str())) + g.write('(') + g.op_arg(node.left, method.params[0].typ, left.typ) + g.write(', ') + g.op_arg(node.right, method.params[1].typ, right.typ) + g.write(')') } - left_styp := g.typ(left.typ.set_nr_muls(0)) - g.write(left_styp) - g.write('_') - g.write(util.replace_op(node.op.str())) - g.write('(') - g.op_arg(node.left, method.params[0].typ, left.typ) - g.write(', ') - g.op_arg(node.right, method.params[1].typ, right.typ) - g.write(')') } // infix_expr_left_shift_op generates code for the `<<` operator diff --git a/vlib/v/tests/generic_operator_overload_test.v b/vlib/v/tests/generic_operator_overload_test.v new file mode 100644 index 0000000000..cbc227321c --- /dev/null +++ b/vlib/v/tests/generic_operator_overload_test.v @@ -0,0 +1,35 @@ +struct Matrix { + row int + col int +mut: + data [][]T +} + +fn from_array(arr [][]T) Matrix { + return Matrix{ + row: arr.len + col: arr[0].len + data: arr.clone() + } +} + +fn (m1 Matrix) + (m2 Matrix) Matrix { + if m1.row != m2.row || m1.col != m2.col { + panic('Addition can only be performed on matrix with same size') + } + mut res := m1 + for i in 0 .. m2.row { + for j in 0 .. m2.col { + res.data[i][j] += m2.data[i][j] + } + } + return res +} + +fn test_generic_operator_overload() { + result := from_array([[1, 2, 3], [4, 5, 6]]) + from_array([[7, 8, 9], [10, 11, 12]]) + println(result) + assert result.row == 2 + assert result.col == 3 + assert result.data == [[8, 10, 12], [14, 16, 18]] +}