mirror of
https://github.com/zed-industries/zed.git
synced 2025-01-27 04:44:30 +00:00
refactored code context retrieval and standardized database migration
Co-authored-by: maxbrunsfeld <max@zed.dev>
This commit is contained in:
parent
5eab628580
commit
0a0e40fb24
7 changed files with 232 additions and 148 deletions
2
Cargo.lock
generated
2
Cargo.lock
generated
|
@ -8483,7 +8483,9 @@ dependencies = [
|
|||
"anyhow",
|
||||
"async-trait",
|
||||
"bincode",
|
||||
"ctor",
|
||||
"editor",
|
||||
"env_logger 0.9.3",
|
||||
"futures 0.3.28",
|
||||
"gpui",
|
||||
"isahc",
|
||||
|
|
|
@ -44,6 +44,9 @@ rpc = { path = "../rpc", features = ["test-support"] }
|
|||
workspace = { path = "../workspace", features = ["test-support"] }
|
||||
settings = { path = "../settings", features = ["test-support"]}
|
||||
tree-sitter-rust = "*"
|
||||
|
||||
rand.workspace = true
|
||||
unindent.workspace = true
|
||||
tempdir.workspace = true
|
||||
ctor.workspace = true
|
||||
env_logger.workspace = true
|
||||
|
|
|
@ -1,20 +1,20 @@
|
|||
use std::{
|
||||
cmp::Ordering,
|
||||
collections::HashMap,
|
||||
path::{Path, PathBuf},
|
||||
rc::Rc,
|
||||
time::SystemTime,
|
||||
};
|
||||
|
||||
use crate::{parsing::Document, VECTOR_STORE_VERSION};
|
||||
use anyhow::{anyhow, Result};
|
||||
|
||||
use crate::parsing::ParsedFile;
|
||||
use crate::VECTOR_STORE_VERSION;
|
||||
use project::Fs;
|
||||
use rpc::proto::Timestamp;
|
||||
use rusqlite::{
|
||||
params,
|
||||
types::{FromSql, FromSqlResult, ValueRef},
|
||||
};
|
||||
use std::{
|
||||
cmp::Ordering,
|
||||
collections::HashMap,
|
||||
ops::Range,
|
||||
path::{Path, PathBuf},
|
||||
rc::Rc,
|
||||
sync::Arc,
|
||||
time::SystemTime,
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct FileRecord {
|
||||
|
@ -42,48 +42,88 @@ pub struct VectorDatabase {
|
|||
}
|
||||
|
||||
impl VectorDatabase {
|
||||
pub fn new(path: String) -> Result<Self> {
|
||||
pub async fn new(fs: Arc<dyn Fs>, path: Arc<PathBuf>) -> Result<Self> {
|
||||
if let Some(db_directory) = path.parent() {
|
||||
fs.create_dir(db_directory).await?;
|
||||
}
|
||||
|
||||
let this = Self {
|
||||
db: rusqlite::Connection::open(path)?,
|
||||
db: rusqlite::Connection::open(path.as_path())?,
|
||||
};
|
||||
this.initialize_database()?;
|
||||
Ok(this)
|
||||
}
|
||||
|
||||
fn get_existing_version(&self) -> Result<i64> {
|
||||
let mut version_query = self.db.prepare("SELECT version from vector_store_config")?;
|
||||
version_query
|
||||
.query_row([], |row| Ok(row.get::<_, i64>(0)?))
|
||||
.map_err(|err| anyhow!("version query failed: {err}"))
|
||||
}
|
||||
|
||||
fn initialize_database(&self) -> Result<()> {
|
||||
rusqlite::vtab::array::load_module(&self.db)?;
|
||||
|
||||
// This will create the database if it doesnt exist
|
||||
if self
|
||||
.get_existing_version()
|
||||
.map_or(false, |version| version == VECTOR_STORE_VERSION as i64)
|
||||
{
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
self.db
|
||||
.execute(
|
||||
"
|
||||
DROP TABLE vector_store_config;
|
||||
DROP TABLE worktrees;
|
||||
DROP TABLE files;
|
||||
DROP TABLE documents;
|
||||
",
|
||||
[],
|
||||
)
|
||||
.ok();
|
||||
|
||||
// Initialize Vector Databasing Tables
|
||||
self.db.execute(
|
||||
"CREATE TABLE IF NOT EXISTS worktrees (
|
||||
"CREATE TABLE vector_store_config (
|
||||
version INTEGER NOT NULL
|
||||
)",
|
||||
[],
|
||||
)?;
|
||||
|
||||
self.db.execute(
|
||||
"INSERT INTO vector_store_config (version) VALUES (?1)",
|
||||
params![VECTOR_STORE_VERSION],
|
||||
)?;
|
||||
|
||||
self.db.execute(
|
||||
"CREATE TABLE worktrees (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
absolute_path VARCHAR NOT NULL
|
||||
);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS worktrees_absolute_path ON worktrees (absolute_path);
|
||||
CREATE UNIQUE INDEX worktrees_absolute_path ON worktrees (absolute_path);
|
||||
",
|
||||
[],
|
||||
)?;
|
||||
|
||||
self.db.execute(
|
||||
"CREATE TABLE IF NOT EXISTS files (
|
||||
"CREATE TABLE files (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
worktree_id INTEGER NOT NULL,
|
||||
relative_path VARCHAR NOT NULL,
|
||||
mtime_seconds INTEGER NOT NULL,
|
||||
mtime_nanos INTEGER NOT NULL,
|
||||
vector_store_version INTEGER NOT NULL,
|
||||
FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
|
||||
)",
|
||||
[],
|
||||
)?;
|
||||
|
||||
self.db.execute(
|
||||
"CREATE TABLE IF NOT EXISTS documents (
|
||||
"CREATE TABLE documents (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
file_id INTEGER NOT NULL,
|
||||
offset INTEGER NOT NULL,
|
||||
start_byte INTEGER NOT NULL,
|
||||
end_byte INTEGER NOT NULL,
|
||||
name VARCHAR NOT NULL,
|
||||
embedding BLOB NOT NULL,
|
||||
FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
|
||||
|
@ -102,43 +142,44 @@ impl VectorDatabase {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub fn insert_file(&self, worktree_id: i64, indexed_file: ParsedFile) -> Result<()> {
|
||||
pub fn insert_file(
|
||||
&self,
|
||||
worktree_id: i64,
|
||||
path: PathBuf,
|
||||
mtime: SystemTime,
|
||||
documents: Vec<Document>,
|
||||
) -> Result<()> {
|
||||
// Write to files table, and return generated id.
|
||||
self.db.execute(
|
||||
"
|
||||
DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2;
|
||||
",
|
||||
params![worktree_id, indexed_file.path.to_str()],
|
||||
params![worktree_id, path.to_str()],
|
||||
)?;
|
||||
let mtime = Timestamp::from(indexed_file.mtime);
|
||||
let mtime = Timestamp::from(mtime);
|
||||
self.db.execute(
|
||||
"
|
||||
INSERT INTO files
|
||||
(worktree_id, relative_path, mtime_seconds, mtime_nanos, vector_store_version)
|
||||
(worktree_id, relative_path, mtime_seconds, mtime_nanos)
|
||||
VALUES
|
||||
(?1, ?2, $3, $4, $5);
|
||||
(?1, ?2, $3, $4);
|
||||
",
|
||||
params![
|
||||
worktree_id,
|
||||
indexed_file.path.to_str(),
|
||||
mtime.seconds,
|
||||
mtime.nanos,
|
||||
VECTOR_STORE_VERSION
|
||||
],
|
||||
params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos],
|
||||
)?;
|
||||
|
||||
let file_id = self.db.last_insert_rowid();
|
||||
|
||||
// Currently inserting at approximately 3400 documents a second
|
||||
// I imagine we can speed this up with a bulk insert of some kind.
|
||||
for document in indexed_file.documents {
|
||||
for document in documents {
|
||||
let embedding_blob = bincode::serialize(&document.embedding)?;
|
||||
|
||||
self.db.execute(
|
||||
"INSERT INTO documents (file_id, offset, name, embedding) VALUES (?1, ?2, ?3, ?4)",
|
||||
"INSERT INTO documents (file_id, start_byte, end_byte, name, embedding) VALUES (?1, ?2, ?3, ?4, $5)",
|
||||
params![
|
||||
file_id,
|
||||
document.offset.to_string(),
|
||||
document.range.start.to_string(),
|
||||
document.range.end.to_string(),
|
||||
document.name,
|
||||
embedding_blob
|
||||
],
|
||||
|
@ -204,7 +245,7 @@ impl VectorDatabase {
|
|||
worktree_ids: &[i64],
|
||||
query_embedding: &Vec<f32>,
|
||||
limit: usize,
|
||||
) -> Result<Vec<(i64, PathBuf, usize, String)>> {
|
||||
) -> Result<Vec<(i64, PathBuf, Range<usize>, String)>> {
|
||||
let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
|
||||
self.for_each_document(&worktree_ids, |id, embedding| {
|
||||
let similarity = dot(&embedding, &query_embedding);
|
||||
|
@ -248,11 +289,18 @@ impl VectorDatabase {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn get_documents_by_ids(&self, ids: &[i64]) -> Result<Vec<(i64, PathBuf, usize, String)>> {
|
||||
fn get_documents_by_ids(
|
||||
&self,
|
||||
ids: &[i64],
|
||||
) -> Result<Vec<(i64, PathBuf, Range<usize>, String)>> {
|
||||
let mut statement = self.db.prepare(
|
||||
"
|
||||
SELECT
|
||||
documents.id, files.worktree_id, files.relative_path, documents.offset, documents.name
|
||||
documents.id,
|
||||
files.worktree_id,
|
||||
files.relative_path,
|
||||
documents.start_byte,
|
||||
documents.end_byte, documents.name
|
||||
FROM
|
||||
documents, files
|
||||
WHERE
|
||||
|
@ -266,15 +314,15 @@ impl VectorDatabase {
|
|||
row.get::<_, i64>(0)?,
|
||||
row.get::<_, i64>(1)?,
|
||||
row.get::<_, String>(2)?.into(),
|
||||
row.get(3)?,
|
||||
row.get(4)?,
|
||||
row.get(3)?..row.get(4)?,
|
||||
row.get(5)?,
|
||||
))
|
||||
})?;
|
||||
|
||||
let mut values_by_id = HashMap::<i64, (i64, PathBuf, usize, String)>::default();
|
||||
let mut values_by_id = HashMap::<i64, (i64, PathBuf, Range<usize>, String)>::default();
|
||||
for row in result_iter {
|
||||
let (id, worktree_id, path, offset, name) = row?;
|
||||
values_by_id.insert(id, (worktree_id, path, offset, name));
|
||||
let (id, worktree_id, path, range, name) = row?;
|
||||
values_by_id.insert(id, (worktree_id, path, range, name));
|
||||
}
|
||||
|
||||
let mut results = Vec::with_capacity(ids.len());
|
||||
|
|
|
@ -66,7 +66,7 @@ impl PickerDelegate for SemanticSearchDelegate {
|
|||
});
|
||||
|
||||
let workspace = self.workspace.clone();
|
||||
let position = search_result.clone().offset;
|
||||
let position = search_result.clone().byte_range.start;
|
||||
cx.spawn(|_, mut cx| async move {
|
||||
let buffer = buffer.await?;
|
||||
workspace.update(&mut cx, |workspace, cx| {
|
||||
|
|
|
@ -1,41 +1,39 @@
|
|||
use std::{path::PathBuf, sync::Arc, time::SystemTime};
|
||||
|
||||
use anyhow::{anyhow, Ok, Result};
|
||||
use project::Fs;
|
||||
use language::Language;
|
||||
use std::{ops::Range, path::Path, sync::Arc};
|
||||
use tree_sitter::{Parser, QueryCursor};
|
||||
|
||||
use crate::PendingFile;
|
||||
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
pub struct Document {
|
||||
pub offset: usize,
|
||||
pub name: String,
|
||||
pub range: Range<usize>,
|
||||
pub content: String,
|
||||
pub embedding: Vec<f32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
pub struct ParsedFile {
|
||||
pub path: PathBuf,
|
||||
pub mtime: SystemTime,
|
||||
pub documents: Vec<Document>,
|
||||
}
|
||||
|
||||
const CODE_CONTEXT_TEMPLATE: &str =
|
||||
"The below code snippet is from file '<path>'\n\n```<language>\n<item>\n```";
|
||||
|
||||
pub struct CodeContextRetriever {
|
||||
pub parser: Parser,
|
||||
pub cursor: QueryCursor,
|
||||
pub fs: Arc<dyn Fs>,
|
||||
}
|
||||
|
||||
impl CodeContextRetriever {
|
||||
pub async fn parse_file(
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
parser: Parser::new(),
|
||||
cursor: QueryCursor::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn parse_file(
|
||||
&mut self,
|
||||
pending_file: PendingFile,
|
||||
) -> Result<(ParsedFile, Vec<String>)> {
|
||||
let grammar = pending_file
|
||||
.language
|
||||
relative_path: &Path,
|
||||
content: &str,
|
||||
language: Arc<Language>,
|
||||
) -> Result<Vec<Document>> {
|
||||
let grammar = language
|
||||
.grammar()
|
||||
.ok_or_else(|| anyhow!("no grammar for language"))?;
|
||||
let embedding_config = grammar
|
||||
|
@ -43,8 +41,6 @@ impl CodeContextRetriever {
|
|||
.as_ref()
|
||||
.ok_or_else(|| anyhow!("no embedding queries"))?;
|
||||
|
||||
let content = self.fs.load(&pending_file.absolute_path).await?;
|
||||
|
||||
self.parser.set_language(grammar.ts_language).unwrap();
|
||||
|
||||
let tree = self
|
||||
|
@ -53,7 +49,6 @@ impl CodeContextRetriever {
|
|||
.ok_or_else(|| anyhow!("parsing failed"))?;
|
||||
|
||||
let mut documents = Vec::new();
|
||||
let mut document_texts = Vec::new();
|
||||
|
||||
// Iterate through query matches
|
||||
for mat in self.cursor.matches(
|
||||
|
@ -63,11 +58,11 @@ impl CodeContextRetriever {
|
|||
) {
|
||||
let mut name: Vec<&str> = vec![];
|
||||
let mut item: Option<&str> = None;
|
||||
let mut offset: Option<usize> = None;
|
||||
let mut byte_range: Option<Range<usize>> = None;
|
||||
let mut context_spans: Vec<&str> = vec![];
|
||||
for capture in mat.captures {
|
||||
if capture.index == embedding_config.item_capture_ix {
|
||||
offset = Some(capture.node.byte_range().start);
|
||||
byte_range = Some(capture.node.byte_range());
|
||||
item = content.get(capture.node.byte_range());
|
||||
} else if capture.index == embedding_config.name_capture_ix {
|
||||
if let Some(name_content) = content.get(capture.node.byte_range()) {
|
||||
|
@ -84,30 +79,25 @@ impl CodeContextRetriever {
|
|||
}
|
||||
}
|
||||
|
||||
if item.is_some() && offset.is_some() && name.len() > 0 {
|
||||
let item = format!("{}\n{}", context_spans.join("\n"), item.unwrap());
|
||||
if let Some((item, byte_range)) = item.zip(byte_range) {
|
||||
if !name.is_empty() {
|
||||
let item = format!("{}\n{}", context_spans.join("\n"), item);
|
||||
|
||||
let document_text = CODE_CONTEXT_TEMPLATE
|
||||
.replace("<path>", pending_file.relative_path.to_str().unwrap())
|
||||
.replace("<language>", &pending_file.language.name().to_lowercase())
|
||||
.replace("<item>", item.as_str());
|
||||
let document_text = CODE_CONTEXT_TEMPLATE
|
||||
.replace("<path>", relative_path.to_str().unwrap())
|
||||
.replace("<language>", &language.name().to_lowercase())
|
||||
.replace("<item>", item.as_str());
|
||||
|
||||
document_texts.push(document_text);
|
||||
documents.push(Document {
|
||||
name: name.join(" "),
|
||||
offset: offset.unwrap(),
|
||||
embedding: Vec::new(),
|
||||
})
|
||||
documents.push(Document {
|
||||
range: byte_range,
|
||||
content: document_text,
|
||||
embedding: Vec::new(),
|
||||
name: name.join(" ").to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Ok((
|
||||
ParsedFile {
|
||||
path: pending_file.relative_path,
|
||||
mtime: pending_file.modified_time,
|
||||
documents,
|
||||
},
|
||||
document_texts,
|
||||
));
|
||||
return Ok(documents);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,16 +18,16 @@ use gpui::{
|
|||
};
|
||||
use language::{Language, LanguageRegistry};
|
||||
use modal::{SemanticSearch, SemanticSearchDelegate, Toggle};
|
||||
use parsing::{CodeContextRetriever, ParsedFile};
|
||||
use parsing::{CodeContextRetriever, Document};
|
||||
use project::{Fs, PathChange, Project, ProjectEntryId, WorktreeId};
|
||||
use smol::channel;
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
ops::Range,
|
||||
path::{Path, PathBuf},
|
||||
sync::Arc,
|
||||
time::{Duration, Instant, SystemTime},
|
||||
};
|
||||
use tree_sitter::{Parser, QueryCursor};
|
||||
use util::{
|
||||
channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME},
|
||||
http::HttpClient,
|
||||
|
@ -36,7 +36,7 @@ use util::{
|
|||
};
|
||||
use workspace::{Workspace, WorkspaceCreated};
|
||||
|
||||
const VECTOR_STORE_VERSION: usize = 0;
|
||||
const VECTOR_STORE_VERSION: usize = 1;
|
||||
const EMBEDDINGS_BATCH_SIZE: usize = 150;
|
||||
|
||||
pub fn init(
|
||||
|
@ -80,11 +80,11 @@ pub fn init(
|
|||
let vector_store = VectorStore::new(
|
||||
fs,
|
||||
db_file_path,
|
||||
// Arc::new(embedding::DummyEmbeddings {}),
|
||||
Arc::new(OpenAIEmbeddings {
|
||||
client: http_client,
|
||||
executor: cx.background(),
|
||||
}),
|
||||
Arc::new(embedding::DummyEmbeddings {}),
|
||||
// Arc::new(OpenAIEmbeddings {
|
||||
// client: http_client,
|
||||
// executor: cx.background(),
|
||||
// }),
|
||||
language_registry,
|
||||
cx.clone(),
|
||||
)
|
||||
|
@ -212,14 +212,16 @@ pub struct PendingFile {
|
|||
pub struct SearchResult {
|
||||
pub worktree_id: WorktreeId,
|
||||
pub name: String,
|
||||
pub offset: usize,
|
||||
pub byte_range: Range<usize>,
|
||||
pub file_path: PathBuf,
|
||||
}
|
||||
|
||||
enum DbOperation {
|
||||
InsertFile {
|
||||
worktree_id: i64,
|
||||
indexed_file: ParsedFile,
|
||||
documents: Vec<Document>,
|
||||
path: PathBuf,
|
||||
mtime: SystemTime,
|
||||
},
|
||||
Delete {
|
||||
worktree_id: i64,
|
||||
|
@ -238,8 +240,9 @@ enum DbOperation {
|
|||
enum EmbeddingJob {
|
||||
Enqueue {
|
||||
worktree_id: i64,
|
||||
parsed_file: ParsedFile,
|
||||
document_spans: Vec<String>,
|
||||
path: PathBuf,
|
||||
mtime: SystemTime,
|
||||
documents: Vec<Document>,
|
||||
},
|
||||
Flush,
|
||||
}
|
||||
|
@ -256,18 +259,7 @@ impl VectorStore {
|
|||
|
||||
let db = cx
|
||||
.background()
|
||||
.spawn({
|
||||
let fs = fs.clone();
|
||||
let database_url = database_url.clone();
|
||||
async move {
|
||||
if let Some(db_directory) = database_url.parent() {
|
||||
fs.create_dir(db_directory).await.log_err();
|
||||
}
|
||||
|
||||
let db = VectorDatabase::new(database_url.to_string_lossy().to_string())?;
|
||||
anyhow::Ok(db)
|
||||
}
|
||||
})
|
||||
.spawn(VectorDatabase::new(fs.clone(), database_url.clone()))
|
||||
.await?;
|
||||
|
||||
Ok(cx.add_model(|cx| {
|
||||
|
@ -280,9 +272,12 @@ impl VectorStore {
|
|||
match job {
|
||||
DbOperation::InsertFile {
|
||||
worktree_id,
|
||||
indexed_file,
|
||||
documents,
|
||||
path,
|
||||
mtime,
|
||||
} => {
|
||||
db.insert_file(worktree_id, indexed_file).log_err();
|
||||
db.insert_file(worktree_id, path, mtime, documents)
|
||||
.log_err();
|
||||
}
|
||||
DbOperation::Delete { worktree_id, path } => {
|
||||
db.delete_file(worktree_id, path).log_err();
|
||||
|
@ -304,35 +299,45 @@ impl VectorStore {
|
|||
|
||||
// embed_tx/rx: Embed Batch and Send to Database
|
||||
let (embed_batch_tx, embed_batch_rx) =
|
||||
channel::unbounded::<Vec<(i64, ParsedFile, Vec<String>)>>();
|
||||
channel::unbounded::<Vec<(i64, Vec<Document>, PathBuf, SystemTime)>>();
|
||||
let _embed_batch_task = cx.background().spawn({
|
||||
let db_update_tx = db_update_tx.clone();
|
||||
let embedding_provider = embedding_provider.clone();
|
||||
async move {
|
||||
while let Ok(mut embeddings_queue) = embed_batch_rx.recv().await {
|
||||
// Construct Batch
|
||||
let mut document_spans = vec![];
|
||||
for (_, _, document_span) in embeddings_queue.iter() {
|
||||
document_spans.extend(document_span.iter().map(|s| s.as_str()));
|
||||
let mut batch_documents = vec![];
|
||||
for (_, documents, _, _) in embeddings_queue.iter() {
|
||||
batch_documents
|
||||
.extend(documents.iter().map(|document| document.content.as_str()));
|
||||
}
|
||||
|
||||
if let Ok(embeddings) = embedding_provider.embed_batch(document_spans).await
|
||||
if let Ok(embeddings) =
|
||||
embedding_provider.embed_batch(batch_documents).await
|
||||
{
|
||||
log::trace!(
|
||||
"created {} embeddings for {} files",
|
||||
embeddings.len(),
|
||||
embeddings_queue.len(),
|
||||
);
|
||||
|
||||
let mut i = 0;
|
||||
let mut j = 0;
|
||||
|
||||
for embedding in embeddings.iter() {
|
||||
while embeddings_queue[i].1.documents.len() == j {
|
||||
while embeddings_queue[i].1.len() == j {
|
||||
i += 1;
|
||||
j = 0;
|
||||
}
|
||||
|
||||
embeddings_queue[i].1.documents[j].embedding = embedding.to_owned();
|
||||
embeddings_queue[i].1[j].embedding = embedding.to_owned();
|
||||
j += 1;
|
||||
}
|
||||
|
||||
for (worktree_id, indexed_file, _) in embeddings_queue.into_iter() {
|
||||
for document in indexed_file.documents.iter() {
|
||||
for (worktree_id, documents, path, mtime) in
|
||||
embeddings_queue.into_iter()
|
||||
{
|
||||
for document in documents.iter() {
|
||||
// TODO: Update this so it doesn't panic
|
||||
assert!(
|
||||
document.embedding.len() > 0,
|
||||
|
@ -343,7 +348,9 @@ impl VectorStore {
|
|||
db_update_tx
|
||||
.send(DbOperation::InsertFile {
|
||||
worktree_id,
|
||||
indexed_file,
|
||||
documents,
|
||||
path,
|
||||
mtime,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
@ -362,12 +369,13 @@ impl VectorStore {
|
|||
while let Ok(job) = batch_files_rx.recv().await {
|
||||
let should_flush = match job {
|
||||
EmbeddingJob::Enqueue {
|
||||
document_spans,
|
||||
documents,
|
||||
worktree_id,
|
||||
parsed_file,
|
||||
path,
|
||||
mtime,
|
||||
} => {
|
||||
queue_len += &document_spans.len();
|
||||
embeddings_queue.push((worktree_id, parsed_file, document_spans));
|
||||
queue_len += &documents.len();
|
||||
embeddings_queue.push((worktree_id, documents, path, mtime));
|
||||
queue_len >= EMBEDDINGS_BATCH_SIZE
|
||||
}
|
||||
EmbeddingJob::Flush => true,
|
||||
|
@ -385,26 +393,38 @@ impl VectorStore {
|
|||
let (parsing_files_tx, parsing_files_rx) = channel::unbounded::<PendingFile>();
|
||||
|
||||
let mut _parsing_files_tasks = Vec::new();
|
||||
// for _ in 0..cx.background().num_cpus() {
|
||||
for _ in 0..1 {
|
||||
for _ in 0..cx.background().num_cpus() {
|
||||
let fs = fs.clone();
|
||||
let parsing_files_rx = parsing_files_rx.clone();
|
||||
let batch_files_tx = batch_files_tx.clone();
|
||||
_parsing_files_tasks.push(cx.background().spawn(async move {
|
||||
let parser = Parser::new();
|
||||
let cursor = QueryCursor::new();
|
||||
let mut retriever = CodeContextRetriever { parser, cursor, fs };
|
||||
let mut retriever = CodeContextRetriever::new();
|
||||
while let Ok(pending_file) = parsing_files_rx.recv().await {
|
||||
if let Some((indexed_file, document_spans)) =
|
||||
retriever.parse_file(pending_file.clone()).await.log_err()
|
||||
if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err()
|
||||
{
|
||||
batch_files_tx
|
||||
.try_send(EmbeddingJob::Enqueue {
|
||||
worktree_id: pending_file.worktree_db_id,
|
||||
parsed_file: indexed_file,
|
||||
document_spans,
|
||||
})
|
||||
.unwrap();
|
||||
if let Some(documents) = retriever
|
||||
.parse_file(
|
||||
&pending_file.relative_path,
|
||||
&content,
|
||||
pending_file.language,
|
||||
)
|
||||
.log_err()
|
||||
{
|
||||
log::trace!(
|
||||
"parsed path {:?}: {} documents",
|
||||
pending_file.relative_path,
|
||||
documents.len()
|
||||
);
|
||||
|
||||
batch_files_tx
|
||||
.try_send(EmbeddingJob::Enqueue {
|
||||
worktree_id: pending_file.worktree_db_id,
|
||||
path: pending_file.relative_path,
|
||||
mtime: pending_file.modified_time,
|
||||
documents,
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
if parsing_files_rx.len() == 0 {
|
||||
|
@ -543,6 +563,7 @@ impl VectorStore {
|
|||
});
|
||||
|
||||
if !already_stored {
|
||||
log::trace!("sending for parsing: {:?}", path_buf);
|
||||
parsing_files_tx
|
||||
.try_send(PendingFile {
|
||||
worktree_db_id: db_ids_by_worktree_id
|
||||
|
@ -565,8 +586,8 @@ impl VectorStore {
|
|||
.unwrap();
|
||||
}
|
||||
}
|
||||
log::info!(
|
||||
"Parsing Worktree Completed in {:?}",
|
||||
log::trace!(
|
||||
"parsing worktree completed in {:?}",
|
||||
t0.elapsed().as_millis()
|
||||
);
|
||||
}
|
||||
|
@ -622,11 +643,12 @@ impl VectorStore {
|
|||
|
||||
let embedding_provider = self.embedding_provider.clone();
|
||||
let database_url = self.database_url.clone();
|
||||
let fs = self.fs.clone();
|
||||
cx.spawn(|this, cx| async move {
|
||||
let documents = cx
|
||||
.background()
|
||||
.spawn(async move {
|
||||
let database = VectorDatabase::new(database_url.to_string_lossy().into())?;
|
||||
let database = VectorDatabase::new(fs, database_url).await?;
|
||||
|
||||
let phrase_embedding = embedding_provider
|
||||
.embed_batch(vec![&phrase])
|
||||
|
@ -648,12 +670,12 @@ impl VectorStore {
|
|||
|
||||
Ok(documents
|
||||
.into_iter()
|
||||
.filter_map(|(worktree_db_id, file_path, offset, name)| {
|
||||
.filter_map(|(worktree_db_id, file_path, byte_range, name)| {
|
||||
let worktree_id = project_state.worktree_id_for_db_id(worktree_db_id)?;
|
||||
Some(SearchResult {
|
||||
worktree_id,
|
||||
name,
|
||||
offset,
|
||||
byte_range,
|
||||
file_path,
|
||||
})
|
||||
})
|
||||
|
|
|
@ -12,6 +12,13 @@ use settings::SettingsStore;
|
|||
use std::sync::Arc;
|
||||
use unindent::Unindent;
|
||||
|
||||
#[ctor::ctor]
|
||||
fn init_logger() {
|
||||
if std::env::var("RUST_LOG").is_ok() {
|
||||
env_logger::init();
|
||||
}
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_vector_store(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
|
@ -95,11 +102,23 @@ async fn test_vector_store(cx: &mut TestAppContext) {
|
|||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(search_results[0].offset, 0);
|
||||
assert_eq!(search_results[0].byte_range.start, 0);
|
||||
assert_eq!(search_results[0].name, "aaa");
|
||||
assert_eq!(search_results[0].worktree_id, worktree_id);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_code_context_retrieval(cx: &mut TestAppContext) {
|
||||
// let mut retriever = CodeContextRetriever::new(fs);
|
||||
|
||||
// retriever::parse_file(
|
||||
// "
|
||||
// //
|
||||
// ",
|
||||
// );
|
||||
//
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_dot_product(mut rng: StdRng) {
|
||||
assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.);
|
||||
|
|
Loading…
Reference in a new issue