From 2ae8b558b991cf3beb5c998fb3168f7d4f70c369 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Tue, 13 Jun 2023 13:43:06 +0200 Subject: [PATCH] Get back to a compiling state with `Buffer` backing the assistant --- crates/ai/src/assistant.rs | 330 ++++++++++++++++++++++++------------- 1 file changed, 215 insertions(+), 115 deletions(-) diff --git a/crates/ai/src/assistant.rs b/crates/ai/src/assistant.rs index f7057b1e5b..31a5f66700 100644 --- a/crates/ai/src/assistant.rs +++ b/crates/ai/src/assistant.rs @@ -11,7 +11,7 @@ use editor::{ autoscroll::{Autoscroll, AutoscrollStrategy}, ScrollAnchor, }, - Anchor, DisplayPoint, Editor, ExcerptId, + Anchor, DisplayPoint, Editor, ToOffset as _, }; use fs::Fs; use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt}; @@ -25,10 +25,13 @@ use gpui::{ Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle, WindowContext, }; use isahc::{http::StatusCode, Request, RequestExt}; -use language::{language_settings::SoftWrap, Buffer, LanguageRegistry}; +use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _}; use serde::Deserialize; use settings::SettingsStore; -use std::{borrow::Cow, cell::RefCell, cmp, fmt::Write, io, rc::Rc, sync::Arc, time::Duration}; +use std::{ + borrow::Cow, cell::RefCell, cmp, fmt::Write, io, iter, ops::Range, rc::Rc, sync::Arc, + time::Duration, +}; use util::{post_inc, truncate_and_trailoff, ResultExt, TryFutureExt}; use workspace::{ dock::{DockPosition, Panel}, @@ -507,16 +510,16 @@ impl Assistant { fn count_remaining_tokens(&mut self, cx: &mut ModelContext) { let messages = self - .messages - .iter() + .open_ai_request_messages(cx) + .into_iter() .filter_map(|message| { Some(tiktoken_rs::ChatCompletionRequestMessage { - role: match self.messages_metadata.get(&message.excerpt_id)?.role { + role: match message.role { Role::User => "user".into(), Role::Assistant => "assistant".into(), Role::System => "system".into(), }, - content: message.content.read(cx).text(), + content: message.content, name: None, }) }) @@ -554,45 +557,47 @@ impl Assistant { } fn assist(&mut self, cx: &mut ModelContext) -> Option<(Message, Message)> { - let messages = self - .messages - .iter() - .filter_map(|message| { - Some(RequestMessage { - role: self.messages_metadata.get(&message.excerpt_id)?.role, - content: message.content.read(cx).text(), - }) - }) - .collect(); let request = OpenAIRequest { model: self.model.clone(), - messages, + messages: self.open_ai_request_messages(cx), stream: true, }; let api_key = self.api_key.borrow().clone()?; let stream = stream_completion(api_key, cx.background().clone(), request); - let assistant_message = self.insert_message_after(ExcerptId::max(), Role::Assistant, cx); - let user_message = self.insert_message_after(ExcerptId::max(), Role::User, cx); + let assistant_message = + self.insert_message_after(self.messages.last()?.id, Role::Assistant, cx)?; + let user_message = self.insert_message_after(assistant_message.id, Role::User, cx)?; let task = cx.spawn_weak({ - let assistant_message = assistant_message.clone(); |this, mut cx| async move { - let assistant_message = assistant_message; + let assistant_message_id = assistant_message.id; let stream_completion = async { let mut messages = stream.await?; while let Some(message) = messages.next().await { let mut message = message?; if let Some(choice) = message.choices.pop() { - assistant_message.content.update(&mut cx, |content, cx| { - let text: Arc = choice.delta.content?.into(); - content.edit([(content.len()..content.len(), text)], None, cx); - Some(()) - }); this.upgrade(&cx) .ok_or_else(|| anyhow!("assistant was dropped"))? - .update(&mut cx, |_, cx| { - cx.emit(AssistantEvent::StreamedCompletion); + .update(&mut cx, |this, cx| { + let text: Arc = choice.delta.content?.into(); + let message_ix = this + .messages + .iter() + .position(|message| message.id == assistant_message_id)?; + this.buffer.update(cx, |buffer, cx| { + let offset = if message_ix + 1 == this.messages.len() { + buffer.len() + } else { + this.messages[message_ix + 1] + .start + .to_offset(buffer) + .saturating_sub(1) + }; + buffer.edit([(offset..offset, text)], None, cx); + }); + + Some(()) }); } } @@ -612,9 +617,8 @@ impl Assistant { if let Some(this) = this.upgrade(&cx) { this.update(&mut cx, |this, cx| { if let Err(error) = result { - if let Some(metadata) = this - .messages_metadata - .get_mut(&assistant_message.excerpt_id) + if let Some(metadata) = + this.messages_metadata.get_mut(&assistant_message.id) { metadata.error = Some(error.to_string().trim().into()); cx.notify(); @@ -642,33 +646,33 @@ impl Assistant { protected_offsets: HashSet, cx: &mut ModelContext, ) { - let mut offset = 0; - let mut excerpts_to_remove = Vec::new(); - self.messages.retain(|message| { - let range = offset..offset + message.content.read(cx).len(); - offset = range.end + 1; - if range.is_empty() - && !protected_offsets.contains(&range.start) - && messages.contains(&message.id) - { - excerpts_to_remove.push(message.excerpt_id); - self.messages_metadata.remove(&message.excerpt_id); - false - } else { - true - } - }); + // let mut offset = 0; + // let mut excerpts_to_remove = Vec::new(); + // self.messages.retain(|message| { + // let range = offset..offset + message.content.read(cx).len(); + // offset = range.end + 1; + // if range.is_empty() + // && !protected_offsets.contains(&range.start) + // && messages.contains(&message.id) + // { + // excerpts_to_remove.push(message.excerpt_id); + // self.messages_metadata.remove(&message.excerpt_id); + // false + // } else { + // true + // } + // }); - if !excerpts_to_remove.is_empty() { - self.buffer.update(cx, |buffer, cx| { - buffer.remove_excerpts(excerpts_to_remove, cx) - }); - cx.notify(); - } + // if !excerpts_to_remove.is_empty() { + // self.buffer.update(cx, |buffer, cx| { + // buffer.remove_excerpts(excerpts_to_remove, cx) + // }); + // cx.notify(); + // } } - fn cycle_message_role(&mut self, excerpt_id: ExcerptId, cx: &mut ModelContext) { - if let Some(metadata) = self.messages_metadata.get_mut(&excerpt_id) { + fn cycle_message_role(&mut self, id: MessageId, cx: &mut ModelContext) { + if let Some(metadata) = self.messages_metadata.get_mut(&id) { metadata.role.cycle(); cx.notify(); } @@ -686,15 +690,18 @@ impl Assistant { .position(|message| message.id == message_id) { let start = self.buffer.update(cx, |buffer, cx| { - let len = buffer.len(); - buffer.edit([(len..len, "\n")], None, cx); - buffer.anchor_before(len + 1) + let offset = self + .messages + .get(prev_message_ix + 1) + .map_or(buffer.len(), |message| message.start.to_offset(buffer) - 1); + buffer.edit([(offset..offset, "\n")], None, cx); + buffer.anchor_before(offset + 1) }); let message = Message { id: MessageId(post_inc(&mut self.next_message_id.0)), start, }; - self.messages.insert(prev_message_ix, message.clone()); + self.messages.insert(prev_message_ix + 1, message.clone()); self.messages_metadata.insert( message.id, MessageMetadata { @@ -713,24 +720,13 @@ impl Assistant { if self.messages.len() >= 2 && self.summary.is_none() { let api_key = self.api_key.borrow().clone(); if let Some(api_key) = api_key { - // let messages = self - // .messages - // .iter() - // .take(2) - // .filter_map(|message| { - // Some(RequestMessage { - // role: self.messages_metadata.get(&message.id)?.role, - // content: message.content.read(cx).text(), - // }) - // }) - // .chain(Some(RequestMessage { - // role: Role::User, - // content: - // "Summarize the conversation into a short title without punctuation" - // .into(), - // })) - // .collect(); - let messages = todo!(); + let mut messages = self.open_ai_request_messages(cx); + messages.truncate(2); + messages.push(RequestMessage { + role: Role::User, + content: "Summarize the conversation into a short title without punctuation" + .into(), + }); let request = OpenAIRequest { model: self.model.clone(), messages, @@ -760,6 +756,44 @@ impl Assistant { } } } + + fn open_ai_request_messages(&self, cx: &AppContext) -> Vec { + let buffer = self.buffer.read(cx); + self.messages(cx) + .map(|(message, metadata, range)| RequestMessage { + role: metadata.role, + content: buffer.text_for_range(range).collect(), + }) + .collect() + } + + fn message_id_for_offset(&self, offset: usize, cx: &AppContext) -> Option { + Some( + self.messages(cx) + .find(|(_, _, range)| range.contains(&offset)) + .map(|(message, _, _)| message) + .or(self.messages.last())? + .id, + ) + } + + fn messages<'a>( + &'a self, + cx: &'a AppContext, + ) -> impl 'a + Iterator)> { + let buffer = self.buffer.read(cx); + let mut messages = self.messages.iter().peekable(); + iter::from_fn(move || { + let message = messages.next()?; + let metadata = self.messages_metadata.get(&message.id)?; + let message_start = message.start.to_offset(buffer); + let message_end = messages + .peek() + .map_or(language::Anchor::MAX, |message| message.start) + .to_offset(buffer); + Some((message, metadata, message_start..message_end)) + }) + } } struct PendingCompletion { @@ -812,16 +846,12 @@ impl AssistantEditor { fn assist(&mut self, _: &Assist, cx: &mut ViewContext) { let user_message = self.assistant.update(cx, |assistant, cx| { let editor = self.editor.read(cx); - let newest_selection = editor.selections.newest_anchor(); - let message_id = if newest_selection.head() == Anchor::min() { - assistant.messages.first().map(|message| message.id)? - } else if newest_selection.head() == Anchor::max() { - assistant.messages.last().map(|message| message.id)? - } else { - todo!() - // newest_selection.head().excerpt_id() - }; - + let newest_selection = editor + .selections + .newest_anchor() + .head() + .to_offset(&editor.buffer().read(cx).snapshot(cx)); + let message_id = assistant.message_id_for_offset(newest_selection, cx)?; let metadata = assistant.messages_metadata.get(&message_id)?; let user_message = if metadata.role == Role::User { let (_, user_message) = assistant.assist(cx)?; @@ -834,16 +864,14 @@ impl AssistantEditor { }); if let Some(user_message) = user_message { + let cursor = user_message + .start + .to_offset(&self.assistant.read(cx).buffer.read(cx)); self.editor.update(cx, |editor, cx| { - let cursor = editor - .buffer() - .read(cx) - .snapshot(cx) - .anchor_in_excerpt(Default::default(), user_message.start); editor.change_selections( Some(Autoscroll::Strategy(AutoscrollStrategy::Fit)), cx, - |selections| selections.select_anchor_ranges([cursor..cursor]), + |selections| selections.select_ranges([cursor..cursor]), ); }); self.update_scroll_bottom(cx); @@ -1011,7 +1039,7 @@ impl AssistantEditor { let mut copied_text = String::new(); let mut spanned_messages = 0; for message in &assistant.messages { - // TODO + todo!(); // let message_range = offset..offset + message.content.read(cx).len() + 1; let message_range = offset..offset + 1; @@ -1260,28 +1288,100 @@ mod tests { #[gpui::test] fn test_inserting_and_removing_messages(cx: &mut AppContext) { let registry = Arc::new(LanguageRegistry::test()); + let assistant = cx.add_model(|cx| Assistant::new(Default::default(), registry, cx)); + let buffer = assistant.read(cx).buffer.clone(); - cx.add_model(|cx| { - let mut assistant = Assistant::new(Default::default(), registry, cx); - let message_1 = assistant.messages[0].clone(); - let message_2 = assistant - .insert_message_after(message_1.id, Role::Assistant, cx) - .unwrap(); - let message_3 = assistant - .insert_message_after(message_2.id, Role::User, cx) - .unwrap(); - let message_4 = assistant - .insert_message_after(message_2.id, Role::User, cx) - .unwrap(); - assistant.remove_empty_messages( - HashSet::from_iter([message_3.id, message_4.id]), - Default::default(), - cx, - ); - assert_eq!(assistant.messages.len(), 2); - assert_eq!(assistant.messages[0].id, message_1.id); - assert_eq!(assistant.messages[1].id, message_2.id); + let message_1 = assistant.read(cx).messages[0].clone(); + assert_eq!( + messages(&assistant, cx), + vec![(message_1.id, Role::User, 0..0)] + ); + + let message_2 = assistant.update(cx, |assistant, cx| { assistant + .insert_message_after(message_1.id, Role::Assistant, cx) + .unwrap() }); + assert_eq!( + messages(&assistant, cx), + vec![ + (message_1.id, Role::User, 0..1), + (message_2.id, Role::Assistant, 1..1) + ] + ); + + buffer.update(cx, |buffer, cx| { + buffer.edit([(0..0, "1"), (1..1, "2")], None, cx) + }); + assert_eq!( + messages(&assistant, cx), + vec![ + (message_1.id, Role::User, 0..2), + (message_2.id, Role::Assistant, 2..3) + ] + ); + + let message_3 = assistant.update(cx, |assistant, cx| { + assistant + .insert_message_after(message_2.id, Role::User, cx) + .unwrap() + }); + assert_eq!( + messages(&assistant, cx), + vec![ + (message_1.id, Role::User, 0..2), + (message_2.id, Role::Assistant, 2..4), + (message_3.id, Role::User, 4..4) + ] + ); + + let message_4 = assistant.update(cx, |assistant, cx| { + assistant + .insert_message_after(message_2.id, Role::User, cx) + .unwrap() + }); + assert_eq!( + messages(&assistant, cx), + vec![ + (message_1.id, Role::User, 0..2), + (message_2.id, Role::Assistant, 2..4), + (message_4.id, Role::User, 4..5), + (message_3.id, Role::User, 5..5), + ] + ); + + buffer.update(cx, |buffer, cx| { + buffer.edit([(4..4, "C"), (5..5, "D")], None, cx) + }); + assert_eq!( + messages(&assistant, cx), + vec![ + (message_1.id, Role::User, 0..2), + (message_2.id, Role::Assistant, 2..4), + (message_4.id, Role::User, 4..6), + (message_3.id, Role::User, 6..7), + ] + ); + + // Deleting across message boundaries merges the messages. + buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx)); + assert_eq!( + messages(&assistant, cx), + vec![ + (message_1.id, Role::User, 0..6), + (message_3.id, Role::User, 6..7), + ] + ); + } + + fn messages( + assistant: &ModelHandle, + cx: &AppContext, + ) -> Vec<(MessageId, Role, Range)> { + assistant + .read(cx) + .messages(cx) + .map(|(message, metadata, range)| (message.id, metadata.role, range)) + .collect() } }