permit constants to be modified

We now track the last revision in which constants were modified. When
we see a constant query result, we record the current revision as
well. Then later we can check if the result is "still" constant. This
lets us cut out a lot of intermediate work.
This commit is contained in:
Niko Matsakis 2019-06-21 19:55:05 -07:00
parent 636e48d45d
commit 0a5b6b0451
4 changed files with 267 additions and 123 deletions

View file

@ -72,6 +72,11 @@ where
/// Last revision when the memoized value was observed to change.
changed_at: Revision,
/// If `Some`, then this value was considered constant at the
/// given revision. If no constants have changed since then, we
/// don't need to check the inputs to see if they've changed.
constant_in_revision: Option<Revision>,
/// The inputs that went into our query, if we are tracking them.
inputs: MemoInputs<DB>,
}
@ -79,15 +84,15 @@ where
/// An insertion-order-preserving set of queries. Used to track the
/// inputs accessed during query execution.
pub(super) enum MemoInputs<DB: Database> {
// No inputs
Constant,
// Non-empty set of inputs fully known
/// Non-empty set of inputs, fully known
Tracked {
inputs: Arc<FxIndexSet<Dependency<DB>>>,
},
// Unknown quantity of inputs
/// Empty set of inputs, fully known.
NoInputs,
/// Unknown quantity of inputs
Untracked,
}
@ -215,7 +220,12 @@ where
// old value.
if let Some(old_memo) = &panic_guard.memo {
if let Some(old_value) = &old_memo.value {
if MP::memoized_value_eq(&old_value, &result.value) {
// Careful: the "constant-ness" must also not have
// changed, see the test `constant_to_non_constant`.
let memo_was_constant = old_memo.constant_in_revision.is_some();
if memo_was_constant == result.changed_at.is_constant
&& MP::memoized_value_eq(&old_value, &result.value)
{
debug!(
"read_upgrade({:?}): value is equal, back-dating to {:?}",
self, old_memo.changed_at,
@ -243,20 +253,22 @@ where
self, result.changed_at, result.dependencies,
);
let constant_in_revision = if result.changed_at.is_constant {
Some(revision_now)
} else {
None
};
debug!(
"read_upgrade({:?}): constant_in_revision={:?}",
self, constant_in_revision
);
let inputs = match result.dependencies {
None => MemoInputs::Untracked,
Some(dependencies) => {
// If all things that we read were constants, then
// we don't need to track our inputs: our value
// can never be invalidated.
//
// If OTOH we read at least *some* non-constant
// inputs, then we do track our inputs (even the
// constants), so that if we run the GC, we know
// which constants we looked at.
if dependencies.is_empty() || result.changed_at.is_constant {
MemoInputs::Constant
if dependencies.is_empty() {
MemoInputs::NoInputs
} else {
MemoInputs::Tracked {
inputs: Arc::new(dependencies),
@ -264,11 +276,14 @@ where
}
}
};
debug!("read_upgrade({:?}): inputs={:?}", self, inputs);
panic_guard.memo = Some(Memo {
value,
changed_at: result.changed_at.revision,
verified_at: revision_now,
inputs,
constant_in_revision,
});
panic_guard.proceed(&new_value);
@ -350,11 +365,11 @@ where
ProbeState::StaleOrAbsent(state)
}
pub(super) fn is_constant(&self, _db: &DB) -> bool {
pub(super) fn is_constant(&self, db: &DB) -> bool {
match &*self.state.read() {
QueryState::NotComputed => false,
QueryState::InProgress { .. } => panic!("query in progress"),
QueryState::Memoized(memo) => memo.inputs.is_constant(),
QueryState::Memoized(memo) => memo.is_still_constant(db),
}
}
@ -605,28 +620,37 @@ where
}
}
impl<DB: Database> MemoInputs<DB> {
fn is_constant(&self) -> bool {
if let MemoInputs::Constant = self {
true
} else {
false
}
}
}
impl<DB, Q> Memo<DB, Q>
where
Q: QueryFunction<DB>,
DB: Database + HasQueryGroup<Q::Group>,
{
/// True if this memo should still be considered constant
/// (presuming it ever was).
fn is_still_constant(&self, db: &DB) -> bool {
if let Some(constant_at) = self.constant_in_revision {
let last_changed = db.salsa_runtime().revision_when_constant_last_changed();
debug!(
"is_still_constant(last_changed={:?} <= constant_at={:?}) = {:?}",
last_changed,
constant_at,
last_changed <= constant_at,
);
last_changed <= constant_at
} else {
false
}
}
fn validate_memoized_value(
&mut self,
db: &DB,
revision_now: Revision,
) -> Option<StampedValue<Q::Value>> {
// If we don't have a memoized value, nothing to validate.
let value = self.value.as_ref()?;
if self.value.is_none() {
return None;
}
assert!(self.verified_at != revision_now);
let verified_at = self.verified_at;
@ -637,15 +661,18 @@ where
self.inputs,
);
let is_constant = match &mut self.inputs {
if self.is_still_constant(db) {
return Some(self.verify_value(revision_now));
}
match &mut self.inputs {
// We can't validate values that had untracked inputs; just have to
// re-execute.
MemoInputs::Untracked { .. } => {
return None;
}
// Constant: no changed input
MemoInputs::Constant => true,
MemoInputs::NoInputs => {}
// Check whether any of our inputs changed since the
// **last point where we were verified** (not since we
@ -671,19 +698,31 @@ where
return None;
}
false
}
};
Some(self.verify_value(revision_now))
}
fn verify_value(&mut self, revision_now: Revision) -> StampedValue<Q::Value> {
let value = match &self.value {
Some(v) => v.clone(),
None => panic!("invoked `verify_value` without a value!"),
};
self.verified_at = revision_now;
Some(StampedValue {
let is_constant = self.constant_in_revision.is_some();
if is_constant {
self.constant_in_revision = Some(revision_now);
}
StampedValue {
changed_at: ChangedAt {
is_constant,
revision: self.changed_at,
},
value: value.clone(),
})
value,
}
}
/// Returns the memoized value *if* it is known to be update in the given revision.
@ -696,10 +735,7 @@ where
);
if self.verified_at == revision_now {
let is_constant = match self.inputs {
MemoInputs::Constant => true,
_ => false,
};
let is_constant = self.constant_in_revision.is_some();
return Some(StampedValue {
changed_at: ChangedAt {
@ -735,10 +771,10 @@ where
impl<DB: Database> std::fmt::Debug for MemoInputs<DB> {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
MemoInputs::Constant => fmt.debug_struct("Constant").finish(),
MemoInputs::Tracked { inputs } => {
fmt.debug_struct("Tracked").field("inputs", inputs).finish()
}
MemoInputs::NoInputs => fmt.debug_struct("NoInputs").finish(),
MemoInputs::Untracked => fmt.debug_struct("Untracked").finish(),
}
}
@ -825,62 +861,74 @@ where
return memo.changed_at > revision;
}
let inputs = match &memo.inputs {
MemoInputs::Untracked => {
// we don't know the full set of
// inputs, so if there is a new
// revision, we must assume it is
// dirty
debug!(
"maybe_changed_since({:?}: true since untracked inputs",
self,
);
return true;
}
let maybe_changed;
MemoInputs::Constant => None,
MemoInputs::Tracked { inputs } => {
// At this point, the value may be dirty (we have
// to check the database-keys). If we have a cached
// value, we'll just fall back to invoking `read`,
// which will do that checking (and a bit more) --
// note that we skip the "pure read" part as we
// already know the result.
assert!(inputs.len() > 0);
if memo.value.is_some() {
std::mem::drop(state);
return match self.read_upgrade(db, revision_now) {
Ok(v) => {
debug!(
"maybe_changed_since({:?}: {:?} since (recomputed) value changed at {:?}",
self,
v.changed_at.changed_since(revision),
v.changed_at,
);
v.changed_at.changed_since(revision)
}
Err(CycleDetected) => true,
};
// If we only depended on constants, and no constant has been
// modified since then, we cannot have changed; no need to
// trace our inputs.
if memo.is_still_constant(db) {
std::mem::drop(state);
maybe_changed = false;
} else {
match &memo.inputs {
MemoInputs::Untracked => {
// we don't know the full set of
// inputs, so if there is a new
// revision, we must assume it is
// dirty
debug!(
"maybe_changed_since({:?}: true since untracked inputs",
self,
);
return true;
}
Some(inputs.clone())
MemoInputs::NoInputs => {
std::mem::drop(state);
maybe_changed = false;
}
MemoInputs::Tracked { inputs } => {
// At this point, the value may be dirty (we have
// to check the database-keys). If we have a cached
// value, we'll just fall back to invoking `read`,
// which will do that checking (and a bit more) --
// note that we skip the "pure read" part as we
// already know the result.
assert!(inputs.len() > 0);
if memo.value.is_some() {
std::mem::drop(state);
return match self.read_upgrade(db, revision_now) {
Ok(v) => {
debug!(
"maybe_changed_since({:?}: {:?} since (recomputed) value changed at {:?}",
self,
v.changed_at.changed_since(revision),
v.changed_at,
);
v.changed_at.changed_since(revision)
}
Err(CycleDetected) => true,
};
}
let inputs = inputs.clone();
// We have a **tracked set of inputs**
// (found in `database_keys`) that need to
// be validated.
std::mem::drop(state);
// Iterate the inputs and see if any have maybe changed.
maybe_changed = inputs
.iter()
.filter(|input| input.maybe_changed_since(db, revision))
.inspect(|input| debug!("{:?}: input `{:?}` may have changed", self, input))
.next()
.is_some();
}
}
};
// We have a **tracked set of inputs**
// (found in `database_keys`) that need to
// be validated.
std::mem::drop(state);
// Iterate the inputs and see if any have maybe changed.
let maybe_changed = inputs
.iter()
.flat_map(|inputs| inputs.iter())
.filter(|input| input.maybe_changed_since(db, revision))
.inspect(|input| debug!("{:?}: input `{:?}` may have changed", self, input))
.next()
.is_some();
}
// Either way, we have to update our entry.
//

View file

@ -86,7 +86,7 @@ where
// need to read from this input. Therefore, we wait to acquire
// the lock on `map` until we also hold the global query write
// lock.
db.salsa_runtime().with_incremented_revision(|next_revision| {
db.salsa_runtime().with_incremented_revision(|guard| {
let mut slots = self.slots.write();
db.salsa_event(|| Event {
@ -102,7 +102,7 @@ where
// into the same cell while we block on the lock.)
let changed_at = ChangedAt {
is_constant: is_constant.0,
revision: next_revision,
revision: guard.new_revision(),
};
let stamped_value = StampedValue { value, changed_at };
@ -111,14 +111,9 @@ where
Entry::Occupied(entry) => {
let mut slot_stamped_value = entry.get().stamped_value.write();
assert!(
!slot_stamped_value.changed_at.is_constant,
"modifying `{:?}({:?})`, which was previously marked as constant (old value `{:?}`, new value `{:?}`)",
Q::default(),
entry.key(),
slot_stamped_value.value,
stamped_value.value,
);
if slot_stamped_value.changed_at.is_constant {
guard.mark_constants_as_changed();
}
*slot_stamped_value = stamped_value;
}

View file

@ -1,6 +1,7 @@
use crate::dependency::DatabaseSlot;
use crate::dependency::Dependency;
use crate::{Database, Event, EventKind, SweepStrategy};
use crossbeam::atomic::AtomicCell;
use lock_api::{RawRwLock, RawRwLockRecursive};
use log::debug;
use parking_lot::{Mutex, RwLock};
@ -158,6 +159,12 @@ where
Revision::from(self.shared_state.revision.load(Ordering::SeqCst))
}
/// The revision in which constants last changed.
#[inline]
pub(crate) fn revision_when_constant_last_changed(&self) -> Revision {
self.shared_state.constant_revision.load()
}
/// Read current value of the revision counter.
#[inline]
fn pending_revision(&self) -> Revision {
@ -262,7 +269,10 @@ where
/// Note that, given our writer model, we can assume that only one
/// thread is attempting to increment the global revision at a
/// time.
pub(crate) fn with_incremented_revision<R>(&self, op: impl FnOnce(Revision) -> R) -> R {
pub(crate) fn with_incremented_revision<R>(
&self,
op: impl FnOnce(&DatabaseWriteLockGuard<'_, DB>) -> R,
) -> R {
log::debug!("increment_revision()");
if !self.permits_increment() {
@ -287,7 +297,10 @@ where
debug!("increment_revision: incremented to {:?}", new_revision);
op(new_revision)
op(&DatabaseWriteLockGuard {
runtime: self,
new_revision,
})
}
pub(crate) fn permits_increment(&self) -> bool {
@ -402,6 +415,39 @@ where
}
}
/// Temporary guard that indicates that the database write-lock is
/// held. You can get one of these by invoking
/// `with_incremented_revision`. It gives access to the new revision
/// and a few other operations that only make sense to do while an
/// update is happening.
pub(crate) struct DatabaseWriteLockGuard<'db, DB>
where
DB: Database,
{
runtime: &'db Runtime<DB>,
new_revision: Revision,
}
impl<DB> DatabaseWriteLockGuard<'_, DB>
where
DB: Database,
{
pub(crate) fn new_revision(&self) -> Revision {
self.new_revision
}
/// Indicates that this update modified an input marked as
/// "constant". This will force re-evaluation of anything that was
/// dependent on constants (which otherwise might not get
/// re-evaluated).
pub(crate) fn mark_constants_as_changed(&self) {
self.runtime
.shared_state
.constant_revision
.store(self.new_revision);
}
}
/// State that will be common to all threads (when we support multiple threads)
struct SharedState<DB: Database> {
storage: DB::DatabaseStorage,
@ -430,6 +476,12 @@ struct SharedState<DB: Database> {
/// revision is canceled).
pending_revision: AtomicU64,
/// The last time that a value marked as "constant" changed. Like
/// `revision` and `pending_revision`, this is readable without
/// any lock but requires the query-lock to be write-locked for
/// updates
constant_revision: AtomicCell<Revision>,
/// The dependency graph tracks which runtimes are blocked on one
/// another, waiting for queries to terminate.
dependency_graph: Mutex<DependencyGraph<DB>>,
@ -448,8 +500,9 @@ impl<DB: Database> Default for SharedState<DB> {
next_id: AtomicUsize::new(1),
storage: Default::default(),
query_lock: Default::default(),
revision: AtomicU64::new(1),
pending_revision: AtomicU64::new(1),
revision: AtomicU64::new(Revision::START_U64),
pending_revision: AtomicU64::new(Revision::START_U64),
constant_revision: AtomicCell::new(Revision::start()),
dependency_graph: Default::default(),
}
}
@ -562,8 +615,12 @@ pub struct Revision {
}
impl Revision {
/// Value if the initial revision, as a u64. We don't use 0
/// because we want to use a `NonZeroU64`.
const START_U64: u64 = 1;
fn start() -> Self {
Self::from(1)
Self::from(Self::START_U64)
}
fn from(g: u64) -> Self {

View file

@ -7,44 +7,58 @@ pub(crate) trait ConstantsDatabase: TestContext {
#[salsa::input]
fn input(&self, key: char) -> usize;
fn add(&self, keys: (char, char)) -> usize;
fn add(&self, key1: char, key2: char) -> usize;
fn add3(&self, key1: char, key2: char, key3: char) -> usize;
}
fn add(db: &impl ConstantsDatabase, (key1, key2): (char, char)) -> usize {
fn add(db: &impl ConstantsDatabase, key1: char, key2: char) -> usize {
db.log().add(format!("add({}, {})", key1, key2));
db.input(key1) + db.input(key2)
}
fn add3(db: &impl ConstantsDatabase, key1: char, key2: char, key3: char) -> usize {
db.log().add(format!("add3({}, {}, {})", key1, key2, key3));
db.add(key1, key2) + db.input(key3)
}
// Test we can assign a constant and things will be correctly
// recomputed afterwards.
#[test]
#[should_panic]
fn invalidate_constant() {
let db = &mut TestContextImpl::default();
db.set_constant_input('a', 44);
db.set_constant_input('b', 22);
assert_eq!(db.add('a', 'b'), 66);
db.set_constant_input('a', 66);
assert_eq!(db.add('a', 'b'), 88);
}
#[test]
#[should_panic]
fn invalidate_constant_1() {
let db = &mut TestContextImpl::default();
// Not constant:
db.set_input('a', 44);
assert_eq!(db.add('a', 'a'), 88);
// Becomes constant:
db.set_constant_input('a', 44);
assert_eq!(db.add('a', 'a'), 88);
// Invalidates:
db.set_constant_input('a', 66);
db.set_constant_input('a', 33);
assert_eq!(db.add('a', 'a'), 66);
}
/// Test that invoking `set` on a constant is an error, even if you
/// don't change the value.
// Test cases where we assign same value to 'a' after declaring it a
// constant.
#[test]
#[should_panic]
fn set_after_constant_same_value() {
let db = &mut TestContextImpl::default();
db.set_constant_input('a', 44);
db.set_constant_input('a', 44);
db.set_input('a', 44);
}
@ -54,7 +68,7 @@ fn not_constant() {
db.set_input('a', 22);
db.set_input('b', 44);
assert_eq!(db.add(('a', 'b')), 66);
assert_eq!(db.add('a', 'b'), 66);
assert!(!db.query(AddQuery).is_constant(('a', 'b')));
}
@ -64,7 +78,7 @@ fn is_constant() {
db.set_constant_input('a', 22);
db.set_constant_input('b', 44);
assert_eq!(db.add(('a', 'b')), 66);
assert_eq!(db.add('a', 'b'), 66);
assert!(db.query(AddQuery).is_constant(('a', 'b')));
}
@ -74,7 +88,7 @@ fn mixed_constant() {
db.set_constant_input('a', 22);
db.set_input('b', 44);
assert_eq!(db.add(('a', 'b')), 66);
assert_eq!(db.add('a', 'b'), 66);
assert!(!db.query(AddQuery).is_constant(('a', 'b')));
}
@ -84,14 +98,44 @@ fn becomes_constant_with_change() {
db.set_input('a', 22);
db.set_input('b', 44);
assert_eq!(db.add(('a', 'b')), 66);
assert_eq!(db.add('a', 'b'), 66);
assert!(!db.query(AddQuery).is_constant(('a', 'b')));
db.set_constant_input('a', 23);
assert_eq!(db.add(('a', 'b')), 67);
assert_eq!(db.add('a', 'b'), 67);
assert!(!db.query(AddQuery).is_constant(('a', 'b')));
db.set_constant_input('b', 45);
assert_eq!(db.add(('a', 'b')), 68);
assert_eq!(db.add('a', 'b'), 68);
assert!(db.query(AddQuery).is_constant(('a', 'b')));
}
// Test a subtle case in which an input changes from constant to
// non-constant, but its value doesn't change. If we're not careful,
// this can cause us to incorrectly consider derived values as still
// being constant.
#[test]
fn constant_to_non_constant() {
let db = &mut TestContextImpl::default();
db.set_constant_input('a', 11);
db.set_constant_input('b', 22);
db.set_constant_input('c', 33);
// Here, `add3` invokes `add`, which yields 33. Both calls are
// constant.
assert_eq!(db.add3('a', 'b', 'c'), 66);
db.set_input('a', 11);
// Here, `add3` invokes `add`, which *still* yields 33, but which
// is no longer constant. Since value didn't change, we might
// preserve `add3` unchanged, not noticing that it is no longer
// constant.
assert_eq!(db.add3('a', 'b', 'c'), 66);
// In that case, we would not get the correct result here, when
// 'a' changes *again*.
db.set_input('a', 22);
assert_eq!(db.add3('a', 'b', 'c'), 77);
}