From 612cec6703bec4e906c3ebfc2b2978d6b3eefc04 Mon Sep 17 00:00:00 2001 From: Niko Matsakis Date: Mon, 15 Jul 2024 20:29:36 -0400 Subject: [PATCH] wip --- components/salsa-macro-rules/src/lib.rs | 1 + .../src/setup_accumulator_impl.rs | 43 ++++ .../src/setup_interned_fn.rs | 53 +++-- .../salsa-macro-rules/src/setup_struct_fn.rs | 33 ++- components/salsa-macros/src/accumulator.rs | 191 ++++-------------- components/salsa-macros/src/input.rs | 2 - components/salsa-macros/src/interned.rs | 2 - components/salsa-macros/src/options.rs | 26 --- components/salsa-macros/src/tracked_fn.rs | 2 - components/salsa-macros/src/tracked_struct.rs | 2 - examples/lazy-input/main.rs | 24 +-- src/accumulator.rs | 50 +++-- src/function.rs | 3 +- src/function/accumulated.rs | 79 -------- src/input/input_field.rs | 1 - src/lib.rs | 14 ++ src/runtime.rs | 2 +- src/storage.rs | 3 +- tests/accumulate-from-tracked-fn.rs | 10 +- tests/accumulate-reuse.rs | 16 +- 20 files changed, 227 insertions(+), 330 deletions(-) create mode 100644 components/salsa-macro-rules/src/setup_accumulator_impl.rs delete mode 100644 src/function/accumulated.rs diff --git a/components/salsa-macro-rules/src/lib.rs b/components/salsa-macro-rules/src/lib.rs index f4723dae..0cf993b4 100644 --- a/components/salsa-macro-rules/src/lib.rs +++ b/components/salsa-macro-rules/src/lib.rs @@ -14,6 +14,7 @@ mod maybe_backdate; mod maybe_clone; +mod setup_accumulator_impl; mod setup_input_struct; mod setup_interned_fn; mod setup_interned_struct; diff --git a/components/salsa-macro-rules/src/setup_accumulator_impl.rs b/components/salsa-macro-rules/src/setup_accumulator_impl.rs new file mode 100644 index 00000000..cd639b7f --- /dev/null +++ b/components/salsa-macro-rules/src/setup_accumulator_impl.rs @@ -0,0 +1,43 @@ +/// Macro for setting up a function that must intern its arguments. +#[macro_export] +macro_rules! setup_accumulator_impl { + ( + // Name of the struct + Struct: $Struct:ident, + + // Annoyingly macro-rules hygiene does not extend to items defined in the macro. + // We have the procedural macro generate names for those items that are + // not used elsewhere in the user's code. + unused_names: [ + $zalsa:ident, + $zalsa_struct:ident, + $CACHE:ident, + $ingredient:ident, + ] + ) => { + const _: () = { + use salsa::plumbing as $zalsa; + use salsa::plumbing::accumulator as $zalsa_struct; + + static $CACHE: $zalsa::IngredientCache<$zalsa_struct::IngredientImpl<$Struct>> = + $zalsa::IngredientCache::new(); + + fn $ingredient(db: &dyn $zalsa::Database) -> &$zalsa_struct::IngredientImpl<$Struct> { + $CACHE.get_or_create(db, || { + db.add_or_lookup_jar_by_type(&<$zalsa_struct::JarImpl<$Struct>>::default()) + }) + } + + impl $zalsa::Accumulator for $Struct { + const DEBUG_NAME: &'static str = stringify!($Struct); + + fn accumulate(self, db: &Db) + where + Db: ?Sized + $zalsa::Database, + { + $ingredient(db.as_salsa_database()).push(db.runtime(), self); + } + } + }; + }; +} diff --git a/components/salsa-macro-rules/src/setup_interned_fn.rs b/components/salsa-macro-rules/src/setup_interned_fn.rs index 4053001b..25730d5e 100644 --- a/components/salsa-macro-rules/src/setup_interned_fn.rs +++ b/components/salsa-macro-rules/src/setup_interned_fn.rs @@ -52,6 +52,11 @@ macro_rules! setup_interned_fn { $inner:ident, ] ) => { + #[allow(non_camel_case_types)] + $vis struct $fn_name { + _priv: std::convert::Infallible, + } + $(#[$attr])* $vis fn $fn_name<$db_lt>( $db: &$db_lt dyn $Db, @@ -62,7 +67,7 @@ macro_rules! setup_interned_fn { struct $Configuration; #[derive(Copy, Clone)] - struct $InternedData<'db>( + struct $InternedData<$db_lt>( std::ptr::NonNull<$zalsa::interned::Value<$Configuration>>, std::marker::PhantomData<&'db $zalsa::interned::Value<$Configuration>>, ); @@ -73,6 +78,23 @@ macro_rules! setup_interned_fn { static $INTERN_CACHE: $zalsa::IngredientCache<$zalsa::interned::IngredientImpl<$Configuration>> = $zalsa::IngredientCache::new(); + impl $Configuration { + fn fn_ingredient(db: &dyn $Db) -> &$zalsa::function::IngredientImpl<$Configuration> { + $FN_CACHE.get_or_create(db.as_salsa_database(), || { + ::zalsa_db(db); + db.add_or_lookup_jar_by_type(&$Configuration) + }) + } + + fn intern_ingredient( + db: &dyn $Db, + ) -> &$zalsa::interned::IngredientImpl<$Configuration> { + $INTERN_CACHE.get_or_create(db.as_salsa_database(), || { + db.add_or_lookup_jar_by_type(&$Configuration).successor(0) + }) + } + } + impl $zalsa::SalsaStructInDb for $InternedData<'_> { fn register_dependent_fn(_db: &dyn $zalsa::Database, _index: $zalsa::IngredientIndex) {} } @@ -112,10 +134,7 @@ macro_rules! setup_interned_fn { } fn id_to_input<'db>(db: &'db Self::DbView, key: salsa::Id) -> Self::Input<'db> { - let ingredient = $INTERN_CACHE.get_or_create(db.as_salsa_database(), || { - db.add_or_lookup_jar_by_type(&$Configuration).successor(0) - }); - ingredient.data(key).clone() + $Configuration::intern_ingredient(db).data(key).clone() } } @@ -153,17 +172,21 @@ macro_rules! setup_interned_fn { } } - $zalsa::attach_database($db, || { - let intern_ingredient = $INTERN_CACHE.get_or_create($db.as_salsa_database(), || { - $db.add_or_lookup_jar_by_type(&$Configuration).successor(0) - }); - let key = intern_ingredient.intern_id($db.runtime(), ($($input_id),*)); + impl $fn_name { + pub fn accumulated<$db_lt, A: salsa::Accumulator>( + $db: &$db_lt dyn $Db, + $($input_id: $input_ty,)* + ) -> Vec { + use salsa::plumbing as $zalsa; + let key = $Configuration::intern_ingredient($db).intern_id($db.runtime(), ($($input_id),*)); + let database_key_index = $Configuration::fn_ingredient($db).database_key_index(key); + $zalsa::accumulated_by($db.as_salsa_database(), database_key_index) + } + } - let fn_ingredient = $FN_CACHE.get_or_create($db.as_salsa_database(), || { - ::zalsa_db($db); - $db.add_or_lookup_jar_by_type(&$Configuration) - }); - fn_ingredient.fetch($db, key).clone() + $zalsa::attach_database($db, || { + let key = $Configuration::intern_ingredient($db).intern_id($db.runtime(), ($($input_id),*)); + $Configuration::fn_ingredient($db).fetch($db, key).clone() }) } }; diff --git a/components/salsa-macro-rules/src/setup_struct_fn.rs b/components/salsa-macro-rules/src/setup_struct_fn.rs index a8405fa1..3bd0d073 100644 --- a/components/salsa-macro-rules/src/setup_struct_fn.rs +++ b/components/salsa-macro-rules/src/setup_struct_fn.rs @@ -50,6 +50,11 @@ macro_rules! setup_struct_fn { $inner:ident, ] ) => { + #[allow(non_camel_case_types)] + $vis struct $fn_name { + _priv: std::convert::Infallible, + } + $(#[$attr])* $vis fn $fn_name<$db_lt>( $db: &$db_lt dyn $Db, @@ -62,6 +67,15 @@ macro_rules! setup_struct_fn { static $FN_CACHE: $zalsa::IngredientCache<$zalsa::function::IngredientImpl<$Configuration>> = $zalsa::IngredientCache::new(); + impl $Configuration { + fn fn_ingredient(db: &dyn $Db) -> &$zalsa::function::IngredientImpl<$Configuration> { + $FN_CACHE.get_or_create(db.as_salsa_database(), || { + ::zalsa_db(db); + db.add_or_lookup_jar_by_type(&$Configuration) + }) + } + } + impl $zalsa::function::Configuration for $Configuration { const DEBUG_NAME: &'static str = stringify!($fn_name); @@ -114,13 +128,20 @@ macro_rules! setup_struct_fn { } } - $zalsa::attach_database($db, || { - let fn_ingredient = $FN_CACHE.get_or_create($db.as_salsa_database(), || { - ::zalsa_db($db); - $db.add_or_lookup_jar_by_type(&$Configuration) - }); + impl $fn_name { + pub fn accumulated<$db_lt, A: salsa::Accumulator>( + $db: &$db_lt dyn $Db, + $input_id: $input_ty, + ) -> Vec { + use salsa::plumbing as $zalsa; + let key = $zalsa::AsId::as_id(&$input_id); + let database_key_index = $Configuration::fn_ingredient($db).database_key_index(key); + $zalsa::accumulated_by($db.as_salsa_database(), database_key_index) + } + } - fn_ingredient.fetch($db, $zalsa::AsId::as_id(&$input_id)).clone() + $zalsa::attach_database($db, || { + $Configuration::fn_ingredient($db).fetch($db, $zalsa::AsId::as_id(&$input_id)).clone() }) } }; diff --git a/components/salsa-macros/src/accumulator.rs b/components/salsa-macros/src/accumulator.rs index 31bdf11a..b2d21d31 100644 --- a/components/salsa-macros/src/accumulator.rs +++ b/components/salsa-macros/src/accumulator.rs @@ -1,4 +1,7 @@ -use syn::{spanned::Spanned, ItemStruct}; +use proc_macro2::TokenStream; +use syn::{parse::Nothing, spanned::Spanned}; + +use crate::hygiene::Hygiene; // #[salsa::accumulator(jar = Jar0)] // struct Accumulator(DataType); @@ -7,159 +10,49 @@ pub(crate) fn accumulator( args: proc_macro::TokenStream, input: proc_macro::TokenStream, ) -> proc_macro::TokenStream { - let args = syn::parse_macro_input!(args as Args); - let struct_impl = syn::parse_macro_input!(input as ItemStruct); - accumulator_contents(&args, &struct_impl) - .unwrap_or_else(syn::Error::into_compile_error) - .into() -} - -type Args = crate::options::Options; - -struct Accumulator; - -impl crate::options::AllowedOptions for Accumulator { - const RETURN_REF: bool = false; - - const SPECIFY: bool = false; - - const NO_EQ: bool = false; - - const SINGLETON: bool = false; - - const JAR: bool = true; - - const DATA: bool = false; - - const DB: bool = false; - - const RECOVERY_FN: bool = false; - - const LRU: bool = false; - - const CONSTRUCTOR_NAME: bool = false; -} - -fn accumulator_contents( - args: &Args, - struct_item: &syn::ItemStruct, -) -> syn::Result { - // We expect a single anonymous field. - let data_ty = data_ty(struct_item)?; - let struct_name = &struct_item.ident; - let struct_ty = &parse_quote! {#struct_name}; - - let inherent_impl = inherent_impl(args, struct_ty, data_ty); - let ingredients_for_impl = ingredients_for_impl(args, struct_name, data_ty); - let struct_item_out = struct_item_out(args, struct_item, data_ty); - let accumulator_impl = accumulator_impl(args, struct_ty, data_ty); - - Ok(quote! { - #inherent_impl - #ingredients_for_impl - #struct_item_out - #accumulator_impl - }) -} - -fn data_ty(struct_item: &syn::ItemStruct) -> syn::Result<&syn::Type> { - if let syn::Fields::Unnamed(fields) = &struct_item.fields { - if fields.unnamed.len() != 1 { - Err(syn::Error::new( - struct_item.ident.span(), - "accumulator structs should have only one anonymous field", - )) - } else { - Ok(&fields.unnamed[0].ty) - } - } else { - Err(syn::Error::new( - struct_item.ident.span(), - "accumulator structs should have only one anonymous field", - )) + let hygiene = Hygiene::from1(&input); + let _ = syn::parse_macro_input!(args as Nothing); + let struct_item = syn::parse_macro_input!(input as syn::ItemStruct); + let ident = struct_item.ident.clone(); + let m = StructMacro { + hygiene, + struct_item, + }; + match m.try_expand() { + Ok(v) => crate::debug::dump_tokens(&ident, v).into(), + Err(e) => e.to_compile_error().into(), } } -fn struct_item_out( - _args: &Args, - struct_item: &syn::ItemStruct, - data_ty: &syn::Type, -) -> syn::ItemStruct { - let mut struct_item_out = struct_item.clone(); - struct_item_out.fields = syn::Fields::Unnamed(parse_quote_spanned! { data_ty.span() => - (std::marker::PhantomData<#data_ty>) - }); - struct_item_out +struct StructMacro { + hygiene: Hygiene, + struct_item: syn::ItemStruct, } -fn inherent_impl(args: &Args, struct_ty: &syn::Type, data_ty: &syn::Type) -> syn::ItemImpl { - let jar_ty = args.jar_ty(); - parse_quote_spanned! { struct_ty.span() => - #[allow(dead_code, clippy::pedantic, clippy::complexity, clippy::style)] - impl #struct_ty { - pub fn push(db: &DB, data: #data_ty) - where - DB: salsa::storage::HasJar<#jar_ty>, - { - let (jar, runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(db); - let ingredients = <#jar_ty as salsa::storage::HasIngredientsFor< #struct_ty >>::ingredient(jar); - ingredients.push(runtime, data) +#[allow(non_snake_case)] +impl StructMacro { + fn try_expand(self) -> syn::Result { + let ident = self.struct_item.ident.clone(); + + let zalsa = self.hygiene.ident("zalsa"); + let zalsa_struct = self.hygiene.ident("zalsa_struct"); + let CACHE = self.hygiene.ident("CACHE"); + let ingredient = self.hygiene.ident("ingredient"); + + let struct_item = self.struct_item; + + Ok(quote! { + #struct_item + + salsa::plumbing::setup_accumulator_impl! { + Struct: #ident, + unused_names: [ + #zalsa, + #zalsa_struct, + #CACHE, + #ingredient, + ] } - } - } -} - -fn ingredients_for_impl( - args: &Args, - struct_name: &syn::Ident, - data_ty: &syn::Type, -) -> syn::ItemImpl { - let jar_ty = args.jar_ty(); - let debug_name = crate::literal(struct_name); - parse_quote_spanned! { struct_name.span() => - #[allow(dead_code, clippy::pedantic, clippy::complexity, clippy::style)] - impl salsa::storage::IngredientsFor for #struct_name { - type Ingredients = salsa::accumulator::AccumulatorIngredient<#data_ty>; - type Jar = #jar_ty; - - fn create_ingredients(routes: &mut salsa::routes::Routes) -> Self::Ingredients - where - DB: salsa::DbWithJar + salsa::storage::JarFromJars, - { - let index = routes.push( - |jars| { - let jar = >::jar_from_jars(jars); - <_ as salsa::storage::HasIngredientsFor>::ingredient(jar) - }, - |jars| { - let jar = >::jar_from_jars_mut(jars); - <_ as salsa::storage::HasIngredientsFor>::ingredient_mut(jar) - }, - ); - salsa::accumulator::AccumulatorIngredient::new(index, #debug_name) - } - } - } -} - -fn accumulator_impl(args: &Args, struct_ty: &syn::Type, data_ty: &syn::Type) -> syn::ItemImpl { - let jar_ty = args.jar_ty(); - parse_quote_spanned! { struct_ty.span() => - #[allow(dead_code, clippy::pedantic, clippy::complexity, clippy::style)] - impl salsa::accumulator::Accumulator for #struct_ty { - type Data = #data_ty; - type Jar = #jar_ty; - - fn accumulator_ingredient<'db, Db>( - db: &'db Db, - ) -> &'db salsa::accumulator::AccumulatorIngredient - where - Db: ?Sized + salsa::storage::HasJar - { - let (jar, _) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(db); - let ingredients = <#jar_ty as salsa::storage::HasIngredientsFor<#struct_ty>>::ingredient(jar); - ingredients - } - } + }) } } diff --git a/components/salsa-macros/src/input.rs b/components/salsa-macros/src/input.rs index 2d979377..cd4bb8c2 100644 --- a/components/salsa-macros/src/input.rs +++ b/components/salsa-macros/src/input.rs @@ -40,8 +40,6 @@ impl crate::options::AllowedOptions for InputStruct { const NO_EQ: bool = false; const SINGLETON: bool = true; - const JAR: bool = true; - const DATA: bool = true; const DB: bool = false; diff --git a/components/salsa-macros/src/interned.rs b/components/salsa-macros/src/interned.rs index e6a64298..afcc29ac 100644 --- a/components/salsa-macros/src/interned.rs +++ b/components/salsa-macros/src/interned.rs @@ -42,8 +42,6 @@ impl crate::options::AllowedOptions for InternedStruct { const SINGLETON: bool = true; - const JAR: bool = true; - const DATA: bool = true; const DB: bool = false; diff --git a/components/salsa-macros/src/options.rs b/components/salsa-macros/src/options.rs index 66fca672..0fbecdc2 100644 --- a/components/salsa-macros/src/options.rs +++ b/components/salsa-macros/src/options.rs @@ -89,7 +89,6 @@ pub(crate) trait AllowedOptions { const SPECIFY: bool; const NO_EQ: bool; const SINGLETON: bool; - const JAR: bool; const DATA: bool; const DB: bool; const RECOVERY_FN: bool; @@ -100,18 +99,6 @@ pub(crate) trait AllowedOptions { type Equals = syn::Token![=]; type Comma = syn::Token![,]; -impl Options { - /// Returns the `jar type` given by the user; if none is given, - /// returns the default `crate::Jar`. - pub(crate) fn jar_ty(&self) -> syn::Type { - if let Some(jar_ty) = &self.jar_ty { - return jar_ty.clone(); - } - - parse_quote! {crate::Jar} - } -} - impl syn::parse::Parse for Options { fn parse(input: syn::parse::ParseStream) -> syn::Result { let mut options = Options::default(); @@ -171,19 +158,6 @@ impl syn::parse::Parse for Options { "`specify` option not allowed here", )); } - } else if ident == "jar" { - if A::JAR { - let _eq = Equals::parse(input)?; - let ty = syn::Type::parse(input)?; - if let Some(old) = std::mem::replace(&mut options.jar_ty, Some(ty)) { - return Err(syn::Error::new(old.span(), "option `jar` provided twice")); - } - } else { - return Err(syn::Error::new( - ident.span(), - "`jar` option not allowed here", - )); - } } else if ident == "db" { if A::DB { let _eq = Equals::parse(input)?; diff --git a/components/salsa-macros/src/tracked_fn.rs b/components/salsa-macros/src/tracked_fn.rs index 31e95b30..c01056e4 100644 --- a/components/salsa-macros/src/tracked_fn.rs +++ b/components/salsa-macros/src/tracked_fn.rs @@ -31,8 +31,6 @@ impl crate::options::AllowedOptions for TrackedFn { const SINGLETON: bool = false; - const JAR: bool = false; - const DATA: bool = false; const DB: bool = false; diff --git a/components/salsa-macros/src/tracked_struct.rs b/components/salsa-macros/src/tracked_struct.rs index 98ffb1a3..5c121875 100644 --- a/components/salsa-macros/src/tracked_struct.rs +++ b/components/salsa-macros/src/tracked_struct.rs @@ -37,8 +37,6 @@ impl crate::options::AllowedOptions for TrackedStruct { const SINGLETON: bool = true; - const JAR: bool = true; - const DATA: bool = true; const DB: bool = false; diff --git a/examples/lazy-input/main.rs b/examples/lazy-input/main.rs index e9850842..203e54d9 100644 --- a/examples/lazy-input/main.rs +++ b/examples/lazy-input/main.rs @@ -8,6 +8,7 @@ use notify_debouncer_mini::{ notify::{RecommendedWatcher, RecursiveMode}, DebounceEventResult, Debouncer, }; +use salsa::{Accumulator, Setter}; // ANCHOR: main fn main() -> Result<()> { @@ -31,7 +32,7 @@ fn main() -> Result<()> { println!("Sum is: {}", sum); } else { for diagnostic in diagnostics { - println!("{}", diagnostic); + println!("{}", diagnostic.0); } } @@ -132,21 +133,20 @@ impl salsa::Database for Database { } #[salsa::accumulator] +#[derive(Clone, Debug)] struct Diagnostic(String); impl Diagnostic { fn push_error(db: &dyn Db, file: File, error: Report) { - Diagnostic::push( - db, - format!( - "Error in file {}: {:?}\n", - file.path(db) - .file_name() - .unwrap_or_else(|| "".as_ref()) - .to_string_lossy(), - error, - ), - ) + Diagnostic(format!( + "Error in file {}: {:?}\n", + file.path(db) + .file_name() + .unwrap_or_else(|| "".as_ref()) + .to_string_lossy(), + error, + )) + .accumulate(db); } } diff --git a/src/accumulator.rs b/src/accumulator.rs index 70c03dc5..cc96d29b 100644 --- a/src/accumulator.rs +++ b/src/accumulator.rs @@ -15,17 +15,20 @@ use crate::{ Database, DatabaseKeyIndex, Event, EventKind, Id, Revision, Runtime, }; -pub trait Accumulator: Jar { +pub trait Accumulator: Clone + Debug + Send + Sync + 'static + Sized { const DEBUG_NAME: &'static str; - type Data: Clone + Debug + Send + Sync; + /// Accumulate an instance of this in the database for later retrieval. + fn accumulate(self, db: &Db) + where + Db: ?Sized + Database; } -pub struct AccumulatorJar { +pub struct JarImpl { phantom: PhantomData, } -impl Default for AccumulatorJar { +impl Default for JarImpl { fn default() -> Self { Self { phantom: Default::default(), @@ -33,29 +36,29 @@ impl Default for AccumulatorJar { } } -impl Jar for AccumulatorJar { +impl Jar for JarImpl { fn create_ingredients(&self, first_index: IngredientIndex) -> Vec> { - vec![Box::new(>::new(first_index))] + vec![Box::new(>::new(first_index))] } } -pub struct AccumulatorIngredient { +pub struct IngredientImpl { index: IngredientIndex, - map: FxDashMap>, + map: FxDashMap>, } -struct AccumulatedValues { +struct AccumulatedValues { produced_at: Revision, - values: Vec, + values: Vec, } -impl AccumulatorIngredient { +impl IngredientImpl { /// Find the accumulator ingrediate for `A` in the database, if any. pub fn from_db(db: &Db) -> Option<&Self> where Db: ?Sized + Database, { - let jar: AccumulatorJar = Default::default(); + let jar: JarImpl = Default::default(); let index = db.add_or_lookup_jar_by_type(&jar); let ingredient = db.lookup_ingredient(index).assert_type::(); Some(ingredient) @@ -75,7 +78,7 @@ impl AccumulatorIngredient { } } - pub fn push(&self, runtime: &Runtime, value: A::Data) { + pub fn push(&self, runtime: &Runtime, value: A) { let current_revision = runtime.current_revision(); let (active_query, _) = match runtime.active_query() { Some(pair) => pair, @@ -105,7 +108,7 @@ impl AccumulatorIngredient { &self, runtime: &Runtime, query: DatabaseKeyIndex, - output: &mut Vec, + output: &mut Vec, ) { let current_revision = runtime.current_revision(); if let Some(v) = self.map.get(&query) { @@ -126,7 +129,7 @@ impl AccumulatorIngredient { } } -impl Ingredient for AccumulatorIngredient { +impl Ingredient for IngredientImpl { fn ingredient_index(&self) -> IngredientIndex { self.index } @@ -193,11 +196,11 @@ impl Ingredient for AccumulatorIngredient { } } -impl IngredientRequiresReset for AccumulatorIngredient { +impl IngredientRequiresReset for IngredientImpl { const RESET_ON_NEW_REVISION: bool = false; } -impl std::fmt::Debug for AccumulatorIngredient +impl std::fmt::Debug for IngredientImpl where A: Accumulator, { @@ -207,3 +210,16 @@ where .finish() } } + +pub fn accumulated_by(db: &dyn Database, database_key_index: DatabaseKeyIndex) -> Vec +where + A: Accumulator, +{ + let Some(accumulator) = >::from_db(db) else { + return vec![]; + }; + let runtime = db.runtime(); + let mut output = vec![]; + accumulator.produced_by(runtime, database_key_index, &mut output); + output +} diff --git a/src/function.rs b/src/function.rs index 9039c1da..51c42587 100644 --- a/src/function.rs +++ b/src/function.rs @@ -16,7 +16,6 @@ use self::delete::DeletedEntries; use super::ingredient::Ingredient; -mod accumulated; mod backdate; mod delete; mod diff_outputs; @@ -146,7 +145,7 @@ where } } - fn database_key_index(&self, k: Id) -> DatabaseKeyIndex { + pub fn database_key_index(&self, k: Id) -> DatabaseKeyIndex { DatabaseKeyIndex { ingredient_index: self.index, key_index: k, diff --git a/src/function/accumulated.rs b/src/function/accumulated.rs deleted file mode 100644 index 4c9139ce..00000000 --- a/src/function/accumulated.rs +++ /dev/null @@ -1,79 +0,0 @@ -use crate::{ - accumulator::AccumulatorIngredient, hash::FxHashSet, runtime::local_state::QueryOrigin, - storage::DatabaseGen, DatabaseKeyIndex, Id, -}; - -use super::{Configuration, IngredientImpl}; -use crate::accumulator::Accumulator; - -impl IngredientImpl -where - C: Configuration, -{ - /// Returns all the values accumulated into `accumulator` by this query and its - /// transitive inputs. - pub fn accumulated<'db, A>(&'db self, db: &'db C::DbView, key: Id) -> Vec - where - A: Accumulator, - { - // To start, ensure that the value is up to date: - self.fetch(db, key); - - let Some(accumulator_ingredient) = >::from_db(db) else { - return vec![]; - }; - - // Now walk over all the things that the value depended on - // and find the values they accumulated into the given - // accumulator: - let runtime = db.runtime(); - let mut result = vec![]; - let mut stack = Stack::new(self.database_key_index(key)); - while let Some(input) = stack.pop() { - accumulator_ingredient.produced_by(runtime, input, &mut result); - stack.extend(input.origin(db.as_salsa_database())); - } - result - } -} - -/// The stack is used to execute a DFS across all the queries -/// that were transitively executed by some given start query. -/// When we visit a query Q0, we look at its dependencies Q1...Qn, -/// and if they have not already been visited, we push them on the stack. -struct Stack { - /// Stack of queries left to visit. - v: Vec, - - /// Set of all queries we've seen. - s: FxHashSet, -} - -impl Stack { - fn new(start: DatabaseKeyIndex) -> Self { - Self { - v: vec![start], - s: FxHashSet::default(), - } - } - - fn pop(&mut self) -> Option { - self.v.pop() - } - - /// Extend the stack of queries with the dependencies from `origin`. - fn extend(&mut self, origin: Option) { - match origin { - None | Some(QueryOrigin::Assigned(_)) | Some(QueryOrigin::BaseInput) => {} - Some(QueryOrigin::Derived(edges)) | Some(QueryOrigin::DerivedUntracked(edges)) => { - for dependency_index in edges.inputs() { - if let Ok(i) = DatabaseKeyIndex::try_from(dependency_index) { - if self.s.insert(i) { - self.v.push(i) - } - } - } - } - } - } -} diff --git a/src/input/input_field.rs b/src/input/input_field.rs index 078482d4..4cf78152 100644 --- a/src/input/input_field.rs +++ b/src/input/input_field.rs @@ -1,5 +1,4 @@ use crate::cycle::CycleRecoveryStrategy; -use crate::id::AsId; use crate::ingredient::{fmt_index, Ingredient, IngredientRequiresReset}; use crate::input::Configuration; use crate::runtime::local_state::QueryOrigin; diff --git a/src/lib.rs b/src/lib.rs index ce581e42..edb5dbcd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -24,6 +24,7 @@ mod tracked_struct; mod update; mod views; +pub use self::accumulator::Accumulator; pub use self::cancelled::Cancelled; pub use self::cycle::Cycle; pub use self::database::Database; @@ -44,12 +45,19 @@ pub use salsa_macros::interned; pub use salsa_macros::tracked; pub use salsa_macros::Update; +pub mod prelude { + pub use crate::Accumulator; + pub use crate::Setter; +} + /// Internal names used by salsa macros. /// /// # WARNING /// /// The contents of this module are NOT subject to semver. pub mod plumbing { + pub use crate::accumulator::accumulated_by; + pub use crate::accumulator::Accumulator; pub use crate::array::Array; pub use crate::cycle::Cycle; pub use crate::cycle::CycleRecoveryStrategy; @@ -84,6 +92,7 @@ pub mod plumbing { pub use salsa_macro_rules::maybe_backdate; pub use salsa_macro_rules::maybe_clone; pub use salsa_macro_rules::maybe_cloned_ty; + pub use salsa_macro_rules::setup_accumulator_impl; pub use salsa_macro_rules::setup_input_struct; pub use salsa_macro_rules::setup_interned_fn; pub use salsa_macro_rules::setup_interned_struct; @@ -91,6 +100,11 @@ pub mod plumbing { pub use salsa_macro_rules::setup_tracked_struct; pub use salsa_macro_rules::unexpected_cycle_recovery; + pub mod accumulator { + pub use crate::accumulator::IngredientImpl; + pub use crate::accumulator::JarImpl; + } + pub mod input { pub use crate::input::input_field::FieldIngredientImpl; pub use crate::input::setter::SetterImpl; diff --git a/src/runtime.rs b/src/runtime.rs index f043f3a3..300f3d49 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -50,7 +50,7 @@ pub struct RuntimeId { counter: usize, } -#[derive(Clone, Debug)] +#[derive(Copy, Clone, Debug)] pub struct StampedValue { pub value: V, pub durability: Durability, diff --git a/src/storage.rs b/src/storage.rs index 0581da25..bc646c99 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -1,8 +1,7 @@ use std::any::{Any, TypeId}; -use std::sync::Arc; use orx_concurrent_vec::ConcurrentVec; -use parking_lot::{Condvar, Mutex}; +use parking_lot::Mutex; use rustc_hash::FxHashMap; use crate::cycle::CycleRecoveryStrategy; diff --git a/tests/accumulate-from-tracked-fn.rs b/tests/accumulate-from-tracked-fn.rs index 767b62cf..c66ac428 100644 --- a/tests/accumulate-from-tracked-fn.rs +++ b/tests/accumulate-from-tracked-fn.rs @@ -8,10 +8,8 @@ use common::{HasLogger, Logger}; use expect_test::expect; use test_log::test; -#[salsa::jar(db = Db)] -struct Jar(List, Integers, compute); - -trait Db: salsa::DbWithJar + HasLogger {} +#[salsa::db] +trait Db: salsa::Database + HasLogger {} #[salsa::input] struct List { @@ -43,17 +41,19 @@ fn compute(db: &dyn Db, input: List) { eprintln!("pushed result {:?}", result); } -#[salsa::db(Jar)] +#[salsa::db] #[derive(Default)] struct Database { storage: salsa::Storage, logger: Logger, } +#[salsa::db] impl salsa::Database for Database { fn salsa_event(&self, _event: salsa::Event) {} } +#[salsa::db] impl Db for Database {} impl HasLogger for Database { diff --git a/tests/accumulate-reuse.rs b/tests/accumulate-reuse.rs index fc333fa4..3f83d516 100644 --- a/tests/accumulate-reuse.rs +++ b/tests/accumulate-reuse.rs @@ -7,12 +7,11 @@ mod common; use common::{HasLogger, Logger}; use expect_test::expect; +use salsa::prelude::*; use test_log::test; -#[salsa::jar(db = Db)] -struct Jar(List, Integers, compute); - -trait Db: salsa::DbWithJar + HasLogger {} +#[salsa::db] +trait Db: salsa::Database + HasLogger {} #[salsa::input] struct List { @@ -21,6 +20,7 @@ struct List { } #[salsa::accumulator] +#[derive(Clone, Debug)] struct Integers(u32); #[salsa::tracked] @@ -28,11 +28,11 @@ fn compute(db: &dyn Db, input: List) -> u32 { db.push_log(format!("compute({:?})", input,)); // always pushes 0 - Integers::push(db, 0); + Integers(0).accumulate(db); let result = if let Some(next) = input.next(db) { let next_integers = compute::accumulated::(db, next); - let v = input.value(db) + next_integers.iter().sum::(); + let v = input.value(db) + next_integers.iter().map(|i| i.0).sum::(); v } else { input.value(db) @@ -42,17 +42,19 @@ fn compute(db: &dyn Db, input: List) -> u32 { result } -#[salsa::db(Jar)] +#[salsa::db] #[derive(Default)] struct Database { storage: salsa::Storage, logger: Logger, } +#[salsa::db] impl salsa::Database for Database { fn salsa_event(&self, _event: salsa::Event) {} } +#[salsa::db] impl Db for Database {} impl HasLogger for Database {