From 7937a16002f7fa4abb752f20bce1bf0d810a823e Mon Sep 17 00:00:00 2001 From: KCaverly Date: Mon, 26 Jun 2023 10:34:12 -0400 Subject: [PATCH] added brute force search and VectorSearch trait --- Cargo.lock | 39 ++++++++++++++ crates/vector_store/Cargo.toml | 1 + crates/vector_store/src/search.rs | 84 ++++++++++++++++++++++++++++++- 3 files changed, 122 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 309bcfa378..48952d6c25 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3837,6 +3837,16 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "73cbba799671b762df5a175adf59ce145165747bb891505c43d09aefbbf38beb" +[[package]] +name = "matrixmultiply" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "090126dc04f95dc0d1c1c91f61bdd474b3930ca064c1edc8a849da2c6cbe1e77" +dependencies = [ + "autocfg 1.1.0", + "rawpointer", +] + [[package]] name = "maybe-owned" version = "0.3.4" @@ -4121,6 +4131,19 @@ dependencies = [ "tempfile", ] +[[package]] +name = "ndarray" +version = "0.15.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "rawpointer", +] + [[package]] name = "net2" version = "0.2.38" @@ -4228,6 +4251,15 @@ dependencies = [ "zeroize", ] +[[package]] +name = "num-complex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02e0d21255c828d6f128a1e41534206671e8c3ea0c62f32291e808dc82cff17d" +dependencies = [ + "num-traits", +] + [[package]] name = "num-integer" version = "0.1.45" @@ -5245,6 +5277,12 @@ dependencies = [ "rand_core 0.5.1", ] +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + [[package]] name = "rayon" version = "1.7.0" @@ -7920,6 +7958,7 @@ dependencies = [ "language", "lazy_static", "log", + "ndarray", "project", "rusqlite", "serde", diff --git a/crates/vector_store/Cargo.toml b/crates/vector_store/Cargo.toml index 6446651d5d..8de93c0401 100644 --- a/crates/vector_store/Cargo.toml +++ b/crates/vector_store/Cargo.toml @@ -26,6 +26,7 @@ serde.workspace = true serde_json.workspace = true async-trait.workspace = true bincode = "1.3.3" +ndarray = "0.15.6" [dev-dependencies] gpui = { path = "../gpui", features = ["test-support"] } diff --git a/crates/vector_store/src/search.rs b/crates/vector_store/src/search.rs index 3dc72edbce..6b508b401b 100644 --- a/crates/vector_store/src/search.rs +++ b/crates/vector_store/src/search.rs @@ -1,5 +1,85 @@ -trait VectorSearch { +use std::cmp::Ordering; + +use async_trait::async_trait; +use ndarray::{Array1, Array2}; + +use crate::db::{DocumentRecord, VectorDatabase}; +use anyhow::Result; + +#[async_trait] +pub trait VectorSearch { // Given a query vector, and a limit to return // Return a vector of id, distance tuples. - fn top_k_search(&self, vec: &Vec) -> Vec<(usize, f32)>; + async fn top_k_search(&mut self, vec: &Vec, limit: usize) -> Vec<(usize, f32)>; +} + +pub struct BruteForceSearch { + document_ids: Vec, + candidate_array: ndarray::Array2, +} + +impl BruteForceSearch { + pub fn load() -> Result { + let db = VectorDatabase {}; + let documents = db.get_documents()?; + let embeddings: Vec<&DocumentRecord> = documents.values().into_iter().collect(); + let mut document_ids = vec![]; + for i in documents.keys() { + document_ids.push(i.to_owned()); + } + + let mut candidate_array = Array2::::default((documents.len(), 1536)); + for (i, mut row) in candidate_array.axis_iter_mut(ndarray::Axis(0)).enumerate() { + for (j, col) in row.iter_mut().enumerate() { + *col = embeddings[i].embedding.0[j]; + } + } + + return Ok(BruteForceSearch { + document_ids, + candidate_array, + }); + } +} + +#[async_trait] +impl VectorSearch for BruteForceSearch { + async fn top_k_search(&mut self, vec: &Vec, limit: usize) -> Vec<(usize, f32)> { + let target = Array1::from_vec(vec.to_owned()); + + let distances = self.candidate_array.dot(&target); + + let distances = distances.to_vec(); + + // construct a tuple vector from the floats, the tuple being (index,float) + let mut with_indices = distances + .clone() + .into_iter() + .enumerate() + .map(|(index, value)| (index, value)) + .collect::>(); + + // sort the tuple vector by float + with_indices.sort_by(|&a, &b| match (a.1.is_nan(), b.1.is_nan()) { + (true, true) => Ordering::Equal, + (true, false) => Ordering::Greater, + (false, true) => Ordering::Less, + (false, false) => a.1.partial_cmp(&b.1).unwrap(), + }); + + // extract the sorted indices from the sorted tuple vector + let stored_indices = with_indices + .into_iter() + .map(|(index, value)| index) + .collect::>(); + + let sorted_indices: Vec = stored_indices.into_iter().rev().collect(); + + let mut results = vec![]; + for idx in sorted_indices[0..limit].to_vec() { + results.push((self.document_ids[idx], 1.0 - distances[idx])); + } + + return results; + } }