From 02008d51a76237b83aa105cbaed4e008bad5817f Mon Sep 17 00:00:00 2001 From: Phoebe Szmucer Date: Mon, 22 Jul 2024 11:52:55 +0100 Subject: [PATCH] Add a test, fix a bug, refactor --- src/function/accumulated.rs | 20 +++++----- tests/accumulate-chain.rs | 2 +- tests/accumulate-execution-order.rs | 61 +++++++++++++++++++++++++++++ 3 files changed, 71 insertions(+), 12 deletions(-) create mode 100644 tests/accumulate-execution-order.rs diff --git a/src/function/accumulated.rs b/src/function/accumulated.rs index 7d54a9f8..f0f04500 100644 --- a/src/function/accumulated.rs +++ b/src/function/accumulated.rs @@ -26,18 +26,16 @@ where let mut stack: Vec = vec![db_key]; while let Some(k) = stack.pop() { - visited.insert(k); - accumulator.produced_by(runtime, k, &mut output); + if visited.insert(k) { + accumulator.produced_by(runtime, k, &mut output); - let origin = db.lookup_ingredient(k.ingredient_index).origin(k.key_index); - let inputs = origin.iter().flat_map(|origin| origin.inputs()); - - for input in inputs.rev() { - if let Ok(input) = input.try_into() { - if !visited.contains(&input) { - stack.push(input); - } - } + let origin = db.lookup_ingredient(k.ingredient_index).origin(k.key_index); + let inputs = origin.iter().flat_map(|origin| origin.inputs()); + stack.extend( + inputs + .flat_map(|input| TryInto::::try_into(input).into_iter()) + .rev(), + ); } } diff --git a/tests/accumulate-chain.rs b/tests/accumulate-chain.rs index aa3a8ea8..7cf3d3b3 100644 --- a/tests/accumulate-chain.rs +++ b/tests/accumulate-chain.rs @@ -39,7 +39,7 @@ fn push_d_logs(db: &dyn Database) { fn accumulate_chain() { salsa::default_database().attach(|db| { let logs = push_logs::accumulated::(db); - // Check that we don't see logs from `a` appearing twice in the input. + // Check that we get all the logs. expect![[r#" [ Log( diff --git a/tests/accumulate-execution-order.rs b/tests/accumulate-execution-order.rs new file mode 100644 index 00000000..ddc2e023 --- /dev/null +++ b/tests/accumulate-execution-order.rs @@ -0,0 +1,61 @@ +mod common; + +use expect_test::expect; +use salsa::{Accumulator, Database}; +use test_log::test; + +#[salsa::accumulator] +struct Log(#[allow(dead_code)] String); + +#[salsa::tracked] +fn push_logs(db: &dyn Database) { + push_a_logs(db); +} + +#[salsa::tracked] +fn push_a_logs(db: &dyn Database) { + Log("log a".to_string()).accumulate(db); + push_b_logs(db); + push_c_logs(db); + push_d_logs(db); +} + +#[salsa::tracked] +fn push_b_logs(db: &dyn Database) { + Log("log b".to_string()).accumulate(db); + push_d_logs(db); +} + +#[salsa::tracked] +fn push_c_logs(db: &dyn Database) { + Log("log c".to_string()).accumulate(db); +} + +#[salsa::tracked] +fn push_d_logs(db: &dyn Database) { + Log("log d".to_string()).accumulate(db); +} + +#[test] +fn accumulate_chain() { + salsa::default_database().attach(|db| { + let logs = push_logs::accumulated::(db); + // Check that we get logs in execution order + expect![[r#" + [ + Log( + "log a", + ), + Log( + "log b", + ), + Log( + "log d", + ), + Log( + "log c", + ), + ]"#]] + .assert_eq(&format!("{:#?}", logs)); + }) +}