mirror of
https://github.com/zed-industries/zed.git
synced 2025-01-27 04:44:30 +00:00
Working incremental index engine, with streaming similarity search!
Co-authored-by: maxbrunsfeld <max@zed.dev>
This commit is contained in:
parent
953e928bdb
commit
4bfe3de1f2
5 changed files with 268 additions and 111 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -7967,6 +7967,7 @@ dependencies = [
|
|||
"serde_json",
|
||||
"sha-1 0.10.1",
|
||||
"smol",
|
||||
"tempdir",
|
||||
"tree-sitter",
|
||||
"tree-sitter-rust",
|
||||
"unindent",
|
||||
|
|
|
@ -17,7 +17,7 @@ util = { path = "../util" }
|
|||
anyhow.workspace = true
|
||||
futures.workspace = true
|
||||
smol.workspace = true
|
||||
rusqlite = { version = "0.27.0", features=["blob"] }
|
||||
rusqlite = { version = "0.27.0", features = ["blob", "array", "modern_sqlite"] }
|
||||
isahc.workspace = true
|
||||
log.workspace = true
|
||||
tree-sitter.workspace = true
|
||||
|
@ -38,3 +38,4 @@ workspace = { path = "../workspace", features = ["test-support"] }
|
|||
tree-sitter-rust = "*"
|
||||
rand.workspace = true
|
||||
unindent.workspace = true
|
||||
tempdir.workspace = true
|
||||
|
|
|
@ -7,9 +7,10 @@ use anyhow::{anyhow, Result};
|
|||
|
||||
use rusqlite::{
|
||||
params,
|
||||
types::{FromSql, FromSqlResult, ValueRef},
|
||||
Connection,
|
||||
types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef},
|
||||
ToSql,
|
||||
};
|
||||
use sha1::{Digest, Sha1};
|
||||
|
||||
use crate::IndexedFile;
|
||||
|
||||
|
@ -32,7 +33,60 @@ pub struct DocumentRecord {
|
|||
pub struct FileRecord {
|
||||
pub id: usize,
|
||||
pub relative_path: String,
|
||||
pub sha1: String,
|
||||
pub sha1: FileSha1,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct FileSha1(pub Vec<u8>);
|
||||
|
||||
impl FileSha1 {
|
||||
pub fn from_str(content: String) -> Self {
|
||||
let mut hasher = Sha1::new();
|
||||
hasher.update(content);
|
||||
let sha1 = hasher.finalize()[..]
|
||||
.into_iter()
|
||||
.map(|val| val.to_owned())
|
||||
.collect::<Vec<u8>>();
|
||||
return FileSha1(sha1);
|
||||
}
|
||||
|
||||
pub fn equals(&self, content: &String) -> bool {
|
||||
let mut hasher = Sha1::new();
|
||||
hasher.update(content);
|
||||
let sha1 = hasher.finalize()[..]
|
||||
.into_iter()
|
||||
.map(|val| val.to_owned())
|
||||
.collect::<Vec<u8>>();
|
||||
|
||||
let equal = self
|
||||
.0
|
||||
.clone()
|
||||
.into_iter()
|
||||
.zip(sha1)
|
||||
.filter(|&(a, b)| a == b)
|
||||
.count()
|
||||
== self.0.len();
|
||||
|
||||
equal
|
||||
}
|
||||
}
|
||||
|
||||
impl ToSql for FileSha1 {
|
||||
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput<'_>> {
|
||||
return self.0.to_sql();
|
||||
}
|
||||
}
|
||||
|
||||
impl FromSql for FileSha1 {
|
||||
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
|
||||
let bytes = value.as_blob()?;
|
||||
Ok(FileSha1(
|
||||
bytes
|
||||
.into_iter()
|
||||
.map(|val| val.to_owned())
|
||||
.collect::<Vec<u8>>(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
@ -63,6 +117,8 @@ impl VectorDatabase {
|
|||
}
|
||||
|
||||
fn initialize_database(&self) -> Result<()> {
|
||||
rusqlite::vtab::array::load_module(&self.db)?;
|
||||
|
||||
// This will create the database if it doesnt exist
|
||||
|
||||
// Initialize Vector Databasing Tables
|
||||
|
@ -81,7 +137,7 @@ impl VectorDatabase {
|
|||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
worktree_id INTEGER NOT NULL,
|
||||
relative_path VARCHAR NOT NULL,
|
||||
sha1 NVARCHAR(40) NOT NULL,
|
||||
sha1 BLOB NOT NULL,
|
||||
FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
|
||||
)",
|
||||
[],
|
||||
|
@ -102,30 +158,23 @@ impl VectorDatabase {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
// pub async fn get_or_create_project(project_path: PathBuf) -> Result<usize> {
|
||||
// // Check if we have the project, if we do, return the ID
|
||||
// // If we do not have the project, insert the project and return the ID
|
||||
|
||||
// let db = rusqlite::Connection::open(VECTOR_DB_URL)?;
|
||||
|
||||
// let projects_query = db.prepare(&format!(
|
||||
// "SELECT id FROM projects WHERE path = {}",
|
||||
// project_path.to_str().unwrap() // This is unsafe
|
||||
// ))?;
|
||||
|
||||
// let project_id = db.last_insert_rowid();
|
||||
|
||||
// return Ok(project_id as usize);
|
||||
// }
|
||||
|
||||
pub fn insert_file(&self, indexed_file: IndexedFile) -> Result<()> {
|
||||
pub fn insert_file(&self, worktree_id: i64, indexed_file: IndexedFile) -> Result<()> {
|
||||
// Write to files table, and return generated id.
|
||||
let files_insert = self.db.execute(
|
||||
"INSERT INTO files (relative_path, sha1) VALUES (?1, ?2)",
|
||||
params![indexed_file.path.to_str(), indexed_file.sha1],
|
||||
log::info!("Inserting File!");
|
||||
self.db.execute(
|
||||
"
|
||||
DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2;
|
||||
",
|
||||
params![worktree_id, indexed_file.path.to_str()],
|
||||
)?;
|
||||
self.db.execute(
|
||||
"
|
||||
INSERT INTO files (worktree_id, relative_path, sha1) VALUES (?1, ?2, $3);
|
||||
",
|
||||
params![worktree_id, indexed_file.path.to_str(), indexed_file.sha1],
|
||||
)?;
|
||||
|
||||
let inserted_id = self.db.last_insert_rowid();
|
||||
let file_id = self.db.last_insert_rowid();
|
||||
|
||||
// Currently inserting at approximately 3400 documents a second
|
||||
// I imagine we can speed this up with a bulk insert of some kind.
|
||||
|
@ -135,7 +184,7 @@ impl VectorDatabase {
|
|||
self.db.execute(
|
||||
"INSERT INTO documents (file_id, offset, name, embedding) VALUES (?1, ?2, ?3, ?4)",
|
||||
params![
|
||||
inserted_id,
|
||||
file_id,
|
||||
document.offset.to_string(),
|
||||
document.name,
|
||||
embedding_blob
|
||||
|
@ -147,25 +196,41 @@ impl VectorDatabase {
|
|||
}
|
||||
|
||||
pub fn find_or_create_worktree(&self, worktree_root_path: &Path) -> Result<i64> {
|
||||
// Check that the absolute path doesnt exist
|
||||
let mut worktree_query = self
|
||||
.db
|
||||
.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
|
||||
|
||||
let worktree_id = worktree_query
|
||||
.query_row(params![worktree_root_path.to_string_lossy()], |row| {
|
||||
Ok(row.get::<_, i64>(0)?)
|
||||
})
|
||||
.map_err(|err| anyhow!(err));
|
||||
|
||||
if worktree_id.is_ok() {
|
||||
return worktree_id;
|
||||
}
|
||||
|
||||
// If worktree_id is Err, insert new worktree
|
||||
self.db.execute(
|
||||
"
|
||||
INSERT into worktrees (absolute_path) VALUES (?1)
|
||||
ON CONFLICT DO NOTHING
|
||||
",
|
||||
params![worktree_root_path.to_string_lossy()],
|
||||
)?;
|
||||
Ok(self.db.last_insert_rowid())
|
||||
}
|
||||
|
||||
pub fn get_file_hashes(&self, worktree_id: i64) -> Result<Vec<(PathBuf, String)>> {
|
||||
let mut statement = self
|
||||
.db
|
||||
.prepare("SELECT relative_path, sha1 FROM files ORDER BY relative_path")?;
|
||||
let mut result = Vec::new();
|
||||
for row in
|
||||
statement.query_map([], |row| Ok((row.get::<_, String>(0)?.into(), row.get(1)?)))?
|
||||
{
|
||||
result.push(row?);
|
||||
pub fn get_file_hashes(&self, worktree_id: i64) -> Result<HashMap<PathBuf, FileSha1>> {
|
||||
let mut statement = self.db.prepare(
|
||||
"SELECT relative_path, sha1 FROM files WHERE worktree_id = ?1 ORDER BY relative_path",
|
||||
)?;
|
||||
let mut result: HashMap<PathBuf, FileSha1> = HashMap::new();
|
||||
for row in statement.query_map(params![worktree_id], |row| {
|
||||
Ok((row.get::<_, String>(0)?.into(), row.get(1)?))
|
||||
})? {
|
||||
let row = row?;
|
||||
result.insert(row.0, row.1);
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
@ -204,6 +269,53 @@ impl VectorDatabase {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub fn get_documents_by_ids(&self, ids: &[i64]) -> Result<Vec<(PathBuf, usize, String)>> {
|
||||
let mut statement = self.db.prepare(
|
||||
"
|
||||
SELECT
|
||||
documents.id, files.relative_path, documents.offset, documents.name
|
||||
FROM
|
||||
documents, files
|
||||
WHERE
|
||||
documents.file_id = files.id AND
|
||||
documents.id in rarray(?)
|
||||
",
|
||||
)?;
|
||||
|
||||
let result_iter = statement.query_map(
|
||||
params![std::rc::Rc::new(
|
||||
ids.iter()
|
||||
.copied()
|
||||
.map(|v| rusqlite::types::Value::from(v))
|
||||
.collect::<Vec<_>>()
|
||||
)],
|
||||
|row| {
|
||||
Ok((
|
||||
row.get::<_, i64>(0)?,
|
||||
row.get::<_, String>(1)?.into(),
|
||||
row.get(2)?,
|
||||
row.get(3)?,
|
||||
))
|
||||
},
|
||||
)?;
|
||||
|
||||
let mut values_by_id = HashMap::<i64, (PathBuf, usize, String)>::default();
|
||||
for row in result_iter {
|
||||
let (id, path, offset, name) = row?;
|
||||
values_by_id.insert(id, (path, offset, name));
|
||||
}
|
||||
|
||||
let mut results = Vec::with_capacity(ids.len());
|
||||
for id in ids {
|
||||
let (path, offset, name) = values_by_id
|
||||
.remove(id)
|
||||
.ok_or(anyhow!("missing document id {}", id))?;
|
||||
results.push((path, offset, name));
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
pub fn get_documents(&self) -> Result<HashMap<usize, DocumentRecord>> {
|
||||
let mut query_statement = self
|
||||
.db
|
||||
|
|
|
@ -7,15 +7,14 @@ mod search;
|
|||
mod vector_store_tests;
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use db::{VectorDatabase, VECTOR_DB_URL};
|
||||
use embedding::{DummyEmbeddings, EmbeddingProvider, OpenAIEmbeddings};
|
||||
use db::{FileSha1, VectorDatabase, VECTOR_DB_URL};
|
||||
use embedding::{EmbeddingProvider, OpenAIEmbeddings};
|
||||
use gpui::{AppContext, Entity, ModelContext, ModelHandle, Task};
|
||||
use language::LanguageRegistry;
|
||||
use language::{Language, LanguageRegistry};
|
||||
use parsing::Document;
|
||||
use project::{Fs, Project};
|
||||
use search::{BruteForceSearch, VectorSearch};
|
||||
use smol::channel;
|
||||
use std::{cmp::Ordering, path::PathBuf, sync::Arc, time::Instant};
|
||||
use std::{cmp::Ordering, collections::HashMap, path::PathBuf, sync::Arc};
|
||||
use tree_sitter::{Parser, QueryCursor};
|
||||
use util::{http::HttpClient, ResultExt, TryFutureExt};
|
||||
use workspace::WorkspaceCreated;
|
||||
|
@ -45,7 +44,7 @@ pub fn init(
|
|||
let project = workspace.read(cx).project().clone();
|
||||
if project.read(cx).is_local() {
|
||||
vector_store.update(cx, |store, cx| {
|
||||
store.add_project(project, cx);
|
||||
store.add_project(project, cx).detach();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
@ -57,16 +56,10 @@ pub fn init(
|
|||
#[derive(Debug)]
|
||||
pub struct IndexedFile {
|
||||
path: PathBuf,
|
||||
sha1: String,
|
||||
sha1: FileSha1,
|
||||
documents: Vec<Document>,
|
||||
}
|
||||
|
||||
// struct SearchResult {
|
||||
// path: PathBuf,
|
||||
// offset: usize,
|
||||
// name: String,
|
||||
// distance: f32,
|
||||
// }
|
||||
struct VectorStore {
|
||||
fs: Arc<dyn Fs>,
|
||||
database_url: Arc<str>,
|
||||
|
@ -99,20 +92,10 @@ impl VectorStore {
|
|||
cursor: &mut QueryCursor,
|
||||
parser: &mut Parser,
|
||||
embedding_provider: &dyn EmbeddingProvider,
|
||||
language_registry: &Arc<LanguageRegistry>,
|
||||
language: Arc<Language>,
|
||||
file_path: PathBuf,
|
||||
content: String,
|
||||
) -> Result<IndexedFile> {
|
||||
dbg!(&file_path, &content);
|
||||
|
||||
let language = language_registry
|
||||
.language_for_file(&file_path, None)
|
||||
.await?;
|
||||
|
||||
if language.name().as_ref() != "Rust" {
|
||||
Err(anyhow!("unsupported language"))?;
|
||||
}
|
||||
|
||||
let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?;
|
||||
let outline_config = grammar
|
||||
.outline_config
|
||||
|
@ -156,9 +139,11 @@ impl VectorStore {
|
|||
document.embedding = embedding;
|
||||
}
|
||||
|
||||
let sha1 = FileSha1::from_str(content);
|
||||
|
||||
return Ok(IndexedFile {
|
||||
path: file_path,
|
||||
sha1: String::new(),
|
||||
sha1,
|
||||
documents,
|
||||
});
|
||||
}
|
||||
|
@ -171,7 +156,13 @@ impl VectorStore {
|
|||
let worktree_scans_complete = project
|
||||
.read(cx)
|
||||
.worktrees(cx)
|
||||
.map(|worktree| worktree.read(cx).as_local().unwrap().scan_complete())
|
||||
.map(|worktree| {
|
||||
let scan_complete = worktree.read(cx).as_local().unwrap().scan_complete();
|
||||
async move {
|
||||
scan_complete.await;
|
||||
log::info!("worktree scan completed");
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let fs = self.fs.clone();
|
||||
|
@ -182,6 +173,13 @@ impl VectorStore {
|
|||
cx.spawn(|_, cx| async move {
|
||||
futures::future::join_all(worktree_scans_complete).await;
|
||||
|
||||
// TODO: remove this after fixing the bug in scan_complete
|
||||
cx.background()
|
||||
.timer(std::time::Duration::from_secs(3))
|
||||
.await;
|
||||
|
||||
let db = VectorDatabase::new(&database_url)?;
|
||||
|
||||
let worktrees = project.read_with(&cx, |project, cx| {
|
||||
project
|
||||
.worktrees(cx)
|
||||
|
@ -189,37 +187,74 @@ impl VectorStore {
|
|||
.collect::<Vec<_>>()
|
||||
});
|
||||
|
||||
let db = VectorDatabase::new(&database_url)?;
|
||||
let worktree_root_paths = worktrees
|
||||
.iter()
|
||||
.map(|worktree| worktree.abs_path().clone())
|
||||
.collect::<Vec<_>>();
|
||||
let (db, file_hashes) = cx
|
||||
|
||||
// Here we query the worktree ids, and yet we dont have them elsewhere
|
||||
// We likely want to clean up these datastructures
|
||||
let (db, worktree_hashes, worktree_ids) = cx
|
||||
.background()
|
||||
.spawn(async move {
|
||||
let mut hashes = Vec::new();
|
||||
let mut worktree_ids: HashMap<PathBuf, i64> = HashMap::new();
|
||||
let mut hashes: HashMap<i64, HashMap<PathBuf, FileSha1>> = HashMap::new();
|
||||
for worktree_root_path in worktree_root_paths {
|
||||
let worktree_id =
|
||||
db.find_or_create_worktree(worktree_root_path.as_ref())?;
|
||||
hashes.push((worktree_id, db.get_file_hashes(worktree_id)?));
|
||||
worktree_ids.insert(worktree_root_path.to_path_buf(), worktree_id);
|
||||
hashes.insert(worktree_id, db.get_file_hashes(worktree_id)?);
|
||||
}
|
||||
anyhow::Ok((db, hashes))
|
||||
anyhow::Ok((db, hashes, worktree_ids))
|
||||
})
|
||||
.await?;
|
||||
|
||||
let (paths_tx, paths_rx) = channel::unbounded::<(i64, PathBuf, String)>();
|
||||
let (indexed_files_tx, indexed_files_rx) = channel::unbounded::<IndexedFile>();
|
||||
let (paths_tx, paths_rx) =
|
||||
channel::unbounded::<(i64, PathBuf, String, Arc<Language>)>();
|
||||
let (indexed_files_tx, indexed_files_rx) = channel::unbounded::<(i64, IndexedFile)>();
|
||||
cx.background()
|
||||
.spawn({
|
||||
let fs = fs.clone();
|
||||
async move {
|
||||
for worktree in worktrees.into_iter() {
|
||||
let worktree_id = worktree_ids[&worktree.abs_path().to_path_buf()];
|
||||
let file_hashes = &worktree_hashes[&worktree_id];
|
||||
for file in worktree.files(false, 0) {
|
||||
let absolute_path = worktree.absolutize(&file.path);
|
||||
dbg!(&absolute_path);
|
||||
if let Some(content) = fs.load(&absolute_path).await.log_err() {
|
||||
dbg!(&content);
|
||||
paths_tx.try_send((0, absolute_path, content)).unwrap();
|
||||
|
||||
if let Ok(language) = language_registry
|
||||
.language_for_file(&absolute_path, None)
|
||||
.await
|
||||
{
|
||||
if language.name().as_ref() != "Rust" {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(content) = fs.load(&absolute_path).await.log_err() {
|
||||
log::info!("loaded file: {absolute_path:?}");
|
||||
|
||||
let path_buf = file.path.to_path_buf();
|
||||
let already_stored = file_hashes
|
||||
.get(&path_buf)
|
||||
.map_or(false, |existing_hash| {
|
||||
existing_hash.equals(&content)
|
||||
});
|
||||
|
||||
if !already_stored {
|
||||
log::info!(
|
||||
"File Changed (Sending to Parse): {:?}",
|
||||
&path_buf
|
||||
);
|
||||
paths_tx
|
||||
.try_send((
|
||||
worktree_id,
|
||||
path_buf,
|
||||
content,
|
||||
language,
|
||||
))
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -230,8 +265,8 @@ impl VectorStore {
|
|||
let db_write_task = cx.background().spawn(
|
||||
async move {
|
||||
// Initialize Database, creates database and tables if not exists
|
||||
while let Ok(indexed_file) = indexed_files_rx.recv().await {
|
||||
db.insert_file(indexed_file).log_err();
|
||||
while let Ok((worktree_id, indexed_file)) = indexed_files_rx.recv().await {
|
||||
db.insert_file(worktree_id, indexed_file).log_err();
|
||||
}
|
||||
|
||||
// ALL OF THE BELOW IS FOR TESTING,
|
||||
|
@ -271,29 +306,29 @@ impl VectorStore {
|
|||
.log_err(),
|
||||
);
|
||||
|
||||
let provider = DummyEmbeddings {};
|
||||
// let provider = OpenAIEmbeddings { client };
|
||||
|
||||
cx.background()
|
||||
.scoped(|scope| {
|
||||
for _ in 0..cx.background().num_cpus() {
|
||||
scope.spawn(async {
|
||||
let mut parser = Parser::new();
|
||||
let mut cursor = QueryCursor::new();
|
||||
while let Ok((worktree_id, file_path, content)) = paths_rx.recv().await
|
||||
while let Ok((worktree_id, file_path, content, language)) =
|
||||
paths_rx.recv().await
|
||||
{
|
||||
if let Some(indexed_file) = Self::index_file(
|
||||
&mut cursor,
|
||||
&mut parser,
|
||||
&provider,
|
||||
&language_registry,
|
||||
embedding_provider.as_ref(),
|
||||
language,
|
||||
file_path,
|
||||
content,
|
||||
)
|
||||
.await
|
||||
.log_err()
|
||||
{
|
||||
indexed_files_tx.try_send(indexed_file).unwrap();
|
||||
indexed_files_tx
|
||||
.try_send((worktree_id, indexed_file))
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
});
|
||||
|
@ -315,41 +350,42 @@ impl VectorStore {
|
|||
) -> Task<Result<Vec<SearchResult>>> {
|
||||
let embedding_provider = self.embedding_provider.clone();
|
||||
let database_url = self.database_url.clone();
|
||||
cx.spawn(|this, cx| async move {
|
||||
cx.background().spawn(async move {
|
||||
let database = VectorDatabase::new(database_url.as_ref())?;
|
||||
|
||||
// let embedding = embedding_provider.embed_batch(vec![&phrase]).await?;
|
||||
//
|
||||
let phrase_embedding = embedding_provider
|
||||
.embed_batch(vec![&phrase])
|
||||
.await?
|
||||
.into_iter()
|
||||
.next()
|
||||
.unwrap();
|
||||
|
||||
let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
|
||||
|
||||
database.for_each_document(0, |id, embedding| {
|
||||
dbg!(id, &embedding);
|
||||
|
||||
let similarity = dot(&embedding.0, &embedding.0);
|
||||
let similarity = dot(&embedding.0, &phrase_embedding);
|
||||
let ix = match results.binary_search_by(|(_, s)| {
|
||||
s.partial_cmp(&similarity).unwrap_or(Ordering::Equal)
|
||||
similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
|
||||
}) {
|
||||
Ok(ix) => ix,
|
||||
Err(ix) => ix,
|
||||
};
|
||||
|
||||
results.insert(ix, (id, similarity));
|
||||
results.truncate(limit);
|
||||
})?;
|
||||
|
||||
dbg!(&results);
|
||||
|
||||
let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
|
||||
// let documents = database.get_documents_by_ids(ids)?;
|
||||
let documents = database.get_documents_by_ids(&ids)?;
|
||||
|
||||
// let search_provider = cx
|
||||
// .background()
|
||||
// .spawn(async move { BruteForceSearch::load(&database) })
|
||||
// .await?;
|
||||
|
||||
// let results = search_provider.top_k_search(&embedding, limit))
|
||||
|
||||
anyhow::Ok(vec![])
|
||||
anyhow::Ok(
|
||||
documents
|
||||
.into_iter()
|
||||
.map(|(file_path, offset, name)| SearchResult {
|
||||
name,
|
||||
offset,
|
||||
file_path,
|
||||
})
|
||||
.collect(),
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -57,20 +57,26 @@ async fn test_vector_store(cx: &mut TestAppContext) {
|
|||
);
|
||||
languages.add(rust_language);
|
||||
|
||||
let db_dir = tempdir::TempDir::new("vector-store").unwrap();
|
||||
let db_path = db_dir.path().join("db.sqlite");
|
||||
|
||||
let store = cx.add_model(|_| {
|
||||
VectorStore::new(
|
||||
fs.clone(),
|
||||
"foo".to_string(),
|
||||
db_path.to_string_lossy().to_string(),
|
||||
Arc::new(FakeEmbeddingProvider),
|
||||
languages,
|
||||
)
|
||||
});
|
||||
|
||||
let project = Project::test(fs, ["/the-root".as_ref()], cx).await;
|
||||
store
|
||||
.update(cx, |store, cx| store.add_project(project, cx))
|
||||
.await
|
||||
.unwrap();
|
||||
let add_project = store.update(cx, |store, cx| store.add_project(project, cx));
|
||||
|
||||
// TODO - remove
|
||||
cx.foreground()
|
||||
.advance_clock(std::time::Duration::from_secs(3));
|
||||
|
||||
add_project.await.unwrap();
|
||||
|
||||
let search_results = store
|
||||
.update(cx, |store, cx| store.search("aaaa".to_string(), 5, cx))
|
||||
|
@ -78,7 +84,7 @@ async fn test_vector_store(cx: &mut TestAppContext) {
|
|||
.unwrap();
|
||||
|
||||
assert_eq!(search_results[0].offset, 0);
|
||||
assert_eq!(search_results[1].name, "aaa");
|
||||
assert_eq!(search_results[0].name, "aaa");
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -114,9 +120,10 @@ impl EmbeddingProvider for FakeEmbeddingProvider {
|
|||
Ok(spans
|
||||
.iter()
|
||||
.map(|span| {
|
||||
let mut result = vec![0.0; 26];
|
||||
let mut result = vec![1.0; 26];
|
||||
for letter in span.chars() {
|
||||
if letter as u32 > 'a' as u32 {
|
||||
let letter = letter.to_ascii_lowercase();
|
||||
if letter as u32 >= 'a' as u32 {
|
||||
let ix = (letter as u32) - ('a' as u32);
|
||||
if ix < 26 {
|
||||
result[ix as usize] += 1.0;
|
||||
|
|
Loading…
Reference in a new issue