move local state into thread local

This commit is contained in:
Niko Matsakis 2024-07-23 07:52:52 -04:00
parent 74ef66dbb4
commit 21af3a2009
18 changed files with 726 additions and 747 deletions

View file

@ -265,26 +265,24 @@ macro_rules! setup_tracked_fn {
} }
}
$zalsa::attach_database($db, || {
let result = $zalsa::macro_if! {
if $needs_interner {
{
let key = $Configuration::intern_ingredient($db).intern_id($db.as_salsa_database(), ($($input_id),*));
$Configuration::fn_ingredient($db).fetch($db, key)
}
} else {
$Configuration::fn_ingredient($db).fetch($db, $zalsa::AsId::as_id(&($($input_id),*)))
}
};
$zalsa::macro_if! {
if $return_ref {
result
} else {
<$output_ty as std::clone::Clone>::clone(result)
let result = $zalsa::macro_if! {
if $needs_interner {
{
let key = $Configuration::intern_ingredient($db).intern_id($db.as_salsa_database(), ($($input_id),*));
$Configuration::fn_ingredient($db).fetch($db, key)
}
} else {
$Configuration::fn_ingredient($db).fetch($db, $zalsa::AsId::as_id(&($($input_id),*)))
}
})
};
$zalsa::macro_if! {
if $return_ref {
result
} else {
<$output_ty as std::clone::Clone>::clone(result)
}
}
}
};
}

View file

