diff --git a/book/src/SUMMARY.md b/book/src/SUMMARY.md index be31ee0b..a4fdeee9 100644 --- a/book/src/SUMMARY.md +++ b/book/src/SUMMARY.md @@ -31,6 +31,7 @@ - [Plumbing](./plumbing.md) - [Jars and ingredients](./plumbing/jars_and_ingredients.md) - [Databases and runtime](./plumbing/database_and_runtime.md) + - [Tracked structures](./plumbing/tracked_structs.md) - [Query operations](./plumbing/query_ops.md) - [maybe changed after](./plumbing/maybe_changed_after.md) - [Fetch](./plumbing/fetch.md) diff --git a/book/src/plumbing/tracked_structs.md b/book/src/plumbing/tracked_structs.md new file mode 100644 index 00000000..4444b463 --- /dev/null +++ b/book/src/plumbing/tracked_structs.md @@ -0,0 +1,50 @@ +# Tracked structs + +Tracked structs are stored in a special way to reduce their costs. + +Tracked structs are created via a `new` operation. + +## The tracked struct and tracked field ingredients + +For a single tracked struct we create multiple ingredients. +The **tracked struct ingredient** is the ingredient created first. +It offers methods to create new instances of the struct and therefore +has unique access to the interner and hashtables used to create the struct id. +It also shares access to a hashtable that stores the `TrackedStructValue` that +contains the field data. + +For each field, we create a **tracked field ingredient** that moderates access +to a particular field. All of these ingredients use that same shared hashtable +to access the `TrackedStructValue` instance for a given id. The `TrackedStructValue` +contains both the field values but also the revisions when they last changed value. + +## Each tracked struct has a globally unique id + +This will begin by creating a *globally unique, 32-bit id* for the tracked struct. It is created by interning a combination of + +* the currently executing query; +* a u64 hash of the `#[id]` fields; +* a *disambiguator* that makes this hash unique within the current query. i.e., when a query starts executing, it creates an empty map, and the first time a tracked struct with a given hash is created, it gets disambiguator 0. The next one will be given 1, etc. + +## Each tracked struct has a `TrackedStructValue` storing its data + +The struct and field ingredients share access to a hashmap that maps +each field id to a value struct: + +```rust,ignore +{{#include ../../../components/salsa-2022/src/tracked_struct.rs:TrackedStructValue}} +``` + +The value struct stores the values of the fields but also the revisions when +that field last changed. Each time the struct is recreated in a new revision, +the old and new values for its fields are compared and a new revision is created. + +## The macro generates the tracked struct `Configuration` + +The "configuration" for a tracked struct defines not only the types of the fields, +but also various important operations such as extracting the hashable id fields +and updating the "revisions" to track when a field last changed: + +```rust,ignore +{{#include ../../../components/salsa-2022/src/tracked_struct.rs:Configuration}} +``` diff --git a/components/salsa-2022-macros/src/salsa_struct.rs b/components/salsa-2022-macros/src/salsa_struct.rs index 3c3378b6..5b20b761 100644 --- a/components/salsa-2022-macros/src/salsa_struct.rs +++ b/components/salsa-2022-macros/src/salsa_struct.rs @@ -25,12 +25,8 @@ //! * data method `impl Foo { fn data(&self, db: &dyn crate::Db) -> FooData { FooData { f: self.f(db), ... } } }` //! * this could be optimized, particularly for interned fields -use crate::{ - configuration, - options::{AllowedOptions, Options}, -}; -use heck::ToUpperCamelCase; -use proc_macro2::{Ident, Literal, Span, TokenStream}; +use crate::options::{AllowedOptions, Options}; +use proc_macro2::{Ident, Span, TokenStream}; use syn::spanned::Spanned; pub(crate) enum SalsaStructKind { @@ -217,69 +213,6 @@ impl SalsaStruct { } } - /// For each of the fields passed as an argument, - /// generate a struct named `Ident_Field` and an impl - /// of `salsa::function::Configuration` for that struct. - pub(crate) fn field_config_structs_and_impls<'a>( - &self, - fields: impl Iterator, - ) -> (Vec, Vec) { - let ident = &self.id_ident(); - let jar_ty = self.jar_ty(); - let visibility = self.visibility(); - fields - .map(|ef| { - let value_field_name = ef.name(); - let value_field_ty = ef.ty(); - let value_field_backdate = ef.is_backdate_field(); - let config_name = syn::Ident::new( - &format!( - "__{}", - format!("{}_{}", ident, value_field_name).to_upper_camel_case() - ), - value_field_name.span(), - ); - let item_struct: syn::ItemStruct = parse_quote! { - #[derive(Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash, Debug)] - #visibility struct #config_name(std::convert::Infallible); - }; - - let execute_string = Literal::string(&format!("`execute` method for field `{}::{}` invoked", - ident, - ef.name(), - )); - - let recover_from_cycle_string = Literal::string(&format!("`execute` method for field `{}::{}` invoked", - ident, - ef.name(), - )); - - let should_backdate_value_fn = configuration::should_backdate_value_fn(value_field_backdate); - let item_impl: syn::ItemImpl = parse_quote! { - impl salsa::function::Configuration for #config_name { - type Jar = #jar_ty; - type SalsaStruct = #ident; - type Key = #ident; - type Value = #value_field_ty; - const CYCLE_STRATEGY: salsa::cycle::CycleRecoveryStrategy = salsa::cycle::CycleRecoveryStrategy::Panic; - - #should_backdate_value_fn - - fn execute(db: &salsa::function::DynDb, key: Self::Key) -> Self::Value { - panic!(#execute_string) - } - - fn recover_from_cycle(db: &salsa::function::DynDb, cycle: &salsa::Cycle, key: Self::Key) -> Self::Value { - panic!(#recover_from_cycle_string) - } - } - }; - - (item_struct, item_impl) - }) - .unzip() - } - /// Generate `impl salsa::AsId for Foo` pub(crate) fn as_id_impl(&self) -> syn::ItemImpl { let ident = self.id_ident(); @@ -307,7 +240,6 @@ impl SalsaStruct { // `::salsa::debug::helper::SalsaDebug` will use `DebugWithDb` or fallbak to `Debug` let fields = self .all_fields() - .into_iter() .map(|field| -> TokenStream { let field_name_string = field.name().to_string(); let field_getter = field.get_name(); @@ -405,7 +337,7 @@ impl SalsaField { if BANNED_FIELD_NAMES.iter().any(|n| *n == field_name_str) { return Err(syn::Error::new( field_name.span(), - &format!( + format!( "the field name `{}` is disallowed in salsa structs", field_name_str ), @@ -435,6 +367,10 @@ impl SalsaField { Ok(result) } + pub(crate) fn span(&self) -> Span { + self.field.span() + } + /// The name of this field (all `SalsaField` instances are named). pub(crate) fn name(&self) -> &syn::Ident { self.field.ident.as_ref().unwrap() diff --git a/components/salsa-2022-macros/src/tracked_fn.rs b/components/salsa-2022-macros/src/tracked_fn.rs index b15d6ae6..d6dec496 100644 --- a/components/salsa-2022-macros/src/tracked_fn.rs +++ b/components/salsa-2022-macros/src/tracked_fn.rs @@ -105,6 +105,7 @@ pub(crate) fn tracked_impl( ), None => format!("{}", self_type_name), }; + #[allow(clippy::manual_try_fold)] // we accumulate errors let extra_impls = item_impl .items .iter_mut() diff --git a/components/salsa-2022-macros/src/tracked_struct.rs b/components/salsa-2022-macros/src/tracked_struct.rs index 98f36cf9..feb36031 100644 --- a/components/salsa-2022-macros/src/tracked_struct.rs +++ b/components/salsa-2022-macros/src/tracked_struct.rs @@ -1,4 +1,4 @@ -use proc_macro2::{Literal, TokenStream}; +use proc_macro2::{Literal, Span, TokenStream}; use crate::salsa_struct::{SalsaField, SalsaStruct, SalsaStructKind}; @@ -51,18 +51,18 @@ impl TrackedStruct { fn generate_tracked(&self) -> syn::Result { self.validate_tracked()?; - let (config_structs, config_impls) = - self.field_config_structs_and_impls(self.value_fields()); - let id_struct = self.id_struct(); + let config_struct = self.config_struct(); + let config_impl = self.config_impl(&config_struct); let inherent_impl = self.tracked_inherent_impl(); - let ingredients_for_impl = self.tracked_struct_ingredients(&config_structs); + let ingredients_for_impl = self.tracked_struct_ingredients(&config_struct); let salsa_struct_in_db_impl = self.salsa_struct_in_db_impl(); let tracked_struct_in_db_impl = self.tracked_struct_in_db_impl(); let as_id_impl = self.as_id_impl(); let as_debug_with_db_impl = self.as_debug_with_db_impl(); Ok(quote! { - #(#config_structs)* + #config_struct + #config_impl #id_struct #inherent_impl #ingredients_for_impl @@ -70,7 +70,6 @@ impl TrackedStruct { #tracked_struct_in_db_impl #as_id_impl #as_debug_with_db_impl - #(#config_impls)* }) } @@ -78,27 +77,104 @@ impl TrackedStruct { Ok(()) } + fn config_struct(&self) -> syn::ItemStruct { + let config_ident = syn::Ident::new( + &format!("__{}Config", self.id_ident()), + self.id_ident().span(), + ); + let visibility = self.visibility(); + + parse_quote! { + #visibility struct #config_ident { + _uninhabited: std::convert::Infallible, + } + } + } + + fn config_impl(&self, config_struct: &syn::ItemStruct) -> syn::ItemImpl { + let id_ident = self.id_ident(); + let config_ident = &config_struct.ident; + let field_tys: Vec<_> = self.all_fields().map(SalsaField::ty).collect(); + let id_field_indices = self.id_field_indices(); + let arity = self.all_field_count(); + + // Create the function body that will update the revisions for each field. + // If a field is a "backdate field" (the default), then we first check if + // the new value is `==` to the old value. If so, we leave the revision unchanged. + let old_value = syn::Ident::new("old_value_", Span::call_site()); + let new_value = syn::Ident::new("new_value_", Span::call_site()); + let revisions = syn::Ident::new("revisions_", Span::call_site()); + let current_revision = syn::Ident::new("current_revision_", Span::call_site()); + let update_revisions: TokenStream = self + .all_fields() + .zip(0..) + .map(|(field, i)| { + let field_index = Literal::u32_unsuffixed(i); + if field.is_backdate_field() { + quote_spanned! { field.span() => + if #old_value.#field_index != #new_value.#field_index { + #revisions[#field_index] = #current_revision; + } + } + } else { + quote_spanned! { field.span() => + #revisions[#field_index] = #current_revision; + } + } + }) + .collect(); + + parse_quote! { + impl salsa::tracked_struct::Configuration for #config_ident { + type Id = #id_ident; + type Fields = ( #(#field_tys,)* ); + type Revisions = [salsa::Revision; #arity]; + + #[allow(clippy::unused_unit)] + fn id_fields(fields: &Self::Fields) -> impl std::hash::Hash { + ( #( &fields.#id_field_indices ),* ) + } + + fn revision(revisions: &Self::Revisions, field_index: u32) -> salsa::Revision { + revisions[field_index as usize] + } + + fn new_revisions(current_revision: salsa::Revision) -> Self::Revisions { + [current_revision; #arity] + } + + fn update_revisions( + #current_revision: salsa::Revision, + #old_value: &Self::Fields, + #new_value: &Self::Fields, + #revisions: &mut Self::Revisions, + ) { + #update_revisions + } + } + } + } + /// Generate an inherent impl with methods on the tracked type. fn tracked_inherent_impl(&self) -> syn::ItemImpl { let ident = self.id_ident(); let jar_ty = self.jar_ty(); let db_dyn_ty = self.db_dyn_ty(); - let struct_index = self.tracked_struct_index(); + let tracked_field_ingredients: Literal = self.tracked_field_ingredients_index(); - let id_field_indices: Vec<_> = self.id_field_indices(); - let id_field_names: Vec<_> = self.id_fields().map(SalsaField::name).collect(); - let id_field_get_names: Vec<_> = self.id_fields().map(SalsaField::get_name).collect(); - let id_field_tys: Vec<_> = self.id_fields().map(SalsaField::ty).collect(); - let id_field_vises: Vec<_> = self.id_fields().map(SalsaField::vis).collect(); - let id_field_clones: Vec<_> = self.id_fields().map(SalsaField::is_clone_field).collect(); - let id_field_getters: Vec = id_field_indices.iter().zip(&id_field_get_names).zip(&id_field_tys).zip(&id_field_vises).zip(&id_field_clones).map(|((((field_index, field_get_name), field_ty), field_vis), is_clone_field)| + let field_indices = self.all_field_indices(); + let field_vises: Vec<_> = self.all_fields().map(SalsaField::vis).collect(); + let field_tys: Vec<_> = self.all_fields().map(SalsaField::ty).collect(); + let field_get_names: Vec<_> = self.all_fields().map(SalsaField::get_name).collect(); + let field_clones: Vec<_> = self.all_fields().map(SalsaField::is_clone_field).collect(); + let field_getters: Vec = field_indices.iter().zip(&field_get_names).zip(&field_tys).zip(&field_vises).zip(&field_clones).map(|((((field_index, field_get_name), field_ty), field_vis), is_clone_field)| if !*is_clone_field { parse_quote! { #field_vis fn #field_get_name<'db>(self, __db: &'db #db_dyn_ty) -> &'db #field_ty { let (__jar, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(__db); let __ingredients = <#jar_ty as salsa::storage::HasIngredientsFor< #ident >>::ingredient(__jar); - &__ingredients.#struct_index.tracked_struct_data(__runtime, self).#field_index + &__ingredients.#tracked_field_ingredients[#field_index].field(__runtime, self).#field_index } } } else { @@ -107,65 +183,32 @@ impl TrackedStruct { { let (__jar, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(__db); let __ingredients = <#jar_ty as salsa::storage::HasIngredientsFor< #ident >>::ingredient(__jar); - __ingredients.#struct_index.tracked_struct_data(__runtime, self).#field_index.clone() + __ingredients.#tracked_field_ingredients[#field_index].field(__runtime, self).#field_index.clone() } } } ) .collect(); - let value_field_indices = self.value_field_indices(); - let value_field_names: Vec<_> = self.value_fields().map(SalsaField::name).collect(); - let value_field_vises: Vec<_> = self.value_fields().map(SalsaField::vis).collect(); - let value_field_tys: Vec<_> = self.value_fields().map(SalsaField::ty).collect(); - let value_field_get_names: Vec<_> = self.value_fields().map(SalsaField::get_name).collect(); - let value_field_clones: Vec<_> = self - .value_fields() - .map(SalsaField::is_clone_field) - .collect(); - let value_field_getters: Vec = value_field_indices.iter().zip(&value_field_get_names).zip(&value_field_tys).zip(&value_field_vises).zip(&value_field_clones).map(|((((field_index, field_get_name), field_ty), field_vis), is_clone_field)| - if !*is_clone_field { - parse_quote! { - #field_vis fn #field_get_name<'db>(self, __db: &'db #db_dyn_ty) -> &'db #field_ty - { - let (__jar, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(__db); - let __ingredients = <#jar_ty as salsa::storage::HasIngredientsFor< #ident >>::ingredient(__jar); - __ingredients.#field_index.fetch(__db, self) - } - } - } else { - parse_quote! { - #field_vis fn #field_get_name<'db>(self, __db: &'db #db_dyn_ty) -> #field_ty - { - let (__jar, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(__db); - let __ingredients = <#jar_ty as salsa::storage::HasIngredientsFor< #ident >>::ingredient(__jar); - __ingredients.#field_index.fetch(__db, self).clone() - } - } - } - ) - .collect(); - - let all_field_names = self.all_field_names(); - let all_field_tys = self.all_field_tys(); + let field_names = self.all_field_names(); + let field_tys = self.all_field_tys(); let constructor_name = self.constructor_name(); parse_quote! { + #[allow(clippy::too_many_arguments)] impl #ident { - pub fn #constructor_name(__db: &#db_dyn_ty, #(#all_field_names: #all_field_tys,)*) -> Self + pub fn #constructor_name(__db: &#db_dyn_ty, #(#field_names: #field_tys,)*) -> Self { let (__jar, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(__db); let __ingredients = <#jar_ty as salsa::storage::HasIngredientsFor< #ident >>::ingredient(__jar); - let __id = __ingredients.#struct_index.new_struct(__runtime, (#(#id_field_names,)*)); - #( - __ingredients.#value_field_indices.specify_and_record(__db, __id, #value_field_names); - )* + let __id = __ingredients.0.new_struct( + __runtime, + (#(#field_names,)*), + ); __id } - #(#id_field_getters)* - - #(#value_field_getters)* + #(#field_getters)* } } } @@ -174,14 +217,15 @@ impl TrackedStruct { /// /// The tracked struct's ingredients include both the main tracked struct ingredient along with a /// function ingredient for each of the value fields. - fn tracked_struct_ingredients(&self, config_structs: &[syn::ItemStruct]) -> syn::ItemImpl { + fn tracked_struct_ingredients(&self, config_struct: &syn::ItemStruct) -> syn::ItemImpl { use crate::literal; let ident = self.id_ident(); let jar_ty = self.jar_ty(); - let id_field_tys: Vec<&syn::Type> = self.id_fields().map(SalsaField::ty).collect(); - let value_field_indices: Vec = self.value_field_indices(); - let tracked_struct_index: Literal = self.tracked_struct_index(); - let config_struct_names = config_structs.iter().map(|s| &s.ident); + let config_struct_name = &config_struct.ident; + let field_indices: Vec = self.all_field_indices(); + let arity = self.all_field_count(); + let tracked_struct_ingredient: Literal = self.tracked_struct_ingredient_index(); + let tracked_fields_ingredients: Literal = self.tracked_field_ingredients_index(); let debug_name_struct = literal(self.id_ident()); let debug_name_fields: Vec<_> = self.all_field_names().into_iter().map(literal).collect(); @@ -189,10 +233,8 @@ impl TrackedStruct { impl salsa::storage::IngredientsFor for #ident { type Jar = #jar_ty; type Ingredients = ( - #( - salsa::function::FunctionIngredient<#config_struct_names>, - )* - salsa::tracked_struct::TrackedStructIngredient<#ident, (#(#id_field_tys,)*)>, + salsa::tracked_struct::TrackedStructIngredient<#config_struct_name>, + [salsa::tracked_struct::TrackedFieldIngredient<#config_struct_name>; #arity], ); fn create_ingredients( @@ -201,40 +243,43 @@ impl TrackedStruct { where DB: salsa::DbWithJar + salsa::storage::JarFromJars, { - ( + let struct_ingredient = { + let index = routes.push( + |jars| { + let jar = >::jar_from_jars(jars); + let ingredients = <_ as salsa::storage::HasIngredientsFor>::ingredient(jar); + &ingredients.#tracked_struct_ingredient + }, + |jars| { + let jar = >::jar_from_jars_mut(jars); + let ingredients = <_ as salsa::storage::HasIngredientsFor>::ingredient_mut(jar); + &mut ingredients.#tracked_struct_ingredient + }, + ); + salsa::tracked_struct::TrackedStructIngredient::new(index, #debug_name_struct) + }; + + let field_ingredients = [ #( { let index = routes.push( |jars| { let jar = >::jar_from_jars(jars); let ingredients = <_ as salsa::storage::HasIngredientsFor>::ingredient(jar); - &ingredients.#value_field_indices + &ingredients.#tracked_fields_ingredients[#field_indices] }, |jars| { let jar = >::jar_from_jars_mut(jars); let ingredients = <_ as salsa::storage::HasIngredientsFor>::ingredient_mut(jar); - &mut ingredients.#value_field_indices + &mut ingredients.#tracked_fields_ingredients[#field_indices] }, ); - salsa::function::FunctionIngredient::new(index, #debug_name_fields) + struct_ingredient.new_field_ingredient(index, #field_indices, #debug_name_fields) }, )* - { - let index = routes.push( - |jars| { - let jar = >::jar_from_jars(jars); - let ingredients = <_ as salsa::storage::HasIngredientsFor>::ingredient(jar); - &ingredients.#tracked_struct_index - }, - |jars| { - let jar = >::jar_from_jars_mut(jars); - let ingredients = <_ as salsa::storage::HasIngredientsFor>::ingredient_mut(jar); - &mut ingredients.#tracked_struct_index - }, - ); - salsa::tracked_struct::TrackedStructIngredient::new(index, #debug_name_struct) - }, - ) + ]; + + (struct_ingredient, field_ingredients) } } } @@ -244,7 +289,7 @@ impl TrackedStruct { fn salsa_struct_in_db_impl(&self) -> syn::ItemImpl { let ident = self.id_ident(); let jar_ty = self.jar_ty(); - let tracked_struct_index: Literal = self.tracked_struct_index(); + let tracked_struct_ingredient = self.tracked_struct_ingredient_index(); parse_quote! { impl salsa::salsa_struct::SalsaStructInDb for #ident where @@ -253,7 +298,7 @@ impl TrackedStruct { fn register_dependent_fn(db: &DB, index: salsa::routes::IngredientIndex) { let (jar, _) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(db); let ingredients = <#jar_ty as salsa::storage::HasIngredientsFor<#ident>>::ingredient(jar); - ingredients.#tracked_struct_index.register_dependent_fn(index) + ingredients.#tracked_struct_ingredient.register_dependent_fn(index) } } } @@ -263,7 +308,7 @@ impl TrackedStruct { fn tracked_struct_in_db_impl(&self) -> syn::ItemImpl { let ident = self.id_ident(); let jar_ty = self.jar_ty(); - let tracked_struct_index = self.tracked_struct_index(); + let tracked_struct_ingredient = self.tracked_struct_ingredient_index(); parse_quote! { impl salsa::tracked_struct::TrackedStructInDb for #ident where @@ -272,46 +317,44 @@ impl TrackedStruct { fn database_key_index(self, db: &DB) -> salsa::DatabaseKeyIndex { let (jar, _) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(db); let ingredients = <#jar_ty as salsa::storage::HasIngredientsFor<#ident>>::ingredient(jar); - ingredients.#tracked_struct_index.database_key_index(self) + ingredients.#tracked_struct_ingredient.database_key_index(self) } } } } - /// List of id fields (fields that are part of the tracked struct's identity across revisions). - /// - /// If this is an enum, empty iterator. - fn id_fields(&self) -> impl Iterator { - self.all_fields().filter(|ef| ef.is_id_field()) + /// The index of the tracked struct ingredient in the ingredient tuple. + fn tracked_struct_ingredient_index(&self) -> Literal { + Literal::usize_unsuffixed(0) } - /// List of value fields (fields that are not part of the tracked struct's identity across revisions). - /// - /// If this is an enum, empty iterator. - fn value_fields(&self) -> impl Iterator { - self.all_fields().filter(|ef| !ef.is_id_field()) + /// The index of the tracked field ingredients array in the ingredient tuple. + fn tracked_field_ingredients_index(&self) -> Literal { + Literal::usize_unsuffixed(1) } /// For this struct, we create a tuple that contains the function ingredients - /// for each "other" field and the tracked-struct ingredient. This is the index of - /// the entity ingredient within that tuple. - fn tracked_struct_index(&self) -> Literal { - Literal::usize_unsuffixed(self.value_fields().count()) + /// for each field and the tracked-struct ingredient. These are the indices + /// of the function ingredients within that tuple. + fn all_field_indices(&self) -> Vec { + (0..self.all_fields().count()) + .map(Literal::usize_unsuffixed) + .collect() } /// For this struct, we create a tuple that contains the function ingredients /// for each "other" field and the tracked-struct ingredient. These are the indices /// of the function ingredients within that tuple. - fn value_field_indices(&self) -> Vec { - (0..self.value_fields().count()) - .map(Literal::usize_unsuffixed) - .collect() + fn all_field_count(&self) -> Literal { + Literal::usize_unsuffixed(self.all_fields().count()) } /// Indices of each of the id fields fn id_field_indices(&self) -> Vec { - (0..self.id_fields().count()) - .map(Literal::usize_unsuffixed) + self.all_fields() + .zip(0..) + .filter(|(field, _)| field.is_id_field()) + .map(|(_, index)| Literal::usize_unsuffixed(index)) .collect() } } diff --git a/components/salsa-2022/src/debug.rs b/components/salsa-2022/src/debug.rs index 98149c9a..5050a227 100644 --- a/components/salsa-2022/src/debug.rs +++ b/components/salsa-2022/src/debug.rs @@ -221,11 +221,7 @@ pub mod helper { use std::{fmt, marker::PhantomData}; pub trait Fallback { - fn salsa_debug<'a, 'b>( - a: &'a T, - _db: &'b Db, - _include_all_fields: bool, - ) -> &'a dyn fmt::Debug { + fn salsa_debug<'a>(a: &'a T, _db: &Db, _include_all_fields: bool) -> &'a dyn fmt::Debug { a } } diff --git a/components/salsa-2022/src/function/maybe_changed_after.rs b/components/salsa-2022/src/function/maybe_changed_after.rs index 5ab74d60..f8f9d8fd 100644 --- a/components/salsa-2022/src/function/maybe_changed_after.rs +++ b/components/salsa-2022/src/function/maybe_changed_after.rs @@ -168,6 +168,12 @@ where // then we would have updated the `verified_at` field already. // So the fact that we are here means that it was not specified // during this revision or is otherwise stale. + // + // Example of how this can happen: + // + // Conditionally specified queries + // where the value is specified + // in rev 1 but not in rev 2. return false; } QueryOrigin::BaseInput => { diff --git a/components/salsa-2022/src/hash.rs b/components/salsa-2022/src/hash.rs index 9d6a708e..61055b75 100644 --- a/components/salsa-2022/src/hash.rs +++ b/components/salsa-2022/src/hash.rs @@ -1,4 +1,4 @@ -use std::hash::{BuildHasher, Hash, Hasher}; +use std::hash::{BuildHasher, Hash}; pub(crate) type FxHasher = std::hash::BuildHasherDefault; pub(crate) type FxIndexSet = indexmap::IndexSet; @@ -8,7 +8,5 @@ pub(crate) type FxLinkedHashSet = hashlink::LinkedHashSet; pub(crate) type FxHashSet = std::collections::HashSet; pub(crate) fn hash(t: &T) -> u64 { - let mut hasher = FxHasher::default().build_hasher(); - t.hash(&mut hasher); - hasher.finish() + FxHasher::default().hash_one(t) } diff --git a/components/salsa-2022/src/input_field.rs b/components/salsa-2022/src/input_field.rs index fcd9c706..187dafa7 100644 --- a/components/salsa-2022/src/input_field.rs +++ b/components/salsa-2022/src/input_field.rs @@ -1,6 +1,7 @@ use crate::cycle::CycleRecoveryStrategy; use crate::ingredient::{fmt_index, Ingredient, IngredientRequiresReset}; use crate::key::DependencyIndex; +use crate::plumbing::transmute_lifetime; use crate::runtime::local_state::QueryOrigin; use crate::runtime::StampedValue; use crate::{AsId, DatabaseKeyIndex, Durability, Id, IngredientIndex, Revision, Runtime}; @@ -104,14 +105,6 @@ where } } -// Returns `u` but with the lifetime of `t`. -// -// Safe if you know that data at `u` will remain shared -// until the reference `t` expires. -unsafe fn transmute_lifetime<'t, 'u, T, U>(_t: &'t T, u: &'u U) -> &'t U { - std::mem::transmute(u) -} - impl Ingredient for InputFieldIngredient where K: AsId, diff --git a/components/salsa-2022/src/interned.rs b/components/salsa-2022/src/interned.rs index 80d78f87..47d17619 100644 --- a/components/salsa-2022/src/interned.rs +++ b/components/salsa-2022/src/interned.rs @@ -8,6 +8,7 @@ use crate::durability::Durability; use crate::id::AsId; use crate::ingredient::{fmt_index, IngredientRequiresReset}; use crate::key::DependencyIndex; +use crate::plumbing::transmute_lifetime; use crate::runtime::local_state::QueryOrigin; use crate::runtime::Runtime; use crate::DatabaseKeyIndex; @@ -76,7 +77,12 @@ where } } - pub fn intern(&self, runtime: &Runtime, data: Data) -> Id { + /// Intern `data` and return `(id, b`) where + /// + /// * `id` is the interned id + /// * `b` is a boolean, `true` indicates this fn call added `data` to the interning table; + /// `false` indicates it was already present + pub(crate) fn intern_full(&self, runtime: &Runtime, data: Data) -> (Id, bool) { runtime.report_tracked_read( DependencyIndex::for_table(self.ingredient_index), Durability::MAX, @@ -86,12 +92,12 @@ where // Optimisation to only get read lock on the map if the data has already // been interned. if let Some(id) = self.key_map.get(&data) { - return *id; + return (*id, false); } match self.key_map.entry(data.clone()) { // Data has been interned by a racing call, use that ID instead - dashmap::mapref::entry::Entry::Occupied(entry) => *entry.get(), + dashmap::mapref::entry::Entry::Occupied(entry) => (*entry.get(), false), // We won any races so should intern the data dashmap::mapref::entry::Entry::Vacant(entry) => { let next_id = self.counter.fetch_add(1); @@ -102,11 +108,15 @@ where "next_id is guaranteed to be unique, bar overflow" ); entry.insert(next_id); - next_id + (next_id, true) } } } + pub fn intern(&self, runtime: &Runtime, data: Data) -> Id { + self.intern_full(runtime, data).0 + } + pub(crate) fn reset_at(&self) -> Revision { self.reset_at } @@ -181,14 +191,6 @@ where } } -// Returns `u` but with the lifetime of `t`. -// -// Safe if you know that data at `u` will remain shared -// until the reference `t` expires. -unsafe fn transmute_lifetime<'t, 'u, T, U>(_t: &'t T, u: &'u U) -> &'t U { - std::mem::transmute(u) -} - impl Ingredient for InternedIngredient where Id: InternedId, diff --git a/components/salsa-2022/src/lib.rs b/components/salsa-2022/src/lib.rs index d3fe4ddd..2a96a9c8 100644 --- a/components/salsa-2022/src/lib.rs +++ b/components/salsa-2022/src/lib.rs @@ -22,7 +22,6 @@ pub mod runtime; pub mod salsa_struct; pub mod setter; pub mod storage; -#[doc(hidden)] pub mod tracked_struct; pub use self::cancelled::Cancelled; @@ -43,8 +42,6 @@ pub use self::routes::IngredientIndex; pub use self::runtime::Runtime; pub use self::storage::DbWithJar; pub use self::storage::Storage; -pub use self::tracked_struct::TrackedStructData; -pub use self::tracked_struct::TrackedStructId; pub use salsa_2022_macros::accumulator; pub use salsa_2022_macros::db; pub use salsa_2022_macros::input; diff --git a/components/salsa-2022/src/plumbing.rs b/components/salsa-2022/src/plumbing.rs index 65a06451..ddf672d2 100644 --- a/components/salsa-2022/src/plumbing.rs +++ b/components/salsa-2022/src/plumbing.rs @@ -4,7 +4,7 @@ use crate::storage::HasJars; /// Initializes the `DB`'s jars in-place /// -/// # Safety: +/// # Safety /// /// `init` must fully initialize all of jars fields pub unsafe fn create_jars_inplace(init: impl FnOnce(*mut DB::Jars)) -> Box { @@ -26,3 +26,12 @@ pub unsafe fn create_jars_inplace(init: impl FnOnce(*mut DB::Jars)) unsafe { Box::from_raw(place) } } } + +// Returns `u` but with the lifetime of `t`. +// +// Safe if you know that data at `u` will remain shared +// until the reference `t` expires. +#[allow(clippy::needless_lifetimes)] +pub(crate) unsafe fn transmute_lifetime<'t, 'u, T, U>(_t: &'t T, u: &'u U) -> &'t U { + std::mem::transmute(u) +} diff --git a/components/salsa-2022/src/runtime.rs b/components/salsa-2022/src/runtime.rs index ad2cc879..87beb740 100644 --- a/components/salsa-2022/src/runtime.rs +++ b/components/salsa-2022/src/runtime.rs @@ -166,13 +166,16 @@ impl Runtime { /// * Add a query read on `DatabaseKeyIndex::for_table(entity_index)` /// * Identify a unique disambiguator for the hash within the current query, /// adding the hash to the current query's disambiguator table. - /// * Return that hash + id of the current query. + /// * Returns a tuple of: + /// * the id of the current query + /// * the current dependencies (durability, changed_at) of current query + /// * the disambiguator index pub(crate) fn disambiguate_entity( &self, entity_index: IngredientIndex, reset_at: Revision, data_hash: u64, - ) -> (DatabaseKeyIndex, Disambiguator) { + ) -> (DatabaseKeyIndex, StampedValue<()>, Disambiguator) { self.report_tracked_read( DependencyIndex::for_table(entity_index), Durability::MAX, diff --git a/components/salsa-2022/src/runtime/active_query.rs b/components/salsa-2022/src/runtime/active_query.rs index ebe33a48..67ea202e 100644 --- a/components/salsa-2022/src/runtime/active_query.rs +++ b/components/salsa-2022/src/runtime/active_query.rs @@ -123,7 +123,7 @@ impl ActiveQuery { pub(super) fn remove_cycle_participants(&mut self, cycle: &Cycle) { for p in cycle.participant_keys() { let p: DependencyIndex = p.into(); - self.input_outputs.remove(&(EdgeKind::Input, p)); + self.input_outputs.shift_remove(&(EdgeKind::Input, p)); } } diff --git a/components/salsa-2022/src/runtime/local_state.rs b/components/salsa-2022/src/runtime/local_state.rs index 57724d3c..6405f850 100644 --- a/components/salsa-2022/src/runtime/local_state.rs +++ b/components/salsa-2022/src/runtime/local_state.rs @@ -290,7 +290,10 @@ impl LocalState { } #[track_caller] - pub(crate) fn disambiguate(&self, data_hash: u64) -> (DatabaseKeyIndex, Disambiguator) { + pub(crate) fn disambiguate( + &self, + data_hash: u64, + ) -> (DatabaseKeyIndex, StampedValue<()>, Disambiguator) { assert!( self.query_in_progress(), "cannot create a tracked struct disambiguator outside of a tracked function" @@ -298,7 +301,15 @@ impl LocalState { self.with_query_stack(|stack| { let top_query = stack.last_mut().unwrap(); let disambiguator = top_query.disambiguate(data_hash); - (top_query.database_key_index, disambiguator) + ( + top_query.database_key_index, + StampedValue { + value: (), + durability: top_query.durability, + changed_at: top_query.changed_at, + }, + disambiguator, + ) }) } } diff --git a/components/salsa-2022/src/tracked_struct.rs b/components/salsa-2022/src/tracked_struct.rs index ac0b39c2..25589bc6 100644 --- a/components/salsa-2022/src/tracked_struct.rs +++ b/components/salsa-2022/src/tracked_struct.rs @@ -1,21 +1,65 @@ -use std::fmt; +use std::{fmt, hash::Hash, sync::Arc}; + +use crossbeam::queue::SegQueue; use crate::{ cycle::CycleRecoveryStrategy, + hash::FxDashMap, + id::AsId, ingredient::{fmt_index, Ingredient, IngredientRequiresReset}, ingredient_list::IngredientList, - interned::{InternedData, InternedId, InternedIngredient}, + interned::{InternedId, InternedIngredient}, key::{DatabaseKeyIndex, DependencyIndex}, runtime::{local_state::QueryOrigin, Runtime}, salsa_struct::SalsaStructInDb, - Database, Event, IngredientIndex, Revision, + Database, Durability, Event, IngredientIndex, Revision, }; -pub trait TrackedStructId: InternedId {} -impl TrackedStructId for T {} +pub use self::tracked_field::TrackedFieldIngredient; -pub trait TrackedStructData: InternedData {} -impl TrackedStructData for T {} +mod tracked_field; + +// ANCHOR: Configuration +/// Trait that defines the key properties of a tracked struct. +/// Implemented by the `#[salsa::tracked]` macro when applied +/// to a struct. +pub trait Configuration { + /// The id type used to define instances of this struct. + /// The [`TrackedStructIngredient`][] contains the interner + /// that will create the id values. + type Id: InternedId; + + /// A (possibly empty) tuple of the fields for this struct. + type Fields; + + /// A array of [`Revision`][] values, one per each of the value fields. + /// When a struct is re-recreated in a new revision, the corresponding + /// entries for each field are updated to the new revision if their + /// values have changed (or if the field is marked as `#[no_eq]`). + type Revisions; + + fn id_fields(fields: &Self::Fields) -> impl Hash; + + /// Access the revision of a given value field. + /// `field_index` will be between 0 and the number of value fields. + fn revision(revisions: &Self::Revisions, field_index: u32) -> Revision; + + /// Create a new value revision array where each element is set to `current_revision`. + fn new_revisions(current_revision: Revision) -> Self::Revisions; + + /// Update an existing value revision array `revisions`, + /// given the tuple of the old values (`old_value`) + /// and the tuple of the values (`new_value`). + /// If a value has changed, then its element is + /// updated to `current_revision`. + fn update_revisions( + current_revision: Revision, + old_value: &Self::Fields, + new_value: &Self::Fields, + revisions: &mut Self::Revisions, + ); +} +// ANCHOR_END: Configuration pub trait TrackedStructInDb: SalsaStructInDb { /// Converts the identifier for this tracked struct into a `DatabaseKeyIndex`. @@ -30,12 +74,13 @@ pub trait TrackedStructInDb: SalsaStructInDb { /// Unlike normal interners, tracked struct indices can be deleted and reused aggressively: /// when a tracked function re-executes, /// any tracked structs that it created before but did not create this time can be deleted. -pub struct TrackedStructIngredient +pub struct TrackedStructIngredient where - Id: TrackedStructId, - Data: TrackedStructData, + C: Configuration, { - interned: InternedIngredient>, + interned: InternedIngredient, + + entity_data: Arc>>>, /// A list of each tracked function `f` whose key is this /// tracked struct. @@ -45,58 +90,142 @@ where /// so they can remove any data tied to that instance. dependent_fns: IngredientList, + /// When specific entities are deleted, their data is added + /// to this vector rather than being immediately freed. This is because we may` have + /// references to that data floating about that are tied to the lifetime of some + /// `&db` reference. This queue itself is not freed until we have an `&mut db` reference, + /// guaranteeing that there are no more references to it. + deleted_entries: SegQueue>>, + debug_name: &'static str, } #[derive(Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Copy, Clone)] -struct TrackedStructKey { - query_key: Option, +struct TrackedStructKey { + query_key: DatabaseKeyIndex, disambiguator: Disambiguator, - data: Data, + data_hash: u64, } +// ANCHOR: TrackedStructValue +#[derive(Debug)] +struct TrackedStructValue +where + C: Configuration, +{ + /// The durability minimum durability of all inputs consumed + /// by the creator query prior to creating this tracked struct. + /// If any of those inputs changes, then the creator query may + /// create this struct with different values. + durability: Durability, + + /// The revision when this entity was most recently created. + /// Typically the current revision. + /// Used to detect "leaks" outside of the salsa system -- i.e., + /// access to tracked structs that have not (yet?) been created in the + /// current revision. This should be impossible within salsa queries + /// but it can happen through "leaks" like thread-local data or storing + /// values outside of the root salsa query. + created_at: Revision, + + /// Fields of this tracked struct. They can change across revisions, + /// but they do not change within a particular revision. + fields: C::Fields, + + /// The revision information for each field: when did this field last change. + /// When tracked structs are re-created, this revision may be updated to the + /// current revision if the value is different. + revisions: C::Revisions, +} +// ANCHOR_END: TrackedStructValue + #[derive(Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Copy, Clone)] pub struct Disambiguator(pub u32); -impl TrackedStructIngredient +impl TrackedStructIngredient where - Id: TrackedStructId, - Data: TrackedStructData, + C: Configuration, { pub fn new(index: IngredientIndex, debug_name: &'static str) -> Self { Self { interned: InternedIngredient::new(index, debug_name), + entity_data: Default::default(), dependent_fns: IngredientList::new(), + deleted_entries: SegQueue::default(), debug_name, } } - pub fn database_key_index(&self, id: Id) -> DatabaseKeyIndex { + pub fn new_field_ingredient( + &self, + field_ingredient_index: IngredientIndex, + field_index: u32, + field_debug_name: &'static str, + ) -> TrackedFieldIngredient { + TrackedFieldIngredient { + ingredient_index: field_ingredient_index, + field_index, + entity_data: self.entity_data.clone(), + struct_debug_name: self.debug_name, + field_debug_name, + } + } + + pub fn database_key_index(&self, id: C::Id) -> DatabaseKeyIndex { DatabaseKeyIndex { ingredient_index: self.interned.ingredient_index(), key_index: id.as_id(), } } - pub fn new_struct(&self, runtime: &Runtime, data: Data) -> Id { - let data_hash = crate::hash::hash(&data); - let (query_key, disambiguator) = runtime.disambiguate_entity( + pub fn new_struct(&self, runtime: &Runtime, fields: C::Fields) -> C::Id { + let data_hash = crate::hash::hash(&C::id_fields(&fields)); + + let (query_key, current_deps, disambiguator) = runtime.disambiguate_entity( self.interned.ingredient_index(), self.interned.reset_at(), data_hash, ); - let entity_key = TrackedStructKey { - query_key: Some(query_key), - disambiguator, - data, - }; - let result = self.interned.intern(runtime, entity_key); - runtime.add_output(self.database_key_index(result).into()); - result - } - pub fn tracked_struct_data<'db>(&'db self, runtime: &'db Runtime, id: Id) -> &'db Data { - &self.interned.data(runtime, id).data + let entity_key = TrackedStructKey { + query_key, + disambiguator, + data_hash, + }; + let (id, new_id) = self.interned.intern_full(runtime, entity_key); + runtime.add_output(self.database_key_index(id).into()); + + let current_revision = runtime.current_revision(); + if new_id { + let old_value = self.entity_data.insert( + id, + Box::new(TrackedStructValue { + created_at: current_revision, + durability: current_deps.durability, + fields, + revisions: C::new_revisions(current_deps.changed_at), + }), + ); + assert!(old_value.is_none()); + } else { + let mut data = self.entity_data.get_mut(&id).unwrap(); + let data = &mut *data; + if current_deps.durability < data.durability { + data.revisions = C::new_revisions(current_revision); + } else { + C::update_revisions(current_revision, &data.fields, &fields, &mut data.revisions); + } + data.created_at = current_revision; + data.durability = current_deps.durability; + + // Subtle but important: we *always* update the values of the fields, + // even if they are `==` to the old values. This is because the `==` + // operation might not mean tha tthe fields are bitwise equal, and we + // want to take the new value. + data.fields = fields; + } + + id } /// Deletes the given entities. This is used after a query `Q` executes and we can compare @@ -109,7 +238,7 @@ where /// Using this method on an entity id that MAY be used in the current revision will lead to /// unspecified results (but not UB). See [`InternedIngredient::delete_index`] for more /// discussion and important considerations. - pub(crate) fn delete_entity(&self, db: &dyn crate::Database, id: Id) { + pub(crate) fn delete_entity(&self, db: &dyn crate::Database, id: C::Id) { db.salsa_event(Event { runtime_id: db.runtime().id(), kind: crate::EventKind::DidDiscard { @@ -118,6 +247,10 @@ where }); self.interned.delete_index(id); + if let Some((_, data)) = self.entity_data.remove(&id) { + self.deleted_entries.push(data); + } + for dependent_fn in self.dependent_fns.iter() { db.salsa_struct_deleted(dependent_fn, id.as_id()); } @@ -131,11 +264,10 @@ where } } -impl Ingredient for TrackedStructIngredient +impl Ingredient for TrackedStructIngredient where - Id: TrackedStructId, - Data: TrackedStructData, - DB: crate::Database, + DB: Database, + C: Configuration, { fn ingredient_index(&self) -> IngredientIndex { self.interned.ingredient_index() @@ -172,12 +304,13 @@ where // `executor` creates a tracked struct `salsa_output_key`, // but it did not in the current revision. // In that case, we can delete `stale_output_key` and any data associated with it. - let stale_output_key: Id = Id::from_id(stale_output_key.unwrap()); + let stale_output_key: C::Id = ::from_id(stale_output_key.unwrap()); self.delete_entity(db.as_salsa_database(), stale_output_key); } fn reset_for_new_revision(&mut self) { self.interned.clear_deleted_indices(); + std::mem::take(&mut self.deleted_entries); } fn salsa_struct_deleted(&self, _db: &DB, _id: crate::Id) { @@ -189,10 +322,9 @@ where } } -impl IngredientRequiresReset for TrackedStructIngredient +impl IngredientRequiresReset for TrackedStructIngredient where - Id: TrackedStructId, - Data: TrackedStructData, + C: Configuration, { const RESET_ON_NEW_REVISION: bool = true; } diff --git a/components/salsa-2022/src/tracked_struct/tracked_field.rs b/components/salsa-2022/src/tracked_struct/tracked_field.rs new file mode 100644 index 00000000..6fa55eee --- /dev/null +++ b/components/salsa-2022/src/tracked_struct/tracked_field.rs @@ -0,0 +1,153 @@ +use std::sync::Arc; + +use crate::{ + hash::FxDashMap, + id::AsId, + ingredient::{Ingredient, IngredientRequiresReset}, + key::DependencyIndex, + plumbing::transmute_lifetime, + tracked_struct::TrackedStructValue, + IngredientIndex, Runtime, +}; + +use super::Configuration; + +/// Created for each tracked struct. +/// This ingredient only stores the "id" fields. +/// It is a kind of "dressed up" interner; +/// the active query + values of id fields are hashed to create the tracked struct id. +/// The value fields are stored in [`crate::function::FunctionIngredient`] instances keyed by the tracked struct id. +/// Unlike normal interners, tracked struct indices can be deleted and reused aggressively: +/// when a tracked function re-executes, +/// any tracked structs that it created before but did not create this time can be deleted. +pub struct TrackedFieldIngredient +where + C: Configuration, +{ + /// Index of this ingredient in the database (used to construct database-ids, etc). + pub(super) ingredient_index: IngredientIndex, + pub(super) field_index: u32, + pub(super) entity_data: Arc>>>, + pub(super) struct_debug_name: &'static str, + pub(super) field_debug_name: &'static str, +} + +impl TrackedFieldIngredient +where + C: Configuration, +{ + /// Access to this value field. + /// Note that this function returns the entire tuple of value fields. + /// The caller is responible for selecting the appropriate element. + pub fn field<'db>(&'db self, runtime: &'db Runtime, id: C::Id) -> &'db C::Fields { + let Some(data) = self.entity_data.get(&id) else { + panic!("no data found for entity id {id:?}"); + }; + + let current_revision = runtime.current_revision(); + let created_at = data.created_at; + assert!( + created_at == current_revision, + "access to tracked struct from previous revision" + ); + + let changed_at = C::revision(&data.revisions, self.field_index); + + runtime.report_tracked_read( + DependencyIndex { + ingredient_index: self.ingredient_index, + key_index: Some(id.as_id()), + }, + data.durability, + changed_at, + ); + + // Unsafety clause: + // + // * Values are only removed or altered when we have `&mut self` + unsafe { transmute_lifetime(self, &data.fields) } + } +} + +impl Ingredient for TrackedFieldIngredient +where + C: Configuration, +{ + fn ingredient_index(&self) -> IngredientIndex { + self.ingredient_index + } + + fn cycle_recovery_strategy(&self) -> crate::cycle::CycleRecoveryStrategy { + crate::cycle::CycleRecoveryStrategy::Panic + } + + fn maybe_changed_after( + &self, + _db: &DB, + input: crate::key::DependencyIndex, + revision: crate::Revision, + ) -> bool { + let id = ::from_id(input.key_index.unwrap()); + eprintln!("maybe_changed_after({id:?}, {revision:?})"); + match self.entity_data.get(&id) { + Some(data) => { + let field_changed_at = C::revision(&data.revisions, self.field_index); + field_changed_at > revision + } + None => { + panic!("no data found for field `{id:?}`"); + } + } + } + + fn origin(&self, _key_index: crate::Id) -> Option { + None + } + + fn mark_validated_output( + &self, + _db: &DB, + _executor: crate::DatabaseKeyIndex, + _output_key: Option, + ) { + panic!("tracked field ingredients have no outputs") + } + + fn remove_stale_output( + &self, + _db: &DB, + _executor: crate::DatabaseKeyIndex, + _stale_output_key: Option, + ) { + panic!("tracked field ingredients have no outputs") + } + + fn salsa_struct_deleted(&self, _db: &DB, _id: crate::Id) { + panic!("tracked field ingredients are not registered as dependent") + } + + fn reset_for_new_revision(&mut self) { + panic!("tracked field ingredients do not require reset") + } + + fn fmt_index( + &self, + index: Option, + fmt: &mut std::fmt::Formatter<'_>, + ) -> std::fmt::Result { + write!( + fmt, + "{}.{}({:?})", + self.struct_debug_name, + self.field_debug_name, + index.unwrap() + ) + } +} + +impl IngredientRequiresReset for TrackedFieldIngredient +where + C: Configuration, +{ + const RESET_ON_NEW_REVISION: bool = false; +} diff --git a/examples-2022/calc/Cargo.toml b/examples-2022/calc/Cargo.toml index 85aa0f7e..6301adce 100644 --- a/examples-2022/calc/Cargo.toml +++ b/examples-2022/calc/Cargo.toml @@ -9,6 +9,7 @@ edition = "2021" derive-new = "0.5.9" salsa = { path = "../../components/salsa-2022", package = "salsa-2022" } ordered-float = "3.0" +test-log = { version = "0.2.15", features = ["trace"] } [dev-dependencies] -expect-test = "1.4.0" \ No newline at end of file +expect-test = "1.4.0" diff --git a/examples-2022/calc/src/db.rs b/examples-2022/calc/src/db.rs index 11da9cb1..03af0223 100644 --- a/examples-2022/calc/src/db.rs +++ b/examples-2022/calc/src/db.rs @@ -38,6 +38,7 @@ impl Database { // ANCHOR: db_impl impl salsa::Database for Database { fn salsa_event(&self, event: salsa::Event) { + eprintln!("Event: {event:?}"); // Log interesting events, if logging is enabled if let Some(logs) = &self.logs { // don't log boring events diff --git a/examples-2022/calc/src/type_check.rs b/examples-2022/calc/src/type_check.rs index 24db01cf..62df3133 100644 --- a/examples-2022/calc/src/type_check.rs +++ b/examples-2022/calc/src/type_check.rs @@ -5,6 +5,8 @@ use crate::ir::{ use derive_new::new; #[cfg(test)] use expect_test::expect; +#[cfg(test)] +use test_log::test; // ANCHOR: parse_statements #[salsa::tracked] diff --git a/salsa-2022-tests/tests/compile-fail/tracked_fn_incompatibles.stderr b/salsa-2022-tests/tests/compile-fail/tracked_fn_incompatibles.stderr index 2d9f13c9..7a21ce06 100644 --- a/salsa-2022-tests/tests/compile-fail/tracked_fn_incompatibles.stderr +++ b/salsa-2022-tests/tests/compile-fail/tracked_fn_incompatibles.stderr @@ -54,11 +54,3 @@ error[E0412]: cannot find type `tracked_fn_with_receiver_not_applied_to_impl_blo | 2 | ...r, tracked_fn_with_one_input, tracked_fn_with_receiver_not_applied_to_impl_block); | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ not found in this scope - -error[E0308]: mismatched types - --> tests/compile-fail/tracked_fn_incompatibles.rs:29:46 - | -29 | fn tracked_fn_with_one_input(db: &dyn Db) -> u32 { - | ------------------------- ^^^ expected `u32`, found `()` - | | - | implicitly returns `()` as its body has no tail or `return` expression diff --git a/salsa-2022-tests/tests/create-empty-database.rs b/salsa-2022-tests/tests/create-empty-database.rs index 9064a244..a417d88e 100644 --- a/salsa-2022-tests/tests/create-empty-database.rs +++ b/salsa-2022-tests/tests/create-empty-database.rs @@ -31,5 +31,6 @@ fn ensure_init(place: *const ::Jars) { // SAFETY: Intentionally tries to access potentially uninitialized memory, // so that miri can catch if we accidentally forget to initialize the memory. + #[allow(clippy::forget_non_drop)] forget(unsafe { addr_of!((*place).0).read() }); } diff --git a/salsa-2022-tests/tests/deletion-cascade.rs b/salsa-2022-tests/tests/deletion-cascade.rs index d6343203..db935f8e 100644 --- a/salsa-2022-tests/tests/deletion-cascade.rs +++ b/salsa-2022-tests/tests/deletion-cascade.rs @@ -120,12 +120,9 @@ fn basic() { "intermediate_result(MyInput(Id { value: 1 }))", "salsa_event(WillDiscardStaleOutput { execute_key: create_tracked_structs(0), output_key: MyTracked(2) })", "salsa_event(DidDiscard { key: MyTracked(2) })", - "salsa_event(DidDiscard { key: field(2) })", "salsa_event(DidDiscard { key: contribution_from_struct(2) })", "salsa_event(DidDiscard { key: MyTracked(5) })", - "salsa_event(DidDiscard { key: field(5) })", "salsa_event(DidDiscard { key: copy_field(5) })", - "salsa_event(WillDiscardStaleOutput { execute_key: create_tracked_structs(0), output_key: field(2) })", "final_result(MyInput(Id { value: 1 }))", ]"#]]); } diff --git a/salsa-2022-tests/tests/deletion.rs b/salsa-2022-tests/tests/deletion.rs index e732d40c..5d981ff5 100644 --- a/salsa-2022-tests/tests/deletion.rs +++ b/salsa-2022-tests/tests/deletion.rs @@ -106,9 +106,7 @@ fn basic() { "intermediate_result(MyInput(Id { value: 1 }))", "salsa_event(WillDiscardStaleOutput { execute_key: create_tracked_structs(0), output_key: MyTracked(2) })", "salsa_event(DidDiscard { key: MyTracked(2) })", - "salsa_event(DidDiscard { key: field(2) })", "salsa_event(DidDiscard { key: contribution_from_struct(2) })", - "salsa_event(WillDiscardStaleOutput { execute_key: create_tracked_structs(0), output_key: field(2) })", "final_result(MyInput(Id { value: 1 }))", ]"#]]); } diff --git a/salsa-2022-tests/tests/lru.rs b/salsa-2022-tests/tests/lru.rs index a1ad9a9e..2b84f3cc 100644 --- a/salsa-2022-tests/tests/lru.rs +++ b/salsa-2022-tests/tests/lru.rs @@ -19,7 +19,7 @@ trait Db: salsa::DbWithJar + HasLogger {} struct HotPotato(u32); thread_local! { - static N_POTATOES: AtomicUsize = AtomicUsize::new(0) + static N_POTATOES: AtomicUsize = const { AtomicUsize::new(0) } } impl HotPotato { diff --git a/salsa-2022-tests/tests/panic-when-reading-fields-of-tracked-structs-from-older-revisions.rs b/salsa-2022-tests/tests/panic-when-reading-fields-of-tracked-structs-from-older-revisions.rs index 8288f948..aebb2f34 100644 --- a/salsa-2022-tests/tests/panic-when-reading-fields-of-tracked-structs-from-older-revisions.rs +++ b/salsa-2022-tests/tests/panic-when-reading-fields-of-tracked-structs-from-older-revisions.rs @@ -31,7 +31,7 @@ impl salsa::Database for Database {} impl Db for Database {} #[test] -#[should_panic(expected = "`execute` method for field")] +#[should_panic(expected = "access to tracked struct from previous revision")] fn execute() { let mut db = Database::default(); diff --git a/salsa-2022-tests/tests/specify_tracked_fn_in_rev_1_but_not_2.rs b/salsa-2022-tests/tests/specify_tracked_fn_in_rev_1_but_not_2.rs index db6a8322..00ddef1a 100644 --- a/salsa-2022-tests/tests/specify_tracked_fn_in_rev_1_but_not_2.rs +++ b/salsa-2022-tests/tests/specify_tracked_fn_in_rev_1_but_not_2.rs @@ -147,7 +147,6 @@ fn test_run_10() { "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillCheckCancellation }", "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: maybe_specified(0) } }", "maybe_specified(MyTracked(Id { value: 1 }))", - "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillCheckCancellation }", ]"#]]); } @@ -171,7 +170,6 @@ fn test_run_20() { "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillCheckCancellation }", "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: maybe_specified(0) } }", "maybe_specified(MyTracked(Id { value: 1 }))", - "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillCheckCancellation }", ]"#]]); } @@ -209,7 +207,6 @@ fn test_run_0_then_5_then_20() { [ "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillCheckCancellation }", "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillCheckCancellation }", - "Event { runtime_id: RuntimeId { counter: 0 }, kind: DidValidateMemoizedValue { database_key: input(0) } }", "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: create_tracked(0) } }", "create_tracked(MyInput(Id { value: 1 }))", "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillCheckCancellation }", @@ -229,7 +226,6 @@ fn test_run_0_then_5_then_20() { [ "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillCheckCancellation }", "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillCheckCancellation }", - "Event { runtime_id: RuntimeId { counter: 0 }, kind: DidValidateMemoizedValue { database_key: input(0) } }", "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: create_tracked(0) } }", "create_tracked(MyInput(Id { value: 1 }))", "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillDiscardStaleOutput { execute_key: create_tracked(0), output_key: maybe_specified(0) } }", @@ -237,7 +233,6 @@ fn test_run_0_then_5_then_20() { "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillCheckCancellation }", "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: maybe_specified(0) } }", "maybe_specified(MyTracked(Id { value: 1 }))", - "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillCheckCancellation }", "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: read_maybe_specified(0) } }", "read_maybe_specified(MyTracked(Id { value: 1 }))", "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillCheckCancellation }", @@ -282,7 +277,6 @@ fn test_run_0_then_5_then_10_then_20() { [ "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillCheckCancellation }", "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillCheckCancellation }", - "Event { runtime_id: RuntimeId { counter: 0 }, kind: DidValidateMemoizedValue { database_key: input(0) } }", "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: create_tracked(0) } }", "create_tracked(MyInput(Id { value: 1 }))", "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillCheckCancellation }", @@ -302,7 +296,6 @@ fn test_run_0_then_5_then_10_then_20() { [ "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillCheckCancellation }", "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillCheckCancellation }", - "Event { runtime_id: RuntimeId { counter: 0 }, kind: DidValidateMemoizedValue { database_key: input(0) } }", "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: create_tracked(0) } }", "create_tracked(MyInput(Id { value: 1 }))", "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillDiscardStaleOutput { execute_key: create_tracked(0), output_key: maybe_specified(0) } }", @@ -310,7 +303,6 @@ fn test_run_0_then_5_then_10_then_20() { "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillCheckCancellation }", "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: maybe_specified(0) } }", "maybe_specified(MyTracked(Id { value: 1 }))", - "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillCheckCancellation }", "Event { runtime_id: RuntimeId { counter: 0 }, kind: DidValidateMemoizedValue { database_key: read_maybe_specified(0) } }", "Event { runtime_id: RuntimeId { counter: 0 }, kind: DidValidateMemoizedValue { database_key: final_result(0) } }", ]"#]]); @@ -324,15 +316,12 @@ fn test_run_0_then_5_then_10_then_20() { [ "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillCheckCancellation }", "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillCheckCancellation }", - "Event { runtime_id: RuntimeId { counter: 0 }, kind: DidValidateMemoizedValue { database_key: input(0) } }", "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: create_tracked(0) } }", "create_tracked(MyInput(Id { value: 1 }))", "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillCheckCancellation }", "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillCheckCancellation }", - "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillCheckCancellation }", "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: maybe_specified(0) } }", "maybe_specified(MyTracked(Id { value: 1 }))", - "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillCheckCancellation }", "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: read_maybe_specified(0) } }", "read_maybe_specified(MyTracked(Id { value: 1 }))", "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillCheckCancellation }", @@ -369,7 +358,6 @@ fn test_run_5_then_20() { [ "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillCheckCancellation }", "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillCheckCancellation }", - "Event { runtime_id: RuntimeId { counter: 0 }, kind: DidValidateMemoizedValue { database_key: input(0) } }", "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: create_tracked(0) } }", "create_tracked(MyInput(Id { value: 1 }))", "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillDiscardStaleOutput { execute_key: create_tracked(0), output_key: maybe_specified(0) } }", @@ -377,7 +365,6 @@ fn test_run_5_then_20() { "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillCheckCancellation }", "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: maybe_specified(0) } }", "maybe_specified(MyTracked(Id { value: 1 }))", - "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillCheckCancellation }", "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: read_maybe_specified(0) } }", "read_maybe_specified(MyTracked(Id { value: 1 }))", "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillCheckCancellation }", diff --git a/salsa-2022-tests/tests/tracked-struct-field-bad-eq.rs b/salsa-2022-tests/tests/tracked-struct-field-bad-eq.rs new file mode 100644 index 00000000..bd487a8f --- /dev/null +++ b/salsa-2022-tests/tests/tracked-struct-field-bad-eq.rs @@ -0,0 +1,65 @@ +//! Test a field whose `PartialEq` impl is always true. +//! This can our "last changed" data to be wrong +//! but we *should* always reflect the final values. + +use test_log::test; + +#[salsa::jar(db = Db)] +struct Jar(MyInput, MyTracked, the_fn); + +trait Db: salsa::DbWithJar {} + +#[salsa::input] +struct MyInput { + field: bool, +} + +#[allow(clippy::derived_hash_with_manual_eq)] +#[derive(Eq, Hash, Debug, Clone)] +struct BadEq { + field: bool, +} + +impl PartialEq for BadEq { + fn eq(&self, _other: &Self) -> bool { + true + } +} + +impl From for BadEq { + fn from(value: bool) -> Self { + Self { field: value } + } +} + +#[salsa::tracked] +struct MyTracked { + #[id] + field: BadEq, +} + +#[salsa::tracked] +fn the_fn(db: &dyn Db, input: MyInput) { + let tracked0 = MyTracked::new(db, BadEq::from(input.field(db))); + assert_eq!(tracked0.field(db).field, input.field(db)); +} + +#[salsa::db(Jar)] +#[derive(Default)] +struct Database { + storage: salsa::Storage, +} + +impl salsa::Database for Database {} + +impl Db for Database {} + +#[test] +fn execute() { + let mut db = Database::default(); + + let input = MyInput::new(&db, true); + the_fn(&db, input); + input.set_field(&mut db).to(false); + the_fn(&db, input); +} diff --git a/salsa-2022-tests/tests/tracked-struct-field-not-eq.rs b/salsa-2022-tests/tests/tracked-struct-field-not-eq.rs new file mode 100644 index 00000000..2822611b --- /dev/null +++ b/salsa-2022-tests/tests/tracked-struct-field-not-eq.rs @@ -0,0 +1,58 @@ +//! Test a field whose `PartialEq` impl is always true. +//! This can our "last changed" data to be wrong +//! but we *should* always reflect the final values. + +use test_log::test; + +#[salsa::jar(db = Db)] +struct Jar(MyInput, MyTracked, the_fn); + +trait Db: salsa::DbWithJar {} + +#[salsa::input] +struct MyInput { + field: bool, +} + +#[derive(Hash, Debug, Clone)] +struct NotEq { + field: bool, +} + +impl From for NotEq { + fn from(value: bool) -> Self { + Self { field: value } + } +} + +#[salsa::tracked] +struct MyTracked { + #[no_eq] + field: NotEq, +} + +#[salsa::tracked] +fn the_fn(db: &dyn Db, input: MyInput) { + let tracked0 = MyTracked::new(db, NotEq::from(input.field(db))); + assert_eq!(tracked0.field(db).field, input.field(db)); +} + +#[salsa::db(Jar)] +#[derive(Default)] +struct Database { + storage: salsa::Storage, +} + +impl salsa::Database for Database {} + +impl Db for Database {} + +#[test] +fn execute() { + let mut db = Database::default(); + + let input = MyInput::new(&db, true); + the_fn(&db, input); + input.set_field(&mut db).to(false); + the_fn(&db, input); +} diff --git a/salsa-2022-tests/tests/tracked-struct-id-field-bad-hash.rs b/salsa-2022-tests/tests/tracked-struct-id-field-bad-hash.rs new file mode 100644 index 00000000..98c1458f --- /dev/null +++ b/salsa-2022-tests/tests/tracked-struct-id-field-bad-hash.rs @@ -0,0 +1,67 @@ +//! Test for a tracked struct where the id field has a +//! very poorly chosen hash impl (always returns 0). +//! This demonstrates that the `#[id]` fields on a struct +//! can change values and yet the struct can have the same +//! id (because struct ids are based on the *hash* of the +//! `#[id]` fields). + +use test_log::test; + +#[salsa::jar(db = Db)] +struct Jar(MyInput, MyTracked, the_fn); + +trait Db: salsa::DbWithJar {} + +#[salsa::input] +struct MyInput { + field: bool, +} + +#[derive(PartialEq, Eq, PartialOrd, Ord, Debug, Clone)] +struct BadHash { + field: bool, +} + +impl From for BadHash { + fn from(value: bool) -> Self { + Self { field: value } + } +} + +impl std::hash::Hash for BadHash { + fn hash(&self, state: &mut H) { + state.write_i16(0); + } +} + +#[salsa::tracked] +struct MyTracked { + #[id] + field: BadHash, +} + +#[salsa::tracked] +fn the_fn(db: &dyn Db, input: MyInput) { + let tracked0 = MyTracked::new(db, BadHash::from(input.field(db))); + assert_eq!(tracked0.field(db).field, input.field(db)); +} + +#[salsa::db(Jar)] +#[derive(Default)] +struct Database { + storage: salsa::Storage, +} + +impl salsa::Database for Database {} + +impl Db for Database {} + +#[test] +fn execute() { + let mut db = Database::default(); + + let input = MyInput::new(&db, true); + the_fn(&db, input); + input.set_field(&mut db).to(false); + the_fn(&db, input); +} diff --git a/src/runtime.rs b/src/runtime.rs index 20bf20ed..d0b616a4 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -618,7 +618,7 @@ impl ActiveQuery { fn remove_cycle_participants(&mut self, cycle: &Cycle) { if let Some(my_dependencies) = &mut self.dependencies { for p in cycle.participant_keys() { - my_dependencies.remove(&p); + my_dependencies.shift_remove(&p); } } } diff --git a/tests/lru.rs b/tests/lru.rs index 7d42522c..66f3498c 100644 --- a/tests/lru.rs +++ b/tests/lru.rs @@ -17,7 +17,7 @@ trait LruPeek { struct HotPotato(u32); thread_local! { - static N_POTATOES: AtomicUsize = AtomicUsize::new(0) + static N_POTATOES: AtomicUsize = const { AtomicUsize::new(0) } } impl HotPotato { diff --git a/tests/parallel/setup.rs b/tests/parallel/setup.rs index 203e3633..9acc34c8 100644 --- a/tests/parallel/setup.rs +++ b/tests/parallel/setup.rs @@ -57,18 +57,13 @@ impl WithValue for Cell { } } -#[derive(Clone, Copy, PartialEq, Eq)] +#[derive(Clone, Copy, PartialEq, Eq, Default)] pub(crate) enum CancellationFlag { + #[default] Down, Panic, } -impl Default for CancellationFlag { - fn default() -> CancellationFlag { - CancellationFlag::Down - } -} - /// Various "knobs" that can be used to customize how the queries /// behave on one specific thread. Note that this state is /// intentionally thread-local (apart from `signal`).