diff --git a/examples/hello_world/class_table.rs b/examples/hello_world/class_table.rs index 5fc439c3..bd722c25 100644 --- a/examples/hello_world/class_table.rs +++ b/examples/hello_world/class_table.rs @@ -34,11 +34,11 @@ query_definition! { pub AllFields(query: &impl ClassTableQueryContext, (): ()) -> Arc> { Arc::new( query.all_classes() - .of(()) + .get(()) .iter() .cloned() .flat_map(|def_id| { - let fields = query.fields().of(def_id); + let fields = query.fields().get(def_id); (0..fields.len()).map(move |i| fields[i]) }) .collect() diff --git a/examples/hello_world/main.rs b/examples/hello_world/main.rs index 5fbf33fc..5b631c8e 100644 --- a/examples/hello_world/main.rs +++ b/examples/hello_world/main.rs @@ -8,7 +8,7 @@ use self::implementation::QueryContextImpl; #[test] fn test() { let query = QueryContextImpl::default(); - let all_def_ids = query.all_fields().of(()); + let all_def_ids = query.all_fields().read(); assert_eq!( format!("{:?}", all_def_ids), "[DefId(1), DefId(2), DefId(11), DefId(12)]" @@ -17,7 +17,7 @@ fn test() { fn main() { let query = QueryContextImpl::default(); - for f in query.all_fields().of(()).iter() { + for f in query.all_fields().read().iter() { println!("{:?}", f); } } diff --git a/src/input.rs b/src/input.rs new file mode 100644 index 00000000..bdce2d19 --- /dev/null +++ b/src/input.rs @@ -0,0 +1,145 @@ +use crate::runtime::QueryDescriptorSet; +use crate::runtime::Revision; +use crate::CycleDetected; +use crate::MutQueryStorageOps; +use crate::Query; +use crate::QueryContext; +use crate::QueryDescriptor; +use crate::QueryStorageOps; +use crate::QueryTable; +use log::debug; +use parking_lot::{RwLock, RwLockUpgradableReadGuard}; +use rustc_hash::FxHashMap; +use std::any::Any; +use std::cell::RefCell; +use std::collections::hash_map::Entry; +use std::fmt::Debug; +use std::fmt::Display; +use std::fmt::Write; +use std::hash::Hash; + +/// Input queries store the result plus a list of the other queries +/// that they invoked. This means we can avoid recomputing them when +/// none of those inputs have changed. +pub struct InputStorage +where + Q: Query, + QC: QueryContext, + Q::Value: Default, +{ + map: RwLock>>, +} + +impl Default for InputStorage +where + Q: Query, + QC: QueryContext, + Q::Value: Default, +{ + fn default() -> Self { + InputStorage { + map: RwLock::new(FxHashMap::default()), + } + } +} + +impl InputStorage +where + Q: Query, + QC: QueryContext, + Q::Value: Default, +{ + fn read<'q>( + &self, + _query: &'q QC, + key: &Q::Key, + _descriptor: &QC::QueryDescriptor, + ) -> Result, CycleDetected> { + { + let map_read = self.map.read(); + if let Some(value) = map_read.get(key) { + return Ok(value.clone()); + } + } + + Ok(StampedValue { + value: ::default(), + changed_at: Revision::ZERO, + }) + } +} + +impl QueryStorageOps for InputStorage +where + Q: Query, + QC: QueryContext, + Q::Value: Default, +{ + fn try_fetch<'q>( + &self, + query: &'q QC, + key: &Q::Key, + descriptor: &QC::QueryDescriptor, + ) -> Result { + let StampedValue { + value, + changed_at: _, + } = self.read(query, key, &descriptor)?; + + query.salsa_runtime().report_query_read(descriptor); + + Ok(value) + } + + fn maybe_changed_since( + &self, + _query: &'q QC, + revision: Revision, + key: &Q::Key, + _descriptor: &QC::QueryDescriptor, + ) -> bool { + debug!( + "{:?}({:?})::maybe_changed_since(revision={:?})", + Q::default(), + key, + revision, + ); + + let changed_at = { + let map_read = self.map.read(); + map_read + .get(key) + .map(|v| v.changed_at) + .unwrap_or(Revision::ZERO) + }; + + changed_at > revision + } +} + +impl MutQueryStorageOps for InputStorage +where + Q: Query, + QC: QueryContext, + Q::Value: Default, +{ + fn set(&self, query: &QC, key: &Q::Key, value: Q::Value) { + let key = key.clone(); + + let mut map_write = self.map.write(); + + // Do this *after* we acquire the lock, so that we are not + // racing with somebody else to modify this same cell. + // (Otherwise, someone else might write a *newer* revision + // into the same cell while we block on the lock.) + let changed_at = query.salsa_runtime().increment_revision(); + + map_write.insert(key, StampedValue { value, changed_at }); + } +} + +#[derive(Clone)] +struct StampedValue { + value: V, + changed_at: Revision, +} diff --git a/src/lib.rs b/src/lib.rs index e42e838a..911be173 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,6 +16,7 @@ use std::fmt::Display; use std::fmt::Write; use std::hash::Hash; +pub mod input; pub mod memoized; pub mod runtime; pub mod volatile; @@ -72,9 +73,9 @@ where /// Returns `Err` in the event of a cycle, meaning that computing /// the value for this `key` is recursively attempting to fetch /// itself. - fn try_fetch<'q>( + fn try_fetch( &self, - query: &'q QC, + query: &QC, key: &Q::Key, descriptor: &QC::QueryDescriptor, ) -> Result; @@ -99,22 +100,33 @@ where /// Other storage types will skip some or all of these steps. fn maybe_changed_since( &self, - query: &'q QC, + query: &QC, revision: runtime::Revision, key: &Q::Key, descriptor: &QC::QueryDescriptor, ) -> bool; } +/// An optional trait that is implemented for "user mutable" storage: +/// that is, storage whose value is not derived from other storage but +/// is set independently. +pub trait MutQueryStorageOps: Default +where + QC: QueryContext, + Q: Query, +{ + fn set(&self, query: &QC, key: &Q::Key, new_value: Q::Value); +} + #[derive(new)] pub struct QueryTable<'me, QC, Q> where QC: QueryContext, Q: Query, { - pub query: &'me QC, - pub storage: &'me Q::Storage, - pub descriptor_fn: fn(&QC, &Q::Key) -> QC::QueryDescriptor, + query: &'me QC, + storage: &'me Q::Storage, + descriptor_fn: fn(&QC, &Q::Key) -> QC::QueryDescriptor, } pub struct CycleDetected; @@ -124,7 +136,7 @@ where QC: QueryContext, Q: Query, { - pub fn of(&self, key: Q::Key) -> Q::Value { + pub fn get(&self, key: Q::Key) -> Q::Value { let descriptor = self.descriptor(&key); self.storage .try_fetch(self.query, &key, &descriptor) @@ -135,11 +147,42 @@ where }) } + /// Equivalent to `of(DefaultKey::default_key())` + pub fn read(&self) -> Q::Value + where + Q::Key: DefaultKey, + { + self.get(DefaultKey::default_key()) + } + + /// Assign a value to an "input queries". Must be used outside of + /// an active query computation. + pub fn set(&self, key: Q::Key, value: Q::Value) + where + Q::Storage: MutQueryStorageOps, + { + self.storage.set(self.query, &key, value); + } + fn descriptor(&self, key: &Q::Key) -> QC::QueryDescriptor { (self.descriptor_fn)(self.query, key) } } +/// A variant of the `Default` trait used for query keys that are +/// either singletons (e.g., `()`) or have some overwhelming default. +/// In this case, you can write `query.my_query().read()` as a +/// convenient shorthand. +pub trait DefaultKey { + fn default_key() -> Self; +} + +impl DefaultKey for () { + fn default_key() -> Self { + () + } +} + /// A macro that helps in defining the "context trait" of a given /// module. This is a trait that defines everything that a block of /// queries need to execute, as well as defining the queries @@ -277,6 +320,7 @@ macro_rules! query_definition { } }; + // Accept a "fn-like" query definition ( @filter_attrs { input { @@ -309,6 +353,12 @@ macro_rules! query_definition { } }; + ( + @storage_ty[$QC:ident, $Self:ident, default] + ) => { + $crate::query_definition! { @storage_ty[$QC, $Self, memoized] } + }; + ( @storage_ty[$QC:ident, $Self:ident, memoized] ) => { @@ -321,6 +371,34 @@ macro_rules! query_definition { $crate::volatile::VolatileStorage }; + // Accept a "field-like" query definition (input) + ( + @filter_attrs { + input { + $v:vis $name:ident: Map<$key_ty:ty, $value_ty:ty>; + }; + storage { default }; + other_attrs { $($attrs:tt)* }; + } + ) => { + #[derive(Default, Debug)] + $($attrs)* + $v struct $name; + + impl $crate::Query for $name + where + QC: $crate::QueryContext + { + type Key = $key_ty; + type Value = $value_ty; + type Storage = $crate::input::InputStorage; + + fn execute(_: &QC, _: $key_ty) -> $value_ty { + panic!("execute should never run for an input query") + } + } + }; + // Various legal start states: ( # $($tokens:tt)* @@ -328,7 +406,7 @@ macro_rules! query_definition { $crate::query_definition! { @filter_attrs { input { # $($tokens)* }; - storage { memoized }; + storage { default }; other_attrs { }; } } @@ -339,7 +417,7 @@ macro_rules! query_definition { $crate::query_definition! { @filter_attrs { input { $v $name $($tokens)* }; - storage { memoized }; + storage { default }; other_attrs { }; } } diff --git a/src/memoized.rs b/src/memoized.rs index 52b8db4b..158b8d86 100644 --- a/src/memoized.rs +++ b/src/memoized.rs @@ -97,9 +97,9 @@ where Q: Query, QC: QueryContext, { - fn read<'q>( + fn read( &self, - query: &'q QC, + query: &QC, key: &Q::Key, descriptor: &QC::QueryDescriptor, ) -> Result, CycleDetected> { diff --git a/src/runtime.rs b/src/runtime.rs index 4e6fbd4e..eee53ce7 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -195,9 +195,7 @@ pub struct Revision { } impl Revision { - crate fn zero() -> Self { - Revision { generation: 0 } - } + crate const ZERO: Self = Revision { generation: 0 }; } impl std::fmt::Debug for Revision { diff --git a/tests/incremental/implementation.rs b/tests/incremental/implementation.rs index c89d060b..10cbc0bd 100644 --- a/tests/incremental/implementation.rs +++ b/tests/incremental/implementation.rs @@ -1,25 +1,62 @@ use crate::counter::Counter; use crate::log::Log; -use crate::queries; +use crate::memoized_inputs; +use crate::memoized_volatile; + +crate trait TestContext: salsa::QueryContext { + fn clock(&self) -> &Counter; + fn log(&self) -> &Log; +} #[derive(Default)] -pub struct QueryContextImpl { - runtime: salsa::runtime::Runtime, +crate struct TestContextImpl { + runtime: salsa::runtime::Runtime, clock: Counter, log: Log, } +impl TestContextImpl { + crate fn assert_log(&self, expected_log: &[&str]) { + use difference::{Changeset, Difference}; + + let expected_text = &format!("{:#?}", expected_log); + let actual_text = &format!("{:#?}", self.log().take()); + + if expected_text == actual_text { + return; + } + + let Changeset { diffs, .. } = Changeset::new(expected_text, actual_text, "\n"); + + for i in 0..diffs.len() { + match &diffs[i] { + Difference::Same(x) => println!(" {}", x), + Difference::Add(x) => println!("+{}", x), + Difference::Rem(x) => println!("-{}", x), + } + } + + panic!("incorrect log results"); + } +} + salsa::query_context_storage! { - pub struct QueryContextImplStorage for QueryContextImpl { - impl queries::QueryContext { - fn memoized2() for queries::Memoized2; - fn memoized1() for queries::Memoized1; - fn volatile() for queries::Volatile; + crate struct TestContextImplStorage for TestContextImpl { + impl memoized_volatile::MemoizedVolatileContext { + fn memoized2() for memoized_volatile::Memoized2; + fn memoized1() for memoized_volatile::Memoized1; + fn volatile() for memoized_volatile::Volatile; + } + + impl memoized_inputs::MemoizedInputsContext { + fn max() for memoized_inputs::Max; + fn input1() for memoized_inputs::Input1; + fn input2() for memoized_inputs::Input2; } } } -impl queries::CounterContext for QueryContextImpl { +impl TestContext for TestContextImpl { fn clock(&self) -> &Counter { &self.clock } @@ -29,8 +66,8 @@ impl queries::CounterContext for QueryContextImpl { } } -impl salsa::QueryContext for QueryContextImpl { - fn salsa_runtime(&self) -> &salsa::runtime::Runtime { +impl salsa::QueryContext for TestContextImpl { + fn salsa_runtime(&self) -> &salsa::runtime::Runtime { &self.runtime } } diff --git a/tests/incremental/main.rs b/tests/incremental/main.rs index 38095b91..25b9a967 100644 --- a/tests/incremental/main.rs +++ b/tests/incremental/main.rs @@ -4,7 +4,7 @@ mod counter; mod implementation; mod log; -mod queries; -mod tests; +mod memoized_inputs; +mod memoized_volatile; fn main() {} diff --git a/tests/incremental/memoized_inputs.rs b/tests/incremental/memoized_inputs.rs new file mode 100644 index 00000000..83b7d370 --- /dev/null +++ b/tests/incremental/memoized_inputs.rs @@ -0,0 +1,66 @@ +use crate::implementation::{TestContext, TestContextImpl}; + +crate trait MemoizedInputsContext: TestContext { + salsa::query_prototype! { + fn max() for Max; + fn input1() for Input1; + fn input2() for Input2; + } +} + +salsa::query_definition! { + crate Max(query: &impl MemoizedInputsContext, (): ()) -> usize { + query.log().add("Max invoked"); + std::cmp::max( + query.input1().read(), + query.input2().read(), + ) + } +} + +salsa::query_definition! { + crate Input1: Map<(), usize>; +} + +salsa::query_definition! { + crate Input2: Map<(), usize>; +} + +#[test] +fn revalidate() { + let query = TestContextImpl::default(); + + let v = query.max().read(); + assert_eq!(v, 0); + query.assert_log(&["Max invoked"]); + + let v = query.max().read(); + assert_eq!(v, 0); + query.assert_log(&[]); + + query.input1().set((), 44); + query.assert_log(&[]); + + let v = query.max().read(); + assert_eq!(v, 44); + query.assert_log(&["Max invoked"]); + + let v = query.max().read(); + assert_eq!(v, 44); + query.assert_log(&[]); + + query.input1().set((), 44); + query.assert_log(&[]); + query.input2().set((), 66); + query.assert_log(&[]); + query.input1().set((), 64); + query.assert_log(&[]); + + let v = query.max().read(); + assert_eq!(v, 66); + query.assert_log(&["Max invoked"]); + + let v = query.max().read(); + assert_eq!(v, 66); + query.assert_log(&[]); +} diff --git a/tests/incremental/memoized_volatile.rs b/tests/incremental/memoized_volatile.rs new file mode 100644 index 00000000..b56e0c69 --- /dev/null +++ b/tests/incremental/memoized_volatile.rs @@ -0,0 +1,85 @@ +use crate::implementation::{TestContext, TestContextImpl}; +use salsa::QueryContext; + +crate trait MemoizedVolatileContext: TestContext { + salsa::query_prototype! { + // Queries for testing a "volatile" value wrapped by + // memoization. + fn memoized2() for Memoized2; + fn memoized1() for Memoized1; + fn volatile() for Volatile; + } +} + +salsa::query_definition! { + crate Memoized2(query: &impl MemoizedVolatileContext, (): ()) -> usize { + query.log().add("Memoized2 invoked"); + query.memoized1().read() + } +} + +salsa::query_definition! { + crate Memoized1(query: &impl MemoizedVolatileContext, (): ()) -> usize { + query.log().add("Memoized1 invoked"); + let v = query.volatile().read(); + v / 2 + } +} + +salsa::query_definition! { + #[storage(volatile)] + crate Volatile(query: &impl MemoizedVolatileContext, (): ()) -> usize { + query.log().add("Volatile invoked"); + query.clock().increment() + } +} + +#[test] +fn volatile_x2() { + let query = TestContextImpl::default(); + + // Invoking volatile twice will simply execute twice. + query.volatile().read(); + query.volatile().read(); + query.assert_log(&["Volatile invoked", "Volatile invoked"]); +} + +/// Test that: +/// +/// - On the first run of R0, we recompute everything. +/// - On the second run of R1, we recompute nothing. +/// - On the first run of R1, we recompute Memoized1 but not Memoized2 (since Memoized1 result +/// did not change). +/// - On the second run of R1, we recompute nothing. +/// - On the first run of R2, we recompute everything (since Memoized1 result *did* change). +#[test] +fn revalidate() { + let query = TestContextImpl::default(); + + query.memoized2().read(); + query.assert_log(&["Memoized2 invoked", "Memoized1 invoked", "Volatile invoked"]); + + query.memoized2().read(); + query.assert_log(&[]); + + // Second generation: volatile will change (to 1) but memoized1 + // will not (still 0, as 1/2 = 0) + query.salsa_runtime().next_revision(); + + query.memoized2().read(); + query.assert_log(&["Memoized1 invoked", "Volatile invoked"]); + + query.memoized2().read(); + query.assert_log(&[]); + + // Third generation: volatile will change (to 2) and memoized1 + // will too (to 1). Therefore, after validating that Memoized1 + // changed, we now invoke Memoized2. + query.salsa_runtime().next_revision(); + + query.memoized2().read(); + query.assert_log(&["Memoized1 invoked", "Volatile invoked", "Memoized2 invoked"]); + + query.memoized2().read(); + query.assert_log(&[]); +} diff --git a/tests/incremental/queries.rs b/tests/incremental/queries.rs deleted file mode 100644 index 328adfbd..00000000 --- a/tests/incremental/queries.rs +++ /dev/null @@ -1,38 +0,0 @@ -use crate::counter::Counter; -use crate::log::Log; - -crate trait CounterContext: salsa::QueryContext { - fn clock(&self) -> &Counter; - fn log(&self) -> &Log; -} - -crate trait QueryContext: CounterContext { - salsa::query_prototype! { - fn memoized2() for Memoized2; - fn memoized1() for Memoized1; - fn volatile() for Volatile; - } -} - -salsa::query_definition! { - crate Memoized2(query: &impl QueryContext, (): ()) -> usize { - query.log().add("Memoized2 invoked"); - query.memoized1().of(()) - } -} - -salsa::query_definition! { - crate Memoized1(query: &impl QueryContext, (): ()) -> usize { - query.log().add("Memoized1 invoked"); - let v = query.volatile().of(()); - v / 2 - } -} - -salsa::query_definition! { - #[storage(volatile)] - crate Volatile(query: &impl QueryContext, (): ()) -> usize { - query.log().add("Volatile invoked"); - query.clock().increment() - } -} diff --git a/tests/incremental/tests.rs b/tests/incremental/tests.rs deleted file mode 100644 index f4cbacb6..00000000 --- a/tests/incremental/tests.rs +++ /dev/null @@ -1,83 +0,0 @@ -#![cfg(test)] - -use crate::implementation::QueryContextImpl; -use crate::queries::CounterContext; -use crate::queries::QueryContext as _; -use salsa::QueryContext as _; - -impl QueryContextImpl { - fn assert_log(&self, expected_log: &[&str]) { - use difference::{Changeset, Difference}; - - let expected_text = &format!("{:#?}", expected_log); - let actual_text = &format!("{:#?}", self.log().take()); - - if expected_text == actual_text { - return; - } - - let Changeset { diffs, .. } = Changeset::new(expected_text, actual_text, "\n"); - - for i in 0..diffs.len() { - match &diffs[i] { - Difference::Same(x) => println!(" {}", x), - Difference::Add(x) => println!("+{}", x), - Difference::Rem(x) => println!("-{}", x), - } - } - - panic!("incorrect log results"); - } -} - -#[test] -fn volatile_x2() { - let query = QueryContextImpl::default(); - - // Invoking volatile twice will simply execute twice. - query.volatile().of(()); - query.volatile().of(()); - query.assert_log(&["Volatile invoked", "Volatile invoked"]); -} - -/// Test that: -/// -/// - On the first run of R0, we recompute everything. -/// - On the second run of R1, we recompute nothing. -/// - On the first run of R1, we recompute Memoized1 but not Memoized2 (since Memoized1 result -/// did not change). -/// - On the second run of R1, we recompute nothing. -/// - On the first run of R2, we recompute everything (since Memoized1 result *did* change). -#[test] -fn revalidate() { - env_logger::init(); - - let query = QueryContextImpl::default(); - - query.memoized2().of(()); - query.assert_log(&["Memoized2 invoked", "Memoized1 invoked", "Volatile invoked"]); - - query.memoized2().of(()); - query.assert_log(&[]); - - // Second generation: volatile will change (to 1) but memoized1 - // will not (still 0, as 1/2 = 0) - query.salsa_runtime().next_revision(); - - query.memoized2().of(()); - query.assert_log(&["Memoized1 invoked", "Volatile invoked"]); - - query.memoized2().of(()); - query.assert_log(&[]); - - // Third generation: volatile will change (to 2) and memoized1 - // will too (to 1). Therefore, after validating that Memoized1 - // changed, we now invoke Memoized2. - query.salsa_runtime().next_revision(); - - query.memoized2().of(()); - query.assert_log(&["Memoized1 invoked", "Volatile invoked", "Memoized2 invoked"]); - - query.memoized2().of(()); - query.assert_log(&[]); -} diff --git a/tests/storage_varieties/tests.rs b/tests/storage_varieties/tests.rs index 87f11451..c9032e23 100644 --- a/tests/storage_varieties/tests.rs +++ b/tests/storage_varieties/tests.rs @@ -6,26 +6,26 @@ use crate::queries::QueryContext; #[test] fn memoized_twice() { let query = QueryContextImpl::default(); - let v1 = query.memoized().of(()); - let v2 = query.memoized().of(()); + let v1 = query.memoized().read(); + let v2 = query.memoized().read(); assert_eq!(v1, v2); } #[test] fn volatile_twice() { let query = QueryContextImpl::default(); - let v1 = query.volatile().of(()); - let v2 = query.volatile().of(()); + let v1 = query.volatile().read(); + let v2 = query.volatile().read(); assert_eq!(v1 + 1, v2); } #[test] fn intermingled() { let query = QueryContextImpl::default(); - let v1 = query.volatile().of(()); - let v2 = query.memoized().of(()); - let v3 = query.volatile().of(()); - let v4 = query.memoized().of(()); + let v1 = query.volatile().read(); + let v2 = query.memoized().read(); + let v3 = query.volatile().read(); + let v4 = query.memoized().read(); assert_eq!(v1 + 1, v2); assert_eq!(v2 + 1, v3);