From 18a5a47f8ab758d0b4288871457af5aa05d1404b Mon Sep 17 00:00:00 2001 From: KCaverly Date: Fri, 30 Jun 2023 18:41:19 -0400 Subject: [PATCH] moved semantic search model to dev and preview only. moved db update tasks to long lived persistent task. Co-authored-by: maxbrunsfeld --- crates/project/src/project.rs | 5 + crates/vector_store/src/modal.rs | 2 +- crates/vector_store/src/vector_store.rs | 328 ++++++++++++------ crates/vector_store/src/vector_store_tests.rs | 25 +- 4 files changed, 239 insertions(+), 121 deletions(-) diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index bbb2064da2..eb0004850c 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -260,6 +260,7 @@ pub enum Event { ActiveEntryChanged(Option), WorktreeAdded, WorktreeRemoved(WorktreeId), + WorktreeUpdatedEntries(WorktreeId, UpdatedEntriesSet), DiskBasedDiagnosticsStarted { language_server_id: LanguageServerId, }, @@ -5371,6 +5372,10 @@ impl Project { this.update_local_worktree_buffers(&worktree, changes, cx); this.update_local_worktree_language_servers(&worktree, changes, cx); this.update_local_worktree_settings(&worktree, changes, cx); + cx.emit(Event::WorktreeUpdatedEntries( + worktree.read(cx).id(), + changes.clone(), + )); } worktree::Event::UpdatedGitRepositories(updated_repos) => { this.update_local_worktree_buffers_git_repos(worktree, updated_repos, cx) diff --git a/crates/vector_store/src/modal.rs b/crates/vector_store/src/modal.rs index 4d377c6819..9225fe8786 100644 --- a/crates/vector_store/src/modal.rs +++ b/crates/vector_store/src/modal.rs @@ -124,7 +124,7 @@ impl PickerDelegate for SemanticSearchDelegate { if let Some(retrieved) = retrieved_cached.log_err() { if !retrieved { let task = vector_store.update(&mut cx, |store, cx| { - store.search(&project, query.to_string(), 10, cx) + store.search(project.clone(), query.to_string(), 10, cx) }); if let Some(results) = task.await.log_err() { diff --git a/crates/vector_store/src/vector_store.rs b/crates/vector_store/src/vector_store.rs index c329206c4b..3f0a7001ef 100644 --- a/crates/vector_store/src/vector_store.rs +++ b/crates/vector_store/src/vector_store.rs @@ -8,7 +8,11 @@ mod vector_store_tests; use anyhow::{anyhow, Result}; use db::VectorDatabase; use embedding::{EmbeddingProvider, OpenAIEmbeddings}; -use gpui::{AppContext, Entity, ModelContext, ModelHandle, Task, ViewContext}; +use futures::{channel::oneshot, Future}; +use gpui::{ + AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, ViewContext, + WeakModelHandle, +}; use language::{Language, LanguageRegistry}; use modal::{SemanticSearch, SemanticSearchDelegate, Toggle}; use project::{Fs, Project, WorktreeId}; @@ -22,7 +26,10 @@ use std::{ }; use tree_sitter::{Parser, QueryCursor}; use util::{ - channel::RELEASE_CHANNEL_NAME, http::HttpClient, paths::EMBEDDINGS_DIR, ResultExt, TryFutureExt, + channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME}, + http::HttpClient, + paths::EMBEDDINGS_DIR, + ResultExt, }; use workspace::{Workspace, WorkspaceCreated}; @@ -39,12 +46,16 @@ pub fn init( language_registry: Arc, cx: &mut AppContext, ) { + if *RELEASE_CHANNEL == ReleaseChannel::Stable { + return; + } + let db_file_path = EMBEDDINGS_DIR .join(Path::new(RELEASE_CHANNEL_NAME.as_str())) .join("embeddings_db"); - let vector_store = cx.add_model(|_| { - VectorStore::new( + cx.spawn(move |mut cx| async move { + let vector_store = VectorStore::new( fs, db_file_path, // Arc::new(embedding::DummyEmbeddings {}), @@ -52,42 +63,49 @@ pub fn init( client: http_client, }), language_registry, + cx.clone(), ) - }); + .await?; - cx.subscribe_global::({ - let vector_store = vector_store.clone(); - move |event, cx| { - let workspace = &event.0; - if let Some(workspace) = workspace.upgrade(cx) { - let project = workspace.read(cx).project().clone(); - if project.read(cx).is_local() { - vector_store.update(cx, |store, cx| { - store.add_project(project, cx).detach(); - }); + cx.update(|cx| { + cx.subscribe_global::({ + let vector_store = vector_store.clone(); + move |event, cx| { + let workspace = &event.0; + if let Some(workspace) = workspace.upgrade(cx) { + let project = workspace.read(cx).project().clone(); + if project.read(cx).is_local() { + vector_store.update(cx, |store, cx| { + store.add_project(project, cx).detach(); + }); + } + } } - } - } + }) + .detach(); + + cx.add_action({ + move |workspace: &mut Workspace, _: &Toggle, cx: &mut ViewContext| { + let vector_store = vector_store.clone(); + workspace.toggle_modal(cx, |workspace, cx| { + let project = workspace.project().clone(); + let workspace = cx.weak_handle(); + cx.add_view(|cx| { + SemanticSearch::new( + SemanticSearchDelegate::new(workspace, project, vector_store), + cx, + ) + }) + }) + } + }); + + SemanticSearch::init(cx); + }); + + anyhow::Ok(()) }) .detach(); - - cx.add_action({ - move |workspace: &mut Workspace, _: &Toggle, cx: &mut ViewContext| { - let vector_store = vector_store.clone(); - workspace.toggle_modal(cx, |workspace, cx| { - let project = workspace.project().clone(); - let workspace = cx.weak_handle(); - cx.add_view(|cx| { - SemanticSearch::new( - SemanticSearchDelegate::new(workspace, project, vector_store), - cx, - ) - }) - }) - } - }); - - SemanticSearch::init(cx); } #[derive(Debug)] @@ -102,7 +120,14 @@ pub struct VectorStore { database_url: Arc, embedding_provider: Arc, language_registry: Arc, + db_update_tx: channel::Sender, + _db_update_task: Task<()>, + projects: HashMap, ProjectState>, +} + +struct ProjectState { worktree_db_ids: Vec<(WorktreeId, i64)>, + _subscription: gpui::Subscription, } #[derive(Debug, Clone)] @@ -113,20 +138,81 @@ pub struct SearchResult { pub file_path: PathBuf, } +enum DbWrite { + InsertFile { + worktree_id: i64, + indexed_file: IndexedFile, + }, + Delete { + worktree_id: i64, + path: PathBuf, + }, + FindOrCreateWorktree { + path: PathBuf, + sender: oneshot::Sender>, + }, +} + impl VectorStore { - fn new( + async fn new( fs: Arc, database_url: PathBuf, embedding_provider: Arc, language_registry: Arc, - ) -> Self { - Self { - fs, - database_url: Arc::new(database_url), - embedding_provider, - language_registry, - worktree_db_ids: Vec::new(), - } + mut cx: AsyncAppContext, + ) -> Result> { + let database_url = Arc::new(database_url); + + let db = cx + .background() + .spawn({ + let fs = fs.clone(); + let database_url = database_url.clone(); + async move { + if let Some(db_directory) = database_url.parent() { + fs.create_dir(db_directory).await.log_err(); + } + + let db = VectorDatabase::new(database_url.to_string_lossy().to_string())?; + anyhow::Ok(db) + } + }) + .await?; + + Ok(cx.add_model(|cx| { + let (db_update_tx, db_update_rx) = channel::unbounded(); + let _db_update_task = cx.background().spawn(async move { + while let Ok(job) = db_update_rx.recv().await { + match job { + DbWrite::InsertFile { + worktree_id, + indexed_file, + } => { + log::info!("Inserting File: {:?}", &indexed_file.path); + db.insert_file(worktree_id, indexed_file).log_err(); + } + DbWrite::Delete { worktree_id, path } => { + log::info!("Deleting File: {:?}", &path); + db.delete_file(worktree_id, path).log_err(); + } + DbWrite::FindOrCreateWorktree { path, sender } => { + let id = db.find_or_create_worktree(&path); + sender.send(id).ok(); + } + } + } + }); + + Self { + fs, + database_url, + db_update_tx, + embedding_provider, + language_registry, + projects: HashMap::new(), + _db_update_task, + } + })) } async fn index_file( @@ -196,6 +282,14 @@ impl VectorStore { }); } + fn find_or_create_worktree(&self, path: PathBuf) -> impl Future> { + let (tx, rx) = oneshot::channel(); + self.db_update_tx + .try_send(DbWrite::FindOrCreateWorktree { path, sender: tx }) + .unwrap(); + async move { rx.await? } + } + fn add_project( &mut self, project: ModelHandle, @@ -211,19 +305,28 @@ impl VectorStore { } }) .collect::>(); + let worktree_db_ids = project + .read(cx) + .worktrees(cx) + .map(|worktree| { + self.find_or_create_worktree(worktree.read(cx).abs_path().to_path_buf()) + }) + .collect::>(); let fs = self.fs.clone(); let language_registry = self.language_registry.clone(); let embedding_provider = self.embedding_provider.clone(); let database_url = self.database_url.clone(); + let db_update_tx = self.db_update_tx.clone(); cx.spawn(|this, mut cx| async move { futures::future::join_all(worktree_scans_complete).await; + let worktree_db_ids = futures::future::join_all(worktree_db_ids).await; + if let Some(db_directory) = database_url.parent() { fs.create_dir(db_directory).await.log_err(); } - let db = VectorDatabase::new(database_url.to_string_lossy().into())?; let worktrees = project.read_with(&cx, |project, cx| { project @@ -234,32 +337,31 @@ impl VectorStore { // Here we query the worktree ids, and yet we dont have them elsewhere // We likely want to clean up these datastructures - let (db, mut worktree_file_times, worktree_db_ids) = cx + let (mut worktree_file_times, db_ids_by_worktree_id) = cx .background() .spawn({ let worktrees = worktrees.clone(); async move { - let mut worktree_db_ids: HashMap = HashMap::new(); + let db = VectorDatabase::new(database_url.to_string_lossy().into())?; + let mut db_ids_by_worktree_id = HashMap::new(); let mut file_times: HashMap> = HashMap::new(); - for worktree in worktrees { - let worktree_db_id = - db.find_or_create_worktree(worktree.abs_path().as_ref())?; - worktree_db_ids.insert(worktree.id(), worktree_db_id); - file_times.insert(worktree.id(), db.get_file_mtimes(worktree_db_id)?); + for (worktree, db_id) in worktrees.iter().zip(worktree_db_ids) { + let db_id = db_id?; + db_ids_by_worktree_id.insert(worktree.id(), db_id); + file_times.insert(worktree.id(), db.get_file_mtimes(db_id)?); } - anyhow::Ok((db, file_times, worktree_db_ids)) + anyhow::Ok((file_times, db_ids_by_worktree_id)) } }) .await?; let (paths_tx, paths_rx) = channel::unbounded::<(i64, PathBuf, Arc, SystemTime)>(); - let (delete_paths_tx, delete_paths_rx) = channel::unbounded::<(i64, PathBuf)>(); - let (indexed_files_tx, indexed_files_rx) = channel::unbounded::<(i64, IndexedFile)>(); cx.background() .spawn({ - let worktree_db_ids = worktree_db_ids.clone(); + let db_ids_by_worktree_id = db_ids_by_worktree_id.clone(); + let db_update_tx = db_update_tx.clone(); async move { for worktree in worktrees.into_iter() { let mut file_mtimes = @@ -289,7 +391,7 @@ impl VectorStore { if !already_stored { paths_tx .try_send(( - worktree_db_ids[&worktree.id()], + db_ids_by_worktree_id[&worktree.id()], path_buf, language, file.mtime, @@ -299,8 +401,11 @@ impl VectorStore { } } for file in file_mtimes.keys() { - delete_paths_tx - .try_send((worktree_db_ids[&worktree.id()], file.to_owned())) + db_update_tx + .try_send(DbWrite::Delete { + worktree_id: db_ids_by_worktree_id[&worktree.id()], + path: file.to_owned(), + }) .unwrap(); } } @@ -308,25 +413,6 @@ impl VectorStore { }) .detach(); - let db_update_task = cx.background().spawn( - async move { - // Inserting all new files - while let Ok((worktree_id, indexed_file)) = indexed_files_rx.recv().await { - log::info!("Inserting File: {:?}", &indexed_file.path); - db.insert_file(worktree_id, indexed_file).log_err(); - } - - // Deleting all old files - while let Ok((worktree_id, delete_path)) = delete_paths_rx.recv().await { - log::info!("Deleting File: {:?}", &delete_path); - db.delete_file(worktree_id, delete_path).log_err(); - } - - anyhow::Ok(()) - } - .log_err(), - ); - cx.background() .scoped(|scope| { for _ in 0..cx.background().num_cpus() { @@ -348,8 +434,11 @@ impl VectorStore { .await .log_err() { - indexed_files_tx - .try_send((worktree_id, indexed_file)) + db_update_tx + .try_send(DbWrite::InsertFile { + worktree_id, + indexed_file, + }) .unwrap(); } } @@ -357,12 +446,22 @@ impl VectorStore { } }) .await; - drop(indexed_files_tx); - db_update_task.await; + this.update(&mut cx, |this, cx| { + let _subscription = cx.subscribe(&project, |this, project, event, cx| { + if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event { + // + log::info!("worktree changes {:?}", changes); + } + }); - this.update(&mut cx, |this, _| { - this.worktree_db_ids.extend(worktree_db_ids); + this.projects.insert( + project.downgrade(), + ProjectState { + worktree_db_ids: db_ids_by_worktree_id.into_iter().collect(), + _subscription, + }, + ); }); log::info!("Semantic Indexing Complete!"); @@ -373,23 +472,32 @@ impl VectorStore { pub fn search( &mut self, - project: &ModelHandle, + project: ModelHandle, phrase: String, limit: usize, cx: &mut ModelContext, ) -> Task>> { - let project = project.read(cx); + let project_state = if let Some(state) = self.projects.get(&project.downgrade()) { + state + } else { + return Task::ready(Err(anyhow!("project not added"))); + }; + let worktree_db_ids = project + .read(cx) .worktrees(cx) .filter_map(|worktree| { let worktree_id = worktree.read(cx).id(); - self.worktree_db_ids.iter().find_map(|(id, db_id)| { - if *id == worktree_id { - Some(*db_id) - } else { - None - } - }) + project_state + .worktree_db_ids + .iter() + .find_map(|(id, db_id)| { + if *id == worktree_id { + Some(*db_id) + } else { + None + } + }) }) .collect::>(); @@ -428,17 +536,27 @@ impl VectorStore { }) .await?; - let results = this.read_with(&cx, |this, _| { - documents + this.read_with(&cx, |this, _| { + let project_state = if let Some(state) = this.projects.get(&project.downgrade()) { + state + } else { + return Err(anyhow!("project not added")); + }; + + Ok(documents .into_iter() .filter_map(|(worktree_db_id, file_path, offset, name)| { - let worktree_id = this.worktree_db_ids.iter().find_map(|(id, db_id)| { - if *db_id == worktree_db_id { - Some(*id) - } else { - None - } - })?; + let worktree_id = + project_state + .worktree_db_ids + .iter() + .find_map(|(id, db_id)| { + if *db_id == worktree_db_id { + Some(*id) + } else { + None + } + })?; Some(SearchResult { worktree_id, name, @@ -446,10 +564,8 @@ impl VectorStore { file_path, }) }) - .collect() - }); - - anyhow::Ok(results) + .collect()) + }) }) } } diff --git a/crates/vector_store/src/vector_store_tests.rs b/crates/vector_store/src/vector_store_tests.rs index 78470ad4be..51065c0ee4 100644 --- a/crates/vector_store/src/vector_store_tests.rs +++ b/crates/vector_store/src/vector_store_tests.rs @@ -5,7 +5,7 @@ use anyhow::Result; use async_trait::async_trait; use gpui::{Task, TestAppContext}; use language::{Language, LanguageConfig, LanguageRegistry}; -use project::{FakeFs, Project}; +use project::{FakeFs, Fs, Project}; use rand::Rng; use serde_json::json; use unindent::Unindent; @@ -60,14 +60,15 @@ async fn test_vector_store(cx: &mut TestAppContext) { let db_dir = tempdir::TempDir::new("vector-store").unwrap(); let db_path = db_dir.path().join("db.sqlite"); - let store = cx.add_model(|_| { - VectorStore::new( - fs.clone(), - db_path, - Arc::new(FakeEmbeddingProvider), - languages, - ) - }); + let store = VectorStore::new( + fs.clone(), + db_path, + Arc::new(FakeEmbeddingProvider), + languages, + cx.to_async(), + ) + .await + .unwrap(); let project = Project::test(fs, ["/the-root".as_ref()], cx).await; let worktree_id = project.read_with(cx, |project, cx| { @@ -75,15 +76,11 @@ async fn test_vector_store(cx: &mut TestAppContext) { }); let add_project = store.update(cx, |store, cx| store.add_project(project.clone(), cx)); - // TODO - remove - cx.foreground() - .advance_clock(std::time::Duration::from_secs(3)); - add_project.await.unwrap(); let search_results = store .update(cx, |store, cx| { - store.search(&project, "aaaa".to_string(), 5, cx) + store.search(project.clone(), "aaaa".to_string(), 5, cx) }) .await .unwrap();