diff --git a/crates/ai/src/embedding.rs b/crates/ai/src/embedding.rs index 4d5e40fad9..b791414ba2 100644 --- a/crates/ai/src/embedding.rs +++ b/crates/ai/src/embedding.rs @@ -2,7 +2,7 @@ use anyhow::{anyhow, Result}; use async_trait::async_trait; use futures::AsyncReadExt; use gpui::executor::Background; -use gpui::{serde_json, ViewContext}; +use gpui::{serde_json, AppContext}; use isahc::http::StatusCode; use isahc::prelude::Configurable; use isahc::{AsyncBody, Response}; @@ -89,7 +89,6 @@ impl Embedding { #[derive(Clone)] pub struct OpenAIEmbeddings { - pub api_key: Option, pub client: Arc, pub executor: Arc, rate_limit_count_rx: watch::Receiver>, @@ -123,8 +122,12 @@ struct OpenAIEmbeddingUsage { #[async_trait] pub trait EmbeddingProvider: Sync + Send { - fn is_authenticated(&self) -> bool; - async fn embed_batch(&self, spans: Vec) -> Result>; + fn retrieve_credentials(&self, cx: &AppContext) -> Option; + async fn embed_batch( + &self, + spans: Vec, + api_key: Option, + ) -> Result>; fn max_tokens_per_batch(&self) -> usize; fn truncate(&self, span: &str) -> (String, usize); fn rate_limit_expiration(&self) -> Option; @@ -134,13 +137,17 @@ pub struct DummyEmbeddings {} #[async_trait] impl EmbeddingProvider for DummyEmbeddings { - fn is_authenticated(&self) -> bool { - true + fn retrieve_credentials(&self, _cx: &AppContext) -> Option { + Some("Dummy API KEY".to_string()) } fn rate_limit_expiration(&self) -> Option { None } - async fn embed_batch(&self, spans: Vec) -> Result> { + async fn embed_batch( + &self, + spans: Vec, + _api_key: Option, + ) -> Result> { // 1024 is the OpenAI Embeddings size for ada models. // the model we will likely be starting with. let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]); @@ -169,36 +176,11 @@ impl EmbeddingProvider for DummyEmbeddings { const OPENAI_INPUT_LIMIT: usize = 8190; impl OpenAIEmbeddings { - pub fn authenticate(&mut self, cx: &mut ViewContext) { - if self.api_key.is_none() { - let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") { - Some(api_key) - } else if let Some((_, api_key)) = cx - .platform() - .read_credentials(OPENAI_API_URL) - .log_err() - .flatten() - { - String::from_utf8(api_key).log_err() - } else { - None - }; - - if let Some(api_key) = api_key { - self.api_key = Some(api_key); - } - } - } - pub fn new( - api_key: Option, - client: Arc, - executor: Arc, - ) -> Self { + pub fn new(client: Arc, executor: Arc) -> Self { let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None); let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx)); OpenAIEmbeddings { - api_key, client, executor, rate_limit_count_rx, @@ -264,8 +246,19 @@ impl OpenAIEmbeddings { #[async_trait] impl EmbeddingProvider for OpenAIEmbeddings { - fn is_authenticated(&self) -> bool { - self.api_key.is_some() + fn retrieve_credentials(&self, cx: &AppContext) -> Option { + if let Ok(api_key) = env::var("OPENAI_API_KEY") { + Some(api_key) + } else if let Some((_, api_key)) = cx + .platform() + .read_credentials(OPENAI_API_URL) + .log_err() + .flatten() + { + String::from_utf8(api_key).log_err() + } else { + None + } } fn max_tokens_per_batch(&self) -> usize { @@ -290,11 +283,15 @@ impl EmbeddingProvider for OpenAIEmbeddings { (output, tokens.len()) } - async fn embed_batch(&self, spans: Vec) -> Result> { + async fn embed_batch( + &self, + spans: Vec, + api_key: Option, + ) -> Result> { const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45]; const MAX_RETRIES: usize = 4; - let Some(api_key) = self.api_key.clone() else { + let Some(api_key) = api_key else { return Err(anyhow!("no open ai key provided")); }; diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index 2eb1fd421c..cb238a8673 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -53,7 +53,7 @@ use lsp::{ use lsp_command::*; use node_runtime::NodeRuntime; use postage::watch; -use prettier::{LocateStart, Prettier, PRETTIER_SERVER_FILE, PRETTIER_SERVER_JS}; +use prettier::{LocateStart, Prettier}; use project_settings::{LspSettings, ProjectSettings}; use rand::prelude::*; use search::SearchQuery; @@ -79,13 +79,10 @@ use std::{ time::{Duration, Instant}, }; use terminals::Terminals; -use text::{Anchor, LineEnding, Rope}; +use text::Anchor; use util::{ - debug_panic, defer, - http::HttpClient, - merge_json_value_into, - paths::{DEFAULT_PRETTIER_DIR, LOCAL_SETTINGS_RELATIVE_PATH}, - post_inc, ResultExt, TryFutureExt as _, + debug_panic, defer, http::HttpClient, merge_json_value_into, + paths::LOCAL_SETTINGS_RELATIVE_PATH, post_inc, ResultExt, TryFutureExt as _, }; pub use fs::*; @@ -8489,6 +8486,18 @@ impl Project { } } + #[cfg(any(test, feature = "test-support"))] + fn install_default_formatters( + &self, + _worktree: Option, + _new_language: &Language, + _language_settings: &LanguageSettings, + _cx: &mut ModelContext, + ) -> Task> { + return Task::ready(Ok(())); + } + + #[cfg(not(any(test, feature = "test-support")))] fn install_default_formatters( &self, worktree: Option, @@ -8519,7 +8528,7 @@ impl Project { return Task::ready(Ok(())); }; - let default_prettier_dir = DEFAULT_PRETTIER_DIR.as_path(); + let default_prettier_dir = util::paths::DEFAULT_PRETTIER_DIR.as_path(); let already_running_prettier = self .prettier_instances .get(&(worktree, default_prettier_dir.to_path_buf())) @@ -8528,10 +8537,10 @@ impl Project { let fs = Arc::clone(&self.fs); cx.background() .spawn(async move { - let prettier_wrapper_path = default_prettier_dir.join(PRETTIER_SERVER_FILE); + let prettier_wrapper_path = default_prettier_dir.join(prettier::PRETTIER_SERVER_FILE); // method creates parent directory if it doesn't exist - fs.save(&prettier_wrapper_path, &Rope::from(PRETTIER_SERVER_JS), LineEnding::Unix).await - .with_context(|| format!("writing {PRETTIER_SERVER_FILE} file at {prettier_wrapper_path:?}"))?; + fs.save(&prettier_wrapper_path, &text::Rope::from(prettier::PRETTIER_SERVER_JS), text::LineEnding::Unix).await + .with_context(|| format!("writing {} file at {prettier_wrapper_path:?}", prettier::PRETTIER_SERVER_FILE))?; let packages_to_versions = future::try_join_all( prettier_plugins diff --git a/crates/semantic_index/src/embedding_queue.rs b/crates/semantic_index/src/embedding_queue.rs index 6ae8faa4cd..d57d5c7bbe 100644 --- a/crates/semantic_index/src/embedding_queue.rs +++ b/crates/semantic_index/src/embedding_queue.rs @@ -41,6 +41,7 @@ pub struct EmbeddingQueue { pending_batch_token_count: usize, finished_files_tx: channel::Sender, finished_files_rx: channel::Receiver, + api_key: Option, } #[derive(Clone)] @@ -50,7 +51,11 @@ pub struct FileFragmentToEmbed { } impl EmbeddingQueue { - pub fn new(embedding_provider: Arc, executor: Arc) -> Self { + pub fn new( + embedding_provider: Arc, + executor: Arc, + api_key: Option, + ) -> Self { let (finished_files_tx, finished_files_rx) = channel::unbounded(); Self { embedding_provider, @@ -59,9 +64,14 @@ impl EmbeddingQueue { pending_batch_token_count: 0, finished_files_tx, finished_files_rx, + api_key, } } + pub fn set_api_key(&mut self, api_key: Option) { + self.api_key = api_key + } + pub fn push(&mut self, file: FileToEmbed) { if file.spans.is_empty() { self.finished_files_tx.try_send(file).unwrap(); @@ -108,6 +118,7 @@ impl EmbeddingQueue { let finished_files_tx = self.finished_files_tx.clone(); let embedding_provider = self.embedding_provider.clone(); + let api_key = self.api_key.clone(); self.executor .spawn(async move { @@ -132,7 +143,7 @@ impl EmbeddingQueue { return; }; - match embedding_provider.embed_batch(spans).await { + match embedding_provider.embed_batch(spans, api_key).await { Ok(embeddings) => { let mut embeddings = embeddings.into_iter(); for fragment in batch { diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index aae289e417..8839d25a84 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -7,10 +7,7 @@ pub mod semantic_index_settings; mod semantic_index_tests; use crate::semantic_index_settings::SemanticIndexSettings; -use ai::{ - completion::OPENAI_API_URL, - embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings}, -}; +use ai::embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings}; use anyhow::{anyhow, Result}; use collections::{BTreeMap, HashMap, HashSet}; use db::VectorDatabase; @@ -58,19 +55,6 @@ pub fn init( .join(Path::new(RELEASE_CHANNEL_NAME.as_str())) .join("embeddings_db"); - let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") { - Some(api_key) - } else if let Some((_, api_key)) = cx - .platform() - .read_credentials(OPENAI_API_URL) - .log_err() - .flatten() - { - String::from_utf8(api_key).log_err() - } else { - None - }; - cx.subscribe_global::({ move |event, cx| { let Some(semantic_index) = SemanticIndex::global(cx) else { @@ -104,7 +88,7 @@ pub fn init( let semantic_index = SemanticIndex::new( fs, db_file_path, - Arc::new(OpenAIEmbeddings::new(api_key, http_client, cx.background())), + Arc::new(OpenAIEmbeddings::new(http_client, cx.background())), language_registry, cx.clone(), ) @@ -139,6 +123,8 @@ pub struct SemanticIndex { _embedding_task: Task<()>, _parsing_files_tasks: Vec>, projects: HashMap, ProjectState>, + api_key: Option, + embedding_queue: Arc>, } struct ProjectState { @@ -284,7 +270,7 @@ pub struct SearchResult { } impl SemanticIndex { - pub fn global(cx: &AppContext) -> Option> { + pub fn global(cx: &mut AppContext) -> Option> { if cx.has_global::>() { Some(cx.global::>().clone()) } else { @@ -292,12 +278,26 @@ impl SemanticIndex { } } + pub fn authenticate(&mut self, cx: &AppContext) { + if self.api_key.is_none() { + self.api_key = self.embedding_provider.retrieve_credentials(cx); + + self.embedding_queue + .lock() + .set_api_key(self.api_key.clone()); + } + } + + pub fn is_authenticated(&self) -> bool { + self.api_key.is_some() + } + pub fn enabled(cx: &AppContext) -> bool { settings::get::(cx).enabled } pub fn status(&self, project: &ModelHandle) -> SemanticIndexStatus { - if !self.embedding_provider.is_authenticated() { + if !self.is_authenticated() { return SemanticIndexStatus::NotAuthenticated; } @@ -339,7 +339,7 @@ impl SemanticIndex { Ok(cx.add_model(|cx| { let t0 = Instant::now(); let embedding_queue = - EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone()); + EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone(), None); let _embedding_task = cx.background().spawn({ let embedded_files = embedding_queue.finished_files(); let db = db.clone(); @@ -404,6 +404,8 @@ impl SemanticIndex { _embedding_task, _parsing_files_tasks, projects: Default::default(), + api_key: None, + embedding_queue } })) } @@ -718,12 +720,13 @@ impl SemanticIndex { let index = self.index_project(project.clone(), cx); let embedding_provider = self.embedding_provider.clone(); + let api_key = self.api_key.clone(); cx.spawn(|this, mut cx| async move { index.await?; let t0 = Instant::now(); let query = embedding_provider - .embed_batch(vec![query]) + .embed_batch(vec![query], api_key) .await? .pop() .ok_or_else(|| anyhow!("could not embed query"))?; @@ -941,6 +944,7 @@ impl SemanticIndex { let fs = self.fs.clone(); let db_path = self.db.path().clone(); let background = cx.background().clone(); + let api_key = self.api_key.clone(); cx.background().spawn(async move { let db = VectorDatabase::new(fs, db_path.clone(), background).await?; let mut results = Vec::::new(); @@ -955,10 +959,15 @@ impl SemanticIndex { .parse_file_with_template(None, &snapshot.text(), language) .log_err() .unwrap_or_default(); - if Self::embed_spans(&mut spans, embedding_provider.as_ref(), &db) - .await - .log_err() - .is_some() + if Self::embed_spans( + &mut spans, + embedding_provider.as_ref(), + &db, + api_key.clone(), + ) + .await + .log_err() + .is_some() { for span in spans { let similarity = span.embedding.unwrap().similarity(&query); @@ -998,8 +1007,11 @@ impl SemanticIndex { project: ModelHandle, cx: &mut ModelContext, ) -> Task> { - if !self.embedding_provider.is_authenticated() { - return Task::ready(Err(anyhow!("user is not authenticated"))); + if self.api_key.is_none() { + self.authenticate(cx); + if self.api_key.is_none() { + return Task::ready(Err(anyhow!("user is not authenticated"))); + } } if !self.projects.contains_key(&project.downgrade()) { @@ -1180,6 +1192,7 @@ impl SemanticIndex { spans: &mut [Span], embedding_provider: &dyn EmbeddingProvider, db: &VectorDatabase, + api_key: Option, ) -> Result<()> { let mut batch = Vec::new(); let mut batch_tokens = 0; @@ -1202,7 +1215,7 @@ impl SemanticIndex { if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() { let batch_embeddings = embedding_provider - .embed_batch(mem::take(&mut batch)) + .embed_batch(mem::take(&mut batch), api_key.clone()) .await?; embeddings.extend(batch_embeddings); batch_tokens = 0; @@ -1214,7 +1227,7 @@ impl SemanticIndex { if !batch.is_empty() { let batch_embeddings = embedding_provider - .embed_batch(mem::take(&mut batch)) + .embed_batch(mem::take(&mut batch), api_key) .await?; embeddings.extend(batch_embeddings); diff --git a/crates/semantic_index/src/semantic_index_tests.rs b/crates/semantic_index/src/semantic_index_tests.rs index 182010ca83..a1ee3e5ada 100644 --- a/crates/semantic_index/src/semantic_index_tests.rs +++ b/crates/semantic_index/src/semantic_index_tests.rs @@ -7,7 +7,7 @@ use crate::{ use ai::embedding::{DummyEmbeddings, Embedding, EmbeddingProvider}; use anyhow::Result; use async_trait::async_trait; -use gpui::{executor::Deterministic, Task, TestAppContext}; +use gpui::{executor::Deterministic, AppContext, Task, TestAppContext}; use language::{Language, LanguageConfig, LanguageRegistry, ToOffset}; use parking_lot::Mutex; use pretty_assertions::assert_eq; @@ -228,7 +228,7 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) { let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); - let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background()); + let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background(), None); for file in &files { queue.push(file.clone()); } @@ -1281,8 +1281,8 @@ impl FakeEmbeddingProvider { #[async_trait] impl EmbeddingProvider for FakeEmbeddingProvider { - fn is_authenticated(&self) -> bool { - true + fn retrieve_credentials(&self, _cx: &AppContext) -> Option { + Some("Fake Credentials".to_string()) } fn truncate(&self, span: &str) -> (String, usize) { (span.to_string(), 1) @@ -1296,7 +1296,11 @@ impl EmbeddingProvider for FakeEmbeddingProvider { None } - async fn embed_batch(&self, spans: Vec) -> Result> { + async fn embed_batch( + &self, + spans: Vec, + _api_key: Option, + ) -> Result> { self.embedding_count .fetch_add(spans.len(), atomic::Ordering::SeqCst); Ok(spans.iter().map(|span| self.embed_sync(span)).collect()) diff --git a/crates/zed/examples/semantic_index_eval.rs b/crates/zed/examples/semantic_index_eval.rs index 73b3b9987b..e750307800 100644 --- a/crates/zed/examples/semantic_index_eval.rs +++ b/crates/zed/examples/semantic_index_eval.rs @@ -1,4 +1,3 @@ -use ai::completion::OPENAI_API_URL; use ai::embedding::OpenAIEmbeddings; use anyhow::{anyhow, Result}; use client::{self, UserStore}; @@ -18,7 +17,6 @@ use std::{cmp, env, fs}; use util::channel::{RELEASE_CHANNEL, RELEASE_CHANNEL_NAME}; use util::http::{self}; use util::paths::EMBEDDINGS_DIR; -use util::ResultExt; use zed::languages; #[derive(Deserialize, Clone, Serialize)] @@ -57,7 +55,7 @@ fn parse_eval() -> anyhow::Result> { .as_path() .parent() .unwrap() - .join("crates/semantic_index/eval"); + .join("zed/crates/semantic_index/eval"); let mut repo_evals: Vec = Vec::new(); for entry in fs::read_dir(eval_folder)? { @@ -472,25 +470,12 @@ fn main() { let languages = languages.clone(); - let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") { - Some(api_key) - } else if let Some((_, api_key)) = cx - .platform() - .read_credentials(OPENAI_API_URL) - .log_err() - .flatten() - { - String::from_utf8(api_key).log_err() - } else { - None - }; - let fs = fs.clone(); cx.spawn(|mut cx| async move { let semantic_index = SemanticIndex::new( fs.clone(), db_file_path, - Arc::new(OpenAIEmbeddings::new(api_key, http_client, cx.background())), + Arc::new(OpenAIEmbeddings::new(http_client, cx.background())), languages.clone(), cx.clone(), ) diff --git a/script/evaluate_semantic_index b/script/evaluate_semantic_index index 8dcb53c399..9ecfe898c5 100755 --- a/script/evaluate_semantic_index +++ b/script/evaluate_semantic_index @@ -1,3 +1,3 @@ #!/bin/bash -RUST_LOG=semantic_index=trace cargo run -p semantic_index --example eval --release +RUST_LOG=semantic_index=trace cargo run --example semantic_index_eval --release