mirror of
https://github.com/zed-industries/zed.git
synced 2025-02-03 08:54:04 +00:00
Instruct the assistant to reply to a specific message (#2631)
Closes https://linear.app/zed-industries/issue/Z-2384/hitting-cmd-enter-in-a-user-or-system-message-should-generate-a Release Notes: - Introduced the ability to generate assistant messages for any user/system message, as well as generating multiple assists at the same time, one for each cursor. (preview-only)
This commit is contained in:
commit
6ed86781b2
4 changed files with 344 additions and 142 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -114,6 +114,7 @@ dependencies = [
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"settings",
|
"settings",
|
||||||
|
"smol",
|
||||||
"theme",
|
"theme",
|
||||||
"tiktoken-rs",
|
"tiktoken-rs",
|
||||||
"util",
|
"util",
|
||||||
|
|
|
@ -28,6 +28,7 @@ isahc.workspace = true
|
||||||
schemars.workspace = true
|
schemars.workspace = true
|
||||||
serde.workspace = true
|
serde.workspace = true
|
||||||
serde_json.workspace = true
|
serde_json.workspace = true
|
||||||
|
smol.workspace = true
|
||||||
tiktoken-rs = "0.4"
|
tiktoken-rs = "0.4"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
|
|
|
@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize};
|
||||||
use std::fmt::{self, Display};
|
use std::fmt::{self, Display};
|
||||||
|
|
||||||
// Data types for chat completion requests
|
// Data types for chat completion requests
|
||||||
#[derive(Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
struct OpenAIRequest {
|
struct OpenAIRequest {
|
||||||
model: String,
|
model: String,
|
||||||
messages: Vec<RequestMessage>,
|
messages: Vec<RequestMessage>,
|
||||||
|
|
|
@ -473,7 +473,7 @@ impl Assistant {
|
||||||
language_registry: Arc<LanguageRegistry>,
|
language_registry: Arc<LanguageRegistry>,
|
||||||
cx: &mut ModelContext<Self>,
|
cx: &mut ModelContext<Self>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let model = "gpt-3.5-turbo";
|
let model = "gpt-3.5-turbo-0613";
|
||||||
let markdown = language_registry.language_for_name("Markdown");
|
let markdown = language_registry.language_for_name("Markdown");
|
||||||
let buffer = cx.add_model(|cx| {
|
let buffer = cx.add_model(|cx| {
|
||||||
let mut buffer = Buffer::new(0, "", cx);
|
let mut buffer = Buffer::new(0, "", cx);
|
||||||
|
@ -518,7 +518,7 @@ impl Assistant {
|
||||||
MessageMetadata {
|
MessageMetadata {
|
||||||
role: Role::User,
|
role: Role::User,
|
||||||
sent_at: Local::now(),
|
sent_at: Local::now(),
|
||||||
error: None,
|
status: MessageStatus::Done,
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -543,7 +543,7 @@ impl Assistant {
|
||||||
|
|
||||||
fn count_remaining_tokens(&mut self, cx: &mut ModelContext<Self>) {
|
fn count_remaining_tokens(&mut self, cx: &mut ModelContext<Self>) {
|
||||||
let messages = self
|
let messages = self
|
||||||
.open_ai_request_messages(cx)
|
.messages(cx)
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.filter_map(|message| {
|
.filter_map(|message| {
|
||||||
Some(tiktoken_rs::ChatCompletionRequestMessage {
|
Some(tiktoken_rs::ChatCompletionRequestMessage {
|
||||||
|
@ -552,7 +552,7 @@ impl Assistant {
|
||||||
Role::Assistant => "assistant".into(),
|
Role::Assistant => "assistant".into(),
|
||||||
Role::System => "system".into(),
|
Role::System => "system".into(),
|
||||||
},
|
},
|
||||||
content: message.content,
|
content: self.buffer.read(cx).text_for_range(message.range).collect(),
|
||||||
name: None,
|
name: None,
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
@ -589,97 +589,169 @@ impl Assistant {
|
||||||
cx.notify();
|
cx.notify();
|
||||||
}
|
}
|
||||||
|
|
||||||
fn assist(&mut self, cx: &mut ModelContext<Self>) -> Option<(MessageAnchor, MessageAnchor)> {
|
fn assist(
|
||||||
let request = OpenAIRequest {
|
&mut self,
|
||||||
model: self.model.clone(),
|
selected_messages: HashSet<MessageId>,
|
||||||
messages: self.open_ai_request_messages(cx),
|
cx: &mut ModelContext<Self>,
|
||||||
stream: true,
|
) -> Vec<MessageAnchor> {
|
||||||
};
|
let mut user_messages = Vec::new();
|
||||||
|
let mut tasks = Vec::new();
|
||||||
|
for selected_message_id in selected_messages {
|
||||||
|
let selected_message_role =
|
||||||
|
if let Some(metadata) = self.messages_metadata.get(&selected_message_id) {
|
||||||
|
metadata.role
|
||||||
|
} else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
let api_key = self.api_key.borrow().clone()?;
|
if selected_message_role == Role::Assistant {
|
||||||
let stream = stream_completion(api_key, cx.background().clone(), request);
|
if let Some(user_message) = self.insert_message_after(
|
||||||
let assistant_message =
|
selected_message_id,
|
||||||
self.insert_message_after(self.message_anchors.last()?.id, Role::Assistant, cx)?;
|
Role::User,
|
||||||
let user_message = self.insert_message_after(assistant_message.id, Role::User, cx)?;
|
MessageStatus::Done,
|
||||||
let task = cx.spawn_weak({
|
cx,
|
||||||
|this, mut cx| async move {
|
) {
|
||||||
let assistant_message_id = assistant_message.id;
|
user_messages.push(user_message);
|
||||||
let stream_completion = async {
|
} else {
|
||||||
let mut messages = stream.await?;
|
continue;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let request = OpenAIRequest {
|
||||||
|
model: self.model.clone(),
|
||||||
|
messages: self
|
||||||
|
.messages(cx)
|
||||||
|
.filter(|message| matches!(message.status, MessageStatus::Done))
|
||||||
|
.flat_map(|message| {
|
||||||
|
let mut system_message = None;
|
||||||
|
if message.id == selected_message_id {
|
||||||
|
system_message = Some(RequestMessage {
|
||||||
|
role: Role::System,
|
||||||
|
content: concat!(
|
||||||
|
"Treat the following messages as additional knowledge you have learned about, ",
|
||||||
|
"but act as if they were not part of this conversation. That is, treat them ",
|
||||||
|
"as if the user didn't see them and couldn't possibly inquire about them."
|
||||||
|
).into()
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(message.to_open_ai_message(self.buffer.read(cx))).into_iter().chain(system_message)
|
||||||
|
})
|
||||||
|
.chain(Some(RequestMessage {
|
||||||
|
role: Role::System,
|
||||||
|
content: format!(
|
||||||
|
"Direct your reply to message with id {}. Do not include a [Message X] header.",
|
||||||
|
selected_message_id.0
|
||||||
|
),
|
||||||
|
}))
|
||||||
|
.collect(),
|
||||||
|
stream: true,
|
||||||
|
};
|
||||||
|
|
||||||
|
let Some(api_key) = self.api_key.borrow().clone() else { continue };
|
||||||
|
let stream = stream_completion(api_key, cx.background().clone(), request);
|
||||||
|
let assistant_message = self
|
||||||
|
.insert_message_after(
|
||||||
|
selected_message_id,
|
||||||
|
Role::Assistant,
|
||||||
|
MessageStatus::Pending,
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
tasks.push(cx.spawn_weak({
|
||||||
|
|this, mut cx| async move {
|
||||||
|
let assistant_message_id = assistant_message.id;
|
||||||
|
let stream_completion = async {
|
||||||
|
let mut messages = stream.await?;
|
||||||
|
|
||||||
|
while let Some(message) = messages.next().await {
|
||||||
|
let mut message = message?;
|
||||||
|
if let Some(choice) = message.choices.pop() {
|
||||||
|
this.upgrade(&cx)
|
||||||
|
.ok_or_else(|| anyhow!("assistant was dropped"))?
|
||||||
|
.update(&mut cx, |this, cx| {
|
||||||
|
let text: Arc<str> = choice.delta.content?.into();
|
||||||
|
let message_ix = this.message_anchors.iter().position(
|
||||||
|
|message| message.id == assistant_message_id,
|
||||||
|
)?;
|
||||||
|
this.buffer.update(cx, |buffer, cx| {
|
||||||
|
let offset = this.message_anchors[message_ix + 1..]
|
||||||
|
.iter()
|
||||||
|
.find(|message| message.start.is_valid(buffer))
|
||||||
|
.map_or(buffer.len(), |message| {
|
||||||
|
message
|
||||||
|
.start
|
||||||
|
.to_offset(buffer)
|
||||||
|
.saturating_sub(1)
|
||||||
|
});
|
||||||
|
buffer.edit([(offset..offset, text)], None, cx);
|
||||||
|
});
|
||||||
|
cx.emit(AssistantEvent::StreamedCompletion);
|
||||||
|
|
||||||
|
Some(())
|
||||||
|
});
|
||||||
|
}
|
||||||
|
smol::future::yield_now().await;
|
||||||
|
}
|
||||||
|
|
||||||
while let Some(message) = messages.next().await {
|
|
||||||
let mut message = message?;
|
|
||||||
if let Some(choice) = message.choices.pop() {
|
|
||||||
this.upgrade(&cx)
|
this.upgrade(&cx)
|
||||||
.ok_or_else(|| anyhow!("assistant was dropped"))?
|
.ok_or_else(|| anyhow!("assistant was dropped"))?
|
||||||
.update(&mut cx, |this, cx| {
|
.update(&mut cx, |this, cx| {
|
||||||
let text: Arc<str> = choice.delta.content?.into();
|
this.pending_completions.retain(|completion| {
|
||||||
let message_ix = this
|
completion.id != this.completion_count
|
||||||
.message_anchors
|
|
||||||
.iter()
|
|
||||||
.position(|message| message.id == assistant_message_id)?;
|
|
||||||
this.buffer.update(cx, |buffer, cx| {
|
|
||||||
let offset = if message_ix + 1 == this.message_anchors.len()
|
|
||||||
{
|
|
||||||
buffer.len()
|
|
||||||
} else {
|
|
||||||
this.message_anchors[message_ix + 1]
|
|
||||||
.start
|
|
||||||
.to_offset(buffer)
|
|
||||||
.saturating_sub(1)
|
|
||||||
};
|
|
||||||
buffer.edit([(offset..offset, text)], None, cx);
|
|
||||||
});
|
});
|
||||||
cx.emit(AssistantEvent::StreamedCompletion);
|
this.summarize(cx);
|
||||||
|
|
||||||
Some(())
|
|
||||||
});
|
});
|
||||||
|
|
||||||
|
anyhow::Ok(())
|
||||||
|
};
|
||||||
|
|
||||||
|
let result = stream_completion.await;
|
||||||
|
if let Some(this) = this.upgrade(&cx) {
|
||||||
|
this.update(&mut cx, |this, cx| {
|
||||||
|
if let Some(metadata) =
|
||||||
|
this.messages_metadata.get_mut(&assistant_message.id)
|
||||||
|
{
|
||||||
|
match result {
|
||||||
|
Ok(_) => {
|
||||||
|
metadata.status = MessageStatus::Done;
|
||||||
|
}
|
||||||
|
Err(error) => {
|
||||||
|
metadata.status = MessageStatus::Error(
|
||||||
|
error.to_string().trim().into(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cx.notify();
|
||||||
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}));
|
||||||
this.upgrade(&cx)
|
|
||||||
.ok_or_else(|| anyhow!("assistant was dropped"))?
|
|
||||||
.update(&mut cx, |this, cx| {
|
|
||||||
this.pending_completions
|
|
||||||
.retain(|completion| completion.id != this.completion_count);
|
|
||||||
this.summarize(cx);
|
|
||||||
});
|
|
||||||
|
|
||||||
anyhow::Ok(())
|
|
||||||
};
|
|
||||||
|
|
||||||
let result = stream_completion.await;
|
|
||||||
if let Some(this) = this.upgrade(&cx) {
|
|
||||||
this.update(&mut cx, |this, cx| {
|
|
||||||
if let Err(error) = result {
|
|
||||||
if let Some(metadata) =
|
|
||||||
this.messages_metadata.get_mut(&assistant_message.id)
|
|
||||||
{
|
|
||||||
metadata.error = Some(error.to_string().trim().into());
|
|
||||||
cx.notify();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
});
|
}
|
||||||
|
|
||||||
self.pending_completions.push(PendingCompletion {
|
if !tasks.is_empty() {
|
||||||
id: post_inc(&mut self.completion_count),
|
self.pending_completions.push(PendingCompletion {
|
||||||
_task: task,
|
id: post_inc(&mut self.completion_count),
|
||||||
});
|
_tasks: tasks,
|
||||||
Some((assistant_message, user_message))
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
user_messages
|
||||||
}
|
}
|
||||||
|
|
||||||
fn cancel_last_assist(&mut self) -> bool {
|
fn cancel_last_assist(&mut self) -> bool {
|
||||||
self.pending_completions.pop().is_some()
|
self.pending_completions.pop().is_some()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn cycle_message_role(&mut self, id: MessageId, cx: &mut ModelContext<Self>) {
|
fn cycle_message_roles(&mut self, ids: HashSet<MessageId>, cx: &mut ModelContext<Self>) {
|
||||||
if let Some(metadata) = self.messages_metadata.get_mut(&id) {
|
for id in ids {
|
||||||
metadata.role.cycle();
|
if let Some(metadata) = self.messages_metadata.get_mut(&id) {
|
||||||
cx.emit(AssistantEvent::MessagesEdited);
|
metadata.role.cycle();
|
||||||
cx.notify();
|
cx.emit(AssistantEvent::MessagesEdited);
|
||||||
|
cx.notify();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -687,6 +759,7 @@ impl Assistant {
|
||||||
&mut self,
|
&mut self,
|
||||||
message_id: MessageId,
|
message_id: MessageId,
|
||||||
role: Role,
|
role: Role,
|
||||||
|
status: MessageStatus,
|
||||||
cx: &mut ModelContext<Self>,
|
cx: &mut ModelContext<Self>,
|
||||||
) -> Option<MessageAnchor> {
|
) -> Option<MessageAnchor> {
|
||||||
if let Some(prev_message_ix) = self
|
if let Some(prev_message_ix) = self
|
||||||
|
@ -713,7 +786,7 @@ impl Assistant {
|
||||||
MessageMetadata {
|
MessageMetadata {
|
||||||
role,
|
role,
|
||||||
sent_at: Local::now(),
|
sent_at: Local::now(),
|
||||||
error: None,
|
status,
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
cx.emit(AssistantEvent::MessagesEdited);
|
cx.emit(AssistantEvent::MessagesEdited);
|
||||||
|
@ -772,7 +845,7 @@ impl Assistant {
|
||||||
MessageMetadata {
|
MessageMetadata {
|
||||||
role,
|
role,
|
||||||
sent_at: Local::now(),
|
sent_at: Local::now(),
|
||||||
error: None,
|
status: MessageStatus::Done,
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -814,7 +887,7 @@ impl Assistant {
|
||||||
MessageMetadata {
|
MessageMetadata {
|
||||||
role,
|
role,
|
||||||
sent_at: Local::now(),
|
sent_at: Local::now(),
|
||||||
error: None,
|
status: MessageStatus::Done,
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
(Some(selection), Some(suffix))
|
(Some(selection), Some(suffix))
|
||||||
|
@ -833,16 +906,19 @@ impl Assistant {
|
||||||
if self.message_anchors.len() >= 2 && self.summary.is_none() {
|
if self.message_anchors.len() >= 2 && self.summary.is_none() {
|
||||||
let api_key = self.api_key.borrow().clone();
|
let api_key = self.api_key.borrow().clone();
|
||||||
if let Some(api_key) = api_key {
|
if let Some(api_key) = api_key {
|
||||||
let mut messages = self.open_ai_request_messages(cx);
|
let messages = self
|
||||||
messages.truncate(2);
|
.messages(cx)
|
||||||
messages.push(RequestMessage {
|
.take(2)
|
||||||
role: Role::User,
|
.map(|message| message.to_open_ai_message(self.buffer.read(cx)))
|
||||||
content: "Summarize the conversation into a short title without punctuation"
|
.chain(Some(RequestMessage {
|
||||||
.into(),
|
role: Role::User,
|
||||||
});
|
content:
|
||||||
|
"Summarize the conversation into a short title without punctuation"
|
||||||
|
.into(),
|
||||||
|
}));
|
||||||
let request = OpenAIRequest {
|
let request = OpenAIRequest {
|
||||||
model: self.model.clone(),
|
model: self.model.clone(),
|
||||||
messages,
|
messages: messages.collect(),
|
||||||
stream: true,
|
stream: true,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -870,24 +946,39 @@ impl Assistant {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn open_ai_request_messages(&self, cx: &AppContext) -> Vec<RequestMessage> {
|
fn message_for_offset(&self, offset: usize, cx: &AppContext) -> Option<Message> {
|
||||||
let buffer = self.buffer.read(cx);
|
self.messages_for_offsets([offset], cx).pop()
|
||||||
self.messages(cx)
|
|
||||||
.map(|message| RequestMessage {
|
|
||||||
role: message.role,
|
|
||||||
content: buffer.text_for_range(message.range).collect(),
|
|
||||||
})
|
|
||||||
.collect()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn message_for_offset<'a>(&'a self, offset: usize, cx: &'a AppContext) -> Option<Message> {
|
fn messages_for_offsets(
|
||||||
|
&self,
|
||||||
|
offsets: impl IntoIterator<Item = usize>,
|
||||||
|
cx: &AppContext,
|
||||||
|
) -> Vec<Message> {
|
||||||
|
let mut result = Vec::new();
|
||||||
|
|
||||||
|
let buffer_len = self.buffer.read(cx).len();
|
||||||
let mut messages = self.messages(cx).peekable();
|
let mut messages = self.messages(cx).peekable();
|
||||||
while let Some(message) = messages.next() {
|
let mut offsets = offsets.into_iter().peekable();
|
||||||
if message.range.contains(&offset) || messages.peek().is_none() {
|
while let Some(offset) = offsets.next() {
|
||||||
return Some(message);
|
// Skip messages that start after the offset.
|
||||||
|
while messages.peek().map_or(false, |message| {
|
||||||
|
message.range.end < offset || (message.range.end == offset && offset < buffer_len)
|
||||||
|
}) {
|
||||||
|
messages.next();
|
||||||
}
|
}
|
||||||
|
let Some(message) = messages.peek() else { continue };
|
||||||
|
|
||||||
|
// Skip offsets that are in the same message.
|
||||||
|
while offsets.peek().map_or(false, |offset| {
|
||||||
|
message.range.contains(offset) || message.range.end == buffer_len
|
||||||
|
}) {
|
||||||
|
offsets.next();
|
||||||
|
}
|
||||||
|
|
||||||
|
result.push(message.clone());
|
||||||
}
|
}
|
||||||
None
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
fn messages<'a>(&'a self, cx: &'a AppContext) -> impl 'a + Iterator<Item = Message> {
|
fn messages<'a>(&'a self, cx: &'a AppContext) -> impl 'a + Iterator<Item = Message> {
|
||||||
|
@ -916,7 +1007,7 @@ impl Assistant {
|
||||||
anchor: message_anchor.start,
|
anchor: message_anchor.start,
|
||||||
role: metadata.role,
|
role: metadata.role,
|
||||||
sent_at: metadata.sent_at,
|
sent_at: metadata.sent_at,
|
||||||
error: metadata.error.clone(),
|
status: metadata.status.clone(),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
None
|
None
|
||||||
|
@ -926,7 +1017,7 @@ impl Assistant {
|
||||||
|
|
||||||
struct PendingCompletion {
|
struct PendingCompletion {
|
||||||
id: usize,
|
id: usize,
|
||||||
_task: Task<()>,
|
_tasks: Vec<Task<()>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
enum AssistantEditorEvent {
|
enum AssistantEditorEvent {
|
||||||
|
@ -979,20 +1070,31 @@ impl AssistantEditor {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn assist(&mut self, _: &Assist, cx: &mut ViewContext<Self>) {
|
fn assist(&mut self, _: &Assist, cx: &mut ViewContext<Self>) {
|
||||||
let user_message = self.assistant.update(cx, |assistant, cx| {
|
let cursors = self.cursors(cx);
|
||||||
let (_, user_message) = assistant.assist(cx)?;
|
|
||||||
Some(user_message)
|
|
||||||
});
|
|
||||||
|
|
||||||
if let Some(user_message) = user_message {
|
let user_messages = self.assistant.update(cx, |assistant, cx| {
|
||||||
let cursor = user_message
|
let selected_messages = assistant
|
||||||
.start
|
.messages_for_offsets(cursors, cx)
|
||||||
.to_offset(&self.assistant.read(cx).buffer.read(cx));
|
.into_iter()
|
||||||
|
.map(|message| message.id)
|
||||||
|
.collect();
|
||||||
|
assistant.assist(selected_messages, cx)
|
||||||
|
});
|
||||||
|
let new_selections = user_messages
|
||||||
|
.iter()
|
||||||
|
.map(|message| {
|
||||||
|
let cursor = message
|
||||||
|
.start
|
||||||
|
.to_offset(self.assistant.read(cx).buffer.read(cx));
|
||||||
|
cursor..cursor
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
if !new_selections.is_empty() {
|
||||||
self.editor.update(cx, |editor, cx| {
|
self.editor.update(cx, |editor, cx| {
|
||||||
editor.change_selections(
|
editor.change_selections(
|
||||||
Some(Autoscroll::Strategy(AutoscrollStrategy::Fit)),
|
Some(Autoscroll::Strategy(AutoscrollStrategy::Fit)),
|
||||||
cx,
|
cx,
|
||||||
|selections| selections.select_ranges([cursor..cursor]),
|
|selections| selections.select_ranges(new_selections),
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -1008,14 +1110,25 @@ impl AssistantEditor {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn cycle_message_role(&mut self, _: &CycleMessageRole, cx: &mut ViewContext<Self>) {
|
fn cycle_message_role(&mut self, _: &CycleMessageRole, cx: &mut ViewContext<Self>) {
|
||||||
let cursor_offset = self.editor.read(cx).selections.newest(cx).head();
|
let cursors = self.cursors(cx);
|
||||||
self.assistant.update(cx, |assistant, cx| {
|
self.assistant.update(cx, |assistant, cx| {
|
||||||
if let Some(message) = assistant.message_for_offset(cursor_offset, cx) {
|
let messages = assistant
|
||||||
assistant.cycle_message_role(message.id, cx);
|
.messages_for_offsets(cursors, cx)
|
||||||
}
|
.into_iter()
|
||||||
|
.map(|message| message.id)
|
||||||
|
.collect();
|
||||||
|
assistant.cycle_message_roles(messages, cx)
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn cursors(&self, cx: &AppContext) -> Vec<usize> {
|
||||||
|
let selections = self.editor.read(cx).selections.all::<usize>(cx);
|
||||||
|
selections
|
||||||
|
.into_iter()
|
||||||
|
.map(|selection| selection.head())
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
fn handle_assistant_event(
|
fn handle_assistant_event(
|
||||||
&mut self,
|
&mut self,
|
||||||
_: ModelHandle<Assistant>,
|
_: ModelHandle<Assistant>,
|
||||||
|
@ -1144,7 +1257,10 @@ impl AssistantEditor {
|
||||||
let assistant = assistant.clone();
|
let assistant = assistant.clone();
|
||||||
move |_, _, cx| {
|
move |_, _, cx| {
|
||||||
assistant.update(cx, |assistant, cx| {
|
assistant.update(cx, |assistant, cx| {
|
||||||
assistant.cycle_message_role(message_id, cx)
|
assistant.cycle_message_roles(
|
||||||
|
HashSet::from_iter(Some(message_id)),
|
||||||
|
cx,
|
||||||
|
)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
@ -1160,22 +1276,28 @@ impl AssistantEditor {
|
||||||
.with_style(style.sent_at.container)
|
.with_style(style.sent_at.container)
|
||||||
.aligned(),
|
.aligned(),
|
||||||
)
|
)
|
||||||
.with_children(message.error.as_ref().map(|error| {
|
.with_children(
|
||||||
Svg::new("icons/circle_x_mark_12.svg")
|
if let MessageStatus::Error(error) = &message.status {
|
||||||
.with_color(style.error_icon.color)
|
Some(
|
||||||
.constrained()
|
Svg::new("icons/circle_x_mark_12.svg")
|
||||||
.with_width(style.error_icon.width)
|
.with_color(style.error_icon.color)
|
||||||
.contained()
|
.constrained()
|
||||||
.with_style(style.error_icon.container)
|
.with_width(style.error_icon.width)
|
||||||
.with_tooltip::<ErrorTooltip>(
|
.contained()
|
||||||
message_id.0,
|
.with_style(style.error_icon.container)
|
||||||
error.to_string(),
|
.with_tooltip::<ErrorTooltip>(
|
||||||
None,
|
message_id.0,
|
||||||
theme.tooltip.clone(),
|
error.to_string(),
|
||||||
cx,
|
None,
|
||||||
|
theme.tooltip.clone(),
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
.aligned(),
|
||||||
)
|
)
|
||||||
.aligned()
|
} else {
|
||||||
}))
|
None
|
||||||
|
},
|
||||||
|
)
|
||||||
.aligned()
|
.aligned()
|
||||||
.left()
|
.left()
|
||||||
.contained()
|
.contained()
|
||||||
|
@ -1308,8 +1430,8 @@ impl AssistantEditor {
|
||||||
fn cycle_model(&mut self, cx: &mut ViewContext<Self>) {
|
fn cycle_model(&mut self, cx: &mut ViewContext<Self>) {
|
||||||
self.assistant.update(cx, |assistant, cx| {
|
self.assistant.update(cx, |assistant, cx| {
|
||||||
let new_model = match assistant.model.as_str() {
|
let new_model = match assistant.model.as_str() {
|
||||||
"gpt-4" => "gpt-3.5-turbo",
|
"gpt-4-0613" => "gpt-3.5-turbo-0613",
|
||||||
_ => "gpt-4",
|
_ => "gpt-4-0613",
|
||||||
};
|
};
|
||||||
assistant.set_model(new_model.into(), cx);
|
assistant.set_model(new_model.into(), cx);
|
||||||
});
|
});
|
||||||
|
@ -1423,7 +1545,14 @@ struct MessageAnchor {
|
||||||
struct MessageMetadata {
|
struct MessageMetadata {
|
||||||
role: Role,
|
role: Role,
|
||||||
sent_at: DateTime<Local>,
|
sent_at: DateTime<Local>,
|
||||||
error: Option<Arc<str>>,
|
status: MessageStatus,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
enum MessageStatus {
|
||||||
|
Pending,
|
||||||
|
Done,
|
||||||
|
Error(Arc<str>),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
|
@ -1434,7 +1563,18 @@ pub struct Message {
|
||||||
anchor: language::Anchor,
|
anchor: language::Anchor,
|
||||||
role: Role,
|
role: Role,
|
||||||
sent_at: DateTime<Local>,
|
sent_at: DateTime<Local>,
|
||||||
error: Option<Arc<str>>,
|
status: MessageStatus,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Message {
|
||||||
|
fn to_open_ai_message(&self, buffer: &Buffer) -> RequestMessage {
|
||||||
|
let mut content = format!("[Message {}]\n", self.id.0).to_string();
|
||||||
|
content.extend(buffer.text_for_range(self.range.clone()));
|
||||||
|
RequestMessage {
|
||||||
|
role: self.role,
|
||||||
|
content,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn stream_completion(
|
async fn stream_completion(
|
||||||
|
@ -1542,7 +1682,7 @@ mod tests {
|
||||||
|
|
||||||
let message_2 = assistant.update(cx, |assistant, cx| {
|
let message_2 = assistant.update(cx, |assistant, cx| {
|
||||||
assistant
|
assistant
|
||||||
.insert_message_after(message_1.id, Role::Assistant, cx)
|
.insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
});
|
});
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
|
@ -1566,7 +1706,7 @@ mod tests {
|
||||||
|
|
||||||
let message_3 = assistant.update(cx, |assistant, cx| {
|
let message_3 = assistant.update(cx, |assistant, cx| {
|
||||||
assistant
|
assistant
|
||||||
.insert_message_after(message_2.id, Role::User, cx)
|
.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
});
|
});
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
|
@ -1580,7 +1720,7 @@ mod tests {
|
||||||
|
|
||||||
let message_4 = assistant.update(cx, |assistant, cx| {
|
let message_4 = assistant.update(cx, |assistant, cx| {
|
||||||
assistant
|
assistant
|
||||||
.insert_message_after(message_2.id, Role::User, cx)
|
.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
});
|
});
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
|
@ -1641,7 +1781,7 @@ mod tests {
|
||||||
// Ensure we can still insert after a merged message.
|
// Ensure we can still insert after a merged message.
|
||||||
let message_5 = assistant.update(cx, |assistant, cx| {
|
let message_5 = assistant.update(cx, |assistant, cx| {
|
||||||
assistant
|
assistant
|
||||||
.insert_message_after(message_1.id, Role::System, cx)
|
.insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
});
|
});
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
|
@ -1747,6 +1887,66 @@ mod tests {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
fn test_messages_for_offsets(cx: &mut AppContext) {
|
||||||
|
let registry = Arc::new(LanguageRegistry::test());
|
||||||
|
let assistant = cx.add_model(|cx| Assistant::new(Default::default(), registry, cx));
|
||||||
|
let buffer = assistant.read(cx).buffer.clone();
|
||||||
|
|
||||||
|
let message_1 = assistant.read(cx).message_anchors[0].clone();
|
||||||
|
assert_eq!(
|
||||||
|
messages(&assistant, cx),
|
||||||
|
vec![(message_1.id, Role::User, 0..0)]
|
||||||
|
);
|
||||||
|
|
||||||
|
buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
|
||||||
|
let message_2 = assistant
|
||||||
|
.update(cx, |assistant, cx| {
|
||||||
|
assistant.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
|
||||||
|
|
||||||
|
let message_3 = assistant
|
||||||
|
.update(cx, |assistant, cx| {
|
||||||
|
assistant.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
|
||||||
|
|
||||||
|
assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
|
||||||
|
assert_eq!(
|
||||||
|
messages(&assistant, cx),
|
||||||
|
vec![
|
||||||
|
(message_1.id, Role::User, 0..4),
|
||||||
|
(message_2.id, Role::User, 4..8),
|
||||||
|
(message_3.id, Role::User, 8..11)
|
||||||
|
]
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
message_ids_for_offsets(&assistant, &[0, 4, 9], cx),
|
||||||
|
[message_1.id, message_2.id, message_3.id]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
message_ids_for_offsets(&assistant, &[0, 1, 11], cx),
|
||||||
|
[message_1.id, message_3.id]
|
||||||
|
);
|
||||||
|
|
||||||
|
fn message_ids_for_offsets(
|
||||||
|
assistant: &ModelHandle<Assistant>,
|
||||||
|
offsets: &[usize],
|
||||||
|
cx: &AppContext,
|
||||||
|
) -> Vec<MessageId> {
|
||||||
|
assistant
|
||||||
|
.read(cx)
|
||||||
|
.messages_for_offsets(offsets.iter().copied(), cx)
|
||||||
|
.into_iter()
|
||||||
|
.map(|message| message.id)
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn messages(
|
fn messages(
|
||||||
assistant: &ModelHandle<Assistant>,
|
assistant: &ModelHandle<Assistant>,
|
||||||
cx: &AppContext,
|
cx: &AppContext,
|
||||||
|
|
Loading…
Reference in a new issue