mirror of
https://github.com/zed-industries/zed.git
synced 2025-02-03 08:54:04 +00:00
added brute force search and VectorSearch trait
This commit is contained in:
parent
65bbb7c57b
commit
7937a16002
3 changed files with 122 additions and 2 deletions
39
Cargo.lock
generated
39
Cargo.lock
generated
|
@ -3837,6 +3837,16 @@ version = "0.5.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "73cbba799671b762df5a175adf59ce145165747bb891505c43d09aefbbf38beb"
|
checksum = "73cbba799671b762df5a175adf59ce145165747bb891505c43d09aefbbf38beb"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "matrixmultiply"
|
||||||
|
version = "0.3.7"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "090126dc04f95dc0d1c1c91f61bdd474b3930ca064c1edc8a849da2c6cbe1e77"
|
||||||
|
dependencies = [
|
||||||
|
"autocfg 1.1.0",
|
||||||
|
"rawpointer",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "maybe-owned"
|
name = "maybe-owned"
|
||||||
version = "0.3.4"
|
version = "0.3.4"
|
||||||
|
@ -4121,6 +4131,19 @@ dependencies = [
|
||||||
"tempfile",
|
"tempfile",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ndarray"
|
||||||
|
version = "0.15.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32"
|
||||||
|
dependencies = [
|
||||||
|
"matrixmultiply",
|
||||||
|
"num-complex",
|
||||||
|
"num-integer",
|
||||||
|
"num-traits",
|
||||||
|
"rawpointer",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "net2"
|
name = "net2"
|
||||||
version = "0.2.38"
|
version = "0.2.38"
|
||||||
|
@ -4228,6 +4251,15 @@ dependencies = [
|
||||||
"zeroize",
|
"zeroize",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num-complex"
|
||||||
|
version = "0.4.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "02e0d21255c828d6f128a1e41534206671e8c3ea0c62f32291e808dc82cff17d"
|
||||||
|
dependencies = [
|
||||||
|
"num-traits",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "num-integer"
|
name = "num-integer"
|
||||||
version = "0.1.45"
|
version = "0.1.45"
|
||||||
|
@ -5245,6 +5277,12 @@ dependencies = [
|
||||||
"rand_core 0.5.1",
|
"rand_core 0.5.1",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rawpointer"
|
||||||
|
version = "0.2.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rayon"
|
name = "rayon"
|
||||||
version = "1.7.0"
|
version = "1.7.0"
|
||||||
|
@ -7920,6 +7958,7 @@ dependencies = [
|
||||||
"language",
|
"language",
|
||||||
"lazy_static",
|
"lazy_static",
|
||||||
"log",
|
"log",
|
||||||
|
"ndarray",
|
||||||
"project",
|
"project",
|
||||||
"rusqlite",
|
"rusqlite",
|
||||||
"serde",
|
"serde",
|
||||||
|
|
|
@ -26,6 +26,7 @@ serde.workspace = true
|
||||||
serde_json.workspace = true
|
serde_json.workspace = true
|
||||||
async-trait.workspace = true
|
async-trait.workspace = true
|
||||||
bincode = "1.3.3"
|
bincode = "1.3.3"
|
||||||
|
ndarray = "0.15.6"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
gpui = { path = "../gpui", features = ["test-support"] }
|
gpui = { path = "../gpui", features = ["test-support"] }
|
||||||
|
|
|
@ -1,5 +1,85 @@
|
||||||
trait VectorSearch {
|
use std::cmp::Ordering;
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use ndarray::{Array1, Array2};
|
||||||
|
|
||||||
|
use crate::db::{DocumentRecord, VectorDatabase};
|
||||||
|
use anyhow::Result;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
pub trait VectorSearch {
|
||||||
// Given a query vector, and a limit to return
|
// Given a query vector, and a limit to return
|
||||||
// Return a vector of id, distance tuples.
|
// Return a vector of id, distance tuples.
|
||||||
fn top_k_search(&self, vec: &Vec<f32>) -> Vec<(usize, f32)>;
|
async fn top_k_search(&mut self, vec: &Vec<f32>, limit: usize) -> Vec<(usize, f32)>;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct BruteForceSearch {
|
||||||
|
document_ids: Vec<usize>,
|
||||||
|
candidate_array: ndarray::Array2<f32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BruteForceSearch {
|
||||||
|
pub fn load() -> Result<Self> {
|
||||||
|
let db = VectorDatabase {};
|
||||||
|
let documents = db.get_documents()?;
|
||||||
|
let embeddings: Vec<&DocumentRecord> = documents.values().into_iter().collect();
|
||||||
|
let mut document_ids = vec![];
|
||||||
|
for i in documents.keys() {
|
||||||
|
document_ids.push(i.to_owned());
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut candidate_array = Array2::<f32>::default((documents.len(), 1536));
|
||||||
|
for (i, mut row) in candidate_array.axis_iter_mut(ndarray::Axis(0)).enumerate() {
|
||||||
|
for (j, col) in row.iter_mut().enumerate() {
|
||||||
|
*col = embeddings[i].embedding.0[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Ok(BruteForceSearch {
|
||||||
|
document_ids,
|
||||||
|
candidate_array,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl VectorSearch for BruteForceSearch {
|
||||||
|
async fn top_k_search(&mut self, vec: &Vec<f32>, limit: usize) -> Vec<(usize, f32)> {
|
||||||
|
let target = Array1::from_vec(vec.to_owned());
|
||||||
|
|
||||||
|
let distances = self.candidate_array.dot(&target);
|
||||||
|
|
||||||
|
let distances = distances.to_vec();
|
||||||
|
|
||||||
|
// construct a tuple vector from the floats, the tuple being (index,float)
|
||||||
|
let mut with_indices = distances
|
||||||
|
.clone()
|
||||||
|
.into_iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(index, value)| (index, value))
|
||||||
|
.collect::<Vec<(usize, f32)>>();
|
||||||
|
|
||||||
|
// sort the tuple vector by float
|
||||||
|
with_indices.sort_by(|&a, &b| match (a.1.is_nan(), b.1.is_nan()) {
|
||||||
|
(true, true) => Ordering::Equal,
|
||||||
|
(true, false) => Ordering::Greater,
|
||||||
|
(false, true) => Ordering::Less,
|
||||||
|
(false, false) => a.1.partial_cmp(&b.1).unwrap(),
|
||||||
|
});
|
||||||
|
|
||||||
|
// extract the sorted indices from the sorted tuple vector
|
||||||
|
let stored_indices = with_indices
|
||||||
|
.into_iter()
|
||||||
|
.map(|(index, value)| index)
|
||||||
|
.collect::<Vec<usize>>();
|
||||||
|
|
||||||
|
let sorted_indices: Vec<usize> = stored_indices.into_iter().rev().collect();
|
||||||
|
|
||||||
|
let mut results = vec![];
|
||||||
|
for idx in sorted_indices[0..limit].to_vec() {
|
||||||
|
results.push((self.document_ids[idx], 1.0 - distances[idx]));
|
||||||
|
}
|
||||||
|
|
||||||
|
return results;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue