use single lock for LRU

This commit is contained in:
Aleksey Kladov 2019-06-07 15:42:42 +03:00
parent 3d89c0d817
commit 0827c88259
2 changed files with 105 additions and 43 deletions

View file

@ -40,8 +40,6 @@ pub type DependencyStorage<DB, Q> = DerivedStorage<DB, Q, NeverMemoizeValue>;
/// storage requirements. /// storage requirements.
pub type VolatileStorage<DB, Q> = DerivedStorage<DB, Q, VolatileValue>; pub type VolatileStorage<DB, Q> = DerivedStorage<DB, Q, VolatileValue>;
type LinkedHashSet<T> = LinkedHashMap<T, (), BuildHasherDefault<FxHasher>>;
/// Handles storage where the value is 'derived' by executing a /// Handles storage where the value is 'derived' by executing a
/// function (in contrast to "inputs"). /// function (in contrast to "inputs").
pub struct DerivedStorage<DB, Q, MP> pub struct DerivedStorage<DB, Q, MP>
@ -50,11 +48,10 @@ where
DB: Database, DB: Database,
MP: MemoizationPolicy<DB, Q>, MP: MemoizationPolicy<DB, Q>,
{ {
// `lru_cap` logically belongs to `QueryMap`, but we store it outside, so
// that we can read it without aquiring the lock.
lru_cap: AtomicUsize, lru_cap: AtomicUsize,
// if `lru_keys` and `map` are locked togeter, map: RwLock<QueryMap<DB, Q>>,
// `lru_keys` is locked first, to prevent deadlocks.
lru_keys: Mutex<LinkedHashSet<Q::Key>>,
map: RwLock<FxHashMap<Q::Key, QueryState<DB, Q>>>,
policy: PhantomData<MP>, policy: PhantomData<MP>,
} }
@ -238,6 +235,63 @@ impl<DB: Database> std::fmt::Debug for MemoInputs<DB> {
} }
} }
type LinkedHashSet<T> = LinkedHashMap<T, (), BuildHasherDefault<FxHasher>>;
struct QueryMap<DB, Q>
where
Q: QueryFunction<DB>,
DB: Database,
{
lru_keys: LinkedHashSet<Q::Key>,
data: FxHashMap<Q::Key, QueryState<DB, Q>>,
}
impl<DB, Q> Default for QueryMap<DB, Q>
where
Q: QueryFunction<DB>,
DB: Database,
{
fn default() -> Self {
QueryMap {
lru_keys: Default::default(),
data: Default::default(),
}
}
}
impl<DB, Q> QueryMap<DB, Q>
where
Q: QueryFunction<DB>,
DB: Database,
{
fn set_lru_capacity(&mut self, new_capacity: usize) {
if new_capacity == 0 {
self.lru_keys.clear();
} else {
while self.lru_keys.len() > new_capacity {
self.remove_lru();
}
let additional_cap = new_capacity - self.lru_keys.len();
self.lru_keys.reserve(additional_cap);
}
}
fn record_use(&mut self, key: &Q::Key, lru_cap: usize) {
self.lru_keys.insert(key.clone(), ());
if self.lru_keys.len() > lru_cap {
self.remove_lru();
}
}
fn remove_lru(&mut self) {
if let Some((evicted, ())) = self.lru_keys.pop_front() {
if let Some(QueryState::Memoized(memo)) = self.data.get_mut(&evicted) {
memo.value = None;
}
}
}
}
impl<DB, Q, MP> Default for DerivedStorage<DB, Q, MP> impl<DB, Q, MP> Default for DerivedStorage<DB, Q, MP>
where where
Q: QueryFunction<DB>, Q: QueryFunction<DB>,
@ -246,9 +300,8 @@ where
{ {
fn default() -> Self { fn default() -> Self {
DerivedStorage { DerivedStorage {
map: RwLock::new(FxHashMap::default()),
lru_cap: AtomicUsize::new(0), lru_cap: AtomicUsize::new(0),
lru_keys: Mutex::new(LinkedHashSet::with_hasher(Default::default())), map: RwLock::new(QueryMap::default()),
policy: PhantomData, policy: PhantomData,
} }
} }
@ -341,7 +394,10 @@ where
) { ) {
ProbeState::UpToDate(v) => return v, ProbeState::UpToDate(v) => return v,
ProbeState::StaleOrAbsent(mut map) => { ProbeState::StaleOrAbsent(mut map) => {
match map.insert(key.clone(), QueryState::in_progress(runtime.id())) { match map
.data
.insert(key.clone(), QueryState::in_progress(runtime.id()))
{
Some(QueryState::Memoized(old_memo)) => Some(old_memo), Some(QueryState::Memoized(old_memo)) => Some(old_memo),
Some(QueryState::InProgress { .. }) => unreachable!(), Some(QueryState::InProgress { .. }) => unreachable!(),
None => None, None => None,
@ -499,9 +555,9 @@ where
key: &Q::Key, key: &Q::Key,
) -> ProbeState<StampedValue<Q::Value>, MapGuard> ) -> ProbeState<StampedValue<Q::Value>, MapGuard>
where where
MapGuard: Deref<Target = FxHashMap<Q::Key, QueryState<DB, Q>>>, MapGuard: Deref<Target = QueryMap<DB, Q>>,
{ {
match map.get(key) { match map.data.get(key) {
Some(QueryState::InProgress { id, waiting }) => { Some(QueryState::InProgress { id, waiting }) => {
let other_id = *id; let other_id = *id;
return match self.register_with_in_progress_thread( return match self.register_with_in_progress_thread(
@ -600,7 +656,7 @@ where
database_key: &'db DB::DatabaseKey, database_key: &'db DB::DatabaseKey,
key: &'db Q::Key, key: &'db Q::Key,
memo: Option<Memo<DB, Q>>, memo: Option<Memo<DB, Q>>,
map: &'db RwLock<FxHashMap<Q::Key, QueryState<DB, Q>>>, map: &'db RwLock<QueryMap<DB, Q>>,
runtime: &'db Runtime<DB>, runtime: &'db Runtime<DB>,
} }
@ -610,7 +666,7 @@ where
Q: QueryFunction<DB>, Q: QueryFunction<DB>,
{ {
fn new( fn new(
map: &'db RwLock<FxHashMap<Q::Key, QueryState<DB, Q>>>, map: &'db RwLock<QueryMap<DB, Q>>,
key: &'db Q::Key, key: &'db Q::Key,
memo: Option<Memo<DB, Q>>, memo: Option<Memo<DB, Q>>,
database_key: &'db DB::DatabaseKey, database_key: &'db DB::DatabaseKey,
@ -642,12 +698,14 @@ where
let old_value = match self.memo.take() { let old_value = match self.memo.take() {
// Replace the `InProgress` marker that we installed with the new // Replace the `InProgress` marker that we installed with the new
// memo, thus releasing our unique access to this key. // memo, thus releasing our unique access to this key.
Some(memo) => write.insert(self.key.clone(), QueryState::Memoized(memo)), Some(memo) => write
.data
.insert(self.key.clone(), QueryState::Memoized(memo)),
// We had installed an `InProgress` marker, but we panicked before // We had installed an `InProgress` marker, but we panicked before
// it could be removed. At this point, we therefore "own" unique // it could be removed. At this point, we therefore "own" unique
// access to our slot, so we can just remove the key. // access to our slot, so we can just remove the key.
None => write.remove(self.key), None => write.data.remove(self.key),
}; };
match old_value { match old_value {
@ -716,15 +774,7 @@ where
let lru_cap = self.lru_cap.load(Ordering::Relaxed); let lru_cap = self.lru_cap.load(Ordering::Relaxed);
if lru_cap > 0 { if lru_cap > 0 {
let mut lru_keys = self.lru_keys.lock(); self.map.write().record_use(key, lru_cap);
lru_keys.insert(key.clone(), ());
if lru_keys.len() > lru_cap {
if let Some((evicted, ())) = lru_keys.pop_front() {
if let Some(QueryState::Memoized(memo)) = self.map.write().get_mut(&evicted) {
memo.value = None;
}
}
}
} }
db.salsa_runtime() db.salsa_runtime()
@ -756,7 +806,7 @@ where
let map = self.map.read(); let map = self.map.read();
// Look for a memoized value. // Look for a memoized value.
let memo = match map.get(key) { let memo = match map.data.get(key) {
// If somebody depends on us, but we have no map // If somebody depends on us, but we have no map
// entry, that must mean that it was found to be out // entry, that must mean that it was found to be out
// of date and removed. // of date and removed.
@ -888,7 +938,7 @@ where
// ought to do nothing. // ought to do nothing.
{ {
let mut map = self.map.write(); let mut map = self.map.write();
match map.get_mut(key) { match map.data.get_mut(key) {
Some(QueryState::Memoized(memo)) => { Some(QueryState::Memoized(memo)) => {
if memo.verified_at == revision_now { if memo.verified_at == revision_now {
// Since we started verifying inputs, somebody // Since we started verifying inputs, somebody
@ -908,7 +958,7 @@ where
// We found this entry is out of date and // We found this entry is out of date and
// nobody touch it in the meantime. Just // nobody touch it in the meantime. Just
// remove it. // remove it.
map.remove(key); map.data.remove(key);
} else { } else {
// We found this entry is valid. Update the // We found this entry is valid. Update the
// `verified_at` to reflect the current // `verified_at` to reflect the current
@ -937,7 +987,7 @@ where
fn is_constant(&self, _db: &DB, key: &Q::Key) -> bool { fn is_constant(&self, _db: &DB, key: &Q::Key) -> bool {
let map_read = self.map.read(); let map_read = self.map.read();
match map_read.get(key) { match map_read.data.get(key) {
None => false, None => false,
Some(QueryState::InProgress { .. }) => panic!("query in progress"), Some(QueryState::InProgress { .. }) => panic!("query in progress"),
Some(QueryState::Memoized(memo)) => memo.inputs.is_constant(), Some(QueryState::Memoized(memo)) => memo.inputs.is_constant(),
@ -949,7 +999,8 @@ where
C: std::iter::FromIterator<TableEntry<Q::Key, Q::Value>>, C: std::iter::FromIterator<TableEntry<Q::Key, Q::Value>>,
{ {
let map = self.map.read(); let map = self.map.read();
map.iter() map.data
.iter()
.map(|(key, query_state)| TableEntry::new(key.clone(), query_state.value())) .map(|(key, query_state)| TableEntry::new(key.clone(), query_state.value()))
.collect() .collect()
} }
@ -964,7 +1015,7 @@ where
fn sweep(&self, db: &DB, strategy: SweepStrategy) { fn sweep(&self, db: &DB, strategy: SweepStrategy) {
let mut map_write = self.map.write(); let mut map_write = self.map.write();
let revision_now = db.salsa_runtime().current_revision(); let revision_now = db.salsa_runtime().current_revision();
map_write.retain(|key, query_state| { map_write.data.retain(|key, query_state| {
match query_state { match query_state {
// Leave stuff that is currently being computed -- the // Leave stuff that is currently being computed -- the
// other thread doing that work has unique access to // other thread doing that work has unique access to
@ -1043,17 +1094,8 @@ where
MP: MemoizationPolicy<DB, Q>, MP: MemoizationPolicy<DB, Q>,
{ {
fn set_lru_capacity(&self, new_capacity: usize) { fn set_lru_capacity(&self, new_capacity: usize) {
let mut lru_keys = self.lru_keys.lock();
let mut map = self.map.write();
self.lru_cap.store(new_capacity, Ordering::SeqCst); self.lru_cap.store(new_capacity, Ordering::SeqCst);
while lru_keys.len() > new_capacity { self.map.write().set_lru_capacity(new_capacity);
let (evicted, ()) = lru_keys.pop_front().unwrap();
if let Some(QueryState::Memoized(memo)) = map.get_mut(&evicted) {
memo.value = None;
}
}
let additional_cap = new_capacity - lru_keys.len();
lru_keys.reserve(additional_cap);
} }
} }

View file

@ -48,7 +48,6 @@ impl salsa::Database for Database {
#[test] #[test]
fn lru_works() { fn lru_works() {
let mut db = Database::default(); let mut db = Database::default();
let cap = 32;
db.query_mut(GetQuery).set_lru_capacity(32); db.query_mut(GetQuery).set_lru_capacity(32);
assert_eq!(N_POTATOES.load(Ordering::SeqCst), 0); assert_eq!(N_POTATOES.load(Ordering::SeqCst), 0);
@ -56,13 +55,34 @@ fn lru_works() {
let p = db.get(i); let p = db.get(i);
assert_eq!(p.0, i) assert_eq!(p.0, i)
} }
assert_eq!(N_POTATOES.load(Ordering::SeqCst), cap); assert_eq!(N_POTATOES.load(Ordering::SeqCst), 32);
for i in 0..128u32 { for i in 0..128u32 {
let p = db.get(i); let p = db.get(i);
assert_eq!(p.0, i) assert_eq!(p.0, i)
} }
assert_eq!(N_POTATOES.load(Ordering::SeqCst), cap); assert_eq!(N_POTATOES.load(Ordering::SeqCst), 32);
db.query_mut(GetQuery).set_lru_capacity(32);
assert_eq!(N_POTATOES.load(Ordering::SeqCst), 32);
db.query_mut(GetQuery).set_lru_capacity(64);
assert_eq!(N_POTATOES.load(Ordering::SeqCst), 32);
for i in 0..128u32 {
let p = db.get(i);
assert_eq!(p.0, i)
}
assert_eq!(N_POTATOES.load(Ordering::SeqCst), 64);
// Special case: setting capacity to zero disables LRU
db.query_mut(GetQuery).set_lru_capacity(0);
assert_eq!(N_POTATOES.load(Ordering::SeqCst), 64);
for i in 0..128u32 {
let p = db.get(i);
assert_eq!(p.0, i)
}
assert_eq!(N_POTATOES.load(Ordering::SeqCst), 128);
drop(db); drop(db);
assert_eq!(N_POTATOES.load(Ordering::SeqCst), 0); assert_eq!(N_POTATOES.load(Ordering::SeqCst), 0);
} }