@ -10,9 +10,9 @@ use crate::{
hash::FxDashMap,
ingredient::{fmt_index, Ingredient, Jar},
key::DependencyIndex,
local_state::QueryOrigin,
local_state::{self, LocalState, QueryOrigin},
storage::IngredientIndex,
Database, DatabaseKeyIndex, Event, EventKind, Id, Revision, Runtime,
Database, DatabaseKeyIndex, Event, EventKind, Id, Revision,
};
pub trait Accumulator: Clone + Debug + Send + Sync + 'static + Sized {
@ -79,44 +79,47 @@ impl<A: Accumulator> IngredientImpl<A> {
}
pub fn push(&self, db: &dyn crate::Database, value: A) {
let runtime = db.runtime();
let current_revision = runtime.current_revision();
let (active_query, _) = match runtime.active_query() {
Some(pair) => pair,
None => {
panic!("cannot accumulate values outside of an active query")
local_state::attach(db, |state| {
let runtime = db.runtime();
let current_revision = runtime.current_revision();
let (active_query, _) = match state.active_query() {
Some(pair) => pair,
None => {
panic!("cannot accumulate values outside of an active query")
}
};
let mut accumulated_values =
self.map.entry(active_query).or_insert(AccumulatedValues {
values: vec![],
produced_at: current_revision,
});
// When we call `push' in a query, we will add the accumulator to the output of the query.
// If we find here that this accumulator is not the output of the query,
// we can say that the accumulated values we stored for this query is out of date.
if !state.is_output_of_active_query(self.dependency_index()) {
accumulated_values.values.truncate(0);
accumulated_values.produced_at = current_revision;
}
};
let mut accumulated_values = self.map.entry(active_query).or_insert(AccumulatedValues {
values: vec![],
produced_at: current_revision,
});
// When we call `push' in a query, we will add the accumulator to the output of the query.
// If we find here that this accumulator is not the output of the query,
// we can say that the accumulated values we stored for this query is out of date.
if !runtime.is_output_of_active_query(self.dependency_index()) {
accumulated_values.values.truncate(0);
accumulated_values.produced_at = current_revision;
}
runtime.add_output(self.dependency_index());
accumulated_values.values.push(value);
state.add_output(self.dependency_index());
accumulated_values.values.push(value);
})
}
pub(crate) fn produced_by(
&self,
runtime: &Runtime,
current_revision: Revision,
local_state: &LocalState,
query: DatabaseKeyIndex,
output: &mut Vec<A>,
) {
let current_revision = runtime.current_revision();
if let Some(v) = self.map.get(&query) {
// FIXME: We don't currently have a good way to identify the value that was read.
// You can't report is as a tracked read of `query`, because the return value of query is not being read here --
// instead it is the set of values accumuated by `query`.
runtime.report_untracked_read();
local_state.report_untracked_read(current_revision);
let AccumulatedValues {
values,

View file

@ -1,4 +1,4 @@
use crate::{database, key::DatabaseKeyIndex, Database};
use crate::{key::DatabaseKeyIndex, local_state, Database};
use std::{panic::AssertUnwindSafe, sync::Arc};
/// Captures the participants of a cycle that occurred when executing a query.
@ -74,7 +74,7 @@ impl Cycle {
impl std::fmt::Debug for Cycle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
database::with_attached_database(|db| {
local_state::with_attached_database(|db| {
f.debug_struct("UnexpectedCycle")
.field("all_participants", &self.all_participants(db))
.field("unexpected_participants", &self.unexpected_participants(db))

View file

@ -1,6 +1,4 @@
use std::{cell::Cell, ptr::NonNull};
use crate::{storage::DatabaseGen, Durability, Event, Revision};
use crate::{local_state, storage::DatabaseGen, Durability, Event, Revision};
#[salsa_macros::db]
pub trait Database: DatabaseGen {
@ -31,7 +29,10 @@ pub trait Database: DatabaseGen {
/// Queries which report untracked reads will be re-executed in the next
/// revision.
fn report_untracked_read(&self) {
self.runtime().report_untracked_read();
let db = self.as_salsa_database();
local_state::attach(db, |state| {
state.report_untracked_read(db.runtime().current_revision())
})
}
/// Execute `op` with the database in thread-local storage for debug print-outs.
@ -39,73 +40,7 @@ pub trait Database: DatabaseGen {
where
Self: Sized,
{
attach_database(self, || op(self))
}
}
thread_local! {
static DATABASE: Cell<AttachedDatabase> = const { Cell::new(AttachedDatabase::null()) };
}
/// Access the "attached" database. Returns `None` if no database is attached.
/// Databases are attached with `attach_database`.
pub fn with_attached_database<R>(op: impl FnOnce(&dyn Database) -> R) -> Option<R> {
// SAFETY: We always attach the database in for the entire duration of a function,
// so it cannot become "unattached" while this function is running.
let db = DATABASE.get();
Some(op(unsafe { db.ptr?.as_ref() }))
}
/// Attach database and returns a guard that will un-attach the database when dropped.
/// Has no effect if a database is already attached.
pub fn attach_database<Db: ?Sized + Database, R>(db: &Db, op: impl FnOnce() -> R) -> R {
let _guard = AttachedDb::new(db);
op()
}
#[derive(Copy, Clone, PartialEq, Eq)]
struct AttachedDatabase {
ptr: Option<NonNull<dyn Database>>,
}
impl AttachedDatabase {
pub const fn null() -> Self {
Self { ptr: None }
}
pub fn from<Db: ?Sized + Database>(db: &Db) -> Self {
unsafe {
let db: *const dyn Database = db.as_salsa_database();
Self {
ptr: Some(NonNull::new_unchecked(db as *mut dyn Database)),
}
}
}
}
struct AttachedDb<'db, Db: ?Sized + Database> {
db: &'db Db,
previous: AttachedDatabase,
}
impl<'db, Db: ?Sized + Database> AttachedDb<'db, Db> {
pub fn new(db: &'db Db) -> Self {
let previous = DATABASE.replace(AttachedDatabase::from(db));
AttachedDb { db, previous }
}
}
impl<Db: ?Sized + Database> Drop for AttachedDb<'_, Db> {
fn drop(&mut self) {
DATABASE.set(self.previous);
}
}
impl<Db: ?Sized + Database> std::ops::Deref for AttachedDb<'_, Db> {
type Target = Db;
fn deref(&self) -> &Db {
self.db
local_state::attach(self, |_state| op(self))
}
}

View file

@ -1,4 +1,6 @@
use crate::{accumulator, hash::FxHashSet, storage::DatabaseGen, DatabaseKeyIndex, Id};
use crate::{
accumulator, hash::FxHashSet, local_state, storage::DatabaseGen, DatabaseKeyIndex, Id,
};
use super::{Configuration, IngredientImpl};
@ -12,36 +14,41 @@ where
where
A: accumulator::Accumulator,
{
let Some(accumulator) = <accumulator::IngredientImpl<A>>::from_db(db) else {
return vec![];
};
let runtime = db.runtime();
let mut output = vec![];
local_state::attach(db, |local_state| {
let current_revision = db.runtime().current_revision();
// First ensure the result is up to date
self.fetch(db, key);
let Some(accumulator) = <accumulator::IngredientImpl<A>>::from_db(db) else {
return vec![];
};
let mut output = vec![];
let db_key = self.database_key_index(key);
let mut visited: FxHashSet<DatabaseKeyIndex> = FxHashSet::default();
let mut stack: Vec<DatabaseKeyIndex> = vec![db_key];
// First ensure the result is up to date
self.fetch(db, key);
while let Some(k) = stack.pop() {
if visited.insert(k) {
accumulator.produced_by(runtime, k, &mut output);
let db_key = self.database_key_index(key);
let mut visited: FxHashSet<DatabaseKeyIndex> = FxHashSet::default();
let mut stack: Vec<DatabaseKeyIndex> = vec![db_key];
let origin = db.lookup_ingredient(k.ingredient_index).origin(k.key_index);
let inputs = origin.iter().flat_map(|origin| origin.inputs());
// Careful: we want to push in execution order, so reverse order to
// ensure the first child that was executed will be the first child popped
// from the stack.
stack.extend(
inputs
.flat_map(|input| TryInto::<DatabaseKeyIndex>::try_into(input).into_iter())
.rev(),
);
while let Some(k) = stack.pop() {
if visited.insert(k) {
accumulator.produced_by(current_revision, local_state, k, &mut output);
let origin = db.lookup_ingredient(k.ingredient_index).origin(k.key_index);
let inputs = origin.iter().flat_map(|origin| origin.inputs());
// Careful: we want to push in execution order, so reverse order to
// ensure the first child that was executed will be the first child popped
// from the stack.
stack.extend(
inputs
.flat_map(|input| {
TryInto::<DatabaseKeyIndex>::try_into(input).into_iter()
})
.rev(),
);
}
}
}
output
output
})
}
}

View file

@ -1,6 +1,11 @@
use arc_swap::Guard;
use crate::{runtime::StampedValue, storage::DatabaseGen, Id};
use crate::{
local_state::{self, LocalState},
runtime::StampedValue,
storage::DatabaseGen,
Id,
};
use super::{Configuration, IngredientImpl};
@ -9,37 +14,41 @@ where
C: Configuration,
{
pub fn fetch<'db>(&'db self, db: &'db C::DbView, key: Id) -> &C::Output<'db> {
let runtime = db.runtime();
local_state::attach(db.as_salsa_database(), |local_state| {
local_state.unwind_if_revision_cancelled(db.as_salsa_database());
runtime.unwind_if_revision_cancelled(db);
let StampedValue {
value,
durability,
changed_at,
} = self.compute_value(db, local_state, key);
let StampedValue {
value,
durability,
changed_at,
} = self.compute_value(db, key);
if let Some(evicted) = self.lru.record_use(key) {
self.evict(evicted);
}
if let Some(evicted) = self.lru.record_use(key) {
self.evict(evicted);
}
local_state.report_tracked_read(
self.database_key_index(key).into(),
durability,
changed_at,
);
db.runtime().report_tracked_read(
self.database_key_index(key).into(),
durability,
changed_at,
);
value
value
})
}
#[inline]
fn compute_value<'db>(
&'db self,
db: &'db C::DbView,
local_state: &LocalState,
key: Id,
) -> StampedValue<&'db C::Output<'db>> {
loop {
if let Some(value) = self.fetch_hot(db, key).or_else(|| self.fetch_cold(db, key)) {
if let Some(value) = self
.fetch_hot(db, key)
.or_else(|| self.fetch_cold(db, local_state, key))
{
return value;
}
}
@ -70,18 +79,18 @@ where
fn fetch_cold<'db>(
&'db self,
db: &'db C::DbView,
local_state: &LocalState,
key: Id,
) -> Option<StampedValue<&'db C::Output<'db>>> {
let runtime = db.runtime();
let database_key_index = self.database_key_index(key);
// Try to claim this query: if someone else has claimed it already, go back and start again.
let _claim_guard = self
.sync_map
.claim(db.as_salsa_database(), database_key_index)?;
let _claim_guard =
self.sync_map
.claim(db.as_salsa_database(), local_state, database_key_index)?;
// Push the query on the stack.
let active_query = runtime.push_query(database_key_index);
let active_query = local_state.push_query(database_key_index);
// Now that we've claimed the item, check again to see if there's a "hot" value.
// This time we can do a *deep* verify. Because this can recurse, don't hold the arcswap guard.

View file

@ -2,7 +2,7 @@ use arc_swap::Guard;
use crate::{
key::DatabaseKeyIndex,
local_state::{ActiveQueryGuard, EdgeKind, QueryOrigin},
local_state::{self, ActiveQueryGuard, EdgeKind, LocalState, QueryOrigin},
runtime::StampedValue,
storage::DatabaseGen,
Id, Revision, Runtime,
@ -20,46 +20,51 @@ where
key: Id,
revision: Revision,
) -> bool {
let runtime = db.runtime();
runtime.unwind_if_revision_cancelled(db);
local_state::attach(db.as_salsa_database(), |local_state| {
let runtime = db.runtime();
local_state.unwind_if_revision_cancelled(db.as_salsa_database());
loop {
let database_key_index = self.database_key_index(key);
loop {
let database_key_index = self.database_key_index(key);
tracing::debug!("{database_key_index:?}: maybe_changed_after(revision = {revision:?})");
tracing::debug!(
"{database_key_index:?}: maybe_changed_after(revision = {revision:?})"
);
// Check if we have a verified version: this is the hot path.
let memo_guard = self.memo_map.get(key);
if let Some(memo) = &memo_guard {
if self.shallow_verify_memo(db, runtime, database_key_index, memo) {
return memo.revisions.changed_at > revision;
}
drop(memo_guard); // release the arc-swap guard before cold path
if let Some(mcs) = self.maybe_changed_after_cold(db, key, revision) {
return mcs;
// Check if we have a verified version: this is the hot path.
let memo_guard = self.memo_map.get(key);
if let Some(memo) = &memo_guard {
if self.shallow_verify_memo(db, runtime, database_key_index, memo) {
return memo.revisions.changed_at > revision;
}
drop(memo_guard); // release the arc-swap guard before cold path
if let Some(mcs) = self.maybe_changed_after_cold(db, local_state, key, revision)
{
return mcs;
} else {
// We failed to claim, have to retry.
}
} else {
// We failed to claim, have to retry.
// No memo? Assume has changed.
return true;
}
} else {
// No memo? Assume has changed.
return true;
}
}
})
}
fn maybe_changed_after_cold<'db>(
&'db self,
db: &'db C::DbView,
local_state: &LocalState,
key_index: Id,
revision: Revision,
) -> Option<bool> {
let runtime = db.runtime();
let database_key_index = self.database_key_index(key_index);
let _claim_guard = self
.sync_map
.claim(db.as_salsa_database(), database_key_index)?;
let active_query = runtime.push_query(database_key_index);
let _claim_guard =
self.sync_map
.claim(db.as_salsa_database(), local_state, database_key_index)?;
let active_query = local_state.push_query(database_key_index);
// Load the current memo, if any. Use a real arc, not an arc-swap guard,
// since we may recurse.
@ -70,7 +75,7 @@ where
tracing::debug!(
"{database_key_index:?}: maybe_changed_after_cold, successful claim, \
revision = {revision:?}, old_memo = {old_memo:#?}",
revision = {revision:?}, old_memo = {old_memo:#?}",
);
// Check if the inputs are still valid and we can just compare `changed_at`.

View file

@ -1,7 +1,7 @@
use crossbeam::atomic::AtomicCell;
use crate::{
local_state::{QueryOrigin, QueryRevisions},
local_state::{self, QueryOrigin, QueryRevisions},
storage::DatabaseGen,
tracked_struct::TrackedStructInDb,
Database, DatabaseKeyIndex, Id,
@ -13,97 +13,83 @@ impl<C> IngredientImpl<C>
where
C: Configuration,
{
/// Specifies the value of the function for the given key.
/// This is a way to imperatively set the value of a function.
/// It only works if the key is a tracked struct created in the current query.
fn specify<'db>(
&'db self,
db: &'db C::DbView,
key: Id,
value: C::Output<'db>,
origin: impl Fn(DatabaseKeyIndex) -> QueryOrigin,
) where
C::Input<'db>: TrackedStructInDb,
{
let runtime = db.runtime();
let (active_query_key, current_deps) = match runtime.active_query() {
Some(v) => v,
None => panic!("can only use `specify` inside a tracked function"),
};
// `specify` only works if the key is a tracked struct created in the current query.
//
// The reason is this. We want to ensure that the same result is reached regardless of
// the "path" that the user takes through the execution graph.
// If you permit values to be specified from other queries, you can have a situation like this:
// * Q0 creates the tracked struct T0
// * Q1 specifies the value for F(T0)
// * Q2 invokes F(T0)
// * Q3 invokes Q1 and then Q2
// * Q4 invokes Q2 and then Q1
//
// Now, if We invoke Q3 first, We get one result for Q2, but if We invoke Q4 first, We get a different value. That's no good.
let database_key_index = <C::Input<'db>>::database_key_index(db.as_salsa_database(), key);
let dependency_index = database_key_index.into();
if !runtime.is_output_of_active_query(dependency_index) {
panic!("can only use `specify` on salsa structs created during the current tracked fn");
}
// Subtle: we treat the "input" to a set query as if it were
// volatile.
//
// The idea is this. You have the current query C that
// created the entity E, and it is setting the value F(E) of the function F.
// When some other query R reads the field F(E), in order to have obtained
// the entity E, it has to have executed the query C.
//
// This will have forced C to either:
//
// - not create E this time, in which case R shouldn't have it (some kind of leak has occurred)
// - assign a value to F(E), in which case `verified_at` will be the current revision and `changed_at` will be updated appropriately
// - NOT assign a value to F(E), in which case we need to re-execute the function (which typically panics).
//
// So, ruling out the case of a leak having occurred, that means that the reader R will either see:
//
// - a result that is verified in the current revision, because it was set, which will use the set value
// - a result that is NOT verified and has untracked inputs, which will re-execute (and likely panic)
let revision = runtime.current_revision();
let mut revisions = QueryRevisions {
changed_at: current_deps.changed_at,
durability: current_deps.durability,
origin: origin(active_query_key),
};
if let Some(old_memo) = self.memo_map.get(key) {
self.backdate_if_appropriate(&old_memo, &mut revisions, &value);
self.diff_outputs(db, database_key_index, &old_memo, &revisions);
}
let memo = Memo {
value: Some(value),
verified_at: AtomicCell::new(revision),
revisions,
};
tracing::debug!("specify: about to add memo {:#?} for key {:?}", memo, key);
self.insert_memo(db, key, memo);
}
/// Specify the value for `key` *and* record that we did so.
/// Used for explicit calls to `specify`, but not needed for pre-declared tracked struct fields.
pub fn specify_and_record<'db>(&'db self, db: &'db C::DbView, key: Id, value: C::Output<'db>)
where
C::Input<'db>: TrackedStructInDb,
{
self.specify(db, key, value, |database_key_index| {
QueryOrigin::Assigned(database_key_index)
});
local_state::attach(db.as_salsa_database(), |state| {
let (active_query_key, current_deps) = match state.active_query() {
Some(v) => v,
None => panic!("can only use `specify` inside a tracked function"),
};
// Record that the current query *specified* a value for this cell.
let database_key_index = self.database_key_index(key);
db.runtime().add_output(database_key_index.into());
// `specify` only works if the key is a tracked struct created in the current query.
//
// The reason is this. We want to ensure that the same result is reached regardless of
// the "path" that the user takes through the execution graph.
// If you permit values to be specified from other queries, you can have a situation like this:
// * Q0 creates the tracked struct T0
// * Q1 specifies the value for F(T0)
// * Q2 invokes F(T0)
// * Q3 invokes Q1 and then Q2
// * Q4 invokes Q2 and then Q1
//
// Now, if We invoke Q3 first, We get one result for Q2, but if We invoke Q4 first, We get a different value. That's no good.
let database_key_index =
<C::Input<'db>>::database_key_index(db.as_salsa_database(), key);
let dependency_index = database_key_index.into();
if !state.is_output_of_active_query(dependency_index) {
panic!(
"can only use `specify` on salsa structs created during the current tracked fn"
);
}
// Subtle: we treat the "input" to a set query as if it were
// volatile.
//
// The idea is this. You have the current query C that
// created the entity E, and it is setting the value F(E) of the function F.
// When some other query R reads the field F(E), in order to have obtained
// the entity E, it has to have executed the query C.
//
// This will have forced C to either:
//
// - not create E this time, in which case R shouldn't have it (some kind of leak has occurred)
// - assign a value to F(E), in which case `verified_at` will be the current revision and `changed_at` will be updated appropriately
// - NOT assign a value to F(E), in which case we need to re-execute the function (which typically panics).
//
// So, ruling out the case of a leak having occurred, that means that the reader R will either see:
//
// - a result that is verified in the current revision, because it was set, which will use the set value
// - a result that is NOT verified and has untracked inputs, which will re-execute (and likely panic)
let revision = db.runtime().current_revision();
let mut revisions = QueryRevisions {
changed_at: current_deps.changed_at,
durability: current_deps.durability,
origin: QueryOrigin::Assigned(active_query_key),
};
if let Some(old_memo) = self.memo_map.get(key) {
self.backdate_if_appropriate(&old_memo, &mut revisions, &value);
self.diff_outputs(db, database_key_index, &old_memo, &revisions);
}
let memo = Memo {
value: Some(value),
verified_at: AtomicCell::new(revision),
revisions,
};
tracing::debug!("specify: about to add memo {:#?} for key {:?}", memo, key);
self.insert_memo(db, key, memo);
// Record that the current query *specified* a value for this cell.
let database_key_index = self.database_key_index(key);
state.add_output(database_key_index.into());
})
}
/// Invoked when the query `executor` has been validated as having green inputs

View file

@ -3,7 +3,10 @@ use std::{
thread::ThreadId,
};
use crate::{hash::FxDashMap, key::DatabaseKeyIndex, runtime::WaitResult, Database, Id, Runtime};
use crate::{
hash::FxDashMap, key::DatabaseKeyIndex, local_state::LocalState, runtime::WaitResult, Database,
Id, Runtime,
};
#[derive(Default)]
pub(super) struct SyncMap {
@ -22,6 +25,7 @@ impl SyncMap {
pub(super) fn claim<'me>(
&'me self,
db: &'me dyn Database,
local_state: &LocalState,
database_key_index: DatabaseKeyIndex,
) -> Option<ClaimGuard<'me>> {
let runtime = db.runtime();
@ -47,7 +51,7 @@ impl SyncMap {
// not to gate future atomic reads.
entry.get().anyone_waiting.store(true, Ordering::Relaxed);
let other_id = entry.get().id;
runtime.block_on_or_unwind(db, database_key_index, other_id, entry);
runtime.block_on_or_unwind(db, local_state, database_key_index, other_id, entry);
None
}
}

View file

@ -17,7 +17,7 @@ use crate::{
id::{AsId, FromId},
ingredient::{fmt_index, Ingredient},
key::{DatabaseKeyIndex, DependencyIndex},
local_state::QueryOrigin,
local_state::{self, QueryOrigin},
plumbing::{Jar, Stamp},
runtime::Runtime,
storage::IngredientIndex,
@ -154,19 +154,21 @@ impl<C: Configuration> IngredientImpl<C> {
id: C::Struct,
field_index: usize,
) -> &'db C::Fields {
let field_ingredient_index = self.ingredient_index.successor(field_index);
let id = id.as_id();
let value = self.struct_map.get(id);
let stamp = &value.stamps[field_index];
db.runtime().report_tracked_read(
DependencyIndex {
ingredient_index: field_ingredient_index,
key_index: Some(id),
},
stamp.durability,
stamp.changed_at,
);
&value.fields
local_state::attach(db, |state| {
let field_ingredient_index = self.ingredient_index.successor(field_index);
let id = id.as_id();
let value = self.struct_map.get(id);
let stamp = &value.stamps[field_index];
state.report_tracked_read(
DependencyIndex {
ingredient_index: field_ingredient_index,
key_index: Some(id),
},
stamp.durability,
stamp.changed_at,
);
&value.fields
})
}
/// Peek at the field values without recording any read dependency.

View file

@ -9,7 +9,7 @@ use crate::durability::Durability;
use crate::id::AsId;
use crate::ingredient::fmt_index;
use crate::key::DependencyIndex;
use crate::local_state::QueryOrigin;
use crate::local_state::{self, QueryOrigin};
use crate::plumbing::Jar;
use crate::storage::IngredientIndex;
use crate::{Database, DatabaseKeyIndex, Id};
@ -136,44 +136,46 @@ where
db: &'db dyn crate::Database,
data: C::Data<'db>,
) -> C::Struct<'db> {
db.runtime().report_tracked_read(
DependencyIndex::for_table(self.ingredient_index),
Durability::MAX,
self.reset_at,
);
local_state::attach(db, |state| {
state.report_tracked_read(
DependencyIndex::for_table(self.ingredient_index),
Durability::MAX,
self.reset_at,
);
// Optimisation to only get read lock on the map if the data has already
// been interned.
let internal_data = unsafe { self.to_internal_data(data) };
if let Some(guard) = self.key_map.get(&internal_data) {
let id = *guard;
drop(guard);
return self.interned_value(id);
}
match self.key_map.entry(internal_data.clone()) {
// Data has been interned by a racing call, use that ID instead
dashmap::mapref::entry::Entry::Occupied(entry) => {
let id = *entry.get();
drop(entry);
self.interned_value(id)
// Optimisation to only get read lock on the map if the data has already
// been interned.
let internal_data = unsafe { self.to_internal_data(data) };
if let Some(guard) = self.key_map.get(&internal_data) {
let id = *guard;
drop(guard);
return self.interned_value(id);
}
// We won any races so should intern the data
dashmap::mapref::entry::Entry::Vacant(entry) => {
let next_id = self.counter.fetch_add(1);
let next_id = crate::id::Id::from_u32(next_id);
let value = self.value_map.entry(next_id).or_insert(Alloc::new(Value {
id: next_id,
fields: internal_data,
}));
let value_raw = value.as_raw();
drop(value);
entry.insert(next_id);
// SAFETY: Items are only removed from the `value_map` with an `&mut self` reference.
unsafe { C::struct_from_raw(value_raw) }
match self.key_map.entry(internal_data.clone()) {
// Data has been interned by a racing call, use that ID instead
dashmap::mapref::entry::Entry::Occupied(entry) => {
let id = *entry.get();
drop(entry);
self.interned_value(id)
}
// We won any races so should intern the data
dashmap::mapref::entry::Entry::Vacant(entry) => {
let next_id = self.counter.fetch_add(1);
let next_id = crate::id::Id::from_u32(next_id);
let value = self.value_map.entry(next_id).or_insert(Alloc::new(Value {
id: next_id,
fields: internal_data,
}));
let value_raw = value.as_raw();
drop(value);
entry.insert(next_id);
// SAFETY: Items are only removed from the `value_map` with an `&mut self` reference.
unsafe { C::struct_from_raw(value_raw) }
}
}
}
})
}
pub fn interned_value(&self, id: Id) -> C::Struct<'_> {

View file

@ -1,4 +1,4 @@
use crate::{cycle::CycleRecoveryStrategy, database, storage::IngredientIndex, Database, Id};
use crate::{cycle::CycleRecoveryStrategy, local_state, storage::IngredientIndex, Database, Id};
/// An integer that uniquely identifies a particular query instance within the
/// database. Used to track dependencies between queries. Fully ordered and
@ -57,7 +57,7 @@ impl DependencyIndex {
impl std::fmt::Debug for DependencyIndex {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
database::with_attached_database(|db| {
local_state::with_attached_database(|db| {
let ingredient = db.lookup_ingredient(self.ingredient_index);
ingredient.fmt_index(self.key_index, f)
})

View file

@ -41,7 +41,7 @@ pub use self::revision::Revision;
pub use self::runtime::Runtime;
pub use self::storage::Storage;
pub use self::update::Update;
pub use crate::database::with_attached_database;
pub use crate::local_state::with_attached_database;
pub use salsa_macros::accumulator;
pub use salsa_macros::db;
pub use salsa_macros::input;
@ -79,9 +79,7 @@ pub mod plumbing {
pub use crate::array::Array;
pub use crate::cycle::Cycle;
pub use crate::cycle::CycleRecoveryStrategy;
pub use crate::database::attach_database;
pub use crate::database::current_revision;
pub use crate::database::with_attached_database;
pub use crate::database::Database;
pub use crate::function::should_backdate_value;
pub use crate::id::AsId;
@ -91,6 +89,7 @@ pub mod plumbing {
pub use crate::ingredient::Ingredient;
pub use crate::ingredient::Jar;
pub use crate::key::DatabaseKeyIndex;
pub use crate::local_state::with_attached_database;
pub use crate::revision::Revision;
pub use crate::runtime::stamp;
pub use crate::runtime::Runtime;

View file

@ -5,13 +5,48 @@ use crate::durability::Durability;
use crate::key::DatabaseKeyIndex;
use crate::key::DependencyIndex;
use crate::runtime::StampedValue;
use crate::storage::IngredientIndex;
use crate::tracked_struct::Disambiguator;
use crate::Cancelled;
use crate::Cycle;
use crate::Database;
use crate::Event;
use crate::EventKind;
use crate::Revision;
use crate::Runtime;
use std::cell::Cell;
use std::cell::RefCell;
use std::ptr::NonNull;
use std::sync::Arc;
thread_local! {
/// The thread-local state salsa requires for a given thread
static LOCAL_STATE: LocalState = const { LocalState::new() }
}
/// Attach the database to the current thread and execute `op`.
/// Panics if a different database has already been attached.
pub(crate) fn attach<R, DB>(db: &DB, op: impl FnOnce(&LocalState) -> R) -> R
where
DB: ?Sized + Database,
{
LOCAL_STATE.with(|state| state.attach(db.as_salsa_database(), || op(state)))
}
/// Access the "attached" database. Returns `None` if no database is attached.
/// Databases are attached with `attach_database`.
pub fn with_attached_database<R>(op: impl FnOnce(&dyn Database) -> R) -> Option<R> {
LOCAL_STATE.with(|state| {
if let Some(db) = state.database.get() {
// SAFETY: We always attach the database in for the entire duration of a function,
// so it cannot become "unattached" while this function is running.
Some(op(unsafe { db.as_ref() }))
} else {
None
}
})
}
/// State that is specific to a single execution thread.
///
/// Internally, this type uses ref-cells.
@ -19,6 +54,9 @@ use std::sync::Arc;
/// **Note also that all mutations to the database handle (and hence
/// to the local-state) must be undone during unwinding.**
pub(crate) struct LocalState {
/// Pointer to the currently attached database.
database: Cell<Option<NonNull<dyn Database>>>,
/// Vector of active queries.
///
/// This is normally `Some`, but it is set to `None`
@ -29,6 +67,282 @@ pub(crate) struct LocalState {
query_stack: RefCell<Option<Vec<ActiveQuery>>>,
}
impl LocalState {
const fn new() -> Self {
LocalState {
database: Cell::new(None),
query_stack: RefCell::new(Some(vec![])),
}
}
fn attach<R>(&self, db: &dyn Database, op: impl FnOnce() -> R) -> R {
struct DbGuard<'s> {
state: Option<&'s LocalState>,
}
impl<'s> DbGuard<'s> {
fn new(state: &'s LocalState, db: &dyn Database) -> Self {
if let Some(current_db) = state.database.get() {
// Already attached? Assert that the database has not changed.
assert_eq!(
current_db,
NonNull::from(db),
"cannot change database mid-query",
);
Self { state: None }
} else {
// Otherwise, set the database.
state.database.set(Some(NonNull::from(db)));
Self { state: Some(state) }
}
}
}
impl Drop for DbGuard<'_> {
fn drop(&mut self) {
// Reset database to null if we did anything in `DbGuard::new`.
if let Some(state) = self.state {
state.database.set(None);
// All stack frames should have been popped from the local stack.
assert!(state.query_stack.borrow().as_ref().unwrap().is_empty());
}
}
}
let _guard = DbGuard::new(self, db);
op()
}
#[inline]
pub(crate) fn push_query(&self, database_key_index: DatabaseKeyIndex) -> ActiveQueryGuard<'_> {
let mut query_stack = self.query_stack.borrow_mut();
let query_stack = query_stack.as_mut().expect("local stack taken");
query_stack.push(ActiveQuery::new(database_key_index));
ActiveQueryGuard {
local_state: self,
database_key_index,
push_len: query_stack.len(),
}
}
fn with_query_stack<R>(&self, c: impl FnOnce(&mut Vec<ActiveQuery>) -> R) -> R {
c(self
.query_stack
.borrow_mut()
.as_mut()
.expect("query stack taken"))
}
fn query_in_progress(&self) -> bool {
self.with_query_stack(|stack| !stack.is_empty())
}
/// Returns the index of the active query along with its *current* durability/changed-at
/// information. As the query continues to execute, naturally, that information may change.
pub(crate) fn active_query(&self) -> Option<(DatabaseKeyIndex, StampedValue<()>)> {
self.with_query_stack(|stack| {
stack.last().map(|active_query| {
(
active_query.database_key_index,
StampedValue {
value: (),
durability: active_query.durability,
changed_at: active_query.changed_at,
},
)
})
})
}
/// Add an output to the current query's list of dependencies
pub(crate) fn add_output(&self, entity: DependencyIndex) {
self.with_query_stack(|stack| {
if let Some(top_query) = stack.last_mut() {
top_query.add_output(entity)
}
})
}
/// Check whether `entity` is an output of the currently active query (if any)
pub(crate) fn is_output_of_active_query(&self, entity: DependencyIndex) -> bool {
self.with_query_stack(|stack| {
if let Some(top_query) = stack.last_mut() {
top_query.is_output(entity)
} else {
false
}
})
}
/// Register that currently active query reads the given input
pub(crate) fn report_tracked_read(
&self,
input: DependencyIndex,
durability: Durability,
changed_at: Revision,
) {
debug!(
"report_query_read_and_unwind_if_cycle_resulted(input={:?}, durability={:?}, changed_at={:?})",
input, durability, changed_at
);
self.with_query_stack(|stack| {
if let Some(top_query) = stack.last_mut() {
top_query.add_read(input, durability, changed_at);
// We are a cycle participant:
//
// C0 --> ... --> Ci --> Ci+1 -> ... -> Cn --> C0
// ^ ^
// : |
// This edge -----+ |
// |
// |
// N0
//
// In this case, the value we have just read from `Ci+1`
// is actually the cycle fallback value and not especially
// interesting. We unwind now with `CycleParticipant` to avoid
// executing the rest of our query function. This unwinding
// will be caught and our own fallback value will be used.
//
// Note that `Ci+1` may` have *other* callers who are not
// participants in the cycle (e.g., N0 in the graph above).
// They will not have the `cycle` marker set in their
// stack frames, so they will just read the fallback value
// from `Ci+1` and continue on their merry way.
if let Some(cycle) = &top_query.cycle {
cycle.clone().throw()
}
}
})
}
/// Register that the current query read an untracked value
///
/// # Parameters
///
/// * `current_revision`, the current revision
pub(crate) fn report_untracked_read(&self, current_revision: Revision) {
self.with_query_stack(|stack| {
if let Some(top_query) = stack.last_mut() {
top_query.add_untracked_read(current_revision);
}
})
}
/// Update the top query on the stack to act as though it read a value
/// of durability `durability` which changed in `revision`.
// FIXME: Use or remove this.
#[allow(dead_code)]
pub(crate) fn report_synthetic_read(&self, durability: Durability, revision: Revision) {
self.with_query_stack(|stack| {
if let Some(top_query) = stack.last_mut() {
top_query.add_synthetic_read(durability, revision);
}
})
}
/// Takes the query stack and returns it. This is used when
/// the current thread is blocking. The stack must be restored
/// with [`Self::restore_query_stack`] when the thread unblocks.
pub(crate) fn take_query_stack(&self) -> Vec<ActiveQuery> {
assert!(
self.query_stack.borrow().is_some(),
"query stack already taken"
);
self.query_stack.take().unwrap()
}
/// Restores a query stack taken with [`Self::take_query_stack`] once
/// the thread unblocks.
pub(crate) fn restore_query_stack(&self, stack: Vec<ActiveQuery>) {
assert!(self.query_stack.borrow().is_none(), "query stack not taken");
self.query_stack.replace(Some(stack));
}
/// Called when the active queries creates an index from the
/// entity table with the index `entity_index`. Has the following effects:
///
/// * Add a query read on `DatabaseKeyIndex::for_table(entity_index)`
/// * Identify a unique disambiguator for the hash within the current query,
/// adding the hash to the current query's disambiguator table.
/// * Returns a tuple of:
/// * the id of the current query
/// * the current dependencies (durability, changed_at) of current query
/// * the disambiguator index
#[track_caller]
pub(crate) fn disambiguate(
&self,
entity_index: IngredientIndex,
reset_at: Revision,
data_hash: u64,
) -> (DatabaseKeyIndex, StampedValue<()>, Disambiguator) {
assert!(
self.query_in_progress(),
"cannot create a tracked struct disambiguator outside of a tracked function"
);
self.report_tracked_read(
DependencyIndex::for_table(entity_index),
Durability::MAX,
reset_at,
);
self.with_query_stack(|stack| {
let top_query = stack.last_mut().unwrap();
let disambiguator = top_query.disambiguate(data_hash);
(
top_query.database_key_index,
StampedValue {
value: (),
durability: top_query.durability,
changed_at: top_query.changed_at,
},
disambiguator,
)
})
}
/// Starts unwinding the stack if the current revision is cancelled.
///
/// This method can be called by query implementations that perform
/// potentially expensive computations, in order to speed up propagation of
/// cancellation.
///
/// Cancellation will automatically be triggered by salsa on any query
/// invocation.
///
/// This method should not be overridden by `Database` implementors. A
/// `salsa_event` is emitted when this method is called, so that should be
/// used instead.
pub(crate) fn unwind_if_revision_cancelled(&self, db: &dyn Database) {
let runtime = db.runtime();
let thread_id = std::thread::current().id();
db.salsa_event(Event {
thread_id,
kind: EventKind::WillCheckCancellation,
});
if runtime.load_cancellation_flag() {
db.salsa_event(Event {
thread_id,
kind: EventKind::WillCheckCancellation,
});
self.unwind_cancelled(runtime);
}
}
#[cold]
pub(crate) fn unwind_cancelled(&self, runtime: &Runtime) {
let current_revision = runtime.current_revision();
self.report_untracked_read(current_revision);
Cancelled::PendingWrite.throw();
}
}
impl std::panic::RefUnwindSafe for LocalState {}
/// Summarizes "all the inputs that a query used"
#[derive(Debug, Clone)]
pub(crate) struct QueryRevisions {
@ -149,197 +463,6 @@ impl QueryEdges {
}
}
impl Default for LocalState {
fn default() -> Self {
LocalState {
query_stack: RefCell::new(Some(Vec::new())),
}
}
}
impl LocalState {
#[inline]
pub(crate) fn push_query(&self, database_key_index: DatabaseKeyIndex) -> ActiveQueryGuard<'_> {
let mut query_stack = self.query_stack.borrow_mut();
let query_stack = query_stack.as_mut().expect("local stack taken");
query_stack.push(ActiveQuery::new(database_key_index));
ActiveQueryGuard {
local_state: self,
database_key_index,
push_len: query_stack.len(),
}
}
fn with_query_stack<R>(&self, c: impl FnOnce(&mut Vec<ActiveQuery>) -> R) -> R {
c(self
.query_stack
.borrow_mut()
.as_mut()
.expect("query stack taken"))
}
pub(crate) fn query_in_progress(&self) -> bool {
self.with_query_stack(|stack| !stack.is_empty())
}
/// Returns the index of the active query along with its *current* durability/changed-at
/// information. As the query continues to execute, naturally, that information may change.
pub(crate) fn active_query(&self) -> Option<(DatabaseKeyIndex, StampedValue<()>)> {
self.with_query_stack(|stack| {
stack.last().map(|active_query| {
(
active_query.database_key_index,
StampedValue {
value: (),
durability: active_query.durability,
changed_at: active_query.changed_at,
},
)
})
})
}
/// Add an output to the current query's list of dependencies
pub(crate) fn add_output(&self, entity: DependencyIndex) {
self.with_query_stack(|stack| {
if let Some(top_query) = stack.last_mut() {
top_query.add_output(entity)
}
})
}
/// Check whether `entity` is an output of the currently active query (if any)
pub(crate) fn is_output(&self, entity: DependencyIndex) -> bool {
self.with_query_stack(|stack| {
if let Some(top_query) = stack.last_mut() {
top_query.is_output(entity)
} else {
false
}
})
}
/// Register that currently active query reads the given input
pub(crate) fn report_tracked_read(
&self,
input: DependencyIndex,
durability: Durability,
changed_at: Revision,
) {
debug!(
"report_query_read_and_unwind_if_cycle_resulted(input={:?}, durability={:?}, changed_at={:?})",
input, durability, changed_at
);
self.with_query_stack(|stack| {
if let Some(top_query) = stack.last_mut() {
top_query.add_read(input, durability, changed_at);
// We are a cycle participant:
//
// C0 --> ... --> Ci --> Ci+1 -> ... -> Cn --> C0
// ^ ^
// : |
// This edge -----+ |
// |
// |
// N0
//
// In this case, the value we have just read from `Ci+1`
// is actually the cycle fallback value and not especially
// interesting. We unwind now with `CycleParticipant` to avoid
// executing the rest of our query function. This unwinding
// will be caught and our own fallback value will be used.
//
// Note that `Ci+1` may` have *other* callers who are not
// participants in the cycle (e.g., N0 in the graph above).
// They will not have the `cycle` marker set in their
// stack frames, so they will just read the fallback value
// from `Ci+1` and continue on their merry way.
if let Some(cycle) = &top_query.cycle {
cycle.clone().throw()
}
}
})
}
/// Register that the current query read an untracked value
///
/// # Parameters
///
/// * `current_revision`, the current revision
pub(crate) fn report_untracked_read(&self, current_revision: Revision) {
self.with_query_stack(|stack| {
if let Some(top_query) = stack.last_mut() {
top_query.add_untracked_read(current_revision);
}
})
}
/// Update the top query on the stack to act as though it read a value
/// of durability `durability` which changed in `revision`.
// FIXME: Use or remove this.
#[allow(dead_code)]
pub(crate) fn report_synthetic_read(&self, durability: Durability, revision: Revision) {
self.with_query_stack(|stack| {
if let Some(top_query) = stack.last_mut() {
top_query.add_synthetic_read(durability, revision);
}
})
}
/// Takes the query stack and returns it. This is used when
/// the current thread is blocking. The stack must be restored
/// with [`Self::restore_query_stack`] when the thread unblocks.
pub(crate) fn take_query_stack(&self) -> Vec<ActiveQuery> {
assert!(
self.query_stack.borrow().is_some(),
"query stack already taken"
);
self.query_stack.take().unwrap()
}
/// Restores a query stack taken with [`Self::take_query_stack`] once
/// the thread unblocks.
pub(crate) fn restore_query_stack(&self, stack: Vec<ActiveQuery>) {
assert!(self.query_stack.borrow().is_none(), "query stack not taken");
self.query_stack.replace(Some(stack));
}
/// Given the hash of the id fields of a tracked struct, returns:
///
/// * database-key-index of currently active query
/// * durability/changed-at info for the inputs read thus far by said query
/// * a `Disambiguator` that uniquely identifies the tracked struct about to be created
///
/// The disambiguator is basically an integer that increments each time
/// a tracked struct with this `data_hash` is created.
#[track_caller]
pub(crate) fn disambiguate(
&self,
data_hash: u64,
) -> (DatabaseKeyIndex, StampedValue<()>, Disambiguator) {
assert!(
self.query_in_progress(),
"cannot create a tracked struct disambiguator outside of a tracked function"
);
self.with_query_stack(|stack| {
let top_query = stack.last_mut().unwrap();
let disambiguator = top_query.disambiguate(data_hash);
(
top_query.database_key_index,
StampedValue {
value: (),
durability: top_query.durability,
changed_at: top_query.changed_at,
},
disambiguator,
)
})
}
}
impl std::panic::RefUnwindSafe for LocalState {}
/// When a query is pushed onto the `active_query` stack, this guard
/// is returned to represent its slot. The guard can be used to pop
/// the query from the stack -- in the case of unwinding, the guard's

View file

@ -12,22 +12,16 @@ use crate::{
cycle::CycleRecoveryStrategy,
durability::Durability,
key::{DatabaseKeyIndex, DependencyIndex},
local_state::{self, ActiveQueryGuard, EdgeKind},
local_state::{EdgeKind, LocalState},
revision::AtomicRevision,
storage::IngredientIndex,
Cancelled, Cycle, Database, Event, EventKind, Revision,
};
use self::dependency_graph::DependencyGraph;
use super::tracked_struct::Disambiguator;
mod dependency_graph;
pub struct Runtime {
/// Local state that is specific to this runtime (thread).
local_state: local_state::LocalState,
/// Stores the next id to use for a snapshotted runtime (starts at 1).
next_id: AtomicUsize,
@ -91,7 +85,6 @@ impl<V> StampedValue<V> {
impl Default for Runtime {
fn default() -> Self {
Runtime {
local_state: Default::default(),
revisions: (0..Durability::LEN)
.map(|_| AtomicRevision::start())
.collect(),
@ -119,35 +112,10 @@ impl Runtime {
self.revisions[0].load()
}
/// Returns the index of the active query along with its *current* durability/changed-at
/// information. As the query continues to execute, naturally, that information may change.
pub(crate) fn active_query(&self) -> Option<(DatabaseKeyIndex, StampedValue<()>)> {
self.local_state.active_query()
}
pub(crate) fn empty_dependencies(&self) -> Arc<[(EdgeKind, DependencyIndex)]> {
self.empty_dependencies.clone()
}
pub(crate) fn report_tracked_read(
&self,
key_index: DependencyIndex,
durability: Durability,
changed_at: Revision,
) {
self.local_state
.report_tracked_read(key_index, durability, changed_at)
}
/// Reports that the query depends on some state unknown to salsa.
///
/// Queries which report untracked reads will be re-executed in the next
/// revision.
pub fn report_untracked_read(&self) {
self.local_state
.report_untracked_read(self.current_revision());
}
/// Reports that an input with durability `durability` changed.
/// This will update the 'last changed at' values for every durability
/// less than or equal to `durability` to the current revision.
@ -158,41 +126,6 @@ impl Runtime {
}
}
/// Adds `key` to the list of output created by the current query
/// (if not already present).
pub(crate) fn add_output(&self, key: DependencyIndex) {
self.local_state.add_output(key);
}
/// Check whether `entity` is contained the list of outputs written by the current query.
pub(super) fn is_output_of_active_query(&self, entity: DependencyIndex) -> bool {
self.local_state.is_output(entity)
}
/// Called when the active queries creates an index from the
/// entity table with the index `entity_index`. Has the following effects:
///
/// * Add a query read on `DatabaseKeyIndex::for_table(entity_index)`
/// * Identify a unique disambiguator for the hash within the current query,
/// adding the hash to the current query's disambiguator table.
/// * Returns a tuple of:
/// * the id of the current query
/// * the current dependencies (durability, changed_at) of current query
/// * the disambiguator index
pub(crate) fn disambiguate_entity(
&self,
entity_index: IngredientIndex,
reset_at: Revision,
data_hash: u64,
) -> (DatabaseKeyIndex, StampedValue<()>, Disambiguator) {
self.report_tracked_read(
DependencyIndex::for_table(entity_index),
Durability::MAX,
reset_at,
);
self.local_state.disambiguate(data_hash)
}
/// The revision in which values with durability `d` may have last
/// changed. For D0, this is just the current revision. But for
/// higher levels of durability, this value may lag behind the
@ -205,38 +138,8 @@ impl Runtime {
self.revisions[d.index()].load()
}
/// Starts unwinding the stack if the current revision is cancelled.
///
/// This method can be called by query implementations that perform
/// potentially expensive computations, in order to speed up propagation of
/// cancellation.
///
/// Cancellation will automatically be triggered by salsa on any query
/// invocation.
///
/// This method should not be overridden by `Database` implementors. A
/// `salsa_event` is emitted when this method is called, so that should be
/// used instead.
pub(crate) fn unwind_if_revision_cancelled<DB: ?Sized + Database>(&self, db: &DB) {
let thread_id = std::thread::current().id();
db.salsa_event(Event {
thread_id,
kind: EventKind::WillCheckCancellation,
});
if self.revision_canceled.load() {
db.salsa_event(Event {
thread_id,
kind: EventKind::WillCheckCancellation,
});
self.unwind_cancelled();
}
}
#[cold]
pub(crate) fn unwind_cancelled(&self) {
self.report_untracked_read();
Cancelled::PendingWrite.throw();
pub(crate) fn load_cancellation_flag(&self) -> bool {
self.revision_canceled.load()
}
pub(crate) fn set_cancellation_flag(&self) {
@ -255,11 +158,6 @@ impl Runtime {
r_new
}
#[inline]
pub(crate) fn push_query(&self, database_key_index: DatabaseKeyIndex) -> ActiveQueryGuard<'_> {
self.local_state.push_query(database_key_index)
}
/// Block until `other_id` completes executing `database_key`;
/// panic or unwind in the case of a cycle.
///
@ -285,6 +183,7 @@ impl Runtime {
pub(crate) fn block_on_or_unwind<QueryMutexGuard>(
&self,
db: &dyn Database,
local_state: &LocalState,
database_key: DatabaseKeyIndex,
other_id: ThreadId,
query_mutex_guard: QueryMutexGuard,
@ -293,7 +192,7 @@ impl Runtime {
let thread_id = std::thread::current().id();
if dg.depends_on(other_id, thread_id) {
self.unblock_cycle_and_maybe_throw(db, &mut dg, database_key, other_id);
self.unblock_cycle_and_maybe_throw(db, local_state, &mut dg, database_key, other_id);
// If the above fn returns, then (via cycle recovery) it has unblocked the
// cycle, so we can continue.
@ -308,7 +207,7 @@ impl Runtime {
},
});
let stack = self.local_state.take_query_stack();
let stack = local_state.take_query_stack();
let (stack, result) = DependencyGraph::block_on(
dg,
@ -319,7 +218,7 @@ impl Runtime {
query_mutex_guard,
);
self.local_state.restore_query_stack(stack);
local_state.restore_query_stack(stack);
match result {
WaitResult::Completed => (),
@ -344,6 +243,7 @@ impl Runtime {
fn unblock_cycle_and_maybe_throw(
&self,
db: &dyn Database,
local_state: &LocalState,
dg: &mut DependencyGraph,
database_key_index: DatabaseKeyIndex,
to_id: ThreadId,
@ -353,7 +253,7 @@ impl Runtime {
database_key_index
);
let mut from_stack = self.local_state.take_query_stack();
let mut from_stack = local_state.take_query_stack();
let from_id = std::thread::current().id();
// Make a "dummy stack frame". As we iterate through the cycle, we will collect the
@ -426,7 +326,7 @@ impl Runtime {
let (me_recovered, others_recovered) =
dg.maybe_unblock_runtimes_in_cycle(from_id, &from_stack, database_key_index, to_id);
self.local_state.restore_query_stack(from_stack);
local_state.restore_query_stack(from_stack);
if me_recovered {
// If the current thread has recovery, we want to throw

View file

@ -11,7 +11,7 @@ use crate::{
ingredient::{fmt_index, Ingredient, Jar},
ingredient_list::IngredientList,
key::{DatabaseKeyIndex, DependencyIndex},
local_state::QueryOrigin,
local_state::{self, QueryOrigin},
runtime::Runtime,
salsa_struct::SalsaStructInDb,
storage::IngredientIndex,
@ -291,83 +291,84 @@ where
db: &'db dyn Database,
fields: C::Fields<'db>,
) -> C::Struct<'db> {
let data_hash = crate::hash::hash(&C::id_fields(&fields));
local_state::attach(db, |local_state| {
let data_hash = crate::hash::hash(&C::id_fields(&fields));
let runtime = db.runtime();
let (query_key, current_deps, disambiguator) =
runtime.disambiguate_entity(self.ingredient_index, Revision::start(), data_hash);
let (query_key, current_deps, disambiguator) =
local_state.disambiguate(self.ingredient_index, Revision::start(), data_hash);
let entity_key = KeyStruct {
query_key,
disambiguator,
data_hash,
};
let entity_key = KeyStruct {
query_key,
disambiguator,
data_hash,
};
let (id, new_id) = self.intern(entity_key);
runtime.add_output(self.database_key_index(id).into());
let (id, new_id) = self.intern(entity_key);
local_state.add_output(self.database_key_index(id).into());
let current_revision = runtime.current_revision();
if new_id {
// This is a new tracked struct, so create an entry in the struct map.
let current_revision = db.runtime().current_revision();
if new_id {
// This is a new tracked struct, so create an entry in the struct map.
self.struct_map.insert(
runtime,
Value {
id,
key: entity_key,
struct_ingredient_index: self.ingredient_index,
created_at: current_revision,
durability: current_deps.durability,
fields: unsafe { self.to_static(fields) },
revisions: C::new_revisions(current_deps.changed_at),
},
)
} else {
// The struct already exists in the intern map.
// Note that we assume there is at most one executing copy of
// the current query at a time, which implies that the
// struct must exist in `self.struct_map` already
// (if the same query could execute twice in parallel,
// then it would potentially create the same struct twice in parallel,
// which means the interned key could exist but `struct_map` not yet have
// been updated).
self.struct_map.insert(
current_revision,
Value {
id,
key: entity_key,
struct_ingredient_index: self.ingredient_index,
created_at: current_revision,
durability: current_deps.durability,
fields: unsafe { self.to_static(fields) },
revisions: C::new_revisions(current_deps.changed_at),
},
)
} else {
// The struct already exists in the intern map.
// Note that we assume there is at most one executing copy of
// the current query at a time, which implies that the
// struct must exist in `self.struct_map` already
// (if the same query could execute twice in parallel,
// then it would potentially create the same struct twice in parallel,
// which means the interned key could exist but `struct_map` not yet have
// been updated).
match self.struct_map.update(runtime, id) {
Update::Current(r) => {
// All inputs up to this point were previously
// observed to be green and this struct was already
// verified. Therefore, the durability ought not to have
// changed (nor the field values, but the user could've
// done something stupid, so we can't *assert* this is true).
assert!(C::deref_struct(r).durability == current_deps.durability);
match self.struct_map.update(current_revision, id) {
Update::Current(r) => {
// All inputs up to this point were previously
// observed to be green and this struct was already
// verified. Therefore, the durability ought not to have
// changed (nor the field values, but the user could've
// done something stupid, so we can't *assert* this is true).
assert!(C::deref_struct(r).durability == current_deps.durability);
r
}
Update::Outdated(mut data_ref) => {
let data = &mut *data_ref;
// SAFETY: We assert that the pointer to `data.revisions`
// is a pointer into the database referencing a value
// from a previous revision. As such, it continues to meet
// its validity invariant and any owned content also continues
// to meet its safety invariant.
unsafe {
C::update_fields(
current_revision,
&mut data.revisions,
self.to_self_ptr(std::ptr::addr_of_mut!(data.fields)),
fields,
);
r
}
if current_deps.durability < data.durability {
data.revisions = C::new_revisions(current_revision);
Update::Outdated(mut data_ref) => {
let data = &mut *data_ref;
// SAFETY: We assert that the pointer to `data.revisions`
// is a pointer into the database referencing a value
// from a previous revision. As such, it continues to meet
// its validity invariant and any owned content also continues
// to meet its safety invariant.
unsafe {
C::update_fields(
current_revision,
&mut data.revisions,
self.to_self_ptr(std::ptr::addr_of_mut!(data.fields)),
fields,
);
}
if current_deps.durability < data.durability {
data.revisions = C::new_revisions(current_revision);
}
data.durability = current_deps.durability;
data.created_at = current_revision;
data_ref.freeze()
}
data.durability = current_deps.durability;
data.created_at = current_revision;
data_ref.freeze()
}
}
}
})
}
/// Given the id of a tracked struct created in this revision,
@ -377,7 +378,8 @@ where
///
/// If the struct has not been created in this revision.
pub fn lookup_struct<'db>(&'db self, runtime: &'db Runtime, id: Id) -> C::Struct<'db> {
self.struct_map.get(runtime, id)
let current_revision = runtime.current_revision();
self.struct_map.get(current_revision, id)
}
/// Deletes the given entities. This is used after a query `Q` executes and we can compare
@ -507,21 +509,26 @@ where
/// Access to this value field.
/// Note that this function returns the entire tuple of value fields.
/// The caller is responible for selecting the appropriate element.
pub fn field<'db>(&'db self, db: &'db dyn Database, field_index: usize) -> &'db C::Fields<'db> {
let runtime = db.runtime();
let field_ingredient_index = self.struct_ingredient_index.successor(field_index);
let changed_at = self.revisions[field_index];
pub fn field<'db>(
&'db self,
db: &dyn crate::Database,
field_index: usize,
) -> &'db C::Fields<'db> {
local_state::attach(db, |local_state| {
let field_ingredient_index = self.struct_ingredient_index.successor(field_index);
let changed_at = self.revisions[field_index];
runtime.report_tracked_read(
DependencyIndex {
ingredient_index: field_ingredient_index,
key_index: Some(self.id.as_id()),
},
self.durability,
changed_at,
);
local_state.report_tracked_read(
DependencyIndex {
ingredient_index: field_ingredient_index,
key_index: Some(self.id.as_id()),
},
self.durability,
changed_at,
);
unsafe { self.to_self_ref(&self.fields) }
unsafe { self.to_self_ref(&self.fields) }
})
}
unsafe fn to_self_ref<'db>(&'db self, fields: &'db C::Fields<'static>) -> &'db C::Fields<'db> {

View file

@ -6,7 +6,7 @@ use std::{
use crossbeam::queue::SegQueue;
use dashmap::mapref::one::RefMut;
use crate::{alloc::Alloc, hash::FxDashMap, Id, Runtime};
use crate::{alloc::Alloc, hash::FxDashMap, Id, Revision, Runtime};
use super::{Configuration, KeyStruct, Value};
@ -80,8 +80,8 @@ where
///
/// * If value with same `value.id` is already present in the map.
/// * If value not created in current revision.
pub fn insert<'db>(&'db self, runtime: &'db Runtime, value: Value<C>) -> C::Struct<'db> {
assert_eq!(value.created_at, runtime.current_revision());
pub fn insert<'db>(&'db self, current_revision: Revision, value: Value<C>) -> C::Struct<'db> {
assert_eq!(value.created_at, current_revision);
let id = value.id;
let boxed_value = Alloc::new(value);
@ -119,12 +119,9 @@ where
///
/// * If the value is not present in the map.
/// * If the value is already updated in this revision.
pub fn update<'db>(&'db self, runtime: &'db Runtime, id: Id) -> Update<'db, C> {
pub fn update<'db>(&'db self, current_revision: Revision, id: Id) -> Update<'db, C> {
let mut data = self.map.get_mut(&id).unwrap();
// Never update a struct twice in the same revision.
let current_revision = runtime.current_revision();
// UNSAFE: We never permit `&`-access in the current revision until data.created_at
// has been updated to the current revision (which we check below).
let data_ref = unsafe { data.as_mut() };
@ -154,7 +151,7 @@ where
// code cannot violate that `&`-reference.
if data_ref.created_at == current_revision {
drop(data);
return Update::Current(Self::get_from_map(&self.map, runtime, id));
return Update::Current(Self::get_from_map(&self.map, current_revision, id));
}
data_ref.created_at = current_revision;
@ -167,8 +164,8 @@ where
///
/// * If the value is not present in the map.
/// * If the value has not been updated in this revision.
pub fn get<'db>(&'db self, runtime: &'db Runtime, id: Id) -> C::Struct<'db> {
Self::get_from_map(&self.map, runtime, id)
pub fn get<'db>(&'db self, current_revision: Revision, id: Id) -> C::Struct<'db> {
Self::get_from_map(&self.map, current_revision, id)
}
/// Helper function, provides shared functionality for [`StructMapView`][]
@ -179,7 +176,7 @@ where
/// * If the value has not been updated in this revision.
fn get_from_map<'db>(
map: &'db FxDashMap<Id, Alloc<Value<C>>>,
runtime: &'db Runtime,
current_revision: Revision,
id: Id,
) -> C::Struct<'db> {
let data = map.get(&id).unwrap();
@ -190,7 +187,6 @@ where
// Before we drop the lock, check that the value has
// been updated in this revision. This is what allows us to return a ``
let current_revision = runtime.current_revision();
let created_at = data_ref.created_at;
assert!(
created_at == current_revision,
@ -235,8 +231,8 @@ where
///
/// * If the value is not present in the map.
/// * If the value has not been updated in this revision.
pub fn get<'db>(&'db self, runtime: &'db Runtime, id: Id) -> C::Struct<'db> {
StructMap::get_from_map(&self.map, runtime, id)
pub fn get<'db>(&'db self, current_revision: Revision, id: Id) -> C::Struct<'db> {
StructMap::get_from_map(&self.map, current_revision, id)
}
}

View file

@ -1,5 +1,6 @@
use crate::{
id::AsId, ingredient::Ingredient, key::DependencyIndex, storage::IngredientIndex, Database, Id,
id::AsId, ingredient::Ingredient, key::DependencyIndex, local_state, storage::IngredientIndex,
Database, Id,
};
use super::{struct_map::StructMapView, Configuration};
@ -46,21 +47,23 @@ where
/// Note that this function returns the entire tuple of value fields.
/// The caller is responible for selecting the appropriate element.
pub fn field<'db>(&'db self, db: &'db dyn Database, id: Id) -> &'db C::Fields<'db> {
let runtime = db.runtime();
let data = self.struct_map.get(runtime, id);
let data = C::deref_struct(data);
let changed_at = data.revisions[self.field_index];
local_state::attach(db, |local_state| {
let current_revision = db.runtime().current_revision();
let data = self.struct_map.get(current_revision, id);
let data = C::deref_struct(data);
let changed_at = data.revisions[self.field_index];
runtime.report_tracked_read(
DependencyIndex {
ingredient_index: self.ingredient_index,
key_index: Some(id.as_id()),
},
data.durability,
changed_at,
);
local_state.report_tracked_read(
DependencyIndex {
ingredient_index: self.ingredient_index,
key_index: Some(id.as_id()),
},
data.durability,
changed_at,
);
unsafe { self.to_self_ref(&data.fields) }
unsafe { self.to_self_ref(&data.fields) }
})
}
}
@ -82,9 +85,9 @@ where
input: Option<Id>,
revision: crate::Revision,
) -> bool {
let runtime = db.runtime();
let current_revision = db.runtime().current_revision();
let id = input.unwrap();
let data = self.struct_map.get(runtime, id);
let data = self.struct_map.get(current_revision, id);
let data = C::deref_struct(data);
let field_changed_at = data.revisions[self.field_index];
field_changed_at > revision