Instruct the assistant to reply to a specific message (#2631)

Closes
https://linear.app/zed-industries/issue/Z-2384/hitting-cmd-enter-in-a-user-or-system-message-should-generate-a

Release Notes:

- Introduced the ability to generate assistant messages for any
user/system message, as well as generating multiple assists at the same
time, one for each cursor. (preview-only)
This commit is contained in:
Antonio Scandurra 2023-06-20 18:16:23 +02:00 committed by GitHub
commit 6ed86781b2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 344 additions and 142 deletions

1
Cargo.lock generated
View file

@ -114,6 +114,7 @@ dependencies = [
"serde",
"serde_json",
"settings",
"smol",
"theme",
"tiktoken-rs",
"util",

View file

@ -28,6 +28,7 @@ isahc.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
smol.workspace = true
tiktoken-rs = "0.4"
[dev-dependencies]

View file

@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize};
use std::fmt::{self, Display};
// Data types for chat completion requests
#[derive(Serialize)]
#[derive(Debug, Serialize)]
struct OpenAIRequest {
model: String,
messages: Vec<RequestMessage>,

View file

@ -473,7 +473,7 @@ impl Assistant {
language_registry: Arc<LanguageRegistry>,
cx: &mut ModelContext<Self>,
) -> Self {
let model = "gpt-3.5-turbo";
let model = "gpt-3.5-turbo-0613";
let markdown = language_registry.language_for_name("Markdown");
let buffer = cx.add_model(|cx| {
let mut buffer = Buffer::new(0, "", cx);
@ -518,7 +518,7 @@ impl Assistant {
MessageMetadata {
role: Role::User,
sent_at: Local::now(),
error: None,
status: MessageStatus::Done,
},
);
@ -543,7 +543,7 @@ impl Assistant {
fn count_remaining_tokens(&mut self, cx: &mut ModelContext<Self>) {
let messages = self
.open_ai_request_messages(cx)
.messages(cx)
.into_iter()
.filter_map(|message| {
Some(tiktoken_rs::ChatCompletionRequestMessage {
@ -552,7 +552,7 @@ impl Assistant {
Role::Assistant => "assistant".into(),
Role::System => "system".into(),
},
content: message.content,
content: self.buffer.read(cx).text_for_range(message.range).collect(),
name: None,
})
})
@ -589,97 +589,169 @@ impl Assistant {
cx.notify();
}
fn assist(&mut self, cx: &mut ModelContext<Self>) -> Option<(MessageAnchor, MessageAnchor)> {
let request = OpenAIRequest {
model: self.model.clone(),
messages: self.open_ai_request_messages(cx),
stream: true,
};
fn assist(
&mut self,
selected_messages: HashSet<MessageId>,
cx: &mut ModelContext<Self>,
) -> Vec<MessageAnchor> {
let mut user_messages = Vec::new();
let mut tasks = Vec::new();
for selected_message_id in selected_messages {
let selected_message_role =
if let Some(metadata) = self.messages_metadata.get(&selected_message_id) {
metadata.role
} else {
continue;
};
let api_key = self.api_key.borrow().clone()?;
let stream = stream_completion(api_key, cx.background().clone(), request);
let assistant_message =
self.insert_message_after(self.message_anchors.last()?.id, Role::Assistant, cx)?;
let user_message = self.insert_message_after(assistant_message.id, Role::User, cx)?;
let task = cx.spawn_weak({
|this, mut cx| async move {
let assistant_message_id = assistant_message.id;
let stream_completion = async {
let mut messages = stream.await?;
if selected_message_role == Role::Assistant {
if let Some(user_message) = self.insert_message_after(
selected_message_id,
Role::User,
MessageStatus::Done,
cx,
) {
user_messages.push(user_message);
} else {
continue;
}
} else {
let request = OpenAIRequest {
model: self.model.clone(),
messages: self
.messages(cx)
.filter(|message| matches!(message.status, MessageStatus::Done))
.flat_map(|message| {
let mut system_message = None;
if message.id == selected_message_id {
system_message = Some(RequestMessage {
role: Role::System,
content: concat!(
"Treat the following messages as additional knowledge you have learned about, ",
"but act as if they were not part of this conversation. That is, treat them ",
"as if the user didn't see them and couldn't possibly inquire about them."
).into()
});
}
Some(message.to_open_ai_message(self.buffer.read(cx))).into_iter().chain(system_message)
})
.chain(Some(RequestMessage {
role: Role::System,
content: format!(
"Direct your reply to message with id {}. Do not include a [Message X] header.",
selected_message_id.0
),
}))
.collect(),
stream: true,
};
let Some(api_key) = self.api_key.borrow().clone() else { continue };
let stream = stream_completion(api_key, cx.background().clone(), request);
let assistant_message = self
.insert_message_after(
selected_message_id,
Role::Assistant,
MessageStatus::Pending,
cx,
)
.unwrap();
tasks.push(cx.spawn_weak({
|this, mut cx| async move {
let assistant_message_id = assistant_message.id;
let stream_completion = async {
let mut messages = stream.await?;
while let Some(message) = messages.next().await {
let mut message = message?;
if let Some(choice) = message.choices.pop() {
this.upgrade(&cx)
.ok_or_else(|| anyhow!("assistant was dropped"))?
.update(&mut cx, |this, cx| {
let text: Arc<str> = choice.delta.content?.into();
let message_ix = this.message_anchors.iter().position(
|message| message.id == assistant_message_id,
)?;
this.buffer.update(cx, |buffer, cx| {
let offset = this.message_anchors[message_ix + 1..]
.iter()
.find(|message| message.start.is_valid(buffer))
.map_or(buffer.len(), |message| {
message
.start
.to_offset(buffer)
.saturating_sub(1)
});
buffer.edit([(offset..offset, text)], None, cx);
});
cx.emit(AssistantEvent::StreamedCompletion);
Some(())
});
}
smol::future::yield_now().await;
}
while let Some(message) = messages.next().await {
let mut message = message?;
if let Some(choice) = message.choices.pop() {
this.upgrade(&cx)
.ok_or_else(|| anyhow!("assistant was dropped"))?
.update(&mut cx, |this, cx| {
let text: Arc<str> = choice.delta.content?.into();
let message_ix = this
.message_anchors
.iter()
.position(|message| message.id == assistant_message_id)?;
this.buffer.update(cx, |buffer, cx| {
let offset = if message_ix + 1 == this.message_anchors.len()
{
buffer.len()
} else {
this.message_anchors[message_ix + 1]
.start
.to_offset(buffer)
.saturating_sub(1)
};
buffer.edit([(offset..offset, text)], None, cx);
this.pending_completions.retain(|completion| {
completion.id != this.completion_count
});
cx.emit(AssistantEvent::StreamedCompletion);
Some(())
this.summarize(cx);
});
anyhow::Ok(())
};
let result = stream_completion.await;
if let Some(this) = this.upgrade(&cx) {
this.update(&mut cx, |this, cx| {
if let Some(metadata) =
this.messages_metadata.get_mut(&assistant_message.id)
{
match result {
Ok(_) => {
metadata.status = MessageStatus::Done;
}
Err(error) => {
metadata.status = MessageStatus::Error(
error.to_string().trim().into(),
);
}
}
cx.notify();
}
});
}
}
this.upgrade(&cx)
.ok_or_else(|| anyhow!("assistant was dropped"))?
.update(&mut cx, |this, cx| {
this.pending_completions
.retain(|completion| completion.id != this.completion_count);
this.summarize(cx);
});
anyhow::Ok(())
};
let result = stream_completion.await;
if let Some(this) = this.upgrade(&cx) {
this.update(&mut cx, |this, cx| {
if let Err(error) = result {
if let Some(metadata) =
this.messages_metadata.get_mut(&assistant_message.id)
{
metadata.error = Some(error.to_string().trim().into());
cx.notify();
}
}
});
}
}));
}
});
}
self.pending_completions.push(PendingCompletion {
id: post_inc(&mut self.completion_count),
_task: task,
});
Some((assistant_message, user_message))
if !tasks.is_empty() {
self.pending_completions.push(PendingCompletion {
id: post_inc(&mut self.completion_count),
_tasks: tasks,
});
}
user_messages
}
fn cancel_last_assist(&mut self) -> bool {
self.pending_completions.pop().is_some()
}
fn cycle_message_role(&mut self, id: MessageId, cx: &mut ModelContext<Self>) {
if let Some(metadata) = self.messages_metadata.get_mut(&id) {
metadata.role.cycle();
cx.emit(AssistantEvent::MessagesEdited);
cx.notify();
fn cycle_message_roles(&mut self, ids: HashSet<MessageId>, cx: &mut ModelContext<Self>) {
for id in ids {
if let Some(metadata) = self.messages_metadata.get_mut(&id) {
metadata.role.cycle();
cx.emit(AssistantEvent::MessagesEdited);
cx.notify();
}
}
}
@ -687,6 +759,7 @@ impl Assistant {
&mut self,
message_id: MessageId,
role: Role,
status: MessageStatus,
cx: &mut ModelContext<Self>,
) -> Option<MessageAnchor> {
if let Some(prev_message_ix) = self
@ -713,7 +786,7 @@ impl Assistant {
MessageMetadata {
role,
sent_at: Local::now(),
error: None,
status,
},
);
cx.emit(AssistantEvent::MessagesEdited);
@ -772,7 +845,7 @@ impl Assistant {
MessageMetadata {
role,
sent_at: Local::now(),
error: None,
status: MessageStatus::Done,
},
);
@ -814,7 +887,7 @@ impl Assistant {
MessageMetadata {
role,
sent_at: Local::now(),
error: None,
status: MessageStatus::Done,
},
);
(Some(selection), Some(suffix))
@ -833,16 +906,19 @@ impl Assistant {
if self.message_anchors.len() >= 2 && self.summary.is_none() {
let api_key = self.api_key.borrow().clone();
if let Some(api_key) = api_key {
let mut messages = self.open_ai_request_messages(cx);
messages.truncate(2);
messages.push(RequestMessage {
role: Role::User,
content: "Summarize the conversation into a short title without punctuation"
.into(),
});
let messages = self
.messages(cx)
.take(2)
.map(|message| message.to_open_ai_message(self.buffer.read(cx)))
.chain(Some(RequestMessage {
role: Role::User,
content:
"Summarize the conversation into a short title without punctuation"
.into(),
}));
let request = OpenAIRequest {
model: self.model.clone(),
messages,
messages: messages.collect(),
stream: true,
};
@ -870,24 +946,39 @@ impl Assistant {
}
}
fn open_ai_request_messages(&self, cx: &AppContext) -> Vec<RequestMessage> {
let buffer = self.buffer.read(cx);
self.messages(cx)
.map(|message| RequestMessage {
role: message.role,
content: buffer.text_for_range(message.range).collect(),
})
.collect()
fn message_for_offset(&self, offset: usize, cx: &AppContext) -> Option<Message> {
self.messages_for_offsets([offset], cx).pop()
}
fn message_for_offset<'a>(&'a self, offset: usize, cx: &'a AppContext) -> Option<Message> {
fn messages_for_offsets(
&self,
offsets: impl IntoIterator<Item = usize>,
cx: &AppContext,
) -> Vec<Message> {
let mut result = Vec::new();
let buffer_len = self.buffer.read(cx).len();
let mut messages = self.messages(cx).peekable();
while let Some(message) = messages.next() {
if message.range.contains(&offset) || messages.peek().is_none() {
return Some(message);
let mut offsets = offsets.into_iter().peekable();
while let Some(offset) = offsets.next() {
// Skip messages that start after the offset.
while messages.peek().map_or(false, |message| {
message.range.end < offset || (message.range.end == offset && offset < buffer_len)
}) {
messages.next();
}
let Some(message) = messages.peek() else { continue };
// Skip offsets that are in the same message.
while offsets.peek().map_or(false, |offset| {
message.range.contains(offset) || message.range.end == buffer_len
}) {
offsets.next();
}
result.push(message.clone());
}
None
result
}
fn messages<'a>(&'a self, cx: &'a AppContext) -> impl 'a + Iterator<Item = Message> {
@ -916,7 +1007,7 @@ impl Assistant {
anchor: message_anchor.start,
role: metadata.role,
sent_at: metadata.sent_at,
error: metadata.error.clone(),
status: metadata.status.clone(),
});
}
None
@ -926,7 +1017,7 @@ impl Assistant {
struct PendingCompletion {
id: usize,
_task: Task<()>,
_tasks: Vec<Task<()>>,
}
enum AssistantEditorEvent {
@ -979,20 +1070,31 @@ impl AssistantEditor {
}
fn assist(&mut self, _: &Assist, cx: &mut ViewContext<Self>) {
let user_message = self.assistant.update(cx, |assistant, cx| {
let (_, user_message) = assistant.assist(cx)?;
Some(user_message)
});
let cursors = self.cursors(cx);
if let Some(user_message) = user_message {
let cursor = user_message
.start
.to_offset(&self.assistant.read(cx).buffer.read(cx));
let user_messages = self.assistant.update(cx, |assistant, cx| {
let selected_messages = assistant
.messages_for_offsets(cursors, cx)
.into_iter()
.map(|message| message.id)
.collect();
assistant.assist(selected_messages, cx)
});
let new_selections = user_messages
.iter()
.map(|message| {
let cursor = message
.start
.to_offset(self.assistant.read(cx).buffer.read(cx));
cursor..cursor
})
.collect::<Vec<_>>();
if !new_selections.is_empty() {
self.editor.update(cx, |editor, cx| {
editor.change_selections(
Some(Autoscroll::Strategy(AutoscrollStrategy::Fit)),
cx,
|selections| selections.select_ranges([cursor..cursor]),
|selections| selections.select_ranges(new_selections),
);
});
}
@ -1008,14 +1110,25 @@ impl AssistantEditor {
}
fn cycle_message_role(&mut self, _: &CycleMessageRole, cx: &mut ViewContext<Self>) {
let cursor_offset = self.editor.read(cx).selections.newest(cx).head();
let cursors = self.cursors(cx);
self.assistant.update(cx, |assistant, cx| {
if let Some(message) = assistant.message_for_offset(cursor_offset, cx) {
assistant.cycle_message_role(message.id, cx);
}
let messages = assistant
.messages_for_offsets(cursors, cx)
.into_iter()
.map(|message| message.id)
.collect();
assistant.cycle_message_roles(messages, cx)
});
}
fn cursors(&self, cx: &AppContext) -> Vec<usize> {
let selections = self.editor.read(cx).selections.all::<usize>(cx);
selections
.into_iter()
.map(|selection| selection.head())
.collect()
}
fn handle_assistant_event(
&mut self,
_: ModelHandle<Assistant>,
@ -1144,7 +1257,10 @@ impl AssistantEditor {
let assistant = assistant.clone();
move |_, _, cx| {
assistant.update(cx, |assistant, cx| {
assistant.cycle_message_role(message_id, cx)
assistant.cycle_message_roles(
HashSet::from_iter(Some(message_id)),
cx,
)
})
}
});
@ -1160,22 +1276,28 @@ impl AssistantEditor {
.with_style(style.sent_at.container)
.aligned(),
)
.with_children(message.error.as_ref().map(|error| {
Svg::new("icons/circle_x_mark_12.svg")
.with_color(style.error_icon.color)
.constrained()
.with_width(style.error_icon.width)
.contained()
.with_style(style.error_icon.container)
.with_tooltip::<ErrorTooltip>(
message_id.0,
error.to_string(),
None,
theme.tooltip.clone(),
cx,
.with_children(
if let MessageStatus::Error(error) = &message.status {
Some(
Svg::new("icons/circle_x_mark_12.svg")
.with_color(style.error_icon.color)
.constrained()
.with_width(style.error_icon.width)
.contained()
.with_style(style.error_icon.container)
.with_tooltip::<ErrorTooltip>(
message_id.0,
error.to_string(),
None,
theme.tooltip.clone(),
cx,
)
.aligned(),
)
.aligned()
}))
} else {
None
},
)
.aligned()
.left()
.contained()
@ -1308,8 +1430,8 @@ impl AssistantEditor {
fn cycle_model(&mut self, cx: &mut ViewContext<Self>) {
self.assistant.update(cx, |assistant, cx| {
let new_model = match assistant.model.as_str() {
"gpt-4" => "gpt-3.5-turbo",
_ => "gpt-4",
"gpt-4-0613" => "gpt-3.5-turbo-0613",
_ => "gpt-4-0613",
};
assistant.set_model(new_model.into(), cx);
});
@ -1423,7 +1545,14 @@ struct MessageAnchor {
struct MessageMetadata {
role: Role,
sent_at: DateTime<Local>,
error: Option<Arc<str>>,
status: MessageStatus,
}
#[derive(Clone, Debug)]
enum MessageStatus {
Pending,
Done,
Error(Arc<str>),
}
#[derive(Clone, Debug)]
@ -1434,7 +1563,18 @@ pub struct Message {
anchor: language::Anchor,
role: Role,
sent_at: DateTime<Local>,
error: Option<Arc<str>>,
status: MessageStatus,
}
impl Message {
fn to_open_ai_message(&self, buffer: &Buffer) -> RequestMessage {
let mut content = format!("[Message {}]\n", self.id.0).to_string();
content.extend(buffer.text_for_range(self.range.clone()));
RequestMessage {
role: self.role,
content,
}
}
}
async fn stream_completion(
@ -1542,7 +1682,7 @@ mod tests {
let message_2 = assistant.update(cx, |assistant, cx| {
assistant
.insert_message_after(message_1.id, Role::Assistant, cx)
.insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
.unwrap()
});
assert_eq!(
@ -1566,7 +1706,7 @@ mod tests {
let message_3 = assistant.update(cx, |assistant, cx| {
assistant
.insert_message_after(message_2.id, Role::User, cx)
.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
.unwrap()
});
assert_eq!(
@ -1580,7 +1720,7 @@ mod tests {
let message_4 = assistant.update(cx, |assistant, cx| {
assistant
.insert_message_after(message_2.id, Role::User, cx)
.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
.unwrap()
});
assert_eq!(
@ -1641,7 +1781,7 @@ mod tests {
// Ensure we can still insert after a merged message.
let message_5 = assistant.update(cx, |assistant, cx| {
assistant
.insert_message_after(message_1.id, Role::System, cx)
.insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
.unwrap()
});
assert_eq!(
@ -1747,6 +1887,66 @@ mod tests {
);
}
#[gpui::test]
fn test_messages_for_offsets(cx: &mut AppContext) {
let registry = Arc::new(LanguageRegistry::test());
let assistant = cx.add_model(|cx| Assistant::new(Default::default(), registry, cx));
let buffer = assistant.read(cx).buffer.clone();
let message_1 = assistant.read(cx).message_anchors[0].clone();
assert_eq!(
messages(&assistant, cx),
vec![(message_1.id, Role::User, 0..0)]
);
buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
let message_2 = assistant
.update(cx, |assistant, cx| {
assistant.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
})
.unwrap();
buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
let message_3 = assistant
.update(cx, |assistant, cx| {
assistant.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
})
.unwrap();
buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
assert_eq!(
messages(&assistant, cx),
vec![
(message_1.id, Role::User, 0..4),
(message_2.id, Role::User, 4..8),
(message_3.id, Role::User, 8..11)
]
);
assert_eq!(
message_ids_for_offsets(&assistant, &[0, 4, 9], cx),
[message_1.id, message_2.id, message_3.id]
);
assert_eq!(
message_ids_for_offsets(&assistant, &[0, 1, 11], cx),
[message_1.id, message_3.id]
);
fn message_ids_for_offsets(
assistant: &ModelHandle<Assistant>,
offsets: &[usize],
cx: &AppContext,
) -> Vec<MessageId> {
assistant
.read(cx)
.messages_for_offsets(offsets.iter().copied(), cx)
.into_iter()
.map(|message| message.id)
.collect()
}
}
fn messages(
assistant: &ModelHandle<Assistant>,
cx: &AppContext,