introduce query_mut which you must use to get set methods

This commit is contained in:
Niko Matsakis 2018-11-01 04:30:54 -04:00
parent 981de0cac1
commit 49cc8abe43
16 changed files with 162 additions and 107 deletions

View file

@ -93,11 +93,11 @@ salsa::database_storage! {
// This shows how to use a query.
fn main() {
let db = DatabaseStruct::default();
let mut db = DatabaseStruct::default();
println!("Initially, the length is {}.", db.length(()));
db.query(InputString)
db.query_mut(InputString)
.set((), Arc::new(format!("Hello, world")));
println!("Now, the length is {}.", db.length(()));

View file

@ -63,6 +63,17 @@ pub trait Database: plumbing::DatabaseStorageTypes + plumbing::DatabaseOps {
<Self as plumbing::GetQueryTable<Q>>::get_query_table(self)
}
/// Get access to extra methods pertaining to a given query,
/// notably `set` (for inputs).
#[allow(unused_variables)]
fn query_mut<Q>(&mut self, query: Q) -> QueryTableMut<'_, Self, Q>
where
Q: Query<Self>,
Self: plumbing::GetQueryTable<Q>,
{
<Self as plumbing::GetQueryTable<Q>>::get_query_table_mut(self)
}
/// This function is invoked at key points in the salsa
/// runtime. It permits the database to be customized and to
/// inject logging or other custom behavior.
@ -300,6 +311,31 @@ where
self.storage.sweep(self.db, strategy);
}
fn descriptor(&self, key: &Q::Key) -> DB::QueryDescriptor {
(self.descriptor_fn)(self.db, key)
}
}
#[derive(new)]
pub struct QueryTableMut<'me, DB, Q>
where
DB: Database + 'me,
Q: Query<DB> + 'me,
{
db: &'me DB,
storage: &'me Q::Storage,
descriptor_fn: fn(&DB, &Q::Key) -> DB::QueryDescriptor,
}
impl<DB, Q> QueryTableMut<'_, DB, Q>
where
DB: Database,
Q: Query<DB>,
{
fn descriptor(&self, key: &Q::Key) -> DB::QueryDescriptor {
(self.descriptor_fn)(self.db, key)
}
/// Assign a value to an "input query". Must be used outside of
/// an active query computation.
pub fn set(&self, key: Q::Key, value: Q::Value)
@ -342,10 +378,6 @@ where
{
self.storage.set_unchecked(self.db, &key, value);
}
fn descriptor(&self, key: &Q::Key) -> DB::QueryDescriptor {
(self.descriptor_fn)(self.db, key)
}
}
/// A macro that helps in defining the "context trait" of a given
@ -748,6 +780,24 @@ macro_rules! database_storage {
},
)
}
fn get_query_table_mut(
db: &mut Self,
) -> $crate::QueryTableMut<'_, Self, $QueryType> {
let db = &*db;
$crate::QueryTableMut::new(
db,
&$crate::Database::salsa_runtime(db)
.storage()
.$query_method,
|_, key| {
let key = std::clone::Clone::clone(key);
__SalsaQueryDescriptor {
kind: __SalsaQueryDescriptorKind::$query_method(key),
}
},
)
}
}
)*
)*

View file

