Cycle message roles on ctrl-r (#2619)

I'd like to follow up to allow roles to be cycled for the selected range
and support multi-cursors, but this is a start and contains a
refactoring, so going to merge.

Release Notes:

- Added the ability to cycle roles in the assistant with `ctrl-r`
This commit is contained in:
Nathan Sobo 2023-06-16 14:11:01 -06:00 committed by GitHub
commit c3b2b4c4e3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 98 additions and 73 deletions

View file

@ -201,7 +201,8 @@
"bindings": { "bindings": {
"cmd-enter": "assistant::Assist", "cmd-enter": "assistant::Assist",
"cmd->": "assistant::QuoteSelection", "cmd->": "assistant::QuoteSelection",
"shift-enter": "assistant::Split" "shift-enter": "assistant::Split",
"ctrl-r": "assistant::CycleMessageRole"
} }
}, },
{ {

View file

@ -44,6 +44,7 @@ actions!(
NewContext, NewContext,
Assist, Assist,
Split, Split,
CycleMessageRole,
QuoteSelection, QuoteSelection,
ToggleFocus, ToggleFocus,
ResetKey ResetKey
@ -72,6 +73,7 @@ pub fn init(cx: &mut AppContext) {
cx.add_action(AssistantEditor::quote_selection); cx.add_action(AssistantEditor::quote_selection);
cx.capture_action(AssistantEditor::copy); cx.capture_action(AssistantEditor::copy);
cx.capture_action(AssistantEditor::split); cx.capture_action(AssistantEditor::split);
cx.capture_action(AssistantEditor::cycle_message_role);
cx.add_action(AssistantPanel::save_api_key); cx.add_action(AssistantPanel::save_api_key);
cx.add_action(AssistantPanel::reset_api_key); cx.add_action(AssistantPanel::reset_api_key);
cx.add_action( cx.add_action(
@ -446,7 +448,7 @@ enum AssistantEvent {
struct Assistant { struct Assistant {
buffer: ModelHandle<Buffer>, buffer: ModelHandle<Buffer>,
messages: Vec<Message>, message_anchors: Vec<MessageAnchor>,
messages_metadata: HashMap<MessageId, MessageMetadata>, messages_metadata: HashMap<MessageId, MessageMetadata>,
next_message_id: MessageId, next_message_id: MessageId,
summary: Option<String>, summary: Option<String>,
@ -491,7 +493,7 @@ impl Assistant {
}); });
let mut this = Self { let mut this = Self {
messages: Default::default(), message_anchors: Default::default(),
messages_metadata: Default::default(), messages_metadata: Default::default(),
next_message_id: Default::default(), next_message_id: Default::default(),
summary: None, summary: None,
@ -506,11 +508,11 @@ impl Assistant {
api_key, api_key,
buffer, buffer,
}; };
let message = Message { let message = MessageAnchor {
id: MessageId(post_inc(&mut this.next_message_id.0)), id: MessageId(post_inc(&mut this.next_message_id.0)),
start: language::Anchor::MIN, start: language::Anchor::MIN,
}; };
this.messages.push(message.clone()); this.message_anchors.push(message.clone());
this.messages_metadata.insert( this.messages_metadata.insert(
message.id, message.id,
MessageMetadata { MessageMetadata {
@ -587,7 +589,7 @@ impl Assistant {
cx.notify(); cx.notify();
} }
fn assist(&mut self, cx: &mut ModelContext<Self>) -> Option<(Message, Message)> { fn assist(&mut self, cx: &mut ModelContext<Self>) -> Option<(MessageAnchor, MessageAnchor)> {
let request = OpenAIRequest { let request = OpenAIRequest {
model: self.model.clone(), model: self.model.clone(),
messages: self.open_ai_request_messages(cx), messages: self.open_ai_request_messages(cx),
@ -597,7 +599,7 @@ impl Assistant {
let api_key = self.api_key.borrow().clone()?; let api_key = self.api_key.borrow().clone()?;
let stream = stream_completion(api_key, cx.background().clone(), request); let stream = stream_completion(api_key, cx.background().clone(), request);
let assistant_message = let assistant_message =
self.insert_message_after(self.messages.last()?.id, Role::Assistant, cx)?; self.insert_message_after(self.message_anchors.last()?.id, Role::Assistant, cx)?;
let user_message = self.insert_message_after(assistant_message.id, Role::User, cx)?; let user_message = self.insert_message_after(assistant_message.id, Role::User, cx)?;
let task = cx.spawn_weak({ let task = cx.spawn_weak({
|this, mut cx| async move { |this, mut cx| async move {
@ -613,14 +615,15 @@ impl Assistant {
.update(&mut cx, |this, cx| { .update(&mut cx, |this, cx| {
let text: Arc<str> = choice.delta.content?.into(); let text: Arc<str> = choice.delta.content?.into();
let message_ix = this let message_ix = this
.messages .message_anchors
.iter() .iter()
.position(|message| message.id == assistant_message_id)?; .position(|message| message.id == assistant_message_id)?;
this.buffer.update(cx, |buffer, cx| { this.buffer.update(cx, |buffer, cx| {
let offset = if message_ix + 1 == this.messages.len() { let offset = if message_ix + 1 == this.message_anchors.len()
{
buffer.len() buffer.len()
} else { } else {
this.messages[message_ix + 1] this.message_anchors[message_ix + 1]
.start .start
.to_offset(buffer) .to_offset(buffer)
.saturating_sub(1) .saturating_sub(1)
@ -685,25 +688,26 @@ impl Assistant {
message_id: MessageId, message_id: MessageId,
role: Role, role: Role,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) -> Option<Message> { ) -> Option<MessageAnchor> {
if let Some(prev_message_ix) = self if let Some(prev_message_ix) = self
.messages .message_anchors
.iter() .iter()
.position(|message| message.id == message_id) .position(|message| message.id == message_id)
{ {
let start = self.buffer.update(cx, |buffer, cx| { let start = self.buffer.update(cx, |buffer, cx| {
let offset = self.messages[prev_message_ix + 1..] let offset = self.message_anchors[prev_message_ix + 1..]
.iter() .iter()
.find(|message| message.start.is_valid(buffer)) .find(|message| message.start.is_valid(buffer))
.map_or(buffer.len(), |message| message.start.to_offset(buffer) - 1); .map_or(buffer.len(), |message| message.start.to_offset(buffer) - 1);
buffer.edit([(offset..offset, "\n")], None, cx); buffer.edit([(offset..offset, "\n")], None, cx);
buffer.anchor_before(offset + 1) buffer.anchor_before(offset + 1)
}); });
let message = Message { let message = MessageAnchor {
id: MessageId(post_inc(&mut self.next_message_id.0)), id: MessageId(post_inc(&mut self.next_message_id.0)),
start, start,
}; };
self.messages.insert(prev_message_ix + 1, message.clone()); self.message_anchors
.insert(prev_message_ix + 1, message.clone());
self.messages_metadata.insert( self.messages_metadata.insert(
message.id, message.id,
MessageMetadata { MessageMetadata {
@ -723,23 +727,21 @@ impl Assistant {
&mut self, &mut self,
range: Range<usize>, range: Range<usize>,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) -> (Option<Message>, Option<Message>) { ) -> (Option<MessageAnchor>, Option<MessageAnchor>) {
let start_message = self.message_for_offset(range.start, cx); let start_message = self.message_for_offset(range.start, cx);
let end_message = self.message_for_offset(range.end, cx); let end_message = self.message_for_offset(range.end, cx);
if let Some((start_message, end_message)) = start_message.zip(end_message) { if let Some((start_message, end_message)) = start_message.zip(end_message) {
let (start_message_ix, _, metadata, message_range) = start_message;
let (end_message_ix, _, _, _) = end_message;
// Prevent splitting when range spans multiple messages. // Prevent splitting when range spans multiple messages.
if start_message_ix != end_message_ix { if start_message.index != end_message.index {
return (None, None); return (None, None);
} }
let role = metadata.role; let message = start_message;
let role = message.role;
let mut edited_buffer = false; let mut edited_buffer = false;
let mut suffix_start = None; let mut suffix_start = None;
if range.start > message_range.start && range.end < message_range.end - 1 { if range.start > message.range.start && range.end < message.range.end - 1 {
if self.buffer.read(cx).chars_at(range.end).next() == Some('\n') { if self.buffer.read(cx).chars_at(range.end).next() == Some('\n') {
suffix_start = Some(range.end + 1); suffix_start = Some(range.end + 1);
} else if self.buffer.read(cx).reversed_chars_at(range.end).next() == Some('\n') { } else if self.buffer.read(cx).reversed_chars_at(range.end).next() == Some('\n') {
@ -748,7 +750,7 @@ impl Assistant {
} }
let suffix = if let Some(suffix_start) = suffix_start { let suffix = if let Some(suffix_start) = suffix_start {
Message { MessageAnchor {
id: MessageId(post_inc(&mut self.next_message_id.0)), id: MessageId(post_inc(&mut self.next_message_id.0)),
start: self.buffer.read(cx).anchor_before(suffix_start), start: self.buffer.read(cx).anchor_before(suffix_start),
} }
@ -757,13 +759,14 @@ impl Assistant {
buffer.edit([(range.end..range.end, "\n")], None, cx); buffer.edit([(range.end..range.end, "\n")], None, cx);
}); });
edited_buffer = true; edited_buffer = true;
Message { MessageAnchor {
id: MessageId(post_inc(&mut self.next_message_id.0)), id: MessageId(post_inc(&mut self.next_message_id.0)),
start: self.buffer.read(cx).anchor_before(range.end + 1), start: self.buffer.read(cx).anchor_before(range.end + 1),
} }
}; };
self.messages.insert(start_message_ix + 1, suffix.clone()); self.message_anchors
.insert(message.index + 1, suffix.clone());
self.messages_metadata.insert( self.messages_metadata.insert(
suffix.id, suffix.id,
MessageMetadata { MessageMetadata {
@ -773,11 +776,11 @@ impl Assistant {
}, },
); );
let new_messages = if range.start == range.end || range.start == message_range.start { let new_messages = if range.start == range.end || range.start == message.range.start {
(None, Some(suffix)) (None, Some(suffix))
} else { } else {
let mut prefix_end = None; let mut prefix_end = None;
if range.start > message_range.start && range.end < message_range.end - 1 { if range.start > message.range.start && range.end < message.range.end - 1 {
if self.buffer.read(cx).chars_at(range.start).next() == Some('\n') { if self.buffer.read(cx).chars_at(range.start).next() == Some('\n') {
prefix_end = Some(range.start + 1); prefix_end = Some(range.start + 1);
} else if self.buffer.read(cx).reversed_chars_at(range.start).next() } else if self.buffer.read(cx).reversed_chars_at(range.start).next()
@ -789,7 +792,7 @@ impl Assistant {
let selection = if let Some(prefix_end) = prefix_end { let selection = if let Some(prefix_end) = prefix_end {
cx.emit(AssistantEvent::MessagesEdited); cx.emit(AssistantEvent::MessagesEdited);
Message { MessageAnchor {
id: MessageId(post_inc(&mut self.next_message_id.0)), id: MessageId(post_inc(&mut self.next_message_id.0)),
start: self.buffer.read(cx).anchor_before(prefix_end), start: self.buffer.read(cx).anchor_before(prefix_end),
} }
@ -798,14 +801,14 @@ impl Assistant {
buffer.edit([(range.start..range.start, "\n")], None, cx) buffer.edit([(range.start..range.start, "\n")], None, cx)
}); });
edited_buffer = true; edited_buffer = true;
Message { MessageAnchor {
id: MessageId(post_inc(&mut self.next_message_id.0)), id: MessageId(post_inc(&mut self.next_message_id.0)),
start: self.buffer.read(cx).anchor_before(range.end + 1), start: self.buffer.read(cx).anchor_before(range.end + 1),
} }
}; };
self.messages self.message_anchors
.insert(start_message_ix + 1, selection.clone()); .insert(message.index + 1, selection.clone());
self.messages_metadata.insert( self.messages_metadata.insert(
selection.id, selection.id,
MessageMetadata { MessageMetadata {
@ -827,7 +830,7 @@ impl Assistant {
} }
fn summarize(&mut self, cx: &mut ModelContext<Self>) { fn summarize(&mut self, cx: &mut ModelContext<Self>) {
if self.messages.len() >= 2 && self.summary.is_none() { if self.message_anchors.len() >= 2 && self.summary.is_none() {
let api_key = self.api_key.borrow().clone(); let api_key = self.api_key.borrow().clone();
if let Some(api_key) = api_key { if let Some(api_key) = api_key {
let mut messages = self.open_ai_request_messages(cx); let mut messages = self.open_ai_request_messages(cx);
@ -870,50 +873,51 @@ impl Assistant {
fn open_ai_request_messages(&self, cx: &AppContext) -> Vec<RequestMessage> { fn open_ai_request_messages(&self, cx: &AppContext) -> Vec<RequestMessage> {
let buffer = self.buffer.read(cx); let buffer = self.buffer.read(cx);
self.messages(cx) self.messages(cx)
.map(|(_ix, _message, metadata, range)| RequestMessage { .map(|message| RequestMessage {
role: metadata.role, role: message.role,
content: buffer.text_for_range(range).collect(), content: buffer.text_for_range(message.range).collect(),
}) })
.collect() .collect()
} }
fn message_for_offset<'a>( fn message_for_offset<'a>(&'a self, offset: usize, cx: &'a AppContext) -> Option<Message> {
&'a self,
offset: usize,
cx: &'a AppContext,
) -> Option<(usize, &Message, &MessageMetadata, Range<usize>)> {
let mut messages = self.messages(cx).peekable(); let mut messages = self.messages(cx).peekable();
while let Some((ix, message, metadata, range)) = messages.next() { while let Some(message) = messages.next() {
if range.contains(&offset) || messages.peek().is_none() { if message.range.contains(&offset) || messages.peek().is_none() {
return Some((ix, message, metadata, range)); return Some(message);
} }
} }
None None
} }
fn messages<'a>( fn messages<'a>(&'a self, cx: &'a AppContext) -> impl 'a + Iterator<Item = Message> {
&'a self,
cx: &'a AppContext,
) -> impl 'a + Iterator<Item = (usize, &Message, &MessageMetadata, Range<usize>)> {
let buffer = self.buffer.read(cx); let buffer = self.buffer.read(cx);
let mut messages = self.messages.iter().enumerate().peekable(); let mut message_anchors = self.message_anchors.iter().enumerate().peekable();
iter::from_fn(move || { iter::from_fn(move || {
while let Some((ix, message)) = messages.next() { while let Some((ix, message_anchor)) = message_anchors.next() {
let metadata = self.messages_metadata.get(&message.id)?; let metadata = self.messages_metadata.get(&message_anchor.id)?;
let message_start = message.start.to_offset(buffer); let message_start = message_anchor.start.to_offset(buffer);
let mut message_end = None; let mut message_end = None;
while let Some((_, next_message)) = messages.peek() { while let Some((_, next_message)) = message_anchors.peek() {
if next_message.start.is_valid(buffer) { if next_message.start.is_valid(buffer) {
message_end = Some(next_message.start); message_end = Some(next_message.start);
break; break;
} else { } else {
messages.next(); message_anchors.next();
} }
} }
let message_end = message_end let message_end = message_end
.unwrap_or(language::Anchor::MAX) .unwrap_or(language::Anchor::MAX)
.to_offset(buffer); .to_offset(buffer);
return Some((ix, message, metadata, message_start..message_end)); return Some(Message {
index: ix,
range: message_start..message_end,
id: message_anchor.id,
anchor: message_anchor.start,
role: metadata.role,
sent_at: metadata.sent_at,
error: metadata.error.clone(),
});
} }
None None
}) })
@ -1003,6 +1007,15 @@ impl AssistantEditor {
} }
} }
fn cycle_message_role(&mut self, _: &CycleMessageRole, cx: &mut ViewContext<Self>) {
let cursor_offset = self.editor.read(cx).selections.newest(cx).head();
self.assistant.update(cx, |assistant, cx| {
if let Some(message) = assistant.message_for_offset(cursor_offset, cx) {
assistant.cycle_message_role(message.id, cx);
}
});
}
fn handle_assistant_event( fn handle_assistant_event(
&mut self, &mut self,
_: ModelHandle<Assistant>, _: ModelHandle<Assistant>,
@ -1087,14 +1100,14 @@ impl AssistantEditor {
.assistant .assistant
.read(cx) .read(cx)
.messages(cx) .messages(cx)
.map(|(_, message, metadata, _)| BlockProperties { .map(|message| BlockProperties {
position: buffer.anchor_in_excerpt(excerpt_id, message.start), position: buffer.anchor_in_excerpt(excerpt_id, message.anchor),
height: 2, height: 2,
style: BlockStyle::Sticky, style: BlockStyle::Sticky,
render: Arc::new({ render: Arc::new({
let assistant = self.assistant.clone(); let assistant = self.assistant.clone();
let metadata = metadata.clone(); // let metadata = message.metadata.clone();
let message = message.clone(); // let message = message.clone();
move |cx| { move |cx| {
enum Sender {} enum Sender {}
enum ErrorTooltip {} enum ErrorTooltip {}
@ -1105,7 +1118,7 @@ impl AssistantEditor {
let sender = MouseEventHandler::<Sender, _>::new( let sender = MouseEventHandler::<Sender, _>::new(
message_id.0, message_id.0,
cx, cx,
|state, _| match metadata.role { |state, _| match message.role {
Role::User => { Role::User => {
let style = style.user_sender.style_for(state, false); let style = style.user_sender.style_for(state, false);
Label::new("You", style.text.clone()) Label::new("You", style.text.clone())
@ -1140,14 +1153,14 @@ impl AssistantEditor {
.with_child(sender.aligned()) .with_child(sender.aligned())
.with_child( .with_child(
Label::new( Label::new(
metadata.sent_at.format("%I:%M%P").to_string(), message.sent_at.format("%I:%M%P").to_string(),
style.sent_at.text.clone(), style.sent_at.text.clone(),
) )
.contained() .contained()
.with_style(style.sent_at.container) .with_style(style.sent_at.container)
.aligned(), .aligned(),
) )
.with_children(metadata.error.clone().map(|error| { .with_children(message.error.as_ref().map(|error| {
Svg::new("icons/circle_x_mark_12.svg") Svg::new("icons/circle_x_mark_12.svg")
.with_color(style.error_icon.color) .with_color(style.error_icon.color)
.constrained() .constrained()
@ -1156,7 +1169,7 @@ impl AssistantEditor {
.with_style(style.error_icon.container) .with_style(style.error_icon.container)
.with_tooltip::<ErrorTooltip>( .with_tooltip::<ErrorTooltip>(
message_id.0, message_id.0,
error, error.to_string(),
None, None,
theme.tooltip.clone(), theme.tooltip.clone(),
cx, cx,
@ -1252,15 +1265,15 @@ impl AssistantEditor {
let selection = editor.selections.newest::<usize>(cx); let selection = editor.selections.newest::<usize>(cx);
let mut copied_text = String::new(); let mut copied_text = String::new();
let mut spanned_messages = 0; let mut spanned_messages = 0;
for (_ix, _message, metadata, message_range) in assistant.messages(cx) { for message in assistant.messages(cx) {
if message_range.start >= selection.range().end { if message.range.start >= selection.range().end {
break; break;
} else if message_range.end >= selection.range().start { } else if message.range.end >= selection.range().start {
let range = cmp::max(message_range.start, selection.range().start) let range = cmp::max(message.range.start, selection.range().start)
..cmp::min(message_range.end, selection.range().end); ..cmp::min(message.range.end, selection.range().end);
if !range.is_empty() { if !range.is_empty() {
spanned_messages += 1; spanned_messages += 1;
write!(&mut copied_text, "## {}\n\n", metadata.role).unwrap(); write!(&mut copied_text, "## {}\n\n", message.role).unwrap();
for chunk in assistant.buffer.read(cx).text_for_range(range) { for chunk in assistant.buffer.read(cx).text_for_range(range) {
copied_text.push_str(&chunk); copied_text.push_str(&chunk);
} }
@ -1395,7 +1408,7 @@ impl Item for AssistantEditor {
struct MessageId(usize); struct MessageId(usize);
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
struct Message { struct MessageAnchor {
id: MessageId, id: MessageId,
start: language::Anchor, start: language::Anchor,
} }
@ -1404,7 +1417,18 @@ struct Message {
struct MessageMetadata { struct MessageMetadata {
role: Role, role: Role,
sent_at: DateTime<Local>, sent_at: DateTime<Local>,
error: Option<String>, error: Option<Arc<str>>,
}
#[derive(Clone, Debug)]
pub struct Message {
range: Range<usize>,
index: usize,
id: MessageId,
anchor: language::Anchor,
role: Role,
sent_at: DateTime<Local>,
error: Option<Arc<str>>,
} }
async fn stream_completion( async fn stream_completion(
@ -1504,7 +1528,7 @@ mod tests {
let assistant = cx.add_model(|cx| Assistant::new(Default::default(), registry, cx)); let assistant = cx.add_model(|cx| Assistant::new(Default::default(), registry, cx));
let buffer = assistant.read(cx).buffer.clone(); let buffer = assistant.read(cx).buffer.clone();
let message_1 = assistant.read(cx).messages[0].clone(); let message_1 = assistant.read(cx).message_anchors[0].clone();
assert_eq!( assert_eq!(
messages(&assistant, cx), messages(&assistant, cx),
vec![(message_1.id, Role::User, 0..0)] vec![(message_1.id, Role::User, 0..0)]
@ -1630,7 +1654,7 @@ mod tests {
let assistant = cx.add_model(|cx| Assistant::new(Default::default(), registry, cx)); let assistant = cx.add_model(|cx| Assistant::new(Default::default(), registry, cx));
let buffer = assistant.read(cx).buffer.clone(); let buffer = assistant.read(cx).buffer.clone();
let message_1 = assistant.read(cx).messages[0].clone(); let message_1 = assistant.read(cx).message_anchors[0].clone();
assert_eq!( assert_eq!(
messages(&assistant, cx), messages(&assistant, cx),
vec![(message_1.id, Role::User, 0..0)] vec![(message_1.id, Role::User, 0..0)]
@ -1724,7 +1748,7 @@ mod tests {
assistant assistant
.read(cx) .read(cx)
.messages(cx) .messages(cx)
.map(|(_, message, metadata, range)| (message.id, metadata.role, range)) .map(|message| (message.id, message.role, message.range))
.collect() .collect()
} }
} }