mirror of
https://github.com/zed-industries/zed.git
synced 2025-01-27 04:44:30 +00:00
remove reindexing subscription, and add status methods for vector store
Co-authored-by: maxbrunsfeld <max@zed.dev>
This commit is contained in:
parent
d8fd0be598
commit
b38e3b804c
5 changed files with 208 additions and 253 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -8493,6 +8493,7 @@ dependencies = [
|
|||
"lazy_static",
|
||||
"log",
|
||||
"matrixmultiply",
|
||||
"parking_lot 0.11.2",
|
||||
"picker",
|
||||
"project",
|
||||
"rand 0.8.5",
|
||||
|
|
|
@ -33,6 +33,7 @@ async-trait.workspace = true
|
|||
bincode = "1.3.3"
|
||||
matrixmultiply = "0.3.7"
|
||||
tiktoken-rs = "0.5.0"
|
||||
parking_lot.workspace = true
|
||||
rand.workspace = true
|
||||
schemars.workspace = true
|
||||
|
||||
|
|
|
@ -124,7 +124,7 @@ impl PickerDelegate for SemanticSearchDelegate {
|
|||
if let Some(retrieved) = retrieved_cached.log_err() {
|
||||
if !retrieved {
|
||||
let task = vector_store.update(&mut cx, |store, cx| {
|
||||
store.search(project.clone(), query.to_string(), 10, cx)
|
||||
store.search_project(project.clone(), query.to_string(), 10, cx)
|
||||
});
|
||||
|
||||
if let Some(results) = task.await.log_err() {
|
||||
|
|
|
@ -18,15 +18,19 @@ use gpui::{
|
|||
};
|
||||
use language::{Language, LanguageRegistry};
|
||||
use modal::{SemanticSearch, SemanticSearchDelegate, Toggle};
|
||||
use parking_lot::Mutex;
|
||||
use parsing::{CodeContextRetriever, Document};
|
||||
use project::{Fs, PathChange, Project, ProjectEntryId, WorktreeId};
|
||||
use project::{Fs, Project, WorktreeId};
|
||||
use smol::channel;
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
collections::{HashMap, HashSet},
|
||||
ops::Range,
|
||||
path::{Path, PathBuf},
|
||||
sync::Arc,
|
||||
time::{Duration, Instant, SystemTime},
|
||||
sync::{
|
||||
atomic::{self, AtomicUsize},
|
||||
Arc, Weak,
|
||||
},
|
||||
time::{Instant, SystemTime},
|
||||
};
|
||||
use util::{
|
||||
channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME},
|
||||
|
@ -99,7 +103,7 @@ pub fn init(
|
|||
let project = workspace.read(cx).project().clone();
|
||||
if project.read(cx).is_local() {
|
||||
vector_store.update(cx, |store, cx| {
|
||||
store.add_project(project, cx).detach();
|
||||
store.index_project(project, cx).detach();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
@ -124,13 +128,20 @@ pub struct VectorStore {
|
|||
_embed_batch_task: Task<()>,
|
||||
_batch_files_task: Task<()>,
|
||||
_parsing_files_tasks: Vec<Task<()>>,
|
||||
next_job_id: Arc<AtomicUsize>,
|
||||
projects: HashMap<WeakModelHandle<Project>, ProjectState>,
|
||||
}
|
||||
|
||||
struct ProjectState {
|
||||
worktree_db_ids: Vec<(WorktreeId, i64)>,
|
||||
pending_files: HashMap<PathBuf, (PendingFile, SystemTime)>,
|
||||
_subscription: gpui::Subscription,
|
||||
outstanding_jobs: Arc<Mutex<HashSet<JobId>>>,
|
||||
}
|
||||
|
||||
type JobId = usize;
|
||||
|
||||
struct JobHandle {
|
||||
id: JobId,
|
||||
set: Weak<Mutex<HashSet<JobId>>>,
|
||||
}
|
||||
|
||||
impl ProjectState {
|
||||
|
@ -157,54 +168,15 @@ impl ProjectState {
|
|||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn update_pending_files(&mut self, pending_file: PendingFile, indexing_time: SystemTime) {
|
||||
// If Pending File Already Exists, Replace it with the new one
|
||||
// but keep the old indexing time
|
||||
if let Some(old_file) = self
|
||||
.pending_files
|
||||
.remove(&pending_file.relative_path.clone())
|
||||
{
|
||||
self.pending_files.insert(
|
||||
pending_file.relative_path.clone(),
|
||||
(pending_file, old_file.1),
|
||||
);
|
||||
} else {
|
||||
self.pending_files.insert(
|
||||
pending_file.relative_path.clone(),
|
||||
(pending_file, indexing_time),
|
||||
);
|
||||
};
|
||||
}
|
||||
|
||||
fn get_outstanding_files(&mut self) -> Vec<PendingFile> {
|
||||
let mut outstanding_files = vec![];
|
||||
let mut remove_keys = vec![];
|
||||
for key in self.pending_files.keys().into_iter() {
|
||||
if let Some(pending_details) = self.pending_files.get(key) {
|
||||
let (pending_file, index_time) = pending_details;
|
||||
if index_time <= &SystemTime::now() {
|
||||
outstanding_files.push(pending_file.clone());
|
||||
remove_keys.push(key.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for key in remove_keys.iter() {
|
||||
self.pending_files.remove(key);
|
||||
}
|
||||
|
||||
return outstanding_files;
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct PendingFile {
|
||||
worktree_db_id: i64,
|
||||
relative_path: PathBuf,
|
||||
absolute_path: PathBuf,
|
||||
language: Arc<Language>,
|
||||
modified_time: SystemTime,
|
||||
job_handle: JobHandle,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
|
@ -221,6 +193,7 @@ enum DbOperation {
|
|||
documents: Vec<Document>,
|
||||
path: PathBuf,
|
||||
mtime: SystemTime,
|
||||
job_handle: JobHandle,
|
||||
},
|
||||
Delete {
|
||||
worktree_id: i64,
|
||||
|
@ -242,6 +215,7 @@ enum EmbeddingJob {
|
|||
path: PathBuf,
|
||||
mtime: SystemTime,
|
||||
documents: Vec<Document>,
|
||||
job_handle: JobHandle,
|
||||
},
|
||||
Flush,
|
||||
}
|
||||
|
@ -274,9 +248,11 @@ impl VectorStore {
|
|||
documents,
|
||||
path,
|
||||
mtime,
|
||||
job_handle,
|
||||
} => {
|
||||
db.insert_file(worktree_id, path, mtime, documents)
|
||||
.log_err();
|
||||
drop(job_handle)
|
||||
}
|
||||
DbOperation::Delete { worktree_id, path } => {
|
||||
db.delete_file(worktree_id, path).log_err();
|
||||
|
@ -298,7 +274,7 @@ impl VectorStore {
|
|||
|
||||
// embed_tx/rx: Embed Batch and Send to Database
|
||||
let (embed_batch_tx, embed_batch_rx) =
|
||||
channel::unbounded::<Vec<(i64, Vec<Document>, PathBuf, SystemTime)>>();
|
||||
channel::unbounded::<Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>>();
|
||||
let _embed_batch_task = cx.background().spawn({
|
||||
let db_update_tx = db_update_tx.clone();
|
||||
let embedding_provider = embedding_provider.clone();
|
||||
|
@ -306,7 +282,7 @@ impl VectorStore {
|
|||
while let Ok(mut embeddings_queue) = embed_batch_rx.recv().await {
|
||||
// Construct Batch
|
||||
let mut batch_documents = vec![];
|
||||
for (_, documents, _, _) in embeddings_queue.iter() {
|
||||
for (_, documents, _, _, _) in embeddings_queue.iter() {
|
||||
batch_documents
|
||||
.extend(documents.iter().map(|document| document.content.as_str()));
|
||||
}
|
||||
|
@ -333,7 +309,7 @@ impl VectorStore {
|
|||
j += 1;
|
||||
}
|
||||
|
||||
for (worktree_id, documents, path, mtime) in
|
||||
for (worktree_id, documents, path, mtime, job_handle) in
|
||||
embeddings_queue.into_iter()
|
||||
{
|
||||
for document in documents.iter() {
|
||||
|
@ -350,6 +326,7 @@ impl VectorStore {
|
|||
documents,
|
||||
path,
|
||||
mtime,
|
||||
job_handle,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
@ -372,9 +349,16 @@ impl VectorStore {
|
|||
worktree_id,
|
||||
path,
|
||||
mtime,
|
||||
job_handle,
|
||||
} => {
|
||||
queue_len += &documents.len();
|
||||
embeddings_queue.push((worktree_id, documents, path, mtime));
|
||||
embeddings_queue.push((
|
||||
worktree_id,
|
||||
documents,
|
||||
path,
|
||||
mtime,
|
||||
job_handle,
|
||||
));
|
||||
queue_len >= EMBEDDINGS_BATCH_SIZE
|
||||
}
|
||||
EmbeddingJob::Flush => true,
|
||||
|
@ -420,6 +404,7 @@ impl VectorStore {
|
|||
worktree_id: pending_file.worktree_db_id,
|
||||
path: pending_file.relative_path,
|
||||
mtime: pending_file.modified_time,
|
||||
job_handle: pending_file.job_handle,
|
||||
documents,
|
||||
})
|
||||
.unwrap();
|
||||
|
@ -439,6 +424,7 @@ impl VectorStore {
|
|||
embedding_provider,
|
||||
language_registry,
|
||||
db_update_tx,
|
||||
next_job_id: Default::default(),
|
||||
parsing_files_tx,
|
||||
_db_update_task,
|
||||
_embed_batch_task,
|
||||
|
@ -471,11 +457,11 @@ impl VectorStore {
|
|||
async move { rx.await? }
|
||||
}
|
||||
|
||||
fn add_project(
|
||||
fn index_project(
|
||||
&mut self,
|
||||
project: ModelHandle<Project>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Task<Result<()>> {
|
||||
) -> Task<Result<usize>> {
|
||||
let worktree_scans_complete = project
|
||||
.read(cx)
|
||||
.worktrees(cx)
|
||||
|
@ -494,21 +480,16 @@ impl VectorStore {
|
|||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let fs = self.fs.clone();
|
||||
let language_registry = self.language_registry.clone();
|
||||
let database_url = self.database_url.clone();
|
||||
let db_update_tx = self.db_update_tx.clone();
|
||||
let parsing_files_tx = self.parsing_files_tx.clone();
|
||||
let next_job_id = self.next_job_id.clone();
|
||||
|
||||
cx.spawn(|this, mut cx| async move {
|
||||
futures::future::join_all(worktree_scans_complete).await;
|
||||
|
||||
let worktree_db_ids = futures::future::join_all(worktree_db_ids).await;
|
||||
|
||||
if let Some(db_directory) = database_url.parent() {
|
||||
fs.create_dir(db_directory).await.log_err();
|
||||
}
|
||||
|
||||
let worktrees = project.read_with(&cx, |project, cx| {
|
||||
project
|
||||
.worktrees(cx)
|
||||
|
@ -516,109 +497,115 @@ impl VectorStore {
|
|||
.collect::<Vec<_>>()
|
||||
});
|
||||
|
||||
let mut worktree_file_times = HashMap::new();
|
||||
let mut worktree_file_mtimes = HashMap::new();
|
||||
let mut db_ids_by_worktree_id = HashMap::new();
|
||||
for (worktree, db_id) in worktrees.iter().zip(worktree_db_ids) {
|
||||
let db_id = db_id?;
|
||||
db_ids_by_worktree_id.insert(worktree.id(), db_id);
|
||||
worktree_file_times.insert(
|
||||
worktree_file_mtimes.insert(
|
||||
worktree.id(),
|
||||
this.read_with(&cx, |this, _| this.get_file_mtimes(db_id))
|
||||
.await?,
|
||||
);
|
||||
}
|
||||
|
||||
cx.background()
|
||||
.spawn({
|
||||
let db_ids_by_worktree_id = db_ids_by_worktree_id.clone();
|
||||
let db_update_tx = db_update_tx.clone();
|
||||
let language_registry = language_registry.clone();
|
||||
let parsing_files_tx = parsing_files_tx.clone();
|
||||
async move {
|
||||
let t0 = Instant::now();
|
||||
for worktree in worktrees.into_iter() {
|
||||
let mut file_mtimes =
|
||||
worktree_file_times.remove(&worktree.id()).unwrap();
|
||||
for file in worktree.files(false, 0) {
|
||||
let absolute_path = worktree.absolutize(&file.path);
|
||||
|
||||
if let Ok(language) = language_registry
|
||||
.language_for_file(&absolute_path, None)
|
||||
.await
|
||||
{
|
||||
if language
|
||||
.grammar()
|
||||
.and_then(|grammar| grammar.embedding_config.as_ref())
|
||||
.is_none()
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
let path_buf = file.path.to_path_buf();
|
||||
let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
|
||||
let already_stored = stored_mtime
|
||||
.map_or(false, |existing_mtime| {
|
||||
existing_mtime == file.mtime
|
||||
});
|
||||
|
||||
if !already_stored {
|
||||
log::trace!("sending for parsing: {:?}", path_buf);
|
||||
parsing_files_tx
|
||||
.try_send(PendingFile {
|
||||
worktree_db_id: db_ids_by_worktree_id
|
||||
[&worktree.id()],
|
||||
relative_path: path_buf,
|
||||
absolute_path,
|
||||
language,
|
||||
modified_time: file.mtime,
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
for file in file_mtimes.keys() {
|
||||
db_update_tx
|
||||
.try_send(DbOperation::Delete {
|
||||
worktree_id: db_ids_by_worktree_id[&worktree.id()],
|
||||
path: file.to_owned(),
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
log::trace!(
|
||||
"parsing worktree completed in {:?}",
|
||||
t0.elapsed().as_millis()
|
||||
);
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
|
||||
// let mut pending_files: Vec<(PathBuf, ((i64, PathBuf, Arc<Language>, SystemTime), SystemTime))> = vec![];
|
||||
this.update(&mut cx, |this, cx| {
|
||||
// The below is managing for updated on save
|
||||
// Currently each time a file is saved, this code is run, and for all the files that were changed, if the current time is
|
||||
// greater than the previous embedded time by the REINDEXING_DELAY variable, we will send the file off to be indexed.
|
||||
let _subscription = cx.subscribe(&project, |this, project, event, cx| {
|
||||
if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event {
|
||||
this.project_entries_changed(project, changes.clone(), cx, worktree_id);
|
||||
}
|
||||
});
|
||||
|
||||
let outstanding_jobs = Arc::new(Mutex::new(HashSet::new()));
|
||||
this.update(&mut cx, |this, _| {
|
||||
this.projects.insert(
|
||||
project.downgrade(),
|
||||
ProjectState {
|
||||
pending_files: HashMap::new(),
|
||||
worktree_db_ids: db_ids_by_worktree_id.into_iter().collect(),
|
||||
_subscription,
|
||||
worktree_db_ids: db_ids_by_worktree_id
|
||||
.iter()
|
||||
.map(|(a, b)| (*a, *b))
|
||||
.collect(),
|
||||
outstanding_jobs: outstanding_jobs.clone(),
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
anyhow::Ok(())
|
||||
cx.background()
|
||||
.spawn(async move {
|
||||
let mut count = 0;
|
||||
let t0 = Instant::now();
|
||||
for worktree in worktrees.into_iter() {
|
||||
let mut file_mtimes = worktree_file_mtimes.remove(&worktree.id()).unwrap();
|
||||
for file in worktree.files(false, 0) {
|
||||
let absolute_path = worktree.absolutize(&file.path);
|
||||
|
||||
if let Ok(language) = language_registry
|
||||
.language_for_file(&absolute_path, None)
|
||||
.await
|
||||
{
|
||||
if language
|
||||
.grammar()
|
||||
.and_then(|grammar| grammar.embedding_config.as_ref())
|
||||
.is_none()
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
let path_buf = file.path.to_path_buf();
|
||||
let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
|
||||
let already_stored = stored_mtime
|
||||
.map_or(false, |existing_mtime| existing_mtime == file.mtime);
|
||||
|
||||
if !already_stored {
|
||||
log::trace!("sending for parsing: {:?}", path_buf);
|
||||
count += 1;
|
||||
let job_id = next_job_id.fetch_add(1, atomic::Ordering::SeqCst);
|
||||
let job_handle = JobHandle {
|
||||
id: job_id,
|
||||
set: Arc::downgrade(&outstanding_jobs),
|
||||
};
|
||||
outstanding_jobs.lock().insert(job_id);
|
||||
parsing_files_tx
|
||||
.try_send(PendingFile {
|
||||
worktree_db_id: db_ids_by_worktree_id[&worktree.id()],
|
||||
relative_path: path_buf,
|
||||
absolute_path,
|
||||
language,
|
||||
job_handle,
|
||||
modified_time: file.mtime,
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
for file in file_mtimes.keys() {
|
||||
db_update_tx
|
||||
.try_send(DbOperation::Delete {
|
||||
worktree_id: db_ids_by_worktree_id[&worktree.id()],
|
||||
path: file.to_owned(),
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
log::trace!(
|
||||
"parsing worktree completed in {:?}",
|
||||
t0.elapsed().as_millis()
|
||||
);
|
||||
|
||||
Ok(count)
|
||||
})
|
||||
.await
|
||||
})
|
||||
}
|
||||
|
||||
pub fn search(
|
||||
pub fn remaining_files_to_index_for_project(
|
||||
&self,
|
||||
project: &ModelHandle<Project>,
|
||||
) -> Option<usize> {
|
||||
Some(
|
||||
self.projects
|
||||
.get(&project.downgrade())?
|
||||
.outstanding_jobs
|
||||
.lock()
|
||||
.len(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn search_project(
|
||||
&mut self,
|
||||
project: ModelHandle<Project>,
|
||||
phrase: String,
|
||||
|
@ -682,110 +669,16 @@ impl VectorStore {
|
|||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn project_entries_changed(
|
||||
&mut self,
|
||||
project: ModelHandle<Project>,
|
||||
changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
|
||||
cx: &mut ModelContext<'_, VectorStore>,
|
||||
worktree_id: &WorktreeId,
|
||||
) -> Option<()> {
|
||||
let reindexing_delay = settings::get::<VectorStoreSettings>(cx).reindexing_delay_seconds;
|
||||
|
||||
let worktree = project
|
||||
.read(cx)
|
||||
.worktree_for_id(worktree_id.clone(), cx)?
|
||||
.read(cx)
|
||||
.snapshot();
|
||||
|
||||
let worktree_db_id = self
|
||||
.projects
|
||||
.get(&project.downgrade())?
|
||||
.db_id_for_worktree_id(worktree.id())?;
|
||||
let file_mtimes = self.get_file_mtimes(worktree_db_id);
|
||||
|
||||
let language_registry = self.language_registry.clone();
|
||||
|
||||
cx.spawn(|this, mut cx| async move {
|
||||
let file_mtimes = file_mtimes.await.log_err()?;
|
||||
|
||||
for change in changes.into_iter() {
|
||||
let change_path = change.0.clone();
|
||||
let absolute_path = worktree.absolutize(&change_path);
|
||||
|
||||
// Skip if git ignored or symlink
|
||||
if let Some(entry) = worktree.entry_for_id(change.1) {
|
||||
if entry.is_ignored || entry.is_symlink || entry.is_external {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
match change.2 {
|
||||
PathChange::Removed => this.update(&mut cx, |this, _| {
|
||||
this.db_update_tx
|
||||
.try_send(DbOperation::Delete {
|
||||
worktree_id: worktree_db_id,
|
||||
path: absolute_path,
|
||||
})
|
||||
.unwrap();
|
||||
}),
|
||||
_ => {
|
||||
if let Ok(language) = language_registry
|
||||
.language_for_file(&change_path.to_path_buf(), None)
|
||||
.await
|
||||
{
|
||||
if language
|
||||
.grammar()
|
||||
.and_then(|grammar| grammar.embedding_config.as_ref())
|
||||
.is_none()
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
let modified_time =
|
||||
change_path.metadata().log_err()?.modified().log_err()?;
|
||||
|
||||
let existing_time = file_mtimes.get(&change_path.to_path_buf());
|
||||
let already_stored = existing_time
|
||||
.map_or(false, |existing_time| &modified_time != existing_time);
|
||||
|
||||
if !already_stored {
|
||||
this.update(&mut cx, |this, _| {
|
||||
let reindex_time = modified_time
|
||||
+ Duration::from_secs(reindexing_delay as u64);
|
||||
|
||||
let project_state =
|
||||
this.projects.get_mut(&project.downgrade())?;
|
||||
project_state.update_pending_files(
|
||||
PendingFile {
|
||||
relative_path: change_path.to_path_buf(),
|
||||
absolute_path,
|
||||
modified_time,
|
||||
worktree_db_id,
|
||||
language: language.clone(),
|
||||
},
|
||||
reindex_time,
|
||||
);
|
||||
|
||||
for file in project_state.get_outstanding_files() {
|
||||
this.parsing_files_tx.try_send(file).unwrap();
|
||||
}
|
||||
Some(())
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Some(())
|
||||
})
|
||||
.detach();
|
||||
|
||||
Some(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Entity for VectorStore {
|
||||
type Event = ();
|
||||
}
|
||||
|
||||
impl Drop for JobHandle {
|
||||
fn drop(&mut self) {
|
||||
if let Some(set) = self.set.upgrade() {
|
||||
set.lock().remove(&self.id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -9,11 +9,17 @@ use anyhow::Result;
|
|||
use async_trait::async_trait;
|
||||
use gpui::{Task, TestAppContext};
|
||||
use language::{Language, LanguageConfig, LanguageRegistry};
|
||||
use project::{project_settings::ProjectSettings, FakeFs, Project};
|
||||
use project::{project_settings::ProjectSettings, FakeFs, Fs, Project};
|
||||
use rand::{rngs::StdRng, Rng};
|
||||
use serde_json::json;
|
||||
use settings::SettingsStore;
|
||||
use std::{path::Path, sync::Arc};
|
||||
use std::{
|
||||
path::Path,
|
||||
sync::{
|
||||
atomic::{self, AtomicUsize},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
use unindent::Unindent;
|
||||
|
||||
#[ctor::ctor]
|
||||
|
@ -62,29 +68,37 @@ async fn test_vector_store(cx: &mut TestAppContext) {
|
|||
let db_dir = tempdir::TempDir::new("vector-store").unwrap();
|
||||
let db_path = db_dir.path().join("db.sqlite");
|
||||
|
||||
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||
let store = VectorStore::new(
|
||||
fs.clone(),
|
||||
db_path,
|
||||
Arc::new(FakeEmbeddingProvider),
|
||||
embedding_provider.clone(),
|
||||
languages,
|
||||
cx.to_async(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let project = Project::test(fs, ["/the-root".as_ref()], cx).await;
|
||||
let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await;
|
||||
let worktree_id = project.read_with(cx, |project, cx| {
|
||||
project.worktrees(cx).next().unwrap().read(cx).id()
|
||||
});
|
||||
store
|
||||
.update(cx, |store, cx| store.add_project(project.clone(), cx))
|
||||
let file_count = store
|
||||
.update(cx, |store, cx| store.index_project(project.clone(), cx))
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(file_count, 2);
|
||||
cx.foreground().run_until_parked();
|
||||
store.update(cx, |store, _cx| {
|
||||
assert_eq!(
|
||||
store.remaining_files_to_index_for_project(&project),
|
||||
Some(0)
|
||||
);
|
||||
});
|
||||
|
||||
let search_results = store
|
||||
.update(cx, |store, cx| {
|
||||
store.search(project.clone(), "aaaa".to_string(), 5, cx)
|
||||
store.search_project(project.clone(), "aaaa".to_string(), 5, cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
@ -92,10 +106,45 @@ async fn test_vector_store(cx: &mut TestAppContext) {
|
|||
assert_eq!(search_results[0].byte_range.start, 0);
|
||||
assert_eq!(search_results[0].name, "aaa");
|
||||
assert_eq!(search_results[0].worktree_id, worktree_id);
|
||||
|
||||
fs.save(
|
||||
"/the-root/src/file2.rs".as_ref(),
|
||||
&"
|
||||
fn dddd() { println!(\"ddddd!\"); }
|
||||
struct pqpqpqp {}
|
||||
"
|
||||
.unindent()
|
||||
.into(),
|
||||
Default::default(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
cx.foreground().run_until_parked();
|
||||
|
||||
let prev_embedding_count = embedding_provider.embedding_count();
|
||||
let file_count = store
|
||||
.update(cx, |store, cx| store.index_project(project.clone(), cx))
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(file_count, 1);
|
||||
|
||||
cx.foreground().run_until_parked();
|
||||
store.update(cx, |store, _cx| {
|
||||
assert_eq!(
|
||||
store.remaining_files_to_index_for_project(&project),
|
||||
Some(0)
|
||||
);
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
embedding_provider.embedding_count() - prev_embedding_count,
|
||||
2
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_code_context_retrieval(cx: &mut TestAppContext) {
|
||||
async fn test_code_context_retrieval() {
|
||||
let language = rust_lang();
|
||||
let mut retriever = CodeContextRetriever::new();
|
||||
|
||||
|
@ -181,11 +230,22 @@ fn test_dot_product(mut rng: StdRng) {
|
|||
}
|
||||
}
|
||||
|
||||
struct FakeEmbeddingProvider;
|
||||
#[derive(Default)]
|
||||
struct FakeEmbeddingProvider {
|
||||
embedding_count: AtomicUsize,
|
||||
}
|
||||
|
||||
impl FakeEmbeddingProvider {
|
||||
fn embedding_count(&self) -> usize {
|
||||
self.embedding_count.load(atomic::Ordering::SeqCst)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl EmbeddingProvider for FakeEmbeddingProvider {
|
||||
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
|
||||
self.embedding_count
|
||||
.fetch_add(spans.len(), atomic::Ordering::SeqCst);
|
||||
Ok(spans
|
||||
.iter()
|
||||
.map(|span| {
|
||||
|
|
Loading…
Reference in a new issue