dag_walk: add fallible dfs(), topo_order(), heads(), and closest_common_node()

This unblocks the use of Result<T, E> in op.parents().

There are two ways to encode errors:
 a. impl IntoIterator<Item = Result<T, E>>
 b. Result<V, E> where V: FromIterator<Item = T>
I think (a) is more natural to algorithms like dfs(), which can process error
nodes transparently.

Still the caller might have to collect the source iterator to temporary Vec
to conform to the neighbors_fn signature. It's not easy for neighbors_fn to
return an iterator borrowing the input node. We already have GAT, but doesn't
have return-position impl Trait in trait yet.
This commit is contained in:
Yuya Nishihara 2023-11-12 09:10:38 +09:00
parent e5a9a26911
commit 3d5a07e86a

View file

@ -12,14 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#![allow(missing_docs)]
//! General-purpose DAG algorithms.
use std::collections::{BinaryHeap, HashMap, HashSet};
use std::convert::Infallible;
use std::hash::Hash;
use std::{iter, mem};
use itertools::Itertools as _;
/// Traverses nodes from `start` in depth-first order.
pub fn dfs<T, ID, II, NI>(
start: II,
id_fn: impl Fn(&T) -> ID,
@ -30,10 +32,31 @@ where
II: IntoIterator<Item = T>,
NI: IntoIterator<Item = T>,
{
let mut work: Vec<T> = start.into_iter().collect();
let neighbors_fn = move |node: &T| to_ok_iter(neighbors_fn(node));
dfs_ok(to_ok_iter(start), id_fn, neighbors_fn).map(Result::unwrap)
}
/// Traverses nodes from `start` in depth-first order.
///
/// An `Err` is emitted as a node with no neighbors. Caller may decide to
/// short-circuit on it.
pub fn dfs_ok<T, ID, E, II, NI>(
start: II,
id_fn: impl Fn(&T) -> ID,
mut neighbors_fn: impl FnMut(&T) -> NI,
) -> impl Iterator<Item = Result<T, E>>
where
ID: Hash + Eq,
II: IntoIterator<Item = Result<T, E>>,
NI: IntoIterator<Item = Result<T, E>>,
{
let mut work: Vec<Result<T, E>> = start.into_iter().collect();
let mut visited: HashSet<ID> = HashSet::new();
iter::from_fn(move || loop {
let c = work.pop()?;
let c = match work.pop() {
Some(Ok(c)) => c,
r @ (Some(Err(_)) | None) => return r,
};
let id = id_fn(&c);
if visited.contains(&id) {
continue;
@ -42,10 +65,12 @@ where
work.push(p);
}
visited.insert(id);
return Some(c);
return Some(Ok(c));
})
}
/// Builds a list of nodes reachable from the `start` where neighbors come
/// before the node itself.
pub fn topo_order_forward<T, ID, II, NI>(
start: II,
id_fn: impl Fn(&T) -> ID,
@ -56,7 +81,26 @@ where
II: IntoIterator<Item = T>,
NI: IntoIterator<Item = T>,
{
let mut stack = start.into_iter().map(|node| (node, false)).collect_vec();
let neighbors_fn = move |node: &T| to_ok_iter(neighbors_fn(node));
topo_order_forward_ok(to_ok_iter(start), id_fn, neighbors_fn).unwrap()
}
/// Builds a list of `Ok` nodes reachable from the `start` where neighbors come
/// before the node itself.
///
/// If `start` or `neighbors_fn()` yields an `Err`, this function terminates and
/// returns the error.
pub fn topo_order_forward_ok<T, ID, E, II, NI>(
start: II,
id_fn: impl Fn(&T) -> ID,
mut neighbors_fn: impl FnMut(&T) -> NI,
) -> Result<Vec<T>, E>
where
ID: Hash + Eq + Clone,
II: IntoIterator<Item = Result<T, E>>,
NI: IntoIterator<Item = Result<T, E>>,
{
let mut stack: Vec<(T, bool)> = start.into_iter().map(|r| Ok((r?, false))).try_collect()?;
let mut visiting = HashSet::new();
let mut emitted = HashSet::new();
let mut result = vec![];
@ -67,32 +111,55 @@ where
}
if !neighbors_visited {
assert!(visiting.insert(id.clone()), "graph has cycle");
let neighbors = neighbors_fn(&node);
let neighbors_iter = neighbors_fn(&node).into_iter();
stack.reserve(neighbors_iter.size_hint().0 + 1);
stack.push((node, true));
stack.extend(neighbors.into_iter().map(|neighbor| (neighbor, false)));
for neighbor in neighbors_iter {
stack.push((neighbor?, false));
}
} else {
visiting.remove(&id);
emitted.insert(id);
result.push(node);
}
}
result
Ok(result)
}
/// Returns neighbors before the node itself.
/// Builds a list of nodes reachable from the `start` where neighbors come after
/// the node itself.
pub fn topo_order_reverse<T, ID, II, NI>(
start: II,
id_fn: impl Fn(&T) -> ID,
neighbors_fn: impl FnMut(&T) -> NI,
mut neighbors_fn: impl FnMut(&T) -> NI,
) -> Vec<T>
where
ID: Hash + Eq + Clone,
II: IntoIterator<Item = T>,
NI: IntoIterator<Item = T>,
{
let mut result = topo_order_forward(start, id_fn, neighbors_fn);
let neighbors_fn = move |node: &T| to_ok_iter(neighbors_fn(node));
topo_order_reverse_ok(to_ok_iter(start), id_fn, neighbors_fn).unwrap()
}
/// Builds a list of `Ok` nodes reachable from the `start` where neighbors come
/// after the node itself.
///
/// If `start` or `neighbors_fn()` yields an `Err`, this function terminates and
/// returns the error.
pub fn topo_order_reverse_ok<T, ID, E, II, NI>(
start: II,
id_fn: impl Fn(&T) -> ID,
neighbors_fn: impl FnMut(&T) -> NI,
) -> Result<Vec<T>, E>
where
ID: Hash + Eq + Clone,
II: IntoIterator<Item = Result<T, E>>,
NI: IntoIterator<Item = Result<T, E>>,
{
let mut result = topo_order_forward_ok(start, id_fn, neighbors_fn)?;
result.reverse();
result
Ok(result)
}
/// Like `topo_order_reverse()`, but can iterate linear DAG lazily.
@ -253,6 +320,29 @@ where
ID: Hash + Eq + Clone,
II: IntoIterator<Item = T>,
NI: IntoIterator<Item = T>,
{
let neighbors_fn = move |node: &T| to_ok_iter(neighbors_fn(node));
topo_order_reverse_ord_ok(to_ok_iter(start), id_fn, neighbors_fn).unwrap()
}
/// Builds a list of `Ok` nodes reachable from the `start` where neighbors come
/// after the node itself.
///
/// Unlike `topo_order_reverse_ok()`, nodes are sorted in reverse `T: Ord` order
/// so long as they can respect the topological requirement.
///
/// If `start` or `neighbors_fn()` yields an `Err`, this function terminates and
/// returns the error.
pub fn topo_order_reverse_ord_ok<T, ID, E, II, NI>(
start: II,
id_fn: impl Fn(&T) -> ID,
mut neighbors_fn: impl FnMut(&T) -> NI,
) -> Result<Vec<T>, E>
where
T: Ord,
ID: Hash + Eq + Clone,
II: IntoIterator<Item = Result<T, E>>,
NI: IntoIterator<Item = Result<T, E>>,
{
struct InnerNode<T> {
node: Option<T>,
@ -260,7 +350,7 @@ where
}
// DFS to accumulate incoming edges
let mut stack: Vec<T> = start.into_iter().collect();
let mut stack: Vec<T> = start.into_iter().try_collect()?;
let mut head_node_map: HashMap<ID, T> = HashMap::new();
let mut inner_node_map: HashMap<ID, InnerNode<T>> = HashMap::new();
let mut neighbor_ids_map: HashMap<ID, Vec<ID>> = HashMap::new();
@ -270,8 +360,12 @@ where
continue; // Already visited
}
let neighbors_iter = neighbors_fn(&node).into_iter();
let pos = stack.len();
stack.extend(neighbors_fn(&node));
stack.reserve(neighbors_iter.size_hint().0);
for neighbor in neighbors_iter {
stack.push(neighbor?);
}
let neighbor_ids = stack[pos..].iter().map(&id_fn).collect_vec();
if let Some(inner) = inner_node_map.get_mut(&node_id) {
inner.node = Some(node);
@ -316,7 +410,7 @@ where
}
assert!(inner_node_map.is_empty(), "graph has cycle");
result
Ok(result)
}
/// Find nodes in the start set that are not reachable from other nodes in the
@ -332,18 +426,40 @@ where
II: IntoIterator<Item = T>,
NI: IntoIterator<Item = T>,
{
let start: Vec<T> = start.into_iter().collect();
let neighbors_fn = move |node: &T| to_ok_iter(neighbors_fn(node));
heads_ok(to_ok_iter(start), id_fn, neighbors_fn).unwrap()
}
/// Finds `Ok` nodes in the start set that are not reachable from other nodes in
/// the start set.
///
/// If `start` or `neighbors_fn()` yields an `Err`, this function terminates and
/// returns the error.
pub fn heads_ok<T, ID, E, II, NI>(
start: II,
id_fn: impl Fn(&T) -> ID,
mut neighbors_fn: impl FnMut(&T) -> NI,
) -> Result<HashSet<T>, E>
where
T: Hash + Eq + Clone,
ID: Hash + Eq,
II: IntoIterator<Item = Result<T, E>>,
NI: IntoIterator<Item = Result<T, E>>,
{
let start: Vec<T> = start.into_iter().try_collect()?;
let mut reachable: HashSet<T> = start.iter().cloned().collect();
for _node in dfs(start, id_fn, |node| {
let neighbors: Vec<T> = neighbors_fn(node).into_iter().collect();
for neighbor in &neighbors {
dfs_ok(start.into_iter().map(Ok), id_fn, |node| {
let neighbors: Vec<Result<T, E>> = neighbors_fn(node).into_iter().collect();
for neighbor in neighbors.iter().filter_map(|x| x.as_ref().ok()) {
reachable.remove(neighbor);
}
neighbors
}) {}
reachable
})
.try_for_each(|r| r.map(|_| ()))?;
Ok(reachable)
}
/// Finds the closest common neighbor among the `set1` and `set2`.
pub fn closest_common_node<T, ID, II1, II2, NI>(
set1: II1,
set2: II2,
@ -355,18 +471,42 @@ where
II1: IntoIterator<Item = T>,
II2: IntoIterator<Item = T>,
NI: IntoIterator<Item = T>,
{
let neighbors_fn = move |node: &T| to_ok_iter(neighbors_fn(node));
closest_common_node_ok(to_ok_iter(set1), to_ok_iter(set2), id_fn, neighbors_fn).unwrap()
}
/// Finds the closest common `Ok` neighbor among the `set1` and `set2`.
///
/// If the traverse reached to an `Err`, this function terminates and returns
/// the error.
pub fn closest_common_node_ok<T, ID, E, II1, II2, NI>(
set1: II1,
set2: II2,
id_fn: impl Fn(&T) -> ID,
mut neighbors_fn: impl FnMut(&T) -> NI,
) -> Result<Option<T>, E>
where
ID: Hash + Eq,
II1: IntoIterator<Item = Result<T, E>>,
II2: IntoIterator<Item = Result<T, E>>,
NI: IntoIterator<Item = Result<T, E>>,
{
let mut visited1 = HashSet::new();
let mut visited2 = HashSet::new();
let mut work1: Vec<T> = set1.into_iter().collect();
let mut work2: Vec<T> = set2.into_iter().collect();
// TODO: might be better to leave an Err so long as the work contains at
// least one Ok node. If a work1 node is included in visited2, it should be
// the closest node even if work2 had previously contained an Err.
let mut work1: Vec<Result<T, E>> = set1.into_iter().collect();
let mut work2: Vec<Result<T, E>> = set2.into_iter().collect();
while !work1.is_empty() || !work2.is_empty() {
let mut new_work1 = vec![];
for node in work1 {
let node = node?;
let id: ID = id_fn(&node);
if visited2.contains(&id) {
return Some(node);
return Ok(Some(node));
}
if visited1.insert(id) {
for neighbor in neighbors_fn(&node) {
@ -378,9 +518,10 @@ where
let mut new_work2 = vec![];
for node in work2 {
let node = node?;
let id: ID = id_fn(&node);
if visited1.contains(&id) {
return Some(node);
return Ok(Some(node));
}
if visited2.insert(id) {
for neighbor in neighbors_fn(&node) {
@ -390,7 +531,11 @@ where
}
work2 = new_work2;
}
None
Ok(None)
}
fn to_ok_iter<T>(iter: impl IntoIterator<Item = T>) -> impl Iterator<Item = Result<T, Infallible>> {
iter.into_iter().map(Ok)
}
#[cfg(test)]
@ -401,6 +546,21 @@ mod tests {
use super::*;
#[test]
fn test_dfs_ok() {
let neighbors = hashmap! {
'A' => vec![],
'B' => vec![Ok('A'), Err('X')],
'C' => vec![Ok('B')],
};
let id_fn = |node: &char| *node;
let neighbors_fn = |node: &char| neighbors[node].clone();
// Self and neighbor nodes shouldn't be lost at the error.
let nodes = dfs_ok([Ok('C')], id_fn, neighbors_fn).collect_vec();
assert_eq!(nodes, [Ok('C'), Ok('B'), Err('X'), Ok('A')]);
}
#[test]
fn test_topo_order_reverse_linear() {
// This graph:
@ -879,6 +1039,26 @@ mod tests {
assert!(result.is_err());
}
#[test]
fn test_topo_order_ok() {
let neighbors = hashmap! {
'A' => vec![Err('Y')],
'B' => vec![Ok('A'), Err('X')],
'C' => vec![Ok('B')],
};
let id_fn = |node: &char| *node;
let neighbors_fn = |node: &char| neighbors[node].clone();
// Terminates at Err('X') no matter if the sorting order is forward or
// reverse. The visiting order matters.
let result = topo_order_forward_ok([Ok('C')], id_fn, neighbors_fn);
assert_eq!(result, Err('X'));
let result = topo_order_reverse_ok([Ok('C')], id_fn, neighbors_fn);
assert_eq!(result, Err('X'));
let result = topo_order_reverse_ord_ok([Ok('C')], id_fn, neighbors_fn);
assert_eq!(result, Err('X'));
}
#[test]
fn test_closest_common_node_tricky() {
// Test this case where A is the shortest distance away, but we still want the
@ -916,6 +1096,25 @@ mod tests {
assert_eq!(common, Some('A'));
}
#[test]
fn test_closest_common_node_ok() {
let neighbors = hashmap! {
'A' => vec![Err('Y')],
'B' => vec![Ok('A')],
'C' => vec![Ok('A')],
'D' => vec![Err('X')],
};
let id_fn = |node: &char| *node;
let neighbors_fn = |node: &char| neighbors[node].clone();
let result = closest_common_node_ok([Ok('B')], [Ok('C')], id_fn, neighbors_fn);
assert_eq!(result, Ok(Some('A')));
let result = closest_common_node_ok([Ok('C')], [Ok('D')], id_fn, neighbors_fn);
assert_eq!(result, Err('X'));
let result = closest_common_node_ok([Ok('C')], [Err('Z')], id_fn, neighbors_fn);
assert_eq!(result, Err('Z'));
}
#[test]
fn test_heads_mixed() {
// Test the uppercase letters are in the start set
@ -952,4 +1151,22 @@ mod tests {
);
assert_eq!(actual, hashset!['D', 'F']);
}
#[test]
fn test_heads_ok() {
let neighbors = hashmap! {
'A' => vec![],
'B' => vec![Ok('A'), Err('X')],
'C' => vec![Ok('B')],
};
let id_fn = |node: &char| *node;
let neighbors_fn = |node: &char| neighbors[node].clone();
let result = heads_ok([Ok('C')], id_fn, neighbors_fn);
assert_eq!(result, Err('X'));
let result = heads_ok([Ok('B')], id_fn, neighbors_fn);
assert_eq!(result, Err('X'));
let result = heads_ok([Ok('A')], id_fn, neighbors_fn);
assert_eq!(result, Ok(hashset! {'A'}));
}
}