mirror of
https://github.com/salsa-rs/salsa.git
synced 2025-01-12 16:35:21 +00:00
introduce parallel salsa
This commit is contained in:
parent
3b7246ba9f
commit
5f0904ae4a
13 changed files with 188 additions and 15 deletions
|
@ -21,6 +21,7 @@ salsa-macro-rules = { version = "0.1.0", path = "components/salsa-macro-rules" }
|
|||
salsa-macros = { path = "components/salsa-macros" }
|
||||
smallvec = "1"
|
||||
lazy_static = "1"
|
||||
rayon = "1.10.0"
|
||||
|
||||
[dev-dependencies]
|
||||
annotate-snippets = "0.11.4"
|
||||
|
|
|
@ -2,7 +2,7 @@ use std::sync::{Arc, Mutex};
|
|||
|
||||
// ANCHOR: db_struct
|
||||
#[salsa::db]
|
||||
#[derive(Default)]
|
||||
#[derive(Default, Clone)]
|
||||
pub struct CalcDatabaseImpl {
|
||||
storage: salsa::Storage<Self>,
|
||||
|
||||
|
|
|
@ -1,6 +1,10 @@
|
|||
#![allow(unreachable_patterns)]
|
||||
// FIXME(rust-lang/rust#129031): regression in nightly
|
||||
use std::{path::PathBuf, sync::Mutex, time::Duration};
|
||||
use std::{
|
||||
path::PathBuf,
|
||||
sync::{Arc, Mutex},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use crossbeam::channel::{unbounded, Sender};
|
||||
use dashmap::{mapref::entry::Entry, DashMap};
|
||||
|
@ -77,11 +81,12 @@ trait Db: salsa::Database {
|
|||
}
|
||||
|
||||
#[salsa::db]
|
||||
#[derive(Clone)]
|
||||
struct LazyInputDatabase {
|
||||
storage: Storage<Self>,
|
||||
logs: Mutex<Vec<String>>,
|
||||
logs: Arc<Mutex<Vec<String>>>,
|
||||
files: DashMap<PathBuf, File>,
|
||||
file_watcher: Mutex<Debouncer<RecommendedWatcher>>,
|
||||
file_watcher: Arc<Mutex<Debouncer<RecommendedWatcher>>>,
|
||||
}
|
||||
|
||||
impl LazyInputDatabase {
|
||||
|
@ -90,7 +95,9 @@ impl LazyInputDatabase {
|
|||
storage: Default::default(),
|
||||
logs: Default::default(),
|
||||
files: DashMap::new(),
|
||||
file_watcher: Mutex::new(new_debouncer(Duration::from_secs(1), tx).unwrap()),
|
||||
file_watcher: Arc::new(Mutex::new(
|
||||
new_debouncer(Duration::from_secs(1), tx).unwrap(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -90,7 +90,7 @@ impl dyn Database {
|
|||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// If the view has not been added to the database (see [`DatabaseView`][])
|
||||
/// If the view has not been added to the database (see [`crate::views::Views`]).
|
||||
#[track_caller]
|
||||
pub fn as_view<DbView: ?Sized + Database>(&self) -> &DbView {
|
||||
self.zalsa().views().try_view_as(self).unwrap()
|
||||
|
|
|
@ -3,7 +3,7 @@ use crate::{self as salsa, Database, Event, Storage};
|
|||
#[salsa::db]
|
||||
/// Default database implementation that you can use if you don't
|
||||
/// require any custom user data.
|
||||
#[derive(Default)]
|
||||
#[derive(Default, Clone)]
|
||||
pub struct DatabaseImpl {
|
||||
storage: Storage<Self>,
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@ mod input;
|
|||
mod interned;
|
||||
mod key;
|
||||
mod nonce;
|
||||
mod par_map;
|
||||
mod revision;
|
||||
mod runtime;
|
||||
mod salsa_struct;
|
||||
|
@ -45,6 +46,7 @@ pub use self::storage::Storage;
|
|||
pub use self::update::Update;
|
||||
pub use self::zalsa::IngredientIndex;
|
||||
pub use crate::attach::with_attached_database;
|
||||
pub use par_map::par_map;
|
||||
pub use salsa_macros::accumulator;
|
||||
pub use salsa_macros::db;
|
||||
pub use salsa_macros::input;
|
||||
|
|
54
src/par_map.rs
Normal file
54
src/par_map.rs
Normal file
|
@ -0,0 +1,54 @@
|
|||
use std::ops::Deref;
|
||||
|
||||
use rayon::iter::{FromParallelIterator, IntoParallelIterator, ParallelIterator};
|
||||
|
||||
use crate::Database;
|
||||
|
||||
pub fn par_map<Db, D, E, C>(
|
||||
db: &Db,
|
||||
inputs: impl IntoParallelIterator<Item = D>,
|
||||
op: fn(&Db, D) -> E,
|
||||
) -> C
|
||||
where
|
||||
Db: Database + ?Sized,
|
||||
D: Send,
|
||||
E: Send + Sync,
|
||||
C: FromParallelIterator<E>,
|
||||
{
|
||||
let parallel_db = ParallelDb::Ref(db.as_dyn_database());
|
||||
|
||||
inputs
|
||||
.into_par_iter()
|
||||
.map_with(parallel_db, |parallel_db, element| {
|
||||
let db = parallel_db.as_view::<Db>();
|
||||
op(db, element)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// This enum _must not_ be public or used outside of `par_map`.
|
||||
enum ParallelDb<'db> {
|
||||
Ref(&'db dyn Database),
|
||||
Fork(Box<dyn Database + Send>),
|
||||
}
|
||||
|
||||
/// SAFETY: the contents of the database are never accessed on the thread
|
||||
/// where this wrapper type is created.
|
||||
unsafe impl Send for ParallelDb<'_> {}
|
||||
|
||||
impl Deref for ParallelDb<'_> {
|
||||
type Target = dyn Database;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
match self {
|
||||
ParallelDb::Ref(db) => *db,
|
||||
ParallelDb::Fork(db) => db.as_dyn_database(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for ParallelDb<'_> {
|
||||
fn clone(&self) -> Self {
|
||||
ParallelDb::Fork(self.fork_db())
|
||||
}
|
||||
}
|
|
@ -15,7 +15,7 @@ use crate::{
|
|||
///
|
||||
/// The `storage` and `storage_mut` fields must both return a reference to the same
|
||||
/// storage field which must be owned by `self`.
|
||||
pub unsafe trait HasStorage: Database + Sized {
|
||||
pub unsafe trait HasStorage: Database + Clone + Sized {
|
||||
fn storage(&self) -> &Storage<Self>;
|
||||
fn storage_mut(&mut self) -> &mut Storage<Self>;
|
||||
}
|
||||
|
@ -108,6 +108,10 @@ unsafe impl<T: HasStorage> ZalsaDatabase for T {
|
|||
fn zalsa_local(&self) -> &ZalsaLocal {
|
||||
&self.storage().zalsa_local
|
||||
}
|
||||
|
||||
fn fork_db(&self) -> Box<dyn Database> {
|
||||
Box::new(self.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl<Db: Database> RefUnwindSafe for Storage<Db> {}
|
||||
|
|
|
@ -50,6 +50,10 @@ pub unsafe trait ZalsaDatabase: Any {
|
|||
/// Access the thread-local state associated with this database
|
||||
#[doc(hidden)]
|
||||
fn zalsa_local(&self) -> &ZalsaLocal;
|
||||
|
||||
/// Clone the database.
|
||||
#[doc(hidden)]
|
||||
fn fork_db(&self) -> Box<dyn Database>;
|
||||
}
|
||||
|
||||
pub fn views<Db: ?Sized + Database>(db: &Db) -> &Views {
|
||||
|
|
|
@ -2,15 +2,17 @@
|
|||
|
||||
#![allow(dead_code)]
|
||||
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use salsa::{Database, Storage};
|
||||
|
||||
/// Logging userdata: provides [`LogDatabase`][] trait.
|
||||
///
|
||||
/// If you wish to use it along with other userdata,
|
||||
/// you can also embed it in another struct and implement [`HasLogger`][] for that struct.
|
||||
#[derive(Default)]
|
||||
#[derive(Clone, Default)]
|
||||
pub struct Logger {
|
||||
logs: std::sync::Mutex<Vec<String>>,
|
||||
logs: Arc<Mutex<Vec<String>>>,
|
||||
}
|
||||
|
||||
/// Trait implemented by databases that lets them log events.
|
||||
|
@ -48,7 +50,7 @@ impl<Db: HasLogger + Database> LogDatabase for Db {}
|
|||
|
||||
/// Database that provides logging but does not log salsa event.
|
||||
#[salsa::db]
|
||||
#[derive(Default)]
|
||||
#[derive(Clone, Default)]
|
||||
pub struct LoggerDatabase {
|
||||
storage: Storage<Self>,
|
||||
logger: Logger,
|
||||
|
@ -67,7 +69,7 @@ impl Database for LoggerDatabase {
|
|||
|
||||
/// Database that provides logging and logs salsa events.
|
||||
#[salsa::db]
|
||||
#[derive(Default)]
|
||||
#[derive(Clone, Default)]
|
||||
pub struct EventLoggerDatabase {
|
||||
storage: Storage<Self>,
|
||||
logger: Logger,
|
||||
|
@ -87,7 +89,7 @@ impl HasLogger for EventLoggerDatabase {
|
|||
}
|
||||
|
||||
#[salsa::db]
|
||||
#[derive(Default)]
|
||||
#[derive(Clone, Default)]
|
||||
pub struct DiscardLoggerDatabase {
|
||||
storage: Storage<Self>,
|
||||
logger: Logger,
|
||||
|
@ -114,7 +116,7 @@ impl HasLogger for DiscardLoggerDatabase {
|
|||
}
|
||||
|
||||
#[salsa::db]
|
||||
#[derive(Default)]
|
||||
#[derive(Clone, Default)]
|
||||
pub struct ExecuteValidateLoggerDatabase {
|
||||
storage: Storage<Self>,
|
||||
logger: Logger,
|
||||
|
|
|
@ -5,4 +5,5 @@ mod parallel_cycle_all_recover;
|
|||
mod parallel_cycle_mid_recover;
|
||||
mod parallel_cycle_none_recover;
|
||||
mod parallel_cycle_one_recover;
|
||||
mod parallel_map;
|
||||
mod signal;
|
||||
|
|
98
tests/parallel/parallel_map.rs
Normal file
98
tests/parallel/parallel_map.rs
Normal file
|
@ -0,0 +1,98 @@
|
|||
// test for rayon interations.
|
||||
|
||||
use salsa::Cancelled;
|
||||
use salsa::Setter;
|
||||
|
||||
use crate::setup::Knobs;
|
||||
use crate::setup::KnobsDatabase;
|
||||
|
||||
#[salsa::input]
|
||||
struct ParallelInput {
|
||||
field: Vec<u32>,
|
||||
}
|
||||
|
||||
#[salsa::tracked]
|
||||
fn tracked_fn(db: &dyn salsa::Database, input: ParallelInput) -> Vec<u32> {
|
||||
salsa::par_map(db, input.field(db), |_db, field| field + 1)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn execute() {
|
||||
let db = salsa::DatabaseImpl::new();
|
||||
|
||||
let counts = (1..=10).collect::<Vec<u32>>();
|
||||
let input = ParallelInput::new(&db, counts);
|
||||
|
||||
tracked_fn(&db, input);
|
||||
}
|
||||
|
||||
#[salsa::tracked]
|
||||
fn a1(db: &dyn KnobsDatabase, input: ParallelInput) -> Vec<u32> {
|
||||
db.signal(1);
|
||||
salsa::par_map(db, input.field(db), |db, field| {
|
||||
db.wait_for(2);
|
||||
field + 1
|
||||
})
|
||||
}
|
||||
|
||||
#[salsa::tracked]
|
||||
fn dummy(_db: &dyn KnobsDatabase, _input: ParallelInput) -> ParallelInput {
|
||||
panic!("should never get here!")
|
||||
}
|
||||
|
||||
// we expect this to panic, as `salsa::par_map` needs to be called from a query.
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn direct_calls_panic() {
|
||||
let db = salsa::DatabaseImpl::new();
|
||||
|
||||
let counts = (1..=10).collect::<Vec<u32>>();
|
||||
let input = ParallelInput::new(&db, counts);
|
||||
let _: Vec<u32> = salsa::par_map(&db, input.field(&db), |_db, field| field + 1);
|
||||
}
|
||||
|
||||
// Cancellation signalling test
|
||||
//
|
||||
// The pattern is as follows.
|
||||
//
|
||||
// Thread A Thread B
|
||||
// -------- --------
|
||||
// a1
|
||||
// | wait for stage 1
|
||||
// signal stage 1 set input, triggers cancellation
|
||||
// wait for stage 2 (blocks) triggering cancellation sends stage 2
|
||||
// |
|
||||
// (unblocked)
|
||||
// dummy
|
||||
// panics
|
||||
|
||||
#[test]
|
||||
fn execute_cancellation() {
|
||||
let mut db = Knobs::default();
|
||||
|
||||
let counts = (1..=10).collect::<Vec<u32>>();
|
||||
let input = ParallelInput::new(&db, counts);
|
||||
|
||||
let thread_a = std::thread::spawn({
|
||||
let db = db.clone();
|
||||
move || a1(&db, input)
|
||||
});
|
||||
|
||||
let counts = (2..=20).collect::<Vec<u32>>();
|
||||
|
||||
db.signal_on_did_cancel.store(2);
|
||||
input.set_field(&mut db).to(counts);
|
||||
|
||||
// Assert thread A *should* was cancelled
|
||||
let cancelled = thread_a
|
||||
.join()
|
||||
.unwrap_err()
|
||||
.downcast::<Cancelled>()
|
||||
.unwrap();
|
||||
|
||||
// and inspect the output
|
||||
expect_test::expect![[r#"
|
||||
PendingWrite
|
||||
"#]]
|
||||
.assert_debug_eq(&cancelled);
|
||||
}
|
|
@ -83,7 +83,7 @@ fn check<'db>(db: &'db dyn Db, file: File) -> Inference<'db> {
|
|||
#[test]
|
||||
fn execute() {
|
||||
#[salsa::db]
|
||||
#[derive(Default)]
|
||||
#[derive(Default, Clone)]
|
||||
struct Database {
|
||||
storage: salsa::Storage<Self>,
|
||||
files: Vec<File>,
|
||||
|
|
Loading…
Reference in a new issue