From 307d8d9c8d26ecaf4ecd2a3bddf58ec00be7a666 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Mon, 10 Jul 2023 17:50:19 -0400 Subject: [PATCH] Reduced redundant database connections on each worktree change. Co-authored-by: maxbrunsfeld --- crates/vector_store/src/db.rs | 78 +++++-- crates/vector_store/src/vector_store.rs | 282 ++++++++++-------------- 2 files changed, 182 insertions(+), 178 deletions(-) diff --git a/crates/vector_store/src/db.rs b/crates/vector_store/src/db.rs index 4882db443b..197e7d5696 100644 --- a/crates/vector_store/src/db.rs +++ b/crates/vector_store/src/db.rs @@ -1,4 +1,5 @@ use std::{ + cmp::Ordering, collections::HashMap, path::{Path, PathBuf}, rc::Rc, @@ -14,16 +15,6 @@ use rusqlite::{ types::{FromSql, FromSqlResult, ValueRef}, }; -// Note this is not an appropriate document -#[derive(Debug)] -pub struct DocumentRecord { - pub id: usize, - pub file_id: usize, - pub offset: usize, - pub name: String, - pub embedding: Embedding, -} - #[derive(Debug)] pub struct FileRecord { pub id: usize, @@ -32,7 +23,7 @@ pub struct FileRecord { } #[derive(Debug)] -pub struct Embedding(pub Vec); +struct Embedding(pub Vec); impl FromSql for Embedding { fn column_result(value: ValueRef) -> FromSqlResult { @@ -205,10 +196,35 @@ impl VectorDatabase { Ok(result) } - pub fn for_each_document( + pub fn top_k_search( &self, worktree_ids: &[i64], - mut f: impl FnMut(i64, Embedding), + query_embedding: &Vec, + limit: usize, + ) -> Result> { + let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1); + self.for_each_document(&worktree_ids, |id, embedding| { + eprintln!("document {id} {embedding:?}"); + + let similarity = dot(&embedding, &query_embedding); + let ix = match results + .binary_search_by(|(_, s)| similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)) + { + Ok(ix) => ix, + Err(ix) => ix, + }; + results.insert(ix, (id, similarity)); + results.truncate(limit); + })?; + + let ids = results.into_iter().map(|(id, _)| id).collect::>(); + self.get_documents_by_ids(&ids) + } + + fn for_each_document( + &self, + worktree_ids: &[i64], + mut f: impl FnMut(i64, Vec), ) -> Result<()> { let mut query_statement = self.db.prepare( " @@ -221,16 +237,20 @@ impl VectorDatabase { files.worktree_id IN rarray(?) ", )?; + query_statement .query_map(params![ids_to_sql(worktree_ids)], |row| { - Ok((row.get(0)?, row.get(1)?)) + Ok((row.get(0)?, row.get::<_, Embedding>(1)?)) })? .filter_map(|row| row.ok()) - .for_each(|row| f(row.0, row.1)); + .for_each(|(id, embedding)| { + dbg!("id"); + f(id, embedding.0) + }); Ok(()) } - pub fn get_documents_by_ids(&self, ids: &[i64]) -> Result> { + fn get_documents_by_ids(&self, ids: &[i64]) -> Result> { let mut statement = self.db.prepare( " SELECT @@ -279,3 +299,29 @@ fn ids_to_sql(ids: &[i64]) -> Rc> { .collect::>(), ) } + +pub(crate) fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 { + let len = vec_a.len(); + assert_eq!(len, vec_b.len()); + + let mut result = 0.0; + unsafe { + matrixmultiply::sgemm( + 1, + len, + 1, + 1.0, + vec_a.as_ptr(), + len as isize, + 1, + vec_b.as_ptr(), + 1, + len as isize, + 0.0, + &mut result as *mut f32, + 1, + 1, + ); + } + result +} diff --git a/crates/vector_store/src/vector_store.rs b/crates/vector_store/src/vector_store.rs index c27c4992f3..c42b7ab129 100644 --- a/crates/vector_store/src/vector_store.rs +++ b/crates/vector_store/src/vector_store.rs @@ -20,7 +20,6 @@ use parsing::{CodeContextRetriever, ParsedFile}; use project::{Fs, PathChange, Project, ProjectEntryId, WorktreeId}; use smol::channel; use std::{ - cmp::Ordering, collections::HashMap, path::{Path, PathBuf}, sync::Arc, @@ -112,10 +111,10 @@ pub struct VectorStore { database_url: Arc, embedding_provider: Arc, language_registry: Arc, - db_update_tx: channel::Sender, + db_update_tx: channel::Sender, parsing_files_tx: channel::Sender, _db_update_task: Task<()>, - _embed_batch_task: Vec>, + _embed_batch_task: Task<()>, _batch_files_task: Task<()>, _parsing_files_tasks: Vec>, projects: HashMap, ProjectState>, @@ -128,6 +127,30 @@ struct ProjectState { } impl ProjectState { + fn db_id_for_worktree_id(&self, id: WorktreeId) -> Option { + self.worktree_db_ids + .iter() + .find_map(|(worktree_id, db_id)| { + if *worktree_id == id { + Some(*db_id) + } else { + None + } + }) + } + + fn worktree_id_for_db_id(&self, id: i64) -> Option { + self.worktree_db_ids + .iter() + .find_map(|(worktree_id, db_id)| { + if *db_id == id { + Some(*worktree_id) + } else { + None + } + }) + } + fn update_pending_files(&mut self, pending_file: PendingFile, indexing_time: SystemTime) { // If Pending File Already Exists, Replace it with the new one // but keep the old indexing time @@ -185,7 +208,7 @@ pub struct SearchResult { pub file_path: PathBuf, } -enum DbWrite { +enum DbOperation { InsertFile { worktree_id: i64, indexed_file: ParsedFile, @@ -198,6 +221,10 @@ enum DbWrite { path: PathBuf, sender: oneshot::Sender>, }, + FileMTimes { + worktree_id: i64, + sender: oneshot::Sender>>, + }, } enum EmbeddingJob { @@ -243,20 +270,27 @@ impl VectorStore { let _db_update_task = cx.background().spawn(async move { while let Ok(job) = db_update_rx.recv().await { match job { - DbWrite::InsertFile { + DbOperation::InsertFile { worktree_id, indexed_file, } => { log::info!("Inserting Data for {:?}", &indexed_file.path); db.insert_file(worktree_id, indexed_file).log_err(); } - DbWrite::Delete { worktree_id, path } => { + DbOperation::Delete { worktree_id, path } => { db.delete_file(worktree_id, path).log_err(); } - DbWrite::FindOrCreateWorktree { path, sender } => { + DbOperation::FindOrCreateWorktree { path, sender } => { let id = db.find_or_create_worktree(&path); sender.send(id).ok(); } + DbOperation::FileMTimes { + worktree_id: worktree_db_id, + sender, + } => { + let file_mtimes = db.get_file_mtimes(worktree_db_id); + sender.send(file_mtimes).ok(); + } } } }); @@ -264,24 +298,18 @@ impl VectorStore { // embed_tx/rx: Embed Batch and Send to Database let (embed_batch_tx, embed_batch_rx) = channel::unbounded::)>>(); - let mut _embed_batch_task = Vec::new(); - for _ in 0..1 { - //cx.background().num_cpus() { + let _embed_batch_task = cx.background().spawn({ let db_update_tx = db_update_tx.clone(); - let embed_batch_rx = embed_batch_rx.clone(); let embedding_provider = embedding_provider.clone(); - _embed_batch_task.push(cx.background().spawn(async move { - while let Ok(embeddings_queue) = embed_batch_rx.recv().await { + async move { + while let Ok(mut embeddings_queue) = embed_batch_rx.recv().await { // Construct Batch - let mut embeddings_queue = embeddings_queue.clone(); let mut document_spans = vec![]; - for (_, _, document_span) in embeddings_queue.clone().into_iter() { - document_spans.extend(document_span); + for (_, _, document_span) in embeddings_queue.iter() { + document_spans.extend(document_span.iter().map(|s| s.as_str())); } - if let Ok(embeddings) = embedding_provider - .embed_batch(document_spans.iter().map(|x| &**x).collect()) - .await + if let Ok(embeddings) = embedding_provider.embed_batch(document_spans).await { let mut i = 0; let mut j = 0; @@ -306,7 +334,7 @@ impl VectorStore { } db_update_tx - .send(DbWrite::InsertFile { + .send(DbOperation::InsertFile { worktree_id, indexed_file, }) @@ -315,8 +343,9 @@ impl VectorStore { } } } - })) - } + } + }); + // batch_tx/rx: Batch Files to Send for Embeddings let (batch_files_tx, batch_files_rx) = channel::unbounded::(); let _batch_files_task = cx.background().spawn(async move { @@ -398,7 +427,21 @@ impl VectorStore { fn find_or_create_worktree(&self, path: PathBuf) -> impl Future> { let (tx, rx) = oneshot::channel(); self.db_update_tx - .try_send(DbWrite::FindOrCreateWorktree { path, sender: tx }) + .try_send(DbOperation::FindOrCreateWorktree { path, sender: tx }) + .unwrap(); + async move { rx.await? } + } + + fn get_file_mtimes( + &self, + worktree_id: i64, + ) -> impl Future>> { + let (tx, rx) = oneshot::channel(); + self.db_update_tx + .try_send(DbOperation::FileMTimes { + worktree_id, + sender: tx, + }) .unwrap(); async move { rx.await? } } @@ -450,26 +493,17 @@ impl VectorStore { .collect::>() }); - // Here we query the worktree ids, and yet we dont have them elsewhere - // We likely want to clean up these datastructures - let (mut worktree_file_times, db_ids_by_worktree_id) = cx - .background() - .spawn({ - let worktrees = worktrees.clone(); - async move { - let db = VectorDatabase::new(database_url.to_string_lossy().into())?; - let mut db_ids_by_worktree_id = HashMap::new(); - let mut file_times: HashMap> = - HashMap::new(); - for (worktree, db_id) in worktrees.iter().zip(worktree_db_ids) { - let db_id = db_id?; - db_ids_by_worktree_id.insert(worktree.id(), db_id); - file_times.insert(worktree.id(), db.get_file_mtimes(db_id)?); - } - anyhow::Ok((file_times, db_ids_by_worktree_id)) - } - }) - .await?; + let mut worktree_file_times = HashMap::new(); + let mut db_ids_by_worktree_id = HashMap::new(); + for (worktree, db_id) in worktrees.iter().zip(worktree_db_ids) { + let db_id = db_id?; + db_ids_by_worktree_id.insert(worktree.id(), db_id); + worktree_file_times.insert( + worktree.id(), + this.read_with(&cx, |this, _| this.get_file_mtimes(db_id)) + .await?, + ); + } cx.background() .spawn({ @@ -520,7 +554,7 @@ impl VectorStore { } for file in file_mtimes.keys() { db_update_tx - .try_send(DbWrite::Delete { + .try_send(DbOperation::Delete { worktree_id: db_ids_by_worktree_id[&worktree.id()], path: file.to_owned(), }) @@ -542,7 +576,7 @@ impl VectorStore { // greater than the previous embedded time by the REINDEXING_DELAY variable, we will send the file off to be indexed. let _subscription = cx.subscribe(&project, |this, project, event, cx| { if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event { - this.project_entries_changed(project, changes, cx, worktree_id); + this.project_entries_changed(project, changes.clone(), cx, worktree_id); } }); @@ -578,16 +612,7 @@ impl VectorStore { .worktrees(cx) .filter_map(|worktree| { let worktree_id = worktree.read(cx).id(); - project_state - .worktree_db_ids - .iter() - .find_map(|(id, db_id)| { - if *id == worktree_id { - Some(*db_id) - } else { - None - } - }) + project_state.db_id_for_worktree_id(worktree_id) }) .collect::>(); @@ -606,24 +631,12 @@ impl VectorStore { .next() .unwrap(); - let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1); - database.for_each_document(&worktree_db_ids, |id, embedding| { - let similarity = dot(&embedding.0, &phrase_embedding); - let ix = match results.binary_search_by(|(_, s)| { - similarity.partial_cmp(&s).unwrap_or(Ordering::Equal) - }) { - Ok(ix) => ix, - Err(ix) => ix, - }; - results.insert(ix, (id, similarity)); - results.truncate(limit); - })?; - - let ids = results.into_iter().map(|(id, _)| id).collect::>(); - database.get_documents_by_ids(&ids) + database.top_k_search(&worktree_db_ids, &phrase_embedding, limit) }) .await?; + dbg!(&documents); + this.read_with(&cx, |this, _| { let project_state = if let Some(state) = this.projects.get(&project.downgrade()) { state @@ -634,17 +647,7 @@ impl VectorStore { Ok(documents .into_iter() .filter_map(|(worktree_db_id, file_path, offset, name)| { - let worktree_id = - project_state - .worktree_db_ids - .iter() - .find_map(|(id, db_id)| { - if *db_id == worktree_db_id { - Some(*id) - } else { - None - } - })?; + let worktree_id = project_state.worktree_id_for_db_id(worktree_db_id)?; Some(SearchResult { worktree_id, name, @@ -660,51 +663,36 @@ impl VectorStore { fn project_entries_changed( &mut self, project: ModelHandle, - changes: &[(Arc, ProjectEntryId, PathChange)], + changes: Arc<[(Arc, ProjectEntryId, PathChange)]>, cx: &mut ModelContext<'_, VectorStore>, worktree_id: &WorktreeId, ) -> Option<()> { - let project_state = self.projects.get_mut(&project.downgrade())?; - let worktree_db_ids = project_state.worktree_db_ids.clone(); - let worktree = project.read(cx).worktree_for_id(worktree_id.clone(), cx)?; + let worktree = project + .read(cx) + .worktree_for_id(worktree_id.clone(), cx)? + .read(cx) + .snapshot(); - // Get Database - let (file_mtimes, worktree_db_id) = { - if let Ok(db) = VectorDatabase::new(self.database_url.to_string_lossy().into()) { - let worktree_db_id = { - let mut found_db_id = None; - for (w_id, db_id) in worktree_db_ids.into_iter() { - if &w_id == &worktree.read(cx).id() { - found_db_id = Some(db_id) - } - } - found_db_id - }?; + let worktree_db_id = self + .projects + .get(&project.downgrade())? + .db_id_for_worktree_id(worktree.id())?; + let file_mtimes = self.get_file_mtimes(worktree_db_id); - let file_mtimes = db.get_file_mtimes(worktree_db_id).log_err()?; - - Some((file_mtimes, worktree_db_id)) - } else { - return None; - } - }?; - - // Iterate Through Changes let language_registry = self.language_registry.clone(); - let parsing_files_tx = self.parsing_files_tx.clone(); - smol::block_on(async move { + cx.spawn(|this, mut cx| async move { + let file_mtimes = file_mtimes.await.log_err()?; + for change in changes.into_iter() { let change_path = change.0.clone(); - let absolute_path = worktree.read(cx).absolutize(&change_path); + let absolute_path = worktree.absolutize(&change_path); // Skip if git ignored or symlink - if let Some(entry) = worktree.read(cx).entry_for_id(change.1) { - if entry.is_ignored || entry.is_symlink { + if let Some(entry) = worktree.entry_for_id(change.1) { + if entry.is_ignored || entry.is_symlink || entry.is_external { continue; - } else { - log::info!("Testing for Reindexing: {:?}", &change_path); } - }; + } if let Ok(language) = language_registry .language_for_file(&change_path.to_path_buf(), None) @@ -718,27 +706,18 @@ impl VectorStore { continue; } - if let Some(modified_time) = { - let metadata = change_path.metadata(); - if metadata.is_err() { - None - } else { - let mtime = metadata.unwrap().modified(); - if mtime.is_err() { - None - } else { - Some(mtime.unwrap()) - } - } - } { - let existing_time = file_mtimes.get(&change_path.to_path_buf()); - let already_stored = existing_time - .map_or(false, |existing_time| &modified_time != existing_time); + let modified_time = change_path.metadata().log_err()?.modified().log_err()?; - let reindex_time = - modified_time + Duration::from_secs(REINDEXING_DELAY_SECONDS); + let existing_time = file_mtimes.get(&change_path.to_path_buf()); + let already_stored = existing_time + .map_or(false, |existing_time| &modified_time != existing_time); - if !already_stored { + if !already_stored { + this.update(&mut cx, |this, _| { + let reindex_time = + modified_time + Duration::from_secs(REINDEXING_DELAY_SECONDS); + + let project_state = this.projects.get_mut(&project.downgrade())?; project_state.update_pending_files( PendingFile { relative_path: change_path.to_path_buf(), @@ -751,13 +730,18 @@ impl VectorStore { ); for file in project_state.get_outstanding_files() { - parsing_files_tx.try_send(file).unwrap(); + this.parsing_files_tx.try_send(file).unwrap(); } - } + Some(()) + }); } } } - }); + + Some(()) + }) + .detach(); + Some(()) } } @@ -765,29 +749,3 @@ impl VectorStore { impl Entity for VectorStore { type Event = (); } - -fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 { - let len = vec_a.len(); - assert_eq!(len, vec_b.len()); - - let mut result = 0.0; - unsafe { - matrixmultiply::sgemm( - 1, - len, - 1, - 1.0, - vec_a.as_ptr(), - len as isize, - 1, - vec_b.as_ptr(), - 1, - len as isize, - 0.0, - &mut result as *mut f32, - 1, - 1, - ); - } - result -}