diff --git a/components/salsa-macros/src/query_group.rs b/components/salsa-macros/src/query_group.rs index 5a69bd05..ba5af0ed 100644 --- a/components/salsa-macros/src/query_group.rs +++ b/components/salsa-macros/src/query_group.rs @@ -243,7 +243,9 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream #trait_vis struct #group_struct { } impl salsa::plumbing::QueryGroup for #group_struct - where DB__: #trait_name + where + DB__: #trait_name, + DB__: salsa::Database, { type GroupStorage = #group_storage; type GroupKey = #group_key; @@ -288,6 +290,7 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream impl salsa::Query for #qt where DB: #trait_name, + DB: salsa::Database, { type Key = (#(#keys),*); type Value = #value; @@ -325,6 +328,7 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream impl salsa::plumbing::QueryFunction for #qt where DB: #trait_name, + DB: salsa::Database, { fn execute(db: &DB, #key_pattern: >::Key) -> >::Value { @@ -371,11 +375,19 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream // It would derive Default, but then all database structs would have to implement Default // as the derived version includes an unused `+ Default` constraint. output.extend(quote! { - #trait_vis struct #group_storage { + #trait_vis struct #group_storage + where + DB__: #trait_name, + DB__: salsa::Database, + { #storage_fields } - impl Default for #group_storage { + impl Default for #group_storage + where + DB__: #trait_name, + DB__: salsa::Database, + { #[inline] fn default() -> Self { #group_storage { diff --git a/tests/dyn_trait.rs b/tests/dyn_trait.rs new file mode 100644 index 00000000..c232c303 --- /dev/null +++ b/tests/dyn_trait.rs @@ -0,0 +1,32 @@ +//! Test that you can implement a query using a `dyn Trait` setup. + +#[salsa::database(DynTraitStorage)] +#[derive(Default)] +struct DynTraitDatabase { + runtime: salsa::Runtime, +} + +impl salsa::Database for DynTraitDatabase { + fn salsa_runtime(&self) -> &salsa::Runtime { + &self.runtime + } +} + +#[salsa::query_group(DynTraitStorage)] +trait DynTrait { + #[salsa::input] + fn input(&self, x: u32) -> u32; + + fn output(&self, x: u32) -> u32; +} + +fn output(db: &dyn DynTrait, x: u32) -> u32 { + db.input(x) * 2 +} + +#[test] +fn dyn_trait() { + let mut query = DynTraitDatabase::default(); + query.set_input(22, 23); + assert_eq!(query.output(22), 46); +}