From 356392578bb86d2aa287940bfe95ae6355f286a3 Mon Sep 17 00:00:00 2001 From: Niko Matsakis Date: Mon, 8 Nov 2021 06:32:03 -0500 Subject: [PATCH] new parallel friendly algorithm --- Cargo.toml | 2 + src/derived.rs | 176 ++-- src/derived/execute.rs | 134 +++ src/derived/fetch.rs | 115 +++ src/derived/key_to_key_index.rs | 62 ++ src/derived/lru.rs | 39 + src/derived/maybe_changed_after.rs | 179 ++++ src/derived/memo.rs | 107 +++ src/derived/slot.rs | 871 ------------------ src/derived/sync.rs | 87 ++ src/hash.rs | 1 + src/lib.rs | 1 - src/lru.rs | 335 ------- src/runtime/local_state.rs | 14 +- tests/incremental/memoized_volatile.rs | 1 + tests/on_demand_inputs.rs | 2 + tests/panic_safely.rs | 2 +- tests/parallel/parallel_cycle_none_recover.rs | 2 +- 18 files changed, 812 insertions(+), 1318 deletions(-) create mode 100644 src/derived/execute.rs create mode 100644 src/derived/fetch.rs create mode 100644 src/derived/key_to_key_index.rs create mode 100644 src/derived/lru.rs create mode 100644 src/derived/maybe_changed_after.rs create mode 100644 src/derived/memo.rs delete mode 100644 src/derived/slot.rs create mode 100644 src/derived/sync.rs delete mode 100644 src/lru.rs diff --git a/Cargo.toml b/Cargo.toml index c391f02c..cca0b32b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,8 +8,10 @@ repository = "https://github.com/salsa-rs/salsa" description = "A generic framework for on-demand, incrementalized computation (experimental)" [dependencies] +arc-swap = "1.4.0" crossbeam-utils = { version = "0.8", default-features = false } dashmap = "4.0.2" +hashlink = "0.7.0" indexmap = "1.0.1" lock_api = "0.4" log = "0.4.5" diff --git a/src/derived.rs b/src/derived.rs index 5228451e..48a36236 100644 --- a/src/derived.rs +++ b/src/derived.rs @@ -1,23 +1,28 @@ use crate::debug::TableEntry; use crate::durability::Durability; -use crate::hash::FxDashMap; -use crate::lru::Lru; use crate::plumbing::DerivedQueryStorageOps; use crate::plumbing::LruQueryStorageOps; use crate::plumbing::QueryFunction; use crate::plumbing::QueryStorageMassOps; use crate::plumbing::QueryStorageOps; -use crate::runtime::StampedValue; +use crate::runtime::local_state::QueryInputs; +use crate::runtime::local_state::QueryRevisions; use crate::Runtime; use crate::{Database, DatabaseKeyIndex, QueryDb, Revision}; -use crossbeam_utils::atomic::AtomicCell; use std::borrow::Borrow; use std::hash::Hash; use std::marker::PhantomData; -use std::sync::Arc; -mod slot; -use slot::Slot; +mod execute; +mod fetch; +mod key_to_key_index; +mod lru; +mod maybe_changed_after; +mod memo; +mod sync; + +//mod slot; +//use slot::Slot; /// Memoized queries store the result plus a list of the other queries /// that they invoked. This means we can avoid recomputing them when @@ -37,22 +42,13 @@ where MP: MemoizationPolicy, { group_index: u16, - lru_list: Lru>, - indices: AtomicCell, - index_map: FxDashMap, - slot_map: FxDashMap>, + lru: lru::Lru, + key_map: key_to_key_index::KeyToKeyIndex, + memo_map: memo::MemoMap, + sync_map: sync::SyncMap, policy: PhantomData, } -struct KeySlot -where - Q: QueryFunction, - MP: MemoizationPolicy, -{ - key: Q::Key, - slot: Arc>, -} - type DerivedKeyIndex = u32; impl std::panic::RefUnwindSafe for DerivedStorage @@ -107,52 +103,22 @@ where Q: QueryFunction, MP: MemoizationPolicy, { - fn slot_for_key(&self, key: &Q::Key) -> Arc> { - // Common case: get an existing key - if let Some(v) = self.index_map.get(key) { - let index = *v; - - // release the read-write lock early, for no particular reason - // apart from it bothers me - drop(v); - - return self.slot_for_key_index(index); - } - - // Less common case: (potentially) create a new slot - match self.index_map.entry(key.clone()) { - dashmap::mapref::entry::Entry::Occupied(entry) => self.slot_for_key_index(*entry.get()), - dashmap::mapref::entry::Entry::Vacant(entry) => { - let key_index = self.indices.fetch_add(1); - let database_key_index = DatabaseKeyIndex { - group_index: self.group_index, - query_index: Q::QUERY_INDEX, - key_index, - }; - let slot = Arc::new(Slot::new(key.clone(), database_key_index)); - // Subtle: store the new slot *before* the new index, so that - // other threads only see the new index once the slot is also available. - self.slot_map.insert( - key_index, - KeySlot { - key: key.clone(), - slot: slot.clone(), - }, - ); - entry.insert(key_index); - slot - } + fn database_key_index(&self, key_index: DerivedKeyIndex) -> DatabaseKeyIndex { + DatabaseKeyIndex { + group_index: self.group_index, + query_index: Q::QUERY_INDEX, + key_index: key_index, } } - fn slot_for_key_index(&self, index: DerivedKeyIndex) -> Arc> { - return self.slot_map.get(&index).unwrap().slot.clone(); - } - - fn slot_for_db_index(&self, index: DatabaseKeyIndex) -> Arc> { + fn assert_our_key_index(&self, index: DatabaseKeyIndex) { assert_eq!(index.group_index, self.group_index); assert_eq!(index.query_index, Q::QUERY_INDEX); - self.slot_for_key_index(index.key_index) + } + + fn key_index(&self, index: DatabaseKeyIndex) -> DerivedKeyIndex { + self.assert_our_key_index(index); + index.key_index } } @@ -166,11 +132,11 @@ where fn new(group_index: u16) -> Self { DerivedStorage { group_index, - index_map: Default::default(), - slot_map: Default::default(), - lru_list: Default::default(), + lru: Default::default(), + key_map: Default::default(), + memo_map: Default::default(), + sync_map: Default::default(), policy: PhantomData, - indices: Default::default(), } } @@ -180,58 +146,47 @@ where index: DatabaseKeyIndex, fmt: &mut std::fmt::Formatter<'_>, ) -> std::fmt::Result { - assert_eq!(index.group_index, self.group_index); - assert_eq!(index.query_index, Q::QUERY_INDEX); - let key_slot = self.slot_map.get(&index.key_index).unwrap(); - write!(fmt, "{}({:?})", Q::QUERY_NAME, key_slot.key) + let key_index = self.key_index(index); + let key = self.key_map.key_for_key_index(key_index); + write!(fmt, "{}({:?})", Q::QUERY_NAME, key) } fn maybe_changed_after( &self, db: &>::DynDb, - input: DatabaseKeyIndex, + database_key_index: DatabaseKeyIndex, revision: Revision, ) -> bool { debug_assert!(revision < db.salsa_runtime().current_revision()); - let slot = self.slot_for_db_index(input); - slot.maybe_changed_after(db, revision) + let key_index = self.key_index(database_key_index); + self.maybe_changed_after(db, key_index, revision) } fn fetch(&self, db: &>::DynDb, key: &Q::Key) -> Q::Value { db.unwind_if_cancelled(); - - let slot = self.slot_for_key(key); - let StampedValue { - value, - durability, - changed_at, - } = slot.read(db); - - if let Some(evicted) = self.lru_list.record_use(&slot) { - evicted.evict(); - } - - db.salsa_runtime() - .report_query_read_and_unwind_if_cycle_resulted( - slot.database_key_index(), - durability, - changed_at, - ); - - value + let key_index = self.key_map.key_index_for_key(key); + self.fetch(db, key_index) } - fn durability(&self, db: &>::DynDb, key: &Q::Key) -> Durability { - self.slot_for_key(key).durability(db) + fn durability(&self, _db: &>::DynDb, key: &Q::Key) -> Durability { + let key_index = self.key_map.key_index_for_key(key); + if let Some(memo) = self.memo_map.get(key_index) { + memo.revisions.durability + } else { + Durability::LOW + } } fn entries(&self, _db: &>::DynDb) -> C where C: std::iter::FromIterator>, { - self.slot_map + self.memo_map .iter() - .filter_map(|r| r.value().slot.as_table_entry()) + .map(|(key_index, memo)| { + let key = self.key_map.key_for_key_index(key_index); + TableEntry::new(key, memo.value.clone()) + }) .collect() } } @@ -242,10 +197,8 @@ where MP: MemoizationPolicy, { fn purge(&self) { - self.lru_list.purge(); - self.indices.store(0); - self.index_map.clear(); - self.slot_map.clear(); + self.lru.set_capacity(0); + self.memo_map.clear(); } } @@ -255,7 +208,7 @@ where MP: MemoizationPolicy, { fn set_lru_capacity(&self, new_capacity: usize) { - self.lru_list.set_lru_capacity(new_capacity); + self.lru.set_capacity(new_capacity); } } @@ -270,13 +223,20 @@ where Q::Key: Borrow, { runtime.with_incremented_revision(|new_revision| { - if let Some(key_index) = self.index_map.get(key) { - let slot = self.slot_for_key_index(*key_index); - if let Some(durability) = slot.invalidate(new_revision) { - return Some(durability); - } - } - None + let key_index = self.key_map.existing_key_index_for_key(key)?; + let memo = self.memo_map.get(key_index)?; + let invalidated_revisions = QueryRevisions { + changed_at: new_revision, + durability: memo.revisions.durability, + inputs: QueryInputs::Untracked, + }; + let new_memo = memo::Memo::new( + memo.value.clone(), + memo.verified_at.load(), + invalidated_revisions, + ); + self.memo_map.insert(key_index, new_memo); + Some(memo.revisions.durability) }) } } diff --git a/src/derived/execute.rs b/src/derived/execute.rs new file mode 100644 index 00000000..229a565d --- /dev/null +++ b/src/derived/execute.rs @@ -0,0 +1,134 @@ +use std::sync::Arc; + +use crate::{ + plumbing::QueryFunction, + runtime::{local_state::ActiveQueryGuard, StampedValue}, + Cycle, Database, Event, EventKind, QueryDb, +}; + +use super::{memo::Memo, DerivedStorage, MemoizationPolicy}; + +impl DerivedStorage +where + Q: QueryFunction, + MP: MemoizationPolicy, +{ + /// Executes the query function for the given `active_query`. Creates and stores + /// a new memo with the result, backdated if possible. Once this completes, + /// the query will have been popped off the active query stack. + /// + /// # Parameters + /// + /// * `db`, the database. + /// * `active_query`, the active stack frame for the query to execute. + /// * `opt_old_memo`, the older memo, if any existed. Used for backdated. + pub(super) fn execute( + &self, + db: &>::DynDb, + active_query: ActiveQueryGuard<'_>, + opt_old_memo: Option>>, + ) -> StampedValue { + let runtime = db.salsa_runtime(); + let revision_now = runtime.current_revision(); + let database_key_index = active_query.database_key_index; + + log::info!("{:?}: executing query", database_key_index.debug(db)); + + db.salsa_event(Event { + runtime_id: db.salsa_runtime().id(), + kind: EventKind::WillExecute { + database_key: database_key_index, + }, + }); + + // Query was not previously executed, or value is potentially + // stale, or value is absent. Let's execute! + let database_key_index = active_query.database_key_index; + let key_index = database_key_index.key_index; + let key = self.key_map.key_for_key_index(key_index); + let value = match Cycle::catch(|| Q::execute(db, key.clone())) { + Ok(v) => v, + Err(cycle) => { + log::debug!( + "{:?}: caught cycle {:?}, have strategy {:?}", + database_key_index.debug(db), + cycle, + Q::CYCLE_STRATEGY, + ); + match Q::CYCLE_STRATEGY { + crate::plumbing::CycleRecoveryStrategy::Panic => cycle.throw(), + crate::plumbing::CycleRecoveryStrategy::Fallback => { + if let Some(c) = active_query.take_cycle() { + assert!(c.is(&cycle)); + Q::cycle_fallback(db, &cycle, &key) + } else { + // we are not a participant in this cycle + debug_assert!(!cycle + .participant_keys() + .any(|k| k == database_key_index)); + cycle.throw() + } + } + } + } + }; + let mut revisions = active_query.pop(); + + // We assume that query is side-effect free -- that is, does + // not mutate the "inputs" to the query system. Sanity check + // that assumption here, at least to the best of our ability. + assert_eq!( + runtime.current_revision(), + revision_now, + "revision altered during query execution", + ); + + // If the new value is equal to the old one, then it didn't + // really change, even if some of its inputs have. So we can + // "backdate" its `changed_at` revision to be the same as the + // old value. + if let Some(old_memo) = &opt_old_memo { + if let Some(old_value) = &old_memo.value { + // Careful: if the value became less durable than it + // used to be, that is a "breaking change" that our + // consumers must be aware of. Becoming *more* durable + // is not. See the test `constant_to_non_constant`. + if revisions.durability >= old_memo.revisions.durability + && MP::memoized_value_eq(old_value, &value) + { + log::debug!( + "{:?}: read_upgrade: value is equal, back-dating to {:?}", + database_key_index.debug(db), + old_memo.revisions.changed_at, + ); + + assert!(old_memo.revisions.changed_at <= revisions.changed_at); + revisions.changed_at = old_memo.revisions.changed_at; + } + } + } + + let stamped_value = revisions.stamped_value(value); + + log::debug!( + "{:?}: read_upgrade: result.revisions = {:#?}", + database_key_index.debug(db), + revisions + ); + + self.memo_map.insert( + key_index, + Memo::new( + if MP::should_memoize_value(&key) { + Some(stamped_value.value.clone()) + } else { + None + }, + revision_now, + revisions, + ), + ); + + stamped_value + } +} diff --git a/src/derived/fetch.rs b/src/derived/fetch.rs new file mode 100644 index 00000000..0d07415e --- /dev/null +++ b/src/derived/fetch.rs @@ -0,0 +1,115 @@ +use arc_swap::Guard; + +use crate::{ + plumbing::{DatabaseOps, QueryFunction}, + runtime::{local_state::QueryInputs, StampedValue}, + Database, QueryDb, +}; + +use super::{DerivedKeyIndex, DerivedStorage, MemoizationPolicy}; + +impl DerivedStorage +where + Q: QueryFunction, + MP: MemoizationPolicy, +{ + #[inline] + pub(super) fn fetch( + &self, + db: &>::DynDb, + key_index: DerivedKeyIndex, + ) -> Q::Value { + let StampedValue { + value, + durability, + changed_at, + } = self.compute_value(db, key_index); + + if let Some(evicted) = self.lru.record_use(key_index) { + self.evict(evicted); + } + + db.salsa_runtime() + .report_query_read_and_unwind_if_cycle_resulted( + self.database_key_index(key_index), + durability, + changed_at, + ); + + value + } + + #[inline] + fn compute_value( + &self, + db: &>::DynDb, + key_index: DerivedKeyIndex, + ) -> StampedValue { + loop { + if let Some(value) = self + .fetch_hot(db, key_index) + .or_else(|| self.fetch_cold(db, key_index)) + { + return value; + } + } + } + + #[inline] + fn fetch_hot( + &self, + db: &>::DynDb, + key_index: DerivedKeyIndex, + ) -> Option> { + let memo_guard = self.memo_map.get(key_index); + if let Some(memo) = &memo_guard { + if let Some(value) = &memo.value { + let runtime = db.salsa_runtime(); + if self.shallow_verify_memo(db, runtime, self.database_key_index(key_index), memo) { + return Some(memo.revisions.stamped_value(value.clone())); + } + } + } + None + } + + fn fetch_cold( + &self, + db: &>::DynDb, + key_index: DerivedKeyIndex, + ) -> Option> { + let runtime = db.salsa_runtime(); + let database_key_index = self.database_key_index(key_index); + + // Try to claim this query: if someone else has claimed it already, go back and start again. + let _claim_guard = self.sync_map.claim(db.ops_database(), database_key_index)?; + + // Push the query on the stack. + let active_query = runtime.push_query(database_key_index); + + // Now that we've claimed the item, check again to see if there's a "hot" value. + // This time we can do a *deep* verify. Because this can recurse, don't hold the arcswap guard. + let opt_old_memo = self.memo_map.get(key_index).map(Guard::into_inner); + if let Some(old_memo) = &opt_old_memo { + if let Some(value) = &old_memo.value { + if self.deep_verify_memo(db, old_memo, &active_query) { + return Some(old_memo.revisions.stamped_value(value.clone())); + } + } + } + + Some(self.execute(db, active_query, opt_old_memo)) + } + + fn evict(&self, key_index: DerivedKeyIndex) { + if let Some(memo) = self.memo_map.get(key_index) { + // Careful: we can't evict memos with untracked inputs + // as their values cannot be reconstructed. + if let QueryInputs::Untracked = memo.revisions.inputs { + return; + } + + self.memo_map.remove(key_index); + } + } +} diff --git a/src/derived/key_to_key_index.rs b/src/derived/key_to_key_index.rs new file mode 100644 index 00000000..a6343d15 --- /dev/null +++ b/src/derived/key_to_key_index.rs @@ -0,0 +1,62 @@ +use crossbeam_utils::atomic::AtomicCell; +use std::borrow::Borrow; +use std::hash::Hash; + +use crate::hash::FxDashMap; + +use super::DerivedKeyIndex; + +pub(super) struct KeyToKeyIndex { + index_map: FxDashMap, + key_map: FxDashMap, + indices: AtomicCell, +} + +impl Default for KeyToKeyIndex +where + K: Hash + Eq, +{ + fn default() -> Self { + Self { + index_map: Default::default(), + key_map: Default::default(), + indices: Default::default(), + } + } +} + +impl KeyToKeyIndex +where + K: Hash + Eq + Clone, +{ + pub(super) fn key_index_for_key(&self, key: &K) -> DerivedKeyIndex { + // Common case: get an existing key + if let Some(v) = self.index_map.get(key) { + return *v; + } + + // Less common case: (potentially) create a new slot + *self.index_map.entry(key.clone()).or_insert_with(|| { + let key_index = self.indices.fetch_add(1); + self.key_map.insert(key_index, key.clone()); + key_index + }) + } + + pub(super) fn existing_key_index_for_key(&self, key: &S) -> Option + where + S: Eq + Hash, + K: Borrow, + { + // Common case: get an existing key + if let Some(v) = self.index_map.get(key) { + Some(*v) + } else { + None + } + } + + pub(super) fn key_for_key_index(&self, key_index: DerivedKeyIndex) -> K { + self.key_map.get(&key_index).unwrap().clone() + } +} diff --git a/src/derived/lru.rs b/src/derived/lru.rs new file mode 100644 index 00000000..4a9af335 --- /dev/null +++ b/src/derived/lru.rs @@ -0,0 +1,39 @@ +use crate::hash::FxLinkedHashSet; + +use super::DerivedKeyIndex; +use crossbeam_utils::atomic::AtomicCell; +use parking_lot::Mutex; + +#[derive(Default)] +pub(super) struct Lru { + capacity: AtomicCell, + set: Mutex>, +} + +impl Lru { + pub(super) fn record_use(&self, index: DerivedKeyIndex) -> Option { + let capacity = self.capacity.load(); + + if capacity == 0 { + // LRU is disabled + return None; + } + + let mut set = self.set.lock(); + set.insert(index); + if set.len() > capacity { + return set.pop_front(); + } + + None + } + + pub(super) fn set_capacity(&self, capacity: usize) { + self.capacity.store(capacity); + + if capacity == 0 { + let mut set = self.set.lock(); + *set = FxLinkedHashSet::default(); + } + } +} diff --git a/src/derived/maybe_changed_after.rs b/src/derived/maybe_changed_after.rs new file mode 100644 index 00000000..3c571e5a --- /dev/null +++ b/src/derived/maybe_changed_after.rs @@ -0,0 +1,179 @@ +use arc_swap::Guard; + +use crate::{ + plumbing::{DatabaseOps, QueryFunction}, + runtime::{ + local_state::{ActiveQueryGuard, QueryInputs}, + StampedValue, + }, + Database, DatabaseKeyIndex, QueryDb, Revision, Runtime, +}; + +use super::{memo::Memo, DerivedKeyIndex, DerivedStorage, MemoizationPolicy}; + +impl DerivedStorage +where + Q: QueryFunction, + MP: MemoizationPolicy, +{ + pub(super) fn maybe_changed_after( + &self, + db: &>::DynDb, + key_index: DerivedKeyIndex, + revision: Revision, + ) -> bool { + loop { + let runtime = db.salsa_runtime(); + let database_key_index = self.database_key_index(key_index); + + log::debug!( + "{:?}: maybe_changed_after(revision = {:?})", + database_key_index.debug(db), + revision, + ); + + // Check if we have a verified version: this is the hot path. + let memo_guard = self.memo_map.get(key_index); + if let Some(memo) = &memo_guard { + if self.shallow_verify_memo(db, runtime, database_key_index, memo) { + return memo.revisions.changed_at > revision; + } + drop(memo_guard); // release the arc-swap guard before cold path + if let Some(mcs) = self.maybe_changed_after_cold(db, key_index, revision) { + return mcs; + } else { + // We failed to claim, have to retry. + } + } else { + // No memo? Assume has changed. + return true; + } + } + } + + fn maybe_changed_after_cold( + &self, + db: &>::DynDb, + key_index: DerivedKeyIndex, + revision: Revision, + ) -> Option { + let runtime = db.salsa_runtime(); + let database_key_index = self.database_key_index(key_index); + + let _claim_guard = self.sync_map.claim(db.ops_database(), database_key_index)?; + let active_query = runtime.push_query(database_key_index); + + // Load the current memo, if any. Use a real arc, not an arc-swap guard, + // since we may recurse. + let old_memo = match self.memo_map.get(key_index) { + Some(m) => Guard::into_inner(m), + None => return Some(true), + }; + + log::debug!( + "{:?}: maybe_changed_after_cold, successful claim, revision = {:?}, old_memo = {:#?}", + database_key_index.debug(db), + revision, + old_memo + ); + + // Check if the inputs are still valid and we can just compare `changed_at`. + if self.deep_verify_memo(db, &old_memo, &active_query) { + return Some(old_memo.revisions.changed_at > revision); + } + + // If inputs have changed, but we have an old value, we can re-execute. + // It is possible the result will be equal to the old value and hence + // backdated. In that case, although we will have computed a new memo, + // the value has not logically changed. + if old_memo.value.is_some() { + let StampedValue { changed_at, .. } = self.execute(db, active_query, Some(old_memo)); + return Some(changed_at > revision); + } + + // Otherwise, nothing for it: have to consider the value to have changed. + Some(true) + } + + /// True if the memo's value and `changed_at` time is still valid in this revision. + /// Does only a shallow O(1) check, doesn't walk the dependencies. + #[inline] + pub(super) fn shallow_verify_memo( + &self, + db: &>::DynDb, + runtime: &Runtime, + database_key_index: DatabaseKeyIndex, + memo: &Memo, + ) -> bool { + let verified_at = memo.verified_at.load(); + let revision_now = runtime.current_revision(); + + log::debug!( + "{:?}: shallow_verify_memo(memo = {:#?})", + database_key_index.debug(db), + memo, + ); + + if verified_at == revision_now { + // Already verified. + return true; + } + + if memo.check_durability(runtime) { + // No input of the suitable durability has changed since last verified. + memo.mark_as_verified(db.ops_database(), runtime, database_key_index); + return true; + } + + false + } + + /// True if the memo's value and `changed_at` time is up to date in the current + /// revision. When this returns true, it also updates the memo's `verified_at` + /// field if needed to make future calls cheaper. + /// + /// Takes an [`ActiveQueryGuard`] argument because this function recursively + /// walks dependencies of `old_memo` and may even execute them to see if their + /// outputs have changed. As that could lead to cycles, it is important that the + /// query is on the stack. + pub(super) fn deep_verify_memo( + &self, + db: &>::DynDb, + old_memo: &Memo, + active_query: &ActiveQueryGuard<'_>, + ) -> bool { + let runtime = db.salsa_runtime(); + let database_key_index = active_query.database_key_index; + + log::debug!( + "{:?}: deep_verify_memo(old_memo = {:#?})", + database_key_index.debug(db), + old_memo + ); + + if self.shallow_verify_memo(db, runtime, database_key_index, old_memo) { + return true; + } + + match &old_memo.revisions.inputs { + QueryInputs::Untracked => { + // Untracked inputs? Have to assume that it changed. + return false; + } + QueryInputs::NoInputs => { + // No inputs, cannot have changed. + } + QueryInputs::Tracked { inputs } => { + let last_verified_at = old_memo.verified_at.load(); + for &input in inputs.iter() { + if db.maybe_changed_after(input, last_verified_at) { + return false; + } + } + } + } + + old_memo.mark_as_verified(db.ops_database(), runtime, database_key_index); + true + } +} diff --git a/src/derived/memo.rs b/src/derived/memo.rs new file mode 100644 index 00000000..c3eaeb21 --- /dev/null +++ b/src/derived/memo.rs @@ -0,0 +1,107 @@ +use std::sync::Arc; + +use arc_swap::{ArcSwap, Guard}; +use crossbeam_utils::atomic::AtomicCell; + +use crate::{ + hash::FxDashMap, runtime::local_state::QueryRevisions, DatabaseKeyIndex, Event, EventKind, + Revision, Runtime, +}; + +use super::DerivedKeyIndex; + +pub(super) struct MemoMap { + map: FxDashMap>>, +} + +impl Default for MemoMap { + fn default() -> Self { + Self { + map: Default::default(), + } + } +} + +impl MemoMap { + /// Inserts the memo for the given key; (atomically) overwrites any previously existing memo.- + pub(super) fn insert(&self, key: DerivedKeyIndex, memo: Memo) { + self.map.insert(key, ArcSwap::from(Arc::new(memo))); + } + + /// Removes any existing memo for the given key. + pub(super) fn remove(&self, key: DerivedKeyIndex) { + self.map.remove(&key); + } + + /// Loads the current memo for `key_index`. This does not hold any sort of + /// lock on the `memo_map` once it returns, so this memo could immediately + /// become outdated if other threads store into the `memo_map`. + pub(super) fn get(&self, key: DerivedKeyIndex) -> Option>>> { + self.map.get(&key).map(|v| v.load()) + } + + /// Iterates over the entries in the map. This holds a read lock while iteration continues. + pub(super) fn iter(&self) -> impl Iterator>)> + '_ { + self.map + .iter() + .map(move |r| (*r.key(), r.value().load_full())) + } + + /// Clears the memo of all entries. + pub(super) fn clear(&self) { + self.map.clear() + } +} + +#[derive(Debug)] +pub(super) struct Memo { + /// The result of the query, if we decide to memoize it. + pub(super) value: Option, + + /// Last revision when this memo was verified; this begins + /// as the current revision. + pub(super) verified_at: AtomicCell, + + /// Revision information + pub(super) revisions: QueryRevisions, +} + +impl Memo { + pub(super) fn new(value: Option, revision_now: Revision, revisions: QueryRevisions) -> Self { + Memo { + value, + verified_at: AtomicCell::new(revision_now), + revisions, + } + } + /// True if this memo is known not to have changed based on its durability. + pub(super) fn check_durability(&self, runtime: &Runtime) -> bool { + let last_changed = runtime.last_changed_revision(self.revisions.durability); + let verified_at = self.verified_at.load(); + log::debug!( + "check_durability(last_changed={:?} <= verified_at={:?}) = {:?}", + last_changed, + self.verified_at, + last_changed <= verified_at, + ); + last_changed <= verified_at + } + + /// Mark memo as having been verified in the `revision_now`, which should + /// be the current revision. + pub(super) fn mark_as_verified( + &self, + db: &dyn crate::Database, + runtime: &crate::Runtime, + database_key_index: DatabaseKeyIndex, + ) { + db.salsa_event(Event { + runtime_id: runtime.id(), + kind: EventKind::DidValidateMemoizedValue { + database_key: database_key_index, + }, + }); + + self.verified_at.store(runtime.current_revision()); + } +} diff --git a/src/derived/slot.rs b/src/derived/slot.rs deleted file mode 100644 index 41e788bb..00000000 --- a/src/derived/slot.rs +++ /dev/null @@ -1,871 +0,0 @@ -use crate::debug::TableEntry; -use crate::derived::MemoizationPolicy; -use crate::durability::Durability; -use crate::lru::LruIndex; -use crate::lru::LruNode; -use crate::plumbing::{DatabaseOps, QueryFunction}; -use crate::revision::Revision; -use crate::runtime::local_state::ActiveQueryGuard; -use crate::runtime::local_state::QueryInputs; -use crate::runtime::local_state::QueryRevisions; -use crate::runtime::Runtime; -use crate::runtime::RuntimeId; -use crate::runtime::StampedValue; -use crate::runtime::WaitResult; -use crate::Cycle; -use crate::{Database, DatabaseKeyIndex, Event, EventKind, QueryDb}; -use log::{debug, info}; -use parking_lot::{RawRwLock, RwLock}; -use std::marker::PhantomData; -use std::ops::Deref; -use std::sync::atomic::{AtomicBool, Ordering}; - -pub(super) struct Slot -where - Q: QueryFunction, - MP: MemoizationPolicy, -{ - key: Q::Key, - database_key_index: DatabaseKeyIndex, - state: RwLock>, - policy: PhantomData, - lru_index: LruIndex, -} - -/// Defines the "current state" of query's memoized results. -enum QueryState -where - Q: QueryFunction, -{ - NotComputed, - - /// The runtime with the given id is currently computing the - /// result of this query. - InProgress { - id: RuntimeId, - - /// Set to true if any other queries are blocked, - /// waiting for this query to complete. - anyone_waiting: AtomicBool, - }, - - /// We have computed the query already, and here is the result. - Memoized(Memo), -} - -struct Memo { - /// The result of the query, if we decide to memoize it. - value: Option, - - /// Last revision when this memo was verified; this begins - /// as the current revision. - pub(crate) verified_at: Revision, - - /// Revision information - revisions: QueryRevisions, -} - -/// Return value of `probe` helper. -enum ProbeState { - /// Another thread was active but has completed. - /// Try again! - Retry, - - /// No entry for this key at all. - NotComputed(G), - - /// There is an entry, but its contents have not been - /// verified in this revision. - Stale(G), - - /// There is an entry, and it has been verified - /// in this revision, but it has no cached - /// value. The `Revision` is the revision where the - /// value last changed (if we were to recompute it). - NoValue(G, Revision), - - /// There is an entry which has been verified, - /// and it has the following value-- or, we blocked - /// on another thread, and that resulted in a cycle. - UpToDate(V), -} - -/// Return value of `maybe_changed_after_probe` helper. -enum MaybeChangedSinceProbeState { - /// Another thread was active but has completed. - /// Try again! - Retry, - - /// Value may have changed in the given revision. - ChangedAt(Revision), - - /// There is a stale cache entry that has not been - /// verified in this revision, so we can't say. - Stale(G), -} - -impl Slot -where - Q: QueryFunction, - MP: MemoizationPolicy, -{ - pub(super) fn new(key: Q::Key, database_key_index: DatabaseKeyIndex) -> Self { - Self { - key, - database_key_index, - state: RwLock::new(QueryState::NotComputed), - lru_index: LruIndex::default(), - policy: PhantomData, - } - } - - pub(super) fn database_key_index(&self) -> DatabaseKeyIndex { - self.database_key_index - } - - pub(super) fn read(&self, db: &>::DynDb) -> StampedValue { - let runtime = db.salsa_runtime(); - - // NB: We don't need to worry about people modifying the - // revision out from under our feet. Either `db` is a frozen - // database, in which case there is a lock, or the mutator - // thread is the current thread, and it will be prevented from - // doing any `set` invocations while the query function runs. - let revision_now = runtime.current_revision(); - - info!("{:?}: invoked at {:?}", self, revision_now,); - - // First, do a check with a read-lock. - loop { - match self.probe(db, self.state.read(), runtime, revision_now) { - ProbeState::UpToDate(v) => return v, - ProbeState::Stale(..) | ProbeState::NoValue(..) | ProbeState::NotComputed(..) => { - break - } - ProbeState::Retry => continue, - } - } - - self.read_upgrade(db, revision_now) - } - - /// Second phase of a read operation: acquires an upgradable-read - /// and -- if needed -- validates whether inputs have changed, - /// recomputes value, etc. This is invoked after our initial probe - /// shows a potentially out of date value. - fn read_upgrade( - &self, - db: &>::DynDb, - revision_now: Revision, - ) -> StampedValue { - let runtime = db.salsa_runtime(); - - debug!("{:?}: read_upgrade(revision_now={:?})", self, revision_now,); - - // Check with an upgradable read to see if there is a value - // already. (This permits other readers but prevents anyone - // else from running `read_upgrade` at the same time.) - let mut old_memo = loop { - match self.probe(db, self.state.upgradable_read(), runtime, revision_now) { - ProbeState::UpToDate(v) => return v, - ProbeState::Stale(state) - | ProbeState::NotComputed(state) - | ProbeState::NoValue(state, _) => { - type RwLockUpgradableReadGuard<'a, T> = - lock_api::RwLockUpgradableReadGuard<'a, RawRwLock, T>; - - let mut state = RwLockUpgradableReadGuard::upgrade(state); - match std::mem::replace(&mut *state, QueryState::in_progress(runtime.id())) { - QueryState::Memoized(old_memo) => break Some(old_memo), - QueryState::InProgress { .. } => unreachable!(), - QueryState::NotComputed => break None, - } - } - ProbeState::Retry => continue, - } - }; - - let panic_guard = PanicGuard::new(self.database_key_index, self, runtime); - let active_query = runtime.push_query(self.database_key_index); - - // If we have an old-value, it *may* now be stale, since there - // has been a new revision since the last time we checked. So, - // first things first, let's walk over each of our previous - // inputs and check whether they are out of date. - if let Some(memo) = &mut old_memo { - if let Some(value) = memo.verify_value(db.ops_database(), revision_now, &active_query) { - info!("{:?}: validated old memoized value", self,); - - db.salsa_event(Event { - runtime_id: runtime.id(), - kind: EventKind::DidValidateMemoizedValue { - database_key: self.database_key_index, - }, - }); - - panic_guard.proceed(old_memo); - - return value; - } - } - - self.execute( - db, - runtime, - revision_now, - active_query, - panic_guard, - old_memo, - ) - } - - fn execute( - &self, - db: &>::DynDb, - runtime: &Runtime, - revision_now: Revision, - active_query: ActiveQueryGuard<'_>, - panic_guard: PanicGuard<'_, Q, MP>, - old_memo: Option>, - ) -> StampedValue { - log::info!("{:?}: executing query", self.database_key_index.debug(db)); - - db.salsa_event(Event { - runtime_id: db.salsa_runtime().id(), - kind: EventKind::WillExecute { - database_key: self.database_key_index, - }, - }); - - // Query was not previously executed, or value is potentially - // stale, or value is absent. Let's execute! - let value = match Cycle::catch(|| Q::execute(db, self.key.clone())) { - Ok(v) => v, - Err(cycle) => { - log::debug!( - "{:?}: caught cycle {:?}, have strategy {:?}", - self.database_key_index.debug(db), - cycle, - Q::CYCLE_STRATEGY, - ); - match Q::CYCLE_STRATEGY { - crate::plumbing::CycleRecoveryStrategy::Panic => { - panic_guard.proceed(None); - cycle.throw() - } - crate::plumbing::CycleRecoveryStrategy::Fallback => { - if let Some(c) = active_query.take_cycle() { - assert!(c.is(&cycle)); - Q::cycle_fallback(db, &cycle, &self.key) - } else { - // we are not a participant in this cycle - debug_assert!(!cycle - .participant_keys() - .any(|k| k == self.database_key_index)); - cycle.throw() - } - } - } - } - }; - - let mut revisions = active_query.pop(); - - // We assume that query is side-effect free -- that is, does - // not mutate the "inputs" to the query system. Sanity check - // that assumption here, at least to the best of our ability. - assert_eq!( - runtime.current_revision(), - revision_now, - "revision altered during query execution", - ); - - // If the new value is equal to the old one, then it didn't - // really change, even if some of its inputs have. So we can - // "backdate" its `changed_at` revision to be the same as the - // old value. - if let Some(old_memo) = &old_memo { - if let Some(old_value) = &old_memo.value { - // Careful: if the value became less durable than it - // used to be, that is a "breaking change" that our - // consumers must be aware of. Becoming *more* durable - // is not. See the test `constant_to_non_constant`. - if revisions.durability >= old_memo.revisions.durability - && MP::memoized_value_eq(old_value, &value) - { - debug!( - "read_upgrade({:?}): value is equal, back-dating to {:?}", - self, old_memo.revisions.changed_at, - ); - - assert!(old_memo.revisions.changed_at <= revisions.changed_at); - revisions.changed_at = old_memo.revisions.changed_at; - } - } - } - - let new_value = StampedValue { - value, - durability: revisions.durability, - changed_at: revisions.changed_at, - }; - - let memo_value = if self.should_memoize_value(&self.key) { - Some(new_value.value.clone()) - } else { - None - }; - - debug!( - "read_upgrade({:?}): result.revisions = {:#?}", - self, revisions, - ); - - panic_guard.proceed(Some(Memo { - value: memo_value, - verified_at: revision_now, - revisions, - })); - - new_value - } - - /// Helper for `read` that does a shallow check (not recursive) if we have an up-to-date value. - /// - /// Invoked with the guard `state` corresponding to the `QueryState` of some `Slot` (the guard - /// can be either read or write). Returns a suitable `ProbeState`: - /// - /// - `ProbeState::UpToDate(r)` if the table has an up-to-date value (or we blocked on another - /// thread that produced such a value). - /// - `ProbeState::StaleOrAbsent(g)` if either (a) there is no memo for this key, (b) the memo - /// has no value; or (c) the memo has not been verified at the current revision. - /// - /// Note that in case `ProbeState::UpToDate`, the lock will have been released. - fn probe( - &self, - db: &>::DynDb, - state: StateGuard, - runtime: &Runtime, - revision_now: Revision, - ) -> ProbeState, StateGuard> - where - StateGuard: Deref>, - { - match &*state { - QueryState::NotComputed => ProbeState::NotComputed(state), - - QueryState::InProgress { id, anyone_waiting } => { - let other_id = *id; - - // NB: `Ordering::Relaxed` is sufficient here, - // as there are no loads that are "gated" on this - // value. Everything that is written is also protected - // by a lock that must be acquired. The role of this - // boolean is to decide *whether* to acquire the lock, - // not to gate future atomic reads. - anyone_waiting.store(true, Ordering::Relaxed); - - self.block_on_or_unwind(db, runtime, other_id, state); - - // Other thread completely normally, so our value may be available now. - ProbeState::Retry - } - - QueryState::Memoized(memo) => { - debug!( - "{:?}: found memoized value, verified_at={:?}, changed_at={:?}", - self, memo.verified_at, memo.revisions.changed_at, - ); - - if memo.verified_at < revision_now { - return ProbeState::Stale(state); - } - - if let Some(value) = &memo.value { - let value = StampedValue { - durability: memo.revisions.durability, - changed_at: memo.revisions.changed_at, - value: value.clone(), - }; - - info!( - "{:?}: returning memoized value changed at {:?}", - self, value.changed_at - ); - - ProbeState::UpToDate(value) - } else { - let changed_at = memo.revisions.changed_at; - ProbeState::NoValue(state, changed_at) - } - } - } - } - - pub(super) fn durability(&self, db: &>::DynDb) -> Durability { - match &*self.state.read() { - QueryState::NotComputed => Durability::LOW, - QueryState::InProgress { .. } => panic!("query in progress"), - QueryState::Memoized(memo) => { - if memo.check_durability(db.salsa_runtime()) { - memo.revisions.durability - } else { - Durability::LOW - } - } - } - } - - pub(super) fn as_table_entry(&self) -> Option> { - match &*self.state.read() { - QueryState::NotComputed => None, - QueryState::InProgress { .. } => Some(TableEntry::new(self.key.clone(), None)), - QueryState::Memoized(memo) => { - Some(TableEntry::new(self.key.clone(), memo.value.clone())) - } - } - } - - pub(super) fn evict(&self) { - let mut state = self.state.write(); - if let QueryState::Memoized(memo) = &mut *state { - // Evicting a value with an untracked input could - // lead to inconsistencies. Note that we can't check - // `has_untracked_input` when we add the value to the cache, - // because inputs can become untracked in the next revision. - if memo.has_untracked_input() { - return; - } - memo.value = None; - } - } - - pub(super) fn invalidate(&self, new_revision: Revision) -> Option { - log::debug!("Slot::invalidate(new_revision = {:?})", new_revision); - match &mut *self.state.write() { - QueryState::Memoized(memo) => { - memo.revisions.inputs = QueryInputs::Untracked; - memo.revisions.changed_at = new_revision; - Some(memo.revisions.durability) - } - QueryState::NotComputed => None, - QueryState::InProgress { .. } => unreachable!(), - } - } - - pub(super) fn maybe_changed_after( - &self, - db: &>::DynDb, - revision: Revision, - ) -> bool { - let runtime = db.salsa_runtime(); - let revision_now = runtime.current_revision(); - - db.unwind_if_cancelled(); - - debug!( - "maybe_changed_after({:?}) called with revision={:?}, revision_now={:?}", - self, revision, revision_now, - ); - - // Do an initial probe with just the read-lock. - // - // If we find that a cache entry for the value is present - // but hasn't been verified in this revision, we'll have to - // do more. - loop { - match self.maybe_changed_after_probe(db, self.state.read(), runtime, revision_now) { - MaybeChangedSinceProbeState::Retry => continue, - MaybeChangedSinceProbeState::ChangedAt(changed_at) => return changed_at > revision, - MaybeChangedSinceProbeState::Stale(state) => { - drop(state); - return self.maybe_changed_after_upgrade(db, revision); - } - } - } - } - - fn maybe_changed_after_probe( - &self, - db: &>::DynDb, - state: StateGuard, - runtime: &Runtime, - revision_now: Revision, - ) -> MaybeChangedSinceProbeState - where - StateGuard: Deref>, - { - match self.probe(db, state, runtime, revision_now) { - ProbeState::Retry => MaybeChangedSinceProbeState::Retry, - - ProbeState::Stale(state) => MaybeChangedSinceProbeState::Stale(state), - - // If we know when value last changed, we can return right away. - // Note that we don't need the actual value to be available. - ProbeState::NoValue(_, changed_at) - | ProbeState::UpToDate(StampedValue { - value: _, - durability: _, - changed_at, - }) => MaybeChangedSinceProbeState::ChangedAt(changed_at), - - // If we have nothing cached, then value may have changed. - ProbeState::NotComputed(_) => MaybeChangedSinceProbeState::ChangedAt(revision_now), - } - } - - fn maybe_changed_after_upgrade( - &self, - db: &>::DynDb, - revision: Revision, - ) -> bool { - let runtime = db.salsa_runtime(); - let revision_now = runtime.current_revision(); - - // Get an upgradable read lock, which permits other reads but no writers. - // Probe again. If the value is stale (needs to be verified), then upgrade - // to a write lock and swap it with InProgress while we work. - let mut old_memo = match self.maybe_changed_after_probe( - db, - self.state.upgradable_read(), - runtime, - revision_now, - ) { - MaybeChangedSinceProbeState::ChangedAt(changed_at) => return changed_at > revision, - - // If another thread was active, then the cache line is going to be - // either verified or cleared out. Just recurse to figure out which. - // Note that we don't need an upgradable read. - MaybeChangedSinceProbeState::Retry => return self.maybe_changed_after(db, revision), - - MaybeChangedSinceProbeState::Stale(state) => { - type RwLockUpgradableReadGuard<'a, T> = - lock_api::RwLockUpgradableReadGuard<'a, RawRwLock, T>; - - let mut state = RwLockUpgradableReadGuard::upgrade(state); - match std::mem::replace(&mut *state, QueryState::in_progress(runtime.id())) { - QueryState::Memoized(old_memo) => old_memo, - QueryState::NotComputed | QueryState::InProgress { .. } => unreachable!(), - } - } - }; - - let panic_guard = PanicGuard::new(self.database_key_index, self, runtime); - let active_query = runtime.push_query(self.database_key_index); - - if old_memo.verify_revisions(db.ops_database(), revision_now, &active_query) { - let maybe_changed = old_memo.revisions.changed_at > revision; - panic_guard.proceed(Some(old_memo)); - maybe_changed - } else if old_memo.value.is_some() { - // We found that this memoized value may have changed - // but we have an old value. We can re-run the code and - // actually *check* if it has changed. - let StampedValue { changed_at, .. } = self.execute( - db, - runtime, - revision_now, - active_query, - panic_guard, - Some(old_memo), - ); - changed_at > revision - } else { - // We found that inputs to this memoized value may have chanced - // but we don't have an old value to compare against or re-use. - // No choice but to drop the memo and say that its value may have changed. - panic_guard.proceed(None); - true - } - } - - /// Helper: see [`Runtime::try_block_on_or_unwind`]. - fn block_on_or_unwind( - &self, - db: &>::DynDb, - runtime: &Runtime, - other_id: RuntimeId, - mutex_guard: MutexGuard, - ) { - runtime.block_on_or_unwind( - db.ops_database(), - self.database_key_index, - other_id, - mutex_guard, - ) - } - - fn should_memoize_value(&self, key: &Q::Key) -> bool { - MP::should_memoize_value(key) - } -} - -impl QueryState -where - Q: QueryFunction, -{ - fn in_progress(id: RuntimeId) -> Self { - QueryState::InProgress { - id, - anyone_waiting: Default::default(), - } - } -} - -struct PanicGuard<'me, Q, MP> -where - Q: QueryFunction, - MP: MemoizationPolicy, -{ - database_key_index: DatabaseKeyIndex, - slot: &'me Slot, - runtime: &'me Runtime, -} - -impl<'me, Q, MP> PanicGuard<'me, Q, MP> -where - Q: QueryFunction, - MP: MemoizationPolicy, -{ - fn new( - database_key_index: DatabaseKeyIndex, - slot: &'me Slot, - runtime: &'me Runtime, - ) -> Self { - Self { - database_key_index, - slot, - runtime, - } - } - - /// Indicates that we have concluded normally (without panicking). - /// If `opt_memo` is some, then this memo is installed as the new - /// memoized value. If `opt_memo` is `None`, then the slot is cleared - /// and has no value. - fn proceed(mut self, opt_memo: Option>) { - self.overwrite_placeholder(WaitResult::Completed, opt_memo); - std::mem::forget(self) - } - - /// Overwrites the `InProgress` placeholder for `key` that we - /// inserted; if others were blocked, waiting for us to finish, - /// then notify them. - fn overwrite_placeholder(&mut self, wait_result: WaitResult, opt_memo: Option>) { - let mut write = self.slot.state.write(); - - let old_value = match opt_memo { - // Replace the `InProgress` marker that we installed with the new - // memo, thus releasing our unique access to this key. - Some(memo) => std::mem::replace(&mut *write, QueryState::Memoized(memo)), - - // We had installed an `InProgress` marker, but we panicked before - // it could be removed. At this point, we therefore "own" unique - // access to our slot, so we can just remove the key. - None => std::mem::replace(&mut *write, QueryState::NotComputed), - }; - - match old_value { - QueryState::InProgress { id, anyone_waiting } => { - assert_eq!(id, self.runtime.id()); - - // NB: As noted on the `store`, `Ordering::Relaxed` is - // sufficient here. This boolean signals us on whether to - // acquire a mutex; the mutex will guarantee that all writes - // we are interested in are visible. - if anyone_waiting.load(Ordering::Relaxed) { - self.runtime - .unblock_queries_blocked_on(self.database_key_index, wait_result); - } - } - _ => panic!( - "\ -Unexpected panic during query evaluation, aborting the process. - -Please report this bug to https://github.com/salsa-rs/salsa/issues." - ), - } - } -} - -impl<'me, Q, MP> Drop for PanicGuard<'me, Q, MP> -where - Q: QueryFunction, - MP: MemoizationPolicy, -{ - fn drop(&mut self) { - if std::thread::panicking() { - // We panicked before we could proceed and need to remove `key`. - self.overwrite_placeholder(WaitResult::Panicked, None) - } else { - // If no panic occurred, then panic guard ought to be - // "forgotten" and so this Drop code should never run. - panic!(".forget() was not called") - } - } -} - -impl Memo -where - V: Clone, -{ - /// Determines whether the value stored in this memo (if any) is still - /// valid in the current revision. If so, returns a stamped value. - /// - /// If needed, this will walk each dependency and - /// recursively invoke `maybe_changed_after`, which may in turn - /// re-execute the dependency. This can cause cycles to occur, - /// so the current query must be pushed onto the - /// stack to permit cycle detection and recovery: therefore, - /// takes the `active_query` argument as evidence. - fn verify_value( - &mut self, - db: &dyn Database, - revision_now: Revision, - active_query: &ActiveQueryGuard<'_>, - ) -> Option> { - // If we don't have a memoized value, nothing to validate. - if self.value.is_none() { - return None; - } - if self.verify_revisions(db, revision_now, active_query) { - Some(StampedValue { - durability: self.revisions.durability, - changed_at: self.revisions.changed_at, - value: self.value.as_ref().unwrap().clone(), - }) - } else { - None - } - } - - /// Determines whether the value represented by this memo is still - /// valid in the current revision; note that the value itself is - /// not needed for this check. If needed, this will walk each - /// dependency and recursively invoke `maybe_changed_after`, which - /// may in turn re-execute the dependency. This can cause cycles to occur, - /// so the current query must be pushed onto the - /// stack to permit cycle detection and recovery: therefore, - /// takes the `active_query` argument as evidence. - fn verify_revisions( - &mut self, - db: &dyn Database, - revision_now: Revision, - _active_query: &ActiveQueryGuard<'_>, - ) -> bool { - assert!(self.verified_at != revision_now); - let verified_at = self.verified_at; - - debug!( - "verify_revisions: verified_at={:?}, revision_now={:?}, inputs={:#?}", - verified_at, revision_now, self.revisions.inputs - ); - - if self.check_durability(db.salsa_runtime()) { - return self.mark_value_as_verified(revision_now); - } - - match &self.revisions.inputs { - // We can't validate values that had untracked inputs; just have to - // re-execute. - QueryInputs::Untracked => { - return false; - } - - QueryInputs::NoInputs => {} - - // Check whether any of our inputs changed since the - // **last point where we were verified** (not since we - // last changed). This is important: if we have - // memoized values, then an input may have changed in - // revision R2, but we found that *our* value was the - // same regardless, so our change date is still - // R1. But our *verification* date will be R2, and we - // are only interested in finding out whether the - // input changed *again*. - QueryInputs::Tracked { inputs } => { - let changed_input = inputs - .iter() - .find(|&&input| db.maybe_changed_after(input, verified_at)); - if let Some(input) = changed_input { - debug!("validate_memoized_value: `{:?}` may have changed", input); - - return false; - } - } - }; - - self.mark_value_as_verified(revision_now) - } - - /// True if this memo is known not to have changed based on its durability. - fn check_durability(&self, runtime: &Runtime) -> bool { - let last_changed = runtime.last_changed_revision(self.revisions.durability); - debug!( - "check_durability(last_changed={:?} <= verified_at={:?}) = {:?}", - last_changed, - self.verified_at, - last_changed <= self.verified_at, - ); - last_changed <= self.verified_at - } - - fn mark_value_as_verified(&mut self, revision_now: Revision) -> bool { - self.verified_at = revision_now; - true - } - - fn has_untracked_input(&self) -> bool { - matches!(self.revisions.inputs, QueryInputs::Untracked) - } -} - -impl std::fmt::Debug for Slot -where - Q: QueryFunction, - MP: MemoizationPolicy, -{ - fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(fmt, "{:?}({:?})", Q::default(), self.key) - } -} - -impl LruNode for Slot -where - Q: QueryFunction, - MP: MemoizationPolicy, -{ - fn lru_index(&self) -> &LruIndex { - &self.lru_index - } -} - -/// Check that `Slot: Send + Sync` as long as -/// `DB::DatabaseData: Send + Sync`, which in turn implies that -/// `Q::Key: Send + Sync`, `Q::Value: Send + Sync`. -#[allow(dead_code)] -fn check_send_sync() -where - Q: QueryFunction, - MP: MemoizationPolicy, - Q::Key: Send + Sync, - Q::Value: Send + Sync, -{ - fn is_send_sync() {} - is_send_sync::>(); -} - -/// Check that `Slot: 'static` as long as -/// `DB::DatabaseData: 'static`, which in turn implies that -/// `Q::Key: 'static`, `Q::Value: 'static`. -#[allow(dead_code)] -fn check_static() -where - Q: QueryFunction + 'static, - MP: MemoizationPolicy + 'static, - Q::Key: 'static, - Q::Value: 'static, -{ - fn is_static() {} - is_static::>(); -} diff --git a/src/derived/sync.rs b/src/derived/sync.rs new file mode 100644 index 00000000..31d9e47c --- /dev/null +++ b/src/derived/sync.rs @@ -0,0 +1,87 @@ +use std::sync::atomic::{AtomicBool, Ordering}; + +use crate::{hash::FxDashMap, runtime::WaitResult, Database, DatabaseKeyIndex, Runtime, RuntimeId}; + +use super::DerivedKeyIndex; + +#[derive(Default)] +pub(super) struct SyncMap { + sync_map: FxDashMap, +} + +struct SyncState { + id: RuntimeId, + + /// Set to true if any other queries are blocked, + /// waiting for this query to complete. + anyone_waiting: AtomicBool, +} + +impl SyncMap { + pub(super) fn claim<'me>( + &'me self, + db: &'me dyn Database, + database_key_index: DatabaseKeyIndex, + ) -> Option> { + let runtime = db.salsa_runtime(); + match self.sync_map.entry(database_key_index.key_index) { + dashmap::mapref::entry::Entry::Vacant(entry) => { + entry.insert(SyncState { + id: runtime.id(), + anyone_waiting: AtomicBool::new(false), + }); + Some(ClaimGuard { + database_key: database_key_index, + runtime, + sync_map: &self.sync_map, + }) + } + dashmap::mapref::entry::Entry::Occupied(entry) => { + // NB: `Ordering::Relaxed` is sufficient here, + // as there are no loads that are "gated" on this + // value. Everything that is written is also protected + // by a lock that must be acquired. The role of this + // boolean is to decide *whether* to acquire the lock, + // not to gate future atomic reads. + entry.get().anyone_waiting.store(true, Ordering::Relaxed); + let other_id = entry.get().id; + runtime.block_on_or_unwind(db, database_key_index, other_id, entry); + None + } + } + } +} + +/// Marks an active 'claim' in the synchronization map. The claim is +/// released when this value is dropped. +#[must_use] +pub(super) struct ClaimGuard<'me> { + database_key: DatabaseKeyIndex, + runtime: &'me Runtime, + sync_map: &'me FxDashMap, +} + +impl<'me> ClaimGuard<'me> { + fn remove_from_map_and_unblock_queries(&self, wait_result: WaitResult) { + let (_, SyncState { anyone_waiting, .. }) = + self.sync_map.remove(&self.database_key.key_index).unwrap(); + + // NB: `Ordering::Relaxed` is sufficient here, + // see `store` above for explanation. + if anyone_waiting.load(Ordering::Relaxed) { + self.runtime + .unblock_queries_blocked_on(self.database_key, wait_result) + } + } +} + +impl<'me> Drop for ClaimGuard<'me> { + fn drop(&mut self) { + let wait_result = if std::thread::panicking() { + WaitResult::Panicked + } else { + WaitResult::Completed + }; + self.remove_from_map_and_unblock_queries(wait_result) + } +} diff --git a/src/hash.rs b/src/hash.rs index 4c7d2da7..ec94909e 100644 --- a/src/hash.rs +++ b/src/hash.rs @@ -2,3 +2,4 @@ pub(crate) type FxHasher = std::hash::BuildHasherDefault; pub(crate) type FxIndexSet = indexmap::IndexSet; pub(crate) type FxIndexMap = indexmap::IndexMap; pub(crate) type FxDashMap = dashmap::DashMap; +pub(crate) type FxLinkedHashSet = hashlink::LinkedHashSet; diff --git a/src/lib.rs b/src/lib.rs index a177b910..0496b1f7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,7 +16,6 @@ mod hash; mod input; mod intern_id; mod interned; -mod lru; mod revision; mod runtime; mod storage; diff --git a/src/lru.rs b/src/lru.rs deleted file mode 100644 index d6194bc4..00000000 --- a/src/lru.rs +++ /dev/null @@ -1,335 +0,0 @@ -use oorandom::Rand64; -use parking_lot::Mutex; -use std::fmt::Debug; -use std::sync::atomic::AtomicUsize; -use std::sync::atomic::Ordering; -use std::sync::Arc; - -/// A simple and approximate concurrent lru list. -/// -/// We assume but do not verify that each node is only used with one -/// list. If this is not the case, it is not *unsafe*, but panics and -/// weird results will ensue. -/// -/// Each "node" in the list is of type `Node` and must implement -/// `LruNode`, which is a trait that gives access to a field that -/// stores the index in the list. This index gives us a rough idea of -/// how recently the node has been used. -#[derive(Debug)] -pub(crate) struct Lru -where - Node: LruNode, -{ - green_zone: AtomicUsize, - data: Mutex>, -} - -#[derive(Debug)] -struct LruData { - end_red_zone: usize, - end_yellow_zone: usize, - end_green_zone: usize, - rng: Rand64, - entries: Vec>, -} - -pub(crate) trait LruNode: Sized + Debug { - fn lru_index(&self) -> &LruIndex; -} - -#[derive(Debug)] -pub(crate) struct LruIndex { - /// Index in the approprate LRU list, or std::usize::MAX if not a - /// member. - index: AtomicUsize, -} - -impl Default for Lru -where - Node: LruNode, -{ - fn default() -> Self { - Lru::new() - } -} - -// We always use a fixed seed for our randomness so that we have -// predictable results. -const LRU_SEED: &str = "Hello, Rustaceans"; - -impl Lru -where - Node: LruNode, -{ - /// Creates a new LRU list where LRU caching is disabled. - pub fn new() -> Self { - Self::with_seed(LRU_SEED) - } - - #[cfg_attr(not(test), allow(dead_code))] - fn with_seed(seed: &str) -> Self { - Lru { - green_zone: AtomicUsize::new(0), - data: Mutex::new(LruData::with_seed(seed)), - } - } - - /// Adjust the total number of nodes permitted to have a value at - /// once. If `len` is zero, this disables LRU caching completely. - pub fn set_lru_capacity(&self, len: usize) { - let mut data = self.data.lock(); - - // We require each zone to have at least 1 slot. Therefore, - // the length cannot be just 1 or 2. - if len == 0 { - self.green_zone.store(0, Ordering::Release); - data.resize(0, 0, 0); - } else { - let len = std::cmp::max(len, 3); - - // Top 10% is the green zone. This must be at least length 1. - let green_zone = std::cmp::max(len / 10, 1); - - // Next 20% is the yellow zone. - let yellow_zone = std::cmp::max(len / 5, 1); - - // Remaining 70% is the red zone. - let red_zone = len - yellow_zone - green_zone; - - // We need quick access to the green zone. - self.green_zone.store(green_zone, Ordering::Release); - - // Resize existing array. - data.resize(green_zone, yellow_zone, red_zone); - } - } - - /// Records that `node` was used. This may displace an old node (if the LRU limits are - pub fn record_use(&self, node: &Arc) -> Option> { - log::debug!("record_use(node={:?})", node); - - // Load green zone length and check if the LRU cache is even enabled. - let green_zone = self.green_zone.load(Ordering::Acquire); - log::debug!("record_use: green_zone={}", green_zone); - if green_zone == 0 { - return None; - } - - // Find current index of list (if any) and the current length - // of our green zone. - let index = node.lru_index().load(); - log::debug!("record_use: index={}", index); - - // Already a member of the list, and in the green zone -- nothing to do! - if index < green_zone { - return None; - } - - self.data.lock().record_use(node) - } - - pub fn purge(&self) { - self.green_zone.store(0, Ordering::SeqCst); - *self.data.lock() = LruData::with_seed(LRU_SEED); - } -} - -impl LruData -where - Node: LruNode, -{ - fn with_seed(seed_str: &str) -> Self { - Self::with_rng(rng_with_seed(seed_str)) - } - - fn with_rng(rng: Rand64) -> Self { - LruData { - end_yellow_zone: 0, - end_green_zone: 0, - end_red_zone: 0, - entries: Vec::new(), - rng, - } - } - - fn green_zone(&self) -> std::ops::Range { - 0..self.end_green_zone - } - - fn yellow_zone(&self) -> std::ops::Range { - self.end_green_zone..self.end_yellow_zone - } - - fn red_zone(&self) -> std::ops::Range { - self.end_yellow_zone..self.end_red_zone - } - - fn resize(&mut self, len_green_zone: usize, len_yellow_zone: usize, len_red_zone: usize) { - self.end_green_zone = len_green_zone; - self.end_yellow_zone = self.end_green_zone + len_yellow_zone; - self.end_red_zone = self.end_yellow_zone + len_red_zone; - let entries = std::mem::replace(&mut self.entries, Vec::with_capacity(self.end_red_zone)); - - log::debug!("green_zone = {:?}", self.green_zone()); - log::debug!("yellow_zone = {:?}", self.yellow_zone()); - log::debug!("red_zone = {:?}", self.red_zone()); - - // We expect to resize when the LRU cache is basically empty. - // So just forget all the old LRU indices to start. - for entry in entries { - entry.lru_index().clear(); - } - } - - /// Records that a node was used. If it is already a member of the - /// LRU list, it is promoted to the green zone (unless it's - /// already there). Otherwise, it is added to the list first and - /// *then* promoted to the green zone. Adding a new node to the - /// list may displace an old member of the red zone, in which case - /// that is returned. - fn record_use(&mut self, node: &Arc) -> Option> { - log::debug!("record_use(node={:?})", node); - - // NB: When this is invoked, we have typically already loaded - // the LRU index (to check if it is in green zone). But that - // check was done outside the lock and -- for all we know -- - // the index may have changed since. So we always reload. - let index = node.lru_index().load(); - - if index < self.end_green_zone { - None - } else if index < self.end_yellow_zone { - self.promote_yellow_to_green(node, index); - None - } else if index < self.end_red_zone { - self.promote_red_to_green(node, index); - None - } else { - self.insert_new(node) - } - } - - /// Inserts a node that is not yet a member of the LRU list. If - /// the list is at capacity, this can displace an existing member. - fn insert_new(&mut self, node: &Arc) -> Option> { - debug_assert!(!node.lru_index().is_in_lru()); - - // Easy case: we still have capacity. Push it, and then promote - // it up to the appropriate zone. - let len = self.entries.len(); - if len < self.end_red_zone { - self.entries.push(node.clone()); - node.lru_index().store(len); - log::debug!("inserted node {:?} at {}", node, len); - return self.record_use(node); - } - - // Harder case: no capacity. Create some by evicting somebody from red - // zone and then promoting. - let victim_index = self.pick_index(self.red_zone()); - let victim_node = std::mem::replace(&mut self.entries[victim_index], node.clone()); - log::debug!("evicting red node {:?} from {}", victim_node, victim_index); - victim_node.lru_index().clear(); - self.promote_red_to_green(node, victim_index); - Some(victim_node) - } - - /// Promotes the node `node`, stored at `red_index` (in the red - /// zone), into a green index, demoting yellow/green nodes at - /// random. - /// - /// NB: It is not required that `node.lru_index()` is up-to-date - /// when entering this method. - fn promote_red_to_green(&mut self, node: &Arc, red_index: usize) { - debug_assert!(self.red_zone().contains(&red_index)); - - // Pick a yellow at random and switch places with it. - // - // Subtle: we do not update `node.lru_index` *yet* -- we're - // going to invoke `self.promote_yellow` next, and it will get - // updated then. - let yellow_index = self.pick_index(self.yellow_zone()); - log::debug!( - "demoting yellow node {:?} from {} to red at {}", - self.entries[yellow_index], - yellow_index, - red_index, - ); - self.entries.swap(yellow_index, red_index); - self.entries[red_index].lru_index().store(red_index); - - // Now move ourselves up into the green zone. - self.promote_yellow_to_green(node, yellow_index); - } - - /// Promotes the node `node`, stored at `yellow_index` (in the - /// yellow zone), into a green index, demoting a green node at - /// random to replace it. - /// - /// NB: It is not required that `node.lru_index()` is up-to-date - /// when entering this method. - fn promote_yellow_to_green(&mut self, node: &Arc, yellow_index: usize) { - debug_assert!(self.yellow_zone().contains(&yellow_index)); - - // Pick a yellow at random and switch places with it. - let green_index = self.pick_index(self.green_zone()); - log::debug!( - "demoting green node {:?} from {} to yellow at {}", - self.entries[green_index], - green_index, - yellow_index - ); - self.entries.swap(green_index, yellow_index); - self.entries[yellow_index].lru_index().store(yellow_index); - node.lru_index().store(green_index); - - log::debug!("promoted {:?} to green index {}", node, green_index); - } - - fn pick_index(&mut self, zone: std::ops::Range) -> usize { - let end_index = std::cmp::min(zone.end, self.entries.len()); - self.rng.rand_range(zone.start as u64..end_index as u64) as usize - } -} - -impl Default for LruIndex { - fn default() -> Self { - Self { - index: AtomicUsize::new(std::usize::MAX), - } - } -} - -impl LruIndex { - fn load(&self) -> usize { - self.index.load(Ordering::Acquire) // see note on ordering below - } - - fn store(&self, value: usize) { - self.index.store(value, Ordering::Release) // see note on ordering below - } - - fn clear(&self) { - self.store(std::usize::MAX); - } - - fn is_in_lru(&self) -> bool { - self.load() != std::usize::MAX - } -} - -fn rng_with_seed(seed_str: &str) -> Rand64 { - let mut seed: [u8; 16] = [0; 16]; - for (i, &b) in seed_str.as_bytes().iter().take(16).enumerate() { - seed[i] = b; - } - Rand64::new(u128::from_le_bytes(seed)) -} - -// A note on ordering: -// -// I chose to use AcqRel for the ordering but I don't think it's -// strictly needed. All writes occur under a lock, so they should be -// ordered w/r/t one another. As for the reads, they can occur -// outside the lock, but they don't themselves enable dependent reads -// -- if the reads are out of bounds, we would acquire a lock. diff --git a/src/runtime/local_state.rs b/src/runtime/local_state.rs index d84dafa1..bf357d4f 100644 --- a/src/runtime/local_state.rs +++ b/src/runtime/local_state.rs @@ -8,6 +8,8 @@ use crate::DatabaseKeyIndex; use std::cell::RefCell; use std::sync::Arc; +use super::StampedValue; + /// State that is specific to a single execution thread. /// /// Internally, this type uses ref-cells. @@ -38,6 +40,16 @@ pub(crate) struct QueryRevisions { pub(crate) inputs: QueryInputs, } +impl QueryRevisions { + pub(crate) fn stamped_value(&self, value: V) -> StampedValue { + StampedValue { + value, + durability: self.durability, + changed_at: self.changed_at, + } + } +} + /// Every input. #[derive(Debug, Clone)] pub(crate) enum QueryInputs { @@ -180,7 +192,7 @@ impl std::panic::RefUnwindSafe for LocalState {} pub(crate) struct ActiveQueryGuard<'me> { local_state: &'me LocalState, push_len: usize, - database_key_index: DatabaseKeyIndex, + pub(crate) database_key_index: DatabaseKeyIndex, } impl ActiveQueryGuard<'_> { diff --git a/tests/incremental/memoized_volatile.rs b/tests/incremental/memoized_volatile.rs index 6dc50300..7e979777 100644 --- a/tests/incremental/memoized_volatile.rs +++ b/tests/incremental/memoized_volatile.rs @@ -1,5 +1,6 @@ use crate::implementation::{TestContext, TestContextImpl}; use salsa::{Database, Durability}; +use test_env_log::test; #[salsa::query_group(MemoizedVolatile)] pub(crate) trait MemoizedVolatileContext: TestContext { diff --git a/tests/on_demand_inputs.rs b/tests/on_demand_inputs.rs index 99092025..ef4b6b5a 100644 --- a/tests/on_demand_inputs.rs +++ b/tests/on_demand_inputs.rs @@ -144,8 +144,10 @@ fn on_demand_input_durability() { RefCell { value: [ "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: a(1) } }", + "Event { runtime_id: RuntimeId { counter: 0 }, kind: DidValidateMemoizedValue { database_key: b(1) } }", "Event { runtime_id: RuntimeId { counter: 0 }, kind: DidValidateMemoizedValue { database_key: c(1) } }", "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: a(2) } }", + "Event { runtime_id: RuntimeId { counter: 0 }, kind: DidValidateMemoizedValue { database_key: b(2) } }", "Event { runtime_id: RuntimeId { counter: 0 }, kind: DidValidateMemoizedValue { database_key: c(2) } }", ], } diff --git a/tests/panic_safely.rs b/tests/panic_safely.rs index e51a74e1..551e6c7a 100644 --- a/tests/panic_safely.rs +++ b/tests/panic_safely.rs @@ -70,7 +70,7 @@ fn should_panic_safely() { db.set_one(1); db.outer(); - assert_eq!(OUTER_CALLS.load(SeqCst), 2); + assert_eq!(OUTER_CALLS.load(SeqCst), 1); } } diff --git a/tests/parallel/parallel_cycle_none_recover.rs b/tests/parallel/parallel_cycle_none_recover.rs index 9cb8f2f9..c33d3b97 100644 --- a/tests/parallel/parallel_cycle_none_recover.rs +++ b/tests/parallel/parallel_cycle_none_recover.rs @@ -40,7 +40,7 @@ fn parallel_cycle_none_recover() { assert!(thread_a .join() .unwrap_err() - .downcast_ref::() + .downcast_ref::() .is_some()); }