mirror of
https://github.com/zed-industries/zed.git
synced 2024-12-24 01:11:51 +00:00
WIP
This commit is contained in:
parent
7e6cccfa3d
commit
30de64845f
13 changed files with 86 additions and 120 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -104,6 +104,7 @@ dependencies = [
|
|||
"editor",
|
||||
"futures 0.3.28",
|
||||
"gpui",
|
||||
"indoc",
|
||||
"isahc",
|
||||
"pulldown-cmark",
|
||||
"serde",
|
||||
|
|
|
@ -79,6 +79,7 @@ ctor = { version = "0.1" }
|
|||
env_logger = { version = "0.9" }
|
||||
futures = { version = "0.3" }
|
||||
glob = { version = "0.3.1" }
|
||||
indoc = "1"
|
||||
isahc = "1.7.2"
|
||||
lazy_static = { version = "1.4.0" }
|
||||
log = { version = "0.4.16", features = ["kv_unstable_serde"] }
|
||||
|
|
0
Untitled
Normal file
0
Untitled
Normal file
|
@ -16,6 +16,7 @@ util = { path = "../util" }
|
|||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
anyhow.workspace = true
|
||||
indoc.workspace = true
|
||||
pulldown-cmark = "0.9.2"
|
||||
futures.workspace = true
|
||||
isahc.workspace = true
|
||||
|
|
5
crates/ai/README.zmd
Normal file
5
crates/ai/README.zmd
Normal file
|
@ -0,0 +1,5 @@
|
|||
This is Zed Markdown.
|
||||
|
||||
Mention a language model with / at the start of any line, like this:
|
||||
|
||||
/ What do you think of this idea?
|
|
@ -1,16 +1,14 @@
|
|||
use std::io;
|
||||
use std::rc::Rc;
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use editor::Editor;
|
||||
use futures::AsyncBufReadExt;
|
||||
use futures::{io::BufReader, AsyncReadExt, Stream, StreamExt};
|
||||
use gpui::executor::Foreground;
|
||||
use gpui::executor::Background;
|
||||
use gpui::{actions, AppContext, Task, ViewContext};
|
||||
use indoc::indoc;
|
||||
use isahc::prelude::*;
|
||||
use isahc::{http::StatusCode, Request};
|
||||
use pulldown_cmark::{Event, HeadingLevel, Parser, Tag};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{io, sync::Arc};
|
||||
use util::ResultExt;
|
||||
|
||||
actions!(ai, [Assist]);
|
||||
|
@ -93,99 +91,87 @@ fn assist(
|
|||
) -> Option<Task<Result<()>>> {
|
||||
let api_key = std::env::var("OPENAI_API_KEY").log_err()?;
|
||||
|
||||
let markdown = editor.text(cx);
|
||||
let prompt = parse_dialog(&markdown);
|
||||
let response = stream_completion(api_key, prompt, cx.foreground().clone());
|
||||
const SYSTEM_MESSAGE: &'static str = indoc! {r#"
|
||||
You an AI language model embedded in a code editor named Zed, authored by Zed Industries.
|
||||
The input you are currently processing was produced by a special \"model mention\" in a document that is open in the editor.
|
||||
A model mention is indicated via a leading / on a line.
|
||||
The user's currently selected text is indicated via ->->selected text<-<- surrounding selected text.
|
||||
In this sentence, the word ->->example<-<- is selected.
|
||||
Respond to any selected model mention.
|
||||
Summarize each mention in a single short sentence like:
|
||||
> The user selected the word \"example\".
|
||||
Then provide your response to that mention below its summary.
|
||||
"#};
|
||||
|
||||
let range = editor.buffer().update(cx, |buffer, cx| {
|
||||
let (user_message, insertion_site) = editor.buffer().update(cx, |buffer, cx| {
|
||||
// Insert ->-> <-<- around selected text as described in the system prompt above.
|
||||
let snapshot = buffer.snapshot(cx);
|
||||
let chars = snapshot.reversed_chars_at(snapshot.len());
|
||||
let trailing_newlines = chars.take(2).take_while(|c| *c == '\n').count();
|
||||
let suffix = "\n".repeat(2 - trailing_newlines);
|
||||
let end = snapshot.len();
|
||||
buffer.edit([(end..end, suffix.clone())], None, cx);
|
||||
let snapshot = buffer.snapshot(cx);
|
||||
let start = snapshot.anchor_before(snapshot.len());
|
||||
let end = snapshot.anchor_after(snapshot.len());
|
||||
start..end
|
||||
let mut user_message = String::new();
|
||||
let mut buffer_offset = 0;
|
||||
for selection in editor.selections.all(cx) {
|
||||
user_message.extend(snapshot.text_for_range(buffer_offset..selection.start));
|
||||
user_message.push_str("->->");
|
||||
user_message.extend(snapshot.text_for_range(selection.start..selection.end));
|
||||
buffer_offset = selection.end;
|
||||
user_message.push_str("<-<-");
|
||||
}
|
||||
if buffer_offset < snapshot.len() {
|
||||
user_message.extend(snapshot.text_for_range(buffer_offset..snapshot.len()));
|
||||
}
|
||||
|
||||
// Ensure the document ends with 4 trailing newlines.
|
||||
let trailing_newline_count = snapshot
|
||||
.reversed_chars_at(snapshot.len())
|
||||
.take_while(|c| *c == '\n')
|
||||
.take(4);
|
||||
let suffix = "\n".repeat(4 - trailing_newline_count.count());
|
||||
buffer.edit([(snapshot.len()..snapshot.len(), suffix)], None, cx);
|
||||
|
||||
let snapshot = buffer.snapshot(cx); // Take a new snapshot after editing.
|
||||
let insertion_site = snapshot.len() - 2; // Insert text at end of buffer, with an empty line both above and below.
|
||||
|
||||
(user_message, insertion_site)
|
||||
});
|
||||
|
||||
let stream = stream_completion(
|
||||
api_key,
|
||||
cx.background_executor().clone(),
|
||||
OpenAIRequest {
|
||||
model: "gpt-4".to_string(),
|
||||
messages: vec![
|
||||
RequestMessage {
|
||||
role: Role::System,
|
||||
content: SYSTEM_MESSAGE.to_string(),
|
||||
},
|
||||
RequestMessage {
|
||||
role: Role::User,
|
||||
content: user_message,
|
||||
},
|
||||
],
|
||||
stream: false,
|
||||
},
|
||||
);
|
||||
let buffer = editor.buffer().clone();
|
||||
|
||||
Some(cx.spawn(|_, mut cx| async move {
|
||||
let mut stream = response.await?;
|
||||
let mut message = String::new();
|
||||
while let Some(stream_event) = stream.next().await {
|
||||
if let Some(choice) = stream_event?.choices.first() {
|
||||
if let Some(content) = &choice.delta.content {
|
||||
message.push_str(content);
|
||||
}
|
||||
let mut messages = stream.await?;
|
||||
while let Some(message) = messages.next().await {
|
||||
let mut message = message?;
|
||||
if let Some(choice) = message.choices.pop() {
|
||||
buffer.update(&mut cx, |buffer, cx| {
|
||||
let text: Arc<str> = choice.delta.content?.into();
|
||||
buffer.edit([(insertion_site.clone()..insertion_site, text)], None, cx);
|
||||
Some(())
|
||||
});
|
||||
}
|
||||
|
||||
buffer.update(&mut cx, |buffer, cx| {
|
||||
buffer.edit([(range.clone(), message.clone())], None, cx);
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}))
|
||||
}
|
||||
|
||||
fn parse_dialog(markdown: &str) -> OpenAIRequest {
|
||||
let parser = Parser::new(markdown);
|
||||
let mut messages = Vec::new();
|
||||
|
||||
let mut current_role: Option<Role> = None;
|
||||
let mut buffer = String::new();
|
||||
for event in parser {
|
||||
match event {
|
||||
Event::Start(Tag::Heading(HeadingLevel::H2, _, _)) => {
|
||||
if let Some(role) = current_role.take() {
|
||||
if !buffer.is_empty() {
|
||||
messages.push(RequestMessage {
|
||||
role,
|
||||
content: buffer.trim().to_string(),
|
||||
});
|
||||
buffer.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
Event::Text(text) => {
|
||||
if current_role.is_some() {
|
||||
buffer.push_str(&text);
|
||||
} else {
|
||||
// Determine the current role based on the H2 header text
|
||||
let text = text.to_lowercase();
|
||||
current_role = if text.contains("user") {
|
||||
Some(Role::User)
|
||||
} else if text.contains("assistant") {
|
||||
Some(Role::Assistant)
|
||||
} else if text.contains("system") {
|
||||
Some(Role::System)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
}
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
}
|
||||
if let Some(role) = current_role {
|
||||
messages.push(RequestMessage {
|
||||
role,
|
||||
content: buffer,
|
||||
});
|
||||
}
|
||||
|
||||
OpenAIRequest {
|
||||
model: "gpt-4".into(),
|
||||
messages,
|
||||
stream: true,
|
||||
}
|
||||
}
|
||||
|
||||
async fn stream_completion(
|
||||
api_key: String,
|
||||
executor: Arc<Background>,
|
||||
mut request: OpenAIRequest,
|
||||
executor: Rc<Foreground>,
|
||||
) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
|
||||
request.stream = true;
|
||||
|
||||
|
@ -240,32 +226,4 @@ async fn stream_completion(
|
|||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_dialog() {
|
||||
use unindent::Unindent;
|
||||
|
||||
let test_input = r#"
|
||||
## System
|
||||
Hey there, welcome to Zed!
|
||||
|
||||
## Assintant
|
||||
Thanks! I'm excited to be here. I have much to learn, but also much to teach, and I'm growing fast.
|
||||
"#.unindent();
|
||||
|
||||
let expected_output = vec![
|
||||
RequestMessage {
|
||||
role: Role::User,
|
||||
content: "Hey there, welcome to Zed!".to_string(),
|
||||
},
|
||||
RequestMessage {
|
||||
role: Role::Assistant,
|
||||
content: "Thanks! I'm excited to be here. I have much to learn, but also much to teach, and I'm growing fast.".to_string(),
|
||||
},
|
||||
];
|
||||
|
||||
assert_eq!(parse_dialog(&test_input).messages, expected_output);
|
||||
}
|
||||
}
|
||||
mod tests {}
|
||||
|
|
|
@ -76,7 +76,7 @@ workspace = { path = "../workspace", features = ["test-support"] }
|
|||
|
||||
ctor.workspace = true
|
||||
env_logger.workspace = true
|
||||
indoc = "1.0.4"
|
||||
indoc.workspace = true
|
||||
util = { path = "../util" }
|
||||
lazy_static.workspace = true
|
||||
sea-orm = { git = "https://github.com/zed-industries/sea-orm", rev = "18f4c691085712ad014a51792af75a9044bacee6", features = ["sqlx-sqlite"] }
|
||||
|
|
|
@ -18,7 +18,7 @@ sqlez = { path = "../sqlez" }
|
|||
sqlez_macros = { path = "../sqlez_macros" }
|
||||
util = { path = "../util" }
|
||||
anyhow.workspace = true
|
||||
indoc = "1.0.4"
|
||||
indoc.workspace = true
|
||||
async-trait.workspace = true
|
||||
lazy_static.workspace = true
|
||||
log.workspace = true
|
||||
|
|
|
@ -50,7 +50,7 @@ aho-corasick = "0.7"
|
|||
anyhow.workspace = true
|
||||
futures.workspace = true
|
||||
glob.workspace = true
|
||||
indoc = "1.0.4"
|
||||
indoc.workspace = true
|
||||
itertools = "0.10"
|
||||
lazy_static.workspace = true
|
||||
log.workspace = true
|
||||
|
|
|
@ -70,7 +70,7 @@ settings = { path = "../settings", features = ["test-support"] }
|
|||
util = { path = "../util", features = ["test-support"] }
|
||||
ctor.workspace = true
|
||||
env_logger.workspace = true
|
||||
indoc = "1.0.4"
|
||||
indoc.workspace = true
|
||||
rand.workspace = true
|
||||
tree-sitter-embedded-template = "*"
|
||||
tree-sitter-html = "*"
|
||||
|
|
|
@ -6,7 +6,7 @@ publish = false
|
|||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
indoc = "1.0.7"
|
||||
indoc.workspace = true
|
||||
libsqlite3-sys = { version = "0.24", features = ["bundled"] }
|
||||
smol.workspace = true
|
||||
thread_local = "1.1.4"
|
||||
|
|
|
@ -35,7 +35,7 @@ settings = { path = "../settings" }
|
|||
workspace = { path = "../workspace" }
|
||||
|
||||
[dev-dependencies]
|
||||
indoc = "1.0.4"
|
||||
indoc.workspace = true
|
||||
parking_lot.workspace = true
|
||||
lazy_static.workspace = true
|
||||
|
||||
|
|
|
@ -62,5 +62,5 @@ settings = { path = "../settings", features = ["test-support"] }
|
|||
fs = { path = "../fs", features = ["test-support"] }
|
||||
db = { path = "../db", features = ["test-support"] }
|
||||
|
||||
indoc = "1.0.4"
|
||||
indoc.workspace = true
|
||||
env_logger.workspace = true
|
||||
|
|
Loading…
Reference in a new issue