fix test failures

This commit is contained in:
Kay Simmons 2022-11-30 12:34:42 -08:00 committed by Mikayla Maki
parent a29ccb4ff8
commit 1b225fa37c
2 changed files with 64 additions and 34 deletions

View file

@ -4,6 +4,7 @@ pub mod kvp;
pub use anyhow;
pub use indoc::indoc;
pub use lazy_static;
use parking_lot::Mutex;
pub use smol;
pub use sqlez;
pub use sqlez_macros;
@ -59,6 +60,14 @@ pub async fn open_memory_db<M: Migrator>(db_name: &str) -> ThreadSafeConnection<
ThreadSafeConnection::<M>::builder(db_name, false)
.with_db_initialization_query(DB_INITIALIZE_QUERY)
.with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
// Serialize queued writes via a mutex and run them synchronously
.with_write_queue_constructor(Box::new(|connection| {
let connection = Mutex::new(connection);
Box::new(move |queued_write| {
let connection = connection.lock();
queued_write(&connection)
})
}))
.build()
.await
}

View file

@ -13,12 +13,14 @@ use crate::{
const MIGRATION_RETRIES: usize = 10;
type QueuedWrite = Box<dyn 'static + Send + FnOnce(&Connection)>;
type WriteQueueConstructor =
Box<dyn 'static + Send + FnMut(Connection) -> Box<dyn 'static + Send + Sync + Fn(QueuedWrite)>>;
lazy_static! {
/// List of queues of tasks by database uri. This lets us serialize writes to the database
/// and have a single worker thread per db file. This means many thread safe connections
/// (possibly with different migrations) could all be communicating with the same background
/// thread.
static ref QUEUES: RwLock<HashMap<Arc<str>, UnboundedSyncSender<QueuedWrite>>> =
static ref QUEUES: RwLock<HashMap<Arc<str>, Box<dyn 'static + Send + Sync + Fn(QueuedWrite)>>> =
Default::default();
}
@ -38,6 +40,7 @@ unsafe impl<T: Migrator> Sync for ThreadSafeConnection<T> {}
pub struct ThreadSafeConnectionBuilder<M: Migrator = ()> {
db_initialize_query: Option<&'static str>,
write_queue_constructor: Option<WriteQueueConstructor>,
connection: ThreadSafeConnection<M>,
}
@ -50,6 +53,18 @@ impl<M: Migrator> ThreadSafeConnectionBuilder<M> {
self
}
/// Specifies how the thread safe connection should serialize writes. If provided
/// the connection will call the write_queue_constructor for each database file in
/// this process. The constructor is responsible for setting up a background thread or
/// async task which handles queued writes with the provided connection.
pub fn with_write_queue_constructor(
mut self,
write_queue_constructor: WriteQueueConstructor,
) -> Self {
self.write_queue_constructor = Some(write_queue_constructor);
self
}
/// Queues an initialization query for the database file. This must be infallible
/// but may cause changes to the database file such as with `PRAGMA journal_mode`
pub fn with_db_initialization_query(mut self, initialize_query: &'static str) -> Self {
@ -58,6 +73,38 @@ impl<M: Migrator> ThreadSafeConnectionBuilder<M> {
}
pub async fn build(self) -> ThreadSafeConnection<M> {
if !QUEUES.read().contains_key(&self.connection.uri) {
let mut queues = QUEUES.write();
if !queues.contains_key(&self.connection.uri) {
let mut write_connection = self.connection.create_connection();
// Enable writes for this connection
write_connection.write = true;
if let Some(mut write_queue_constructor) = self.write_queue_constructor {
let write_channel = write_queue_constructor(write_connection);
queues.insert(self.connection.uri.clone(), write_channel);
} else {
use std::sync::mpsc::channel;
let (sender, reciever) = channel::<QueuedWrite>();
thread::spawn(move || {
while let Ok(write) = reciever.recv() {
write(&write_connection)
}
});
let sender = UnboundedSyncSender::new(sender);
queues.insert(
self.connection.uri.clone(),
Box::new(move |queued_write| {
sender
.send(queued_write)
.expect("Could not send write action to backgorund thread");
}),
);
}
}
}
let db_initialize_query = self.db_initialize_query;
self.connection
@ -90,6 +137,7 @@ impl<M: Migrator> ThreadSafeConnection<M> {
pub fn builder(uri: &str, persistent: bool) -> ThreadSafeConnectionBuilder<M> {
ThreadSafeConnectionBuilder::<M> {
db_initialize_query: None,
write_queue_constructor: None,
connection: Self {
uri: Arc::from(uri),
persistent,
@ -112,48 +160,21 @@ impl<M: Migrator> ThreadSafeConnection<M> {
Connection::open_memory(Some(self.uri.as_ref()))
}
fn queue_write_task(&self, callback: QueuedWrite) {
// Startup write thread for this database if one hasn't already
// been started and insert a channel to queue work for it
if !QUEUES.read().contains_key(&self.uri) {
let mut queues = QUEUES.write();
if !queues.contains_key(&self.uri) {
use std::sync::mpsc::channel;
let (sender, reciever) = channel::<QueuedWrite>();
let mut write_connection = self.create_connection();
// Enable writes for this connection
write_connection.write = true;
thread::spawn(move || {
while let Ok(write) = reciever.recv() {
write(&write_connection)
}
});
queues.insert(self.uri.clone(), UnboundedSyncSender::new(sender));
}
}
// Grab the queue for this database
let queues = QUEUES.read();
let write_channel = queues.get(&self.uri).unwrap();
write_channel
.send(callback)
.expect("Could not send write action to backgorund thread");
}
pub fn write<T: 'static + Send + Sync>(
&self,
callback: impl 'static + Send + FnOnce(&Connection) -> T,
) -> impl Future<Output = T> {
let queues = QUEUES.read();
let write_channel = queues
.get(&self.uri)
.expect("Queues are inserted when build is called. This should always succeed");
// Create a one shot channel for the result of the queued write
// so we can await on the result
let (sender, reciever) = oneshot::channel();
self.queue_write_task(Box::new(move |connection| {
write_channel(Box::new(move |connection| {
sender.send(callback(connection)).ok();
}));
reciever.map(|response| response.expect("Background writer thread unexpectedly closed"))
}