mirror of
https://github.com/salsa-rs/salsa.git
synced 2025-01-12 08:30:51 +00:00
use single lock for LRU
This commit is contained in:
parent
3d89c0d817
commit
0827c88259
2 changed files with 105 additions and 43 deletions
122
src/derived.rs
122
src/derived.rs
|
@ -40,8 +40,6 @@ pub type DependencyStorage<DB, Q> = DerivedStorage<DB, Q, NeverMemoizeValue>;
|
||||||
/// storage requirements.
|
/// storage requirements.
|
||||||
pub type VolatileStorage<DB, Q> = DerivedStorage<DB, Q, VolatileValue>;
|
pub type VolatileStorage<DB, Q> = DerivedStorage<DB, Q, VolatileValue>;
|
||||||
|
|
||||||
type LinkedHashSet<T> = LinkedHashMap<T, (), BuildHasherDefault<FxHasher>>;
|
|
||||||
|
|
||||||
/// Handles storage where the value is 'derived' by executing a
|
/// Handles storage where the value is 'derived' by executing a
|
||||||
/// function (in contrast to "inputs").
|
/// function (in contrast to "inputs").
|
||||||
pub struct DerivedStorage<DB, Q, MP>
|
pub struct DerivedStorage<DB, Q, MP>
|
||||||
|
@ -50,11 +48,10 @@ where
|
||||||
DB: Database,
|
DB: Database,
|
||||||
MP: MemoizationPolicy<DB, Q>,
|
MP: MemoizationPolicy<DB, Q>,
|
||||||
{
|
{
|
||||||
|
// `lru_cap` logically belongs to `QueryMap`, but we store it outside, so
|
||||||
|
// that we can read it without aquiring the lock.
|
||||||
lru_cap: AtomicUsize,
|
lru_cap: AtomicUsize,
|
||||||
// if `lru_keys` and `map` are locked togeter,
|
map: RwLock<QueryMap<DB, Q>>,
|
||||||
// `lru_keys` is locked first, to prevent deadlocks.
|
|
||||||
lru_keys: Mutex<LinkedHashSet<Q::Key>>,
|
|
||||||
map: RwLock<FxHashMap<Q::Key, QueryState<DB, Q>>>,
|
|
||||||
policy: PhantomData<MP>,
|
policy: PhantomData<MP>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -238,6 +235,63 @@ impl<DB: Database> std::fmt::Debug for MemoInputs<DB> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type LinkedHashSet<T> = LinkedHashMap<T, (), BuildHasherDefault<FxHasher>>;
|
||||||
|
|
||||||
|
struct QueryMap<DB, Q>
|
||||||
|
where
|
||||||
|
Q: QueryFunction<DB>,
|
||||||
|
DB: Database,
|
||||||
|
{
|
||||||
|
lru_keys: LinkedHashSet<Q::Key>,
|
||||||
|
data: FxHashMap<Q::Key, QueryState<DB, Q>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<DB, Q> Default for QueryMap<DB, Q>
|
||||||
|
where
|
||||||
|
Q: QueryFunction<DB>,
|
||||||
|
DB: Database,
|
||||||
|
{
|
||||||
|
fn default() -> Self {
|
||||||
|
QueryMap {
|
||||||
|
lru_keys: Default::default(),
|
||||||
|
data: Default::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<DB, Q> QueryMap<DB, Q>
|
||||||
|
where
|
||||||
|
Q: QueryFunction<DB>,
|
||||||
|
DB: Database,
|
||||||
|
{
|
||||||
|
fn set_lru_capacity(&mut self, new_capacity: usize) {
|
||||||
|
if new_capacity == 0 {
|
||||||
|
self.lru_keys.clear();
|
||||||
|
} else {
|
||||||
|
while self.lru_keys.len() > new_capacity {
|
||||||
|
self.remove_lru();
|
||||||
|
}
|
||||||
|
let additional_cap = new_capacity - self.lru_keys.len();
|
||||||
|
self.lru_keys.reserve(additional_cap);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn record_use(&mut self, key: &Q::Key, lru_cap: usize) {
|
||||||
|
self.lru_keys.insert(key.clone(), ());
|
||||||
|
if self.lru_keys.len() > lru_cap {
|
||||||
|
self.remove_lru();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn remove_lru(&mut self) {
|
||||||
|
if let Some((evicted, ())) = self.lru_keys.pop_front() {
|
||||||
|
if let Some(QueryState::Memoized(memo)) = self.data.get_mut(&evicted) {
|
||||||
|
memo.value = None;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<DB, Q, MP> Default for DerivedStorage<DB, Q, MP>
|
impl<DB, Q, MP> Default for DerivedStorage<DB, Q, MP>
|
||||||
where
|
where
|
||||||
Q: QueryFunction<DB>,
|
Q: QueryFunction<DB>,
|
||||||
|
@ -246,9 +300,8 @@ where
|
||||||
{
|
{
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
DerivedStorage {
|
DerivedStorage {
|
||||||
map: RwLock::new(FxHashMap::default()),
|
|
||||||
lru_cap: AtomicUsize::new(0),
|
lru_cap: AtomicUsize::new(0),
|
||||||
lru_keys: Mutex::new(LinkedHashSet::with_hasher(Default::default())),
|
map: RwLock::new(QueryMap::default()),
|
||||||
policy: PhantomData,
|
policy: PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -341,7 +394,10 @@ where
|
||||||
) {
|
) {
|
||||||
ProbeState::UpToDate(v) => return v,
|
ProbeState::UpToDate(v) => return v,
|
||||||
ProbeState::StaleOrAbsent(mut map) => {
|
ProbeState::StaleOrAbsent(mut map) => {
|
||||||
match map.insert(key.clone(), QueryState::in_progress(runtime.id())) {
|
match map
|
||||||
|
.data
|
||||||
|
.insert(key.clone(), QueryState::in_progress(runtime.id()))
|
||||||
|
{
|
||||||
Some(QueryState::Memoized(old_memo)) => Some(old_memo),
|
Some(QueryState::Memoized(old_memo)) => Some(old_memo),
|
||||||
Some(QueryState::InProgress { .. }) => unreachable!(),
|
Some(QueryState::InProgress { .. }) => unreachable!(),
|
||||||
None => None,
|
None => None,
|
||||||
|
@ -499,9 +555,9 @@ where
|
||||||
key: &Q::Key,
|
key: &Q::Key,
|
||||||
) -> ProbeState<StampedValue<Q::Value>, MapGuard>
|
) -> ProbeState<StampedValue<Q::Value>, MapGuard>
|
||||||
where
|
where
|
||||||
MapGuard: Deref<Target = FxHashMap<Q::Key, QueryState<DB, Q>>>,
|
MapGuard: Deref<Target = QueryMap<DB, Q>>,
|
||||||
{
|
{
|
||||||
match map.get(key) {
|
match map.data.get(key) {
|
||||||
Some(QueryState::InProgress { id, waiting }) => {
|
Some(QueryState::InProgress { id, waiting }) => {
|
||||||
let other_id = *id;
|
let other_id = *id;
|
||||||
return match self.register_with_in_progress_thread(
|
return match self.register_with_in_progress_thread(
|
||||||
|
@ -600,7 +656,7 @@ where
|
||||||
database_key: &'db DB::DatabaseKey,
|
database_key: &'db DB::DatabaseKey,
|
||||||
key: &'db Q::Key,
|
key: &'db Q::Key,
|
||||||
memo: Option<Memo<DB, Q>>,
|
memo: Option<Memo<DB, Q>>,
|
||||||
map: &'db RwLock<FxHashMap<Q::Key, QueryState<DB, Q>>>,
|
map: &'db RwLock<QueryMap<DB, Q>>,
|
||||||
runtime: &'db Runtime<DB>,
|
runtime: &'db Runtime<DB>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -610,7 +666,7 @@ where
|
||||||
Q: QueryFunction<DB>,
|
Q: QueryFunction<DB>,
|
||||||
{
|
{
|
||||||
fn new(
|
fn new(
|
||||||
map: &'db RwLock<FxHashMap<Q::Key, QueryState<DB, Q>>>,
|
map: &'db RwLock<QueryMap<DB, Q>>,
|
||||||
key: &'db Q::Key,
|
key: &'db Q::Key,
|
||||||
memo: Option<Memo<DB, Q>>,
|
memo: Option<Memo<DB, Q>>,
|
||||||
database_key: &'db DB::DatabaseKey,
|
database_key: &'db DB::DatabaseKey,
|
||||||
|
@ -642,12 +698,14 @@ where
|
||||||
let old_value = match self.memo.take() {
|
let old_value = match self.memo.take() {
|
||||||
// Replace the `InProgress` marker that we installed with the new
|
// Replace the `InProgress` marker that we installed with the new
|
||||||
// memo, thus releasing our unique access to this key.
|
// memo, thus releasing our unique access to this key.
|
||||||
Some(memo) => write.insert(self.key.clone(), QueryState::Memoized(memo)),
|
Some(memo) => write
|
||||||
|
.data
|
||||||
|
.insert(self.key.clone(), QueryState::Memoized(memo)),
|
||||||
|
|
||||||
// We had installed an `InProgress` marker, but we panicked before
|
// We had installed an `InProgress` marker, but we panicked before
|
||||||
// it could be removed. At this point, we therefore "own" unique
|
// it could be removed. At this point, we therefore "own" unique
|
||||||
// access to our slot, so we can just remove the key.
|
// access to our slot, so we can just remove the key.
|
||||||
None => write.remove(self.key),
|
None => write.data.remove(self.key),
|
||||||
};
|
};
|
||||||
|
|
||||||
match old_value {
|
match old_value {
|
||||||
|
@ -716,15 +774,7 @@ where
|
||||||
|
|
||||||
let lru_cap = self.lru_cap.load(Ordering::Relaxed);
|
let lru_cap = self.lru_cap.load(Ordering::Relaxed);
|
||||||
if lru_cap > 0 {
|
if lru_cap > 0 {
|
||||||
let mut lru_keys = self.lru_keys.lock();
|
self.map.write().record_use(key, lru_cap);
|
||||||
lru_keys.insert(key.clone(), ());
|
|
||||||
if lru_keys.len() > lru_cap {
|
|
||||||
if let Some((evicted, ())) = lru_keys.pop_front() {
|
|
||||||
if let Some(QueryState::Memoized(memo)) = self.map.write().get_mut(&evicted) {
|
|
||||||
memo.value = None;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
db.salsa_runtime()
|
db.salsa_runtime()
|
||||||
|
@ -756,7 +806,7 @@ where
|
||||||
let map = self.map.read();
|
let map = self.map.read();
|
||||||
|
|
||||||
// Look for a memoized value.
|
// Look for a memoized value.
|
||||||
let memo = match map.get(key) {
|
let memo = match map.data.get(key) {
|
||||||
// If somebody depends on us, but we have no map
|
// If somebody depends on us, but we have no map
|
||||||
// entry, that must mean that it was found to be out
|
// entry, that must mean that it was found to be out
|
||||||
// of date and removed.
|
// of date and removed.
|
||||||
|
@ -888,7 +938,7 @@ where
|
||||||
// ought to do nothing.
|
// ought to do nothing.
|
||||||
{
|
{
|
||||||
let mut map = self.map.write();
|
let mut map = self.map.write();
|
||||||
match map.get_mut(key) {
|
match map.data.get_mut(key) {
|
||||||
Some(QueryState::Memoized(memo)) => {
|
Some(QueryState::Memoized(memo)) => {
|
||||||
if memo.verified_at == revision_now {
|
if memo.verified_at == revision_now {
|
||||||
// Since we started verifying inputs, somebody
|
// Since we started verifying inputs, somebody
|
||||||
|
@ -908,7 +958,7 @@ where
|
||||||
// We found this entry is out of date and
|
// We found this entry is out of date and
|
||||||
// nobody touch it in the meantime. Just
|
// nobody touch it in the meantime. Just
|
||||||
// remove it.
|
// remove it.
|
||||||
map.remove(key);
|
map.data.remove(key);
|
||||||
} else {
|
} else {
|
||||||
// We found this entry is valid. Update the
|
// We found this entry is valid. Update the
|
||||||
// `verified_at` to reflect the current
|
// `verified_at` to reflect the current
|
||||||
|
@ -937,7 +987,7 @@ where
|
||||||
|
|
||||||
fn is_constant(&self, _db: &DB, key: &Q::Key) -> bool {
|
fn is_constant(&self, _db: &DB, key: &Q::Key) -> bool {
|
||||||
let map_read = self.map.read();
|
let map_read = self.map.read();
|
||||||
match map_read.get(key) {
|
match map_read.data.get(key) {
|
||||||
None => false,
|
None => false,
|
||||||
Some(QueryState::InProgress { .. }) => panic!("query in progress"),
|
Some(QueryState::InProgress { .. }) => panic!("query in progress"),
|
||||||
Some(QueryState::Memoized(memo)) => memo.inputs.is_constant(),
|
Some(QueryState::Memoized(memo)) => memo.inputs.is_constant(),
|
||||||
|
@ -949,7 +999,8 @@ where
|
||||||
C: std::iter::FromIterator<TableEntry<Q::Key, Q::Value>>,
|
C: std::iter::FromIterator<TableEntry<Q::Key, Q::Value>>,
|
||||||
{
|
{
|
||||||
let map = self.map.read();
|
let map = self.map.read();
|
||||||
map.iter()
|
map.data
|
||||||
|
.iter()
|
||||||
.map(|(key, query_state)| TableEntry::new(key.clone(), query_state.value()))
|
.map(|(key, query_state)| TableEntry::new(key.clone(), query_state.value()))
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
@ -964,7 +1015,7 @@ where
|
||||||
fn sweep(&self, db: &DB, strategy: SweepStrategy) {
|
fn sweep(&self, db: &DB, strategy: SweepStrategy) {
|
||||||
let mut map_write = self.map.write();
|
let mut map_write = self.map.write();
|
||||||
let revision_now = db.salsa_runtime().current_revision();
|
let revision_now = db.salsa_runtime().current_revision();
|
||||||
map_write.retain(|key, query_state| {
|
map_write.data.retain(|key, query_state| {
|
||||||
match query_state {
|
match query_state {
|
||||||
// Leave stuff that is currently being computed -- the
|
// Leave stuff that is currently being computed -- the
|
||||||
// other thread doing that work has unique access to
|
// other thread doing that work has unique access to
|
||||||
|
@ -1043,17 +1094,8 @@ where
|
||||||
MP: MemoizationPolicy<DB, Q>,
|
MP: MemoizationPolicy<DB, Q>,
|
||||||
{
|
{
|
||||||
fn set_lru_capacity(&self, new_capacity: usize) {
|
fn set_lru_capacity(&self, new_capacity: usize) {
|
||||||
let mut lru_keys = self.lru_keys.lock();
|
|
||||||
let mut map = self.map.write();
|
|
||||||
self.lru_cap.store(new_capacity, Ordering::SeqCst);
|
self.lru_cap.store(new_capacity, Ordering::SeqCst);
|
||||||
while lru_keys.len() > new_capacity {
|
self.map.write().set_lru_capacity(new_capacity);
|
||||||
let (evicted, ()) = lru_keys.pop_front().unwrap();
|
|
||||||
if let Some(QueryState::Memoized(memo)) = map.get_mut(&evicted) {
|
|
||||||
memo.value = None;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let additional_cap = new_capacity - lru_keys.len();
|
|
||||||
lru_keys.reserve(additional_cap);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
26
tests/lru.rs
26
tests/lru.rs
|
@ -48,7 +48,6 @@ impl salsa::Database for Database {
|
||||||
#[test]
|
#[test]
|
||||||
fn lru_works() {
|
fn lru_works() {
|
||||||
let mut db = Database::default();
|
let mut db = Database::default();
|
||||||
let cap = 32;
|
|
||||||
db.query_mut(GetQuery).set_lru_capacity(32);
|
db.query_mut(GetQuery).set_lru_capacity(32);
|
||||||
assert_eq!(N_POTATOES.load(Ordering::SeqCst), 0);
|
assert_eq!(N_POTATOES.load(Ordering::SeqCst), 0);
|
||||||
|
|
||||||
|
@ -56,13 +55,34 @@ fn lru_works() {
|
||||||
let p = db.get(i);
|
let p = db.get(i);
|
||||||
assert_eq!(p.0, i)
|
assert_eq!(p.0, i)
|
||||||
}
|
}
|
||||||
assert_eq!(N_POTATOES.load(Ordering::SeqCst), cap);
|
assert_eq!(N_POTATOES.load(Ordering::SeqCst), 32);
|
||||||
|
|
||||||
for i in 0..128u32 {
|
for i in 0..128u32 {
|
||||||
let p = db.get(i);
|
let p = db.get(i);
|
||||||
assert_eq!(p.0, i)
|
assert_eq!(p.0, i)
|
||||||
}
|
}
|
||||||
assert_eq!(N_POTATOES.load(Ordering::SeqCst), cap);
|
assert_eq!(N_POTATOES.load(Ordering::SeqCst), 32);
|
||||||
|
|
||||||
|
db.query_mut(GetQuery).set_lru_capacity(32);
|
||||||
|
assert_eq!(N_POTATOES.load(Ordering::SeqCst), 32);
|
||||||
|
|
||||||
|
db.query_mut(GetQuery).set_lru_capacity(64);
|
||||||
|
assert_eq!(N_POTATOES.load(Ordering::SeqCst), 32);
|
||||||
|
for i in 0..128u32 {
|
||||||
|
let p = db.get(i);
|
||||||
|
assert_eq!(p.0, i)
|
||||||
|
}
|
||||||
|
assert_eq!(N_POTATOES.load(Ordering::SeqCst), 64);
|
||||||
|
|
||||||
|
// Special case: setting capacity to zero disables LRU
|
||||||
|
db.query_mut(GetQuery).set_lru_capacity(0);
|
||||||
|
assert_eq!(N_POTATOES.load(Ordering::SeqCst), 64);
|
||||||
|
for i in 0..128u32 {
|
||||||
|
let p = db.get(i);
|
||||||
|
assert_eq!(p.0, i)
|
||||||
|
}
|
||||||
|
assert_eq!(N_POTATOES.load(Ordering::SeqCst), 128);
|
||||||
|
|
||||||
drop(db);
|
drop(db);
|
||||||
assert_eq!(N_POTATOES.load(Ordering::SeqCst), 0);
|
assert_eq!(N_POTATOES.load(Ordering::SeqCst), 0);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue