mirror of
https://github.com/zed-industries/zed.git
synced 2024-12-26 10:40:54 +00:00
open ai indexing on open for rust files
This commit is contained in:
parent
d4a4db42aa
commit
dd309070eb
7 changed files with 252 additions and 55 deletions
57
Cargo.lock
generated
57
Cargo.lock
generated
|
@ -1389,15 +1389,6 @@ dependencies = [
|
|||
"theme",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "conv"
|
||||
version = "0.3.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "78ff10625fd0ac447827aa30ea8b861fead473bb60aeb73af6c1c58caf0d1299"
|
||||
dependencies = [
|
||||
"custom_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "copilot"
|
||||
version = "0.1.0"
|
||||
|
@ -1775,12 +1766,6 @@ dependencies = [
|
|||
"winapi 0.3.9",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "custom_derive"
|
||||
version = "0.1.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ef8ae57c4978a2acd8b869ce6b9ca1dfe817bff704c220209fdef2c0b75a01b9"
|
||||
|
||||
[[package]]
|
||||
name = "cxx"
|
||||
version = "1.0.94"
|
||||
|
@ -2219,6 +2204,12 @@ version = "0.2.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7"
|
||||
|
||||
[[package]]
|
||||
name = "fallible-streaming-iterator"
|
||||
version = "0.1.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a"
|
||||
|
||||
[[package]]
|
||||
name = "fancy-regex"
|
||||
version = "0.11.0"
|
||||
|
@ -2909,6 +2900,15 @@ dependencies = [
|
|||
"ahash 0.8.3",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hashlink"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7249a3129cbc1ffccd74857f81464a323a152173cdb134e0fd81bc803b29facf"
|
||||
dependencies = [
|
||||
"hashbrown 0.11.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hashlink"
|
||||
version = "0.8.1"
|
||||
|
@ -5600,6 +5600,21 @@ dependencies = [
|
|||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rusqlite"
|
||||
version = "0.27.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "85127183a999f7db96d1a976a309eebbfb6ea3b0b400ddd8340190129de6eb7a"
|
||||
dependencies = [
|
||||
"bitflags",
|
||||
"fallible-iterator",
|
||||
"fallible-streaming-iterator",
|
||||
"hashlink 0.7.0",
|
||||
"libsqlite3-sys",
|
||||
"memchr",
|
||||
"smallvec",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rust-embed"
|
||||
version = "6.6.1"
|
||||
|
@ -6531,7 +6546,7 @@ dependencies = [
|
|||
"futures-executor",
|
||||
"futures-intrusive",
|
||||
"futures-util",
|
||||
"hashlink",
|
||||
"hashlink 0.8.1",
|
||||
"hex",
|
||||
"hkdf",
|
||||
"hmac 0.12.1",
|
||||
|
@ -7898,14 +7913,20 @@ version = "0.1.0"
|
|||
dependencies = [
|
||||
"anyhow",
|
||||
"async-compat",
|
||||
"conv",
|
||||
"async-trait",
|
||||
"futures 0.3.28",
|
||||
"gpui",
|
||||
"isahc",
|
||||
"language",
|
||||
"lazy_static",
|
||||
"log",
|
||||
"project",
|
||||
"rand 0.8.5",
|
||||
"rusqlite",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"smol",
|
||||
"sqlx",
|
||||
"tree-sitter",
|
||||
"util",
|
||||
"workspace",
|
||||
]
|
||||
|
|
|
@ -476,12 +476,12 @@ pub struct Language {
|
|||
|
||||
pub struct Grammar {
|
||||
id: usize,
|
||||
pub(crate) ts_language: tree_sitter::Language,
|
||||
pub ts_language: tree_sitter::Language,
|
||||
pub(crate) error_query: Query,
|
||||
pub(crate) highlights_query: Option<Query>,
|
||||
pub(crate) brackets_config: Option<BracketConfig>,
|
||||
pub(crate) indents_config: Option<IndentConfig>,
|
||||
pub(crate) outline_config: Option<OutlineConfig>,
|
||||
pub outline_config: Option<OutlineConfig>,
|
||||
pub(crate) injection_config: Option<InjectionConfig>,
|
||||
pub(crate) override_config: Option<OverrideConfig>,
|
||||
pub(crate) highlight_map: Mutex<HighlightMap>,
|
||||
|
@ -495,12 +495,12 @@ struct IndentConfig {
|
|||
outdent_capture_ix: Option<u32>,
|
||||
}
|
||||
|
||||
struct OutlineConfig {
|
||||
query: Query,
|
||||
item_capture_ix: u32,
|
||||
name_capture_ix: u32,
|
||||
context_capture_ix: Option<u32>,
|
||||
extra_context_capture_ix: Option<u32>,
|
||||
pub struct OutlineConfig {
|
||||
pub query: Query,
|
||||
pub item_capture_ix: u32,
|
||||
pub name_capture_ix: u32,
|
||||
pub context_capture_ix: Option<u32>,
|
||||
pub extra_context_capture_ix: Option<u32>,
|
||||
}
|
||||
|
||||
struct InjectionConfig {
|
||||
|
|
|
@ -19,8 +19,14 @@ futures.workspace = true
|
|||
smol.workspace = true
|
||||
sqlx = { version = "0.6", features = ["sqlite","runtime-tokio-rustls"] }
|
||||
async-compat = "0.2.1"
|
||||
conv = "0.3.3"
|
||||
rand.workspace = true
|
||||
rusqlite = "0.27.0"
|
||||
isahc.workspace = true
|
||||
log.workspace = true
|
||||
tree-sitter.workspace = true
|
||||
lazy_static.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
async-trait.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
gpui = { path = "../gpui", features = ["test-support"] }
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
use anyhow::Result;
|
||||
use async_compat::{Compat, CompatExt};
|
||||
use conv::ValueFrom;
|
||||
use sqlx::{migrate::MigrateDatabase, Pool, Sqlite, SqlitePool};
|
||||
use std::time::{Duration, Instant};
|
||||
use sqlx::{migrate::MigrateDatabase, Sqlite, SqlitePool};
|
||||
|
||||
use crate::IndexedFile;
|
||||
|
||||
|
|
100
crates/vector_store/src/embedding.rs
Normal file
100
crates/vector_store/src/embedding.rs
Normal file
|
@ -0,0 +1,100 @@
|
|||
use anyhow::{anyhow, Result};
|
||||
use async_trait::async_trait;
|
||||
use futures::AsyncReadExt;
|
||||
use gpui::serde_json;
|
||||
use isahc::prelude::Configurable;
|
||||
use lazy_static::lazy_static;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::env;
|
||||
use std::sync::Arc;
|
||||
use util::http::{HttpClient, Request};
|
||||
|
||||
lazy_static! {
|
||||
static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
|
||||
}
|
||||
|
||||
pub struct OpenAIEmbeddings {
|
||||
pub client: Arc<dyn HttpClient>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct OpenAIEmbeddingRequest<'a> {
|
||||
model: &'static str,
|
||||
input: Vec<&'a str>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OpenAIEmbeddingResponse {
|
||||
data: Vec<OpenAIEmbedding>,
|
||||
usage: OpenAIEmbeddingUsage,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OpenAIEmbedding {
|
||||
embedding: Vec<f32>,
|
||||
index: usize,
|
||||
object: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OpenAIEmbeddingUsage {
|
||||
prompt_tokens: usize,
|
||||
total_tokens: usize,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait EmbeddingProvider: Sync {
|
||||
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl EmbeddingProvider for OpenAIEmbeddings {
|
||||
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
|
||||
let api_key = OPENAI_API_KEY
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow!("no api key"))?;
|
||||
|
||||
let request = Request::post("https://api.openai.com/v1/embeddings")
|
||||
.redirect_policy(isahc::config::RedirectPolicy::Follow)
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {}", api_key))
|
||||
.body(
|
||||
serde_json::to_string(&OpenAIEmbeddingRequest {
|
||||
input: spans,
|
||||
model: "text-embedding-ada-002",
|
||||
})
|
||||
.unwrap()
|
||||
.into(),
|
||||
)?;
|
||||
|
||||
let mut response = self.client.send(request).await?;
|
||||
if !response.status().is_success() {
|
||||
return Err(anyhow!("openai embedding failed {}", response.status()));
|
||||
}
|
||||
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
|
||||
|
||||
log::info!(
|
||||
"openai embedding completed. tokens: {:?}",
|
||||
response.usage.total_tokens
|
||||
);
|
||||
|
||||
// do we need to re-order these based on the `index` field?
|
||||
eprintln!(
|
||||
"indices: {:?}",
|
||||
response
|
||||
.data
|
||||
.iter()
|
||||
.map(|embedding| embedding.index)
|
||||
.collect::<Vec<_>>()
|
||||
);
|
||||
|
||||
Ok(response
|
||||
.data
|
||||
.into_iter()
|
||||
.map(|embedding| embedding.embedding)
|
||||
.collect())
|
||||
}
|
||||
}
|
|
@ -1,17 +1,25 @@
|
|||
mod db;
|
||||
use anyhow::Result;
|
||||
mod embedding;
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use db::VectorDatabase;
|
||||
use embedding::{EmbeddingProvider, OpenAIEmbeddings};
|
||||
use gpui::{AppContext, Entity, ModelContext, ModelHandle};
|
||||
use language::LanguageRegistry;
|
||||
use project::{Fs, Project};
|
||||
use rand::Rng;
|
||||
use smol::channel;
|
||||
use std::{path::PathBuf, sync::Arc, time::Instant};
|
||||
use util::ResultExt;
|
||||
use tree_sitter::{Parser, QueryCursor};
|
||||
use util::{http::HttpClient, ResultExt};
|
||||
use workspace::WorkspaceCreated;
|
||||
|
||||
pub fn init(fs: Arc<dyn Fs>, language_registry: Arc<LanguageRegistry>, cx: &mut AppContext) {
|
||||
let vector_store = cx.add_model(|cx| VectorStore::new(fs, language_registry));
|
||||
pub fn init(
|
||||
fs: Arc<dyn Fs>,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
cx: &mut AppContext,
|
||||
) {
|
||||
let vector_store = cx.add_model(|cx| VectorStore::new(fs, http_client, language_registry));
|
||||
|
||||
cx.subscribe_global::<WorkspaceCreated, _>({
|
||||
let vector_store = vector_store.clone();
|
||||
|
@ -53,38 +61,86 @@ struct SearchResult {
|
|||
|
||||
struct VectorStore {
|
||||
fs: Arc<dyn Fs>,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
}
|
||||
|
||||
impl VectorStore {
|
||||
fn new(fs: Arc<dyn Fs>, language_registry: Arc<LanguageRegistry>) -> Self {
|
||||
fn new(
|
||||
fs: Arc<dyn Fs>,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
) -> Self {
|
||||
Self {
|
||||
fs,
|
||||
http_client,
|
||||
language_registry,
|
||||
}
|
||||
}
|
||||
|
||||
async fn index_file(
|
||||
cursor: &mut QueryCursor,
|
||||
parser: &mut Parser,
|
||||
embedding_provider: &dyn EmbeddingProvider,
|
||||
fs: &Arc<dyn Fs>,
|
||||
language_registry: &Arc<LanguageRegistry>,
|
||||
file_path: PathBuf,
|
||||
) -> Result<IndexedFile> {
|
||||
// This is creating dummy documents to test the database writes.
|
||||
let mut documents = vec![];
|
||||
let mut rng = rand::thread_rng();
|
||||
let rand_num_of_documents: u8 = rng.gen_range(0..200);
|
||||
for _ in 0..rand_num_of_documents {
|
||||
let doc = Document {
|
||||
offset: 0,
|
||||
name: "test symbol".to_string(),
|
||||
embedding: vec![0.32 as f32; 768],
|
||||
};
|
||||
documents.push(doc);
|
||||
let language = language_registry
|
||||
.language_for_file(&file_path, None)
|
||||
.await?;
|
||||
|
||||
if language.name().as_ref() != "Rust" {
|
||||
Err(anyhow!("unsupported language"))?;
|
||||
}
|
||||
|
||||
let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?;
|
||||
let outline_config = grammar
|
||||
.outline_config
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow!("no outline query"))?;
|
||||
|
||||
let content = fs.load(&file_path).await?;
|
||||
parser.set_language(grammar.ts_language).unwrap();
|
||||
let tree = parser
|
||||
.parse(&content, None)
|
||||
.ok_or_else(|| anyhow!("parsing failed"))?;
|
||||
|
||||
let mut documents = Vec::new();
|
||||
let mut context_spans = Vec::new();
|
||||
for mat in cursor.matches(&outline_config.query, tree.root_node(), content.as_bytes()) {
|
||||
let mut item_range = None;
|
||||
let mut name_range = None;
|
||||
for capture in mat.captures {
|
||||
if capture.index == outline_config.item_capture_ix {
|
||||
item_range = Some(capture.node.byte_range());
|
||||
} else if capture.index == outline_config.name_capture_ix {
|
||||
name_range = Some(capture.node.byte_range());
|
||||
}
|
||||
}
|
||||
|
||||
if let Some((item_range, name_range)) = item_range.zip(name_range) {
|
||||
if let Some((item, name)) =
|
||||
content.get(item_range.clone()).zip(content.get(name_range))
|
||||
{
|
||||
context_spans.push(item);
|
||||
documents.push(Document {
|
||||
name: name.to_string(),
|
||||
offset: item_range.start,
|
||||
embedding: Vec::new(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let embeddings = embedding_provider.embed_batch(context_spans).await?;
|
||||
for (document, embedding) in documents.iter_mut().zip(embeddings) {
|
||||
document.embedding = embedding;
|
||||
}
|
||||
|
||||
return Ok(IndexedFile {
|
||||
path: file_path,
|
||||
sha1: "asdfasdfasdf".to_string(),
|
||||
sha1: String::new(),
|
||||
documents,
|
||||
});
|
||||
}
|
||||
|
@ -98,8 +154,9 @@ impl VectorStore {
|
|||
|
||||
let fs = self.fs.clone();
|
||||
let language_registry = self.language_registry.clone();
|
||||
let client = self.http_client.clone();
|
||||
|
||||
cx.spawn(|this, cx| async move {
|
||||
cx.spawn(|_, cx| async move {
|
||||
futures::future::join_all(worktree_scans_complete).await;
|
||||
|
||||
let worktrees = project.read_with(&cx, |project, cx| {
|
||||
|
@ -131,15 +188,27 @@ impl VectorStore {
|
|||
})
|
||||
.detach();
|
||||
|
||||
let provider = OpenAIEmbeddings { client };
|
||||
|
||||
let t0 = Instant::now();
|
||||
|
||||
cx.background()
|
||||
.scoped(|scope| {
|
||||
for _ in 0..cx.background().num_cpus() {
|
||||
scope.spawn(async {
|
||||
let mut parser = Parser::new();
|
||||
let mut cursor = QueryCursor::new();
|
||||
while let Ok(file_path) = paths_rx.recv().await {
|
||||
if let Some(indexed_file) =
|
||||
Self::index_file(&fs, &language_registry, file_path)
|
||||
.await
|
||||
.log_err()
|
||||
if let Some(indexed_file) = Self::index_file(
|
||||
&mut cursor,
|
||||
&mut parser,
|
||||
&provider,
|
||||
&fs,
|
||||
&language_registry,
|
||||
file_path,
|
||||
)
|
||||
.await
|
||||
.log_err()
|
||||
{
|
||||
indexed_files_tx.try_send(indexed_file).unwrap();
|
||||
}
|
||||
|
@ -148,6 +217,9 @@ impl VectorStore {
|
|||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
let duration = t0.elapsed();
|
||||
log::info!("indexed project in {duration:?}");
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
|
|
@ -152,7 +152,7 @@ fn main() {
|
|||
project_panel::init(cx);
|
||||
diagnostics::init(cx);
|
||||
search::init(cx);
|
||||
vector_store::init(fs.clone(), languages.clone(), cx);
|
||||
vector_store::init(fs.clone(), http.clone(), languages.clone(), cx);
|
||||
vim::init(cx);
|
||||
terminal_view::init(cx);
|
||||
theme_testbench::init(cx);
|
||||
|
|
Loading…
Reference in a new issue