Merge pull request #43 from nikomatsakis/derived-storage

Combine memoized and volatile to make "derived storage"
This commit is contained in:
Niko Matsakis 2018-10-09 15:19:14 -04:00 committed by GitHub
commit 3b5f16cbcb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 224 additions and 194 deletions

View file

@ -1,3 +1,4 @@
use crate::runtime::ChangedAt;
use crate::runtime::QueryDescriptorSet;
use crate::runtime::Revision;
use crate::runtime::StampedValue;
@ -23,14 +24,21 @@ use std::marker::PhantomData;
/// Memoized queries store the result plus a list of the other queries
/// that they invoked. This means we can avoid recomputing them when
/// none of those inputs have changed.
pub type MemoizedStorage<DB, Q> = WeakMemoizedStorage<DB, Q, AlwaysMemoizeValue>;
pub type MemoizedStorage<DB, Q> = DerivedStorage<DB, Q, AlwaysMemoizeValue>;
/// "Dependency" queries just track their dependencies and not the
/// actual value (which they produce on demand). This lessens the
/// storage requirements.
pub type DependencyStorage<DB, Q> = WeakMemoizedStorage<DB, Q, NeverMemoizeValue>;
pub type DependencyStorage<DB, Q> = DerivedStorage<DB, Q, NeverMemoizeValue>;
pub struct WeakMemoizedStorage<DB, Q, MP>
/// "Dependency" queries just track their dependencies and not the
/// actual value (which they produce on demand). This lessens the
/// storage requirements.
pub type VolatileStorage<DB, Q> = DerivedStorage<DB, Q, VolatileValue>;
/// Handles storage where the value is 'derived' by executing a
/// function (in contrast to "inputs").
pub struct DerivedStorage<DB, Q, MP>
where
Q: QueryFunction<DB>,
DB: Database,
@ -46,6 +54,8 @@ where
DB: Database,
{
fn should_memoize_value(key: &Q::Key) -> bool;
fn should_track_inputs(key: &Q::Key) -> bool;
}
pub enum AlwaysMemoizeValue {}
@ -57,6 +67,10 @@ where
fn should_memoize_value(_key: &Q::Key) -> bool {
true
}
fn should_track_inputs(_key: &Q::Key) -> bool {
true
}
}
pub enum NeverMemoizeValue {}
@ -68,6 +82,30 @@ where
fn should_memoize_value(_key: &Q::Key) -> bool {
false
}
fn should_track_inputs(_key: &Q::Key) -> bool {
true
}
}
pub enum VolatileValue {}
impl<DB, Q> MemoizationPolicy<DB, Q> for VolatileValue
where
Q: QueryFunction<DB>,
DB: Database,
{
fn should_memoize_value(_key: &Q::Key) -> bool {
// Why memoize? Well, if the "volatile" value really is
// constantly changing, we still want to capture its value
// until the next revision is triggered and ensure it doesn't
// change -- otherwise the system gets into an inconsistent
// state where the same query reports back different values.
true
}
fn should_track_inputs(_key: &Q::Key) -> bool {
false
}
}
/// Defines the "current state" of query's memoized results.
@ -91,11 +129,12 @@ where
{
/// Last time the value has actually changed.
/// changed_at can be less than verified_at.
changed_at: Revision,
changed_at: ChangedAt,
/// The result of the query, if we decide to memoize it.
value: Option<Q::Value>,
/// The inputs that went into our query, if we are tracking them.
inputs: QueryDescriptorSet<DB>,
/// Last time that we checked our inputs to see if they have
@ -106,21 +145,21 @@ where
verified_at: Revision,
}
impl<DB, Q, MP> Default for WeakMemoizedStorage<DB, Q, MP>
impl<DB, Q, MP> Default for DerivedStorage<DB, Q, MP>
where
Q: QueryFunction<DB>,
DB: Database,
MP: MemoizationPolicy<DB, Q>,
{
fn default() -> Self {
WeakMemoizedStorage {
DerivedStorage {
map: RwLock::new(FxHashMap::default()),
policy: PhantomData,
}
}
}
impl<DB, Q, MP> WeakMemoizedStorage<DB, Q, MP>
impl<DB, Q, MP> DerivedStorage<DB, Q, MP>
where
Q: QueryFunction<DB>,
DB: Database,
@ -184,32 +223,32 @@ where
// first things first, let's walk over each of our previous
// inputs and check whether they are out of date.
if let Some(QueryState::Memoized(old_memo)) = &mut old_value {
if old_memo.value.is_some() {
if old_memo
.inputs
.iter()
.all(|old_input| !old_input.maybe_changed_since(db, old_memo.changed_at))
{
debug!("{:?}({:?}): inputs still valid", Q::default(), key);
// If none of out inputs have changed since the last time we refreshed
// our value, then our value must still be good. We'll just patch
// the verified-at date and re-use it.
old_memo.verified_at = revision_now;
let value = old_memo.value.clone().unwrap();
let changed_at = old_memo.changed_at;
if let Some(value) = old_memo.verify_memoized_value(db) {
debug!("{:?}({:?}): inputs still valid", Q::default(), key);
// If none of out inputs have changed since the last time we refreshed
// our value, then our value must still be good. We'll just patch
// the verified-at date and re-use it.
old_memo.verified_at = revision_now;
let changed_at = old_memo.changed_at;
let mut map_write = self.map.write();
self.overwrite_placeholder(&mut map_write, key, old_value.unwrap());
return Ok(StampedValue { value, changed_at });
}
let mut map_write = self.map.write();
self.overwrite_placeholder(&mut map_write, key, old_value.unwrap());
return Ok(StampedValue { value, changed_at });
}
}
// Query was not previously executed, or value is potentially
// stale, or value is absent. Let's execute!
let (mut stamped_value, inputs) = db
.salsa_runtime()
.execute_query_implementation::<Q>(db, descriptor, key);
let runtime = db.salsa_runtime();
let (mut stamped_value, inputs) = runtime.execute_query_implementation(descriptor, || {
debug!("{:?}({:?}): executing query", Q::default(), key);
if !self.should_track_inputs(key) {
runtime.report_untracked_read();
}
Q::execute(db, key.clone())
});
// We assume that query is side-effect free -- that is, does
// not mutate the "inputs" to the query system. Sanity check
@ -272,9 +311,13 @@ where
fn should_memoize_value(&self, key: &Q::Key) -> bool {
MP::should_memoize_value(key)
}
fn should_track_inputs(&self, key: &Q::Key) -> bool {
MP::should_track_inputs(key)
}
}
impl<DB, Q, MP> QueryStorageOps<DB, Q> for WeakMemoizedStorage<DB, Q, MP>
impl<DB, Q, MP> QueryStorageOps<DB, Q> for DerivedStorage<DB, Q, MP>
where
Q: QueryFunction<DB>,
DB: Database,
@ -318,14 +361,14 @@ where
// If our memo is still up to date, then check if we've
// changed since the revision.
if memo.verified_at == revision_now {
return memo.changed_at > revision;
return memo.changed_at.changed_since(revision);
}
if memo.value.is_some() {
// Otherwise, if we cache values, fall back to the full read to compute the result.
drop(memo);
drop(map_read);
return match self.read(db, key, descriptor) {
Ok(v) => v.changed_at > revision,
Ok(v) => v.changed_at.changed_since(revision),
Err(CycleDetected) => true,
};
}
@ -342,11 +385,7 @@ where
_ => unreachable!(),
};
if memo
.inputs
.iter()
.all(|old_input| !old_input.maybe_changed_since(db, memo.verified_at))
{
if memo.verify_inputs(db) {
memo.verified_at = revision_now;
self.overwrite_placeholder(&mut self.map.write(), key, QueryState::Memoized(memo));
return false;
@ -359,7 +398,7 @@ where
}
}
impl<DB, Q, MP> UncheckedMutQueryStorageOps<DB, Q> for WeakMemoizedStorage<DB, Q, MP>
impl<DB, Q, MP> UncheckedMutQueryStorageOps<DB, Q> for DerivedStorage<DB, Q, MP>
where
Q: QueryFunction<DB>,
DB: Database,
@ -370,16 +409,47 @@ where
let mut map_write = self.map.write();
let changed_at = db.salsa_runtime().current_revision();
let current_revision = db.salsa_runtime().current_revision();
let changed_at = ChangedAt::Revision(current_revision);
map_write.insert(
key,
QueryState::Memoized(Memo {
value: Some(value),
changed_at,
inputs: QueryDescriptorSet::new(),
verified_at: changed_at,
inputs: QueryDescriptorSet::default(),
verified_at: current_revision,
}),
);
}
}
impl<DB, Q> Memo<DB, Q>
where
Q: QueryFunction<DB>,
DB: Database,
{
fn verify_memoized_value(&self, db: &DB) -> Option<Q::Value> {
// If we don't have a memoized value, nothing to validate.
if let Some(v) = &self.value {
// If inputs are still valid.
if self.verify_inputs(db) {
return Some(v.clone());
}
}
None
}
fn verify_inputs(&self, db: &DB) -> bool {
match self.changed_at {
ChangedAt::Revision(revision) => match &self.inputs {
QueryDescriptorSet::Tracked(inputs) => inputs
.iter()
.all(|old_input| !old_input.maybe_changed_since(db, revision)),
QueryDescriptorSet::Untracked => false,
},
}
}
}

