diff --git a/Cargo.lock b/Cargo.lock index a185542c63..8786c8ed6e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3539,7 +3539,7 @@ dependencies = [ "gif", "jpeg-decoder", "num-iter", - "num-rational", + "num-rational 0.3.2", "num-traits", "png", "scoped_threadpool", @@ -4631,6 +4631,31 @@ dependencies = [ "winapi 0.3.9", ] +[[package]] +name = "num" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8536030f9fea7127f841b45bb6243b27255787fb4eb83958aa1ef9d2fdc0c36" +dependencies = [ + "num-bigint 0.2.6", + "num-complex", + "num-integer", + "num-iter", + "num-rational 0.2.4", + "num-traits", +] + +[[package]] +name = "num-bigint" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "090c7f9998ee0ff65aa5b723e4009f7b217707f1fb5ea551329cc4d6231fb304" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + [[package]] name = "num-bigint" version = "0.4.4" @@ -4659,6 +4684,16 @@ dependencies = [ "zeroize", ] +[[package]] +name = "num-complex" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6b19411a9719e753aff12e5187b74d60d3dc449ec3f4dc21e3989c3f554bc95" +dependencies = [ + "autocfg", + "num-traits", +] + [[package]] name = "num-derive" version = "0.3.3" @@ -4691,6 +4726,18 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-rational" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c000134b5dbf44adc5cb772486d335293351644b801551abe8f75c84cfa4aef" +dependencies = [ + "autocfg", + "num-bigint 0.2.6", + "num-integer", + "num-traits", +] + [[package]] name = "num-rational" version = "0.3.2" @@ -5007,6 +5054,17 @@ dependencies = [ "windows-targets 0.48.5", ] +[[package]] +name = "parse_duration" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7037e5e93e0172a5a96874380bf73bc6ecef022e26fa25f2be26864d6b3ba95d" +dependencies = [ + "lazy_static", + "num", + "regex", +] + [[package]] name = "password-hash" version = "0.2.3" @@ -6674,6 +6732,7 @@ dependencies = [ "log", "matrixmultiply", "parking_lot 0.11.2", + "parse_duration", "picker", "postage", "pretty_assertions", @@ -7005,7 +7064,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8eb4ea60fb301dc81dfc113df680571045d375ab7345d171c5dc7d7e13107a80" dependencies = [ "chrono", - "num-bigint", + "num-bigint 0.4.4", "num-traits", "thiserror", ] @@ -7237,7 +7296,7 @@ dependencies = [ "log", "md-5", "memchr", - "num-bigint", + "num-bigint 0.4.4", "once_cell", "paste", "percent-encoding", diff --git a/crates/semantic_index/Cargo.toml b/crates/semantic_index/Cargo.toml index 4e817fcbe2..d46346e0ab 100644 --- a/crates/semantic_index/Cargo.toml +++ b/crates/semantic_index/Cargo.toml @@ -39,6 +39,7 @@ rand.workspace = true schemars.workspace = true globset.workspace = true sha1 = "0.10.5" +parse_duration = "2.1.1" [dev-dependencies] gpui = { path = "../gpui", features = ["test-support"] } diff --git a/crates/semantic_index/src/db.rs b/crates/semantic_index/src/db.rs index 60ecf3b45f..2ececc1eb6 100644 --- a/crates/semantic_index/src/db.rs +++ b/crates/semantic_index/src/db.rs @@ -1,20 +1,26 @@ -use crate::{parsing::Document, SEMANTIC_INDEX_VERSION}; +use crate::{ + embedding::Embedding, + parsing::{Document, DocumentDigest}, + SEMANTIC_INDEX_VERSION, +}; use anyhow::{anyhow, Context, Result}; +use futures::channel::oneshot; +use gpui::executor; use project::{search::PathMatcher, Fs}; use rpc::proto::Timestamp; -use rusqlite::{ - params, - types::{FromSql, FromSqlResult, ValueRef}, -}; +use rusqlite::params; +use rusqlite::types::Value; use std::{ cmp::Ordering, collections::HashMap, + future::Future, ops::Range, path::{Path, PathBuf}, rc::Rc, sync::Arc, - time::SystemTime, + time::{Instant, SystemTime}, }; +use util::TryFutureExt; #[derive(Debug)] pub struct FileRecord { @@ -23,145 +29,181 @@ pub struct FileRecord { pub mtime: Timestamp, } -#[derive(Debug)] -struct Embedding(pub Vec); - -#[derive(Debug)] -struct Sha1(pub Vec); - -impl FromSql for Embedding { - fn column_result(value: ValueRef) -> FromSqlResult { - let bytes = value.as_blob()?; - let embedding: Result, Box> = bincode::deserialize(bytes); - if embedding.is_err() { - return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err())); - } - return Ok(Embedding(embedding.unwrap())); - } -} - -impl FromSql for Sha1 { - fn column_result(value: ValueRef) -> FromSqlResult { - let bytes = value.as_blob()?; - let sha1: Result, Box> = bincode::deserialize(bytes); - if sha1.is_err() { - return Err(rusqlite::types::FromSqlError::Other(sha1.unwrap_err())); - } - return Ok(Sha1(sha1.unwrap())); - } -} - +#[derive(Clone)] pub struct VectorDatabase { - db: rusqlite::Connection, + path: Arc, + transactions: + smol::channel::Sender>, } impl VectorDatabase { - pub async fn new(fs: Arc, path: Arc) -> Result { + pub async fn new( + fs: Arc, + path: Arc, + executor: Arc, + ) -> Result { if let Some(db_directory) = path.parent() { fs.create_dir(db_directory).await?; } + let (transactions_tx, transactions_rx) = smol::channel::unbounded::< + Box, + >(); + executor + .spawn({ + let path = path.clone(); + async move { + let mut connection = rusqlite::Connection::open(&path)?; + + connection.pragma_update(None, "journal_mode", "wal")?; + connection.pragma_update(None, "synchronous", "normal")?; + connection.pragma_update(None, "cache_size", 1000000)?; + connection.pragma_update(None, "temp_store", "MEMORY")?; + + while let Ok(transaction) = transactions_rx.recv().await { + transaction(&mut connection); + } + + anyhow::Ok(()) + } + .log_err() + }) + .detach(); let this = Self { - db: rusqlite::Connection::open(path.as_path())?, + transactions: transactions_tx, + path, }; - this.initialize_database()?; + this.initialize_database().await?; Ok(this) } - fn get_existing_version(&self) -> Result { - let mut version_query = self - .db - .prepare("SELECT version from semantic_index_config")?; - version_query - .query_row([], |row| Ok(row.get::<_, i64>(0)?)) - .map_err(|err| anyhow!("version query failed: {err}")) + pub fn path(&self) -> &Arc { + &self.path } - fn initialize_database(&self) -> Result<()> { - rusqlite::vtab::array::load_module(&self.db)?; - - // Delete existing tables, if SEMANTIC_INDEX_VERSION is bumped - if self - .get_existing_version() - .map_or(false, |version| version == SEMANTIC_INDEX_VERSION as i64) - { - log::trace!("vector database schema up to date"); - return Ok(()); + fn transact(&self, f: F) -> impl Future> + where + F: 'static + Send + FnOnce(&rusqlite::Transaction) -> Result, + T: 'static + Send, + { + let (tx, rx) = oneshot::channel(); + let transactions = self.transactions.clone(); + async move { + if transactions + .send(Box::new(|connection| { + let result = connection + .transaction() + .map_err(|err| anyhow!(err)) + .and_then(|transaction| { + let result = f(&transaction)?; + transaction.commit()?; + Ok(result) + }); + let _ = tx.send(result); + })) + .await + .is_err() + { + return Err(anyhow!("connection was dropped"))?; + } + rx.await? } - - log::trace!("vector database schema out of date. updating..."); - self.db - .execute("DROP TABLE IF EXISTS documents", []) - .context("failed to drop 'documents' table")?; - self.db - .execute("DROP TABLE IF EXISTS files", []) - .context("failed to drop 'files' table")?; - self.db - .execute("DROP TABLE IF EXISTS worktrees", []) - .context("failed to drop 'worktrees' table")?; - self.db - .execute("DROP TABLE IF EXISTS semantic_index_config", []) - .context("failed to drop 'semantic_index_config' table")?; - - // Initialize Vector Databasing Tables - self.db.execute( - "CREATE TABLE semantic_index_config ( - version INTEGER NOT NULL - )", - [], - )?; - - self.db.execute( - "INSERT INTO semantic_index_config (version) VALUES (?1)", - params![SEMANTIC_INDEX_VERSION], - )?; - - self.db.execute( - "CREATE TABLE worktrees ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - absolute_path VARCHAR NOT NULL - ); - CREATE UNIQUE INDEX worktrees_absolute_path ON worktrees (absolute_path); - ", - [], - )?; - - self.db.execute( - "CREATE TABLE files ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - worktree_id INTEGER NOT NULL, - relative_path VARCHAR NOT NULL, - mtime_seconds INTEGER NOT NULL, - mtime_nanos INTEGER NOT NULL, - FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE - )", - [], - )?; - - self.db.execute( - "CREATE TABLE documents ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - file_id INTEGER NOT NULL, - start_byte INTEGER NOT NULL, - end_byte INTEGER NOT NULL, - name VARCHAR NOT NULL, - embedding BLOB NOT NULL, - sha1 BLOB NOT NULL, - FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE - )", - [], - )?; - - log::trace!("vector database initialized with updated schema."); - Ok(()) } - pub fn delete_file(&self, worktree_id: i64, delete_path: PathBuf) -> Result<()> { - self.db.execute( - "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2", - params![worktree_id, delete_path.to_str()], - )?; - Ok(()) + fn initialize_database(&self) -> impl Future> { + self.transact(|db| { + rusqlite::vtab::array::load_module(&db)?; + + // Delete existing tables, if SEMANTIC_INDEX_VERSION is bumped + let version_query = db.prepare("SELECT version from semantic_index_config"); + let version = version_query + .and_then(|mut query| query.query_row([], |row| Ok(row.get::<_, i64>(0)?))); + if version.map_or(false, |version| version == SEMANTIC_INDEX_VERSION as i64) { + log::trace!("vector database schema up to date"); + return Ok(()); + } + + log::trace!("vector database schema out of date. updating..."); + db.execute("DROP TABLE IF EXISTS documents", []) + .context("failed to drop 'documents' table")?; + db.execute("DROP TABLE IF EXISTS files", []) + .context("failed to drop 'files' table")?; + db.execute("DROP TABLE IF EXISTS worktrees", []) + .context("failed to drop 'worktrees' table")?; + db.execute("DROP TABLE IF EXISTS semantic_index_config", []) + .context("failed to drop 'semantic_index_config' table")?; + + // Initialize Vector Databasing Tables + db.execute( + "CREATE TABLE semantic_index_config ( + version INTEGER NOT NULL + )", + [], + )?; + + db.execute( + "INSERT INTO semantic_index_config (version) VALUES (?1)", + params![SEMANTIC_INDEX_VERSION], + )?; + + db.execute( + "CREATE TABLE worktrees ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + absolute_path VARCHAR NOT NULL + ); + CREATE UNIQUE INDEX worktrees_absolute_path ON worktrees (absolute_path); + ", + [], + )?; + + db.execute( + "CREATE TABLE files ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + worktree_id INTEGER NOT NULL, + relative_path VARCHAR NOT NULL, + mtime_seconds INTEGER NOT NULL, + mtime_nanos INTEGER NOT NULL, + FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE + )", + [], + )?; + + db.execute( + "CREATE UNIQUE INDEX files_worktree_id_and_relative_path ON files (worktree_id, relative_path)", + [], + )?; + + db.execute( + "CREATE TABLE documents ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + file_id INTEGER NOT NULL, + start_byte INTEGER NOT NULL, + end_byte INTEGER NOT NULL, + name VARCHAR NOT NULL, + embedding BLOB NOT NULL, + digest BLOB NOT NULL, + FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE + )", + [], + )?; + + log::trace!("vector database initialized with updated schema."); + Ok(()) + }) + } + + pub fn delete_file( + &self, + worktree_id: i64, + delete_path: PathBuf, + ) -> impl Future> { + self.transact(move |db| { + db.execute( + "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2", + params![worktree_id, delete_path.to_str()], + )?; + Ok(()) + }) } pub fn insert_file( @@ -170,139 +212,187 @@ impl VectorDatabase { path: PathBuf, mtime: SystemTime, documents: Vec, - ) -> Result<()> { - // Return the existing ID, if both the file and mtime match - let mtime = Timestamp::from(mtime); - let mut existing_id_query = self.db.prepare("SELECT id FROM files WHERE worktree_id = ?1 AND relative_path = ?2 AND mtime_seconds = ?3 AND mtime_nanos = ?4")?; - let existing_id = existing_id_query - .query_row( + ) -> impl Future> { + self.transact(move |db| { + // Return the existing ID, if both the file and mtime match + let mtime = Timestamp::from(mtime); + + db.execute( + " + REPLACE INTO files + (worktree_id, relative_path, mtime_seconds, mtime_nanos) + VALUES (?1, ?2, ?3, ?4) + ", params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos], - |row| Ok(row.get::<_, i64>(0)?), - ) - .map_err(|err| anyhow!(err)); - let file_id = if existing_id.is_ok() { - // If already exists, just return the existing id - existing_id.unwrap() - } else { - // Delete Existing Row - self.db.execute( - "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2;", - params![worktree_id, path.to_str()], )?; - self.db.execute("INSERT INTO files (worktree_id, relative_path, mtime_seconds, mtime_nanos) VALUES (?1, ?2, ?3, ?4);", params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos])?; - self.db.last_insert_rowid() - }; - // Currently inserting at approximately 3400 documents a second - // I imagine we can speed this up with a bulk insert of some kind. - for document in documents { - let embedding_blob = bincode::serialize(&document.embedding)?; - let sha_blob = bincode::serialize(&document.sha1)?; + let file_id = db.last_insert_rowid(); - self.db.execute( - "INSERT INTO documents (file_id, start_byte, end_byte, name, embedding, sha1) VALUES (?1, ?2, ?3, ?4, ?5, ?6)", - params![ + let t0 = Instant::now(); + let mut query = db.prepare( + " + INSERT INTO documents + (file_id, start_byte, end_byte, name, embedding, digest) + VALUES (?1, ?2, ?3, ?4, ?5, ?6) + ", + )?; + log::trace!( + "Preparing Query Took: {:?} milliseconds", + t0.elapsed().as_millis() + ); + + for document in documents { + query.execute(params![ file_id, document.range.start.to_string(), document.range.end.to_string(), document.name, - embedding_blob, - sha_blob - ], - )?; - } + document.embedding, + document.digest + ])?; + } - Ok(()) + Ok(()) + }) } - pub fn worktree_previously_indexed(&self, worktree_root_path: &Path) -> Result { - let mut worktree_query = self - .db - .prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?; - let worktree_id = worktree_query - .query_row(params![worktree_root_path.to_string_lossy()], |row| { - Ok(row.get::<_, i64>(0)?) - }) - .map_err(|err| anyhow!(err)); + pub fn worktree_previously_indexed( + &self, + worktree_root_path: &Path, + ) -> impl Future> { + let worktree_root_path = worktree_root_path.to_string_lossy().into_owned(); + self.transact(move |db| { + let mut worktree_query = + db.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?; + let worktree_id = worktree_query + .query_row(params![worktree_root_path], |row| Ok(row.get::<_, i64>(0)?)); - if worktree_id.is_ok() { - return Ok(true); - } else { - return Ok(false); - } + if worktree_id.is_ok() { + return Ok(true); + } else { + return Ok(false); + } + }) } - pub fn find_or_create_worktree(&self, worktree_root_path: &Path) -> Result { - // Check that the absolute path doesnt exist - let mut worktree_query = self - .db - .prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?; - - let worktree_id = worktree_query - .query_row(params![worktree_root_path.to_string_lossy()], |row| { - Ok(row.get::<_, i64>(0)?) - }) - .map_err(|err| anyhow!(err)); - - if worktree_id.is_ok() { - return worktree_id; - } - - // If worktree_id is Err, insert new worktree - self.db.execute( - " - INSERT into worktrees (absolute_path) VALUES (?1) + pub fn embeddings_for_files( + &self, + worktree_id_file_paths: HashMap>>, + ) -> impl Future>> { + self.transact(move |db| { + let mut query = db.prepare( + " + SELECT digest, embedding + FROM documents + LEFT JOIN files ON files.id = documents.file_id + WHERE files.worktree_id = ? AND files.relative_path IN rarray(?) ", - params![worktree_root_path.to_string_lossy()], - )?; - Ok(self.db.last_insert_rowid()) + )?; + let mut embeddings_by_digest = HashMap::new(); + for (worktree_id, file_paths) in worktree_id_file_paths { + let file_paths = Rc::new( + file_paths + .into_iter() + .map(|p| Value::Text(p.to_string_lossy().into_owned())) + .collect::>(), + ); + let rows = query.query_map(params![worktree_id, file_paths], |row| { + Ok(( + row.get::<_, DocumentDigest>(0)?, + row.get::<_, Embedding>(1)?, + )) + })?; + + for row in rows { + if let Ok(row) = row { + embeddings_by_digest.insert(row.0, row.1); + } + } + } + + Ok(embeddings_by_digest) + }) } - pub fn get_file_mtimes(&self, worktree_id: i64) -> Result> { - let mut statement = self.db.prepare( - " - SELECT relative_path, mtime_seconds, mtime_nanos - FROM files - WHERE worktree_id = ?1 - ORDER BY relative_path", - )?; - let mut result: HashMap = HashMap::new(); - for row in statement.query_map(params![worktree_id], |row| { - Ok(( - row.get::<_, String>(0)?.into(), - Timestamp { - seconds: row.get(1)?, - nanos: row.get(2)?, - } - .into(), - )) - })? { - let row = row?; - result.insert(row.0, row.1); - } - Ok(result) + pub fn find_or_create_worktree( + &self, + worktree_root_path: PathBuf, + ) -> impl Future> { + self.transact(move |db| { + let mut worktree_query = + db.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?; + let worktree_id = worktree_query + .query_row(params![worktree_root_path.to_string_lossy()], |row| { + Ok(row.get::<_, i64>(0)?) + }); + + if worktree_id.is_ok() { + return Ok(worktree_id?); + } + + // If worktree_id is Err, insert new worktree + db.execute( + "INSERT into worktrees (absolute_path) VALUES (?1)", + params![worktree_root_path.to_string_lossy()], + )?; + Ok(db.last_insert_rowid()) + }) + } + + pub fn get_file_mtimes( + &self, + worktree_id: i64, + ) -> impl Future>> { + self.transact(move |db| { + let mut statement = db.prepare( + " + SELECT relative_path, mtime_seconds, mtime_nanos + FROM files + WHERE worktree_id = ?1 + ORDER BY relative_path", + )?; + let mut result: HashMap = HashMap::new(); + for row in statement.query_map(params![worktree_id], |row| { + Ok(( + row.get::<_, String>(0)?.into(), + Timestamp { + seconds: row.get(1)?, + nanos: row.get(2)?, + } + .into(), + )) + })? { + let row = row?; + result.insert(row.0, row.1); + } + Ok(result) + }) } pub fn top_k_search( &self, - query_embedding: &Vec, + query_embedding: &Embedding, limit: usize, file_ids: &[i64], - ) -> Result> { - let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1); - self.for_each_document(file_ids, |id, embedding| { - let similarity = dot(&embedding, &query_embedding); - let ix = match results - .binary_search_by(|(_, s)| similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)) - { - Ok(ix) => ix, - Err(ix) => ix, - }; - results.insert(ix, (id, similarity)); - results.truncate(limit); - })?; + ) -> impl Future>> { + let query_embedding = query_embedding.clone(); + let file_ids = file_ids.to_vec(); + self.transact(move |db| { + let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1); + Self::for_each_document(db, &file_ids, |id, embedding| { + let similarity = embedding.similarity(&query_embedding); + let ix = match results.binary_search_by(|(_, s)| { + similarity.partial_cmp(&s).unwrap_or(Ordering::Equal) + }) { + Ok(ix) => ix, + Err(ix) => ix, + }; + results.insert(ix, (id, similarity)); + results.truncate(limit); + })?; - Ok(results) + anyhow::Ok(results) + }) } pub fn retrieve_included_file_ids( @@ -310,37 +400,46 @@ impl VectorDatabase { worktree_ids: &[i64], includes: &[PathMatcher], excludes: &[PathMatcher], - ) -> Result> { - let mut file_query = self.db.prepare( - " - SELECT - id, relative_path - FROM - files - WHERE - worktree_id IN rarray(?) - ", - )?; + ) -> impl Future>> { + let worktree_ids = worktree_ids.to_vec(); + let includes = includes.to_vec(); + let excludes = excludes.to_vec(); + self.transact(move |db| { + let mut file_query = db.prepare( + " + SELECT + id, relative_path + FROM + files + WHERE + worktree_id IN rarray(?) + ", + )?; - let mut file_ids = Vec::::new(); - let mut rows = file_query.query([ids_to_sql(worktree_ids)])?; + let mut file_ids = Vec::::new(); + let mut rows = file_query.query([ids_to_sql(&worktree_ids)])?; - while let Some(row) = rows.next()? { - let file_id = row.get(0)?; - let relative_path = row.get_ref(1)?.as_str()?; - let included = - includes.is_empty() || includes.iter().any(|glob| glob.is_match(relative_path)); - let excluded = excludes.iter().any(|glob| glob.is_match(relative_path)); - if included && !excluded { - file_ids.push(file_id); + while let Some(row) = rows.next()? { + let file_id = row.get(0)?; + let relative_path = row.get_ref(1)?.as_str()?; + let included = + includes.is_empty() || includes.iter().any(|glob| glob.is_match(relative_path)); + let excluded = excludes.iter().any(|glob| glob.is_match(relative_path)); + if included && !excluded { + file_ids.push(file_id); + } } - } - Ok(file_ids) + anyhow::Ok(file_ids) + }) } - fn for_each_document(&self, file_ids: &[i64], mut f: impl FnMut(i64, Vec)) -> Result<()> { - let mut query_statement = self.db.prepare( + fn for_each_document( + db: &rusqlite::Connection, + file_ids: &[i64], + mut f: impl FnMut(i64, Embedding), + ) -> Result<()> { + let mut query_statement = db.prepare( " SELECT id, embedding @@ -356,51 +455,57 @@ impl VectorDatabase { Ok((row.get(0)?, row.get::<_, Embedding>(1)?)) })? .filter_map(|row| row.ok()) - .for_each(|(id, embedding)| f(id, embedding.0)); + .for_each(|(id, embedding)| f(id, embedding)); Ok(()) } - pub fn get_documents_by_ids(&self, ids: &[i64]) -> Result)>> { - let mut statement = self.db.prepare( - " - SELECT - documents.id, - files.worktree_id, - files.relative_path, - documents.start_byte, - documents.end_byte - FROM - documents, files - WHERE - documents.file_id = files.id AND - documents.id in rarray(?) - ", - )?; + pub fn get_documents_by_ids( + &self, + ids: &[i64], + ) -> impl Future)>>> { + let ids = ids.to_vec(); + self.transact(move |db| { + let mut statement = db.prepare( + " + SELECT + documents.id, + files.worktree_id, + files.relative_path, + documents.start_byte, + documents.end_byte + FROM + documents, files + WHERE + documents.file_id = files.id AND + documents.id in rarray(?) + ", + )?; - let result_iter = statement.query_map(params![ids_to_sql(ids)], |row| { - Ok(( - row.get::<_, i64>(0)?, - row.get::<_, i64>(1)?, - row.get::<_, String>(2)?.into(), - row.get(3)?..row.get(4)?, - )) - })?; + let result_iter = statement.query_map(params![ids_to_sql(&ids)], |row| { + Ok(( + row.get::<_, i64>(0)?, + row.get::<_, i64>(1)?, + row.get::<_, String>(2)?.into(), + row.get(3)?..row.get(4)?, + )) + })?; - let mut values_by_id = HashMap::)>::default(); - for row in result_iter { - let (id, worktree_id, path, range) = row?; - values_by_id.insert(id, (worktree_id, path, range)); - } + let mut values_by_id = HashMap::)>::default(); + for row in result_iter { + let (id, worktree_id, path, range) = row?; + values_by_id.insert(id, (worktree_id, path, range)); + } - let mut results = Vec::with_capacity(ids.len()); - for id in ids { - let value = values_by_id - .remove(id) - .ok_or(anyhow!("missing document id {}", id))?; - results.push(value); - } + let mut results = Vec::with_capacity(ids.len()); + for id in &ids { + let value = values_by_id + .remove(id) + .ok_or(anyhow!("missing document id {}", id))?; + results.push(value); + } - Ok(results) + Ok(results) + }) } } @@ -412,29 +517,3 @@ fn ids_to_sql(ids: &[i64]) -> Rc> { .collect::>(), ) } - -pub(crate) fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 { - let len = vec_a.len(); - assert_eq!(len, vec_b.len()); - - let mut result = 0.0; - unsafe { - matrixmultiply::sgemm( - 1, - len, - 1, - 1.0, - vec_a.as_ptr(), - len as isize, - 1, - vec_b.as_ptr(), - 1, - len as isize, - 0.0, - &mut result as *mut f32, - 1, - 1, - ); - } - result -} diff --git a/crates/semantic_index/src/embedding.rs b/crates/semantic_index/src/embedding.rs index f2269a786a..97c25ca170 100644 --- a/crates/semantic_index/src/embedding.rs +++ b/crates/semantic_index/src/embedding.rs @@ -7,6 +7,9 @@ use isahc::http::StatusCode; use isahc::prelude::Configurable; use isahc::{AsyncBody, Response}; use lazy_static::lazy_static; +use parse_duration::parse; +use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef}; +use rusqlite::ToSql; use serde::{Deserialize, Serialize}; use std::env; use std::sync::Arc; @@ -19,6 +22,62 @@ lazy_static! { static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap(); } +#[derive(Debug, PartialEq, Clone)] +pub struct Embedding(Vec); + +impl From> for Embedding { + fn from(value: Vec) -> Self { + Embedding(value) + } +} + +impl Embedding { + pub fn similarity(&self, other: &Self) -> f32 { + let len = self.0.len(); + assert_eq!(len, other.0.len()); + + let mut result = 0.0; + unsafe { + matrixmultiply::sgemm( + 1, + len, + 1, + 1.0, + self.0.as_ptr(), + len as isize, + 1, + other.0.as_ptr(), + 1, + len as isize, + 0.0, + &mut result as *mut f32, + 1, + 1, + ); + } + result + } +} + +impl FromSql for Embedding { + fn column_result(value: ValueRef) -> FromSqlResult { + let bytes = value.as_blob()?; + let embedding: Result, Box> = bincode::deserialize(bytes); + if embedding.is_err() { + return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err())); + } + Ok(Embedding(embedding.unwrap())) + } +} + +impl ToSql for Embedding { + fn to_sql(&self) -> rusqlite::Result { + let bytes = bincode::serialize(&self.0) + .map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?; + Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes))) + } +} + #[derive(Clone)] pub struct OpenAIEmbeddings { pub client: Arc, @@ -52,42 +111,53 @@ struct OpenAIEmbeddingUsage { #[async_trait] pub trait EmbeddingProvider: Sync + Send { - async fn embed_batch(&self, spans: Vec<&str>) -> Result>>; + async fn embed_batch(&self, spans: Vec) -> Result>; + fn max_tokens_per_batch(&self) -> usize; + fn truncate(&self, span: &str) -> (String, usize); } pub struct DummyEmbeddings {} #[async_trait] impl EmbeddingProvider for DummyEmbeddings { - async fn embed_batch(&self, spans: Vec<&str>) -> Result>> { + async fn embed_batch(&self, spans: Vec) -> Result> { // 1024 is the OpenAI Embeddings size for ada models. // the model we will likely be starting with. - let dummy_vec = vec![0.32 as f32; 1536]; + let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]); return Ok(vec![dummy_vec; spans.len()]); } + + fn max_tokens_per_batch(&self) -> usize { + OPENAI_INPUT_LIMIT + } + + fn truncate(&self, span: &str) -> (String, usize) { + let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); + let token_count = tokens.len(); + let output = if token_count > OPENAI_INPUT_LIMIT { + tokens.truncate(OPENAI_INPUT_LIMIT); + let new_input = OPENAI_BPE_TOKENIZER.decode(tokens.clone()); + new_input.ok().unwrap_or_else(|| span.to_string()) + } else { + span.to_string() + }; + + (output, tokens.len()) + } } const OPENAI_INPUT_LIMIT: usize = 8190; impl OpenAIEmbeddings { - fn truncate(span: String) -> String { - let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span.as_ref()); - if tokens.len() > OPENAI_INPUT_LIMIT { - tokens.truncate(OPENAI_INPUT_LIMIT); - let result = OPENAI_BPE_TOKENIZER.decode(tokens.clone()); - if result.is_ok() { - let transformed = result.unwrap(); - return transformed; - } - } - - span - } - - async fn send_request(&self, api_key: &str, spans: Vec<&str>) -> Result> { + async fn send_request( + &self, + api_key: &str, + spans: Vec<&str>, + request_timeout: u64, + ) -> Result> { let request = Request::post("https://api.openai.com/v1/embeddings") .redirect_policy(isahc::config::RedirectPolicy::Follow) - .timeout(Duration::from_secs(4)) + .timeout(Duration::from_secs(request_timeout)) .header("Content-Type", "application/json") .header("Authorization", format!("Bearer {}", api_key)) .body( @@ -105,7 +175,27 @@ impl OpenAIEmbeddings { #[async_trait] impl EmbeddingProvider for OpenAIEmbeddings { - async fn embed_batch(&self, spans: Vec<&str>) -> Result>> { + fn max_tokens_per_batch(&self) -> usize { + 50000 + } + + fn truncate(&self, span: &str) -> (String, usize) { + let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); + let token_count = tokens.len(); + let output = if token_count > OPENAI_INPUT_LIMIT { + tokens.truncate(OPENAI_INPUT_LIMIT); + OPENAI_BPE_TOKENIZER + .decode(tokens) + .ok() + .unwrap_or_else(|| span.to_string()) + } else { + span.to_string() + }; + + (output, token_count) + } + + async fn embed_batch(&self, spans: Vec) -> Result> { const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45]; const MAX_RETRIES: usize = 4; @@ -114,45 +204,21 @@ impl EmbeddingProvider for OpenAIEmbeddings { .ok_or_else(|| anyhow!("no api key"))?; let mut request_number = 0; - let mut truncated = false; + let mut request_timeout: u64 = 10; let mut response: Response; - let mut spans: Vec = spans.iter().map(|x| x.to_string()).collect(); while request_number < MAX_RETRIES { response = self - .send_request(api_key, spans.iter().map(|x| &**x).collect()) + .send_request( + api_key, + spans.iter().map(|x| &**x).collect(), + request_timeout, + ) .await?; request_number += 1; - if request_number + 1 == MAX_RETRIES && response.status() != StatusCode::OK { - return Err(anyhow!( - "openai max retries, error: {:?}", - &response.status() - )); - } - match response.status() { - StatusCode::TOO_MANY_REQUESTS => { - let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64); - log::trace!( - "open ai rate limiting, delaying request by {:?} seconds", - delay.as_secs() - ); - self.executor.timer(delay).await; - } - StatusCode::BAD_REQUEST => { - // Only truncate if it hasnt been truncated before - if !truncated { - for span in spans.iter_mut() { - *span = Self::truncate(span.clone()); - } - truncated = true; - } else { - // If failing once already truncated, log the error and break the loop - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - log::trace!("open ai bad request: {:?} {:?}", &response.status(), body); - break; - } + StatusCode::REQUEST_TIMEOUT => { + request_timeout += 5; } StatusCode::OK => { let mut body = String::new(); @@ -163,18 +229,96 @@ impl EmbeddingProvider for OpenAIEmbeddings { "openai embedding completed. tokens: {:?}", response.usage.total_tokens ); + return Ok(response .data .into_iter() - .map(|embedding| embedding.embedding) + .map(|embedding| Embedding::from(embedding.embedding)) .collect()); } + StatusCode::TOO_MANY_REQUESTS => { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + let delay_duration = { + let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64); + if let Some(time_to_reset) = + response.headers().get("x-ratelimit-reset-tokens") + { + if let Ok(time_str) = time_to_reset.to_str() { + parse(time_str).unwrap_or(delay) + } else { + delay + } + } else { + delay + } + }; + + log::trace!( + "openai rate limiting: waiting {:?} until lifted", + &delay_duration + ); + + self.executor.timer(delay_duration).await; + } _ => { - return Err(anyhow!("openai embedding failed {}", response.status())); + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + return Err(anyhow!( + "open ai bad request: {:?} {:?}", + &response.status(), + body + )); } } } - - Err(anyhow!("openai embedding failed")) + Err(anyhow!("openai max retries")) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use rand::prelude::*; + + #[gpui::test] + fn test_similarity(mut rng: StdRng) { + assert_eq!( + Embedding::from(vec![1., 0., 0., 0., 0.]) + .similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])), + 0. + ); + assert_eq!( + Embedding::from(vec![2., 0., 0., 0., 0.]) + .similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])), + 6. + ); + + for _ in 0..100 { + let size = 1536; + let mut a = vec![0.; size]; + let mut b = vec![0.; size]; + for (a, b) in a.iter_mut().zip(b.iter_mut()) { + *a = rng.gen(); + *b = rng.gen(); + } + let a = Embedding::from(a); + let b = Embedding::from(b); + + assert_eq!( + round_to_decimals(a.similarity(&b), 1), + round_to_decimals(reference_dot(&a.0, &b.0), 1) + ); + } + + fn round_to_decimals(n: f32, decimal_places: i32) -> f32 { + let factor = (10.0 as f32).powi(decimal_places); + (n * factor).round() / factor + } + + fn reference_dot(a: &[f32], b: &[f32]) -> f32 { + a.iter().zip(b.iter()).map(|(a, b)| a * b).sum() + } } } diff --git a/crates/semantic_index/src/embedding_queue.rs b/crates/semantic_index/src/embedding_queue.rs new file mode 100644 index 0000000000..96493fc4d3 --- /dev/null +++ b/crates/semantic_index/src/embedding_queue.rs @@ -0,0 +1,173 @@ +use crate::{embedding::EmbeddingProvider, parsing::Document, JobHandle}; +use gpui::executor::Background; +use parking_lot::Mutex; +use smol::channel; +use std::{mem, ops::Range, path::PathBuf, sync::Arc, time::SystemTime}; + +#[derive(Clone)] +pub struct FileToEmbed { + pub worktree_id: i64, + pub path: PathBuf, + pub mtime: SystemTime, + pub documents: Vec, + pub job_handle: JobHandle, +} + +impl std::fmt::Debug for FileToEmbed { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("FileToEmbed") + .field("worktree_id", &self.worktree_id) + .field("path", &self.path) + .field("mtime", &self.mtime) + .field("document", &self.documents) + .finish_non_exhaustive() + } +} + +impl PartialEq for FileToEmbed { + fn eq(&self, other: &Self) -> bool { + self.worktree_id == other.worktree_id + && self.path == other.path + && self.mtime == other.mtime + && self.documents == other.documents + } +} + +pub struct EmbeddingQueue { + embedding_provider: Arc, + pending_batch: Vec, + executor: Arc, + pending_batch_token_count: usize, + finished_files_tx: channel::Sender, + finished_files_rx: channel::Receiver, +} + +#[derive(Clone)] +pub struct FileToEmbedFragment { + file: Arc>, + document_range: Range, +} + +impl EmbeddingQueue { + pub fn new(embedding_provider: Arc, executor: Arc) -> Self { + let (finished_files_tx, finished_files_rx) = channel::unbounded(); + Self { + embedding_provider, + executor, + pending_batch: Vec::new(), + pending_batch_token_count: 0, + finished_files_tx, + finished_files_rx, + } + } + + pub fn push(&mut self, file: FileToEmbed) { + if file.documents.is_empty() { + self.finished_files_tx.try_send(file).unwrap(); + return; + } + + let file = Arc::new(Mutex::new(file)); + + self.pending_batch.push(FileToEmbedFragment { + file: file.clone(), + document_range: 0..0, + }); + + let mut fragment_range = &mut self.pending_batch.last_mut().unwrap().document_range; + let mut saved_tokens = 0; + for (ix, document) in file.lock().documents.iter().enumerate() { + let document_token_count = if document.embedding.is_none() { + document.token_count + } else { + saved_tokens += document.token_count; + 0 + }; + + let next_token_count = self.pending_batch_token_count + document_token_count; + if next_token_count > self.embedding_provider.max_tokens_per_batch() { + let range_end = fragment_range.end; + self.flush(); + self.pending_batch.push(FileToEmbedFragment { + file: file.clone(), + document_range: range_end..range_end, + }); + fragment_range = &mut self.pending_batch.last_mut().unwrap().document_range; + } + + fragment_range.end = ix + 1; + self.pending_batch_token_count += document_token_count; + } + log::trace!("Saved Tokens: {:?}", saved_tokens); + } + + pub fn flush(&mut self) { + let batch = mem::take(&mut self.pending_batch); + self.pending_batch_token_count = 0; + if batch.is_empty() { + return; + } + + let finished_files_tx = self.finished_files_tx.clone(); + let embedding_provider = self.embedding_provider.clone(); + + self.executor.spawn(async move { + let mut spans = Vec::new(); + let mut document_count = 0; + for fragment in &batch { + let file = fragment.file.lock(); + document_count += file.documents[fragment.document_range.clone()].len(); + spans.extend( + { + file.documents[fragment.document_range.clone()] + .iter().filter(|d| d.embedding.is_none()) + .map(|d| d.content.clone()) + } + ); + } + + log::trace!("Documents Length: {:?}", document_count); + log::trace!("Span Length: {:?}", spans.clone().len()); + + // If spans is 0, just send the fragment to the finished files if its the last one. + if spans.len() == 0 { + for fragment in batch.clone() { + if let Some(file) = Arc::into_inner(fragment.file) { + finished_files_tx.try_send(file.into_inner()).unwrap(); + } + } + return; + }; + + match embedding_provider.embed_batch(spans).await { + Ok(embeddings) => { + let mut embeddings = embeddings.into_iter(); + for fragment in batch { + for document in + &mut fragment.file.lock().documents[fragment.document_range.clone()].iter_mut().filter(|d| d.embedding.is_none()) + { + if let Some(embedding) = embeddings.next() { + document.embedding = Some(embedding); + } else { + // + log::error!("number of embeddings returned different from number of documents"); + } + } + + if let Some(file) = Arc::into_inner(fragment.file) { + finished_files_tx.try_send(file.into_inner()).unwrap(); + } + } + } + Err(error) => { + log::error!("{:?}", error); + } + } + }) + .detach(); + } + + pub fn finished_files(&self) -> channel::Receiver { + self.finished_files_rx.clone() + } +} diff --git a/crates/semantic_index/src/parsing.rs b/crates/semantic_index/src/parsing.rs index 4aefb0b00d..c0a94c6b73 100644 --- a/crates/semantic_index/src/parsing.rs +++ b/crates/semantic_index/src/parsing.rs @@ -1,5 +1,10 @@ -use anyhow::{anyhow, Ok, Result}; +use crate::embedding::{Embedding, EmbeddingProvider}; +use anyhow::{anyhow, Result}; use language::{Grammar, Language}; +use rusqlite::{ + types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef}, + ToSql, +}; use sha1::{Digest, Sha1}; use std::{ cmp::{self, Reverse}, @@ -10,13 +15,44 @@ use std::{ }; use tree_sitter::{Parser, QueryCursor}; +#[derive(Debug, PartialEq, Eq, Clone, Hash)] +pub struct DocumentDigest([u8; 20]); + +impl FromSql for DocumentDigest { + fn column_result(value: ValueRef) -> FromSqlResult { + let blob = value.as_blob()?; + let bytes = + blob.try_into() + .map_err(|_| rusqlite::types::FromSqlError::InvalidBlobSize { + expected_size: 20, + blob_size: blob.len(), + })?; + return Ok(DocumentDigest(bytes)); + } +} + +impl ToSql for DocumentDigest { + fn to_sql(&self) -> rusqlite::Result { + self.0.to_sql() + } +} + +impl From<&'_ str> for DocumentDigest { + fn from(value: &'_ str) -> Self { + let mut sha1 = Sha1::new(); + sha1.update(value); + Self(sha1.finalize().into()) + } +} + #[derive(Debug, PartialEq, Clone)] pub struct Document { pub name: String, pub range: Range, pub content: String, - pub embedding: Vec, - pub sha1: [u8; 20], + pub embedding: Option, + pub digest: DocumentDigest, + pub token_count: usize, } const CODE_CONTEXT_TEMPLATE: &str = @@ -30,6 +66,7 @@ pub const PARSEABLE_ENTIRE_FILE_TYPES: &[&str] = pub struct CodeContextRetriever { pub parser: Parser, pub cursor: QueryCursor, + pub embedding_provider: Arc, } // Every match has an item, this represents the fundamental treesitter symbol and anchors the search @@ -47,10 +84,11 @@ pub struct CodeContextMatch { } impl CodeContextRetriever { - pub fn new() -> Self { + pub fn new(embedding_provider: Arc) -> Self { Self { parser: Parser::new(), cursor: QueryCursor::new(), + embedding_provider, } } @@ -64,16 +102,15 @@ impl CodeContextRetriever { .replace("", relative_path.to_string_lossy().as_ref()) .replace("", language_name.as_ref()) .replace("", &content); - - let mut sha1 = Sha1::new(); - sha1.update(&document_span); - + let digest = DocumentDigest::from(document_span.as_str()); + let (document_span, token_count) = self.embedding_provider.truncate(&document_span); Ok(vec![Document { range: 0..content.len(), content: document_span, - embedding: Vec::new(), + embedding: Default::default(), name: language_name.to_string(), - sha1: sha1.finalize().into(), + digest, + token_count, }]) } @@ -81,16 +118,15 @@ impl CodeContextRetriever { let document_span = MARKDOWN_CONTEXT_TEMPLATE .replace("", relative_path.to_string_lossy().as_ref()) .replace("", &content); - - let mut sha1 = Sha1::new(); - sha1.update(&document_span); - + let digest = DocumentDigest::from(document_span.as_str()); + let (document_span, token_count) = self.embedding_provider.truncate(&document_span); Ok(vec![Document { range: 0..content.len(), content: document_span, - embedding: Vec::new(), + embedding: None, name: "Markdown".to_string(), - sha1: sha1.finalize().into(), + digest, + token_count, }]) } @@ -166,10 +202,16 @@ impl CodeContextRetriever { let mut documents = self.parse_file(content, language)?; for document in &mut documents { - document.content = CODE_CONTEXT_TEMPLATE + let document_content = CODE_CONTEXT_TEMPLATE .replace("", relative_path.to_string_lossy().as_ref()) .replace("", language_name.as_ref()) .replace("item", &document.content); + + let (document_content, token_count) = + self.embedding_provider.truncate(&document_content); + + document.content = document_content; + document.token_count = token_count; } Ok(documents) } @@ -263,15 +305,14 @@ impl CodeContextRetriever { ); } - let mut sha1 = Sha1::new(); - sha1.update(&document_content); - + let sha1 = DocumentDigest::from(document_content.as_str()); documents.push(Document { name, content: document_content, range: item_range.clone(), - embedding: vec![], - sha1: sha1.finalize().into(), + embedding: None, + digest: sha1, + token_count: 0, }) } diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 736f2c98a8..a917eabfc8 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -1,5 +1,6 @@ mod db; mod embedding; +mod embedding_queue; mod parsing; pub mod semantic_index_settings; @@ -9,23 +10,25 @@ mod semantic_index_tests; use crate::semantic_index_settings::SemanticIndexSettings; use anyhow::{anyhow, Result}; use db::VectorDatabase; -use embedding::{EmbeddingProvider, OpenAIEmbeddings}; -use futures::{channel::oneshot, Future}; +use embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings}; +use embedding_queue::{EmbeddingQueue, FileToEmbed}; +use futures::{FutureExt, StreamExt}; use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle}; use language::{Anchor, Buffer, Language, LanguageRegistry}; use parking_lot::Mutex; -use parsing::{CodeContextRetriever, Document, PARSEABLE_ENTIRE_FILE_TYPES}; +use parsing::{CodeContextRetriever, DocumentDigest, PARSEABLE_ENTIRE_FILE_TYPES}; use postage::watch; -use project::{search::PathMatcher, Fs, PathChange, Project, ProjectEntryId, WorktreeId}; +use project::{ + search::PathMatcher, Fs, PathChange, Project, ProjectEntryId, ProjectPath, Worktree, WorktreeId, +}; use smol::channel; use std::{ cmp::Ordering, - collections::HashMap, - mem, + collections::{BTreeMap, HashMap}, ops::Range, path::{Path, PathBuf}, sync::{Arc, Weak}, - time::{Instant, SystemTime}, + time::{Duration, Instant, SystemTime}, }; use util::{ channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME}, @@ -35,8 +38,9 @@ use util::{ }; use workspace::WorkspaceCreated; -const SEMANTIC_INDEX_VERSION: usize = 7; -const EMBEDDINGS_BATCH_SIZE: usize = 80; +const SEMANTIC_INDEX_VERSION: usize = 9; +const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(5 * 60); +const EMBEDDING_QUEUE_FLUSH_TIMEOUT: Duration = Duration::from_millis(250); pub fn init( fs: Arc, @@ -97,14 +101,11 @@ pub fn init( pub struct SemanticIndex { fs: Arc, - database_url: Arc, + db: VectorDatabase, embedding_provider: Arc, language_registry: Arc, - db_update_tx: channel::Sender, - parsing_files_tx: channel::Sender, - _db_update_task: Task<()>, - _embed_batch_tasks: Vec>, - _batch_files_task: Task<()>, + parsing_files_tx: channel::Sender<(Arc>, PendingFile)>, + _embedding_task: Task<()>, _parsing_files_tasks: Vec>, projects: HashMap, ProjectState>, } @@ -113,13 +114,18 @@ struct ProjectState { worktree_db_ids: Vec<(WorktreeId, i64)>, _subscription: gpui::Subscription, outstanding_job_count_rx: watch::Receiver, - _outstanding_job_count_tx: Arc>>, - job_queue_tx: channel::Sender, - _queue_update_task: Task<()>, + outstanding_job_count_tx: Arc>>, + changed_paths: BTreeMap, +} + +struct ChangedPathInfo { + changed_at: Instant, + mtime: SystemTime, + is_deleted: bool, } #[derive(Clone)] -struct JobHandle { +pub struct JobHandle { /// The outer Arc is here to count the clones of a JobHandle instance; /// when the last handle to a given job is dropped, we decrement a counter (just once). tx: Arc>>>, @@ -133,31 +139,21 @@ impl JobHandle { } } } + impl ProjectState { fn new( - cx: &mut AppContext, subscription: gpui::Subscription, worktree_db_ids: Vec<(WorktreeId, i64)>, - outstanding_job_count_rx: watch::Receiver, - _outstanding_job_count_tx: Arc>>, + changed_paths: BTreeMap, ) -> Self { - let (job_queue_tx, job_queue_rx) = channel::unbounded(); - let _queue_update_task = cx.background().spawn({ - let mut worktree_queue = HashMap::new(); - async move { - while let Ok(operation) = job_queue_rx.recv().await { - Self::update_queue(&mut worktree_queue, operation); - } - } - }); - + let (outstanding_job_count_tx, outstanding_job_count_rx) = watch::channel_with(0); + let outstanding_job_count_tx = Arc::new(Mutex::new(outstanding_job_count_tx)); Self { worktree_db_ids, outstanding_job_count_rx, - _outstanding_job_count_tx, + outstanding_job_count_tx, + changed_paths, _subscription: subscription, - _queue_update_task, - job_queue_tx, } } @@ -165,41 +161,6 @@ impl ProjectState { self.outstanding_job_count_rx.borrow().clone() } - fn update_queue(queue: &mut HashMap, operation: IndexOperation) { - match operation { - IndexOperation::FlushQueue => { - let queue = std::mem::take(queue); - for (_, op) in queue { - match op { - IndexOperation::IndexFile { - absolute_path: _, - payload, - tx, - } => { - let _ = tx.try_send(payload); - } - IndexOperation::DeleteFile { - absolute_path: _, - payload, - tx, - } => { - let _ = tx.try_send(payload); - } - _ => {} - } - } - } - IndexOperation::IndexFile { - ref absolute_path, .. - } - | IndexOperation::DeleteFile { - ref absolute_path, .. - } => { - queue.insert(absolute_path.clone(), operation); - } - } - } - fn db_id_for_worktree_id(&self, id: WorktreeId) -> Option { self.worktree_db_ids .iter() @@ -230,66 +191,16 @@ pub struct PendingFile { worktree_db_id: i64, relative_path: PathBuf, absolute_path: PathBuf, - language: Arc, + language: Option>, modified_time: SystemTime, job_handle: JobHandle, } -enum IndexOperation { - IndexFile { - absolute_path: PathBuf, - payload: PendingFile, - tx: channel::Sender, - }, - DeleteFile { - absolute_path: PathBuf, - payload: DbOperation, - tx: channel::Sender, - }, - FlushQueue, -} pub struct SearchResult { pub buffer: ModelHandle, pub range: Range, } -enum DbOperation { - InsertFile { - worktree_id: i64, - documents: Vec, - path: PathBuf, - mtime: SystemTime, - job_handle: JobHandle, - }, - Delete { - worktree_id: i64, - path: PathBuf, - }, - FindOrCreateWorktree { - path: PathBuf, - sender: oneshot::Sender>, - }, - FileMTimes { - worktree_id: i64, - sender: oneshot::Sender>>, - }, - WorktreePreviouslyIndexed { - path: Arc, - sender: oneshot::Sender>, - }, -} - -enum EmbeddingJob { - Enqueue { - worktree_id: i64, - path: PathBuf, - mtime: SystemTime, - documents: Vec, - job_handle: JobHandle, - }, - Flush, -} - impl SemanticIndex { pub fn global(cx: &AppContext) -> Option> { if cx.has_global::>() { @@ -306,18 +217,14 @@ impl SemanticIndex { async fn new( fs: Arc, - database_url: PathBuf, + database_path: PathBuf, embedding_provider: Arc, language_registry: Arc, mut cx: AsyncAppContext, ) -> Result> { let t0 = Instant::now(); - let database_url = Arc::new(database_url); - - let db = cx - .background() - .spawn(VectorDatabase::new(fs.clone(), database_url.clone())) - .await?; + let database_path = Arc::from(database_path); + let db = VectorDatabase::new(fs.clone(), database_path, cx.background()).await?; log::trace!( "db initialization took {:?} milliseconds", @@ -326,73 +233,55 @@ impl SemanticIndex { Ok(cx.add_model(|cx| { let t0 = Instant::now(); - // Perform database operations - let (db_update_tx, db_update_rx) = channel::unbounded(); - let _db_update_task = cx.background().spawn({ + let embedding_queue = + EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone()); + let _embedding_task = cx.background().spawn({ + let embedded_files = embedding_queue.finished_files(); + let db = db.clone(); async move { - while let Ok(job) = db_update_rx.recv().await { - Self::run_db_operation(&db, job) + while let Ok(file) = embedded_files.recv().await { + db.insert_file(file.worktree_id, file.path, file.mtime, file.documents) + .await + .log_err(); } } }); - // Group documents into batches and send them to the embedding provider. - let (embed_batch_tx, embed_batch_rx) = - channel::unbounded::, PathBuf, SystemTime, JobHandle)>>(); - let mut _embed_batch_tasks = Vec::new(); - for _ in 0..cx.background().num_cpus() { - let embed_batch_rx = embed_batch_rx.clone(); - _embed_batch_tasks.push(cx.background().spawn({ - let db_update_tx = db_update_tx.clone(); - let embedding_provider = embedding_provider.clone(); - async move { - while let Ok(embeddings_queue) = embed_batch_rx.recv().await { - Self::compute_embeddings_for_batch( - embeddings_queue, - &embedding_provider, - &db_update_tx, - ) - .await; - } - } - })); - } - - // Group documents into batches and send them to the embedding provider. - let (batch_files_tx, batch_files_rx) = channel::unbounded::(); - let _batch_files_task = cx.background().spawn(async move { - let mut queue_len = 0; - let mut embeddings_queue = vec![]; - while let Ok(job) = batch_files_rx.recv().await { - Self::enqueue_documents_to_embed( - job, - &mut queue_len, - &mut embeddings_queue, - &embed_batch_tx, - ); - } - }); - // Parse files into embeddable documents. - let (parsing_files_tx, parsing_files_rx) = channel::unbounded::(); + let (parsing_files_tx, parsing_files_rx) = + channel::unbounded::<(Arc>, PendingFile)>(); + let embedding_queue = Arc::new(Mutex::new(embedding_queue)); let mut _parsing_files_tasks = Vec::new(); for _ in 0..cx.background().num_cpus() { let fs = fs.clone(); - let parsing_files_rx = parsing_files_rx.clone(); - let batch_files_tx = batch_files_tx.clone(); - let db_update_tx = db_update_tx.clone(); + let mut parsing_files_rx = parsing_files_rx.clone(); + let embedding_provider = embedding_provider.clone(); + let embedding_queue = embedding_queue.clone(); + let background = cx.background().clone(); _parsing_files_tasks.push(cx.background().spawn(async move { - let mut retriever = CodeContextRetriever::new(); - while let Ok(pending_file) = parsing_files_rx.recv().await { - Self::parse_file( - &fs, - pending_file, - &mut retriever, - &batch_files_tx, - &parsing_files_rx, - &db_update_tx, - ) - .await; + let mut retriever = CodeContextRetriever::new(embedding_provider.clone()); + loop { + let mut timer = background.timer(EMBEDDING_QUEUE_FLUSH_TIMEOUT).fuse(); + let mut next_file_to_parse = parsing_files_rx.next().fuse(); + futures::select_biased! { + next_file_to_parse = next_file_to_parse => { + if let Some((embeddings_for_digest, pending_file)) = next_file_to_parse { + Self::parse_file( + &fs, + pending_file, + &mut retriever, + &embedding_queue, + &embeddings_for_digest, + ) + .await + } else { + break; + } + }, + _ = timer => { + embedding_queue.lock().flush(); + } + } } })); } @@ -403,192 +292,31 @@ impl SemanticIndex { ); Self { fs, - database_url, + db, embedding_provider, language_registry, - db_update_tx, parsing_files_tx, - _db_update_task, - _embed_batch_tasks, - _batch_files_task, + _embedding_task, _parsing_files_tasks, projects: HashMap::new(), } })) } - fn run_db_operation(db: &VectorDatabase, job: DbOperation) { - match job { - DbOperation::InsertFile { - worktree_id, - documents, - path, - mtime, - job_handle, - } => { - db.insert_file(worktree_id, path, mtime, documents) - .log_err(); - drop(job_handle) - } - DbOperation::Delete { worktree_id, path } => { - db.delete_file(worktree_id, path).log_err(); - } - DbOperation::FindOrCreateWorktree { path, sender } => { - let id = db.find_or_create_worktree(&path); - sender.send(id).ok(); - } - DbOperation::FileMTimes { - worktree_id: worktree_db_id, - sender, - } => { - let file_mtimes = db.get_file_mtimes(worktree_db_id); - sender.send(file_mtimes).ok(); - } - DbOperation::WorktreePreviouslyIndexed { path, sender } => { - let worktree_indexed = db.worktree_previously_indexed(path.as_ref()); - sender.send(worktree_indexed).ok(); - } - } - } - - async fn compute_embeddings_for_batch( - mut embeddings_queue: Vec<(i64, Vec, PathBuf, SystemTime, JobHandle)>, - embedding_provider: &Arc, - db_update_tx: &channel::Sender, - ) { - let mut batch_documents = vec![]; - for (_, documents, _, _, _) in embeddings_queue.iter() { - batch_documents.extend(documents.iter().map(|document| document.content.as_str())); - } - - if let Ok(embeddings) = embedding_provider.embed_batch(batch_documents).await { - log::trace!( - "created {} embeddings for {} files", - embeddings.len(), - embeddings_queue.len(), - ); - - let mut i = 0; - let mut j = 0; - - for embedding in embeddings.iter() { - while embeddings_queue[i].1.len() == j { - i += 1; - j = 0; - } - - embeddings_queue[i].1[j].embedding = embedding.to_owned(); - j += 1; - } - - for (worktree_id, documents, path, mtime, job_handle) in embeddings_queue.into_iter() { - db_update_tx - .send(DbOperation::InsertFile { - worktree_id, - documents, - path, - mtime, - job_handle, - }) - .await - .unwrap(); - } - } else { - // Insert the file in spite of failure so that future attempts to index it do not take place (unless the file is changed). - for (worktree_id, _, path, mtime, job_handle) in embeddings_queue.into_iter() { - db_update_tx - .send(DbOperation::InsertFile { - worktree_id, - documents: vec![], - path, - mtime, - job_handle, - }) - .await - .unwrap(); - } - } - } - - fn enqueue_documents_to_embed( - job: EmbeddingJob, - queue_len: &mut usize, - embeddings_queue: &mut Vec<(i64, Vec, PathBuf, SystemTime, JobHandle)>, - embed_batch_tx: &channel::Sender, PathBuf, SystemTime, JobHandle)>>, - ) { - // Handle edge case where individual file has more documents than max batch size - let should_flush = match job { - EmbeddingJob::Enqueue { - documents, - worktree_id, - path, - mtime, - job_handle, - } => { - // If documents is greater than embeddings batch size, recursively batch existing rows. - if &documents.len() > &EMBEDDINGS_BATCH_SIZE { - let first_job = EmbeddingJob::Enqueue { - documents: documents[..EMBEDDINGS_BATCH_SIZE].to_vec(), - worktree_id, - path: path.clone(), - mtime, - job_handle: job_handle.clone(), - }; - - Self::enqueue_documents_to_embed( - first_job, - queue_len, - embeddings_queue, - embed_batch_tx, - ); - - let second_job = EmbeddingJob::Enqueue { - documents: documents[EMBEDDINGS_BATCH_SIZE..].to_vec(), - worktree_id, - path: path.clone(), - mtime, - job_handle: job_handle.clone(), - }; - - Self::enqueue_documents_to_embed( - second_job, - queue_len, - embeddings_queue, - embed_batch_tx, - ); - return; - } else { - *queue_len += &documents.len(); - embeddings_queue.push((worktree_id, documents, path, mtime, job_handle)); - *queue_len >= EMBEDDINGS_BATCH_SIZE - } - } - EmbeddingJob::Flush => true, - }; - - if should_flush { - embed_batch_tx - .try_send(mem::take(embeddings_queue)) - .unwrap(); - *queue_len = 0; - } - } - async fn parse_file( fs: &Arc, pending_file: PendingFile, retriever: &mut CodeContextRetriever, - batch_files_tx: &channel::Sender, - parsing_files_rx: &channel::Receiver, - db_update_tx: &channel::Sender, + embedding_queue: &Arc>, + embeddings_for_digest: &HashMap, ) { + let Some(language) = pending_file.language else { + return; + }; + if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() { - if let Some(documents) = retriever - .parse_file_with_template( - &pending_file.relative_path, - &content, - pending_file.language, - ) + if let Some(mut documents) = retriever + .parse_file_with_template(&pending_file.relative_path, &content, language) .log_err() { log::trace!( @@ -597,64 +325,21 @@ impl SemanticIndex { documents.len() ); - if documents.len() == 0 { - db_update_tx - .send(DbOperation::InsertFile { - worktree_id: pending_file.worktree_db_id, - documents, - path: pending_file.relative_path, - mtime: pending_file.modified_time, - job_handle: pending_file.job_handle, - }) - .await - .unwrap(); - } else { - batch_files_tx - .try_send(EmbeddingJob::Enqueue { - worktree_id: pending_file.worktree_db_id, - path: pending_file.relative_path, - mtime: pending_file.modified_time, - job_handle: pending_file.job_handle, - documents, - }) - .unwrap(); + for document in documents.iter_mut() { + if let Some(embedding) = embeddings_for_digest.get(&document.digest) { + document.embedding = Some(embedding.to_owned()); + } } + + embedding_queue.lock().push(FileToEmbed { + worktree_id: pending_file.worktree_db_id, + path: pending_file.relative_path, + mtime: pending_file.modified_time, + job_handle: pending_file.job_handle, + documents, + }); } } - - if parsing_files_rx.len() == 0 { - batch_files_tx.try_send(EmbeddingJob::Flush).unwrap(); - } - } - - fn find_or_create_worktree(&self, path: PathBuf) -> impl Future> { - let (tx, rx) = oneshot::channel(); - self.db_update_tx - .try_send(DbOperation::FindOrCreateWorktree { path, sender: tx }) - .unwrap(); - async move { rx.await? } - } - - fn get_file_mtimes( - &self, - worktree_id: i64, - ) -> impl Future>> { - let (tx, rx) = oneshot::channel(); - self.db_update_tx - .try_send(DbOperation::FileMTimes { - worktree_id, - sender: tx, - }) - .unwrap(); - async move { rx.await? } - } - - fn worktree_previously_indexed(&self, path: Arc) -> impl Future> { - let (tx, rx) = oneshot::channel(); - self.db_update_tx - .try_send(DbOperation::WorktreePreviouslyIndexed { path, sender: tx }) - .unwrap(); - async move { rx.await? } } pub fn project_previously_indexed( @@ -665,7 +350,10 @@ impl SemanticIndex { let worktrees_indexed_previously = project .read(cx) .worktrees(cx) - .map(|worktree| self.worktree_previously_indexed(worktree.read(cx).abs_path())) + .map(|worktree| { + self.db + .worktree_previously_indexed(&worktree.read(cx).abs_path()) + }) .collect::>(); cx.spawn(|_, _cx| async move { let worktree_indexed_previously = @@ -679,103 +367,73 @@ impl SemanticIndex { } fn project_entries_changed( - &self, + &mut self, project: ModelHandle, changes: Arc<[(Arc, ProjectEntryId, PathChange)]>, cx: &mut ModelContext<'_, SemanticIndex>, worktree_id: &WorktreeId, - ) -> Result<()> { - let parsing_files_tx = self.parsing_files_tx.clone(); - let db_update_tx = self.db_update_tx.clone(); - let (job_queue_tx, outstanding_job_tx, worktree_db_id) = { - let state = self - .projects - .get(&project.downgrade()) - .ok_or(anyhow!("Project not yet initialized"))?; - let worktree_db_id = state - .db_id_for_worktree_id(*worktree_id) - .ok_or(anyhow!("Worktree ID in Database Not Available"))?; - ( - state.job_queue_tx.clone(), - state._outstanding_job_count_tx.clone(), - worktree_db_id, - ) + ) { + let Some(worktree) = project.read(cx).worktree_for_id(worktree_id.clone(), cx) else { + return; + }; + let project = project.downgrade(); + let Some(project_state) = self.projects.get_mut(&project) else { + return; }; - let language_registry = self.language_registry.clone(); - let parsing_files_tx = parsing_files_tx.clone(); - let db_update_tx = db_update_tx.clone(); - - let worktree = project - .read(cx) - .worktree_for_id(worktree_id.clone(), cx) - .ok_or(anyhow!("Worktree not available"))? - .read(cx) - .snapshot(); - cx.spawn(|_, _| async move { - let worktree = worktree.clone(); - for (path, entry_id, path_change) in changes.iter() { - let relative_path = path.to_path_buf(); - let absolute_path = worktree.absolutize(path); - - let Some(entry) = worktree.entry_for_id(*entry_id) else { - continue; - }; - if entry.is_ignored || entry.is_symlink || entry.is_external { - continue; + let embeddings_for_digest = { + let mut worktree_id_file_paths = HashMap::new(); + for (path, _) in &project_state.changed_paths { + if let Some(worktree_db_id) = project_state.db_id_for_worktree_id(path.worktree_id) + { + worktree_id_file_paths + .entry(worktree_db_id) + .or_insert(Vec::new()) + .push(path.path.clone()); } + } + self.db.embeddings_for_files(worktree_id_file_paths) + }; - log::trace!("File Event: {:?}, Path: {:?}", &path_change, &path); - match path_change { - PathChange::AddedOrUpdated | PathChange::Updated | PathChange::Added => { - if let Ok(language) = language_registry - .language_for_file(&relative_path, None) - .await - { - if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref()) - && &language.name().as_ref() != &"Markdown" - && language - .grammar() - .and_then(|grammar| grammar.embedding_config.as_ref()) - .is_none() - { - continue; - } + let worktree = worktree.read(cx); + let change_time = Instant::now(); + for (path, entry_id, change) in changes.iter() { + let Some(entry) = worktree.entry_for_id(*entry_id) else { + continue; + }; + if entry.is_ignored || entry.is_symlink || entry.is_external { + continue; + } + let project_path = ProjectPath { + worktree_id: *worktree_id, + path: path.clone(), + }; + project_state.changed_paths.insert( + project_path, + ChangedPathInfo { + changed_at: change_time, + mtime: entry.mtime, + is_deleted: *change == PathChange::Removed, + }, + ); + } - let job_handle = JobHandle::new(&outstanding_job_tx); - let new_operation = IndexOperation::IndexFile { - absolute_path: absolute_path.clone(), - payload: PendingFile { - worktree_db_id, - relative_path, - absolute_path, - language, - modified_time: entry.mtime, - job_handle, - }, - tx: parsing_files_tx.clone(), - }; - let _ = job_queue_tx.try_send(new_operation); - } - } - PathChange::Removed => { - let new_operation = IndexOperation::DeleteFile { - absolute_path, - payload: DbOperation::Delete { - worktree_id: worktree_db_id, - path: relative_path, - }, - tx: db_update_tx.clone(), - }; - let _ = job_queue_tx.try_send(new_operation); - } - _ => {} - } + cx.spawn_weak(|this, mut cx| async move { + let embeddings_for_digest = embeddings_for_digest.await.log_err().unwrap_or_default(); + + cx.background().timer(BACKGROUND_INDEXING_DELAY).await; + if let Some((this, project)) = this.upgrade(&cx).zip(project.upgrade(&cx)) { + Self::reindex_changed_paths( + this, + project, + Some(change_time), + &mut cx, + Arc::new(embeddings_for_digest), + ) + .await; } }) .detach(); - - Ok(()) } pub fn initialize_project( @@ -799,20 +457,18 @@ impl SemanticIndex { .read(cx) .worktrees(cx) .map(|worktree| { - self.find_or_create_worktree(worktree.read(cx).abs_path().to_path_buf()) + self.db + .find_or_create_worktree(worktree.read(cx).abs_path().to_path_buf()) }) .collect::>(); let _subscription = cx.subscribe(&project, |this, project, event, cx| { if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event { - let _ = - this.project_entries_changed(project.clone(), changes.clone(), cx, worktree_id); + this.project_entries_changed(project.clone(), changes.clone(), cx, worktree_id); }; }); let language_registry = self.language_registry.clone(); - let parsing_files_tx = self.parsing_files_tx.clone(); - let db_update_tx = self.db_update_tx.clone(); cx.spawn(|this, mut cx| async move { futures::future::join_all(worktree_scans_complete).await; @@ -833,7 +489,7 @@ impl SemanticIndex { db_ids_by_worktree_id.insert(worktree.id(), db_id); worktree_file_mtimes.insert( worktree.id(), - this.read_with(&cx, |this, _| this.get_file_mtimes(db_id)) + this.read_with(&cx, |this, _| this.db.get_file_mtimes(db_id)) .await?, ); } @@ -843,17 +499,13 @@ impl SemanticIndex { .map(|(a, b)| (*a, *b)) .collect(); - let (job_count_tx, job_count_rx) = watch::channel_with(0); - let job_count_tx = Arc::new(Mutex::new(job_count_tx)); - let job_count_tx_longlived = job_count_tx.clone(); - - let worktree_files = cx + let changed_paths = cx .background() .spawn(async move { - let mut worktree_files = Vec::new(); + let mut changed_paths = BTreeMap::new(); + let now = Instant::now(); for worktree in worktrees.into_iter() { let mut file_mtimes = worktree_file_mtimes.remove(&worktree.id()).unwrap(); - let worktree_db_id = db_ids_by_worktree_id[&worktree.id()]; for file in worktree.files(false, 0) { let absolute_path = worktree.absolutize(&file.path); @@ -876,59 +528,51 @@ impl SemanticIndex { continue; } - let path_buf = file.path.to_path_buf(); let stored_mtime = file_mtimes.remove(&file.path.to_path_buf()); let already_stored = stored_mtime .map_or(false, |existing_mtime| existing_mtime == file.mtime); if !already_stored { - let job_handle = JobHandle::new(&job_count_tx); - worktree_files.push(IndexOperation::IndexFile { - absolute_path: absolute_path.clone(), - payload: PendingFile { - worktree_db_id, - relative_path: path_buf, - absolute_path, - language, - job_handle, - modified_time: file.mtime, + changed_paths.insert( + ProjectPath { + worktree_id: worktree.id(), + path: file.path.clone(), }, - tx: parsing_files_tx.clone(), - }); + ChangedPathInfo { + changed_at: now, + mtime: file.mtime, + is_deleted: false, + }, + ); } } } + // Clean up entries from database that are no longer in the worktree. - for (path, _) in file_mtimes { - worktree_files.push(IndexOperation::DeleteFile { - absolute_path: worktree.absolutize(path.as_path()), - payload: DbOperation::Delete { - worktree_id: worktree_db_id, - path, + for (path, mtime) in file_mtimes { + changed_paths.insert( + ProjectPath { + worktree_id: worktree.id(), + path: path.into(), }, - tx: db_update_tx.clone(), - }); + ChangedPathInfo { + changed_at: now, + mtime, + is_deleted: true, + }, + ); } } - anyhow::Ok(worktree_files) + anyhow::Ok(changed_paths) }) .await?; - this.update(&mut cx, |this, cx| { - let project_state = ProjectState::new( - cx, - _subscription, - worktree_db_ids, - job_count_rx, - job_count_tx_longlived, + this.update(&mut cx, |this, _| { + this.projects.insert( + project.downgrade(), + ProjectState::new(_subscription, worktree_db_ids, changed_paths), ); - - for op in worktree_files { - let _ = project_state.job_queue_tx.try_send(op); - } - - this.projects.insert(project.downgrade(), project_state); }); Result::<(), _>::Ok(()) }) @@ -939,27 +583,45 @@ impl SemanticIndex { project: ModelHandle, cx: &mut ModelContext, ) -> Task)>> { - let state = self.projects.get_mut(&project.downgrade()); - let state = if state.is_none() { - return Task::Ready(Some(Err(anyhow!("Project not yet initialized")))); - } else { - state.unwrap() - }; - - // let parsing_files_tx = self.parsing_files_tx.clone(); - // let db_update_tx = self.db_update_tx.clone(); - let job_count_rx = state.outstanding_job_count_rx.clone(); - let count = state.get_outstanding_count(); - cx.spawn(|this, mut cx| async move { - this.update(&mut cx, |this, _| { - let Some(state) = this.projects.get_mut(&project.downgrade()) else { - return; - }; - let _ = state.job_queue_tx.try_send(IndexOperation::FlushQueue); - }); + let embeddings_for_digest = this.read_with(&cx, |this, _| { + if let Some(state) = this.projects.get(&project.downgrade()) { + let mut worktree_id_file_paths = HashMap::default(); + for (path, _) in &state.changed_paths { + if let Some(worktree_db_id) = state.db_id_for_worktree_id(path.worktree_id) + { + worktree_id_file_paths + .entry(worktree_db_id) + .or_insert(Vec::new()) + .push(path.path.clone()); + } + } - Ok((count, job_count_rx)) + Ok(this.db.embeddings_for_files(worktree_id_file_paths)) + } else { + Err(anyhow!("Project not yet initialized")) + } + })?; + + let embeddings_for_digest = Arc::new(embeddings_for_digest.await?); + + Self::reindex_changed_paths( + this.clone(), + project.clone(), + None, + &mut cx, + embeddings_for_digest, + ) + .await; + + this.update(&mut cx, |this, _cx| { + let Some(state) = this.projects.get(&project.downgrade()) else { + return Err(anyhow!("Project not yet initialized")); + }; + let job_count_rx = state.outstanding_job_count_rx.clone(); + let count = state.get_outstanding_count(); + Ok((count, job_count_rx)) + }) }) } @@ -1000,14 +662,15 @@ impl SemanticIndex { .collect::>(); let embedding_provider = self.embedding_provider.clone(); - let database_url = self.database_url.clone(); + let db_path = self.db.path().clone(); let fs = self.fs.clone(); cx.spawn(|this, mut cx| async move { let t0 = Instant::now(); - let database = VectorDatabase::new(fs.clone(), database_url.clone()).await?; + let database = + VectorDatabase::new(fs.clone(), db_path.clone(), cx.background()).await?; let phrase_embedding = embedding_provider - .embed_batch(vec![&phrase]) + .embed_batch(vec![phrase]) .await? .into_iter() .next() @@ -1018,8 +681,9 @@ impl SemanticIndex { t0.elapsed().as_millis() ); - let file_ids = - database.retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes)?; + let file_ids = database + .retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes) + .await?; let batch_n = cx.background().num_cpus(); let ids_len = file_ids.clone().len(); @@ -1029,27 +693,24 @@ impl SemanticIndex { ids_len / batch_n }; - let mut result_tasks = Vec::new(); + let mut batch_results = Vec::new(); for batch in file_ids.chunks(batch_size) { let batch = batch.into_iter().map(|v| *v).collect::>(); let limit = limit.clone(); let fs = fs.clone(); - let database_url = database_url.clone(); + let db_path = db_path.clone(); let phrase_embedding = phrase_embedding.clone(); - let task = cx.background().spawn(async move { - let database = VectorDatabase::new(fs, database_url).await.log_err(); - if database.is_none() { - return Err(anyhow!("failed to acquire database connection")); - } else { - database - .unwrap() - .top_k_search(&phrase_embedding, limit, batch.as_slice()) - } - }); - result_tasks.push(task); + if let Some(db) = VectorDatabase::new(fs, db_path.clone(), cx.background()) + .await + .log_err() + { + batch_results.push(async move { + db.top_k_search(&phrase_embedding, limit, batch.as_slice()) + .await + }); + } } - - let batch_results = futures::future::join_all(result_tasks).await; + let batch_results = futures::future::join_all(batch_results).await; let mut results = Vec::new(); for batch_result in batch_results { @@ -1068,7 +729,7 @@ impl SemanticIndex { } let ids = results.into_iter().map(|(id, _)| id).collect::>(); - let documents = database.get_documents_by_ids(ids.as_slice())?; + let documents = database.get_documents_by_ids(ids.as_slice()).await?; let mut tasks = Vec::new(); let mut ranges = Vec::new(); @@ -1110,6 +771,97 @@ impl SemanticIndex { .collect::>()) }) } + + async fn reindex_changed_paths( + this: ModelHandle, + project: ModelHandle, + last_changed_before: Option, + cx: &mut AsyncAppContext, + embeddings_for_digest: Arc>, + ) { + let mut pending_files = Vec::new(); + let mut files_to_delete = Vec::new(); + let (db, language_registry, parsing_files_tx) = this.update(cx, |this, cx| { + if let Some(project_state) = this.projects.get_mut(&project.downgrade()) { + let outstanding_job_count_tx = &project_state.outstanding_job_count_tx; + let db_ids = &project_state.worktree_db_ids; + let mut worktree: Option> = None; + + project_state.changed_paths.retain(|path, info| { + if let Some(last_changed_before) = last_changed_before { + if info.changed_at > last_changed_before { + return true; + } + } + + if worktree + .as_ref() + .map_or(true, |tree| tree.read(cx).id() != path.worktree_id) + { + worktree = project.read(cx).worktree_for_id(path.worktree_id, cx); + } + let Some(worktree) = &worktree else { + return false; + }; + + let Some(worktree_db_id) = db_ids + .iter() + .find_map(|entry| (entry.0 == path.worktree_id).then_some(entry.1)) + else { + return false; + }; + + if info.is_deleted { + files_to_delete.push((worktree_db_id, path.path.to_path_buf())); + } else { + let absolute_path = worktree.read(cx).absolutize(&path.path); + let job_handle = JobHandle::new(&outstanding_job_count_tx); + pending_files.push(PendingFile { + absolute_path, + relative_path: path.path.to_path_buf(), + language: None, + job_handle, + modified_time: info.mtime, + worktree_db_id, + }); + } + + false + }); + } + + ( + this.db.clone(), + this.language_registry.clone(), + this.parsing_files_tx.clone(), + ) + }); + + for (worktree_db_id, path) in files_to_delete { + db.delete_file(worktree_db_id, path).await.log_err(); + } + + for mut pending_file in pending_files { + if let Ok(language) = language_registry + .language_for_file(&pending_file.relative_path, None) + .await + { + if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref()) + && &language.name().as_ref() != &"Markdown" + && language + .grammar() + .and_then(|grammar| grammar.embedding_config.as_ref()) + .is_none() + { + continue; + } + pending_file.language = Some(language); + } + parsing_files_tx + .try_send((embeddings_for_digest.clone(), pending_file)) + .ok(); + } + } } impl Entity for SemanticIndex { diff --git a/crates/semantic_index/src/semantic_index_tests.rs b/crates/semantic_index/src/semantic_index_tests.rs index 32d8bb0fb8..f549e68e04 100644 --- a/crates/semantic_index/src/semantic_index_tests.rs +++ b/crates/semantic_index/src/semantic_index_tests.rs @@ -1,14 +1,15 @@ use crate::{ - db::dot, - embedding::EmbeddingProvider, - parsing::{subtract_ranges, CodeContextRetriever, Document}, + embedding::{DummyEmbeddings, Embedding, EmbeddingProvider}, + embedding_queue::EmbeddingQueue, + parsing::{subtract_ranges, CodeContextRetriever, Document, DocumentDigest}, semantic_index_settings::SemanticIndexSettings, - SearchResult, SemanticIndex, + FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT, }; use anyhow::Result; use async_trait::async_trait; -use gpui::{Task, TestAppContext}; +use gpui::{executor::Deterministic, Task, TestAppContext}; use language::{Language, LanguageConfig, LanguageRegistry, ToOffset}; +use parking_lot::Mutex; use pretty_assertions::assert_eq; use project::{project_settings::ProjectSettings, search::PathMatcher, FakeFs, Fs, Project}; use rand::{rngs::StdRng, Rng}; @@ -20,8 +21,10 @@ use std::{ atomic::{self, AtomicUsize}, Arc, }, + time::SystemTime, }; use unindent::Unindent; +use util::RandomCharIter; #[ctor::ctor] fn init_logger() { @@ -31,12 +34,8 @@ fn init_logger() { } #[gpui::test] -async fn test_semantic_index(cx: &mut TestAppContext) { - cx.update(|cx| { - cx.set_global(SettingsStore::test(cx)); - settings::register::(cx); - settings::register::(cx); - }); +async fn test_semantic_index(deterministic: Arc, cx: &mut TestAppContext) { + init_test(cx); let fs = FakeFs::new(cx.background()); fs.insert_tree( @@ -56,6 +55,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) { fn bbb() { println!(\"bbbbbbbbbbbbb!\"); } + struct pqpqpqp {} ".unindent(), "file3.toml": " ZZZZZZZZZZZZZZZZZZ = 5 @@ -75,7 +75,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) { let db_path = db_dir.path().join("db.sqlite"); let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); - let store = SemanticIndex::new( + let semantic_index = SemanticIndex::new( fs.clone(), db_path, embedding_provider.clone(), @@ -87,21 +87,21 @@ async fn test_semantic_index(cx: &mut TestAppContext) { let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await; - let _ = store + let _ = semantic_index .update(cx, |store, cx| { store.initialize_project(project.clone(), cx) }) .await; - let (file_count, outstanding_file_count) = store + let (file_count, outstanding_file_count) = semantic_index .update(cx, |store, cx| store.index_project(project.clone(), cx)) .await .unwrap(); assert_eq!(file_count, 3); - cx.foreground().run_until_parked(); + deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT); assert_eq!(*outstanding_file_count.borrow(), 0); - let search_results = store + let search_results = semantic_index .update(cx, |store, cx| { store.search_project( project.clone(), @@ -122,6 +122,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) { (Path::new("src/file2.rs").into(), 0), (Path::new("src/file3.toml").into(), 0), (Path::new("src/file1.rs").into(), 45), + (Path::new("src/file2.rs").into(), 45), ], cx, ); @@ -129,7 +130,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) { // Test Include Files Functonality let include_files = vec![PathMatcher::new("*.rs").unwrap()]; let exclude_files = vec![PathMatcher::new("*.rs").unwrap()]; - let rust_only_search_results = store + let rust_only_search_results = semantic_index .update(cx, |store, cx| { store.search_project( project.clone(), @@ -149,11 +150,12 @@ async fn test_semantic_index(cx: &mut TestAppContext) { (Path::new("src/file1.rs").into(), 0), (Path::new("src/file2.rs").into(), 0), (Path::new("src/file1.rs").into(), 45), + (Path::new("src/file2.rs").into(), 45), ], cx, ); - let no_rust_search_results = store + let no_rust_search_results = semantic_index .update(cx, |store, cx| { store.search_project( project.clone(), @@ -186,24 +188,87 @@ async fn test_semantic_index(cx: &mut TestAppContext) { .await .unwrap(); - cx.foreground().run_until_parked(); + deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT); let prev_embedding_count = embedding_provider.embedding_count(); - let (file_count, outstanding_file_count) = store + let (file_count, outstanding_file_count) = semantic_index .update(cx, |store, cx| store.index_project(project.clone(), cx)) .await .unwrap(); assert_eq!(file_count, 1); - cx.foreground().run_until_parked(); + deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT); assert_eq!(*outstanding_file_count.borrow(), 0); assert_eq!( embedding_provider.embedding_count() - prev_embedding_count, - 2 + 1 ); } +#[gpui::test(iterations = 10)] +async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) { + let (outstanding_job_count, _) = postage::watch::channel_with(0); + let outstanding_job_count = Arc::new(Mutex::new(outstanding_job_count)); + + let files = (1..=3) + .map(|file_ix| FileToEmbed { + worktree_id: 5, + path: format!("path-{file_ix}").into(), + mtime: SystemTime::now(), + documents: (0..rng.gen_range(4..22)) + .map(|document_ix| { + let content_len = rng.gen_range(10..100); + let content = RandomCharIter::new(&mut rng) + .with_simple_text() + .take(content_len) + .collect::(); + let digest = DocumentDigest::from(content.as_str()); + Document { + range: 0..10, + embedding: None, + name: format!("document {document_ix}"), + content, + digest, + token_count: rng.gen_range(10..30), + } + }) + .collect(), + job_handle: JobHandle::new(&outstanding_job_count), + }) + .collect::>(); + + let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); + + let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background()); + for file in &files { + queue.push(file.clone()); + } + queue.flush(); + + cx.foreground().run_until_parked(); + let finished_files = queue.finished_files(); + let mut embedded_files: Vec<_> = files + .iter() + .map(|_| finished_files.try_recv().expect("no finished file")) + .collect(); + + let expected_files: Vec<_> = files + .iter() + .map(|file| { + let mut file = file.clone(); + for doc in &mut file.documents { + doc.embedding = Some(embedding_provider.embed_sync(doc.content.as_ref())); + } + file + }) + .collect(); + + embedded_files.sort_by_key(|f| f.path.clone()); + + assert_eq!(embedded_files, expected_files); +} + #[track_caller] fn assert_search_results( actual: &[SearchResult], @@ -227,7 +292,8 @@ fn assert_search_results( #[gpui::test] async fn test_code_context_retrieval_rust() { let language = rust_lang(); - let mut retriever = CodeContextRetriever::new(); + let embedding_provider = Arc::new(DummyEmbeddings {}); + let mut retriever = CodeContextRetriever::new(embedding_provider); let text = " /// A doc comment @@ -314,7 +380,8 @@ async fn test_code_context_retrieval_rust() { #[gpui::test] async fn test_code_context_retrieval_json() { let language = json_lang(); - let mut retriever = CodeContextRetriever::new(); + let embedding_provider = Arc::new(DummyEmbeddings {}); + let mut retriever = CodeContextRetriever::new(embedding_provider); let text = r#" { @@ -397,7 +464,8 @@ fn assert_documents_eq( #[gpui::test] async fn test_code_context_retrieval_javascript() { let language = js_lang(); - let mut retriever = CodeContextRetriever::new(); + let embedding_provider = Arc::new(DummyEmbeddings {}); + let mut retriever = CodeContextRetriever::new(embedding_provider); let text = " /* globals importScripts, backend */ @@ -495,7 +563,8 @@ async fn test_code_context_retrieval_javascript() { #[gpui::test] async fn test_code_context_retrieval_lua() { let language = lua_lang(); - let mut retriever = CodeContextRetriever::new(); + let embedding_provider = Arc::new(DummyEmbeddings {}); + let mut retriever = CodeContextRetriever::new(embedding_provider); let text = r#" -- Creates a new class @@ -568,7 +637,8 @@ async fn test_code_context_retrieval_lua() { #[gpui::test] async fn test_code_context_retrieval_elixir() { let language = elixir_lang(); - let mut retriever = CodeContextRetriever::new(); + let embedding_provider = Arc::new(DummyEmbeddings {}); + let mut retriever = CodeContextRetriever::new(embedding_provider); let text = r#" defmodule File.Stream do @@ -684,7 +754,8 @@ async fn test_code_context_retrieval_elixir() { #[gpui::test] async fn test_code_context_retrieval_cpp() { let language = cpp_lang(); - let mut retriever = CodeContextRetriever::new(); + let embedding_provider = Arc::new(DummyEmbeddings {}); + let mut retriever = CodeContextRetriever::new(embedding_provider); let text = " /** @@ -836,7 +907,8 @@ async fn test_code_context_retrieval_cpp() { #[gpui::test] async fn test_code_context_retrieval_ruby() { let language = ruby_lang(); - let mut retriever = CodeContextRetriever::new(); + let embedding_provider = Arc::new(DummyEmbeddings {}); + let mut retriever = CodeContextRetriever::new(embedding_provider); let text = r#" # This concern is inspired by "sudo mode" on GitHub. It @@ -1026,7 +1098,8 @@ async fn test_code_context_retrieval_ruby() { #[gpui::test] async fn test_code_context_retrieval_php() { let language = php_lang(); - let mut retriever = CodeContextRetriever::new(); + let embedding_provider = Arc::new(DummyEmbeddings {}); + let mut retriever = CodeContextRetriever::new(embedding_provider); let text = r#" f32 { - let factor = (10.0 as f32).powi(decimal_places); - (n * factor).round() / factor - } - - fn reference_dot(a: &[f32], b: &[f32]) -> f32 { - a.iter().zip(b.iter()).map(|(a, b)| a * b).sum() - } -} - #[derive(Default)] struct FakeEmbeddingProvider { embedding_count: AtomicUsize, @@ -1212,35 +1255,42 @@ impl FakeEmbeddingProvider { fn embedding_count(&self) -> usize { self.embedding_count.load(atomic::Ordering::SeqCst) } + + fn embed_sync(&self, span: &str) -> Embedding { + let mut result = vec![1.0; 26]; + for letter in span.chars() { + let letter = letter.to_ascii_lowercase(); + if letter as u32 >= 'a' as u32 { + let ix = (letter as u32) - ('a' as u32); + if ix < 26 { + result[ix as usize] += 1.0; + } + } + } + + let norm = result.iter().map(|x| x * x).sum::().sqrt(); + for x in &mut result { + *x /= norm; + } + + result.into() + } } #[async_trait] impl EmbeddingProvider for FakeEmbeddingProvider { - async fn embed_batch(&self, spans: Vec<&str>) -> Result>> { + fn truncate(&self, span: &str) -> (String, usize) { + (span.to_string(), 1) + } + + fn max_tokens_per_batch(&self) -> usize { + 200 + } + + async fn embed_batch(&self, spans: Vec) -> Result> { self.embedding_count .fetch_add(spans.len(), atomic::Ordering::SeqCst); - Ok(spans - .iter() - .map(|span| { - let mut result = vec![1.0; 26]; - for letter in span.chars() { - let letter = letter.to_ascii_lowercase(); - if letter as u32 >= 'a' as u32 { - let ix = (letter as u32) - ('a' as u32); - if ix < 26 { - result[ix as usize] += 1.0; - } - } - } - - let norm = result.iter().map(|x| x * x).sum::().sqrt(); - for x in &mut result { - *x /= norm; - } - - result - }) - .collect()) + Ok(spans.iter().map(|span| self.embed_sync(span)).collect()) } } @@ -1684,3 +1734,11 @@ fn test_subtract_ranges() { assert_eq!(subtract_ranges(&[0..5], &[1..2]), &[0..1, 2..5]); } + +fn init_test(cx: &mut TestAppContext) { + cx.update(|cx| { + cx.set_global(SettingsStore::test(cx)); + settings::register::(cx); + settings::register::(cx); + }); +} diff --git a/crates/util/src/util.rs b/crates/util/src/util.rs index c8beb86aef..785426ed4c 100644 --- a/crates/util/src/util.rs +++ b/crates/util/src/util.rs @@ -260,11 +260,22 @@ pub fn defer(f: F) -> impl Drop { Defer(Some(f)) } -pub struct RandomCharIter(T); +pub struct RandomCharIter { + rng: T, + simple_text: bool, +} impl RandomCharIter { pub fn new(rng: T) -> Self { - Self(rng) + Self { + rng, + simple_text: std::env::var("SIMPLE_TEXT").map_or(false, |v| !v.is_empty()), + } + } + + pub fn with_simple_text(mut self) -> Self { + self.simple_text = true; + self } } @@ -272,25 +283,27 @@ impl Iterator for RandomCharIter { type Item = char; fn next(&mut self) -> Option { - if std::env::var("SIMPLE_TEXT").map_or(false, |v| !v.is_empty()) { - return if self.0.gen_range(0..100) < 5 { + if self.simple_text { + return if self.rng.gen_range(0..100) < 5 { Some('\n') } else { - Some(self.0.gen_range(b'a'..b'z' + 1).into()) + Some(self.rng.gen_range(b'a'..b'z' + 1).into()) }; } - match self.0.gen_range(0..100) { + match self.rng.gen_range(0..100) { // whitespace - 0..=19 => [' ', '\n', '\r', '\t'].choose(&mut self.0).copied(), + 0..=19 => [' ', '\n', '\r', '\t'].choose(&mut self.rng).copied(), // two-byte greek letters - 20..=32 => char::from_u32(self.0.gen_range(('α' as u32)..('ω' as u32 + 1))), + 20..=32 => char::from_u32(self.rng.gen_range(('α' as u32)..('ω' as u32 + 1))), // // three-byte characters - 33..=45 => ['✋', '✅', '❌', '❎', '⭐'].choose(&mut self.0).copied(), + 33..=45 => ['✋', '✅', '❌', '❎', '⭐'] + .choose(&mut self.rng) + .copied(), // // four-byte characters - 46..=58 => ['🍐', '🏀', '🍗', '🎉'].choose(&mut self.0).copied(), + 46..=58 => ['🍐', '🏀', '🍗', '🎉'].choose(&mut self.rng).copied(), // ascii letters - _ => Some(self.0.gen_range(b'a'..b'z' + 1).into()), + _ => Some(self.rng.gen_range(b'a'..b'z' + 1).into()), } } }