From 75e23290285a751fed9b3faf0d81832bb21d5d8c Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 19 Jun 2023 17:23:40 +0200 Subject: [PATCH] Allow for multi-cursor `assist` and `cycle_role` actions Co-Authored-By: Nathan Sobo Co-Authored-By: Kyle Caverly --- crates/ai/src/assistant.rs | 350 +++++++++++++++++++++++++------------ 1 file changed, 241 insertions(+), 109 deletions(-) diff --git a/crates/ai/src/assistant.rs b/crates/ai/src/assistant.rs index be3a49ce18..83b7105be2 100644 --- a/crates/ai/src/assistant.rs +++ b/crates/ai/src/assistant.rs @@ -22,7 +22,7 @@ use gpui::{ Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle, WindowContext, }; use isahc::{http::StatusCode, Request, RequestExt}; -use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, Selection, ToOffset as _}; +use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _}; use serde::Deserialize; use settings::SettingsStore; use std::{ @@ -591,106 +591,129 @@ impl Assistant { fn assist( &mut self, - selection: Selection, + selected_messages: HashSet, cx: &mut ModelContext, - ) -> Option<(MessageAnchor, MessageAnchor)> { - let request = OpenAIRequest { - model: self.model.clone(), - messages: self - .messages(cx) - .map(|message| message.to_open_ai_message(self.buffer.read(cx))) - .collect(), - stream: true, - }; + ) -> Vec { + let mut user_messages = Vec::new(); + for selected_message_id in selected_messages { + let selected_message_role = + if let Some(metadata) = self.messages_metadata.get(&selected_message_id) { + metadata.role + } else { + continue; + }; + let Some(user_message) = self.insert_message_after(selected_message_id, Role::User, cx) else { + continue; + }; + user_messages.push(user_message); + if selected_message_role == Role::User { + let request = OpenAIRequest { + model: self.model.clone(), + messages: self + .messages(cx) + .map(|message| message.to_open_ai_message(self.buffer.read(cx))) + .chain(Some(RequestMessage { + role: Role::System, + content: format!( + "Direct your reply to message with id {}. Do not include a [Message X] header.", + selected_message_id.0 + ), + })) + .collect(), + stream: true, + }; - let api_key = self.api_key.borrow().clone()?; - let stream = stream_completion(api_key, cx.background().clone(), request); - let assistant_message = self.insert_message_after( - self.message_for_offset(selection.head(), cx)?.id, - Role::Assistant, - cx, - )?; - let user_message = self.insert_message_after(assistant_message.id, Role::User, cx)?; + let Some(api_key) = self.api_key.borrow().clone() else { continue }; + let stream = stream_completion(api_key, cx.background().clone(), request); + let assistant_message = self + .insert_message_after(selected_message_id, Role::Assistant, cx) + .unwrap(); - let task = cx.spawn_weak({ - |this, mut cx| async move { - let assistant_message_id = assistant_message.id; - let stream_completion = async { - let mut messages = stream.await?; + let task = cx.spawn_weak({ + |this, mut cx| async move { + let assistant_message_id = assistant_message.id; + let stream_completion = async { + let mut messages = stream.await?; + + while let Some(message) = messages.next().await { + let mut message = message?; + if let Some(choice) = message.choices.pop() { + this.upgrade(&cx) + .ok_or_else(|| anyhow!("assistant was dropped"))? + .update(&mut cx, |this, cx| { + let text: Arc = choice.delta.content?.into(); + let message_ix = this.message_anchors.iter().position( + |message| message.id == assistant_message_id, + )?; + this.buffer.update(cx, |buffer, cx| { + let offset = if message_ix + 1 + == this.message_anchors.len() + { + buffer.len() + } else { + this.message_anchors[message_ix + 1] + .start + .to_offset(buffer) + .saturating_sub(1) + }; + buffer.edit([(offset..offset, text)], None, cx); + }); + cx.emit(AssistantEvent::StreamedCompletion); + + Some(()) + }); + } + } - while let Some(message) = messages.next().await { - let mut message = message?; - if let Some(choice) = message.choices.pop() { this.upgrade(&cx) .ok_or_else(|| anyhow!("assistant was dropped"))? .update(&mut cx, |this, cx| { - let text: Arc = choice.delta.content?.into(); - let message_ix = this - .message_anchors - .iter() - .position(|message| message.id == assistant_message_id)?; - this.buffer.update(cx, |buffer, cx| { - let offset = if message_ix + 1 == this.message_anchors.len() - { - buffer.len() - } else { - this.message_anchors[message_ix + 1] - .start - .to_offset(buffer) - .saturating_sub(1) - }; - buffer.edit([(offset..offset, text)], None, cx); + this.pending_completions.retain(|completion| { + completion.id != this.completion_count }); - cx.emit(AssistantEvent::StreamedCompletion); - - Some(()) + this.summarize(cx); }); + + anyhow::Ok(()) + }; + + let result = stream_completion.await; + if let Some(this) = this.upgrade(&cx) { + this.update(&mut cx, |this, cx| { + if let Err(error) = result { + if let Some(metadata) = + this.messages_metadata.get_mut(&assistant_message.id) + { + metadata.error = Some(error.to_string().trim().into()); + cx.notify(); + } + } + }); } } + }); - this.upgrade(&cx) - .ok_or_else(|| anyhow!("assistant was dropped"))? - .update(&mut cx, |this, cx| { - this.pending_completions - .retain(|completion| completion.id != this.completion_count); - this.summarize(cx); - }); - - anyhow::Ok(()) - }; - - let result = stream_completion.await; - if let Some(this) = this.upgrade(&cx) { - this.update(&mut cx, |this, cx| { - if let Err(error) = result { - if let Some(metadata) = - this.messages_metadata.get_mut(&assistant_message.id) - { - metadata.error = Some(error.to_string().trim().into()); - cx.notify(); - } - } - }); - } + self.pending_completions.push(PendingCompletion { + id: post_inc(&mut self.completion_count), + _task: task, + }); } - }); + } - self.pending_completions.push(PendingCompletion { - id: post_inc(&mut self.completion_count), - _task: task, - }); - Some((assistant_message, user_message)) + user_messages } fn cancel_last_assist(&mut self) -> bool { self.pending_completions.pop().is_some() } - fn cycle_message_role(&mut self, id: MessageId, cx: &mut ModelContext) { - if let Some(metadata) = self.messages_metadata.get_mut(&id) { - metadata.role.cycle(); - cx.emit(AssistantEvent::MessagesEdited); - cx.notify(); + fn cycle_message_roles(&mut self, ids: HashSet, cx: &mut ModelContext) { + for id in ids { + if let Some(metadata) = self.messages_metadata.get_mut(&id) { + metadata.role.cycle(); + cx.emit(AssistantEvent::MessagesEdited); + cx.notify(); + } } } @@ -884,14 +907,39 @@ impl Assistant { } } - fn message_for_offset<'a>(&'a self, offset: usize, cx: &'a AppContext) -> Option { + fn message_for_offset(&self, offset: usize, cx: &AppContext) -> Option { + self.messages_for_offsets([offset], cx).pop() + } + + fn messages_for_offsets( + &self, + offsets: impl IntoIterator, + cx: &AppContext, + ) -> Vec { + let mut result = Vec::new(); + + let buffer_len = self.buffer.read(cx).len(); let mut messages = self.messages(cx).peekable(); - while let Some(message) = messages.next() { - if message.range.contains(&offset) || messages.peek().is_none() { - return Some(message); + let mut offsets = offsets.into_iter().peekable(); + while let Some(offset) = offsets.next() { + // Skip messages that start after the offset. + while messages.peek().map_or(false, |message| { + message.range.end < offset || (message.range.end == offset && offset < buffer_len) + }) { + messages.next(); } + let Some(message) = messages.peek() else { continue }; + + // Skip offsets that are in the same message. + while offsets.peek().map_or(false, |offset| { + message.range.contains(offset) || message.range.end == buffer_len + }) { + offsets.next(); + } + + result.push(message.clone()); } - None + result } fn messages<'a>(&'a self, cx: &'a AppContext) -> impl 'a + Iterator { @@ -983,24 +1031,32 @@ impl AssistantEditor { } fn assist(&mut self, _: &Assist, cx: &mut ViewContext) { - let selection = self.editor.read(cx).selections.newest(cx); - let user_message = self.assistant.update(cx, |assistant, cx| { - let (_, user_message) = assistant.assist(selection, cx)?; - Some(user_message) - }); + let cursors = self.cursors(cx); - if let Some(user_message) = user_message { - let cursor = user_message - .start - .to_offset(&self.assistant.read(cx).buffer.read(cx)); - self.editor.update(cx, |editor, cx| { - editor.change_selections( - Some(Autoscroll::Strategy(AutoscrollStrategy::Fit)), - cx, - |selections| selections.select_ranges([cursor..cursor]), - ); - }); - } + let user_messages = self.assistant.update(cx, |assistant, cx| { + let selected_messages = assistant + .messages_for_offsets(cursors, cx) + .into_iter() + .map(|message| message.id) + .collect(); + assistant.assist(selected_messages, cx) + }); + let new_selections = user_messages + .iter() + .map(|message| { + let cursor = message + .start + .to_offset(self.assistant.read(cx).buffer.read(cx)); + cursor..cursor + }) + .collect::>(); + self.editor.update(cx, |editor, cx| { + editor.change_selections( + Some(Autoscroll::Strategy(AutoscrollStrategy::Fit)), + cx, + |selections| selections.select_ranges(new_selections), + ); + }); } fn cancel_last_assist(&mut self, _: &editor::Cancel, cx: &mut ViewContext) { @@ -1013,14 +1069,25 @@ impl AssistantEditor { } fn cycle_message_role(&mut self, _: &CycleMessageRole, cx: &mut ViewContext) { - let cursor_offset = self.editor.read(cx).selections.newest(cx).head(); + let cursors = self.cursors(cx); self.assistant.update(cx, |assistant, cx| { - if let Some(message) = assistant.message_for_offset(cursor_offset, cx) { - assistant.cycle_message_role(message.id, cx); - } + let messages = assistant + .messages_for_offsets(cursors, cx) + .into_iter() + .map(|message| message.id) + .collect(); + assistant.cycle_message_roles(messages, cx) }); } + fn cursors(&self, cx: &AppContext) -> Vec { + let selections = self.editor.read(cx).selections.all::(cx); + selections + .into_iter() + .map(|selection| selection.head()) + .collect() + } + fn handle_assistant_event( &mut self, _: ModelHandle, @@ -1149,7 +1216,10 @@ impl AssistantEditor { let assistant = assistant.clone(); move |_, _, cx| { assistant.update(cx, |assistant, cx| { - assistant.cycle_message_role(message_id, cx) + assistant.cycle_message_roles( + HashSet::from_iter(Some(message_id)), + cx, + ) }) } }); @@ -1444,9 +1514,11 @@ pub struct Message { impl Message { fn to_open_ai_message(&self, buffer: &Buffer) -> RequestMessage { + let mut content = format!("[Message {}]\n", self.id.0).to_string(); + content.extend(buffer.text_for_range(self.range.clone())); RequestMessage { role: self.role, - content: buffer.text_for_range(self.range.clone()).collect(), + content, } } } @@ -1761,6 +1833,66 @@ mod tests { ); } + #[gpui::test] + fn test_messages_for_offsets(cx: &mut AppContext) { + let registry = Arc::new(LanguageRegistry::test()); + let assistant = cx.add_model(|cx| Assistant::new(Default::default(), registry, cx)); + let buffer = assistant.read(cx).buffer.clone(); + + let message_1 = assistant.read(cx).message_anchors[0].clone(); + assert_eq!( + messages(&assistant, cx), + vec![(message_1.id, Role::User, 0..0)] + ); + + buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx)); + let message_2 = assistant + .update(cx, |assistant, cx| { + assistant.insert_message_after(message_1.id, Role::User, cx) + }) + .unwrap(); + buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx)); + + let message_3 = assistant + .update(cx, |assistant, cx| { + assistant.insert_message_after(message_2.id, Role::User, cx) + }) + .unwrap(); + buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx)); + + assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc"); + assert_eq!( + messages(&assistant, cx), + vec![ + (message_1.id, Role::User, 0..4), + (message_2.id, Role::User, 4..8), + (message_3.id, Role::User, 8..11) + ] + ); + + assert_eq!( + message_ids_for_offsets(&assistant, &[0, 4, 9], cx), + [message_1.id, message_2.id, message_3.id] + ); + assert_eq!( + message_ids_for_offsets(&assistant, &[0, 1, 11], cx), + [message_1.id, message_3.id] + ); + + fn message_ids_for_offsets( + assistant: &ModelHandle, + offsets: &[usize], + cx: &AppContext, + ) -> Vec { + assistant + .read(cx) + .messages_for_offsets(offsets.iter().copied(), cx) + .into_iter() + .map(|message| message.id) + .collect() + } + } + fn messages( assistant: &ModelHandle, cx: &AppContext,