diff --git a/lib/src/matchers.rs b/lib/src/matchers.rs index dd6464959..0f5573a42 100644 --- a/lib/src/matchers.rs +++ b/lib/src/matchers.rs @@ -38,6 +38,13 @@ impl Visit { files: VisitFiles::Set(hashset! {}), } } + + pub fn is_nothing(&self) -> bool { + matches!(self, Visit::Specific { + dirs: VisitDirs::Set(dirs), + files: VisitFiles::Set(files) + } if dirs.is_empty() && files.is_empty()) + } } #[derive(PartialEq, Eq, Debug)] @@ -160,6 +167,44 @@ impl Matcher for PrefixMatcher { } } +pub struct DifferenceMatcher<'input> { + /// The minuend + wanted: &'input dyn Matcher, + /// The subtrahend + unwanted: &'input dyn Matcher, +} + +impl<'input> DifferenceMatcher<'input> { + pub fn new(wanted: &'input dyn Matcher, unwanted: &'input dyn Matcher) -> Self { + Self { wanted, unwanted } + } +} + +impl Matcher for DifferenceMatcher<'_> { + fn matches(&self, file: &RepoPath) -> bool { + self.wanted.matches(file) && !self.unwanted.matches(file) + } + + fn visit(&self, dir: &RepoPath) -> Visit { + match self.unwanted.visit(dir) { + Visit::AllRecursively => Visit::nothing(), + unwanted_visit => match self.wanted.visit(dir) { + Visit::AllRecursively => { + if unwanted_visit.is_nothing() { + Visit::AllRecursively + } else { + Visit::Specific { + dirs: VisitDirs::All, + files: VisitFiles::All, + } + } + } + wanted_visit => wanted_visit, + }, + } + } +} + /// Keeps track of which subdirectories and files of each directory need to be /// visited. #[derive(PartialEq, Eq, Debug)] @@ -419,4 +464,94 @@ mod tests { Visit::AllRecursively ); } + + #[test] + fn test_differencematcher_remove_subdir() { + let m1 = PrefixMatcher::new(&[ + RepoPath::from_internal_string("foo"), + RepoPath::from_internal_string("bar"), + ]); + let m2 = PrefixMatcher::new(&[RepoPath::from_internal_string("foo/bar")]); + let m = DifferenceMatcher::new(&m1, &m2); + + assert!(m.matches(&RepoPath::from_internal_string("foo"))); + assert!(!m.matches(&RepoPath::from_internal_string("foo/bar"))); + assert!(!m.matches(&RepoPath::from_internal_string("foo/bar/baz"))); + assert!(m.matches(&RepoPath::from_internal_string("foo/baz"))); + assert!(m.matches(&RepoPath::from_internal_string("bar"))); + + assert_eq!( + m.visit(&RepoPath::root()), + Visit::Specific { + dirs: VisitDirs::Set( + hashset! {RepoPathComponent::from("foo"), RepoPathComponent::from("bar")} + ), + files: VisitFiles::Set( + hashset! {RepoPathComponent::from("foo"), RepoPathComponent::from("bar")} + ), + } + ); + assert_eq!( + m.visit(&RepoPath::from_internal_string("foo")), + Visit::Specific { + dirs: VisitDirs::All, + files: VisitFiles::All, + } + ); + assert_eq!( + m.visit(&RepoPath::from_internal_string("foo/bar")), + Visit::nothing() + ); + assert_eq!( + m.visit(&RepoPath::from_internal_string("foo/baz")), + Visit::AllRecursively + ); + assert_eq!( + m.visit(&RepoPath::from_internal_string("bar")), + Visit::AllRecursively + ); + } + + #[test] + fn test_differencematcher_shared_patterns() { + let m1 = PrefixMatcher::new(&[ + RepoPath::from_internal_string("foo"), + RepoPath::from_internal_string("bar"), + ]); + let m2 = PrefixMatcher::new(&[RepoPath::from_internal_string("foo")]); + let m = DifferenceMatcher::new(&m1, &m2); + + assert!(!m.matches(&RepoPath::from_internal_string("foo"))); + assert!(!m.matches(&RepoPath::from_internal_string("foo/bar"))); + assert!(m.matches(&RepoPath::from_internal_string("bar"))); + assert!(m.matches(&RepoPath::from_internal_string("bar/foo"))); + + assert_eq!( + m.visit(&RepoPath::root()), + Visit::Specific { + dirs: VisitDirs::Set( + hashset! {RepoPathComponent::from("foo"), RepoPathComponent::from("bar")} + ), + files: VisitFiles::Set( + hashset! {RepoPathComponent::from("foo"), RepoPathComponent::from("bar")} + ), + } + ); + assert_eq!( + m.visit(&RepoPath::from_internal_string("foo")), + Visit::nothing() + ); + assert_eq!( + m.visit(&RepoPath::from_internal_string("foo/bar")), + Visit::nothing() + ); + assert_eq!( + m.visit(&RepoPath::from_internal_string("bar")), + Visit::AllRecursively + ); + assert_eq!( + m.visit(&RepoPath::from_internal_string("bar/foo")), + Visit::AllRecursively + ); + } }