diff --git a/src/dependencies.rs b/src/dependencies.rs deleted file mode 100644 index a7e9ca7..0000000 --- a/src/dependencies.rs +++ /dev/null @@ -1,245 +0,0 @@ -use crate::runtime::QueryDescriptorSet; -use crate::runtime::Revision; -use crate::runtime::StampedValue; -use crate::CycleDetected; -use crate::Database; -use crate::QueryDescriptor; -use crate::QueryFunction; -use crate::QueryStorageOps; -use crate::QueryTable; -use log::debug; -use parking_lot::{RwLock, RwLockUpgradableReadGuard}; -use rustc_hash::FxHashMap; -use std::any::Any; -use std::cell::RefCell; -use std::collections::hash_map::Entry; -use std::fmt::Debug; -use std::fmt::Display; -use std::fmt::Write; -use std::hash::Hash; - -/// "Dependency" queries just track their dependencies and not the -/// actual value (which they produce on demand). This lessens the -/// storage requirements. -pub struct DependencyStorage -where - Q: QueryFunction, - DB: Database, -{ - map: RwLock>>, -} - -/// Defines the "current state" of query's memoized results. -enum QueryState -where - DB: Database, -{ - /// We are currently computing the result of this query; if we see - /// this value in the table, it indeeds a cycle. - InProgress, - - /// We have computed the query already, and here is the result. - Memoized(Memo), -} - -struct Memo -where - DB: Database, -{ - inputs: QueryDescriptorSet, - - /// Last time that we checked our inputs to see if they have - /// changed. If this is equal to the current revision, then the - /// value is up to date. If not, we need to check our inputs and - /// see if any of them have changed since our last check -- if so, - /// we'll need to re-execute. - verified_at: Revision, - - /// Last time that our value changed. - changed_at: Revision, -} - -impl Default for DependencyStorage -where - Q: QueryFunction, - DB: Database, -{ - fn default() -> Self { - DependencyStorage { - map: RwLock::new(FxHashMap::default()), - } - } -} - -impl DependencyStorage -where - Q: QueryFunction, - DB: Database, -{ - fn read( - &self, - db: &DB, - key: &Q::Key, - descriptor: &DB::QueryDescriptor, - ) -> Result, CycleDetected> { - let revision_now = db.salsa_runtime().current_revision(); - - debug!( - "{:?}({:?}): invoked at {:?}", - Q::default(), - key, - revision_now, - ); - - { - let map_read = self.map.upgradable_read(); - if let Some(value) = map_read.get(key) { - match value { - QueryState::InProgress => return Err(CycleDetected), - QueryState::Memoized(_) => {} - } - } - - let mut map_write = RwLockUpgradableReadGuard::upgrade(map_read); - map_write.insert(key.clone(), QueryState::InProgress); - } - - // Note that, unlike with a memoized query, we must always - // re-execute. - let (stamped_value, inputs) = db - .salsa_runtime() - .execute_query_implementation::(db, descriptor, key); - - // We assume that query is side-effect free -- that is, does - // not mutate the "inputs" to the query system. Sanity check - // that assumption here, at least to the best of our ability. - assert_eq!( - db.salsa_runtime().current_revision(), - revision_now, - "revision altered during query execution", - ); - - { - let mut map_write = self.map.write(); - - let old_value = map_write.insert( - key.clone(), - QueryState::Memoized(Memo { - inputs, - verified_at: revision_now, - changed_at: stamped_value.changed_at, - }), - ); - assert!( - match old_value { - Some(QueryState::InProgress) => true, - _ => false, - }, - "expected in-progress state", - ); - } - - Ok(stamped_value) - } - - fn overwrite_placeholder( - &self, - map_write: &mut FxHashMap>, - key: &Q::Key, - value: Option>, - ) { - let old_value = if let Some(v) = value { - map_write.insert(key.clone(), v) - } else { - map_write.remove(key) - }; - - assert!( - match old_value { - Some(QueryState::InProgress) => true, - _ => false, - }, - "expected in-progress state", - ); - } -} - -impl QueryStorageOps for DependencyStorage -where - Q: QueryFunction, - DB: Database, -{ - fn try_fetch<'q>( - &self, - db: &'q DB, - key: &Q::Key, - descriptor: &DB::QueryDescriptor, - ) -> Result { - let StampedValue { value, changed_at } = self.read(db, key, &descriptor)?; - - db.salsa_runtime().report_query_read(descriptor, changed_at); - - Ok(value) - } - - fn maybe_changed_since( - &self, - db: &'q DB, - revision: Revision, - key: &Q::Key, - _descriptor: &DB::QueryDescriptor, - ) -> bool { - let revision_now = db.salsa_runtime().current_revision(); - - debug!( - "{:?}({:?})::maybe_changed_since(revision={:?}, revision_now={:?})", - Q::default(), - key, - revision, - revision_now, - ); - - let value = { - let map_read = self.map.upgradable_read(); - match map_read.get(key) { - None | Some(QueryState::InProgress) => return true, - Some(QueryState::Memoized(memo)) => { - // If our memo is still up to date, then check if we've - // changed since the revision. - if memo.verified_at == revision_now { - return memo.changed_at > revision; - } - } - } - - let mut map_write = RwLockUpgradableReadGuard::upgrade(map_read); - map_write.insert(key.clone(), QueryState::InProgress) - }; - - // Otherwise, walk the inputs we had and check them. Note that - // we don't want to hold the lock while we do this. - let mut memo = match value { - Some(QueryState::Memoized(memo)) => memo, - _ => unreachable!(), - }; - - if memo - .inputs - .iter() - .all(|old_input| !old_input.maybe_changed_since(db, memo.verified_at)) - { - memo.verified_at = revision_now; - self.overwrite_placeholder( - &mut self.map.write(), - key, - Some(QueryState::Memoized(memo)), - ); - return false; - } - - // Just remove the existing entry. It's out of date. - self.overwrite_placeholder(&mut self.map.write(), key, None); - - true - } -} diff --git a/src/lib.rs b/src/lib.rs index 54d2a70..4a5b513 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,7 +16,6 @@ use std::fmt::Display; use std::fmt::Write; use std::hash::Hash; -pub mod dependencies; pub mod input; pub mod memoized; pub mod runtime; @@ -435,7 +434,7 @@ macro_rules! query_group { ( @storage_ty[$DB:ident, $Self:ident, dependencies] ) => { - $crate::dependencies::DependencyStorage<$DB, $Self> + $crate::memoized::DependencyStorage<$DB, $Self> }; ( diff --git a/src/memoized.rs b/src/memoized.rs index 64ab247..5a6e123 100644 --- a/src/memoized.rs +++ b/src/memoized.rs @@ -18,16 +18,56 @@ use std::fmt::Debug; use std::fmt::Display; use std::fmt::Write; use std::hash::Hash; +use std::marker::PhantomData; /// Memoized queries store the result plus a list of the other queries /// that they invoked. This means we can avoid recomputing them when /// none of those inputs have changed. -pub struct MemoizedStorage +pub type MemoizedStorage = WeakMemoizedStorage; + +/// "Dependency" queries just track their dependencies and not the +/// actual value (which they produce on demand). This lessens the +/// storage requirements. +pub type DependencyStorage = WeakMemoizedStorage; + +pub struct WeakMemoizedStorage +where + Q: QueryFunction, + DB: Database, + M: ShouldMemoizeValue, +{ + map: RwLock>>, + m: PhantomData, +} + +pub trait ShouldMemoizeValue where Q: QueryFunction, DB: Database, { - map: RwLock>>, + fn should_memoize_value(key: &Q::Key) -> bool; +} + +pub enum AlwaysMemoizeValue {} +impl ShouldMemoizeValue for AlwaysMemoizeValue +where + Q: QueryFunction, + DB: Database, +{ + fn should_memoize_value(_key: &Q::Key) -> bool { + true + } +} + +pub enum NeverMemoizeValue {} +impl ShouldMemoizeValue for NeverMemoizeValue +where + Q: QueryFunction, + DB: Database, +{ + fn should_memoize_value(_key: &Q::Key) -> bool { + false + } } /// Defines the "current state" of query's memoized results. @@ -49,7 +89,11 @@ where Q: QueryFunction, DB: Database, { - stamped_value: StampedValue, + /// Last time the value has actually changed. + /// changed_at can be less than verified_at. + changed_at: Revision, + /// The result of the query, if we decide to memoize it. + value: Option, inputs: QueryDescriptorSet, @@ -61,22 +105,25 @@ where verified_at: Revision, } -impl Default for MemoizedStorage +impl Default for WeakMemoizedStorage where Q: QueryFunction, DB: Database, + M: ShouldMemoizeValue, { fn default() -> Self { - MemoizedStorage { + WeakMemoizedStorage { map: RwLock::new(FxHashMap::default()), + m: PhantomData, } } } -impl MemoizedStorage +impl WeakMemoizedStorage where Q: QueryFunction, DB: Database, + M: ShouldMemoizeValue, { fn read( &self, @@ -105,16 +152,22 @@ where key, m.verified_at, ); - + // We've found that the query is defenitelly up-to-date. + // If the value is also memoized, return it. + // Otherwise fallback to recomputing the value. if m.verified_at == revision_now { - debug!( - "{:?}({:?}): returning memoized value (changed_at={:?})", - Q::default(), - key, - m.stamped_value.changed_at, - ); - - return Ok(m.stamped_value.clone()); + if let Some(value) = &m.value { + debug!( + "{:?}({:?}): returning memoized value (changed_at={:?})", + Q::default(), + key, + m.changed_at, + ); + return Ok(StampedValue { + value: value.clone(), + changed_at: m.changed_at, + }); + }; } } } @@ -129,25 +182,29 @@ where // first things first, let's walk over each of our previous // inputs and check whether they are out of date. if let Some(QueryState::Memoized(old_memo)) = &mut old_value { - if old_memo.inputs.iter().all(|old_input| { - !old_input.maybe_changed_since(db, old_memo.stamped_value.changed_at) - }) { + if old_memo + .inputs + .iter() + .all(|old_input| !old_input.maybe_changed_since(db, old_memo.changed_at)) + { debug!("{:?}({:?}): inputs still valid", Q::default(), key); + if old_memo.value.is_some() { + // If none of out inputs have changed since the last time we refreshed + // our value, then our value must still be good. We'll just patch + // the verified-at date and re-use it. + old_memo.verified_at = revision_now; + let value = old_memo.value.clone().unwrap(); + let changed_at = old_memo.changed_at; - // If none of out inputs have changed since the last time we refreshed - // our value, then our value must still be good. We'll just patch - // the verified-at date and re-use it. - old_memo.verified_at = revision_now; - let stamped_value = old_memo.stamped_value.clone(); - - let mut map_write = self.map.write(); - self.overwrite_placeholder(&mut map_write, key, old_value.unwrap()); - return Ok(stamped_value); + let mut map_write = self.map.write(); + self.overwrite_placeholder(&mut map_write, key, old_value.unwrap()); + return Ok(StampedValue { value, changed_at }); + } } } - // Query was not previously executed or value is potentially - // stale. Let's execute! + // Query was not previously executed, or value is potentially + // stale, or value is absent. Let's execute! let (mut stamped_value, inputs) = db .salsa_runtime() .execute_query_implementation::(db, descriptor, key); @@ -166,19 +223,25 @@ where // "backdate" its `changed_at` revision to be the same as the // old value. if let Some(QueryState::Memoized(old_memo)) = &old_value { - if old_memo.stamped_value.value == stamped_value.value { - assert!(old_memo.stamped_value.changed_at <= stamped_value.changed_at); - stamped_value.changed_at = old_memo.stamped_value.changed_at; + if old_memo.value.as_ref() == Some(&stamped_value.value) { + assert!(old_memo.changed_at <= stamped_value.changed_at); + stamped_value.changed_at = old_memo.changed_at; } } { + let value = if self.should_memoize_value(key) { + Some(stamped_value.value.clone()) + } else { + None + }; let mut map_write = self.map.write(); self.overwrite_placeholder( &mut map_write, key, QueryState::Memoized(Memo { - stamped_value: stamped_value.clone(), + changed_at: stamped_value.changed_at, + value, inputs, verified_at: revision_now, }), @@ -203,12 +266,17 @@ where "expected in-progress state", ); } + + fn should_memoize_value(&self, key: &Q::Key) -> bool { + M::should_memoize_value(key) + } } -impl QueryStorageOps for MemoizedStorage +impl QueryStorageOps for WeakMemoizedStorage where Q: QueryFunction, DB: Database, + M: ShouldMemoizeValue, { fn try_fetch<'q>( &self, @@ -240,32 +308,60 @@ where revision_now, ); - // Check for the case where we have no cache entry, or our cache - // entry is up to date (common case): - { - let map_read = self.map.read(); + let value = { + let map_read = self.map.upgradable_read(); match map_read.get(key) { None | Some(QueryState::InProgress) => return true, Some(QueryState::Memoized(memo)) => { - if memo.verified_at >= revision_now { - return memo.stamped_value.changed_at > revision; + // If our memo is still up to date, then check if we've + // changed since the revision. + if memo.verified_at == revision_now { + return memo.changed_at > revision; + } + if memo.value.is_some() { + // Otherwise, if we cache values, fall back to the full read to compute the result. + drop(memo); + drop(map_read); + return match self.read(db, key, descriptor) { + Ok(v) => v.changed_at > revision, + Err(CycleDetected) => true, + }; } } - } + }; + // If, however, we don't cache values, then optimistically + // try to advance `verified_at` by walking the inputs. + let mut map_write = RwLockUpgradableReadGuard::upgrade(map_read); + map_write.insert(key.clone(), QueryState::InProgress) + }; + + let mut memo = match value { + Some(QueryState::Memoized(memo)) => memo, + _ => unreachable!(), + }; + + if memo + .inputs + .iter() + .all(|old_input| !old_input.maybe_changed_since(db, memo.verified_at)) + { + memo.verified_at = revision_now; + self.overwrite_placeholder(&mut self.map.write(), key, QueryState::Memoized(memo)); + return false; } - // Otherwise fall back to the full read to compute the result. - match self.read(db, key, descriptor) { - Ok(v) => v.changed_at > revision, - Err(CycleDetected) => true, - } + // Just remove the existing entry. It's out of date. + self.map.write().remove(key); + + true } } -impl UncheckedMutQueryStorageOps for MemoizedStorage +impl UncheckedMutQueryStorageOps for WeakMemoizedStorage where Q: QueryFunction, DB: Database, + M: ShouldMemoizeValue, { fn set_unchecked(&self, db: &DB, key: &Q::Key, value: Q::Value) { let key = key.clone(); @@ -277,7 +373,8 @@ where map_write.insert( key, QueryState::Memoized(Memo { - stamped_value: StampedValue { value, changed_at }, + value: Some(value), + changed_at, inputs: QueryDescriptorSet::new(), verified_at: changed_at, }),