diff --git a/crates/db/src/db.rs b/crates/db/src/db.rs index 701aa57656..1ac1d1604b 100644 --- a/crates/db/src/db.rs +++ b/crates/db/src/db.rs @@ -4,6 +4,7 @@ pub mod kvp; pub use anyhow; pub use indoc::indoc; pub use lazy_static; +use parking_lot::Mutex; pub use smol; pub use sqlez; pub use sqlez_macros; @@ -59,6 +60,14 @@ pub async fn open_memory_db(db_name: &str) -> ThreadSafeConnection< ThreadSafeConnection::::builder(db_name, false) .with_db_initialization_query(DB_INITIALIZE_QUERY) .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY) + // Serialize queued writes via a mutex and run them synchronously + .with_write_queue_constructor(Box::new(|connection| { + let connection = Mutex::new(connection); + Box::new(move |queued_write| { + let connection = connection.lock(); + queued_write(&connection) + }) + })) .build() .await } diff --git a/crates/sqlez/src/thread_safe_connection.rs b/crates/sqlez/src/thread_safe_connection.rs index 880a58d194..b17c87d63f 100644 --- a/crates/sqlez/src/thread_safe_connection.rs +++ b/crates/sqlez/src/thread_safe_connection.rs @@ -13,12 +13,14 @@ use crate::{ const MIGRATION_RETRIES: usize = 10; type QueuedWrite = Box; +type WriteQueueConstructor = + Box Box>; lazy_static! { /// List of queues of tasks by database uri. This lets us serialize writes to the database /// and have a single worker thread per db file. This means many thread safe connections /// (possibly with different migrations) could all be communicating with the same background /// thread. - static ref QUEUES: RwLock, UnboundedSyncSender>> = + static ref QUEUES: RwLock, Box>> = Default::default(); } @@ -38,6 +40,7 @@ unsafe impl Sync for ThreadSafeConnection {} pub struct ThreadSafeConnectionBuilder { db_initialize_query: Option<&'static str>, + write_queue_constructor: Option, connection: ThreadSafeConnection, } @@ -50,6 +53,18 @@ impl ThreadSafeConnectionBuilder { self } + /// Specifies how the thread safe connection should serialize writes. If provided + /// the connection will call the write_queue_constructor for each database file in + /// this process. The constructor is responsible for setting up a background thread or + /// async task which handles queued writes with the provided connection. + pub fn with_write_queue_constructor( + mut self, + write_queue_constructor: WriteQueueConstructor, + ) -> Self { + self.write_queue_constructor = Some(write_queue_constructor); + self + } + /// Queues an initialization query for the database file. This must be infallible /// but may cause changes to the database file such as with `PRAGMA journal_mode` pub fn with_db_initialization_query(mut self, initialize_query: &'static str) -> Self { @@ -58,6 +73,38 @@ impl ThreadSafeConnectionBuilder { } pub async fn build(self) -> ThreadSafeConnection { + if !QUEUES.read().contains_key(&self.connection.uri) { + let mut queues = QUEUES.write(); + if !queues.contains_key(&self.connection.uri) { + let mut write_connection = self.connection.create_connection(); + // Enable writes for this connection + write_connection.write = true; + if let Some(mut write_queue_constructor) = self.write_queue_constructor { + let write_channel = write_queue_constructor(write_connection); + queues.insert(self.connection.uri.clone(), write_channel); + } else { + use std::sync::mpsc::channel; + + let (sender, reciever) = channel::(); + thread::spawn(move || { + while let Ok(write) = reciever.recv() { + write(&write_connection) + } + }); + + let sender = UnboundedSyncSender::new(sender); + queues.insert( + self.connection.uri.clone(), + Box::new(move |queued_write| { + sender + .send(queued_write) + .expect("Could not send write action to backgorund thread"); + }), + ); + } + } + } + let db_initialize_query = self.db_initialize_query; self.connection @@ -90,6 +137,7 @@ impl ThreadSafeConnection { pub fn builder(uri: &str, persistent: bool) -> ThreadSafeConnectionBuilder { ThreadSafeConnectionBuilder:: { db_initialize_query: None, + write_queue_constructor: None, connection: Self { uri: Arc::from(uri), persistent, @@ -112,48 +160,21 @@ impl ThreadSafeConnection { Connection::open_memory(Some(self.uri.as_ref())) } - fn queue_write_task(&self, callback: QueuedWrite) { - // Startup write thread for this database if one hasn't already - // been started and insert a channel to queue work for it - if !QUEUES.read().contains_key(&self.uri) { - let mut queues = QUEUES.write(); - if !queues.contains_key(&self.uri) { - use std::sync::mpsc::channel; - - let (sender, reciever) = channel::(); - let mut write_connection = self.create_connection(); - // Enable writes for this connection - write_connection.write = true; - thread::spawn(move || { - while let Ok(write) = reciever.recv() { - write(&write_connection) - } - }); - - queues.insert(self.uri.clone(), UnboundedSyncSender::new(sender)); - } - } - - // Grab the queue for this database - let queues = QUEUES.read(); - let write_channel = queues.get(&self.uri).unwrap(); - - write_channel - .send(callback) - .expect("Could not send write action to backgorund thread"); - } - pub fn write( &self, callback: impl 'static + Send + FnOnce(&Connection) -> T, ) -> impl Future { + let queues = QUEUES.read(); + let write_channel = queues + .get(&self.uri) + .expect("Queues are inserted when build is called. This should always succeed"); + // Create a one shot channel for the result of the queued write // so we can await on the result let (sender, reciever) = oneshot::channel(); - self.queue_write_task(Box::new(move |connection| { + write_channel(Box::new(move |connection| { sender.send(callback(connection)).ok(); })); - reciever.map(|response| response.expect("Background writer thread unexpectedly closed")) }