From dab886f4794e8d14cc82dc8e4906e93b78aa8fe8 Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Thu, 29 Feb 2024 13:02:08 -0500 Subject: [PATCH] Stub out support for Azure OpenAI (#8624) This PR stubs out support for [Azure OpenAI](https://learn.microsoft.com/en-us/azure/ai-services/openai/overview) within the `OpenAiCompletionProvider`. It still requires some additional wiring so that it is accessible, but the necessary hooks should be in place now. Release Notes: - N/A --- .../ai/src/{providers/mod.rs => providers.rs} | 0 crates/ai/src/providers/open_ai/completion.rs | 58 +++++++++++++++++-- crates/assistant/src/assistant_panel.rs | 8 ++- 3 files changed, 59 insertions(+), 7 deletions(-) rename crates/ai/src/{providers/mod.rs => providers.rs} (100%) diff --git a/crates/ai/src/providers/mod.rs b/crates/ai/src/providers.rs similarity index 100% rename from crates/ai/src/providers/mod.rs rename to crates/ai/src/providers.rs diff --git a/crates/ai/src/providers/open_ai/completion.rs b/crates/ai/src/providers/open_ai/completion.rs index f3c7ebbdbc..5cf6658ba2 100644 --- a/crates/ai/src/providers/open_ai/completion.rs +++ b/crates/ai/src/providers/open_ai/completion.rs @@ -102,8 +102,9 @@ pub struct OpenAiResponseStreamEvent { pub usage: Option, } -pub async fn stream_completion( +async fn stream_completion( api_url: String, + kind: OpenAiCompletionProviderKind, credential: ProviderCredential, executor: BackgroundExecutor, request: Box, @@ -117,10 +118,11 @@ pub async fn stream_completion( let (tx, rx) = futures::channel::mpsc::unbounded::>(); + let (auth_header_name, auth_header_value) = kind.auth_header(api_key); let json_data = request.data()?; - let mut response = Request::post(format!("{api_url}/chat/completions")) + let mut response = Request::post(kind.completions_endpoint_url(&api_url)) .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {}", api_key)) + .header(auth_header_name, auth_header_value) .body(json_data)? .send_async() .await?; @@ -194,22 +196,65 @@ pub async fn stream_completion( } } +#[derive(Clone)] +pub enum OpenAiCompletionProviderKind { + OpenAi, + AzureOpenAi { + deployment_id: String, + api_version: String, + }, +} + +impl OpenAiCompletionProviderKind { + /// Returns the chat completion endpoint URL for this [`OpenAiCompletionProviderKind`]. + fn completions_endpoint_url(&self, api_url: &str) -> String { + match self { + Self::OpenAi => { + // https://platform.openai.com/docs/api-reference/chat/create + format!("{api_url}/chat/completions") + } + Self::AzureOpenAi { + deployment_id, + api_version, + } => { + // https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#completions + format!("{api_url}/openai/deployments/{deployment_id}/completions?api-version={api_version}") + } + } + } + + /// Returns the authentication header for this [`OpenAiCompletionProviderKind`]. + fn auth_header(&self, api_key: String) -> (&'static str, String) { + match self { + Self::OpenAi => ("Authorization", format!("Bearer {api_key}")), + Self::AzureOpenAi { .. } => ("Api-Key", api_key), + } + } +} + #[derive(Clone)] pub struct OpenAiCompletionProvider { api_url: String, + kind: OpenAiCompletionProviderKind, model: OpenAiLanguageModel, credential: Arc>, executor: BackgroundExecutor, } impl OpenAiCompletionProvider { - pub async fn new(api_url: String, model_name: String, executor: BackgroundExecutor) -> Self { + pub async fn new( + api_url: String, + kind: OpenAiCompletionProviderKind, + model_name: String, + executor: BackgroundExecutor, + ) -> Self { let model = executor .spawn(async move { OpenAiLanguageModel::load(&model_name) }) .await; let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials)); Self { api_url, + kind, model, credential, executor, @@ -297,6 +342,7 @@ impl CompletionProvider for OpenAiCompletionProvider { let model: Box = Box::new(self.model.clone()); model } + fn complete( &self, prompt: Box, @@ -307,7 +353,8 @@ impl CompletionProvider for OpenAiCompletionProvider { // At some point in the future we should rectify this. let credential = self.credential.read().clone(); let api_url = self.api_url.clone(); - let request = stream_completion(api_url, credential, self.executor.clone(), prompt); + let kind = self.kind.clone(); + let request = stream_completion(api_url, kind, credential, self.executor.clone(), prompt); async move { let response = request.await?; let stream = response @@ -322,6 +369,7 @@ impl CompletionProvider for OpenAiCompletionProvider { } .boxed() } + fn box_clone(&self) -> Box { Box::new((*self).clone()) } diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 9a04f016ba..cbe9edf8e5 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -7,11 +7,13 @@ use crate::{ SavedMessage, Split, ToggleFocus, ToggleIncludeConversation, ToggleRetrieveContext, }; use ai::prompts::repository_context::PromptCodeSnippet; -use ai::providers::open_ai::OPEN_AI_API_URL; use ai::{ auth::ProviderCredential, completion::{CompletionProvider, CompletionRequest}, - providers::open_ai::{OpenAiCompletionProvider, OpenAiRequest, RequestMessage}, + providers::open_ai::{ + OpenAiCompletionProvider, OpenAiCompletionProviderKind, OpenAiRequest, RequestMessage, + OPEN_AI_API_URL, + }, }; use anyhow::{anyhow, Result}; use chrono::{DateTime, Local}; @@ -131,6 +133,7 @@ impl AssistantPanel { })?; let completion_provider = OpenAiCompletionProvider::new( api_url, + OpenAiCompletionProviderKind::OpenAi, model_name, cx.background_executor().clone(), ) @@ -1533,6 +1536,7 @@ impl Conversation { api_url .clone() .unwrap_or_else(|| OPEN_AI_API_URL.to_string()), + OpenAiCompletionProviderKind::OpenAi, model.full_name().into(), cx.background_executor().clone(), )