mirror of
https://github.com/salsa-rs/salsa.git
synced 2025-01-12 16:35:21 +00:00
Merge pull request #169 from matklad/strong-panic-safety
implement strong panic safety
This commit is contained in:
commit
8d0f5ffd95
2 changed files with 49 additions and 25 deletions
|
@ -319,7 +319,7 @@ where
|
|||
// FIXME(Amanieu/parking_lot#101) -- we are using a write-lock
|
||||
// and not an upgradable read here because upgradable reads
|
||||
// can sometimes encounter deadlocks.
|
||||
let mut old_memo = match self.probe(
|
||||
let old_memo = match self.probe(
|
||||
db,
|
||||
self.map.write(),
|
||||
runtime,
|
||||
|
@ -337,13 +337,13 @@ where
|
|||
}
|
||||
};
|
||||
|
||||
let panic_guard = PanicGuard::new(&self.map, key, database_key, runtime);
|
||||
let mut panic_guard = PanicGuard::new(&self.map, key, old_memo, database_key, runtime);
|
||||
|
||||
// If we have an old-value, it *may* now be stale, since there
|
||||
// has been a new revision since the last time we checked. So,
|
||||
// first things first, let's walk over each of our previous
|
||||
// inputs and check whether they are out of date.
|
||||
if let Some(memo) = &mut old_memo {
|
||||
if let Some(memo) = &mut panic_guard.memo {
|
||||
if let Some(value) = memo.validate_memoized_value(db, revision_now) {
|
||||
info!(
|
||||
"{:?}({:?}): validated old memoized value",
|
||||
|
@ -358,7 +358,7 @@ where
|
|||
},
|
||||
});
|
||||
|
||||
panic_guard.proceed(old_memo.unwrap(), &value);
|
||||
panic_guard.proceed(&value);
|
||||
|
||||
return Ok(value);
|
||||
}
|
||||
|
@ -389,7 +389,7 @@ where
|
|||
// really change, even if some of its inputs have. So we can
|
||||
// "backdate" its `changed_at` revision to be the same as the
|
||||
// old value.
|
||||
if let Some(old_memo) = &old_memo {
|
||||
if let Some(old_memo) = &panic_guard.memo {
|
||||
if let Some(old_value) = &old_memo.value {
|
||||
if MP::memoized_value_eq(&old_value, &result.value) {
|
||||
debug!(
|
||||
|
@ -445,16 +445,14 @@ where
|
|||
}
|
||||
}
|
||||
};
|
||||
|
||||
panic_guard.proceed(
|
||||
Memo {
|
||||
panic_guard.memo = Some(Memo {
|
||||
value,
|
||||
changed_at: result.changed_at.revision,
|
||||
verified_at: revision_now,
|
||||
inputs,
|
||||
},
|
||||
&new_value,
|
||||
);
|
||||
});
|
||||
|
||||
panic_guard.proceed(&new_value);
|
||||
|
||||
Ok(new_value)
|
||||
}
|
||||
|
@ -589,6 +587,7 @@ where
|
|||
{
|
||||
database_key: &'db DB::DatabaseKey,
|
||||
key: &'db Q::Key,
|
||||
memo: Option<Memo<DB, Q>>,
|
||||
map: &'db RwLock<FxHashMap<Q::Key, QueryState<DB, Q>>>,
|
||||
runtime: &'db Runtime<DB>,
|
||||
}
|
||||
|
@ -601,12 +600,14 @@ where
|
|||
fn new(
|
||||
map: &'db RwLock<FxHashMap<Q::Key, QueryState<DB, Q>>>,
|
||||
key: &'db Q::Key,
|
||||
memo: Option<Memo<DB, Q>>,
|
||||
database_key: &'db DB::DatabaseKey,
|
||||
runtime: &'db Runtime<DB>,
|
||||
) -> Self {
|
||||
Self {
|
||||
database_key,
|
||||
key,
|
||||
memo,
|
||||
map,
|
||||
runtime,
|
||||
}
|
||||
|
@ -615,22 +616,18 @@ where
|
|||
/// Proceed with our panic guard by overwriting the placeholder for `key`.
|
||||
/// Once that completes, ensure that our deconstructor is not run once we
|
||||
/// are out of scope.
|
||||
fn proceed(self, memo: Memo<DB, Q>, new_value: &StampedValue<Q::Value>) {
|
||||
self.overwrite_placeholder(Some(memo), Some(new_value));
|
||||
fn proceed(mut self, new_value: &StampedValue<Q::Value>) {
|
||||
self.overwrite_placeholder(Some(new_value));
|
||||
std::mem::forget(self)
|
||||
}
|
||||
|
||||
/// Overwrites the `InProgress` placeholder for `key` that we
|
||||
/// inserted; if others were blocked, waiting for us to finish,
|
||||
/// then notify them.
|
||||
fn overwrite_placeholder(
|
||||
&self,
|
||||
memo: Option<Memo<DB, Q>>,
|
||||
new_value: Option<&StampedValue<Q::Value>>,
|
||||
) {
|
||||
fn overwrite_placeholder(&mut self, new_value: Option<&StampedValue<Q::Value>>) {
|
||||
let mut write = self.map.write();
|
||||
|
||||
let old_value = match memo {
|
||||
let old_value = match self.memo.take() {
|
||||
// Replace the `InProgress` marker that we installed with the new
|
||||
// memo, thus releasing our unique access to this key.
|
||||
Some(memo) => write.insert(self.key.clone(), QueryState::Memoized(memo)),
|
||||
|
@ -682,7 +679,7 @@ where
|
|||
fn drop(&mut self) {
|
||||
if std::thread::panicking() {
|
||||
// We panicked before we could proceed and need to remove `key`.
|
||||
self.overwrite_placeholder(None, None)
|
||||
self.overwrite_placeholder(None)
|
||||
} else {
|
||||
// If no panic occurred, then panic guard ought to be
|
||||
// "forgotten" and so this Drop code should never run.
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
use salsa::{Database, ParallelDatabase, Snapshot};
|
||||
use std::panic::{self, AssertUnwindSafe};
|
||||
use std::sync::atomic::{AtomicU32, Ordering::SeqCst};
|
||||
|
||||
#[salsa::query_group(PanicSafelyStruct)]
|
||||
trait PanicSafelyDatabase: salsa::Database {
|
||||
|
@ -7,12 +8,21 @@ trait PanicSafelyDatabase: salsa::Database {
|
|||
fn one(&self) -> usize;
|
||||
|
||||
fn panic_safely(&self) -> ();
|
||||
|
||||
fn outer(&self) -> ();
|
||||
}
|
||||
|
||||
fn panic_safely(db: &impl PanicSafelyDatabase) -> () {
|
||||
assert_eq!(db.one(), 1);
|
||||
}
|
||||
|
||||
static OUTER_CALLS: AtomicU32 = AtomicU32::new(0);
|
||||
|
||||
fn outer(db: &impl PanicSafelyDatabase) -> () {
|
||||
OUTER_CALLS.fetch_add(1, SeqCst);
|
||||
db.panic_safely();
|
||||
}
|
||||
|
||||
#[salsa::database(PanicSafelyStruct)]
|
||||
#[derive(Default)]
|
||||
struct DatabaseStruct {
|
||||
|
@ -36,9 +46,10 @@ impl salsa::ParallelDatabase for DatabaseStruct {
|
|||
#[test]
|
||||
fn should_panic_safely() {
|
||||
let mut db = DatabaseStruct::default();
|
||||
db.set_one(0);
|
||||
|
||||
// Invoke `db.panic_safely() without having set `db.one`. `db.one` will
|
||||
// default to 0 and we should catch the panic.
|
||||
// return 0 and we should catch the panic.
|
||||
let result = panic::catch_unwind(AssertUnwindSafe({
|
||||
let db = db.snapshot();
|
||||
move || db.panic_safely()
|
||||
|
@ -48,7 +59,23 @@ fn should_panic_safely() {
|
|||
// Set `db.one` to 1 and assert ok
|
||||
db.set_one(1);
|
||||
let result = panic::catch_unwind(AssertUnwindSafe(|| db.panic_safely()));
|
||||
assert!(result.is_ok())
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Check, that memoized outer is not invalidated by a panic
|
||||
{
|
||||
assert_eq!(OUTER_CALLS.load(SeqCst), 0);
|
||||
db.outer();
|
||||
assert_eq!(OUTER_CALLS.load(SeqCst), 1);
|
||||
|
||||
db.set_one(0);
|
||||
let result = panic::catch_unwind(AssertUnwindSafe(|| db.outer()));
|
||||
assert!(result.is_err());
|
||||
assert_eq!(OUTER_CALLS.load(SeqCst), 1);
|
||||
|
||||
db.set_one(1);
|
||||
db.outer();
|
||||
assert_eq!(OUTER_CALLS.load(SeqCst), 1);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
Loading…
Reference in a new issue