diff --git a/examples/hello_world/main.rs b/examples/hello_world/main.rs index f7f6ac29..4c9a6304 100644 --- a/examples/hello_world/main.rs +++ b/examples/hello_world/main.rs @@ -93,11 +93,11 @@ salsa::database_storage! { // This shows how to use a query. fn main() { - let db = DatabaseStruct::default(); + let mut db = DatabaseStruct::default(); println!("Initially, the length is {}.", db.length(())); - db.query(InputString) + db.query_mut(InputString) .set((), Arc::new(format!("Hello, world"))); println!("Now, the length is {}.", db.length(())); diff --git a/src/lib.rs b/src/lib.rs index 434f48f6..342ce29f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -63,6 +63,17 @@ pub trait Database: plumbing::DatabaseStorageTypes + plumbing::DatabaseOps { >::get_query_table(self) } + /// Get access to extra methods pertaining to a given query, + /// notably `set` (for inputs). + #[allow(unused_variables)] + fn query_mut(&mut self, query: Q) -> QueryTableMut<'_, Self, Q> + where + Q: Query, + Self: plumbing::GetQueryTable, + { + >::get_query_table_mut(self) + } + /// This function is invoked at key points in the salsa /// runtime. It permits the database to be customized and to /// inject logging or other custom behavior. @@ -300,6 +311,31 @@ where self.storage.sweep(self.db, strategy); } + fn descriptor(&self, key: &Q::Key) -> DB::QueryDescriptor { + (self.descriptor_fn)(self.db, key) + } +} + +#[derive(new)] +pub struct QueryTableMut<'me, DB, Q> +where + DB: Database + 'me, + Q: Query + 'me, +{ + db: &'me DB, + storage: &'me Q::Storage, + descriptor_fn: fn(&DB, &Q::Key) -> DB::QueryDescriptor, +} + +impl QueryTableMut<'_, DB, Q> +where + DB: Database, + Q: Query, +{ + fn descriptor(&self, key: &Q::Key) -> DB::QueryDescriptor { + (self.descriptor_fn)(self.db, key) + } + /// Assign a value to an "input query". Must be used outside of /// an active query computation. pub fn set(&self, key: Q::Key, value: Q::Value) @@ -342,10 +378,6 @@ where { self.storage.set_unchecked(self.db, &key, value); } - - fn descriptor(&self, key: &Q::Key) -> DB::QueryDescriptor { - (self.descriptor_fn)(self.db, key) - } } /// A macro that helps in defining the "context trait" of a given @@ -748,6 +780,24 @@ macro_rules! database_storage { }, ) } + + fn get_query_table_mut( + db: &mut Self, + ) -> $crate::QueryTableMut<'_, Self, $QueryType> { + let db = &*db; + $crate::QueryTableMut::new( + db, + &$crate::Database::salsa_runtime(db) + .storage() + .$query_method, + |_, key| { + let key = std::clone::Clone::clone(key); + __SalsaQueryDescriptor { + kind: __SalsaQueryDescriptorKind::$query_method(key), + } + }, + ) + } } )* )* diff --git a/src/plumbing.rs b/src/plumbing.rs index d6810665..4b35edfe 100644 --- a/src/plumbing.rs +++ b/src/plumbing.rs @@ -1,6 +1,7 @@ use crate::Database; use crate::Query; use crate::QueryTable; +use crate::QueryTableMut; use crate::SweepStrategy; use std::fmt::Debug; use std::hash::Hash; @@ -57,6 +58,8 @@ pub trait QueryFunction: Query { pub trait GetQueryTable>: Database { fn get_query_table(db: &Self) -> QueryTable<'_, Self, Q>; + + fn get_query_table_mut(db: &mut Self) -> QueryTableMut<'_, Self, Q>; } pub trait QueryStorageOps: Default diff --git a/tests/gc/derived_tests.rs b/tests/gc/derived_tests.rs index 88223374..ca2bab67 100644 --- a/tests/gc/derived_tests.rs +++ b/tests/gc/derived_tests.rs @@ -15,10 +15,10 @@ macro_rules! assert_keys { #[test] fn compute_one() { - let db = db::DatabaseImpl::default(); + let mut db = db::DatabaseImpl::default(); // Will compute fibonacci(5) - db.query(UseTriangular).set(5, false); + db.query_mut(UseTriangular).set(5, false); db.compute(5); db.salsa_runtime().next_revision(); @@ -50,14 +50,14 @@ fn compute_one() { #[test] fn compute_switch() { - let db = db::DatabaseImpl::default(); + let mut db = db::DatabaseImpl::default(); // Will compute fibonacci(5) - db.query(UseTriangular).set(5, false); + db.query_mut(UseTriangular).set(5, false); assert_eq!(db.compute(5), 5); // Change to triangular mode - db.query(UseTriangular).set(5, true); + db.query_mut(UseTriangular).set(5, true); // Now computes triangular(5) assert_eq!(db.compute(5), 15); @@ -107,14 +107,14 @@ fn compute_switch() { /// Test a query with multiple layers of keys. #[test] fn compute_all() { - let db = db::DatabaseImpl::default(); + let mut db = db::DatabaseImpl::default(); for i in 0..6 { - db.query(UseTriangular).set(i, (i % 2) != 0); + db.query_mut(UseTriangular).set(i, (i % 2) != 0); } - db.query(Min).set((), 0); - db.query(Max).set((), 6); + db.query_mut(Min).set((), 0); + db.query_mut(Max).set((), 6); db.compute_all(); db.salsa_runtime().next_revision(); @@ -133,7 +133,7 @@ fn compute_all() { } // Reduce the range to exclude index 5. - db.query(Max).set((), 5); + db.query_mut(Max).set((), 5); db.compute_all(); assert_keys! { diff --git a/tests/incremental/constants.rs b/tests/incremental/constants.rs index 1d7ebc43..f7bc960f 100644 --- a/tests/incremental/constants.rs +++ b/tests/incremental/constants.rs @@ -23,24 +23,24 @@ fn constants_add(db: &impl ConstantsDatabase, (key1, key2): (char, char)) -> usi #[test] #[should_panic] fn invalidate_constant() { - let db = &TestContextImpl::default(); - db.query(ConstantsInput).set_constant('a', 44); - db.query(ConstantsInput).set_constant('a', 66); + let db = &mut TestContextImpl::default(); + db.query_mut(ConstantsInput).set_constant('a', 44); + db.query_mut(ConstantsInput).set_constant('a', 66); } #[test] #[should_panic] fn invalidate_constant_1() { - let db = &TestContextImpl::default(); + let db = &mut TestContextImpl::default(); // Not constant: - db.query(ConstantsInput).set('a', 44); + db.query_mut(ConstantsInput).set('a', 44); // Becomes constant: - db.query(ConstantsInput).set_constant('a', 44); + db.query_mut(ConstantsInput).set_constant('a', 44); // Invalidates: - db.query(ConstantsInput).set_constant('a', 66); + db.query_mut(ConstantsInput).set_constant('a', 66); } /// Test that invoking `set` on a constant is an error, even if you @@ -48,55 +48,55 @@ fn invalidate_constant_1() { #[test] #[should_panic] fn set_after_constant_same_value() { - let db = &TestContextImpl::default(); - db.query(ConstantsInput).set_constant('a', 44); - db.query(ConstantsInput).set('a', 44); + let db = &mut TestContextImpl::default(); + db.query_mut(ConstantsInput).set_constant('a', 44); + db.query_mut(ConstantsInput).set('a', 44); } #[test] fn not_constant() { - let db = &TestContextImpl::default(); + let db = &mut TestContextImpl::default(); - db.query(ConstantsInput).set('a', 22); - db.query(ConstantsInput).set('b', 44); + db.query_mut(ConstantsInput).set('a', 22); + db.query_mut(ConstantsInput).set('b', 44); assert_eq!(db.constants_add(('a', 'b')), 66); assert!(!db.query(ConstantsAdd).is_constant(('a', 'b'))); } #[test] fn is_constant() { - let db = &TestContextImpl::default(); + let db = &mut TestContextImpl::default(); - db.query(ConstantsInput).set_constant('a', 22); - db.query(ConstantsInput).set_constant('b', 44); + db.query_mut(ConstantsInput).set_constant('a', 22); + db.query_mut(ConstantsInput).set_constant('b', 44); assert_eq!(db.constants_add(('a', 'b')), 66); assert!(db.query(ConstantsAdd).is_constant(('a', 'b'))); } #[test] fn mixed_constant() { - let db = &TestContextImpl::default(); + let db = &mut TestContextImpl::default(); - db.query(ConstantsInput).set_constant('a', 22); - db.query(ConstantsInput).set('b', 44); + db.query_mut(ConstantsInput).set_constant('a', 22); + db.query_mut(ConstantsInput).set('b', 44); assert_eq!(db.constants_add(('a', 'b')), 66); assert!(!db.query(ConstantsAdd).is_constant(('a', 'b'))); } #[test] fn becomes_constant_with_change() { - let db = &TestContextImpl::default(); + let db = &mut TestContextImpl::default(); - db.query(ConstantsInput).set('a', 22); - db.query(ConstantsInput).set('b', 44); + db.query_mut(ConstantsInput).set('a', 22); + db.query_mut(ConstantsInput).set('b', 44); assert_eq!(db.constants_add(('a', 'b')), 66); assert!(!db.query(ConstantsAdd).is_constant(('a', 'b'))); - db.query(ConstantsInput).set_constant('a', 23); + db.query_mut(ConstantsInput).set_constant('a', 23); assert_eq!(db.constants_add(('a', 'b')), 67); assert!(!db.query(ConstantsAdd).is_constant(('a', 'b'))); - db.query(ConstantsInput).set_constant('b', 45); + db.query_mut(ConstantsInput).set_constant('b', 45); assert_eq!(db.constants_add(('a', 'b')), 68); assert!(db.query(ConstantsAdd).is_constant(('a', 'b'))); } diff --git a/tests/incremental/memoized_dep_inputs.rs b/tests/incremental/memoized_dep_inputs.rs index fd305999..ac895845 100644 --- a/tests/incremental/memoized_dep_inputs.rs +++ b/tests/incremental/memoized_dep_inputs.rs @@ -41,9 +41,9 @@ fn dep_derived1(db: &impl MemoizedDepInputsContext) -> usize { #[test] fn revalidate() { - let db = &TestContextImpl::default(); + let db = &mut TestContextImpl::default(); - db.query(Input1).set((), 0); + db.query_mut(Input1).set((), 0); // Initial run starts from Memoized2: let v = db.dep_memoized2(); @@ -53,19 +53,19 @@ fn revalidate() { // After that, we first try to validate Memoized1 but wind up // running Memoized2. Note that we don't try to validate // Derived1, so it is invoked by Memoized1. - db.query(Input1).set((), 44); + db.query_mut(Input1).set((), 44); let v = db.dep_memoized2(); assert_eq!(v, 44); db.assert_log(&["Memoized1 invoked", "Derived1 invoked", "Memoized2 invoked"]); // Here validation of Memoized1 succeeds so Memoized2 never runs. - db.query(Input1).set((), 45); + db.query_mut(Input1).set((), 45); let v = db.dep_memoized2(); assert_eq!(v, 44); db.assert_log(&["Memoized1 invoked", "Derived1 invoked"]); // Here, a change to input2 doesn't affect us, so nothing runs. - db.query(Input2).set((), 45); + db.query_mut(Input2).set((), 45); let v = db.dep_memoized2(); assert_eq!(v, 44); db.assert_log(&[]); diff --git a/tests/incremental/memoized_inputs.rs b/tests/incremental/memoized_inputs.rs index 5ff409b5..7957144e 100644 --- a/tests/incremental/memoized_inputs.rs +++ b/tests/incremental/memoized_inputs.rs @@ -24,10 +24,10 @@ fn max(db: &impl MemoizedInputsContext) -> usize { #[test] fn revalidate() { - let db = &TestContextImpl::default(); + let db = &mut TestContextImpl::default(); - db.query(Input1).set((), 0); - db.query(Input2).set((), 0); + db.query_mut(Input1).set((), 0); + db.query_mut(Input2).set((), 0); let v = db.max(); assert_eq!(v, 0); @@ -37,7 +37,7 @@ fn revalidate() { assert_eq!(v, 0); db.assert_log(&[]); - db.query(Input1).set((), 44); + db.query_mut(Input1).set((), 44); db.assert_log(&[]); let v = db.max(); @@ -48,11 +48,11 @@ fn revalidate() { assert_eq!(v, 44); db.assert_log(&[]); - db.query(Input1).set((), 44); + db.query_mut(Input1).set((), 44); db.assert_log(&[]); - db.query(Input2).set((), 66); + db.query_mut(Input2).set((), 66); db.assert_log(&[]); - db.query(Input1).set((), 64); + db.query_mut(Input1).set((), 64); db.assert_log(&[]); let v = db.max(); @@ -68,16 +68,16 @@ fn revalidate() { /// triggers a new revision. #[test] fn set_after_no_change() { - let db = &TestContextImpl::default(); + let db = &mut TestContextImpl::default(); - db.query(Input2).set((), 0); + db.query_mut(Input2).set((), 0); - db.query(Input1).set((), 44); + db.query_mut(Input1).set((), 44); let v = db.max(); assert_eq!(v, 44); db.assert_log(&["Max invoked"]); - db.query(Input1).set((), 44); + db.query_mut(Input1).set((), 44); let v = db.max(); assert_eq!(v, 44); db.assert_log(&["Max invoked"]); diff --git a/tests/panic_safely.rs b/tests/panic_safely.rs index d8dcc19b..132c0546 100644 --- a/tests/panic_safely.rs +++ b/tests/panic_safely.rs @@ -48,7 +48,7 @@ salsa::database_storage! { #[test] fn should_panic_safely() { - let db = DatabaseStruct::default(); + let mut db = DatabaseStruct::default(); // Invoke `db.panic_safely() without having set `db.one`. `db.one` will // default to 0 and we should catch the panic. @@ -59,7 +59,7 @@ fn should_panic_safely() { assert!(result.is_err()); // Set `db.one` to 1 and assert ok - db.query(One).set((), 1); + db.query_mut(One).set((), 1); let result = panic::catch_unwind(AssertUnwindSafe(|| db.panic_safely())); assert!(result.is_ok()) } diff --git a/tests/parallel/cancellation.rs b/tests/parallel/cancellation.rs index b6fa8024..27cde460 100644 --- a/tests/parallel/cancellation.rs +++ b/tests/parallel/cancellation.rs @@ -6,12 +6,12 @@ use salsa::{Database, ParallelDatabase}; /// though none of the inputs have changed. #[test] fn in_par_get_set_cancellation_immediate() { - let db = ParDatabaseImpl::default(); + let mut db = ParDatabaseImpl::default(); - db.query(Input).set('a', 100); - db.query(Input).set('b', 010); - db.query(Input).set('c', 001); - db.query(Input).set('d', 0); + db.query_mut(Input).set('a', 100); + db.query_mut(Input).set('b', 010); + db.query_mut(Input).set('c', 001); + db.query_mut(Input).set('d', 0); let thread1 = std::thread::spawn({ let db = db.snapshot(); @@ -30,7 +30,7 @@ fn in_par_get_set_cancellation_immediate() { db.wait_for(1); // Try to set the input. This will signal cancellation. - db.query(Input).set('d', 1000); + db.query_mut(Input).set('d', 1000); // This should re-compute the value (even though no input has changed). let thread2 = std::thread::spawn({ @@ -47,12 +47,12 @@ fn in_par_get_set_cancellation_immediate() { /// to `sum2` properly. #[test] fn in_par_get_set_cancellation_transitive() { - let db = ParDatabaseImpl::default(); + let mut db = ParDatabaseImpl::default(); - db.query(Input).set('a', 100); - db.query(Input).set('b', 010); - db.query(Input).set('c', 001); - db.query(Input).set('d', 0); + db.query_mut(Input).set('a', 100); + db.query_mut(Input).set('b', 010); + db.query_mut(Input).set('c', 001); + db.query_mut(Input).set('d', 0); let thread1 = std::thread::spawn({ let db = db.snapshot(); @@ -71,7 +71,7 @@ fn in_par_get_set_cancellation_transitive() { db.wait_for(1); // Try to set the input. This will signal cancellation. - db.query(Input).set('d', 1000); + db.query_mut(Input).set('d', 1000); // This should re-compute the value (even though no input has changed). let thread2 = std::thread::spawn({ diff --git a/tests/parallel/frozen.rs b/tests/parallel/frozen.rs index fb8970f7..9a830af3 100644 --- a/tests/parallel/frozen.rs +++ b/tests/parallel/frozen.rs @@ -8,9 +8,9 @@ use std::sync::Arc; /// though none of the inputs have changed. #[test] fn in_par_get_set_cancellation() { - let db = ParDatabaseImpl::default(); + let mut db = ParDatabaseImpl::default(); - db.query(Input).set('a', 1); + db.query_mut(Input).set('a', 1); let signal = Arc::new(Signal::default()); @@ -50,7 +50,7 @@ fn in_par_get_set_cancellation() { signal.wait_for(1); // This will block until thread1 drops the revision lock. - db.query(Input).set('a', 2); + db.query_mut(Input).set('a', 2); db.input('a') } diff --git a/tests/parallel/independent.rs b/tests/parallel/independent.rs index 27716b66..252b15db 100644 --- a/tests/parallel/independent.rs +++ b/tests/parallel/independent.rs @@ -5,14 +5,14 @@ use salsa::{Database, ParallelDatabase}; /// threads. Really just a test that `snapshot` etc compiles. #[test] fn in_par_two_independent_queries() { - let db = ParDatabaseImpl::default(); + let mut db = ParDatabaseImpl::default(); - db.query(Input).set('a', 100); - db.query(Input).set('b', 010); - db.query(Input).set('c', 001); - db.query(Input).set('d', 200); - db.query(Input).set('e', 020); - db.query(Input).set('f', 002); + db.query_mut(Input).set('a', 100); + db.query_mut(Input).set('b', 010); + db.query_mut(Input).set('c', 001); + db.query_mut(Input).set('d', 200); + db.query_mut(Input).set('e', 020); + db.query_mut(Input).set('f', 002); let thread1 = std::thread::spawn({ let db = db.snapshot(); diff --git a/tests/parallel/race.rs b/tests/parallel/race.rs index 3fd1e409..b862096c 100644 --- a/tests/parallel/race.rs +++ b/tests/parallel/race.rs @@ -5,11 +5,11 @@ use salsa::{Database, ParallelDatabase}; /// Should be atomic. #[test] fn in_par_get_set_race() { - let db = ParDatabaseImpl::default(); + let mut db = ParDatabaseImpl::default(); - db.query(Input).set('a', 100); - db.query(Input).set('b', 010); - db.query(Input).set('c', 001); + db.query_mut(Input).set('a', 100); + db.query_mut(Input).set('b', 010); + db.query_mut(Input).set('c', 001); let thread1 = std::thread::spawn({ let db = db.snapshot(); @@ -20,7 +20,7 @@ fn in_par_get_set_race() { }); let thread2 = std::thread::spawn(move || { - db.query(Input).set('a', 1000); + db.query_mut(Input).set('a', 1000); db.sum("a") }); diff --git a/tests/parallel/stress.rs b/tests/parallel/stress.rs index 4d312bd6..440e4bbc 100644 --- a/tests/parallel/stress.rs +++ b/tests/parallel/stress.rs @@ -157,7 +157,7 @@ impl WriteOp { fn execute(self, db: &mut StressDatabaseImpl) { match self { WriteOp::SetA(key, value) => { - db.query(A).set(key, value); + db.query_mut(A).set(key, value); } } } @@ -199,7 +199,7 @@ impl ReadOp { fn stress_test() { let mut db = StressDatabaseImpl::default(); for i in 0..10 { - db.query(A).set(i, i); + db.query_mut(A).set(i, i); } let mut rng = rand::thread_rng(); diff --git a/tests/parallel/true_parallel.rs b/tests/parallel/true_parallel.rs index 53645390..986dc9c0 100644 --- a/tests/parallel/true_parallel.rs +++ b/tests/parallel/true_parallel.rs @@ -8,11 +8,11 @@ use salsa::ParallelDatabase; /// waits for thread1 to send a signal before it enters). #[test] fn true_parallel_different_keys() { - let db = ParDatabaseImpl::default(); + let mut db = ParDatabaseImpl::default(); - db.query(Input).set('a', 100); - db.query(Input).set('b', 010); - db.query(Input).set('c', 001); + db.query_mut(Input).set('a', 100); + db.query_mut(Input).set('b', 010); + db.query_mut(Input).set('c', 001); // Thread 1 will signal stage 1 when it enters and wait for stage 2. let thread1 = std::thread::spawn({ @@ -48,11 +48,11 @@ fn true_parallel_different_keys() { /// therefore has to block. #[test] fn true_parallel_same_keys() { - let db = ParDatabaseImpl::default(); + let mut db = ParDatabaseImpl::default(); - db.query(Input).set('a', 100); - db.query(Input).set('b', 010); - db.query(Input).set('c', 001); + db.query_mut(Input).set('a', 100); + db.query_mut(Input).set('b', 010); + db.query_mut(Input).set('c', 001); // Thread 1 will wait_for a barrier in the start of `sum` let thread1 = std::thread::spawn({ diff --git a/tests/set_unchecked.rs b/tests/set_unchecked.rs index f27b0c8c..4a308883 100644 --- a/tests/set_unchecked.rs +++ b/tests/set_unchecked.rs @@ -50,10 +50,10 @@ salsa::database_storage! { #[test] fn normal() { - let db = DatabaseStruct::default(); - db.query(Input).set((), format!("Hello, world")); + let mut db = DatabaseStruct::default(); + db.query_mut(Input).set((), format!("Hello, world")); assert_eq!(db.double_length(), 24); - db.query(Input).set((), format!("Hello, world!")); + db.query_mut(Input).set((), format!("Hello, world!")); assert_eq!(db.double_length(), 26); } @@ -66,30 +66,32 @@ fn use_without_set() { #[test] fn using_set_unchecked_on_input() { - let db = DatabaseStruct::default(); - db.query(Input).set_unchecked((), format!("Hello, world")); + let mut db = DatabaseStruct::default(); + db.query_mut(Input) + .set_unchecked((), format!("Hello, world")); assert_eq!(db.double_length(), 24); } #[test] fn using_set_unchecked_on_input_after() { - let db = DatabaseStruct::default(); - db.query(Input).set((), format!("Hello, world")); + let mut db = DatabaseStruct::default(); + db.query_mut(Input).set((), format!("Hello, world")); assert_eq!(db.double_length(), 24); // If we use `set_unchecked`, we don't notice that `double_length` // is out of date. Oh well, don't do that. - db.query(Input).set_unchecked((), format!("Hello, world!")); + db.query_mut(Input) + .set_unchecked((), format!("Hello, world!")); assert_eq!(db.double_length(), 24); } #[test] fn using_set_unchecked() { - let db = DatabaseStruct::default(); + let mut db = DatabaseStruct::default(); // Use `set_unchecked` to intentionally set the wrong value, // demonstrating that the code never runs. - db.query(Length).set_unchecked((), 24); + db.query_mut(Length).set_unchecked((), 24); assert_eq!(db.double_length(), 48); } diff --git a/tests/variadic.rs b/tests/variadic.rs index 4d184e85..6b17697f 100644 --- a/tests/variadic.rs +++ b/tests/variadic.rs @@ -66,10 +66,10 @@ salsa::database_storage! { #[test] fn execute() { - let db = DatabaseStruct::default(); + let mut db = DatabaseStruct::default(); // test what happens with inputs: - db.query(Input).set((1, 2), 3); + db.query_mut(Input).set((1, 2), 3); assert_eq!(db.input(1, 2), 3); assert_eq!(db.none(), 22);