diff --git a/Cargo.lock b/Cargo.lock index 430a665f98..484ef3644b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6430,6 +6430,7 @@ dependencies = [ "menu", "postage", "project", + "semantic_index", "serde", "serde_derive", "serde_json", @@ -6484,6 +6485,7 @@ dependencies = [ "matrixmultiply", "parking_lot 0.11.2", "picker", + "postage", "project", "rand 0.8.5", "rpc", diff --git a/crates/search/Cargo.toml b/crates/search/Cargo.toml index 7ef388f7c0..f6ed6c3fef 100644 --- a/crates/search/Cargo.toml +++ b/crates/search/Cargo.toml @@ -19,6 +19,7 @@ settings = { path = "../settings" } theme = { path = "../theme" } util = { path = "../util" } workspace = { path = "../workspace" } +semantic_index = { path = "../semantic_index" } anyhow.workspace = true futures.workspace = true log.workspace = true diff --git a/crates/search/src/project_search.rs b/crates/search/src/project_search.rs index ebd504d02c..91d2b142ae 100644 --- a/crates/search/src/project_search.rs +++ b/crates/search/src/project_search.rs @@ -2,7 +2,7 @@ use crate::{ SearchOption, SelectNextMatch, SelectPrevMatch, ToggleCaseSensitive, ToggleRegex, ToggleWholeWord, }; -use anyhow::Result; +use anyhow::{Context, Result}; use collections::HashMap; use editor::{ items::active_match_index, scroll::autoscroll::Autoscroll, Anchor, Editor, MultiBuffer, @@ -18,7 +18,9 @@ use gpui::{ Task, View, ViewContext, ViewHandle, WeakModelHandle, WeakViewHandle, }; use menu::Confirm; +use postage::stream::Stream; use project::{search::SearchQuery, Project}; +use semantic_index::SemanticIndex; use smallvec::SmallVec; use std::{ any::{Any, TypeId}, @@ -36,7 +38,10 @@ use workspace::{ ItemNavHistory, Pane, ToolbarItemLocation, ToolbarItemView, Workspace, WorkspaceId, }; -actions!(project_search, [SearchInNew, ToggleFocus, NextField]); +actions!( + project_search, + [SearchInNew, ToggleFocus, NextField, ToggleSemanticSearch] +); #[derive(Default)] struct ActiveSearches(HashMap, WeakViewHandle>); @@ -92,6 +97,7 @@ pub struct ProjectSearchView { case_sensitive: bool, whole_word: bool, regex: bool, + semantic: Option, panels_with_errors: HashSet, active_match_index: Option, search_id: usize, @@ -100,6 +106,13 @@ pub struct ProjectSearchView { excluded_files_editor: ViewHandle, } +struct SemanticSearchState { + file_count: usize, + outstanding_file_count: usize, + _progress_task: Task<()>, + search_task: Option>>, +} + pub struct ProjectSearchBar { active_project_search: Option>, subscription: Option, @@ -198,12 +211,25 @@ impl View for ProjectSearchView { let theme = theme::current(cx).clone(); let text = if self.query_editor.read(cx).text(cx).is_empty() { - "" + Cow::Borrowed("") + } else if let Some(semantic) = &self.semantic { + if semantic.search_task.is_some() { + Cow::Borrowed("Searching...") + } else if semantic.outstanding_file_count > 0 { + Cow::Owned(format!( + "Indexing. {} of {}...", + semantic.file_count - semantic.outstanding_file_count, + semantic.file_count + )) + } else { + Cow::Borrowed("Indexing complete") + } } else if model.pending_search.is_some() { - "Searching..." + Cow::Borrowed("Searching...") } else { - "No results" + Cow::Borrowed("No results") }; + MouseEventHandler::::new(0, cx, |_, _| { Label::new(text, theme.search.results_status.clone()) .aligned() @@ -499,6 +525,7 @@ impl ProjectSearchView { case_sensitive, whole_word, regex, + semantic: None, panels_with_errors: HashSet::new(), active_match_index: None, query_editor_was_focused: false, @@ -563,6 +590,35 @@ impl ProjectSearchView { } fn search(&mut self, cx: &mut ViewContext) { + if let Some(semantic) = &mut self.semantic { + if semantic.outstanding_file_count > 0 { + return; + } + + let search_phrase = self.query_editor.read(cx).text(cx); + let project = self.model.read(cx).project.clone(); + if let Some(semantic_index) = SemanticIndex::global(cx) { + let search_task = semantic_index.update(cx, |semantic_index, cx| { + semantic_index.search_project(project, search_phrase, 10, cx) + }); + semantic.search_task = Some(cx.spawn(|this, mut cx| async move { + let results = search_task.await.context("search task")?; + + this.update(&mut cx, |this, cx| { + dbg!(&results); + // TODO: Update results + + if let Some(semantic) = &mut this.semantic { + semantic.search_task = None; + } + })?; + + anyhow::Ok(()) + })); + } + return; + } + if let Some(query) = self.build_search_query(cx) { self.model.update(cx, |model, cx| model.search(query, cx)); } @@ -876,6 +932,59 @@ impl ProjectSearchBar { } } + fn toggle_semantic_search(&mut self, cx: &mut ViewContext) -> bool { + if let Some(search_view) = self.active_project_search.as_ref() { + search_view.update(cx, |search_view, cx| { + if search_view.semantic.is_some() { + search_view.semantic = None; + } else if let Some(semantic_index) = SemanticIndex::global(cx) { + // TODO: confirm that it's ok to send this project + + let project = search_view.model.read(cx).project.clone(); + let index_task = semantic_index.update(cx, |semantic_index, cx| { + semantic_index.index_project(project, cx) + }); + + cx.spawn(|search_view, mut cx| async move { + let (files_to_index, mut files_remaining_rx) = index_task.await?; + + search_view.update(&mut cx, |search_view, cx| { + search_view.semantic = Some(SemanticSearchState { + file_count: files_to_index, + outstanding_file_count: files_to_index, + search_task: None, + _progress_task: cx.spawn(|search_view, mut cx| async move { + while let Some(count) = files_remaining_rx.recv().await { + search_view + .update(&mut cx, |search_view, cx| { + if let Some(semantic_search_state) = + &mut search_view.semantic + { + semantic_search_state.outstanding_file_count = + count; + cx.notify(); + if count == 0 { + return; + } + } + }) + .ok(); + } + }), + }); + })?; + anyhow::Ok(()) + }) + .detach_and_log_err(cx); + } + }); + cx.notify(); + true + } else { + false + } + } + fn render_nav_button( &self, icon: &'static str, @@ -953,6 +1062,42 @@ impl ProjectSearchBar { .into_any() } + fn render_semantic_search_button(&self, cx: &mut ViewContext) -> AnyElement { + let tooltip_style = theme::current(cx).tooltip.clone(); + let is_active = if let Some(search) = self.active_project_search.as_ref() { + let search = search.read(cx); + search.semantic.is_some() + } else { + false + }; + + let region_id = 3; + + MouseEventHandler::::new(region_id, cx, |state, cx| { + let theme = theme::current(cx); + let style = theme + .search + .option_button + .in_state(is_active) + .style_for(state); + Label::new("Semantic", style.text.clone()) + .contained() + .with_style(style.container) + }) + .on_click(MouseButton::Left, move |_, this, cx| { + this.toggle_semantic_search(cx); + }) + .with_cursor_style(CursorStyle::PointingHand) + .with_tooltip::( + region_id, + format!("Toggle Semantic Search"), + Some(Box::new(ToggleSemanticSearch)), + tooltip_style, + cx, + ) + .into_any() + } + fn is_option_enabled(&self, option: SearchOption, cx: &AppContext) -> bool { if let Some(search) = self.active_project_search.as_ref() { let search = search.read(cx); @@ -1049,6 +1194,7 @@ impl View for ProjectSearchBar { ) .with_child( Flex::row() + .with_child(self.render_semantic_search_button(cx)) .with_child(self.render_option_button( "Case", SearchOption::CaseSensitive, diff --git a/crates/semantic_index/Cargo.toml b/crates/semantic_index/Cargo.toml index 5c5af072c8..2d21ff6c1c 100644 --- a/crates/semantic_index/Cargo.toml +++ b/crates/semantic_index/Cargo.toml @@ -20,6 +20,7 @@ editor = { path = "../editor" } rpc = { path = "../rpc" } settings = { path = "../settings" } anyhow.workspace = true +postage.workspace = true futures.workspace = true smol.workspace = true rusqlite = { version = "0.27.0", features = ["blob", "array", "modern_sqlite"] } diff --git a/crates/semantic_index/src/db.rs b/crates/semantic_index/src/db.rs index 1d5a9a475e..a667ff877c 100644 --- a/crates/semantic_index/src/db.rs +++ b/crates/semantic_index/src/db.rs @@ -1,5 +1,5 @@ use crate::{parsing::Document, SEMANTIC_INDEX_VERSION}; -use anyhow::{anyhow, Result}; +use anyhow::{anyhow, Context, Result}; use project::Fs; use rpc::proto::Timestamp; use rusqlite::{ @@ -76,14 +76,14 @@ impl VectorDatabase { self.db .execute( " - DROP TABLE semantic_index_config; - DROP TABLE worktrees; - DROP TABLE files; - DROP TABLE documents; + DROP TABLE IF EXISTS documents; + DROP TABLE IF EXISTS files; + DROP TABLE IF EXISTS worktrees; + DROP TABLE IF EXISTS semantic_index_config; ", [], ) - .ok(); + .context("failed to drop tables")?; // Initialize Vector Databasing Tables self.db.execute( diff --git a/crates/semantic_index/src/embedding.rs b/crates/semantic_index/src/embedding.rs index ea349c8afa..4f49d66ce7 100644 --- a/crates/semantic_index/src/embedding.rs +++ b/crates/semantic_index/src/embedding.rs @@ -86,6 +86,7 @@ impl OpenAIEmbeddings { async fn send_request(&self, api_key: &str, spans: Vec<&str>) -> Result> { let request = Request::post("https://api.openai.com/v1/embeddings") .redirect_policy(isahc::config::RedirectPolicy::Follow) + .timeout(Duration::from_secs(4)) .header("Content-Type", "application/json") .header("Authorization", format!("Bearer {}", api_key)) .body( @@ -133,7 +134,11 @@ impl EmbeddingProvider for OpenAIEmbeddings { self.executor.timer(delay).await; } StatusCode::BAD_REQUEST => { - log::info!("BAD REQUEST: {:?}", &response.status()); + log::info!( + "BAD REQUEST: {:?} {:?}", + &response.status(), + response.body() + ); // Don't worry about delaying bad request, as we can assume // we haven't been rate limited yet. for span in spans.iter_mut() { diff --git a/crates/semantic_index/src/modal.rs b/crates/semantic_index/src/modal.rs deleted file mode 100644 index ffc64a195c..0000000000 --- a/crates/semantic_index/src/modal.rs +++ /dev/null @@ -1,172 +0,0 @@ -use crate::{SearchResult, SemanticIndex}; -use editor::{scroll::autoscroll::Autoscroll, Editor}; -use gpui::{ - actions, elements::*, AnyElement, AppContext, ModelHandle, MouseState, Task, ViewContext, - WeakViewHandle, -}; -use picker::{Picker, PickerDelegate, PickerEvent}; -use project::{Project, ProjectPath}; -use std::{collections::HashMap, sync::Arc, time::Duration}; -use util::ResultExt; -use workspace::Workspace; - -const MIN_QUERY_LEN: usize = 5; -const EMBEDDING_DEBOUNCE_INTERVAL: Duration = Duration::from_millis(500); - -actions!(semantic_search, [Toggle]); - -pub type SemanticSearch = Picker; - -pub struct SemanticSearchDelegate { - workspace: WeakViewHandle, - project: ModelHandle, - semantic_index: ModelHandle, - selected_match_index: usize, - matches: Vec, - history: HashMap>, -} - -impl SemanticSearchDelegate { - // This is currently searching on every keystroke, - // This is wildly overkill, and has the potential to get expensive - // We will need to update this to throttle searching - pub fn new( - workspace: WeakViewHandle, - project: ModelHandle, - semantic_index: ModelHandle, - ) -> Self { - Self { - workspace, - project, - semantic_index, - selected_match_index: 0, - matches: vec![], - history: HashMap::new(), - } - } -} - -impl PickerDelegate for SemanticSearchDelegate { - fn placeholder_text(&self) -> Arc { - "Search repository in natural language...".into() - } - - fn confirm(&mut self, cx: &mut ViewContext) { - if let Some(search_result) = self.matches.get(self.selected_match_index) { - // Open Buffer - let search_result = search_result.clone(); - let buffer = self.project.update(cx, |project, cx| { - project.open_buffer( - ProjectPath { - worktree_id: search_result.worktree_id, - path: search_result.file_path.clone().into(), - }, - cx, - ) - }); - - let workspace = self.workspace.clone(); - let position = search_result.clone().byte_range.start; - cx.spawn(|_, mut cx| async move { - let buffer = buffer.await?; - workspace.update(&mut cx, |workspace, cx| { - let editor = workspace.open_project_item::(buffer, cx); - editor.update(cx, |editor, cx| { - editor.change_selections(Some(Autoscroll::center()), cx, |s| { - s.select_ranges([position..position]) - }); - }); - })?; - Ok::<_, anyhow::Error>(()) - }) - .detach_and_log_err(cx); - cx.emit(PickerEvent::Dismiss); - } - } - - fn dismissed(&mut self, _cx: &mut ViewContext) {} - - fn match_count(&self) -> usize { - self.matches.len() - } - - fn selected_index(&self) -> usize { - self.selected_match_index - } - - fn set_selected_index(&mut self, ix: usize, _cx: &mut ViewContext) { - self.selected_match_index = ix; - } - - fn update_matches(&mut self, query: String, cx: &mut ViewContext) -> Task<()> { - log::info!("Searching for {:?}...", query); - if query.len() < MIN_QUERY_LEN { - log::info!("Query below minimum length"); - return Task::ready(()); - } - - let semantic_index = self.semantic_index.clone(); - let project = self.project.clone(); - cx.spawn(|this, mut cx| async move { - cx.background().timer(EMBEDDING_DEBOUNCE_INTERVAL).await; - - let retrieved_cached = this.update(&mut cx, |this, _| { - let delegate = this.delegate_mut(); - if delegate.history.contains_key(&query) { - let historic_results = delegate.history.get(&query).unwrap().to_owned(); - delegate.matches = historic_results.clone(); - true - } else { - false - } - }); - - if let Some(retrieved) = retrieved_cached.log_err() { - if !retrieved { - let task = semantic_index.update(&mut cx, |store, cx| { - store.search_project(project.clone(), query.to_string(), 10, cx) - }); - - if let Some(results) = task.await.log_err() { - log::info!("Not queried previously, searching..."); - this.update(&mut cx, |this, _| { - let delegate = this.delegate_mut(); - delegate.matches = results.clone(); - delegate.history.insert(query, results); - }) - .ok(); - } - } else { - log::info!("Already queried, retrieved directly from cached history"); - } - } - }) - } - - fn render_match( - &self, - ix: usize, - mouse_state: &mut MouseState, - selected: bool, - cx: &AppContext, - ) -> AnyElement> { - let theme = theme::current(cx); - let style = &theme.picker.item; - let current_style = style.in_state(selected).style_for(mouse_state); - - let search_result = &self.matches[ix]; - - let path = search_result.file_path.to_string_lossy(); - let name = search_result.name.clone(); - - Flex::column() - .with_child(Text::new(name, current_style.label.text.clone()).with_soft_wrap(false)) - .with_child(Label::new( - path.to_string(), - style.inactive_state().default.label.clone(), - )) - .contained() - .with_style(current_style.container) - .into_any() - } -} diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index b59b20370a..e6443870aa 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -1,6 +1,5 @@ mod db; mod embedding; -mod modal; mod parsing; mod semantic_index_settings; @@ -12,25 +11,20 @@ use anyhow::{anyhow, Result}; use db::VectorDatabase; use embedding::{EmbeddingProvider, OpenAIEmbeddings}; use futures::{channel::oneshot, Future}; -use gpui::{ - AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, ViewContext, - WeakModelHandle, -}; +use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle}; use language::{Language, LanguageRegistry}; -use modal::{SemanticSearch, SemanticSearchDelegate, Toggle}; use parking_lot::Mutex; use parsing::{CodeContextRetriever, Document, PARSEABLE_ENTIRE_FILE_TYPES}; +use postage::watch; use project::{Fs, Project, WorktreeId}; use smol::channel; use std::{ - collections::{HashMap, HashSet}, + collections::HashMap, + mem, ops::Range, path::{Path, PathBuf}, - sync::{ - atomic::{self, AtomicUsize}, - Arc, Weak, - }, - time::{Instant, SystemTime}, + sync::{Arc, Weak}, + time::SystemTime, }; use util::{ channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME}, @@ -38,9 +32,8 @@ use util::{ paths::EMBEDDINGS_DIR, ResultExt, }; -use workspace::{Workspace, WorkspaceCreated}; -const SEMANTIC_INDEX_VERSION: usize = 1; +const SEMANTIC_INDEX_VERSION: usize = 3; const EMBEDDINGS_BATCH_SIZE: usize = 150; pub fn init( @@ -55,25 +48,6 @@ pub fn init( .join(Path::new(RELEASE_CHANNEL_NAME.as_str())) .join("embeddings_db"); - SemanticSearch::init(cx); - cx.add_action( - |workspace: &mut Workspace, _: &Toggle, cx: &mut ViewContext| { - if cx.has_global::>() { - let semantic_index = cx.global::>().clone(); - workspace.toggle_modal(cx, |workspace, cx| { - let project = workspace.project().clone(); - let workspace = cx.weak_handle(); - cx.add_view(|cx| { - SemanticSearch::new( - SemanticSearchDelegate::new(workspace, project, semantic_index), - cx, - ) - }) - }); - } - }, - ); - if *RELEASE_CHANNEL == ReleaseChannel::Stable || !settings::get::(cx).enabled { @@ -95,21 +69,6 @@ pub fn init( cx.update(|cx| { cx.set_global(semantic_index.clone()); - cx.subscribe_global::({ - let semantic_index = semantic_index.clone(); - move |event, cx| { - let workspace = &event.0; - if let Some(workspace) = workspace.upgrade(cx) { - let project = workspace.read(cx).project().clone(); - if project.read(cx).is_local() { - semantic_index.update(cx, |store, cx| { - store.index_project(project, cx).detach(); - }); - } - } - } - }) - .detach(); }); anyhow::Ok(()) @@ -128,20 +87,17 @@ pub struct SemanticIndex { _embed_batch_task: Task<()>, _batch_files_task: Task<()>, _parsing_files_tasks: Vec>, - next_job_id: Arc, projects: HashMap, ProjectState>, } struct ProjectState { worktree_db_ids: Vec<(WorktreeId, i64)>, - outstanding_jobs: Arc>>, + outstanding_job_count_rx: watch::Receiver, + outstanding_job_count_tx: Arc>>, } -type JobId = usize; - struct JobHandle { - id: JobId, - set: Weak>>, + tx: Weak>>, } impl ProjectState { @@ -221,6 +177,14 @@ enum EmbeddingJob { } impl SemanticIndex { + pub fn global(cx: &AppContext) -> Option> { + if cx.has_global::>() { + Some(cx.global::>().clone()) + } else { + None + } + } + async fn new( fs: Arc, database_url: PathBuf, @@ -236,184 +200,69 @@ impl SemanticIndex { .await?; Ok(cx.add_model(|cx| { - // paths_tx -> embeddings_tx -> db_update_tx - - //db_update_tx/rx: Updating Database + // Perform database operations let (db_update_tx, db_update_rx) = channel::unbounded(); - let _db_update_task = cx.background().spawn(async move { - while let Ok(job) = db_update_rx.recv().await { - match job { - DbOperation::InsertFile { - worktree_id, - documents, - path, - mtime, - job_handle, - } => { - db.insert_file(worktree_id, path, mtime, documents) - .log_err(); - drop(job_handle) - } - DbOperation::Delete { worktree_id, path } => { - db.delete_file(worktree_id, path).log_err(); - } - DbOperation::FindOrCreateWorktree { path, sender } => { - let id = db.find_or_create_worktree(&path); - sender.send(id).ok(); - } - DbOperation::FileMTimes { - worktree_id: worktree_db_id, - sender, - } => { - let file_mtimes = db.get_file_mtimes(worktree_db_id); - sender.send(file_mtimes).ok(); - } + let _db_update_task = cx.background().spawn({ + async move { + while let Ok(job) = db_update_rx.recv().await { + Self::run_db_operation(&db, job) } } }); - // embed_tx/rx: Embed Batch and Send to Database + // Group documents into batches and send them to the embedding provider. let (embed_batch_tx, embed_batch_rx) = channel::unbounded::, PathBuf, SystemTime, JobHandle)>>(); let _embed_batch_task = cx.background().spawn({ let db_update_tx = db_update_tx.clone(); let embedding_provider = embedding_provider.clone(); async move { - while let Ok(mut embeddings_queue) = embed_batch_rx.recv().await { - // Construct Batch - let mut batch_documents = vec![]; - for (_, documents, _, _, _) in embeddings_queue.iter() { - batch_documents - .extend(documents.iter().map(|document| document.content.as_str())); - } - - if let Ok(embeddings) = - embedding_provider.embed_batch(batch_documents).await - { - log::trace!( - "created {} embeddings for {} files", - embeddings.len(), - embeddings_queue.len(), - ); - - let mut i = 0; - let mut j = 0; - - for embedding in embeddings.iter() { - while embeddings_queue[i].1.len() == j { - i += 1; - j = 0; - } - - embeddings_queue[i].1[j].embedding = embedding.to_owned(); - j += 1; - } - - for (worktree_id, documents, path, mtime, job_handle) in - embeddings_queue.into_iter() - { - for document in documents.iter() { - // TODO: Update this so it doesn't panic - assert!( - document.embedding.len() > 0, - "Document Embedding Not Complete" - ); - } - - db_update_tx - .send(DbOperation::InsertFile { - worktree_id, - documents, - path, - mtime, - job_handle, - }) - .await - .unwrap(); - } - } + while let Ok(embeddings_queue) = embed_batch_rx.recv().await { + Self::compute_embeddings_for_batch( + embeddings_queue, + &embedding_provider, + &db_update_tx, + ) + .await; } } }); - // batch_tx/rx: Batch Files to Send for Embeddings + // Group documents into batches and send them to the embedding provider. let (batch_files_tx, batch_files_rx) = channel::unbounded::(); let _batch_files_task = cx.background().spawn(async move { let mut queue_len = 0; let mut embeddings_queue = vec![]; - while let Ok(job) = batch_files_rx.recv().await { - let should_flush = match job { - EmbeddingJob::Enqueue { - documents, - worktree_id, - path, - mtime, - job_handle, - } => { - queue_len += &documents.len(); - embeddings_queue.push(( - worktree_id, - documents, - path, - mtime, - job_handle, - )); - queue_len >= EMBEDDINGS_BATCH_SIZE - } - EmbeddingJob::Flush => true, - }; - - if should_flush { - embed_batch_tx.try_send(embeddings_queue).unwrap(); - embeddings_queue = vec![]; - queue_len = 0; - } + Self::enqueue_documents_to_embed( + job, + &mut queue_len, + &mut embeddings_queue, + &embed_batch_tx, + ); } }); - // parsing_files_tx/rx: Parsing Files to Embeddable Documents + // Parse files into embeddable documents. let (parsing_files_tx, parsing_files_rx) = channel::unbounded::(); - let mut _parsing_files_tasks = Vec::new(); for _ in 0..cx.background().num_cpus() { let fs = fs.clone(); let parsing_files_rx = parsing_files_rx.clone(); let batch_files_tx = batch_files_tx.clone(); + let db_update_tx = db_update_tx.clone(); _parsing_files_tasks.push(cx.background().spawn(async move { let mut retriever = CodeContextRetriever::new(); while let Ok(pending_file) = parsing_files_rx.recv().await { - if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() - { - if let Some(documents) = retriever - .parse_file( - &pending_file.relative_path, - &content, - pending_file.language, - ) - .log_err() - { - log::trace!( - "parsed path {:?}: {} documents", - pending_file.relative_path, - documents.len() - ); - - batch_files_tx - .try_send(EmbeddingJob::Enqueue { - worktree_id: pending_file.worktree_db_id, - path: pending_file.relative_path, - mtime: pending_file.modified_time, - job_handle: pending_file.job_handle, - documents, - }) - .unwrap(); - } - } - - if parsing_files_rx.len() == 0 { - batch_files_tx.try_send(EmbeddingJob::Flush).unwrap(); - } + Self::parse_file( + &fs, + pending_file, + &mut retriever, + &batch_files_tx, + &parsing_files_rx, + &db_update_tx, + ) + .await; } })); } @@ -424,7 +273,6 @@ impl SemanticIndex { embedding_provider, language_registry, db_update_tx, - next_job_id: Default::default(), parsing_files_tx, _db_update_task, _embed_batch_task, @@ -435,6 +283,167 @@ impl SemanticIndex { })) } + fn run_db_operation(db: &VectorDatabase, job: DbOperation) { + match job { + DbOperation::InsertFile { + worktree_id, + documents, + path, + mtime, + job_handle, + } => { + db.insert_file(worktree_id, path, mtime, documents) + .log_err(); + drop(job_handle) + } + DbOperation::Delete { worktree_id, path } => { + db.delete_file(worktree_id, path).log_err(); + } + DbOperation::FindOrCreateWorktree { path, sender } => { + let id = db.find_or_create_worktree(&path); + sender.send(id).ok(); + } + DbOperation::FileMTimes { + worktree_id: worktree_db_id, + sender, + } => { + let file_mtimes = db.get_file_mtimes(worktree_db_id); + sender.send(file_mtimes).ok(); + } + } + } + + async fn compute_embeddings_for_batch( + mut embeddings_queue: Vec<(i64, Vec, PathBuf, SystemTime, JobHandle)>, + embedding_provider: &Arc, + db_update_tx: &channel::Sender, + ) { + let mut batch_documents = vec![]; + for (_, documents, _, _, _) in embeddings_queue.iter() { + batch_documents.extend(documents.iter().map(|document| document.content.as_str())); + } + + if let Ok(embeddings) = embedding_provider.embed_batch(batch_documents).await { + log::trace!( + "created {} embeddings for {} files", + embeddings.len(), + embeddings_queue.len(), + ); + + let mut i = 0; + let mut j = 0; + + for embedding in embeddings.iter() { + while embeddings_queue[i].1.len() == j { + i += 1; + j = 0; + } + + embeddings_queue[i].1[j].embedding = embedding.to_owned(); + j += 1; + } + + for (worktree_id, documents, path, mtime, job_handle) in embeddings_queue.into_iter() { + // for document in documents.iter() { + // // TODO: Update this so it doesn't panic + // assert!( + // document.embedding.len() > 0, + // "Document Embedding Not Complete" + // ); + // } + + db_update_tx + .send(DbOperation::InsertFile { + worktree_id, + documents, + path, + mtime, + job_handle, + }) + .await + .unwrap(); + } + } + } + + fn enqueue_documents_to_embed( + job: EmbeddingJob, + queue_len: &mut usize, + embeddings_queue: &mut Vec<(i64, Vec, PathBuf, SystemTime, JobHandle)>, + embed_batch_tx: &channel::Sender, PathBuf, SystemTime, JobHandle)>>, + ) { + let should_flush = match job { + EmbeddingJob::Enqueue { + documents, + worktree_id, + path, + mtime, + job_handle, + } => { + *queue_len += &documents.len(); + embeddings_queue.push((worktree_id, documents, path, mtime, job_handle)); + *queue_len >= EMBEDDINGS_BATCH_SIZE + } + EmbeddingJob::Flush => true, + }; + + if should_flush { + embed_batch_tx + .try_send(mem::take(embeddings_queue)) + .unwrap(); + *queue_len = 0; + } + } + + async fn parse_file( + fs: &Arc, + pending_file: PendingFile, + retriever: &mut CodeContextRetriever, + batch_files_tx: &channel::Sender, + parsing_files_rx: &channel::Receiver, + db_update_tx: &channel::Sender, + ) { + if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() { + if let Some(documents) = retriever + .parse_file(&pending_file.relative_path, &content, pending_file.language) + .log_err() + { + log::trace!( + "parsed path {:?}: {} documents", + pending_file.relative_path, + documents.len() + ); + + if documents.len() == 0 { + db_update_tx + .send(DbOperation::InsertFile { + worktree_id: pending_file.worktree_db_id, + documents, + path: pending_file.relative_path, + mtime: pending_file.modified_time, + job_handle: pending_file.job_handle, + }) + .await + .unwrap(); + } else { + batch_files_tx + .try_send(EmbeddingJob::Enqueue { + worktree_id: pending_file.worktree_db_id, + path: pending_file.relative_path, + mtime: pending_file.modified_time, + job_handle: pending_file.job_handle, + documents, + }) + .unwrap(); + } + } + } + + if parsing_files_rx.len() == 0 { + batch_files_tx.try_send(EmbeddingJob::Flush).unwrap(); + } + } + fn find_or_create_worktree(&self, path: PathBuf) -> impl Future> { let (tx, rx) = oneshot::channel(); self.db_update_tx @@ -457,11 +466,11 @@ impl SemanticIndex { async move { rx.await? } } - fn index_project( + pub fn index_project( &mut self, project: ModelHandle, cx: &mut ModelContext, - ) -> Task> { + ) -> Task)>> { let worktree_scans_complete = project .read(cx) .worktrees(cx) @@ -483,7 +492,6 @@ impl SemanticIndex { let language_registry = self.language_registry.clone(); let db_update_tx = self.db_update_tx.clone(); let parsing_files_tx = self.parsing_files_tx.clone(); - let next_job_id = self.next_job_id.clone(); cx.spawn(|this, mut cx| async move { futures::future::join_all(worktree_scans_complete).await; @@ -509,8 +517,8 @@ impl SemanticIndex { ); } - // let mut pending_files: Vec<(PathBuf, ((i64, PathBuf, Arc, SystemTime), SystemTime))> = vec![]; - let outstanding_jobs = Arc::new(Mutex::new(HashSet::new())); + let (job_count_tx, job_count_rx) = watch::channel_with(0); + let job_count_tx = Arc::new(Mutex::new(job_count_tx)); this.update(&mut cx, |this, _| { this.projects.insert( project.downgrade(), @@ -519,7 +527,8 @@ impl SemanticIndex { .iter() .map(|(a, b)| (*a, *b)) .collect(), - outstanding_jobs: outstanding_jobs.clone(), + outstanding_job_count_rx: job_count_rx.clone(), + outstanding_job_count_tx: job_count_tx.clone(), }, ); }); @@ -527,7 +536,6 @@ impl SemanticIndex { cx.background() .spawn(async move { let mut count = 0; - let t0 = Instant::now(); for worktree in worktrees.into_iter() { let mut file_mtimes = worktree_file_mtimes.remove(&worktree.id()).unwrap(); for file in worktree.files(false, 0) { @@ -552,14 +560,11 @@ impl SemanticIndex { .map_or(false, |existing_mtime| existing_mtime == file.mtime); if !already_stored { - log::trace!("sending for parsing: {:?}", path_buf); count += 1; - let job_id = next_job_id.fetch_add(1, atomic::Ordering::SeqCst); + *job_count_tx.lock().borrow_mut() += 1; let job_handle = JobHandle { - id: job_id, - set: Arc::downgrade(&outstanding_jobs), + tx: Arc::downgrade(&job_count_tx), }; - outstanding_jobs.lock().insert(job_id); parsing_files_tx .try_send(PendingFile { worktree_db_id: db_ids_by_worktree_id[&worktree.id()], @@ -582,27 +587,22 @@ impl SemanticIndex { .unwrap(); } } - log::trace!( - "parsing worktree completed in {:?}", - t0.elapsed().as_millis() - ); - Ok(count) + anyhow::Ok((count, job_count_rx)) }) .await }) } - pub fn remaining_files_to_index_for_project( + pub fn outstanding_job_count_rx( &self, project: &ModelHandle, - ) -> Option { + ) -> Option> { Some( self.projects .get(&project.downgrade())? - .outstanding_jobs - .lock() - .len(), + .outstanding_job_count_rx + .clone(), ) } @@ -678,8 +678,9 @@ impl Entity for SemanticIndex { impl Drop for JobHandle { fn drop(&mut self) { - if let Some(set) = self.set.upgrade() { - set.lock().remove(&self.id); + if let Some(tx) = self.tx.upgrade() { + let mut tx = tx.lock(); + *tx.borrow_mut() -= 1; } } } diff --git a/crates/semantic_index/src/semantic_index_tests.rs b/crates/semantic_index/src/semantic_index_tests.rs index ed48cf256b..2ccc52d64b 100644 --- a/crates/semantic_index/src/semantic_index_tests.rs +++ b/crates/semantic_index/src/semantic_index_tests.rs @@ -88,18 +88,13 @@ async fn test_semantic_index(cx: &mut TestAppContext) { let worktree_id = project.read_with(cx, |project, cx| { project.worktrees(cx).next().unwrap().read(cx).id() }); - let file_count = store + let (file_count, outstanding_file_count) = store .update(cx, |store, cx| store.index_project(project.clone(), cx)) .await .unwrap(); assert_eq!(file_count, 3); cx.foreground().run_until_parked(); - store.update(cx, |store, _cx| { - assert_eq!( - store.remaining_files_to_index_for_project(&project), - Some(0) - ); - }); + assert_eq!(*outstanding_file_count.borrow(), 0); let search_results = store .update(cx, |store, cx| { @@ -128,19 +123,14 @@ async fn test_semantic_index(cx: &mut TestAppContext) { cx.foreground().run_until_parked(); let prev_embedding_count = embedding_provider.embedding_count(); - let file_count = store + let (file_count, outstanding_file_count) = store .update(cx, |store, cx| store.index_project(project.clone(), cx)) .await .unwrap(); assert_eq!(file_count, 1); cx.foreground().run_until_parked(); - store.update(cx, |store, _cx| { - assert_eq!( - store.remaining_files_to_index_for_project(&project), - Some(0) - ); - }); + assert_eq!(*outstanding_file_count.borrow(), 0); assert_eq!( embedding_provider.embedding_count() - prev_embedding_count,