open ai indexing on open for rust files

This commit is contained in:
KCaverly 2023-06-22 16:50:07 -04:00
parent d4a4db42aa
commit dd309070eb
7 changed files with 252 additions and 55 deletions

57
Cargo.lock generated
View file

@ -1389,15 +1389,6 @@ dependencies = [
"theme",
]
[[package]]
name = "conv"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "78ff10625fd0ac447827aa30ea8b861fead473bb60aeb73af6c1c58caf0d1299"
dependencies = [
"custom_derive",
]
[[package]]
name = "copilot"
version = "0.1.0"
@ -1775,12 +1766,6 @@ dependencies = [
"winapi 0.3.9",
]
[[package]]
name = "custom_derive"
version = "0.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ef8ae57c4978a2acd8b869ce6b9ca1dfe817bff704c220209fdef2c0b75a01b9"
[[package]]
name = "cxx"
version = "1.0.94"
@ -2219,6 +2204,12 @@ version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7"
[[package]]
name = "fallible-streaming-iterator"
version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a"
[[package]]
name = "fancy-regex"
version = "0.11.0"
@ -2909,6 +2900,15 @@ dependencies = [
"ahash 0.8.3",
]
[[package]]
name = "hashlink"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7249a3129cbc1ffccd74857f81464a323a152173cdb134e0fd81bc803b29facf"
dependencies = [
"hashbrown 0.11.2",
]
[[package]]
name = "hashlink"
version = "0.8.1"
@ -5600,6 +5600,21 @@ dependencies = [
"zeroize",
]
[[package]]
name = "rusqlite"
version = "0.27.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "85127183a999f7db96d1a976a309eebbfb6ea3b0b400ddd8340190129de6eb7a"
dependencies = [
"bitflags",
"fallible-iterator",
"fallible-streaming-iterator",
"hashlink 0.7.0",
"libsqlite3-sys",
"memchr",
"smallvec",
]
[[package]]
name = "rust-embed"
version = "6.6.1"
@ -6531,7 +6546,7 @@ dependencies = [
"futures-executor",
"futures-intrusive",
"futures-util",
"hashlink",
"hashlink 0.8.1",
"hex",
"hkdf",
"hmac 0.12.1",
@ -7898,14 +7913,20 @@ version = "0.1.0"
dependencies = [
"anyhow",
"async-compat",
"conv",
"async-trait",
"futures 0.3.28",
"gpui",
"isahc",
"language",
"lazy_static",
"log",
"project",
"rand 0.8.5",
"rusqlite",
"serde",
"serde_json",
"smol",
"sqlx",
"tree-sitter",
"util",
"workspace",
]

View file

@ -476,12 +476,12 @@ pub struct Language {
pub struct Grammar {
id: usize,
pub(crate) ts_language: tree_sitter::Language,
pub ts_language: tree_sitter::Language,
pub(crate) error_query: Query,
pub(crate) highlights_query: Option<Query>,
pub(crate) brackets_config: Option<BracketConfig>,
pub(crate) indents_config: Option<IndentConfig>,
pub(crate) outline_config: Option<OutlineConfig>,
pub outline_config: Option<OutlineConfig>,
pub(crate) injection_config: Option<InjectionConfig>,
pub(crate) override_config: Option<OverrideConfig>,
pub(crate) highlight_map: Mutex<HighlightMap>,
@ -495,12 +495,12 @@ struct IndentConfig {
outdent_capture_ix: Option<u32>,
}
struct OutlineConfig {
query: Query,
item_capture_ix: u32,
name_capture_ix: u32,
context_capture_ix: Option<u32>,
extra_context_capture_ix: Option<u32>,
pub struct OutlineConfig {
pub query: Query,
pub item_capture_ix: u32,
pub name_capture_ix: u32,
pub context_capture_ix: Option<u32>,
pub extra_context_capture_ix: Option<u32>,
}
struct InjectionConfig {

View file

@ -19,8 +19,14 @@ futures.workspace = true
smol.workspace = true
sqlx = { version = "0.6", features = ["sqlite","runtime-tokio-rustls"] }
async-compat = "0.2.1"
conv = "0.3.3"
rand.workspace = true
rusqlite = "0.27.0"
isahc.workspace = true
log.workspace = true
tree-sitter.workspace = true
lazy_static.workspace = true
serde.workspace = true
serde_json.workspace = true
async-trait.workspace = true
[dev-dependencies]
gpui = { path = "../gpui", features = ["test-support"] }

View file

@ -1,8 +1,6 @@
use anyhow::Result;
use async_compat::{Compat, CompatExt};
use conv::ValueFrom;
use sqlx::{migrate::MigrateDatabase, Pool, Sqlite, SqlitePool};
use std::time::{Duration, Instant};
use sqlx::{migrate::MigrateDatabase, Sqlite, SqlitePool};
use crate::IndexedFile;

View file

@ -0,0 +1,100 @@
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use futures::AsyncReadExt;
use gpui::serde_json;
use isahc::prelude::Configurable;
use lazy_static::lazy_static;
use serde::{Deserialize, Serialize};
use std::env;
use std::sync::Arc;
use util::http::{HttpClient, Request};
lazy_static! {
static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
}
pub struct OpenAIEmbeddings {
pub client: Arc<dyn HttpClient>,
}
#[derive(Serialize)]
struct OpenAIEmbeddingRequest<'a> {
model: &'static str,
input: Vec<&'a str>,
}
#[derive(Deserialize)]
struct OpenAIEmbeddingResponse {
data: Vec<OpenAIEmbedding>,
usage: OpenAIEmbeddingUsage,
}
#[derive(Debug, Deserialize)]
struct OpenAIEmbedding {
embedding: Vec<f32>,
index: usize,
object: String,
}
#[derive(Deserialize)]
struct OpenAIEmbeddingUsage {
prompt_tokens: usize,
total_tokens: usize,
}
#[async_trait]
pub trait EmbeddingProvider: Sync {
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>>;
}
#[async_trait]
impl EmbeddingProvider for OpenAIEmbeddings {
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
let api_key = OPENAI_API_KEY
.as_ref()
.ok_or_else(|| anyhow!("no api key"))?;
let request = Request::post("https://api.openai.com/v1/embeddings")
.redirect_policy(isahc::config::RedirectPolicy::Follow)
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_key))
.body(
serde_json::to_string(&OpenAIEmbeddingRequest {
input: spans,
model: "text-embedding-ada-002",
})
.unwrap()
.into(),
)?;
let mut response = self.client.send(request).await?;
if !response.status().is_success() {
return Err(anyhow!("openai embedding failed {}", response.status()));
}
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
log::info!(
"openai embedding completed. tokens: {:?}",
response.usage.total_tokens
);
// do we need to re-order these based on the `index` field?
eprintln!(
"indices: {:?}",
response
.data
.iter()
.map(|embedding| embedding.index)
.collect::<Vec<_>>()
);
Ok(response
.data
.into_iter()
.map(|embedding| embedding.embedding)
.collect())
}
}

View file

@ -1,17 +1,25 @@
mod db;
use anyhow::Result;
mod embedding;
use anyhow::{anyhow, Result};
use db::VectorDatabase;
use embedding::{EmbeddingProvider, OpenAIEmbeddings};
use gpui::{AppContext, Entity, ModelContext, ModelHandle};
use language::LanguageRegistry;
use project::{Fs, Project};
use rand::Rng;
use smol::channel;
use std::{path::PathBuf, sync::Arc, time::Instant};
use util::ResultExt;
use tree_sitter::{Parser, QueryCursor};
use util::{http::HttpClient, ResultExt};
use workspace::WorkspaceCreated;
pub fn init(fs: Arc<dyn Fs>, language_registry: Arc<LanguageRegistry>, cx: &mut AppContext) {
let vector_store = cx.add_model(|cx| VectorStore::new(fs, language_registry));
pub fn init(
fs: Arc<dyn Fs>,
http_client: Arc<dyn HttpClient>,
language_registry: Arc<LanguageRegistry>,
cx: &mut AppContext,
) {
let vector_store = cx.add_model(|cx| VectorStore::new(fs, http_client, language_registry));
cx.subscribe_global::<WorkspaceCreated, _>({
let vector_store = vector_store.clone();
@ -53,38 +61,86 @@ struct SearchResult {
struct VectorStore {
fs: Arc<dyn Fs>,
http_client: Arc<dyn HttpClient>,
language_registry: Arc<LanguageRegistry>,
}
impl VectorStore {
fn new(fs: Arc<dyn Fs>, language_registry: Arc<LanguageRegistry>) -> Self {
fn new(
fs: Arc<dyn Fs>,
http_client: Arc<dyn HttpClient>,
language_registry: Arc<LanguageRegistry>,
) -> Self {
Self {
fs,
http_client,
language_registry,
}
}
async fn index_file(
cursor: &mut QueryCursor,
parser: &mut Parser,
embedding_provider: &dyn EmbeddingProvider,
fs: &Arc<dyn Fs>,
language_registry: &Arc<LanguageRegistry>,
file_path: PathBuf,
) -> Result<IndexedFile> {
// This is creating dummy documents to test the database writes.
let mut documents = vec![];
let mut rng = rand::thread_rng();
let rand_num_of_documents: u8 = rng.gen_range(0..200);
for _ in 0..rand_num_of_documents {
let doc = Document {
offset: 0,
name: "test symbol".to_string(),
embedding: vec![0.32 as f32; 768],
};
documents.push(doc);
let language = language_registry
.language_for_file(&file_path, None)
.await?;
if language.name().as_ref() != "Rust" {
Err(anyhow!("unsupported language"))?;
}
let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?;
let outline_config = grammar
.outline_config
.as_ref()
.ok_or_else(|| anyhow!("no outline query"))?;
let content = fs.load(&file_path).await?;
parser.set_language(grammar.ts_language).unwrap();
let tree = parser
.parse(&content, None)
.ok_or_else(|| anyhow!("parsing failed"))?;
let mut documents = Vec::new();
let mut context_spans = Vec::new();
for mat in cursor.matches(&outline_config.query, tree.root_node(), content.as_bytes()) {
let mut item_range = None;
let mut name_range = None;
for capture in mat.captures {
if capture.index == outline_config.item_capture_ix {
item_range = Some(capture.node.byte_range());
} else if capture.index == outline_config.name_capture_ix {
name_range = Some(capture.node.byte_range());
}
}
if let Some((item_range, name_range)) = item_range.zip(name_range) {
if let Some((item, name)) =
content.get(item_range.clone()).zip(content.get(name_range))
{
context_spans.push(item);
documents.push(Document {
name: name.to_string(),
offset: item_range.start,
embedding: Vec::new(),
});
}
}
}
let embeddings = embedding_provider.embed_batch(context_spans).await?;
for (document, embedding) in documents.iter_mut().zip(embeddings) {
document.embedding = embedding;
}
return Ok(IndexedFile {
path: file_path,
sha1: "asdfasdfasdf".to_string(),
sha1: String::new(),
documents,
});
}
@ -98,8 +154,9 @@ impl VectorStore {
let fs = self.fs.clone();
let language_registry = self.language_registry.clone();
let client = self.http_client.clone();
cx.spawn(|this, cx| async move {
cx.spawn(|_, cx| async move {
futures::future::join_all(worktree_scans_complete).await;
let worktrees = project.read_with(&cx, |project, cx| {
@ -131,15 +188,27 @@ impl VectorStore {
})
.detach();
let provider = OpenAIEmbeddings { client };
let t0 = Instant::now();
cx.background()
.scoped(|scope| {
for _ in 0..cx.background().num_cpus() {
scope.spawn(async {
let mut parser = Parser::new();
let mut cursor = QueryCursor::new();
while let Ok(file_path) = paths_rx.recv().await {
if let Some(indexed_file) =
Self::index_file(&fs, &language_registry, file_path)
.await
.log_err()
if let Some(indexed_file) = Self::index_file(
&mut cursor,
&mut parser,
&provider,
&fs,
&language_registry,
file_path,
)
.await
.log_err()
{
indexed_files_tx.try_send(indexed_file).unwrap();
}
@ -148,6 +217,9 @@ impl VectorStore {
}
})
.await;
let duration = t0.elapsed();
log::info!("indexed project in {duration:?}");
})
.detach();
}

View file

@ -152,7 +152,7 @@ fn main() {
project_panel::init(cx);
diagnostics::init(cx);
search::init(cx);
vector_store::init(fs.clone(), languages.clone(), cx);
vector_store::init(fs.clone(), http.clone(), languages.clone(), cx);
vim::init(cx);
terminal_view::init(cx);
theme_testbench::init(cx);