new parallel friendly algorithm

This commit is contained in:
Niko Matsakis 2021-11-08 06:32:03 -05:00
parent 685fccc9c5
commit 356392578b
18 changed files with 812 additions and 1318 deletions

View file

@ -8,8 +8,10 @@ repository = "https://github.com/salsa-rs/salsa"
description = "A generic framework for on-demand, incrementalized computation (experimental)"
[dependencies]
arc-swap = "1.4.0"
crossbeam-utils = { version = "0.8", default-features = false }
dashmap = "4.0.2"
hashlink = "0.7.0"
indexmap = "1.0.1"
lock_api = "0.4"
log = "0.4.5"

View file

@ -1,23 +1,28 @@
use crate::debug::TableEntry;
use crate::durability::Durability;
use crate::hash::FxDashMap;
use crate::lru::Lru;
use crate::plumbing::DerivedQueryStorageOps;
use crate::plumbing::LruQueryStorageOps;
use crate::plumbing::QueryFunction;
use crate::plumbing::QueryStorageMassOps;
use crate::plumbing::QueryStorageOps;
use crate::runtime::StampedValue;
use crate::runtime::local_state::QueryInputs;
use crate::runtime::local_state::QueryRevisions;
use crate::Runtime;
use crate::{Database, DatabaseKeyIndex, QueryDb, Revision};
use crossbeam_utils::atomic::AtomicCell;
use std::borrow::Borrow;
use std::hash::Hash;
use std::marker::PhantomData;
use std::sync::Arc;
mod slot;
use slot::Slot;
mod execute;
mod fetch;
mod key_to_key_index;
mod lru;
mod maybe_changed_after;
mod memo;
mod sync;
//mod slot;
//use slot::Slot;
/// Memoized queries store the result plus a list of the other queries
/// that they invoked. This means we can avoid recomputing them when
@ -37,22 +42,13 @@ where
MP: MemoizationPolicy<Q>,
{
group_index: u16,
lru_list: Lru<Slot<Q, MP>>,
indices: AtomicCell<u32>,
index_map: FxDashMap<Q::Key, DerivedKeyIndex>,
slot_map: FxDashMap<DerivedKeyIndex, KeySlot<Q, MP>>,
lru: lru::Lru,
key_map: key_to_key_index::KeyToKeyIndex<Q::Key>,
memo_map: memo::MemoMap<Q::Value>,
sync_map: sync::SyncMap,
policy: PhantomData<MP>,
}
struct KeySlot<Q, MP>
where
Q: QueryFunction,
MP: MemoizationPolicy<Q>,
{
key: Q::Key,
slot: Arc<Slot<Q, MP>>,
}
type DerivedKeyIndex = u32;
impl<Q, MP> std::panic::RefUnwindSafe for DerivedStorage<Q, MP>
@ -107,52 +103,22 @@ where
Q: QueryFunction,
MP: MemoizationPolicy<Q>,
{
fn slot_for_key(&self, key: &Q::Key) -> Arc<Slot<Q, MP>> {
// Common case: get an existing key
if let Some(v) = self.index_map.get(key) {
let index = *v;
// release the read-write lock early, for no particular reason
// apart from it bothers me
drop(v);
return self.slot_for_key_index(index);
}
// Less common case: (potentially) create a new slot
match self.index_map.entry(key.clone()) {
dashmap::mapref::entry::Entry::Occupied(entry) => self.slot_for_key_index(*entry.get()),
dashmap::mapref::entry::Entry::Vacant(entry) => {
let key_index = self.indices.fetch_add(1);
let database_key_index = DatabaseKeyIndex {
group_index: self.group_index,
query_index: Q::QUERY_INDEX,
key_index,
};
let slot = Arc::new(Slot::new(key.clone(), database_key_index));
// Subtle: store the new slot *before* the new index, so that
// other threads only see the new index once the slot is also available.
self.slot_map.insert(
key_index,
KeySlot {
key: key.clone(),
slot: slot.clone(),
},
);
entry.insert(key_index);
slot
}
fn database_key_index(&self, key_index: DerivedKeyIndex) -> DatabaseKeyIndex {
DatabaseKeyIndex {
group_index: self.group_index,
query_index: Q::QUERY_INDEX,
key_index: key_index,
}
}
fn slot_for_key_index(&self, index: DerivedKeyIndex) -> Arc<Slot<Q, MP>> {
return self.slot_map.get(&index).unwrap().slot.clone();
}
fn slot_for_db_index(&self, index: DatabaseKeyIndex) -> Arc<Slot<Q, MP>> {
fn assert_our_key_index(&self, index: DatabaseKeyIndex) {
assert_eq!(index.group_index, self.group_index);
assert_eq!(index.query_index, Q::QUERY_INDEX);
self.slot_for_key_index(index.key_index)
}
fn key_index(&self, index: DatabaseKeyIndex) -> DerivedKeyIndex {
self.assert_our_key_index(index);
index.key_index
}
}
@ -166,11 +132,11 @@ where
fn new(group_index: u16) -> Self {
DerivedStorage {
group_index,
index_map: Default::default(),
slot_map: Default::default(),
lru_list: Default::default(),
lru: Default::default(),
key_map: Default::default(),
memo_map: Default::default(),
sync_map: Default::default(),
policy: PhantomData,
indices: Default::default(),
}
}
@ -180,58 +146,47 @@ where
index: DatabaseKeyIndex,
fmt: &mut std::fmt::Formatter<'_>,
) -> std::fmt::Result {
assert_eq!(index.group_index, self.group_index);
assert_eq!(index.query_index, Q::QUERY_INDEX);
let key_slot = self.slot_map.get(&index.key_index).unwrap();
write!(fmt, "{}({:?})", Q::QUERY_NAME, key_slot.key)
let key_index = self.key_index(index);
let key = self.key_map.key_for_key_index(key_index);
write!(fmt, "{}({:?})", Q::QUERY_NAME, key)
}
fn maybe_changed_after(
&self,
db: &<Q as QueryDb<'_>>::DynDb,
input: DatabaseKeyIndex,
database_key_index: DatabaseKeyIndex,
revision: Revision,
) -> bool {
debug_assert!(revision < db.salsa_runtime().current_revision());
let slot = self.slot_for_db_index(input);
slot.maybe_changed_after(db, revision)
let key_index = self.key_index(database_key_index);
self.maybe_changed_after(db, key_index, revision)
}
fn fetch(&self, db: &<Q as QueryDb<'_>>::DynDb, key: &Q::Key) -> Q::Value {
db.unwind_if_cancelled();
let slot = self.slot_for_key(key);
let StampedValue {
value,
durability,
changed_at,
} = slot.read(db);
if let Some(evicted) = self.lru_list.record_use(&slot) {
evicted.evict();
}
db.salsa_runtime()
.report_query_read_and_unwind_if_cycle_resulted(
slot.database_key_index(),
durability,
changed_at,
);
value
let key_index = self.key_map.key_index_for_key(key);
self.fetch(db, key_index)
}
fn durability(&self, db: &<Q as QueryDb<'_>>::DynDb, key: &Q::Key) -> Durability {
self.slot_for_key(key).durability(db)
fn durability(&self, _db: &<Q as QueryDb<'_>>::DynDb, key: &Q::Key) -> Durability {
let key_index = self.key_map.key_index_for_key(key);
if let Some(memo) = self.memo_map.get(key_index) {
memo.revisions.durability
} else {
Durability::LOW
}
}
fn entries<C>(&self, _db: &<Q as QueryDb<'_>>::DynDb) -> C
where
C: std::iter::FromIterator<TableEntry<Q::Key, Q::Value>>,
{
self.slot_map
self.memo_map
.iter()
.filter_map(|r| r.value().slot.as_table_entry())
.map(|(key_index, memo)| {
let key = self.key_map.key_for_key_index(key_index);
TableEntry::new(key, memo.value.clone())
})
.collect()
}
}
@ -242,10 +197,8 @@ where
MP: MemoizationPolicy<Q>,
{
fn purge(&self) {
self.lru_list.purge();
self.indices.store(0);
self.index_map.clear();
self.slot_map.clear();
self.lru.set_capacity(0);
self.memo_map.clear();
}
}
@ -255,7 +208,7 @@ where
MP: MemoizationPolicy<Q>,
{
fn set_lru_capacity(&self, new_capacity: usize) {
self.lru_list.set_lru_capacity(new_capacity);
self.lru.set_capacity(new_capacity);
}
}
@ -270,13 +223,20 @@ where
Q::Key: Borrow<S>,
{
runtime.with_incremented_revision(|new_revision| {
if let Some(key_index) = self.index_map.get(key) {
let slot = self.slot_for_key_index(*key_index);
if let Some(durability) = slot.invalidate(new_revision) {
return Some(durability);
}
}
None
let key_index = self.key_map.existing_key_index_for_key(key)?;
let memo = self.memo_map.get(key_index)?;
let invalidated_revisions = QueryRevisions {
changed_at: new_revision,
durability: memo.revisions.durability,
inputs: QueryInputs::Untracked,
};
let new_memo = memo::Memo::new(
memo.value.clone(),
memo.verified_at.load(),
invalidated_revisions,
);
self.memo_map.insert(key_index, new_memo);
Some(memo.revisions.durability)
})
}
}

134
src/derived/execute.rs Normal file
View file

@ -0,0 +1,134 @@
use std::sync::Arc;
use crate::{
plumbing::QueryFunction,
runtime::{local_state::ActiveQueryGuard, StampedValue},
Cycle, Database, Event, EventKind, QueryDb,
};
use super::{memo::Memo, DerivedStorage, MemoizationPolicy};
impl<Q, MP> DerivedStorage<Q, MP>
where
Q: QueryFunction,
MP: MemoizationPolicy<Q>,
{
/// Executes the query function for the given `active_query`. Creates and stores
/// a new memo with the result, backdated if possible. Once this completes,
/// the query will have been popped off the active query stack.
///
/// # Parameters
///
/// * `db`, the database.
/// * `active_query`, the active stack frame for the query to execute.
/// * `opt_old_memo`, the older memo, if any existed. Used for backdated.
pub(super) fn execute(
&self,
db: &<Q as QueryDb<'_>>::DynDb,
active_query: ActiveQueryGuard<'_>,
opt_old_memo: Option<Arc<Memo<Q::Value>>>,
) -> StampedValue<Q::Value> {
let runtime = db.salsa_runtime();
let revision_now = runtime.current_revision();
let database_key_index = active_query.database_key_index;
log::info!("{:?}: executing query", database_key_index.debug(db));
db.salsa_event(Event {
runtime_id: db.salsa_runtime().id(),
kind: EventKind::WillExecute {
database_key: database_key_index,
},
});
// Query was not previously executed, or value is potentially
// stale, or value is absent. Let's execute!
let database_key_index = active_query.database_key_index;
let key_index = database_key_index.key_index;
let key = self.key_map.key_for_key_index(key_index);
let value = match Cycle::catch(|| Q::execute(db, key.clone())) {
Ok(v) => v,
Err(cycle) => {
log::debug!(
"{:?}: caught cycle {:?}, have strategy {:?}",
database_key_index.debug(db),
cycle,
Q::CYCLE_STRATEGY,
);
match Q::CYCLE_STRATEGY {
crate::plumbing::CycleRecoveryStrategy::Panic => cycle.throw(),
crate::plumbing::CycleRecoveryStrategy::Fallback => {
if let Some(c) = active_query.take_cycle() {
assert!(c.is(&cycle));
Q::cycle_fallback(db, &cycle, &key)
} else {
// we are not a participant in this cycle
debug_assert!(!cycle
.participant_keys()
.any(|k| k == database_key_index));
cycle.throw()
}
}
}
}
};
let mut revisions = active_query.pop();
// We assume that query is side-effect free -- that is, does
// not mutate the "inputs" to the query system. Sanity check
// that assumption here, at least to the best of our ability.
assert_eq!(
runtime.current_revision(),
revision_now,
"revision altered during query execution",
);
// If the new value is equal to the old one, then it didn't
// really change, even if some of its inputs have. So we can
// "backdate" its `changed_at` revision to be the same as the
// old value.
if let Some(old_memo) = &opt_old_memo {
if let Some(old_value) = &old_memo.value {
// Careful: if the value became less durable than it
// used to be, that is a "breaking change" that our
// consumers must be aware of. Becoming *more* durable
// is not. See the test `constant_to_non_constant`.
if revisions.durability >= old_memo.revisions.durability
&& MP::memoized_value_eq(old_value, &value)
{
log::debug!(
"{:?}: read_upgrade: value is equal, back-dating to {:?}",
database_key_index.debug(db),
old_memo.revisions.changed_at,
);
assert!(old_memo.revisions.changed_at <= revisions.changed_at);
revisions.changed_at = old_memo.revisions.changed_at;
}
}
}
let stamped_value = revisions.stamped_value(value);
log::debug!(
"{:?}: read_upgrade: result.revisions = {:#?}",
database_key_index.debug(db),
revisions
);
self.memo_map.insert(
key_index,
Memo::new(
if MP::should_memoize_value(&key) {
Some(stamped_value.value.clone())
} else {
None
},
revision_now,
revisions,
),
);
stamped_value
}
}

115
src/derived/fetch.rs Normal file
View file

@ -0,0 +1,115 @@
use arc_swap::Guard;
use crate::{
plumbing::{DatabaseOps, QueryFunction},
runtime::{local_state::QueryInputs, StampedValue},
Database, QueryDb,
};
use super::{DerivedKeyIndex, DerivedStorage, MemoizationPolicy};
impl<Q, MP> DerivedStorage<Q, MP>
where
Q: QueryFunction,
MP: MemoizationPolicy<Q>,
{
#[inline]
pub(super) fn fetch(
&self,
db: &<Q as QueryDb<'_>>::DynDb,
key_index: DerivedKeyIndex,
) -> Q::Value {
let StampedValue {
value,
durability,
changed_at,
} = self.compute_value(db, key_index);
if let Some(evicted) = self.lru.record_use(key_index) {
self.evict(evicted);
}
db.salsa_runtime()
.report_query_read_and_unwind_if_cycle_resulted(
self.database_key_index(key_index),
durability,
changed_at,
);
value
}
#[inline]
fn compute_value(
&self,
db: &<Q as QueryDb<'_>>::DynDb,
key_index: DerivedKeyIndex,
) -> StampedValue<Q::Value> {
loop {
if let Some(value) = self
.fetch_hot(db, key_index)
.or_else(|| self.fetch_cold(db, key_index))
{
return value;
}
}
}
#[inline]
fn fetch_hot(
&self,
db: &<Q as QueryDb<'_>>::DynDb,
key_index: DerivedKeyIndex,
) -> Option<StampedValue<Q::Value>> {
let memo_guard = self.memo_map.get(key_index);
if let Some(memo) = &memo_guard {
if let Some(value) = &memo.value {
let runtime = db.salsa_runtime();
if self.shallow_verify_memo(db, runtime, self.database_key_index(key_index), memo) {
return Some(memo.revisions.stamped_value(value.clone()));
}
}
}
None
}
fn fetch_cold(
&self,
db: &<Q as QueryDb<'_>>::DynDb,
key_index: DerivedKeyIndex,
) -> Option<StampedValue<Q::Value>> {
let runtime = db.salsa_runtime();
let database_key_index = self.database_key_index(key_index);
// Try to claim this query: if someone else has claimed it already, go back and start again.
let _claim_guard = self.sync_map.claim(db.ops_database(), database_key_index)?;
// Push the query on the stack.
let active_query = runtime.push_query(database_key_index);
// Now that we've claimed the item, check again to see if there's a "hot" value.
// This time we can do a *deep* verify. Because this can recurse, don't hold the arcswap guard.
let opt_old_memo = self.memo_map.get(key_index).map(Guard::into_inner);
if let Some(old_memo) = &opt_old_memo {
if let Some(value) = &old_memo.value {
if self.deep_verify_memo(db, old_memo, &active_query) {
return Some(old_memo.revisions.stamped_value(value.clone()));
}
}
}
Some(self.execute(db, active_query, opt_old_memo))
}
fn evict(&self, key_index: DerivedKeyIndex) {
if let Some(memo) = self.memo_map.get(key_index) {
// Careful: we can't evict memos with untracked inputs
// as their values cannot be reconstructed.
if let QueryInputs::Untracked = memo.revisions.inputs {
return;
}
self.memo_map.remove(key_index);
}
}
}

View file

@ -0,0 +1,62 @@
use crossbeam_utils::atomic::AtomicCell;
use std::borrow::Borrow;
use std::hash::Hash;
use crate::hash::FxDashMap;
use super::DerivedKeyIndex;
pub(super) struct KeyToKeyIndex<K> {
index_map: FxDashMap<K, DerivedKeyIndex>,
key_map: FxDashMap<DerivedKeyIndex, K>,
indices: AtomicCell<u32>,
}
impl<K> Default for KeyToKeyIndex<K>
where
K: Hash + Eq,
{
fn default() -> Self {
Self {
index_map: Default::default(),
key_map: Default::default(),
indices: Default::default(),
}
}
}
impl<K> KeyToKeyIndex<K>
where
K: Hash + Eq + Clone,
{
pub(super) fn key_index_for_key(&self, key: &K) -> DerivedKeyIndex {
// Common case: get an existing key
if let Some(v) = self.index_map.get(key) {
return *v;
}
// Less common case: (potentially) create a new slot
*self.index_map.entry(key.clone()).or_insert_with(|| {
let key_index = self.indices.fetch_add(1);
self.key_map.insert(key_index, key.clone());
key_index
})
}
pub(super) fn existing_key_index_for_key<S>(&self, key: &S) -> Option<DerivedKeyIndex>
where
S: Eq + Hash,
K: Borrow<S>,
{
// Common case: get an existing key
if let Some(v) = self.index_map.get(key) {
Some(*v)
} else {
None
}
}
pub(super) fn key_for_key_index(&self, key_index: DerivedKeyIndex) -> K {
self.key_map.get(&key_index).unwrap().clone()
}
}

39
src/derived/lru.rs Normal file
View file

@ -0,0 +1,39 @@
use crate::hash::FxLinkedHashSet;
use super::DerivedKeyIndex;
use crossbeam_utils::atomic::AtomicCell;
use parking_lot::Mutex;
#[derive(Default)]
pub(super) struct Lru {
capacity: AtomicCell<usize>,
set: Mutex<FxLinkedHashSet<DerivedKeyIndex>>,
}
impl Lru {
pub(super) fn record_use(&self, index: DerivedKeyIndex) -> Option<DerivedKeyIndex> {
let capacity = self.capacity.load();
if capacity == 0 {
// LRU is disabled
return None;
}
let mut set = self.set.lock();
set.insert(index);
if set.len() > capacity {
return set.pop_front();
}
None
}
pub(super) fn set_capacity(&self, capacity: usize) {
self.capacity.store(capacity);
if capacity == 0 {
let mut set = self.set.lock();
*set = FxLinkedHashSet::default();
}
}
}

View file

@ -0,0 +1,179 @@
use arc_swap::Guard;
use crate::{
plumbing::{DatabaseOps, QueryFunction},
runtime::{
local_state::{ActiveQueryGuard, QueryInputs},
StampedValue,
},
Database, DatabaseKeyIndex, QueryDb, Revision, Runtime,
};
use super::{memo::Memo, DerivedKeyIndex, DerivedStorage, MemoizationPolicy};
impl<Q, MP> DerivedStorage<Q, MP>
where
Q: QueryFunction,
MP: MemoizationPolicy<Q>,
{
pub(super) fn maybe_changed_after(
&self,
db: &<Q as QueryDb<'_>>::DynDb,
key_index: DerivedKeyIndex,
revision: Revision,
) -> bool {
loop {
let runtime = db.salsa_runtime();
let database_key_index = self.database_key_index(key_index);
log::debug!(
"{:?}: maybe_changed_after(revision = {:?})",
database_key_index.debug(db),
revision,
);
// Check if we have a verified version: this is the hot path.
let memo_guard = self.memo_map.get(key_index);
if let Some(memo) = &memo_guard {
if self.shallow_verify_memo(db, runtime, database_key_index, memo) {
return memo.revisions.changed_at > revision;
}
drop(memo_guard); // release the arc-swap guard before cold path
if let Some(mcs) = self.maybe_changed_after_cold(db, key_index, revision) {
return mcs;
} else {
// We failed to claim, have to retry.
}
} else {
// No memo? Assume has changed.
return true;
}
}
}
fn maybe_changed_after_cold(
&self,
db: &<Q as QueryDb<'_>>::DynDb,
key_index: DerivedKeyIndex,
revision: Revision,
) -> Option<bool> {
let runtime = db.salsa_runtime();
let database_key_index = self.database_key_index(key_index);
let _claim_guard = self.sync_map.claim(db.ops_database(), database_key_index)?;
let active_query = runtime.push_query(database_key_index);
// Load the current memo, if any. Use a real arc, not an arc-swap guard,
// since we may recurse.
let old_memo = match self.memo_map.get(key_index) {
Some(m) => Guard::into_inner(m),
None => return Some(true),
};
log::debug!(
"{:?}: maybe_changed_after_cold, successful claim, revision = {:?}, old_memo = {:#?}",
database_key_index.debug(db),
revision,
old_memo
);
// Check if the inputs are still valid and we can just compare `changed_at`.
if self.deep_verify_memo(db, &old_memo, &active_query) {
return Some(old_memo.revisions.changed_at > revision);
}
// If inputs have changed, but we have an old value, we can re-execute.
// It is possible the result will be equal to the old value and hence
// backdated. In that case, although we will have computed a new memo,
// the value has not logically changed.
if old_memo.value.is_some() {
let StampedValue { changed_at, .. } = self.execute(db, active_query, Some(old_memo));
return Some(changed_at > revision);
}
// Otherwise, nothing for it: have to consider the value to have changed.
Some(true)
}
/// True if the memo's value and `changed_at` time is still valid in this revision.
/// Does only a shallow O(1) check, doesn't walk the dependencies.
#[inline]
pub(super) fn shallow_verify_memo(
&self,
db: &<Q as QueryDb<'_>>::DynDb,
runtime: &Runtime,
database_key_index: DatabaseKeyIndex,
memo: &Memo<Q::Value>,
) -> bool {
let verified_at = memo.verified_at.load();
let revision_now = runtime.current_revision();
log::debug!(
"{:?}: shallow_verify_memo(memo = {:#?})",
database_key_index.debug(db),
memo,
);
if verified_at == revision_now {
// Already verified.
return true;
}
if memo.check_durability(runtime) {
// No input of the suitable durability has changed since last verified.
memo.mark_as_verified(db.ops_database(), runtime, database_key_index);
return true;
}
false
}
/// True if the memo's value and `changed_at` time is up to date in the current
/// revision. When this returns true, it also updates the memo's `verified_at`
/// field if needed to make future calls cheaper.
///
/// Takes an [`ActiveQueryGuard`] argument because this function recursively
/// walks dependencies of `old_memo` and may even execute them to see if their
/// outputs have changed. As that could lead to cycles, it is important that the
/// query is on the stack.
pub(super) fn deep_verify_memo(
&self,
db: &<Q as QueryDb<'_>>::DynDb,
old_memo: &Memo<Q::Value>,
active_query: &ActiveQueryGuard<'_>,
) -> bool {
let runtime = db.salsa_runtime();
let database_key_index = active_query.database_key_index;
log::debug!(
"{:?}: deep_verify_memo(old_memo = {:#?})",
database_key_index.debug(db),
old_memo
);
if self.shallow_verify_memo(db, runtime, database_key_index, old_memo) {
return true;
}
match &old_memo.revisions.inputs {
QueryInputs::Untracked => {
// Untracked inputs? Have to assume that it changed.
return false;
}
QueryInputs::NoInputs => {
// No inputs, cannot have changed.
}
QueryInputs::Tracked { inputs } => {
let last_verified_at = old_memo.verified_at.load();
for &input in inputs.iter() {
if db.maybe_changed_after(input, last_verified_at) {
return false;
}
}
}
}
old_memo.mark_as_verified(db.ops_database(), runtime, database_key_index);
true
}
}

107
src/derived/memo.rs Normal file
View file

@ -0,0 +1,107 @@
use std::sync::Arc;
use arc_swap::{ArcSwap, Guard};
use crossbeam_utils::atomic::AtomicCell;
use crate::{
hash::FxDashMap, runtime::local_state::QueryRevisions, DatabaseKeyIndex, Event, EventKind,
Revision, Runtime,
};
use super::DerivedKeyIndex;
pub(super) struct MemoMap<V> {
map: FxDashMap<DerivedKeyIndex, ArcSwap<Memo<V>>>,
}
impl<V> Default for MemoMap<V> {
fn default() -> Self {
Self {
map: Default::default(),
}
}
}
impl<V> MemoMap<V> {
/// Inserts the memo for the given key; (atomically) overwrites any previously existing memo.-
pub(super) fn insert(&self, key: DerivedKeyIndex, memo: Memo<V>) {
self.map.insert(key, ArcSwap::from(Arc::new(memo)));
}
/// Removes any existing memo for the given key.
pub(super) fn remove(&self, key: DerivedKeyIndex) {
self.map.remove(&key);
}
/// Loads the current memo for `key_index`. This does not hold any sort of
/// lock on the `memo_map` once it returns, so this memo could immediately
/// become outdated if other threads store into the `memo_map`.
pub(super) fn get(&self, key: DerivedKeyIndex) -> Option<Guard<Arc<Memo<V>>>> {
self.map.get(&key).map(|v| v.load())
}
/// Iterates over the entries in the map. This holds a read lock while iteration continues.
pub(super) fn iter(&self) -> impl Iterator<Item = (DerivedKeyIndex, Arc<Memo<V>>)> + '_ {
self.map
.iter()
.map(move |r| (*r.key(), r.value().load_full()))
}
/// Clears the memo of all entries.
pub(super) fn clear(&self) {
self.map.clear()
}
}
#[derive(Debug)]
pub(super) struct Memo<V> {
/// The result of the query, if we decide to memoize it.
pub(super) value: Option<V>,
/// Last revision when this memo was verified; this begins
/// as the current revision.
pub(super) verified_at: AtomicCell<Revision>,
/// Revision information
pub(super) revisions: QueryRevisions,
}
impl<V> Memo<V> {
pub(super) fn new(value: Option<V>, revision_now: Revision, revisions: QueryRevisions) -> Self {
Memo {
value,
verified_at: AtomicCell::new(revision_now),
revisions,
}
}
/// True if this memo is known not to have changed based on its durability.
pub(super) fn check_durability(&self, runtime: &Runtime) -> bool {
let last_changed = runtime.last_changed_revision(self.revisions.durability);
let verified_at = self.verified_at.load();
log::debug!(
"check_durability(last_changed={:?} <= verified_at={:?}) = {:?}",
last_changed,
self.verified_at,
last_changed <= verified_at,
);
last_changed <= verified_at
}
/// Mark memo as having been verified in the `revision_now`, which should
/// be the current revision.
pub(super) fn mark_as_verified(
&self,
db: &dyn crate::Database,
runtime: &crate::Runtime,
database_key_index: DatabaseKeyIndex,
) {
db.salsa_event(Event {
runtime_id: runtime.id(),
kind: EventKind::DidValidateMemoizedValue {
database_key: database_key_index,
},
});
self.verified_at.store(runtime.current_revision());
}
}

View file

@ -1,871 +0,0 @@
use crate::debug::TableEntry;
use crate::derived::MemoizationPolicy;
use crate::durability::Durability;
use crate::lru::LruIndex;
use crate::lru::LruNode;
use crate::plumbing::{DatabaseOps, QueryFunction};
use crate::revision::Revision;
use crate::runtime::local_state::ActiveQueryGuard;
use crate::runtime::local_state::QueryInputs;
use crate::runtime::local_state::QueryRevisions;
use crate::runtime::Runtime;
use crate::runtime::RuntimeId;
use crate::runtime::StampedValue;
use crate::runtime::WaitResult;
use crate::Cycle;
use crate::{Database, DatabaseKeyIndex, Event, EventKind, QueryDb};
use log::{debug, info};
use parking_lot::{RawRwLock, RwLock};
use std::marker::PhantomData;
use std::ops::Deref;
use std::sync::atomic::{AtomicBool, Ordering};
pub(super) struct Slot<Q, MP>
where
Q: QueryFunction,
MP: MemoizationPolicy<Q>,
{
key: Q::Key,
database_key_index: DatabaseKeyIndex,
state: RwLock<QueryState<Q>>,
policy: PhantomData<MP>,
lru_index: LruIndex,
}
/// Defines the "current state" of query's memoized results.
enum QueryState<Q>
where
Q: QueryFunction,
{
NotComputed,
/// The runtime with the given id is currently computing the
/// result of this query.
InProgress {
id: RuntimeId,
/// Set to true if any other queries are blocked,
/// waiting for this query to complete.
anyone_waiting: AtomicBool,
},
/// We have computed the query already, and here is the result.
Memoized(Memo<Q::Value>),
}
struct Memo<V> {
/// The result of the query, if we decide to memoize it.
value: Option<V>,
/// Last revision when this memo was verified; this begins
/// as the current revision.
pub(crate) verified_at: Revision,
/// Revision information
revisions: QueryRevisions,
}
/// Return value of `probe` helper.
enum ProbeState<V, G> {
/// Another thread was active but has completed.
/// Try again!
Retry,
/// No entry for this key at all.
NotComputed(G),
/// There is an entry, but its contents have not been
/// verified in this revision.
Stale(G),
/// There is an entry, and it has been verified
/// in this revision, but it has no cached
/// value. The `Revision` is the revision where the
/// value last changed (if we were to recompute it).
NoValue(G, Revision),
/// There is an entry which has been verified,
/// and it has the following value-- or, we blocked
/// on another thread, and that resulted in a cycle.
UpToDate(V),
}
/// Return value of `maybe_changed_after_probe` helper.
enum MaybeChangedSinceProbeState<G> {
/// Another thread was active but has completed.
/// Try again!
Retry,
/// Value may have changed in the given revision.
ChangedAt(Revision),
/// There is a stale cache entry that has not been
/// verified in this revision, so we can't say.
Stale(G),
}
impl<Q, MP> Slot<Q, MP>
where
Q: QueryFunction,
MP: MemoizationPolicy<Q>,
{
pub(super) fn new(key: Q::Key, database_key_index: DatabaseKeyIndex) -> Self {
Self {
key,
database_key_index,
state: RwLock::new(QueryState::NotComputed),
lru_index: LruIndex::default(),
policy: PhantomData,
}
}
pub(super) fn database_key_index(&self) -> DatabaseKeyIndex {
self.database_key_index
}
pub(super) fn read(&self, db: &<Q as QueryDb<'_>>::DynDb) -> StampedValue<Q::Value> {
let runtime = db.salsa_runtime();
// NB: We don't need to worry about people modifying the
// revision out from under our feet. Either `db` is a frozen
// database, in which case there is a lock, or the mutator
// thread is the current thread, and it will be prevented from
// doing any `set` invocations while the query function runs.
let revision_now = runtime.current_revision();
info!("{:?}: invoked at {:?}", self, revision_now,);
// First, do a check with a read-lock.
loop {
match self.probe(db, self.state.read(), runtime, revision_now) {
ProbeState::UpToDate(v) => return v,
ProbeState::Stale(..) | ProbeState::NoValue(..) | ProbeState::NotComputed(..) => {
break
}
ProbeState::Retry => continue,
}
}
self.read_upgrade(db, revision_now)
}
/// Second phase of a read operation: acquires an upgradable-read
/// and -- if needed -- validates whether inputs have changed,
/// recomputes value, etc. This is invoked after our initial probe
/// shows a potentially out of date value.
fn read_upgrade(
&self,
db: &<Q as QueryDb<'_>>::DynDb,
revision_now: Revision,
) -> StampedValue<Q::Value> {
let runtime = db.salsa_runtime();
debug!("{:?}: read_upgrade(revision_now={:?})", self, revision_now,);
// Check with an upgradable read to see if there is a value
// already. (This permits other readers but prevents anyone
// else from running `read_upgrade` at the same time.)
let mut old_memo = loop {
match self.probe(db, self.state.upgradable_read(), runtime, revision_now) {
ProbeState::UpToDate(v) => return v,
ProbeState::Stale(state)
| ProbeState::NotComputed(state)
| ProbeState::NoValue(state, _) => {
type RwLockUpgradableReadGuard<'a, T> =
lock_api::RwLockUpgradableReadGuard<'a, RawRwLock, T>;
let mut state = RwLockUpgradableReadGuard::upgrade(state);
match std::mem::replace(&mut *state, QueryState::in_progress(runtime.id())) {
QueryState::Memoized(old_memo) => break Some(old_memo),
QueryState::InProgress { .. } => unreachable!(),
QueryState::NotComputed => break None,
}
}
ProbeState::Retry => continue,
}
};
let panic_guard = PanicGuard::new(self.database_key_index, self, runtime);
let active_query = runtime.push_query(self.database_key_index);
// If we have an old-value, it *may* now be stale, since there
// has been a new revision since the last time we checked. So,
// first things first, let's walk over each of our previous
// inputs and check whether they are out of date.
if let Some(memo) = &mut old_memo {
if let Some(value) = memo.verify_value(db.ops_database(), revision_now, &active_query) {
info!("{:?}: validated old memoized value", self,);
db.salsa_event(Event {
runtime_id: runtime.id(),
kind: EventKind::DidValidateMemoizedValue {
database_key: self.database_key_index,
},
});
panic_guard.proceed(old_memo);
return value;
}
}
self.execute(
db,
runtime,
revision_now,
active_query,
panic_guard,
old_memo,
)
}
fn execute(
&self,
db: &<Q as QueryDb<'_>>::DynDb,
runtime: &Runtime,
revision_now: Revision,
active_query: ActiveQueryGuard<'_>,
panic_guard: PanicGuard<'_, Q, MP>,
old_memo: Option<Memo<Q::Value>>,
) -> StampedValue<Q::Value> {
log::info!("{:?}: executing query", self.database_key_index.debug(db));
db.salsa_event(Event {
runtime_id: db.salsa_runtime().id(),
kind: EventKind::WillExecute {
database_key: self.database_key_index,
},
});
// Query was not previously executed, or value is potentially
// stale, or value is absent. Let's execute!
let value = match Cycle::catch(|| Q::execute(db, self.key.clone())) {
Ok(v) => v,
Err(cycle) => {
log::debug!(
"{:?}: caught cycle {:?}, have strategy {:?}",
self.database_key_index.debug(db),
cycle,
Q::CYCLE_STRATEGY,
);
match Q::CYCLE_STRATEGY {
crate::plumbing::CycleRecoveryStrategy::Panic => {
panic_guard.proceed(None);
cycle.throw()
}
crate::plumbing::CycleRecoveryStrategy::Fallback => {
if let Some(c) = active_query.take_cycle() {
assert!(c.is(&cycle));
Q::cycle_fallback(db, &cycle, &self.key)
} else {
// we are not a participant in this cycle
debug_assert!(!cycle
.participant_keys()
.any(|k| k == self.database_key_index));
cycle.throw()
}
}
}
}
};
let mut revisions = active_query.pop();
// We assume that query is side-effect free -- that is, does
// not mutate the "inputs" to the query system. Sanity check
// that assumption here, at least to the best of our ability.
assert_eq!(
runtime.current_revision(),
revision_now,
"revision altered during query execution",
);
// If the new value is equal to the old one, then it didn't
// really change, even if some of its inputs have. So we can
// "backdate" its `changed_at` revision to be the same as the
// old value.
if let Some(old_memo) = &old_memo {
if let Some(old_value) = &old_memo.value {
// Careful: if the value became less durable than it
// used to be, that is a "breaking change" that our
// consumers must be aware of. Becoming *more* durable
// is not. See the test `constant_to_non_constant`.
if revisions.durability >= old_memo.revisions.durability
&& MP::memoized_value_eq(old_value, &value)
{
debug!(
"read_upgrade({:?}): value is equal, back-dating to {:?}",
self, old_memo.revisions.changed_at,
);
assert!(old_memo.revisions.changed_at <= revisions.changed_at);
revisions.changed_at = old_memo.revisions.changed_at;
}
}
}
let new_value = StampedValue {
value,
durability: revisions.durability,
changed_at: revisions.changed_at,
};
let memo_value = if self.should_memoize_value(&self.key) {
Some(new_value.value.clone())
} else {
None
};
debug!(
"read_upgrade({:?}): result.revisions = {:#?}",
self, revisions,
);
panic_guard.proceed(Some(Memo {
value: memo_value,
verified_at: revision_now,
revisions,
}));
new_value
}
/// Helper for `read` that does a shallow check (not recursive) if we have an up-to-date value.
///
/// Invoked with the guard `state` corresponding to the `QueryState` of some `Slot` (the guard
/// can be either read or write). Returns a suitable `ProbeState`:
///
/// - `ProbeState::UpToDate(r)` if the table has an up-to-date value (or we blocked on another
/// thread that produced such a value).
/// - `ProbeState::StaleOrAbsent(g)` if either (a) there is no memo for this key, (b) the memo
/// has no value; or (c) the memo has not been verified at the current revision.
///
/// Note that in case `ProbeState::UpToDate`, the lock will have been released.
fn probe<StateGuard>(
&self,
db: &<Q as QueryDb<'_>>::DynDb,
state: StateGuard,
runtime: &Runtime,
revision_now: Revision,
) -> ProbeState<StampedValue<Q::Value>, StateGuard>
where
StateGuard: Deref<Target = QueryState<Q>>,
{
match &*state {
QueryState::NotComputed => ProbeState::NotComputed(state),
QueryState::InProgress { id, anyone_waiting } => {
let other_id = *id;
// NB: `Ordering::Relaxed` is sufficient here,
// as there are no loads that are "gated" on this
// value. Everything that is written is also protected
// by a lock that must be acquired. The role of this
// boolean is to decide *whether* to acquire the lock,
// not to gate future atomic reads.
anyone_waiting.store(true, Ordering::Relaxed);
self.block_on_or_unwind(db, runtime, other_id, state);
// Other thread completely normally, so our value may be available now.
ProbeState::Retry
}
QueryState::Memoized(memo) => {
debug!(
"{:?}: found memoized value, verified_at={:?}, changed_at={:?}",
self, memo.verified_at, memo.revisions.changed_at,
);
if memo.verified_at < revision_now {
return ProbeState::Stale(state);
}
if let Some(value) = &memo.value {
let value = StampedValue {
durability: memo.revisions.durability,
changed_at: memo.revisions.changed_at,
value: value.clone(),
};
info!(
"{:?}: returning memoized value changed at {:?}",
self, value.changed_at
);
ProbeState::UpToDate(value)
} else {
let changed_at = memo.revisions.changed_at;
ProbeState::NoValue(state, changed_at)
}
}
}
}
pub(super) fn durability(&self, db: &<Q as QueryDb<'_>>::DynDb) -> Durability {
match &*self.state.read() {
QueryState::NotComputed => Durability::LOW,
QueryState::InProgress { .. } => panic!("query in progress"),
QueryState::Memoized(memo) => {
if memo.check_durability(db.salsa_runtime()) {
memo.revisions.durability
} else {
Durability::LOW
}
}
}
}
pub(super) fn as_table_entry(&self) -> Option<TableEntry<Q::Key, Q::Value>> {
match &*self.state.read() {
QueryState::NotComputed => None,
QueryState::InProgress { .. } => Some(TableEntry::new(self.key.clone(), None)),
QueryState::Memoized(memo) => {
Some(TableEntry::new(self.key.clone(), memo.value.clone()))
}
}
}
pub(super) fn evict(&self) {
let mut state = self.state.write();
if let QueryState::Memoized(memo) = &mut *state {
// Evicting a value with an untracked input could
// lead to inconsistencies. Note that we can't check
// `has_untracked_input` when we add the value to the cache,
// because inputs can become untracked in the next revision.
if memo.has_untracked_input() {
return;
}
memo.value = None;
}
}
pub(super) fn invalidate(&self, new_revision: Revision) -> Option<Durability> {
log::debug!("Slot::invalidate(new_revision = {:?})", new_revision);
match &mut *self.state.write() {
QueryState::Memoized(memo) => {
memo.revisions.inputs = QueryInputs::Untracked;
memo.revisions.changed_at = new_revision;
Some(memo.revisions.durability)
}
QueryState::NotComputed => None,
QueryState::InProgress { .. } => unreachable!(),
}
}
pub(super) fn maybe_changed_after(
&self,
db: &<Q as QueryDb<'_>>::DynDb,
revision: Revision,
) -> bool {
let runtime = db.salsa_runtime();
let revision_now = runtime.current_revision();
db.unwind_if_cancelled();
debug!(
"maybe_changed_after({:?}) called with revision={:?}, revision_now={:?}",
self, revision, revision_now,
);
// Do an initial probe with just the read-lock.
//
// If we find that a cache entry for the value is present
// but hasn't been verified in this revision, we'll have to
// do more.
loop {
match self.maybe_changed_after_probe(db, self.state.read(), runtime, revision_now) {
MaybeChangedSinceProbeState::Retry => continue,
MaybeChangedSinceProbeState::ChangedAt(changed_at) => return changed_at > revision,
MaybeChangedSinceProbeState::Stale(state) => {
drop(state);
return self.maybe_changed_after_upgrade(db, revision);
}
}
}
}
fn maybe_changed_after_probe<StateGuard>(
&self,
db: &<Q as QueryDb<'_>>::DynDb,
state: StateGuard,
runtime: &Runtime,
revision_now: Revision,
) -> MaybeChangedSinceProbeState<StateGuard>
where
StateGuard: Deref<Target = QueryState<Q>>,
{
match self.probe(db, state, runtime, revision_now) {
ProbeState::Retry => MaybeChangedSinceProbeState::Retry,
ProbeState::Stale(state) => MaybeChangedSinceProbeState::Stale(state),
// If we know when value last changed, we can return right away.
// Note that we don't need the actual value to be available.
ProbeState::NoValue(_, changed_at)
| ProbeState::UpToDate(StampedValue {
value: _,
durability: _,
changed_at,
}) => MaybeChangedSinceProbeState::ChangedAt(changed_at),
// If we have nothing cached, then value may have changed.
ProbeState::NotComputed(_) => MaybeChangedSinceProbeState::ChangedAt(revision_now),
}
}
fn maybe_changed_after_upgrade(
&self,
db: &<Q as QueryDb<'_>>::DynDb,
revision: Revision,
) -> bool {
let runtime = db.salsa_runtime();
let revision_now = runtime.current_revision();
// Get an upgradable read lock, which permits other reads but no writers.
// Probe again. If the value is stale (needs to be verified), then upgrade
// to a write lock and swap it with InProgress while we work.
let mut old_memo = match self.maybe_changed_after_probe(
db,
self.state.upgradable_read(),
runtime,
revision_now,
) {
MaybeChangedSinceProbeState::ChangedAt(changed_at) => return changed_at > revision,
// If another thread was active, then the cache line is going to be
// either verified or cleared out. Just recurse to figure out which.
// Note that we don't need an upgradable read.
MaybeChangedSinceProbeState::Retry => return self.maybe_changed_after(db, revision),
MaybeChangedSinceProbeState::Stale(state) => {
type RwLockUpgradableReadGuard<'a, T> =
lock_api::RwLockUpgradableReadGuard<'a, RawRwLock, T>;
let mut state = RwLockUpgradableReadGuard::upgrade(state);
match std::mem::replace(&mut *state, QueryState::in_progress(runtime.id())) {
QueryState::Memoized(old_memo) => old_memo,
QueryState::NotComputed | QueryState::InProgress { .. } => unreachable!(),
}
}
};
let panic_guard = PanicGuard::new(self.database_key_index, self, runtime);
let active_query = runtime.push_query(self.database_key_index);
if old_memo.verify_revisions(db.ops_database(), revision_now, &active_query) {
let maybe_changed = old_memo.revisions.changed_at > revision;
panic_guard.proceed(Some(old_memo));
maybe_changed
} else if old_memo.value.is_some() {
// We found that this memoized value may have changed
// but we have an old value. We can re-run the code and
// actually *check* if it has changed.
let StampedValue { changed_at, .. } = self.execute(
db,
runtime,
revision_now,
active_query,
panic_guard,
Some(old_memo),
);
changed_at > revision
} else {
// We found that inputs to this memoized value may have chanced
// but we don't have an old value to compare against or re-use.
// No choice but to drop the memo and say that its value may have changed.
panic_guard.proceed(None);
true
}
}
/// Helper: see [`Runtime::try_block_on_or_unwind`].
fn block_on_or_unwind<MutexGuard>(
&self,
db: &<Q as QueryDb<'_>>::DynDb,
runtime: &Runtime,
other_id: RuntimeId,
mutex_guard: MutexGuard,
) {
runtime.block_on_or_unwind(
db.ops_database(),
self.database_key_index,
other_id,
mutex_guard,
)
}
fn should_memoize_value(&self, key: &Q::Key) -> bool {
MP::should_memoize_value(key)
}
}
impl<Q> QueryState<Q>
where
Q: QueryFunction,
{
fn in_progress(id: RuntimeId) -> Self {
QueryState::InProgress {
id,
anyone_waiting: Default::default(),
}
}
}
struct PanicGuard<'me, Q, MP>
where
Q: QueryFunction,
MP: MemoizationPolicy<Q>,
{
database_key_index: DatabaseKeyIndex,
slot: &'me Slot<Q, MP>,
runtime: &'me Runtime,
}
impl<'me, Q, MP> PanicGuard<'me, Q, MP>
where
Q: QueryFunction,
MP: MemoizationPolicy<Q>,
{
fn new(
database_key_index: DatabaseKeyIndex,
slot: &'me Slot<Q, MP>,
runtime: &'me Runtime,
) -> Self {
Self {
database_key_index,
slot,
runtime,
}
}
/// Indicates that we have concluded normally (without panicking).
/// If `opt_memo` is some, then this memo is installed as the new
/// memoized value. If `opt_memo` is `None`, then the slot is cleared
/// and has no value.
fn proceed(mut self, opt_memo: Option<Memo<Q::Value>>) {
self.overwrite_placeholder(WaitResult::Completed, opt_memo);
std::mem::forget(self)
}
/// Overwrites the `InProgress` placeholder for `key` that we
/// inserted; if others were blocked, waiting for us to finish,
/// then notify them.
fn overwrite_placeholder(&mut self, wait_result: WaitResult, opt_memo: Option<Memo<Q::Value>>) {
let mut write = self.slot.state.write();
let old_value = match opt_memo {
// Replace the `InProgress` marker that we installed with the new
// memo, thus releasing our unique access to this key.
Some(memo) => std::mem::replace(&mut *write, QueryState::Memoized(memo)),
// We had installed an `InProgress` marker, but we panicked before
// it could be removed. At this point, we therefore "own" unique
// access to our slot, so we can just remove the key.
None => std::mem::replace(&mut *write, QueryState::NotComputed),
};
match old_value {
QueryState::InProgress { id, anyone_waiting } => {
assert_eq!(id, self.runtime.id());
// NB: As noted on the `store`, `Ordering::Relaxed` is
// sufficient here. This boolean signals us on whether to
// acquire a mutex; the mutex will guarantee that all writes
// we are interested in are visible.
if anyone_waiting.load(Ordering::Relaxed) {
self.runtime
.unblock_queries_blocked_on(self.database_key_index, wait_result);
}
}
_ => panic!(
"\
Unexpected panic during query evaluation, aborting the process.
Please report this bug to https://github.com/salsa-rs/salsa/issues."
),
}
}
}
impl<'me, Q, MP> Drop for PanicGuard<'me, Q, MP>
where
Q: QueryFunction,
MP: MemoizationPolicy<Q>,
{
fn drop(&mut self) {
if std::thread::panicking() {
// We panicked before we could proceed and need to remove `key`.
self.overwrite_placeholder(WaitResult::Panicked, None)
} else {
// If no panic occurred, then panic guard ought to be
// "forgotten" and so this Drop code should never run.
panic!(".forget() was not called")
}
}
}
impl<V> Memo<V>
where
V: Clone,
{
/// Determines whether the value stored in this memo (if any) is still
/// valid in the current revision. If so, returns a stamped value.
///
/// If needed, this will walk each dependency and
/// recursively invoke `maybe_changed_after`, which may in turn
/// re-execute the dependency. This can cause cycles to occur,
/// so the current query must be pushed onto the
/// stack to permit cycle detection and recovery: therefore,
/// takes the `active_query` argument as evidence.
fn verify_value(
&mut self,
db: &dyn Database,
revision_now: Revision,
active_query: &ActiveQueryGuard<'_>,
) -> Option<StampedValue<V>> {
// If we don't have a memoized value, nothing to validate.
if self.value.is_none() {
return None;
}
if self.verify_revisions(db, revision_now, active_query) {
Some(StampedValue {
durability: self.revisions.durability,
changed_at: self.revisions.changed_at,
value: self.value.as_ref().unwrap().clone(),
})
} else {
None
}
}
/// Determines whether the value represented by this memo is still
/// valid in the current revision; note that the value itself is
/// not needed for this check. If needed, this will walk each
/// dependency and recursively invoke `maybe_changed_after`, which
/// may in turn re-execute the dependency. This can cause cycles to occur,
/// so the current query must be pushed onto the
/// stack to permit cycle detection and recovery: therefore,
/// takes the `active_query` argument as evidence.
fn verify_revisions(
&mut self,
db: &dyn Database,
revision_now: Revision,
_active_query: &ActiveQueryGuard<'_>,
) -> bool {
assert!(self.verified_at != revision_now);
let verified_at = self.verified_at;
debug!(
"verify_revisions: verified_at={:?}, revision_now={:?}, inputs={:#?}",
verified_at, revision_now, self.revisions.inputs
);
if self.check_durability(db.salsa_runtime()) {
return self.mark_value_as_verified(revision_now);
}
match &self.revisions.inputs {
// We can't validate values that had untracked inputs; just have to
// re-execute.
QueryInputs::Untracked => {
return false;
}
QueryInputs::NoInputs => {}
// Check whether any of our inputs changed since the
// **last point where we were verified** (not since we
// last changed). This is important: if we have
// memoized values, then an input may have changed in
// revision R2, but we found that *our* value was the
// same regardless, so our change date is still
// R1. But our *verification* date will be R2, and we
// are only interested in finding out whether the
// input changed *again*.
QueryInputs::Tracked { inputs } => {
let changed_input = inputs
.iter()
.find(|&&input| db.maybe_changed_after(input, verified_at));
if let Some(input) = changed_input {
debug!("validate_memoized_value: `{:?}` may have changed", input);
return false;
}
}
};
self.mark_value_as_verified(revision_now)
}
/// True if this memo is known not to have changed based on its durability.
fn check_durability(&self, runtime: &Runtime) -> bool {
let last_changed = runtime.last_changed_revision(self.revisions.durability);
debug!(
"check_durability(last_changed={:?} <= verified_at={:?}) = {:?}",
last_changed,
self.verified_at,
last_changed <= self.verified_at,
);
last_changed <= self.verified_at
}
fn mark_value_as_verified(&mut self, revision_now: Revision) -> bool {
self.verified_at = revision_now;
true
}
fn has_untracked_input(&self) -> bool {
matches!(self.revisions.inputs, QueryInputs::Untracked)
}
}
impl<Q, MP> std::fmt::Debug for Slot<Q, MP>
where
Q: QueryFunction,
MP: MemoizationPolicy<Q>,
{
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(fmt, "{:?}({:?})", Q::default(), self.key)
}
}
impl<Q, MP> LruNode for Slot<Q, MP>
where
Q: QueryFunction,
MP: MemoizationPolicy<Q>,
{
fn lru_index(&self) -> &LruIndex {
&self.lru_index
}
}
/// Check that `Slot<Q, MP>: Send + Sync` as long as
/// `DB::DatabaseData: Send + Sync`, which in turn implies that
/// `Q::Key: Send + Sync`, `Q::Value: Send + Sync`.
#[allow(dead_code)]
fn check_send_sync<Q, MP>()
where
Q: QueryFunction,
MP: MemoizationPolicy<Q>,
Q::Key: Send + Sync,
Q::Value: Send + Sync,
{
fn is_send_sync<T: Send + Sync>() {}
is_send_sync::<Slot<Q, MP>>();
}
/// Check that `Slot<Q, MP>: 'static` as long as
/// `DB::DatabaseData: 'static`, which in turn implies that
/// `Q::Key: 'static`, `Q::Value: 'static`.
#[allow(dead_code)]
fn check_static<Q, MP>()
where
Q: QueryFunction + 'static,
MP: MemoizationPolicy<Q> + 'static,
Q::Key: 'static,
Q::Value: 'static,
{
fn is_static<T: 'static>() {}
is_static::<Slot<Q, MP>>();
}

