diff --git a/Cargo.toml b/Cargo.toml index 50dd84f4..968ba365 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,5 +26,6 @@ linked-hash-map = "0.5.2" rand = "0.7" rand_distr = "0.2.1" test-env-log = "0.2.7" +insta = "1.8.0" [workspace] diff --git a/src/derived/slot.rs b/src/derived/slot.rs index 63b179ca..bbcb2267 100644 --- a/src/derived/slot.rs +++ b/src/derived/slot.rs @@ -117,6 +117,20 @@ enum ProbeState { UpToDate(Result), } +/// Return value of `maybe_changed_since_probe` helper. +enum MaybeChangedSinceProbeState { + /// Another thread was active but has completed. + /// Try again! + Retry, + + /// Value may have changed in the given revision. + ChangedAt(Revision), + + /// There is a stale cache entry that has not been + /// verified in this revision, so we can't say. + Stale(G), +} + impl Slot where Q: QueryFunction, @@ -480,176 +494,117 @@ where self, revision, revision_now, ); - // Acquire read lock to start. In some of the arms below, we - // drop this explicitly. - let state = self.state.read(); - - // Look for a memoized value. - let memo = match &*state { - // If somebody depends on us, but we have no map - // entry, that must mean that it was found to be out - // of date and removed. - QueryState::NotComputed => { - debug!("maybe_changed_since({:?}): no value", self); - return true; + // Do an initial probe with just the read-lock. + // + // If we find that a cache entry for the value is present + // but hasn't been verified in this revision, we'll have to + // do more. + loop { + match self.maybe_changed_since_probe(db, self.state.read(), runtime, revision_now) { + MaybeChangedSinceProbeState::Retry => continue, + MaybeChangedSinceProbeState::ChangedAt(changed_at) => return changed_at > revision, + MaybeChangedSinceProbeState::Stale(state) => { + drop(state); + return self.maybe_changed_since_upgrade(db, revision); + } } + } + } - // This value is being actively recomputed. Wait for - // that thread to finish (assuming it's not dependent - // on us...) and check its associated revision. - QueryState::InProgress { id, anyone_waiting } => { - let other_id = *id; - debug!( - "maybe_changed_since({:?}): blocking on thread `{:?}`", - self, other_id, - ); + fn maybe_changed_since_probe( + &self, + db: &>::DynDb, + state: StateGuard, + runtime: &Runtime, + revision_now: Revision, + ) -> MaybeChangedSinceProbeState + where + StateGuard: Deref>, + { + match self.probe(db, state, runtime, revision_now) { + ProbeState::Retry => MaybeChangedSinceProbeState::Retry, - // NB: `Ordering::Relaxed` is sufficient here, - // see `probe` for more details. - anyone_waiting.store(true, Ordering::Relaxed); + ProbeState::Stale(state) => MaybeChangedSinceProbeState::Stale(state), - return match self.block_on_in_progress_thread(db, runtime, other_id, state) { - // The other thread has completed. Have to try again. We've lost our lock, - // so just recurse. (We should probably clean this up to a loop later, - // but recursing is not terrible: this shouldn't happen more than once per revision.) - Ok(WaitResult::Completed) => self.maybe_changed_since(db, revision), - Ok(WaitResult::Panicked) => Cancelled::throw(), - Err(_) => true, - }; + // If we know when value last changed, we can return right away. + // Note that we don't need the actual value to be available. + ProbeState::NoValue(_, changed_at) + | ProbeState::UpToDate(Ok(StampedValue { + value: _, + durability: _, + changed_at, + })) => MaybeChangedSinceProbeState::ChangedAt(changed_at), + + // If we have nothing cached, then value may have changed. + ProbeState::NotComputed(_) => MaybeChangedSinceProbeState::ChangedAt(revision_now), + + // Consider cycles as potentially having changed. + ProbeState::UpToDate(Err(_)) => MaybeChangedSinceProbeState::ChangedAt(revision_now), + } + } + + fn maybe_changed_since_upgrade( + &self, + db: &>::DynDb, + revision: Revision, + ) -> bool { + let runtime = db.salsa_runtime(); + let revision_now = runtime.current_revision(); + + // Get an upgradable read lock, which permits other reads but no writers. + // Probe again. If the value is stale (needs to be verified), then upgrade + // to a write lock and swap it with InProgress while we work. + let old_memo = match self.maybe_changed_since_probe( + db, + self.state.upgradable_read(), + runtime, + revision_now, + ) { + MaybeChangedSinceProbeState::ChangedAt(changed_at) => return changed_at > revision, + + // If another thread was active, then the cache line is going to be + // either verified or cleared out. Just recurse to figure out which. + // Note that we don't need an upgradable read. + MaybeChangedSinceProbeState::Retry => return self.maybe_changed_since(db, revision), + + MaybeChangedSinceProbeState::Stale(state) => { + type RwLockUpgradableReadGuard<'a, T> = + lock_api::RwLockUpgradableReadGuard<'a, RawRwLock, T>; + + let mut state = RwLockUpgradableReadGuard::upgrade(state); + match std::mem::replace(&mut *state, QueryState::in_progress(runtime.id())) { + QueryState::Memoized(old_memo) => old_memo, + QueryState::NotComputed | QueryState::InProgress { .. } => unreachable!(), + } } - - QueryState::Memoized(memo) => memo, }; - if memo.revisions.verified_at == revision_now { - debug!( - "maybe_changed_since({:?}): {:?} since up-to-date memo that changed at {:?}", - self, - memo.revisions.changed_at > revision, - memo.revisions.changed_at, - ); - return memo.revisions.changed_at > revision; - } + let mut panic_guard = + PanicGuard::new(self.database_key_index, self, Some(old_memo), runtime); - let maybe_changed; - - // If we only depended on constants, and no constant has been - // modified since then, we cannot have changed; no need to - // trace our inputs. - if memo.revisions.check_durability(runtime) { - std::mem::drop(state); - maybe_changed = false; - } else { - match &memo.revisions.inputs { - MemoInputs::Untracked => { - // we don't know the full set of - // inputs, so if there is a new - // revision, we must assume it is - // dirty - debug!( - "maybe_changed_since({:?}: true since untracked inputs", - self, - ); - return true; - } - - MemoInputs::NoInputs => { - std::mem::drop(state); - maybe_changed = false; - } - - MemoInputs::Tracked { inputs } => { - // At this point, the value may be dirty (we have - // to check the database-keys). If we have a cached - // value, we'll just fall back to invoking `read`, - // which will do that checking (and a bit more) -- - // note that we skip the "pure read" part as we - // already know the result. - assert!(inputs.len() > 0); - if memo.value.is_some() { - std::mem::drop(state); - return match self.read_upgrade(db, revision_now) { - Ok(v) => { - debug!( - "maybe_changed_since({:?}: {:?} since (recomputed) value changed at {:?}", - self, - v.changed_at > revision, - v.changed_at, - ); - v.changed_at > revision - } - Err(_) => true, - }; - } - - // We have a **tracked set of inputs** that need to be validated. - let inputs = inputs.clone(); - // We'll need to update the state anyway (see below), so release the read-lock. - std::mem::drop(state); - - // Iterate the inputs and see if any have maybe changed. - maybe_changed = inputs - .iter() - .filter(|&&input| db.maybe_changed_since(input, revision)) - .inspect(|input| debug!("{:?}: input `{:?}` may have changed", self, input)) - .next() - .is_some(); - } - } - } - - // Either way, we have to update our entry. - // - // Keep in mind, though, that we released the lock before checking the ipnuts and a lot - // could have happened in the interim. =) Therefore, we have to probe the current - // `self.state` again and in some cases we ought to do nothing. + let memo = panic_guard.memo.as_mut().unwrap(); + if memo + .revisions + .validate_memoized_value(db.ops_database(), revision_now) { - let mut state = self.state.write(); - match &mut *state { - QueryState::Memoized(memo) => { - if memo.revisions.verified_at == revision_now { - // Since we started verifying inputs, somebody - // else has come along and updated this value - // (they may even have recomputed - // it). Therefore, we should not touch this - // memo. - // - // FIXME: Should we still return whatever - // `maybe_changed` value we computed, - // however..? It seems .. harmless to indicate - // that the value has changed, but possibly - // less efficient? (It may cause some - // downstream value to be recomputed that - // wouldn't otherwise have to be?) - } else if maybe_changed { - // We found this entry is out of date and - // nobody touch it in the meantime. Just - // remove it. - *state = QueryState::NotComputed; - } else { - // We found this entry is valid. Update the - // `verified_at` to reflect the current - // revision. - memo.revisions.verified_at = revision_now; - } - } - - QueryState::InProgress { .. } => { - // Since we started verifying inputs, somebody - // else has come along and started updated this - // value. Just leave their marker alone and return - // whatever `maybe_changed` value we computed. - } - - QueryState::NotComputed => { - // Since we started verifying inputs, somebody - // else has come along and removed this value. The - // GC can do this, for example. That's fine. - } - } + let maybe_changed = memo.revisions.changed_at > revision; + panic_guard.proceed(); + maybe_changed + } else if memo.value.is_some() { + // We found that this memoized value may have changed + // but we have an old value. We can re-run the code and + // actually *check* if it has changed. + let StampedValue { changed_at, .. } = + self.execute(db, runtime, revision_now, panic_guard); + changed_at > revision + } else { + // We found that inputs to this memoized value may have chanced + // but we don't have an old value to compare against or re-use. + // No choice but to drop the memo and say that its value may have changed. + panic_guard.memo = None; + panic_guard.proceed(); + true } - - maybe_changed } /// Helper: diff --git a/src/lib.rs b/src/lib.rs index 28bbab0f..b7bf59c7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -117,7 +117,7 @@ pub struct Event { impl Event { /// Returns a type that gives a user-readable debug output. /// Use like `println!("{:?}", index.debug(db))`. - pub fn debug(self, db: &D) -> impl std::fmt::Debug + '_ + pub fn debug<'me, D: ?Sized>(&'me self, db: &'me D) -> impl std::fmt::Debug + 'me where D: plumbing::DatabaseOps, { @@ -134,15 +134,15 @@ impl fmt::Debug for Event { } } -struct EventDebug<'db, D: ?Sized> +struct EventDebug<'me, D: ?Sized> where D: plumbing::DatabaseOps, { - event: Event, - db: &'db D, + event: &'me Event, + db: &'me D, } -impl<'db, D: ?Sized> fmt::Debug for EventDebug<'db, D> +impl<'me, D: ?Sized> fmt::Debug for EventDebug<'me, D> where D: plumbing::DatabaseOps, { diff --git a/src/runtime.rs b/src/runtime.rs index fd875eca..6f6c327f 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -288,8 +288,9 @@ impl Runtime { /// /// This is mostly useful to control the durability level for [on-demand inputs](https://salsa-rs.github.io/salsa/common_patterns/on_demand_inputs.html). pub fn report_synthetic_read(&self, durability: Durability) { + let changed_at = self.last_changed_revision(durability); self.local_state - .report_synthetic_read(durability, self.current_revision()); + .report_synthetic_read(durability, changed_at); } /// Obviously, this should be user configurable at some point. @@ -550,9 +551,10 @@ impl ActiveQuery { self.changed_at = changed_at; } - fn add_synthetic_read(&mut self, durability: Durability, current_revision: Revision) { + fn add_synthetic_read(&mut self, durability: Durability, revision: Revision) { + self.dependencies = None; self.durability = self.durability.min(durability); - self.changed_at = current_revision; + self.changed_at = self.changed_at.max(revision); } } diff --git a/src/runtime/local_state.rs b/src/runtime/local_state.rs index 60a4d9e9..028f9bcc 100644 --- a/src/runtime/local_state.rs +++ b/src/runtime/local_state.rs @@ -85,10 +85,12 @@ impl LocalState { }) } - pub(super) fn report_synthetic_read(&self, durability: Durability, current_revision: Revision) { + /// Update the top query on the stack to act as though it read a value + /// of durability `durability` which changed in `revision`. + pub(super) fn report_synthetic_read(&self, durability: Durability, revision: Revision) { self.with_query_stack(|stack| { if let Some(top_query) = stack.last_mut() { - top_query.add_synthetic_read(durability, current_revision); + top_query.add_synthetic_read(durability, revision); } }) } diff --git a/tests/incremental/memoized_volatile.rs b/tests/incremental/memoized_volatile.rs index 203c441a..6dc50300 100644 --- a/tests/incremental/memoized_volatile.rs +++ b/tests/incremental/memoized_volatile.rs @@ -60,7 +60,7 @@ fn revalidate() { // will not (still 0, as 1/2 = 0) query.salsa_runtime_mut().synthetic_write(Durability::LOW); query.memoized2(); - query.assert_log(&["Memoized1 invoked", "Volatile invoked"]); + query.assert_log(&["Volatile invoked", "Memoized1 invoked"]); query.memoized2(); query.assert_log(&[]); @@ -70,7 +70,7 @@ fn revalidate() { query.salsa_runtime_mut().synthetic_write(Durability::LOW); query.memoized2(); - query.assert_log(&["Memoized1 invoked", "Volatile invoked", "Memoized2 invoked"]); + query.assert_log(&["Volatile invoked", "Memoized1 invoked", "Memoized2 invoked"]); query.memoized2(); query.assert_log(&[]); diff --git a/tests/on_demand_inputs.rs b/tests/on_demand_inputs.rs index 0d4b1987..99092025 100644 --- a/tests/on_demand_inputs.rs +++ b/tests/on_demand_inputs.rs @@ -4,9 +4,9 @@ //! via a b query with zero inputs, which uses `add_synthetic_read` to //! tweak durability and `invalidate` to clear the input. -use std::{cell::Cell, collections::HashMap, rc::Rc}; +use std::{cell::RefCell, collections::HashMap, rc::Rc}; -use salsa::{Database as _, Durability}; +use salsa::{Database as _, Durability, EventKind}; #[salsa::query_group(QueryGroupStorage)] trait QueryGroup: salsa::Database + AsRef> { @@ -39,13 +39,15 @@ fn c(db: &dyn QueryGroup, x: u32) -> u32 { struct Database { storage: salsa::Storage, external_state: HashMap, - on_event: Option>, + on_event: Option>, } impl salsa::Database for Database { fn salsa_event(&self, event: salsa::Event) { + dbg!(event.debug(self)); + if let Some(cb) = &self.on_event { - cb(event) + cb(self, event) } } } @@ -84,30 +86,68 @@ fn on_demand_input_works() { #[test] fn on_demand_input_durability() { let mut db = Database::default(); - db.external_state.insert(1, 10); - db.external_state.insert(2, 20); - assert_eq!(db.b(1), 10); - assert_eq!(db.b(2), 20); - let validated = Rc::new(Cell::new(0)); + let events = Rc::new(RefCell::new(vec![])); db.on_event = Some(Box::new({ - let validated = Rc::clone(&validated); - move |event| { - if let salsa::EventKind::DidValidateMemoizedValue { .. } = event.kind { - validated.set(validated.get() + 1) + let events = events.clone(); + move |db, event| { + if let EventKind::WillCheckCancellation = event.kind { + // these events are not interesting + } else { + events.borrow_mut().push(format!("{:?}", event.debug(db))) } } })); - db.salsa_runtime_mut().synthetic_write(Durability::LOW); - validated.set(0); - assert_eq!(db.c(1), 10); - assert_eq!(db.c(2), 20); - assert_eq!(validated.get(), 2); + events.replace(vec![]); + db.external_state.insert(1, 10); + db.external_state.insert(2, 20); + assert_eq!(db.b(1), 10); + assert_eq!(db.b(2), 20); + insta::assert_debug_snapshot!(events, @r###" + RefCell { + value: [ + "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: b(1) } }", + "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: a(1) } }", + "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: b(2) } }", + "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: a(2) } }", + ], + } + "###); - db.salsa_runtime_mut().synthetic_write(Durability::HIGH); - validated.set(0); + eprintln!("------------------"); + db.salsa_runtime_mut().synthetic_write(Durability::LOW); + events.replace(vec![]); assert_eq!(db.c(1), 10); assert_eq!(db.c(2), 20); - assert_eq!(validated.get(), 4); + // Re-execute `a(2)` because that has low durability, but not `a(1)` + insta::assert_debug_snapshot!(events, @r###" + RefCell { + value: [ + "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: c(1) } }", + "Event { runtime_id: RuntimeId { counter: 0 }, kind: DidValidateMemoizedValue { database_key: b(1) } }", + "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: c(2) } }", + "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: a(2) } }", + "Event { runtime_id: RuntimeId { counter: 0 }, kind: DidValidateMemoizedValue { database_key: b(2) } }", + ], + } + "###); + + eprintln!("------------------"); + db.salsa_runtime_mut().synthetic_write(Durability::HIGH); + events.replace(vec![]); + assert_eq!(db.c(1), 10); + assert_eq!(db.c(2), 20); + // Re-execute both `a(1)` and `a(2)`, but we don't re-execute any `b` queries as the + // result didn't actually change. + insta::assert_debug_snapshot!(events, @r###" + RefCell { + value: [ + "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: a(1) } }", + "Event { runtime_id: RuntimeId { counter: 0 }, kind: DidValidateMemoizedValue { database_key: c(1) } }", + "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: a(2) } }", + "Event { runtime_id: RuntimeId { counter: 0 }, kind: DidValidateMemoizedValue { database_key: c(2) } }", + ], + } + "###); }