diff --git a/tests/parallel/cancellation.rs b/tests/parallel/cancellation.rs index 5c6da7af..b9590f00 100644 --- a/tests/parallel/cancellation.rs +++ b/tests/parallel/cancellation.rs @@ -1,125 +1,6 @@ -use parking_lot::{Condvar, Mutex}; +use crate::setup::{Input, Knobs, ParDatabase, ParDatabaseImpl, WithValue}; use salsa::Database; use salsa::ParallelDatabase; -use std::cell::Cell; -use std::sync::Arc; - -#[derive(Default)] -pub struct ParDatabaseImpl { - runtime: salsa::Runtime, - signal: Arc, -} - -impl Database for ParDatabaseImpl { - fn salsa_runtime(&self) -> &salsa::Runtime { - &self.runtime - } -} - -impl ParallelDatabase for ParDatabaseImpl { - fn fork(&self) -> Self { - ParDatabaseImpl { - runtime: self.runtime.fork(), - signal: self.signal.clone(), - } - } -} - -salsa::database_storage! { - pub struct DatabaseImplStorage for ParDatabaseImpl { - impl ParDatabase { - fn input() for Input; - fn sum() for Sum; - } - } -} - -salsa::query_group! { - trait ParDatabase: HasSignal + salsa::Database { - fn input(key: char) -> usize { - type Input; - storage input; - } - - fn sum(key: &'static str) -> usize { - type Sum; - use fn sum; - } - } -} - -// This is used to force `sum` to block on the signal sometimes so -// that we can forcibly arrange race conditions we would like to test. -thread_local! { - static SUM_SHOULD_AWAIT_CANCELLATION: Cell = Cell::new(false); -} - -trait HasSignal { - fn signal(&self) -> &Signal; -} - -impl HasSignal for ParDatabaseImpl { - fn signal(&self) -> &Signal { - &self.signal - } -} - -#[derive(Default)] -struct Signal { - value: Mutex, - cond_var: Condvar, -} - -impl Signal { - fn signal(&self, stage: usize) { - log::debug!("signal({})", stage); - let mut v = self.value.lock(); - assert!( - stage > *v, - "stage should be increasing monotonically (old={}, new={})", - *v, - stage - ); - *v = stage; - self.cond_var.notify_all(); - } - - /// Waits until the given condition is true; the fn is invoked - /// with the current stage. - fn await(&self, stage: usize) { - log::debug!("await({})", stage); - let mut v = self.value.lock(); - while *v < stage { - self.cond_var.wait(&mut v); - } - } -} - -fn sum(db: &impl ParDatabase, key: &'static str) -> usize { - let mut sum = 0; - - // If we are going to await cancellation, we first *signal* when - // we have entered. This way, the other thread can wait and be - // sure that we are executing `sum`. - if SUM_SHOULD_AWAIT_CANCELLATION.with(|s| s.get()) { - db.signal().signal(1); - } - - for ch in key.chars() { - sum += db.input(ch); - } - - if SUM_SHOULD_AWAIT_CANCELLATION.with(|s| s.get()) { - log::debug!("awaiting cancellation"); - while !db.salsa_runtime().is_current_revision_canceled() { - std::thread::yield_now(); - } - log::debug!("cancellation observed"); - return std::usize::MAX; // when we are cancelled, we return usize::MAX. - } - - sum -} #[test] fn in_par() { @@ -180,12 +61,13 @@ fn in_par_get_set_cancellation() { let thread1 = std::thread::spawn({ let db = db.fork(); move || { - SUM_SHOULD_AWAIT_CANCELLATION.with(|c| c.set(true)); - let v1 = db.sum("abc"); + let v1 = db.sum_signal_on_entry().with_value(1, || { + db.sum_await_cancellation() + .with_value(true, || db.sum("abc")) + }); // check that we observed cancellation assert_eq!(v1, std::usize::MAX); - SUM_SHOULD_AWAIT_CANCELLATION.with(|c| c.set(false)); // at this point, we have observed cancellation, so let's // wait until the `set` is known to have occurred. diff --git a/tests/parallel/main.rs b/tests/parallel/main.rs index dc333d49..209dbc30 100644 --- a/tests/parallel/main.rs +++ b/tests/parallel/main.rs @@ -1 +1,3 @@ +mod setup; + mod cancellation; diff --git a/tests/parallel/setup.rs b/tests/parallel/setup.rs new file mode 100644 index 00000000..67f6413e --- /dev/null +++ b/tests/parallel/setup.rs @@ -0,0 +1,148 @@ +use parking_lot::{Condvar, Mutex}; +use salsa::Database; +use salsa::ParallelDatabase; +use std::cell::Cell; +use std::sync::Arc; + +salsa::query_group! { + pub(crate) trait ParDatabase: Knobs + salsa::Database { + fn input(key: char) -> usize { + type Input; + storage input; + } + + fn sum(key: &'static str) -> usize { + type Sum; + use fn sum; + } + } +} + +/// Various "knobs" and utilities used by tests to force +/// a certain behavior. +pub(crate) trait Knobs { + fn signal(&self) -> &Signal; + + /// Invocations of `sum` will signal `stage` this stage on entry. + fn sum_signal_on_entry(&self) -> &Cell; + + /// If set to true, invocations of `sum` will await cancellation + /// before they exit. + fn sum_await_cancellation(&self) -> &Cell; +} + +pub(crate) trait WithValue { + fn with_value(&self, value: T, closure: impl FnOnce() -> R) -> R; +} + +impl WithValue for Cell { + fn with_value(&self, value: T, closure: impl FnOnce() -> R) -> R { + let old_value = self.replace(value); + + let result = closure(); + + self.set(old_value); + + result + } +} + +#[derive(Clone, Default)] +struct KnobsStruct { + signal: Arc, + sum_signal_on_entry: Cell, + sum_await_cancellation: Cell, +} + +#[derive(Default)] +pub(crate) struct Signal { + value: Mutex, + cond_var: Condvar, +} + +impl Signal { + pub(crate) fn signal(&self, stage: usize) { + log::debug!("signal({})", stage); + let mut v = self.value.lock(); + if stage > *v { + *v = stage; + self.cond_var.notify_all(); + } + } + + /// Waits until the given condition is true; the fn is invoked + /// with the current stage. + pub(crate) fn await(&self, stage: usize) { + log::debug!("await({})", stage); + let mut v = self.value.lock(); + while *v < stage { + self.cond_var.wait(&mut v); + } + } +} + +fn sum(db: &impl ParDatabase, key: &'static str) -> usize { + let mut sum = 0; + + let stage = db.sum_signal_on_entry().get(); + db.signal().signal(stage); + + for ch in key.chars() { + sum += db.input(ch); + } + + if db.sum_await_cancellation().get() { + log::debug!("awaiting cancellation"); + while !db.salsa_runtime().is_current_revision_canceled() { + std::thread::yield_now(); + } + log::debug!("cancellation observed"); + return std::usize::MAX; // when we are cancelled, we return usize::MAX. + } + + sum +} + +#[derive(Default)] +pub struct ParDatabaseImpl { + runtime: salsa::Runtime, + knobs: KnobsStruct, +} + +impl Database for ParDatabaseImpl { + fn salsa_runtime(&self) -> &salsa::Runtime { + &self.runtime + } +} + +impl ParallelDatabase for ParDatabaseImpl { + fn fork(&self) -> Self { + ParDatabaseImpl { + runtime: self.runtime.fork(), + knobs: self.knobs.clone(), + } + } +} + +impl Knobs for ParDatabaseImpl { + fn signal(&self) -> &Signal { + &self.knobs.signal + } + + fn sum_signal_on_entry(&self) -> &Cell { + &self.knobs.sum_signal_on_entry + } + + fn sum_await_cancellation(&self) -> &Cell { + &self.knobs.sum_await_cancellation + } +} + +salsa::database_storage! { + pub struct DatabaseImplStorage for ParDatabaseImpl { + impl ParDatabase { + fn input() for Input; + fn sum() for Sum; + } + } +}