87
src/derived/sync.rs Normal file
View file

@ -0,0 +1,87 @@
use std::sync::atomic::{AtomicBool, Ordering};
use crate::{hash::FxDashMap, runtime::WaitResult, Database, DatabaseKeyIndex, Runtime, RuntimeId};
use super::DerivedKeyIndex;
#[derive(Default)]
pub(super) struct SyncMap {
sync_map: FxDashMap<DerivedKeyIndex, SyncState>,
}
struct SyncState {
id: RuntimeId,
/// Set to true if any other queries are blocked,
/// waiting for this query to complete.
anyone_waiting: AtomicBool,
}
impl SyncMap {
pub(super) fn claim<'me>(
&'me self,
db: &'me dyn Database,
database_key_index: DatabaseKeyIndex,
) -> Option<ClaimGuard<'me>> {
let runtime = db.salsa_runtime();
match self.sync_map.entry(database_key_index.key_index) {
dashmap::mapref::entry::Entry::Vacant(entry) => {
entry.insert(SyncState {
id: runtime.id(),
anyone_waiting: AtomicBool::new(false),
});
Some(ClaimGuard {
database_key: database_key_index,
runtime,
sync_map: &self.sync_map,
})
}
dashmap::mapref::entry::Entry::Occupied(entry) => {
// NB: `Ordering::Relaxed` is sufficient here,
// as there are no loads that are "gated" on this
// value. Everything that is written is also protected
// by a lock that must be acquired. The role of this
// boolean is to decide *whether* to acquire the lock,
// not to gate future atomic reads.
entry.get().anyone_waiting.store(true, Ordering::Relaxed);
let other_id = entry.get().id;
runtime.block_on_or_unwind(db, database_key_index, other_id, entry);
None
}
}
}
}
/// Marks an active 'claim' in the synchronization map. The claim is
/// released when this value is dropped.
#[must_use]
pub(super) struct ClaimGuard<'me> {
database_key: DatabaseKeyIndex,
runtime: &'me Runtime,
sync_map: &'me FxDashMap<DerivedKeyIndex, SyncState>,
}
impl<'me> ClaimGuard<'me> {
fn remove_from_map_and_unblock_queries(&self, wait_result: WaitResult) {
let (_, SyncState { anyone_waiting, .. }) =
self.sync_map.remove(&self.database_key.key_index).unwrap();
// NB: `Ordering::Relaxed` is sufficient here,
// see `store` above for explanation.
if anyone_waiting.load(Ordering::Relaxed) {
self.runtime
.unblock_queries_blocked_on(self.database_key, wait_result)
}
}
}
impl<'me> Drop for ClaimGuard<'me> {
fn drop(&mut self) {
let wait_result = if std::thread::panicking() {
WaitResult::Panicked
} else {
WaitResult::Completed
};
self.remove_from_map_and_unblock_queries(wait_result)
}
}

