From 75b5ac8488c72af2702f582249543320e3f541e4 Mon Sep 17 00:00:00 2001 From: Nathan Sobo Date: Thu, 15 Jun 2023 16:24:53 -0600 Subject: [PATCH] Cycle message roles on ctrl-r --- assets/keymaps/default.json | 3 +- crates/ai/src/assistant.rs | 168 ++++++++++++++++++++---------------- 2 files changed, 98 insertions(+), 73 deletions(-) diff --git a/assets/keymaps/default.json b/assets/keymaps/default.json index f6682a9f0b..a642697a37 100644 --- a/assets/keymaps/default.json +++ b/assets/keymaps/default.json @@ -201,7 +201,8 @@ "bindings": { "cmd-enter": "assistant::Assist", "cmd->": "assistant::QuoteSelection", - "shift-enter": "assistant::Split" + "shift-enter": "assistant::Split", + "ctrl-r": "assistant::CycleMessageRole" } }, { diff --git a/crates/ai/src/assistant.rs b/crates/ai/src/assistant.rs index eff3dc4d20..853f7262d3 100644 --- a/crates/ai/src/assistant.rs +++ b/crates/ai/src/assistant.rs @@ -44,6 +44,7 @@ actions!( NewContext, Assist, Split, + CycleMessageRole, QuoteSelection, ToggleFocus, ResetKey @@ -72,6 +73,7 @@ pub fn init(cx: &mut AppContext) { cx.add_action(AssistantEditor::quote_selection); cx.capture_action(AssistantEditor::copy); cx.capture_action(AssistantEditor::split); + cx.capture_action(AssistantEditor::cycle_message_role); cx.add_action(AssistantPanel::save_api_key); cx.add_action(AssistantPanel::reset_api_key); cx.add_action( @@ -446,7 +448,7 @@ enum AssistantEvent { struct Assistant { buffer: ModelHandle, - messages: Vec, + message_anchors: Vec, messages_metadata: HashMap, next_message_id: MessageId, summary: Option, @@ -491,7 +493,7 @@ impl Assistant { }); let mut this = Self { - messages: Default::default(), + message_anchors: Default::default(), messages_metadata: Default::default(), next_message_id: Default::default(), summary: None, @@ -506,11 +508,11 @@ impl Assistant { api_key, buffer, }; - let message = Message { + let message = MessageAnchor { id: MessageId(post_inc(&mut this.next_message_id.0)), start: language::Anchor::MIN, }; - this.messages.push(message.clone()); + this.message_anchors.push(message.clone()); this.messages_metadata.insert( message.id, MessageMetadata { @@ -587,7 +589,7 @@ impl Assistant { cx.notify(); } - fn assist(&mut self, cx: &mut ModelContext) -> Option<(Message, Message)> { + fn assist(&mut self, cx: &mut ModelContext) -> Option<(MessageAnchor, MessageAnchor)> { let request = OpenAIRequest { model: self.model.clone(), messages: self.open_ai_request_messages(cx), @@ -597,7 +599,7 @@ impl Assistant { let api_key = self.api_key.borrow().clone()?; let stream = stream_completion(api_key, cx.background().clone(), request); let assistant_message = - self.insert_message_after(self.messages.last()?.id, Role::Assistant, cx)?; + self.insert_message_after(self.message_anchors.last()?.id, Role::Assistant, cx)?; let user_message = self.insert_message_after(assistant_message.id, Role::User, cx)?; let task = cx.spawn_weak({ |this, mut cx| async move { @@ -613,14 +615,15 @@ impl Assistant { .update(&mut cx, |this, cx| { let text: Arc = choice.delta.content?.into(); let message_ix = this - .messages + .message_anchors .iter() .position(|message| message.id == assistant_message_id)?; this.buffer.update(cx, |buffer, cx| { - let offset = if message_ix + 1 == this.messages.len() { + let offset = if message_ix + 1 == this.message_anchors.len() + { buffer.len() } else { - this.messages[message_ix + 1] + this.message_anchors[message_ix + 1] .start .to_offset(buffer) .saturating_sub(1) @@ -685,25 +688,26 @@ impl Assistant { message_id: MessageId, role: Role, cx: &mut ModelContext, - ) -> Option { + ) -> Option { if let Some(prev_message_ix) = self - .messages + .message_anchors .iter() .position(|message| message.id == message_id) { let start = self.buffer.update(cx, |buffer, cx| { - let offset = self.messages[prev_message_ix + 1..] + let offset = self.message_anchors[prev_message_ix + 1..] .iter() .find(|message| message.start.is_valid(buffer)) .map_or(buffer.len(), |message| message.start.to_offset(buffer) - 1); buffer.edit([(offset..offset, "\n")], None, cx); buffer.anchor_before(offset + 1) }); - let message = Message { + let message = MessageAnchor { id: MessageId(post_inc(&mut self.next_message_id.0)), start, }; - self.messages.insert(prev_message_ix + 1, message.clone()); + self.message_anchors + .insert(prev_message_ix + 1, message.clone()); self.messages_metadata.insert( message.id, MessageMetadata { @@ -723,23 +727,21 @@ impl Assistant { &mut self, range: Range, cx: &mut ModelContext, - ) -> (Option, Option) { + ) -> (Option, Option) { let start_message = self.message_for_offset(range.start, cx); let end_message = self.message_for_offset(range.end, cx); if let Some((start_message, end_message)) = start_message.zip(end_message) { - let (start_message_ix, _, metadata, message_range) = start_message; - let (end_message_ix, _, _, _) = end_message; - // Prevent splitting when range spans multiple messages. - if start_message_ix != end_message_ix { + if start_message.index != end_message.index { return (None, None); } - let role = metadata.role; + let message = start_message; + let role = message.role; let mut edited_buffer = false; let mut suffix_start = None; - if range.start > message_range.start && range.end < message_range.end - 1 { + if range.start > message.range.start && range.end < message.range.end - 1 { if self.buffer.read(cx).chars_at(range.end).next() == Some('\n') { suffix_start = Some(range.end + 1); } else if self.buffer.read(cx).reversed_chars_at(range.end).next() == Some('\n') { @@ -748,7 +750,7 @@ impl Assistant { } let suffix = if let Some(suffix_start) = suffix_start { - Message { + MessageAnchor { id: MessageId(post_inc(&mut self.next_message_id.0)), start: self.buffer.read(cx).anchor_before(suffix_start), } @@ -757,13 +759,14 @@ impl Assistant { buffer.edit([(range.end..range.end, "\n")], None, cx); }); edited_buffer = true; - Message { + MessageAnchor { id: MessageId(post_inc(&mut self.next_message_id.0)), start: self.buffer.read(cx).anchor_before(range.end + 1), } }; - self.messages.insert(start_message_ix + 1, suffix.clone()); + self.message_anchors + .insert(message.index + 1, suffix.clone()); self.messages_metadata.insert( suffix.id, MessageMetadata { @@ -773,11 +776,11 @@ impl Assistant { }, ); - let new_messages = if range.start == range.end || range.start == message_range.start { + let new_messages = if range.start == range.end || range.start == message.range.start { (None, Some(suffix)) } else { let mut prefix_end = None; - if range.start > message_range.start && range.end < message_range.end - 1 { + if range.start > message.range.start && range.end < message.range.end - 1 { if self.buffer.read(cx).chars_at(range.start).next() == Some('\n') { prefix_end = Some(range.start + 1); } else if self.buffer.read(cx).reversed_chars_at(range.start).next() @@ -789,7 +792,7 @@ impl Assistant { let selection = if let Some(prefix_end) = prefix_end { cx.emit(AssistantEvent::MessagesEdited); - Message { + MessageAnchor { id: MessageId(post_inc(&mut self.next_message_id.0)), start: self.buffer.read(cx).anchor_before(prefix_end), } @@ -798,14 +801,14 @@ impl Assistant { buffer.edit([(range.start..range.start, "\n")], None, cx) }); edited_buffer = true; - Message { + MessageAnchor { id: MessageId(post_inc(&mut self.next_message_id.0)), start: self.buffer.read(cx).anchor_before(range.end + 1), } }; - self.messages - .insert(start_message_ix + 1, selection.clone()); + self.message_anchors + .insert(message.index + 1, selection.clone()); self.messages_metadata.insert( selection.id, MessageMetadata { @@ -827,7 +830,7 @@ impl Assistant { } fn summarize(&mut self, cx: &mut ModelContext) { - if self.messages.len() >= 2 && self.summary.is_none() { + if self.message_anchors.len() >= 2 && self.summary.is_none() { let api_key = self.api_key.borrow().clone(); if let Some(api_key) = api_key { let mut messages = self.open_ai_request_messages(cx); @@ -870,50 +873,51 @@ impl Assistant { fn open_ai_request_messages(&self, cx: &AppContext) -> Vec { let buffer = self.buffer.read(cx); self.messages(cx) - .map(|(_ix, _message, metadata, range)| RequestMessage { - role: metadata.role, - content: buffer.text_for_range(range).collect(), + .map(|message| RequestMessage { + role: message.role, + content: buffer.text_for_range(message.range).collect(), }) .collect() } - fn message_for_offset<'a>( - &'a self, - offset: usize, - cx: &'a AppContext, - ) -> Option<(usize, &Message, &MessageMetadata, Range)> { + fn message_for_offset<'a>(&'a self, offset: usize, cx: &'a AppContext) -> Option { let mut messages = self.messages(cx).peekable(); - while let Some((ix, message, metadata, range)) = messages.next() { - if range.contains(&offset) || messages.peek().is_none() { - return Some((ix, message, metadata, range)); + while let Some(message) = messages.next() { + if message.range.contains(&offset) || messages.peek().is_none() { + return Some(message); } } None } - fn messages<'a>( - &'a self, - cx: &'a AppContext, - ) -> impl 'a + Iterator)> { + fn messages<'a>(&'a self, cx: &'a AppContext) -> impl 'a + Iterator { let buffer = self.buffer.read(cx); - let mut messages = self.messages.iter().enumerate().peekable(); + let mut message_anchors = self.message_anchors.iter().enumerate().peekable(); iter::from_fn(move || { - while let Some((ix, message)) = messages.next() { - let metadata = self.messages_metadata.get(&message.id)?; - let message_start = message.start.to_offset(buffer); + while let Some((ix, message_anchor)) = message_anchors.next() { + let metadata = self.messages_metadata.get(&message_anchor.id)?; + let message_start = message_anchor.start.to_offset(buffer); let mut message_end = None; - while let Some((_, next_message)) = messages.peek() { + while let Some((_, next_message)) = message_anchors.peek() { if next_message.start.is_valid(buffer) { message_end = Some(next_message.start); break; } else { - messages.next(); + message_anchors.next(); } } let message_end = message_end .unwrap_or(language::Anchor::MAX) .to_offset(buffer); - return Some((ix, message, metadata, message_start..message_end)); + return Some(Message { + index: ix, + range: message_start..message_end, + id: message_anchor.id, + anchor: message_anchor.start, + role: metadata.role, + sent_at: metadata.sent_at, + error: metadata.error.clone(), + }); } None }) @@ -1003,6 +1007,15 @@ impl AssistantEditor { } } + fn cycle_message_role(&mut self, _: &CycleMessageRole, cx: &mut ViewContext) { + let cursor_offset = self.editor.read(cx).selections.newest(cx).head(); + self.assistant.update(cx, |assistant, cx| { + if let Some(message) = assistant.message_for_offset(cursor_offset, cx) { + assistant.cycle_message_role(message.id, cx); + } + }); + } + fn handle_assistant_event( &mut self, _: ModelHandle, @@ -1087,14 +1100,14 @@ impl AssistantEditor { .assistant .read(cx) .messages(cx) - .map(|(_, message, metadata, _)| BlockProperties { - position: buffer.anchor_in_excerpt(excerpt_id, message.start), + .map(|message| BlockProperties { + position: buffer.anchor_in_excerpt(excerpt_id, message.anchor), height: 2, style: BlockStyle::Sticky, render: Arc::new({ let assistant = self.assistant.clone(); - let metadata = metadata.clone(); - let message = message.clone(); + // let metadata = message.metadata.clone(); + // let message = message.clone(); move |cx| { enum Sender {} enum ErrorTooltip {} @@ -1105,7 +1118,7 @@ impl AssistantEditor { let sender = MouseEventHandler::::new( message_id.0, cx, - |state, _| match metadata.role { + |state, _| match message.role { Role::User => { let style = style.user_sender.style_for(state, false); Label::new("You", style.text.clone()) @@ -1140,14 +1153,14 @@ impl AssistantEditor { .with_child(sender.aligned()) .with_child( Label::new( - metadata.sent_at.format("%I:%M%P").to_string(), + message.sent_at.format("%I:%M%P").to_string(), style.sent_at.text.clone(), ) .contained() .with_style(style.sent_at.container) .aligned(), ) - .with_children(metadata.error.clone().map(|error| { + .with_children(message.error.as_ref().map(|error| { Svg::new("icons/circle_x_mark_12.svg") .with_color(style.error_icon.color) .constrained() @@ -1156,7 +1169,7 @@ impl AssistantEditor { .with_style(style.error_icon.container) .with_tooltip::( message_id.0, - error, + error.to_string(), None, theme.tooltip.clone(), cx, @@ -1252,15 +1265,15 @@ impl AssistantEditor { let selection = editor.selections.newest::(cx); let mut copied_text = String::new(); let mut spanned_messages = 0; - for (_ix, _message, metadata, message_range) in assistant.messages(cx) { - if message_range.start >= selection.range().end { + for message in assistant.messages(cx) { + if message.range.start >= selection.range().end { break; - } else if message_range.end >= selection.range().start { - let range = cmp::max(message_range.start, selection.range().start) - ..cmp::min(message_range.end, selection.range().end); + } else if message.range.end >= selection.range().start { + let range = cmp::max(message.range.start, selection.range().start) + ..cmp::min(message.range.end, selection.range().end); if !range.is_empty() { spanned_messages += 1; - write!(&mut copied_text, "## {}\n\n", metadata.role).unwrap(); + write!(&mut copied_text, "## {}\n\n", message.role).unwrap(); for chunk in assistant.buffer.read(cx).text_for_range(range) { copied_text.push_str(&chunk); } @@ -1395,7 +1408,7 @@ impl Item for AssistantEditor { struct MessageId(usize); #[derive(Clone, Debug)] -struct Message { +struct MessageAnchor { id: MessageId, start: language::Anchor, } @@ -1404,7 +1417,18 @@ struct Message { struct MessageMetadata { role: Role, sent_at: DateTime, - error: Option, + error: Option>, +} + +#[derive(Clone, Debug)] +pub struct Message { + range: Range, + index: usize, + id: MessageId, + anchor: language::Anchor, + role: Role, + sent_at: DateTime, + error: Option>, } async fn stream_completion( @@ -1504,7 +1528,7 @@ mod tests { let assistant = cx.add_model(|cx| Assistant::new(Default::default(), registry, cx)); let buffer = assistant.read(cx).buffer.clone(); - let message_1 = assistant.read(cx).messages[0].clone(); + let message_1 = assistant.read(cx).message_anchors[0].clone(); assert_eq!( messages(&assistant, cx), vec![(message_1.id, Role::User, 0..0)] @@ -1630,7 +1654,7 @@ mod tests { let assistant = cx.add_model(|cx| Assistant::new(Default::default(), registry, cx)); let buffer = assistant.read(cx).buffer.clone(); - let message_1 = assistant.read(cx).messages[0].clone(); + let message_1 = assistant.read(cx).message_anchors[0].clone(); assert_eq!( messages(&assistant, cx), vec![(message_1.id, Role::User, 0..0)] @@ -1724,7 +1748,7 @@ mod tests { assistant .read(cx) .messages(cx) - .map(|(_, message, metadata, range)| (message.id, metadata.role, range)) + .map(|message| (message.id, message.role, message.range)) .collect() } }