mirror of
https://github.com/martinvonz/jj.git
synced 2025-01-07 05:16:33 +00:00
dag_walk: add fallible dfs(), topo_order(), heads(), and closest_common_node()
This unblocks the use of Result<T, E> in op.parents(). There are two ways to encode errors: a. impl IntoIterator<Item = Result<T, E>> b. Result<V, E> where V: FromIterator<Item = T> I think (a) is more natural to algorithms like dfs(), which can process error nodes transparently. Still the caller might have to collect the source iterator to temporary Vec to conform to the neighbors_fn signature. It's not easy for neighbors_fn to return an iterator borrowing the input node. We already have GAT, but doesn't have return-position impl Trait in trait yet.
This commit is contained in:
parent
e5a9a26911
commit
3d5a07e86a
1 changed files with 243 additions and 26 deletions
|
@ -12,14 +12,16 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#![allow(missing_docs)]
|
||||
//! General-purpose DAG algorithms.
|
||||
|
||||
use std::collections::{BinaryHeap, HashMap, HashSet};
|
||||
use std::convert::Infallible;
|
||||
use std::hash::Hash;
|
||||
use std::{iter, mem};
|
||||
|
||||
use itertools::Itertools as _;
|
||||
|
||||
/// Traverses nodes from `start` in depth-first order.
|
||||
pub fn dfs<T, ID, II, NI>(
|
||||
start: II,
|
||||
id_fn: impl Fn(&T) -> ID,
|
||||
|
@ -30,10 +32,31 @@ where
|
|||
II: IntoIterator<Item = T>,
|
||||
NI: IntoIterator<Item = T>,
|
||||
{
|
||||
let mut work: Vec<T> = start.into_iter().collect();
|
||||
let neighbors_fn = move |node: &T| to_ok_iter(neighbors_fn(node));
|
||||
dfs_ok(to_ok_iter(start), id_fn, neighbors_fn).map(Result::unwrap)
|
||||
}
|
||||
|
||||
/// Traverses nodes from `start` in depth-first order.
|
||||
///
|
||||
/// An `Err` is emitted as a node with no neighbors. Caller may decide to
|
||||
/// short-circuit on it.
|
||||
pub fn dfs_ok<T, ID, E, II, NI>(
|
||||
start: II,
|
||||
id_fn: impl Fn(&T) -> ID,
|
||||
mut neighbors_fn: impl FnMut(&T) -> NI,
|
||||
) -> impl Iterator<Item = Result<T, E>>
|
||||
where
|
||||
ID: Hash + Eq,
|
||||
II: IntoIterator<Item = Result<T, E>>,
|
||||
NI: IntoIterator<Item = Result<T, E>>,
|
||||
{
|
||||
let mut work: Vec<Result<T, E>> = start.into_iter().collect();
|
||||
let mut visited: HashSet<ID> = HashSet::new();
|
||||
iter::from_fn(move || loop {
|
||||
let c = work.pop()?;
|
||||
let c = match work.pop() {
|
||||
Some(Ok(c)) => c,
|
||||
r @ (Some(Err(_)) | None) => return r,
|
||||
};
|
||||
let id = id_fn(&c);
|
||||
if visited.contains(&id) {
|
||||
continue;
|
||||
|
@ -42,10 +65,12 @@ where
|
|||
work.push(p);
|
||||
}
|
||||
visited.insert(id);
|
||||
return Some(c);
|
||||
return Some(Ok(c));
|
||||
})
|
||||
}
|
||||
|
||||
/// Builds a list of nodes reachable from the `start` where neighbors come
|
||||
/// before the node itself.
|
||||
pub fn topo_order_forward<T, ID, II, NI>(
|
||||
start: II,
|
||||
id_fn: impl Fn(&T) -> ID,
|
||||
|
@ -56,7 +81,26 @@ where
|
|||
II: IntoIterator<Item = T>,
|
||||
NI: IntoIterator<Item = T>,
|
||||
{
|
||||
let mut stack = start.into_iter().map(|node| (node, false)).collect_vec();
|
||||
let neighbors_fn = move |node: &T| to_ok_iter(neighbors_fn(node));
|
||||
topo_order_forward_ok(to_ok_iter(start), id_fn, neighbors_fn).unwrap()
|
||||
}
|
||||
|
||||
/// Builds a list of `Ok` nodes reachable from the `start` where neighbors come
|
||||
/// before the node itself.
|
||||
///
|
||||
/// If `start` or `neighbors_fn()` yields an `Err`, this function terminates and
|
||||
/// returns the error.
|
||||
pub fn topo_order_forward_ok<T, ID, E, II, NI>(
|
||||
start: II,
|
||||
id_fn: impl Fn(&T) -> ID,
|
||||
mut neighbors_fn: impl FnMut(&T) -> NI,
|
||||
) -> Result<Vec<T>, E>
|
||||
where
|
||||
ID: Hash + Eq + Clone,
|
||||
II: IntoIterator<Item = Result<T, E>>,
|
||||
NI: IntoIterator<Item = Result<T, E>>,
|
||||
{
|
||||
let mut stack: Vec<(T, bool)> = start.into_iter().map(|r| Ok((r?, false))).try_collect()?;
|
||||
let mut visiting = HashSet::new();
|
||||
let mut emitted = HashSet::new();
|
||||
let mut result = vec![];
|
||||
|
@ -67,32 +111,55 @@ where
|
|||
}
|
||||
if !neighbors_visited {
|
||||
assert!(visiting.insert(id.clone()), "graph has cycle");
|
||||
let neighbors = neighbors_fn(&node);
|
||||
let neighbors_iter = neighbors_fn(&node).into_iter();
|
||||
stack.reserve(neighbors_iter.size_hint().0 + 1);
|
||||
stack.push((node, true));
|
||||
stack.extend(neighbors.into_iter().map(|neighbor| (neighbor, false)));
|
||||
for neighbor in neighbors_iter {
|
||||
stack.push((neighbor?, false));
|
||||
}
|
||||
} else {
|
||||
visiting.remove(&id);
|
||||
emitted.insert(id);
|
||||
result.push(node);
|
||||
}
|
||||
}
|
||||
result
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Returns neighbors before the node itself.
|
||||
/// Builds a list of nodes reachable from the `start` where neighbors come after
|
||||
/// the node itself.
|
||||
pub fn topo_order_reverse<T, ID, II, NI>(
|
||||
start: II,
|
||||
id_fn: impl Fn(&T) -> ID,
|
||||
neighbors_fn: impl FnMut(&T) -> NI,
|
||||
mut neighbors_fn: impl FnMut(&T) -> NI,
|
||||
) -> Vec<T>
|
||||
where
|
||||
ID: Hash + Eq + Clone,
|
||||
II: IntoIterator<Item = T>,
|
||||
NI: IntoIterator<Item = T>,
|
||||
{
|
||||
let mut result = topo_order_forward(start, id_fn, neighbors_fn);
|
||||
let neighbors_fn = move |node: &T| to_ok_iter(neighbors_fn(node));
|
||||
topo_order_reverse_ok(to_ok_iter(start), id_fn, neighbors_fn).unwrap()
|
||||
}
|
||||
|
||||
/// Builds a list of `Ok` nodes reachable from the `start` where neighbors come
|
||||
/// after the node itself.
|
||||
///
|
||||
/// If `start` or `neighbors_fn()` yields an `Err`, this function terminates and
|
||||
/// returns the error.
|
||||
pub fn topo_order_reverse_ok<T, ID, E, II, NI>(
|
||||
start: II,
|
||||
id_fn: impl Fn(&T) -> ID,
|
||||
neighbors_fn: impl FnMut(&T) -> NI,
|
||||
) -> Result<Vec<T>, E>
|
||||
where
|
||||
ID: Hash + Eq + Clone,
|
||||
II: IntoIterator<Item = Result<T, E>>,
|
||||
NI: IntoIterator<Item = Result<T, E>>,
|
||||
{
|
||||
let mut result = topo_order_forward_ok(start, id_fn, neighbors_fn)?;
|
||||
result.reverse();
|
||||
result
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Like `topo_order_reverse()`, but can iterate linear DAG lazily.
|
||||
|
@ -253,6 +320,29 @@ where
|
|||
ID: Hash + Eq + Clone,
|
||||
II: IntoIterator<Item = T>,
|
||||
NI: IntoIterator<Item = T>,
|
||||
{
|
||||
let neighbors_fn = move |node: &T| to_ok_iter(neighbors_fn(node));
|
||||
topo_order_reverse_ord_ok(to_ok_iter(start), id_fn, neighbors_fn).unwrap()
|
||||
}
|
||||
|
||||
/// Builds a list of `Ok` nodes reachable from the `start` where neighbors come
|
||||
/// after the node itself.
|
||||
///
|
||||
/// Unlike `topo_order_reverse_ok()`, nodes are sorted in reverse `T: Ord` order
|
||||
/// so long as they can respect the topological requirement.
|
||||
///
|
||||
/// If `start` or `neighbors_fn()` yields an `Err`, this function terminates and
|
||||
/// returns the error.
|
||||
pub fn topo_order_reverse_ord_ok<T, ID, E, II, NI>(
|
||||
start: II,
|
||||
id_fn: impl Fn(&T) -> ID,
|
||||
mut neighbors_fn: impl FnMut(&T) -> NI,
|
||||
) -> Result<Vec<T>, E>
|
||||
where
|
||||
T: Ord,
|
||||
ID: Hash + Eq + Clone,
|
||||
II: IntoIterator<Item = Result<T, E>>,
|
||||
NI: IntoIterator<Item = Result<T, E>>,
|
||||
{
|
||||
struct InnerNode<T> {
|
||||
node: Option<T>,
|
||||
|
@ -260,7 +350,7 @@ where
|
|||
}
|
||||
|
||||
// DFS to accumulate incoming edges
|
||||
let mut stack: Vec<T> = start.into_iter().collect();
|
||||
let mut stack: Vec<T> = start.into_iter().try_collect()?;
|
||||
let mut head_node_map: HashMap<ID, T> = HashMap::new();
|
||||
let mut inner_node_map: HashMap<ID, InnerNode<T>> = HashMap::new();
|
||||
let mut neighbor_ids_map: HashMap<ID, Vec<ID>> = HashMap::new();
|
||||
|
@ -270,8 +360,12 @@ where
|
|||
continue; // Already visited
|
||||
}
|
||||
|
||||
let neighbors_iter = neighbors_fn(&node).into_iter();
|
||||
let pos = stack.len();
|
||||
stack.extend(neighbors_fn(&node));
|
||||
stack.reserve(neighbors_iter.size_hint().0);
|
||||
for neighbor in neighbors_iter {
|
||||
stack.push(neighbor?);
|
||||
}
|
||||
let neighbor_ids = stack[pos..].iter().map(&id_fn).collect_vec();
|
||||
if let Some(inner) = inner_node_map.get_mut(&node_id) {
|
||||
inner.node = Some(node);
|
||||
|
@ -316,7 +410,7 @@ where
|
|||
}
|
||||
|
||||
assert!(inner_node_map.is_empty(), "graph has cycle");
|
||||
result
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Find nodes in the start set that are not reachable from other nodes in the
|
||||
|
@ -332,18 +426,40 @@ where
|
|||
II: IntoIterator<Item = T>,
|
||||
NI: IntoIterator<Item = T>,
|
||||
{
|
||||
let start: Vec<T> = start.into_iter().collect();
|
||||
let neighbors_fn = move |node: &T| to_ok_iter(neighbors_fn(node));
|
||||
heads_ok(to_ok_iter(start), id_fn, neighbors_fn).unwrap()
|
||||
}
|
||||
|
||||
/// Finds `Ok` nodes in the start set that are not reachable from other nodes in
|
||||
/// the start set.
|
||||
///
|
||||
/// If `start` or `neighbors_fn()` yields an `Err`, this function terminates and
|
||||
/// returns the error.
|
||||
pub fn heads_ok<T, ID, E, II, NI>(
|
||||
start: II,
|
||||
id_fn: impl Fn(&T) -> ID,
|
||||
mut neighbors_fn: impl FnMut(&T) -> NI,
|
||||
) -> Result<HashSet<T>, E>
|
||||
where
|
||||
T: Hash + Eq + Clone,
|
||||
ID: Hash + Eq,
|
||||
II: IntoIterator<Item = Result<T, E>>,
|
||||
NI: IntoIterator<Item = Result<T, E>>,
|
||||
{
|
||||
let start: Vec<T> = start.into_iter().try_collect()?;
|
||||
let mut reachable: HashSet<T> = start.iter().cloned().collect();
|
||||
for _node in dfs(start, id_fn, |node| {
|
||||
let neighbors: Vec<T> = neighbors_fn(node).into_iter().collect();
|
||||
for neighbor in &neighbors {
|
||||
dfs_ok(start.into_iter().map(Ok), id_fn, |node| {
|
||||
let neighbors: Vec<Result<T, E>> = neighbors_fn(node).into_iter().collect();
|
||||
for neighbor in neighbors.iter().filter_map(|x| x.as_ref().ok()) {
|
||||
reachable.remove(neighbor);
|
||||
}
|
||||
neighbors
|
||||
}) {}
|
||||
reachable
|
||||
})
|
||||
.try_for_each(|r| r.map(|_| ()))?;
|
||||
Ok(reachable)
|
||||
}
|
||||
|
||||
/// Finds the closest common neighbor among the `set1` and `set2`.
|
||||
pub fn closest_common_node<T, ID, II1, II2, NI>(
|
||||
set1: II1,
|
||||
set2: II2,
|
||||
|
@ -355,18 +471,42 @@ where
|
|||
II1: IntoIterator<Item = T>,
|
||||
II2: IntoIterator<Item = T>,
|
||||
NI: IntoIterator<Item = T>,
|
||||
{
|
||||
let neighbors_fn = move |node: &T| to_ok_iter(neighbors_fn(node));
|
||||
closest_common_node_ok(to_ok_iter(set1), to_ok_iter(set2), id_fn, neighbors_fn).unwrap()
|
||||
}
|
||||
|
||||
/// Finds the closest common `Ok` neighbor among the `set1` and `set2`.
|
||||
///
|
||||
/// If the traverse reached to an `Err`, this function terminates and returns
|
||||
/// the error.
|
||||
pub fn closest_common_node_ok<T, ID, E, II1, II2, NI>(
|
||||
set1: II1,
|
||||
set2: II2,
|
||||
id_fn: impl Fn(&T) -> ID,
|
||||
mut neighbors_fn: impl FnMut(&T) -> NI,
|
||||
) -> Result<Option<T>, E>
|
||||
where
|
||||
ID: Hash + Eq,
|
||||
II1: IntoIterator<Item = Result<T, E>>,
|
||||
II2: IntoIterator<Item = Result<T, E>>,
|
||||
NI: IntoIterator<Item = Result<T, E>>,
|
||||
{
|
||||
let mut visited1 = HashSet::new();
|
||||
let mut visited2 = HashSet::new();
|
||||
|
||||
let mut work1: Vec<T> = set1.into_iter().collect();
|
||||
let mut work2: Vec<T> = set2.into_iter().collect();
|
||||
// TODO: might be better to leave an Err so long as the work contains at
|
||||
// least one Ok node. If a work1 node is included in visited2, it should be
|
||||
// the closest node even if work2 had previously contained an Err.
|
||||
let mut work1: Vec<Result<T, E>> = set1.into_iter().collect();
|
||||
let mut work2: Vec<Result<T, E>> = set2.into_iter().collect();
|
||||
while !work1.is_empty() || !work2.is_empty() {
|
||||
let mut new_work1 = vec![];
|
||||
for node in work1 {
|
||||
let node = node?;
|
||||
let id: ID = id_fn(&node);
|
||||
if visited2.contains(&id) {
|
||||
return Some(node);
|
||||
return Ok(Some(node));
|
||||
}
|
||||
if visited1.insert(id) {
|
||||
for neighbor in neighbors_fn(&node) {
|
||||
|
@ -378,9 +518,10 @@ where
|
|||
|
||||
let mut new_work2 = vec![];
|
||||
for node in work2 {
|
||||
let node = node?;
|
||||
let id: ID = id_fn(&node);
|
||||
if visited1.contains(&id) {
|
||||
return Some(node);
|
||||
return Ok(Some(node));
|
||||
}
|
||||
if visited2.insert(id) {
|
||||
for neighbor in neighbors_fn(&node) {
|
||||
|
@ -390,7 +531,11 @@ where
|
|||
}
|
||||
work2 = new_work2;
|
||||
}
|
||||
None
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn to_ok_iter<T>(iter: impl IntoIterator<Item = T>) -> impl Iterator<Item = Result<T, Infallible>> {
|
||||
iter.into_iter().map(Ok)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
@ -401,6 +546,21 @@ mod tests {
|
|||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_dfs_ok() {
|
||||
let neighbors = hashmap! {
|
||||
'A' => vec![],
|
||||
'B' => vec![Ok('A'), Err('X')],
|
||||
'C' => vec![Ok('B')],
|
||||
};
|
||||
let id_fn = |node: &char| *node;
|
||||
let neighbors_fn = |node: &char| neighbors[node].clone();
|
||||
|
||||
// Self and neighbor nodes shouldn't be lost at the error.
|
||||
let nodes = dfs_ok([Ok('C')], id_fn, neighbors_fn).collect_vec();
|
||||
assert_eq!(nodes, [Ok('C'), Ok('B'), Err('X'), Ok('A')]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_topo_order_reverse_linear() {
|
||||
// This graph:
|
||||
|
@ -879,6 +1039,26 @@ mod tests {
|
|||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_topo_order_ok() {
|
||||
let neighbors = hashmap! {
|
||||
'A' => vec![Err('Y')],
|
||||
'B' => vec![Ok('A'), Err('X')],
|
||||
'C' => vec![Ok('B')],
|
||||
};
|
||||
let id_fn = |node: &char| *node;
|
||||
let neighbors_fn = |node: &char| neighbors[node].clone();
|
||||
|
||||
// Terminates at Err('X') no matter if the sorting order is forward or
|
||||
// reverse. The visiting order matters.
|
||||
let result = topo_order_forward_ok([Ok('C')], id_fn, neighbors_fn);
|
||||
assert_eq!(result, Err('X'));
|
||||
let result = topo_order_reverse_ok([Ok('C')], id_fn, neighbors_fn);
|
||||
assert_eq!(result, Err('X'));
|
||||
let result = topo_order_reverse_ord_ok([Ok('C')], id_fn, neighbors_fn);
|
||||
assert_eq!(result, Err('X'));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_closest_common_node_tricky() {
|
||||
// Test this case where A is the shortest distance away, but we still want the
|
||||
|
@ -916,6 +1096,25 @@ mod tests {
|
|||
assert_eq!(common, Some('A'));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_closest_common_node_ok() {
|
||||
let neighbors = hashmap! {
|
||||
'A' => vec![Err('Y')],
|
||||
'B' => vec![Ok('A')],
|
||||
'C' => vec![Ok('A')],
|
||||
'D' => vec![Err('X')],
|
||||
};
|
||||
let id_fn = |node: &char| *node;
|
||||
let neighbors_fn = |node: &char| neighbors[node].clone();
|
||||
|
||||
let result = closest_common_node_ok([Ok('B')], [Ok('C')], id_fn, neighbors_fn);
|
||||
assert_eq!(result, Ok(Some('A')));
|
||||
let result = closest_common_node_ok([Ok('C')], [Ok('D')], id_fn, neighbors_fn);
|
||||
assert_eq!(result, Err('X'));
|
||||
let result = closest_common_node_ok([Ok('C')], [Err('Z')], id_fn, neighbors_fn);
|
||||
assert_eq!(result, Err('Z'));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_heads_mixed() {
|
||||
// Test the uppercase letters are in the start set
|
||||
|
@ -952,4 +1151,22 @@ mod tests {
|
|||
);
|
||||
assert_eq!(actual, hashset!['D', 'F']);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_heads_ok() {
|
||||
let neighbors = hashmap! {
|
||||
'A' => vec![],
|
||||
'B' => vec![Ok('A'), Err('X')],
|
||||
'C' => vec![Ok('B')],
|
||||
};
|
||||
let id_fn = |node: &char| *node;
|
||||
let neighbors_fn = |node: &char| neighbors[node].clone();
|
||||
|
||||
let result = heads_ok([Ok('C')], id_fn, neighbors_fn);
|
||||
assert_eq!(result, Err('X'));
|
||||
let result = heads_ok([Ok('B')], id_fn, neighbors_fn);
|
||||
assert_eq!(result, Err('X'));
|
||||
let result = heads_ok([Ok('A')], id_fn, neighbors_fn);
|
||||
assert_eq!(result, Ok(hashset! {'A'}));
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue