diff --git a/components/salsa-2022/src/runtime/active_query.rs b/components/salsa-2022/src/runtime/active_query.rs index 87fa8b46..298b5afe 100644 --- a/components/salsa-2022/src/runtime/active_query.rs +++ b/components/salsa-2022/src/runtime/active_query.rs @@ -1,5 +1,3 @@ -use std::collections::BTreeSet; - use crate::{ durability::Durability, hash::{FxIndexMap, FxIndexSet}, @@ -8,7 +6,7 @@ use crate::{ Cycle, Revision, Runtime, }; -use super::local_state::{QueryEdges, QueryOrigin, QueryRevisions}; +use super::local_state::{QueryEdges, QueryOrigin, QueryRevisions, EdgeKind}; #[derive(Debug)] pub(super) struct ActiveQuery { @@ -22,8 +20,13 @@ pub(super) struct ActiveQuery { /// untracked read, this will be set to the most recent revision. pub(super) changed_at: Revision, - /// Set of subqueries that were accessed thus far. - pub(super) dependencies: FxIndexSet, + /// Inputs: Set of subqueries that were accessed thus far. + /// Outputs: Tracks values written by this query. Could be... + /// + /// * tracked structs created + /// * invocations of `specify` + /// * accumulators pushed to + pub(super) input_outputs: FxIndexSet<(EdgeKind, DependencyIndex)>, /// True if there was an untracked read. pub(super) untracked_read: bool, @@ -35,16 +38,6 @@ pub(super) struct ActiveQuery { /// hash is added to this map. If it is not present, then the disambiguator is 0. /// Otherwise it is 1 more than the current value (which is incremented). pub(super) disambiguator_map: FxIndexMap, - - /// Tracks values written by this query. Could be... - /// - /// * tracked structs created - /// * invocations of `specify` - /// * accumulators pushed to - /// - /// We use a btree-set because we want to be able to - /// extract the keys in sorted order. - pub(super) outputs: BTreeSet, } impl ActiveQuery { @@ -53,11 +46,10 @@ impl ActiveQuery { database_key_index, durability: Durability::MAX, changed_at: Revision::start(), - dependencies: FxIndexSet::default(), + input_outputs: FxIndexSet::default(), untracked_read: false, cycle: None, disambiguator_map: Default::default(), - outputs: Default::default(), } } @@ -67,7 +59,7 @@ impl ActiveQuery { durability: Durability, revision: Revision, ) { - self.dependencies.insert(input); + self.input_outputs.insert((EdgeKind::Input, input)); self.durability = self.durability.min(durability); self.changed_at = self.changed_at.max(revision); } @@ -86,29 +78,19 @@ impl ActiveQuery { /// Adds a key to our list of outputs. pub(super) fn add_output(&mut self, key: DependencyIndex) { - self.outputs.insert(key); + self.input_outputs.insert((EdgeKind::Output, key)); } /// True if the given key was output by this query. pub(super) fn is_output(&self, key: DatabaseKeyIndex) -> bool { let key: DependencyIndex = key.into(); - self.outputs.contains(&key) + self.input_outputs.contains(&(EdgeKind::Output, key)) } pub(crate) fn revisions(&self, runtime: &Runtime) -> QueryRevisions { - let separator = self.dependencies.len(); + let input_outputs = self.input_outputs.iter().copied().collect(); - let input_outputs = if self.dependencies.is_empty() && self.outputs.is_empty() { - runtime.empty_dependencies() - } else { - self.dependencies - .iter() - .copied() - .chain(self.outputs.iter().copied()) - .collect() - }; - - let edges = QueryEdges::new(separator, input_outputs); + let edges = QueryEdges::new(input_outputs); let origin = if self.untracked_read { QueryOrigin::DerivedUntracked(edges) @@ -129,7 +111,7 @@ impl ActiveQuery { self.changed_at = self.changed_at.max(other.changed_at); self.durability = self.durability.min(other.durability); self.untracked_read |= other.untracked_read; - self.dependencies.extend(other.dependencies.iter().copied()); + self.input_outputs.extend(other.input_outputs.iter().copied()); } /// Removes the participants in `cycle` from my dependencies. @@ -137,7 +119,7 @@ impl ActiveQuery { pub(super) fn remove_cycle_participants(&mut self, cycle: &Cycle) { for p in cycle.participant_keys() { let p: DependencyIndex = p.into(); - self.dependencies.remove(&p); + self.input_outputs.remove(&(EdgeKind::Input, p)); } } @@ -146,7 +128,7 @@ impl ActiveQuery { pub(crate) fn take_inputs_from(&mut self, cycle_query: &ActiveQuery) { self.changed_at = cycle_query.changed_at; self.durability = cycle_query.durability; - self.dependencies = cycle_query.dependencies.clone(); + self.input_outputs = cycle_query.input_outputs.clone(); } pub(super) fn disambiguate(&mut self, hash: u64) -> Disambiguator { diff --git a/components/salsa-2022/src/runtime/local_state.rs b/components/salsa-2022/src/runtime/local_state.rs index e5359cb9..8dbfd60a 100644 --- a/components/salsa-2022/src/runtime/local_state.rs +++ b/components/salsa-2022/src/runtime/local_state.rs @@ -76,19 +76,25 @@ pub enum QueryOrigin { } impl QueryOrigin { - /// Indices for queries *written* by this query (or `&[]` if its value was assigned). + /// Indices for queries *written* by this query (or `vec![]` if its value was assigned). pub(crate) fn outputs(&self) -> impl Iterator + '_ { let slice = match self { QueryOrigin::Derived(edges) | QueryOrigin::DerivedUntracked(edges) => { - &edges.input_outputs[edges.separator as usize..] + edges.outputs() } - QueryOrigin::Assigned(_) | QueryOrigin::BaseInput => &[], + QueryOrigin::Assigned(_) | QueryOrigin::BaseInput => vec![], }; - slice.iter().copied() + slice.into_iter() } } +#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)] +pub enum EdgeKind { + Input, + Output, +} + /// The edges between a memoized value and other queries in the dependency graph. /// These edges include both dependency edges /// e.g., when creating the memoized value for Q0 executed another function Q1) @@ -98,8 +104,6 @@ impl QueryOrigin { pub struct QueryEdges { /// The list of outgoing edges from this node. /// This list combines *both* inputs and outputs. - /// The inputs are defined from the indices `0..S` where - /// `S` is the value of the `separator` field. /// /// Note that we always track input dependencies even when there are untracked reads. /// Untracked reads mean that we can't verify values, so we don't use the list of inputs for that, @@ -110,26 +114,35 @@ pub struct QueryEdges { /// Important: /// /// * The inputs must be in **execution order** for the red-green algorithm to work. - /// * The outputs must be in **sorted order** so that we can easily "diff" them between revisions. - input_outputs: Arc<[DependencyIndex]>, - - /// The index that separates inputs from outputs in the `tracked` field. - separator: u32, + input_outputs: Arc<[(EdgeKind, DependencyIndex)]>, } impl QueryEdges { /// Returns the (tracked) inputs that were executed in computing this memoized value. /// /// These will always be in execution order. - pub(crate) fn inputs(&self) -> &[DependencyIndex] { - &self.input_outputs[0..self.separator as usize] + pub(crate) fn inputs(&self) -> Vec { + self.input_outputs + .iter() + .filter(|(edge_kind, _)| *edge_kind == EdgeKind::Input) + .map(|(_, dependency_index)| *dependency_index) + .collect() + } + + /// Returns the (tracked) inputs that were executed in computing this memoized value. + /// + /// These will always be in execution order. + pub(crate) fn outputs(&self) -> Vec { + self.input_outputs + .iter() + .filter(|(edge_kind, _)| *edge_kind == EdgeKind::Output) + .map(|(_, dependency_index)| *dependency_index) + .collect() } /// Creates a new `QueryEdges`; the values given for each field must meet struct invariants. - pub(crate) fn new(separator: usize, input_outputs: Arc<[DependencyIndex]>) -> Self { - debug_assert!(separator <= input_outputs.len()); + pub(crate) fn new(input_outputs: Arc<[(EdgeKind, DependencyIndex)]>) -> Self { Self { - separator: u32::try_from(separator).unwrap(), input_outputs, } }