diff --git a/vlib/v/ast/table.v b/vlib/v/ast/table.v index 30715d6d8d..cc8c8df274 100644 --- a/vlib/v/ast/table.v +++ b/vlib/v/ast/table.v @@ -1139,9 +1139,20 @@ pub fn (t &Table) sumtype_has_variant(parent Type, variant Type) bool { parent_sym := t.get_type_symbol(parent) if parent_sym.kind == .sum_type { parent_info := parent_sym.info as SumType - for v in parent_info.variants { - if v.idx() == variant.idx() { - return true + var_sym := t.get_type_symbol(variant) + if var_sym.kind == .aggregate { + var_info := var_sym.info as Aggregate + for var_type in var_info.types { + if !t.sumtype_has_variant(parent, var_type) { + return false + } + } + return true + } else { + for v in parent_info.variants { + if v.idx() == variant.idx() { + return true + } } } } diff --git a/vlib/v/checker/checker.v b/vlib/v/checker/checker.v index 94b8f439a0..cdd0aad7d7 100644 --- a/vlib/v/checker/checker.v +++ b/vlib/v/checker/checker.v @@ -1555,6 +1555,10 @@ pub fn (mut c Checker) infix_expr(mut node ast.InfixExpr) ast.Type { } return ast.void_type } + if left_value_sym.kind == .sum_type + && c.table.sumtype_has_variant(left_value_type, right_type) { + return ast.void_type + } // []T << T or []T << []T unwrapped_right_type := c.unwrap_generic(right_type) if c.check_types(unwrapped_right_type, left_value_type) diff --git a/vlib/v/checker/tests/sumtype_mismatch_of_aggregate_err.out b/vlib/v/checker/tests/sumtype_mismatch_of_aggregate_err.out new file mode 100644 index 0000000000..14cfcb56c3 --- /dev/null +++ b/vlib/v/checker/tests/sumtype_mismatch_of_aggregate_err.out @@ -0,0 +1,7 @@ +vlib/v/checker/tests/sumtype_mismatch_of_aggregate_err.vv:7:11: error: cannot use `(i8 | i16 | int | i64)` as type `SimpleInt` in return argument + 5 | match s { + 6 | i8, i16, int, i64 { + 7 | return s + | ^ + 8 | } + 9 | } diff --git a/vlib/v/checker/tests/sumtype_mismatch_of_aggregate_err.vv b/vlib/v/checker/tests/sumtype_mismatch_of_aggregate_err.vv new file mode 100644 index 0000000000..0aea56cd10 --- /dev/null +++ b/vlib/v/checker/tests/sumtype_mismatch_of_aggregate_err.vv @@ -0,0 +1,13 @@ +type SimpleInt = i64 | int +type SuperInt = i16 | i64 | i8 | int + +fn ret_super(s SuperInt) SimpleInt { + match s { + i8, i16, int, i64 { + return s + } + } +} + +fn main() { +} diff --git a/vlib/v/tests/array_of_sumtype_append_aggregate_type_test.v b/vlib/v/tests/array_of_sumtype_append_aggregate_type_test.v new file mode 100644 index 0000000000..efae04056b --- /dev/null +++ b/vlib/v/tests/array_of_sumtype_append_aggregate_type_test.v @@ -0,0 +1,29 @@ +struct StructA { + value int +} + +struct StructB { + value int + offset int +} + +type AB = StructA | StructB + +fn test_sumtype_array_append_aggregate_type() { + mut arr := []AB{} + arr << StructA{0} + arr << StructB{0, 1} + + mut arr2 := []AB{} + + for a_or_b in arr { + match a_or_b { + StructA, StructB { + arr2 << a_or_b + } + } + } + + println(arr2) + assert arr2.len == 2 +} diff --git a/vlib/v/tests/match_sumtype_var_aggregate_test.v b/vlib/v/tests/match_sumtype_var_aggregate_test.v new file mode 100644 index 0000000000..9478b21dcf --- /dev/null +++ b/vlib/v/tests/match_sumtype_var_aggregate_test.v @@ -0,0 +1,21 @@ +struct Foo {} + +struct Bar {} + +struct Baz {} + +type Sum = Bar | Baz | Foo + +fn foo(s Sum) Sum { + match s { + Foo, Bar { return s } + Baz { return Baz{} } + } +} + +fn test_match_sumtype_var_aggregate() { + a := Foo{} + ret := foo(a) + println(ret) + assert '$ret' == 'Sum(Foo{})' +}