mirror of
https://github.com/zed-industries/zed.git
synced 2025-01-24 02:46:43 +00:00
implement new search strategy
This commit is contained in:
parent
0697d08e54
commit
86ec0b1d9f
5 changed files with 227 additions and 50 deletions
121
Cargo.lock
generated
121
Cargo.lock
generated
|
@ -570,7 +570,7 @@ dependencies = [
|
|||
"libc",
|
||||
"pin-project",
|
||||
"redox_syscall 0.2.16",
|
||||
"xattr",
|
||||
"xattr 0.2.3",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -903,6 +903,15 @@ dependencies = [
|
|||
"wyz",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "blas-src"
|
||||
version = "0.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bb48fbaa7a0cb9d6d96c46bac6cedb16f13a10aebcef1c4e73515aaae8c9909d"
|
||||
dependencies = [
|
||||
"openblas-src",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "block"
|
||||
version = "0.1.6"
|
||||
|
@ -1181,6 +1190,15 @@ version = "0.1.2"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a2698f953def977c68f935bb0dfa959375ad4638570e969e2f1e9f433cbf1af6"
|
||||
|
||||
[[package]]
|
||||
name = "cblas-sys"
|
||||
version = "0.1.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b6feecd82cce51b0204cf063f0041d69f24ce83f680d87514b004248e7b0fa65"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cc"
|
||||
version = "1.0.83"
|
||||
|
@ -4580,6 +4598,21 @@ dependencies = [
|
|||
"tempfile",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ndarray"
|
||||
version = "0.15.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32"
|
||||
dependencies = [
|
||||
"cblas-sys",
|
||||
"libc",
|
||||
"matrixmultiply",
|
||||
"num-complex 0.4.4",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
"rawpointer",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ndk"
|
||||
version = "0.7.0"
|
||||
|
@ -4706,7 +4739,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "b8536030f9fea7127f841b45bb6243b27255787fb4eb83958aa1ef9d2fdc0c36"
|
||||
dependencies = [
|
||||
"num-bigint 0.2.6",
|
||||
"num-complex",
|
||||
"num-complex 0.2.4",
|
||||
"num-integer",
|
||||
"num-iter",
|
||||
"num-rational 0.2.4",
|
||||
|
@ -4762,6 +4795,15 @@ dependencies = [
|
|||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-complex"
|
||||
version = "0.4.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1ba157ca0885411de85d6ca030ba7e2a83a28636056c7c699b07c8b6f7383214"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-derive"
|
||||
version = "0.3.3"
|
||||
|
@ -4948,6 +4990,32 @@ version = "0.3.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5"
|
||||
|
||||
[[package]]
|
||||
name = "openblas-build"
|
||||
version = "0.10.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "eba42c395477605f400a8d79ee0b756cfb82abe3eb5618e35fa70d3a36010a7f"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"flate2",
|
||||
"native-tls",
|
||||
"tar",
|
||||
"thiserror",
|
||||
"ureq",
|
||||
"walkdir",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "openblas-src"
|
||||
version = "0.10.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "38e5d8af0b707ac2fe1574daa88b4157da73b0de3dc7c39fe3e2c0bb64070501"
|
||||
dependencies = [
|
||||
"dirs 3.0.2",
|
||||
"openblas-build",
|
||||
"vcpkg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "openssl"
|
||||
version = "0.10.57"
|
||||
|
@ -6422,6 +6490,18 @@ dependencies = [
|
|||
"webpki 0.22.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-native-certs"
|
||||
version = "0.6.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a9aace74cb666635c918e9c12bc0d348266037aa8eb599b5cba565709a8dff00"
|
||||
dependencies = [
|
||||
"openssl-probe",
|
||||
"rustls-pemfile",
|
||||
"schannel",
|
||||
"security-framework",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-pemfile"
|
||||
version = "1.0.3"
|
||||
|
@ -6740,6 +6820,7 @@ dependencies = [
|
|||
"ai",
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
"blas-src",
|
||||
"client",
|
||||
"collections",
|
||||
"ctor",
|
||||
|
@ -6751,6 +6832,7 @@ dependencies = [
|
|||
"language",
|
||||
"lazy_static",
|
||||
"log",
|
||||
"ndarray",
|
||||
"node_runtime",
|
||||
"ordered-float",
|
||||
"parking_lot 0.11.2",
|
||||
|
@ -7629,6 +7711,17 @@ version = "1.0.1"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369"
|
||||
|
||||
[[package]]
|
||||
name = "tar"
|
||||
version = "0.4.40"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b16afcea1f22891c49a00c751c7b63b2233284064f11a200fc624137c51e2ddb"
|
||||
dependencies = [
|
||||
"filetime",
|
||||
"libc",
|
||||
"xattr 1.0.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "target-lexicon"
|
||||
version = "0.12.11"
|
||||
|
@ -8703,6 +8796,21 @@ version = "0.7.1"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a"
|
||||
|
||||
[[package]]
|
||||
name = "ureq"
|
||||
version = "2.7.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0b11c96ac7ee530603dcdf68ed1557050f374ce55a5a07193ebf8cbc9f8927e9"
|
||||
dependencies = [
|
||||
"base64 0.21.4",
|
||||
"flate2",
|
||||
"log",
|
||||
"native-tls",
|
||||
"once_cell",
|
||||
"rustls-native-certs",
|
||||
"url",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "url"
|
||||
version = "2.4.1"
|
||||
|
@ -9754,6 +9862,15 @@ dependencies = [
|
|||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "xattr"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f4686009f71ff3e5c4dbcf1a282d0a44db3f021ba69350cd42086b3e5f1c6985"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "xmlparser"
|
||||
version = "0.13.5"
|
||||
|
|
|
@ -39,6 +39,8 @@ rand.workspace = true
|
|||
schemars.workspace = true
|
||||
globset.workspace = true
|
||||
sha1 = "0.10.5"
|
||||
ndarray = { version = "0.15.0", features = ["blas"] }
|
||||
blas-src = { version = "0.8", features = ["openblas"] }
|
||||
|
||||
[dev-dependencies]
|
||||
collections = { path = "../collections", features = ["test-support"] }
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
extern crate blas_src;
|
||||
|
||||
use crate::{
|
||||
parsing::{Span, SpanDigest},
|
||||
SEMANTIC_INDEX_VERSION,
|
||||
|
@ -7,6 +9,7 @@ use anyhow::{anyhow, Context, Result};
|
|||
use collections::HashMap;
|
||||
use futures::channel::oneshot;
|
||||
use gpui::executor;
|
||||
use ndarray::{Array1, Array2};
|
||||
use ordered_float::OrderedFloat;
|
||||
use project::{search::PathMatcher, Fs};
|
||||
use rpc::proto::Timestamp;
|
||||
|
@ -19,10 +22,16 @@ use std::{
|
|||
path::{Path, PathBuf},
|
||||
rc::Rc,
|
||||
sync::Arc,
|
||||
time::SystemTime,
|
||||
time::{Instant, SystemTime},
|
||||
};
|
||||
use util::TryFutureExt;
|
||||
|
||||
pub fn argsort<T: Ord>(data: &[T]) -> Vec<usize> {
|
||||
let mut indices = (0..data.len()).collect::<Vec<_>>();
|
||||
indices.sort_by_key(|&i| &data[i]);
|
||||
indices
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct FileRecord {
|
||||
pub id: usize,
|
||||
|
@ -409,23 +418,82 @@ impl VectorDatabase {
|
|||
limit: usize,
|
||||
file_ids: &[i64],
|
||||
) -> impl Future<Output = Result<Vec<(i64, OrderedFloat<f32>)>>> {
|
||||
let query_embedding = query_embedding.clone();
|
||||
let file_ids = file_ids.to_vec();
|
||||
let query = query_embedding.clone().0;
|
||||
let query = Array1::from_vec(query);
|
||||
self.transact(move |db| {
|
||||
let mut results = Vec::<(i64, OrderedFloat<f32>)>::with_capacity(limit + 1);
|
||||
Self::for_each_span(db, &file_ids, |id, embedding| {
|
||||
let similarity = embedding.similarity(&query_embedding);
|
||||
let ix = match results
|
||||
.binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s))
|
||||
{
|
||||
Ok(ix) => ix,
|
||||
Err(ix) => ix,
|
||||
};
|
||||
results.insert(ix, (id, similarity));
|
||||
results.truncate(limit);
|
||||
})?;
|
||||
let mut query_statement = db.prepare(
|
||||
"
|
||||
SELECT
|
||||
id, embedding
|
||||
FROM
|
||||
spans
|
||||
WHERE
|
||||
file_id IN rarray(?)
|
||||
",
|
||||
)?;
|
||||
|
||||
anyhow::Ok(results)
|
||||
let deserialized_rows = query_statement
|
||||
.query_map(params![ids_to_sql(&file_ids)], |row| {
|
||||
Ok((row.get::<_, usize>(0)?, row.get::<_, Embedding>(1)?))
|
||||
})?
|
||||
.filter_map(|row| row.ok())
|
||||
.collect::<Vec<(usize, Embedding)>>();
|
||||
|
||||
let batch_n = 250;
|
||||
let mut batches = Vec::new();
|
||||
let mut batch_ids = Vec::new();
|
||||
let mut batch_embeddings: Vec<f32> = Vec::new();
|
||||
deserialized_rows.iter().for_each(|(id, embedding)| {
|
||||
batch_ids.push(id);
|
||||
batch_embeddings.extend(&embedding.0);
|
||||
if batch_ids.len() == batch_n {
|
||||
let array =
|
||||
Array2::from_shape_vec((batch_ids.len(), 1536), batch_embeddings.clone());
|
||||
match array {
|
||||
Ok(array) => {
|
||||
batches.push((batch_ids.clone(), array));
|
||||
}
|
||||
Err(err) => log::error!("Failed to deserialize to ndarray: {:?}", err),
|
||||
}
|
||||
|
||||
batch_ids = Vec::new();
|
||||
batch_embeddings = Vec::new();
|
||||
}
|
||||
});
|
||||
|
||||
if batch_ids.len() > 0 {
|
||||
let array =
|
||||
Array2::from_shape_vec((batch_ids.len(), 1536), batch_embeddings.clone());
|
||||
match array {
|
||||
Ok(array) => {
|
||||
batches.push((batch_ids.clone(), array));
|
||||
}
|
||||
Err(err) => log::error!("Failed to deserialize to ndarray: {:?}", err),
|
||||
}
|
||||
}
|
||||
|
||||
let mut ids: Vec<usize> = Vec::new();
|
||||
let mut results = Vec::new();
|
||||
for (batch_ids, array) in batches {
|
||||
let scores = array
|
||||
.dot(&query.t())
|
||||
.to_vec()
|
||||
.iter()
|
||||
.map(|score| OrderedFloat(*score))
|
||||
.collect::<Vec<OrderedFloat<f32>>>();
|
||||
results.extend(scores);
|
||||
ids.extend(batch_ids);
|
||||
}
|
||||
|
||||
let sorted_idx = argsort(&results);
|
||||
let mut sorted_results = Vec::new();
|
||||
let last_idx = limit.min(sorted_idx.len());
|
||||
for idx in &sorted_idx[0..last_idx] {
|
||||
sorted_results.push((ids[*idx] as i64, results[*idx]))
|
||||
}
|
||||
|
||||
Ok(sorted_results)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -468,31 +536,6 @@ impl VectorDatabase {
|
|||
})
|
||||
}
|
||||
|
||||
fn for_each_span(
|
||||
db: &rusqlite::Connection,
|
||||
file_ids: &[i64],
|
||||
mut f: impl FnMut(i64, Embedding),
|
||||
) -> Result<()> {
|
||||
let mut query_statement = db.prepare(
|
||||
"
|
||||
SELECT
|
||||
id, embedding
|
||||
FROM
|
||||
spans
|
||||
WHERE
|
||||
file_id IN rarray(?)
|
||||
",
|
||||
)?;
|
||||
|
||||
query_statement
|
||||
.query_map(params![ids_to_sql(&file_ids)], |row| {
|
||||
Ok((row.get(0)?, row.get::<_, Embedding>(1)?))
|
||||
})?
|
||||
.filter_map(|row| row.ok())
|
||||
.for_each(|(id, embedding)| f(id, embedding));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn spans_for_ids(
|
||||
&self,
|
||||
ids: &[i64],
|
||||
|
|
|
@ -705,11 +705,13 @@ impl SemanticIndex {
|
|||
|
||||
cx.spawn(|this, mut cx| async move {
|
||||
index.await?;
|
||||
let t0 = Instant::now();
|
||||
let query = embedding_provider
|
||||
.embed_batch(vec![query])
|
||||
.await?
|
||||
.pop()
|
||||
.ok_or_else(|| anyhow!("could not embed query"))?;
|
||||
log::trace!("Embedding Search Query: {:?}", t0.elapsed().as_millis());
|
||||
|
||||
let search_start = Instant::now();
|
||||
let modified_buffer_results = this.update(&mut cx, |this, cx| {
|
||||
|
@ -787,10 +789,15 @@ impl SemanticIndex {
|
|||
|
||||
let batch_n = cx.background().num_cpus();
|
||||
let ids_len = file_ids.clone().len();
|
||||
let batch_size = if ids_len <= batch_n {
|
||||
ids_len
|
||||
} else {
|
||||
ids_len / batch_n
|
||||
let minimum_batch_size = 50;
|
||||
|
||||
let batch_size = {
|
||||
let size = ids_len / batch_n;
|
||||
if size < minimum_batch_size {
|
||||
minimum_batch_size
|
||||
} else {
|
||||
size
|
||||
}
|
||||
};
|
||||
|
||||
let mut batch_results = Vec::new();
|
||||
|
@ -813,17 +820,26 @@ impl SemanticIndex {
|
|||
let batch_results = futures::future::join_all(batch_results).await;
|
||||
|
||||
let mut results = Vec::new();
|
||||
let mut min_similarity = None;
|
||||
for batch_result in batch_results {
|
||||
if batch_result.is_ok() {
|
||||
for (id, similarity) in batch_result.unwrap() {
|
||||
if min_similarity.map_or_else(|| false, |min_sim| min_sim > similarity) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let ix = match results
|
||||
.binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s))
|
||||
{
|
||||
Ok(ix) => ix,
|
||||
Err(ix) => ix,
|
||||
};
|
||||
results.insert(ix, (id, similarity));
|
||||
results.truncate(limit);
|
||||
|
||||
if ix <= limit {
|
||||
min_similarity = Some(similarity);
|
||||
results.insert(ix, (id, similarity));
|
||||
results.truncate(limit);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -856,7 +872,6 @@ impl SemanticIndex {
|
|||
})?;
|
||||
|
||||
let buffers = futures::future::join_all(tasks).await;
|
||||
|
||||
Ok(buffers
|
||||
.into_iter()
|
||||
.zip(ranges)
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
#!/bin/bash
|
||||
|
||||
cargo run -p semantic_index --example eval
|
||||
RUST_LOG=semantic_index=trace cargo run -p semantic_index --example eval --release
|
||||
|
|
Loading…
Reference in a new issue