View file

@ -1,3 +1,4 @@
use crate::runtime::ChangedAt;
use crate::runtime::QueryDescriptorSet;
use crate::runtime::Revision;
use crate::runtime::StampedValue;
@ -66,7 +67,7 @@ where
Ok(StampedValue {
value: <Q::Value>::default(),
changed_at: Revision::ZERO,
changed_at: ChangedAt::Revision(Revision::ZERO),
})
}
}
@ -109,10 +110,10 @@ where
map_read
.get(key)
.map(|v| v.changed_at)
.unwrap_or(Revision::ZERO)
.unwrap_or(ChangedAt::Revision(Revision::ZERO))
};
changed_at > revision
changed_at.changed_since(revision)
}
}
@ -131,7 +132,7 @@ where
// racing with somebody else to modify this same cell.
// (Otherwise, someone else might write a *newer* revision
// into the same cell while we block on the lock.)
let changed_at = db.salsa_runtime().increment_revision();
let changed_at = ChangedAt::Revision(db.salsa_runtime().increment_revision());
map_write.insert(key, StampedValue { value, changed_at });
}
@ -150,7 +151,7 @@ where
// Unlike with `set`, here we use the **current revision** and
// do not create a new one.
let changed_at = db.salsa_runtime().current_revision();
let changed_at = ChangedAt::Revision(db.salsa_runtime().current_revision());
map_write.insert(key, StampedValue { value, changed_at });
}

