remove the query_in_progress field

This commit is contained in:
Niko Matsakis 2018-10-31 15:59:00 -04:00
parent 252132f9b4
commit 2a6b8e07f9
2 changed files with 18 additions and 31 deletions

View file

@ -1,7 +1,7 @@
use crate::{Database, Event, EventKind, SweepStrategy};
use lock_api::{RawRwLock, RawRwLockRecursive};
use log::debug;
use parking_lot::{Mutex, RwLock, RwLockReadGuard};
use parking_lot::{Mutex, RwLock};
use rustc_hash::{FxHashMap, FxHasher};
use smallvec::SmallVec;
use std::cell::RefCell;
@ -95,14 +95,11 @@ where
counter: self.shared_state.next_id.fetch_add(1, Ordering::SeqCst),
};
let mut local_state = LocalState::default();
local_state.query_in_progress = true;
Runtime {
id,
revision_guard: Some(revision_guard),
shared_state: self.shared_state.clone(),
local_state: RefCell::new(local_state),
local_state: Default::default(),
}
}
@ -189,7 +186,7 @@ where
pub(crate) fn with_incremented_revision<R>(&self, op: impl FnOnce(Revision) -> R) -> R {
log::debug!("increment_revision()");
if self.query_in_progress() {
if !self.permits_increment() {
panic!("increment_revision invoked during a query computation");
}
@ -228,8 +225,8 @@ where
op(new_revision)
}
pub(crate) fn query_in_progress(&self) -> bool {
self.local_state.borrow().query_in_progress
pub(crate) fn permits_increment(&self) -> bool {
self.revision_guard.is_none() && self.local_state.borrow().query_stack.is_empty()
}
pub(crate) fn execute_query_implementation<V>(
@ -398,38 +395,17 @@ impl<DB: Database> Default for SharedState<DB> {
/// State that will be specific to a single execution threads (when we
/// support multiple threads)
struct LocalState<DB: Database> {
query_in_progress: bool,
query_stack: Vec<ActiveQuery<DB>>,
}
impl<DB: Database> Default for LocalState<DB> {
fn default() -> Self {
LocalState {
query_in_progress: false,
query_stack: Default::default(),
}
}
}
pub(crate) struct QueryGuard<'db, DB: Database + 'db> {
db: &'db Runtime<DB>,
lock: RwLockReadGuard<'db, ()>,
}
impl<'db, DB: Database> QueryGuard<'db, DB> {
fn new(db: &'db Runtime<DB>, lock: RwLockReadGuard<'db, ()>) -> Self {
Self { db, lock }
}
}
impl<'db, DB: Database> Drop for QueryGuard<'db, DB> {
fn drop(&mut self) {
let mut local_state = self.db.local_state.borrow_mut();
assert!(local_state.query_in_progress);
local_state.query_in_progress = false;
}
}
struct ActiveQuery<DB: Database> {
/// What query is executing
descriptor: DB::QueryDescriptor,

View file

@ -1,4 +1,4 @@
use salsa::Database;
use salsa::{Database, Frozen, ParallelDatabase};
use std::panic::{self, AssertUnwindSafe};
salsa::query_group! {
@ -29,6 +29,14 @@ impl salsa::Database for DatabaseStruct {
}
}
impl salsa::ParallelDatabase for DatabaseStruct {
fn fork(&self) -> Frozen<Self> {
Frozen::new(DatabaseStruct {
runtime: self.runtime.fork(self),
})
}
}
salsa::database_storage! {
struct DatabaseStorage for DatabaseStruct {
impl PanicSafelyDatabase {
@ -44,7 +52,10 @@ fn should_panic_safely() {
// Invoke `db.panic_safely() without having set `db.one`. `db.one` will
// default to 0 and we should catch the panic.
let result = panic::catch_unwind(AssertUnwindSafe(|| db.panic_safely()));
let result = panic::catch_unwind(AssertUnwindSafe({
let db = db.fork();
move || db.panic_safely()
}));
assert!(result.is_err());
// Set `db.one` to 1 and assert ok