From 7db690b713a547fbc1f2e7567a273df8916734c8 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 12 Jun 2023 17:50:13 +0200 Subject: [PATCH] WIP --- Cargo.lock | 2 +- crates/ai/src/assistant.rs | 314 +++++++++++++------------------------ 2 files changed, 110 insertions(+), 206 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8ef48849ac..fba46c59cf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7559,7 +7559,7 @@ dependencies = [ [[package]] name = "tree-sitter-yaml" version = "0.0.1" -source = "git+https://github.com/zed-industries/tree-sitter-yaml?rev=5694b7f290cd9ef998829a0a6d8391a666370886#5694b7f290cd9ef998829a0a6d8391a666370886" +source = "git+https://github.com/zed-industries/tree-sitter-yaml?rev=f545a41f57502e1b5ddf2a6668896c1b0620f930#f545a41f57502e1b5ddf2a6668896c1b0620f930" dependencies = [ "cc", "tree-sitter", diff --git a/crates/ai/src/assistant.rs b/crates/ai/src/assistant.rs index 77353e1ee4..f7057b1e5b 100644 --- a/crates/ai/src/assistant.rs +++ b/crates/ai/src/assistant.rs @@ -11,7 +11,7 @@ use editor::{ autoscroll::{Autoscroll, AutoscrollStrategy}, ScrollAnchor, }, - Anchor, DisplayPoint, Editor, ExcerptId, ExcerptRange, MultiBuffer, + Anchor, DisplayPoint, Editor, ExcerptId, }; use fs::Fs; use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt}; @@ -420,15 +420,16 @@ impl Panel for AssistantPanel { } enum AssistantEvent { - MessagesEdited { ids: Vec }, + MessagesEdited, SummaryChanged, StreamedCompletion, } struct Assistant { - buffer: ModelHandle, + buffer: ModelHandle, messages: Vec, - messages_metadata: HashMap, + messages_metadata: HashMap, + next_message_id: MessageId, summary: Option, pending_summary: Task>, completion_count: usize, @@ -453,10 +454,11 @@ impl Assistant { cx: &mut ModelContext, ) -> Self { let model = "gpt-3.5-turbo"; - let buffer = cx.add_model(|_| MultiBuffer::new(0)); + let buffer = cx.add_model(|cx| Buffer::new(0, "", cx)); let mut this = Self { messages: Default::default(), messages_metadata: Default::default(), + next_message_id: Default::default(), summary: None, pending_summary: Task::ready(None), completion_count: Default::default(), @@ -470,23 +472,34 @@ impl Assistant { api_key, buffer, }; - this.insert_message_after(ExcerptId::max(), Role::User, cx); + let message = Message { + id: MessageId(post_inc(&mut this.next_message_id.0)), + start: language::Anchor::MIN, + }; + this.messages.push(message.clone()); + this.messages_metadata.insert( + message.id, + MessageMetadata { + role: Role::User, + sent_at: Local::now(), + error: None, + }, + ); + this.count_remaining_tokens(cx); this } fn handle_buffer_event( &mut self, - _: ModelHandle, - event: &editor::multi_buffer::Event, + _: ModelHandle, + event: &language::Event, cx: &mut ModelContext, ) { match event { - editor::multi_buffer::Event::ExcerptsAdded { .. } - | editor::multi_buffer::Event::ExcerptsRemoved { .. } - | editor::multi_buffer::Event::Edited => self.count_remaining_tokens(cx), - editor::multi_buffer::Event::ExcerptsEdited { ids } => { - cx.emit(AssistantEvent::MessagesEdited { ids: ids.clone() }); + language::Event::Edited => { + self.count_remaining_tokens(cx); + cx.emit(AssistantEvent::MessagesEdited); } _ => {} } @@ -625,7 +638,7 @@ impl Assistant { fn remove_empty_messages<'a>( &mut self, - excerpts: HashSet, + messages: HashSet, protected_offsets: HashSet, cx: &mut ModelContext, ) { @@ -636,7 +649,7 @@ impl Assistant { offset = range.end + 1; if range.is_empty() && !protected_offsets.contains(&range.start) - && excerpts.contains(&message.excerpt_id) + && messages.contains(&message.id) { excerpts_to_remove.push(message.excerpt_id); self.messages_metadata.remove(&message.excerpt_id); @@ -663,84 +676,61 @@ impl Assistant { fn insert_message_after( &mut self, - excerpt_id: ExcerptId, + message_id: MessageId, role: Role, cx: &mut ModelContext, - ) -> Message { - let content = cx.add_model(|cx| { - let mut buffer = Buffer::new(0, "", cx); - let markdown = self.languages.language_for_name("Markdown"); - cx.spawn_weak(|buffer, mut cx| async move { - let markdown = markdown.await?; - let buffer = buffer - .upgrade(&cx) - .ok_or_else(|| anyhow!("buffer was dropped"))?; - buffer.update(&mut cx, |buffer, cx| { - buffer.set_language(Some(markdown), cx) - }); - anyhow::Ok(()) - }) - .detach_and_log_err(cx); - buffer.set_language_registry(self.languages.clone()); - buffer - }); - let new_excerpt_id = self.buffer.update(cx, |buffer, cx| { - buffer - .insert_excerpts_after( - excerpt_id, - content.clone(), - vec![ExcerptRange { - context: 0..0, - primary: None, - }], - cx, - ) - .pop() - .unwrap() - }); - - let ix = self + ) -> Option { + if let Some(prev_message_ix) = self .messages .iter() - .position(|message| message.excerpt_id == excerpt_id) - .map_or(self.messages.len(), |ix| ix + 1); - let message = Message { - excerpt_id: new_excerpt_id, - content: content.clone(), - }; - self.messages.insert(ix, message.clone()); - self.messages_metadata.insert( - new_excerpt_id, - MessageMetadata { - role, - sent_at: Local::now(), - error: None, - }, - ); - message + .position(|message| message.id == message_id) + { + let start = self.buffer.update(cx, |buffer, cx| { + let len = buffer.len(); + buffer.edit([(len..len, "\n")], None, cx); + buffer.anchor_before(len + 1) + }); + let message = Message { + id: MessageId(post_inc(&mut self.next_message_id.0)), + start, + }; + self.messages.insert(prev_message_ix, message.clone()); + self.messages_metadata.insert( + message.id, + MessageMetadata { + role, + sent_at: Local::now(), + error: None, + }, + ); + Some(message) + } else { + None + } } fn summarize(&mut self, cx: &mut ModelContext) { if self.messages.len() >= 2 && self.summary.is_none() { let api_key = self.api_key.borrow().clone(); if let Some(api_key) = api_key { - let messages = self - .messages - .iter() - .take(2) - .filter_map(|message| { - Some(RequestMessage { - role: self.messages_metadata.get(&message.excerpt_id)?.role, - content: message.content.read(cx).text(), - }) - }) - .chain(Some(RequestMessage { - role: Role::User, - content: - "Summarize the conversation into a short title without punctuation" - .into(), - })) - .collect(); + // let messages = self + // .messages + // .iter() + // .take(2) + // .filter_map(|message| { + // Some(RequestMessage { + // role: self.messages_metadata.get(&message.id)?.role, + // content: message.content.read(cx).text(), + // }) + // }) + // .chain(Some(RequestMessage { + // role: Role::User, + // content: + // "Summarize the conversation into a short title without punctuation" + // .into(), + // })) + // .collect(); + let messages = todo!(); let request = OpenAIRequest { model: self.model.clone(), messages, @@ -796,98 +786,9 @@ impl AssistantEditor { ) -> Self { let assistant = cx.add_model(|cx| Assistant::new(api_key, language_registry, cx)); let editor = cx.add_view(|cx| { - let mut editor = Editor::for_multibuffer(assistant.read(cx).buffer.clone(), None, cx); + let mut editor = Editor::for_buffer(assistant.read(cx).buffer.clone(), None, cx); editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx); editor.set_show_gutter(false, cx); - editor.set_render_excerpt_header( - { - let assistant = assistant.clone(); - move |_editor, params: editor::RenderExcerptHeaderParams, cx| { - enum Sender {} - enum ErrorTooltip {} - - let theme = theme::current(cx); - let style = &theme.assistant; - let excerpt_id = params.id; - if let Some(metadata) = assistant - .read(cx) - .messages_metadata - .get(&excerpt_id) - .cloned() - { - let sender = MouseEventHandler::::new( - params.id.into(), - cx, - |state, _| match metadata.role { - Role::User => { - let style = style.user_sender.style_for(state, false); - Label::new("You", style.text.clone()) - .contained() - .with_style(style.container) - } - Role::Assistant => { - let style = style.assistant_sender.style_for(state, false); - Label::new("Assistant", style.text.clone()) - .contained() - .with_style(style.container) - } - Role::System => { - let style = style.system_sender.style_for(state, false); - Label::new("System", style.text.clone()) - .contained() - .with_style(style.container) - } - }, - ) - .with_cursor_style(CursorStyle::PointingHand) - .on_down(MouseButton::Left, { - let assistant = assistant.clone(); - move |_, _, cx| { - assistant.update(cx, |assistant, cx| { - assistant.cycle_message_role(excerpt_id, cx) - }) - } - }); - - Flex::row() - .with_child(sender.aligned()) - .with_child( - Label::new( - metadata.sent_at.format("%I:%M%P").to_string(), - style.sent_at.text.clone(), - ) - .contained() - .with_style(style.sent_at.container) - .aligned(), - ) - .with_children(metadata.error.map(|error| { - Svg::new("icons/circle_x_mark_12.svg") - .with_color(style.error_icon.color) - .constrained() - .with_width(style.error_icon.width) - .contained() - .with_style(style.error_icon.container) - .with_tooltip::( - params.id.into(), - error, - None, - theme.tooltip.clone(), - cx, - ) - .aligned() - })) - .aligned() - .left() - .contained() - .with_style(style.header) - .into_any() - } else { - Empty::new().into_any() - } - } - }, - cx, - ); editor }); @@ -912,26 +813,21 @@ impl AssistantEditor { let user_message = self.assistant.update(cx, |assistant, cx| { let editor = self.editor.read(cx); let newest_selection = editor.selections.newest_anchor(); - let excerpt_id = if newest_selection.head() == Anchor::min() { - assistant - .messages - .first() - .map(|message| message.excerpt_id)? + let message_id = if newest_selection.head() == Anchor::min() { + assistant.messages.first().map(|message| message.id)? } else if newest_selection.head() == Anchor::max() { - assistant - .messages - .last() - .map(|message| message.excerpt_id)? + assistant.messages.last().map(|message| message.id)? } else { - newest_selection.head().excerpt_id() + todo!() + // newest_selection.head().excerpt_id() }; - let metadata = assistant.messages_metadata.get(&excerpt_id)?; + let metadata = assistant.messages_metadata.get(&message_id)?; let user_message = if metadata.role == Role::User { let (_, user_message) = assistant.assist(cx)?; user_message } else { - let user_message = assistant.insert_message_after(excerpt_id, Role::User, cx); + let user_message = assistant.insert_message_after(message_id, Role::User, cx)?; user_message }; Some(user_message) @@ -943,7 +839,7 @@ impl AssistantEditor { .buffer() .read(cx) .snapshot(cx) - .anchor_in_excerpt(user_message.excerpt_id, language::Anchor::MIN); + .anchor_in_excerpt(Default::default(), user_message.start); editor.change_selections( Some(Autoscroll::Strategy(AutoscrollStrategy::Fit)), cx, @@ -970,16 +866,16 @@ impl AssistantEditor { cx: &mut ViewContext, ) { match event { - AssistantEvent::MessagesEdited { ids } => { + AssistantEvent::MessagesEdited => { let selections = self.editor.read(cx).selections.all::(cx); let selection_heads = selections .iter() .map(|selection| selection.head()) .collect::>(); - let ids = ids.iter().copied().collect::>(); - self.assistant.update(cx, |assistant, cx| { - assistant.remove_empty_messages(ids, selection_heads, cx) - }); + // let ids = ids.iter().copied().collect::>(); + // self.assistant.update(cx, |assistant, cx| { + // assistant.remove_empty_messages(ids, selection_heads, cx) + // }); } AssistantEvent::SummaryChanged => { cx.emit(AssistantEditorEvent::TabContentChanged); @@ -1115,7 +1011,9 @@ impl AssistantEditor { let mut copied_text = String::new(); let mut spanned_messages = 0; for message in &assistant.messages { - let message_range = offset..offset + message.content.read(cx).len() + 1; + // TODO + // let message_range = offset..offset + message.content.read(cx).len() + 1; + let message_range = offset..offset + 1; if message_range.start >= selection.range().end { break; @@ -1123,13 +1021,10 @@ impl AssistantEditor { let range = cmp::max(message_range.start, selection.range().start) ..cmp::min(message_range.end, selection.range().end); if !range.is_empty() { - if let Some(metadata) = assistant.messages_metadata.get(&message.excerpt_id) - { + if let Some(metadata) = assistant.messages_metadata.get(&message.id) { spanned_messages += 1; write!(&mut copied_text, "## {}\n\n", metadata.role).unwrap(); - for chunk in - assistant.buffer.read(cx).snapshot(cx).text_for_range(range) - { + for chunk in assistant.buffer.read(cx).text_for_range(range) { copied_text.push_str(&chunk); } copied_text.push('\n'); @@ -1255,10 +1150,13 @@ impl Item for AssistantEditor { } } +#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Hash)] +struct MessageId(usize); + #[derive(Clone, Debug)] struct Message { - excerpt_id: ExcerptId, - content: ModelHandle, + id: MessageId, + start: language::Anchor, } #[derive(Clone, Debug)] @@ -1366,17 +1264,23 @@ mod tests { cx.add_model(|cx| { let mut assistant = Assistant::new(Default::default(), registry, cx); let message_1 = assistant.messages[0].clone(); - let message_2 = assistant.insert_message_after(ExcerptId::max(), Role::Assistant, cx); - let message_3 = assistant.insert_message_after(message_2.excerpt_id, Role::User, cx); - let message_4 = assistant.insert_message_after(message_2.excerpt_id, Role::User, cx); + let message_2 = assistant + .insert_message_after(message_1.id, Role::Assistant, cx) + .unwrap(); + let message_3 = assistant + .insert_message_after(message_2.id, Role::User, cx) + .unwrap(); + let message_4 = assistant + .insert_message_after(message_2.id, Role::User, cx) + .unwrap(); assistant.remove_empty_messages( - HashSet::from_iter([message_3.excerpt_id, message_4.excerpt_id]), + HashSet::from_iter([message_3.id, message_4.id]), Default::default(), cx, ); assert_eq!(assistant.messages.len(), 2); - assert_eq!(assistant.messages[0].excerpt_id, message_1.excerpt_id); - assert_eq!(assistant.messages[1].excerpt_id, message_2.excerpt_id); + assert_eq!(assistant.messages[0].id, message_1.id); + assert_eq!(assistant.messages[1].id, message_2.id); assistant }); }