WIP permit 'db on tracked struct definitions (opt)

This commit is contained in:
Niko Matsakis 2024-04-18 07:17:18 -04:00
parent 5ce5e3c374
commit b6311d8102
7 changed files with 199 additions and 84 deletions

View file

@ -49,7 +49,7 @@ impl crate::options::AllowedOptions for InputStruct {
impl InputStruct {
fn generate_input(&self) -> syn::Result<TokenStream> {
let id_struct = self.id_struct();
let id_struct = self.the_struct_id();
let inherent_impl = self.input_inherent_impl();
let ingredients_for_impl = self.input_ingredients();
let as_id_impl = self.as_id_impl();
@ -68,7 +68,7 @@ impl InputStruct {
/// Generate an inherent impl with methods on the entity type.
fn input_inherent_impl(&self) -> syn::ItemImpl {
let ident = self.id_ident();
let ident = self.the_ident();
let jar_ty = self.jar_ty();
let db_dyn_ty = self.db_dyn_ty();
let input_index = self.input_index();
@ -212,12 +212,12 @@ impl InputStruct {
/// function ingredient for each of the value fields.
fn input_ingredients(&self) -> syn::ItemImpl {
use crate::literal;
let ident = self.id_ident();
let ident = self.the_ident();
let field_ty = self.all_field_tys();
let jar_ty = self.jar_ty();
let all_field_indices: Vec<Literal> = self.all_field_indices();
let input_index: Literal = self.input_index();
let debug_name_struct = literal(self.id_ident());
let debug_name_struct = literal(self.the_ident());
let debug_name_fields: Vec<_> = self.all_field_names().into_iter().map(literal).collect();
parse_quote! {
@ -304,7 +304,7 @@ impl InputStruct {
/// Implementation of `SalsaStructInDb`.
fn salsa_struct_in_db_impl(&self) -> syn::ItemImpl {
let ident = self.id_ident();
let ident = self.the_ident();
let jar_ty = self.jar_ty();
parse_quote! {
impl<DB> salsa::salsa_struct::SalsaStructInDb<DB> for #ident

View file

@ -54,7 +54,7 @@ impl crate::options::AllowedOptions for InternedStruct {
impl InternedStruct {
fn generate_interned(&self) -> syn::Result<TokenStream> {
self.validate_interned()?;
let id_struct = self.id_struct();
let id_struct = self.the_struct_id();
let data_struct = self.data_struct();
let ingredients_for_impl = self.ingredients_for_impl();
let as_id_impl = self.as_id_impl();
@ -82,7 +82,7 @@ impl InternedStruct {
/// as well as a `new` method.
fn inherent_impl_for_named_fields(&self) -> syn::ItemImpl {
let vis = self.visibility();
let id_ident = self.id_ident();
let id_ident = self.the_ident();
let db_dyn_ty = self.db_dyn_ty();
let jar_ty = self.jar_ty();
@ -144,7 +144,7 @@ impl InternedStruct {
///
/// For a memoized type, the only ingredient is an `InternedIngredient`.
fn ingredients_for_impl(&self) -> syn::ItemImpl {
let id_ident = self.id_ident();
let id_ident = self.the_ident();
let debug_name = crate::literal(id_ident);
let jar_ty = self.jar_ty();
let data_ident = self.data_ident();
@ -177,7 +177,7 @@ impl InternedStruct {
/// Implementation of `SalsaStructInDb`.
fn salsa_struct_in_db_impl(&self) -> syn::ItemImpl {
let ident = self.id_ident();
let ident = self.the_ident();
let jar_ty = self.jar_ty();
parse_quote! {
impl<DB> salsa::salsa_struct::SalsaStructInDb<DB> for #ident

View file

@ -46,6 +46,17 @@ pub enum Customization {
const BANNED_FIELD_NAMES: &[&str] = &["from", "new"];
/// Classifies the kind of field stored in this salsa
/// struct.
#[derive(Debug, PartialEq, Eq)]
pub enum TheStructKind {
/// Stores an "id"
Id,
/// Stores a "pointer"
Pointer,
}
impl<A: AllowedOptions> SalsaStruct<A> {
pub(crate) fn new(
args: proc_macro::TokenStream,
@ -70,6 +81,14 @@ impl<A: AllowedOptions> SalsaStruct<A> {
})
}
pub(crate) fn the_struct_kind(&self) -> TheStructKind {
if self.struct_item.generics.params.is_empty() {
TheStructKind::Id
} else {
TheStructKind::Pointer
}
}
fn extract_customizations(struct_item: &syn::ItemStruct) -> syn::Result<Vec<Customization>> {
Ok(struct_item
.attrs
@ -142,8 +161,8 @@ impl<A: AllowedOptions> SalsaStruct<A> {
self.all_fields().map(|ef| ef.ty()).collect()
}
/// The name of the "identity" struct (this is the name the user gave, e.g., `Foo`).
pub(crate) fn id_ident(&self) -> &syn::Ident {
/// The name of "the struct" (this is the name the user gave, e.g., `Foo`).
pub(crate) fn the_ident(&self) -> &syn::Ident {
&self.struct_item.ident
}
@ -151,7 +170,7 @@ impl<A: AllowedOptions> SalsaStruct<A> {
///
/// * its list of generic parameters
/// * the generics "split for impl".
pub(crate) fn id_ident_and_generics(
pub(crate) fn the_ident_and_generics(
&self,
) -> (
&syn::Ident,
@ -195,17 +214,31 @@ impl<A: AllowedOptions> SalsaStruct<A> {
match &self.args.data {
Some(d) => d.clone(),
None => syn::Ident::new(
&format!("__{}Data", self.id_ident()),
self.id_ident().span(),
&format!("__{}Data", self.the_ident()),
self.the_ident().span(),
),
}
}
/// Create a struct that wraps the id.
/// The type used for `id` values -- this is sometimes
/// the struct type or sometimes `salsa::Id`.
pub(crate) fn id_ty(&self) -> syn::Type {
match self.the_struct_kind() {
TheStructKind::Pointer => parse_quote!(salsa::Id),
TheStructKind::Id => {
let ident = &self.struct_item.ident;
parse_quote!(#ident)
}
}
}
/// Create "the struct" whose field is an id.
/// This is the struct the user will refernece, but only if there
/// are no lifetimes.
pub(crate) fn id_struct(&self) -> syn::ItemStruct {
let ident = self.id_ident();
pub(crate) fn the_struct_id(&self) -> syn::ItemStruct {
assert_eq!(self.the_struct_kind(), TheStructKind::Id);
let ident = self.the_ident();
let visibility = &self.struct_item.vis;
// Extract the attributes the user gave, but screen out derive, since we are adding our own,
@ -227,14 +260,11 @@ impl<A: AllowedOptions> SalsaStruct<A> {
/// Create the struct that the user will reference.
/// If
pub(crate) fn id_or_ptr_struct(
&self,
config_ident: &syn::Ident,
) -> syn::Result<syn::ItemStruct> {
pub(crate) fn the_struct(&self, config_ident: &syn::Ident) -> syn::Result<syn::ItemStruct> {
if self.struct_item.generics.params.is_empty() {
Ok(self.id_struct())
Ok(self.the_struct_id())
} else {
let ident = self.id_ident();
let ident = self.the_ident();
let visibility = &self.struct_item.vis;
let generics = &self.struct_item.generics;
@ -299,38 +329,43 @@ impl<A: AllowedOptions> SalsaStruct<A> {
pub(crate) fn constructor_name(&self) -> syn::Ident {
match self.args.constructor_name.clone() {
Some(name) => name,
None => Ident::new("new", self.id_ident().span()),
None => Ident::new("new", self.the_ident().span()),
}
}
/// Generate `impl salsa::AsId for Foo`
pub(crate) fn as_id_impl(&self) -> syn::ItemImpl {
let ident = self.id_ident();
let (impl_generics, type_generics, where_clause) =
self.struct_item.generics.split_for_impl();
parse_quote_spanned! { ident.span() =>
impl #impl_generics salsa::AsId for #ident #type_generics
#where_clause
{
fn as_id(self) -> salsa::Id {
self.0
}
pub(crate) fn as_id_impl(&self) -> Option<syn::ItemImpl> {
match self.the_struct_kind() {
TheStructKind::Id => {
let ident = self.the_ident();
let (impl_generics, type_generics, where_clause) =
self.struct_item.generics.split_for_impl();
Some(parse_quote_spanned! { ident.span() =>
impl #impl_generics salsa::AsId for #ident #type_generics
#where_clause
{
fn as_id(self) -> salsa::Id {
self.0
}
fn from_id(id: salsa::Id) -> Self {
#ident(id)
}
fn from_id(id: salsa::Id) -> Self {
#ident(id)
}
}
})
}
TheStructKind::Pointer => None,
}
}
/// Generate `impl salsa::DebugWithDb for Foo`
/// Generate `impl salsa::DebugWithDb for Foo`, but only if this is an id struct.
pub(crate) fn as_debug_with_db_impl(&self) -> Option<syn::ItemImpl> {
if self.customizations.contains(&Customization::DebugWithDb) {
return None;
}
let ident = self.id_ident();
let ident = self.the_ident();
let (impl_generics, type_generics, where_clause) =
self.struct_item.generics.split_for_impl();
@ -367,7 +402,7 @@ impl<A: AllowedOptions> SalsaStruct<A> {
#[allow(unused_imports)]
use ::salsa::debug::helper::Fallback;
let mut debug_struct = &mut f.debug_struct(#ident_string);
debug_struct = debug_struct.field("[salsa id]", &self.0.as_u32());
// debug_struct = debug_struct.field("[salsa id]", &self.0.as_u32());
#fields
debug_struct.finish()
}

View file

@ -1,6 +1,6 @@
use proc_macro2::{Literal, Span, TokenStream};
use crate::salsa_struct::{SalsaField, SalsaStruct};
use crate::salsa_struct::{SalsaField, SalsaStruct, TheStructKind};
/// For an tracked struct `Foo` with fields `f1: T1, ..., fN: TN`, we generate...
///
@ -11,7 +11,14 @@ pub(crate) fn tracked(
args: proc_macro::TokenStream,
struct_item: syn::ItemStruct,
) -> syn::Result<TokenStream> {
SalsaStruct::with_struct(args, struct_item).and_then(|el| TrackedStruct(el).generate_tracked())
let tokens = SalsaStruct::with_struct(args, struct_item)
.and_then(|el| TrackedStruct(el).generate_tracked())?;
if std::env::var("NDM").is_ok() {
eprintln!("{}", tokens);
}
Ok(tokens)
}
struct TrackedStruct(SalsaStruct<Self>);
@ -51,7 +58,7 @@ impl TrackedStruct {
self.validate_tracked()?;
let config_struct = self.config_struct();
let the_struct = self.id_or_ptr_struct(&config_struct.ident)?;
let the_struct = self.the_struct(&config_struct.ident)?;
let config_impl = self.config_impl(&config_struct);
let inherent_impl = self.tracked_inherent_impl();
let ingredients_for_impl = self.tracked_struct_ingredients(&config_struct);
@ -80,8 +87,8 @@ impl TrackedStruct {
fn config_struct(&self) -> syn::ItemStruct {
let config_ident = syn::Ident::new(
&format!("__{}Config", self.id_ident()),
self.id_ident().span(),
&format!("__{}Config", self.the_ident()),
self.the_ident().span(),
);
let visibility = self.visibility();
@ -93,7 +100,7 @@ impl TrackedStruct {
}
fn config_impl(&self, config_struct: &syn::ItemStruct) -> syn::ItemImpl {
let id_ident = self.id_ident();
let id_ty = self.id_ty();
let config_ident = &config_struct.ident;
let field_tys: Vec<_> = self.all_fields().map(SalsaField::ty).collect();
let id_field_indices = self.id_field_indices();
@ -136,7 +143,7 @@ impl TrackedStruct {
parse_quote! {
impl salsa::tracked_struct::Configuration for #config_ident {
type Id = #id_ident;
type Id = #id_ty;
type Fields = ( #(#field_tys,)* );
type Revisions = [salsa::Revision; #arity];
@ -168,7 +175,10 @@ impl TrackedStruct {
/// Generate an inherent impl with methods on the tracked type.
fn tracked_inherent_impl(&self) -> syn::ItemImpl {
let (ident, _, impl_generics, type_generics, where_clause) = self.id_ident_and_generics();
let (ident, parameters, impl_generics, type_generics, where_clause) = self.the_ident_and_generics();
let id_ty = self.id_ty();
let lt_db = parameters.iter().next();
let jar_ty = self.jar_ty();
let db_dyn_ty = self.db_dyn_ty();
@ -180,25 +190,53 @@ impl TrackedStruct {
let field_get_names: Vec<_> = self.all_fields().map(SalsaField::get_name).collect();
let field_clones: Vec<_> = self.all_fields().map(SalsaField::is_clone_field).collect();
let field_getters: Vec<syn::ImplItemMethod> = field_indices.iter().zip(&field_get_names).zip(&field_tys).zip(&field_vises).zip(&field_clones).map(|((((field_index, field_get_name), field_ty), field_vis), is_clone_field)|
if !*is_clone_field {
parse_quote_spanned! { field_get_name.span() =>
#field_vis fn #field_get_name<'db>(self, __db: &'db #db_dyn_ty) -> &'db #field_ty
{
let (__jar, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(__db);
let __ingredients = <#jar_ty as salsa::storage::HasIngredientsFor< #ident >>::ingredient(__jar);
&__ingredients.#tracked_field_ingredients[#field_index].field(__runtime, self).#field_index
match self.the_struct_kind() {
TheStructKind::Id => {
if !*is_clone_field {
parse_quote_spanned! { field_get_name.span() =>
#field_vis fn #field_get_name<'db>(self, __db: &'db #db_dyn_ty) -> &'db #field_ty
{
let (__jar, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(__db);
let __ingredients = <#jar_ty as salsa::storage::HasIngredientsFor< #ident #type_generics >>::ingredient(__jar);
&__ingredients.#tracked_field_ingredients[#field_index].field(__runtime, self).#field_index
}
}
} else {
parse_quote_spanned! { field_get_name.span() =>
#field_vis fn #field_get_name<'db>(self, __db: &'db #db_dyn_ty) -> #field_ty
{
let (__jar, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(__db);
let __ingredients = <#jar_ty as salsa::storage::HasIngredientsFor< #ident #type_generics >>::ingredient(__jar);
__ingredients.#tracked_field_ingredients[#field_index].field(__runtime, self).#field_index.clone()
}
}
}
}
} else {
parse_quote_spanned! { field_get_name.span() =>
#field_vis fn #field_get_name<'db>(self, __db: &'db #db_dyn_ty) -> #field_ty
{
let (__jar, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(__db);
let __ingredients = <#jar_ty as salsa::storage::HasIngredientsFor< #ident >>::ingredient(__jar);
__ingredients.#tracked_field_ingredients[#field_index].field(__runtime, self).#field_index.clone()
TheStructKind::Pointer => {
let lt_db = lt_db.unwrap();
if !*is_clone_field {
parse_quote_spanned! { field_get_name.span() =>
#field_vis fn #field_get_name(self, __db: & #lt_db #db_dyn_ty) -> & #lt_db #field_ty
{
let (_, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(__db);
let fields = unsafe { &*self.0 }.field(__runtime, #field_index);
&fields.#field_index
}
}
} else {
parse_quote_spanned! { field_get_name.span() =>
#field_vis fn #field_get_name(self, __db: & #lt_db #db_dyn_ty) -> #field_ty
{
let (_, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(__db);
let fields = unsafe { &*self.0 }.field(__runtime, #field_index);
fields.#field_index.clone()
}
}
}
}
}
}
)
.collect();
@ -206,6 +244,18 @@ impl TrackedStruct {
let field_tys = self.all_field_tys();
let constructor_name = self.constructor_name();
let data = syn::Ident::new("__data", Span::call_site());
let salsa_id = match self.the_struct_kind() {
TheStructKind::Id => quote!(self.0),
TheStructKind::Pointer => quote!(unsafe { &*self.0 }.id()),
};
let ctor = match self.the_struct_kind() {
TheStructKind::Id => quote!(salsa::AsId::from_id(#data.id())),
TheStructKind::Pointer => quote!(Self(#data, std::marker::PhantomData)),
};
parse_quote! {
#[allow(dead_code, clippy::pedantic, clippy::complexity, clippy::style)]
impl #impl_generics #ident #type_generics
@ -213,12 +263,16 @@ impl TrackedStruct {
pub fn #constructor_name(__db: &#db_dyn_ty, #(#field_names: #field_tys,)*) -> Self
{
let (__jar, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(__db);
let __ingredients = <#jar_ty as salsa::storage::HasIngredientsFor< #ident >>::ingredient(__jar);
let __data = __ingredients.0.new_struct(
let __ingredients = <#jar_ty as salsa::storage::HasIngredientsFor< Self >>::ingredient(__jar);
let #data = __ingredients.0.new_struct(
__runtime,
(#(#field_names,)*),
);
__data.id()
#ctor
}
pub fn salsa_id(&self) -> #id_ty {
#salsa_id
}
#(#field_getters)*
@ -232,14 +286,14 @@ impl TrackedStruct {
/// function ingredient for each of the value fields.
fn tracked_struct_ingredients(&self, config_struct: &syn::ItemStruct) -> syn::ItemImpl {
use crate::literal;
let (ident, _, impl_generics, type_generics, where_clause) = self.id_ident_and_generics();
let (ident, _, impl_generics, type_generics, where_clause) = self.the_ident_and_generics();
let jar_ty = self.jar_ty();
let config_struct_name = &config_struct.ident;
let field_indices: Vec<Literal> = self.all_field_indices();
let arity = self.all_field_count();
let tracked_struct_ingredient: Literal = self.tracked_struct_ingredient_index();
let tracked_fields_ingredients: Literal = self.tracked_field_ingredients_index();
let debug_name_struct = literal(self.id_ident());
let debug_name_struct = literal(self.the_ident());
let debug_name_fields: Vec<_> = self.all_field_names().into_iter().map(literal).collect();
parse_quote! {
@ -301,7 +355,7 @@ impl TrackedStruct {
/// Implementation of `SalsaStructInDb`.
fn salsa_struct_in_db_impl(&self) -> syn::ItemImpl {
let (ident, parameters, _, type_generics, where_clause) = self.id_ident_and_generics();
let (ident, parameters, _, type_generics, where_clause) = self.the_ident_and_generics();
let db = syn::Ident::new("DB", ident.span());
let jar_ty = self.jar_ty();
let tracked_struct_ingredient = self.tracked_struct_ingredient_index();
@ -313,7 +367,7 @@ impl TrackedStruct {
{
fn register_dependent_fn(db: & #db, index: salsa::routes::IngredientIndex) {
let (jar, _) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(db);
let ingredients = <#jar_ty as salsa::storage::HasIngredientsFor<#ident>>::ingredient(jar);
let ingredients = <#jar_ty as salsa::storage::HasIngredientsFor<#ident #type_generics>>::ingredient(jar);
ingredients.#tracked_struct_ingredient.register_dependent_fn(index)
}
}
@ -322,7 +376,7 @@ impl TrackedStruct {
/// Implementation of `TrackedStructInDb`.
fn tracked_struct_in_db_impl(&self) -> syn::ItemImpl {
let (ident, parameters, _, type_generics, where_clause) = self.id_ident_and_generics();
let (ident, parameters, _, type_generics, where_clause) = self.the_ident_and_generics();
let db = syn::Ident::new("DB", ident.span());
let jar_ty = self.jar_ty();
let tracked_struct_ingredient = self.tracked_struct_ingredient_index();
@ -334,8 +388,8 @@ impl TrackedStruct {
{
fn database_key_index(self, db: &#db) -> salsa::DatabaseKeyIndex {
let (jar, _) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(db);
let ingredients = <#jar_ty as salsa::storage::HasIngredientsFor<#ident>>::ingredient(jar);
ingredients.#tracked_struct_ingredient.database_key_index(self)
let ingredients = <#jar_ty as salsa::storage::HasIngredientsFor<#ident #type_generics>>::ingredient(jar);
ingredients.#tracked_struct_ingredient.database_key_index(self.salsa_id())
}
}
}
@ -343,9 +397,11 @@ impl TrackedStruct {
/// Implementation of `Update`.
fn update_impl(&self) -> syn::ItemImpl {
let ident = self.id_ident();
let (ident, _, impl_generics, type_generics, where_clause) = self.the_ident_and_generics();
parse_quote! {
unsafe impl salsa::update::Update for #ident {
unsafe impl #impl_generics salsa::update::Update for #ident #type_generics
#where_clause
{
unsafe fn maybe_update(old_pointer: *mut Self, new_value: Self) -> bool {
if unsafe { *old_pointer } != new_value {
unsafe { *old_pointer = new_value };

View file

@ -89,7 +89,7 @@ pub trait Configuration {
type Key: AsId;
/// The value computed by the function.
type Value: fmt::Debug;
type Value<'db>: fmt::Debug;
/// Determines whether this function can recover from being a participant in a cycle
/// (and, if so, how).
@ -101,13 +101,13 @@ pub trait Configuration {
/// even though it was recomputed).
///
/// This invokes user's code in form of the `Eq` impl.
fn should_backdate_value(old_value: &Self::Value, new_value: &Self::Value) -> bool;
fn should_backdate_value(old_value: &Self::Value<'_>, new_value: &Self::Value<'_>) -> bool;
/// Invoked when we need to compute the value for the given key, either because we've never
/// computed it before or because the old one relied on inputs that have changed.
///
/// This invokes the function the user wrote.
fn execute(db: &DynDb<Self>, key: Self::Key) -> Self::Value;
fn execute<'db>(db: &'db DynDb<Self>, key: Self::Key) -> Self::Value<'db>;
/// If the cycle strategy is `Recover`, then invoked when `key` is a participant
/// in a cycle to find out what value it should have.

View file

@ -122,7 +122,6 @@ where
C: Configuration,
{
/// Index of the struct ingredient.
#[allow(dead_code)]
struct_ingredient_index: IngredientIndex,
/// The id of this struct in the ingredient.
@ -385,4 +384,24 @@ where
pub fn id(&self) -> C::Id {
self.id
}
/// Access to this value field.
/// Note that this function returns the entire tuple of value fields.
/// The caller is responible for selecting the appropriate element.
pub fn field<'db>(&'db self, runtime: &'db Runtime, field_index: u32) -> &'db C::Fields {
let field_ingredient_index =
IngredientIndex::from(self.struct_ingredient_index.as_usize() + field_index as usize);
let changed_at = C::revision(&self.revisions, field_index);
runtime.report_tracked_read(
DependencyIndex {
ingredient_index: field_ingredient_index,
key_index: Some(self.id.as_id()),
},
self.durability,
changed_at,
);
&self.fields
}
}

View file

@ -7,7 +7,12 @@ use expect_test::expect;
use test_log::test;
#[salsa::jar(db = Db)]
struct Jar(MyInput, MyTracked, final_result, intermediate_result);
struct Jar(
MyInput,
MyTracked<'static>,
final_result,
intermediate_result,
);
trait Db: salsa::DbWithJar<Jar> + HasLogger {}
@ -23,12 +28,12 @@ fn final_result(db: &dyn Db, input: MyInput) -> u32 {
}
#[salsa::tracked(jar = Jar)]
struct MyTracked {
struct MyTracked<'db> {
field: u32,
}
#[salsa::tracked(jar = Jar)]
fn intermediate_result(db: &dyn Db, input: MyInput) -> MyTracked {
fn intermediate_result<'db>(db: &'db dyn Db, input: MyInput) -> MyTracked<'db> {
db.push_log(format!("intermediate_result({:?})", input));
MyTracked::new(db, input.field(db) / 2)
}