Start work on exposing semantic search via project search view

Co-authored-by: Kyle <kyle@zed.dev>
This commit is contained in:
Max Brunsfeld 2023-07-17 18:10:51 -07:00
parent d83c4ffb07
commit afc4c10ec1
9 changed files with 397 additions and 423 deletions

2
Cargo.lock generated
View file

@ -6430,6 +6430,7 @@ dependencies = [
"menu",
"postage",
"project",
"semantic_index",
"serde",
"serde_derive",
"serde_json",
@ -6484,6 +6485,7 @@ dependencies = [
"matrixmultiply",
"parking_lot 0.11.2",
"picker",
"postage",
"project",
"rand 0.8.5",
"rpc",

View file

@ -19,6 +19,7 @@ settings = { path = "../settings" }
theme = { path = "../theme" }
util = { path = "../util" }
workspace = { path = "../workspace" }
semantic_index = { path = "../semantic_index" }
anyhow.workspace = true
futures.workspace = true
log.workspace = true

View file

@ -2,7 +2,7 @@ use crate::{
SearchOption, SelectNextMatch, SelectPrevMatch, ToggleCaseSensitive, ToggleRegex,
ToggleWholeWord,
};
use anyhow::Result;
use anyhow::{Context, Result};
use collections::HashMap;
use editor::{
items::active_match_index, scroll::autoscroll::Autoscroll, Anchor, Editor, MultiBuffer,
@ -18,7 +18,9 @@ use gpui::{
Task, View, ViewContext, ViewHandle, WeakModelHandle, WeakViewHandle,
};
use menu::Confirm;
use postage::stream::Stream;
use project::{search::SearchQuery, Project};
use semantic_index::SemanticIndex;
use smallvec::SmallVec;
use std::{
any::{Any, TypeId},
@ -36,7 +38,10 @@ use workspace::{
ItemNavHistory, Pane, ToolbarItemLocation, ToolbarItemView, Workspace, WorkspaceId,
};
actions!(project_search, [SearchInNew, ToggleFocus, NextField]);
actions!(
project_search,
[SearchInNew, ToggleFocus, NextField, ToggleSemanticSearch]
);
#[derive(Default)]
struct ActiveSearches(HashMap<WeakModelHandle<Project>, WeakViewHandle<ProjectSearchView>>);
@ -92,6 +97,7 @@ pub struct ProjectSearchView {
case_sensitive: bool,
whole_word: bool,
regex: bool,
semantic: Option<SemanticSearchState>,
panels_with_errors: HashSet<InputPanel>,
active_match_index: Option<usize>,
search_id: usize,
@ -100,6 +106,13 @@ pub struct ProjectSearchView {
excluded_files_editor: ViewHandle<Editor>,
}
struct SemanticSearchState {
file_count: usize,
outstanding_file_count: usize,
_progress_task: Task<()>,
search_task: Option<Task<Result<()>>>,
}
pub struct ProjectSearchBar {
active_project_search: Option<ViewHandle<ProjectSearchView>>,
subscription: Option<Subscription>,
@ -198,12 +211,25 @@ impl View for ProjectSearchView {
let theme = theme::current(cx).clone();
let text = if self.query_editor.read(cx).text(cx).is_empty() {
""
Cow::Borrowed("")
} else if let Some(semantic) = &self.semantic {
if semantic.search_task.is_some() {
Cow::Borrowed("Searching...")
} else if semantic.outstanding_file_count > 0 {
Cow::Owned(format!(
"Indexing. {} of {}...",
semantic.file_count - semantic.outstanding_file_count,
semantic.file_count
))
} else {
Cow::Borrowed("Indexing complete")
}
} else if model.pending_search.is_some() {
"Searching..."
Cow::Borrowed("Searching...")
} else {
"No results"
Cow::Borrowed("No results")
};
MouseEventHandler::<Status, _>::new(0, cx, |_, _| {
Label::new(text, theme.search.results_status.clone())
.aligned()
@ -499,6 +525,7 @@ impl ProjectSearchView {
case_sensitive,
whole_word,
regex,
semantic: None,
panels_with_errors: HashSet::new(),
active_match_index: None,
query_editor_was_focused: false,
@ -563,6 +590,35 @@ impl ProjectSearchView {
}
fn search(&mut self, cx: &mut ViewContext<Self>) {
if let Some(semantic) = &mut self.semantic {
if semantic.outstanding_file_count > 0 {
return;
}
let search_phrase = self.query_editor.read(cx).text(cx);
let project = self.model.read(cx).project.clone();
if let Some(semantic_index) = SemanticIndex::global(cx) {
let search_task = semantic_index.update(cx, |semantic_index, cx| {
semantic_index.search_project(project, search_phrase, 10, cx)
});
semantic.search_task = Some(cx.spawn(|this, mut cx| async move {
let results = search_task.await.context("search task")?;
this.update(&mut cx, |this, cx| {
dbg!(&results);
// TODO: Update results
if let Some(semantic) = &mut this.semantic {
semantic.search_task = None;
}
})?;
anyhow::Ok(())
}));
}
return;
}
if let Some(query) = self.build_search_query(cx) {
self.model.update(cx, |model, cx| model.search(query, cx));
}
@ -876,6 +932,59 @@ impl ProjectSearchBar {
}
}
fn toggle_semantic_search(&mut self, cx: &mut ViewContext<Self>) -> bool {
if let Some(search_view) = self.active_project_search.as_ref() {
search_view.update(cx, |search_view, cx| {
if search_view.semantic.is_some() {
search_view.semantic = None;
} else if let Some(semantic_index) = SemanticIndex::global(cx) {
// TODO: confirm that it's ok to send this project
let project = search_view.model.read(cx).project.clone();
let index_task = semantic_index.update(cx, |semantic_index, cx| {
semantic_index.index_project(project, cx)
});
cx.spawn(|search_view, mut cx| async move {
let (files_to_index, mut files_remaining_rx) = index_task.await?;
search_view.update(&mut cx, |search_view, cx| {
search_view.semantic = Some(SemanticSearchState {
file_count: files_to_index,
outstanding_file_count: files_to_index,
search_task: None,
_progress_task: cx.spawn(|search_view, mut cx| async move {
while let Some(count) = files_remaining_rx.recv().await {
search_view
.update(&mut cx, |search_view, cx| {
if let Some(semantic_search_state) =
&mut search_view.semantic
{
semantic_search_state.outstanding_file_count =
count;
cx.notify();
if count == 0 {
return;
}
}
})
.ok();
}
}),
});
})?;
anyhow::Ok(())
})
.detach_and_log_err(cx);
}
});
cx.notify();
true
} else {
false
}
}
fn render_nav_button(
&self,
icon: &'static str,
@ -953,6 +1062,42 @@ impl ProjectSearchBar {
.into_any()
}
fn render_semantic_search_button(&self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
let tooltip_style = theme::current(cx).tooltip.clone();
let is_active = if let Some(search) = self.active_project_search.as_ref() {
let search = search.read(cx);
search.semantic.is_some()
} else {
false
};
let region_id = 3;
MouseEventHandler::<Self, _>::new(region_id, cx, |state, cx| {
let theme = theme::current(cx);
let style = theme
.search
.option_button
.in_state(is_active)
.style_for(state);
Label::new("Semantic", style.text.clone())
.contained()
.with_style(style.container)
})
.on_click(MouseButton::Left, move |_, this, cx| {
this.toggle_semantic_search(cx);
})
.with_cursor_style(CursorStyle::PointingHand)
.with_tooltip::<Self>(
region_id,
format!("Toggle Semantic Search"),
Some(Box::new(ToggleSemanticSearch)),
tooltip_style,
cx,
)
.into_any()
}
fn is_option_enabled(&self, option: SearchOption, cx: &AppContext) -> bool {
if let Some(search) = self.active_project_search.as_ref() {
let search = search.read(cx);
@ -1049,6 +1194,7 @@ impl View for ProjectSearchBar {
)
.with_child(
Flex::row()
.with_child(self.render_semantic_search_button(cx))
.with_child(self.render_option_button(
"Case",
SearchOption::CaseSensitive,

View file

@ -20,6 +20,7 @@ editor = { path = "../editor" }
rpc = { path = "../rpc" }
settings = { path = "../settings" }
anyhow.workspace = true
postage.workspace = true
futures.workspace = true
smol.workspace = true
rusqlite = { version = "0.27.0", features = ["blob", "array", "modern_sqlite"] }

View file

@ -1,5 +1,5 @@
use crate::{parsing::Document, SEMANTIC_INDEX_VERSION};
use anyhow::{anyhow, Result};
use anyhow::{anyhow, Context, Result};
use project::Fs;
use rpc::proto::Timestamp;
use rusqlite::{
@ -76,14 +76,14 @@ impl VectorDatabase {
self.db
.execute(
"
DROP TABLE semantic_index_config;
DROP TABLE worktrees;
DROP TABLE files;
DROP TABLE documents;
DROP TABLE IF EXISTS documents;
DROP TABLE IF EXISTS files;
DROP TABLE IF EXISTS worktrees;
DROP TABLE IF EXISTS semantic_index_config;
",
[],
)
.ok();
.context("failed to drop tables")?;
// Initialize Vector Databasing Tables
self.db.execute(

View file

@ -86,6 +86,7 @@ impl OpenAIEmbeddings {
async fn send_request(&self, api_key: &str, spans: Vec<&str>) -> Result<Response<AsyncBody>> {
let request = Request::post("https://api.openai.com/v1/embeddings")
.redirect_policy(isahc::config::RedirectPolicy::Follow)
.timeout(Duration::from_secs(4))
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_key))
.body(
@ -133,7 +134,11 @@ impl EmbeddingProvider for OpenAIEmbeddings {
self.executor.timer(delay).await;
}
StatusCode::BAD_REQUEST => {
log::info!("BAD REQUEST: {:?}", &response.status());
log::info!(
"BAD REQUEST: {:?} {:?}",
&response.status(),
response.body()
);
// Don't worry about delaying bad request, as we can assume
// we haven't been rate limited yet.
for span in spans.iter_mut() {

View file

@ -1,172 +0,0 @@
use crate::{SearchResult, SemanticIndex};
use editor::{scroll::autoscroll::Autoscroll, Editor};
use gpui::{
actions, elements::*, AnyElement, AppContext, ModelHandle, MouseState, Task, ViewContext,
WeakViewHandle,
};
use picker::{Picker, PickerDelegate, PickerEvent};
use project::{Project, ProjectPath};
use std::{collections::HashMap, sync::Arc, time::Duration};
use util::ResultExt;
use workspace::Workspace;
const MIN_QUERY_LEN: usize = 5;
const EMBEDDING_DEBOUNCE_INTERVAL: Duration = Duration::from_millis(500);
actions!(semantic_search, [Toggle]);
pub type SemanticSearch = Picker<SemanticSearchDelegate>;
pub struct SemanticSearchDelegate {
workspace: WeakViewHandle<Workspace>,
project: ModelHandle<Project>,
semantic_index: ModelHandle<SemanticIndex>,
selected_match_index: usize,
matches: Vec<SearchResult>,
history: HashMap<String, Vec<SearchResult>>,
}
impl SemanticSearchDelegate {
// This is currently searching on every keystroke,
// This is wildly overkill, and has the potential to get expensive
// We will need to update this to throttle searching
pub fn new(
workspace: WeakViewHandle<Workspace>,
project: ModelHandle<Project>,
semantic_index: ModelHandle<SemanticIndex>,
) -> Self {
Self {
workspace,
project,
semantic_index,
selected_match_index: 0,
matches: vec![],
history: HashMap::new(),
}
}
}
impl PickerDelegate for SemanticSearchDelegate {
fn placeholder_text(&self) -> Arc<str> {
"Search repository in natural language...".into()
}
fn confirm(&mut self, cx: &mut ViewContext<SemanticSearch>) {
if let Some(search_result) = self.matches.get(self.selected_match_index) {
// Open Buffer
let search_result = search_result.clone();
let buffer = self.project.update(cx, |project, cx| {
project.open_buffer(
ProjectPath {
worktree_id: search_result.worktree_id,
path: search_result.file_path.clone().into(),
},
cx,
)
});
let workspace = self.workspace.clone();
let position = search_result.clone().byte_range.start;
cx.spawn(|_, mut cx| async move {
let buffer = buffer.await?;
workspace.update(&mut cx, |workspace, cx| {
let editor = workspace.open_project_item::<Editor>(buffer, cx);
editor.update(cx, |editor, cx| {
editor.change_selections(Some(Autoscroll::center()), cx, |s| {
s.select_ranges([position..position])
});
});
})?;
Ok::<_, anyhow::Error>(())
})
.detach_and_log_err(cx);
cx.emit(PickerEvent::Dismiss);
}
}
fn dismissed(&mut self, _cx: &mut ViewContext<SemanticSearch>) {}
fn match_count(&self) -> usize {
self.matches.len()
}
fn selected_index(&self) -> usize {
self.selected_match_index
}
fn set_selected_index(&mut self, ix: usize, _cx: &mut ViewContext<SemanticSearch>) {
self.selected_match_index = ix;
}
fn update_matches(&mut self, query: String, cx: &mut ViewContext<SemanticSearch>) -> Task<()> {
log::info!("Searching for {:?}...", query);
if query.len() < MIN_QUERY_LEN {
log::info!("Query below minimum length");
return Task::ready(());
}
let semantic_index = self.semantic_index.clone();
let project = self.project.clone();
cx.spawn(|this, mut cx| async move {
cx.background().timer(EMBEDDING_DEBOUNCE_INTERVAL).await;
let retrieved_cached = this.update(&mut cx, |this, _| {
let delegate = this.delegate_mut();
if delegate.history.contains_key(&query) {
let historic_results = delegate.history.get(&query).unwrap().to_owned();
delegate.matches = historic_results.clone();
true
} else {
false
}
});
if let Some(retrieved) = retrieved_cached.log_err() {
if !retrieved {
let task = semantic_index.update(&mut cx, |store, cx| {
store.search_project(project.clone(), query.to_string(), 10, cx)
});
if let Some(results) = task.await.log_err() {
log::info!("Not queried previously, searching...");
this.update(&mut cx, |this, _| {
let delegate = this.delegate_mut();
delegate.matches = results.clone();
delegate.history.insert(query, results);
})
.ok();
}
} else {
log::info!("Already queried, retrieved directly from cached history");
}
}
})
}
fn render_match(
&self,
ix: usize,
mouse_state: &mut MouseState,
selected: bool,
cx: &AppContext,
) -> AnyElement<Picker<Self>> {
let theme = theme::current(cx);
let style = &theme.picker.item;
let current_style = style.in_state(selected).style_for(mouse_state);
let search_result = &self.matches[ix];
let path = search_result.file_path.to_string_lossy();
let name = search_result.name.clone();
Flex::column()
.with_child(Text::new(name, current_style.label.text.clone()).with_soft_wrap(false))
.with_child(Label::new(
path.to_string(),
style.inactive_state().default.label.clone(),
))
.contained()
.with_style(current_style.container)
.into_any()
}
}

View file

@ -1,6 +1,5 @@
mod db;
mod embedding;
mod modal;
mod parsing;
mod semantic_index_settings;
@ -12,25 +11,20 @@ use anyhow::{anyhow, Result};
use db::VectorDatabase;
use embedding::{EmbeddingProvider, OpenAIEmbeddings};
use futures::{channel::oneshot, Future};
use gpui::{
AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, ViewContext,
WeakModelHandle,
};
use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
use language::{Language, LanguageRegistry};
use modal::{SemanticSearch, SemanticSearchDelegate, Toggle};
use parking_lot::Mutex;
use parsing::{CodeContextRetriever, Document, PARSEABLE_ENTIRE_FILE_TYPES};
use postage::watch;
use project::{Fs, Project, WorktreeId};
use smol::channel;
use std::{
collections::{HashMap, HashSet},
collections::HashMap,
mem,
ops::Range,
path::{Path, PathBuf},
sync::{
atomic::{self, AtomicUsize},
Arc, Weak,
},
time::{Instant, SystemTime},
sync::{Arc, Weak},
time::SystemTime,
};
use util::{
channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME},
@ -38,9 +32,8 @@ use util::{
paths::EMBEDDINGS_DIR,
ResultExt,
};
use workspace::{Workspace, WorkspaceCreated};
const SEMANTIC_INDEX_VERSION: usize = 1;
const SEMANTIC_INDEX_VERSION: usize = 3;
const EMBEDDINGS_BATCH_SIZE: usize = 150;
pub fn init(
@ -55,25 +48,6 @@ pub fn init(
.join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
.join("embeddings_db");
SemanticSearch::init(cx);
cx.add_action(
|workspace: &mut Workspace, _: &Toggle, cx: &mut ViewContext<Workspace>| {
if cx.has_global::<ModelHandle<SemanticIndex>>() {
let semantic_index = cx.global::<ModelHandle<SemanticIndex>>().clone();
workspace.toggle_modal(cx, |workspace, cx| {
let project = workspace.project().clone();
let workspace = cx.weak_handle();
cx.add_view(|cx| {
SemanticSearch::new(
SemanticSearchDelegate::new(workspace, project, semantic_index),
cx,
)
})
});
}
},
);
if *RELEASE_CHANNEL == ReleaseChannel::Stable
|| !settings::get::<SemanticIndexSettings>(cx).enabled
{
@ -95,21 +69,6 @@ pub fn init(
cx.update(|cx| {
cx.set_global(semantic_index.clone());
cx.subscribe_global::<WorkspaceCreated, _>({
let semantic_index = semantic_index.clone();
move |event, cx| {
let workspace = &event.0;
if let Some(workspace) = workspace.upgrade(cx) {
let project = workspace.read(cx).project().clone();
if project.read(cx).is_local() {
semantic_index.update(cx, |store, cx| {
store.index_project(project, cx).detach();
});
}
}
}
})
.detach();
});
anyhow::Ok(())
@ -128,20 +87,17 @@ pub struct SemanticIndex {
_embed_batch_task: Task<()>,
_batch_files_task: Task<()>,
_parsing_files_tasks: Vec<Task<()>>,
next_job_id: Arc<AtomicUsize>,
projects: HashMap<WeakModelHandle<Project>, ProjectState>,
}
struct ProjectState {
worktree_db_ids: Vec<(WorktreeId, i64)>,
outstanding_jobs: Arc<Mutex<HashSet<JobId>>>,
outstanding_job_count_rx: watch::Receiver<usize>,
outstanding_job_count_tx: Arc<Mutex<watch::Sender<usize>>>,
}
type JobId = usize;
struct JobHandle {
id: JobId,
set: Weak<Mutex<HashSet<JobId>>>,
tx: Weak<Mutex<watch::Sender<usize>>>,
}
impl ProjectState {
@ -221,6 +177,14 @@ enum EmbeddingJob {
}
impl SemanticIndex {
pub fn global(cx: &AppContext) -> Option<ModelHandle<SemanticIndex>> {
if cx.has_global::<ModelHandle<Self>>() {
Some(cx.global::<ModelHandle<SemanticIndex>>().clone())
} else {
None
}
}
async fn new(
fs: Arc<dyn Fs>,
database_url: PathBuf,
@ -236,184 +200,69 @@ impl SemanticIndex {
.await?;
Ok(cx.add_model(|cx| {
// paths_tx -> embeddings_tx -> db_update_tx
//db_update_tx/rx: Updating Database
// Perform database operations
let (db_update_tx, db_update_rx) = channel::unbounded();
let _db_update_task = cx.background().spawn(async move {
while let Ok(job) = db_update_rx.recv().await {
match job {
DbOperation::InsertFile {
worktree_id,
documents,
path,
mtime,
job_handle,
} => {
db.insert_file(worktree_id, path, mtime, documents)
.log_err();
drop(job_handle)
}
DbOperation::Delete { worktree_id, path } => {
db.delete_file(worktree_id, path).log_err();
}
DbOperation::FindOrCreateWorktree { path, sender } => {
let id = db.find_or_create_worktree(&path);
sender.send(id).ok();
}
DbOperation::FileMTimes {
worktree_id: worktree_db_id,
sender,
} => {
let file_mtimes = db.get_file_mtimes(worktree_db_id);
sender.send(file_mtimes).ok();
}
let _db_update_task = cx.background().spawn({
async move {
while let Ok(job) = db_update_rx.recv().await {
Self::run_db_operation(&db, job)
}
}
});
// embed_tx/rx: Embed Batch and Send to Database
// Group documents into batches and send them to the embedding provider.
let (embed_batch_tx, embed_batch_rx) =
channel::unbounded::<Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>>();
let _embed_batch_task = cx.background().spawn({
let db_update_tx = db_update_tx.clone();
let embedding_provider = embedding_provider.clone();
async move {
while let Ok(mut embeddings_queue) = embed_batch_rx.recv().await {
// Construct Batch
let mut batch_documents = vec![];
for (_, documents, _, _, _) in embeddings_queue.iter() {
batch_documents
.extend(documents.iter().map(|document| document.content.as_str()));
}
if let Ok(embeddings) =
embedding_provider.embed_batch(batch_documents).await
{
log::trace!(
"created {} embeddings for {} files",
embeddings.len(),
embeddings_queue.len(),
);
let mut i = 0;
let mut j = 0;
for embedding in embeddings.iter() {
while embeddings_queue[i].1.len() == j {
i += 1;
j = 0;
}
embeddings_queue[i].1[j].embedding = embedding.to_owned();
j += 1;
}
for (worktree_id, documents, path, mtime, job_handle) in
embeddings_queue.into_iter()
{
for document in documents.iter() {
// TODO: Update this so it doesn't panic
assert!(
document.embedding.len() > 0,
"Document Embedding Not Complete"
);
}
db_update_tx
.send(DbOperation::InsertFile {
worktree_id,
documents,
path,
mtime,
job_handle,
})
.await
.unwrap();
}
}
while let Ok(embeddings_queue) = embed_batch_rx.recv().await {
Self::compute_embeddings_for_batch(
embeddings_queue,
&embedding_provider,
&db_update_tx,
)
.await;
}
}
});
// batch_tx/rx: Batch Files to Send for Embeddings
// Group documents into batches and send them to the embedding provider.
let (batch_files_tx, batch_files_rx) = channel::unbounded::<EmbeddingJob>();
let _batch_files_task = cx.background().spawn(async move {
let mut queue_len = 0;
let mut embeddings_queue = vec![];
while let Ok(job) = batch_files_rx.recv().await {
let should_flush = match job {
EmbeddingJob::Enqueue {
documents,
worktree_id,
path,
mtime,
job_handle,
} => {
queue_len += &documents.len();
embeddings_queue.push((
worktree_id,
documents,
path,
mtime,
job_handle,
));
queue_len >= EMBEDDINGS_BATCH_SIZE
}
EmbeddingJob::Flush => true,
};
if should_flush {
embed_batch_tx.try_send(embeddings_queue).unwrap();
embeddings_queue = vec![];
queue_len = 0;
}
Self::enqueue_documents_to_embed(
job,
&mut queue_len,
&mut embeddings_queue,
&embed_batch_tx,
);
}
});
// parsing_files_tx/rx: Parsing Files to Embeddable Documents
// Parse files into embeddable documents.
let (parsing_files_tx, parsing_files_rx) = channel::unbounded::<PendingFile>();
let mut _parsing_files_tasks = Vec::new();
for _ in 0..cx.background().num_cpus() {
let fs = fs.clone();
let parsing_files_rx = parsing_files_rx.clone();
let batch_files_tx = batch_files_tx.clone();
let db_update_tx = db_update_tx.clone();
_parsing_files_tasks.push(cx.background().spawn(async move {
let mut retriever = CodeContextRetriever::new();
while let Ok(pending_file) = parsing_files_rx.recv().await {
if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err()
{
if let Some(documents) = retriever
.parse_file(
&pending_file.relative_path,
&content,
pending_file.language,
)
.log_err()
{
log::trace!(
"parsed path {:?}: {} documents",
pending_file.relative_path,
documents.len()
);
batch_files_tx
.try_send(EmbeddingJob::Enqueue {
worktree_id: pending_file.worktree_db_id,
path: pending_file.relative_path,
mtime: pending_file.modified_time,
job_handle: pending_file.job_handle,
documents,
})
.unwrap();
}
}
if parsing_files_rx.len() == 0 {
batch_files_tx.try_send(EmbeddingJob::Flush).unwrap();
}
Self::parse_file(
&fs,
pending_file,
&mut retriever,
&batch_files_tx,
&parsing_files_rx,
&db_update_tx,
)
.await;
}
}));
}
@ -424,7 +273,6 @@ impl SemanticIndex {
embedding_provider,
language_registry,
db_update_tx,
next_job_id: Default::default(),
parsing_files_tx,
_db_update_task,
_embed_batch_task,
@ -435,6 +283,167 @@ impl SemanticIndex {
}))
}
fn run_db_operation(db: &VectorDatabase, job: DbOperation) {
match job {
DbOperation::InsertFile {
worktree_id,
documents,
path,
mtime,
job_handle,
} => {
db.insert_file(worktree_id, path, mtime, documents)
.log_err();
drop(job_handle)
}
DbOperation::Delete { worktree_id, path } => {
db.delete_file(worktree_id, path).log_err();
}
DbOperation::FindOrCreateWorktree { path, sender } => {
let id = db.find_or_create_worktree(&path);
sender.send(id).ok();
}
DbOperation::FileMTimes {
worktree_id: worktree_db_id,
sender,
} => {
let file_mtimes = db.get_file_mtimes(worktree_db_id);
sender.send(file_mtimes).ok();
}
}
}
async fn compute_embeddings_for_batch(
mut embeddings_queue: Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>,
embedding_provider: &Arc<dyn EmbeddingProvider>,
db_update_tx: &channel::Sender<DbOperation>,
) {
let mut batch_documents = vec![];
for (_, documents, _, _, _) in embeddings_queue.iter() {
batch_documents.extend(documents.iter().map(|document| document.content.as_str()));
}
if let Ok(embeddings) = embedding_provider.embed_batch(batch_documents).await {
log::trace!(
"created {} embeddings for {} files",
embeddings.len(),
embeddings_queue.len(),
);
let mut i = 0;
let mut j = 0;
for embedding in embeddings.iter() {
while embeddings_queue[i].1.len() == j {
i += 1;
j = 0;
}
embeddings_queue[i].1[j].embedding = embedding.to_owned();
j += 1;
}
for (worktree_id, documents, path, mtime, job_handle) in embeddings_queue.into_iter() {
// for document in documents.iter() {
// // TODO: Update this so it doesn't panic
// assert!(
// document.embedding.len() > 0,
// "Document Embedding Not Complete"
// );
// }
db_update_tx
.send(DbOperation::InsertFile {
worktree_id,
documents,
path,
mtime,
job_handle,
})
.await
.unwrap();
}
}
}
fn enqueue_documents_to_embed(
job: EmbeddingJob,
queue_len: &mut usize,
embeddings_queue: &mut Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>,
embed_batch_tx: &channel::Sender<Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>>,
) {
let should_flush = match job {
EmbeddingJob::Enqueue {
documents,
worktree_id,
path,
mtime,
job_handle,
} => {
*queue_len += &documents.len();
embeddings_queue.push((worktree_id, documents, path, mtime, job_handle));
*queue_len >= EMBEDDINGS_BATCH_SIZE
}
EmbeddingJob::Flush => true,
};
if should_flush {
embed_batch_tx
.try_send(mem::take(embeddings_queue))
.unwrap();
*queue_len = 0;
}
}
async fn parse_file(
fs: &Arc<dyn Fs>,
pending_file: PendingFile,
retriever: &mut CodeContextRetriever,
batch_files_tx: &channel::Sender<EmbeddingJob>,
parsing_files_rx: &channel::Receiver<PendingFile>,
db_update_tx: &channel::Sender<DbOperation>,
) {
if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() {
if let Some(documents) = retriever
.parse_file(&pending_file.relative_path, &content, pending_file.language)
.log_err()
{
log::trace!(
"parsed path {:?}: {} documents",
pending_file.relative_path,
documents.len()
);
if documents.len() == 0 {
db_update_tx
.send(DbOperation::InsertFile {
worktree_id: pending_file.worktree_db_id,
documents,
path: pending_file.relative_path,
mtime: pending_file.modified_time,
job_handle: pending_file.job_handle,
})
.await
.unwrap();
} else {
batch_files_tx
.try_send(EmbeddingJob::Enqueue {
worktree_id: pending_file.worktree_db_id,
path: pending_file.relative_path,
mtime: pending_file.modified_time,
job_handle: pending_file.job_handle,
documents,
})
.unwrap();
}
}
}
if parsing_files_rx.len() == 0 {
batch_files_tx.try_send(EmbeddingJob::Flush).unwrap();
}
}
fn find_or_create_worktree(&self, path: PathBuf) -> impl Future<Output = Result<i64>> {
let (tx, rx) = oneshot::channel();
self.db_update_tx
@ -457,11 +466,11 @@ impl SemanticIndex {
async move { rx.await? }
}
fn index_project(
pub fn index_project(
&mut self,
project: ModelHandle<Project>,
cx: &mut ModelContext<Self>,
) -> Task<Result<usize>> {
) -> Task<Result<(usize, watch::Receiver<usize>)>> {
let worktree_scans_complete = project
.read(cx)
.worktrees(cx)
@ -483,7 +492,6 @@ impl SemanticIndex {
let language_registry = self.language_registry.clone();
let db_update_tx = self.db_update_tx.clone();
let parsing_files_tx = self.parsing_files_tx.clone();
let next_job_id = self.next_job_id.clone();
cx.spawn(|this, mut cx| async move {
futures::future::join_all(worktree_scans_complete).await;
@ -509,8 +517,8 @@ impl SemanticIndex {
);
}
// let mut pending_files: Vec<(PathBuf, ((i64, PathBuf, Arc<Language>, SystemTime), SystemTime))> = vec![];
let outstanding_jobs = Arc::new(Mutex::new(HashSet::new()));
let (job_count_tx, job_count_rx) = watch::channel_with(0);
let job_count_tx = Arc::new(Mutex::new(job_count_tx));
this.update(&mut cx, |this, _| {
this.projects.insert(
project.downgrade(),
@ -519,7 +527,8 @@ impl SemanticIndex {
.iter()
.map(|(a, b)| (*a, *b))
.collect(),
outstanding_jobs: outstanding_jobs.clone(),
outstanding_job_count_rx: job_count_rx.clone(),
outstanding_job_count_tx: job_count_tx.clone(),
},
);
});
@ -527,7 +536,6 @@ impl SemanticIndex {
cx.background()
.spawn(async move {
let mut count = 0;
let t0 = Instant::now();
for worktree in worktrees.into_iter() {
let mut file_mtimes = worktree_file_mtimes.remove(&worktree.id()).unwrap();
for file in worktree.files(false, 0) {
@ -552,14 +560,11 @@ impl SemanticIndex {
.map_or(false, |existing_mtime| existing_mtime == file.mtime);
if !already_stored {
log::trace!("sending for parsing: {:?}", path_buf);
count += 1;
let job_id = next_job_id.fetch_add(1, atomic::Ordering::SeqCst);
*job_count_tx.lock().borrow_mut() += 1;
let job_handle = JobHandle {
id: job_id,
set: Arc::downgrade(&outstanding_jobs),
tx: Arc::downgrade(&job_count_tx),
};
outstanding_jobs.lock().insert(job_id);
parsing_files_tx
.try_send(PendingFile {
worktree_db_id: db_ids_by_worktree_id[&worktree.id()],
@ -582,27 +587,22 @@ impl SemanticIndex {
.unwrap();
}
}
log::trace!(
"parsing worktree completed in {:?}",
t0.elapsed().as_millis()
);
Ok(count)
anyhow::Ok((count, job_count_rx))
})
.await
})
}
pub fn remaining_files_to_index_for_project(
pub fn outstanding_job_count_rx(
&self,
project: &ModelHandle<Project>,
) -> Option<usize> {
) -> Option<watch::Receiver<usize>> {
Some(
self.projects
.get(&project.downgrade())?
.outstanding_jobs
.lock()
.len(),
.outstanding_job_count_rx
.clone(),
)
}
@ -678,8 +678,9 @@ impl Entity for SemanticIndex {
impl Drop for JobHandle {
fn drop(&mut self) {
if let Some(set) = self.set.upgrade() {
set.lock().remove(&self.id);
if let Some(tx) = self.tx.upgrade() {
let mut tx = tx.lock();
*tx.borrow_mut() -= 1;
}
}
}

View file

@ -88,18 +88,13 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
let worktree_id = project.read_with(cx, |project, cx| {
project.worktrees(cx).next().unwrap().read(cx).id()
});
let file_count = store
let (file_count, outstanding_file_count) = store
.update(cx, |store, cx| store.index_project(project.clone(), cx))
.await
.unwrap();
assert_eq!(file_count, 3);
cx.foreground().run_until_parked();
store.update(cx, |store, _cx| {
assert_eq!(
store.remaining_files_to_index_for_project(&project),
Some(0)
);
});
assert_eq!(*outstanding_file_count.borrow(), 0);
let search_results = store
.update(cx, |store, cx| {
@ -128,19 +123,14 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
cx.foreground().run_until_parked();
let prev_embedding_count = embedding_provider.embedding_count();
let file_count = store
let (file_count, outstanding_file_count) = store
.update(cx, |store, cx| store.index_project(project.clone(), cx))
.await
.unwrap();
assert_eq!(file_count, 1);
cx.foreground().run_until_parked();
store.update(cx, |store, _cx| {
assert_eq!(
store.remaining_files_to_index_for_project(&project),
Some(0)
);
});
assert_eq!(*outstanding_file_count.borrow(), 0);
assert_eq!(
embedding_provider.embedding_count() - prev_embedding_count,