From 815bf6003b4e183d44bc0124b646587617a28df1 Mon Sep 17 00:00:00 2001 From: XFFXFF <1247714429@qq.com> Date: Mon, 8 Aug 2022 07:20:39 +0800 Subject: [PATCH 1/7] add a test --- salsa-2022-tests/tests/cycles.rs | 43 ++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 salsa-2022-tests/tests/cycles.rs diff --git a/salsa-2022-tests/tests/cycles.rs b/salsa-2022-tests/tests/cycles.rs new file mode 100644 index 00000000..0e5972a6 --- /dev/null +++ b/salsa-2022-tests/tests/cycles.rs @@ -0,0 +1,43 @@ + +#![allow(warnings)] + +#[salsa::jar(db = Db)] +struct Jar(MyInput, memoized_a, memoized_b); + +trait Db: salsa::DbWithJar {} + +#[salsa::db(Jar)] +#[derive(Default)] +struct Database { + storage: salsa::Storage, +} + +impl salsa::Database for Database { + fn salsa_runtime(&self) -> &salsa::Runtime { + self.storage.runtime() + } +} + +impl Db for Database {} + +#[salsa::tracked(jar = Jar)] +struct MyInput { + +} + +#[salsa::tracked(jar = Jar)] +fn memoized_a(db: &dyn Db, input: MyInput) -> () { + memoized_b(db, input) +} + +#[salsa::tracked(jar = Jar)] +fn memoized_b(db: &dyn Db, input: MyInput) -> () { + memoized_a(db, input) +} + +#[test] +fn execute() { + let mut db = Database::default(); + let input = MyInput::new(&mut db); + memoized_a(&db, input); +} \ No newline at end of file From 0f907dd3cd6220635480eda0d945f57746ebeaaf Mon Sep 17 00:00:00 2001 From: XFFXFF <1247714429@qq.com> Date: Mon, 8 Aug 2022 07:32:39 +0800 Subject: [PATCH 2/7] add recovery_fn option --- .../salsa-2022-macros/src/accumulator.rs | 2 ++ components/salsa-2022-macros/src/jar.rs | 2 ++ components/salsa-2022-macros/src/options.rs | 20 +++++++++++++++++++ .../salsa-2022-macros/src/salsa_struct.rs | 2 ++ .../salsa-2022-macros/src/tracked_fn.rs | 2 ++ salsa-2022-tests/tests/cycles.rs | 4 +++- 6 files changed, 31 insertions(+), 1 deletion(-) diff --git a/components/salsa-2022-macros/src/accumulator.rs b/components/salsa-2022-macros/src/accumulator.rs index ed870fe8..266f3ac3 100644 --- a/components/salsa-2022-macros/src/accumulator.rs +++ b/components/salsa-2022-macros/src/accumulator.rs @@ -30,6 +30,8 @@ impl crate::options::AllowedOptions for Accumulator { const DATA: bool = false; const DB: bool = false; + + const RECOVERY_FN: bool = false; } fn accumulator_contents( diff --git a/components/salsa-2022-macros/src/jar.rs b/components/salsa-2022-macros/src/jar.rs index 6b474ce7..923ecc17 100644 --- a/components/salsa-2022-macros/src/jar.rs +++ b/components/salsa-2022-macros/src/jar.rs @@ -39,6 +39,8 @@ impl crate::options::AllowedOptions for Jar { const DATA: bool = false; const DB: bool = true; + + const RECOVERY_FN: bool = false; } pub(crate) fn jar_struct_and_friends( diff --git a/components/salsa-2022-macros/src/options.rs b/components/salsa-2022-macros/src/options.rs index e6f080f4..9d72b8dd 100644 --- a/components/salsa-2022-macros/src/options.rs +++ b/components/salsa-2022-macros/src/options.rs @@ -41,6 +41,8 @@ pub(crate) struct Options { /// If this is `Some`, the value is the ``. pub data: Option, + pub recovery_fn: Option, + /// Remember the `A` parameter, which plays no role after parsing. phantom: PhantomData, } @@ -53,6 +55,7 @@ impl Default for Options { no_eq: Default::default(), jar_ty: Default::default(), db_path: Default::default(), + recovery_fn: Default::default(), data: Default::default(), phantom: Default::default(), } @@ -67,6 +70,7 @@ pub(crate) trait AllowedOptions { const JAR: bool; const DATA: bool; const DB: bool; + const RECOVERY_FN: bool; } type Equals = syn::Token![=]; @@ -159,6 +163,22 @@ impl syn::parse::Parse for Options { "`db` option not allowed here", )); } + } else if ident == "recovery_fn" { + if A::RECOVERY_FN { + let _eq = Equals::parse(input)?; + let path = syn::Path::parse(input)?; + if let Some(old) = std::mem::replace(&mut options.recovery_fn, Some(path)) { + return Err(syn::Error::new( + old.span(), + "option `recovery_fn` provided twice", + )); + } + } else { + return Err(syn::Error::new( + ident.span(), + "`recovery_fn` option not allowed here", + )); + } } else if ident == "data" { if A::DATA { let _eq = Equals::parse(input)?; diff --git a/components/salsa-2022-macros/src/salsa_struct.rs b/components/salsa-2022-macros/src/salsa_struct.rs index 6721eb34..72bf27ce 100644 --- a/components/salsa-2022-macros/src/salsa_struct.rs +++ b/components/salsa-2022-macros/src/salsa_struct.rs @@ -47,6 +47,8 @@ impl crate::options::AllowedOptions for SalsaStruct { const DATA: bool = true; const DB: bool = false; + + const RECOVERY_FN: bool = false; } const BANNED_FIELD_NAMES: &[&str] = &["from", "new"]; diff --git a/components/salsa-2022-macros/src/tracked_fn.rs b/components/salsa-2022-macros/src/tracked_fn.rs index 84be7a37..662d7e7e 100644 --- a/components/salsa-2022-macros/src/tracked_fn.rs +++ b/components/salsa-2022-macros/src/tracked_fn.rs @@ -70,6 +70,8 @@ impl crate::options::AllowedOptions for TrackedFn { const DATA: bool = false; const DB: bool = false; + + const RECOVERY_FN: bool = true; } /// Returns the key type for this tracked function. diff --git a/salsa-2022-tests/tests/cycles.rs b/salsa-2022-tests/tests/cycles.rs index 0e5972a6..125a5f5a 100644 --- a/salsa-2022-tests/tests/cycles.rs +++ b/salsa-2022-tests/tests/cycles.rs @@ -25,11 +25,13 @@ struct MyInput { } -#[salsa::tracked(jar = Jar)] +#[salsa::tracked(jar = Jar, recovery_fn = my_recover_fn)] fn memoized_a(db: &dyn Db, input: MyInput) -> () { memoized_b(db, input) } +fn my_recover_fn(db: &dyn Db, cycle: &salsa::Cycle) -> () {} + #[salsa::tracked(jar = Jar)] fn memoized_b(db: &dyn Db, input: MyInput) -> () { memoized_a(db, input) From 045f5186b3e9ab41022baa5e62920bf709a1ce69 Mon Sep 17 00:00:00 2001 From: XFFXFF <1247714429@qq.com> Date: Mon, 8 Aug 2022 08:57:29 +0800 Subject: [PATCH 3/7] modify tracked_fn macro to use it --- .../salsa-2022-macros/src/tracked_fn.rs | 28 ++++++++++++++++--- salsa-2022-tests/tests/cycles.rs | 5 ++-- 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/components/salsa-2022-macros/src/tracked_fn.rs b/components/salsa-2022-macros/src/tracked_fn.rs index 662d7e7e..e52a34d8 100644 --- a/components/salsa-2022-macros/src/tracked_fn.rs +++ b/components/salsa-2022-macros/src/tracked_fn.rs @@ -134,14 +134,34 @@ fn fn_configuration(args: &Args, item_fn: &syn::ItemFn) -> Configuration { }; let value_ty = configuration::value_ty(&item_fn.sig); - // FIXME: these are hardcoded for now - let cycle_strategy = CycleRecoveryStrategy::Panic; + let fn_ty = item_fn.sig.ident.clone(); + // FIXME: these are hardcoded for now + let indices = (0..item_fn.sig.inputs.len() - 1).map(|i| Literal::usize_unsuffixed(i)); + let (cycle_strategy, recover_fn) = if let Some(recovery_fn) = &args.recovery_fn { + let cycle_strategy = CycleRecoveryStrategy::Fallback; + + let cycle_fullback = parse_quote! { + fn recover_from_cycle(__db: &salsa::function::DynDb, __cycle: &salsa::Cycle, __id: Self::Key) -> Self::Value { + let (__jar, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(__db); + let __ingredients = + <_ as salsa::storage::HasIngredientsFor<#fn_ty>>::ingredient(__jar); + let __key = __ingredients.intern_map.data(__runtime, __id).clone(); + #recovery_fn(__db, __cycle, #(__key.#indices),*) + } + }; + (cycle_strategy, cycle_fullback) + } else { + let cycle_strategy = CycleRecoveryStrategy::Panic; + let cycle_panic = configuration::panic_cycle_recovery_fn(); + (cycle_strategy, cycle_panic) + }; + + // let cycle_strategy = CycleRecoveryStrategy::Panic; + // let recover_fn = configuration::panic_cycle_recovery_fn(); let backdate_fn = configuration::should_backdate_value_fn(args.should_backdate()); - let recover_fn = configuration::panic_cycle_recovery_fn(); // The type of the configuration struct; this has the same name as the fn itself. - let fn_ty = item_fn.sig.ident.clone(); // Make a copy of the fn with a different name; we will invoke this from `execute`. // We need to change the name because, otherwise, if the function invoked itself diff --git a/salsa-2022-tests/tests/cycles.rs b/salsa-2022-tests/tests/cycles.rs index 125a5f5a..31cf6f84 100644 --- a/salsa-2022-tests/tests/cycles.rs +++ b/salsa-2022-tests/tests/cycles.rs @@ -20,7 +20,7 @@ impl salsa::Database for Database { impl Db for Database {} -#[salsa::tracked(jar = Jar)] +#[salsa::input(jar = Jar)] struct MyInput { } @@ -30,7 +30,8 @@ fn memoized_a(db: &dyn Db, input: MyInput) -> () { memoized_b(db, input) } -fn my_recover_fn(db: &dyn Db, cycle: &salsa::Cycle) -> () {} +fn my_recover_fn(db: &dyn Db, cycle: &salsa::Cycle, input: MyInput) -> () { +} #[salsa::tracked(jar = Jar)] fn memoized_b(db: &dyn Db, input: MyInput) -> () { From 80bfff8d7ae700b71ecae9a3ffed369ee3fcf31a Mon Sep 17 00:00:00 2001 From: XFFXFF <1247714429@qq.com> Date: Tue, 9 Aug 2022 07:54:15 +0800 Subject: [PATCH 4/7] port old tests for cycle in a single thread --- salsa-2022-tests/tests/cycles.rs | 405 +++++++++++++++++++++++++++++-- 1 file changed, 391 insertions(+), 14 deletions(-) diff --git a/salsa-2022-tests/tests/cycles.rs b/salsa-2022-tests/tests/cycles.rs index 31cf6f84..111a9685 100644 --- a/salsa-2022-tests/tests/cycles.rs +++ b/salsa-2022-tests/tests/cycles.rs @@ -1,8 +1,68 @@ - #![allow(warnings)] +use std::panic::{RefUnwindSafe, UnwindSafe}; + +use expect_test::expect; +use salsa::storage::HasJarsDyn; + +// Axes: +// +// Threading +// * Intra-thread +// * Cross-thread -- part of cycle is on one thread, part on another +// +// Recovery strategies: +// * Panic +// * Fallback +// * Mixed -- multiple strategies within cycle participants +// +// Across revisions: +// * N/A -- only one revision +// * Present in new revision, not old +// * Present in old revision, not new +// * Present in both revisions +// +// Dependencies +// * Tracked +// * Untracked -- cycle participant(s) contain untracked reads +// +// Layers +// * Direct -- cycle participant is directly invoked from test +// * Indirect -- invoked a query that invokes the cycle +// +// +// | Thread | Recovery | Old, New | Dep style | Layers | Test Name | +// | ------ | -------- | -------- | --------- | ------ | --------- | +// | Intra | Panic | N/A | Tracked | direct | cycle_memoized | +// | Intra | Panic | N/A | Untracked | direct | cycle_volatile | +// | Intra | Fallback | N/A | Tracked | direct | cycle_cycle | +// | Intra | Fallback | N/A | Tracked | indirect | inner_cycle | +// | Intra | Fallback | Both | Tracked | direct | cycle_revalidate | +// | Intra | Fallback | New | Tracked | direct | cycle_appears | +// | Intra | Fallback | Old | Tracked | direct | cycle_disappears | +// | Intra | Fallback | Old | Tracked | direct | cycle_disappears_durability | +// | Intra | Mixed | N/A | Tracked | direct | cycle_mixed_1 | +// | Intra | Mixed | N/A | Tracked | direct | cycle_mixed_2 | +// | Cross | Fallback | N/A | Tracked | both | parallel/cycles.rs: recover_parallel_cycle | +// | Cross | Panic | N/A | Tracked | both | parallel/cycles.rs: panic_parallel_cycle | + +#[derive(PartialEq, Eq, Hash, Clone, Debug)] +struct Error { + cycle: Vec, +} + #[salsa::jar(db = Db)] -struct Jar(MyInput, memoized_a, memoized_b); +struct Jar( + MyInput, + memoized_a, + memoized_b, + volatile_a, + volatile_b, + ABC, + cycle_a, + cycle_b, + cycle_c, +); trait Db: salsa::DbWithJar {} @@ -20,27 +80,344 @@ impl salsa::Database for Database { impl Db for Database {} +impl RefUnwindSafe for Database {} + #[salsa::input(jar = Jar)] -struct MyInput { +struct MyInput {} -} - -#[salsa::tracked(jar = Jar, recovery_fn = my_recover_fn)] -fn memoized_a(db: &dyn Db, input: MyInput) -> () { +#[salsa::tracked(jar = Jar)] +fn memoized_a(db: &dyn Db, input: MyInput) { memoized_b(db, input) } -fn my_recover_fn(db: &dyn Db, cycle: &salsa::Cycle, input: MyInput) -> () { -} - #[salsa::tracked(jar = Jar)] -fn memoized_b(db: &dyn Db, input: MyInput) -> () { +fn memoized_b(db: &dyn Db, input: MyInput) { memoized_a(db, input) } +#[salsa::tracked(jar = Jar)] +fn volatile_a(db: &dyn Db, input: MyInput) { + db.runtime().report_untracked_read(); + volatile_b(db, input) +} + +#[salsa::tracked(jar = Jar)] +fn volatile_b(db: &dyn Db, input: MyInput) { + db.runtime().report_untracked_read(); + volatile_a(db, input) +} + +/// The queries A, B, and C in `Database` can be configured +/// to invoke one another in arbitrary ways using this +/// enum. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +enum CycleQuery { + None, + A, + B, + C, + AthenC, +} + +#[salsa::input(jar = Jar)] +struct ABC { + a: CycleQuery, + b: CycleQuery, + c: CycleQuery, +} + +impl CycleQuery { + fn invoke(self, db: &dyn Db, abc: ABC) -> Result<(), Error> { + match self { + CycleQuery::A => cycle_a(db, abc), + CycleQuery::B => cycle_b(db, abc), + CycleQuery::C => cycle_c(db, abc), + CycleQuery::AthenC => { + let _ = cycle_a(db, abc); + cycle_c(db, abc) + } + CycleQuery::None => Ok(()), + } + } +} + +#[salsa::tracked(jar = Jar, recovery_fn=recover_a)] +fn cycle_a(db: &dyn Db, abc: ABC) -> Result<(), Error> { + abc.a(db).invoke(db, abc) +} + +fn recover_a(db: &dyn Db, cycle: &salsa::Cycle, abc: ABC) -> Result<(), Error> { + Err(Error { + cycle: cycle.all_participants(db), + }) +} + +#[salsa::tracked(jar = Jar, recovery_fn=recover_b)] +fn cycle_b(db: &dyn Db, abc: ABC) -> Result<(), Error> { + abc.b(db).invoke(db, abc) +} + +fn recover_b(db: &dyn Db, cycle: &salsa::Cycle, abc: ABC) -> Result<(), Error> { + Err(Error { + cycle: cycle.all_participants(db), + }) +} + +#[salsa::tracked(jar = Jar)] +fn cycle_c(db: &dyn Db, abc: ABC) -> Result<(), Error> { + abc.c(db).invoke(db, abc) +} + +#[track_caller] +fn extract_cycle(f: impl FnOnce() + UnwindSafe) -> salsa::Cycle { + let v = std::panic::catch_unwind(f); + if let Err(d) = &v { + if let Some(cycle) = d.downcast_ref::() { + return cycle.clone(); + } + } + panic!("unexpected value: {:?}", v) +} + #[test] -fn execute() { +fn cycle_memoized() { let mut db = Database::default(); let input = MyInput::new(&mut db); - memoized_a(&db, input); -} \ No newline at end of file + let cycle = extract_cycle(|| memoized_a(&db, input)); + let expected = expect![[r#" + [ + "DependencyIndex { ingredient_index: IngredientIndex(1), key_index: Some(Id { value: 1 }) }", + "DependencyIndex { ingredient_index: IngredientIndex(2), key_index: Some(Id { value: 1 }) }", + ] + "#]]; + expected.assert_debug_eq(&cycle.all_participants(&db)); +} + +#[test] +fn cycle_volatile() { + let mut db = Database::default(); + let input = MyInput::new(&mut db); + let cycle = extract_cycle(|| volatile_a(&db, input)); + let expected = expect![[r#" + [ + "DependencyIndex { ingredient_index: IngredientIndex(3), key_index: Some(Id { value: 1 }) }", + "DependencyIndex { ingredient_index: IngredientIndex(4), key_index: Some(Id { value: 1 }) }", + ] + "#]]; + expected.assert_debug_eq(&cycle.all_participants(&db)); +} + +#[test] +fn expect_cycle() { + // A --> B + // ^ | + // +-----+ + + let mut db = Database::default(); + let abc = ABC::new(&mut db, CycleQuery::B, CycleQuery::A, CycleQuery::None); + assert!(cycle_a(&db, abc).is_err()); +} + +#[test] +fn inner_cycle() { + // A --> B <-- C + // ^ | + // +-----+ + let mut db = Database::default(); + let abc = ABC::new(&mut db, CycleQuery::B, CycleQuery::A, CycleQuery::B); + let err = cycle_c(&db, abc); + assert!(err.is_err()); + let expected = expect![[r#" + [ + "DependencyIndex { ingredient_index: IngredientIndex(9), key_index: Some(Id { value: 1 }) }", + "DependencyIndex { ingredient_index: IngredientIndex(10), key_index: Some(Id { value: 1 }) }", + ] + "#]]; + expected.assert_debug_eq(&err.unwrap_err().cycle); +} + +#[test] +fn cycle_revalidate() { + // A --> B + // ^ | + // +-----+ + let mut db = Database::default(); + let abc = ABC::new(&mut db, CycleQuery::B, CycleQuery::A, CycleQuery::None); + assert!(cycle_a(&db, abc).is_err()); + abc.set_b(&mut db, CycleQuery::A); // same value as default + assert!(cycle_a(&db, abc).is_err()); +} + +#[test] +fn cycle_recovery_unchanged_twice() { + // A --> B + // ^ | + // +-----+ + let mut db = Database::default(); + let abc = ABC::new(&mut db, CycleQuery::B, CycleQuery::A, CycleQuery::None); + assert!(cycle_a(&db, abc).is_err()); + + abc.set_c(&mut db, CycleQuery::A); // same value as default + assert!(cycle_a(&db, abc).is_err()); +} + +#[test] +fn cycle_appears() { + let mut db = Database::default(); + + // A --> B + let abc = ABC::new(&mut db, CycleQuery::B, CycleQuery::None, CycleQuery::None); + assert!(cycle_a(&db, abc).is_ok()); + + // A --> B + // ^ | + // +-----+ + abc.set_b(&mut db, CycleQuery::A); + assert!(cycle_a(&db, abc).is_err()); +} + +#[test] +fn cycle_disappears() { + let mut db = Database::default(); + + // A --> B + // ^ | + // +-----+ + let abc = ABC::new(&mut db, CycleQuery::B, CycleQuery::A, CycleQuery::None); + assert!(cycle_a(&db, abc).is_err()); + + // A --> B + abc.set_b(&mut db, CycleQuery::None); + assert!(cycle_a(&db, abc).is_ok()); +} + +#[test] +fn cycle_mixed_1() { + let mut db = Database::default(); + + // A --> B <-- C + // | ^ + // +-----+ + let abc = ABC::new(&mut db, CycleQuery::B, CycleQuery::C, CycleQuery::B); + + let expected = expect![[r#" + [ + "DependencyIndex { ingredient_index: IngredientIndex(10), key_index: Some(Id { value: 1 }) }", + "DependencyIndex { ingredient_index: IngredientIndex(11), key_index: Some(Id { value: 1 }) }", + ] + "#]]; + expected.assert_debug_eq(&cycle_c(&db, abc).unwrap_err().cycle); +} + +#[test] +fn cycle_mixed_2() { + let mut db = Database::default(); + + // Configuration: + // + // A --> B --> C + // ^ | + // +-----------+ + let abc = ABC::new(&mut db, CycleQuery::B, CycleQuery::C, CycleQuery::A); + let expected = expect![[r#" + [ + "DependencyIndex { ingredient_index: IngredientIndex(9), key_index: Some(Id { value: 1 }) }", + "DependencyIndex { ingredient_index: IngredientIndex(10), key_index: Some(Id { value: 1 }) }", + "DependencyIndex { ingredient_index: IngredientIndex(11), key_index: Some(Id { value: 1 }) }", + ] + "#]]; + expected.assert_debug_eq(&cycle_a(&db, abc).unwrap_err().cycle); +} + +#[test] +fn cycle_deterministic_order() { + // No matter whether we start from A or B, we get the same set of participants: + let f = || { + let mut db = Database::default(); + + // A --> B + // ^ | + // +-----+ + let abc = ABC::new(&mut db, CycleQuery::B, CycleQuery::A, CycleQuery::None); + (db, abc) + }; + let (db, abc) = f(); + let a = cycle_a(&db, abc); + let (db, abc) = f(); + let b = cycle_b(&db, abc); + let expected = expect![[r#" + ( + [ + "DependencyIndex { ingredient_index: IngredientIndex(9), key_index: Some(Id { value: 1 }) }", + "DependencyIndex { ingredient_index: IngredientIndex(10), key_index: Some(Id { value: 1 }) }", + ], + [ + "DependencyIndex { ingredient_index: IngredientIndex(9), key_index: Some(Id { value: 1 }) }", + "DependencyIndex { ingredient_index: IngredientIndex(10), key_index: Some(Id { value: 1 }) }", + ], + ) + "#]]; + expected.assert_debug_eq(&(a.unwrap_err().cycle, b.unwrap_err().cycle)); +} + +#[test] +fn cycle_multiple() { + // No matter whether we start from A or B, we get the same set of participants: + let mut db = Database::default(); + + // Configuration: + // + // A --> B <-- C + // ^ | ^ + // +-----+ | + // | | + // +-----+ + // + // Here, conceptually, B encounters a cycle with A and then + // recovers. + let abc = ABC::new(&mut db, CycleQuery::B, CycleQuery::AthenC, CycleQuery::A); + + let c = cycle_c(&db, abc); + let b = cycle_b(&db, abc); + let a = cycle_a(&db, abc); + let expected = expect![[r#" + ( + [ + "DependencyIndex { ingredient_index: IngredientIndex(9), key_index: Some(Id { value: 1 }) }", + "DependencyIndex { ingredient_index: IngredientIndex(10), key_index: Some(Id { value: 1 }) }", + ], + [ + "DependencyIndex { ingredient_index: IngredientIndex(9), key_index: Some(Id { value: 1 }) }", + "DependencyIndex { ingredient_index: IngredientIndex(10), key_index: Some(Id { value: 1 }) }", + ], + [ + "DependencyIndex { ingredient_index: IngredientIndex(9), key_index: Some(Id { value: 1 }) }", + "DependencyIndex { ingredient_index: IngredientIndex(10), key_index: Some(Id { value: 1 }) }", + ], + ) + "#]]; + expected.assert_debug_eq(&( + c.unwrap_err().cycle, + b.unwrap_err().cycle, + a.unwrap_err().cycle, + )); +} + +#[test] +fn cycle_recovery_set_but_not_participating() { + let mut db = Database::default(); + + // A --> C -+ + // ^ | + // +--+ + let abc = ABC::new(&mut db, CycleQuery::C, CycleQuery::None, CycleQuery::C); + + // Here we expect C to panic and A not to recover: + let r = extract_cycle(|| drop(cycle_a(&db, abc))); + let expected = expect![[r#" + [ + "DependencyIndex { ingredient_index: IngredientIndex(11), key_index: Some(Id { value: 1 }) }", + ] + "#]]; + expected.assert_debug_eq(&r.all_participants(&db)); +} From 9fb5f7a366ea462bce89011b4fe8fb31701ce4bf Mon Sep 17 00:00:00 2001 From: XFFXFF <1247714429@qq.com> Date: Tue, 9 Aug 2022 08:13:21 +0800 Subject: [PATCH 5/7] add some comments --- components/salsa-2022-macros/src/options.rs | 13 ++++++++----- components/salsa-2022-macros/src/tracked_fn.rs | 6 +++--- salsa-2022-tests/tests/cycles.rs | 6 ++++-- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/components/salsa-2022-macros/src/options.rs b/components/salsa-2022-macros/src/options.rs index 9d72b8dd..eb8bd937 100644 --- a/components/salsa-2022-macros/src/options.rs +++ b/components/salsa-2022-macros/src/options.rs @@ -27,22 +27,25 @@ pub(crate) struct Options { /// The `jar = ` option is used to indicate the jar; it defaults to `crate::jar`. /// - /// If this is `Some`, the value is the ``. + /// If this is `Some`, the value is the ``. pub jar_ty: Option, - /// The `db = ` option is used to indicate the db. + /// The `db = ` option is used to indicate the db. /// - /// If this is `Some`, the value is the ``. + /// If this is `Some`, the value is the ``. pub db_path: Option, + /// The `recovery_fn = ` option is used to indicate the recovery function. + /// + /// If this is `Some`, the value is the ``. + pub recovery_fn: Option, + /// The `data = ` option is used to define the name of the data type for an interned /// struct. /// /// If this is `Some`, the value is the ``. pub data: Option, - pub recovery_fn: Option, - /// Remember the `A` parameter, which plays no role after parsing. phantom: PhantomData, } diff --git a/components/salsa-2022-macros/src/tracked_fn.rs b/components/salsa-2022-macros/src/tracked_fn.rs index e52a34d8..f22dd213 100644 --- a/components/salsa-2022-macros/src/tracked_fn.rs +++ b/components/salsa-2022-macros/src/tracked_fn.rs @@ -136,9 +136,10 @@ fn fn_configuration(args: &Args, item_fn: &syn::ItemFn) -> Configuration { let fn_ty = item_fn.sig.ident.clone(); - // FIXME: these are hardcoded for now let indices = (0..item_fn.sig.inputs.len() - 1).map(|i| Literal::usize_unsuffixed(i)); let (cycle_strategy, recover_fn) = if let Some(recovery_fn) = &args.recovery_fn { + // Create the `recover_from_cycle` function, which (a) maps from the interned id to the actual + // keys and then (b) invokes the recover function itself. let cycle_strategy = CycleRecoveryStrategy::Fallback; let cycle_fullback = parse_quote! { @@ -152,13 +153,12 @@ fn fn_configuration(args: &Args, item_fn: &syn::ItemFn) -> Configuration { }; (cycle_strategy, cycle_fullback) } else { + // When the `recovery_fn` attribute is not set, set `cycle_strategy` to `Panic` let cycle_strategy = CycleRecoveryStrategy::Panic; let cycle_panic = configuration::panic_cycle_recovery_fn(); (cycle_strategy, cycle_panic) }; - // let cycle_strategy = CycleRecoveryStrategy::Panic; - // let recover_fn = configuration::panic_cycle_recovery_fn(); let backdate_fn = configuration::should_backdate_value_fn(args.should_backdate()); // The type of the configuration struct; this has the same name as the fn itself. diff --git a/salsa-2022-tests/tests/cycles.rs b/salsa-2022-tests/tests/cycles.rs index 111a9685..5bd066ad 100644 --- a/salsa-2022-tests/tests/cycles.rs +++ b/salsa-2022-tests/tests/cycles.rs @@ -40,9 +40,11 @@ use salsa::storage::HasJarsDyn; // | Intra | Fallback | Both | Tracked | direct | cycle_revalidate | // | Intra | Fallback | New | Tracked | direct | cycle_appears | // | Intra | Fallback | Old | Tracked | direct | cycle_disappears | -// | Intra | Fallback | Old | Tracked | direct | cycle_disappears_durability | // | Intra | Mixed | N/A | Tracked | direct | cycle_mixed_1 | // | Intra | Mixed | N/A | Tracked | direct | cycle_mixed_2 | + +// TODO: The following tests are not yet ported. +// | Intra | Fallback | Old | Tracked | direct | cycle_disappears_durability | // | Cross | Fallback | N/A | Tracked | both | parallel/cycles.rs: recover_parallel_cycle | // | Cross | Panic | N/A | Tracked | both | parallel/cycles.rs: panic_parallel_cycle | @@ -257,7 +259,7 @@ fn cycle_recovery_unchanged_twice() { let abc = ABC::new(&mut db, CycleQuery::B, CycleQuery::A, CycleQuery::None); assert!(cycle_a(&db, abc).is_err()); - abc.set_c(&mut db, CycleQuery::A); // same value as default + abc.set_c(&mut db, CycleQuery::A); // force new revision assert!(cycle_a(&db, abc).is_err()); } From 9feb0050e48ed07306363c82e94e9f461b9f0a2a Mon Sep 17 00:00:00 2001 From: XFFXFF <1247714429@qq.com> Date: Tue, 9 Aug 2022 18:06:39 +0800 Subject: [PATCH 6/7] ports parallel related tests for cycle --- salsa-2022-tests/Cargo.toml | 1 + salsa-2022-tests/tests/parallel/main.rs | 6 + .../parallel/parallel_cycle_mid_recover.rs | 105 ++++++++++++++++++ .../parallel/parallel_cycle_none_recover.rs | 58 ++++++++++ .../parallel/parallel_cycle_one_recover.rs | 91 +++++++++++++++ salsa-2022-tests/tests/parallel/setup.rs | 92 +++++++++++++++ salsa-2022-tests/tests/parallel/signal.rs | 40 +++++++ 7 files changed, 393 insertions(+) create mode 100644 salsa-2022-tests/tests/parallel/main.rs create mode 100644 salsa-2022-tests/tests/parallel/parallel_cycle_mid_recover.rs create mode 100644 salsa-2022-tests/tests/parallel/parallel_cycle_none_recover.rs create mode 100644 salsa-2022-tests/tests/parallel/parallel_cycle_one_recover.rs create mode 100644 salsa-2022-tests/tests/parallel/setup.rs create mode 100644 salsa-2022-tests/tests/parallel/signal.rs diff --git a/salsa-2022-tests/Cargo.toml b/salsa-2022-tests/Cargo.toml index f0140c02..edd31a4e 100644 --- a/salsa-2022-tests/Cargo.toml +++ b/salsa-2022-tests/Cargo.toml @@ -8,4 +8,5 @@ edition = "2021" [dependencies] salsa = { path = "../components/salsa-2022", package = "salsa-2022" } expect-test = "1.4.0" +parking_lot = "0.12.1" diff --git a/salsa-2022-tests/tests/parallel/main.rs b/salsa-2022-tests/tests/parallel/main.rs new file mode 100644 index 00000000..a1ded1ce --- /dev/null +++ b/salsa-2022-tests/tests/parallel/main.rs @@ -0,0 +1,6 @@ +mod setup; + +mod signal; +mod parallel_cycle_none_recover; +mod parallel_cycle_one_recover; +mod parallel_cycle_mid_recover; \ No newline at end of file diff --git a/salsa-2022-tests/tests/parallel/parallel_cycle_mid_recover.rs b/salsa-2022-tests/tests/parallel/parallel_cycle_mid_recover.rs new file mode 100644 index 00000000..52c4f94b --- /dev/null +++ b/salsa-2022-tests/tests/parallel/parallel_cycle_mid_recover.rs @@ -0,0 +1,105 @@ +//! Test for cycle recover spread across two threads. +//! See `../cycles.rs` for a complete listing of cycle tests, +//! both intra and cross thread. + +use crate::setup::Knobs; +use crate::setup::Database as DatabaseImpl; +use salsa::ParallelDatabase; +use crate::setup::Jar; +use crate::setup::Db; + +// Recover cycle test: +// +// The pattern is as follows. +// +// Thread A Thread B +// -------- -------- +// a1 b1 +// | wait for stage 1 (blocks) +// signal stage 1 | +// wait for stage 2 (blocks) (unblocked) +// | | +// | b2 +// | b3 +// | a1 (blocks -> stage 2) +// (unblocked) | +// a2 (cycle detected) | +// b3 recovers +// b2 resumes +// b1 panics because bug + +#[test] +fn execute() { + let mut db = DatabaseImpl::default(); + db.knobs().signal_on_will_block.set(3); + + let input = MyInput::new(&mut db, 1); + + let thread_a = std::thread::spawn({ + let db = db.snapshot(); + move || a1(&*db, input) + }); + + let thread_b = std::thread::spawn({ + let db = db.snapshot(); + move || b1(&*db, input) + }); + + // We expect that the recovery function yields + // `1 * 20 + 2`, which is returned (and forwarded) + // to b1, and from there to a2 and a1. + assert_eq!(thread_a.join().unwrap(), 22); + assert_eq!(thread_b.join().unwrap(), 22); +} + +#[salsa::input(jar = Jar)] +pub(crate) struct MyInput { + field: i32 +} + +#[salsa::tracked(jar = Jar)] +pub(crate) fn a1(db: &dyn Db, input: MyInput) -> i32 { + // tell thread b we have started + db.signal(1); + + // wait for thread b to block on a1 + db.wait_for(2); + + a2(db, input) +} +#[salsa::tracked(jar = Jar)] +pub(crate) fn a2(db: &dyn Db, input: MyInput) -> i32 { + // create the cycle + b1(db, input) +} + +#[salsa::tracked(jar = Jar, recovery_fn=recover_b1)] +pub(crate) fn b1(db: &dyn Db, input: MyInput) -> i32 { + // wait for thread a to have started + db.wait_for(1); + b2(db, input) +} + +fn recover_b1(db: &dyn Db, _cycle: &salsa::Cycle, key: MyInput) -> i32 { + dbg!("recover_b1"); + key.field(db) * 20 + 2 +} + +#[salsa::tracked(jar = Jar)] +pub(crate) fn b2(db: &dyn Db, input: MyInput) -> i32 { + // will encounter a cycle but recover + b3(db, input); + b1(db, input); // hasn't recovered yet + 0 +} + +#[salsa::tracked(jar = Jar, recovery_fn=recover_b3)] +pub(crate) fn b3(db: &dyn Db, input: MyInput) -> i32 { + // will block on thread a, signaling stage 2 + a1(db, input) +} + +fn recover_b3(db: &dyn Db, _cycle: &salsa::Cycle, key: MyInput) -> i32 { + dbg!("recover_b3"); + key.field(db) * 200 + 2 +} diff --git a/salsa-2022-tests/tests/parallel/parallel_cycle_none_recover.rs b/salsa-2022-tests/tests/parallel/parallel_cycle_none_recover.rs new file mode 100644 index 00000000..04d0be74 --- /dev/null +++ b/salsa-2022-tests/tests/parallel/parallel_cycle_none_recover.rs @@ -0,0 +1,58 @@ +use crate::setup::Knobs; +use crate::setup::Database as DatabaseImpl; +use expect_test::expect; +use salsa::ParallelDatabase; +use crate::setup::Jar; +use crate::setup::Db; + +#[salsa::input(jar = Jar)] +pub(crate) struct MyInput { + field: i32 +} + +#[salsa::tracked(jar = Jar)] +pub(crate) fn a(db: &dyn Db, input: MyInput) -> i32 { + db.signal(1); + db.wait_for(2); + + b(db, input) +} + +#[salsa::tracked(jar = Jar)] +pub(crate) fn b(db: &dyn Db, input: MyInput) -> i32 { + db.wait_for(1); + db.signal(2); + + db.wait_for(3); + a(db, input) +} + +#[test] +fn execute() { + let mut db = DatabaseImpl::default(); + db.knobs().signal_on_will_block.set(3); + + let input = MyInput::new(&mut db, -1); + + std::thread::spawn({ + let db = db.snapshot(); + move || a(&*db, input) + }); + + let thread_b = std::thread::spawn({ + let db = db.snapshot(); + move || b(&*db, input) + }); + let err_b = thread_b.join().unwrap_err(); + if let Some(c) = err_b.downcast_ref::() { + let expected = expect![[r#" + [ + "DependencyIndex { ingredient_index: IngredientIndex(2), key_index: Some(Id { value: 1 }) }", + "DependencyIndex { ingredient_index: IngredientIndex(3), key_index: Some(Id { value: 1 }) }", + ] + "#]]; + expected.assert_debug_eq(&c.all_participants(&db)); + } else { + panic!("b failed in an unexpected way: {:?}", err_b); + } +} \ No newline at end of file diff --git a/salsa-2022-tests/tests/parallel/parallel_cycle_one_recover.rs b/salsa-2022-tests/tests/parallel/parallel_cycle_one_recover.rs new file mode 100644 index 00000000..cdce86d7 --- /dev/null +++ b/salsa-2022-tests/tests/parallel/parallel_cycle_one_recover.rs @@ -0,0 +1,91 @@ +//! Test for cycle recover spread across two threads. +//! See `../cycles.rs` for a complete listing of cycle tests, +//! both intra and cross thread. + +use crate::setup::Knobs; +use crate::setup::Database as DatabaseImpl; +use salsa::ParallelDatabase; +use crate::setup::Jar; +use crate::setup::Db; + +#[salsa::input(jar = Jar)] +pub(crate) struct MyInput { + field: i32 +} + +#[salsa::tracked(jar = Jar)] +pub(crate) fn a1(db: &dyn Db, input: MyInput) -> i32 { + db.signal(1); + db.wait_for(2); + + a2(db, input) +} +#[salsa::tracked(jar = Jar, recovery_fn=recover)] +pub(crate) fn a2(db: &dyn Db, input: MyInput) -> i32 { + b1(db, input) +} + +fn recover(db: &dyn Db, _cycle: &salsa::Cycle, key: MyInput) -> i32 { + dbg!("recover"); + key.field(db) * 20 + 2 +} + +#[salsa::tracked(jar = Jar)] +pub(crate) fn b1(db: &dyn Db, input: MyInput) -> i32 { + db.wait_for(1); + db.signal(2); + + db.wait_for(3); + b2(db, input) +} + +#[salsa::tracked(jar = Jar)] +pub(crate) fn b2(db: &dyn Db, input: MyInput) -> i32 { + a1(db, input) +} + +// Recover cycle test: +// +// The pattern is as follows. +// +// Thread A Thread B +// -------- -------- +// a1 b1 +// | wait for stage 1 (blocks) +// signal stage 1 | +// wait for stage 2 (blocks) (unblocked) +// | signal stage 2 +// (unblocked) wait for stage 3 (blocks) +// a2 | +// b1 (blocks -> stage 3) | +// | (unblocked) +// | b2 +// | a1 (cycle detected) +// a2 recovery fn executes | +// a1 completes normally | +// b2 completes, recovers +// b1 completes, recovers + +#[test] +fn execute() { + let mut db = DatabaseImpl::default(); + db.knobs().signal_on_will_block.set(3); + + let input = MyInput::new(&mut db, 1); + + let thread_a = std::thread::spawn({ + let db = db.snapshot(); + move || a1(&*db, input) + }); + + let thread_b = std::thread::spawn({ + let db = db.snapshot(); + move || b1(&*db, input) + }); + + // We expect that the recovery function yields + // `1 * 20 + 2`, which is returned (and forwarded) + // to b1, and from there to a2 and a1. + assert_eq!(thread_a.join().unwrap(), 22); + assert_eq!(thread_b.join().unwrap(), 22); +} \ No newline at end of file diff --git a/salsa-2022-tests/tests/parallel/setup.rs b/salsa-2022-tests/tests/parallel/setup.rs new file mode 100644 index 00000000..716f1841 --- /dev/null +++ b/salsa-2022-tests/tests/parallel/setup.rs @@ -0,0 +1,92 @@ +use std::{sync::Arc, cell::Cell}; + +use crate::signal::Signal; + +/// Various "knobs" and utilities used by tests to force +/// a certain behavior. +pub(crate) trait Knobs { + fn knobs(&self) -> &KnobsStruct; + + fn signal(&self, stage: usize); + + fn wait_for(&self, stage: usize); +} + +/// Various "knobs" that can be used to customize how the queries +/// behave on one specific thread. Note that this state is +/// intentionally thread-local (apart from `signal`). +#[derive(Clone, Default)] +pub(crate) struct KnobsStruct { + /// A kind of flexible barrier used to coordinate execution across + /// threads to ensure we reach various weird states. + pub(crate) signal: Arc, + + /// When this database is about to block, send a signal. + pub(crate) signal_on_will_block: Cell, +} + +#[salsa::jar(db = Db)] +pub(crate) struct Jar( + crate::parallel_cycle_none_recover::MyInput, + crate::parallel_cycle_none_recover::a, + crate::parallel_cycle_none_recover::b, + crate::parallel_cycle_one_recover::MyInput, + crate::parallel_cycle_one_recover::a1, + crate::parallel_cycle_one_recover::a2, + crate::parallel_cycle_one_recover::b1, + crate::parallel_cycle_one_recover::b2, + crate::parallel_cycle_mid_recover::MyInput, + crate::parallel_cycle_mid_recover::a1, + crate::parallel_cycle_mid_recover::a2, + crate::parallel_cycle_mid_recover::b1, + crate::parallel_cycle_mid_recover::b2, + crate::parallel_cycle_mid_recover::b3, +); + +pub(crate) trait Db: salsa::DbWithJar + Knobs {} + +#[salsa::db(Jar)] +#[derive(Default)] +pub(crate) struct Database { + storage: salsa::Storage, + knobs: KnobsStruct +} + +impl salsa::Database for Database { + fn salsa_runtime(&self) -> &salsa::Runtime { + self.storage.runtime() + } + + fn salsa_event(&self, event: salsa::Event) { + if let salsa::EventKind::WillBlockOn { .. } = event.kind { + self.signal(self.knobs().signal_on_will_block.get()); + } + } +} + +impl salsa::ParallelDatabase for Database { + fn snapshot(&self) -> salsa::Snapshot { + salsa::Snapshot::new( + Database { + storage: self.storage.snapshot(), + knobs: self.knobs.clone() + } + ) + } +} + +impl Db for Database {} + +impl Knobs for Database { + fn knobs(&self) -> &KnobsStruct { + &self.knobs + } + + fn signal(&self, stage: usize) { + self.knobs.signal.signal(stage); + } + + fn wait_for(&self, stage: usize) { + self.knobs.signal.wait_for(stage); + } +} diff --git a/salsa-2022-tests/tests/parallel/signal.rs b/salsa-2022-tests/tests/parallel/signal.rs new file mode 100644 index 00000000..f09aecc8 --- /dev/null +++ b/salsa-2022-tests/tests/parallel/signal.rs @@ -0,0 +1,40 @@ +use parking_lot::{Condvar, Mutex}; + +#[derive(Default)] +pub(crate) struct Signal { + value: Mutex, + cond_var: Condvar, +} + +impl Signal { + pub(crate) fn signal(&self, stage: usize) { + dbg!(format!("signal({})", stage)); + + // This check avoids acquiring the lock for things that will + // clearly be a no-op. Not *necessary* but helps to ensure we + // are more likely to encounter weird race conditions; + // otherwise calls to `sum` will tend to be unnecessarily + // synchronous. + if stage > 0 { + let mut v = self.value.lock(); + if stage > *v { + *v = stage; + self.cond_var.notify_all(); + } + } + } + + /// Waits until the given condition is true; the fn is invoked + /// with the current stage. + pub(crate) fn wait_for(&self, stage: usize) { + dbg!(format!("wait_for({})", stage)); + + // As above, avoid lock if clearly a no-op. + if stage > 0 { + let mut v = self.value.lock(); + while *v < stage { + self.cond_var.wait(&mut v); + } + } + } +} From d32cff1bfba2f200f8639880b9817d9e2a09ff94 Mon Sep 17 00:00:00 2001 From: XFFXFF <1247714429@qq.com> Date: Tue, 9 Aug 2022 18:33:46 +0800 Subject: [PATCH 7/7] define Jar struct separately --- salsa-2022-tests/tests/cycles.rs | 8 +- salsa-2022-tests/tests/parallel/main.rs | 5 +- .../parallel/parallel_cycle_all_recover.rs | 112 ++++++++++++++++++ .../parallel/parallel_cycle_mid_recover.rs | 95 ++++++++------- .../parallel/parallel_cycle_none_recover.rs | 43 +++++-- .../parallel/parallel_cycle_one_recover.rs | 20 +++- salsa-2022-tests/tests/parallel/setup.rs | 43 ++----- 7 files changed, 230 insertions(+), 96 deletions(-) create mode 100644 salsa-2022-tests/tests/parallel/parallel_cycle_all_recover.rs diff --git a/salsa-2022-tests/tests/cycles.rs b/salsa-2022-tests/tests/cycles.rs index 5bd066ad..5a88ed9a 100644 --- a/salsa-2022-tests/tests/cycles.rs +++ b/salsa-2022-tests/tests/cycles.rs @@ -42,11 +42,13 @@ use salsa::storage::HasJarsDyn; // | Intra | Fallback | Old | Tracked | direct | cycle_disappears | // | Intra | Mixed | N/A | Tracked | direct | cycle_mixed_1 | // | Intra | Mixed | N/A | Tracked | direct | cycle_mixed_2 | +// | Cross | Panic | N/A | Tracked | both | parallel/parallel_cycle_none_recover.rs | +// | Cross | Fallback | N/A | Tracked | both | parallel/parallel_cycle_one_recover.rs | +// | Cross | Fallback | N/A | Tracked | both | parallel/parallel_cycle_mid_recover.rs | +// | Cross | Fallback | N/A | Tracked | both | parallel/parallel_cycle_all_recover.rs | -// TODO: The following tests are not yet ported. +// TODO: The following test is not yet ported. // | Intra | Fallback | Old | Tracked | direct | cycle_disappears_durability | -// | Cross | Fallback | N/A | Tracked | both | parallel/cycles.rs: recover_parallel_cycle | -// | Cross | Panic | N/A | Tracked | both | parallel/cycles.rs: panic_parallel_cycle | #[derive(PartialEq, Eq, Hash, Clone, Debug)] struct Error { diff --git a/salsa-2022-tests/tests/parallel/main.rs b/salsa-2022-tests/tests/parallel/main.rs index a1ded1ce..3f8ce0e2 100644 --- a/salsa-2022-tests/tests/parallel/main.rs +++ b/salsa-2022-tests/tests/parallel/main.rs @@ -1,6 +1,7 @@ mod setup; -mod signal; +mod parallel_cycle_all_recover; +mod parallel_cycle_mid_recover; mod parallel_cycle_none_recover; mod parallel_cycle_one_recover; -mod parallel_cycle_mid_recover; \ No newline at end of file +mod signal; diff --git a/salsa-2022-tests/tests/parallel/parallel_cycle_all_recover.rs b/salsa-2022-tests/tests/parallel/parallel_cycle_all_recover.rs new file mode 100644 index 00000000..9c42abbf --- /dev/null +++ b/salsa-2022-tests/tests/parallel/parallel_cycle_all_recover.rs @@ -0,0 +1,112 @@ +//! Test for cycle recover spread across two threads. +//! See `../cycles.rs` for a complete listing of cycle tests, +//! both intra and cross thread. + +use crate::setup::Database; +use crate::setup::Knobs; +use salsa::ParallelDatabase; + +pub(crate) trait Db: salsa::DbWithJar + Knobs {} + +impl + Knobs> Db for T {} + +#[salsa::jar(db = Db)] +pub(crate) struct Jar(MyInput, a1, a2, b1, b2); + +#[salsa::input(jar = Jar)] +pub(crate) struct MyInput { + field: i32, +} + +#[salsa::tracked(jar = Jar, recovery_fn=recover_a1)] +pub(crate) fn a1(db: &dyn Db, input: MyInput) -> i32 { + // Wait to create the cycle until both threads have entered + db.signal(1); + db.wait_for(2); + + a2(db, input) +} + +fn recover_a1(db: &dyn Db, _cycle: &salsa::Cycle, key: MyInput) -> i32 { + dbg!("recover_a1"); + key.field(db) * 10 + 1 +} + +#[salsa::tracked(jar = Jar, recovery_fn=recover_a2)] +pub(crate) fn a2(db: &dyn Db, input: MyInput) -> i32 { + b1(db, input) +} + +fn recover_a2(db: &dyn Db, _cycle: &salsa::Cycle, key: MyInput) -> i32 { + dbg!("recover_a2"); + key.field(db) * 10 + 2 +} + +#[salsa::tracked(jar = Jar, recovery_fn=recover_b1)] +pub(crate) fn b1(db: &dyn Db, input: MyInput) -> i32 { + // Wait to create the cycle until both threads have entered + db.wait_for(1); + db.signal(2); + + // Wait for thread A to block on this thread + db.wait_for(3); + b2(db, input) +} + +fn recover_b1(db: &dyn Db, _cycle: &salsa::Cycle, key: MyInput) -> i32 { + dbg!("recover_b1"); + key.field(db) * 20 + 1 +} + +#[salsa::tracked(jar = Jar, recovery_fn=recover_b2)] +pub(crate) fn b2(db: &dyn Db, input: MyInput) -> i32 { + a1(db, input) +} + +fn recover_b2(db: &dyn Db, _cycle: &salsa::Cycle, key: MyInput) -> i32 { + dbg!("recover_b2"); + key.field(db) * 20 + 2 +} + +// Recover cycle test: +// +// The pattern is as follows. +// +// Thread A Thread B +// -------- -------- +// a1 b1 +// | wait for stage 1 (blocks) +// signal stage 1 | +// wait for stage 2 (blocks) (unblocked) +// | signal stage 2 +// (unblocked) wait for stage 3 (blocks) +// a2 | +// b1 (blocks -> stage 3) | +// | (unblocked) +// | b2 +// | a1 (cycle detected, recovers) +// | b2 completes, recovers +// | b1 completes, recovers +// a2 sees cycle, recovers +// a1 completes, recovers + +#[test] +fn execute() { + let mut db = Database::default(); + db.knobs().signal_on_will_block.set(3); + + let input = MyInput::new(&mut db, 1); + + let thread_a = std::thread::spawn({ + let db = db.snapshot(); + move || a1(&*db, input) + }); + + let thread_b = std::thread::spawn({ + let db = db.snapshot(); + move || b1(&*db, input) + }); + + assert_eq!(thread_a.join().unwrap(), 11); + assert_eq!(thread_b.join().unwrap(), 21); +} diff --git a/salsa-2022-tests/tests/parallel/parallel_cycle_mid_recover.rs b/salsa-2022-tests/tests/parallel/parallel_cycle_mid_recover.rs index 52c4f94b..4256082f 100644 --- a/salsa-2022-tests/tests/parallel/parallel_cycle_mid_recover.rs +++ b/salsa-2022-tests/tests/parallel/parallel_cycle_mid_recover.rs @@ -2,59 +2,20 @@ //! See `../cycles.rs` for a complete listing of cycle tests, //! both intra and cross thread. +use crate::setup::Database; use crate::setup::Knobs; -use crate::setup::Database as DatabaseImpl; use salsa::ParallelDatabase; -use crate::setup::Jar; -use crate::setup::Db; -// Recover cycle test: -// -// The pattern is as follows. -// -// Thread A Thread B -// -------- -------- -// a1 b1 -// | wait for stage 1 (blocks) -// signal stage 1 | -// wait for stage 2 (blocks) (unblocked) -// | | -// | b2 -// | b3 -// | a1 (blocks -> stage 2) -// (unblocked) | -// a2 (cycle detected) | -// b3 recovers -// b2 resumes -// b1 panics because bug +pub(crate) trait Db: salsa::DbWithJar + Knobs {} -#[test] -fn execute() { - let mut db = DatabaseImpl::default(); - db.knobs().signal_on_will_block.set(3); +impl + Knobs> Db for T {} - let input = MyInput::new(&mut db, 1); - - let thread_a = std::thread::spawn({ - let db = db.snapshot(); - move || a1(&*db, input) - }); - - let thread_b = std::thread::spawn({ - let db = db.snapshot(); - move || b1(&*db, input) - }); - - // We expect that the recovery function yields - // `1 * 20 + 2`, which is returned (and forwarded) - // to b1, and from there to a2 and a1. - assert_eq!(thread_a.join().unwrap(), 22); - assert_eq!(thread_b.join().unwrap(), 22); -} +#[salsa::jar(db = Db)] +pub(crate) struct Jar(MyInput, a1, a2, b1, b2, b3); #[salsa::input(jar = Jar)] pub(crate) struct MyInput { - field: i32 + field: i32, } #[salsa::tracked(jar = Jar)] @@ -103,3 +64,47 @@ fn recover_b3(db: &dyn Db, _cycle: &salsa::Cycle, key: MyInput) -> i32 { dbg!("recover_b3"); key.field(db) * 200 + 2 } + +// Recover cycle test: +// +// The pattern is as follows. +// +// Thread A Thread B +// -------- -------- +// a1 b1 +// | wait for stage 1 (blocks) +// signal stage 1 | +// wait for stage 2 (blocks) (unblocked) +// | | +// | b2 +// | b3 +// | a1 (blocks -> stage 2) +// (unblocked) | +// a2 (cycle detected) | +// b3 recovers +// b2 resumes +// b1 recovers + +#[test] +fn execute() { + let mut db = Database::default(); + db.knobs().signal_on_will_block.set(3); + + let input = MyInput::new(&mut db, 1); + + let thread_a = std::thread::spawn({ + let db = db.snapshot(); + move || a1(&*db, input) + }); + + let thread_b = std::thread::spawn({ + let db = db.snapshot(); + move || b1(&*db, input) + }); + + // We expect that the recovery function yields + // `1 * 20 + 2`, which is returned (and forwarded) + // to b1, and from there to a2 and a1. + assert_eq!(thread_a.join().unwrap(), 22); + assert_eq!(thread_b.join().unwrap(), 22); +} diff --git a/salsa-2022-tests/tests/parallel/parallel_cycle_none_recover.rs b/salsa-2022-tests/tests/parallel/parallel_cycle_none_recover.rs index 04d0be74..5a0227f2 100644 --- a/salsa-2022-tests/tests/parallel/parallel_cycle_none_recover.rs +++ b/salsa-2022-tests/tests/parallel/parallel_cycle_none_recover.rs @@ -1,17 +1,27 @@ +//! Test a cycle where no queries recover that occurs across threads. +//! See the `../cycles.rs` for a complete listing of cycle tests, +//! both intra and cross thread. + +use crate::setup::Database; use crate::setup::Knobs; -use crate::setup::Database as DatabaseImpl; use expect_test::expect; use salsa::ParallelDatabase; -use crate::setup::Jar; -use crate::setup::Db; + +pub(crate) trait Db: salsa::DbWithJar + Knobs {} + +impl + Knobs> Db for T {} + +#[salsa::jar(db = Db)] +pub(crate) struct Jar(MyInput, a, b); #[salsa::input(jar = Jar)] pub(crate) struct MyInput { - field: i32 + field: i32, } #[salsa::tracked(jar = Jar)] pub(crate) fn a(db: &dyn Db, input: MyInput) -> i32 { + // Wait to create the cycle until both threads have entered db.signal(1); db.wait_for(2); @@ -20,21 +30,25 @@ pub(crate) fn a(db: &dyn Db, input: MyInput) -> i32 { #[salsa::tracked(jar = Jar)] pub(crate) fn b(db: &dyn Db, input: MyInput) -> i32 { + // Wait to create the cycle until both threads have entered db.wait_for(1); db.signal(2); + // Wait for thread A to block on this thread db.wait_for(3); + + // Now try to execute A a(db, input) } #[test] fn execute() { - let mut db = DatabaseImpl::default(); + let mut db = Database::default(); db.knobs().signal_on_will_block.set(3); let input = MyInput::new(&mut db, -1); - std::thread::spawn({ + let thread_a = std::thread::spawn({ let db = db.snapshot(); move || a(&*db, input) }); @@ -43,16 +57,27 @@ fn execute() { let db = db.snapshot(); move || b(&*db, input) }); + + // We expect B to panic because it detects a cycle (it is the one that calls A, ultimately). + // Right now, it panics with a string. let err_b = thread_b.join().unwrap_err(); if let Some(c) = err_b.downcast_ref::() { let expected = expect![[r#" [ - "DependencyIndex { ingredient_index: IngredientIndex(2), key_index: Some(Id { value: 1 }) }", - "DependencyIndex { ingredient_index: IngredientIndex(3), key_index: Some(Id { value: 1 }) }", + "DependencyIndex { ingredient_index: IngredientIndex(8), key_index: Some(Id { value: 1 }) }", + "DependencyIndex { ingredient_index: IngredientIndex(9), key_index: Some(Id { value: 1 }) }", ] "#]]; expected.assert_debug_eq(&c.all_participants(&db)); } else { panic!("b failed in an unexpected way: {:?}", err_b); } -} \ No newline at end of file + + // We expect A to propagate a panic, which causes us to use the sentinel + // type `Canceled`. + assert!(thread_a + .join() + .unwrap_err() + .downcast_ref::() + .is_some()); +} diff --git a/salsa-2022-tests/tests/parallel/parallel_cycle_one_recover.rs b/salsa-2022-tests/tests/parallel/parallel_cycle_one_recover.rs index cdce86d7..becdcddd 100644 --- a/salsa-2022-tests/tests/parallel/parallel_cycle_one_recover.rs +++ b/salsa-2022-tests/tests/parallel/parallel_cycle_one_recover.rs @@ -2,19 +2,25 @@ //! See `../cycles.rs` for a complete listing of cycle tests, //! both intra and cross thread. +use crate::setup::Database; use crate::setup::Knobs; -use crate::setup::Database as DatabaseImpl; use salsa::ParallelDatabase; -use crate::setup::Jar; -use crate::setup::Db; + +pub(crate) trait Db: salsa::DbWithJar + Knobs {} + +impl + Knobs> Db for T {} + +#[salsa::jar(db = Db)] +pub(crate) struct Jar(MyInput, a1, a2, b1, b2); #[salsa::input(jar = Jar)] pub(crate) struct MyInput { - field: i32 + field: i32, } #[salsa::tracked(jar = Jar)] pub(crate) fn a1(db: &dyn Db, input: MyInput) -> i32 { + // Wait to create the cycle until both threads have entered db.signal(1); db.wait_for(2); @@ -32,9 +38,11 @@ fn recover(db: &dyn Db, _cycle: &salsa::Cycle, key: MyInput) -> i32 { #[salsa::tracked(jar = Jar)] pub(crate) fn b1(db: &dyn Db, input: MyInput) -> i32 { + // Wait to create the cycle until both threads have entered db.wait_for(1); db.signal(2); + // Wait for thread A to block on this thread db.wait_for(3); b2(db, input) } @@ -68,7 +76,7 @@ pub(crate) fn b2(db: &dyn Db, input: MyInput) -> i32 { #[test] fn execute() { - let mut db = DatabaseImpl::default(); + let mut db = Database::default(); db.knobs().signal_on_will_block.set(3); let input = MyInput::new(&mut db, 1); @@ -88,4 +96,4 @@ fn execute() { // to b1, and from there to a2 and a1. assert_eq!(thread_a.join().unwrap(), 22); assert_eq!(thread_b.join().unwrap(), 22); -} \ No newline at end of file +} diff --git a/salsa-2022-tests/tests/parallel/setup.rs b/salsa-2022-tests/tests/parallel/setup.rs index 716f1841..8012928c 100644 --- a/salsa-2022-tests/tests/parallel/setup.rs +++ b/salsa-2022-tests/tests/parallel/setup.rs @@ -1,4 +1,4 @@ -use std::{sync::Arc, cell::Cell}; +use std::{cell::Cell, sync::Arc}; use crate::signal::Signal; @@ -25,31 +25,16 @@ pub(crate) struct KnobsStruct { pub(crate) signal_on_will_block: Cell, } -#[salsa::jar(db = Db)] -pub(crate) struct Jar( - crate::parallel_cycle_none_recover::MyInput, - crate::parallel_cycle_none_recover::a, - crate::parallel_cycle_none_recover::b, - crate::parallel_cycle_one_recover::MyInput, - crate::parallel_cycle_one_recover::a1, - crate::parallel_cycle_one_recover::a2, - crate::parallel_cycle_one_recover::b1, - crate::parallel_cycle_one_recover::b2, - crate::parallel_cycle_mid_recover::MyInput, - crate::parallel_cycle_mid_recover::a1, - crate::parallel_cycle_mid_recover::a2, - crate::parallel_cycle_mid_recover::b1, - crate::parallel_cycle_mid_recover::b2, - crate::parallel_cycle_mid_recover::b3, -); - -pub(crate) trait Db: salsa::DbWithJar + Knobs {} - -#[salsa::db(Jar)] +#[salsa::db( + crate::parallel_cycle_one_recover::Jar, + crate::parallel_cycle_none_recover::Jar, + crate::parallel_cycle_mid_recover::Jar, + crate::parallel_cycle_all_recover::Jar +)] #[derive(Default)] pub(crate) struct Database { storage: salsa::Storage, - knobs: KnobsStruct + knobs: KnobsStruct, } impl salsa::Database for Database { @@ -66,17 +51,13 @@ impl salsa::Database for Database { impl salsa::ParallelDatabase for Database { fn snapshot(&self) -> salsa::Snapshot { - salsa::Snapshot::new( - Database { - storage: self.storage.snapshot(), - knobs: self.knobs.clone() - } - ) + salsa::Snapshot::new(Database { + storage: self.storage.snapshot(), + knobs: self.knobs.clone(), + }) } } -impl Db for Database {} - impl Knobs for Database { fn knobs(&self) -> &KnobsStruct { &self.knobs