diff --git a/lib/src/matchers.rs b/lib/src/matchers.rs index 893d40759..336f19383 100644 --- a/lib/src/matchers.rs +++ b/lib/src/matchers.rs @@ -269,54 +269,67 @@ impl Matcher for IntersectionMatcher<'_> { /// visited. #[derive(PartialEq, Eq, Debug)] struct Dirs { - dirs: HashMap>, - files: HashMap>, + entries: HashMap, + // is_dir/is_file aren't exclusive, both can be set to true. If entries is not empty, + // is_dir should be set. + is_dir: bool, + is_file: bool, } impl Dirs { fn new() -> Self { Dirs { - dirs: HashMap::new(), - files: HashMap::new(), + entries: HashMap::new(), + is_dir: false, + is_file: false, } } + fn add(&mut self, dir: &RepoPath) -> &mut Dirs { + dir.components().iter().fold(self, |sub, name| { + // Avoid name.clone() if entry already exists. + if !sub.entries.contains_key(name) { + sub.is_dir = true; + sub.entries.insert(name.clone(), Dirs::new()); + } + sub.entries.get_mut(name).unwrap() + }) + } + fn add_dir(&mut self, dir: &RepoPath) { - let mut dir = dir.clone(); - let mut maybe_child = None; - loop { - let was_present = self.dirs.contains_key(&dir); - let children = self.dirs.entry(dir.clone()).or_default(); - if let Some(child) = maybe_child { - children.insert(child); - } - if was_present { - break; - } - match dir.split() { - None => break, - Some((new_dir, new_child)) => { - maybe_child = Some(new_child.clone()); - dir = new_dir; - } - }; - } + self.add(dir).is_dir = true; } fn add_file(&mut self, file: &RepoPath) { - let (dir, basename) = file - .split() - .unwrap_or_else(|| panic!("got empty filename: {file:?}")); - self.add_dir(&dir); - self.files.entry(dir).or_default().insert(basename.clone()); + self.add(file).is_file = true; + } + + fn get(&self, dir: &RepoPath) -> Option<&Dirs> { + dir.components() + .iter() + .try_fold(self, |sub, name| sub.entries.get(name)) } fn get_dirs(&self, dir: &RepoPath) -> HashSet { - self.dirs.get(dir).cloned().unwrap_or_default() + self.get(dir) + .map(|sub| { + sub.entries + .iter() + .filter_map(|(name, sub)| sub.is_dir.then(|| name.clone())) + .collect() + }) + .unwrap_or_default() } fn get_files(&self, dir: &RepoPath) -> HashSet { - self.files.get(dir).cloned().unwrap_or_default() + self.get(dir) + .map(|sub| { + sub.entries + .iter() + .filter_map(|(name, sub)| sub.is_file.then(|| name.clone())) + .collect() + }) + .unwrap_or_default() } } @@ -348,6 +361,15 @@ mod tests { dirs.get_dirs(&RepoPath::root()), hashset! {RepoPathComponent::from("dir")} ); + dirs.add_dir(&RepoPath::from_internal_string("dir/sub")); + assert_eq!( + dirs.get_dirs(&RepoPath::from_internal_string("dir")), + hashset! {RepoPathComponent::from("sub")} + ); + assert_eq!( + dirs.get_files(&RepoPath::from_internal_string("dir")), + hashset! {} + ); } #[test] @@ -359,6 +381,14 @@ mod tests { hashset! {RepoPathComponent::from("dir")} ); assert_eq!(dirs.get_files(&RepoPath::root()), hashset! {}); + assert_eq!( + dirs.get_files(&RepoPath::from_internal_string("dir")), + hashset! {RepoPathComponent::from("file")} + ); + assert_eq!( + dirs.get_dirs(&RepoPath::from_internal_string("dir")), + hashset! {} + ); } #[test]