diff --git a/vlib/v/checker/check_types.v b/vlib/v/checker/check_types.v index a2d54cbf1c..41556588fa 100644 --- a/vlib/v/checker/check_types.v +++ b/vlib/v/checker/check_types.v @@ -102,21 +102,8 @@ pub fn (c &Checker) check_basic(got, expected table.Type) bool { return true } // sum type - // TODO: there is a bug when casting sumtype the other way if its pointer - // so until fixed at least show v (not C) error `x(variant) = y(SumType*)` - // if got_type_sym.kind == .sum_type { - // sum_info := got_type_sym.info as table.SumType - // // TODO: handle `match SumType { &PtrVariant {} }` currently just checking base - // if expected.set_nr_muls(0) in sum_info.variants { - // return true - // } - // } - if exp_type_sym.kind == .sum_type { - sum_info := exp_type_sym.info as table.SumType - // TODO: handle `match SumType { &PtrVariant {} }` currently just checking base - if got.set_nr_muls(0) in sum_info.variants { - return true - } + if c.table.check_sumtype_compatibility(got, expected) { + return true } // fn type if got_type_sym.kind == .function && exp_type_sym.kind == .function { diff --git a/vlib/v/checker/checker.v b/vlib/v/checker/checker.v index 00a2c0f1c5..ab797c704e 100644 --- a/vlib/v/checker/checker.v +++ b/vlib/v/checker/checker.v @@ -2462,10 +2462,7 @@ pub fn (mut c Checker) match_expr(mut node ast.MatchExpr) table.Type { typ_sym := c.table.get_type_symbol(typ) if node.is_sum_type || node.is_interface { ok := if cond_type_sym.kind == .sum_type { - // TODO verify sum type - // true // c.check_types(typ, cond_type) - info := cond_type_sym.info as table.SumType - typ in info.variants + c.table.check_sumtype_has_variant(cond_type, typ) } else { // interface match c.type_implements(typ, cond_type, node.pos) diff --git a/vlib/v/gen/cgen.v b/vlib/v/gen/cgen.v index c4deb0dc72..381dd5e8ae 100644 --- a/vlib/v/gen/cgen.v +++ b/vlib/v/gen/cgen.v @@ -949,19 +949,15 @@ fn (mut g Gen) for_in(it ast.ForInStmt) { fn (mut g Gen) expr_with_cast(expr ast.Expr, got_type, expected_type table.Type) { // cast to sum type if expected_type != table.void_type { - exp_sym := g.table.get_type_symbol(expected_type) - if exp_sym.kind == .sum_type { - sum_info := exp_sym.info as table.SumType - if got_type in sum_info.variants { - got_sym := g.table.get_type_symbol(got_type) - got_styp := g.typ(got_type) - exp_styp := g.typ(expected_type) - got_idx := got_type.idx() - g.write('/* sum type cast */ ($exp_styp) {.obj = memdup(&($got_styp[]) {') - g.expr(expr) - g.write('}, sizeof($got_styp)), .typ = $got_idx /* $got_sym.name */}') - return - } + if g.table.check_sumtype_has_variant(expected_type, got_type) { + got_sym := g.table.get_type_symbol(got_type) + got_styp := g.typ(got_type) + exp_styp := g.typ(expected_type) + got_idx := got_type.idx() + g.write('/* sum type cast */ ($exp_styp) {.obj = memdup(&($got_styp[]) {') + g.expr(expr) + g.write('}, sizeof($got_styp)), .typ = $got_idx /* $got_sym.name */}') + return } } // Generic dereferencing logic diff --git a/vlib/v/table/table.v b/vlib/v/table/table.v index c23e5aa9b5..715981c08e 100644 --- a/vlib/v/table/table.v +++ b/vlib/v/table/table.v @@ -491,3 +491,32 @@ pub fn (table &Table) register_fn_gen_type(fn_name string, typ Type) { // println('registering fn ($fn_name) gen type $sym.name') table.fn_gen_types[fn_name] = a } + + +// TODO: there is a bug when casting sumtype the other way if its pointer +// so until fixed at least show v (not C) error `x(variant) = y(SumType*)` +pub fn (table &Table) check_sumtype_has_variant(parent Type, variant Type) bool { + parent_sym := table.get_type_symbol(parent) + if parent_sym.kind ==.sum_type { + parent_info := parent_sym.info as SumType + for v in parent_info.variants { + if v.idx() == variant.idx() { + return true + } + if table.check_sumtype_has_variant(v, variant) { + return true + } + } + } + return false +} + +pub fn (table &Table) check_sumtype_compatibility(a Type, b Type) bool { + if table.check_sumtype_has_variant(a, b) { + return true + } + if table.check_sumtype_has_variant(b, a) { + return true + } + return false +} diff --git a/vlib/v/tests/sum_type_test.v b/vlib/v/tests/sum_type_test.v index 425d95f2e3..f797e38da9 100644 --- a/vlib/v/tests/sum_type_test.v +++ b/vlib/v/tests/sum_type_test.v @@ -1,4 +1,15 @@ type Expr = IfExpr | IntegerLiteral +type Stmt = FnDecl | StructDecl +type Node = Expr | Stmt + +struct FnDecl { + pos int +} + +struct StructDecl { + pos int +} + struct IfExpr { pos int @@ -94,3 +105,26 @@ fn test_converting_down() { assert res[1].val == 3 assert res[1].name == 'three' } + +fn test_nested_sumtype() { + mut a := Node{} + mut b := Node{} + a = StructDecl{pos: 1} + b = IfExpr {pos: 1} + match a { + StructDecl { + assert true + } + else { + assert false + } + } + // TODO: not working + // assert b is IfExpr + if b is IfExpr { + assert true + } + else { + assert false + } +}