From 80a894b82999d4e562a18800568c8f712a705e6e Mon Sep 17 00:00:00 2001 From: KCaverly Date: Wed, 21 Jun 2023 14:53:08 -0400 Subject: [PATCH 01/51] WIP: started work on vector store db, by walking project worktrees.\n\nCo-Authored-By: Max --- Cargo.lock | 15 +++ Cargo.toml | 1 + crates/vector_store/Cargo.toml | 25 +++++ crates/vector_store/README.md | 31 ++++++ crates/vector_store/src/vector_store.rs | 134 ++++++++++++++++++++++++ crates/zed/Cargo.toml | 1 + crates/zed/src/main.rs | 1 + 7 files changed, 208 insertions(+) create mode 100644 crates/vector_store/Cargo.toml create mode 100644 crates/vector_store/README.md create mode 100644 crates/vector_store/src/vector_store.rs diff --git a/Cargo.lock b/Cargo.lock index a4b12223e5..3bf0a568a2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7877,6 +7877,20 @@ version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" +[[package]] +name = "vector_store" +version = "0.1.0" +dependencies = [ + "anyhow", + "futures 0.3.28", + "gpui", + "language", + "project", + "smol", + "util", + "workspace", +] + [[package]] name = "version_check" version = "0.9.4" @@ -8917,6 +8931,7 @@ dependencies = [ "urlencoding", "util", "uuid 1.3.2", + "vector_store", "vim", "welcome", "workspace", diff --git a/Cargo.toml b/Cargo.toml index fca7355964..b1faf158df 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -63,6 +63,7 @@ members = [ "crates/theme_selector", "crates/theme_testbench", "crates/util", + "crates/vector_store", "crates/vim", "crates/workspace", "crates/welcome", diff --git a/crates/vector_store/Cargo.toml b/crates/vector_store/Cargo.toml new file mode 100644 index 0000000000..c33a35bcad --- /dev/null +++ b/crates/vector_store/Cargo.toml @@ -0,0 +1,25 @@ +[package] +name = "vector_store" +version = "0.1.0" +edition = "2021" +publish = false + +[lib] +path = "src/vector_store.rs" +doctest = false + +[dependencies] +gpui = { path = "../gpui" } +language = { path = "../language" } +project = { path = "../project" } +workspace = { path = "../workspace" } +util = { path = "../util" } +anyhow.workspace = true +futures.workspace = true +smol.workspace = true + +[dev-dependencies] +gpui = { path = "../gpui", features = ["test-support"] } +language = { path = "../language", features = ["test-support"] } +project = { path = "../project", features = ["test-support"] } +workspace = { path = "../workspace", features = ["test-support"] } diff --git a/crates/vector_store/README.md b/crates/vector_store/README.md new file mode 100644 index 0000000000..86e68dc414 --- /dev/null +++ b/crates/vector_store/README.md @@ -0,0 +1,31 @@ + +WIP: Sample SQL Queries +/* + +create table "files" ( +"id" INTEGER PRIMARY KEY, +"path" VARCHAR, +"sha1" VARCHAR, +); + +create table symbols ( +"file_id" INTEGER REFERENCES("files", "id") ON CASCADE DELETE, +"offset" INTEGER, +"embedding" VECTOR, +); + +insert into "files" ("path", "sha1") values ("src/main.rs", "sha1") return id; +insert into symbols ( +"file_id", +"start", +"end", +"embedding" +) values ( +(id,), +(id,), +(id,), +(id,), +) + + +*/ diff --git a/crates/vector_store/src/vector_store.rs b/crates/vector_store/src/vector_store.rs new file mode 100644 index 0000000000..1556df7ebe --- /dev/null +++ b/crates/vector_store/src/vector_store.rs @@ -0,0 +1,134 @@ +use anyhow::{anyhow, Result}; +use gpui::{AppContext, Entity, ModelContext, ModelHandle}; +use language::LanguageRegistry; +use project::{Fs, Project}; +use smol::channel; +use std::{path::PathBuf, sync::Arc}; +use util::ResultExt; +use workspace::WorkspaceCreated; + +pub fn init(fs: Arc, language_registry: Arc, cx: &mut AppContext) { + let vector_store = cx.add_model(|cx| VectorStore::new(fs, language_registry)); + + cx.subscribe_global::({ + let vector_store = vector_store.clone(); + move |event, cx| { + let workspace = &event.0; + if let Some(workspace) = workspace.upgrade(cx) { + let project = workspace.read(cx).project().clone(); + if project.read(cx).is_local() { + vector_store.update(cx, |store, cx| { + store.add_project(project, cx); + }); + } + } + } + }) + .detach(); +} + +struct Document { + offset: usize, + name: String, + embedding: Vec, +} + +struct IndexedFile { + path: PathBuf, + sha1: String, + documents: Vec, +} + +struct SearchResult { + path: PathBuf, + offset: usize, + name: String, + distance: f32, +} + +struct VectorStore { + fs: Arc, + language_registry: Arc, +} + +impl VectorStore { + fn new(fs: Arc, language_registry: Arc) -> Self { + Self { + fs, + language_registry, + } + } + + async fn index_file( + fs: &Arc, + language_registry: &Arc, + file_path: PathBuf, + ) -> Result { + eprintln!("indexing file {file_path:?}"); + Err(anyhow!("not implemented")) + // todo!(); + } + + fn add_project(&mut self, project: ModelHandle, cx: &mut ModelContext) { + let worktree_scans_complete = project + .read(cx) + .worktrees(cx) + .map(|worktree| worktree.read(cx).as_local().unwrap().scan_complete()) + .collect::>(); + + let fs = self.fs.clone(); + let language_registry = self.language_registry.clone(); + + cx.spawn(|this, cx| async move { + futures::future::join_all(worktree_scans_complete).await; + + let worktrees = project.read_with(&cx, |project, cx| { + project + .worktrees(cx) + .map(|worktree| worktree.read(cx).snapshot()) + .collect::>() + }); + + let (paths_tx, paths_rx) = channel::unbounded::(); + let (indexed_files_tx, indexed_files_rx) = channel::unbounded::(); + cx.background() + .spawn(async move { + for worktree in worktrees { + for file in worktree.files(false, 0) { + paths_tx.try_send(worktree.absolutize(&file.path)).unwrap(); + } + } + }) + .detach(); + cx.background() + .spawn(async move { + while let Ok(indexed_file) = indexed_files_rx.recv().await { + // write document to database + } + }) + .detach(); + cx.background() + .scoped(|scope| { + for _ in 0..cx.background().num_cpus() { + scope.spawn(async { + while let Ok(file_path) = paths_rx.recv().await { + if let Some(indexed_file) = + Self::index_file(&fs, &language_registry, file_path) + .await + .log_err() + { + indexed_files_tx.try_send(indexed_file).unwrap(); + } + } + }); + } + }) + .await; + }) + .detach(); + } +} + +impl Entity for VectorStore { + type Event = (); +} diff --git a/crates/zed/Cargo.toml b/crates/zed/Cargo.toml index d8e47d1c3e..26e27a9193 100644 --- a/crates/zed/Cargo.toml +++ b/crates/zed/Cargo.toml @@ -64,6 +64,7 @@ theme = { path = "../theme" } theme_selector = { path = "../theme_selector" } theme_testbench = { path = "../theme_testbench" } util = { path = "../util" } +vector_store = { path = "../vector_store" } vim = { path = "../vim" } workspace = { path = "../workspace" } welcome = { path = "../welcome" } diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index dcdf5c1ea5..76d02307f6 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -152,6 +152,7 @@ fn main() { project_panel::init(cx); diagnostics::init(cx); search::init(cx); + vector_store::init(fs.clone(), languages.clone(), cx); vim::init(cx); terminal_view::init(cx); theme_testbench::init(cx); From d4a4db42aa4d96c2576713bd86260d38e8febc8f Mon Sep 17 00:00:00 2001 From: KCaverly Date: Thu, 22 Jun 2023 13:25:33 -0400 Subject: [PATCH 02/51] WIP: started DB creating and naive inserts --- Cargo.lock | 19 +++++ crates/vector_store/Cargo.toml | 4 + crates/vector_store/src/db.rs | 107 ++++++++++++++++++++++++ crates/vector_store/src/vector_store.rs | 38 +++++++-- 4 files changed, 161 insertions(+), 7 deletions(-) create mode 100644 crates/vector_store/src/db.rs diff --git a/Cargo.lock b/Cargo.lock index 3bf0a568a2..beb84e04bd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1389,6 +1389,15 @@ dependencies = [ "theme", ] +[[package]] +name = "conv" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78ff10625fd0ac447827aa30ea8b861fead473bb60aeb73af6c1c58caf0d1299" +dependencies = [ + "custom_derive", +] + [[package]] name = "copilot" version = "0.1.0" @@ -1766,6 +1775,12 @@ dependencies = [ "winapi 0.3.9", ] +[[package]] +name = "custom_derive" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef8ae57c4978a2acd8b869ce6b9ca1dfe817bff704c220209fdef2c0b75a01b9" + [[package]] name = "cxx" version = "1.0.94" @@ -7882,11 +7897,15 @@ name = "vector_store" version = "0.1.0" dependencies = [ "anyhow", + "async-compat", + "conv", "futures 0.3.28", "gpui", "language", "project", + "rand 0.8.5", "smol", + "sqlx", "util", "workspace", ] diff --git a/crates/vector_store/Cargo.toml b/crates/vector_store/Cargo.toml index c33a35bcad..74ad23740e 100644 --- a/crates/vector_store/Cargo.toml +++ b/crates/vector_store/Cargo.toml @@ -17,6 +17,10 @@ util = { path = "../util" } anyhow.workspace = true futures.workspace = true smol.workspace = true +sqlx = { version = "0.6", features = ["sqlite","runtime-tokio-rustls"] } +async-compat = "0.2.1" +conv = "0.3.3" +rand.workspace = true [dev-dependencies] gpui = { path = "../gpui", features = ["test-support"] } diff --git a/crates/vector_store/src/db.rs b/crates/vector_store/src/db.rs new file mode 100644 index 0000000000..dfa85044d6 --- /dev/null +++ b/crates/vector_store/src/db.rs @@ -0,0 +1,107 @@ +use anyhow::Result; +use async_compat::{Compat, CompatExt}; +use conv::ValueFrom; +use sqlx::{migrate::MigrateDatabase, Pool, Sqlite, SqlitePool}; +use std::time::{Duration, Instant}; + +use crate::IndexedFile; + +// This is saving to a local database store within the users dev zed path +// Where do we want this to sit? +// Assuming near where the workspace DB sits. +const VECTOR_DB_URL: &str = "embeddings_db"; + +pub struct VectorDatabase {} + +impl VectorDatabase { + pub async fn initialize_database() -> Result<()> { + // If database doesnt exist create database + if !Sqlite::database_exists(VECTOR_DB_URL) + .compat() + .await + .unwrap_or(false) + { + Sqlite::create_database(VECTOR_DB_URL).compat().await?; + } + + let db = SqlitePool::connect(VECTOR_DB_URL).compat().await?; + + // Initialize Vector Databasing Tables + // We may be able to skip this assuming the database is never created + // without creating the tables at the same time. + sqlx::query( + "CREATE TABLE IF NOT EXISTS files ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + path NVARCHAR(100) NOT NULL, + sha1 NVARCHAR(40) NOT NULL + )", + ) + .execute(&db) + .compat() + .await?; + + sqlx::query( + "CREATE TABLE IF NOT EXISTS documents ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + file_id INTEGER NOT NULL, + offset INTEGER NOT NULL, + name NVARCHAR(100) NOT NULL, + embedding BLOB NOT NULL, + FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE + )", + ) + .execute(&db) + .compat() + .await?; + + Ok(()) + } + + pub async fn insert_file(indexed_file: IndexedFile) -> Result<()> { + // Write to files table, and return generated id. + let db = SqlitePool::connect(VECTOR_DB_URL).compat().await?; + + let files_insert = sqlx::query("INSERT INTO files (path, sha1) VALUES ($1, $2)") + .bind(indexed_file.path.to_str()) + .bind(indexed_file.sha1) + .execute(&db) + .compat() + .await?; + + let inserted_id = files_insert.last_insert_rowid(); + + // I stole this from https://stackoverflow.com/questions/71829931/how-do-i-convert-a-negative-f32-value-to-binary-string-and-back-again + // I imagine there is a better way to serialize to/from blob + fn get_binary_from_values(values: Vec) -> String { + let bits: Vec<_> = values.iter().map(|v| v.to_bits().to_string()).collect(); + bits.join(";") + } + + fn get_values_from_binary(bin: &str) -> Vec { + (0..bin.len() / 32) + .map(|i| { + let start = i * 32; + let end = start + 32; + f32::from_bits(u32::from_str_radix(&bin[start..end], 2).unwrap()) + }) + .collect() + } + + // Currently inserting at approximately 3400 documents a second + // I imagine we can speed this up with a bulk insert of some kind. + for document in indexed_file.documents { + sqlx::query( + "INSERT INTO documents (file_id, offset, name, embedding) VALUES ($1, $2, $3, $4)", + ) + .bind(inserted_id) + .bind(document.offset.to_string()) + .bind(document.name) + .bind(get_binary_from_values(document.embedding)) + .execute(&db) + .compat() + .await?; + } + + Ok(()) + } +} diff --git a/crates/vector_store/src/vector_store.rs b/crates/vector_store/src/vector_store.rs index 1556df7ebe..93f9fbe06d 100644 --- a/crates/vector_store/src/vector_store.rs +++ b/crates/vector_store/src/vector_store.rs @@ -1,9 +1,12 @@ -use anyhow::{anyhow, Result}; +mod db; +use anyhow::Result; +use db::VectorDatabase; use gpui::{AppContext, Entity, ModelContext, ModelHandle}; use language::LanguageRegistry; use project::{Fs, Project}; +use rand::Rng; use smol::channel; -use std::{path::PathBuf, sync::Arc}; +use std::{path::PathBuf, sync::Arc, time::Instant}; use util::ResultExt; use workspace::WorkspaceCreated; @@ -27,13 +30,15 @@ pub fn init(fs: Arc, language_registry: Arc, cx: &mut .detach(); } +#[derive(Debug, sqlx::FromRow)] struct Document { offset: usize, name: String, embedding: Vec, } -struct IndexedFile { +#[derive(Debug, sqlx::FromRow)] +pub struct IndexedFile { path: PathBuf, sha1: String, documents: Vec, @@ -64,9 +69,24 @@ impl VectorStore { language_registry: &Arc, file_path: PathBuf, ) -> Result { - eprintln!("indexing file {file_path:?}"); - Err(anyhow!("not implemented")) - // todo!(); + // This is creating dummy documents to test the database writes. + let mut documents = vec![]; + let mut rng = rand::thread_rng(); + let rand_num_of_documents: u8 = rng.gen_range(0..200); + for _ in 0..rand_num_of_documents { + let doc = Document { + offset: 0, + name: "test symbol".to_string(), + embedding: vec![0.32 as f32; 768], + }; + documents.push(doc); + } + + return Ok(IndexedFile { + path: file_path, + sha1: "asdfasdfasdf".to_string(), + documents, + }); } fn add_project(&mut self, project: ModelHandle, cx: &mut ModelContext) { @@ -100,13 +120,17 @@ impl VectorStore { } }) .detach(); + cx.background() .spawn(async move { + // Initialize Database, creates database and tables if not exists + VectorDatabase::initialize_database().await.log_err(); while let Ok(indexed_file) = indexed_files_rx.recv().await { - // write document to database + VectorDatabase::insert_file(indexed_file).await.log_err(); } }) .detach(); + cx.background() .scoped(|scope| { for _ in 0..cx.background().num_cpus() { From dd309070eb03dd51041d412ecce553ab43450342 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Thu, 22 Jun 2023 16:50:07 -0400 Subject: [PATCH 03/51] open ai indexing on open for rust files --- Cargo.lock | 57 ++++++++---- crates/language/src/language.rs | 16 ++-- crates/vector_store/Cargo.toml | 10 +- crates/vector_store/src/db.rs | 4 +- crates/vector_store/src/embedding.rs | 100 ++++++++++++++++++++ crates/vector_store/src/vector_store.rs | 118 +++++++++++++++++++----- crates/zed/src/main.rs | 2 +- 7 files changed, 252 insertions(+), 55 deletions(-) create mode 100644 crates/vector_store/src/embedding.rs diff --git a/Cargo.lock b/Cargo.lock index beb84e04bd..5a93ce77af 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1389,15 +1389,6 @@ dependencies = [ "theme", ] -[[package]] -name = "conv" -version = "0.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78ff10625fd0ac447827aa30ea8b861fead473bb60aeb73af6c1c58caf0d1299" -dependencies = [ - "custom_derive", -] - [[package]] name = "copilot" version = "0.1.0" @@ -1775,12 +1766,6 @@ dependencies = [ "winapi 0.3.9", ] -[[package]] -name = "custom_derive" -version = "0.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef8ae57c4978a2acd8b869ce6b9ca1dfe817bff704c220209fdef2c0b75a01b9" - [[package]] name = "cxx" version = "1.0.94" @@ -2219,6 +2204,12 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7" +[[package]] +name = "fallible-streaming-iterator" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" + [[package]] name = "fancy-regex" version = "0.11.0" @@ -2909,6 +2900,15 @@ dependencies = [ "ahash 0.8.3", ] +[[package]] +name = "hashlink" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7249a3129cbc1ffccd74857f81464a323a152173cdb134e0fd81bc803b29facf" +dependencies = [ + "hashbrown 0.11.2", +] + [[package]] name = "hashlink" version = "0.8.1" @@ -5600,6 +5600,21 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rusqlite" +version = "0.27.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85127183a999f7db96d1a976a309eebbfb6ea3b0b400ddd8340190129de6eb7a" +dependencies = [ + "bitflags", + "fallible-iterator", + "fallible-streaming-iterator", + "hashlink 0.7.0", + "libsqlite3-sys", + "memchr", + "smallvec", +] + [[package]] name = "rust-embed" version = "6.6.1" @@ -6531,7 +6546,7 @@ dependencies = [ "futures-executor", "futures-intrusive", "futures-util", - "hashlink", + "hashlink 0.8.1", "hex", "hkdf", "hmac 0.12.1", @@ -7898,14 +7913,20 @@ version = "0.1.0" dependencies = [ "anyhow", "async-compat", - "conv", + "async-trait", "futures 0.3.28", "gpui", + "isahc", "language", + "lazy_static", + "log", "project", - "rand 0.8.5", + "rusqlite", + "serde", + "serde_json", "smol", "sqlx", + "tree-sitter", "util", "workspace", ] diff --git a/crates/language/src/language.rs b/crates/language/src/language.rs index 5a4d604ce3..4c6f709f38 100644 --- a/crates/language/src/language.rs +++ b/crates/language/src/language.rs @@ -476,12 +476,12 @@ pub struct Language { pub struct Grammar { id: usize, - pub(crate) ts_language: tree_sitter::Language, + pub ts_language: tree_sitter::Language, pub(crate) error_query: Query, pub(crate) highlights_query: Option, pub(crate) brackets_config: Option, pub(crate) indents_config: Option, - pub(crate) outline_config: Option, + pub outline_config: Option, pub(crate) injection_config: Option, pub(crate) override_config: Option, pub(crate) highlight_map: Mutex, @@ -495,12 +495,12 @@ struct IndentConfig { outdent_capture_ix: Option, } -struct OutlineConfig { - query: Query, - item_capture_ix: u32, - name_capture_ix: u32, - context_capture_ix: Option, - extra_context_capture_ix: Option, +pub struct OutlineConfig { + pub query: Query, + pub item_capture_ix: u32, + pub name_capture_ix: u32, + pub context_capture_ix: Option, + pub extra_context_capture_ix: Option, } struct InjectionConfig { diff --git a/crates/vector_store/Cargo.toml b/crates/vector_store/Cargo.toml index 74ad23740e..2db672ed25 100644 --- a/crates/vector_store/Cargo.toml +++ b/crates/vector_store/Cargo.toml @@ -19,8 +19,14 @@ futures.workspace = true smol.workspace = true sqlx = { version = "0.6", features = ["sqlite","runtime-tokio-rustls"] } async-compat = "0.2.1" -conv = "0.3.3" -rand.workspace = true +rusqlite = "0.27.0" +isahc.workspace = true +log.workspace = true +tree-sitter.workspace = true +lazy_static.workspace = true +serde.workspace = true +serde_json.workspace = true +async-trait.workspace = true [dev-dependencies] gpui = { path = "../gpui", features = ["test-support"] } diff --git a/crates/vector_store/src/db.rs b/crates/vector_store/src/db.rs index dfa85044d6..d335d327b8 100644 --- a/crates/vector_store/src/db.rs +++ b/crates/vector_store/src/db.rs @@ -1,8 +1,6 @@ use anyhow::Result; use async_compat::{Compat, CompatExt}; -use conv::ValueFrom; -use sqlx::{migrate::MigrateDatabase, Pool, Sqlite, SqlitePool}; -use std::time::{Duration, Instant}; +use sqlx::{migrate::MigrateDatabase, Sqlite, SqlitePool}; use crate::IndexedFile; diff --git a/crates/vector_store/src/embedding.rs b/crates/vector_store/src/embedding.rs new file mode 100644 index 0000000000..f1ae5479ee --- /dev/null +++ b/crates/vector_store/src/embedding.rs @@ -0,0 +1,100 @@ +use anyhow::{anyhow, Result}; +use async_trait::async_trait; +use futures::AsyncReadExt; +use gpui::serde_json; +use isahc::prelude::Configurable; +use lazy_static::lazy_static; +use serde::{Deserialize, Serialize}; +use std::env; +use std::sync::Arc; +use util::http::{HttpClient, Request}; + +lazy_static! { + static ref OPENAI_API_KEY: Option = env::var("OPENAI_API_KEY").ok(); +} + +pub struct OpenAIEmbeddings { + pub client: Arc, +} + +#[derive(Serialize)] +struct OpenAIEmbeddingRequest<'a> { + model: &'static str, + input: Vec<&'a str>, +} + +#[derive(Deserialize)] +struct OpenAIEmbeddingResponse { + data: Vec, + usage: OpenAIEmbeddingUsage, +} + +#[derive(Debug, Deserialize)] +struct OpenAIEmbedding { + embedding: Vec, + index: usize, + object: String, +} + +#[derive(Deserialize)] +struct OpenAIEmbeddingUsage { + prompt_tokens: usize, + total_tokens: usize, +} + +#[async_trait] +pub trait EmbeddingProvider: Sync { + async fn embed_batch(&self, spans: Vec<&str>) -> Result>>; +} + +#[async_trait] +impl EmbeddingProvider for OpenAIEmbeddings { + async fn embed_batch(&self, spans: Vec<&str>) -> Result>> { + let api_key = OPENAI_API_KEY + .as_ref() + .ok_or_else(|| anyhow!("no api key"))?; + + let request = Request::post("https://api.openai.com/v1/embeddings") + .redirect_policy(isahc::config::RedirectPolicy::Follow) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", api_key)) + .body( + serde_json::to_string(&OpenAIEmbeddingRequest { + input: spans, + model: "text-embedding-ada-002", + }) + .unwrap() + .into(), + )?; + + let mut response = self.client.send(request).await?; + if !response.status().is_success() { + return Err(anyhow!("openai embedding failed {}", response.status())); + } + + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?; + + log::info!( + "openai embedding completed. tokens: {:?}", + response.usage.total_tokens + ); + + // do we need to re-order these based on the `index` field? + eprintln!( + "indices: {:?}", + response + .data + .iter() + .map(|embedding| embedding.index) + .collect::>() + ); + + Ok(response + .data + .into_iter() + .map(|embedding| embedding.embedding) + .collect()) + } +} diff --git a/crates/vector_store/src/vector_store.rs b/crates/vector_store/src/vector_store.rs index 93f9fbe06d..f4d5baca80 100644 --- a/crates/vector_store/src/vector_store.rs +++ b/crates/vector_store/src/vector_store.rs @@ -1,17 +1,25 @@ mod db; -use anyhow::Result; +mod embedding; + +use anyhow::{anyhow, Result}; use db::VectorDatabase; +use embedding::{EmbeddingProvider, OpenAIEmbeddings}; use gpui::{AppContext, Entity, ModelContext, ModelHandle}; use language::LanguageRegistry; use project::{Fs, Project}; -use rand::Rng; use smol::channel; use std::{path::PathBuf, sync::Arc, time::Instant}; -use util::ResultExt; +use tree_sitter::{Parser, QueryCursor}; +use util::{http::HttpClient, ResultExt}; use workspace::WorkspaceCreated; -pub fn init(fs: Arc, language_registry: Arc, cx: &mut AppContext) { - let vector_store = cx.add_model(|cx| VectorStore::new(fs, language_registry)); +pub fn init( + fs: Arc, + http_client: Arc, + language_registry: Arc, + cx: &mut AppContext, +) { + let vector_store = cx.add_model(|cx| VectorStore::new(fs, http_client, language_registry)); cx.subscribe_global::({ let vector_store = vector_store.clone(); @@ -53,38 +61,86 @@ struct SearchResult { struct VectorStore { fs: Arc, + http_client: Arc, language_registry: Arc, } impl VectorStore { - fn new(fs: Arc, language_registry: Arc) -> Self { + fn new( + fs: Arc, + http_client: Arc, + language_registry: Arc, + ) -> Self { Self { fs, + http_client, language_registry, } } async fn index_file( + cursor: &mut QueryCursor, + parser: &mut Parser, + embedding_provider: &dyn EmbeddingProvider, fs: &Arc, language_registry: &Arc, file_path: PathBuf, ) -> Result { - // This is creating dummy documents to test the database writes. - let mut documents = vec![]; - let mut rng = rand::thread_rng(); - let rand_num_of_documents: u8 = rng.gen_range(0..200); - for _ in 0..rand_num_of_documents { - let doc = Document { - offset: 0, - name: "test symbol".to_string(), - embedding: vec![0.32 as f32; 768], - }; - documents.push(doc); + let language = language_registry + .language_for_file(&file_path, None) + .await?; + + if language.name().as_ref() != "Rust" { + Err(anyhow!("unsupported language"))?; + } + + let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?; + let outline_config = grammar + .outline_config + .as_ref() + .ok_or_else(|| anyhow!("no outline query"))?; + + let content = fs.load(&file_path).await?; + parser.set_language(grammar.ts_language).unwrap(); + let tree = parser + .parse(&content, None) + .ok_or_else(|| anyhow!("parsing failed"))?; + + let mut documents = Vec::new(); + let mut context_spans = Vec::new(); + for mat in cursor.matches(&outline_config.query, tree.root_node(), content.as_bytes()) { + let mut item_range = None; + let mut name_range = None; + for capture in mat.captures { + if capture.index == outline_config.item_capture_ix { + item_range = Some(capture.node.byte_range()); + } else if capture.index == outline_config.name_capture_ix { + name_range = Some(capture.node.byte_range()); + } + } + + if let Some((item_range, name_range)) = item_range.zip(name_range) { + if let Some((item, name)) = + content.get(item_range.clone()).zip(content.get(name_range)) + { + context_spans.push(item); + documents.push(Document { + name: name.to_string(), + offset: item_range.start, + embedding: Vec::new(), + }); + } + } + } + + let embeddings = embedding_provider.embed_batch(context_spans).await?; + for (document, embedding) in documents.iter_mut().zip(embeddings) { + document.embedding = embedding; } return Ok(IndexedFile { path: file_path, - sha1: "asdfasdfasdf".to_string(), + sha1: String::new(), documents, }); } @@ -98,8 +154,9 @@ impl VectorStore { let fs = self.fs.clone(); let language_registry = self.language_registry.clone(); + let client = self.http_client.clone(); - cx.spawn(|this, cx| async move { + cx.spawn(|_, cx| async move { futures::future::join_all(worktree_scans_complete).await; let worktrees = project.read_with(&cx, |project, cx| { @@ -131,15 +188,27 @@ impl VectorStore { }) .detach(); + let provider = OpenAIEmbeddings { client }; + + let t0 = Instant::now(); + cx.background() .scoped(|scope| { for _ in 0..cx.background().num_cpus() { scope.spawn(async { + let mut parser = Parser::new(); + let mut cursor = QueryCursor::new(); while let Ok(file_path) = paths_rx.recv().await { - if let Some(indexed_file) = - Self::index_file(&fs, &language_registry, file_path) - .await - .log_err() + if let Some(indexed_file) = Self::index_file( + &mut cursor, + &mut parser, + &provider, + &fs, + &language_registry, + file_path, + ) + .await + .log_err() { indexed_files_tx.try_send(indexed_file).unwrap(); } @@ -148,6 +217,9 @@ impl VectorStore { } }) .await; + + let duration = t0.elapsed(); + log::info!("indexed project in {duration:?}"); }) .detach(); } diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index 76d02307f6..8a59bbde41 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -152,7 +152,7 @@ fn main() { project_panel::init(cx); diagnostics::init(cx); search::init(cx); - vector_store::init(fs.clone(), languages.clone(), cx); + vector_store::init(fs.clone(), http.clone(), languages.clone(), cx); vim::init(cx); terminal_view::init(cx); theme_testbench::init(cx); From c071b271be195b0e8af9335469c969e6f1624d6d Mon Sep 17 00:00:00 2001 From: KCaverly Date: Fri, 23 Jun 2023 10:25:12 -0400 Subject: [PATCH 04/51] removed tokio and sqlx dependency, added dummy embeddings provider to save on open ai costs when testing --- Cargo.lock | 2 - crates/vector_store/Cargo.toml | 2 - crates/vector_store/src/db.rs | 74 ++++++++++--------------- crates/vector_store/src/embedding.rs | 12 ++++ crates/vector_store/src/vector_store.rs | 9 +-- 5 files changed, 45 insertions(+), 54 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5a93ce77af..3f13c75dda 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7912,7 +7912,6 @@ name = "vector_store" version = "0.1.0" dependencies = [ "anyhow", - "async-compat", "async-trait", "futures 0.3.28", "gpui", @@ -7925,7 +7924,6 @@ dependencies = [ "serde", "serde_json", "smol", - "sqlx", "tree-sitter", "util", "workspace", diff --git a/crates/vector_store/Cargo.toml b/crates/vector_store/Cargo.toml index 2db672ed25..434f341147 100644 --- a/crates/vector_store/Cargo.toml +++ b/crates/vector_store/Cargo.toml @@ -17,8 +17,6 @@ util = { path = "../util" } anyhow.workspace = true futures.workspace = true smol.workspace = true -sqlx = { version = "0.6", features = ["sqlite","runtime-tokio-rustls"] } -async-compat = "0.2.1" rusqlite = "0.27.0" isahc.workspace = true log.workspace = true diff --git a/crates/vector_store/src/db.rs b/crates/vector_store/src/db.rs index d335d327b8..e2b23f7548 100644 --- a/crates/vector_store/src/db.rs +++ b/crates/vector_store/src/db.rs @@ -1,6 +1,5 @@ use anyhow::Result; -use async_compat::{Compat, CompatExt}; -use sqlx::{migrate::MigrateDatabase, Sqlite, SqlitePool}; +use rusqlite::params; use crate::IndexedFile; @@ -13,32 +12,20 @@ pub struct VectorDatabase {} impl VectorDatabase { pub async fn initialize_database() -> Result<()> { - // If database doesnt exist create database - if !Sqlite::database_exists(VECTOR_DB_URL) - .compat() - .await - .unwrap_or(false) - { - Sqlite::create_database(VECTOR_DB_URL).compat().await?; - } - - let db = SqlitePool::connect(VECTOR_DB_URL).compat().await?; + // This will create the database if it doesnt exist + let db = rusqlite::Connection::open(VECTOR_DB_URL)?; // Initialize Vector Databasing Tables - // We may be able to skip this assuming the database is never created - // without creating the tables at the same time. - sqlx::query( + db.execute( "CREATE TABLE IF NOT EXISTS files ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - path NVARCHAR(100) NOT NULL, - sha1 NVARCHAR(40) NOT NULL - )", - ) - .execute(&db) - .compat() - .await?; + id INTEGER PRIMARY KEY AUTOINCREMENT, + path NVARCHAR(100) NOT NULL, + sha1 NVARCHAR(40) NOT NULL + )", + [], + )?; - sqlx::query( + db.execute( "CREATE TABLE IF NOT EXISTS documents ( id INTEGER PRIMARY KEY AUTOINCREMENT, file_id INTEGER NOT NULL, @@ -47,26 +34,22 @@ impl VectorDatabase { embedding BLOB NOT NULL, FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE )", - ) - .execute(&db) - .compat() - .await?; + [], + )?; Ok(()) } pub async fn insert_file(indexed_file: IndexedFile) -> Result<()> { // Write to files table, and return generated id. - let db = SqlitePool::connect(VECTOR_DB_URL).compat().await?; + let db = rusqlite::Connection::open(VECTOR_DB_URL)?; - let files_insert = sqlx::query("INSERT INTO files (path, sha1) VALUES ($1, $2)") - .bind(indexed_file.path.to_str()) - .bind(indexed_file.sha1) - .execute(&db) - .compat() - .await?; + let files_insert = db.execute( + "INSERT INTO files (path, sha1) VALUES (?1, ?2)", + params![indexed_file.path.to_str(), indexed_file.sha1], + )?; - let inserted_id = files_insert.last_insert_rowid(); + let inserted_id = db.last_insert_rowid(); // I stole this from https://stackoverflow.com/questions/71829931/how-do-i-convert-a-negative-f32-value-to-binary-string-and-back-again // I imagine there is a better way to serialize to/from blob @@ -88,16 +71,15 @@ impl VectorDatabase { // Currently inserting at approximately 3400 documents a second // I imagine we can speed this up with a bulk insert of some kind. for document in indexed_file.documents { - sqlx::query( - "INSERT INTO documents (file_id, offset, name, embedding) VALUES ($1, $2, $3, $4)", - ) - .bind(inserted_id) - .bind(document.offset.to_string()) - .bind(document.name) - .bind(get_binary_from_values(document.embedding)) - .execute(&db) - .compat() - .await?; + db.execute( + "INSERT INTO documents (file_id, offset, name, embedding) VALUES (?1, ?2, ?3, ?4)", + params![ + inserted_id, + document.offset.to_string(), + document.name, + get_binary_from_values(document.embedding) + ], + )?; } Ok(()) diff --git a/crates/vector_store/src/embedding.rs b/crates/vector_store/src/embedding.rs index f1ae5479ee..4883917d5a 100644 --- a/crates/vector_store/src/embedding.rs +++ b/crates/vector_store/src/embedding.rs @@ -47,6 +47,18 @@ pub trait EmbeddingProvider: Sync { async fn embed_batch(&self, spans: Vec<&str>) -> Result>>; } +pub struct DummyEmbeddings {} + +#[async_trait] +impl EmbeddingProvider for DummyEmbeddings { + async fn embed_batch(&self, spans: Vec<&str>) -> Result>> { + // 1024 is the OpenAI Embeddings size for ada models. + // the model we will likely be starting with. + let dummy_vec = vec![0.32 as f32; 1024]; + return Ok(vec![dummy_vec; spans.len()]); + } +} + #[async_trait] impl EmbeddingProvider for OpenAIEmbeddings { async fn embed_batch(&self, spans: Vec<&str>) -> Result>> { diff --git a/crates/vector_store/src/vector_store.rs b/crates/vector_store/src/vector_store.rs index f4d5baca80..f424346d56 100644 --- a/crates/vector_store/src/vector_store.rs +++ b/crates/vector_store/src/vector_store.rs @@ -3,7 +3,7 @@ mod embedding; use anyhow::{anyhow, Result}; use db::VectorDatabase; -use embedding::{EmbeddingProvider, OpenAIEmbeddings}; +use embedding::{DummyEmbeddings, EmbeddingProvider, OpenAIEmbeddings}; use gpui::{AppContext, Entity, ModelContext, ModelHandle}; use language::LanguageRegistry; use project::{Fs, Project}; @@ -38,14 +38,14 @@ pub fn init( .detach(); } -#[derive(Debug, sqlx::FromRow)] +#[derive(Debug)] struct Document { offset: usize, name: String, embedding: Vec, } -#[derive(Debug, sqlx::FromRow)] +#[derive(Debug)] pub struct IndexedFile { path: PathBuf, sha1: String, @@ -188,7 +188,8 @@ impl VectorStore { }) .detach(); - let provider = OpenAIEmbeddings { client }; + // let provider = OpenAIEmbeddings { client }; + let provider = DummyEmbeddings {}; let t0 = Instant::now(); From 65bbb7c57bad891dbe9303ddc413c674974d5234 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Sun, 25 Jun 2023 20:02:56 -0400 Subject: [PATCH 05/51] added proper blob serialization for embeddings and vector search trait --- Cargo.lock | 13 +-- crates/vector_store/Cargo.toml | 3 +- crates/vector_store/src/db.rs | 102 +++++++++++++++++++----- crates/vector_store/src/embedding.rs | 3 +- crates/vector_store/src/search.rs | 5 ++ crates/vector_store/src/vector_store.rs | 17 ++-- 6 files changed, 104 insertions(+), 39 deletions(-) create mode 100644 crates/vector_store/src/search.rs diff --git a/Cargo.lock b/Cargo.lock index 3f13c75dda..309bcfa378 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1768,9 +1768,9 @@ dependencies = [ [[package]] name = "cxx" -version = "1.0.94" +version = "1.0.97" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f61f1b6389c3fe1c316bf8a4dccc90a38208354b330925bce1f74a6c4756eb93" +checksum = "e88abab2f5abbe4c56e8f1fb431b784d710b709888f35755a160e62e33fe38e8" dependencies = [ "cc", "cxxbridge-flags", @@ -1795,15 +1795,15 @@ dependencies = [ [[package]] name = "cxxbridge-flags" -version = "1.0.94" +version = "1.0.97" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7944172ae7e4068c533afbb984114a56c46e9ccddda550499caa222902c7f7bb" +checksum = "8d3816ed957c008ccd4728485511e3d9aaf7db419aa321e3d2c5a2f3411e36c8" [[package]] name = "cxxbridge-macro" -version = "1.0.94" +version = "1.0.97" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2345488264226bf682893e25de0769f3360aac9957980ec49361b083ddaa5bc5" +checksum = "a26acccf6f445af85ea056362561a24ef56cdc15fcc685f03aec50b9c702cb6d" dependencies = [ "proc-macro2", "quote", @@ -7913,6 +7913,7 @@ version = "0.1.0" dependencies = [ "anyhow", "async-trait", + "bincode", "futures 0.3.28", "gpui", "isahc", diff --git a/crates/vector_store/Cargo.toml b/crates/vector_store/Cargo.toml index 434f341147..6446651d5d 100644 --- a/crates/vector_store/Cargo.toml +++ b/crates/vector_store/Cargo.toml @@ -17,7 +17,7 @@ util = { path = "../util" } anyhow.workspace = true futures.workspace = true smol.workspace = true -rusqlite = "0.27.0" +rusqlite = { version = "0.27.0", features=["blob"] } isahc.workspace = true log.workspace = true tree-sitter.workspace = true @@ -25,6 +25,7 @@ lazy_static.workspace = true serde.workspace = true serde_json.workspace = true async-trait.workspace = true +bincode = "1.3.3" [dev-dependencies] gpui = { path = "../gpui", features = ["test-support"] } diff --git a/crates/vector_store/src/db.rs b/crates/vector_store/src/db.rs index e2b23f7548..54f0292d1f 100644 --- a/crates/vector_store/src/db.rs +++ b/crates/vector_store/src/db.rs @@ -1,13 +1,44 @@ -use anyhow::Result; -use rusqlite::params; +use std::collections::HashMap; -use crate::IndexedFile; +use anyhow::{anyhow, Result}; + +use rusqlite::{ + params, + types::{FromSql, FromSqlResult, ValueRef}, + Connection, +}; +use util::ResultExt; + +use crate::{Document, IndexedFile}; // This is saving to a local database store within the users dev zed path // Where do we want this to sit? // Assuming near where the workspace DB sits. const VECTOR_DB_URL: &str = "embeddings_db"; +// Note this is not an appropriate document +#[derive(Debug)] +pub struct DocumentRecord { + id: usize, + offset: usize, + name: String, + embedding: Embedding, +} + +#[derive(Debug)] +struct Embedding(Vec); + +impl FromSql for Embedding { + fn column_result(value: ValueRef) -> FromSqlResult { + let bytes = value.as_blob()?; + let embedding: Result, Box> = bincode::deserialize(bytes); + if embedding.is_err() { + return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err())); + } + return Ok(Embedding(embedding.unwrap())); + } +} + pub struct VectorDatabase {} impl VectorDatabase { @@ -51,37 +82,66 @@ impl VectorDatabase { let inserted_id = db.last_insert_rowid(); - // I stole this from https://stackoverflow.com/questions/71829931/how-do-i-convert-a-negative-f32-value-to-binary-string-and-back-again - // I imagine there is a better way to serialize to/from blob - fn get_binary_from_values(values: Vec) -> String { - let bits: Vec<_> = values.iter().map(|v| v.to_bits().to_string()).collect(); - bits.join(";") - } - - fn get_values_from_binary(bin: &str) -> Vec { - (0..bin.len() / 32) - .map(|i| { - let start = i * 32; - let end = start + 32; - f32::from_bits(u32::from_str_radix(&bin[start..end], 2).unwrap()) - }) - .collect() - } - // Currently inserting at approximately 3400 documents a second // I imagine we can speed this up with a bulk insert of some kind. for document in indexed_file.documents { + let embedding_blob = bincode::serialize(&document.embedding)?; + db.execute( "INSERT INTO documents (file_id, offset, name, embedding) VALUES (?1, ?2, ?3, ?4)", params![ inserted_id, document.offset.to_string(), document.name, - get_binary_from_values(document.embedding) + embedding_blob ], )?; } Ok(()) } + + pub fn get_documents(&self) -> Result> { + // Should return a HashMap in which the key is the id, and the value is the finished document + + // Get Data from Database + let db = rusqlite::Connection::open(VECTOR_DB_URL)?; + + fn query(db: Connection) -> rusqlite::Result> { + let mut query_statement = + db.prepare("SELECT id, offset, name, embedding FROM documents LIMIT 10")?; + let result_iter = query_statement.query_map([], |row| { + Ok(DocumentRecord { + id: row.get(0)?, + offset: row.get(1)?, + name: row.get(2)?, + embedding: row.get(3)?, + }) + })?; + + let mut results = vec![]; + for result in result_iter { + results.push(result?); + } + + return Ok(results); + } + + let mut documents: HashMap = HashMap::new(); + let result_iter = query(db); + if result_iter.is_ok() { + for result in result_iter.unwrap() { + documents.insert( + result.id, + Document { + offset: result.offset, + name: result.name, + embedding: result.embedding.0, + }, + ); + } + } + + return Ok(documents); + } } diff --git a/crates/vector_store/src/embedding.rs b/crates/vector_store/src/embedding.rs index 4883917d5a..903c2451b3 100644 --- a/crates/vector_store/src/embedding.rs +++ b/crates/vector_store/src/embedding.rs @@ -13,6 +13,7 @@ lazy_static! { static ref OPENAI_API_KEY: Option = env::var("OPENAI_API_KEY").ok(); } +#[derive(Clone)] pub struct OpenAIEmbeddings { pub client: Arc, } @@ -54,7 +55,7 @@ impl EmbeddingProvider for DummyEmbeddings { async fn embed_batch(&self, spans: Vec<&str>) -> Result>> { // 1024 is the OpenAI Embeddings size for ada models. // the model we will likely be starting with. - let dummy_vec = vec![0.32 as f32; 1024]; + let dummy_vec = vec![0.32 as f32; 1536]; return Ok(vec![dummy_vec; spans.len()]); } } diff --git a/crates/vector_store/src/search.rs b/crates/vector_store/src/search.rs new file mode 100644 index 0000000000..3dc72edbce --- /dev/null +++ b/crates/vector_store/src/search.rs @@ -0,0 +1,5 @@ +trait VectorSearch { + // Given a query vector, and a limit to return + // Return a vector of id, distance tuples. + fn top_k_search(&self, vec: &Vec) -> Vec<(usize, f32)>; +} diff --git a/crates/vector_store/src/vector_store.rs b/crates/vector_store/src/vector_store.rs index f424346d56..0b6d2928cc 100644 --- a/crates/vector_store/src/vector_store.rs +++ b/crates/vector_store/src/vector_store.rs @@ -1,5 +1,6 @@ mod db; mod embedding; +mod search; use anyhow::{anyhow, Result}; use db::VectorDatabase; @@ -39,10 +40,10 @@ pub fn init( } #[derive(Debug)] -struct Document { - offset: usize, - name: String, - embedding: Vec, +pub struct Document { + pub offset: usize, + pub name: String, + pub embedding: Vec, } #[derive(Debug)] @@ -185,14 +186,13 @@ impl VectorStore { while let Ok(indexed_file) = indexed_files_rx.recv().await { VectorDatabase::insert_file(indexed_file).await.log_err(); } + + anyhow::Ok(()) }) .detach(); - // let provider = OpenAIEmbeddings { client }; let provider = DummyEmbeddings {}; - let t0 = Instant::now(); - cx.background() .scoped(|scope| { for _ in 0..cx.background().num_cpus() { @@ -218,9 +218,6 @@ impl VectorStore { } }) .await; - - let duration = t0.elapsed(); - log::info!("indexed project in {duration:?}"); }) .detach(); } From 7937a16002f7fa4abb752f20bce1bf0d810a823e Mon Sep 17 00:00:00 2001 From: KCaverly Date: Mon, 26 Jun 2023 10:34:12 -0400 Subject: [PATCH 06/51] added brute force search and VectorSearch trait --- Cargo.lock | 39 ++++++++++++++ crates/vector_store/Cargo.toml | 1 + crates/vector_store/src/search.rs | 84 ++++++++++++++++++++++++++++++- 3 files changed, 122 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 309bcfa378..48952d6c25 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3837,6 +3837,16 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "73cbba799671b762df5a175adf59ce145165747bb891505c43d09aefbbf38beb" +[[package]] +name = "matrixmultiply" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "090126dc04f95dc0d1c1c91f61bdd474b3930ca064c1edc8a849da2c6cbe1e77" +dependencies = [ + "autocfg 1.1.0", + "rawpointer", +] + [[package]] name = "maybe-owned" version = "0.3.4" @@ -4121,6 +4131,19 @@ dependencies = [ "tempfile", ] +[[package]] +name = "ndarray" +version = "0.15.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "rawpointer", +] + [[package]] name = "net2" version = "0.2.38" @@ -4228,6 +4251,15 @@ dependencies = [ "zeroize", ] +[[package]] +name = "num-complex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02e0d21255c828d6f128a1e41534206671e8c3ea0c62f32291e808dc82cff17d" +dependencies = [ + "num-traits", +] + [[package]] name = "num-integer" version = "0.1.45" @@ -5245,6 +5277,12 @@ dependencies = [ "rand_core 0.5.1", ] +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + [[package]] name = "rayon" version = "1.7.0" @@ -7920,6 +7958,7 @@ dependencies = [ "language", "lazy_static", "log", + "ndarray", "project", "rusqlite", "serde", diff --git a/crates/vector_store/Cargo.toml b/crates/vector_store/Cargo.toml index 6446651d5d..8de93c0401 100644 --- a/crates/vector_store/Cargo.toml +++ b/crates/vector_store/Cargo.toml @@ -26,6 +26,7 @@ serde.workspace = true serde_json.workspace = true async-trait.workspace = true bincode = "1.3.3" +ndarray = "0.15.6" [dev-dependencies] gpui = { path = "../gpui", features = ["test-support"] } diff --git a/crates/vector_store/src/search.rs b/crates/vector_store/src/search.rs index 3dc72edbce..6b508b401b 100644 --- a/crates/vector_store/src/search.rs +++ b/crates/vector_store/src/search.rs @@ -1,5 +1,85 @@ -trait VectorSearch { +use std::cmp::Ordering; + +use async_trait::async_trait; +use ndarray::{Array1, Array2}; + +use crate::db::{DocumentRecord, VectorDatabase}; +use anyhow::Result; + +#[async_trait] +pub trait VectorSearch { // Given a query vector, and a limit to return // Return a vector of id, distance tuples. - fn top_k_search(&self, vec: &Vec) -> Vec<(usize, f32)>; + async fn top_k_search(&mut self, vec: &Vec, limit: usize) -> Vec<(usize, f32)>; +} + +pub struct BruteForceSearch { + document_ids: Vec, + candidate_array: ndarray::Array2, +} + +impl BruteForceSearch { + pub fn load() -> Result { + let db = VectorDatabase {}; + let documents = db.get_documents()?; + let embeddings: Vec<&DocumentRecord> = documents.values().into_iter().collect(); + let mut document_ids = vec![]; + for i in documents.keys() { + document_ids.push(i.to_owned()); + } + + let mut candidate_array = Array2::::default((documents.len(), 1536)); + for (i, mut row) in candidate_array.axis_iter_mut(ndarray::Axis(0)).enumerate() { + for (j, col) in row.iter_mut().enumerate() { + *col = embeddings[i].embedding.0[j]; + } + } + + return Ok(BruteForceSearch { + document_ids, + candidate_array, + }); + } +} + +#[async_trait] +impl VectorSearch for BruteForceSearch { + async fn top_k_search(&mut self, vec: &Vec, limit: usize) -> Vec<(usize, f32)> { + let target = Array1::from_vec(vec.to_owned()); + + let distances = self.candidate_array.dot(&target); + + let distances = distances.to_vec(); + + // construct a tuple vector from the floats, the tuple being (index,float) + let mut with_indices = distances + .clone() + .into_iter() + .enumerate() + .map(|(index, value)| (index, value)) + .collect::>(); + + // sort the tuple vector by float + with_indices.sort_by(|&a, &b| match (a.1.is_nan(), b.1.is_nan()) { + (true, true) => Ordering::Equal, + (true, false) => Ordering::Greater, + (false, true) => Ordering::Less, + (false, false) => a.1.partial_cmp(&b.1).unwrap(), + }); + + // extract the sorted indices from the sorted tuple vector + let stored_indices = with_indices + .into_iter() + .map(|(index, value)| index) + .collect::>(); + + let sorted_indices: Vec = stored_indices.into_iter().rev().collect(); + + let mut results = vec![]; + for idx in sorted_indices[0..limit].to_vec() { + results.push((self.document_ids[idx], 1.0 - distances[idx])); + } + + return results; + } } From 0f232e0ce2c7e50ef91b0daf9b8618c81f0ec33d Mon Sep 17 00:00:00 2001 From: KCaverly Date: Mon, 26 Jun 2023 10:35:56 -0400 Subject: [PATCH 07/51] added file metadata retrieval from db --- crates/vector_store/src/db.rs | 87 ++++++++++++++++++++++++----------- 1 file changed, 60 insertions(+), 27 deletions(-) diff --git a/crates/vector_store/src/db.rs b/crates/vector_store/src/db.rs index 54f0292d1f..bc5a7fd497 100644 --- a/crates/vector_store/src/db.rs +++ b/crates/vector_store/src/db.rs @@ -7,9 +7,8 @@ use rusqlite::{ types::{FromSql, FromSqlResult, ValueRef}, Connection, }; -use util::ResultExt; -use crate::{Document, IndexedFile}; +use crate::IndexedFile; // This is saving to a local database store within the users dev zed path // Where do we want this to sit? @@ -19,14 +18,22 @@ const VECTOR_DB_URL: &str = "embeddings_db"; // Note this is not an appropriate document #[derive(Debug)] pub struct DocumentRecord { - id: usize, - offset: usize, - name: String, - embedding: Embedding, + pub id: usize, + pub file_id: usize, + pub offset: usize, + pub name: String, + pub embedding: Embedding, } #[derive(Debug)] -struct Embedding(Vec); +pub struct FileRecord { + pub id: usize, + pub path: String, + pub sha1: String, +} + +#[derive(Debug)] +pub struct Embedding(pub Vec); impl FromSql for Embedding { fn column_result(value: ValueRef) -> FromSqlResult { @@ -101,21 +108,16 @@ impl VectorDatabase { Ok(()) } - pub fn get_documents(&self) -> Result> { - // Should return a HashMap in which the key is the id, and the value is the finished document - - // Get Data from Database + pub fn get_files(&self) -> Result> { let db = rusqlite::Connection::open(VECTOR_DB_URL)?; - fn query(db: Connection) -> rusqlite::Result> { - let mut query_statement = - db.prepare("SELECT id, offset, name, embedding FROM documents LIMIT 10")?; + fn query(db: Connection) -> rusqlite::Result> { + let mut query_statement = db.prepare("SELECT id, path, sha1 FROM files")?; let result_iter = query_statement.query_map([], |row| { - Ok(DocumentRecord { + Ok(FileRecord { id: row.get(0)?, - offset: row.get(1)?, - name: row.get(2)?, - embedding: row.get(3)?, + path: row.get(1)?, + sha1: row.get(2)?, }) })?; @@ -127,18 +129,49 @@ impl VectorDatabase { return Ok(results); } - let mut documents: HashMap = HashMap::new(); + let mut pages: HashMap = HashMap::new(); let result_iter = query(db); if result_iter.is_ok() { for result in result_iter.unwrap() { - documents.insert( - result.id, - Document { - offset: result.offset, - name: result.name, - embedding: result.embedding.0, - }, - ); + pages.insert(result.id, result); + } + } + + return Ok(pages); + } + + pub fn get_documents(&self) -> Result> { + // Should return a HashMap in which the key is the id, and the value is the finished document + + // Get Data from Database + let db = rusqlite::Connection::open(VECTOR_DB_URL)?; + + fn query(db: Connection) -> rusqlite::Result> { + let mut query_statement = + db.prepare("SELECT id, file_id, offset, name, embedding FROM documents")?; + let result_iter = query_statement.query_map([], |row| { + Ok(DocumentRecord { + id: row.get(0)?, + file_id: row.get(1)?, + offset: row.get(2)?, + name: row.get(3)?, + embedding: row.get(4)?, + }) + })?; + + let mut results = vec![]; + for result in result_iter { + results.push(result?); + } + + return Ok(results); + } + + let mut documents: HashMap = HashMap::new(); + let result_iter = query(db); + if result_iter.is_ok() { + for result in result_iter.unwrap() { + documents.insert(result.id, result); } } From 74b693d6b915f587956a85e3decadfab2d5238fc Mon Sep 17 00:00:00 2001 From: KCaverly Date: Mon, 26 Jun 2023 14:57:57 -0400 Subject: [PATCH 08/51] Updated database calls to share single connection, and simplified top_k_search sorting. Co-authored-by: maxbrunsfeld --- crates/vector_store/src/db.rs | 159 ++++++++++++------------ crates/vector_store/src/embedding.rs | 10 -- crates/vector_store/src/search.rs | 47 ++++--- crates/vector_store/src/vector_store.rs | 56 +++++++-- 4 files changed, 148 insertions(+), 124 deletions(-) diff --git a/crates/vector_store/src/db.rs b/crates/vector_store/src/db.rs index bc5a7fd497..4f6da14cab 100644 --- a/crates/vector_store/src/db.rs +++ b/crates/vector_store/src/db.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::{collections::HashMap, path::PathBuf}; use anyhow::{anyhow, Result}; @@ -46,31 +46,50 @@ impl FromSql for Embedding { } } -pub struct VectorDatabase {} +pub struct VectorDatabase { + db: rusqlite::Connection, +} impl VectorDatabase { - pub async fn initialize_database() -> Result<()> { + pub fn new() -> Result { + let this = Self { + db: rusqlite::Connection::open(VECTOR_DB_URL)?, + }; + this.initialize_database()?; + Ok(this) + } + + fn initialize_database(&self) -> Result<()> { // This will create the database if it doesnt exist - let db = rusqlite::Connection::open(VECTOR_DB_URL)?; // Initialize Vector Databasing Tables - db.execute( + // self.db.execute( + // " + // CREATE TABLE IF NOT EXISTS projects ( + // id INTEGER PRIMARY KEY AUTOINCREMENT, + // path NVARCHAR(100) NOT NULL + // ) + // ", + // [], + // )?; + + self.db.execute( "CREATE TABLE IF NOT EXISTS files ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - path NVARCHAR(100) NOT NULL, - sha1 NVARCHAR(40) NOT NULL - )", + id INTEGER PRIMARY KEY AUTOINCREMENT, + path NVARCHAR(100) NOT NULL, + sha1 NVARCHAR(40) NOT NULL + )", [], )?; - db.execute( + self.db.execute( "CREATE TABLE IF NOT EXISTS documents ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - file_id INTEGER NOT NULL, - offset INTEGER NOT NULL, - name NVARCHAR(100) NOT NULL, - embedding BLOB NOT NULL, - FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE + id INTEGER PRIMARY KEY AUTOINCREMENT, + file_id INTEGER NOT NULL, + offset INTEGER NOT NULL, + name NVARCHAR(100) NOT NULL, + embedding BLOB NOT NULL, + FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE )", [], )?; @@ -78,23 +97,37 @@ impl VectorDatabase { Ok(()) } - pub async fn insert_file(indexed_file: IndexedFile) -> Result<()> { - // Write to files table, and return generated id. - let db = rusqlite::Connection::open(VECTOR_DB_URL)?; + // pub async fn get_or_create_project(project_path: PathBuf) -> Result { + // // Check if we have the project, if we do, return the ID + // // If we do not have the project, insert the project and return the ID - let files_insert = db.execute( + // let db = rusqlite::Connection::open(VECTOR_DB_URL)?; + + // let projects_query = db.prepare(&format!( + // "SELECT id FROM projects WHERE path = {}", + // project_path.to_str().unwrap() // This is unsafe + // ))?; + + // let project_id = db.last_insert_rowid(); + + // return Ok(project_id as usize); + // } + + pub fn insert_file(&self, indexed_file: IndexedFile) -> Result<()> { + // Write to files table, and return generated id. + let files_insert = self.db.execute( "INSERT INTO files (path, sha1) VALUES (?1, ?2)", params![indexed_file.path.to_str(), indexed_file.sha1], )?; - let inserted_id = db.last_insert_rowid(); + let inserted_id = self.db.last_insert_rowid(); // Currently inserting at approximately 3400 documents a second // I imagine we can speed this up with a bulk insert of some kind. for document in indexed_file.documents { let embedding_blob = bincode::serialize(&document.embedding)?; - db.execute( + self.db.execute( "INSERT INTO documents (file_id, offset, name, embedding) VALUES (?1, ?2, ?3, ?4)", params![ inserted_id, @@ -109,70 +142,42 @@ impl VectorDatabase { } pub fn get_files(&self) -> Result> { - let db = rusqlite::Connection::open(VECTOR_DB_URL)?; - - fn query(db: Connection) -> rusqlite::Result> { - let mut query_statement = db.prepare("SELECT id, path, sha1 FROM files")?; - let result_iter = query_statement.query_map([], |row| { - Ok(FileRecord { - id: row.get(0)?, - path: row.get(1)?, - sha1: row.get(2)?, - }) - })?; - - let mut results = vec![]; - for result in result_iter { - results.push(result?); - } - - return Ok(results); - } + let mut query_statement = self.db.prepare("SELECT id, path, sha1 FROM files")?; + let result_iter = query_statement.query_map([], |row| { + Ok(FileRecord { + id: row.get(0)?, + path: row.get(1)?, + sha1: row.get(2)?, + }) + })?; let mut pages: HashMap = HashMap::new(); - let result_iter = query(db); - if result_iter.is_ok() { - for result in result_iter.unwrap() { - pages.insert(result.id, result); - } + for result in result_iter { + let result = result?; + pages.insert(result.id, result); } - return Ok(pages); + Ok(pages) } pub fn get_documents(&self) -> Result> { - // Should return a HashMap in which the key is the id, and the value is the finished document - - // Get Data from Database - let db = rusqlite::Connection::open(VECTOR_DB_URL)?; - - fn query(db: Connection) -> rusqlite::Result> { - let mut query_statement = - db.prepare("SELECT id, file_id, offset, name, embedding FROM documents")?; - let result_iter = query_statement.query_map([], |row| { - Ok(DocumentRecord { - id: row.get(0)?, - file_id: row.get(1)?, - offset: row.get(2)?, - name: row.get(3)?, - embedding: row.get(4)?, - }) - })?; - - let mut results = vec![]; - for result in result_iter { - results.push(result?); - } - - return Ok(results); - } + let mut query_statement = self + .db + .prepare("SELECT id, file_id, offset, name, embedding FROM documents")?; + let result_iter = query_statement.query_map([], |row| { + Ok(DocumentRecord { + id: row.get(0)?, + file_id: row.get(1)?, + offset: row.get(2)?, + name: row.get(3)?, + embedding: row.get(4)?, + }) + })?; let mut documents: HashMap = HashMap::new(); - let result_iter = query(db); - if result_iter.is_ok() { - for result in result_iter.unwrap() { - documents.insert(result.id, result); - } + for result in result_iter { + let result = result?; + documents.insert(result.id, result); } return Ok(documents); diff --git a/crates/vector_store/src/embedding.rs b/crates/vector_store/src/embedding.rs index 903c2451b3..f995639e64 100644 --- a/crates/vector_store/src/embedding.rs +++ b/crates/vector_store/src/embedding.rs @@ -94,16 +94,6 @@ impl EmbeddingProvider for OpenAIEmbeddings { response.usage.total_tokens ); - // do we need to re-order these based on the `index` field? - eprintln!( - "indices: {:?}", - response - .data - .iter() - .map(|embedding| embedding.index) - .collect::>() - ); - Ok(response .data .into_iter() diff --git a/crates/vector_store/src/search.rs b/crates/vector_store/src/search.rs index 6b508b401b..ce8bdd1af4 100644 --- a/crates/vector_store/src/search.rs +++ b/crates/vector_store/src/search.rs @@ -19,8 +19,8 @@ pub struct BruteForceSearch { } impl BruteForceSearch { - pub fn load() -> Result { - let db = VectorDatabase {}; + pub fn load(db: &VectorDatabase) -> Result { + // let db = VectorDatabase {}; let documents = db.get_documents()?; let embeddings: Vec<&DocumentRecord> = documents.values().into_iter().collect(); let mut document_ids = vec![]; @@ -47,39 +47,36 @@ impl VectorSearch for BruteForceSearch { async fn top_k_search(&mut self, vec: &Vec, limit: usize) -> Vec<(usize, f32)> { let target = Array1::from_vec(vec.to_owned()); - let distances = self.candidate_array.dot(&target); + let similarities = self.candidate_array.dot(&target); - let distances = distances.to_vec(); + let similarities = similarities.to_vec(); // construct a tuple vector from the floats, the tuple being (index,float) - let mut with_indices = distances - .clone() - .into_iter() + let mut with_indices = similarities + .iter() + .copied() .enumerate() - .map(|(index, value)| (index, value)) + .map(|(index, value)| (self.document_ids[index], value)) .collect::>(); // sort the tuple vector by float - with_indices.sort_by(|&a, &b| match (a.1.is_nan(), b.1.is_nan()) { - (true, true) => Ordering::Equal, - (true, false) => Ordering::Greater, - (false, true) => Ordering::Less, - (false, false) => a.1.partial_cmp(&b.1).unwrap(), - }); + with_indices.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal)); + with_indices.truncate(limit); + with_indices - // extract the sorted indices from the sorted tuple vector - let stored_indices = with_indices - .into_iter() - .map(|(index, value)| index) - .collect::>(); + // // extract the sorted indices from the sorted tuple vector + // let stored_indices = with_indices + // .into_iter() + // .map(|(index, value)| index) + // .collect::>(); - let sorted_indices: Vec = stored_indices.into_iter().rev().collect(); + // let sorted_indices: Vec = stored_indices.into_iter().rev().collect(); - let mut results = vec![]; - for idx in sorted_indices[0..limit].to_vec() { - results.push((self.document_ids[idx], 1.0 - distances[idx])); - } + // let mut results = vec![]; + // for idx in sorted_indices[0..limit].to_vec() { + // results.push((self.document_ids[idx], 1.0 - similarities[idx])); + // } - return results; + // return results; } } diff --git a/crates/vector_store/src/vector_store.rs b/crates/vector_store/src/vector_store.rs index 0b6d2928cc..6e6bedc33a 100644 --- a/crates/vector_store/src/vector_store.rs +++ b/crates/vector_store/src/vector_store.rs @@ -1,5 +1,6 @@ mod db; mod embedding; +mod parsing; mod search; use anyhow::{anyhow, Result}; @@ -7,11 +8,13 @@ use db::VectorDatabase; use embedding::{DummyEmbeddings, EmbeddingProvider, OpenAIEmbeddings}; use gpui::{AppContext, Entity, ModelContext, ModelHandle}; use language::LanguageRegistry; +use parsing::Document; use project::{Fs, Project}; +use search::{BruteForceSearch, VectorSearch}; use smol::channel; use std::{path::PathBuf, sync::Arc, time::Instant}; use tree_sitter::{Parser, QueryCursor}; -use util::{http::HttpClient, ResultExt}; +use util::{http::HttpClient, ResultExt, TryFutureExt}; use workspace::WorkspaceCreated; pub fn init( @@ -39,13 +42,6 @@ pub fn init( .detach(); } -#[derive(Debug)] -pub struct Document { - pub offset: usize, - pub name: String, - pub embedding: Vec, -} - #[derive(Debug)] pub struct IndexedFile { path: PathBuf, @@ -180,18 +176,54 @@ impl VectorStore { .detach(); cx.background() - .spawn(async move { + .spawn({ + let client = client.clone(); + async move { // Initialize Database, creates database and tables if not exists - VectorDatabase::initialize_database().await.log_err(); + let db = VectorDatabase::new()?; while let Ok(indexed_file) = indexed_files_rx.recv().await { - VectorDatabase::insert_file(indexed_file).await.log_err(); + db.insert_file(indexed_file).log_err(); + } + + // ALL OF THE BELOW IS FOR TESTING, + // This should be removed as we find and appropriate place for evaluate our search. + + let embedding_provider = OpenAIEmbeddings{ client }; + let queries = vec![ + "compute embeddings for all of the symbols in the codebase, and write them to a database", + "compute an outline view of all of the symbols in a buffer", + "scan a directory on the file system and load all of its children into an in-memory snapshot", + ]; + let embeddings = embedding_provider.embed_batch(queries.clone()).await?; + + let t2 = Instant::now(); + let documents = db.get_documents().unwrap(); + let files = db.get_files().unwrap(); + println!("Retrieving all documents from Database: {}", t2.elapsed().as_millis()); + + let t1 = Instant::now(); + let mut bfs = BruteForceSearch::load(&db).unwrap(); + println!("Loading BFS to Memory: {:?}", t1.elapsed().as_millis()); + for (idx, embed) in embeddings.into_iter().enumerate() { + let t0 = Instant::now(); + println!("\nQuery: {:?}", queries[idx]); + let results = bfs.top_k_search(&embed, 5).await; + println!("Search Elapsed: {}", t0.elapsed().as_millis()); + for (id, distance) in results { + println!(""); + println!(" distance: {:?}", distance); + println!(" document: {:?}", documents[&id].name); + println!(" path: {:?}", files[&documents[&id].file_id].path); + } + } anyhow::Ok(()) - }) + }}.log_err()) .detach(); let provider = DummyEmbeddings {}; + // let provider = OpenAIEmbeddings { client }; cx.background() .scoped(|scope| { From 953e928bdb3aa80744a13ff53a197fd798fec0fe Mon Sep 17 00:00:00 2001 From: KCaverly Date: Mon, 26 Jun 2023 19:01:19 -0400 Subject: [PATCH 09/51] WIP: Got the streaming matrix multiplication working, and started work on file hashing. Co-authored-by: maxbrunsfeld --- Cargo.lock | 5 + crates/vector_store/Cargo.toml | 5 + crates/vector_store/src/db.rs | 84 ++++-- crates/vector_store/src/embedding.rs | 2 +- crates/vector_store/src/search.rs | 18 +- crates/vector_store/src/vector_store.rs | 243 +++++++++++++----- crates/vector_store/src/vector_store_tests.rs | 136 ++++++++++ 7 files changed, 396 insertions(+), 97 deletions(-) create mode 100644 crates/vector_store/src/vector_store_tests.rs diff --git a/Cargo.lock b/Cargo.lock index 48952d6c25..ff4caaa5a6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7958,13 +7958,18 @@ dependencies = [ "language", "lazy_static", "log", + "matrixmultiply", "ndarray", "project", + "rand 0.8.5", "rusqlite", "serde", "serde_json", + "sha-1 0.10.1", "smol", "tree-sitter", + "tree-sitter-rust", + "unindent", "util", "workspace", ] diff --git a/crates/vector_store/Cargo.toml b/crates/vector_store/Cargo.toml index 8de93c0401..dbe0a2e69c 100644 --- a/crates/vector_store/Cargo.toml +++ b/crates/vector_store/Cargo.toml @@ -27,9 +27,14 @@ serde_json.workspace = true async-trait.workspace = true bincode = "1.3.3" ndarray = "0.15.6" +sha-1 = "0.10.1" +matrixmultiply = "0.3.7" [dev-dependencies] gpui = { path = "../gpui", features = ["test-support"] } language = { path = "../language", features = ["test-support"] } project = { path = "../project", features = ["test-support"] } workspace = { path = "../workspace", features = ["test-support"] } +tree-sitter-rust = "*" +rand.workspace = true +unindent.workspace = true diff --git a/crates/vector_store/src/db.rs b/crates/vector_store/src/db.rs index 4f6da14cab..bcb1090a8d 100644 --- a/crates/vector_store/src/db.rs +++ b/crates/vector_store/src/db.rs @@ -1,4 +1,7 @@ -use std::{collections::HashMap, path::PathBuf}; +use std::{ + collections::HashMap, + path::{Path, PathBuf}, +}; use anyhow::{anyhow, Result}; @@ -13,7 +16,7 @@ use crate::IndexedFile; // This is saving to a local database store within the users dev zed path // Where do we want this to sit? // Assuming near where the workspace DB sits. -const VECTOR_DB_URL: &str = "embeddings_db"; +pub const VECTOR_DB_URL: &str = "embeddings_db"; // Note this is not an appropriate document #[derive(Debug)] @@ -28,7 +31,7 @@ pub struct DocumentRecord { #[derive(Debug)] pub struct FileRecord { pub id: usize, - pub path: String, + pub relative_path: String, pub sha1: String, } @@ -51,9 +54,9 @@ pub struct VectorDatabase { } impl VectorDatabase { - pub fn new() -> Result { + pub fn new(path: &str) -> Result { let this = Self { - db: rusqlite::Connection::open(VECTOR_DB_URL)?, + db: rusqlite::Connection::open(path)?, }; this.initialize_database()?; Ok(this) @@ -63,21 +66,23 @@ impl VectorDatabase { // This will create the database if it doesnt exist // Initialize Vector Databasing Tables - // self.db.execute( - // " - // CREATE TABLE IF NOT EXISTS projects ( - // id INTEGER PRIMARY KEY AUTOINCREMENT, - // path NVARCHAR(100) NOT NULL - // ) - // ", - // [], - // )?; + self.db.execute( + "CREATE TABLE IF NOT EXISTS worktrees ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + absolute_path VARCHAR NOT NULL + ); + CREATE UNIQUE INDEX IF NOT EXISTS worktrees_absolute_path ON worktrees (absolute_path); + ", + [], + )?; self.db.execute( "CREATE TABLE IF NOT EXISTS files ( id INTEGER PRIMARY KEY AUTOINCREMENT, - path NVARCHAR(100) NOT NULL, - sha1 NVARCHAR(40) NOT NULL + worktree_id INTEGER NOT NULL, + relative_path VARCHAR NOT NULL, + sha1 NVARCHAR(40) NOT NULL, + FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE )", [], )?; @@ -87,7 +92,7 @@ impl VectorDatabase { id INTEGER PRIMARY KEY AUTOINCREMENT, file_id INTEGER NOT NULL, offset INTEGER NOT NULL, - name NVARCHAR(100) NOT NULL, + name VARCHAR NOT NULL, embedding BLOB NOT NULL, FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE )", @@ -116,7 +121,7 @@ impl VectorDatabase { pub fn insert_file(&self, indexed_file: IndexedFile) -> Result<()> { // Write to files table, and return generated id. let files_insert = self.db.execute( - "INSERT INTO files (path, sha1) VALUES (?1, ?2)", + "INSERT INTO files (relative_path, sha1) VALUES (?1, ?2)", params![indexed_file.path.to_str(), indexed_file.sha1], )?; @@ -141,12 +146,38 @@ impl VectorDatabase { Ok(()) } + pub fn find_or_create_worktree(&self, worktree_root_path: &Path) -> Result { + self.db.execute( + " + INSERT into worktrees (absolute_path) VALUES (?1) + ON CONFLICT DO NOTHING + ", + params![worktree_root_path.to_string_lossy()], + )?; + Ok(self.db.last_insert_rowid()) + } + + pub fn get_file_hashes(&self, worktree_id: i64) -> Result> { + let mut statement = self + .db + .prepare("SELECT relative_path, sha1 FROM files ORDER BY relative_path")?; + let mut result = Vec::new(); + for row in + statement.query_map([], |row| Ok((row.get::<_, String>(0)?.into(), row.get(1)?)))? + { + result.push(row?); + } + Ok(result) + } + pub fn get_files(&self) -> Result> { - let mut query_statement = self.db.prepare("SELECT id, path, sha1 FROM files")?; + let mut query_statement = self + .db + .prepare("SELECT id, relative_path, sha1 FROM files")?; let result_iter = query_statement.query_map([], |row| { Ok(FileRecord { id: row.get(0)?, - path: row.get(1)?, + relative_path: row.get(1)?, sha1: row.get(2)?, }) })?; @@ -160,6 +191,19 @@ impl VectorDatabase { Ok(pages) } + pub fn for_each_document( + &self, + worktree_id: i64, + mut f: impl FnMut(i64, Embedding), + ) -> Result<()> { + let mut query_statement = self.db.prepare("SELECT id, embedding FROM documents")?; + query_statement + .query_map(params![], |row| Ok((row.get(0)?, row.get(1)?)))? + .filter_map(|row| row.ok()) + .for_each(|row| f(row.0, row.1)); + Ok(()) + } + pub fn get_documents(&self) -> Result> { let mut query_statement = self .db diff --git a/crates/vector_store/src/embedding.rs b/crates/vector_store/src/embedding.rs index f995639e64..86d8494ab4 100644 --- a/crates/vector_store/src/embedding.rs +++ b/crates/vector_store/src/embedding.rs @@ -44,7 +44,7 @@ struct OpenAIEmbeddingUsage { } #[async_trait] -pub trait EmbeddingProvider: Sync { +pub trait EmbeddingProvider: Sync + Send { async fn embed_batch(&self, spans: Vec<&str>) -> Result>>; } diff --git a/crates/vector_store/src/search.rs b/crates/vector_store/src/search.rs index ce8bdd1af4..90a8d874da 100644 --- a/crates/vector_store/src/search.rs +++ b/crates/vector_store/src/search.rs @@ -1,4 +1,4 @@ -use std::cmp::Ordering; +use std::{cmp::Ordering, path::PathBuf}; use async_trait::async_trait; use ndarray::{Array1, Array2}; @@ -20,7 +20,6 @@ pub struct BruteForceSearch { impl BruteForceSearch { pub fn load(db: &VectorDatabase) -> Result { - // let db = VectorDatabase {}; let documents = db.get_documents()?; let embeddings: Vec<&DocumentRecord> = documents.values().into_iter().collect(); let mut document_ids = vec![]; @@ -63,20 +62,5 @@ impl VectorSearch for BruteForceSearch { with_indices.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal)); with_indices.truncate(limit); with_indices - - // // extract the sorted indices from the sorted tuple vector - // let stored_indices = with_indices - // .into_iter() - // .map(|(index, value)| index) - // .collect::>(); - - // let sorted_indices: Vec = stored_indices.into_iter().rev().collect(); - - // let mut results = vec![]; - // for idx in sorted_indices[0..limit].to_vec() { - // results.push((self.document_ids[idx], 1.0 - similarities[idx])); - // } - - // return results; } } diff --git a/crates/vector_store/src/vector_store.rs b/crates/vector_store/src/vector_store.rs index 6e6bedc33a..f34316e950 100644 --- a/crates/vector_store/src/vector_store.rs +++ b/crates/vector_store/src/vector_store.rs @@ -3,16 +3,19 @@ mod embedding; mod parsing; mod search; +#[cfg(test)] +mod vector_store_tests; + use anyhow::{anyhow, Result}; -use db::VectorDatabase; +use db::{VectorDatabase, VECTOR_DB_URL}; use embedding::{DummyEmbeddings, EmbeddingProvider, OpenAIEmbeddings}; -use gpui::{AppContext, Entity, ModelContext, ModelHandle}; +use gpui::{AppContext, Entity, ModelContext, ModelHandle, Task}; use language::LanguageRegistry; use parsing::Document; use project::{Fs, Project}; use search::{BruteForceSearch, VectorSearch}; use smol::channel; -use std::{path::PathBuf, sync::Arc, time::Instant}; +use std::{cmp::Ordering, path::PathBuf, sync::Arc, time::Instant}; use tree_sitter::{Parser, QueryCursor}; use util::{http::HttpClient, ResultExt, TryFutureExt}; use workspace::WorkspaceCreated; @@ -23,7 +26,16 @@ pub fn init( language_registry: Arc, cx: &mut AppContext, ) { - let vector_store = cx.add_model(|cx| VectorStore::new(fs, http_client, language_registry)); + let vector_store = cx.add_model(|cx| { + VectorStore::new( + fs, + VECTOR_DB_URL.to_string(), + Arc::new(OpenAIEmbeddings { + client: http_client, + }), + language_registry, + ) + }); cx.subscribe_global::({ let vector_store = vector_store.clone(); @@ -49,28 +61,36 @@ pub struct IndexedFile { documents: Vec, } -struct SearchResult { - path: PathBuf, - offset: usize, - name: String, - distance: f32, -} - +// struct SearchResult { +// path: PathBuf, +// offset: usize, +// name: String, +// distance: f32, +// } struct VectorStore { fs: Arc, - http_client: Arc, + database_url: Arc, + embedding_provider: Arc, language_registry: Arc, } +pub struct SearchResult { + pub name: String, + pub offset: usize, + pub file_path: PathBuf, +} + impl VectorStore { fn new( fs: Arc, - http_client: Arc, + database_url: String, + embedding_provider: Arc, language_registry: Arc, ) -> Self { Self { fs, - http_client, + database_url: database_url.into(), + embedding_provider, language_registry, } } @@ -79,10 +99,12 @@ impl VectorStore { cursor: &mut QueryCursor, parser: &mut Parser, embedding_provider: &dyn EmbeddingProvider, - fs: &Arc, language_registry: &Arc, file_path: PathBuf, + content: String, ) -> Result { + dbg!(&file_path, &content); + let language = language_registry .language_for_file(&file_path, None) .await?; @@ -97,7 +119,6 @@ impl VectorStore { .as_ref() .ok_or_else(|| anyhow!("no outline query"))?; - let content = fs.load(&file_path).await?; parser.set_language(grammar.ts_language).unwrap(); let tree = parser .parse(&content, None) @@ -142,7 +163,11 @@ impl VectorStore { }); } - fn add_project(&mut self, project: ModelHandle, cx: &mut ModelContext) { + fn add_project( + &mut self, + project: ModelHandle, + cx: &mut ModelContext, + ) -> Task> { let worktree_scans_complete = project .read(cx) .worktrees(cx) @@ -151,7 +176,8 @@ impl VectorStore { let fs = self.fs.clone(); let language_registry = self.language_registry.clone(); - let client = self.http_client.clone(); + let embedding_provider = self.embedding_provider.clone(); + let database_url = self.database_url.clone(); cx.spawn(|_, cx| async move { futures::future::join_all(worktree_scans_complete).await; @@ -163,24 +189,47 @@ impl VectorStore { .collect::>() }); - let (paths_tx, paths_rx) = channel::unbounded::(); + let db = VectorDatabase::new(&database_url)?; + let worktree_root_paths = worktrees + .iter() + .map(|worktree| worktree.abs_path().clone()) + .collect::>(); + let (db, file_hashes) = cx + .background() + .spawn(async move { + let mut hashes = Vec::new(); + for worktree_root_path in worktree_root_paths { + let worktree_id = + db.find_or_create_worktree(worktree_root_path.as_ref())?; + hashes.push((worktree_id, db.get_file_hashes(worktree_id)?)); + } + anyhow::Ok((db, hashes)) + }) + .await?; + + let (paths_tx, paths_rx) = channel::unbounded::<(i64, PathBuf, String)>(); let (indexed_files_tx, indexed_files_rx) = channel::unbounded::(); cx.background() - .spawn(async move { - for worktree in worktrees { - for file in worktree.files(false, 0) { - paths_tx.try_send(worktree.absolutize(&file.path)).unwrap(); + .spawn({ + let fs = fs.clone(); + async move { + for worktree in worktrees.into_iter() { + for file in worktree.files(false, 0) { + let absolute_path = worktree.absolutize(&file.path); + dbg!(&absolute_path); + if let Some(content) = fs.load(&absolute_path).await.log_err() { + dbg!(&content); + paths_tx.try_send((0, absolute_path, content)).unwrap(); + } + } } } }) .detach(); - cx.background() - .spawn({ - let client = client.clone(); - async move { + let db_write_task = cx.background().spawn( + async move { // Initialize Database, creates database and tables if not exists - let db = VectorDatabase::new()?; while let Ok(indexed_file) = indexed_files_rx.recv().await { db.insert_file(indexed_file).log_err(); } @@ -188,39 +237,39 @@ impl VectorStore { // ALL OF THE BELOW IS FOR TESTING, // This should be removed as we find and appropriate place for evaluate our search. - let embedding_provider = OpenAIEmbeddings{ client }; - let queries = vec![ - "compute embeddings for all of the symbols in the codebase, and write them to a database", - "compute an outline view of all of the symbols in a buffer", - "scan a directory on the file system and load all of its children into an in-memory snapshot", - ]; - let embeddings = embedding_provider.embed_batch(queries.clone()).await?; + // let queries = vec![ + // "compute embeddings for all of the symbols in the codebase, and write them to a database", + // "compute an outline view of all of the symbols in a buffer", + // "scan a directory on the file system and load all of its children into an in-memory snapshot", + // ]; + // let embeddings = embedding_provider.embed_batch(queries.clone()).await?; - let t2 = Instant::now(); - let documents = db.get_documents().unwrap(); - let files = db.get_files().unwrap(); - println!("Retrieving all documents from Database: {}", t2.elapsed().as_millis()); + // let t2 = Instant::now(); + // let documents = db.get_documents().unwrap(); + // let files = db.get_files().unwrap(); + // println!("Retrieving all documents from Database: {}", t2.elapsed().as_millis()); - let t1 = Instant::now(); - let mut bfs = BruteForceSearch::load(&db).unwrap(); - println!("Loading BFS to Memory: {:?}", t1.elapsed().as_millis()); - for (idx, embed) in embeddings.into_iter().enumerate() { - let t0 = Instant::now(); - println!("\nQuery: {:?}", queries[idx]); - let results = bfs.top_k_search(&embed, 5).await; - println!("Search Elapsed: {}", t0.elapsed().as_millis()); - for (id, distance) in results { - println!(""); - println!(" distance: {:?}", distance); - println!(" document: {:?}", documents[&id].name); - println!(" path: {:?}", files[&documents[&id].file_id].path); - } + // let t1 = Instant::now(); + // let mut bfs = BruteForceSearch::load(&db).unwrap(); + // println!("Loading BFS to Memory: {:?}", t1.elapsed().as_millis()); + // for (idx, embed) in embeddings.into_iter().enumerate() { + // let t0 = Instant::now(); + // println!("\nQuery: {:?}", queries[idx]); + // let results = bfs.top_k_search(&embed, 5).await; + // println!("Search Elapsed: {}", t0.elapsed().as_millis()); + // for (id, distance) in results { + // println!(""); + // println!(" distance: {:?}", distance); + // println!(" document: {:?}", documents[&id].name); + // println!(" path: {:?}", files[&documents[&id].file_id].relative_path); + // } - } + // } anyhow::Ok(()) - }}.log_err()) - .detach(); + } + .log_err(), + ); let provider = DummyEmbeddings {}; // let provider = OpenAIEmbeddings { client }; @@ -231,14 +280,15 @@ impl VectorStore { scope.spawn(async { let mut parser = Parser::new(); let mut cursor = QueryCursor::new(); - while let Ok(file_path) = paths_rx.recv().await { + while let Ok((worktree_id, file_path, content)) = paths_rx.recv().await + { if let Some(indexed_file) = Self::index_file( &mut cursor, &mut parser, &provider, - &fs, &language_registry, file_path, + content, ) .await .log_err() @@ -250,11 +300,86 @@ impl VectorStore { } }) .await; + drop(indexed_files_tx); + + db_write_task.await; + anyhow::Ok(()) + }) + } + + pub fn search( + &mut self, + phrase: String, + limit: usize, + cx: &mut ModelContext, + ) -> Task>> { + let embedding_provider = self.embedding_provider.clone(); + let database_url = self.database_url.clone(); + cx.spawn(|this, cx| async move { + let database = VectorDatabase::new(database_url.as_ref())?; + + // let embedding = embedding_provider.embed_batch(vec![&phrase]).await?; + // + let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1); + + database.for_each_document(0, |id, embedding| { + dbg!(id, &embedding); + + let similarity = dot(&embedding.0, &embedding.0); + let ix = match results.binary_search_by(|(_, s)| { + s.partial_cmp(&similarity).unwrap_or(Ordering::Equal) + }) { + Ok(ix) => ix, + Err(ix) => ix, + }; + + results.insert(ix, (id, similarity)); + results.truncate(limit); + })?; + + dbg!(&results); + + let ids = results.into_iter().map(|(id, _)| id).collect::>(); + // let documents = database.get_documents_by_ids(ids)?; + + // let search_provider = cx + // .background() + // .spawn(async move { BruteForceSearch::load(&database) }) + // .await?; + + // let results = search_provider.top_k_search(&embedding, limit)) + + anyhow::Ok(vec![]) }) - .detach(); } } impl Entity for VectorStore { type Event = (); } + +fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 { + let len = vec_a.len(); + assert_eq!(len, vec_b.len()); + + let mut result = 0.0; + unsafe { + matrixmultiply::sgemm( + 1, + len, + 1, + 1.0, + vec_a.as_ptr(), + len as isize, + 1, + vec_b.as_ptr(), + 1, + len as isize, + 0.0, + &mut result as *mut f32, + 1, + 1, + ); + } + result +} diff --git a/crates/vector_store/src/vector_store_tests.rs b/crates/vector_store/src/vector_store_tests.rs new file mode 100644 index 0000000000..f3d01835e9 --- /dev/null +++ b/crates/vector_store/src/vector_store_tests.rs @@ -0,0 +1,136 @@ +use std::sync::Arc; + +use crate::{dot, embedding::EmbeddingProvider, VectorStore}; +use anyhow::Result; +use async_trait::async_trait; +use gpui::{Task, TestAppContext}; +use language::{Language, LanguageConfig, LanguageRegistry}; +use project::{FakeFs, Project}; +use rand::Rng; +use serde_json::json; +use unindent::Unindent; + +#[gpui::test] +async fn test_vector_store(cx: &mut TestAppContext) { + let fs = FakeFs::new(cx.background()); + fs.insert_tree( + "/the-root", + json!({ + "src": { + "file1.rs": " + fn aaa() { + println!(\"aaaa!\"); + } + + fn zzzzzzzzz() { + println!(\"SLEEPING\"); + } + ".unindent(), + "file2.rs": " + fn bbb() { + println!(\"bbbb!\"); + } + ".unindent(), + } + }), + ) + .await; + + let languages = Arc::new(LanguageRegistry::new(Task::ready(()))); + let rust_language = Arc::new( + Language::new( + LanguageConfig { + name: "Rust".into(), + path_suffixes: vec!["rs".into()], + ..Default::default() + }, + Some(tree_sitter_rust::language()), + ) + .with_outline_query( + r#" + (function_item + name: (identifier) @name + body: (block)) @item + "#, + ) + .unwrap(), + ); + languages.add(rust_language); + + let store = cx.add_model(|_| { + VectorStore::new( + fs.clone(), + "foo".to_string(), + Arc::new(FakeEmbeddingProvider), + languages, + ) + }); + + let project = Project::test(fs, ["/the-root".as_ref()], cx).await; + store + .update(cx, |store, cx| store.add_project(project, cx)) + .await + .unwrap(); + + let search_results = store + .update(cx, |store, cx| store.search("aaaa".to_string(), 5, cx)) + .await + .unwrap(); + + assert_eq!(search_results[0].offset, 0); + assert_eq!(search_results[1].name, "aaa"); +} + +#[test] +fn test_dot_product() { + assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.); + assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.); + + for _ in 0..100 { + let mut rng = rand::thread_rng(); + let a: [f32; 32] = rng.gen(); + let b: [f32; 32] = rng.gen(); + assert_eq!( + round_to_decimals(dot(&a, &b), 3), + round_to_decimals(reference_dot(&a, &b), 3) + ); + } + + fn round_to_decimals(n: f32, decimal_places: i32) -> f32 { + let factor = (10.0 as f32).powi(decimal_places); + (n * factor).round() / factor + } + + fn reference_dot(a: &[f32], b: &[f32]) -> f32 { + a.iter().zip(b.iter()).map(|(a, b)| a * b).sum() + } +} + +struct FakeEmbeddingProvider; + +#[async_trait] +impl EmbeddingProvider for FakeEmbeddingProvider { + async fn embed_batch(&self, spans: Vec<&str>) -> Result>> { + Ok(spans + .iter() + .map(|span| { + let mut result = vec![0.0; 26]; + for letter in span.chars() { + if letter as u32 > 'a' as u32 { + let ix = (letter as u32) - ('a' as u32); + if ix < 26 { + result[ix as usize] += 1.0; + } + } + } + + let norm = result.iter().map(|x| x * x).sum::().sqrt(); + for x in &mut result { + *x /= norm; + } + + result + }) + .collect()) + } +} From 4bfe3de1f2e012bc3fb7ee9a928d7ae223e3c97d Mon Sep 17 00:00:00 2001 From: KCaverly Date: Tue, 27 Jun 2023 15:31:21 -0400 Subject: [PATCH 10/51] Working incremental index engine, with streaming similarity search! Co-authored-by: maxbrunsfeld --- Cargo.lock | 1 + crates/vector_store/Cargo.toml | 3 +- crates/vector_store/src/db.rs | 184 ++++++++++++++---- crates/vector_store/src/vector_store.rs | 168 +++++++++------- crates/vector_store/src/vector_store_tests.rs | 23 ++- 5 files changed, 268 insertions(+), 111 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ff4caaa5a6..1ea1d1a1b4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7967,6 +7967,7 @@ dependencies = [ "serde_json", "sha-1 0.10.1", "smol", + "tempdir", "tree-sitter", "tree-sitter-rust", "unindent", diff --git a/crates/vector_store/Cargo.toml b/crates/vector_store/Cargo.toml index dbe0a2e69c..edc06bb295 100644 --- a/crates/vector_store/Cargo.toml +++ b/crates/vector_store/Cargo.toml @@ -17,7 +17,7 @@ util = { path = "../util" } anyhow.workspace = true futures.workspace = true smol.workspace = true -rusqlite = { version = "0.27.0", features=["blob"] } +rusqlite = { version = "0.27.0", features = ["blob", "array", "modern_sqlite"] } isahc.workspace = true log.workspace = true tree-sitter.workspace = true @@ -38,3 +38,4 @@ workspace = { path = "../workspace", features = ["test-support"] } tree-sitter-rust = "*" rand.workspace = true unindent.workspace = true +tempdir.workspace = true diff --git a/crates/vector_store/src/db.rs b/crates/vector_store/src/db.rs index bcb1090a8d..f074a7066b 100644 --- a/crates/vector_store/src/db.rs +++ b/crates/vector_store/src/db.rs @@ -7,9 +7,10 @@ use anyhow::{anyhow, Result}; use rusqlite::{ params, - types::{FromSql, FromSqlResult, ValueRef}, - Connection, + types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef}, + ToSql, }; +use sha1::{Digest, Sha1}; use crate::IndexedFile; @@ -32,7 +33,60 @@ pub struct DocumentRecord { pub struct FileRecord { pub id: usize, pub relative_path: String, - pub sha1: String, + pub sha1: FileSha1, +} + +#[derive(Debug)] +pub struct FileSha1(pub Vec); + +impl FileSha1 { + pub fn from_str(content: String) -> Self { + let mut hasher = Sha1::new(); + hasher.update(content); + let sha1 = hasher.finalize()[..] + .into_iter() + .map(|val| val.to_owned()) + .collect::>(); + return FileSha1(sha1); + } + + pub fn equals(&self, content: &String) -> bool { + let mut hasher = Sha1::new(); + hasher.update(content); + let sha1 = hasher.finalize()[..] + .into_iter() + .map(|val| val.to_owned()) + .collect::>(); + + let equal = self + .0 + .clone() + .into_iter() + .zip(sha1) + .filter(|&(a, b)| a == b) + .count() + == self.0.len(); + + equal + } +} + +impl ToSql for FileSha1 { + fn to_sql(&self) -> rusqlite::Result> { + return self.0.to_sql(); + } +} + +impl FromSql for FileSha1 { + fn column_result(value: ValueRef) -> FromSqlResult { + let bytes = value.as_blob()?; + Ok(FileSha1( + bytes + .into_iter() + .map(|val| val.to_owned()) + .collect::>(), + )) + } } #[derive(Debug)] @@ -63,6 +117,8 @@ impl VectorDatabase { } fn initialize_database(&self) -> Result<()> { + rusqlite::vtab::array::load_module(&self.db)?; + // This will create the database if it doesnt exist // Initialize Vector Databasing Tables @@ -81,7 +137,7 @@ impl VectorDatabase { id INTEGER PRIMARY KEY AUTOINCREMENT, worktree_id INTEGER NOT NULL, relative_path VARCHAR NOT NULL, - sha1 NVARCHAR(40) NOT NULL, + sha1 BLOB NOT NULL, FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE )", [], @@ -102,30 +158,23 @@ impl VectorDatabase { Ok(()) } - // pub async fn get_or_create_project(project_path: PathBuf) -> Result { - // // Check if we have the project, if we do, return the ID - // // If we do not have the project, insert the project and return the ID - - // let db = rusqlite::Connection::open(VECTOR_DB_URL)?; - - // let projects_query = db.prepare(&format!( - // "SELECT id FROM projects WHERE path = {}", - // project_path.to_str().unwrap() // This is unsafe - // ))?; - - // let project_id = db.last_insert_rowid(); - - // return Ok(project_id as usize); - // } - - pub fn insert_file(&self, indexed_file: IndexedFile) -> Result<()> { + pub fn insert_file(&self, worktree_id: i64, indexed_file: IndexedFile) -> Result<()> { // Write to files table, and return generated id. - let files_insert = self.db.execute( - "INSERT INTO files (relative_path, sha1) VALUES (?1, ?2)", - params![indexed_file.path.to_str(), indexed_file.sha1], + log::info!("Inserting File!"); + self.db.execute( + " + DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2; + ", + params![worktree_id, indexed_file.path.to_str()], + )?; + self.db.execute( + " + INSERT INTO files (worktree_id, relative_path, sha1) VALUES (?1, ?2, $3); + ", + params![worktree_id, indexed_file.path.to_str(), indexed_file.sha1], )?; - let inserted_id = self.db.last_insert_rowid(); + let file_id = self.db.last_insert_rowid(); // Currently inserting at approximately 3400 documents a second // I imagine we can speed this up with a bulk insert of some kind. @@ -135,7 +184,7 @@ impl VectorDatabase { self.db.execute( "INSERT INTO documents (file_id, offset, name, embedding) VALUES (?1, ?2, ?3, ?4)", params![ - inserted_id, + file_id, document.offset.to_string(), document.name, embedding_blob @@ -147,25 +196,41 @@ impl VectorDatabase { } pub fn find_or_create_worktree(&self, worktree_root_path: &Path) -> Result { + // Check that the absolute path doesnt exist + let mut worktree_query = self + .db + .prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?; + + let worktree_id = worktree_query + .query_row(params![worktree_root_path.to_string_lossy()], |row| { + Ok(row.get::<_, i64>(0)?) + }) + .map_err(|err| anyhow!(err)); + + if worktree_id.is_ok() { + return worktree_id; + } + + // If worktree_id is Err, insert new worktree self.db.execute( " INSERT into worktrees (absolute_path) VALUES (?1) - ON CONFLICT DO NOTHING ", params![worktree_root_path.to_string_lossy()], )?; Ok(self.db.last_insert_rowid()) } - pub fn get_file_hashes(&self, worktree_id: i64) -> Result> { - let mut statement = self - .db - .prepare("SELECT relative_path, sha1 FROM files ORDER BY relative_path")?; - let mut result = Vec::new(); - for row in - statement.query_map([], |row| Ok((row.get::<_, String>(0)?.into(), row.get(1)?)))? - { - result.push(row?); + pub fn get_file_hashes(&self, worktree_id: i64) -> Result> { + let mut statement = self.db.prepare( + "SELECT relative_path, sha1 FROM files WHERE worktree_id = ?1 ORDER BY relative_path", + )?; + let mut result: HashMap = HashMap::new(); + for row in statement.query_map(params![worktree_id], |row| { + Ok((row.get::<_, String>(0)?.into(), row.get(1)?)) + })? { + let row = row?; + result.insert(row.0, row.1); } Ok(result) } @@ -204,6 +269,53 @@ impl VectorDatabase { Ok(()) } + pub fn get_documents_by_ids(&self, ids: &[i64]) -> Result> { + let mut statement = self.db.prepare( + " + SELECT + documents.id, files.relative_path, documents.offset, documents.name + FROM + documents, files + WHERE + documents.file_id = files.id AND + documents.id in rarray(?) + ", + )?; + + let result_iter = statement.query_map( + params![std::rc::Rc::new( + ids.iter() + .copied() + .map(|v| rusqlite::types::Value::from(v)) + .collect::>() + )], + |row| { + Ok(( + row.get::<_, i64>(0)?, + row.get::<_, String>(1)?.into(), + row.get(2)?, + row.get(3)?, + )) + }, + )?; + + let mut values_by_id = HashMap::::default(); + for row in result_iter { + let (id, path, offset, name) = row?; + values_by_id.insert(id, (path, offset, name)); + } + + let mut results = Vec::with_capacity(ids.len()); + for id in ids { + let (path, offset, name) = values_by_id + .remove(id) + .ok_or(anyhow!("missing document id {}", id))?; + results.push((path, offset, name)); + } + + Ok(results) + } + pub fn get_documents(&self) -> Result> { let mut query_statement = self .db diff --git a/crates/vector_store/src/vector_store.rs b/crates/vector_store/src/vector_store.rs index f34316e950..7e4c29cef6 100644 --- a/crates/vector_store/src/vector_store.rs +++ b/crates/vector_store/src/vector_store.rs @@ -7,15 +7,14 @@ mod search; mod vector_store_tests; use anyhow::{anyhow, Result}; -use db::{VectorDatabase, VECTOR_DB_URL}; -use embedding::{DummyEmbeddings, EmbeddingProvider, OpenAIEmbeddings}; +use db::{FileSha1, VectorDatabase, VECTOR_DB_URL}; +use embedding::{EmbeddingProvider, OpenAIEmbeddings}; use gpui::{AppContext, Entity, ModelContext, ModelHandle, Task}; -use language::LanguageRegistry; +use language::{Language, LanguageRegistry}; use parsing::Document; use project::{Fs, Project}; -use search::{BruteForceSearch, VectorSearch}; use smol::channel; -use std::{cmp::Ordering, path::PathBuf, sync::Arc, time::Instant}; +use std::{cmp::Ordering, collections::HashMap, path::PathBuf, sync::Arc}; use tree_sitter::{Parser, QueryCursor}; use util::{http::HttpClient, ResultExt, TryFutureExt}; use workspace::WorkspaceCreated; @@ -45,7 +44,7 @@ pub fn init( let project = workspace.read(cx).project().clone(); if project.read(cx).is_local() { vector_store.update(cx, |store, cx| { - store.add_project(project, cx); + store.add_project(project, cx).detach(); }); } } @@ -57,16 +56,10 @@ pub fn init( #[derive(Debug)] pub struct IndexedFile { path: PathBuf, - sha1: String, + sha1: FileSha1, documents: Vec, } -// struct SearchResult { -// path: PathBuf, -// offset: usize, -// name: String, -// distance: f32, -// } struct VectorStore { fs: Arc, database_url: Arc, @@ -99,20 +92,10 @@ impl VectorStore { cursor: &mut QueryCursor, parser: &mut Parser, embedding_provider: &dyn EmbeddingProvider, - language_registry: &Arc, + language: Arc, file_path: PathBuf, content: String, ) -> Result { - dbg!(&file_path, &content); - - let language = language_registry - .language_for_file(&file_path, None) - .await?; - - if language.name().as_ref() != "Rust" { - Err(anyhow!("unsupported language"))?; - } - let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?; let outline_config = grammar .outline_config @@ -156,9 +139,11 @@ impl VectorStore { document.embedding = embedding; } + let sha1 = FileSha1::from_str(content); + return Ok(IndexedFile { path: file_path, - sha1: String::new(), + sha1, documents, }); } @@ -171,7 +156,13 @@ impl VectorStore { let worktree_scans_complete = project .read(cx) .worktrees(cx) - .map(|worktree| worktree.read(cx).as_local().unwrap().scan_complete()) + .map(|worktree| { + let scan_complete = worktree.read(cx).as_local().unwrap().scan_complete(); + async move { + scan_complete.await; + log::info!("worktree scan completed"); + } + }) .collect::>(); let fs = self.fs.clone(); @@ -182,6 +173,13 @@ impl VectorStore { cx.spawn(|_, cx| async move { futures::future::join_all(worktree_scans_complete).await; + // TODO: remove this after fixing the bug in scan_complete + cx.background() + .timer(std::time::Duration::from_secs(3)) + .await; + + let db = VectorDatabase::new(&database_url)?; + let worktrees = project.read_with(&cx, |project, cx| { project .worktrees(cx) @@ -189,37 +187,74 @@ impl VectorStore { .collect::>() }); - let db = VectorDatabase::new(&database_url)?; let worktree_root_paths = worktrees .iter() .map(|worktree| worktree.abs_path().clone()) .collect::>(); - let (db, file_hashes) = cx + + // Here we query the worktree ids, and yet we dont have them elsewhere + // We likely want to clean up these datastructures + let (db, worktree_hashes, worktree_ids) = cx .background() .spawn(async move { - let mut hashes = Vec::new(); + let mut worktree_ids: HashMap = HashMap::new(); + let mut hashes: HashMap> = HashMap::new(); for worktree_root_path in worktree_root_paths { let worktree_id = db.find_or_create_worktree(worktree_root_path.as_ref())?; - hashes.push((worktree_id, db.get_file_hashes(worktree_id)?)); + worktree_ids.insert(worktree_root_path.to_path_buf(), worktree_id); + hashes.insert(worktree_id, db.get_file_hashes(worktree_id)?); } - anyhow::Ok((db, hashes)) + anyhow::Ok((db, hashes, worktree_ids)) }) .await?; - let (paths_tx, paths_rx) = channel::unbounded::<(i64, PathBuf, String)>(); - let (indexed_files_tx, indexed_files_rx) = channel::unbounded::(); + let (paths_tx, paths_rx) = + channel::unbounded::<(i64, PathBuf, String, Arc)>(); + let (indexed_files_tx, indexed_files_rx) = channel::unbounded::<(i64, IndexedFile)>(); cx.background() .spawn({ let fs = fs.clone(); async move { for worktree in worktrees.into_iter() { + let worktree_id = worktree_ids[&worktree.abs_path().to_path_buf()]; + let file_hashes = &worktree_hashes[&worktree_id]; for file in worktree.files(false, 0) { let absolute_path = worktree.absolutize(&file.path); - dbg!(&absolute_path); - if let Some(content) = fs.load(&absolute_path).await.log_err() { - dbg!(&content); - paths_tx.try_send((0, absolute_path, content)).unwrap(); + + if let Ok(language) = language_registry + .language_for_file(&absolute_path, None) + .await + { + if language.name().as_ref() != "Rust" { + continue; + } + + if let Some(content) = fs.load(&absolute_path).await.log_err() { + log::info!("loaded file: {absolute_path:?}"); + + let path_buf = file.path.to_path_buf(); + let already_stored = file_hashes + .get(&path_buf) + .map_or(false, |existing_hash| { + existing_hash.equals(&content) + }); + + if !already_stored { + log::info!( + "File Changed (Sending to Parse): {:?}", + &path_buf + ); + paths_tx + .try_send(( + worktree_id, + path_buf, + content, + language, + )) + .unwrap(); + } + } } } } @@ -230,8 +265,8 @@ impl VectorStore { let db_write_task = cx.background().spawn( async move { // Initialize Database, creates database and tables if not exists - while let Ok(indexed_file) = indexed_files_rx.recv().await { - db.insert_file(indexed_file).log_err(); + while let Ok((worktree_id, indexed_file)) = indexed_files_rx.recv().await { + db.insert_file(worktree_id, indexed_file).log_err(); } // ALL OF THE BELOW IS FOR TESTING, @@ -271,29 +306,29 @@ impl VectorStore { .log_err(), ); - let provider = DummyEmbeddings {}; - // let provider = OpenAIEmbeddings { client }; - cx.background() .scoped(|scope| { for _ in 0..cx.background().num_cpus() { scope.spawn(async { let mut parser = Parser::new(); let mut cursor = QueryCursor::new(); - while let Ok((worktree_id, file_path, content)) = paths_rx.recv().await + while let Ok((worktree_id, file_path, content, language)) = + paths_rx.recv().await { if let Some(indexed_file) = Self::index_file( &mut cursor, &mut parser, - &provider, - &language_registry, + embedding_provider.as_ref(), + language, file_path, content, ) .await .log_err() { - indexed_files_tx.try_send(indexed_file).unwrap(); + indexed_files_tx + .try_send((worktree_id, indexed_file)) + .unwrap(); } } }); @@ -315,41 +350,42 @@ impl VectorStore { ) -> Task>> { let embedding_provider = self.embedding_provider.clone(); let database_url = self.database_url.clone(); - cx.spawn(|this, cx| async move { + cx.background().spawn(async move { let database = VectorDatabase::new(database_url.as_ref())?; - // let embedding = embedding_provider.embed_batch(vec![&phrase]).await?; - // + let phrase_embedding = embedding_provider + .embed_batch(vec![&phrase]) + .await? + .into_iter() + .next() + .unwrap(); + let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1); - database.for_each_document(0, |id, embedding| { - dbg!(id, &embedding); - - let similarity = dot(&embedding.0, &embedding.0); + let similarity = dot(&embedding.0, &phrase_embedding); let ix = match results.binary_search_by(|(_, s)| { - s.partial_cmp(&similarity).unwrap_or(Ordering::Equal) + similarity.partial_cmp(&s).unwrap_or(Ordering::Equal) }) { Ok(ix) => ix, Err(ix) => ix, }; - results.insert(ix, (id, similarity)); results.truncate(limit); })?; - dbg!(&results); - let ids = results.into_iter().map(|(id, _)| id).collect::>(); - // let documents = database.get_documents_by_ids(ids)?; + let documents = database.get_documents_by_ids(&ids)?; - // let search_provider = cx - // .background() - // .spawn(async move { BruteForceSearch::load(&database) }) - // .await?; - - // let results = search_provider.top_k_search(&embedding, limit)) - - anyhow::Ok(vec![]) + anyhow::Ok( + documents + .into_iter() + .map(|(file_path, offset, name)| SearchResult { + name, + offset, + file_path, + }) + .collect(), + ) }) } } diff --git a/crates/vector_store/src/vector_store_tests.rs b/crates/vector_store/src/vector_store_tests.rs index f3d01835e9..c67bb9954f 100644 --- a/crates/vector_store/src/vector_store_tests.rs +++ b/crates/vector_store/src/vector_store_tests.rs @@ -57,20 +57,26 @@ async fn test_vector_store(cx: &mut TestAppContext) { ); languages.add(rust_language); + let db_dir = tempdir::TempDir::new("vector-store").unwrap(); + let db_path = db_dir.path().join("db.sqlite"); + let store = cx.add_model(|_| { VectorStore::new( fs.clone(), - "foo".to_string(), + db_path.to_string_lossy().to_string(), Arc::new(FakeEmbeddingProvider), languages, ) }); let project = Project::test(fs, ["/the-root".as_ref()], cx).await; - store - .update(cx, |store, cx| store.add_project(project, cx)) - .await - .unwrap(); + let add_project = store.update(cx, |store, cx| store.add_project(project, cx)); + + // TODO - remove + cx.foreground() + .advance_clock(std::time::Duration::from_secs(3)); + + add_project.await.unwrap(); let search_results = store .update(cx, |store, cx| store.search("aaaa".to_string(), 5, cx)) @@ -78,7 +84,7 @@ async fn test_vector_store(cx: &mut TestAppContext) { .unwrap(); assert_eq!(search_results[0].offset, 0); - assert_eq!(search_results[1].name, "aaa"); + assert_eq!(search_results[0].name, "aaa"); } #[test] @@ -114,9 +120,10 @@ impl EmbeddingProvider for FakeEmbeddingProvider { Ok(spans .iter() .map(|span| { - let mut result = vec![0.0; 26]; + let mut result = vec![1.0; 26]; for letter in span.chars() { - if letter as u32 > 'a' as u32 { + let letter = letter.to_ascii_lowercase(); + if letter as u32 >= 'a' as u32 { let ix = (letter as u32) - ('a' as u32); if ix < 26 { result[ix as usize] += 1.0; From d1bdfa0be6e2a638b5f8dd6e836fd2aa06c3264f Mon Sep 17 00:00:00 2001 From: KCaverly Date: Tue, 27 Jun 2023 15:53:07 -0400 Subject: [PATCH 11/51] Added a dummy action for testing the semantic search functionality in the command palette. Co-authored-by: maxbrunsfeld --- crates/vector_store/src/vector_store.rs | 27 +++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/crates/vector_store/src/vector_store.rs b/crates/vector_store/src/vector_store.rs index 7e4c29cef6..4860bcd2bb 100644 --- a/crates/vector_store/src/vector_store.rs +++ b/crates/vector_store/src/vector_store.rs @@ -9,7 +9,7 @@ mod vector_store_tests; use anyhow::{anyhow, Result}; use db::{FileSha1, VectorDatabase, VECTOR_DB_URL}; use embedding::{EmbeddingProvider, OpenAIEmbeddings}; -use gpui::{AppContext, Entity, ModelContext, ModelHandle, Task}; +use gpui::{actions, AppContext, Entity, ModelContext, ModelHandle, Task, ViewContext}; use language::{Language, LanguageRegistry}; use parsing::Document; use project::{Fs, Project}; @@ -17,7 +17,9 @@ use smol::channel; use std::{cmp::Ordering, collections::HashMap, path::PathBuf, sync::Arc}; use tree_sitter::{Parser, QueryCursor}; use util::{http::HttpClient, ResultExt, TryFutureExt}; -use workspace::WorkspaceCreated; +use workspace::{Workspace, WorkspaceCreated}; + +actions!(semantic_search, [TestSearch]); pub fn init( fs: Arc, @@ -51,6 +53,26 @@ pub fn init( } }) .detach(); + + cx.add_action({ + let vector_store = vector_store.clone(); + move |workspace: &mut Workspace, _: &TestSearch, cx: &mut ViewContext| { + let t0 = std::time::Instant::now(); + let task = vector_store.update(cx, |store, cx| { + store.search("compute embeddings for all of the symbols in the codebase and write them to a database".to_string(), 10, cx) + }); + + cx.spawn(|this, cx| async move { + let results = task.await?; + let duration = t0.elapsed(); + + println!("search took {:?}", duration); + println!("results {:?}", results); + + anyhow::Ok(()) + }).detach() + } + }); } #[derive(Debug)] @@ -67,6 +89,7 @@ struct VectorStore { language_registry: Arc, } +#[derive(Debug)] pub struct SearchResult { pub name: String, pub offset: usize, From 9d19dea7dd858bb49fbbc34ed8eb56b0146d8ed3 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Wed, 28 Jun 2023 08:58:50 -0400 Subject: [PATCH 12/51] updated vector_store to remove parsing module --- crates/vector_store/src/vector_store.rs | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/crates/vector_store/src/vector_store.rs b/crates/vector_store/src/vector_store.rs index 4860bcd2bb..d7fd59466f 100644 --- a/crates/vector_store/src/vector_store.rs +++ b/crates/vector_store/src/vector_store.rs @@ -1,6 +1,5 @@ mod db; mod embedding; -mod parsing; mod search; #[cfg(test)] @@ -11,7 +10,6 @@ use db::{FileSha1, VectorDatabase, VECTOR_DB_URL}; use embedding::{EmbeddingProvider, OpenAIEmbeddings}; use gpui::{actions, AppContext, Entity, ModelContext, ModelHandle, Task, ViewContext}; use language::{Language, LanguageRegistry}; -use parsing::Document; use project::{Fs, Project}; use smol::channel; use std::{cmp::Ordering, collections::HashMap, path::PathBuf, sync::Arc}; @@ -21,6 +19,13 @@ use workspace::{Workspace, WorkspaceCreated}; actions!(semantic_search, [TestSearch]); +#[derive(Debug)] +pub struct Document { + pub offset: usize, + pub name: String, + pub embedding: Vec, +} + pub fn init( fs: Arc, http_client: Arc, From 40ff7779bbf858fe4786602ad1edef90ad69ca51 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Wed, 28 Jun 2023 13:27:26 -0400 Subject: [PATCH 13/51] WIP: Working modal, without navigation and search on every keystroke --- Cargo.lock | 2 + crates/vector_store/Cargo.toml | 2 + crates/vector_store/src/modal.rs | 107 ++++++++++++++++++++++++ crates/vector_store/src/vector_store.rs | 54 +++++++----- 4 files changed, 146 insertions(+), 19 deletions(-) create mode 100644 crates/vector_store/src/modal.rs diff --git a/Cargo.lock b/Cargo.lock index 1ea1d1a1b4..2eff8630cc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7960,6 +7960,7 @@ dependencies = [ "log", "matrixmultiply", "ndarray", + "picker", "project", "rand 0.8.5", "rusqlite", @@ -7968,6 +7969,7 @@ dependencies = [ "sha-1 0.10.1", "smol", "tempdir", + "theme", "tree-sitter", "tree-sitter-rust", "unindent", diff --git a/crates/vector_store/Cargo.toml b/crates/vector_store/Cargo.toml index edc06bb295..ddfef6927b 100644 --- a/crates/vector_store/Cargo.toml +++ b/crates/vector_store/Cargo.toml @@ -14,6 +14,8 @@ language = { path = "../language" } project = { path = "../project" } workspace = { path = "../workspace" } util = { path = "../util" } +picker = { path = "../picker" } +theme = { path = "../theme" } anyhow.workspace = true futures.workspace = true smol.workspace = true diff --git a/crates/vector_store/src/modal.rs b/crates/vector_store/src/modal.rs new file mode 100644 index 0000000000..48429150cd --- /dev/null +++ b/crates/vector_store/src/modal.rs @@ -0,0 +1,107 @@ +use std::sync::Arc; + +use gpui::{ + actions, elements::*, AnyElement, AppContext, ModelHandle, MouseState, Task, ViewContext, + WeakViewHandle, +}; +use picker::{Picker, PickerDelegate, PickerEvent}; +use project::Project; +use util::ResultExt; +use workspace::Workspace; + +use crate::{SearchResult, VectorStore}; + +actions!(semantic_search, [Toggle]); + +pub type SemanticSearch = Picker; + +pub struct SemanticSearchDelegate { + workspace: WeakViewHandle, + project: ModelHandle, + vector_store: ModelHandle, + selected_match_index: usize, + matches: Vec, +} + +impl SemanticSearchDelegate { + // This is currently searching on every keystroke, + // This is wildly overkill, and has the potential to get expensive + // We will need to update this to throttle searching + pub fn new( + workspace: WeakViewHandle, + project: ModelHandle, + vector_store: ModelHandle, + ) -> Self { + Self { + workspace, + project, + vector_store, + selected_match_index: 0, + matches: vec![], + } + } +} + +impl PickerDelegate for SemanticSearchDelegate { + fn placeholder_text(&self) -> Arc { + "Search repository in natural language...".into() + } + + fn confirm(&mut self, cx: &mut ViewContext) { + todo!() + } + + fn dismissed(&mut self, _cx: &mut ViewContext) {} + + fn match_count(&self) -> usize { + self.matches.len() + } + + fn selected_index(&self) -> usize { + self.selected_match_index + } + + fn set_selected_index(&mut self, ix: usize, _cx: &mut ViewContext) { + self.selected_match_index = ix; + } + + fn update_matches(&mut self, query: String, cx: &mut ViewContext) -> Task<()> { + let task = self + .vector_store + .update(cx, |store, cx| store.search(query.to_string(), 10, cx)); + + cx.spawn(|this, mut cx| async move { + let results = task.await.log_err(); + this.update(&mut cx, |this, cx| { + if let Some(results) = results { + let delegate = this.delegate_mut(); + delegate.matches = results; + } + }); + }) + } + + fn render_match( + &self, + ix: usize, + mouse_state: &mut MouseState, + selected: bool, + cx: &AppContext, + ) -> AnyElement> { + let theme = theme::current(cx); + let style = &theme.picker.item; + let current_style = style.style_for(mouse_state, selected); + + let search_result = &self.matches[ix]; + + let mut path = search_result.file_path.to_string_lossy(); + let name = search_result.name.clone(); + + Flex::column() + .with_child(Text::new(name, current_style.label.text.clone()).with_soft_wrap(false)) + .with_child(Label::new(path.to_string(), style.default.label.clone())) + .contained() + .with_style(current_style.container) + .into_any() + } +} diff --git a/crates/vector_store/src/vector_store.rs b/crates/vector_store/src/vector_store.rs index d7fd59466f..2dc479045f 100644 --- a/crates/vector_store/src/vector_store.rs +++ b/crates/vector_store/src/vector_store.rs @@ -1,5 +1,6 @@ mod db; mod embedding; +mod modal; mod search; #[cfg(test)] @@ -10,6 +11,7 @@ use db::{FileSha1, VectorDatabase, VECTOR_DB_URL}; use embedding::{EmbeddingProvider, OpenAIEmbeddings}; use gpui::{actions, AppContext, Entity, ModelContext, ModelHandle, Task, ViewContext}; use language::{Language, LanguageRegistry}; +use modal::{SemanticSearch, SemanticSearchDelegate, Toggle}; use project::{Fs, Project}; use smol::channel; use std::{cmp::Ordering, collections::HashMap, path::PathBuf, sync::Arc}; @@ -17,8 +19,6 @@ use tree_sitter::{Parser, QueryCursor}; use util::{http::HttpClient, ResultExt, TryFutureExt}; use workspace::{Workspace, WorkspaceCreated}; -actions!(semantic_search, [TestSearch]); - #[derive(Debug)] pub struct Document { pub offset: usize, @@ -60,24 +60,40 @@ pub fn init( .detach(); cx.add_action({ - let vector_store = vector_store.clone(); - move |workspace: &mut Workspace, _: &TestSearch, cx: &mut ViewContext| { - let t0 = std::time::Instant::now(); - let task = vector_store.update(cx, |store, cx| { - store.search("compute embeddings for all of the symbols in the codebase and write them to a database".to_string(), 10, cx) - }); - - cx.spawn(|this, cx| async move { - let results = task.await?; - let duration = t0.elapsed(); - - println!("search took {:?}", duration); - println!("results {:?}", results); - - anyhow::Ok(()) - }).detach() + move |workspace: &mut Workspace, _: &Toggle, cx: &mut ViewContext| { + let vector_store = vector_store.clone(); + workspace.toggle_modal(cx, |workspace, cx| { + let project = workspace.project().clone(); + let workspace = cx.weak_handle(); + cx.add_view(|cx| { + SemanticSearch::new( + SemanticSearchDelegate::new(workspace, project, vector_store), + cx, + ) + }) + }) } }); + SemanticSearch::init(cx); + // cx.add_action({ + // let vector_store = vector_store.clone(); + // move |workspace: &mut Workspace, _: &TestSearch, cx: &mut ViewContext| { + // let t0 = std::time::Instant::now(); + // let task = vector_store.update(cx, |store, cx| { + // store.search("compute embeddings for all of the symbols in the codebase and write them to a database".to_string(), 10, cx) + // }); + + // cx.spawn(|this, cx| async move { + // let results = task.await?; + // let duration = t0.elapsed(); + + // println!("search took {:?}", duration); + // println!("results {:?}", results); + + // anyhow::Ok(()) + // }).detach() + // } + // }); } #[derive(Debug)] @@ -87,7 +103,7 @@ pub struct IndexedFile { documents: Vec, } -struct VectorStore { +pub struct VectorStore { fs: Arc, database_url: Arc, embedding_provider: Arc, From 400d39740ca505c3b5f143818c0ebe8eeead0e6e Mon Sep 17 00:00:00 2001 From: KCaverly Date: Wed, 28 Jun 2023 16:21:03 -0400 Subject: [PATCH 14/51] updated both indexing and search method for vector store, to maintain both zed worktree ids and db worktree ids Co-authored-by: maxbrunsfeld --- crates/vector_store/src/db.rs | 67 ++++--- crates/vector_store/src/modal.rs | 17 +- crates/vector_store/src/vector_store.rs | 174 ++++++++++-------- crates/vector_store/src/vector_store_tests.rs | 10 +- 4 files changed, 159 insertions(+), 109 deletions(-) diff --git a/crates/vector_store/src/db.rs b/crates/vector_store/src/db.rs index f074a7066b..96856936fc 100644 --- a/crates/vector_store/src/db.rs +++ b/crates/vector_store/src/db.rs @@ -1,6 +1,7 @@ use std::{ collections::HashMap, path::{Path, PathBuf}, + rc::Rc, }; use anyhow::{anyhow, Result}; @@ -258,22 +259,34 @@ impl VectorDatabase { pub fn for_each_document( &self, - worktree_id: i64, + worktree_ids: &[i64], mut f: impl FnMut(i64, Embedding), ) -> Result<()> { - let mut query_statement = self.db.prepare("SELECT id, embedding FROM documents")?; + let mut query_statement = self.db.prepare( + " + SELECT + documents.id, documents.embedding + FROM + documents, files + WHERE + documents.file_id = files.id AND + files.worktree_id IN rarray(?) + ", + )?; query_statement - .query_map(params![], |row| Ok((row.get(0)?, row.get(1)?)))? + .query_map(params![ids_to_sql(worktree_ids)], |row| { + Ok((row.get(0)?, row.get(1)?)) + })? .filter_map(|row| row.ok()) .for_each(|row| f(row.0, row.1)); Ok(()) } - pub fn get_documents_by_ids(&self, ids: &[i64]) -> Result> { + pub fn get_documents_by_ids(&self, ids: &[i64]) -> Result> { let mut statement = self.db.prepare( " SELECT - documents.id, files.relative_path, documents.offset, documents.name + documents.id, files.worktree_id, files.relative_path, documents.offset, documents.name FROM documents, files WHERE @@ -282,35 +295,28 @@ impl VectorDatabase { ", )?; - let result_iter = statement.query_map( - params![std::rc::Rc::new( - ids.iter() - .copied() - .map(|v| rusqlite::types::Value::from(v)) - .collect::>() - )], - |row| { - Ok(( - row.get::<_, i64>(0)?, - row.get::<_, String>(1)?.into(), - row.get(2)?, - row.get(3)?, - )) - }, - )?; + let result_iter = statement.query_map(params![ids_to_sql(ids)], |row| { + Ok(( + row.get::<_, i64>(0)?, + row.get::<_, i64>(1)?, + row.get::<_, String>(2)?.into(), + row.get(3)?, + row.get(4)?, + )) + })?; - let mut values_by_id = HashMap::::default(); + let mut values_by_id = HashMap::::default(); for row in result_iter { - let (id, path, offset, name) = row?; - values_by_id.insert(id, (path, offset, name)); + let (id, worktree_id, path, offset, name) = row?; + values_by_id.insert(id, (worktree_id, path, offset, name)); } let mut results = Vec::with_capacity(ids.len()); for id in ids { - let (path, offset, name) = values_by_id + let value = values_by_id .remove(id) .ok_or(anyhow!("missing document id {}", id))?; - results.push((path, offset, name)); + results.push(value); } Ok(results) @@ -339,3 +345,12 @@ impl VectorDatabase { return Ok(documents); } } + +fn ids_to_sql(ids: &[i64]) -> Rc> { + Rc::new( + ids.iter() + .copied() + .map(|v| rusqlite::types::Value::from(v)) + .collect::>(), + ) +} diff --git a/crates/vector_store/src/modal.rs b/crates/vector_store/src/modal.rs index 48429150cd..8052277a0b 100644 --- a/crates/vector_store/src/modal.rs +++ b/crates/vector_store/src/modal.rs @@ -48,7 +48,9 @@ impl PickerDelegate for SemanticSearchDelegate { } fn confirm(&mut self, cx: &mut ViewContext) { - todo!() + if let Some(search_result) = self.matches.get(self.selected_match_index) { + // search_result.file_path + } } fn dismissed(&mut self, _cx: &mut ViewContext) {} @@ -66,9 +68,9 @@ impl PickerDelegate for SemanticSearchDelegate { } fn update_matches(&mut self, query: String, cx: &mut ViewContext) -> Task<()> { - let task = self - .vector_store - .update(cx, |store, cx| store.search(query.to_string(), 10, cx)); + let task = self.vector_store.update(cx, |store, cx| { + store.search(&self.project, query.to_string(), 10, cx) + }); cx.spawn(|this, mut cx| async move { let results = task.await.log_err(); @@ -90,7 +92,7 @@ impl PickerDelegate for SemanticSearchDelegate { ) -> AnyElement> { let theme = theme::current(cx); let style = &theme.picker.item; - let current_style = style.style_for(mouse_state, selected); + let current_style = style.in_state(selected).style_for(mouse_state); let search_result = &self.matches[ix]; @@ -99,7 +101,10 @@ impl PickerDelegate for SemanticSearchDelegate { Flex::column() .with_child(Text::new(name, current_style.label.text.clone()).with_soft_wrap(false)) - .with_child(Label::new(path.to_string(), style.default.label.clone())) + .with_child(Label::new( + path.to_string(), + style.inactive_state().default.label.clone(), + )) .contained() .with_style(current_style.container) .into_any() diff --git a/crates/vector_store/src/vector_store.rs b/crates/vector_store/src/vector_store.rs index 2dc479045f..92926b1f75 100644 --- a/crates/vector_store/src/vector_store.rs +++ b/crates/vector_store/src/vector_store.rs @@ -8,11 +8,11 @@ mod vector_store_tests; use anyhow::{anyhow, Result}; use db::{FileSha1, VectorDatabase, VECTOR_DB_URL}; -use embedding::{EmbeddingProvider, OpenAIEmbeddings}; +use embedding::{DummyEmbeddings, EmbeddingProvider, OpenAIEmbeddings}; use gpui::{actions, AppContext, Entity, ModelContext, ModelHandle, Task, ViewContext}; use language::{Language, LanguageRegistry}; use modal::{SemanticSearch, SemanticSearchDelegate, Toggle}; -use project::{Fs, Project}; +use project::{Fs, Project, WorktreeId}; use smol::channel; use std::{cmp::Ordering, collections::HashMap, path::PathBuf, sync::Arc}; use tree_sitter::{Parser, QueryCursor}; @@ -36,9 +36,10 @@ pub fn init( VectorStore::new( fs, VECTOR_DB_URL.to_string(), - Arc::new(OpenAIEmbeddings { - client: http_client, - }), + // Arc::new(OpenAIEmbeddings { + // client: http_client, + // }), + Arc::new(DummyEmbeddings {}), language_registry, ) }); @@ -75,25 +76,6 @@ pub fn init( } }); SemanticSearch::init(cx); - // cx.add_action({ - // let vector_store = vector_store.clone(); - // move |workspace: &mut Workspace, _: &TestSearch, cx: &mut ViewContext| { - // let t0 = std::time::Instant::now(); - // let task = vector_store.update(cx, |store, cx| { - // store.search("compute embeddings for all of the symbols in the codebase and write them to a database".to_string(), 10, cx) - // }); - - // cx.spawn(|this, cx| async move { - // let results = task.await?; - // let duration = t0.elapsed(); - - // println!("search took {:?}", duration); - // println!("results {:?}", results); - - // anyhow::Ok(()) - // }).detach() - // } - // }); } #[derive(Debug)] @@ -108,10 +90,12 @@ pub struct VectorStore { database_url: Arc, embedding_provider: Arc, language_registry: Arc, + worktree_db_ids: Vec<(WorktreeId, i64)>, } #[derive(Debug)] pub struct SearchResult { + pub worktree_id: WorktreeId, pub name: String, pub offset: usize, pub file_path: PathBuf, @@ -129,6 +113,7 @@ impl VectorStore { database_url: database_url.into(), embedding_provider, language_registry, + worktree_db_ids: Vec::new(), } } @@ -178,9 +163,11 @@ impl VectorStore { } } - let embeddings = embedding_provider.embed_batch(context_spans).await?; - for (document, embedding) in documents.iter_mut().zip(embeddings) { - document.embedding = embedding; + if !documents.is_empty() { + let embeddings = embedding_provider.embed_batch(context_spans).await?; + for (document, embedding) in documents.iter_mut().zip(embeddings) { + document.embedding = embedding; + } } let sha1 = FileSha1::from_str(content); @@ -214,7 +201,7 @@ impl VectorStore { let embedding_provider = self.embedding_provider.clone(); let database_url = self.database_url.clone(); - cx.spawn(|_, cx| async move { + cx.spawn(|this, mut cx| async move { futures::future::join_all(worktree_scans_complete).await; // TODO: remove this after fixing the bug in scan_complete @@ -231,25 +218,24 @@ impl VectorStore { .collect::>() }); - let worktree_root_paths = worktrees - .iter() - .map(|worktree| worktree.abs_path().clone()) - .collect::>(); - // Here we query the worktree ids, and yet we dont have them elsewhere // We likely want to clean up these datastructures - let (db, worktree_hashes, worktree_ids) = cx + let (db, worktree_hashes, worktree_db_ids) = cx .background() - .spawn(async move { - let mut worktree_ids: HashMap = HashMap::new(); - let mut hashes: HashMap> = HashMap::new(); - for worktree_root_path in worktree_root_paths { - let worktree_id = - db.find_or_create_worktree(worktree_root_path.as_ref())?; - worktree_ids.insert(worktree_root_path.to_path_buf(), worktree_id); - hashes.insert(worktree_id, db.get_file_hashes(worktree_id)?); + .spawn({ + let worktrees = worktrees.clone(); + async move { + let mut worktree_db_ids: HashMap = HashMap::new(); + let mut hashes: HashMap> = + HashMap::new(); + for worktree in worktrees { + let worktree_db_id = + db.find_or_create_worktree(worktree.abs_path().as_ref())?; + worktree_db_ids.insert(worktree.id(), worktree_db_id); + hashes.insert(worktree.id(), db.get_file_hashes(worktree_db_id)?); + } + anyhow::Ok((db, hashes, worktree_db_ids)) } - anyhow::Ok((db, hashes, worktree_ids)) }) .await?; @@ -259,10 +245,10 @@ impl VectorStore { cx.background() .spawn({ let fs = fs.clone(); + let worktree_db_ids = worktree_db_ids.clone(); async move { for worktree in worktrees.into_iter() { - let worktree_id = worktree_ids[&worktree.abs_path().to_path_buf()]; - let file_hashes = &worktree_hashes[&worktree_id]; + let file_hashes = &worktree_hashes[&worktree.id()]; for file in worktree.files(false, 0) { let absolute_path = worktree.absolutize(&file.path); @@ -291,7 +277,7 @@ impl VectorStore { ); paths_tx .try_send(( - worktree_id, + worktree_db_ids[&worktree.id()], path_buf, content, language, @@ -382,54 +368,92 @@ impl VectorStore { drop(indexed_files_tx); db_write_task.await; + + this.update(&mut cx, |this, _| { + this.worktree_db_ids.extend(worktree_db_ids); + }); + anyhow::Ok(()) }) } pub fn search( &mut self, + project: &ModelHandle, phrase: String, limit: usize, cx: &mut ModelContext, ) -> Task>> { + let project = project.read(cx); + let worktree_db_ids = project + .worktrees(cx) + .filter_map(|worktree| { + let worktree_id = worktree.read(cx).id(); + self.worktree_db_ids.iter().find_map(|(id, db_id)| { + if *id == worktree_id { + Some(*db_id) + } else { + None + } + }) + }) + .collect::>(); + let embedding_provider = self.embedding_provider.clone(); let database_url = self.database_url.clone(); - cx.background().spawn(async move { - let database = VectorDatabase::new(database_url.as_ref())?; + cx.spawn(|this, cx| async move { + let documents = cx + .background() + .spawn(async move { + let database = VectorDatabase::new(database_url.as_ref())?; - let phrase_embedding = embedding_provider - .embed_batch(vec![&phrase]) - .await? - .into_iter() - .next() - .unwrap(); + let phrase_embedding = embedding_provider + .embed_batch(vec![&phrase]) + .await? + .into_iter() + .next() + .unwrap(); - let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1); - database.for_each_document(0, |id, embedding| { - let similarity = dot(&embedding.0, &phrase_embedding); - let ix = match results.binary_search_by(|(_, s)| { - similarity.partial_cmp(&s).unwrap_or(Ordering::Equal) - }) { - Ok(ix) => ix, - Err(ix) => ix, - }; - results.insert(ix, (id, similarity)); - results.truncate(limit); - })?; + let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1); + database.for_each_document(&worktree_db_ids, |id, embedding| { + let similarity = dot(&embedding.0, &phrase_embedding); + let ix = match results.binary_search_by(|(_, s)| { + similarity.partial_cmp(&s).unwrap_or(Ordering::Equal) + }) { + Ok(ix) => ix, + Err(ix) => ix, + }; + results.insert(ix, (id, similarity)); + results.truncate(limit); + })?; - let ids = results.into_iter().map(|(id, _)| id).collect::>(); - let documents = database.get_documents_by_ids(&ids)?; + let ids = results.into_iter().map(|(id, _)| id).collect::>(); + database.get_documents_by_ids(&ids) + }) + .await?; - anyhow::Ok( + let results = this.read_with(&cx, |this, _| { documents .into_iter() - .map(|(file_path, offset, name)| SearchResult { - name, - offset, - file_path, + .filter_map(|(worktree_db_id, file_path, offset, name)| { + let worktree_id = this.worktree_db_ids.iter().find_map(|(id, db_id)| { + if *db_id == worktree_db_id { + Some(*id) + } else { + None + } + })?; + Some(SearchResult { + worktree_id, + name, + offset, + file_path, + }) }) - .collect(), - ) + .collect() + }); + + anyhow::Ok(results) }) } } diff --git a/crates/vector_store/src/vector_store_tests.rs b/crates/vector_store/src/vector_store_tests.rs index c67bb9954f..6f8856c4fb 100644 --- a/crates/vector_store/src/vector_store_tests.rs +++ b/crates/vector_store/src/vector_store_tests.rs @@ -70,7 +70,10 @@ async fn test_vector_store(cx: &mut TestAppContext) { }); let project = Project::test(fs, ["/the-root".as_ref()], cx).await; - let add_project = store.update(cx, |store, cx| store.add_project(project, cx)); + let worktree_id = project.read_with(cx, |project, cx| { + project.worktrees(cx).next().unwrap().read(cx).id() + }); + let add_project = store.update(cx, |store, cx| store.add_project(project.clone(), cx)); // TODO - remove cx.foreground() @@ -79,12 +82,15 @@ async fn test_vector_store(cx: &mut TestAppContext) { add_project.await.unwrap(); let search_results = store - .update(cx, |store, cx| store.search("aaaa".to_string(), 5, cx)) + .update(cx, |store, cx| { + store.search(&project, "aaaa".to_string(), 5, cx) + }) .await .unwrap(); assert_eq!(search_results[0].offset, 0); assert_eq!(search_results[0].name, "aaa"); + assert_eq!(search_results[0].worktree_id, worktree_id); } #[test] From 85e71415fea6102001c324b08c8558abea9b07f7 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Wed, 28 Jun 2023 16:25:05 -0400 Subject: [PATCH 15/51] updated embedding database calls to maintain project consistency Co-authored-by: maxbrunsfeld --- crates/vector_store/src/db.rs | 44 ----------------- crates/vector_store/src/search.rs | 66 ------------------------- crates/vector_store/src/vector_store.rs | 1 - 3 files changed, 111 deletions(-) delete mode 100644 crates/vector_store/src/search.rs diff --git a/crates/vector_store/src/db.rs b/crates/vector_store/src/db.rs index 96856936fc..f1453141bb 100644 --- a/crates/vector_store/src/db.rs +++ b/crates/vector_store/src/db.rs @@ -236,27 +236,6 @@ impl VectorDatabase { Ok(result) } - pub fn get_files(&self) -> Result> { - let mut query_statement = self - .db - .prepare("SELECT id, relative_path, sha1 FROM files")?; - let result_iter = query_statement.query_map([], |row| { - Ok(FileRecord { - id: row.get(0)?, - relative_path: row.get(1)?, - sha1: row.get(2)?, - }) - })?; - - let mut pages: HashMap = HashMap::new(); - for result in result_iter { - let result = result?; - pages.insert(result.id, result); - } - - Ok(pages) - } - pub fn for_each_document( &self, worktree_ids: &[i64], @@ -321,29 +300,6 @@ impl VectorDatabase { Ok(results) } - - pub fn get_documents(&self) -> Result> { - let mut query_statement = self - .db - .prepare("SELECT id, file_id, offset, name, embedding FROM documents")?; - let result_iter = query_statement.query_map([], |row| { - Ok(DocumentRecord { - id: row.get(0)?, - file_id: row.get(1)?, - offset: row.get(2)?, - name: row.get(3)?, - embedding: row.get(4)?, - }) - })?; - - let mut documents: HashMap = HashMap::new(); - for result in result_iter { - let result = result?; - documents.insert(result.id, result); - } - - return Ok(documents); - } } fn ids_to_sql(ids: &[i64]) -> Rc> { diff --git a/crates/vector_store/src/search.rs b/crates/vector_store/src/search.rs deleted file mode 100644 index 90a8d874da..0000000000 --- a/crates/vector_store/src/search.rs +++ /dev/null @@ -1,66 +0,0 @@ -use std::{cmp::Ordering, path::PathBuf}; - -use async_trait::async_trait; -use ndarray::{Array1, Array2}; - -use crate::db::{DocumentRecord, VectorDatabase}; -use anyhow::Result; - -#[async_trait] -pub trait VectorSearch { - // Given a query vector, and a limit to return - // Return a vector of id, distance tuples. - async fn top_k_search(&mut self, vec: &Vec, limit: usize) -> Vec<(usize, f32)>; -} - -pub struct BruteForceSearch { - document_ids: Vec, - candidate_array: ndarray::Array2, -} - -impl BruteForceSearch { - pub fn load(db: &VectorDatabase) -> Result { - let documents = db.get_documents()?; - let embeddings: Vec<&DocumentRecord> = documents.values().into_iter().collect(); - let mut document_ids = vec![]; - for i in documents.keys() { - document_ids.push(i.to_owned()); - } - - let mut candidate_array = Array2::::default((documents.len(), 1536)); - for (i, mut row) in candidate_array.axis_iter_mut(ndarray::Axis(0)).enumerate() { - for (j, col) in row.iter_mut().enumerate() { - *col = embeddings[i].embedding.0[j]; - } - } - - return Ok(BruteForceSearch { - document_ids, - candidate_array, - }); - } -} - -#[async_trait] -impl VectorSearch for BruteForceSearch { - async fn top_k_search(&mut self, vec: &Vec, limit: usize) -> Vec<(usize, f32)> { - let target = Array1::from_vec(vec.to_owned()); - - let similarities = self.candidate_array.dot(&target); - - let similarities = similarities.to_vec(); - - // construct a tuple vector from the floats, the tuple being (index,float) - let mut with_indices = similarities - .iter() - .copied() - .enumerate() - .map(|(index, value)| (self.document_ids[index], value)) - .collect::>(); - - // sort the tuple vector by float - with_indices.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal)); - with_indices.truncate(limit); - with_indices - } -} diff --git a/crates/vector_store/src/vector_store.rs b/crates/vector_store/src/vector_store.rs index 92926b1f75..a66c2d65ba 100644 --- a/crates/vector_store/src/vector_store.rs +++ b/crates/vector_store/src/vector_store.rs @@ -1,7 +1,6 @@ mod db; mod embedding; mod modal; -mod search; #[cfg(test)] mod vector_store_tests; From fd68a2afaec50423b714b615157857278b038321 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Wed, 28 Jun 2023 15:02:20 -0700 Subject: [PATCH 16/51] Debounce searches in semantic search modal --- crates/vector_store/src/modal.rs | 31 ++++++++++++++++--------- crates/vector_store/src/vector_store.rs | 11 +++++---- 2 files changed, 26 insertions(+), 16 deletions(-) diff --git a/crates/vector_store/src/modal.rs b/crates/vector_store/src/modal.rs index 8052277a0b..1ca59c5585 100644 --- a/crates/vector_store/src/modal.rs +++ b/crates/vector_store/src/modal.rs @@ -1,15 +1,16 @@ -use std::sync::Arc; - +use crate::{SearchResult, VectorStore}; use gpui::{ actions, elements::*, AnyElement, AppContext, ModelHandle, MouseState, Task, ViewContext, WeakViewHandle, }; use picker::{Picker, PickerDelegate, PickerEvent}; use project::Project; +use std::{sync::Arc, time::Duration}; use util::ResultExt; use workspace::Workspace; -use crate::{SearchResult, VectorStore}; +const MIN_QUERY_LEN: usize = 5; +const EMBEDDING_DEBOUNCE_INTERVAL: Duration = Duration::from_millis(500); actions!(semantic_search, [Toggle]); @@ -68,18 +69,26 @@ impl PickerDelegate for SemanticSearchDelegate { } fn update_matches(&mut self, query: String, cx: &mut ViewContext) -> Task<()> { - let task = self.vector_store.update(cx, |store, cx| { - store.search(&self.project, query.to_string(), 10, cx) - }); + if query.len() < MIN_QUERY_LEN { + return Task::ready(()); + } + let vector_store = self.vector_store.clone(); + let project = self.project.clone(); cx.spawn(|this, mut cx| async move { - let results = task.await.log_err(); - this.update(&mut cx, |this, cx| { - if let Some(results) = results { + cx.background().timer(EMBEDDING_DEBOUNCE_INTERVAL).await; + + let task = vector_store.update(&mut cx, |store, cx| { + store.search(&project, query.to_string(), 10, cx) + }); + + if let Some(results) = task.await.log_err() { + this.update(&mut cx, |this, _| { let delegate = this.delegate_mut(); delegate.matches = results; - } - }); + }) + .ok(); + } }) } diff --git a/crates/vector_store/src/vector_store.rs b/crates/vector_store/src/vector_store.rs index a66c2d65ba..c37a50e4de 100644 --- a/crates/vector_store/src/vector_store.rs +++ b/crates/vector_store/src/vector_store.rs @@ -31,14 +31,14 @@ pub fn init( language_registry: Arc, cx: &mut AppContext, ) { - let vector_store = cx.add_model(|cx| { + let vector_store = cx.add_model(|_| { VectorStore::new( fs, VECTOR_DB_URL.to_string(), - // Arc::new(OpenAIEmbeddings { - // client: http_client, - // }), - Arc::new(DummyEmbeddings {}), + // Arc::new(DummyEmbeddings {}), + Arc::new(OpenAIEmbeddings { + client: http_client, + }), language_registry, ) }); @@ -74,6 +74,7 @@ pub fn init( }) } }); + SemanticSearch::init(cx); } From a08d60fc61a307bc838b7f53930bc8be2c6bcb37 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Thu, 29 Jun 2023 11:58:47 -0400 Subject: [PATCH 17/51] added navigation on confirm to semantic search modal --- Cargo.lock | 1 + crates/vector_store/Cargo.toml | 1 + crates/vector_store/src/modal.rs | 36 ++++++++++++++++++++++--- crates/vector_store/src/vector_store.rs | 33 ----------------------- 4 files changed, 35 insertions(+), 36 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index fb8c719278..dbdf7f5774 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8041,6 +8041,7 @@ dependencies = [ "anyhow", "async-trait", "bincode", + "editor", "futures 0.3.28", "gpui", "isahc", diff --git a/crates/vector_store/Cargo.toml b/crates/vector_store/Cargo.toml index ddfef6927b..4ecd46cb92 100644 --- a/crates/vector_store/Cargo.toml +++ b/crates/vector_store/Cargo.toml @@ -16,6 +16,7 @@ workspace = { path = "../workspace" } util = { path = "../util" } picker = { path = "../picker" } theme = { path = "../theme" } +editor = { path = "../editor" } anyhow.workspace = true futures.workspace = true smol.workspace = true diff --git a/crates/vector_store/src/modal.rs b/crates/vector_store/src/modal.rs index 1ca59c5585..e857aa2ab2 100644 --- a/crates/vector_store/src/modal.rs +++ b/crates/vector_store/src/modal.rs @@ -1,11 +1,12 @@ use crate::{SearchResult, VectorStore}; +use editor::{scroll::autoscroll::Autoscroll, Editor}; use gpui::{ actions, elements::*, AnyElement, AppContext, ModelHandle, MouseState, Task, ViewContext, WeakViewHandle, }; use picker::{Picker, PickerDelegate, PickerEvent}; -use project::Project; -use std::{sync::Arc, time::Duration}; +use project::{Project, ProjectPath}; +use std::{path::Path, sync::Arc, time::Duration}; use util::ResultExt; use workspace::Workspace; @@ -50,7 +51,34 @@ impl PickerDelegate for SemanticSearchDelegate { fn confirm(&mut self, cx: &mut ViewContext) { if let Some(search_result) = self.matches.get(self.selected_match_index) { - // search_result.file_path + // Open Buffer + let search_result = search_result.clone(); + let buffer = self.project.update(cx, |project, cx| { + project.open_buffer( + ProjectPath { + worktree_id: search_result.worktree_id, + path: search_result.file_path.clone().into(), + }, + cx, + ) + }); + + let workspace = self.workspace.clone(); + let position = search_result.clone().offset; + cx.spawn(|_, mut cx| async move { + let buffer = buffer.await?; + workspace.update(&mut cx, |workspace, cx| { + let editor = workspace.open_project_item::(buffer, cx); + editor.update(cx, |editor, cx| { + editor.change_selections(Some(Autoscroll::center()), cx, |s| { + s.select_ranges([position..position]) + }); + }); + })?; + Ok::<_, anyhow::Error>(()) + }) + .detach_and_log_err(cx); + cx.emit(PickerEvent::Dismiss); } } @@ -78,6 +106,8 @@ impl PickerDelegate for SemanticSearchDelegate { cx.spawn(|this, mut cx| async move { cx.background().timer(EMBEDDING_DEBOUNCE_INTERVAL).await; + log::info!("Searching for {:?}", &query); + let task = vector_store.update(&mut cx, |store, cx| { store.search(&project, query.to_string(), 10, cx) }); diff --git a/crates/vector_store/src/vector_store.rs b/crates/vector_store/src/vector_store.rs index c37a50e4de..641fdd86f2 100644 --- a/crates/vector_store/src/vector_store.rs +++ b/crates/vector_store/src/vector_store.rs @@ -294,43 +294,10 @@ impl VectorStore { let db_write_task = cx.background().spawn( async move { - // Initialize Database, creates database and tables if not exists while let Ok((worktree_id, indexed_file)) = indexed_files_rx.recv().await { db.insert_file(worktree_id, indexed_file).log_err(); } - // ALL OF THE BELOW IS FOR TESTING, - // This should be removed as we find and appropriate place for evaluate our search. - - // let queries = vec![ - // "compute embeddings for all of the symbols in the codebase, and write them to a database", - // "compute an outline view of all of the symbols in a buffer", - // "scan a directory on the file system and load all of its children into an in-memory snapshot", - // ]; - // let embeddings = embedding_provider.embed_batch(queries.clone()).await?; - - // let t2 = Instant::now(); - // let documents = db.get_documents().unwrap(); - // let files = db.get_files().unwrap(); - // println!("Retrieving all documents from Database: {}", t2.elapsed().as_millis()); - - // let t1 = Instant::now(); - // let mut bfs = BruteForceSearch::load(&db).unwrap(); - // println!("Loading BFS to Memory: {:?}", t1.elapsed().as_millis()); - // for (idx, embed) in embeddings.into_iter().enumerate() { - // let t0 = Instant::now(); - // println!("\nQuery: {:?}", queries[idx]); - // let results = bfs.top_k_search(&embed, 5).await; - // println!("Search Elapsed: {}", t0.elapsed().as_millis()); - // for (id, distance) in results { - // println!(""); - // println!(" distance: {:?}", distance); - // println!(" document: {:?}", documents[&id].name); - // println!(" path: {:?}", files[&documents[&id].file_id].relative_path); - // } - - // } - anyhow::Ok(()) } .log_err(), From 0a7245a583667789dde8d03f9be07117f73e1e31 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Thu, 29 Jun 2023 13:50:49 -0400 Subject: [PATCH 18/51] updated semantic search modal to manage for duplicate queries --- crates/vector_store/src/modal.rs | 45 ++++++++++++++++++------- crates/vector_store/src/vector_store.rs | 6 ++-- 2 files changed, 36 insertions(+), 15 deletions(-) diff --git a/crates/vector_store/src/modal.rs b/crates/vector_store/src/modal.rs index e857aa2ab2..4d377c6819 100644 --- a/crates/vector_store/src/modal.rs +++ b/crates/vector_store/src/modal.rs @@ -6,7 +6,7 @@ use gpui::{ }; use picker::{Picker, PickerDelegate, PickerEvent}; use project::{Project, ProjectPath}; -use std::{path::Path, sync::Arc, time::Duration}; +use std::{collections::HashMap, sync::Arc, time::Duration}; use util::ResultExt; use workspace::Workspace; @@ -23,6 +23,7 @@ pub struct SemanticSearchDelegate { vector_store: ModelHandle, selected_match_index: usize, matches: Vec, + history: HashMap>, } impl SemanticSearchDelegate { @@ -40,6 +41,7 @@ impl SemanticSearchDelegate { vector_store, selected_match_index: 0, matches: vec![], + history: HashMap::new(), } } } @@ -97,7 +99,9 @@ impl PickerDelegate for SemanticSearchDelegate { } fn update_matches(&mut self, query: String, cx: &mut ViewContext) -> Task<()> { + log::info!("Searching for {:?}...", query); if query.len() < MIN_QUERY_LEN { + log::info!("Query below minimum length"); return Task::ready(()); } @@ -106,18 +110,35 @@ impl PickerDelegate for SemanticSearchDelegate { cx.spawn(|this, mut cx| async move { cx.background().timer(EMBEDDING_DEBOUNCE_INTERVAL).await; - log::info!("Searching for {:?}", &query); - - let task = vector_store.update(&mut cx, |store, cx| { - store.search(&project, query.to_string(), 10, cx) + let retrieved_cached = this.update(&mut cx, |this, _| { + let delegate = this.delegate_mut(); + if delegate.history.contains_key(&query) { + let historic_results = delegate.history.get(&query).unwrap().to_owned(); + delegate.matches = historic_results.clone(); + true + } else { + false + } }); - if let Some(results) = task.await.log_err() { - this.update(&mut cx, |this, _| { - let delegate = this.delegate_mut(); - delegate.matches = results; - }) - .ok(); + if let Some(retrieved) = retrieved_cached.log_err() { + if !retrieved { + let task = vector_store.update(&mut cx, |store, cx| { + store.search(&project, query.to_string(), 10, cx) + }); + + if let Some(results) = task.await.log_err() { + log::info!("Not queried previously, searching..."); + this.update(&mut cx, |this, _| { + let delegate = this.delegate_mut(); + delegate.matches = results.clone(); + delegate.history.insert(query, results); + }) + .ok(); + } + } else { + log::info!("Already queried, retrieved directly from cached history"); + } } }) } @@ -135,7 +156,7 @@ impl PickerDelegate for SemanticSearchDelegate { let search_result = &self.matches[ix]; - let mut path = search_result.file_path.to_string_lossy(); + let path = search_result.file_path.to_string_lossy(); let name = search_result.name.clone(); Flex::column() diff --git a/crates/vector_store/src/vector_store.rs b/crates/vector_store/src/vector_store.rs index 641fdd86f2..b3894f3686 100644 --- a/crates/vector_store/src/vector_store.rs +++ b/crates/vector_store/src/vector_store.rs @@ -7,8 +7,8 @@ mod vector_store_tests; use anyhow::{anyhow, Result}; use db::{FileSha1, VectorDatabase, VECTOR_DB_URL}; -use embedding::{DummyEmbeddings, EmbeddingProvider, OpenAIEmbeddings}; -use gpui::{actions, AppContext, Entity, ModelContext, ModelHandle, Task, ViewContext}; +use embedding::{EmbeddingProvider, OpenAIEmbeddings}; +use gpui::{AppContext, Entity, ModelContext, ModelHandle, Task, ViewContext}; use language::{Language, LanguageRegistry}; use modal::{SemanticSearch, SemanticSearchDelegate, Toggle}; use project::{Fs, Project, WorktreeId}; @@ -93,7 +93,7 @@ pub struct VectorStore { worktree_db_ids: Vec<(WorktreeId, i64)>, } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct SearchResult { pub worktree_id: WorktreeId, pub name: String, From 39137fc19f001a5ab3d24d54a2c7ebaa50ca4d06 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Thu, 29 Jun 2023 15:18:32 -0400 Subject: [PATCH 19/51] updated vector_store db to leverage EMBEDDINGS_DIR path --- crates/util/src/paths.rs | 1 + crates/vector_store/src/db.rs | 7 +--- crates/vector_store/src/vector_store.rs | 33 +++++++++++++------ crates/vector_store/src/vector_store_tests.rs | 2 +- 4 files changed, 26 insertions(+), 17 deletions(-) diff --git a/crates/util/src/paths.rs b/crates/util/src/paths.rs index 7ef55a9918..5df0ed12e9 100644 --- a/crates/util/src/paths.rs +++ b/crates/util/src/paths.rs @@ -6,6 +6,7 @@ lazy_static::lazy_static! { pub static ref HOME: PathBuf = dirs::home_dir().expect("failed to determine home directory"); pub static ref CONFIG_DIR: PathBuf = HOME.join(".config").join("zed"); pub static ref CONVERSATIONS_DIR: PathBuf = HOME.join(".config/zed/conversations"); + pub static ref EMBEDDINGS_DIR: PathBuf = HOME.join(".config/zed/embeddings"); pub static ref LOGS_DIR: PathBuf = HOME.join("Library/Logs/Zed"); pub static ref SUPPORT_DIR: PathBuf = HOME.join("Library/Application Support/Zed"); pub static ref LANGUAGES_DIR: PathBuf = HOME.join("Library/Application Support/Zed/languages"); diff --git a/crates/vector_store/src/db.rs b/crates/vector_store/src/db.rs index f1453141bb..768df8069f 100644 --- a/crates/vector_store/src/db.rs +++ b/crates/vector_store/src/db.rs @@ -15,11 +15,6 @@ use sha1::{Digest, Sha1}; use crate::IndexedFile; -// This is saving to a local database store within the users dev zed path -// Where do we want this to sit? -// Assuming near where the workspace DB sits. -pub const VECTOR_DB_URL: &str = "embeddings_db"; - // Note this is not an appropriate document #[derive(Debug)] pub struct DocumentRecord { @@ -109,7 +104,7 @@ pub struct VectorDatabase { } impl VectorDatabase { - pub fn new(path: &str) -> Result { + pub fn new(path: String) -> Result { let this = Self { db: rusqlite::Connection::open(path)?, }; diff --git a/crates/vector_store/src/vector_store.rs b/crates/vector_store/src/vector_store.rs index b3894f3686..47d6932685 100644 --- a/crates/vector_store/src/vector_store.rs +++ b/crates/vector_store/src/vector_store.rs @@ -6,16 +6,23 @@ mod modal; mod vector_store_tests; use anyhow::{anyhow, Result}; -use db::{FileSha1, VectorDatabase, VECTOR_DB_URL}; +use db::{FileSha1, VectorDatabase}; use embedding::{EmbeddingProvider, OpenAIEmbeddings}; use gpui::{AppContext, Entity, ModelContext, ModelHandle, Task, ViewContext}; use language::{Language, LanguageRegistry}; use modal::{SemanticSearch, SemanticSearchDelegate, Toggle}; use project::{Fs, Project, WorktreeId}; use smol::channel; -use std::{cmp::Ordering, collections::HashMap, path::PathBuf, sync::Arc}; +use std::{ + cmp::Ordering, + collections::HashMap, + path::{Path, PathBuf}, + sync::Arc, +}; use tree_sitter::{Parser, QueryCursor}; -use util::{http::HttpClient, ResultExt, TryFutureExt}; +use util::{ + channel::RELEASE_CHANNEL_NAME, http::HttpClient, paths::EMBEDDINGS_DIR, ResultExt, TryFutureExt, +}; use workspace::{Workspace, WorkspaceCreated}; #[derive(Debug)] @@ -31,11 +38,14 @@ pub fn init( language_registry: Arc, cx: &mut AppContext, ) { + let db_file_path = EMBEDDINGS_DIR + .join(Path::new(RELEASE_CHANNEL_NAME.as_str())) + .join("embeddings_db"); + let vector_store = cx.add_model(|_| { VectorStore::new( fs, - VECTOR_DB_URL.to_string(), - // Arc::new(DummyEmbeddings {}), + db_file_path, Arc::new(OpenAIEmbeddings { client: http_client, }), @@ -87,7 +97,7 @@ pub struct IndexedFile { pub struct VectorStore { fs: Arc, - database_url: Arc, + database_url: Arc, embedding_provider: Arc, language_registry: Arc, worktree_db_ids: Vec<(WorktreeId, i64)>, @@ -104,13 +114,13 @@ pub struct SearchResult { impl VectorStore { fn new( fs: Arc, - database_url: String, + database_url: PathBuf, embedding_provider: Arc, language_registry: Arc, ) -> Self { Self { fs, - database_url: database_url.into(), + database_url: Arc::new(database_url), embedding_provider, language_registry, worktree_db_ids: Vec::new(), @@ -209,7 +219,10 @@ impl VectorStore { .timer(std::time::Duration::from_secs(3)) .await; - let db = VectorDatabase::new(&database_url)?; + if let Some(db_directory) = database_url.parent() { + fs.create_dir(db_directory).await.log_err(); + } + let db = VectorDatabase::new(database_url.to_string_lossy().into())?; let worktrees = project.read_with(&cx, |project, cx| { project @@ -372,7 +385,7 @@ impl VectorStore { let documents = cx .background() .spawn(async move { - let database = VectorDatabase::new(database_url.as_ref())?; + let database = VectorDatabase::new(database_url.to_string_lossy().into())?; let phrase_embedding = embedding_provider .embed_batch(vec![&phrase]) diff --git a/crates/vector_store/src/vector_store_tests.rs b/crates/vector_store/src/vector_store_tests.rs index 6f8856c4fb..e232ba9f21 100644 --- a/crates/vector_store/src/vector_store_tests.rs +++ b/crates/vector_store/src/vector_store_tests.rs @@ -63,7 +63,7 @@ async fn test_vector_store(cx: &mut TestAppContext) { let store = cx.add_model(|_| { VectorStore::new( fs.clone(), - db_path.to_string_lossy().to_string(), + db_path, Arc::new(FakeEmbeddingProvider), languages, ) From e3ab54942ee46b4395f20ccdaba31c838223b4cb Mon Sep 17 00:00:00 2001 From: KCaverly Date: Fri, 30 Jun 2023 10:17:31 -0400 Subject: [PATCH 20/51] removed sleep from directory scanning as fixes upstream appear to be scanning correctly --- crates/vector_store/src/vector_store.rs | 5 ----- 1 file changed, 5 deletions(-) diff --git a/crates/vector_store/src/vector_store.rs b/crates/vector_store/src/vector_store.rs index 47d6932685..9e589e010f 100644 --- a/crates/vector_store/src/vector_store.rs +++ b/crates/vector_store/src/vector_store.rs @@ -214,11 +214,6 @@ impl VectorStore { cx.spawn(|this, mut cx| async move { futures::future::join_all(worktree_scans_complete).await; - // TODO: remove this after fixing the bug in scan_complete - cx.background() - .timer(std::time::Duration::from_secs(3)) - .await; - if let Some(db_directory) = database_url.parent() { fs.create_dir(db_directory).await.log_err(); } From 0db0876289c9bb96706fbd997c52df1a33191b13 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Fri, 30 Jun 2023 11:01:35 -0400 Subject: [PATCH 21/51] implemented file deletes on project indexing --- crates/vector_store/src/db.rs | 9 +++++- crates/vector_store/src/vector_store.rs | 39 ++++++++++++++++--------- 2 files changed, 33 insertions(+), 15 deletions(-) diff --git a/crates/vector_store/src/db.rs b/crates/vector_store/src/db.rs index 768df8069f..fec2980550 100644 --- a/crates/vector_store/src/db.rs +++ b/crates/vector_store/src/db.rs @@ -154,9 +154,16 @@ impl VectorDatabase { Ok(()) } + pub fn delete_file(&self, worktree_id: i64, delete_path: PathBuf) -> Result<()> { + self.db.execute( + "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2", + params![worktree_id, delete_path.to_str()], + )?; + Ok(()) + } + pub fn insert_file(&self, worktree_id: i64, indexed_file: IndexedFile) -> Result<()> { // Write to files table, and return generated id. - log::info!("Inserting File!"); self.db.execute( " DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2; diff --git a/crates/vector_store/src/vector_store.rs b/crates/vector_store/src/vector_store.rs index 9e589e010f..876a6018b8 100644 --- a/crates/vector_store/src/vector_store.rs +++ b/crates/vector_store/src/vector_store.rs @@ -15,7 +15,7 @@ use project::{Fs, Project, WorktreeId}; use smol::channel; use std::{ cmp::Ordering, - collections::HashMap, + collections::{HashMap, HashSet}, path::{Path, PathBuf}, sync::Arc, }; @@ -201,7 +201,6 @@ impl VectorStore { let scan_complete = worktree.read(cx).as_local().unwrap().scan_complete(); async move { scan_complete.await; - log::info!("worktree scan completed"); } }) .collect::>(); @@ -249,6 +248,7 @@ impl VectorStore { let (paths_tx, paths_rx) = channel::unbounded::<(i64, PathBuf, String, Arc)>(); + let (delete_paths_tx, delete_paths_rx) = channel::unbounded::<(i64, PathBuf)>(); let (indexed_files_tx, indexed_files_rx) = channel::unbounded::<(i64, IndexedFile)>(); cx.background() .spawn({ @@ -257,6 +257,8 @@ impl VectorStore { async move { for worktree in worktrees.into_iter() { let file_hashes = &worktree_hashes[&worktree.id()]; + let mut files_included = + file_hashes.keys().collect::>(); for file in worktree.files(false, 0) { let absolute_path = worktree.absolutize(&file.path); @@ -269,20 +271,16 @@ impl VectorStore { } if let Some(content) = fs.load(&absolute_path).await.log_err() { - log::info!("loaded file: {absolute_path:?}"); - let path_buf = file.path.to_path_buf(); - let already_stored = file_hashes - .get(&path_buf) - .map_or(false, |existing_hash| { + let already_stored = file_hashes.get(&path_buf).map_or( + false, + |existing_hash| { + files_included.remove(&path_buf); existing_hash.equals(&content) - }); + }, + ); if !already_stored { - log::info!( - "File Changed (Sending to Parse): {:?}", - &path_buf - ); paths_tx .try_send(( worktree_db_ids[&worktree.id()], @@ -295,17 +293,30 @@ impl VectorStore { } } } + for file in files_included { + delete_paths_tx + .try_send((worktree_db_ids[&worktree.id()], file.to_owned())) + .unwrap(); + } } } }) .detach(); - let db_write_task = cx.background().spawn( + let db_update_task = cx.background().spawn( async move { + // Inserting all new files while let Ok((worktree_id, indexed_file)) = indexed_files_rx.recv().await { + log::info!("Inserting File: {:?}", &indexed_file.path); db.insert_file(worktree_id, indexed_file).log_err(); } + // Deleting all old files + while let Ok((worktree_id, delete_path)) = delete_paths_rx.recv().await { + log::info!("Deleting File: {:?}", &delete_path); + db.delete_file(worktree_id, delete_path).log_err(); + } + anyhow::Ok(()) } .log_err(), @@ -342,7 +353,7 @@ impl VectorStore { .await; drop(indexed_files_tx); - db_write_task.await; + db_update_task.await; this.update(&mut cx, |this, _| { this.worktree_db_ids.extend(worktree_db_ids); From 36907bb4dc604c2715242d7cedfc04cde7cf60ff Mon Sep 17 00:00:00 2001 From: KCaverly Date: Fri, 30 Jun 2023 16:14:11 -0400 Subject: [PATCH 22/51] updated vector store indexing to only use languages with an embedding.scm treesitter query Co-authored-by: maxbrunsfeld --- crates/language/src/language.rs | 44 +++++++++++++++++++ crates/vector_store/src/vector_store.rs | 22 +++++++--- crates/vector_store/src/vector_store_tests.rs | 2 +- crates/zed/src/languages.rs | 1 + crates/zed/src/languages/rust/embedding.scm | 36 +++++++++++++++ 5 files changed, 98 insertions(+), 7 deletions(-) create mode 100644 crates/zed/src/languages/rust/embedding.scm diff --git a/crates/language/src/language.rs b/crates/language/src/language.rs index b880cbc8d7..4ef9d25894 100644 --- a/crates/language/src/language.rs +++ b/crates/language/src/language.rs @@ -350,6 +350,7 @@ pub struct LanguageQueries { pub brackets: Option>, pub indents: Option>, pub outline: Option>, + pub embedding: Option>, pub injections: Option>, pub overrides: Option>, } @@ -495,6 +496,7 @@ pub struct Grammar { pub(crate) brackets_config: Option, pub(crate) indents_config: Option, pub outline_config: Option, + pub embedding_config: Option, pub(crate) injection_config: Option, pub(crate) override_config: Option, pub(crate) highlight_map: Mutex, @@ -516,6 +518,15 @@ pub struct OutlineConfig { pub extra_context_capture_ix: Option, } +#[derive(Debug)] +pub struct EmbeddingConfig { + pub query: Query, + pub item_capture_ix: u32, + pub name_capture_ix: u32, + pub context_capture_ix: Option, + pub extra_context_capture_ix: Option, +} + struct InjectionConfig { query: Query, content_capture_ix: u32, @@ -1145,6 +1156,7 @@ impl Language { highlights_query: None, brackets_config: None, outline_config: None, + embedding_config: None, indents_config: None, injection_config: None, override_config: None, @@ -1181,6 +1193,9 @@ impl Language { if let Some(query) = queries.outline { self = self.with_outline_query(query.as_ref())?; } + if let Some(query) = queries.embedding { + self = self.with_embedding_query(query.as_ref())?; + } if let Some(query) = queries.injections { self = self.with_injection_query(query.as_ref())?; } @@ -1189,6 +1204,7 @@ impl Language { } Ok(self) } + pub fn with_highlights_query(mut self, source: &str) -> Result { let grammar = self.grammar_mut(); grammar.highlights_query = Some(Query::new(grammar.ts_language, source)?); @@ -1223,6 +1239,34 @@ impl Language { Ok(self) } + pub fn with_embedding_query(mut self, source: &str) -> Result { + let grammar = self.grammar_mut(); + let query = Query::new(grammar.ts_language, source)?; + let mut item_capture_ix = None; + let mut name_capture_ix = None; + let mut context_capture_ix = None; + let mut extra_context_capture_ix = None; + get_capture_indices( + &query, + &mut [ + ("item", &mut item_capture_ix), + ("name", &mut name_capture_ix), + ("context", &mut context_capture_ix), + ("context.extra", &mut extra_context_capture_ix), + ], + ); + if let Some((item_capture_ix, name_capture_ix)) = item_capture_ix.zip(name_capture_ix) { + grammar.embedding_config = Some(EmbeddingConfig { + query, + item_capture_ix, + name_capture_ix, + context_capture_ix, + extra_context_capture_ix, + }); + } + Ok(self) + } + pub fn with_brackets_query(mut self, source: &str) -> Result { let grammar = self.grammar_mut(); let query = Query::new(grammar.ts_language, source)?; diff --git a/crates/vector_store/src/vector_store.rs b/crates/vector_store/src/vector_store.rs index 876a6018b8..35a467b82f 100644 --- a/crates/vector_store/src/vector_store.rs +++ b/crates/vector_store/src/vector_store.rs @@ -136,8 +136,8 @@ impl VectorStore { content: String, ) -> Result { let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?; - let outline_config = grammar - .outline_config + let embedding_config = grammar + .embedding_config .as_ref() .ok_or_else(|| anyhow!("no outline query"))?; @@ -148,13 +148,17 @@ impl VectorStore { let mut documents = Vec::new(); let mut context_spans = Vec::new(); - for mat in cursor.matches(&outline_config.query, tree.root_node(), content.as_bytes()) { + for mat in cursor.matches( + &embedding_config.query, + tree.root_node(), + content.as_bytes(), + ) { let mut item_range = None; let mut name_range = None; for capture in mat.captures { - if capture.index == outline_config.item_capture_ix { + if capture.index == embedding_config.item_capture_ix { item_range = Some(capture.node.byte_range()); - } else if capture.index == outline_config.name_capture_ix { + } else if capture.index == embedding_config.name_capture_ix { name_range = Some(capture.node.byte_range()); } } @@ -266,7 +270,11 @@ impl VectorStore { .language_for_file(&absolute_path, None) .await { - if language.name().as_ref() != "Rust" { + if language + .grammar() + .and_then(|grammar| grammar.embedding_config.as_ref()) + .is_none() + { continue; } @@ -359,6 +367,8 @@ impl VectorStore { this.worktree_db_ids.extend(worktree_db_ids); }); + log::info!("Semantic Indexing Complete!"); + anyhow::Ok(()) }) } diff --git a/crates/vector_store/src/vector_store_tests.rs b/crates/vector_store/src/vector_store_tests.rs index e232ba9f21..78470ad4be 100644 --- a/crates/vector_store/src/vector_store_tests.rs +++ b/crates/vector_store/src/vector_store_tests.rs @@ -46,7 +46,7 @@ async fn test_vector_store(cx: &mut TestAppContext) { }, Some(tree_sitter_rust::language()), ) - .with_outline_query( + .with_embedding_query( r#" (function_item name: (identifier) @name diff --git a/crates/zed/src/languages.rs b/crates/zed/src/languages.rs index 44e144e89b..820f564151 100644 --- a/crates/zed/src/languages.rs +++ b/crates/zed/src/languages.rs @@ -170,6 +170,7 @@ fn load_queries(name: &str) -> LanguageQueries { brackets: load_query(name, "/brackets"), indents: load_query(name, "/indents"), outline: load_query(name, "/outline"), + embedding: load_query(name, "/embedding"), injections: load_query(name, "/injections"), overrides: load_query(name, "/overrides"), } diff --git a/crates/zed/src/languages/rust/embedding.scm b/crates/zed/src/languages/rust/embedding.scm new file mode 100644 index 0000000000..ea8bab9f68 --- /dev/null +++ b/crates/zed/src/languages/rust/embedding.scm @@ -0,0 +1,36 @@ +(struct_item + (visibility_modifier)? @context + "struct" @context + name: (_) @name) @item + +(enum_item + (visibility_modifier)? @context + "enum" @context + name: (_) @name) @item + +(impl_item + "impl" @context + trait: (_)? @name + "for"? @context + type: (_) @name) @item + +(trait_item + (visibility_modifier)? @context + "trait" @context + name: (_) @name) @item + +(function_item + (visibility_modifier)? @context + (function_modifiers)? @context + "fn" @context + name: (_) @name) @item + +(function_signature_item + (visibility_modifier)? @context + (function_modifiers)? @context + "fn" @context + name: (_) @name) @item + +(macro_definition + . "macro_rules!" @context + name: (_) @name) @item From 3408b98167481aa54c70839f6024bdc5cdfb2aec Mon Sep 17 00:00:00 2001 From: KCaverly Date: Fri, 30 Jun 2023 16:53:23 -0400 Subject: [PATCH 23/51] updated file compare in the semantic indexing engine, to work off of modified system times as opposed to file hashes Co-authored-by: maxbrunsfeld --- Cargo.lock | 25 +------ crates/vector_store/Cargo.toml | 4 +- crates/vector_store/src/db.rs | 99 +++++++++---------------- crates/vector_store/src/vector_store.rs | 76 +++++++++---------- 4 files changed, 74 insertions(+), 130 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 85599036a1..59cf30001e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4232,19 +4232,6 @@ dependencies = [ "tempfile", ] -[[package]] -name = "ndarray" -version = "0.15.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32" -dependencies = [ - "matrixmultiply", - "num-complex", - "num-integer", - "num-traits", - "rawpointer", -] - [[package]] name = "net2" version = "0.2.38" @@ -4353,15 +4340,6 @@ dependencies = [ "zeroize", ] -[[package]] -name = "num-complex" -version = "0.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02e0d21255c828d6f128a1e41534206671e8c3ea0c62f32291e808dc82cff17d" -dependencies = [ - "num-traits", -] - [[package]] name = "num-integer" version = "0.1.45" @@ -8050,14 +8028,13 @@ dependencies = [ "lazy_static", "log", "matrixmultiply", - "ndarray", "picker", "project", "rand 0.8.5", + "rpc", "rusqlite", "serde", "serde_json", - "sha-1 0.10.1", "smol", "tempdir", "theme", diff --git a/crates/vector_store/Cargo.toml b/crates/vector_store/Cargo.toml index 4ecd46cb92..d1ad8a0f9b 100644 --- a/crates/vector_store/Cargo.toml +++ b/crates/vector_store/Cargo.toml @@ -17,6 +17,7 @@ util = { path = "../util" } picker = { path = "../picker" } theme = { path = "../theme" } editor = { path = "../editor" } +rpc = { path = "../rpc" } anyhow.workspace = true futures.workspace = true smol.workspace = true @@ -29,14 +30,13 @@ serde.workspace = true serde_json.workspace = true async-trait.workspace = true bincode = "1.3.3" -ndarray = "0.15.6" -sha-1 = "0.10.1" matrixmultiply = "0.3.7" [dev-dependencies] gpui = { path = "../gpui", features = ["test-support"] } language = { path = "../language", features = ["test-support"] } project = { path = "../project", features = ["test-support"] } +rpc = { path = "../rpc", features = ["test-support"] } workspace = { path = "../workspace", features = ["test-support"] } tree-sitter-rust = "*" rand.workspace = true diff --git a/crates/vector_store/src/db.rs b/crates/vector_store/src/db.rs index fec2980550..f822cca77e 100644 --- a/crates/vector_store/src/db.rs +++ b/crates/vector_store/src/db.rs @@ -2,18 +2,17 @@ use std::{ collections::HashMap, path::{Path, PathBuf}, rc::Rc, + time::SystemTime, }; use anyhow::{anyhow, Result}; +use crate::IndexedFile; +use rpc::proto::Timestamp; use rusqlite::{ params, - types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef}, - ToSql, + types::{FromSql, FromSqlResult, ValueRef}, }; -use sha1::{Digest, Sha1}; - -use crate::IndexedFile; // Note this is not an appropriate document #[derive(Debug)] @@ -29,60 +28,7 @@ pub struct DocumentRecord { pub struct FileRecord { pub id: usize, pub relative_path: String, - pub sha1: FileSha1, -} - -#[derive(Debug)] -pub struct FileSha1(pub Vec); - -impl FileSha1 { - pub fn from_str(content: String) -> Self { - let mut hasher = Sha1::new(); - hasher.update(content); - let sha1 = hasher.finalize()[..] - .into_iter() - .map(|val| val.to_owned()) - .collect::>(); - return FileSha1(sha1); - } - - pub fn equals(&self, content: &String) -> bool { - let mut hasher = Sha1::new(); - hasher.update(content); - let sha1 = hasher.finalize()[..] - .into_iter() - .map(|val| val.to_owned()) - .collect::>(); - - let equal = self - .0 - .clone() - .into_iter() - .zip(sha1) - .filter(|&(a, b)| a == b) - .count() - == self.0.len(); - - equal - } -} - -impl ToSql for FileSha1 { - fn to_sql(&self) -> rusqlite::Result> { - return self.0.to_sql(); - } -} - -impl FromSql for FileSha1 { - fn column_result(value: ValueRef) -> FromSqlResult { - let bytes = value.as_blob()?; - Ok(FileSha1( - bytes - .into_iter() - .map(|val| val.to_owned()) - .collect::>(), - )) - } + pub mtime: Timestamp, } #[derive(Debug)] @@ -133,7 +79,8 @@ impl VectorDatabase { id INTEGER PRIMARY KEY AUTOINCREMENT, worktree_id INTEGER NOT NULL, relative_path VARCHAR NOT NULL, - sha1 BLOB NOT NULL, + mtime_seconds INTEGER NOT NULL, + mtime_nanos INTEGER NOT NULL, FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE )", [], @@ -170,11 +117,20 @@ impl VectorDatabase { ", params![worktree_id, indexed_file.path.to_str()], )?; + let mtime = Timestamp::from(indexed_file.mtime); self.db.execute( " - INSERT INTO files (worktree_id, relative_path, sha1) VALUES (?1, ?2, $3); + INSERT INTO files + (worktree_id, relative_path, mtime_seconds, mtime_nanos) + VALUES + (?1, ?2, $3, $4); ", - params![worktree_id, indexed_file.path.to_str(), indexed_file.sha1], + params![ + worktree_id, + indexed_file.path.to_str(), + mtime.seconds, + mtime.nanos + ], )?; let file_id = self.db.last_insert_rowid(); @@ -224,13 +180,24 @@ impl VectorDatabase { Ok(self.db.last_insert_rowid()) } - pub fn get_file_hashes(&self, worktree_id: i64) -> Result> { + pub fn get_file_mtimes(&self, worktree_id: i64) -> Result> { let mut statement = self.db.prepare( - "SELECT relative_path, sha1 FROM files WHERE worktree_id = ?1 ORDER BY relative_path", + " + SELECT relative_path, mtime_seconds, mtime_nanos + FROM files + WHERE worktree_id = ?1 + ORDER BY relative_path", )?; - let mut result: HashMap = HashMap::new(); + let mut result: HashMap = HashMap::new(); for row in statement.query_map(params![worktree_id], |row| { - Ok((row.get::<_, String>(0)?.into(), row.get(1)?)) + Ok(( + row.get::<_, String>(0)?.into(), + Timestamp { + seconds: row.get(1)?, + nanos: row.get(2)?, + } + .into(), + )) })? { let row = row?; result.insert(row.0, row.1); diff --git a/crates/vector_store/src/vector_store.rs b/crates/vector_store/src/vector_store.rs index 35a467b82f..c329206c4b 100644 --- a/crates/vector_store/src/vector_store.rs +++ b/crates/vector_store/src/vector_store.rs @@ -6,7 +6,7 @@ mod modal; mod vector_store_tests; use anyhow::{anyhow, Result}; -use db::{FileSha1, VectorDatabase}; +use db::VectorDatabase; use embedding::{EmbeddingProvider, OpenAIEmbeddings}; use gpui::{AppContext, Entity, ModelContext, ModelHandle, Task, ViewContext}; use language::{Language, LanguageRegistry}; @@ -15,9 +15,10 @@ use project::{Fs, Project, WorktreeId}; use smol::channel; use std::{ cmp::Ordering, - collections::{HashMap, HashSet}, + collections::HashMap, path::{Path, PathBuf}, sync::Arc, + time::SystemTime, }; use tree_sitter::{Parser, QueryCursor}; use util::{ @@ -46,6 +47,7 @@ pub fn init( VectorStore::new( fs, db_file_path, + // Arc::new(embedding::DummyEmbeddings {}), Arc::new(OpenAIEmbeddings { client: http_client, }), @@ -91,7 +93,7 @@ pub fn init( #[derive(Debug)] pub struct IndexedFile { path: PathBuf, - sha1: FileSha1, + mtime: SystemTime, documents: Vec, } @@ -131,9 +133,10 @@ impl VectorStore { cursor: &mut QueryCursor, parser: &mut Parser, embedding_provider: &dyn EmbeddingProvider, + fs: &Arc, language: Arc, file_path: PathBuf, - content: String, + mtime: SystemTime, ) -> Result { let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?; let embedding_config = grammar @@ -141,6 +144,8 @@ impl VectorStore { .as_ref() .ok_or_else(|| anyhow!("no outline query"))?; + let content = fs.load(&file_path).await?; + parser.set_language(grammar.ts_language).unwrap(); let tree = parser .parse(&content, None) @@ -184,11 +189,9 @@ impl VectorStore { } } - let sha1 = FileSha1::from_str(content); - return Ok(IndexedFile { path: file_path, - sha1, + mtime, documents, }); } @@ -231,38 +234,36 @@ impl VectorStore { // Here we query the worktree ids, and yet we dont have them elsewhere // We likely want to clean up these datastructures - let (db, worktree_hashes, worktree_db_ids) = cx + let (db, mut worktree_file_times, worktree_db_ids) = cx .background() .spawn({ let worktrees = worktrees.clone(); async move { let mut worktree_db_ids: HashMap = HashMap::new(); - let mut hashes: HashMap> = + let mut file_times: HashMap> = HashMap::new(); for worktree in worktrees { let worktree_db_id = db.find_or_create_worktree(worktree.abs_path().as_ref())?; worktree_db_ids.insert(worktree.id(), worktree_db_id); - hashes.insert(worktree.id(), db.get_file_hashes(worktree_db_id)?); + file_times.insert(worktree.id(), db.get_file_mtimes(worktree_db_id)?); } - anyhow::Ok((db, hashes, worktree_db_ids)) + anyhow::Ok((db, file_times, worktree_db_ids)) } }) .await?; let (paths_tx, paths_rx) = - channel::unbounded::<(i64, PathBuf, String, Arc)>(); + channel::unbounded::<(i64, PathBuf, Arc, SystemTime)>(); let (delete_paths_tx, delete_paths_rx) = channel::unbounded::<(i64, PathBuf)>(); let (indexed_files_tx, indexed_files_rx) = channel::unbounded::<(i64, IndexedFile)>(); cx.background() .spawn({ - let fs = fs.clone(); let worktree_db_ids = worktree_db_ids.clone(); async move { for worktree in worktrees.into_iter() { - let file_hashes = &worktree_hashes[&worktree.id()]; - let mut files_included = - file_hashes.keys().collect::>(); + let mut file_mtimes = + worktree_file_times.remove(&worktree.id()).unwrap(); for file in worktree.files(false, 0) { let absolute_path = worktree.absolutize(&file.path); @@ -278,30 +279,26 @@ impl VectorStore { continue; } - if let Some(content) = fs.load(&absolute_path).await.log_err() { - let path_buf = file.path.to_path_buf(); - let already_stored = file_hashes.get(&path_buf).map_or( - false, - |existing_hash| { - files_included.remove(&path_buf); - existing_hash.equals(&content) - }, - ); + let path_buf = file.path.to_path_buf(); + let stored_mtime = file_mtimes.remove(&file.path.to_path_buf()); + let already_stored = stored_mtime + .map_or(false, |existing_mtime| { + existing_mtime == file.mtime + }); - if !already_stored { - paths_tx - .try_send(( - worktree_db_ids[&worktree.id()], - path_buf, - content, - language, - )) - .unwrap(); - } + if !already_stored { + paths_tx + .try_send(( + worktree_db_ids[&worktree.id()], + path_buf, + language, + file.mtime, + )) + .unwrap(); } } } - for file in files_included { + for file in file_mtimes.keys() { delete_paths_tx .try_send((worktree_db_ids[&worktree.id()], file.to_owned())) .unwrap(); @@ -336,16 +333,17 @@ impl VectorStore { scope.spawn(async { let mut parser = Parser::new(); let mut cursor = QueryCursor::new(); - while let Ok((worktree_id, file_path, content, language)) = + while let Ok((worktree_id, file_path, language, mtime)) = paths_rx.recv().await { if let Some(indexed_file) = Self::index_file( &mut cursor, &mut parser, embedding_provider.as_ref(), + &fs, language, file_path, - content, + mtime, ) .await .log_err() @@ -395,6 +393,8 @@ impl VectorStore { }) .collect::>(); + log::info!("Searching for: {:?}", phrase); + let embedding_provider = self.embedding_provider.clone(); let database_url = self.database_url.clone(); cx.spawn(|this, cx| async move { From 18a5a47f8ab758d0b4288871457af5aa05d1404b Mon Sep 17 00:00:00 2001 From: KCaverly Date: Fri, 30 Jun 2023 18:41:19 -0400 Subject: [PATCH 24/51] moved semantic search model to dev and preview only. moved db update tasks to long lived persistent task. Co-authored-by: maxbrunsfeld --- crates/project/src/project.rs | 5 + crates/vector_store/src/modal.rs | 2 +- crates/vector_store/src/vector_store.rs | 328 ++++++++++++------ crates/vector_store/src/vector_store_tests.rs | 25 +- 4 files changed, 239 insertions(+), 121 deletions(-) diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index bbb2064da2..eb0004850c 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -260,6 +260,7 @@ pub enum Event { ActiveEntryChanged(Option), WorktreeAdded, WorktreeRemoved(WorktreeId), + WorktreeUpdatedEntries(WorktreeId, UpdatedEntriesSet), DiskBasedDiagnosticsStarted { language_server_id: LanguageServerId, }, @@ -5371,6 +5372,10 @@ impl Project { this.update_local_worktree_buffers(&worktree, changes, cx); this.update_local_worktree_language_servers(&worktree, changes, cx); this.update_local_worktree_settings(&worktree, changes, cx); + cx.emit(Event::WorktreeUpdatedEntries( + worktree.read(cx).id(), + changes.clone(), + )); } worktree::Event::UpdatedGitRepositories(updated_repos) => { this.update_local_worktree_buffers_git_repos(worktree, updated_repos, cx) diff --git a/crates/vector_store/src/modal.rs b/crates/vector_store/src/modal.rs index 4d377c6819..9225fe8786 100644 --- a/crates/vector_store/src/modal.rs +++ b/crates/vector_store/src/modal.rs @@ -124,7 +124,7 @@ impl PickerDelegate for SemanticSearchDelegate { if let Some(retrieved) = retrieved_cached.log_err() { if !retrieved { let task = vector_store.update(&mut cx, |store, cx| { - store.search(&project, query.to_string(), 10, cx) + store.search(project.clone(), query.to_string(), 10, cx) }); if let Some(results) = task.await.log_err() { diff --git a/crates/vector_store/src/vector_store.rs b/crates/vector_store/src/vector_store.rs index c329206c4b..3f0a7001ef 100644 --- a/crates/vector_store/src/vector_store.rs +++ b/crates/vector_store/src/vector_store.rs @@ -8,7 +8,11 @@ mod vector_store_tests; use anyhow::{anyhow, Result}; use db::VectorDatabase; use embedding::{EmbeddingProvider, OpenAIEmbeddings}; -use gpui::{AppContext, Entity, ModelContext, ModelHandle, Task, ViewContext}; +use futures::{channel::oneshot, Future}; +use gpui::{ + AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, ViewContext, + WeakModelHandle, +}; use language::{Language, LanguageRegistry}; use modal::{SemanticSearch, SemanticSearchDelegate, Toggle}; use project::{Fs, Project, WorktreeId}; @@ -22,7 +26,10 @@ use std::{ }; use tree_sitter::{Parser, QueryCursor}; use util::{ - channel::RELEASE_CHANNEL_NAME, http::HttpClient, paths::EMBEDDINGS_DIR, ResultExt, TryFutureExt, + channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME}, + http::HttpClient, + paths::EMBEDDINGS_DIR, + ResultExt, }; use workspace::{Workspace, WorkspaceCreated}; @@ -39,12 +46,16 @@ pub fn init( language_registry: Arc, cx: &mut AppContext, ) { + if *RELEASE_CHANNEL == ReleaseChannel::Stable { + return; + } + let db_file_path = EMBEDDINGS_DIR .join(Path::new(RELEASE_CHANNEL_NAME.as_str())) .join("embeddings_db"); - let vector_store = cx.add_model(|_| { - VectorStore::new( + cx.spawn(move |mut cx| async move { + let vector_store = VectorStore::new( fs, db_file_path, // Arc::new(embedding::DummyEmbeddings {}), @@ -52,42 +63,49 @@ pub fn init( client: http_client, }), language_registry, + cx.clone(), ) - }); + .await?; - cx.subscribe_global::({ - let vector_store = vector_store.clone(); - move |event, cx| { - let workspace = &event.0; - if let Some(workspace) = workspace.upgrade(cx) { - let project = workspace.read(cx).project().clone(); - if project.read(cx).is_local() { - vector_store.update(cx, |store, cx| { - store.add_project(project, cx).detach(); - }); + cx.update(|cx| { + cx.subscribe_global::({ + let vector_store = vector_store.clone(); + move |event, cx| { + let workspace = &event.0; + if let Some(workspace) = workspace.upgrade(cx) { + let project = workspace.read(cx).project().clone(); + if project.read(cx).is_local() { + vector_store.update(cx, |store, cx| { + store.add_project(project, cx).detach(); + }); + } + } } - } - } + }) + .detach(); + + cx.add_action({ + move |workspace: &mut Workspace, _: &Toggle, cx: &mut ViewContext| { + let vector_store = vector_store.clone(); + workspace.toggle_modal(cx, |workspace, cx| { + let project = workspace.project().clone(); + let workspace = cx.weak_handle(); + cx.add_view(|cx| { + SemanticSearch::new( + SemanticSearchDelegate::new(workspace, project, vector_store), + cx, + ) + }) + }) + } + }); + + SemanticSearch::init(cx); + }); + + anyhow::Ok(()) }) .detach(); - - cx.add_action({ - move |workspace: &mut Workspace, _: &Toggle, cx: &mut ViewContext| { - let vector_store = vector_store.clone(); - workspace.toggle_modal(cx, |workspace, cx| { - let project = workspace.project().clone(); - let workspace = cx.weak_handle(); - cx.add_view(|cx| { - SemanticSearch::new( - SemanticSearchDelegate::new(workspace, project, vector_store), - cx, - ) - }) - }) - } - }); - - SemanticSearch::init(cx); } #[derive(Debug)] @@ -102,7 +120,14 @@ pub struct VectorStore { database_url: Arc, embedding_provider: Arc, language_registry: Arc, + db_update_tx: channel::Sender, + _db_update_task: Task<()>, + projects: HashMap, ProjectState>, +} + +struct ProjectState { worktree_db_ids: Vec<(WorktreeId, i64)>, + _subscription: gpui::Subscription, } #[derive(Debug, Clone)] @@ -113,20 +138,81 @@ pub struct SearchResult { pub file_path: PathBuf, } +enum DbWrite { + InsertFile { + worktree_id: i64, + indexed_file: IndexedFile, + }, + Delete { + worktree_id: i64, + path: PathBuf, + }, + FindOrCreateWorktree { + path: PathBuf, + sender: oneshot::Sender>, + }, +} + impl VectorStore { - fn new( + async fn new( fs: Arc, database_url: PathBuf, embedding_provider: Arc, language_registry: Arc, - ) -> Self { - Self { - fs, - database_url: Arc::new(database_url), - embedding_provider, - language_registry, - worktree_db_ids: Vec::new(), - } + mut cx: AsyncAppContext, + ) -> Result> { + let database_url = Arc::new(database_url); + + let db = cx + .background() + .spawn({ + let fs = fs.clone(); + let database_url = database_url.clone(); + async move { + if let Some(db_directory) = database_url.parent() { + fs.create_dir(db_directory).await.log_err(); + } + + let db = VectorDatabase::new(database_url.to_string_lossy().to_string())?; + anyhow::Ok(db) + } + }) + .await?; + + Ok(cx.add_model(|cx| { + let (db_update_tx, db_update_rx) = channel::unbounded(); + let _db_update_task = cx.background().spawn(async move { + while let Ok(job) = db_update_rx.recv().await { + match job { + DbWrite::InsertFile { + worktree_id, + indexed_file, + } => { + log::info!("Inserting File: {:?}", &indexed_file.path); + db.insert_file(worktree_id, indexed_file).log_err(); + } + DbWrite::Delete { worktree_id, path } => { + log::info!("Deleting File: {:?}", &path); + db.delete_file(worktree_id, path).log_err(); + } + DbWrite::FindOrCreateWorktree { path, sender } => { + let id = db.find_or_create_worktree(&path); + sender.send(id).ok(); + } + } + } + }); + + Self { + fs, + database_url, + db_update_tx, + embedding_provider, + language_registry, + projects: HashMap::new(), + _db_update_task, + } + })) } async fn index_file( @@ -196,6 +282,14 @@ impl VectorStore { }); } + fn find_or_create_worktree(&self, path: PathBuf) -> impl Future> { + let (tx, rx) = oneshot::channel(); + self.db_update_tx + .try_send(DbWrite::FindOrCreateWorktree { path, sender: tx }) + .unwrap(); + async move { rx.await? } + } + fn add_project( &mut self, project: ModelHandle, @@ -211,19 +305,28 @@ impl VectorStore { } }) .collect::>(); + let worktree_db_ids = project + .read(cx) + .worktrees(cx) + .map(|worktree| { + self.find_or_create_worktree(worktree.read(cx).abs_path().to_path_buf()) + }) + .collect::>(); let fs = self.fs.clone(); let language_registry = self.language_registry.clone(); let embedding_provider = self.embedding_provider.clone(); let database_url = self.database_url.clone(); + let db_update_tx = self.db_update_tx.clone(); cx.spawn(|this, mut cx| async move { futures::future::join_all(worktree_scans_complete).await; + let worktree_db_ids = futures::future::join_all(worktree_db_ids).await; + if let Some(db_directory) = database_url.parent() { fs.create_dir(db_directory).await.log_err(); } - let db = VectorDatabase::new(database_url.to_string_lossy().into())?; let worktrees = project.read_with(&cx, |project, cx| { project @@ -234,32 +337,31 @@ impl VectorStore { // Here we query the worktree ids, and yet we dont have them elsewhere // We likely want to clean up these datastructures - let (db, mut worktree_file_times, worktree_db_ids) = cx + let (mut worktree_file_times, db_ids_by_worktree_id) = cx .background() .spawn({ let worktrees = worktrees.clone(); async move { - let mut worktree_db_ids: HashMap = HashMap::new(); + let db = VectorDatabase::new(database_url.to_string_lossy().into())?; + let mut db_ids_by_worktree_id = HashMap::new(); let mut file_times: HashMap> = HashMap::new(); - for worktree in worktrees { - let worktree_db_id = - db.find_or_create_worktree(worktree.abs_path().as_ref())?; - worktree_db_ids.insert(worktree.id(), worktree_db_id); - file_times.insert(worktree.id(), db.get_file_mtimes(worktree_db_id)?); + for (worktree, db_id) in worktrees.iter().zip(worktree_db_ids) { + let db_id = db_id?; + db_ids_by_worktree_id.insert(worktree.id(), db_id); + file_times.insert(worktree.id(), db.get_file_mtimes(db_id)?); } - anyhow::Ok((db, file_times, worktree_db_ids)) + anyhow::Ok((file_times, db_ids_by_worktree_id)) } }) .await?; let (paths_tx, paths_rx) = channel::unbounded::<(i64, PathBuf, Arc, SystemTime)>(); - let (delete_paths_tx, delete_paths_rx) = channel::unbounded::<(i64, PathBuf)>(); - let (indexed_files_tx, indexed_files_rx) = channel::unbounded::<(i64, IndexedFile)>(); cx.background() .spawn({ - let worktree_db_ids = worktree_db_ids.clone(); + let db_ids_by_worktree_id = db_ids_by_worktree_id.clone(); + let db_update_tx = db_update_tx.clone(); async move { for worktree in worktrees.into_iter() { let mut file_mtimes = @@ -289,7 +391,7 @@ impl VectorStore { if !already_stored { paths_tx .try_send(( - worktree_db_ids[&worktree.id()], + db_ids_by_worktree_id[&worktree.id()], path_buf, language, file.mtime, @@ -299,8 +401,11 @@ impl VectorStore { } } for file in file_mtimes.keys() { - delete_paths_tx - .try_send((worktree_db_ids[&worktree.id()], file.to_owned())) + db_update_tx + .try_send(DbWrite::Delete { + worktree_id: db_ids_by_worktree_id[&worktree.id()], + path: file.to_owned(), + }) .unwrap(); } } @@ -308,25 +413,6 @@ impl VectorStore { }) .detach(); - let db_update_task = cx.background().spawn( - async move { - // Inserting all new files - while let Ok((worktree_id, indexed_file)) = indexed_files_rx.recv().await { - log::info!("Inserting File: {:?}", &indexed_file.path); - db.insert_file(worktree_id, indexed_file).log_err(); - } - - // Deleting all old files - while let Ok((worktree_id, delete_path)) = delete_paths_rx.recv().await { - log::info!("Deleting File: {:?}", &delete_path); - db.delete_file(worktree_id, delete_path).log_err(); - } - - anyhow::Ok(()) - } - .log_err(), - ); - cx.background() .scoped(|scope| { for _ in 0..cx.background().num_cpus() { @@ -348,8 +434,11 @@ impl VectorStore { .await .log_err() { - indexed_files_tx - .try_send((worktree_id, indexed_file)) + db_update_tx + .try_send(DbWrite::InsertFile { + worktree_id, + indexed_file, + }) .unwrap(); } } @@ -357,12 +446,22 @@ impl VectorStore { } }) .await; - drop(indexed_files_tx); - db_update_task.await; + this.update(&mut cx, |this, cx| { + let _subscription = cx.subscribe(&project, |this, project, event, cx| { + if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event { + // + log::info!("worktree changes {:?}", changes); + } + }); - this.update(&mut cx, |this, _| { - this.worktree_db_ids.extend(worktree_db_ids); + this.projects.insert( + project.downgrade(), + ProjectState { + worktree_db_ids: db_ids_by_worktree_id.into_iter().collect(), + _subscription, + }, + ); }); log::info!("Semantic Indexing Complete!"); @@ -373,23 +472,32 @@ impl VectorStore { pub fn search( &mut self, - project: &ModelHandle, + project: ModelHandle, phrase: String, limit: usize, cx: &mut ModelContext, ) -> Task>> { - let project = project.read(cx); + let project_state = if let Some(state) = self.projects.get(&project.downgrade()) { + state + } else { + return Task::ready(Err(anyhow!("project not added"))); + }; + let worktree_db_ids = project + .read(cx) .worktrees(cx) .filter_map(|worktree| { let worktree_id = worktree.read(cx).id(); - self.worktree_db_ids.iter().find_map(|(id, db_id)| { - if *id == worktree_id { - Some(*db_id) - } else { - None - } - }) + project_state + .worktree_db_ids + .iter() + .find_map(|(id, db_id)| { + if *id == worktree_id { + Some(*db_id) + } else { + None + } + }) }) .collect::>(); @@ -428,17 +536,27 @@ impl VectorStore { }) .await?; - let results = this.read_with(&cx, |this, _| { - documents + this.read_with(&cx, |this, _| { + let project_state = if let Some(state) = this.projects.get(&project.downgrade()) { + state + } else { + return Err(anyhow!("project not added")); + }; + + Ok(documents .into_iter() .filter_map(|(worktree_db_id, file_path, offset, name)| { - let worktree_id = this.worktree_db_ids.iter().find_map(|(id, db_id)| { - if *db_id == worktree_db_id { - Some(*id) - } else { - None - } - })?; + let worktree_id = + project_state + .worktree_db_ids + .iter() + .find_map(|(id, db_id)| { + if *db_id == worktree_db_id { + Some(*id) + } else { + None + } + })?; Some(SearchResult { worktree_id, name, @@ -446,10 +564,8 @@ impl VectorStore { file_path, }) }) - .collect() - }); - - anyhow::Ok(results) + .collect()) + }) }) } } diff --git a/crates/vector_store/src/vector_store_tests.rs b/crates/vector_store/src/vector_store_tests.rs index 78470ad4be..51065c0ee4 100644 --- a/crates/vector_store/src/vector_store_tests.rs +++ b/crates/vector_store/src/vector_store_tests.rs @@ -5,7 +5,7 @@ use anyhow::Result; use async_trait::async_trait; use gpui::{Task, TestAppContext}; use language::{Language, LanguageConfig, LanguageRegistry}; -use project::{FakeFs, Project}; +use project::{FakeFs, Fs, Project}; use rand::Rng; use serde_json::json; use unindent::Unindent; @@ -60,14 +60,15 @@ async fn test_vector_store(cx: &mut TestAppContext) { let db_dir = tempdir::TempDir::new("vector-store").unwrap(); let db_path = db_dir.path().join("db.sqlite"); - let store = cx.add_model(|_| { - VectorStore::new( - fs.clone(), - db_path, - Arc::new(FakeEmbeddingProvider), - languages, - ) - }); + let store = VectorStore::new( + fs.clone(), + db_path, + Arc::new(FakeEmbeddingProvider), + languages, + cx.to_async(), + ) + .await + .unwrap(); let project = Project::test(fs, ["/the-root".as_ref()], cx).await; let worktree_id = project.read_with(cx, |project, cx| { @@ -75,15 +76,11 @@ async fn test_vector_store(cx: &mut TestAppContext) { }); let add_project = store.update(cx, |store, cx| store.add_project(project.clone(), cx)); - // TODO - remove - cx.foreground() - .advance_clock(std::time::Duration::from_secs(3)); - add_project.await.unwrap(); let search_results = store .update(cx, |store, cx| { - store.search(&project, "aaaa".to_string(), 5, cx) + store.search(project.clone(), "aaaa".to_string(), 5, cx) }) .await .unwrap(); From e45d3a0a635ed7d8846134c7a02514ce3733d727 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Tue, 4 Jul 2023 11:46:09 -0400 Subject: [PATCH 25/51] WIP: initial reindexing logic worked out --- crates/vector_store/src/vector_store.rs | 100 ++++++++++++++++++++++-- 1 file changed, 93 insertions(+), 7 deletions(-) diff --git a/crates/vector_store/src/vector_store.rs b/crates/vector_store/src/vector_store.rs index 3f0a7001ef..1bdc0127b7 100644 --- a/crates/vector_store/src/vector_store.rs +++ b/crates/vector_store/src/vector_store.rs @@ -58,10 +58,10 @@ pub fn init( let vector_store = VectorStore::new( fs, db_file_path, - // Arc::new(embedding::DummyEmbeddings {}), - Arc::new(OpenAIEmbeddings { - client: http_client, - }), + Arc::new(embedding::DummyEmbeddings {}), + // Arc::new(OpenAIEmbeddings { + // client: http_client, + // }), language_registry, cx.clone(), ) @@ -362,6 +362,8 @@ impl VectorStore { .spawn({ let db_ids_by_worktree_id = db_ids_by_worktree_id.clone(); let db_update_tx = db_update_tx.clone(); + let language_registry = language_registry.clone(); + let paths_tx = paths_tx.clone(); async move { for worktree in worktrees.into_iter() { let mut file_mtimes = @@ -449,9 +451,93 @@ impl VectorStore { this.update(&mut cx, |this, cx| { let _subscription = cx.subscribe(&project, |this, project, event, cx| { - if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event { - // - log::info!("worktree changes {:?}", changes); + if let Some(project_state) = this.projects.get(&project.downgrade()) { + let worktree_db_ids = project_state.worktree_db_ids.clone(); + + if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event + { + // Iterate through changes + let language_registry = this.language_registry.clone(); + + let db = + VectorDatabase::new(this.database_url.to_string_lossy().into()); + if db.is_err() { + return; + } + let db = db.unwrap(); + + let worktree_db_id: Option = { + let mut found_db_id = None; + for (w_id, db_id) in worktree_db_ids.into_iter() { + if &w_id == worktree_id { + found_db_id = Some(db_id); + } + } + + found_db_id + }; + + if worktree_db_id.is_none() { + return; + } + let worktree_db_id = worktree_db_id.unwrap(); + + let file_mtimes = db.get_file_mtimes(worktree_db_id); + if file_mtimes.is_err() { + return; + } + + let file_mtimes = file_mtimes.unwrap(); + + smol::block_on(async move { + for change in changes.into_iter() { + let change_path = change.0.clone(); + log::info!("Change: {:?}", &change_path); + if let Ok(language) = language_registry + .language_for_file(&change_path.to_path_buf(), None) + .await + { + if language + .grammar() + .and_then(|grammar| grammar.embedding_config.as_ref()) + .is_none() + { + continue; + } + log::info!("Language found: {:?}: ", language.name()); + + // TODO: Make this a bit more defensive + let modified_time = + change_path.metadata().unwrap().modified().unwrap(); + let existing_time = + file_mtimes.get(&change_path.to_path_buf()); + let already_stored = + existing_time.map_or(false, |existing_time| { + if &modified_time != existing_time + && existing_time.elapsed().unwrap().as_secs() + > 30 + { + false + } else { + true + } + }); + + if !already_stored { + log::info!("Need to reindex: {:?}", &change_path); + // paths_tx + // .try_send(( + // worktree_db_id, + // change_path.to_path_buf(), + // language, + // modified_time, + // )) + // .unwrap(); + } + } + } + }) + } } }); From b6520a8f1d11d39055273758d59d3647f2864046 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Tue, 4 Jul 2023 14:42:12 -0400 Subject: [PATCH 26/51] updated vector_store to reindex on save after timed delay --- crates/vector_store/src/vector_store.rs | 108 +++++++++--------- crates/vector_store/src/vector_store_tests.rs | 2 +- 2 files changed, 57 insertions(+), 53 deletions(-) diff --git a/crates/vector_store/src/vector_store.rs b/crates/vector_store/src/vector_store.rs index 1bdc0127b7..5189993eee 100644 --- a/crates/vector_store/src/vector_store.rs +++ b/crates/vector_store/src/vector_store.rs @@ -33,6 +33,8 @@ use util::{ }; use workspace::{Workspace, WorkspaceCreated}; +const REINDEXING_DELAY: u64 = 30; + #[derive(Debug)] pub struct Document { pub offset: usize, @@ -58,10 +60,10 @@ pub fn init( let vector_store = VectorStore::new( fs, db_file_path, - Arc::new(embedding::DummyEmbeddings {}), - // Arc::new(OpenAIEmbeddings { - // client: http_client, - // }), + // Arc::new(embedding::DummyEmbeddings {}), + Arc::new(OpenAIEmbeddings { + client: http_client, + }), language_registry, cx.clone(), ) @@ -121,7 +123,9 @@ pub struct VectorStore { embedding_provider: Arc, language_registry: Arc, db_update_tx: channel::Sender, + paths_tx: channel::Sender<(i64, PathBuf, Arc, SystemTime)>, _db_update_task: Task<()>, + _paths_update_task: Task<()>, projects: HashMap, ProjectState>, } @@ -203,14 +207,50 @@ impl VectorStore { } }); + let (paths_tx, paths_rx) = + channel::unbounded::<(i64, PathBuf, Arc, SystemTime)>(); + + let fs_clone = fs.clone(); + let db_update_tx_clone = db_update_tx.clone(); + let embedding_provider_clone = embedding_provider.clone(); + + let _paths_update_task = cx.background().spawn(async move { + let mut parser = Parser::new(); + let mut cursor = QueryCursor::new(); + while let Ok((worktree_id, file_path, language, mtime)) = paths_rx.recv().await { + log::info!("Parsing File: {:?}", &file_path); + if let Some(indexed_file) = Self::index_file( + &mut cursor, + &mut parser, + embedding_provider_clone.as_ref(), + &fs_clone, + language, + file_path, + mtime, + ) + .await + .log_err() + { + db_update_tx_clone + .try_send(DbWrite::InsertFile { + worktree_id, + indexed_file, + }) + .unwrap(); + } + } + }); + Self { fs, database_url, db_update_tx, + paths_tx, embedding_provider, language_registry, projects: HashMap::new(), _db_update_task, + _paths_update_task, } })) } @@ -315,9 +355,9 @@ impl VectorStore { let fs = self.fs.clone(); let language_registry = self.language_registry.clone(); - let embedding_provider = self.embedding_provider.clone(); let database_url = self.database_url.clone(); let db_update_tx = self.db_update_tx.clone(); + let paths_tx = self.paths_tx.clone(); cx.spawn(|this, mut cx| async move { futures::future::join_all(worktree_scans_complete).await; @@ -356,8 +396,6 @@ impl VectorStore { }) .await?; - let (paths_tx, paths_rx) = - channel::unbounded::<(i64, PathBuf, Arc, SystemTime)>(); cx.background() .spawn({ let db_ids_by_worktree_id = db_ids_by_worktree_id.clone(); @@ -415,42 +453,8 @@ impl VectorStore { }) .detach(); - cx.background() - .scoped(|scope| { - for _ in 0..cx.background().num_cpus() { - scope.spawn(async { - let mut parser = Parser::new(); - let mut cursor = QueryCursor::new(); - while let Ok((worktree_id, file_path, language, mtime)) = - paths_rx.recv().await - { - if let Some(indexed_file) = Self::index_file( - &mut cursor, - &mut parser, - embedding_provider.as_ref(), - &fs, - language, - file_path, - mtime, - ) - .await - .log_err() - { - db_update_tx - .try_send(DbWrite::InsertFile { - worktree_id, - indexed_file, - }) - .unwrap(); - } - } - }); - } - }) - .await; - this.update(&mut cx, |this, cx| { - let _subscription = cx.subscribe(&project, |this, project, event, cx| { + let _subscription = cx.subscribe(&project, |this, project, event, _cx| { if let Some(project_state) = this.projects.get(&project.downgrade()) { let worktree_db_ids = project_state.worktree_db_ids.clone(); @@ -488,6 +492,7 @@ impl VectorStore { } let file_mtimes = file_mtimes.unwrap(); + let paths_tx = this.paths_tx.clone(); smol::block_on(async move { for change in changes.into_iter() { @@ -504,7 +509,6 @@ impl VectorStore { { continue; } - log::info!("Language found: {:?}: ", language.name()); // TODO: Make this a bit more defensive let modified_time = @@ -515,7 +519,7 @@ impl VectorStore { existing_time.map_or(false, |existing_time| { if &modified_time != existing_time && existing_time.elapsed().unwrap().as_secs() - > 30 + > REINDEXING_DELAY { false } else { @@ -525,14 +529,14 @@ impl VectorStore { if !already_stored { log::info!("Need to reindex: {:?}", &change_path); - // paths_tx - // .try_send(( - // worktree_db_id, - // change_path.to_path_buf(), - // language, - // modified_time, - // )) - // .unwrap(); + paths_tx + .try_send(( + worktree_db_id, + change_path.to_path_buf(), + language, + modified_time, + )) + .unwrap(); } } } diff --git a/crates/vector_store/src/vector_store_tests.rs b/crates/vector_store/src/vector_store_tests.rs index 51065c0ee4..e25b737b06 100644 --- a/crates/vector_store/src/vector_store_tests.rs +++ b/crates/vector_store/src/vector_store_tests.rs @@ -5,7 +5,7 @@ use anyhow::Result; use async_trait::async_trait; use gpui::{Task, TestAppContext}; use language::{Language, LanguageConfig, LanguageRegistry}; -use project::{FakeFs, Fs, Project}; +use project::{FakeFs, Project}; use rand::Rng; use serde_json::json; use unindent::Unindent; From eff0ee3b60406c53ac2e6f7ebf6f968264e56b5e Mon Sep 17 00:00:00 2001 From: KCaverly Date: Wed, 5 Jul 2023 10:02:42 -0400 Subject: [PATCH 27/51] enabled batching for embedding calls --- crates/vector_store/src/vector_store.rs | 157 ++++++++++++++++++------ 1 file changed, 120 insertions(+), 37 deletions(-) diff --git a/crates/vector_store/src/vector_store.rs b/crates/vector_store/src/vector_store.rs index 5189993eee..e072793e25 100644 --- a/crates/vector_store/src/vector_store.rs +++ b/crates/vector_store/src/vector_store.rs @@ -22,7 +22,7 @@ use std::{ collections::HashMap, path::{Path, PathBuf}, sync::Arc, - time::SystemTime, + time::{Instant, SystemTime}, }; use tree_sitter::{Parser, QueryCursor}; use util::{ @@ -34,8 +34,9 @@ use util::{ use workspace::{Workspace, WorkspaceCreated}; const REINDEXING_DELAY: u64 = 30; +const EMBEDDINGS_BATCH_SIZE: usize = 25; -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Document { pub offset: usize, pub name: String, @@ -110,7 +111,7 @@ pub fn init( .detach(); } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct IndexedFile { path: PathBuf, mtime: SystemTime, @@ -126,6 +127,7 @@ pub struct VectorStore { paths_tx: channel::Sender<(i64, PathBuf, Arc, SystemTime)>, _db_update_task: Task<()>, _paths_update_task: Task<()>, + _embeddings_update_task: Task<()>, projects: HashMap, ProjectState>, } @@ -184,7 +186,14 @@ impl VectorStore { .await?; Ok(cx.add_model(|cx| { + // paths_tx -> embeddings_tx -> db_update_tx + let (db_update_tx, db_update_rx) = channel::unbounded(); + let (paths_tx, paths_rx) = + channel::unbounded::<(i64, PathBuf, Arc, SystemTime)>(); + let (embeddings_tx, embeddings_rx) = + channel::unbounded::<(i64, IndexedFile, Vec)>(); + let _db_update_task = cx.background().spawn(async move { while let Ok(job) = db_update_rx.recv().await { match job { @@ -192,11 +201,9 @@ impl VectorStore { worktree_id, indexed_file, } => { - log::info!("Inserting File: {:?}", &indexed_file.path); db.insert_file(worktree_id, indexed_file).log_err(); } DbWrite::Delete { worktree_id, path } => { - log::info!("Deleting File: {:?}", &path); db.delete_file(worktree_id, path).log_err(); } DbWrite::FindOrCreateWorktree { path, sender } => { @@ -207,35 +214,116 @@ impl VectorStore { } }); - let (paths_tx, paths_rx) = - channel::unbounded::<(i64, PathBuf, Arc, SystemTime)>(); + async fn embed_batch( + embeddings_queue: Vec<(i64, IndexedFile, Vec)>, + embedding_provider: &Arc, + db_update_tx: channel::Sender, + ) -> Result<()> { + let mut embeddings_queue = embeddings_queue.clone(); + + let mut document_spans = vec![]; + for (_, _, document_span) in embeddings_queue.clone().into_iter() { + document_spans.extend(document_span); + } + + let mut embeddings = embedding_provider + .embed_batch(document_spans.iter().map(|x| &**x).collect()) + .await?; + + // This assumes the embeddings are returned in order + let t0 = Instant::now(); + let mut i = 0; + let mut j = 0; + while let Some(embedding) = embeddings.pop() { + // This has to accomodate for multiple indexed_files in a row without documents + while embeddings_queue[i].1.documents.len() == j { + i += 1; + j = 0; + } + + embeddings_queue[i].1.documents[j].embedding = embedding; + j += 1; + } + + for (worktree_id, indexed_file, _) in embeddings_queue.into_iter() { + // TODO: Update this so it doesnt panic + for document in indexed_file.documents.iter() { + assert!( + document.embedding.len() > 0, + "Document Embedding not Complete" + ); + } + + db_update_tx + .send(DbWrite::InsertFile { + worktree_id, + indexed_file, + }) + .await + .unwrap(); + } + + anyhow::Ok(()) + } + + let embedding_provider_clone = embedding_provider.clone(); + + let db_update_tx_clone = db_update_tx.clone(); + let _embeddings_update_task = cx.background().spawn(async move { + let mut queue_len = 0; + let mut embeddings_queue = vec![]; + let mut request_count = 0; + while let Ok((worktree_id, indexed_file, document_spans)) = + embeddings_rx.recv().await + { + queue_len += &document_spans.len(); + embeddings_queue.push((worktree_id, indexed_file, document_spans)); + + if queue_len >= EMBEDDINGS_BATCH_SIZE { + let _ = embed_batch( + embeddings_queue, + &embedding_provider_clone, + db_update_tx_clone.clone(), + ) + .await; + + embeddings_queue = vec![]; + queue_len = 0; + + request_count += 1; + } + } + + if queue_len > 0 { + let _ = embed_batch( + embeddings_queue, + &embedding_provider_clone, + db_update_tx_clone.clone(), + ) + .await; + request_count += 1; + } + }); let fs_clone = fs.clone(); - let db_update_tx_clone = db_update_tx.clone(); - let embedding_provider_clone = embedding_provider.clone(); let _paths_update_task = cx.background().spawn(async move { let mut parser = Parser::new(); let mut cursor = QueryCursor::new(); while let Ok((worktree_id, file_path, language, mtime)) = paths_rx.recv().await { - log::info!("Parsing File: {:?}", &file_path); - if let Some(indexed_file) = Self::index_file( + if let Some((indexed_file, document_spans)) = Self::index_file( &mut cursor, &mut parser, - embedding_provider_clone.as_ref(), &fs_clone, language, - file_path, + file_path.clone(), mtime, ) .await .log_err() { - db_update_tx_clone - .try_send(DbWrite::InsertFile { - worktree_id, - indexed_file, - }) + embeddings_tx + .try_send((worktree_id, indexed_file, document_spans)) .unwrap(); } } @@ -251,6 +339,7 @@ impl VectorStore { projects: HashMap::new(), _db_update_task, _paths_update_task, + _embeddings_update_task, } })) } @@ -258,12 +347,11 @@ impl VectorStore { async fn index_file( cursor: &mut QueryCursor, parser: &mut Parser, - embedding_provider: &dyn EmbeddingProvider, fs: &Arc, language: Arc, file_path: PathBuf, mtime: SystemTime, - ) -> Result { + ) -> Result<(IndexedFile, Vec)> { let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?; let embedding_config = grammar .embedding_config @@ -298,7 +386,7 @@ impl VectorStore { if let Some((item, name)) = content.get(item_range.clone()).zip(content.get(name_range)) { - context_spans.push(item); + context_spans.push(item.to_string()); documents.push(Document { name: name.to_string(), offset: item_range.start, @@ -308,18 +396,14 @@ impl VectorStore { } } - if !documents.is_empty() { - let embeddings = embedding_provider.embed_batch(context_spans).await?; - for (document, embedding) in documents.iter_mut().zip(embeddings) { - document.embedding = embedding; - } - } - - return Ok(IndexedFile { - path: file_path, - mtime, - documents, - }); + return Ok(( + IndexedFile { + path: file_path, + mtime, + documents, + }, + context_spans, + )); } fn find_or_create_worktree(&self, path: PathBuf) -> impl Future> { @@ -454,6 +538,9 @@ impl VectorStore { .detach(); this.update(&mut cx, |this, cx| { + // The below is managing for updated on save + // Currently each time a file is saved, this code is run, and for all the files that were changed, if the current time is + // greater than the previous embedded time by the REINDEXING_DELAY variable, we will send the file off to be indexed. let _subscription = cx.subscribe(&project, |this, project, event, _cx| { if let Some(project_state) = this.projects.get(&project.downgrade()) { let worktree_db_ids = project_state.worktree_db_ids.clone(); @@ -554,8 +641,6 @@ impl VectorStore { ); }); - log::info!("Semantic Indexing Complete!"); - anyhow::Ok(()) }) } @@ -591,8 +676,6 @@ impl VectorStore { }) .collect::>(); - log::info!("Searching for: {:?}", phrase); - let embedding_provider = self.embedding_provider.clone(); let database_url = self.database_url.clone(); cx.spawn(|this, cx| async move { From afccf608f42d9b35d6b1942ae60734f3b3e8d3a9 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Wed, 5 Jul 2023 12:39:08 -0400 Subject: [PATCH 28/51] updated both embed and parsing tasks to be multi-threaded. --- Cargo.lock | 34 +- crates/vector_store/Cargo.toml | 1 + crates/vector_store/src/embedding.rs | 27 +- crates/vector_store/src/vector_store.rs | 411 +++++++++++++----------- 4 files changed, 281 insertions(+), 192 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 59cf30001e..dbc2a1cbb0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -118,7 +118,7 @@ dependencies = [ "settings", "smol", "theme", - "tiktoken-rs", + "tiktoken-rs 0.4.2", "util", "workspace", ] @@ -737,9 +737,9 @@ checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" [[package]] name = "base64" -version = "0.21.0" +version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4a4ddaa51a5bc52a6948f74c06d20aaaddb71924eab79b8c97a8c556e942d6a" +checksum = "604178f6c5c21f02dc555784810edfb88d34ac2c73b2eae109655649ee73ce3d" [[package]] name = "base64ct" @@ -914,9 +914,9 @@ dependencies = [ [[package]] name = "bstr" -version = "1.4.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3d4260bcc2e8fc9df1eac4919a720effeb63a3f0952f5bf4944adfa18897f09" +checksum = "a246e68bb43f6cd9db24bea052a53e40405417c5fb372e3d1a8a7f770a564ef5" dependencies = [ "memchr", "once_cell", @@ -4812,7 +4812,7 @@ version = "1.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9bd9647b268a3d3e14ff09c23201133a62589c658db02bb7388c7246aafe0590" dependencies = [ - "base64 0.21.0", + "base64 0.21.2", "indexmap", "line-wrap", "quick-xml", @@ -5529,7 +5529,7 @@ version = "0.11.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "13293b639a097af28fc8a90f22add145a9c954e49d77da06263d58cf44d5fb91" dependencies = [ - "base64 0.21.0", + "base64 0.21.2", "bytes 1.4.0", "encoding_rs", "futures-core", @@ -5868,7 +5868,7 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d194b56d58803a43635bdc398cd17e383d6f71f9182b9a192c127ca42494a59b" dependencies = [ - "base64 0.21.0", + "base64 0.21.2", ] [[package]] @@ -7118,7 +7118,22 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ba161c549e2c0686f35f5d920e63fad5cafba2c28ad2caceaf07e5d9fa6e8c4" dependencies = [ "anyhow", - "base64 0.21.0", + "base64 0.21.2", + "bstr", + "fancy-regex", + "lazy_static", + "parking_lot 0.12.1", + "rustc-hash", +] + +[[package]] +name = "tiktoken-rs" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a99d843674a3468b4a9200a565bbe909a0152f95e82a52feae71e6bf2d4b49d" +dependencies = [ + "anyhow", + "base64 0.21.2", "bstr", "fancy-regex", "lazy_static", @@ -8038,6 +8053,7 @@ dependencies = [ "smol", "tempdir", "theme", + "tiktoken-rs 0.5.0", "tree-sitter", "tree-sitter-rust", "unindent", diff --git a/crates/vector_store/Cargo.toml b/crates/vector_store/Cargo.toml index d1ad8a0f9b..854afe5b6e 100644 --- a/crates/vector_store/Cargo.toml +++ b/crates/vector_store/Cargo.toml @@ -31,6 +31,7 @@ serde_json.workspace = true async-trait.workspace = true bincode = "1.3.3" matrixmultiply = "0.3.7" +tiktoken-rs = "0.5.0" [dev-dependencies] gpui = { path = "../gpui", features = ["test-support"] } diff --git a/crates/vector_store/src/embedding.rs b/crates/vector_store/src/embedding.rs index 86d8494ab4..72b30d9424 100644 --- a/crates/vector_store/src/embedding.rs +++ b/crates/vector_store/src/embedding.rs @@ -5,8 +5,8 @@ use gpui::serde_json; use isahc::prelude::Configurable; use lazy_static::lazy_static; use serde::{Deserialize, Serialize}; -use std::env; use std::sync::Arc; +use std::{env, time::Instant}; use util::http::{HttpClient, Request}; lazy_static! { @@ -60,9 +60,34 @@ impl EmbeddingProvider for DummyEmbeddings { } } +// impl OpenAIEmbeddings { +// async fn truncate(span: &str) -> String { +// let bpe = cl100k_base().unwrap(); +// let mut tokens = bpe.encode_with_special_tokens(span); +// if tokens.len() > 8192 { +// tokens.truncate(8192); +// let result = bpe.decode(tokens); +// if result.is_ok() { +// return result.unwrap(); +// } +// } + +// return span.to_string(); +// } +// } + #[async_trait] impl EmbeddingProvider for OpenAIEmbeddings { async fn embed_batch(&self, spans: Vec<&str>) -> Result>> { + // Truncate spans to 8192 if needed + // let t0 = Instant::now(); + // let mut truncated_spans = vec![]; + // for span in spans { + // truncated_spans.push(Self::truncate(span)); + // } + // let spans = futures::future::join_all(truncated_spans).await; + // log::info!("Truncated Spans in {:?}", t0.elapsed().as_secs()); + let api_key = OPENAI_API_KEY .as_ref() .ok_or_else(|| anyhow!("no api key"))?; diff --git a/crates/vector_store/src/vector_store.rs b/crates/vector_store/src/vector_store.rs index e072793e25..a63674bc34 100644 --- a/crates/vector_store/src/vector_store.rs +++ b/crates/vector_store/src/vector_store.rs @@ -34,7 +34,7 @@ use util::{ use workspace::{Workspace, WorkspaceCreated}; const REINDEXING_DELAY: u64 = 30; -const EMBEDDINGS_BATCH_SIZE: usize = 25; +const EMBEDDINGS_BATCH_SIZE: usize = 150; #[derive(Debug, Clone)] pub struct Document { @@ -74,6 +74,7 @@ pub fn init( cx.subscribe_global::({ let vector_store = vector_store.clone(); move |event, cx| { + let t0 = Instant::now(); let workspace = &event.0; if let Some(workspace) = workspace.upgrade(cx) { let project = workspace.read(cx).project().clone(); @@ -124,10 +125,14 @@ pub struct VectorStore { embedding_provider: Arc, language_registry: Arc, db_update_tx: channel::Sender, - paths_tx: channel::Sender<(i64, PathBuf, Arc, SystemTime)>, + // embed_batch_tx: channel::Sender)>>, + batch_files_tx: channel::Sender<(i64, IndexedFile, Vec)>, + parsing_files_tx: channel::Sender<(i64, PathBuf, Arc, SystemTime)>, + parsing_files_rx: channel::Receiver<(i64, PathBuf, Arc, SystemTime)>, _db_update_task: Task<()>, - _paths_update_task: Task<()>, - _embeddings_update_task: Task<()>, + _embed_batch_task: Vec>, + _batch_files_task: Task<()>, + _parsing_files_tasks: Vec>, projects: HashMap, ProjectState>, } @@ -188,12 +193,8 @@ impl VectorStore { Ok(cx.add_model(|cx| { // paths_tx -> embeddings_tx -> db_update_tx + //db_update_tx/rx: Updating Database let (db_update_tx, db_update_rx) = channel::unbounded(); - let (paths_tx, paths_rx) = - channel::unbounded::<(i64, PathBuf, Arc, SystemTime)>(); - let (embeddings_tx, embeddings_rx) = - channel::unbounded::<(i64, IndexedFile, Vec)>(); - let _db_update_task = cx.background().spawn(async move { while let Ok(job) = db_update_rx.recv().await { match job { @@ -201,6 +202,7 @@ impl VectorStore { worktree_id, indexed_file, } => { + log::info!("Inserting Data for {:?}", &indexed_file.path); db.insert_file(worktree_id, indexed_file).log_err(); } DbWrite::Delete { worktree_id, path } => { @@ -214,132 +216,137 @@ impl VectorStore { } }); - async fn embed_batch( - embeddings_queue: Vec<(i64, IndexedFile, Vec)>, - embedding_provider: &Arc, - db_update_tx: channel::Sender, - ) -> Result<()> { - let mut embeddings_queue = embeddings_queue.clone(); + // embed_tx/rx: Embed Batch and Send to Database + let (embed_batch_tx, embed_batch_rx) = + channel::unbounded::)>>(); + let mut _embed_batch_task = Vec::new(); + for _ in 0..cx.background().num_cpus() { + let db_update_tx = db_update_tx.clone(); + let embed_batch_rx = embed_batch_rx.clone(); + let embedding_provider = embedding_provider.clone(); + _embed_batch_task.push(cx.background().spawn(async move { + while let Ok(embeddings_queue) = embed_batch_rx.recv().await { + log::info!("Embedding Batch! "); - let mut document_spans = vec![]; - for (_, _, document_span) in embeddings_queue.clone().into_iter() { - document_spans.extend(document_span); - } + // Construct Batch + let mut embeddings_queue = embeddings_queue.clone(); + let mut document_spans = vec![]; + for (_, _, document_span) in embeddings_queue.clone().into_iter() { + document_spans.extend(document_span); + } - let mut embeddings = embedding_provider - .embed_batch(document_spans.iter().map(|x| &**x).collect()) - .await?; + if let Some(mut embeddings) = embedding_provider + .embed_batch(document_spans.iter().map(|x| &**x).collect()) + .await + .log_err() + { + let mut i = 0; + let mut j = 0; + while let Some(embedding) = embeddings.pop() { + while embeddings_queue[i].1.documents.len() == j { + i += 1; + j = 0; + } - // This assumes the embeddings are returned in order - let t0 = Instant::now(); - let mut i = 0; - let mut j = 0; - while let Some(embedding) = embeddings.pop() { - // This has to accomodate for multiple indexed_files in a row without documents - while embeddings_queue[i].1.documents.len() == j { - i += 1; - j = 0; + embeddings_queue[i].1.documents[j].embedding = embedding; + j += 1; + } + + for (worktree_id, indexed_file, _) in embeddings_queue.into_iter() { + for document in indexed_file.documents.iter() { + // TODO: Update this so it doesn't panic + assert!( + document.embedding.len() > 0, + "Document Embedding Not Complete" + ); + } + + db_update_tx + .send(DbWrite::InsertFile { + worktree_id, + indexed_file, + }) + .await + .unwrap(); + } + } } - - embeddings_queue[i].1.documents[j].embedding = embedding; - j += 1; - } - - for (worktree_id, indexed_file, _) in embeddings_queue.into_iter() { - // TODO: Update this so it doesnt panic - for document in indexed_file.documents.iter() { - assert!( - document.embedding.len() > 0, - "Document Embedding not Complete" - ); - } - - db_update_tx - .send(DbWrite::InsertFile { - worktree_id, - indexed_file, - }) - .await - .unwrap(); - } - - anyhow::Ok(()) + })) } - let embedding_provider_clone = embedding_provider.clone(); - - let db_update_tx_clone = db_update_tx.clone(); - let _embeddings_update_task = cx.background().spawn(async move { + // batch_tx/rx: Batch Files to Send for Embeddings + let (batch_files_tx, batch_files_rx) = + channel::unbounded::<(i64, IndexedFile, Vec)>(); + let _batch_files_task = cx.background().spawn(async move { let mut queue_len = 0; let mut embeddings_queue = vec![]; - let mut request_count = 0; while let Ok((worktree_id, indexed_file, document_spans)) = - embeddings_rx.recv().await + batch_files_rx.recv().await { + log::info!("Batching File: {:?}", &indexed_file.path); queue_len += &document_spans.len(); embeddings_queue.push((worktree_id, indexed_file, document_spans)); - if queue_len >= EMBEDDINGS_BATCH_SIZE { - let _ = embed_batch( - embeddings_queue, - &embedding_provider_clone, - db_update_tx_clone.clone(), - ) - .await; - + embed_batch_tx.try_send(embeddings_queue).unwrap(); embeddings_queue = vec![]; queue_len = 0; - - request_count += 1; } } - if queue_len > 0 { - let _ = embed_batch( - embeddings_queue, - &embedding_provider_clone, - db_update_tx_clone.clone(), - ) - .await; - request_count += 1; + embed_batch_tx.try_send(embeddings_queue).unwrap(); } }); - let fs_clone = fs.clone(); + // parsing_files_tx/rx: Parsing Files to Embeddable Documents + let (parsing_files_tx, parsing_files_rx) = + channel::unbounded::<(i64, PathBuf, Arc, SystemTime)>(); - let _paths_update_task = cx.background().spawn(async move { - let mut parser = Parser::new(); - let mut cursor = QueryCursor::new(); - while let Ok((worktree_id, file_path, language, mtime)) = paths_rx.recv().await { - if let Some((indexed_file, document_spans)) = Self::index_file( - &mut cursor, - &mut parser, - &fs_clone, - language, - file_path.clone(), - mtime, - ) - .await - .log_err() + let mut _parsing_files_tasks = Vec::new(); + for _ in 0..cx.background().num_cpus() { + let fs = fs.clone(); + let parsing_files_rx = parsing_files_rx.clone(); + let batch_files_tx = batch_files_tx.clone(); + _parsing_files_tasks.push(cx.background().spawn(async move { + let mut parser = Parser::new(); + let mut cursor = QueryCursor::new(); + while let Ok((worktree_id, file_path, language, mtime)) = + parsing_files_rx.recv().await { - embeddings_tx - .try_send((worktree_id, indexed_file, document_spans)) - .unwrap(); + log::info!("Parsing File: {:?}", &file_path); + if let Some((indexed_file, document_spans)) = Self::index_file( + &mut cursor, + &mut parser, + &fs, + language, + file_path.clone(), + mtime, + ) + .await + .log_err() + { + batch_files_tx + .try_send((worktree_id, indexed_file, document_spans)) + .unwrap(); + } } - } - }); + })); + } Self { fs, database_url, - db_update_tx, - paths_tx, embedding_provider, language_registry, - projects: HashMap::new(), + db_update_tx, + // embed_batch_tx, + batch_files_tx, + parsing_files_tx, + parsing_files_rx, _db_update_task, - _paths_update_task, - _embeddings_update_task, + _embed_batch_task, + _batch_files_task, + _parsing_files_tasks, + projects: HashMap::new(), } })) } @@ -441,12 +448,16 @@ impl VectorStore { let language_registry = self.language_registry.clone(); let database_url = self.database_url.clone(); let db_update_tx = self.db_update_tx.clone(); - let paths_tx = self.paths_tx.clone(); + let parsing_files_tx = self.parsing_files_tx.clone(); + let parsing_files_rx = self.parsing_files_rx.clone(); + let batch_files_tx = self.batch_files_tx.clone(); cx.spawn(|this, mut cx| async move { + let t0 = Instant::now(); futures::future::join_all(worktree_scans_complete).await; let worktree_db_ids = futures::future::join_all(worktree_db_ids).await; + log::info!("Worktree Scanning Done in {:?}", t0.elapsed().as_millis()); if let Some(db_directory) = database_url.parent() { fs.create_dir(db_directory).await.log_err(); @@ -485,8 +496,9 @@ impl VectorStore { let db_ids_by_worktree_id = db_ids_by_worktree_id.clone(); let db_update_tx = db_update_tx.clone(); let language_registry = language_registry.clone(); - let paths_tx = paths_tx.clone(); + let parsing_files_tx = parsing_files_tx.clone(); async move { + let t0 = Instant::now(); for worktree in worktrees.into_iter() { let mut file_mtimes = worktree_file_times.remove(&worktree.id()).unwrap(); @@ -513,7 +525,7 @@ impl VectorStore { }); if !already_stored { - paths_tx + parsing_files_tx .try_send(( db_ids_by_worktree_id[&worktree.id()], path_buf, @@ -533,10 +545,45 @@ impl VectorStore { .unwrap(); } } + log::info!( + "Parsing Worktree Completed in {:?}", + t0.elapsed().as_millis() + ); } }) .detach(); + // cx.background() + // .scoped(|scope| { + // for _ in 0..cx.background().num_cpus() { + // scope.spawn(async { + // let mut parser = Parser::new(); + // let mut cursor = QueryCursor::new(); + // while let Ok((worktree_id, file_path, language, mtime)) = + // parsing_files_rx.recv().await + // { + // log::info!("Parsing File: {:?}", &file_path); + // if let Some((indexed_file, document_spans)) = Self::index_file( + // &mut cursor, + // &mut parser, + // &fs, + // language, + // file_path.clone(), + // mtime, + // ) + // .await + // .log_err() + // { + // batch_files_tx + // .try_send((worktree_id, indexed_file, document_spans)) + // .unwrap(); + // } + // } + // }); + // } + // }) + // .await; + this.update(&mut cx, |this, cx| { // The below is managing for updated on save // Currently each time a file is saved, this code is run, and for all the files that were changed, if the current time is @@ -545,90 +592,90 @@ impl VectorStore { if let Some(project_state) = this.projects.get(&project.downgrade()) { let worktree_db_ids = project_state.worktree_db_ids.clone(); - if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event - { - // Iterate through changes - let language_registry = this.language_registry.clone(); + // if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event + // { + // // Iterate through changes + // let language_registry = this.language_registry.clone(); - let db = - VectorDatabase::new(this.database_url.to_string_lossy().into()); - if db.is_err() { - return; - } - let db = db.unwrap(); + // let db = + // VectorDatabase::new(this.database_url.to_string_lossy().into()); + // if db.is_err() { + // return; + // } + // let db = db.unwrap(); - let worktree_db_id: Option = { - let mut found_db_id = None; - for (w_id, db_id) in worktree_db_ids.into_iter() { - if &w_id == worktree_id { - found_db_id = Some(db_id); - } - } + // let worktree_db_id: Option = { + // let mut found_db_id = None; + // for (w_id, db_id) in worktree_db_ids.into_iter() { + // if &w_id == worktree_id { + // found_db_id = Some(db_id); + // } + // } - found_db_id - }; + // found_db_id + // }; - if worktree_db_id.is_none() { - return; - } - let worktree_db_id = worktree_db_id.unwrap(); + // if worktree_db_id.is_none() { + // return; + // } + // let worktree_db_id = worktree_db_id.unwrap(); - let file_mtimes = db.get_file_mtimes(worktree_db_id); - if file_mtimes.is_err() { - return; - } + // let file_mtimes = db.get_file_mtimes(worktree_db_id); + // if file_mtimes.is_err() { + // return; + // } - let file_mtimes = file_mtimes.unwrap(); - let paths_tx = this.paths_tx.clone(); + // let file_mtimes = file_mtimes.unwrap(); + // let paths_tx = this.paths_tx.clone(); - smol::block_on(async move { - for change in changes.into_iter() { - let change_path = change.0.clone(); - log::info!("Change: {:?}", &change_path); - if let Ok(language) = language_registry - .language_for_file(&change_path.to_path_buf(), None) - .await - { - if language - .grammar() - .and_then(|grammar| grammar.embedding_config.as_ref()) - .is_none() - { - continue; - } + // smol::block_on(async move { + // for change in changes.into_iter() { + // let change_path = change.0.clone(); + // log::info!("Change: {:?}", &change_path); + // if let Ok(language) = language_registry + // .language_for_file(&change_path.to_path_buf(), None) + // .await + // { + // if language + // .grammar() + // .and_then(|grammar| grammar.embedding_config.as_ref()) + // .is_none() + // { + // continue; + // } - // TODO: Make this a bit more defensive - let modified_time = - change_path.metadata().unwrap().modified().unwrap(); - let existing_time = - file_mtimes.get(&change_path.to_path_buf()); - let already_stored = - existing_time.map_or(false, |existing_time| { - if &modified_time != existing_time - && existing_time.elapsed().unwrap().as_secs() - > REINDEXING_DELAY - { - false - } else { - true - } - }); + // // TODO: Make this a bit more defensive + // let modified_time = + // change_path.metadata().unwrap().modified().unwrap(); + // let existing_time = + // file_mtimes.get(&change_path.to_path_buf()); + // let already_stored = + // existing_time.map_or(false, |existing_time| { + // if &modified_time != existing_time + // && existing_time.elapsed().unwrap().as_secs() + // > REINDEXING_DELAY + // { + // false + // } else { + // true + // } + // }); - if !already_stored { - log::info!("Need to reindex: {:?}", &change_path); - paths_tx - .try_send(( - worktree_db_id, - change_path.to_path_buf(), - language, - modified_time, - )) - .unwrap(); - } - } - } - }) - } + // if !already_stored { + // log::info!("Need to reindex: {:?}", &change_path); + // paths_tx + // .try_send(( + // worktree_db_id, + // change_path.to_path_buf(), + // language, + // modified_time, + // )) + // .unwrap(); + // } + // } + // } + // }) + // } } }); From a86b6c42c77c7fe5e3721ba3fe4df0fbe91eb268 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Thu, 6 Jul 2023 11:11:39 -0400 Subject: [PATCH 29/51] corrected batching order and managed for open ai embedding errors --- crates/vector_store/Cargo.toml | 1 + crates/vector_store/src/embedding.rs | 138 ++++++++++------ crates/vector_store/src/vector_store.rs | 203 ++++++++++-------------- 3 files changed, 169 insertions(+), 173 deletions(-) diff --git a/crates/vector_store/Cargo.toml b/crates/vector_store/Cargo.toml index 854afe5b6e..35a6a689ae 100644 --- a/crates/vector_store/Cargo.toml +++ b/crates/vector_store/Cargo.toml @@ -32,6 +32,7 @@ async-trait.workspace = true bincode = "1.3.3" matrixmultiply = "0.3.7" tiktoken-rs = "0.5.0" +rand.workspace = true [dev-dependencies] gpui = { path = "../gpui", features = ["test-support"] } diff --git a/crates/vector_store/src/embedding.rs b/crates/vector_store/src/embedding.rs index 72b30d9424..029a6cdf61 100644 --- a/crates/vector_store/src/embedding.rs +++ b/crates/vector_store/src/embedding.rs @@ -2,15 +2,20 @@ use anyhow::{anyhow, Result}; use async_trait::async_trait; use futures::AsyncReadExt; use gpui::serde_json; +use isahc::http::StatusCode; use isahc::prelude::Configurable; +use isahc::{AsyncBody, Response}; use lazy_static::lazy_static; use serde::{Deserialize, Serialize}; +use std::env; use std::sync::Arc; -use std::{env, time::Instant}; +use std::time::Duration; +use tiktoken_rs::{cl100k_base, CoreBPE}; use util::http::{HttpClient, Request}; lazy_static! { static ref OPENAI_API_KEY: Option = env::var("OPENAI_API_KEY").ok(); + static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap(); } #[derive(Clone)] @@ -60,69 +65,100 @@ impl EmbeddingProvider for DummyEmbeddings { } } -// impl OpenAIEmbeddings { -// async fn truncate(span: &str) -> String { -// let bpe = cl100k_base().unwrap(); -// let mut tokens = bpe.encode_with_special_tokens(span); -// if tokens.len() > 8192 { -// tokens.truncate(8192); -// let result = bpe.decode(tokens); -// if result.is_ok() { -// return result.unwrap(); -// } -// } +impl OpenAIEmbeddings { + async fn truncate(span: String) -> String { + let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span.as_ref()); + if tokens.len() > 8190 { + tokens.truncate(8190); + let result = OPENAI_BPE_TOKENIZER.decode(tokens.clone()); + if result.is_ok() { + let transformed = result.unwrap(); + // assert_ne!(transformed, span); + return transformed; + } + } -// return span.to_string(); -// } -// } - -#[async_trait] -impl EmbeddingProvider for OpenAIEmbeddings { - async fn embed_batch(&self, spans: Vec<&str>) -> Result>> { - // Truncate spans to 8192 if needed - // let t0 = Instant::now(); - // let mut truncated_spans = vec![]; - // for span in spans { - // truncated_spans.push(Self::truncate(span)); - // } - // let spans = futures::future::join_all(truncated_spans).await; - // log::info!("Truncated Spans in {:?}", t0.elapsed().as_secs()); - - let api_key = OPENAI_API_KEY - .as_ref() - .ok_or_else(|| anyhow!("no api key"))?; + return span.to_string(); + } + async fn send_request(&self, api_key: &str, spans: Vec<&str>) -> Result> { let request = Request::post("https://api.openai.com/v1/embeddings") .redirect_policy(isahc::config::RedirectPolicy::Follow) .header("Content-Type", "application/json") .header("Authorization", format!("Bearer {}", api_key)) .body( serde_json::to_string(&OpenAIEmbeddingRequest { - input: spans, + input: spans.clone(), model: "text-embedding-ada-002", }) .unwrap() .into(), )?; - let mut response = self.client.send(request).await?; - if !response.status().is_success() { - return Err(anyhow!("openai embedding failed {}", response.status())); - } - - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?; - - log::info!( - "openai embedding completed. tokens: {:?}", - response.usage.total_tokens - ); - - Ok(response - .data - .into_iter() - .map(|embedding| embedding.embedding) - .collect()) + Ok(self.client.send(request).await?) + } +} + +#[async_trait] +impl EmbeddingProvider for OpenAIEmbeddings { + async fn embed_batch(&self, spans: Vec<&str>) -> Result>> { + const BACKOFF_SECONDS: [usize; 3] = [65, 180, 360]; + const MAX_RETRIES: usize = 3; + + let api_key = OPENAI_API_KEY + .as_ref() + .ok_or_else(|| anyhow!("no api key"))?; + + let mut request_number = 0; + let mut response: Response; + let mut spans: Vec = spans.iter().map(|x| x.to_string()).collect(); + while request_number < MAX_RETRIES { + response = self + .send_request(api_key, spans.iter().map(|x| &**x).collect()) + .await?; + request_number += 1; + + if request_number + 1 == MAX_RETRIES && response.status() != StatusCode::OK { + return Err(anyhow!( + "openai max retries, error: {:?}", + &response.status() + )); + } + + match response.status() { + StatusCode::TOO_MANY_REQUESTS => { + let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64); + std::thread::sleep(delay); + } + StatusCode::BAD_REQUEST => { + log::info!("BAD REQUEST: {:?}", &response.status()); + // Don't worry about delaying bad request, as we can assume + // we haven't been rate limited yet. + for span in spans.iter_mut() { + *span = Self::truncate(span.to_string()).await; + } + } + StatusCode::OK => { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?; + + log::info!( + "openai embedding completed. tokens: {:?}", + response.usage.total_tokens + ); + return Ok(response + .data + .into_iter() + .map(|embedding| embedding.embedding) + .collect()); + } + _ => { + return Err(anyhow!("openai embedding failed {}", response.status())); + } + } + } + + Err(anyhow!("openai embedding failed")) } } diff --git a/crates/vector_store/src/vector_store.rs b/crates/vector_store/src/vector_store.rs index a63674bc34..5141451e64 100644 --- a/crates/vector_store/src/vector_store.rs +++ b/crates/vector_store/src/vector_store.rs @@ -74,7 +74,6 @@ pub fn init( cx.subscribe_global::({ let vector_store = vector_store.clone(); move |event, cx| { - let t0 = Instant::now(); let workspace = &event.0; if let Some(workspace) = workspace.upgrade(cx) { let project = workspace.read(cx).project().clone(); @@ -126,9 +125,7 @@ pub struct VectorStore { language_registry: Arc, db_update_tx: channel::Sender, // embed_batch_tx: channel::Sender)>>, - batch_files_tx: channel::Sender<(i64, IndexedFile, Vec)>, parsing_files_tx: channel::Sender<(i64, PathBuf, Arc, SystemTime)>, - parsing_files_rx: channel::Receiver<(i64, PathBuf, Arc, SystemTime)>, _db_update_task: Task<()>, _embed_batch_task: Vec>, _batch_files_task: Task<()>, @@ -220,14 +217,13 @@ impl VectorStore { let (embed_batch_tx, embed_batch_rx) = channel::unbounded::)>>(); let mut _embed_batch_task = Vec::new(); - for _ in 0..cx.background().num_cpus() { + for _ in 0..1 { + //cx.background().num_cpus() { let db_update_tx = db_update_tx.clone(); let embed_batch_rx = embed_batch_rx.clone(); let embedding_provider = embedding_provider.clone(); _embed_batch_task.push(cx.background().spawn(async move { while let Ok(embeddings_queue) = embed_batch_rx.recv().await { - log::info!("Embedding Batch! "); - // Construct Batch let mut embeddings_queue = embeddings_queue.clone(); let mut document_spans = vec![]; @@ -235,20 +231,20 @@ impl VectorStore { document_spans.extend(document_span); } - if let Some(mut embeddings) = embedding_provider + if let Ok(embeddings) = embedding_provider .embed_batch(document_spans.iter().map(|x| &**x).collect()) .await - .log_err() { let mut i = 0; let mut j = 0; - while let Some(embedding) = embeddings.pop() { + + for embedding in embeddings.iter() { while embeddings_queue[i].1.documents.len() == j { i += 1; j = 0; } - embeddings_queue[i].1.documents[j].embedding = embedding; + embeddings_queue[i].1.documents[j].embedding = embedding.to_owned(); j += 1; } @@ -283,7 +279,6 @@ impl VectorStore { while let Ok((worktree_id, indexed_file, document_spans)) = batch_files_rx.recv().await { - log::info!("Batching File: {:?}", &indexed_file.path); queue_len += &document_spans.len(); embeddings_queue.push((worktree_id, indexed_file, document_spans)); if queue_len >= EMBEDDINGS_BATCH_SIZE { @@ -338,10 +333,7 @@ impl VectorStore { embedding_provider, language_registry, db_update_tx, - // embed_batch_tx, - batch_files_tx, parsing_files_tx, - parsing_files_rx, _db_update_task, _embed_batch_task, _batch_files_task, @@ -449,8 +441,6 @@ impl VectorStore { let database_url = self.database_url.clone(); let db_update_tx = self.db_update_tx.clone(); let parsing_files_tx = self.parsing_files_tx.clone(); - let parsing_files_rx = self.parsing_files_rx.clone(); - let batch_files_tx = self.batch_files_tx.clone(); cx.spawn(|this, mut cx| async move { let t0 = Instant::now(); @@ -553,37 +543,6 @@ impl VectorStore { }) .detach(); - // cx.background() - // .scoped(|scope| { - // for _ in 0..cx.background().num_cpus() { - // scope.spawn(async { - // let mut parser = Parser::new(); - // let mut cursor = QueryCursor::new(); - // while let Ok((worktree_id, file_path, language, mtime)) = - // parsing_files_rx.recv().await - // { - // log::info!("Parsing File: {:?}", &file_path); - // if let Some((indexed_file, document_spans)) = Self::index_file( - // &mut cursor, - // &mut parser, - // &fs, - // language, - // file_path.clone(), - // mtime, - // ) - // .await - // .log_err() - // { - // batch_files_tx - // .try_send((worktree_id, indexed_file, document_spans)) - // .unwrap(); - // } - // } - // }); - // } - // }) - // .await; - this.update(&mut cx, |this, cx| { // The below is managing for updated on save // Currently each time a file is saved, this code is run, and for all the files that were changed, if the current time is @@ -592,90 +551,90 @@ impl VectorStore { if let Some(project_state) = this.projects.get(&project.downgrade()) { let worktree_db_ids = project_state.worktree_db_ids.clone(); - // if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event - // { - // // Iterate through changes - // let language_registry = this.language_registry.clone(); + if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event + { + // Iterate through changes + let language_registry = this.language_registry.clone(); - // let db = - // VectorDatabase::new(this.database_url.to_string_lossy().into()); - // if db.is_err() { - // return; - // } - // let db = db.unwrap(); + let db = + VectorDatabase::new(this.database_url.to_string_lossy().into()); + if db.is_err() { + return; + } + let db = db.unwrap(); - // let worktree_db_id: Option = { - // let mut found_db_id = None; - // for (w_id, db_id) in worktree_db_ids.into_iter() { - // if &w_id == worktree_id { - // found_db_id = Some(db_id); - // } - // } + let worktree_db_id: Option = { + let mut found_db_id = None; + for (w_id, db_id) in worktree_db_ids.into_iter() { + if &w_id == worktree_id { + found_db_id = Some(db_id); + } + } - // found_db_id - // }; + found_db_id + }; - // if worktree_db_id.is_none() { - // return; - // } - // let worktree_db_id = worktree_db_id.unwrap(); + if worktree_db_id.is_none() { + return; + } + let worktree_db_id = worktree_db_id.unwrap(); - // let file_mtimes = db.get_file_mtimes(worktree_db_id); - // if file_mtimes.is_err() { - // return; - // } + let file_mtimes = db.get_file_mtimes(worktree_db_id); + if file_mtimes.is_err() { + return; + } - // let file_mtimes = file_mtimes.unwrap(); - // let paths_tx = this.paths_tx.clone(); + let file_mtimes = file_mtimes.unwrap(); + let parsing_files_tx = this.parsing_files_tx.clone(); - // smol::block_on(async move { - // for change in changes.into_iter() { - // let change_path = change.0.clone(); - // log::info!("Change: {:?}", &change_path); - // if let Ok(language) = language_registry - // .language_for_file(&change_path.to_path_buf(), None) - // .await - // { - // if language - // .grammar() - // .and_then(|grammar| grammar.embedding_config.as_ref()) - // .is_none() - // { - // continue; - // } + smol::block_on(async move { + for change in changes.into_iter() { + let change_path = change.0.clone(); + log::info!("Change: {:?}", &change_path); + if let Ok(language) = language_registry + .language_for_file(&change_path.to_path_buf(), None) + .await + { + if language + .grammar() + .and_then(|grammar| grammar.embedding_config.as_ref()) + .is_none() + { + continue; + } - // // TODO: Make this a bit more defensive - // let modified_time = - // change_path.metadata().unwrap().modified().unwrap(); - // let existing_time = - // file_mtimes.get(&change_path.to_path_buf()); - // let already_stored = - // existing_time.map_or(false, |existing_time| { - // if &modified_time != existing_time - // && existing_time.elapsed().unwrap().as_secs() - // > REINDEXING_DELAY - // { - // false - // } else { - // true - // } - // }); + // TODO: Make this a bit more defensive + let modified_time = + change_path.metadata().unwrap().modified().unwrap(); + let existing_time = + file_mtimes.get(&change_path.to_path_buf()); + let already_stored = + existing_time.map_or(false, |existing_time| { + if &modified_time != existing_time + && existing_time.elapsed().unwrap().as_secs() + > REINDEXING_DELAY + { + false + } else { + true + } + }); - // if !already_stored { - // log::info!("Need to reindex: {:?}", &change_path); - // paths_tx - // .try_send(( - // worktree_db_id, - // change_path.to_path_buf(), - // language, - // modified_time, - // )) - // .unwrap(); - // } - // } - // } - // }) - // } + if !already_stored { + log::info!("Need to reindex: {:?}", &change_path); + parsing_files_tx + .try_send(( + worktree_db_id, + change_path.to_path_buf(), + language, + modified_time, + )) + .unwrap(); + } + } + } + }) + } } }); From e57f6f21fe11e1bb585202c346cb5c28360c935f Mon Sep 17 00:00:00 2001 From: KCaverly Date: Thu, 6 Jul 2023 15:26:43 -0400 Subject: [PATCH 30/51] reindexing update to appropriately accomodate for buffer delay and persistent pending files list --- crates/vector_store/src/vector_store.rs | 234 ++++++++++++++++-------- 1 file changed, 160 insertions(+), 74 deletions(-) diff --git a/crates/vector_store/src/vector_store.rs b/crates/vector_store/src/vector_store.rs index 5141451e64..57277e39af 100644 --- a/crates/vector_store/src/vector_store.rs +++ b/crates/vector_store/src/vector_store.rs @@ -18,11 +18,13 @@ use modal::{SemanticSearch, SemanticSearchDelegate, Toggle}; use project::{Fs, Project, WorktreeId}; use smol::channel; use std::{ + cell::RefCell, cmp::Ordering, collections::HashMap, path::{Path, PathBuf}, + rc::Rc, sync::Arc, - time::{Instant, SystemTime}, + time::{Duration, Instant, SystemTime}, }; use tree_sitter::{Parser, QueryCursor}; use util::{ @@ -33,7 +35,7 @@ use util::{ }; use workspace::{Workspace, WorkspaceCreated}; -const REINDEXING_DELAY: u64 = 30; +const REINDEXING_DELAY_SECONDS: u64 = 3; const EMBEDDINGS_BATCH_SIZE: usize = 150; #[derive(Debug, Clone)] @@ -124,20 +126,62 @@ pub struct VectorStore { embedding_provider: Arc, language_registry: Arc, db_update_tx: channel::Sender, - // embed_batch_tx: channel::Sender)>>, - parsing_files_tx: channel::Sender<(i64, PathBuf, Arc, SystemTime)>, + parsing_files_tx: channel::Sender, _db_update_task: Task<()>, _embed_batch_task: Vec>, _batch_files_task: Task<()>, _parsing_files_tasks: Vec>, - projects: HashMap, ProjectState>, + projects: HashMap, Rc>>, } struct ProjectState { worktree_db_ids: Vec<(WorktreeId, i64)>, + pending_files: HashMap, _subscription: gpui::Subscription, } +impl ProjectState { + fn update_pending_files(&mut self, pending_file: PendingFile, indexing_time: SystemTime) { + // If Pending File Already Exists, Replace it with the new one + // but keep the old indexing time + if let Some(old_file) = self.pending_files.remove(&pending_file.path.clone()) { + self.pending_files + .insert(pending_file.path.clone(), (pending_file, old_file.1)); + } else { + self.pending_files + .insert(pending_file.path.clone(), (pending_file, indexing_time)); + }; + } + + fn get_outstanding_files(&mut self) -> Vec { + let mut outstanding_files = vec![]; + let mut remove_keys = vec![]; + for key in self.pending_files.keys().into_iter() { + if let Some(pending_details) = self.pending_files.get(key) { + let (pending_file, index_time) = pending_details; + if index_time <= &SystemTime::now() { + outstanding_files.push(pending_file.clone()); + remove_keys.push(key.clone()); + } + } + } + + for key in remove_keys.iter() { + self.pending_files.remove(key); + } + + return outstanding_files; + } +} + +#[derive(Clone, Debug)] +struct PendingFile { + worktree_db_id: i64, + path: PathBuf, + language: Arc, + modified_time: SystemTime, +} + #[derive(Debug, Clone)] pub struct SearchResult { pub worktree_id: WorktreeId, @@ -293,8 +337,7 @@ impl VectorStore { }); // parsing_files_tx/rx: Parsing Files to Embeddable Documents - let (parsing_files_tx, parsing_files_rx) = - channel::unbounded::<(i64, PathBuf, Arc, SystemTime)>(); + let (parsing_files_tx, parsing_files_rx) = channel::unbounded::(); let mut _parsing_files_tasks = Vec::new(); for _ in 0..cx.background().num_cpus() { @@ -304,23 +347,25 @@ impl VectorStore { _parsing_files_tasks.push(cx.background().spawn(async move { let mut parser = Parser::new(); let mut cursor = QueryCursor::new(); - while let Ok((worktree_id, file_path, language, mtime)) = - parsing_files_rx.recv().await - { - log::info!("Parsing File: {:?}", &file_path); + while let Ok(pending_file) = parsing_files_rx.recv().await { + log::info!("Parsing File: {:?}", &pending_file.path); if let Some((indexed_file, document_spans)) = Self::index_file( &mut cursor, &mut parser, &fs, - language, - file_path.clone(), - mtime, + pending_file.language, + pending_file.path.clone(), + pending_file.modified_time, ) .await .log_err() { batch_files_tx - .try_send((worktree_id, indexed_file, document_spans)) + .try_send(( + pending_file.worktree_db_id, + indexed_file, + document_spans, + )) .unwrap(); } } @@ -516,12 +561,13 @@ impl VectorStore { if !already_stored { parsing_files_tx - .try_send(( - db_ids_by_worktree_id[&worktree.id()], - path_buf, + .try_send(PendingFile { + worktree_db_id: db_ids_by_worktree_id + [&worktree.id()], + path: path_buf, language, - file.mtime, - )) + modified_time: file.mtime, + }) .unwrap(); } } @@ -543,54 +589,82 @@ impl VectorStore { }) .detach(); + // let mut pending_files: Vec<(PathBuf, ((i64, PathBuf, Arc, SystemTime), SystemTime))> = vec![]; this.update(&mut cx, |this, cx| { // The below is managing for updated on save // Currently each time a file is saved, this code is run, and for all the files that were changed, if the current time is // greater than the previous embedded time by the REINDEXING_DELAY variable, we will send the file off to be indexed. - let _subscription = cx.subscribe(&project, |this, project, event, _cx| { + let _subscription = cx.subscribe(&project, |this, project, event, cx| { if let Some(project_state) = this.projects.get(&project.downgrade()) { + let mut project_state = project_state.borrow_mut(); let worktree_db_ids = project_state.worktree_db_ids.clone(); if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event { - // Iterate through changes - let language_registry = this.language_registry.clone(); - - let db = - VectorDatabase::new(this.database_url.to_string_lossy().into()); - if db.is_err() { + // Get Worktree Object + let worktree = + project.read(cx).worktree_for_id(worktree_id.clone(), cx); + if worktree.is_none() { return; } - let db = db.unwrap(); + let worktree = worktree.unwrap(); - let worktree_db_id: Option = { - let mut found_db_id = None; - for (w_id, db_id) in worktree_db_ids.into_iter() { - if &w_id == worktree_id { - found_db_id = Some(db_id); + // Get Database + let db_values = { + if let Ok(db) = + VectorDatabase::new(this.database_url.to_string_lossy().into()) + { + let worktree_db_id: Option = { + let mut found_db_id = None; + for (w_id, db_id) in worktree_db_ids.into_iter() { + if &w_id == &worktree.read(cx).id() { + found_db_id = Some(db_id) + } + } + found_db_id + }; + if worktree_db_id.is_none() { + return; } - } + let worktree_db_id = worktree_db_id.unwrap(); - found_db_id + let file_mtimes = db.get_file_mtimes(worktree_db_id); + if file_mtimes.is_err() { + return; + } + + let file_mtimes = file_mtimes.unwrap(); + Some((file_mtimes, worktree_db_id)) + } else { + return; + } }; - if worktree_db_id.is_none() { - return; - } - let worktree_db_id = worktree_db_id.unwrap(); - - let file_mtimes = db.get_file_mtimes(worktree_db_id); - if file_mtimes.is_err() { + if db_values.is_none() { return; } - let file_mtimes = file_mtimes.unwrap(); + let (file_mtimes, worktree_db_id) = db_values.unwrap(); + + // Iterate Through Changes + let language_registry = this.language_registry.clone(); let parsing_files_tx = this.parsing_files_tx.clone(); smol::block_on(async move { for change in changes.into_iter() { let change_path = change.0.clone(); - log::info!("Change: {:?}", &change_path); + // Skip if git ignored or symlink + if let Some(entry) = worktree.read(cx).entry_for_id(change.1) { + if entry.is_ignored || entry.is_symlink { + continue; + } else { + log::info!( + "Testing for Reindexing: {:?}", + &change_path + ); + } + }; + if let Ok(language) = language_registry .language_for_file(&change_path.to_path_buf(), None) .await @@ -603,47 +677,59 @@ impl VectorStore { continue; } - // TODO: Make this a bit more defensive - let modified_time = - change_path.metadata().unwrap().modified().unwrap(); - let existing_time = - file_mtimes.get(&change_path.to_path_buf()); - let already_stored = - existing_time.map_or(false, |existing_time| { - if &modified_time != existing_time - && existing_time.elapsed().unwrap().as_secs() - > REINDEXING_DELAY - { - false + if let Some(modified_time) = { + let metadata = change_path.metadata(); + if metadata.is_err() { + None + } else { + let mtime = metadata.unwrap().modified(); + if mtime.is_err() { + None } else { - true + Some(mtime.unwrap()) } - }); + } + } { + let existing_time = + file_mtimes.get(&change_path.to_path_buf()); + let already_stored = existing_time + .map_or(false, |existing_time| { + &modified_time != existing_time + }); - if !already_stored { - log::info!("Need to reindex: {:?}", &change_path); - parsing_files_tx - .try_send(( - worktree_db_id, - change_path.to_path_buf(), - language, - modified_time, - )) - .unwrap(); + let reindex_time = modified_time + + Duration::from_secs(REINDEXING_DELAY_SECONDS); + + if !already_stored { + project_state.update_pending_files( + PendingFile { + path: change_path.to_path_buf(), + modified_time, + worktree_db_id, + language: language.clone(), + }, + reindex_time, + ); + + for file in project_state.get_outstanding_files() { + parsing_files_tx.try_send(file).unwrap(); + } + } } } } - }) - } + }); + }; } }); this.projects.insert( project.downgrade(), - ProjectState { + Rc::new(RefCell::new(ProjectState { + pending_files: HashMap::new(), worktree_db_ids: db_ids_by_worktree_id.into_iter().collect(), _subscription, - }, + })), ); }); @@ -659,7 +745,7 @@ impl VectorStore { cx: &mut ModelContext, ) -> Task>> { let project_state = if let Some(state) = self.projects.get(&project.downgrade()) { - state + state.borrow() } else { return Task::ready(Err(anyhow!("project not added"))); }; @@ -717,7 +803,7 @@ impl VectorStore { this.read_with(&cx, |this, _| { let project_state = if let Some(state) = this.projects.get(&project.downgrade()) { - state + state.borrow() } else { return Err(anyhow!("project not added")); }; From 7d634f66e2b2b7196b3e9141c0664d8641251323 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Thu, 6 Jul 2023 16:33:54 -0400 Subject: [PATCH 31/51] updated vector_store to include extra context for semantic search modal --- crates/vector_store/src/vector_store.rs | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/crates/vector_store/src/vector_store.rs b/crates/vector_store/src/vector_store.rs index 57277e39af..065dfb51f2 100644 --- a/crates/vector_store/src/vector_store.rs +++ b/crates/vector_store/src/vector_store.rs @@ -418,21 +418,34 @@ impl VectorStore { ) { let mut item_range = None; let mut name_range = None; + let mut context_range = None; for capture in mat.captures { if capture.index == embedding_config.item_capture_ix { item_range = Some(capture.node.byte_range()); } else if capture.index == embedding_config.name_capture_ix { name_range = Some(capture.node.byte_range()); } + if let Some(context_capture_ix) = embedding_config.context_capture_ix { + if capture.index == context_capture_ix { + context_range = Some(capture.node.byte_range()); + } + } } if let Some((item_range, name_range)) = item_range.zip(name_range) { + let mut context_data = String::new(); + if let Some(context_range) = context_range { + if let Some(context) = content.get(context_range.clone()) { + context_data.push_str(context); + } + } + if let Some((item, name)) = content.get(item_range.clone()).zip(content.get(name_range)) { context_spans.push(item.to_string()); documents.push(Document { - name: name.to_string(), + name: format!("{} {}", context_data.to_string(), name.to_string()), offset: item_range.start, embedding: Vec::new(), }); From 6f1e988cb92aa76bc31c841e0009884576370219 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Thu, 6 Jul 2023 16:36:28 -0400 Subject: [PATCH 32/51] updated embedding treesitter query for python --- crates/zed/src/languages/python/embedding.scm | 9 +++++++++ 1 file changed, 9 insertions(+) create mode 100644 crates/zed/src/languages/python/embedding.scm diff --git a/crates/zed/src/languages/python/embedding.scm b/crates/zed/src/languages/python/embedding.scm new file mode 100644 index 0000000000..e3efb3dbf6 --- /dev/null +++ b/crates/zed/src/languages/python/embedding.scm @@ -0,0 +1,9 @@ +(class_definition + "class" @context + name: (identifier) @name + ) @item + +(function_definition + "async"? @context + "def" @context + name: (_) @name) @item From c03dda1a0cc9f99f841622dc95358e8b9dc39ea8 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Thu, 6 Jul 2023 17:15:41 -0400 Subject: [PATCH 33/51] fixed bug on absolute vs relative path --- crates/vector_store/src/vector_store.rs | 39 ++++++++++++++++--------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/crates/vector_store/src/vector_store.rs b/crates/vector_store/src/vector_store.rs index 065dfb51f2..baab05bec2 100644 --- a/crates/vector_store/src/vector_store.rs +++ b/crates/vector_store/src/vector_store.rs @@ -144,12 +144,19 @@ impl ProjectState { fn update_pending_files(&mut self, pending_file: PendingFile, indexing_time: SystemTime) { // If Pending File Already Exists, Replace it with the new one // but keep the old indexing time - if let Some(old_file) = self.pending_files.remove(&pending_file.path.clone()) { - self.pending_files - .insert(pending_file.path.clone(), (pending_file, old_file.1)); + if let Some(old_file) = self + .pending_files + .remove(&pending_file.relative_path.clone()) + { + self.pending_files.insert( + pending_file.relative_path.clone(), + (pending_file, old_file.1), + ); } else { - self.pending_files - .insert(pending_file.path.clone(), (pending_file, indexing_time)); + self.pending_files.insert( + pending_file.relative_path.clone(), + (pending_file, indexing_time), + ); }; } @@ -177,7 +184,8 @@ impl ProjectState { #[derive(Clone, Debug)] struct PendingFile { worktree_db_id: i64, - path: PathBuf, + relative_path: PathBuf, + absolute_path: PathBuf, language: Arc, modified_time: SystemTime, } @@ -348,13 +356,14 @@ impl VectorStore { let mut parser = Parser::new(); let mut cursor = QueryCursor::new(); while let Ok(pending_file) = parsing_files_rx.recv().await { - log::info!("Parsing File: {:?}", &pending_file.path); + log::info!("Parsing File: {:?}", &pending_file.relative_path); if let Some((indexed_file, document_spans)) = Self::index_file( &mut cursor, &mut parser, &fs, pending_file.language, - pending_file.path.clone(), + pending_file.relative_path.clone(), + pending_file.absolute_path.clone(), pending_file.modified_time, ) .await @@ -393,7 +402,8 @@ impl VectorStore { parser: &mut Parser, fs: &Arc, language: Arc, - file_path: PathBuf, + relative_file_path: PathBuf, + absolute_file_path: PathBuf, mtime: SystemTime, ) -> Result<(IndexedFile, Vec)> { let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?; @@ -402,7 +412,7 @@ impl VectorStore { .as_ref() .ok_or_else(|| anyhow!("no outline query"))?; - let content = fs.load(&file_path).await?; + let content = fs.load(&absolute_file_path).await?; parser.set_language(grammar.ts_language).unwrap(); let tree = parser @@ -455,7 +465,7 @@ impl VectorStore { return Ok(( IndexedFile { - path: file_path, + path: relative_file_path, mtime, documents, }, @@ -577,7 +587,8 @@ impl VectorStore { .try_send(PendingFile { worktree_db_id: db_ids_by_worktree_id [&worktree.id()], - path: path_buf, + relative_path: path_buf, + absolute_path, language, modified_time: file.mtime, }) @@ -666,6 +677,7 @@ impl VectorStore { smol::block_on(async move { for change in changes.into_iter() { let change_path = change.0.clone(); + let absolute_path = worktree.read(cx).absolutize(&change_path); // Skip if git ignored or symlink if let Some(entry) = worktree.read(cx).entry_for_id(change.1) { if entry.is_ignored || entry.is_symlink { @@ -716,7 +728,8 @@ impl VectorStore { if !already_stored { project_state.update_pending_files( PendingFile { - path: change_path.to_path_buf(), + relative_path: change_path.to_path_buf(), + absolute_path, modified_time, worktree_db_id, language: language.clone(), From 01897424979b51cd4f5cf52dd909c807e229324f Mon Sep 17 00:00:00 2001 From: KCaverly Date: Mon, 10 Jul 2023 10:06:07 -0400 Subject: [PATCH 34/51] pulled treesitter parsing to own file for ease of testing and management --- crates/vector_store/src/db.rs | 4 +- crates/vector_store/src/parsing.rs | 94 ++++++++++++++++++ crates/vector_store/src/vector_store.rs | 121 +++--------------------- 3 files changed, 110 insertions(+), 109 deletions(-) create mode 100644 crates/vector_store/src/parsing.rs diff --git a/crates/vector_store/src/db.rs b/crates/vector_store/src/db.rs index f822cca77e..4882db443b 100644 --- a/crates/vector_store/src/db.rs +++ b/crates/vector_store/src/db.rs @@ -7,7 +7,7 @@ use std::{ use anyhow::{anyhow, Result}; -use crate::IndexedFile; +use crate::parsing::ParsedFile; use rpc::proto::Timestamp; use rusqlite::{ params, @@ -109,7 +109,7 @@ impl VectorDatabase { Ok(()) } - pub fn insert_file(&self, worktree_id: i64, indexed_file: IndexedFile) -> Result<()> { + pub fn insert_file(&self, worktree_id: i64, indexed_file: ParsedFile) -> Result<()> { // Write to files table, and return generated id. self.db.execute( " diff --git a/crates/vector_store/src/parsing.rs b/crates/vector_store/src/parsing.rs new file mode 100644 index 0000000000..6a8742fedd --- /dev/null +++ b/crates/vector_store/src/parsing.rs @@ -0,0 +1,94 @@ +use std::{ops::Range, path::PathBuf, sync::Arc, time::SystemTime}; + +use anyhow::{anyhow, Ok, Result}; +use project::Fs; +use tree_sitter::{Parser, QueryCursor}; + +use crate::PendingFile; + +#[derive(Debug, PartialEq, Clone)] +pub struct Document { + pub offset: usize, + pub name: String, + pub embedding: Vec, +} + +#[derive(Debug, PartialEq, Clone)] +pub struct ParsedFile { + pub path: PathBuf, + pub mtime: SystemTime, + pub documents: Vec, +} + +pub struct CodeContextRetriever { + pub parser: Parser, + pub cursor: QueryCursor, + pub fs: Arc, +} + +impl CodeContextRetriever { + pub async fn parse_file( + &mut self, + pending_file: PendingFile, + ) -> Result<(ParsedFile, Vec)> { + let grammar = pending_file + .language + .grammar() + .ok_or_else(|| anyhow!("no grammar for language"))?; + let embedding_config = grammar + .embedding_config + .as_ref() + .ok_or_else(|| anyhow!("no embedding queries"))?; + + let content = self.fs.load(&pending_file.absolute_path).await?; + + self.parser.set_language(grammar.ts_language).unwrap(); + + let tree = self + .parser + .parse(&content, None) + .ok_or_else(|| anyhow!("parsing failed"))?; + + let mut documents = Vec::new(); + let mut context_spans = Vec::new(); + + // Iterate through query matches + for mat in self.cursor.matches( + &embedding_config.query, + tree.root_node(), + content.as_bytes(), + ) { + let mut item_range: Option> = None; + let mut name_range: Option> = None; + for capture in mat.captures { + if capture.index == embedding_config.item_capture_ix { + item_range = Some(capture.node.byte_range()); + } else if capture.index == embedding_config.name_capture_ix { + name_range = Some(capture.node.byte_range()); + } + } + + if let Some((item_range, name_range)) = item_range.zip(name_range) { + if let Some((item, name)) = + content.get(item_range.clone()).zip(content.get(name_range)) + { + context_spans.push(item.to_string()); + documents.push(Document { + name: name.to_string(), + offset: item_range.start, + embedding: Vec::new(), + }); + } + } + } + + return Ok(( + ParsedFile { + path: pending_file.relative_path, + mtime: pending_file.modified_time, + documents, + }, + context_spans, + )); + } +} diff --git a/crates/vector_store/src/vector_store.rs b/crates/vector_store/src/vector_store.rs index baab05bec2..92557fd801 100644 --- a/crates/vector_store/src/vector_store.rs +++ b/crates/vector_store/src/vector_store.rs @@ -1,6 +1,7 @@ mod db; mod embedding; mod modal; +mod parsing; #[cfg(test)] mod vector_store_tests; @@ -15,6 +16,7 @@ use gpui::{ }; use language::{Language, LanguageRegistry}; use modal::{SemanticSearch, SemanticSearchDelegate, Toggle}; +use parsing::{CodeContextRetriever, ParsedFile}; use project::{Fs, Project, WorktreeId}; use smol::channel; use std::{ @@ -38,13 +40,6 @@ use workspace::{Workspace, WorkspaceCreated}; const REINDEXING_DELAY_SECONDS: u64 = 3; const EMBEDDINGS_BATCH_SIZE: usize = 150; -#[derive(Debug, Clone)] -pub struct Document { - pub offset: usize, - pub name: String, - pub embedding: Vec, -} - pub fn init( fs: Arc, http_client: Arc, @@ -113,13 +108,6 @@ pub fn init( .detach(); } -#[derive(Debug, Clone)] -pub struct IndexedFile { - path: PathBuf, - mtime: SystemTime, - documents: Vec, -} - pub struct VectorStore { fs: Arc, database_url: Arc, @@ -182,7 +170,7 @@ impl ProjectState { } #[derive(Clone, Debug)] -struct PendingFile { +pub struct PendingFile { worktree_db_id: i64, relative_path: PathBuf, absolute_path: PathBuf, @@ -201,7 +189,7 @@ pub struct SearchResult { enum DbWrite { InsertFile { worktree_id: i64, - indexed_file: IndexedFile, + indexed_file: ParsedFile, }, Delete { worktree_id: i64, @@ -267,7 +255,7 @@ impl VectorStore { // embed_tx/rx: Embed Batch and Send to Database let (embed_batch_tx, embed_batch_rx) = - channel::unbounded::)>>(); + channel::unbounded::)>>(); let mut _embed_batch_task = Vec::new(); for _ in 0..1 { //cx.background().num_cpus() { @@ -324,13 +312,14 @@ impl VectorStore { // batch_tx/rx: Batch Files to Send for Embeddings let (batch_files_tx, batch_files_rx) = - channel::unbounded::<(i64, IndexedFile, Vec)>(); + channel::unbounded::<(i64, ParsedFile, Vec)>(); let _batch_files_task = cx.background().spawn(async move { let mut queue_len = 0; let mut embeddings_queue = vec![]; while let Ok((worktree_id, indexed_file, document_spans)) = batch_files_rx.recv().await { + dbg!("Batching in while loop"); queue_len += &document_spans.len(); embeddings_queue.push((worktree_id, indexed_file, document_spans)); if queue_len >= EMBEDDINGS_BATCH_SIZE { @@ -339,6 +328,7 @@ impl VectorStore { queue_len = 0; } } + // TODO: This is never getting called, We've gotta manage for how to clear the embedding batch if its less than the necessary batch size. if queue_len > 0 { embed_batch_tx.try_send(embeddings_queue).unwrap(); } @@ -353,21 +343,14 @@ impl VectorStore { let parsing_files_rx = parsing_files_rx.clone(); let batch_files_tx = batch_files_tx.clone(); _parsing_files_tasks.push(cx.background().spawn(async move { - let mut parser = Parser::new(); - let mut cursor = QueryCursor::new(); + let parser = Parser::new(); + let cursor = QueryCursor::new(); + let mut retriever = CodeContextRetriever { parser, cursor, fs }; while let Ok(pending_file) = parsing_files_rx.recv().await { log::info!("Parsing File: {:?}", &pending_file.relative_path); - if let Some((indexed_file, document_spans)) = Self::index_file( - &mut cursor, - &mut parser, - &fs, - pending_file.language, - pending_file.relative_path.clone(), - pending_file.absolute_path.clone(), - pending_file.modified_time, - ) - .await - .log_err() + + if let Some((indexed_file, document_spans)) = + retriever.parse_file(pending_file.clone()).await.log_err() { batch_files_tx .try_send(( @@ -397,82 +380,6 @@ impl VectorStore { })) } - async fn index_file( - cursor: &mut QueryCursor, - parser: &mut Parser, - fs: &Arc, - language: Arc, - relative_file_path: PathBuf, - absolute_file_path: PathBuf, - mtime: SystemTime, - ) -> Result<(IndexedFile, Vec)> { - let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?; - let embedding_config = grammar - .embedding_config - .as_ref() - .ok_or_else(|| anyhow!("no outline query"))?; - - let content = fs.load(&absolute_file_path).await?; - - parser.set_language(grammar.ts_language).unwrap(); - let tree = parser - .parse(&content, None) - .ok_or_else(|| anyhow!("parsing failed"))?; - - let mut documents = Vec::new(); - let mut context_spans = Vec::new(); - for mat in cursor.matches( - &embedding_config.query, - tree.root_node(), - content.as_bytes(), - ) { - let mut item_range = None; - let mut name_range = None; - let mut context_range = None; - for capture in mat.captures { - if capture.index == embedding_config.item_capture_ix { - item_range = Some(capture.node.byte_range()); - } else if capture.index == embedding_config.name_capture_ix { - name_range = Some(capture.node.byte_range()); - } - if let Some(context_capture_ix) = embedding_config.context_capture_ix { - if capture.index == context_capture_ix { - context_range = Some(capture.node.byte_range()); - } - } - } - - if let Some((item_range, name_range)) = item_range.zip(name_range) { - let mut context_data = String::new(); - if let Some(context_range) = context_range { - if let Some(context) = content.get(context_range.clone()) { - context_data.push_str(context); - } - } - - if let Some((item, name)) = - content.get(item_range.clone()).zip(content.get(name_range)) - { - context_spans.push(item.to_string()); - documents.push(Document { - name: format!("{} {}", context_data.to_string(), name.to_string()), - offset: item_range.start, - embedding: Vec::new(), - }); - } - } - } - - return Ok(( - IndexedFile { - path: relative_file_path, - mtime, - documents, - }, - context_spans, - )); - } - fn find_or_create_worktree(&self, path: PathBuf) -> impl Future> { let (tx, rx) = oneshot::channel(); self.db_update_tx From 82079dd422613b98c8b1c6edfedaac1187ab2536 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Mon, 10 Jul 2023 16:33:14 -0400 Subject: [PATCH 35/51] Updated batching to accomodate for full flushes, and cleaned up reindexing. Co-authored-by: maxbrunsfeld --- crates/vector_store/src/embedding.rs | 4 +- crates/vector_store/src/vector_store.rs | 300 ++++++++++++------------ 2 files changed, 150 insertions(+), 154 deletions(-) diff --git a/crates/vector_store/src/embedding.rs b/crates/vector_store/src/embedding.rs index 029a6cdf61..ea349c8afa 100644 --- a/crates/vector_store/src/embedding.rs +++ b/crates/vector_store/src/embedding.rs @@ -1,6 +1,7 @@ use anyhow::{anyhow, Result}; use async_trait::async_trait; use futures::AsyncReadExt; +use gpui::executor::Background; use gpui::serde_json; use isahc::http::StatusCode; use isahc::prelude::Configurable; @@ -21,6 +22,7 @@ lazy_static! { #[derive(Clone)] pub struct OpenAIEmbeddings { pub client: Arc, + pub executor: Arc, } #[derive(Serialize)] @@ -128,7 +130,7 @@ impl EmbeddingProvider for OpenAIEmbeddings { match response.status() { StatusCode::TOO_MANY_REQUESTS => { let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64); - std::thread::sleep(delay); + self.executor.timer(delay).await; } StatusCode::BAD_REQUEST => { log::info!("BAD REQUEST: {:?}", &response.status()); diff --git a/crates/vector_store/src/vector_store.rs b/crates/vector_store/src/vector_store.rs index 92557fd801..c27c4992f3 100644 --- a/crates/vector_store/src/vector_store.rs +++ b/crates/vector_store/src/vector_store.rs @@ -17,14 +17,12 @@ use gpui::{ use language::{Language, LanguageRegistry}; use modal::{SemanticSearch, SemanticSearchDelegate, Toggle}; use parsing::{CodeContextRetriever, ParsedFile}; -use project::{Fs, Project, WorktreeId}; +use project::{Fs, PathChange, Project, ProjectEntryId, WorktreeId}; use smol::channel; use std::{ - cell::RefCell, cmp::Ordering, collections::HashMap, path::{Path, PathBuf}, - rc::Rc, sync::Arc, time::{Duration, Instant, SystemTime}, }; @@ -61,6 +59,7 @@ pub fn init( // Arc::new(embedding::DummyEmbeddings {}), Arc::new(OpenAIEmbeddings { client: http_client, + executor: cx.background(), }), language_registry, cx.clone(), @@ -119,7 +118,7 @@ pub struct VectorStore { _embed_batch_task: Vec>, _batch_files_task: Task<()>, _parsing_files_tasks: Vec>, - projects: HashMap, Rc>>, + projects: HashMap, ProjectState>, } struct ProjectState { @@ -201,6 +200,15 @@ enum DbWrite { }, } +enum EmbeddingJob { + Enqueue { + worktree_id: i64, + parsed_file: ParsedFile, + document_spans: Vec, + }, + Flush, +} + impl VectorStore { async fn new( fs: Arc, @@ -309,29 +317,32 @@ impl VectorStore { } })) } - // batch_tx/rx: Batch Files to Send for Embeddings - let (batch_files_tx, batch_files_rx) = - channel::unbounded::<(i64, ParsedFile, Vec)>(); + let (batch_files_tx, batch_files_rx) = channel::unbounded::(); let _batch_files_task = cx.background().spawn(async move { let mut queue_len = 0; let mut embeddings_queue = vec![]; - while let Ok((worktree_id, indexed_file, document_spans)) = - batch_files_rx.recv().await - { - dbg!("Batching in while loop"); - queue_len += &document_spans.len(); - embeddings_queue.push((worktree_id, indexed_file, document_spans)); - if queue_len >= EMBEDDINGS_BATCH_SIZE { + + while let Ok(job) = batch_files_rx.recv().await { + let should_flush = match job { + EmbeddingJob::Enqueue { + document_spans, + worktree_id, + parsed_file, + } => { + queue_len += &document_spans.len(); + embeddings_queue.push((worktree_id, parsed_file, document_spans)); + queue_len >= EMBEDDINGS_BATCH_SIZE + } + EmbeddingJob::Flush => true, + }; + + if should_flush { embed_batch_tx.try_send(embeddings_queue).unwrap(); embeddings_queue = vec![]; queue_len = 0; } } - // TODO: This is never getting called, We've gotta manage for how to clear the embedding batch if its less than the necessary batch size. - if queue_len > 0 { - embed_batch_tx.try_send(embeddings_queue).unwrap(); - } }); // parsing_files_tx/rx: Parsing Files to Embeddable Documents @@ -353,13 +364,17 @@ impl VectorStore { retriever.parse_file(pending_file.clone()).await.log_err() { batch_files_tx - .try_send(( - pending_file.worktree_db_id, - indexed_file, + .try_send(EmbeddingJob::Enqueue { + worktree_id: pending_file.worktree_db_id, + parsed_file: indexed_file, document_spans, - )) + }) .unwrap(); } + + if parsing_files_rx.len() == 0 { + batch_files_tx.try_send(EmbeddingJob::Flush).unwrap(); + } } })); } @@ -526,143 +541,18 @@ impl VectorStore { // Currently each time a file is saved, this code is run, and for all the files that were changed, if the current time is // greater than the previous embedded time by the REINDEXING_DELAY variable, we will send the file off to be indexed. let _subscription = cx.subscribe(&project, |this, project, event, cx| { - if let Some(project_state) = this.projects.get(&project.downgrade()) { - let mut project_state = project_state.borrow_mut(); - let worktree_db_ids = project_state.worktree_db_ids.clone(); - - if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event - { - // Get Worktree Object - let worktree = - project.read(cx).worktree_for_id(worktree_id.clone(), cx); - if worktree.is_none() { - return; - } - let worktree = worktree.unwrap(); - - // Get Database - let db_values = { - if let Ok(db) = - VectorDatabase::new(this.database_url.to_string_lossy().into()) - { - let worktree_db_id: Option = { - let mut found_db_id = None; - for (w_id, db_id) in worktree_db_ids.into_iter() { - if &w_id == &worktree.read(cx).id() { - found_db_id = Some(db_id) - } - } - found_db_id - }; - if worktree_db_id.is_none() { - return; - } - let worktree_db_id = worktree_db_id.unwrap(); - - let file_mtimes = db.get_file_mtimes(worktree_db_id); - if file_mtimes.is_err() { - return; - } - - let file_mtimes = file_mtimes.unwrap(); - Some((file_mtimes, worktree_db_id)) - } else { - return; - } - }; - - if db_values.is_none() { - return; - } - - let (file_mtimes, worktree_db_id) = db_values.unwrap(); - - // Iterate Through Changes - let language_registry = this.language_registry.clone(); - let parsing_files_tx = this.parsing_files_tx.clone(); - - smol::block_on(async move { - for change in changes.into_iter() { - let change_path = change.0.clone(); - let absolute_path = worktree.read(cx).absolutize(&change_path); - // Skip if git ignored or symlink - if let Some(entry) = worktree.read(cx).entry_for_id(change.1) { - if entry.is_ignored || entry.is_symlink { - continue; - } else { - log::info!( - "Testing for Reindexing: {:?}", - &change_path - ); - } - }; - - if let Ok(language) = language_registry - .language_for_file(&change_path.to_path_buf(), None) - .await - { - if language - .grammar() - .and_then(|grammar| grammar.embedding_config.as_ref()) - .is_none() - { - continue; - } - - if let Some(modified_time) = { - let metadata = change_path.metadata(); - if metadata.is_err() { - None - } else { - let mtime = metadata.unwrap().modified(); - if mtime.is_err() { - None - } else { - Some(mtime.unwrap()) - } - } - } { - let existing_time = - file_mtimes.get(&change_path.to_path_buf()); - let already_stored = existing_time - .map_or(false, |existing_time| { - &modified_time != existing_time - }); - - let reindex_time = modified_time - + Duration::from_secs(REINDEXING_DELAY_SECONDS); - - if !already_stored { - project_state.update_pending_files( - PendingFile { - relative_path: change_path.to_path_buf(), - absolute_path, - modified_time, - worktree_db_id, - language: language.clone(), - }, - reindex_time, - ); - - for file in project_state.get_outstanding_files() { - parsing_files_tx.try_send(file).unwrap(); - } - } - } - } - } - }); - }; + if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event { + this.project_entries_changed(project, changes, cx, worktree_id); } }); this.projects.insert( project.downgrade(), - Rc::new(RefCell::new(ProjectState { + ProjectState { pending_files: HashMap::new(), worktree_db_ids: db_ids_by_worktree_id.into_iter().collect(), _subscription, - })), + }, ); }); @@ -678,7 +568,7 @@ impl VectorStore { cx: &mut ModelContext, ) -> Task>> { let project_state = if let Some(state) = self.projects.get(&project.downgrade()) { - state.borrow() + state } else { return Task::ready(Err(anyhow!("project not added"))); }; @@ -736,7 +626,7 @@ impl VectorStore { this.read_with(&cx, |this, _| { let project_state = if let Some(state) = this.projects.get(&project.downgrade()) { - state.borrow() + state } else { return Err(anyhow!("project not added")); }; @@ -766,6 +656,110 @@ impl VectorStore { }) }) } + + fn project_entries_changed( + &mut self, + project: ModelHandle, + changes: &[(Arc, ProjectEntryId, PathChange)], + cx: &mut ModelContext<'_, VectorStore>, + worktree_id: &WorktreeId, + ) -> Option<()> { + let project_state = self.projects.get_mut(&project.downgrade())?; + let worktree_db_ids = project_state.worktree_db_ids.clone(); + let worktree = project.read(cx).worktree_for_id(worktree_id.clone(), cx)?; + + // Get Database + let (file_mtimes, worktree_db_id) = { + if let Ok(db) = VectorDatabase::new(self.database_url.to_string_lossy().into()) { + let worktree_db_id = { + let mut found_db_id = None; + for (w_id, db_id) in worktree_db_ids.into_iter() { + if &w_id == &worktree.read(cx).id() { + found_db_id = Some(db_id) + } + } + found_db_id + }?; + + let file_mtimes = db.get_file_mtimes(worktree_db_id).log_err()?; + + Some((file_mtimes, worktree_db_id)) + } else { + return None; + } + }?; + + // Iterate Through Changes + let language_registry = self.language_registry.clone(); + let parsing_files_tx = self.parsing_files_tx.clone(); + + smol::block_on(async move { + for change in changes.into_iter() { + let change_path = change.0.clone(); + let absolute_path = worktree.read(cx).absolutize(&change_path); + // Skip if git ignored or symlink + if let Some(entry) = worktree.read(cx).entry_for_id(change.1) { + if entry.is_ignored || entry.is_symlink { + continue; + } else { + log::info!("Testing for Reindexing: {:?}", &change_path); + } + }; + + if let Ok(language) = language_registry + .language_for_file(&change_path.to_path_buf(), None) + .await + { + if language + .grammar() + .and_then(|grammar| grammar.embedding_config.as_ref()) + .is_none() + { + continue; + } + + if let Some(modified_time) = { + let metadata = change_path.metadata(); + if metadata.is_err() { + None + } else { + let mtime = metadata.unwrap().modified(); + if mtime.is_err() { + None + } else { + Some(mtime.unwrap()) + } + } + } { + let existing_time = file_mtimes.get(&change_path.to_path_buf()); + let already_stored = existing_time + .map_or(false, |existing_time| &modified_time != existing_time); + + let reindex_time = + modified_time + Duration::from_secs(REINDEXING_DELAY_SECONDS); + + if !already_stored { + project_state.update_pending_files( + PendingFile { + relative_path: change_path.to_path_buf(), + absolute_path, + modified_time, + worktree_db_id, + language: language.clone(), + }, + reindex_time, + ); + + for file in project_state.get_outstanding_files() { + parsing_files_tx.try_send(file).unwrap(); + } + } + } + } + } + }); + Some(()) + } } impl Entity for VectorStore { From 307d8d9c8d26ecaf4ecd2a3bddf58ec00be7a666 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Mon, 10 Jul 2023 17:50:19 -0400 Subject: [PATCH 36/51] Reduced redundant database connections on each worktree change. Co-authored-by: maxbrunsfeld --- crates/vector_store/src/db.rs | 78 +++++-- crates/vector_store/src/vector_store.rs | 282 ++++++++++-------------- 2 files changed, 182 insertions(+), 178 deletions(-) diff --git a/crates/vector_store/src/db.rs b/crates/vector_store/src/db.rs index 4882db443b..197e7d5696 100644 --- a/crates/vector_store/src/db.rs +++ b/crates/vector_store/src/db.rs @@ -1,4 +1,5 @@ use std::{ + cmp::Ordering, collections::HashMap, path::{Path, PathBuf}, rc::Rc, @@ -14,16 +15,6 @@ use rusqlite::{ types::{FromSql, FromSqlResult, ValueRef}, }; -// Note this is not an appropriate document -#[derive(Debug)] -pub struct DocumentRecord { - pub id: usize, - pub file_id: usize, - pub offset: usize, - pub name: String, - pub embedding: Embedding, -} - #[derive(Debug)] pub struct FileRecord { pub id: usize, @@ -32,7 +23,7 @@ pub struct FileRecord { } #[derive(Debug)] -pub struct Embedding(pub Vec); +struct Embedding(pub Vec); impl FromSql for Embedding { fn column_result(value: ValueRef) -> FromSqlResult { @@ -205,10 +196,35 @@ impl VectorDatabase { Ok(result) } - pub fn for_each_document( + pub fn top_k_search( &self, worktree_ids: &[i64], - mut f: impl FnMut(i64, Embedding), + query_embedding: &Vec, + limit: usize, + ) -> Result> { + let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1); + self.for_each_document(&worktree_ids, |id, embedding| { + eprintln!("document {id} {embedding:?}"); + + let similarity = dot(&embedding, &query_embedding); + let ix = match results + .binary_search_by(|(_, s)| similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)) + { + Ok(ix) => ix, + Err(ix) => ix, + }; + results.insert(ix, (id, similarity)); + results.truncate(limit); + })?; + + let ids = results.into_iter().map(|(id, _)| id).collect::>(); + self.get_documents_by_ids(&ids) + } + + fn for_each_document( + &self, + worktree_ids: &[i64], + mut f: impl FnMut(i64, Vec), ) -> Result<()> { let mut query_statement = self.db.prepare( " @@ -221,16 +237,20 @@ impl VectorDatabase { files.worktree_id IN rarray(?) ", )?; + query_statement .query_map(params![ids_to_sql(worktree_ids)], |row| { - Ok((row.get(0)?, row.get(1)?)) + Ok((row.get(0)?, row.get::<_, Embedding>(1)?)) })? .filter_map(|row| row.ok()) - .for_each(|row| f(row.0, row.1)); + .for_each(|(id, embedding)| { + dbg!("id"); + f(id, embedding.0) + }); Ok(()) } - pub fn get_documents_by_ids(&self, ids: &[i64]) -> Result> { + fn get_documents_by_ids(&self, ids: &[i64]) -> Result> { let mut statement = self.db.prepare( " SELECT @@ -279,3 +299,29 @@ fn ids_to_sql(ids: &[i64]) -> Rc> { .collect::>(), ) } + +pub(crate) fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 { + let len = vec_a.len(); + assert_eq!(len, vec_b.len()); + + let mut result = 0.0; + unsafe { + matrixmultiply::sgemm( + 1, + len, + 1, + 1.0, + vec_a.as_ptr(), + len as isize, + 1, + vec_b.as_ptr(), + 1, + len as isize, + 0.0, + &mut result as *mut f32, + 1, + 1, + ); + } + result +} diff --git a/crates/vector_store/src/vector_store.rs b/crates/vector_store/src/vector_store.rs index c27c4992f3..c42b7ab129 100644 --- a/crates/vector_store/src/vector_store.rs +++ b/crates/vector_store/src/vector_store.rs @@ -20,7 +20,6 @@ use parsing::{CodeContextRetriever, ParsedFile}; use project::{Fs, PathChange, Project, ProjectEntryId, WorktreeId}; use smol::channel; use std::{ - cmp::Ordering, collections::HashMap, path::{Path, PathBuf}, sync::Arc, @@ -112,10 +111,10 @@ pub struct VectorStore { database_url: Arc, embedding_provider: Arc, language_registry: Arc, - db_update_tx: channel::Sender, + db_update_tx: channel::Sender, parsing_files_tx: channel::Sender, _db_update_task: Task<()>, - _embed_batch_task: Vec>, + _embed_batch_task: Task<()>, _batch_files_task: Task<()>, _parsing_files_tasks: Vec>, projects: HashMap, ProjectState>, @@ -128,6 +127,30 @@ struct ProjectState { } impl ProjectState { + fn db_id_for_worktree_id(&self, id: WorktreeId) -> Option { + self.worktree_db_ids + .iter() + .find_map(|(worktree_id, db_id)| { + if *worktree_id == id { + Some(*db_id) + } else { + None + } + }) + } + + fn worktree_id_for_db_id(&self, id: i64) -> Option { + self.worktree_db_ids + .iter() + .find_map(|(worktree_id, db_id)| { + if *db_id == id { + Some(*worktree_id) + } else { + None + } + }) + } + fn update_pending_files(&mut self, pending_file: PendingFile, indexing_time: SystemTime) { // If Pending File Already Exists, Replace it with the new one // but keep the old indexing time @@ -185,7 +208,7 @@ pub struct SearchResult { pub file_path: PathBuf, } -enum DbWrite { +enum DbOperation { InsertFile { worktree_id: i64, indexed_file: ParsedFile, @@ -198,6 +221,10 @@ enum DbWrite { path: PathBuf, sender: oneshot::Sender>, }, + FileMTimes { + worktree_id: i64, + sender: oneshot::Sender>>, + }, } enum EmbeddingJob { @@ -243,20 +270,27 @@ impl VectorStore { let _db_update_task = cx.background().spawn(async move { while let Ok(job) = db_update_rx.recv().await { match job { - DbWrite::InsertFile { + DbOperation::InsertFile { worktree_id, indexed_file, } => { log::info!("Inserting Data for {:?}", &indexed_file.path); db.insert_file(worktree_id, indexed_file).log_err(); } - DbWrite::Delete { worktree_id, path } => { + DbOperation::Delete { worktree_id, path } => { db.delete_file(worktree_id, path).log_err(); } - DbWrite::FindOrCreateWorktree { path, sender } => { + DbOperation::FindOrCreateWorktree { path, sender } => { let id = db.find_or_create_worktree(&path); sender.send(id).ok(); } + DbOperation::FileMTimes { + worktree_id: worktree_db_id, + sender, + } => { + let file_mtimes = db.get_file_mtimes(worktree_db_id); + sender.send(file_mtimes).ok(); + } } } }); @@ -264,24 +298,18 @@ impl VectorStore { // embed_tx/rx: Embed Batch and Send to Database let (embed_batch_tx, embed_batch_rx) = channel::unbounded::)>>(); - let mut _embed_batch_task = Vec::new(); - for _ in 0..1 { - //cx.background().num_cpus() { + let _embed_batch_task = cx.background().spawn({ let db_update_tx = db_update_tx.clone(); - let embed_batch_rx = embed_batch_rx.clone(); let embedding_provider = embedding_provider.clone(); - _embed_batch_task.push(cx.background().spawn(async move { - while let Ok(embeddings_queue) = embed_batch_rx.recv().await { + async move { + while let Ok(mut embeddings_queue) = embed_batch_rx.recv().await { // Construct Batch - let mut embeddings_queue = embeddings_queue.clone(); let mut document_spans = vec![]; - for (_, _, document_span) in embeddings_queue.clone().into_iter() { - document_spans.extend(document_span); + for (_, _, document_span) in embeddings_queue.iter() { + document_spans.extend(document_span.iter().map(|s| s.as_str())); } - if let Ok(embeddings) = embedding_provider - .embed_batch(document_spans.iter().map(|x| &**x).collect()) - .await + if let Ok(embeddings) = embedding_provider.embed_batch(document_spans).await { let mut i = 0; let mut j = 0; @@ -306,7 +334,7 @@ impl VectorStore { } db_update_tx - .send(DbWrite::InsertFile { + .send(DbOperation::InsertFile { worktree_id, indexed_file, }) @@ -315,8 +343,9 @@ impl VectorStore { } } } - })) - } + } + }); + // batch_tx/rx: Batch Files to Send for Embeddings let (batch_files_tx, batch_files_rx) = channel::unbounded::(); let _batch_files_task = cx.background().spawn(async move { @@ -398,7 +427,21 @@ impl VectorStore { fn find_or_create_worktree(&self, path: PathBuf) -> impl Future> { let (tx, rx) = oneshot::channel(); self.db_update_tx - .try_send(DbWrite::FindOrCreateWorktree { path, sender: tx }) + .try_send(DbOperation::FindOrCreateWorktree { path, sender: tx }) + .unwrap(); + async move { rx.await? } + } + + fn get_file_mtimes( + &self, + worktree_id: i64, + ) -> impl Future>> { + let (tx, rx) = oneshot::channel(); + self.db_update_tx + .try_send(DbOperation::FileMTimes { + worktree_id, + sender: tx, + }) .unwrap(); async move { rx.await? } } @@ -450,26 +493,17 @@ impl VectorStore { .collect::>() }); - // Here we query the worktree ids, and yet we dont have them elsewhere - // We likely want to clean up these datastructures - let (mut worktree_file_times, db_ids_by_worktree_id) = cx - .background() - .spawn({ - let worktrees = worktrees.clone(); - async move { - let db = VectorDatabase::new(database_url.to_string_lossy().into())?; - let mut db_ids_by_worktree_id = HashMap::new(); - let mut file_times: HashMap> = - HashMap::new(); - for (worktree, db_id) in worktrees.iter().zip(worktree_db_ids) { - let db_id = db_id?; - db_ids_by_worktree_id.insert(worktree.id(), db_id); - file_times.insert(worktree.id(), db.get_file_mtimes(db_id)?); - } - anyhow::Ok((file_times, db_ids_by_worktree_id)) - } - }) - .await?; + let mut worktree_file_times = HashMap::new(); + let mut db_ids_by_worktree_id = HashMap::new(); + for (worktree, db_id) in worktrees.iter().zip(worktree_db_ids) { + let db_id = db_id?; + db_ids_by_worktree_id.insert(worktree.id(), db_id); + worktree_file_times.insert( + worktree.id(), + this.read_with(&cx, |this, _| this.get_file_mtimes(db_id)) + .await?, + ); + } cx.background() .spawn({ @@ -520,7 +554,7 @@ impl VectorStore { } for file in file_mtimes.keys() { db_update_tx - .try_send(DbWrite::Delete { + .try_send(DbOperation::Delete { worktree_id: db_ids_by_worktree_id[&worktree.id()], path: file.to_owned(), }) @@ -542,7 +576,7 @@ impl VectorStore { // greater than the previous embedded time by the REINDEXING_DELAY variable, we will send the file off to be indexed. let _subscription = cx.subscribe(&project, |this, project, event, cx| { if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event { - this.project_entries_changed(project, changes, cx, worktree_id); + this.project_entries_changed(project, changes.clone(), cx, worktree_id); } }); @@ -578,16 +612,7 @@ impl VectorStore { .worktrees(cx) .filter_map(|worktree| { let worktree_id = worktree.read(cx).id(); - project_state - .worktree_db_ids - .iter() - .find_map(|(id, db_id)| { - if *id == worktree_id { - Some(*db_id) - } else { - None - } - }) + project_state.db_id_for_worktree_id(worktree_id) }) .collect::>(); @@ -606,24 +631,12 @@ impl VectorStore { .next() .unwrap(); - let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1); - database.for_each_document(&worktree_db_ids, |id, embedding| { - let similarity = dot(&embedding.0, &phrase_embedding); - let ix = match results.binary_search_by(|(_, s)| { - similarity.partial_cmp(&s).unwrap_or(Ordering::Equal) - }) { - Ok(ix) => ix, - Err(ix) => ix, - }; - results.insert(ix, (id, similarity)); - results.truncate(limit); - })?; - - let ids = results.into_iter().map(|(id, _)| id).collect::>(); - database.get_documents_by_ids(&ids) + database.top_k_search(&worktree_db_ids, &phrase_embedding, limit) }) .await?; + dbg!(&documents); + this.read_with(&cx, |this, _| { let project_state = if let Some(state) = this.projects.get(&project.downgrade()) { state @@ -634,17 +647,7 @@ impl VectorStore { Ok(documents .into_iter() .filter_map(|(worktree_db_id, file_path, offset, name)| { - let worktree_id = - project_state - .worktree_db_ids - .iter() - .find_map(|(id, db_id)| { - if *db_id == worktree_db_id { - Some(*id) - } else { - None - } - })?; + let worktree_id = project_state.worktree_id_for_db_id(worktree_db_id)?; Some(SearchResult { worktree_id, name, @@ -660,51 +663,36 @@ impl VectorStore { fn project_entries_changed( &mut self, project: ModelHandle, - changes: &[(Arc, ProjectEntryId, PathChange)], + changes: Arc<[(Arc, ProjectEntryId, PathChange)]>, cx: &mut ModelContext<'_, VectorStore>, worktree_id: &WorktreeId, ) -> Option<()> { - let project_state = self.projects.get_mut(&project.downgrade())?; - let worktree_db_ids = project_state.worktree_db_ids.clone(); - let worktree = project.read(cx).worktree_for_id(worktree_id.clone(), cx)?; + let worktree = project + .read(cx) + .worktree_for_id(worktree_id.clone(), cx)? + .read(cx) + .snapshot(); - // Get Database - let (file_mtimes, worktree_db_id) = { - if let Ok(db) = VectorDatabase::new(self.database_url.to_string_lossy().into()) { - let worktree_db_id = { - let mut found_db_id = None; - for (w_id, db_id) in worktree_db_ids.into_iter() { - if &w_id == &worktree.read(cx).id() { - found_db_id = Some(db_id) - } - } - found_db_id - }?; + let worktree_db_id = self + .projects + .get(&project.downgrade())? + .db_id_for_worktree_id(worktree.id())?; + let file_mtimes = self.get_file_mtimes(worktree_db_id); - let file_mtimes = db.get_file_mtimes(worktree_db_id).log_err()?; - - Some((file_mtimes, worktree_db_id)) - } else { - return None; - } - }?; - - // Iterate Through Changes let language_registry = self.language_registry.clone(); - let parsing_files_tx = self.parsing_files_tx.clone(); - smol::block_on(async move { + cx.spawn(|this, mut cx| async move { + let file_mtimes = file_mtimes.await.log_err()?; + for change in changes.into_iter() { let change_path = change.0.clone(); - let absolute_path = worktree.read(cx).absolutize(&change_path); + let absolute_path = worktree.absolutize(&change_path); // Skip if git ignored or symlink - if let Some(entry) = worktree.read(cx).entry_for_id(change.1) { - if entry.is_ignored || entry.is_symlink { + if let Some(entry) = worktree.entry_for_id(change.1) { + if entry.is_ignored || entry.is_symlink || entry.is_external { continue; - } else { - log::info!("Testing for Reindexing: {:?}", &change_path); } - }; + } if let Ok(language) = language_registry .language_for_file(&change_path.to_path_buf(), None) @@ -718,27 +706,18 @@ impl VectorStore { continue; } - if let Some(modified_time) = { - let metadata = change_path.metadata(); - if metadata.is_err() { - None - } else { - let mtime = metadata.unwrap().modified(); - if mtime.is_err() { - None - } else { - Some(mtime.unwrap()) - } - } - } { - let existing_time = file_mtimes.get(&change_path.to_path_buf()); - let already_stored = existing_time - .map_or(false, |existing_time| &modified_time != existing_time); + let modified_time = change_path.metadata().log_err()?.modified().log_err()?; - let reindex_time = - modified_time + Duration::from_secs(REINDEXING_DELAY_SECONDS); + let existing_time = file_mtimes.get(&change_path.to_path_buf()); + let already_stored = existing_time + .map_or(false, |existing_time| &modified_time != existing_time); - if !already_stored { + if !already_stored { + this.update(&mut cx, |this, _| { + let reindex_time = + modified_time + Duration::from_secs(REINDEXING_DELAY_SECONDS); + + let project_state = this.projects.get_mut(&project.downgrade())?; project_state.update_pending_files( PendingFile { relative_path: change_path.to_path_buf(), @@ -751,13 +730,18 @@ impl VectorStore { ); for file in project_state.get_outstanding_files() { - parsing_files_tx.try_send(file).unwrap(); + this.parsing_files_tx.try_send(file).unwrap(); } - } + Some(()) + }); } } } - }); + + Some(()) + }) + .detach(); + Some(()) } } @@ -765,29 +749,3 @@ impl VectorStore { impl Entity for VectorStore { type Event = (); } - -fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 { - let len = vec_a.len(); - assert_eq!(len, vec_b.len()); - - let mut result = 0.0; - unsafe { - matrixmultiply::sgemm( - 1, - len, - 1, - 1.0, - vec_a.as_ptr(), - len as isize, - 1, - vec_b.as_ptr(), - 1, - len as isize, - 0.0, - &mut result as *mut f32, - 1, - 1, - ); - } - result -} From dce72a1ce71ddf16b0d900e4d673fee49204888a Mon Sep 17 00:00:00 2001 From: KCaverly Date: Mon, 10 Jul 2023 18:19:29 -0400 Subject: [PATCH 37/51] updated tests to accomodate for new dot location --- crates/vector_store/src/vector_store_tests.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/crates/vector_store/src/vector_store_tests.rs b/crates/vector_store/src/vector_store_tests.rs index e25b737b06..b1756b7964 100644 --- a/crates/vector_store/src/vector_store_tests.rs +++ b/crates/vector_store/src/vector_store_tests.rs @@ -1,6 +1,4 @@ -use std::sync::Arc; - -use crate::{dot, embedding::EmbeddingProvider, VectorStore}; +use crate::{db::dot, embedding::EmbeddingProvider, VectorStore}; use anyhow::Result; use async_trait::async_trait; use gpui::{Task, TestAppContext}; @@ -8,6 +6,7 @@ use language::{Language, LanguageConfig, LanguageRegistry}; use project::{FakeFs, Project}; use rand::Rng; use serde_json::json; +use std::sync::Arc; use unindent::Unindent; #[gpui::test] From f5fec559308a12391bca54e1d981c5be7c846d1e Mon Sep 17 00:00:00 2001 From: KCaverly Date: Tue, 11 Jul 2023 10:03:53 -0400 Subject: [PATCH 38/51] updated vector_store to handle for removed files --- crates/vector_store/src/vector_store.rs | 95 ++++++++++++++----------- 1 file changed, 54 insertions(+), 41 deletions(-) diff --git a/crates/vector_store/src/vector_store.rs b/crates/vector_store/src/vector_store.rs index c42b7ab129..9b21073998 100644 --- a/crates/vector_store/src/vector_store.rs +++ b/crates/vector_store/src/vector_store.rs @@ -635,8 +635,6 @@ impl VectorStore { }) .await?; - dbg!(&documents); - this.read_with(&cx, |this, _| { let project_state = if let Some(state) = this.projects.get(&project.downgrade()) { state @@ -687,6 +685,7 @@ impl VectorStore { for change in changes.into_iter() { let change_path = change.0.clone(); let absolute_path = worktree.absolutize(&change_path); + // Skip if git ignored or symlink if let Some(entry) = worktree.entry_for_id(change.1) { if entry.is_ignored || entry.is_symlink || entry.is_external { @@ -694,46 +693,60 @@ impl VectorStore { } } - if let Ok(language) = language_registry - .language_for_file(&change_path.to_path_buf(), None) - .await - { - if language - .grammar() - .and_then(|grammar| grammar.embedding_config.as_ref()) - .is_none() - { - continue; - } - - let modified_time = change_path.metadata().log_err()?.modified().log_err()?; - - let existing_time = file_mtimes.get(&change_path.to_path_buf()); - let already_stored = existing_time - .map_or(false, |existing_time| &modified_time != existing_time); - - if !already_stored { - this.update(&mut cx, |this, _| { - let reindex_time = - modified_time + Duration::from_secs(REINDEXING_DELAY_SECONDS); - - let project_state = this.projects.get_mut(&project.downgrade())?; - project_state.update_pending_files( - PendingFile { - relative_path: change_path.to_path_buf(), - absolute_path, - modified_time, - worktree_db_id, - language: language.clone(), - }, - reindex_time, - ); - - for file in project_state.get_outstanding_files() { - this.parsing_files_tx.try_send(file).unwrap(); + match change.2 { + PathChange::Removed => this.update(&mut cx, |this, _| { + this.db_update_tx + .try_send(DbOperation::Delete { + worktree_id: worktree_db_id, + path: absolute_path, + }) + .unwrap(); + }), + _ => { + if let Ok(language) = language_registry + .language_for_file(&change_path.to_path_buf(), None) + .await + { + if language + .grammar() + .and_then(|grammar| grammar.embedding_config.as_ref()) + .is_none() + { + continue; } - Some(()) - }); + + let modified_time = + change_path.metadata().log_err()?.modified().log_err()?; + + let existing_time = file_mtimes.get(&change_path.to_path_buf()); + let already_stored = existing_time + .map_or(false, |existing_time| &modified_time != existing_time); + + if !already_stored { + this.update(&mut cx, |this, _| { + let reindex_time = modified_time + + Duration::from_secs(REINDEXING_DELAY_SECONDS); + + let project_state = + this.projects.get_mut(&project.downgrade())?; + project_state.update_pending_files( + PendingFile { + relative_path: change_path.to_path_buf(), + absolute_path, + modified_time, + worktree_db_id, + language: language.clone(), + }, + reindex_time, + ); + + for file in project_state.get_outstanding_files() { + this.parsing_files_tx.try_send(file).unwrap(); + } + Some(()) + }); + } + } } } } From 298c2213a0afa68f1dbaf04dc8b90420303743a9 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Tue, 11 Jul 2023 12:03:56 -0400 Subject: [PATCH 39/51] added opt-in default settings for vector store --- Cargo.lock | 2 ++ assets/settings/default.json | 6 ++++ crates/vector_store/Cargo.toml | 3 ++ crates/vector_store/src/db.rs | 7 +--- crates/vector_store/src/vector_store.rs | 29 ++++++++++------- .../vector_store/src/vector_store_settings.rs | 32 +++++++++++++++++++ crates/vector_store/src/vector_store_tests.rs | 10 +++++- 7 files changed, 70 insertions(+), 19 deletions(-) create mode 100644 crates/vector_store/src/vector_store_settings.rs diff --git a/Cargo.lock b/Cargo.lock index 22df4083fd..cd92d0003a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8503,8 +8503,10 @@ dependencies = [ "rand 0.8.5", "rpc", "rusqlite", + "schemars", "serde", "serde_json", + "settings", "smol", "tempdir", "theme", diff --git a/assets/settings/default.json b/assets/settings/default.json index 9ae5c916b5..cf8f630dfb 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -291,6 +291,12 @@ // the terminal will default to matching the buffer's font family. // "font_family": "Zed Mono" }, + // Difference settings for vector_store + "vector_store": { + "enable": false, + "reindexing_delay_seconds": 600, + "embedding_batch_size": 150 + }, // Different settings for specific languages. "languages": { "Plain Text": { diff --git a/crates/vector_store/Cargo.toml b/crates/vector_store/Cargo.toml index 35a6a689ae..40bff8b95c 100644 --- a/crates/vector_store/Cargo.toml +++ b/crates/vector_store/Cargo.toml @@ -18,6 +18,7 @@ picker = { path = "../picker" } theme = { path = "../theme" } editor = { path = "../editor" } rpc = { path = "../rpc" } +settings = { path = "../settings" } anyhow.workspace = true futures.workspace = true smol.workspace = true @@ -33,6 +34,7 @@ bincode = "1.3.3" matrixmultiply = "0.3.7" tiktoken-rs = "0.5.0" rand.workspace = true +schemars.workspace = true [dev-dependencies] gpui = { path = "../gpui", features = ["test-support"] } @@ -40,6 +42,7 @@ language = { path = "../language", features = ["test-support"] } project = { path = "../project", features = ["test-support"] } rpc = { path = "../rpc", features = ["test-support"] } workspace = { path = "../workspace", features = ["test-support"] } +settings = { path = "../settings", features = ["test-support"]} tree-sitter-rust = "*" rand.workspace = true unindent.workspace = true diff --git a/crates/vector_store/src/db.rs b/crates/vector_store/src/db.rs index 197e7d5696..79d90e87bf 100644 --- a/crates/vector_store/src/db.rs +++ b/crates/vector_store/src/db.rs @@ -204,8 +204,6 @@ impl VectorDatabase { ) -> Result> { let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1); self.for_each_document(&worktree_ids, |id, embedding| { - eprintln!("document {id} {embedding:?}"); - let similarity = dot(&embedding, &query_embedding); let ix = match results .binary_search_by(|(_, s)| similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)) @@ -243,10 +241,7 @@ impl VectorDatabase { Ok((row.get(0)?, row.get::<_, Embedding>(1)?)) })? .filter_map(|row| row.ok()) - .for_each(|(id, embedding)| { - dbg!("id"); - f(id, embedding.0) - }); + .for_each(|(id, embedding)| f(id, embedding.0)); Ok(()) } diff --git a/crates/vector_store/src/vector_store.rs b/crates/vector_store/src/vector_store.rs index 9b21073998..4b5f6b636f 100644 --- a/crates/vector_store/src/vector_store.rs +++ b/crates/vector_store/src/vector_store.rs @@ -2,22 +2,25 @@ mod db; mod embedding; mod modal; mod parsing; +mod vector_store_settings; #[cfg(test)] mod vector_store_tests; +use crate::vector_store_settings::VectorStoreSettings; use anyhow::{anyhow, Result}; use db::VectorDatabase; use embedding::{EmbeddingProvider, OpenAIEmbeddings}; use futures::{channel::oneshot, Future}; use gpui::{ - AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, ViewContext, - WeakModelHandle, + AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Subscription, Task, + ViewContext, WeakModelHandle, }; use language::{Language, LanguageRegistry}; use modal::{SemanticSearch, SemanticSearchDelegate, Toggle}; use parsing::{CodeContextRetriever, ParsedFile}; use project::{Fs, PathChange, Project, ProjectEntryId, WorktreeId}; +use settings::SettingsStore; use smol::channel; use std::{ collections::HashMap, @@ -34,9 +37,6 @@ use util::{ }; use workspace::{Workspace, WorkspaceCreated}; -const REINDEXING_DELAY_SECONDS: u64 = 3; -const EMBEDDINGS_BATCH_SIZE: usize = 150; - pub fn init( fs: Arc, http_client: Arc, @@ -47,6 +47,12 @@ pub fn init( return; } + settings::register::(cx); + + if !settings::get::(cx).enable { + return; + } + let db_file_path = EMBEDDINGS_DIR .join(Path::new(RELEASE_CHANNEL_NAME.as_str())) .join("embeddings_db"); @@ -83,6 +89,7 @@ pub fn init( .detach(); cx.add_action({ + // "semantic search: Toggle" move |workspace: &mut Workspace, _: &Toggle, cx: &mut ViewContext| { let vector_store = vector_store.clone(); workspace.toggle_modal(cx, |workspace, cx| { @@ -274,7 +281,6 @@ impl VectorStore { worktree_id, indexed_file, } => { - log::info!("Inserting Data for {:?}", &indexed_file.path); db.insert_file(worktree_id, indexed_file).log_err(); } DbOperation::Delete { worktree_id, path } => { @@ -347,6 +353,7 @@ impl VectorStore { }); // batch_tx/rx: Batch Files to Send for Embeddings + let batch_size = settings::get::(cx).embedding_batch_size; let (batch_files_tx, batch_files_rx) = channel::unbounded::(); let _batch_files_task = cx.background().spawn(async move { let mut queue_len = 0; @@ -361,7 +368,7 @@ impl VectorStore { } => { queue_len += &document_spans.len(); embeddings_queue.push((worktree_id, parsed_file, document_spans)); - queue_len >= EMBEDDINGS_BATCH_SIZE + queue_len >= batch_size } EmbeddingJob::Flush => true, }; @@ -387,8 +394,6 @@ impl VectorStore { let cursor = QueryCursor::new(); let mut retriever = CodeContextRetriever { parser, cursor, fs }; while let Ok(pending_file) = parsing_files_rx.recv().await { - log::info!("Parsing File: {:?}", &pending_file.relative_path); - if let Some((indexed_file, document_spans)) = retriever.parse_file(pending_file.clone()).await.log_err() { @@ -476,11 +481,9 @@ impl VectorStore { let parsing_files_tx = self.parsing_files_tx.clone(); cx.spawn(|this, mut cx| async move { - let t0 = Instant::now(); futures::future::join_all(worktree_scans_complete).await; let worktree_db_ids = futures::future::join_all(worktree_db_ids).await; - log::info!("Worktree Scanning Done in {:?}", t0.elapsed().as_millis()); if let Some(db_directory) = database_url.parent() { fs.create_dir(db_directory).await.log_err(); @@ -665,6 +668,8 @@ impl VectorStore { cx: &mut ModelContext<'_, VectorStore>, worktree_id: &WorktreeId, ) -> Option<()> { + let reindexing_delay = settings::get::(cx).reindexing_delay_seconds; + let worktree = project .read(cx) .worktree_for_id(worktree_id.clone(), cx)? @@ -725,7 +730,7 @@ impl VectorStore { if !already_stored { this.update(&mut cx, |this, _| { let reindex_time = modified_time - + Duration::from_secs(REINDEXING_DELAY_SECONDS); + + Duration::from_secs(reindexing_delay as u64); let project_state = this.projects.get_mut(&project.downgrade())?; diff --git a/crates/vector_store/src/vector_store_settings.rs b/crates/vector_store/src/vector_store_settings.rs new file mode 100644 index 0000000000..0bde07dd65 --- /dev/null +++ b/crates/vector_store/src/vector_store_settings.rs @@ -0,0 +1,32 @@ +use anyhow; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use settings::Setting; + +#[derive(Deserialize, Debug)] +pub struct VectorStoreSettings { + pub enable: bool, + pub reindexing_delay_seconds: usize, + pub embedding_batch_size: usize, +} + +#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)] +pub struct VectorStoreSettingsContent { + pub enable: Option, + pub reindexing_delay_seconds: Option, + pub embedding_batch_size: Option, +} + +impl Setting for VectorStoreSettings { + const KEY: Option<&'static str> = Some("vector_store"); + + type FileContent = VectorStoreSettingsContent; + + fn load( + default_value: &Self::FileContent, + user_values: &[&Self::FileContent], + _: &gpui::AppContext, + ) -> anyhow::Result { + Self::load_via_json_merge(default_value, user_values) + } +} diff --git a/crates/vector_store/src/vector_store_tests.rs b/crates/vector_store/src/vector_store_tests.rs index b1756b7964..a3a40722ea 100644 --- a/crates/vector_store/src/vector_store_tests.rs +++ b/crates/vector_store/src/vector_store_tests.rs @@ -1,4 +1,6 @@ -use crate::{db::dot, embedding::EmbeddingProvider, VectorStore}; +use crate::{ + db::dot, embedding::EmbeddingProvider, vector_store_settings::VectorStoreSettings, VectorStore, +}; use anyhow::Result; use async_trait::async_trait; use gpui::{Task, TestAppContext}; @@ -6,11 +8,17 @@ use language::{Language, LanguageConfig, LanguageRegistry}; use project::{FakeFs, Project}; use rand::Rng; use serde_json::json; +use settings::SettingsStore; use std::sync::Arc; use unindent::Unindent; #[gpui::test] async fn test_vector_store(cx: &mut TestAppContext) { + cx.update(|cx| { + cx.set_global(SettingsStore::test(cx)); + settings::register::(cx); + }); + let fs = FakeFs::new(cx.background()); fs.insert_tree( "/the-root", From 1649cf81de4bc3cc506b3d118c2454693758088f Mon Sep 17 00:00:00 2001 From: KCaverly Date: Tue, 11 Jul 2023 14:42:03 -0400 Subject: [PATCH 40/51] added versioning to files table --- crates/vector_store/src/db.rs | 9 ++++++--- crates/vector_store/src/vector_store.rs | 7 ++++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/crates/vector_store/src/db.rs b/crates/vector_store/src/db.rs index 79d90e87bf..a91a1872b5 100644 --- a/crates/vector_store/src/db.rs +++ b/crates/vector_store/src/db.rs @@ -9,6 +9,7 @@ use std::{ use anyhow::{anyhow, Result}; use crate::parsing::ParsedFile; +use crate::VECTOR_STORE_VERSION; use rpc::proto::Timestamp; use rusqlite::{ params, @@ -72,6 +73,7 @@ impl VectorDatabase { relative_path VARCHAR NOT NULL, mtime_seconds INTEGER NOT NULL, mtime_nanos INTEGER NOT NULL, + vector_store_version INTEGER NOT NULL, FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE )", [], @@ -112,15 +114,16 @@ impl VectorDatabase { self.db.execute( " INSERT INTO files - (worktree_id, relative_path, mtime_seconds, mtime_nanos) + (worktree_id, relative_path, mtime_seconds, mtime_nanos, vector_store_version) VALUES - (?1, ?2, $3, $4); + (?1, ?2, $3, $4, $5); ", params![ worktree_id, indexed_file.path.to_str(), mtime.seconds, - mtime.nanos + mtime.nanos, + VECTOR_STORE_VERSION ], )?; diff --git a/crates/vector_store/src/vector_store.rs b/crates/vector_store/src/vector_store.rs index 4b5f6b636f..6f63f07b88 100644 --- a/crates/vector_store/src/vector_store.rs +++ b/crates/vector_store/src/vector_store.rs @@ -13,14 +13,13 @@ use db::VectorDatabase; use embedding::{EmbeddingProvider, OpenAIEmbeddings}; use futures::{channel::oneshot, Future}; use gpui::{ - AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Subscription, Task, - ViewContext, WeakModelHandle, + AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, ViewContext, + WeakModelHandle, }; use language::{Language, LanguageRegistry}; use modal::{SemanticSearch, SemanticSearchDelegate, Toggle}; use parsing::{CodeContextRetriever, ParsedFile}; use project::{Fs, PathChange, Project, ProjectEntryId, WorktreeId}; -use settings::SettingsStore; use smol::channel; use std::{ collections::HashMap, @@ -37,6 +36,8 @@ use util::{ }; use workspace::{Workspace, WorkspaceCreated}; +const VECTOR_STORE_VERSION: usize = 0; + pub fn init( fs: Arc, http_client: Arc, From 02f523094be66efdcc9f6ca6b072ce787f8860c8 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Tue, 11 Jul 2023 15:58:33 -0400 Subject: [PATCH 41/51] expanded embeddable context to accomodate for struct context and file paths --- crates/vector_store/src/parsing.rs | 46 ++++++++++++++++++++---------- 1 file changed, 31 insertions(+), 15 deletions(-) diff --git a/crates/vector_store/src/parsing.rs b/crates/vector_store/src/parsing.rs index 6a8742fedd..a91e87aa22 100644 --- a/crates/vector_store/src/parsing.rs +++ b/crates/vector_store/src/parsing.rs @@ -20,6 +20,9 @@ pub struct ParsedFile { pub documents: Vec, } +const CODE_CONTEXT_TEMPLATE: &str = + "The below code snippet is from file ''\n\n```\n\n```"; + pub struct CodeContextRetriever { pub parser: Parser, pub cursor: QueryCursor, @@ -58,27 +61,40 @@ impl CodeContextRetriever { tree.root_node(), content.as_bytes(), ) { - let mut item_range: Option> = None; - let mut name_range: Option> = None; + let mut name: Vec<&str> = vec![]; + let mut item: Option<&str> = None; + let mut offset: Option = None; for capture in mat.captures { if capture.index == embedding_config.item_capture_ix { - item_range = Some(capture.node.byte_range()); + offset = Some(capture.node.byte_range().start); + item = content.get(capture.node.byte_range()); } else if capture.index == embedding_config.name_capture_ix { - name_range = Some(capture.node.byte_range()); + if let Some(name_content) = content.get(capture.node.byte_range()) { + name.push(name_content); + } + } + + if let Some(context_capture_ix) = embedding_config.context_capture_ix { + if capture.index == context_capture_ix { + if let Some(context) = content.get(capture.node.byte_range()) { + name.push(context); + } + } } } - if let Some((item_range, name_range)) = item_range.zip(name_range) { - if let Some((item, name)) = - content.get(item_range.clone()).zip(content.get(name_range)) - { - context_spans.push(item.to_string()); - documents.push(Document { - name: name.to_string(), - offset: item_range.start, - embedding: Vec::new(), - }); - } + if item.is_some() && offset.is_some() && name.len() > 0 { + let context_span = CODE_CONTEXT_TEMPLATE + .replace("", pending_file.relative_path.to_str().unwrap()) + .replace("", &pending_file.language.name().to_lowercase()) + .replace("", item.unwrap()); + + context_spans.push(context_span); + documents.push(Document { + name: name.join(" "), + offset: offset.unwrap(), + embedding: Vec::new(), + }) } } From debe6f107e44c4c9a5b07c9286135d474508da88 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Tue, 11 Jul 2023 16:22:40 -0400 Subject: [PATCH 42/51] updated embedding queries for tsx and typescript --- crates/zed/src/languages/tsx/embedding.scm | 59 +++++++++++++++++++ .../src/languages/typescript/embedding.scm | 59 +++++++++++++++++++ 2 files changed, 118 insertions(+) create mode 100644 crates/zed/src/languages/tsx/embedding.scm create mode 100644 crates/zed/src/languages/typescript/embedding.scm diff --git a/crates/zed/src/languages/tsx/embedding.scm b/crates/zed/src/languages/tsx/embedding.scm new file mode 100644 index 0000000000..a8cde61b9e --- /dev/null +++ b/crates/zed/src/languages/tsx/embedding.scm @@ -0,0 +1,59 @@ +; (internal_module +; "namespace" @context + name: (_) @name) @item + +(enum_declaration + "enum" @context + name: (_) @name) @item + +; (type_alias_declaration +; "type" @context + name: (_) @name) @item + +(function_declaration + "async"? @context + "function" @context + name: (_) @name) @item + +(interface_declaration + "interface" @context + name: (_) @name) @item + +; (export_statement +; (lexical_declaration +; ["let" "const"] @context +; (variable_declarator +; name: (_) @name) @item)) + +(program + (lexical_declaration + ["let" "const"] @context + (variable_declarator + name: (_) @name) @item)) + +(class_declaration + "class" @context + name: (_) @name) @item + +(method_definition + [ + "get" + "set" + "async" + "*" + "readonly" + "static" + (override_modifier) + (accessibility_modifier) + ]* @context + name: (_) @name) @item + +; (public_field_definition +; [ +; "declare" +; "readonly" +; "abstract" +; "static" +; (accessibility_modifier) +; ]* @context +; name: (_) @name) @item diff --git a/crates/zed/src/languages/typescript/embedding.scm b/crates/zed/src/languages/typescript/embedding.scm new file mode 100644 index 0000000000..f261a0a565 --- /dev/null +++ b/crates/zed/src/languages/typescript/embedding.scm @@ -0,0 +1,59 @@ +; (internal_module +; "namespace" @context +; name: (_) @name) @item + +(enum_declaration + "enum" @context + name: (_) @name) @item + +; (type_alias_declaration +; "type" @context +; name: (_) @name) @item + +(function_declaration + "async"? @context + "function" @context + name: (_) @name) @item + +(interface_declaration + "interface" @context + name: (_) @name) @item + +; (export_statement +; (lexical_declaration +; ["let" "const"] @context +; (variable_declarator +; name: (_) @name) @item)) + +(program + (lexical_declaration + ["let" "const"] @context + (variable_declarator + name: (_) @name) @item)) + +(class_declaration + "class" @context + name: (_) @name) @item + +(method_definition + [ + "get" + "set" + "async" + "*" + "readonly" + "static" + (override_modifier) + (accessibility_modifier) + ]* @context + name: (_) @name) @item + +; (public_field_definition +; [ +; "declare" +; "readonly" +; "abstract" +; "static" +; (accessibility_modifier) +; ]* @context +; name: (_) @name) @item From 2ca4b3f4cc399ac5773514a7b76962cfc0aa568f Mon Sep 17 00:00:00 2001 From: KCaverly Date: Tue, 11 Jul 2023 16:41:08 -0400 Subject: [PATCH 43/51] cleaned up warnings and added javascript --- crates/vector_store/src/parsing.rs | 10 +++- crates/vector_store/src/vector_store.rs | 3 +- .../src/languages/javascript/embedding.scm | 56 +++++++++++++++++++ crates/zed/src/languages/tsx/embedding.scm | 2 +- 4 files changed, 68 insertions(+), 3 deletions(-) create mode 100644 crates/zed/src/languages/javascript/embedding.scm diff --git a/crates/vector_store/src/parsing.rs b/crates/vector_store/src/parsing.rs index a91e87aa22..91dcf699f8 100644 --- a/crates/vector_store/src/parsing.rs +++ b/crates/vector_store/src/parsing.rs @@ -1,4 +1,4 @@ -use std::{ops::Range, path::PathBuf, sync::Arc, time::SystemTime}; +use std::{path::PathBuf, sync::Arc, time::SystemTime}; use anyhow::{anyhow, Ok, Result}; use project::Fs; @@ -61,6 +61,8 @@ impl CodeContextRetriever { tree.root_node(), content.as_bytes(), ) { + // log::info!("-----MATCH-----"); + let mut name: Vec<&str> = vec![]; let mut item: Option<&str> = None; let mut offset: Option = None; @@ -89,6 +91,12 @@ impl CodeContextRetriever { .replace("", &pending_file.language.name().to_lowercase()) .replace("", item.unwrap()); + let mut truncated_span = context_span.clone(); + truncated_span.truncate(100); + + // log::info!("Name: {:?}", name); + // log::info!("Span: {:?}", truncated_span); + context_spans.push(context_span); documents.push(Document { name: name.join(" "), diff --git a/crates/vector_store/src/vector_store.rs b/crates/vector_store/src/vector_store.rs index 6f63f07b88..a2ca90e84e 100644 --- a/crates/vector_store/src/vector_store.rs +++ b/crates/vector_store/src/vector_store.rs @@ -386,7 +386,8 @@ impl VectorStore { let (parsing_files_tx, parsing_files_rx) = channel::unbounded::(); let mut _parsing_files_tasks = Vec::new(); - for _ in 0..cx.background().num_cpus() { + // for _ in 0..cx.background().num_cpus() { + for _ in 0..1 { let fs = fs.clone(); let parsing_files_rx = parsing_files_rx.clone(); let batch_files_tx = batch_files_tx.clone(); diff --git a/crates/zed/src/languages/javascript/embedding.scm b/crates/zed/src/languages/javascript/embedding.scm new file mode 100644 index 0000000000..ec6eb5ab1a --- /dev/null +++ b/crates/zed/src/languages/javascript/embedding.scm @@ -0,0 +1,56 @@ +; (internal_module +; "namespace" @context +; name: (_) @name) @item + +(enum_declaration + "enum" @context + name: (_) @name) @item + +(function_declaration + "async"? @context + "function" @context + name: (_) @name) @item + +(interface_declaration + "interface" @context + name: (_) @name) @item + +; (program +; (export_statement +; (lexical_declaration +; ["let" "const"] @context +; (variable_declarator +; name: (_) @name) @item))) + +(program + (lexical_declaration + ["let" "const"] @context + (variable_declarator + name: (_) @name) @item)) + +(class_declaration + "class" @context + name: (_) @name) @item + +(method_definition + [ + "get" + "set" + "async" + "*" + "readonly" + "static" + (override_modifier) + (accessibility_modifier) + ]* @context + name: (_) @name) @item + +; (public_field_definition +; [ +; "declare" +; "readonly" +; "abstract" +; "static" +; (accessibility_modifier) +; ]* @context +; name: (_) @name) @item diff --git a/crates/zed/src/languages/tsx/embedding.scm b/crates/zed/src/languages/tsx/embedding.scm index a8cde61b9e..96c56abe9f 100644 --- a/crates/zed/src/languages/tsx/embedding.scm +++ b/crates/zed/src/languages/tsx/embedding.scm @@ -1,6 +1,6 @@ ; (internal_module ; "namespace" @context - name: (_) @name) @item + ; name: (_) @name) @item (enum_declaration "enum" @context From af7b2f17ae28699fc20bbe88513db00450390fa3 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Tue, 11 Jul 2023 17:13:58 -0400 Subject: [PATCH 44/51] added initial keymap for toggle semantic search Co-authored-by: maxbrunsfeld --- assets/keymaps/default.json | 1 + crates/vector_store/src/vector_store.rs | 42 +++++++++++++------------ 2 files changed, 23 insertions(+), 20 deletions(-) diff --git a/assets/keymaps/default.json b/assets/keymaps/default.json index 8c3a1f407c..3f0b545ebc 100644 --- a/assets/keymaps/default.json +++ b/assets/keymaps/default.json @@ -405,6 +405,7 @@ "cmd-k cmd-t": "theme_selector::Toggle", "cmd-k cmd-s": "zed::OpenKeymap", "cmd-t": "project_symbols::Toggle", + "cmd-alt-t": "semantic_search::Toggle", "cmd-p": "file_finder::Toggle", "cmd-shift-p": "command_palette::Toggle", "cmd-shift-m": "diagnostics::Deploy", diff --git a/crates/vector_store/src/vector_store.rs b/crates/vector_store/src/vector_store.rs index a2ca90e84e..d3f89d568a 100644 --- a/crates/vector_store/src/vector_store.rs +++ b/crates/vector_store/src/vector_store.rs @@ -49,7 +49,6 @@ pub fn init( } settings::register::(cx); - if !settings::get::(cx).enable { return; } @@ -58,6 +57,27 @@ pub fn init( .join(Path::new(RELEASE_CHANNEL_NAME.as_str())) .join("embeddings_db"); + SemanticSearch::init(cx); + cx.add_action( + |workspace: &mut Workspace, _: &Toggle, cx: &mut ViewContext| { + eprintln!("semantic_search::Toggle action"); + + if cx.has_global::>() { + let vector_store = cx.global::>().clone(); + workspace.toggle_modal(cx, |workspace, cx| { + let project = workspace.project().clone(); + let workspace = cx.weak_handle(); + cx.add_view(|cx| { + SemanticSearch::new( + SemanticSearchDelegate::new(workspace, project, vector_store), + cx, + ) + }) + }); + } + }, + ); + cx.spawn(move |mut cx| async move { let vector_store = VectorStore::new( fs, @@ -73,6 +93,7 @@ pub fn init( .await?; cx.update(|cx| { + cx.set_global(vector_store.clone()); cx.subscribe_global::({ let vector_store = vector_store.clone(); move |event, cx| { @@ -88,25 +109,6 @@ pub fn init( } }) .detach(); - - cx.add_action({ - // "semantic search: Toggle" - move |workspace: &mut Workspace, _: &Toggle, cx: &mut ViewContext| { - let vector_store = vector_store.clone(); - workspace.toggle_modal(cx, |workspace, cx| { - let project = workspace.project().clone(); - let workspace = cx.weak_handle(); - cx.add_view(|cx| { - SemanticSearch::new( - SemanticSearchDelegate::new(workspace, project, vector_store), - cx, - ) - }) - }) - } - }); - - SemanticSearch::init(cx); }); anyhow::Ok(()) From 08e24bbbae8de4f8db3d0bdc68c2c1e3293958f6 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Tue, 11 Jul 2023 14:29:06 -0700 Subject: [PATCH 45/51] Use cmd-ctrl-t for semantic search key binding Co-authored-by: Kyle --- assets/keymaps/default.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/assets/keymaps/default.json b/assets/keymaps/default.json index 3f0b545ebc..4726c67aea 100644 --- a/assets/keymaps/default.json +++ b/assets/keymaps/default.json @@ -405,7 +405,7 @@ "cmd-k cmd-t": "theme_selector::Toggle", "cmd-k cmd-s": "zed::OpenKeymap", "cmd-t": "project_symbols::Toggle", - "cmd-alt-t": "semantic_search::Toggle", + "cmd-ctrl-t": "semantic_search::Toggle", "cmd-p": "file_finder::Toggle", "cmd-shift-p": "command_palette::Toggle", "cmd-shift-m": "diagnostics::Deploy", From badf94b097e7fc5c158b624d4ef2907ac0ae1b0e Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Tue, 11 Jul 2023 14:29:48 -0700 Subject: [PATCH 46/51] Update dot product test to use larger vectors Co-authored-by: Kyle --- crates/vector_store/src/vector_store_tests.rs | 21 ++++++++++++------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/crates/vector_store/src/vector_store_tests.rs b/crates/vector_store/src/vector_store_tests.rs index a3a40722ea..ede43b9ff8 100644 --- a/crates/vector_store/src/vector_store_tests.rs +++ b/crates/vector_store/src/vector_store_tests.rs @@ -6,7 +6,7 @@ use async_trait::async_trait; use gpui::{Task, TestAppContext}; use language::{Language, LanguageConfig, LanguageRegistry}; use project::{FakeFs, Project}; -use rand::Rng; +use rand::{rngs::StdRng, Rng}; use serde_json::json; use settings::SettingsStore; use std::sync::Arc; @@ -97,18 +97,23 @@ async fn test_vector_store(cx: &mut TestAppContext) { assert_eq!(search_results[0].worktree_id, worktree_id); } -#[test] -fn test_dot_product() { +#[gpui::test] +fn test_dot_product(mut rng: StdRng) { assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.); assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.); for _ in 0..100 { - let mut rng = rand::thread_rng(); - let a: [f32; 32] = rng.gen(); - let b: [f32; 32] = rng.gen(); + let size = 1536; + let mut a = vec![0.; size]; + let mut b = vec![0.; size]; + for (a, b) in a.iter_mut().zip(b.iter_mut()) { + *a = rng.gen(); + *b = rng.gen(); + } + assert_eq!( - round_to_decimals(dot(&a, &b), 3), - round_to_decimals(reference_dot(&a, &b), 3) + round_to_decimals(dot(&a, &b), 1), + round_to_decimals(reference_dot(&a, &b), 1) ); } From d244c0fcea07bc936baaab71e21a40638f24f383 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Tue, 11 Jul 2023 14:30:11 -0700 Subject: [PATCH 47/51] Get vector store test passing - wait for indexing Co-authored-by: Kyle --- crates/vector_store/src/vector_store_tests.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/crates/vector_store/src/vector_store_tests.rs b/crates/vector_store/src/vector_store_tests.rs index ede43b9ff8..8c5a667c7d 100644 --- a/crates/vector_store/src/vector_store_tests.rs +++ b/crates/vector_store/src/vector_store_tests.rs @@ -81,9 +81,11 @@ async fn test_vector_store(cx: &mut TestAppContext) { let worktree_id = project.read_with(cx, |project, cx| { project.worktrees(cx).next().unwrap().read(cx).id() }); - let add_project = store.update(cx, |store, cx| store.add_project(project.clone(), cx)); - - add_project.await.unwrap(); + store + .update(cx, |store, cx| store.add_project(project.clone(), cx)) + .await + .unwrap(); + cx.foreground().run_until_parked(); let search_results = store .update(cx, |store, cx| { From 4a4dd398750add1b93c7db41f7e4739405043e22 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Tue, 11 Jul 2023 15:02:19 -0700 Subject: [PATCH 48/51] Fix TSX embedding query --- crates/zed/src/languages/tsx/embedding.scm | 24 ---------------------- 1 file changed, 24 deletions(-) diff --git a/crates/zed/src/languages/tsx/embedding.scm b/crates/zed/src/languages/tsx/embedding.scm index 96c56abe9f..305f634e04 100644 --- a/crates/zed/src/languages/tsx/embedding.scm +++ b/crates/zed/src/languages/tsx/embedding.scm @@ -1,15 +1,7 @@ -; (internal_module -; "namespace" @context - ; name: (_) @name) @item - (enum_declaration "enum" @context name: (_) @name) @item -; (type_alias_declaration -; "type" @context - name: (_) @name) @item - (function_declaration "async"? @context "function" @context @@ -19,12 +11,6 @@ "interface" @context name: (_) @name) @item -; (export_statement -; (lexical_declaration -; ["let" "const"] @context -; (variable_declarator -; name: (_) @name) @item)) - (program (lexical_declaration ["let" "const"] @context @@ -47,13 +33,3 @@ (accessibility_modifier) ]* @context name: (_) @name) @item - -; (public_field_definition -; [ -; "declare" -; "readonly" -; "abstract" -; "static" -; (accessibility_modifier) -; ]* @context -; name: (_) @name) @item From 4b3bb2c6611eda1eed0ba88e9d3ef731f1439a62 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Tue, 11 Jul 2023 15:02:43 -0700 Subject: [PATCH 49/51] Define semantic search action regardless of whether the feature is enabled --- crates/vector_store/src/vector_store.rs | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/crates/vector_store/src/vector_store.rs b/crates/vector_store/src/vector_store.rs index d3f89d568a..87e70230ee 100644 --- a/crates/vector_store/src/vector_store.rs +++ b/crates/vector_store/src/vector_store.rs @@ -44,14 +44,7 @@ pub fn init( language_registry: Arc, cx: &mut AppContext, ) { - if *RELEASE_CHANNEL == ReleaseChannel::Stable { - return; - } - settings::register::(cx); - if !settings::get::(cx).enable { - return; - } let db_file_path = EMBEDDINGS_DIR .join(Path::new(RELEASE_CHANNEL_NAME.as_str())) @@ -60,8 +53,6 @@ pub fn init( SemanticSearch::init(cx); cx.add_action( |workspace: &mut Workspace, _: &Toggle, cx: &mut ViewContext| { - eprintln!("semantic_search::Toggle action"); - if cx.has_global::>() { let vector_store = cx.global::>().clone(); workspace.toggle_modal(cx, |workspace, cx| { @@ -78,6 +69,12 @@ pub fn init( }, ); + if *RELEASE_CHANNEL == ReleaseChannel::Stable + || !settings::get::(cx).enable + { + return; + } + cx.spawn(move |mut cx| async move { let vector_store = VectorStore::new( fs, From b68cd58a3b9ac1aa4a13955bae4a8c2fc08ce279 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Tue, 11 Jul 2023 19:54:03 -0400 Subject: [PATCH 50/51] updated vector store settings to remove batch embeddings size --- assets/settings/default.json | 5 ++--- crates/vector_store/src/vector_store.rs | 6 +++--- crates/vector_store/src/vector_store_settings.rs | 6 ++---- 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/assets/settings/default.json b/assets/settings/default.json index cf8f630dfb..1f8d12a3d9 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -293,9 +293,8 @@ }, // Difference settings for vector_store "vector_store": { - "enable": false, - "reindexing_delay_seconds": 600, - "embedding_batch_size": 150 + "enabled": false, + "reindexing_delay_seconds": 600 }, // Different settings for specific languages. "languages": { diff --git a/crates/vector_store/src/vector_store.rs b/crates/vector_store/src/vector_store.rs index 87e70230ee..0a197bc406 100644 --- a/crates/vector_store/src/vector_store.rs +++ b/crates/vector_store/src/vector_store.rs @@ -37,6 +37,7 @@ use util::{ use workspace::{Workspace, WorkspaceCreated}; const VECTOR_STORE_VERSION: usize = 0; +const EMBEDDINGS_BATCH_SIZE: usize = 150; pub fn init( fs: Arc, @@ -70,7 +71,7 @@ pub fn init( ); if *RELEASE_CHANNEL == ReleaseChannel::Stable - || !settings::get::(cx).enable + || !settings::get::(cx).enabled { return; } @@ -353,7 +354,6 @@ impl VectorStore { }); // batch_tx/rx: Batch Files to Send for Embeddings - let batch_size = settings::get::(cx).embedding_batch_size; let (batch_files_tx, batch_files_rx) = channel::unbounded::(); let _batch_files_task = cx.background().spawn(async move { let mut queue_len = 0; @@ -368,7 +368,7 @@ impl VectorStore { } => { queue_len += &document_spans.len(); embeddings_queue.push((worktree_id, parsed_file, document_spans)); - queue_len >= batch_size + queue_len >= EMBEDDINGS_BATCH_SIZE } EmbeddingJob::Flush => true, }; diff --git a/crates/vector_store/src/vector_store_settings.rs b/crates/vector_store/src/vector_store_settings.rs index 0bde07dd65..e1fa7cc05a 100644 --- a/crates/vector_store/src/vector_store_settings.rs +++ b/crates/vector_store/src/vector_store_settings.rs @@ -5,16 +5,14 @@ use settings::Setting; #[derive(Deserialize, Debug)] pub struct VectorStoreSettings { - pub enable: bool, + pub enabled: bool, pub reindexing_delay_seconds: usize, - pub embedding_batch_size: usize, } #[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)] pub struct VectorStoreSettingsContent { - pub enable: Option, + pub enabled: Option, pub reindexing_delay_seconds: Option, - pub embedding_batch_size: Option, } impl Setting for VectorStoreSettings { From 33e2b52a01fce046082b4aa8f7933b73343ecd6c Mon Sep 17 00:00:00 2001 From: KCaverly Date: Tue, 11 Jul 2023 20:12:43 -0400 Subject: [PATCH 51/51] added test registration for project settings --- crates/vector_store/src/vector_store_tests.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/crates/vector_store/src/vector_store_tests.rs b/crates/vector_store/src/vector_store_tests.rs index 8c5a667c7d..b6e47e7a23 100644 --- a/crates/vector_store/src/vector_store_tests.rs +++ b/crates/vector_store/src/vector_store_tests.rs @@ -5,7 +5,7 @@ use anyhow::Result; use async_trait::async_trait; use gpui::{Task, TestAppContext}; use language::{Language, LanguageConfig, LanguageRegistry}; -use project::{FakeFs, Project}; +use project::{project_settings::ProjectSettings, FakeFs, Project}; use rand::{rngs::StdRng, Rng}; use serde_json::json; use settings::SettingsStore; @@ -17,6 +17,7 @@ async fn test_vector_store(cx: &mut TestAppContext) { cx.update(|cx| { cx.set_global(SettingsStore::test(cx)); settings::register::(cx); + settings::register::(cx); }); let fs = FakeFs::new(cx.background());