mirror of
https://github.com/salsa-rs/salsa.git
synced 2025-01-15 01:39:25 +00:00
Fix unsoundness in input_field.rs
This commit is contained in:
parent
609acc396c
commit
f1499a20e2
5 changed files with 54 additions and 22 deletions
|
@ -130,13 +130,13 @@ impl InputStruct {
|
|||
|
||||
let constructor: syn::ImplItemMethod = if singleton {
|
||||
parse_quote! {
|
||||
pub fn #constructor_name(__db: &#db_dyn_ty, #(#field_names: #field_tys,)*) -> Self
|
||||
pub fn #constructor_name(__db: &mut #db_dyn_ty, #(#field_names: #field_tys,)*) -> Self
|
||||
{
|
||||
let (__jar, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(__db);
|
||||
let __ingredients = <#jar_ty as salsa::storage::HasIngredientsFor< #ident >>::ingredient(__jar);
|
||||
let (__jar, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar_mut(__db);
|
||||
let __ingredients = <#jar_ty as salsa::storage::HasIngredientsFor< #ident >>::ingredient_mut(__jar);
|
||||
let __id = __ingredients.#input_index.new_singleton_input(__runtime);
|
||||
#(
|
||||
__ingredients.#field_indices.store(__runtime, __id, #field_names, salsa::Durability::LOW);
|
||||
__ingredients.#field_indices.store_mut(__runtime, __id, #field_names, salsa::Durability::LOW);
|
||||
)*
|
||||
__id
|
||||
}
|
||||
|
@ -149,7 +149,7 @@ impl InputStruct {
|
|||
let __ingredients = <#jar_ty as salsa::storage::HasIngredientsFor< #ident >>::ingredient(__jar);
|
||||
let __id = __ingredients.#input_index.new_input(__runtime);
|
||||
#(
|
||||
__ingredients.#field_indices.store(__runtime, __id, #field_names, salsa::Durability::LOW);
|
||||
__ingredients.#field_indices.store_new(__runtime, __id, #field_names, salsa::Durability::LOW);
|
||||
)*
|
||||
__id
|
||||
}
|
||||
|
|
|
@ -49,7 +49,7 @@ where
|
|||
Id::from_id(crate::Id::from_u32(next_id))
|
||||
}
|
||||
|
||||
pub fn new_singleton_input(&self, _runtime: &Runtime) -> Id {
|
||||
pub fn new_singleton_input(&mut self, _runtime: &Runtime) -> Id {
|
||||
// There's only one singleton so record that we've created it
|
||||
// and return the only id.
|
||||
self.counter.store(1, Ordering::Relaxed);
|
||||
|
|
|
@ -4,18 +4,23 @@ use crate::key::DependencyIndex;
|
|||
use crate::runtime::local_state::QueryOrigin;
|
||||
use crate::runtime::StampedValue;
|
||||
use crate::{AsId, DatabaseKeyIndex, Durability, Id, IngredientIndex, Revision, Runtime};
|
||||
use dashmap::mapref::entry::Entry;
|
||||
use dashmap::DashMap;
|
||||
use std::fmt;
|
||||
use std::hash::Hash;
|
||||
|
||||
/// Ingredient used to represent the fields of a `#[salsa::input]`.
|
||||
/// These fields can only be mutated by an explicit call to a setter
|
||||
/// with an `&mut` reference to the database,
|
||||
/// and therefore cannot be mutated during a tracked function or in parallel.
|
||||
/// This makes the implementation considerably simpler.
|
||||
///
|
||||
/// These fields can only be mutated by a call to a setter with an `&mut`
|
||||
/// reference to the database, and therefore cannot be mutated during a tracked
|
||||
/// function or in parallel.
|
||||
/// However for on-demand inputs to work the fields must be able to be set via
|
||||
/// a shared reference, so some locking is required.
|
||||
/// Altogether this makes the implementation somewhat simpler than tracked
|
||||
/// structs.
|
||||
pub struct InputFieldIngredient<K, F> {
|
||||
index: IngredientIndex,
|
||||
map: DashMap<K, StampedValue<F>>,
|
||||
map: DashMap<K, Box<StampedValue<F>>>,
|
||||
debug_name: &'static str,
|
||||
}
|
||||
|
||||
|
@ -31,18 +36,43 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
pub fn store(&self, runtime: &Runtime, key: K, value: F, durability: Durability) -> Option<F> {
|
||||
pub fn store_mut(
|
||||
&mut self,
|
||||
runtime: &Runtime,
|
||||
key: K,
|
||||
value: F,
|
||||
durability: Durability,
|
||||
) -> Option<F> {
|
||||
let revision = runtime.current_revision();
|
||||
let stamped_value = StampedValue {
|
||||
let stamped_value = Box::new(StampedValue {
|
||||
value,
|
||||
durability,
|
||||
changed_at: revision,
|
||||
};
|
||||
});
|
||||
|
||||
if let Some(old_value) = self.map.insert(key, stamped_value) {
|
||||
Some(old_value.value)
|
||||
} else {
|
||||
None
|
||||
self.map
|
||||
.insert(key, stamped_value)
|
||||
.map(|old_value| old_value.value)
|
||||
}
|
||||
|
||||
/// Set the field of a new input.
|
||||
///
|
||||
/// This function panics if the field has ever been set before.
|
||||
pub fn store_new(&self, runtime: &Runtime, key: K, value: F, durability: Durability) {
|
||||
let revision = runtime.current_revision();
|
||||
let stamped_value = Box::new(StampedValue {
|
||||
value,
|
||||
durability,
|
||||
changed_at: revision,
|
||||
});
|
||||
|
||||
match self.map.entry(key) {
|
||||
Entry::Occupied(_) => {
|
||||
panic!("attempted to set field of existing input using `store_new`, use `store_mut` instead");
|
||||
}
|
||||
Entry::Vacant(entry) => {
|
||||
entry.insert(stamped_value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -51,7 +81,7 @@ where
|
|||
value,
|
||||
durability,
|
||||
changed_at,
|
||||
} = &*self.map.get(&key).unwrap();
|
||||
} = &**self.map.get(&key).unwrap();
|
||||
|
||||
runtime.report_tracked_read(
|
||||
self.database_key_index(key).into(),
|
||||
|
@ -60,7 +90,9 @@ where
|
|||
);
|
||||
|
||||
// SAFETY:
|
||||
// * Values are only removed or altered when we have `&mut self`
|
||||
// The value is stored in a box so internal moves in the dashmap don't
|
||||
// invalidate the reference to the value inside the box.
|
||||
// Values are only removed or altered when we have `&mut self`.
|
||||
unsafe { transmute_lifetime(self, value) }
|
||||
}
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ where
|
|||
|
||||
pub fn to(self, value: F) -> F {
|
||||
self.ingredient
|
||||
.store(self.runtime, self.key, value, self.durability)
|
||||
.store_mut(self.runtime, self.key, value, self.durability)
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -91,7 +91,7 @@ fn basic() {
|
|||
let mut db = Database::default();
|
||||
|
||||
// Creates 3 tracked structs
|
||||
let input = MyInput::new(&db, 3);
|
||||
let input = MyInput::new(&mut db, 3);
|
||||
assert_eq!(final_result(&db, input), 2 * 2 + 2);
|
||||
db.assert_logs(expect![[r#"
|
||||
[
|
||||
|
|
Loading…
Reference in a new issue