diff --git a/crates/assistant2/src/assistant2.rs b/crates/assistant2/src/assistant2.rs index 2dc1415371..ba7c607bd5 100644 --- a/crates/assistant2/src/assistant2.rs +++ b/crates/assistant2/src/assistant2.rs @@ -231,12 +231,18 @@ pub struct AssistantChat { user_store: Model, next_message_id: MessageId, collapsed_messages: HashMap, - editing_message_id: Option, + editing_message: Option, pending_completion: Option>, tool_registry: Arc, project_index: Option>, } +struct EditingMessage { + id: MessageId, + old_body: Arc, + body: View, +} + impl AssistantChat { fn new( language_registry: Arc, @@ -271,13 +277,17 @@ impl AssistantChat { language_registry, project_index, next_message_id: MessageId(0), - editing_message_id: None, + editing_message: None, collapsed_messages: HashMap::default(), pending_completion: None, tool_registry, } } + fn editing_message_id(&self) -> Option { + self.editing_message.as_ref().map(|message| message.id) + } + fn focused_message_id(&self, cx: &WindowContext) -> Option { self.messages.iter().find_map(|message| match message { ChatMessage::User(message) => message @@ -291,30 +301,40 @@ impl AssistantChat { fn cancel(&mut self, _: &Cancel, cx: &mut ViewContext) { // If we're currently editing a message, cancel the edit. - self.editing_message_id.take(); - - if self.pending_completion.take().is_none() { - cx.propagate(); + if let Some(editing_message) = self.editing_message.take() { + editing_message + .body + .update(cx, |body, cx| body.set_text(editing_message.old_body, cx)); return; } - if let Some(ChatMessage::Assistant(message)) = self.messages.last() { - if message.body.text.is_empty() { - self.pop_message(cx); + if self.pending_completion.take().is_some() { + if let Some(ChatMessage::Assistant(message)) = self.messages.last() { + if message.body.text.is_empty() { + self.pop_message(cx); + } } + return; } + + cx.propagate(); } fn submit(&mut self, Submit(mode): &Submit, cx: &mut ViewContext) { - // Don't allow multiple concurrent completions. - if self.pending_completion.is_some() { - cx.propagate(); - return; - } - if let Some(focused_message_id) = self.focused_message_id(cx) { self.truncate_messages(focused_message_id, cx); + self.pending_completion.take(); + self.composer_editor.focus_handle(cx).focus(cx); + if self.editing_message_id() == Some(focused_message_id) { + self.editing_message.take(); + } } else if self.composer_editor.focus_handle(cx).is_focused(cx) { + // Don't allow multiple concurrent completions. + if self.pending_completion.is_some() { + cx.propagate(); + return; + } + let message = self.composer_editor.update(cx, |composer_editor, cx| { let text = composer_editor.text(cx); let id = self.next_message_id.post_inc(); @@ -344,9 +364,7 @@ impl AssistantChat { .await .log_err(); - this.update(&mut cx, |this, cx| { - let composer_focus_handle = this.composer_editor.focus_handle(cx); - cx.focus(&composer_focus_handle); + this.update(&mut cx, |this, _cx| { this.pending_completion = None; }) .context("Failed to push new user message") @@ -572,11 +590,11 @@ impl AssistantChat { .id(SharedString::from(format!("message-{}-container", id.0))) .when(!is_last, |element| element.mb_2()) .map(|element| { - if self.editing_message_id.as_ref() == Some(id) { + if self.editing_message_id() == Some(*id) { element.child(Composer::new( body.clone(), self.user_store.read(cx).current_user(), - self.can_submit(), + true, self.tool_registry.clone(), crate::ui::ModelSelector::new( cx.view().downgrade(), @@ -588,9 +606,15 @@ impl AssistantChat { element .on_click(cx.listener({ let id = *id; - move |assistant_chat, event: &ClickEvent, _cx| { + let body = body.clone(); + move |assistant_chat, event: &ClickEvent, cx| { if event.up.click_count == 2 { - assistant_chat.editing_message_id = Some(id); + assistant_chat.editing_message = Some(EditingMessage { + id, + body: body.clone(), + old_body: body.read(cx).text(cx).into(), + }); + body.focus_handle(cx).focus(cx); } } })) diff --git a/crates/assistant2/src/ui/composer.rs b/crates/assistant2/src/ui/composer.rs index 81c1d4340f..a0da37f21f 100644 --- a/crates/assistant2/src/ui/composer.rs +++ b/crates/assistant2/src/ui/composer.rs @@ -69,8 +69,7 @@ impl RenderOnce for Composer { v_flex() .justify_between() .w_full() - .gap_1() - .min_h(line_height * 4 + px(74.0)) + .gap_2() .child({ let settings = ThemeSettings::get_global(cx); let text_style = TextStyle {