diff --git a/components/salsa-macro-rules/src/setup_tracked_fn.rs b/components/salsa-macro-rules/src/setup_tracked_fn.rs index e3f60aa..d91ba72 100644 --- a/components/salsa-macro-rules/src/setup_tracked_fn.rs +++ b/components/salsa-macro-rules/src/setup_tracked_fn.rs @@ -265,26 +265,24 @@ macro_rules! setup_tracked_fn { } } } - $zalsa::attach_database($db, || { - let result = $zalsa::macro_if! { - if $needs_interner { - { - let key = $Configuration::intern_ingredient($db).intern_id($db.as_salsa_database(), ($($input_id),*)); - $Configuration::fn_ingredient($db).fetch($db, key) - } - } else { - $Configuration::fn_ingredient($db).fetch($db, $zalsa::AsId::as_id(&($($input_id),*))) - } - }; - - $zalsa::macro_if! { - if $return_ref { - result - } else { - <$output_ty as std::clone::Clone>::clone(result) + let result = $zalsa::macro_if! { + if $needs_interner { + { + let key = $Configuration::intern_ingredient($db).intern_id($db.as_salsa_database(), ($($input_id),*)); + $Configuration::fn_ingredient($db).fetch($db, key) } + } else { + $Configuration::fn_ingredient($db).fetch($db, $zalsa::AsId::as_id(&($($input_id),*))) } - }) + }; + + $zalsa::macro_if! { + if $return_ref { + result + } else { + <$output_ty as std::clone::Clone>::clone(result) + } + } } }; } diff --git a/src/accumulator.rs b/src/accumulator.rs index 52bb71b..4731861 100644 --- a/src/accumulator.rs +++ b/src/accumulator.rs @@ -10,9 +10,9 @@ use crate::{ hash::FxDashMap, ingredient::{fmt_index, Ingredient, Jar}, key::DependencyIndex, - local_state::QueryOrigin, + local_state::{self, LocalState, QueryOrigin}, storage::IngredientIndex, - Database, DatabaseKeyIndex, Event, EventKind, Id, Revision, Runtime, + Database, DatabaseKeyIndex, Event, EventKind, Id, Revision, }; pub trait Accumulator: Clone + Debug + Send + Sync + 'static + Sized { @@ -79,44 +79,47 @@ impl IngredientImpl { } pub fn push(&self, db: &dyn crate::Database, value: A) { - let runtime = db.runtime(); - let current_revision = runtime.current_revision(); - let (active_query, _) = match runtime.active_query() { - Some(pair) => pair, - None => { - panic!("cannot accumulate values outside of an active query") + local_state::attach(db, |state| { + let runtime = db.runtime(); + let current_revision = runtime.current_revision(); + let (active_query, _) = match state.active_query() { + Some(pair) => pair, + None => { + panic!("cannot accumulate values outside of an active query") + } + }; + + let mut accumulated_values = + self.map.entry(active_query).or_insert(AccumulatedValues { + values: vec![], + produced_at: current_revision, + }); + + // When we call `push' in a query, we will add the accumulator to the output of the query. + // If we find here that this accumulator is not the output of the query, + // we can say that the accumulated values we stored for this query is out of date. + if !state.is_output_of_active_query(self.dependency_index()) { + accumulated_values.values.truncate(0); + accumulated_values.produced_at = current_revision; } - }; - let mut accumulated_values = self.map.entry(active_query).or_insert(AccumulatedValues { - values: vec![], - produced_at: current_revision, - }); - - // When we call `push' in a query, we will add the accumulator to the output of the query. - // If we find here that this accumulator is not the output of the query, - // we can say that the accumulated values we stored for this query is out of date. - if !runtime.is_output_of_active_query(self.dependency_index()) { - accumulated_values.values.truncate(0); - accumulated_values.produced_at = current_revision; - } - - runtime.add_output(self.dependency_index()); - accumulated_values.values.push(value); + state.add_output(self.dependency_index()); + accumulated_values.values.push(value); + }) } pub(crate) fn produced_by( &self, - runtime: &Runtime, + current_revision: Revision, + local_state: &LocalState, query: DatabaseKeyIndex, output: &mut Vec, ) { - let current_revision = runtime.current_revision(); if let Some(v) = self.map.get(&query) { // FIXME: We don't currently have a good way to identify the value that was read. // You can't report is as a tracked read of `query`, because the return value of query is not being read here -- // instead it is the set of values accumuated by `query`. - runtime.report_untracked_read(); + local_state.report_untracked_read(current_revision); let AccumulatedValues { values, diff --git a/src/cycle.rs b/src/cycle.rs index e9c5a4e..4a8a56f 100644 --- a/src/cycle.rs +++ b/src/cycle.rs @@ -1,4 +1,4 @@ -use crate::{database, key::DatabaseKeyIndex, Database}; +use crate::{key::DatabaseKeyIndex, local_state, Database}; use std::{panic::AssertUnwindSafe, sync::Arc}; /// Captures the participants of a cycle that occurred when executing a query. @@ -74,7 +74,7 @@ impl Cycle { impl std::fmt::Debug for Cycle { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - database::with_attached_database(|db| { + local_state::with_attached_database(|db| { f.debug_struct("UnexpectedCycle") .field("all_participants", &self.all_participants(db)) .field("unexpected_participants", &self.unexpected_participants(db)) diff --git a/src/database.rs b/src/database.rs index 8a9a092..ce4514f 100644 --- a/src/database.rs +++ b/src/database.rs @@ -1,6 +1,4 @@ -use std::{cell::Cell, ptr::NonNull}; - -use crate::{storage::DatabaseGen, Durability, Event, Revision}; +use crate::{local_state, storage::DatabaseGen, Durability, Event, Revision}; #[salsa_macros::db] pub trait Database: DatabaseGen { @@ -31,7 +29,10 @@ pub trait Database: DatabaseGen { /// Queries which report untracked reads will be re-executed in the next /// revision. fn report_untracked_read(&self) { - self.runtime().report_untracked_read(); + let db = self.as_salsa_database(); + local_state::attach(db, |state| { + state.report_untracked_read(db.runtime().current_revision()) + }) } /// Execute `op` with the database in thread-local storage for debug print-outs. @@ -39,73 +40,7 @@ pub trait Database: DatabaseGen { where Self: Sized, { - attach_database(self, || op(self)) - } -} - -thread_local! { - static DATABASE: Cell = const { Cell::new(AttachedDatabase::null()) }; -} - -/// Access the "attached" database. Returns `None` if no database is attached. -/// Databases are attached with `attach_database`. -pub fn with_attached_database(op: impl FnOnce(&dyn Database) -> R) -> Option { - // SAFETY: We always attach the database in for the entire duration of a function, - // so it cannot become "unattached" while this function is running. - let db = DATABASE.get(); - Some(op(unsafe { db.ptr?.as_ref() })) -} - -/// Attach database and returns a guard that will un-attach the database when dropped. -/// Has no effect if a database is already attached. -pub fn attach_database(db: &Db, op: impl FnOnce() -> R) -> R { - let _guard = AttachedDb::new(db); - op() -} - -#[derive(Copy, Clone, PartialEq, Eq)] -struct AttachedDatabase { - ptr: Option>, -} - -impl AttachedDatabase { - pub const fn null() -> Self { - Self { ptr: None } - } - - pub fn from(db: &Db) -> Self { - unsafe { - let db: *const dyn Database = db.as_salsa_database(); - Self { - ptr: Some(NonNull::new_unchecked(db as *mut dyn Database)), - } - } - } -} - -struct AttachedDb<'db, Db: ?Sized + Database> { - db: &'db Db, - previous: AttachedDatabase, -} - -impl<'db, Db: ?Sized + Database> AttachedDb<'db, Db> { - pub fn new(db: &'db Db) -> Self { - let previous = DATABASE.replace(AttachedDatabase::from(db)); - AttachedDb { db, previous } - } -} - -impl Drop for AttachedDb<'_, Db> { - fn drop(&mut self) { - DATABASE.set(self.previous); - } -} - -impl std::ops::Deref for AttachedDb<'_, Db> { - type Target = Db; - - fn deref(&self) -> &Db { - self.db + local_state::attach(self, |_state| op(self)) } } diff --git a/src/function/accumulated.rs b/src/function/accumulated.rs index ea457e2..8ff145a 100644 --- a/src/function/accumulated.rs +++ b/src/function/accumulated.rs @@ -1,4 +1,6 @@ -use crate::{accumulator, hash::FxHashSet, storage::DatabaseGen, DatabaseKeyIndex, Id}; +use crate::{ + accumulator, hash::FxHashSet, local_state, storage::DatabaseGen, DatabaseKeyIndex, Id, +}; use super::{Configuration, IngredientImpl}; @@ -12,36 +14,41 @@ where where A: accumulator::Accumulator, { - let Some(accumulator) = >::from_db(db) else { - return vec![]; - }; - let runtime = db.runtime(); - let mut output = vec![]; + local_state::attach(db, |local_state| { + let current_revision = db.runtime().current_revision(); - // First ensure the result is up to date - self.fetch(db, key); + let Some(accumulator) = >::from_db(db) else { + return vec![]; + }; + let mut output = vec![]; - let db_key = self.database_key_index(key); - let mut visited: FxHashSet = FxHashSet::default(); - let mut stack: Vec = vec![db_key]; + // First ensure the result is up to date + self.fetch(db, key); - while let Some(k) = stack.pop() { - if visited.insert(k) { - accumulator.produced_by(runtime, k, &mut output); + let db_key = self.database_key_index(key); + let mut visited: FxHashSet = FxHashSet::default(); + let mut stack: Vec = vec![db_key]; - let origin = db.lookup_ingredient(k.ingredient_index).origin(k.key_index); - let inputs = origin.iter().flat_map(|origin| origin.inputs()); - // Careful: we want to push in execution order, so reverse order to - // ensure the first child that was executed will be the first child popped - // from the stack. - stack.extend( - inputs - .flat_map(|input| TryInto::::try_into(input).into_iter()) - .rev(), - ); + while let Some(k) = stack.pop() { + if visited.insert(k) { + accumulator.produced_by(current_revision, local_state, k, &mut output); + + let origin = db.lookup_ingredient(k.ingredient_index).origin(k.key_index); + let inputs = origin.iter().flat_map(|origin| origin.inputs()); + // Careful: we want to push in execution order, so reverse order to + // ensure the first child that was executed will be the first child popped + // from the stack. + stack.extend( + inputs + .flat_map(|input| { + TryInto::::try_into(input).into_iter() + }) + .rev(), + ); + } } - } - output + output + }) } } diff --git a/src/function/fetch.rs b/src/function/fetch.rs index b780126..c5bbf7f 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -1,6 +1,11 @@ use arc_swap::Guard; -use crate::{runtime::StampedValue, storage::DatabaseGen, Id}; +use crate::{ + local_state::{self, LocalState}, + runtime::StampedValue, + storage::DatabaseGen, + Id, +}; use super::{Configuration, IngredientImpl}; @@ -9,37 +14,41 @@ where C: Configuration, { pub fn fetch<'db>(&'db self, db: &'db C::DbView, key: Id) -> &C::Output<'db> { - let runtime = db.runtime(); + local_state::attach(db.as_salsa_database(), |local_state| { + local_state.unwind_if_revision_cancelled(db.as_salsa_database()); - runtime.unwind_if_revision_cancelled(db); + let StampedValue { + value, + durability, + changed_at, + } = self.compute_value(db, local_state, key); - let StampedValue { - value, - durability, - changed_at, - } = self.compute_value(db, key); + if let Some(evicted) = self.lru.record_use(key) { + self.evict(evicted); + } - if let Some(evicted) = self.lru.record_use(key) { - self.evict(evicted); - } + local_state.report_tracked_read( + self.database_key_index(key).into(), + durability, + changed_at, + ); - db.runtime().report_tracked_read( - self.database_key_index(key).into(), - durability, - changed_at, - ); - - value + value + }) } #[inline] fn compute_value<'db>( &'db self, db: &'db C::DbView, + local_state: &LocalState, key: Id, ) -> StampedValue<&'db C::Output<'db>> { loop { - if let Some(value) = self.fetch_hot(db, key).or_else(|| self.fetch_cold(db, key)) { + if let Some(value) = self + .fetch_hot(db, key) + .or_else(|| self.fetch_cold(db, local_state, key)) + { return value; } } @@ -70,18 +79,18 @@ where fn fetch_cold<'db>( &'db self, db: &'db C::DbView, + local_state: &LocalState, key: Id, ) -> Option>> { - let runtime = db.runtime(); let database_key_index = self.database_key_index(key); // Try to claim this query: if someone else has claimed it already, go back and start again. - let _claim_guard = self - .sync_map - .claim(db.as_salsa_database(), database_key_index)?; + let _claim_guard = + self.sync_map + .claim(db.as_salsa_database(), local_state, database_key_index)?; // Push the query on the stack. - let active_query = runtime.push_query(database_key_index); + let active_query = local_state.push_query(database_key_index); // Now that we've claimed the item, check again to see if there's a "hot" value. // This time we can do a *deep* verify. Because this can recurse, don't hold the arcswap guard. diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index f838d74..41e57ae 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -2,7 +2,7 @@ use arc_swap::Guard; use crate::{ key::DatabaseKeyIndex, - local_state::{ActiveQueryGuard, EdgeKind, QueryOrigin}, + local_state::{self, ActiveQueryGuard, EdgeKind, LocalState, QueryOrigin}, runtime::StampedValue, storage::DatabaseGen, Id, Revision, Runtime, @@ -20,46 +20,51 @@ where key: Id, revision: Revision, ) -> bool { - let runtime = db.runtime(); - runtime.unwind_if_revision_cancelled(db); + local_state::attach(db.as_salsa_database(), |local_state| { + let runtime = db.runtime(); + local_state.unwind_if_revision_cancelled(db.as_salsa_database()); - loop { - let database_key_index = self.database_key_index(key); + loop { + let database_key_index = self.database_key_index(key); - tracing::debug!("{database_key_index:?}: maybe_changed_after(revision = {revision:?})"); + tracing::debug!( + "{database_key_index:?}: maybe_changed_after(revision = {revision:?})" + ); - // Check if we have a verified version: this is the hot path. - let memo_guard = self.memo_map.get(key); - if let Some(memo) = &memo_guard { - if self.shallow_verify_memo(db, runtime, database_key_index, memo) { - return memo.revisions.changed_at > revision; - } - drop(memo_guard); // release the arc-swap guard before cold path - if let Some(mcs) = self.maybe_changed_after_cold(db, key, revision) { - return mcs; + // Check if we have a verified version: this is the hot path. + let memo_guard = self.memo_map.get(key); + if let Some(memo) = &memo_guard { + if self.shallow_verify_memo(db, runtime, database_key_index, memo) { + return memo.revisions.changed_at > revision; + } + drop(memo_guard); // release the arc-swap guard before cold path + if let Some(mcs) = self.maybe_changed_after_cold(db, local_state, key, revision) + { + return mcs; + } else { + // We failed to claim, have to retry. + } } else { - // We failed to claim, have to retry. + // No memo? Assume has changed. + return true; } - } else { - // No memo? Assume has changed. - return true; } - } + }) } fn maybe_changed_after_cold<'db>( &'db self, db: &'db C::DbView, + local_state: &LocalState, key_index: Id, revision: Revision, ) -> Option { - let runtime = db.runtime(); let database_key_index = self.database_key_index(key_index); - let _claim_guard = self - .sync_map - .claim(db.as_salsa_database(), database_key_index)?; - let active_query = runtime.push_query(database_key_index); + let _claim_guard = + self.sync_map + .claim(db.as_salsa_database(), local_state, database_key_index)?; + let active_query = local_state.push_query(database_key_index); // Load the current memo, if any. Use a real arc, not an arc-swap guard, // since we may recurse. @@ -70,7 +75,7 @@ where tracing::debug!( "{database_key_index:?}: maybe_changed_after_cold, successful claim, \ - revision = {revision:?}, old_memo = {old_memo:#?}", + revision = {revision:?}, old_memo = {old_memo:#?}", ); // Check if the inputs are still valid and we can just compare `changed_at`. diff --git a/src/function/specify.rs b/src/function/specify.rs index a066403..97577d1 100644 --- a/src/function/specify.rs +++ b/src/function/specify.rs @@ -1,7 +1,7 @@ use crossbeam::atomic::AtomicCell; use crate::{ - local_state::{QueryOrigin, QueryRevisions}, + local_state::{self, QueryOrigin, QueryRevisions}, storage::DatabaseGen, tracked_struct::TrackedStructInDb, Database, DatabaseKeyIndex, Id, @@ -13,97 +13,83 @@ impl IngredientImpl where C: Configuration, { - /// Specifies the value of the function for the given key. - /// This is a way to imperatively set the value of a function. - /// It only works if the key is a tracked struct created in the current query. - fn specify<'db>( - &'db self, - db: &'db C::DbView, - key: Id, - value: C::Output<'db>, - origin: impl Fn(DatabaseKeyIndex) -> QueryOrigin, - ) where - C::Input<'db>: TrackedStructInDb, - { - let runtime = db.runtime(); - - let (active_query_key, current_deps) = match runtime.active_query() { - Some(v) => v, - None => panic!("can only use `specify` inside a tracked function"), - }; - - // `specify` only works if the key is a tracked struct created in the current query. - // - // The reason is this. We want to ensure that the same result is reached regardless of - // the "path" that the user takes through the execution graph. - // If you permit values to be specified from other queries, you can have a situation like this: - // * Q0 creates the tracked struct T0 - // * Q1 specifies the value for F(T0) - // * Q2 invokes F(T0) - // * Q3 invokes Q1 and then Q2 - // * Q4 invokes Q2 and then Q1 - // - // Now, if We invoke Q3 first, We get one result for Q2, but if We invoke Q4 first, We get a different value. That's no good. - let database_key_index = >::database_key_index(db.as_salsa_database(), key); - let dependency_index = database_key_index.into(); - if !runtime.is_output_of_active_query(dependency_index) { - panic!("can only use `specify` on salsa structs created during the current tracked fn"); - } - - // Subtle: we treat the "input" to a set query as if it were - // volatile. - // - // The idea is this. You have the current query C that - // created the entity E, and it is setting the value F(E) of the function F. - // When some other query R reads the field F(E), in order to have obtained - // the entity E, it has to have executed the query C. - // - // This will have forced C to either: - // - // - not create E this time, in which case R shouldn't have it (some kind of leak has occurred) - // - assign a value to F(E), in which case `verified_at` will be the current revision and `changed_at` will be updated appropriately - // - NOT assign a value to F(E), in which case we need to re-execute the function (which typically panics). - // - // So, ruling out the case of a leak having occurred, that means that the reader R will either see: - // - // - a result that is verified in the current revision, because it was set, which will use the set value - // - a result that is NOT verified and has untracked inputs, which will re-execute (and likely panic) - - let revision = runtime.current_revision(); - let mut revisions = QueryRevisions { - changed_at: current_deps.changed_at, - durability: current_deps.durability, - origin: origin(active_query_key), - }; - - if let Some(old_memo) = self.memo_map.get(key) { - self.backdate_if_appropriate(&old_memo, &mut revisions, &value); - self.diff_outputs(db, database_key_index, &old_memo, &revisions); - } - - let memo = Memo { - value: Some(value), - verified_at: AtomicCell::new(revision), - revisions, - }; - - tracing::debug!("specify: about to add memo {:#?} for key {:?}", memo, key); - self.insert_memo(db, key, memo); - } - /// Specify the value for `key` *and* record that we did so. /// Used for explicit calls to `specify`, but not needed for pre-declared tracked struct fields. pub fn specify_and_record<'db>(&'db self, db: &'db C::DbView, key: Id, value: C::Output<'db>) where C::Input<'db>: TrackedStructInDb, { - self.specify(db, key, value, |database_key_index| { - QueryOrigin::Assigned(database_key_index) - }); + local_state::attach(db.as_salsa_database(), |state| { + let (active_query_key, current_deps) = match state.active_query() { + Some(v) => v, + None => panic!("can only use `specify` inside a tracked function"), + }; - // Record that the current query *specified* a value for this cell. - let database_key_index = self.database_key_index(key); - db.runtime().add_output(database_key_index.into()); + // `specify` only works if the key is a tracked struct created in the current query. + // + // The reason is this. We want to ensure that the same result is reached regardless of + // the "path" that the user takes through the execution graph. + // If you permit values to be specified from other queries, you can have a situation like this: + // * Q0 creates the tracked struct T0 + // * Q1 specifies the value for F(T0) + // * Q2 invokes F(T0) + // * Q3 invokes Q1 and then Q2 + // * Q4 invokes Q2 and then Q1 + // + // Now, if We invoke Q3 first, We get one result for Q2, but if We invoke Q4 first, We get a different value. That's no good. + let database_key_index = + >::database_key_index(db.as_salsa_database(), key); + let dependency_index = database_key_index.into(); + if !state.is_output_of_active_query(dependency_index) { + panic!( + "can only use `specify` on salsa structs created during the current tracked fn" + ); + } + + // Subtle: we treat the "input" to a set query as if it were + // volatile. + // + // The idea is this. You have the current query C that + // created the entity E, and it is setting the value F(E) of the function F. + // When some other query R reads the field F(E), in order to have obtained + // the entity E, it has to have executed the query C. + // + // This will have forced C to either: + // + // - not create E this time, in which case R shouldn't have it (some kind of leak has occurred) + // - assign a value to F(E), in which case `verified_at` will be the current revision and `changed_at` will be updated appropriately + // - NOT assign a value to F(E), in which case we need to re-execute the function (which typically panics). + // + // So, ruling out the case of a leak having occurred, that means that the reader R will either see: + // + // - a result that is verified in the current revision, because it was set, which will use the set value + // - a result that is NOT verified and has untracked inputs, which will re-execute (and likely panic) + + let revision = db.runtime().current_revision(); + let mut revisions = QueryRevisions { + changed_at: current_deps.changed_at, + durability: current_deps.durability, + origin: QueryOrigin::Assigned(active_query_key), + }; + + if let Some(old_memo) = self.memo_map.get(key) { + self.backdate_if_appropriate(&old_memo, &mut revisions, &value); + self.diff_outputs(db, database_key_index, &old_memo, &revisions); + } + + let memo = Memo { + value: Some(value), + verified_at: AtomicCell::new(revision), + revisions, + }; + + tracing::debug!("specify: about to add memo {:#?} for key {:?}", memo, key); + self.insert_memo(db, key, memo); + + // Record that the current query *specified* a value for this cell. + let database_key_index = self.database_key_index(key); + state.add_output(database_key_index.into()); + }) } /// Invoked when the query `executor` has been validated as having green inputs diff --git a/src/function/sync.rs b/src/function/sync.rs index 79b504f..f9b1d1f 100644 --- a/src/function/sync.rs +++ b/src/function/sync.rs @@ -3,7 +3,10 @@ use std::{ thread::ThreadId, }; -use crate::{hash::FxDashMap, key::DatabaseKeyIndex, runtime::WaitResult, Database, Id, Runtime}; +use crate::{ + hash::FxDashMap, key::DatabaseKeyIndex, local_state::LocalState, runtime::WaitResult, Database, + Id, Runtime, +}; #[derive(Default)] pub(super) struct SyncMap { @@ -22,6 +25,7 @@ impl SyncMap { pub(super) fn claim<'me>( &'me self, db: &'me dyn Database, + local_state: &LocalState, database_key_index: DatabaseKeyIndex, ) -> Option> { let runtime = db.runtime(); @@ -47,7 +51,7 @@ impl SyncMap { // not to gate future atomic reads. entry.get().anyone_waiting.store(true, Ordering::Relaxed); let other_id = entry.get().id; - runtime.block_on_or_unwind(db, database_key_index, other_id, entry); + runtime.block_on_or_unwind(db, local_state, database_key_index, other_id, entry); None } } diff --git a/src/input.rs b/src/input.rs index cd6e508..d41d3c4 100644 --- a/src/input.rs +++ b/src/input.rs @@ -17,7 +17,7 @@ use crate::{ id::{AsId, FromId}, ingredient::{fmt_index, Ingredient}, key::{DatabaseKeyIndex, DependencyIndex}, - local_state::QueryOrigin, + local_state::{self, QueryOrigin}, plumbing::{Jar, Stamp}, runtime::Runtime, storage::IngredientIndex, @@ -154,19 +154,21 @@ impl IngredientImpl { id: C::Struct, field_index: usize, ) -> &'db C::Fields { - let field_ingredient_index = self.ingredient_index.successor(field_index); - let id = id.as_id(); - let value = self.struct_map.get(id); - let stamp = &value.stamps[field_index]; - db.runtime().report_tracked_read( - DependencyIndex { - ingredient_index: field_ingredient_index, - key_index: Some(id), - }, - stamp.durability, - stamp.changed_at, - ); - &value.fields + local_state::attach(db, |state| { + let field_ingredient_index = self.ingredient_index.successor(field_index); + let id = id.as_id(); + let value = self.struct_map.get(id); + let stamp = &value.stamps[field_index]; + state.report_tracked_read( + DependencyIndex { + ingredient_index: field_ingredient_index, + key_index: Some(id), + }, + stamp.durability, + stamp.changed_at, + ); + &value.fields + }) } /// Peek at the field values without recording any read dependency. diff --git a/src/interned.rs b/src/interned.rs index 18be1eb..81c5794 100644 --- a/src/interned.rs +++ b/src/interned.rs @@ -9,7 +9,7 @@ use crate::durability::Durability; use crate::id::AsId; use crate::ingredient::fmt_index; use crate::key::DependencyIndex; -use crate::local_state::QueryOrigin; +use crate::local_state::{self, QueryOrigin}; use crate::plumbing::Jar; use crate::storage::IngredientIndex; use crate::{Database, DatabaseKeyIndex, Id}; @@ -136,44 +136,46 @@ where db: &'db dyn crate::Database, data: C::Data<'db>, ) -> C::Struct<'db> { - db.runtime().report_tracked_read( - DependencyIndex::for_table(self.ingredient_index), - Durability::MAX, - self.reset_at, - ); + local_state::attach(db, |state| { + state.report_tracked_read( + DependencyIndex::for_table(self.ingredient_index), + Durability::MAX, + self.reset_at, + ); - // Optimisation to only get read lock on the map if the data has already - // been interned. - let internal_data = unsafe { self.to_internal_data(data) }; - if let Some(guard) = self.key_map.get(&internal_data) { - let id = *guard; - drop(guard); - return self.interned_value(id); - } - - match self.key_map.entry(internal_data.clone()) { - // Data has been interned by a racing call, use that ID instead - dashmap::mapref::entry::Entry::Occupied(entry) => { - let id = *entry.get(); - drop(entry); - self.interned_value(id) + // Optimisation to only get read lock on the map if the data has already + // been interned. + let internal_data = unsafe { self.to_internal_data(data) }; + if let Some(guard) = self.key_map.get(&internal_data) { + let id = *guard; + drop(guard); + return self.interned_value(id); } - // We won any races so should intern the data - dashmap::mapref::entry::Entry::Vacant(entry) => { - let next_id = self.counter.fetch_add(1); - let next_id = crate::id::Id::from_u32(next_id); - let value = self.value_map.entry(next_id).or_insert(Alloc::new(Value { - id: next_id, - fields: internal_data, - })); - let value_raw = value.as_raw(); - drop(value); - entry.insert(next_id); - // SAFETY: Items are only removed from the `value_map` with an `&mut self` reference. - unsafe { C::struct_from_raw(value_raw) } + match self.key_map.entry(internal_data.clone()) { + // Data has been interned by a racing call, use that ID instead + dashmap::mapref::entry::Entry::Occupied(entry) => { + let id = *entry.get(); + drop(entry); + self.interned_value(id) + } + + // We won any races so should intern the data + dashmap::mapref::entry::Entry::Vacant(entry) => { + let next_id = self.counter.fetch_add(1); + let next_id = crate::id::Id::from_u32(next_id); + let value = self.value_map.entry(next_id).or_insert(Alloc::new(Value { + id: next_id, + fields: internal_data, + })); + let value_raw = value.as_raw(); + drop(value); + entry.insert(next_id); + // SAFETY: Items are only removed from the `value_map` with an `&mut self` reference. + unsafe { C::struct_from_raw(value_raw) } + } } - } + }) } pub fn interned_value(&self, id: Id) -> C::Struct<'_> { diff --git a/src/key.rs b/src/key.rs index e52631f..e3e43c0 100644 --- a/src/key.rs +++ b/src/key.rs @@ -1,4 +1,4 @@ -use crate::{cycle::CycleRecoveryStrategy, database, storage::IngredientIndex, Database, Id}; +use crate::{cycle::CycleRecoveryStrategy, local_state, storage::IngredientIndex, Database, Id}; /// An integer that uniquely identifies a particular query instance within the /// database. Used to track dependencies between queries. Fully ordered and @@ -57,7 +57,7 @@ impl DependencyIndex { impl std::fmt::Debug for DependencyIndex { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - database::with_attached_database(|db| { + local_state::with_attached_database(|db| { let ingredient = db.lookup_ingredient(self.ingredient_index); ingredient.fmt_index(self.key_index, f) }) diff --git a/src/lib.rs b/src/lib.rs index 88700b2..7d1397a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -41,7 +41,7 @@ pub use self::revision::Revision; pub use self::runtime::Runtime; pub use self::storage::Storage; pub use self::update::Update; -pub use crate::database::with_attached_database; +pub use crate::local_state::with_attached_database; pub use salsa_macros::accumulator; pub use salsa_macros::db; pub use salsa_macros::input; @@ -79,9 +79,7 @@ pub mod plumbing { pub use crate::array::Array; pub use crate::cycle::Cycle; pub use crate::cycle::CycleRecoveryStrategy; - pub use crate::database::attach_database; pub use crate::database::current_revision; - pub use crate::database::with_attached_database; pub use crate::database::Database; pub use crate::function::should_backdate_value; pub use crate::id::AsId; @@ -91,6 +89,7 @@ pub mod plumbing { pub use crate::ingredient::Ingredient; pub use crate::ingredient::Jar; pub use crate::key::DatabaseKeyIndex; + pub use crate::local_state::with_attached_database; pub use crate::revision::Revision; pub use crate::runtime::stamp; pub use crate::runtime::Runtime; diff --git a/src/local_state.rs b/src/local_state.rs index cc0c839..6c5941a 100644 --- a/src/local_state.rs +++ b/src/local_state.rs @@ -5,13 +5,48 @@ use crate::durability::Durability; use crate::key::DatabaseKeyIndex; use crate::key::DependencyIndex; use crate::runtime::StampedValue; +use crate::storage::IngredientIndex; use crate::tracked_struct::Disambiguator; +use crate::Cancelled; use crate::Cycle; +use crate::Database; +use crate::Event; +use crate::EventKind; use crate::Revision; use crate::Runtime; +use std::cell::Cell; use std::cell::RefCell; +use std::ptr::NonNull; use std::sync::Arc; +thread_local! { + /// The thread-local state salsa requires for a given thread + static LOCAL_STATE: LocalState = const { LocalState::new() } +} + +/// Attach the database to the current thread and execute `op`. +/// Panics if a different database has already been attached. +pub(crate) fn attach(db: &DB, op: impl FnOnce(&LocalState) -> R) -> R +where + DB: ?Sized + Database, +{ + LOCAL_STATE.with(|state| state.attach(db.as_salsa_database(), || op(state))) +} + +/// Access the "attached" database. Returns `None` if no database is attached. +/// Databases are attached with `attach_database`. +pub fn with_attached_database(op: impl FnOnce(&dyn Database) -> R) -> Option { + LOCAL_STATE.with(|state| { + if let Some(db) = state.database.get() { + // SAFETY: We always attach the database in for the entire duration of a function, + // so it cannot become "unattached" while this function is running. + Some(op(unsafe { db.as_ref() })) + } else { + None + } + }) +} + /// State that is specific to a single execution thread. /// /// Internally, this type uses ref-cells. @@ -19,6 +54,9 @@ use std::sync::Arc; /// **Note also that all mutations to the database handle (and hence /// to the local-state) must be undone during unwinding.** pub(crate) struct LocalState { + /// Pointer to the currently attached database. + database: Cell>>, + /// Vector of active queries. /// /// This is normally `Some`, but it is set to `None` @@ -29,6 +67,282 @@ pub(crate) struct LocalState { query_stack: RefCell>>, } +impl LocalState { + const fn new() -> Self { + LocalState { + database: Cell::new(None), + query_stack: RefCell::new(Some(vec![])), + } + } + + fn attach(&self, db: &dyn Database, op: impl FnOnce() -> R) -> R { + struct DbGuard<'s> { + state: Option<&'s LocalState>, + } + + impl<'s> DbGuard<'s> { + fn new(state: &'s LocalState, db: &dyn Database) -> Self { + if let Some(current_db) = state.database.get() { + // Already attached? Assert that the database has not changed. + assert_eq!( + current_db, + NonNull::from(db), + "cannot change database mid-query", + ); + Self { state: None } + } else { + // Otherwise, set the database. + state.database.set(Some(NonNull::from(db))); + Self { state: Some(state) } + } + } + } + + impl Drop for DbGuard<'_> { + fn drop(&mut self) { + // Reset database to null if we did anything in `DbGuard::new`. + if let Some(state) = self.state { + state.database.set(None); + + // All stack frames should have been popped from the local stack. + assert!(state.query_stack.borrow().as_ref().unwrap().is_empty()); + } + } + } + + let _guard = DbGuard::new(self, db); + op() + } + + #[inline] + pub(crate) fn push_query(&self, database_key_index: DatabaseKeyIndex) -> ActiveQueryGuard<'_> { + let mut query_stack = self.query_stack.borrow_mut(); + let query_stack = query_stack.as_mut().expect("local stack taken"); + query_stack.push(ActiveQuery::new(database_key_index)); + ActiveQueryGuard { + local_state: self, + database_key_index, + push_len: query_stack.len(), + } + } + + fn with_query_stack(&self, c: impl FnOnce(&mut Vec) -> R) -> R { + c(self + .query_stack + .borrow_mut() + .as_mut() + .expect("query stack taken")) + } + + fn query_in_progress(&self) -> bool { + self.with_query_stack(|stack| !stack.is_empty()) + } + + /// Returns the index of the active query along with its *current* durability/changed-at + /// information. As the query continues to execute, naturally, that information may change. + pub(crate) fn active_query(&self) -> Option<(DatabaseKeyIndex, StampedValue<()>)> { + self.with_query_stack(|stack| { + stack.last().map(|active_query| { + ( + active_query.database_key_index, + StampedValue { + value: (), + durability: active_query.durability, + changed_at: active_query.changed_at, + }, + ) + }) + }) + } + + /// Add an output to the current query's list of dependencies + pub(crate) fn add_output(&self, entity: DependencyIndex) { + self.with_query_stack(|stack| { + if let Some(top_query) = stack.last_mut() { + top_query.add_output(entity) + } + }) + } + + /// Check whether `entity` is an output of the currently active query (if any) + pub(crate) fn is_output_of_active_query(&self, entity: DependencyIndex) -> bool { + self.with_query_stack(|stack| { + if let Some(top_query) = stack.last_mut() { + top_query.is_output(entity) + } else { + false + } + }) + } + + /// Register that currently active query reads the given input + pub(crate) fn report_tracked_read( + &self, + input: DependencyIndex, + durability: Durability, + changed_at: Revision, + ) { + debug!( + "report_query_read_and_unwind_if_cycle_resulted(input={:?}, durability={:?}, changed_at={:?})", + input, durability, changed_at + ); + self.with_query_stack(|stack| { + if let Some(top_query) = stack.last_mut() { + top_query.add_read(input, durability, changed_at); + + // We are a cycle participant: + // + // C0 --> ... --> Ci --> Ci+1 -> ... -> Cn --> C0 + // ^ ^ + // : | + // This edge -----+ | + // | + // | + // N0 + // + // In this case, the value we have just read from `Ci+1` + // is actually the cycle fallback value and not especially + // interesting. We unwind now with `CycleParticipant` to avoid + // executing the rest of our query function. This unwinding + // will be caught and our own fallback value will be used. + // + // Note that `Ci+1` may` have *other* callers who are not + // participants in the cycle (e.g., N0 in the graph above). + // They will not have the `cycle` marker set in their + // stack frames, so they will just read the fallback value + // from `Ci+1` and continue on their merry way. + if let Some(cycle) = &top_query.cycle { + cycle.clone().throw() + } + } + }) + } + + /// Register that the current query read an untracked value + /// + /// # Parameters + /// + /// * `current_revision`, the current revision + pub(crate) fn report_untracked_read(&self, current_revision: Revision) { + self.with_query_stack(|stack| { + if let Some(top_query) = stack.last_mut() { + top_query.add_untracked_read(current_revision); + } + }) + } + + /// Update the top query on the stack to act as though it read a value + /// of durability `durability` which changed in `revision`. + // FIXME: Use or remove this. + #[allow(dead_code)] + pub(crate) fn report_synthetic_read(&self, durability: Durability, revision: Revision) { + self.with_query_stack(|stack| { + if let Some(top_query) = stack.last_mut() { + top_query.add_synthetic_read(durability, revision); + } + }) + } + + /// Takes the query stack and returns it. This is used when + /// the current thread is blocking. The stack must be restored + /// with [`Self::restore_query_stack`] when the thread unblocks. + pub(crate) fn take_query_stack(&self) -> Vec { + assert!( + self.query_stack.borrow().is_some(), + "query stack already taken" + ); + self.query_stack.take().unwrap() + } + + /// Restores a query stack taken with [`Self::take_query_stack`] once + /// the thread unblocks. + pub(crate) fn restore_query_stack(&self, stack: Vec) { + assert!(self.query_stack.borrow().is_none(), "query stack not taken"); + self.query_stack.replace(Some(stack)); + } + + /// Called when the active queries creates an index from the + /// entity table with the index `entity_index`. Has the following effects: + /// + /// * Add a query read on `DatabaseKeyIndex::for_table(entity_index)` + /// * Identify a unique disambiguator for the hash within the current query, + /// adding the hash to the current query's disambiguator table. + /// * Returns a tuple of: + /// * the id of the current query + /// * the current dependencies (durability, changed_at) of current query + /// * the disambiguator index + #[track_caller] + pub(crate) fn disambiguate( + &self, + entity_index: IngredientIndex, + reset_at: Revision, + data_hash: u64, + ) -> (DatabaseKeyIndex, StampedValue<()>, Disambiguator) { + assert!( + self.query_in_progress(), + "cannot create a tracked struct disambiguator outside of a tracked function" + ); + + self.report_tracked_read( + DependencyIndex::for_table(entity_index), + Durability::MAX, + reset_at, + ); + + self.with_query_stack(|stack| { + let top_query = stack.last_mut().unwrap(); + let disambiguator = top_query.disambiguate(data_hash); + ( + top_query.database_key_index, + StampedValue { + value: (), + durability: top_query.durability, + changed_at: top_query.changed_at, + }, + disambiguator, + ) + }) + } + + /// Starts unwinding the stack if the current revision is cancelled. + /// + /// This method can be called by query implementations that perform + /// potentially expensive computations, in order to speed up propagation of + /// cancellation. + /// + /// Cancellation will automatically be triggered by salsa on any query + /// invocation. + /// + /// This method should not be overridden by `Database` implementors. A + /// `salsa_event` is emitted when this method is called, so that should be + /// used instead. + pub(crate) fn unwind_if_revision_cancelled(&self, db: &dyn Database) { + let runtime = db.runtime(); + let thread_id = std::thread::current().id(); + db.salsa_event(Event { + thread_id, + + kind: EventKind::WillCheckCancellation, + }); + if runtime.load_cancellation_flag() { + db.salsa_event(Event { + thread_id, + kind: EventKind::WillCheckCancellation, + }); + self.unwind_cancelled(runtime); + } + } + + #[cold] + pub(crate) fn unwind_cancelled(&self, runtime: &Runtime) { + let current_revision = runtime.current_revision(); + self.report_untracked_read(current_revision); + Cancelled::PendingWrite.throw(); + } +} + +impl std::panic::RefUnwindSafe for LocalState {} + /// Summarizes "all the inputs that a query used" #[derive(Debug, Clone)] pub(crate) struct QueryRevisions { @@ -149,197 +463,6 @@ impl QueryEdges { } } -impl Default for LocalState { - fn default() -> Self { - LocalState { - query_stack: RefCell::new(Some(Vec::new())), - } - } -} - -impl LocalState { - #[inline] - pub(crate) fn push_query(&self, database_key_index: DatabaseKeyIndex) -> ActiveQueryGuard<'_> { - let mut query_stack = self.query_stack.borrow_mut(); - let query_stack = query_stack.as_mut().expect("local stack taken"); - query_stack.push(ActiveQuery::new(database_key_index)); - ActiveQueryGuard { - local_state: self, - database_key_index, - push_len: query_stack.len(), - } - } - - fn with_query_stack(&self, c: impl FnOnce(&mut Vec) -> R) -> R { - c(self - .query_stack - .borrow_mut() - .as_mut() - .expect("query stack taken")) - } - - pub(crate) fn query_in_progress(&self) -> bool { - self.with_query_stack(|stack| !stack.is_empty()) - } - - /// Returns the index of the active query along with its *current* durability/changed-at - /// information. As the query continues to execute, naturally, that information may change. - pub(crate) fn active_query(&self) -> Option<(DatabaseKeyIndex, StampedValue<()>)> { - self.with_query_stack(|stack| { - stack.last().map(|active_query| { - ( - active_query.database_key_index, - StampedValue { - value: (), - durability: active_query.durability, - changed_at: active_query.changed_at, - }, - ) - }) - }) - } - - /// Add an output to the current query's list of dependencies - pub(crate) fn add_output(&self, entity: DependencyIndex) { - self.with_query_stack(|stack| { - if let Some(top_query) = stack.last_mut() { - top_query.add_output(entity) - } - }) - } - - /// Check whether `entity` is an output of the currently active query (if any) - pub(crate) fn is_output(&self, entity: DependencyIndex) -> bool { - self.with_query_stack(|stack| { - if let Some(top_query) = stack.last_mut() { - top_query.is_output(entity) - } else { - false - } - }) - } - - /// Register that currently active query reads the given input - pub(crate) fn report_tracked_read( - &self, - input: DependencyIndex, - durability: Durability, - changed_at: Revision, - ) { - debug!( - "report_query_read_and_unwind_if_cycle_resulted(input={:?}, durability={:?}, changed_at={:?})", - input, durability, changed_at - ); - self.with_query_stack(|stack| { - if let Some(top_query) = stack.last_mut() { - top_query.add_read(input, durability, changed_at); - - // We are a cycle participant: - // - // C0 --> ... --> Ci --> Ci+1 -> ... -> Cn --> C0 - // ^ ^ - // : | - // This edge -----+ | - // | - // | - // N0 - // - // In this case, the value we have just read from `Ci+1` - // is actually the cycle fallback value and not especially - // interesting. We unwind now with `CycleParticipant` to avoid - // executing the rest of our query function. This unwinding - // will be caught and our own fallback value will be used. - // - // Note that `Ci+1` may` have *other* callers who are not - // participants in the cycle (e.g., N0 in the graph above). - // They will not have the `cycle` marker set in their - // stack frames, so they will just read the fallback value - // from `Ci+1` and continue on their merry way. - if let Some(cycle) = &top_query.cycle { - cycle.clone().throw() - } - } - }) - } - - /// Register that the current query read an untracked value - /// - /// # Parameters - /// - /// * `current_revision`, the current revision - pub(crate) fn report_untracked_read(&self, current_revision: Revision) { - self.with_query_stack(|stack| { - if let Some(top_query) = stack.last_mut() { - top_query.add_untracked_read(current_revision); - } - }) - } - - /// Update the top query on the stack to act as though it read a value - /// of durability `durability` which changed in `revision`. - // FIXME: Use or remove this. - #[allow(dead_code)] - pub(crate) fn report_synthetic_read(&self, durability: Durability, revision: Revision) { - self.with_query_stack(|stack| { - if let Some(top_query) = stack.last_mut() { - top_query.add_synthetic_read(durability, revision); - } - }) - } - - /// Takes the query stack and returns it. This is used when - /// the current thread is blocking. The stack must be restored - /// with [`Self::restore_query_stack`] when the thread unblocks. - pub(crate) fn take_query_stack(&self) -> Vec { - assert!( - self.query_stack.borrow().is_some(), - "query stack already taken" - ); - self.query_stack.take().unwrap() - } - - /// Restores a query stack taken with [`Self::take_query_stack`] once - /// the thread unblocks. - pub(crate) fn restore_query_stack(&self, stack: Vec) { - assert!(self.query_stack.borrow().is_none(), "query stack not taken"); - self.query_stack.replace(Some(stack)); - } - - /// Given the hash of the id fields of a tracked struct, returns: - /// - /// * database-key-index of currently active query - /// * durability/changed-at info for the inputs read thus far by said query - /// * a `Disambiguator` that uniquely identifies the tracked struct about to be created - /// - /// The disambiguator is basically an integer that increments each time - /// a tracked struct with this `data_hash` is created. - #[track_caller] - pub(crate) fn disambiguate( - &self, - data_hash: u64, - ) -> (DatabaseKeyIndex, StampedValue<()>, Disambiguator) { - assert!( - self.query_in_progress(), - "cannot create a tracked struct disambiguator outside of a tracked function" - ); - self.with_query_stack(|stack| { - let top_query = stack.last_mut().unwrap(); - let disambiguator = top_query.disambiguate(data_hash); - ( - top_query.database_key_index, - StampedValue { - value: (), - durability: top_query.durability, - changed_at: top_query.changed_at, - }, - disambiguator, - ) - }) - } -} - -impl std::panic::RefUnwindSafe for LocalState {} - /// When a query is pushed onto the `active_query` stack, this guard /// is returned to represent its slot. The guard can be used to pop /// the query from the stack -- in the case of unwinding, the guard's diff --git a/src/runtime.rs b/src/runtime.rs index 1439f00..d47d519 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -12,22 +12,16 @@ use crate::{ cycle::CycleRecoveryStrategy, durability::Durability, key::{DatabaseKeyIndex, DependencyIndex}, - local_state::{self, ActiveQueryGuard, EdgeKind}, + local_state::{EdgeKind, LocalState}, revision::AtomicRevision, - storage::IngredientIndex, Cancelled, Cycle, Database, Event, EventKind, Revision, }; use self::dependency_graph::DependencyGraph; -use super::tracked_struct::Disambiguator; - mod dependency_graph; pub struct Runtime { - /// Local state that is specific to this runtime (thread). - local_state: local_state::LocalState, - /// Stores the next id to use for a snapshotted runtime (starts at 1). next_id: AtomicUsize, @@ -91,7 +85,6 @@ impl StampedValue { impl Default for Runtime { fn default() -> Self { Runtime { - local_state: Default::default(), revisions: (0..Durability::LEN) .map(|_| AtomicRevision::start()) .collect(), @@ -119,35 +112,10 @@ impl Runtime { self.revisions[0].load() } - /// Returns the index of the active query along with its *current* durability/changed-at - /// information. As the query continues to execute, naturally, that information may change. - pub(crate) fn active_query(&self) -> Option<(DatabaseKeyIndex, StampedValue<()>)> { - self.local_state.active_query() - } - pub(crate) fn empty_dependencies(&self) -> Arc<[(EdgeKind, DependencyIndex)]> { self.empty_dependencies.clone() } - pub(crate) fn report_tracked_read( - &self, - key_index: DependencyIndex, - durability: Durability, - changed_at: Revision, - ) { - self.local_state - .report_tracked_read(key_index, durability, changed_at) - } - - /// Reports that the query depends on some state unknown to salsa. - /// - /// Queries which report untracked reads will be re-executed in the next - /// revision. - pub fn report_untracked_read(&self) { - self.local_state - .report_untracked_read(self.current_revision()); - } - /// Reports that an input with durability `durability` changed. /// This will update the 'last changed at' values for every durability /// less than or equal to `durability` to the current revision. @@ -158,41 +126,6 @@ impl Runtime { } } - /// Adds `key` to the list of output created by the current query - /// (if not already present). - pub(crate) fn add_output(&self, key: DependencyIndex) { - self.local_state.add_output(key); - } - - /// Check whether `entity` is contained the list of outputs written by the current query. - pub(super) fn is_output_of_active_query(&self, entity: DependencyIndex) -> bool { - self.local_state.is_output(entity) - } - - /// Called when the active queries creates an index from the - /// entity table with the index `entity_index`. Has the following effects: - /// - /// * Add a query read on `DatabaseKeyIndex::for_table(entity_index)` - /// * Identify a unique disambiguator for the hash within the current query, - /// adding the hash to the current query's disambiguator table. - /// * Returns a tuple of: - /// * the id of the current query - /// * the current dependencies (durability, changed_at) of current query - /// * the disambiguator index - pub(crate) fn disambiguate_entity( - &self, - entity_index: IngredientIndex, - reset_at: Revision, - data_hash: u64, - ) -> (DatabaseKeyIndex, StampedValue<()>, Disambiguator) { - self.report_tracked_read( - DependencyIndex::for_table(entity_index), - Durability::MAX, - reset_at, - ); - self.local_state.disambiguate(data_hash) - } - /// The revision in which values with durability `d` may have last /// changed. For D0, this is just the current revision. But for /// higher levels of durability, this value may lag behind the @@ -205,38 +138,8 @@ impl Runtime { self.revisions[d.index()].load() } - /// Starts unwinding the stack if the current revision is cancelled. - /// - /// This method can be called by query implementations that perform - /// potentially expensive computations, in order to speed up propagation of - /// cancellation. - /// - /// Cancellation will automatically be triggered by salsa on any query - /// invocation. - /// - /// This method should not be overridden by `Database` implementors. A - /// `salsa_event` is emitted when this method is called, so that should be - /// used instead. - pub(crate) fn unwind_if_revision_cancelled(&self, db: &DB) { - let thread_id = std::thread::current().id(); - db.salsa_event(Event { - thread_id, - - kind: EventKind::WillCheckCancellation, - }); - if self.revision_canceled.load() { - db.salsa_event(Event { - thread_id, - kind: EventKind::WillCheckCancellation, - }); - self.unwind_cancelled(); - } - } - - #[cold] - pub(crate) fn unwind_cancelled(&self) { - self.report_untracked_read(); - Cancelled::PendingWrite.throw(); + pub(crate) fn load_cancellation_flag(&self) -> bool { + self.revision_canceled.load() } pub(crate) fn set_cancellation_flag(&self) { @@ -255,11 +158,6 @@ impl Runtime { r_new } - #[inline] - pub(crate) fn push_query(&self, database_key_index: DatabaseKeyIndex) -> ActiveQueryGuard<'_> { - self.local_state.push_query(database_key_index) - } - /// Block until `other_id` completes executing `database_key`; /// panic or unwind in the case of a cycle. /// @@ -285,6 +183,7 @@ impl Runtime { pub(crate) fn block_on_or_unwind( &self, db: &dyn Database, + local_state: &LocalState, database_key: DatabaseKeyIndex, other_id: ThreadId, query_mutex_guard: QueryMutexGuard, @@ -293,7 +192,7 @@ impl Runtime { let thread_id = std::thread::current().id(); if dg.depends_on(other_id, thread_id) { - self.unblock_cycle_and_maybe_throw(db, &mut dg, database_key, other_id); + self.unblock_cycle_and_maybe_throw(db, local_state, &mut dg, database_key, other_id); // If the above fn returns, then (via cycle recovery) it has unblocked the // cycle, so we can continue. @@ -308,7 +207,7 @@ impl Runtime { }, }); - let stack = self.local_state.take_query_stack(); + let stack = local_state.take_query_stack(); let (stack, result) = DependencyGraph::block_on( dg, @@ -319,7 +218,7 @@ impl Runtime { query_mutex_guard, ); - self.local_state.restore_query_stack(stack); + local_state.restore_query_stack(stack); match result { WaitResult::Completed => (), @@ -344,6 +243,7 @@ impl Runtime { fn unblock_cycle_and_maybe_throw( &self, db: &dyn Database, + local_state: &LocalState, dg: &mut DependencyGraph, database_key_index: DatabaseKeyIndex, to_id: ThreadId, @@ -353,7 +253,7 @@ impl Runtime { database_key_index ); - let mut from_stack = self.local_state.take_query_stack(); + let mut from_stack = local_state.take_query_stack(); let from_id = std::thread::current().id(); // Make a "dummy stack frame". As we iterate through the cycle, we will collect the @@ -426,7 +326,7 @@ impl Runtime { let (me_recovered, others_recovered) = dg.maybe_unblock_runtimes_in_cycle(from_id, &from_stack, database_key_index, to_id); - self.local_state.restore_query_stack(from_stack); + local_state.restore_query_stack(from_stack); if me_recovered { // If the current thread has recovery, we want to throw diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index b768075..5fb2420 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -11,7 +11,7 @@ use crate::{ ingredient::{fmt_index, Ingredient, Jar}, ingredient_list::IngredientList, key::{DatabaseKeyIndex, DependencyIndex}, - local_state::QueryOrigin, + local_state::{self, QueryOrigin}, runtime::Runtime, salsa_struct::SalsaStructInDb, storage::IngredientIndex, @@ -291,83 +291,84 @@ where db: &'db dyn Database, fields: C::Fields<'db>, ) -> C::Struct<'db> { - let data_hash = crate::hash::hash(&C::id_fields(&fields)); + local_state::attach(db, |local_state| { + let data_hash = crate::hash::hash(&C::id_fields(&fields)); - let runtime = db.runtime(); - let (query_key, current_deps, disambiguator) = - runtime.disambiguate_entity(self.ingredient_index, Revision::start(), data_hash); + let (query_key, current_deps, disambiguator) = + local_state.disambiguate(self.ingredient_index, Revision::start(), data_hash); - let entity_key = KeyStruct { - query_key, - disambiguator, - data_hash, - }; + let entity_key = KeyStruct { + query_key, + disambiguator, + data_hash, + }; - let (id, new_id) = self.intern(entity_key); - runtime.add_output(self.database_key_index(id).into()); + let (id, new_id) = self.intern(entity_key); + local_state.add_output(self.database_key_index(id).into()); - let current_revision = runtime.current_revision(); - if new_id { - // This is a new tracked struct, so create an entry in the struct map. + let current_revision = db.runtime().current_revision(); + if new_id { + // This is a new tracked struct, so create an entry in the struct map. - self.struct_map.insert( - runtime, - Value { - id, - key: entity_key, - struct_ingredient_index: self.ingredient_index, - created_at: current_revision, - durability: current_deps.durability, - fields: unsafe { self.to_static(fields) }, - revisions: C::new_revisions(current_deps.changed_at), - }, - ) - } else { - // The struct already exists in the intern map. - // Note that we assume there is at most one executing copy of - // the current query at a time, which implies that the - // struct must exist in `self.struct_map` already - // (if the same query could execute twice in parallel, - // then it would potentially create the same struct twice in parallel, - // which means the interned key could exist but `struct_map` not yet have - // been updated). + self.struct_map.insert( + current_revision, + Value { + id, + key: entity_key, + struct_ingredient_index: self.ingredient_index, + created_at: current_revision, + durability: current_deps.durability, + fields: unsafe { self.to_static(fields) }, + revisions: C::new_revisions(current_deps.changed_at), + }, + ) + } else { + // The struct already exists in the intern map. + // Note that we assume there is at most one executing copy of + // the current query at a time, which implies that the + // struct must exist in `self.struct_map` already + // (if the same query could execute twice in parallel, + // then it would potentially create the same struct twice in parallel, + // which means the interned key could exist but `struct_map` not yet have + // been updated). - match self.struct_map.update(runtime, id) { - Update::Current(r) => { - // All inputs up to this point were previously - // observed to be green and this struct was already - // verified. Therefore, the durability ought not to have - // changed (nor the field values, but the user could've - // done something stupid, so we can't *assert* this is true). - assert!(C::deref_struct(r).durability == current_deps.durability); + match self.struct_map.update(current_revision, id) { + Update::Current(r) => { + // All inputs up to this point were previously + // observed to be green and this struct was already + // verified. Therefore, the durability ought not to have + // changed (nor the field values, but the user could've + // done something stupid, so we can't *assert* this is true). + assert!(C::deref_struct(r).durability == current_deps.durability); - r - } - Update::Outdated(mut data_ref) => { - let data = &mut *data_ref; - - // SAFETY: We assert that the pointer to `data.revisions` - // is a pointer into the database referencing a value - // from a previous revision. As such, it continues to meet - // its validity invariant and any owned content also continues - // to meet its safety invariant. - unsafe { - C::update_fields( - current_revision, - &mut data.revisions, - self.to_self_ptr(std::ptr::addr_of_mut!(data.fields)), - fields, - ); + r } - if current_deps.durability < data.durability { - data.revisions = C::new_revisions(current_revision); + Update::Outdated(mut data_ref) => { + let data = &mut *data_ref; + + // SAFETY: We assert that the pointer to `data.revisions` + // is a pointer into the database referencing a value + // from a previous revision. As such, it continues to meet + // its validity invariant and any owned content also continues + // to meet its safety invariant. + unsafe { + C::update_fields( + current_revision, + &mut data.revisions, + self.to_self_ptr(std::ptr::addr_of_mut!(data.fields)), + fields, + ); + } + if current_deps.durability < data.durability { + data.revisions = C::new_revisions(current_revision); + } + data.durability = current_deps.durability; + data.created_at = current_revision; + data_ref.freeze() } - data.durability = current_deps.durability; - data.created_at = current_revision; - data_ref.freeze() } } - } + }) } /// Given the id of a tracked struct created in this revision, @@ -377,7 +378,8 @@ where /// /// If the struct has not been created in this revision. pub fn lookup_struct<'db>(&'db self, runtime: &'db Runtime, id: Id) -> C::Struct<'db> { - self.struct_map.get(runtime, id) + let current_revision = runtime.current_revision(); + self.struct_map.get(current_revision, id) } /// Deletes the given entities. This is used after a query `Q` executes and we can compare @@ -507,21 +509,26 @@ where /// Access to this value field. /// Note that this function returns the entire tuple of value fields. /// The caller is responible for selecting the appropriate element. - pub fn field<'db>(&'db self, db: &'db dyn Database, field_index: usize) -> &'db C::Fields<'db> { - let runtime = db.runtime(); - let field_ingredient_index = self.struct_ingredient_index.successor(field_index); - let changed_at = self.revisions[field_index]; + pub fn field<'db>( + &'db self, + db: &dyn crate::Database, + field_index: usize, + ) -> &'db C::Fields<'db> { + local_state::attach(db, |local_state| { + let field_ingredient_index = self.struct_ingredient_index.successor(field_index); + let changed_at = self.revisions[field_index]; - runtime.report_tracked_read( - DependencyIndex { - ingredient_index: field_ingredient_index, - key_index: Some(self.id.as_id()), - }, - self.durability, - changed_at, - ); + local_state.report_tracked_read( + DependencyIndex { + ingredient_index: field_ingredient_index, + key_index: Some(self.id.as_id()), + }, + self.durability, + changed_at, + ); - unsafe { self.to_self_ref(&self.fields) } + unsafe { self.to_self_ref(&self.fields) } + }) } unsafe fn to_self_ref<'db>(&'db self, fields: &'db C::Fields<'static>) -> &'db C::Fields<'db> { diff --git a/src/tracked_struct/struct_map.rs b/src/tracked_struct/struct_map.rs index 9f982a3..4779a27 100644 --- a/src/tracked_struct/struct_map.rs +++ b/src/tracked_struct/struct_map.rs @@ -6,7 +6,7 @@ use std::{ use crossbeam::queue::SegQueue; use dashmap::mapref::one::RefMut; -use crate::{alloc::Alloc, hash::FxDashMap, Id, Runtime}; +use crate::{alloc::Alloc, hash::FxDashMap, Id, Revision, Runtime}; use super::{Configuration, KeyStruct, Value}; @@ -80,8 +80,8 @@ where /// /// * If value with same `value.id` is already present in the map. /// * If value not created in current revision. - pub fn insert<'db>(&'db self, runtime: &'db Runtime, value: Value) -> C::Struct<'db> { - assert_eq!(value.created_at, runtime.current_revision()); + pub fn insert<'db>(&'db self, current_revision: Revision, value: Value) -> C::Struct<'db> { + assert_eq!(value.created_at, current_revision); let id = value.id; let boxed_value = Alloc::new(value); @@ -119,12 +119,9 @@ where /// /// * If the value is not present in the map. /// * If the value is already updated in this revision. - pub fn update<'db>(&'db self, runtime: &'db Runtime, id: Id) -> Update<'db, C> { + pub fn update<'db>(&'db self, current_revision: Revision, id: Id) -> Update<'db, C> { let mut data = self.map.get_mut(&id).unwrap(); - // Never update a struct twice in the same revision. - let current_revision = runtime.current_revision(); - // UNSAFE: We never permit `&`-access in the current revision until data.created_at // has been updated to the current revision (which we check below). let data_ref = unsafe { data.as_mut() }; @@ -154,7 +151,7 @@ where // code cannot violate that `&`-reference. if data_ref.created_at == current_revision { drop(data); - return Update::Current(Self::get_from_map(&self.map, runtime, id)); + return Update::Current(Self::get_from_map(&self.map, current_revision, id)); } data_ref.created_at = current_revision; @@ -167,8 +164,8 @@ where /// /// * If the value is not present in the map. /// * If the value has not been updated in this revision. - pub fn get<'db>(&'db self, runtime: &'db Runtime, id: Id) -> C::Struct<'db> { - Self::get_from_map(&self.map, runtime, id) + pub fn get<'db>(&'db self, current_revision: Revision, id: Id) -> C::Struct<'db> { + Self::get_from_map(&self.map, current_revision, id) } /// Helper function, provides shared functionality for [`StructMapView`][] @@ -179,7 +176,7 @@ where /// * If the value has not been updated in this revision. fn get_from_map<'db>( map: &'db FxDashMap>>, - runtime: &'db Runtime, + current_revision: Revision, id: Id, ) -> C::Struct<'db> { let data = map.get(&id).unwrap(); @@ -190,7 +187,6 @@ where // Before we drop the lock, check that the value has // been updated in this revision. This is what allows us to return a `` - let current_revision = runtime.current_revision(); let created_at = data_ref.created_at; assert!( created_at == current_revision, @@ -235,8 +231,8 @@ where /// /// * If the value is not present in the map. /// * If the value has not been updated in this revision. - pub fn get<'db>(&'db self, runtime: &'db Runtime, id: Id) -> C::Struct<'db> { - StructMap::get_from_map(&self.map, runtime, id) + pub fn get<'db>(&'db self, current_revision: Revision, id: Id) -> C::Struct<'db> { + StructMap::get_from_map(&self.map, current_revision, id) } } diff --git a/src/tracked_struct/tracked_field.rs b/src/tracked_struct/tracked_field.rs index 73c5dd6..a42db5b 100644 --- a/src/tracked_struct/tracked_field.rs +++ b/src/tracked_struct/tracked_field.rs @@ -1,5 +1,6 @@ use crate::{ - id::AsId, ingredient::Ingredient, key::DependencyIndex, storage::IngredientIndex, Database, Id, + id::AsId, ingredient::Ingredient, key::DependencyIndex, local_state, storage::IngredientIndex, + Database, Id, }; use super::{struct_map::StructMapView, Configuration}; @@ -46,21 +47,23 @@ where /// Note that this function returns the entire tuple of value fields. /// The caller is responible for selecting the appropriate element. pub fn field<'db>(&'db self, db: &'db dyn Database, id: Id) -> &'db C::Fields<'db> { - let runtime = db.runtime(); - let data = self.struct_map.get(runtime, id); - let data = C::deref_struct(data); - let changed_at = data.revisions[self.field_index]; + local_state::attach(db, |local_state| { + let current_revision = db.runtime().current_revision(); + let data = self.struct_map.get(current_revision, id); + let data = C::deref_struct(data); + let changed_at = data.revisions[self.field_index]; - runtime.report_tracked_read( - DependencyIndex { - ingredient_index: self.ingredient_index, - key_index: Some(id.as_id()), - }, - data.durability, - changed_at, - ); + local_state.report_tracked_read( + DependencyIndex { + ingredient_index: self.ingredient_index, + key_index: Some(id.as_id()), + }, + data.durability, + changed_at, + ); - unsafe { self.to_self_ref(&data.fields) } + unsafe { self.to_self_ref(&data.fields) } + }) } } @@ -82,9 +85,9 @@ where input: Option, revision: crate::Revision, ) -> bool { - let runtime = db.runtime(); + let current_revision = db.runtime().current_revision(); let id = input.unwrap(); - let data = self.struct_map.get(runtime, id); + let data = self.struct_map.get(current_revision, id); let data = C::deref_struct(data); let field_changed_at = data.revisions[self.field_index]; field_changed_at > revision