diff --git a/src/derived/slot.rs b/src/derived/slot.rs index 1aadbd4a..8d980870 100644 --- a/src/derived/slot.rs +++ b/src/derived/slot.rs @@ -5,7 +5,6 @@ use crate::lru::LruIndex; use crate::lru::LruNode; use crate::plumbing::{DatabaseOps, QueryFunction}; use crate::revision::Revision; -use crate::runtime::cycle_participant::CycleParticipant; use crate::runtime::local_state::ActiveQueryGuard; use crate::runtime::local_state::QueryInputs; use crate::runtime::local_state::QueryRevisions; @@ -13,6 +12,7 @@ use crate::runtime::Runtime; use crate::runtime::RuntimeId; use crate::runtime::StampedValue; use crate::runtime::WaitResult; +use crate::Cycle; use crate::{Database, DatabaseKeyIndex, Event, EventKind, QueryDb}; use log::{debug, info}; use parking_lot::{RawRwLock, RwLock}; @@ -239,12 +239,29 @@ where // Query was not previously executed, or value is potentially // stale, or value is absent. Let's execute! - let value = CycleParticipant::recover( - || Q::execute(db, self.key.clone()), - // If a recoverable cycle occurs, `Q::execute` will throw - // and this closure will be executed with the cycle information. - |cycle| Q::cycle_fallback(db, &cycle, &self.key), - ); + let value = match Cycle::catch(|| Q::execute(db, self.key.clone())) { + Ok(v) => v, + Err(cycle) => { + log::debug!( + "{:?}: caught cycle {:?}, have strategy {:?}", + self.database_key_index.debug(db), + cycle, + Q::CYCLE_STRATEGY, + ); + match Q::CYCLE_STRATEGY { + crate::plumbing::CycleRecoveryStrategy::Panic => { + panic_guard.proceed(None); + cycle.throw() + } + crate::plumbing::CycleRecoveryStrategy::Fallback => { + if let Some(c) = active_query.take_cycle() { + assert!(c.is(&cycle)); + } + Q::cycle_fallback(db, &cycle, &self.key) + } + } + } + }; let mut revisions = active_query.pop(); diff --git a/src/lib.rs b/src/lib.rs index 84fbe9b3..a34df5ac 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -35,6 +35,7 @@ use crate::plumbing::QueryStorageOps; pub use crate::revision::Revision; use std::fmt::{self, Debug}; use std::hash::Hash; +use std::panic::AssertUnwindSafe; use std::panic::{self, UnwindSafe}; use std::sync::Arc; @@ -645,7 +646,7 @@ impl std::fmt::Display for Cancelled { impl std::error::Error for Cancelled {} -/// Captuers the participants of a cycle that occurred when executing a query. +/// Captures the participants of a cycle that occurred when executing a query. /// /// This type is meant to be used to help give meaningful error messages to the /// user or to help salsa developers figure out why their program is resulting @@ -670,10 +671,26 @@ impl Cycle { Self { participants } } + /// True if two `Cycle` values represent the same cycle. + pub(crate) fn is(&self, cycle: &Cycle) -> bool { + Arc::ptr_eq(&self.participants, &cycle.participants) + } + pub(crate) fn throw(self) -> ! { + log::debug!("throwing cycle {:?}", self); std::panic::resume_unwind(Box::new(self)) } + pub(crate) fn catch(execute: impl FnOnce() -> T) -> Result { + match std::panic::catch_unwind(AssertUnwindSafe(execute)) { + Ok(v) => Ok(v), + Err(err) => match err.downcast::() { + Ok(cycle) => Err(*cycle), + Err(other) => std::panic::resume_unwind(other), + }, + } + } + /// Iterate over the [`DatabaseKeyIndex`] for each query participating /// in the cycle. The start point of this iteration within the cycle /// is arbitrary but deterministic, but the ordering is otherwise determined diff --git a/src/runtime.rs b/src/runtime.rs index a22e77a6..3d040f53 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -13,15 +13,12 @@ use std::sync::Arc; pub(crate) type FxIndexSet = indexmap::IndexSet>; pub(crate) type FxIndexMap = indexmap::IndexMap>; -pub(crate) mod cycle_participant; - mod dependency_graph; use dependency_graph::DependencyGraph; pub(crate) mod local_state; use local_state::LocalState; -use self::cycle_participant::CycleParticipant; use self::local_state::{ActiveQueryGuard, QueryInputs, QueryRevisions}; /// The salsa runtime stores the storage for all queries as well as @@ -269,13 +266,13 @@ impl Runtime { .report_synthetic_read(durability, changed_at); } - fn create_cycle_error( + fn throw_cycle_error( &self, db: &dyn Database, mut dg: MutexGuard<'_, DependencyGraph>, database_key_index: DatabaseKeyIndex, to_id: RuntimeId, - ) -> (CycleRecoveryStrategy, Cycle) { + ) -> ! { debug!("create_cycle_error(database_key={:?})", database_key_index); let mut from_stack = self.local_state.take_query_stack(); @@ -319,80 +316,32 @@ impl Runtime { cycle_query, ); - // Identify cycle recovery strategy: - let recovery_strategy = self.mutual_cycle_recovery_strategy(db, &cycle); - debug!("cycle recovery strategy {:?}", recovery_strategy); + // We can remove the cycle participants from the list of dependencies; + // they are a strongly connected component (SCC) and we only care about + // dependencies to things outside the SCC that control whether it will + // form again. + cycle_query.remove_cycle_participants(&cycle); - // If using fallback, we have to mark the cycle participants, so they know to recover. - // In this case, we also want to modify their changed-at and durability values to - // be the max/min we computed across the entire cycle. - match recovery_strategy { - CycleRecoveryStrategy::Panic => {} - CycleRecoveryStrategy::Fallback => { - // We don't need to include the cycle participants in the input. - // Everyone depends on all the external dependencies. - cycle_query.remove_cycle_participants(&cycle); + // Mark the cycle participants so they know to recover. + // This only matters for queries that have a fallback value specified; + // the others will just unwind without storing any recovery information. + dg.for_each_cycle_participant(from_id, &mut from_stack, database_key_index, to_id, |aq| { + match db.cycle_recovery_strategy(aq.database_key_index) { + CycleRecoveryStrategy::Fallback => { + debug!("marking {:?} for fallback", aq.database_key_index.debug(db)); + aq.take_inputs_from(&cycle_query); + aq.cycle = Some(cycle.clone()); + } - // Mark the cycle participants, so they know to recover: - dg.for_each_cycle_participant( - from_id, - &mut from_stack, - database_key_index, - to_id, - |aq| { - aq.take_inputs_from(&cycle_query); - aq.cycle = Some(cycle.clone()); - }, - ); - - // The top of the current stack is a bit of a special case: - // - // C0 --> ... --> Cn-1 -> Cn --> C0 - // ^ - // : - // This edge -------------+ - // - // Each query Ci in C0..=Cn-1 will recover the "usual" way: - // the query Ci+1 will recover from [`CycleParticipant::unwind`] - // and store a recovery value in the table. This value will - // be returned to Ci, which will invoke - // `report_query_read_and_unwind_if_cycle_resulted` to record - // the dependency Ci -> Ci+1. This method will see the `cycle` flag - // set on the `ActiveQuery` for `Ci` and will unwind. - // - // However, the cyclic edge `Cn -> C0` is a bit different: - // since C0 has not recovered yet, we don't have an easy value - // to return and propagate upwards. Instead, we just add the dependency - // *here* and then (in our caller) unwind with `CycleParticipant`. - from_stack.last_mut().unwrap().cycle = None; + CycleRecoveryStrategy::Panic => { + // NB: Don't mark these frames! + } } - } + }); self.local_state.restore_query_stack(from_stack); - (recovery_strategy, cycle) - } - - fn mutual_cycle_recovery_strategy( - &self, - db: &dyn Database, - cycle: &Cycle, - ) -> CycleRecoveryStrategy { - let participants = &cycle.participants; - let crs = db.cycle_recovery_strategy(participants[0]); - if let Some(key) = participants[1..] - .iter() - .copied() - .find(|&key| db.cycle_recovery_strategy(key) != crs) - { - debug!("mutual_cycle_recovery_strategy: cycle had multiple strategies ({:?} for {:?} vs {:?} for {:?})", - crs, participants[0], - db.cycle_recovery_strategy(key), key - ); - CycleRecoveryStrategy::Panic - } else { - crs - } + cycle.throw() } /// Block until `other_id` completes executing `database_key`; @@ -427,40 +376,37 @@ impl Runtime { let mut dg = self.shared_state.dependency_graph.lock(); if self.id() == other_id || dg.depends_on(other_id, self.id()) { - match self.create_cycle_error(db, dg, database_key, other_id) { - (CycleRecoveryStrategy::Panic, cycle) => cycle.throw(), - (CycleRecoveryStrategy::Fallback, cycle) => CycleParticipant::new(cycle).unwind(), - } - } else { - db.salsa_event(Event { - runtime_id: self.id(), - kind: EventKind::WillBlockOn { - other_runtime_id: other_id, - database_key, - }, - }); + self.throw_cycle_error(db, dg, database_key, other_id) + } - let stack = self.local_state.take_query_stack(); - - let (stack, result) = DependencyGraph::block_on( - dg, - self.id(), + db.salsa_event(Event { + runtime_id: self.id(), + kind: EventKind::WillBlockOn { + other_runtime_id: other_id, database_key, - other_id, - stack, - query_mutex_guard, - ); + }, + }); - self.local_state.restore_query_stack(stack); + let stack = self.local_state.take_query_stack(); - match result { - WaitResult::Completed => (), + let (stack, result) = DependencyGraph::block_on( + dg, + self.id(), + database_key, + other_id, + stack, + query_mutex_guard, + ); - // If the other thread panicked, then we consider this thread - // cancelled. The assumption is that the panic will be detected - // by the other thread and responded to appropriately. - WaitResult::Panicked => Cancelled::PropagatedPanic.throw(), - } + self.local_state.restore_query_stack(stack); + + match result { + WaitResult::Completed => (), + + // If the other thread panicked, then we consider this thread + // cancelled. The assumption is that the panic will be detected + // by the other thread and responded to appropriately. + WaitResult::Panicked => Cancelled::PropagatedPanic.throw(), } } diff --git a/src/runtime/cycle_participant.rs b/src/runtime/cycle_participant.rs deleted file mode 100644 index aaf910f9..00000000 --- a/src/runtime/cycle_participant.rs +++ /dev/null @@ -1,29 +0,0 @@ -use std::panic::AssertUnwindSafe; - -use crate::Cycle; - -pub(crate) struct CycleParticipant { - cycle: Cycle, -} - -impl CycleParticipant { - pub(crate) fn new(cycle: Cycle) -> Self { - Self { cycle } - } - - /// Initiate unwinding. This is called `unwind` and not `throw` or `panic` - /// because every call to `unwind` here ought to be caught by a - /// matching call to [`recover`]. - pub(crate) fn unwind(self) -> ! { - std::panic::resume_unwind(Box::new(self)); - } - - pub(crate) fn recover(execute: impl FnOnce() -> T, recover: impl FnOnce(Cycle) -> T) -> T { - std::panic::catch_unwind(AssertUnwindSafe(execute)).unwrap_or_else(|err| { - match err.downcast::() { - Ok(participant) => recover(participant.cycle), - Err(v) => std::panic::resume_unwind(v), - } - }) - } -} diff --git a/src/runtime/local_state.rs b/src/runtime/local_state.rs index 2bf8478f..b1da8a0a 100644 --- a/src/runtime/local_state.rs +++ b/src/runtime/local_state.rs @@ -1,12 +1,13 @@ +use log::debug; + use crate::durability::Durability; use crate::runtime::ActiveQuery; use crate::runtime::Revision; +use crate::Cycle; use crate::DatabaseKeyIndex; use std::cell::RefCell; use std::sync::Arc; -use super::cycle_participant::CycleParticipant; - /// State that is specific to a single execution thread. /// /// Internally, this type uses ref-cells. @@ -97,6 +98,10 @@ impl LocalState { durability: Durability, changed_at: Revision, ) { + debug!( + "report_query_read_and_unwind_if_cycle_resulted(input={:?}, durability={:?}, changed_at={:?})", + input, durability, changed_at + ); self.with_query_stack(|stack| { if let Some(top_query) = stack.last_mut() { top_query.add_read(input, durability, changed_at); @@ -123,7 +128,7 @@ impl LocalState { // stack frames, so they will just read the fallback value // from `Ci+1` and continue on their merry way. if let Some(cycle) = top_query.cycle.take() { - CycleParticipant::new(cycle).unwind() + cycle.throw() } } }) @@ -211,6 +216,13 @@ impl ActiveQueryGuard<'_> { popped_query.revisions() } + + /// If the active query is registered as a cycle participant, remove and + /// return that cycle. + pub(crate) fn take_cycle(&self) -> Option { + self.local_state + .with_query_stack(|stack| stack.last_mut()?.cycle.take()) + } } impl Drop for ActiveQueryGuard<'_> { diff --git a/tests/cycles.rs b/tests/cycles.rs index eedc1566..d514b84d 100644 --- a/tests/cycles.rs +++ b/tests/cycles.rs @@ -172,6 +172,7 @@ fn cycle_c(db: &dyn Database) -> Result<(), Error> { db.c_invokes().invoke(db) } +#[track_caller] fn extract_cycle(f: impl FnOnce() + UnwindSafe) -> salsa::Cycle { let v = std::panic::catch_unwind(f); if let Err(d) = &v { @@ -352,20 +353,17 @@ fn cycle_mixed_1() { db.set_b_invokes(CycleQuery::C); db.set_c_invokes(CycleQuery::B); - let u = extract_cycle(|| { - let _ = db.cycle_a(); - }); - insta::assert_debug_snapshot!((u.all_participants(&db), u.unexpected_participants(&db)), @r###" - ( - [ - "cycle_b(())", - "cycle_c(())", - ], - [ - "cycle_c(())", - ], - ) - "###); + let u = db.cycle_c(); + insta::assert_debug_snapshot!(u, @r###" + Err( + Error { + cycle: [ + "cycle_b(())", + "cycle_c(())", + ], + }, + ) + "###); } #[test] @@ -381,21 +379,18 @@ fn cycle_mixed_2() { db.set_b_invokes(CycleQuery::C); db.set_c_invokes(CycleQuery::A); - let u = extract_cycle(|| { - let _ = db.cycle_a(); - }); - insta::assert_debug_snapshot!((u.all_participants(&db), u.unexpected_participants(&db)), @r###" - ( - [ - "cycle_a(())", - "cycle_b(())", - "cycle_c(())", - ], - [ - "cycle_c(())", - ], - ) - "###); + let u = db.cycle_a(); + insta::assert_debug_snapshot!(u, @r###" + Err( + Error { + cycle: [ + "cycle_a(())", + "cycle_b(())", + "cycle_c(())", + ], + }, + ) + "###); } #[test] diff --git a/tests/parallel/cycles.rs b/tests/parallel/cycles.rs index 7b3823bb..6170fdb4 100644 --- a/tests/parallel/cycles.rs +++ b/tests/parallel/cycles.rs @@ -3,7 +3,7 @@ //! both intra and cross thread. use crate::setup::{Knobs, ParDatabase, ParDatabaseImpl}; -use salsa::{Cancelled, ParallelDatabase}; +use salsa::ParallelDatabase; use test_env_log::test; // Recover cycle test: @@ -164,6 +164,6 @@ fn panic_parallel_cycle() { assert!(thread_a .join() .unwrap_err() - .downcast_ref::() + .downcast_ref::() .is_some()); }