diff --git a/crates/ai/src/embedding.rs b/crates/ai/src/embedding.rs index 4587ece0a2..4d5e40fad9 100644 --- a/crates/ai/src/embedding.rs +++ b/crates/ai/src/embedding.rs @@ -2,7 +2,7 @@ use anyhow::{anyhow, Result}; use async_trait::async_trait; use futures::AsyncReadExt; use gpui::executor::Background; -use gpui::serde_json; +use gpui::{serde_json, ViewContext}; use isahc::http::StatusCode; use isahc::prelude::Configurable; use isahc::{AsyncBody, Response}; @@ -20,9 +20,11 @@ use std::sync::Arc; use std::time::{Duration, Instant}; use tiktoken_rs::{cl100k_base, CoreBPE}; use util::http::{HttpClient, Request}; +use util::ResultExt; + +use crate::completion::OPENAI_API_URL; lazy_static! { - static ref OPENAI_API_KEY: Option = env::var("OPENAI_API_KEY").ok(); static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap(); } @@ -87,6 +89,7 @@ impl Embedding { #[derive(Clone)] pub struct OpenAIEmbeddings { + pub api_key: Option, pub client: Arc, pub executor: Arc, rate_limit_count_rx: watch::Receiver>, @@ -166,11 +169,36 @@ impl EmbeddingProvider for DummyEmbeddings { const OPENAI_INPUT_LIMIT: usize = 8190; impl OpenAIEmbeddings { - pub fn new(client: Arc, executor: Arc) -> Self { + pub fn authenticate(&mut self, cx: &mut ViewContext) { + if self.api_key.is_none() { + let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") { + Some(api_key) + } else if let Some((_, api_key)) = cx + .platform() + .read_credentials(OPENAI_API_URL) + .log_err() + .flatten() + { + String::from_utf8(api_key).log_err() + } else { + None + }; + + if let Some(api_key) = api_key { + self.api_key = Some(api_key); + } + } + } + pub fn new( + api_key: Option, + client: Arc, + executor: Arc, + ) -> Self { let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None); let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx)); OpenAIEmbeddings { + api_key, client, executor, rate_limit_count_rx, @@ -237,8 +265,9 @@ impl OpenAIEmbeddings { #[async_trait] impl EmbeddingProvider for OpenAIEmbeddings { fn is_authenticated(&self) -> bool { - OPENAI_API_KEY.as_ref().is_some() + self.api_key.is_some() } + fn max_tokens_per_batch(&self) -> usize { 50000 } @@ -265,9 +294,9 @@ impl EmbeddingProvider for OpenAIEmbeddings { const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45]; const MAX_RETRIES: usize = 4; - let api_key = OPENAI_API_KEY - .as_ref() - .ok_or_else(|| anyhow!("no api key"))?; + let Some(api_key) = self.api_key.clone() else { + return Err(anyhow!("no open ai key provided")); + }; let mut request_number = 0; let mut rate_limiting = false; @@ -276,7 +305,7 @@ impl EmbeddingProvider for OpenAIEmbeddings { while request_number < MAX_RETRIES { response = self .send_request( - api_key, + &api_key, spans.iter().map(|x| &**x).collect(), request_timeout, ) diff --git a/crates/search/src/project_search.rs b/crates/search/src/project_search.rs index c03e5dc80e..55e3f6babd 100644 --- a/crates/search/src/project_search.rs +++ b/crates/search/src/project_search.rs @@ -351,33 +351,32 @@ impl View for ProjectSearchView { SemanticIndexStatus::NotAuthenticated => { major_text = Cow::Borrowed("Not Authenticated"); show_minor_text = false; - Some( - "API Key Missing: Please set 'OPENAI_API_KEY' in Environment Variables" - .to_string(), - ) + Some(vec![ + "API Key Missing: Please set 'OPENAI_API_KEY' in Environment Variables." + .to_string(), "If you authenticated using the Assistant Panel, please restart Zed to Authenticate.".to_string()]) } - SemanticIndexStatus::Indexed => Some("Indexing complete".to_string()), + SemanticIndexStatus::Indexed => Some(vec!["Indexing complete".to_string()]), SemanticIndexStatus::Indexing { remaining_files, rate_limit_expiry, } => { if remaining_files == 0 { - Some(format!("Indexing...")) + Some(vec![format!("Indexing...")]) } else { if let Some(rate_limit_expiry) = rate_limit_expiry { let remaining_seconds = rate_limit_expiry.duration_since(Instant::now()); if remaining_seconds > Duration::from_secs(0) { - Some(format!( + Some(vec![format!( "Remaining files to index (rate limit resets in {}s): {}", remaining_seconds.as_secs(), remaining_files - )) + )]) } else { - Some(format!("Remaining files to index: {}", remaining_files)) + Some(vec![format!("Remaining files to index: {}", remaining_files)]) } } else { - Some(format!("Remaining files to index: {}", remaining_files)) + Some(vec![format!("Remaining files to index: {}", remaining_files)]) } } } @@ -394,9 +393,11 @@ impl View for ProjectSearchView { } else { match current_mode { SearchMode::Semantic => { - let mut minor_text = Vec::new(); + let mut minor_text: Vec = Vec::new(); minor_text.push("".into()); - minor_text.extend(semantic_status); + if let Some(semantic_status) = semantic_status { + minor_text.extend(semantic_status); + } if show_minor_text { minor_text .push("Simply explain the code you are looking to find.".into()); diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index ecdba43643..aae289e417 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -7,7 +7,10 @@ pub mod semantic_index_settings; mod semantic_index_tests; use crate::semantic_index_settings::SemanticIndexSettings; -use ai::embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings}; +use ai::{ + completion::OPENAI_API_URL, + embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings}, +}; use anyhow::{anyhow, Result}; use collections::{BTreeMap, HashMap, HashSet}; use db::VectorDatabase; @@ -55,6 +58,19 @@ pub fn init( .join(Path::new(RELEASE_CHANNEL_NAME.as_str())) .join("embeddings_db"); + let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") { + Some(api_key) + } else if let Some((_, api_key)) = cx + .platform() + .read_credentials(OPENAI_API_URL) + .log_err() + .flatten() + { + String::from_utf8(api_key).log_err() + } else { + None + }; + cx.subscribe_global::({ move |event, cx| { let Some(semantic_index) = SemanticIndex::global(cx) else { @@ -88,7 +104,7 @@ pub fn init( let semantic_index = SemanticIndex::new( fs, db_file_path, - Arc::new(OpenAIEmbeddings::new(http_client, cx.background())), + Arc::new(OpenAIEmbeddings::new(api_key, http_client, cx.background())), language_registry, cx.clone(), ) diff --git a/crates/zed/examples/semantic_index_eval.rs b/crates/zed/examples/semantic_index_eval.rs index 33d6b3689c..73b3b9987b 100644 --- a/crates/zed/examples/semantic_index_eval.rs +++ b/crates/zed/examples/semantic_index_eval.rs @@ -1,3 +1,4 @@ +use ai::completion::OPENAI_API_URL; use ai::embedding::OpenAIEmbeddings; use anyhow::{anyhow, Result}; use client::{self, UserStore}; @@ -17,6 +18,7 @@ use std::{cmp, env, fs}; use util::channel::{RELEASE_CHANNEL, RELEASE_CHANNEL_NAME}; use util::http::{self}; use util::paths::EMBEDDINGS_DIR; +use util::ResultExt; use zed::languages; #[derive(Deserialize, Clone, Serialize)] @@ -469,12 +471,26 @@ fn main() { .join("embeddings_db"); let languages = languages.clone(); + + let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") { + Some(api_key) + } else if let Some((_, api_key)) = cx + .platform() + .read_credentials(OPENAI_API_URL) + .log_err() + .flatten() + { + String::from_utf8(api_key).log_err() + } else { + None + }; + let fs = fs.clone(); cx.spawn(|mut cx| async move { let semantic_index = SemanticIndex::new( fs.clone(), db_file_path, - Arc::new(OpenAIEmbeddings::new(http_client, cx.background())), + Arc::new(OpenAIEmbeddings::new(api_key, http_client, cx.background())), languages.clone(), cx.clone(), )