diff --git a/.gitignore b/.gitignore index da1950f2b3..e2d90adbb1 100644 --- a/.gitignore +++ b/.gitignore @@ -10,7 +10,6 @@ /assets/themes/Internal/*.json /assets/themes/Experiments/*.json **/venv -<<<<<<< HEAD .build Packages *.xcodeproj @@ -19,6 +18,4 @@ DerivedData/ .swiftpm/config/registries.json .swiftpm/xcode/package.xcworkspace/contents.xcworkspacedata .netrc -======= crates/db/test-db.db ->>>>>>> 9d9ad38ce (Successfully detecting workplace IDs :D) diff --git a/crates/db/src/kvp.rs b/crates/db/src/kvp.rs index 93be5e10c0..6f1230f7b8 100644 --- a/crates/db/src/kvp.rs +++ b/crates/db/src/kvp.rs @@ -15,24 +15,19 @@ pub(crate) const KVP_MIGRATION: Migration = Migration::new( impl Db { pub fn read_kvp(&self, key: &str) -> Result> { - self.0 - .prepare("SELECT value FROM kv_store WHERE key = (?)")? - .with_bindings(key)? - .maybe_row() + self.select_row_bound("SELECT value FROM kv_store WHERE key = (?)")?(key) } pub fn write_kvp(&self, key: &str, value: &str) -> Result<()> { - self.0 - .prepare("INSERT OR REPLACE INTO kv_store(key, value) VALUES ((?), (?))")? - .with_bindings((key, value))? - .exec() + self.exec_bound("INSERT OR REPLACE INTO kv_store(key, value) VALUES ((?), (?))")?(( + key, value, + ))?; + + Ok(()) } pub fn delete_kvp(&self, key: &str) -> Result<()> { - self.0 - .prepare("DELETE FROM kv_store WHERE key = (?)")? - .with_bindings(key)? - .exec() + self.exec_bound("DELETE FROM kv_store WHERE key = (?)")?(key) } } diff --git a/crates/db/src/workspace.rs b/crates/db/src/workspace.rs index 9b2d9e4563..c4e4873dce 100644 --- a/crates/db/src/workspace.rs +++ b/crates/db/src/workspace.rs @@ -22,7 +22,7 @@ pub(crate) const WORKSPACES_MIGRATION: Migration = Migration::new( "}], ); -use self::model::{SerializedWorkspace, WorkspaceId, WorkspaceRow}; +use self::model::{SerializedWorkspace, WorkspaceId}; use super::Db; @@ -40,21 +40,19 @@ impl Db { // and we've grabbed the most recent workspace let (workspace_id, dock_anchor, dock_visible) = iife!({ if worktree_roots.len() == 0 { - self.prepare(indoc! {" + self.select_row(indoc! {" SELECT workspace_id, dock_anchor, dock_visible FROM workspaces - ORDER BY timestamp DESC LIMIT 1"})? - .maybe_row::() + ORDER BY timestamp DESC LIMIT 1"})?()? } else { - self.prepare(indoc! {" + self.select_row_bound(indoc! {" SELECT workspace_id, dock_anchor, dock_visible FROM workspaces - WHERE workspace_id = ?"})? - .with_bindings(&workspace_id)? - .maybe_row::() + WHERE workspace_id = ?"})?(&workspace_id)? } + .context("No workspaces found") }) - .log_err() + .warn_on_err() .flatten()?; Some(SerializedWorkspace { @@ -85,23 +83,17 @@ impl Db { if let Some(old_roots) = old_roots { let old_id: WorkspaceId = old_roots.into(); - self.prepare("DELETE FROM WORKSPACES WHERE workspace_id = ?")? - .with_bindings(&old_id)? - .exec()?; + self.exec_bound("DELETE FROM WORKSPACES WHERE workspace_id = ?")?(&old_id)?; } // Delete any previous workspaces with the same roots. This cascades to all // other tables that are based on the same roots set. // Insert new workspace into workspaces table if none were found - self.prepare("DELETE FROM workspaces WHERE workspace_id = ?;")? - .with_bindings(&workspace_id)? - .exec()?; + self.exec_bound("DELETE FROM workspaces WHERE workspace_id = ?;")?(&workspace_id)?; - self.prepare( + self.exec_bound( "INSERT INTO workspaces(workspace_id, dock_anchor, dock_visible) VALUES (?, ?, ?)", - )? - .with_bindings((&workspace_id, workspace.dock_anchor, workspace.dock_visible))? - .exec()?; + )?((&workspace_id, workspace.dock_anchor, workspace.dock_visible))?; // Save center pane group and dock pane self.save_pane_group(&workspace_id, &workspace.center_group, None)?; @@ -126,11 +118,9 @@ impl Db { iife!({ // TODO, upgrade anyhow: https://docs.rs/anyhow/1.0.66/anyhow/fn.Ok.html Ok::<_, anyhow::Error>( - self.prepare( + self.select_bound::( "SELECT workspace_id FROM workspaces ORDER BY timestamp DESC LIMIT ?", - )? - .with_bindings(limit)? - .rows::()? + )?(limit)? .into_iter() .map(|id| id.paths()) .collect::>>(), diff --git a/crates/db/src/workspace/items.rs b/crates/db/src/workspace/items.rs index 25873a7f9b..9e859ffdad 100644 --- a/crates/db/src/workspace/items.rs +++ b/crates/db/src/workspace/items.rs @@ -3,7 +3,7 @@ use indoc::indoc; use sqlez::migrations::Migration; use crate::{ - model::{ItemId, PaneId, SerializedItem, SerializedItemKind, WorkspaceId}, + model::{PaneId, SerializedItem, SerializedItemKind, WorkspaceId}, Db, }; @@ -29,19 +29,16 @@ pub(crate) const ITEM_MIGRATIONS: Migration = Migration::new( impl Db { pub(crate) fn get_items(&self, pane_id: PaneId) -> Result> { - Ok(self - .prepare(indoc! {" + Ok(self.select_bound(indoc! {" SELECT item_id, kind FROM items WHERE pane_id = ? - ORDER BY position"})? - .with_bindings(pane_id)? - .rows::<(ItemId, SerializedItemKind)>()? - .into_iter() - .map(|(item_id, kind)| match kind { - SerializedItemKind::Terminal => SerializedItem::Terminal { item_id }, - _ => unimplemented!(), - }) - .collect()) + ORDER BY position"})?(pane_id)? + .into_iter() + .map(|(item_id, kind)| match kind { + SerializedItemKind::Terminal => SerializedItem::Terminal { item_id }, + _ => unimplemented!(), + }) + .collect()) } pub(crate) fn save_items( @@ -51,19 +48,14 @@ impl Db { items: &[SerializedItem], ) -> Result<()> { let mut delete_old = self - .prepare("DELETE FROM items WHERE workspace_id = ? AND pane_id = ? AND item_id = ?") + .exec_bound("DELETE FROM items WHERE workspace_id = ? AND pane_id = ? AND item_id = ?") .context("Preparing deletion")?; - let mut insert_new = self.prepare( + let mut insert_new = self.exec_bound( "INSERT INTO items(item_id, workspace_id, pane_id, kind, position) VALUES (?, ?, ?, ?, ?)", ).context("Preparing insertion")?; for (position, item) in items.iter().enumerate() { - delete_old - .with_bindings((workspace_id, pane_id, item.item_id()))? - .exec()?; - - insert_new - .with_bindings((item.item_id(), workspace_id, pane_id, item.kind(), position))? - .exec()?; + delete_old((workspace_id, pane_id, item.item_id()))?; + insert_new((item.item_id(), workspace_id, pane_id, item.kind(), position))?; } Ok(()) diff --git a/crates/db/src/workspace/model.rs b/crates/db/src/workspace/model.rs index 1d9065f6d9..36099f66e6 100644 --- a/crates/db/src/workspace/model.rs +++ b/crates/db/src/workspace/model.rs @@ -80,8 +80,6 @@ impl Column for DockAnchor { } } -pub(crate) type WorkspaceRow = (WorkspaceId, DockAnchor, bool); - #[derive(Debug, PartialEq, Eq)] pub struct SerializedWorkspace { pub dock_anchor: DockAnchor, @@ -240,23 +238,20 @@ mod tests { workspace_id BLOB, dock_anchor TEXT );"}) - .unwrap(); + .unwrap()() + .unwrap(); let workspace_id: WorkspaceId = WorkspaceId::from(&["\test2", "\test1"]); - db.prepare("INSERT INTO workspace_id_test(workspace_id, dock_anchor) VALUES (?,?)") - .unwrap() - .with_bindings((&workspace_id, DockAnchor::Bottom)) - .unwrap() - .exec() - .unwrap(); + db.exec_bound("INSERT INTO workspace_id_test(workspace_id, dock_anchor) VALUES (?,?)") + .unwrap()((&workspace_id, DockAnchor::Bottom)) + .unwrap(); assert_eq!( - db.prepare("SELECT workspace_id, dock_anchor FROM workspace_id_test LIMIT 1") - .unwrap() - .row::<(WorkspaceId, DockAnchor)>() - .unwrap(), - (WorkspaceId::from(&["\test1", "\test2"]), DockAnchor::Bottom) + db.select_row("SELECT workspace_id, dock_anchor FROM workspace_id_test LIMIT 1") + .unwrap()() + .unwrap(), + Some((WorkspaceId::from(&["\test1", "\test2"]), DockAnchor::Bottom)) ); } } diff --git a/crates/db/src/workspace/pane.rs b/crates/db/src/workspace/pane.rs index 8528acb8af..24d6a3f938 100644 --- a/crates/db/src/workspace/pane.rs +++ b/crates/db/src/workspace/pane.rs @@ -1,6 +1,6 @@ use anyhow::{bail, Context, Result}; use indoc::indoc; -use sqlez::{migrations::Migration, statement::Statement}; +use sqlez::migrations::Migration; use util::unzip_option; use crate::model::{Axis, GroupId, PaneId, SerializedPane}; @@ -39,38 +39,29 @@ impl Db { &self, workspace_id: &WorkspaceId, ) -> Result { - let mut query = self.prepare(indoc! {" - SELECT group_id, axis, pane_id - FROM (SELECT group_id, axis, NULL as pane_id, position, parent_group_id, workspace_id - FROM pane_groups - UNION - SELECT NULL, NULL, pane_id, position, parent_group_id, workspace_id - FROM panes - -- Remove the dock panes from the union - WHERE parent_group_id IS NOT NULL and position IS NOT NULL) - WHERE parent_group_id IS ? AND workspace_id = ? - ORDER BY position - "})?; - - self.get_pane_group_children(workspace_id, None, &mut query)? + self.get_pane_group_children(workspace_id, None)? .into_iter() .next() .context("No center pane group") } - fn get_pane_group_children( + fn get_pane_group_children<'a>( &self, workspace_id: &WorkspaceId, group_id: Option, - query: &mut Statement, ) -> Result> { - let children = query.with_bindings((group_id, workspace_id))?.rows::<( - Option, - Option, - Option, - )>()?; - - children + self.select_bound::<(Option, &WorkspaceId), (Option, Option, Option)>(indoc! {" + SELECT group_id, axis, pane_id + FROM (SELECT group_id, axis, NULL as pane_id, position, parent_group_id, workspace_id + FROM pane_groups + UNION + SELECT NULL, NULL, pane_id, position, parent_group_id, workspace_id + FROM panes + -- Remove the dock panes from the union + WHERE parent_group_id IS NOT NULL and position IS NOT NULL) + WHERE parent_group_id IS ? AND workspace_id = ? + ORDER BY position + "})?((group_id, workspace_id))? .into_iter() .map(|(group_id, axis, pane_id)| { if let Some((group_id, axis)) = group_id.zip(axis) { @@ -79,7 +70,6 @@ impl Db { children: self.get_pane_group_children( workspace_id, Some(group_id), - query, )?, }) } else if let Some(pane_id) = pane_id { @@ -107,9 +97,8 @@ impl Db { match pane_group { SerializedPaneGroup::Group { axis, children } => { - let parent_id = self.prepare("INSERT INTO pane_groups(workspace_id, parent_group_id, position, axis) VALUES (?, ?, ?, ?)")? - .with_bindings((workspace_id, parent_id, position, *axis))? - .insert()? as GroupId; + let parent_id = self.insert_bound("INSERT INTO pane_groups(workspace_id, parent_group_id, position, axis) VALUES (?, ?, ?, ?)")? + ((workspace_id, parent_id, position, *axis))?; for (position, group) in children.iter().enumerate() { self.save_pane_group(workspace_id, group, Some((parent_id, position)))? @@ -121,12 +110,12 @@ impl Db { } pub(crate) fn get_dock_pane(&self, workspace_id: &WorkspaceId) -> Result { - let pane_id = self - .prepare(indoc! {" + let pane_id = self.select_row_bound(indoc! {" SELECT pane_id FROM panes - WHERE workspace_id = ? AND parent_group_id IS NULL AND position IS NULL"})? - .with_bindings(workspace_id)? - .row::()?; + WHERE workspace_id = ? AND parent_group_id IS NULL AND position IS NULL"})?( + workspace_id, + )? + .context("No dock pane for workspace")?; Ok(SerializedPane::new( self.get_items(pane_id).context("Reading items")?, @@ -141,10 +130,9 @@ impl Db { ) -> Result<()> { let (parent_id, order) = unzip_option(parent); - let pane_id = self - .prepare("INSERT INTO panes(workspace_id, parent_group_id, position) VALUES (?, ?, ?)")? - .with_bindings((workspace_id, parent_id, order))? - .insert()? as PaneId; + let pane_id = self.insert_bound( + "INSERT INTO panes(workspace_id, parent_group_id, position) VALUES (?, ?, ?)", + )?((workspace_id, parent_id, order))?; self.save_items(workspace_id, pane_id, &pane.children) .context("Saving items") diff --git a/crates/db/test.db b/crates/db/test.db index 09a0bc8f11..cedefe5f83 100644 Binary files a/crates/db/test.db and b/crates/db/test.db differ diff --git a/crates/sqlez/src/connection.rs b/crates/sqlez/src/connection.rs index 04a12cfc97..b673167c86 100644 --- a/crates/sqlez/src/connection.rs +++ b/crates/sqlez/src/connection.rs @@ -6,8 +6,6 @@ use std::{ use anyhow::{anyhow, Result}; use libsqlite3_sys::*; -use crate::statement::Statement; - pub struct Connection { pub(crate) sqlite3: *mut sqlite3, persistent: bool, @@ -60,30 +58,6 @@ impl Connection { unsafe { sqlite3_last_insert_rowid(self.sqlite3) } } - pub fn insert(&self, query: impl AsRef) -> Result { - self.exec(query)?; - Ok(self.last_insert_id()) - } - - pub fn exec(&self, query: impl AsRef) -> Result<()> { - unsafe { - sqlite3_exec( - self.sqlite3, - CString::new(query.as_ref())?.as_ptr(), - None, - 0 as *mut _, - 0 as *mut _, - ); - sqlite3_errcode(self.sqlite3); - self.last_error()?; - } - Ok(()) - } - - pub fn prepare>(&self, query: T) -> Result { - Statement::prepare(&self, query) - } - pub fn backup_main(&self, destination: &Connection) -> Result<()> { unsafe { let backup = sqlite3_backup_init( @@ -136,7 +110,7 @@ mod test { use anyhow::Result; use indoc::indoc; - use crate::{connection::Connection, migrations::Migration}; + use crate::connection::Connection; #[test] fn string_round_trips() -> Result<()> { @@ -146,25 +120,19 @@ mod test { CREATE TABLE text ( text TEXT );"}) - .unwrap(); + .unwrap()() + .unwrap(); let text = "Some test text"; connection - .prepare("INSERT INTO text (text) VALUES (?);") - .unwrap() - .with_bindings(text) - .unwrap() - .exec() - .unwrap(); + .insert_bound("INSERT INTO text (text) VALUES (?);") + .unwrap()(text) + .unwrap(); assert_eq!( - &connection - .prepare("SELECT text FROM text;") - .unwrap() - .row::() - .unwrap(), - text + connection.select_row("SELECT text FROM text;").unwrap()().unwrap(), + Some(text.to_string()) ); Ok(()) @@ -180,32 +148,26 @@ mod test { integer INTEGER, blob BLOB );"}) - .unwrap(); + .unwrap()() + .unwrap(); let tuple1 = ("test".to_string(), 64, vec![0, 1, 2, 4, 8, 16, 32, 64]); let tuple2 = ("test2".to_string(), 32, vec![64, 32, 16, 8, 4, 2, 1, 0]); let mut insert = connection - .prepare("INSERT INTO test (text, integer, blob) VALUES (?, ?, ?)") + .insert_bound::<(String, usize, Vec)>( + "INSERT INTO test (text, integer, blob) VALUES (?, ?, ?)", + ) .unwrap(); - insert - .with_bindings(tuple1.clone()) - .unwrap() - .exec() - .unwrap(); - insert - .with_bindings(tuple2.clone()) - .unwrap() - .exec() - .unwrap(); + insert(tuple1.clone()).unwrap(); + insert(tuple2.clone()).unwrap(); assert_eq!( connection - .prepare("SELECT * FROM test") - .unwrap() - .rows::<(String, usize, Vec)>() - .unwrap(), + .select::<(String, usize, Vec)>("SELECT * FROM test") + .unwrap()() + .unwrap(), vec![tuple1, tuple2] ); } @@ -219,23 +181,20 @@ mod test { t INTEGER, f INTEGER );"}) - .unwrap(); + .unwrap()() + .unwrap(); connection - .prepare("INSERT INTO bools(t, f) VALUES (?, ?);") - .unwrap() - .with_bindings((true, false)) - .unwrap() - .exec() - .unwrap(); + .insert_bound("INSERT INTO bools(t, f) VALUES (?, ?);") + .unwrap()((true, false)) + .unwrap(); assert_eq!( - &connection - .prepare("SELECT * FROM bools;") - .unwrap() - .row::<(bool, bool)>() - .unwrap(), - &(true, false) + connection + .select_row::<(bool, bool)>("SELECT * FROM bools;") + .unwrap()() + .unwrap(), + Some((true, false)) ); } @@ -247,13 +206,13 @@ mod test { CREATE TABLE blobs ( data BLOB );"}) - .unwrap(); - let blob = &[0, 1, 2, 4, 8, 16, 32, 64]; - let mut write = connection1 - .prepare("INSERT INTO blobs (data) VALUES (?);") - .unwrap(); - write.bind_blob(1, blob).unwrap(); - write.exec().unwrap(); + .unwrap()() + .unwrap(); + let blob = vec![0, 1, 2, 4, 8, 16, 32, 64]; + connection1 + .insert_bound::>("INSERT INTO blobs (data) VALUES (?);") + .unwrap()(blob.clone()) + .unwrap(); // Backup connection1 to connection2 let connection2 = Connection::open_memory("backup_works_other"); @@ -261,40 +220,36 @@ mod test { // Delete the added blob and verify its deleted on the other side let read_blobs = connection1 - .prepare("SELECT * FROM blobs;") - .unwrap() - .rows::>() - .unwrap(); + .select::>("SELECT * FROM blobs;") + .unwrap()() + .unwrap(); assert_eq!(read_blobs, vec![blob]); } #[test] - fn test_kv_store() -> anyhow::Result<()> { - let connection = Connection::open_memory("kv_store"); + fn multi_step_statement_works() { + let connection = Connection::open_memory("multi_step_statement_works"); - Migration::new( - "kv", - &["CREATE TABLE kv_store( - key TEXT PRIMARY KEY, - value TEXT NOT NULL - ) STRICT;"], - ) - .run(&connection) + connection + .exec(indoc! {" + CREATE TABLE test ( + col INTEGER + )"}) + .unwrap()() .unwrap(); - let mut stmt = connection.prepare("INSERT INTO kv_store(key, value) VALUES(?, ?)")?; - stmt.bind_text(1, "a").unwrap(); - stmt.bind_text(2, "b").unwrap(); - stmt.exec().unwrap(); - let id = connection.last_insert_id(); + connection + .exec(indoc! {" + INSERT INTO test(col) VALUES (2)"}) + .unwrap()() + .unwrap(); - let res = connection - .prepare("SELECT key, value FROM kv_store WHERE rowid = ?")? - .with_bindings(id)? - .row::<(String, String)>()?; - - assert_eq!(res, ("a".to_string(), "b".to_string())); - - Ok(()) + assert_eq!( + connection + .select_row::("SELECt * FROM test") + .unwrap()() + .unwrap(), + Some(2) + ); } } diff --git a/crates/sqlez/src/lib.rs b/crates/sqlez/src/lib.rs index 3bed7a06cb..155fb28901 100644 --- a/crates/sqlez/src/lib.rs +++ b/crates/sqlez/src/lib.rs @@ -4,3 +4,4 @@ pub mod migrations; pub mod savepoint; pub mod statement; pub mod thread_safe_connection; +pub mod typed_statements; diff --git a/crates/sqlez/src/migrations.rs b/crates/sqlez/src/migrations.rs index 9f3bd333ca..89eaebb494 100644 --- a/crates/sqlez/src/migrations.rs +++ b/crates/sqlez/src/migrations.rs @@ -18,7 +18,7 @@ const MIGRATIONS_MIGRATION: Migration = Migration::new( domain TEXT, step INTEGER, migration TEXT - ); + ) "}], ); @@ -34,24 +34,26 @@ impl Migration { } fn run_unchecked(&self, connection: &Connection) -> Result<()> { - connection.exec(self.migrations.join(";\n")) + for migration in self.migrations { + connection.exec(migration)?()?; + } + + Ok(()) } pub fn run(&self, connection: &Connection) -> Result<()> { // Setup the migrations table unconditionally MIGRATIONS_MIGRATION.run_unchecked(connection)?; - let completed_migrations = connection - .prepare(indoc! {" - SELECT domain, step, migration FROM migrations - WHERE domain = ? - ORDER BY step - "})? - .with_bindings(self.domain)? - .rows::<(String, usize, String)>()?; + let completed_migrations = + connection.select_bound::<&str, (String, usize, String)>(indoc! {" + SELECT domain, step, migration FROM migrations + WHERE domain = ? + ORDER BY step + "})?(self.domain)?; let mut store_completed_migration = connection - .prepare("INSERT INTO migrations (domain, step, migration) VALUES (?, ?, ?)")?; + .insert_bound("INSERT INTO migrations (domain, step, migration) VALUES (?, ?, ?)")?; for (index, migration) in self.migrations.iter().enumerate() { if let Some((_, _, completed_migration)) = completed_migrations.get(index) { @@ -70,10 +72,8 @@ impl Migration { } } - connection.exec(migration)?; - store_completed_migration - .with_bindings((self.domain, index, *migration))? - .exec()?; + connection.exec(migration)?()?; + store_completed_migration((self.domain, index, *migration))?; } Ok(()) @@ -97,17 +97,16 @@ mod test { CREATE TABLE test1 ( a TEXT, b TEXT - );"}], + )"}], ); migration.run(&connection).unwrap(); // Verify it got added to the migrations table assert_eq!( &connection - .prepare("SELECT (migration) FROM migrations") - .unwrap() - .rows::() - .unwrap()[..], + .select::("SELECT (migration) FROM migrations") + .unwrap()() + .unwrap()[..], migration.migrations ); @@ -117,22 +116,21 @@ mod test { CREATE TABLE test1 ( a TEXT, b TEXT - );"}, + )"}, indoc! {" CREATE TABLE test2 ( c TEXT, d TEXT - );"}, + )"}, ]; migration.run(&connection).unwrap(); // Verify it is also added to the migrations table assert_eq!( &connection - .prepare("SELECT (migration) FROM migrations") - .unwrap() - .rows::() - .unwrap()[..], + .select::("SELECT (migration) FROM migrations") + .unwrap()() + .unwrap()[..], migration.migrations ); } @@ -142,15 +140,17 @@ mod test { let connection = Connection::open_memory("migration_setup_works"); connection - .exec(indoc! {"CREATE TABLE IF NOT EXISTS migrations ( + .exec(indoc! {" + CREATE TABLE IF NOT EXISTS migrations ( domain TEXT, step INTEGER, migration TEXT );"}) - .unwrap(); + .unwrap()() + .unwrap(); let mut store_completed_migration = connection - .prepare(indoc! {" + .insert_bound::<(&str, usize, String)>(indoc! {" INSERT INTO migrations (domain, step, migration) VALUES (?, ?, ?)"}) .unwrap(); @@ -159,14 +159,11 @@ mod test { for i in 0..5 { // Create a table forcing a schema change connection - .exec(format!("CREATE TABLE table{} ( test TEXT );", i)) - .unwrap(); + .exec(&format!("CREATE TABLE table{} ( test TEXT );", i)) + .unwrap()() + .unwrap(); - store_completed_migration - .with_bindings((domain, i, i.to_string())) - .unwrap() - .exec() - .unwrap(); + store_completed_migration((domain, i, i.to_string())).unwrap(); } } @@ -180,46 +177,49 @@ mod test { // Manually create the table for that migration with a row connection .exec(indoc! {" - CREATE TABLE test_table ( - test_column INTEGER - ); - INSERT INTO test_table (test_column) VALUES (1)"}) - .unwrap(); + CREATE TABLE test_table ( + test_column INTEGER + );"}) + .unwrap()() + .unwrap(); + connection + .exec(indoc! {" + INSERT INTO test_table (test_column) VALUES (1);"}) + .unwrap()() + .unwrap(); assert_eq!( connection - .prepare("SELECT * FROM test_table") - .unwrap() - .row::() - .unwrap(), - 1 + .select_row::("SELECT * FROM test_table") + .unwrap()() + .unwrap(), + Some(1) ); // Run the migration verifying that the row got dropped migration.run(&connection).unwrap(); assert_eq!( connection - .prepare("SELECT * FROM test_table") - .unwrap() - .rows::() - .unwrap(), - Vec::new() + .select_row::("SELECT * FROM test_table") + .unwrap()() + .unwrap(), + None ); // Recreate the dropped row connection .exec("INSERT INTO test_table (test_column) VALUES (2)") - .unwrap(); + .unwrap()() + .unwrap(); // Run the same migration again and verify that the table was left unchanged migration.run(&connection).unwrap(); assert_eq!( connection - .prepare("SELECT * FROM test_table") - .unwrap() - .row::() - .unwrap(), - 2 + .select_row::("SELECT * FROM test_table") + .unwrap()() + .unwrap(), + Some(2) ); } diff --git a/crates/sqlez/src/savepoint.rs b/crates/sqlez/src/savepoint.rs index ba4b1e774b..b78358deb9 100644 --- a/crates/sqlez/src/savepoint.rs +++ b/crates/sqlez/src/savepoint.rs @@ -1,4 +1,5 @@ use anyhow::Result; +use indoc::{formatdoc, indoc}; use crate::connection::Connection; @@ -10,16 +11,17 @@ impl Connection { where F: FnOnce() -> Result, { - let name = name.as_ref().to_owned(); - self.exec(format!("SAVEPOINT {}", &name))?; + let name = name.as_ref(); + self.exec(&format!("SAVEPOINT {name}"))?()?; let result = f(); match result { Ok(_) => { - self.exec(format!("RELEASE {}", name))?; + self.exec(&format!("RELEASE {name}"))?()?; } Err(_) => { - self.exec(format!("ROLLBACK TO {}", name))?; - self.exec(format!("RELEASE {}", name))?; + self.exec(&formatdoc! {" + ROLLBACK TO {name}; + RELEASE {name}"})?()?; } } result @@ -32,16 +34,17 @@ impl Connection { where F: FnOnce() -> Result>, { - let name = name.as_ref().to_owned(); - self.exec(format!("SAVEPOINT {}", &name))?; + let name = name.as_ref(); + self.exec(&format!("SAVEPOINT {name}"))?()?; let result = f(); match result { Ok(Some(_)) => { - self.exec(format!("RELEASE {}", name))?; + self.exec(&format!("RELEASE {name}"))?()?; } Ok(None) | Err(_) => { - self.exec(format!("ROLLBACK TO {}", name))?; - self.exec(format!("RELEASE {}", name))?; + self.exec(&formatdoc! {" + ROLLBACK TO {name}; + RELEASE {name}"})?()?; } } result @@ -64,28 +67,25 @@ mod tests { text TEXT, idx INTEGER );"}) - .unwrap(); + .unwrap()() + .unwrap(); let save1_text = "test save1"; let save2_text = "test save2"; connection.with_savepoint("first", || { - connection - .prepare("INSERT INTO text(text, idx) VALUES (?, ?)")? - .with_bindings((save1_text, 1))? - .exec()?; + connection.exec_bound("INSERT INTO text(text, idx) VALUES (?, ?)")?((save1_text, 1))?; assert!(connection .with_savepoint("second", || -> Result, anyhow::Error> { - connection - .prepare("INSERT INTO text(text, idx) VALUES (?, ?)")? - .with_bindings((save2_text, 2))? - .exec()?; + connection.exec_bound("INSERT INTO text(text, idx) VALUES (?, ?)")?(( + save2_text, 2, + ))?; assert_eq!( connection - .prepare("SELECT text FROM text ORDER BY text.idx ASC")? - .rows::()?, + .select::("SELECT text FROM text ORDER BY text.idx ASC")?( + )?, vec![save1_text, save2_text], ); @@ -95,22 +95,17 @@ mod tests { .is_some()); assert_eq!( - connection - .prepare("SELECT text FROM text ORDER BY text.idx ASC")? - .rows::()?, + connection.select::("SELECT text FROM text ORDER BY text.idx ASC")?()?, vec![save1_text], ); connection.with_savepoint_rollback::<(), _>("second", || { - connection - .prepare("INSERT INTO text(text, idx) VALUES (?, ?)")? - .with_bindings((save2_text, 2))? - .exec()?; + connection.exec_bound("INSERT INTO text(text, idx) VALUES (?, ?)")?(( + save2_text, 2, + ))?; assert_eq!( - connection - .prepare("SELECT text FROM text ORDER BY text.idx ASC")? - .rows::()?, + connection.select::("SELECT text FROM text ORDER BY text.idx ASC")?()?, vec![save1_text, save2_text], ); @@ -118,22 +113,17 @@ mod tests { })?; assert_eq!( - connection - .prepare("SELECT text FROM text ORDER BY text.idx ASC")? - .rows::()?, + connection.select::("SELECT text FROM text ORDER BY text.idx ASC")?()?, vec![save1_text], ); connection.with_savepoint_rollback("second", || { - connection - .prepare("INSERT INTO text(text, idx) VALUES (?, ?)")? - .with_bindings((save2_text, 2))? - .exec()?; + connection.exec_bound("INSERT INTO text(text, idx) VALUES (?, ?)")?(( + save2_text, 2, + ))?; assert_eq!( - connection - .prepare("SELECT text FROM text ORDER BY text.idx ASC")? - .rows::()?, + connection.select::("SELECT text FROM text ORDER BY text.idx ASC")?()?, vec![save1_text, save2_text], ); @@ -141,9 +131,7 @@ mod tests { })?; assert_eq!( - connection - .prepare("SELECT text FROM text ORDER BY text.idx ASC")? - .rows::()?, + connection.select::("SELECT text FROM text ORDER BY text.idx ASC")?()?, vec![save1_text, save2_text], ); @@ -151,9 +139,7 @@ mod tests { })?; assert_eq!( - connection - .prepare("SELECT text FROM text ORDER BY text.idx ASC")? - .rows::()?, + connection.select::("SELECT text FROM text ORDER BY text.idx ASC")?()?, vec![save1_text, save2_text], ); diff --git a/crates/sqlez/src/statement.rs b/crates/sqlez/src/statement.rs index f0de8703ab..e0b284e628 100644 --- a/crates/sqlez/src/statement.rs +++ b/crates/sqlez/src/statement.rs @@ -1,6 +1,6 @@ -use std::ffi::{c_int, CString}; +use std::ffi::{c_int, CStr, CString}; use std::marker::PhantomData; -use std::{slice, str}; +use std::{ptr, slice, str}; use anyhow::{anyhow, Context, Result}; use libsqlite3_sys::*; @@ -9,7 +9,8 @@ use crate::bindable::{Bind, Column}; use crate::connection::Connection; pub struct Statement<'a> { - raw_statement: *mut sqlite3_stmt, + raw_statements: Vec<*mut sqlite3_stmt>, + current_statement: usize, connection: &'a Connection, phantom: PhantomData, } @@ -34,19 +35,31 @@ pub enum SqlType { impl<'a> Statement<'a> { pub fn prepare>(connection: &'a Connection, query: T) -> Result { let mut statement = Self { - raw_statement: 0 as *mut _, + raw_statements: Default::default(), + current_statement: 0, connection, phantom: PhantomData, }; unsafe { - sqlite3_prepare_v2( - connection.sqlite3, - CString::new(query.as_ref())?.as_ptr(), - -1, - &mut statement.raw_statement, - 0 as *mut _, - ); + let sql = CString::new(query.as_ref())?; + let mut remaining_sql = sql.as_c_str(); + while { + let remaining_sql_str = remaining_sql.to_str()?; + remaining_sql_str.trim() != ";" && !remaining_sql_str.is_empty() + } { + let mut raw_statement = 0 as *mut sqlite3_stmt; + let mut remaining_sql_ptr = ptr::null(); + sqlite3_prepare_v2( + connection.sqlite3, + remaining_sql.as_ptr(), + -1, + &mut raw_statement, + &mut remaining_sql_ptr, + ); + remaining_sql = CStr::from_ptr(remaining_sql_ptr); + statement.raw_statements.push(raw_statement); + } connection .last_error() @@ -56,131 +69,138 @@ impl<'a> Statement<'a> { Ok(statement) } + fn current_statement(&self) -> *mut sqlite3_stmt { + *self.raw_statements.get(self.current_statement).unwrap() + } + pub fn reset(&mut self) { unsafe { - sqlite3_reset(self.raw_statement); + for raw_statement in self.raw_statements.iter() { + sqlite3_reset(*raw_statement); + } } + self.current_statement = 0; } pub fn parameter_count(&self) -> i32 { - unsafe { sqlite3_bind_parameter_count(self.raw_statement) } + unsafe { + self.raw_statements + .iter() + .map(|raw_statement| sqlite3_bind_parameter_count(*raw_statement)) + .max() + .unwrap_or(0) + } } pub fn bind_blob(&self, index: i32, blob: &[u8]) -> Result<()> { - // dbg!("bind blob", index); let index = index as c_int; let blob_pointer = blob.as_ptr() as *const _; let len = blob.len() as c_int; unsafe { - sqlite3_bind_blob( - self.raw_statement, - index, - blob_pointer, - len, - SQLITE_TRANSIENT(), - ); + for raw_statement in self.raw_statements.iter() { + sqlite3_bind_blob(*raw_statement, index, blob_pointer, len, SQLITE_TRANSIENT()); + } } self.connection.last_error() } pub fn column_blob<'b>(&'b mut self, index: i32) -> Result<&'b [u8]> { let index = index as c_int; - let pointer = unsafe { sqlite3_column_blob(self.raw_statement, index) }; + let pointer = unsafe { sqlite3_column_blob(self.current_statement(), index) }; self.connection.last_error()?; if pointer.is_null() { return Ok(&[]); } - let len = unsafe { sqlite3_column_bytes(self.raw_statement, index) as usize }; + let len = unsafe { sqlite3_column_bytes(self.current_statement(), index) as usize }; self.connection.last_error()?; unsafe { Ok(slice::from_raw_parts(pointer as *const u8, len)) } } pub fn bind_double(&self, index: i32, double: f64) -> Result<()> { - // dbg!("bind double", index); let index = index as c_int; unsafe { - sqlite3_bind_double(self.raw_statement, index, double); + for raw_statement in self.raw_statements.iter() { + sqlite3_bind_double(*raw_statement, index, double); + } } self.connection.last_error() } pub fn column_double(&self, index: i32) -> Result { let index = index as c_int; - let result = unsafe { sqlite3_column_double(self.raw_statement, index) }; + let result = unsafe { sqlite3_column_double(self.current_statement(), index) }; self.connection.last_error()?; Ok(result) } pub fn bind_int(&self, index: i32, int: i32) -> Result<()> { - // dbg!("bind int", index); let index = index as c_int; unsafe { - sqlite3_bind_int(self.raw_statement, index, int); + for raw_statement in self.raw_statements.iter() { + sqlite3_bind_int(*raw_statement, index, int); + } }; self.connection.last_error() } pub fn column_int(&self, index: i32) -> Result { let index = index as c_int; - let result = unsafe { sqlite3_column_int(self.raw_statement, index) }; + let result = unsafe { sqlite3_column_int(self.current_statement(), index) }; self.connection.last_error()?; Ok(result) } pub fn bind_int64(&self, index: i32, int: i64) -> Result<()> { - // dbg!("bind int64", index); let index = index as c_int; unsafe { - sqlite3_bind_int64(self.raw_statement, index, int); + for raw_statement in self.raw_statements.iter() { + sqlite3_bind_int64(*raw_statement, index, int); + } } self.connection.last_error() } pub fn column_int64(&self, index: i32) -> Result { let index = index as c_int; - let result = unsafe { sqlite3_column_int64(self.raw_statement, index) }; + let result = unsafe { sqlite3_column_int64(self.current_statement(), index) }; self.connection.last_error()?; Ok(result) } pub fn bind_null(&self, index: i32) -> Result<()> { - // dbg!("bind null", index); let index = index as c_int; unsafe { - sqlite3_bind_null(self.raw_statement, index); + for raw_statement in self.raw_statements.iter() { + sqlite3_bind_null(*raw_statement, index); + } } self.connection.last_error() } pub fn bind_text(&self, index: i32, text: &str) -> Result<()> { - // dbg!("bind text", index, text); let index = index as c_int; let text_pointer = text.as_ptr() as *const _; let len = text.len() as c_int; unsafe { - sqlite3_bind_text( - self.raw_statement, - index, - text_pointer, - len, - SQLITE_TRANSIENT(), - ); + for raw_statement in self.raw_statements.iter() { + sqlite3_bind_text(*raw_statement, index, text_pointer, len, SQLITE_TRANSIENT()); + } } self.connection.last_error() } pub fn column_text<'b>(&'b mut self, index: i32) -> Result<&'b str> { let index = index as c_int; - let pointer = unsafe { sqlite3_column_text(self.raw_statement, index) }; + let pointer = unsafe { sqlite3_column_text(self.current_statement(), index) }; self.connection.last_error()?; if pointer.is_null() { return Ok(""); } - let len = unsafe { sqlite3_column_bytes(self.raw_statement, index) as usize }; + let len = unsafe { sqlite3_column_bytes(self.current_statement(), index) as usize }; self.connection.last_error()?; let slice = unsafe { slice::from_raw_parts(pointer as *const u8, len) }; @@ -198,7 +218,7 @@ impl<'a> Statement<'a> { } pub fn column_type(&mut self, index: i32) -> Result { - let result = unsafe { sqlite3_column_type(self.raw_statement, index) }; // SELECT FROM TABLE + let result = unsafe { sqlite3_column_type(self.current_statement(), index) }; self.connection.last_error()?; match result { SQLITE_INTEGER => Ok(SqlType::Integer), @@ -217,9 +237,16 @@ impl<'a> Statement<'a> { fn step(&mut self) -> Result { unsafe { - match sqlite3_step(self.raw_statement) { + match sqlite3_step(self.current_statement()) { SQLITE_ROW => Ok(StepResult::Row), - SQLITE_DONE => Ok(StepResult::Done), + SQLITE_DONE => { + if self.current_statement >= self.raw_statements.len() - 1 { + Ok(StepResult::Done) + } else { + self.current_statement += 1; + self.step() + } + } SQLITE_MISUSE => Ok(StepResult::Misuse), other => self .connection @@ -311,7 +338,11 @@ impl<'a> Statement<'a> { impl<'a> Drop for Statement<'a> { fn drop(&mut self) { - unsafe { sqlite3_finalize(self.raw_statement) }; + unsafe { + for raw_statement in self.raw_statements.iter() { + sqlite3_finalize(*raw_statement); + } + } } } @@ -319,7 +350,10 @@ impl<'a> Drop for Statement<'a> { mod test { use indoc::indoc; - use crate::{connection::Connection, statement::StepResult}; + use crate::{ + connection::Connection, + statement::{Statement, StepResult}, + }; #[test] fn blob_round_trips() { @@ -327,28 +361,28 @@ mod test { connection1 .exec(indoc! {" CREATE TABLE blobs ( - data BLOB - );"}) - .unwrap(); + data BLOB + )"}) + .unwrap()() + .unwrap(); let blob = &[0, 1, 2, 4, 8, 16, 32, 64]; - let mut write = connection1 - .prepare("INSERT INTO blobs (data) VALUES (?);") - .unwrap(); + let mut write = + Statement::prepare(&connection1, "INSERT INTO blobs (data) VALUES (?)").unwrap(); write.bind_blob(1, blob).unwrap(); assert_eq!(write.step().unwrap(), StepResult::Done); // Read the blob from the let connection2 = Connection::open_memory("blob_round_trips"); - let mut read = connection2.prepare("SELECT * FROM blobs;").unwrap(); + let mut read = Statement::prepare(&connection2, "SELECT * FROM blobs").unwrap(); assert_eq!(read.step().unwrap(), StepResult::Row); assert_eq!(read.column_blob(0).unwrap(), blob); assert_eq!(read.step().unwrap(), StepResult::Done); // Delete the added blob and verify its deleted on the other side - connection2.exec("DELETE FROM blobs;").unwrap(); - let mut read = connection1.prepare("SELECT * FROM blobs;").unwrap(); + connection2.exec("DELETE FROM blobs").unwrap()().unwrap(); + let mut read = Statement::prepare(&connection1, "SELECT * FROM blobs").unwrap(); assert_eq!(read.step().unwrap(), StepResult::Done); } @@ -359,32 +393,25 @@ mod test { .exec(indoc! {" CREATE TABLE texts ( text TEXT - );"}) - .unwrap(); + )"}) + .unwrap()() + .unwrap(); assert!(connection - .prepare("SELECT text FROM texts") - .unwrap() - .maybe_row::() - .unwrap() - .is_none()); + .select_row::("SELECT text FROM texts") + .unwrap()() + .unwrap() + .is_none()); let text_to_insert = "This is a test"; connection - .prepare("INSERT INTO texts VALUES (?)") - .unwrap() - .with_bindings(text_to_insert) - .unwrap() - .exec() - .unwrap(); + .exec_bound("INSERT INTO texts VALUES (?)") + .unwrap()(text_to_insert) + .unwrap(); assert_eq!( - connection - .prepare("SELECT text FROM texts") - .unwrap() - .maybe_row::() - .unwrap(), + connection.select_row("SELECT text FROM texts").unwrap()().unwrap(), Some(text_to_insert.to_string()) ); } diff --git a/crates/sqlez/src/thread_safe_connection.rs b/crates/sqlez/src/thread_safe_connection.rs index f4f759cd6c..45e22e4b3f 100644 --- a/crates/sqlez/src/thread_safe_connection.rs +++ b/crates/sqlez/src/thread_safe_connection.rs @@ -79,7 +79,8 @@ impl Deref for ThreadSafeConnection { connection.exec(initialize_query).expect(&format!( "Initialize query failed to execute: {}", initialize_query - )); + ))() + .unwrap(); } if let Some(migrations) = self.migrations { diff --git a/crates/sqlez/src/typed_statements.rs b/crates/sqlez/src/typed_statements.rs new file mode 100644 index 0000000000..f2d66a781f --- /dev/null +++ b/crates/sqlez/src/typed_statements.rs @@ -0,0 +1,67 @@ +use anyhow::Result; + +use crate::{ + bindable::{Bind, Column}, + connection::Connection, + statement::Statement, +}; + +impl Connection { + pub fn exec<'a>(&'a self, query: &str) -> Result Result<()>> { + let mut statement = Statement::prepare(&self, query)?; + Ok(move || statement.exec()) + } + + pub fn exec_bound<'a, B: Bind>( + &'a self, + query: &str, + ) -> Result Result<()>> { + let mut statement = Statement::prepare(&self, query)?; + Ok(move |bindings| statement.with_bindings(bindings)?.exec()) + } + + pub fn insert<'a>(&'a self, query: &str) -> Result Result> { + let mut statement = Statement::prepare(&self, query)?; + Ok(move || statement.insert()) + } + + pub fn insert_bound<'a, B: Bind>( + &'a self, + query: &str, + ) -> Result Result> { + let mut statement = Statement::prepare(&self, query)?; + Ok(move |bindings| statement.with_bindings(bindings)?.insert()) + } + + pub fn select<'a, C: Column>( + &'a self, + query: &str, + ) -> Result Result>> { + let mut statement = Statement::prepare(&self, query)?; + Ok(move || statement.rows::()) + } + + pub fn select_bound<'a, B: Bind, C: Column>( + &'a self, + query: &str, + ) -> Result Result>> { + let mut statement = Statement::prepare(&self, query)?; + Ok(move |bindings| statement.with_bindings(bindings)?.rows::()) + } + + pub fn select_row<'a, C: Column>( + &'a self, + query: &str, + ) -> Result Result>> { + let mut statement = Statement::prepare(&self, query)?; + Ok(move || statement.maybe_row::()) + } + + pub fn select_row_bound<'a, B: Bind, C: Column>( + &'a self, + query: &str, + ) -> Result Result>> { + let mut statement = Statement::prepare(&self, query)?; + Ok(move |bindings| statement.with_bindings(bindings)?.maybe_row::()) + } +}