diff --git a/components/salsa-macro-rules/src/setup_tracked_fn.rs b/components/salsa-macro-rules/src/setup_tracked_fn.rs
index 1a25b378..2c11d9ca 100644
--- a/components/salsa-macro-rules/src/setup_tracked_fn.rs
+++ b/components/salsa-macro-rules/src/setup_tracked_fn.rs
@@ -227,9 +227,8 @@ macro_rules! setup_tracked_fn {
$zalsa::AsId::as_id(&($($input_id),*))
}
};
- $Configuration::fn_ingredient($db).fetch($db, key);
- let database_key_index = $Configuration::fn_ingredient($db).database_key_index(key);
- $zalsa::accumulated_by($db.as_salsa_database(), database_key_index)
+
+ $Configuration::fn_ingredient($db).accumulated_by::($db, key)
}
$zalsa::macro_if! { $is_specifiable =>
diff --git a/src/accumulator.rs b/src/accumulator.rs
index cc96d29b..dad727ca 100644
--- a/src/accumulator.rs
+++ b/src/accumulator.rs
@@ -210,16 +210,3 @@ where
.finish()
}
}
-
-pub fn accumulated_by(db: &dyn Database, database_key_index: DatabaseKeyIndex) -> Vec
-where
- A: Accumulator,
-{
- let Some(accumulator) = >::from_db(db) else {
- return vec![];
- };
- let runtime = db.runtime();
- let mut output = vec![];
- accumulator.produced_by(runtime, database_key_index, &mut output);
- output
-}
diff --git a/src/function.rs b/src/function.rs
index 51c42587..7ee0cd04 100644
--- a/src/function.rs
+++ b/src/function.rs
@@ -16,6 +16,7 @@ use self::delete::DeletedEntries;
use super::ingredient::Ingredient;
+mod accumulated;
mod backdate;
mod delete;
mod diff_outputs;
diff --git a/src/function/accumulated.rs b/src/function/accumulated.rs
new file mode 100644
index 00000000..8bd86140
--- /dev/null
+++ b/src/function/accumulated.rs
@@ -0,0 +1,37 @@
+use crate::{accumulator, storage::DatabaseGen, Database, Id};
+
+use super::{Configuration, IngredientImpl};
+
+impl IngredientImpl
+where
+ C: Configuration,
+{
+ /// Helper used by `accumulate` functions. Computes the results accumulated by `database_key_index`
+ /// and its inputs.
+ pub fn accumulated_by(&self, db: &C::DbView, key: Id) -> Vec
+ where
+ A: accumulator::Accumulator,
+ {
+ let Some(accumulator) = >::from_db(db) else {
+ return vec![];
+ };
+ let runtime = db.runtime();
+ let mut output = vec![];
+
+ // First ensure the result is up to date
+ self.fetch(db, key);
+
+ let database_key_index = self.database_key_index(key);
+ accumulator.produced_by(runtime, database_key_index, &mut output);
+
+ if let Some(origin) = self.origin(key) {
+ for input in origin.inputs() {
+ if let Ok(input) = input.try_into() {
+ accumulator.produced_by(runtime, input, &mut output);
+ }
+ }
+ }
+
+ output
+ }
+}
diff --git a/src/lib.rs b/src/lib.rs
index 299714e9..a68dd73c 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -73,7 +73,6 @@ pub mod prelude {
///
/// The contents of this module are NOT subject to semver.
pub mod plumbing {
- pub use crate::accumulator::accumulated_by;
pub use crate::accumulator::Accumulator;
pub use crate::array::Array;
pub use crate::cycle::Cycle;
diff --git a/src/runtime/local_state.rs b/src/runtime/local_state.rs
index 2fc21cb3..2213bf1f 100644
--- a/src/runtime/local_state.rs
+++ b/src/runtime/local_state.rs
@@ -76,7 +76,16 @@ pub enum QueryOrigin {
}
impl QueryOrigin {
- /// Indices for queries *written* by this query (or `vec![]` if its value was assigned).
+ /// Indices for queries *read* by this query
+ pub(crate) fn inputs(&self) -> impl Iterator- + '_ {
+ let opt_edges = match self {
+ QueryOrigin::Derived(edges) | QueryOrigin::DerivedUntracked(edges) => Some(edges),
+ QueryOrigin::Assigned(_) | QueryOrigin::BaseInput => None,
+ };
+ opt_edges.into_iter().flat_map(|edges| edges.inputs())
+ }
+
+ /// Indices for queries *written* by this query (if any)
pub(crate) fn outputs(&self) -> impl Iterator
- + '_ {
let opt_edges = match self {
QueryOrigin::Derived(edges) | QueryOrigin::DerivedUntracked(edges) => Some(edges),
diff --git a/tests/accumulate-dag.rs b/tests/accumulate-dag.rs
new file mode 100644
index 00000000..5aa5b0df
--- /dev/null
+++ b/tests/accumulate-dag.rs
@@ -0,0 +1,68 @@
+mod common;
+use common::{HasLogger, Logger};
+
+use expect_test::expect;
+use salsa::{Accumulator, Database};
+use test_log::test;
+
+#[salsa::input]
+struct MyInput {
+ field_a: u32,
+ field_b: u32,
+}
+
+#[salsa::accumulator]
+#[derive(Clone, Debug)]
+struct Log(#[allow(dead_code)] String);
+
+#[salsa::tracked]
+fn push_logs(db: &dyn Database, input: MyInput) {
+ push_a_logs(db, input);
+ push_b_logs(db, input);
+}
+
+#[salsa::tracked]
+fn push_a_logs(db: &dyn Database, input: MyInput) {
+ let count = input.field_a(db);
+ for i in 0..count {
+ Log(format!("log_a({} of {})", i, count)).accumulate(db);
+ }
+}
+
+#[salsa::tracked]
+fn push_b_logs(db: &dyn Database, input: MyInput) {
+ // Note that b calls a
+ push_a_logs(db, input);
+ let count = input.field_b(db);
+ for i in 0..count {
+ Log(format!("log_b({} of {})", i, count)).accumulate(db);
+ }
+}
+
+#[test]
+fn accumulate_a_called_twice() {
+ salsa::default_database().attach(|db| {
+ let input = MyInput::new(db, 2, 3);
+ let logs = push_logs::accumulated::(db, input);
+ // Check that we don't see logs from `a` appearing twice in the input.
+ expect![[r#"
+ [
+ Log(
+ "log_a(0 of 2)",
+ ),
+ Log(
+ "log_a(1 of 2)",
+ ),
+ Log(
+ "log_b(0 of 3)",
+ ),
+ Log(
+ "log_b(1 of 3)",
+ ),
+ Log(
+ "log_b(2 of 3)",
+ ),
+ ]"#]]
+ .assert_eq(&format!("{:#?}", logs));
+ })
+}
diff --git a/tests/accumulate.rs b/tests/accumulate.rs
index 2a5ddeb4..6fb4e05f 100644
--- a/tests/accumulate.rs
+++ b/tests/accumulate.rs
@@ -20,7 +20,7 @@ struct MyInput {
#[salsa::accumulator]
#[derive(Clone, Debug)]
-struct Logs(#[allow(dead_code)] String);
+struct Log(#[allow(dead_code)] String);
#[salsa::tracked]
fn push_logs(db: &dyn Db, input: MyInput) {
@@ -30,13 +30,13 @@ fn push_logs(db: &dyn Db, input: MyInput) {
input.field_b(db)
));
- // We don't invoke `push_a_logs` (or `push_b_logs`) with a value of 1 or less.
+ // We don't invoke `push_a_logs` (or `push_b_logs`) with a value of 0.
// This allows us to test what happens a change in inputs causes a function not to be called at all.
- if input.field_a(db) > 1 {
+ if input.field_a(db) > 0 {
push_a_logs(db, input);
}
- if input.field_b(db) > 1 {
+ if input.field_b(db) > 0 {
push_b_logs(db, input);
}
}
@@ -47,7 +47,7 @@ fn push_a_logs(db: &dyn Db, input: MyInput) {
db.push_log(format!("push_a_logs({})", field_a));
for i in 0..field_a {
- Logs(format!("log_a({} of {})", i, field_a)).accumulate(db);
+ Log(format!("log_a({} of {})", i, field_a)).accumulate(db);
}
}
@@ -57,7 +57,7 @@ fn push_b_logs(db: &dyn Db, input: MyInput) {
db.push_log(format!("push_b_logs({})", field_a));
for i in 0..field_a {
- Logs(format!("log_b({} of {})", i, field_a)).accumulate(db);
+ Log(format!("log_b({} of {})", i, field_a)).accumulate(db);
}
}
@@ -88,38 +88,115 @@ fn accumulate_once() {
// Just call accumulate on a base input to see what happens.
let input = MyInput::new(&db, 2, 3);
- let logs = push_logs::accumulated::(&db, input);
+ let logs = push_logs::accumulated::(&db, input);
db.assert_logs(expect![[r#"
[
"push_logs(a = 2, b = 3)",
"push_a_logs(2)",
"push_b_logs(3)",
]"#]]);
+ // Check that we see logs from `a` first and then logs from `b`
+ // (execution order).
expect![[r#"
[
- "log_b(0 of 3)",
- "log_b(1 of 3)",
- "log_b(2 of 3)",
- "log_a(0 of 2)",
- "log_a(1 of 2)",
+ Log(
+ "log_a(0 of 2)",
+ ),
+ Log(
+ "log_a(1 of 2)",
+ ),
+ Log(
+ "log_b(0 of 3)",
+ ),
+ Log(
+ "log_b(1 of 3)",
+ ),
+ Log(
+ "log_b(2 of 3)",
+ ),
]"#]]
.assert_eq(&format!("{:#?}", logs));
}
#[test]
-fn change_a_and_reaccumulate() {
+fn change_a_from_2_to_0() {
let mut db = Database::default();
// Accumulate logs for `a = 2` and `b = 3`
let input = MyInput::new(&db, 2, 3);
- let logs = push_logs::accumulated::(&db, input);
+ let logs = push_logs::accumulated::(&db, input);
expect![[r#"
[
- "log_b(0 of 3)",
- "log_b(1 of 3)",
- "log_b(2 of 3)",
- "log_a(0 of 2)",
- "log_a(1 of 2)",
+ Log(
+ "log_a(0 of 2)",
+ ),
+ Log(
+ "log_a(1 of 2)",
+ ),
+ Log(
+ "log_b(0 of 3)",
+ ),
+ Log(
+ "log_b(1 of 3)",
+ ),
+ Log(
+ "log_b(2 of 3)",
+ ),
+ ]"#]]
+ .assert_eq(&format!("{:#?}", logs));
+ db.assert_logs(expect![[r#"
+ [
+ "push_logs(a = 2, b = 3)",
+ "push_a_logs(2)",
+ "push_b_logs(3)",
+ ]"#]]);
+
+ // Change to `a = 0`, which means `push_logs` does not call `push_a_logs` at all
+ input.set_field_a(&mut db).to(0);
+ let logs = push_logs::accumulated::(&db, input);
+ expect![[r#"
+ [
+ Log(
+ "log_b(0 of 3)",
+ ),
+ Log(
+ "log_b(1 of 3)",
+ ),
+ Log(
+ "log_b(2 of 3)",
+ ),
+ ]"#]]
+ .assert_eq(&format!("{:#?}", logs));
+ db.assert_logs(expect![[r#"
+ [
+ "push_logs(a = 1, b = 3)",
+ ]"#]]);
+}
+
+#[test]
+fn change_a_from_2_to_1() {
+ let mut db = Database::default();
+
+ // Accumulate logs for `a = 2` and `b = 3`
+ let input = MyInput::new(&db, 2, 3);
+ let logs = push_logs::accumulated::(&db, input);
+ expect![[r#"
+ [
+ Log(
+ "log_a(0 of 2)",
+ ),
+ Log(
+ "log_a(1 of 2)",
+ ),
+ Log(
+ "log_b(0 of 3)",
+ ),
+ Log(
+ "log_b(1 of 3)",
+ ),
+ Log(
+ "log_b(2 of 3)",
+ ),
]"#]]
.assert_eq(&format!("{:#?}", logs));
db.assert_logs(expect![[r#"
@@ -131,17 +208,27 @@ fn change_a_and_reaccumulate() {
// Change to `a = 1`, which means `push_logs` does not call `push_a_logs` at all
input.set_field_a(&mut db).to(1);
- let logs = push_logs::accumulated::(&db, input);
+ let logs = push_logs::accumulated::(&db, input);
expect![[r#"
[
- "log_b(0 of 3)",
- "log_b(1 of 3)",
- "log_b(2 of 3)",
+ Log(
+ "log_a(0 of 1)",
+ ),
+ Log(
+ "log_b(0 of 3)",
+ ),
+ Log(
+ "log_b(1 of 3)",
+ ),
+ Log(
+ "log_b(2 of 3)",
+ ),
]"#]]
.assert_eq(&format!("{:#?}", logs));
db.assert_logs(expect![[r#"
[
"push_logs(a = 1, b = 3)",
+ "push_a_logs(1)",
]"#]]);
}
@@ -151,11 +238,15 @@ fn get_a_logs_after_changing_b() {
// Invoke `push_a_logs` with `a = 2` and `b = 3` (but `b` doesn't matter)
let input = MyInput::new(&db, 2, 3);
- let logs = push_a_logs::accumulated::(&db, input);
+ let logs = push_a_logs::accumulated::(&db, input);
expect![[r#"
[
- "log_a(0 of 2)",
- "log_a(1 of 2)",
+ Log(
+ "log_a(0 of 2)",
+ ),
+ Log(
+ "log_a(1 of 2)",
+ ),
]"#]]
.assert_eq(&format!("{:#?}", logs));
db.assert_logs(expect![[r#"
@@ -166,11 +257,15 @@ fn get_a_logs_after_changing_b() {
// Changing `b` does not cause `push_a_logs` to re-execute
// and we still get the same result
input.set_field_b(&mut db).to(5);
- let logs = push_a_logs::accumulated::(&db, input);
+ let logs = push_a_logs::accumulated::(&db, input);
expect![[r#"
[
- "log_a(0 of 2)",
- "log_a(1 of 2)",
+ Log(
+ "log_a(0 of 2)",
+ ),
+ Log(
+ "log_a(1 of 2)",
+ ),
]
"#]]
.assert_debug_eq(&logs);