diff --git a/src/runtime.rs b/src/runtime.rs index 3d5d997..9d32a27 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -7,9 +7,12 @@ use parking_lot::lock_api::{RawRwLock, RawRwLockRecursive}; use parking_lot::{Mutex, RwLock}; use rustc_hash::{FxHashMap, FxHasher}; use smallvec::SmallVec; -use std::hash::{BuildHasherDefault, Hash}; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; +use std::{ + hash::{BuildHasherDefault, Hash}, + panic::RefUnwindSafe, +}; pub(crate) type FxIndexSet = indexmap::IndexSet>; pub(crate) type FxIndexMap = indexmap::IndexMap>; @@ -38,6 +41,8 @@ pub struct Runtime { /// Shared state that is accessible via all runtimes. shared_state: Arc, + + on_cancelation_check: Option>, } impl Default for Runtime { @@ -47,6 +52,7 @@ impl Default for Runtime { revision_guard: None, shared_state: Default::default(), local_state: Default::default(), + on_cancelation_check: None, } } } @@ -85,6 +91,7 @@ impl Runtime { revision_guard: Some(revision_guard), shared_state: self.shared_state.clone(), local_state: Default::default(), + on_cancelation_check: None, } } @@ -166,6 +173,10 @@ impl Runtime { /// invocation. #[inline] pub fn unwind_if_canceled(&self) { + if let Some(callback) = &self.on_cancelation_check { + callback(); + } + let current_revision = self.current_revision(); let pending_revision = self.pending_revision(); debug!( @@ -183,6 +194,15 @@ impl Runtime { Canceled::throw(); } + /// Registers a callback to be invoked every time [`Runtime::unwind_if_canceled`] is called + /// (either automatically by salsa, or manually by user code). + pub fn set_cancelation_check_callback(&mut self, callback: F) + where + F: Fn() + Send + RefUnwindSafe + 'static, + { + self.on_cancelation_check = Some(Box::new(callback)); + } + /// Acquires the **global query write lock** (ensuring that no queries are /// executing) and then increments the current revision counter; invokes /// `op` with the global query write lock still held.