use table for tracked structs and their fields

This also retools a tiny bit how deletion works.
We will reuse ids faster than before, actually.
This commit is contained in:
Niko Matsakis 2024-08-10 10:13:38 +03:00
parent 01d4ef86b2
commit 188f759555
17 changed files with 561 additions and 575 deletions

View file

@ -63,7 +63,7 @@ macro_rules! setup_tracked_struct {
$(#[$attr])*
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
$vis struct $Struct<$db_lt>(
std::ptr::NonNull<salsa::plumbing::tracked_struct::Value < $Struct<'static> >>,
salsa::Id,
std::marker::PhantomData < & $db_lt salsa::plumbing::tracked_struct::Value < $Struct<'static> > >
);
@ -90,12 +90,12 @@ macro_rules! setup_tracked_struct {
type Struct<$db_lt> = $Struct<$db_lt>;
unsafe fn struct_from_raw<$db_lt>(ptr: $NonNull<$zalsa_struct::Value<Self>>) -> Self::Struct<$db_lt> {
$Struct(ptr, std::marker::PhantomData)
fn struct_from_id<$db_lt>(id: salsa::Id) -> Self::Struct<$db_lt> {
$Struct(id, std::marker::PhantomData)
}
fn deref_struct(s: Self::Struct<'_>) -> &$zalsa_struct::Value<Self> {
unsafe { s.0.as_ref() }
fn deref_struct(s: Self::Struct<'_>) -> salsa::Id {
s.0
}
fn id_fields(fields: &Self::Fields<'_>) -> impl std::hash::Hash {
@ -141,13 +141,13 @@ macro_rules! setup_tracked_struct {
impl<$db_lt> $zalsa::LookupId<$db_lt> for $Struct<$db_lt> {
fn lookup_id(id: salsa::Id, db: &$db_lt dyn $zalsa::Database) -> Self {
$Configuration::ingredient(db).lookup_struct(db, id)
$Struct(id, std::marker::PhantomData)
}
}
impl $zalsa::AsId for $Struct<'_> {
fn as_id(&self) -> $zalsa::Id {
unsafe { self.0.as_ref() }.as_id()
self.0
}
}
@ -199,12 +199,13 @@ macro_rules! setup_tracked_struct {
}
$(
$field_getter_vis fn $field_getter_id<$Db>(&self, db: &$db_lt $Db) -> $crate::maybe_cloned_ty!($field_option, $db_lt, $field_ty)
$field_getter_vis fn $field_getter_id<$Db>(self, db: &$db_lt $Db) -> $crate::maybe_cloned_ty!($field_option, $db_lt, $field_ty)
where
// FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database`
$Db: ?Sized + $zalsa::Database,
{
let fields = unsafe { self.0.as_ref() }.field(db.as_dyn_database(), $field_index);
let db = db.as_dyn_database();
let fields = $Configuration::ingredient(db).field(db, self, $field_index);
$crate::maybe_clone!(
$field_option,
$field_ty,
@ -216,7 +217,7 @@ macro_rules! setup_tracked_struct {
/// Default debug formatting for this struct (may be useful if you define your own `Debug` impl)
pub fn default_debug_fmt(this: Self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
$zalsa::with_attached_database(|db| {
let fields = $Configuration::ingredient(db).leak_fields(this);
let fields = $Configuration::ingredient(db).leak_fields(db, this);
let mut f = f.debug_struct(stringify!($Struct));
let f = f.field("[salsa id]", &$zalsa::AsId::as_id(&this));
$(

View file

@ -378,7 +378,7 @@ fn parse_print() {
let expected = expect_test::expect![[r#"
(
Program {
[salsa id]: Id(0),
[salsa id]: Id(400),
statements: [
Statement {
span: Span {
@ -441,7 +441,7 @@ fn parse_example() {
let expected = expect_test::expect![[r#"
(
Program {
[salsa id]: Id(0),
[salsa id]: Id(1000),
statements: [
Statement {
span: Span {
@ -451,7 +451,7 @@ fn parse_example() {
},
data: Function(
Function {
[salsa id]: Id(0),
[salsa id]: Id(c00),
name: FunctionId {
text: "area_rectangle",
},
@ -513,7 +513,7 @@ fn parse_example() {
},
data: Function(
Function {
[salsa id]: Id(1),
[salsa id]: Id(c01),
name: FunctionId {
text: "area_circle",
},
@ -714,7 +714,7 @@ fn parse_error() {
let expected = expect_test::expect![[r#"
(
Program {
[salsa id]: Id(0),
[salsa id]: Id(400),
statements: [],
},
[
@ -736,7 +736,7 @@ fn parse_precedence() {
let expected = expect_test::expect![[r#"
(
Program {
[salsa id]: Id(0),
[salsa id]: Id(400),
statements: [
Statement {
span: Span {

View file

@ -268,7 +268,7 @@ fn fix_bad_variable_in_function() {
expect![[r#"
[
"Event: Event { thread_id: ThreadId(11), kind: WillExecute { database_key: parse_statements(Id(0)) } }",
"Event: Event { thread_id: ThreadId(11), kind: WillExecute { database_key: type_check_function(Id(0)) } }",
"Event: Event { thread_id: ThreadId(11), kind: WillExecute { database_key: type_check_function(Id(1400)) } }",
]
"#]],
)],

View file

@ -1,10 +1,12 @@
use rustc_hash::FxHashMap;
use crate::{
durability::Durability,
hash::{FxIndexMap, FxIndexSet},
key::{DatabaseKeyIndex, DependencyIndex},
tracked_struct::Disambiguator,
tracked_struct::{Disambiguator, KeyStruct},
zalsa_local::EMPTY_DEPENDENCIES,
Cycle, Revision,
Cycle, Id, Revision,
};
use super::zalsa_local::{EdgeKind, QueryEdges, QueryOrigin, QueryRevisions};
@ -35,10 +37,18 @@ pub(crate) struct ActiveQuery {
/// Stores the entire cycle, if one is found and this query is part of it.
pub(crate) cycle: Option<Cycle>,
/// When new entities are created, their data is hashed, and the resulting
/// When new tracked structs are created, their data is hashed, and the resulting
/// hash is added to this map. If it is not present, then the disambiguator is 0.
/// Otherwise it is 1 more than the current value (which is incremented).
///
/// This table starts empty as the query begins and is gradually populated.
/// Note that if a query executes in 2 different revisions but creates the same
/// set of tracked structs, they will get the same disambiguator values.
disambiguator_map: FxIndexMap<u64, Disambiguator>,
/// Map from tracked struct keys (which include the hash + disambiguator) to their
/// final id.
pub(crate) tracked_struct_ids: FxHashMap<KeyStruct, Id>,
}
impl ActiveQuery {
@ -51,6 +61,7 @@ impl ActiveQuery {
untracked_read: false,
cycle: None,
disambiguator_map: Default::default(),
tracked_struct_ids: Default::default(),
}
}
@ -106,6 +117,7 @@ impl ActiveQuery {
changed_at: self.changed_at,
origin,
durability: self.durability,
tracked_struct_ids: self.tracked_struct_ids.clone(),
}
}

View file

@ -14,10 +14,6 @@ impl<T> Alloc<T> {
}
}
pub fn as_raw(&self) -> NonNull<T> {
self.data
}
pub unsafe fn as_ref(&self) -> &T {
unsafe { self.data.as_ref() }
}

View file

@ -39,6 +39,12 @@ where
},
});
// If we already executed this query once, then use the tracked-struct ids from the
// previous execution as the starting point for the new one.
if let Some(old_memo) = &opt_old_memo {
active_query.seed_tracked_struct_ids(&old_memo.revisions.tracked_struct_ids);
}
// Query was not previously executed, or value is potentially
// stale, or value is absent. Let's execute!
let database_key_index = active_query.database_key_index;

View file

@ -68,6 +68,7 @@ where
changed_at: current_deps.changed_at,
durability: current_deps.durability,
origin: QueryOrigin::Assigned(active_query_key),
tracked_struct_ids: Default::default(),
};
if let Some(old_memo) = self.memo_map.get(key) {

View file

@ -1,6 +1,7 @@
use std::sync::Arc;
use crossbeam::atomic::AtomicCell;
use rustc_hash::FxHashMap;
use crate::{
durability::Durability,
@ -29,6 +30,7 @@ where
changed_at: revision,
durability,
origin: QueryOrigin::BaseInput,
tracked_struct_ids: FxHashMap::default(),
},
};

View file

@ -1,4 +1,4 @@
use std::{any::Any, mem::MaybeUninit};
use std::{any::Any, cell::UnsafeCell, panic::RefUnwindSafe};
use append_only_vec::AppendOnlyVec;
use crossbeam::atomic::AtomicCell;
@ -16,7 +16,7 @@ pub struct Table {
pub struct Page<T: Any + Send + Sync> {
/// The ingredient for elements on this page.
#[allow(dead_code)] // pretty sure we'll need this eventually
#[allow(dead_code)] // pretty sure we'll need this
ingredient: IngredientIndex,
/// Number of elements of `data` that are initialized.
@ -34,9 +34,15 @@ pub struct Page<T: Any + Send + Sync> {
/// Vector with data. This is always created with the capacity/length of `PAGE_LEN`
/// and uninitialized data. As we initialize new entries, we increment `allocated`.
data: Vec<MaybeUninit<T>>,
data: Vec<UnsafeCell<T>>,
}
unsafe impl<T: Any + Send + Sync> Send for Page<T> {}
unsafe impl<T: Any + Send + Sync> Sync for Page<T> {}
impl<T: Any + Send + Sync> RefUnwindSafe for Page<T> {}
#[derive(Copy, Clone)]
pub struct PageIndex(usize);
@ -58,6 +64,12 @@ impl Table {
page_ref.get(slot)
}
pub fn get_raw<T: Any + Send + Sync>(&self, id: Id) -> *mut T {
let (page, slot) = split_id(id);
let page_ref = self.page::<T>(page);
page_ref.get_raw(slot)
}
pub fn page<T: Any + Send + Sync>(&self, page: PageIndex) -> &Page<T> {
self.pages[page.0].downcast_ref::<Page<T>>().unwrap()
}
@ -85,7 +97,15 @@ impl<T: Any + Send + Sync> Page<T> {
pub(crate) fn get(&self, slot: SlotIndex) -> &T {
let len = self.allocated.load();
assert!(slot.0 < len);
unsafe { self.data[slot.0].assume_init_ref() }
unsafe { &*self.data[slot.0].get() }
}
/// Returns a raw pointer to the given slot.
/// Reads/writes must be coordinated properly with calls to `get`.
pub(crate) fn get_raw(&self, slot: SlotIndex) -> *mut T {
let len = self.allocated.load();
assert!(slot.0 < len);
self.data[slot.0].get()
}
pub(crate) fn allocate(&self, page: PageIndex, value: T) -> Result<Id, T> {
@ -95,10 +115,11 @@ impl<T: Any + Send + Sync> Page<T> {
return Err(value);
}
// Initialize entry `index`
let data = &self.data[index];
let data = data.as_ptr() as *mut T;
unsafe { std::ptr::write(data, value) };
unsafe { std::ptr::write(data.get(), value) };
// Update the length (this must be done after initialization!)
self.allocated.store(index + 1);
drop(guard);
@ -108,11 +129,14 @@ impl<T: Any + Send + Sync> Page<T> {
impl<T: Any + Send + Sync> Drop for Page<T> {
fn drop(&mut self) {
// Free `self.data` and the data within: to do this, we swap it out with an empty vector
// and then convert it from a `Vec<UnsafeCell<T>>` with partially uninitialized values
// to a `Vec<T>` with the correct length. This way the `Vec` drop impl can do its job.
let mut data = std::mem::replace(&mut self.data, vec![]);
let len = self.allocated.load();
unsafe {
data.set_len(len);
drop(std::mem::transmute::<Vec<MaybeUninit<T>>, Vec<T>>(data));
drop(std::mem::transmute::<Vec<UnsafeCell<T>>, Vec<T>>(data));
}
}
}

View file

@ -1,25 +1,22 @@
use std::{fmt, hash::Hash, marker::PhantomData, ops::DerefMut, ptr::NonNull};
use std::{fmt, hash::Hash, marker::PhantomData, ops::DerefMut};
use crossbeam::atomic::AtomicCell;
use dashmap::mapref::entry::Entry;
use crossbeam::{atomic::AtomicCell, queue::SegQueue};
use tracked_field::FieldIngredientImpl;
use crate::{
cycle::CycleRecoveryStrategy,
hash::FxDashMap,
id::AsId,
ingredient::{fmt_index, Ingredient, Jar},
ingredient_list::IngredientList,
key::{DatabaseKeyIndex, DependencyIndex},
plumbing::ZalsaLocal,
runtime::StampedValue,
salsa_struct::SalsaStructInDb,
zalsa::IngredientIndex,
table::Table,
zalsa::{IngredientIndex, Zalsa},
zalsa_local::QueryOrigin,
Database, Durability, Event, Id, Revision,
};
use self::struct_map::{StructMap, Update};
mod struct_map;
pub mod tracked_field;
// ANCHOR: Configuration
@ -47,18 +44,10 @@ pub trait Configuration: Sized + 'static {
/// process in a given revision: it occurs only when the struct is newly
/// created or, if a struct is being reused, after we have updated its
/// fields (or confirmed it is green and no updates are required).
///
/// # Safety
///
/// Requires that `ptr` represents a "confirmed" value in this revision,
/// which means that it will remain valid and immutable for the remainder of this
/// revision, represented by the lifetime `'db`.
unsafe fn struct_from_raw<'db>(ptr: NonNull<Value<Self>>) -> Self::Struct<'db>;
fn struct_from_id<'db>(id: Id) -> Self::Struct<'db>;
/// Deref the struct to yield the underlying value struct.
/// Since we are still part of the `'db` lifetime in which the struct was created,
/// this deref is safe, and the value-struct fields are immutable and verified.
fn deref_struct(s: Self::Struct<'_>) -> &Value<Self>;
/// Deref the struct to yield the underlying id.
fn deref_struct(s: Self::Struct<'_>) -> Id;
fn id_fields(fields: &Self::Fields<'_>) -> impl Hash;
@ -115,16 +104,11 @@ impl<C: Configuration> Jar for JarImpl<C> {
&self,
struct_index: crate::zalsa::IngredientIndex,
) -> Vec<Box<dyn Ingredient>> {
let struct_ingredient = IngredientImpl::new(struct_index);
let struct_map = &struct_ingredient.struct_map.view();
let struct_ingredient = <IngredientImpl<C>>::new(struct_index);
std::iter::once(Box::new(struct_ingredient) as _)
.chain((0..C::FIELD_DEBUG_NAMES.len()).map(|field_index| {
Box::new(FieldIngredientImpl::<C>::new(
struct_index,
field_index,
struct_map,
)) as _
Box::new(<FieldIngredientImpl<C>>::new(struct_index, field_index)) as _
}))
.collect()
}
@ -150,18 +134,6 @@ where
/// Our index in the database.
ingredient_index: IngredientIndex,
/// Defines the set of live tracked structs.
/// Entries are added to this map when a new struct is created.
/// They are removed when that struct is deleted
/// (i.e., a query completes without having recreated the struct).
keys: FxDashMap<KeyStruct, Id>,
/// The number of tracked structs created.
counter: AtomicCell<u32>,
/// Map from the [`Id`][] of each struct to its fields/values.
struct_map: struct_map::StructMap<C>,
/// A list of each tracked function `f` whose key is this
/// tracked struct.
///
@ -169,14 +141,20 @@ where
/// each of these functions will be notified
/// so they can remove any data tied to that instance.
dependent_fns: IngredientList,
/// Phantom data: we fetch `Value<C>` out from `Table`
phantom: PhantomData<fn() -> Value<C>>,
/// Store freed ids
free_list: SegQueue<Id>,
}
/// Defines the identity of a tracked struct.
/// This is the key to a hashmap that is (initially)
/// stored in the [`ActiveQuery`](`crate::active_query::ActiveQuery`)
/// struct and later moved to the [`Memo`](`crate::function::memo::Memo`).
#[derive(Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Copy, Clone)]
struct KeyStruct {
/// The active query (i.e., tracked function) that created this tracked struct.
query_key: DatabaseKeyIndex,
pub(crate) struct KeyStruct {
/// The hash of the `#[id]` fields of this struct.
/// Note that multiple structs may share the same hash.
data_hash: u64,
@ -192,29 +170,32 @@ pub struct Value<C>
where
C: Configuration,
{
/// Index of the struct ingredient.
struct_ingredient_index: IngredientIndex,
/// The id of this struct in the ingredient.
id: Id,
/// The key used to create the id.
key: KeyStruct,
/// The durability minimum durability of all inputs consumed
/// by the creator query prior to creating this tracked struct.
/// If any of those inputs changes, then the creator query may
/// create this struct with different values.
durability: Durability,
/// The revision when this entity was most recently created.
/// Typically the current revision.
/// Used to detect "leaks" outside of the salsa system -- i.e.,
/// access to tracked structs that have not (yet?) been created in the
/// current revision. This should be impossible within salsa queries
/// but it can happen through "leaks" like thread-local data or storing
/// values outside of the root salsa query.
created_at: Revision,
/// The revision when this tracked struct was last updated.
/// This field also acts as a kind of "lock". Once it is equal
/// to `Some(current_revision)`, the fields are locked and
/// cannot change further. This makes it safe to give out `&`-references
/// so long as they do not live longer than the current revision
/// (which is assured by tying their lifetime to the lifetime of an `&`-ref
/// to the database).
///
/// The struct is updated from an older revision `R0` to the current revision `R1`
/// when the struct is first accessed in `R1`, whether that be because the original
/// query re-created the struct (i.e., by user calling `Struct::new`) or because
/// the struct was read from. (Structs may not be recreated in the new revision if
/// the inputs to the query have not changed.)
///
/// When re-creating the struct, the field is temporarily set to `None`.
/// This is signal that there is an active `&mut` modifying the other fields:
/// even reading from those fields in that situation would create UB.
/// This `None` value should never be observable by users unless they have
/// leaked a reference across threads somehow.
updated_at: AtomicCell<Option<Revision>>,
/// Fields of this tracked struct. They can change across revisions,
/// but they do not change within a particular revision.
@ -240,6 +221,10 @@ where
unsafe { std::mem::transmute(fields) }
}
unsafe fn to_self_ref<'db>(&'db self, fields: &'db C::Fields<'static>) -> &'db C::Fields<'db> {
unsafe { std::mem::transmute(fields) }
}
/// Convert from static back to the db lifetime; used when returning data
/// out from this ingredient.
unsafe fn to_self_ptr<'db>(&'db self, fields: *mut C::Fields<'static>) -> *mut C::Fields<'db> {
@ -251,10 +236,9 @@ where
fn new(index: IngredientIndex) -> Self {
Self {
ingredient_index: index,
keys: FxDashMap::default(),
counter: AtomicCell::new(0),
struct_map: StructMap::new(),
dependent_fns: IngredientList::new(),
phantom: PhantomData,
free_list: Default::default(),
}
}
@ -266,120 +250,184 @@ where
}
}
/// Intern a tracked struct key to get a unique tracked struct id.
/// Also returns a bool indicating whether this id was newly created or whether it already existed.
fn intern(&self, key: KeyStruct) -> (Id, bool) {
let (id, new_id) = if let Some(g) = self.keys.get(&key) {
(*g.value(), false)
} else {
match self.keys.entry(key) {
Entry::Occupied(o) => (*o.get(), false),
Entry::Vacant(v) => {
let id = Id::from_u32(self.counter.fetch_add(1));
v.insert(id);
(id, true)
}
}
};
(id, new_id)
}
pub fn new_struct<'db>(
&'db self,
db: &'db dyn Database,
fields: C::Fields<'db>,
) -> C::Struct<'db> {
let zalsa = db.zalsa();
let zalsa_local = db.zalsa_local();
let (zalsa, zalsa_local) = db.zalsas();
let data_hash = crate::hash::hash(&C::id_fields(&fields));
let (query_key, current_deps, disambiguator) =
let (current_deps, disambiguator) =
zalsa_local.disambiguate(self.ingredient_index, Revision::start(), data_hash);
let entity_key = KeyStruct {
query_key,
let key_struct = KeyStruct {
disambiguator,
data_hash,
};
let (id, new_id) = self.intern(entity_key);
zalsa_local.add_output(self.database_key_index(id).into());
let current_revision = zalsa.current_revision();
if new_id {
// This is a new tracked struct, so create an entry in the struct map.
match zalsa_local.tracked_struct_id(&key_struct) {
Some(id) => {
// The struct already exists in the intern map.
zalsa_local.add_output(self.database_key_index(id).into());
self.update(zalsa, current_revision, id, &current_deps, fields);
C::struct_from_id(id)
}
self.struct_map.insert(
current_revision,
Value {
id,
key: entity_key,
struct_ingredient_index: self.ingredient_index,
created_at: current_revision,
durability: current_deps.durability,
fields: unsafe { self.to_static(fields) },
revisions: C::new_revisions(current_deps.changed_at),
},
)
} else {
// The struct already exists in the intern map.
// Note that we assume there is at most one executing copy of
// the current query at a time, which implies that the
// struct must exist in `self.struct_map` already
// (if the same query could execute twice in parallel,
// then it would potentially create the same struct twice in parallel,
// which means the interned key could exist but `struct_map` not yet have
// been updated).
match self.struct_map.update(current_revision, id) {
Update::Current(r) => {
// All inputs up to this point were previously
// observed to be green and this struct was already
// verified. Therefore, the durability ought not to have
// changed (nor the field values, but the user could've
// done something stupid, so we can't *assert* this is true).
assert!(C::deref_struct(r).durability == current_deps.durability);
r
}
Update::Outdated(mut data_ref) => {
let data = &mut *data_ref;
// SAFETY: We assert that the pointer to `data.revisions`
// is a pointer into the database referencing a value
// from a previous revision. As such, it continues to meet
// its validity invariant and any owned content also continues
// to meet its safety invariant.
unsafe {
C::update_fields(
current_revision,
&mut data.revisions,
self.to_self_ptr(std::ptr::addr_of_mut!(data.fields)),
fields,
);
}
if current_deps.durability < data.durability {
data.revisions = C::new_revisions(current_revision);
}
data.durability = current_deps.durability;
data.created_at = current_revision;
data_ref.freeze()
}
None => {
// This is a new tracked struct, so create an entry in the struct map.
let id = self.allocate(zalsa, zalsa_local, current_revision, &current_deps, fields);
zalsa_local.add_output(self.database_key_index(id).into());
zalsa_local.store_tracked_struct_id(key_struct, id);
C::struct_from_id(id)
}
}
}
/// Given the id of a tracked struct created in this revision,
/// returns a pointer to the struct.
fn allocate<'db>(
&'db self,
zalsa: &'db Zalsa,
zalsa_local: &'db ZalsaLocal,
current_revision: Revision,
current_deps: &StampedValue<()>,
fields: C::Fields<'db>,
) -> Id {
let value = || Value {
updated_at: AtomicCell::new(Some(current_revision)),
durability: current_deps.durability,
fields: unsafe { self.to_static(fields) },
revisions: C::new_revisions(current_deps.changed_at),
};
if let Some(id) = self.free_list.pop() {
let data_raw = Self::data_raw(zalsa.table(), id);
assert!(
unsafe { (*data_raw).updated_at.load().is_none() },
"free list entry for `{id:?}` does not have `None` for `updated_at`"
);
// Overwrite the free-list entry. Use `*foo = ` because the entry
// has been previously initialized and we want to free the old contents.
unsafe {
*data_raw = value();
}
id
} else {
zalsa_local.allocate::<Value<C>>(zalsa.table(), self.ingredient_index, value())
}
}
/// Get mutable access to the data for `id` -- this holds a write lock for the duration
/// of the returned value.
///
/// # Panics
///
/// If the struct has not been created in this revision.
pub fn lookup_struct<'db>(&'db self, db: &'db dyn Database, id: Id) -> C::Struct<'db> {
let current_revision = db.zalsa().current_revision();
self.struct_map.get(current_revision, id)
/// * If the value is not present in the map.
/// * If the value is already updated in this revision.
fn update<'db>(
&'db self,
zalsa: &'db Zalsa,
current_revision: Revision,
id: Id,
current_deps: &StampedValue<()>,
fields: C::Fields<'db>,
) {
let data_raw = Self::data_raw(zalsa.table(), id);
// The protocol is:
//
// * When we begin updating, we store `None` in the `created_at` field
// * When completed, we store `Some(current_revision)` in `created_at`
//
// No matter what mischief users get up to, it should be impossible for us to
// observe `None` in `created_at`. The `id` should only be associated with one
// query and that query can only be running in one thread at a time.
//
// We *can* observe `Some(current_revision)` however, which means that this
// tracked struct is already updated for this revision in two ways.
// In that case we should not modify or touch it because there may be
// `&`-references to its contents floating around.
//
// Observing `Some(current_revision)` can happen in two scenarios: leaks (tsk tsk)
// but also the scenario embodied by the test test `test_run_5_then_20` in `specify_tracked_fn_in_rev_1_but_not_2.rs`:
//
// * Revision 1:
// * Tracked function F creates tracked struct S
// * F reads input I
// * Revision 2: I is changed, F is re-executed
//
// When F is re-executed in rev 2, we first try to validate F's inputs/outputs,
// which is the list [output: S, input: I]. As no inputs have changed by the time
// we reach S, we mark it as verified. But then input I is seen to hvae changed,
// and so we re-execute F. Note that we *know* that S will have the same value
// (barring program bugs).
//
// Further complicating things: it is possible that F calls F2
// and gives it (e.g.) S as one of its arguments. Validating F2 may cause F2 to
// re-execute which means that it may indeed have read from S's fields
// during the current revision and thus obtained an `&` reference to those fields
// that is still live.
// UNSAFE: Marking as mut requires exclusive access for the duration of
// the `mut`. We have now *claimed* this data by swapping in `None`,
// any attempt to read concurrently will panic.
let last_updated_at = unsafe { (*data_raw).updated_at.load() };
assert!(
last_updated_at.is_some(),
"two concurrent writers to {id:?}, should not be possible"
);
if last_updated_at == Some(current_revision) {
// already read-locked
return;
}
// Acquire the write-lock. This can only fail if there is a parallel thread
// reading from this same `id`, which can only happen if the user has leaked it.
// Tsk tsk.
let swapped_out = unsafe { (*data_raw).updated_at.swap(None) };
if swapped_out != last_updated_at {
panic!(
"failed to acquire write lock, id `{id:?}` must have been leaked across threads"
);
}
// UNSAFE: Marking as mut requires exclusive access for the duration of
// the `mut`. We have now *claimed* this data by swapping in `None`,
// any attempt to read concurrently will panic.
let data = unsafe { &mut *data_raw };
// SAFETY: We assert that the pointer to `data.revisions`
// is a pointer into the database referencing a value
// from a previous revision. As such, it continues to meet
// its validity invariant and any owned content also continues
// to meet its safety invariant.
unsafe {
C::update_fields(
current_revision,
&mut data.revisions,
self.to_self_ptr(std::ptr::addr_of_mut!(data.fields)),
fields,
);
}
if current_deps.durability < data.durability {
data.revisions = C::new_revisions(current_revision);
}
data.durability = current_deps.durability;
let swapped_out = data.updated_at.swap(Some(current_revision));
assert!(swapped_out.is_none());
}
/// Fetch the data for a given id created by this ingredient from the table,
/// -giving it the appropriate type.
fn data<'t>(table: &'t Table, id: Id) -> &'t Value<C> {
table.get(id)
}
fn data_raw(table: &Table, id: Id) -> *mut Value<C> {
table.get_raw(id)
}
/// Deletes the given entities. This is used after a query `Q` executes and we can compare
@ -400,8 +448,29 @@ where
},
});
if let Some(key) = self.struct_map.delete(id) {
self.keys.remove(&key);
let zalsa = db.zalsa();
let current_revision = zalsa.current_revision();
let data = Self::data(zalsa.table(), id);
// We want to set `updated_at` to `None`, signalling that other field values
// cannot be read. The current vaue should be `Some(R0)` for some older revision.
match data.updated_at.load() {
None => {
panic!("cannot delete write-locked id `{id:?}`; value leaked across threads");
}
Some(r) => {
if r == current_revision {
panic!(
"cannot delete read-locked id `{id:?}`; \
value leaked across threads or user functions not deterministic"
)
}
if data.updated_at.compare_exchange(Some(r), None).is_err() {
panic!("race occurred when deleting value `{id:?}`")
}
}
}
for dependent_fn in self.dependent_fns.iter() {
@ -409,6 +478,9 @@ where
.lookup_ingredient(dependent_fn)
.salsa_struct_deleted(db, id);
}
// now that all cleanup has occurred, make available for re-use
self.free_list.push(id);
}
/// Adds a dependent function (one keyed by this tracked struct) to our list.
@ -420,9 +492,67 @@ where
/// Return reference to the field data ignoring dependency tracking.
/// Used for debugging.
pub fn leak_fields<'db>(&'db self, s: C::Struct<'db>) -> &'db C::Fields<'db> {
let value = C::deref_struct(s);
unsafe { value.to_self_ref(&value.fields) }
pub fn leak_fields<'db>(
&'db self,
db: &'db dyn Database,
s: C::Struct<'db>,
) -> &'db C::Fields<'db> {
let id = C::deref_struct(s);
let value = Self::data(db.zalsa().table(), id);
unsafe { self.to_self_ref(&value.fields) }
}
/// Access to this value field.
/// Note that this function returns the entire tuple of value fields.
/// The caller is responible for selecting the appropriate element.
pub fn field<'db>(
&'db self,
db: &'db dyn crate::Database,
s: C::Struct<'db>,
field_index: usize,
) -> &'db C::Fields<'db> {
let (zalsa, zalsa_local) = db.zalsas();
let id = C::deref_struct(s);
let field_ingredient_index = self.ingredient_index.successor(field_index);
let data = Self::data(zalsa.table(), id);
self.read_lock(data, zalsa.current_revision());
let field_changed_at = data.revisions[field_index];
zalsa_local.report_tracked_read(
DependencyIndex {
ingredient_index: field_ingredient_index,
key_index: Some(id),
},
data.durability,
field_changed_at,
);
unsafe { self.to_self_ref(&data.fields) }
}
fn read_lock(&self, data: &Value<C>, current_revision: Revision) {
loop {
match data.updated_at.load() {
None => {
panic!("access to field whilst the value is being initialized");
}
Some(r) => {
if r == current_revision {
return;
}
if data
.updated_at
.compare_exchange(Some(r), Some(current_revision))
.is_ok()
{
break;
}
}
}
}
}
}
@ -453,13 +583,13 @@ where
fn mark_validated_output<'db>(
&'db self,
db: &'db dyn Database,
_db: &'db dyn Database,
_executor: DatabaseKeyIndex,
output_key: Option<crate::Id>,
_output_key: Option<crate::Id>,
) {
let current_revision = db.zalsa().current_revision();
let output_key = output_key.unwrap();
self.struct_map.validate(current_revision, output_key);
// we used to update `update_at` field but now we do it lazilly when data is accessed
//
// FIXME: delete this method
}
fn remove_stale_output(
@ -475,14 +605,6 @@ where
self.delete_entity(db.as_dyn_database(), stale_output_key.unwrap());
}
fn requires_reset_for_new_revision(&self) -> bool {
true
}
fn reset_for_new_revision(&mut self) {
self.struct_map.drop_deleted_entries();
}
fn salsa_struct_deleted(&self, _db: &dyn Database, _id: crate::Id) {
panic!("unexpected call: interned ingredients do not register for salsa struct deletion events");
}
@ -494,6 +616,12 @@ where
fn debug_name(&self) -> &'static str {
C::DEBUG_NAME
}
fn requires_reset_for_new_revision(&self) -> bool {
false
}
fn reset_for_new_revision(&mut self) {}
}
impl<C> std::fmt::Debug for IngredientImpl<C>
@ -506,45 +634,3 @@ where
.finish()
}
}
impl<C> Value<C>
where
C: Configuration,
{
/// Access to this value field.
/// Note that this function returns the entire tuple of value fields.
/// The caller is responible for selecting the appropriate element.
pub fn field<'db>(
&'db self,
db: &dyn crate::Database,
field_index: usize,
) -> &'db C::Fields<'db> {
let zalsa_local = db.zalsa_local();
let field_ingredient_index = self.struct_ingredient_index.successor(field_index);
let changed_at = self.revisions[field_index];
zalsa_local.report_tracked_read(
DependencyIndex {
ingredient_index: field_ingredient_index,
key_index: Some(self.id.as_id()),
},
self.durability,
changed_at,
);
unsafe { self.to_self_ref(&self.fields) }
}
unsafe fn to_self_ref<'db>(&'db self, fields: &'db C::Fields<'static>) -> &'db C::Fields<'db> {
unsafe { std::mem::transmute(fields) }
}
}
impl<C> AsId for Value<C>
where
C: Configuration,
{
fn as_id(&self) -> Id {
self.id
}
}

View file

@ -1,280 +0,0 @@
use std::{
ops::{Deref, DerefMut},
sync::Arc,
};
use crossbeam::queue::SegQueue;
use dashmap::mapref::one::RefMut;
use crate::{alloc::Alloc, hash::FxDashMap, Id, Revision};
use super::{Configuration, KeyStruct, Value};
pub(crate) struct StructMap<C>
where
C: Configuration,
{
map: Arc<FxDashMap<Id, Alloc<Value<C>>>>,
/// When specific entities are deleted, their data is added
/// to this vector rather than being immediately freed. This is because we may` have
/// references to that data floating about that are tied to the lifetime of some
/// `&db` reference. This queue itself is not freed until we have an `&mut db` reference,
/// guaranteeing that there are no more references to it.
deleted_entries: SegQueue<Alloc<Value<C>>>,
}
pub(crate) struct StructMapView<C>
where
C: Configuration,
{
map: Arc<FxDashMap<Id, Alloc<Value<C>>>>,
}
impl<C: Configuration> Clone for StructMapView<C> {
fn clone(&self) -> Self {
Self {
map: self.map.clone(),
}
}
}
/// Return value for [`StructMap`][]'s `update` method.
pub(crate) enum Update<'db, C>
where
C: Configuration,
{
/// Indicates that the given struct has not yet been verified in this revision.
/// The [`UpdateRef`][] gives mutable access to the field contents so that
/// its fields can be compared and updated.
Outdated(UpdateRef<'db, C>),
/// Indicates that we have already verified that all the inputs accessed prior
/// to this struct creation were up-to-date, and therefore the field contents
/// ought not to have changed (barring user error). Returns a shared reference
/// because caller cannot safely modify fields at this point.
Current(C::Struct<'db>),
}
impl<C> StructMap<C>
where
C: Configuration,
{
pub fn new() -> Self {
Self {
map: Arc::new(FxDashMap::default()),
deleted_entries: SegQueue::new(),
}
}
/// Get a secondary view onto this struct-map that can be used to fetch entries.
pub fn view(&self) -> StructMapView<C> {
StructMapView {
map: self.map.clone(),
}
}
/// Insert the given tracked struct value into the map.
///
/// # Panics
///
/// * If value with same `value.id` is already present in the map.
/// * If value not created in current revision.
pub fn insert(&self, current_revision: Revision, value: Value<C>) -> C::Struct<'_> {
assert_eq!(value.created_at, current_revision);
let id = value.id;
let boxed_value = Alloc::new(value);
let pointer = boxed_value.as_raw();
let old_value = self.map.insert(id, boxed_value);
assert!(old_value.is_none()); // ...strictly speaking we probably need to abort here
// Unsafety clause:
//
// * The box is owned by self and, although the box has been moved,
// the pointer is to the contents of the box, which have a stable
// address.
// * Values are only removed or altered when we have `&mut self`.
unsafe { C::struct_from_raw(pointer) }
}
pub fn validate(&self, current_revision: Revision, id: Id) {
let mut data = self.map.get_mut(&id).unwrap();
// UNSAFE: We never permit `&`-access in the current revision until data.created_at
// has been updated to the current revision (which we check below).
let data = unsafe { data.as_mut() };
// Never update a struct twice in the same revision.
assert!(data.created_at < current_revision);
data.created_at = current_revision;
}
/// Get mutable access to the data for `id` -- this holds a write lock for the duration
/// of the returned value.
///
/// # Panics
///
/// * If the value is not present in the map.
/// * If the value is already updated in this revision.
pub fn update(&self, current_revision: Revision, id: Id) -> Update<'_, C> {
let mut data = self.map.get_mut(&id).unwrap();
// UNSAFE: We never permit `&`-access in the current revision until data.created_at
// has been updated to the current revision (which we check below).
let data_ref = unsafe { data.as_mut() };
// Subtle: it's possible that this struct was already validated
// in this revision. What can happen (e.g., in the test
// `test_run_5_then_20` in `specify_tracked_fn_in_rev_1_but_not_2.rs`)
// is that
//
// * Revision 1:
// * Tracked function F creates tracked struct S
// * F reads input I
//
// In Revision 2, I is changed, and F is re-executed.
// We try to validate F's inputs/outputs, which is the list [output: S, input: I].
// As no inputs have changed by the time we reach S, we mark it as verified.
// But then input I is seen to hvae changed, and so we re-execute F.
// Note that we *know* that S will have the same value (barring program bugs).
//
// Further complicating things: it is possible that F calls F2
// and gives it (e.g.) S as one of its arguments. Validating F2 may cause F2 to
// re-execute which means that it may indeed have read from S's fields
// during the current revision and thus obtained an `&` reference to those fields
// that is still live.
//
// For this reason, we just return `None` in this case, ensuring that the calling
// code cannot violate that `&`-reference.
if data_ref.created_at == current_revision {
drop(data);
return Update::Current(Self::get_from_map(&self.map, current_revision, id));
}
data_ref.created_at = current_revision;
Update::Outdated(UpdateRef { guard: data })
}
/// Lookup an existing tracked struct from the map.
///
/// # Panics
///
/// * If the value is not present in the map.
/// * If the value has not been updated in this revision.
pub fn get(&self, current_revision: Revision, id: Id) -> C::Struct<'_> {
Self::get_from_map(&self.map, current_revision, id)
}
/// Helper function, provides shared functionality for [`StructMapView`][]
///
/// # Panics
///
/// * If the value is not present in the map.
/// * If the value has not been updated in this revision.
fn get_from_map(
map: &FxDashMap<Id, Alloc<Value<C>>>,
current_revision: Revision,
id: Id,
) -> C::Struct<'_> {
let data = map.get(&id).unwrap();
// UNSAFE: We permit `&`-access in the current revision once data.created_at
// has been updated to the current revision (which we check below).
let data_ref: &Value<C> = unsafe { data.as_ref() };
// Before we drop the lock, check that the value has
// been updated in this revision. This is what allows us to return a ``
let created_at = data_ref.created_at;
assert!(
created_at == current_revision,
"access to tracked struct from previous revision"
);
// Unsafety clause:
//
// * Value will not be updated again in this revision,
// and revision will not change so long as runtime is shared
// * We only remove values from the map when we have `&mut self`
unsafe { C::struct_from_raw(data.as_raw()) }
}
/// Remove the entry for `id` from the map.
///
/// NB. the data won't actually be freed until `drop_deleted_entries` is called.
pub fn delete(&self, id: Id) -> Option<KeyStruct> {
if let Some((_, data)) = self.map.remove(&id) {
// UNSAFE: The `key` field is immutable once `ValueStruct` is created.
let key = unsafe { data.as_ref() }.key;
self.deleted_entries.push(data);
Some(key)
} else {
None
}
}
/// Drop all entries deleted until now.
pub fn drop_deleted_entries(&mut self) {
std::mem::take(&mut self.deleted_entries);
}
}
impl<C> StructMapView<C>
where
C: Configuration,
{
/// Get a pointer to the data for the given `id`.
///
/// # Panics
///
/// * If the value is not present in the map.
/// * If the value has not been updated in this revision.
pub fn get(&self, current_revision: Revision, id: Id) -> C::Struct<'_> {
StructMap::get_from_map(&self.map, current_revision, id)
}
}
/// A mutable reference to the data for a single struct.
/// Can be "frozen" to yield an `&` that will remain valid
/// until the end of the revision.
pub(crate) struct UpdateRef<'db, C>
where
C: Configuration,
{
guard: RefMut<'db, Id, Alloc<Value<C>>>,
}
impl<'db, C> UpdateRef<'db, C>
where
C: Configuration,
{
/// Finalize this update, freezing the value for the rest of the revision.
pub fn freeze(self) -> C::Struct<'db> {
// Unsafety clause:
//
// see `get` above
let data = self.guard.as_raw();
unsafe { C::struct_from_raw(data) }
}
}
impl<C> Deref for UpdateRef<'_, C>
where
C: Configuration,
{
type Target = Value<C>;
fn deref(&self) -> &Self::Target {
unsafe { self.guard.as_ref() }
}
}
impl<C> DerefMut for UpdateRef<'_, C>
where
C: Configuration,
{
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { self.guard.as_mut() }
}
}

View file

@ -1,8 +1,8 @@
use crate::{
id::AsId, ingredient::Ingredient, key::DependencyIndex, zalsa::IngredientIndex, Database, Id,
};
use std::marker::PhantomData;
use super::{struct_map::StructMapView, Configuration};
use crate::{ingredient::Ingredient, zalsa::IngredientIndex, Database, Id};
use super::{Configuration, Value};
/// Created for each tracked struct.
/// This ingredient only stores the "id" fields.
@ -19,50 +19,20 @@ where
/// Index of this ingredient in the database (used to construct database-ids, etc).
ingredient_index: IngredientIndex,
field_index: usize,
struct_map: StructMapView<C>,
phantom: PhantomData<fn() -> Value<C>>,
}
impl<C> FieldIngredientImpl<C>
where
C: Configuration,
{
pub(super) fn new(
struct_index: IngredientIndex,
field_index: usize,
struct_map: &StructMapView<C>,
) -> Self {
pub(super) fn new(struct_index: IngredientIndex, field_index: usize) -> Self {
Self {
ingredient_index: struct_index.successor(field_index),
field_index,
struct_map: struct_map.clone(),
phantom: PhantomData,
}
}
unsafe fn to_self_ref<'db>(&'db self, fields: &'db C::Fields<'static>) -> &'db C::Fields<'db> {
unsafe { std::mem::transmute(fields) }
}
/// Access to this value field.
/// Note that this function returns the entire tuple of value fields.
/// The caller is responible for selecting the appropriate element.
pub fn field<'db>(&'db self, db: &'db dyn Database, id: Id) -> &'db C::Fields<'db> {
let zalsa_local = db.zalsa_local();
let current_revision = db.zalsa().current_revision();
let data = self.struct_map.get(current_revision, id);
let data = C::deref_struct(data);
let changed_at = data.revisions[self.field_index];
zalsa_local.report_tracked_read(
DependencyIndex {
ingredient_index: self.ingredient_index,
key_index: Some(id.as_id()),
},
data.durability,
changed_at,
);
unsafe { self.to_self_ref(&data.fields) }
}
}
impl<C> Ingredient for FieldIngredientImpl<C>
@ -83,10 +53,9 @@ where
input: Option<Id>,
revision: crate::Revision,
) -> bool {
let current_revision = db.zalsa().current_revision();
let zalsa = db.zalsa();
let id = input.unwrap();
let data = self.struct_map.get(current_revision, id);
let data = C::deref_struct(data);
let data = <super::IngredientImpl<C>>::data(zalsa.table(), id);
let field_changed_at = data.revisions[self.field_index];
field_changed_at > revision
}

View file

@ -24,6 +24,13 @@ use crate::{Database, DatabaseKeyIndex, Durability, Revision};
/// Do not implement this yourself, instead, apply the [`salsa::db`](`crate::db`) macro
/// to your database.
pub unsafe trait ZalsaDatabase: Any {
/// Plumbing method: access both zalsa and zalsa-local at once.
/// More efficient if you need both as it does only a single vtable dispatch.
#[doc(hidden)]
fn zalsas(&self) -> (&Zalsa, &ZalsaLocal) {
(self.zalsa(), self.zalsa_local())
}
/// Plumbing method: Access the internal salsa methods.
#[doc(hidden)]
fn zalsa(&self) -> &Zalsa;

View file

@ -9,6 +9,7 @@ use crate::runtime::StampedValue;
use crate::table::PageIndex;
use crate::table::Table;
use crate::tracked_struct::Disambiguator;
use crate::tracked_struct::KeyStruct;
use crate::zalsa::IngredientIndex;
use crate::Cancelled;
use crate::Cycle;
@ -246,7 +247,7 @@ impl ZalsaLocal {
entity_index: IngredientIndex,
reset_at: Revision,
data_hash: u64,
) -> (DatabaseKeyIndex, StampedValue<()>, Disambiguator) {
) -> (StampedValue<()>, Disambiguator) {
assert!(
self.query_in_progress(),
"cannot create a tracked struct disambiguator outside of a tracked function"
@ -262,7 +263,6 @@ impl ZalsaLocal {
let top_query = stack.last_mut().unwrap();
let disambiguator = top_query.disambiguate(data_hash);
(
top_query.database_key_index,
StampedValue {
value: (),
durability: top_query.durability,
@ -273,6 +273,34 @@ impl ZalsaLocal {
})
}
#[track_caller]
pub(crate) fn tracked_struct_id(&self, key_struct: &KeyStruct) -> Option<Id> {
debug_assert!(
self.query_in_progress(),
"cannot create a tracked struct disambiguator outside of a tracked function"
);
self.with_query_stack(|stack| {
let top_query = stack.last().unwrap();
top_query.tracked_struct_ids.get(&key_struct).cloned()
})
}
#[track_caller]
pub(crate) fn store_tracked_struct_id(&self, key_struct: KeyStruct, id: Id) {
debug_assert!(
self.query_in_progress(),
"cannot create a tracked struct disambiguator outside of a tracked function"
);
self.with_query_stack(|stack| {
let top_query = stack.last_mut().unwrap();
let old_id = top_query.tracked_struct_ids.insert(key_struct, id);
assert!(
old_id.is_none(),
"overwrote a previous id for `{key_struct:?}`"
);
})
}
/// Starts unwinding the stack if the current revision is cancelled.
///
/// This method can be called by query implementations that perform
@ -308,6 +336,7 @@ impl ZalsaLocal {
impl std::panic::RefUnwindSafe for ZalsaLocal {}
/// Summarizes "all the inputs that a query used"
/// and "all the outputs its wrote to"
#[derive(Debug, Clone)]
pub(crate) struct QueryRevisions {
/// The most revision in which some input changed.
@ -318,6 +347,11 @@ pub(crate) struct QueryRevisions {
/// How was this query computed?
pub(crate) origin: QueryOrigin,
/// The ids of tracked structs created by this query.
/// This is used to seed the next round if the query is
/// re-executed.
pub(super) tracked_struct_ids: FxHashMap<KeyStruct, Id>,
}
impl QueryRevisions {
@ -454,6 +488,16 @@ impl ActiveQueryGuard<'_> {
})
}
/// Initialize the tracked struct ids with the values from the prior execution.
pub(crate) fn seed_tracked_struct_ids(&self, tracked_struct_ids: &FxHashMap<KeyStruct, Id>) {
self.local_state.with_query_stack(|stack| {
assert_eq!(stack.len(), self.push_len);
let frame = stack.last_mut().unwrap();
assert!(frame.tracked_struct_ids.is_empty());
frame.tracked_struct_ids = tracked_struct_ids.clone();
})
}
/// Invoked when the query has successfully completed execution.
pub(crate) fn complete(self) -> ActiveQuery {
let query = self.pop_helper();

View file

@ -80,6 +80,7 @@ fn deletion_drops() {
"#]]
.assert_debug_eq(&dropped());
// Now that we execute with rev = 44, the old id is put on the free list
let tracked_struct = input.create_tracked_struct(&db);
assert_eq!(tracked_struct.field(&db).identity, 44);
@ -88,7 +89,9 @@ fn deletion_drops() {
"#]]
.assert_debug_eq(&dropped());
input.set_identity(&mut db).to(66);
// When we execute again with `input1`, that id is re-used, so the old value is deleted
let input1 = MyInput::new(&db, 66);
let _tracked_struct1 = input1.create_tracked_struct(&db);
expect_test::expect![[r#"
[

View file

@ -0,0 +1,100 @@
//! Test that a `tracked` fn on a `salsa::input`
//! compiles and executes successfully.
use std::cell::Cell;
use common::LogDatabase;
use expect_test::expect;
mod common;
use salsa::{Database, Setter};
use test_log::test;
thread_local! {
static COUNTER: Cell<usize> = const { Cell::new(0) };
}
#[salsa::input]
struct MyInput {
field1: u32,
field2: u32,
}
#[salsa::tracked]
struct MyTracked<'db> {
counter: usize,
}
#[salsa::tracked]
fn function(db: &dyn Database, input: MyInput) -> (usize, usize) {
// Read input 1
let _field1 = input.field1(db);
// **BAD:** Leak in the value of the counter non-deterministically
let counter = COUNTER.with(|c| c.get());
// Create the tracked struct, which (from salsa's POV), only depends on field1;
// but which actually depends on the leaked value.
let tracked = MyTracked::new(db, counter);
// Read the tracked field
let result = counter_field(db, input, tracked);
// Read input 2. This will cause us to re-execute on revision 2.
let _field2 = input.field2(db);
(result, tracked.counter(db))
}
#[salsa::tracked]
fn counter_field<'db>(db: &'db dyn Database, input: MyInput, tracked: MyTracked<'db>) -> usize {
// Read input 2. This will cause us to re-execute on revision 2.
let _field2 = input.field2(db);
tracked.counter(db)
}
#[test]
fn test_leaked_inputs_ignored() {
let mut db = common::EventLoggerDatabase::default();
let input = MyInput::new(&db, 10, 20);
let result_in_rev_1 = function(&db, input);
db.assert_logs(expect![[r#"
[
"Event { thread_id: ThreadId(2), kind: WillCheckCancellation }",
"Event { thread_id: ThreadId(2), kind: WillExecute { database_key: function(Id(0)) } }",
"Event { thread_id: ThreadId(2), kind: WillCheckCancellation }",
"Event { thread_id: ThreadId(2), kind: WillExecute { database_key: counter_field(Id(400)) } }",
]"#]]);
assert_eq!(result_in_rev_1, (0, 0));
// Modify field2 so that `function` is seen to have changed --
// but only *after* the tracked struct is created.
input.set_field2(&mut db).to(30);
// Also modify the thread-local counter
COUNTER.with(|c| c.set(100));
let result_in_rev_2 = function(&db, input);
db.assert_logs(expect![[r#"
[
"Event { thread_id: ThreadId(2), kind: DidSetCancellationFlag }",
"Event { thread_id: ThreadId(2), kind: WillCheckCancellation }",
"Event { thread_id: ThreadId(2), kind: WillCheckCancellation }",
"Event { thread_id: ThreadId(2), kind: WillExecute { database_key: counter_field(Id(400)) } }",
"Event { thread_id: ThreadId(2), kind: WillExecute { database_key: function(Id(0)) } }",
"Event { thread_id: ThreadId(2), kind: WillCheckCancellation }",
]"#]]);
// Salsa will re-execute `counter_field` before re-executing
// `function` since, from what it can see, no inputs have changed
// before `counter_field` is called. This will read the field of
// the tracked struct which means it will be *fixed* at `0`.
// When we re-execute `counter_field` later, we ignore the new
// value of 100 since the struct has already been read during
// this revision.
//
// Contrast with preverify-struct-with-leaked-data-2.rs.
assert_eq!(result_in_rev_2, (0, 0));
}

View file

@ -25,7 +25,7 @@ struct MyTracked<'db> {
}
#[salsa::tracked]
fn function(db: &dyn Database, input: MyInput) -> usize {
fn function(db: &dyn Database, input: MyInput) -> (usize, usize) {
// Read input 1
let _field1 = input.field1(db);
@ -36,9 +36,17 @@ fn function(db: &dyn Database, input: MyInput) -> usize {
// but which actually depends on the leaked value.
let tracked = MyTracked::new(db, counter);
// Read the tracked field
let result = counter_field(db, tracked);
// Read input 2. This will cause us to re-execute on revision 2.
let _field2 = input.field2(db);
(result, tracked.counter(db))
}
#[salsa::tracked]
fn counter_field<'db>(db: &'db dyn Database, tracked: MyTracked<'db>) -> usize {
tracked.counter(db)
}
@ -52,9 +60,11 @@ fn test_leaked_inputs_ignored() {
[
"Event { thread_id: ThreadId(2), kind: WillCheckCancellation }",
"Event { thread_id: ThreadId(2), kind: WillExecute { database_key: function(Id(0)) } }",
"Event { thread_id: ThreadId(2), kind: WillCheckCancellation }",
"Event { thread_id: ThreadId(2), kind: WillExecute { database_key: counter_field(Id(0)) } }",
]"#]]);
assert_eq!(result_in_rev_1, 0);
assert_eq!(result_in_rev_1, (0, 0));
// Modify field2 so that `function` is seen to have changed --
// but only *after* the tracked struct is created.
@ -68,12 +78,17 @@ fn test_leaked_inputs_ignored() {
[
"Event { thread_id: ThreadId(2), kind: DidSetCancellationFlag }",
"Event { thread_id: ThreadId(2), kind: WillCheckCancellation }",
"Event { thread_id: ThreadId(2), kind: WillCheckCancellation }",
"Event { thread_id: ThreadId(2), kind: DidValidateMemoizedValue { database_key: counter_field(Id(0)) } }",
"Event { thread_id: ThreadId(2), kind: WillExecute { database_key: function(Id(0)) } }",
"Event { thread_id: ThreadId(2), kind: WillCheckCancellation }",
]"#]]);
// Because salsa did not see any way for the tracked
// struct to have changed, its field values will not have
// been updated, even though in theory they would have
// the leaked value from the counter.
assert_eq!(result_in_rev_2, 0);
// Because salsa does not see any way for the tracked
// struct to have changed, it will re-use the cached return value
// from `counter_field` (`0`) but when we actually recreate
// the cached struct we get the new value (`100`).
//
// Contrast with preverify-struct-with-leaked-data-2.rs.
assert_eq!(result_in_rev_2, (0, 100));
}