diff --git a/crates/vector_store/src/db.rs b/crates/vector_store/src/db.rs index bc5a7fd497..4f6da14cab 100644 --- a/crates/vector_store/src/db.rs +++ b/crates/vector_store/src/db.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::{collections::HashMap, path::PathBuf}; use anyhow::{anyhow, Result}; @@ -46,31 +46,50 @@ impl FromSql for Embedding { } } -pub struct VectorDatabase {} +pub struct VectorDatabase { + db: rusqlite::Connection, +} impl VectorDatabase { - pub async fn initialize_database() -> Result<()> { + pub fn new() -> Result { + let this = Self { + db: rusqlite::Connection::open(VECTOR_DB_URL)?, + }; + this.initialize_database()?; + Ok(this) + } + + fn initialize_database(&self) -> Result<()> { // This will create the database if it doesnt exist - let db = rusqlite::Connection::open(VECTOR_DB_URL)?; // Initialize Vector Databasing Tables - db.execute( + // self.db.execute( + // " + // CREATE TABLE IF NOT EXISTS projects ( + // id INTEGER PRIMARY KEY AUTOINCREMENT, + // path NVARCHAR(100) NOT NULL + // ) + // ", + // [], + // )?; + + self.db.execute( "CREATE TABLE IF NOT EXISTS files ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - path NVARCHAR(100) NOT NULL, - sha1 NVARCHAR(40) NOT NULL - )", + id INTEGER PRIMARY KEY AUTOINCREMENT, + path NVARCHAR(100) NOT NULL, + sha1 NVARCHAR(40) NOT NULL + )", [], )?; - db.execute( + self.db.execute( "CREATE TABLE IF NOT EXISTS documents ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - file_id INTEGER NOT NULL, - offset INTEGER NOT NULL, - name NVARCHAR(100) NOT NULL, - embedding BLOB NOT NULL, - FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE + id INTEGER PRIMARY KEY AUTOINCREMENT, + file_id INTEGER NOT NULL, + offset INTEGER NOT NULL, + name NVARCHAR(100) NOT NULL, + embedding BLOB NOT NULL, + FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE )", [], )?; @@ -78,23 +97,37 @@ impl VectorDatabase { Ok(()) } - pub async fn insert_file(indexed_file: IndexedFile) -> Result<()> { - // Write to files table, and return generated id. - let db = rusqlite::Connection::open(VECTOR_DB_URL)?; + // pub async fn get_or_create_project(project_path: PathBuf) -> Result { + // // Check if we have the project, if we do, return the ID + // // If we do not have the project, insert the project and return the ID - let files_insert = db.execute( + // let db = rusqlite::Connection::open(VECTOR_DB_URL)?; + + // let projects_query = db.prepare(&format!( + // "SELECT id FROM projects WHERE path = {}", + // project_path.to_str().unwrap() // This is unsafe + // ))?; + + // let project_id = db.last_insert_rowid(); + + // return Ok(project_id as usize); + // } + + pub fn insert_file(&self, indexed_file: IndexedFile) -> Result<()> { + // Write to files table, and return generated id. + let files_insert = self.db.execute( "INSERT INTO files (path, sha1) VALUES (?1, ?2)", params![indexed_file.path.to_str(), indexed_file.sha1], )?; - let inserted_id = db.last_insert_rowid(); + let inserted_id = self.db.last_insert_rowid(); // Currently inserting at approximately 3400 documents a second // I imagine we can speed this up with a bulk insert of some kind. for document in indexed_file.documents { let embedding_blob = bincode::serialize(&document.embedding)?; - db.execute( + self.db.execute( "INSERT INTO documents (file_id, offset, name, embedding) VALUES (?1, ?2, ?3, ?4)", params![ inserted_id, @@ -109,70 +142,42 @@ impl VectorDatabase { } pub fn get_files(&self) -> Result> { - let db = rusqlite::Connection::open(VECTOR_DB_URL)?; - - fn query(db: Connection) -> rusqlite::Result> { - let mut query_statement = db.prepare("SELECT id, path, sha1 FROM files")?; - let result_iter = query_statement.query_map([], |row| { - Ok(FileRecord { - id: row.get(0)?, - path: row.get(1)?, - sha1: row.get(2)?, - }) - })?; - - let mut results = vec![]; - for result in result_iter { - results.push(result?); - } - - return Ok(results); - } + let mut query_statement = self.db.prepare("SELECT id, path, sha1 FROM files")?; + let result_iter = query_statement.query_map([], |row| { + Ok(FileRecord { + id: row.get(0)?, + path: row.get(1)?, + sha1: row.get(2)?, + }) + })?; let mut pages: HashMap = HashMap::new(); - let result_iter = query(db); - if result_iter.is_ok() { - for result in result_iter.unwrap() { - pages.insert(result.id, result); - } + for result in result_iter { + let result = result?; + pages.insert(result.id, result); } - return Ok(pages); + Ok(pages) } pub fn get_documents(&self) -> Result> { - // Should return a HashMap in which the key is the id, and the value is the finished document - - // Get Data from Database - let db = rusqlite::Connection::open(VECTOR_DB_URL)?; - - fn query(db: Connection) -> rusqlite::Result> { - let mut query_statement = - db.prepare("SELECT id, file_id, offset, name, embedding FROM documents")?; - let result_iter = query_statement.query_map([], |row| { - Ok(DocumentRecord { - id: row.get(0)?, - file_id: row.get(1)?, - offset: row.get(2)?, - name: row.get(3)?, - embedding: row.get(4)?, - }) - })?; - - let mut results = vec![]; - for result in result_iter { - results.push(result?); - } - - return Ok(results); - } + let mut query_statement = self + .db + .prepare("SELECT id, file_id, offset, name, embedding FROM documents")?; + let result_iter = query_statement.query_map([], |row| { + Ok(DocumentRecord { + id: row.get(0)?, + file_id: row.get(1)?, + offset: row.get(2)?, + name: row.get(3)?, + embedding: row.get(4)?, + }) + })?; let mut documents: HashMap = HashMap::new(); - let result_iter = query(db); - if result_iter.is_ok() { - for result in result_iter.unwrap() { - documents.insert(result.id, result); - } + for result in result_iter { + let result = result?; + documents.insert(result.id, result); } return Ok(documents); diff --git a/crates/vector_store/src/embedding.rs b/crates/vector_store/src/embedding.rs index 903c2451b3..f995639e64 100644 --- a/crates/vector_store/src/embedding.rs +++ b/crates/vector_store/src/embedding.rs @@ -94,16 +94,6 @@ impl EmbeddingProvider for OpenAIEmbeddings { response.usage.total_tokens ); - // do we need to re-order these based on the `index` field? - eprintln!( - "indices: {:?}", - response - .data - .iter() - .map(|embedding| embedding.index) - .collect::>() - ); - Ok(response .data .into_iter() diff --git a/crates/vector_store/src/search.rs b/crates/vector_store/src/search.rs index 6b508b401b..ce8bdd1af4 100644 --- a/crates/vector_store/src/search.rs +++ b/crates/vector_store/src/search.rs @@ -19,8 +19,8 @@ pub struct BruteForceSearch { } impl BruteForceSearch { - pub fn load() -> Result { - let db = VectorDatabase {}; + pub fn load(db: &VectorDatabase) -> Result { + // let db = VectorDatabase {}; let documents = db.get_documents()?; let embeddings: Vec<&DocumentRecord> = documents.values().into_iter().collect(); let mut document_ids = vec![]; @@ -47,39 +47,36 @@ impl VectorSearch for BruteForceSearch { async fn top_k_search(&mut self, vec: &Vec, limit: usize) -> Vec<(usize, f32)> { let target = Array1::from_vec(vec.to_owned()); - let distances = self.candidate_array.dot(&target); + let similarities = self.candidate_array.dot(&target); - let distances = distances.to_vec(); + let similarities = similarities.to_vec(); // construct a tuple vector from the floats, the tuple being (index,float) - let mut with_indices = distances - .clone() - .into_iter() + let mut with_indices = similarities + .iter() + .copied() .enumerate() - .map(|(index, value)| (index, value)) + .map(|(index, value)| (self.document_ids[index], value)) .collect::>(); // sort the tuple vector by float - with_indices.sort_by(|&a, &b| match (a.1.is_nan(), b.1.is_nan()) { - (true, true) => Ordering::Equal, - (true, false) => Ordering::Greater, - (false, true) => Ordering::Less, - (false, false) => a.1.partial_cmp(&b.1).unwrap(), - }); + with_indices.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal)); + with_indices.truncate(limit); + with_indices - // extract the sorted indices from the sorted tuple vector - let stored_indices = with_indices - .into_iter() - .map(|(index, value)| index) - .collect::>(); + // // extract the sorted indices from the sorted tuple vector + // let stored_indices = with_indices + // .into_iter() + // .map(|(index, value)| index) + // .collect::>(); - let sorted_indices: Vec = stored_indices.into_iter().rev().collect(); + // let sorted_indices: Vec = stored_indices.into_iter().rev().collect(); - let mut results = vec![]; - for idx in sorted_indices[0..limit].to_vec() { - results.push((self.document_ids[idx], 1.0 - distances[idx])); - } + // let mut results = vec![]; + // for idx in sorted_indices[0..limit].to_vec() { + // results.push((self.document_ids[idx], 1.0 - similarities[idx])); + // } - return results; + // return results; } } diff --git a/crates/vector_store/src/vector_store.rs b/crates/vector_store/src/vector_store.rs index 0b6d2928cc..6e6bedc33a 100644 --- a/crates/vector_store/src/vector_store.rs +++ b/crates/vector_store/src/vector_store.rs @@ -1,5 +1,6 @@ mod db; mod embedding; +mod parsing; mod search; use anyhow::{anyhow, Result}; @@ -7,11 +8,13 @@ use db::VectorDatabase; use embedding::{DummyEmbeddings, EmbeddingProvider, OpenAIEmbeddings}; use gpui::{AppContext, Entity, ModelContext, ModelHandle}; use language::LanguageRegistry; +use parsing::Document; use project::{Fs, Project}; +use search::{BruteForceSearch, VectorSearch}; use smol::channel; use std::{path::PathBuf, sync::Arc, time::Instant}; use tree_sitter::{Parser, QueryCursor}; -use util::{http::HttpClient, ResultExt}; +use util::{http::HttpClient, ResultExt, TryFutureExt}; use workspace::WorkspaceCreated; pub fn init( @@ -39,13 +42,6 @@ pub fn init( .detach(); } -#[derive(Debug)] -pub struct Document { - pub offset: usize, - pub name: String, - pub embedding: Vec, -} - #[derive(Debug)] pub struct IndexedFile { path: PathBuf, @@ -180,18 +176,54 @@ impl VectorStore { .detach(); cx.background() - .spawn(async move { + .spawn({ + let client = client.clone(); + async move { // Initialize Database, creates database and tables if not exists - VectorDatabase::initialize_database().await.log_err(); + let db = VectorDatabase::new()?; while let Ok(indexed_file) = indexed_files_rx.recv().await { - VectorDatabase::insert_file(indexed_file).await.log_err(); + db.insert_file(indexed_file).log_err(); + } + + // ALL OF THE BELOW IS FOR TESTING, + // This should be removed as we find and appropriate place for evaluate our search. + + let embedding_provider = OpenAIEmbeddings{ client }; + let queries = vec![ + "compute embeddings for all of the symbols in the codebase, and write them to a database", + "compute an outline view of all of the symbols in a buffer", + "scan a directory on the file system and load all of its children into an in-memory snapshot", + ]; + let embeddings = embedding_provider.embed_batch(queries.clone()).await?; + + let t2 = Instant::now(); + let documents = db.get_documents().unwrap(); + let files = db.get_files().unwrap(); + println!("Retrieving all documents from Database: {}", t2.elapsed().as_millis()); + + let t1 = Instant::now(); + let mut bfs = BruteForceSearch::load(&db).unwrap(); + println!("Loading BFS to Memory: {:?}", t1.elapsed().as_millis()); + for (idx, embed) in embeddings.into_iter().enumerate() { + let t0 = Instant::now(); + println!("\nQuery: {:?}", queries[idx]); + let results = bfs.top_k_search(&embed, 5).await; + println!("Search Elapsed: {}", t0.elapsed().as_millis()); + for (id, distance) in results { + println!(""); + println!(" distance: {:?}", distance); + println!(" document: {:?}", documents[&id].name); + println!(" path: {:?}", files[&documents[&id].file_id].path); + } + } anyhow::Ok(()) - }) + }}.log_err()) .detach(); let provider = DummyEmbeddings {}; + // let provider = OpenAIEmbeddings { client }; cx.background() .scoped(|scope| {