From 7e6cccfa3d2dfd6ca76f78c06b4aa0f83ea36c54 Mon Sep 17 00:00:00 2001 From: Nathan Sobo Date: Mon, 22 May 2023 20:28:22 -0600 Subject: [PATCH] WIP: Stream in completions Drop dependency on tokio introduced by async-openai and do it ourselves. The approach I'm taking of replacing instead of appending is causing issues. Need to just append. --- Cargo.lock | 204 +------------------------------ Cargo.toml | 1 + crates/ai/Cargo.toml | 7 +- crates/ai/src/ai.rs | 223 +++++++++++++++++++++++++++++----- crates/auto_update/Cargo.toml | 2 +- crates/feedback/Cargo.toml | 2 +- crates/gpui/src/executor.rs | 2 +- crates/util/Cargo.toml | 2 +- crates/zed/Cargo.toml | 2 +- 9 files changed, 209 insertions(+), 236 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 59f933f4b9..5c0570f912 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -100,11 +100,16 @@ name = "ai" version = "0.1.0" dependencies = [ "anyhow", - "async-openai", + "async-stream", "editor", + "futures 0.3.28", "gpui", + "isahc", "pulldown-cmark", + "serde", + "serde_json", "unindent", + "util", ] [[package]] @@ -354,28 +359,6 @@ dependencies = [ "futures-lite", ] -[[package]] -name = "async-openai" -version = "0.10.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5d5e93aca1b2f0ca772c76cadd43e965809df87ef98e25e47244c7f006c85d2" -dependencies = [ - "backoff", - "base64 0.21.0", - "derive_builder", - "futures 0.3.28", - "rand 0.8.5", - "reqwest", - "reqwest-eventsource", - "serde", - "serde_json", - "thiserror", - "tokio", - "tokio-stream", - "tokio-util 0.7.8", - "tracing", -] - [[package]] name = "async-pipe" version = "0.1.3" @@ -676,20 +659,6 @@ dependencies = [ "tower-service", ] -[[package]] -name = "backoff" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b62ddb9cb1ec0a098ad4bbf9344d0713fa193ae1a80af55febcff2627b6a00c1" -dependencies = [ - "futures-core", - "getrandom 0.2.9", - "instant", - "pin-project-lite 0.2.9", - "rand 0.8.5", - "tokio", -] - [[package]] name = "backtrace" version = "0.3.67" @@ -1849,41 +1818,6 @@ dependencies = [ "syn 2.0.15", ] -[[package]] -name = "darling" -version = "0.14.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b750cb3417fd1b327431a470f388520309479ab0bf5e323505daf0290cd3850" -dependencies = [ - "darling_core", - "darling_macro", -] - -[[package]] -name = "darling_core" -version = "0.14.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "109c1ca6e6b7f82cc233a97004ea8ed7ca123a9af07a8230878fcfda9b158bf0" -dependencies = [ - "fnv", - "ident_case", - "proc-macro2", - "quote", - "strsim 0.10.0", - "syn 1.0.109", -] - -[[package]] -name = "darling_macro" -version = "0.14.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4aab4dbc9f7611d8b55048a3a16d2d010c2c8334e46304b40ac1cc14bf3b48e" -dependencies = [ - "darling_core", - "quote", - "syn 1.0.109", -] - [[package]] name = "dashmap" version = "5.4.0" @@ -1938,37 +1872,6 @@ dependencies = [ "byteorder", ] -[[package]] -name = "derive_builder" -version = "0.12.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d67778784b508018359cbc8696edb3db78160bab2c2a28ba7f56ef6932997f8" -dependencies = [ - "derive_builder_macro", -] - -[[package]] -name = "derive_builder_core" -version = "0.12.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c11bdc11a0c47bc7d37d582b5285da6849c96681023680b906673c5707af7b0f" -dependencies = [ - "darling", - "proc-macro2", - "quote", - "syn 1.0.109", -] - -[[package]] -name = "derive_builder_macro" -version = "0.12.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ebcda35c7a396850a55ffeac740804b40ffec779b98fffbb1738f4033f0ee79e" -dependencies = [ - "derive_builder_core", - "syn 1.0.109", -] - [[package]] name = "dhat" version = "0.3.2" @@ -2304,17 +2207,6 @@ version = "2.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0206175f82b8d6bf6652ff7d71a1e27fd2e4efde587fd368662814d6ec1d9ce0" -[[package]] -name = "eventsource-stream" -version = "0.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74fef4569247a5f429d9156b9d0a2599914385dd189c539334c625d8099d90ab" -dependencies = [ - "futures-core", - "nom", - "pin-project-lite 0.2.9", -] - [[package]] name = "fallible-iterator" version = "0.2.0" @@ -2711,12 +2603,6 @@ version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "76d3d132be6c0e6aa1534069c705a74a5997a356c0dc2f86a47765e5617c5b65" -[[package]] -name = "futures-timer" -version = "3.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e64b03909df88034c26dc1547e8970b91f98bdb65165d6a4e9110d94263dbb2c" - [[package]] name = "futures-util" version = "0.3.28" @@ -3200,19 +3086,6 @@ dependencies = [ "want", ] -[[package]] -name = "hyper-rustls" -version = "0.23.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1788965e61b367cd03a62950836d5cd41560c3577d90e40e0819373194d1661c" -dependencies = [ - "http", - "hyper", - "rustls 0.20.8", - "tokio", - "tokio-rustls", -] - [[package]] name = "hyper-timeout" version = "0.4.1" @@ -3262,12 +3135,6 @@ dependencies = [ "cxx-build", ] -[[package]] -name = "ident_case" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" - [[package]] name = "idna" version = "0.3.0" @@ -4062,16 +3929,6 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" -[[package]] -name = "mime_guess" -version = "2.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4192263c238a5f0d0c6bfd21f336a313a4ce1c450542449ca191bb657b4642ef" -dependencies = [ - "mime", - "unicase", -] - [[package]] name = "minimal-lexical" version = "0.2.1" @@ -5537,52 +5394,28 @@ dependencies = [ "http", "http-body", "hyper", - "hyper-rustls", "hyper-tls", "ipnet", "js-sys", "log", "mime", - "mime_guess", "native-tls", "once_cell", "percent-encoding", "pin-project-lite 0.2.9", - "rustls 0.20.8", - "rustls-native-certs", - "rustls-pemfile", "serde", "serde_json", "serde_urlencoded", "tokio", "tokio-native-tls", - "tokio-rustls", - "tokio-util 0.7.8", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", - "wasm-streams", "web-sys", "winreg", ] -[[package]] -name = "reqwest-eventsource" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f03f570355882dd8d15acc3a313841e6e90eddbc76a93c748fd82cc13ba9f51" -dependencies = [ - "eventsource-stream", - "futures-core", - "futures-timer", - "mime", - "nom", - "pin-project-lite 0.2.9", - "reqwest", - "thiserror", -] - [[package]] name = "resvg" version = "0.14.1" @@ -5870,18 +5703,6 @@ dependencies = [ "webpki 0.22.0", ] -[[package]] -name = "rustls-native-certs" -version = "0.6.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0167bac7a9f490495f3c33013e7722b53cb087ecbe082fb0c6387c96f634ea50" -dependencies = [ - "openssl-probe", - "rustls-pemfile", - "schannel", - "security-framework", -] - [[package]] name = "rustls-pemfile" version = "1.0.2" @@ -8245,19 +8066,6 @@ dependencies = [ "leb128", ] -[[package]] -name = "wasm-streams" -version = "0.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6bbae3363c08332cadccd13b67db371814cd214c2524020932f0804b8cf7c078" -dependencies = [ - "futures-util", - "js-sys", - "wasm-bindgen", - "wasm-bindgen-futures", - "web-sys", -] - [[package]] name = "wasmparser" version = "0.85.0" diff --git a/Cargo.toml b/Cargo.toml index 77252802d5..d8bf005b77 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -79,6 +79,7 @@ ctor = { version = "0.1" } env_logger = { version = "0.9" } futures = { version = "0.3" } glob = { version = "0.3.1" } +isahc = "1.7.2" lazy_static = { version = "1.4.0" } log = { version = "0.4.16", features = ["kv_unstable_serde"] } ordered-float = { version = "2.1.1" } diff --git a/crates/ai/Cargo.toml b/crates/ai/Cargo.toml index 30dc5ee5a2..0953330a69 100644 --- a/crates/ai/Cargo.toml +++ b/crates/ai/Cargo.toml @@ -11,11 +11,16 @@ doctest = false [dependencies] editor = { path = "../editor" } gpui = { path = "../gpui" } +util = { path = "../util" } +serde.workspace = true +serde_json.workspace = true anyhow.workspace = true -async-openai = "0.10.3" pulldown-cmark = "0.9.2" +futures.workspace = true +isahc.workspace = true unindent.workspace = true +async-stream = "0.3.5" [dev-dependencies] editor = { path = "../editor", features = ["test-support"] } diff --git a/crates/ai/src/ai.rs b/crates/ai/src/ai.rs index 0ae960e281..b0bbd15d59 100644 --- a/crates/ai/src/ai.rs +++ b/crates/ai/src/ai.rs @@ -1,11 +1,87 @@ -use anyhow::Result; -use async_openai::types::{ChatCompletionRequestMessage, CreateChatCompletionRequest, Role}; +use std::io; +use std::rc::Rc; + +use anyhow::{anyhow, Result}; use editor::Editor; +use futures::AsyncBufReadExt; +use futures::{io::BufReader, AsyncReadExt, Stream, StreamExt}; +use gpui::executor::Foreground; use gpui::{actions, AppContext, Task, ViewContext}; +use isahc::prelude::*; +use isahc::{http::StatusCode, Request}; use pulldown_cmark::{Event, HeadingLevel, Parser, Tag}; +use serde::{Deserialize, Serialize}; +use util::ResultExt; actions!(ai, [Assist]); +// Data types for chat completion requests +#[derive(Serialize)] +struct OpenAIRequest { + model: String, + messages: Vec, + stream: bool, +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +struct RequestMessage { + role: Role, + content: String, +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +struct ResponseMessage { + role: Option, + content: Option, +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +#[serde(rename_all = "lowercase")] +enum Role { + User, + Assistant, + System, +} + +#[derive(Deserialize, Debug)] +struct OpenAIResponseStreamEvent { + pub id: Option, + pub object: String, + pub created: u32, + pub model: String, + pub choices: Vec, + pub usage: Option, +} + +#[derive(Deserialize, Debug)] +struct Usage { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, +} + +#[derive(Deserialize, Debug)] +struct ChatChoiceDelta { + pub index: u32, + pub delta: ResponseMessage, + pub finish_reason: Option, +} + +#[derive(Deserialize, Debug)] +struct OpenAIUsage { + prompt_tokens: u64, + completion_tokens: u64, + total_tokens: u64, +} + +#[derive(Deserialize, Debug)] +struct OpenAIChoice { + text: String, + index: u32, + logprobs: Option, + finish_reason: Option, +} + pub fn init(cx: &mut AppContext) { cx.add_async_action(assist) } @@ -15,26 +91,58 @@ fn assist( _: &Assist, cx: &mut ViewContext, ) -> Option>> { + let api_key = std::env::var("OPENAI_API_KEY").log_err()?; + let markdown = editor.text(cx); - parse_dialog(&markdown); - None + let prompt = parse_dialog(&markdown); + let response = stream_completion(api_key, prompt, cx.foreground().clone()); + + let range = editor.buffer().update(cx, |buffer, cx| { + let snapshot = buffer.snapshot(cx); + let chars = snapshot.reversed_chars_at(snapshot.len()); + let trailing_newlines = chars.take(2).take_while(|c| *c == '\n').count(); + let suffix = "\n".repeat(2 - trailing_newlines); + let end = snapshot.len(); + buffer.edit([(end..end, suffix.clone())], None, cx); + let snapshot = buffer.snapshot(cx); + let start = snapshot.anchor_before(snapshot.len()); + let end = snapshot.anchor_after(snapshot.len()); + start..end + }); + let buffer = editor.buffer().clone(); + + Some(cx.spawn(|_, mut cx| async move { + let mut stream = response.await?; + let mut message = String::new(); + while let Some(stream_event) = stream.next().await { + if let Some(choice) = stream_event?.choices.first() { + if let Some(content) = &choice.delta.content { + message.push_str(content); + } + } + + buffer.update(&mut cx, |buffer, cx| { + buffer.edit([(range.clone(), message.clone())], None, cx); + }); + } + Ok(()) + })) } -fn parse_dialog(markdown: &str) -> CreateChatCompletionRequest { +fn parse_dialog(markdown: &str) -> OpenAIRequest { let parser = Parser::new(markdown); let mut messages = Vec::new(); - let mut current_role: Option<(Role, Option)> = None; + let mut current_role: Option = None; let mut buffer = String::new(); for event in parser { match event { Event::Start(Tag::Heading(HeadingLevel::H2, _, _)) => { - if let Some((role, name)) = current_role.take() { + if let Some(role) = current_role.take() { if !buffer.is_empty() { - messages.push(ChatCompletionRequestMessage { + messages.push(RequestMessage { role, content: buffer.trim().to_string(), - name, }); buffer.clear(); } @@ -45,36 +153,89 @@ fn parse_dialog(markdown: &str) -> CreateChatCompletionRequest { buffer.push_str(&text); } else { // Determine the current role based on the H2 header text - let mut chars = text.chars(); - let first_char = chars.by_ref().skip_while(|c| c.is_whitespace()).next(); - let name = chars.take_while(|c| *c != '\n').collect::(); - let name = if name.is_empty() { None } else { Some(name) }; - - let role = match first_char { - Some('@') => Some(Role::User), - Some('/') => Some(Role::Assistant), - Some('#') => Some(Role::System), - _ => None, + let text = text.to_lowercase(); + current_role = if text.contains("user") { + Some(Role::User) + } else if text.contains("assistant") { + Some(Role::Assistant) + } else if text.contains("system") { + Some(Role::System) + } else { + None }; - - current_role = role.map(|role| (role, name)); } } _ => (), } } - if let Some((role, name)) = current_role { - messages.push(ChatCompletionRequestMessage { + if let Some(role) = current_role { + messages.push(RequestMessage { role, content: buffer, - name, }); } - CreateChatCompletionRequest { + OpenAIRequest { model: "gpt-4".into(), messages, - ..Default::default() + stream: true, + } +} + +async fn stream_completion( + api_key: String, + mut request: OpenAIRequest, + executor: Rc, +) -> Result>> { + request.stream = true; + + let (tx, rx) = futures::channel::mpsc::unbounded::>(); + + let json_data = serde_json::to_string(&request)?; + let mut response = Request::post("https://api.openai.com/v1/chat/completions") + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", api_key)) + .body(json_data)? + .send_async() + .await?; + + let status = response.status(); + if status == StatusCode::OK { + executor + .spawn(async move { + let mut lines = BufReader::new(response.body_mut()).lines(); + + fn parse_line( + line: Result, + ) -> Result> { + if let Some(data) = line?.strip_prefix("data: ") { + let event = serde_json::from_str(&data)?; + Ok(Some(event)) + } else { + Ok(None) + } + } + + while let Some(line) = lines.next().await { + if let Some(event) = parse_line(line).transpose() { + tx.unbounded_send(event).log_err(); + } + } + + anyhow::Ok(()) + }) + .detach(); + + Ok(rx) + } else { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + Err(anyhow!( + "Failed to connect to OpenAI API: {} {}", + response.status(), + body, + )) } } @@ -87,23 +248,21 @@ mod tests { use unindent::Unindent; let test_input = r#" - ## @nathan + ## System Hey there, welcome to Zed! - ## /sky + ## Assintant Thanks! I'm excited to be here. I have much to learn, but also much to teach, and I'm growing fast. "#.unindent(); let expected_output = vec![ - ChatCompletionRequestMessage { + RequestMessage { role: Role::User, content: "Hey there, welcome to Zed!".to_string(), - name: Some("nathan".to_string()), }, - ChatCompletionRequestMessage { + RequestMessage { role: Role::Assistant, content: "Thanks! I'm excited to be here. I have much to learn, but also much to teach, and I'm growing fast.".to_string(), - name: Some("sky".to_string()), }, ]; diff --git a/crates/auto_update/Cargo.toml b/crates/auto_update/Cargo.toml index f2b5cea854..884ed2b7a0 100644 --- a/crates/auto_update/Cargo.toml +++ b/crates/auto_update/Cargo.toml @@ -19,7 +19,7 @@ theme = { path = "../theme" } workspace = { path = "../workspace" } util = { path = "../util" } anyhow.workspace = true -isahc = "1.7" +isahc.workspace = true lazy_static.workspace = true log.workspace = true serde.workspace = true diff --git a/crates/feedback/Cargo.toml b/crates/feedback/Cargo.toml index e74e14ff4c..ae8d0f1569 100644 --- a/crates/feedback/Cargo.toml +++ b/crates/feedback/Cargo.toml @@ -27,7 +27,7 @@ futures.workspace = true anyhow.workspace = true smallvec.workspace = true human_bytes = "0.4.1" -isahc = "1.7" +isahc.workspace = true lazy_static.workspace = true postage.workspace = true serde.workspace = true diff --git a/crates/gpui/src/executor.rs b/crates/gpui/src/executor.rs index 028656a027..a06e0d5fdb 100644 --- a/crates/gpui/src/executor.rs +++ b/crates/gpui/src/executor.rs @@ -960,7 +960,7 @@ impl Task> { pub fn detach_and_log_err(self, cx: &mut AppContext) { cx.spawn(|_| async move { if let Err(err) = self.await { - log::error!("{}", err); + log::error!("{:#}", err); } }) .detach(); diff --git a/crates/util/Cargo.toml b/crates/util/Cargo.toml index 4ec8f7553c..6216d2e472 100644 --- a/crates/util/Cargo.toml +++ b/crates/util/Cargo.toml @@ -17,7 +17,7 @@ backtrace = "0.3" log.workspace = true lazy_static.workspace = true futures.workspace = true -isahc = "1.7" +isahc.workspace = true smol.workspace = true url = "2.2" rand.workspace = true diff --git a/crates/zed/Cargo.toml b/crates/zed/Cargo.toml index e24b7ef232..a385d37693 100644 --- a/crates/zed/Cargo.toml +++ b/crates/zed/Cargo.toml @@ -82,7 +82,7 @@ futures.workspace = true ignore = "0.4" image = "0.23" indexmap = "1.6.2" -isahc = "1.7" +isahc.workspace = true lazy_static.workspace = true libc = "0.2" log.workspace = true