diff --git a/crates/search/src/project_search.rs b/crates/search/src/project_search.rs index 91d2b142ae..1097969c00 100644 --- a/crates/search/src/project_search.rs +++ b/crates/search/src/project_search.rs @@ -2,7 +2,7 @@ use crate::{ SearchOption, SelectNextMatch, SelectPrevMatch, ToggleCaseSensitive, ToggleRegex, ToggleWholeWord, }; -use anyhow::{Context, Result}; +use anyhow::Result; use collections::HashMap; use editor::{ items::active_match_index, scroll::autoscroll::Autoscroll, Anchor, Editor, MultiBuffer, @@ -187,6 +187,53 @@ impl ProjectSearch { })); cx.notify(); } + + fn semantic_search(&mut self, query: String, cx: &mut ModelContext) -> Option<()> { + let project = self.project.clone(); + let semantic_index = SemanticIndex::global(cx)?; + let search_task = semantic_index.update(cx, |semantic_index, cx| { + semantic_index.search_project(project, query.clone(), 10, cx) + }); + + self.search_id += 1; + // self.active_query = Some(query); + self.match_ranges.clear(); + self.pending_search = Some(cx.spawn(|this, mut cx| async move { + let results = search_task.await.log_err()?; + + let (_task, mut match_ranges) = this.update(&mut cx, |this, cx| { + this.excerpts.update(cx, |excerpts, cx| { + excerpts.clear(cx); + + let matches = results + .into_iter() + .map(|result| (result.buffer, vec![result.range])) + .collect(); + + excerpts.stream_excerpts_with_context_lines(matches, 3, cx) + }) + }); + + while let Some(match_range) = match_ranges.next().await { + this.update(&mut cx, |this, cx| { + this.match_ranges.push(match_range); + while let Ok(Some(match_range)) = match_ranges.try_next() { + this.match_ranges.push(match_range); + } + cx.notify(); + }); + } + + this.update(&mut cx, |this, cx| { + this.pending_search.take(); + cx.notify(); + }); + + None + })); + + Some(()) + } } pub enum ViewEvent { @@ -595,27 +642,9 @@ impl ProjectSearchView { return; } - let search_phrase = self.query_editor.read(cx).text(cx); - let project = self.model.read(cx).project.clone(); - if let Some(semantic_index) = SemanticIndex::global(cx) { - let search_task = semantic_index.update(cx, |semantic_index, cx| { - semantic_index.search_project(project, search_phrase, 10, cx) - }); - semantic.search_task = Some(cx.spawn(|this, mut cx| async move { - let results = search_task.await.context("search task")?; - - this.update(&mut cx, |this, cx| { - dbg!(&results); - // TODO: Update results - - if let Some(semantic) = &mut this.semantic { - semantic.search_task = None; - } - })?; - - anyhow::Ok(()) - })); - } + let query = self.query_editor.read(cx).text(cx); + self.model + .update(cx, |model, cx| model.semantic_search(query, cx)); return; } diff --git a/crates/semantic_index/src/db.rs b/crates/semantic_index/src/db.rs index 74e1021b15..fd99594aab 100644 --- a/crates/semantic_index/src/db.rs +++ b/crates/semantic_index/src/db.rs @@ -252,7 +252,7 @@ impl VectorDatabase { worktree_ids: &[i64], query_embedding: &Vec, limit: usize, - ) -> Result, String)>> { + ) -> Result)>> { let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1); self.for_each_document(&worktree_ids, |id, embedding| { let similarity = dot(&embedding, &query_embedding); @@ -296,10 +296,7 @@ impl VectorDatabase { Ok(()) } - fn get_documents_by_ids( - &self, - ids: &[i64], - ) -> Result, String)>> { + fn get_documents_by_ids(&self, ids: &[i64]) -> Result)>> { let mut statement = self.db.prepare( " SELECT @@ -307,7 +304,7 @@ impl VectorDatabase { files.worktree_id, files.relative_path, documents.start_byte, - documents.end_byte, documents.name + documents.end_byte FROM documents, files WHERE @@ -322,14 +319,13 @@ impl VectorDatabase { row.get::<_, i64>(1)?, row.get::<_, String>(2)?.into(), row.get(3)?..row.get(4)?, - row.get(5)?, )) })?; - let mut values_by_id = HashMap::, String)>::default(); + let mut values_by_id = HashMap::)>::default(); for row in result_iter { - let (id, worktree_id, path, range, name) = row?; - values_by_id.insert(id, (worktree_id, path, range, name)); + let (id, worktree_id, path, range) = row?; + values_by_id.insert(id, (worktree_id, path, range)); } let mut results = Vec::with_capacity(ids.len()); diff --git a/crates/semantic_index/src/embedding.rs b/crates/semantic_index/src/embedding.rs index d41350f321..728fc9283a 100644 --- a/crates/semantic_index/src/embedding.rs +++ b/crates/semantic_index/src/embedding.rs @@ -70,10 +70,6 @@ impl EmbeddingProvider for DummyEmbeddings { const OPENAI_INPUT_LIMIT: usize = 8190; impl OpenAIEmbeddings { - pub fn new(client: Arc, executor: Arc) -> Self { - Self { client, executor } - } - fn truncate(span: String) -> String { let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span.as_ref()); if tokens.len() > OPENAI_INPUT_LIMIT { @@ -81,7 +77,6 @@ impl OpenAIEmbeddings { let result = OPENAI_BPE_TOKENIZER.decode(tokens.clone()); if result.is_ok() { let transformed = result.unwrap(); - // assert_ne!(transformed, span); return transformed; } } diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index f6575f6ad7..5c6919d4fd 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -12,7 +12,7 @@ use db::VectorDatabase; use embedding::{EmbeddingProvider, OpenAIEmbeddings}; use futures::{channel::oneshot, Future}; use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle}; -use language::{Language, LanguageRegistry}; +use language::{Anchor, Buffer, Language, LanguageRegistry}; use parking_lot::Mutex; use parsing::{CodeContextRetriever, Document, PARSEABLE_ENTIRE_FILE_TYPES}; use postage::watch; @@ -93,7 +93,7 @@ pub struct SemanticIndex { struct ProjectState { worktree_db_ids: Vec<(WorktreeId, i64)>, outstanding_job_count_rx: watch::Receiver, - outstanding_job_count_tx: Arc>>, + _outstanding_job_count_tx: Arc>>, } struct JobHandle { @@ -135,12 +135,9 @@ pub struct PendingFile { job_handle: JobHandle, } -#[derive(Debug, Clone)] pub struct SearchResult { - pub worktree_id: WorktreeId, - pub name: String, - pub byte_range: Range, - pub file_path: PathBuf, + pub buffer: ModelHandle, + pub range: Range, } enum DbOperation { @@ -520,7 +517,7 @@ impl SemanticIndex { .map(|(a, b)| (*a, *b)) .collect(), outstanding_job_count_rx: job_count_rx.clone(), - outstanding_job_count_tx: job_count_tx.clone(), + _outstanding_job_count_tx: job_count_tx.clone(), }, ); }); @@ -623,7 +620,7 @@ impl SemanticIndex { let embedding_provider = self.embedding_provider.clone(); let database_url = self.database_url.clone(); let fs = self.fs.clone(); - cx.spawn(|this, cx| async move { + cx.spawn(|this, mut cx| async move { let documents = cx .background() .spawn(async move { @@ -640,26 +637,39 @@ impl SemanticIndex { }) .await?; - this.read_with(&cx, |this, _| { - let project_state = if let Some(state) = this.projects.get(&project.downgrade()) { - state - } else { - return Err(anyhow!("project not added")); - }; + let mut tasks = Vec::new(); + let mut ranges = Vec::new(); + let weak_project = project.downgrade(); + project.update(&mut cx, |project, cx| { + for (worktree_db_id, file_path, byte_range) in documents { + let project_state = + if let Some(state) = this.read(cx).projects.get(&weak_project) { + state + } else { + return Err(anyhow!("project not added")); + }; + if let Some(worktree_id) = project_state.worktree_id_for_db_id(worktree_db_id) { + tasks.push(project.open_buffer((worktree_id, file_path), cx)); + ranges.push(byte_range); + } + } - Ok(documents - .into_iter() - .filter_map(|(worktree_db_id, file_path, byte_range, name)| { - let worktree_id = project_state.worktree_id_for_db_id(worktree_db_id)?; - Some(SearchResult { - worktree_id, - name, - byte_range, - file_path, - }) - }) - .collect()) - }) + Ok(()) + })?; + + let buffers = futures::future::join_all(tasks).await; + + Ok(buffers + .into_iter() + .zip(ranges) + .filter_map(|(buffer, range)| { + let buffer = buffer.log_err()?; + let range = buffer.read_with(&cx, |buffer, _| { + buffer.anchor_before(range.start)..buffer.anchor_after(range.end) + }); + Some(SearchResult { buffer, range }) + }) + .collect::>()) }) } } diff --git a/crates/semantic_index/src/semantic_index_tests.rs b/crates/semantic_index/src/semantic_index_tests.rs index 2ccc52d64b..63b28798ad 100644 --- a/crates/semantic_index/src/semantic_index_tests.rs +++ b/crates/semantic_index/src/semantic_index_tests.rs @@ -8,7 +8,7 @@ use crate::{ use anyhow::Result; use async_trait::async_trait; use gpui::{Task, TestAppContext}; -use language::{Language, LanguageConfig, LanguageRegistry}; +use language::{Language, LanguageConfig, LanguageRegistry, ToOffset}; use project::{project_settings::ProjectSettings, FakeFs, Fs, Project}; use rand::{rngs::StdRng, Rng}; use serde_json::json; @@ -85,9 +85,6 @@ async fn test_semantic_index(cx: &mut TestAppContext) { .unwrap(); let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await; - let worktree_id = project.read_with(cx, |project, cx| { - project.worktrees(cx).next().unwrap().read(cx).id() - }); let (file_count, outstanding_file_count) = store .update(cx, |store, cx| store.index_project(project.clone(), cx)) .await @@ -103,9 +100,13 @@ async fn test_semantic_index(cx: &mut TestAppContext) { .await .unwrap(); - assert_eq!(search_results[0].byte_range.start, 0); - assert_eq!(search_results[0].name, "aaa"); - assert_eq!(search_results[0].worktree_id, worktree_id); + search_results[0].buffer.read_with(cx, |buffer, _cx| { + assert_eq!(search_results[0].range.start.to_offset(buffer), 0); + assert_eq!( + buffer.file().unwrap().path().as_ref(), + Path::new("file1.rs") + ); + }); fs.save( "/the-root/src/file2.rs".as_ref(),