diff --git a/Cargo.toml b/Cargo.toml index 9b78d04b..5e8b17d4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,4 +19,5 @@ smallvec = "0.6.5" [dev-dependencies] diff = "0.1.0" -env_logger = "0.5.13" \ No newline at end of file +env_logger = "0.5.13" +rand = "0.5.5" diff --git a/tests/parallel/main.rs b/tests/parallel/main.rs index c1bb0004..7b8823e1 100644 --- a/tests/parallel/main.rs +++ b/tests/parallel/main.rs @@ -6,3 +6,4 @@ mod race; mod revision_lock; mod signal; mod true_parallel; +mod stress; diff --git a/tests/parallel/stress.rs b/tests/parallel/stress.rs new file mode 100644 index 00000000..32670695 --- /dev/null +++ b/tests/parallel/stress.rs @@ -0,0 +1,169 @@ +use rand::Rng; + +use salsa::Database; +use salsa::ParallelDatabase; +use salsa::SweepStrategy; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +struct Canceled; +type Cancelable = Result; + +salsa::query_group! { + trait StressDatabase: salsa::Database { + fn a(key: usize) -> usize { + type A; + storage input; + } + + fn b(key: usize) -> Cancelable { + type B; + } + + fn c(key: usize) -> Cancelable { + type C; + } + } +} + +fn b(db: &impl StressDatabase, key: usize) -> Cancelable { + if db.salsa_runtime().is_current_revision_canceled() { + return Err(Canceled); + } + Ok(db.a(key)) +} + +fn c(db: &impl StressDatabase, key: usize) -> Cancelable { + db.b(key) +} + +#[derive(Default)] +struct StressDatabaseImpl { + runtime: salsa::Runtime +} + +impl salsa::Database for StressDatabaseImpl { + fn salsa_runtime(&self) -> &salsa::Runtime { + &self.runtime + } +} + +impl salsa::ParallelDatabase for StressDatabaseImpl { + fn fork(&self) -> StressDatabaseImpl { + StressDatabaseImpl { runtime: self.runtime.fork() } + } +} + +salsa::database_storage! { + pub struct DatabaseImplStorage for StressDatabaseImpl { + impl StressDatabase { + fn a() for A; + fn b() for B; + fn c() for C; + } + } +} + +#[derive(Clone, Copy, Debug)] +enum Query { A, B, C } + +#[derive(Debug)] +enum Op { + SetA(usize, usize), + Get(Query, usize), + Gc(Query, SweepStrategy), + GcAll(SweepStrategy), +} + +impl rand::distributions::Distribution for rand::distributions::Standard { + fn sample(&self, rng: &mut R) -> Query { + *rng.choose(&[Query::A, Query::B, Query::C]).unwrap() + } +} + +impl rand::distributions::Distribution for rand::distributions::Standard { + fn sample(&self, rng: &mut R) -> Op { + if rng.gen_bool(0.5) { + let query = rng.gen::(); + let key = rng.gen::() % 10; + return Op::Get(query, key); + } + if rng.gen_bool(0.5) { + let key = rng.gen::() % 10; + let value = rng.gen::() % 10; + return Op::SetA(key, value); + } + let mut strategy = SweepStrategy::default(); + if rng.gen_bool(0.5) { + strategy = strategy.discard_values(); + } + if rng.gen_bool(0.5) { + Op::Gc(rng.gen::(), strategy) + } else { + Op::GcAll(strategy) + } + } +} + +fn db_thread(db: StressDatabaseImpl, ops: Vec) { + for op in ops { + // eprintln!("{:02?}: {:?}", std::thread::current().id(), op); + match op { + Op::SetA(key, value) => { + db.query(A).set(key, value); + } + Op::Get(query, key) => { + match query { + Query::A => { + db.a(key); + }, + Query::B => { + let _ = db.b(key); + }, + Query::C => { + let _ = db.c(key); + }, + } + } + Op::Gc(query, strategy) => { + match query { + Query::A => { + db.query(A).sweep(strategy); + }, + Query::B => { + db.query(B).sweep(strategy); + }, + Query::C => { + db.query(C).sweep(strategy); + }, + } + } + Op::GcAll(strategy) => { + db.sweep_all(strategy); + } + } + } +} + +fn random_ops(n_ops: usize) -> Vec { + let mut rng = rand::thread_rng(); + (0..n_ops).map(|_| rng.gen::()).collect() +} + +#[test] +fn stress_test() { + let db = StressDatabaseImpl::default(); + for i in 0..10 { + db.query(A).set(i, i); + } + let n_threads = 20; + let n_ops = 100; + let ops = (0..n_threads).map(|_| random_ops(n_ops)); + let threads = ops.into_iter().map(|ops| { + let db = db.fork(); + std::thread::spawn(move || db_thread(db, ops)) + }).collect::>(); + std::mem::drop(db); + for thread in threads { + thread.join().unwrap(); + } +}