diff --git a/components/salsa-macros/src/database_storage.rs b/components/salsa-macros/src/database_storage.rs index e2909cb..5669593 100644 --- a/components/salsa-macros/src/database_storage.rs +++ b/components/salsa-macros/src/database_storage.rs @@ -106,6 +106,7 @@ pub(crate) fn database(args: TokenStream, input: TokenStream) -> TokenStream { // ANCHOR:DatabaseOps let mut fmt_ops = proc_macro2::TokenStream::new(); let mut maybe_changed_ops = proc_macro2::TokenStream::new(); + let mut cycle_recovery_strategy_ops = proc_macro2::TokenStream::new(); let mut for_each_ops = proc_macro2::TokenStream::new(); for ((QueryGroup { group_path }, group_storage), group_index) in query_groups .iter() @@ -126,6 +127,13 @@ pub(crate) fn database(args: TokenStream, input: TokenStream) -> TokenStream { storage.maybe_changed_since(self, input, revision) } }); + cycle_recovery_strategy_ops.extend(quote! { + #group_index => { + let storage: &#group_storage = + >::group_storage(self); + storage.cycle_recovery_strategy(self, input) + } + }); for_each_ops.extend(quote! { let storage: &#group_storage = >::group_storage(self); @@ -168,6 +176,16 @@ pub(crate) fn database(args: TokenStream, input: TokenStream) -> TokenStream { } } + fn cycle_recovery_strategy( + &self, + input: salsa::DatabaseKeyIndex, + ) -> salsa::plumbing::CycleRecoveryStrategy { + match input.group_index() { + #cycle_recovery_strategy_ops + i => panic!("salsa: invalid group index {}", i) + } + } + fn for_each_query( &self, mut op: &mut dyn FnMut(&dyn salsa::plumbing::QueryStorageMassOps), diff --git a/components/salsa-macros/src/query_group.rs b/components/salsa-macros/src/query_group.rs index a241561..9b988e0 100644 --- a/components/salsa-macros/src/query_group.rs +++ b/components/salsa-macros/src/query_group.rs @@ -539,6 +539,17 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream }); } + let mut cycle_recovery_strategy_ops = proc_macro2::TokenStream::new(); + for (Query { fn_name, .. }, query_index) in non_transparent_queries().zip(0_u16..) { + cycle_recovery_strategy_ops.extend(quote! { + #query_index => { + salsa::plumbing::QueryStorageOps::cycle_recovery_strategy( + &*self.#fn_name + ) + } + }); + } + let mut for_each_ops = proc_macro2::TokenStream::new(); for Query { fn_name, .. } in non_transparent_queries() { for_each_ops.extend(quote! { @@ -591,6 +602,17 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream } } + #trait_vis fn cycle_recovery_strategy( + &self, + db: &(#dyn_db + '_), + input: salsa::DatabaseKeyIndex, + ) -> salsa::plumbing::CycleRecoveryStrategy { + match input.query_index() { + #cycle_recovery_strategy_ops + i => panic!("salsa: impossible query index {}", i), + } + } + #trait_vis fn for_each_query( &self, _runtime: &salsa::Runtime, diff --git a/src/plumbing.rs b/src/plumbing.rs index edd06b2..1f4720b 100644 --- a/src/plumbing.rs +++ b/src/plumbing.rs @@ -56,6 +56,9 @@ pub trait DatabaseOps { /// True if the computed value for `input` may have changed since `revision`. fn maybe_changed_since(&self, input: DatabaseKeyIndex, revision: Revision) -> bool; + /// Find the `CycleRecoveryStrategy` for a given input. + fn cycle_recovery_strategy(&self, input: DatabaseKeyIndex) -> CycleRecoveryStrategy; + /// Executes the callback for each kind of query. fn for_each_query(&self, op: &mut dyn FnMut(&dyn QueryStorageMassOps)); } @@ -172,6 +175,10 @@ where revision: Revision, ) -> bool; + fn cycle_recovery_strategy(&self) -> CycleRecoveryStrategy { + Self::CYCLE_STRATEGY + } + /// Execute the query, returning the result (often, the result /// will be memoized). This is the "main method" for /// queries.