Merge pull request #147 from Marwes/cycles

feat: Allow queries to avoid panics on cycles
This commit is contained in:
Niko Matsakis 2019-09-19 05:56:39 -04:00 committed by GitHub
commit a9860bf37f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 520 additions and 106 deletions

View file

@ -39,6 +39,7 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream
match item { match item {
TraitItem::Method(method) => { TraitItem::Method(method) => {
let mut storage = QueryStorage::Memoized; let mut storage = QueryStorage::Memoized;
let mut cycle = None;
let mut invoke = None; let mut invoke = None;
let mut query_type = Ident::new( let mut query_type = Ident::new(
&format!("{}Query", method.sig.ident.to_string().to_camel_case()), &format!("{}Query", method.sig.ident.to_string().to_camel_case()),
@ -66,6 +67,9 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream
storage = QueryStorage::Interned; storage = QueryStorage::Interned;
num_storages += 1; num_storages += 1;
} }
"cycle" => {
cycle = Some(parse_macro_input!(tts as Parenthesized<syn::Path>).0);
}
"invoke" => { "invoke" => {
invoke = Some(parse_macro_input!(tts as Parenthesized<syn::Path>).0); invoke = Some(parse_macro_input!(tts as Parenthesized<syn::Path>).0);
} }
@ -150,6 +154,7 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream
keys: lookup_keys, keys: lookup_keys,
value: lookup_value, value: lookup_value,
invoke: None, invoke: None,
cycle: cycle.clone(),
}) })
} else { } else {
None None
@ -163,6 +168,7 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream
keys, keys,
value, value,
invoke, invoke,
cycle,
}); });
queries.extend(lookup_query); queries.extend(lookup_query);
@ -354,9 +360,7 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream
QueryStorage::Dependencies => quote!(salsa::plumbing::DependencyStorage<#db, Self>), QueryStorage::Dependencies => quote!(salsa::plumbing::DependencyStorage<#db, Self>),
QueryStorage::Input => quote!(salsa::plumbing::InputStorage<#db, Self>), QueryStorage::Input => quote!(salsa::plumbing::InputStorage<#db, Self>),
QueryStorage::Interned => quote!(salsa::plumbing::InternedStorage<#db, Self>), QueryStorage::Interned => quote!(salsa::plumbing::InternedStorage<#db, Self>),
QueryStorage::InternedLookup { intern_query_type } => { QueryStorage::InternedLookup { intern_query_type } => quote!(salsa::plumbing::LookupInternedStorage<#db, Self, #intern_query_type>),
quote!(salsa::plumbing::LookupInternedStorage<#db, Self, #intern_query_type>)
}
QueryStorage::Transparent => continue, QueryStorage::Transparent => continue,
}; };
let keys = &query.keys; let keys = &query.keys;
@ -404,6 +408,22 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream
quote! { (#(#key_names),*) } quote! { (#(#key_names),*) }
}; };
let invoke = query.invoke_tt(); let invoke = query.invoke_tt();
let recover = if let Some(cycle_recovery_fn) = &query.cycle {
quote! {
fn recover(db: &DB, cycle: &[DB::DatabaseKey], #key_pattern: &<Self as salsa::Query<DB>>::Key)
-> Option<<Self as salsa::Query<DB>>::Value> {
Some(#cycle_recovery_fn(
db,
&cycle.iter().map(|k| format!("{:?}", k)).collect::<Vec<String>>(),
#(#key_names),*
))
}
}
} else {
quote! {}
};
output.extend(quote_spanned! {span=> output.extend(quote_spanned! {span=>
impl<DB> salsa::plumbing::QueryFunction<DB> for #qt impl<DB> salsa::plumbing::QueryFunction<DB> for #qt
where where
@ -415,6 +435,8 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream
-> <Self as salsa::Query<DB>>::Value { -> <Self as salsa::Query<DB>>::Value {
#invoke(db, #(#key_names),*) #invoke(db, #(#key_names),*)
} }
#recover
} }
}); });
} }
@ -541,6 +563,7 @@ struct Query {
keys: Vec<syn::Type>, keys: Vec<syn::Type>,
value: syn::Type, value: syn::Type,
invoke: Option<syn::Path>, invoke: Option<syn::Path>,
cycle: Option<syn::Path>,
} }
impl Query { impl Query {

View file

@ -1,14 +1,13 @@
use crate::debug::TableEntry; use crate::debug::TableEntry;
use crate::durability::Durability; use crate::durability::Durability;
use crate::lru::Lru; use crate::lru::Lru;
use crate::plumbing::CycleDetected;
use crate::plumbing::HasQueryGroup; use crate::plumbing::HasQueryGroup;
use crate::plumbing::LruQueryStorageOps; use crate::plumbing::LruQueryStorageOps;
use crate::plumbing::QueryFunction; use crate::plumbing::QueryFunction;
use crate::plumbing::QueryStorageMassOps; use crate::plumbing::QueryStorageMassOps;
use crate::plumbing::QueryStorageOps; use crate::plumbing::QueryStorageOps;
use crate::runtime::StampedValue; use crate::runtime::StampedValue;
use crate::{Database, SweepStrategy}; use crate::{CycleError, Database, SweepStrategy};
use parking_lot::RwLock; use parking_lot::RwLock;
use rustc_hash::FxHashMap; use rustc_hash::FxHashMap;
use std::marker::PhantomData; use std::marker::PhantomData;
@ -131,7 +130,7 @@ where
DB: Database + HasQueryGroup<Q::Group>, DB: Database + HasQueryGroup<Q::Group>,
MP: MemoizationPolicy<DB, Q>, MP: MemoizationPolicy<DB, Q>,
{ {
fn try_fetch(&self, db: &DB, key: &Q::Key) -> Result<Q::Value, CycleDetected> { fn try_fetch(&self, db: &DB, key: &Q::Key) -> Result<Q::Value, CycleError<DB::DatabaseKey>> {
let slot = self.slot(key); let slot = self.slot(key);
let StampedValue { let StampedValue {
value, value,

View file

@ -14,7 +14,7 @@ use crate::runtime::FxIndexSet;
use crate::runtime::Runtime; use crate::runtime::Runtime;
use crate::runtime::RuntimeId; use crate::runtime::RuntimeId;
use crate::runtime::StampedValue; use crate::runtime::StampedValue;
use crate::{Database, DiscardIf, DiscardWhat, Event, EventKind, SweepStrategy}; use crate::{CycleError, Database, DiscardIf, DiscardWhat, Event, EventKind, SweepStrategy};
use log::{debug, info}; use log::{debug, info};
use parking_lot::Mutex; use parking_lot::Mutex;
use parking_lot::RwLock; use parking_lot::RwLock;
@ -36,6 +36,12 @@ where
lru_index: LruIndex, lru_index: LruIndex,
} }
#[derive(Clone)]
struct WaitResult<V, K> {
value: StampedValue<V>,
cycle: Vec<K>,
}
/// Defines the "current state" of query's memoized results. /// Defines the "current state" of query's memoized results.
enum QueryState<DB, Q> enum QueryState<DB, Q>
where where
@ -49,7 +55,7 @@ where
/// indeeds a cycle. /// indeeds a cycle.
InProgress { InProgress {
id: RuntimeId, id: RuntimeId,
waiting: Mutex<SmallVec<[Sender<StampedValue<Q::Value>>; 2]>>, waiting: Mutex<SmallVec<[Sender<WaitResult<Q::Value, DB::DatabaseKey>>; 2]>>,
}, },
/// We have computed the query already, and here is the result. /// We have computed the query already, and here is the result.
@ -95,8 +101,8 @@ pub(super) enum MemoInputs<DB: Database> {
} }
/// Return value of `probe` helper. /// Return value of `probe` helper.
enum ProbeState<V, G> { enum ProbeState<V, K, G> {
UpToDate(Result<V, CycleDetected>), UpToDate(Result<V, CycleError<K>>),
StaleOrAbsent(G), StaleOrAbsent(G),
} }
@ -119,7 +125,10 @@ where
<DB as GetQueryTable<Q>>::database_key(db, self.key.clone()) <DB as GetQueryTable<Q>>::database_key(db, self.key.clone())
} }
pub(super) fn read(&self, db: &DB) -> Result<StampedValue<Q::Value>, CycleDetected> { pub(super) fn read(
&self,
db: &DB,
) -> Result<StampedValue<Q::Value>, CycleError<DB::DatabaseKey>> {
let runtime = db.salsa_runtime(); let runtime = db.salsa_runtime();
// NB: We don't need to worry about people modifying the // NB: We don't need to worry about people modifying the
@ -148,7 +157,7 @@ where
&self, &self,
db: &DB, db: &DB,
revision_now: Revision, revision_now: Revision,
) -> Result<StampedValue<Q::Value>, CycleDetected> { ) -> Result<StampedValue<Q::Value>, CycleError<DB::DatabaseKey>> {
let runtime = db.salsa_runtime(); let runtime = db.salsa_runtime();
debug!("{:?}: read_upgrade(revision_now={:?})", self, revision_now,); debug!("{:?}: read_upgrade(revision_now={:?})", self, revision_now,);
@ -189,7 +198,15 @@ where
}, },
}); });
panic_guard.proceed(&value); panic_guard.proceed(
&value,
// The returned value could have been produced as part of a cycle but since
// we returned the memoized value we know we short-circuited the execution
// just as we entered the cycle. Therefore there is no values to invalidate
// and no need to call a cycle handler so we do not need to return the
// actual cycle
Vec::new(),
);
return Ok(value); return Ok(value);
} }
@ -203,6 +220,21 @@ where
Q::execute(db, self.key.clone()) Q::execute(db, self.key.clone())
}); });
if !result.cycle.is_empty() {
result.value = match Q::recover(db, &result.cycle, &self.key) {
Some(v) => v,
None => {
let err = CycleError {
cycle: result.cycle,
durability: result.durability,
changed_at: result.changed_at,
};
panic_guard.report_unexpected_cycle();
return Err(err);
}
};
}
// We assume that query is side-effect free -- that is, does // We assume that query is side-effect free -- that is, does
// not mutate the "inputs" to the query system. Sanity check // not mutate the "inputs" to the query system. Sanity check
// that assumption here, at least to the best of our ability. // that assumption here, at least to the best of our ability.
@ -277,7 +309,7 @@ where
durability: result.durability, durability: result.durability,
}); });
panic_guard.proceed(&new_value); panic_guard.proceed(&new_value, result.cycle);
Ok(new_value) Ok(new_value)
} }
@ -308,7 +340,7 @@ where
state: StateGuard, state: StateGuard,
runtime: &Runtime<DB>, runtime: &Runtime<DB>,
revision_now: Revision, revision_now: Revision,
) -> ProbeState<StampedValue<Q::Value>, StateGuard> ) -> ProbeState<StampedValue<Q::Value>, DB::DatabaseKey, StateGuard>
where where
StateGuard: Deref<Target = QueryState<DB, Q>>, StateGuard: Deref<Target = QueryState<DB, Q>>,
{ {
@ -331,11 +363,42 @@ where
}, },
}); });
let value = rx.recv().unwrap_or_else(|_| db.on_propagated_panic()); let result = rx.recv().unwrap_or_else(|_| db.on_propagated_panic());
ProbeState::UpToDate(Ok(value)) ProbeState::UpToDate(if result.cycle.is_empty() {
Ok(result.value)
} else {
let err = CycleError {
cycle: result.cycle,
changed_at: result.value.changed_at,
durability: result.value.durability,
};
runtime.mark_cycle_participants(&err);
Q::recover(db, &err.cycle, &self.key)
.map(|value| StampedValue {
value,
durability: err.durability,
changed_at: err.changed_at,
})
.ok_or_else(|| err)
})
} }
Err(CycleDetected) => ProbeState::UpToDate(Err(CycleDetected)), Err(err) => {
let err = runtime.report_unexpected_cycle(
&self.database_key(db),
err,
revision_now,
);
ProbeState::UpToDate(
Q::recover(db, &err.cycle, &self.key)
.map(|value| StampedValue {
value,
changed_at: err.changed_at,
durability: err.durability,
})
.ok_or_else(|| err),
)
}
}; };
} }
@ -483,13 +546,17 @@ where
db: &DB, db: &DB,
runtime: &Runtime<DB>, runtime: &Runtime<DB>,
other_id: RuntimeId, other_id: RuntimeId,
waiting: &Mutex<SmallVec<[Sender<StampedValue<Q::Value>>; 2]>>, waiting: &Mutex<SmallVec<[Sender<WaitResult<Q::Value, DB::DatabaseKey>>; 2]>>,
) -> Result<Receiver<StampedValue<Q::Value>>, CycleDetected> { ) -> Result<Receiver<WaitResult<Q::Value, DB::DatabaseKey>>, CycleDetected> {
if other_id == runtime.id() { let id = runtime.id();
return Err(CycleDetected); if other_id == id {
return Err(CycleDetected { from: id, to: id });
} else { } else {
if !runtime.try_block_on(&self.database_key(db), other_id) { if !runtime.try_block_on(&self.database_key(db), other_id) {
return Err(CycleDetected); return Err(CycleDetected {
from: id,
to: other_id,
});
} }
let (tx, rx) = mpsc::channel(); let (tx, rx) = mpsc::channel();
@ -555,15 +622,23 @@ where
/// Proceed with our panic guard by overwriting the placeholder for `key`. /// Proceed with our panic guard by overwriting the placeholder for `key`.
/// Once that completes, ensure that our deconstructor is not run once we /// Once that completes, ensure that our deconstructor is not run once we
/// are out of scope. /// are out of scope.
fn proceed(mut self, new_value: &StampedValue<Q::Value>) { fn proceed(mut self, new_value: &StampedValue<Q::Value>, cycle: Vec<DB::DatabaseKey>) {
self.overwrite_placeholder(Some(new_value)); self.overwrite_placeholder(Some((new_value, cycle)));
std::mem::forget(self)
}
fn report_unexpected_cycle(mut self) {
self.overwrite_placeholder(None);
std::mem::forget(self) std::mem::forget(self)
} }
/// Overwrites the `InProgress` placeholder for `key` that we /// Overwrites the `InProgress` placeholder for `key` that we
/// inserted; if others were blocked, waiting for us to finish, /// inserted; if others were blocked, waiting for us to finish,
/// then notify them. /// then notify them.
fn overwrite_placeholder(&mut self, new_value: Option<&StampedValue<Q::Value>>) { fn overwrite_placeholder(
&mut self,
new_value: Option<(&StampedValue<Q::Value>, Vec<DB::DatabaseKey>)>,
) {
let mut write = self.slot.state.write(); let mut write = self.slot.state.write();
let old_value = match self.memo.take() { let old_value = match self.memo.take() {
@ -587,9 +662,13 @@ where
match new_value { match new_value {
// If anybody has installed themselves in our "waiting" // If anybody has installed themselves in our "waiting"
// list, notify them that the value is available. // list, notify them that the value is available.
Some(new_value) => { Some((new_value, ref cycle)) => {
for tx in waiting.into_inner() { for tx in waiting.into_inner() {
tx.send(new_value.clone()).unwrap() tx.send(WaitResult {
value: new_value.clone(),
cycle: cycle.clone(),
})
.unwrap();
} }
} }
@ -811,12 +890,12 @@ where
// can complete. // can complete.
std::mem::drop(state); std::mem::drop(state);
let value = rx.recv().unwrap_or_else(|_| db.on_propagated_panic()); let result = rx.recv().unwrap_or_else(|_| db.on_propagated_panic());
return value.changed_at > revision; return !result.cycle.is_empty() || result.value.changed_at > revision;
} }
// Consider a cycle to have changed. // Consider a cycle to have changed.
Err(CycleDetected) => return true, Err(_) => return true,
} }
} }
@ -873,14 +952,14 @@ where
return match self.read_upgrade(db, revision_now) { return match self.read_upgrade(db, revision_now) {
Ok(v) => { Ok(v) => {
debug!( debug!(
"maybe_changed_since({:?}: {:?} since (recomputed) value changed at {:?}", "maybe_changed_since({:?}: {:?} since (recomputed) value changed at {:?}",
self, self,
v.changed_at > revision, v.changed_at > revision,
v.changed_at, v.changed_at,
); );
v.changed_at > revision v.changed_at > revision
} }
Err(CycleDetected) => true, Err(_) => true,
}; };
} }
@ -968,6 +1047,7 @@ where
DB: Database + HasQueryGroup<Q::Group>, DB: Database + HasQueryGroup<Q::Group>,
MP: MemoizationPolicy<DB, Q>, MP: MemoizationPolicy<DB, Q>,
DB::DatabaseData: Send + Sync, DB::DatabaseData: Send + Sync,
DB::DatabaseKey: Send + Sync,
Q::Key: Send + Sync, Q::Key: Send + Sync,
Q::Value: Send + Sync, Q::Value: Send + Sync,
{ {

View file

@ -45,6 +45,7 @@ fn test_key_not_send_db_not_send() {}
/// ///
/// ```compile_fail,E0277 /// ```compile_fail,E0277
/// use std::rc::Rc; /// use std::rc::Rc;
/// use std::cell::Cell;
/// ///
/// #[salsa::query_group(NoSendSyncStorage)] /// #[salsa::query_group(NoSendSyncStorage)]
/// trait NoSendSyncDatabase: salsa::Database { /// trait NoSendSyncDatabase: salsa::Database {

View file

@ -1,12 +1,12 @@
use crate::debug::TableEntry; use crate::debug::TableEntry;
use crate::dependency::DatabaseSlot; use crate::dependency::DatabaseSlot;
use crate::durability::Durability; use crate::durability::Durability;
use crate::plumbing::CycleDetected;
use crate::plumbing::InputQueryStorageOps; use crate::plumbing::InputQueryStorageOps;
use crate::plumbing::QueryStorageMassOps; use crate::plumbing::QueryStorageMassOps;
use crate::plumbing::QueryStorageOps; use crate::plumbing::QueryStorageOps;
use crate::revision::Revision; use crate::revision::Revision;
use crate::runtime::StampedValue; use crate::runtime::StampedValue;
use crate::CycleError;
use crate::Database; use crate::Database;
use crate::Event; use crate::Event;
use crate::EventKind; use crate::EventKind;
@ -74,10 +74,10 @@ where
Q: Query<DB>, Q: Query<DB>,
DB: Database, DB: Database,
{ {
fn try_fetch(&self, db: &DB, key: &Q::Key) -> Result<Q::Value, CycleDetected> { fn try_fetch(&self, db: &DB, key: &Q::Key) -> Result<Q::Value, CycleError<DB::DatabaseKey>> {
let slot = self.slot(key).unwrap_or_else(|| { let slot = self
panic!("no value set for {:?}({:?})", Q::default(), key) .slot(key)
}); .unwrap_or_else(|| panic!("no value set for {:?}({:?})", Q::default(), key));
let StampedValue { let StampedValue {
value, value,

View file

@ -2,13 +2,12 @@ use crate::debug::TableEntry;
use crate::dependency::DatabaseSlot; use crate::dependency::DatabaseSlot;
use crate::durability::Durability; use crate::durability::Durability;
use crate::intern_id::InternId; use crate::intern_id::InternId;
use crate::plumbing::CycleDetected;
use crate::plumbing::HasQueryGroup; use crate::plumbing::HasQueryGroup;
use crate::plumbing::QueryStorageMassOps; use crate::plumbing::QueryStorageMassOps;
use crate::plumbing::QueryStorageOps; use crate::plumbing::QueryStorageOps;
use crate::revision::Revision; use crate::revision::Revision;
use crate::Query; use crate::Query;
use crate::{Database, DiscardIf, SweepStrategy}; use crate::{CycleError, Database, DiscardIf, SweepStrategy};
use crossbeam::atomic::AtomicCell; use crossbeam::atomic::AtomicCell;
use parking_lot::RwLock; use parking_lot::RwLock;
use rustc_hash::FxHashMap; use rustc_hash::FxHashMap;
@ -321,7 +320,7 @@ where
Q::Value: InternKey, Q::Value: InternKey,
DB: Database, DB: Database,
{ {
fn try_fetch(&self, db: &DB, key: &Q::Key) -> Result<Q::Value, CycleDetected> { fn try_fetch(&self, db: &DB, key: &Q::Key) -> Result<Q::Value, CycleError<DB::DatabaseKey>> {
let slot = self.intern_index(db, key); let slot = self.intern_index(db, key);
let changed_at = slot.interned_at; let changed_at = slot.interned_at;
let index = slot.index; let index = slot.index;
@ -420,7 +419,7 @@ where
>, >,
DB: Database + HasQueryGroup<Q::Group>, DB: Database + HasQueryGroup<Q::Group>,
{ {
fn try_fetch(&self, db: &DB, key: &Q::Key) -> Result<Q::Value, CycleDetected> { fn try_fetch(&self, db: &DB, key: &Q::Key) -> Result<Q::Value, CycleError<DB::DatabaseKey>> {
let index = key.as_intern_id(); let index = key.as_intern_id();
let group_storage = <DB as HasQueryGroup<Q::Group>>::group_storage(db); let group_storage = <DB as HasQueryGroup<Q::Group>>::group_storage(db);
let interned_storage = IQ::query_storage(group_storage); let interned_storage = IQ::query_storage(group_storage);

View file

@ -24,11 +24,11 @@ pub mod debug;
#[doc(hidden)] #[doc(hidden)]
pub mod plumbing; pub mod plumbing;
use crate::plumbing::CycleDetected;
use crate::plumbing::InputQueryStorageOps; use crate::plumbing::InputQueryStorageOps;
use crate::plumbing::LruQueryStorageOps; use crate::plumbing::LruQueryStorageOps;
use crate::plumbing::QueryStorageMassOps; use crate::plumbing::QueryStorageMassOps;
use crate::plumbing::QueryStorageOps; use crate::plumbing::QueryStorageOps;
use crate::revision::Revision;
use derive_new::new; use derive_new::new;
use std::fmt::{self, Debug}; use std::fmt::{self, Debug};
use std::hash::Hash; use std::hash::Hash;
@ -468,14 +468,11 @@ where
/// queries (those with no inputs, or those with more than one /// queries (those with no inputs, or those with more than one
/// input) the key will be a tuple. /// input) the key will be a tuple.
pub fn get(&self, key: Q::Key) -> Q::Value { pub fn get(&self, key: Q::Key) -> Q::Value {
self.storage self.try_get(key).unwrap_or_else(|err| panic!("{}", err))
.try_fetch(self.db, &key) }
.unwrap_or_else(|CycleDetected| {
let database_key = self.database_key(&key); fn try_get(&self, key: Q::Key) -> Result<Q::Value, CycleError<DB::DatabaseKey>> {
self.db self.storage.try_fetch(self.db, &key)
.salsa_runtime()
.report_unexpected_cycle(database_key)
})
} }
/// Remove all values for this query that have not been used in /// Remove all values for this query that have not been used in
@ -486,10 +483,6 @@ where
{ {
self.storage.sweep(self.db, strategy); self.storage.sweep(self.db, strategy);
} }
fn database_key(&self, key: &Q::Key) -> DB::DatabaseKey {
<DB as plumbing::GetQueryTable<Q>>::database_key(&self.db, key.clone())
}
} }
/// Return value from [the `query_mut` method] on `Database`. /// Return value from [the `query_mut` method] on `Database`.
@ -561,6 +554,28 @@ where
} }
} }
/// The error returned when a query could not be resolved due to a cycle
#[derive(Eq, PartialEq, Clone, Debug)]
pub struct CycleError<K> {
/// The queries that were part of the cycle
cycle: Vec<K>,
changed_at: Revision,
durability: Durability,
}
impl<K> fmt::Display for CycleError<K>
where
K: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "Internal error, cycle detected:\n")?;
for i in &self.cycle {
writeln!(f, "{:?}", i)?;
}
Ok(())
}
}
// Re-export the procedural macros. // Re-export the procedural macros.
#[allow(unused_imports)] #[allow(unused_imports)]
#[macro_use] #[macro_use]

View file

@ -2,10 +2,12 @@
use crate::debug::TableEntry; use crate::debug::TableEntry;
use crate::durability::Durability; use crate::durability::Durability;
use crate::CycleError;
use crate::Database; use crate::Database;
use crate::Query; use crate::Query;
use crate::QueryTable; use crate::QueryTable;
use crate::QueryTableMut; use crate::QueryTableMut;
use crate::RuntimeId;
use crate::SweepStrategy; use crate::SweepStrategy;
use std::fmt::Debug; use std::fmt::Debug;
use std::hash::Hash; use std::hash::Hash;
@ -17,7 +19,11 @@ pub use crate::interned::InternedStorage;
pub use crate::interned::LookupInternedStorage; pub use crate::interned::LookupInternedStorage;
pub use crate::revision::Revision; pub use crate::revision::Revision;
pub struct CycleDetected; #[derive(Clone, Debug)]
pub struct CycleDetected {
pub(crate) from: RuntimeId,
pub(crate) to: RuntimeId,
}
/// Defines various associated types. An impl of this /// Defines various associated types. An impl of this
/// should be generated for your query-context type automatically by /// should be generated for your query-context type automatically by
@ -61,6 +67,10 @@ pub trait DatabaseKey<DB>: Clone + Debug + Eq + Hash {}
pub trait QueryFunction<DB: Database>: Query<DB> { pub trait QueryFunction<DB: Database>: Query<DB> {
fn execute(db: &DB, key: Self::Key) -> Self::Value; fn execute(db: &DB, key: Self::Key) -> Self::Value;
fn recover(db: &DB, cycle: &[DB::DatabaseKey], key: &Self::Key) -> Option<Self::Value> {
let _ = (db, cycle, key);
None
}
} }
/// The `GetQueryTable` trait makes the connection the *database type* /// The `GetQueryTable` trait makes the connection the *database type*
@ -146,7 +156,7 @@ where
/// Returns `Err` in the event of a cycle, meaning that computing /// Returns `Err` in the event of a cycle, meaning that computing
/// the value for this `key` is recursively attempting to fetch /// the value for this `key` is recursively attempting to fetch
/// itself. /// itself.
fn try_fetch(&self, db: &DB, key: &Q::Key) -> Result<Q::Value, CycleDetected>; fn try_fetch(&self, db: &DB, key: &Q::Key) -> Result<Q::Value, CycleError<DB::DatabaseKey>>;
/// Returns the durability associated with a given key. /// Returns the durability associated with a given key.
fn durability(&self, db: &DB, key: &Q::Key) -> Durability; fn durability(&self, db: &DB, key: &Q::Key) -> Durability;

View file

@ -1,15 +1,15 @@
use crate::dependency::DatabaseSlot; use crate::dependency::DatabaseSlot;
use crate::dependency::Dependency; use crate::dependency::Dependency;
use crate::durability::Durability; use crate::durability::Durability;
use crate::plumbing::CycleDetected;
use crate::revision::{AtomicRevision, Revision}; use crate::revision::{AtomicRevision, Revision};
use crate::{Database, Event, EventKind, SweepStrategy}; use crate::{CycleError, Database, Event, EventKind, SweepStrategy};
use log::debug; use log::debug;
use parking_lot::{Mutex, RwLock};
use parking_lot::lock_api::{RawRwLock, RawRwLockRecursive}; use parking_lot::lock_api::{RawRwLock, RawRwLockRecursive};
use parking_lot::{Mutex, RwLock};
use rustc_hash::{FxHashMap, FxHasher}; use rustc_hash::{FxHashMap, FxHasher};
use smallvec::SmallVec; use smallvec::SmallVec;
use std::fmt::Write; use std::hash::{BuildHasherDefault, Hash};
use std::hash::BuildHasherDefault;
use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc; use std::sync::Arc;
@ -355,6 +355,7 @@ where
dependencies, dependencies,
changed_at, changed_at,
durability, durability,
cycle,
.. ..
} = active_query.complete(); } = active_query.complete();
@ -363,6 +364,7 @@ where
durability, durability,
changed_at, changed_at,
dependencies, dependencies,
cycle,
} }
} }
@ -407,30 +409,100 @@ where
} }
/// Obviously, this should be user configurable at some point. /// Obviously, this should be user configurable at some point.
pub(crate) fn report_unexpected_cycle(&self, database_key: DB::DatabaseKey) -> ! { pub(crate) fn report_unexpected_cycle(
&self,
database_key: &DB::DatabaseKey,
error: CycleDetected,
changed_at: Revision,
) -> crate::CycleError<DB::DatabaseKey> {
debug!("report_unexpected_cycle(database_key={:?})", database_key); debug!("report_unexpected_cycle(database_key={:?})", database_key);
let query_stack = self.local_state.borrow_query_stack(); let mut query_stack = self.local_state.borrow_query_stack_mut();
let start_index = (0..query_stack.len())
.rev()
.filter(|&i| query_stack[i].database_key == database_key)
.next()
.unwrap();
let mut message = format!("Internal error, cycle detected:\n"); if error.from == error.to {
for active_query in &query_stack[start_index..] { // All queries in the cycle is local
writeln!(message, "- {:?}\n", active_query.database_key).unwrap(); let start_index = query_stack
.iter()
.rposition(|active_query| active_query.database_key == *database_key)
.unwrap();
let mut cycle = Vec::new();
let cycle_participants = &mut query_stack[start_index..];
for active_query in &mut *cycle_participants {
cycle.push(active_query.database_key.clone());
}
assert!(!cycle.is_empty());
for active_query in cycle_participants {
active_query.cycle = cycle.clone();
}
crate::CycleError {
cycle,
changed_at,
durability: Durability::MAX,
}
} else {
// Part of the cycle is on another thread so we need to lock and inspect the shared
// state
let dependency_graph = self.shared_state.dependency_graph.lock();
let mut cycle = Vec::new();
{
let cycle_iter = dependency_graph
.get_cycle_path(
database_key,
error.to,
query_stack.iter().map(|query| &query.database_key),
)
.chain(Some(database_key));
for key in cycle_iter {
cycle.push(key.clone());
}
}
assert!(!cycle.is_empty());
for active_query in query_stack
.iter_mut()
.filter(|query| cycle.iter().any(|key| *key == query.database_key))
{
active_query.cycle = cycle.clone();
}
crate::CycleError {
cycle,
changed_at,
durability: Durability::MAX,
}
}
}
pub(crate) fn mark_cycle_participants(&self, err: &CycleError<DB::DatabaseKey>) {
for active_query in self
.local_state
.borrow_query_stack_mut()
.iter_mut()
.rev()
.take_while(|active_query| err.cycle.iter().any(|e| *e == active_query.database_key))
{
active_query.cycle = err.cycle.clone();
} }
panic!(message)
} }
/// Try to make this runtime blocked on `other_id`. Returns true /// Try to make this runtime blocked on `other_id`. Returns true
/// upon success or false if `other_id` is already blocked on us. /// upon success or false if `other_id` is already blocked on us.
pub(crate) fn try_block_on(&self, database_key: &DB::DatabaseKey, other_id: RuntimeId) -> bool { pub(crate) fn try_block_on(&self, database_key: &DB::DatabaseKey, other_id: RuntimeId) -> bool {
self.shared_state self.shared_state.dependency_graph.lock().add_edge(
.dependency_graph self.id(),
.lock() database_key,
.add_edge(self.id(), database_key, other_id) other_id,
self.local_state
.borrow_query_stack()
.iter()
.map(|query| query.database_key.clone()),
)
} }
pub(crate) fn unblock_queries_blocked_on_self(&self, database_key: &DB::DatabaseKey) { pub(crate) fn unblock_queries_blocked_on_self(&self, database_key: &DB::DatabaseKey) {
@ -508,7 +580,7 @@ struct SharedState<DB: Database> {
/// The dependency graph tracks which runtimes are blocked on one /// The dependency graph tracks which runtimes are blocked on one
/// another, waiting for queries to terminate. /// another, waiting for queries to terminate.
dependency_graph: Mutex<DependencyGraph<DB>>, dependency_graph: Mutex<DependencyGraph<DB::DatabaseKey>>,
} }
impl<DB: Database> SharedState<DB> { impl<DB: Database> SharedState<DB> {
@ -571,6 +643,9 @@ struct ActiveQuery<DB: Database> {
/// Set of subqueries that were accessed thus far, or `None` if /// Set of subqueries that were accessed thus far, or `None` if
/// there was an untracked the read. /// there was an untracked the read.
dependencies: Option<FxIndexSet<Dependency<DB>>>, dependencies: Option<FxIndexSet<Dependency<DB>>>,
/// Stores the entire cycle, if one is found and this query is part of it.
cycle: Vec<DB::DatabaseKey>,
} }
pub(crate) struct ComputedQueryResult<DB: Database, V> { pub(crate) struct ComputedQueryResult<DB: Database, V> {
@ -587,6 +662,9 @@ pub(crate) struct ComputedQueryResult<DB: Database, V> {
/// Complete set of subqueries that were accessed, or `None` if /// Complete set of subqueries that were accessed, or `None` if
/// there was an untracked the read. /// there was an untracked the read.
pub(crate) dependencies: Option<FxIndexSet<Dependency<DB>>>, pub(crate) dependencies: Option<FxIndexSet<Dependency<DB>>>,
/// The cycle if one occured while computing this value
pub(crate) cycle: Vec<DB::DatabaseKey>,
} }
impl<DB: Database> ActiveQuery<DB> { impl<DB: Database> ActiveQuery<DB> {
@ -596,6 +674,7 @@ impl<DB: Database> ActiveQuery<DB> {
durability: max_durability, durability: max_durability,
changed_at: Revision::start(), changed_at: Revision::start(),
dependencies: Some(FxIndexSet::default()), dependencies: Some(FxIndexSet::default()),
cycle: Vec::new(),
} }
} }
@ -634,16 +713,26 @@ pub(crate) struct StampedValue<V> {
pub(crate) changed_at: Revision, pub(crate) changed_at: Revision,
} }
struct DependencyGraph<DB: Database> { #[derive(Debug)]
struct Edge<K> {
id: RuntimeId,
path: Vec<K>,
}
#[derive(Debug)]
struct DependencyGraph<K: Hash + Eq> {
/// A `(K -> V)` pair in this map indicates that the the runtime /// A `(K -> V)` pair in this map indicates that the the runtime
/// `K` is blocked on some query executing in the runtime `V`. /// `K` is blocked on some query executing in the runtime `V`.
/// This encodes a graph that must be acyclic (or else deadlock /// This encodes a graph that must be acyclic (or else deadlock
/// will result). /// will result).
edges: FxHashMap<RuntimeId, RuntimeId>, edges: FxHashMap<RuntimeId, Edge<K>>,
labels: FxHashMap<DB::DatabaseKey, SmallVec<[RuntimeId; 4]>>, labels: FxHashMap<K, SmallVec<[RuntimeId; 4]>>,
} }
impl<DB: Database> Default for DependencyGraph<DB> { impl<K> Default for DependencyGraph<K>
where
K: Hash + Eq,
{
fn default() -> Self { fn default() -> Self {
DependencyGraph { DependencyGraph {
edges: Default::default(), edges: Default::default(),
@ -652,13 +741,17 @@ impl<DB: Database> Default for DependencyGraph<DB> {
} }
} }
impl<DB: Database> DependencyGraph<DB> { impl<K> DependencyGraph<K>
where
K: Hash + Eq + Clone,
{
/// Attempt to add an edge `from_id -> to_id` into the result graph. /// Attempt to add an edge `from_id -> to_id` into the result graph.
fn add_edge( fn add_edge(
&mut self, &mut self,
from_id: RuntimeId, from_id: RuntimeId,
database_key: &DB::DatabaseKey, database_key: &K,
to_id: RuntimeId, to_id: RuntimeId,
path: impl IntoIterator<Item = K>,
) -> bool { ) -> bool {
assert_ne!(from_id, to_id); assert_ne!(from_id, to_id);
debug_assert!(!self.edges.contains_key(&from_id)); debug_assert!(!self.edges.contains_key(&from_id));
@ -666,7 +759,7 @@ impl<DB: Database> DependencyGraph<DB> {
// First: walk the chain of things that `to_id` depends on, // First: walk the chain of things that `to_id` depends on,
// looking for us. // looking for us.
let mut p = to_id; let mut p = to_id;
while let Some(&q) = self.edges.get(&p) { while let Some(q) = self.edges.get(&p).map(|edge| edge.id) {
if q == from_id { if q == from_id {
return false; return false;
} }
@ -674,7 +767,13 @@ impl<DB: Database> DependencyGraph<DB> {
p = q; p = q;
} }
self.edges.insert(from_id, to_id); self.edges.insert(
from_id,
Edge {
id: to_id,
path: path.into_iter().chain(Some(database_key.clone())).collect(),
},
);
self.labels self.labels
.entry(database_key.clone()) .entry(database_key.clone())
.or_default() .or_default()
@ -682,17 +781,54 @@ impl<DB: Database> DependencyGraph<DB> {
true true
} }
fn remove_edge(&mut self, database_key: &DB::DatabaseKey, to_id: RuntimeId) { fn remove_edge(&mut self, database_key: &K, to_id: RuntimeId) {
let vec = self let vec = self.labels.remove(database_key).unwrap_or_default();
.labels
.remove(database_key)
.unwrap_or_default();
for from_id in &vec { for from_id in &vec {
let to_id1 = self.edges.remove(from_id); let to_id1 = self.edges.remove(from_id).map(|edge| edge.id);
assert_eq!(Some(to_id), to_id1); assert_eq!(Some(to_id), to_id1);
} }
} }
fn get_cycle_path<'a>(
&'a self,
database_key: &'a K,
to: RuntimeId,
local_path: impl IntoIterator<Item = &'a K>,
) -> impl Iterator<Item = &'a K>
where
K: std::fmt::Debug,
{
let mut current = Some((to, std::slice::from_ref(database_key)));
let mut last = None;
let mut local_path = Some(local_path);
std::iter::from_fn(move || match current.take() {
Some((id, path)) => {
let link_key = path.last().unwrap();
current = self.edges.get(&id).map(|edge| {
let i = edge.path.iter().rposition(|p| p == link_key).unwrap();
(edge.id, &edge.path[i + 1..])
});
if current.is_none() {
last = local_path.take().map(|local_path| {
local_path
.into_iter()
.skip_while(move |p| *p != link_key)
.skip(1)
});
}
Some(path)
}
None => match &mut last {
Some(iter) => iter.next().map(std::slice::from_ref),
None => None,
},
})
.flat_map(|x| x)
}
} }
struct RevisionGuard<DB: Database> { struct RevisionGuard<DB: Database> {
@ -742,3 +878,42 @@ where
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn dependency_graph_path1() {
let mut graph = DependencyGraph::default();
let a = RuntimeId { counter: 0 };
let b = RuntimeId { counter: 1 };
assert!(graph.add_edge(a, &2, b, vec![1]));
// assert!(graph.add_edge(b, &1, a, vec![3, 2]));
assert_eq!(
graph
.get_cycle_path(&1, a, &[3, 2][..])
.cloned()
.collect::<Vec<i32>>(),
vec![1, 2]
);
}
#[test]
fn dependency_graph_path2() {
let mut graph = DependencyGraph::default();
let a = RuntimeId { counter: 0 };
let b = RuntimeId { counter: 1 };
let c = RuntimeId { counter: 2 };
assert!(graph.add_edge(a, &3, b, vec![1]));
assert!(graph.add_edge(b, &4, c, vec![2, 3]));
// assert!(graph.add_edge(c, &1, a, vec![5, 6, 4, 7]));
assert_eq!(
graph
.get_cycle_path(&1, a, &[5, 6, 4, 7][..])
.cloned()
.collect::<Vec<i32>>(),
vec![1, 3, 4, 7]
);
}
}

View file

@ -3,8 +3,7 @@ use crate::durability::Durability;
use crate::runtime::ActiveQuery; use crate::runtime::ActiveQuery;
use crate::runtime::Revision; use crate::runtime::Revision;
use crate::Database; use crate::Database;
use std::cell::Ref; use std::cell::{Ref, RefCell, RefMut};
use std::cell::RefCell;
/// State that is specific to a single execution thread. /// State that is specific to a single execution thread.
/// ///
@ -51,6 +50,10 @@ impl<DB: Database> LocalState<DB> {
self.query_stack.borrow() self.query_stack.borrow()
} }
pub(super) fn borrow_query_stack_mut(&self) -> RefMut<'_, Vec<ActiveQuery<DB>>> {
self.query_stack.borrow_mut()
}
pub(super) fn query_in_progress(&self) -> bool { pub(super) fn query_in_progress(&self) -> bool {
!self.query_stack.borrow().is_empty() !self.query_stack.borrow().is_empty()
} }

View file

@ -1,3 +1,10 @@
use salsa::{ParallelDatabase, Snapshot};
#[derive(PartialEq, Eq, Hash, Clone, Debug)]
struct Error {
cycle: Vec<String>,
}
#[salsa::database(GroupStruct)] #[salsa::database(GroupStruct)]
#[derive(Default)] #[derive(Default)]
struct DatabaseImpl { struct DatabaseImpl {
@ -10,6 +17,14 @@ impl salsa::Database for DatabaseImpl {
} }
} }
impl ParallelDatabase for DatabaseImpl {
fn snapshot(&self) -> Snapshot<Self> {
Snapshot::new(DatabaseImpl {
runtime: self.runtime.snapshot(self),
})
}
}
#[salsa::query_group(GroupStruct)] #[salsa::query_group(GroupStruct)]
trait Database: salsa::Database { trait Database: salsa::Database {
// `a` and `b` depend on each other and form a cycle // `a` and `b` depend on each other and form a cycle
@ -17,6 +32,27 @@ trait Database: salsa::Database {
fn memoized_b(&self) -> (); fn memoized_b(&self) -> ();
fn volatile_a(&self) -> (); fn volatile_a(&self) -> ();
fn volatile_b(&self) -> (); fn volatile_b(&self) -> ();
fn cycle_leaf(&self) -> ();
#[salsa::cycle(recover_a)]
fn cycle_a(&self) -> Result<(), Error>;
#[salsa::cycle(recover_b)]
fn cycle_b(&self) -> Result<(), Error>;
fn cycle_c(&self) -> Result<(), Error>;
}
fn recover_a(_db: &impl Database, cycle: &[String]) -> Result<(), Error> {
Err(Error {
cycle: cycle.to_owned(),
})
}
fn recover_b(_db: &impl Database, cycle: &[String]) -> Result<(), Error> {
Err(Error {
cycle: cycle.to_owned(),
})
} }
fn memoized_a(db: &impl Database) -> () { fn memoized_a(db: &impl Database) -> () {
@ -37,6 +73,23 @@ fn volatile_b(db: &impl Database) -> () {
db.volatile_a() db.volatile_a()
} }
fn cycle_leaf(_db: &impl Database) -> () {}
fn cycle_a(db: &impl Database) -> Result<(), Error> {
let _ = db.cycle_b();
Ok(())
}
fn cycle_b(db: &impl Database) -> Result<(), Error> {
db.cycle_leaf();
let _ = db.cycle_a();
Ok(())
}
fn cycle_c(db: &impl Database) -> Result<(), Error> {
db.cycle_b()
}
#[test] #[test]
#[should_panic(expected = "cycle detected")] #[should_panic(expected = "cycle detected")]
fn cycle_memoized() { fn cycle_memoized() {
@ -50,3 +103,64 @@ fn cycle_volatile() {
let query = DatabaseImpl::default(); let query = DatabaseImpl::default();
query.volatile_a(); query.volatile_a();
} }
#[test]
fn cycle_cycle() {
let query = DatabaseImpl::default();
assert!(query.cycle_a().is_err());
}
#[test]
fn inner_cycle() {
let query = DatabaseImpl::default();
let err = query.cycle_c();
assert!(err.is_err());
let cycle = err.unwrap_err().cycle;
assert!(
cycle
.iter()
.zip(&["cycle_b", "cycle_a"])
.all(|(l, r)| l.contains(r)),
"{:#?}",
cycle
);
}
#[test]
fn parallel_cycle() {
let _ = env_logger::try_init();
let db = DatabaseImpl::default();
let thread1 = std::thread::spawn({
let db = db.snapshot();
move || {
let result = db.cycle_a();
assert!(result.is_err(), "Expected cycle error");
let cycle = result.unwrap_err().cycle;
assert!(
cycle
.iter()
.all(|l| ["cycle_b", "cycle_a"].iter().any(|r| l.contains(r))),
"{:#?}",
cycle
);
}
});
let thread2 = std::thread::spawn(move || {
let result = db.cycle_c();
assert!(result.is_err(), "Expected cycle error");
let cycle = result.unwrap_err().cycle;
assert!(
cycle
.iter()
.all(|l| ["cycle_b", "cycle_a"].iter().any(|r| l.contains(r))),
"{:#?}",
cycle
);
});
thread1.join().unwrap();
thread2.join().unwrap();
eprintln!("OK");
}

View file

@ -1,6 +1,4 @@
use crate::setup::{ use crate::setup::{CancelationFlag, Canceled, Knobs, ParDatabase, ParDatabaseImpl, WithValue};
CancelationFlag, Canceled, Knobs, ParDatabase, ParDatabaseImpl, WithValue,
};
use salsa::ParallelDatabase; use salsa::ParallelDatabase;
macro_rules! assert_canceled { macro_rules! assert_canceled {

View file

@ -22,7 +22,7 @@ fn in_par_two_independent_queries() {
let thread2 = std::thread::spawn({ let thread2 = std::thread::spawn({
let db = db.snapshot(); let db = db.snapshot();
move || db.sum("def") move || db.sum("def")
});; });
assert_eq!(thread1.join().unwrap(), 111); assert_eq!(thread1.join().unwrap(), 111);
assert_eq!(thread2.join().unwrap(), 222); assert_eq!(thread2.join().unwrap(), 222);

View file

@ -1,7 +1,6 @@
//! Test `salsa::requires` attribute for private query dependencies //! Test `salsa::requires` attribute for private query dependencies
//! https://github.com/salsa-rs/salsa-rfcs/pull/3 //! https://github.com/salsa-rs/salsa-rfcs/pull/3
mod queries { mod queries {
#[salsa::query_group(InputGroupStorage)] #[salsa::query_group(InputGroupStorage)]
pub trait InputGroup { pub trait InputGroup {
@ -14,7 +13,7 @@ mod queries {
fn private_a(&self, x: u32) -> u32; fn private_a(&self, x: u32) -> u32;
} }
fn private_a(db: &impl PrivGroupA, x: u32) -> u32{ fn private_a(db: &impl PrivGroupA, x: u32) -> u32 {
db.input(x) db.input(x)
} }
@ -23,7 +22,7 @@ mod queries {
fn private_b(&self, x: u32) -> u32; fn private_b(&self, x: u32) -> u32;
} }
fn private_b(db: &impl PrivGroupB, x: u32) -> u32{ fn private_b(db: &impl PrivGroupB, x: u32) -> u32 {
db.input(x) db.input(x)
} }
@ -34,7 +33,6 @@ mod queries {
fn public(&self, x: u32) -> u32; fn public(&self, x: u32) -> u32;
} }
fn public(db: &(impl PubGroup + PrivGroupA + PrivGroupB), x: u32) -> u32 { fn public(db: &(impl PubGroup + PrivGroupA + PrivGroupB), x: u32) -> u32 {
db.private_a(x) + db.private_b(x) db.private_a(x) + db.private_b(x)
} }
@ -44,7 +42,7 @@ mod queries {
queries::InputGroupStorage, queries::InputGroupStorage,
queries::PrivGroupAStorage, queries::PrivGroupAStorage,
queries::PrivGroupBStorage, queries::PrivGroupBStorage,
queries::PubGroupStorage, queries::PubGroupStorage
)] )]
#[derive(Default)] #[derive(Default)]
struct Database { struct Database {

View file

@ -17,7 +17,6 @@ fn get(db: &impl QueryGroup, x: u32) -> u32 {
db.wrap(x) db.wrap(x)
} }
#[salsa::database(QueryGroupStorage)] #[salsa::database(QueryGroupStorage)]
#[derive(Default)] #[derive(Default)]
struct Database { struct Database {