From 4f234cfbb9ae42e4e54aa9e1a3cc568e688cf373 Mon Sep 17 00:00:00 2001 From: Niko Matsakis Date: Fri, 5 Aug 2022 02:46:16 -0400 Subject: [PATCH] remove component and replace with `specify` option You can now do `#[salsa::tracked(specify)]` and you will get a method `some_fn::specify(...)` that can be used to specify the value. --- .../salsa-entity-macros/src/accumulator.rs | 2 + .../salsa-entity-macros/src/component.rs | 411 ------------------ components/salsa-entity-macros/src/jar.rs | 2 + components/salsa-entity-macros/src/lib.rs | 6 - components/salsa-entity-macros/src/options.rs | 22 + .../salsa-entity-macros/src/salsa_struct.rs | 2 + .../salsa-entity-macros/src/tracked_fn.rs | 177 ++++++-- components/salsa-entity-mock/src/interned.rs | 19 + components/salsa-entity-mock/src/lib.rs | 1 - salsa-2022-tests/main.rs | 1 + salsa-2022-tests/tracked_fn_on_tracked.rs | 2 +- .../tracked_fn_on_tracked_specify.rs | 64 +++ 12 files changed, 248 insertions(+), 461 deletions(-) delete mode 100644 components/salsa-entity-macros/src/component.rs create mode 100644 salsa-2022-tests/tracked_fn_on_tracked_specify.rs diff --git a/components/salsa-entity-macros/src/accumulator.rs b/components/salsa-entity-macros/src/accumulator.rs index 695b352b..ed870fe8 100644 --- a/components/salsa-entity-macros/src/accumulator.rs +++ b/components/salsa-entity-macros/src/accumulator.rs @@ -21,6 +21,8 @@ struct Accumulator; impl crate::options::AllowedOptions for Accumulator { const RETURN_REF: bool = false; + const SPECIFY: bool = false; + const NO_EQ: bool = false; const JAR: bool = true; diff --git a/components/salsa-entity-macros/src/component.rs b/components/salsa-entity-macros/src/component.rs deleted file mode 100644 index 90161a47..00000000 --- a/components/salsa-entity-macros/src/component.rs +++ /dev/null @@ -1,411 +0,0 @@ -use syn::spanned::Spanned; -use syn::{ItemFn, ReturnType}; - -use crate::configuration::{self, Configuration, CycleRecoveryStrategy}; -use crate::options::Options; - -// #[salsa::component(in Jar0)] -// fn my_func(db: &dyn Jar0Db, input1: u32, input2: u32) -> String { -// format!("Hello, world") -// } - -pub(crate) fn component( - args: proc_macro::TokenStream, - input: proc_macro::TokenStream, -) -> proc_macro::TokenStream { - let args = syn::parse_macro_input!(args as Args); - let item_fn = syn::parse_macro_input!(input as ItemFn); - match component_helper(args, item_fn) { - Ok(v) => v, - Err(e) => return e.into_compile_error().into(), - } -} - -fn component_helper(args: Args, item_fn: ItemFn) -> syn::Result { - let struct_item = configuration_struct(&item_fn); - let configuration = fn_configuration(&args, &item_fn)?; - let struct_item_ident = &struct_item.ident; - let struct_ty: syn::Type = parse_quote!(#struct_item_ident); - let configuration_impl = configuration.to_impl(&struct_ty); - let ingredients_for_impl = ingredients_for_impl(&args, &struct_ty); - let (getter, item_impl) = wrapper_fns(&args, &item_fn, &struct_ty)?; - - Ok(proc_macro::TokenStream::from(quote! { - #struct_item - #configuration_impl - #ingredients_for_impl - #getter - #item_impl - })) -} - -struct Component; - -type Args = Options; - -impl crate::options::AllowedOptions for Component { - const RETURN_REF: bool = true; - - const NO_EQ: bool = true; - - const JAR: bool = true; - - const DATA: bool = false; - - const DB: bool = false; -} - -fn configuration_struct(item_fn: &syn::ItemFn) -> syn::ItemStruct { - let fn_name = item_fn.sig.ident.clone(); - let vis = &item_fn.vis; - parse_quote! { - #[allow(non_camel_case_types)] - #vis struct #fn_name { - function: salsa::function::FunctionIngredient, - } - } -} - -fn fn_configuration(args: &Args, item_fn: &syn::ItemFn) -> syn::Result { - let jar_ty = args.jar_ty(); - let key_ty = arg_ty(item_fn)?.clone(); - let value_ty = configuration::value_ty(&item_fn.sig); - - // FIXME: these are hardcoded for now - let cycle_strategy = CycleRecoveryStrategy::Panic; - - let backdate_fn = configuration::should_backdate_value_fn(args.should_backdate()); - let recover_fn = configuration::panic_cycle_recovery_fn(); - - // The type of the configuration struct; this has the same name as the fn itself. - let fn_ty = item_fn.sig.ident.clone(); - - // Make a copy of the fn with a different name; we will invoke this from `execute`. - // We need to change the name because, otherwise, if the function invoked itself - // recursively it would not go through the query system. - let inner_fn_name = &syn::Ident::new("__fn", item_fn.sig.ident.span()); - let mut inner_fn = item_fn.clone(); - inner_fn.sig.ident = inner_fn_name.clone(); - - // Create the `execute` function, which invokes the function itself (which we embed within). - let execute_fn = parse_quote! { - fn execute(__db: &salsa::function::DynDb, __id: Self::Key) -> Self::Value { - #inner_fn - - let (__jar, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(__db); - let __ingredients = - <_ as salsa::storage::HasIngredientsFor<#fn_ty>>::ingredient(__jar); - #inner_fn_name(__db, __id) - } - }; - - Ok(Configuration { - jar_ty, - key_ty, - value_ty, - cycle_strategy, - backdate_fn, - execute_fn, - recover_fn, - }) -} - -fn ingredients_for_impl(args: &Args, struct_ty: &syn::Type) -> syn::ItemImpl { - let jar_ty = args.jar_ty(); - parse_quote! { - impl salsa::storage::IngredientsFor for #struct_ty { - type Ingredients = Self; - type Jar = #jar_ty; - - fn create_ingredients(ingredients: &mut salsa::routes::Ingredients) -> Self::Ingredients - where - DB: salsa::DbWithJar + salsa::storage::JarFromJars, - { - Self { - function: { - let index = ingredients.push(|jars| { - let jar = >::jar_from_jars(jars); - let ingredients = - <_ as salsa::storage::HasIngredientsFor>::ingredient(jar); - &ingredients.function - }); - salsa::function::FunctionIngredient::new(index) - }, - } - } - } - } -} - -fn wrapper_fns( - args: &Args, - item_fn: &syn::ItemFn, - struct_ty: &syn::Type, -) -> syn::Result<(syn::ItemFn, syn::ItemImpl)> { - // The "getter" has same signature as the original: - let getter_fn = getter_fn(args, item_fn, struct_ty)?; - - let ref_getter_fn = ref_getter_fn(args, item_fn, struct_ty)?; - let accumulated_fn = accumulated_fn(args, item_fn, struct_ty)?; - let setter_fn = setter_fn(args, item_fn, struct_ty)?; - - let item_impl: syn::ItemImpl = parse_quote! { - impl #struct_ty { - #ref_getter_fn - #setter_fn - #accumulated_fn - } - }; - - Ok((getter_fn, item_impl)) -} - -fn getter_fn( - args: &Args, - item_fn: &syn::ItemFn, - struct_ty: &syn::Type, -) -> syn::Result { - let mut getter_fn = item_fn.clone(); - let arg_idents: Vec<_> = item_fn - .sig - .inputs - .iter() - .map(|arg| -> syn::Result { - match arg { - syn::FnArg::Receiver(_) => Err(syn::Error::new(arg.span(), "unexpected receiver")), - syn::FnArg::Typed(pat_ty) => Ok(match &*pat_ty.pat { - syn::Pat::Ident(ident) => ident.ident.clone(), - _ => return Err(syn::Error::new(arg.span(), "unexpected receiver")), - }), - } - }) - .collect::>()?; - if args.return_ref.is_some() { - getter_fn = make_fn_return_ref(getter_fn)?; - getter_fn.block = Box::new(parse_quote_spanned! { - item_fn.block.span() => { - #struct_ty::get(#(#arg_idents,)*) - } - }); - } else { - getter_fn.block = Box::new(parse_quote_spanned! { - item_fn.block.span() => { - Clone::clone(#struct_ty::get(#(#arg_idents,)*)) - } - }); - } - Ok(getter_fn) -} - -fn ref_getter_fn( - args: &Args, - item_fn: &syn::ItemFn, - struct_ty: &syn::Type, -) -> syn::Result { - let jar_ty = args.jar_ty(); - let mut ref_getter_fn = item_fn.clone(); - ref_getter_fn.sig.ident = syn::Ident::new("get", item_fn.sig.ident.span()); - ref_getter_fn = make_fn_return_ref(ref_getter_fn)?; - - let (db_var, arg_names) = fn_args(item_fn)?; - ref_getter_fn.block = parse_quote! { - { - let (__jar, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(#db_var); - let __ingredients = <_ as salsa::storage::HasIngredientsFor<#struct_ty>>::ingredient(__jar); - __ingredients.function.fetch(#db_var, #(#arg_names),*) - } - }; - - Ok(ref_getter_fn) -} - -fn setter_fn( - args: &Args, - item_fn: &syn::ItemFn, - struct_ty: &syn::Type, -) -> syn::Result { - // The setter has *always* the same signature as the original: - // but it takes a value arg and has no return type. - let jar_ty = args.jar_ty(); - let (db_var, arg_names) = fn_args(item_fn)?; - let mut setter_sig = item_fn.sig.clone(); - let value_ty = configuration::value_ty(&item_fn.sig); - setter_sig.ident = syn::Ident::new("set", item_fn.sig.ident.span()); - let value_arg = syn::Ident::new("__value", item_fn.sig.output.span()); - setter_sig.inputs.push(parse_quote!(#value_arg: #value_ty)); - setter_sig.output = ReturnType::Default; - Ok(syn::ImplItemMethod { - attrs: vec![], - vis: item_fn.vis.clone(), - defaultness: None, - sig: setter_sig, - block: parse_quote! { - { - let (__jar, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(#db_var); - let __ingredients = <_ as salsa::storage::HasIngredientsFor<#struct_ty>>::ingredient(__jar); - __ingredients.function.set(#db_var, #(#arg_names),*, #value_arg) - } - }, - }) -} - -fn make_fn_return_ref(mut ref_getter_fn: syn::ItemFn) -> syn::Result { - // The 0th input should be a `&dyn Foo`. We need to ensure - // it has a named lifetime parameter. - let (db_lifetime, _) = db_lifetime_and_ty(&mut ref_getter_fn)?; - - let (right_arrow, elem) = match ref_getter_fn.sig.output { - ReturnType::Default => ( - syn::Token![->](ref_getter_fn.sig.paren_token.span), - parse_quote!(()), - ), - ReturnType::Type(rarrow, ty) => (rarrow, ty), - }; - - let ref_output = syn::TypeReference { - and_token: syn::Token![&](right_arrow.span()), - lifetime: Some(db_lifetime), - mutability: None, - elem, - }; - - ref_getter_fn.sig.output = syn::ReturnType::Type(right_arrow, Box::new(ref_output.into())); - - Ok(ref_getter_fn) -} - -fn db_lifetime_and_ty(func: &mut syn::ItemFn) -> syn::Result<(syn::Lifetime, &syn::Type)> { - match &mut func.sig.inputs[0] { - syn::FnArg::Receiver(r) => { - return Err(syn::Error::new(r.span(), "expected database, not self")) - } - syn::FnArg::Typed(pat_ty) => match &mut *pat_ty.ty { - syn::Type::Reference(ty) => match &ty.lifetime { - Some(lt) => Ok((lt.clone(), &pat_ty.ty)), - None => { - let and_token_span = ty.and_token.span(); - let ident = syn::Ident::new("__db", and_token_span); - func.sig.generics.params.insert( - 0, - syn::LifetimeDef { - attrs: vec![], - lifetime: syn::Lifetime { - apostrophe: and_token_span, - ident: ident.clone(), - }, - colon_token: None, - bounds: Default::default(), - } - .into(), - ); - let db_lifetime = syn::Lifetime { - apostrophe: and_token_span, - ident, - }; - ty.lifetime = Some(db_lifetime.clone()); - Ok((db_lifetime, &pat_ty.ty)) - } - }, - _ => { - return Err(syn::Error::new( - pat_ty.span(), - "expected database to be a `&` type", - )) - } - }, - } -} - -fn accumulated_fn( - args: &Args, - item_fn: &syn::ItemFn, - struct_ty: &syn::Type, -) -> syn::Result { - let jar_ty = args.jar_ty(); - - let mut accumulated_fn = item_fn.clone(); - accumulated_fn.sig.ident = syn::Ident::new("accumulated", item_fn.sig.ident.span()); - accumulated_fn.sig.generics.params.push(parse_quote! { - __A: salsa::accumulator::Accumulator - }); - accumulated_fn.sig.output = parse_quote! { - -> Vec<<__A as salsa::accumulator::Accumulator>::Data> - }; - - let (db_lifetime, _) = db_lifetime_and_ty(&mut accumulated_fn)?; - let predicate: syn::WherePredicate = parse_quote!(<#jar_ty as salsa::jar::Jar<#db_lifetime>>::DynDb: salsa::storage::HasJar<<__A as salsa::accumulator::Accumulator>::Jar>); - - if let Some(where_clause) = &mut accumulated_fn.sig.generics.where_clause { - where_clause.predicates.push(predicate); - } else { - accumulated_fn.sig.generics.where_clause = parse_quote!(where #predicate); - } - - let (db_var, arg_names) = fn_args(item_fn)?; - accumulated_fn.block = parse_quote! { - { - let (__jar, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(#db_var); - let __ingredients = <_ as salsa::storage::HasIngredientsFor<#struct_ty>>::ingredient(__jar); - __ingredients.function.accumulated::<__A>(#db_var, #(#arg_names),*) - } - }; - - Ok(accumulated_fn) -} - -fn fn_args(item_fn: &syn::ItemFn) -> syn::Result<(proc_macro2::Ident, Vec)> { - // Check that we have no receiver and that all argments have names - if item_fn.sig.inputs.len() == 0 { - return Err(syn::Error::new( - item_fn.sig.span(), - "method needs a database argument", - )); - } - - let mut input_names = vec![]; - for input in &item_fn.sig.inputs { - match input { - syn::FnArg::Receiver(r) => { - return Err(syn::Error::new(r.span(), "no self argument expected")); - } - syn::FnArg::Typed(pat_ty) => match &*pat_ty.pat { - syn::Pat::Ident(ident) => { - input_names.push(ident.ident.clone()); - } - - _ => { - return Err(syn::Error::new( - pat_ty.pat.span(), - "all arguments must be given names", - )); - } - }, - } - } - - // Database is the first argument - let db_var = input_names[0].clone(); - let arg_names = input_names[1..].to_owned(); - - Ok((db_var, arg_names)) -} - -fn arg_ty(item_fn: &syn::ItemFn) -> syn::Result<&syn::Type> { - // Check that we have no receiver and that all argments have names - if item_fn.sig.inputs.len() != 2 { - return Err(syn::Error::new( - item_fn.sig.span(), - "component method needs a database argument and an entity", - )); - } - - match &item_fn.sig.inputs[1] { - syn::FnArg::Typed(pat_ty) => Ok(&pat_ty.ty), - _ => { - return Err(syn::Error::new( - item_fn.sig.inputs[1].span(), - "expected a fn parameter with a type", - )); - } - } -} diff --git a/components/salsa-entity-macros/src/jar.rs b/components/salsa-entity-macros/src/jar.rs index 358d38b8..6b474ce7 100644 --- a/components/salsa-entity-macros/src/jar.rs +++ b/components/salsa-entity-macros/src/jar.rs @@ -30,6 +30,8 @@ struct Jar; impl crate::options::AllowedOptions for Jar { const RETURN_REF: bool = false; + const SPECIFY: bool = false; + const NO_EQ: bool = false; const JAR: bool = false; diff --git a/components/salsa-entity-macros/src/lib.rs b/components/salsa-entity-macros/src/lib.rs index 774f92af..13263b3b 100644 --- a/components/salsa-entity-macros/src/lib.rs +++ b/components/salsa-entity-macros/src/lib.rs @@ -26,7 +26,6 @@ macro_rules! parse_quote_spanned { } mod accumulator; -mod component; mod configuration; mod db; mod input; @@ -63,11 +62,6 @@ pub fn input(args: TokenStream, input: TokenStream) -> TokenStream { input::input(args, input) } -#[proc_macro_attribute] -pub fn component(args: TokenStream, input: TokenStream) -> TokenStream { - component::component(args, input) -} - #[proc_macro_attribute] pub fn tracked(args: TokenStream, input: TokenStream) -> TokenStream { tracked::tracked(args, input) diff --git a/components/salsa-entity-macros/src/options.rs b/components/salsa-entity-macros/src/options.rs index f9bc7e41..e6f080f4 100644 --- a/components/salsa-entity-macros/src/options.rs +++ b/components/salsa-entity-macros/src/options.rs @@ -19,6 +19,12 @@ pub(crate) struct Options { /// If this is `Some`, the value is the `no_eq` identifier. pub no_eq: Option, + /// The `specify` option is used to signal that a tracked function can + /// have its value externally specified (at least some of the time). + /// + /// If this is `Some`, the value is the `specify` identifier. + pub specify: Option, + /// The `jar = ` option is used to indicate the jar; it defaults to `crate::jar`. /// /// If this is `Some`, the value is the ``. @@ -43,6 +49,7 @@ impl Default for Options { fn default() -> Self { Self { return_ref: Default::default(), + specify: Default::default(), no_eq: Default::default(), jar_ty: Default::default(), db_path: Default::default(), @@ -55,6 +62,7 @@ impl Default for Options { /// These flags determine which options are allowed in a given context pub(crate) trait AllowedOptions { const RETURN_REF: bool; + const SPECIFY: bool; const NO_EQ: bool; const JAR: bool; const DATA: bool; @@ -111,6 +119,20 @@ impl syn::parse::Parse for Options { "`no_eq` option not allowed here", )); } + } else if ident == "specify" { + if A::SPECIFY { + if let Some(old) = std::mem::replace(&mut options.specify, Some(ident)) { + return Err(syn::Error::new( + old.span(), + "option `specify` provided twice", + )); + } + } else { + return Err(syn::Error::new( + ident.span(), + "`specify` option not allowed here", + )); + } } else if ident == "jar" { if A::JAR { let _eq = Equals::parse(input)?; diff --git a/components/salsa-entity-macros/src/salsa_struct.rs b/components/salsa-entity-macros/src/salsa_struct.rs index 082538c7..d88ddf23 100644 --- a/components/salsa-entity-macros/src/salsa_struct.rs +++ b/components/salsa-entity-macros/src/salsa_struct.rs @@ -38,6 +38,8 @@ pub(crate) struct SalsaStruct { impl crate::options::AllowedOptions for SalsaStruct { const RETURN_REF: bool = false; + const SPECIFY: bool = false; + const NO_EQ: bool = false; const JAR: bool = true; diff --git a/components/salsa-entity-macros/src/tracked_fn.rs b/components/salsa-entity-macros/src/tracked_fn.rs index b0afc07f..4305adc9 100644 --- a/components/salsa-entity-macros/src/tracked_fn.rs +++ b/components/salsa-entity-macros/src/tracked_fn.rs @@ -1,4 +1,4 @@ -use proc_macro2::Literal; +use proc_macro2::{Literal, TokenStream}; use syn::spanned::Spanned; use syn::{ReturnType, Token}; @@ -24,13 +24,22 @@ fn tracked_fn(args: Args, item_fn: syn::ItemFn) -> syn::Result { )); } + if let Some(s) = &args.specify { + if requires_interning(&item_fn) { + return Err(syn::Error::new( + s.span(), + "tracked functon takes too many argments to have its value set with `specify`", + )); + } + } + let struct_item = configuration_struct(&item_fn); let configuration = fn_configuration(&args, &item_fn); let struct_item_ident = &struct_item.ident; - let struct_ty: syn::Type = parse_quote!(#struct_item_ident); - let configuration_impl = configuration.to_impl(&struct_ty); - let ingredients_for_impl = ingredients_for_impl(&args, &item_fn, &struct_ty); - let (getter, setter) = wrapper_fns(&args, &item_fn, &struct_ty)?; + let config_ty: syn::Type = parse_quote!(#struct_item_ident); + let configuration_impl = configuration.to_impl(&config_ty); + let ingredients_for_impl = ingredients_for_impl(&args, &item_fn, &config_ty); + let (getter, item_impl) = wrapper_fns(&args, &item_fn, &config_ty)?; Ok(quote! { #struct_item @@ -41,7 +50,7 @@ fn tracked_fn(args: Args, item_fn: syn::ItemFn) -> syn::Result { // sometimes doesn't like #[allow(clippy::needless_lifetimes)] #getter - #setter + #item_impl }) } @@ -52,6 +61,8 @@ struct TrackedFn; impl crate::options::AllowedOptions for TrackedFn { const RETURN_REF: bool = true; + const SPECIFY: bool = true; + const NO_EQ: bool = true; const JAR: bool = true; @@ -61,9 +72,9 @@ impl crate::options::AllowedOptions for TrackedFn { const DB: bool = false; } -/// Returns the key type for this tracked function. The result is -/// a tuple of the fn argments, ignoring the database. -fn key_ty(item_fn: &syn::ItemFn) -> syn::Type { +/// Returns the key type for this tracked function. +/// This is a tuple of all the argument types (apart from the database). +fn key_tuple_ty(item_fn: &syn::ItemFn) -> syn::Type { let arg_tys = item_fn.sig.inputs.iter().skip(1).map(|arg| match arg { syn::FnArg::Receiver(_) => unreachable!(), syn::FnArg::Typed(pat_ty) => pat_ty.ty.clone(), @@ -76,20 +87,49 @@ fn key_ty(item_fn: &syn::ItemFn) -> syn::Type { fn configuration_struct(item_fn: &syn::ItemFn) -> syn::ItemStruct { let fn_name = item_fn.sig.ident.clone(); - let key_tuple_ty = key_ty(item_fn); let visibility = &item_fn.vis; + + let salsa_struct_ty = salsa_struct_ty(item_fn); + let intern_map: syn::Type = if requires_interning(item_fn) { + let key_ty = key_tuple_ty(item_fn); + parse_quote! { salsa::interned::InternedIngredient } + } else { + parse_quote! { salsa::interned::IdentityInterner<#salsa_struct_ty> } + }; + parse_quote! { #[allow(non_camel_case_types)] - #visibility struct #fn_name { - intern_map: salsa::interned::InternedIngredient, + #visibility struct #fn_name + where + #salsa_struct_ty: salsa::AsId, // require that the salsa struct is, well, a salsa struct! + { + intern_map: #intern_map, function: salsa::function::FunctionIngredient, } } } +/// True if this fn takes more arguments. +fn requires_interning(item_fn: &syn::ItemFn) -> bool { + item_fn.sig.inputs.len() > 2 +} + +/// Every tracked fn takes a salsa struct as its second argument. +/// This fn returns the type of that second argument. +fn salsa_struct_ty(item_fn: &syn::ItemFn) -> &syn::Type { + match &item_fn.sig.inputs[1] { + syn::FnArg::Receiver(_) => panic!("receiver not expected"), + syn::FnArg::Typed(pat_ty) => &pat_ty.ty, + } +} + fn fn_configuration(args: &Args, item_fn: &syn::ItemFn) -> Configuration { let jar_ty = args.jar_ty(); - let key_ty = parse_quote!(salsa::id::Id); + let key_ty = if requires_interning(item_fn) { + parse_quote!(salsa::id::Id) + } else { + salsa_struct_ty(item_fn).clone() + }; let value_ty = configuration::value_ty(&item_fn.sig); // FIXME: these are hardcoded for now @@ -134,10 +174,33 @@ fn fn_configuration(args: &Args, item_fn: &syn::ItemFn) -> Configuration { } } -fn ingredients_for_impl(args: &Args, struct_ty: &syn::Type) -> syn::ItemImpl { +fn ingredients_for_impl( + args: &Args, + item_fn: &syn::ItemFn, + config_ty: &syn::Type, +) -> syn::ItemImpl { let jar_ty = args.jar_ty(); + + let intern_map: syn::Expr = if requires_interning(item_fn) { + parse_quote! { + { + let index = ingredients.push(|jars| { + let jar = >::jar_from_jars(jars); + let ingredients = + <_ as salsa::storage::HasIngredientsFor>::ingredient(jar); + &ingredients.intern_map + }); + salsa::interned::InternedIngredient::new(index) + } + } + } else { + parse_quote! { + salsa::interned::IdentityInterner::new() + } + }; + parse_quote! { - impl salsa::storage::IngredientsFor for #struct_ty { + impl salsa::storage::IngredientsFor for #config_ty { type Ingredients = Self; type Jar = #jar_ty; @@ -146,15 +209,7 @@ fn ingredients_for_impl(args: &Args, struct_ty: &syn::Type) -> syn::ItemImpl { DB: salsa::DbWithJar + salsa::storage::JarFromJars, { Self { - intern_map: { - let index = ingredients.push(|jars| { - let jar = >::jar_from_jars(jars); - let ingredients = - <_ as salsa::storage::HasIngredientsFor>::ingredient(jar); - &ingredients.intern_map - }); - salsa::interned::InternedIngredient::new(index) - }, + intern_map: #intern_map, function: { let index = ingredients.push(|jars| { @@ -174,17 +229,18 @@ fn ingredients_for_impl(args: &Args, struct_ty: &syn::Type) -> syn::ItemImpl { fn wrapper_fns( args: &Args, item_fn: &syn::ItemFn, - struct_ty: &syn::Type, + config_ty: &syn::Type, ) -> syn::Result<(syn::ItemFn, syn::ItemImpl)> { // The "getter" has same signature as the original: - let getter_fn = getter_fn(args, item_fn, struct_ty)?; + let getter_fn = getter_fn(args, item_fn, config_ty)?; - let ref_getter_fn = ref_getter_fn(args, item_fn, struct_ty)?; - let accumulated_fn = accumulated_fn(args, item_fn, struct_ty)?; - let setter_fn = setter_fn(args, item_fn, struct_ty)?; + let ref_getter_fn = ref_getter_fn(args, item_fn, config_ty)?; + let accumulated_fn = accumulated_fn(args, item_fn, config_ty)?; + let setter_fn = setter_fn(args, item_fn, config_ty)?; + let specify_fn = specify_fn(args, item_fn, config_ty)?.map(|f| quote! { #f }); let setter_impl: syn::ItemImpl = parse_quote! { - impl #struct_ty { + impl #config_ty { #[allow(dead_code, clippy::needless_lifetimes)] #ref_getter_fn @@ -193,6 +249,8 @@ fn wrapper_fns( #[allow(dead_code, clippy::needless_lifetimes)] #accumulated_fn + + #specify_fn } }; @@ -203,7 +261,7 @@ fn wrapper_fns( fn getter_fn( args: &Args, item_fn: &syn::ItemFn, - struct_ty: &syn::Type, + config_ty: &syn::Type, ) -> syn::Result { let mut getter_fn = item_fn.clone(); let arg_idents: Vec<_> = item_fn @@ -224,13 +282,13 @@ fn getter_fn( getter_fn = make_fn_return_ref(getter_fn)?; getter_fn.block = Box::new(parse_quote_spanned! { item_fn.block.span() => { - #struct_ty::get(#(#arg_idents,)*) + #config_ty::get(#(#arg_idents,)*) } }); } else { getter_fn.block = Box::new(parse_quote_spanned! { item_fn.block.span() => { - Clone::clone(#struct_ty::get(#(#arg_idents,)*)) + Clone::clone(#config_ty::get(#(#arg_idents,)*)) } }); } @@ -244,7 +302,7 @@ fn getter_fn( fn ref_getter_fn( args: &Args, item_fn: &syn::ItemFn, - struct_ty: &syn::Type, + config_ty: &syn::Type, ) -> syn::Result { let jar_ty = args.jar_ty(); let mut ref_getter_fn = item_fn.clone(); @@ -255,8 +313,8 @@ fn ref_getter_fn( ref_getter_fn.block = parse_quote! { { let (__jar, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(#db_var); - let __ingredients = <_ as salsa::storage::HasIngredientsFor<#struct_ty>>::ingredient(__jar); - let __key = __ingredients.intern_map.intern(__runtime, (#(#arg_names,)*)); + let __ingredients = <_ as salsa::storage::HasIngredientsFor<#config_ty>>::ingredient(__jar); + let __key = __ingredients.intern_map.intern(__runtime, (#(#arg_names),*)); __ingredients.function.fetch(#db_var, __key) } }; @@ -269,7 +327,7 @@ fn ref_getter_fn( fn setter_fn( args: &Args, item_fn: &syn::ItemFn, - struct_ty: &syn::Type, + config_ty: &syn::Type, ) -> syn::Result { // The setter has *always* the same signature as the original: // but it takes a value arg and has no return type. @@ -299,14 +357,49 @@ fn setter_fn( block: parse_quote! { { let (__jar, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar_mut(#db_var); - let __ingredients = <_ as salsa::storage::HasIngredientsFor<#struct_ty>>::ingredient_mut(__jar); - let __key = __ingredients.intern_map.intern(__runtime, (#(#arg_names,)*)); + let __ingredients = <_ as salsa::storage::HasIngredientsFor<#config_ty>>::ingredient_mut(__jar); + let __key = __ingredients.intern_map.intern(__runtime, (#(#arg_names),*)); __ingredients.function.store(__runtime, __key, #value_arg, salsa::Durability::LOW) } }, }) } +fn specify_fn( + args: &Args, + item_fn: &syn::ItemFn, + config_ty: &syn::Type, +) -> syn::Result> { + let specify = match &args.specify { + Some(s) => s, + None => return Ok(None), + }; + + // `specify` has the same signature as the original, + // but it takes a value arg and has no return type. + let jar_ty = args.jar_ty(); + let (db_var, arg_names) = fn_args(item_fn)?; + let mut setter_sig = item_fn.sig.clone(); + let value_ty = configuration::value_ty(&item_fn.sig); + setter_sig.ident = syn::Ident::new("specify", item_fn.sig.ident.span()); + let value_arg = syn::Ident::new("__value", item_fn.sig.output.span()); + setter_sig.inputs.push(parse_quote!(#value_arg: #value_ty)); + setter_sig.output = ReturnType::Default; + Ok(Some(syn::ImplItemMethod { + attrs: vec![], + vis: item_fn.vis.clone(), + defaultness: None, + sig: setter_sig, + block: parse_quote! { + { + + let (__jar, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(#db_var); + let __ingredients = <_ as salsa::storage::HasIngredientsFor<#config_ty>>::ingredient(__jar); + __ingredients.function.set(#db_var, #(#arg_names,)* #value_arg) + } + }, + })) +} /// Given a function def tagged with `#[return_ref]`, modifies `ref_getter_fn` /// so that it returns an `&Value` instead of `Value`. May introduce a name for the /// database lifetime if required. @@ -386,7 +479,7 @@ fn db_lifetime_and_ty(func: &mut syn::ItemFn) -> syn::Result<(syn::Lifetime, &sy fn accumulated_fn( args: &Args, item_fn: &syn::ItemFn, - struct_ty: &syn::Type, + config_ty: &syn::Type, ) -> syn::Result { let jar_ty = args.jar_ty(); @@ -412,8 +505,8 @@ fn accumulated_fn( accumulated_fn.block = parse_quote! { { let (__jar, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(#db_var); - let __ingredients = <_ as salsa::storage::HasIngredientsFor<#struct_ty>>::ingredient(__jar); - let __key = __ingredients.intern_map.intern(__runtime, (#(#arg_names,)*)); + let __ingredients = <_ as salsa::storage::HasIngredientsFor<#config_ty>>::ingredient(__jar); + let __key = __ingredients.intern_map.intern(__runtime, (#(#arg_names),*)); __ingredients.function.accumulated::<__A>(#db_var, __key) } }; diff --git a/components/salsa-entity-mock/src/interned.rs b/components/salsa-entity-mock/src/interned.rs index f6172df4..b8222fda 100644 --- a/components/salsa-entity-mock/src/interned.rs +++ b/components/salsa-entity-mock/src/interned.rs @@ -1,6 +1,7 @@ use crossbeam::atomic::AtomicCell; use crossbeam::queue::SegQueue; use std::hash::Hash; +use std::marker::PhantomData; use crate::durability::Durability; use crate::id::AsId; @@ -198,3 +199,21 @@ where None } } + +pub struct IdentityInterner { + data: PhantomData, +} + +impl IdentityInterner { + pub fn new() -> Self { + IdentityInterner { data: PhantomData } + } + + pub fn intern(&self, _runtime: &Runtime, id: Id) -> Id { + id + } + + pub fn data(&self, _runtime: &Runtime, id: Id) -> (Id,) { + (id,) + } +} diff --git a/components/salsa-entity-mock/src/lib.rs b/components/salsa-entity-mock/src/lib.rs index 49812bb1..9eb52298 100644 --- a/components/salsa-entity-mock/src/lib.rs +++ b/components/salsa-entity-mock/src/lib.rs @@ -42,7 +42,6 @@ pub use self::storage::Storage; pub use self::tracked_struct::TrackedStructData; pub use self::tracked_struct::TrackedStructId; pub use salsa_entity_macros::accumulator; -pub use salsa_entity_macros::component; pub use salsa_entity_macros::db; pub use salsa_entity_macros::input; pub use salsa_entity_macros::interned; diff --git a/salsa-2022-tests/main.rs b/salsa-2022-tests/main.rs index dd052d09..6ad48afc 100644 --- a/salsa-2022-tests/main.rs +++ b/salsa-2022-tests/main.rs @@ -3,5 +3,6 @@ mod tracked_fn_on_input; mod tracked_fn_on_tracked; +mod tracked_fn_on_tracked_specify; fn main() {} diff --git a/salsa-2022-tests/tracked_fn_on_tracked.rs b/salsa-2022-tests/tracked_fn_on_tracked.rs index 2cfd9404..49d4f81d 100644 --- a/salsa-2022-tests/tracked_fn_on_tracked.rs +++ b/salsa-2022-tests/tracked_fn_on_tracked.rs @@ -2,7 +2,7 @@ //! compiles and executes successfully. #[salsa::jar(db = Db)] -struct Jar(MyInput, tracked_fn); +struct Jar(MyInput, MyTracked, tracked_fn); trait Db: salsa::DbWithJar {} diff --git a/salsa-2022-tests/tracked_fn_on_tracked_specify.rs b/salsa-2022-tests/tracked_fn_on_tracked_specify.rs new file mode 100644 index 00000000..53a5ef1d --- /dev/null +++ b/salsa-2022-tests/tracked_fn_on_tracked_specify.rs @@ -0,0 +1,64 @@ +//! Test that a `tracked` fn on a `salsa::input` +//! compiles and executes successfully. + +#[salsa::jar(db = Db)] +struct Jar(MyInput, MyTracked, tracked_fn, tracked_fn_extra); + +trait Db: salsa::DbWithJar {} + +#[salsa::input(jar = Jar)] +struct MyInput { + field: u32, +} + +#[salsa::tracked(jar = Jar)] +struct MyTracked { + field: u32, +} + +#[salsa::tracked(jar = Jar)] +fn tracked_fn(db: &dyn Db, input: MyInput) -> MyTracked { + let t = MyTracked::new(db, input.field(db) * 2); + if input.field(db) != 0 { + tracked_fn_extra::specify(db, t, 2222); + } + t +} + +#[salsa::tracked(jar = Jar, specify)] +fn tracked_fn_extra(_db: &dyn Db, _input: MyTracked) -> u32 { + 0 +} + +#[salsa::db(Jar)] +#[derive(Default)] +struct Database { + storage: salsa::Storage, +} + +impl salsa::Database for Database { + fn salsa_runtime(&self) -> &salsa::Runtime { + self.storage.runtime() + } +} + +impl Db for Database {} + +#[test] +fn execute_when_specified() { + let mut db = Database::default(); + let input = MyInput::new(&mut db, 22); + let tracked = tracked_fn(&db, input); + assert_eq!(tracked.field(&db), 44); + assert_eq!(tracked_fn_extra(&db, tracked), 2222); +} + + +#[test] +fn execute_when_not_specified() { + let mut db = Database::default(); + let input = MyInput::new(&mut db, 0); + let tracked = tracked_fn(&db, input); + assert_eq!(tracked.field(&db), 0); + assert_eq!(tracked_fn_extra(&db, tracked), 0); +}