mirror of
https://github.com/salsa-rs/salsa.git
synced 2025-01-13 16:58:52 +00:00
fix accumulation
This commit is contained in:
parent
5209735d0b
commit
b677019407
8 changed files with 242 additions and 47 deletions
|
@ -227,9 +227,8 @@ macro_rules! setup_tracked_fn {
|
||||||
$zalsa::AsId::as_id(&($($input_id),*))
|
$zalsa::AsId::as_id(&($($input_id),*))
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
$Configuration::fn_ingredient($db).fetch($db, key);
|
|
||||||
let database_key_index = $Configuration::fn_ingredient($db).database_key_index(key);
|
$Configuration::fn_ingredient($db).accumulated_by::<A>($db, key)
|
||||||
$zalsa::accumulated_by($db.as_salsa_database(), database_key_index)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
$zalsa::macro_if! { $is_specifiable =>
|
$zalsa::macro_if! { $is_specifiable =>
|
||||||
|
|
|
@ -210,16 +210,3 @@ where
|
||||||
.finish()
|
.finish()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn accumulated_by<A>(db: &dyn Database, database_key_index: DatabaseKeyIndex) -> Vec<A>
|
|
||||||
where
|
|
||||||
A: Accumulator,
|
|
||||||
{
|
|
||||||
let Some(accumulator) = <IngredientImpl<A>>::from_db(db) else {
|
|
||||||
return vec![];
|
|
||||||
};
|
|
||||||
let runtime = db.runtime();
|
|
||||||
let mut output = vec![];
|
|
||||||
accumulator.produced_by(runtime, database_key_index, &mut output);
|
|
||||||
output
|
|
||||||
}
|
|
||||||
|
|
|
@ -16,6 +16,7 @@ use self::delete::DeletedEntries;
|
||||||
|
|
||||||
use super::ingredient::Ingredient;
|
use super::ingredient::Ingredient;
|
||||||
|
|
||||||
|
mod accumulated;
|
||||||
mod backdate;
|
mod backdate;
|
||||||
mod delete;
|
mod delete;
|
||||||
mod diff_outputs;
|
mod diff_outputs;
|
||||||
|
|
37
src/function/accumulated.rs
Normal file
37
src/function/accumulated.rs
Normal file
|
@ -0,0 +1,37 @@
|
||||||
|
use crate::{accumulator, storage::DatabaseGen, Database, Id};
|
||||||
|
|
||||||
|
use super::{Configuration, IngredientImpl};
|
||||||
|
|
||||||
|
impl<C> IngredientImpl<C>
|
||||||
|
where
|
||||||
|
C: Configuration,
|
||||||
|
{
|
||||||
|
/// Helper used by `accumulate` functions. Computes the results accumulated by `database_key_index`
|
||||||
|
/// and its inputs.
|
||||||
|
pub fn accumulated_by<A>(&self, db: &C::DbView, key: Id) -> Vec<A>
|
||||||
|
where
|
||||||
|
A: accumulator::Accumulator,
|
||||||
|
{
|
||||||
|
let Some(accumulator) = <accumulator::IngredientImpl<A>>::from_db(db) else {
|
||||||
|
return vec![];
|
||||||
|
};
|
||||||
|
let runtime = db.runtime();
|
||||||
|
let mut output = vec![];
|
||||||
|
|
||||||
|
// First ensure the result is up to date
|
||||||
|
self.fetch(db, key);
|
||||||
|
|
||||||
|
let database_key_index = self.database_key_index(key);
|
||||||
|
accumulator.produced_by(runtime, database_key_index, &mut output);
|
||||||
|
|
||||||
|
if let Some(origin) = self.origin(key) {
|
||||||
|
for input in origin.inputs() {
|
||||||
|
if let Ok(input) = input.try_into() {
|
||||||
|
accumulator.produced_by(runtime, input, &mut output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
output
|
||||||
|
}
|
||||||
|
}
|
|
@ -73,7 +73,6 @@ pub mod prelude {
|
||||||
///
|
///
|
||||||
/// The contents of this module are NOT subject to semver.
|
/// The contents of this module are NOT subject to semver.
|
||||||
pub mod plumbing {
|
pub mod plumbing {
|
||||||
pub use crate::accumulator::accumulated_by;
|
|
||||||
pub use crate::accumulator::Accumulator;
|
pub use crate::accumulator::Accumulator;
|
||||||
pub use crate::array::Array;
|
pub use crate::array::Array;
|
||||||
pub use crate::cycle::Cycle;
|
pub use crate::cycle::Cycle;
|
||||||
|
|
|
@ -76,7 +76,16 @@ pub enum QueryOrigin {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl QueryOrigin {
|
impl QueryOrigin {
|
||||||
/// Indices for queries *written* by this query (or `vec![]` if its value was assigned).
|
/// Indices for queries *read* by this query
|
||||||
|
pub(crate) fn inputs(&self) -> impl Iterator<Item = DependencyIndex> + '_ {
|
||||||
|
let opt_edges = match self {
|
||||||
|
QueryOrigin::Derived(edges) | QueryOrigin::DerivedUntracked(edges) => Some(edges),
|
||||||
|
QueryOrigin::Assigned(_) | QueryOrigin::BaseInput => None,
|
||||||
|
};
|
||||||
|
opt_edges.into_iter().flat_map(|edges| edges.inputs())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Indices for queries *written* by this query (if any)
|
||||||
pub(crate) fn outputs(&self) -> impl Iterator<Item = DependencyIndex> + '_ {
|
pub(crate) fn outputs(&self) -> impl Iterator<Item = DependencyIndex> + '_ {
|
||||||
let opt_edges = match self {
|
let opt_edges = match self {
|
||||||
QueryOrigin::Derived(edges) | QueryOrigin::DerivedUntracked(edges) => Some(edges),
|
QueryOrigin::Derived(edges) | QueryOrigin::DerivedUntracked(edges) => Some(edges),
|
||||||
|
|
68
tests/accumulate-dag.rs
Normal file
68
tests/accumulate-dag.rs
Normal file
|
@ -0,0 +1,68 @@
|
||||||
|
mod common;
|
||||||
|
use common::{HasLogger, Logger};
|
||||||
|
|
||||||
|
use expect_test::expect;
|
||||||
|
use salsa::{Accumulator, Database};
|
||||||
|
use test_log::test;
|
||||||
|
|
||||||
|
#[salsa::input]
|
||||||
|
struct MyInput {
|
||||||
|
field_a: u32,
|
||||||
|
field_b: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[salsa::accumulator]
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
struct Log(#[allow(dead_code)] String);
|
||||||
|
|
||||||
|
#[salsa::tracked]
|
||||||
|
fn push_logs(db: &dyn Database, input: MyInput) {
|
||||||
|
push_a_logs(db, input);
|
||||||
|
push_b_logs(db, input);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[salsa::tracked]
|
||||||
|
fn push_a_logs(db: &dyn Database, input: MyInput) {
|
||||||
|
let count = input.field_a(db);
|
||||||
|
for i in 0..count {
|
||||||
|
Log(format!("log_a({} of {})", i, count)).accumulate(db);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[salsa::tracked]
|
||||||
|
fn push_b_logs(db: &dyn Database, input: MyInput) {
|
||||||
|
// Note that b calls a
|
||||||
|
push_a_logs(db, input);
|
||||||
|
let count = input.field_b(db);
|
||||||
|
for i in 0..count {
|
||||||
|
Log(format!("log_b({} of {})", i, count)).accumulate(db);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn accumulate_a_called_twice() {
|
||||||
|
salsa::default_database().attach(|db| {
|
||||||
|
let input = MyInput::new(db, 2, 3);
|
||||||
|
let logs = push_logs::accumulated::<Log>(db, input);
|
||||||
|
// Check that we don't see logs from `a` appearing twice in the input.
|
||||||
|
expect![[r#"
|
||||||
|
[
|
||||||
|
Log(
|
||||||
|
"log_a(0 of 2)",
|
||||||
|
),
|
||||||
|
Log(
|
||||||
|
"log_a(1 of 2)",
|
||||||
|
),
|
||||||
|
Log(
|
||||||
|
"log_b(0 of 3)",
|
||||||
|
),
|
||||||
|
Log(
|
||||||
|
"log_b(1 of 3)",
|
||||||
|
),
|
||||||
|
Log(
|
||||||
|
"log_b(2 of 3)",
|
||||||
|
),
|
||||||
|
]"#]]
|
||||||
|
.assert_eq(&format!("{:#?}", logs));
|
||||||
|
})
|
||||||
|
}
|
|
@ -20,7 +20,7 @@ struct MyInput {
|
||||||
|
|
||||||
#[salsa::accumulator]
|
#[salsa::accumulator]
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
struct Logs(#[allow(dead_code)] String);
|
struct Log(#[allow(dead_code)] String);
|
||||||
|
|
||||||
#[salsa::tracked]
|
#[salsa::tracked]
|
||||||
fn push_logs(db: &dyn Db, input: MyInput) {
|
fn push_logs(db: &dyn Db, input: MyInput) {
|
||||||
|
@ -30,13 +30,13 @@ fn push_logs(db: &dyn Db, input: MyInput) {
|
||||||
input.field_b(db)
|
input.field_b(db)
|
||||||
));
|
));
|
||||||
|
|
||||||
// We don't invoke `push_a_logs` (or `push_b_logs`) with a value of 1 or less.
|
// We don't invoke `push_a_logs` (or `push_b_logs`) with a value of 0.
|
||||||
// This allows us to test what happens a change in inputs causes a function not to be called at all.
|
// This allows us to test what happens a change in inputs causes a function not to be called at all.
|
||||||
if input.field_a(db) > 1 {
|
if input.field_a(db) > 0 {
|
||||||
push_a_logs(db, input);
|
push_a_logs(db, input);
|
||||||
}
|
}
|
||||||
|
|
||||||
if input.field_b(db) > 1 {
|
if input.field_b(db) > 0 {
|
||||||
push_b_logs(db, input);
|
push_b_logs(db, input);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -47,7 +47,7 @@ fn push_a_logs(db: &dyn Db, input: MyInput) {
|
||||||
db.push_log(format!("push_a_logs({})", field_a));
|
db.push_log(format!("push_a_logs({})", field_a));
|
||||||
|
|
||||||
for i in 0..field_a {
|
for i in 0..field_a {
|
||||||
Logs(format!("log_a({} of {})", i, field_a)).accumulate(db);
|
Log(format!("log_a({} of {})", i, field_a)).accumulate(db);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -57,7 +57,7 @@ fn push_b_logs(db: &dyn Db, input: MyInput) {
|
||||||
db.push_log(format!("push_b_logs({})", field_a));
|
db.push_log(format!("push_b_logs({})", field_a));
|
||||||
|
|
||||||
for i in 0..field_a {
|
for i in 0..field_a {
|
||||||
Logs(format!("log_b({} of {})", i, field_a)).accumulate(db);
|
Log(format!("log_b({} of {})", i, field_a)).accumulate(db);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -88,38 +88,115 @@ fn accumulate_once() {
|
||||||
|
|
||||||
// Just call accumulate on a base input to see what happens.
|
// Just call accumulate on a base input to see what happens.
|
||||||
let input = MyInput::new(&db, 2, 3);
|
let input = MyInput::new(&db, 2, 3);
|
||||||
let logs = push_logs::accumulated::<Logs>(&db, input);
|
let logs = push_logs::accumulated::<Log>(&db, input);
|
||||||
db.assert_logs(expect![[r#"
|
db.assert_logs(expect![[r#"
|
||||||
[
|
[
|
||||||
"push_logs(a = 2, b = 3)",
|
"push_logs(a = 2, b = 3)",
|
||||||
"push_a_logs(2)",
|
"push_a_logs(2)",
|
||||||
"push_b_logs(3)",
|
"push_b_logs(3)",
|
||||||
]"#]]);
|
]"#]]);
|
||||||
|
// Check that we see logs from `a` first and then logs from `b`
|
||||||
|
// (execution order).
|
||||||
expect![[r#"
|
expect![[r#"
|
||||||
[
|
[
|
||||||
"log_b(0 of 3)",
|
Log(
|
||||||
"log_b(1 of 3)",
|
"log_a(0 of 2)",
|
||||||
"log_b(2 of 3)",
|
),
|
||||||
"log_a(0 of 2)",
|
Log(
|
||||||
"log_a(1 of 2)",
|
"log_a(1 of 2)",
|
||||||
|
),
|
||||||
|
Log(
|
||||||
|
"log_b(0 of 3)",
|
||||||
|
),
|
||||||
|
Log(
|
||||||
|
"log_b(1 of 3)",
|
||||||
|
),
|
||||||
|
Log(
|
||||||
|
"log_b(2 of 3)",
|
||||||
|
),
|
||||||
]"#]]
|
]"#]]
|
||||||
.assert_eq(&format!("{:#?}", logs));
|
.assert_eq(&format!("{:#?}", logs));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn change_a_and_reaccumulate() {
|
fn change_a_from_2_to_0() {
|
||||||
let mut db = Database::default();
|
let mut db = Database::default();
|
||||||
|
|
||||||
// Accumulate logs for `a = 2` and `b = 3`
|
// Accumulate logs for `a = 2` and `b = 3`
|
||||||
let input = MyInput::new(&db, 2, 3);
|
let input = MyInput::new(&db, 2, 3);
|
||||||
let logs = push_logs::accumulated::<Logs>(&db, input);
|
let logs = push_logs::accumulated::<Log>(&db, input);
|
||||||
expect![[r#"
|
expect![[r#"
|
||||||
[
|
[
|
||||||
"log_b(0 of 3)",
|
Log(
|
||||||
"log_b(1 of 3)",
|
"log_a(0 of 2)",
|
||||||
"log_b(2 of 3)",
|
),
|
||||||
"log_a(0 of 2)",
|
Log(
|
||||||
"log_a(1 of 2)",
|
"log_a(1 of 2)",
|
||||||
|
),
|
||||||
|
Log(
|
||||||
|
"log_b(0 of 3)",
|
||||||
|
),
|
||||||
|
Log(
|
||||||
|
"log_b(1 of 3)",
|
||||||
|
),
|
||||||
|
Log(
|
||||||
|
"log_b(2 of 3)",
|
||||||
|
),
|
||||||
|
]"#]]
|
||||||
|
.assert_eq(&format!("{:#?}", logs));
|
||||||
|
db.assert_logs(expect![[r#"
|
||||||
|
[
|
||||||
|
"push_logs(a = 2, b = 3)",
|
||||||
|
"push_a_logs(2)",
|
||||||
|
"push_b_logs(3)",
|
||||||
|
]"#]]);
|
||||||
|
|
||||||
|
// Change to `a = 0`, which means `push_logs` does not call `push_a_logs` at all
|
||||||
|
input.set_field_a(&mut db).to(0);
|
||||||
|
let logs = push_logs::accumulated::<Log>(&db, input);
|
||||||
|
expect![[r#"
|
||||||
|
[
|
||||||
|
Log(
|
||||||
|
"log_b(0 of 3)",
|
||||||
|
),
|
||||||
|
Log(
|
||||||
|
"log_b(1 of 3)",
|
||||||
|
),
|
||||||
|
Log(
|
||||||
|
"log_b(2 of 3)",
|
||||||
|
),
|
||||||
|
]"#]]
|
||||||
|
.assert_eq(&format!("{:#?}", logs));
|
||||||
|
db.assert_logs(expect![[r#"
|
||||||
|
[
|
||||||
|
"push_logs(a = 1, b = 3)",
|
||||||
|
]"#]]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn change_a_from_2_to_1() {
|
||||||
|
let mut db = Database::default();
|
||||||
|
|
||||||
|
// Accumulate logs for `a = 2` and `b = 3`
|
||||||
|
let input = MyInput::new(&db, 2, 3);
|
||||||
|
let logs = push_logs::accumulated::<Log>(&db, input);
|
||||||
|
expect![[r#"
|
||||||
|
[
|
||||||
|
Log(
|
||||||
|
"log_a(0 of 2)",
|
||||||
|
),
|
||||||
|
Log(
|
||||||
|
"log_a(1 of 2)",
|
||||||
|
),
|
||||||
|
Log(
|
||||||
|
"log_b(0 of 3)",
|
||||||
|
),
|
||||||
|
Log(
|
||||||
|
"log_b(1 of 3)",
|
||||||
|
),
|
||||||
|
Log(
|
||||||
|
"log_b(2 of 3)",
|
||||||
|
),
|
||||||
]"#]]
|
]"#]]
|
||||||
.assert_eq(&format!("{:#?}", logs));
|
.assert_eq(&format!("{:#?}", logs));
|
||||||
db.assert_logs(expect![[r#"
|
db.assert_logs(expect![[r#"
|
||||||
|
@ -131,17 +208,27 @@ fn change_a_and_reaccumulate() {
|
||||||
|
|
||||||
// Change to `a = 1`, which means `push_logs` does not call `push_a_logs` at all
|
// Change to `a = 1`, which means `push_logs` does not call `push_a_logs` at all
|
||||||
input.set_field_a(&mut db).to(1);
|
input.set_field_a(&mut db).to(1);
|
||||||
let logs = push_logs::accumulated::<Logs>(&db, input);
|
let logs = push_logs::accumulated::<Log>(&db, input);
|
||||||
expect![[r#"
|
expect![[r#"
|
||||||
[
|
[
|
||||||
"log_b(0 of 3)",
|
Log(
|
||||||
"log_b(1 of 3)",
|
"log_a(0 of 1)",
|
||||||
"log_b(2 of 3)",
|
),
|
||||||
|
Log(
|
||||||
|
"log_b(0 of 3)",
|
||||||
|
),
|
||||||
|
Log(
|
||||||
|
"log_b(1 of 3)",
|
||||||
|
),
|
||||||
|
Log(
|
||||||
|
"log_b(2 of 3)",
|
||||||
|
),
|
||||||
]"#]]
|
]"#]]
|
||||||
.assert_eq(&format!("{:#?}", logs));
|
.assert_eq(&format!("{:#?}", logs));
|
||||||
db.assert_logs(expect![[r#"
|
db.assert_logs(expect![[r#"
|
||||||
[
|
[
|
||||||
"push_logs(a = 1, b = 3)",
|
"push_logs(a = 1, b = 3)",
|
||||||
|
"push_a_logs(1)",
|
||||||
]"#]]);
|
]"#]]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -151,11 +238,15 @@ fn get_a_logs_after_changing_b() {
|
||||||
|
|
||||||
// Invoke `push_a_logs` with `a = 2` and `b = 3` (but `b` doesn't matter)
|
// Invoke `push_a_logs` with `a = 2` and `b = 3` (but `b` doesn't matter)
|
||||||
let input = MyInput::new(&db, 2, 3);
|
let input = MyInput::new(&db, 2, 3);
|
||||||
let logs = push_a_logs::accumulated::<Logs>(&db, input);
|
let logs = push_a_logs::accumulated::<Log>(&db, input);
|
||||||
expect![[r#"
|
expect![[r#"
|
||||||
[
|
[
|
||||||
"log_a(0 of 2)",
|
Log(
|
||||||
"log_a(1 of 2)",
|
"log_a(0 of 2)",
|
||||||
|
),
|
||||||
|
Log(
|
||||||
|
"log_a(1 of 2)",
|
||||||
|
),
|
||||||
]"#]]
|
]"#]]
|
||||||
.assert_eq(&format!("{:#?}", logs));
|
.assert_eq(&format!("{:#?}", logs));
|
||||||
db.assert_logs(expect![[r#"
|
db.assert_logs(expect![[r#"
|
||||||
|
@ -166,11 +257,15 @@ fn get_a_logs_after_changing_b() {
|
||||||
// Changing `b` does not cause `push_a_logs` to re-execute
|
// Changing `b` does not cause `push_a_logs` to re-execute
|
||||||
// and we still get the same result
|
// and we still get the same result
|
||||||
input.set_field_b(&mut db).to(5);
|
input.set_field_b(&mut db).to(5);
|
||||||
let logs = push_a_logs::accumulated::<Logs>(&db, input);
|
let logs = push_a_logs::accumulated::<Log>(&db, input);
|
||||||
expect![[r#"
|
expect![[r#"
|
||||||
[
|
[
|
||||||
"log_a(0 of 2)",
|
Log(
|
||||||
"log_a(1 of 2)",
|
"log_a(0 of 2)",
|
||||||
|
),
|
||||||
|
Log(
|
||||||
|
"log_a(1 of 2)",
|
||||||
|
),
|
||||||
]
|
]
|
||||||
"#]]
|
"#]]
|
||||||
.assert_debug_eq(&logs);
|
.assert_debug_eq(&logs);
|
||||||
|
|
Loading…
Reference in a new issue