dag_walk: flatten nested loops in topo_order_reverse()

This commit is contained in:
Yuya Nishihara 2023-06-07 21:30:42 +09:00
parent d9c417dcb8
commit 4987c74d4b

View file

@ -16,6 +16,8 @@ use std::collections::HashSet;
use std::hash::Hash; use std::hash::Hash;
use std::iter; use std::iter;
use itertools::Itertools as _;
pub fn dfs<T, ID, II, NI>( pub fn dfs<T, ID, II, NI>(
start: II, start: II,
id_fn: impl Fn(&T) -> ID, id_fn: impl Fn(&T) -> ID,
@ -53,32 +55,24 @@ where
II: IntoIterator<Item = T>, II: IntoIterator<Item = T>,
NI: IntoIterator<Item = T>, NI: IntoIterator<Item = T>,
{ {
let mut stack = start.into_iter().map(|node| (node, false)).collect_vec();
let mut visiting = HashSet::new(); let mut visiting = HashSet::new();
let mut emitted = HashSet::new(); let mut emitted = HashSet::new();
let mut result = vec![]; let mut result = vec![];
while let Some((node, neighbors_visited)) = stack.pop() {
let mut start_nodes: Vec<T> = start.into_iter().collect(); let id = id_fn(&node);
start_nodes.reverse(); if emitted.contains(&id) {
continue;
for start_node in start_nodes { }
let mut stack = vec![(start_node, false)]; if !neighbors_visited {
while let Some((node, neighbors_visited)) = stack.pop() { assert!(visiting.insert(id.clone()), "graph has cycle");
let id = id_fn(&node); let neighbors = neighbors_fn(&node);
if emitted.contains(&id) { stack.push((node, true));
continue; stack.extend(neighbors.into_iter().map(|neighbor| (neighbor, false)));
} } else {
if !neighbors_visited { visiting.remove(&id);
assert!(visiting.insert(id.clone()), "graph has cycle"); emitted.insert(id);
let neighbors = neighbors_fn(&node); result.push(node);
stack.push((node, true));
for neighbor in neighbors {
stack.push((neighbor, false));
}
} else {
visiting.remove(&id);
emitted.insert(id);
result.push(node);
}
} }
} }
result.reverse(); result.reverse();