diff --git a/components/salsa-2022-macros/src/accumulator.rs b/components/salsa-2022-macros/src/accumulator.rs index ed870fe8..266f3ac3 100644 --- a/components/salsa-2022-macros/src/accumulator.rs +++ b/components/salsa-2022-macros/src/accumulator.rs @@ -30,6 +30,8 @@ impl crate::options::AllowedOptions for Accumulator { const DATA: bool = false; const DB: bool = false; + + const RECOVERY_FN: bool = false; } fn accumulator_contents( diff --git a/components/salsa-2022-macros/src/jar.rs b/components/salsa-2022-macros/src/jar.rs index 6b474ce7..923ecc17 100644 --- a/components/salsa-2022-macros/src/jar.rs +++ b/components/salsa-2022-macros/src/jar.rs @@ -39,6 +39,8 @@ impl crate::options::AllowedOptions for Jar { const DATA: bool = false; const DB: bool = true; + + const RECOVERY_FN: bool = false; } pub(crate) fn jar_struct_and_friends( diff --git a/components/salsa-2022-macros/src/options.rs b/components/salsa-2022-macros/src/options.rs index e6f080f4..eb8bd937 100644 --- a/components/salsa-2022-macros/src/options.rs +++ b/components/salsa-2022-macros/src/options.rs @@ -27,14 +27,19 @@ pub(crate) struct Options { /// The `jar = ` option is used to indicate the jar; it defaults to `crate::jar`. /// - /// If this is `Some`, the value is the ``. + /// If this is `Some`, the value is the ``. pub jar_ty: Option, - /// The `db = ` option is used to indicate the db. + /// The `db = ` option is used to indicate the db. /// - /// If this is `Some`, the value is the ``. + /// If this is `Some`, the value is the ``. pub db_path: Option, + /// The `recovery_fn = ` option is used to indicate the recovery function. + /// + /// If this is `Some`, the value is the ``. + pub recovery_fn: Option, + /// The `data = ` option is used to define the name of the data type for an interned /// struct. /// @@ -53,6 +58,7 @@ impl Default for Options { no_eq: Default::default(), jar_ty: Default::default(), db_path: Default::default(), + recovery_fn: Default::default(), data: Default::default(), phantom: Default::default(), } @@ -67,6 +73,7 @@ pub(crate) trait AllowedOptions { const JAR: bool; const DATA: bool; const DB: bool; + const RECOVERY_FN: bool; } type Equals = syn::Token![=]; @@ -159,6 +166,22 @@ impl syn::parse::Parse for Options { "`db` option not allowed here", )); } + } else if ident == "recovery_fn" { + if A::RECOVERY_FN { + let _eq = Equals::parse(input)?; + let path = syn::Path::parse(input)?; + if let Some(old) = std::mem::replace(&mut options.recovery_fn, Some(path)) { + return Err(syn::Error::new( + old.span(), + "option `recovery_fn` provided twice", + )); + } + } else { + return Err(syn::Error::new( + ident.span(), + "`recovery_fn` option not allowed here", + )); + } } else if ident == "data" { if A::DATA { let _eq = Equals::parse(input)?; diff --git a/components/salsa-2022-macros/src/salsa_struct.rs b/components/salsa-2022-macros/src/salsa_struct.rs index 6721eb34..72bf27ce 100644 --- a/components/salsa-2022-macros/src/salsa_struct.rs +++ b/components/salsa-2022-macros/src/salsa_struct.rs @@ -47,6 +47,8 @@ impl crate::options::AllowedOptions for SalsaStruct { const DATA: bool = true; const DB: bool = false; + + const RECOVERY_FN: bool = false; } const BANNED_FIELD_NAMES: &[&str] = &["from", "new"]; diff --git a/components/salsa-2022-macros/src/tracked_fn.rs b/components/salsa-2022-macros/src/tracked_fn.rs index 84be7a37..f22dd213 100644 --- a/components/salsa-2022-macros/src/tracked_fn.rs +++ b/components/salsa-2022-macros/src/tracked_fn.rs @@ -70,6 +70,8 @@ impl crate::options::AllowedOptions for TrackedFn { const DATA: bool = false; const DB: bool = false; + + const RECOVERY_FN: bool = true; } /// Returns the key type for this tracked function. @@ -132,14 +134,34 @@ fn fn_configuration(args: &Args, item_fn: &syn::ItemFn) -> Configuration { }; let value_ty = configuration::value_ty(&item_fn.sig); - // FIXME: these are hardcoded for now - let cycle_strategy = CycleRecoveryStrategy::Panic; + let fn_ty = item_fn.sig.ident.clone(); + + let indices = (0..item_fn.sig.inputs.len() - 1).map(|i| Literal::usize_unsuffixed(i)); + let (cycle_strategy, recover_fn) = if let Some(recovery_fn) = &args.recovery_fn { + // Create the `recover_from_cycle` function, which (a) maps from the interned id to the actual + // keys and then (b) invokes the recover function itself. + let cycle_strategy = CycleRecoveryStrategy::Fallback; + + let cycle_fullback = parse_quote! { + fn recover_from_cycle(__db: &salsa::function::DynDb, __cycle: &salsa::Cycle, __id: Self::Key) -> Self::Value { + let (__jar, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(__db); + let __ingredients = + <_ as salsa::storage::HasIngredientsFor<#fn_ty>>::ingredient(__jar); + let __key = __ingredients.intern_map.data(__runtime, __id).clone(); + #recovery_fn(__db, __cycle, #(__key.#indices),*) + } + }; + (cycle_strategy, cycle_fullback) + } else { + // When the `recovery_fn` attribute is not set, set `cycle_strategy` to `Panic` + let cycle_strategy = CycleRecoveryStrategy::Panic; + let cycle_panic = configuration::panic_cycle_recovery_fn(); + (cycle_strategy, cycle_panic) + }; let backdate_fn = configuration::should_backdate_value_fn(args.should_backdate()); - let recover_fn = configuration::panic_cycle_recovery_fn(); // The type of the configuration struct; this has the same name as the fn itself. - let fn_ty = item_fn.sig.ident.clone(); // Make a copy of the fn with a different name; we will invoke this from `execute`. // We need to change the name because, otherwise, if the function invoked itself diff --git a/salsa-2022-tests/Cargo.toml b/salsa-2022-tests/Cargo.toml index f0140c02..edd31a4e 100644 --- a/salsa-2022-tests/Cargo.toml +++ b/salsa-2022-tests/Cargo.toml @@ -8,4 +8,5 @@ edition = "2021" [dependencies] salsa = { path = "../components/salsa-2022", package = "salsa-2022" } expect-test = "1.4.0" +parking_lot = "0.12.1" diff --git a/salsa-2022-tests/tests/cycles.rs b/salsa-2022-tests/tests/cycles.rs new file mode 100644 index 00000000..5a88ed9a --- /dev/null +++ b/salsa-2022-tests/tests/cycles.rs @@ -0,0 +1,427 @@ +#![allow(warnings)] + +use std::panic::{RefUnwindSafe, UnwindSafe}; + +use expect_test::expect; +use salsa::storage::HasJarsDyn; + +// Axes: +// +// Threading +// * Intra-thread +// * Cross-thread -- part of cycle is on one thread, part on another +// +// Recovery strategies: +// * Panic +// * Fallback +// * Mixed -- multiple strategies within cycle participants +// +// Across revisions: +// * N/A -- only one revision +// * Present in new revision, not old +// * Present in old revision, not new +// * Present in both revisions +// +// Dependencies +// * Tracked +// * Untracked -- cycle participant(s) contain untracked reads +// +// Layers +// * Direct -- cycle participant is directly invoked from test +// * Indirect -- invoked a query that invokes the cycle +// +// +// | Thread | Recovery | Old, New | Dep style | Layers | Test Name | +// | ------ | -------- | -------- | --------- | ------ | --------- | +// | Intra | Panic | N/A | Tracked | direct | cycle_memoized | +// | Intra | Panic | N/A | Untracked | direct | cycle_volatile | +// | Intra | Fallback | N/A | Tracked | direct | cycle_cycle | +// | Intra | Fallback | N/A | Tracked | indirect | inner_cycle | +// | Intra | Fallback | Both | Tracked | direct | cycle_revalidate | +// | Intra | Fallback | New | Tracked | direct | cycle_appears | +// | Intra | Fallback | Old | Tracked | direct | cycle_disappears | +// | Intra | Mixed | N/A | Tracked | direct | cycle_mixed_1 | +// | Intra | Mixed | N/A | Tracked | direct | cycle_mixed_2 | +// | Cross | Panic | N/A | Tracked | both | parallel/parallel_cycle_none_recover.rs | +// | Cross | Fallback | N/A | Tracked | both | parallel/parallel_cycle_one_recover.rs | +// | Cross | Fallback | N/A | Tracked | both | parallel/parallel_cycle_mid_recover.rs | +// | Cross | Fallback | N/A | Tracked | both | parallel/parallel_cycle_all_recover.rs | + +// TODO: The following test is not yet ported. +// | Intra | Fallback | Old | Tracked | direct | cycle_disappears_durability | + +#[derive(PartialEq, Eq, Hash, Clone, Debug)] +struct Error { + cycle: Vec, +} + +#[salsa::jar(db = Db)] +struct Jar( + MyInput, + memoized_a, + memoized_b, + volatile_a, + volatile_b, + ABC, + cycle_a, + cycle_b, + cycle_c, +); + +trait Db: salsa::DbWithJar {} + +#[salsa::db(Jar)] +#[derive(Default)] +struct Database { + storage: salsa::Storage, +} + +impl salsa::Database for Database { + fn salsa_runtime(&self) -> &salsa::Runtime { + self.storage.runtime() + } +} + +impl Db for Database {} + +impl RefUnwindSafe for Database {} + +#[salsa::input(jar = Jar)] +struct MyInput {} + +#[salsa::tracked(jar = Jar)] +fn memoized_a(db: &dyn Db, input: MyInput) { + memoized_b(db, input) +} + +#[salsa::tracked(jar = Jar)] +fn memoized_b(db: &dyn Db, input: MyInput) { + memoized_a(db, input) +} + +#[salsa::tracked(jar = Jar)] +fn volatile_a(db: &dyn Db, input: MyInput) { + db.runtime().report_untracked_read(); + volatile_b(db, input) +} + +#[salsa::tracked(jar = Jar)] +fn volatile_b(db: &dyn Db, input: MyInput) { + db.runtime().report_untracked_read(); + volatile_a(db, input) +} + +/// The queries A, B, and C in `Database` can be configured +/// to invoke one another in arbitrary ways using this +/// enum. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +enum CycleQuery { + None, + A, + B, + C, + AthenC, +} + +#[salsa::input(jar = Jar)] +struct ABC { + a: CycleQuery, + b: CycleQuery, + c: CycleQuery, +} + +impl CycleQuery { + fn invoke(self, db: &dyn Db, abc: ABC) -> Result<(), Error> { + match self { + CycleQuery::A => cycle_a(db, abc), + CycleQuery::B => cycle_b(db, abc), + CycleQuery::C => cycle_c(db, abc), + CycleQuery::AthenC => { + let _ = cycle_a(db, abc); + cycle_c(db, abc) + } + CycleQuery::None => Ok(()), + } + } +} + +#[salsa::tracked(jar = Jar, recovery_fn=recover_a)] +fn cycle_a(db: &dyn Db, abc: ABC) -> Result<(), Error> { + abc.a(db).invoke(db, abc) +} + +fn recover_a(db: &dyn Db, cycle: &salsa::Cycle, abc: ABC) -> Result<(), Error> { + Err(Error { + cycle: cycle.all_participants(db), + }) +} + +#[salsa::tracked(jar = Jar, recovery_fn=recover_b)] +fn cycle_b(db: &dyn Db, abc: ABC) -> Result<(), Error> { + abc.b(db).invoke(db, abc) +} + +fn recover_b(db: &dyn Db, cycle: &salsa::Cycle, abc: ABC) -> Result<(), Error> { + Err(Error { + cycle: cycle.all_participants(db), + }) +} + +#[salsa::tracked(jar = Jar)] +fn cycle_c(db: &dyn Db, abc: ABC) -> Result<(), Error> { + abc.c(db).invoke(db, abc) +} + +#[track_caller] +fn extract_cycle(f: impl FnOnce() + UnwindSafe) -> salsa::Cycle { + let v = std::panic::catch_unwind(f); + if let Err(d) = &v { + if let Some(cycle) = d.downcast_ref::() { + return cycle.clone(); + } + } + panic!("unexpected value: {:?}", v) +} + +#[test] +fn cycle_memoized() { + let mut db = Database::default(); + let input = MyInput::new(&mut db); + let cycle = extract_cycle(|| memoized_a(&db, input)); + let expected = expect![[r#" + [ + "DependencyIndex { ingredient_index: IngredientIndex(1), key_index: Some(Id { value: 1 }) }", + "DependencyIndex { ingredient_index: IngredientIndex(2), key_index: Some(Id { value: 1 }) }", + ] + "#]]; + expected.assert_debug_eq(&cycle.all_participants(&db)); +} + +#[test] +fn cycle_volatile() { + let mut db = Database::default(); + let input = MyInput::new(&mut db); + let cycle = extract_cycle(|| volatile_a(&db, input)); + let expected = expect![[r#" + [ + "DependencyIndex { ingredient_index: IngredientIndex(3), key_index: Some(Id { value: 1 }) }", + "DependencyIndex { ingredient_index: IngredientIndex(4), key_index: Some(Id { value: 1 }) }", + ] + "#]]; + expected.assert_debug_eq(&cycle.all_participants(&db)); +} + +#[test] +fn expect_cycle() { + // A --> B + // ^ | + // +-----+ + + let mut db = Database::default(); + let abc = ABC::new(&mut db, CycleQuery::B, CycleQuery::A, CycleQuery::None); + assert!(cycle_a(&db, abc).is_err()); +} + +#[test] +fn inner_cycle() { + // A --> B <-- C + // ^ | + // +-----+ + let mut db = Database::default(); + let abc = ABC::new(&mut db, CycleQuery::B, CycleQuery::A, CycleQuery::B); + let err = cycle_c(&db, abc); + assert!(err.is_err()); + let expected = expect![[r#" + [ + "DependencyIndex { ingredient_index: IngredientIndex(9), key_index: Some(Id { value: 1 }) }", + "DependencyIndex { ingredient_index: IngredientIndex(10), key_index: Some(Id { value: 1 }) }", + ] + "#]]; + expected.assert_debug_eq(&err.unwrap_err().cycle); +} + +#[test] +fn cycle_revalidate() { + // A --> B + // ^ | + // +-----+ + let mut db = Database::default(); + let abc = ABC::new(&mut db, CycleQuery::B, CycleQuery::A, CycleQuery::None); + assert!(cycle_a(&db, abc).is_err()); + abc.set_b(&mut db, CycleQuery::A); // same value as default + assert!(cycle_a(&db, abc).is_err()); +} + +#[test] +fn cycle_recovery_unchanged_twice() { + // A --> B + // ^ | + // +-----+ + let mut db = Database::default(); + let abc = ABC::new(&mut db, CycleQuery::B, CycleQuery::A, CycleQuery::None); + assert!(cycle_a(&db, abc).is_err()); + + abc.set_c(&mut db, CycleQuery::A); // force new revision + assert!(cycle_a(&db, abc).is_err()); +} + +#[test] +fn cycle_appears() { + let mut db = Database::default(); + + // A --> B + let abc = ABC::new(&mut db, CycleQuery::B, CycleQuery::None, CycleQuery::None); + assert!(cycle_a(&db, abc).is_ok()); + + // A --> B + // ^ | + // +-----+ + abc.set_b(&mut db, CycleQuery::A); + assert!(cycle_a(&db, abc).is_err()); +} + +#[test] +fn cycle_disappears() { + let mut db = Database::default(); + + // A --> B + // ^ | + // +-----+ + let abc = ABC::new(&mut db, CycleQuery::B, CycleQuery::A, CycleQuery::None); + assert!(cycle_a(&db, abc).is_err()); + + // A --> B + abc.set_b(&mut db, CycleQuery::None); + assert!(cycle_a(&db, abc).is_ok()); +} + +#[test] +fn cycle_mixed_1() { + let mut db = Database::default(); + + // A --> B <-- C + // | ^ + // +-----+ + let abc = ABC::new(&mut db, CycleQuery::B, CycleQuery::C, CycleQuery::B); + + let expected = expect![[r#" + [ + "DependencyIndex { ingredient_index: IngredientIndex(10), key_index: Some(Id { value: 1 }) }", + "DependencyIndex { ingredient_index: IngredientIndex(11), key_index: Some(Id { value: 1 }) }", + ] + "#]]; + expected.assert_debug_eq(&cycle_c(&db, abc).unwrap_err().cycle); +} + +#[test] +fn cycle_mixed_2() { + let mut db = Database::default(); + + // Configuration: + // + // A --> B --> C + // ^ | + // +-----------+ + let abc = ABC::new(&mut db, CycleQuery::B, CycleQuery::C, CycleQuery::A); + let expected = expect![[r#" + [ + "DependencyIndex { ingredient_index: IngredientIndex(9), key_index: Some(Id { value: 1 }) }", + "DependencyIndex { ingredient_index: IngredientIndex(10), key_index: Some(Id { value: 1 }) }", + "DependencyIndex { ingredient_index: IngredientIndex(11), key_index: Some(Id { value: 1 }) }", + ] + "#]]; + expected.assert_debug_eq(&cycle_a(&db, abc).unwrap_err().cycle); +} + +#[test] +fn cycle_deterministic_order() { + // No matter whether we start from A or B, we get the same set of participants: + let f = || { + let mut db = Database::default(); + + // A --> B + // ^ | + // +-----+ + let abc = ABC::new(&mut db, CycleQuery::B, CycleQuery::A, CycleQuery::None); + (db, abc) + }; + let (db, abc) = f(); + let a = cycle_a(&db, abc); + let (db, abc) = f(); + let b = cycle_b(&db, abc); + let expected = expect![[r#" + ( + [ + "DependencyIndex { ingredient_index: IngredientIndex(9), key_index: Some(Id { value: 1 }) }", + "DependencyIndex { ingredient_index: IngredientIndex(10), key_index: Some(Id { value: 1 }) }", + ], + [ + "DependencyIndex { ingredient_index: IngredientIndex(9), key_index: Some(Id { value: 1 }) }", + "DependencyIndex { ingredient_index: IngredientIndex(10), key_index: Some(Id { value: 1 }) }", + ], + ) + "#]]; + expected.assert_debug_eq(&(a.unwrap_err().cycle, b.unwrap_err().cycle)); +} + +#[test] +fn cycle_multiple() { + // No matter whether we start from A or B, we get the same set of participants: + let mut db = Database::default(); + + // Configuration: + // + // A --> B <-- C + // ^ | ^ + // +-----+ | + // | | + // +-----+ + // + // Here, conceptually, B encounters a cycle with A and then + // recovers. + let abc = ABC::new(&mut db, CycleQuery::B, CycleQuery::AthenC, CycleQuery::A); + + let c = cycle_c(&db, abc); + let b = cycle_b(&db, abc); + let a = cycle_a(&db, abc); + let expected = expect![[r#" + ( + [ + "DependencyIndex { ingredient_index: IngredientIndex(9), key_index: Some(Id { value: 1 }) }", + "DependencyIndex { ingredient_index: IngredientIndex(10), key_index: Some(Id { value: 1 }) }", + ], + [ + "DependencyIndex { ingredient_index: IngredientIndex(9), key_index: Some(Id { value: 1 }) }", + "DependencyIndex { ingredient_index: IngredientIndex(10), key_index: Some(Id { value: 1 }) }", + ], + [ + "DependencyIndex { ingredient_index: IngredientIndex(9), key_index: Some(Id { value: 1 }) }", + "DependencyIndex { ingredient_index: IngredientIndex(10), key_index: Some(Id { value: 1 }) }", + ], + ) + "#]]; + expected.assert_debug_eq(&( + c.unwrap_err().cycle, + b.unwrap_err().cycle, + a.unwrap_err().cycle, + )); +} + +#[test] +fn cycle_recovery_set_but_not_participating() { + let mut db = Database::default(); + + // A --> C -+ + // ^ | + // +--+ + let abc = ABC::new(&mut db, CycleQuery::C, CycleQuery::None, CycleQuery::C); + + // Here we expect C to panic and A not to recover: + let r = extract_cycle(|| drop(cycle_a(&db, abc))); + let expected = expect![[r#" + [ + "DependencyIndex { ingredient_index: IngredientIndex(11), key_index: Some(Id { value: 1 }) }", + ] + "#]]; + expected.assert_debug_eq(&r.all_participants(&db)); +} diff --git a/salsa-2022-tests/tests/parallel/main.rs b/salsa-2022-tests/tests/parallel/main.rs new file mode 100644 index 00000000..3f8ce0e2 --- /dev/null +++ b/salsa-2022-tests/tests/parallel/main.rs @@ -0,0 +1,7 @@ +mod setup; + +mod parallel_cycle_all_recover; +mod parallel_cycle_mid_recover; +mod parallel_cycle_none_recover; +mod parallel_cycle_one_recover; +mod signal; diff --git a/salsa-2022-tests/tests/parallel/parallel_cycle_all_recover.rs b/salsa-2022-tests/tests/parallel/parallel_cycle_all_recover.rs new file mode 100644 index 00000000..9c42abbf --- /dev/null +++ b/salsa-2022-tests/tests/parallel/parallel_cycle_all_recover.rs @@ -0,0 +1,112 @@ +//! Test for cycle recover spread across two threads. +//! See `../cycles.rs` for a complete listing of cycle tests, +//! both intra and cross thread. + +use crate::setup::Database; +use crate::setup::Knobs; +use salsa::ParallelDatabase; + +pub(crate) trait Db: salsa::DbWithJar + Knobs {} + +impl + Knobs> Db for T {} + +#[salsa::jar(db = Db)] +pub(crate) struct Jar(MyInput, a1, a2, b1, b2); + +#[salsa::input(jar = Jar)] +pub(crate) struct MyInput { + field: i32, +} + +#[salsa::tracked(jar = Jar, recovery_fn=recover_a1)] +pub(crate) fn a1(db: &dyn Db, input: MyInput) -> i32 { + // Wait to create the cycle until both threads have entered + db.signal(1); + db.wait_for(2); + + a2(db, input) +} + +fn recover_a1(db: &dyn Db, _cycle: &salsa::Cycle, key: MyInput) -> i32 { + dbg!("recover_a1"); + key.field(db) * 10 + 1 +} + +#[salsa::tracked(jar = Jar, recovery_fn=recover_a2)] +pub(crate) fn a2(db: &dyn Db, input: MyInput) -> i32 { + b1(db, input) +} + +fn recover_a2(db: &dyn Db, _cycle: &salsa::Cycle, key: MyInput) -> i32 { + dbg!("recover_a2"); + key.field(db) * 10 + 2 +} + +#[salsa::tracked(jar = Jar, recovery_fn=recover_b1)] +pub(crate) fn b1(db: &dyn Db, input: MyInput) -> i32 { + // Wait to create the cycle until both threads have entered + db.wait_for(1); + db.signal(2); + + // Wait for thread A to block on this thread + db.wait_for(3); + b2(db, input) +} + +fn recover_b1(db: &dyn Db, _cycle: &salsa::Cycle, key: MyInput) -> i32 { + dbg!("recover_b1"); + key.field(db) * 20 + 1 +} + +#[salsa::tracked(jar = Jar, recovery_fn=recover_b2)] +pub(crate) fn b2(db: &dyn Db, input: MyInput) -> i32 { + a1(db, input) +} + +fn recover_b2(db: &dyn Db, _cycle: &salsa::Cycle, key: MyInput) -> i32 { + dbg!("recover_b2"); + key.field(db) * 20 + 2 +} + +// Recover cycle test: +// +// The pattern is as follows. +// +// Thread A Thread B +// -------- -------- +// a1 b1 +// | wait for stage 1 (blocks) +// signal stage 1 | +// wait for stage 2 (blocks) (unblocked) +// | signal stage 2 +// (unblocked) wait for stage 3 (blocks) +// a2 | +// b1 (blocks -> stage 3) | +// | (unblocked) +// | b2 +// | a1 (cycle detected, recovers) +// | b2 completes, recovers +// | b1 completes, recovers +// a2 sees cycle, recovers +// a1 completes, recovers + +#[test] +fn execute() { + let mut db = Database::default(); + db.knobs().signal_on_will_block.set(3); + + let input = MyInput::new(&mut db, 1); + + let thread_a = std::thread::spawn({ + let db = db.snapshot(); + move || a1(&*db, input) + }); + + let thread_b = std::thread::spawn({ + let db = db.snapshot(); + move || b1(&*db, input) + }); + + assert_eq!(thread_a.join().unwrap(), 11); + assert_eq!(thread_b.join().unwrap(), 21); +} diff --git a/salsa-2022-tests/tests/parallel/parallel_cycle_mid_recover.rs b/salsa-2022-tests/tests/parallel/parallel_cycle_mid_recover.rs new file mode 100644 index 00000000..4256082f --- /dev/null +++ b/salsa-2022-tests/tests/parallel/parallel_cycle_mid_recover.rs @@ -0,0 +1,110 @@ +//! Test for cycle recover spread across two threads. +//! See `../cycles.rs` for a complete listing of cycle tests, +//! both intra and cross thread. + +use crate::setup::Database; +use crate::setup::Knobs; +use salsa::ParallelDatabase; + +pub(crate) trait Db: salsa::DbWithJar + Knobs {} + +impl + Knobs> Db for T {} + +#[salsa::jar(db = Db)] +pub(crate) struct Jar(MyInput, a1, a2, b1, b2, b3); + +#[salsa::input(jar = Jar)] +pub(crate) struct MyInput { + field: i32, +} + +#[salsa::tracked(jar = Jar)] +pub(crate) fn a1(db: &dyn Db, input: MyInput) -> i32 { + // tell thread b we have started + db.signal(1); + + // wait for thread b to block on a1 + db.wait_for(2); + + a2(db, input) +} +#[salsa::tracked(jar = Jar)] +pub(crate) fn a2(db: &dyn Db, input: MyInput) -> i32 { + // create the cycle + b1(db, input) +} + +#[salsa::tracked(jar = Jar, recovery_fn=recover_b1)] +pub(crate) fn b1(db: &dyn Db, input: MyInput) -> i32 { + // wait for thread a to have started + db.wait_for(1); + b2(db, input) +} + +fn recover_b1(db: &dyn Db, _cycle: &salsa::Cycle, key: MyInput) -> i32 { + dbg!("recover_b1"); + key.field(db) * 20 + 2 +} + +#[salsa::tracked(jar = Jar)] +pub(crate) fn b2(db: &dyn Db, input: MyInput) -> i32 { + // will encounter a cycle but recover + b3(db, input); + b1(db, input); // hasn't recovered yet + 0 +} + +#[salsa::tracked(jar = Jar, recovery_fn=recover_b3)] +pub(crate) fn b3(db: &dyn Db, input: MyInput) -> i32 { + // will block on thread a, signaling stage 2 + a1(db, input) +} + +fn recover_b3(db: &dyn Db, _cycle: &salsa::Cycle, key: MyInput) -> i32 { + dbg!("recover_b3"); + key.field(db) * 200 + 2 +} + +// Recover cycle test: +// +// The pattern is as follows. +// +// Thread A Thread B +// -------- -------- +// a1 b1 +// | wait for stage 1 (blocks) +// signal stage 1 | +// wait for stage 2 (blocks) (unblocked) +// | | +// | b2 +// | b3 +// | a1 (blocks -> stage 2) +// (unblocked) | +// a2 (cycle detected) | +// b3 recovers +// b2 resumes +// b1 recovers + +#[test] +fn execute() { + let mut db = Database::default(); + db.knobs().signal_on_will_block.set(3); + + let input = MyInput::new(&mut db, 1); + + let thread_a = std::thread::spawn({ + let db = db.snapshot(); + move || a1(&*db, input) + }); + + let thread_b = std::thread::spawn({ + let db = db.snapshot(); + move || b1(&*db, input) + }); + + // We expect that the recovery function yields + // `1 * 20 + 2`, which is returned (and forwarded) + // to b1, and from there to a2 and a1. + assert_eq!(thread_a.join().unwrap(), 22); + assert_eq!(thread_b.join().unwrap(), 22); +} diff --git a/salsa-2022-tests/tests/parallel/parallel_cycle_none_recover.rs b/salsa-2022-tests/tests/parallel/parallel_cycle_none_recover.rs new file mode 100644 index 00000000..5a0227f2 --- /dev/null +++ b/salsa-2022-tests/tests/parallel/parallel_cycle_none_recover.rs @@ -0,0 +1,83 @@ +//! Test a cycle where no queries recover that occurs across threads. +//! See the `../cycles.rs` for a complete listing of cycle tests, +//! both intra and cross thread. + +use crate::setup::Database; +use crate::setup::Knobs; +use expect_test::expect; +use salsa::ParallelDatabase; + +pub(crate) trait Db: salsa::DbWithJar + Knobs {} + +impl + Knobs> Db for T {} + +#[salsa::jar(db = Db)] +pub(crate) struct Jar(MyInput, a, b); + +#[salsa::input(jar = Jar)] +pub(crate) struct MyInput { + field: i32, +} + +#[salsa::tracked(jar = Jar)] +pub(crate) fn a(db: &dyn Db, input: MyInput) -> i32 { + // Wait to create the cycle until both threads have entered + db.signal(1); + db.wait_for(2); + + b(db, input) +} + +#[salsa::tracked(jar = Jar)] +pub(crate) fn b(db: &dyn Db, input: MyInput) -> i32 { + // Wait to create the cycle until both threads have entered + db.wait_for(1); + db.signal(2); + + // Wait for thread A to block on this thread + db.wait_for(3); + + // Now try to execute A + a(db, input) +} + +#[test] +fn execute() { + let mut db = Database::default(); + db.knobs().signal_on_will_block.set(3); + + let input = MyInput::new(&mut db, -1); + + let thread_a = std::thread::spawn({ + let db = db.snapshot(); + move || a(&*db, input) + }); + + let thread_b = std::thread::spawn({ + let db = db.snapshot(); + move || b(&*db, input) + }); + + // We expect B to panic because it detects a cycle (it is the one that calls A, ultimately). + // Right now, it panics with a string. + let err_b = thread_b.join().unwrap_err(); + if let Some(c) = err_b.downcast_ref::() { + let expected = expect![[r#" + [ + "DependencyIndex { ingredient_index: IngredientIndex(8), key_index: Some(Id { value: 1 }) }", + "DependencyIndex { ingredient_index: IngredientIndex(9), key_index: Some(Id { value: 1 }) }", + ] + "#]]; + expected.assert_debug_eq(&c.all_participants(&db)); + } else { + panic!("b failed in an unexpected way: {:?}", err_b); + } + + // We expect A to propagate a panic, which causes us to use the sentinel + // type `Canceled`. + assert!(thread_a + .join() + .unwrap_err() + .downcast_ref::() + .is_some()); +} diff --git a/salsa-2022-tests/tests/parallel/parallel_cycle_one_recover.rs b/salsa-2022-tests/tests/parallel/parallel_cycle_one_recover.rs new file mode 100644 index 00000000..becdcddd --- /dev/null +++ b/salsa-2022-tests/tests/parallel/parallel_cycle_one_recover.rs @@ -0,0 +1,99 @@ +//! Test for cycle recover spread across two threads. +//! See `../cycles.rs` for a complete listing of cycle tests, +//! both intra and cross thread. + +use crate::setup::Database; +use crate::setup::Knobs; +use salsa::ParallelDatabase; + +pub(crate) trait Db: salsa::DbWithJar + Knobs {} + +impl + Knobs> Db for T {} + +#[salsa::jar(db = Db)] +pub(crate) struct Jar(MyInput, a1, a2, b1, b2); + +#[salsa::input(jar = Jar)] +pub(crate) struct MyInput { + field: i32, +} + +#[salsa::tracked(jar = Jar)] +pub(crate) fn a1(db: &dyn Db, input: MyInput) -> i32 { + // Wait to create the cycle until both threads have entered + db.signal(1); + db.wait_for(2); + + a2(db, input) +} +#[salsa::tracked(jar = Jar, recovery_fn=recover)] +pub(crate) fn a2(db: &dyn Db, input: MyInput) -> i32 { + b1(db, input) +} + +fn recover(db: &dyn Db, _cycle: &salsa::Cycle, key: MyInput) -> i32 { + dbg!("recover"); + key.field(db) * 20 + 2 +} + +#[salsa::tracked(jar = Jar)] +pub(crate) fn b1(db: &dyn Db, input: MyInput) -> i32 { + // Wait to create the cycle until both threads have entered + db.wait_for(1); + db.signal(2); + + // Wait for thread A to block on this thread + db.wait_for(3); + b2(db, input) +} + +#[salsa::tracked(jar = Jar)] +pub(crate) fn b2(db: &dyn Db, input: MyInput) -> i32 { + a1(db, input) +} + +// Recover cycle test: +// +// The pattern is as follows. +// +// Thread A Thread B +// -------- -------- +// a1 b1 +// | wait for stage 1 (blocks) +// signal stage 1 | +// wait for stage 2 (blocks) (unblocked) +// | signal stage 2 +// (unblocked) wait for stage 3 (blocks) +// a2 | +// b1 (blocks -> stage 3) | +// | (unblocked) +// | b2 +// | a1 (cycle detected) +// a2 recovery fn executes | +// a1 completes normally | +// b2 completes, recovers +// b1 completes, recovers + +#[test] +fn execute() { + let mut db = Database::default(); + db.knobs().signal_on_will_block.set(3); + + let input = MyInput::new(&mut db, 1); + + let thread_a = std::thread::spawn({ + let db = db.snapshot(); + move || a1(&*db, input) + }); + + let thread_b = std::thread::spawn({ + let db = db.snapshot(); + move || b1(&*db, input) + }); + + // We expect that the recovery function yields + // `1 * 20 + 2`, which is returned (and forwarded) + // to b1, and from there to a2 and a1. + assert_eq!(thread_a.join().unwrap(), 22); + assert_eq!(thread_b.join().unwrap(), 22); +} diff --git a/salsa-2022-tests/tests/parallel/setup.rs b/salsa-2022-tests/tests/parallel/setup.rs new file mode 100644 index 00000000..8012928c --- /dev/null +++ b/salsa-2022-tests/tests/parallel/setup.rs @@ -0,0 +1,73 @@ +use std::{cell::Cell, sync::Arc}; + +use crate::signal::Signal; + +/// Various "knobs" and utilities used by tests to force +/// a certain behavior. +pub(crate) trait Knobs { + fn knobs(&self) -> &KnobsStruct; + + fn signal(&self, stage: usize); + + fn wait_for(&self, stage: usize); +} + +/// Various "knobs" that can be used to customize how the queries +/// behave on one specific thread. Note that this state is +/// intentionally thread-local (apart from `signal`). +#[derive(Clone, Default)] +pub(crate) struct KnobsStruct { + /// A kind of flexible barrier used to coordinate execution across + /// threads to ensure we reach various weird states. + pub(crate) signal: Arc, + + /// When this database is about to block, send a signal. + pub(crate) signal_on_will_block: Cell, +} + +#[salsa::db( + crate::parallel_cycle_one_recover::Jar, + crate::parallel_cycle_none_recover::Jar, + crate::parallel_cycle_mid_recover::Jar, + crate::parallel_cycle_all_recover::Jar +)] +#[derive(Default)] +pub(crate) struct Database { + storage: salsa::Storage, + knobs: KnobsStruct, +} + +impl salsa::Database for Database { + fn salsa_runtime(&self) -> &salsa::Runtime { + self.storage.runtime() + } + + fn salsa_event(&self, event: salsa::Event) { + if let salsa::EventKind::WillBlockOn { .. } = event.kind { + self.signal(self.knobs().signal_on_will_block.get()); + } + } +} + +impl salsa::ParallelDatabase for Database { + fn snapshot(&self) -> salsa::Snapshot { + salsa::Snapshot::new(Database { + storage: self.storage.snapshot(), + knobs: self.knobs.clone(), + }) + } +} + +impl Knobs for Database { + fn knobs(&self) -> &KnobsStruct { + &self.knobs + } + + fn signal(&self, stage: usize) { + self.knobs.signal.signal(stage); + } + + fn wait_for(&self, stage: usize) { + self.knobs.signal.wait_for(stage); + } +} diff --git a/salsa-2022-tests/tests/parallel/signal.rs b/salsa-2022-tests/tests/parallel/signal.rs new file mode 100644 index 00000000..f09aecc8 --- /dev/null +++ b/salsa-2022-tests/tests/parallel/signal.rs @@ -0,0 +1,40 @@ +use parking_lot::{Condvar, Mutex}; + +#[derive(Default)] +pub(crate) struct Signal { + value: Mutex, + cond_var: Condvar, +} + +impl Signal { + pub(crate) fn signal(&self, stage: usize) { + dbg!(format!("signal({})", stage)); + + // This check avoids acquiring the lock for things that will + // clearly be a no-op. Not *necessary* but helps to ensure we + // are more likely to encounter weird race conditions; + // otherwise calls to `sum` will tend to be unnecessarily + // synchronous. + if stage > 0 { + let mut v = self.value.lock(); + if stage > *v { + *v = stage; + self.cond_var.notify_all(); + } + } + } + + /// Waits until the given condition is true; the fn is invoked + /// with the current stage. + pub(crate) fn wait_for(&self, stage: usize) { + dbg!(format!("wait_for({})", stage)); + + // As above, avoid lock if clearly a no-op. + if stage > 0 { + let mut v = self.value.lock(); + while *v < stage { + self.cond_var.wait(&mut v); + } + } + } +}