diff --git a/components/salsa-2022-macros/src/accumulator.rs b/components/salsa-2022-macros/src/accumulator.rs index 26276398..55cb3f54 100644 --- a/components/salsa-2022-macros/src/accumulator.rs +++ b/components/salsa-2022-macros/src/accumulator.rs @@ -34,6 +34,8 @@ impl crate::options::AllowedOptions for Accumulator { const RECOVERY_FN: bool = false; const LRU: bool = false; + + const CONSTRUCTOR_NAME: bool = false; } fn accumulator_contents( diff --git a/components/salsa-2022-macros/src/input.rs b/components/salsa-2022-macros/src/input.rs index a09794c9..9932226b 100644 --- a/components/salsa-2022-macros/src/input.rs +++ b/components/salsa-2022-macros/src/input.rs @@ -64,13 +64,14 @@ impl InputStruct { let input_index = self.input_index(); let field_indices = self.all_field_indices(); - let field_names: Vec<_> = self.all_field_names(); + let field_names = self.all_field_names(); let field_tys: Vec<_> = self.all_field_tys(); let field_clones: Vec<_> = self.all_fields().map(SalsaField::is_clone_field).collect(); - let field_getters: Vec = field_indices.iter().zip(&field_names).zip(&field_tys).zip(&field_clones).map(|(((field_index, field_name), field_ty), is_clone_field)| + let get_field_names: Vec<_> = self.all_get_field_names(); + let field_getters: Vec = field_indices.iter().zip(&get_field_names).zip(&field_tys).zip(&field_clones).map(|(((field_index, get_field_name), field_ty), is_clone_field)| if !*is_clone_field { parse_quote! { - pub fn #field_name<'db>(self, __db: &'db #db_dyn_ty) -> &'db #field_ty + pub fn #get_field_name<'db>(self, __db: &'db #db_dyn_ty) -> &'db #field_ty { let (__jar, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(__db); let __ingredients = <#jar_ty as salsa::storage::HasIngredientsFor< #ident >>::ingredient(__jar); @@ -79,7 +80,7 @@ impl InputStruct { } } else { parse_quote! { - pub fn #field_name<'db>(self, __db: &'db #db_dyn_ty) -> #field_ty + pub fn #get_field_name<'db>(self, __db: &'db #db_dyn_ty) -> #field_ty { let (__jar, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(__db); let __ingredients = <#jar_ty as salsa::storage::HasIngredientsFor< #ident >>::ingredient(__jar); @@ -90,8 +91,8 @@ impl InputStruct { ) .collect(); - let field_setters: Vec = field_indices.iter().zip(&field_names).zip(&field_tys).map(|((field_index, field_name), field_ty)| { - let set_field_name = syn::Ident::new(&format!("set_{}", field_name), field_name.span()); + let set_field_names = self.all_set_field_names(); + let field_setters: Vec = field_indices.iter().zip(&set_field_names).zip(&field_tys).map(|((field_index, set_field_name), field_ty)| { parse_quote! { pub fn #set_field_name<'db>(self, __db: &'db mut #db_dyn_ty, __value: #field_ty) -> #field_ty { @@ -103,9 +104,10 @@ impl InputStruct { }) .collect(); + let constructor_name = self.constructor_name(); parse_quote! { impl #ident { - pub fn new(__db: &mut #db_dyn_ty, #(#field_names: #field_tys,)*) -> Self + pub fn #constructor_name(__db: &mut #db_dyn_ty, #(#field_names: #field_tys,)*) -> Self { let (__jar, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar_mut(__db); let __ingredients = <#jar_ty as salsa::storage::HasIngredientsFor< #ident >>::ingredient_mut(__jar); diff --git a/components/salsa-2022-macros/src/interned.rs b/components/salsa-2022-macros/src/interned.rs index 2d5dc0ff..629a23de 100644 --- a/components/salsa-2022-macros/src/interned.rs +++ b/components/salsa-2022-macros/src/interned.rs @@ -68,9 +68,10 @@ impl InternedStruct { .map(|field| { let field_name = field.name(); let field_ty = field.ty(); + let field_get_name = field.get_name(); if field.is_clone_field() { parse_quote! { - #vis fn #field_name(self, db: &#db_dyn_ty) -> #field_ty { + #vis fn #field_get_name(self, db: &#db_dyn_ty) -> #field_ty { let (jar, runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(db); let ingredients = <#jar_ty as salsa::storage::HasIngredientsFor< #id_ident >>::ingredient(jar); std::clone::Clone::clone(&ingredients.data(runtime, self).#field_name) @@ -78,7 +79,7 @@ impl InternedStruct { } } else { parse_quote! { - #vis fn #field_name<'db>(self, db: &'db #db_dyn_ty) -> &'db #field_ty { + #vis fn #field_get_name<'db>(self, db: &'db #db_dyn_ty) -> &'db #field_ty { let (jar, runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(db); let ingredients = <#jar_ty as salsa::storage::HasIngredientsFor< #id_ident >>::ingredient(jar); &ingredients.data(runtime, self).#field_name @@ -91,8 +92,9 @@ impl InternedStruct { let field_names = self.all_field_names(); let field_tys = self.all_field_tys(); let data_ident = self.data_ident(); + let constructor_name = self.constructor_name(); let new_method: syn::ImplItemMethod = parse_quote! { - #vis fn new( + #vis fn #constructor_name( db: &#db_dyn_ty, #(#field_names: #field_tys,)* ) -> Self { diff --git a/components/salsa-2022-macros/src/jar.rs b/components/salsa-2022-macros/src/jar.rs index d4839802..6a18fc40 100644 --- a/components/salsa-2022-macros/src/jar.rs +++ b/components/salsa-2022-macros/src/jar.rs @@ -43,6 +43,8 @@ impl crate::options::AllowedOptions for Jar { const RECOVERY_FN: bool = false; const LRU: bool = false; + + const CONSTRUCTOR_NAME: bool = false; } pub(crate) fn jar_struct_and_friends( diff --git a/components/salsa-2022-macros/src/options.rs b/components/salsa-2022-macros/src/options.rs index 185bde14..c24f66b6 100644 --- a/components/salsa-2022-macros/src/options.rs +++ b/components/salsa-2022-macros/src/options.rs @@ -4,7 +4,7 @@ use syn::{ext::IdentExt, spanned::Spanned}; /// "Options" are flags that can be supplied to the various salsa related /// macros. They are listed like `(ref, no_eq, foo=bar)` etc. The commas -/// are required and trailing comms are permitted. The options accepted +/// are required and trailing commas are permitted. The options accepted /// for any particular location are configured via the `AllowedOptions` /// trait. pub(crate) struct Options { @@ -51,6 +51,12 @@ pub(crate) struct Options { /// If this is `Some`, the value is the ``. pub lru: Option, + /// The `constructor = ` option lets the user specify the name of + /// the constructor of a salsa struct. + /// + /// If this is `Some`, the value is the ``. + pub constructor_name: Option, + /// Remember the `A` parameter, which plays no role after parsing. phantom: PhantomData, } @@ -65,6 +71,7 @@ impl Default for Options { db_path: Default::default(), recovery_fn: Default::default(), data: Default::default(), + constructor_name: Default::default(), phantom: Default::default(), lru: Default::default(), } @@ -81,6 +88,7 @@ pub(crate) trait AllowedOptions { const DB: bool; const RECOVERY_FN: bool; const LRU: bool; + const CONSTRUCTOR_NAME: bool; } type Equals = syn::Token![=]; @@ -216,6 +224,23 @@ impl syn::parse::Parse for Options { "`lru` option not allowed here", )); } + } else if ident == "constructor" { + if A::CONSTRUCTOR_NAME { + let _eq = Equals::parse(input)?; + let ident = syn::Ident::parse(input)?; + if let Some(old) = std::mem::replace(&mut options.constructor_name, Some(ident)) + { + return Err(syn::Error::new( + old.span(), + "option `constructor` provided twice", + )); + } + } else { + return Err(syn::Error::new( + ident.span(), + "`constructor` option not allowed here", + )); + } } else { return Err(syn::Error::new( ident.span(), diff --git a/components/salsa-2022-macros/src/salsa_struct.rs b/components/salsa-2022-macros/src/salsa_struct.rs index 36b01df9..49e5bdf0 100644 --- a/components/salsa-2022-macros/src/salsa_struct.rs +++ b/components/salsa-2022-macros/src/salsa_struct.rs @@ -26,7 +26,7 @@ //! * this could be optimized, particularly for interned fields use heck::CamelCase; -use proc_macro2::Literal; +use proc_macro2::{Ident, Literal, Span}; use crate::{configuration, options::Options}; @@ -52,6 +52,8 @@ impl crate::options::AllowedOptions for SalsaStruct { const RECOVERY_FN: bool = false; const LRU: bool = false; + + const CONSTRUCTOR_NAME: bool = true; } const BANNED_FIELD_NAMES: &[&str] = &["from", "new"]; @@ -110,6 +112,16 @@ impl SalsaStruct { self.all_fields().map(|ef| ef.name()).collect() } + /// Names of getters of all fields + pub(crate) fn all_get_field_names(&self) -> Vec<&syn::Ident> { + self.all_fields().map(|ef| ef.get_name()).collect() + } + + /// Names of setters of all fields + pub(crate) fn all_set_field_names(&self) -> Vec<&syn::Ident> { + self.all_fields().map(|ef| ef.set_name()).collect() + } + /// Types of all fields (id and value). /// /// If this is an enum, empty vec. @@ -193,6 +205,14 @@ impl SalsaStruct { &self.struct_item.vis } + /// Returns the `constructor_name` in `Options` if it is `Some`, else `new` + pub(crate) fn constructor_name(&self) -> syn::Ident { + match self.args.constructor_name.clone() { + Some(name) => name, + None => Ident::new("new", Span::call_site()), + } + } + /// For each of the fields passed as an argument, /// generate a struct named `Ident_Field` and an impl /// of `salsa::function::Configuration` for that struct. @@ -298,6 +318,12 @@ pub(crate) const FIELD_OPTION_ATTRIBUTES: &[(&str, fn(&syn::Attribute, &mut Sals ("id", |_, ef| ef.has_id_attr = true), ("return_ref", |_, ef| ef.has_ref_attr = true), ("no_eq", |_, ef| ef.has_no_eq_attr = true), + ("get", |attr, ef| { + ef.get_name = attr.parse_args().unwrap(); + }), + ("set", |attr, ef| { + ef.set_name = attr.parse_args().unwrap(); + }), ]; pub(crate) struct SalsaField { @@ -306,6 +332,8 @@ pub(crate) struct SalsaField { pub(crate) has_id_attr: bool, pub(crate) has_ref_attr: bool, pub(crate) has_no_eq_attr: bool, + get_name: syn::Ident, + set_name: syn::Ident, } impl SalsaField { @@ -322,11 +350,15 @@ impl SalsaField { )); } + let get_name = Ident::new(&field_name_str, Span::call_site()); + let set_name = Ident::new(&format!("set_{}", field_name_str), Span::call_site()); let mut result = SalsaField { field: field.clone(), has_id_attr: false, has_ref_attr: false, has_no_eq_attr: false, + get_name, + set_name, }; // Scan the attributes and look for the salsa attributes: @@ -341,16 +373,26 @@ impl SalsaField { Ok(result) } - /// The name of this field (all `EntityField` instances are named). + /// The name of this field (all `SalsaField` instances are named). pub(crate) fn name(&self) -> &syn::Ident { self.field.ident.as_ref().unwrap() } - /// The type of this field (all `EntityField` instances are named). + /// The type of this field (all `SalsaField` instances are named). pub(crate) fn ty(&self) -> &syn::Type { &self.field.ty } + /// The name of this field's get method + pub(crate) fn get_name(&self) -> &syn::Ident { + &self.get_name + } + + /// The name of this field's get method + pub(crate) fn set_name(&self) -> &syn::Ident { + &self.set_name + } + /// Do you clone the value of this field? (True if it is not a ref field) pub(crate) fn is_clone_field(&self) -> bool { !self.has_ref_attr diff --git a/components/salsa-2022-macros/src/tracked_fn.rs b/components/salsa-2022-macros/src/tracked_fn.rs index 4ce59b3d..355a7668 100644 --- a/components/salsa-2022-macros/src/tracked_fn.rs +++ b/components/salsa-2022-macros/src/tracked_fn.rs @@ -81,6 +81,8 @@ impl crate::options::AllowedOptions for TrackedFn { const RECOVERY_FN: bool = true; const LRU: bool = true; + + const CONSTRUCTOR_NAME: bool = false; } /// Returns the key type for this tracked function. diff --git a/components/salsa-2022-macros/src/tracked_struct.rs b/components/salsa-2022-macros/src/tracked_struct.rs index b689ff6d..5212e5b0 100644 --- a/components/salsa-2022-macros/src/tracked_struct.rs +++ b/components/salsa-2022-macros/src/tracked_struct.rs @@ -67,12 +67,13 @@ impl TrackedStruct { let id_field_indices: Vec<_> = self.id_field_indices(); let id_field_names: Vec<_> = self.id_fields().map(SalsaField::name).collect(); + let id_field_get_names: Vec<_> = self.id_fields().map(SalsaField::get_name).collect(); let id_field_tys: Vec<_> = self.id_fields().map(SalsaField::ty).collect(); let id_field_clones: Vec<_> = self.id_fields().map(SalsaField::is_clone_field).collect(); - let id_field_getters: Vec = id_field_indices.iter().zip(&id_field_names).zip(&id_field_tys).zip(&id_field_clones).map(|(((field_index, field_name), field_ty), is_clone_field)| + let id_field_getters: Vec = id_field_indices.iter().zip(&id_field_get_names).zip(&id_field_tys).zip(&id_field_clones).map(|(((field_index, field_get_name), field_ty), is_clone_field)| if !*is_clone_field { parse_quote! { - pub fn #field_name<'db>(self, __db: &'db #db_dyn_ty) -> &'db #field_ty + pub fn #field_get_name<'db>(self, __db: &'db #db_dyn_ty) -> &'db #field_ty { let (__jar, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(__db); let __ingredients = <#jar_ty as salsa::storage::HasIngredientsFor< #ident >>::ingredient(__jar); @@ -81,7 +82,7 @@ impl TrackedStruct { } } else { parse_quote! { - pub fn #field_name<'db>(self, __db: &'db #db_dyn_ty) -> #field_ty + pub fn #field_get_name<'db>(self, __db: &'db #db_dyn_ty) -> #field_ty { let (__jar, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(__db); let __ingredients = <#jar_ty as salsa::storage::HasIngredientsFor< #ident >>::ingredient(__jar); @@ -95,14 +96,15 @@ impl TrackedStruct { let value_field_indices = self.value_field_indices(); let value_field_names: Vec<_> = self.value_fields().map(SalsaField::name).collect(); let value_field_tys: Vec<_> = self.value_fields().map(SalsaField::ty).collect(); + let value_field_get_names: Vec<_> = self.value_fields().map(SalsaField::get_name).collect(); let value_field_clones: Vec<_> = self .value_fields() .map(SalsaField::is_clone_field) .collect(); - let value_field_getters: Vec = value_field_indices.iter().zip(&value_field_names).zip(&value_field_tys).zip(&value_field_clones).map(|(((field_index, field_name), field_ty), is_clone_field)| + let value_field_getters: Vec = value_field_indices.iter().zip(&value_field_get_names).zip(&value_field_tys).zip(&value_field_clones).map(|(((field_index, field_get_name), field_ty), is_clone_field)| if !*is_clone_field { parse_quote! { - pub fn #field_name<'db>(self, __db: &'db #db_dyn_ty) -> &'db #field_ty + pub fn #field_get_name<'db>(self, __db: &'db #db_dyn_ty) -> &'db #field_ty { let (__jar, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(__db); let __ingredients = <#jar_ty as salsa::storage::HasIngredientsFor< #ident >>::ingredient(__jar); @@ -111,7 +113,7 @@ impl TrackedStruct { } } else { parse_quote! { - pub fn #field_name<'db>(self, __db: &'db #db_dyn_ty) -> #field_ty + pub fn #field_get_name<'db>(self, __db: &'db #db_dyn_ty) -> #field_ty { let (__jar, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(__db); let __ingredients = <#jar_ty as salsa::storage::HasIngredientsFor< #ident >>::ingredient(__jar); @@ -124,10 +126,11 @@ impl TrackedStruct { let all_field_names = self.all_field_names(); let all_field_tys = self.all_field_tys(); + let constructor_name = self.constructor_name(); parse_quote! { impl #ident { - pub fn new(__db: &#db_dyn_ty, #(#all_field_names: #all_field_tys,)*) -> Self + pub fn #constructor_name(__db: &#db_dyn_ty, #(#all_field_names: #all_field_tys,)*) -> Self { let (__jar, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(__db); let __ingredients = <#jar_ty as salsa::storage::HasIngredientsFor< #ident >>::ingredient(__jar); diff --git a/salsa-2022-tests/tests/override_new_get_set.rs b/salsa-2022-tests/tests/override_new_get_set.rs new file mode 100644 index 00000000..86a5190c --- /dev/null +++ b/salsa-2022-tests/tests/override_new_get_set.rs @@ -0,0 +1,84 @@ +//! Test that the `constructor` macro overrides +//! the `new` method's name and `get` and `set` +//! change the name of the getter and setter of the fields. +#![allow(warnings)] + +use std::fmt::Display; + +#[salsa::jar(db = Db)] +struct Jar(MyInput, MyInterned, MyTracked); + +trait Db: salsa::DbWithJar {} + +#[salsa::input(jar = Jar, constructor = from_string)] +struct MyInput { + #[get(text)] + #[set(set_text)] + field: String, +} + +impl MyInput { + pub fn new(db: &mut dyn Db, s: impl Display) -> MyInput { + MyInput::from_string(db, s.to_string()) + } + + pub fn field(self, db: &dyn Db) -> String { + self.text(db) + } + + pub fn set_field(self, db: &mut dyn Db, id: String) { + self.set_text(db, id); + } +} + +#[salsa::interned(constructor = from_string)] +struct MyInterned { + #[get(text)] + #[return_ref] + field: String, +} + +impl MyInterned { + pub fn new(db: &dyn Db, s: impl Display) -> MyInterned { + MyInterned::from_string(db, s.to_string()) + } + + pub fn field(self, db: &dyn Db) -> &str { + &self.text(db) + } +} + +#[salsa::tracked(constructor = from_string)] +struct MyTracked { + #[get(text)] + field: String, +} + +impl MyTracked { + pub fn new(db: &dyn Db, s: impl Display) -> MyTracked { + MyTracked::from_string(db, s.to_string()) + } + + pub fn field(self, db: &dyn Db) -> String { + self.text(db) + } +} + +#[test] +fn execute() { + #[salsa::db(Jar)] + #[derive(Default)] + struct Database { + storage: salsa::Storage, + } + + impl salsa::Database for Database { + fn salsa_runtime(&self) -> &salsa::Runtime { + self.storage.runtime() + } + } + + impl Db for Database {} + + let mut db = Database::default(); +}