From 9b673089dbcafd4e9d69fe8419c4e74e5df20303 Mon Sep 17 00:00:00 2001 From: Nathan Sobo Date: Sun, 31 Mar 2024 14:57:57 -0700 Subject: [PATCH] Enable Claude 3 models to be used via the Zed server if "language-models" feature flag is enabled for user (#10015) Release Notes: - N/A --- Cargo.lock | 13 + Cargo.toml | 2 + crates/anthropic/Cargo.toml | 22 ++ crates/anthropic/src/anthropic.rs | 234 ++++++++++++++++++ crates/assistant/src/assistant_panel.rs | 13 +- crates/assistant/src/assistant_settings.rs | 40 +-- .../assistant/src/completion_provider/zed.rs | 20 +- crates/collab/Cargo.toml | 1 + crates/collab/k8s/collab.template.yml | 5 + crates/collab/src/lib.rs | 1 + crates/collab/src/rpc.rs | 121 +++++++++ crates/collab/src/tests/test_server.rs | 1 + 12 files changed, 447 insertions(+), 26 deletions(-) create mode 100644 crates/anthropic/Cargo.toml create mode 100644 crates/anthropic/src/anthropic.rs diff --git a/Cargo.lock b/Cargo.lock index 27fc366f17..8a775d2f85 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -213,6 +213,18 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "anthropic" +version = "0.1.0" +dependencies = [ + "anyhow", + "futures 0.3.28", + "serde", + "serde_json", + "tokio", + "util", +] + [[package]] name = "anyhow" version = "1.0.75" @@ -2214,6 +2226,7 @@ dependencies = [ name = "collab" version = "0.44.0" dependencies = [ + "anthropic", "anyhow", "async-trait", "async-tungstenite", diff --git a/Cargo.toml b/Cargo.toml index a7d0500795..6f35a6b41a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,7 @@ [workspace] members = [ "crates/activity_indicator", + "crates/anthropic", "crates/assets", "crates/assistant", "crates/audio", @@ -119,6 +120,7 @@ resolver = "2" [workspace.dependencies] activity_indicator = { path = "crates/activity_indicator" } ai = { path = "crates/ai" } +anthropic = { path = "crates/anthropic" } assets = { path = "crates/assets" } assistant = { path = "crates/assistant" } audio = { path = "crates/audio" } diff --git a/crates/anthropic/Cargo.toml b/crates/anthropic/Cargo.toml new file mode 100644 index 0000000000..ba0284185e --- /dev/null +++ b/crates/anthropic/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "anthropic" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[lib] +path = "src/anthropic.rs" + +[dependencies] +anyhow.workspace = true +futures.workspace = true +serde.workspace = true +serde_json.workspace = true +util.workspace = true + +[dev-dependencies] +tokio.workspace = true + +[lints] +workspace = true diff --git a/crates/anthropic/src/anthropic.rs b/crates/anthropic/src/anthropic.rs new file mode 100644 index 0000000000..a96a23b166 --- /dev/null +++ b/crates/anthropic/src/anthropic.rs @@ -0,0 +1,234 @@ +use anyhow::{anyhow, Result}; +use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt}; +use serde::{Deserialize, Serialize}; +use std::convert::TryFrom; +use util::http::{AsyncBody, HttpClient, Method, Request as HttpRequest}; + +#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)] +pub enum Model { + #[default] + #[serde(rename = "claude-3-opus-20240229")] + Claude3Opus, + #[serde(rename = "claude-3-sonnet-20240229")] + Claude3Sonnet, + #[serde(rename = "claude-3-haiku-20240307")] + Claude3Haiku, +} + +impl Model { + pub fn from_id(id: &str) -> Result { + if id.starts_with("claude-3-opus") { + Ok(Self::Claude3Opus) + } else if id.starts_with("claude-3-sonnet") { + Ok(Self::Claude3Sonnet) + } else if id.starts_with("claude-3-haiku") { + Ok(Self::Claude3Haiku) + } else { + Err(anyhow!("Invalid model id: {}", id)) + } + } + + pub fn display_name(&self) -> &'static str { + match self { + Self::Claude3Opus => "Claude 3 Opus", + Self::Claude3Sonnet => "Claude 3 Sonnet", + Self::Claude3Haiku => "Claude 3 Haiku", + } + } + + pub fn max_token_count(&self) -> usize { + 200_000 + } +} + +#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum Role { + User, + Assistant, +} + +impl TryFrom for Role { + type Error = anyhow::Error; + + fn try_from(value: String) -> Result { + match value.as_str() { + "user" => Ok(Self::User), + "assistant" => Ok(Self::Assistant), + _ => Err(anyhow!("invalid role '{value}'")), + } + } +} + +impl From for String { + fn from(val: Role) -> Self { + match val { + Role::User => "user".to_owned(), + Role::Assistant => "assistant".to_owned(), + } + } +} + +#[derive(Debug, Serialize)] +pub struct Request { + pub model: Model, + pub messages: Vec, + pub stream: bool, + pub system: String, + pub max_tokens: u32, +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct RequestMessage { + pub role: Role, + pub content: String, +} + +#[derive(Deserialize, Debug)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ResponseEvent { + MessageStart { + message: ResponseMessage, + }, + ContentBlockStart { + index: u32, + content_block: ContentBlock, + }, + Ping {}, + ContentBlockDelta { + index: u32, + delta: TextDelta, + }, + ContentBlockStop { + index: u32, + }, + MessageDelta { + delta: ResponseMessage, + usage: Usage, + }, + MessageStop {}, +} + +#[derive(Deserialize, Debug)] +pub struct ResponseMessage { + #[serde(rename = "type")] + pub message_type: Option, + pub id: Option, + pub role: Option, + pub content: Option>, + pub model: Option, + pub stop_reason: Option, + pub stop_sequence: Option, + pub usage: Option, +} + +#[derive(Deserialize, Debug)] +pub struct Usage { + pub input_tokens: Option, + pub output_tokens: Option, +} + +#[derive(Deserialize, Debug)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ContentBlock { + Text { text: String }, +} + +#[derive(Deserialize, Debug)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum TextDelta { + TextDelta { text: String }, +} + +pub async fn stream_completion( + client: &dyn HttpClient, + api_url: &str, + api_key: &str, + request: Request, +) -> Result>> { + let uri = format!("{api_url}/v1/messages"); + let request = HttpRequest::builder() + .method(Method::POST) + .uri(uri) + .header("Anthropic-Version", "2023-06-01") + .header("Anthropic-Beta", "messages-2023-12-15") + .header("X-Api-Key", api_key) + .header("Content-Type", "application/json") + .body(AsyncBody::from(serde_json::to_string(&request)?))?; + let mut response = client.send(request).await?; + if response.status().is_success() { + let reader = BufReader::new(response.into_body()); + Ok(reader + .lines() + .filter_map(|line| async move { + match line { + Ok(line) => { + let line = line.strip_prefix("data: ")?; + match serde_json::from_str(line) { + Ok(response) => Some(Ok(response)), + Err(error) => Some(Err(anyhow!(error))), + } + } + Err(error) => Some(Err(anyhow!(error))), + } + }) + .boxed()) + } else { + let mut body = Vec::new(); + response.body_mut().read_to_end(&mut body).await?; + + let body_str = std::str::from_utf8(&body)?; + + match serde_json::from_str::(body_str) { + Ok(_) => Err(anyhow!( + "Unexpected success response while expecting an error: {}", + body_str, + )), + Err(_) => Err(anyhow!( + "Failed to connect to API: {} {}", + response.status(), + body_str, + )), + } + } +} + +// #[cfg(test)] +// mod tests { +// use super::*; +// use util::http::IsahcHttpClient; + +// #[tokio::test] +// async fn stream_completion_success() { +// let http_client = IsahcHttpClient::new().unwrap(); + +// let request = Request { +// model: Model::Claude3Opus, +// messages: vec![RequestMessage { +// role: Role::User, +// content: "Ping".to_string(), +// }], +// stream: true, +// system: "Respond to ping with pong".to_string(), +// max_tokens: 4096, +// }; + +// let stream = stream_completion( +// &http_client, +// "https://api.anthropic.com", +// &std::env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY not set"), +// request, +// ) +// .await +// .unwrap(); + +// stream +// .for_each(|event| async { +// match event { +// Ok(event) => println!("{:?}", event), +// Err(e) => eprintln!("Error: {:?}", e), +// } +// }) +// .await; +// } +// } diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 06447616c3..61604d7ef6 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -768,15 +768,18 @@ impl AssistantPanel { open_ai::Model::FourTurbo => open_ai::Model::ThreePointFiveTurbo, }), LanguageModel::ZedDotDev(model) => LanguageModel::ZedDotDev(match &model { - ZedDotDevModel::GptThreePointFiveTurbo => ZedDotDevModel::GptFour, - ZedDotDevModel::GptFour => ZedDotDevModel::GptFourTurbo, - ZedDotDevModel::GptFourTurbo => { + ZedDotDevModel::Gpt3Point5Turbo => ZedDotDevModel::Gpt4, + ZedDotDevModel::Gpt4 => ZedDotDevModel::Gpt4Turbo, + ZedDotDevModel::Gpt4Turbo => ZedDotDevModel::Claude3Opus, + ZedDotDevModel::Claude3Opus => ZedDotDevModel::Claude3Sonnet, + ZedDotDevModel::Claude3Sonnet => ZedDotDevModel::Claude3Haiku, + ZedDotDevModel::Claude3Haiku => { match CompletionProvider::global(cx).default_model() { LanguageModel::ZedDotDev(custom) => custom, - _ => ZedDotDevModel::GptThreePointFiveTurbo, + _ => ZedDotDevModel::Gpt3Point5Turbo, } } - ZedDotDevModel::Custom(_) => ZedDotDevModel::GptThreePointFiveTurbo, + ZedDotDevModel::Custom(_) => ZedDotDevModel::Gpt3Point5Turbo, }), }; diff --git a/crates/assistant/src/assistant_settings.rs b/crates/assistant/src/assistant_settings.rs index fa68eaa918..fb7060a932 100644 --- a/crates/assistant/src/assistant_settings.rs +++ b/crates/assistant/src/assistant_settings.rs @@ -14,10 +14,13 @@ use settings::Settings; #[derive(Clone, Debug, Default, PartialEq)] pub enum ZedDotDevModel { - GptThreePointFiveTurbo, - GptFour, + Gpt3Point5Turbo, + Gpt4, #[default] - GptFourTurbo, + Gpt4Turbo, + Claude3Opus, + Claude3Sonnet, + Claude3Haiku, Custom(String), } @@ -49,9 +52,9 @@ impl<'de> Deserialize<'de> for ZedDotDevModel { E: de::Error, { match value { - "gpt-3.5-turbo" => Ok(ZedDotDevModel::GptThreePointFiveTurbo), - "gpt-4" => Ok(ZedDotDevModel::GptFour), - "gpt-4-turbo-preview" => Ok(ZedDotDevModel::GptFourTurbo), + "gpt-3.5-turbo" => Ok(ZedDotDevModel::Gpt3Point5Turbo), + "gpt-4" => Ok(ZedDotDevModel::Gpt4), + "gpt-4-turbo-preview" => Ok(ZedDotDevModel::Gpt4Turbo), _ => Ok(ZedDotDevModel::Custom(value.to_owned())), } } @@ -94,27 +97,34 @@ impl JsonSchema for ZedDotDevModel { impl ZedDotDevModel { pub fn id(&self) -> &str { match self { - Self::GptThreePointFiveTurbo => "gpt-3.5-turbo", - Self::GptFour => "gpt-4", - Self::GptFourTurbo => "gpt-4-turbo-preview", + Self::Gpt3Point5Turbo => "gpt-3.5-turbo", + Self::Gpt4 => "gpt-4", + Self::Gpt4Turbo => "gpt-4-turbo-preview", + Self::Claude3Opus => "claude-3-opus", + Self::Claude3Sonnet => "claude-3-sonnet", + Self::Claude3Haiku => "claude-3-haiku", Self::Custom(id) => id, } } pub fn display_name(&self) -> &str { match self { - Self::GptThreePointFiveTurbo => "gpt-3.5-turbo", - Self::GptFour => "gpt-4", - Self::GptFourTurbo => "gpt-4-turbo", + Self::Gpt3Point5Turbo => "GPT 3.5 Turbo", + Self::Gpt4 => "GPT 4", + Self::Gpt4Turbo => "GPT 4 Turbo", + Self::Claude3Opus => "Claude 3 Opus", + Self::Claude3Sonnet => "Claude 3 Sonnet", + Self::Claude3Haiku => "Claude 3 Haiku", Self::Custom(id) => id.as_str(), } } pub fn max_token_count(&self) -> usize { match self { - Self::GptThreePointFiveTurbo => 2048, - Self::GptFour => 4096, - Self::GptFourTurbo => 128000, + Self::Gpt3Point5Turbo => 2048, + Self::Gpt4 => 4096, + Self::Gpt4Turbo => 128000, + Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => 200000, Self::Custom(_) => 4096, // TODO: Make this configurable } } diff --git a/crates/assistant/src/completion_provider/zed.rs b/crates/assistant/src/completion_provider/zed.rs index 0febb05278..1ec852da19 100644 --- a/crates/assistant/src/completion_provider/zed.rs +++ b/crates/assistant/src/completion_provider/zed.rs @@ -1,5 +1,5 @@ use crate::{ - assistant_settings::ZedDotDevModel, count_open_ai_tokens, CompletionProvider, + assistant_settings::ZedDotDevModel, count_open_ai_tokens, CompletionProvider, LanguageModel, LanguageModelRequest, }; use anyhow::{anyhow, Result}; @@ -78,13 +78,21 @@ impl ZedDotDevCompletionProvider { cx: &AppContext, ) -> BoxFuture<'static, Result> { match request.model { - crate::LanguageModel::OpenAi(_) => future::ready(Err(anyhow!("invalid model"))).boxed(), - crate::LanguageModel::ZedDotDev(ZedDotDevModel::GptFour) - | crate::LanguageModel::ZedDotDev(ZedDotDevModel::GptFourTurbo) - | crate::LanguageModel::ZedDotDev(ZedDotDevModel::GptThreePointFiveTurbo) => { + LanguageModel::OpenAi(_) => future::ready(Err(anyhow!("invalid model"))).boxed(), + LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4) + | LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4Turbo) + | LanguageModel::ZedDotDev(ZedDotDevModel::Gpt3Point5Turbo) => { count_open_ai_tokens(request, cx.background_executor()) } - crate::LanguageModel::ZedDotDev(ZedDotDevModel::Custom(model)) => { + LanguageModel::ZedDotDev( + ZedDotDevModel::Claude3Opus + | ZedDotDevModel::Claude3Sonnet + | ZedDotDevModel::Claude3Haiku, + ) => { + // Can't find a tokenizer for Claude 3, so for now just use the same as OpenAI's as an approximation. + count_open_ai_tokens(request, cx.background_executor()) + } + LanguageModel::ZedDotDev(ZedDotDevModel::Custom(model)) => { let request = self.client.request(proto::CountTokensWithLanguageModel { model, messages: request diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index cb0f1b1f8a..7fbc4bfd03 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -18,6 +18,7 @@ sqlite = ["sea-orm/sqlx-sqlite", "sqlx/sqlite"] test-support = ["sqlite"] [dependencies] +anthropic.workspace = true anyhow.workspace = true async-tungstenite = "0.16" aws-config = { version = "1.1.5" } diff --git a/crates/collab/k8s/collab.template.yml b/crates/collab/k8s/collab.template.yml index 8a8f55fd1a..da2d9b1b16 100644 --- a/crates/collab/k8s/collab.template.yml +++ b/crates/collab/k8s/collab.template.yml @@ -130,6 +130,11 @@ spec: secretKeyRef: name: openai key: api_key + - name: ANTHROPIC_API_KEY + valueFrom: + secretKeyRef: + name: anthropic + key: api_key - name: BLOB_STORE_ACCESS_KEY valueFrom: secretKeyRef: diff --git a/crates/collab/src/lib.rs b/crates/collab/src/lib.rs index 268d89dd10..925d192fc0 100644 --- a/crates/collab/src/lib.rs +++ b/crates/collab/src/lib.rs @@ -134,6 +134,7 @@ pub struct Config { pub zed_environment: Arc, pub openai_api_key: Option>, pub google_ai_api_key: Option>, + pub anthropic_api_key: Option>, pub zed_client_checksum_seed: Option, pub slack_panics_webhook: Option, pub auto_join_channel_id: Option, diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index af2cdb75ac..7251f095cd 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -419,6 +419,7 @@ impl Server { session, app_state.config.openai_api_key.clone(), app_state.config.google_ai_api_key.clone(), + app_state.config.anthropic_api_key.clone(), ) } }) @@ -3506,6 +3507,7 @@ async fn complete_with_language_model( session: Session, open_ai_api_key: Option>, google_ai_api_key: Option>, + anthropic_api_key: Option>, ) -> Result<()> { let Some(session) = session.for_user() else { return Err(anyhow!("user not found"))?; @@ -3524,6 +3526,10 @@ async fn complete_with_language_model( let api_key = google_ai_api_key .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?; complete_with_google_ai(request, response, session, api_key).await?; + } else if request.model.starts_with("claude") { + let api_key = anthropic_api_key + .ok_or_else(|| anyhow!("no Anthropic AI API key configured on the server"))?; + complete_with_anthropic(request, response, session, api_key).await?; } Ok(()) @@ -3621,6 +3627,121 @@ async fn complete_with_google_ai( Ok(()) } +async fn complete_with_anthropic( + request: proto::CompleteWithLanguageModel, + response: StreamingResponse, + session: UserSession, + api_key: Arc, +) -> Result<()> { + let model = anthropic::Model::from_id(&request.model)?; + + let mut system_message = String::new(); + let messages = request + .messages + .into_iter() + .filter_map(|message| match message.role() { + LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage { + role: anthropic::Role::User, + content: message.content, + }), + LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage { + role: anthropic::Role::Assistant, + content: message.content, + }), + // Anthropic's API breaks system instructions out as a separate field rather + // than having a system message role. + LanguageModelRole::LanguageModelSystem => { + if !system_message.is_empty() { + system_message.push_str("\n\n"); + } + system_message.push_str(&message.content); + + None + } + }) + .collect(); + + let mut stream = anthropic::stream_completion( + &session.http_client, + "https://api.anthropic.com", + &api_key, + anthropic::Request { + model, + messages, + stream: true, + system: system_message, + max_tokens: 4092, + }, + ) + .await?; + + let mut current_role = proto::LanguageModelRole::LanguageModelAssistant; + + while let Some(event) = stream.next().await { + let event = event?; + + match event { + anthropic::ResponseEvent::MessageStart { message } => { + if let Some(role) = message.role { + if role == "assistant" { + current_role = proto::LanguageModelRole::LanguageModelAssistant; + } else if role == "user" { + current_role = proto::LanguageModelRole::LanguageModelUser; + } + } + } + anthropic::ResponseEvent::ContentBlockStart { content_block, .. } => { + match content_block { + anthropic::ContentBlock::Text { text } => { + if !text.is_empty() { + response.send(proto::LanguageModelResponse { + choices: vec![proto::LanguageModelChoiceDelta { + index: 0, + delta: Some(proto::LanguageModelResponseMessage { + role: Some(current_role as i32), + content: Some(text), + }), + finish_reason: None, + }], + })?; + } + } + } + } + anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => match delta { + anthropic::TextDelta::TextDelta { text } => { + response.send(proto::LanguageModelResponse { + choices: vec![proto::LanguageModelChoiceDelta { + index: 0, + delta: Some(proto::LanguageModelResponseMessage { + role: Some(current_role as i32), + content: Some(text), + }), + finish_reason: None, + }], + })?; + } + }, + anthropic::ResponseEvent::MessageDelta { delta, .. } => { + if let Some(stop_reason) = delta.stop_reason { + response.send(proto::LanguageModelResponse { + choices: vec![proto::LanguageModelChoiceDelta { + index: 0, + delta: None, + finish_reason: Some(stop_reason), + }], + })?; + } + } + anthropic::ResponseEvent::ContentBlockStop { .. } => {} + anthropic::ResponseEvent::MessageStop {} => {} + anthropic::ResponseEvent::Ping {} => {} + } + } + + Ok(()) +} + struct CountTokensWithLanguageModelRateLimit; impl RateLimit for CountTokensWithLanguageModelRateLimit { diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index 64be8cb09b..f1364cdc66 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -512,6 +512,7 @@ impl TestServer { blob_store_bucket: None, openai_api_key: None, google_ai_api_key: None, + anthropic_api_key: None, clickhouse_url: None, clickhouse_user: None, clickhouse_password: None,