From 3447a9478c62476728f1e0131d708699dad2bcd1 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Thu, 26 Oct 2023 11:18:16 +0200 Subject: [PATCH] updated authentication for embedding provider --- crates/ai/Cargo.toml | 3 + crates/ai/src/ai.rs | 3 + crates/ai/src/auth.rs | 20 +++ crates/ai/src/embedding.rs | 8 +- crates/ai/src/prompts/base.rs | 41 +----- crates/ai/src/providers/dummy.rs | 85 ------------ crates/ai/src/providers/mod.rs | 1 - crates/ai/src/providers/open_ai/auth.rs | 33 +++++ crates/ai/src/providers/open_ai/embedding.rs | 46 ++----- crates/ai/src/providers/open_ai/mod.rs | 1 + crates/ai/src/test.rs | 123 ++++++++++++++++++ crates/assistant/src/codegen.rs | 14 +- crates/semantic_index/Cargo.toml | 1 + crates/semantic_index/src/embedding_queue.rs | 16 +-- crates/semantic_index/src/semantic_index.rs | 52 +++++--- .../src/semantic_index_tests.rs | 101 +++----------- 16 files changed, 277 insertions(+), 271 deletions(-) create mode 100644 crates/ai/src/auth.rs delete mode 100644 crates/ai/src/providers/dummy.rs create mode 100644 crates/ai/src/providers/open_ai/auth.rs create mode 100644 crates/ai/src/test.rs diff --git a/crates/ai/Cargo.toml b/crates/ai/Cargo.toml index b24c4e5ece..fb49a4b515 100644 --- a/crates/ai/Cargo.toml +++ b/crates/ai/Cargo.toml @@ -8,6 +8,9 @@ publish = false path = "src/ai.rs" doctest = false +[features] +test-support = [] + [dependencies] gpui = { path = "../gpui" } util = { path = "../util" } diff --git a/crates/ai/src/ai.rs b/crates/ai/src/ai.rs index a3ae2fcf7f..dda22d2a1d 100644 --- a/crates/ai/src/ai.rs +++ b/crates/ai/src/ai.rs @@ -1,5 +1,8 @@ +pub mod auth; pub mod completion; pub mod embedding; pub mod models; pub mod prompts; pub mod providers; +#[cfg(any(test, feature = "test-support"))] +pub mod test; diff --git a/crates/ai/src/auth.rs b/crates/ai/src/auth.rs new file mode 100644 index 0000000000..a3ce8aece1 --- /dev/null +++ b/crates/ai/src/auth.rs @@ -0,0 +1,20 @@ +use gpui::AppContext; + +#[derive(Clone)] +pub enum ProviderCredential { + Credentials { api_key: String }, + NoCredentials, + NotNeeded, +} + +pub trait CredentialProvider: Send + Sync { + fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential; +} + +#[derive(Clone)] +pub struct NullCredentialProvider; +impl CredentialProvider for NullCredentialProvider { + fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential { + ProviderCredential::NotNeeded + } +} diff --git a/crates/ai/src/embedding.rs b/crates/ai/src/embedding.rs index 8cfc901525..50f04232ab 100644 --- a/crates/ai/src/embedding.rs +++ b/crates/ai/src/embedding.rs @@ -7,6 +7,7 @@ use ordered_float::OrderedFloat; use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef}; use rusqlite::ToSql; +use crate::auth::{CredentialProvider, ProviderCredential}; use crate::models::LanguageModel; #[derive(Debug, PartialEq, Clone)] @@ -71,11 +72,14 @@ impl Embedding { #[async_trait] pub trait EmbeddingProvider: Sync + Send { fn base_model(&self) -> Box; - fn retrieve_credentials(&self, cx: &AppContext) -> Option; + fn credential_provider(&self) -> Box; + fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential { + self.credential_provider().retrieve_credentials(cx) + } async fn embed_batch( &self, spans: Vec, - api_key: Option, + credential: ProviderCredential, ) -> Result>; fn max_tokens_per_batch(&self) -> usize; fn rate_limit_expiration(&self) -> Option; diff --git a/crates/ai/src/prompts/base.rs b/crates/ai/src/prompts/base.rs index f0ff597e63..a2106c7410 100644 --- a/crates/ai/src/prompts/base.rs +++ b/crates/ai/src/prompts/base.rs @@ -126,6 +126,7 @@ impl PromptChain { #[cfg(test)] pub(crate) mod tests { use crate::models::TruncationDirection; + use crate::test::FakeLanguageModel; use super::*; @@ -181,39 +182,7 @@ pub(crate) mod tests { } } - #[derive(Clone)] - struct DummyLanguageModel { - capacity: usize, - } - - impl LanguageModel for DummyLanguageModel { - fn name(&self) -> String { - "dummy".to_string() - } - fn count_tokens(&self, content: &str) -> anyhow::Result { - anyhow::Ok(content.chars().collect::>().len()) - } - fn truncate( - &self, - content: &str, - length: usize, - direction: TruncationDirection, - ) -> anyhow::Result { - anyhow::Ok(match direction { - TruncationDirection::End => content.chars().collect::>()[..length] - .into_iter() - .collect::(), - TruncationDirection::Start => content.chars().collect::>()[length..] - .into_iter() - .collect::(), - }) - } - fn capacity(&self) -> anyhow::Result { - anyhow::Ok(self.capacity) - } - } - - let model: Arc = Arc::new(DummyLanguageModel { capacity: 100 }); + let model: Arc = Arc::new(FakeLanguageModel { capacity: 100 }); let args = PromptArguments { model: model.clone(), language_name: None, @@ -249,7 +218,7 @@ pub(crate) mod tests { // Testing with Truncation Off // Should ignore capacity and return all prompts - let model: Arc = Arc::new(DummyLanguageModel { capacity: 20 }); + let model: Arc = Arc::new(FakeLanguageModel { capacity: 20 }); let args = PromptArguments { model: model.clone(), language_name: None, @@ -286,7 +255,7 @@ pub(crate) mod tests { // Testing with Truncation Off // Should ignore capacity and return all prompts let capacity = 20; - let model: Arc = Arc::new(DummyLanguageModel { capacity }); + let model: Arc = Arc::new(FakeLanguageModel { capacity }); let args = PromptArguments { model: model.clone(), language_name: None, @@ -322,7 +291,7 @@ pub(crate) mod tests { // Change Ordering of Prompts Based on Priority let capacity = 120; let reserved_tokens = 10; - let model: Arc = Arc::new(DummyLanguageModel { capacity }); + let model: Arc = Arc::new(FakeLanguageModel { capacity }); let args = PromptArguments { model: model.clone(), language_name: None, diff --git a/crates/ai/src/providers/dummy.rs b/crates/ai/src/providers/dummy.rs deleted file mode 100644 index 2ee26488bd..0000000000 --- a/crates/ai/src/providers/dummy.rs +++ /dev/null @@ -1,85 +0,0 @@ -use std::time::Instant; - -use crate::{ - completion::CompletionRequest, - embedding::{Embedding, EmbeddingProvider}, - models::{LanguageModel, TruncationDirection}, -}; -use async_trait::async_trait; -use gpui::AppContext; -use serde::Serialize; - -pub struct DummyLanguageModel {} - -impl LanguageModel for DummyLanguageModel { - fn name(&self) -> String { - "dummy".to_string() - } - fn capacity(&self) -> anyhow::Result { - anyhow::Ok(1000) - } - fn truncate( - &self, - content: &str, - length: usize, - direction: crate::models::TruncationDirection, - ) -> anyhow::Result { - if content.len() < length { - return anyhow::Ok(content.to_string()); - } - - let truncated = match direction { - TruncationDirection::End => content.chars().collect::>()[..length] - .iter() - .collect::(), - TruncationDirection::Start => content.chars().collect::>()[..length] - .iter() - .collect::(), - }; - - anyhow::Ok(truncated) - } - fn count_tokens(&self, content: &str) -> anyhow::Result { - anyhow::Ok(content.chars().collect::>().len()) - } -} - -#[derive(Serialize)] -pub struct DummyCompletionRequest { - pub name: String, -} - -impl CompletionRequest for DummyCompletionRequest { - fn data(&self) -> serde_json::Result { - serde_json::to_string(self) - } -} - -pub struct DummyEmbeddingProvider {} - -#[async_trait] -impl EmbeddingProvider for DummyEmbeddingProvider { - fn retrieve_credentials(&self, _cx: &AppContext) -> Option { - Some("Dummy Credentials".to_string()) - } - fn base_model(&self) -> Box { - Box::new(DummyLanguageModel {}) - } - fn rate_limit_expiration(&self) -> Option { - None - } - async fn embed_batch( - &self, - spans: Vec, - api_key: Option, - ) -> anyhow::Result> { - // 1024 is the OpenAI Embeddings size for ada models. - // the model we will likely be starting with. - let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]); - return Ok(vec![dummy_vec; spans.len()]); - } - - fn max_tokens_per_batch(&self) -> usize { - 8190 - } -} diff --git a/crates/ai/src/providers/mod.rs b/crates/ai/src/providers/mod.rs index 7a7092baf3..acd0f9d910 100644 --- a/crates/ai/src/providers/mod.rs +++ b/crates/ai/src/providers/mod.rs @@ -1,2 +1 @@ -pub mod dummy; pub mod open_ai; diff --git a/crates/ai/src/providers/open_ai/auth.rs b/crates/ai/src/providers/open_ai/auth.rs new file mode 100644 index 0000000000..c817ffea00 --- /dev/null +++ b/crates/ai/src/providers/open_ai/auth.rs @@ -0,0 +1,33 @@ +use std::env; + +use gpui::AppContext; +use util::ResultExt; + +use crate::auth::{CredentialProvider, ProviderCredential}; +use crate::providers::open_ai::OPENAI_API_URL; + +#[derive(Clone)] +pub struct OpenAICredentialProvider {} + +impl CredentialProvider for OpenAICredentialProvider { + fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential { + let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") { + Some(api_key) + } else if let Some((_, api_key)) = cx + .platform() + .read_credentials(OPENAI_API_URL) + .log_err() + .flatten() + { + String::from_utf8(api_key).log_err() + } else { + None + }; + + if let Some(api_key) = api_key { + ProviderCredential::Credentials { api_key } + } else { + ProviderCredential::NoCredentials + } + } +} diff --git a/crates/ai/src/providers/open_ai/embedding.rs b/crates/ai/src/providers/open_ai/embedding.rs index 805a906dda..1385b32b4d 100644 --- a/crates/ai/src/providers/open_ai/embedding.rs +++ b/crates/ai/src/providers/open_ai/embedding.rs @@ -2,7 +2,7 @@ use anyhow::{anyhow, Result}; use async_trait::async_trait; use futures::AsyncReadExt; use gpui::executor::Background; -use gpui::{serde_json, AppContext}; +use gpui::serde_json; use isahc::http::StatusCode; use isahc::prelude::Configurable; use isahc::{AsyncBody, Response}; @@ -17,13 +17,13 @@ use std::sync::Arc; use std::time::{Duration, Instant}; use tiktoken_rs::{cl100k_base, CoreBPE}; use util::http::{HttpClient, Request}; -use util::ResultExt; +use crate::auth::{CredentialProvider, ProviderCredential}; use crate::embedding::{Embedding, EmbeddingProvider}; use crate::models::LanguageModel; use crate::providers::open_ai::OpenAILanguageModel; -use super::OPENAI_API_URL; +use crate::providers::open_ai::auth::OpenAICredentialProvider; lazy_static! { static ref OPENAI_API_KEY: Option = env::var("OPENAI_API_KEY").ok(); @@ -33,6 +33,7 @@ lazy_static! { #[derive(Clone)] pub struct OpenAIEmbeddingProvider { model: OpenAILanguageModel, + credential_provider: OpenAICredentialProvider, pub client: Arc, pub executor: Arc, rate_limit_count_rx: watch::Receiver>, @@ -73,6 +74,7 @@ impl OpenAIEmbeddingProvider { OpenAIEmbeddingProvider { model, + credential_provider: OpenAICredentialProvider {}, client, executor, rate_limit_count_rx, @@ -138,25 +140,17 @@ impl OpenAIEmbeddingProvider { #[async_trait] impl EmbeddingProvider for OpenAIEmbeddingProvider { - fn retrieve_credentials(&self, cx: &AppContext) -> Option { - let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") { - Some(api_key) - } else if let Some((_, api_key)) = cx - .platform() - .read_credentials(OPENAI_API_URL) - .log_err() - .flatten() - { - String::from_utf8(api_key).log_err() - } else { - None - }; - api_key - } fn base_model(&self) -> Box { let model: Box = Box::new(self.model.clone()); model } + + fn credential_provider(&self) -> Box { + let credential_provider: Box = + Box::new(self.credential_provider.clone()); + credential_provider + } + fn max_tokens_per_batch(&self) -> usize { 50000 } @@ -164,25 +158,11 @@ impl EmbeddingProvider for OpenAIEmbeddingProvider { fn rate_limit_expiration(&self) -> Option { *self.rate_limit_count_rx.borrow() } - // fn truncate(&self, span: &str) -> (String, usize) { - // let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); - // let output = if tokens.len() > OPENAI_INPUT_LIMIT { - // tokens.truncate(OPENAI_INPUT_LIMIT); - // OPENAI_BPE_TOKENIZER - // .decode(tokens.clone()) - // .ok() - // .unwrap_or_else(|| span.to_string()) - // } else { - // span.to_string() - // }; - - // (output, tokens.len()) - // } async fn embed_batch( &self, spans: Vec, - api_key: Option, + _credential: ProviderCredential, ) -> Result> { const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45]; const MAX_RETRIES: usize = 4; diff --git a/crates/ai/src/providers/open_ai/mod.rs b/crates/ai/src/providers/open_ai/mod.rs index 67cb2b5315..49e29fbc8c 100644 --- a/crates/ai/src/providers/open_ai/mod.rs +++ b/crates/ai/src/providers/open_ai/mod.rs @@ -1,3 +1,4 @@ +pub mod auth; pub mod completion; pub mod embedding; pub mod model; diff --git a/crates/ai/src/test.rs b/crates/ai/src/test.rs new file mode 100644 index 0000000000..d8805bad1a --- /dev/null +++ b/crates/ai/src/test.rs @@ -0,0 +1,123 @@ +use std::{ + sync::atomic::{self, AtomicUsize, Ordering}, + time::Instant, +}; + +use async_trait::async_trait; + +use crate::{ + auth::{CredentialProvider, NullCredentialProvider, ProviderCredential}, + embedding::{Embedding, EmbeddingProvider}, + models::{LanguageModel, TruncationDirection}, +}; + +#[derive(Clone)] +pub struct FakeLanguageModel { + pub capacity: usize, +} + +impl LanguageModel for FakeLanguageModel { + fn name(&self) -> String { + "dummy".to_string() + } + fn count_tokens(&self, content: &str) -> anyhow::Result { + anyhow::Ok(content.chars().collect::>().len()) + } + fn truncate( + &self, + content: &str, + length: usize, + direction: TruncationDirection, + ) -> anyhow::Result { + anyhow::Ok(match direction { + TruncationDirection::End => content.chars().collect::>()[..length] + .into_iter() + .collect::(), + TruncationDirection::Start => content.chars().collect::>()[length..] + .into_iter() + .collect::(), + }) + } + fn capacity(&self) -> anyhow::Result { + anyhow::Ok(self.capacity) + } +} + +pub struct FakeEmbeddingProvider { + pub embedding_count: AtomicUsize, + pub credential_provider: NullCredentialProvider, +} + +impl Clone for FakeEmbeddingProvider { + fn clone(&self) -> Self { + FakeEmbeddingProvider { + embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)), + credential_provider: self.credential_provider.clone(), + } + } +} + +impl Default for FakeEmbeddingProvider { + fn default() -> Self { + FakeEmbeddingProvider { + embedding_count: AtomicUsize::default(), + credential_provider: NullCredentialProvider {}, + } + } +} + +impl FakeEmbeddingProvider { + pub fn embedding_count(&self) -> usize { + self.embedding_count.load(atomic::Ordering::SeqCst) + } + + pub fn embed_sync(&self, span: &str) -> Embedding { + let mut result = vec![1.0; 26]; + for letter in span.chars() { + let letter = letter.to_ascii_lowercase(); + if letter as u32 >= 'a' as u32 { + let ix = (letter as u32) - ('a' as u32); + if ix < 26 { + result[ix as usize] += 1.0; + } + } + } + + let norm = result.iter().map(|x| x * x).sum::().sqrt(); + for x in &mut result { + *x /= norm; + } + + result.into() + } +} + +#[async_trait] +impl EmbeddingProvider for FakeEmbeddingProvider { + fn base_model(&self) -> Box { + Box::new(FakeLanguageModel { capacity: 1000 }) + } + fn credential_provider(&self) -> Box { + let credential_provider: Box = + Box::new(self.credential_provider.clone()); + credential_provider + } + fn max_tokens_per_batch(&self) -> usize { + 1000 + } + + fn rate_limit_expiration(&self) -> Option { + None + } + + async fn embed_batch( + &self, + spans: Vec, + _credential: ProviderCredential, + ) -> anyhow::Result> { + self.embedding_count + .fetch_add(spans.len(), atomic::Ordering::SeqCst); + + anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect()) + } +} diff --git a/crates/assistant/src/codegen.rs b/crates/assistant/src/codegen.rs index e535eca144..e71b1ae2cb 100644 --- a/crates/assistant/src/codegen.rs +++ b/crates/assistant/src/codegen.rs @@ -335,7 +335,6 @@ fn strip_markdown_codeblock( #[cfg(test)] mod tests { use super::*; - use ai::providers::dummy::DummyCompletionRequest; use futures::{ future::BoxFuture, stream::{self, BoxStream}, @@ -345,9 +344,21 @@ mod tests { use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point}; use parking_lot::Mutex; use rand::prelude::*; + use serde::Serialize; use settings::SettingsStore; use smol::future::FutureExt; + #[derive(Serialize)] + pub struct DummyCompletionRequest { + pub name: String, + } + + impl CompletionRequest for DummyCompletionRequest { + fn data(&self) -> serde_json::Result { + serde_json::to_string(self) + } + } + #[gpui::test(iterations = 10)] async fn test_transform_autoindent( cx: &mut TestAppContext, @@ -381,6 +392,7 @@ mod tests { cx, ) }); + let request = Box::new(DummyCompletionRequest { name: "test".to_string(), }); diff --git a/crates/semantic_index/Cargo.toml b/crates/semantic_index/Cargo.toml index 1febb2af78..875440ef3f 100644 --- a/crates/semantic_index/Cargo.toml +++ b/crates/semantic_index/Cargo.toml @@ -42,6 +42,7 @@ sha1 = "0.10.5" ndarray = { version = "0.15.0" } [dev-dependencies] +ai = { path = "../ai", features = ["test-support"] } collections = { path = "../collections", features = ["test-support"] } gpui = { path = "../gpui", features = ["test-support"] } language = { path = "../language", features = ["test-support"] } diff --git a/crates/semantic_index/src/embedding_queue.rs b/crates/semantic_index/src/embedding_queue.rs index d57d5c7bbe..9ca6d8a0d9 100644 --- a/crates/semantic_index/src/embedding_queue.rs +++ b/crates/semantic_index/src/embedding_queue.rs @@ -1,5 +1,5 @@ use crate::{parsing::Span, JobHandle}; -use ai::embedding::EmbeddingProvider; +use ai::{auth::ProviderCredential, embedding::EmbeddingProvider}; use gpui::executor::Background; use parking_lot::Mutex; use smol::channel; @@ -41,7 +41,7 @@ pub struct EmbeddingQueue { pending_batch_token_count: usize, finished_files_tx: channel::Sender, finished_files_rx: channel::Receiver, - api_key: Option, + provider_credential: ProviderCredential, } #[derive(Clone)] @@ -54,7 +54,7 @@ impl EmbeddingQueue { pub fn new( embedding_provider: Arc, executor: Arc, - api_key: Option, + provider_credential: ProviderCredential, ) -> Self { let (finished_files_tx, finished_files_rx) = channel::unbounded(); Self { @@ -64,12 +64,12 @@ impl EmbeddingQueue { pending_batch_token_count: 0, finished_files_tx, finished_files_rx, - api_key, + provider_credential, } } - pub fn set_api_key(&mut self, api_key: Option) { - self.api_key = api_key + pub fn set_credential(&mut self, credential: ProviderCredential) { + self.provider_credential = credential } pub fn push(&mut self, file: FileToEmbed) { @@ -118,7 +118,7 @@ impl EmbeddingQueue { let finished_files_tx = self.finished_files_tx.clone(); let embedding_provider = self.embedding_provider.clone(); - let api_key = self.api_key.clone(); + let credential = self.provider_credential.clone(); self.executor .spawn(async move { @@ -143,7 +143,7 @@ impl EmbeddingQueue { return; }; - match embedding_provider.embed_batch(spans, api_key).await { + match embedding_provider.embed_batch(spans, credential).await { Ok(embeddings) => { let mut embeddings = embeddings.into_iter(); for fragment in batch { diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 6863918d5d..5be3d6ccf5 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -7,6 +7,7 @@ pub mod semantic_index_settings; mod semantic_index_tests; use crate::semantic_index_settings::SemanticIndexSettings; +use ai::auth::ProviderCredential; use ai::embedding::{Embedding, EmbeddingProvider}; use ai::providers::open_ai::OpenAIEmbeddingProvider; use anyhow::{anyhow, Result}; @@ -124,7 +125,7 @@ pub struct SemanticIndex { _embedding_task: Task<()>, _parsing_files_tasks: Vec>, projects: HashMap, ProjectState>, - api_key: Option, + provider_credential: ProviderCredential, embedding_queue: Arc>, } @@ -279,18 +280,27 @@ impl SemanticIndex { } } - pub fn authenticate(&mut self, cx: &AppContext) { - if self.api_key.is_none() { - self.api_key = self.embedding_provider.retrieve_credentials(cx); - - self.embedding_queue - .lock() - .set_api_key(self.api_key.clone()); + pub fn authenticate(&mut self, cx: &AppContext) -> bool { + let credential = self.provider_credential.clone(); + match credential { + ProviderCredential::NoCredentials => { + let credential = self.embedding_provider.retrieve_credentials(cx); + self.provider_credential = credential; + } + _ => {} } + + self.embedding_queue.lock().set_credential(credential); + + self.is_authenticated() } pub fn is_authenticated(&self) -> bool { - self.api_key.is_some() + let credential = &self.provider_credential; + match credential { + &ProviderCredential::Credentials { .. } => true, + _ => false, + } } pub fn enabled(cx: &AppContext) -> bool { @@ -340,7 +350,7 @@ impl SemanticIndex { Ok(cx.add_model(|cx| { let t0 = Instant::now(); let embedding_queue = - EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone(), None); + EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone(), ProviderCredential::NoCredentials); let _embedding_task = cx.background().spawn({ let embedded_files = embedding_queue.finished_files(); let db = db.clone(); @@ -405,7 +415,7 @@ impl SemanticIndex { _embedding_task, _parsing_files_tasks, projects: Default::default(), - api_key: None, + provider_credential: ProviderCredential::NoCredentials, embedding_queue } })) @@ -721,13 +731,14 @@ impl SemanticIndex { let index = self.index_project(project.clone(), cx); let embedding_provider = self.embedding_provider.clone(); - let api_key = self.api_key.clone(); + let credential = self.provider_credential.clone(); cx.spawn(|this, mut cx| async move { index.await?; let t0 = Instant::now(); + let query = embedding_provider - .embed_batch(vec![query], api_key) + .embed_batch(vec![query], credential) .await? .pop() .ok_or_else(|| anyhow!("could not embed query"))?; @@ -945,7 +956,7 @@ impl SemanticIndex { let fs = self.fs.clone(); let db_path = self.db.path().clone(); let background = cx.background().clone(); - let api_key = self.api_key.clone(); + let credential = self.provider_credential.clone(); cx.background().spawn(async move { let db = VectorDatabase::new(fs, db_path.clone(), background).await?; let mut results = Vec::::new(); @@ -964,7 +975,7 @@ impl SemanticIndex { &mut spans, embedding_provider.as_ref(), &db, - api_key.clone(), + credential.clone(), ) .await .log_err() @@ -1008,9 +1019,8 @@ impl SemanticIndex { project: ModelHandle, cx: &mut ModelContext, ) -> Task> { - if self.api_key.is_none() { - self.authenticate(cx); - if self.api_key.is_none() { + if !self.is_authenticated() { + if !self.authenticate(cx) { return Task::ready(Err(anyhow!("user is not authenticated"))); } } @@ -1193,7 +1203,7 @@ impl SemanticIndex { spans: &mut [Span], embedding_provider: &dyn EmbeddingProvider, db: &VectorDatabase, - api_key: Option, + credential: ProviderCredential, ) -> Result<()> { let mut batch = Vec::new(); let mut batch_tokens = 0; @@ -1216,7 +1226,7 @@ impl SemanticIndex { if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() { let batch_embeddings = embedding_provider - .embed_batch(mem::take(&mut batch), api_key.clone()) + .embed_batch(mem::take(&mut batch), credential.clone()) .await?; embeddings.extend(batch_embeddings); batch_tokens = 0; @@ -1228,7 +1238,7 @@ impl SemanticIndex { if !batch.is_empty() { let batch_embeddings = embedding_provider - .embed_batch(mem::take(&mut batch), api_key) + .embed_batch(mem::take(&mut batch), credential) .await?; embeddings.extend(batch_embeddings); diff --git a/crates/semantic_index/src/semantic_index_tests.rs b/crates/semantic_index/src/semantic_index_tests.rs index 1c117c9ea2..7d5a4e22e8 100644 --- a/crates/semantic_index/src/semantic_index_tests.rs +++ b/crates/semantic_index/src/semantic_index_tests.rs @@ -4,14 +4,9 @@ use crate::{ semantic_index_settings::SemanticIndexSettings, FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT, }; -use ai::providers::dummy::{DummyEmbeddingProvider, DummyLanguageModel}; -use ai::{ - embedding::{Embedding, EmbeddingProvider}, - models::LanguageModel, -}; -use anyhow::Result; -use async_trait::async_trait; -use gpui::{executor::Deterministic, AppContext, Task, TestAppContext}; +use ai::test::FakeEmbeddingProvider; + +use gpui::{executor::Deterministic, Task, TestAppContext}; use language::{Language, LanguageConfig, LanguageRegistry, ToOffset}; use parking_lot::Mutex; use pretty_assertions::assert_eq; @@ -19,14 +14,7 @@ use project::{project_settings::ProjectSettings, search::PathMatcher, FakeFs, Fs use rand::{rngs::StdRng, Rng}; use serde_json::json; use settings::SettingsStore; -use std::{ - path::Path, - sync::{ - atomic::{self, AtomicUsize}, - Arc, - }, - time::{Instant, SystemTime}, -}; +use std::{path::Path, sync::Arc, time::SystemTime}; use unindent::Unindent; use util::RandomCharIter; @@ -232,7 +220,11 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) { let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); - let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background(), None); + let mut queue = EmbeddingQueue::new( + embedding_provider.clone(), + cx.background(), + ai::auth::ProviderCredential::NoCredentials, + ); for file in &files { queue.push(file.clone()); } @@ -284,7 +276,7 @@ fn assert_search_results( #[gpui::test] async fn test_code_context_retrieval_rust() { let language = rust_lang(); - let embedding_provider = Arc::new(DummyEmbeddingProvider {}); + let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); let mut retriever = CodeContextRetriever::new(embedding_provider); let text = " @@ -386,7 +378,7 @@ async fn test_code_context_retrieval_rust() { #[gpui::test] async fn test_code_context_retrieval_json() { let language = json_lang(); - let embedding_provider = Arc::new(DummyEmbeddingProvider {}); + let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); let mut retriever = CodeContextRetriever::new(embedding_provider); let text = r#" @@ -470,7 +462,7 @@ fn assert_documents_eq( #[gpui::test] async fn test_code_context_retrieval_javascript() { let language = js_lang(); - let embedding_provider = Arc::new(DummyEmbeddingProvider {}); + let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); let mut retriever = CodeContextRetriever::new(embedding_provider); let text = " @@ -569,7 +561,7 @@ async fn test_code_context_retrieval_javascript() { #[gpui::test] async fn test_code_context_retrieval_lua() { let language = lua_lang(); - let embedding_provider = Arc::new(DummyEmbeddingProvider {}); + let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); let mut retriever = CodeContextRetriever::new(embedding_provider); let text = r#" @@ -643,7 +635,7 @@ async fn test_code_context_retrieval_lua() { #[gpui::test] async fn test_code_context_retrieval_elixir() { let language = elixir_lang(); - let embedding_provider = Arc::new(DummyEmbeddingProvider {}); + let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); let mut retriever = CodeContextRetriever::new(embedding_provider); let text = r#" @@ -760,7 +752,7 @@ async fn test_code_context_retrieval_elixir() { #[gpui::test] async fn test_code_context_retrieval_cpp() { let language = cpp_lang(); - let embedding_provider = Arc::new(DummyEmbeddingProvider {}); + let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); let mut retriever = CodeContextRetriever::new(embedding_provider); let text = " @@ -913,7 +905,7 @@ async fn test_code_context_retrieval_cpp() { #[gpui::test] async fn test_code_context_retrieval_ruby() { let language = ruby_lang(); - let embedding_provider = Arc::new(DummyEmbeddingProvider {}); + let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); let mut retriever = CodeContextRetriever::new(embedding_provider); let text = r#" @@ -1104,7 +1096,7 @@ async fn test_code_context_retrieval_ruby() { #[gpui::test] async fn test_code_context_retrieval_php() { let language = php_lang(); - let embedding_provider = Arc::new(DummyEmbeddingProvider {}); + let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); let mut retriever = CodeContextRetriever::new(embedding_provider); let text = r#" @@ -1252,65 +1244,6 @@ async fn test_code_context_retrieval_php() { ); } -#[derive(Default)] -struct FakeEmbeddingProvider { - embedding_count: AtomicUsize, -} - -impl FakeEmbeddingProvider { - fn embedding_count(&self) -> usize { - self.embedding_count.load(atomic::Ordering::SeqCst) - } - - fn embed_sync(&self, span: &str) -> Embedding { - let mut result = vec![1.0; 26]; - for letter in span.chars() { - let letter = letter.to_ascii_lowercase(); - if letter as u32 >= 'a' as u32 { - let ix = (letter as u32) - ('a' as u32); - if ix < 26 { - result[ix as usize] += 1.0; - } - } - } - - let norm = result.iter().map(|x| x * x).sum::().sqrt(); - for x in &mut result { - *x /= norm; - } - - result.into() - } -} - -#[async_trait] -impl EmbeddingProvider for FakeEmbeddingProvider { - fn base_model(&self) -> Box { - Box::new(DummyLanguageModel {}) - } - fn retrieve_credentials(&self, _cx: &AppContext) -> Option { - Some("Fake Credentials".to_string()) - } - fn max_tokens_per_batch(&self) -> usize { - 1000 - } - - fn rate_limit_expiration(&self) -> Option { - None - } - - async fn embed_batch( - &self, - spans: Vec, - _api_key: Option, - ) -> Result> { - self.embedding_count - .fetch_add(spans.len(), atomic::Ordering::SeqCst); - - anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect()) - } -} - fn js_lang() -> Arc { Arc::new( Language::new(