introduce parallel salsa

This commit is contained in:
David Barsky 2024-08-23 14:30:59 -04:00
parent 3b7246ba9f
commit 5f0904ae4a
13 changed files with 188 additions and 15 deletions

View file

@ -21,6 +21,7 @@ salsa-macro-rules = { version = "0.1.0", path = "components/salsa-macro-rules" }
salsa-macros = { path = "components/salsa-macros" }
smallvec = "1"
lazy_static = "1"
rayon = "1.10.0"
[dev-dependencies]
annotate-snippets = "0.11.4"

View file

@ -2,7 +2,7 @@ use std::sync::{Arc, Mutex};
// ANCHOR: db_struct
#[salsa::db]
#[derive(Default)]
#[derive(Default, Clone)]
pub struct CalcDatabaseImpl {
storage: salsa::Storage<Self>,

View file

@ -1,6 +1,10 @@
#![allow(unreachable_patterns)]
// FIXME(rust-lang/rust#129031): regression in nightly
use std::{path::PathBuf, sync::Mutex, time::Duration};
use std::{
path::PathBuf,
sync::{Arc, Mutex},
time::Duration,
};
use crossbeam::channel::{unbounded, Sender};
use dashmap::{mapref::entry::Entry, DashMap};
@ -77,11 +81,12 @@ trait Db: salsa::Database {
}
#[salsa::db]
#[derive(Clone)]
struct LazyInputDatabase {
storage: Storage<Self>,
logs: Mutex<Vec<String>>,
logs: Arc<Mutex<Vec<String>>>,
files: DashMap<PathBuf, File>,
file_watcher: Mutex<Debouncer<RecommendedWatcher>>,
file_watcher: Arc<Mutex<Debouncer<RecommendedWatcher>>>,
}
impl LazyInputDatabase {
@ -90,7 +95,9 @@ impl LazyInputDatabase {
storage: Default::default(),
logs: Default::default(),
files: DashMap::new(),
file_watcher: Mutex::new(new_debouncer(Duration::from_secs(1), tx).unwrap()),
file_watcher: Arc::new(Mutex::new(
new_debouncer(Duration::from_secs(1), tx).unwrap(),
)),
}
}
}

View file

@ -90,7 +90,7 @@ impl dyn Database {
///
/// # Panics
///
/// If the view has not been added to the database (see [`DatabaseView`][])
/// If the view has not been added to the database (see [`crate::views::Views`]).
#[track_caller]
pub fn as_view<DbView: ?Sized + Database>(&self) -> &DbView {
self.zalsa().views().try_view_as(self).unwrap()

View file

@ -3,7 +3,7 @@ use crate::{self as salsa, Database, Event, Storage};
#[salsa::db]
/// Default database implementation that you can use if you don't
/// require any custom user data.
#[derive(Default)]
#[derive(Default, Clone)]
pub struct DatabaseImpl {
storage: Storage<Self>,
}

View file

@ -16,6 +16,7 @@ mod input;
mod interned;
mod key;
mod nonce;
mod par_map;
mod revision;
mod runtime;
mod salsa_struct;
@ -45,6 +46,7 @@ pub use self::storage::Storage;
pub use self::update::Update;
pub use self::zalsa::IngredientIndex;
pub use crate::attach::with_attached_database;
pub use par_map::par_map;
pub use salsa_macros::accumulator;
pub use salsa_macros::db;
pub use salsa_macros::input;

54
src/par_map.rs Normal file
View file

@ -0,0 +1,54 @@
use std::ops::Deref;
use rayon::iter::{FromParallelIterator, IntoParallelIterator, ParallelIterator};
use crate::Database;
pub fn par_map<Db, D, E, C>(
db: &Db,
inputs: impl IntoParallelIterator<Item = D>,
op: fn(&Db, D) -> E,
) -> C
where
Db: Database + ?Sized,
D: Send,
E: Send + Sync,
C: FromParallelIterator<E>,
{
let parallel_db = ParallelDb::Ref(db.as_dyn_database());
inputs
.into_par_iter()
.map_with(parallel_db, |parallel_db, element| {
let db = parallel_db.as_view::<Db>();
op(db, element)
})
.collect()
}
/// This enum _must not_ be public or used outside of `par_map`.
enum ParallelDb<'db> {
Ref(&'db dyn Database),
Fork(Box<dyn Database + Send>),
}
/// SAFETY: the contents of the database are never accessed on the thread
/// where this wrapper type is created.
unsafe impl Send for ParallelDb<'_> {}
impl Deref for ParallelDb<'_> {
type Target = dyn Database;
fn deref(&self) -> &Self::Target {
match self {
ParallelDb::Ref(db) => *db,
ParallelDb::Fork(db) => db.as_dyn_database(),
}
}
}
impl Clone for ParallelDb<'_> {
fn clone(&self) -> Self {
ParallelDb::Fork(self.fork_db())
}
}

View file

