diff --git a/components/salsa-macro-rules/src/setup_tracked_struct.rs b/components/salsa-macro-rules/src/setup_tracked_struct.rs index 21524f8a..c8e65e98 100644 --- a/components/salsa-macro-rules/src/setup_tracked_struct.rs +++ b/components/salsa-macro-rules/src/setup_tracked_struct.rs @@ -63,7 +63,7 @@ macro_rules! setup_tracked_struct { $(#[$attr])* #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] $vis struct $Struct<$db_lt>( - std::ptr::NonNull >>, + salsa::Id, std::marker::PhantomData < & $db_lt salsa::plumbing::tracked_struct::Value < $Struct<'static> > > ); @@ -90,12 +90,12 @@ macro_rules! setup_tracked_struct { type Struct<$db_lt> = $Struct<$db_lt>; - unsafe fn struct_from_raw<$db_lt>(ptr: $NonNull<$zalsa_struct::Value>) -> Self::Struct<$db_lt> { - $Struct(ptr, std::marker::PhantomData) + fn struct_from_id<$db_lt>(id: salsa::Id) -> Self::Struct<$db_lt> { + $Struct(id, std::marker::PhantomData) } - fn deref_struct(s: Self::Struct<'_>) -> &$zalsa_struct::Value { - unsafe { s.0.as_ref() } + fn deref_struct(s: Self::Struct<'_>) -> salsa::Id { + s.0 } fn id_fields(fields: &Self::Fields<'_>) -> impl std::hash::Hash { @@ -141,13 +141,13 @@ macro_rules! setup_tracked_struct { impl<$db_lt> $zalsa::LookupId<$db_lt> for $Struct<$db_lt> { fn lookup_id(id: salsa::Id, db: &$db_lt dyn $zalsa::Database) -> Self { - $Configuration::ingredient(db).lookup_struct(db, id) + $Struct(id, std::marker::PhantomData) } } impl $zalsa::AsId for $Struct<'_> { fn as_id(&self) -> $zalsa::Id { - unsafe { self.0.as_ref() }.as_id() + self.0 } } @@ -199,12 +199,13 @@ macro_rules! setup_tracked_struct { } $( - $field_getter_vis fn $field_getter_id<$Db>(&self, db: &$db_lt $Db) -> $crate::maybe_cloned_ty!($field_option, $db_lt, $field_ty) + $field_getter_vis fn $field_getter_id<$Db>(self, db: &$db_lt $Db) -> $crate::maybe_cloned_ty!($field_option, $db_lt, $field_ty) where // FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database` $Db: ?Sized + $zalsa::Database, { - let fields = unsafe { self.0.as_ref() }.field(db.as_dyn_database(), $field_index); + let db = db.as_dyn_database(); + let fields = $Configuration::ingredient(db).field(db, self, $field_index); $crate::maybe_clone!( $field_option, $field_ty, @@ -216,7 +217,7 @@ macro_rules! setup_tracked_struct { /// Default debug formatting for this struct (may be useful if you define your own `Debug` impl) pub fn default_debug_fmt(this: Self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { $zalsa::with_attached_database(|db| { - let fields = $Configuration::ingredient(db).leak_fields(this); + let fields = $Configuration::ingredient(db).leak_fields(db, this); let mut f = f.debug_struct(stringify!($Struct)); let f = f.field("[salsa id]", &$zalsa::AsId::as_id(&this)); $( diff --git a/examples/calc/parser.rs b/examples/calc/parser.rs index e5fe0d83..15f2ad91 100644 --- a/examples/calc/parser.rs +++ b/examples/calc/parser.rs @@ -378,7 +378,7 @@ fn parse_print() { let expected = expect_test::expect![[r#" ( Program { - [salsa id]: Id(0), + [salsa id]: Id(400), statements: [ Statement { span: Span { @@ -441,7 +441,7 @@ fn parse_example() { let expected = expect_test::expect![[r#" ( Program { - [salsa id]: Id(0), + [salsa id]: Id(1000), statements: [ Statement { span: Span { @@ -451,7 +451,7 @@ fn parse_example() { }, data: Function( Function { - [salsa id]: Id(0), + [salsa id]: Id(c00), name: FunctionId { text: "area_rectangle", }, @@ -513,7 +513,7 @@ fn parse_example() { }, data: Function( Function { - [salsa id]: Id(1), + [salsa id]: Id(c01), name: FunctionId { text: "area_circle", }, @@ -714,7 +714,7 @@ fn parse_error() { let expected = expect_test::expect![[r#" ( Program { - [salsa id]: Id(0), + [salsa id]: Id(400), statements: [], }, [ @@ -736,7 +736,7 @@ fn parse_precedence() { let expected = expect_test::expect![[r#" ( Program { - [salsa id]: Id(0), + [salsa id]: Id(400), statements: [ Statement { span: Span { diff --git a/examples/calc/type_check.rs b/examples/calc/type_check.rs index e84d0c9b..5276c487 100644 --- a/examples/calc/type_check.rs +++ b/examples/calc/type_check.rs @@ -268,7 +268,7 @@ fn fix_bad_variable_in_function() { expect![[r#" [ "Event: Event { thread_id: ThreadId(11), kind: WillExecute { database_key: parse_statements(Id(0)) } }", - "Event: Event { thread_id: ThreadId(11), kind: WillExecute { database_key: type_check_function(Id(0)) } }", + "Event: Event { thread_id: ThreadId(11), kind: WillExecute { database_key: type_check_function(Id(1400)) } }", ] "#]], )], diff --git a/src/active_query.rs b/src/active_query.rs index eb728597..3e0bac0a 100644 --- a/src/active_query.rs +++ b/src/active_query.rs @@ -1,10 +1,12 @@ +use rustc_hash::FxHashMap; + use crate::{ durability::Durability, hash::{FxIndexMap, FxIndexSet}, key::{DatabaseKeyIndex, DependencyIndex}, - tracked_struct::Disambiguator, + tracked_struct::{Disambiguator, KeyStruct}, zalsa_local::EMPTY_DEPENDENCIES, - Cycle, Revision, + Cycle, Id, Revision, }; use super::zalsa_local::{EdgeKind, QueryEdges, QueryOrigin, QueryRevisions}; @@ -35,10 +37,18 @@ pub(crate) struct ActiveQuery { /// Stores the entire cycle, if one is found and this query is part of it. pub(crate) cycle: Option, - /// When new entities are created, their data is hashed, and the resulting + /// When new tracked structs are created, their data is hashed, and the resulting /// hash is added to this map. If it is not present, then the disambiguator is 0. /// Otherwise it is 1 more than the current value (which is incremented). + /// + /// This table starts empty as the query begins and is gradually populated. + /// Note that if a query executes in 2 different revisions but creates the same + /// set of tracked structs, they will get the same disambiguator values. disambiguator_map: FxIndexMap, + + /// Map from tracked struct keys (which include the hash + disambiguator) to their + /// final id. + pub(crate) tracked_struct_ids: FxHashMap, } impl ActiveQuery { @@ -51,6 +61,7 @@ impl ActiveQuery { untracked_read: false, cycle: None, disambiguator_map: Default::default(), + tracked_struct_ids: Default::default(), } } @@ -106,6 +117,7 @@ impl ActiveQuery { changed_at: self.changed_at, origin, durability: self.durability, + tracked_struct_ids: self.tracked_struct_ids.clone(), } } diff --git a/src/alloc.rs b/src/alloc.rs index 557f429a..6562bef7 100644 --- a/src/alloc.rs +++ b/src/alloc.rs @@ -14,10 +14,6 @@ impl Alloc { } } - pub fn as_raw(&self) -> NonNull { - self.data - } - pub unsafe fn as_ref(&self) -> &T { unsafe { self.data.as_ref() } } diff --git a/src/function/execute.rs b/src/function/execute.rs index ddf58736..d3962754 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -39,6 +39,12 @@ where }, }); + // If we already executed this query once, then use the tracked-struct ids from the + // previous execution as the starting point for the new one. + if let Some(old_memo) = &opt_old_memo { + active_query.seed_tracked_struct_ids(&old_memo.revisions.tracked_struct_ids); + } + // Query was not previously executed, or value is potentially // stale, or value is absent. Let's execute! let database_key_index = active_query.database_key_index; diff --git a/src/function/specify.rs b/src/function/specify.rs index 9aabac54..bed71a13 100644 --- a/src/function/specify.rs +++ b/src/function/specify.rs @@ -68,6 +68,7 @@ where changed_at: current_deps.changed_at, durability: current_deps.durability, origin: QueryOrigin::Assigned(active_query_key), + tracked_struct_ids: Default::default(), }; if let Some(old_memo) = self.memo_map.get(key) { diff --git a/src/function/store.rs b/src/function/store.rs index 79757106..e09cd1a9 100644 --- a/src/function/store.rs +++ b/src/function/store.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use crossbeam::atomic::AtomicCell; +use rustc_hash::FxHashMap; use crate::{ durability::Durability, @@ -29,6 +30,7 @@ where changed_at: revision, durability, origin: QueryOrigin::BaseInput, + tracked_struct_ids: FxHashMap::default(), }, }; diff --git a/src/table.rs b/src/table.rs index 56682fd9..686ab615 100644 --- a/src/table.rs +++ b/src/table.rs @@ -1,4 +1,4 @@ -use std::{any::Any, mem::MaybeUninit}; +use std::{any::Any, cell::UnsafeCell, panic::RefUnwindSafe}; use append_only_vec::AppendOnlyVec; use crossbeam::atomic::AtomicCell; @@ -16,7 +16,7 @@ pub struct Table { pub struct Page { /// The ingredient for elements on this page. - #[allow(dead_code)] // pretty sure we'll need this eventually + #[allow(dead_code)] // pretty sure we'll need this ingredient: IngredientIndex, /// Number of elements of `data` that are initialized. @@ -34,9 +34,15 @@ pub struct Page { /// Vector with data. This is always created with the capacity/length of `PAGE_LEN` /// and uninitialized data. As we initialize new entries, we increment `allocated`. - data: Vec>, + data: Vec>, } +unsafe impl Send for Page {} + +unsafe impl Sync for Page {} + +impl RefUnwindSafe for Page {} + #[derive(Copy, Clone)] pub struct PageIndex(usize); @@ -58,6 +64,12 @@ impl Table { page_ref.get(slot) } + pub fn get_raw(&self, id: Id) -> *mut T { + let (page, slot) = split_id(id); + let page_ref = self.page::(page); + page_ref.get_raw(slot) + } + pub fn page(&self, page: PageIndex) -> &Page { self.pages[page.0].downcast_ref::>().unwrap() } @@ -85,7 +97,15 @@ impl Page { pub(crate) fn get(&self, slot: SlotIndex) -> &T { let len = self.allocated.load(); assert!(slot.0 < len); - unsafe { self.data[slot.0].assume_init_ref() } + unsafe { &*self.data[slot.0].get() } + } + + /// Returns a raw pointer to the given slot. + /// Reads/writes must be coordinated properly with calls to `get`. + pub(crate) fn get_raw(&self, slot: SlotIndex) -> *mut T { + let len = self.allocated.load(); + assert!(slot.0 < len); + self.data[slot.0].get() } pub(crate) fn allocate(&self, page: PageIndex, value: T) -> Result { @@ -95,10 +115,11 @@ impl Page { return Err(value); } + // Initialize entry `index` let data = &self.data[index]; - let data = data.as_ptr() as *mut T; - unsafe { std::ptr::write(data, value) }; + unsafe { std::ptr::write(data.get(), value) }; + // Update the length (this must be done after initialization!) self.allocated.store(index + 1); drop(guard); @@ -108,11 +129,14 @@ impl Page { impl Drop for Page { fn drop(&mut self) { + // Free `self.data` and the data within: to do this, we swap it out with an empty vector + // and then convert it from a `Vec>` with partially uninitialized values + // to a `Vec` with the correct length. This way the `Vec` drop impl can do its job. let mut data = std::mem::replace(&mut self.data, vec![]); let len = self.allocated.load(); unsafe { data.set_len(len); - drop(std::mem::transmute::>, Vec>(data)); + drop(std::mem::transmute::>, Vec>(data)); } } } diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index b0eb049e..e61825ce 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -1,25 +1,22 @@ -use std::{fmt, hash::Hash, marker::PhantomData, ops::DerefMut, ptr::NonNull}; +use std::{fmt, hash::Hash, marker::PhantomData, ops::DerefMut}; -use crossbeam::atomic::AtomicCell; -use dashmap::mapref::entry::Entry; +use crossbeam::{atomic::AtomicCell, queue::SegQueue}; use tracked_field::FieldIngredientImpl; use crate::{ cycle::CycleRecoveryStrategy, - hash::FxDashMap, - id::AsId, ingredient::{fmt_index, Ingredient, Jar}, ingredient_list::IngredientList, key::{DatabaseKeyIndex, DependencyIndex}, + plumbing::ZalsaLocal, + runtime::StampedValue, salsa_struct::SalsaStructInDb, - zalsa::IngredientIndex, + table::Table, + zalsa::{IngredientIndex, Zalsa}, zalsa_local::QueryOrigin, Database, Durability, Event, Id, Revision, }; -use self::struct_map::{StructMap, Update}; - -mod struct_map; pub mod tracked_field; // ANCHOR: Configuration @@ -47,18 +44,10 @@ pub trait Configuration: Sized + 'static { /// process in a given revision: it occurs only when the struct is newly /// created or, if a struct is being reused, after we have updated its /// fields (or confirmed it is green and no updates are required). - /// - /// # Safety - /// - /// Requires that `ptr` represents a "confirmed" value in this revision, - /// which means that it will remain valid and immutable for the remainder of this - /// revision, represented by the lifetime `'db`. - unsafe fn struct_from_raw<'db>(ptr: NonNull>) -> Self::Struct<'db>; + fn struct_from_id<'db>(id: Id) -> Self::Struct<'db>; - /// Deref the struct to yield the underlying value struct. - /// Since we are still part of the `'db` lifetime in which the struct was created, - /// this deref is safe, and the value-struct fields are immutable and verified. - fn deref_struct(s: Self::Struct<'_>) -> &Value; + /// Deref the struct to yield the underlying id. + fn deref_struct(s: Self::Struct<'_>) -> Id; fn id_fields(fields: &Self::Fields<'_>) -> impl Hash; @@ -115,16 +104,11 @@ impl Jar for JarImpl { &self, struct_index: crate::zalsa::IngredientIndex, ) -> Vec> { - let struct_ingredient = IngredientImpl::new(struct_index); - let struct_map = &struct_ingredient.struct_map.view(); + let struct_ingredient = >::new(struct_index); std::iter::once(Box::new(struct_ingredient) as _) .chain((0..C::FIELD_DEBUG_NAMES.len()).map(|field_index| { - Box::new(FieldIngredientImpl::::new( - struct_index, - field_index, - struct_map, - )) as _ + Box::new(>::new(struct_index, field_index)) as _ })) .collect() } @@ -150,18 +134,6 @@ where /// Our index in the database. ingredient_index: IngredientIndex, - /// Defines the set of live tracked structs. - /// Entries are added to this map when a new struct is created. - /// They are removed when that struct is deleted - /// (i.e., a query completes without having recreated the struct). - keys: FxDashMap, - - /// The number of tracked structs created. - counter: AtomicCell, - - /// Map from the [`Id`][] of each struct to its fields/values. - struct_map: struct_map::StructMap, - /// A list of each tracked function `f` whose key is this /// tracked struct. /// @@ -169,14 +141,20 @@ where /// each of these functions will be notified /// so they can remove any data tied to that instance. dependent_fns: IngredientList, + + /// Phantom data: we fetch `Value` out from `Table` + phantom: PhantomData Value>, + + /// Store freed ids + free_list: SegQueue, } /// Defines the identity of a tracked struct. +/// This is the key to a hashmap that is (initially) +/// stored in the [`ActiveQuery`](`crate::active_query::ActiveQuery`) +/// struct and later moved to the [`Memo`](`crate::function::memo::Memo`). #[derive(Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Copy, Clone)] -struct KeyStruct { - /// The active query (i.e., tracked function) that created this tracked struct. - query_key: DatabaseKeyIndex, - +pub(crate) struct KeyStruct { /// The hash of the `#[id]` fields of this struct. /// Note that multiple structs may share the same hash. data_hash: u64, @@ -192,29 +170,32 @@ pub struct Value where C: Configuration, { - /// Index of the struct ingredient. - struct_ingredient_index: IngredientIndex, - - /// The id of this struct in the ingredient. - id: Id, - - /// The key used to create the id. - key: KeyStruct, - /// The durability minimum durability of all inputs consumed /// by the creator query prior to creating this tracked struct. /// If any of those inputs changes, then the creator query may /// create this struct with different values. durability: Durability, - /// The revision when this entity was most recently created. - /// Typically the current revision. - /// Used to detect "leaks" outside of the salsa system -- i.e., - /// access to tracked structs that have not (yet?) been created in the - /// current revision. This should be impossible within salsa queries - /// but it can happen through "leaks" like thread-local data or storing - /// values outside of the root salsa query. - created_at: Revision, + /// The revision when this tracked struct was last updated. + /// This field also acts as a kind of "lock". Once it is equal + /// to `Some(current_revision)`, the fields are locked and + /// cannot change further. This makes it safe to give out `&`-references + /// so long as they do not live longer than the current revision + /// (which is assured by tying their lifetime to the lifetime of an `&`-ref + /// to the database). + /// + /// The struct is updated from an older revision `R0` to the current revision `R1` + /// when the struct is first accessed in `R1`, whether that be because the original + /// query re-created the struct (i.e., by user calling `Struct::new`) or because + /// the struct was read from. (Structs may not be recreated in the new revision if + /// the inputs to the query have not changed.) + /// + /// When re-creating the struct, the field is temporarily set to `None`. + /// This is signal that there is an active `&mut` modifying the other fields: + /// even reading from those fields in that situation would create UB. + /// This `None` value should never be observable by users unless they have + /// leaked a reference across threads somehow. + updated_at: AtomicCell>, /// Fields of this tracked struct. They can change across revisions, /// but they do not change within a particular revision. @@ -240,6 +221,10 @@ where unsafe { std::mem::transmute(fields) } } + unsafe fn to_self_ref<'db>(&'db self, fields: &'db C::Fields<'static>) -> &'db C::Fields<'db> { + unsafe { std::mem::transmute(fields) } + } + /// Convert from static back to the db lifetime; used when returning data /// out from this ingredient. unsafe fn to_self_ptr<'db>(&'db self, fields: *mut C::Fields<'static>) -> *mut C::Fields<'db> { @@ -251,10 +236,9 @@ where fn new(index: IngredientIndex) -> Self { Self { ingredient_index: index, - keys: FxDashMap::default(), - counter: AtomicCell::new(0), - struct_map: StructMap::new(), dependent_fns: IngredientList::new(), + phantom: PhantomData, + free_list: Default::default(), } } @@ -266,120 +250,184 @@ where } } - /// Intern a tracked struct key to get a unique tracked struct id. - /// Also returns a bool indicating whether this id was newly created or whether it already existed. - fn intern(&self, key: KeyStruct) -> (Id, bool) { - let (id, new_id) = if let Some(g) = self.keys.get(&key) { - (*g.value(), false) - } else { - match self.keys.entry(key) { - Entry::Occupied(o) => (*o.get(), false), - Entry::Vacant(v) => { - let id = Id::from_u32(self.counter.fetch_add(1)); - v.insert(id); - (id, true) - } - } - }; - - (id, new_id) - } - pub fn new_struct<'db>( &'db self, db: &'db dyn Database, fields: C::Fields<'db>, ) -> C::Struct<'db> { - let zalsa = db.zalsa(); - let zalsa_local = db.zalsa_local(); + let (zalsa, zalsa_local) = db.zalsas(); let data_hash = crate::hash::hash(&C::id_fields(&fields)); - let (query_key, current_deps, disambiguator) = + let (current_deps, disambiguator) = zalsa_local.disambiguate(self.ingredient_index, Revision::start(), data_hash); - let entity_key = KeyStruct { - query_key, + let key_struct = KeyStruct { disambiguator, data_hash, }; - let (id, new_id) = self.intern(entity_key); - zalsa_local.add_output(self.database_key_index(id).into()); - let current_revision = zalsa.current_revision(); - if new_id { - // This is a new tracked struct, so create an entry in the struct map. + match zalsa_local.tracked_struct_id(&key_struct) { + Some(id) => { + // The struct already exists in the intern map. + zalsa_local.add_output(self.database_key_index(id).into()); + self.update(zalsa, current_revision, id, ¤t_deps, fields); + C::struct_from_id(id) + } - self.struct_map.insert( - current_revision, - Value { - id, - key: entity_key, - struct_ingredient_index: self.ingredient_index, - created_at: current_revision, - durability: current_deps.durability, - fields: unsafe { self.to_static(fields) }, - revisions: C::new_revisions(current_deps.changed_at), - }, - ) - } else { - // The struct already exists in the intern map. - // Note that we assume there is at most one executing copy of - // the current query at a time, which implies that the - // struct must exist in `self.struct_map` already - // (if the same query could execute twice in parallel, - // then it would potentially create the same struct twice in parallel, - // which means the interned key could exist but `struct_map` not yet have - // been updated). - - match self.struct_map.update(current_revision, id) { - Update::Current(r) => { - // All inputs up to this point were previously - // observed to be green and this struct was already - // verified. Therefore, the durability ought not to have - // changed (nor the field values, but the user could've - // done something stupid, so we can't *assert* this is true). - assert!(C::deref_struct(r).durability == current_deps.durability); - - r - } - Update::Outdated(mut data_ref) => { - let data = &mut *data_ref; - - // SAFETY: We assert that the pointer to `data.revisions` - // is a pointer into the database referencing a value - // from a previous revision. As such, it continues to meet - // its validity invariant and any owned content also continues - // to meet its safety invariant. - unsafe { - C::update_fields( - current_revision, - &mut data.revisions, - self.to_self_ptr(std::ptr::addr_of_mut!(data.fields)), - fields, - ); - } - if current_deps.durability < data.durability { - data.revisions = C::new_revisions(current_revision); - } - data.durability = current_deps.durability; - data.created_at = current_revision; - data_ref.freeze() - } + None => { + // This is a new tracked struct, so create an entry in the struct map. + let id = self.allocate(zalsa, zalsa_local, current_revision, ¤t_deps, fields); + zalsa_local.add_output(self.database_key_index(id).into()); + zalsa_local.store_tracked_struct_id(key_struct, id); + C::struct_from_id(id) } } } - /// Given the id of a tracked struct created in this revision, - /// returns a pointer to the struct. + fn allocate<'db>( + &'db self, + zalsa: &'db Zalsa, + zalsa_local: &'db ZalsaLocal, + current_revision: Revision, + current_deps: &StampedValue<()>, + fields: C::Fields<'db>, + ) -> Id { + let value = || Value { + updated_at: AtomicCell::new(Some(current_revision)), + durability: current_deps.durability, + fields: unsafe { self.to_static(fields) }, + revisions: C::new_revisions(current_deps.changed_at), + }; + + if let Some(id) = self.free_list.pop() { + let data_raw = Self::data_raw(zalsa.table(), id); + assert!( + unsafe { (*data_raw).updated_at.load().is_none() }, + "free list entry for `{id:?}` does not have `None` for `updated_at`" + ); + + // Overwrite the free-list entry. Use `*foo = ` because the entry + // has been previously initialized and we want to free the old contents. + unsafe { + *data_raw = value(); + } + + id + } else { + zalsa_local.allocate::>(zalsa.table(), self.ingredient_index, value()) + } + } + + /// Get mutable access to the data for `id` -- this holds a write lock for the duration + /// of the returned value. /// /// # Panics /// - /// If the struct has not been created in this revision. - pub fn lookup_struct<'db>(&'db self, db: &'db dyn Database, id: Id) -> C::Struct<'db> { - let current_revision = db.zalsa().current_revision(); - self.struct_map.get(current_revision, id) + /// * If the value is not present in the map. + /// * If the value is already updated in this revision. + fn update<'db>( + &'db self, + zalsa: &'db Zalsa, + current_revision: Revision, + id: Id, + current_deps: &StampedValue<()>, + fields: C::Fields<'db>, + ) { + let data_raw = Self::data_raw(zalsa.table(), id); + + // The protocol is: + // + // * When we begin updating, we store `None` in the `created_at` field + // * When completed, we store `Some(current_revision)` in `created_at` + // + // No matter what mischief users get up to, it should be impossible for us to + // observe `None` in `created_at`. The `id` should only be associated with one + // query and that query can only be running in one thread at a time. + // + // We *can* observe `Some(current_revision)` however, which means that this + // tracked struct is already updated for this revision in two ways. + // In that case we should not modify or touch it because there may be + // `&`-references to its contents floating around. + // + // Observing `Some(current_revision)` can happen in two scenarios: leaks (tsk tsk) + // but also the scenario embodied by the test test `test_run_5_then_20` in `specify_tracked_fn_in_rev_1_but_not_2.rs`: + // + // * Revision 1: + // * Tracked function F creates tracked struct S + // * F reads input I + // * Revision 2: I is changed, F is re-executed + // + // When F is re-executed in rev 2, we first try to validate F's inputs/outputs, + // which is the list [output: S, input: I]. As no inputs have changed by the time + // we reach S, we mark it as verified. But then input I is seen to hvae changed, + // and so we re-execute F. Note that we *know* that S will have the same value + // (barring program bugs). + // + // Further complicating things: it is possible that F calls F2 + // and gives it (e.g.) S as one of its arguments. Validating F2 may cause F2 to + // re-execute which means that it may indeed have read from S's fields + // during the current revision and thus obtained an `&` reference to those fields + // that is still live. + + // UNSAFE: Marking as mut requires exclusive access for the duration of + // the `mut`. We have now *claimed* this data by swapping in `None`, + // any attempt to read concurrently will panic. + let last_updated_at = unsafe { (*data_raw).updated_at.load() }; + assert!( + last_updated_at.is_some(), + "two concurrent writers to {id:?}, should not be possible" + ); + if last_updated_at == Some(current_revision) { + // already read-locked + return; + } + + // Acquire the write-lock. This can only fail if there is a parallel thread + // reading from this same `id`, which can only happen if the user has leaked it. + // Tsk tsk. + let swapped_out = unsafe { (*data_raw).updated_at.swap(None) }; + if swapped_out != last_updated_at { + panic!( + "failed to acquire write lock, id `{id:?}` must have been leaked across threads" + ); + } + + // UNSAFE: Marking as mut requires exclusive access for the duration of + // the `mut`. We have now *claimed* this data by swapping in `None`, + // any attempt to read concurrently will panic. + let data = unsafe { &mut *data_raw }; + + // SAFETY: We assert that the pointer to `data.revisions` + // is a pointer into the database referencing a value + // from a previous revision. As such, it continues to meet + // its validity invariant and any owned content also continues + // to meet its safety invariant. + unsafe { + C::update_fields( + current_revision, + &mut data.revisions, + self.to_self_ptr(std::ptr::addr_of_mut!(data.fields)), + fields, + ); + } + if current_deps.durability < data.durability { + data.revisions = C::new_revisions(current_revision); + } + data.durability = current_deps.durability; + let swapped_out = data.updated_at.swap(Some(current_revision)); + assert!(swapped_out.is_none()); + } + + /// Fetch the data for a given id created by this ingredient from the table, + /// -giving it the appropriate type. + fn data<'t>(table: &'t Table, id: Id) -> &'t Value { + table.get(id) + } + + fn data_raw(table: &Table, id: Id) -> *mut Value { + table.get_raw(id) } /// Deletes the given entities. This is used after a query `Q` executes and we can compare @@ -400,8 +448,29 @@ where }, }); - if let Some(key) = self.struct_map.delete(id) { - self.keys.remove(&key); + let zalsa = db.zalsa(); + let current_revision = zalsa.current_revision(); + let data = Self::data(zalsa.table(), id); + + // We want to set `updated_at` to `None`, signalling that other field values + // cannot be read. The current vaue should be `Some(R0)` for some older revision. + match data.updated_at.load() { + None => { + panic!("cannot delete write-locked id `{id:?}`; value leaked across threads"); + } + + Some(r) => { + if r == current_revision { + panic!( + "cannot delete read-locked id `{id:?}`; \ + value leaked across threads or user functions not deterministic" + ) + } + + if data.updated_at.compare_exchange(Some(r), None).is_err() { + panic!("race occurred when deleting value `{id:?}`") + } + } } for dependent_fn in self.dependent_fns.iter() { @@ -409,6 +478,9 @@ where .lookup_ingredient(dependent_fn) .salsa_struct_deleted(db, id); } + + // now that all cleanup has occurred, make available for re-use + self.free_list.push(id); } /// Adds a dependent function (one keyed by this tracked struct) to our list. @@ -420,9 +492,67 @@ where /// Return reference to the field data ignoring dependency tracking. /// Used for debugging. - pub fn leak_fields<'db>(&'db self, s: C::Struct<'db>) -> &'db C::Fields<'db> { - let value = C::deref_struct(s); - unsafe { value.to_self_ref(&value.fields) } + pub fn leak_fields<'db>( + &'db self, + db: &'db dyn Database, + s: C::Struct<'db>, + ) -> &'db C::Fields<'db> { + let id = C::deref_struct(s); + let value = Self::data(db.zalsa().table(), id); + unsafe { self.to_self_ref(&value.fields) } + } + + /// Access to this value field. + /// Note that this function returns the entire tuple of value fields. + /// The caller is responible for selecting the appropriate element. + pub fn field<'db>( + &'db self, + db: &'db dyn crate::Database, + s: C::Struct<'db>, + field_index: usize, + ) -> &'db C::Fields<'db> { + let (zalsa, zalsa_local) = db.zalsas(); + let id = C::deref_struct(s); + let field_ingredient_index = self.ingredient_index.successor(field_index); + let data = Self::data(zalsa.table(), id); + + self.read_lock(data, zalsa.current_revision()); + + let field_changed_at = data.revisions[field_index]; + + zalsa_local.report_tracked_read( + DependencyIndex { + ingredient_index: field_ingredient_index, + key_index: Some(id), + }, + data.durability, + field_changed_at, + ); + + unsafe { self.to_self_ref(&data.fields) } + } + + fn read_lock(&self, data: &Value, current_revision: Revision) { + loop { + match data.updated_at.load() { + None => { + panic!("access to field whilst the value is being initialized"); + } + Some(r) => { + if r == current_revision { + return; + } + + if data + .updated_at + .compare_exchange(Some(r), Some(current_revision)) + .is_ok() + { + break; + } + } + } + } } } @@ -453,13 +583,13 @@ where fn mark_validated_output<'db>( &'db self, - db: &'db dyn Database, + _db: &'db dyn Database, _executor: DatabaseKeyIndex, - output_key: Option, + _output_key: Option, ) { - let current_revision = db.zalsa().current_revision(); - let output_key = output_key.unwrap(); - self.struct_map.validate(current_revision, output_key); + // we used to update `update_at` field but now we do it lazilly when data is accessed + // + // FIXME: delete this method } fn remove_stale_output( @@ -475,14 +605,6 @@ where self.delete_entity(db.as_dyn_database(), stale_output_key.unwrap()); } - fn requires_reset_for_new_revision(&self) -> bool { - true - } - - fn reset_for_new_revision(&mut self) { - self.struct_map.drop_deleted_entries(); - } - fn salsa_struct_deleted(&self, _db: &dyn Database, _id: crate::Id) { panic!("unexpected call: interned ingredients do not register for salsa struct deletion events"); } @@ -494,6 +616,12 @@ where fn debug_name(&self) -> &'static str { C::DEBUG_NAME } + + fn requires_reset_for_new_revision(&self) -> bool { + false + } + + fn reset_for_new_revision(&mut self) {} } impl std::fmt::Debug for IngredientImpl @@ -506,45 +634,3 @@ where .finish() } } - -impl Value -where - C: Configuration, -{ - /// Access to this value field. - /// Note that this function returns the entire tuple of value fields. - /// The caller is responible for selecting the appropriate element. - pub fn field<'db>( - &'db self, - db: &dyn crate::Database, - field_index: usize, - ) -> &'db C::Fields<'db> { - let zalsa_local = db.zalsa_local(); - let field_ingredient_index = self.struct_ingredient_index.successor(field_index); - let changed_at = self.revisions[field_index]; - - zalsa_local.report_tracked_read( - DependencyIndex { - ingredient_index: field_ingredient_index, - key_index: Some(self.id.as_id()), - }, - self.durability, - changed_at, - ); - - unsafe { self.to_self_ref(&self.fields) } - } - - unsafe fn to_self_ref<'db>(&'db self, fields: &'db C::Fields<'static>) -> &'db C::Fields<'db> { - unsafe { std::mem::transmute(fields) } - } -} - -impl AsId for Value -where - C: Configuration, -{ - fn as_id(&self) -> Id { - self.id - } -} diff --git a/src/tracked_struct/struct_map.rs b/src/tracked_struct/struct_map.rs deleted file mode 100644 index 2502f485..00000000 --- a/src/tracked_struct/struct_map.rs +++ /dev/null @@ -1,280 +0,0 @@ -use std::{ - ops::{Deref, DerefMut}, - sync::Arc, -}; - -use crossbeam::queue::SegQueue; -use dashmap::mapref::one::RefMut; - -use crate::{alloc::Alloc, hash::FxDashMap, Id, Revision}; - -use super::{Configuration, KeyStruct, Value}; - -pub(crate) struct StructMap -where - C: Configuration, -{ - map: Arc>>>, - - /// When specific entities are deleted, their data is added - /// to this vector rather than being immediately freed. This is because we may` have - /// references to that data floating about that are tied to the lifetime of some - /// `&db` reference. This queue itself is not freed until we have an `&mut db` reference, - /// guaranteeing that there are no more references to it. - deleted_entries: SegQueue>>, -} - -pub(crate) struct StructMapView -where - C: Configuration, -{ - map: Arc>>>, -} - -impl Clone for StructMapView { - fn clone(&self) -> Self { - Self { - map: self.map.clone(), - } - } -} - -/// Return value for [`StructMap`][]'s `update` method. -pub(crate) enum Update<'db, C> -where - C: Configuration, -{ - /// Indicates that the given struct has not yet been verified in this revision. - /// The [`UpdateRef`][] gives mutable access to the field contents so that - /// its fields can be compared and updated. - Outdated(UpdateRef<'db, C>), - - /// Indicates that we have already verified that all the inputs accessed prior - /// to this struct creation were up-to-date, and therefore the field contents - /// ought not to have changed (barring user error). Returns a shared reference - /// because caller cannot safely modify fields at this point. - Current(C::Struct<'db>), -} - -impl StructMap -where - C: Configuration, -{ - pub fn new() -> Self { - Self { - map: Arc::new(FxDashMap::default()), - deleted_entries: SegQueue::new(), - } - } - - /// Get a secondary view onto this struct-map that can be used to fetch entries. - pub fn view(&self) -> StructMapView { - StructMapView { - map: self.map.clone(), - } - } - - /// Insert the given tracked struct value into the map. - /// - /// # Panics - /// - /// * If value with same `value.id` is already present in the map. - /// * If value not created in current revision. - pub fn insert(&self, current_revision: Revision, value: Value) -> C::Struct<'_> { - assert_eq!(value.created_at, current_revision); - - let id = value.id; - let boxed_value = Alloc::new(value); - let pointer = boxed_value.as_raw(); - - let old_value = self.map.insert(id, boxed_value); - assert!(old_value.is_none()); // ...strictly speaking we probably need to abort here - - // Unsafety clause: - // - // * The box is owned by self and, although the box has been moved, - // the pointer is to the contents of the box, which have a stable - // address. - // * Values are only removed or altered when we have `&mut self`. - unsafe { C::struct_from_raw(pointer) } - } - - pub fn validate(&self, current_revision: Revision, id: Id) { - let mut data = self.map.get_mut(&id).unwrap(); - - // UNSAFE: We never permit `&`-access in the current revision until data.created_at - // has been updated to the current revision (which we check below). - let data = unsafe { data.as_mut() }; - - // Never update a struct twice in the same revision. - assert!(data.created_at < current_revision); - data.created_at = current_revision; - } - - /// Get mutable access to the data for `id` -- this holds a write lock for the duration - /// of the returned value. - /// - /// # Panics - /// - /// * If the value is not present in the map. - /// * If the value is already updated in this revision. - pub fn update(&self, current_revision: Revision, id: Id) -> Update<'_, C> { - let mut data = self.map.get_mut(&id).unwrap(); - - // UNSAFE: We never permit `&`-access in the current revision until data.created_at - // has been updated to the current revision (which we check below). - let data_ref = unsafe { data.as_mut() }; - - // Subtle: it's possible that this struct was already validated - // in this revision. What can happen (e.g., in the test - // `test_run_5_then_20` in `specify_tracked_fn_in_rev_1_but_not_2.rs`) - // is that - // - // * Revision 1: - // * Tracked function F creates tracked struct S - // * F reads input I - // - // In Revision 2, I is changed, and F is re-executed. - // We try to validate F's inputs/outputs, which is the list [output: S, input: I]. - // As no inputs have changed by the time we reach S, we mark it as verified. - // But then input I is seen to hvae changed, and so we re-execute F. - // Note that we *know* that S will have the same value (barring program bugs). - // - // Further complicating things: it is possible that F calls F2 - // and gives it (e.g.) S as one of its arguments. Validating F2 may cause F2 to - // re-execute which means that it may indeed have read from S's fields - // during the current revision and thus obtained an `&` reference to those fields - // that is still live. - // - // For this reason, we just return `None` in this case, ensuring that the calling - // code cannot violate that `&`-reference. - if data_ref.created_at == current_revision { - drop(data); - return Update::Current(Self::get_from_map(&self.map, current_revision, id)); - } - - data_ref.created_at = current_revision; - Update::Outdated(UpdateRef { guard: data }) - } - - /// Lookup an existing tracked struct from the map. - /// - /// # Panics - /// - /// * If the value is not present in the map. - /// * If the value has not been updated in this revision. - pub fn get(&self, current_revision: Revision, id: Id) -> C::Struct<'_> { - Self::get_from_map(&self.map, current_revision, id) - } - - /// Helper function, provides shared functionality for [`StructMapView`][] - /// - /// # Panics - /// - /// * If the value is not present in the map. - /// * If the value has not been updated in this revision. - fn get_from_map( - map: &FxDashMap>>, - current_revision: Revision, - id: Id, - ) -> C::Struct<'_> { - let data = map.get(&id).unwrap(); - - // UNSAFE: We permit `&`-access in the current revision once data.created_at - // has been updated to the current revision (which we check below). - let data_ref: &Value = unsafe { data.as_ref() }; - - // Before we drop the lock, check that the value has - // been updated in this revision. This is what allows us to return a `` - let created_at = data_ref.created_at; - assert!( - created_at == current_revision, - "access to tracked struct from previous revision" - ); - - // Unsafety clause: - // - // * Value will not be updated again in this revision, - // and revision will not change so long as runtime is shared - // * We only remove values from the map when we have `&mut self` - unsafe { C::struct_from_raw(data.as_raw()) } - } - - /// Remove the entry for `id` from the map. - /// - /// NB. the data won't actually be freed until `drop_deleted_entries` is called. - pub fn delete(&self, id: Id) -> Option { - if let Some((_, data)) = self.map.remove(&id) { - // UNSAFE: The `key` field is immutable once `ValueStruct` is created. - let key = unsafe { data.as_ref() }.key; - self.deleted_entries.push(data); - Some(key) - } else { - None - } - } - - /// Drop all entries deleted until now. - pub fn drop_deleted_entries(&mut self) { - std::mem::take(&mut self.deleted_entries); - } -} - -impl StructMapView -where - C: Configuration, -{ - /// Get a pointer to the data for the given `id`. - /// - /// # Panics - /// - /// * If the value is not present in the map. - /// * If the value has not been updated in this revision. - pub fn get(&self, current_revision: Revision, id: Id) -> C::Struct<'_> { - StructMap::get_from_map(&self.map, current_revision, id) - } -} - -/// A mutable reference to the data for a single struct. -/// Can be "frozen" to yield an `&` that will remain valid -/// until the end of the revision. -pub(crate) struct UpdateRef<'db, C> -where - C: Configuration, -{ - guard: RefMut<'db, Id, Alloc>>, -} - -impl<'db, C> UpdateRef<'db, C> -where - C: Configuration, -{ - /// Finalize this update, freezing the value for the rest of the revision. - pub fn freeze(self) -> C::Struct<'db> { - // Unsafety clause: - // - // see `get` above - let data = self.guard.as_raw(); - unsafe { C::struct_from_raw(data) } - } -} - -impl Deref for UpdateRef<'_, C> -where - C: Configuration, -{ - type Target = Value; - - fn deref(&self) -> &Self::Target { - unsafe { self.guard.as_ref() } - } -} - -impl DerefMut for UpdateRef<'_, C> -where - C: Configuration, -{ - fn deref_mut(&mut self) -> &mut Self::Target { - unsafe { self.guard.as_mut() } - } -} diff --git a/src/tracked_struct/tracked_field.rs b/src/tracked_struct/tracked_field.rs index 07a93c19..714f80d9 100644 --- a/src/tracked_struct/tracked_field.rs +++ b/src/tracked_struct/tracked_field.rs @@ -1,8 +1,8 @@ -use crate::{ - id::AsId, ingredient::Ingredient, key::DependencyIndex, zalsa::IngredientIndex, Database, Id, -}; +use std::marker::PhantomData; -use super::{struct_map::StructMapView, Configuration}; +use crate::{ingredient::Ingredient, zalsa::IngredientIndex, Database, Id}; + +use super::{Configuration, Value}; /// Created for each tracked struct. /// This ingredient only stores the "id" fields. @@ -19,50 +19,20 @@ where /// Index of this ingredient in the database (used to construct database-ids, etc). ingredient_index: IngredientIndex, field_index: usize, - struct_map: StructMapView, + phantom: PhantomData Value>, } impl FieldIngredientImpl where C: Configuration, { - pub(super) fn new( - struct_index: IngredientIndex, - field_index: usize, - struct_map: &StructMapView, - ) -> Self { + pub(super) fn new(struct_index: IngredientIndex, field_index: usize) -> Self { Self { ingredient_index: struct_index.successor(field_index), field_index, - struct_map: struct_map.clone(), + phantom: PhantomData, } } - - unsafe fn to_self_ref<'db>(&'db self, fields: &'db C::Fields<'static>) -> &'db C::Fields<'db> { - unsafe { std::mem::transmute(fields) } - } - - /// Access to this value field. - /// Note that this function returns the entire tuple of value fields. - /// The caller is responible for selecting the appropriate element. - pub fn field<'db>(&'db self, db: &'db dyn Database, id: Id) -> &'db C::Fields<'db> { - let zalsa_local = db.zalsa_local(); - let current_revision = db.zalsa().current_revision(); - let data = self.struct_map.get(current_revision, id); - let data = C::deref_struct(data); - let changed_at = data.revisions[self.field_index]; - - zalsa_local.report_tracked_read( - DependencyIndex { - ingredient_index: self.ingredient_index, - key_index: Some(id.as_id()), - }, - data.durability, - changed_at, - ); - - unsafe { self.to_self_ref(&data.fields) } - } } impl Ingredient for FieldIngredientImpl @@ -83,10 +53,9 @@ where input: Option, revision: crate::Revision, ) -> bool { - let current_revision = db.zalsa().current_revision(); + let zalsa = db.zalsa(); let id = input.unwrap(); - let data = self.struct_map.get(current_revision, id); - let data = C::deref_struct(data); + let data = >::data(zalsa.table(), id); let field_changed_at = data.revisions[self.field_index]; field_changed_at > revision } diff --git a/src/zalsa.rs b/src/zalsa.rs index 817bf760..29e2afb2 100644 --- a/src/zalsa.rs +++ b/src/zalsa.rs @@ -24,6 +24,13 @@ use crate::{Database, DatabaseKeyIndex, Durability, Revision}; /// Do not implement this yourself, instead, apply the [`salsa::db`](`crate::db`) macro /// to your database. pub unsafe trait ZalsaDatabase: Any { + /// Plumbing method: access both zalsa and zalsa-local at once. + /// More efficient if you need both as it does only a single vtable dispatch. + #[doc(hidden)] + fn zalsas(&self) -> (&Zalsa, &ZalsaLocal) { + (self.zalsa(), self.zalsa_local()) + } + /// Plumbing method: Access the internal salsa methods. #[doc(hidden)] fn zalsa(&self) -> &Zalsa; diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index 725b8f27..87051352 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -9,6 +9,7 @@ use crate::runtime::StampedValue; use crate::table::PageIndex; use crate::table::Table; use crate::tracked_struct::Disambiguator; +use crate::tracked_struct::KeyStruct; use crate::zalsa::IngredientIndex; use crate::Cancelled; use crate::Cycle; @@ -246,7 +247,7 @@ impl ZalsaLocal { entity_index: IngredientIndex, reset_at: Revision, data_hash: u64, - ) -> (DatabaseKeyIndex, StampedValue<()>, Disambiguator) { + ) -> (StampedValue<()>, Disambiguator) { assert!( self.query_in_progress(), "cannot create a tracked struct disambiguator outside of a tracked function" @@ -262,7 +263,6 @@ impl ZalsaLocal { let top_query = stack.last_mut().unwrap(); let disambiguator = top_query.disambiguate(data_hash); ( - top_query.database_key_index, StampedValue { value: (), durability: top_query.durability, @@ -273,6 +273,34 @@ impl ZalsaLocal { }) } + #[track_caller] + pub(crate) fn tracked_struct_id(&self, key_struct: &KeyStruct) -> Option { + debug_assert!( + self.query_in_progress(), + "cannot create a tracked struct disambiguator outside of a tracked function" + ); + self.with_query_stack(|stack| { + let top_query = stack.last().unwrap(); + top_query.tracked_struct_ids.get(&key_struct).cloned() + }) + } + + #[track_caller] + pub(crate) fn store_tracked_struct_id(&self, key_struct: KeyStruct, id: Id) { + debug_assert!( + self.query_in_progress(), + "cannot create a tracked struct disambiguator outside of a tracked function" + ); + self.with_query_stack(|stack| { + let top_query = stack.last_mut().unwrap(); + let old_id = top_query.tracked_struct_ids.insert(key_struct, id); + assert!( + old_id.is_none(), + "overwrote a previous id for `{key_struct:?}`" + ); + }) + } + /// Starts unwinding the stack if the current revision is cancelled. /// /// This method can be called by query implementations that perform @@ -308,6 +336,7 @@ impl ZalsaLocal { impl std::panic::RefUnwindSafe for ZalsaLocal {} /// Summarizes "all the inputs that a query used" +/// and "all the outputs its wrote to" #[derive(Debug, Clone)] pub(crate) struct QueryRevisions { /// The most revision in which some input changed. @@ -318,6 +347,11 @@ pub(crate) struct QueryRevisions { /// How was this query computed? pub(crate) origin: QueryOrigin, + + /// The ids of tracked structs created by this query. + /// This is used to seed the next round if the query is + /// re-executed. + pub(super) tracked_struct_ids: FxHashMap, } impl QueryRevisions { @@ -454,6 +488,16 @@ impl ActiveQueryGuard<'_> { }) } + /// Initialize the tracked struct ids with the values from the prior execution. + pub(crate) fn seed_tracked_struct_ids(&self, tracked_struct_ids: &FxHashMap) { + self.local_state.with_query_stack(|stack| { + assert_eq!(stack.len(), self.push_len); + let frame = stack.last_mut().unwrap(); + assert!(frame.tracked_struct_ids.is_empty()); + frame.tracked_struct_ids = tracked_struct_ids.clone(); + }) + } + /// Invoked when the query has successfully completed execution. pub(crate) fn complete(self) -> ActiveQuery { let query = self.pop_helper(); diff --git a/tests/deletion-drops.rs b/tests/deletion-drops.rs index b03ceda7..57811569 100644 --- a/tests/deletion-drops.rs +++ b/tests/deletion-drops.rs @@ -80,6 +80,7 @@ fn deletion_drops() { "#]] .assert_debug_eq(&dropped()); + // Now that we execute with rev = 44, the old id is put on the free list let tracked_struct = input.create_tracked_struct(&db); assert_eq!(tracked_struct.field(&db).identity, 44); @@ -88,7 +89,9 @@ fn deletion_drops() { "#]] .assert_debug_eq(&dropped()); - input.set_identity(&mut db).to(66); + // When we execute again with `input1`, that id is re-used, so the old value is deleted + let input1 = MyInput::new(&db, 66); + let _tracked_struct1 = input1.create_tracked_struct(&db); expect_test::expect![[r#" [ diff --git a/tests/preverify-struct-with-leaked-data-2.rs b/tests/preverify-struct-with-leaked-data-2.rs new file mode 100644 index 00000000..1dc68121 --- /dev/null +++ b/tests/preverify-struct-with-leaked-data-2.rs @@ -0,0 +1,100 @@ +//! Test that a `tracked` fn on a `salsa::input` +//! compiles and executes successfully. + +use std::cell::Cell; + +use common::LogDatabase; +use expect_test::expect; +mod common; +use salsa::{Database, Setter}; +use test_log::test; + +thread_local! { + static COUNTER: Cell = const { Cell::new(0) }; +} + +#[salsa::input] +struct MyInput { + field1: u32, + field2: u32, +} + +#[salsa::tracked] +struct MyTracked<'db> { + counter: usize, +} + +#[salsa::tracked] +fn function(db: &dyn Database, input: MyInput) -> (usize, usize) { + // Read input 1 + let _field1 = input.field1(db); + + // **BAD:** Leak in the value of the counter non-deterministically + let counter = COUNTER.with(|c| c.get()); + + // Create the tracked struct, which (from salsa's POV), only depends on field1; + // but which actually depends on the leaked value. + let tracked = MyTracked::new(db, counter); + + // Read the tracked field + let result = counter_field(db, input, tracked); + + // Read input 2. This will cause us to re-execute on revision 2. + let _field2 = input.field2(db); + + (result, tracked.counter(db)) +} + +#[salsa::tracked] +fn counter_field<'db>(db: &'db dyn Database, input: MyInput, tracked: MyTracked<'db>) -> usize { + // Read input 2. This will cause us to re-execute on revision 2. + let _field2 = input.field2(db); + + tracked.counter(db) +} + +#[test] +fn test_leaked_inputs_ignored() { + let mut db = common::EventLoggerDatabase::default(); + + let input = MyInput::new(&db, 10, 20); + let result_in_rev_1 = function(&db, input); + db.assert_logs(expect![[r#" + [ + "Event { thread_id: ThreadId(2), kind: WillCheckCancellation }", + "Event { thread_id: ThreadId(2), kind: WillExecute { database_key: function(Id(0)) } }", + "Event { thread_id: ThreadId(2), kind: WillCheckCancellation }", + "Event { thread_id: ThreadId(2), kind: WillExecute { database_key: counter_field(Id(400)) } }", + ]"#]]); + + assert_eq!(result_in_rev_1, (0, 0)); + + // Modify field2 so that `function` is seen to have changed -- + // but only *after* the tracked struct is created. + input.set_field2(&mut db).to(30); + + // Also modify the thread-local counter + COUNTER.with(|c| c.set(100)); + + let result_in_rev_2 = function(&db, input); + db.assert_logs(expect![[r#" + [ + "Event { thread_id: ThreadId(2), kind: DidSetCancellationFlag }", + "Event { thread_id: ThreadId(2), kind: WillCheckCancellation }", + "Event { thread_id: ThreadId(2), kind: WillCheckCancellation }", + "Event { thread_id: ThreadId(2), kind: WillExecute { database_key: counter_field(Id(400)) } }", + "Event { thread_id: ThreadId(2), kind: WillExecute { database_key: function(Id(0)) } }", + "Event { thread_id: ThreadId(2), kind: WillCheckCancellation }", + ]"#]]); + + // Salsa will re-execute `counter_field` before re-executing + // `function` since, from what it can see, no inputs have changed + // before `counter_field` is called. This will read the field of + // the tracked struct which means it will be *fixed* at `0`. + // When we re-execute `counter_field` later, we ignore the new + // value of 100 since the struct has already been read during + // this revision. + // + // Contrast with preverify-struct-with-leaked-data-2.rs. + assert_eq!(result_in_rev_2, (0, 0)); +} diff --git a/tests/preverify-struct-with-leaked-data.rs b/tests/preverify-struct-with-leaked-data.rs index 3966edaf..448fad90 100644 --- a/tests/preverify-struct-with-leaked-data.rs +++ b/tests/preverify-struct-with-leaked-data.rs @@ -25,7 +25,7 @@ struct MyTracked<'db> { } #[salsa::tracked] -fn function(db: &dyn Database, input: MyInput) -> usize { +fn function(db: &dyn Database, input: MyInput) -> (usize, usize) { // Read input 1 let _field1 = input.field1(db); @@ -36,9 +36,17 @@ fn function(db: &dyn Database, input: MyInput) -> usize { // but which actually depends on the leaked value. let tracked = MyTracked::new(db, counter); + // Read the tracked field + let result = counter_field(db, tracked); + // Read input 2. This will cause us to re-execute on revision 2. let _field2 = input.field2(db); + (result, tracked.counter(db)) +} + +#[salsa::tracked] +fn counter_field<'db>(db: &'db dyn Database, tracked: MyTracked<'db>) -> usize { tracked.counter(db) } @@ -52,9 +60,11 @@ fn test_leaked_inputs_ignored() { [ "Event { thread_id: ThreadId(2), kind: WillCheckCancellation }", "Event { thread_id: ThreadId(2), kind: WillExecute { database_key: function(Id(0)) } }", + "Event { thread_id: ThreadId(2), kind: WillCheckCancellation }", + "Event { thread_id: ThreadId(2), kind: WillExecute { database_key: counter_field(Id(0)) } }", ]"#]]); - assert_eq!(result_in_rev_1, 0); + assert_eq!(result_in_rev_1, (0, 0)); // Modify field2 so that `function` is seen to have changed -- // but only *after* the tracked struct is created. @@ -68,12 +78,17 @@ fn test_leaked_inputs_ignored() { [ "Event { thread_id: ThreadId(2), kind: DidSetCancellationFlag }", "Event { thread_id: ThreadId(2), kind: WillCheckCancellation }", + "Event { thread_id: ThreadId(2), kind: WillCheckCancellation }", + "Event { thread_id: ThreadId(2), kind: DidValidateMemoizedValue { database_key: counter_field(Id(0)) } }", "Event { thread_id: ThreadId(2), kind: WillExecute { database_key: function(Id(0)) } }", + "Event { thread_id: ThreadId(2), kind: WillCheckCancellation }", ]"#]]); - // Because salsa did not see any way for the tracked - // struct to have changed, its field values will not have - // been updated, even though in theory they would have - // the leaked value from the counter. - assert_eq!(result_in_rev_2, 0); + // Because salsa does not see any way for the tracked + // struct to have changed, it will re-use the cached return value + // from `counter_field` (`0`) but when we actually recreate + // the cached struct we get the new value (`100`). + // + // Contrast with preverify-struct-with-leaked-data-2.rs. + assert_eq!(result_in_rev_2, (0, 100)); }