make dyn Trait query implementations work

This commit is contained in:
Niko Matsakis 2019-01-25 18:35:16 -05:00
parent 5e2fcc2a17
commit 6f15a440ca
2 changed files with 47 additions and 3 deletions

View file

@ -243,7 +243,9 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream
#trait_vis struct #group_struct { }
impl<DB__> salsa::plumbing::QueryGroup<DB__> for #group_struct
where DB__: #trait_name
where
DB__: #trait_name,
DB__: salsa::Database,
{
type GroupStorage = #group_storage<DB__>;
type GroupKey = #group_key;
@ -288,6 +290,7 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream
impl<DB> salsa::Query<DB> for #qt
where
DB: #trait_name,
DB: salsa::Database,
{
type Key = (#(#keys),*);
type Value = #value;
@ -325,6 +328,7 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream
impl<DB> salsa::plumbing::QueryFunction<DB> for #qt
where
DB: #trait_name,
DB: salsa::Database,
{
fn execute(db: &DB, #key_pattern: <Self as salsa::Query<DB>>::Key)
-> <Self as salsa::Query<DB>>::Value {
@ -371,11 +375,19 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream
// It would derive Default, but then all database structs would have to implement Default
// as the derived version includes an unused `+ Default` constraint.
output.extend(quote! {
#trait_vis struct #group_storage<DB__: #trait_name> {
#trait_vis struct #group_storage<DB__>
where
DB__: #trait_name,
DB__: salsa::Database,
{
#storage_fields
}
impl<DB__: #trait_name> Default for #group_storage<DB__> {
impl<DB__> Default for #group_storage<DB__>
where
DB__: #trait_name,
DB__: salsa::Database,
{
#[inline]
fn default() -> Self {
#group_storage {

32
tests/dyn_trait.rs Normal file
View file

@ -0,0 +1,32 @@
//! Test that you can implement a query using a `dyn Trait` setup.
#[salsa::database(DynTraitStorage)]
#[derive(Default)]
struct DynTraitDatabase {
runtime: salsa::Runtime<DynTraitDatabase>,
}
impl salsa::Database for DynTraitDatabase {
fn salsa_runtime(&self) -> &salsa::Runtime<DynTraitDatabase> {
&self.runtime
}
}
#[salsa::query_group(DynTraitStorage)]
trait DynTrait {
#[salsa::input]
fn input(&self, x: u32) -> u32;
fn output(&self, x: u32) -> u32;
}
fn output(db: &dyn DynTrait, x: u32) -> u32 {
db.input(x) * 2
}
#[test]
fn dyn_trait() {
let mut query = DynTraitDatabase::default();
query.set_input(22, 23);
assert_eq!(query.output(22), 46);
}