diff --git a/Cargo.lock b/Cargo.lock index 5c0570f912..0ea65f93ac 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -104,6 +104,7 @@ dependencies = [ "editor", "futures 0.3.28", "gpui", + "indoc", "isahc", "pulldown-cmark", "serde", diff --git a/Cargo.toml b/Cargo.toml index d8bf005b77..7411dd53ad 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -79,6 +79,7 @@ ctor = { version = "0.1" } env_logger = { version = "0.9" } futures = { version = "0.3" } glob = { version = "0.3.1" } +indoc = "1" isahc = "1.7.2" lazy_static = { version = "1.4.0" } log = { version = "0.4.16", features = ["kv_unstable_serde"] } diff --git a/Untitled b/Untitled new file mode 100644 index 0000000000..e69de29bb2 diff --git a/crates/ai/Cargo.toml b/crates/ai/Cargo.toml index 0953330a69..dacdbbbf63 100644 --- a/crates/ai/Cargo.toml +++ b/crates/ai/Cargo.toml @@ -16,6 +16,7 @@ util = { path = "../util" } serde.workspace = true serde_json.workspace = true anyhow.workspace = true +indoc.workspace = true pulldown-cmark = "0.9.2" futures.workspace = true isahc.workspace = true diff --git a/crates/ai/README.zmd b/crates/ai/README.zmd new file mode 100644 index 0000000000..44cda74cd5 --- /dev/null +++ b/crates/ai/README.zmd @@ -0,0 +1,5 @@ +This is Zed Markdown. + +Mention a language model with / at the start of any line, like this: + +/ What do you think of this idea? diff --git a/crates/ai/src/ai.rs b/crates/ai/src/ai.rs index b0bbd15d59..101378e747 100644 --- a/crates/ai/src/ai.rs +++ b/crates/ai/src/ai.rs @@ -1,16 +1,14 @@ -use std::io; -use std::rc::Rc; - use anyhow::{anyhow, Result}; use editor::Editor; use futures::AsyncBufReadExt; use futures::{io::BufReader, AsyncReadExt, Stream, StreamExt}; -use gpui::executor::Foreground; +use gpui::executor::Background; use gpui::{actions, AppContext, Task, ViewContext}; +use indoc::indoc; use isahc::prelude::*; use isahc::{http::StatusCode, Request}; -use pulldown_cmark::{Event, HeadingLevel, Parser, Tag}; use serde::{Deserialize, Serialize}; +use std::{io, sync::Arc}; use util::ResultExt; actions!(ai, [Assist]); @@ -93,99 +91,87 @@ fn assist( ) -> Option>> { let api_key = std::env::var("OPENAI_API_KEY").log_err()?; - let markdown = editor.text(cx); - let prompt = parse_dialog(&markdown); - let response = stream_completion(api_key, prompt, cx.foreground().clone()); + const SYSTEM_MESSAGE: &'static str = indoc! {r#" + You an AI language model embedded in a code editor named Zed, authored by Zed Industries. + The input you are currently processing was produced by a special \"model mention\" in a document that is open in the editor. + A model mention is indicated via a leading / on a line. + The user's currently selected text is indicated via ->->selected text<-<- surrounding selected text. + In this sentence, the word ->->example<-<- is selected. + Respond to any selected model mention. + Summarize each mention in a single short sentence like: + > The user selected the word \"example\". + Then provide your response to that mention below its summary. + "#}; - let range = editor.buffer().update(cx, |buffer, cx| { + let (user_message, insertion_site) = editor.buffer().update(cx, |buffer, cx| { + // Insert ->-> <-<- around selected text as described in the system prompt above. let snapshot = buffer.snapshot(cx); - let chars = snapshot.reversed_chars_at(snapshot.len()); - let trailing_newlines = chars.take(2).take_while(|c| *c == '\n').count(); - let suffix = "\n".repeat(2 - trailing_newlines); - let end = snapshot.len(); - buffer.edit([(end..end, suffix.clone())], None, cx); - let snapshot = buffer.snapshot(cx); - let start = snapshot.anchor_before(snapshot.len()); - let end = snapshot.anchor_after(snapshot.len()); - start..end + let mut user_message = String::new(); + let mut buffer_offset = 0; + for selection in editor.selections.all(cx) { + user_message.extend(snapshot.text_for_range(buffer_offset..selection.start)); + user_message.push_str("->->"); + user_message.extend(snapshot.text_for_range(selection.start..selection.end)); + buffer_offset = selection.end; + user_message.push_str("<-<-"); + } + if buffer_offset < snapshot.len() { + user_message.extend(snapshot.text_for_range(buffer_offset..snapshot.len())); + } + + // Ensure the document ends with 4 trailing newlines. + let trailing_newline_count = snapshot + .reversed_chars_at(snapshot.len()) + .take_while(|c| *c == '\n') + .take(4); + let suffix = "\n".repeat(4 - trailing_newline_count.count()); + buffer.edit([(snapshot.len()..snapshot.len(), suffix)], None, cx); + + let snapshot = buffer.snapshot(cx); // Take a new snapshot after editing. + let insertion_site = snapshot.len() - 2; // Insert text at end of buffer, with an empty line both above and below. + + (user_message, insertion_site) }); + + let stream = stream_completion( + api_key, + cx.background_executor().clone(), + OpenAIRequest { + model: "gpt-4".to_string(), + messages: vec![ + RequestMessage { + role: Role::System, + content: SYSTEM_MESSAGE.to_string(), + }, + RequestMessage { + role: Role::User, + content: user_message, + }, + ], + stream: false, + }, + ); let buffer = editor.buffer().clone(); - Some(cx.spawn(|_, mut cx| async move { - let mut stream = response.await?; - let mut message = String::new(); - while let Some(stream_event) = stream.next().await { - if let Some(choice) = stream_event?.choices.first() { - if let Some(content) = &choice.delta.content { - message.push_str(content); - } + let mut messages = stream.await?; + while let Some(message) = messages.next().await { + let mut message = message?; + if let Some(choice) = message.choices.pop() { + buffer.update(&mut cx, |buffer, cx| { + let text: Arc = choice.delta.content?.into(); + buffer.edit([(insertion_site.clone()..insertion_site, text)], None, cx); + Some(()) + }); } - - buffer.update(&mut cx, |buffer, cx| { - buffer.edit([(range.clone(), message.clone())], None, cx); - }); } Ok(()) })) } -fn parse_dialog(markdown: &str) -> OpenAIRequest { - let parser = Parser::new(markdown); - let mut messages = Vec::new(); - - let mut current_role: Option = None; - let mut buffer = String::new(); - for event in parser { - match event { - Event::Start(Tag::Heading(HeadingLevel::H2, _, _)) => { - if let Some(role) = current_role.take() { - if !buffer.is_empty() { - messages.push(RequestMessage { - role, - content: buffer.trim().to_string(), - }); - buffer.clear(); - } - } - } - Event::Text(text) => { - if current_role.is_some() { - buffer.push_str(&text); - } else { - // Determine the current role based on the H2 header text - let text = text.to_lowercase(); - current_role = if text.contains("user") { - Some(Role::User) - } else if text.contains("assistant") { - Some(Role::Assistant) - } else if text.contains("system") { - Some(Role::System) - } else { - None - }; - } - } - _ => (), - } - } - if let Some(role) = current_role { - messages.push(RequestMessage { - role, - content: buffer, - }); - } - - OpenAIRequest { - model: "gpt-4".into(), - messages, - stream: true, - } -} - async fn stream_completion( api_key: String, + executor: Arc, mut request: OpenAIRequest, - executor: Rc, ) -> Result>> { request.stream = true; @@ -240,32 +226,4 @@ async fn stream_completion( } #[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_parse_dialog() { - use unindent::Unindent; - - let test_input = r#" - ## System - Hey there, welcome to Zed! - - ## Assintant - Thanks! I'm excited to be here. I have much to learn, but also much to teach, and I'm growing fast. - "#.unindent(); - - let expected_output = vec![ - RequestMessage { - role: Role::User, - content: "Hey there, welcome to Zed!".to_string(), - }, - RequestMessage { - role: Role::Assistant, - content: "Thanks! I'm excited to be here. I have much to learn, but also much to teach, and I'm growing fast.".to_string(), - }, - ]; - - assert_eq!(parse_dialog(&test_input).messages, expected_output); - } -} +mod tests {} diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index f2202618f4..cd06b9a70a 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -76,7 +76,7 @@ workspace = { path = "../workspace", features = ["test-support"] } ctor.workspace = true env_logger.workspace = true -indoc = "1.0.4" +indoc.workspace = true util = { path = "../util" } lazy_static.workspace = true sea-orm = { git = "https://github.com/zed-industries/sea-orm", rev = "18f4c691085712ad014a51792af75a9044bacee6", features = ["sqlx-sqlite"] } diff --git a/crates/db/Cargo.toml b/crates/db/Cargo.toml index 8cb7170ef6..b49078e860 100644 --- a/crates/db/Cargo.toml +++ b/crates/db/Cargo.toml @@ -18,7 +18,7 @@ sqlez = { path = "../sqlez" } sqlez_macros = { path = "../sqlez_macros" } util = { path = "../util" } anyhow.workspace = true -indoc = "1.0.4" +indoc.workspace = true async-trait.workspace = true lazy_static.workspace = true log.workspace = true diff --git a/crates/editor/Cargo.toml b/crates/editor/Cargo.toml index fc7bf4b8ab..482923fee7 100644 --- a/crates/editor/Cargo.toml +++ b/crates/editor/Cargo.toml @@ -50,7 +50,7 @@ aho-corasick = "0.7" anyhow.workspace = true futures.workspace = true glob.workspace = true -indoc = "1.0.4" +indoc.workspace = true itertools = "0.10" lazy_static.workspace = true log.workspace = true diff --git a/crates/language/Cargo.toml b/crates/language/Cargo.toml index 5a7644d98e..79121b3799 100644 --- a/crates/language/Cargo.toml +++ b/crates/language/Cargo.toml @@ -70,7 +70,7 @@ settings = { path = "../settings", features = ["test-support"] } util = { path = "../util", features = ["test-support"] } ctor.workspace = true env_logger.workspace = true -indoc = "1.0.4" +indoc.workspace = true rand.workspace = true tree-sitter-embedded-template = "*" tree-sitter-html = "*" diff --git a/crates/sqlez/Cargo.toml b/crates/sqlez/Cargo.toml index 7371a7863a..01d17d4812 100644 --- a/crates/sqlez/Cargo.toml +++ b/crates/sqlez/Cargo.toml @@ -6,7 +6,7 @@ publish = false [dependencies] anyhow.workspace = true -indoc = "1.0.7" +indoc.workspace = true libsqlite3-sys = { version = "0.24", features = ["bundled"] } smol.workspace = true thread_local = "1.1.4" diff --git a/crates/vim/Cargo.toml b/crates/vim/Cargo.toml index c34a5b469b..ee3144fd56 100644 --- a/crates/vim/Cargo.toml +++ b/crates/vim/Cargo.toml @@ -35,7 +35,7 @@ settings = { path = "../settings" } workspace = { path = "../workspace" } [dev-dependencies] -indoc = "1.0.4" +indoc.workspace = true parking_lot.workspace = true lazy_static.workspace = true diff --git a/crates/workspace/Cargo.toml b/crates/workspace/Cargo.toml index 33e5e7aefe..b22607e20d 100644 --- a/crates/workspace/Cargo.toml +++ b/crates/workspace/Cargo.toml @@ -62,5 +62,5 @@ settings = { path = "../settings", features = ["test-support"] } fs = { path = "../fs", features = ["test-support"] } db = { path = "../db", features = ["test-support"] } -indoc = "1.0.4" +indoc.workspace = true env_logger.workspace = true