diff --git a/cmd/tools/vtest-self.v b/cmd/tools/vtest-self.v index 780977a3df..174adc45dd 100644 --- a/cmd/tools/vtest-self.v +++ b/cmd/tools/vtest-self.v @@ -128,6 +128,7 @@ const ( 'vlib/orm/orm_last_id_test.v', 'vlib/orm/orm_string_interpolation_in_where_test.v', 'vlib/orm/orm_interface_test.v', + 'vlib/orm/orm_mut_db_test.v', 'vlib/db/sqlite/sqlite_test.v', 'vlib/db/sqlite/sqlite_orm_test.v', 'vlib/db/sqlite/sqlite_vfs_lowlevel_test.v', @@ -199,6 +200,7 @@ const ( 'vlib/orm/orm_last_id_test.v', 'vlib/orm/orm_string_interpolation_in_where_test.v', 'vlib/orm/orm_interface_test.v', + 'vlib/orm/orm_mut_db_test.v', 'vlib/v/tests/orm_sub_struct_test.v', 'vlib/v/tests/orm_sub_array_struct_test.v', 'vlib/v/tests/orm_joined_tables_select_test.v', diff --git a/vlib/orm/orm_mut_db_test.v b/vlib/orm/orm_mut_db_test.v new file mode 100644 index 0000000000..1c05d5eb93 --- /dev/null +++ b/vlib/orm/orm_mut_db_test.v @@ -0,0 +1,36 @@ +import db.sqlite + +struct User { + id int [primary; sql: serial] + name string +} + +fn get_users(mut db sqlite.DB) []User { + return sql db { + select from User + } +} + +fn test_orm_mut_db() { + mut db := sqlite.connect(':memory:') or { panic(err) } + + sql db { + create table User + } + + first_user := User{ + name: 'first' + } + second_user := User{ + name: 'second' + } + + sql db { + insert first_user into User + insert second_user into User + } + + users := get_users(mut db) + + assert users.len == 2 +} diff --git a/vlib/v/gen/c/sql.v b/vlib/v/gen/c/sql.v index 47aa6e22f5..5c0db9a0ac 100644 --- a/vlib/v/gen/c/sql.v +++ b/vlib/v/gen/c/sql.v @@ -12,24 +12,12 @@ enum SqlExprSide { } fn (mut g Gen) sql_stmt(node ast.SqlStmt) { - conn := g.new_tmp_var() - g.writeln('') - g.writeln('// orm') - g.write('orm__Connection ${conn} = ') + connection_var_name := g.new_tmp_var() - db_expr_ctype_name := g.typ(node.db_expr_type) - - if db_expr_ctype_name == 'orm__Connection' { - g.expr(node.db_expr) - g.writeln(';') - } else { - g.write('(orm__Connection){._${db_expr_ctype_name} = &') - g.expr(node.db_expr) - g.writeln(', ._typ = _orm__Connection_${db_expr_ctype_name}_index};') - } + g.write_orm_connection_init(connection_var_name, &node.db_expr) for line in node.lines { - g.sql_stmt_line(line, conn, node.or_expr) + g.sql_stmt_line(line, connection_var_name, node.or_expr) } } @@ -539,25 +527,12 @@ fn (mut g Gen) sql_gen_where_data(where_expr ast.Expr) { fn (mut g Gen) sql_select_expr(node ast.SqlExpr) { left := g.go_before_stmt(0) - conn := g.new_tmp_var() + connection_var_name := g.new_tmp_var() g.writeln('') - g.writeln('// orm') - g.write('orm__Connection ${conn} = ') - db_expr_type := g.get_db_type(node.db_expr) or { - verror('sql orm error - unknown db type for ${node.db_expr}') - } - db_expr_ctype_name := g.typ(db_expr_type) - if db_expr_ctype_name == 'orm__Connection' { - g.expr(node.db_expr) - g.writeln(';') - } else { - g.write('(orm__Connection){._${db_expr_ctype_name} = &') - g.expr(node.db_expr) - g.writeln(', ._typ = _orm__Connection_${db_expr_ctype_name}_index};') - } + g.write_orm_connection_init(connection_var_name, &node.db_expr) - g.sql_select(node, conn, left, node.or_expr) + g.sql_select(node, connection_var_name, left, node.or_expr) } fn (mut g Gen) sql_select(node ast.SqlExpr, expr string, left string, or_expr ast.OrExpr) { @@ -899,3 +874,24 @@ fn (mut g Gen) write_error_handling_for_orm_result(expr_pos &token.Pos, result_v g.writeln('}') } + +fn (mut g Gen) write_orm_connection_init(connection_var_name string, db_expr &ast.Expr) { + db_expr_type := g.get_db_type(db_expr) or { verror('V ORM: unknown db type for ${db_expr}') } + + mut db_ctype_name := g.typ(db_expr_type) + is_pointer := db_ctype_name.ends_with('*') + reference_sign := if is_pointer { '' } else { '&' } + db_ctype_name = db_ctype_name.trim_right('*') + + g.writeln('// orm') + g.write('orm__Connection ${connection_var_name} = ') + + if db_ctype_name == 'orm__Connection' { + g.expr(db_expr) + g.writeln(';') + } else { + g.write('(orm__Connection){._${db_ctype_name} = ${reference_sign}') + g.expr(db_expr) + g.writeln(', ._typ = _orm__Connection_${db_ctype_name}_index};') + } +}