rework cycle to permit partial recovery (wip)

This is "wip" because it does not yet handle cross-thread recovery
correctly. That is coming in a later commit.
This commit is contained in:
Niko Matsakis 2021-11-10 21:01:24 -05:00
parent c14a3d47ea
commit 45434cfa93
7 changed files with 131 additions and 173 deletions

View file

@ -5,7 +5,6 @@ use crate::lru::LruIndex;
use crate::lru::LruNode;
use crate::plumbing::{DatabaseOps, QueryFunction};
use crate::revision::Revision;
use crate::runtime::cycle_participant::CycleParticipant;
use crate::runtime::local_state::ActiveQueryGuard;
use crate::runtime::local_state::QueryInputs;
use crate::runtime::local_state::QueryRevisions;
@ -13,6 +12,7 @@ use crate::runtime::Runtime;
use crate::runtime::RuntimeId;
use crate::runtime::StampedValue;
use crate::runtime::WaitResult;
use crate::Cycle;
use crate::{Database, DatabaseKeyIndex, Event, EventKind, QueryDb};
use log::{debug, info};
use parking_lot::{RawRwLock, RwLock};
@ -239,12 +239,29 @@ where
// Query was not previously executed, or value is potentially
// stale, or value is absent. Let's execute!
let value = CycleParticipant::recover(
|| Q::execute(db, self.key.clone()),
// If a recoverable cycle occurs, `Q::execute` will throw
// and this closure will be executed with the cycle information.
|cycle| Q::cycle_fallback(db, &cycle, &self.key),
);
let value = match Cycle::catch(|| Q::execute(db, self.key.clone())) {
Ok(v) => v,
Err(cycle) => {
log::debug!(
"{:?}: caught cycle {:?}, have strategy {:?}",
self.database_key_index.debug(db),
cycle,
Q::CYCLE_STRATEGY,
);
match Q::CYCLE_STRATEGY {
crate::plumbing::CycleRecoveryStrategy::Panic => {
panic_guard.proceed(None);
cycle.throw()
}
crate::plumbing::CycleRecoveryStrategy::Fallback => {
if let Some(c) = active_query.take_cycle() {
assert!(c.is(&cycle));
}
Q::cycle_fallback(db, &cycle, &self.key)
}
}
}
};
let mut revisions = active_query.pop();

View file