View file

@ -16,10 +16,9 @@ use std::fmt::Display;
use std::fmt::Write;
use std::hash::Hash;
pub mod derived;
pub mod input;
pub mod memoized;
pub mod runtime;
pub mod volatile;
pub use crate::runtime::Runtime;
@ -402,19 +401,19 @@ macro_rules! query_group {
(
@storage_ty[$DB:ident, $Self:ident, memoized]
) => {
$crate::memoized::MemoizedStorage<$DB, $Self>
$crate::derived::MemoizedStorage<$DB, $Self>
};
(
@storage_ty[$DB:ident, $Self:ident, volatile]
) => {
$crate::volatile::VolatileStorage<$DB, $Self>
$crate::derived::VolatileStorage<$DB, $Self>
};
(
@storage_ty[$DB:ident, $Self:ident, dependencies]
) => {
$crate::memoized::DependencyStorage<$DB, $Self>
$crate::derived::DependencyStorage<$DB, $Self>
};
(

View file

@ -88,16 +88,12 @@ where
result
}
crate fn execute_query_implementation<Q>(
crate fn execute_query_implementation<V>(
&self,
db: &DB,
descriptor: &DB::QueryDescriptor,
key: &Q::Key,
) -> (StampedValue<Q::Value>, QueryDescriptorSet<DB>)
where
Q: QueryFunction<DB>,
{
debug!("{:?}({:?}): executing query", Q::default(), key);
execute: impl FnOnce() -> V,
) -> (StampedValue<V>, QueryDescriptorSet<DB>) {
debug!("{:?}: execute_query_implementation invoked", descriptor);
// Push the active query onto the stack.
let push_len = {
@ -109,7 +105,7 @@ where
};
// Execute user's code, accumulating inputs etc.
let value = Q::execute(db, key.clone());
let value = execute();
// Extract accumulated inputs.
let ActiveQuery {
@ -136,12 +132,19 @@ where
/// - `descriptor`: the query whose result was read
/// - `changed_revision`: the last revision in which the result of that
/// query had changed
crate fn report_query_read(&self, descriptor: &DB::QueryDescriptor, changed_at: Revision) {
crate fn report_query_read(&self, descriptor: &DB::QueryDescriptor, changed_at: ChangedAt) {
if let Some(top_query) = self.local_state.borrow_mut().query_stack.last_mut() {
top_query.add_read(descriptor, changed_at);
}
}
crate fn report_untracked_read(&self) {
if let Some(top_query) = self.local_state.borrow_mut().query_stack.last_mut() {
let changed_at = ChangedAt::Revision(self.current_revision());
top_query.add_untracked_read(changed_at);
}
}
/// Obviously, this should be user configurable at some point.
crate fn report_unexpected_cycle(&self, descriptor: DB::QueryDescriptor) -> ! {
let local_state = self.local_state.borrow();
@ -178,7 +181,7 @@ struct ActiveQuery<DB: Database> {
descriptor: DB::QueryDescriptor,
/// Records the maximum revision where any subquery changed
changed_at: Revision,
changed_at: ChangedAt,
/// Each subquery
subqueries: QueryDescriptorSet<DB>,
@ -188,15 +191,20 @@ impl<DB: Database> ActiveQuery<DB> {
fn new(descriptor: DB::QueryDescriptor) -> Self {
ActiveQuery {
descriptor,
changed_at: Revision::ZERO,
subqueries: QueryDescriptorSet::new(),
changed_at: ChangedAt::Revision(Revision::ZERO),
subqueries: QueryDescriptorSet::default(),
}
}
fn add_read(&mut self, subquery: &DB::QueryDescriptor, changed_at: Revision) {
fn add_read(&mut self, subquery: &DB::QueryDescriptor, changed_at: ChangedAt) {
self.subqueries.insert(subquery.clone());
self.changed_at = self.changed_at.max(changed_at);
}
fn add_untracked_read(&mut self, changed_at: ChangedAt) {
self.subqueries.insert_untracked();
self.changed_at = self.changed_at.max(changed_at);
}
}
#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
@ -214,40 +222,70 @@ impl std::fmt::Debug for Revision {
}
}
/// Records when a stamped value changed.
///
/// Note: the order of variants is significant. We sometimes use `max`
/// for example to find the "most recent revision" when something
/// changed.
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub enum ChangedAt {
Revision(Revision),
}
impl ChangedAt {
/// True if this value has changed after `revision`.
pub fn changed_since(self, revision: Revision) -> bool {
match self {
ChangedAt::Revision(r) => r > revision,
}
}
}
/// An insertion-order-preserving set of queries. Used to track the
/// inputs accessed during query execution.
crate struct QueryDescriptorSet<DB: Database> {
set: FxIndexSet<DB::QueryDescriptor>,
crate enum QueryDescriptorSet<DB: Database> {
/// All reads were to tracked things:
Tracked(FxIndexSet<DB::QueryDescriptor>),
/// Some reads to an untracked thing:
Untracked,
}
impl<DB: Database> std::fmt::Debug for QueryDescriptorSet<DB> {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Debug::fmt(&self.set, fmt)
match self {
QueryDescriptorSet::Tracked(set) => std::fmt::Debug::fmt(set, fmt),
QueryDescriptorSet::Untracked => write!(fmt, "Untracked"),
}
}
}
impl<DB: Database> Default for QueryDescriptorSet<DB> {
fn default() -> Self {
QueryDescriptorSet::Tracked(FxIndexSet::default())
}
}
impl<DB: Database> QueryDescriptorSet<DB> {
crate fn new() -> Self {
QueryDescriptorSet {
set: FxIndexSet::default(),
/// Add `descriptor` to the set. Returns true if `descriptor` is
/// newly added and false if `descriptor` was already a member.
fn insert(&mut self, descriptor: DB::QueryDescriptor) {
match self {
QueryDescriptorSet::Tracked(set) => {
set.insert(descriptor);
}
QueryDescriptorSet::Untracked => {}
}
}
/// Add `descriptor` to the set. Returns true if `descriptor` is
/// newly added and false if `descriptor` was already a member.
fn insert(&mut self, descriptor: DB::QueryDescriptor) -> bool {
self.set.insert(descriptor)
}
/// Iterate over all queries in the set, in the order of their
/// first insertion.
pub fn iter(&self) -> impl Iterator<Item = &DB::QueryDescriptor> {
self.set.iter()
fn insert_untracked(&mut self) {
*self = QueryDescriptorSet::Untracked;
}
}
#[derive(Clone, Debug)]
crate struct StampedValue<V> {
crate value: V,
crate changed_at: Revision,
crate changed_at: ChangedAt,
}

View file

@ -1,95 +0,0 @@
use crate::runtime::Revision;
use crate::runtime::StampedValue;
use crate::CycleDetected;
use crate::Database;
use crate::QueryFunction;
use crate::QueryStorageOps;
use crate::QueryTable;
use log::debug;
use parking_lot::Mutex;
use rustc_hash::FxHashSet;
use std::any::Any;
use std::cell::RefCell;
use std::collections::hash_map::Entry;
use std::fmt::Debug;
use std::fmt::Display;
use std::fmt::Write;
use std::hash::Hash;
/// Volatile Storage is just **always** considered dirty. Any time you
/// ask for the result of such a query, it is recomputed.
pub struct VolatileStorage<DB, Q>
where
Q: QueryFunction<DB>,
DB: Database,
{
/// We don't store the results of volatile queries,
/// but we track in-progress set to detect cycles.
in_progress: Mutex<FxHashSet<Q::Key>>,
}
impl<DB, Q> Default for VolatileStorage<DB, Q>
where
Q: QueryFunction<DB>,
DB: Database,
{
fn default() -> Self {
VolatileStorage {
in_progress: Mutex::new(FxHashSet::default()),
}
}
}
impl<DB, Q> QueryStorageOps<DB, Q> for VolatileStorage<DB, Q>
where
Q: QueryFunction<DB>,
DB: Database,
{
fn try_fetch<'q>(
&self,
db: &'q DB,
key: &Q::Key,
descriptor: &DB::QueryDescriptor,
) -> Result<Q::Value, CycleDetected> {
if !self.in_progress.lock().insert(key.clone()) {
return Err(CycleDetected);
}
let (
StampedValue {
value,
changed_at: _,
},
_inputs,
) = db
.salsa_runtime()
.execute_query_implementation::<Q>(db, descriptor, key);
let was_in_progress = self.in_progress.lock().remove(key);
assert!(was_in_progress);
let revision_now = db.salsa_runtime().current_revision();
db.salsa_runtime()
.report_query_read(descriptor, revision_now);
Ok(value)
}
fn maybe_changed_since(
&self,
_db: &'q DB,
revision: Revision,
key: &Q::Key,
_descriptor: &DB::QueryDescriptor,
) -> bool {
debug!(
"{:?}({:?})::maybe_changed_since(revision={:?}) ==> true (volatile)",
Q::default(),
key,
revision,
);
true
}
}

View file

@ -38,10 +38,11 @@ fn volatile(db: &impl MemoizedVolatileContext, (): ()) -> usize {
fn volatile_x2() {
let query = TestContextImpl::default();
// Invoking volatile twice will simply execute twice.
// Invoking volatile twice doesn't execute twice, because volatile
// queries are memoized by default.
query.volatile(());
query.volatile(());
query.assert_log(&["Volatile invoked", "Volatile invoked"]);
query.assert_log(&["Volatile invoked"]);
}
/// Test that:
@ -67,7 +68,7 @@ fn revalidate() {
query.salsa_runtime().next_revision();
query.memoized2(());
query.assert_log(&["Memoized1 invoked", "Volatile invoked"]);
query.assert_log(&["Volatile invoked", "Memoized1 invoked"]);
query.memoized2(());
query.assert_log(&[]);
@ -78,7 +79,7 @@ fn revalidate() {
query.salsa_runtime().next_revision();
query.memoized2(());
query.assert_log(&["Memoized1 invoked", "Volatile invoked", "Memoized2 invoked"]);
query.assert_log(&["Volatile invoked", "Memoized1 invoked", "Memoized2 invoked"]);
query.memoized2(());
query.assert_log(&[]);

View file

@ -1,4 +1,5 @@
#![feature(crate_visibility_modifier)]
#![feature(underscore_imports)]
mod implementation;
mod queries;

View file

@ -17,7 +17,7 @@ salsa::query_group! {
/// Because this query is memoized, we only increment the counter
/// the first time it is invoked.
fn memoized(db: &impl Database, (): ()) -> usize {
db.increment()
db.volatile(())
}
/// Because this query is volatile, each time it is invoked,

View file

@ -2,32 +2,47 @@
use crate::implementation::DatabaseImpl;
use crate::queries::Database;
use salsa::Database as _;
#[test]
fn memoized_twice() {
let query = DatabaseImpl::default();
let v1 = query.memoized(());
let v2 = query.memoized(());
let db = DatabaseImpl::default();
let v1 = db.memoized(());
let v2 = db.memoized(());
assert_eq!(v1, v2);
}
#[test]
fn volatile_twice() {
let query = DatabaseImpl::default();
let v1 = query.volatile(());
let v2 = query.volatile(());
assert_eq!(v1 + 1, v2);
let db = DatabaseImpl::default();
let v1 = db.volatile(());
let v2 = db.volatile(()); // volatiles are cached, so 2nd read returns the same
assert_eq!(v1, v2);
db.salsa_runtime().next_revision(); // clears volatile caches
let v3 = db.volatile(()); // will re-increment the counter
let v4 = db.volatile(()); // second call will be cached
assert_eq!(v1 + 1, v3);
assert_eq!(v3, v4);
}
#[test]
fn intermingled() {
let query = DatabaseImpl::default();
let v1 = query.volatile(());
let v2 = query.memoized(());
let v3 = query.volatile(());
let v4 = query.memoized(());
let db = DatabaseImpl::default();
let v1 = db.volatile(());
let v2 = db.memoized(());
let v3 = db.volatile(()); // cached
let v4 = db.memoized(()); // cached
assert_eq!(v1 + 1, v2);
assert_eq!(v2 + 1, v3);
assert_eq!(v1, v2);
assert_eq!(v1, v3);
assert_eq!(v2, v4);
db.salsa_runtime().next_revision(); // clears volatile caches
let v5 = db.memoized(()); // re-executes volatile, caches new result
let v6 = db.memoized(()); // re-use cached result
assert_eq!(v4 + 1, v5);
assert_eq!(v5, v6);
}