diff --git a/tests/parallel/race.rs b/tests/parallel/race.rs index b862096c..234aea8e 100644 --- a/tests/parallel/race.rs +++ b/tests/parallel/race.rs @@ -25,9 +25,14 @@ fn in_par_get_set_race() { }); // If the 1st thread runs first, you get 111, otherwise you get - // 1011. + // 1011; if they run concurrently and the 1st thread observes the + // cancelation, you get back usize::max. let value1 = thread1.join().unwrap(); - assert!(value1 == 111 || value1 == 1011, "illegal result {}", value1); + assert!( + value1 == 111 || value1 == 1011 || value1 == std::usize::MAX, + "illegal result {}", + value1 + ); assert_eq!(thread2.join().unwrap(), 1000); } diff --git a/tests/parallel/setup.rs b/tests/parallel/setup.rs index d4c58880..f66986f1 100644 --- a/tests/parallel/setup.rs +++ b/tests/parallel/setup.rs @@ -109,6 +109,17 @@ fn sum(db: &impl ParDatabase, key: &'static str) -> usize { std::thread::yield_now(); } log::debug!("cancellation observed"); + } + + // Check for cancelation and return MAX if so. Note that we check + // for cancelation *deterministically* -- but if + // `sum_wait_for_cancellation` is set, we will block + // beforehand. Deterministic execution is a requirement for valid + // salsa user code. It's also important to some tests that `sum` + // *attempts* to invoke `is_current_revision_canceled` even if we + // know it will not be canceled, because that helps us keep the + // accounting up to date. + if db.salsa_runtime().is_current_revision_canceled() { return std::usize::MAX; // when we are cancelled, we return usize::MAX. }