From 3d89c0d8178c37045dee79ffffcd7f1a4c543419 Mon Sep 17 00:00:00 2001 From: Aleksey Kladov Date: Thu, 6 Jun 2019 17:59:54 +0300 Subject: [PATCH] Add LRU to derived storage LRU allows to bound the maximum number of *values* that are present in the table. --- Cargo.toml | 1 + src/derived.rs | 48 +++++++++++++++++++++++++++++++++- src/lib.rs | 16 ++++++++++++ src/plumbing.rs | 11 ++++++++ tests/lru.rs | 68 +++++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 143 insertions(+), 1 deletion(-) create mode 100644 tests/lru.rs diff --git a/Cargo.toml b/Cargo.toml index ce04ecb..f6bc937 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,7 @@ indexmap = "1.0.1" log = "0.4.5" smallvec = "0.6.5" salsa-macros = { version = "0.12.1", path = "components/salsa-macros" } +linked-hash-map = "0.5.2" [dev-dependencies] diff = "0.1.0" diff --git a/src/derived.rs b/src/derived.rs index 0bc2ded..9452b94 100644 --- a/src/derived.rs +++ b/src/derived.rs @@ -1,6 +1,7 @@ use crate::debug::TableEntry; use crate::plumbing::CycleDetected; use crate::plumbing::DatabaseKey; +use crate::plumbing::LruQueryStorageOps; use crate::plumbing::QueryFunction; use crate::plumbing::QueryStorageMassOps; use crate::plumbing::QueryStorageOps; @@ -11,13 +12,16 @@ use crate::runtime::Runtime; use crate::runtime::RuntimeId; use crate::runtime::StampedValue; use crate::{Database, DiscardIf, DiscardWhat, Event, EventKind, SweepStrategy}; +use linked_hash_map::LinkedHashMap; use log::{debug, info}; use parking_lot::Mutex; use parking_lot::RwLock; -use rustc_hash::FxHashMap; +use rustc_hash::{FxHashMap, FxHasher}; use smallvec::SmallVec; +use std::hash::BuildHasherDefault; use std::marker::PhantomData; use std::ops::Deref; +use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::mpsc::{self, Receiver, Sender}; use std::sync::Arc; @@ -36,6 +40,8 @@ pub type DependencyStorage = DerivedStorage; /// storage requirements. pub type VolatileStorage = DerivedStorage; +type LinkedHashSet = LinkedHashMap>; + /// Handles storage where the value is 'derived' by executing a /// function (in contrast to "inputs"). pub struct DerivedStorage @@ -44,6 +50,10 @@ where DB: Database, MP: MemoizationPolicy, { + lru_cap: AtomicUsize, + // if `lru_keys` and `map` are locked togeter, + // `lru_keys` is locked first, to prevent deadlocks. + lru_keys: Mutex>, map: RwLock>>, policy: PhantomData, } @@ -237,6 +247,8 @@ where fn default() -> Self { DerivedStorage { map: RwLock::new(FxHashMap::default()), + lru_cap: AtomicUsize::new(0), + lru_keys: Mutex::new(LinkedHashSet::with_hasher(Default::default())), policy: PhantomData, } } @@ -702,6 +714,19 @@ where ) -> Result { let StampedValue { value, changed_at } = self.read(db, key, &database_key)?; + let lru_cap = self.lru_cap.load(Ordering::Relaxed); + if lru_cap > 0 { + let mut lru_keys = self.lru_keys.lock(); + lru_keys.insert(key.clone(), ()); + if lru_keys.len() > lru_cap { + if let Some((evicted, ())) = lru_keys.pop_front() { + if let Some(QueryState::Memoized(memo)) = self.map.write().get_mut(&evicted) { + memo.value = None; + } + } + } + } + db.salsa_runtime() .report_query_read(database_key, changed_at); @@ -1011,6 +1036,27 @@ where } } +impl LruQueryStorageOps for DerivedStorage +where + Q: QueryFunction, + DB: Database, + MP: MemoizationPolicy, +{ + fn set_lru_capacity(&self, new_capacity: usize) { + let mut lru_keys = self.lru_keys.lock(); + let mut map = self.map.write(); + self.lru_cap.store(new_capacity, Ordering::SeqCst); + while lru_keys.len() > new_capacity { + let (evicted, ()) = lru_keys.pop_front().unwrap(); + if let Some(QueryState::Memoized(memo)) = map.get_mut(&evicted) { + memo.value = None; + } + } + let additional_cap = new_capacity - lru_keys.len(); + lru_keys.reserve(additional_cap); + } +} + impl Memo where Q: QueryFunction, diff --git a/src/lib.rs b/src/lib.rs index acf201a..2872c3e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -24,6 +24,7 @@ use crate::plumbing::CycleDetected; use crate::plumbing::InputQueryStorageOps; use crate::plumbing::QueryStorageMassOps; use crate::plumbing::QueryStorageOps; +use crate::plumbing::LruQueryStorageOps; use derive_new::new; use std::fmt::{self, Debug}; use std::hash::Hash; @@ -534,6 +535,21 @@ where self.storage .set_constant(self.db, &key, &self.database_key(&key), value); } + + /// Sets the size of LRU cache of values for this query table. + /// + /// That is, at most `cap` values will be preset in the table at the same + /// time. This helps with keeping maximum memory usage under control, at the + /// cost of potential extra recalculations of evicted values. + /// + /// If `cap` is zero, all values are preserved, this is the default. + pub fn set_lru_capacity(&self, cap: usize) + where + Q::Storage: plumbing::LruQueryStorageOps, + { + self.storage + .set_lru_capacity(cap); + } } // Re-export the procedural macros. diff --git a/src/plumbing.rs b/src/plumbing.rs index 636a47d..a0c11d5 100644 --- a/src/plumbing.rs +++ b/src/plumbing.rs @@ -203,3 +203,14 @@ where new_value: Q::Value, ); } + +/// An optional trait that is implemented for "user mutable" storage: +/// that is, storage whose value is not derived from other storage but +/// is set independently. +pub trait LruQueryStorageOps: Default +{ + fn set_lru_capacity( + &self, + new_capacity: usize, + ); +} diff --git a/tests/lru.rs b/tests/lru.rs new file mode 100644 index 0000000..87ed76d --- /dev/null +++ b/tests/lru.rs @@ -0,0 +1,68 @@ +//! Test setting LRU actually limits the number of things in the database; +use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, +}; + +use salsa::Database as _; + +#[derive(Debug, PartialEq, Eq)] +struct HotPotato(u32); + +static N_POTATOES: AtomicUsize = AtomicUsize::new(0); + +impl HotPotato { + fn new(id: u32) -> HotPotato { + N_POTATOES.fetch_add(1, Ordering::SeqCst); + HotPotato(id) + } +} + +impl Drop for HotPotato { + fn drop(&mut self) { + N_POTATOES.fetch_sub(1, Ordering::SeqCst); + } +} + +#[salsa::query_group(QueryGroupStorage)] +trait QueryGroup { + fn get(&self, x: u32) -> Arc; +} + +fn get(_db: &impl QueryGroup, x: u32) -> Arc { + Arc::new(HotPotato::new(x)) +} + +#[salsa::database(QueryGroupStorage)] +#[derive(Default)] +struct Database { + runtime: salsa::Runtime, +} + +impl salsa::Database for Database { + fn salsa_runtime(&self) -> &salsa::Runtime { + &self.runtime + } +} + +#[test] +fn lru_works() { + let mut db = Database::default(); + let cap = 32; + db.query_mut(GetQuery).set_lru_capacity(32); + assert_eq!(N_POTATOES.load(Ordering::SeqCst), 0); + + for i in 0..128u32 { + let p = db.get(i); + assert_eq!(p.0, i) + } + assert_eq!(N_POTATOES.load(Ordering::SeqCst), cap); + + for i in 0..128u32 { + let p = db.get(i); + assert_eq!(p.0, i) + } + assert_eq!(N_POTATOES.load(Ordering::SeqCst), cap); + drop(db); + assert_eq!(N_POTATOES.load(Ordering::SeqCst), 0); +}