mirror of
https://github.com/zed-industries/zed.git
synced 2025-01-12 13:24:19 +00:00
464 lines
15 KiB
Rust
464 lines
15 KiB
Rust
use std::{
|
||
borrow::Cow,
|
||
sync::atomic::{self, AtomicBool},
|
||
};
|
||
|
||
use crate::CharBag;
|
||
|
||
const BASE_DISTANCE_PENALTY: f64 = 0.6;
|
||
const ADDITIONAL_DISTANCE_PENALTY: f64 = 0.05;
|
||
const MIN_DISTANCE_PENALTY: f64 = 0.2;
|
||
|
||
pub struct Matcher<'a> {
|
||
query: &'a [char],
|
||
lowercase_query: &'a [char],
|
||
query_char_bag: CharBag,
|
||
smart_case: bool,
|
||
max_results: usize,
|
||
min_score: f64,
|
||
match_positions: Vec<usize>,
|
||
last_positions: Vec<usize>,
|
||
score_matrix: Vec<Option<f64>>,
|
||
best_position_matrix: Vec<usize>,
|
||
}
|
||
|
||
pub trait Match: Ord {
|
||
fn score(&self) -> f64;
|
||
fn set_positions(&mut self, positions: Vec<usize>);
|
||
}
|
||
|
||
pub trait MatchCandidate {
|
||
fn has_chars(&self, bag: CharBag) -> bool;
|
||
fn to_string(&self) -> Cow<'_, str>;
|
||
}
|
||
|
||
impl<'a> Matcher<'a> {
|
||
pub fn new(
|
||
query: &'a [char],
|
||
lowercase_query: &'a [char],
|
||
query_char_bag: CharBag,
|
||
smart_case: bool,
|
||
max_results: usize,
|
||
) -> Self {
|
||
Self {
|
||
query,
|
||
lowercase_query,
|
||
query_char_bag,
|
||
min_score: 0.0,
|
||
last_positions: vec![0; query.len()],
|
||
match_positions: vec![0; query.len()],
|
||
score_matrix: Vec::new(),
|
||
best_position_matrix: Vec::new(),
|
||
smart_case,
|
||
max_results,
|
||
}
|
||
}
|
||
|
||
pub fn match_candidates<C: MatchCandidate, R, F>(
|
||
&mut self,
|
||
prefix: &[char],
|
||
lowercase_prefix: &[char],
|
||
candidates: impl Iterator<Item = C>,
|
||
results: &mut Vec<R>,
|
||
cancel_flag: &AtomicBool,
|
||
build_match: F,
|
||
) where
|
||
R: Match,
|
||
F: Fn(&C, f64) -> R,
|
||
{
|
||
let mut candidate_chars = Vec::new();
|
||
let mut lowercase_candidate_chars = Vec::new();
|
||
|
||
for candidate in candidates {
|
||
if !candidate.has_chars(self.query_char_bag) {
|
||
continue;
|
||
}
|
||
|
||
if cancel_flag.load(atomic::Ordering::Relaxed) {
|
||
break;
|
||
}
|
||
|
||
candidate_chars.clear();
|
||
lowercase_candidate_chars.clear();
|
||
for c in candidate.to_string().chars() {
|
||
candidate_chars.push(c);
|
||
lowercase_candidate_chars.push(c.to_ascii_lowercase());
|
||
}
|
||
|
||
if !self.find_last_positions(lowercase_prefix, &lowercase_candidate_chars) {
|
||
continue;
|
||
}
|
||
|
||
let matrix_len = self.query.len() * (prefix.len() + candidate_chars.len());
|
||
self.score_matrix.clear();
|
||
self.score_matrix.resize(matrix_len, None);
|
||
self.best_position_matrix.clear();
|
||
self.best_position_matrix.resize(matrix_len, 0);
|
||
|
||
let score = self.score_match(
|
||
&candidate_chars,
|
||
&lowercase_candidate_chars,
|
||
prefix,
|
||
lowercase_prefix,
|
||
);
|
||
|
||
if score > 0.0 {
|
||
let mut mat = build_match(&candidate, score);
|
||
if let Err(i) = results.binary_search_by(|m| mat.cmp(m)) {
|
||
if results.len() < self.max_results {
|
||
mat.set_positions(self.match_positions.clone());
|
||
results.insert(i, mat);
|
||
} else if i < results.len() {
|
||
results.pop();
|
||
mat.set_positions(self.match_positions.clone());
|
||
results.insert(i, mat);
|
||
}
|
||
if results.len() == self.max_results {
|
||
self.min_score = results.last().unwrap().score();
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
fn find_last_positions(
|
||
&mut self,
|
||
lowercase_prefix: &[char],
|
||
lowercase_candidate: &[char],
|
||
) -> bool {
|
||
let mut lowercase_prefix = lowercase_prefix.iter();
|
||
let mut lowercase_candidate = lowercase_candidate.iter();
|
||
for (i, char) in self.lowercase_query.iter().enumerate().rev() {
|
||
if let Some(j) = lowercase_candidate.rposition(|c| c == char) {
|
||
self.last_positions[i] = j + lowercase_prefix.len();
|
||
} else if let Some(j) = lowercase_prefix.rposition(|c| c == char) {
|
||
self.last_positions[i] = j;
|
||
} else {
|
||
return false;
|
||
}
|
||
}
|
||
true
|
||
}
|
||
|
||
fn score_match(
|
||
&mut self,
|
||
path: &[char],
|
||
path_cased: &[char],
|
||
prefix: &[char],
|
||
lowercase_prefix: &[char],
|
||
) -> f64 {
|
||
let score = self.recursive_score_match(
|
||
path,
|
||
path_cased,
|
||
prefix,
|
||
lowercase_prefix,
|
||
0,
|
||
0,
|
||
self.query.len() as f64,
|
||
) * self.query.len() as f64;
|
||
|
||
if score <= 0.0 {
|
||
return 0.0;
|
||
}
|
||
|
||
let path_len = prefix.len() + path.len();
|
||
let mut cur_start = 0;
|
||
let mut byte_ix = 0;
|
||
let mut char_ix = 0;
|
||
for i in 0..self.query.len() {
|
||
let match_char_ix = self.best_position_matrix[i * path_len + cur_start];
|
||
while char_ix < match_char_ix {
|
||
let ch = prefix
|
||
.get(char_ix)
|
||
.or_else(|| path.get(char_ix - prefix.len()))
|
||
.unwrap();
|
||
byte_ix += ch.len_utf8();
|
||
char_ix += 1;
|
||
}
|
||
cur_start = match_char_ix + 1;
|
||
self.match_positions[i] = byte_ix;
|
||
}
|
||
|
||
score
|
||
}
|
||
|
||
#[allow(clippy::too_many_arguments)]
|
||
fn recursive_score_match(
|
||
&mut self,
|
||
path: &[char],
|
||
path_cased: &[char],
|
||
prefix: &[char],
|
||
lowercase_prefix: &[char],
|
||
query_idx: usize,
|
||
path_idx: usize,
|
||
cur_score: f64,
|
||
) -> f64 {
|
||
if query_idx == self.query.len() {
|
||
return 1.0;
|
||
}
|
||
|
||
let path_len = prefix.len() + path.len();
|
||
|
||
if let Some(memoized) = self.score_matrix[query_idx * path_len + path_idx] {
|
||
return memoized;
|
||
}
|
||
|
||
let mut score = 0.0;
|
||
let mut best_position = 0;
|
||
|
||
let query_char = self.lowercase_query[query_idx];
|
||
let limit = self.last_positions[query_idx];
|
||
|
||
let mut last_slash = 0;
|
||
for j in path_idx..=limit {
|
||
let path_char = if j < prefix.len() {
|
||
lowercase_prefix[j]
|
||
} else {
|
||
path_cased[j - prefix.len()]
|
||
};
|
||
let is_path_sep = path_char == '/' || path_char == '\\';
|
||
|
||
if query_idx == 0 && is_path_sep {
|
||
last_slash = j;
|
||
}
|
||
|
||
if query_char == path_char || (is_path_sep && query_char == '_' || query_char == '\\') {
|
||
let curr = if j < prefix.len() {
|
||
prefix[j]
|
||
} else {
|
||
path[j - prefix.len()]
|
||
};
|
||
|
||
let mut char_score = 1.0;
|
||
if j > path_idx {
|
||
let last = if j - 1 < prefix.len() {
|
||
prefix[j - 1]
|
||
} else {
|
||
path[j - 1 - prefix.len()]
|
||
};
|
||
|
||
if last == '/' {
|
||
char_score = 0.9;
|
||
} else if (last == '-' || last == '_' || last == ' ' || last.is_numeric())
|
||
|| (last.is_lowercase() && curr.is_uppercase())
|
||
{
|
||
char_score = 0.8;
|
||
} else if last == '.' {
|
||
char_score = 0.7;
|
||
} else if query_idx == 0 {
|
||
char_score = BASE_DISTANCE_PENALTY;
|
||
} else {
|
||
char_score = MIN_DISTANCE_PENALTY.max(
|
||
BASE_DISTANCE_PENALTY
|
||
- (j - path_idx - 1) as f64 * ADDITIONAL_DISTANCE_PENALTY,
|
||
);
|
||
}
|
||
}
|
||
|
||
// Apply a severe penalty if the case doesn't match.
|
||
// This will make the exact matches have higher score than the case-insensitive and the
|
||
// path insensitive matches.
|
||
if (self.smart_case || curr == '/') && self.query[query_idx] != curr {
|
||
char_score *= 0.001;
|
||
}
|
||
|
||
let mut multiplier = char_score;
|
||
|
||
// Scale the score based on how deep within the path we found the match.
|
||
if query_idx == 0 {
|
||
multiplier /= ((prefix.len() + path.len()) - last_slash) as f64;
|
||
}
|
||
|
||
let mut next_score = 1.0;
|
||
if self.min_score > 0.0 {
|
||
next_score = cur_score * multiplier;
|
||
// Scores only decrease. If we can't pass the previous best, bail
|
||
if next_score < self.min_score {
|
||
// Ensure that score is non-zero so we use it in the memo table.
|
||
if score == 0.0 {
|
||
score = 1e-18;
|
||
}
|
||
continue;
|
||
}
|
||
}
|
||
|
||
let new_score = self.recursive_score_match(
|
||
path,
|
||
path_cased,
|
||
prefix,
|
||
lowercase_prefix,
|
||
query_idx + 1,
|
||
j + 1,
|
||
next_score,
|
||
) * multiplier;
|
||
|
||
if new_score > score {
|
||
score = new_score;
|
||
best_position = j;
|
||
// Optimization: can't score better than 1.
|
||
if new_score == 1.0 {
|
||
break;
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
if best_position != 0 {
|
||
self.best_position_matrix[query_idx * path_len + path_idx] = best_position;
|
||
}
|
||
|
||
self.score_matrix[query_idx * path_len + path_idx] = Some(score);
|
||
score
|
||
}
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use crate::{PathMatch, PathMatchCandidate};
|
||
|
||
use super::*;
|
||
use std::{
|
||
path::{Path, PathBuf},
|
||
sync::Arc,
|
||
};
|
||
|
||
#[test]
|
||
fn test_get_last_positions() {
|
||
let mut query: &[char] = &['d', 'c'];
|
||
let mut matcher = Matcher::new(query, query, query.into(), false, 10);
|
||
let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
|
||
assert!(!result);
|
||
|
||
query = &['c', 'd'];
|
||
let mut matcher = Matcher::new(query, query, query.into(), false, 10);
|
||
let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
|
||
assert!(result);
|
||
assert_eq!(matcher.last_positions, vec![2, 4]);
|
||
|
||
query = &['z', '/', 'z', 'f'];
|
||
let mut matcher = Matcher::new(query, query, query.into(), false, 10);
|
||
let result = matcher.find_last_positions(&['z', 'e', 'd', '/'], &['z', 'e', 'd', '/', 'f']);
|
||
assert!(result);
|
||
assert_eq!(matcher.last_positions, vec![0, 3, 4, 8]);
|
||
}
|
||
|
||
#[test]
|
||
fn test_match_path_entries() {
|
||
let paths = vec![
|
||
"",
|
||
"a",
|
||
"ab",
|
||
"abC",
|
||
"abcd",
|
||
"alphabravocharlie",
|
||
"AlphaBravoCharlie",
|
||
"thisisatestdir",
|
||
"/////ThisIsATestDir",
|
||
"/this/is/a/test/dir",
|
||
"/test/tiatd",
|
||
];
|
||
|
||
assert_eq!(
|
||
match_single_path_query("abc", false, &paths),
|
||
vec![
|
||
("abC", vec![0, 1, 2]),
|
||
("abcd", vec![0, 1, 2]),
|
||
("AlphaBravoCharlie", vec![0, 5, 10]),
|
||
("alphabravocharlie", vec![4, 5, 10]),
|
||
]
|
||
);
|
||
assert_eq!(
|
||
match_single_path_query("t/i/a/t/d", false, &paths),
|
||
vec![("/this/is/a/test/dir", vec![1, 5, 6, 8, 9, 10, 11, 15, 16]),]
|
||
);
|
||
|
||
assert_eq!(
|
||
match_single_path_query("tiatd", false, &paths),
|
||
vec![
|
||
("/test/tiatd", vec![6, 7, 8, 9, 10]),
|
||
("/this/is/a/test/dir", vec![1, 6, 9, 11, 16]),
|
||
("/////ThisIsATestDir", vec![5, 9, 11, 12, 16]),
|
||
("thisisatestdir", vec![0, 2, 6, 7, 11]),
|
||
]
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn test_match_multibyte_path_entries() {
|
||
let paths = vec!["aαbβ/cγdδ", "αβγδ/bcde", "c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", "/d/🆒/h"];
|
||
assert_eq!("1️⃣".len(), 7);
|
||
assert_eq!(
|
||
match_single_path_query("bcd", false, &paths),
|
||
vec![
|
||
("αβγδ/bcde", vec![9, 10, 11]),
|
||
("aαbβ/cγdδ", vec![3, 7, 10]),
|
||
]
|
||
);
|
||
assert_eq!(
|
||
match_single_path_query("cde", false, &paths),
|
||
vec![
|
||
("αβγδ/bcde", vec![10, 11, 12]),
|
||
("c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", vec![0, 23, 46]),
|
||
]
|
||
);
|
||
}
|
||
|
||
fn match_single_path_query<'a>(
|
||
query: &str,
|
||
smart_case: bool,
|
||
paths: &[&'a str],
|
||
) -> Vec<(&'a str, Vec<usize>)> {
|
||
let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
|
||
let query = query.chars().collect::<Vec<_>>();
|
||
let query_chars = CharBag::from(&lowercase_query[..]);
|
||
|
||
let path_arcs: Vec<Arc<Path>> = paths
|
||
.iter()
|
||
.map(|path| Arc::from(PathBuf::from(path)))
|
||
.collect::<Vec<_>>();
|
||
let mut path_entries = Vec::new();
|
||
for (i, path) in paths.iter().enumerate() {
|
||
let lowercase_path = path.to_lowercase().chars().collect::<Vec<_>>();
|
||
let char_bag = CharBag::from(lowercase_path.as_slice());
|
||
path_entries.push(PathMatchCandidate {
|
||
char_bag,
|
||
path: &path_arcs[i],
|
||
});
|
||
}
|
||
|
||
let mut matcher = Matcher::new(&query, &lowercase_query, query_chars, smart_case, 100);
|
||
|
||
let cancel_flag = AtomicBool::new(false);
|
||
let mut results = Vec::new();
|
||
|
||
matcher.match_candidates(
|
||
&[],
|
||
&[],
|
||
path_entries.into_iter(),
|
||
&mut results,
|
||
&cancel_flag,
|
||
|candidate, score| PathMatch {
|
||
score,
|
||
worktree_id: 0,
|
||
positions: Vec::new(),
|
||
path: candidate.path.clone(),
|
||
path_prefix: "".into(),
|
||
distance_to_relative_ancestor: usize::MAX,
|
||
},
|
||
);
|
||
|
||
results
|
||
.into_iter()
|
||
.map(|result| {
|
||
(
|
||
paths
|
||
.iter()
|
||
.copied()
|
||
.find(|p| result.path.as_ref() == Path::new(p))
|
||
.unwrap(),
|
||
result.positions,
|
||
)
|
||
})
|
||
.collect()
|
||
}
|
||
}
|