dag_walk: flatten nested loops in topo_order_reverse()

This commit is contained in:
Yuya Nishihara 2023-06-07 21:30:42 +09:00
parent d9c417dcb8
commit 4987c74d4b

View file

@ -16,6 +16,8 @@ use std::collections::HashSet;
use std::hash::Hash;
use std::iter;
use itertools::Itertools as _;
pub fn dfs<T, ID, II, NI>(
start: II,
id_fn: impl Fn(&T) -> ID,
@ -53,32 +55,24 @@ where
II: IntoIterator<Item = T>,
NI: IntoIterator<Item = T>,
{
let mut stack = start.into_iter().map(|node| (node, false)).collect_vec();
let mut visiting = HashSet::new();
let mut emitted = HashSet::new();
let mut result = vec![];
let mut start_nodes: Vec<T> = start.into_iter().collect();
start_nodes.reverse();
for start_node in start_nodes {
let mut stack = vec![(start_node, false)];
while let Some((node, neighbors_visited)) = stack.pop() {
let id = id_fn(&node);
if emitted.contains(&id) {
continue;
}
if !neighbors_visited {
assert!(visiting.insert(id.clone()), "graph has cycle");
let neighbors = neighbors_fn(&node);
stack.push((node, true));
for neighbor in neighbors {
stack.push((neighbor, false));
}
} else {
visiting.remove(&id);
emitted.insert(id);
result.push(node);
}
while let Some((node, neighbors_visited)) = stack.pop() {
let id = id_fn(&node);
if emitted.contains(&id) {
continue;
}
if !neighbors_visited {
assert!(visiting.insert(id.clone()), "graph has cycle");
let neighbors = neighbors_fn(&node);
stack.push((node, true));
stack.extend(neighbors.into_iter().map(|neighbor| (neighbor, false)));
} else {
visiting.remove(&id);
emitted.insert(id);
result.push(node);
}
}
result.reverse();