Instruct the assistant to reply to a specific message (#2631)

Closes
https://linear.app/zed-industries/issue/Z-2384/hitting-cmd-enter-in-a-user-or-system-message-should-generate-a

Release Notes:

- Introduced the ability to generate assistant messages for any
user/system message, as well as generating multiple assists at the same
time, one for each cursor. (preview-only)
This commit is contained in:
Antonio Scandurra 2023-06-20 18:16:23 +02:00 committed by GitHub
commit 6ed86781b2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 344 additions and 142 deletions

1
Cargo.lock generated
View file

@ -114,6 +114,7 @@ dependencies = [
"serde",
"serde_json",
"settings",
"smol",
"theme",
"tiktoken-rs",
"util",

View file

@ -28,6 +28,7 @@ isahc.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
smol.workspace = true
tiktoken-rs = "0.4"
[dev-dependencies]

View file

@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize};
use std::fmt::{self, Display};
// Data types for chat completion requests
#[derive(Serialize)]
#[derive(Debug, Serialize)]
struct OpenAIRequest {
model: String,
messages: Vec<RequestMessage>,

View file

@ -473,7 +473,7 @@ impl Assistant {
language_registry: Arc<LanguageRegistry>,
cx: &mut ModelContext<Self>,
) -> Self {
let model = "gpt-3.5-turbo";
let model = "gpt-3.5-turbo-0613";
let markdown = language_registry.language_for_name("Markdown");
let buffer = cx.add_model(|cx| {
let mut buffer = Buffer::new(0, "", cx);
@ -518,7 +518,7 @@ impl Assistant {
MessageMetadata {
role: Role::User,
sent_at: Local::now(),
error: None,
status: MessageStatus::Done,
},
);
@ -543,7 +543,7 @@ impl Assistant {
fn count_remaining_tokens(&mut self, cx: &mut ModelContext<Self>) {
let messages = self
.open_ai_request_messages(cx)
.messages(cx)
.into_iter()
.filter_map(|message| {
Some(tiktoken_rs::ChatCompletionRequestMessage {
@ -552,7 +552,7 @@ impl Assistant {
Role::Assistant => "assistant".into(),
Role::System => "system".into(),
},
content: message.content,
content: self.buffer.read(cx).text_for_range(message.range).collect(),
name: None,
})
})
@ -589,19 +589,76 @@ impl Assistant {
cx.notify();
}
fn assist(&mut self, cx: &mut ModelContext<Self>) -> Option<(MessageAnchor, MessageAnchor)> {
fn assist(
&mut self,
selected_messages: HashSet<MessageId>,
cx: &mut ModelContext<Self>,
) -> Vec<MessageAnchor> {
let mut user_messages = Vec::new();
let mut tasks = Vec::new();
for selected_message_id in selected_messages {
let selected_message_role =
if let Some(metadata) = self.messages_metadata.get(&selected_message_id) {
metadata.role
} else {
continue;
};
if selected_message_role == Role::Assistant {
if let Some(user_message) = self.insert_message_after(
selected_message_id,
Role::User,
MessageStatus::Done,
cx,
) {
user_messages.push(user_message);
} else {
continue;
}
} else {
let request = OpenAIRequest {
model: self.model.clone(),
messages: self.open_ai_request_messages(cx),
messages: self
.messages(cx)
.filter(|message| matches!(message.status, MessageStatus::Done))
.flat_map(|message| {
let mut system_message = None;
if message.id == selected_message_id {
system_message = Some(RequestMessage {
role: Role::System,
content: concat!(
"Treat the following messages as additional knowledge you have learned about, ",
"but act as if they were not part of this conversation. That is, treat them ",
"as if the user didn't see them and couldn't possibly inquire about them."
).into()
});
}
Some(message.to_open_ai_message(self.buffer.read(cx))).into_iter().chain(system_message)
})
.chain(Some(RequestMessage {
role: Role::System,
content: format!(
"Direct your reply to message with id {}. Do not include a [Message X] header.",
selected_message_id.0
),
}))
.collect(),
stream: true,
};
let api_key = self.api_key.borrow().clone()?;
let Some(api_key) = self.api_key.borrow().clone() else { continue };
let stream = stream_completion(api_key, cx.background().clone(), request);
let assistant_message =
self.insert_message_after(self.message_anchors.last()?.id, Role::Assistant, cx)?;
let user_message = self.insert_message_after(assistant_message.id, Role::User, cx)?;
let task = cx.spawn_weak({
let assistant_message = self
.insert_message_after(
selected_message_id,
Role::Assistant,
MessageStatus::Pending,
cx,
)
.unwrap();
tasks.push(cx.spawn_weak({
|this, mut cx| async move {
let assistant_message_id = assistant_message.id;
let stream_completion = async {
@ -614,20 +671,19 @@ impl Assistant {
.ok_or_else(|| anyhow!("assistant was dropped"))?
.update(&mut cx, |this, cx| {
let text: Arc<str> = choice.delta.content?.into();
let message_ix = this
.message_anchors
.iter()
.position(|message| message.id == assistant_message_id)?;
let message_ix = this.message_anchors.iter().position(
|message| message.id == assistant_message_id,
)?;
this.buffer.update(cx, |buffer, cx| {
let offset = if message_ix + 1 == this.message_anchors.len()
{
buffer.len()
} else {
this.message_anchors[message_ix + 1]
let offset = this.message_anchors[message_ix + 1..]
.iter()
.find(|message| message.start.is_valid(buffer))
.map_or(buffer.len(), |message| {
message
.start
.to_offset(buffer)
.saturating_sub(1)
};
});
buffer.edit([(offset..offset, text)], None, cx);
});
cx.emit(AssistantEvent::StreamedCompletion);
@ -635,13 +691,15 @@ impl Assistant {
Some(())
});
}
smol::future::yield_now().await;
}
this.upgrade(&cx)
.ok_or_else(|| anyhow!("assistant was dropped"))?
.update(&mut cx, |this, cx| {
this.pending_completions
.retain(|completion| completion.id != this.completion_count);
this.pending_completions.retain(|completion| {
completion.id != this.completion_count
});
this.summarize(cx);
});
@ -651,42 +709,57 @@ impl Assistant {
let result = stream_completion.await;
if let Some(this) = this.upgrade(&cx) {
this.update(&mut cx, |this, cx| {
if let Err(error) = result {
if let Some(metadata) =
this.messages_metadata.get_mut(&assistant_message.id)
{
metadata.error = Some(error.to_string().trim().into());
match result {
Ok(_) => {
metadata.status = MessageStatus::Done;
}
Err(error) => {
metadata.status = MessageStatus::Error(
error.to_string().trim().into(),
);
}
}
cx.notify();
}
}
});
}
}
});
}));
}
}
if !tasks.is_empty() {
self.pending_completions.push(PendingCompletion {
id: post_inc(&mut self.completion_count),
_task: task,
_tasks: tasks,
});
Some((assistant_message, user_message))
}
user_messages
}
fn cancel_last_assist(&mut self) -> bool {
self.pending_completions.pop().is_some()
}
fn cycle_message_role(&mut self, id: MessageId, cx: &mut ModelContext<Self>) {
fn cycle_message_roles(&mut self, ids: HashSet<MessageId>, cx: &mut ModelContext<Self>) {
for id in ids {
if let Some(metadata) = self.messages_metadata.get_mut(&id) {
metadata.role.cycle();
cx.emit(AssistantEvent::MessagesEdited);
cx.notify();
}
}
}
fn insert_message_after(
&mut self,
message_id: MessageId,
role: Role,
status: MessageStatus,
cx: &mut ModelContext<Self>,
) -> Option<MessageAnchor> {
if let Some(prev_message_ix) = self
@ -713,7 +786,7 @@ impl Assistant {
MessageMetadata {
role,
sent_at: Local::now(),
error: None,
status,
},
);
cx.emit(AssistantEvent::MessagesEdited);
@ -772,7 +845,7 @@ impl Assistant {
MessageMetadata {
role,
sent_at: Local::now(),
error: None,
status: MessageStatus::Done,
},
);
@ -814,7 +887,7 @@ impl Assistant {
MessageMetadata {
role,
sent_at: Local::now(),
error: None,
status: MessageStatus::Done,
},
);
(Some(selection), Some(suffix))
@ -833,16 +906,19 @@ impl Assistant {
if self.message_anchors.len() >= 2 && self.summary.is_none() {
let api_key = self.api_key.borrow().clone();
if let Some(api_key) = api_key {
let mut messages = self.open_ai_request_messages(cx);
messages.truncate(2);
messages.push(RequestMessage {
let messages = self
.messages(cx)
.take(2)
.map(|message| message.to_open_ai_message(self.buffer.read(cx)))
.chain(Some(RequestMessage {
role: Role::User,
content: "Summarize the conversation into a short title without punctuation"
content:
"Summarize the conversation into a short title without punctuation"
.into(),
});
}));
let request = OpenAIRequest {
model: self.model.clone(),
messages,
messages: messages.collect(),
stream: true,
};
@ -870,24 +946,39 @@ impl Assistant {
}
}
fn open_ai_request_messages(&self, cx: &AppContext) -> Vec<RequestMessage> {
let buffer = self.buffer.read(cx);
self.messages(cx)
.map(|message| RequestMessage {
role: message.role,
content: buffer.text_for_range(message.range).collect(),
})
.collect()
fn message_for_offset(&self, offset: usize, cx: &AppContext) -> Option<Message> {
self.messages_for_offsets([offset], cx).pop()
}
fn message_for_offset<'a>(&'a self, offset: usize, cx: &'a AppContext) -> Option<Message> {
fn messages_for_offsets(
&self,
offsets: impl IntoIterator<Item = usize>,
cx: &AppContext,
) -> Vec<Message> {
let mut result = Vec::new();
let buffer_len = self.buffer.read(cx).len();
let mut messages = self.messages(cx).peekable();
while let Some(message) = messages.next() {
if message.range.contains(&offset) || messages.peek().is_none() {
return Some(message);
let mut offsets = offsets.into_iter().peekable();
while let Some(offset) = offsets.next() {
// Skip messages that start after the offset.
while messages.peek().map_or(false, |message| {
message.range.end < offset || (message.range.end == offset && offset < buffer_len)
}) {
messages.next();
}
let Some(message) = messages.peek() else { continue };
// Skip offsets that are in the same message.
while offsets.peek().map_or(false, |offset| {
message.range.contains(offset) || message.range.end == buffer_len
}) {
offsets.next();
}
None
result.push(message.clone());
}
result
}
fn messages<'a>(&'a self, cx: &'a AppContext) -> impl 'a + Iterator<Item = Message> {
@ -916,7 +1007,7 @@ impl Assistant {
anchor: message_anchor.start,
role: metadata.role,
sent_at: metadata.sent_at,
error: metadata.error.clone(),
status: metadata.status.clone(),
});
}
None
@ -926,7 +1017,7 @@ impl Assistant {
struct PendingCompletion {
id: usize,
_task: Task<()>,
_tasks: Vec<Task<()>>,
}
enum AssistantEditorEvent {
@ -979,20 +1070,31 @@ impl AssistantEditor {
}
fn assist(&mut self, _: &Assist, cx: &mut ViewContext<Self>) {
let user_message = self.assistant.update(cx, |assistant, cx| {
let (_, user_message) = assistant.assist(cx)?;
Some(user_message)
});
let cursors = self.cursors(cx);
if let Some(user_message) = user_message {
let cursor = user_message
let user_messages = self.assistant.update(cx, |assistant, cx| {
let selected_messages = assistant
.messages_for_offsets(cursors, cx)
.into_iter()
.map(|message| message.id)
.collect();
assistant.assist(selected_messages, cx)
});
let new_selections = user_messages
.iter()
.map(|message| {
let cursor = message
.start
.to_offset(&self.assistant.read(cx).buffer.read(cx));
.to_offset(self.assistant.read(cx).buffer.read(cx));
cursor..cursor
})
.collect::<Vec<_>>();
if !new_selections.is_empty() {
self.editor.update(cx, |editor, cx| {
editor.change_selections(
Some(Autoscroll::Strategy(AutoscrollStrategy::Fit)),
cx,
|selections| selections.select_ranges([cursor..cursor]),
|selections| selections.select_ranges(new_selections),
);
});
}
@ -1008,14 +1110,25 @@ impl AssistantEditor {
}
fn cycle_message_role(&mut self, _: &CycleMessageRole, cx: &mut ViewContext<Self>) {
let cursor_offset = self.editor.read(cx).selections.newest(cx).head();
let cursors = self.cursors(cx);
self.assistant.update(cx, |assistant, cx| {
if let Some(message) = assistant.message_for_offset(cursor_offset, cx) {
assistant.cycle_message_role(message.id, cx);
}
let messages = assistant
.messages_for_offsets(cursors, cx)
.into_iter()
.map(|message| message.id)
.collect();
assistant.cycle_message_roles(messages, cx)
});
}
fn cursors(&self, cx: &AppContext) -> Vec<usize> {
let selections = self.editor.read(cx).selections.all::<usize>(cx);
selections
.into_iter()
.map(|selection| selection.head())
.collect()
}
fn handle_assistant_event(
&mut self,
_: ModelHandle<Assistant>,
@ -1144,7 +1257,10 @@ impl AssistantEditor {
let assistant = assistant.clone();
move |_, _, cx| {
assistant.update(cx, |assistant, cx| {
assistant.cycle_message_role(message_id, cx)
assistant.cycle_message_roles(
HashSet::from_iter(Some(message_id)),
cx,
)
})
}
});
@ -1160,7 +1276,9 @@ impl AssistantEditor {
.with_style(style.sent_at.container)
.aligned(),
)
.with_children(message.error.as_ref().map(|error| {
.with_children(
if let MessageStatus::Error(error) = &message.status {
Some(
Svg::new("icons/circle_x_mark_12.svg")
.with_color(style.error_icon.color)
.constrained()
@ -1174,8 +1292,12 @@ impl AssistantEditor {
theme.tooltip.clone(),
cx,
)
.aligned()
}))
.aligned(),
)
} else {
None
},
)
.aligned()
.left()
.contained()
@ -1308,8 +1430,8 @@ impl AssistantEditor {
fn cycle_model(&mut self, cx: &mut ViewContext<Self>) {
self.assistant.update(cx, |assistant, cx| {
let new_model = match assistant.model.as_str() {
"gpt-4" => "gpt-3.5-turbo",
_ => "gpt-4",
"gpt-4-0613" => "gpt-3.5-turbo-0613",
_ => "gpt-4-0613",
};
assistant.set_model(new_model.into(), cx);
});
@ -1423,7 +1545,14 @@ struct MessageAnchor {
struct MessageMetadata {
role: Role,
sent_at: DateTime<Local>,
error: Option<Arc<str>>,
status: MessageStatus,
}
#[derive(Clone, Debug)]
enum MessageStatus {
Pending,
Done,
Error(Arc<str>),
}
#[derive(Clone, Debug)]
@ -1434,7 +1563,18 @@ pub struct Message {
anchor: language::Anchor,
role: Role,
sent_at: DateTime<Local>,
error: Option<Arc<str>>,
status: MessageStatus,
}
impl Message {
fn to_open_ai_message(&self, buffer: &Buffer) -> RequestMessage {
let mut content = format!("[Message {}]\n", self.id.0).to_string();
content.extend(buffer.text_for_range(self.range.clone()));
RequestMessage {
role: self.role,
content,
}
}
}
async fn stream_completion(
@ -1542,7 +1682,7 @@ mod tests {
let message_2 = assistant.update(cx, |assistant, cx| {
assistant
.insert_message_after(message_1.id, Role::Assistant, cx)
.insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
.unwrap()
});
assert_eq!(
@ -1566,7 +1706,7 @@ mod tests {
let message_3 = assistant.update(cx, |assistant, cx| {
assistant
.insert_message_after(message_2.id, Role::User, cx)
.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
.unwrap()
});
assert_eq!(
@ -1580,7 +1720,7 @@ mod tests {
let message_4 = assistant.update(cx, |assistant, cx| {
assistant
.insert_message_after(message_2.id, Role::User, cx)
.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
.unwrap()
});
assert_eq!(
@ -1641,7 +1781,7 @@ mod tests {
// Ensure we can still insert after a merged message.
let message_5 = assistant.update(cx, |assistant, cx| {
assistant
.insert_message_after(message_1.id, Role::System, cx)
.insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
.unwrap()
});
assert_eq!(
@ -1747,6 +1887,66 @@ mod tests {
);
}
#[gpui::test]
fn test_messages_for_offsets(cx: &mut AppContext) {
let registry = Arc::new(LanguageRegistry::test());
let assistant = cx.add_model(|cx| Assistant::new(Default::default(), registry, cx));
let buffer = assistant.read(cx).buffer.clone();
let message_1 = assistant.read(cx).message_anchors[0].clone();
assert_eq!(
messages(&assistant, cx),
vec![(message_1.id, Role::User, 0..0)]
);
buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
let message_2 = assistant
.update(cx, |assistant, cx| {
assistant.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
})
.unwrap();
buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
let message_3 = assistant
.update(cx, |assistant, cx| {
assistant.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
})
.unwrap();
buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
assert_eq!(
messages(&assistant, cx),
vec![
(message_1.id, Role::User, 0..4),
(message_2.id, Role::User, 4..8),
(message_3.id, Role::User, 8..11)
]
);
assert_eq!(
message_ids_for_offsets(&assistant, &[0, 4, 9], cx),
[message_1.id, message_2.id, message_3.id]
);
assert_eq!(
message_ids_for_offsets(&assistant, &[0, 1, 11], cx),
[message_1.id, message_3.id]
);
fn message_ids_for_offsets(
assistant: &ModelHandle<Assistant>,
offsets: &[usize],
cx: &AppContext,
) -> Vec<MessageId> {
assistant
.read(cx)
.messages_for_offsets(offsets.iter().copied(), cx)
.into_iter()
.map(|message| message.id)
.collect()
}
}
fn messages(
assistant: &ModelHandle<Assistant>,
cx: &AppContext,