mirror of
https://github.com/salsa-rs/salsa.git
synced 2025-01-14 17:18:20 +00:00
allow elided lifetimes in tracked fn return values
This commit is contained in:
parent
b9ab8fcebd
commit
ce750dadf5
4 changed files with 124 additions and 11 deletions
|
@ -1,3 +1,5 @@
|
|||
use crate::xform::ChangeLt;
|
||||
|
||||
pub(crate) struct Configuration {
|
||||
pub(crate) db_lt: syn::Lifetime,
|
||||
pub(crate) jar_ty: syn::Type,
|
||||
|
@ -88,9 +90,9 @@ pub(crate) fn panic_cycle_recovery_fn() -> syn::ImplItemFn {
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) fn value_ty(sig: &syn::Signature) -> syn::Type {
|
||||
pub(crate) fn value_ty(db_lt: &syn::Lifetime, sig: &syn::Signature) -> syn::Type {
|
||||
match &sig.output {
|
||||
syn::ReturnType::Default => parse_quote!(()),
|
||||
syn::ReturnType::Type(_, ty) => syn::Type::clone(ty),
|
||||
syn::ReturnType::Type(_, ty) => ChangeLt::elided_to(db_lt).in_type(ty),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@ use syn::{ReturnType, Token};
|
|||
use crate::configuration::{self, Configuration, CycleRecoveryStrategy};
|
||||
use crate::db_lifetime::{self, db_lifetime, require_optional_db_lifetime};
|
||||
use crate::options::Options;
|
||||
use crate::xform::ChangeLt;
|
||||
|
||||
pub(crate) fn tracked_fn(
|
||||
args: proc_macro::TokenStream,
|
||||
|
@ -340,6 +341,8 @@ fn interned_configuration_impl(
|
|||
(#(#arg_tys),*)
|
||||
);
|
||||
|
||||
let intern_data_ty = ChangeLt::elided_to(db_lt).in_type(&intern_data_ty);
|
||||
|
||||
parse_quote!(
|
||||
impl salsa::interned::Configuration for #config_ty {
|
||||
type Data<#db_lt> = #intern_data_ty;
|
||||
|
@ -421,7 +424,8 @@ fn fn_configuration(args: &FnArgs, item_fn: &syn::ItemFn) -> Configuration {
|
|||
FunctionType::SalsaStruct => salsa_struct_ty.clone(),
|
||||
FunctionType::RequiresInterning => parse_quote!(salsa::id::Id),
|
||||
};
|
||||
let value_ty = configuration::value_ty(&item_fn.sig);
|
||||
let key_ty = ChangeLt::elided_to(&db_lt).in_type(&key_ty);
|
||||
let value_ty = configuration::value_ty(&db_lt, &item_fn.sig);
|
||||
|
||||
let fn_ty = item_fn.sig.ident.clone();
|
||||
|
||||
|
@ -693,12 +697,21 @@ fn setter_fn(
|
|||
item_fn: &syn::ItemFn,
|
||||
config_ty: &syn::Type,
|
||||
) -> syn::Result<syn::ImplItemFn> {
|
||||
let mut setter_sig = item_fn.sig.clone();
|
||||
|
||||
require_optional_db_lifetime(&item_fn.sig.generics)?;
|
||||
let db_lt = &db_lifetime(&item_fn.sig.generics);
|
||||
match setter_sig.generics.lifetimes().count() {
|
||||
0 => setter_sig.generics.params.push(parse_quote!(#db_lt)),
|
||||
1 => (),
|
||||
_ => panic!("unreachable -- would have generated an error earlier"),
|
||||
};
|
||||
|
||||
// The setter has *always* the same signature as the original:
|
||||
// but it takes a value arg and has no return type.
|
||||
let jar_ty = args.jar_ty();
|
||||
let (db_var, arg_names) = fn_args(item_fn)?;
|
||||
let mut setter_sig = item_fn.sig.clone();
|
||||
let value_ty = configuration::value_ty(&item_fn.sig);
|
||||
let value_ty = configuration::value_ty(db_lt, &item_fn.sig);
|
||||
setter_sig.ident = syn::Ident::new("set", item_fn.sig.ident.span());
|
||||
match &mut setter_sig.inputs[0] {
|
||||
// change from `&dyn ...` to `&mut dyn...`
|
||||
|
@ -706,6 +719,7 @@ fn setter_fn(
|
|||
syn::FnArg::Typed(pat_ty) => match &mut *pat_ty.ty {
|
||||
syn::Type::Reference(ty) => {
|
||||
ty.mutability = Some(Token![mut](ty.and_token.span()));
|
||||
ty.lifetime = Some(db_lt.clone());
|
||||
}
|
||||
_ => unreachable!(), // early fns should have detected
|
||||
},
|
||||
|
@ -771,12 +785,21 @@ fn specify_fn(
|
|||
return Ok(None);
|
||||
}
|
||||
|
||||
let mut setter_sig = item_fn.sig.clone();
|
||||
|
||||
require_optional_db_lifetime(&item_fn.sig.generics)?;
|
||||
let db_lt = &db_lifetime(&item_fn.sig.generics);
|
||||
match setter_sig.generics.lifetimes().count() {
|
||||
0 => setter_sig.generics.params.push(parse_quote!(#db_lt)),
|
||||
1 => (),
|
||||
_ => panic!("unreachable -- would have generated an error earlier"),
|
||||
};
|
||||
|
||||
// `specify` has the same signature as the original,
|
||||
// but it takes a value arg and has no return type.
|
||||
let jar_ty = args.jar_ty();
|
||||
let (db_var, arg_names) = fn_args(item_fn)?;
|
||||
let mut setter_sig = item_fn.sig.clone();
|
||||
let value_ty = configuration::value_ty(&item_fn.sig);
|
||||
let value_ty = configuration::value_ty(db_lt, &item_fn.sig);
|
||||
setter_sig.ident = syn::Ident::new("specify", item_fn.sig.ident.span());
|
||||
let value_arg = syn::Ident::new("__value", item_fn.sig.output.span());
|
||||
setter_sig.inputs.push(parse_quote!(#value_arg: #value_ty));
|
||||
|
|
|
@ -2,21 +2,28 @@ use syn::visit_mut::VisitMut;
|
|||
|
||||
pub(crate) struct ChangeLt<'a> {
|
||||
from: Option<&'a str>,
|
||||
to: &'a str,
|
||||
to: String,
|
||||
}
|
||||
|
||||
impl<'a> ChangeLt<'a> {
|
||||
pub fn elided_to_static() -> Self {
|
||||
ChangeLt {
|
||||
from: Some("_"),
|
||||
to: "static",
|
||||
to: "static".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn elided_to(db_lt: &syn::Lifetime) -> Self {
|
||||
ChangeLt {
|
||||
from: Some("_"),
|
||||
to: db_lt.ident.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_elided() -> Self {
|
||||
ChangeLt {
|
||||
from: None,
|
||||
to: "_",
|
||||
to: "_".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -30,7 +37,7 @@ impl<'a> ChangeLt<'a> {
|
|||
impl syn::visit_mut::VisitMut for ChangeLt<'_> {
|
||||
fn visit_lifetime_mut(&mut self, i: &mut syn::Lifetime) {
|
||||
if self.from.map(|f| i.ident == f).unwrap_or(true) {
|
||||
i.ident = syn::Ident::new(self.to, i.ident.span());
|
||||
i.ident = syn::Ident::new(&self.to, i.ident.span());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
81
salsa-2022-tests/tests/elided-lifetime-in-tracked-fn.rs
Normal file
81
salsa-2022-tests/tests/elided-lifetime-in-tracked-fn.rs
Normal file
|
@ -0,0 +1,81 @@
|
|||
//! Test that a `tracked` fn on a `salsa::input`
|
||||
//! compiles and executes successfully.
|
||||
|
||||
use salsa_2022_tests::{HasLogger, Logger};
|
||||
|
||||
use expect_test::expect;
|
||||
use test_log::test;
|
||||
|
||||
#[salsa::jar(db = Db)]
|
||||
struct Jar(MyInput, MyTracked<'_>, final_result, intermediate_result);
|
||||
|
||||
trait Db: salsa::DbWithJar<Jar> + HasLogger {}
|
||||
|
||||
#[salsa::input(jar = Jar)]
|
||||
struct MyInput {
|
||||
field: u32,
|
||||
}
|
||||
|
||||
#[salsa::tracked(jar = Jar)]
|
||||
fn final_result(db: &dyn Db, input: MyInput) -> u32 {
|
||||
db.push_log(format!("final_result({:?})", input));
|
||||
intermediate_result(db, input).field(db) * 2
|
||||
}
|
||||
|
||||
#[salsa::tracked(jar = Jar)]
|
||||
struct MyTracked<'db> {
|
||||
field: u32,
|
||||
}
|
||||
|
||||
#[salsa::tracked]
|
||||
fn intermediate_result(db: &dyn Db, input: MyInput) -> MyTracked<'_> {
|
||||
db.push_log(format!("intermediate_result({:?})", input));
|
||||
MyTracked::new(db, input.field(db) / 2)
|
||||
}
|
||||
|
||||
#[salsa::db(Jar)]
|
||||
#[derive(Default)]
|
||||
struct Database {
|
||||
storage: salsa::Storage<Self>,
|
||||
logger: Logger,
|
||||
}
|
||||
|
||||
impl salsa::Database for Database {}
|
||||
|
||||
impl Db for Database {}
|
||||
|
||||
impl HasLogger for Database {
|
||||
fn logger(&self) -> &Logger {
|
||||
&self.logger
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn execute() {
|
||||
let mut db = Database::default();
|
||||
|
||||
let input = MyInput::new(&db, 22);
|
||||
assert_eq!(final_result(&db, input), 22);
|
||||
db.assert_logs(expect![[r#"
|
||||
[
|
||||
"final_result(MyInput { [salsa id]: 0 })",
|
||||
"intermediate_result(MyInput { [salsa id]: 0 })",
|
||||
]"#]]);
|
||||
|
||||
// Intermediate result is the same, so final result does
|
||||
// not need to be recomputed:
|
||||
input.set_field(&mut db).to(23);
|
||||
assert_eq!(final_result(&db, input), 22);
|
||||
db.assert_logs(expect![[r#"
|
||||
[
|
||||
"intermediate_result(MyInput { [salsa id]: 0 })",
|
||||
]"#]]);
|
||||
|
||||
input.set_field(&mut db).to(24);
|
||||
assert_eq!(final_result(&db, input), 24);
|
||||
db.assert_logs(expect![[r#"
|
||||
[
|
||||
"intermediate_result(MyInput { [salsa id]: 0 })",
|
||||
"final_result(MyInput { [salsa id]: 0 })",
|
||||
]"#]]);
|
||||
}
|
Loading…
Reference in a new issue