diff --git a/src/derived.rs b/src/derived.rs index a329822a..301a289d 100644 --- a/src/derived.rs +++ b/src/derived.rs @@ -171,6 +171,13 @@ where } } +enum ProbeState { + UpToDate(StampedValue), + CycleDetected, + StaleOrAbsent(G), + BlockedOnOtherThread(RuntimeId), +} + impl DerivedStorage where Q: QueryFunction, @@ -196,22 +203,38 @@ where revision_now, ); - // First, check for an up-to-date value (or a cycle). This we - // can do with a simple read-lock. - match self.read_up_to_date_or_cycle(self.map.read(), runtime, revision_now, key) { - Ok(r) => return r, - Err(_guard) => (), - } + // In this loop, we are looking for an up-to-date value. The loop is needed + // to handle other threads: if we find that some other thread is "using" this + // key, we will block until they are done and then loop back around and try + // again. + // + // Otherwise, we first check for a usable value with the read + // lock. If that fails, we acquire the write lock and try + // again. We don't use an upgradable read lock because that + // would eliminate the ability for multiple cache hits to be + // executing in parallel. + let mut old_value = loop { + // Read-lock check. + match self.probe(self.map.read(), runtime, revision_now, key) { + ProbeState::UpToDate(v) => return Ok(v), + ProbeState::CycleDetected => return Err(CycleDetected), + ProbeState::BlockedOnOtherThread(other_id) => { + self.await_other_thread(other_id, key); + continue; + } + ProbeState::StaleOrAbsent(_guard) => (), + } - // Otherwise, we may have to take ownership. Get the write - // lock and check again. If the value is not up-to-date (or - // we have to verify it), insert an `InProgress` indicator to - // hold our spot. - let mut old_value = { - match self.read_up_to_date_or_cycle(self.map.write(), runtime, revision_now, key) { - Ok(r) => return r, - Err(mut map) => { - map.insert(key.clone(), QueryState::InProgress { id: runtime.id() }) + // Write-lock check: install `InProgress` sentinel if no usable value. + match self.probe(self.map.write(), runtime, revision_now, key) { + ProbeState::UpToDate(v) => return Ok(v), + ProbeState::CycleDetected => return Err(CycleDetected), + ProbeState::BlockedOnOtherThread(other_id) => { + self.await_other_thread(other_id, key); + continue; + } + ProbeState::StaleOrAbsent(mut map) => { + break map.insert(key.clone(), QueryState::InProgress { id: runtime.id() }) } } }; @@ -298,20 +321,20 @@ where /// /// Otherwise, returns `Err(map)` where `map` is the lock guard /// that was given in as argument. - fn read_up_to_date_or_cycle( + fn probe( &self, map: MapGuard, runtime: &Runtime, revision_now: Revision, key: &Q::Key, - ) -> Result, CycleDetected>, MapGuard> + ) -> ProbeState where MapGuard: Deref>>, { match map.get(key) { Some(QueryState::InProgress { id }) => { if *id == runtime.id() { - return Ok(Err(CycleDetected)); + return ProbeState::CycleDetected; } else { unimplemented!(); } @@ -336,10 +359,10 @@ where key, m.changed_at, ); - return Ok(Ok(StampedValue { + return ProbeState::UpToDate(StampedValue { value: value.clone(), changed_at: m.changed_at, - })); + }); }; } } @@ -347,7 +370,7 @@ where None => {} } - Err(map) + ProbeState::StaleOrAbsent(map) } /// If some other thread is tasked with producing a memoized @@ -356,36 +379,24 @@ where /// Pre-conditions: /// - we have installed ourselves in the dependency graph and set the /// bool that informs the producer we are waiting - /// - `self.map` must be locked (with `map_guard` as the guard) - fn await_other_thread( - &self, - map_guard: MapGuard, - revision_now: Revision, - key: &Q::Key, - ) -> StampedValue - where - MapGuard: Deref>>, - { - // Intentionally release the lock on map. We cannot be holding - // it while we are sleeping! - std::mem::drop(map_guard); - + /// - `self.map` must not be locked + fn await_other_thread(&self, other_id: RuntimeId, key: &Q::Key) { let mut signal_lock_guard = self.signal_mutex.lock(); loop { { let map = self.map.read(); - if let Some(QueryState::Memoized(m)) = map.get(key) { - assert_eq!(m.verified_at, revision_now); - return if let Some(value) = &m.value { - StampedValue { - value: value.clone(), - changed_at: m.changed_at, - } - } else { - panic!("awaiting production of non-memoized value"); - }; + match map.get(key) { + Some(QueryState::InProgress { id }) => { + // Other thread still working! + assert_eq!(*id, other_id); + } + + _ => { + // The other thread finished! + return; + } } }