diff --git a/components/salsa-2022-macros/src/configuration.rs b/components/salsa-2022-macros/src/configuration.rs index 65ad7f37..f4942f2d 100644 --- a/components/salsa-2022-macros/src/configuration.rs +++ b/components/salsa-2022-macros/src/configuration.rs @@ -1,3 +1,5 @@ +use crate::xform::ChangeLt; + pub(crate) struct Configuration { pub(crate) db_lt: syn::Lifetime, pub(crate) jar_ty: syn::Type, @@ -88,9 +90,9 @@ pub(crate) fn panic_cycle_recovery_fn() -> syn::ImplItemFn { } } -pub(crate) fn value_ty(sig: &syn::Signature) -> syn::Type { +pub(crate) fn value_ty(db_lt: &syn::Lifetime, sig: &syn::Signature) -> syn::Type { match &sig.output { syn::ReturnType::Default => parse_quote!(()), - syn::ReturnType::Type(_, ty) => syn::Type::clone(ty), + syn::ReturnType::Type(_, ty) => ChangeLt::elided_to(db_lt).in_type(ty), } } diff --git a/components/salsa-2022-macros/src/tracked_fn.rs b/components/salsa-2022-macros/src/tracked_fn.rs index 29744200..2af3482f 100644 --- a/components/salsa-2022-macros/src/tracked_fn.rs +++ b/components/salsa-2022-macros/src/tracked_fn.rs @@ -6,6 +6,7 @@ use syn::{ReturnType, Token}; use crate::configuration::{self, Configuration, CycleRecoveryStrategy}; use crate::db_lifetime::{self, db_lifetime, require_optional_db_lifetime}; use crate::options::Options; +use crate::xform::ChangeLt; pub(crate) fn tracked_fn( args: proc_macro::TokenStream, @@ -340,6 +341,8 @@ fn interned_configuration_impl( (#(#arg_tys),*) ); + let intern_data_ty = ChangeLt::elided_to(db_lt).in_type(&intern_data_ty); + parse_quote!( impl salsa::interned::Configuration for #config_ty { type Data<#db_lt> = #intern_data_ty; @@ -421,7 +424,8 @@ fn fn_configuration(args: &FnArgs, item_fn: &syn::ItemFn) -> Configuration { FunctionType::SalsaStruct => salsa_struct_ty.clone(), FunctionType::RequiresInterning => parse_quote!(salsa::id::Id), }; - let value_ty = configuration::value_ty(&item_fn.sig); + let key_ty = ChangeLt::elided_to(&db_lt).in_type(&key_ty); + let value_ty = configuration::value_ty(&db_lt, &item_fn.sig); let fn_ty = item_fn.sig.ident.clone(); @@ -693,12 +697,21 @@ fn setter_fn( item_fn: &syn::ItemFn, config_ty: &syn::Type, ) -> syn::Result { + let mut setter_sig = item_fn.sig.clone(); + + require_optional_db_lifetime(&item_fn.sig.generics)?; + let db_lt = &db_lifetime(&item_fn.sig.generics); + match setter_sig.generics.lifetimes().count() { + 0 => setter_sig.generics.params.push(parse_quote!(#db_lt)), + 1 => (), + _ => panic!("unreachable -- would have generated an error earlier"), + }; + // The setter has *always* the same signature as the original: // but it takes a value arg and has no return type. let jar_ty = args.jar_ty(); let (db_var, arg_names) = fn_args(item_fn)?; - let mut setter_sig = item_fn.sig.clone(); - let value_ty = configuration::value_ty(&item_fn.sig); + let value_ty = configuration::value_ty(db_lt, &item_fn.sig); setter_sig.ident = syn::Ident::new("set", item_fn.sig.ident.span()); match &mut setter_sig.inputs[0] { // change from `&dyn ...` to `&mut dyn...` @@ -706,6 +719,7 @@ fn setter_fn( syn::FnArg::Typed(pat_ty) => match &mut *pat_ty.ty { syn::Type::Reference(ty) => { ty.mutability = Some(Token![mut](ty.and_token.span())); + ty.lifetime = Some(db_lt.clone()); } _ => unreachable!(), // early fns should have detected }, @@ -771,12 +785,21 @@ fn specify_fn( return Ok(None); } + let mut setter_sig = item_fn.sig.clone(); + + require_optional_db_lifetime(&item_fn.sig.generics)?; + let db_lt = &db_lifetime(&item_fn.sig.generics); + match setter_sig.generics.lifetimes().count() { + 0 => setter_sig.generics.params.push(parse_quote!(#db_lt)), + 1 => (), + _ => panic!("unreachable -- would have generated an error earlier"), + }; + // `specify` has the same signature as the original, // but it takes a value arg and has no return type. let jar_ty = args.jar_ty(); let (db_var, arg_names) = fn_args(item_fn)?; - let mut setter_sig = item_fn.sig.clone(); - let value_ty = configuration::value_ty(&item_fn.sig); + let value_ty = configuration::value_ty(db_lt, &item_fn.sig); setter_sig.ident = syn::Ident::new("specify", item_fn.sig.ident.span()); let value_arg = syn::Ident::new("__value", item_fn.sig.output.span()); setter_sig.inputs.push(parse_quote!(#value_arg: #value_ty)); diff --git a/components/salsa-2022-macros/src/xform.rs b/components/salsa-2022-macros/src/xform.rs index ec00d5c5..9097d617 100644 --- a/components/salsa-2022-macros/src/xform.rs +++ b/components/salsa-2022-macros/src/xform.rs @@ -2,21 +2,28 @@ use syn::visit_mut::VisitMut; pub(crate) struct ChangeLt<'a> { from: Option<&'a str>, - to: &'a str, + to: String, } impl<'a> ChangeLt<'a> { pub fn elided_to_static() -> Self { ChangeLt { from: Some("_"), - to: "static", + to: "static".to_string(), + } + } + + pub fn elided_to(db_lt: &syn::Lifetime) -> Self { + ChangeLt { + from: Some("_"), + to: db_lt.ident.to_string(), } } pub fn to_elided() -> Self { ChangeLt { from: None, - to: "_", + to: "_".to_string(), } } @@ -30,7 +37,7 @@ impl<'a> ChangeLt<'a> { impl syn::visit_mut::VisitMut for ChangeLt<'_> { fn visit_lifetime_mut(&mut self, i: &mut syn::Lifetime) { if self.from.map(|f| i.ident == f).unwrap_or(true) { - i.ident = syn::Ident::new(self.to, i.ident.span()); + i.ident = syn::Ident::new(&self.to, i.ident.span()); } } } diff --git a/salsa-2022-tests/tests/elided-lifetime-in-tracked-fn.rs b/salsa-2022-tests/tests/elided-lifetime-in-tracked-fn.rs new file mode 100644 index 00000000..3bae2667 --- /dev/null +++ b/salsa-2022-tests/tests/elided-lifetime-in-tracked-fn.rs @@ -0,0 +1,81 @@ +//! Test that a `tracked` fn on a `salsa::input` +//! compiles and executes successfully. + +use salsa_2022_tests::{HasLogger, Logger}; + +use expect_test::expect; +use test_log::test; + +#[salsa::jar(db = Db)] +struct Jar(MyInput, MyTracked<'_>, final_result, intermediate_result); + +trait Db: salsa::DbWithJar + HasLogger {} + +#[salsa::input(jar = Jar)] +struct MyInput { + field: u32, +} + +#[salsa::tracked(jar = Jar)] +fn final_result(db: &dyn Db, input: MyInput) -> u32 { + db.push_log(format!("final_result({:?})", input)); + intermediate_result(db, input).field(db) * 2 +} + +#[salsa::tracked(jar = Jar)] +struct MyTracked<'db> { + field: u32, +} + +#[salsa::tracked] +fn intermediate_result(db: &dyn Db, input: MyInput) -> MyTracked<'_> { + db.push_log(format!("intermediate_result({:?})", input)); + MyTracked::new(db, input.field(db) / 2) +} + +#[salsa::db(Jar)] +#[derive(Default)] +struct Database { + storage: salsa::Storage, + logger: Logger, +} + +impl salsa::Database for Database {} + +impl Db for Database {} + +impl HasLogger for Database { + fn logger(&self) -> &Logger { + &self.logger + } +} + +#[test] +fn execute() { + let mut db = Database::default(); + + let input = MyInput::new(&db, 22); + assert_eq!(final_result(&db, input), 22); + db.assert_logs(expect![[r#" + [ + "final_result(MyInput { [salsa id]: 0 })", + "intermediate_result(MyInput { [salsa id]: 0 })", + ]"#]]); + + // Intermediate result is the same, so final result does + // not need to be recomputed: + input.set_field(&mut db).to(23); + assert_eq!(final_result(&db, input), 22); + db.assert_logs(expect![[r#" + [ + "intermediate_result(MyInput { [salsa id]: 0 })", + ]"#]]); + + input.set_field(&mut db).to(24); + assert_eq!(final_result(&db, input), 24); + db.assert_logs(expect![[r#" + [ + "intermediate_result(MyInput { [salsa id]: 0 })", + "final_result(MyInput { [salsa id]: 0 })", + ]"#]]); +}