View file

@ -2,3 +2,4 @@ pub(crate) type FxHasher = std::hash::BuildHasherDefault<rustc_hash::FxHasher>;
pub(crate) type FxIndexSet<K> = indexmap::IndexSet<K, FxHasher>;
pub(crate) type FxIndexMap<K, V> = indexmap::IndexMap<K, V, FxHasher>;
pub(crate) type FxDashMap<K, V> = dashmap::DashMap<K, V, FxHasher>;
pub(crate) type FxLinkedHashSet<K> = hashlink::LinkedHashSet<K, FxHasher>;

View file

@ -16,7 +16,6 @@ mod hash;
mod input;
mod intern_id;
mod interned;
mod lru;
mod revision;
mod runtime;
mod storage;

View file

@ -1,335 +0,0 @@
use oorandom::Rand64;
use parking_lot::Mutex;
use std::fmt::Debug;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
use std::sync::Arc;
/// A simple and approximate concurrent lru list.
///
/// We assume but do not verify that each node is only used with one
/// list. If this is not the case, it is not *unsafe*, but panics and
/// weird results will ensue.
///
/// Each "node" in the list is of type `Node` and must implement
/// `LruNode`, which is a trait that gives access to a field that
/// stores the index in the list. This index gives us a rough idea of
/// how recently the node has been used.
#[derive(Debug)]
pub(crate) struct Lru<Node>
where
Node: LruNode,
{
green_zone: AtomicUsize,
data: Mutex<LruData<Node>>,
}
#[derive(Debug)]
struct LruData<Node> {
end_red_zone: usize,
end_yellow_zone: usize,
end_green_zone: usize,
rng: Rand64,
entries: Vec<Arc<Node>>,
}
pub(crate) trait LruNode: Sized + Debug {
fn lru_index(&self) -> &LruIndex;
}
#[derive(Debug)]
pub(crate) struct LruIndex {
/// Index in the approprate LRU list, or std::usize::MAX if not a
/// member.
index: AtomicUsize,
}
impl<Node> Default for Lru<Node>
where
Node: LruNode,
{
fn default() -> Self {
Lru::new()
}
}
// We always use a fixed seed for our randomness so that we have
// predictable results.
const LRU_SEED: &str = "Hello, Rustaceans";
impl<Node> Lru<Node>
where
Node: LruNode,
{
/// Creates a new LRU list where LRU caching is disabled.
pub fn new() -> Self {
Self::with_seed(LRU_SEED)
}
#[cfg_attr(not(test), allow(dead_code))]
fn with_seed(seed: &str) -> Self {
Lru {
green_zone: AtomicUsize::new(0),
data: Mutex::new(LruData::with_seed(seed)),
}
}
/// Adjust the total number of nodes permitted to have a value at
/// once. If `len` is zero, this disables LRU caching completely.
pub fn set_lru_capacity(&self, len: usize) {
let mut data = self.data.lock();
// We require each zone to have at least 1 slot. Therefore,
// the length cannot be just 1 or 2.
if len == 0 {
self.green_zone.store(0, Ordering::Release);
data.resize(0, 0, 0);
} else {
let len = std::cmp::max(len, 3);
// Top 10% is the green zone. This must be at least length 1.
let green_zone = std::cmp::max(len / 10, 1);
// Next 20% is the yellow zone.
let yellow_zone = std::cmp::max(len / 5, 1);
// Remaining 70% is the red zone.
let red_zone = len - yellow_zone - green_zone;
// We need quick access to the green zone.
self.green_zone.store(green_zone, Ordering::Release);
// Resize existing array.
data.resize(green_zone, yellow_zone, red_zone);
}
}
/// Records that `node` was used. This may displace an old node (if the LRU limits are
pub fn record_use(&self, node: &Arc<Node>) -> Option<Arc<Node>> {
log::debug!("record_use(node={:?})", node);
// Load green zone length and check if the LRU cache is even enabled.
let green_zone = self.green_zone.load(Ordering::Acquire);
log::debug!("record_use: green_zone={}", green_zone);
if green_zone == 0 {
return None;
}
// Find current index of list (if any) and the current length
// of our green zone.
let index = node.lru_index().load();
log::debug!("record_use: index={}", index);
// Already a member of the list, and in the green zone -- nothing to do!
if index < green_zone {
return None;
}
self.data.lock().record_use(node)
}
pub fn purge(&self) {
self.green_zone.store(0, Ordering::SeqCst);
*self.data.lock() = LruData::with_seed(LRU_SEED);
}
}
impl<Node> LruData<Node>
where
Node: LruNode,
{
fn with_seed(seed_str: &str) -> Self {
Self::with_rng(rng_with_seed(seed_str))
}
fn with_rng(rng: Rand64) -> Self {
LruData {
end_yellow_zone: 0,
end_green_zone: 0,
end_red_zone: 0,
entries: Vec::new(),
rng,
}
}
fn green_zone(&self) -> std::ops::Range<usize> {
0..self.end_green_zone
}
fn yellow_zone(&self) -> std::ops::Range<usize> {
self.end_green_zone..self.end_yellow_zone
}
fn red_zone(&self) -> std::ops::Range<usize> {
self.end_yellow_zone..self.end_red_zone
}
fn resize(&mut self, len_green_zone: usize, len_yellow_zone: usize, len_red_zone: usize) {
self.end_green_zone = len_green_zone;
self.end_yellow_zone = self.end_green_zone + len_yellow_zone;
self.end_red_zone = self.end_yellow_zone + len_red_zone;
let entries = std::mem::replace(&mut self.entries, Vec::with_capacity(self.end_red_zone));
log::debug!("green_zone = {:?}", self.green_zone());
log::debug!("yellow_zone = {:?}", self.yellow_zone());
log::debug!("red_zone = {:?}", self.red_zone());
// We expect to resize when the LRU cache is basically empty.
// So just forget all the old LRU indices to start.
for entry in entries {
entry.lru_index().clear();
}
}
/// Records that a node was used. If it is already a member of the
/// LRU list, it is promoted to the green zone (unless it's
/// already there). Otherwise, it is added to the list first and
/// *then* promoted to the green zone. Adding a new node to the
/// list may displace an old member of the red zone, in which case
/// that is returned.
fn record_use(&mut self, node: &Arc<Node>) -> Option<Arc<Node>> {
log::debug!("record_use(node={:?})", node);
// NB: When this is invoked, we have typically already loaded
// the LRU index (to check if it is in green zone). But that
// check was done outside the lock and -- for all we know --
// the index may have changed since. So we always reload.
let index = node.lru_index().load();
if index < self.end_green_zone {
None
} else if index < self.end_yellow_zone {
self.promote_yellow_to_green(node, index);
None
} else if index < self.end_red_zone {
self.promote_red_to_green(node, index);
None
} else {
self.insert_new(node)
}
}
/// Inserts a node that is not yet a member of the LRU list. If
/// the list is at capacity, this can displace an existing member.
fn insert_new(&mut self, node: &Arc<Node>) -> Option<Arc<Node>> {
debug_assert!(!node.lru_index().is_in_lru());
// Easy case: we still have capacity. Push it, and then promote
// it up to the appropriate zone.
let len = self.entries.len();
if len < self.end_red_zone {
self.entries.push(node.clone());
node.lru_index().store(len);
log::debug!("inserted node {:?} at {}", node, len);
return self.record_use(node);
}
// Harder case: no capacity. Create some by evicting somebody from red
// zone and then promoting.
let victim_index = self.pick_index(self.red_zone());
let victim_node = std::mem::replace(&mut self.entries[victim_index], node.clone());
log::debug!("evicting red node {:?} from {}", victim_node, victim_index);
victim_node.lru_index().clear();
self.promote_red_to_green(node, victim_index);
Some(victim_node)
}
/// Promotes the node `node`, stored at `red_index` (in the red
/// zone), into a green index, demoting yellow/green nodes at
/// random.
///
/// NB: It is not required that `node.lru_index()` is up-to-date
/// when entering this method.
fn promote_red_to_green(&mut self, node: &Arc<Node>, red_index: usize) {
debug_assert!(self.red_zone().contains(&red_index));
// Pick a yellow at random and switch places with it.
//
// Subtle: we do not update `node.lru_index` *yet* -- we're
// going to invoke `self.promote_yellow` next, and it will get
// updated then.
let yellow_index = self.pick_index(self.yellow_zone());
log::debug!(
"demoting yellow node {:?} from {} to red at {}",
self.entries[yellow_index],
yellow_index,
red_index,
);
self.entries.swap(yellow_index, red_index);
self.entries[red_index].lru_index().store(red_index);
// Now move ourselves up into the green zone.
self.promote_yellow_to_green(node, yellow_index);
}
/// Promotes the node `node`, stored at `yellow_index` (in the
/// yellow zone), into a green index, demoting a green node at
/// random to replace it.
///
/// NB: It is not required that `node.lru_index()` is up-to-date
/// when entering this method.
fn promote_yellow_to_green(&mut self, node: &Arc<Node>, yellow_index: usize) {
debug_assert!(self.yellow_zone().contains(&yellow_index));
// Pick a yellow at random and switch places with it.
let green_index = self.pick_index(self.green_zone());
log::debug!(
"demoting green node {:?} from {} to yellow at {}",
self.entries[green_index],
green_index,
yellow_index
);
self.entries.swap(green_index, yellow_index);
self.entries[yellow_index].lru_index().store(yellow_index);
node.lru_index().store(green_index);
log::debug!("promoted {:?} to green index {}", node, green_index);
}
fn pick_index(&mut self, zone: std::ops::Range<usize>) -> usize {
let end_index = std::cmp::min(zone.end, self.entries.len());
self.rng.rand_range(zone.start as u64..end_index as u64) as usize
}
}
impl Default for LruIndex {
fn default() -> Self {
Self {
index: AtomicUsize::new(std::usize::MAX),
}
}
}
impl LruIndex {
fn load(&self) -> usize {
self.index.load(Ordering::Acquire) // see note on ordering below
}
fn store(&self, value: usize) {
self.index.store(value, Ordering::Release) // see note on ordering below
}
fn clear(&self) {
self.store(std::usize::MAX);
}
fn is_in_lru(&self) -> bool {
self.load() != std::usize::MAX
}
}
fn rng_with_seed(seed_str: &str) -> Rand64 {
let mut seed: [u8; 16] = [0; 16];
for (i, &b) in seed_str.as_bytes().iter().take(16).enumerate() {
seed[i] = b;
}
Rand64::new(u128::from_le_bytes(seed))
}
// A note on ordering:
//
// I chose to use AcqRel for the ordering but I don't think it's
// strictly needed. All writes occur under a lock, so they should be
// ordered w/r/t one another. As for the reads, they can occur
// outside the lock, but they don't themselves enable dependent reads
// -- if the reads are out of bounds, we would acquire a lock.

