zed/crates/sqlez/src/savepoint.rs

149 lines
4.5 KiB
Rust
Raw Normal View History

2022-11-01 20:32:46 +00:00
use anyhow::Result;
use indoc::formatdoc;
2022-11-01 20:32:46 +00:00
use crate::connection::Connection;
impl Connection {
// Run a set of commands within the context of a `SAVEPOINT name`. If the callback
// returns Err(_), the savepoint will be rolled back. Otherwise, the save
// point is released.
2022-11-02 20:26:23 +00:00
pub fn with_savepoint<R, F>(&self, name: impl AsRef<str>, f: F) -> Result<R>
where
F: FnOnce() -> Result<R>,
{
2022-11-07 01:00:34 +00:00
let name = name.as_ref();
self.exec(&format!("SAVEPOINT {name}"))?()?;
let result = f();
match result {
Ok(_) => {
2022-11-07 01:00:34 +00:00
self.exec(&format!("RELEASE {name}"))?()?;
}
Err(_) => {
2022-11-07 01:00:34 +00:00
self.exec(&formatdoc! {"
ROLLBACK TO {name};
RELEASE {name}"})?()?;
}
}
result
}
2022-11-01 20:32:46 +00:00
// Run a set of commands within the context of a `SAVEPOINT name`. If the callback
// returns Ok(None) or Err(_), the savepoint will be rolled back. Otherwise, the save
// point is released.
2022-11-02 20:26:23 +00:00
pub fn with_savepoint_rollback<R, F>(&self, name: impl AsRef<str>, f: F) -> Result<Option<R>>
2022-11-01 20:32:46 +00:00
where
F: FnOnce() -> Result<Option<R>>,
2022-11-01 20:32:46 +00:00
{
2022-11-07 01:00:34 +00:00
let name = name.as_ref();
self.exec(&format!("SAVEPOINT {name}"))?()?;
let result = f();
2022-11-01 20:32:46 +00:00
match result {
Ok(Some(_)) => {
2022-11-07 01:00:34 +00:00
self.exec(&format!("RELEASE {name}"))?()?;
2022-11-01 20:32:46 +00:00
}
Ok(None) | Err(_) => {
2022-11-07 01:00:34 +00:00
self.exec(&formatdoc! {"
ROLLBACK TO {name};
RELEASE {name}"})?()?;
2022-11-01 20:32:46 +00:00
}
}
result
}
}
#[cfg(test)]
mod tests {
use crate::connection::Connection;
use anyhow::Result;
use indoc::indoc;
#[test]
fn test_nested_savepoints() -> Result<()> {
2022-11-18 22:20:52 +00:00
let connection = Connection::open_memory(Some("nested_savepoints"));
2022-11-01 20:32:46 +00:00
connection
.exec(indoc! {"
CREATE TABLE text (
text TEXT,
idx INTEGER
);"})
2022-11-07 01:00:34 +00:00
.unwrap()()
.unwrap();
2022-11-01 20:32:46 +00:00
let save1_text = "test save1";
let save2_text = "test save2";
connection.with_savepoint("first", || {
2022-11-07 01:00:34 +00:00
connection.exec_bound("INSERT INTO text(text, idx) VALUES (?, ?)")?((save1_text, 1))?;
2022-11-01 20:32:46 +00:00
assert!(connection
.with_savepoint("second", || -> Result<Option<()>, anyhow::Error> {
2022-11-07 01:00:34 +00:00
connection.exec_bound("INSERT INTO text(text, idx) VALUES (?, ?)")?((
save2_text, 2,
))?;
2022-11-01 20:32:46 +00:00
assert_eq!(
connection
2022-11-07 01:00:34 +00:00
.select::<String>("SELECT text FROM text ORDER BY text.idx ASC")?(
)?,
2022-11-01 20:32:46 +00:00
vec![save1_text, save2_text],
);
anyhow::bail!("Failed second save point :(")
})
.err()
.is_some());
assert_eq!(
2022-11-07 01:00:34 +00:00
connection.select::<String>("SELECT text FROM text ORDER BY text.idx ASC")?()?,
2022-11-01 20:32:46 +00:00
vec![save1_text],
);
connection.with_savepoint_rollback::<(), _>("second", || {
2022-11-07 01:00:34 +00:00
connection.exec_bound("INSERT INTO text(text, idx) VALUES (?, ?)")?((
save2_text, 2,
))?;
assert_eq!(
2022-11-07 01:00:34 +00:00
connection.select::<String>("SELECT text FROM text ORDER BY text.idx ASC")?()?,
vec![save1_text, save2_text],
);
Ok(None)
})?;
assert_eq!(
2022-11-07 01:00:34 +00:00
connection.select::<String>("SELECT text FROM text ORDER BY text.idx ASC")?()?,
vec![save1_text],
);
connection.with_savepoint_rollback("second", || {
2022-11-07 01:00:34 +00:00
connection.exec_bound("INSERT INTO text(text, idx) VALUES (?, ?)")?((
save2_text, 2,
))?;
2022-11-01 20:32:46 +00:00
assert_eq!(
2022-11-07 01:00:34 +00:00
connection.select::<String>("SELECT text FROM text ORDER BY text.idx ASC")?()?,
2022-11-01 20:32:46 +00:00
vec![save1_text, save2_text],
);
Ok(Some(()))
})?;
assert_eq!(
2022-11-07 01:00:34 +00:00
connection.select::<String>("SELECT text FROM text ORDER BY text.idx ASC")?()?,
2022-11-01 20:32:46 +00:00
vec![save1_text, save2_text],
);
Ok(())
2022-11-01 20:32:46 +00:00
})?;
assert_eq!(
2022-11-07 01:00:34 +00:00
connection.select::<String>("SELECT text FROM text ORDER BY text.idx ASC")?()?,
vec![save1_text, save2_text],
);
2022-11-01 20:32:46 +00:00
Ok(())
}
}