diff --git a/components/salsa-macro-rules/src/setup_input_struct.rs b/components/salsa-macro-rules/src/setup_input_struct.rs index f5a0f984..c9b4a430 100644 --- a/components/salsa-macro-rules/src/setup_input_struct.rs +++ b/components/salsa-macro-rules/src/setup_input_struct.rs @@ -48,6 +48,7 @@ macro_rules! setup_input_struct { $zalsa:ident, $zalsa_struct:ident, $Configuration:ident, + $Builder:ident, $CACHE:ident, $Db:ident, ] @@ -123,14 +124,33 @@ macro_rules! setup_input_struct { } impl $Struct { + #[inline] pub fn $new_fn<$Db>(db: &$Db, $($field_id: $field_ty),*) -> Self where // FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database` $Db: ?Sized + salsa::Database, { - let current_revision = $zalsa::current_revision(db); - let stamps = $zalsa::Array::new([$zalsa::stamp(current_revision, Default::default()); $N]); - $Configuration::ingredient(db.as_dyn_database()).new_input(($($field_id,)*), stamps) + Self::builder($($field_id,)*).new(db) + } + + pub fn builder($($field_id: $field_ty),*) -> ::Builder + { + // Implement `new` here instead of inside the builder module + // because $Configuration can't be named in `builder`. + impl builder::$Builder { + pub fn new<$Db>(self, db: &$Db) -> $Struct + where + // FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database` + $Db: ?Sized + salsa::Database + { + let current_revision = $zalsa::current_revision(db); + let ingredient = $Configuration::ingredient(db.as_dyn_database()); + let (fields, stamps) = builder::builder_into_inner(self, current_revision); + ingredient.new_input(fields, stamps) + } + } + + builder::new_builder($($field_id,)*) } $( @@ -206,6 +226,41 @@ macro_rules! setup_input_struct { }) } } + + impl $zalsa_struct::HasBuilder for $Struct { + type Builder = builder::$Builder; + } + + mod builder { + use super::*; + + use salsa::plumbing as $zalsa; + use $zalsa::input as $zalsa_struct; + + // These are standalone functions instead of methods on `Builder` to prevent + // that the enclosing module can call them. + pub(super) fn new_builder($($field_id: $field_ty),*) -> $Builder { + $Builder { fields: ($($field_id,)*), durability: salsa::Durability::default() } + } + + pub(super) fn builder_into_inner(builder: $Builder, revision: $zalsa::Revision) -> (($($field_ty,)*), $zalsa::Array<$zalsa::Stamp, $N>) { + let stamps = $zalsa::Array::new([$zalsa::stamp(revision, builder.durability); $N]); + (builder.fields, stamps) + } + + pub struct $Builder { + fields: ($($field_ty,)*), + durability: salsa::Durability, + } + + impl $Builder { + /// Sets the durability of all fields + pub fn durability(mut self, durability: salsa::Durability) -> Self { + self.durability = durability; + self + } + } + } }; }; } diff --git a/components/salsa-macro-rules/src/setup_struct_fn.rs b/components/salsa-macro-rules/src/setup_struct_fn.rs deleted file mode 100644 index 8b137891..00000000 --- a/components/salsa-macro-rules/src/setup_struct_fn.rs +++ /dev/null @@ -1 +0,0 @@ - diff --git a/components/salsa-macros/src/input.rs b/components/salsa-macros/src/input.rs index 2ff9f18b..4040bf61 100644 --- a/components/salsa-macros/src/input.rs +++ b/components/salsa-macros/src/input.rs @@ -93,6 +93,7 @@ impl Macro { let zalsa = self.hygiene.ident("zalsa"); let zalsa_struct = self.hygiene.ident("zalsa_struct"); let Configuration = self.hygiene.ident("Configuration"); + let Builder = self.hygiene.ident("Builder"); let CACHE = self.hygiene.ident("CACHE"); let Db = self.hygiene.ident("Db"); @@ -117,6 +118,7 @@ impl Macro { #zalsa, #zalsa_struct, #Configuration, + #Builder, #CACHE, #Db, ] diff --git a/src/input.rs b/src/input.rs index 1da42d75..c086dcb6 100644 --- a/src/input.rs +++ b/src/input.rs @@ -276,3 +276,7 @@ where /// The revision and durability information for each field: when did this field last change. stamps: C::Stamps, } + +pub trait HasBuilder { + type Builder; +} diff --git a/src/lib.rs b/src/lib.rs index f701fc1e..422a632f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -122,6 +122,7 @@ pub mod plumbing { pub use crate::input::input_field::FieldIngredientImpl; pub use crate::input::setter::SetterImpl; pub use crate::input::Configuration; + pub use crate::input::HasBuilder; pub use crate::input::IngredientImpl; pub use crate::input::JarImpl; } diff --git a/tests/tracked_fn_on_input_with_high_durability.rs b/tests/tracked_fn_on_input_with_high_durability.rs new file mode 100644 index 00000000..4cff92bd --- /dev/null +++ b/tests/tracked_fn_on_input_with_high_durability.rs @@ -0,0 +1,75 @@ +#![allow(warnings)] + +use expect_test::expect; + +use common::{HasLogger, Logger}; +use salsa::plumbing::HasStorage; +use salsa::{Database, Durability, Event, EventKind, Setter}; + +mod common; +#[salsa::input] +struct MyInput { + field: u32, +} + +#[salsa::tracked] +fn tracked_fn(db: &dyn salsa::Database, input: MyInput) -> u32 { + input.field(db) * 2 +} + +#[test] +fn execute() { + #[salsa::db] + #[derive(Default)] + struct Database { + storage: salsa::Storage, + logger: Logger, + } + + #[salsa::db] + impl salsa::Database for Database { + fn salsa_event(&self, event: Event) { + match event.kind { + EventKind::WillCheckCancellation => {} + _ => { + self.push_log(format!("salsa_event({:?})", event.kind)); + } + } + } + } + + impl HasLogger for Database { + fn logger(&self) -> &Logger { + &self.logger + } + } + + let mut db = Database::default(); + let input_low = MyInput::new(&db, 22); + let input_high = MyInput::builder(2200).durability(Durability::HIGH).new(&db); + + assert_eq!(tracked_fn(&db, input_low), 44); + assert_eq!(tracked_fn(&db, input_high), 4400); + + db.assert_logs(expect![[r#" + [ + "salsa_event(WillExecute { database_key: tracked_fn(0) })", + "salsa_event(WillExecute { database_key: tracked_fn(1) })", + ]"#]]); + + db.synthetic_write(Durability::LOW); + + assert_eq!(tracked_fn(&db, input_low), 44); + assert_eq!(tracked_fn(&db, input_high), 4400); + + // There's currently no good way to verify whether an input was validated using shallow or deep comparison. + // All we can do for now is verify that the values were validated. + // Note: It maybe confusing why it validates `input_high` when the write has `Durability::LOW`. + // This is because all values must be validated whenever a write occurs. It doesn't mean that it + // executed the query. + db.assert_logs(expect![[r#" + [ + "salsa_event(DidValidateMemoizedValue { database_key: tracked_fn(0) })", + "salsa_event(DidValidateMemoizedValue { database_key: tracked_fn(1) })", + ]"#]]); +}