diff --git a/vlib/orm/orm_insert_test.v b/vlib/orm/orm_insert_test.v index 670b259430..8ee51a41c4 100644 --- a/vlib/orm/orm_insert_test.v +++ b/vlib/orm/orm_insert_test.v @@ -21,16 +21,40 @@ mut: text string } +pub fn insert_parent(db sqlite.DB, mut parent Parent) { + sql db { + insert parent into Parent + } +} + +fn test_orm_insert_mut_object() { + db := sqlite.connect(':memory:') or { panic(err) } + + sql db { + create table Parent + create table Child + create table Note + } + + mut parent := Parent{ + name: 'test' + } + + insert_parent(db, mut parent) + + parents := sql db { + select from Parent + } + + assert parents.len == 1 +} + fn test_orm_insert_with_multiple_child_elements() { mut db := sqlite.connect(':memory:') or { panic(err) } sql db { create table Parent - } - sql db { create table Child - } - sql db { create table Note } diff --git a/vlib/v/checker/orm.v b/vlib/v/checker/orm.v index f562f043a4..a321f0e41a 100644 --- a/vlib/v/checker/orm.v +++ b/vlib/v/checker/orm.v @@ -185,14 +185,19 @@ fn (mut c Checker) sql_stmt_line(mut node ast.SqlStmtLine) ast.Type { if node.kind == .insert && node.is_top_level { inserting_object_name := node.object_var_name - inserting_object_var := node.scope.find(inserting_object_name) or { + inserting_object := node.scope.find(inserting_object_name) or { c.error('undefined ident: `${inserting_object_name}`', node.pos) return ast.void_type } + mut inserting_object_type := inserting_object.typ - if inserting_object_var.typ != node.table_expr.typ { + if inserting_object_type.is_ptr() { + inserting_object_type = inserting_object.typ.deref() + } + + if inserting_object_type != node.table_expr.typ { table_name := table_sym.name - inserting_type_name := c.table.sym(inserting_object_var.typ).name + inserting_type_name := c.table.sym(inserting_object_type).name c.error('cannot use `${inserting_type_name}` as `${table_name}`', node.pos) return ast.void_type diff --git a/vlib/v/gen/c/sql.v b/vlib/v/gen/c/sql.v index 5e09d7d84b..93ca50b13a 100644 --- a/vlib/v/gen/c/sql.v +++ b/vlib/v/gen/c/sql.v @@ -174,6 +174,16 @@ fn (mut g Gen) sql_insert(node ast.SqlStmtLine, expr string, table_name string, } g.write('),') + mut member_access_type := '.' + + if node.scope != unsafe { nil } { + inserting_object := node.scope.find(node.object_var_name) or { verror(err.str()) } + + if inserting_object.typ.is_ptr() { + member_access_type = '->' + } + } + g.write('.data = new_array_from_c_array(${fields.len}, ${fields.len}, sizeof(orm__Primitive),') if fields.len > 0 { g.write(' _MOV((orm__Primitive[${fields.len}]){') @@ -193,7 +203,8 @@ fn (mut g Gen) sql_insert(node ast.SqlStmtLine, expr string, table_name string, if typ == 'time__Time' { typ = 'time' } - g.write('orm__${typ}_to_primitive(${node.object_var_name}.${f.name}),') + + g.write('orm__${typ}_to_primitive(${node.object_var_name}${member_access_type}${f.name}),') } g.write('})') } else { @@ -210,12 +221,12 @@ fn (mut g Gen) sql_insert(node ast.SqlStmtLine, expr string, table_name string, g.writeln('orm__Primitive ${id_name} = orm__int_to_primitive(orm__Connection_name_table[${expr}._typ]._method_last_id(${expr}._object));') for i, mut arr in arrs { idx := g.new_tmp_var() - g.writeln('for (int ${idx} = 0; ${idx} < ${arr.object_var_name}.${field_names[i]}.len; ${idx}++) {') + g.writeln('for (int ${idx} = 0; ${idx} < ${arr.object_var_name}${member_access_type}${field_names[i]}.len; ${idx}++) {') last_ids := g.new_tmp_var() res_ := g.new_tmp_var() tmp_var := g.new_tmp_var() ctyp := g.typ(arr.table_expr.typ) - g.writeln('${ctyp} ${tmp_var} = (*(${ctyp}*)array_get(${arr.object_var_name}.${field_names[i]}, ${idx}));') + g.writeln('${ctyp} ${tmp_var} = (*(${ctyp}*)array_get(${arr.object_var_name}${member_access_type}${field_names[i]}, ${idx}));') arr.object_var_name = tmp_var mut fff := []ast.StructField{} for f in arr.fields {