use salsa::{Database, ParallelDatabase, Snapshot}; use std::panic::{self, AssertUnwindSafe}; use std::sync::atomic::{AtomicU32, Ordering::SeqCst}; #[salsa::query_group(PanicSafelyStruct)] trait PanicSafelyDatabase: salsa::Database { #[salsa::input] fn one(&self) -> usize; fn panic_safely(&self) -> (); fn outer(&self) -> (); } fn panic_safely(db: &impl PanicSafelyDatabase) -> () { assert_eq!(db.one(), 1); } static OUTER_CALLS: AtomicU32 = AtomicU32::new(0); fn outer(db: &impl PanicSafelyDatabase) -> () { OUTER_CALLS.fetch_add(1, SeqCst); db.panic_safely(); } #[salsa::database(PanicSafelyStruct)] #[derive(Default)] struct DatabaseStruct { runtime: salsa::Runtime, } impl salsa::Database for DatabaseStruct { fn salsa_runtime(&self) -> &salsa::Runtime { &self.runtime } fn salsa_runtime_mut(&mut self) -> &mut salsa::Runtime { &mut self.runtime } } impl salsa::ParallelDatabase for DatabaseStruct { fn snapshot(&self) -> Snapshot { Snapshot::new(DatabaseStruct { runtime: self.runtime.snapshot(self), }) } } #[test] fn should_panic_safely() { let mut db = DatabaseStruct::default(); db.set_one(0); // Invoke `db.panic_safely() without having set `db.one`. `db.one` will // return 0 and we should catch the panic. let result = panic::catch_unwind(AssertUnwindSafe({ let db = db.snapshot(); move || db.panic_safely() })); assert!(result.is_err()); // Set `db.one` to 1 and assert ok db.set_one(1); let result = panic::catch_unwind(AssertUnwindSafe(|| db.panic_safely())); assert!(result.is_ok()); // Check, that memoized outer is not invalidated by a panic { assert_eq!(OUTER_CALLS.load(SeqCst), 0); db.outer(); assert_eq!(OUTER_CALLS.load(SeqCst), 1); db.set_one(0); let result = panic::catch_unwind(AssertUnwindSafe(|| db.outer())); assert!(result.is_err()); assert_eq!(OUTER_CALLS.load(SeqCst), 1); db.set_one(1); db.outer(); assert_eq!(OUTER_CALLS.load(SeqCst), 1); } } #[test] fn storages_are_unwind_safe() { fn check_unwind_safe() {} check_unwind_safe::<&DatabaseStruct>(); } #[test] fn panics_clear_query_stack() { let db = DatabaseStruct::default(); // Invoke `db.panic_if_not_one() without having set `db.input`. `db.input` // will default to 0 and we should catch the panic. let result = panic::catch_unwind(AssertUnwindSafe(|| db.panic_safely())); assert!(result.is_err()); // The database has been poisoned and any attempt to increment the // revision should panic. assert_eq!(db.salsa_runtime().active_query(), None); }