diff --git a/lib/src/dag_walk.rs b/lib/src/dag_walk.rs index ff698f964..a5707fbcb 100644 --- a/lib/src/dag_walk.rs +++ b/lib/src/dag_walk.rs @@ -488,14 +488,28 @@ where { let start: Vec = start.into_iter().try_collect()?; let mut heads: HashSet = start.iter().cloned().collect(); - dfs_ok(start.into_iter().map(Ok), id_fn, |node| { - let neighbors: Vec> = neighbors_fn(node).into_iter().collect(); - for neighbor in neighbors.iter().filter_map(|x| x.as_ref().ok()) { - heads.remove(neighbor); + // Do a BFS until we have only one item left in the frontier. That frontier must + // have originated from one of the heads, and since there can't be cycles, + // it won't be able to eliminate any other heads. + let mut frontier: Vec = heads.iter().cloned().collect(); + let mut visited: HashSet = heads.iter().map(&id_fn).collect(); + let mut root_reached = false; + while frontier.len() > 1 || (!frontier.is_empty() && root_reached) { + frontier = frontier + .iter() + .flat_map(|node| { + let neighbors = neighbors_fn(node).into_iter().collect_vec(); + if neighbors.is_empty() { + root_reached = true; + } + neighbors + }) + .try_collect()?; + for node in &frontier { + heads.remove(node); } - neighbors - }) - .try_for_each(|r| r.map(|_| ()))?; + frontier.retain(|node| visited.insert(id_fn(node))); + } Ok(heads) } @@ -1211,10 +1225,14 @@ mod tests { let neighbors_fn = |node: &char| neighbors[node].clone(); let result = heads_ok([Ok('C')], id_fn, neighbors_fn); - assert_eq!(result, Err('X')); + assert_eq!(result, Ok(hashset! {'C'})); let result = heads_ok([Ok('B')], id_fn, neighbors_fn); - assert_eq!(result, Err('X')); + assert_eq!(result, Ok(hashset! {'B'})); let result = heads_ok([Ok('A')], id_fn, neighbors_fn); assert_eq!(result, Ok(hashset! {'A'})); + let result = heads_ok([Ok('C'), Ok('B')], id_fn, neighbors_fn); + assert_eq!(result, Err('X')); + let result = heads_ok([Ok('C'), Ok('A')], id_fn, neighbors_fn); + assert_eq!(result, Err('X')); } }