autogenerate interned case

This commit is contained in:
Niko Matsakis 2024-07-13 07:01:31 -04:00
parent dde7341f97
commit a4e550065f
8 changed files with 1227 additions and 906 deletions

View file

@ -45,10 +45,6 @@ macro_rules! setup_interned_fn {
// (e.g., `(a, b): (u32, u32)`) we will synthesize an identifier.
input_ids: [$($input_id:ident),*],
// Patterns that the user gave for each argument EXCEPT the database.
// May be identifiers, but could be something else.
input_pats: [$($input_pat:pat),*],
// Types of the function arguments (may reference `$generics`).
input_tys: [$($input_ty:ty),*],
@ -56,7 +52,7 @@ macro_rules! setup_interned_fn {
output_ty: $output_ty:ty,
// Function body, may reference identifiers defined in `$input_pats` and the generics from `$generics`
body: $body:block,
inner_fn: $inner_fn:item,
// Path to the cycle recovery function to use.
cycle_recovery_fn: ($($cycle_recovery_fn:tt)*),
@ -121,12 +117,7 @@ macro_rules! setup_interned_fn {
}
fn execute<'db>($db: &'db Self::DbView, ($($input_id),*): ($($input_ty),*)) -> Self::Output<'db> {
$vis fn $inner<$db_lt>(
$db: &$db_lt dyn $Db,
$($input_pat: $input_ty,)*
) -> $output_ty {
$body
}
$inner_fn
$inner($db, $($input_id),*)
}

View file