@ -1,6 +1,7 @@
use crate::Database;
use crate::Query;
use crate::QueryTable;
use crate::QueryTableMut;
use crate::SweepStrategy;
use std::fmt::Debug;
use std::hash::Hash;
@ -57,6 +58,8 @@ pub trait QueryFunction<DB: Database>: Query<DB> {
pub trait GetQueryTable<Q: Query<Self>>: Database {
fn get_query_table(db: &Self) -> QueryTable<'_, Self, Q>;
fn get_query_table_mut(db: &mut Self) -> QueryTableMut<'_, Self, Q>;
}
pub trait QueryStorageOps<DB, Q>: Default

View file

@ -15,10 +15,10 @@ macro_rules! assert_keys {
#[test]
fn compute_one() {
let db = db::DatabaseImpl::default();
let mut db = db::DatabaseImpl::default();
// Will compute fibonacci(5)
db.query(UseTriangular).set(5, false);
db.query_mut(UseTriangular).set(5, false);
db.compute(5);
db.salsa_runtime().next_revision();
@ -50,14 +50,14 @@ fn compute_one() {
#[test]
fn compute_switch() {
let db = db::DatabaseImpl::default();
let mut db = db::DatabaseImpl::default();
// Will compute fibonacci(5)
db.query(UseTriangular).set(5, false);
db.query_mut(UseTriangular).set(5, false);
assert_eq!(db.compute(5), 5);
// Change to triangular mode
db.query(UseTriangular).set(5, true);
db.query_mut(UseTriangular).set(5, true);
// Now computes triangular(5)
assert_eq!(db.compute(5), 15);
@ -107,14 +107,14 @@ fn compute_switch() {
/// Test a query with multiple layers of keys.
#[test]
fn compute_all() {
let db = db::DatabaseImpl::default();
let mut db = db::DatabaseImpl::default();
for i in 0..6 {
db.query(UseTriangular).set(i, (i % 2) != 0);
db.query_mut(UseTriangular).set(i, (i % 2) != 0);
}
db.query(Min).set((), 0);
db.query(Max).set((), 6);
db.query_mut(Min).set((), 0);
db.query_mut(Max).set((), 6);
db.compute_all();
db.salsa_runtime().next_revision();
@ -133,7 +133,7 @@ fn compute_all() {
}
// Reduce the range to exclude index 5.
db.query(Max).set((), 5);
db.query_mut(Max).set((), 5);
db.compute_all();
assert_keys! {

View file

@ -23,24 +23,24 @@ fn constants_add(db: &impl ConstantsDatabase, (key1, key2): (char, char)) -> usi
#[test]
#[should_panic]
fn invalidate_constant() {
let db = &TestContextImpl::default();
db.query(ConstantsInput).set_constant('a', 44);
db.query(ConstantsInput).set_constant('a', 66);
let db = &mut TestContextImpl::default();
db.query_mut(ConstantsInput).set_constant('a', 44);
db.query_mut(ConstantsInput).set_constant('a', 66);
}
#[test]
#[should_panic]
fn invalidate_constant_1() {
let db = &TestContextImpl::default();
let db = &mut TestContextImpl::default();
// Not constant:
db.query(ConstantsInput).set('a', 44);
db.query_mut(ConstantsInput).set('a', 44);
// Becomes constant:
db.query(ConstantsInput).set_constant('a', 44);
db.query_mut(ConstantsInput).set_constant('a', 44);
// Invalidates:
db.query(ConstantsInput).set_constant('a', 66);
db.query_mut(ConstantsInput).set_constant('a', 66);
}
/// Test that invoking `set` on a constant is an error, even if you
@ -48,55 +48,55 @@ fn invalidate_constant_1() {
#[test]
#[should_panic]
fn set_after_constant_same_value() {
let db = &TestContextImpl::default();
db.query(ConstantsInput).set_constant('a', 44);
db.query(ConstantsInput).set('a', 44);
let db = &mut TestContextImpl::default();
db.query_mut(ConstantsInput).set_constant('a', 44);
db.query_mut(ConstantsInput).set('a', 44);
}
#[test]
fn not_constant() {
let db = &TestContextImpl::default();
let db = &mut TestContextImpl::default();
db.query(ConstantsInput).set('a', 22);
db.query(ConstantsInput).set('b', 44);
db.query_mut(ConstantsInput).set('a', 22);
db.query_mut(ConstantsInput).set('b', 44);
assert_eq!(db.constants_add(('a', 'b')), 66);
assert!(!db.query(ConstantsAdd).is_constant(('a', 'b')));
}
#[test]
fn is_constant() {
let db = &TestContextImpl::default();
let db = &mut TestContextImpl::default();
db.query(ConstantsInput).set_constant('a', 22);
db.query(ConstantsInput).set_constant('b', 44);
db.query_mut(ConstantsInput).set_constant('a', 22);
db.query_mut(ConstantsInput).set_constant('b', 44);
assert_eq!(db.constants_add(('a', 'b')), 66);
assert!(db.query(ConstantsAdd).is_constant(('a', 'b')));
}
#[test]
fn mixed_constant() {
let db = &TestContextImpl::default();
let db = &mut TestContextImpl::default();
db.query(ConstantsInput).set_constant('a', 22);
db.query(ConstantsInput).set('b', 44);
db.query_mut(ConstantsInput).set_constant('a', 22);
db.query_mut(ConstantsInput).set('b', 44);
assert_eq!(db.constants_add(('a', 'b')), 66);
assert!(!db.query(ConstantsAdd).is_constant(('a', 'b')));
}
#[test]
fn becomes_constant_with_change() {
let db = &TestContextImpl::default();
let db = &mut TestContextImpl::default();
db.query(ConstantsInput).set('a', 22);
db.query(ConstantsInput).set('b', 44);
db.query_mut(ConstantsInput).set('a', 22);
db.query_mut(ConstantsInput).set('b', 44);
assert_eq!(db.constants_add(('a', 'b')), 66);
assert!(!db.query(ConstantsAdd).is_constant(('a', 'b')));
db.query(ConstantsInput).set_constant('a', 23);
db.query_mut(ConstantsInput).set_constant('a', 23);
assert_eq!(db.constants_add(('a', 'b')), 67);
assert!(!db.query(ConstantsAdd).is_constant(('a', 'b')));
db.query(ConstantsInput).set_constant('b', 45);
db.query_mut(ConstantsInput).set_constant('b', 45);
assert_eq!(db.constants_add(('a', 'b')), 68);
assert!(db.query(ConstantsAdd).is_constant(('a', 'b')));
}

View file

@ -41,9 +41,9 @@ fn dep_derived1(db: &impl MemoizedDepInputsContext) -> usize {
#[test]
fn revalidate() {
let db = &TestContextImpl::default();
let db = &mut TestContextImpl::default();
db.query(Input1).set((), 0);
db.query_mut(Input1).set((), 0);
// Initial run starts from Memoized2:
let v = db.dep_memoized2();
@ -53,19 +53,19 @@ fn revalidate() {
// After that, we first try to validate Memoized1 but wind up
// running Memoized2. Note that we don't try to validate
// Derived1, so it is invoked by Memoized1.
db.query(Input1).set((), 44);
db.query_mut(Input1).set((), 44);
let v = db.dep_memoized2();
assert_eq!(v, 44);
db.assert_log(&["Memoized1 invoked", "Derived1 invoked", "Memoized2 invoked"]);
// Here validation of Memoized1 succeeds so Memoized2 never runs.
db.query(Input1).set((), 45);
db.query_mut(Input1).set((), 45);
let v = db.dep_memoized2();
assert_eq!(v, 44);
db.assert_log(&["Memoized1 invoked", "Derived1 invoked"]);
// Here, a change to input2 doesn't affect us, so nothing runs.
db.query(Input2).set((), 45);
db.query_mut(Input2).set((), 45);
let v = db.dep_memoized2();
assert_eq!(v, 44);
db.assert_log(&[]);

View file

@ -24,10 +24,10 @@ fn max(db: &impl MemoizedInputsContext) -> usize {
#[test]
fn revalidate() {
let db = &TestContextImpl::default();
let db = &mut TestContextImpl::default();
db.query(Input1).set((), 0);
db.query(Input2).set((), 0);
db.query_mut(Input1).set((), 0);
db.query_mut(Input2).set((), 0);
let v = db.max();
assert_eq!(v, 0);
@ -37,7 +37,7 @@ fn revalidate() {
assert_eq!(v, 0);
db.assert_log(&[]);
db.query(Input1).set((), 44);
db.query_mut(Input1).set((), 44);
db.assert_log(&[]);
let v = db.max();
@ -48,11 +48,11 @@ fn revalidate() {
assert_eq!(v, 44);
db.assert_log(&[]);
db.query(Input1).set((), 44);
db.query_mut(Input1).set((), 44);
db.assert_log(&[]);
db.query(Input2).set((), 66);
db.query_mut(Input2).set((), 66);
db.assert_log(&[]);
db.query(Input1).set((), 64);
db.query_mut(Input1).set((), 64);
db.assert_log(&[]);
let v = db.max();
@ -68,16 +68,16 @@ fn revalidate() {
/// triggers a new revision.
#[test]
fn set_after_no_change() {
let db = &TestContextImpl::default();
let db = &mut TestContextImpl::default();
db.query(Input2).set((), 0);
db.query_mut(Input2).set((), 0);
db.query(Input1).set((), 44);
db.query_mut(Input1).set((), 44);
let v = db.max();
assert_eq!(v, 44);
db.assert_log(&["Max invoked"]);
db.query(Input1).set((), 44);
db.query_mut(Input1).set((), 44);
let v = db.max();
assert_eq!(v, 44);
db.assert_log(&["Max invoked"]);

View file

@ -48,7 +48,7 @@ salsa::database_storage! {
#[test]
fn should_panic_safely() {
let db = DatabaseStruct::default();
let mut db = DatabaseStruct::default();
// Invoke `db.panic_safely() without having set `db.one`. `db.one` will
// default to 0 and we should catch the panic.
@ -59,7 +59,7 @@ fn should_panic_safely() {
assert!(result.is_err());
// Set `db.one` to 1 and assert ok
db.query(One).set((), 1);
db.query_mut(One).set((), 1);
let result = panic::catch_unwind(AssertUnwindSafe(|| db.panic_safely()));
assert!(result.is_ok())
}

View file

@ -6,12 +6,12 @@ use salsa::{Database, ParallelDatabase};
/// though none of the inputs have changed.
#[test]
fn in_par_get_set_cancellation_immediate() {
let db = ParDatabaseImpl::default();
let mut db = ParDatabaseImpl::default();
db.query(Input).set('a', 100);
db.query(Input).set('b', 010);
db.query(Input).set('c', 001);
db.query(Input).set('d', 0);
db.query_mut(Input).set('a', 100);
db.query_mut(Input).set('b', 010);
db.query_mut(Input).set('c', 001);
db.query_mut(Input).set('d', 0);
let thread1 = std::thread::spawn({
let db = db.snapshot();
@ -30,7 +30,7 @@ fn in_par_get_set_cancellation_immediate() {
db.wait_for(1);
// Try to set the input. This will signal cancellation.
db.query(Input).set('d', 1000);
db.query_mut(Input).set('d', 1000);
// This should re-compute the value (even though no input has changed).
let thread2 = std::thread::spawn({
@ -47,12 +47,12 @@ fn in_par_get_set_cancellation_immediate() {
/// to `sum2` properly.
#[test]
fn in_par_get_set_cancellation_transitive() {
let db = ParDatabaseImpl::default();
let mut db = ParDatabaseImpl::default();
db.query(Input).set('a', 100);
db.query(Input).set('b', 010);
db.query(Input).set('c', 001);
db.query(Input).set('d', 0);
db.query_mut(Input).set('a', 100);
db.query_mut(Input).set('b', 010);
db.query_mut(Input).set('c', 001);
db.query_mut(Input).set('d', 0);
let thread1 = std::thread::spawn({
let db = db.snapshot();
@ -71,7 +71,7 @@ fn in_par_get_set_cancellation_transitive() {
db.wait_for(1);
// Try to set the input. This will signal cancellation.
db.query(Input).set('d', 1000);
db.query_mut(Input).set('d', 1000);
// This should re-compute the value (even though no input has changed).
let thread2 = std::thread::spawn({

View file

@ -8,9 +8,9 @@ use std::sync::Arc;
/// though none of the inputs have changed.
#[test]
fn in_par_get_set_cancellation() {
let db = ParDatabaseImpl::default();
let mut db = ParDatabaseImpl::default();
db.query(Input).set('a', 1);
db.query_mut(Input).set('a', 1);
let signal = Arc::new(Signal::default());
@ -50,7 +50,7 @@ fn in_par_get_set_cancellation() {
signal.wait_for(1);
// This will block until thread1 drops the revision lock.
db.query(Input).set('a', 2);
db.query_mut(Input).set('a', 2);
db.input('a')
}

View file

@ -5,14 +5,14 @@ use salsa::{Database, ParallelDatabase};
/// threads. Really just a test that `snapshot` etc compiles.
#[test]
fn in_par_two_independent_queries() {
let db = ParDatabaseImpl::default();
let mut db = ParDatabaseImpl::default();
db.query(Input).set('a', 100);
db.query(Input).set('b', 010);
db.query(Input).set('c', 001);
db.query(Input).set('d', 200);
db.query(Input).set('e', 020);
db.query(Input).set('f', 002);
db.query_mut(Input).set('a', 100);
db.query_mut(Input).set('b', 010);
db.query_mut(Input).set('c', 001);
db.query_mut(Input).set('d', 200);
db.query_mut(Input).set('e', 020);
db.query_mut(Input).set('f', 002);
let thread1 = std::thread::spawn({
let db = db.snapshot();

View file

@ -5,11 +5,11 @@ use salsa::{Database, ParallelDatabase};
/// Should be atomic.
#[test]
fn in_par_get_set_race() {
let db = ParDatabaseImpl::default();
let mut db = ParDatabaseImpl::default();
db.query(Input).set('a', 100);
db.query(Input).set('b', 010);
db.query(Input).set('c', 001);
db.query_mut(Input).set('a', 100);
db.query_mut(Input).set('b', 010);
db.query_mut(Input).set('c', 001);
let thread1 = std::thread::spawn({
let db = db.snapshot();
@ -20,7 +20,7 @@ fn in_par_get_set_race() {
});
let thread2 = std::thread::spawn(move || {
db.query(Input).set('a', 1000);
db.query_mut(Input).set('a', 1000);
db.sum("a")
});

View file

@ -157,7 +157,7 @@ impl WriteOp {
fn execute(self, db: &mut StressDatabaseImpl) {
match self {
WriteOp::SetA(key, value) => {
db.query(A).set(key, value);
db.query_mut(A).set(key, value);
}
}
}
@ -199,7 +199,7 @@ impl ReadOp {
fn stress_test() {
let mut db = StressDatabaseImpl::default();
for i in 0..10 {
db.query(A).set(i, i);
db.query_mut(A).set(i, i);
}
let mut rng = rand::thread_rng();

View file

@ -8,11 +8,11 @@ use salsa::ParallelDatabase;
/// waits for thread1 to send a signal before it enters).
#[test]
fn true_parallel_different_keys() {
let db = ParDatabaseImpl::default();
let mut db = ParDatabaseImpl::default();
db.query(Input).set('a', 100);
db.query(Input).set('b', 010);
db.query(Input).set('c', 001);
db.query_mut(Input).set('a', 100);
db.query_mut(Input).set('b', 010);
db.query_mut(Input).set('c', 001);
// Thread 1 will signal stage 1 when it enters and wait for stage 2.
let thread1 = std::thread::spawn({
@ -48,11 +48,11 @@ fn true_parallel_different_keys() {
/// therefore has to block.
#[test]
fn true_parallel_same_keys() {
let db = ParDatabaseImpl::default();
let mut db = ParDatabaseImpl::default();
db.query(Input).set('a', 100);
db.query(Input).set('b', 010);
db.query(Input).set('c', 001);
db.query_mut(Input).set('a', 100);
db.query_mut(Input).set('b', 010);
db.query_mut(Input).set('c', 001);
// Thread 1 will wait_for a barrier in the start of `sum`
let thread1 = std::thread::spawn({

View file

@ -50,10 +50,10 @@ salsa::database_storage! {
#[test]
fn normal() {
let db = DatabaseStruct::default();
db.query(Input).set((), format!("Hello, world"));
let mut db = DatabaseStruct::default();
db.query_mut(Input).set((), format!("Hello, world"));
assert_eq!(db.double_length(), 24);
db.query(Input).set((), format!("Hello, world!"));
db.query_mut(Input).set((), format!("Hello, world!"));
assert_eq!(db.double_length(), 26);
}
@ -66,30 +66,32 @@ fn use_without_set() {
#[test]
fn using_set_unchecked_on_input() {
let db = DatabaseStruct::default();
db.query(Input).set_unchecked((), format!("Hello, world"));
let mut db = DatabaseStruct::default();
db.query_mut(Input)
.set_unchecked((), format!("Hello, world"));
assert_eq!(db.double_length(), 24);
}
#[test]
fn using_set_unchecked_on_input_after() {
let db = DatabaseStruct::default();
db.query(Input).set((), format!("Hello, world"));
let mut db = DatabaseStruct::default();
db.query_mut(Input).set((), format!("Hello, world"));
assert_eq!(db.double_length(), 24);
// If we use `set_unchecked`, we don't notice that `double_length`
// is out of date. Oh well, don't do that.
db.query(Input).set_unchecked((), format!("Hello, world!"));
db.query_mut(Input)
.set_unchecked((), format!("Hello, world!"));
assert_eq!(db.double_length(), 24);
}
#[test]
fn using_set_unchecked() {
let db = DatabaseStruct::default();
let mut db = DatabaseStruct::default();
// Use `set_unchecked` to intentionally set the wrong value,
// demonstrating that the code never runs.
db.query(Length).set_unchecked((), 24);
db.query_mut(Length).set_unchecked((), 24);
assert_eq!(db.double_length(), 48);
}

View file

@ -66,10 +66,10 @@ salsa::database_storage! {
#[test]
fn execute() {
let db = DatabaseStruct::default();
let mut db = DatabaseStruct::default();
// test what happens with inputs:
db.query(Input).set((1, 2), 3);
db.query_mut(Input).set((1, 2), 3);
assert_eq!(db.input(1, 2), 3);
assert_eq!(db.none(), 22);