Instruct the assistant to reply to a specific message (#2631)

Closes
https://linear.app/zed-industries/issue/Z-2384/hitting-cmd-enter-in-a-user-or-system-message-should-generate-a

Release Notes:

- Introduced the ability to generate assistant messages for any
user/system message, as well as generating multiple assists at the same
time, one for each cursor. (preview-only)
This commit is contained in:
Antonio Scandurra 2023-06-20 18:16:23 +02:00 committed by GitHub
commit 6ed86781b2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 344 additions and 142 deletions

1
Cargo.lock generated
View file

@ -114,6 +114,7 @@ dependencies = [
"serde", "serde",
"serde_json", "serde_json",
"settings", "settings",
"smol",
"theme", "theme",
"tiktoken-rs", "tiktoken-rs",
"util", "util",

View file

@ -28,6 +28,7 @@ isahc.workspace = true
schemars.workspace = true schemars.workspace = true
serde.workspace = true serde.workspace = true
serde_json.workspace = true serde_json.workspace = true
smol.workspace = true
tiktoken-rs = "0.4" tiktoken-rs = "0.4"
[dev-dependencies] [dev-dependencies]

View file

@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize};
use std::fmt::{self, Display}; use std::fmt::{self, Display};
// Data types for chat completion requests // Data types for chat completion requests
#[derive(Serialize)] #[derive(Debug, Serialize)]
struct OpenAIRequest { struct OpenAIRequest {
model: String, model: String,
messages: Vec<RequestMessage>, messages: Vec<RequestMessage>,

View file

@ -473,7 +473,7 @@ impl Assistant {
language_registry: Arc<LanguageRegistry>, language_registry: Arc<LanguageRegistry>,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) -> Self { ) -> Self {
let model = "gpt-3.5-turbo"; let model = "gpt-3.5-turbo-0613";
let markdown = language_registry.language_for_name("Markdown"); let markdown = language_registry.language_for_name("Markdown");
let buffer = cx.add_model(|cx| { let buffer = cx.add_model(|cx| {
let mut buffer = Buffer::new(0, "", cx); let mut buffer = Buffer::new(0, "", cx);
@ -518,7 +518,7 @@ impl Assistant {
MessageMetadata { MessageMetadata {
role: Role::User, role: Role::User,
sent_at: Local::now(), sent_at: Local::now(),
error: None, status: MessageStatus::Done,
}, },
); );
@ -543,7 +543,7 @@ impl Assistant {
fn count_remaining_tokens(&mut self, cx: &mut ModelContext<Self>) { fn count_remaining_tokens(&mut self, cx: &mut ModelContext<Self>) {
let messages = self let messages = self
.open_ai_request_messages(cx) .messages(cx)
.into_iter() .into_iter()
.filter_map(|message| { .filter_map(|message| {
Some(tiktoken_rs::ChatCompletionRequestMessage { Some(tiktoken_rs::ChatCompletionRequestMessage {
@ -552,7 +552,7 @@ impl Assistant {
Role::Assistant => "assistant".into(), Role::Assistant => "assistant".into(),
Role::System => "system".into(), Role::System => "system".into(),
}, },
content: message.content, content: self.buffer.read(cx).text_for_range(message.range).collect(),
name: None, name: None,
}) })
}) })
@ -589,97 +589,169 @@ impl Assistant {
cx.notify(); cx.notify();
} }
fn assist(&mut self, cx: &mut ModelContext<Self>) -> Option<(MessageAnchor, MessageAnchor)> { fn assist(
let request = OpenAIRequest { &mut self,
model: self.model.clone(), selected_messages: HashSet<MessageId>,
messages: self.open_ai_request_messages(cx), cx: &mut ModelContext<Self>,
stream: true, ) -> Vec<MessageAnchor> {
}; let mut user_messages = Vec::new();
let mut tasks = Vec::new();
for selected_message_id in selected_messages {
let selected_message_role =
if let Some(metadata) = self.messages_metadata.get(&selected_message_id) {
metadata.role
} else {
continue;
};
let api_key = self.api_key.borrow().clone()?; if selected_message_role == Role::Assistant {
let stream = stream_completion(api_key, cx.background().clone(), request); if let Some(user_message) = self.insert_message_after(
let assistant_message = selected_message_id,
self.insert_message_after(self.message_anchors.last()?.id, Role::Assistant, cx)?; Role::User,
let user_message = self.insert_message_after(assistant_message.id, Role::User, cx)?; MessageStatus::Done,
let task = cx.spawn_weak({ cx,
|this, mut cx| async move { ) {
let assistant_message_id = assistant_message.id; user_messages.push(user_message);
let stream_completion = async { } else {
let mut messages = stream.await?; continue;
}
} else {
let request = OpenAIRequest {
model: self.model.clone(),
messages: self
.messages(cx)
.filter(|message| matches!(message.status, MessageStatus::Done))
.flat_map(|message| {
let mut system_message = None;
if message.id == selected_message_id {
system_message = Some(RequestMessage {
role: Role::System,
content: concat!(
"Treat the following messages as additional knowledge you have learned about, ",
"but act as if they were not part of this conversation. That is, treat them ",
"as if the user didn't see them and couldn't possibly inquire about them."
).into()
});
}
Some(message.to_open_ai_message(self.buffer.read(cx))).into_iter().chain(system_message)
})
.chain(Some(RequestMessage {
role: Role::System,
content: format!(
"Direct your reply to message with id {}. Do not include a [Message X] header.",
selected_message_id.0
),
}))
.collect(),
stream: true,
};
let Some(api_key) = self.api_key.borrow().clone() else { continue };
let stream = stream_completion(api_key, cx.background().clone(), request);
let assistant_message = self
.insert_message_after(
selected_message_id,
Role::Assistant,
MessageStatus::Pending,
cx,
)
.unwrap();
tasks.push(cx.spawn_weak({
|this, mut cx| async move {
let assistant_message_id = assistant_message.id;
let stream_completion = async {
let mut messages = stream.await?;
while let Some(message) = messages.next().await {
let mut message = message?;
if let Some(choice) = message.choices.pop() {
this.upgrade(&cx)
.ok_or_else(|| anyhow!("assistant was dropped"))?
.update(&mut cx, |this, cx| {
let text: Arc<str> = choice.delta.content?.into();
let message_ix = this.message_anchors.iter().position(
|message| message.id == assistant_message_id,
)?;
this.buffer.update(cx, |buffer, cx| {
let offset = this.message_anchors[message_ix + 1..]
.iter()
.find(|message| message.start.is_valid(buffer))
.map_or(buffer.len(), |message| {
message
.start
.to_offset(buffer)
.saturating_sub(1)
});
buffer.edit([(offset..offset, text)], None, cx);
});
cx.emit(AssistantEvent::StreamedCompletion);
Some(())
});
}
smol::future::yield_now().await;
}
while let Some(message) = messages.next().await {
let mut message = message?;
if let Some(choice) = message.choices.pop() {
this.upgrade(&cx) this.upgrade(&cx)
.ok_or_else(|| anyhow!("assistant was dropped"))? .ok_or_else(|| anyhow!("assistant was dropped"))?
.update(&mut cx, |this, cx| { .update(&mut cx, |this, cx| {
let text: Arc<str> = choice.delta.content?.into(); this.pending_completions.retain(|completion| {
let message_ix = this completion.id != this.completion_count
.message_anchors
.iter()
.position(|message| message.id == assistant_message_id)?;
this.buffer.update(cx, |buffer, cx| {
let offset = if message_ix + 1 == this.message_anchors.len()
{
buffer.len()
} else {
this.message_anchors[message_ix + 1]
.start
.to_offset(buffer)
.saturating_sub(1)
};
buffer.edit([(offset..offset, text)], None, cx);
}); });
cx.emit(AssistantEvent::StreamedCompletion); this.summarize(cx);
Some(())
}); });
anyhow::Ok(())
};
let result = stream_completion.await;
if let Some(this) = this.upgrade(&cx) {
this.update(&mut cx, |this, cx| {
if let Some(metadata) =
this.messages_metadata.get_mut(&assistant_message.id)
{
match result {
Ok(_) => {
metadata.status = MessageStatus::Done;
}
Err(error) => {
metadata.status = MessageStatus::Error(
error.to_string().trim().into(),
);
}
}
cx.notify();
}
});
} }
} }
}));
this.upgrade(&cx)
.ok_or_else(|| anyhow!("assistant was dropped"))?
.update(&mut cx, |this, cx| {
this.pending_completions
.retain(|completion| completion.id != this.completion_count);
this.summarize(cx);
});
anyhow::Ok(())
};
let result = stream_completion.await;
if let Some(this) = this.upgrade(&cx) {
this.update(&mut cx, |this, cx| {
if let Err(error) = result {
if let Some(metadata) =
this.messages_metadata.get_mut(&assistant_message.id)
{
metadata.error = Some(error.to_string().trim().into());
cx.notify();
}
}
});
}
} }
}); }
self.pending_completions.push(PendingCompletion { if !tasks.is_empty() {
id: post_inc(&mut self.completion_count), self.pending_completions.push(PendingCompletion {
_task: task, id: post_inc(&mut self.completion_count),
}); _tasks: tasks,
Some((assistant_message, user_message)) });
}
user_messages
} }
fn cancel_last_assist(&mut self) -> bool { fn cancel_last_assist(&mut self) -> bool {
self.pending_completions.pop().is_some() self.pending_completions.pop().is_some()
} }
fn cycle_message_role(&mut self, id: MessageId, cx: &mut ModelContext<Self>) { fn cycle_message_roles(&mut self, ids: HashSet<MessageId>, cx: &mut ModelContext<Self>) {
if let Some(metadata) = self.messages_metadata.get_mut(&id) { for id in ids {
metadata.role.cycle(); if let Some(metadata) = self.messages_metadata.get_mut(&id) {
cx.emit(AssistantEvent::MessagesEdited); metadata.role.cycle();
cx.notify(); cx.emit(AssistantEvent::MessagesEdited);
cx.notify();
}
} }
} }
@ -687,6 +759,7 @@ impl Assistant {
&mut self, &mut self,
message_id: MessageId, message_id: MessageId,
role: Role, role: Role,
status: MessageStatus,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) -> Option<MessageAnchor> { ) -> Option<MessageAnchor> {
if let Some(prev_message_ix) = self if let Some(prev_message_ix) = self
@ -713,7 +786,7 @@ impl Assistant {
MessageMetadata { MessageMetadata {
role, role,
sent_at: Local::now(), sent_at: Local::now(),
error: None, status,
}, },
); );
cx.emit(AssistantEvent::MessagesEdited); cx.emit(AssistantEvent::MessagesEdited);
@ -772,7 +845,7 @@ impl Assistant {
MessageMetadata { MessageMetadata {
role, role,
sent_at: Local::now(), sent_at: Local::now(),
error: None, status: MessageStatus::Done,
}, },
); );
@ -814,7 +887,7 @@ impl Assistant {
MessageMetadata { MessageMetadata {
role, role,
sent_at: Local::now(), sent_at: Local::now(),
error: None, status: MessageStatus::Done,
}, },
); );
(Some(selection), Some(suffix)) (Some(selection), Some(suffix))
@ -833,16 +906,19 @@ impl Assistant {
if self.message_anchors.len() >= 2 && self.summary.is_none() { if self.message_anchors.len() >= 2 && self.summary.is_none() {
let api_key = self.api_key.borrow().clone(); let api_key = self.api_key.borrow().clone();
if let Some(api_key) = api_key { if let Some(api_key) = api_key {
let mut messages = self.open_ai_request_messages(cx); let messages = self
messages.truncate(2); .messages(cx)
messages.push(RequestMessage { .take(2)
role: Role::User, .map(|message| message.to_open_ai_message(self.buffer.read(cx)))
content: "Summarize the conversation into a short title without punctuation" .chain(Some(RequestMessage {
.into(), role: Role::User,
}); content:
"Summarize the conversation into a short title without punctuation"
.into(),
}));
let request = OpenAIRequest { let request = OpenAIRequest {
model: self.model.clone(), model: self.model.clone(),
messages, messages: messages.collect(),
stream: true, stream: true,
}; };
@ -870,24 +946,39 @@ impl Assistant {
} }
} }
fn open_ai_request_messages(&self, cx: &AppContext) -> Vec<RequestMessage> { fn message_for_offset(&self, offset: usize, cx: &AppContext) -> Option<Message> {
let buffer = self.buffer.read(cx); self.messages_for_offsets([offset], cx).pop()
self.messages(cx)
.map(|message| RequestMessage {
role: message.role,
content: buffer.text_for_range(message.range).collect(),
})
.collect()
} }
fn message_for_offset<'a>(&'a self, offset: usize, cx: &'a AppContext) -> Option<Message> { fn messages_for_offsets(
&self,
offsets: impl IntoIterator<Item = usize>,
cx: &AppContext,
) -> Vec<Message> {
let mut result = Vec::new();
let buffer_len = self.buffer.read(cx).len();
let mut messages = self.messages(cx).peekable(); let mut messages = self.messages(cx).peekable();
while let Some(message) = messages.next() { let mut offsets = offsets.into_iter().peekable();
if message.range.contains(&offset) || messages.peek().is_none() { while let Some(offset) = offsets.next() {
return Some(message); // Skip messages that start after the offset.
while messages.peek().map_or(false, |message| {
message.range.end < offset || (message.range.end == offset && offset < buffer_len)
}) {
messages.next();
} }
let Some(message) = messages.peek() else { continue };
// Skip offsets that are in the same message.
while offsets.peek().map_or(false, |offset| {
message.range.contains(offset) || message.range.end == buffer_len
}) {
offsets.next();
}
result.push(message.clone());
} }
None result
} }
fn messages<'a>(&'a self, cx: &'a AppContext) -> impl 'a + Iterator<Item = Message> { fn messages<'a>(&'a self, cx: &'a AppContext) -> impl 'a + Iterator<Item = Message> {
@ -916,7 +1007,7 @@ impl Assistant {
anchor: message_anchor.start, anchor: message_anchor.start,
role: metadata.role, role: metadata.role,
sent_at: metadata.sent_at, sent_at: metadata.sent_at,
error: metadata.error.clone(), status: metadata.status.clone(),
}); });
} }
None None
@ -926,7 +1017,7 @@ impl Assistant {
struct PendingCompletion { struct PendingCompletion {
id: usize, id: usize,
_task: Task<()>, _tasks: Vec<Task<()>>,
} }
enum AssistantEditorEvent { enum AssistantEditorEvent {
@ -979,20 +1070,31 @@ impl AssistantEditor {
} }
fn assist(&mut self, _: &Assist, cx: &mut ViewContext<Self>) { fn assist(&mut self, _: &Assist, cx: &mut ViewContext<Self>) {
let user_message = self.assistant.update(cx, |assistant, cx| { let cursors = self.cursors(cx);
let (_, user_message) = assistant.assist(cx)?;
Some(user_message)
});
if let Some(user_message) = user_message { let user_messages = self.assistant.update(cx, |assistant, cx| {
let cursor = user_message let selected_messages = assistant
.start .messages_for_offsets(cursors, cx)
.to_offset(&self.assistant.read(cx).buffer.read(cx)); .into_iter()
.map(|message| message.id)
.collect();
assistant.assist(selected_messages, cx)
});
let new_selections = user_messages
.iter()
.map(|message| {
let cursor = message
.start
.to_offset(self.assistant.read(cx).buffer.read(cx));
cursor..cursor
})
.collect::<Vec<_>>();
if !new_selections.is_empty() {
self.editor.update(cx, |editor, cx| { self.editor.update(cx, |editor, cx| {
editor.change_selections( editor.change_selections(
Some(Autoscroll::Strategy(AutoscrollStrategy::Fit)), Some(Autoscroll::Strategy(AutoscrollStrategy::Fit)),
cx, cx,
|selections| selections.select_ranges([cursor..cursor]), |selections| selections.select_ranges(new_selections),
); );
}); });
} }
@ -1008,14 +1110,25 @@ impl AssistantEditor {
} }
fn cycle_message_role(&mut self, _: &CycleMessageRole, cx: &mut ViewContext<Self>) { fn cycle_message_role(&mut self, _: &CycleMessageRole, cx: &mut ViewContext<Self>) {
let cursor_offset = self.editor.read(cx).selections.newest(cx).head(); let cursors = self.cursors(cx);
self.assistant.update(cx, |assistant, cx| { self.assistant.update(cx, |assistant, cx| {
if let Some(message) = assistant.message_for_offset(cursor_offset, cx) { let messages = assistant
assistant.cycle_message_role(message.id, cx); .messages_for_offsets(cursors, cx)
} .into_iter()
.map(|message| message.id)
.collect();
assistant.cycle_message_roles(messages, cx)
}); });
} }
fn cursors(&self, cx: &AppContext) -> Vec<usize> {
let selections = self.editor.read(cx).selections.all::<usize>(cx);
selections
.into_iter()
.map(|selection| selection.head())
.collect()
}
fn handle_assistant_event( fn handle_assistant_event(
&mut self, &mut self,
_: ModelHandle<Assistant>, _: ModelHandle<Assistant>,
@ -1144,7 +1257,10 @@ impl AssistantEditor {
let assistant = assistant.clone(); let assistant = assistant.clone();
move |_, _, cx| { move |_, _, cx| {
assistant.update(cx, |assistant, cx| { assistant.update(cx, |assistant, cx| {
assistant.cycle_message_role(message_id, cx) assistant.cycle_message_roles(
HashSet::from_iter(Some(message_id)),
cx,
)
}) })
} }
}); });
@ -1160,22 +1276,28 @@ impl AssistantEditor {
.with_style(style.sent_at.container) .with_style(style.sent_at.container)
.aligned(), .aligned(),
) )
.with_children(message.error.as_ref().map(|error| { .with_children(
Svg::new("icons/circle_x_mark_12.svg") if let MessageStatus::Error(error) = &message.status {
.with_color(style.error_icon.color) Some(
.constrained() Svg::new("icons/circle_x_mark_12.svg")
.with_width(style.error_icon.width) .with_color(style.error_icon.color)
.contained() .constrained()
.with_style(style.error_icon.container) .with_width(style.error_icon.width)
.with_tooltip::<ErrorTooltip>( .contained()
message_id.0, .with_style(style.error_icon.container)
error.to_string(), .with_tooltip::<ErrorTooltip>(
None, message_id.0,
theme.tooltip.clone(), error.to_string(),
cx, None,
theme.tooltip.clone(),
cx,
)
.aligned(),
) )
.aligned() } else {
})) None
},
)
.aligned() .aligned()
.left() .left()
.contained() .contained()
@ -1308,8 +1430,8 @@ impl AssistantEditor {
fn cycle_model(&mut self, cx: &mut ViewContext<Self>) { fn cycle_model(&mut self, cx: &mut ViewContext<Self>) {
self.assistant.update(cx, |assistant, cx| { self.assistant.update(cx, |assistant, cx| {
let new_model = match assistant.model.as_str() { let new_model = match assistant.model.as_str() {
"gpt-4" => "gpt-3.5-turbo", "gpt-4-0613" => "gpt-3.5-turbo-0613",
_ => "gpt-4", _ => "gpt-4-0613",
}; };
assistant.set_model(new_model.into(), cx); assistant.set_model(new_model.into(), cx);
}); });
@ -1423,7 +1545,14 @@ struct MessageAnchor {
struct MessageMetadata { struct MessageMetadata {
role: Role, role: Role,
sent_at: DateTime<Local>, sent_at: DateTime<Local>,
error: Option<Arc<str>>, status: MessageStatus,
}
#[derive(Clone, Debug)]
enum MessageStatus {
Pending,
Done,
Error(Arc<str>),
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
@ -1434,7 +1563,18 @@ pub struct Message {
anchor: language::Anchor, anchor: language::Anchor,
role: Role, role: Role,
sent_at: DateTime<Local>, sent_at: DateTime<Local>,
error: Option<Arc<str>>, status: MessageStatus,
}
impl Message {
fn to_open_ai_message(&self, buffer: &Buffer) -> RequestMessage {
let mut content = format!("[Message {}]\n", self.id.0).to_string();
content.extend(buffer.text_for_range(self.range.clone()));
RequestMessage {
role: self.role,
content,
}
}
} }
async fn stream_completion( async fn stream_completion(
@ -1542,7 +1682,7 @@ mod tests {
let message_2 = assistant.update(cx, |assistant, cx| { let message_2 = assistant.update(cx, |assistant, cx| {
assistant assistant
.insert_message_after(message_1.id, Role::Assistant, cx) .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
.unwrap() .unwrap()
}); });
assert_eq!( assert_eq!(
@ -1566,7 +1706,7 @@ mod tests {
let message_3 = assistant.update(cx, |assistant, cx| { let message_3 = assistant.update(cx, |assistant, cx| {
assistant assistant
.insert_message_after(message_2.id, Role::User, cx) .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
.unwrap() .unwrap()
}); });
assert_eq!( assert_eq!(
@ -1580,7 +1720,7 @@ mod tests {
let message_4 = assistant.update(cx, |assistant, cx| { let message_4 = assistant.update(cx, |assistant, cx| {
assistant assistant
.insert_message_after(message_2.id, Role::User, cx) .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
.unwrap() .unwrap()
}); });
assert_eq!( assert_eq!(
@ -1641,7 +1781,7 @@ mod tests {
// Ensure we can still insert after a merged message. // Ensure we can still insert after a merged message.
let message_5 = assistant.update(cx, |assistant, cx| { let message_5 = assistant.update(cx, |assistant, cx| {
assistant assistant
.insert_message_after(message_1.id, Role::System, cx) .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
.unwrap() .unwrap()
}); });
assert_eq!( assert_eq!(
@ -1747,6 +1887,66 @@ mod tests {
); );
} }
#[gpui::test]
fn test_messages_for_offsets(cx: &mut AppContext) {
let registry = Arc::new(LanguageRegistry::test());
let assistant = cx.add_model(|cx| Assistant::new(Default::default(), registry, cx));
let buffer = assistant.read(cx).buffer.clone();
let message_1 = assistant.read(cx).message_anchors[0].clone();
assert_eq!(
messages(&assistant, cx),
vec![(message_1.id, Role::User, 0..0)]
);
buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
let message_2 = assistant
.update(cx, |assistant, cx| {
assistant.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
})
.unwrap();
buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
let message_3 = assistant
.update(cx, |assistant, cx| {
assistant.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
})
.unwrap();
buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
assert_eq!(
messages(&assistant, cx),
vec![
(message_1.id, Role::User, 0..4),
(message_2.id, Role::User, 4..8),
(message_3.id, Role::User, 8..11)
]
);
assert_eq!(
message_ids_for_offsets(&assistant, &[0, 4, 9], cx),
[message_1.id, message_2.id, message_3.id]
);
assert_eq!(
message_ids_for_offsets(&assistant, &[0, 1, 11], cx),
[message_1.id, message_3.id]
);
fn message_ids_for_offsets(
assistant: &ModelHandle<Assistant>,
offsets: &[usize],
cx: &AppContext,
) -> Vec<MessageId> {
assistant
.read(cx)
.messages_for_offsets(offsets.iter().copied(), cx)
.into_iter()
.map(|message| message.id)
.collect()
}
}
fn messages( fn messages(
assistant: &ModelHandle<Assistant>, assistant: &ModelHandle<Assistant>,
cx: &AppContext, cx: &AppContext,