From 4f8b95cf0d99955555b6b086bed7c3153cd5bc92 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Tue, 29 Aug 2023 15:44:51 -0400 Subject: [PATCH 01/20] add proper handling for open ai rate limit delays --- Cargo.lock | 65 ++++++++++++++++- crates/semantic_index/Cargo.toml | 1 + crates/semantic_index/src/embedding.rs | 96 ++++++++++++++++---------- 3 files changed, 124 insertions(+), 38 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 347976691d..e0eb1947e2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3532,7 +3532,7 @@ dependencies = [ "gif", "jpeg-decoder", "num-iter", - "num-rational", + "num-rational 0.3.2", "num-traits", "png", "scoped_threadpool", @@ -4625,6 +4625,31 @@ dependencies = [ "winapi 0.3.9", ] +[[package]] +name = "num" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8536030f9fea7127f841b45bb6243b27255787fb4eb83958aa1ef9d2fdc0c36" +dependencies = [ + "num-bigint 0.2.6", + "num-complex", + "num-integer", + "num-iter", + "num-rational 0.2.4", + "num-traits", +] + +[[package]] +name = "num-bigint" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "090c7f9998ee0ff65aa5b723e4009f7b217707f1fb5ea551329cc4d6231fb304" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + [[package]] name = "num-bigint" version = "0.4.4" @@ -4653,6 +4678,16 @@ dependencies = [ "zeroize", ] +[[package]] +name = "num-complex" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6b19411a9719e753aff12e5187b74d60d3dc449ec3f4dc21e3989c3f554bc95" +dependencies = [ + "autocfg", + "num-traits", +] + [[package]] name = "num-derive" version = "0.3.3" @@ -4685,6 +4720,18 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-rational" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c000134b5dbf44adc5cb772486d335293351644b801551abe8f75c84cfa4aef" +dependencies = [ + "autocfg", + "num-bigint 0.2.6", + "num-integer", + "num-traits", +] + [[package]] name = "num-rational" version = "0.3.2" @@ -5001,6 +5048,17 @@ dependencies = [ "windows-targets 0.48.5", ] +[[package]] +name = "parse_duration" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7037e5e93e0172a5a96874380bf73bc6ecef022e26fa25f2be26864d6b3ba95d" +dependencies = [ + "lazy_static", + "num", + "regex", +] + [[package]] name = "password-hash" version = "0.2.3" @@ -6667,6 +6725,7 @@ dependencies = [ "log", "matrixmultiply", "parking_lot 0.11.2", + "parse_duration", "picker", "postage", "pretty_assertions", @@ -6998,7 +7057,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8eb4ea60fb301dc81dfc113df680571045d375ab7345d171c5dc7d7e13107a80" dependencies = [ "chrono", - "num-bigint", + "num-bigint 0.4.4", "num-traits", "thiserror", ] @@ -7230,7 +7289,7 @@ dependencies = [ "log", "md-5", "memchr", - "num-bigint", + "num-bigint 0.4.4", "once_cell", "paste", "percent-encoding", diff --git a/crates/semantic_index/Cargo.toml b/crates/semantic_index/Cargo.toml index 4e817fcbe2..d46346e0ab 100644 --- a/crates/semantic_index/Cargo.toml +++ b/crates/semantic_index/Cargo.toml @@ -39,6 +39,7 @@ rand.workspace = true schemars.workspace = true globset.workspace = true sha1 = "0.10.5" +parse_duration = "2.1.1" [dev-dependencies] gpui = { path = "../gpui", features = ["test-support"] } diff --git a/crates/semantic_index/src/embedding.rs b/crates/semantic_index/src/embedding.rs index f2269a786a..a9cb0245c4 100644 --- a/crates/semantic_index/src/embedding.rs +++ b/crates/semantic_index/src/embedding.rs @@ -7,6 +7,7 @@ use isahc::http::StatusCode; use isahc::prelude::Configurable; use isahc::{AsyncBody, Response}; use lazy_static::lazy_static; +use parse_duration::parse; use serde::{Deserialize, Serialize}; use std::env; use std::sync::Arc; @@ -84,10 +85,15 @@ impl OpenAIEmbeddings { span } - async fn send_request(&self, api_key: &str, spans: Vec<&str>) -> Result> { + async fn send_request( + &self, + api_key: &str, + spans: Vec<&str>, + request_timeout: u64, + ) -> Result> { let request = Request::post("https://api.openai.com/v1/embeddings") .redirect_policy(isahc::config::RedirectPolicy::Follow) - .timeout(Duration::from_secs(4)) + .timeout(Duration::from_secs(request_timeout)) .header("Content-Type", "application/json") .header("Authorization", format!("Bearer {}", api_key)) .body( @@ -114,45 +120,23 @@ impl EmbeddingProvider for OpenAIEmbeddings { .ok_or_else(|| anyhow!("no api key"))?; let mut request_number = 0; + let mut request_timeout: u64 = 10; let mut truncated = false; let mut response: Response; let mut spans: Vec = spans.iter().map(|x| x.to_string()).collect(); while request_number < MAX_RETRIES { response = self - .send_request(api_key, spans.iter().map(|x| &**x).collect()) + .send_request( + api_key, + spans.iter().map(|x| &**x).collect(), + request_timeout, + ) .await?; request_number += 1; - if request_number + 1 == MAX_RETRIES && response.status() != StatusCode::OK { - return Err(anyhow!( - "openai max retries, error: {:?}", - &response.status() - )); - } - match response.status() { - StatusCode::TOO_MANY_REQUESTS => { - let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64); - log::trace!( - "open ai rate limiting, delaying request by {:?} seconds", - delay.as_secs() - ); - self.executor.timer(delay).await; - } - StatusCode::BAD_REQUEST => { - // Only truncate if it hasnt been truncated before - if !truncated { - for span in spans.iter_mut() { - *span = Self::truncate(span.clone()); - } - truncated = true; - } else { - // If failing once already truncated, log the error and break the loop - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - log::trace!("open ai bad request: {:?} {:?}", &response.status(), body); - break; - } + StatusCode::REQUEST_TIMEOUT => { + request_timeout += 5; } StatusCode::OK => { let mut body = String::new(); @@ -163,18 +147,60 @@ impl EmbeddingProvider for OpenAIEmbeddings { "openai embedding completed. tokens: {:?}", response.usage.total_tokens ); + return Ok(response .data .into_iter() .map(|embedding| embedding.embedding) .collect()); } + StatusCode::TOO_MANY_REQUESTS => { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + let delay_duration = { + let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64); + if let Some(time_to_reset) = + response.headers().get("x-ratelimit-reset-tokens") + { + if let Ok(time_str) = time_to_reset.to_str() { + parse(time_str).unwrap_or(delay) + } else { + delay + } + } else { + delay + } + }; + + log::trace!( + "openai rate limiting: waiting {:?} until lifted", + &delay_duration + ); + + self.executor.timer(delay_duration).await; + } _ => { - return Err(anyhow!("openai embedding failed {}", response.status())); + // TODO: Move this to parsing step + // Only truncate if it hasnt been truncated before + if !truncated { + for span in spans.iter_mut() { + *span = Self::truncate(span.clone()); + } + truncated = true; + } else { + // If failing once already truncated, log the error and break the loop + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + return Err(anyhow!( + "open ai bad request: {:?} {:?}", + &response.status(), + body + )); + } } } } - - Err(anyhow!("openai embedding failed")) + Err(anyhow!("openai max retries")) } } From a7e6a65debbe032edfd180e88f6be545edf89281 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Tue, 29 Aug 2023 17:14:44 -0400 Subject: [PATCH 02/20] reindex files in the background after they have not been edited for 10 minutes Co-authored-by: Max --- crates/semantic_index/src/semantic_index.rs | 416 +++++++++----------- 1 file changed, 188 insertions(+), 228 deletions(-) diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 736f2c98a8..2da0d84baf 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -16,16 +16,18 @@ use language::{Anchor, Buffer, Language, LanguageRegistry}; use parking_lot::Mutex; use parsing::{CodeContextRetriever, Document, PARSEABLE_ENTIRE_FILE_TYPES}; use postage::watch; -use project::{search::PathMatcher, Fs, PathChange, Project, ProjectEntryId, WorktreeId}; +use project::{ + search::PathMatcher, Fs, PathChange, Project, ProjectEntryId, ProjectPath, Worktree, WorktreeId, +}; use smol::channel; use std::{ cmp::Ordering, - collections::HashMap, + collections::{BTreeMap, HashMap}, mem, ops::Range, path::{Path, PathBuf}, sync::{Arc, Weak}, - time::{Instant, SystemTime}, + time::{Duration, Instant, SystemTime}, }; use util::{ channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME}, @@ -37,6 +39,7 @@ use workspace::WorkspaceCreated; const SEMANTIC_INDEX_VERSION: usize = 7; const EMBEDDINGS_BATCH_SIZE: usize = 80; +const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(600); pub fn init( fs: Arc, @@ -77,6 +80,7 @@ pub fn init( let semantic_index = SemanticIndex::new( fs, db_file_path, + // Arc::new(embedding::DummyEmbeddings {}), Arc::new(OpenAIEmbeddings { client: http_client, executor: cx.background(), @@ -113,9 +117,14 @@ struct ProjectState { worktree_db_ids: Vec<(WorktreeId, i64)>, _subscription: gpui::Subscription, outstanding_job_count_rx: watch::Receiver, - _outstanding_job_count_tx: Arc>>, - job_queue_tx: channel::Sender, - _queue_update_task: Task<()>, + outstanding_job_count_tx: Arc>>, + changed_paths: BTreeMap, +} + +struct ChangedPathInfo { + changed_at: Instant, + mtime: SystemTime, + is_deleted: bool, } #[derive(Clone)] @@ -133,31 +142,21 @@ impl JobHandle { } } } + impl ProjectState { fn new( - cx: &mut AppContext, subscription: gpui::Subscription, worktree_db_ids: Vec<(WorktreeId, i64)>, - outstanding_job_count_rx: watch::Receiver, - _outstanding_job_count_tx: Arc>>, + changed_paths: BTreeMap, ) -> Self { - let (job_queue_tx, job_queue_rx) = channel::unbounded(); - let _queue_update_task = cx.background().spawn({ - let mut worktree_queue = HashMap::new(); - async move { - while let Ok(operation) = job_queue_rx.recv().await { - Self::update_queue(&mut worktree_queue, operation); - } - } - }); - + let (outstanding_job_count_tx, outstanding_job_count_rx) = watch::channel_with(0); + let outstanding_job_count_tx = Arc::new(Mutex::new(outstanding_job_count_tx)); Self { worktree_db_ids, outstanding_job_count_rx, - _outstanding_job_count_tx, + outstanding_job_count_tx, + changed_paths, _subscription: subscription, - _queue_update_task, - job_queue_tx, } } @@ -165,41 +164,6 @@ impl ProjectState { self.outstanding_job_count_rx.borrow().clone() } - fn update_queue(queue: &mut HashMap, operation: IndexOperation) { - match operation { - IndexOperation::FlushQueue => { - let queue = std::mem::take(queue); - for (_, op) in queue { - match op { - IndexOperation::IndexFile { - absolute_path: _, - payload, - tx, - } => { - let _ = tx.try_send(payload); - } - IndexOperation::DeleteFile { - absolute_path: _, - payload, - tx, - } => { - let _ = tx.try_send(payload); - } - _ => {} - } - } - } - IndexOperation::IndexFile { - ref absolute_path, .. - } - | IndexOperation::DeleteFile { - ref absolute_path, .. - } => { - queue.insert(absolute_path.clone(), operation); - } - } - } - fn db_id_for_worktree_id(&self, id: WorktreeId) -> Option { self.worktree_db_ids .iter() @@ -230,23 +194,10 @@ pub struct PendingFile { worktree_db_id: i64, relative_path: PathBuf, absolute_path: PathBuf, - language: Arc, + language: Option>, modified_time: SystemTime, job_handle: JobHandle, } -enum IndexOperation { - IndexFile { - absolute_path: PathBuf, - payload: PendingFile, - tx: channel::Sender, - }, - DeleteFile { - absolute_path: PathBuf, - payload: DbOperation, - tx: channel::Sender, - }, - FlushQueue, -} pub struct SearchResult { pub buffer: ModelHandle, @@ -582,13 +533,13 @@ impl SemanticIndex { parsing_files_rx: &channel::Receiver, db_update_tx: &channel::Sender, ) { + let Some(language) = pending_file.language else { + return; + }; + if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() { if let Some(documents) = retriever - .parse_file_with_template( - &pending_file.relative_path, - &content, - pending_file.language, - ) + .parse_file_with_template(&pending_file.relative_path, &content, language) .log_err() { log::trace!( @@ -679,103 +630,50 @@ impl SemanticIndex { } fn project_entries_changed( - &self, + &mut self, project: ModelHandle, changes: Arc<[(Arc, ProjectEntryId, PathChange)]>, cx: &mut ModelContext<'_, SemanticIndex>, worktree_id: &WorktreeId, - ) -> Result<()> { - let parsing_files_tx = self.parsing_files_tx.clone(); - let db_update_tx = self.db_update_tx.clone(); - let (job_queue_tx, outstanding_job_tx, worktree_db_id) = { - let state = self - .projects - .get(&project.downgrade()) - .ok_or(anyhow!("Project not yet initialized"))?; - let worktree_db_id = state - .db_id_for_worktree_id(*worktree_id) - .ok_or(anyhow!("Worktree ID in Database Not Available"))?; - ( - state.job_queue_tx.clone(), - state._outstanding_job_count_tx.clone(), - worktree_db_id, - ) + ) { + let Some(worktree) = project.read(cx).worktree_for_id(worktree_id.clone(), cx) else { + return; + }; + let project = project.downgrade(); + let Some(project_state) = self.projects.get_mut(&project) else { + return; }; - let language_registry = self.language_registry.clone(); - let parsing_files_tx = parsing_files_tx.clone(); - let db_update_tx = db_update_tx.clone(); + let worktree = worktree.read(cx); + let change_time = Instant::now(); + for (path, entry_id, change) in changes.iter() { + let Some(entry) = worktree.entry_for_id(*entry_id) else { + continue; + }; + if entry.is_ignored || entry.is_symlink || entry.is_external { + continue; + } + let project_path = ProjectPath { + worktree_id: *worktree_id, + path: path.clone(), + }; + project_state.changed_paths.insert( + project_path, + ChangedPathInfo { + changed_at: change_time, + mtime: entry.mtime, + is_deleted: *change == PathChange::Removed, + }, + ); + } - let worktree = project - .read(cx) - .worktree_for_id(worktree_id.clone(), cx) - .ok_or(anyhow!("Worktree not available"))? - .read(cx) - .snapshot(); - cx.spawn(|_, _| async move { - let worktree = worktree.clone(); - for (path, entry_id, path_change) in changes.iter() { - let relative_path = path.to_path_buf(); - let absolute_path = worktree.absolutize(path); - - let Some(entry) = worktree.entry_for_id(*entry_id) else { - continue; - }; - if entry.is_ignored || entry.is_symlink || entry.is_external { - continue; - } - - log::trace!("File Event: {:?}, Path: {:?}", &path_change, &path); - match path_change { - PathChange::AddedOrUpdated | PathChange::Updated | PathChange::Added => { - if let Ok(language) = language_registry - .language_for_file(&relative_path, None) - .await - { - if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref()) - && &language.name().as_ref() != &"Markdown" - && language - .grammar() - .and_then(|grammar| grammar.embedding_config.as_ref()) - .is_none() - { - continue; - } - - let job_handle = JobHandle::new(&outstanding_job_tx); - let new_operation = IndexOperation::IndexFile { - absolute_path: absolute_path.clone(), - payload: PendingFile { - worktree_db_id, - relative_path, - absolute_path, - language, - modified_time: entry.mtime, - job_handle, - }, - tx: parsing_files_tx.clone(), - }; - let _ = job_queue_tx.try_send(new_operation); - } - } - PathChange::Removed => { - let new_operation = IndexOperation::DeleteFile { - absolute_path, - payload: DbOperation::Delete { - worktree_id: worktree_db_id, - path: relative_path, - }, - tx: db_update_tx.clone(), - }; - let _ = job_queue_tx.try_send(new_operation); - } - _ => {} - } + cx.spawn_weak(|this, mut cx| async move { + cx.background().timer(BACKGROUND_INDEXING_DELAY).await; + if let Some((this, project)) = this.upgrade(&cx).zip(project.upgrade(&cx)) { + Self::reindex_changed_paths(this, project, Some(change_time), &mut cx).await; } }) .detach(); - - Ok(()) } pub fn initialize_project( @@ -805,14 +703,11 @@ impl SemanticIndex { let _subscription = cx.subscribe(&project, |this, project, event, cx| { if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event { - let _ = - this.project_entries_changed(project.clone(), changes.clone(), cx, worktree_id); + this.project_entries_changed(project.clone(), changes.clone(), cx, worktree_id); }; }); let language_registry = self.language_registry.clone(); - let parsing_files_tx = self.parsing_files_tx.clone(); - let db_update_tx = self.db_update_tx.clone(); cx.spawn(|this, mut cx| async move { futures::future::join_all(worktree_scans_complete).await; @@ -843,17 +738,13 @@ impl SemanticIndex { .map(|(a, b)| (*a, *b)) .collect(); - let (job_count_tx, job_count_rx) = watch::channel_with(0); - let job_count_tx = Arc::new(Mutex::new(job_count_tx)); - let job_count_tx_longlived = job_count_tx.clone(); - - let worktree_files = cx + let changed_paths = cx .background() .spawn(async move { - let mut worktree_files = Vec::new(); + let mut changed_paths = BTreeMap::new(); + let now = Instant::now(); for worktree in worktrees.into_iter() { let mut file_mtimes = worktree_file_mtimes.remove(&worktree.id()).unwrap(); - let worktree_db_id = db_ids_by_worktree_id[&worktree.id()]; for file in worktree.files(false, 0) { let absolute_path = worktree.absolutize(&file.path); @@ -876,59 +767,51 @@ impl SemanticIndex { continue; } - let path_buf = file.path.to_path_buf(); let stored_mtime = file_mtimes.remove(&file.path.to_path_buf()); let already_stored = stored_mtime .map_or(false, |existing_mtime| existing_mtime == file.mtime); if !already_stored { - let job_handle = JobHandle::new(&job_count_tx); - worktree_files.push(IndexOperation::IndexFile { - absolute_path: absolute_path.clone(), - payload: PendingFile { - worktree_db_id, - relative_path: path_buf, - absolute_path, - language, - job_handle, - modified_time: file.mtime, + changed_paths.insert( + ProjectPath { + worktree_id: worktree.id(), + path: file.path.clone(), }, - tx: parsing_files_tx.clone(), - }); + ChangedPathInfo { + changed_at: now, + mtime: file.mtime, + is_deleted: false, + }, + ); } } } + // Clean up entries from database that are no longer in the worktree. - for (path, _) in file_mtimes { - worktree_files.push(IndexOperation::DeleteFile { - absolute_path: worktree.absolutize(path.as_path()), - payload: DbOperation::Delete { - worktree_id: worktree_db_id, - path, + for (path, mtime) in file_mtimes { + changed_paths.insert( + ProjectPath { + worktree_id: worktree.id(), + path: path.into(), }, - tx: db_update_tx.clone(), - }); + ChangedPathInfo { + changed_at: now, + mtime, + is_deleted: true, + }, + ); } } - anyhow::Ok(worktree_files) + anyhow::Ok(changed_paths) }) .await?; - this.update(&mut cx, |this, cx| { - let project_state = ProjectState::new( - cx, - _subscription, - worktree_db_ids, - job_count_rx, - job_count_tx_longlived, + this.update(&mut cx, |this, _| { + this.projects.insert( + project.downgrade(), + ProjectState::new(_subscription, worktree_db_ids, changed_paths), ); - - for op in worktree_files { - let _ = project_state.job_queue_tx.try_send(op); - } - - this.projects.insert(project.downgrade(), project_state); }); Result::<(), _>::Ok(()) }) @@ -939,27 +822,17 @@ impl SemanticIndex { project: ModelHandle, cx: &mut ModelContext, ) -> Task)>> { - let state = self.projects.get_mut(&project.downgrade()); - let state = if state.is_none() { - return Task::Ready(Some(Err(anyhow!("Project not yet initialized")))); - } else { - state.unwrap() - }; - - // let parsing_files_tx = self.parsing_files_tx.clone(); - // let db_update_tx = self.db_update_tx.clone(); - let job_count_rx = state.outstanding_job_count_rx.clone(); - let count = state.get_outstanding_count(); - cx.spawn(|this, mut cx| async move { - this.update(&mut cx, |this, _| { - let Some(state) = this.projects.get_mut(&project.downgrade()) else { - return; - }; - let _ = state.job_queue_tx.try_send(IndexOperation::FlushQueue); - }); + Self::reindex_changed_paths(this.clone(), project.clone(), None, &mut cx).await; - Ok((count, job_count_rx)) + this.update(&mut cx, |this, _cx| { + let Some(state) = this.projects.get(&project.downgrade()) else { + return Err(anyhow!("Project not yet initialized")); + }; + let job_count_rx = state.outstanding_job_count_rx.clone(); + let count = state.get_outstanding_count(); + Ok((count, job_count_rx)) + }) }) } @@ -1110,6 +983,93 @@ impl SemanticIndex { .collect::>()) }) } + + async fn reindex_changed_paths( + this: ModelHandle, + project: ModelHandle, + last_changed_before: Option, + cx: &mut AsyncAppContext, + ) { + let mut pending_files = Vec::new(); + let (language_registry, parsing_files_tx) = this.update(cx, |this, cx| { + if let Some(project_state) = this.projects.get_mut(&project.downgrade()) { + let outstanding_job_count_tx = &project_state.outstanding_job_count_tx; + let db_ids = &project_state.worktree_db_ids; + let mut worktree: Option> = None; + + project_state.changed_paths.retain(|path, info| { + if let Some(last_changed_before) = last_changed_before { + if info.changed_at > last_changed_before { + return true; + } + } + + if worktree + .as_ref() + .map_or(true, |tree| tree.read(cx).id() != path.worktree_id) + { + worktree = project.read(cx).worktree_for_id(path.worktree_id, cx); + } + let Some(worktree) = &worktree else { + return false; + }; + + let Some(worktree_db_id) = db_ids + .iter() + .find_map(|entry| (entry.0 == path.worktree_id).then_some(entry.1)) + else { + return false; + }; + + if info.is_deleted { + this.db_update_tx + .try_send(DbOperation::Delete { + worktree_id: worktree_db_id, + path: path.path.to_path_buf(), + }) + .ok(); + } else { + let absolute_path = worktree.read(cx).absolutize(&path.path); + let job_handle = JobHandle::new(&outstanding_job_count_tx); + pending_files.push(PendingFile { + absolute_path, + relative_path: path.path.to_path_buf(), + language: None, + job_handle, + modified_time: info.mtime, + worktree_db_id, + }); + } + + false + }); + } + + ( + this.language_registry.clone(), + this.parsing_files_tx.clone(), + ) + }); + + for mut pending_file in pending_files { + if let Ok(language) = language_registry + .language_for_file(&pending_file.relative_path, None) + .await + { + if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref()) + && &language.name().as_ref() != &"Markdown" + && language + .grammar() + .and_then(|grammar| grammar.embedding_config.as_ref()) + .is_none() + { + continue; + } + pending_file.language = Some(language); + } + parsing_files_tx.try_send(pending_file).ok(); + } + } } impl Entity for SemanticIndex { From e377ada1a9aef3335f08543f6036b69c6adc0ddf Mon Sep 17 00:00:00 2001 From: KCaverly Date: Wed, 30 Aug 2023 11:05:46 -0400 Subject: [PATCH 03/20] added token count to documents during parsing --- crates/semantic_index/src/embedding.rs | 14 +++++++++ crates/semantic_index/src/parsing.rs | 19 ++++++++++-- crates/semantic_index/src/semantic_index.rs | 3 +- .../src/semantic_index_tests.rs | 30 +++++++++++++------ 4 files changed, 54 insertions(+), 12 deletions(-) diff --git a/crates/semantic_index/src/embedding.rs b/crates/semantic_index/src/embedding.rs index a9cb0245c4..72621d3138 100644 --- a/crates/semantic_index/src/embedding.rs +++ b/crates/semantic_index/src/embedding.rs @@ -54,6 +54,8 @@ struct OpenAIEmbeddingUsage { #[async_trait] pub trait EmbeddingProvider: Sync + Send { async fn embed_batch(&self, spans: Vec<&str>) -> Result>>; + fn count_tokens(&self, span: &str) -> usize; + // fn truncate(&self, span: &str) -> Result<&str>; } pub struct DummyEmbeddings {} @@ -66,6 +68,12 @@ impl EmbeddingProvider for DummyEmbeddings { let dummy_vec = vec![0.32 as f32; 1536]; return Ok(vec![dummy_vec; spans.len()]); } + + fn count_tokens(&self, span: &str) -> usize { + // For Dummy Providers, we are going to use OpenAI tokenization for ease + let tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); + tokens.len() + } } const OPENAI_INPUT_LIMIT: usize = 8190; @@ -111,6 +119,12 @@ impl OpenAIEmbeddings { #[async_trait] impl EmbeddingProvider for OpenAIEmbeddings { + fn count_tokens(&self, span: &str) -> usize { + // For Dummy Providers, we are going to use OpenAI tokenization for ease + let tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); + tokens.len() + } + async fn embed_batch(&self, spans: Vec<&str>) -> Result>> { const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45]; const MAX_RETRIES: usize = 4; diff --git a/crates/semantic_index/src/parsing.rs b/crates/semantic_index/src/parsing.rs index 4aefb0b00d..b106e5055b 100644 --- a/crates/semantic_index/src/parsing.rs +++ b/crates/semantic_index/src/parsing.rs @@ -1,3 +1,4 @@ +use crate::embedding::EmbeddingProvider; use anyhow::{anyhow, Ok, Result}; use language::{Grammar, Language}; use sha1::{Digest, Sha1}; @@ -17,6 +18,7 @@ pub struct Document { pub content: String, pub embedding: Vec, pub sha1: [u8; 20], + pub token_count: usize, } const CODE_CONTEXT_TEMPLATE: &str = @@ -30,6 +32,7 @@ pub const PARSEABLE_ENTIRE_FILE_TYPES: &[&str] = pub struct CodeContextRetriever { pub parser: Parser, pub cursor: QueryCursor, + pub embedding_provider: Arc, } // Every match has an item, this represents the fundamental treesitter symbol and anchors the search @@ -47,10 +50,11 @@ pub struct CodeContextMatch { } impl CodeContextRetriever { - pub fn new() -> Self { + pub fn new(embedding_provider: Arc) -> Self { Self { parser: Parser::new(), cursor: QueryCursor::new(), + embedding_provider, } } @@ -68,12 +72,15 @@ impl CodeContextRetriever { let mut sha1 = Sha1::new(); sha1.update(&document_span); + let token_count = self.embedding_provider.count_tokens(&document_span); + Ok(vec![Document { range: 0..content.len(), content: document_span, embedding: Vec::new(), name: language_name.to_string(), sha1: sha1.finalize().into(), + token_count, }]) } @@ -85,12 +92,15 @@ impl CodeContextRetriever { let mut sha1 = Sha1::new(); sha1.update(&document_span); + let token_count = self.embedding_provider.count_tokens(&document_span); + Ok(vec![Document { range: 0..content.len(), content: document_span, embedding: Vec::new(), name: "Markdown".to_string(), sha1: sha1.finalize().into(), + token_count, }]) } @@ -166,10 +176,14 @@ impl CodeContextRetriever { let mut documents = self.parse_file(content, language)?; for document in &mut documents { - document.content = CODE_CONTEXT_TEMPLATE + let document_content = CODE_CONTEXT_TEMPLATE .replace("", relative_path.to_string_lossy().as_ref()) .replace("", language_name.as_ref()) .replace("item", &document.content); + + let token_count = self.embedding_provider.count_tokens(&document_content); + document.content = document_content; + document.token_count = token_count; } Ok(documents) } @@ -272,6 +286,7 @@ impl CodeContextRetriever { range: item_range.clone(), embedding: vec![], sha1: sha1.finalize().into(), + token_count: 0, }) } diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 2da0d84baf..ab05ca7581 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -332,8 +332,9 @@ impl SemanticIndex { let parsing_files_rx = parsing_files_rx.clone(); let batch_files_tx = batch_files_tx.clone(); let db_update_tx = db_update_tx.clone(); + let embedding_provider = embedding_provider.clone(); _parsing_files_tasks.push(cx.background().spawn(async move { - let mut retriever = CodeContextRetriever::new(); + let mut retriever = CodeContextRetriever::new(embedding_provider.clone()); while let Ok(pending_file) = parsing_files_rx.recv().await { Self::parse_file( &fs, diff --git a/crates/semantic_index/src/semantic_index_tests.rs b/crates/semantic_index/src/semantic_index_tests.rs index 32d8bb0fb8..cb318a9fd6 100644 --- a/crates/semantic_index/src/semantic_index_tests.rs +++ b/crates/semantic_index/src/semantic_index_tests.rs @@ -1,6 +1,6 @@ use crate::{ db::dot, - embedding::EmbeddingProvider, + embedding::{DummyEmbeddings, EmbeddingProvider}, parsing::{subtract_ranges, CodeContextRetriever, Document}, semantic_index_settings::SemanticIndexSettings, SearchResult, SemanticIndex, @@ -227,7 +227,8 @@ fn assert_search_results( #[gpui::test] async fn test_code_context_retrieval_rust() { let language = rust_lang(); - let mut retriever = CodeContextRetriever::new(); + let embedding_provider = Arc::new(DummyEmbeddings {}); + let mut retriever = CodeContextRetriever::new(embedding_provider); let text = " /// A doc comment @@ -314,7 +315,8 @@ async fn test_code_context_retrieval_rust() { #[gpui::test] async fn test_code_context_retrieval_json() { let language = json_lang(); - let mut retriever = CodeContextRetriever::new(); + let embedding_provider = Arc::new(DummyEmbeddings {}); + let mut retriever = CodeContextRetriever::new(embedding_provider); let text = r#" { @@ -397,7 +399,8 @@ fn assert_documents_eq( #[gpui::test] async fn test_code_context_retrieval_javascript() { let language = js_lang(); - let mut retriever = CodeContextRetriever::new(); + let embedding_provider = Arc::new(DummyEmbeddings {}); + let mut retriever = CodeContextRetriever::new(embedding_provider); let text = " /* globals importScripts, backend */ @@ -495,7 +498,8 @@ async fn test_code_context_retrieval_javascript() { #[gpui::test] async fn test_code_context_retrieval_lua() { let language = lua_lang(); - let mut retriever = CodeContextRetriever::new(); + let embedding_provider = Arc::new(DummyEmbeddings {}); + let mut retriever = CodeContextRetriever::new(embedding_provider); let text = r#" -- Creates a new class @@ -568,7 +572,8 @@ async fn test_code_context_retrieval_lua() { #[gpui::test] async fn test_code_context_retrieval_elixir() { let language = elixir_lang(); - let mut retriever = CodeContextRetriever::new(); + let embedding_provider = Arc::new(DummyEmbeddings {}); + let mut retriever = CodeContextRetriever::new(embedding_provider); let text = r#" defmodule File.Stream do @@ -684,7 +689,8 @@ async fn test_code_context_retrieval_elixir() { #[gpui::test] async fn test_code_context_retrieval_cpp() { let language = cpp_lang(); - let mut retriever = CodeContextRetriever::new(); + let embedding_provider = Arc::new(DummyEmbeddings {}); + let mut retriever = CodeContextRetriever::new(embedding_provider); let text = " /** @@ -836,7 +842,8 @@ async fn test_code_context_retrieval_cpp() { #[gpui::test] async fn test_code_context_retrieval_ruby() { let language = ruby_lang(); - let mut retriever = CodeContextRetriever::new(); + let embedding_provider = Arc::new(DummyEmbeddings {}); + let mut retriever = CodeContextRetriever::new(embedding_provider); let text = r#" # This concern is inspired by "sudo mode" on GitHub. It @@ -1026,7 +1033,8 @@ async fn test_code_context_retrieval_ruby() { #[gpui::test] async fn test_code_context_retrieval_php() { let language = php_lang(); - let mut retriever = CodeContextRetriever::new(); + let embedding_provider = Arc::new(DummyEmbeddings {}); + let mut retriever = CodeContextRetriever::new(embedding_provider); let text = r#" usize { + span.len() + } + async fn embed_batch(&self, spans: Vec<&str>) -> Result>> { self.embedding_count .fetch_add(spans.len(), atomic::Ordering::SeqCst); From 76caea80f7543cf86eaf0f4e899f06ea478f3d8a Mon Sep 17 00:00:00 2001 From: KCaverly Date: Wed, 30 Aug 2023 11:58:45 -0400 Subject: [PATCH 04/20] add should_truncate to embedding providers --- crates/semantic_index/src/embedding.rs | 19 +++++++++++++++++++ .../src/semantic_index_tests.rs | 4 ++++ 2 files changed, 23 insertions(+) diff --git a/crates/semantic_index/src/embedding.rs b/crates/semantic_index/src/embedding.rs index 72621d3138..3dd979f01b 100644 --- a/crates/semantic_index/src/embedding.rs +++ b/crates/semantic_index/src/embedding.rs @@ -55,6 +55,7 @@ struct OpenAIEmbeddingUsage { pub trait EmbeddingProvider: Sync + Send { async fn embed_batch(&self, spans: Vec<&str>) -> Result>>; fn count_tokens(&self, span: &str) -> usize; + fn should_truncate(&self, span: &str) -> bool; // fn truncate(&self, span: &str) -> Result<&str>; } @@ -74,6 +75,20 @@ impl EmbeddingProvider for DummyEmbeddings { let tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); tokens.len() } + + fn should_truncate(&self, span: &str) -> bool { + self.count_tokens(span) > OPENAI_INPUT_LIMIT + + // let tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); + // let Ok(output) = { + // if tokens.len() > OPENAI_INPUT_LIMIT { + // tokens.truncate(OPENAI_INPUT_LIMIT); + // OPENAI_BPE_TOKENIZER.decode(tokens) + // } else { + // Ok(span) + // } + // }; + } } const OPENAI_INPUT_LIMIT: usize = 8190; @@ -125,6 +140,10 @@ impl EmbeddingProvider for OpenAIEmbeddings { tokens.len() } + fn should_truncate(&self, span: &str) -> bool { + self.count_tokens(span) > OPENAI_INPUT_LIMIT + } + async fn embed_batch(&self, spans: Vec<&str>) -> Result>> { const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45]; const MAX_RETRIES: usize = 4; diff --git a/crates/semantic_index/src/semantic_index_tests.rs b/crates/semantic_index/src/semantic_index_tests.rs index cb318a9fd6..48cefd93b1 100644 --- a/crates/semantic_index/src/semantic_index_tests.rs +++ b/crates/semantic_index/src/semantic_index_tests.rs @@ -1228,6 +1228,10 @@ impl EmbeddingProvider for FakeEmbeddingProvider { span.len() } + fn should_truncate(&self, span: &str) -> bool { + false + } + async fn embed_batch(&self, spans: Vec<&str>) -> Result>> { self.embedding_count .fetch_add(spans.len(), atomic::Ordering::SeqCst); From 97810471569618955c241e4137629b578c46285b Mon Sep 17 00:00:00 2001 From: KCaverly Date: Wed, 30 Aug 2023 12:13:26 -0400 Subject: [PATCH 05/20] move truncation to parsing step leveraging the EmbeddingProvider trait --- crates/semantic_index/src/embedding.rs | 78 +++++++++---------- crates/semantic_index/src/parsing.rs | 4 + .../src/semantic_index_tests.rs | 4 + 3 files changed, 45 insertions(+), 41 deletions(-) diff --git a/crates/semantic_index/src/embedding.rs b/crates/semantic_index/src/embedding.rs index 3dd979f01b..cba34439c8 100644 --- a/crates/semantic_index/src/embedding.rs +++ b/crates/semantic_index/src/embedding.rs @@ -56,7 +56,7 @@ pub trait EmbeddingProvider: Sync + Send { async fn embed_batch(&self, spans: Vec<&str>) -> Result>>; fn count_tokens(&self, span: &str) -> usize; fn should_truncate(&self, span: &str) -> bool; - // fn truncate(&self, span: &str) -> Result<&str>; + fn truncate(&self, span: &str) -> String; } pub struct DummyEmbeddings {} @@ -78,36 +78,27 @@ impl EmbeddingProvider for DummyEmbeddings { fn should_truncate(&self, span: &str) -> bool { self.count_tokens(span) > OPENAI_INPUT_LIMIT + } - // let tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); - // let Ok(output) = { - // if tokens.len() > OPENAI_INPUT_LIMIT { - // tokens.truncate(OPENAI_INPUT_LIMIT); - // OPENAI_BPE_TOKENIZER.decode(tokens) - // } else { - // Ok(span) - // } - // }; + fn truncate(&self, span: &str) -> String { + let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); + let output = if tokens.len() > OPENAI_INPUT_LIMIT { + tokens.truncate(OPENAI_INPUT_LIMIT); + OPENAI_BPE_TOKENIZER + .decode(tokens) + .ok() + .unwrap_or_else(|| span.to_string()) + } else { + span.to_string() + }; + + output } } const OPENAI_INPUT_LIMIT: usize = 8190; impl OpenAIEmbeddings { - fn truncate(span: String) -> String { - let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span.as_ref()); - if tokens.len() > OPENAI_INPUT_LIMIT { - tokens.truncate(OPENAI_INPUT_LIMIT); - let result = OPENAI_BPE_TOKENIZER.decode(tokens.clone()); - if result.is_ok() { - let transformed = result.unwrap(); - return transformed; - } - } - - span - } - async fn send_request( &self, api_key: &str, @@ -144,6 +135,21 @@ impl EmbeddingProvider for OpenAIEmbeddings { self.count_tokens(span) > OPENAI_INPUT_LIMIT } + fn truncate(&self, span: &str) -> String { + let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); + let output = if tokens.len() > OPENAI_INPUT_LIMIT { + tokens.truncate(OPENAI_INPUT_LIMIT); + OPENAI_BPE_TOKENIZER + .decode(tokens) + .ok() + .unwrap_or_else(|| span.to_string()) + } else { + span.to_string() + }; + + output + } + async fn embed_batch(&self, spans: Vec<&str>) -> Result>> { const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45]; const MAX_RETRIES: usize = 4; @@ -214,23 +220,13 @@ impl EmbeddingProvider for OpenAIEmbeddings { self.executor.timer(delay_duration).await; } _ => { - // TODO: Move this to parsing step - // Only truncate if it hasnt been truncated before - if !truncated { - for span in spans.iter_mut() { - *span = Self::truncate(span.clone()); - } - truncated = true; - } else { - // If failing once already truncated, log the error and break the loop - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - return Err(anyhow!( - "open ai bad request: {:?} {:?}", - &response.status(), - body - )); - } + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + return Err(anyhow!( + "open ai bad request: {:?} {:?}", + &response.status(), + body + )); } } } diff --git a/crates/semantic_index/src/parsing.rs b/crates/semantic_index/src/parsing.rs index b106e5055b..00849580bb 100644 --- a/crates/semantic_index/src/parsing.rs +++ b/crates/semantic_index/src/parsing.rs @@ -73,6 +73,7 @@ impl CodeContextRetriever { sha1.update(&document_span); let token_count = self.embedding_provider.count_tokens(&document_span); + let document_span = self.embedding_provider.truncate(&document_span); Ok(vec![Document { range: 0..content.len(), @@ -93,6 +94,7 @@ impl CodeContextRetriever { sha1.update(&document_span); let token_count = self.embedding_provider.count_tokens(&document_span); + let document_span = self.embedding_provider.truncate(&document_span); Ok(vec![Document { range: 0..content.len(), @@ -182,6 +184,8 @@ impl CodeContextRetriever { .replace("item", &document.content); let token_count = self.embedding_provider.count_tokens(&document_content); + let document_content = self.embedding_provider.truncate(&document_content); + document.content = document_content; document.token_count = token_count; } diff --git a/crates/semantic_index/src/semantic_index_tests.rs b/crates/semantic_index/src/semantic_index_tests.rs index 48cefd93b1..7093cf9fcf 100644 --- a/crates/semantic_index/src/semantic_index_tests.rs +++ b/crates/semantic_index/src/semantic_index_tests.rs @@ -1232,6 +1232,10 @@ impl EmbeddingProvider for FakeEmbeddingProvider { false } + fn truncate(&self, span: &str) -> String { + span.to_string() + } + async fn embed_batch(&self, spans: Vec<&str>) -> Result>> { self.embedding_count .fetch_add(spans.len(), atomic::Ordering::SeqCst); From 76ce52df4ee0f4b4b977093f096c76e15b852ae3 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Wed, 30 Aug 2023 16:01:28 -0400 Subject: [PATCH 06/20] move queuing to embedding_queue functionality and update embedding provider to include trait items for max tokens per batch" Co-authored-by: Max --- crates/semantic_index/src/embedding.rs | 47 ++---- crates/semantic_index/src/embedding_queue.rs | 140 ++++++++++++++++ crates/semantic_index/src/parsing.rs | 10 +- .../src/semantic_index_tests.rs | 154 +++++++++++++----- crates/util/src/util.rs | 35 ++-- 5 files changed, 295 insertions(+), 91 deletions(-) create mode 100644 crates/semantic_index/src/embedding_queue.rs diff --git a/crates/semantic_index/src/embedding.rs b/crates/semantic_index/src/embedding.rs index cba34439c8..7db22c3716 100644 --- a/crates/semantic_index/src/embedding.rs +++ b/crates/semantic_index/src/embedding.rs @@ -53,36 +53,30 @@ struct OpenAIEmbeddingUsage { #[async_trait] pub trait EmbeddingProvider: Sync + Send { - async fn embed_batch(&self, spans: Vec<&str>) -> Result>>; - fn count_tokens(&self, span: &str) -> usize; - fn should_truncate(&self, span: &str) -> bool; - fn truncate(&self, span: &str) -> String; + async fn embed_batch(&self, spans: Vec) -> Result>>; + fn max_tokens_per_batch(&self) -> usize; + fn truncate(&self, span: &str) -> (String, usize); } pub struct DummyEmbeddings {} #[async_trait] impl EmbeddingProvider for DummyEmbeddings { - async fn embed_batch(&self, spans: Vec<&str>) -> Result>> { + async fn embed_batch(&self, spans: Vec) -> Result>> { // 1024 is the OpenAI Embeddings size for ada models. // the model we will likely be starting with. let dummy_vec = vec![0.32 as f32; 1536]; return Ok(vec![dummy_vec; spans.len()]); } - fn count_tokens(&self, span: &str) -> usize { - // For Dummy Providers, we are going to use OpenAI tokenization for ease - let tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); - tokens.len() + fn max_tokens_per_batch(&self) -> usize { + OPENAI_INPUT_LIMIT } - fn should_truncate(&self, span: &str) -> bool { - self.count_tokens(span) > OPENAI_INPUT_LIMIT - } - - fn truncate(&self, span: &str) -> String { + fn truncate(&self, span: &str) -> (String, usize) { let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); - let output = if tokens.len() > OPENAI_INPUT_LIMIT { + let token_count = tokens.len(); + let output = if token_count > OPENAI_INPUT_LIMIT { tokens.truncate(OPENAI_INPUT_LIMIT); OPENAI_BPE_TOKENIZER .decode(tokens) @@ -92,7 +86,7 @@ impl EmbeddingProvider for DummyEmbeddings { span.to_string() }; - output + (output, token_count) } } @@ -125,19 +119,14 @@ impl OpenAIEmbeddings { #[async_trait] impl EmbeddingProvider for OpenAIEmbeddings { - fn count_tokens(&self, span: &str) -> usize { - // For Dummy Providers, we are going to use OpenAI tokenization for ease - let tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); - tokens.len() + fn max_tokens_per_batch(&self) -> usize { + OPENAI_INPUT_LIMIT } - fn should_truncate(&self, span: &str) -> bool { - self.count_tokens(span) > OPENAI_INPUT_LIMIT - } - - fn truncate(&self, span: &str) -> String { + fn truncate(&self, span: &str) -> (String, usize) { let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); - let output = if tokens.len() > OPENAI_INPUT_LIMIT { + let token_count = tokens.len(); + let output = if token_count > OPENAI_INPUT_LIMIT { tokens.truncate(OPENAI_INPUT_LIMIT); OPENAI_BPE_TOKENIZER .decode(tokens) @@ -147,10 +136,10 @@ impl EmbeddingProvider for OpenAIEmbeddings { span.to_string() }; - output + (output, token_count) } - async fn embed_batch(&self, spans: Vec<&str>) -> Result>> { + async fn embed_batch(&self, spans: Vec) -> Result>> { const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45]; const MAX_RETRIES: usize = 4; @@ -160,9 +149,7 @@ impl EmbeddingProvider for OpenAIEmbeddings { let mut request_number = 0; let mut request_timeout: u64 = 10; - let mut truncated = false; let mut response: Response; - let mut spans: Vec = spans.iter().map(|x| x.to_string()).collect(); while request_number < MAX_RETRIES { response = self .send_request( diff --git a/crates/semantic_index/src/embedding_queue.rs b/crates/semantic_index/src/embedding_queue.rs new file mode 100644 index 0000000000..6609c39e78 --- /dev/null +++ b/crates/semantic_index/src/embedding_queue.rs @@ -0,0 +1,140 @@ +use std::{mem, ops::Range, path::PathBuf, sync::Arc, time::SystemTime}; + +use gpui::AppContext; +use parking_lot::Mutex; +use smol::channel; + +use crate::{embedding::EmbeddingProvider, parsing::Document, JobHandle}; + +#[derive(Clone)] +pub struct FileToEmbed { + pub worktree_id: i64, + pub path: PathBuf, + pub mtime: SystemTime, + pub documents: Vec, + pub job_handle: JobHandle, +} + +impl std::fmt::Debug for FileToEmbed { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("FileToEmbed") + .field("worktree_id", &self.worktree_id) + .field("path", &self.path) + .field("mtime", &self.mtime) + .field("document", &self.documents) + .finish_non_exhaustive() + } +} + +impl PartialEq for FileToEmbed { + fn eq(&self, other: &Self) -> bool { + self.worktree_id == other.worktree_id + && self.path == other.path + && self.mtime == other.mtime + && self.documents == other.documents + } +} + +pub struct EmbeddingQueue { + embedding_provider: Arc, + pending_batch: Vec, + pending_batch_token_count: usize, + finished_files_tx: channel::Sender, + finished_files_rx: channel::Receiver, +} + +pub struct FileToEmbedFragment { + file: Arc>, + document_range: Range, +} + +impl EmbeddingQueue { + pub fn new(embedding_provider: Arc) -> Self { + let (finished_files_tx, finished_files_rx) = channel::unbounded(); + Self { + embedding_provider, + pending_batch: Vec::new(), + pending_batch_token_count: 0, + finished_files_tx, + finished_files_rx, + } + } + + pub fn push(&mut self, file: FileToEmbed, cx: &mut AppContext) { + let file = Arc::new(Mutex::new(file)); + + self.pending_batch.push(FileToEmbedFragment { + file: file.clone(), + document_range: 0..0, + }); + + let mut fragment_range = &mut self.pending_batch.last_mut().unwrap().document_range; + for (ix, document) in file.lock().documents.iter().enumerate() { + let next_token_count = self.pending_batch_token_count + document.token_count; + if next_token_count > self.embedding_provider.max_tokens_per_batch() { + let range_end = fragment_range.end; + self.flush(cx); + self.pending_batch.push(FileToEmbedFragment { + file: file.clone(), + document_range: range_end..range_end, + }); + fragment_range = &mut self.pending_batch.last_mut().unwrap().document_range; + } + + fragment_range.end = ix + 1; + self.pending_batch_token_count += document.token_count; + } + } + + pub fn flush(&mut self, cx: &mut AppContext) { + let batch = mem::take(&mut self.pending_batch); + self.pending_batch_token_count = 0; + if batch.is_empty() { + return; + } + + let finished_files_tx = self.finished_files_tx.clone(); + let embedding_provider = self.embedding_provider.clone(); + cx.background().spawn(async move { + let mut spans = Vec::new(); + for fragment in &batch { + let file = fragment.file.lock(); + spans.extend( + file.documents[fragment.document_range.clone()] + .iter() + .map(|d| d.content.clone()), + ); + } + + match embedding_provider.embed_batch(spans).await { + Ok(embeddings) => { + let mut embeddings = embeddings.into_iter(); + for fragment in batch { + for document in + &mut fragment.file.lock().documents[fragment.document_range.clone()] + { + if let Some(embedding) = embeddings.next() { + document.embedding = embedding; + } else { + // + log::error!("number of embeddings returned different from number of documents"); + } + } + + if let Some(file) = Arc::into_inner(fragment.file) { + finished_files_tx.try_send(file.into_inner()).unwrap(); + } + } + } + Err(error) => { + log::error!("{:?}", error); + } + } + }) + .detach(); + } + + pub fn finished_files(&self) -> channel::Receiver { + self.finished_files_rx.clone() + } +} diff --git a/crates/semantic_index/src/parsing.rs b/crates/semantic_index/src/parsing.rs index 00849580bb..51f1bd7ca9 100644 --- a/crates/semantic_index/src/parsing.rs +++ b/crates/semantic_index/src/parsing.rs @@ -72,8 +72,7 @@ impl CodeContextRetriever { let mut sha1 = Sha1::new(); sha1.update(&document_span); - let token_count = self.embedding_provider.count_tokens(&document_span); - let document_span = self.embedding_provider.truncate(&document_span); + let (document_span, token_count) = self.embedding_provider.truncate(&document_span); Ok(vec![Document { range: 0..content.len(), @@ -93,8 +92,7 @@ impl CodeContextRetriever { let mut sha1 = Sha1::new(); sha1.update(&document_span); - let token_count = self.embedding_provider.count_tokens(&document_span); - let document_span = self.embedding_provider.truncate(&document_span); + let (document_span, token_count) = self.embedding_provider.truncate(&document_span); Ok(vec![Document { range: 0..content.len(), @@ -183,8 +181,8 @@ impl CodeContextRetriever { .replace("", language_name.as_ref()) .replace("item", &document.content); - let token_count = self.embedding_provider.count_tokens(&document_content); - let document_content = self.embedding_provider.truncate(&document_content); + let (document_content, token_count) = + self.embedding_provider.truncate(&document_content); document.content = document_content; document.token_count = token_count; diff --git a/crates/semantic_index/src/semantic_index_tests.rs b/crates/semantic_index/src/semantic_index_tests.rs index 7093cf9fcf..7178987165 100644 --- a/crates/semantic_index/src/semantic_index_tests.rs +++ b/crates/semantic_index/src/semantic_index_tests.rs @@ -1,14 +1,16 @@ use crate::{ db::dot, embedding::{DummyEmbeddings, EmbeddingProvider}, + embedding_queue::EmbeddingQueue, parsing::{subtract_ranges, CodeContextRetriever, Document}, semantic_index_settings::SemanticIndexSettings, - SearchResult, SemanticIndex, + FileToEmbed, JobHandle, SearchResult, SemanticIndex, }; use anyhow::Result; use async_trait::async_trait; use gpui::{Task, TestAppContext}; use language::{Language, LanguageConfig, LanguageRegistry, ToOffset}; +use parking_lot::Mutex; use pretty_assertions::assert_eq; use project::{project_settings::ProjectSettings, search::PathMatcher, FakeFs, Fs, Project}; use rand::{rngs::StdRng, Rng}; @@ -20,8 +22,10 @@ use std::{ atomic::{self, AtomicUsize}, Arc, }, + time::SystemTime, }; use unindent::Unindent; +use util::RandomCharIter; #[ctor::ctor] fn init_logger() { @@ -32,11 +36,7 @@ fn init_logger() { #[gpui::test] async fn test_semantic_index(cx: &mut TestAppContext) { - cx.update(|cx| { - cx.set_global(SettingsStore::test(cx)); - settings::register::(cx); - settings::register::(cx); - }); + init_test(cx); let fs = FakeFs::new(cx.background()); fs.insert_tree( @@ -75,7 +75,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) { let db_path = db_dir.path().join("db.sqlite"); let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); - let store = SemanticIndex::new( + let semantic_index = SemanticIndex::new( fs.clone(), db_path, embedding_provider.clone(), @@ -87,13 +87,13 @@ async fn test_semantic_index(cx: &mut TestAppContext) { let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await; - let _ = store + let _ = semantic_index .update(cx, |store, cx| { store.initialize_project(project.clone(), cx) }) .await; - let (file_count, outstanding_file_count) = store + let (file_count, outstanding_file_count) = semantic_index .update(cx, |store, cx| store.index_project(project.clone(), cx)) .await .unwrap(); @@ -101,7 +101,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) { cx.foreground().run_until_parked(); assert_eq!(*outstanding_file_count.borrow(), 0); - let search_results = store + let search_results = semantic_index .update(cx, |store, cx| { store.search_project( project.clone(), @@ -129,7 +129,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) { // Test Include Files Functonality let include_files = vec![PathMatcher::new("*.rs").unwrap()]; let exclude_files = vec![PathMatcher::new("*.rs").unwrap()]; - let rust_only_search_results = store + let rust_only_search_results = semantic_index .update(cx, |store, cx| { store.search_project( project.clone(), @@ -153,7 +153,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) { cx, ); - let no_rust_search_results = store + let no_rust_search_results = semantic_index .update(cx, |store, cx| { store.search_project( project.clone(), @@ -189,7 +189,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) { cx.foreground().run_until_parked(); let prev_embedding_count = embedding_provider.embedding_count(); - let (file_count, outstanding_file_count) = store + let (file_count, outstanding_file_count) = semantic_index .update(cx, |store, cx| store.index_project(project.clone(), cx)) .await .unwrap(); @@ -204,6 +204,69 @@ async fn test_semantic_index(cx: &mut TestAppContext) { ); } +#[gpui::test(iterations = 10)] +async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) { + let (outstanding_job_count, _) = postage::watch::channel_with(0); + let outstanding_job_count = Arc::new(Mutex::new(outstanding_job_count)); + + let files = (1..=3) + .map(|file_ix| FileToEmbed { + worktree_id: 5, + path: format!("path-{file_ix}").into(), + mtime: SystemTime::now(), + documents: (0..rng.gen_range(4..22)) + .map(|document_ix| { + let content_len = rng.gen_range(10..100); + Document { + range: 0..10, + embedding: Vec::new(), + name: format!("document {document_ix}"), + content: RandomCharIter::new(&mut rng) + .with_simple_text() + .take(content_len) + .collect(), + sha1: rng.gen(), + token_count: rng.gen_range(10..30), + } + }) + .collect(), + job_handle: JobHandle::new(&outstanding_job_count), + }) + .collect::>(); + + let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); + let mut queue = EmbeddingQueue::new(embedding_provider.clone()); + + let finished_files = cx.update(|cx| { + for file in &files { + queue.push(file.clone(), cx); + } + queue.flush(cx); + queue.finished_files() + }); + + cx.foreground().run_until_parked(); + let mut embedded_files: Vec<_> = files + .iter() + .map(|_| finished_files.try_recv().expect("no finished file")) + .collect(); + + let expected_files: Vec<_> = files + .iter() + .map(|file| { + let mut file = file.clone(); + for doc in &mut file.documents { + doc.embedding = embedding_provider.embed_sync(doc.content.as_ref()); + } + file + }) + .collect(); + + embedded_files.sort_by_key(|f| f.path.clone()); + + assert_eq!(embedded_files, expected_files); +} + #[track_caller] fn assert_search_results( actual: &[SearchResult], @@ -1220,47 +1283,42 @@ impl FakeEmbeddingProvider { fn embedding_count(&self) -> usize { self.embedding_count.load(atomic::Ordering::SeqCst) } + + fn embed_sync(&self, span: &str) -> Vec { + let mut result = vec![1.0; 26]; + for letter in span.chars() { + let letter = letter.to_ascii_lowercase(); + if letter as u32 >= 'a' as u32 { + let ix = (letter as u32) - ('a' as u32); + if ix < 26 { + result[ix as usize] += 1.0; + } + } + } + + let norm = result.iter().map(|x| x * x).sum::().sqrt(); + for x in &mut result { + *x /= norm; + } + + result + } } #[async_trait] impl EmbeddingProvider for FakeEmbeddingProvider { - fn count_tokens(&self, span: &str) -> usize { - span.len() + fn truncate(&self, span: &str) -> (String, usize) { + (span.to_string(), 1) } - fn should_truncate(&self, span: &str) -> bool { - false + fn max_tokens_per_batch(&self) -> usize { + 200 } - fn truncate(&self, span: &str) -> String { - span.to_string() - } - - async fn embed_batch(&self, spans: Vec<&str>) -> Result>> { + async fn embed_batch(&self, spans: Vec) -> Result>> { self.embedding_count .fetch_add(spans.len(), atomic::Ordering::SeqCst); - Ok(spans - .iter() - .map(|span| { - let mut result = vec![1.0; 26]; - for letter in span.chars() { - let letter = letter.to_ascii_lowercase(); - if letter as u32 >= 'a' as u32 { - let ix = (letter as u32) - ('a' as u32); - if ix < 26 { - result[ix as usize] += 1.0; - } - } - } - - let norm = result.iter().map(|x| x * x).sum::().sqrt(); - for x in &mut result { - *x /= norm; - } - - result - }) - .collect()) + Ok(spans.iter().map(|span| self.embed_sync(span)).collect()) } } @@ -1704,3 +1762,11 @@ fn test_subtract_ranges() { assert_eq!(subtract_ranges(&[0..5], &[1..2]), &[0..1, 2..5]); } + +fn init_test(cx: &mut TestAppContext) { + cx.update(|cx| { + cx.set_global(SettingsStore::test(cx)); + settings::register::(cx); + settings::register::(cx); + }); +} diff --git a/crates/util/src/util.rs b/crates/util/src/util.rs index c8beb86aef..785426ed4c 100644 --- a/crates/util/src/util.rs +++ b/crates/util/src/util.rs @@ -260,11 +260,22 @@ pub fn defer(f: F) -> impl Drop { Defer(Some(f)) } -pub struct RandomCharIter(T); +pub struct RandomCharIter { + rng: T, + simple_text: bool, +} impl RandomCharIter { pub fn new(rng: T) -> Self { - Self(rng) + Self { + rng, + simple_text: std::env::var("SIMPLE_TEXT").map_or(false, |v| !v.is_empty()), + } + } + + pub fn with_simple_text(mut self) -> Self { + self.simple_text = true; + self } } @@ -272,25 +283,27 @@ impl Iterator for RandomCharIter { type Item = char; fn next(&mut self) -> Option { - if std::env::var("SIMPLE_TEXT").map_or(false, |v| !v.is_empty()) { - return if self.0.gen_range(0..100) < 5 { + if self.simple_text { + return if self.rng.gen_range(0..100) < 5 { Some('\n') } else { - Some(self.0.gen_range(b'a'..b'z' + 1).into()) + Some(self.rng.gen_range(b'a'..b'z' + 1).into()) }; } - match self.0.gen_range(0..100) { + match self.rng.gen_range(0..100) { // whitespace - 0..=19 => [' ', '\n', '\r', '\t'].choose(&mut self.0).copied(), + 0..=19 => [' ', '\n', '\r', '\t'].choose(&mut self.rng).copied(), // two-byte greek letters - 20..=32 => char::from_u32(self.0.gen_range(('α' as u32)..('ω' as u32 + 1))), + 20..=32 => char::from_u32(self.rng.gen_range(('α' as u32)..('ω' as u32 + 1))), // // three-byte characters - 33..=45 => ['✋', '✅', '❌', '❎', '⭐'].choose(&mut self.0).copied(), + 33..=45 => ['✋', '✅', '❌', '❎', '⭐'] + .choose(&mut self.rng) + .copied(), // // four-byte characters - 46..=58 => ['🍐', '🏀', '🍗', '🎉'].choose(&mut self.0).copied(), + 46..=58 => ['🍐', '🏀', '🍗', '🎉'].choose(&mut self.rng).copied(), // ascii letters - _ => Some(self.0.gen_range(b'a'..b'z' + 1).into()), + _ => Some(self.rng.gen_range(b'a'..b'z' + 1).into()), } } } From 5abad58b0d81941726f81fd8e6e8ca876811163e Mon Sep 17 00:00:00 2001 From: KCaverly Date: Wed, 30 Aug 2023 16:58:45 -0400 Subject: [PATCH 07/20] moved semantic index to use embeddings queue to batch and managed for atomic database writes Co-authored-by: Max --- crates/semantic_index/src/embedding_queue.rs | 25 +- crates/semantic_index/src/semantic_index.rs | 238 +++--------------- .../src/semantic_index_tests.rs | 14 +- 3 files changed, 55 insertions(+), 222 deletions(-) diff --git a/crates/semantic_index/src/embedding_queue.rs b/crates/semantic_index/src/embedding_queue.rs index 6609c39e78..2b48b7a7d6 100644 --- a/crates/semantic_index/src/embedding_queue.rs +++ b/crates/semantic_index/src/embedding_queue.rs @@ -1,10 +1,8 @@ -use std::{mem, ops::Range, path::PathBuf, sync::Arc, time::SystemTime}; - -use gpui::AppContext; +use crate::{embedding::EmbeddingProvider, parsing::Document, JobHandle}; +use gpui::executor::Background; use parking_lot::Mutex; use smol::channel; - -use crate::{embedding::EmbeddingProvider, parsing::Document, JobHandle}; +use std::{mem, ops::Range, path::PathBuf, sync::Arc, time::SystemTime}; #[derive(Clone)] pub struct FileToEmbed { @@ -38,6 +36,7 @@ impl PartialEq for FileToEmbed { pub struct EmbeddingQueue { embedding_provider: Arc, pending_batch: Vec, + executor: Arc, pending_batch_token_count: usize, finished_files_tx: channel::Sender, finished_files_rx: channel::Receiver, @@ -49,10 +48,11 @@ pub struct FileToEmbedFragment { } impl EmbeddingQueue { - pub fn new(embedding_provider: Arc) -> Self { + pub fn new(embedding_provider: Arc, executor: Arc) -> Self { let (finished_files_tx, finished_files_rx) = channel::unbounded(); Self { embedding_provider, + executor, pending_batch: Vec::new(), pending_batch_token_count: 0, finished_files_tx, @@ -60,7 +60,12 @@ impl EmbeddingQueue { } } - pub fn push(&mut self, file: FileToEmbed, cx: &mut AppContext) { + pub fn push(&mut self, file: FileToEmbed) { + if file.documents.is_empty() { + self.finished_files_tx.try_send(file).unwrap(); + return; + } + let file = Arc::new(Mutex::new(file)); self.pending_batch.push(FileToEmbedFragment { @@ -73,7 +78,7 @@ impl EmbeddingQueue { let next_token_count = self.pending_batch_token_count + document.token_count; if next_token_count > self.embedding_provider.max_tokens_per_batch() { let range_end = fragment_range.end; - self.flush(cx); + self.flush(); self.pending_batch.push(FileToEmbedFragment { file: file.clone(), document_range: range_end..range_end, @@ -86,7 +91,7 @@ impl EmbeddingQueue { } } - pub fn flush(&mut self, cx: &mut AppContext) { + pub fn flush(&mut self) { let batch = mem::take(&mut self.pending_batch); self.pending_batch_token_count = 0; if batch.is_empty() { @@ -95,7 +100,7 @@ impl EmbeddingQueue { let finished_files_tx = self.finished_files_tx.clone(); let embedding_provider = self.embedding_provider.clone(); - cx.background().spawn(async move { + self.executor.spawn(async move { let mut spans = Vec::new(); for fragment in &batch { let file = fragment.file.lock(); diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index ab05ca7581..cde53182dc 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -1,5 +1,6 @@ mod db; mod embedding; +mod embedding_queue; mod parsing; pub mod semantic_index_settings; @@ -10,6 +11,7 @@ use crate::semantic_index_settings::SemanticIndexSettings; use anyhow::{anyhow, Result}; use db::VectorDatabase; use embedding::{EmbeddingProvider, OpenAIEmbeddings}; +use embedding_queue::{EmbeddingQueue, FileToEmbed}; use futures::{channel::oneshot, Future}; use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle}; use language::{Anchor, Buffer, Language, LanguageRegistry}; @@ -23,7 +25,6 @@ use smol::channel; use std::{ cmp::Ordering, collections::{BTreeMap, HashMap}, - mem, ops::Range, path::{Path, PathBuf}, sync::{Arc, Weak}, @@ -38,7 +39,6 @@ use util::{ use workspace::WorkspaceCreated; const SEMANTIC_INDEX_VERSION: usize = 7; -const EMBEDDINGS_BATCH_SIZE: usize = 80; const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(600); pub fn init( @@ -106,9 +106,8 @@ pub struct SemanticIndex { language_registry: Arc, db_update_tx: channel::Sender, parsing_files_tx: channel::Sender, + _embedding_task: Task<()>, _db_update_task: Task<()>, - _embed_batch_tasks: Vec>, - _batch_files_task: Task<()>, _parsing_files_tasks: Vec>, projects: HashMap, ProjectState>, } @@ -128,7 +127,7 @@ struct ChangedPathInfo { } #[derive(Clone)] -struct JobHandle { +pub struct JobHandle { /// The outer Arc is here to count the clones of a JobHandle instance; /// when the last handle to a given job is dropped, we decrement a counter (just once). tx: Arc>>>, @@ -230,17 +229,6 @@ enum DbOperation { }, } -enum EmbeddingJob { - Enqueue { - worktree_id: i64, - path: PathBuf, - mtime: SystemTime, - documents: Vec, - job_handle: JobHandle, - }, - Flush, -} - impl SemanticIndex { pub fn global(cx: &AppContext) -> Option> { if cx.has_global::>() { @@ -287,52 +275,35 @@ impl SemanticIndex { } }); - // Group documents into batches and send them to the embedding provider. - let (embed_batch_tx, embed_batch_rx) = - channel::unbounded::, PathBuf, SystemTime, JobHandle)>>(); - let mut _embed_batch_tasks = Vec::new(); - for _ in 0..cx.background().num_cpus() { - let embed_batch_rx = embed_batch_rx.clone(); - _embed_batch_tasks.push(cx.background().spawn({ - let db_update_tx = db_update_tx.clone(); - let embedding_provider = embedding_provider.clone(); - async move { - while let Ok(embeddings_queue) = embed_batch_rx.recv().await { - Self::compute_embeddings_for_batch( - embeddings_queue, - &embedding_provider, - &db_update_tx, - ) - .await; - } + let embedding_queue = + EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone()); + let _embedding_task = cx.background().spawn({ + let embedded_files = embedding_queue.finished_files(); + let db_update_tx = db_update_tx.clone(); + async move { + while let Ok(file) = embedded_files.recv().await { + db_update_tx + .try_send(DbOperation::InsertFile { + worktree_id: file.worktree_id, + documents: file.documents, + path: file.path, + mtime: file.mtime, + job_handle: file.job_handle, + }) + .ok(); } - })); - } - - // Group documents into batches and send them to the embedding provider. - let (batch_files_tx, batch_files_rx) = channel::unbounded::(); - let _batch_files_task = cx.background().spawn(async move { - let mut queue_len = 0; - let mut embeddings_queue = vec![]; - while let Ok(job) = batch_files_rx.recv().await { - Self::enqueue_documents_to_embed( - job, - &mut queue_len, - &mut embeddings_queue, - &embed_batch_tx, - ); } }); // Parse files into embeddable documents. let (parsing_files_tx, parsing_files_rx) = channel::unbounded::(); + let embedding_queue = Arc::new(Mutex::new(embedding_queue)); let mut _parsing_files_tasks = Vec::new(); for _ in 0..cx.background().num_cpus() { let fs = fs.clone(); let parsing_files_rx = parsing_files_rx.clone(); - let batch_files_tx = batch_files_tx.clone(); - let db_update_tx = db_update_tx.clone(); let embedding_provider = embedding_provider.clone(); + let embedding_queue = embedding_queue.clone(); _parsing_files_tasks.push(cx.background().spawn(async move { let mut retriever = CodeContextRetriever::new(embedding_provider.clone()); while let Ok(pending_file) = parsing_files_rx.recv().await { @@ -340,9 +311,8 @@ impl SemanticIndex { &fs, pending_file, &mut retriever, - &batch_files_tx, + &embedding_queue, &parsing_files_rx, - &db_update_tx, ) .await; } @@ -361,8 +331,7 @@ impl SemanticIndex { db_update_tx, parsing_files_tx, _db_update_task, - _embed_batch_tasks, - _batch_files_task, + _embedding_task, _parsing_files_tasks, projects: HashMap::new(), } @@ -403,136 +372,12 @@ impl SemanticIndex { } } - async fn compute_embeddings_for_batch( - mut embeddings_queue: Vec<(i64, Vec, PathBuf, SystemTime, JobHandle)>, - embedding_provider: &Arc, - db_update_tx: &channel::Sender, - ) { - let mut batch_documents = vec![]; - for (_, documents, _, _, _) in embeddings_queue.iter() { - batch_documents.extend(documents.iter().map(|document| document.content.as_str())); - } - - if let Ok(embeddings) = embedding_provider.embed_batch(batch_documents).await { - log::trace!( - "created {} embeddings for {} files", - embeddings.len(), - embeddings_queue.len(), - ); - - let mut i = 0; - let mut j = 0; - - for embedding in embeddings.iter() { - while embeddings_queue[i].1.len() == j { - i += 1; - j = 0; - } - - embeddings_queue[i].1[j].embedding = embedding.to_owned(); - j += 1; - } - - for (worktree_id, documents, path, mtime, job_handle) in embeddings_queue.into_iter() { - db_update_tx - .send(DbOperation::InsertFile { - worktree_id, - documents, - path, - mtime, - job_handle, - }) - .await - .unwrap(); - } - } else { - // Insert the file in spite of failure so that future attempts to index it do not take place (unless the file is changed). - for (worktree_id, _, path, mtime, job_handle) in embeddings_queue.into_iter() { - db_update_tx - .send(DbOperation::InsertFile { - worktree_id, - documents: vec![], - path, - mtime, - job_handle, - }) - .await - .unwrap(); - } - } - } - - fn enqueue_documents_to_embed( - job: EmbeddingJob, - queue_len: &mut usize, - embeddings_queue: &mut Vec<(i64, Vec, PathBuf, SystemTime, JobHandle)>, - embed_batch_tx: &channel::Sender, PathBuf, SystemTime, JobHandle)>>, - ) { - // Handle edge case where individual file has more documents than max batch size - let should_flush = match job { - EmbeddingJob::Enqueue { - documents, - worktree_id, - path, - mtime, - job_handle, - } => { - // If documents is greater than embeddings batch size, recursively batch existing rows. - if &documents.len() > &EMBEDDINGS_BATCH_SIZE { - let first_job = EmbeddingJob::Enqueue { - documents: documents[..EMBEDDINGS_BATCH_SIZE].to_vec(), - worktree_id, - path: path.clone(), - mtime, - job_handle: job_handle.clone(), - }; - - Self::enqueue_documents_to_embed( - first_job, - queue_len, - embeddings_queue, - embed_batch_tx, - ); - - let second_job = EmbeddingJob::Enqueue { - documents: documents[EMBEDDINGS_BATCH_SIZE..].to_vec(), - worktree_id, - path: path.clone(), - mtime, - job_handle: job_handle.clone(), - }; - - Self::enqueue_documents_to_embed( - second_job, - queue_len, - embeddings_queue, - embed_batch_tx, - ); - return; - } else { - *queue_len += &documents.len(); - embeddings_queue.push((worktree_id, documents, path, mtime, job_handle)); - *queue_len >= EMBEDDINGS_BATCH_SIZE - } - } - EmbeddingJob::Flush => true, - }; - - if should_flush { - embed_batch_tx - .try_send(mem::take(embeddings_queue)) - .unwrap(); - *queue_len = 0; - } - } - async fn parse_file( fs: &Arc, pending_file: PendingFile, retriever: &mut CodeContextRetriever, - batch_files_tx: &channel::Sender, + embedding_queue: &Arc>, parsing_files_rx: &channel::Receiver, - db_update_tx: &channel::Sender, ) { let Some(language) = pending_file.language else { return; @@ -549,33 +394,18 @@ impl SemanticIndex { documents.len() ); - if documents.len() == 0 { - db_update_tx - .send(DbOperation::InsertFile { - worktree_id: pending_file.worktree_db_id, - documents, - path: pending_file.relative_path, - mtime: pending_file.modified_time, - job_handle: pending_file.job_handle, - }) - .await - .unwrap(); - } else { - batch_files_tx - .try_send(EmbeddingJob::Enqueue { - worktree_id: pending_file.worktree_db_id, - path: pending_file.relative_path, - mtime: pending_file.modified_time, - job_handle: pending_file.job_handle, - documents, - }) - .unwrap(); - } + embedding_queue.lock().push(FileToEmbed { + worktree_id: pending_file.worktree_db_id, + path: pending_file.relative_path, + mtime: pending_file.modified_time, + job_handle: pending_file.job_handle, + documents, + }); } } if parsing_files_rx.len() == 0 { - batch_files_tx.try_send(EmbeddingJob::Flush).unwrap(); + embedding_queue.lock().flush(); } } @@ -881,7 +711,7 @@ impl SemanticIndex { let database = VectorDatabase::new(fs.clone(), database_url.clone()).await?; let phrase_embedding = embedding_provider - .embed_batch(vec![&phrase]) + .embed_batch(vec![phrase]) .await? .into_iter() .next() diff --git a/crates/semantic_index/src/semantic_index_tests.rs b/crates/semantic_index/src/semantic_index_tests.rs index 7178987165..dc41c09f7a 100644 --- a/crates/semantic_index/src/semantic_index_tests.rs +++ b/crates/semantic_index/src/semantic_index_tests.rs @@ -235,17 +235,15 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) { .collect::>(); let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); - let mut queue = EmbeddingQueue::new(embedding_provider.clone()); - let finished_files = cx.update(|cx| { - for file in &files { - queue.push(file.clone(), cx); - } - queue.flush(cx); - queue.finished_files() - }); + let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background()); + for file in &files { + queue.push(file.clone()); + } + queue.flush(); cx.foreground().run_until_parked(); + let finished_files = queue.finished_files(); let mut embedded_files: Vec<_> = files .iter() .map(|_| finished_files.try_recv().expect("no finished file")) From 7d4d6c871ba88eafc8a084539a4619c8ba686872 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Wed, 30 Aug 2023 17:42:16 -0400 Subject: [PATCH 08/20] fix bug for truncation ensuring no valid inputs are sent to openai --- crates/semantic_index/src/embedding.rs | 10 ++++------ crates/semantic_index/src/embedding_queue.rs | 8 +++++--- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/crates/semantic_index/src/embedding.rs b/crates/semantic_index/src/embedding.rs index 7db22c3716..60e13a9e01 100644 --- a/crates/semantic_index/src/embedding.rs +++ b/crates/semantic_index/src/embedding.rs @@ -78,15 +78,13 @@ impl EmbeddingProvider for DummyEmbeddings { let token_count = tokens.len(); let output = if token_count > OPENAI_INPUT_LIMIT { tokens.truncate(OPENAI_INPUT_LIMIT); - OPENAI_BPE_TOKENIZER - .decode(tokens) - .ok() - .unwrap_or_else(|| span.to_string()) + let new_input = OPENAI_BPE_TOKENIZER.decode(tokens.clone()); + new_input.ok().unwrap_or_else(|| span.to_string()) } else { span.to_string() }; - (output, token_count) + (output, tokens.len()) } } @@ -120,7 +118,7 @@ impl OpenAIEmbeddings { #[async_trait] impl EmbeddingProvider for OpenAIEmbeddings { fn max_tokens_per_batch(&self) -> usize { - OPENAI_INPUT_LIMIT + 50000 } fn truncate(&self, span: &str) -> (String, usize) { diff --git a/crates/semantic_index/src/embedding_queue.rs b/crates/semantic_index/src/embedding_queue.rs index 2b48b7a7d6..c3a5de1373 100644 --- a/crates/semantic_index/src/embedding_queue.rs +++ b/crates/semantic_index/src/embedding_queue.rs @@ -105,9 +105,11 @@ impl EmbeddingQueue { for fragment in &batch { let file = fragment.file.lock(); spans.extend( - file.documents[fragment.document_range.clone()] - .iter() - .map(|d| d.content.clone()), + { + file.documents[fragment.document_range.clone()] + .iter() + .map(|d| d.content.clone()) + } ); } From 35440be98e13df2d87f1e87e2ef750adf2ff59cc Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 31 Aug 2023 16:54:11 +0200 Subject: [PATCH 09/20] Abstract away how database transactions are executed Co-Authored-By: Kyle Caverly --- crates/semantic_index/src/db.rs | 630 +++++++++++--------- crates/semantic_index/src/semantic_index.rs | 199 ++----- 2 files changed, 397 insertions(+), 432 deletions(-) diff --git a/crates/semantic_index/src/db.rs b/crates/semantic_index/src/db.rs index 60ecf3b45f..652c2819ce 100644 --- a/crates/semantic_index/src/db.rs +++ b/crates/semantic_index/src/db.rs @@ -1,5 +1,7 @@ use crate::{parsing::Document, SEMANTIC_INDEX_VERSION}; use anyhow::{anyhow, Context, Result}; +use futures::channel::oneshot; +use gpui::executor; use project::{search::PathMatcher, Fs}; use rpc::proto::Timestamp; use rusqlite::{ @@ -9,12 +11,14 @@ use rusqlite::{ use std::{ cmp::Ordering, collections::HashMap, + future::Future, ops::Range, path::{Path, PathBuf}, rc::Rc, sync::Arc, time::SystemTime, }; +use util::TryFutureExt; #[derive(Debug)] pub struct FileRecord { @@ -51,117 +55,161 @@ impl FromSql for Sha1 { } } +#[derive(Clone)] pub struct VectorDatabase { - db: rusqlite::Connection, + path: Arc, + transactions: smol::channel::Sender>, } impl VectorDatabase { - pub async fn new(fs: Arc, path: Arc) -> Result { + pub async fn new( + fs: Arc, + path: Arc, + executor: Arc, + ) -> Result { if let Some(db_directory) = path.parent() { fs.create_dir(db_directory).await?; } + let (transactions_tx, transactions_rx) = + smol::channel::unbounded::>(); + executor + .spawn({ + let path = path.clone(); + async move { + let connection = rusqlite::Connection::open(&path)?; + while let Ok(transaction) = transactions_rx.recv().await { + transaction(&connection); + } + + anyhow::Ok(()) + } + .log_err() + }) + .detach(); let this = Self { - db: rusqlite::Connection::open(path.as_path())?, + transactions: transactions_tx, + path, }; - this.initialize_database()?; + this.initialize_database().await?; Ok(this) } - fn get_existing_version(&self) -> Result { - let mut version_query = self - .db - .prepare("SELECT version from semantic_index_config")?; - version_query - .query_row([], |row| Ok(row.get::<_, i64>(0)?)) - .map_err(|err| anyhow!("version query failed: {err}")) + pub fn path(&self) -> &Arc { + &self.path } - fn initialize_database(&self) -> Result<()> { - rusqlite::vtab::array::load_module(&self.db)?; - - // Delete existing tables, if SEMANTIC_INDEX_VERSION is bumped - if self - .get_existing_version() - .map_or(false, |version| version == SEMANTIC_INDEX_VERSION as i64) - { - log::trace!("vector database schema up to date"); - return Ok(()); + fn transact(&self, transaction: F) -> impl Future> + where + F: 'static + Send + FnOnce(&rusqlite::Connection) -> Result, + T: 'static + Send, + { + let (tx, rx) = oneshot::channel(); + let transactions = self.transactions.clone(); + async move { + if transactions + .send(Box::new(|connection| { + let result = transaction(connection); + let _ = tx.send(result); + })) + .await + .is_err() + { + return Err(anyhow!("connection was dropped"))?; + } + rx.await? } - - log::trace!("vector database schema out of date. updating..."); - self.db - .execute("DROP TABLE IF EXISTS documents", []) - .context("failed to drop 'documents' table")?; - self.db - .execute("DROP TABLE IF EXISTS files", []) - .context("failed to drop 'files' table")?; - self.db - .execute("DROP TABLE IF EXISTS worktrees", []) - .context("failed to drop 'worktrees' table")?; - self.db - .execute("DROP TABLE IF EXISTS semantic_index_config", []) - .context("failed to drop 'semantic_index_config' table")?; - - // Initialize Vector Databasing Tables - self.db.execute( - "CREATE TABLE semantic_index_config ( - version INTEGER NOT NULL - )", - [], - )?; - - self.db.execute( - "INSERT INTO semantic_index_config (version) VALUES (?1)", - params![SEMANTIC_INDEX_VERSION], - )?; - - self.db.execute( - "CREATE TABLE worktrees ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - absolute_path VARCHAR NOT NULL - ); - CREATE UNIQUE INDEX worktrees_absolute_path ON worktrees (absolute_path); - ", - [], - )?; - - self.db.execute( - "CREATE TABLE files ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - worktree_id INTEGER NOT NULL, - relative_path VARCHAR NOT NULL, - mtime_seconds INTEGER NOT NULL, - mtime_nanos INTEGER NOT NULL, - FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE - )", - [], - )?; - - self.db.execute( - "CREATE TABLE documents ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - file_id INTEGER NOT NULL, - start_byte INTEGER NOT NULL, - end_byte INTEGER NOT NULL, - name VARCHAR NOT NULL, - embedding BLOB NOT NULL, - sha1 BLOB NOT NULL, - FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE - )", - [], - )?; - - log::trace!("vector database initialized with updated schema."); - Ok(()) } - pub fn delete_file(&self, worktree_id: i64, delete_path: PathBuf) -> Result<()> { - self.db.execute( - "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2", - params![worktree_id, delete_path.to_str()], - )?; - Ok(()) + fn initialize_database(&self) -> impl Future> { + self.transact(|db| { + rusqlite::vtab::array::load_module(&db)?; + + // Delete existing tables, if SEMANTIC_INDEX_VERSION is bumped + let version_query = db.prepare("SELECT version from semantic_index_config"); + let version = version_query + .and_then(|mut query| query.query_row([], |row| Ok(row.get::<_, i64>(0)?))); + if version.map_or(false, |version| version == SEMANTIC_INDEX_VERSION as i64) { + log::trace!("vector database schema up to date"); + return Ok(()); + } + + log::trace!("vector database schema out of date. updating..."); + db.execute("DROP TABLE IF EXISTS documents", []) + .context("failed to drop 'documents' table")?; + db.execute("DROP TABLE IF EXISTS files", []) + .context("failed to drop 'files' table")?; + db.execute("DROP TABLE IF EXISTS worktrees", []) + .context("failed to drop 'worktrees' table")?; + db.execute("DROP TABLE IF EXISTS semantic_index_config", []) + .context("failed to drop 'semantic_index_config' table")?; + + // Initialize Vector Databasing Tables + db.execute( + "CREATE TABLE semantic_index_config ( + version INTEGER NOT NULL + )", + [], + )?; + + db.execute( + "INSERT INTO semantic_index_config (version) VALUES (?1)", + params![SEMANTIC_INDEX_VERSION], + )?; + + db.execute( + "CREATE TABLE worktrees ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + absolute_path VARCHAR NOT NULL + ); + CREATE UNIQUE INDEX worktrees_absolute_path ON worktrees (absolute_path); + ", + [], + )?; + + db.execute( + "CREATE TABLE files ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + worktree_id INTEGER NOT NULL, + relative_path VARCHAR NOT NULL, + mtime_seconds INTEGER NOT NULL, + mtime_nanos INTEGER NOT NULL, + FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE + )", + [], + )?; + + db.execute( + "CREATE TABLE documents ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + file_id INTEGER NOT NULL, + start_byte INTEGER NOT NULL, + end_byte INTEGER NOT NULL, + name VARCHAR NOT NULL, + embedding BLOB NOT NULL, + sha1 BLOB NOT NULL, + FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE + )", + [], + )?; + + log::trace!("vector database initialized with updated schema."); + Ok(()) + }) + } + + pub fn delete_file( + &self, + worktree_id: i64, + delete_path: PathBuf, + ) -> impl Future> { + self.transact(move |db| { + db.execute( + "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2", + params![worktree_id, delete_path.to_str()], + )?; + Ok(()) + }) } pub fn insert_file( @@ -170,117 +218,126 @@ impl VectorDatabase { path: PathBuf, mtime: SystemTime, documents: Vec, - ) -> Result<()> { - // Return the existing ID, if both the file and mtime match - let mtime = Timestamp::from(mtime); - let mut existing_id_query = self.db.prepare("SELECT id FROM files WHERE worktree_id = ?1 AND relative_path = ?2 AND mtime_seconds = ?3 AND mtime_nanos = ?4")?; - let existing_id = existing_id_query - .query_row( - params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos], - |row| Ok(row.get::<_, i64>(0)?), - ) - .map_err(|err| anyhow!(err)); - let file_id = if existing_id.is_ok() { - // If already exists, just return the existing id - existing_id.unwrap() - } else { - // Delete Existing Row - self.db.execute( - "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2;", - params![worktree_id, path.to_str()], + ) -> impl Future> { + self.transact(move |db| { + // Return the existing ID, if both the file and mtime match + let mtime = Timestamp::from(mtime); + + let mut existing_id_query = db.prepare("SELECT id FROM files WHERE worktree_id = ?1 AND relative_path = ?2 AND mtime_seconds = ?3 AND mtime_nanos = ?4")?; + let existing_id = existing_id_query + .query_row( + params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos], + |row| Ok(row.get::<_, i64>(0)?), + ); + + let file_id = if existing_id.is_ok() { + // If already exists, just return the existing id + existing_id? + } else { + // Delete Existing Row + db.execute( + "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2;", + params![worktree_id, path.to_str()], + )?; + db.execute("INSERT INTO files (worktree_id, relative_path, mtime_seconds, mtime_nanos) VALUES (?1, ?2, ?3, ?4);", params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos])?; + db.last_insert_rowid() + }; + + // Currently inserting at approximately 3400 documents a second + // I imagine we can speed this up with a bulk insert of some kind. + for document in documents { + let embedding_blob = bincode::serialize(&document.embedding)?; + let sha_blob = bincode::serialize(&document.sha1)?; + + db.execute( + "INSERT INTO documents (file_id, start_byte, end_byte, name, embedding, sha1) VALUES (?1, ?2, ?3, ?4, ?5, ?6)", + params![ + file_id, + document.range.start.to_string(), + document.range.end.to_string(), + document.name, + embedding_blob, + sha_blob + ], + )?; + } + + Ok(()) + }) + } + + pub fn worktree_previously_indexed( + &self, + worktree_root_path: &Path, + ) -> impl Future> { + let worktree_root_path = worktree_root_path.to_string_lossy().into_owned(); + self.transact(move |db| { + let mut worktree_query = + db.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?; + let worktree_id = worktree_query + .query_row(params![worktree_root_path], |row| Ok(row.get::<_, i64>(0)?)); + + if worktree_id.is_ok() { + return Ok(true); + } else { + return Ok(false); + } + }) + } + + pub fn find_or_create_worktree( + &self, + worktree_root_path: PathBuf, + ) -> impl Future> { + self.transact(move |db| { + let mut worktree_query = + db.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?; + let worktree_id = worktree_query + .query_row(params![worktree_root_path.to_string_lossy()], |row| { + Ok(row.get::<_, i64>(0)?) + }); + + if worktree_id.is_ok() { + return Ok(worktree_id?); + } + + // If worktree_id is Err, insert new worktree + db.execute( + "INSERT into worktrees (absolute_path) VALUES (?1)", + params![worktree_root_path.to_string_lossy()], )?; - self.db.execute("INSERT INTO files (worktree_id, relative_path, mtime_seconds, mtime_nanos) VALUES (?1, ?2, ?3, ?4);", params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos])?; - self.db.last_insert_rowid() - }; + Ok(db.last_insert_rowid()) + }) + } - // Currently inserting at approximately 3400 documents a second - // I imagine we can speed this up with a bulk insert of some kind. - for document in documents { - let embedding_blob = bincode::serialize(&document.embedding)?; - let sha_blob = bincode::serialize(&document.sha1)?; - - self.db.execute( - "INSERT INTO documents (file_id, start_byte, end_byte, name, embedding, sha1) VALUES (?1, ?2, ?3, ?4, ?5, ?6)", - params![ - file_id, - document.range.start.to_string(), - document.range.end.to_string(), - document.name, - embedding_blob, - sha_blob - ], + pub fn get_file_mtimes( + &self, + worktree_id: i64, + ) -> impl Future>> { + self.transact(move |db| { + let mut statement = db.prepare( + " + SELECT relative_path, mtime_seconds, mtime_nanos + FROM files + WHERE worktree_id = ?1 + ORDER BY relative_path", )?; - } - - Ok(()) - } - - pub fn worktree_previously_indexed(&self, worktree_root_path: &Path) -> Result { - let mut worktree_query = self - .db - .prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?; - let worktree_id = worktree_query - .query_row(params![worktree_root_path.to_string_lossy()], |row| { - Ok(row.get::<_, i64>(0)?) - }) - .map_err(|err| anyhow!(err)); - - if worktree_id.is_ok() { - return Ok(true); - } else { - return Ok(false); - } - } - - pub fn find_or_create_worktree(&self, worktree_root_path: &Path) -> Result { - // Check that the absolute path doesnt exist - let mut worktree_query = self - .db - .prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?; - - let worktree_id = worktree_query - .query_row(params![worktree_root_path.to_string_lossy()], |row| { - Ok(row.get::<_, i64>(0)?) - }) - .map_err(|err| anyhow!(err)); - - if worktree_id.is_ok() { - return worktree_id; - } - - // If worktree_id is Err, insert new worktree - self.db.execute( - " - INSERT into worktrees (absolute_path) VALUES (?1) - ", - params![worktree_root_path.to_string_lossy()], - )?; - Ok(self.db.last_insert_rowid()) - } - - pub fn get_file_mtimes(&self, worktree_id: i64) -> Result> { - let mut statement = self.db.prepare( - " - SELECT relative_path, mtime_seconds, mtime_nanos - FROM files - WHERE worktree_id = ?1 - ORDER BY relative_path", - )?; - let mut result: HashMap = HashMap::new(); - for row in statement.query_map(params![worktree_id], |row| { - Ok(( - row.get::<_, String>(0)?.into(), - Timestamp { - seconds: row.get(1)?, - nanos: row.get(2)?, - } - .into(), - )) - })? { - let row = row?; - result.insert(row.0, row.1); - } - Ok(result) + let mut result: HashMap = HashMap::new(); + for row in statement.query_map(params![worktree_id], |row| { + Ok(( + row.get::<_, String>(0)?.into(), + Timestamp { + seconds: row.get(1)?, + nanos: row.get(2)?, + } + .into(), + )) + })? { + let row = row?; + result.insert(row.0, row.1); + } + Ok(result) + }) } pub fn top_k_search( @@ -288,21 +345,25 @@ impl VectorDatabase { query_embedding: &Vec, limit: usize, file_ids: &[i64], - ) -> Result> { - let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1); - self.for_each_document(file_ids, |id, embedding| { - let similarity = dot(&embedding, &query_embedding); - let ix = match results - .binary_search_by(|(_, s)| similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)) - { - Ok(ix) => ix, - Err(ix) => ix, - }; - results.insert(ix, (id, similarity)); - results.truncate(limit); - })?; + ) -> impl Future>> { + let query_embedding = query_embedding.clone(); + let file_ids = file_ids.to_vec(); + self.transact(move |db| { + let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1); + Self::for_each_document(db, &file_ids, |id, embedding| { + let similarity = dot(&embedding, &query_embedding); + let ix = match results.binary_search_by(|(_, s)| { + similarity.partial_cmp(&s).unwrap_or(Ordering::Equal) + }) { + Ok(ix) => ix, + Err(ix) => ix, + }; + results.insert(ix, (id, similarity)); + results.truncate(limit); + })?; - Ok(results) + anyhow::Ok(results) + }) } pub fn retrieve_included_file_ids( @@ -310,37 +371,46 @@ impl VectorDatabase { worktree_ids: &[i64], includes: &[PathMatcher], excludes: &[PathMatcher], - ) -> Result> { - let mut file_query = self.db.prepare( - " - SELECT - id, relative_path - FROM - files - WHERE - worktree_id IN rarray(?) - ", - )?; + ) -> impl Future>> { + let worktree_ids = worktree_ids.to_vec(); + let includes = includes.to_vec(); + let excludes = excludes.to_vec(); + self.transact(move |db| { + let mut file_query = db.prepare( + " + SELECT + id, relative_path + FROM + files + WHERE + worktree_id IN rarray(?) + ", + )?; - let mut file_ids = Vec::::new(); - let mut rows = file_query.query([ids_to_sql(worktree_ids)])?; + let mut file_ids = Vec::::new(); + let mut rows = file_query.query([ids_to_sql(&worktree_ids)])?; - while let Some(row) = rows.next()? { - let file_id = row.get(0)?; - let relative_path = row.get_ref(1)?.as_str()?; - let included = - includes.is_empty() || includes.iter().any(|glob| glob.is_match(relative_path)); - let excluded = excludes.iter().any(|glob| glob.is_match(relative_path)); - if included && !excluded { - file_ids.push(file_id); + while let Some(row) = rows.next()? { + let file_id = row.get(0)?; + let relative_path = row.get_ref(1)?.as_str()?; + let included = + includes.is_empty() || includes.iter().any(|glob| glob.is_match(relative_path)); + let excluded = excludes.iter().any(|glob| glob.is_match(relative_path)); + if included && !excluded { + file_ids.push(file_id); + } } - } - Ok(file_ids) + anyhow::Ok(file_ids) + }) } - fn for_each_document(&self, file_ids: &[i64], mut f: impl FnMut(i64, Vec)) -> Result<()> { - let mut query_statement = self.db.prepare( + fn for_each_document( + db: &rusqlite::Connection, + file_ids: &[i64], + mut f: impl FnMut(i64, Vec), + ) -> Result<()> { + let mut query_statement = db.prepare( " SELECT id, embedding @@ -360,47 +430,53 @@ impl VectorDatabase { Ok(()) } - pub fn get_documents_by_ids(&self, ids: &[i64]) -> Result)>> { - let mut statement = self.db.prepare( - " - SELECT - documents.id, - files.worktree_id, - files.relative_path, - documents.start_byte, - documents.end_byte - FROM - documents, files - WHERE - documents.file_id = files.id AND - documents.id in rarray(?) - ", - )?; + pub fn get_documents_by_ids( + &self, + ids: &[i64], + ) -> impl Future)>>> { + let ids = ids.to_vec(); + self.transact(move |db| { + let mut statement = db.prepare( + " + SELECT + documents.id, + files.worktree_id, + files.relative_path, + documents.start_byte, + documents.end_byte + FROM + documents, files + WHERE + documents.file_id = files.id AND + documents.id in rarray(?) + ", + )?; - let result_iter = statement.query_map(params![ids_to_sql(ids)], |row| { - Ok(( - row.get::<_, i64>(0)?, - row.get::<_, i64>(1)?, - row.get::<_, String>(2)?.into(), - row.get(3)?..row.get(4)?, - )) - })?; + let result_iter = statement.query_map(params![ids_to_sql(&ids)], |row| { + Ok(( + row.get::<_, i64>(0)?, + row.get::<_, i64>(1)?, + row.get::<_, String>(2)?.into(), + row.get(3)?..row.get(4)?, + )) + })?; - let mut values_by_id = HashMap::)>::default(); - for row in result_iter { - let (id, worktree_id, path, range) = row?; - values_by_id.insert(id, (worktree_id, path, range)); - } + let mut values_by_id = HashMap::)>::default(); + for row in result_iter { + let (id, worktree_id, path, range) = row?; + values_by_id.insert(id, (worktree_id, path, range)); + } - let mut results = Vec::with_capacity(ids.len()); - for id in ids { - let value = values_by_id - .remove(id) - .ok_or(anyhow!("missing document id {}", id))?; - results.push(value); - } + let mut results = Vec::with_capacity(ids.len()); + for id in &ids { + let value = values_by_id + .remove(id) + .ok_or(anyhow!("missing document id {}", id))?; + results.push(value); + } - Ok(results) + Ok(results) + }) } } diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index cde53182dc..7a0985b273 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -12,11 +12,10 @@ use anyhow::{anyhow, Result}; use db::VectorDatabase; use embedding::{EmbeddingProvider, OpenAIEmbeddings}; use embedding_queue::{EmbeddingQueue, FileToEmbed}; -use futures::{channel::oneshot, Future}; use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle}; use language::{Anchor, Buffer, Language, LanguageRegistry}; use parking_lot::Mutex; -use parsing::{CodeContextRetriever, Document, PARSEABLE_ENTIRE_FILE_TYPES}; +use parsing::{CodeContextRetriever, PARSEABLE_ENTIRE_FILE_TYPES}; use postage::watch; use project::{ search::PathMatcher, Fs, PathChange, Project, ProjectEntryId, ProjectPath, Worktree, WorktreeId, @@ -101,13 +100,11 @@ pub fn init( pub struct SemanticIndex { fs: Arc, - database_url: Arc, + db: VectorDatabase, embedding_provider: Arc, language_registry: Arc, - db_update_tx: channel::Sender, parsing_files_tx: channel::Sender, _embedding_task: Task<()>, - _db_update_task: Task<()>, _parsing_files_tasks: Vec>, projects: HashMap, ProjectState>, } @@ -203,32 +200,6 @@ pub struct SearchResult { pub range: Range, } -enum DbOperation { - InsertFile { - worktree_id: i64, - documents: Vec, - path: PathBuf, - mtime: SystemTime, - job_handle: JobHandle, - }, - Delete { - worktree_id: i64, - path: PathBuf, - }, - FindOrCreateWorktree { - path: PathBuf, - sender: oneshot::Sender>, - }, - FileMTimes { - worktree_id: i64, - sender: oneshot::Sender>>, - }, - WorktreePreviouslyIndexed { - path: Arc, - sender: oneshot::Sender>, - }, -} - impl SemanticIndex { pub fn global(cx: &AppContext) -> Option> { if cx.has_global::>() { @@ -245,18 +216,14 @@ impl SemanticIndex { async fn new( fs: Arc, - database_url: PathBuf, + database_path: PathBuf, embedding_provider: Arc, language_registry: Arc, mut cx: AsyncAppContext, ) -> Result> { let t0 = Instant::now(); - let database_url = Arc::new(database_url); - - let db = cx - .background() - .spawn(VectorDatabase::new(fs.clone(), database_url.clone())) - .await?; + let database_path = Arc::from(database_path); + let db = VectorDatabase::new(fs.clone(), database_path, cx.background()).await?; log::trace!( "db initialization took {:?} milliseconds", @@ -265,32 +232,16 @@ impl SemanticIndex { Ok(cx.add_model(|cx| { let t0 = Instant::now(); - // Perform database operations - let (db_update_tx, db_update_rx) = channel::unbounded(); - let _db_update_task = cx.background().spawn({ - async move { - while let Ok(job) = db_update_rx.recv().await { - Self::run_db_operation(&db, job) - } - } - }); - let embedding_queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone()); let _embedding_task = cx.background().spawn({ let embedded_files = embedding_queue.finished_files(); - let db_update_tx = db_update_tx.clone(); + let db = db.clone(); async move { while let Ok(file) = embedded_files.recv().await { - db_update_tx - .try_send(DbOperation::InsertFile { - worktree_id: file.worktree_id, - documents: file.documents, - path: file.path, - mtime: file.mtime, - job_handle: file.job_handle, - }) - .ok(); + db.insert_file(file.worktree_id, file.path, file.mtime, file.documents) + .await + .log_err(); } } }); @@ -325,12 +276,10 @@ impl SemanticIndex { ); Self { fs, - database_url, + db, embedding_provider, language_registry, - db_update_tx, parsing_files_tx, - _db_update_task, _embedding_task, _parsing_files_tasks, projects: HashMap::new(), @@ -338,40 +287,6 @@ impl SemanticIndex { })) } - fn run_db_operation(db: &VectorDatabase, job: DbOperation) { - match job { - DbOperation::InsertFile { - worktree_id, - documents, - path, - mtime, - job_handle, - } => { - db.insert_file(worktree_id, path, mtime, documents) - .log_err(); - drop(job_handle) - } - DbOperation::Delete { worktree_id, path } => { - db.delete_file(worktree_id, path).log_err(); - } - DbOperation::FindOrCreateWorktree { path, sender } => { - let id = db.find_or_create_worktree(&path); - sender.send(id).ok(); - } - DbOperation::FileMTimes { - worktree_id: worktree_db_id, - sender, - } => { - let file_mtimes = db.get_file_mtimes(worktree_db_id); - sender.send(file_mtimes).ok(); - } - DbOperation::WorktreePreviouslyIndexed { path, sender } => { - let worktree_indexed = db.worktree_previously_indexed(path.as_ref()); - sender.send(worktree_indexed).ok(); - } - } - } - async fn parse_file( fs: &Arc, pending_file: PendingFile, @@ -409,36 +324,6 @@ impl SemanticIndex { } } - fn find_or_create_worktree(&self, path: PathBuf) -> impl Future> { - let (tx, rx) = oneshot::channel(); - self.db_update_tx - .try_send(DbOperation::FindOrCreateWorktree { path, sender: tx }) - .unwrap(); - async move { rx.await? } - } - - fn get_file_mtimes( - &self, - worktree_id: i64, - ) -> impl Future>> { - let (tx, rx) = oneshot::channel(); - self.db_update_tx - .try_send(DbOperation::FileMTimes { - worktree_id, - sender: tx, - }) - .unwrap(); - async move { rx.await? } - } - - fn worktree_previously_indexed(&self, path: Arc) -> impl Future> { - let (tx, rx) = oneshot::channel(); - self.db_update_tx - .try_send(DbOperation::WorktreePreviouslyIndexed { path, sender: tx }) - .unwrap(); - async move { rx.await? } - } - pub fn project_previously_indexed( &mut self, project: ModelHandle, @@ -447,7 +332,10 @@ impl SemanticIndex { let worktrees_indexed_previously = project .read(cx) .worktrees(cx) - .map(|worktree| self.worktree_previously_indexed(worktree.read(cx).abs_path())) + .map(|worktree| { + self.db + .worktree_previously_indexed(&worktree.read(cx).abs_path()) + }) .collect::>(); cx.spawn(|_, _cx| async move { let worktree_indexed_previously = @@ -528,7 +416,8 @@ impl SemanticIndex { .read(cx) .worktrees(cx) .map(|worktree| { - self.find_or_create_worktree(worktree.read(cx).abs_path().to_path_buf()) + self.db + .find_or_create_worktree(worktree.read(cx).abs_path().to_path_buf()) }) .collect::>(); @@ -559,7 +448,7 @@ impl SemanticIndex { db_ids_by_worktree_id.insert(worktree.id(), db_id); worktree_file_mtimes.insert( worktree.id(), - this.read_with(&cx, |this, _| this.get_file_mtimes(db_id)) + this.read_with(&cx, |this, _| this.db.get_file_mtimes(db_id)) .await?, ); } @@ -704,11 +593,12 @@ impl SemanticIndex { .collect::>(); let embedding_provider = self.embedding_provider.clone(); - let database_url = self.database_url.clone(); + let db_path = self.db.path().clone(); let fs = self.fs.clone(); cx.spawn(|this, mut cx| async move { let t0 = Instant::now(); - let database = VectorDatabase::new(fs.clone(), database_url.clone()).await?; + let database = + VectorDatabase::new(fs.clone(), db_path.clone(), cx.background()).await?; let phrase_embedding = embedding_provider .embed_batch(vec![phrase]) @@ -722,8 +612,9 @@ impl SemanticIndex { t0.elapsed().as_millis() ); - let file_ids = - database.retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes)?; + let file_ids = database + .retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes) + .await?; let batch_n = cx.background().num_cpus(); let ids_len = file_ids.clone().len(); @@ -733,27 +624,24 @@ impl SemanticIndex { ids_len / batch_n }; - let mut result_tasks = Vec::new(); + let mut batch_results = Vec::new(); for batch in file_ids.chunks(batch_size) { let batch = batch.into_iter().map(|v| *v).collect::>(); let limit = limit.clone(); let fs = fs.clone(); - let database_url = database_url.clone(); + let db_path = db_path.clone(); let phrase_embedding = phrase_embedding.clone(); - let task = cx.background().spawn(async move { - let database = VectorDatabase::new(fs, database_url).await.log_err(); - if database.is_none() { - return Err(anyhow!("failed to acquire database connection")); - } else { - database - .unwrap() - .top_k_search(&phrase_embedding, limit, batch.as_slice()) - } - }); - result_tasks.push(task); + if let Some(db) = VectorDatabase::new(fs, db_path.clone(), cx.background()) + .await + .log_err() + { + batch_results.push(async move { + db.top_k_search(&phrase_embedding, limit, batch.as_slice()) + .await + }); + } } - - let batch_results = futures::future::join_all(result_tasks).await; + let batch_results = futures::future::join_all(batch_results).await; let mut results = Vec::new(); for batch_result in batch_results { @@ -772,7 +660,7 @@ impl SemanticIndex { } let ids = results.into_iter().map(|(id, _)| id).collect::>(); - let documents = database.get_documents_by_ids(ids.as_slice())?; + let documents = database.get_documents_by_ids(ids.as_slice()).await?; let mut tasks = Vec::new(); let mut ranges = Vec::new(); @@ -822,7 +710,8 @@ impl SemanticIndex { cx: &mut AsyncAppContext, ) { let mut pending_files = Vec::new(); - let (language_registry, parsing_files_tx) = this.update(cx, |this, cx| { + let mut files_to_delete = Vec::new(); + let (db, language_registry, parsing_files_tx) = this.update(cx, |this, cx| { if let Some(project_state) = this.projects.get_mut(&project.downgrade()) { let outstanding_job_count_tx = &project_state.outstanding_job_count_tx; let db_ids = &project_state.worktree_db_ids; @@ -853,12 +742,7 @@ impl SemanticIndex { }; if info.is_deleted { - this.db_update_tx - .try_send(DbOperation::Delete { - worktree_id: worktree_db_id, - path: path.path.to_path_buf(), - }) - .ok(); + files_to_delete.push((worktree_db_id, path.path.to_path_buf())); } else { let absolute_path = worktree.read(cx).absolutize(&path.path); let job_handle = JobHandle::new(&outstanding_job_count_tx); @@ -877,11 +761,16 @@ impl SemanticIndex { } ( + this.db.clone(), this.language_registry.clone(), this.parsing_files_tx.clone(), ) }); + for (worktree_db_id, path) in files_to_delete { + db.delete_file(worktree_db_id, path).await.log_err(); + } + for mut pending_file in pending_files { if let Ok(language) = language_registry .language_for_file(&pending_file.relative_path, None) From c763e728d12b413d27ae9f1477026dc82c0cf002 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 31 Aug 2023 16:59:54 +0200 Subject: [PATCH 10/20] Write to and read from the database in a transactional way Co-Authored-By: Kyle Caverly --- crates/semantic_index/src/db.rs | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/crates/semantic_index/src/db.rs b/crates/semantic_index/src/db.rs index 652c2819ce..313df40674 100644 --- a/crates/semantic_index/src/db.rs +++ b/crates/semantic_index/src/db.rs @@ -58,7 +58,8 @@ impl FromSql for Sha1 { #[derive(Clone)] pub struct VectorDatabase { path: Arc, - transactions: smol::channel::Sender>, + transactions: + smol::channel::Sender>, } impl VectorDatabase { @@ -71,15 +72,16 @@ impl VectorDatabase { fs.create_dir(db_directory).await?; } - let (transactions_tx, transactions_rx) = - smol::channel::unbounded::>(); + let (transactions_tx, transactions_rx) = smol::channel::unbounded::< + Box, + >(); executor .spawn({ let path = path.clone(); async move { - let connection = rusqlite::Connection::open(&path)?; + let mut connection = rusqlite::Connection::open(&path)?; while let Ok(transaction) = transactions_rx.recv().await { - transaction(&connection); + transaction(&mut connection); } anyhow::Ok(()) @@ -99,9 +101,9 @@ impl VectorDatabase { &self.path } - fn transact(&self, transaction: F) -> impl Future> + fn transact(&self, f: F) -> impl Future> where - F: 'static + Send + FnOnce(&rusqlite::Connection) -> Result, + F: 'static + Send + FnOnce(&rusqlite::Transaction) -> Result, T: 'static + Send, { let (tx, rx) = oneshot::channel(); @@ -109,7 +111,14 @@ impl VectorDatabase { async move { if transactions .send(Box::new(|connection| { - let result = transaction(connection); + let result = connection + .transaction() + .map_err(|err| anyhow!(err)) + .and_then(|transaction| { + let result = f(&transaction)?; + transaction.commit()?; + Ok(result) + }); let _ = tx.send(result); })) .await From 3001a46f6995cd900cae7bf633605dc0fb1334e4 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 31 Aug 2023 17:55:43 +0200 Subject: [PATCH 11/20] Reify `Embedding`/`Sha1` structs that can be (de)serialized from SQL Co-Authored-By: Kyle Caverly --- crates/semantic_index/src/db.rs | 76 ++---------- crates/semantic_index/src/embedding.rs | 114 +++++++++++++++++- crates/semantic_index/src/embedding_queue.rs | 2 +- crates/semantic_index/src/parsing.rs | 69 +++++++---- .../src/semantic_index_tests.rs | 57 +++------ 5 files changed, 180 insertions(+), 138 deletions(-) diff --git a/crates/semantic_index/src/db.rs b/crates/semantic_index/src/db.rs index 313df40674..81b05720d2 100644 --- a/crates/semantic_index/src/db.rs +++ b/crates/semantic_index/src/db.rs @@ -1,13 +1,10 @@ -use crate::{parsing::Document, SEMANTIC_INDEX_VERSION}; +use crate::{embedding::Embedding, parsing::Document, SEMANTIC_INDEX_VERSION}; use anyhow::{anyhow, Context, Result}; use futures::channel::oneshot; use gpui::executor; use project::{search::PathMatcher, Fs}; use rpc::proto::Timestamp; -use rusqlite::{ - params, - types::{FromSql, FromSqlResult, ValueRef}, -}; +use rusqlite::params; use std::{ cmp::Ordering, collections::HashMap, @@ -27,34 +24,6 @@ pub struct FileRecord { pub mtime: Timestamp, } -#[derive(Debug)] -struct Embedding(pub Vec); - -#[derive(Debug)] -struct Sha1(pub Vec); - -impl FromSql for Embedding { - fn column_result(value: ValueRef) -> FromSqlResult { - let bytes = value.as_blob()?; - let embedding: Result, Box> = bincode::deserialize(bytes); - if embedding.is_err() { - return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err())); - } - return Ok(Embedding(embedding.unwrap())); - } -} - -impl FromSql for Sha1 { - fn column_result(value: ValueRef) -> FromSqlResult { - let bytes = value.as_blob()?; - let sha1: Result, Box> = bincode::deserialize(bytes); - if sha1.is_err() { - return Err(rusqlite::types::FromSqlError::Other(sha1.unwrap_err())); - } - return Ok(Sha1(sha1.unwrap())); - } -} - #[derive(Clone)] pub struct VectorDatabase { path: Arc, @@ -255,9 +224,6 @@ impl VectorDatabase { // Currently inserting at approximately 3400 documents a second // I imagine we can speed this up with a bulk insert of some kind. for document in documents { - let embedding_blob = bincode::serialize(&document.embedding)?; - let sha_blob = bincode::serialize(&document.sha1)?; - db.execute( "INSERT INTO documents (file_id, start_byte, end_byte, name, embedding, sha1) VALUES (?1, ?2, ?3, ?4, ?5, ?6)", params![ @@ -265,8 +231,8 @@ impl VectorDatabase { document.range.start.to_string(), document.range.end.to_string(), document.name, - embedding_blob, - sha_blob + document.embedding, + document.sha1 ], )?; } @@ -351,7 +317,7 @@ impl VectorDatabase { pub fn top_k_search( &self, - query_embedding: &Vec, + query_embedding: &Embedding, limit: usize, file_ids: &[i64], ) -> impl Future>> { @@ -360,7 +326,7 @@ impl VectorDatabase { self.transact(move |db| { let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1); Self::for_each_document(db, &file_ids, |id, embedding| { - let similarity = dot(&embedding, &query_embedding); + let similarity = embedding.similarity(&query_embedding); let ix = match results.binary_search_by(|(_, s)| { similarity.partial_cmp(&s).unwrap_or(Ordering::Equal) }) { @@ -417,7 +383,7 @@ impl VectorDatabase { fn for_each_document( db: &rusqlite::Connection, file_ids: &[i64], - mut f: impl FnMut(i64, Vec), + mut f: impl FnMut(i64, Embedding), ) -> Result<()> { let mut query_statement = db.prepare( " @@ -435,7 +401,7 @@ impl VectorDatabase { Ok((row.get(0)?, row.get::<_, Embedding>(1)?)) })? .filter_map(|row| row.ok()) - .for_each(|(id, embedding)| f(id, embedding.0)); + .for_each(|(id, embedding)| f(id, embedding)); Ok(()) } @@ -497,29 +463,3 @@ fn ids_to_sql(ids: &[i64]) -> Rc> { .collect::>(), ) } - -pub(crate) fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 { - let len = vec_a.len(); - assert_eq!(len, vec_b.len()); - - let mut result = 0.0; - unsafe { - matrixmultiply::sgemm( - 1, - len, - 1, - 1.0, - vec_a.as_ptr(), - len as isize, - 1, - vec_b.as_ptr(), - 1, - len as isize, - 0.0, - &mut result as *mut f32, - 1, - 1, - ); - } - result -} diff --git a/crates/semantic_index/src/embedding.rs b/crates/semantic_index/src/embedding.rs index 60e13a9e01..97c25ca170 100644 --- a/crates/semantic_index/src/embedding.rs +++ b/crates/semantic_index/src/embedding.rs @@ -8,6 +8,8 @@ use isahc::prelude::Configurable; use isahc::{AsyncBody, Response}; use lazy_static::lazy_static; use parse_duration::parse; +use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef}; +use rusqlite::ToSql; use serde::{Deserialize, Serialize}; use std::env; use std::sync::Arc; @@ -20,6 +22,62 @@ lazy_static! { static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap(); } +#[derive(Debug, PartialEq, Clone)] +pub struct Embedding(Vec); + +impl From> for Embedding { + fn from(value: Vec) -> Self { + Embedding(value) + } +} + +impl Embedding { + pub fn similarity(&self, other: &Self) -> f32 { + let len = self.0.len(); + assert_eq!(len, other.0.len()); + + let mut result = 0.0; + unsafe { + matrixmultiply::sgemm( + 1, + len, + 1, + 1.0, + self.0.as_ptr(), + len as isize, + 1, + other.0.as_ptr(), + 1, + len as isize, + 0.0, + &mut result as *mut f32, + 1, + 1, + ); + } + result + } +} + +impl FromSql for Embedding { + fn column_result(value: ValueRef) -> FromSqlResult { + let bytes = value.as_blob()?; + let embedding: Result, Box> = bincode::deserialize(bytes); + if embedding.is_err() { + return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err())); + } + Ok(Embedding(embedding.unwrap())) + } +} + +impl ToSql for Embedding { + fn to_sql(&self) -> rusqlite::Result { + let bytes = bincode::serialize(&self.0) + .map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?; + Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes))) + } +} + #[derive(Clone)] pub struct OpenAIEmbeddings { pub client: Arc, @@ -53,7 +111,7 @@ struct OpenAIEmbeddingUsage { #[async_trait] pub trait EmbeddingProvider: Sync + Send { - async fn embed_batch(&self, spans: Vec) -> Result>>; + async fn embed_batch(&self, spans: Vec) -> Result>; fn max_tokens_per_batch(&self) -> usize; fn truncate(&self, span: &str) -> (String, usize); } @@ -62,10 +120,10 @@ pub struct DummyEmbeddings {} #[async_trait] impl EmbeddingProvider for DummyEmbeddings { - async fn embed_batch(&self, spans: Vec) -> Result>> { + async fn embed_batch(&self, spans: Vec) -> Result> { // 1024 is the OpenAI Embeddings size for ada models. // the model we will likely be starting with. - let dummy_vec = vec![0.32 as f32; 1536]; + let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]); return Ok(vec![dummy_vec; spans.len()]); } @@ -137,7 +195,7 @@ impl EmbeddingProvider for OpenAIEmbeddings { (output, token_count) } - async fn embed_batch(&self, spans: Vec) -> Result>> { + async fn embed_batch(&self, spans: Vec) -> Result> { const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45]; const MAX_RETRIES: usize = 4; @@ -175,7 +233,7 @@ impl EmbeddingProvider for OpenAIEmbeddings { return Ok(response .data .into_iter() - .map(|embedding| embedding.embedding) + .map(|embedding| Embedding::from(embedding.embedding)) .collect()); } StatusCode::TOO_MANY_REQUESTS => { @@ -218,3 +276,49 @@ impl EmbeddingProvider for OpenAIEmbeddings { Err(anyhow!("openai max retries")) } } + +#[cfg(test)] +mod tests { + use super::*; + use rand::prelude::*; + + #[gpui::test] + fn test_similarity(mut rng: StdRng) { + assert_eq!( + Embedding::from(vec![1., 0., 0., 0., 0.]) + .similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])), + 0. + ); + assert_eq!( + Embedding::from(vec![2., 0., 0., 0., 0.]) + .similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])), + 6. + ); + + for _ in 0..100 { + let size = 1536; + let mut a = vec![0.; size]; + let mut b = vec![0.; size]; + for (a, b) in a.iter_mut().zip(b.iter_mut()) { + *a = rng.gen(); + *b = rng.gen(); + } + let a = Embedding::from(a); + let b = Embedding::from(b); + + assert_eq!( + round_to_decimals(a.similarity(&b), 1), + round_to_decimals(reference_dot(&a.0, &b.0), 1) + ); + } + + fn round_to_decimals(n: f32, decimal_places: i32) -> f32 { + let factor = (10.0 as f32).powi(decimal_places); + (n * factor).round() / factor + } + + fn reference_dot(a: &[f32], b: &[f32]) -> f32 { + a.iter().zip(b.iter()).map(|(a, b)| a * b).sum() + } + } +} diff --git a/crates/semantic_index/src/embedding_queue.rs b/crates/semantic_index/src/embedding_queue.rs index c3a5de1373..4c82ced918 100644 --- a/crates/semantic_index/src/embedding_queue.rs +++ b/crates/semantic_index/src/embedding_queue.rs @@ -121,7 +121,7 @@ impl EmbeddingQueue { &mut fragment.file.lock().documents[fragment.document_range.clone()] { if let Some(embedding) = embeddings.next() { - document.embedding = embedding; + document.embedding = Some(embedding); } else { // log::error!("number of embeddings returned different from number of documents"); diff --git a/crates/semantic_index/src/parsing.rs b/crates/semantic_index/src/parsing.rs index 51f1bd7ca9..2b67f41714 100644 --- a/crates/semantic_index/src/parsing.rs +++ b/crates/semantic_index/src/parsing.rs @@ -1,7 +1,11 @@ -use crate::embedding::EmbeddingProvider; -use anyhow::{anyhow, Ok, Result}; +use crate::embedding::{EmbeddingProvider, Embedding}; +use anyhow::{anyhow, Result}; use language::{Grammar, Language}; -use sha1::{Digest, Sha1}; +use rusqlite::{ + types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef}, + ToSql, +}; +use sha1::Digest; use std::{ cmp::{self, Reverse}, collections::HashSet, @@ -11,13 +15,43 @@ use std::{ }; use tree_sitter::{Parser, QueryCursor}; +#[derive(Debug, PartialEq, Clone)] +pub struct Sha1([u8; 20]); + +impl FromSql for Sha1 { + fn column_result(value: ValueRef) -> FromSqlResult { + let blob = value.as_blob()?; + let bytes = + blob.try_into() + .map_err(|_| rusqlite::types::FromSqlError::InvalidBlobSize { + expected_size: 20, + blob_size: blob.len(), + })?; + return Ok(Sha1(bytes)); + } +} + +impl ToSql for Sha1 { + fn to_sql(&self) -> rusqlite::Result { + self.0.to_sql() + } +} + +impl From<&'_ str> for Sha1 { + fn from(value: &'_ str) -> Self { + let mut sha1 = sha1::Sha1::new(); + sha1.update(value); + Self(sha1.finalize().into()) + } +} + #[derive(Debug, PartialEq, Clone)] pub struct Document { pub name: String, pub range: Range, pub content: String, - pub embedding: Vec, - pub sha1: [u8; 20], + pub embedding: Option, + pub sha1: Sha1, pub token_count: usize, } @@ -69,17 +103,16 @@ impl CodeContextRetriever { .replace("", language_name.as_ref()) .replace("", &content); - let mut sha1 = Sha1::new(); - sha1.update(&document_span); + let sha1 = Sha1::from(document_span.as_str()); let (document_span, token_count) = self.embedding_provider.truncate(&document_span); Ok(vec![Document { range: 0..content.len(), content: document_span, - embedding: Vec::new(), + embedding: Default::default(), name: language_name.to_string(), - sha1: sha1.finalize().into(), + sha1, token_count, }]) } @@ -88,18 +121,14 @@ impl CodeContextRetriever { let document_span = MARKDOWN_CONTEXT_TEMPLATE .replace("", relative_path.to_string_lossy().as_ref()) .replace("", &content); - - let mut sha1 = Sha1::new(); - sha1.update(&document_span); - + let sha1 = Sha1::from(document_span.as_str()); let (document_span, token_count) = self.embedding_provider.truncate(&document_span); - Ok(vec![Document { range: 0..content.len(), content: document_span, - embedding: Vec::new(), + embedding: None, name: "Markdown".to_string(), - sha1: sha1.finalize().into(), + sha1, token_count, }]) } @@ -279,15 +308,13 @@ impl CodeContextRetriever { ); } - let mut sha1 = Sha1::new(); - sha1.update(&document_content); - + let sha1 = Sha1::from(document_content.as_str()); documents.push(Document { name, content: document_content, range: item_range.clone(), - embedding: vec![], - sha1: sha1.finalize().into(), + embedding: None, + sha1, token_count: 0, }) } diff --git a/crates/semantic_index/src/semantic_index_tests.rs b/crates/semantic_index/src/semantic_index_tests.rs index dc41c09f7a..75232eb4d2 100644 --- a/crates/semantic_index/src/semantic_index_tests.rs +++ b/crates/semantic_index/src/semantic_index_tests.rs @@ -1,8 +1,7 @@ use crate::{ - db::dot, - embedding::{DummyEmbeddings, EmbeddingProvider}, + embedding::{DummyEmbeddings, Embedding, EmbeddingProvider}, embedding_queue::EmbeddingQueue, - parsing::{subtract_ranges, CodeContextRetriever, Document}, + parsing::{subtract_ranges, CodeContextRetriever, Document, Sha1}, semantic_index_settings::SemanticIndexSettings, FileToEmbed, JobHandle, SearchResult, SemanticIndex, }; @@ -217,15 +216,17 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) { documents: (0..rng.gen_range(4..22)) .map(|document_ix| { let content_len = rng.gen_range(10..100); + let content = RandomCharIter::new(&mut rng) + .with_simple_text() + .take(content_len) + .collect::(); + let sha1 = Sha1::from(content.as_str()); Document { range: 0..10, - embedding: Vec::new(), + embedding: None, name: format!("document {document_ix}"), - content: RandomCharIter::new(&mut rng) - .with_simple_text() - .take(content_len) - .collect(), - sha1: rng.gen(), + content, + sha1, token_count: rng.gen_range(10..30), } }) @@ -254,7 +255,7 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) { .map(|file| { let mut file = file.clone(); for doc in &mut file.documents { - doc.embedding = embedding_provider.embed_sync(doc.content.as_ref()); + doc.embedding = Some(embedding_provider.embed_sync(doc.content.as_ref())); } file }) @@ -1242,36 +1243,6 @@ async fn test_code_context_retrieval_php() { ); } -#[gpui::test] -fn test_dot_product(mut rng: StdRng) { - assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.); - assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.); - - for _ in 0..100 { - let size = 1536; - let mut a = vec![0.; size]; - let mut b = vec![0.; size]; - for (a, b) in a.iter_mut().zip(b.iter_mut()) { - *a = rng.gen(); - *b = rng.gen(); - } - - assert_eq!( - round_to_decimals(dot(&a, &b), 1), - round_to_decimals(reference_dot(&a, &b), 1) - ); - } - - fn round_to_decimals(n: f32, decimal_places: i32) -> f32 { - let factor = (10.0 as f32).powi(decimal_places); - (n * factor).round() / factor - } - - fn reference_dot(a: &[f32], b: &[f32]) -> f32 { - a.iter().zip(b.iter()).map(|(a, b)| a * b).sum() - } -} - #[derive(Default)] struct FakeEmbeddingProvider { embedding_count: AtomicUsize, @@ -1282,7 +1253,7 @@ impl FakeEmbeddingProvider { self.embedding_count.load(atomic::Ordering::SeqCst) } - fn embed_sync(&self, span: &str) -> Vec { + fn embed_sync(&self, span: &str) -> Embedding { let mut result = vec![1.0; 26]; for letter in span.chars() { let letter = letter.to_ascii_lowercase(); @@ -1299,7 +1270,7 @@ impl FakeEmbeddingProvider { *x /= norm; } - result + result.into() } } @@ -1313,7 +1284,7 @@ impl EmbeddingProvider for FakeEmbeddingProvider { 200 } - async fn embed_batch(&self, spans: Vec) -> Result>> { + async fn embed_batch(&self, spans: Vec) -> Result> { self.embedding_count .fetch_add(spans.len(), atomic::Ordering::SeqCst); Ok(spans.iter().map(|span| self.embed_sync(span)).collect()) From 2503d54d1957f9b34c64af54b2a6d2e0e712ac13 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 31 Aug 2023 18:00:36 +0200 Subject: [PATCH 12/20] Rename `Sha1` to `DocumentDigest` Co-Authored-By: Kyle Caverly --- crates/semantic_index/src/db.rs | 12 ++++--- crates/semantic_index/src/parsing.rs | 35 +++++++++---------- crates/semantic_index/src/semantic_index.rs | 2 +- .../src/semantic_index_tests.rs | 6 ++-- 4 files changed, 28 insertions(+), 27 deletions(-) diff --git a/crates/semantic_index/src/db.rs b/crates/semantic_index/src/db.rs index 81b05720d2..375934e7fe 100644 --- a/crates/semantic_index/src/db.rs +++ b/crates/semantic_index/src/db.rs @@ -1,4 +1,8 @@ -use crate::{embedding::Embedding, parsing::Document, SEMANTIC_INDEX_VERSION}; +use crate::{ + embedding::Embedding, + parsing::{Document, DocumentDigest}, + SEMANTIC_INDEX_VERSION, +}; use anyhow::{anyhow, Context, Result}; use futures::channel::oneshot; use gpui::executor; @@ -165,7 +169,7 @@ impl VectorDatabase { end_byte INTEGER NOT NULL, name VARCHAR NOT NULL, embedding BLOB NOT NULL, - sha1 BLOB NOT NULL, + digest BLOB NOT NULL, FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE )", [], @@ -225,14 +229,14 @@ impl VectorDatabase { // I imagine we can speed this up with a bulk insert of some kind. for document in documents { db.execute( - "INSERT INTO documents (file_id, start_byte, end_byte, name, embedding, sha1) VALUES (?1, ?2, ?3, ?4, ?5, ?6)", + "INSERT INTO documents (file_id, start_byte, end_byte, name, embedding, digest) VALUES (?1, ?2, ?3, ?4, ?5, ?6)", params![ file_id, document.range.start.to_string(), document.range.end.to_string(), document.name, document.embedding, - document.sha1 + document.digest ], )?; } diff --git a/crates/semantic_index/src/parsing.rs b/crates/semantic_index/src/parsing.rs index 2b67f41714..c0a94c6b73 100644 --- a/crates/semantic_index/src/parsing.rs +++ b/crates/semantic_index/src/parsing.rs @@ -1,11 +1,11 @@ -use crate::embedding::{EmbeddingProvider, Embedding}; +use crate::embedding::{Embedding, EmbeddingProvider}; use anyhow::{anyhow, Result}; use language::{Grammar, Language}; use rusqlite::{ types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef}, ToSql, }; -use sha1::Digest; +use sha1::{Digest, Sha1}; use std::{ cmp::{self, Reverse}, collections::HashSet, @@ -15,10 +15,10 @@ use std::{ }; use tree_sitter::{Parser, QueryCursor}; -#[derive(Debug, PartialEq, Clone)] -pub struct Sha1([u8; 20]); +#[derive(Debug, PartialEq, Eq, Clone, Hash)] +pub struct DocumentDigest([u8; 20]); -impl FromSql for Sha1 { +impl FromSql for DocumentDigest { fn column_result(value: ValueRef) -> FromSqlResult { let blob = value.as_blob()?; let bytes = @@ -27,19 +27,19 @@ impl FromSql for Sha1 { expected_size: 20, blob_size: blob.len(), })?; - return Ok(Sha1(bytes)); + return Ok(DocumentDigest(bytes)); } } -impl ToSql for Sha1 { +impl ToSql for DocumentDigest { fn to_sql(&self) -> rusqlite::Result { self.0.to_sql() } } -impl From<&'_ str> for Sha1 { +impl From<&'_ str> for DocumentDigest { fn from(value: &'_ str) -> Self { - let mut sha1 = sha1::Sha1::new(); + let mut sha1 = Sha1::new(); sha1.update(value); Self(sha1.finalize().into()) } @@ -51,7 +51,7 @@ pub struct Document { pub range: Range, pub content: String, pub embedding: Option, - pub sha1: Sha1, + pub digest: DocumentDigest, pub token_count: usize, } @@ -102,17 +102,14 @@ impl CodeContextRetriever { .replace("", relative_path.to_string_lossy().as_ref()) .replace("", language_name.as_ref()) .replace("", &content); - - let sha1 = Sha1::from(document_span.as_str()); - + let digest = DocumentDigest::from(document_span.as_str()); let (document_span, token_count) = self.embedding_provider.truncate(&document_span); - Ok(vec![Document { range: 0..content.len(), content: document_span, embedding: Default::default(), name: language_name.to_string(), - sha1, + digest, token_count, }]) } @@ -121,14 +118,14 @@ impl CodeContextRetriever { let document_span = MARKDOWN_CONTEXT_TEMPLATE .replace("", relative_path.to_string_lossy().as_ref()) .replace("", &content); - let sha1 = Sha1::from(document_span.as_str()); + let digest = DocumentDigest::from(document_span.as_str()); let (document_span, token_count) = self.embedding_provider.truncate(&document_span); Ok(vec![Document { range: 0..content.len(), content: document_span, embedding: None, name: "Markdown".to_string(), - sha1, + digest, token_count, }]) } @@ -308,13 +305,13 @@ impl CodeContextRetriever { ); } - let sha1 = Sha1::from(document_content.as_str()); + let sha1 = DocumentDigest::from(document_content.as_str()); documents.push(Document { name, content: document_content, range: item_range.clone(), embedding: None, - sha1, + digest: sha1, token_count: 0, }) } diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 7a0985b273..0a9a808a64 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -37,7 +37,7 @@ use util::{ }; use workspace::WorkspaceCreated; -const SEMANTIC_INDEX_VERSION: usize = 7; +const SEMANTIC_INDEX_VERSION: usize = 8; const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(600); pub fn init( diff --git a/crates/semantic_index/src/semantic_index_tests.rs b/crates/semantic_index/src/semantic_index_tests.rs index 75232eb4d2..e65bc04412 100644 --- a/crates/semantic_index/src/semantic_index_tests.rs +++ b/crates/semantic_index/src/semantic_index_tests.rs @@ -1,7 +1,7 @@ use crate::{ embedding::{DummyEmbeddings, Embedding, EmbeddingProvider}, embedding_queue::EmbeddingQueue, - parsing::{subtract_ranges, CodeContextRetriever, Document, Sha1}, + parsing::{subtract_ranges, CodeContextRetriever, Document, DocumentDigest}, semantic_index_settings::SemanticIndexSettings, FileToEmbed, JobHandle, SearchResult, SemanticIndex, }; @@ -220,13 +220,13 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) { .with_simple_text() .take(content_len) .collect::(); - let sha1 = Sha1::from(content.as_str()); + let digest = DocumentDigest::from(content.as_str()); Document { range: 0..10, embedding: None, name: format!("document {document_ix}"), content, - sha1, + digest, token_count: rng.gen_range(10..30), } }) From 220533ff1abf46066853eae31c11eb17b219554d Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 31 Aug 2023 18:00:57 +0200 Subject: [PATCH 13/20] WIP --- crates/semantic_index/src/db.rs | 19 +++++++++++++++++++ crates/semantic_index/src/semantic_index.rs | 17 +++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/crates/semantic_index/src/db.rs b/crates/semantic_index/src/db.rs index 375934e7fe..134a70972f 100644 --- a/crates/semantic_index/src/db.rs +++ b/crates/semantic_index/src/db.rs @@ -264,6 +264,25 @@ impl VectorDatabase { }) } + pub fn embeddings_for_file( + &self, + worktree_id: i64, + relative_path: PathBuf, + ) -> impl Future>> { + let relative_path = relative_path.to_string_lossy().into_owned(); + self.transact(move |db| { + let mut query = db.prepare("SELECT digest, embedding FROM documents LEFT JOIN files ON files.id = documents.file_id WHERE files.worktree_id = ?1 AND files.relative_path = ?2")?; + let mut result: HashMap = HashMap::new(); + for row in query.query_map(params![worktree_id, relative_path], |row| { + Ok((row.get::<_, DocumentDigest>(0)?.into(), row.get::<_, Embedding>(1)?.into())) + })? { + let row = row?; + result.insert(row.0, row.1); + } + Ok(result) + }) + } + pub fn find_or_create_worktree( &self, worktree_root_path: PathBuf, diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 0a9a808a64..58166c1a22 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -309,6 +309,23 @@ impl SemanticIndex { documents.len() ); + todo!(); + // if let Some(embeddings) = db + // .embeddings_for_documents( + // pending_file.worktree_db_id, + // pending_file.relative_path, + // &documents, + // ) + // .await + // .log_err() + // { + // for (document, embedding) in documents.iter_mut().zip(embeddings) { + // if let Some(embedding) = embedding { + // document.embedding = embedding; + // } + // } + // } + embedding_queue.lock().push(FileToEmbed { worktree_id: pending_file.worktree_db_id, path: pending_file.relative_path, From 50cfb067e7c536636ed5bf7e119968d50843b287 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Thu, 31 Aug 2023 13:19:17 -0400 Subject: [PATCH 14/20] fill embeddings with database values and skip during embeddings queue --- crates/semantic_index/src/embedding_queue.rs | 34 ++++++++++++++++--- crates/semantic_index/src/semantic_index.rs | 35 ++++++++++---------- 2 files changed, 48 insertions(+), 21 deletions(-) diff --git a/crates/semantic_index/src/embedding_queue.rs b/crates/semantic_index/src/embedding_queue.rs index 4c82ced918..96493fc4d3 100644 --- a/crates/semantic_index/src/embedding_queue.rs +++ b/crates/semantic_index/src/embedding_queue.rs @@ -42,6 +42,7 @@ pub struct EmbeddingQueue { finished_files_rx: channel::Receiver, } +#[derive(Clone)] pub struct FileToEmbedFragment { file: Arc>, document_range: Range, @@ -74,8 +75,16 @@ impl EmbeddingQueue { }); let mut fragment_range = &mut self.pending_batch.last_mut().unwrap().document_range; + let mut saved_tokens = 0; for (ix, document) in file.lock().documents.iter().enumerate() { - let next_token_count = self.pending_batch_token_count + document.token_count; + let document_token_count = if document.embedding.is_none() { + document.token_count + } else { + saved_tokens += document.token_count; + 0 + }; + + let next_token_count = self.pending_batch_token_count + document_token_count; if next_token_count > self.embedding_provider.max_tokens_per_batch() { let range_end = fragment_range.end; self.flush(); @@ -87,8 +96,9 @@ impl EmbeddingQueue { } fragment_range.end = ix + 1; - self.pending_batch_token_count += document.token_count; + self.pending_batch_token_count += document_token_count; } + log::trace!("Saved Tokens: {:?}", saved_tokens); } pub fn flush(&mut self) { @@ -100,25 +110,41 @@ impl EmbeddingQueue { let finished_files_tx = self.finished_files_tx.clone(); let embedding_provider = self.embedding_provider.clone(); + self.executor.spawn(async move { let mut spans = Vec::new(); + let mut document_count = 0; for fragment in &batch { let file = fragment.file.lock(); + document_count += file.documents[fragment.document_range.clone()].len(); spans.extend( { file.documents[fragment.document_range.clone()] - .iter() + .iter().filter(|d| d.embedding.is_none()) .map(|d| d.content.clone()) } ); } + log::trace!("Documents Length: {:?}", document_count); + log::trace!("Span Length: {:?}", spans.clone().len()); + + // If spans is 0, just send the fragment to the finished files if its the last one. + if spans.len() == 0 { + for fragment in batch.clone() { + if let Some(file) = Arc::into_inner(fragment.file) { + finished_files_tx.try_send(file.into_inner()).unwrap(); + } + } + return; + }; + match embedding_provider.embed_batch(spans).await { Ok(embeddings) => { let mut embeddings = embeddings.into_iter(); for fragment in batch { for document in - &mut fragment.file.lock().documents[fragment.document_range.clone()] + &mut fragment.file.lock().documents[fragment.document_range.clone()].iter_mut().filter(|d| d.embedding.is_none()) { if let Some(embedding) = embeddings.next() { document.embedding = Some(embedding); diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 58166c1a22..726b04583a 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -255,6 +255,7 @@ impl SemanticIndex { let parsing_files_rx = parsing_files_rx.clone(); let embedding_provider = embedding_provider.clone(); let embedding_queue = embedding_queue.clone(); + let db = db.clone(); _parsing_files_tasks.push(cx.background().spawn(async move { let mut retriever = CodeContextRetriever::new(embedding_provider.clone()); while let Ok(pending_file) = parsing_files_rx.recv().await { @@ -264,6 +265,7 @@ impl SemanticIndex { &mut retriever, &embedding_queue, &parsing_files_rx, + &db, ) .await; } @@ -293,13 +295,14 @@ impl SemanticIndex { retriever: &mut CodeContextRetriever, embedding_queue: &Arc>, parsing_files_rx: &channel::Receiver, + db: &VectorDatabase, ) { let Some(language) = pending_file.language else { return; }; if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() { - if let Some(documents) = retriever + if let Some(mut documents) = retriever .parse_file_with_template(&pending_file.relative_path, &content, language) .log_err() { @@ -309,22 +312,20 @@ impl SemanticIndex { documents.len() ); - todo!(); - // if let Some(embeddings) = db - // .embeddings_for_documents( - // pending_file.worktree_db_id, - // pending_file.relative_path, - // &documents, - // ) - // .await - // .log_err() - // { - // for (document, embedding) in documents.iter_mut().zip(embeddings) { - // if let Some(embedding) = embedding { - // document.embedding = embedding; - // } - // } - // } + if let Some(sha_to_embeddings) = db + .embeddings_for_file( + pending_file.worktree_db_id, + pending_file.relative_path.clone(), + ) + .await + .log_err() + { + for document in documents.iter_mut() { + if let Some(embedding) = sha_to_embeddings.get(&document.digest) { + document.embedding = Some(embedding.to_owned()); + } + } + } embedding_queue.lock().push(FileToEmbed { worktree_id: pending_file.worktree_db_id, From afa59abbcd8a6208a844227e122e0e439e50bfda Mon Sep 17 00:00:00 2001 From: KCaverly Date: Thu, 31 Aug 2023 16:42:39 -0400 Subject: [PATCH 15/20] WIP: work towards wiring up a embeddings_for_digest hashmap that is stored for all indexed files --- crates/semantic_index/src/db.rs | 36 ++++++++ crates/semantic_index/src/semantic_index.rs | 91 +++++++++++++++------ 2 files changed, 104 insertions(+), 23 deletions(-) diff --git a/crates/semantic_index/src/db.rs b/crates/semantic_index/src/db.rs index 134a70972f..4a953a2866 100644 --- a/crates/semantic_index/src/db.rs +++ b/crates/semantic_index/src/db.rs @@ -9,6 +9,7 @@ use gpui::executor; use project::{search::PathMatcher, Fs}; use rpc::proto::Timestamp; use rusqlite::params; +use rusqlite::types::Value; use std::{ cmp::Ordering, collections::HashMap, @@ -283,6 +284,41 @@ impl VectorDatabase { }) } + pub fn embeddings_for_files( + &self, + worktree_id_file_paths: Vec<(i64, PathBuf)>, + ) -> impl Future>> { + todo!(); + // The remainder of the code is wired up. + // I'm having a bit of trouble figuring out the rusqlite syntax for a WHERE (files.worktree_id, files.relative_path) IN (VALUES (?, ?), (?, ?)) query + async { Ok(HashMap::new()) } + // let mut embeddings_by_digest = HashMap::new(); + // self.transact(move |db| { + + // let worktree_ids: Rc> = Rc::new( + // worktree_id_file_paths + // .iter() + // .map(|(id, _)| Value::from(*id)) + // .collect(), + // ); + // let file_paths: Rc> = Rc::new(worktree_id_file_paths + // .iter() + // .map(|(_, path)| Value::from(path.to_string_lossy().to_string())) + // .collect()); + + // let mut query = db.prepare("SELECT digest, embedding FROM documents LEFT JOIN files ON files.id = documents.file_id WHERE (files.worktree_id, files.relative_path) IN (VALUES (rarray = (?1), rarray = (?2))")?; + + // for row in query.query_map(params![worktree_ids, file_paths], |row| { + // Ok((row.get::<_, DocumentDigest>(0)?, row.get::<_, Embedding>(1)?)) + // })? { + // if let Ok(row) = row { + // embeddings_by_digest.insert(row.0, row.1); + // } + // } + // Ok(embeddings_by_digest) + // }) + } + pub fn find_or_create_worktree( &self, worktree_root_path: PathBuf, diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 726b04583a..908ac1f4be 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -10,12 +10,12 @@ mod semantic_index_tests; use crate::semantic_index_settings::SemanticIndexSettings; use anyhow::{anyhow, Result}; use db::VectorDatabase; -use embedding::{EmbeddingProvider, OpenAIEmbeddings}; +use embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings}; use embedding_queue::{EmbeddingQueue, FileToEmbed}; use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle}; use language::{Anchor, Buffer, Language, LanguageRegistry}; use parking_lot::Mutex; -use parsing::{CodeContextRetriever, PARSEABLE_ENTIRE_FILE_TYPES}; +use parsing::{CodeContextRetriever, DocumentDigest, PARSEABLE_ENTIRE_FILE_TYPES}; use postage::watch; use project::{ search::PathMatcher, Fs, PathChange, Project, ProjectEntryId, ProjectPath, Worktree, WorktreeId, @@ -103,7 +103,7 @@ pub struct SemanticIndex { db: VectorDatabase, embedding_provider: Arc, language_registry: Arc, - parsing_files_tx: channel::Sender, + parsing_files_tx: channel::Sender<(Arc>, PendingFile)>, _embedding_task: Task<()>, _parsing_files_tasks: Vec>, projects: HashMap, ProjectState>, @@ -247,7 +247,8 @@ impl SemanticIndex { }); // Parse files into embeddable documents. - let (parsing_files_tx, parsing_files_rx) = channel::unbounded::(); + let (parsing_files_tx, parsing_files_rx) = + channel::unbounded::<(Arc>, PendingFile)>(); let embedding_queue = Arc::new(Mutex::new(embedding_queue)); let mut _parsing_files_tasks = Vec::new(); for _ in 0..cx.background().num_cpus() { @@ -258,14 +259,16 @@ impl SemanticIndex { let db = db.clone(); _parsing_files_tasks.push(cx.background().spawn(async move { let mut retriever = CodeContextRetriever::new(embedding_provider.clone()); - while let Ok(pending_file) = parsing_files_rx.recv().await { + while let Ok((embeddings_for_digest, pending_file)) = + parsing_files_rx.recv().await + { Self::parse_file( &fs, pending_file, &mut retriever, &embedding_queue, &parsing_files_rx, - &db, + &embeddings_for_digest, ) .await; } @@ -294,8 +297,11 @@ impl SemanticIndex { pending_file: PendingFile, retriever: &mut CodeContextRetriever, embedding_queue: &Arc>, - parsing_files_rx: &channel::Receiver, - db: &VectorDatabase, + parsing_files_rx: &channel::Receiver<( + Arc>, + PendingFile, + )>, + embeddings_for_digest: &HashMap, ) { let Some(language) = pending_file.language else { return; @@ -312,18 +318,9 @@ impl SemanticIndex { documents.len() ); - if let Some(sha_to_embeddings) = db - .embeddings_for_file( - pending_file.worktree_db_id, - pending_file.relative_path.clone(), - ) - .await - .log_err() - { - for document in documents.iter_mut() { - if let Some(embedding) = sha_to_embeddings.get(&document.digest) { - document.embedding = Some(embedding.to_owned()); - } + for document in documents.iter_mut() { + if let Some(embedding) = embeddings_for_digest.get(&document.digest) { + document.embedding = Some(embedding.to_owned()); } } @@ -381,6 +378,17 @@ impl SemanticIndex { return; }; + let embeddings_for_digest = { + let mut worktree_id_file_paths = Vec::new(); + for (path, _) in &project_state.changed_paths { + if let Some(worktree_db_id) = project_state.db_id_for_worktree_id(path.worktree_id) + { + worktree_id_file_paths.push((worktree_db_id, path.path.to_path_buf())); + } + } + self.db.embeddings_for_files(worktree_id_file_paths) + }; + let worktree = worktree.read(cx); let change_time = Instant::now(); for (path, entry_id, change) in changes.iter() { @@ -405,9 +413,18 @@ impl SemanticIndex { } cx.spawn_weak(|this, mut cx| async move { + let embeddings_for_digest = embeddings_for_digest.await.log_err().unwrap_or_default(); + cx.background().timer(BACKGROUND_INDEXING_DELAY).await; if let Some((this, project)) = this.upgrade(&cx).zip(project.upgrade(&cx)) { - Self::reindex_changed_paths(this, project, Some(change_time), &mut cx).await; + Self::reindex_changed_paths( + this, + project, + Some(change_time), + &mut cx, + Arc::new(embeddings_for_digest), + ) + .await; } }) .detach(); @@ -561,7 +578,32 @@ impl SemanticIndex { cx: &mut ModelContext, ) -> Task)>> { cx.spawn(|this, mut cx| async move { - Self::reindex_changed_paths(this.clone(), project.clone(), None, &mut cx).await; + let embeddings_for_digest = this.read_with(&cx, |this, cx| { + if let Some(state) = this.projects.get(&project.downgrade()) { + let mut worktree_id_file_paths = Vec::new(); + for (path, _) in &state.changed_paths { + if let Some(worktree_db_id) = state.db_id_for_worktree_id(path.worktree_id) + { + worktree_id_file_paths.push((worktree_db_id, path.path.to_path_buf())); + } + } + + Ok(this.db.embeddings_for_files(worktree_id_file_paths)) + } else { + Err(anyhow!("Project not yet initialized")) + } + })?; + + let embeddings_for_digest = Arc::new(embeddings_for_digest.await?); + + Self::reindex_changed_paths( + this.clone(), + project.clone(), + None, + &mut cx, + embeddings_for_digest, + ) + .await; this.update(&mut cx, |this, _cx| { let Some(state) = this.projects.get(&project.downgrade()) else { @@ -726,6 +768,7 @@ impl SemanticIndex { project: ModelHandle, last_changed_before: Option, cx: &mut AsyncAppContext, + embeddings_for_digest: Arc>, ) { let mut pending_files = Vec::new(); let mut files_to_delete = Vec::new(); @@ -805,7 +848,9 @@ impl SemanticIndex { } pending_file.language = Some(language); } - parsing_files_tx.try_send(pending_file).ok(); + parsing_files_tx + .try_send((embeddings_for_digest.clone(), pending_file)) + .ok(); } } } From c4db914f0a4397878ffeb7ffb74c8f6a3522e272 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Fri, 1 Sep 2023 08:59:18 -0400 Subject: [PATCH 16/20] move embeddings queue to use single hashmap for all changed paths Co-authored-by: Antonio --- crates/semantic_index/src/db.rs | 79 ++++++++----------- crates/semantic_index/src/semantic_index.rs | 14 +++- .../src/semantic_index_tests.rs | 5 +- 3 files changed, 46 insertions(+), 52 deletions(-) diff --git a/crates/semantic_index/src/db.rs b/crates/semantic_index/src/db.rs index 4a953a2866..abb47cddf0 100644 --- a/crates/semantic_index/src/db.rs +++ b/crates/semantic_index/src/db.rs @@ -265,58 +265,43 @@ impl VectorDatabase { }) } - pub fn embeddings_for_file( - &self, - worktree_id: i64, - relative_path: PathBuf, - ) -> impl Future>> { - let relative_path = relative_path.to_string_lossy().into_owned(); - self.transact(move |db| { - let mut query = db.prepare("SELECT digest, embedding FROM documents LEFT JOIN files ON files.id = documents.file_id WHERE files.worktree_id = ?1 AND files.relative_path = ?2")?; - let mut result: HashMap = HashMap::new(); - for row in query.query_map(params![worktree_id, relative_path], |row| { - Ok((row.get::<_, DocumentDigest>(0)?.into(), row.get::<_, Embedding>(1)?.into())) - })? { - let row = row?; - result.insert(row.0, row.1); - } - Ok(result) - }) - } - pub fn embeddings_for_files( &self, - worktree_id_file_paths: Vec<(i64, PathBuf)>, + worktree_id_file_paths: HashMap>>, ) -> impl Future>> { - todo!(); - // The remainder of the code is wired up. - // I'm having a bit of trouble figuring out the rusqlite syntax for a WHERE (files.worktree_id, files.relative_path) IN (VALUES (?, ?), (?, ?)) query - async { Ok(HashMap::new()) } - // let mut embeddings_by_digest = HashMap::new(); - // self.transact(move |db| { + self.transact(move |db| { + let mut query = db.prepare( + " + SELECT digest, embedding + FROM documents + LEFT JOIN files ON files.id = documents.file_id + WHERE files.worktree_id = ? AND files.relative_path IN rarray(?) + ", + )?; + let mut embeddings_by_digest = HashMap::new(); + for (worktree_id, file_paths) in worktree_id_file_paths { + let file_paths = Rc::new( + file_paths + .into_iter() + .map(|p| Value::Text(p.to_string_lossy().into_owned())) + .collect::>(), + ); + let rows = query.query_map(params![worktree_id, file_paths], |row| { + Ok(( + row.get::<_, DocumentDigest>(0)?, + row.get::<_, Embedding>(1)?, + )) + })?; - // let worktree_ids: Rc> = Rc::new( - // worktree_id_file_paths - // .iter() - // .map(|(id, _)| Value::from(*id)) - // .collect(), - // ); - // let file_paths: Rc> = Rc::new(worktree_id_file_paths - // .iter() - // .map(|(_, path)| Value::from(path.to_string_lossy().to_string())) - // .collect()); + for row in rows { + if let Ok(row) = row { + embeddings_by_digest.insert(row.0, row.1); + } + } + } - // let mut query = db.prepare("SELECT digest, embedding FROM documents LEFT JOIN files ON files.id = documents.file_id WHERE (files.worktree_id, files.relative_path) IN (VALUES (rarray = (?1), rarray = (?2))")?; - - // for row in query.query_map(params![worktree_ids, file_paths], |row| { - // Ok((row.get::<_, DocumentDigest>(0)?, row.get::<_, Embedding>(1)?)) - // })? { - // if let Ok(row) = row { - // embeddings_by_digest.insert(row.0, row.1); - // } - // } - // Ok(embeddings_by_digest) - // }) + Ok(embeddings_by_digest) + }) } pub fn find_or_create_worktree( diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 908ac1f4be..6d140931d6 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -379,11 +379,14 @@ impl SemanticIndex { }; let embeddings_for_digest = { - let mut worktree_id_file_paths = Vec::new(); + let mut worktree_id_file_paths = HashMap::new(); for (path, _) in &project_state.changed_paths { if let Some(worktree_db_id) = project_state.db_id_for_worktree_id(path.worktree_id) { - worktree_id_file_paths.push((worktree_db_id, path.path.to_path_buf())); + worktree_id_file_paths + .entry(worktree_db_id) + .or_insert(Vec::new()) + .push(path.path.clone()); } } self.db.embeddings_for_files(worktree_id_file_paths) @@ -580,11 +583,14 @@ impl SemanticIndex { cx.spawn(|this, mut cx| async move { let embeddings_for_digest = this.read_with(&cx, |this, cx| { if let Some(state) = this.projects.get(&project.downgrade()) { - let mut worktree_id_file_paths = Vec::new(); + let mut worktree_id_file_paths = HashMap::default(); for (path, _) in &state.changed_paths { if let Some(worktree_db_id) = state.db_id_for_worktree_id(path.worktree_id) { - worktree_id_file_paths.push((worktree_db_id, path.path.to_path_buf())); + worktree_id_file_paths + .entry(worktree_db_id) + .or_insert(Vec::new()) + .push(path.path.clone()); } } diff --git a/crates/semantic_index/src/semantic_index_tests.rs b/crates/semantic_index/src/semantic_index_tests.rs index e65bc04412..01f34a2b1d 100644 --- a/crates/semantic_index/src/semantic_index_tests.rs +++ b/crates/semantic_index/src/semantic_index_tests.rs @@ -55,6 +55,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) { fn bbb() { println!(\"bbbbbbbbbbbbb!\"); } + struct pqpqpqp {} ".unindent(), "file3.toml": " ZZZZZZZZZZZZZZZZZZ = 5 @@ -121,6 +122,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) { (Path::new("src/file2.rs").into(), 0), (Path::new("src/file3.toml").into(), 0), (Path::new("src/file1.rs").into(), 45), + (Path::new("src/file2.rs").into(), 45), ], cx, ); @@ -148,6 +150,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) { (Path::new("src/file1.rs").into(), 0), (Path::new("src/file2.rs").into(), 0), (Path::new("src/file1.rs").into(), 45), + (Path::new("src/file2.rs").into(), 45), ], cx, ); @@ -199,7 +202,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) { assert_eq!( embedding_provider.embedding_count() - prev_embedding_count, - 2 + 1 ); } From 524533cfb227dffba93adfec461fee722c73ba4d Mon Sep 17 00:00:00 2001 From: KCaverly Date: Fri, 1 Sep 2023 11:24:08 -0400 Subject: [PATCH 17/20] flush embeddings queue when no files are parsed for 250 milliseconds Co-authored-by: Antonio --- crates/semantic_index/src/semantic_index.rs | 50 ++++++++++--------- .../src/semantic_index_tests.rs | 12 ++--- 2 files changed, 33 insertions(+), 29 deletions(-) diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 6d140931d6..a8518ce695 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -12,6 +12,7 @@ use anyhow::{anyhow, Result}; use db::VectorDatabase; use embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings}; use embedding_queue::{EmbeddingQueue, FileToEmbed}; +use futures::{FutureExt, StreamExt}; use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle}; use language::{Anchor, Buffer, Language, LanguageRegistry}; use parking_lot::Mutex; @@ -39,6 +40,7 @@ use workspace::WorkspaceCreated; const SEMANTIC_INDEX_VERSION: usize = 8; const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(600); +const EMBEDDING_QUEUE_FLUSH_TIMEOUT: Duration = Duration::from_millis(250); pub fn init( fs: Arc, @@ -253,24 +255,34 @@ impl SemanticIndex { let mut _parsing_files_tasks = Vec::new(); for _ in 0..cx.background().num_cpus() { let fs = fs.clone(); - let parsing_files_rx = parsing_files_rx.clone(); + let mut parsing_files_rx = parsing_files_rx.clone(); let embedding_provider = embedding_provider.clone(); let embedding_queue = embedding_queue.clone(); - let db = db.clone(); + let background = cx.background().clone(); _parsing_files_tasks.push(cx.background().spawn(async move { let mut retriever = CodeContextRetriever::new(embedding_provider.clone()); - while let Ok((embeddings_for_digest, pending_file)) = - parsing_files_rx.recv().await - { - Self::parse_file( - &fs, - pending_file, - &mut retriever, - &embedding_queue, - &parsing_files_rx, - &embeddings_for_digest, - ) - .await; + loop { + let mut timer = background.timer(EMBEDDING_QUEUE_FLUSH_TIMEOUT).fuse(); + let mut next_file_to_parse = parsing_files_rx.next().fuse(); + futures::select_biased! { + next_file_to_parse = next_file_to_parse => { + if let Some((embeddings_for_digest, pending_file)) = next_file_to_parse { + Self::parse_file( + &fs, + pending_file, + &mut retriever, + &embedding_queue, + &embeddings_for_digest, + ) + .await + } else { + break; + } + }, + _ = timer => { + embedding_queue.lock().flush(); + } + } } })); } @@ -297,10 +309,6 @@ impl SemanticIndex { pending_file: PendingFile, retriever: &mut CodeContextRetriever, embedding_queue: &Arc>, - parsing_files_rx: &channel::Receiver<( - Arc>, - PendingFile, - )>, embeddings_for_digest: &HashMap, ) { let Some(language) = pending_file.language else { @@ -333,10 +341,6 @@ impl SemanticIndex { }); } } - - if parsing_files_rx.len() == 0 { - embedding_queue.lock().flush(); - } } pub fn project_previously_indexed( @@ -581,7 +585,7 @@ impl SemanticIndex { cx: &mut ModelContext, ) -> Task)>> { cx.spawn(|this, mut cx| async move { - let embeddings_for_digest = this.read_with(&cx, |this, cx| { + let embeddings_for_digest = this.read_with(&cx, |this, _| { if let Some(state) = this.projects.get(&project.downgrade()) { let mut worktree_id_file_paths = HashMap::default(); for (path, _) in &state.changed_paths { diff --git a/crates/semantic_index/src/semantic_index_tests.rs b/crates/semantic_index/src/semantic_index_tests.rs index 01f34a2b1d..f549e68e04 100644 --- a/crates/semantic_index/src/semantic_index_tests.rs +++ b/crates/semantic_index/src/semantic_index_tests.rs @@ -3,11 +3,11 @@ use crate::{ embedding_queue::EmbeddingQueue, parsing::{subtract_ranges, CodeContextRetriever, Document, DocumentDigest}, semantic_index_settings::SemanticIndexSettings, - FileToEmbed, JobHandle, SearchResult, SemanticIndex, + FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT, }; use anyhow::Result; use async_trait::async_trait; -use gpui::{Task, TestAppContext}; +use gpui::{executor::Deterministic, Task, TestAppContext}; use language::{Language, LanguageConfig, LanguageRegistry, ToOffset}; use parking_lot::Mutex; use pretty_assertions::assert_eq; @@ -34,7 +34,7 @@ fn init_logger() { } #[gpui::test] -async fn test_semantic_index(cx: &mut TestAppContext) { +async fn test_semantic_index(deterministic: Arc, cx: &mut TestAppContext) { init_test(cx); let fs = FakeFs::new(cx.background()); @@ -98,7 +98,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) { .await .unwrap(); assert_eq!(file_count, 3); - cx.foreground().run_until_parked(); + deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT); assert_eq!(*outstanding_file_count.borrow(), 0); let search_results = semantic_index @@ -188,7 +188,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) { .await .unwrap(); - cx.foreground().run_until_parked(); + deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT); let prev_embedding_count = embedding_provider.embedding_count(); let (file_count, outstanding_file_count) = semantic_index @@ -197,7 +197,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) { .unwrap(); assert_eq!(file_count, 1); - cx.foreground().run_until_parked(); + deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT); assert_eq!(*outstanding_file_count.borrow(), 0); assert_eq!( From e86964eb5d4f7e2a387d8faec32f18df8da91362 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Fri, 1 Sep 2023 13:01:37 -0400 Subject: [PATCH 18/20] optimize insert file in vector database Co-authored-by: Max --- crates/semantic_index/src/db.rs | 65 ++++++++++----------- crates/semantic_index/src/semantic_index.rs | 2 +- 2 files changed, 33 insertions(+), 34 deletions(-) diff --git a/crates/semantic_index/src/db.rs b/crates/semantic_index/src/db.rs index abb47cddf0..6cfd01456d 100644 --- a/crates/semantic_index/src/db.rs +++ b/crates/semantic_index/src/db.rs @@ -162,6 +162,11 @@ impl VectorDatabase { [], )?; + db.execute( + "CREATE UNIQUE INDEX files_worktree_id_and_relative_path ON files (worktree_id, relative_path)", + [], + )?; + db.execute( "CREATE TABLE documents ( id INTEGER PRIMARY KEY AUTOINCREMENT, @@ -206,43 +211,37 @@ impl VectorDatabase { // Return the existing ID, if both the file and mtime match let mtime = Timestamp::from(mtime); - let mut existing_id_query = db.prepare("SELECT id FROM files WHERE worktree_id = ?1 AND relative_path = ?2 AND mtime_seconds = ?3 AND mtime_nanos = ?4")?; - let existing_id = existing_id_query - .query_row( - params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos], - |row| Ok(row.get::<_, i64>(0)?), - ); + db.execute( + " + REPLACE INTO files + (worktree_id, relative_path, mtime_seconds, mtime_nanos) + VALUES (?1, ?2, ?3, ?4) + ", + params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos], + )?; - let file_id = if existing_id.is_ok() { - // If already exists, just return the existing id - existing_id? - } else { - // Delete Existing Row - db.execute( - "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2;", - params![worktree_id, path.to_str()], - )?; - db.execute("INSERT INTO files (worktree_id, relative_path, mtime_seconds, mtime_nanos) VALUES (?1, ?2, ?3, ?4);", params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos])?; - db.last_insert_rowid() - }; + let file_id = db.last_insert_rowid(); + + let mut query = db.prepare( + " + INSERT INTO documents + (file_id, start_byte, end_byte, name, embedding, digest) + VALUES (?1, ?2, ?3, ?4, ?5, ?6) + ", + )?; - // Currently inserting at approximately 3400 documents a second - // I imagine we can speed this up with a bulk insert of some kind. for document in documents { - db.execute( - "INSERT INTO documents (file_id, start_byte, end_byte, name, embedding, digest) VALUES (?1, ?2, ?3, ?4, ?5, ?6)", - params![ - file_id, - document.range.start.to_string(), - document.range.end.to_string(), - document.name, - document.embedding, - document.digest - ], - )?; - } + query.execute(params![ + file_id, + document.range.start.to_string(), + document.range.end.to_string(), + document.name, + document.embedding, + document.digest + ])?; + } - Ok(()) + Ok(()) }) } diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index a8518ce695..e155fe3c74 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -38,7 +38,7 @@ use util::{ }; use workspace::WorkspaceCreated; -const SEMANTIC_INDEX_VERSION: usize = 8; +const SEMANTIC_INDEX_VERSION: usize = 9; const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(600); const EMBEDDING_QUEUE_FLUSH_TIMEOUT: Duration = Duration::from_millis(250); From 54235f4fb179049b7f8b27eaf9de1cd5e7e54d33 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Fri, 1 Sep 2023 13:04:09 -0400 Subject: [PATCH 19/20] updated embeddings background delay to 5 minutes Co-authored-by: Max --- crates/semantic_index/src/semantic_index.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index e155fe3c74..4e48b9cd71 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -39,7 +39,7 @@ use util::{ use workspace::WorkspaceCreated; const SEMANTIC_INDEX_VERSION: usize = 9; -const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(600); +const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(5 * 60); const EMBEDDING_QUEUE_FLUSH_TIMEOUT: Duration = Duration::from_millis(250); pub fn init( From 8dbc0fe0333958d6057a99bb734700deb270bf8b Mon Sep 17 00:00:00 2001 From: KCaverly Date: Fri, 1 Sep 2023 17:07:20 -0400 Subject: [PATCH 20/20] update pragma settings for improved database performance --- crates/semantic_index/src/db.rs | 13 ++++++++++++- crates/semantic_index/src/semantic_index.rs | 1 - 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/crates/semantic_index/src/db.rs b/crates/semantic_index/src/db.rs index 6cfd01456d..2ececc1eb6 100644 --- a/crates/semantic_index/src/db.rs +++ b/crates/semantic_index/src/db.rs @@ -18,7 +18,7 @@ use std::{ path::{Path, PathBuf}, rc::Rc, sync::Arc, - time::SystemTime, + time::{Instant, SystemTime}, }; use util::TryFutureExt; @@ -54,6 +54,12 @@ impl VectorDatabase { let path = path.clone(); async move { let mut connection = rusqlite::Connection::open(&path)?; + + connection.pragma_update(None, "journal_mode", "wal")?; + connection.pragma_update(None, "synchronous", "normal")?; + connection.pragma_update(None, "cache_size", 1000000)?; + connection.pragma_update(None, "temp_store", "MEMORY")?; + while let Ok(transaction) = transactions_rx.recv().await { transaction(&mut connection); } @@ -222,6 +228,7 @@ impl VectorDatabase { let file_id = db.last_insert_rowid(); + let t0 = Instant::now(); let mut query = db.prepare( " INSERT INTO documents @@ -229,6 +236,10 @@ impl VectorDatabase { VALUES (?1, ?2, ?3, ?4, ?5, ?6) ", )?; + log::trace!( + "Preparing Query Took: {:?} milliseconds", + t0.elapsed().as_millis() + ); for document in documents { query.execute(params![ diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 4e48b9cd71..a917eabfc8 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -81,7 +81,6 @@ pub fn init( let semantic_index = SemanticIndex::new( fs, db_file_path, - // Arc::new(embedding::DummyEmbeddings {}), Arc::new(OpenAIEmbeddings { client: http_client, executor: cx.background(),