@ -35,6 +35,7 @@ use crate::plumbing::QueryStorageOps;
pub use crate::revision::Revision;
use std::fmt::{self, Debug};
use std::hash::Hash;
use std::panic::AssertUnwindSafe;
use std::panic::{self, UnwindSafe};
use std::sync::Arc;
@ -645,7 +646,7 @@ impl std::fmt::Display for Cancelled {
impl std::error::Error for Cancelled {}
/// Captuers the participants of a cycle that occurred when executing a query.
/// Captures the participants of a cycle that occurred when executing a query.
///
/// This type is meant to be used to help give meaningful error messages to the
/// user or to help salsa developers figure out why their program is resulting
@ -670,10 +671,26 @@ impl Cycle {
Self { participants }
}
/// True if two `Cycle` values represent the same cycle.
pub(crate) fn is(&self, cycle: &Cycle) -> bool {
Arc::ptr_eq(&self.participants, &cycle.participants)
}
pub(crate) fn throw(self) -> ! {
log::debug!("throwing cycle {:?}", self);
std::panic::resume_unwind(Box::new(self))
}
pub(crate) fn catch<T>(execute: impl FnOnce() -> T) -> Result<T, Cycle> {
match std::panic::catch_unwind(AssertUnwindSafe(execute)) {
Ok(v) => Ok(v),
Err(err) => match err.downcast::<Cycle>() {
Ok(cycle) => Err(*cycle),
Err(other) => std::panic::resume_unwind(other),
},
}
}
/// Iterate over the [`DatabaseKeyIndex`] for each query participating
/// in the cycle. The start point of this iteration within the cycle
/// is arbitrary but deterministic, but the ordering is otherwise determined

View file

@ -13,15 +13,12 @@ use std::sync::Arc;
pub(crate) type FxIndexSet<K> = indexmap::IndexSet<K, BuildHasherDefault<FxHasher>>;
pub(crate) type FxIndexMap<K, V> = indexmap::IndexMap<K, V, BuildHasherDefault<FxHasher>>;
pub(crate) mod cycle_participant;
mod dependency_graph;
use dependency_graph::DependencyGraph;
pub(crate) mod local_state;
use local_state::LocalState;
use self::cycle_participant::CycleParticipant;
use self::local_state::{ActiveQueryGuard, QueryInputs, QueryRevisions};
/// The salsa runtime stores the storage for all queries as well as
@ -269,13 +266,13 @@ impl Runtime {
.report_synthetic_read(durability, changed_at);
}
fn create_cycle_error(
fn throw_cycle_error(
&self,
db: &dyn Database,
mut dg: MutexGuard<'_, DependencyGraph>,
database_key_index: DatabaseKeyIndex,
to_id: RuntimeId,
) -> (CycleRecoveryStrategy, Cycle) {
) -> ! {
debug!("create_cycle_error(database_key={:?})", database_key_index);
let mut from_stack = self.local_state.take_query_stack();
@ -319,80 +316,32 @@ impl Runtime {
cycle_query,
);
// Identify cycle recovery strategy:
let recovery_strategy = self.mutual_cycle_recovery_strategy(db, &cycle);
debug!("cycle recovery strategy {:?}", recovery_strategy);
// We can remove the cycle participants from the list of dependencies;
// they are a strongly connected component (SCC) and we only care about
// dependencies to things outside the SCC that control whether it will
// form again.
cycle_query.remove_cycle_participants(&cycle);
// If using fallback, we have to mark the cycle participants, so they know to recover.
// In this case, we also want to modify their changed-at and durability values to
// be the max/min we computed across the entire cycle.
match recovery_strategy {
CycleRecoveryStrategy::Panic => {}
CycleRecoveryStrategy::Fallback => {
// We don't need to include the cycle participants in the input.
// Everyone depends on all the external dependencies.
cycle_query.remove_cycle_participants(&cycle);
// Mark the cycle participants so they know to recover.
// This only matters for queries that have a fallback value specified;
// the others will just unwind without storing any recovery information.
dg.for_each_cycle_participant(from_id, &mut from_stack, database_key_index, to_id, |aq| {
match db.cycle_recovery_strategy(aq.database_key_index) {
CycleRecoveryStrategy::Fallback => {
debug!("marking {:?} for fallback", aq.database_key_index.debug(db));
aq.take_inputs_from(&cycle_query);
aq.cycle = Some(cycle.clone());
}
// Mark the cycle participants, so they know to recover:
dg.for_each_cycle_participant(
from_id,
&mut from_stack,
database_key_index,
to_id,
|aq| {
aq.take_inputs_from(&cycle_query);
aq.cycle = Some(cycle.clone());
},
);
// The top of the current stack is a bit of a special case:
//
// C0 --> ... --> Cn-1 -> Cn --> C0
// ^
// :
// This edge -------------+
//
// Each query Ci in C0..=Cn-1 will recover the "usual" way:
// the query Ci+1 will recover from [`CycleParticipant::unwind`]
// and store a recovery value in the table. This value will
// be returned to Ci, which will invoke
// `report_query_read_and_unwind_if_cycle_resulted` to record
// the dependency Ci -> Ci+1. This method will see the `cycle` flag
// set on the `ActiveQuery` for `Ci` and will unwind.
//
// However, the cyclic edge `Cn -> C0` is a bit different:
// since C0 has not recovered yet, we don't have an easy value
// to return and propagate upwards. Instead, we just add the dependency
// *here* and then (in our caller) unwind with `CycleParticipant`.
from_stack.last_mut().unwrap().cycle = None;
CycleRecoveryStrategy::Panic => {
// NB: Don't mark these frames!
}
}
}
});
self.local_state.restore_query_stack(from_stack);
(recovery_strategy, cycle)
}
fn mutual_cycle_recovery_strategy(
&self,
db: &dyn Database,
cycle: &Cycle,
) -> CycleRecoveryStrategy {
let participants = &cycle.participants;
let crs = db.cycle_recovery_strategy(participants[0]);
if let Some(key) = participants[1..]
.iter()
.copied()
.find(|&key| db.cycle_recovery_strategy(key) != crs)
{
debug!("mutual_cycle_recovery_strategy: cycle had multiple strategies ({:?} for {:?} vs {:?} for {:?})",
crs, participants[0],
db.cycle_recovery_strategy(key), key
);
CycleRecoveryStrategy::Panic
} else {
crs
}
cycle.throw()
}
/// Block until `other_id` completes executing `database_key`;
@ -427,40 +376,37 @@ impl Runtime {
let mut dg = self.shared_state.dependency_graph.lock();
if self.id() == other_id || dg.depends_on(other_id, self.id()) {
match self.create_cycle_error(db, dg, database_key, other_id) {
(CycleRecoveryStrategy::Panic, cycle) => cycle.throw(),
(CycleRecoveryStrategy::Fallback, cycle) => CycleParticipant::new(cycle).unwind(),
}
} else {
db.salsa_event(Event {
runtime_id: self.id(),
kind: EventKind::WillBlockOn {
other_runtime_id: other_id,
database_key,
},
});
self.throw_cycle_error(db, dg, database_key, other_id)
}
let stack = self.local_state.take_query_stack();
let (stack, result) = DependencyGraph::block_on(
dg,
self.id(),
db.salsa_event(Event {
runtime_id: self.id(),
kind: EventKind::WillBlockOn {
other_runtime_id: other_id,
database_key,
other_id,
stack,
query_mutex_guard,
);
},
});
self.local_state.restore_query_stack(stack);
let stack = self.local_state.take_query_stack();
match result {
WaitResult::Completed => (),
let (stack, result) = DependencyGraph::block_on(
dg,
self.id(),
database_key,
other_id,
stack,
query_mutex_guard,
);
// If the other thread panicked, then we consider this thread
// cancelled. The assumption is that the panic will be detected
// by the other thread and responded to appropriately.
WaitResult::Panicked => Cancelled::PropagatedPanic.throw(),
}
self.local_state.restore_query_stack(stack);
match result {
WaitResult::Completed => (),
// If the other thread panicked, then we consider this thread
// cancelled. The assumption is that the panic will be detected
// by the other thread and responded to appropriately.
WaitResult::Panicked => Cancelled::PropagatedPanic.throw(),
}
}

View file

@ -1,29 +0,0 @@
use std::panic::AssertUnwindSafe;
use crate::Cycle;
pub(crate) struct CycleParticipant {
cycle: Cycle,
}
impl CycleParticipant {
pub(crate) fn new(cycle: Cycle) -> Self {
Self { cycle }
}
/// Initiate unwinding. This is called `unwind` and not `throw` or `panic`
/// because every call to `unwind` here ought to be caught by a
/// matching call to [`recover`].
pub(crate) fn unwind(self) -> ! {
std::panic::resume_unwind(Box::new(self));
}
pub(crate) fn recover<T>(execute: impl FnOnce() -> T, recover: impl FnOnce(Cycle) -> T) -> T {
std::panic::catch_unwind(AssertUnwindSafe(execute)).unwrap_or_else(|err| {
match err.downcast::<CycleParticipant>() {
Ok(participant) => recover(participant.cycle),
Err(v) => std::panic::resume_unwind(v),
}
})
}
}

View file

@ -1,12 +1,13 @@
use log::debug;
use crate::durability::Durability;
use crate::runtime::ActiveQuery;
use crate::runtime::Revision;
use crate::Cycle;
use crate::DatabaseKeyIndex;
use std::cell::RefCell;
use std::sync::Arc;
use super::cycle_participant::CycleParticipant;
/// State that is specific to a single execution thread.
///
/// Internally, this type uses ref-cells.
@ -97,6 +98,10 @@ impl LocalState {
durability: Durability,
changed_at: Revision,
) {
debug!(
"report_query_read_and_unwind_if_cycle_resulted(input={:?}, durability={:?}, changed_at={:?})",
input, durability, changed_at
);
self.with_query_stack(|stack| {
if let Some(top_query) = stack.last_mut() {
top_query.add_read(input, durability, changed_at);
@ -123,7 +128,7 @@ impl LocalState {
// stack frames, so they will just read the fallback value
// from `Ci+1` and continue on their merry way.
if let Some(cycle) = top_query.cycle.take() {
CycleParticipant::new(cycle).unwind()
cycle.throw()
}
}
})
@ -211,6 +216,13 @@ impl ActiveQueryGuard<'_> {
popped_query.revisions()
}
/// If the active query is registered as a cycle participant, remove and
/// return that cycle.
pub(crate) fn take_cycle(&self) -> Option<Cycle> {
self.local_state
.with_query_stack(|stack| stack.last_mut()?.cycle.take())
}
}
impl Drop for ActiveQueryGuard<'_> {

View file

@ -172,6 +172,7 @@ fn cycle_c(db: &dyn Database) -> Result<(), Error> {
db.c_invokes().invoke(db)
}
#[track_caller]
fn extract_cycle(f: impl FnOnce() + UnwindSafe) -> salsa::Cycle {
let v = std::panic::catch_unwind(f);
if let Err(d) = &v {
@ -352,20 +353,17 @@ fn cycle_mixed_1() {
db.set_b_invokes(CycleQuery::C);
db.set_c_invokes(CycleQuery::B);
let u = extract_cycle(|| {
let _ = db.cycle_a();
});
insta::assert_debug_snapshot!((u.all_participants(&db), u.unexpected_participants(&db)), @r###"
(
[
"cycle_b(())",
"cycle_c(())",
],
[
"cycle_c(())",
],
)
"###);
let u = db.cycle_c();
insta::assert_debug_snapshot!(u, @r###"
Err(
Error {
cycle: [
"cycle_b(())",
"cycle_c(())",
],
},
)
"###);
}
#[test]
@ -381,21 +379,18 @@ fn cycle_mixed_2() {
db.set_b_invokes(CycleQuery::C);
db.set_c_invokes(CycleQuery::A);
let u = extract_cycle(|| {
let _ = db.cycle_a();
});
insta::assert_debug_snapshot!((u.all_participants(&db), u.unexpected_participants(&db)), @r###"
(
[
"cycle_a(())",
"cycle_b(())",
"cycle_c(())",
],
[
"cycle_c(())",
],
)
"###);
let u = db.cycle_a();
insta::assert_debug_snapshot!(u, @r###"
Err(
Error {
cycle: [
"cycle_a(())",
"cycle_b(())",
"cycle_c(())",
],
},
)
"###);
}
#[test]

View file

@ -3,7 +3,7 @@
//! both intra and cross thread.
use crate::setup::{Knobs, ParDatabase, ParDatabaseImpl};
use salsa::{Cancelled, ParallelDatabase};
use salsa::ParallelDatabase;
use test_env_log::test;
// Recover cycle test:
@ -164,6 +164,6 @@ fn panic_parallel_cycle() {
assert!(thread_a
.join()
.unwrap_err()
.downcast_ref::<Cancelled>()
.downcast_ref::<salsa::Cycle>()
.is_some());
}