View file

@ -8,6 +8,8 @@ use crate::DatabaseKeyIndex;
use std::cell::RefCell;
use std::sync::Arc;
use super::StampedValue;
/// State that is specific to a single execution thread.
///
/// Internally, this type uses ref-cells.
@ -38,6 +40,16 @@ pub(crate) struct QueryRevisions {
pub(crate) inputs: QueryInputs,
}
impl QueryRevisions {
pub(crate) fn stamped_value<V>(&self, value: V) -> StampedValue<V> {
StampedValue {
value,
durability: self.durability,
changed_at: self.changed_at,
}
}
}
/// Every input.
#[derive(Debug, Clone)]
pub(crate) enum QueryInputs {
@ -180,7 +192,7 @@ impl std::panic::RefUnwindSafe for LocalState {}
pub(crate) struct ActiveQueryGuard<'me> {
local_state: &'me LocalState,
push_len: usize,
database_key_index: DatabaseKeyIndex,
pub(crate) database_key_index: DatabaseKeyIndex,
}
impl ActiveQueryGuard<'_> {

View file

@ -1,5 +1,6 @@
use crate::implementation::{TestContext, TestContextImpl};
use salsa::{Database, Durability};
use test_env_log::test;
#[salsa::query_group(MemoizedVolatile)]
pub(crate) trait MemoizedVolatileContext: TestContext {

View file

@ -144,8 +144,10 @@ fn on_demand_input_durability() {
RefCell {
value: [
"Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: a(1) } }",
"Event { runtime_id: RuntimeId { counter: 0 }, kind: DidValidateMemoizedValue { database_key: b(1) } }",
"Event { runtime_id: RuntimeId { counter: 0 }, kind: DidValidateMemoizedValue { database_key: c(1) } }",
"Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: a(2) } }",
"Event { runtime_id: RuntimeId { counter: 0 }, kind: DidValidateMemoizedValue { database_key: b(2) } }",
"Event { runtime_id: RuntimeId { counter: 0 }, kind: DidValidateMemoizedValue { database_key: c(2) } }",
],
}

View file

@ -70,7 +70,7 @@ fn should_panic_safely() {
db.set_one(1);
db.outer();
assert_eq!(OUTER_CALLS.load(SeqCst), 2);
assert_eq!(OUTER_CALLS.load(SeqCst), 1);
}
}

View file

@ -40,7 +40,7 @@ fn parallel_cycle_none_recover() {
assert!(thread_a
.join()
.unwrap_err()
.downcast_ref::<salsa::Cycle>()
.downcast_ref::<salsa::Cancelled>()
.is_some());
}