From 953e928bdb3aa80744a13ff53a197fd798fec0fe Mon Sep 17 00:00:00 2001 From: KCaverly Date: Mon, 26 Jun 2023 19:01:19 -0400 Subject: [PATCH] WIP: Got the streaming matrix multiplication working, and started work on file hashing. Co-authored-by: maxbrunsfeld --- Cargo.lock | 5 + crates/vector_store/Cargo.toml | 5 + crates/vector_store/src/db.rs | 84 ++++-- crates/vector_store/src/embedding.rs | 2 +- crates/vector_store/src/search.rs | 18 +- crates/vector_store/src/vector_store.rs | 243 +++++++++++++----- crates/vector_store/src/vector_store_tests.rs | 136 ++++++++++ 7 files changed, 396 insertions(+), 97 deletions(-) create mode 100644 crates/vector_store/src/vector_store_tests.rs diff --git a/Cargo.lock b/Cargo.lock index 48952d6c25..ff4caaa5a6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7958,13 +7958,18 @@ dependencies = [ "language", "lazy_static", "log", + "matrixmultiply", "ndarray", "project", + "rand 0.8.5", "rusqlite", "serde", "serde_json", + "sha-1 0.10.1", "smol", "tree-sitter", + "tree-sitter-rust", + "unindent", "util", "workspace", ] diff --git a/crates/vector_store/Cargo.toml b/crates/vector_store/Cargo.toml index 8de93c0401..dbe0a2e69c 100644 --- a/crates/vector_store/Cargo.toml +++ b/crates/vector_store/Cargo.toml @@ -27,9 +27,14 @@ serde_json.workspace = true async-trait.workspace = true bincode = "1.3.3" ndarray = "0.15.6" +sha-1 = "0.10.1" +matrixmultiply = "0.3.7" [dev-dependencies] gpui = { path = "../gpui", features = ["test-support"] } language = { path = "../language", features = ["test-support"] } project = { path = "../project", features = ["test-support"] } workspace = { path = "../workspace", features = ["test-support"] } +tree-sitter-rust = "*" +rand.workspace = true +unindent.workspace = true diff --git a/crates/vector_store/src/db.rs b/crates/vector_store/src/db.rs index 4f6da14cab..bcb1090a8d 100644 --- a/crates/vector_store/src/db.rs +++ b/crates/vector_store/src/db.rs @@ -1,4 +1,7 @@ -use std::{collections::HashMap, path::PathBuf}; +use std::{ + collections::HashMap, + path::{Path, PathBuf}, +}; use anyhow::{anyhow, Result}; @@ -13,7 +16,7 @@ use crate::IndexedFile; // This is saving to a local database store within the users dev zed path // Where do we want this to sit? // Assuming near where the workspace DB sits. -const VECTOR_DB_URL: &str = "embeddings_db"; +pub const VECTOR_DB_URL: &str = "embeddings_db"; // Note this is not an appropriate document #[derive(Debug)] @@ -28,7 +31,7 @@ pub struct DocumentRecord { #[derive(Debug)] pub struct FileRecord { pub id: usize, - pub path: String, + pub relative_path: String, pub sha1: String, } @@ -51,9 +54,9 @@ pub struct VectorDatabase { } impl VectorDatabase { - pub fn new() -> Result { + pub fn new(path: &str) -> Result { let this = Self { - db: rusqlite::Connection::open(VECTOR_DB_URL)?, + db: rusqlite::Connection::open(path)?, }; this.initialize_database()?; Ok(this) @@ -63,21 +66,23 @@ impl VectorDatabase { // This will create the database if it doesnt exist // Initialize Vector Databasing Tables - // self.db.execute( - // " - // CREATE TABLE IF NOT EXISTS projects ( - // id INTEGER PRIMARY KEY AUTOINCREMENT, - // path NVARCHAR(100) NOT NULL - // ) - // ", - // [], - // )?; + self.db.execute( + "CREATE TABLE IF NOT EXISTS worktrees ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + absolute_path VARCHAR NOT NULL + ); + CREATE UNIQUE INDEX IF NOT EXISTS worktrees_absolute_path ON worktrees (absolute_path); + ", + [], + )?; self.db.execute( "CREATE TABLE IF NOT EXISTS files ( id INTEGER PRIMARY KEY AUTOINCREMENT, - path NVARCHAR(100) NOT NULL, - sha1 NVARCHAR(40) NOT NULL + worktree_id INTEGER NOT NULL, + relative_path VARCHAR NOT NULL, + sha1 NVARCHAR(40) NOT NULL, + FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE )", [], )?; @@ -87,7 +92,7 @@ impl VectorDatabase { id INTEGER PRIMARY KEY AUTOINCREMENT, file_id INTEGER NOT NULL, offset INTEGER NOT NULL, - name NVARCHAR(100) NOT NULL, + name VARCHAR NOT NULL, embedding BLOB NOT NULL, FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE )", @@ -116,7 +121,7 @@ impl VectorDatabase { pub fn insert_file(&self, indexed_file: IndexedFile) -> Result<()> { // Write to files table, and return generated id. let files_insert = self.db.execute( - "INSERT INTO files (path, sha1) VALUES (?1, ?2)", + "INSERT INTO files (relative_path, sha1) VALUES (?1, ?2)", params![indexed_file.path.to_str(), indexed_file.sha1], )?; @@ -141,12 +146,38 @@ impl VectorDatabase { Ok(()) } + pub fn find_or_create_worktree(&self, worktree_root_path: &Path) -> Result { + self.db.execute( + " + INSERT into worktrees (absolute_path) VALUES (?1) + ON CONFLICT DO NOTHING + ", + params![worktree_root_path.to_string_lossy()], + )?; + Ok(self.db.last_insert_rowid()) + } + + pub fn get_file_hashes(&self, worktree_id: i64) -> Result> { + let mut statement = self + .db + .prepare("SELECT relative_path, sha1 FROM files ORDER BY relative_path")?; + let mut result = Vec::new(); + for row in + statement.query_map([], |row| Ok((row.get::<_, String>(0)?.into(), row.get(1)?)))? + { + result.push(row?); + } + Ok(result) + } + pub fn get_files(&self) -> Result> { - let mut query_statement = self.db.prepare("SELECT id, path, sha1 FROM files")?; + let mut query_statement = self + .db + .prepare("SELECT id, relative_path, sha1 FROM files")?; let result_iter = query_statement.query_map([], |row| { Ok(FileRecord { id: row.get(0)?, - path: row.get(1)?, + relative_path: row.get(1)?, sha1: row.get(2)?, }) })?; @@ -160,6 +191,19 @@ impl VectorDatabase { Ok(pages) } + pub fn for_each_document( + &self, + worktree_id: i64, + mut f: impl FnMut(i64, Embedding), + ) -> Result<()> { + let mut query_statement = self.db.prepare("SELECT id, embedding FROM documents")?; + query_statement + .query_map(params![], |row| Ok((row.get(0)?, row.get(1)?)))? + .filter_map(|row| row.ok()) + .for_each(|row| f(row.0, row.1)); + Ok(()) + } + pub fn get_documents(&self) -> Result> { let mut query_statement = self .db diff --git a/crates/vector_store/src/embedding.rs b/crates/vector_store/src/embedding.rs index f995639e64..86d8494ab4 100644 --- a/crates/vector_store/src/embedding.rs +++ b/crates/vector_store/src/embedding.rs @@ -44,7 +44,7 @@ struct OpenAIEmbeddingUsage { } #[async_trait] -pub trait EmbeddingProvider: Sync { +pub trait EmbeddingProvider: Sync + Send { async fn embed_batch(&self, spans: Vec<&str>) -> Result>>; } diff --git a/crates/vector_store/src/search.rs b/crates/vector_store/src/search.rs index ce8bdd1af4..90a8d874da 100644 --- a/crates/vector_store/src/search.rs +++ b/crates/vector_store/src/search.rs @@ -1,4 +1,4 @@ -use std::cmp::Ordering; +use std::{cmp::Ordering, path::PathBuf}; use async_trait::async_trait; use ndarray::{Array1, Array2}; @@ -20,7 +20,6 @@ pub struct BruteForceSearch { impl BruteForceSearch { pub fn load(db: &VectorDatabase) -> Result { - // let db = VectorDatabase {}; let documents = db.get_documents()?; let embeddings: Vec<&DocumentRecord> = documents.values().into_iter().collect(); let mut document_ids = vec![]; @@ -63,20 +62,5 @@ impl VectorSearch for BruteForceSearch { with_indices.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal)); with_indices.truncate(limit); with_indices - - // // extract the sorted indices from the sorted tuple vector - // let stored_indices = with_indices - // .into_iter() - // .map(|(index, value)| index) - // .collect::>(); - - // let sorted_indices: Vec = stored_indices.into_iter().rev().collect(); - - // let mut results = vec![]; - // for idx in sorted_indices[0..limit].to_vec() { - // results.push((self.document_ids[idx], 1.0 - similarities[idx])); - // } - - // return results; } } diff --git a/crates/vector_store/src/vector_store.rs b/crates/vector_store/src/vector_store.rs index 6e6bedc33a..f34316e950 100644 --- a/crates/vector_store/src/vector_store.rs +++ b/crates/vector_store/src/vector_store.rs @@ -3,16 +3,19 @@ mod embedding; mod parsing; mod search; +#[cfg(test)] +mod vector_store_tests; + use anyhow::{anyhow, Result}; -use db::VectorDatabase; +use db::{VectorDatabase, VECTOR_DB_URL}; use embedding::{DummyEmbeddings, EmbeddingProvider, OpenAIEmbeddings}; -use gpui::{AppContext, Entity, ModelContext, ModelHandle}; +use gpui::{AppContext, Entity, ModelContext, ModelHandle, Task}; use language::LanguageRegistry; use parsing::Document; use project::{Fs, Project}; use search::{BruteForceSearch, VectorSearch}; use smol::channel; -use std::{path::PathBuf, sync::Arc, time::Instant}; +use std::{cmp::Ordering, path::PathBuf, sync::Arc, time::Instant}; use tree_sitter::{Parser, QueryCursor}; use util::{http::HttpClient, ResultExt, TryFutureExt}; use workspace::WorkspaceCreated; @@ -23,7 +26,16 @@ pub fn init( language_registry: Arc, cx: &mut AppContext, ) { - let vector_store = cx.add_model(|cx| VectorStore::new(fs, http_client, language_registry)); + let vector_store = cx.add_model(|cx| { + VectorStore::new( + fs, + VECTOR_DB_URL.to_string(), + Arc::new(OpenAIEmbeddings { + client: http_client, + }), + language_registry, + ) + }); cx.subscribe_global::({ let vector_store = vector_store.clone(); @@ -49,28 +61,36 @@ pub struct IndexedFile { documents: Vec, } -struct SearchResult { - path: PathBuf, - offset: usize, - name: String, - distance: f32, -} - +// struct SearchResult { +// path: PathBuf, +// offset: usize, +// name: String, +// distance: f32, +// } struct VectorStore { fs: Arc, - http_client: Arc, + database_url: Arc, + embedding_provider: Arc, language_registry: Arc, } +pub struct SearchResult { + pub name: String, + pub offset: usize, + pub file_path: PathBuf, +} + impl VectorStore { fn new( fs: Arc, - http_client: Arc, + database_url: String, + embedding_provider: Arc, language_registry: Arc, ) -> Self { Self { fs, - http_client, + database_url: database_url.into(), + embedding_provider, language_registry, } } @@ -79,10 +99,12 @@ impl VectorStore { cursor: &mut QueryCursor, parser: &mut Parser, embedding_provider: &dyn EmbeddingProvider, - fs: &Arc, language_registry: &Arc, file_path: PathBuf, + content: String, ) -> Result { + dbg!(&file_path, &content); + let language = language_registry .language_for_file(&file_path, None) .await?; @@ -97,7 +119,6 @@ impl VectorStore { .as_ref() .ok_or_else(|| anyhow!("no outline query"))?; - let content = fs.load(&file_path).await?; parser.set_language(grammar.ts_language).unwrap(); let tree = parser .parse(&content, None) @@ -142,7 +163,11 @@ impl VectorStore { }); } - fn add_project(&mut self, project: ModelHandle, cx: &mut ModelContext) { + fn add_project( + &mut self, + project: ModelHandle, + cx: &mut ModelContext, + ) -> Task> { let worktree_scans_complete = project .read(cx) .worktrees(cx) @@ -151,7 +176,8 @@ impl VectorStore { let fs = self.fs.clone(); let language_registry = self.language_registry.clone(); - let client = self.http_client.clone(); + let embedding_provider = self.embedding_provider.clone(); + let database_url = self.database_url.clone(); cx.spawn(|_, cx| async move { futures::future::join_all(worktree_scans_complete).await; @@ -163,24 +189,47 @@ impl VectorStore { .collect::>() }); - let (paths_tx, paths_rx) = channel::unbounded::(); + let db = VectorDatabase::new(&database_url)?; + let worktree_root_paths = worktrees + .iter() + .map(|worktree| worktree.abs_path().clone()) + .collect::>(); + let (db, file_hashes) = cx + .background() + .spawn(async move { + let mut hashes = Vec::new(); + for worktree_root_path in worktree_root_paths { + let worktree_id = + db.find_or_create_worktree(worktree_root_path.as_ref())?; + hashes.push((worktree_id, db.get_file_hashes(worktree_id)?)); + } + anyhow::Ok((db, hashes)) + }) + .await?; + + let (paths_tx, paths_rx) = channel::unbounded::<(i64, PathBuf, String)>(); let (indexed_files_tx, indexed_files_rx) = channel::unbounded::(); cx.background() - .spawn(async move { - for worktree in worktrees { - for file in worktree.files(false, 0) { - paths_tx.try_send(worktree.absolutize(&file.path)).unwrap(); + .spawn({ + let fs = fs.clone(); + async move { + for worktree in worktrees.into_iter() { + for file in worktree.files(false, 0) { + let absolute_path = worktree.absolutize(&file.path); + dbg!(&absolute_path); + if let Some(content) = fs.load(&absolute_path).await.log_err() { + dbg!(&content); + paths_tx.try_send((0, absolute_path, content)).unwrap(); + } + } } } }) .detach(); - cx.background() - .spawn({ - let client = client.clone(); - async move { + let db_write_task = cx.background().spawn( + async move { // Initialize Database, creates database and tables if not exists - let db = VectorDatabase::new()?; while let Ok(indexed_file) = indexed_files_rx.recv().await { db.insert_file(indexed_file).log_err(); } @@ -188,39 +237,39 @@ impl VectorStore { // ALL OF THE BELOW IS FOR TESTING, // This should be removed as we find and appropriate place for evaluate our search. - let embedding_provider = OpenAIEmbeddings{ client }; - let queries = vec![ - "compute embeddings for all of the symbols in the codebase, and write them to a database", - "compute an outline view of all of the symbols in a buffer", - "scan a directory on the file system and load all of its children into an in-memory snapshot", - ]; - let embeddings = embedding_provider.embed_batch(queries.clone()).await?; + // let queries = vec![ + // "compute embeddings for all of the symbols in the codebase, and write them to a database", + // "compute an outline view of all of the symbols in a buffer", + // "scan a directory on the file system and load all of its children into an in-memory snapshot", + // ]; + // let embeddings = embedding_provider.embed_batch(queries.clone()).await?; - let t2 = Instant::now(); - let documents = db.get_documents().unwrap(); - let files = db.get_files().unwrap(); - println!("Retrieving all documents from Database: {}", t2.elapsed().as_millis()); + // let t2 = Instant::now(); + // let documents = db.get_documents().unwrap(); + // let files = db.get_files().unwrap(); + // println!("Retrieving all documents from Database: {}", t2.elapsed().as_millis()); - let t1 = Instant::now(); - let mut bfs = BruteForceSearch::load(&db).unwrap(); - println!("Loading BFS to Memory: {:?}", t1.elapsed().as_millis()); - for (idx, embed) in embeddings.into_iter().enumerate() { - let t0 = Instant::now(); - println!("\nQuery: {:?}", queries[idx]); - let results = bfs.top_k_search(&embed, 5).await; - println!("Search Elapsed: {}", t0.elapsed().as_millis()); - for (id, distance) in results { - println!(""); - println!(" distance: {:?}", distance); - println!(" document: {:?}", documents[&id].name); - println!(" path: {:?}", files[&documents[&id].file_id].path); - } + // let t1 = Instant::now(); + // let mut bfs = BruteForceSearch::load(&db).unwrap(); + // println!("Loading BFS to Memory: {:?}", t1.elapsed().as_millis()); + // for (idx, embed) in embeddings.into_iter().enumerate() { + // let t0 = Instant::now(); + // println!("\nQuery: {:?}", queries[idx]); + // let results = bfs.top_k_search(&embed, 5).await; + // println!("Search Elapsed: {}", t0.elapsed().as_millis()); + // for (id, distance) in results { + // println!(""); + // println!(" distance: {:?}", distance); + // println!(" document: {:?}", documents[&id].name); + // println!(" path: {:?}", files[&documents[&id].file_id].relative_path); + // } - } + // } anyhow::Ok(()) - }}.log_err()) - .detach(); + } + .log_err(), + ); let provider = DummyEmbeddings {}; // let provider = OpenAIEmbeddings { client }; @@ -231,14 +280,15 @@ impl VectorStore { scope.spawn(async { let mut parser = Parser::new(); let mut cursor = QueryCursor::new(); - while let Ok(file_path) = paths_rx.recv().await { + while let Ok((worktree_id, file_path, content)) = paths_rx.recv().await + { if let Some(indexed_file) = Self::index_file( &mut cursor, &mut parser, &provider, - &fs, &language_registry, file_path, + content, ) .await .log_err() @@ -250,11 +300,86 @@ impl VectorStore { } }) .await; + drop(indexed_files_tx); + + db_write_task.await; + anyhow::Ok(()) + }) + } + + pub fn search( + &mut self, + phrase: String, + limit: usize, + cx: &mut ModelContext, + ) -> Task>> { + let embedding_provider = self.embedding_provider.clone(); + let database_url = self.database_url.clone(); + cx.spawn(|this, cx| async move { + let database = VectorDatabase::new(database_url.as_ref())?; + + // let embedding = embedding_provider.embed_batch(vec![&phrase]).await?; + // + let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1); + + database.for_each_document(0, |id, embedding| { + dbg!(id, &embedding); + + let similarity = dot(&embedding.0, &embedding.0); + let ix = match results.binary_search_by(|(_, s)| { + s.partial_cmp(&similarity).unwrap_or(Ordering::Equal) + }) { + Ok(ix) => ix, + Err(ix) => ix, + }; + + results.insert(ix, (id, similarity)); + results.truncate(limit); + })?; + + dbg!(&results); + + let ids = results.into_iter().map(|(id, _)| id).collect::>(); + // let documents = database.get_documents_by_ids(ids)?; + + // let search_provider = cx + // .background() + // .spawn(async move { BruteForceSearch::load(&database) }) + // .await?; + + // let results = search_provider.top_k_search(&embedding, limit)) + + anyhow::Ok(vec![]) }) - .detach(); } } impl Entity for VectorStore { type Event = (); } + +fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 { + let len = vec_a.len(); + assert_eq!(len, vec_b.len()); + + let mut result = 0.0; + unsafe { + matrixmultiply::sgemm( + 1, + len, + 1, + 1.0, + vec_a.as_ptr(), + len as isize, + 1, + vec_b.as_ptr(), + 1, + len as isize, + 0.0, + &mut result as *mut f32, + 1, + 1, + ); + } + result +} diff --git a/crates/vector_store/src/vector_store_tests.rs b/crates/vector_store/src/vector_store_tests.rs new file mode 100644 index 0000000000..f3d01835e9 --- /dev/null +++ b/crates/vector_store/src/vector_store_tests.rs @@ -0,0 +1,136 @@ +use std::sync::Arc; + +use crate::{dot, embedding::EmbeddingProvider, VectorStore}; +use anyhow::Result; +use async_trait::async_trait; +use gpui::{Task, TestAppContext}; +use language::{Language, LanguageConfig, LanguageRegistry}; +use project::{FakeFs, Project}; +use rand::Rng; +use serde_json::json; +use unindent::Unindent; + +#[gpui::test] +async fn test_vector_store(cx: &mut TestAppContext) { + let fs = FakeFs::new(cx.background()); + fs.insert_tree( + "/the-root", + json!({ + "src": { + "file1.rs": " + fn aaa() { + println!(\"aaaa!\"); + } + + fn zzzzzzzzz() { + println!(\"SLEEPING\"); + } + ".unindent(), + "file2.rs": " + fn bbb() { + println!(\"bbbb!\"); + } + ".unindent(), + } + }), + ) + .await; + + let languages = Arc::new(LanguageRegistry::new(Task::ready(()))); + let rust_language = Arc::new( + Language::new( + LanguageConfig { + name: "Rust".into(), + path_suffixes: vec!["rs".into()], + ..Default::default() + }, + Some(tree_sitter_rust::language()), + ) + .with_outline_query( + r#" + (function_item + name: (identifier) @name + body: (block)) @item + "#, + ) + .unwrap(), + ); + languages.add(rust_language); + + let store = cx.add_model(|_| { + VectorStore::new( + fs.clone(), + "foo".to_string(), + Arc::new(FakeEmbeddingProvider), + languages, + ) + }); + + let project = Project::test(fs, ["/the-root".as_ref()], cx).await; + store + .update(cx, |store, cx| store.add_project(project, cx)) + .await + .unwrap(); + + let search_results = store + .update(cx, |store, cx| store.search("aaaa".to_string(), 5, cx)) + .await + .unwrap(); + + assert_eq!(search_results[0].offset, 0); + assert_eq!(search_results[1].name, "aaa"); +} + +#[test] +fn test_dot_product() { + assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.); + assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.); + + for _ in 0..100 { + let mut rng = rand::thread_rng(); + let a: [f32; 32] = rng.gen(); + let b: [f32; 32] = rng.gen(); + assert_eq!( + round_to_decimals(dot(&a, &b), 3), + round_to_decimals(reference_dot(&a, &b), 3) + ); + } + + fn round_to_decimals(n: f32, decimal_places: i32) -> f32 { + let factor = (10.0 as f32).powi(decimal_places); + (n * factor).round() / factor + } + + fn reference_dot(a: &[f32], b: &[f32]) -> f32 { + a.iter().zip(b.iter()).map(|(a, b)| a * b).sum() + } +} + +struct FakeEmbeddingProvider; + +#[async_trait] +impl EmbeddingProvider for FakeEmbeddingProvider { + async fn embed_batch(&self, spans: Vec<&str>) -> Result>> { + Ok(spans + .iter() + .map(|span| { + let mut result = vec![0.0; 26]; + for letter in span.chars() { + if letter as u32 > 'a' as u32 { + let ix = (letter as u32) - ('a' as u32); + if ix < 26 { + result[ix as usize] += 1.0; + } + } + } + + let norm = result.iter().map(|x| x * x).sum::().sqrt(); + for x in &mut result { + *x /= norm; + } + + result + }) + .collect()) + } +}