Reify Embedding/Sha1 structs that can be (de)serialized from SQL

Co-Authored-By: Kyle Caverly <kyle@zed.dev>
This commit is contained in:
Antonio Scandurra 2023-08-31 17:55:43 +02:00
parent c763e728d1
commit 3001a46f69
5 changed files with 180 additions and 138 deletions

View file

@ -1,13 +1,10 @@
use crate::{parsing::Document, SEMANTIC_INDEX_VERSION};
use crate::{embedding::Embedding, parsing::Document, SEMANTIC_INDEX_VERSION};
use anyhow::{anyhow, Context, Result};
use futures::channel::oneshot;
use gpui::executor;
use project::{search::PathMatcher, Fs};
use rpc::proto::Timestamp;
use rusqlite::{
params,
types::{FromSql, FromSqlResult, ValueRef},
};
use rusqlite::params;
use std::{
cmp::Ordering,
collections::HashMap,
@ -27,34 +24,6 @@ pub struct FileRecord {
pub mtime: Timestamp,
}
#[derive(Debug)]
struct Embedding(pub Vec<f32>);
#[derive(Debug)]
struct Sha1(pub Vec<u8>);
impl FromSql for Embedding {
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
let bytes = value.as_blob()?;
let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
if embedding.is_err() {
return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
}
return Ok(Embedding(embedding.unwrap()));
}
}
impl FromSql for Sha1 {
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
let bytes = value.as_blob()?;
let sha1: Result<Vec<u8>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
if sha1.is_err() {
return Err(rusqlite::types::FromSqlError::Other(sha1.unwrap_err()));
}
return Ok(Sha1(sha1.unwrap()));
}
}
#[derive(Clone)]
pub struct VectorDatabase {
path: Arc<Path>,
@ -255,9 +224,6 @@ impl VectorDatabase {
// Currently inserting at approximately 3400 documents a second
// I imagine we can speed this up with a bulk insert of some kind.
for document in documents {
let embedding_blob = bincode::serialize(&document.embedding)?;
let sha_blob = bincode::serialize(&document.sha1)?;
db.execute(
"INSERT INTO documents (file_id, start_byte, end_byte, name, embedding, sha1) VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
params![
@ -265,8 +231,8 @@ impl VectorDatabase {
document.range.start.to_string(),
document.range.end.to_string(),
document.name,
embedding_blob,
sha_blob
document.embedding,
document.sha1
],
)?;
}
@ -351,7 +317,7 @@ impl VectorDatabase {
pub fn top_k_search(
&self,
query_embedding: &Vec<f32>,
query_embedding: &Embedding,
limit: usize,
file_ids: &[i64],
) -> impl Future<Output = Result<Vec<(i64, f32)>>> {
@ -360,7 +326,7 @@ impl VectorDatabase {
self.transact(move |db| {
let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
Self::for_each_document(db, &file_ids, |id, embedding| {
let similarity = dot(&embedding, &query_embedding);
let similarity = embedding.similarity(&query_embedding);
let ix = match results.binary_search_by(|(_, s)| {
similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
}) {
@ -417,7 +383,7 @@ impl VectorDatabase {
fn for_each_document(
db: &rusqlite::Connection,
file_ids: &[i64],
mut f: impl FnMut(i64, Vec<f32>),
mut f: impl FnMut(i64, Embedding),
) -> Result<()> {
let mut query_statement = db.prepare(
"
@ -435,7 +401,7 @@ impl VectorDatabase {
Ok((row.get(0)?, row.get::<_, Embedding>(1)?))
})?
.filter_map(|row| row.ok())
.for_each(|(id, embedding)| f(id, embedding.0));
.for_each(|(id, embedding)| f(id, embedding));
Ok(())
}
@ -497,29 +463,3 @@ fn ids_to_sql(ids: &[i64]) -> Rc<Vec<rusqlite::types::Value>> {
.collect::<Vec<_>>(),
)
}
pub(crate) fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
let len = vec_a.len();
assert_eq!(len, vec_b.len());
let mut result = 0.0;
unsafe {
matrixmultiply::sgemm(
1,
len,
1,
1.0,
vec_a.as_ptr(),
len as isize,
1,
vec_b.as_ptr(),
1,
len as isize,
0.0,
&mut result as *mut f32,
1,
1,
);
}
result
}

View file

@ -8,6 +8,8 @@ use isahc::prelude::Configurable;
use isahc::{AsyncBody, Response};
use lazy_static::lazy_static;
use parse_duration::parse;
use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
use rusqlite::ToSql;
use serde::{Deserialize, Serialize};
use std::env;
use std::sync::Arc;
@ -20,6 +22,62 @@ lazy_static! {
static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
}
#[derive(Debug, PartialEq, Clone)]
pub struct Embedding(Vec<f32>);
impl From<Vec<f32>> for Embedding {
fn from(value: Vec<f32>) -> Self {
Embedding(value)
}
}
impl Embedding {
pub fn similarity(&self, other: &Self) -> f32 {
let len = self.0.len();
assert_eq!(len, other.0.len());
let mut result = 0.0;
unsafe {
matrixmultiply::sgemm(
1,
len,
1,
1.0,
self.0.as_ptr(),
len as isize,
1,
other.0.as_ptr(),
1,
len as isize,
0.0,
&mut result as *mut f32,
1,
1,
);
}
result
}
}
impl FromSql for Embedding {
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
let bytes = value.as_blob()?;
let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
if embedding.is_err() {
return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
}
Ok(Embedding(embedding.unwrap()))
}
}
impl ToSql for Embedding {
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
let bytes = bincode::serialize(&self.0)
.map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
}
}
#[derive(Clone)]
pub struct OpenAIEmbeddings {
pub client: Arc<dyn HttpClient>,
@ -53,7 +111,7 @@ struct OpenAIEmbeddingUsage {
#[async_trait]
pub trait EmbeddingProvider: Sync + Send {
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Vec<f32>>>;
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
fn max_tokens_per_batch(&self) -> usize;
fn truncate(&self, span: &str) -> (String, usize);
}
@ -62,10 +120,10 @@ pub struct DummyEmbeddings {}
#[async_trait]
impl EmbeddingProvider for DummyEmbeddings {
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Vec<f32>>> {
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
// 1024 is the OpenAI Embeddings size for ada models.
// the model we will likely be starting with.
let dummy_vec = vec![0.32 as f32; 1536];
let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]);
return Ok(vec![dummy_vec; spans.len()]);
}
@ -137,7 +195,7 @@ impl EmbeddingProvider for OpenAIEmbeddings {
(output, token_count)
}
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Vec<f32>>> {
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
const MAX_RETRIES: usize = 4;
@ -175,7 +233,7 @@ impl EmbeddingProvider for OpenAIEmbeddings {
return Ok(response
.data
.into_iter()
.map(|embedding| embedding.embedding)
.map(|embedding| Embedding::from(embedding.embedding))
.collect());
}
StatusCode::TOO_MANY_REQUESTS => {
@ -218,3 +276,49 @@ impl EmbeddingProvider for OpenAIEmbeddings {
Err(anyhow!("openai max retries"))
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::prelude::*;
#[gpui::test]
fn test_similarity(mut rng: StdRng) {
assert_eq!(
Embedding::from(vec![1., 0., 0., 0., 0.])
.similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])),
0.
);
assert_eq!(
Embedding::from(vec![2., 0., 0., 0., 0.])
.similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])),
6.
);
for _ in 0..100 {
let size = 1536;
let mut a = vec![0.; size];
let mut b = vec![0.; size];
for (a, b) in a.iter_mut().zip(b.iter_mut()) {
*a = rng.gen();
*b = rng.gen();
}
let a = Embedding::from(a);
let b = Embedding::from(b);
assert_eq!(
round_to_decimals(a.similarity(&b), 1),
round_to_decimals(reference_dot(&a.0, &b.0), 1)
);
}
fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
let factor = (10.0 as f32).powi(decimal_places);
(n * factor).round() / factor
}
fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
}
}
}

View file

@ -121,7 +121,7 @@ impl EmbeddingQueue {
&mut fragment.file.lock().documents[fragment.document_range.clone()]
{
if let Some(embedding) = embeddings.next() {
document.embedding = embedding;
document.embedding = Some(embedding);
} else {
//
log::error!("number of embeddings returned different from number of documents");

View file

@ -1,7 +1,11 @@
use crate::embedding::EmbeddingProvider;
use anyhow::{anyhow, Ok, Result};
use crate::embedding::{EmbeddingProvider, Embedding};
use anyhow::{anyhow, Result};
use language::{Grammar, Language};
use sha1::{Digest, Sha1};
use rusqlite::{
types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef},
ToSql,
};
use sha1::Digest;
use std::{
cmp::{self, Reverse},
collections::HashSet,
@ -11,13 +15,43 @@ use std::{
};
use tree_sitter::{Parser, QueryCursor};
#[derive(Debug, PartialEq, Clone)]
pub struct Sha1([u8; 20]);
impl FromSql for Sha1 {
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
let blob = value.as_blob()?;
let bytes =
blob.try_into()
.map_err(|_| rusqlite::types::FromSqlError::InvalidBlobSize {
expected_size: 20,
blob_size: blob.len(),
})?;
return Ok(Sha1(bytes));
}
}
impl ToSql for Sha1 {
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
self.0.to_sql()
}
}
impl From<&'_ str> for Sha1 {
fn from(value: &'_ str) -> Self {
let mut sha1 = sha1::Sha1::new();
sha1.update(value);
Self(sha1.finalize().into())
}
}
#[derive(Debug, PartialEq, Clone)]
pub struct Document {
pub name: String,
pub range: Range<usize>,
pub content: String,
pub embedding: Vec<f32>,
pub sha1: [u8; 20],
pub embedding: Option<Embedding>,
pub sha1: Sha1,
pub token_count: usize,
}
@ -69,17 +103,16 @@ impl CodeContextRetriever {
.replace("<language>", language_name.as_ref())
.replace("<item>", &content);
let mut sha1 = Sha1::new();
sha1.update(&document_span);
let sha1 = Sha1::from(document_span.as_str());
let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
Ok(vec![Document {
range: 0..content.len(),
content: document_span,
embedding: Vec::new(),
embedding: Default::default(),
name: language_name.to_string(),
sha1: sha1.finalize().into(),
sha1,
token_count,
}])
}
@ -88,18 +121,14 @@ impl CodeContextRetriever {
let document_span = MARKDOWN_CONTEXT_TEMPLATE
.replace("<path>", relative_path.to_string_lossy().as_ref())
.replace("<item>", &content);
let mut sha1 = Sha1::new();
sha1.update(&document_span);
let sha1 = Sha1::from(document_span.as_str());
let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
Ok(vec![Document {
range: 0..content.len(),
content: document_span,
embedding: Vec::new(),
embedding: None,
name: "Markdown".to_string(),
sha1: sha1.finalize().into(),
sha1,
token_count,
}])
}
@ -279,15 +308,13 @@ impl CodeContextRetriever {
);
}
let mut sha1 = Sha1::new();
sha1.update(&document_content);
let sha1 = Sha1::from(document_content.as_str());
documents.push(Document {
name,
content: document_content,
range: item_range.clone(),
embedding: vec![],
sha1: sha1.finalize().into(),
embedding: None,
sha1,
token_count: 0,
})
}

View file

@ -1,8 +1,7 @@
use crate::{
db::dot,
embedding::{DummyEmbeddings, EmbeddingProvider},
embedding::{DummyEmbeddings, Embedding, EmbeddingProvider},
embedding_queue::EmbeddingQueue,
parsing::{subtract_ranges, CodeContextRetriever, Document},
parsing::{subtract_ranges, CodeContextRetriever, Document, Sha1},
semantic_index_settings::SemanticIndexSettings,
FileToEmbed, JobHandle, SearchResult, SemanticIndex,
};
@ -217,15 +216,17 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
documents: (0..rng.gen_range(4..22))
.map(|document_ix| {
let content_len = rng.gen_range(10..100);
let content = RandomCharIter::new(&mut rng)
.with_simple_text()
.take(content_len)
.collect::<String>();
let sha1 = Sha1::from(content.as_str());
Document {
range: 0..10,
embedding: Vec::new(),
embedding: None,
name: format!("document {document_ix}"),
content: RandomCharIter::new(&mut rng)
.with_simple_text()
.take(content_len)
.collect(),
sha1: rng.gen(),
content,
sha1,
token_count: rng.gen_range(10..30),
}
})
@ -254,7 +255,7 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
.map(|file| {
let mut file = file.clone();
for doc in &mut file.documents {
doc.embedding = embedding_provider.embed_sync(doc.content.as_ref());
doc.embedding = Some(embedding_provider.embed_sync(doc.content.as_ref()));
}
file
})
@ -1242,36 +1243,6 @@ async fn test_code_context_retrieval_php() {
);
}
#[gpui::test]
fn test_dot_product(mut rng: StdRng) {
assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.);
assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.);
for _ in 0..100 {
let size = 1536;
let mut a = vec![0.; size];
let mut b = vec![0.; size];
for (a, b) in a.iter_mut().zip(b.iter_mut()) {
*a = rng.gen();
*b = rng.gen();
}
assert_eq!(
round_to_decimals(dot(&a, &b), 1),
round_to_decimals(reference_dot(&a, &b), 1)
);
}
fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
let factor = (10.0 as f32).powi(decimal_places);
(n * factor).round() / factor
}
fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
}
}
#[derive(Default)]
struct FakeEmbeddingProvider {
embedding_count: AtomicUsize,
@ -1282,7 +1253,7 @@ impl FakeEmbeddingProvider {
self.embedding_count.load(atomic::Ordering::SeqCst)
}
fn embed_sync(&self, span: &str) -> Vec<f32> {
fn embed_sync(&self, span: &str) -> Embedding {
let mut result = vec![1.0; 26];
for letter in span.chars() {
let letter = letter.to_ascii_lowercase();
@ -1299,7 +1270,7 @@ impl FakeEmbeddingProvider {
*x /= norm;
}
result
result.into()
}
}
@ -1313,7 +1284,7 @@ impl EmbeddingProvider for FakeEmbeddingProvider {
200
}
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Vec<f32>>> {
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
self.embedding_count
.fetch_add(spans.len(), atomic::Ordering::SeqCst);
Ok(spans.iter().map(|span| self.embed_sync(span)).collect())