added sha1 encoding for each document

This commit is contained in:
KCaverly 2023-08-21 16:35:57 +02:00
parent bbe6d3b261
commit 3d89cd10a4
6 changed files with 245 additions and 207 deletions

411
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -38,6 +38,7 @@ parking_lot.workspace = true
rand.workspace = true rand.workspace = true
schemars.workspace = true schemars.workspace = true
globset.workspace = true globset.workspace = true
sha1 = "0.10.5"
[dev-dependencies] [dev-dependencies]
gpui = { path = "../gpui", features = ["test-support"] } gpui = { path = "../gpui", features = ["test-support"] }

View file

@ -26,6 +26,9 @@ pub struct FileRecord {
#[derive(Debug)] #[derive(Debug)]
struct Embedding(pub Vec<f32>); struct Embedding(pub Vec<f32>);
#[derive(Debug)]
struct Sha1(pub Vec<u8>);
impl FromSql for Embedding { impl FromSql for Embedding {
fn column_result(value: ValueRef) -> FromSqlResult<Self> { fn column_result(value: ValueRef) -> FromSqlResult<Self> {
let bytes = value.as_blob()?; let bytes = value.as_blob()?;
@ -37,6 +40,17 @@ impl FromSql for Embedding {
} }
} }
impl FromSql for Sha1 {
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
let bytes = value.as_blob()?;
let sha1: Result<Vec<u8>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
if sha1.is_err() {
return Err(rusqlite::types::FromSqlError::Other(sha1.unwrap_err()));
}
return Ok(Sha1(sha1.unwrap()));
}
}
pub struct VectorDatabase { pub struct VectorDatabase {
db: rusqlite::Connection, db: rusqlite::Connection,
} }
@ -132,6 +146,7 @@ impl VectorDatabase {
end_byte INTEGER NOT NULL, end_byte INTEGER NOT NULL,
name VARCHAR NOT NULL, name VARCHAR NOT NULL,
embedding BLOB NOT NULL, embedding BLOB NOT NULL,
sha1 BLOB NOT NULL,
FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
)", )",
[], [],
@ -182,15 +197,17 @@ impl VectorDatabase {
// I imagine we can speed this up with a bulk insert of some kind. // I imagine we can speed this up with a bulk insert of some kind.
for document in documents { for document in documents {
let embedding_blob = bincode::serialize(&document.embedding)?; let embedding_blob = bincode::serialize(&document.embedding)?;
let sha_blob = bincode::serialize(&document.sha1)?;
self.db.execute( self.db.execute(
"INSERT INTO documents (file_id, start_byte, end_byte, name, embedding) VALUES (?1, ?2, ?3, ?4, $5)", "INSERT INTO documents (file_id, start_byte, end_byte, name, embedding, sha1) VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
params![ params![
file_id, file_id,
document.range.start.to_string(), document.range.start.to_string(),
document.range.end.to_string(), document.range.end.to_string(),
document.name, document.name,
embedding_blob embedding_blob,
sha_blob
], ],
)?; )?;
} }

View file

@ -39,7 +39,7 @@ struct OpenAIEmbeddingResponse {
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
struct OpenAIEmbedding { struct OpenAIEmbedding {
embedding: Vec<f32>, embedding: Vec<f16>,
index: usize, index: usize,
object: String, object: String,
} }

View file

@ -1,5 +1,6 @@
use anyhow::{anyhow, Ok, Result}; use anyhow::{anyhow, Ok, Result};
use language::{Grammar, Language}; use language::{Grammar, Language};
use sha1::{Digest, Sha1};
use std::{ use std::{
cmp::{self, Reverse}, cmp::{self, Reverse},
collections::HashSet, collections::HashSet,
@ -15,6 +16,7 @@ pub struct Document {
pub range: Range<usize>, pub range: Range<usize>,
pub content: String, pub content: String,
pub embedding: Vec<f32>, pub embedding: Vec<f32>,
pub sha1: [u8; 20],
} }
const CODE_CONTEXT_TEMPLATE: &str = const CODE_CONTEXT_TEMPLATE: &str =
@ -63,11 +65,15 @@ impl CodeContextRetriever {
.replace("<language>", language_name.as_ref()) .replace("<language>", language_name.as_ref())
.replace("<item>", &content); .replace("<item>", &content);
let mut sha1 = Sha1::new();
sha1.update(&document_span);
Ok(vec![Document { Ok(vec![Document {
range: 0..content.len(), range: 0..content.len(),
content: document_span, content: document_span,
embedding: Vec::new(), embedding: Vec::new(),
name: language_name.to_string(), name: language_name.to_string(),
sha1: sha1.finalize().into(),
}]) }])
} }
@ -76,11 +82,15 @@ impl CodeContextRetriever {
.replace("<path>", relative_path.to_string_lossy().as_ref()) .replace("<path>", relative_path.to_string_lossy().as_ref())
.replace("<item>", &content); .replace("<item>", &content);
let mut sha1 = Sha1::new();
sha1.update(&document_span);
Ok(vec![Document { Ok(vec![Document {
range: 0..content.len(), range: 0..content.len(),
content: document_span, content: document_span,
embedding: Vec::new(), embedding: Vec::new(),
name: "Markdown".to_string(), name: "Markdown".to_string(),
sha1: sha1.finalize().into(),
}]) }])
} }
@ -253,11 +263,15 @@ impl CodeContextRetriever {
); );
} }
let mut sha1 = Sha1::new();
sha1.update(&document_content);
documents.push(Document { documents.push(Document {
name, name,
content: document_content, content: document_content,
range: item_range.clone(), range: item_range.clone(),
embedding: vec![], embedding: vec![],
sha1: sha1.finalize().into(),
}) })
} }

View file

@ -34,7 +34,7 @@ use util::{
ResultExt, ResultExt,
}; };
const SEMANTIC_INDEX_VERSION: usize = 6; const SEMANTIC_INDEX_VERSION: usize = 7;
const EMBEDDINGS_BATCH_SIZE: usize = 80; const EMBEDDINGS_BATCH_SIZE: usize = 80;
pub fn init( pub fn init(
@ -92,6 +92,7 @@ pub struct SemanticIndex {
struct ProjectState { struct ProjectState {
worktree_db_ids: Vec<(WorktreeId, i64)>, worktree_db_ids: Vec<(WorktreeId, i64)>,
file_mtimes: HashMap<PathBuf, SystemTime>,
outstanding_job_count_rx: watch::Receiver<usize>, outstanding_job_count_rx: watch::Receiver<usize>,
_outstanding_job_count_tx: Arc<Mutex<watch::Sender<usize>>>, _outstanding_job_count_tx: Arc<Mutex<watch::Sender<usize>>>,
} }