Improve Canceled API

This commit is contained in:
Jonas Schievink 2021-05-18 14:41:45 +02:00
parent 197b01fa4b
commit 49a1184bcf
3 changed files with 23 additions and 22 deletions

View file

@ -647,13 +647,28 @@ where
/// A panic payload indicating that a salsa revision was canceled. /// A panic payload indicating that a salsa revision was canceled.
#[derive(Debug)] #[derive(Debug)]
pub struct Canceled { #[non_exhaustive]
_private: (), pub struct Canceled;
}
impl Canceled { impl Canceled {
fn throw() -> ! { fn throw() -> ! {
std::panic::resume_unwind(Box::new(Self { _private: () })); // We use resume and not panic here to avoid running the panic
// hook (that is, to avoid collecting and printing backtrace).
std::panic::resume_unwind(Box::new(Self));
}
/// Runs `f`, and catches any salsa cancellation.
pub fn catch<F, T>(f: F) -> Result<T, Canceled>
where
F: FnOnce() -> T + UnwindSafe,
{
match panic::catch_unwind(f) {
Ok(t) => Ok(t),
Err(payload) => match payload.downcast() {
Ok(canceled) => Err(*canceled),
Err(payload) => panic::resume_unwind(payload),
},
}
} }
} }
@ -665,20 +680,6 @@ impl std::fmt::Display for Canceled {
impl std::error::Error for Canceled {} impl std::error::Error for Canceled {}
/// Runs `f`, and catches any salsa cancelation.
pub fn catch_cancellation<F, T>(f: F) -> Result<T, Canceled>
where
F: FnOnce() -> T + UnwindSafe,
{
match panic::catch_unwind(f) {
Ok(t) => Ok(t),
Err(payload) => match payload.downcast() {
Ok(canceled) => Err(*canceled),
Err(payload) => panic::resume_unwind(payload),
},
}
}
// Re-export the procedural macros. // Re-export the procedural macros.
#[allow(unused_imports)] #[allow(unused_imports)]
#[macro_use] #[macro_use]

View file

@ -1,7 +1,7 @@
use std::panic::AssertUnwindSafe; use std::panic::AssertUnwindSafe;
use crate::setup::{ParDatabase, ParDatabaseImpl}; use crate::setup::{ParDatabase, ParDatabaseImpl};
use salsa::ParallelDatabase; use salsa::{Canceled, ParallelDatabase};
/// Test where a read and a set are racing with one another. /// Test where a read and a set are racing with one another.
/// Should be atomic. /// Should be atomic.
@ -16,7 +16,7 @@ fn in_par_get_set_race() {
let thread1 = std::thread::spawn({ let thread1 = std::thread::spawn({
let db = db.snapshot(); let db = db.snapshot();
move || { move || {
salsa::catch_cancellation(AssertUnwindSafe(|| { Canceled::catch(AssertUnwindSafe(|| {
let v = db.sum("abc"); let v = db.sum("abc");
v v
})) }))

View file

@ -1,10 +1,10 @@
use rand::seq::SliceRandom; use rand::seq::SliceRandom;
use rand::Rng; use rand::Rng;
use salsa::Database;
use salsa::ParallelDatabase; use salsa::ParallelDatabase;
use salsa::Snapshot; use salsa::Snapshot;
use salsa::SweepStrategy; use salsa::SweepStrategy;
use salsa::{Canceled, Database};
// Number of operations a reader performs // Number of operations a reader performs
const N_MUTATOR_OPS: usize = 100; const N_MUTATOR_OPS: usize = 100;
@ -191,7 +191,7 @@ fn stress_test() {
check_cancellation, check_cancellation,
} => all_threads.push(std::thread::spawn({ } => all_threads.push(std::thread::spawn({
let db = db.snapshot(); let db = db.snapshot();
move || salsa::catch_cancellation(|| db_reader_thread(&db, ops, check_cancellation)) move || Canceled::catch(|| db_reader_thread(&db, ops, check_cancellation))
})), })),
} }
} }