allow elided lifetimes in tracked fn return values

This commit is contained in:
Niko Matsakis 2024-05-30 09:40:21 -04:00
parent b9ab8fcebd
commit ce750dadf5
4 changed files with 124 additions and 11 deletions

View file

@ -1,3 +1,5 @@
use crate::xform::ChangeLt;
pub(crate) struct Configuration { pub(crate) struct Configuration {
pub(crate) db_lt: syn::Lifetime, pub(crate) db_lt: syn::Lifetime,
pub(crate) jar_ty: syn::Type, pub(crate) jar_ty: syn::Type,
@ -88,9 +90,9 @@ pub(crate) fn panic_cycle_recovery_fn() -> syn::ImplItemFn {
} }
} }
pub(crate) fn value_ty(sig: &syn::Signature) -> syn::Type { pub(crate) fn value_ty(db_lt: &syn::Lifetime, sig: &syn::Signature) -> syn::Type {
match &sig.output { match &sig.output {
syn::ReturnType::Default => parse_quote!(()), syn::ReturnType::Default => parse_quote!(()),
syn::ReturnType::Type(_, ty) => syn::Type::clone(ty), syn::ReturnType::Type(_, ty) => ChangeLt::elided_to(db_lt).in_type(ty),
} }
} }

View file

@ -6,6 +6,7 @@ use syn::{ReturnType, Token};
use crate::configuration::{self, Configuration, CycleRecoveryStrategy}; use crate::configuration::{self, Configuration, CycleRecoveryStrategy};
use crate::db_lifetime::{self, db_lifetime, require_optional_db_lifetime}; use crate::db_lifetime::{self, db_lifetime, require_optional_db_lifetime};
use crate::options::Options; use crate::options::Options;
use crate::xform::ChangeLt;
pub(crate) fn tracked_fn( pub(crate) fn tracked_fn(
args: proc_macro::TokenStream, args: proc_macro::TokenStream,
@ -340,6 +341,8 @@ fn interned_configuration_impl(
(#(#arg_tys),*) (#(#arg_tys),*)
); );
let intern_data_ty = ChangeLt::elided_to(db_lt).in_type(&intern_data_ty);
parse_quote!( parse_quote!(
impl salsa::interned::Configuration for #config_ty { impl salsa::interned::Configuration for #config_ty {
type Data<#db_lt> = #intern_data_ty; type Data<#db_lt> = #intern_data_ty;
@ -421,7 +424,8 @@ fn fn_configuration(args: &FnArgs, item_fn: &syn::ItemFn) -> Configuration {
FunctionType::SalsaStruct => salsa_struct_ty.clone(), FunctionType::SalsaStruct => salsa_struct_ty.clone(),
FunctionType::RequiresInterning => parse_quote!(salsa::id::Id), FunctionType::RequiresInterning => parse_quote!(salsa::id::Id),
}; };
let value_ty = configuration::value_ty(&item_fn.sig); let key_ty = ChangeLt::elided_to(&db_lt).in_type(&key_ty);
let value_ty = configuration::value_ty(&db_lt, &item_fn.sig);
let fn_ty = item_fn.sig.ident.clone(); let fn_ty = item_fn.sig.ident.clone();
@ -693,12 +697,21 @@ fn setter_fn(
item_fn: &syn::ItemFn, item_fn: &syn::ItemFn,
config_ty: &syn::Type, config_ty: &syn::Type,
) -> syn::Result<syn::ImplItemFn> { ) -> syn::Result<syn::ImplItemFn> {
let mut setter_sig = item_fn.sig.clone();
require_optional_db_lifetime(&item_fn.sig.generics)?;
let db_lt = &db_lifetime(&item_fn.sig.generics);
match setter_sig.generics.lifetimes().count() {
0 => setter_sig.generics.params.push(parse_quote!(#db_lt)),
1 => (),
_ => panic!("unreachable -- would have generated an error earlier"),
};
// The setter has *always* the same signature as the original: // The setter has *always* the same signature as the original:
// but it takes a value arg and has no return type. // but it takes a value arg and has no return type.
let jar_ty = args.jar_ty(); let jar_ty = args.jar_ty();
let (db_var, arg_names) = fn_args(item_fn)?; let (db_var, arg_names) = fn_args(item_fn)?;
let mut setter_sig = item_fn.sig.clone(); let value_ty = configuration::value_ty(db_lt, &item_fn.sig);
let value_ty = configuration::value_ty(&item_fn.sig);
setter_sig.ident = syn::Ident::new("set", item_fn.sig.ident.span()); setter_sig.ident = syn::Ident::new("set", item_fn.sig.ident.span());
match &mut setter_sig.inputs[0] { match &mut setter_sig.inputs[0] {
// change from `&dyn ...` to `&mut dyn...` // change from `&dyn ...` to `&mut dyn...`
@ -706,6 +719,7 @@ fn setter_fn(
syn::FnArg::Typed(pat_ty) => match &mut *pat_ty.ty { syn::FnArg::Typed(pat_ty) => match &mut *pat_ty.ty {
syn::Type::Reference(ty) => { syn::Type::Reference(ty) => {
ty.mutability = Some(Token![mut](ty.and_token.span())); ty.mutability = Some(Token![mut](ty.and_token.span()));
ty.lifetime = Some(db_lt.clone());
} }
_ => unreachable!(), // early fns should have detected _ => unreachable!(), // early fns should have detected
}, },
@ -771,12 +785,21 @@ fn specify_fn(
return Ok(None); return Ok(None);
} }
let mut setter_sig = item_fn.sig.clone();
require_optional_db_lifetime(&item_fn.sig.generics)?;
let db_lt = &db_lifetime(&item_fn.sig.generics);
match setter_sig.generics.lifetimes().count() {
0 => setter_sig.generics.params.push(parse_quote!(#db_lt)),
1 => (),
_ => panic!("unreachable -- would have generated an error earlier"),
};
// `specify` has the same signature as the original, // `specify` has the same signature as the original,
// but it takes a value arg and has no return type. // but it takes a value arg and has no return type.
let jar_ty = args.jar_ty(); let jar_ty = args.jar_ty();
let (db_var, arg_names) = fn_args(item_fn)?; let (db_var, arg_names) = fn_args(item_fn)?;
let mut setter_sig = item_fn.sig.clone(); let value_ty = configuration::value_ty(db_lt, &item_fn.sig);
let value_ty = configuration::value_ty(&item_fn.sig);
setter_sig.ident = syn::Ident::new("specify", item_fn.sig.ident.span()); setter_sig.ident = syn::Ident::new("specify", item_fn.sig.ident.span());
let value_arg = syn::Ident::new("__value", item_fn.sig.output.span()); let value_arg = syn::Ident::new("__value", item_fn.sig.output.span());
setter_sig.inputs.push(parse_quote!(#value_arg: #value_ty)); setter_sig.inputs.push(parse_quote!(#value_arg: #value_ty));

View file

@ -2,21 +2,28 @@ use syn::visit_mut::VisitMut;
pub(crate) struct ChangeLt<'a> { pub(crate) struct ChangeLt<'a> {
from: Option<&'a str>, from: Option<&'a str>,
to: &'a str, to: String,
} }
impl<'a> ChangeLt<'a> { impl<'a> ChangeLt<'a> {
pub fn elided_to_static() -> Self { pub fn elided_to_static() -> Self {
ChangeLt { ChangeLt {
from: Some("_"), from: Some("_"),
to: "static", to: "static".to_string(),
}
}
pub fn elided_to(db_lt: &syn::Lifetime) -> Self {
ChangeLt {
from: Some("_"),
to: db_lt.ident.to_string(),
} }
} }
pub fn to_elided() -> Self { pub fn to_elided() -> Self {
ChangeLt { ChangeLt {
from: None, from: None,
to: "_", to: "_".to_string(),
} }
} }
@ -30,7 +37,7 @@ impl<'a> ChangeLt<'a> {
impl syn::visit_mut::VisitMut for ChangeLt<'_> { impl syn::visit_mut::VisitMut for ChangeLt<'_> {
fn visit_lifetime_mut(&mut self, i: &mut syn::Lifetime) { fn visit_lifetime_mut(&mut self, i: &mut syn::Lifetime) {
if self.from.map(|f| i.ident == f).unwrap_or(true) { if self.from.map(|f| i.ident == f).unwrap_or(true) {
i.ident = syn::Ident::new(self.to, i.ident.span()); i.ident = syn::Ident::new(&self.to, i.ident.span());
} }
} }
} }

View file

@ -0,0 +1,81 @@
//! Test that a `tracked` fn on a `salsa::input`
//! compiles and executes successfully.
use salsa_2022_tests::{HasLogger, Logger};
use expect_test::expect;
use test_log::test;
#[salsa::jar(db = Db)]
struct Jar(MyInput, MyTracked<'_>, final_result, intermediate_result);
trait Db: salsa::DbWithJar<Jar> + HasLogger {}
#[salsa::input(jar = Jar)]
struct MyInput {
field: u32,
}
#[salsa::tracked(jar = Jar)]
fn final_result(db: &dyn Db, input: MyInput) -> u32 {
db.push_log(format!("final_result({:?})", input));
intermediate_result(db, input).field(db) * 2
}
#[salsa::tracked(jar = Jar)]
struct MyTracked<'db> {
field: u32,
}
#[salsa::tracked]
fn intermediate_result(db: &dyn Db, input: MyInput) -> MyTracked<'_> {
db.push_log(format!("intermediate_result({:?})", input));
MyTracked::new(db, input.field(db) / 2)
}
#[salsa::db(Jar)]
#[derive(Default)]
struct Database {
storage: salsa::Storage<Self>,
logger: Logger,
}
impl salsa::Database for Database {}
impl Db for Database {}
impl HasLogger for Database {
fn logger(&self) -> &Logger {
&self.logger
}
}
#[test]
fn execute() {
let mut db = Database::default();
let input = MyInput::new(&db, 22);
assert_eq!(final_result(&db, input), 22);
db.assert_logs(expect![[r#"
[
"final_result(MyInput { [salsa id]: 0 })",
"intermediate_result(MyInput { [salsa id]: 0 })",
]"#]]);
// Intermediate result is the same, so final result does
// not need to be recomputed:
input.set_field(&mut db).to(23);
assert_eq!(final_result(&db, input), 22);
db.assert_logs(expect![[r#"
[
"intermediate_result(MyInput { [salsa id]: 0 })",
]"#]]);
input.set_field(&mut db).to(24);
assert_eq!(final_result(&db, input), 24);
db.assert_logs(expect![[r#"
[
"intermediate_result(MyInput { [salsa id]: 0 })",
"final_result(MyInput { [salsa id]: 0 })",
]"#]]);
}