diff --git a/crates/semantic_index/src/db.rs b/crates/semantic_index/src/db.rs index 313df40674..81b05720d2 100644 --- a/crates/semantic_index/src/db.rs +++ b/crates/semantic_index/src/db.rs @@ -1,13 +1,10 @@ -use crate::{parsing::Document, SEMANTIC_INDEX_VERSION}; +use crate::{embedding::Embedding, parsing::Document, SEMANTIC_INDEX_VERSION}; use anyhow::{anyhow, Context, Result}; use futures::channel::oneshot; use gpui::executor; use project::{search::PathMatcher, Fs}; use rpc::proto::Timestamp; -use rusqlite::{ - params, - types::{FromSql, FromSqlResult, ValueRef}, -}; +use rusqlite::params; use std::{ cmp::Ordering, collections::HashMap, @@ -27,34 +24,6 @@ pub struct FileRecord { pub mtime: Timestamp, } -#[derive(Debug)] -struct Embedding(pub Vec); - -#[derive(Debug)] -struct Sha1(pub Vec); - -impl FromSql for Embedding { - fn column_result(value: ValueRef) -> FromSqlResult { - let bytes = value.as_blob()?; - let embedding: Result, Box> = bincode::deserialize(bytes); - if embedding.is_err() { - return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err())); - } - return Ok(Embedding(embedding.unwrap())); - } -} - -impl FromSql for Sha1 { - fn column_result(value: ValueRef) -> FromSqlResult { - let bytes = value.as_blob()?; - let sha1: Result, Box> = bincode::deserialize(bytes); - if sha1.is_err() { - return Err(rusqlite::types::FromSqlError::Other(sha1.unwrap_err())); - } - return Ok(Sha1(sha1.unwrap())); - } -} - #[derive(Clone)] pub struct VectorDatabase { path: Arc, @@ -255,9 +224,6 @@ impl VectorDatabase { // Currently inserting at approximately 3400 documents a second // I imagine we can speed this up with a bulk insert of some kind. for document in documents { - let embedding_blob = bincode::serialize(&document.embedding)?; - let sha_blob = bincode::serialize(&document.sha1)?; - db.execute( "INSERT INTO documents (file_id, start_byte, end_byte, name, embedding, sha1) VALUES (?1, ?2, ?3, ?4, ?5, ?6)", params![ @@ -265,8 +231,8 @@ impl VectorDatabase { document.range.start.to_string(), document.range.end.to_string(), document.name, - embedding_blob, - sha_blob + document.embedding, + document.sha1 ], )?; } @@ -351,7 +317,7 @@ impl VectorDatabase { pub fn top_k_search( &self, - query_embedding: &Vec, + query_embedding: &Embedding, limit: usize, file_ids: &[i64], ) -> impl Future>> { @@ -360,7 +326,7 @@ impl VectorDatabase { self.transact(move |db| { let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1); Self::for_each_document(db, &file_ids, |id, embedding| { - let similarity = dot(&embedding, &query_embedding); + let similarity = embedding.similarity(&query_embedding); let ix = match results.binary_search_by(|(_, s)| { similarity.partial_cmp(&s).unwrap_or(Ordering::Equal) }) { @@ -417,7 +383,7 @@ impl VectorDatabase { fn for_each_document( db: &rusqlite::Connection, file_ids: &[i64], - mut f: impl FnMut(i64, Vec), + mut f: impl FnMut(i64, Embedding), ) -> Result<()> { let mut query_statement = db.prepare( " @@ -435,7 +401,7 @@ impl VectorDatabase { Ok((row.get(0)?, row.get::<_, Embedding>(1)?)) })? .filter_map(|row| row.ok()) - .for_each(|(id, embedding)| f(id, embedding.0)); + .for_each(|(id, embedding)| f(id, embedding)); Ok(()) } @@ -497,29 +463,3 @@ fn ids_to_sql(ids: &[i64]) -> Rc> { .collect::>(), ) } - -pub(crate) fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 { - let len = vec_a.len(); - assert_eq!(len, vec_b.len()); - - let mut result = 0.0; - unsafe { - matrixmultiply::sgemm( - 1, - len, - 1, - 1.0, - vec_a.as_ptr(), - len as isize, - 1, - vec_b.as_ptr(), - 1, - len as isize, - 0.0, - &mut result as *mut f32, - 1, - 1, - ); - } - result -} diff --git a/crates/semantic_index/src/embedding.rs b/crates/semantic_index/src/embedding.rs index 60e13a9e01..97c25ca170 100644 --- a/crates/semantic_index/src/embedding.rs +++ b/crates/semantic_index/src/embedding.rs @@ -8,6 +8,8 @@ use isahc::prelude::Configurable; use isahc::{AsyncBody, Response}; use lazy_static::lazy_static; use parse_duration::parse; +use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef}; +use rusqlite::ToSql; use serde::{Deserialize, Serialize}; use std::env; use std::sync::Arc; @@ -20,6 +22,62 @@ lazy_static! { static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap(); } +#[derive(Debug, PartialEq, Clone)] +pub struct Embedding(Vec); + +impl From> for Embedding { + fn from(value: Vec) -> Self { + Embedding(value) + } +} + +impl Embedding { + pub fn similarity(&self, other: &Self) -> f32 { + let len = self.0.len(); + assert_eq!(len, other.0.len()); + + let mut result = 0.0; + unsafe { + matrixmultiply::sgemm( + 1, + len, + 1, + 1.0, + self.0.as_ptr(), + len as isize, + 1, + other.0.as_ptr(), + 1, + len as isize, + 0.0, + &mut result as *mut f32, + 1, + 1, + ); + } + result + } +} + +impl FromSql for Embedding { + fn column_result(value: ValueRef) -> FromSqlResult { + let bytes = value.as_blob()?; + let embedding: Result, Box> = bincode::deserialize(bytes); + if embedding.is_err() { + return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err())); + } + Ok(Embedding(embedding.unwrap())) + } +} + +impl ToSql for Embedding { + fn to_sql(&self) -> rusqlite::Result { + let bytes = bincode::serialize(&self.0) + .map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?; + Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes))) + } +} + #[derive(Clone)] pub struct OpenAIEmbeddings { pub client: Arc, @@ -53,7 +111,7 @@ struct OpenAIEmbeddingUsage { #[async_trait] pub trait EmbeddingProvider: Sync + Send { - async fn embed_batch(&self, spans: Vec) -> Result>>; + async fn embed_batch(&self, spans: Vec) -> Result>; fn max_tokens_per_batch(&self) -> usize; fn truncate(&self, span: &str) -> (String, usize); } @@ -62,10 +120,10 @@ pub struct DummyEmbeddings {} #[async_trait] impl EmbeddingProvider for DummyEmbeddings { - async fn embed_batch(&self, spans: Vec) -> Result>> { + async fn embed_batch(&self, spans: Vec) -> Result> { // 1024 is the OpenAI Embeddings size for ada models. // the model we will likely be starting with. - let dummy_vec = vec![0.32 as f32; 1536]; + let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]); return Ok(vec![dummy_vec; spans.len()]); } @@ -137,7 +195,7 @@ impl EmbeddingProvider for OpenAIEmbeddings { (output, token_count) } - async fn embed_batch(&self, spans: Vec) -> Result>> { + async fn embed_batch(&self, spans: Vec) -> Result> { const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45]; const MAX_RETRIES: usize = 4; @@ -175,7 +233,7 @@ impl EmbeddingProvider for OpenAIEmbeddings { return Ok(response .data .into_iter() - .map(|embedding| embedding.embedding) + .map(|embedding| Embedding::from(embedding.embedding)) .collect()); } StatusCode::TOO_MANY_REQUESTS => { @@ -218,3 +276,49 @@ impl EmbeddingProvider for OpenAIEmbeddings { Err(anyhow!("openai max retries")) } } + +#[cfg(test)] +mod tests { + use super::*; + use rand::prelude::*; + + #[gpui::test] + fn test_similarity(mut rng: StdRng) { + assert_eq!( + Embedding::from(vec![1., 0., 0., 0., 0.]) + .similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])), + 0. + ); + assert_eq!( + Embedding::from(vec![2., 0., 0., 0., 0.]) + .similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])), + 6. + ); + + for _ in 0..100 { + let size = 1536; + let mut a = vec![0.; size]; + let mut b = vec![0.; size]; + for (a, b) in a.iter_mut().zip(b.iter_mut()) { + *a = rng.gen(); + *b = rng.gen(); + } + let a = Embedding::from(a); + let b = Embedding::from(b); + + assert_eq!( + round_to_decimals(a.similarity(&b), 1), + round_to_decimals(reference_dot(&a.0, &b.0), 1) + ); + } + + fn round_to_decimals(n: f32, decimal_places: i32) -> f32 { + let factor = (10.0 as f32).powi(decimal_places); + (n * factor).round() / factor + } + + fn reference_dot(a: &[f32], b: &[f32]) -> f32 { + a.iter().zip(b.iter()).map(|(a, b)| a * b).sum() + } + } +} diff --git a/crates/semantic_index/src/embedding_queue.rs b/crates/semantic_index/src/embedding_queue.rs index c3a5de1373..4c82ced918 100644 --- a/crates/semantic_index/src/embedding_queue.rs +++ b/crates/semantic_index/src/embedding_queue.rs @@ -121,7 +121,7 @@ impl EmbeddingQueue { &mut fragment.file.lock().documents[fragment.document_range.clone()] { if let Some(embedding) = embeddings.next() { - document.embedding = embedding; + document.embedding = Some(embedding); } else { // log::error!("number of embeddings returned different from number of documents"); diff --git a/crates/semantic_index/src/parsing.rs b/crates/semantic_index/src/parsing.rs index 51f1bd7ca9..2b67f41714 100644 --- a/crates/semantic_index/src/parsing.rs +++ b/crates/semantic_index/src/parsing.rs @@ -1,7 +1,11 @@ -use crate::embedding::EmbeddingProvider; -use anyhow::{anyhow, Ok, Result}; +use crate::embedding::{EmbeddingProvider, Embedding}; +use anyhow::{anyhow, Result}; use language::{Grammar, Language}; -use sha1::{Digest, Sha1}; +use rusqlite::{ + types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef}, + ToSql, +}; +use sha1::Digest; use std::{ cmp::{self, Reverse}, collections::HashSet, @@ -11,13 +15,43 @@ use std::{ }; use tree_sitter::{Parser, QueryCursor}; +#[derive(Debug, PartialEq, Clone)] +pub struct Sha1([u8; 20]); + +impl FromSql for Sha1 { + fn column_result(value: ValueRef) -> FromSqlResult { + let blob = value.as_blob()?; + let bytes = + blob.try_into() + .map_err(|_| rusqlite::types::FromSqlError::InvalidBlobSize { + expected_size: 20, + blob_size: blob.len(), + })?; + return Ok(Sha1(bytes)); + } +} + +impl ToSql for Sha1 { + fn to_sql(&self) -> rusqlite::Result { + self.0.to_sql() + } +} + +impl From<&'_ str> for Sha1 { + fn from(value: &'_ str) -> Self { + let mut sha1 = sha1::Sha1::new(); + sha1.update(value); + Self(sha1.finalize().into()) + } +} + #[derive(Debug, PartialEq, Clone)] pub struct Document { pub name: String, pub range: Range, pub content: String, - pub embedding: Vec, - pub sha1: [u8; 20], + pub embedding: Option, + pub sha1: Sha1, pub token_count: usize, } @@ -69,17 +103,16 @@ impl CodeContextRetriever { .replace("", language_name.as_ref()) .replace("", &content); - let mut sha1 = Sha1::new(); - sha1.update(&document_span); + let sha1 = Sha1::from(document_span.as_str()); let (document_span, token_count) = self.embedding_provider.truncate(&document_span); Ok(vec![Document { range: 0..content.len(), content: document_span, - embedding: Vec::new(), + embedding: Default::default(), name: language_name.to_string(), - sha1: sha1.finalize().into(), + sha1, token_count, }]) } @@ -88,18 +121,14 @@ impl CodeContextRetriever { let document_span = MARKDOWN_CONTEXT_TEMPLATE .replace("", relative_path.to_string_lossy().as_ref()) .replace("", &content); - - let mut sha1 = Sha1::new(); - sha1.update(&document_span); - + let sha1 = Sha1::from(document_span.as_str()); let (document_span, token_count) = self.embedding_provider.truncate(&document_span); - Ok(vec![Document { range: 0..content.len(), content: document_span, - embedding: Vec::new(), + embedding: None, name: "Markdown".to_string(), - sha1: sha1.finalize().into(), + sha1, token_count, }]) } @@ -279,15 +308,13 @@ impl CodeContextRetriever { ); } - let mut sha1 = Sha1::new(); - sha1.update(&document_content); - + let sha1 = Sha1::from(document_content.as_str()); documents.push(Document { name, content: document_content, range: item_range.clone(), - embedding: vec![], - sha1: sha1.finalize().into(), + embedding: None, + sha1, token_count: 0, }) } diff --git a/crates/semantic_index/src/semantic_index_tests.rs b/crates/semantic_index/src/semantic_index_tests.rs index dc41c09f7a..75232eb4d2 100644 --- a/crates/semantic_index/src/semantic_index_tests.rs +++ b/crates/semantic_index/src/semantic_index_tests.rs @@ -1,8 +1,7 @@ use crate::{ - db::dot, - embedding::{DummyEmbeddings, EmbeddingProvider}, + embedding::{DummyEmbeddings, Embedding, EmbeddingProvider}, embedding_queue::EmbeddingQueue, - parsing::{subtract_ranges, CodeContextRetriever, Document}, + parsing::{subtract_ranges, CodeContextRetriever, Document, Sha1}, semantic_index_settings::SemanticIndexSettings, FileToEmbed, JobHandle, SearchResult, SemanticIndex, }; @@ -217,15 +216,17 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) { documents: (0..rng.gen_range(4..22)) .map(|document_ix| { let content_len = rng.gen_range(10..100); + let content = RandomCharIter::new(&mut rng) + .with_simple_text() + .take(content_len) + .collect::(); + let sha1 = Sha1::from(content.as_str()); Document { range: 0..10, - embedding: Vec::new(), + embedding: None, name: format!("document {document_ix}"), - content: RandomCharIter::new(&mut rng) - .with_simple_text() - .take(content_len) - .collect(), - sha1: rng.gen(), + content, + sha1, token_count: rng.gen_range(10..30), } }) @@ -254,7 +255,7 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) { .map(|file| { let mut file = file.clone(); for doc in &mut file.documents { - doc.embedding = embedding_provider.embed_sync(doc.content.as_ref()); + doc.embedding = Some(embedding_provider.embed_sync(doc.content.as_ref())); } file }) @@ -1242,36 +1243,6 @@ async fn test_code_context_retrieval_php() { ); } -#[gpui::test] -fn test_dot_product(mut rng: StdRng) { - assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.); - assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.); - - for _ in 0..100 { - let size = 1536; - let mut a = vec![0.; size]; - let mut b = vec![0.; size]; - for (a, b) in a.iter_mut().zip(b.iter_mut()) { - *a = rng.gen(); - *b = rng.gen(); - } - - assert_eq!( - round_to_decimals(dot(&a, &b), 1), - round_to_decimals(reference_dot(&a, &b), 1) - ); - } - - fn round_to_decimals(n: f32, decimal_places: i32) -> f32 { - let factor = (10.0 as f32).powi(decimal_places); - (n * factor).round() / factor - } - - fn reference_dot(a: &[f32], b: &[f32]) -> f32 { - a.iter().zip(b.iter()).map(|(a, b)| a * b).sum() - } -} - #[derive(Default)] struct FakeEmbeddingProvider { embedding_count: AtomicUsize, @@ -1282,7 +1253,7 @@ impl FakeEmbeddingProvider { self.embedding_count.load(atomic::Ordering::SeqCst) } - fn embed_sync(&self, span: &str) -> Vec { + fn embed_sync(&self, span: &str) -> Embedding { let mut result = vec![1.0; 26]; for letter in span.chars() { let letter = letter.to_ascii_lowercase(); @@ -1299,7 +1270,7 @@ impl FakeEmbeddingProvider { *x /= norm; } - result + result.into() } } @@ -1313,7 +1284,7 @@ impl EmbeddingProvider for FakeEmbeddingProvider { 200 } - async fn embed_batch(&self, spans: Vec) -> Result>> { + async fn embed_batch(&self, spans: Vec) -> Result> { self.embedding_count .fetch_add(spans.len(), atomic::Ordering::SeqCst); Ok(spans.iter().map(|span| self.embed_sync(span)).collect())