Working incremental index engine, with streaming similarity search!

Co-authored-by: maxbrunsfeld <max@zed.dev>
This commit is contained in:
KCaverly 2023-06-27 15:31:21 -04:00
parent 953e928bdb
commit 4bfe3de1f2
5 changed files with 268 additions and 111 deletions

1
Cargo.lock generated
View file

@ -7967,6 +7967,7 @@ dependencies = [
"serde_json",
"sha-1 0.10.1",
"smol",
"tempdir",
"tree-sitter",
"tree-sitter-rust",
"unindent",

View file

@ -17,7 +17,7 @@ util = { path = "../util" }
anyhow.workspace = true
futures.workspace = true
smol.workspace = true
rusqlite = { version = "0.27.0", features=["blob"] }
rusqlite = { version = "0.27.0", features = ["blob", "array", "modern_sqlite"] }
isahc.workspace = true
log.workspace = true
tree-sitter.workspace = true
@ -38,3 +38,4 @@ workspace = { path = "../workspace", features = ["test-support"] }
tree-sitter-rust = "*"
rand.workspace = true
unindent.workspace = true
tempdir.workspace = true

View file

@ -7,9 +7,10 @@ use anyhow::{anyhow, Result};
use rusqlite::{
params,
types::{FromSql, FromSqlResult, ValueRef},
Connection,
types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef},
ToSql,
};
use sha1::{Digest, Sha1};
use crate::IndexedFile;
@ -32,7 +33,60 @@ pub struct DocumentRecord {
pub struct FileRecord {
pub id: usize,
pub relative_path: String,
pub sha1: String,
pub sha1: FileSha1,
}
#[derive(Debug)]
pub struct FileSha1(pub Vec<u8>);
impl FileSha1 {
pub fn from_str(content: String) -> Self {
let mut hasher = Sha1::new();
hasher.update(content);
let sha1 = hasher.finalize()[..]
.into_iter()
.map(|val| val.to_owned())
.collect::<Vec<u8>>();
return FileSha1(sha1);
}
pub fn equals(&self, content: &String) -> bool {
let mut hasher = Sha1::new();
hasher.update(content);
let sha1 = hasher.finalize()[..]
.into_iter()
.map(|val| val.to_owned())
.collect::<Vec<u8>>();
let equal = self
.0
.clone()
.into_iter()
.zip(sha1)
.filter(|&(a, b)| a == b)
.count()
== self.0.len();
equal
}
}
impl ToSql for FileSha1 {
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput<'_>> {
return self.0.to_sql();
}
}
impl FromSql for FileSha1 {
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
let bytes = value.as_blob()?;
Ok(FileSha1(
bytes
.into_iter()
.map(|val| val.to_owned())
.collect::<Vec<u8>>(),
))
}
}
#[derive(Debug)]
@ -63,6 +117,8 @@ impl VectorDatabase {
}
fn initialize_database(&self) -> Result<()> {
rusqlite::vtab::array::load_module(&self.db)?;
// This will create the database if it doesnt exist
// Initialize Vector Databasing Tables
@ -81,7 +137,7 @@ impl VectorDatabase {
id INTEGER PRIMARY KEY AUTOINCREMENT,
worktree_id INTEGER NOT NULL,
relative_path VARCHAR NOT NULL,
sha1 NVARCHAR(40) NOT NULL,
sha1 BLOB NOT NULL,
FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
)",
[],
@ -102,30 +158,23 @@ impl VectorDatabase {
Ok(())
}
// pub async fn get_or_create_project(project_path: PathBuf) -> Result<usize> {
// // Check if we have the project, if we do, return the ID
// // If we do not have the project, insert the project and return the ID
// let db = rusqlite::Connection::open(VECTOR_DB_URL)?;
// let projects_query = db.prepare(&format!(
// "SELECT id FROM projects WHERE path = {}",
// project_path.to_str().unwrap() // This is unsafe
// ))?;
// let project_id = db.last_insert_rowid();
// return Ok(project_id as usize);
// }
pub fn insert_file(&self, indexed_file: IndexedFile) -> Result<()> {
pub fn insert_file(&self, worktree_id: i64, indexed_file: IndexedFile) -> Result<()> {
// Write to files table, and return generated id.
let files_insert = self.db.execute(
"INSERT INTO files (relative_path, sha1) VALUES (?1, ?2)",
params![indexed_file.path.to_str(), indexed_file.sha1],
log::info!("Inserting File!");
self.db.execute(
"
DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2;
",
params![worktree_id, indexed_file.path.to_str()],
)?;
self.db.execute(
"
INSERT INTO files (worktree_id, relative_path, sha1) VALUES (?1, ?2, $3);
",
params![worktree_id, indexed_file.path.to_str(), indexed_file.sha1],
)?;
let inserted_id = self.db.last_insert_rowid();
let file_id = self.db.last_insert_rowid();
// Currently inserting at approximately 3400 documents a second
// I imagine we can speed this up with a bulk insert of some kind.
@ -135,7 +184,7 @@ impl VectorDatabase {
self.db.execute(
"INSERT INTO documents (file_id, offset, name, embedding) VALUES (?1, ?2, ?3, ?4)",
params![
inserted_id,
file_id,
document.offset.to_string(),
document.name,
embedding_blob
@ -147,25 +196,41 @@ impl VectorDatabase {
}
pub fn find_or_create_worktree(&self, worktree_root_path: &Path) -> Result<i64> {
// Check that the absolute path doesnt exist
let mut worktree_query = self
.db
.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
let worktree_id = worktree_query
.query_row(params![worktree_root_path.to_string_lossy()], |row| {
Ok(row.get::<_, i64>(0)?)
})
.map_err(|err| anyhow!(err));
if worktree_id.is_ok() {
return worktree_id;
}
// If worktree_id is Err, insert new worktree
self.db.execute(
"
INSERT into worktrees (absolute_path) VALUES (?1)
ON CONFLICT DO NOTHING
",
params![worktree_root_path.to_string_lossy()],
)?;
Ok(self.db.last_insert_rowid())
}
pub fn get_file_hashes(&self, worktree_id: i64) -> Result<Vec<(PathBuf, String)>> {
let mut statement = self
.db
.prepare("SELECT relative_path, sha1 FROM files ORDER BY relative_path")?;
let mut result = Vec::new();
for row in
statement.query_map([], |row| Ok((row.get::<_, String>(0)?.into(), row.get(1)?)))?
{
result.push(row?);
pub fn get_file_hashes(&self, worktree_id: i64) -> Result<HashMap<PathBuf, FileSha1>> {
let mut statement = self.db.prepare(
"SELECT relative_path, sha1 FROM files WHERE worktree_id = ?1 ORDER BY relative_path",
)?;
let mut result: HashMap<PathBuf, FileSha1> = HashMap::new();
for row in statement.query_map(params![worktree_id], |row| {
Ok((row.get::<_, String>(0)?.into(), row.get(1)?))
})? {
let row = row?;
result.insert(row.0, row.1);
}
Ok(result)
}
@ -204,6 +269,53 @@ impl VectorDatabase {
Ok(())
}
pub fn get_documents_by_ids(&self, ids: &[i64]) -> Result<Vec<(PathBuf, usize, String)>> {
let mut statement = self.db.prepare(
"
SELECT
documents.id, files.relative_path, documents.offset, documents.name
FROM
documents, files
WHERE
documents.file_id = files.id AND
documents.id in rarray(?)
",
)?;
let result_iter = statement.query_map(
params![std::rc::Rc::new(
ids.iter()
.copied()
.map(|v| rusqlite::types::Value::from(v))
.collect::<Vec<_>>()
)],
|row| {
Ok((
row.get::<_, i64>(0)?,
row.get::<_, String>(1)?.into(),
row.get(2)?,
row.get(3)?,
))
},
)?;
let mut values_by_id = HashMap::<i64, (PathBuf, usize, String)>::default();
for row in result_iter {
let (id, path, offset, name) = row?;
values_by_id.insert(id, (path, offset, name));
}
let mut results = Vec::with_capacity(ids.len());
for id in ids {
let (path, offset, name) = values_by_id
.remove(id)
.ok_or(anyhow!("missing document id {}", id))?;
results.push((path, offset, name));
}
Ok(results)
}
pub fn get_documents(&self) -> Result<HashMap<usize, DocumentRecord>> {
let mut query_statement = self
.db

View file

@ -7,15 +7,14 @@ mod search;
mod vector_store_tests;
use anyhow::{anyhow, Result};
use db::{VectorDatabase, VECTOR_DB_URL};
use embedding::{DummyEmbeddings, EmbeddingProvider, OpenAIEmbeddings};
use db::{FileSha1, VectorDatabase, VECTOR_DB_URL};
use embedding::{EmbeddingProvider, OpenAIEmbeddings};
use gpui::{AppContext, Entity, ModelContext, ModelHandle, Task};
use language::LanguageRegistry;
use language::{Language, LanguageRegistry};
use parsing::Document;
use project::{Fs, Project};
use search::{BruteForceSearch, VectorSearch};
use smol::channel;
use std::{cmp::Ordering, path::PathBuf, sync::Arc, time::Instant};
use std::{cmp::Ordering, collections::HashMap, path::PathBuf, sync::Arc};
use tree_sitter::{Parser, QueryCursor};
use util::{http::HttpClient, ResultExt, TryFutureExt};
use workspace::WorkspaceCreated;
@ -45,7 +44,7 @@ pub fn init(
let project = workspace.read(cx).project().clone();
if project.read(cx).is_local() {
vector_store.update(cx, |store, cx| {
store.add_project(project, cx);
store.add_project(project, cx).detach();
});
}
}
@ -57,16 +56,10 @@ pub fn init(
#[derive(Debug)]
pub struct IndexedFile {
path: PathBuf,
sha1: String,
sha1: FileSha1,
documents: Vec<Document>,
}
// struct SearchResult {
// path: PathBuf,
// offset: usize,
// name: String,
// distance: f32,
// }
struct VectorStore {
fs: Arc<dyn Fs>,
database_url: Arc<str>,
@ -99,20 +92,10 @@ impl VectorStore {
cursor: &mut QueryCursor,
parser: &mut Parser,
embedding_provider: &dyn EmbeddingProvider,
language_registry: &Arc<LanguageRegistry>,
language: Arc<Language>,
file_path: PathBuf,
content: String,
) -> Result<IndexedFile> {
dbg!(&file_path, &content);
let language = language_registry
.language_for_file(&file_path, None)
.await?;
if language.name().as_ref() != "Rust" {
Err(anyhow!("unsupported language"))?;
}
let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?;
let outline_config = grammar
.outline_config
@ -156,9 +139,11 @@ impl VectorStore {
document.embedding = embedding;
}
let sha1 = FileSha1::from_str(content);
return Ok(IndexedFile {
path: file_path,
sha1: String::new(),
sha1,
documents,
});
}
@ -171,7 +156,13 @@ impl VectorStore {
let worktree_scans_complete = project
.read(cx)
.worktrees(cx)
.map(|worktree| worktree.read(cx).as_local().unwrap().scan_complete())
.map(|worktree| {
let scan_complete = worktree.read(cx).as_local().unwrap().scan_complete();
async move {
scan_complete.await;
log::info!("worktree scan completed");
}
})
.collect::<Vec<_>>();
let fs = self.fs.clone();
@ -182,6 +173,13 @@ impl VectorStore {
cx.spawn(|_, cx| async move {
futures::future::join_all(worktree_scans_complete).await;
// TODO: remove this after fixing the bug in scan_complete
cx.background()
.timer(std::time::Duration::from_secs(3))
.await;
let db = VectorDatabase::new(&database_url)?;
let worktrees = project.read_with(&cx, |project, cx| {
project
.worktrees(cx)
@ -189,37 +187,74 @@ impl VectorStore {
.collect::<Vec<_>>()
});
let db = VectorDatabase::new(&database_url)?;
let worktree_root_paths = worktrees
.iter()
.map(|worktree| worktree.abs_path().clone())
.collect::<Vec<_>>();
let (db, file_hashes) = cx
// Here we query the worktree ids, and yet we dont have them elsewhere
// We likely want to clean up these datastructures
let (db, worktree_hashes, worktree_ids) = cx
.background()
.spawn(async move {
let mut hashes = Vec::new();
let mut worktree_ids: HashMap<PathBuf, i64> = HashMap::new();
let mut hashes: HashMap<i64, HashMap<PathBuf, FileSha1>> = HashMap::new();
for worktree_root_path in worktree_root_paths {
let worktree_id =
db.find_or_create_worktree(worktree_root_path.as_ref())?;
hashes.push((worktree_id, db.get_file_hashes(worktree_id)?));
worktree_ids.insert(worktree_root_path.to_path_buf(), worktree_id);
hashes.insert(worktree_id, db.get_file_hashes(worktree_id)?);
}
anyhow::Ok((db, hashes))
anyhow::Ok((db, hashes, worktree_ids))
})
.await?;
let (paths_tx, paths_rx) = channel::unbounded::<(i64, PathBuf, String)>();
let (indexed_files_tx, indexed_files_rx) = channel::unbounded::<IndexedFile>();
let (paths_tx, paths_rx) =
channel::unbounded::<(i64, PathBuf, String, Arc<Language>)>();
let (indexed_files_tx, indexed_files_rx) = channel::unbounded::<(i64, IndexedFile)>();
cx.background()
.spawn({
let fs = fs.clone();
async move {
for worktree in worktrees.into_iter() {
let worktree_id = worktree_ids[&worktree.abs_path().to_path_buf()];
let file_hashes = &worktree_hashes[&worktree_id];
for file in worktree.files(false, 0) {
let absolute_path = worktree.absolutize(&file.path);
dbg!(&absolute_path);
if let Some(content) = fs.load(&absolute_path).await.log_err() {
dbg!(&content);
paths_tx.try_send((0, absolute_path, content)).unwrap();
if let Ok(language) = language_registry
.language_for_file(&absolute_path, None)
.await
{
if language.name().as_ref() != "Rust" {
continue;
}
if let Some(content) = fs.load(&absolute_path).await.log_err() {
log::info!("loaded file: {absolute_path:?}");
let path_buf = file.path.to_path_buf();
let already_stored = file_hashes
.get(&path_buf)
.map_or(false, |existing_hash| {
existing_hash.equals(&content)
});
if !already_stored {
log::info!(
"File Changed (Sending to Parse): {:?}",
&path_buf
);
paths_tx
.try_send((
worktree_id,
path_buf,
content,
language,
))
.unwrap();
}
}
}
}
}
@ -230,8 +265,8 @@ impl VectorStore {
let db_write_task = cx.background().spawn(
async move {
// Initialize Database, creates database and tables if not exists
while let Ok(indexed_file) = indexed_files_rx.recv().await {
db.insert_file(indexed_file).log_err();
while let Ok((worktree_id, indexed_file)) = indexed_files_rx.recv().await {
db.insert_file(worktree_id, indexed_file).log_err();
}
// ALL OF THE BELOW IS FOR TESTING,
@ -271,29 +306,29 @@ impl VectorStore {
.log_err(),
);
let provider = DummyEmbeddings {};
// let provider = OpenAIEmbeddings { client };
cx.background()
.scoped(|scope| {
for _ in 0..cx.background().num_cpus() {
scope.spawn(async {
let mut parser = Parser::new();
let mut cursor = QueryCursor::new();
while let Ok((worktree_id, file_path, content)) = paths_rx.recv().await
while let Ok((worktree_id, file_path, content, language)) =
paths_rx.recv().await
{
if let Some(indexed_file) = Self::index_file(
&mut cursor,
&mut parser,
&provider,
&language_registry,
embedding_provider.as_ref(),
language,
file_path,
content,
)
.await
.log_err()
{
indexed_files_tx.try_send(indexed_file).unwrap();
indexed_files_tx
.try_send((worktree_id, indexed_file))
.unwrap();
}
}
});
@ -315,41 +350,42 @@ impl VectorStore {
) -> Task<Result<Vec<SearchResult>>> {
let embedding_provider = self.embedding_provider.clone();
let database_url = self.database_url.clone();
cx.spawn(|this, cx| async move {
cx.background().spawn(async move {
let database = VectorDatabase::new(database_url.as_ref())?;
// let embedding = embedding_provider.embed_batch(vec![&phrase]).await?;
//
let phrase_embedding = embedding_provider
.embed_batch(vec![&phrase])
.await?
.into_iter()
.next()
.unwrap();
let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
database.for_each_document(0, |id, embedding| {
dbg!(id, &embedding);
let similarity = dot(&embedding.0, &embedding.0);
let similarity = dot(&embedding.0, &phrase_embedding);
let ix = match results.binary_search_by(|(_, s)| {
s.partial_cmp(&similarity).unwrap_or(Ordering::Equal)
similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
}) {
Ok(ix) => ix,
Err(ix) => ix,
};
results.insert(ix, (id, similarity));
results.truncate(limit);
})?;
dbg!(&results);
let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
// let documents = database.get_documents_by_ids(ids)?;
let documents = database.get_documents_by_ids(&ids)?;
// let search_provider = cx
// .background()
// .spawn(async move { BruteForceSearch::load(&database) })
// .await?;
// let results = search_provider.top_k_search(&embedding, limit))
anyhow::Ok(vec![])
anyhow::Ok(
documents
.into_iter()
.map(|(file_path, offset, name)| SearchResult {
name,
offset,
file_path,
})
.collect(),
)
})
}
}

View file

@ -57,20 +57,26 @@ async fn test_vector_store(cx: &mut TestAppContext) {
);
languages.add(rust_language);
let db_dir = tempdir::TempDir::new("vector-store").unwrap();
let db_path = db_dir.path().join("db.sqlite");
let store = cx.add_model(|_| {
VectorStore::new(
fs.clone(),
"foo".to_string(),
db_path.to_string_lossy().to_string(),
Arc::new(FakeEmbeddingProvider),
languages,
)
});
let project = Project::test(fs, ["/the-root".as_ref()], cx).await;
store
.update(cx, |store, cx| store.add_project(project, cx))
.await
.unwrap();
let add_project = store.update(cx, |store, cx| store.add_project(project, cx));
// TODO - remove
cx.foreground()
.advance_clock(std::time::Duration::from_secs(3));
add_project.await.unwrap();
let search_results = store
.update(cx, |store, cx| store.search("aaaa".to_string(), 5, cx))
@ -78,7 +84,7 @@ async fn test_vector_store(cx: &mut TestAppContext) {
.unwrap();
assert_eq!(search_results[0].offset, 0);
assert_eq!(search_results[1].name, "aaa");
assert_eq!(search_results[0].name, "aaa");
}
#[test]
@ -114,9 +120,10 @@ impl EmbeddingProvider for FakeEmbeddingProvider {
Ok(spans
.iter()
.map(|span| {
let mut result = vec![0.0; 26];
let mut result = vec![1.0; 26];
for letter in span.chars() {
if letter as u32 > 'a' as u32 {
let letter = letter.to_ascii_lowercase();
if letter as u32 >= 'a' as u32 {
let ix = (letter as u32) - ('a' as u32);
if ix < 26 {
result[ix as usize] += 1.0;