added sha1 encoding for each document

This commit is contained in:
KCaverly 2023-08-21 16:35:57 +02:00
parent bbe6d3b261
commit 3d89cd10a4
6 changed files with 245 additions and 207 deletions

411
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -38,6 +38,7 @@ parking_lot.workspace = true
rand.workspace = true
schemars.workspace = true
globset.workspace = true
sha1 = "0.10.5"
[dev-dependencies]
gpui = { path = "../gpui", features = ["test-support"] }

View file

@ -26,6 +26,9 @@ pub struct FileRecord {
#[derive(Debug)]
struct Embedding(pub Vec<f32>);
#[derive(Debug)]
struct Sha1(pub Vec<u8>);
impl FromSql for Embedding {
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
let bytes = value.as_blob()?;
@ -37,6 +40,17 @@ impl FromSql for Embedding {
}
}
impl FromSql for Sha1 {
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
let bytes = value.as_blob()?;
let sha1: Result<Vec<u8>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
if sha1.is_err() {
return Err(rusqlite::types::FromSqlError::Other(sha1.unwrap_err()));
}
return Ok(Sha1(sha1.unwrap()));
}
}
pub struct VectorDatabase {
db: rusqlite::Connection,
}
@ -132,6 +146,7 @@ impl VectorDatabase {
end_byte INTEGER NOT NULL,
name VARCHAR NOT NULL,
embedding BLOB NOT NULL,
sha1 BLOB NOT NULL,
FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
)",
[],
@ -182,15 +197,17 @@ impl VectorDatabase {
// I imagine we can speed this up with a bulk insert of some kind.
for document in documents {
let embedding_blob = bincode::serialize(&document.embedding)?;
let sha_blob = bincode::serialize(&document.sha1)?;
self.db.execute(
"INSERT INTO documents (file_id, start_byte, end_byte, name, embedding) VALUES (?1, ?2, ?3, ?4, $5)",
"INSERT INTO documents (file_id, start_byte, end_byte, name, embedding, sha1) VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
params![
file_id,
document.range.start.to_string(),
document.range.end.to_string(),
document.name,
embedding_blob
embedding_blob,
sha_blob
],
)?;
}

View file

@ -39,7 +39,7 @@ struct OpenAIEmbeddingResponse {
#[derive(Debug, Deserialize)]
struct OpenAIEmbedding {
embedding: Vec<f32>,
embedding: Vec<f16>,
index: usize,
object: String,
}

View file

@ -1,5 +1,6 @@
use anyhow::{anyhow, Ok, Result};
use language::{Grammar, Language};
use sha1::{Digest, Sha1};
use std::{
cmp::{self, Reverse},
collections::HashSet,
@ -15,6 +16,7 @@ pub struct Document {
pub range: Range<usize>,
pub content: String,
pub embedding: Vec<f32>,
pub sha1: [u8; 20],
}
const CODE_CONTEXT_TEMPLATE: &str =
@ -63,11 +65,15 @@ impl CodeContextRetriever {
.replace("<language>", language_name.as_ref())
.replace("<item>", &content);
let mut sha1 = Sha1::new();
sha1.update(&document_span);
Ok(vec![Document {
range: 0..content.len(),
content: document_span,
embedding: Vec::new(),
name: language_name.to_string(),
sha1: sha1.finalize().into(),
}])
}
@ -76,11 +82,15 @@ impl CodeContextRetriever {
.replace("<path>", relative_path.to_string_lossy().as_ref())
.replace("<item>", &content);
let mut sha1 = Sha1::new();
sha1.update(&document_span);
Ok(vec![Document {
range: 0..content.len(),
content: document_span,
embedding: Vec::new(),
name: "Markdown".to_string(),
sha1: sha1.finalize().into(),
}])
}
@ -253,11 +263,15 @@ impl CodeContextRetriever {
);
}
let mut sha1 = Sha1::new();
sha1.update(&document_content);
documents.push(Document {
name,
content: document_content,
range: item_range.clone(),
embedding: vec![],
sha1: sha1.finalize().into(),
})
}

View file

@ -34,7 +34,7 @@ use util::{
ResultExt,
};
const SEMANTIC_INDEX_VERSION: usize = 6;
const SEMANTIC_INDEX_VERSION: usize = 7;
const EMBEDDINGS_BATCH_SIZE: usize = 80;
pub fn init(
@ -92,6 +92,7 @@ pub struct SemanticIndex {
struct ProjectState {
worktree_db_ids: Vec<(WorktreeId, i64)>,
file_mtimes: HashMap<PathBuf, SystemTime>,
outstanding_job_count_rx: watch::Receiver<usize>,
_outstanding_job_count_tx: Arc<Mutex<watch::Sender<usize>>>,
}