Merge pull request #147 from Marwes/cycles

feat: Allow queries to avoid panics on cycles
This commit is contained in:
Niko Matsakis 2019-09-19 05:56:39 -04:00 committed by GitHub
commit a9860bf37f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 520 additions and 106 deletions

View file

@ -39,6 +39,7 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream
match item {
TraitItem::Method(method) => {
let mut storage = QueryStorage::Memoized;
let mut cycle = None;
let mut invoke = None;
let mut query_type = Ident::new(
&format!("{}Query", method.sig.ident.to_string().to_camel_case()),
@ -66,6 +67,9 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream
storage = QueryStorage::Interned;
num_storages += 1;
}
"cycle" => {
cycle = Some(parse_macro_input!(tts as Parenthesized<syn::Path>).0);
}
"invoke" => {
invoke = Some(parse_macro_input!(tts as Parenthesized<syn::Path>).0);
}
@ -150,6 +154,7 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream
keys: lookup_keys,
value: lookup_value,
invoke: None,
cycle: cycle.clone(),
})
} else {
None
@ -163,6 +168,7 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream
keys,
value,
invoke,
cycle,
});
queries.extend(lookup_query);
@ -354,9 +360,7 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream
QueryStorage::Dependencies => quote!(salsa::plumbing::DependencyStorage<#db, Self>),
QueryStorage::Input => quote!(salsa::plumbing::InputStorage<#db, Self>),
QueryStorage::Interned => quote!(salsa::plumbing::InternedStorage<#db, Self>),
QueryStorage::InternedLookup { intern_query_type } => {
quote!(salsa::plumbing::LookupInternedStorage<#db, Self, #intern_query_type>)
}
QueryStorage::InternedLookup { intern_query_type } => quote!(salsa::plumbing::LookupInternedStorage<#db, Self, #intern_query_type>),
QueryStorage::Transparent => continue,
};
let keys = &query.keys;
@ -404,6 +408,22 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream
quote! { (#(#key_names),*) }
};
let invoke = query.invoke_tt();
let recover = if let Some(cycle_recovery_fn) = &query.cycle {
quote! {
fn recover(db: &DB, cycle: &[DB::DatabaseKey], #key_pattern: &<Self as salsa::Query<DB>>::Key)
-> Option<<Self as salsa::Query<DB>>::Value> {
Some(#cycle_recovery_fn(
db,
&cycle.iter().map(|k| format!("{:?}", k)).collect::<Vec<String>>(),
#(#key_names),*
))
}
}
} else {
quote! {}
};
output.extend(quote_spanned! {span=>
impl<DB> salsa::plumbing::QueryFunction<DB> for #qt
where
@ -415,6 +435,8 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream
-> <Self as salsa::Query<DB>>::Value {
#invoke(db, #(#key_names),*)
}
#recover
}
});
}
@ -541,6 +563,7 @@ struct Query {
keys: Vec<syn::Type>,
value: syn::Type,
invoke: Option<syn::Path>,
cycle: Option<syn::Path>,
}
impl Query {

View file

@ -1,14 +1,13 @@
use crate::debug::TableEntry;
use crate::durability::Durability;
use crate::lru::Lru;
use crate::plumbing::CycleDetected;
use crate::plumbing::HasQueryGroup;
use crate::plumbing::LruQueryStorageOps;
use crate::plumbing::QueryFunction;
use crate::plumbing::QueryStorageMassOps;
use crate::plumbing::QueryStorageOps;
use crate::runtime::StampedValue;
use crate::{Database, SweepStrategy};
use crate::{CycleError, Database, SweepStrategy};
use parking_lot::RwLock;
use rustc_hash::FxHashMap;
use std::marker::PhantomData;
@ -131,7 +130,7 @@ where
DB: Database + HasQueryGroup<Q::Group>,
MP: MemoizationPolicy<DB, Q>,
{
fn try_fetch(&self, db: &DB, key: &Q::Key) -> Result<Q::Value, CycleDetected> {
fn try_fetch(&self, db: &DB, key: &Q::Key) -> Result<Q::Value, CycleError<DB::DatabaseKey>> {
let slot = self.slot(key);
let StampedValue {
value,

View file

@ -14,7 +14,7 @@ use crate::runtime::FxIndexSet;
use crate::runtime::Runtime;
use crate::runtime::RuntimeId;
use crate::runtime::StampedValue;
use crate::{Database, DiscardIf, DiscardWhat, Event, EventKind, SweepStrategy};
use crate::{CycleError, Database, DiscardIf, DiscardWhat, Event, EventKind, SweepStrategy};
use log::{debug, info};
use parking_lot::Mutex;
use parking_lot::RwLock;
@ -36,6 +36,12 @@ where
lru_index: LruIndex,
}
#[derive(Clone)]
struct WaitResult<V, K> {
value: StampedValue<V>,
cycle: Vec<K>,
}
/// Defines the "current state" of query's memoized results.
enum QueryState<DB, Q>
where
@ -49,7 +55,7 @@ where
/// indeeds a cycle.
InProgress {
id: RuntimeId,
waiting: Mutex<SmallVec<[Sender<StampedValue<Q::Value>>; 2]>>,
waiting: Mutex<SmallVec<[Sender<WaitResult<Q::Value, DB::DatabaseKey>>; 2]>>,
},
/// We have computed the query already, and here is the result.
@ -95,8 +101,8 @@ pub(super) enum MemoInputs<DB: Database> {
}
/// Return value of `probe` helper.
enum ProbeState<V, G> {
UpToDate(Result<V, CycleDetected>),
enum ProbeState<V, K, G> {
UpToDate(Result<V, CycleError<K>>),
StaleOrAbsent(G),
}
@ -119,7 +125,10 @@ where
<DB as GetQueryTable<Q>>::database_key(db, self.key.clone())
}
pub(super) fn read(&self, db: &DB) -> Result<StampedValue<Q::Value>, CycleDetected> {
pub(super) fn read(
&self,
db: &DB,
) -> Result<StampedValue<Q::Value>, CycleError<DB::DatabaseKey>> {
let runtime = db.salsa_runtime();
// NB: We don't need to worry about people modifying the
@ -148,7 +157,7 @@ where
&self,
db: &DB,
revision_now: Revision,
) -> Result<StampedValue<Q::Value>, CycleDetected> {
) -> Result<StampedValue<Q::Value>, CycleError<DB::DatabaseKey>> {
let runtime = db.salsa_runtime();
debug!("{:?}: read_upgrade(revision_now={:?})", self, revision_now,);
@ -189,7 +198,15 @@ where
},
});
panic_guard.proceed(&value);
panic_guard.proceed(
&value,
// The returned value could have been produced as part of a cycle but since
// we returned the memoized value we know we short-circuited the execution
// just as we entered the cycle. Therefore there is no values to invalidate
// and no need to call a cycle handler so we do not need to return the
// actual cycle
Vec::new(),
);
return Ok(value);
}
@ -203,6 +220,21 @@ where
Q::execute(db, self.key.clone())
});
if !result.cycle.is_empty() {
result.value = match Q::recover(db, &result.cycle, &self.key) {
Some(v) => v,
None => {
let err = CycleError {
cycle: result.cycle,
durability: result.durability,
changed_at: result.changed_at,
};
panic_guard.report_unexpected_cycle();
return Err(err);
}
};
}
// We assume that query is side-effect free -- that is, does
// not mutate the "inputs" to the query system. Sanity check
// that assumption here, at least to the best of our ability.
@ -277,7 +309,7 @@ where
durability: result.durability,
});
panic_guard.proceed(&new_value);
panic_guard.proceed(&new_value, result.cycle);
Ok(new_value)
}
@ -308,7 +340,7 @@ where
state: StateGuard,
runtime: &Runtime<DB>,
revision_now: Revision,
) -> ProbeState<StampedValue<Q::Value>, StateGuard>
) -> ProbeState<StampedValue<Q::Value>, DB::DatabaseKey, StateGuard>
where
StateGuard: Deref<Target = QueryState<DB, Q>>,
{
@ -331,11 +363,42 @@ where
},
});
let value = rx.recv().unwrap_or_else(|_| db.on_propagated_panic());
ProbeState::UpToDate(Ok(value))
let result = rx.recv().unwrap_or_else(|_| db.on_propagated_panic());
ProbeState::UpToDate(if result.cycle.is_empty() {
Ok(result.value)
} else {
let err = CycleError {
cycle: result.cycle,
changed_at: result.value.changed_at,
durability: result.value.durability,
};
runtime.mark_cycle_participants(&err);
Q::recover(db, &err.cycle, &self.key)
.map(|value| StampedValue {
value,
durability: err.durability,
changed_at: err.changed_at,
})
.ok_or_else(|| err)
})
}
Err(CycleDetected) => ProbeState::UpToDate(Err(CycleDetected)),
Err(err) => {
let err = runtime.report_unexpected_cycle(
&self.database_key(db),
err,
revision_now,
);
ProbeState::UpToDate(
Q::recover(db, &err.cycle, &self.key)
.map(|value| StampedValue {
value,
changed_at: err.changed_at,
durability: err.durability,
})
.ok_or_else(|| err),
)
}
};
}
@ -483,13 +546,17 @@ where
db: &DB,
runtime: &Runtime<DB>,
other_id: RuntimeId,
waiting: &Mutex<SmallVec<[Sender<StampedValue<Q::Value>>; 2]>>,
) -> Result<Receiver<StampedValue<Q::Value>>, CycleDetected> {
if other_id == runtime.id() {
return Err(CycleDetected);
waiting: &Mutex<SmallVec<[Sender<WaitResult<Q::Value, DB::DatabaseKey>>; 2]>>,
) -> Result<Receiver<WaitResult<Q::Value, DB::DatabaseKey>>, CycleDetected> {
let id = runtime.id();
if other_id == id {
return Err(CycleDetected { from: id, to: id });
} else {
if !runtime.try_block_on(&self.database_key(db), other_id) {
return Err(CycleDetected);
return Err(CycleDetected {
from: id,
to: other_id,
});
}
let (tx, rx) = mpsc::channel();
@ -555,15 +622,23 @@ where
/// Proceed with our panic guard by overwriting the placeholder for `key`.
/// Once that completes, ensure that our deconstructor is not run once we
/// are out of scope.
fn proceed(mut self, new_value: &StampedValue<Q::Value>) {
self.overwrite_placeholder(Some(new_value));
fn proceed(mut self, new_value: &StampedValue<Q::Value>, cycle: Vec<DB::DatabaseKey>) {
self.overwrite_placeholder(Some((new_value, cycle)));
std::mem::forget(self)
}
fn report_unexpected_cycle(mut self) {
self.overwrite_placeholder(None);
std::mem::forget(self)
}
/// Overwrites the `InProgress` placeholder for `key` that we
/// inserted; if others were blocked, waiting for us to finish,
/// then notify them.
fn overwrite_placeholder(&mut self, new_value: Option<&StampedValue<Q::Value>>) {
fn overwrite_placeholder(
&mut self,
new_value: Option<(&StampedValue<Q::Value>, Vec<DB::DatabaseKey>)>,
) {
let mut write = self.slot.state.write();
let old_value = match self.memo.take() {
@ -587,9 +662,13 @@ where
match new_value {
// If anybody has installed themselves in our "waiting"
// list, notify them that the value is available.
Some(new_value) => {
Some((new_value, ref cycle)) => {
for tx in waiting.into_inner() {
tx.send(new_value.clone()).unwrap()
tx.send(WaitResult {
value: new_value.clone(),
cycle: cycle.clone(),
})
.unwrap();
}
}
@ -811,12 +890,12 @@ where
// can complete.
std::mem::drop(state);
let value = rx.recv().unwrap_or_else(|_| db.on_propagated_panic());
return value.changed_at > revision;
let result = rx.recv().unwrap_or_else(|_| db.on_propagated_panic());
return !result.cycle.is_empty() || result.value.changed_at > revision;
}
// Consider a cycle to have changed.
Err(CycleDetected) => return true,
Err(_) => return true,
}
}
@ -873,14 +952,14 @@ where
return match self.read_upgrade(db, revision_now) {
Ok(v) => {
debug!(
"maybe_changed_since({:?}: {:?} since (recomputed) value changed at {:?}",
self,
"maybe_changed_since({:?}: {:?} since (recomputed) value changed at {:?}",
self,
v.changed_at > revision,
v.changed_at,
);
v.changed_at,
);
v.changed_at > revision
}
Err(CycleDetected) => true,
Err(_) => true,
};
}
@ -968,6 +1047,7 @@ where
DB: Database + HasQueryGroup<Q::Group>,
MP: MemoizationPolicy<DB, Q>,
DB::DatabaseData: Send + Sync,
DB::DatabaseKey: Send + Sync,
Q::Key: Send + Sync,
Q::Value: Send + Sync,
{

View file

@ -45,6 +45,7 @@ fn test_key_not_send_db_not_send() {}
///
/// ```compile_fail,E0277
/// use std::rc::Rc;
/// use std::cell::Cell;
///
/// #[salsa::query_group(NoSendSyncStorage)]
/// trait NoSendSyncDatabase: salsa::Database {

View file

@ -1,12 +1,12 @@
use crate::debug::TableEntry;
use crate::dependency::DatabaseSlot;
use crate::durability::Durability;
use crate::plumbing::CycleDetected;
use crate::plumbing::InputQueryStorageOps;
use crate::plumbing::QueryStorageMassOps;
use crate::plumbing::QueryStorageOps;
use crate::revision::Revision;
use crate::runtime::StampedValue;
use crate::CycleError;
use crate::Database;
use crate::Event;
use crate::EventKind;
@ -74,10 +74,10 @@ where
Q: Query<DB>,
DB: Database,
{
fn try_fetch(&self, db: &DB, key: &Q::Key) -> Result<Q::Value, CycleDetected> {
let slot = self.slot(key).unwrap_or_else(|| {
panic!("no value set for {:?}({:?})", Q::default(), key)
});
fn try_fetch(&self, db: &DB, key: &Q::Key) -> Result<Q::Value, CycleError<DB::DatabaseKey>> {
let slot = self
.slot(key)
.unwrap_or_else(|| panic!("no value set for {:?}({:?})", Q::default(), key));
let StampedValue {
value,

View file

@ -2,13 +2,12 @@ use crate::debug::TableEntry;
use crate::dependency::DatabaseSlot;
use crate::durability::Durability;
use crate::intern_id::InternId;
use crate::plumbing::CycleDetected;
use crate::plumbing::HasQueryGroup;
use crate::plumbing::QueryStorageMassOps;
use crate::plumbing::QueryStorageOps;
use crate::revision::Revision;
use crate::Query;
use crate::{Database, DiscardIf, SweepStrategy};
use crate::{CycleError, Database, DiscardIf, SweepStrategy};
use crossbeam::atomic::AtomicCell;
use parking_lot::RwLock;
use rustc_hash::FxHashMap;
@ -321,7 +320,7 @@ where
Q::Value: InternKey,
DB: Database,
{
fn try_fetch(&self, db: &DB, key: &Q::Key) -> Result<Q::Value, CycleDetected> {
fn try_fetch(&self, db: &DB, key: &Q::Key) -> Result<Q::Value, CycleError<DB::DatabaseKey>> {
let slot = self.intern_index(db, key);
let changed_at = slot.interned_at;
let index = slot.index;
@ -420,7 +419,7 @@ where
>,
DB: Database + HasQueryGroup<Q::Group>,
{
fn try_fetch(&self, db: &DB, key: &Q::Key) -> Result<Q::Value, CycleDetected> {
fn try_fetch(&self, db: &DB, key: &Q::Key) -> Result<Q::Value, CycleError<DB::DatabaseKey>> {
let index = key.as_intern_id();
let group_storage = <DB as HasQueryGroup<Q::Group>>::group_storage(db);
let interned_storage = IQ::query_storage(group_storage);

View file

@ -24,11 +24,11 @@ pub mod debug;
#[doc(hidden)]
pub mod plumbing;
use crate::plumbing::CycleDetected;
use crate::plumbing::InputQueryStorageOps;
use crate::plumbing::LruQueryStorageOps;
use crate::plumbing::QueryStorageMassOps;
use crate::plumbing::QueryStorageOps;
use crate::revision::Revision;
use derive_new::new;
use std::fmt::{self, Debug};
use std::hash::Hash;
@ -468,14 +468,11 @@ where
/// queries (those with no inputs, or those with more than one
/// input) the key will be a tuple.
pub fn get(&self, key: Q::Key) -> Q::Value {
self.storage
.try_fetch(self.db, &key)
.unwrap_or_else(|CycleDetected| {
let database_key = self.database_key(&key);
self.db
.salsa_runtime()
.report_unexpected_cycle(database_key)
})
self.try_get(key).unwrap_or_else(|err| panic!("{}", err))
}
fn try_get(&self, key: Q::Key) -> Result<Q::Value, CycleError<DB::DatabaseKey>> {
self.storage.try_fetch(self.db, &key)
}
/// Remove all values for this query that have not been used in
@ -486,10 +483,6 @@ where
{
self.storage.sweep(self.db, strategy);
}
fn database_key(&self, key: &Q::Key) -> DB::DatabaseKey {
<DB as plumbing::GetQueryTable<Q>>::database_key(&self.db, key.clone())
}
}
/// Return value from [the `query_mut` method] on `Database`.
@ -561,6 +554,28 @@ where
}
}
/// The error returned when a query could not be resolved due to a cycle
#[derive(Eq, PartialEq, Clone, Debug)]
pub struct CycleError<K> {
/// The queries that were part of the cycle
cycle: Vec<K>,
changed_at: Revision,
durability: Durability,
}
impl<K> fmt::Display for CycleError<K>
where
K: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "Internal error, cycle detected:\n")?;
for i in &self.cycle {
writeln!(f, "{:?}", i)?;
}
Ok(())
}
}
// Re-export the procedural macros.
#[allow(unused_imports)]
#[macro_use]

View file

@ -2,10 +2,12 @@
use crate::debug::TableEntry;
use crate::durability::Durability;
use crate::CycleError;
use crate::Database;
use crate::Query;
use crate::QueryTable;
use crate::QueryTableMut;
use crate::RuntimeId;
use crate::SweepStrategy;
use std::fmt::Debug;
use std::hash::Hash;
@ -17,7 +19,11 @@ pub use crate::interned::InternedStorage;
pub use crate::interned::LookupInternedStorage;
pub use crate::revision::Revision;
pub struct CycleDetected;
#[derive(Clone, Debug)]
pub struct CycleDetected {
pub(crate) from: RuntimeId,
pub(crate) to: RuntimeId,
}
/// Defines various associated types. An impl of this
/// should be generated for your query-context type automatically by
@ -61,6 +67,10 @@ pub trait DatabaseKey<DB>: Clone + Debug + Eq + Hash {}
pub trait QueryFunction<DB: Database>: Query<DB> {
fn execute(db: &DB, key: Self::Key) -> Self::Value;
fn recover(db: &DB, cycle: &[DB::DatabaseKey], key: &Self::Key) -> Option<Self::Value> {
let _ = (db, cycle, key);
None
}
}
/// The `GetQueryTable` trait makes the connection the *database type*
@ -146,7 +156,7 @@ where
/// Returns `Err` in the event of a cycle, meaning that computing
/// the value for this `key` is recursively attempting to fetch
/// itself.
fn try_fetch(&self, db: &DB, key: &Q::Key) -> Result<Q::Value, CycleDetected>;
fn try_fetch(&self, db: &DB, key: &Q::Key) -> Result<Q::Value, CycleError<DB::DatabaseKey>>;
/// Returns the durability associated with a given key.
fn durability(&self, db: &DB, key: &Q::Key) -> Durability;

View file

@ -1,15 +1,15 @@
use crate::dependency::DatabaseSlot;
use crate::dependency::Dependency;
use crate::durability::Durability;
use crate::plumbing::CycleDetected;
use crate::revision::{AtomicRevision, Revision};
use crate::{Database, Event, EventKind, SweepStrategy};
use crate::{CycleError, Database, Event, EventKind, SweepStrategy};
use log::debug;
use parking_lot::{Mutex, RwLock};
use parking_lot::lock_api::{RawRwLock, RawRwLockRecursive};
use parking_lot::{Mutex, RwLock};
use rustc_hash::{FxHashMap, FxHasher};
use smallvec::SmallVec;
use std::fmt::Write;
use std::hash::BuildHasherDefault;
use std::hash::{BuildHasherDefault, Hash};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
@ -355,6 +355,7 @@ where
dependencies,
changed_at,
durability,
cycle,
..
} = active_query.complete();
@ -363,6 +364,7 @@ where
durability,
changed_at,
dependencies,
cycle,
}
}
@ -407,30 +409,100 @@ where
}
/// Obviously, this should be user configurable at some point.
pub(crate) fn report_unexpected_cycle(&self, database_key: DB::DatabaseKey) -> ! {
pub(crate) fn report_unexpected_cycle(
&self,
database_key: &DB::DatabaseKey,
error: CycleDetected,
changed_at: Revision,
) -> crate::CycleError<DB::DatabaseKey> {
debug!("report_unexpected_cycle(database_key={:?})", database_key);
let query_stack = self.local_state.borrow_query_stack();
let start_index = (0..query_stack.len())
.rev()
.filter(|&i| query_stack[i].database_key == database_key)
.next()
.unwrap();
let mut query_stack = self.local_state.borrow_query_stack_mut();
let mut message = format!("Internal error, cycle detected:\n");
for active_query in &query_stack[start_index..] {
writeln!(message, "- {:?}\n", active_query.database_key).unwrap();
if error.from == error.to {
// All queries in the cycle is local
let start_index = query_stack
.iter()
.rposition(|active_query| active_query.database_key == *database_key)
.unwrap();
let mut cycle = Vec::new();
let cycle_participants = &mut query_stack[start_index..];
for active_query in &mut *cycle_participants {
cycle.push(active_query.database_key.clone());
}
assert!(!cycle.is_empty());
for active_query in cycle_participants {
active_query.cycle = cycle.clone();
}
crate::CycleError {
cycle,
changed_at,
durability: Durability::MAX,
}
} else {
// Part of the cycle is on another thread so we need to lock and inspect the shared
// state
let dependency_graph = self.shared_state.dependency_graph.lock();
let mut cycle = Vec::new();
{
let cycle_iter = dependency_graph
.get_cycle_path(
database_key,
error.to,
query_stack.iter().map(|query| &query.database_key),
)
.chain(Some(database_key));
for key in cycle_iter {
cycle.push(key.clone());
}
}
assert!(!cycle.is_empty());
for active_query in query_stack
.iter_mut()
.filter(|query| cycle.iter().any(|key| *key == query.database_key))
{
active_query.cycle = cycle.clone();
}
crate::CycleError {
cycle,
changed_at,
durability: Durability::MAX,
}
}
}
pub(crate) fn mark_cycle_participants(&self, err: &CycleError<DB::DatabaseKey>) {
for active_query in self
.local_state
.borrow_query_stack_mut()
.iter_mut()
.rev()
.take_while(|active_query| err.cycle.iter().any(|e| *e == active_query.database_key))
{
active_query.cycle = err.cycle.clone();
}
panic!(message)
}
/// Try to make this runtime blocked on `other_id`. Returns true
/// upon success or false if `other_id` is already blocked on us.
pub(crate) fn try_block_on(&self, database_key: &DB::DatabaseKey, other_id: RuntimeId) -> bool {
self.shared_state
.dependency_graph
.lock()
.add_edge(self.id(), database_key, other_id)
self.shared_state.dependency_graph.lock().add_edge(
self.id(),
database_key,
other_id,
self.local_state
.borrow_query_stack()
.iter()
.map(|query| query.database_key.clone()),
)
}
pub(crate) fn unblock_queries_blocked_on_self(&self, database_key: &DB::DatabaseKey) {
@ -508,7 +580,7 @@ struct SharedState<DB: Database> {
/// The dependency graph tracks which runtimes are blocked on one
/// another, waiting for queries to terminate.
dependency_graph: Mutex<DependencyGraph<DB>>,
dependency_graph: Mutex<DependencyGraph<DB::DatabaseKey>>,
}
impl<DB: Database> SharedState<DB> {
@ -571,6 +643,9 @@ struct ActiveQuery<DB: Database> {
/// Set of subqueries that were accessed thus far, or `None` if
/// there was an untracked the read.
dependencies: Option<FxIndexSet<Dependency<DB>>>,
/// Stores the entire cycle, if one is found and this query is part of it.
cycle: Vec<DB::DatabaseKey>,
}
pub(crate) struct ComputedQueryResult<DB: Database, V> {
@ -587,6 +662,9 @@ pub(crate) struct ComputedQueryResult<DB: Database, V> {
/// Complete set of subqueries that were accessed, or `None` if
/// there was an untracked the read.
pub(crate) dependencies: Option<FxIndexSet<Dependency<DB>>>,
/// The cycle if one occured while computing this value
pub(crate) cycle: Vec<DB::DatabaseKey>,
}
impl<DB: Database> ActiveQuery<DB> {
@ -596,6 +674,7 @@ impl<DB: Database> ActiveQuery<DB> {
durability: max_durability,
changed_at: Revision::start(),
dependencies: Some(FxIndexSet::default()),
cycle: Vec::new(),
}
}
@ -634,16 +713,26 @@ pub(crate) struct StampedValue<V> {
pub(crate) changed_at: Revision,
}
struct DependencyGraph<DB: Database> {
#[derive(Debug)]
struct Edge<K> {
id: RuntimeId,
path: Vec<K>,
}
#[derive(Debug)]
struct DependencyGraph<K: Hash + Eq> {
/// A `(K -> V)` pair in this map indicates that the the runtime
/// `K` is blocked on some query executing in the runtime `V`.
/// This encodes a graph that must be acyclic (or else deadlock
/// will result).
edges: FxHashMap<RuntimeId, RuntimeId>,
labels: FxHashMap<DB::DatabaseKey, SmallVec<[RuntimeId; 4]>>,
edges: FxHashMap<RuntimeId, Edge<K>>,
labels: FxHashMap<K, SmallVec<[RuntimeId; 4]>>,
}
impl<DB: Database> Default for DependencyGraph<DB> {
impl<K> Default for DependencyGraph<K>
where
K: Hash + Eq,
{
fn default() -> Self {
DependencyGraph {
edges: Default::default(),
@ -652,13 +741,17 @@ impl<DB: Database> Default for DependencyGraph<DB> {
}
}
impl<DB: Database> DependencyGraph<DB> {
impl<K> DependencyGraph<K>
where
K: Hash + Eq + Clone,
{
/// Attempt to add an edge `from_id -> to_id` into the result graph.
fn add_edge(
&mut self,
from_id: RuntimeId,
database_key: &DB::DatabaseKey,
database_key: &K,
to_id: RuntimeId,
path: impl IntoIterator<Item = K>,
) -> bool {
assert_ne!(from_id, to_id);
debug_assert!(!self.edges.contains_key(&from_id));
@ -666,7 +759,7 @@ impl<DB: Database> DependencyGraph<DB> {
// First: walk the chain of things that `to_id` depends on,
// looking for us.
let mut p = to_id;
while let Some(&q) = self.edges.get(&p) {
while let Some(q) = self.edges.get(&p).map(|edge| edge.id) {
if q == from_id {
return false;
}
@ -674,7 +767,13 @@ impl<DB: Database> DependencyGraph<DB> {
p = q;
}
self.edges.insert(from_id, to_id);
self.edges.insert(
from_id,
Edge {
id: to_id,
path: path.into_iter().chain(Some(database_key.clone())).collect(),
},
);
self.labels
.entry(database_key.clone())
.or_default()
@ -682,17 +781,54 @@ impl<DB: Database> DependencyGraph<DB> {
true
}
fn remove_edge(&mut self, database_key: &DB::DatabaseKey, to_id: RuntimeId) {
let vec = self
.labels
.remove(database_key)
.unwrap_or_default();
fn remove_edge(&mut self, database_key: &K, to_id: RuntimeId) {
let vec = self.labels.remove(database_key).unwrap_or_default();
for from_id in &vec {
let to_id1 = self.edges.remove(from_id);
let to_id1 = self.edges.remove(from_id).map(|edge| edge.id);
assert_eq!(Some(to_id), to_id1);
}
}
fn get_cycle_path<'a>(
&'a self,
database_key: &'a K,
to: RuntimeId,
local_path: impl IntoIterator<Item = &'a K>,
) -> impl Iterator<Item = &'a K>
where
K: std::fmt::Debug,
{
let mut current = Some((to, std::slice::from_ref(database_key)));
let mut last = None;
let mut local_path = Some(local_path);
std::iter::from_fn(move || match current.take() {
Some((id, path)) => {
let link_key = path.last().unwrap();
current = self.edges.get(&id).map(|edge| {
let i = edge.path.iter().rposition(|p| p == link_key).unwrap();
(edge.id, &edge.path[i + 1..])
});
if current.is_none() {
last = local_path.take().map(|local_path| {
local_path
.into_iter()
.skip_while(move |p| *p != link_key)
.skip(1)
});
}
Some(path)
}
None => match &mut last {
Some(iter) => iter.next().map(std::slice::from_ref),
None => None,
},
})
.flat_map(|x| x)
}
}
struct RevisionGuard<DB: Database> {
@ -742,3 +878,42 @@ where
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn dependency_graph_path1() {
let mut graph = DependencyGraph::default();
let a = RuntimeId { counter: 0 };
let b = RuntimeId { counter: 1 };
assert!(graph.add_edge(a, &2, b, vec![1]));
// assert!(graph.add_edge(b, &1, a, vec![3, 2]));
assert_eq!(
graph
.get_cycle_path(&1, a, &[3, 2][..])
.cloned()
.collect::<Vec<i32>>(),
vec![1, 2]
);
}
#[test]
fn dependency_graph_path2() {
let mut graph = DependencyGraph::default();
let a = RuntimeId { counter: 0 };
let b = RuntimeId { counter: 1 };
let c = RuntimeId { counter: 2 };
assert!(graph.add_edge(a, &3, b, vec![1]));
assert!(graph.add_edge(b, &4, c, vec![2, 3]));
// assert!(graph.add_edge(c, &1, a, vec![5, 6, 4, 7]));
assert_eq!(
graph
.get_cycle_path(&1, a, &[5, 6, 4, 7][..])
.cloned()
.collect::<Vec<i32>>(),
vec![1, 3, 4, 7]
);
}
}

View file

@ -3,8 +3,7 @@ use crate::durability::Durability;
use crate::runtime::ActiveQuery;
use crate::runtime::Revision;
use crate::Database;
use std::cell::Ref;
use std::cell::RefCell;
use std::cell::{Ref, RefCell, RefMut};
/// State that is specific to a single execution thread.
///
@ -51,6 +50,10 @@ impl<DB: Database> LocalState<DB> {
self.query_stack.borrow()
}
pub(super) fn borrow_query_stack_mut(&self) -> RefMut<'_, Vec<ActiveQuery<DB>>> {
self.query_stack.borrow_mut()
}
pub(super) fn query_in_progress(&self) -> bool {
!self.query_stack.borrow().is_empty()
}

View file

@ -1,3 +1,10 @@
use salsa::{ParallelDatabase, Snapshot};
#[derive(PartialEq, Eq, Hash, Clone, Debug)]
struct Error {
cycle: Vec<String>,
}
#[salsa::database(GroupStruct)]
#[derive(Default)]
struct DatabaseImpl {
@ -10,6 +17,14 @@ impl salsa::Database for DatabaseImpl {
}
}
impl ParallelDatabase for DatabaseImpl {
fn snapshot(&self) -> Snapshot<Self> {
Snapshot::new(DatabaseImpl {
runtime: self.runtime.snapshot(self),
})
}
}
#[salsa::query_group(GroupStruct)]
trait Database: salsa::Database {
// `a` and `b` depend on each other and form a cycle
@ -17,6 +32,27 @@ trait Database: salsa::Database {
fn memoized_b(&self) -> ();
fn volatile_a(&self) -> ();
fn volatile_b(&self) -> ();
fn cycle_leaf(&self) -> ();
#[salsa::cycle(recover_a)]
fn cycle_a(&self) -> Result<(), Error>;
#[salsa::cycle(recover_b)]
fn cycle_b(&self) -> Result<(), Error>;
fn cycle_c(&self) -> Result<(), Error>;
}
fn recover_a(_db: &impl Database, cycle: &[String]) -> Result<(), Error> {
Err(Error {
cycle: cycle.to_owned(),
})
}
fn recover_b(_db: &impl Database, cycle: &[String]) -> Result<(), Error> {
Err(Error {
cycle: cycle.to_owned(),
})
}
fn memoized_a(db: &impl Database) -> () {
@ -37,6 +73,23 @@ fn volatile_b(db: &impl Database) -> () {
db.volatile_a()
}
fn cycle_leaf(_db: &impl Database) -> () {}
fn cycle_a(db: &impl Database) -> Result<(), Error> {
let _ = db.cycle_b();
Ok(())
}
fn cycle_b(db: &impl Database) -> Result<(), Error> {
db.cycle_leaf();
let _ = db.cycle_a();
Ok(())
}
fn cycle_c(db: &impl Database) -> Result<(), Error> {
db.cycle_b()
}
#[test]
#[should_panic(expected = "cycle detected")]
fn cycle_memoized() {
@ -50,3 +103,64 @@ fn cycle_volatile() {
let query = DatabaseImpl::default();
query.volatile_a();
}
#[test]
fn cycle_cycle() {
let query = DatabaseImpl::default();
assert!(query.cycle_a().is_err());
}
#[test]
fn inner_cycle() {
let query = DatabaseImpl::default();
let err = query.cycle_c();
assert!(err.is_err());
let cycle = err.unwrap_err().cycle;
assert!(
cycle
.iter()
.zip(&["cycle_b", "cycle_a"])
.all(|(l, r)| l.contains(r)),
"{:#?}",
cycle
);
}
#[test]
fn parallel_cycle() {
let _ = env_logger::try_init();
let db = DatabaseImpl::default();
let thread1 = std::thread::spawn({
let db = db.snapshot();
move || {
let result = db.cycle_a();
assert!(result.is_err(), "Expected cycle error");
let cycle = result.unwrap_err().cycle;
assert!(
cycle
.iter()
.all(|l| ["cycle_b", "cycle_a"].iter().any(|r| l.contains(r))),
"{:#?}",
cycle
);
}
});
let thread2 = std::thread::spawn(move || {
let result = db.cycle_c();
assert!(result.is_err(), "Expected cycle error");
let cycle = result.unwrap_err().cycle;
assert!(
cycle
.iter()
.all(|l| ["cycle_b", "cycle_a"].iter().any(|r| l.contains(r))),
"{:#?}",
cycle
);
});
thread1.join().unwrap();
thread2.join().unwrap();
eprintln!("OK");
}

View file

@ -1,6 +1,4 @@
use crate::setup::{
CancelationFlag, Canceled, Knobs, ParDatabase, ParDatabaseImpl, WithValue,
};
use crate::setup::{CancelationFlag, Canceled, Knobs, ParDatabase, ParDatabaseImpl, WithValue};
use salsa::ParallelDatabase;
macro_rules! assert_canceled {

View file

@ -22,7 +22,7 @@ fn in_par_two_independent_queries() {
let thread2 = std::thread::spawn({
let db = db.snapshot();
move || db.sum("def")
});;
});
assert_eq!(thread1.join().unwrap(), 111);
assert_eq!(thread2.join().unwrap(), 222);

View file

@ -1,7 +1,6 @@
//! Test `salsa::requires` attribute for private query dependencies
//! https://github.com/salsa-rs/salsa-rfcs/pull/3
mod queries {
#[salsa::query_group(InputGroupStorage)]
pub trait InputGroup {
@ -14,7 +13,7 @@ mod queries {
fn private_a(&self, x: u32) -> u32;
}
fn private_a(db: &impl PrivGroupA, x: u32) -> u32{
fn private_a(db: &impl PrivGroupA, x: u32) -> u32 {
db.input(x)
}
@ -23,7 +22,7 @@ mod queries {
fn private_b(&self, x: u32) -> u32;
}
fn private_b(db: &impl PrivGroupB, x: u32) -> u32{
fn private_b(db: &impl PrivGroupB, x: u32) -> u32 {
db.input(x)
}
@ -34,7 +33,6 @@ mod queries {
fn public(&self, x: u32) -> u32;
}
fn public(db: &(impl PubGroup + PrivGroupA + PrivGroupB), x: u32) -> u32 {
db.private_a(x) + db.private_b(x)
}
@ -44,7 +42,7 @@ mod queries {
queries::InputGroupStorage,
queries::PrivGroupAStorage,
queries::PrivGroupBStorage,
queries::PubGroupStorage,
queries::PubGroupStorage
)]
#[derive(Default)]
struct Database {

View file

@ -17,7 +17,6 @@ fn get(db: &impl QueryGroup, x: u32) -> u32 {
db.wrap(x)
}
#[salsa::database(QueryGroupStorage)]
#[derive(Default)]
struct Database {