Add a test, fix a bug, refactor

This commit is contained in:
Phoebe Szmucer 2024-07-22 11:52:55 +01:00
parent c9f22f108a
commit 02008d51a7
3 changed files with 71 additions and 12 deletions

View file

@ -26,18 +26,16 @@ where
let mut stack: Vec<DatabaseKeyIndex> = vec![db_key];
while let Some(k) = stack.pop() {
visited.insert(k);
if visited.insert(k) {
accumulator.produced_by(runtime, k, &mut output);
let origin = db.lookup_ingredient(k.ingredient_index).origin(k.key_index);
let inputs = origin.iter().flat_map(|origin| origin.inputs());
for input in inputs.rev() {
if let Ok(input) = input.try_into() {
if !visited.contains(&input) {
stack.push(input);
}
}
stack.extend(
inputs
.flat_map(|input| TryInto::<DatabaseKeyIndex>::try_into(input).into_iter())
.rev(),
);
}
}

View file

@ -39,7 +39,7 @@ fn push_d_logs(db: &dyn Database) {
fn accumulate_chain() {
salsa::default_database().attach(|db| {
let logs = push_logs::accumulated::<Log>(db);
// Check that we don't see logs from `a` appearing twice in the input.
// Check that we get all the logs.
expect![[r#"
[
Log(

View file

@ -0,0 +1,61 @@
mod common;
use expect_test::expect;
use salsa::{Accumulator, Database};
use test_log::test;
#[salsa::accumulator]
struct Log(#[allow(dead_code)] String);
#[salsa::tracked]
fn push_logs(db: &dyn Database) {
push_a_logs(db);
}
#[salsa::tracked]
fn push_a_logs(db: &dyn Database) {
Log("log a".to_string()).accumulate(db);
push_b_logs(db);
push_c_logs(db);
push_d_logs(db);
}
#[salsa::tracked]
fn push_b_logs(db: &dyn Database) {
Log("log b".to_string()).accumulate(db);
push_d_logs(db);
}
#[salsa::tracked]
fn push_c_logs(db: &dyn Database) {
Log("log c".to_string()).accumulate(db);
}
#[salsa::tracked]
fn push_d_logs(db: &dyn Database) {
Log("log d".to_string()).accumulate(db);
}
#[test]
fn accumulate_chain() {
salsa::default_database().attach(|db| {
let logs = push_logs::accumulated::<Log>(db);
// Check that we get logs in execution order
expect![[r#"
[
Log(
"log a",
),
Log(
"log b",
),
Log(
"log d",
),
Log(
"log c",
),
]"#]]
.assert_eq(&format!("{:#?}", logs));
})
}