mirror of
https://github.com/salsa-rs/salsa.git
synced 2025-01-12 16:35:21 +00:00
Merge pull request #36 from matklad/weak-memoized
Merge Memoized and Dependency storages
This commit is contained in:
commit
121821117d
3 changed files with 148 additions and 295 deletions
|
@ -1,245 +0,0 @@
|
|||
use crate::runtime::QueryDescriptorSet;
|
||||
use crate::runtime::Revision;
|
||||
use crate::runtime::StampedValue;
|
||||
use crate::CycleDetected;
|
||||
use crate::Database;
|
||||
use crate::QueryDescriptor;
|
||||
use crate::QueryFunction;
|
||||
use crate::QueryStorageOps;
|
||||
use crate::QueryTable;
|
||||
use log::debug;
|
||||
use parking_lot::{RwLock, RwLockUpgradableReadGuard};
|
||||
use rustc_hash::FxHashMap;
|
||||
use std::any::Any;
|
||||
use std::cell::RefCell;
|
||||
use std::collections::hash_map::Entry;
|
||||
use std::fmt::Debug;
|
||||
use std::fmt::Display;
|
||||
use std::fmt::Write;
|
||||
use std::hash::Hash;
|
||||
|
||||
/// "Dependency" queries just track their dependencies and not the
|
||||
/// actual value (which they produce on demand). This lessens the
|
||||
/// storage requirements.
|
||||
pub struct DependencyStorage<DB, Q>
|
||||
where
|
||||
Q: QueryFunction<DB>,
|
||||
DB: Database,
|
||||
{
|
||||
map: RwLock<FxHashMap<Q::Key, QueryState<DB>>>,
|
||||
}
|
||||
|
||||
/// Defines the "current state" of query's memoized results.
|
||||
enum QueryState<DB>
|
||||
where
|
||||
DB: Database,
|
||||
{
|
||||
/// We are currently computing the result of this query; if we see
|
||||
/// this value in the table, it indeeds a cycle.
|
||||
InProgress,
|
||||
|
||||
/// We have computed the query already, and here is the result.
|
||||
Memoized(Memo<DB>),
|
||||
}
|
||||
|
||||
struct Memo<DB>
|
||||
where
|
||||
DB: Database,
|
||||
{
|
||||
inputs: QueryDescriptorSet<DB>,
|
||||
|
||||
/// Last time that we checked our inputs to see if they have
|
||||
/// changed. If this is equal to the current revision, then the
|
||||
/// value is up to date. If not, we need to check our inputs and
|
||||
/// see if any of them have changed since our last check -- if so,
|
||||
/// we'll need to re-execute.
|
||||
verified_at: Revision,
|
||||
|
||||
/// Last time that our value changed.
|
||||
changed_at: Revision,
|
||||
}
|
||||
|
||||
impl<DB, Q> Default for DependencyStorage<DB, Q>
|
||||
where
|
||||
Q: QueryFunction<DB>,
|
||||
DB: Database,
|
||||
{
|
||||
fn default() -> Self {
|
||||
DependencyStorage {
|
||||
map: RwLock::new(FxHashMap::default()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<DB, Q> DependencyStorage<DB, Q>
|
||||
where
|
||||
Q: QueryFunction<DB>,
|
||||
DB: Database,
|
||||
{
|
||||
fn read(
|
||||
&self,
|
||||
db: &DB,
|
||||
key: &Q::Key,
|
||||
descriptor: &DB::QueryDescriptor,
|
||||
) -> Result<StampedValue<Q::Value>, CycleDetected> {
|
||||
let revision_now = db.salsa_runtime().current_revision();
|
||||
|
||||
debug!(
|
||||
"{:?}({:?}): invoked at {:?}",
|
||||
Q::default(),
|
||||
key,
|
||||
revision_now,
|
||||
);
|
||||
|
||||
{
|
||||
let map_read = self.map.upgradable_read();
|
||||
if let Some(value) = map_read.get(key) {
|
||||
match value {
|
||||
QueryState::InProgress => return Err(CycleDetected),
|
||||
QueryState::Memoized(_) => {}
|
||||
}
|
||||
}
|
||||
|
||||
let mut map_write = RwLockUpgradableReadGuard::upgrade(map_read);
|
||||
map_write.insert(key.clone(), QueryState::InProgress);
|
||||
}
|
||||
|
||||
// Note that, unlike with a memoized query, we must always
|
||||
// re-execute.
|
||||
let (stamped_value, inputs) = db
|
||||
.salsa_runtime()
|
||||
.execute_query_implementation::<Q>(db, descriptor, key);
|
||||
|
||||
// We assume that query is side-effect free -- that is, does
|
||||
// not mutate the "inputs" to the query system. Sanity check
|
||||
// that assumption here, at least to the best of our ability.
|
||||
assert_eq!(
|
||||
db.salsa_runtime().current_revision(),
|
||||
revision_now,
|
||||
"revision altered during query execution",
|
||||
);
|
||||
|
||||
{
|
||||
let mut map_write = self.map.write();
|
||||
|
||||
let old_value = map_write.insert(
|
||||
key.clone(),
|
||||
QueryState::Memoized(Memo {
|
||||
inputs,
|
||||
verified_at: revision_now,
|
||||
changed_at: stamped_value.changed_at,
|
||||
}),
|
||||
);
|
||||
assert!(
|
||||
match old_value {
|
||||
Some(QueryState::InProgress) => true,
|
||||
_ => false,
|
||||
},
|
||||
"expected in-progress state",
|
||||
);
|
||||
}
|
||||
|
||||
Ok(stamped_value)
|
||||
}
|
||||
|
||||
fn overwrite_placeholder(
|
||||
&self,
|
||||
map_write: &mut FxHashMap<Q::Key, QueryState<DB>>,
|
||||
key: &Q::Key,
|
||||
value: Option<QueryState<DB>>,
|
||||
) {
|
||||
let old_value = if let Some(v) = value {
|
||||
map_write.insert(key.clone(), v)
|
||||
} else {
|
||||
map_write.remove(key)
|
||||
};
|
||||
|
||||
assert!(
|
||||
match old_value {
|
||||
Some(QueryState::InProgress) => true,
|
||||
_ => false,
|
||||
},
|
||||
"expected in-progress state",
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl<DB, Q> QueryStorageOps<DB, Q> for DependencyStorage<DB, Q>
|
||||
where
|
||||
Q: QueryFunction<DB>,
|
||||
DB: Database,
|
||||
{
|
||||
fn try_fetch<'q>(
|
||||
&self,
|
||||
db: &'q DB,
|
||||
key: &Q::Key,
|
||||
descriptor: &DB::QueryDescriptor,
|
||||
) -> Result<Q::Value, CycleDetected> {
|
||||
let StampedValue { value, changed_at } = self.read(db, key, &descriptor)?;
|
||||
|
||||
db.salsa_runtime().report_query_read(descriptor, changed_at);
|
||||
|
||||
Ok(value)
|
||||
}
|
||||
|
||||
fn maybe_changed_since(
|
||||
&self,
|
||||
db: &'q DB,
|
||||
revision: Revision,
|
||||
key: &Q::Key,
|
||||
_descriptor: &DB::QueryDescriptor,
|
||||
) -> bool {
|
||||
let revision_now = db.salsa_runtime().current_revision();
|
||||
|
||||
debug!(
|
||||
"{:?}({:?})::maybe_changed_since(revision={:?}, revision_now={:?})",
|
||||
Q::default(),
|
||||
key,
|
||||
revision,
|
||||
revision_now,
|
||||
);
|
||||
|
||||
let value = {
|
||||
let map_read = self.map.upgradable_read();
|
||||
match map_read.get(key) {
|
||||
None | Some(QueryState::InProgress) => return true,
|
||||
Some(QueryState::Memoized(memo)) => {
|
||||
// If our memo is still up to date, then check if we've
|
||||
// changed since the revision.
|
||||
if memo.verified_at == revision_now {
|
||||
return memo.changed_at > revision;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut map_write = RwLockUpgradableReadGuard::upgrade(map_read);
|
||||
map_write.insert(key.clone(), QueryState::InProgress)
|
||||
};
|
||||
|
||||
// Otherwise, walk the inputs we had and check them. Note that
|
||||
// we don't want to hold the lock while we do this.
|
||||
let mut memo = match value {
|
||||
Some(QueryState::Memoized(memo)) => memo,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
if memo
|
||||
.inputs
|
||||
.iter()
|
||||
.all(|old_input| !old_input.maybe_changed_since(db, memo.verified_at))
|
||||
{
|
||||
memo.verified_at = revision_now;
|
||||
self.overwrite_placeholder(
|
||||
&mut self.map.write(),
|
||||
key,
|
||||
Some(QueryState::Memoized(memo)),
|
||||
);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Just remove the existing entry. It's out of date.
|
||||
self.overwrite_placeholder(&mut self.map.write(), key, None);
|
||||
|
||||
true
|
||||
}
|
||||
}
|
|
@ -16,7 +16,6 @@ use std::fmt::Display;
|
|||
use std::fmt::Write;
|
||||
use std::hash::Hash;
|
||||
|
||||
pub mod dependencies;
|
||||
pub mod input;
|
||||
pub mod memoized;
|
||||
pub mod runtime;
|
||||
|
@ -415,7 +414,7 @@ macro_rules! query_group {
|
|||
(
|
||||
@storage_ty[$DB:ident, $Self:ident, dependencies]
|
||||
) => {
|
||||
$crate::dependencies::DependencyStorage<$DB, $Self>
|
||||
$crate::memoized::DependencyStorage<$DB, $Self>
|
||||
};
|
||||
|
||||
(
|
||||
|
|
195
src/memoized.rs
195
src/memoized.rs
|
@ -18,16 +18,56 @@ use std::fmt::Debug;
|
|||
use std::fmt::Display;
|
||||
use std::fmt::Write;
|
||||
use std::hash::Hash;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
/// Memoized queries store the result plus a list of the other queries
|
||||
/// that they invoked. This means we can avoid recomputing them when
|
||||
/// none of those inputs have changed.
|
||||
pub struct MemoizedStorage<DB, Q>
|
||||
pub type MemoizedStorage<DB, Q> = WeakMemoizedStorage<DB, Q, AlwaysMemoizeValue>;
|
||||
|
||||
/// "Dependency" queries just track their dependencies and not the
|
||||
/// actual value (which they produce on demand). This lessens the
|
||||
/// storage requirements.
|
||||
pub type DependencyStorage<DB, Q> = WeakMemoizedStorage<DB, Q, NeverMemoizeValue>;
|
||||
|
||||
pub struct WeakMemoizedStorage<DB, Q, MP>
|
||||
where
|
||||
Q: QueryFunction<DB>,
|
||||
DB: Database,
|
||||
MP: MemoizationPolicy<DB, Q>,
|
||||
{
|
||||
map: RwLock<FxHashMap<Q::Key, QueryState<DB, Q>>>,
|
||||
policy: PhantomData<MP>,
|
||||
}
|
||||
|
||||
pub trait MemoizationPolicy<DB, Q>
|
||||
where
|
||||
Q: QueryFunction<DB>,
|
||||
DB: Database,
|
||||
{
|
||||
map: RwLock<FxHashMap<Q::Key, QueryState<DB, Q>>>,
|
||||
fn should_memoize_value(key: &Q::Key) -> bool;
|
||||
}
|
||||
|
||||
pub enum AlwaysMemoizeValue {}
|
||||
impl<DB, Q> MemoizationPolicy<DB, Q> for AlwaysMemoizeValue
|
||||
where
|
||||
Q: QueryFunction<DB>,
|
||||
DB: Database,
|
||||
{
|
||||
fn should_memoize_value(_key: &Q::Key) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
pub enum NeverMemoizeValue {}
|
||||
impl<DB, Q> MemoizationPolicy<DB, Q> for NeverMemoizeValue
|
||||
where
|
||||
Q: QueryFunction<DB>,
|
||||
DB: Database,
|
||||
{
|
||||
fn should_memoize_value(_key: &Q::Key) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Defines the "current state" of query's memoized results.
|
||||
|
@ -49,7 +89,12 @@ where
|
|||
Q: QueryFunction<DB>,
|
||||
DB: Database,
|
||||
{
|
||||
stamped_value: StampedValue<Q::Value>,
|
||||
/// Last time the value has actually changed.
|
||||
/// changed_at can be less than verified_at.
|
||||
changed_at: Revision,
|
||||
|
||||
/// The result of the query, if we decide to memoize it.
|
||||
value: Option<Q::Value>,
|
||||
|
||||
inputs: QueryDescriptorSet<DB>,
|
||||
|
||||
|
@ -61,22 +106,25 @@ where
|
|||
verified_at: Revision,
|
||||
}
|
||||
|
||||
impl<DB, Q> Default for MemoizedStorage<DB, Q>
|
||||
impl<DB, Q, MP> Default for WeakMemoizedStorage<DB, Q, MP>
|
||||
where
|
||||
Q: QueryFunction<DB>,
|
||||
DB: Database,
|
||||
MP: MemoizationPolicy<DB, Q>,
|
||||
{
|
||||
fn default() -> Self {
|
||||
MemoizedStorage {
|
||||
WeakMemoizedStorage {
|
||||
map: RwLock::new(FxHashMap::default()),
|
||||
policy: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<DB, Q> MemoizedStorage<DB, Q>
|
||||
impl<DB, Q, MP> WeakMemoizedStorage<DB, Q, MP>
|
||||
where
|
||||
Q: QueryFunction<DB>,
|
||||
DB: Database,
|
||||
MP: MemoizationPolicy<DB, Q>,
|
||||
{
|
||||
fn read(
|
||||
&self,
|
||||
|
@ -106,15 +154,22 @@ where
|
|||
m.verified_at,
|
||||
);
|
||||
|
||||
// We've found that the query is definitely up-to-date.
|
||||
// If the value is also memoized, return it.
|
||||
// Otherwise fallback to recomputing the value.
|
||||
if m.verified_at == revision_now {
|
||||
debug!(
|
||||
"{:?}({:?}): returning memoized value (changed_at={:?})",
|
||||
Q::default(),
|
||||
key,
|
||||
m.stamped_value.changed_at,
|
||||
);
|
||||
|
||||
return Ok(m.stamped_value.clone());
|
||||
if let Some(value) = &m.value {
|
||||
debug!(
|
||||
"{:?}({:?}): returning memoized value (changed_at={:?})",
|
||||
Q::default(),
|
||||
key,
|
||||
m.changed_at,
|
||||
);
|
||||
return Ok(StampedValue {
|
||||
value: value.clone(),
|
||||
changed_at: m.changed_at,
|
||||
});
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -129,25 +184,29 @@ where
|
|||
// first things first, let's walk over each of our previous
|
||||
// inputs and check whether they are out of date.
|
||||
if let Some(QueryState::Memoized(old_memo)) = &mut old_value {
|
||||
if old_memo.inputs.iter().all(|old_input| {
|
||||
!old_input.maybe_changed_since(db, old_memo.stamped_value.changed_at)
|
||||
}) {
|
||||
debug!("{:?}({:?}): inputs still valid", Q::default(), key);
|
||||
if old_memo.value.is_some() {
|
||||
if old_memo
|
||||
.inputs
|
||||
.iter()
|
||||
.all(|old_input| !old_input.maybe_changed_since(db, old_memo.changed_at))
|
||||
{
|
||||
debug!("{:?}({:?}): inputs still valid", Q::default(), key);
|
||||
// If none of out inputs have changed since the last time we refreshed
|
||||
// our value, then our value must still be good. We'll just patch
|
||||
// the verified-at date and re-use it.
|
||||
old_memo.verified_at = revision_now;
|
||||
let value = old_memo.value.clone().unwrap();
|
||||
let changed_at = old_memo.changed_at;
|
||||
|
||||
// If none of out inputs have changed since the last time we refreshed
|
||||
// our value, then our value must still be good. We'll just patch
|
||||
// the verified-at date and re-use it.
|
||||
old_memo.verified_at = revision_now;
|
||||
let stamped_value = old_memo.stamped_value.clone();
|
||||
|
||||
let mut map_write = self.map.write();
|
||||
self.overwrite_placeholder(&mut map_write, key, old_value.unwrap());
|
||||
return Ok(stamped_value);
|
||||
let mut map_write = self.map.write();
|
||||
self.overwrite_placeholder(&mut map_write, key, old_value.unwrap());
|
||||
return Ok(StampedValue { value, changed_at });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Query was not previously executed or value is potentially
|
||||
// stale. Let's execute!
|
||||
// Query was not previously executed, or value is potentially
|
||||
// stale, or value is absent. Let's execute!
|
||||
let (mut stamped_value, inputs) = db
|
||||
.salsa_runtime()
|
||||
.execute_query_implementation::<Q>(db, descriptor, key);
|
||||
|
@ -166,19 +225,25 @@ where
|
|||
// "backdate" its `changed_at` revision to be the same as the
|
||||
// old value.
|
||||
if let Some(QueryState::Memoized(old_memo)) = &old_value {
|
||||
if old_memo.stamped_value.value == stamped_value.value {
|
||||
assert!(old_memo.stamped_value.changed_at <= stamped_value.changed_at);
|
||||
stamped_value.changed_at = old_memo.stamped_value.changed_at;
|
||||
if old_memo.value.as_ref() == Some(&stamped_value.value) {
|
||||
assert!(old_memo.changed_at <= stamped_value.changed_at);
|
||||
stamped_value.changed_at = old_memo.changed_at;
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
let value = if self.should_memoize_value(key) {
|
||||
Some(stamped_value.value.clone())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let mut map_write = self.map.write();
|
||||
self.overwrite_placeholder(
|
||||
&mut map_write,
|
||||
key,
|
||||
QueryState::Memoized(Memo {
|
||||
stamped_value: stamped_value.clone(),
|
||||
changed_at: stamped_value.changed_at,
|
||||
value,
|
||||
inputs,
|
||||
verified_at: revision_now,
|
||||
}),
|
||||
|
@ -203,12 +268,17 @@ where
|
|||
"expected in-progress state",
|
||||
);
|
||||
}
|
||||
|
||||
fn should_memoize_value(&self, key: &Q::Key) -> bool {
|
||||
MP::should_memoize_value(key)
|
||||
}
|
||||
}
|
||||
|
||||
impl<DB, Q> QueryStorageOps<DB, Q> for MemoizedStorage<DB, Q>
|
||||
impl<DB, Q, MP> QueryStorageOps<DB, Q> for WeakMemoizedStorage<DB, Q, MP>
|
||||
where
|
||||
Q: QueryFunction<DB>,
|
||||
DB: Database,
|
||||
MP: MemoizationPolicy<DB, Q>,
|
||||
{
|
||||
fn try_fetch<'q>(
|
||||
&self,
|
||||
|
@ -240,32 +310,60 @@ where
|
|||
revision_now,
|
||||
);
|
||||
|
||||
// Check for the case where we have no cache entry, or our cache
|
||||
// entry is up to date (common case):
|
||||
{
|
||||
let map_read = self.map.read();
|
||||
let value = {
|
||||
let map_read = self.map.upgradable_read();
|
||||
match map_read.get(key) {
|
||||
None | Some(QueryState::InProgress) => return true,
|
||||
Some(QueryState::Memoized(memo)) => {
|
||||
if memo.verified_at >= revision_now {
|
||||
return memo.stamped_value.changed_at > revision;
|
||||
// If our memo is still up to date, then check if we've
|
||||
// changed since the revision.
|
||||
if memo.verified_at == revision_now {
|
||||
return memo.changed_at > revision;
|
||||
}
|
||||
if memo.value.is_some() {
|
||||
// Otherwise, if we cache values, fall back to the full read to compute the result.
|
||||
drop(memo);
|
||||
drop(map_read);
|
||||
return match self.read(db, key, descriptor) {
|
||||
Ok(v) => v.changed_at > revision,
|
||||
Err(CycleDetected) => true,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
// If, however, we don't cache values, then optimistically
|
||||
// try to advance `verified_at` by walking the inputs.
|
||||
let mut map_write = RwLockUpgradableReadGuard::upgrade(map_read);
|
||||
map_write.insert(key.clone(), QueryState::InProgress)
|
||||
};
|
||||
|
||||
let mut memo = match value {
|
||||
Some(QueryState::Memoized(memo)) => memo,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
if memo
|
||||
.inputs
|
||||
.iter()
|
||||
.all(|old_input| !old_input.maybe_changed_since(db, memo.verified_at))
|
||||
{
|
||||
memo.verified_at = revision_now;
|
||||
self.overwrite_placeholder(&mut self.map.write(), key, QueryState::Memoized(memo));
|
||||
return false;
|
||||
}
|
||||
|
||||
// Otherwise fall back to the full read to compute the result.
|
||||
match self.read(db, key, descriptor) {
|
||||
Ok(v) => v.changed_at > revision,
|
||||
Err(CycleDetected) => true,
|
||||
}
|
||||
// Just remove the existing entry. It's out of date.
|
||||
self.map.write().remove(key);
|
||||
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
impl<DB, Q> UncheckedMutQueryStorageOps<DB, Q> for MemoizedStorage<DB, Q>
|
||||
impl<DB, Q, MP> UncheckedMutQueryStorageOps<DB, Q> for WeakMemoizedStorage<DB, Q, MP>
|
||||
where
|
||||
Q: QueryFunction<DB>,
|
||||
DB: Database,
|
||||
MP: MemoizationPolicy<DB, Q>,
|
||||
{
|
||||
fn set_unchecked(&self, db: &DB, key: &Q::Key, value: Q::Value) {
|
||||
let key = key.clone();
|
||||
|
@ -277,7 +375,8 @@ where
|
|||
map_write.insert(
|
||||
key,
|
||||
QueryState::Memoized(Memo {
|
||||
stamped_value: StampedValue { value, changed_at },
|
||||
value: Some(value),
|
||||
changed_at,
|
||||
inputs: QueryDescriptorSet::new(),
|
||||
verified_at: changed_at,
|
||||
}),
|
||||
|
|
Loading…
Reference in a new issue