Vector store (#2658)

This PR includes a new crate, aimed at maintaining a consistent semantic
embedding database, for any project opened with Zed. At a high level,
for each file in a project, we parse the file with treesitter, embed the
symbol "document" objects with OpenAI, and maintain a consistent
database of these embeddings and offset locations in a sqlite database.
Once stored, we have built a simple modal interface for querying on
these symbols embeddings using natural language, offering the
opportunity to navigate to the selected symbol.

This initial PR is intended to provide this functionality only in preview,
as we explore, evaluate and iterate on the vector store.

- Full task details are provided in the [Semantic Search Linear
Project](https://linear.app/zed-industries/project/semantic-search-7c787d198ebe/Z)
This commit is contained in:
Kyle Caverly 2023-07-12 13:26:17 -04:00 committed by GitHub
commit 37568ccbf0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
24 changed files with 2826 additions and 597 deletions

1330
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -63,6 +63,7 @@ members = [
"crates/theme",
"crates/theme_selector",
"crates/util",
"crates/vector_store",
"crates/vim",
"crates/vcs_menu",
"crates/workspace",

View file

@ -405,6 +405,7 @@
"cmd-k cmd-t": "theme_selector::Toggle",
"cmd-k cmd-s": "zed::OpenKeymap",
"cmd-t": "project_symbols::Toggle",
"cmd-ctrl-t": "semantic_search::Toggle",
"cmd-p": "file_finder::Toggle",
"cmd-shift-p": "command_palette::Toggle",
"cmd-shift-m": "diagnostics::Deploy",

View file

@ -291,6 +291,11 @@
// the terminal will default to matching the buffer's font family.
// "font_family": "Zed Mono"
},
// Difference settings for vector_store
"vector_store": {
"enabled": false,
"reindexing_delay_seconds": 600
},
// Different settings for specific languages.
"languages": {
"Plain Text": {

View file

@ -350,6 +350,7 @@ pub struct LanguageQueries {
pub brackets: Option<Cow<'static, str>>,
pub indents: Option<Cow<'static, str>>,
pub outline: Option<Cow<'static, str>>,
pub embedding: Option<Cow<'static, str>>,
pub injections: Option<Cow<'static, str>>,
pub overrides: Option<Cow<'static, str>>,
}
@ -490,12 +491,13 @@ pub struct Language {
pub struct Grammar {
id: usize,
pub(crate) ts_language: tree_sitter::Language,
pub ts_language: tree_sitter::Language,
pub(crate) error_query: Query,
pub(crate) highlights_query: Option<Query>,
pub(crate) brackets_config: Option<BracketConfig>,
pub(crate) indents_config: Option<IndentConfig>,
pub(crate) outline_config: Option<OutlineConfig>,
pub outline_config: Option<OutlineConfig>,
pub embedding_config: Option<EmbeddingConfig>,
pub(crate) injection_config: Option<InjectionConfig>,
pub(crate) override_config: Option<OverrideConfig>,
pub(crate) highlight_map: Mutex<HighlightMap>,
@ -509,12 +511,21 @@ struct IndentConfig {
outdent_capture_ix: Option<u32>,
}
struct OutlineConfig {
query: Query,
item_capture_ix: u32,
name_capture_ix: u32,
context_capture_ix: Option<u32>,
extra_context_capture_ix: Option<u32>,
pub struct OutlineConfig {
pub query: Query,
pub item_capture_ix: u32,
pub name_capture_ix: u32,
pub context_capture_ix: Option<u32>,
pub extra_context_capture_ix: Option<u32>,
}
#[derive(Debug)]
pub struct EmbeddingConfig {
pub query: Query,
pub item_capture_ix: u32,
pub name_capture_ix: u32,
pub context_capture_ix: Option<u32>,
pub extra_context_capture_ix: Option<u32>,
}
struct InjectionConfig {
@ -1146,6 +1157,7 @@ impl Language {
highlights_query: None,
brackets_config: None,
outline_config: None,
embedding_config: None,
indents_config: None,
injection_config: None,
override_config: None,
@ -1182,6 +1194,9 @@ impl Language {
if let Some(query) = queries.outline {
self = self.with_outline_query(query.as_ref())?;
}
if let Some(query) = queries.embedding {
self = self.with_embedding_query(query.as_ref())?;
}
if let Some(query) = queries.injections {
self = self.with_injection_query(query.as_ref())?;
}
@ -1190,6 +1205,7 @@ impl Language {
}
Ok(self)
}
pub fn with_highlights_query(mut self, source: &str) -> Result<Self> {
let grammar = self.grammar_mut();
grammar.highlights_query = Some(Query::new(grammar.ts_language, source)?);
@ -1224,6 +1240,34 @@ impl Language {
Ok(self)
}
pub fn with_embedding_query(mut self, source: &str) -> Result<Self> {
let grammar = self.grammar_mut();
let query = Query::new(grammar.ts_language, source)?;
let mut item_capture_ix = None;
let mut name_capture_ix = None;
let mut context_capture_ix = None;
let mut extra_context_capture_ix = None;
get_capture_indices(
&query,
&mut [
("item", &mut item_capture_ix),
("name", &mut name_capture_ix),
("context", &mut context_capture_ix),
("context.extra", &mut extra_context_capture_ix),
],
);
if let Some((item_capture_ix, name_capture_ix)) = item_capture_ix.zip(name_capture_ix) {
grammar.embedding_config = Some(EmbeddingConfig {
query,
item_capture_ix,
name_capture_ix,
context_capture_ix,
extra_context_capture_ix,
});
}
Ok(self)
}
pub fn with_brackets_query(mut self, source: &str) -> Result<Self> {
let grammar = self.grammar_mut();
let query = Query::new(grammar.ts_language, source)?;

View file

@ -261,6 +261,7 @@ pub enum Event {
ActiveEntryChanged(Option<ProjectEntryId>),
WorktreeAdded,
WorktreeRemoved(WorktreeId),
WorktreeUpdatedEntries(WorktreeId, UpdatedEntriesSet),
DiskBasedDiagnosticsStarted {
language_server_id: LanguageServerId,
},
@ -5403,6 +5404,10 @@ impl Project {
this.update_local_worktree_buffers(&worktree, changes, cx);
this.update_local_worktree_language_servers(&worktree, changes, cx);
this.update_local_worktree_settings(&worktree, changes, cx);
cx.emit(Event::WorktreeUpdatedEntries(
worktree.read(cx).id(),
changes.clone(),
));
}
worktree::Event::UpdatedGitRepositories(updated_repos) => {
this.update_local_worktree_buffers_git_repos(worktree, updated_repos, cx)

View file

@ -6,6 +6,7 @@ lazy_static::lazy_static! {
pub static ref HOME: PathBuf = dirs::home_dir().expect("failed to determine home directory");
pub static ref CONFIG_DIR: PathBuf = HOME.join(".config").join("zed");
pub static ref CONVERSATIONS_DIR: PathBuf = HOME.join(".config/zed/conversations");
pub static ref EMBEDDINGS_DIR: PathBuf = HOME.join(".config/zed/embeddings");
pub static ref LOGS_DIR: PathBuf = HOME.join("Library/Logs/Zed");
pub static ref SUPPORT_DIR: PathBuf = HOME.join("Library/Application Support/Zed");
pub static ref LANGUAGES_DIR: PathBuf = HOME.join("Library/Application Support/Zed/languages");

View file

@ -0,0 +1,49 @@
[package]
name = "vector_store"
version = "0.1.0"
edition = "2021"
publish = false
[lib]
path = "src/vector_store.rs"
doctest = false
[dependencies]
gpui = { path = "../gpui" }
language = { path = "../language" }
project = { path = "../project" }
workspace = { path = "../workspace" }
util = { path = "../util" }
picker = { path = "../picker" }
theme = { path = "../theme" }
editor = { path = "../editor" }
rpc = { path = "../rpc" }
settings = { path = "../settings" }
anyhow.workspace = true
futures.workspace = true
smol.workspace = true
rusqlite = { version = "0.27.0", features = ["blob", "array", "modern_sqlite"] }
isahc.workspace = true
log.workspace = true
tree-sitter.workspace = true
lazy_static.workspace = true
serde.workspace = true
serde_json.workspace = true
async-trait.workspace = true
bincode = "1.3.3"
matrixmultiply = "0.3.7"
tiktoken-rs = "0.5.0"
rand.workspace = true
schemars.workspace = true
[dev-dependencies]
gpui = { path = "../gpui", features = ["test-support"] }
language = { path = "../language", features = ["test-support"] }
project = { path = "../project", features = ["test-support"] }
rpc = { path = "../rpc", features = ["test-support"] }
workspace = { path = "../workspace", features = ["test-support"] }
settings = { path = "../settings", features = ["test-support"]}
tree-sitter-rust = "*"
rand.workspace = true
unindent.workspace = true
tempdir.workspace = true

View file

@ -0,0 +1,31 @@
WIP: Sample SQL Queries
/*
create table "files" (
"id" INTEGER PRIMARY KEY,
"path" VARCHAR,
"sha1" VARCHAR,
);
create table symbols (
"file_id" INTEGER REFERENCES("files", "id") ON CASCADE DELETE,
"offset" INTEGER,
"embedding" VECTOR,
);
insert into "files" ("path", "sha1") values ("src/main.rs", "sha1") return id;
insert into symbols (
"file_id",
"start",
"end",
"embedding"
) values (
(id,),
(id,),
(id,),
(id,),
)
*/

View file

@ -0,0 +1,325 @@
use std::{
cmp::Ordering,
collections::HashMap,
path::{Path, PathBuf},
rc::Rc,
time::SystemTime,
};
use anyhow::{anyhow, Result};
use crate::parsing::ParsedFile;
use crate::VECTOR_STORE_VERSION;
use rpc::proto::Timestamp;
use rusqlite::{
params,
types::{FromSql, FromSqlResult, ValueRef},
};
#[derive(Debug)]
pub struct FileRecord {
pub id: usize,
pub relative_path: String,
pub mtime: Timestamp,
}
#[derive(Debug)]
struct Embedding(pub Vec<f32>);
impl FromSql for Embedding {
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
let bytes = value.as_blob()?;
let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
if embedding.is_err() {
return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
}
return Ok(Embedding(embedding.unwrap()));
}
}
pub struct VectorDatabase {
db: rusqlite::Connection,
}
impl VectorDatabase {
pub fn new(path: String) -> Result<Self> {
let this = Self {
db: rusqlite::Connection::open(path)?,
};
this.initialize_database()?;
Ok(this)
}
fn initialize_database(&self) -> Result<()> {
rusqlite::vtab::array::load_module(&self.db)?;
// This will create the database if it doesnt exist
// Initialize Vector Databasing Tables
self.db.execute(
"CREATE TABLE IF NOT EXISTS worktrees (
id INTEGER PRIMARY KEY AUTOINCREMENT,
absolute_path VARCHAR NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS worktrees_absolute_path ON worktrees (absolute_path);
",
[],
)?;
self.db.execute(
"CREATE TABLE IF NOT EXISTS files (
id INTEGER PRIMARY KEY AUTOINCREMENT,
worktree_id INTEGER NOT NULL,
relative_path VARCHAR NOT NULL,
mtime_seconds INTEGER NOT NULL,
mtime_nanos INTEGER NOT NULL,
vector_store_version INTEGER NOT NULL,
FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
)",
[],
)?;
self.db.execute(
"CREATE TABLE IF NOT EXISTS documents (
id INTEGER PRIMARY KEY AUTOINCREMENT,
file_id INTEGER NOT NULL,
offset INTEGER NOT NULL,
name VARCHAR NOT NULL,
embedding BLOB NOT NULL,
FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
)",
[],
)?;
Ok(())
}
pub fn delete_file(&self, worktree_id: i64, delete_path: PathBuf) -> Result<()> {
self.db.execute(
"DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2",
params![worktree_id, delete_path.to_str()],
)?;
Ok(())
}
pub fn insert_file(&self, worktree_id: i64, indexed_file: ParsedFile) -> Result<()> {
// Write to files table, and return generated id.
self.db.execute(
"
DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2;
",
params![worktree_id, indexed_file.path.to_str()],
)?;
let mtime = Timestamp::from(indexed_file.mtime);
self.db.execute(
"
INSERT INTO files
(worktree_id, relative_path, mtime_seconds, mtime_nanos, vector_store_version)
VALUES
(?1, ?2, $3, $4, $5);
",
params![
worktree_id,
indexed_file.path.to_str(),
mtime.seconds,
mtime.nanos,
VECTOR_STORE_VERSION
],
)?;
let file_id = self.db.last_insert_rowid();
// Currently inserting at approximately 3400 documents a second
// I imagine we can speed this up with a bulk insert of some kind.
for document in indexed_file.documents {
let embedding_blob = bincode::serialize(&document.embedding)?;
self.db.execute(
"INSERT INTO documents (file_id, offset, name, embedding) VALUES (?1, ?2, ?3, ?4)",
params![
file_id,
document.offset.to_string(),
document.name,
embedding_blob
],
)?;
}
Ok(())
}
pub fn find_or_create_worktree(&self, worktree_root_path: &Path) -> Result<i64> {
// Check that the absolute path doesnt exist
let mut worktree_query = self
.db
.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
let worktree_id = worktree_query
.query_row(params![worktree_root_path.to_string_lossy()], |row| {
Ok(row.get::<_, i64>(0)?)
})
.map_err(|err| anyhow!(err));
if worktree_id.is_ok() {
return worktree_id;
}
// If worktree_id is Err, insert new worktree
self.db.execute(
"
INSERT into worktrees (absolute_path) VALUES (?1)
",
params![worktree_root_path.to_string_lossy()],
)?;
Ok(self.db.last_insert_rowid())
}
pub fn get_file_mtimes(&self, worktree_id: i64) -> Result<HashMap<PathBuf, SystemTime>> {
let mut statement = self.db.prepare(
"
SELECT relative_path, mtime_seconds, mtime_nanos
FROM files
WHERE worktree_id = ?1
ORDER BY relative_path",
)?;
let mut result: HashMap<PathBuf, SystemTime> = HashMap::new();
for row in statement.query_map(params![worktree_id], |row| {
Ok((
row.get::<_, String>(0)?.into(),
Timestamp {
seconds: row.get(1)?,
nanos: row.get(2)?,
}
.into(),
))
})? {
let row = row?;
result.insert(row.0, row.1);
}
Ok(result)
}
pub fn top_k_search(
&self,
worktree_ids: &[i64],
query_embedding: &Vec<f32>,
limit: usize,
) -> Result<Vec<(i64, PathBuf, usize, String)>> {
let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
self.for_each_document(&worktree_ids, |id, embedding| {
let similarity = dot(&embedding, &query_embedding);
let ix = match results
.binary_search_by(|(_, s)| similarity.partial_cmp(&s).unwrap_or(Ordering::Equal))
{
Ok(ix) => ix,
Err(ix) => ix,
};
results.insert(ix, (id, similarity));
results.truncate(limit);
})?;
let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
self.get_documents_by_ids(&ids)
}
fn for_each_document(
&self,
worktree_ids: &[i64],
mut f: impl FnMut(i64, Vec<f32>),
) -> Result<()> {
let mut query_statement = self.db.prepare(
"
SELECT
documents.id, documents.embedding
FROM
documents, files
WHERE
documents.file_id = files.id AND
files.worktree_id IN rarray(?)
",
)?;
query_statement
.query_map(params![ids_to_sql(worktree_ids)], |row| {
Ok((row.get(0)?, row.get::<_, Embedding>(1)?))
})?
.filter_map(|row| row.ok())
.for_each(|(id, embedding)| f(id, embedding.0));
Ok(())
}
fn get_documents_by_ids(&self, ids: &[i64]) -> Result<Vec<(i64, PathBuf, usize, String)>> {
let mut statement = self.db.prepare(
"
SELECT
documents.id, files.worktree_id, files.relative_path, documents.offset, documents.name
FROM
documents, files
WHERE
documents.file_id = files.id AND
documents.id in rarray(?)
",
)?;
let result_iter = statement.query_map(params![ids_to_sql(ids)], |row| {
Ok((
row.get::<_, i64>(0)?,
row.get::<_, i64>(1)?,
row.get::<_, String>(2)?.into(),
row.get(3)?,
row.get(4)?,
))
})?;
let mut values_by_id = HashMap::<i64, (i64, PathBuf, usize, String)>::default();
for row in result_iter {
let (id, worktree_id, path, offset, name) = row?;
values_by_id.insert(id, (worktree_id, path, offset, name));
}
let mut results = Vec::with_capacity(ids.len());
for id in ids {
let value = values_by_id
.remove(id)
.ok_or(anyhow!("missing document id {}", id))?;
results.push(value);
}
Ok(results)
}
}
fn ids_to_sql(ids: &[i64]) -> Rc<Vec<rusqlite::types::Value>> {
Rc::new(
ids.iter()
.copied()
.map(|v| rusqlite::types::Value::from(v))
.collect::<Vec<_>>(),
)
}
pub(crate) fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
let len = vec_a.len();
assert_eq!(len, vec_b.len());
let mut result = 0.0;
unsafe {
matrixmultiply::sgemm(
1,
len,
1,
1.0,
vec_a.as_ptr(),
len as isize,
1,
vec_b.as_ptr(),
1,
len as isize,
0.0,
&mut result as *mut f32,
1,
1,
);
}
result
}

View file

@ -0,0 +1,166 @@
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use futures::AsyncReadExt;
use gpui::executor::Background;
use gpui::serde_json;
use isahc::http::StatusCode;
use isahc::prelude::Configurable;
use isahc::{AsyncBody, Response};
use lazy_static::lazy_static;
use serde::{Deserialize, Serialize};
use std::env;
use std::sync::Arc;
use std::time::Duration;
use tiktoken_rs::{cl100k_base, CoreBPE};
use util::http::{HttpClient, Request};
lazy_static! {
static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
}
#[derive(Clone)]
pub struct OpenAIEmbeddings {
pub client: Arc<dyn HttpClient>,
pub executor: Arc<Background>,
}
#[derive(Serialize)]
struct OpenAIEmbeddingRequest<'a> {
model: &'static str,
input: Vec<&'a str>,
}
#[derive(Deserialize)]
struct OpenAIEmbeddingResponse {
data: Vec<OpenAIEmbedding>,
usage: OpenAIEmbeddingUsage,
}
#[derive(Debug, Deserialize)]
struct OpenAIEmbedding {
embedding: Vec<f32>,
index: usize,
object: String,
}
#[derive(Deserialize)]
struct OpenAIEmbeddingUsage {
prompt_tokens: usize,
total_tokens: usize,
}
#[async_trait]
pub trait EmbeddingProvider: Sync + Send {
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>>;
}
pub struct DummyEmbeddings {}
#[async_trait]
impl EmbeddingProvider for DummyEmbeddings {
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
// 1024 is the OpenAI Embeddings size for ada models.
// the model we will likely be starting with.
let dummy_vec = vec![0.32 as f32; 1536];
return Ok(vec![dummy_vec; spans.len()]);
}
}
impl OpenAIEmbeddings {
async fn truncate(span: String) -> String {
let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span.as_ref());
if tokens.len() > 8190 {
tokens.truncate(8190);
let result = OPENAI_BPE_TOKENIZER.decode(tokens.clone());
if result.is_ok() {
let transformed = result.unwrap();
// assert_ne!(transformed, span);
return transformed;
}
}
return span.to_string();
}
async fn send_request(&self, api_key: &str, spans: Vec<&str>) -> Result<Response<AsyncBody>> {
let request = Request::post("https://api.openai.com/v1/embeddings")
.redirect_policy(isahc::config::RedirectPolicy::Follow)
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_key))
.body(
serde_json::to_string(&OpenAIEmbeddingRequest {
input: spans.clone(),
model: "text-embedding-ada-002",
})
.unwrap()
.into(),
)?;
Ok(self.client.send(request).await?)
}
}
#[async_trait]
impl EmbeddingProvider for OpenAIEmbeddings {
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
const BACKOFF_SECONDS: [usize; 3] = [65, 180, 360];
const MAX_RETRIES: usize = 3;
let api_key = OPENAI_API_KEY
.as_ref()
.ok_or_else(|| anyhow!("no api key"))?;
let mut request_number = 0;
let mut response: Response<AsyncBody>;
let mut spans: Vec<String> = spans.iter().map(|x| x.to_string()).collect();
while request_number < MAX_RETRIES {
response = self
.send_request(api_key, spans.iter().map(|x| &**x).collect())
.await?;
request_number += 1;
if request_number + 1 == MAX_RETRIES && response.status() != StatusCode::OK {
return Err(anyhow!(
"openai max retries, error: {:?}",
&response.status()
));
}
match response.status() {
StatusCode::TOO_MANY_REQUESTS => {
let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
self.executor.timer(delay).await;
}
StatusCode::BAD_REQUEST => {
log::info!("BAD REQUEST: {:?}", &response.status());
// Don't worry about delaying bad request, as we can assume
// we haven't been rate limited yet.
for span in spans.iter_mut() {
*span = Self::truncate(span.to_string()).await;
}
}
StatusCode::OK => {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
log::info!(
"openai embedding completed. tokens: {:?}",
response.usage.total_tokens
);
return Ok(response
.data
.into_iter()
.map(|embedding| embedding.embedding)
.collect());
}
_ => {
return Err(anyhow!("openai embedding failed {}", response.status()));
}
}
}
Err(anyhow!("openai embedding failed"))
}
}

View file

@ -0,0 +1,172 @@
use crate::{SearchResult, VectorStore};
use editor::{scroll::autoscroll::Autoscroll, Editor};
use gpui::{
actions, elements::*, AnyElement, AppContext, ModelHandle, MouseState, Task, ViewContext,
WeakViewHandle,
};
use picker::{Picker, PickerDelegate, PickerEvent};
use project::{Project, ProjectPath};
use std::{collections::HashMap, sync::Arc, time::Duration};
use util::ResultExt;
use workspace::Workspace;
const MIN_QUERY_LEN: usize = 5;
const EMBEDDING_DEBOUNCE_INTERVAL: Duration = Duration::from_millis(500);
actions!(semantic_search, [Toggle]);
pub type SemanticSearch = Picker<SemanticSearchDelegate>;
pub struct SemanticSearchDelegate {
workspace: WeakViewHandle<Workspace>,
project: ModelHandle<Project>,
vector_store: ModelHandle<VectorStore>,
selected_match_index: usize,
matches: Vec<SearchResult>,
history: HashMap<String, Vec<SearchResult>>,
}
impl SemanticSearchDelegate {
// This is currently searching on every keystroke,
// This is wildly overkill, and has the potential to get expensive
// We will need to update this to throttle searching
pub fn new(
workspace: WeakViewHandle<Workspace>,
project: ModelHandle<Project>,
vector_store: ModelHandle<VectorStore>,
) -> Self {
Self {
workspace,
project,
vector_store,
selected_match_index: 0,
matches: vec![],
history: HashMap::new(),
}
}
}
impl PickerDelegate for SemanticSearchDelegate {
fn placeholder_text(&self) -> Arc<str> {
"Search repository in natural language...".into()
}
fn confirm(&mut self, cx: &mut ViewContext<SemanticSearch>) {
if let Some(search_result) = self.matches.get(self.selected_match_index) {
// Open Buffer
let search_result = search_result.clone();
let buffer = self.project.update(cx, |project, cx| {
project.open_buffer(
ProjectPath {
worktree_id: search_result.worktree_id,
path: search_result.file_path.clone().into(),
},
cx,
)
});
let workspace = self.workspace.clone();
let position = search_result.clone().offset;
cx.spawn(|_, mut cx| async move {
let buffer = buffer.await?;
workspace.update(&mut cx, |workspace, cx| {
let editor = workspace.open_project_item::<Editor>(buffer, cx);
editor.update(cx, |editor, cx| {
editor.change_selections(Some(Autoscroll::center()), cx, |s| {
s.select_ranges([position..position])
});
});
})?;
Ok::<_, anyhow::Error>(())
})
.detach_and_log_err(cx);
cx.emit(PickerEvent::Dismiss);
}
}
fn dismissed(&mut self, _cx: &mut ViewContext<SemanticSearch>) {}
fn match_count(&self) -> usize {
self.matches.len()
}
fn selected_index(&self) -> usize {
self.selected_match_index
}
fn set_selected_index(&mut self, ix: usize, _cx: &mut ViewContext<SemanticSearch>) {
self.selected_match_index = ix;
}
fn update_matches(&mut self, query: String, cx: &mut ViewContext<SemanticSearch>) -> Task<()> {
log::info!("Searching for {:?}...", query);
if query.len() < MIN_QUERY_LEN {
log::info!("Query below minimum length");
return Task::ready(());
}
let vector_store = self.vector_store.clone();
let project = self.project.clone();
cx.spawn(|this, mut cx| async move {
cx.background().timer(EMBEDDING_DEBOUNCE_INTERVAL).await;
let retrieved_cached = this.update(&mut cx, |this, _| {
let delegate = this.delegate_mut();
if delegate.history.contains_key(&query) {
let historic_results = delegate.history.get(&query).unwrap().to_owned();
delegate.matches = historic_results.clone();
true
} else {
false
}
});
if let Some(retrieved) = retrieved_cached.log_err() {
if !retrieved {
let task = vector_store.update(&mut cx, |store, cx| {
store.search(project.clone(), query.to_string(), 10, cx)
});
if let Some(results) = task.await.log_err() {
log::info!("Not queried previously, searching...");
this.update(&mut cx, |this, _| {
let delegate = this.delegate_mut();
delegate.matches = results.clone();
delegate.history.insert(query, results);
})
.ok();
}
} else {
log::info!("Already queried, retrieved directly from cached history");
}
}
})
}
fn render_match(
&self,
ix: usize,
mouse_state: &mut MouseState,
selected: bool,
cx: &AppContext,
) -> AnyElement<Picker<Self>> {
let theme = theme::current(cx);
let style = &theme.picker.item;
let current_style = style.in_state(selected).style_for(mouse_state);
let search_result = &self.matches[ix];
let path = search_result.file_path.to_string_lossy();
let name = search_result.name.clone();
Flex::column()
.with_child(Text::new(name, current_style.label.text.clone()).with_soft_wrap(false))
.with_child(Label::new(
path.to_string(),
style.inactive_state().default.label.clone(),
))
.contained()
.with_style(current_style.container)
.into_any()
}
}

View file

@ -0,0 +1,118 @@
use std::{path::PathBuf, sync::Arc, time::SystemTime};
use anyhow::{anyhow, Ok, Result};
use project::Fs;
use tree_sitter::{Parser, QueryCursor};
use crate::PendingFile;
#[derive(Debug, PartialEq, Clone)]
pub struct Document {
pub offset: usize,
pub name: String,
pub embedding: Vec<f32>,
}
#[derive(Debug, PartialEq, Clone)]
pub struct ParsedFile {
pub path: PathBuf,
pub mtime: SystemTime,
pub documents: Vec<Document>,
}
const CODE_CONTEXT_TEMPLATE: &str =
"The below code snippet is from file '<path>'\n\n```<language>\n<item>\n```";
pub struct CodeContextRetriever {
pub parser: Parser,
pub cursor: QueryCursor,
pub fs: Arc<dyn Fs>,
}
impl CodeContextRetriever {
pub async fn parse_file(
&mut self,
pending_file: PendingFile,
) -> Result<(ParsedFile, Vec<String>)> {
let grammar = pending_file
.language
.grammar()
.ok_or_else(|| anyhow!("no grammar for language"))?;
let embedding_config = grammar
.embedding_config
.as_ref()
.ok_or_else(|| anyhow!("no embedding queries"))?;
let content = self.fs.load(&pending_file.absolute_path).await?;
self.parser.set_language(grammar.ts_language).unwrap();
let tree = self
.parser
.parse(&content, None)
.ok_or_else(|| anyhow!("parsing failed"))?;
let mut documents = Vec::new();
let mut context_spans = Vec::new();
// Iterate through query matches
for mat in self.cursor.matches(
&embedding_config.query,
tree.root_node(),
content.as_bytes(),
) {
// log::info!("-----MATCH-----");
let mut name: Vec<&str> = vec![];
let mut item: Option<&str> = None;
let mut offset: Option<usize> = None;
for capture in mat.captures {
if capture.index == embedding_config.item_capture_ix {
offset = Some(capture.node.byte_range().start);
item = content.get(capture.node.byte_range());
} else if capture.index == embedding_config.name_capture_ix {
if let Some(name_content) = content.get(capture.node.byte_range()) {
name.push(name_content);
}
}
if let Some(context_capture_ix) = embedding_config.context_capture_ix {
if capture.index == context_capture_ix {
if let Some(context) = content.get(capture.node.byte_range()) {
name.push(context);
}
}
}
}
if item.is_some() && offset.is_some() && name.len() > 0 {
let context_span = CODE_CONTEXT_TEMPLATE
.replace("<path>", pending_file.relative_path.to_str().unwrap())
.replace("<language>", &pending_file.language.name().to_lowercase())
.replace("<item>", item.unwrap());
let mut truncated_span = context_span.clone();
truncated_span.truncate(100);
// log::info!("Name: {:?}", name);
// log::info!("Span: {:?}", truncated_span);
context_spans.push(context_span);
documents.push(Document {
name: name.join(" "),
offset: offset.unwrap(),
embedding: Vec::new(),
})
}
}
return Ok((
ParsedFile {
path: pending_file.relative_path,
mtime: pending_file.modified_time,
documents,
},
context_spans,
));
}
}

View file

@ -0,0 +1,770 @@
mod db;
mod embedding;
mod modal;
mod parsing;
mod vector_store_settings;
#[cfg(test)]
mod vector_store_tests;
use crate::vector_store_settings::VectorStoreSettings;
use anyhow::{anyhow, Result};
use db::VectorDatabase;
use embedding::{EmbeddingProvider, OpenAIEmbeddings};
use futures::{channel::oneshot, Future};
use gpui::{
AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, ViewContext,
WeakModelHandle,
};
use language::{Language, LanguageRegistry};
use modal::{SemanticSearch, SemanticSearchDelegate, Toggle};
use parsing::{CodeContextRetriever, ParsedFile};
use project::{Fs, PathChange, Project, ProjectEntryId, WorktreeId};
use smol::channel;
use std::{
collections::HashMap,
path::{Path, PathBuf},
sync::Arc,
time::{Duration, Instant, SystemTime},
};
use tree_sitter::{Parser, QueryCursor};
use util::{
channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME},
http::HttpClient,
paths::EMBEDDINGS_DIR,
ResultExt,
};
use workspace::{Workspace, WorkspaceCreated};
const VECTOR_STORE_VERSION: usize = 0;
const EMBEDDINGS_BATCH_SIZE: usize = 150;
pub fn init(
fs: Arc<dyn Fs>,
http_client: Arc<dyn HttpClient>,
language_registry: Arc<LanguageRegistry>,
cx: &mut AppContext,
) {
settings::register::<VectorStoreSettings>(cx);
let db_file_path = EMBEDDINGS_DIR
.join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
.join("embeddings_db");
SemanticSearch::init(cx);
cx.add_action(
|workspace: &mut Workspace, _: &Toggle, cx: &mut ViewContext<Workspace>| {
if cx.has_global::<ModelHandle<VectorStore>>() {
let vector_store = cx.global::<ModelHandle<VectorStore>>().clone();
workspace.toggle_modal(cx, |workspace, cx| {
let project = workspace.project().clone();
let workspace = cx.weak_handle();
cx.add_view(|cx| {
SemanticSearch::new(
SemanticSearchDelegate::new(workspace, project, vector_store),
cx,
)
})
});
}
},
);
if *RELEASE_CHANNEL == ReleaseChannel::Stable
|| !settings::get::<VectorStoreSettings>(cx).enabled
{
return;
}
cx.spawn(move |mut cx| async move {
let vector_store = VectorStore::new(
fs,
db_file_path,
// Arc::new(embedding::DummyEmbeddings {}),
Arc::new(OpenAIEmbeddings {
client: http_client,
executor: cx.background(),
}),
language_registry,
cx.clone(),
)
.await?;
cx.update(|cx| {
cx.set_global(vector_store.clone());
cx.subscribe_global::<WorkspaceCreated, _>({
let vector_store = vector_store.clone();
move |event, cx| {
let workspace = &event.0;
if let Some(workspace) = workspace.upgrade(cx) {
let project = workspace.read(cx).project().clone();
if project.read(cx).is_local() {
vector_store.update(cx, |store, cx| {
store.add_project(project, cx).detach();
});
}
}
}
})
.detach();
});
anyhow::Ok(())
})
.detach();
}
pub struct VectorStore {
fs: Arc<dyn Fs>,
database_url: Arc<PathBuf>,
embedding_provider: Arc<dyn EmbeddingProvider>,
language_registry: Arc<LanguageRegistry>,
db_update_tx: channel::Sender<DbOperation>,
parsing_files_tx: channel::Sender<PendingFile>,
_db_update_task: Task<()>,
_embed_batch_task: Task<()>,
_batch_files_task: Task<()>,
_parsing_files_tasks: Vec<Task<()>>,
projects: HashMap<WeakModelHandle<Project>, ProjectState>,
}
struct ProjectState {
worktree_db_ids: Vec<(WorktreeId, i64)>,
pending_files: HashMap<PathBuf, (PendingFile, SystemTime)>,
_subscription: gpui::Subscription,
}
impl ProjectState {
fn db_id_for_worktree_id(&self, id: WorktreeId) -> Option<i64> {
self.worktree_db_ids
.iter()
.find_map(|(worktree_id, db_id)| {
if *worktree_id == id {
Some(*db_id)
} else {
None
}
})
}
fn worktree_id_for_db_id(&self, id: i64) -> Option<WorktreeId> {
self.worktree_db_ids
.iter()
.find_map(|(worktree_id, db_id)| {
if *db_id == id {
Some(*worktree_id)
} else {
None
}
})
}
fn update_pending_files(&mut self, pending_file: PendingFile, indexing_time: SystemTime) {
// If Pending File Already Exists, Replace it with the new one
// but keep the old indexing time
if let Some(old_file) = self
.pending_files
.remove(&pending_file.relative_path.clone())
{
self.pending_files.insert(
pending_file.relative_path.clone(),
(pending_file, old_file.1),
);
} else {
self.pending_files.insert(
pending_file.relative_path.clone(),
(pending_file, indexing_time),
);
};
}
fn get_outstanding_files(&mut self) -> Vec<PendingFile> {
let mut outstanding_files = vec![];
let mut remove_keys = vec![];
for key in self.pending_files.keys().into_iter() {
if let Some(pending_details) = self.pending_files.get(key) {
let (pending_file, index_time) = pending_details;
if index_time <= &SystemTime::now() {
outstanding_files.push(pending_file.clone());
remove_keys.push(key.clone());
}
}
}
for key in remove_keys.iter() {
self.pending_files.remove(key);
}
return outstanding_files;
}
}
#[derive(Clone, Debug)]
pub struct PendingFile {
worktree_db_id: i64,
relative_path: PathBuf,
absolute_path: PathBuf,
language: Arc<Language>,
modified_time: SystemTime,
}
#[derive(Debug, Clone)]
pub struct SearchResult {
pub worktree_id: WorktreeId,
pub name: String,
pub offset: usize,
pub file_path: PathBuf,
}
enum DbOperation {
InsertFile {
worktree_id: i64,
indexed_file: ParsedFile,
},
Delete {
worktree_id: i64,
path: PathBuf,
},
FindOrCreateWorktree {
path: PathBuf,
sender: oneshot::Sender<Result<i64>>,
},
FileMTimes {
worktree_id: i64,
sender: oneshot::Sender<Result<HashMap<PathBuf, SystemTime>>>,
},
}
enum EmbeddingJob {
Enqueue {
worktree_id: i64,
parsed_file: ParsedFile,
document_spans: Vec<String>,
},
Flush,
}
impl VectorStore {
async fn new(
fs: Arc<dyn Fs>,
database_url: PathBuf,
embedding_provider: Arc<dyn EmbeddingProvider>,
language_registry: Arc<LanguageRegistry>,
mut cx: AsyncAppContext,
) -> Result<ModelHandle<Self>> {
let database_url = Arc::new(database_url);
let db = cx
.background()
.spawn({
let fs = fs.clone();
let database_url = database_url.clone();
async move {
if let Some(db_directory) = database_url.parent() {
fs.create_dir(db_directory).await.log_err();
}
let db = VectorDatabase::new(database_url.to_string_lossy().to_string())?;
anyhow::Ok(db)
}
})
.await?;
Ok(cx.add_model(|cx| {
// paths_tx -> embeddings_tx -> db_update_tx
//db_update_tx/rx: Updating Database
let (db_update_tx, db_update_rx) = channel::unbounded();
let _db_update_task = cx.background().spawn(async move {
while let Ok(job) = db_update_rx.recv().await {
match job {
DbOperation::InsertFile {
worktree_id,
indexed_file,
} => {
db.insert_file(worktree_id, indexed_file).log_err();
}
DbOperation::Delete { worktree_id, path } => {
db.delete_file(worktree_id, path).log_err();
}
DbOperation::FindOrCreateWorktree { path, sender } => {
let id = db.find_or_create_worktree(&path);
sender.send(id).ok();
}
DbOperation::FileMTimes {
worktree_id: worktree_db_id,
sender,
} => {
let file_mtimes = db.get_file_mtimes(worktree_db_id);
sender.send(file_mtimes).ok();
}
}
}
});
// embed_tx/rx: Embed Batch and Send to Database
let (embed_batch_tx, embed_batch_rx) =
channel::unbounded::<Vec<(i64, ParsedFile, Vec<String>)>>();
let _embed_batch_task = cx.background().spawn({
let db_update_tx = db_update_tx.clone();
let embedding_provider = embedding_provider.clone();
async move {
while let Ok(mut embeddings_queue) = embed_batch_rx.recv().await {
// Construct Batch
let mut document_spans = vec![];
for (_, _, document_span) in embeddings_queue.iter() {
document_spans.extend(document_span.iter().map(|s| s.as_str()));
}
if let Ok(embeddings) = embedding_provider.embed_batch(document_spans).await
{
let mut i = 0;
let mut j = 0;
for embedding in embeddings.iter() {
while embeddings_queue[i].1.documents.len() == j {
i += 1;
j = 0;
}
embeddings_queue[i].1.documents[j].embedding = embedding.to_owned();
j += 1;
}
for (worktree_id, indexed_file, _) in embeddings_queue.into_iter() {
for document in indexed_file.documents.iter() {
// TODO: Update this so it doesn't panic
assert!(
document.embedding.len() > 0,
"Document Embedding Not Complete"
);
}
db_update_tx
.send(DbOperation::InsertFile {
worktree_id,
indexed_file,
})
.await
.unwrap();
}
}
}
}
});
// batch_tx/rx: Batch Files to Send for Embeddings
let (batch_files_tx, batch_files_rx) = channel::unbounded::<EmbeddingJob>();
let _batch_files_task = cx.background().spawn(async move {
let mut queue_len = 0;
let mut embeddings_queue = vec![];
while let Ok(job) = batch_files_rx.recv().await {
let should_flush = match job {
EmbeddingJob::Enqueue {
document_spans,
worktree_id,
parsed_file,
} => {
queue_len += &document_spans.len();
embeddings_queue.push((worktree_id, parsed_file, document_spans));
queue_len >= EMBEDDINGS_BATCH_SIZE
}
EmbeddingJob::Flush => true,
};
if should_flush {
embed_batch_tx.try_send(embeddings_queue).unwrap();
embeddings_queue = vec![];
queue_len = 0;
}
}
});
// parsing_files_tx/rx: Parsing Files to Embeddable Documents
let (parsing_files_tx, parsing_files_rx) = channel::unbounded::<PendingFile>();
let mut _parsing_files_tasks = Vec::new();
// for _ in 0..cx.background().num_cpus() {
for _ in 0..1 {
let fs = fs.clone();
let parsing_files_rx = parsing_files_rx.clone();
let batch_files_tx = batch_files_tx.clone();
_parsing_files_tasks.push(cx.background().spawn(async move {
let parser = Parser::new();
let cursor = QueryCursor::new();
let mut retriever = CodeContextRetriever { parser, cursor, fs };
while let Ok(pending_file) = parsing_files_rx.recv().await {
if let Some((indexed_file, document_spans)) =
retriever.parse_file(pending_file.clone()).await.log_err()
{
batch_files_tx
.try_send(EmbeddingJob::Enqueue {
worktree_id: pending_file.worktree_db_id,
parsed_file: indexed_file,
document_spans,
})
.unwrap();
}
if parsing_files_rx.len() == 0 {
batch_files_tx.try_send(EmbeddingJob::Flush).unwrap();
}
}
}));
}
Self {
fs,
database_url,
embedding_provider,
language_registry,
db_update_tx,
parsing_files_tx,
_db_update_task,
_embed_batch_task,
_batch_files_task,
_parsing_files_tasks,
projects: HashMap::new(),
}
}))
}
fn find_or_create_worktree(&self, path: PathBuf) -> impl Future<Output = Result<i64>> {
let (tx, rx) = oneshot::channel();
self.db_update_tx
.try_send(DbOperation::FindOrCreateWorktree { path, sender: tx })
.unwrap();
async move { rx.await? }
}
fn get_file_mtimes(
&self,
worktree_id: i64,
) -> impl Future<Output = Result<HashMap<PathBuf, SystemTime>>> {
let (tx, rx) = oneshot::channel();
self.db_update_tx
.try_send(DbOperation::FileMTimes {
worktree_id,
sender: tx,
})
.unwrap();
async move { rx.await? }
}
fn add_project(
&mut self,
project: ModelHandle<Project>,
cx: &mut ModelContext<Self>,
) -> Task<Result<()>> {
let worktree_scans_complete = project
.read(cx)
.worktrees(cx)
.map(|worktree| {
let scan_complete = worktree.read(cx).as_local().unwrap().scan_complete();
async move {
scan_complete.await;
}
})
.collect::<Vec<_>>();
let worktree_db_ids = project
.read(cx)
.worktrees(cx)
.map(|worktree| {
self.find_or_create_worktree(worktree.read(cx).abs_path().to_path_buf())
})
.collect::<Vec<_>>();
let fs = self.fs.clone();
let language_registry = self.language_registry.clone();
let database_url = self.database_url.clone();
let db_update_tx = self.db_update_tx.clone();
let parsing_files_tx = self.parsing_files_tx.clone();
cx.spawn(|this, mut cx| async move {
futures::future::join_all(worktree_scans_complete).await;
let worktree_db_ids = futures::future::join_all(worktree_db_ids).await;
if let Some(db_directory) = database_url.parent() {
fs.create_dir(db_directory).await.log_err();
}
let worktrees = project.read_with(&cx, |project, cx| {
project
.worktrees(cx)
.map(|worktree| worktree.read(cx).snapshot())
.collect::<Vec<_>>()
});
let mut worktree_file_times = HashMap::new();
let mut db_ids_by_worktree_id = HashMap::new();
for (worktree, db_id) in worktrees.iter().zip(worktree_db_ids) {
let db_id = db_id?;
db_ids_by_worktree_id.insert(worktree.id(), db_id);
worktree_file_times.insert(
worktree.id(),
this.read_with(&cx, |this, _| this.get_file_mtimes(db_id))
.await?,
);
}
cx.background()
.spawn({
let db_ids_by_worktree_id = db_ids_by_worktree_id.clone();
let db_update_tx = db_update_tx.clone();
let language_registry = language_registry.clone();
let parsing_files_tx = parsing_files_tx.clone();
async move {
let t0 = Instant::now();
for worktree in worktrees.into_iter() {
let mut file_mtimes =
worktree_file_times.remove(&worktree.id()).unwrap();
for file in worktree.files(false, 0) {
let absolute_path = worktree.absolutize(&file.path);
if let Ok(language) = language_registry
.language_for_file(&absolute_path, None)
.await
{
if language
.grammar()
.and_then(|grammar| grammar.embedding_config.as_ref())
.is_none()
{
continue;
}
let path_buf = file.path.to_path_buf();
let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
let already_stored = stored_mtime
.map_or(false, |existing_mtime| {
existing_mtime == file.mtime
});
if !already_stored {
parsing_files_tx
.try_send(PendingFile {
worktree_db_id: db_ids_by_worktree_id
[&worktree.id()],
relative_path: path_buf,
absolute_path,
language,
modified_time: file.mtime,
})
.unwrap();
}
}
}
for file in file_mtimes.keys() {
db_update_tx
.try_send(DbOperation::Delete {
worktree_id: db_ids_by_worktree_id[&worktree.id()],
path: file.to_owned(),
})
.unwrap();
}
}
log::info!(
"Parsing Worktree Completed in {:?}",
t0.elapsed().as_millis()
);
}
})
.detach();
// let mut pending_files: Vec<(PathBuf, ((i64, PathBuf, Arc<Language>, SystemTime), SystemTime))> = vec![];
this.update(&mut cx, |this, cx| {
// The below is managing for updated on save
// Currently each time a file is saved, this code is run, and for all the files that were changed, if the current time is
// greater than the previous embedded time by the REINDEXING_DELAY variable, we will send the file off to be indexed.
let _subscription = cx.subscribe(&project, |this, project, event, cx| {
if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event {
this.project_entries_changed(project, changes.clone(), cx, worktree_id);
}
});
this.projects.insert(
project.downgrade(),
ProjectState {
pending_files: HashMap::new(),
worktree_db_ids: db_ids_by_worktree_id.into_iter().collect(),
_subscription,
},
);
});
anyhow::Ok(())
})
}
pub fn search(
&mut self,
project: ModelHandle<Project>,
phrase: String,
limit: usize,
cx: &mut ModelContext<Self>,
) -> Task<Result<Vec<SearchResult>>> {
let project_state = if let Some(state) = self.projects.get(&project.downgrade()) {
state
} else {
return Task::ready(Err(anyhow!("project not added")));
};
let worktree_db_ids = project
.read(cx)
.worktrees(cx)
.filter_map(|worktree| {
let worktree_id = worktree.read(cx).id();
project_state.db_id_for_worktree_id(worktree_id)
})
.collect::<Vec<_>>();
let embedding_provider = self.embedding_provider.clone();
let database_url = self.database_url.clone();
cx.spawn(|this, cx| async move {
let documents = cx
.background()
.spawn(async move {
let database = VectorDatabase::new(database_url.to_string_lossy().into())?;
let phrase_embedding = embedding_provider
.embed_batch(vec![&phrase])
.await?
.into_iter()
.next()
.unwrap();
database.top_k_search(&worktree_db_ids, &phrase_embedding, limit)
})
.await?;
this.read_with(&cx, |this, _| {
let project_state = if let Some(state) = this.projects.get(&project.downgrade()) {
state
} else {
return Err(anyhow!("project not added"));
};
Ok(documents
.into_iter()
.filter_map(|(worktree_db_id, file_path, offset, name)| {
let worktree_id = project_state.worktree_id_for_db_id(worktree_db_id)?;
Some(SearchResult {
worktree_id,
name,
offset,
file_path,
})
})
.collect())
})
})
}
fn project_entries_changed(
&mut self,
project: ModelHandle<Project>,
changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
cx: &mut ModelContext<'_, VectorStore>,
worktree_id: &WorktreeId,
) -> Option<()> {
let reindexing_delay = settings::get::<VectorStoreSettings>(cx).reindexing_delay_seconds;
let worktree = project
.read(cx)
.worktree_for_id(worktree_id.clone(), cx)?
.read(cx)
.snapshot();
let worktree_db_id = self
.projects
.get(&project.downgrade())?
.db_id_for_worktree_id(worktree.id())?;
let file_mtimes = self.get_file_mtimes(worktree_db_id);
let language_registry = self.language_registry.clone();
cx.spawn(|this, mut cx| async move {
let file_mtimes = file_mtimes.await.log_err()?;
for change in changes.into_iter() {
let change_path = change.0.clone();
let absolute_path = worktree.absolutize(&change_path);
// Skip if git ignored or symlink
if let Some(entry) = worktree.entry_for_id(change.1) {
if entry.is_ignored || entry.is_symlink || entry.is_external {
continue;
}
}
match change.2 {
PathChange::Removed => this.update(&mut cx, |this, _| {
this.db_update_tx
.try_send(DbOperation::Delete {
worktree_id: worktree_db_id,
path: absolute_path,
})
.unwrap();
}),
_ => {
if let Ok(language) = language_registry
.language_for_file(&change_path.to_path_buf(), None)
.await
{
if language
.grammar()
.and_then(|grammar| grammar.embedding_config.as_ref())
.is_none()
{
continue;
}
let modified_time =
change_path.metadata().log_err()?.modified().log_err()?;
let existing_time = file_mtimes.get(&change_path.to_path_buf());
let already_stored = existing_time
.map_or(false, |existing_time| &modified_time != existing_time);
if !already_stored {
this.update(&mut cx, |this, _| {
let reindex_time = modified_time
+ Duration::from_secs(reindexing_delay as u64);
let project_state =
this.projects.get_mut(&project.downgrade())?;
project_state.update_pending_files(
PendingFile {
relative_path: change_path.to_path_buf(),
absolute_path,
modified_time,
worktree_db_id,
language: language.clone(),
},
reindex_time,
);
for file in project_state.get_outstanding_files() {
this.parsing_files_tx.try_send(file).unwrap();
}
Some(())
});
}
}
}
}
}
Some(())
})
.detach();
Some(())
}
}
impl Entity for VectorStore {
type Event = ();
}

View file

@ -0,0 +1,30 @@
use anyhow;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::Setting;
#[derive(Deserialize, Debug)]
pub struct VectorStoreSettings {
pub enabled: bool,
pub reindexing_delay_seconds: usize,
}
#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)]
pub struct VectorStoreSettingsContent {
pub enabled: Option<bool>,
pub reindexing_delay_seconds: Option<usize>,
}
impl Setting for VectorStoreSettings {
const KEY: Option<&'static str> = Some("vector_store");
type FileContent = VectorStoreSettingsContent;
fn load(
default_value: &Self::FileContent,
user_values: &[&Self::FileContent],
_: &gpui::AppContext,
) -> anyhow::Result<Self> {
Self::load_via_json_merge(default_value, user_values)
}
}

View file

@ -0,0 +1,161 @@
use crate::{
db::dot, embedding::EmbeddingProvider, vector_store_settings::VectorStoreSettings, VectorStore,
};
use anyhow::Result;
use async_trait::async_trait;
use gpui::{Task, TestAppContext};
use language::{Language, LanguageConfig, LanguageRegistry};
use project::{project_settings::ProjectSettings, FakeFs, Project};
use rand::{rngs::StdRng, Rng};
use serde_json::json;
use settings::SettingsStore;
use std::sync::Arc;
use unindent::Unindent;
#[gpui::test]
async fn test_vector_store(cx: &mut TestAppContext) {
cx.update(|cx| {
cx.set_global(SettingsStore::test(cx));
settings::register::<VectorStoreSettings>(cx);
settings::register::<ProjectSettings>(cx);
});
let fs = FakeFs::new(cx.background());
fs.insert_tree(
"/the-root",
json!({
"src": {
"file1.rs": "
fn aaa() {
println!(\"aaaa!\");
}
fn zzzzzzzzz() {
println!(\"SLEEPING\");
}
".unindent(),
"file2.rs": "
fn bbb() {
println!(\"bbbb!\");
}
".unindent(),
}
}),
)
.await;
let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
let rust_language = Arc::new(
Language::new(
LanguageConfig {
name: "Rust".into(),
path_suffixes: vec!["rs".into()],
..Default::default()
},
Some(tree_sitter_rust::language()),
)
.with_embedding_query(
r#"
(function_item
name: (identifier) @name
body: (block)) @item
"#,
)
.unwrap(),
);
languages.add(rust_language);
let db_dir = tempdir::TempDir::new("vector-store").unwrap();
let db_path = db_dir.path().join("db.sqlite");
let store = VectorStore::new(
fs.clone(),
db_path,
Arc::new(FakeEmbeddingProvider),
languages,
cx.to_async(),
)
.await
.unwrap();
let project = Project::test(fs, ["/the-root".as_ref()], cx).await;
let worktree_id = project.read_with(cx, |project, cx| {
project.worktrees(cx).next().unwrap().read(cx).id()
});
store
.update(cx, |store, cx| store.add_project(project.clone(), cx))
.await
.unwrap();
cx.foreground().run_until_parked();
let search_results = store
.update(cx, |store, cx| {
store.search(project.clone(), "aaaa".to_string(), 5, cx)
})
.await
.unwrap();
assert_eq!(search_results[0].offset, 0);
assert_eq!(search_results[0].name, "aaa");
assert_eq!(search_results[0].worktree_id, worktree_id);
}
#[gpui::test]
fn test_dot_product(mut rng: StdRng) {
assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.);
assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.);
for _ in 0..100 {
let size = 1536;
let mut a = vec![0.; size];
let mut b = vec![0.; size];
for (a, b) in a.iter_mut().zip(b.iter_mut()) {
*a = rng.gen();
*b = rng.gen();
}
assert_eq!(
round_to_decimals(dot(&a, &b), 1),
round_to_decimals(reference_dot(&a, &b), 1)
);
}
fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
let factor = (10.0 as f32).powi(decimal_places);
(n * factor).round() / factor
}
fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
}
}
struct FakeEmbeddingProvider;
#[async_trait]
impl EmbeddingProvider for FakeEmbeddingProvider {
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
Ok(spans
.iter()
.map(|span| {
let mut result = vec![1.0; 26];
for letter in span.chars() {
let letter = letter.to_ascii_lowercase();
if letter as u32 >= 'a' as u32 {
let ix = (letter as u32) - ('a' as u32);
if ix < 26 {
result[ix as usize] += 1.0;
}
}
}
let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
for x in &mut result {
*x /= norm;
}
result
})
.collect())
}
}

View file

@ -64,6 +64,7 @@ terminal_view = { path = "../terminal_view" }
theme = { path = "../theme" }
theme_selector = { path = "../theme_selector" }
util = { path = "../util" }
vector_store = { path = "../vector_store" }
vim = { path = "../vim" }
workspace = { path = "../workspace" }
welcome = { path = "../welcome" }

View file

@ -170,6 +170,7 @@ fn load_queries(name: &str) -> LanguageQueries {
brackets: load_query(name, "/brackets"),
indents: load_query(name, "/indents"),
outline: load_query(name, "/outline"),
embedding: load_query(name, "/embedding"),
injections: load_query(name, "/injections"),
overrides: load_query(name, "/overrides"),
}

View file

@ -0,0 +1,56 @@
; (internal_module
; "namespace" @context
; name: (_) @name) @item
(enum_declaration
"enum" @context
name: (_) @name) @item
(function_declaration
"async"? @context
"function" @context
name: (_) @name) @item
(interface_declaration
"interface" @context
name: (_) @name) @item
; (program
; (export_statement
; (lexical_declaration
; ["let" "const"] @context
; (variable_declarator
; name: (_) @name) @item)))
(program
(lexical_declaration
["let" "const"] @context
(variable_declarator
name: (_) @name) @item))
(class_declaration
"class" @context
name: (_) @name) @item
(method_definition
[
"get"
"set"
"async"
"*"
"readonly"
"static"
(override_modifier)
(accessibility_modifier)
]* @context
name: (_) @name) @item
; (public_field_definition
; [
; "declare"
; "readonly"
; "abstract"
; "static"
; (accessibility_modifier)
; ]* @context
; name: (_) @name) @item

View file

@ -0,0 +1,9 @@
(class_definition
"class" @context
name: (identifier) @name
) @item
(function_definition
"async"? @context
"def" @context
name: (_) @name) @item

View file

@ -0,0 +1,36 @@
(struct_item
(visibility_modifier)? @context
"struct" @context
name: (_) @name) @item
(enum_item
(visibility_modifier)? @context
"enum" @context
name: (_) @name) @item
(impl_item
"impl" @context
trait: (_)? @name
"for"? @context
type: (_) @name) @item
(trait_item
(visibility_modifier)? @context
"trait" @context
name: (_) @name) @item
(function_item
(visibility_modifier)? @context
(function_modifiers)? @context
"fn" @context
name: (_) @name) @item
(function_signature_item
(visibility_modifier)? @context
(function_modifiers)? @context
"fn" @context
name: (_) @name) @item
(macro_definition
. "macro_rules!" @context
name: (_) @name) @item

View file

@ -0,0 +1,35 @@
(enum_declaration
"enum" @context
name: (_) @name) @item
(function_declaration
"async"? @context
"function" @context
name: (_) @name) @item
(interface_declaration
"interface" @context
name: (_) @name) @item
(program
(lexical_declaration
["let" "const"] @context
(variable_declarator
name: (_) @name) @item))
(class_declaration
"class" @context
name: (_) @name) @item
(method_definition
[
"get"
"set"
"async"
"*"
"readonly"
"static"
(override_modifier)
(accessibility_modifier)
]* @context
name: (_) @name) @item

View file

@ -0,0 +1,59 @@
; (internal_module
; "namespace" @context
; name: (_) @name) @item
(enum_declaration
"enum" @context
name: (_) @name) @item
; (type_alias_declaration
; "type" @context
; name: (_) @name) @item
(function_declaration
"async"? @context
"function" @context
name: (_) @name) @item
(interface_declaration
"interface" @context
name: (_) @name) @item
; (export_statement
; (lexical_declaration
; ["let" "const"] @context
; (variable_declarator
; name: (_) @name) @item))
(program
(lexical_declaration
["let" "const"] @context
(variable_declarator
name: (_) @name) @item))
(class_declaration
"class" @context
name: (_) @name) @item
(method_definition
[
"get"
"set"
"async"
"*"
"readonly"
"static"
(override_modifier)
(accessibility_modifier)
]* @context
name: (_) @name) @item
; (public_field_definition
; [
; "declare"
; "readonly"
; "abstract"
; "static"
; (accessibility_modifier)
; ]* @context
; name: (_) @name) @item

View file

@ -157,6 +157,7 @@ fn main() {
project_panel::init(cx);
diagnostics::init(cx);
search::init(cx);
vector_store::init(fs.clone(), http.clone(), languages.clone(), cx);
vim::init(cx);
terminal_view::init(cx);
copilot::init(http.clone(), node_runtime, cx);