diff --git a/components/salsa-macro-rules/src/setup_tracked_fn.rs b/components/salsa-macro-rules/src/setup_tracked_fn.rs index 1a25b378..2c11d9ca 100644 --- a/components/salsa-macro-rules/src/setup_tracked_fn.rs +++ b/components/salsa-macro-rules/src/setup_tracked_fn.rs @@ -227,9 +227,8 @@ macro_rules! setup_tracked_fn { $zalsa::AsId::as_id(&($($input_id),*)) } }; - $Configuration::fn_ingredient($db).fetch($db, key); - let database_key_index = $Configuration::fn_ingredient($db).database_key_index(key); - $zalsa::accumulated_by($db.as_salsa_database(), database_key_index) + + $Configuration::fn_ingredient($db).accumulated_by::($db, key) } $zalsa::macro_if! { $is_specifiable => diff --git a/src/accumulator.rs b/src/accumulator.rs index cc96d29b..dad727ca 100644 --- a/src/accumulator.rs +++ b/src/accumulator.rs @@ -210,16 +210,3 @@ where .finish() } } - -pub fn accumulated_by(db: &dyn Database, database_key_index: DatabaseKeyIndex) -> Vec -where - A: Accumulator, -{ - let Some(accumulator) = >::from_db(db) else { - return vec![]; - }; - let runtime = db.runtime(); - let mut output = vec![]; - accumulator.produced_by(runtime, database_key_index, &mut output); - output -} diff --git a/src/function.rs b/src/function.rs index 51c42587..7ee0cd04 100644 --- a/src/function.rs +++ b/src/function.rs @@ -16,6 +16,7 @@ use self::delete::DeletedEntries; use super::ingredient::Ingredient; +mod accumulated; mod backdate; mod delete; mod diff_outputs; diff --git a/src/function/accumulated.rs b/src/function/accumulated.rs new file mode 100644 index 00000000..8bd86140 --- /dev/null +++ b/src/function/accumulated.rs @@ -0,0 +1,37 @@ +use crate::{accumulator, storage::DatabaseGen, Database, Id}; + +use super::{Configuration, IngredientImpl}; + +impl IngredientImpl +where + C: Configuration, +{ + /// Helper used by `accumulate` functions. Computes the results accumulated by `database_key_index` + /// and its inputs. + pub fn accumulated_by(&self, db: &C::DbView, key: Id) -> Vec + where + A: accumulator::Accumulator, + { + let Some(accumulator) = >::from_db(db) else { + return vec![]; + }; + let runtime = db.runtime(); + let mut output = vec![]; + + // First ensure the result is up to date + self.fetch(db, key); + + let database_key_index = self.database_key_index(key); + accumulator.produced_by(runtime, database_key_index, &mut output); + + if let Some(origin) = self.origin(key) { + for input in origin.inputs() { + if let Ok(input) = input.try_into() { + accumulator.produced_by(runtime, input, &mut output); + } + } + } + + output + } +} diff --git a/src/lib.rs b/src/lib.rs index 299714e9..a68dd73c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -73,7 +73,6 @@ pub mod prelude { /// /// The contents of this module are NOT subject to semver. pub mod plumbing { - pub use crate::accumulator::accumulated_by; pub use crate::accumulator::Accumulator; pub use crate::array::Array; pub use crate::cycle::Cycle; diff --git a/src/runtime/local_state.rs b/src/runtime/local_state.rs index 2fc21cb3..2213bf1f 100644 --- a/src/runtime/local_state.rs +++ b/src/runtime/local_state.rs @@ -76,7 +76,16 @@ pub enum QueryOrigin { } impl QueryOrigin { - /// Indices for queries *written* by this query (or `vec![]` if its value was assigned). + /// Indices for queries *read* by this query + pub(crate) fn inputs(&self) -> impl Iterator + '_ { + let opt_edges = match self { + QueryOrigin::Derived(edges) | QueryOrigin::DerivedUntracked(edges) => Some(edges), + QueryOrigin::Assigned(_) | QueryOrigin::BaseInput => None, + }; + opt_edges.into_iter().flat_map(|edges| edges.inputs()) + } + + /// Indices for queries *written* by this query (if any) pub(crate) fn outputs(&self) -> impl Iterator + '_ { let opt_edges = match self { QueryOrigin::Derived(edges) | QueryOrigin::DerivedUntracked(edges) => Some(edges), diff --git a/tests/accumulate-dag.rs b/tests/accumulate-dag.rs new file mode 100644 index 00000000..5aa5b0df --- /dev/null +++ b/tests/accumulate-dag.rs @@ -0,0 +1,68 @@ +mod common; +use common::{HasLogger, Logger}; + +use expect_test::expect; +use salsa::{Accumulator, Database}; +use test_log::test; + +#[salsa::input] +struct MyInput { + field_a: u32, + field_b: u32, +} + +#[salsa::accumulator] +#[derive(Clone, Debug)] +struct Log(#[allow(dead_code)] String); + +#[salsa::tracked] +fn push_logs(db: &dyn Database, input: MyInput) { + push_a_logs(db, input); + push_b_logs(db, input); +} + +#[salsa::tracked] +fn push_a_logs(db: &dyn Database, input: MyInput) { + let count = input.field_a(db); + for i in 0..count { + Log(format!("log_a({} of {})", i, count)).accumulate(db); + } +} + +#[salsa::tracked] +fn push_b_logs(db: &dyn Database, input: MyInput) { + // Note that b calls a + push_a_logs(db, input); + let count = input.field_b(db); + for i in 0..count { + Log(format!("log_b({} of {})", i, count)).accumulate(db); + } +} + +#[test] +fn accumulate_a_called_twice() { + salsa::default_database().attach(|db| { + let input = MyInput::new(db, 2, 3); + let logs = push_logs::accumulated::(db, input); + // Check that we don't see logs from `a` appearing twice in the input. + expect![[r#" + [ + Log( + "log_a(0 of 2)", + ), + Log( + "log_a(1 of 2)", + ), + Log( + "log_b(0 of 3)", + ), + Log( + "log_b(1 of 3)", + ), + Log( + "log_b(2 of 3)", + ), + ]"#]] + .assert_eq(&format!("{:#?}", logs)); + }) +} diff --git a/tests/accumulate.rs b/tests/accumulate.rs index 2a5ddeb4..6fb4e05f 100644 --- a/tests/accumulate.rs +++ b/tests/accumulate.rs @@ -20,7 +20,7 @@ struct MyInput { #[salsa::accumulator] #[derive(Clone, Debug)] -struct Logs(#[allow(dead_code)] String); +struct Log(#[allow(dead_code)] String); #[salsa::tracked] fn push_logs(db: &dyn Db, input: MyInput) { @@ -30,13 +30,13 @@ fn push_logs(db: &dyn Db, input: MyInput) { input.field_b(db) )); - // We don't invoke `push_a_logs` (or `push_b_logs`) with a value of 1 or less. + // We don't invoke `push_a_logs` (or `push_b_logs`) with a value of 0. // This allows us to test what happens a change in inputs causes a function not to be called at all. - if input.field_a(db) > 1 { + if input.field_a(db) > 0 { push_a_logs(db, input); } - if input.field_b(db) > 1 { + if input.field_b(db) > 0 { push_b_logs(db, input); } } @@ -47,7 +47,7 @@ fn push_a_logs(db: &dyn Db, input: MyInput) { db.push_log(format!("push_a_logs({})", field_a)); for i in 0..field_a { - Logs(format!("log_a({} of {})", i, field_a)).accumulate(db); + Log(format!("log_a({} of {})", i, field_a)).accumulate(db); } } @@ -57,7 +57,7 @@ fn push_b_logs(db: &dyn Db, input: MyInput) { db.push_log(format!("push_b_logs({})", field_a)); for i in 0..field_a { - Logs(format!("log_b({} of {})", i, field_a)).accumulate(db); + Log(format!("log_b({} of {})", i, field_a)).accumulate(db); } } @@ -88,38 +88,115 @@ fn accumulate_once() { // Just call accumulate on a base input to see what happens. let input = MyInput::new(&db, 2, 3); - let logs = push_logs::accumulated::(&db, input); + let logs = push_logs::accumulated::(&db, input); db.assert_logs(expect![[r#" [ "push_logs(a = 2, b = 3)", "push_a_logs(2)", "push_b_logs(3)", ]"#]]); + // Check that we see logs from `a` first and then logs from `b` + // (execution order). expect![[r#" [ - "log_b(0 of 3)", - "log_b(1 of 3)", - "log_b(2 of 3)", - "log_a(0 of 2)", - "log_a(1 of 2)", + Log( + "log_a(0 of 2)", + ), + Log( + "log_a(1 of 2)", + ), + Log( + "log_b(0 of 3)", + ), + Log( + "log_b(1 of 3)", + ), + Log( + "log_b(2 of 3)", + ), ]"#]] .assert_eq(&format!("{:#?}", logs)); } #[test] -fn change_a_and_reaccumulate() { +fn change_a_from_2_to_0() { let mut db = Database::default(); // Accumulate logs for `a = 2` and `b = 3` let input = MyInput::new(&db, 2, 3); - let logs = push_logs::accumulated::(&db, input); + let logs = push_logs::accumulated::(&db, input); expect![[r#" [ - "log_b(0 of 3)", - "log_b(1 of 3)", - "log_b(2 of 3)", - "log_a(0 of 2)", - "log_a(1 of 2)", + Log( + "log_a(0 of 2)", + ), + Log( + "log_a(1 of 2)", + ), + Log( + "log_b(0 of 3)", + ), + Log( + "log_b(1 of 3)", + ), + Log( + "log_b(2 of 3)", + ), + ]"#]] + .assert_eq(&format!("{:#?}", logs)); + db.assert_logs(expect![[r#" + [ + "push_logs(a = 2, b = 3)", + "push_a_logs(2)", + "push_b_logs(3)", + ]"#]]); + + // Change to `a = 0`, which means `push_logs` does not call `push_a_logs` at all + input.set_field_a(&mut db).to(0); + let logs = push_logs::accumulated::(&db, input); + expect![[r#" + [ + Log( + "log_b(0 of 3)", + ), + Log( + "log_b(1 of 3)", + ), + Log( + "log_b(2 of 3)", + ), + ]"#]] + .assert_eq(&format!("{:#?}", logs)); + db.assert_logs(expect![[r#" + [ + "push_logs(a = 1, b = 3)", + ]"#]]); +} + +#[test] +fn change_a_from_2_to_1() { + let mut db = Database::default(); + + // Accumulate logs for `a = 2` and `b = 3` + let input = MyInput::new(&db, 2, 3); + let logs = push_logs::accumulated::(&db, input); + expect![[r#" + [ + Log( + "log_a(0 of 2)", + ), + Log( + "log_a(1 of 2)", + ), + Log( + "log_b(0 of 3)", + ), + Log( + "log_b(1 of 3)", + ), + Log( + "log_b(2 of 3)", + ), ]"#]] .assert_eq(&format!("{:#?}", logs)); db.assert_logs(expect![[r#" @@ -131,17 +208,27 @@ fn change_a_and_reaccumulate() { // Change to `a = 1`, which means `push_logs` does not call `push_a_logs` at all input.set_field_a(&mut db).to(1); - let logs = push_logs::accumulated::(&db, input); + let logs = push_logs::accumulated::(&db, input); expect![[r#" [ - "log_b(0 of 3)", - "log_b(1 of 3)", - "log_b(2 of 3)", + Log( + "log_a(0 of 1)", + ), + Log( + "log_b(0 of 3)", + ), + Log( + "log_b(1 of 3)", + ), + Log( + "log_b(2 of 3)", + ), ]"#]] .assert_eq(&format!("{:#?}", logs)); db.assert_logs(expect![[r#" [ "push_logs(a = 1, b = 3)", + "push_a_logs(1)", ]"#]]); } @@ -151,11 +238,15 @@ fn get_a_logs_after_changing_b() { // Invoke `push_a_logs` with `a = 2` and `b = 3` (but `b` doesn't matter) let input = MyInput::new(&db, 2, 3); - let logs = push_a_logs::accumulated::(&db, input); + let logs = push_a_logs::accumulated::(&db, input); expect![[r#" [ - "log_a(0 of 2)", - "log_a(1 of 2)", + Log( + "log_a(0 of 2)", + ), + Log( + "log_a(1 of 2)", + ), ]"#]] .assert_eq(&format!("{:#?}", logs)); db.assert_logs(expect![[r#" @@ -166,11 +257,15 @@ fn get_a_logs_after_changing_b() { // Changing `b` does not cause `push_a_logs` to re-execute // and we still get the same result input.set_field_b(&mut db).to(5); - let logs = push_a_logs::accumulated::(&db, input); + let logs = push_a_logs::accumulated::(&db, input); expect![[r#" [ - "log_a(0 of 2)", - "log_a(1 of 2)", + Log( + "log_a(0 of 2)", + ), + Log( + "log_a(1 of 2)", + ), ] "#]] .assert_debug_eq(&logs);