diff --git a/components/salsa-macros/src/query_group.rs b/components/salsa-macros/src/query_group.rs index 7682a029..380b7bfd 100644 --- a/components/salsa-macros/src/query_group.rs +++ b/components/salsa-macros/src/query_group.rs @@ -39,6 +39,7 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream match item { TraitItem::Method(method) => { let mut storage = QueryStorage::Memoized; + let mut cycle = None; let mut invoke = None; let mut query_type = Ident::new( &format!("{}Query", method.sig.ident.to_string().to_camel_case()), @@ -66,6 +67,9 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream storage = QueryStorage::Interned; num_storages += 1; } + "cycle" => { + cycle = Some(parse_macro_input!(tts as Parenthesized).0); + } "invoke" => { invoke = Some(parse_macro_input!(tts as Parenthesized).0); } @@ -150,6 +154,7 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream keys: lookup_keys, value: lookup_value, invoke: None, + cycle: cycle.clone(), }) } else { None @@ -163,6 +168,7 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream keys, value, invoke, + cycle, }); queries.extend(lookup_query); @@ -354,9 +360,7 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream QueryStorage::Dependencies => quote!(salsa::plumbing::DependencyStorage<#db, Self>), QueryStorage::Input => quote!(salsa::plumbing::InputStorage<#db, Self>), QueryStorage::Interned => quote!(salsa::plumbing::InternedStorage<#db, Self>), - QueryStorage::InternedLookup { intern_query_type } => { - quote!(salsa::plumbing::LookupInternedStorage<#db, Self, #intern_query_type>) - } + QueryStorage::InternedLookup { intern_query_type } => quote!(salsa::plumbing::LookupInternedStorage<#db, Self, #intern_query_type>), QueryStorage::Transparent => continue, }; let keys = &query.keys; @@ -404,6 +408,22 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream quote! { (#(#key_names),*) } }; let invoke = query.invoke_tt(); + + let recover = if let Some(cycle_recovery_fn) = &query.cycle { + quote! { + fn recover(db: &DB, cycle: &[DB::DatabaseKey], #key_pattern: &>::Key) + -> Option<>::Value> { + Some(#cycle_recovery_fn( + db, + &cycle.iter().map(|k| format!("{:?}", k)).collect::>(), + #(#key_names),* + )) + } + } + } else { + quote! {} + }; + output.extend(quote_spanned! {span=> impl salsa::plumbing::QueryFunction for #qt where @@ -415,6 +435,8 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream -> >::Value { #invoke(db, #(#key_names),*) } + + #recover } }); } @@ -541,6 +563,7 @@ struct Query { keys: Vec, value: syn::Type, invoke: Option, + cycle: Option, } impl Query { diff --git a/src/derived.rs b/src/derived.rs index 768b78af..12214cf2 100644 --- a/src/derived.rs +++ b/src/derived.rs @@ -1,14 +1,13 @@ use crate::debug::TableEntry; use crate::durability::Durability; use crate::lru::Lru; -use crate::plumbing::CycleDetected; use crate::plumbing::HasQueryGroup; use crate::plumbing::LruQueryStorageOps; use crate::plumbing::QueryFunction; use crate::plumbing::QueryStorageMassOps; use crate::plumbing::QueryStorageOps; use crate::runtime::StampedValue; -use crate::{Database, SweepStrategy}; +use crate::{CycleError, Database, SweepStrategy}; use parking_lot::RwLock; use rustc_hash::FxHashMap; use std::marker::PhantomData; @@ -131,7 +130,7 @@ where DB: Database + HasQueryGroup, MP: MemoizationPolicy, { - fn try_fetch(&self, db: &DB, key: &Q::Key) -> Result { + fn try_fetch(&self, db: &DB, key: &Q::Key) -> Result> { let slot = self.slot(key); let StampedValue { value, diff --git a/src/derived/slot.rs b/src/derived/slot.rs index 76d2d553..d87f3bd6 100644 --- a/src/derived/slot.rs +++ b/src/derived/slot.rs @@ -14,7 +14,7 @@ use crate::runtime::FxIndexSet; use crate::runtime::Runtime; use crate::runtime::RuntimeId; use crate::runtime::StampedValue; -use crate::{Database, DiscardIf, DiscardWhat, Event, EventKind, SweepStrategy}; +use crate::{CycleError, Database, DiscardIf, DiscardWhat, Event, EventKind, SweepStrategy}; use log::{debug, info}; use parking_lot::Mutex; use parking_lot::RwLock; @@ -36,6 +36,12 @@ where lru_index: LruIndex, } +#[derive(Clone)] +struct WaitResult { + value: StampedValue, + cycle: Vec, +} + /// Defines the "current state" of query's memoized results. enum QueryState where @@ -49,7 +55,7 @@ where /// indeeds a cycle. InProgress { id: RuntimeId, - waiting: Mutex>; 2]>>, + waiting: Mutex>; 2]>>, }, /// We have computed the query already, and here is the result. @@ -95,8 +101,8 @@ pub(super) enum MemoInputs { } /// Return value of `probe` helper. -enum ProbeState { - UpToDate(Result), +enum ProbeState { + UpToDate(Result>), StaleOrAbsent(G), } @@ -119,7 +125,10 @@ where >::database_key(db, self.key.clone()) } - pub(super) fn read(&self, db: &DB) -> Result, CycleDetected> { + pub(super) fn read( + &self, + db: &DB, + ) -> Result, CycleError> { let runtime = db.salsa_runtime(); // NB: We don't need to worry about people modifying the @@ -148,7 +157,7 @@ where &self, db: &DB, revision_now: Revision, - ) -> Result, CycleDetected> { + ) -> Result, CycleError> { let runtime = db.salsa_runtime(); debug!("{:?}: read_upgrade(revision_now={:?})", self, revision_now,); @@ -189,7 +198,15 @@ where }, }); - panic_guard.proceed(&value); + panic_guard.proceed( + &value, + // The returned value could have been produced as part of a cycle but since + // we returned the memoized value we know we short-circuited the execution + // just as we entered the cycle. Therefore there is no values to invalidate + // and no need to call a cycle handler so we do not need to return the + // actual cycle + Vec::new(), + ); return Ok(value); } @@ -203,6 +220,21 @@ where Q::execute(db, self.key.clone()) }); + if !result.cycle.is_empty() { + result.value = match Q::recover(db, &result.cycle, &self.key) { + Some(v) => v, + None => { + let err = CycleError { + cycle: result.cycle, + durability: result.durability, + changed_at: result.changed_at, + }; + panic_guard.report_unexpected_cycle(); + return Err(err); + } + }; + } + // We assume that query is side-effect free -- that is, does // not mutate the "inputs" to the query system. Sanity check // that assumption here, at least to the best of our ability. @@ -277,7 +309,7 @@ where durability: result.durability, }); - panic_guard.proceed(&new_value); + panic_guard.proceed(&new_value, result.cycle); Ok(new_value) } @@ -308,7 +340,7 @@ where state: StateGuard, runtime: &Runtime, revision_now: Revision, - ) -> ProbeState, StateGuard> + ) -> ProbeState, DB::DatabaseKey, StateGuard> where StateGuard: Deref>, { @@ -331,11 +363,42 @@ where }, }); - let value = rx.recv().unwrap_or_else(|_| db.on_propagated_panic()); - ProbeState::UpToDate(Ok(value)) + let result = rx.recv().unwrap_or_else(|_| db.on_propagated_panic()); + ProbeState::UpToDate(if result.cycle.is_empty() { + Ok(result.value) + } else { + let err = CycleError { + cycle: result.cycle, + changed_at: result.value.changed_at, + durability: result.value.durability, + }; + runtime.mark_cycle_participants(&err); + Q::recover(db, &err.cycle, &self.key) + .map(|value| StampedValue { + value, + durability: err.durability, + changed_at: err.changed_at, + }) + .ok_or_else(|| err) + }) } - Err(CycleDetected) => ProbeState::UpToDate(Err(CycleDetected)), + Err(err) => { + let err = runtime.report_unexpected_cycle( + &self.database_key(db), + err, + revision_now, + ); + ProbeState::UpToDate( + Q::recover(db, &err.cycle, &self.key) + .map(|value| StampedValue { + value, + changed_at: err.changed_at, + durability: err.durability, + }) + .ok_or_else(|| err), + ) + } }; } @@ -483,13 +546,17 @@ where db: &DB, runtime: &Runtime, other_id: RuntimeId, - waiting: &Mutex>; 2]>>, - ) -> Result>, CycleDetected> { - if other_id == runtime.id() { - return Err(CycleDetected); + waiting: &Mutex>; 2]>>, + ) -> Result>, CycleDetected> { + let id = runtime.id(); + if other_id == id { + return Err(CycleDetected { from: id, to: id }); } else { if !runtime.try_block_on(&self.database_key(db), other_id) { - return Err(CycleDetected); + return Err(CycleDetected { + from: id, + to: other_id, + }); } let (tx, rx) = mpsc::channel(); @@ -555,15 +622,23 @@ where /// Proceed with our panic guard by overwriting the placeholder for `key`. /// Once that completes, ensure that our deconstructor is not run once we /// are out of scope. - fn proceed(mut self, new_value: &StampedValue) { - self.overwrite_placeholder(Some(new_value)); + fn proceed(mut self, new_value: &StampedValue, cycle: Vec) { + self.overwrite_placeholder(Some((new_value, cycle))); + std::mem::forget(self) + } + + fn report_unexpected_cycle(mut self) { + self.overwrite_placeholder(None); std::mem::forget(self) } /// Overwrites the `InProgress` placeholder for `key` that we /// inserted; if others were blocked, waiting for us to finish, /// then notify them. - fn overwrite_placeholder(&mut self, new_value: Option<&StampedValue>) { + fn overwrite_placeholder( + &mut self, + new_value: Option<(&StampedValue, Vec)>, + ) { let mut write = self.slot.state.write(); let old_value = match self.memo.take() { @@ -587,9 +662,13 @@ where match new_value { // If anybody has installed themselves in our "waiting" // list, notify them that the value is available. - Some(new_value) => { + Some((new_value, ref cycle)) => { for tx in waiting.into_inner() { - tx.send(new_value.clone()).unwrap() + tx.send(WaitResult { + value: new_value.clone(), + cycle: cycle.clone(), + }) + .unwrap(); } } @@ -811,12 +890,12 @@ where // can complete. std::mem::drop(state); - let value = rx.recv().unwrap_or_else(|_| db.on_propagated_panic()); - return value.changed_at > revision; + let result = rx.recv().unwrap_or_else(|_| db.on_propagated_panic()); + return !result.cycle.is_empty() || result.value.changed_at > revision; } // Consider a cycle to have changed. - Err(CycleDetected) => return true, + Err(_) => return true, } } @@ -873,14 +952,14 @@ where return match self.read_upgrade(db, revision_now) { Ok(v) => { debug!( - "maybe_changed_since({:?}: {:?} since (recomputed) value changed at {:?}", - self, + "maybe_changed_since({:?}: {:?} since (recomputed) value changed at {:?}", + self, v.changed_at > revision, - v.changed_at, - ); + v.changed_at, + ); v.changed_at > revision } - Err(CycleDetected) => true, + Err(_) => true, }; } @@ -968,6 +1047,7 @@ where DB: Database + HasQueryGroup, MP: MemoizationPolicy, DB::DatabaseData: Send + Sync, + DB::DatabaseKey: Send + Sync, Q::Key: Send + Sync, Q::Value: Send + Sync, { diff --git a/src/doctest.rs b/src/doctest.rs index ff8de9a1..9d4e020d 100644 --- a/src/doctest.rs +++ b/src/doctest.rs @@ -45,6 +45,7 @@ fn test_key_not_send_db_not_send() {} /// /// ```compile_fail,E0277 /// use std::rc::Rc; +/// use std::cell::Cell; /// /// #[salsa::query_group(NoSendSyncStorage)] /// trait NoSendSyncDatabase: salsa::Database { diff --git a/src/input.rs b/src/input.rs index cb60fa71..ed2fcb01 100644 --- a/src/input.rs +++ b/src/input.rs @@ -1,12 +1,12 @@ use crate::debug::TableEntry; use crate::dependency::DatabaseSlot; use crate::durability::Durability; -use crate::plumbing::CycleDetected; use crate::plumbing::InputQueryStorageOps; use crate::plumbing::QueryStorageMassOps; use crate::plumbing::QueryStorageOps; use crate::revision::Revision; use crate::runtime::StampedValue; +use crate::CycleError; use crate::Database; use crate::Event; use crate::EventKind; @@ -74,10 +74,10 @@ where Q: Query, DB: Database, { - fn try_fetch(&self, db: &DB, key: &Q::Key) -> Result { - let slot = self.slot(key).unwrap_or_else(|| { - panic!("no value set for {:?}({:?})", Q::default(), key) - }); + fn try_fetch(&self, db: &DB, key: &Q::Key) -> Result> { + let slot = self + .slot(key) + .unwrap_or_else(|| panic!("no value set for {:?}({:?})", Q::default(), key)); let StampedValue { value, diff --git a/src/interned.rs b/src/interned.rs index 26c76497..aaffdc77 100644 --- a/src/interned.rs +++ b/src/interned.rs @@ -2,13 +2,12 @@ use crate::debug::TableEntry; use crate::dependency::DatabaseSlot; use crate::durability::Durability; use crate::intern_id::InternId; -use crate::plumbing::CycleDetected; use crate::plumbing::HasQueryGroup; use crate::plumbing::QueryStorageMassOps; use crate::plumbing::QueryStorageOps; use crate::revision::Revision; use crate::Query; -use crate::{Database, DiscardIf, SweepStrategy}; +use crate::{CycleError, Database, DiscardIf, SweepStrategy}; use crossbeam::atomic::AtomicCell; use parking_lot::RwLock; use rustc_hash::FxHashMap; @@ -321,7 +320,7 @@ where Q::Value: InternKey, DB: Database, { - fn try_fetch(&self, db: &DB, key: &Q::Key) -> Result { + fn try_fetch(&self, db: &DB, key: &Q::Key) -> Result> { let slot = self.intern_index(db, key); let changed_at = slot.interned_at; let index = slot.index; @@ -420,7 +419,7 @@ where >, DB: Database + HasQueryGroup, { - fn try_fetch(&self, db: &DB, key: &Q::Key) -> Result { + fn try_fetch(&self, db: &DB, key: &Q::Key) -> Result> { let index = key.as_intern_id(); let group_storage = >::group_storage(db); let interned_storage = IQ::query_storage(group_storage); diff --git a/src/lib.rs b/src/lib.rs index a9577fa3..4a0d79c0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -24,11 +24,11 @@ pub mod debug; #[doc(hidden)] pub mod plumbing; -use crate::plumbing::CycleDetected; use crate::plumbing::InputQueryStorageOps; use crate::plumbing::LruQueryStorageOps; use crate::plumbing::QueryStorageMassOps; use crate::plumbing::QueryStorageOps; +use crate::revision::Revision; use derive_new::new; use std::fmt::{self, Debug}; use std::hash::Hash; @@ -468,14 +468,11 @@ where /// queries (those with no inputs, or those with more than one /// input) the key will be a tuple. pub fn get(&self, key: Q::Key) -> Q::Value { - self.storage - .try_fetch(self.db, &key) - .unwrap_or_else(|CycleDetected| { - let database_key = self.database_key(&key); - self.db - .salsa_runtime() - .report_unexpected_cycle(database_key) - }) + self.try_get(key).unwrap_or_else(|err| panic!("{}", err)) + } + + fn try_get(&self, key: Q::Key) -> Result> { + self.storage.try_fetch(self.db, &key) } /// Remove all values for this query that have not been used in @@ -486,10 +483,6 @@ where { self.storage.sweep(self.db, strategy); } - - fn database_key(&self, key: &Q::Key) -> DB::DatabaseKey { - >::database_key(&self.db, key.clone()) - } } /// Return value from [the `query_mut` method] on `Database`. @@ -561,6 +554,28 @@ where } } +/// The error returned when a query could not be resolved due to a cycle +#[derive(Eq, PartialEq, Clone, Debug)] +pub struct CycleError { + /// The queries that were part of the cycle + cycle: Vec, + changed_at: Revision, + durability: Durability, +} + +impl fmt::Display for CycleError +where + K: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!(f, "Internal error, cycle detected:\n")?; + for i in &self.cycle { + writeln!(f, "{:?}", i)?; + } + Ok(()) + } +} + // Re-export the procedural macros. #[allow(unused_imports)] #[macro_use] diff --git a/src/plumbing.rs b/src/plumbing.rs index d4dd8bbf..b18c239b 100644 --- a/src/plumbing.rs +++ b/src/plumbing.rs @@ -2,10 +2,12 @@ use crate::debug::TableEntry; use crate::durability::Durability; +use crate::CycleError; use crate::Database; use crate::Query; use crate::QueryTable; use crate::QueryTableMut; +use crate::RuntimeId; use crate::SweepStrategy; use std::fmt::Debug; use std::hash::Hash; @@ -17,7 +19,11 @@ pub use crate::interned::InternedStorage; pub use crate::interned::LookupInternedStorage; pub use crate::revision::Revision; -pub struct CycleDetected; +#[derive(Clone, Debug)] +pub struct CycleDetected { + pub(crate) from: RuntimeId, + pub(crate) to: RuntimeId, +} /// Defines various associated types. An impl of this /// should be generated for your query-context type automatically by @@ -61,6 +67,10 @@ pub trait DatabaseKey: Clone + Debug + Eq + Hash {} pub trait QueryFunction: Query { fn execute(db: &DB, key: Self::Key) -> Self::Value; + fn recover(db: &DB, cycle: &[DB::DatabaseKey], key: &Self::Key) -> Option { + let _ = (db, cycle, key); + None + } } /// The `GetQueryTable` trait makes the connection the *database type* @@ -146,7 +156,7 @@ where /// Returns `Err` in the event of a cycle, meaning that computing /// the value for this `key` is recursively attempting to fetch /// itself. - fn try_fetch(&self, db: &DB, key: &Q::Key) -> Result; + fn try_fetch(&self, db: &DB, key: &Q::Key) -> Result>; /// Returns the durability associated with a given key. fn durability(&self, db: &DB, key: &Q::Key) -> Durability; diff --git a/src/runtime.rs b/src/runtime.rs index 06f2ced4..f6d521ae 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -1,15 +1,15 @@ use crate::dependency::DatabaseSlot; use crate::dependency::Dependency; use crate::durability::Durability; +use crate::plumbing::CycleDetected; use crate::revision::{AtomicRevision, Revision}; -use crate::{Database, Event, EventKind, SweepStrategy}; +use crate::{CycleError, Database, Event, EventKind, SweepStrategy}; use log::debug; -use parking_lot::{Mutex, RwLock}; use parking_lot::lock_api::{RawRwLock, RawRwLockRecursive}; +use parking_lot::{Mutex, RwLock}; use rustc_hash::{FxHashMap, FxHasher}; use smallvec::SmallVec; -use std::fmt::Write; -use std::hash::BuildHasherDefault; +use std::hash::{BuildHasherDefault, Hash}; use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; @@ -355,6 +355,7 @@ where dependencies, changed_at, durability, + cycle, .. } = active_query.complete(); @@ -363,6 +364,7 @@ where durability, changed_at, dependencies, + cycle, } } @@ -407,30 +409,100 @@ where } /// Obviously, this should be user configurable at some point. - pub(crate) fn report_unexpected_cycle(&self, database_key: DB::DatabaseKey) -> ! { + pub(crate) fn report_unexpected_cycle( + &self, + database_key: &DB::DatabaseKey, + error: CycleDetected, + changed_at: Revision, + ) -> crate::CycleError { debug!("report_unexpected_cycle(database_key={:?})", database_key); - let query_stack = self.local_state.borrow_query_stack(); - let start_index = (0..query_stack.len()) - .rev() - .filter(|&i| query_stack[i].database_key == database_key) - .next() - .unwrap(); + let mut query_stack = self.local_state.borrow_query_stack_mut(); - let mut message = format!("Internal error, cycle detected:\n"); - for active_query in &query_stack[start_index..] { - writeln!(message, "- {:?}\n", active_query.database_key).unwrap(); + if error.from == error.to { + // All queries in the cycle is local + let start_index = query_stack + .iter() + .rposition(|active_query| active_query.database_key == *database_key) + .unwrap(); + let mut cycle = Vec::new(); + let cycle_participants = &mut query_stack[start_index..]; + for active_query in &mut *cycle_participants { + cycle.push(active_query.database_key.clone()); + } + + assert!(!cycle.is_empty()); + + for active_query in cycle_participants { + active_query.cycle = cycle.clone(); + } + + crate::CycleError { + cycle, + changed_at, + durability: Durability::MAX, + } + } else { + // Part of the cycle is on another thread so we need to lock and inspect the shared + // state + let dependency_graph = self.shared_state.dependency_graph.lock(); + + let mut cycle = Vec::new(); + { + let cycle_iter = dependency_graph + .get_cycle_path( + database_key, + error.to, + query_stack.iter().map(|query| &query.database_key), + ) + .chain(Some(database_key)); + + for key in cycle_iter { + cycle.push(key.clone()); + } + } + + assert!(!cycle.is_empty()); + + for active_query in query_stack + .iter_mut() + .filter(|query| cycle.iter().any(|key| *key == query.database_key)) + { + active_query.cycle = cycle.clone(); + } + + crate::CycleError { + cycle, + changed_at, + durability: Durability::MAX, + } + } + } + + pub(crate) fn mark_cycle_participants(&self, err: &CycleError) { + for active_query in self + .local_state + .borrow_query_stack_mut() + .iter_mut() + .rev() + .take_while(|active_query| err.cycle.iter().any(|e| *e == active_query.database_key)) + { + active_query.cycle = err.cycle.clone(); } - panic!(message) } /// Try to make this runtime blocked on `other_id`. Returns true /// upon success or false if `other_id` is already blocked on us. pub(crate) fn try_block_on(&self, database_key: &DB::DatabaseKey, other_id: RuntimeId) -> bool { - self.shared_state - .dependency_graph - .lock() - .add_edge(self.id(), database_key, other_id) + self.shared_state.dependency_graph.lock().add_edge( + self.id(), + database_key, + other_id, + self.local_state + .borrow_query_stack() + .iter() + .map(|query| query.database_key.clone()), + ) } pub(crate) fn unblock_queries_blocked_on_self(&self, database_key: &DB::DatabaseKey) { @@ -508,7 +580,7 @@ struct SharedState { /// The dependency graph tracks which runtimes are blocked on one /// another, waiting for queries to terminate. - dependency_graph: Mutex>, + dependency_graph: Mutex>, } impl SharedState { @@ -571,6 +643,9 @@ struct ActiveQuery { /// Set of subqueries that were accessed thus far, or `None` if /// there was an untracked the read. dependencies: Option>>, + + /// Stores the entire cycle, if one is found and this query is part of it. + cycle: Vec, } pub(crate) struct ComputedQueryResult { @@ -587,6 +662,9 @@ pub(crate) struct ComputedQueryResult { /// Complete set of subqueries that were accessed, or `None` if /// there was an untracked the read. pub(crate) dependencies: Option>>, + + /// The cycle if one occured while computing this value + pub(crate) cycle: Vec, } impl ActiveQuery { @@ -596,6 +674,7 @@ impl ActiveQuery { durability: max_durability, changed_at: Revision::start(), dependencies: Some(FxIndexSet::default()), + cycle: Vec::new(), } } @@ -634,16 +713,26 @@ pub(crate) struct StampedValue { pub(crate) changed_at: Revision, } -struct DependencyGraph { +#[derive(Debug)] +struct Edge { + id: RuntimeId, + path: Vec, +} + +#[derive(Debug)] +struct DependencyGraph { /// A `(K -> V)` pair in this map indicates that the the runtime /// `K` is blocked on some query executing in the runtime `V`. /// This encodes a graph that must be acyclic (or else deadlock /// will result). - edges: FxHashMap, - labels: FxHashMap>, + edges: FxHashMap>, + labels: FxHashMap>, } -impl Default for DependencyGraph { +impl Default for DependencyGraph +where + K: Hash + Eq, +{ fn default() -> Self { DependencyGraph { edges: Default::default(), @@ -652,13 +741,17 @@ impl Default for DependencyGraph { } } -impl DependencyGraph { +impl DependencyGraph +where + K: Hash + Eq + Clone, +{ /// Attempt to add an edge `from_id -> to_id` into the result graph. fn add_edge( &mut self, from_id: RuntimeId, - database_key: &DB::DatabaseKey, + database_key: &K, to_id: RuntimeId, + path: impl IntoIterator, ) -> bool { assert_ne!(from_id, to_id); debug_assert!(!self.edges.contains_key(&from_id)); @@ -666,7 +759,7 @@ impl DependencyGraph { // First: walk the chain of things that `to_id` depends on, // looking for us. let mut p = to_id; - while let Some(&q) = self.edges.get(&p) { + while let Some(q) = self.edges.get(&p).map(|edge| edge.id) { if q == from_id { return false; } @@ -674,7 +767,13 @@ impl DependencyGraph { p = q; } - self.edges.insert(from_id, to_id); + self.edges.insert( + from_id, + Edge { + id: to_id, + path: path.into_iter().chain(Some(database_key.clone())).collect(), + }, + ); self.labels .entry(database_key.clone()) .or_default() @@ -682,17 +781,54 @@ impl DependencyGraph { true } - fn remove_edge(&mut self, database_key: &DB::DatabaseKey, to_id: RuntimeId) { - let vec = self - .labels - .remove(database_key) - .unwrap_or_default(); + fn remove_edge(&mut self, database_key: &K, to_id: RuntimeId) { + let vec = self.labels.remove(database_key).unwrap_or_default(); for from_id in &vec { - let to_id1 = self.edges.remove(from_id); + let to_id1 = self.edges.remove(from_id).map(|edge| edge.id); assert_eq!(Some(to_id), to_id1); } } + + fn get_cycle_path<'a>( + &'a self, + database_key: &'a K, + to: RuntimeId, + local_path: impl IntoIterator, + ) -> impl Iterator + where + K: std::fmt::Debug, + { + let mut current = Some((to, std::slice::from_ref(database_key))); + let mut last = None; + let mut local_path = Some(local_path); + std::iter::from_fn(move || match current.take() { + Some((id, path)) => { + let link_key = path.last().unwrap(); + + current = self.edges.get(&id).map(|edge| { + let i = edge.path.iter().rposition(|p| p == link_key).unwrap(); + (edge.id, &edge.path[i + 1..]) + }); + + if current.is_none() { + last = local_path.take().map(|local_path| { + local_path + .into_iter() + .skip_while(move |p| *p != link_key) + .skip(1) + }); + } + + Some(path) + } + None => match &mut last { + Some(iter) => iter.next().map(std::slice::from_ref), + None => None, + }, + }) + .flat_map(|x| x) + } } struct RevisionGuard { @@ -742,3 +878,42 @@ where } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn dependency_graph_path1() { + let mut graph = DependencyGraph::default(); + let a = RuntimeId { counter: 0 }; + let b = RuntimeId { counter: 1 }; + assert!(graph.add_edge(a, &2, b, vec![1])); + // assert!(graph.add_edge(b, &1, a, vec![3, 2])); + assert_eq!( + graph + .get_cycle_path(&1, a, &[3, 2][..]) + .cloned() + .collect::>(), + vec![1, 2] + ); + } + + #[test] + fn dependency_graph_path2() { + let mut graph = DependencyGraph::default(); + let a = RuntimeId { counter: 0 }; + let b = RuntimeId { counter: 1 }; + let c = RuntimeId { counter: 2 }; + assert!(graph.add_edge(a, &3, b, vec![1])); + assert!(graph.add_edge(b, &4, c, vec![2, 3])); + // assert!(graph.add_edge(c, &1, a, vec![5, 6, 4, 7])); + assert_eq!( + graph + .get_cycle_path(&1, a, &[5, 6, 4, 7][..]) + .cloned() + .collect::>(), + vec![1, 3, 4, 7] + ); + } +} diff --git a/src/runtime/local_state.rs b/src/runtime/local_state.rs index be148999..1a81f247 100644 --- a/src/runtime/local_state.rs +++ b/src/runtime/local_state.rs @@ -3,8 +3,7 @@ use crate::durability::Durability; use crate::runtime::ActiveQuery; use crate::runtime::Revision; use crate::Database; -use std::cell::Ref; -use std::cell::RefCell; +use std::cell::{Ref, RefCell, RefMut}; /// State that is specific to a single execution thread. /// @@ -51,6 +50,10 @@ impl LocalState { self.query_stack.borrow() } + pub(super) fn borrow_query_stack_mut(&self) -> RefMut<'_, Vec>> { + self.query_stack.borrow_mut() + } + pub(super) fn query_in_progress(&self) -> bool { !self.query_stack.borrow().is_empty() } diff --git a/tests/cycles.rs b/tests/cycles.rs index 77095aa4..a07134dc 100644 --- a/tests/cycles.rs +++ b/tests/cycles.rs @@ -1,3 +1,10 @@ +use salsa::{ParallelDatabase, Snapshot}; + +#[derive(PartialEq, Eq, Hash, Clone, Debug)] +struct Error { + cycle: Vec, +} + #[salsa::database(GroupStruct)] #[derive(Default)] struct DatabaseImpl { @@ -10,6 +17,14 @@ impl salsa::Database for DatabaseImpl { } } +impl ParallelDatabase for DatabaseImpl { + fn snapshot(&self) -> Snapshot { + Snapshot::new(DatabaseImpl { + runtime: self.runtime.snapshot(self), + }) + } +} + #[salsa::query_group(GroupStruct)] trait Database: salsa::Database { // `a` and `b` depend on each other and form a cycle @@ -17,6 +32,27 @@ trait Database: salsa::Database { fn memoized_b(&self) -> (); fn volatile_a(&self) -> (); fn volatile_b(&self) -> (); + + fn cycle_leaf(&self) -> (); + + #[salsa::cycle(recover_a)] + fn cycle_a(&self) -> Result<(), Error>; + #[salsa::cycle(recover_b)] + fn cycle_b(&self) -> Result<(), Error>; + + fn cycle_c(&self) -> Result<(), Error>; +} + +fn recover_a(_db: &impl Database, cycle: &[String]) -> Result<(), Error> { + Err(Error { + cycle: cycle.to_owned(), + }) +} + +fn recover_b(_db: &impl Database, cycle: &[String]) -> Result<(), Error> { + Err(Error { + cycle: cycle.to_owned(), + }) } fn memoized_a(db: &impl Database) -> () { @@ -37,6 +73,23 @@ fn volatile_b(db: &impl Database) -> () { db.volatile_a() } +fn cycle_leaf(_db: &impl Database) -> () {} + +fn cycle_a(db: &impl Database) -> Result<(), Error> { + let _ = db.cycle_b(); + Ok(()) +} + +fn cycle_b(db: &impl Database) -> Result<(), Error> { + db.cycle_leaf(); + let _ = db.cycle_a(); + Ok(()) +} + +fn cycle_c(db: &impl Database) -> Result<(), Error> { + db.cycle_b() +} + #[test] #[should_panic(expected = "cycle detected")] fn cycle_memoized() { @@ -50,3 +103,64 @@ fn cycle_volatile() { let query = DatabaseImpl::default(); query.volatile_a(); } + +#[test] +fn cycle_cycle() { + let query = DatabaseImpl::default(); + assert!(query.cycle_a().is_err()); +} + +#[test] +fn inner_cycle() { + let query = DatabaseImpl::default(); + let err = query.cycle_c(); + assert!(err.is_err()); + let cycle = err.unwrap_err().cycle; + assert!( + cycle + .iter() + .zip(&["cycle_b", "cycle_a"]) + .all(|(l, r)| l.contains(r)), + "{:#?}", + cycle + ); +} + +#[test] +fn parallel_cycle() { + let _ = env_logger::try_init(); + + let db = DatabaseImpl::default(); + let thread1 = std::thread::spawn({ + let db = db.snapshot(); + move || { + let result = db.cycle_a(); + assert!(result.is_err(), "Expected cycle error"); + let cycle = result.unwrap_err().cycle; + assert!( + cycle + .iter() + .all(|l| ["cycle_b", "cycle_a"].iter().any(|r| l.contains(r))), + "{:#?}", + cycle + ); + } + }); + + let thread2 = std::thread::spawn(move || { + let result = db.cycle_c(); + assert!(result.is_err(), "Expected cycle error"); + let cycle = result.unwrap_err().cycle; + assert!( + cycle + .iter() + .all(|l| ["cycle_b", "cycle_a"].iter().any(|r| l.contains(r))), + "{:#?}", + cycle + ); + }); + + thread1.join().unwrap(); + thread2.join().unwrap(); + eprintln!("OK"); +} diff --git a/tests/parallel/cancellation.rs b/tests/parallel/cancellation.rs index 2e45f74c..2cfe6bba 100644 --- a/tests/parallel/cancellation.rs +++ b/tests/parallel/cancellation.rs @@ -1,6 +1,4 @@ -use crate::setup::{ - CancelationFlag, Canceled, Knobs, ParDatabase, ParDatabaseImpl, WithValue, -}; +use crate::setup::{CancelationFlag, Canceled, Knobs, ParDatabase, ParDatabaseImpl, WithValue}; use salsa::ParallelDatabase; macro_rules! assert_canceled { diff --git a/tests/parallel/independent.rs b/tests/parallel/independent.rs index 5ca76101..8f099725 100644 --- a/tests/parallel/independent.rs +++ b/tests/parallel/independent.rs @@ -22,7 +22,7 @@ fn in_par_two_independent_queries() { let thread2 = std::thread::spawn({ let db = db.snapshot(); move || db.sum("def") - });; + }); assert_eq!(thread1.join().unwrap(), 111); assert_eq!(thread2.join().unwrap(), 222); diff --git a/tests/requires.rs b/tests/requires.rs index ea4502a5..2737aba9 100644 --- a/tests/requires.rs +++ b/tests/requires.rs @@ -1,7 +1,6 @@ //! Test `salsa::requires` attribute for private query dependencies //! https://github.com/salsa-rs/salsa-rfcs/pull/3 - mod queries { #[salsa::query_group(InputGroupStorage)] pub trait InputGroup { @@ -14,7 +13,7 @@ mod queries { fn private_a(&self, x: u32) -> u32; } - fn private_a(db: &impl PrivGroupA, x: u32) -> u32{ + fn private_a(db: &impl PrivGroupA, x: u32) -> u32 { db.input(x) } @@ -23,7 +22,7 @@ mod queries { fn private_b(&self, x: u32) -> u32; } - fn private_b(db: &impl PrivGroupB, x: u32) -> u32{ + fn private_b(db: &impl PrivGroupB, x: u32) -> u32 { db.input(x) } @@ -34,7 +33,6 @@ mod queries { fn public(&self, x: u32) -> u32; } - fn public(db: &(impl PubGroup + PrivGroupA + PrivGroupB), x: u32) -> u32 { db.private_a(x) + db.private_b(x) } @@ -44,7 +42,7 @@ mod queries { queries::InputGroupStorage, queries::PrivGroupAStorage, queries::PrivGroupBStorage, - queries::PubGroupStorage, + queries::PubGroupStorage )] #[derive(Default)] struct Database { diff --git a/tests/transparent.rs b/tests/transparent.rs index 025ec03a..d9e342b4 100644 --- a/tests/transparent.rs +++ b/tests/transparent.rs @@ -17,7 +17,6 @@ fn get(db: &impl QueryGroup, x: u32) -> u32 { db.wrap(x) } - #[salsa::database(QueryGroupStorage)] #[derive(Default)] struct Database {