diff --git a/crates/ai/Cargo.toml b/crates/ai/Cargo.toml index 785bc657cf..013565e14f 100644 --- a/crates/ai/Cargo.toml +++ b/crates/ai/Cargo.toml @@ -22,7 +22,7 @@ util = { path = "../util" } workspace = { path = "../workspace" } anyhow.workspace = true -chrono = "0.4" +chrono = { version = "0.4", features = ["serde"] } futures.workspace = true isahc.workspace = true regex.workspace = true diff --git a/crates/ai/src/ai.rs b/crates/ai/src/ai.rs index 75aabe561c..c1e6a4c569 100644 --- a/crates/ai/src/ai.rs +++ b/crates/ai/src/ai.rs @@ -3,6 +3,8 @@ mod assistant_settings; use anyhow::Result; pub use assistant::AssistantPanel; +use chrono::{DateTime, Local}; +use collections::HashMap; use fs::Fs; use futures::StreamExt; use gpui::AppContext; @@ -12,7 +14,6 @@ use std::{ fmt::{self, Display}, path::PathBuf, sync::Arc, - time::SystemTime, }; use util::paths::CONVERSATIONS_DIR; @@ -24,11 +25,44 @@ struct OpenAIRequest { stream: bool, } +#[derive( + Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize, +)] +struct MessageId(usize); + +#[derive(Clone, Debug, Serialize, Deserialize)] +struct MessageMetadata { + role: Role, + sent_at: DateTime, + status: MessageStatus, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +enum MessageStatus { + Pending, + Done, + Error(Arc), +} + +#[derive(Serialize, Deserialize)] +struct SavedMessage { + id: MessageId, + start: usize, +} + #[derive(Serialize, Deserialize)] struct SavedConversation { zed: String, version: String, - messages: Vec, + text: String, + messages: Vec, + message_metadata: HashMap, + summary: String, + model: String, +} + +impl SavedConversation { + const VERSION: &'static str = "0.1.0"; } struct SavedConversationMetadata { diff --git a/crates/ai/src/assistant.rs b/crates/ai/src/assistant.rs index a21ab29fc0..e8faa48000 100644 --- a/crates/ai/src/assistant.rs +++ b/crates/ai/src/assistant.rs @@ -1,7 +1,7 @@ use crate::{ assistant_settings::{AssistantDockPosition, AssistantSettings}, - OpenAIRequest, OpenAIResponseStreamEvent, RequestMessage, Role, SavedConversation, - SavedConversationMetadata, + MessageId, MessageMetadata, MessageStatus, OpenAIRequest, OpenAIResponseStreamEvent, + RequestMessage, Role, SavedConversation, SavedConversationMetadata, SavedMessage, }; use anyhow::{anyhow, Result}; use chrono::{DateTime, Local}; @@ -27,10 +27,18 @@ use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset a use serde::Deserialize; use settings::SettingsStore; use std::{ - borrow::Cow, cell::RefCell, cmp, env, fmt::Write, io, iter, ops::Range, path::PathBuf, rc::Rc, - sync::Arc, time::Duration, + borrow::Cow, + cell::RefCell, + cmp, env, + fmt::Write, + io, iter, + ops::Range, + path::{Path, PathBuf}, + rc::Rc, + sync::Arc, + time::Duration, }; -use theme::{ui::IconStyle, IconButton, Theme}; +use theme::ui::IconStyle; use util::{ channel::ReleaseChannel, paths::CONVERSATIONS_DIR, post_inc, truncate_and_trailoff, ResultExt, TryFutureExt, @@ -68,7 +76,7 @@ pub fn init(cx: &mut AppContext) { |workspace: &mut Workspace, _: &NewContext, cx: &mut ViewContext| { if let Some(this) = workspace.panel::(cx) { this.update(cx, |this, cx| { - this.add_conversation(cx); + this.new_conversation(cx); }) } @@ -187,13 +195,8 @@ impl AssistantPanel { }) } - fn add_conversation(&mut self, cx: &mut ViewContext) -> ViewHandle { - let focus = self.has_focus(cx); + fn new_conversation(&mut self, cx: &mut ViewContext) -> ViewHandle { let editor = cx.add_view(|cx| { - if focus { - cx.focus_self(); - } - ConversationEditor::new( self.api_key.clone(), self.languages.clone(), @@ -201,14 +204,24 @@ impl AssistantPanel { cx, ) }); + self.add_conversation(editor.clone(), cx); + editor + } + + fn add_conversation( + &mut self, + editor: ViewHandle, + cx: &mut ViewContext, + ) { self.subscriptions .push(cx.subscribe(&editor, Self::handle_conversation_editor_event)); self.active_conversation_index = Some(self.conversation_editors.len()); self.conversation_editors.push(editor.clone()); - + if self.has_focus(cx) { + cx.focus(&editor); + } cx.notify(); - editor } fn handle_conversation_editor_event( @@ -264,9 +277,28 @@ impl AssistantPanel { } fn render_hamburger_button(style: &IconStyle) -> impl Element { + enum ListConversations {} Svg::for_style(style.icon.clone()) .contained() .with_style(style.container) + .mouse::(0) + .with_cursor_style(CursorStyle::PointingHand) + .on_click(MouseButton::Left, |_, this: &mut Self, cx| { + this.active_conversation_index = None; + cx.notify(); + }) + } + + fn render_plus_button(style: &IconStyle) -> impl Element { + enum AddConversation {} + Svg::for_style(style.icon.clone()) + .contained() + .with_style(style.container) + .mouse::(0) + .with_cursor_style(CursorStyle::PointingHand) + .on_click(MouseButton::Left, |_, this: &mut Self, cx| { + this.new_conversation(cx); + }) } fn render_saved_conversation( @@ -274,20 +306,23 @@ impl AssistantPanel { index: usize, cx: &mut ViewContext, ) -> impl Element { + let conversation = &self.saved_conversations[index]; + let path = conversation.path.clone(); MouseEventHandler::::new(index, cx, move |state, cx| { let style = &theme::current(cx).assistant.saved_conversation; - let conversation = &self.saved_conversations[index]; Flex::row() .with_child( Label::new( - conversation.mtime.format("%c").to_string(), + conversation.mtime.format("%F %I:%M%p").to_string(), style.saved_at.text.clone(), ) + .aligned() .contained() .with_style(style.saved_at.container), ) .with_child( Label::new(conversation.title.clone(), style.title.text.clone()) + .aligned() .contained() .with_style(style.title.container), ) @@ -295,7 +330,48 @@ impl AssistantPanel { .with_style(*style.container.style_for(state, false)) }) .with_cursor_style(CursorStyle::PointingHand) - .on_click(MouseButton::Left, |_, this, cx| {}) + .on_click(MouseButton::Left, move |_, this, cx| { + this.open_conversation(path.clone(), cx) + .detach_and_log_err(cx) + }) + } + + fn open_conversation(&mut self, path: PathBuf, cx: &mut ViewContext) -> Task> { + if let Some(ix) = self.conversation_editor_index_for_path(&path, cx) { + self.active_conversation_index = Some(ix); + cx.notify(); + return Task::ready(Ok(())); + } + + let fs = self.fs.clone(); + let conversation = Conversation::load( + path.clone(), + self.api_key.clone(), + self.languages.clone(), + self.fs.clone(), + cx, + ); + cx.spawn(|this, mut cx| async move { + let conversation = conversation.await?; + this.update(&mut cx, |this, cx| { + // If, by the time we've loaded the conversation, the user has already opened + // the same conversation, we don't want to open it again. + if let Some(ix) = this.conversation_editor_index_for_path(&path, cx) { + this.active_conversation_index = Some(ix); + } else { + let editor = cx + .add_view(|cx| ConversationEditor::from_conversation(conversation, fs, cx)); + this.add_conversation(editor, cx); + } + })?; + Ok(()) + }) + } + + fn conversation_editor_index_for_path(&self, path: &Path, cx: &AppContext) -> Option { + self.conversation_editors + .iter() + .position(|editor| editor.read(cx).conversation.read(cx).path.as_deref() == Some(path)) } } @@ -341,30 +417,37 @@ impl View for AssistantPanel { .with_style(style.api_key_prompt.container) .aligned() .into_any() - } else if let Some(editor) = self.active_conversation_editor() { + } else { Flex::column() .with_child( Flex::row() - .with_child(Self::render_hamburger_button(&style.hamburger_button)) + .with_child( + Self::render_hamburger_button(&style.hamburger_button).aligned(), + ) + .with_child(Self::render_plus_button(&style.plus_button).aligned()) .contained() .with_style(theme.workspace.tab_bar.container) + .expanded() .constrained() .with_height(theme.workspace.tab_bar.height), ) - .with_child(ChildView::new(editor, cx).flex(1., true)) + .with_child(if let Some(editor) = self.active_conversation_editor() { + ChildView::new(editor, cx).flex(1., true).into_any() + } else { + UniformList::new( + self.saved_conversations_list_state.clone(), + self.saved_conversations.len(), + cx, + |this, range, items, cx| { + for ix in range { + items.push(this.render_saved_conversation(ix, cx).into_any()); + } + }, + ) + .flex(1., true) + .into_any() + }) .into_any() - } else { - UniformList::new( - self.saved_conversations_list_state.clone(), - self.saved_conversations.len(), - cx, - |this, range, items, cx| { - for ix in range { - items.push(this.render_saved_conversation(ix, cx).into_any()); - } - }, - ) - .into_any() } } @@ -468,7 +551,7 @@ impl Panel for AssistantPanel { } if self.conversation_editors.is_empty() { - self.add_conversation(cx); + self.new_conversation(cx); } } } @@ -598,6 +681,74 @@ impl Conversation { this } + fn load( + path: PathBuf, + api_key: Rc>>, + language_registry: Arc, + fs: Arc, + cx: &mut AppContext, + ) -> Task>> { + cx.spawn(|mut cx| async move { + let saved_conversation = fs.load(&path).await?; + let saved_conversation: SavedConversation = serde_json::from_str(&saved_conversation)?; + + let model = saved_conversation.model; + let markdown = language_registry.language_for_name("Markdown"); + let mut message_anchors = Vec::new(); + let mut next_message_id = MessageId(0); + let buffer = cx.add_model(|cx| { + let mut buffer = Buffer::new(0, saved_conversation.text, cx); + for message in saved_conversation.messages { + message_anchors.push(MessageAnchor { + id: message.id, + start: buffer.anchor_before(message.start), + }); + next_message_id = cmp::max(next_message_id, MessageId(message.id.0 + 1)); + } + buffer.set_language_registry(language_registry); + cx.spawn_weak(|buffer, mut cx| async move { + let markdown = markdown.await?; + let buffer = buffer + .upgrade(&cx) + .ok_or_else(|| anyhow!("buffer was dropped"))?; + buffer.update(&mut cx, |buffer, cx| { + buffer.set_language(Some(markdown), cx) + }); + anyhow::Ok(()) + }) + .detach_and_log_err(cx); + buffer + }); + let conversation = cx.add_model(|cx| { + let mut this = Self { + message_anchors, + messages_metadata: saved_conversation.message_metadata, + next_message_id, + summary: Some(Summary { + text: saved_conversation.summary, + done: true, + }), + pending_summary: Task::ready(None), + completion_count: Default::default(), + pending_completions: Default::default(), + token_count: None, + max_token_count: tiktoken_rs::model::get_context_size(&model), + pending_token_count: Task::ready(None), + model, + _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)], + pending_save: Task::ready(Ok(())), + path: Some(path), + api_key, + buffer, + }; + + this.count_remaining_tokens(cx); + this + }); + Ok(conversation) + }) + } + fn handle_buffer_event( &mut self, _: ModelHandle, @@ -1122,15 +1273,22 @@ impl Conversation { }); if let Some(summary) = summary { - let conversation = SavedConversation { + let conversation = this.read_with(&cx, |this, cx| SavedConversation { zed: "conversation".into(), - version: "0.1".into(), - messages: this.read_with(&cx, |this, cx| { - this.messages(cx) - .map(|message| message.to_open_ai_message(this.buffer.read(cx))) - .collect() - }), - }; + version: SavedConversation::VERSION.into(), + text: this.buffer.read(cx).text(), + message_metadata: this.messages_metadata.clone(), + messages: this + .message_anchors + .iter() + .map(|message| SavedMessage { + id: message.id, + start: message.start.to_offset(this.buffer.read(cx)), + }) + .collect(), + summary: summary.clone(), + model: this.model.clone(), + }); let path = if let Some(old_path) = old_path { old_path @@ -1195,6 +1353,14 @@ impl ConversationEditor { cx: &mut ViewContext, ) -> Self { let conversation = cx.add_model(|cx| Conversation::new(api_key, language_registry, cx)); + Self::from_conversation(conversation, fs, cx) + } + + fn from_conversation( + conversation: ModelHandle, + fs: Arc, + cx: &mut ViewContext, + ) -> Self { let editor = cx.add_view(|cx| { let mut editor = Editor::for_buffer(conversation.read(cx).buffer.clone(), None, cx); editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx); @@ -1524,7 +1690,7 @@ impl ConversationEditor { let conversation = panel .active_conversation_editor() .cloned() - .unwrap_or_else(|| panel.add_conversation(cx)); + .unwrap_or_else(|| panel.new_conversation(cx)); conversation.update(cx, |conversation, cx| { conversation .editor @@ -1693,29 +1859,12 @@ impl Item for ConversationEditor { } } -#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Hash)] -struct MessageId(usize); - #[derive(Clone, Debug)] struct MessageAnchor { id: MessageId, start: language::Anchor, } -#[derive(Clone, Debug)] -struct MessageMetadata { - role: Role, - sent_at: DateTime, - status: MessageStatus, -} - -#[derive(Clone, Debug)] -enum MessageStatus { - Pending, - Done, - Error(Arc), -} - #[derive(Clone, Debug)] pub struct Message { range: Range, @@ -1733,7 +1882,7 @@ impl Message { content.extend(buffer.text_for_range(self.range.clone())); RequestMessage { role: self.role, - content, + content: content.trim_end().into(), } } } @@ -1826,6 +1975,8 @@ async fn stream_completion( #[cfg(test)] mod tests { + use crate::MessageId; + use super::*; use fs::FakeFs; use gpui::{AppContext, TestAppContext}; diff --git a/crates/theme/src/theme.rs b/crates/theme/src/theme.rs index 2ab12f1d96..d76c3432d1 100644 --- a/crates/theme/src/theme.rs +++ b/crates/theme/src/theme.rs @@ -995,6 +995,7 @@ pub struct TerminalStyle { pub struct AssistantStyle { pub container: ContainerStyle, pub hamburger_button: IconStyle, + pub plus_button: IconStyle, pub message_header: ContainerStyle, pub sent_at: ContainedText, pub user_sender: Interactive, diff --git a/styles/src/styleTree/assistant.ts b/styles/src/styleTree/assistant.ts index 13ef484391..94547bd154 100644 --- a/styles/src/styleTree/assistant.ts +++ b/styles/src/styleTree/assistant.ts @@ -23,7 +23,36 @@ export default function assistant(colorScheme: ColorScheme) { height: 15, }, }, - container: {} + container: { + margin: { left: 8 }, + } + }, + plusButton: { + icon: { + color: text(layer, "sans", "default", { size: "sm" }).color, + asset: "icons/plus_12.svg", + dimensions: { + width: 12, + height: 12, + }, + }, + container: { + margin: { left: 8 }, + } + }, + savedConversation: { + background: background(layer, "on"), + hover: { + background: background(layer, "on", "hovered"), + }, + savedAt: { + margin: { left: 8 }, + ...text(layer, "sans", "default", { size: "xs" }), + }, + title: { + margin: { left: 8 }, + ...text(layer, "sans", "default", { size: "sm", weight: "bold" }), + } }, userSender: { ...text(layer, "sans", "default", { size: "sm", weight: "bold" }),