diff --git a/crates/ai/src/ai.rs b/crates/ai/src/ai.rs index bad153879f..48f490c9c0 100644 --- a/crates/ai/src/ai.rs +++ b/crates/ai/src/ai.rs @@ -1,7 +1,8 @@ pub mod assistant; mod assistant_settings; -mod diff; -mod refactor; +mod refactoring_assistant; +mod refactoring_modal; +mod streaming_diff; use anyhow::{anyhow, Result}; pub use assistant::AssistantPanel; @@ -195,7 +196,7 @@ struct OpenAIChoice { pub fn init(cx: &mut AppContext) { assistant::init(cx); - refactor::init(cx); + refactoring_modal::init(cx); } pub async fn stream_completion( diff --git a/crates/ai/src/refactor.rs b/crates/ai/src/refactoring_assistant.rs similarity index 69% rename from crates/ai/src/refactor.rs rename to crates/ai/src/refactoring_assistant.rs index 1cb370dbba..5562cb4606 100644 --- a/crates/ai/src/refactor.rs +++ b/crates/ai/src/refactoring_assistant.rs @@ -1,25 +1,16 @@ -use crate::{diff::Diff, stream_completion, OpenAIRequest, RequestMessage, Role}; use collections::HashMap; use editor::{Editor, ToOffset, ToPoint}; use futures::{channel::mpsc, SinkExt, StreamExt}; -use gpui::{ - actions, elements::*, platform::MouseButton, AnyViewHandle, AppContext, Entity, Task, View, - ViewContext, ViewHandle, WeakViewHandle, -}; +use gpui::{AppContext, Task, ViewHandle}; use language::{Point, Rope}; -use menu::{Cancel, Confirm}; -use std::{cmp, env, sync::Arc}; +use std::{cmp, env, fmt::Write}; use util::TryFutureExt; -use workspace::{Modal, Workspace}; -actions!(assistant, [Refactor]); - -pub fn init(cx: &mut AppContext) { - cx.set_global(RefactoringAssistant::new()); - cx.add_action(RefactoringModal::deploy); - cx.add_action(RefactoringModal::confirm); - cx.add_action(RefactoringModal::cancel); -} +use crate::{ + stream_completion, + streaming_diff::{Hunk, StreamingDiff}, + OpenAIRequest, RequestMessage, Role, +}; pub struct RefactoringAssistant { pending_edits_by_editor: HashMap>>, @@ -32,7 +23,30 @@ impl RefactoringAssistant { } } - fn refactor(&mut self, editor: &ViewHandle, prompt: &str, cx: &mut AppContext) { + pub fn update(cx: &mut AppContext, f: F) -> T + where + F: FnOnce(&mut Self, &mut AppContext) -> T, + { + if !cx.has_global::() { + cx.set_global(Self::new()); + } + + cx.update_global(f) + } + + pub fn refactor( + &mut self, + editor: &ViewHandle, + user_prompt: &str, + cx: &mut AppContext, + ) { + let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") { + api_key + } else { + // TODO: ensure the API key is present by going through the assistant panel's flow. + return; + }; + let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx); let selection = editor.read(cx).selections.newest_anchor().clone(); let selected_text = snapshot @@ -83,18 +97,20 @@ impl RefactoringAssistant { .language_at(selection.start) .map(|language| language.name()); let language_name = language_name.as_deref().unwrap_or(""); + + let mut prompt = String::new(); + writeln!(prompt, "Given the following {language_name} snippet:").unwrap(); + writeln!(prompt, "{normalized_selected_text}").unwrap(); + writeln!(prompt, "{user_prompt}.").unwrap(); + writeln!(prompt, "Never make remarks, reply only with the new code.").unwrap(); let request = OpenAIRequest { model: "gpt-4".into(), - messages: vec![ - RequestMessage { + messages: vec![RequestMessage { role: Role::User, - content: format!( - "Given the following {language_name} snippet:\n{normalized_selected_text}\n{prompt}. Never make remarks and reply only with the new code." - ), + content: prompt, }], stream: true, }; - let api_key = env::var("OPENAI_API_KEY").unwrap(); let response = stream_completion(api_key, cx.background().clone(), request); let editor = editor.downgrade(); self.pending_edits_by_editor.insert( @@ -116,7 +132,7 @@ impl RefactoringAssistant { let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1); let diff = cx.background().spawn(async move { let mut messages = response.await?.ready_chunks(4); - let mut diff = Diff::new(selected_text.to_string()); + let mut diff = StreamingDiff::new(selected_text.to_string()); let indentation_len; let indentation_text; @@ -177,18 +193,18 @@ impl RefactoringAssistant { buffer.start_transaction(cx); buffer.edit( hunks.into_iter().filter_map(|hunk| match hunk { - crate::diff::Hunk::Insert { text } => { + Hunk::Insert { text } => { let edit_start = snapshot.anchor_after(edit_start); Some((edit_start..edit_start, text)) } - crate::diff::Hunk::Remove { len } => { + Hunk::Remove { len } => { let edit_end = edit_start + len; let edit_range = snapshot.anchor_after(edit_start) ..snapshot.anchor_before(edit_end); edit_start = edit_end; Some((edit_range, String::new())) } - crate::diff::Hunk::Keep { len } => { + Hunk::Keep { len } => { let edit_end = edit_start + len; let edit_range = snapshot.anchor_after(edit_start) ..snapshot.anchor_before(edit_end); @@ -234,99 +250,3 @@ impl RefactoringAssistant { ); } } - -enum Event { - Dismissed, -} - -struct RefactoringModal { - active_editor: WeakViewHandle, - prompt_editor: ViewHandle, - has_focus: bool, -} - -impl Entity for RefactoringModal { - type Event = Event; -} - -impl View for RefactoringModal { - fn ui_name() -> &'static str { - "RefactoringModal" - } - - fn render(&mut self, cx: &mut ViewContext) -> AnyElement { - let theme = theme::current(cx); - - ChildView::new(&self.prompt_editor, cx) - .constrained() - .with_width(theme.assistant.modal.width) - .contained() - .with_style(theme.assistant.modal.container) - .mouse::(0) - .on_click_out(MouseButton::Left, |_, _, cx| cx.emit(Event::Dismissed)) - .on_click_out(MouseButton::Right, |_, _, cx| cx.emit(Event::Dismissed)) - .aligned() - .right() - .into_any() - } - - fn focus_in(&mut self, _: AnyViewHandle, cx: &mut ViewContext) { - self.has_focus = true; - cx.focus(&self.prompt_editor); - } - - fn focus_out(&mut self, _: AnyViewHandle, _: &mut ViewContext) { - self.has_focus = false; - } -} - -impl Modal for RefactoringModal { - fn has_focus(&self) -> bool { - self.has_focus - } - - fn dismiss_on_event(event: &Self::Event) -> bool { - matches!(event, Self::Event::Dismissed) - } -} - -impl RefactoringModal { - fn deploy(workspace: &mut Workspace, _: &Refactor, cx: &mut ViewContext) { - if let Some(active_editor) = workspace - .active_item(cx) - .and_then(|item| Some(item.act_as::(cx)?.downgrade())) - { - workspace.toggle_modal(cx, |_, cx| { - let prompt_editor = cx.add_view(|cx| { - let mut editor = Editor::auto_height( - theme::current(cx).assistant.modal.editor_max_lines, - Some(Arc::new(|theme| theme.assistant.modal.editor.clone())), - cx, - ); - editor - .set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx); - editor - }); - cx.add_view(|_| RefactoringModal { - active_editor, - prompt_editor, - has_focus: false, - }) - }); - } - } - - fn cancel(&mut self, _: &Cancel, cx: &mut ViewContext) { - cx.emit(Event::Dismissed); - } - - fn confirm(&mut self, _: &Confirm, cx: &mut ViewContext) { - if let Some(editor) = self.active_editor.upgrade(cx) { - let prompt = self.prompt_editor.read(cx).text(cx); - cx.update_global(|assistant: &mut RefactoringAssistant, cx| { - assistant.refactor(&editor, &prompt, cx); - }); - cx.emit(Event::Dismissed); - } - } -} diff --git a/crates/ai/src/refactoring_modal.rs b/crates/ai/src/refactoring_modal.rs new file mode 100644 index 0000000000..2203acc921 --- /dev/null +++ b/crates/ai/src/refactoring_modal.rs @@ -0,0 +1,134 @@ +use crate::refactoring_assistant::RefactoringAssistant; +use collections::HashSet; +use editor::{ + display_map::{BlockContext, BlockDisposition, BlockProperties, BlockStyle}, + scroll::autoscroll::{Autoscroll, AutoscrollStrategy}, + Editor, +}; +use gpui::{ + actions, elements::*, platform::MouseButton, AnyViewHandle, AppContext, Entity, View, + ViewContext, ViewHandle, WeakViewHandle, +}; +use std::sync::Arc; +use workspace::Workspace; + +actions!(assistant, [Refactor]); + +pub fn init(cx: &mut AppContext) { + cx.add_action(RefactoringModal::deploy); + cx.add_action(RefactoringModal::confirm); + cx.add_action(RefactoringModal::cancel); +} + +enum Event { + Dismissed, +} + +struct RefactoringModal { + active_editor: WeakViewHandle, + prompt_editor: ViewHandle, + has_focus: bool, +} + +impl Entity for RefactoringModal { + type Event = Event; +} + +impl View for RefactoringModal { + fn ui_name() -> &'static str { + "RefactoringModal" + } + + fn render(&mut self, cx: &mut ViewContext) -> AnyElement { + ChildView::new(&self.prompt_editor, cx) + .mouse::(0) + .on_click_out(MouseButton::Left, |_, _, cx| cx.emit(Event::Dismissed)) + .on_click_out(MouseButton::Right, |_, _, cx| cx.emit(Event::Dismissed)) + .into_any() + } + + fn focus_in(&mut self, _: AnyViewHandle, cx: &mut ViewContext) { + self.has_focus = true; + cx.focus(&self.prompt_editor); + } + + fn focus_out(&mut self, _: AnyViewHandle, cx: &mut ViewContext) { + if !self.prompt_editor.is_focused(cx) { + self.has_focus = false; + cx.emit(Event::Dismissed); + } + } +} + +impl RefactoringModal { + fn deploy(workspace: &mut Workspace, _: &Refactor, cx: &mut ViewContext) { + if let Some(active_editor) = workspace + .active_item(cx) + .and_then(|item| item.act_as::(cx)) + { + active_editor.update(cx, |editor, cx| { + let position = editor.selections.newest_anchor().head(); + let prompt_editor = cx.add_view(|cx| { + Editor::single_line( + Some(Arc::new(|theme| theme.assistant.modal.editor.clone())), + cx, + ) + }); + let active_editor = cx.weak_handle(); + let refactoring = cx.add_view(|_| RefactoringModal { + active_editor, + prompt_editor, + has_focus: false, + }); + cx.focus(&refactoring); + + let block_id = editor.insert_blocks( + [BlockProperties { + style: BlockStyle::Flex, + position, + height: 2, + render: Arc::new({ + let refactoring = refactoring.clone(); + move |cx: &mut BlockContext| { + ChildView::new(&refactoring, cx) + .contained() + .with_padding_left(cx.gutter_width) + .aligned() + .left() + .into_any() + } + }), + disposition: BlockDisposition::Below, + }], + Some(Autoscroll::Strategy(AutoscrollStrategy::Newest)), + cx, + )[0]; + cx.subscribe(&refactoring, move |_, refactoring, event, cx| { + let Event::Dismissed = event; + if let Some(active_editor) = refactoring.read(cx).active_editor.upgrade(cx) { + cx.window_context().defer(move |cx| { + active_editor.update(cx, |editor, cx| { + editor.remove_blocks(HashSet::from_iter([block_id]), None, cx); + }) + }); + } + }) + .detach(); + }); + } + } + + fn cancel(&mut self, _: &editor::Cancel, cx: &mut ViewContext) { + cx.emit(Event::Dismissed); + } + + fn confirm(&mut self, _: &menu::Confirm, cx: &mut ViewContext) { + if let Some(editor) = self.active_editor.upgrade(cx) { + let prompt = self.prompt_editor.read(cx).text(cx); + RefactoringAssistant::update(cx, |assistant, cx| { + assistant.refactor(&editor, &prompt, cx); + }); + cx.emit(Event::Dismissed); + } + } +} diff --git a/crates/ai/src/diff.rs b/crates/ai/src/streaming_diff.rs similarity index 98% rename from crates/ai/src/diff.rs rename to crates/ai/src/streaming_diff.rs index 7c5af34ff5..1e5189d4d8 100644 --- a/crates/ai/src/diff.rs +++ b/crates/ai/src/streaming_diff.rs @@ -71,7 +71,7 @@ pub enum Hunk { Keep { len: usize }, } -pub struct Diff { +pub struct StreamingDiff { old: Vec, new: Vec, scores: Matrix, @@ -80,10 +80,10 @@ pub struct Diff { equal_runs: HashMap<(usize, usize), u32>, } -impl Diff { +impl StreamingDiff { const INSERTION_SCORE: f64 = -1.; const DELETION_SCORE: f64 = -5.; - const EQUALITY_BASE: f64 = 1.618; + const EQUALITY_BASE: f64 = 2.; const MAX_EQUALITY_EXPONENT: i32 = 32; pub fn new(old: String) -> Self { @@ -250,7 +250,7 @@ mod tests { .collect::(); log::info!("old text: {:?}", old); - let mut diff = Diff::new(old.clone()); + let mut diff = StreamingDiff::new(old.clone()); let mut hunks = Vec::new(); let mut new_len = 0; let mut new = String::new(); diff --git a/styles/src/style_tree/assistant.ts b/styles/src/style_tree/assistant.ts index 88efabee1e..a02d7eb40c 100644 --- a/styles/src/style_tree/assistant.ts +++ b/styles/src/style_tree/assistant.ts @@ -69,8 +69,7 @@ export default function assistant(): any { width: 500, editor_max_lines: 6, editor: { - background: background(theme.lowest), - text: text(theme.lowest, "mono", "on"), + text: text(theme.lowest, "mono", "on", { size: "sm" }), placeholder_text: text(theme.lowest, "sans", "on", "disabled"), selection: theme.players[0], }