From f1499a20e2b508fe34b94b1730e08ed3e3c814f6 Mon Sep 17 00:00:00 2001 From: Jack Rickard Date: Sat, 10 Sep 2022 18:05:26 +0100 Subject: [PATCH] Fix unsoundness in input_field.rs --- components/salsa-2022-macros/src/input.rs | 10 ++-- components/salsa-2022/src/input.rs | 2 +- components/salsa-2022/src/input_field.rs | 60 +++++++++++++++++----- components/salsa-2022/src/setter.rs | 2 +- salsa-2022-tests/tests/deletion-cascade.rs | 2 +- 5 files changed, 54 insertions(+), 22 deletions(-) diff --git a/components/salsa-2022-macros/src/input.rs b/components/salsa-2022-macros/src/input.rs index 1429ef77..9793550d 100644 --- a/components/salsa-2022-macros/src/input.rs +++ b/components/salsa-2022-macros/src/input.rs @@ -130,13 +130,13 @@ impl InputStruct { let constructor: syn::ImplItemMethod = if singleton { parse_quote! { - pub fn #constructor_name(__db: &#db_dyn_ty, #(#field_names: #field_tys,)*) -> Self + pub fn #constructor_name(__db: &mut #db_dyn_ty, #(#field_names: #field_tys,)*) -> Self { - let (__jar, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(__db); - let __ingredients = <#jar_ty as salsa::storage::HasIngredientsFor< #ident >>::ingredient(__jar); + let (__jar, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar_mut(__db); + let __ingredients = <#jar_ty as salsa::storage::HasIngredientsFor< #ident >>::ingredient_mut(__jar); let __id = __ingredients.#input_index.new_singleton_input(__runtime); #( - __ingredients.#field_indices.store(__runtime, __id, #field_names, salsa::Durability::LOW); + __ingredients.#field_indices.store_mut(__runtime, __id, #field_names, salsa::Durability::LOW); )* __id } @@ -149,7 +149,7 @@ impl InputStruct { let __ingredients = <#jar_ty as salsa::storage::HasIngredientsFor< #ident >>::ingredient(__jar); let __id = __ingredients.#input_index.new_input(__runtime); #( - __ingredients.#field_indices.store(__runtime, __id, #field_names, salsa::Durability::LOW); + __ingredients.#field_indices.store_new(__runtime, __id, #field_names, salsa::Durability::LOW); )* __id } diff --git a/components/salsa-2022/src/input.rs b/components/salsa-2022/src/input.rs index 7439bc71..3f483f09 100644 --- a/components/salsa-2022/src/input.rs +++ b/components/salsa-2022/src/input.rs @@ -49,7 +49,7 @@ where Id::from_id(crate::Id::from_u32(next_id)) } - pub fn new_singleton_input(&self, _runtime: &Runtime) -> Id { + pub fn new_singleton_input(&mut self, _runtime: &Runtime) -> Id { // There's only one singleton so record that we've created it // and return the only id. self.counter.store(1, Ordering::Relaxed); diff --git a/components/salsa-2022/src/input_field.rs b/components/salsa-2022/src/input_field.rs index 02a71e75..f9b4f16a 100644 --- a/components/salsa-2022/src/input_field.rs +++ b/components/salsa-2022/src/input_field.rs @@ -4,18 +4,23 @@ use crate::key::DependencyIndex; use crate::runtime::local_state::QueryOrigin; use crate::runtime::StampedValue; use crate::{AsId, DatabaseKeyIndex, Durability, Id, IngredientIndex, Revision, Runtime}; +use dashmap::mapref::entry::Entry; use dashmap::DashMap; use std::fmt; use std::hash::Hash; /// Ingredient used to represent the fields of a `#[salsa::input]`. -/// These fields can only be mutated by an explicit call to a setter -/// with an `&mut` reference to the database, -/// and therefore cannot be mutated during a tracked function or in parallel. -/// This makes the implementation considerably simpler. +/// +/// These fields can only be mutated by a call to a setter with an `&mut` +/// reference to the database, and therefore cannot be mutated during a tracked +/// function or in parallel. +/// However for on-demand inputs to work the fields must be able to be set via +/// a shared reference, so some locking is required. +/// Altogether this makes the implementation somewhat simpler than tracked +/// structs. pub struct InputFieldIngredient { index: IngredientIndex, - map: DashMap>, + map: DashMap>>, debug_name: &'static str, } @@ -31,18 +36,43 @@ where } } - pub fn store(&self, runtime: &Runtime, key: K, value: F, durability: Durability) -> Option { + pub fn store_mut( + &mut self, + runtime: &Runtime, + key: K, + value: F, + durability: Durability, + ) -> Option { let revision = runtime.current_revision(); - let stamped_value = StampedValue { + let stamped_value = Box::new(StampedValue { value, durability, changed_at: revision, - }; + }); - if let Some(old_value) = self.map.insert(key, stamped_value) { - Some(old_value.value) - } else { - None + self.map + .insert(key, stamped_value) + .map(|old_value| old_value.value) + } + + /// Set the field of a new input. + /// + /// This function panics if the field has ever been set before. + pub fn store_new(&self, runtime: &Runtime, key: K, value: F, durability: Durability) { + let revision = runtime.current_revision(); + let stamped_value = Box::new(StampedValue { + value, + durability, + changed_at: revision, + }); + + match self.map.entry(key) { + Entry::Occupied(_) => { + panic!("attempted to set field of existing input using `store_new`, use `store_mut` instead"); + } + Entry::Vacant(entry) => { + entry.insert(stamped_value); + } } } @@ -51,7 +81,7 @@ where value, durability, changed_at, - } = &*self.map.get(&key).unwrap(); + } = &**self.map.get(&key).unwrap(); runtime.report_tracked_read( self.database_key_index(key).into(), @@ -60,7 +90,9 @@ where ); // SAFETY: - // * Values are only removed or altered when we have `&mut self` + // The value is stored in a box so internal moves in the dashmap don't + // invalidate the reference to the value inside the box. + // Values are only removed or altered when we have `&mut self`. unsafe { transmute_lifetime(self, value) } } diff --git a/components/salsa-2022/src/setter.rs b/components/salsa-2022/src/setter.rs index 05af9c0e..f23289a2 100644 --- a/components/salsa-2022/src/setter.rs +++ b/components/salsa-2022/src/setter.rs @@ -33,7 +33,7 @@ where pub fn to(self, value: F) -> F { self.ingredient - .store(self.runtime, self.key, value, self.durability) + .store_mut(self.runtime, self.key, value, self.durability) .unwrap() } } diff --git a/salsa-2022-tests/tests/deletion-cascade.rs b/salsa-2022-tests/tests/deletion-cascade.rs index a6bb3750..65fc0a6b 100644 --- a/salsa-2022-tests/tests/deletion-cascade.rs +++ b/salsa-2022-tests/tests/deletion-cascade.rs @@ -91,7 +91,7 @@ fn basic() { let mut db = Database::default(); // Creates 3 tracked structs - let input = MyInput::new(&db, 3); + let input = MyInput::new(&mut db, 3); assert_eq!(final_result(&db, input), 2 * 2 + 2); db.assert_logs(expect![[r#" [