diff --git a/crates/copilot/src/copilot_chat.rs b/crates/copilot/src/copilot_chat.rs index c5ba1bfc6a..6a9101e1b1 100644 --- a/crates/copilot/src/copilot_chat.rs +++ b/crates/copilot/src/copilot_chat.rs @@ -35,14 +35,30 @@ pub enum Model { Gpt4, #[serde(alias = "gpt-3.5-turbo", rename = "gpt-3.5-turbo")] Gpt3_5Turbo, + #[serde(alias = "o1-preview", rename = "o1-preview-2024-09-12")] + O1Preview, + #[serde(alias = "o1-mini", rename = "o1-mini-2024-09-12")] + O1Mini, + #[serde(alias = "claude-3-5-sonnet", rename = "claude-3.5-sonnet")] + Claude3_5Sonnet, } impl Model { + pub fn uses_streaming(&self) -> bool { + match self { + Self::Gpt4o | Self::Gpt4 | Self::Gpt3_5Turbo | Self::Claude3_5Sonnet => true, + Self::O1Mini | Self::O1Preview => false, + } + } + pub fn from_id(id: &str) -> Result { match id { "gpt-4o" => Ok(Self::Gpt4o), "gpt-4" => Ok(Self::Gpt4), "gpt-3.5-turbo" => Ok(Self::Gpt3_5Turbo), + "o1-preview" => Ok(Self::O1Preview), + "o1-mini" => Ok(Self::O1Mini), + "claude-3-5-sonnet" => Ok(Self::Claude3_5Sonnet), _ => Err(anyhow!("Invalid model id: {}", id)), } } @@ -52,6 +68,9 @@ impl Model { Self::Gpt3_5Turbo => "gpt-3.5-turbo", Self::Gpt4 => "gpt-4", Self::Gpt4o => "gpt-4o", + Self::O1Mini => "o1-mini", + Self::O1Preview => "o1-preview", + Self::Claude3_5Sonnet => "claude-3-5-sonnet", } } @@ -60,6 +79,9 @@ impl Model { Self::Gpt3_5Turbo => "GPT-3.5", Self::Gpt4 => "GPT-4", Self::Gpt4o => "GPT-4o", + Self::O1Mini => "o1-mini", + Self::O1Preview => "o1-preview", + Self::Claude3_5Sonnet => "Claude 3.5 Sonnet", } } @@ -68,6 +90,9 @@ impl Model { Self::Gpt4o => 128000, Self::Gpt4 => 8192, Self::Gpt3_5Turbo => 16385, + Self::O1Mini => 128000, + Self::O1Preview => 128000, + Self::Claude3_5Sonnet => 200_000, } } } @@ -87,7 +112,7 @@ impl Request { Self { intent: true, n: 1, - stream: true, + stream: model.uses_streaming(), temperature: 0.1, model, messages, @@ -113,7 +138,8 @@ pub struct ResponseEvent { pub struct ResponseChoice { pub index: usize, pub finish_reason: Option, - pub delta: ResponseDelta, + pub delta: Option, + pub message: Option, } #[derive(Debug, Deserialize)] @@ -333,9 +359,23 @@ async fn stream_completion( if let Some(low_speed_timeout) = low_speed_timeout { request_builder = request_builder.read_timeout(low_speed_timeout); } + let is_streaming = request.stream; + let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?; let mut response = client.send(request).await?; - if response.status().is_success() { + + if !response.status().is_success() { + let mut body = Vec::new(); + response.body_mut().read_to_end(&mut body).await?; + let body_str = std::str::from_utf8(&body)?; + return Err(anyhow!( + "Failed to connect to API: {} {}", + response.status(), + body_str + )); + } + + if is_streaming { let reader = BufReader::new(response.into_body()); Ok(reader .lines() @@ -367,19 +407,9 @@ async fn stream_completion( } else { let mut body = Vec::new(); response.body_mut().read_to_end(&mut body).await?; - let body_str = std::str::from_utf8(&body)?; + let response: ResponseEvent = serde_json::from_str(body_str)?; - match serde_json::from_str::(body_str) { - Ok(_) => Err(anyhow!( - "Unexpected success response while expecting an error: {}", - body_str, - )), - Err(_) => Err(anyhow!( - "Failed to connect to API: {} {}", - response.status(), - body_str, - )), - } + Ok(futures::stream::once(async move { Ok(response) }).boxed()) } } diff --git a/crates/language_model/src/provider/copilot_chat.rs b/crates/language_model/src/provider/copilot_chat.rs index 58b486921a..87ab12998e 100644 --- a/crates/language_model/src/provider/copilot_chat.rs +++ b/crates/language_model/src/provider/copilot_chat.rs @@ -30,6 +30,7 @@ use crate::{ }; use crate::{LanguageModelCompletionEvent, LanguageModelProviderState}; +use super::anthropic::count_anthropic_tokens; use super::open_ai::count_open_ai_tokens; const PROVIDER_ID: &str = "copilot_chat"; @@ -179,13 +180,19 @@ impl LanguageModel for CopilotChatLanguageModel { request: LanguageModelRequest, cx: &AppContext, ) -> BoxFuture<'static, Result> { - let model = match self.model { - CopilotChatModel::Gpt4o => open_ai::Model::FourOmni, - CopilotChatModel::Gpt4 => open_ai::Model::Four, - CopilotChatModel::Gpt3_5Turbo => open_ai::Model::ThreePointFiveTurbo, - }; - - count_open_ai_tokens(request, model, cx) + match self.model { + CopilotChatModel::Claude3_5Sonnet => count_anthropic_tokens(request, cx), + _ => { + let model = match self.model { + CopilotChatModel::Gpt4o => open_ai::Model::FourOmni, + CopilotChatModel::Gpt4 => open_ai::Model::Four, + CopilotChatModel::Gpt3_5Turbo => open_ai::Model::ThreePointFiveTurbo, + CopilotChatModel::O1Preview | CopilotChatModel::O1Mini => open_ai::Model::Four, + CopilotChatModel::Claude3_5Sonnet => unreachable!(), + }; + count_open_ai_tokens(request, model, cx) + } + } } fn stream_completion( @@ -209,7 +216,8 @@ impl LanguageModel for CopilotChatLanguageModel { } } - let request = self.to_copilot_chat_request(request); + let copilot_request = self.to_copilot_chat_request(request); + let is_streaming = copilot_request.stream; let Ok(low_speed_timeout) = cx.update(|cx| { AllLanguageModelSettings::get_global(cx) .copilot_chat @@ -220,16 +228,31 @@ impl LanguageModel for CopilotChatLanguageModel { let request_limiter = self.request_limiter.clone(); let future = cx.spawn(|cx| async move { - let response = CopilotChat::stream_completion(request, low_speed_timeout, cx); + let response = CopilotChat::stream_completion(copilot_request, low_speed_timeout, cx); request_limiter.stream(async move { let response = response.await?; let stream = response - .filter_map(|response| async move { + .filter_map(move |response| async move { match response { Ok(result) => { let choice = result.choices.first(); match choice { - Some(choice) => Some(Ok(choice.delta.content.clone().unwrap_or_default())), + Some(choice) if !is_streaming => { + match &choice.message { + Some(msg) => Some(Ok(msg.content.clone().unwrap_or_default())), + None => Some(Err(anyhow::anyhow!( + "The Copilot Chat API returned a response with no message content" + ))), + } + }, + Some(choice) => { + match &choice.delta { + Some(delta) => Some(Ok(delta.content.clone().unwrap_or_default())), + None => Some(Err(anyhow::anyhow!( + "The Copilot Chat API returned a response with no delta content" + ))), + } + }, None => Some(Err(anyhow::anyhow!( "The Copilot Chat API returned a response with no choices, but hadn't finished the message yet. Please try again." ))),