mirror of
https://github.com/zed-industries/zed.git
synced 2024-12-24 17:28:40 +00:00
fix test failures
This commit is contained in:
parent
a29ccb4ff8
commit
1b225fa37c
2 changed files with 64 additions and 34 deletions
|
@ -4,6 +4,7 @@ pub mod kvp;
|
|||
pub use anyhow;
|
||||
pub use indoc::indoc;
|
||||
pub use lazy_static;
|
||||
use parking_lot::Mutex;
|
||||
pub use smol;
|
||||
pub use sqlez;
|
||||
pub use sqlez_macros;
|
||||
|
@ -59,6 +60,14 @@ pub async fn open_memory_db<M: Migrator>(db_name: &str) -> ThreadSafeConnection<
|
|||
ThreadSafeConnection::<M>::builder(db_name, false)
|
||||
.with_db_initialization_query(DB_INITIALIZE_QUERY)
|
||||
.with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
|
||||
// Serialize queued writes via a mutex and run them synchronously
|
||||
.with_write_queue_constructor(Box::new(|connection| {
|
||||
let connection = Mutex::new(connection);
|
||||
Box::new(move |queued_write| {
|
||||
let connection = connection.lock();
|
||||
queued_write(&connection)
|
||||
})
|
||||
}))
|
||||
.build()
|
||||
.await
|
||||
}
|
||||
|
|
|
@ -13,12 +13,14 @@ use crate::{
|
|||
const MIGRATION_RETRIES: usize = 10;
|
||||
|
||||
type QueuedWrite = Box<dyn 'static + Send + FnOnce(&Connection)>;
|
||||
type WriteQueueConstructor =
|
||||
Box<dyn 'static + Send + FnMut(Connection) -> Box<dyn 'static + Send + Sync + Fn(QueuedWrite)>>;
|
||||
lazy_static! {
|
||||
/// List of queues of tasks by database uri. This lets us serialize writes to the database
|
||||
/// and have a single worker thread per db file. This means many thread safe connections
|
||||
/// (possibly with different migrations) could all be communicating with the same background
|
||||
/// thread.
|
||||
static ref QUEUES: RwLock<HashMap<Arc<str>, UnboundedSyncSender<QueuedWrite>>> =
|
||||
static ref QUEUES: RwLock<HashMap<Arc<str>, Box<dyn 'static + Send + Sync + Fn(QueuedWrite)>>> =
|
||||
Default::default();
|
||||
}
|
||||
|
||||
|
@ -38,6 +40,7 @@ unsafe impl<T: Migrator> Sync for ThreadSafeConnection<T> {}
|
|||
|
||||
pub struct ThreadSafeConnectionBuilder<M: Migrator = ()> {
|
||||
db_initialize_query: Option<&'static str>,
|
||||
write_queue_constructor: Option<WriteQueueConstructor>,
|
||||
connection: ThreadSafeConnection<M>,
|
||||
}
|
||||
|
||||
|
@ -50,6 +53,18 @@ impl<M: Migrator> ThreadSafeConnectionBuilder<M> {
|
|||
self
|
||||
}
|
||||
|
||||
/// Specifies how the thread safe connection should serialize writes. If provided
|
||||
/// the connection will call the write_queue_constructor for each database file in
|
||||
/// this process. The constructor is responsible for setting up a background thread or
|
||||
/// async task which handles queued writes with the provided connection.
|
||||
pub fn with_write_queue_constructor(
|
||||
mut self,
|
||||
write_queue_constructor: WriteQueueConstructor,
|
||||
) -> Self {
|
||||
self.write_queue_constructor = Some(write_queue_constructor);
|
||||
self
|
||||
}
|
||||
|
||||
/// Queues an initialization query for the database file. This must be infallible
|
||||
/// but may cause changes to the database file such as with `PRAGMA journal_mode`
|
||||
pub fn with_db_initialization_query(mut self, initialize_query: &'static str) -> Self {
|
||||
|
@ -58,6 +73,38 @@ impl<M: Migrator> ThreadSafeConnectionBuilder<M> {
|
|||
}
|
||||
|
||||
pub async fn build(self) -> ThreadSafeConnection<M> {
|
||||
if !QUEUES.read().contains_key(&self.connection.uri) {
|
||||
let mut queues = QUEUES.write();
|
||||
if !queues.contains_key(&self.connection.uri) {
|
||||
let mut write_connection = self.connection.create_connection();
|
||||
// Enable writes for this connection
|
||||
write_connection.write = true;
|
||||
if let Some(mut write_queue_constructor) = self.write_queue_constructor {
|
||||
let write_channel = write_queue_constructor(write_connection);
|
||||
queues.insert(self.connection.uri.clone(), write_channel);
|
||||
} else {
|
||||
use std::sync::mpsc::channel;
|
||||
|
||||
let (sender, reciever) = channel::<QueuedWrite>();
|
||||
thread::spawn(move || {
|
||||
while let Ok(write) = reciever.recv() {
|
||||
write(&write_connection)
|
||||
}
|
||||
});
|
||||
|
||||
let sender = UnboundedSyncSender::new(sender);
|
||||
queues.insert(
|
||||
self.connection.uri.clone(),
|
||||
Box::new(move |queued_write| {
|
||||
sender
|
||||
.send(queued_write)
|
||||
.expect("Could not send write action to backgorund thread");
|
||||
}),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let db_initialize_query = self.db_initialize_query;
|
||||
|
||||
self.connection
|
||||
|
@ -90,6 +137,7 @@ impl<M: Migrator> ThreadSafeConnection<M> {
|
|||
pub fn builder(uri: &str, persistent: bool) -> ThreadSafeConnectionBuilder<M> {
|
||||
ThreadSafeConnectionBuilder::<M> {
|
||||
db_initialize_query: None,
|
||||
write_queue_constructor: None,
|
||||
connection: Self {
|
||||
uri: Arc::from(uri),
|
||||
persistent,
|
||||
|
@ -112,48 +160,21 @@ impl<M: Migrator> ThreadSafeConnection<M> {
|
|||
Connection::open_memory(Some(self.uri.as_ref()))
|
||||
}
|
||||
|
||||
fn queue_write_task(&self, callback: QueuedWrite) {
|
||||
// Startup write thread for this database if one hasn't already
|
||||
// been started and insert a channel to queue work for it
|
||||
if !QUEUES.read().contains_key(&self.uri) {
|
||||
let mut queues = QUEUES.write();
|
||||
if !queues.contains_key(&self.uri) {
|
||||
use std::sync::mpsc::channel;
|
||||
|
||||
let (sender, reciever) = channel::<QueuedWrite>();
|
||||
let mut write_connection = self.create_connection();
|
||||
// Enable writes for this connection
|
||||
write_connection.write = true;
|
||||
thread::spawn(move || {
|
||||
while let Ok(write) = reciever.recv() {
|
||||
write(&write_connection)
|
||||
}
|
||||
});
|
||||
|
||||
queues.insert(self.uri.clone(), UnboundedSyncSender::new(sender));
|
||||
}
|
||||
}
|
||||
|
||||
// Grab the queue for this database
|
||||
let queues = QUEUES.read();
|
||||
let write_channel = queues.get(&self.uri).unwrap();
|
||||
|
||||
write_channel
|
||||
.send(callback)
|
||||
.expect("Could not send write action to backgorund thread");
|
||||
}
|
||||
|
||||
pub fn write<T: 'static + Send + Sync>(
|
||||
&self,
|
||||
callback: impl 'static + Send + FnOnce(&Connection) -> T,
|
||||
) -> impl Future<Output = T> {
|
||||
let queues = QUEUES.read();
|
||||
let write_channel = queues
|
||||
.get(&self.uri)
|
||||
.expect("Queues are inserted when build is called. This should always succeed");
|
||||
|
||||
// Create a one shot channel for the result of the queued write
|
||||
// so we can await on the result
|
||||
let (sender, reciever) = oneshot::channel();
|
||||
self.queue_write_task(Box::new(move |connection| {
|
||||
write_channel(Box::new(move |connection| {
|
||||
sender.send(callback(connection)).ok();
|
||||
}));
|
||||
|
||||
reciever.map(|response| response.expect("Background writer thread unexpectedly closed"))
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue