From 0827c88259a3b9522d0d26d62465d9d26c088d6b Mon Sep 17 00:00:00 2001 From: Aleksey Kladov Date: Fri, 7 Jun 2019 15:42:42 +0300 Subject: [PATCH] use single lock for LRU --- src/derived.rs | 122 +++++++++++++++++++++++++++++++++---------------- tests/lru.rs | 26 +++++++++-- 2 files changed, 105 insertions(+), 43 deletions(-) diff --git a/src/derived.rs b/src/derived.rs index 9452b949..e7de5efa 100644 --- a/src/derived.rs +++ b/src/derived.rs @@ -40,8 +40,6 @@ pub type DependencyStorage = DerivedStorage; /// storage requirements. pub type VolatileStorage = DerivedStorage; -type LinkedHashSet = LinkedHashMap>; - /// Handles storage where the value is 'derived' by executing a /// function (in contrast to "inputs"). pub struct DerivedStorage @@ -50,11 +48,10 @@ where DB: Database, MP: MemoizationPolicy, { + // `lru_cap` logically belongs to `QueryMap`, but we store it outside, so + // that we can read it without aquiring the lock. lru_cap: AtomicUsize, - // if `lru_keys` and `map` are locked togeter, - // `lru_keys` is locked first, to prevent deadlocks. - lru_keys: Mutex>, - map: RwLock>>, + map: RwLock>, policy: PhantomData, } @@ -238,6 +235,63 @@ impl std::fmt::Debug for MemoInputs { } } +type LinkedHashSet = LinkedHashMap>; + +struct QueryMap +where + Q: QueryFunction, + DB: Database, +{ + lru_keys: LinkedHashSet, + data: FxHashMap>, +} + +impl Default for QueryMap +where + Q: QueryFunction, + DB: Database, +{ + fn default() -> Self { + QueryMap { + lru_keys: Default::default(), + data: Default::default(), + } + } +} + +impl QueryMap +where + Q: QueryFunction, + DB: Database, +{ + fn set_lru_capacity(&mut self, new_capacity: usize) { + if new_capacity == 0 { + self.lru_keys.clear(); + } else { + while self.lru_keys.len() > new_capacity { + self.remove_lru(); + } + let additional_cap = new_capacity - self.lru_keys.len(); + self.lru_keys.reserve(additional_cap); + } + } + + fn record_use(&mut self, key: &Q::Key, lru_cap: usize) { + self.lru_keys.insert(key.clone(), ()); + if self.lru_keys.len() > lru_cap { + self.remove_lru(); + } + } + + fn remove_lru(&mut self) { + if let Some((evicted, ())) = self.lru_keys.pop_front() { + if let Some(QueryState::Memoized(memo)) = self.data.get_mut(&evicted) { + memo.value = None; + } + } + } +} + impl Default for DerivedStorage where Q: QueryFunction, @@ -246,9 +300,8 @@ where { fn default() -> Self { DerivedStorage { - map: RwLock::new(FxHashMap::default()), lru_cap: AtomicUsize::new(0), - lru_keys: Mutex::new(LinkedHashSet::with_hasher(Default::default())), + map: RwLock::new(QueryMap::default()), policy: PhantomData, } } @@ -341,7 +394,10 @@ where ) { ProbeState::UpToDate(v) => return v, ProbeState::StaleOrAbsent(mut map) => { - match map.insert(key.clone(), QueryState::in_progress(runtime.id())) { + match map + .data + .insert(key.clone(), QueryState::in_progress(runtime.id())) + { Some(QueryState::Memoized(old_memo)) => Some(old_memo), Some(QueryState::InProgress { .. }) => unreachable!(), None => None, @@ -499,9 +555,9 @@ where key: &Q::Key, ) -> ProbeState, MapGuard> where - MapGuard: Deref>>, + MapGuard: Deref>, { - match map.get(key) { + match map.data.get(key) { Some(QueryState::InProgress { id, waiting }) => { let other_id = *id; return match self.register_with_in_progress_thread( @@ -600,7 +656,7 @@ where database_key: &'db DB::DatabaseKey, key: &'db Q::Key, memo: Option>, - map: &'db RwLock>>, + map: &'db RwLock>, runtime: &'db Runtime, } @@ -610,7 +666,7 @@ where Q: QueryFunction, { fn new( - map: &'db RwLock>>, + map: &'db RwLock>, key: &'db Q::Key, memo: Option>, database_key: &'db DB::DatabaseKey, @@ -642,12 +698,14 @@ where let old_value = match self.memo.take() { // Replace the `InProgress` marker that we installed with the new // memo, thus releasing our unique access to this key. - Some(memo) => write.insert(self.key.clone(), QueryState::Memoized(memo)), + Some(memo) => write + .data + .insert(self.key.clone(), QueryState::Memoized(memo)), // We had installed an `InProgress` marker, but we panicked before // it could be removed. At this point, we therefore "own" unique // access to our slot, so we can just remove the key. - None => write.remove(self.key), + None => write.data.remove(self.key), }; match old_value { @@ -716,15 +774,7 @@ where let lru_cap = self.lru_cap.load(Ordering::Relaxed); if lru_cap > 0 { - let mut lru_keys = self.lru_keys.lock(); - lru_keys.insert(key.clone(), ()); - if lru_keys.len() > lru_cap { - if let Some((evicted, ())) = lru_keys.pop_front() { - if let Some(QueryState::Memoized(memo)) = self.map.write().get_mut(&evicted) { - memo.value = None; - } - } - } + self.map.write().record_use(key, lru_cap); } db.salsa_runtime() @@ -756,7 +806,7 @@ where let map = self.map.read(); // Look for a memoized value. - let memo = match map.get(key) { + let memo = match map.data.get(key) { // If somebody depends on us, but we have no map // entry, that must mean that it was found to be out // of date and removed. @@ -888,7 +938,7 @@ where // ought to do nothing. { let mut map = self.map.write(); - match map.get_mut(key) { + match map.data.get_mut(key) { Some(QueryState::Memoized(memo)) => { if memo.verified_at == revision_now { // Since we started verifying inputs, somebody @@ -908,7 +958,7 @@ where // We found this entry is out of date and // nobody touch it in the meantime. Just // remove it. - map.remove(key); + map.data.remove(key); } else { // We found this entry is valid. Update the // `verified_at` to reflect the current @@ -937,7 +987,7 @@ where fn is_constant(&self, _db: &DB, key: &Q::Key) -> bool { let map_read = self.map.read(); - match map_read.get(key) { + match map_read.data.get(key) { None => false, Some(QueryState::InProgress { .. }) => panic!("query in progress"), Some(QueryState::Memoized(memo)) => memo.inputs.is_constant(), @@ -949,7 +999,8 @@ where C: std::iter::FromIterator>, { let map = self.map.read(); - map.iter() + map.data + .iter() .map(|(key, query_state)| TableEntry::new(key.clone(), query_state.value())) .collect() } @@ -964,7 +1015,7 @@ where fn sweep(&self, db: &DB, strategy: SweepStrategy) { let mut map_write = self.map.write(); let revision_now = db.salsa_runtime().current_revision(); - map_write.retain(|key, query_state| { + map_write.data.retain(|key, query_state| { match query_state { // Leave stuff that is currently being computed -- the // other thread doing that work has unique access to @@ -1043,17 +1094,8 @@ where MP: MemoizationPolicy, { fn set_lru_capacity(&self, new_capacity: usize) { - let mut lru_keys = self.lru_keys.lock(); - let mut map = self.map.write(); self.lru_cap.store(new_capacity, Ordering::SeqCst); - while lru_keys.len() > new_capacity { - let (evicted, ()) = lru_keys.pop_front().unwrap(); - if let Some(QueryState::Memoized(memo)) = map.get_mut(&evicted) { - memo.value = None; - } - } - let additional_cap = new_capacity - lru_keys.len(); - lru_keys.reserve(additional_cap); + self.map.write().set_lru_capacity(new_capacity); } } diff --git a/tests/lru.rs b/tests/lru.rs index 87ed76d6..40bba2ef 100644 --- a/tests/lru.rs +++ b/tests/lru.rs @@ -48,7 +48,6 @@ impl salsa::Database for Database { #[test] fn lru_works() { let mut db = Database::default(); - let cap = 32; db.query_mut(GetQuery).set_lru_capacity(32); assert_eq!(N_POTATOES.load(Ordering::SeqCst), 0); @@ -56,13 +55,34 @@ fn lru_works() { let p = db.get(i); assert_eq!(p.0, i) } - assert_eq!(N_POTATOES.load(Ordering::SeqCst), cap); + assert_eq!(N_POTATOES.load(Ordering::SeqCst), 32); for i in 0..128u32 { let p = db.get(i); assert_eq!(p.0, i) } - assert_eq!(N_POTATOES.load(Ordering::SeqCst), cap); + assert_eq!(N_POTATOES.load(Ordering::SeqCst), 32); + + db.query_mut(GetQuery).set_lru_capacity(32); + assert_eq!(N_POTATOES.load(Ordering::SeqCst), 32); + + db.query_mut(GetQuery).set_lru_capacity(64); + assert_eq!(N_POTATOES.load(Ordering::SeqCst), 32); + for i in 0..128u32 { + let p = db.get(i); + assert_eq!(p.0, i) + } + assert_eq!(N_POTATOES.load(Ordering::SeqCst), 64); + + // Special case: setting capacity to zero disables LRU + db.query_mut(GetQuery).set_lru_capacity(0); + assert_eq!(N_POTATOES.load(Ordering::SeqCst), 64); + for i in 0..128u32 { + let p = db.get(i); + assert_eq!(p.0, i) + } + assert_eq!(N_POTATOES.load(Ordering::SeqCst), 128); + drop(db); assert_eq!(N_POTATOES.load(Ordering::SeqCst), 0); }