use crate::signal::Signal; use salsa::Database; use salsa::ParallelDatabase; use salsa::Snapshot; use std::sync::Arc; use std::{ cell::Cell, panic::{catch_unwind, resume_unwind, AssertUnwindSafe}, }; #[salsa::query_group(Par)] pub(crate) trait ParDatabase: Knobs { #[salsa::input] fn input(&self, key: char) -> usize; fn sum(&self, key: &'static str) -> usize; /// Invokes `sum` fn sum2(&self, key: &'static str) -> usize; /// Invokes `sum` but doesn't really care about the result. fn sum2_drop_sum(&self, key: &'static str) -> usize; /// Invokes `sum2` fn sum3(&self, key: &'static str) -> usize; /// Invokes `sum2_drop_sum` fn sum3_drop_sum(&self, key: &'static str) -> usize; #[salsa::cycle(crate::cycles::recover_from_cycle_a1)] #[salsa::invoke(crate::cycles::recover_cycle_a1)] fn recover_cycle_a1(&self, key: i32) -> i32; #[salsa::cycle(crate::cycles::recover_from_cycle_a2)] #[salsa::invoke(crate::cycles::recover_cycle_a2)] fn recover_cycle_a2(&self, key: i32) -> i32; #[salsa::cycle(crate::cycles::recover_from_cycle_b1)] #[salsa::invoke(crate::cycles::recover_cycle_b1)] fn recover_cycle_b1(&self, key: i32) -> i32; #[salsa::cycle(crate::cycles::recover_from_cycle_b2)] #[salsa::invoke(crate::cycles::recover_cycle_b2)] fn recover_cycle_b2(&self, key: i32) -> i32; #[salsa::invoke(crate::cycles::panic_cycle_a)] fn panic_cycle_a(&self, key: i32) -> i32; #[salsa::invoke(crate::cycles::panic_cycle_b)] fn panic_cycle_b(&self, key: i32) -> i32; } /// Various "knobs" and utilities used by tests to force /// a certain behavior. pub(crate) trait Knobs { fn knobs(&self) -> &KnobsStruct; fn signal(&self, stage: usize); fn wait_for(&self, stage: usize); } pub(crate) trait WithValue { fn with_value(&self, value: T, closure: impl FnOnce() -> R) -> R; } impl WithValue for Cell { fn with_value(&self, value: T, closure: impl FnOnce() -> R) -> R { let old_value = self.replace(value); let result = catch_unwind(AssertUnwindSafe(|| closure())); self.set(old_value); match result { Ok(r) => r, Err(payload) => resume_unwind(payload), } } } #[derive(Clone, Copy, PartialEq, Eq)] pub(crate) enum CancellationFlag { Down, Panic, } impl Default for CancellationFlag { fn default() -> CancellationFlag { CancellationFlag::Down } } /// Various "knobs" that can be used to customize how the queries /// behave on one specific thread. Note that this state is /// intentionally thread-local (apart from `signal`). #[derive(Clone, Default)] pub(crate) struct KnobsStruct { /// A kind of flexible barrier used to coordinate execution across /// threads to ensure we reach various weird states. pub(crate) signal: Arc, /// When this database is about to block, send a signal. pub(crate) signal_on_will_block: Cell, /// Invocations of `sum` will signal this stage on entry. pub(crate) sum_signal_on_entry: Cell, /// Invocations of `sum` will wait for this stage on entry. pub(crate) sum_wait_for_on_entry: Cell, /// If true, invocations of `sum` will panic before they exit. pub(crate) sum_should_panic: Cell, /// If true, invocations of `sum` will wait for cancellation before /// they exit. pub(crate) sum_wait_for_cancellation: Cell, /// Invocations of `sum` will wait for this stage prior to exiting. pub(crate) sum_wait_for_on_exit: Cell, /// Invocations of `sum` will signal this stage prior to exiting. pub(crate) sum_signal_on_exit: Cell, /// Invocations of `sum3_drop_sum` will panic unconditionally pub(crate) sum3_drop_sum_should_panic: Cell, } fn sum(db: &dyn ParDatabase, key: &'static str) -> usize { let mut sum = 0; db.signal(db.knobs().sum_signal_on_entry.get()); db.wait_for(db.knobs().sum_wait_for_on_entry.get()); if db.knobs().sum_should_panic.get() { panic!("query set to panic before exit") } for ch in key.chars() { sum += db.input(ch); } match db.knobs().sum_wait_for_cancellation.get() { CancellationFlag::Down => (), CancellationFlag::Panic => { log::debug!("waiting for cancellation"); loop { db.unwind_if_cancelled(); std::thread::yield_now(); } } } db.wait_for(db.knobs().sum_wait_for_on_exit.get()); db.signal(db.knobs().sum_signal_on_exit.get()); sum } fn sum2(db: &dyn ParDatabase, key: &'static str) -> usize { db.sum(key) } fn sum2_drop_sum(db: &dyn ParDatabase, key: &'static str) -> usize { let _ = db.sum(key); 22 } fn sum3(db: &dyn ParDatabase, key: &'static str) -> usize { db.sum2(key) } fn sum3_drop_sum(db: &dyn ParDatabase, key: &'static str) -> usize { if db.knobs().sum3_drop_sum_should_panic.get() { panic!("sum3_drop_sum executed") } db.sum2_drop_sum(key) } #[salsa::database(Par)] #[derive(Default)] pub(crate) struct ParDatabaseImpl { storage: salsa::Storage, knobs: KnobsStruct, } impl Database for ParDatabaseImpl { fn salsa_event(&self, event: salsa::Event) { if let salsa::EventKind::WillBlockOn { .. } = event.kind { self.signal(self.knobs().signal_on_will_block.get()); } } } impl ParallelDatabase for ParDatabaseImpl { fn snapshot(&self) -> Snapshot { Snapshot::new(ParDatabaseImpl { storage: self.storage.snapshot(), knobs: self.knobs.clone(), }) } } impl Knobs for ParDatabaseImpl { fn knobs(&self) -> &KnobsStruct { &self.knobs } fn signal(&self, stage: usize) { self.knobs.signal.signal(stage); } fn wait_for(&self, stage: usize) { self.knobs.signal.wait_for(stage); } }