mirror of
https://github.com/zed-industries/zed.git
synced 2025-01-30 06:05:19 +00:00
Reify Embedding
/Sha1
structs that can be (de)serialized from SQL
Co-Authored-By: Kyle Caverly <kyle@zed.dev>
This commit is contained in:
parent
c763e728d1
commit
3001a46f69
5 changed files with 180 additions and 138 deletions
|
@ -1,13 +1,10 @@
|
|||
use crate::{parsing::Document, SEMANTIC_INDEX_VERSION};
|
||||
use crate::{embedding::Embedding, parsing::Document, SEMANTIC_INDEX_VERSION};
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use futures::channel::oneshot;
|
||||
use gpui::executor;
|
||||
use project::{search::PathMatcher, Fs};
|
||||
use rpc::proto::Timestamp;
|
||||
use rusqlite::{
|
||||
params,
|
||||
types::{FromSql, FromSqlResult, ValueRef},
|
||||
};
|
||||
use rusqlite::params;
|
||||
use std::{
|
||||
cmp::Ordering,
|
||||
collections::HashMap,
|
||||
|
@ -27,34 +24,6 @@ pub struct FileRecord {
|
|||
pub mtime: Timestamp,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Embedding(pub Vec<f32>);
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Sha1(pub Vec<u8>);
|
||||
|
||||
impl FromSql for Embedding {
|
||||
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
|
||||
let bytes = value.as_blob()?;
|
||||
let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
|
||||
if embedding.is_err() {
|
||||
return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
|
||||
}
|
||||
return Ok(Embedding(embedding.unwrap()));
|
||||
}
|
||||
}
|
||||
|
||||
impl FromSql for Sha1 {
|
||||
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
|
||||
let bytes = value.as_blob()?;
|
||||
let sha1: Result<Vec<u8>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
|
||||
if sha1.is_err() {
|
||||
return Err(rusqlite::types::FromSqlError::Other(sha1.unwrap_err()));
|
||||
}
|
||||
return Ok(Sha1(sha1.unwrap()));
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct VectorDatabase {
|
||||
path: Arc<Path>,
|
||||
|
@ -255,9 +224,6 @@ impl VectorDatabase {
|
|||
// Currently inserting at approximately 3400 documents a second
|
||||
// I imagine we can speed this up with a bulk insert of some kind.
|
||||
for document in documents {
|
||||
let embedding_blob = bincode::serialize(&document.embedding)?;
|
||||
let sha_blob = bincode::serialize(&document.sha1)?;
|
||||
|
||||
db.execute(
|
||||
"INSERT INTO documents (file_id, start_byte, end_byte, name, embedding, sha1) VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
|
||||
params![
|
||||
|
@ -265,8 +231,8 @@ impl VectorDatabase {
|
|||
document.range.start.to_string(),
|
||||
document.range.end.to_string(),
|
||||
document.name,
|
||||
embedding_blob,
|
||||
sha_blob
|
||||
document.embedding,
|
||||
document.sha1
|
||||
],
|
||||
)?;
|
||||
}
|
||||
|
@ -351,7 +317,7 @@ impl VectorDatabase {
|
|||
|
||||
pub fn top_k_search(
|
||||
&self,
|
||||
query_embedding: &Vec<f32>,
|
||||
query_embedding: &Embedding,
|
||||
limit: usize,
|
||||
file_ids: &[i64],
|
||||
) -> impl Future<Output = Result<Vec<(i64, f32)>>> {
|
||||
|
@ -360,7 +326,7 @@ impl VectorDatabase {
|
|||
self.transact(move |db| {
|
||||
let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
|
||||
Self::for_each_document(db, &file_ids, |id, embedding| {
|
||||
let similarity = dot(&embedding, &query_embedding);
|
||||
let similarity = embedding.similarity(&query_embedding);
|
||||
let ix = match results.binary_search_by(|(_, s)| {
|
||||
similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
|
||||
}) {
|
||||
|
@ -417,7 +383,7 @@ impl VectorDatabase {
|
|||
fn for_each_document(
|
||||
db: &rusqlite::Connection,
|
||||
file_ids: &[i64],
|
||||
mut f: impl FnMut(i64, Vec<f32>),
|
||||
mut f: impl FnMut(i64, Embedding),
|
||||
) -> Result<()> {
|
||||
let mut query_statement = db.prepare(
|
||||
"
|
||||
|
@ -435,7 +401,7 @@ impl VectorDatabase {
|
|||
Ok((row.get(0)?, row.get::<_, Embedding>(1)?))
|
||||
})?
|
||||
.filter_map(|row| row.ok())
|
||||
.for_each(|(id, embedding)| f(id, embedding.0));
|
||||
.for_each(|(id, embedding)| f(id, embedding));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -497,29 +463,3 @@ fn ids_to_sql(ids: &[i64]) -> Rc<Vec<rusqlite::types::Value>> {
|
|||
.collect::<Vec<_>>(),
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
|
||||
let len = vec_a.len();
|
||||
assert_eq!(len, vec_b.len());
|
||||
|
||||
let mut result = 0.0;
|
||||
unsafe {
|
||||
matrixmultiply::sgemm(
|
||||
1,
|
||||
len,
|
||||
1,
|
||||
1.0,
|
||||
vec_a.as_ptr(),
|
||||
len as isize,
|
||||
1,
|
||||
vec_b.as_ptr(),
|
||||
1,
|
||||
len as isize,
|
||||
0.0,
|
||||
&mut result as *mut f32,
|
||||
1,
|
||||
1,
|
||||
);
|
||||
}
|
||||
result
|
||||
}
|
||||
|
|
|
@ -8,6 +8,8 @@ use isahc::prelude::Configurable;
|
|||
use isahc::{AsyncBody, Response};
|
||||
use lazy_static::lazy_static;
|
||||
use parse_duration::parse;
|
||||
use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
|
||||
use rusqlite::ToSql;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::env;
|
||||
use std::sync::Arc;
|
||||
|
@ -20,6 +22,62 @@ lazy_static! {
|
|||
static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
pub struct Embedding(Vec<f32>);
|
||||
|
||||
impl From<Vec<f32>> for Embedding {
|
||||
fn from(value: Vec<f32>) -> Self {
|
||||
Embedding(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl Embedding {
|
||||
pub fn similarity(&self, other: &Self) -> f32 {
|
||||
let len = self.0.len();
|
||||
assert_eq!(len, other.0.len());
|
||||
|
||||
let mut result = 0.0;
|
||||
unsafe {
|
||||
matrixmultiply::sgemm(
|
||||
1,
|
||||
len,
|
||||
1,
|
||||
1.0,
|
||||
self.0.as_ptr(),
|
||||
len as isize,
|
||||
1,
|
||||
other.0.as_ptr(),
|
||||
1,
|
||||
len as isize,
|
||||
0.0,
|
||||
&mut result as *mut f32,
|
||||
1,
|
||||
1,
|
||||
);
|
||||
}
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
impl FromSql for Embedding {
|
||||
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
|
||||
let bytes = value.as_blob()?;
|
||||
let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
|
||||
if embedding.is_err() {
|
||||
return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
|
||||
}
|
||||
Ok(Embedding(embedding.unwrap()))
|
||||
}
|
||||
}
|
||||
|
||||
impl ToSql for Embedding {
|
||||
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
|
||||
let bytes = bincode::serialize(&self.0)
|
||||
.map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
|
||||
Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct OpenAIEmbeddings {
|
||||
pub client: Arc<dyn HttpClient>,
|
||||
|
@ -53,7 +111,7 @@ struct OpenAIEmbeddingUsage {
|
|||
|
||||
#[async_trait]
|
||||
pub trait EmbeddingProvider: Sync + Send {
|
||||
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Vec<f32>>>;
|
||||
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
|
||||
fn max_tokens_per_batch(&self) -> usize;
|
||||
fn truncate(&self, span: &str) -> (String, usize);
|
||||
}
|
||||
|
@ -62,10 +120,10 @@ pub struct DummyEmbeddings {}
|
|||
|
||||
#[async_trait]
|
||||
impl EmbeddingProvider for DummyEmbeddings {
|
||||
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Vec<f32>>> {
|
||||
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
|
||||
// 1024 is the OpenAI Embeddings size for ada models.
|
||||
// the model we will likely be starting with.
|
||||
let dummy_vec = vec![0.32 as f32; 1536];
|
||||
let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]);
|
||||
return Ok(vec![dummy_vec; spans.len()]);
|
||||
}
|
||||
|
||||
|
@ -137,7 +195,7 @@ impl EmbeddingProvider for OpenAIEmbeddings {
|
|||
(output, token_count)
|
||||
}
|
||||
|
||||
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Vec<f32>>> {
|
||||
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
|
||||
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
|
||||
const MAX_RETRIES: usize = 4;
|
||||
|
||||
|
@ -175,7 +233,7 @@ impl EmbeddingProvider for OpenAIEmbeddings {
|
|||
return Ok(response
|
||||
.data
|
||||
.into_iter()
|
||||
.map(|embedding| embedding.embedding)
|
||||
.map(|embedding| Embedding::from(embedding.embedding))
|
||||
.collect());
|
||||
}
|
||||
StatusCode::TOO_MANY_REQUESTS => {
|
||||
|
@ -218,3 +276,49 @@ impl EmbeddingProvider for OpenAIEmbeddings {
|
|||
Err(anyhow!("openai max retries"))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use rand::prelude::*;
|
||||
|
||||
#[gpui::test]
|
||||
fn test_similarity(mut rng: StdRng) {
|
||||
assert_eq!(
|
||||
Embedding::from(vec![1., 0., 0., 0., 0.])
|
||||
.similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])),
|
||||
0.
|
||||
);
|
||||
assert_eq!(
|
||||
Embedding::from(vec![2., 0., 0., 0., 0.])
|
||||
.similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])),
|
||||
6.
|
||||
);
|
||||
|
||||
for _ in 0..100 {
|
||||
let size = 1536;
|
||||
let mut a = vec![0.; size];
|
||||
let mut b = vec![0.; size];
|
||||
for (a, b) in a.iter_mut().zip(b.iter_mut()) {
|
||||
*a = rng.gen();
|
||||
*b = rng.gen();
|
||||
}
|
||||
let a = Embedding::from(a);
|
||||
let b = Embedding::from(b);
|
||||
|
||||
assert_eq!(
|
||||
round_to_decimals(a.similarity(&b), 1),
|
||||
round_to_decimals(reference_dot(&a.0, &b.0), 1)
|
||||
);
|
||||
}
|
||||
|
||||
fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
|
||||
let factor = (10.0 as f32).powi(decimal_places);
|
||||
(n * factor).round() / factor
|
||||
}
|
||||
|
||||
fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
|
||||
a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -121,7 +121,7 @@ impl EmbeddingQueue {
|
|||
&mut fragment.file.lock().documents[fragment.document_range.clone()]
|
||||
{
|
||||
if let Some(embedding) = embeddings.next() {
|
||||
document.embedding = embedding;
|
||||
document.embedding = Some(embedding);
|
||||
} else {
|
||||
//
|
||||
log::error!("number of embeddings returned different from number of documents");
|
||||
|
|
|
@ -1,7 +1,11 @@
|
|||
use crate::embedding::EmbeddingProvider;
|
||||
use anyhow::{anyhow, Ok, Result};
|
||||
use crate::embedding::{EmbeddingProvider, Embedding};
|
||||
use anyhow::{anyhow, Result};
|
||||
use language::{Grammar, Language};
|
||||
use sha1::{Digest, Sha1};
|
||||
use rusqlite::{
|
||||
types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef},
|
||||
ToSql,
|
||||
};
|
||||
use sha1::Digest;
|
||||
use std::{
|
||||
cmp::{self, Reverse},
|
||||
collections::HashSet,
|
||||
|
@ -11,13 +15,43 @@ use std::{
|
|||
};
|
||||
use tree_sitter::{Parser, QueryCursor};
|
||||
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
pub struct Sha1([u8; 20]);
|
||||
|
||||
impl FromSql for Sha1 {
|
||||
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
|
||||
let blob = value.as_blob()?;
|
||||
let bytes =
|
||||
blob.try_into()
|
||||
.map_err(|_| rusqlite::types::FromSqlError::InvalidBlobSize {
|
||||
expected_size: 20,
|
||||
blob_size: blob.len(),
|
||||
})?;
|
||||
return Ok(Sha1(bytes));
|
||||
}
|
||||
}
|
||||
|
||||
impl ToSql for Sha1 {
|
||||
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
|
||||
self.0.to_sql()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&'_ str> for Sha1 {
|
||||
fn from(value: &'_ str) -> Self {
|
||||
let mut sha1 = sha1::Sha1::new();
|
||||
sha1.update(value);
|
||||
Self(sha1.finalize().into())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
pub struct Document {
|
||||
pub name: String,
|
||||
pub range: Range<usize>,
|
||||
pub content: String,
|
||||
pub embedding: Vec<f32>,
|
||||
pub sha1: [u8; 20],
|
||||
pub embedding: Option<Embedding>,
|
||||
pub sha1: Sha1,
|
||||
pub token_count: usize,
|
||||
}
|
||||
|
||||
|
@ -69,17 +103,16 @@ impl CodeContextRetriever {
|
|||
.replace("<language>", language_name.as_ref())
|
||||
.replace("<item>", &content);
|
||||
|
||||
let mut sha1 = Sha1::new();
|
||||
sha1.update(&document_span);
|
||||
let sha1 = Sha1::from(document_span.as_str());
|
||||
|
||||
let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
|
||||
|
||||
Ok(vec![Document {
|
||||
range: 0..content.len(),
|
||||
content: document_span,
|
||||
embedding: Vec::new(),
|
||||
embedding: Default::default(),
|
||||
name: language_name.to_string(),
|
||||
sha1: sha1.finalize().into(),
|
||||
sha1,
|
||||
token_count,
|
||||
}])
|
||||
}
|
||||
|
@ -88,18 +121,14 @@ impl CodeContextRetriever {
|
|||
let document_span = MARKDOWN_CONTEXT_TEMPLATE
|
||||
.replace("<path>", relative_path.to_string_lossy().as_ref())
|
||||
.replace("<item>", &content);
|
||||
|
||||
let mut sha1 = Sha1::new();
|
||||
sha1.update(&document_span);
|
||||
|
||||
let sha1 = Sha1::from(document_span.as_str());
|
||||
let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
|
||||
|
||||
Ok(vec![Document {
|
||||
range: 0..content.len(),
|
||||
content: document_span,
|
||||
embedding: Vec::new(),
|
||||
embedding: None,
|
||||
name: "Markdown".to_string(),
|
||||
sha1: sha1.finalize().into(),
|
||||
sha1,
|
||||
token_count,
|
||||
}])
|
||||
}
|
||||
|
@ -279,15 +308,13 @@ impl CodeContextRetriever {
|
|||
);
|
||||
}
|
||||
|
||||
let mut sha1 = Sha1::new();
|
||||
sha1.update(&document_content);
|
||||
|
||||
let sha1 = Sha1::from(document_content.as_str());
|
||||
documents.push(Document {
|
||||
name,
|
||||
content: document_content,
|
||||
range: item_range.clone(),
|
||||
embedding: vec![],
|
||||
sha1: sha1.finalize().into(),
|
||||
embedding: None,
|
||||
sha1,
|
||||
token_count: 0,
|
||||
})
|
||||
}
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
use crate::{
|
||||
db::dot,
|
||||
embedding::{DummyEmbeddings, EmbeddingProvider},
|
||||
embedding::{DummyEmbeddings, Embedding, EmbeddingProvider},
|
||||
embedding_queue::EmbeddingQueue,
|
||||
parsing::{subtract_ranges, CodeContextRetriever, Document},
|
||||
parsing::{subtract_ranges, CodeContextRetriever, Document, Sha1},
|
||||
semantic_index_settings::SemanticIndexSettings,
|
||||
FileToEmbed, JobHandle, SearchResult, SemanticIndex,
|
||||
};
|
||||
|
@ -217,15 +216,17 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
|
|||
documents: (0..rng.gen_range(4..22))
|
||||
.map(|document_ix| {
|
||||
let content_len = rng.gen_range(10..100);
|
||||
let content = RandomCharIter::new(&mut rng)
|
||||
.with_simple_text()
|
||||
.take(content_len)
|
||||
.collect::<String>();
|
||||
let sha1 = Sha1::from(content.as_str());
|
||||
Document {
|
||||
range: 0..10,
|
||||
embedding: Vec::new(),
|
||||
embedding: None,
|
||||
name: format!("document {document_ix}"),
|
||||
content: RandomCharIter::new(&mut rng)
|
||||
.with_simple_text()
|
||||
.take(content_len)
|
||||
.collect(),
|
||||
sha1: rng.gen(),
|
||||
content,
|
||||
sha1,
|
||||
token_count: rng.gen_range(10..30),
|
||||
}
|
||||
})
|
||||
|
@ -254,7 +255,7 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
|
|||
.map(|file| {
|
||||
let mut file = file.clone();
|
||||
for doc in &mut file.documents {
|
||||
doc.embedding = embedding_provider.embed_sync(doc.content.as_ref());
|
||||
doc.embedding = Some(embedding_provider.embed_sync(doc.content.as_ref()));
|
||||
}
|
||||
file
|
||||
})
|
||||
|
@ -1242,36 +1243,6 @@ async fn test_code_context_retrieval_php() {
|
|||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_dot_product(mut rng: StdRng) {
|
||||
assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.);
|
||||
assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.);
|
||||
|
||||
for _ in 0..100 {
|
||||
let size = 1536;
|
||||
let mut a = vec![0.; size];
|
||||
let mut b = vec![0.; size];
|
||||
for (a, b) in a.iter_mut().zip(b.iter_mut()) {
|
||||
*a = rng.gen();
|
||||
*b = rng.gen();
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
round_to_decimals(dot(&a, &b), 1),
|
||||
round_to_decimals(reference_dot(&a, &b), 1)
|
||||
);
|
||||
}
|
||||
|
||||
fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
|
||||
let factor = (10.0 as f32).powi(decimal_places);
|
||||
(n * factor).round() / factor
|
||||
}
|
||||
|
||||
fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
|
||||
a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct FakeEmbeddingProvider {
|
||||
embedding_count: AtomicUsize,
|
||||
|
@ -1282,7 +1253,7 @@ impl FakeEmbeddingProvider {
|
|||
self.embedding_count.load(atomic::Ordering::SeqCst)
|
||||
}
|
||||
|
||||
fn embed_sync(&self, span: &str) -> Vec<f32> {
|
||||
fn embed_sync(&self, span: &str) -> Embedding {
|
||||
let mut result = vec![1.0; 26];
|
||||
for letter in span.chars() {
|
||||
let letter = letter.to_ascii_lowercase();
|
||||
|
@ -1299,7 +1270,7 @@ impl FakeEmbeddingProvider {
|
|||
*x /= norm;
|
||||
}
|
||||
|
||||
result
|
||||
result.into()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1313,7 +1284,7 @@ impl EmbeddingProvider for FakeEmbeddingProvider {
|
|||
200
|
||||
}
|
||||
|
||||
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Vec<f32>>> {
|
||||
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
|
||||
self.embedding_count
|
||||
.fetch_add(spans.len(), atomic::Ordering::SeqCst);
|
||||
Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
|
||||
|
|
Loading…
Reference in a new issue