change #[salsa::query_group] attribute to take a struct name

This commit is contained in:
Niko Matsakis 2019-01-25 10:25:17 -05:00
parent 690a118472
commit 9b5c7eeb5e
23 changed files with 133 additions and 111 deletions

View file

@ -1,6 +1,5 @@
use heck::SnakeCase;
use proc_macro::TokenStream;
use proc_macro2::Span;
use syn::parse::{Parse, ParseStream};
use syn::punctuated::Punctuated;
use syn::{Ident, ItemStruct, Path, Token};
@ -18,21 +17,29 @@ pub(crate) fn database(args: TokenStream, input: TokenStream) -> TokenStream {
let mut output = proc_macro2::TokenStream::new();
output.extend(quote! { #input });
let query_group_names_camel: Vec<_> = query_groups
let query_group_names_snake: Vec<_> = query_groups
.iter()
.map(|query_group| {
let group_storage = query_group.query_group.clone();
group_storage.segments.last().unwrap().value().ident.clone()
let group_name = query_group.name();
Ident::new(&group_name.to_string().to_snake_case(), group_name.span())
})
.collect();
let query_group_names_snake: Vec<_> = query_group_names_camel
let query_group_storage_names: Vec<_> = query_groups
.iter()
.map(|query_group_name_camel| {
Ident::new(
&query_group_name_camel.to_string().to_snake_case(),
query_group_name_camel.span(),
)
.map(|QueryGroup { group_path }| {
quote! {
<#group_path as salsa::plumbing::QueryGroup<#database_name>>::GroupStorage
}
})
.collect();
let query_group_key_names: Vec<_> = query_groups
.iter()
.map(|QueryGroup { group_path }| {
quote! {
<#group_path as salsa::plumbing::QueryGroup<#database_name>>::GroupKey
}
})
.collect();
@ -40,21 +47,25 @@ pub(crate) fn database(args: TokenStream, input: TokenStream) -> TokenStream {
// `foo::MyGroupGroupStorage`
let mut storage_fields = proc_macro2::TokenStream::new();
let mut has_group_impls = proc_macro2::TokenStream::new();
for (query_group, query_group_name_snake) in query_groups.iter().zip(&query_group_names_snake) {
for (((query_group, group_name_snake), group_storage), group_key) in query_groups
.iter()
.zip(&query_group_names_snake)
.zip(&query_group_storage_names)
.zip(&query_group_key_names)
{
let group_path = &query_group.group_path;
let group_name = query_group.name();
let group_storage = query_group.group_storage();
let group_key = query_group.group_key();
// rewrite the last identifier (`MyGroup`, above) to
// (e.g.) `MyGroupGroupStorage`.
storage_fields.extend(quote! { #query_group_name_snake: #group_storage<#database_name>, });
storage_fields.extend(quote! {
#group_name_snake: #group_storage,
});
has_group_impls.extend(quote! {
impl ::salsa::plumbing::HasQueryGroup<#group_storage<#database_name>, #group_key>
for #database_name
{
fn group_storage(db: &Self) -> &#group_storage<#database_name> {
impl ::salsa::plumbing::HasQueryGroup<#group_path> for #database_name {
fn group_storage(db: &Self) -> &#group_storage {
let runtime = ::salsa::Database::salsa_runtime(db);
&runtime.storage().#query_group_name_snake
&runtime.storage().#group_name_snake
}
fn database_key(group_key: #group_key) -> __SalsaDatabaseKey {
@ -90,9 +101,8 @@ pub(crate) fn database(args: TokenStream, input: TokenStream) -> TokenStream {
// foo(<FooType as ::salsa::Query<#database_name>>::Key),
// ```
let mut variants = proc_macro2::TokenStream::new();
for query_group in query_groups {
for (query_group, group_key) in query_groups.iter().zip(&query_group_key_names) {
let group_name = query_group.name();
let group_key = query_group.group_key();
variants.extend(quote!(
#group_name(#group_key),
));
@ -114,11 +124,12 @@ pub(crate) fn database(args: TokenStream, input: TokenStream) -> TokenStream {
//
let mut for_each_ops = proc_macro2::TokenStream::new();
for query_group in query_groups {
let group_storage = query_group.group_storage();
for (QueryGroup { group_path }, group_storage) in
query_groups.iter().zip(&query_group_storage_names)
{
for_each_ops.extend(quote! {
let storage: &#group_storage<#database_name> =
::salsa::plumbing::HasQueryGroup::group_storage(self);
let storage: &#group_storage =
<Self as salsa::plumbing::HasQueryGroup<#group_path>>::group_storage(self);
storage.for_each_query(self, &mut op);
});
}
@ -184,13 +195,13 @@ impl Parse for QueryGroupList {
#[derive(Clone, Debug)]
struct QueryGroup {
query_group: Path,
group_path: Path,
}
impl QueryGroup {
/// The name of the query group trait.
fn name(&self) -> Ident {
self.query_group
self.group_path
.segments
.last()
.unwrap()
@ -198,35 +209,6 @@ impl QueryGroup {
.ident
.clone()
}
/// Construct the path to the group storage for a query group. For
/// a query group at the path `foo::MyQuery`, this would be
/// `foo::MyQueryGroupStorage`.
fn group_storage(&self) -> Path {
self.path_with_suffix("GroupStorage")
}
/// Construct the path to the group storage for a query group. For
/// a query group at the path `foo::MyQuery`, this would be
/// `foo::MyQueryGroupDatabaseKey`.
fn group_key(&self) -> Path {
self.path_with_suffix("GroupKey")
}
/// Construct a path leading to the query group, but with some
/// suffix. So, for a query group at the path `foo::MyQuery`,
/// this would be `foo::MyQueryXXX` where `XXX` is the provided
/// suffix.
fn path_with_suffix(&self, suffix: &str) -> Path {
let mut group_storage = self.query_group.clone();
let last_ident = &group_storage.segments.last().unwrap().value().ident;
let storage_ident = Ident::new(
&format!("{}{}", last_ident.to_string(), suffix),
Span::call_site(),
);
group_storage.segments.last_mut().unwrap().value_mut().ident = storage_ident;
group_storage
}
}
impl Parse for QueryGroup {
@ -234,8 +216,8 @@ impl Parse for QueryGroup {
/// impl HelloWorldDatabase;
/// ```
fn parse(input: ParseStream) -> syn::Result<Self> {
let query_group: Path = input.parse()?;
Ok(QueryGroup { query_group })
let group_path: Path = input.parse()?;
Ok(QueryGroup { group_path })
}
}

View file

@ -3,12 +3,12 @@ use heck::CamelCase;
use proc_macro::TokenStream;
use proc_macro2::Span;
use quote::ToTokens;
use syn::{parse_macro_input, AttributeArgs, FnArg, Ident, ItemTrait, ReturnType, TraitItem};
use syn::{parse_macro_input, FnArg, Ident, ItemTrait, ReturnType, TraitItem};
/// Implementation for `[salsa::query_group]` decorator.
pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream {
let _args = parse_macro_input!(args as AttributeArgs);
let input = parse_macro_input!(input as ItemTrait);
let group_struct: Ident = parse_macro_input!(args as Ident);
let input: ItemTrait = parse_macro_input!(input as ItemTrait);
// println!("args: {:#?}", args);
// println!("input: {:#?}", input);
@ -121,12 +121,12 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream
}
let group_key = Ident::new(
&format!("{}GroupKey", trait_name.to_string()),
&format!("{}GroupKey__", trait_name.to_string()),
Span::call_site(),
);
let group_storage = Ident::new(
&format!("{}GroupStorage", trait_name.to_string()),
&format!("{}GroupStorage__", trait_name.to_string()),
Span::call_site(),
);
@ -235,6 +235,19 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream
}
};
// Emit the query group struct and impl of `QueryGroup`.
output.extend(quote! {
/// Representative struct for the query group.
#trait_vis struct #group_struct { }
impl<DB__> salsa::plumbing::QueryGroup<DB__> for #group_struct
where DB__: #trait_name
{
type GroupStorage = #group_storage<DB__>;
type GroupKey = #group_key;
}
});
// Emit an impl of the trait
output.extend({
let bounds = &input.supertraits;
@ -242,7 +255,7 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream
impl<T> #trait_name for T
where
T: #bounds,
T: ::salsa::plumbing::HasQueryGroup<#group_storage<T>, #group_key>,
T: ::salsa::plumbing::HasQueryGroup<#group_struct>
{
#query_fn_definitions
}
@ -277,6 +290,7 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream
type Key = (#(#keys),*);
type Value = #value;
type Storage = salsa::plumbing::#storage<DB, Self>;
type Group = #group_struct;
type GroupStorage = #group_storage<DB>;
type GroupKey = #group_key;
@ -335,7 +349,7 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream
) -> bool
where
DB__: #trait_name,
DB__: ::salsa::plumbing::HasQueryGroup<#group_storage<DB__>, #group_key>,
DB__: ::salsa::plumbing::HasQueryGroup<#group_struct>,
{
match self {
#query_descriptor_maybe_change
@ -361,7 +375,7 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream
impl<DB__> #group_storage<DB__>
where
DB__: #trait_name,
DB__: ::salsa::plumbing::HasQueryGroup<#group_storage<DB__>, #group_key>,
DB__: ::salsa::plumbing::HasQueryGroup<#group_struct>,
{
#trait_vis fn for_each_query(
&self,

View file

@ -1,7 +1,7 @@
use crate::compiler;
use std::sync::Arc;
#[salsa::query_group]
#[salsa::query_group(ClassTable)]
pub trait ClassTableDatabase: compiler::CompilerDatabase {
/// Get the fields.
fn fields(&self, class: DefId) -> Arc<Vec<DefId>>;

View file

@ -13,7 +13,7 @@ use crate::compiler::{CompilerDatabase, Interner};
/// to your context (e.g., a shared counter or some such thing). If
/// mutations to that shared state affect the results of your queries,
/// that's going to mess up the incremental results.
#[salsa::database(class_table::ClassTableDatabase)]
#[salsa::database(class_table::ClassTable)]
#[derive(Default)]
pub struct DatabaseImpl {
runtime: salsa::Runtime<DatabaseImpl>,

View file

@ -4,14 +4,23 @@ use std::sync::Arc;
// Step 1. Define the query group
// A **query group** is a collection of queries (both inputs and
// functions) that are defined in one particular spot. Each query group
// represents some subset of the full set of queries you will use in your
// application. Query groups can also depend on one another: so you might
// have some basic query group A and then another query group B that uses
// the queries from A and adds a few more. (These relationships must form
// a DAG at present, but that is due to Rust's restrictions around
// supertraits, which are likely to be lifted.)
#[salsa::query_group]
// functions) that are defined in one particular spot. Each query
// group is defined by a representative struct (used internally by
// Salsa) as well as a representative trait. By convention, for a
// query group `Foo`, the struct is named `Foo` and the trait is named
// `FooDatabase`. The name `FooDatabase` reflects the fact that the
// trait is implemented by **the database**, which stores all the data
// in the system. Each query group thus represents a subset of the
// full data.
//
// To define a query group, you annotate a trait definition with the
// `#[salsa::query_group(Foo)]` attribute macro. In addition to the
// trait definition, the macro will generate a struct with the name
// `Foo` that you provide, as well as various other bits of glue.
//
// Note that one query group can "include" another by listing the
// trait for that query group as a supertrait.
#[salsa::query_group(HelloWorld)]
trait HelloWorldDatabase: salsa::Database {
// For each query, we give the name, some input keys (here, we
// have one key, `()`) and the output type `Arc<String>`. We can
@ -52,15 +61,17 @@ fn length(db: &impl HelloWorldDatabase, (): ()) -> usize {
// Step 3. Define the database struct
// Define the actual database struct. This struct needs to be
// annotated with `#[salsa::database(..)]`, which contains a list of
// query groups that this database supports. This attribute macro will
// generate the necessary impls so that the database implements all of
// those traits.
// annotated with `#[salsa::database(..)]`. The list `..` will be the
// paths leading to the query group structs for each query group that
// this database supports. This attribute macro will generate the
// necessary impls so that the database implements the corresponding
// traits as well (so, here, `DatabaseStruct` will implement the
// `HelloWorldDatabase` trait).
//
// The database struct can contain basically anything you need, but it
// must have a `runtime` field as shown, and you must implement the
// `salsa::Database` trait (as shown below).
#[salsa::database(HelloWorldDatabase)]
#[salsa::database(HelloWorld)]
#[derive(Default)]
struct DatabaseStruct {
runtime: salsa::Runtime<DatabaseStruct>,

View file

@ -349,6 +349,13 @@ pub trait Query<DB: Database>: Debug + Default + Sized + 'static {
/// Internal struct storing the values for the query.
type Storage: plumbing::QueryStorageOps<DB, Self> + Send + Sync;
/// Associate query group struct.
type Group: plumbing::QueryGroup<
DB,
GroupStorage = Self::GroupStorage,
GroupKey = Self::GroupKey,
>;
/// Generated struct that contains storage for all queries in a group.
type GroupStorage;

View file

@ -86,7 +86,7 @@ impl<DB, Q> GetQueryTable<Q> for DB
where
DB: Database,
Q: Query<DB>,
DB: HasQueryGroup<Q::GroupStorage, Q::GroupKey>,
DB: HasQueryGroup<Q::Group>,
{
fn get_query_table(db: &DB) -> QueryTable<'_, DB, Q> {
let group_storage: &Q::GroupStorage = HasQueryGroup::group_storage(db);
@ -106,18 +106,26 @@ where
key: <Q as Query<DB>>::Key,
) -> <DB as DatabaseStorageTypes>::DatabaseKey {
let group_key = Q::group_key(key);
<DB as HasQueryGroup<_, _>>::database_key(group_key)
<DB as HasQueryGroup<_>>::database_key(group_key)
}
}
pub trait QueryGroup<DB: Database> {
type GroupStorage;
type GroupKey;
}
/// Trait implemented by a database for each group that it supports.
/// `S` and `K` are the types for *group storage* and *group key*, respectively.
pub trait HasQueryGroup<S, K>: Database {
pub trait HasQueryGroup<G>: Database
where
G: QueryGroup<Self>,
{
/// Access the group storage struct from the database.
fn group_storage(db: &Self) -> &S;
fn group_storage(db: &Self) -> &G::GroupStorage;
/// "Upcast" a group key into a database key.
fn database_key(group_key: K) -> Self::DatabaseKey;
fn database_key(group_key: G::GroupKey) -> Self::DatabaseKey;
}
pub trait QueryStorageOps<DB, Q>: Default

View file

@ -1,4 +1,4 @@
#[salsa::database(Database)]
#[salsa::database(GroupStruct)]
#[derive(Default)]
struct DatabaseImpl {
runtime: salsa::Runtime<DatabaseImpl>,
@ -10,7 +10,7 @@ impl salsa::Database for DatabaseImpl {
}
}
#[salsa::query_group]
#[salsa::query_group(GroupStruct)]
trait Database: salsa::Database {
// `a` and `b` depend on each other and form a cycle
fn memoized_a(&self) -> ();

View file

@ -1,7 +1,7 @@
use crate::group;
use crate::log::{HasLog, Log};
#[salsa::database(group::GcDatabase)]
#[salsa::database(group::Gc)]
#[derive(Default)]
pub(crate) struct DatabaseImpl {
runtime: salsa::Runtime<DatabaseImpl>,

View file

@ -1,6 +1,6 @@
use crate::log::HasLog;
#[salsa::query_group]
#[salsa::query_group(Gc)]
pub(crate) trait GcDatabase: salsa::Database + HasLog {
#[salsa::input]
fn min(&self) -> usize;

View file

@ -2,7 +2,7 @@ use crate::implementation::{TestContext, TestContextImpl};
use salsa::debug::DebugQueryTable;
use salsa::Database;
#[salsa::query_group]
#[salsa::query_group(Constants)]
pub(crate) trait ConstantsDatabase: TestContext {
#[salsa::input]
fn input(&self, key: char) -> usize;

View file

@ -11,10 +11,10 @@ pub(crate) trait TestContext: salsa::Database {
}
#[salsa::database(
constants::ConstantsDatabase,
memoized_dep_inputs::MemoizedDepInputsContext,
memoized_inputs::MemoizedInputsContext,
memoized_volatile::MemoizedVolatileContext
constants::Constants,
memoized_dep_inputs::MemoizedDepInputs,
memoized_inputs::MemoizedInputs,
memoized_volatile::MemoizedVolatile
)]
#[derive(Default)]
pub(crate) struct TestContextImpl {

View file

@ -1,7 +1,7 @@
use crate::implementation::{TestContext, TestContextImpl};
use salsa::Database;
#[salsa::query_group]
#[salsa::query_group(MemoizedDepInputs)]
pub(crate) trait MemoizedDepInputsContext: TestContext {
fn dep_memoized2(&self) -> usize;
fn dep_memoized1(&self) -> usize;

View file

@ -1,7 +1,7 @@
use crate::implementation::{TestContext, TestContextImpl};
use salsa::Database;
#[salsa::query_group]
#[salsa::query_group(MemoizedInputs)]
pub(crate) trait MemoizedInputsContext: TestContext {
fn max(&self) -> usize;
#[salsa::input]

View file

@ -1,7 +1,7 @@
use crate::implementation::{TestContext, TestContextImpl};
use salsa::Database;
#[salsa::query_group]
#[salsa::query_group(MemoizedVolatile)]
pub(crate) trait MemoizedVolatileContext: TestContext {
// Queries for testing a "volatile" value wrapped by
// memoization.

View file

@ -1,4 +1,4 @@
#[salsa::query_group]
#[salsa::query_group(MyStruct)]
trait MyDatabase: salsa::Database {
#[salsa::invoke(another_module::another_name)]
fn my_query(&self, key: ()) -> ();

View file

@ -1,7 +1,7 @@
use salsa::{Database, ParallelDatabase, Snapshot};
use std::panic::{self, AssertUnwindSafe};
#[salsa::query_group]
#[salsa::query_group(PanicSafelyStruct)]
trait PanicSafelyDatabase: salsa::Database {
#[salsa::input]
fn one(&self) -> usize;
@ -13,7 +13,7 @@ fn panic_safely(db: &impl PanicSafelyDatabase) -> () {
assert_eq!(db.one(), 1);
}
#[salsa::database(PanicSafelyDatabase)]
#[salsa::database(PanicSafelyStruct)]
#[derive(Default)]
struct DatabaseStruct {
runtime: salsa::Runtime<DatabaseStruct>,

View file

@ -5,7 +5,7 @@ use salsa::Snapshot;
use std::cell::Cell;
use std::sync::Arc;
#[salsa::query_group]
#[salsa::query_group(Par)]
pub(crate) trait ParDatabase: Knobs + salsa::ParallelDatabase {
#[salsa::input]
fn input(&self, key: char) -> usize;
@ -184,7 +184,7 @@ fn snapshot_me(db: &impl ParDatabase) {
db.snapshot();
}
#[salsa::database(ParDatabase)]
#[salsa::database(Par)]
#[derive(Default)]
pub(crate) struct ParDatabaseImpl {
runtime: salsa::Runtime<ParDatabaseImpl>,

View file

@ -12,7 +12,7 @@ const N_READER_OPS: usize = 100;
struct Canceled;
type Cancelable<T> = Result<T, Canceled>;
#[salsa::query_group]
#[salsa::query_group(Stress)]
trait StressDatabase: salsa::Database {
#[salsa::input]
fn a(&self, key: usize) -> usize;
@ -33,7 +33,7 @@ fn c(db: &impl StressDatabase, key: usize) -> Cancelable<usize> {
db.b(key)
}
#[salsa::database(StressDatabase)]
#[salsa::database(Stress)]
#[derive(Default)]
struct StressDatabaseImpl {
runtime: salsa::Runtime<StressDatabaseImpl>,

View file

@ -1,6 +1,6 @@
use salsa::Database;
#[salsa::query_group]
#[salsa::query_group(HelloWorldStruct)]
trait HelloWorldDatabase: salsa::Database {
#[salsa::input]
fn input(&self) -> String;
@ -20,7 +20,7 @@ fn double_length(db: &impl HelloWorldDatabase) -> usize {
db.length() * 2
}
#[salsa::database(HelloWorldDatabase)]
#[salsa::database(HelloWorldStruct)]
#[derive(Default)]
struct DatabaseStruct {
runtime: salsa::Runtime<DatabaseStruct>,

View file

@ -1,7 +1,7 @@
use crate::queries;
use std::cell::Cell;
#[salsa::database(queries::Database)]
#[salsa::database(queries::GroupStruct)]
#[derive(Default)]
pub(crate) struct DatabaseImpl {
runtime: salsa::Runtime<DatabaseImpl>,

View file

@ -2,7 +2,7 @@ pub(crate) trait Counter: salsa::Database {
fn increment(&self) -> usize;
}
#[salsa::query_group]
#[salsa::query_group(GroupStruct)]
pub(crate) trait Database: Counter {
fn memoized(&self) -> usize;
#[salsa::volatile]

View file

@ -1,6 +1,6 @@
use salsa::Database;
#[salsa::query_group]
#[salsa::query_group(HelloWorld)]
trait HelloWorldDatabase: salsa::Database {
#[salsa::input]
fn input(&self, a: u32, b: u32) -> u32;
@ -30,7 +30,7 @@ fn trailing(_db: &impl HelloWorldDatabase, a: u32, b: u32) -> u32 {
a - b
}
#[salsa::database(HelloWorldDatabase)]
#[salsa::database(HelloWorld)]
#[derive(Default)]
struct DatabaseStruct {
runtime: salsa::Runtime<DatabaseStruct>,