From 3d5a07e86a223b8f52846b830ef72f344e5143a7 Mon Sep 17 00:00:00 2001 From: Yuya Nishihara Date: Sun, 12 Nov 2023 09:10:38 +0900 Subject: [PATCH] dag_walk: add fallible dfs(), topo_order(), heads(), and closest_common_node() This unblocks the use of Result in op.parents(). There are two ways to encode errors: a. impl IntoIterator> b. Result where V: FromIterator I think (a) is more natural to algorithms like dfs(), which can process error nodes transparently. Still the caller might have to collect the source iterator to temporary Vec to conform to the neighbors_fn signature. It's not easy for neighbors_fn to return an iterator borrowing the input node. We already have GAT, but doesn't have return-position impl Trait in trait yet. --- lib/src/dag_walk.rs | 269 +++++++++++++++++++++++++++++++++++++++----- 1 file changed, 243 insertions(+), 26 deletions(-) diff --git a/lib/src/dag_walk.rs b/lib/src/dag_walk.rs index b070cf6de..ea06a3e14 100644 --- a/lib/src/dag_walk.rs +++ b/lib/src/dag_walk.rs @@ -12,14 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -#![allow(missing_docs)] +//! General-purpose DAG algorithms. use std::collections::{BinaryHeap, HashMap, HashSet}; +use std::convert::Infallible; use std::hash::Hash; use std::{iter, mem}; use itertools::Itertools as _; +/// Traverses nodes from `start` in depth-first order. pub fn dfs( start: II, id_fn: impl Fn(&T) -> ID, @@ -30,10 +32,31 @@ where II: IntoIterator, NI: IntoIterator, { - let mut work: Vec = start.into_iter().collect(); + let neighbors_fn = move |node: &T| to_ok_iter(neighbors_fn(node)); + dfs_ok(to_ok_iter(start), id_fn, neighbors_fn).map(Result::unwrap) +} + +/// Traverses nodes from `start` in depth-first order. +/// +/// An `Err` is emitted as a node with no neighbors. Caller may decide to +/// short-circuit on it. +pub fn dfs_ok( + start: II, + id_fn: impl Fn(&T) -> ID, + mut neighbors_fn: impl FnMut(&T) -> NI, +) -> impl Iterator> +where + ID: Hash + Eq, + II: IntoIterator>, + NI: IntoIterator>, +{ + let mut work: Vec> = start.into_iter().collect(); let mut visited: HashSet = HashSet::new(); iter::from_fn(move || loop { - let c = work.pop()?; + let c = match work.pop() { + Some(Ok(c)) => c, + r @ (Some(Err(_)) | None) => return r, + }; let id = id_fn(&c); if visited.contains(&id) { continue; @@ -42,10 +65,12 @@ where work.push(p); } visited.insert(id); - return Some(c); + return Some(Ok(c)); }) } +/// Builds a list of nodes reachable from the `start` where neighbors come +/// before the node itself. pub fn topo_order_forward( start: II, id_fn: impl Fn(&T) -> ID, @@ -56,7 +81,26 @@ where II: IntoIterator, NI: IntoIterator, { - let mut stack = start.into_iter().map(|node| (node, false)).collect_vec(); + let neighbors_fn = move |node: &T| to_ok_iter(neighbors_fn(node)); + topo_order_forward_ok(to_ok_iter(start), id_fn, neighbors_fn).unwrap() +} + +/// Builds a list of `Ok` nodes reachable from the `start` where neighbors come +/// before the node itself. +/// +/// If `start` or `neighbors_fn()` yields an `Err`, this function terminates and +/// returns the error. +pub fn topo_order_forward_ok( + start: II, + id_fn: impl Fn(&T) -> ID, + mut neighbors_fn: impl FnMut(&T) -> NI, +) -> Result, E> +where + ID: Hash + Eq + Clone, + II: IntoIterator>, + NI: IntoIterator>, +{ + let mut stack: Vec<(T, bool)> = start.into_iter().map(|r| Ok((r?, false))).try_collect()?; let mut visiting = HashSet::new(); let mut emitted = HashSet::new(); let mut result = vec![]; @@ -67,32 +111,55 @@ where } if !neighbors_visited { assert!(visiting.insert(id.clone()), "graph has cycle"); - let neighbors = neighbors_fn(&node); + let neighbors_iter = neighbors_fn(&node).into_iter(); + stack.reserve(neighbors_iter.size_hint().0 + 1); stack.push((node, true)); - stack.extend(neighbors.into_iter().map(|neighbor| (neighbor, false))); + for neighbor in neighbors_iter { + stack.push((neighbor?, false)); + } } else { visiting.remove(&id); emitted.insert(id); result.push(node); } } - result + Ok(result) } -/// Returns neighbors before the node itself. +/// Builds a list of nodes reachable from the `start` where neighbors come after +/// the node itself. pub fn topo_order_reverse( start: II, id_fn: impl Fn(&T) -> ID, - neighbors_fn: impl FnMut(&T) -> NI, + mut neighbors_fn: impl FnMut(&T) -> NI, ) -> Vec where ID: Hash + Eq + Clone, II: IntoIterator, NI: IntoIterator, { - let mut result = topo_order_forward(start, id_fn, neighbors_fn); + let neighbors_fn = move |node: &T| to_ok_iter(neighbors_fn(node)); + topo_order_reverse_ok(to_ok_iter(start), id_fn, neighbors_fn).unwrap() +} + +/// Builds a list of `Ok` nodes reachable from the `start` where neighbors come +/// after the node itself. +/// +/// If `start` or `neighbors_fn()` yields an `Err`, this function terminates and +/// returns the error. +pub fn topo_order_reverse_ok( + start: II, + id_fn: impl Fn(&T) -> ID, + neighbors_fn: impl FnMut(&T) -> NI, +) -> Result, E> +where + ID: Hash + Eq + Clone, + II: IntoIterator>, + NI: IntoIterator>, +{ + let mut result = topo_order_forward_ok(start, id_fn, neighbors_fn)?; result.reverse(); - result + Ok(result) } /// Like `topo_order_reverse()`, but can iterate linear DAG lazily. @@ -253,6 +320,29 @@ where ID: Hash + Eq + Clone, II: IntoIterator, NI: IntoIterator, +{ + let neighbors_fn = move |node: &T| to_ok_iter(neighbors_fn(node)); + topo_order_reverse_ord_ok(to_ok_iter(start), id_fn, neighbors_fn).unwrap() +} + +/// Builds a list of `Ok` nodes reachable from the `start` where neighbors come +/// after the node itself. +/// +/// Unlike `topo_order_reverse_ok()`, nodes are sorted in reverse `T: Ord` order +/// so long as they can respect the topological requirement. +/// +/// If `start` or `neighbors_fn()` yields an `Err`, this function terminates and +/// returns the error. +pub fn topo_order_reverse_ord_ok( + start: II, + id_fn: impl Fn(&T) -> ID, + mut neighbors_fn: impl FnMut(&T) -> NI, +) -> Result, E> +where + T: Ord, + ID: Hash + Eq + Clone, + II: IntoIterator>, + NI: IntoIterator>, { struct InnerNode { node: Option, @@ -260,7 +350,7 @@ where } // DFS to accumulate incoming edges - let mut stack: Vec = start.into_iter().collect(); + let mut stack: Vec = start.into_iter().try_collect()?; let mut head_node_map: HashMap = HashMap::new(); let mut inner_node_map: HashMap> = HashMap::new(); let mut neighbor_ids_map: HashMap> = HashMap::new(); @@ -270,8 +360,12 @@ where continue; // Already visited } + let neighbors_iter = neighbors_fn(&node).into_iter(); let pos = stack.len(); - stack.extend(neighbors_fn(&node)); + stack.reserve(neighbors_iter.size_hint().0); + for neighbor in neighbors_iter { + stack.push(neighbor?); + } let neighbor_ids = stack[pos..].iter().map(&id_fn).collect_vec(); if let Some(inner) = inner_node_map.get_mut(&node_id) { inner.node = Some(node); @@ -316,7 +410,7 @@ where } assert!(inner_node_map.is_empty(), "graph has cycle"); - result + Ok(result) } /// Find nodes in the start set that are not reachable from other nodes in the @@ -332,18 +426,40 @@ where II: IntoIterator, NI: IntoIterator, { - let start: Vec = start.into_iter().collect(); + let neighbors_fn = move |node: &T| to_ok_iter(neighbors_fn(node)); + heads_ok(to_ok_iter(start), id_fn, neighbors_fn).unwrap() +} + +/// Finds `Ok` nodes in the start set that are not reachable from other nodes in +/// the start set. +/// +/// If `start` or `neighbors_fn()` yields an `Err`, this function terminates and +/// returns the error. +pub fn heads_ok( + start: II, + id_fn: impl Fn(&T) -> ID, + mut neighbors_fn: impl FnMut(&T) -> NI, +) -> Result, E> +where + T: Hash + Eq + Clone, + ID: Hash + Eq, + II: IntoIterator>, + NI: IntoIterator>, +{ + let start: Vec = start.into_iter().try_collect()?; let mut reachable: HashSet = start.iter().cloned().collect(); - for _node in dfs(start, id_fn, |node| { - let neighbors: Vec = neighbors_fn(node).into_iter().collect(); - for neighbor in &neighbors { + dfs_ok(start.into_iter().map(Ok), id_fn, |node| { + let neighbors: Vec> = neighbors_fn(node).into_iter().collect(); + for neighbor in neighbors.iter().filter_map(|x| x.as_ref().ok()) { reachable.remove(neighbor); } neighbors - }) {} - reachable + }) + .try_for_each(|r| r.map(|_| ()))?; + Ok(reachable) } +/// Finds the closest common neighbor among the `set1` and `set2`. pub fn closest_common_node( set1: II1, set2: II2, @@ -355,18 +471,42 @@ where II1: IntoIterator, II2: IntoIterator, NI: IntoIterator, +{ + let neighbors_fn = move |node: &T| to_ok_iter(neighbors_fn(node)); + closest_common_node_ok(to_ok_iter(set1), to_ok_iter(set2), id_fn, neighbors_fn).unwrap() +} + +/// Finds the closest common `Ok` neighbor among the `set1` and `set2`. +/// +/// If the traverse reached to an `Err`, this function terminates and returns +/// the error. +pub fn closest_common_node_ok( + set1: II1, + set2: II2, + id_fn: impl Fn(&T) -> ID, + mut neighbors_fn: impl FnMut(&T) -> NI, +) -> Result, E> +where + ID: Hash + Eq, + II1: IntoIterator>, + II2: IntoIterator>, + NI: IntoIterator>, { let mut visited1 = HashSet::new(); let mut visited2 = HashSet::new(); - let mut work1: Vec = set1.into_iter().collect(); - let mut work2: Vec = set2.into_iter().collect(); + // TODO: might be better to leave an Err so long as the work contains at + // least one Ok node. If a work1 node is included in visited2, it should be + // the closest node even if work2 had previously contained an Err. + let mut work1: Vec> = set1.into_iter().collect(); + let mut work2: Vec> = set2.into_iter().collect(); while !work1.is_empty() || !work2.is_empty() { let mut new_work1 = vec![]; for node in work1 { + let node = node?; let id: ID = id_fn(&node); if visited2.contains(&id) { - return Some(node); + return Ok(Some(node)); } if visited1.insert(id) { for neighbor in neighbors_fn(&node) { @@ -378,9 +518,10 @@ where let mut new_work2 = vec![]; for node in work2 { + let node = node?; let id: ID = id_fn(&node); if visited1.contains(&id) { - return Some(node); + return Ok(Some(node)); } if visited2.insert(id) { for neighbor in neighbors_fn(&node) { @@ -390,7 +531,11 @@ where } work2 = new_work2; } - None + Ok(None) +} + +fn to_ok_iter(iter: impl IntoIterator) -> impl Iterator> { + iter.into_iter().map(Ok) } #[cfg(test)] @@ -401,6 +546,21 @@ mod tests { use super::*; + #[test] + fn test_dfs_ok() { + let neighbors = hashmap! { + 'A' => vec![], + 'B' => vec![Ok('A'), Err('X')], + 'C' => vec![Ok('B')], + }; + let id_fn = |node: &char| *node; + let neighbors_fn = |node: &char| neighbors[node].clone(); + + // Self and neighbor nodes shouldn't be lost at the error. + let nodes = dfs_ok([Ok('C')], id_fn, neighbors_fn).collect_vec(); + assert_eq!(nodes, [Ok('C'), Ok('B'), Err('X'), Ok('A')]); + } + #[test] fn test_topo_order_reverse_linear() { // This graph: @@ -879,6 +1039,26 @@ mod tests { assert!(result.is_err()); } + #[test] + fn test_topo_order_ok() { + let neighbors = hashmap! { + 'A' => vec![Err('Y')], + 'B' => vec![Ok('A'), Err('X')], + 'C' => vec![Ok('B')], + }; + let id_fn = |node: &char| *node; + let neighbors_fn = |node: &char| neighbors[node].clone(); + + // Terminates at Err('X') no matter if the sorting order is forward or + // reverse. The visiting order matters. + let result = topo_order_forward_ok([Ok('C')], id_fn, neighbors_fn); + assert_eq!(result, Err('X')); + let result = topo_order_reverse_ok([Ok('C')], id_fn, neighbors_fn); + assert_eq!(result, Err('X')); + let result = topo_order_reverse_ord_ok([Ok('C')], id_fn, neighbors_fn); + assert_eq!(result, Err('X')); + } + #[test] fn test_closest_common_node_tricky() { // Test this case where A is the shortest distance away, but we still want the @@ -916,6 +1096,25 @@ mod tests { assert_eq!(common, Some('A')); } + #[test] + fn test_closest_common_node_ok() { + let neighbors = hashmap! { + 'A' => vec![Err('Y')], + 'B' => vec![Ok('A')], + 'C' => vec![Ok('A')], + 'D' => vec![Err('X')], + }; + let id_fn = |node: &char| *node; + let neighbors_fn = |node: &char| neighbors[node].clone(); + + let result = closest_common_node_ok([Ok('B')], [Ok('C')], id_fn, neighbors_fn); + assert_eq!(result, Ok(Some('A'))); + let result = closest_common_node_ok([Ok('C')], [Ok('D')], id_fn, neighbors_fn); + assert_eq!(result, Err('X')); + let result = closest_common_node_ok([Ok('C')], [Err('Z')], id_fn, neighbors_fn); + assert_eq!(result, Err('Z')); + } + #[test] fn test_heads_mixed() { // Test the uppercase letters are in the start set @@ -952,4 +1151,22 @@ mod tests { ); assert_eq!(actual, hashset!['D', 'F']); } + + #[test] + fn test_heads_ok() { + let neighbors = hashmap! { + 'A' => vec![], + 'B' => vec![Ok('A'), Err('X')], + 'C' => vec![Ok('B')], + }; + let id_fn = |node: &char| *node; + let neighbors_fn = |node: &char| neighbors[node].clone(); + + let result = heads_ok([Ok('C')], id_fn, neighbors_fn); + assert_eq!(result, Err('X')); + let result = heads_ok([Ok('B')], id_fn, neighbors_fn); + assert_eq!(result, Err('X')); + let result = heads_ok([Ok('A')], id_fn, neighbors_fn); + assert_eq!(result, Ok(hashset! {'A'})); + } }