Merge pull request #169 from matklad/strong-panic-safety

implement strong panic safety
This commit is contained in:
Niko Matsakis 2019-06-06 10:57:53 -04:00 committed by GitHub
commit 8d0f5ffd95
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 49 additions and 25 deletions

View file

@ -319,7 +319,7 @@ where
// FIXME(Amanieu/parking_lot#101) -- we are using a write-lock
// and not an upgradable read here because upgradable reads
// can sometimes encounter deadlocks.
let mut old_memo = match self.probe(
let old_memo = match self.probe(
db,
self.map.write(),
runtime,
@ -337,13 +337,13 @@ where
}
};
let panic_guard = PanicGuard::new(&self.map, key, database_key, runtime);
let mut panic_guard = PanicGuard::new(&self.map, key, old_memo, database_key, runtime);
// If we have an old-value, it *may* now be stale, since there
// has been a new revision since the last time we checked. So,
// first things first, let's walk over each of our previous
// inputs and check whether they are out of date.
if let Some(memo) = &mut old_memo {
if let Some(memo) = &mut panic_guard.memo {
if let Some(value) = memo.validate_memoized_value(db, revision_now) {
info!(
"{:?}({:?}): validated old memoized value",
@ -358,7 +358,7 @@ where
},
});
panic_guard.proceed(old_memo.unwrap(), &value);
panic_guard.proceed(&value);
return Ok(value);
}
@ -389,7 +389,7 @@ where
// really change, even if some of its inputs have. So we can
// "backdate" its `changed_at` revision to be the same as the
// old value.
if let Some(old_memo) = &old_memo {
if let Some(old_memo) = &panic_guard.memo {
if let Some(old_value) = &old_memo.value {
if MP::memoized_value_eq(&old_value, &result.value) {
debug!(
@ -445,16 +445,14 @@ where
}
}
};
panic_guard.memo = Some(Memo {
value,
changed_at: result.changed_at.revision,
verified_at: revision_now,
inputs,
});
panic_guard.proceed(
Memo {
value,
changed_at: result.changed_at.revision,
verified_at: revision_now,
inputs,
},
&new_value,
);
panic_guard.proceed(&new_value);
Ok(new_value)
}
@ -589,6 +587,7 @@ where
{
database_key: &'db DB::DatabaseKey,
key: &'db Q::Key,
memo: Option<Memo<DB, Q>>,
map: &'db RwLock<FxHashMap<Q::Key, QueryState<DB, Q>>>,
runtime: &'db Runtime<DB>,
}
@ -601,12 +600,14 @@ where
fn new(
map: &'db RwLock<FxHashMap<Q::Key, QueryState<DB, Q>>>,
key: &'db Q::Key,
memo: Option<Memo<DB, Q>>,
database_key: &'db DB::DatabaseKey,
runtime: &'db Runtime<DB>,
) -> Self {
Self {
database_key,
key,
memo,
map,
runtime,
}
@ -615,22 +616,18 @@ where
/// Proceed with our panic guard by overwriting the placeholder for `key`.
/// Once that completes, ensure that our deconstructor is not run once we
/// are out of scope.
fn proceed(self, memo: Memo<DB, Q>, new_value: &StampedValue<Q::Value>) {
self.overwrite_placeholder(Some(memo), Some(new_value));
fn proceed(mut self, new_value: &StampedValue<Q::Value>) {
self.overwrite_placeholder(Some(new_value));
std::mem::forget(self)
}
/// Overwrites the `InProgress` placeholder for `key` that we
/// inserted; if others were blocked, waiting for us to finish,
/// then notify them.
fn overwrite_placeholder(
&self,
memo: Option<Memo<DB, Q>>,
new_value: Option<&StampedValue<Q::Value>>,
) {
fn overwrite_placeholder(&mut self, new_value: Option<&StampedValue<Q::Value>>) {
let mut write = self.map.write();
let old_value = match memo {
let old_value = match self.memo.take() {
// Replace the `InProgress` marker that we installed with the new
// memo, thus releasing our unique access to this key.
Some(memo) => write.insert(self.key.clone(), QueryState::Memoized(memo)),
@ -682,7 +679,7 @@ where
fn drop(&mut self) {
if std::thread::panicking() {
// We panicked before we could proceed and need to remove `key`.
self.overwrite_placeholder(None, None)
self.overwrite_placeholder(None)
} else {
// If no panic occurred, then panic guard ought to be
// "forgotten" and so this Drop code should never run.

View file

@ -1,5 +1,6 @@
use salsa::{Database, ParallelDatabase, Snapshot};
use std::panic::{self, AssertUnwindSafe};
use std::sync::atomic::{AtomicU32, Ordering::SeqCst};
#[salsa::query_group(PanicSafelyStruct)]
trait PanicSafelyDatabase: salsa::Database {
@ -7,12 +8,21 @@ trait PanicSafelyDatabase: salsa::Database {
fn one(&self) -> usize;
fn panic_safely(&self) -> ();
fn outer(&self) -> ();
}
fn panic_safely(db: &impl PanicSafelyDatabase) -> () {
assert_eq!(db.one(), 1);
}
static OUTER_CALLS: AtomicU32 = AtomicU32::new(0);
fn outer(db: &impl PanicSafelyDatabase) -> () {
OUTER_CALLS.fetch_add(1, SeqCst);
db.panic_safely();
}
#[salsa::database(PanicSafelyStruct)]
#[derive(Default)]
struct DatabaseStruct {
@ -36,9 +46,10 @@ impl salsa::ParallelDatabase for DatabaseStruct {
#[test]
fn should_panic_safely() {
let mut db = DatabaseStruct::default();
db.set_one(0);
// Invoke `db.panic_safely() without having set `db.one`. `db.one` will
// default to 0 and we should catch the panic.
// return 0 and we should catch the panic.
let result = panic::catch_unwind(AssertUnwindSafe({
let db = db.snapshot();
move || db.panic_safely()
@ -48,7 +59,23 @@ fn should_panic_safely() {
// Set `db.one` to 1 and assert ok
db.set_one(1);
let result = panic::catch_unwind(AssertUnwindSafe(|| db.panic_safely()));
assert!(result.is_ok())
assert!(result.is_ok());
// Check, that memoized outer is not invalidated by a panic
{
assert_eq!(OUTER_CALLS.load(SeqCst), 0);
db.outer();
assert_eq!(OUTER_CALLS.load(SeqCst), 1);
db.set_one(0);
let result = panic::catch_unwind(AssertUnwindSafe(|| db.outer()));
assert!(result.is_err());
assert_eq!(OUTER_CALLS.load(SeqCst), 1);
db.set_one(1);
db.outer();
assert_eq!(OUTER_CALLS.load(SeqCst), 1);
}
}
#[test]