diff --git a/components/salsa-macros/Cargo.toml b/components/salsa-macros/Cargo.toml index 1b5fd400..bccdf4e4 100644 --- a/components/salsa-macros/Cargo.toml +++ b/components/salsa-macros/Cargo.toml @@ -16,3 +16,4 @@ heck = "0.3" proc-macro2 = "0.4" quote = "0.6" syn = { version = "0.15", features = ["full", "extra-traits"] } + diff --git a/components/salsa-macros/src/database_storage.rs b/components/salsa-macros/src/database_storage.rs index fdd10725..bab485f3 100644 --- a/components/salsa-macros/src/database_storage.rs +++ b/components/salsa-macros/src/database_storage.rs @@ -50,28 +50,32 @@ pub(crate) fn database_storage(input: TokenStream) -> TokenStream { .collect(); let each_query = || { - query_groups - .iter() - .zip(&query_group_names_snake) - .flat_map(|(query_group, name_snake)| { - query_group.queries.iter().map(move |q| (name_snake, q)) - }) + query_groups.iter().flat_map(|query_group| { + query_group + .queries + .iter() + .map(move |q| (query_group.group_storage(), q)) + }) }; // For each query group `foo::MyGroup` create a link to its // `foo::MyGroupGroupStorage` let mut storage_fields = proc_macro2::TokenStream::new(); + let mut storage_impls = proc_macro2::TokenStream::new(); for (query_group, query_group_name_snake) in query_groups.iter().zip(&query_group_names_snake) { + let group_storage = query_group.group_storage(); + // rewrite the last identifier (`MyGroup`, above) to // (e.g.) `MyGroupGroupStorage`. - let mut group_storage = query_group.query_group.clone(); - let last_ident = &group_storage.segments.last().unwrap().value().ident; - let storage_ident = Ident::new( - &format!("{}GroupStorage", last_ident.to_string()), - Span::call_site(), - ); - group_storage.segments.last_mut().unwrap().value_mut().ident = storage_ident; storage_fields.extend(quote! { #query_group_name_snake: #group_storage<#database_name>, }); + storage_impls.extend(quote! { + impl ::salsa::plumbing::GetQueryGroupStorage<#group_storage<#database_name>> for #database_name { + fn from(db: &Self) -> &#group_storage<#database_name> { + let runtime = ::salsa::Database::salsa_runtime(db); + &runtime.storage().#query_group_name_snake + } + } + }); } let mut attrs = proc_macro2::TokenStream::new(); @@ -132,9 +136,10 @@ pub(crate) fn database_storage(input: TokenStream) -> TokenStream { // let mut for_each_ops = proc_macro2::TokenStream::new(); - for (group_index, Query { query_name, .. }) in each_query() { + for (ref group_storage, Query { query_name, .. }) in each_query() { for_each_ops.extend(quote! { - op(&::salsa::Database::salsa_runtime(self).storage().#group_index.#query_name); + let storage: &#group_storage<#database_name> = ::salsa::plumbing::GetQueryGroupStorage::from(self); + op(&storage.#query_name); }); } output.extend(quote! { @@ -150,7 +155,7 @@ pub(crate) fn database_storage(input: TokenStream) -> TokenStream { let mut for_each_query_desc = proc_macro2::TokenStream::new(); for ( - group_index, + ref group_storage, Query { query_name, query_type, @@ -159,8 +164,8 @@ pub(crate) fn database_storage(input: TokenStream) -> TokenStream { { for_each_query_desc.extend(quote! { __SalsaQueryDescriptorKind::#query_name(key) => { - let runtime = ::salsa::Database::salsa_runtime(db); - let storage = &runtime.storage().#group_index.#query_name; + let group_storage = ::salsa::plumbing::GetQueryGroupStorage::<#group_storage<#database_name>>::from(db); + let storage = &group_storage.#query_name; <_ as ::salsa::plumbing::QueryStorageOps<#database_name, #query_type>>::maybe_changed_since( storage, db, @@ -188,7 +193,7 @@ pub(crate) fn database_storage(input: TokenStream) -> TokenStream { let mut for_each_query_table = proc_macro2::TokenStream::new(); for ( - group_index, + ref group_storage, Query { query_name, query_type, @@ -200,12 +205,10 @@ pub(crate) fn database_storage(input: TokenStream) -> TokenStream { fn get_query_table( db: &Self, ) -> ::salsa::QueryTable<'_, Self, #query_type> { + let storage: &#group_storage<#database_name> = ::salsa::plumbing::GetQueryGroupStorage::from(db); ::salsa::QueryTable::new( db, - &::salsa::Database::salsa_runtime(db) - .storage() - .#group_index - .#query_name, + &storage.#query_name, ) } @@ -213,12 +216,10 @@ pub(crate) fn database_storage(input: TokenStream) -> TokenStream { db: &mut Self, ) -> ::salsa::QueryTableMut<'_, Self, #query_type> { let db = &*db; + let storage: &#group_storage<#database_name> = ::salsa::plumbing::GetQueryGroupStorage::from(db); ::salsa::QueryTableMut::new( db, - &::salsa::Database::salsa_runtime(db) - .storage() - .#group_index - .#query_name, + &storage.#query_name, ) } @@ -235,6 +236,13 @@ pub(crate) fn database_storage(input: TokenStream) -> TokenStream { } output.extend(for_each_query_table); + output.extend(storage_impls); + + if std::env::var("SALSA_DUMP").is_ok() { + println!("~~~ database_storage"); + println!("{}", output.to_string()); + println!("~~~ database_storage"); + } output.into() } @@ -251,6 +259,22 @@ struct QueryGroup { queries: Vec, } +impl QueryGroup { + /// Construct the path to the group storage for a query group. For + /// a query group at the path `foo::MyQuery`, this would be + /// `foo::MyQueryGroupStorage`. + fn group_storage(&self) -> Path { + let mut group_storage = self.query_group.clone(); + let last_ident = &group_storage.segments.last().unwrap().value().ident; + let storage_ident = Ident::new( + &format!("{}GroupStorage", last_ident.to_string()), + Span::call_site(), + ); + group_storage.segments.last_mut().unwrap().value_mut().ident = storage_ident; + group_storage + } +} + struct Query { query_name: Ident, query_type: Path, diff --git a/src/plumbing.rs b/src/plumbing.rs index 61768bed..fa8480da 100644 --- a/src/plumbing.rs +++ b/src/plumbing.rs @@ -81,6 +81,10 @@ pub trait GetQueryTable>: Database { fn descriptor(db: &Self, key: Q::Key) -> Self::QueryDescriptor; } +pub trait GetQueryGroupStorage: Database { + fn from(db: &Self) -> &S; +} + pub trait QueryStorageOps: Default where DB: Database,