@ -15,7 +15,7 @@ pub(crate) fn db(
input: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
let _nothing = syn::parse_macro_input!(args as Nothing);
let hygiene = Hygiene::from(&input);
let hygiene = Hygiene::from1(&input);
let input = syn::parse_macro_input!(input as syn::Item);
let db_macro = DbMacro { hygiene };
match db_macro.try_db(input) {

View file

@ -1,21 +1,29 @@
use std::collections::HashSet;
use quote::ToTokens;
pub struct Hygiene {
user_tokens: HashSet<String>,
}
impl From<&proc_macro::TokenStream> for Hygiene {
fn from(input: &proc_macro::TokenStream) -> Self {
impl Hygiene {
pub fn from1(tokens: &proc_macro::TokenStream) -> Self {
let mut user_tokens = HashSet::new();
push_idents(input.clone(), &mut user_tokens);
push_idents1(tokens.clone(), &mut user_tokens);
Self { user_tokens }
}
pub fn from2(tokens: &impl ToTokens) -> Self {
let mut user_tokens = HashSet::new();
push_idents2(tokens.to_token_stream(), &mut user_tokens);
Self { user_tokens }
}
}
fn push_idents(input: proc_macro::TokenStream, user_tokens: &mut HashSet<String>) {
fn push_idents1(input: proc_macro::TokenStream, user_tokens: &mut HashSet<String>) {
input.into_iter().for_each(|token| match token {
proc_macro::TokenTree::Group(g) => {
push_idents(g.stream(), user_tokens);
push_idents1(g.stream(), user_tokens);
}
proc_macro::TokenTree::Ident(ident) => {
user_tokens.insert(ident.to_string());
@ -25,12 +33,27 @@ fn push_idents(input: proc_macro::TokenStream, user_tokens: &mut HashSet<String>
})
}
fn push_idents2(input: proc_macro2::TokenStream, user_tokens: &mut HashSet<String>) {
input.into_iter().for_each(|token| match token {
proc_macro2::TokenTree::Group(g) => {
push_idents2(g.stream(), user_tokens);
}
proc_macro2::TokenTree::Ident(ident) => {
user_tokens.insert(ident.to_string());
}
proc_macro2::TokenTree::Punct(_) => (),
proc_macro2::TokenTree::Literal(_) => (),
})
}
impl Hygiene {
/// Generates an identifier similar to `text` but
/// distinct from any identifiers that appear in the user's
/// code.
pub(crate) fn ident(&self, text: &str) -> syn::Ident {
let mut buffer = String::from(text);
// Make the default be `foo_` rather than `foo` -- this helps detect
// cases where people wrote `foo` instead of `#foo` or `$foo` in the generated code.
let mut buffer = format!("{}_", text);
while self.user_tokens.contains(&buffer) {
buffer.push('_');

View file

@ -50,6 +50,7 @@ mod options;
mod salsa_struct;
mod tracked;
mod tracked_fn;
mod tracked_fn1;
mod tracked_struct;
mod update;
mod xform;

View file

@ -8,7 +8,7 @@ pub(crate) fn tracked(
let res = match item {
syn::Item::Struct(item) => crate::tracked_struct::tracked(args, item),
syn::Item::Fn(item) => crate::tracked_fn::tracked_fn(args, item),
syn::Item::Impl(item) => crate::tracked_fn::tracked_impl(args, item),
syn::Item::Impl(item) => crate::tracked_fn1::tracked_impl(args, item),
_ => Err(syn::Error::new(
item.span(),
"tracked can only be applied to structs, functions, and impls",

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,925 @@
use proc_macro2::{Literal, Span, TokenStream};
use syn::spanned::Spanned;
use syn::visit_mut::VisitMut;
use syn::ReturnType;
use crate::configuration::{self, Configuration, CycleRecoveryStrategy};
use crate::db_lifetime::{self, db_lifetime, require_optional_db_lifetime};
use crate::options::Options;
use crate::xform::ChangeLt;
pub(crate) fn tracked_fn(
args: proc_macro::TokenStream,
mut item_fn: syn::ItemFn,
) -> syn::Result<TokenStream> {
db_lifetime::require_optional_db_lifetime(&item_fn.sig.generics)?;
let fn_ident = item_fn.sig.ident.clone();
let args: FnArgs = syn::parse(args)?;
if item_fn.sig.inputs.is_empty() {
return Err(syn::Error::new(
item_fn.sig.ident.span(),
"tracked functions must have at least a database argument",
));
}
if let syn::FnArg::Receiver(receiver) = &item_fn.sig.inputs[0] {
return Err(syn::Error::new(
receiver.span(),
"#[salsa::tracked] must also be applied to the impl block for tracked methods",
));
}
if let Some(s) = &args.specify {
if function_type(&item_fn) == FunctionType::RequiresInterning {
return Err(syn::Error::new(
s.span(),
"tracked function takes too many arguments to have its value set with `specify`",
));
}
if args.lru.is_some() {
return Err(syn::Error::new(
s.span(),
"`specify` and `lru` cannot be used together",
));
}
}
let (config_ty, fn_struct) = fn_struct(&args, &item_fn)?;
*item_fn.block = getter_fn(&args, &mut item_fn.sig, item_fn.block.span(), &config_ty)?;
Ok(crate::debug::dump_tokens(
fn_ident,
quote! {
#fn_struct
// we generate a `'db` lifetime that clippy
// sometimes doesn't like
#[allow(clippy::needless_lifetimes)]
#item_fn
},
))
}
type FnArgs = Options<TrackedFn>;
struct TrackedFn;
impl crate::options::AllowedOptions for TrackedFn {
const RETURN_REF: bool = true;
const SPECIFY: bool = true;
const NO_EQ: bool = true;
const SINGLETON: bool = false;
const JAR: bool = true;
const DATA: bool = false;
const DB: bool = false;
const RECOVERY_FN: bool = true;
const LRU: bool = true;
const CONSTRUCTOR_NAME: bool = false;
}
type ImplArgs = Options<TrackedImpl>;
pub(crate) fn tracked_impl(
args: proc_macro::TokenStream,
mut item_impl: syn::ItemImpl,
) -> syn::Result<TokenStream> {
let args: ImplArgs = syn::parse(args)?;
let self_type = match &*item_impl.self_ty {
syn::Type::Path(path) => path,
_ => {
return Err(syn::Error::new(
item_impl.self_ty.span(),
"#[salsa::tracked] can only be applied to salsa structs",
))
}
};
let self_type_name = &self_type.path.segments.last().unwrap().ident;
let name_prefix = match &item_impl.trait_ {
Some((_, trait_name, _)) => format!(
"{}_{}",
self_type_name,
trait_name.segments.last().unwrap().ident
),
None => format!("{}", self_type_name),
};
#[allow(clippy::manual_try_fold)] // we accumulate errors
let extra_impls = item_impl
.items
.iter_mut()
.filter_map(|item| {
let item_method = match item {
syn::ImplItem::Fn(item_method) => item_method,
_ => return None,
};
let salsa_tracked_attr = item_method.attrs.iter().position(|attr| {
let path = &attr.path().segments;
path.len() == 2
&& path[0].arguments.is_none()
&& path[0].ident == "salsa"
&& path[1].arguments.is_none()
&& path[1].ident == "tracked"
})?;
let salsa_tracked_attr = item_method.attrs.remove(salsa_tracked_attr);
let inner_args = match salsa_tracked_attr.meta {
syn::Meta::Path(_) => Ok(FnArgs::default()),
syn::Meta::List(_) | syn::Meta::NameValue(_) => salsa_tracked_attr.parse_args(),
};
let inner_args = match inner_args {
Ok(inner_args) => inner_args,
Err(err) => return Some(Err(err)),
};
let name = format!("{}_{}", name_prefix, item_method.sig.ident);
Some(tracked_method(
&item_impl.generics,
&args,
inner_args,
item_method,
self_type,
&name,
))
})
// Collate all the errors so we can display them all at once
.fold(Ok(Vec::new()), |mut acc, res| {
match (&mut acc, res) {
(Ok(extra_impls), Ok(impls)) => extra_impls.push(impls),
(Ok(_), Err(err)) => acc = Err(err),
(Err(_), Ok(_)) => {}
(Err(errors), Err(err)) => errors.combine(err),
}
acc
})?;
Ok(crate::debug::dump_tokens(
self_type_name,
quote! {
#item_impl
#(#extra_impls)*
},
))
}
struct TrackedImpl;
impl crate::options::AllowedOptions for TrackedImpl {
const RETURN_REF: bool = false;
const SPECIFY: bool = false;
const NO_EQ: bool = false;
const JAR: bool = true;
const DATA: bool = false;
const DB: bool = false;
const RECOVERY_FN: bool = false;
const LRU: bool = false;
const CONSTRUCTOR_NAME: bool = false;
const SINGLETON: bool = false;
}
fn tracked_method(
impl_generics: &syn::Generics,
outer_args: &ImplArgs,
mut args: FnArgs,
item_method: &mut syn::ImplItemFn,
self_type: &syn::TypePath,
name: &str,
) -> syn::Result<TokenStream> {
args.jar_ty = args.jar_ty.or_else(|| outer_args.jar_ty.clone());
if item_method.sig.inputs.len() <= 1 {
return Err(syn::Error::new(
item_method.sig.ident.span(),
"tracked methods must have at least self and a database argument",
));
}
let mut item_fn = syn::ItemFn {
attrs: item_method.attrs.clone(),
vis: item_method.vis.clone(),
sig: item_method.sig.clone(),
block: Box::new(rename_self_in_block(item_method.block.clone())?),
};
item_fn.sig.ident = syn::Ident::new(name, item_fn.sig.ident.span());
// Insert the generics from impl at the start of the fn generics
for parameter in impl_generics.params.iter().rev() {
item_fn.sig.generics.params.insert(0, parameter.clone());
}
// Flip the first and second arguments as the rest of the code expects the
// database to come first and the struct to come second. We also need to
// change the self argument to a normal typed argument called __salsa_self.
let mut original_inputs = item_fn.sig.inputs.into_pairs();
let self_param = match original_inputs.next().unwrap().into_value() {
syn::FnArg::Receiver(r) if r.reference.is_none() => r,
arg => return Err(syn::Error::new(arg.span(), "first argument must be self")),
};
let db_param = original_inputs.next().unwrap().into_value();
let mut inputs = syn::punctuated::Punctuated::new();
inputs.push(db_param);
inputs.push(syn::FnArg::Typed(syn::PatType {
attrs: self_param.attrs,
pat: Box::new(syn::Pat::Ident(syn::PatIdent {
attrs: Vec::new(),
by_ref: None,
mutability: self_param.mutability,
ident: syn::Ident::new("__salsa_self", self_param.self_token.span),
subpat: None,
})),
colon_token: Default::default(),
ty: Box::new(syn::Type::Path(self_type.clone())),
}));
inputs.push_punct(Default::default());
inputs.extend(original_inputs);
item_fn.sig.inputs = inputs;
let (config_ty, fn_struct) = crate::tracked_fn1::fn_struct(&args, &item_fn)?;
// we generate a `'db` lifetime that clippy
// sometimes doesn't like
item_method
.attrs
.push(syn::parse_quote! {#[allow(clippy::needless_lifetimes)]});
item_method.block = getter_fn(
&args,
&mut item_method.sig,
item_method.block.span(),
&config_ty,
)?;
Ok(fn_struct)
}
/// Rename all occurrences of `self` to `__salsa_self` in a block
/// so that it can be used in a free function.
fn rename_self_in_block(mut block: syn::Block) -> syn::Result<syn::Block> {
struct RenameIdent(syn::Result<()>);
impl syn::visit_mut::VisitMut for RenameIdent {
fn visit_ident_mut(&mut self, i: &mut syn::Ident) {
if i == "__salsa_self" {
let err = syn::Error::new(
i.span(),
"Existing variable name clashes with 'self' -> '__salsa_self' renaming",
);
match &mut self.0 {
Ok(()) => self.0 = Err(err),
Err(errors) => errors.combine(err),
}
}
if i == "self" {
*i = syn::Ident::new("__salsa_self", i.span());
}
}
}
let mut rename = RenameIdent(Ok(()));
rename.visit_block_mut(&mut block);
rename.0.map(move |()| block)
}
/// Create the struct representing the function and all of its impls.
///
/// This returns the name of the constructed type and the code defining everything.
fn fn_struct(args: &FnArgs, item_fn: &syn::ItemFn) -> syn::Result<(syn::Type, TokenStream)> {
require_optional_db_lifetime(&item_fn.sig.generics)?;
let db_lt = &db_lifetime(&item_fn.sig.generics);
let struct_item = configuration_struct(item_fn);
let configuration = fn_configuration(args, item_fn);
let struct_item_ident = &struct_item.ident;
let config_ty: syn::Type = parse_quote!(#struct_item_ident);
let configuration_impl = configuration.to_impl(&config_ty);
let interned_configuration_impl = interned_configuration_impl(db_lt, item_fn, &config_ty);
let ingredients_for_impl = ingredients_for_impl(args, item_fn, &config_ty);
let item_impl = setter_impl(args, item_fn, &config_ty)?;
Ok((
config_ty,
quote! {
#struct_item
#configuration_impl
#interned_configuration_impl
#ingredients_for_impl
#item_impl
},
))
}
fn interned_configuration_impl(
db_lt: &syn::Lifetime,
item_fn: &syn::ItemFn,
config_ty: &syn::Type,
) -> syn::ItemImpl {
let arg_tys = item_fn.sig.inputs.iter().skip(1).map(|arg| match arg {
syn::FnArg::Receiver(_) => unreachable!(),
syn::FnArg::Typed(pat_ty) => pat_ty.ty.clone(),
});
let intern_data_ty: syn::Type = parse_quote!(
(#(#arg_tys),*)
);
let intern_data_ty = ChangeLt::elided_to(db_lt).in_type(&intern_data_ty);
let debug_name = crate::literal(&item_fn.sig.ident);
parse_quote!(
impl salsa::interned::Configuration for #config_ty {
const DEBUG_NAME: &'static str = #debug_name;
type Data<#db_lt> = #intern_data_ty;
type Struct<#db_lt> = & #db_lt salsa::interned::ValueStruct<Self>;
unsafe fn struct_from_raw<'db>(ptr: std::ptr::NonNull<salsa::interned::ValueStruct<Self>>) -> Self::Struct<'db> {
unsafe { ptr.as_ref() }
}
fn deref_struct<'db>(s: Self::Struct<'db>) -> &'db salsa::interned::ValueStruct<Self> {
s
}
}
)
}
fn configuration_struct(item_fn: &syn::ItemFn) -> syn::ItemStruct {
let fn_name = item_fn.sig.ident.clone();
let visibility = &item_fn.vis;
let intern_map: syn::Type = match function_type(item_fn) {
FunctionType::Constant => {
parse_quote! { salsa::interned::IdentityInterner<Self> }
}
FunctionType::SalsaStruct => {
parse_quote! { salsa::interned::IdentityInterner<Self> }
}
FunctionType::RequiresInterning => {
parse_quote! { salsa::interned::InternedIngredient<Self> }
}
};
parse_quote! {
#[allow(non_camel_case_types)]
#visibility struct #fn_name {
intern_map: #intern_map,
function: salsa::function::FunctionIngredient<Self>,
}
}
}
#[derive(Debug, PartialEq, Eq, Hash)]
enum FunctionType {
Constant,
SalsaStruct,
RequiresInterning,
}
fn function_type(item_fn: &syn::ItemFn) -> FunctionType {
match item_fn.sig.inputs.len() {
0 => unreachable!(
"functions have been checked to have at least a database argument by this point"
),
1 => FunctionType::Constant,
2 => FunctionType::SalsaStruct,
_ => FunctionType::RequiresInterning,
}
}
/// Every tracked fn takes a salsa struct as its second argument.
/// This fn returns the type of that second argument.
fn salsa_struct_ty(item_fn: &syn::ItemFn) -> syn::Type {
if item_fn.sig.inputs.len() == 1 {
return parse_quote! { salsa::salsa_struct::Singleton };
}
match &item_fn.sig.inputs[1] {
syn::FnArg::Receiver(_) => panic!("receiver not expected"),
syn::FnArg::Typed(pat_ty) => (*pat_ty.ty).clone(),
}
}
fn fn_configuration(args: &FnArgs, item_fn: &syn::ItemFn) -> Configuration {
let jar_ty = args.jar_ty();
let db_lt = db_lifetime(&item_fn.sig.generics);
let salsa_struct_ty = salsa_struct_ty(item_fn);
let key_ty = match function_type(item_fn) {
FunctionType::Constant => parse_quote!(()),
FunctionType::SalsaStruct => salsa_struct_ty.clone(),
FunctionType::RequiresInterning => parse_quote!(salsa::id::Id),
};
let key_ty = ChangeLt::elided_to(&db_lt).in_type(&key_ty);
let value_ty = configuration::value_ty(&db_lt, &item_fn.sig);
let fn_ty = item_fn.sig.ident.clone();
// During recovery or execution, we are invoked with a `salsa::Id`
// that represents the interned value. We convert it back to a key
// which is either a single value (if there is one argument)
// or a tuple of values (if multiple arugments). `key_var` is the variable
// name that will store this result, and `key_splat` is a set of tokens
// that will convert it into one or multiple arguments (e.g., `key_var` if there
// is one argument or `key_var.0, key_var.1` if 2) that can be pasted into a function call.
let key_var = syn::Ident::new("__key", item_fn.span());
let key_fields = item_fn.sig.inputs.len() - 1;
let key_splat = if key_fields == 1 {
quote!(#key_var)
} else {
let indices = (0..key_fields)
.map(Literal::usize_unsuffixed)
.collect::<Vec<_>>();
quote!(#(__key.#indices),*)
};
let (cycle_strategy, recover_fn) = if let Some(recovery_fn) = &args.recovery_fn {
// Create the `recover_from_cycle` function, which (a) maps from the interned id to the actual
// keys and then (b) invokes the recover function itself.
let cycle_strategy = CycleRecoveryStrategy::Fallback;
let cycle_fullback = parse_quote! {
fn recover_from_cycle<'db>(__db: &'db salsa::function::DynDb<Self>, __cycle: &salsa::Cycle, __id: salsa::Id) -> Self::Value<'db> {
let (__jar, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(__db);
let __ingredients =
<_ as salsa::storage::HasIngredientsFor<#fn_ty>>::ingredient(__jar);
let #key_var = __ingredients.intern_map.data_with_db(__id, __db).clone();
#recovery_fn(__db, __cycle, #key_splat)
}
};
(cycle_strategy, cycle_fullback)
} else {
// When the `recovery_fn` attribute is not set, set `cycle_strategy` to `Panic`
let cycle_strategy = CycleRecoveryStrategy::Panic;
let cycle_panic = configuration::panic_cycle_recovery_fn();
(cycle_strategy, cycle_panic)
};
let backdate_fn = configuration::should_backdate_value_fn(args.should_backdate());
// The type of the configuration struct; this has the same name as the fn itself.
// Make a copy of the fn with a different name; we will invoke this from `execute`.
// We need to change the name because, otherwise, if the function invoked itself
// recursively it would not go through the query system.
let inner_fn_name = &syn::Ident::new("__fn", item_fn.sig.ident.span());
let mut inner_fn = item_fn.clone();
inner_fn.sig.ident = inner_fn_name.clone();
// Create the `execute` function, which (a) maps from the interned id to the actual
// keys and then (b) invokes the function itself (which we embed within).
let execute_fn = parse_quote! {
fn execute<'db>(__db: &'db salsa::function::DynDb<Self>, __id: salsa::Id) -> Self::Value<'db> {
#inner_fn
let (__jar, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(__db);
let __ingredients =
<_ as salsa::storage::HasIngredientsFor<#fn_ty>>::ingredient(__jar);
let #key_var = __ingredients.intern_map.data_with_db(__id, __db).clone();
#inner_fn_name(__db, #key_splat)
}
};
// get the name of the function as a string literal
let debug_name = crate::literal(&item_fn.sig.ident);
Configuration {
debug_name,
db_lt,
jar_ty,
salsa_struct_ty,
input_ty: key_ty,
value_ty,
cycle_strategy,
backdate_fn,
execute_fn,
recover_fn,
}
}
fn ingredients_for_impl(
args: &FnArgs,
item_fn: &syn::ItemFn,
config_ty: &syn::Type,
) -> syn::ItemImpl {
let jar_ty = args.jar_ty();
let intern_map: syn::Expr = match function_type(item_fn) {
FunctionType::Constant | FunctionType::SalsaStruct => {
parse_quote! {
salsa::interned::IdentityInterner::new()
}
}
FunctionType::RequiresInterning => {
parse_quote! {
{
let index = routes.push(
|jars| {
let jar = <DB as salsa::storage::JarFromJars<Self::Jar>>::jar_from_jars(jars);
let ingredients =
<_ as salsa::storage::HasIngredientsFor<Self::Ingredients>>::ingredient(jar);
&ingredients.intern_map
},
|jars| {
let jar = <DB as salsa::storage::JarFromJars<Self::Jar>>::jar_from_jars_mut(jars);
let ingredients =
<_ as salsa::storage::HasIngredientsFor<Self::Ingredients>>::ingredient_mut(jar);
&mut ingredients.intern_map
}
);
salsa::interned::InternedIngredient::new(index)
}
}
}
};
// set 0 as default to disable LRU
let lru = args.lru.unwrap_or(0);
parse_quote! {
impl salsa::storage::IngredientsFor for #config_ty {
type Ingredients = Self;
type Jar = #jar_ty;
fn create_ingredients<DB>(routes: &mut salsa::routes::Routes<DB>) -> Self::Ingredients
where
DB: salsa::DbWithJar<Self::Jar> + salsa::storage::JarFromJars<Self::Jar>,
{
Self {
intern_map: #intern_map,
function: {
let index = routes.push(
|jars| {
let jar = <DB as salsa::storage::JarFromJars<Self::Jar>>::jar_from_jars(jars);
let ingredients =
<_ as salsa::storage::HasIngredientsFor<Self::Ingredients>>::ingredient(jar);
&ingredients.function
},
|jars| {
let jar = <DB as salsa::storage::JarFromJars<Self::Jar>>::jar_from_jars_mut(jars);
let ingredients =
<_ as salsa::storage::HasIngredientsFor<Self::Ingredients>>::ingredient_mut(jar);
&mut ingredients.function
});
let ingredient = salsa::function::FunctionIngredient::new(index);
ingredient.set_capacity(#lru);
ingredient
}
}
}
}
}
}
fn setter_impl(
args: &FnArgs,
item_fn: &syn::ItemFn,
config_ty: &syn::Type,
) -> syn::Result<syn::ItemImpl> {
let ref_getter_fn = ref_getter_fn(args, item_fn, config_ty)?;
let accumulated_fn = accumulated_fn(args, item_fn, config_ty)?;
let specify_fn = specify_fn(args, item_fn, config_ty)?.map(|f| quote! { #f });
let set_lru_fn = set_lru_capacity_fn(args, config_ty)?.map(|f| quote! { #f });
let setter_impl: syn::ItemImpl = parse_quote! {
impl #config_ty {
#[allow(dead_code, clippy::needless_lifetimes)]
#ref_getter_fn
#[allow(dead_code, clippy::needless_lifetimes)]
#accumulated_fn
#set_lru_fn
#specify_fn
}
};
Ok(setter_impl)
}
/// Creates the shim function that looks like the original function but calls
/// into the machinery we've just generated rather than executing the code.
fn getter_fn(
args: &FnArgs,
fn_sig: &mut syn::Signature,
block_span: proc_macro2::Span,
config_ty: &syn::Type,
) -> syn::Result<syn::Block> {
let mut is_method = false;
let mut arg_idents: Vec<_> = fn_sig
.inputs
.iter()
.map(|arg| -> syn::Result<syn::Ident> {
match arg {
syn::FnArg::Receiver(receiver) => {
is_method = true;
Ok(syn::Ident::new("self", receiver.self_token.span()))
}
syn::FnArg::Typed(pat_ty) => Ok(match &*pat_ty.pat {
syn::Pat::Ident(ident) => ident.ident.clone(),
_ => return Err(syn::Error::new(arg.span(), "unsupported argument kind")),
}),
}
})
.collect::<Result<_, _>>()?;
// If this is a method then the order of the database and the salsa struct are reversed
// because the self argument must always come first.
if is_method {
arg_idents.swap(0, 1);
}
Ok(if args.return_ref.is_some() {
make_fn_return_ref(fn_sig)?;
parse_quote_spanned! {
block_span => {
#config_ty::get(#(#arg_idents,)*)
}
}
} else {
parse_quote_spanned! {
block_span => {
Clone::clone(#config_ty::get(#(#arg_idents,)*))
}
}
})
}
/// Creates a `get` associated function that returns `&Value`
/// (to be used when `return_ref` is specified).
///
/// (Helper for `getter_fn`)
fn ref_getter_fn(
args: &FnArgs,
item_fn: &syn::ItemFn,
config_ty: &syn::Type,
) -> syn::Result<syn::ItemFn> {
let jar_ty = args.jar_ty();
let mut ref_getter_fn = item_fn.clone();
ref_getter_fn.sig.ident = syn::Ident::new("get", item_fn.sig.ident.span());
make_fn_return_ref(&mut ref_getter_fn.sig)?;
let (db_var, arg_names) = fn_args(item_fn)?;
ref_getter_fn.block = parse_quote! {
{
let (__jar, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(#db_var);
let __ingredients = <_ as salsa::storage::HasIngredientsFor<#config_ty>>::ingredient(__jar);
let __key = __ingredients.intern_map.intern_id(__runtime, (#(#arg_names),*));
__ingredients.function.fetch(#db_var, __key)
}
};
Ok(ref_getter_fn)
}
/// Create a `set_lru_capacity` associated function that can be used to change LRU
/// capacity at runtime.
/// Note that this function is only generated if the tracked function has the lru option set.
///
/// # Examples
///
/// ```rust,ignore
/// #[salsa::tracked(lru=32)]
/// fn my_tracked_fn(db: &dyn crate::Db, ...) { }
///
/// my_tracked_fn::set_lru_capacity(16)
/// ```
fn set_lru_capacity_fn(
args: &FnArgs,
config_ty: &syn::Type,
) -> syn::Result<Option<syn::ImplItemFn>> {
if args.lru.is_none() {
return Ok(None);
}
let jar_ty = args.jar_ty();
let lru_fn = parse_quote! {
#[allow(dead_code, clippy::needless_lifetimes)]
fn set_lru_capacity(__db: &salsa::function::DynDb<Self>, __value: usize) {
let (__jar, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(__db);
let __ingredients =
<_ as salsa::storage::HasIngredientsFor<#config_ty>>::ingredient(__jar);
__ingredients.function.set_capacity(__value);
}
};
Ok(Some(lru_fn))
}
fn specify_fn(
args: &FnArgs,
item_fn: &syn::ItemFn,
config_ty: &syn::Type,
) -> syn::Result<Option<syn::ImplItemFn>> {
if args.specify.is_none() {
return Ok(None);
}
let mut setter_sig = item_fn.sig.clone();
require_optional_db_lifetime(&item_fn.sig.generics)?;
let db_lt = &db_lifetime(&item_fn.sig.generics);
match setter_sig.generics.lifetimes().count() {
0 => setter_sig.generics.params.push(parse_quote!(#db_lt)),
1 => (),
_ => panic!("unreachable -- would have generated an error earlier"),
};
// `specify` has the same signature as the original,
// but it takes a value arg and has no return type.
let jar_ty = args.jar_ty();
let (db_var, arg_names) = fn_args(item_fn)?;
let value_ty = configuration::value_ty(db_lt, &item_fn.sig);
setter_sig.ident = syn::Ident::new("specify", item_fn.sig.ident.span());
let value_arg = syn::Ident::new("__value", item_fn.sig.output.span());
setter_sig.inputs.push(parse_quote!(#value_arg: #value_ty));
setter_sig.output = ReturnType::Default;
Ok(Some(syn::ImplItemFn {
attrs: vec![],
vis: item_fn.vis.clone(),
defaultness: None,
sig: setter_sig,
block: parse_quote! {
{
let (__jar, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(#db_var);
let __ingredients = <_ as salsa::storage::HasIngredientsFor<#config_ty>>::ingredient(__jar);
let __key = __ingredients.intern_map.intern_id(__runtime, (#(#arg_names),*));
__ingredients.function.specify_and_record(#db_var, __key, #value_arg)
}
},
}))
}
/// Given a function def tagged with `#[return_ref]`, modifies `fn_sig` so that
/// it returns an `&Value` instead of `Value`. May introduce a name for the
/// database lifetime if required.
fn make_fn_return_ref(fn_sig: &mut syn::Signature) -> syn::Result<()> {
// An input should be a `&dyn Db`.
// We need to ensure it has a named lifetime parameter.
let (db_lifetime, _) = db_lifetime_and_ty(fn_sig)?;
let (right_arrow, elem) = match fn_sig.output.clone() {
ReturnType::Default => (
syn::Token![->]([Span::call_site(), Span::call_site()]),
parse_quote!(()),
),
ReturnType::Type(rarrow, ty) => (rarrow, ty),
};
let ref_output = syn::TypeReference {
and_token: syn::Token![&](right_arrow.span()),
lifetime: Some(db_lifetime),
mutability: None,
elem,
};
fn_sig.output = syn::ReturnType::Type(right_arrow, Box::new(ref_output.into()));
Ok(())
}
/// Given a function signature, identifies the name given to the `&dyn Db` reference
/// and returns it, along with the type of the database.
/// If the database lifetime did not have a name, then modifies the item function
/// so that it is called `'__db` and returns that.
fn db_lifetime_and_ty(func: &mut syn::Signature) -> syn::Result<(syn::Lifetime, &syn::Type)> {
// If this is a method, then the database should be the second argument.
let db_loc = if matches!(func.inputs[0], syn::FnArg::Receiver(_)) {
1
} else {
0
};
match &mut func.inputs[db_loc] {
syn::FnArg::Receiver(r) => Err(syn::Error::new(r.span(), "two self arguments")),
syn::FnArg::Typed(pat_ty) => match &mut *pat_ty.ty {
syn::Type::Reference(ty) => match &ty.lifetime {
Some(lt) => Ok((lt.clone(), &pat_ty.ty)),
None => {
let and_token_span = ty.and_token.span();
let ident = syn::Ident::new("__db", and_token_span);
func.generics.params.insert(
0,
syn::LifetimeParam {
attrs: vec![],
lifetime: syn::Lifetime {
apostrophe: and_token_span,
ident: ident.clone(),
},
colon_token: None,
bounds: Default::default(),
}
.into(),
);
let db_lifetime = syn::Lifetime {
apostrophe: and_token_span,
ident,
};
ty.lifetime = Some(db_lifetime.clone());
Ok((db_lifetime, &pat_ty.ty))
}
},
_ => Err(syn::Error::new(
pat_ty.span(),
"expected database to be a `&` type",
)),
},
}
}
/// Generates the `accumulated` function, which invokes `accumulated`
/// on the function ingredient to extract the values pushed (transitively)
/// into an accumulator.
fn accumulated_fn(
args: &FnArgs,
item_fn: &syn::ItemFn,
config_ty: &syn::Type,
) -> syn::Result<syn::ItemFn> {
let jar_ty = args.jar_ty();
let mut accumulated_fn = item_fn.clone();
accumulated_fn.sig.ident = syn::Ident::new("accumulated", item_fn.sig.ident.span());
accumulated_fn.sig.generics.params.push(parse_quote! {
__A: salsa::accumulator::Accumulator
});
accumulated_fn.sig.output = parse_quote! {
-> Vec<<__A as salsa::accumulator::Accumulator>::Data>
};
let predicate: syn::WherePredicate = parse_quote!(<#jar_ty as salsa::jar::Jar>::DynDb: salsa::storage::HasJar<<__A as salsa::accumulator::Accumulator>::Jar>);
if let Some(where_clause) = &mut accumulated_fn.sig.generics.where_clause {
where_clause.predicates.push(predicate);
} else {
accumulated_fn.sig.generics.where_clause = parse_quote!(where #predicate);
}
let (db_var, arg_names) = fn_args(item_fn)?;
accumulated_fn.block = parse_quote! {
{
let (__jar, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(#db_var);
let __ingredients = <_ as salsa::storage::HasIngredientsFor<#config_ty>>::ingredient(__jar);
let __key = __ingredients.intern_map.intern_id(__runtime, (#(#arg_names),*));
__ingredients.function.accumulated::<__A>(#db_var, __key)
}
};
Ok(accumulated_fn)
}
/// Examines the function arguments and returns a tuple of:
///
/// * the name of the database argument
/// * the name(s) of the key arguments
fn fn_args(item_fn: &syn::ItemFn) -> syn::Result<(proc_macro2::Ident, Vec<proc_macro2::Ident>)> {
// Check that we have no receiver and that all arguments have names
if item_fn.sig.inputs.is_empty() {
return Err(syn::Error::new(
item_fn.sig.span(),
"method needs a database argument",
));
}
let mut input_names = vec![];
for input in &item_fn.sig.inputs {
match input {
syn::FnArg::Receiver(r) => {
return Err(syn::Error::new(r.span(), "no self argument expected"));
}
syn::FnArg::Typed(pat_ty) => match &*pat_ty.pat {
syn::Pat::Ident(ident) => {
input_names.push(ident.ident.clone());
}
_ => {
return Err(syn::Error::new(
pat_ty.pat.span(),
"all arguments must be given names",
));
}
},
}
}
// Database is the first argument
let db_var = input_names[0].clone();
let arg_names = input_names[1..].to_owned();
Ok((db_var, arg_names))
}

View file

@ -29,34 +29,33 @@ impl HasLogger for Database {
}
}
// #[salsa::tracked]
// fn final_result(db: &dyn Db, input: i32, (b, c): (i32, i32)) -> i32 {
// db.push_log(format!("final_result({:?})", input));
// input
// }
#[salsa::tracked]
fn final_result(dbx: &dyn Db, a: i32, (b, c): (i32, i32)) -> i32 {
dbx.push_log(format!("final_result({a}, {b}, {c})"));
a + b * c
}
salsa::plumbing::setup_interned_fn!(
vis: ,
fn_name: identity,
db_lt: 'db,
Db: Db,
db: dbx,
input_ids: [input1, input2],
input_pats: [a, (b, c)],
input_tys: [i32, (i32, i32)],
output_ty: i32,
body: {
dbx.push_log(format!("final_result({a}, {b}, {c})"));
a + b * c
},
cycle_recovery_fn: (salsa::plumbing::unexpected_cycle_recovery!),
cycle_recovery_strategy: Panic,
unused_names: [
zalsa1,
Configuration1,
InternedData1,
FN_CACHE1,
INTERN_CACHE1,
inner1,
]
);
// salsa::plumbing::setup_interned_fn!(
// vis: ,
// fn_name: identity,
// db_lt: 'db,
// Db: Db,
// db: dbx,
// input_ids: [input1, input2],
// input_tys: [i32, (i32, i32)],
// output_ty: i32,
// inner_fn: fn inner1(dbx: &dyn Db, a: i32, (b, c): (i32, i32)) -> i32 {
// dbx.push_log(format!("final_result({a}, {b}, {c})"));
// a + b * c
// },
// cycle_recovery_fn: (salsa::plumbing::unexpected_cycle_recovery!),
// cycle_recovery_strategy: Panic,
// unused_names: [
// zalsa1,
// Configuration1,
// InternedData1,
// FN_CACHE1,
// INTERN_CACHE1,
// inner1,
// ]
// );