336: Add options to tracked functions for cycle recovery r=nikomatsakis a=XFFXFF

closes #331 

This pr ports the old salsa tests for cycle in a single thread, except for [cycle_disappears_durability](03a27a7054/tests/cycles.rs (L326)), since we don't have the api that permits setting durability.  

~I haven't ported parallel related tests, which would be some work, wondering if we can merge this in first~

Co-authored-by: XFFXFF <1247714429@qq.com>
This commit is contained in:
bors[bot] 2022-08-10 04:14:16 +00:00 committed by GitHub
commit 9ff6fb3376
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 1010 additions and 7 deletions

View file

@ -30,6 +30,8 @@ impl crate::options::AllowedOptions for Accumulator {
const DATA: bool = false;
const DB: bool = false;
const RECOVERY_FN: bool = false;
}
fn accumulator_contents(

View file

@ -39,6 +39,8 @@ impl crate::options::AllowedOptions for Jar {
const DATA: bool = false;
const DB: bool = true;
const RECOVERY_FN: bool = false;
}
pub(crate) fn jar_struct_and_friends(

View file

@ -27,14 +27,19 @@ pub(crate) struct Options<A: AllowedOptions> {
/// The `jar = <type>` option is used to indicate the jar; it defaults to `crate::jar`.
///
/// If this is `Some`, the value is the `<path>`.
/// If this is `Some`, the value is the `<type>`.
pub jar_ty: Option<syn::Type>,
/// The `db = <type>` option is used to indicate the db.
/// The `db = <path>` option is used to indicate the db.
///
/// If this is `Some`, the value is the `<type>`.
/// If this is `Some`, the value is the `<path>`.
pub db_path: Option<syn::Path>,
/// The `recovery_fn = <path>` option is used to indicate the recovery function.
///
/// If this is `Some`, the value is the `<path>`.
pub recovery_fn: Option<syn::Path>,
/// The `data = <ident>` option is used to define the name of the data type for an interned
/// struct.
///
@ -53,6 +58,7 @@ impl<A: AllowedOptions> Default for Options<A> {
no_eq: Default::default(),
jar_ty: Default::default(),
db_path: Default::default(),
recovery_fn: Default::default(),
data: Default::default(),
phantom: Default::default(),
}
@ -67,6 +73,7 @@ pub(crate) trait AllowedOptions {
const JAR: bool;
const DATA: bool;
const DB: bool;
const RECOVERY_FN: bool;
}
type Equals = syn::Token![=];
@ -159,6 +166,22 @@ impl<A: AllowedOptions> syn::parse::Parse for Options<A> {
"`db` option not allowed here",
));
}
} else if ident == "recovery_fn" {
if A::RECOVERY_FN {
let _eq = Equals::parse(input)?;
let path = syn::Path::parse(input)?;
if let Some(old) = std::mem::replace(&mut options.recovery_fn, Some(path)) {
return Err(syn::Error::new(
old.span(),
"option `recovery_fn` provided twice",
));
}
} else {
return Err(syn::Error::new(
ident.span(),
"`recovery_fn` option not allowed here",
));
}
} else if ident == "data" {
if A::DATA {
let _eq = Equals::parse(input)?;

View file

@ -47,6 +47,8 @@ impl crate::options::AllowedOptions for SalsaStruct {
const DATA: bool = true;
const DB: bool = false;
const RECOVERY_FN: bool = false;
}
const BANNED_FIELD_NAMES: &[&str] = &["from", "new"];

View file

@ -70,6 +70,8 @@ impl crate::options::AllowedOptions for TrackedFn {
const DATA: bool = false;
const DB: bool = false;
const RECOVERY_FN: bool = true;
}
/// Returns the key type for this tracked function.
@ -132,14 +134,34 @@ fn fn_configuration(args: &Args, item_fn: &syn::ItemFn) -> Configuration {
};
let value_ty = configuration::value_ty(&item_fn.sig);
// FIXME: these are hardcoded for now
let cycle_strategy = CycleRecoveryStrategy::Panic;
let fn_ty = item_fn.sig.ident.clone();
let indices = (0..item_fn.sig.inputs.len() - 1).map(|i| Literal::usize_unsuffixed(i));
let (cycle_strategy, recover_fn) = if let Some(recovery_fn) = &args.recovery_fn {
// Create the `recover_from_cycle` function, which (a) maps from the interned id to the actual
// keys and then (b) invokes the recover function itself.
let cycle_strategy = CycleRecoveryStrategy::Fallback;
let cycle_fullback = parse_quote! {
fn recover_from_cycle(__db: &salsa::function::DynDb<Self>, __cycle: &salsa::Cycle, __id: Self::Key) -> Self::Value {
let (__jar, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(__db);
let __ingredients =
<_ as salsa::storage::HasIngredientsFor<#fn_ty>>::ingredient(__jar);
let __key = __ingredients.intern_map.data(__runtime, __id).clone();
#recovery_fn(__db, __cycle, #(__key.#indices),*)
}
};
(cycle_strategy, cycle_fullback)
} else {
// When the `recovery_fn` attribute is not set, set `cycle_strategy` to `Panic`
let cycle_strategy = CycleRecoveryStrategy::Panic;
let cycle_panic = configuration::panic_cycle_recovery_fn();
(cycle_strategy, cycle_panic)
};
let backdate_fn = configuration::should_backdate_value_fn(args.should_backdate());
let recover_fn = configuration::panic_cycle_recovery_fn();
// The type of the configuration struct; this has the same name as the fn itself.
let fn_ty = item_fn.sig.ident.clone();
// Make a copy of the fn with a different name; we will invoke this from `execute`.
// We need to change the name because, otherwise, if the function invoked itself

View file

@ -8,4 +8,5 @@ edition = "2021"
[dependencies]
salsa = { path = "../components/salsa-2022", package = "salsa-2022" }
expect-test = "1.4.0"
parking_lot = "0.12.1"

View file

@ -0,0 +1,427 @@
#![allow(warnings)]
use std::panic::{RefUnwindSafe, UnwindSafe};
use expect_test::expect;
use salsa::storage::HasJarsDyn;
// Axes:
//
// Threading
// * Intra-thread
// * Cross-thread -- part of cycle is on one thread, part on another
//
// Recovery strategies:
// * Panic
// * Fallback
// * Mixed -- multiple strategies within cycle participants
//
// Across revisions:
// * N/A -- only one revision
// * Present in new revision, not old
// * Present in old revision, not new
// * Present in both revisions
//
// Dependencies
// * Tracked
// * Untracked -- cycle participant(s) contain untracked reads
//
// Layers
// * Direct -- cycle participant is directly invoked from test
// * Indirect -- invoked a query that invokes the cycle
//
//
// | Thread | Recovery | Old, New | Dep style | Layers | Test Name |
// | ------ | -------- | -------- | --------- | ------ | --------- |
// | Intra | Panic | N/A | Tracked | direct | cycle_memoized |
// | Intra | Panic | N/A | Untracked | direct | cycle_volatile |
// | Intra | Fallback | N/A | Tracked | direct | cycle_cycle |
// | Intra | Fallback | N/A | Tracked | indirect | inner_cycle |
// | Intra | Fallback | Both | Tracked | direct | cycle_revalidate |
// | Intra | Fallback | New | Tracked | direct | cycle_appears |
// | Intra | Fallback | Old | Tracked | direct | cycle_disappears |
// | Intra | Mixed | N/A | Tracked | direct | cycle_mixed_1 |
// | Intra | Mixed | N/A | Tracked | direct | cycle_mixed_2 |
// | Cross | Panic | N/A | Tracked | both | parallel/parallel_cycle_none_recover.rs |
// | Cross | Fallback | N/A | Tracked | both | parallel/parallel_cycle_one_recover.rs |
// | Cross | Fallback | N/A | Tracked | both | parallel/parallel_cycle_mid_recover.rs |
// | Cross | Fallback | N/A | Tracked | both | parallel/parallel_cycle_all_recover.rs |
// TODO: The following test is not yet ported.
// | Intra | Fallback | Old | Tracked | direct | cycle_disappears_durability |
#[derive(PartialEq, Eq, Hash, Clone, Debug)]
struct Error {
cycle: Vec<String>,
}
#[salsa::jar(db = Db)]
struct Jar(
MyInput,
memoized_a,
memoized_b,
volatile_a,
volatile_b,
ABC,
cycle_a,
cycle_b,
cycle_c,
);
trait Db: salsa::DbWithJar<Jar> {}
#[salsa::db(Jar)]
#[derive(Default)]
struct Database {
storage: salsa::Storage<Self>,
}
impl salsa::Database for Database {
fn salsa_runtime(&self) -> &salsa::Runtime {
self.storage.runtime()
}
}
impl Db for Database {}
impl RefUnwindSafe for Database {}
#[salsa::input(jar = Jar)]
struct MyInput {}
#[salsa::tracked(jar = Jar)]
fn memoized_a(db: &dyn Db, input: MyInput) {
memoized_b(db, input)
}
#[salsa::tracked(jar = Jar)]
fn memoized_b(db: &dyn Db, input: MyInput) {
memoized_a(db, input)
}
#[salsa::tracked(jar = Jar)]
fn volatile_a(db: &dyn Db, input: MyInput) {
db.runtime().report_untracked_read();
volatile_b(db, input)
}
#[salsa::tracked(jar = Jar)]
fn volatile_b(db: &dyn Db, input: MyInput) {
db.runtime().report_untracked_read();
volatile_a(db, input)
}
/// The queries A, B, and C in `Database` can be configured
/// to invoke one another in arbitrary ways using this
/// enum.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
enum CycleQuery {
None,
A,
B,
C,
AthenC,
}
#[salsa::input(jar = Jar)]
struct ABC {
a: CycleQuery,
b: CycleQuery,
c: CycleQuery,
}
impl CycleQuery {
fn invoke(self, db: &dyn Db, abc: ABC) -> Result<(), Error> {
match self {
CycleQuery::A => cycle_a(db, abc),
CycleQuery::B => cycle_b(db, abc),
CycleQuery::C => cycle_c(db, abc),
CycleQuery::AthenC => {
let _ = cycle_a(db, abc);
cycle_c(db, abc)
}
CycleQuery::None => Ok(()),
}
}
}
#[salsa::tracked(jar = Jar, recovery_fn=recover_a)]
fn cycle_a(db: &dyn Db, abc: ABC) -> Result<(), Error> {
abc.a(db).invoke(db, abc)
}
fn recover_a(db: &dyn Db, cycle: &salsa::Cycle, abc: ABC) -> Result<(), Error> {
Err(Error {
cycle: cycle.all_participants(db),
})
}
#[salsa::tracked(jar = Jar, recovery_fn=recover_b)]
fn cycle_b(db: &dyn Db, abc: ABC) -> Result<(), Error> {
abc.b(db).invoke(db, abc)
}
fn recover_b(db: &dyn Db, cycle: &salsa::Cycle, abc: ABC) -> Result<(), Error> {
Err(Error {
cycle: cycle.all_participants(db),
})
}
#[salsa::tracked(jar = Jar)]
fn cycle_c(db: &dyn Db, abc: ABC) -> Result<(), Error> {
abc.c(db).invoke(db, abc)
}
#[track_caller]
fn extract_cycle(f: impl FnOnce() + UnwindSafe) -> salsa::Cycle {
let v = std::panic::catch_unwind(f);
if let Err(d) = &v {
if let Some(cycle) = d.downcast_ref::<salsa::Cycle>() {
return cycle.clone();
}
}
panic!("unexpected value: {:?}", v)
}
#[test]
fn cycle_memoized() {
let mut db = Database::default();
let input = MyInput::new(&mut db);
let cycle = extract_cycle(|| memoized_a(&db, input));
let expected = expect![[r#"
[
"DependencyIndex { ingredient_index: IngredientIndex(1), key_index: Some(Id { value: 1 }) }",
"DependencyIndex { ingredient_index: IngredientIndex(2), key_index: Some(Id { value: 1 }) }",
]
"#]];
expected.assert_debug_eq(&cycle.all_participants(&db));
}
#[test]
fn cycle_volatile() {
let mut db = Database::default();
let input = MyInput::new(&mut db);
let cycle = extract_cycle(|| volatile_a(&db, input));
let expected = expect![[r#"
[
"DependencyIndex { ingredient_index: IngredientIndex(3), key_index: Some(Id { value: 1 }) }",
"DependencyIndex { ingredient_index: IngredientIndex(4), key_index: Some(Id { value: 1 }) }",
]
"#]];
expected.assert_debug_eq(&cycle.all_participants(&db));
}
#[test]
fn expect_cycle() {
// A --> B
// ^ |
// +-----+
let mut db = Database::default();
let abc = ABC::new(&mut db, CycleQuery::B, CycleQuery::A, CycleQuery::None);
assert!(cycle_a(&db, abc).is_err());
}
#[test]
fn inner_cycle() {
// A --> B <-- C
// ^ |
// +-----+
let mut db = Database::default();
let abc = ABC::new(&mut db, CycleQuery::B, CycleQuery::A, CycleQuery::B);
let err = cycle_c(&db, abc);
assert!(err.is_err());
let expected = expect![[r#"
[
"DependencyIndex { ingredient_index: IngredientIndex(9), key_index: Some(Id { value: 1 }) }",
"DependencyIndex { ingredient_index: IngredientIndex(10), key_index: Some(Id { value: 1 }) }",
]
"#]];
expected.assert_debug_eq(&err.unwrap_err().cycle);
}
#[test]
fn cycle_revalidate() {
// A --> B
// ^ |
// +-----+
let mut db = Database::default();
let abc = ABC::new(&mut db, CycleQuery::B, CycleQuery::A, CycleQuery::None);
assert!(cycle_a(&db, abc).is_err());
abc.set_b(&mut db, CycleQuery::A); // same value as default
assert!(cycle_a(&db, abc).is_err());
}
#[test]
fn cycle_recovery_unchanged_twice() {
// A --> B
// ^ |
// +-----+
let mut db = Database::default();
let abc = ABC::new(&mut db, CycleQuery::B, CycleQuery::A, CycleQuery::None);
assert!(cycle_a(&db, abc).is_err());
abc.set_c(&mut db, CycleQuery::A); // force new revision
assert!(cycle_a(&db, abc).is_err());
}
#[test]
fn cycle_appears() {
let mut db = Database::default();
// A --> B
let abc = ABC::new(&mut db, CycleQuery::B, CycleQuery::None, CycleQuery::None);
assert!(cycle_a(&db, abc).is_ok());
// A --> B
// ^ |
// +-----+
abc.set_b(&mut db, CycleQuery::A);
assert!(cycle_a(&db, abc).is_err());
}
#[test]
fn cycle_disappears() {
let mut db = Database::default();
// A --> B
// ^ |
// +-----+
let abc = ABC::new(&mut db, CycleQuery::B, CycleQuery::A, CycleQuery::None);
assert!(cycle_a(&db, abc).is_err());
// A --> B
abc.set_b(&mut db, CycleQuery::None);
assert!(cycle_a(&db, abc).is_ok());
}
#[test]
fn cycle_mixed_1() {
let mut db = Database::default();
// A --> B <-- C
// | ^
// +-----+
let abc = ABC::new(&mut db, CycleQuery::B, CycleQuery::C, CycleQuery::B);
let expected = expect![[r#"
[
"DependencyIndex { ingredient_index: IngredientIndex(10), key_index: Some(Id { value: 1 }) }",
"DependencyIndex { ingredient_index: IngredientIndex(11), key_index: Some(Id { value: 1 }) }",
]
"#]];
expected.assert_debug_eq(&cycle_c(&db, abc).unwrap_err().cycle);
}
#[test]
fn cycle_mixed_2() {
let mut db = Database::default();
// Configuration:
//
// A --> B --> C
// ^ |
// +-----------+
let abc = ABC::new(&mut db, CycleQuery::B, CycleQuery::C, CycleQuery::A);
let expected = expect![[r#"
[
"DependencyIndex { ingredient_index: IngredientIndex(9), key_index: Some(Id { value: 1 }) }",
"DependencyIndex { ingredient_index: IngredientIndex(10), key_index: Some(Id { value: 1 }) }",
"DependencyIndex { ingredient_index: IngredientIndex(11), key_index: Some(Id { value: 1 }) }",
]
"#]];
expected.assert_debug_eq(&cycle_a(&db, abc).unwrap_err().cycle);
}
#[test]
fn cycle_deterministic_order() {
// No matter whether we start from A or B, we get the same set of participants:
let f = || {
let mut db = Database::default();
// A --> B
// ^ |
// +-----+
let abc = ABC::new(&mut db, CycleQuery::B, CycleQuery::A, CycleQuery::None);
(db, abc)
};
let (db, abc) = f();
let a = cycle_a(&db, abc);
let (db, abc) = f();
let b = cycle_b(&db, abc);
let expected = expect![[r#"
(
[
"DependencyIndex { ingredient_index: IngredientIndex(9), key_index: Some(Id { value: 1 }) }",
"DependencyIndex { ingredient_index: IngredientIndex(10), key_index: Some(Id { value: 1 }) }",
],
[
"DependencyIndex { ingredient_index: IngredientIndex(9), key_index: Some(Id { value: 1 }) }",
"DependencyIndex { ingredient_index: IngredientIndex(10), key_index: Some(Id { value: 1 }) }",
],
)
"#]];
expected.assert_debug_eq(&(a.unwrap_err().cycle, b.unwrap_err().cycle));
}
#[test]
fn cycle_multiple() {
// No matter whether we start from A or B, we get the same set of participants:
let mut db = Database::default();
// Configuration:
//
// A --> B <-- C
// ^ | ^
// +-----+ |
// | |
// +-----+
//
// Here, conceptually, B encounters a cycle with A and then
// recovers.
let abc = ABC::new(&mut db, CycleQuery::B, CycleQuery::AthenC, CycleQuery::A);
let c = cycle_c(&db, abc);
let b = cycle_b(&db, abc);
let a = cycle_a(&db, abc);
let expected = expect![[r#"
(
[
"DependencyIndex { ingredient_index: IngredientIndex(9), key_index: Some(Id { value: 1 }) }",
"DependencyIndex { ingredient_index: IngredientIndex(10), key_index: Some(Id { value: 1 }) }",
],
[
"DependencyIndex { ingredient_index: IngredientIndex(9), key_index: Some(Id { value: 1 }) }",
"DependencyIndex { ingredient_index: IngredientIndex(10), key_index: Some(Id { value: 1 }) }",
],
[
"DependencyIndex { ingredient_index: IngredientIndex(9), key_index: Some(Id { value: 1 }) }",
"DependencyIndex { ingredient_index: IngredientIndex(10), key_index: Some(Id { value: 1 }) }",
],
)
"#]];
expected.assert_debug_eq(&(
c.unwrap_err().cycle,
b.unwrap_err().cycle,
a.unwrap_err().cycle,
));
}
#[test]
fn cycle_recovery_set_but_not_participating() {
let mut db = Database::default();
// A --> C -+
// ^ |
// +--+
let abc = ABC::new(&mut db, CycleQuery::C, CycleQuery::None, CycleQuery::C);
// Here we expect C to panic and A not to recover:
let r = extract_cycle(|| drop(cycle_a(&db, abc)));
let expected = expect![[r#"
[
"DependencyIndex { ingredient_index: IngredientIndex(11), key_index: Some(Id { value: 1 }) }",
]
"#]];
expected.assert_debug_eq(&r.all_participants(&db));
}

View file

@ -0,0 +1,7 @@
mod setup;
mod parallel_cycle_all_recover;
mod parallel_cycle_mid_recover;
mod parallel_cycle_none_recover;
mod parallel_cycle_one_recover;
mod signal;

View file

@ -0,0 +1,112 @@
//! Test for cycle recover spread across two threads.
//! See `../cycles.rs` for a complete listing of cycle tests,
//! both intra and cross thread.
use crate::setup::Database;
use crate::setup::Knobs;
use salsa::ParallelDatabase;
pub(crate) trait Db: salsa::DbWithJar<Jar> + Knobs {}
impl<T: salsa::DbWithJar<Jar> + Knobs> Db for T {}
#[salsa::jar(db = Db)]
pub(crate) struct Jar(MyInput, a1, a2, b1, b2);
#[salsa::input(jar = Jar)]
pub(crate) struct MyInput {
field: i32,
}
#[salsa::tracked(jar = Jar, recovery_fn=recover_a1)]
pub(crate) fn a1(db: &dyn Db, input: MyInput) -> i32 {
// Wait to create the cycle until both threads have entered
db.signal(1);
db.wait_for(2);
a2(db, input)
}
fn recover_a1(db: &dyn Db, _cycle: &salsa::Cycle, key: MyInput) -> i32 {
dbg!("recover_a1");
key.field(db) * 10 + 1
}
#[salsa::tracked(jar = Jar, recovery_fn=recover_a2)]
pub(crate) fn a2(db: &dyn Db, input: MyInput) -> i32 {
b1(db, input)
}
fn recover_a2(db: &dyn Db, _cycle: &salsa::Cycle, key: MyInput) -> i32 {
dbg!("recover_a2");
key.field(db) * 10 + 2
}
#[salsa::tracked(jar = Jar, recovery_fn=recover_b1)]
pub(crate) fn b1(db: &dyn Db, input: MyInput) -> i32 {
// Wait to create the cycle until both threads have entered
db.wait_for(1);
db.signal(2);
// Wait for thread A to block on this thread
db.wait_for(3);
b2(db, input)
}
fn recover_b1(db: &dyn Db, _cycle: &salsa::Cycle, key: MyInput) -> i32 {
dbg!("recover_b1");
key.field(db) * 20 + 1
}
#[salsa::tracked(jar = Jar, recovery_fn=recover_b2)]
pub(crate) fn b2(db: &dyn Db, input: MyInput) -> i32 {
a1(db, input)
}
fn recover_b2(db: &dyn Db, _cycle: &salsa::Cycle, key: MyInput) -> i32 {
dbg!("recover_b2");
key.field(db) * 20 + 2
}
// Recover cycle test:
//
// The pattern is as follows.
//
// Thread A Thread B
// -------- --------
// a1 b1
// | wait for stage 1 (blocks)
// signal stage 1 |
// wait for stage 2 (blocks) (unblocked)
// | signal stage 2
// (unblocked) wait for stage 3 (blocks)
// a2 |
// b1 (blocks -> stage 3) |
// | (unblocked)
// | b2
// | a1 (cycle detected, recovers)
// | b2 completes, recovers
// | b1 completes, recovers
// a2 sees cycle, recovers
// a1 completes, recovers
#[test]
fn execute() {
let mut db = Database::default();
db.knobs().signal_on_will_block.set(3);
let input = MyInput::new(&mut db, 1);
let thread_a = std::thread::spawn({
let db = db.snapshot();
move || a1(&*db, input)
});
let thread_b = std::thread::spawn({
let db = db.snapshot();
move || b1(&*db, input)
});
assert_eq!(thread_a.join().unwrap(), 11);
assert_eq!(thread_b.join().unwrap(), 21);
}

View file

@ -0,0 +1,110 @@
//! Test for cycle recover spread across two threads.
//! See `../cycles.rs` for a complete listing of cycle tests,
//! both intra and cross thread.
use crate::setup::Database;
use crate::setup::Knobs;
use salsa::ParallelDatabase;
pub(crate) trait Db: salsa::DbWithJar<Jar> + Knobs {}
impl<T: salsa::DbWithJar<Jar> + Knobs> Db for T {}
#[salsa::jar(db = Db)]
pub(crate) struct Jar(MyInput, a1, a2, b1, b2, b3);
#[salsa::input(jar = Jar)]
pub(crate) struct MyInput {
field: i32,
}
#[salsa::tracked(jar = Jar)]
pub(crate) fn a1(db: &dyn Db, input: MyInput) -> i32 {
// tell thread b we have started
db.signal(1);
// wait for thread b to block on a1
db.wait_for(2);
a2(db, input)
}
#[salsa::tracked(jar = Jar)]
pub(crate) fn a2(db: &dyn Db, input: MyInput) -> i32 {
// create the cycle
b1(db, input)
}
#[salsa::tracked(jar = Jar, recovery_fn=recover_b1)]
pub(crate) fn b1(db: &dyn Db, input: MyInput) -> i32 {
// wait for thread a to have started
db.wait_for(1);
b2(db, input)
}
fn recover_b1(db: &dyn Db, _cycle: &salsa::Cycle, key: MyInput) -> i32 {
dbg!("recover_b1");
key.field(db) * 20 + 2
}
#[salsa::tracked(jar = Jar)]
pub(crate) fn b2(db: &dyn Db, input: MyInput) -> i32 {
// will encounter a cycle but recover
b3(db, input);
b1(db, input); // hasn't recovered yet
0
}
#[salsa::tracked(jar = Jar, recovery_fn=recover_b3)]
pub(crate) fn b3(db: &dyn Db, input: MyInput) -> i32 {
// will block on thread a, signaling stage 2
a1(db, input)
}
fn recover_b3(db: &dyn Db, _cycle: &salsa::Cycle, key: MyInput) -> i32 {
dbg!("recover_b3");
key.field(db) * 200 + 2
}
// Recover cycle test:
//
// The pattern is as follows.
//
// Thread A Thread B
// -------- --------
// a1 b1
// | wait for stage 1 (blocks)
// signal stage 1 |
// wait for stage 2 (blocks) (unblocked)
// | |
// | b2
// | b3
// | a1 (blocks -> stage 2)
// (unblocked) |
// a2 (cycle detected) |
// b3 recovers
// b2 resumes
// b1 recovers
#[test]
fn execute() {
let mut db = Database::default();
db.knobs().signal_on_will_block.set(3);
let input = MyInput::new(&mut db, 1);
let thread_a = std::thread::spawn({
let db = db.snapshot();
move || a1(&*db, input)
});
let thread_b = std::thread::spawn({
let db = db.snapshot();
move || b1(&*db, input)
});
// We expect that the recovery function yields
// `1 * 20 + 2`, which is returned (and forwarded)
// to b1, and from there to a2 and a1.
assert_eq!(thread_a.join().unwrap(), 22);
assert_eq!(thread_b.join().unwrap(), 22);
}

View file

@ -0,0 +1,83 @@
//! Test a cycle where no queries recover that occurs across threads.
//! See the `../cycles.rs` for a complete listing of cycle tests,
//! both intra and cross thread.
use crate::setup::Database;
use crate::setup::Knobs;
use expect_test::expect;
use salsa::ParallelDatabase;
pub(crate) trait Db: salsa::DbWithJar<Jar> + Knobs {}
impl<T: salsa::DbWithJar<Jar> + Knobs> Db for T {}
#[salsa::jar(db = Db)]
pub(crate) struct Jar(MyInput, a, b);
#[salsa::input(jar = Jar)]
pub(crate) struct MyInput {
field: i32,
}
#[salsa::tracked(jar = Jar)]
pub(crate) fn a(db: &dyn Db, input: MyInput) -> i32 {
// Wait to create the cycle until both threads have entered
db.signal(1);
db.wait_for(2);
b(db, input)
}
#[salsa::tracked(jar = Jar)]
pub(crate) fn b(db: &dyn Db, input: MyInput) -> i32 {
// Wait to create the cycle until both threads have entered
db.wait_for(1);
db.signal(2);
// Wait for thread A to block on this thread
db.wait_for(3);
// Now try to execute A
a(db, input)
}
#[test]
fn execute() {
let mut db = Database::default();
db.knobs().signal_on_will_block.set(3);
let input = MyInput::new(&mut db, -1);
let thread_a = std::thread::spawn({
let db = db.snapshot();
move || a(&*db, input)
});
let thread_b = std::thread::spawn({
let db = db.snapshot();
move || b(&*db, input)
});
// We expect B to panic because it detects a cycle (it is the one that calls A, ultimately).
// Right now, it panics with a string.
let err_b = thread_b.join().unwrap_err();
if let Some(c) = err_b.downcast_ref::<salsa::Cycle>() {
let expected = expect![[r#"
[
"DependencyIndex { ingredient_index: IngredientIndex(8), key_index: Some(Id { value: 1 }) }",
"DependencyIndex { ingredient_index: IngredientIndex(9), key_index: Some(Id { value: 1 }) }",
]
"#]];
expected.assert_debug_eq(&c.all_participants(&db));
} else {
panic!("b failed in an unexpected way: {:?}", err_b);
}
// We expect A to propagate a panic, which causes us to use the sentinel
// type `Canceled`.
assert!(thread_a
.join()
.unwrap_err()
.downcast_ref::<salsa::Cancelled>()
.is_some());
}

View file

@ -0,0 +1,99 @@
//! Test for cycle recover spread across two threads.
//! See `../cycles.rs` for a complete listing of cycle tests,
//! both intra and cross thread.
use crate::setup::Database;
use crate::setup::Knobs;
use salsa::ParallelDatabase;
pub(crate) trait Db: salsa::DbWithJar<Jar> + Knobs {}
impl<T: salsa::DbWithJar<Jar> + Knobs> Db for T {}
#[salsa::jar(db = Db)]
pub(crate) struct Jar(MyInput, a1, a2, b1, b2);
#[salsa::input(jar = Jar)]
pub(crate) struct MyInput {
field: i32,
}
#[salsa::tracked(jar = Jar)]
pub(crate) fn a1(db: &dyn Db, input: MyInput) -> i32 {
// Wait to create the cycle until both threads have entered
db.signal(1);
db.wait_for(2);
a2(db, input)
}
#[salsa::tracked(jar = Jar, recovery_fn=recover)]
pub(crate) fn a2(db: &dyn Db, input: MyInput) -> i32 {
b1(db, input)
}
fn recover(db: &dyn Db, _cycle: &salsa::Cycle, key: MyInput) -> i32 {
dbg!("recover");
key.field(db) * 20 + 2
}
#[salsa::tracked(jar = Jar)]
pub(crate) fn b1(db: &dyn Db, input: MyInput) -> i32 {
// Wait to create the cycle until both threads have entered
db.wait_for(1);
db.signal(2);
// Wait for thread A to block on this thread
db.wait_for(3);
b2(db, input)
}
#[salsa::tracked(jar = Jar)]
pub(crate) fn b2(db: &dyn Db, input: MyInput) -> i32 {
a1(db, input)
}
// Recover cycle test:
//
// The pattern is as follows.
//
// Thread A Thread B
// -------- --------
// a1 b1
// | wait for stage 1 (blocks)
// signal stage 1 |
// wait for stage 2 (blocks) (unblocked)
// | signal stage 2
// (unblocked) wait for stage 3 (blocks)
// a2 |
// b1 (blocks -> stage 3) |
// | (unblocked)
// | b2
// | a1 (cycle detected)
// a2 recovery fn executes |
// a1 completes normally |
// b2 completes, recovers
// b1 completes, recovers
#[test]
fn execute() {
let mut db = Database::default();
db.knobs().signal_on_will_block.set(3);
let input = MyInput::new(&mut db, 1);
let thread_a = std::thread::spawn({
let db = db.snapshot();
move || a1(&*db, input)
});
let thread_b = std::thread::spawn({
let db = db.snapshot();
move || b1(&*db, input)
});
// We expect that the recovery function yields
// `1 * 20 + 2`, which is returned (and forwarded)
// to b1, and from there to a2 and a1.
assert_eq!(thread_a.join().unwrap(), 22);
assert_eq!(thread_b.join().unwrap(), 22);
}

View file

@ -0,0 +1,73 @@
use std::{cell::Cell, sync::Arc};
use crate::signal::Signal;
/// Various "knobs" and utilities used by tests to force
/// a certain behavior.
pub(crate) trait Knobs {
fn knobs(&self) -> &KnobsStruct;
fn signal(&self, stage: usize);
fn wait_for(&self, stage: usize);
}
/// Various "knobs" that can be used to customize how the queries
/// behave on one specific thread. Note that this state is
/// intentionally thread-local (apart from `signal`).
#[derive(Clone, Default)]
pub(crate) struct KnobsStruct {
/// A kind of flexible barrier used to coordinate execution across
/// threads to ensure we reach various weird states.
pub(crate) signal: Arc<Signal>,
/// When this database is about to block, send a signal.
pub(crate) signal_on_will_block: Cell<usize>,
}
#[salsa::db(
crate::parallel_cycle_one_recover::Jar,
crate::parallel_cycle_none_recover::Jar,
crate::parallel_cycle_mid_recover::Jar,
crate::parallel_cycle_all_recover::Jar
)]
#[derive(Default)]
pub(crate) struct Database {
storage: salsa::Storage<Self>,
knobs: KnobsStruct,
}
impl salsa::Database for Database {
fn salsa_runtime(&self) -> &salsa::Runtime {
self.storage.runtime()
}
fn salsa_event(&self, event: salsa::Event) {
if let salsa::EventKind::WillBlockOn { .. } = event.kind {
self.signal(self.knobs().signal_on_will_block.get());
}
}
}
impl salsa::ParallelDatabase for Database {
fn snapshot(&self) -> salsa::Snapshot<Self> {
salsa::Snapshot::new(Database {
storage: self.storage.snapshot(),
knobs: self.knobs.clone(),
})
}
}
impl Knobs for Database {
fn knobs(&self) -> &KnobsStruct {
&self.knobs
}
fn signal(&self, stage: usize) {
self.knobs.signal.signal(stage);
}
fn wait_for(&self, stage: usize) {
self.knobs.signal.wait_for(stage);
}
}

View file

@ -0,0 +1,40 @@
use parking_lot::{Condvar, Mutex};
#[derive(Default)]
pub(crate) struct Signal {
value: Mutex<usize>,
cond_var: Condvar,
}
impl Signal {
pub(crate) fn signal(&self, stage: usize) {
dbg!(format!("signal({})", stage));
// This check avoids acquiring the lock for things that will
// clearly be a no-op. Not *necessary* but helps to ensure we
// are more likely to encounter weird race conditions;
// otherwise calls to `sum` will tend to be unnecessarily
// synchronous.
if stage > 0 {
let mut v = self.value.lock();
if stage > *v {
*v = stage;
self.cond_var.notify_all();
}
}
}
/// Waits until the given condition is true; the fn is invoked
/// with the current stage.
pub(crate) fn wait_for(&self, stage: usize) {
dbg!(format!("wait_for({})", stage));
// As above, avoid lock if clearly a no-op.
if stage > 0 {
let mut v = self.value.lock();
while *v < stage {
self.cond_var.wait(&mut v);
}
}
}
}