mirror of
https://github.com/salsa-rs/salsa.git
synced 2025-01-13 00:40:22 +00:00
move the test setup in setup.rs
This commit is contained in:
parent
8f03f3bb76
commit
fb5ba07290
3 changed files with 155 additions and 123 deletions
|
@ -1,125 +1,6 @@
|
||||||
use parking_lot::{Condvar, Mutex};
|
use crate::setup::{Input, Knobs, ParDatabase, ParDatabaseImpl, WithValue};
|
||||||
use salsa::Database;
|
use salsa::Database;
|
||||||
use salsa::ParallelDatabase;
|
use salsa::ParallelDatabase;
|
||||||
use std::cell::Cell;
|
|
||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
#[derive(Default)]
|
|
||||||
pub struct ParDatabaseImpl {
|
|
||||||
runtime: salsa::Runtime<ParDatabaseImpl>,
|
|
||||||
signal: Arc<Signal>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Database for ParDatabaseImpl {
|
|
||||||
fn salsa_runtime(&self) -> &salsa::Runtime<ParDatabaseImpl> {
|
|
||||||
&self.runtime
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ParallelDatabase for ParDatabaseImpl {
|
|
||||||
fn fork(&self) -> Self {
|
|
||||||
ParDatabaseImpl {
|
|
||||||
runtime: self.runtime.fork(),
|
|
||||||
signal: self.signal.clone(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
salsa::database_storage! {
|
|
||||||
pub struct DatabaseImplStorage for ParDatabaseImpl {
|
|
||||||
impl ParDatabase {
|
|
||||||
fn input() for Input;
|
|
||||||
fn sum() for Sum;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
salsa::query_group! {
|
|
||||||
trait ParDatabase: HasSignal + salsa::Database {
|
|
||||||
fn input(key: char) -> usize {
|
|
||||||
type Input;
|
|
||||||
storage input;
|
|
||||||
}
|
|
||||||
|
|
||||||
fn sum(key: &'static str) -> usize {
|
|
||||||
type Sum;
|
|
||||||
use fn sum;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// This is used to force `sum` to block on the signal sometimes so
|
|
||||||
// that we can forcibly arrange race conditions we would like to test.
|
|
||||||
thread_local! {
|
|
||||||
static SUM_SHOULD_AWAIT_CANCELLATION: Cell<bool> = Cell::new(false);
|
|
||||||
}
|
|
||||||
|
|
||||||
trait HasSignal {
|
|
||||||
fn signal(&self) -> &Signal;
|
|
||||||
}
|
|
||||||
|
|
||||||
impl HasSignal for ParDatabaseImpl {
|
|
||||||
fn signal(&self) -> &Signal {
|
|
||||||
&self.signal
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Default)]
|
|
||||||
struct Signal {
|
|
||||||
value: Mutex<usize>,
|
|
||||||
cond_var: Condvar,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Signal {
|
|
||||||
fn signal(&self, stage: usize) {
|
|
||||||
log::debug!("signal({})", stage);
|
|
||||||
let mut v = self.value.lock();
|
|
||||||
assert!(
|
|
||||||
stage > *v,
|
|
||||||
"stage should be increasing monotonically (old={}, new={})",
|
|
||||||
*v,
|
|
||||||
stage
|
|
||||||
);
|
|
||||||
*v = stage;
|
|
||||||
self.cond_var.notify_all();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Waits until the given condition is true; the fn is invoked
|
|
||||||
/// with the current stage.
|
|
||||||
fn await(&self, stage: usize) {
|
|
||||||
log::debug!("await({})", stage);
|
|
||||||
let mut v = self.value.lock();
|
|
||||||
while *v < stage {
|
|
||||||
self.cond_var.wait(&mut v);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn sum(db: &impl ParDatabase, key: &'static str) -> usize {
|
|
||||||
let mut sum = 0;
|
|
||||||
|
|
||||||
// If we are going to await cancellation, we first *signal* when
|
|
||||||
// we have entered. This way, the other thread can wait and be
|
|
||||||
// sure that we are executing `sum`.
|
|
||||||
if SUM_SHOULD_AWAIT_CANCELLATION.with(|s| s.get()) {
|
|
||||||
db.signal().signal(1);
|
|
||||||
}
|
|
||||||
|
|
||||||
for ch in key.chars() {
|
|
||||||
sum += db.input(ch);
|
|
||||||
}
|
|
||||||
|
|
||||||
if SUM_SHOULD_AWAIT_CANCELLATION.with(|s| s.get()) {
|
|
||||||
log::debug!("awaiting cancellation");
|
|
||||||
while !db.salsa_runtime().is_current_revision_canceled() {
|
|
||||||
std::thread::yield_now();
|
|
||||||
}
|
|
||||||
log::debug!("cancellation observed");
|
|
||||||
return std::usize::MAX; // when we are cancelled, we return usize::MAX.
|
|
||||||
}
|
|
||||||
|
|
||||||
sum
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn in_par() {
|
fn in_par() {
|
||||||
|
@ -180,12 +61,13 @@ fn in_par_get_set_cancellation() {
|
||||||
let thread1 = std::thread::spawn({
|
let thread1 = std::thread::spawn({
|
||||||
let db = db.fork();
|
let db = db.fork();
|
||||||
move || {
|
move || {
|
||||||
SUM_SHOULD_AWAIT_CANCELLATION.with(|c| c.set(true));
|
let v1 = db.sum_signal_on_entry().with_value(1, || {
|
||||||
let v1 = db.sum("abc");
|
db.sum_await_cancellation()
|
||||||
|
.with_value(true, || db.sum("abc"))
|
||||||
|
});
|
||||||
|
|
||||||
// check that we observed cancellation
|
// check that we observed cancellation
|
||||||
assert_eq!(v1, std::usize::MAX);
|
assert_eq!(v1, std::usize::MAX);
|
||||||
SUM_SHOULD_AWAIT_CANCELLATION.with(|c| c.set(false));
|
|
||||||
|
|
||||||
// at this point, we have observed cancellation, so let's
|
// at this point, we have observed cancellation, so let's
|
||||||
// wait until the `set` is known to have occurred.
|
// wait until the `set` is known to have occurred.
|
||||||
|
|
|
@ -1 +1,3 @@
|
||||||
|
mod setup;
|
||||||
|
|
||||||
mod cancellation;
|
mod cancellation;
|
||||||
|
|
148
tests/parallel/setup.rs
Normal file
148
tests/parallel/setup.rs
Normal file
|
@ -0,0 +1,148 @@
|
||||||
|
use parking_lot::{Condvar, Mutex};
|
||||||
|
use salsa::Database;
|
||||||
|
use salsa::ParallelDatabase;
|
||||||
|
use std::cell::Cell;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
salsa::query_group! {
|
||||||
|
pub(crate) trait ParDatabase: Knobs + salsa::Database {
|
||||||
|
fn input(key: char) -> usize {
|
||||||
|
type Input;
|
||||||
|
storage input;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sum(key: &'static str) -> usize {
|
||||||
|
type Sum;
|
||||||
|
use fn sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Various "knobs" and utilities used by tests to force
|
||||||
|
/// a certain behavior.
|
||||||
|
pub(crate) trait Knobs {
|
||||||
|
fn signal(&self) -> &Signal;
|
||||||
|
|
||||||
|
/// Invocations of `sum` will signal `stage` this stage on entry.
|
||||||
|
fn sum_signal_on_entry(&self) -> &Cell<usize>;
|
||||||
|
|
||||||
|
/// If set to true, invocations of `sum` will await cancellation
|
||||||
|
/// before they exit.
|
||||||
|
fn sum_await_cancellation(&self) -> &Cell<bool>;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) trait WithValue<T> {
|
||||||
|
fn with_value<R>(&self, value: T, closure: impl FnOnce() -> R) -> R;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> WithValue<T> for Cell<T> {
|
||||||
|
fn with_value<R>(&self, value: T, closure: impl FnOnce() -> R) -> R {
|
||||||
|
let old_value = self.replace(value);
|
||||||
|
|
||||||
|
let result = closure();
|
||||||
|
|
||||||
|
self.set(old_value);
|
||||||
|
|
||||||
|
result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Default)]
|
||||||
|
struct KnobsStruct {
|
||||||
|
signal: Arc<Signal>,
|
||||||
|
sum_signal_on_entry: Cell<usize>,
|
||||||
|
sum_await_cancellation: Cell<bool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Default)]
|
||||||
|
pub(crate) struct Signal {
|
||||||
|
value: Mutex<usize>,
|
||||||
|
cond_var: Condvar,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Signal {
|
||||||
|
pub(crate) fn signal(&self, stage: usize) {
|
||||||
|
log::debug!("signal({})", stage);
|
||||||
|
let mut v = self.value.lock();
|
||||||
|
if stage > *v {
|
||||||
|
*v = stage;
|
||||||
|
self.cond_var.notify_all();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Waits until the given condition is true; the fn is invoked
|
||||||
|
/// with the current stage.
|
||||||
|
pub(crate) fn await(&self, stage: usize) {
|
||||||
|
log::debug!("await({})", stage);
|
||||||
|
let mut v = self.value.lock();
|
||||||
|
while *v < stage {
|
||||||
|
self.cond_var.wait(&mut v);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sum(db: &impl ParDatabase, key: &'static str) -> usize {
|
||||||
|
let mut sum = 0;
|
||||||
|
|
||||||
|
let stage = db.sum_signal_on_entry().get();
|
||||||
|
db.signal().signal(stage);
|
||||||
|
|
||||||
|
for ch in key.chars() {
|
||||||
|
sum += db.input(ch);
|
||||||
|
}
|
||||||
|
|
||||||
|
if db.sum_await_cancellation().get() {
|
||||||
|
log::debug!("awaiting cancellation");
|
||||||
|
while !db.salsa_runtime().is_current_revision_canceled() {
|
||||||
|
std::thread::yield_now();
|
||||||
|
}
|
||||||
|
log::debug!("cancellation observed");
|
||||||
|
return std::usize::MAX; // when we are cancelled, we return usize::MAX.
|
||||||
|
}
|
||||||
|
|
||||||
|
sum
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Default)]
|
||||||
|
pub struct ParDatabaseImpl {
|
||||||
|
runtime: salsa::Runtime<ParDatabaseImpl>,
|
||||||
|
knobs: KnobsStruct,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Database for ParDatabaseImpl {
|
||||||
|
fn salsa_runtime(&self) -> &salsa::Runtime<ParDatabaseImpl> {
|
||||||
|
&self.runtime
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ParallelDatabase for ParDatabaseImpl {
|
||||||
|
fn fork(&self) -> Self {
|
||||||
|
ParDatabaseImpl {
|
||||||
|
runtime: self.runtime.fork(),
|
||||||
|
knobs: self.knobs.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Knobs for ParDatabaseImpl {
|
||||||
|
fn signal(&self) -> &Signal {
|
||||||
|
&self.knobs.signal
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sum_signal_on_entry(&self) -> &Cell<usize> {
|
||||||
|
&self.knobs.sum_signal_on_entry
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sum_await_cancellation(&self) -> &Cell<bool> {
|
||||||
|
&self.knobs.sum_await_cancellation
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
salsa::database_storage! {
|
||||||
|
pub struct DatabaseImplStorage for ParDatabaseImpl {
|
||||||
|
impl ParDatabase {
|
||||||
|
fn input() for Input;
|
||||||
|
fn sum() for Sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in a new issue