Merge pull request #5 from nikomatsakis/input

support "inputs" for incremental computation
This commit is contained in:
Niko Matsakis 2018-09-30 11:05:06 -04:00 committed by GitHub
commit 803095f13d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 448 additions and 160 deletions

View file

@ -34,11 +34,11 @@ query_definition! {
pub AllFields(query: &impl ClassTableQueryContext, (): ()) -> Arc<Vec<DefId>> {
Arc::new(
query.all_classes()
.of(())
.get(())
.iter()
.cloned()
.flat_map(|def_id| {
let fields = query.fields().of(def_id);
let fields = query.fields().get(def_id);
(0..fields.len()).map(move |i| fields[i])
})
.collect()

View file

@ -8,7 +8,7 @@ use self::implementation::QueryContextImpl;
#[test]
fn test() {
let query = QueryContextImpl::default();
let all_def_ids = query.all_fields().of(());
let all_def_ids = query.all_fields().read();
assert_eq!(
format!("{:?}", all_def_ids),
"[DefId(1), DefId(2), DefId(11), DefId(12)]"
@ -17,7 +17,7 @@ fn test() {
fn main() {
let query = QueryContextImpl::default();
for f in query.all_fields().of(()).iter() {
for f in query.all_fields().read().iter() {
println!("{:?}", f);
}
}

145
src/input.rs Normal file
View file

@ -0,0 +1,145 @@
use crate::runtime::QueryDescriptorSet;
use crate::runtime::Revision;
use crate::CycleDetected;
use crate::MutQueryStorageOps;
use crate::Query;
use crate::QueryContext;
use crate::QueryDescriptor;
use crate::QueryStorageOps;
use crate::QueryTable;
use log::debug;
use parking_lot::{RwLock, RwLockUpgradableReadGuard};
use rustc_hash::FxHashMap;
use std::any::Any;
use std::cell::RefCell;
use std::collections::hash_map::Entry;
use std::fmt::Debug;
use std::fmt::Display;
use std::fmt::Write;
use std::hash::Hash;
/// Input queries store the result plus a list of the other queries
/// that they invoked. This means we can avoid recomputing them when
/// none of those inputs have changed.
pub struct InputStorage<QC, Q>
where
Q: Query<QC>,
QC: QueryContext,
Q::Value: Default,
{
map: RwLock<FxHashMap<Q::Key, StampedValue<Q::Value>>>,
}
impl<QC, Q> Default for InputStorage<QC, Q>
where
Q: Query<QC>,
QC: QueryContext,
Q::Value: Default,
{
fn default() -> Self {
InputStorage {
map: RwLock::new(FxHashMap::default()),
}
}
}
impl<QC, Q> InputStorage<QC, Q>
where
Q: Query<QC>,
QC: QueryContext,
Q::Value: Default,
{
fn read<'q>(
&self,
_query: &'q QC,
key: &Q::Key,
_descriptor: &QC::QueryDescriptor,
) -> Result<StampedValue<Q::Value>, CycleDetected> {
{
let map_read = self.map.read();
if let Some(value) = map_read.get(key) {
return Ok(value.clone());
}
}
Ok(StampedValue {
value: <Q::Value>::default(),
changed_at: Revision::ZERO,
})
}
}
impl<QC, Q> QueryStorageOps<QC, Q> for InputStorage<QC, Q>
where
Q: Query<QC>,
QC: QueryContext,
Q::Value: Default,
{
fn try_fetch<'q>(
&self,
query: &'q QC,
key: &Q::Key,
descriptor: &QC::QueryDescriptor,
) -> Result<Q::Value, CycleDetected> {
let StampedValue {
value,
changed_at: _,
} = self.read(query, key, &descriptor)?;
query.salsa_runtime().report_query_read(descriptor);
Ok(value)
}
fn maybe_changed_since(
&self,
_query: &'q QC,
revision: Revision,
key: &Q::Key,
_descriptor: &QC::QueryDescriptor,
) -> bool {
debug!(
"{:?}({:?})::maybe_changed_since(revision={:?})",
Q::default(),
key,
revision,
);
let changed_at = {
let map_read = self.map.read();
map_read
.get(key)
.map(|v| v.changed_at)
.unwrap_or(Revision::ZERO)
};
changed_at > revision
}
}
impl<QC, Q> MutQueryStorageOps<QC, Q> for InputStorage<QC, Q>
where
Q: Query<QC>,
QC: QueryContext,
Q::Value: Default,
{
fn set(&self, query: &QC, key: &Q::Key, value: Q::Value) {
let key = key.clone();
let mut map_write = self.map.write();
// Do this *after* we acquire the lock, so that we are not
// racing with somebody else to modify this same cell.
// (Otherwise, someone else might write a *newer* revision
// into the same cell while we block on the lock.)
let changed_at = query.salsa_runtime().increment_revision();
map_write.insert(key, StampedValue { value, changed_at });
}
}
#[derive(Clone)]
struct StampedValue<V> {
value: V,
changed_at: Revision,
}

View file

@ -16,6 +16,7 @@ use std::fmt::Display;
use std::fmt::Write;
use std::hash::Hash;
pub mod input;
pub mod memoized;
pub mod runtime;
pub mod volatile;
@ -72,9 +73,9 @@ where
/// Returns `Err` in the event of a cycle, meaning that computing
/// the value for this `key` is recursively attempting to fetch
/// itself.
fn try_fetch<'q>(
fn try_fetch(
&self,
query: &'q QC,
query: &QC,
key: &Q::Key,
descriptor: &QC::QueryDescriptor,
) -> Result<Q::Value, CycleDetected>;
@ -99,22 +100,33 @@ where
/// Other storage types will skip some or all of these steps.
fn maybe_changed_since(
&self,
query: &'q QC,
query: &QC,
revision: runtime::Revision,
key: &Q::Key,
descriptor: &QC::QueryDescriptor,
) -> bool;
}
/// An optional trait that is implemented for "user mutable" storage:
/// that is, storage whose value is not derived from other storage but
/// is set independently.
pub trait MutQueryStorageOps<QC, Q>: Default
where
QC: QueryContext,
Q: Query<QC>,
{
fn set(&self, query: &QC, key: &Q::Key, new_value: Q::Value);
}
#[derive(new)]
pub struct QueryTable<'me, QC, Q>
where
QC: QueryContext,
Q: Query<QC>,
{
pub query: &'me QC,
pub storage: &'me Q::Storage,
pub descriptor_fn: fn(&QC, &Q::Key) -> QC::QueryDescriptor,
query: &'me QC,
storage: &'me Q::Storage,
descriptor_fn: fn(&QC, &Q::Key) -> QC::QueryDescriptor,
}
pub struct CycleDetected;
@ -124,7 +136,7 @@ where
QC: QueryContext,
Q: Query<QC>,
{
pub fn of(&self, key: Q::Key) -> Q::Value {
pub fn get(&self, key: Q::Key) -> Q::Value {
let descriptor = self.descriptor(&key);
self.storage
.try_fetch(self.query, &key, &descriptor)
@ -135,11 +147,42 @@ where
})
}
/// Equivalent to `of(DefaultKey::default_key())`
pub fn read(&self) -> Q::Value
where
Q::Key: DefaultKey,
{
self.get(DefaultKey::default_key())
}
/// Assign a value to an "input queries". Must be used outside of
/// an active query computation.
pub fn set(&self, key: Q::Key, value: Q::Value)
where
Q::Storage: MutQueryStorageOps<QC, Q>,
{
self.storage.set(self.query, &key, value);
}
fn descriptor(&self, key: &Q::Key) -> QC::QueryDescriptor {
(self.descriptor_fn)(self.query, key)
}
}
/// A variant of the `Default` trait used for query keys that are
/// either singletons (e.g., `()`) or have some overwhelming default.
/// In this case, you can write `query.my_query().read()` as a
/// convenient shorthand.
pub trait DefaultKey {
fn default_key() -> Self;
}
impl DefaultKey for () {
fn default_key() -> Self {
()
}
}
/// A macro that helps in defining the "context trait" of a given
/// module. This is a trait that defines everything that a block of
/// queries need to execute, as well as defining the queries
@ -277,6 +320,7 @@ macro_rules! query_definition {
}
};
// Accept a "fn-like" query definition
(
@filter_attrs {
input {
@ -309,6 +353,12 @@ macro_rules! query_definition {
}
};
(
@storage_ty[$QC:ident, $Self:ident, default]
) => {
$crate::query_definition! { @storage_ty[$QC, $Self, memoized] }
};
(
@storage_ty[$QC:ident, $Self:ident, memoized]
) => {
@ -321,6 +371,34 @@ macro_rules! query_definition {
$crate::volatile::VolatileStorage
};
// Accept a "field-like" query definition (input)
(
@filter_attrs {
input {
$v:vis $name:ident: Map<$key_ty:ty, $value_ty:ty>;
};
storage { default };
other_attrs { $($attrs:tt)* };
}
) => {
#[derive(Default, Debug)]
$($attrs)*
$v struct $name;
impl<QC> $crate::Query<QC> for $name
where
QC: $crate::QueryContext
{
type Key = $key_ty;
type Value = $value_ty;
type Storage = $crate::input::InputStorage<QC, Self>;
fn execute(_: &QC, _: $key_ty) -> $value_ty {
panic!("execute should never run for an input query")
}
}
};
// Various legal start states:
(
# $($tokens:tt)*
@ -328,7 +406,7 @@ macro_rules! query_definition {
$crate::query_definition! {
@filter_attrs {
input { # $($tokens)* };
storage { memoized };
storage { default };
other_attrs { };
}
}
@ -339,7 +417,7 @@ macro_rules! query_definition {
$crate::query_definition! {
@filter_attrs {
input { $v $name $($tokens)* };
storage { memoized };
storage { default };
other_attrs { };
}
}

View file

@ -97,9 +97,9 @@ where
Q: Query<QC>,
QC: QueryContext,
{
fn read<'q>(
fn read(
&self,
query: &'q QC,
query: &QC,
key: &Q::Key,
descriptor: &QC::QueryDescriptor,
) -> Result<StampedValue<Q::Value>, CycleDetected> {

View file

@ -195,9 +195,7 @@ pub struct Revision {
}
impl Revision {
crate fn zero() -> Self {
Revision { generation: 0 }
}
crate const ZERO: Self = Revision { generation: 0 };
}
impl std::fmt::Debug for Revision {

View file

@ -1,25 +1,62 @@
use crate::counter::Counter;
use crate::log::Log;
use crate::queries;
use crate::memoized_inputs;
use crate::memoized_volatile;
crate trait TestContext: salsa::QueryContext {
fn clock(&self) -> &Counter;
fn log(&self) -> &Log;
}
#[derive(Default)]
pub struct QueryContextImpl {
runtime: salsa::runtime::Runtime<QueryContextImpl>,
crate struct TestContextImpl {
runtime: salsa::runtime::Runtime<TestContextImpl>,
clock: Counter,
log: Log,
}
impl TestContextImpl {
crate fn assert_log(&self, expected_log: &[&str]) {
use difference::{Changeset, Difference};
let expected_text = &format!("{:#?}", expected_log);
let actual_text = &format!("{:#?}", self.log().take());
if expected_text == actual_text {
return;
}
let Changeset { diffs, .. } = Changeset::new(expected_text, actual_text, "\n");
for i in 0..diffs.len() {
match &diffs[i] {
Difference::Same(x) => println!(" {}", x),
Difference::Add(x) => println!("+{}", x),
Difference::Rem(x) => println!("-{}", x),
}
}
panic!("incorrect log results");
}
}
salsa::query_context_storage! {
pub struct QueryContextImplStorage for QueryContextImpl {
impl queries::QueryContext {
fn memoized2() for queries::Memoized2;
fn memoized1() for queries::Memoized1;
fn volatile() for queries::Volatile;
crate struct TestContextImplStorage for TestContextImpl {
impl memoized_volatile::MemoizedVolatileContext {
fn memoized2() for memoized_volatile::Memoized2;
fn memoized1() for memoized_volatile::Memoized1;
fn volatile() for memoized_volatile::Volatile;
}
impl memoized_inputs::MemoizedInputsContext {
fn max() for memoized_inputs::Max;
fn input1() for memoized_inputs::Input1;
fn input2() for memoized_inputs::Input2;
}
}
}
impl queries::CounterContext for QueryContextImpl {
impl TestContext for TestContextImpl {
fn clock(&self) -> &Counter {
&self.clock
}
@ -29,8 +66,8 @@ impl queries::CounterContext for QueryContextImpl {
}
}
impl salsa::QueryContext for QueryContextImpl {
fn salsa_runtime(&self) -> &salsa::runtime::Runtime<QueryContextImpl> {
impl salsa::QueryContext for TestContextImpl {
fn salsa_runtime(&self) -> &salsa::runtime::Runtime<TestContextImpl> {
&self.runtime
}
}

View file

@ -4,7 +4,7 @@
mod counter;
mod implementation;
mod log;
mod queries;
mod tests;
mod memoized_inputs;
mod memoized_volatile;
fn main() {}

View file

@ -0,0 +1,66 @@
use crate::implementation::{TestContext, TestContextImpl};
crate trait MemoizedInputsContext: TestContext {
salsa::query_prototype! {
fn max() for Max;
fn input1() for Input1;
fn input2() for Input2;
}
}
salsa::query_definition! {
crate Max(query: &impl MemoizedInputsContext, (): ()) -> usize {
query.log().add("Max invoked");
std::cmp::max(
query.input1().read(),
query.input2().read(),
)
}
}
salsa::query_definition! {
crate Input1: Map<(), usize>;
}
salsa::query_definition! {
crate Input2: Map<(), usize>;
}
#[test]
fn revalidate() {
let query = TestContextImpl::default();
let v = query.max().read();
assert_eq!(v, 0);
query.assert_log(&["Max invoked"]);
let v = query.max().read();
assert_eq!(v, 0);
query.assert_log(&[]);
query.input1().set((), 44);
query.assert_log(&[]);
let v = query.max().read();
assert_eq!(v, 44);
query.assert_log(&["Max invoked"]);
let v = query.max().read();
assert_eq!(v, 44);
query.assert_log(&[]);
query.input1().set((), 44);
query.assert_log(&[]);
query.input2().set((), 66);
query.assert_log(&[]);
query.input1().set((), 64);
query.assert_log(&[]);
let v = query.max().read();
assert_eq!(v, 66);
query.assert_log(&["Max invoked"]);
let v = query.max().read();
assert_eq!(v, 66);
query.assert_log(&[]);
}

View file

@ -0,0 +1,85 @@
use crate::implementation::{TestContext, TestContextImpl};
use salsa::QueryContext;
crate trait MemoizedVolatileContext: TestContext {
salsa::query_prototype! {
// Queries for testing a "volatile" value wrapped by
// memoization.
fn memoized2() for Memoized2;
fn memoized1() for Memoized1;
fn volatile() for Volatile;
}
}
salsa::query_definition! {
crate Memoized2(query: &impl MemoizedVolatileContext, (): ()) -> usize {
query.log().add("Memoized2 invoked");
query.memoized1().read()
}
}
salsa::query_definition! {
crate Memoized1(query: &impl MemoizedVolatileContext, (): ()) -> usize {
query.log().add("Memoized1 invoked");
let v = query.volatile().read();
v / 2
}
}
salsa::query_definition! {
#[storage(volatile)]
crate Volatile(query: &impl MemoizedVolatileContext, (): ()) -> usize {
query.log().add("Volatile invoked");
query.clock().increment()
}
}
#[test]
fn volatile_x2() {
let query = TestContextImpl::default();
// Invoking volatile twice will simply execute twice.
query.volatile().read();
query.volatile().read();
query.assert_log(&["Volatile invoked", "Volatile invoked"]);
}
/// Test that:
///
/// - On the first run of R0, we recompute everything.
/// - On the second run of R1, we recompute nothing.
/// - On the first run of R1, we recompute Memoized1 but not Memoized2 (since Memoized1 result
/// did not change).
/// - On the second run of R1, we recompute nothing.
/// - On the first run of R2, we recompute everything (since Memoized1 result *did* change).
#[test]
fn revalidate() {
let query = TestContextImpl::default();
query.memoized2().read();
query.assert_log(&["Memoized2 invoked", "Memoized1 invoked", "Volatile invoked"]);
query.memoized2().read();
query.assert_log(&[]);
// Second generation: volatile will change (to 1) but memoized1
// will not (still 0, as 1/2 = 0)
query.salsa_runtime().next_revision();
query.memoized2().read();
query.assert_log(&["Memoized1 invoked", "Volatile invoked"]);
query.memoized2().read();
query.assert_log(&[]);
// Third generation: volatile will change (to 2) and memoized1
// will too (to 1). Therefore, after validating that Memoized1
// changed, we now invoke Memoized2.
query.salsa_runtime().next_revision();
query.memoized2().read();
query.assert_log(&["Memoized1 invoked", "Volatile invoked", "Memoized2 invoked"]);
query.memoized2().read();
query.assert_log(&[]);
}

View file

@ -1,38 +0,0 @@
use crate::counter::Counter;
use crate::log::Log;
crate trait CounterContext: salsa::QueryContext {
fn clock(&self) -> &Counter;
fn log(&self) -> &Log;
}
crate trait QueryContext: CounterContext {
salsa::query_prototype! {
fn memoized2() for Memoized2;
fn memoized1() for Memoized1;
fn volatile() for Volatile;
}
}
salsa::query_definition! {
crate Memoized2(query: &impl QueryContext, (): ()) -> usize {
query.log().add("Memoized2 invoked");
query.memoized1().of(())
}
}
salsa::query_definition! {
crate Memoized1(query: &impl QueryContext, (): ()) -> usize {
query.log().add("Memoized1 invoked");
let v = query.volatile().of(());
v / 2
}
}
salsa::query_definition! {
#[storage(volatile)]
crate Volatile(query: &impl QueryContext, (): ()) -> usize {
query.log().add("Volatile invoked");
query.clock().increment()
}
}

View file

@ -1,83 +0,0 @@
#![cfg(test)]
use crate::implementation::QueryContextImpl;
use crate::queries::CounterContext;
use crate::queries::QueryContext as _;
use salsa::QueryContext as _;
impl QueryContextImpl {
fn assert_log(&self, expected_log: &[&str]) {
use difference::{Changeset, Difference};
let expected_text = &format!("{:#?}", expected_log);
let actual_text = &format!("{:#?}", self.log().take());
if expected_text == actual_text {
return;
}
let Changeset { diffs, .. } = Changeset::new(expected_text, actual_text, "\n");
for i in 0..diffs.len() {
match &diffs[i] {
Difference::Same(x) => println!(" {}", x),
Difference::Add(x) => println!("+{}", x),
Difference::Rem(x) => println!("-{}", x),
}
}
panic!("incorrect log results");
}
}
#[test]
fn volatile_x2() {
let query = QueryContextImpl::default();
// Invoking volatile twice will simply execute twice.
query.volatile().of(());
query.volatile().of(());
query.assert_log(&["Volatile invoked", "Volatile invoked"]);
}
/// Test that:
///
/// - On the first run of R0, we recompute everything.
/// - On the second run of R1, we recompute nothing.
/// - On the first run of R1, we recompute Memoized1 but not Memoized2 (since Memoized1 result
/// did not change).
/// - On the second run of R1, we recompute nothing.
/// - On the first run of R2, we recompute everything (since Memoized1 result *did* change).
#[test]
fn revalidate() {
env_logger::init();
let query = QueryContextImpl::default();
query.memoized2().of(());
query.assert_log(&["Memoized2 invoked", "Memoized1 invoked", "Volatile invoked"]);
query.memoized2().of(());
query.assert_log(&[]);
// Second generation: volatile will change (to 1) but memoized1
// will not (still 0, as 1/2 = 0)
query.salsa_runtime().next_revision();
query.memoized2().of(());
query.assert_log(&["Memoized1 invoked", "Volatile invoked"]);
query.memoized2().of(());
query.assert_log(&[]);
// Third generation: volatile will change (to 2) and memoized1
// will too (to 1). Therefore, after validating that Memoized1
// changed, we now invoke Memoized2.
query.salsa_runtime().next_revision();
query.memoized2().of(());
query.assert_log(&["Memoized1 invoked", "Volatile invoked", "Memoized2 invoked"]);
query.memoized2().of(());
query.assert_log(&[]);
}

View file

@ -6,26 +6,26 @@ use crate::queries::QueryContext;
#[test]
fn memoized_twice() {
let query = QueryContextImpl::default();
let v1 = query.memoized().of(());
let v2 = query.memoized().of(());
let v1 = query.memoized().read();
let v2 = query.memoized().read();
assert_eq!(v1, v2);
}
#[test]
fn volatile_twice() {
let query = QueryContextImpl::default();
let v1 = query.volatile().of(());
let v2 = query.volatile().of(());
let v1 = query.volatile().read();
let v2 = query.volatile().read();
assert_eq!(v1 + 1, v2);
}
#[test]
fn intermingled() {
let query = QueryContextImpl::default();
let v1 = query.volatile().of(());
let v2 = query.memoized().of(());
let v3 = query.volatile().of(());
let v4 = query.memoized().of(());
let v1 = query.volatile().read();
let v2 = query.memoized().read();
let v3 = query.volatile().read();
let v4 = query.memoized().read();
assert_eq!(v1 + 1, v2);
assert_eq!(v2 + 1, v3);