mirror of
https://github.com/zed-industries/zed.git
synced 2025-02-10 04:09:37 +00:00
update semantic search to use keychain as fallback (#3151)
Use the keychain for authenticating as fallback when api_key is not present in environment variables. Release Notes: - Add consistency between OPENAI_API_KEY management in Semantic Search and Assistant
This commit is contained in:
commit
ef1a69156d
4 changed files with 85 additions and 23 deletions
|
@ -2,7 +2,7 @@ use anyhow::{anyhow, Result};
|
|||
use async_trait::async_trait;
|
||||
use futures::AsyncReadExt;
|
||||
use gpui::executor::Background;
|
||||
use gpui::serde_json;
|
||||
use gpui::{serde_json, ViewContext};
|
||||
use isahc::http::StatusCode;
|
||||
use isahc::prelude::Configurable;
|
||||
use isahc::{AsyncBody, Response};
|
||||
|
@ -20,9 +20,11 @@ use std::sync::Arc;
|
|||
use std::time::{Duration, Instant};
|
||||
use tiktoken_rs::{cl100k_base, CoreBPE};
|
||||
use util::http::{HttpClient, Request};
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::completion::OPENAI_API_URL;
|
||||
|
||||
lazy_static! {
|
||||
static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
|
||||
static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
|
||||
}
|
||||
|
||||
|
@ -87,6 +89,7 @@ impl Embedding {
|
|||
|
||||
#[derive(Clone)]
|
||||
pub struct OpenAIEmbeddings {
|
||||
pub api_key: Option<String>,
|
||||
pub client: Arc<dyn HttpClient>,
|
||||
pub executor: Arc<Background>,
|
||||
rate_limit_count_rx: watch::Receiver<Option<Instant>>,
|
||||
|
@ -166,11 +169,36 @@ impl EmbeddingProvider for DummyEmbeddings {
|
|||
const OPENAI_INPUT_LIMIT: usize = 8190;
|
||||
|
||||
impl OpenAIEmbeddings {
|
||||
pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self {
|
||||
pub fn authenticate(&mut self, cx: &mut ViewContext<Self>) {
|
||||
if self.api_key.is_none() {
|
||||
let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
|
||||
Some(api_key)
|
||||
} else if let Some((_, api_key)) = cx
|
||||
.platform()
|
||||
.read_credentials(OPENAI_API_URL)
|
||||
.log_err()
|
||||
.flatten()
|
||||
{
|
||||
String::from_utf8(api_key).log_err()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if let Some(api_key) = api_key {
|
||||
self.api_key = Some(api_key);
|
||||
}
|
||||
}
|
||||
}
|
||||
pub fn new(
|
||||
api_key: Option<String>,
|
||||
client: Arc<dyn HttpClient>,
|
||||
executor: Arc<Background>,
|
||||
) -> Self {
|
||||
let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
|
||||
let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
|
||||
|
||||
OpenAIEmbeddings {
|
||||
api_key,
|
||||
client,
|
||||
executor,
|
||||
rate_limit_count_rx,
|
||||
|
@ -237,8 +265,9 @@ impl OpenAIEmbeddings {
|
|||
#[async_trait]
|
||||
impl EmbeddingProvider for OpenAIEmbeddings {
|
||||
fn is_authenticated(&self) -> bool {
|
||||
OPENAI_API_KEY.as_ref().is_some()
|
||||
self.api_key.is_some()
|
||||
}
|
||||
|
||||
fn max_tokens_per_batch(&self) -> usize {
|
||||
50000
|
||||
}
|
||||
|
@ -265,9 +294,9 @@ impl EmbeddingProvider for OpenAIEmbeddings {
|
|||
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
|
||||
const MAX_RETRIES: usize = 4;
|
||||
|
||||
let api_key = OPENAI_API_KEY
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow!("no api key"))?;
|
||||
let Some(api_key) = self.api_key.clone() else {
|
||||
return Err(anyhow!("no open ai key provided"));
|
||||
};
|
||||
|
||||
let mut request_number = 0;
|
||||
let mut rate_limiting = false;
|
||||
|
@ -276,7 +305,7 @@ impl EmbeddingProvider for OpenAIEmbeddings {
|
|||
while request_number < MAX_RETRIES {
|
||||
response = self
|
||||
.send_request(
|
||||
api_key,
|
||||
&api_key,
|
||||
spans.iter().map(|x| &**x).collect(),
|
||||
request_timeout,
|
||||
)
|
||||
|
|
|
@ -351,33 +351,32 @@ impl View for ProjectSearchView {
|
|||
SemanticIndexStatus::NotAuthenticated => {
|
||||
major_text = Cow::Borrowed("Not Authenticated");
|
||||
show_minor_text = false;
|
||||
Some(
|
||||
"API Key Missing: Please set 'OPENAI_API_KEY' in Environment Variables"
|
||||
.to_string(),
|
||||
)
|
||||
Some(vec![
|
||||
"API Key Missing: Please set 'OPENAI_API_KEY' in Environment Variables."
|
||||
.to_string(), "If you authenticated using the Assistant Panel, please restart Zed to Authenticate.".to_string()])
|
||||
}
|
||||
SemanticIndexStatus::Indexed => Some("Indexing complete".to_string()),
|
||||
SemanticIndexStatus::Indexed => Some(vec!["Indexing complete".to_string()]),
|
||||
SemanticIndexStatus::Indexing {
|
||||
remaining_files,
|
||||
rate_limit_expiry,
|
||||
} => {
|
||||
if remaining_files == 0 {
|
||||
Some(format!("Indexing..."))
|
||||
Some(vec![format!("Indexing...")])
|
||||
} else {
|
||||
if let Some(rate_limit_expiry) = rate_limit_expiry {
|
||||
let remaining_seconds =
|
||||
rate_limit_expiry.duration_since(Instant::now());
|
||||
if remaining_seconds > Duration::from_secs(0) {
|
||||
Some(format!(
|
||||
Some(vec![format!(
|
||||
"Remaining files to index (rate limit resets in {}s): {}",
|
||||
remaining_seconds.as_secs(),
|
||||
remaining_files
|
||||
))
|
||||
)])
|
||||
} else {
|
||||
Some(format!("Remaining files to index: {}", remaining_files))
|
||||
Some(vec![format!("Remaining files to index: {}", remaining_files)])
|
||||
}
|
||||
} else {
|
||||
Some(format!("Remaining files to index: {}", remaining_files))
|
||||
Some(vec![format!("Remaining files to index: {}", remaining_files)])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -394,9 +393,11 @@ impl View for ProjectSearchView {
|
|||
} else {
|
||||
match current_mode {
|
||||
SearchMode::Semantic => {
|
||||
let mut minor_text = Vec::new();
|
||||
let mut minor_text: Vec<String> = Vec::new();
|
||||
minor_text.push("".into());
|
||||
minor_text.extend(semantic_status);
|
||||
if let Some(semantic_status) = semantic_status {
|
||||
minor_text.extend(semantic_status);
|
||||
}
|
||||
if show_minor_text {
|
||||
minor_text
|
||||
.push("Simply explain the code you are looking to find.".into());
|
||||
|
|
|
@ -7,7 +7,10 @@ pub mod semantic_index_settings;
|
|||
mod semantic_index_tests;
|
||||
|
||||
use crate::semantic_index_settings::SemanticIndexSettings;
|
||||
use ai::embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings};
|
||||
use ai::{
|
||||
completion::OPENAI_API_URL,
|
||||
embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings},
|
||||
};
|
||||
use anyhow::{anyhow, Result};
|
||||
use collections::{BTreeMap, HashMap, HashSet};
|
||||
use db::VectorDatabase;
|
||||
|
@ -55,6 +58,19 @@ pub fn init(
|
|||
.join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
|
||||
.join("embeddings_db");
|
||||
|
||||
let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
|
||||
Some(api_key)
|
||||
} else if let Some((_, api_key)) = cx
|
||||
.platform()
|
||||
.read_credentials(OPENAI_API_URL)
|
||||
.log_err()
|
||||
.flatten()
|
||||
{
|
||||
String::from_utf8(api_key).log_err()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
cx.subscribe_global::<WorkspaceCreated, _>({
|
||||
move |event, cx| {
|
||||
let Some(semantic_index) = SemanticIndex::global(cx) else {
|
||||
|
@ -88,7 +104,7 @@ pub fn init(
|
|||
let semantic_index = SemanticIndex::new(
|
||||
fs,
|
||||
db_file_path,
|
||||
Arc::new(OpenAIEmbeddings::new(http_client, cx.background())),
|
||||
Arc::new(OpenAIEmbeddings::new(api_key, http_client, cx.background())),
|
||||
language_registry,
|
||||
cx.clone(),
|
||||
)
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
use ai::completion::OPENAI_API_URL;
|
||||
use ai::embedding::OpenAIEmbeddings;
|
||||
use anyhow::{anyhow, Result};
|
||||
use client::{self, UserStore};
|
||||
|
@ -17,6 +18,7 @@ use std::{cmp, env, fs};
|
|||
use util::channel::{RELEASE_CHANNEL, RELEASE_CHANNEL_NAME};
|
||||
use util::http::{self};
|
||||
use util::paths::EMBEDDINGS_DIR;
|
||||
use util::ResultExt;
|
||||
use zed::languages;
|
||||
|
||||
#[derive(Deserialize, Clone, Serialize)]
|
||||
|
@ -469,12 +471,26 @@ fn main() {
|
|||
.join("embeddings_db");
|
||||
|
||||
let languages = languages.clone();
|
||||
|
||||
let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
|
||||
Some(api_key)
|
||||
} else if let Some((_, api_key)) = cx
|
||||
.platform()
|
||||
.read_credentials(OPENAI_API_URL)
|
||||
.log_err()
|
||||
.flatten()
|
||||
{
|
||||
String::from_utf8(api_key).log_err()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let fs = fs.clone();
|
||||
cx.spawn(|mut cx| async move {
|
||||
let semantic_index = SemanticIndex::new(
|
||||
fs.clone(),
|
||||
db_file_path,
|
||||
Arc::new(OpenAIEmbeddings::new(http_client, cx.background())),
|
||||
Arc::new(OpenAIEmbeddings::new(api_key, http_client, cx.background())),
|
||||
languages.clone(),
|
||||
cx.clone(),
|
||||
)
|
||||
|
|
Loading…
Reference in a new issue