diff --git a/tests/parallel/cancellation.rs b/tests/parallel/cancellation.rs index 73afbad1..8ff2a38f 100644 --- a/tests/parallel/cancellation.rs +++ b/tests/parallel/cancellation.rs @@ -1,6 +1,20 @@ -use crate::setup::{Input, Knobs, ParDatabase, ParDatabaseImpl, WithValue}; +use crate::setup::{ + CancelationFlag, Canceled, Input, Knobs, ParDatabase, ParDatabaseImpl, WithValue, +}; use salsa::{Database, ParallelDatabase}; +macro_rules! assert_canceled { + ($thread:expr) => { + match $thread.join() { + Ok(value) => panic!("expected cancelation, got {:?}", value), + Err(payload) => match payload.downcast::() { + Ok(_) => {} + Err(payload) => ::std::panic::resume_unwind(payload), + }, + } + }; +} + /// Add test where a call to `sum` is cancelled by a simultaneous /// write. Check that we recompute the result in next revision, even /// though none of the inputs have changed. @@ -21,7 +35,7 @@ fn in_par_get_set_cancellation_immediate() { db.knobs().sum_signal_on_entry.with_value(1, || { db.knobs() .sum_wait_for_cancellation - .with_value(true, || db.sum("abc")) + .with_value(CancelationFlag::Panic, || db.sum("abc")) }) } }); @@ -39,7 +53,7 @@ fn in_par_get_set_cancellation_immediate() { }); assert_eq!(db.sum("d"), 1000); - assert_eq!(thread1.join().unwrap(), std::usize::MAX); + assert_canceled!(thread1); assert_eq!(thread2.join().unwrap(), 111); } @@ -62,7 +76,7 @@ fn in_par_get_set_cancellation_transitive() { db.knobs().sum_signal_on_entry.with_value(1, || { db.knobs() .sum_wait_for_cancellation - .with_value(true, || db.sum2("abc")) + .with_value(CancelationFlag::Panic, || db.sum2("abc")) }) } }); @@ -80,7 +94,7 @@ fn in_par_get_set_cancellation_transitive() { }); assert_eq!(db.sum2("d"), 1000); - assert_eq!(thread1.join().unwrap(), std::usize::MAX); + assert_canceled!(thread1); assert_eq!(thread2.join().unwrap(), 111); } @@ -98,7 +112,7 @@ fn no_back_dating_in_cancellation() { db.knobs().sum_signal_on_entry.with_value(1, || { db.knobs() .sum_wait_for_cancellation - .with_value(true, || db.sum3("a")) + .with_value(CancelationFlag::Panic, || db.sum3("a")) }) } }); @@ -112,7 +126,7 @@ fn no_back_dating_in_cancellation() { // state. If we get `usize::max()` here, it is a bug! assert_eq!(db.sum3("a"), 1); - assert_eq!(thread1.join().unwrap(), std::usize::MAX); + assert_canceled!(thread1); db.query_mut(Input).set('a', 3); db.query_mut(Input).set('a', 4); @@ -137,7 +151,7 @@ fn transitive_cancellation() { db.knobs().sum_signal_on_entry.with_value(1, || { db.knobs() .sum_wait_for_cancellation - .with_value(true, || db.sum3_drop_sum("a")) + .with_value(CancelationFlag::SpecialValue, || db.sum3_drop_sum("a")) }) } }); diff --git a/tests/parallel/setup.rs b/tests/parallel/setup.rs index 100c4e05..d32bc645 100644 --- a/tests/parallel/setup.rs +++ b/tests/parallel/setup.rs @@ -42,6 +42,16 @@ salsa::query_group! { } } +#[derive(PartialEq, Eq)] +pub(crate) struct Canceled; + +impl Canceled { + fn throw() -> ! { + // Don't print backtrace + std::panic::resume_unwind(Box::new(Canceled)); + } +} + /// Various "knobs" and utilities used by tests to force /// a certain behavior. pub(crate) trait Knobs { @@ -68,6 +78,19 @@ impl WithValue for Cell { } } +#[derive(Clone, Copy, PartialEq, Eq)] +pub(crate) enum CancelationFlag { + Down, + Panic, + SpecialValue, +} + +impl Default for CancelationFlag { + fn default() -> CancelationFlag { + CancelationFlag::Down + } +} + /// Various "knobs" that can be used to customize how the queries /// behave on one specific thread. Note that this state is /// intentionally thread-local (apart from `signal`). @@ -91,7 +114,7 @@ pub(crate) struct KnobsStruct { /// If true, invocations of `sum` will wait for cancellation before /// they exit. - pub(crate) sum_wait_for_cancellation: Cell, + pub(crate) sum_wait_for_cancellation: Cell, /// Invocations of `sum` will wait for this stage prior to exiting. pub(crate) sum_wait_for_on_exit: Cell, @@ -118,12 +141,25 @@ fn sum(db: &impl ParDatabase, key: &'static str) -> usize { sum += db.input(ch); } - if db.knobs().sum_wait_for_cancellation.get() { - log::debug!("waiting for cancellation"); - while !db.salsa_runtime().is_current_revision_canceled() { - std::thread::yield_now(); + match db.knobs().sum_wait_for_cancellation.get() { + CancelationFlag::Down => (), + CancelationFlag::SpecialValue => { + log::debug!("waiting for cancellation"); + while !db.salsa_runtime().is_current_revision_canceled() { + std::thread::yield_now(); + } + log::debug!("observed cancelation"); + } + CancelationFlag::Panic => { + log::debug!("waiting for cancellation"); + loop { + db.salsa_runtime().if_current_revision_is_canceled(|| { + log::debug!("observed cancelation"); + Canceled::throw() + }); + std::thread::yield_now(); + } } - log::debug!("cancellation observed"); } // Check for cancelation and return MAX if so. Note that we check @@ -191,6 +227,10 @@ impl Database for ParDatabaseImpl { _ => {} } } + + fn on_propagated_panic(&self) -> ! { + Canceled::throw() + } } impl ParallelDatabase for ParDatabaseImpl {