@ -15,7 +15,7 @@ use crate::{
///
/// The `storage` and `storage_mut` fields must both return a reference to the same
/// storage field which must be owned by `self`.
pub unsafe trait HasStorage: Database + Sized {
pub unsafe trait HasStorage: Database + Clone + Sized {
fn storage(&self) -> &Storage<Self>;
fn storage_mut(&mut self) -> &mut Storage<Self>;
}
@ -108,6 +108,10 @@ unsafe impl<T: HasStorage> ZalsaDatabase for T {
fn zalsa_local(&self) -> &ZalsaLocal {
&self.storage().zalsa_local
}
fn fork_db(&self) -> Box<dyn Database> {
Box::new(self.clone())
}
}
impl<Db: Database> RefUnwindSafe for Storage<Db> {}

View file

@ -50,6 +50,10 @@ pub unsafe trait ZalsaDatabase: Any {
/// Access the thread-local state associated with this database
#[doc(hidden)]
fn zalsa_local(&self) -> &ZalsaLocal;
/// Clone the database.
#[doc(hidden)]
fn fork_db(&self) -> Box<dyn Database>;
}
pub fn views<Db: ?Sized + Database>(db: &Db) -> &Views {

View file

@ -2,15 +2,17 @@
#![allow(dead_code)]
use std::sync::{Arc, Mutex};
use salsa::{Database, Storage};
/// Logging userdata: provides [`LogDatabase`][] trait.
///
/// If you wish to use it along with other userdata,
/// you can also embed it in another struct and implement [`HasLogger`][] for that struct.
#[derive(Default)]
#[derive(Clone, Default)]
pub struct Logger {
logs: std::sync::Mutex<Vec<String>>,
logs: Arc<Mutex<Vec<String>>>,
}
/// Trait implemented by databases that lets them log events.
@ -48,7 +50,7 @@ impl<Db: HasLogger + Database> LogDatabase for Db {}
/// Database that provides logging but does not log salsa event.
#[salsa::db]
#[derive(Default)]
#[derive(Clone, Default)]
pub struct LoggerDatabase {
storage: Storage<Self>,
logger: Logger,
@ -67,7 +69,7 @@ impl Database for LoggerDatabase {
/// Database that provides logging and logs salsa events.
#[salsa::db]
#[derive(Default)]
#[derive(Clone, Default)]
pub struct EventLoggerDatabase {
storage: Storage<Self>,
logger: Logger,
@ -87,7 +89,7 @@ impl HasLogger for EventLoggerDatabase {
}
#[salsa::db]
#[derive(Default)]
#[derive(Clone, Default)]
pub struct DiscardLoggerDatabase {
storage: Storage<Self>,
logger: Logger,
@ -114,7 +116,7 @@ impl HasLogger for DiscardLoggerDatabase {
}
#[salsa::db]
#[derive(Default)]
#[derive(Clone, Default)]
pub struct ExecuteValidateLoggerDatabase {
storage: Storage<Self>,
logger: Logger,

View file

@ -5,4 +5,5 @@ mod parallel_cycle_all_recover;
mod parallel_cycle_mid_recover;
mod parallel_cycle_none_recover;
mod parallel_cycle_one_recover;
mod parallel_map;
mod signal;

View file

@ -0,0 +1,98 @@
// test for rayon interations.
use salsa::Cancelled;
use salsa::Setter;
use crate::setup::Knobs;
use crate::setup::KnobsDatabase;
#[salsa::input]
struct ParallelInput {
field: Vec<u32>,
}
#[salsa::tracked]
fn tracked_fn(db: &dyn salsa::Database, input: ParallelInput) -> Vec<u32> {
salsa::par_map(db, input.field(db), |_db, field| field + 1)
}
#[test]
fn execute() {
let db = salsa::DatabaseImpl::new();
let counts = (1..=10).collect::<Vec<u32>>();
let input = ParallelInput::new(&db, counts);
tracked_fn(&db, input);
}
#[salsa::tracked]
fn a1(db: &dyn KnobsDatabase, input: ParallelInput) -> Vec<u32> {
db.signal(1);
salsa::par_map(db, input.field(db), |db, field| {
db.wait_for(2);
field + 1
})
}
#[salsa::tracked]
fn dummy(_db: &dyn KnobsDatabase, _input: ParallelInput) -> ParallelInput {
panic!("should never get here!")
}
// we expect this to panic, as `salsa::par_map` needs to be called from a query.
#[test]
#[should_panic]
fn direct_calls_panic() {
let db = salsa::DatabaseImpl::new();
let counts = (1..=10).collect::<Vec<u32>>();
let input = ParallelInput::new(&db, counts);
let _: Vec<u32> = salsa::par_map(&db, input.field(&db), |_db, field| field + 1);
}
// Cancellation signalling test
//
// The pattern is as follows.
//
// Thread A Thread B
// -------- --------
// a1
// | wait for stage 1
// signal stage 1 set input, triggers cancellation
// wait for stage 2 (blocks) triggering cancellation sends stage 2
// |
// (unblocked)
// dummy
// panics
#[test]
fn execute_cancellation() {
let mut db = Knobs::default();
let counts = (1..=10).collect::<Vec<u32>>();
let input = ParallelInput::new(&db, counts);
let thread_a = std::thread::spawn({
let db = db.clone();
move || a1(&db, input)
});
let counts = (2..=20).collect::<Vec<u32>>();
db.signal_on_did_cancel.store(2);
input.set_field(&mut db).to(counts);
// Assert thread A *should* was cancelled
let cancelled = thread_a
.join()
.unwrap_err()
.downcast::<Cancelled>()
.unwrap();
// and inspect the output
expect_test::expect![[r#"
PendingWrite
"#]]
.assert_debug_eq(&cancelled);
}

View file

@ -83,7 +83,7 @@ fn check<'db>(db: &'db dyn Db, file: File) -> Inference<'db> {
#[test]
fn execute() {
#[salsa::db]
#[derive(Default)]
#[derive(Default, Clone)]
struct Database {
storage: salsa::Storage<Self>,
files: Vec<File>,