diff --git a/Cargo.lock b/Cargo.lock index d21006ee55..936cbbdeca 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -464,18 +464,23 @@ dependencies = [ "feature_flags", "futures 0.3.31", "gpui", + "language", "language_model", "language_model_selector", "language_models", "log", + "pretty_assertions", "project", "proto", + "schemars", "serde", "serde_json", "settings", "smol", + "text", "theme", "ui", + "unindent", "util", "workspace", ] diff --git a/crates/assistant2/Cargo.toml b/crates/assistant2/Cargo.toml index 20e8dfbc9a..d5cd6448e7 100644 --- a/crates/assistant2/Cargo.toml +++ b/crates/assistant2/Cargo.toml @@ -23,17 +23,24 @@ editor.workspace = true feature_flags.workspace = true futures.workspace = true gpui.workspace = true +language.workspace = true language_model.workspace = true language_model_selector.workspace = true language_models.workspace = true log.workspace = true project.workspace = true proto.workspace = true -settings.workspace = true +schemars.workspace = true serde.workspace = true serde_json.workspace = true +settings.workspace = true smol.workspace = true +text.workspace = true theme.workspace = true ui.workspace = true util.workspace = true workspace.workspace = true + +[dev-dependencies] +pretty_assertions.workspace = true +unindent.workspace = true diff --git a/crates/assistant2/src/assistant.rs b/crates/assistant2/src/assistant.rs index 8ef4a1d9dc..37cc966eb4 100644 --- a/crates/assistant2/src/assistant.rs +++ b/crates/assistant2/src/assistant.rs @@ -1,4 +1,5 @@ mod assistant_panel; +mod edits; mod message_editor; mod thread; mod thread_store; @@ -20,6 +21,7 @@ const NAMESPACE: &str = "assistant2"; pub fn init(cx: &mut AppContext) { assistant_panel::init(cx); feature_gate_assistant2_actions(cx); + edits::init(cx); } fn feature_gate_assistant2_actions(cx: &mut AppContext) { diff --git a/crates/assistant2/src/edits.rs b/crates/assistant2/src/edits.rs new file mode 100644 index 0000000000..dba006f728 --- /dev/null +++ b/crates/assistant2/src/edits.rs @@ -0,0 +1,12 @@ +mod code_edits_tool; +pub mod patch; + +use assistant_tool::ToolRegistry; +use gpui::AppContext; + +pub use crate::edits::code_edits_tool::CodeEditsTool; + +pub fn init(cx: &mut AppContext) { + let tool_registry = ToolRegistry::global(cx); + tool_registry.register_tool(CodeEditsTool); +} diff --git a/crates/assistant_tools/src/code_edits_tool.rs b/crates/assistant2/src/edits/code_edits_tool.rs similarity index 56% rename from crates/assistant_tools/src/code_edits_tool.rs rename to crates/assistant2/src/edits/code_edits_tool.rs index 1d05200e74..2bc07fc7b4 100644 --- a/crates/assistant_tools/src/code_edits_tool.rs +++ b/crates/assistant2/src/edits/code_edits_tool.rs @@ -18,31 +18,52 @@ pub struct CodeEditsToolInput { pub struct Edit { /// The path to the file that this edit will change. pub path: String, - /// An arbitrarily-long comment that describes the purpose of this edit. - #[serde(skip_serializing_if = "Option::is_none")] - pub description: Option, - /// An excerpt from the file's current contents that uniquely identifies a range within the file where the edit should occur. - #[serde(skip_serializing_if = "Option::is_none")] - pub old_text: Option, - /// The new text to insert into the file. - pub new_text: String, /// The type of change that should occur at the given range of the file. pub operation: Operation, } #[derive(Debug, Serialize, Deserialize, JsonSchema)] -#[serde(rename_all = "snake_case")] +#[serde(tag = "type", rename_all = "snake_case")] pub enum Operation { /// Replaces the entire range with the new text. - Update, + Update { + /// An excerpt from the file's current contents that uniquely identifies a range within the file where the edit should occur. + old_text: String, + /// The new text to insert into the file. + new_text: String, + /// An arbitrarily-long comment that describes the purpose of this edit. + description: Option, + }, /// Inserts the new text before the range. - InsertBefore, + InsertBefore { + /// An excerpt from the file's current contents that uniquely identifies a range within the file where the edit should occur. + old_text: String, + /// The new text to insert into the file. + new_text: String, + /// An arbitrarily-long comment that describes the purpose of this edit. + description: Option, + }, /// Inserts new text after the range. - InsertAfter, + InsertAfter { + /// An excerpt from the file's current contents that uniquely identifies a range within the file where the edit should occur. + old_text: String, + /// The new text to insert into the file. + new_text: String, + /// An arbitrarily-long comment that describes the purpose of this edit. + description: Option, + }, /// Creates a new file with the given path and the new text. - Create, + Create { + /// An arbitrarily-long comment that describes the purpose of this edit. + description: Option, + /// The new text to insert into the file. + new_text: String, + }, /// Deletes the specified range from the file. - Delete, + Delete { + /// An excerpt from the file's current contents that uniquely identifies a range within the file where the edit should occur. + old_text: String, + }, } pub struct CodeEditsTool; @@ -79,8 +100,6 @@ impl Tool for CodeEditsTool { Err(err) => return Task::ready(Err(anyhow!(err))), }; - let text = format!("The tool returned {:?}.", input); - - Task::ready(Ok(text)) + Task::ready(serde_json::to_string(&input).map_err(|err| anyhow!(err))) } } diff --git a/crates/assistant_tools/src/code_edits_tool_description.txt b/crates/assistant2/src/edits/code_edits_tool_description.txt similarity index 100% rename from crates/assistant_tools/src/code_edits_tool_description.txt rename to crates/assistant2/src/edits/code_edits_tool_description.txt diff --git a/crates/assistant2/src/edits/patch.rs b/crates/assistant2/src/edits/patch.rs new file mode 100644 index 0000000000..d5170ec3fc --- /dev/null +++ b/crates/assistant2/src/edits/patch.rs @@ -0,0 +1,997 @@ +use anyhow::{anyhow, Context as _, Result}; +use collections::HashMap; +use editor::ProposedChangesEditor; +use futures::{future, TryFutureExt as _}; +use gpui::{AppContext, AsyncAppContext, Model, SharedString}; +use language::{AutoindentMode, Buffer, BufferSnapshot}; +use project::{Project, ProjectPath}; +use std::{cmp, ops::Range, path::Path, sync::Arc}; +use text::{AnchorRangeExt as _, Bias, OffsetRangeExt as _, Point}; + +use crate::edits::code_edits_tool::CodeEditsToolInput; + +use super::code_edits_tool; + +#[derive(Clone, Debug)] +pub(crate) struct AssistantPatch { + pub range: Range, + pub title: SharedString, + pub edits: Arc<[Result]>, + pub status: AssistantPatchStatus, +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub(crate) enum AssistantPatchStatus { + Pending, + Ready, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub(crate) struct AssistantEdit { + pub path: String, + pub kind: AssistantEditKind, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum AssistantEditKind { + Update { + old_text: String, + new_text: String, + description: Option, + }, + Create { + new_text: String, + description: Option, + }, + InsertBefore { + old_text: String, + new_text: String, + description: Option, + }, + InsertAfter { + old_text: String, + new_text: String, + description: Option, + }, + Delete { + old_text: String, + }, +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub(crate) struct ResolvedPatch { + pub edit_groups: HashMap, Vec>, + pub errors: Vec, +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct ResolvedEditGroup { + pub context_range: Range, + pub edits: Vec, +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct ResolvedEdit { + range: Range, + new_text: String, + description: Option, +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub(crate) struct AssistantPatchResolutionError { + pub edit_ix: usize, + pub message: String, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] +enum SearchDirection { + Up, + Left, + Diagonal, +} + +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] +struct SearchState { + cost: u32, + direction: SearchDirection, +} + +impl SearchState { + fn new(cost: u32, direction: SearchDirection) -> Self { + Self { cost, direction } + } +} + +struct SearchMatrix { + cols: usize, + data: Vec, +} + +impl SearchMatrix { + fn new(rows: usize, cols: usize) -> Self { + SearchMatrix { + cols, + data: vec![SearchState::new(0, SearchDirection::Diagonal); rows * cols], + } + } + + fn get(&self, row: usize, col: usize) -> SearchState { + self.data[row * self.cols + col] + } + + fn set(&mut self, row: usize, col: usize, cost: SearchState) { + self.data[row * self.cols + col] = cost; + } +} + +impl ResolvedPatch { + pub fn apply(&self, editor: &ProposedChangesEditor, cx: &mut AppContext) { + for (buffer, groups) in &self.edit_groups { + let branch = editor.branch_buffer_for_base(buffer).unwrap(); + Self::apply_edit_groups(groups, &branch, cx); + } + editor.recalculate_all_buffer_diffs(); + } + + fn apply_edit_groups( + groups: &Vec, + buffer: &Model, + cx: &mut AppContext, + ) { + let mut edits = Vec::new(); + for group in groups { + for suggestion in &group.edits { + edits.push((suggestion.range.clone(), suggestion.new_text.clone())); + } + } + buffer.update(cx, |buffer, cx| { + buffer.edit( + edits, + Some(AutoindentMode::Block { + original_indent_columns: Vec::new(), + }), + cx, + ); + }); + } +} + +impl ResolvedEdit { + pub fn try_merge(&mut self, other: &Self, buffer: &text::BufferSnapshot) -> bool { + let range = &self.range; + let other_range = &other.range; + + // Don't merge if we don't contain the other suggestion. + if range.start.cmp(&other_range.start, buffer).is_gt() + || range.end.cmp(&other_range.end, buffer).is_lt() + { + return false; + } + + let other_offset_range = other_range.to_offset(buffer); + let offset_range = range.to_offset(buffer); + + // If the other range is empty at the start of this edit's range, combine the new text + if other_offset_range.is_empty() && other_offset_range.start == offset_range.start { + self.new_text = format!("{}\n{}", other.new_text, self.new_text); + self.range.start = other_range.start; + + if let Some((description, other_description)) = + self.description.as_mut().zip(other.description.as_ref()) + { + *description = format!("{}\n{}", other_description, description) + } + } else { + if let Some((description, other_description)) = + self.description.as_mut().zip(other.description.as_ref()) + { + description.push('\n'); + description.push_str(other_description); + } + } + + true + } +} + +impl AssistantEdit { + pub async fn resolve( + &self, + project: Model, + mut cx: AsyncAppContext, + ) -> Result<(Model, ResolvedEdit)> { + let path = self.path.clone(); + let kind = self.kind.clone(); + let buffer = project + .update(&mut cx, |project, cx| { + let project_path = project + .find_project_path(Path::new(&path), cx) + .or_else(|| { + // If we couldn't find a project path for it, put it in the active worktree + // so that when we create the buffer, it can be saved. + let worktree = project + .active_entry() + .and_then(|entry_id| project.worktree_for_entry(entry_id, cx)) + .or_else(|| project.worktrees(cx).next())?; + let worktree = worktree.read(cx); + + Some(ProjectPath { + worktree_id: worktree.id(), + path: Arc::from(Path::new(&path)), + }) + }) + .with_context(|| format!("worktree not found for {:?}", path))?; + anyhow::Ok(project.open_buffer(project_path, cx)) + })?? + .await?; + + let snapshot = buffer.update(&mut cx, |buffer, _| buffer.snapshot())?; + let suggestion = cx + .background_executor() + .spawn(async move { kind.resolve(&snapshot) }) + .await; + + Ok((buffer, suggestion)) + } +} + +impl AssistantEditKind { + fn resolve(self, snapshot: &BufferSnapshot) -> ResolvedEdit { + match self { + Self::Update { + old_text, + new_text, + description, + } => { + let range = Self::resolve_location(&snapshot, &old_text); + ResolvedEdit { + range, + new_text, + description, + } + } + Self::Create { + new_text, + description, + } => ResolvedEdit { + range: text::Anchor::MIN..text::Anchor::MAX, + description, + new_text, + }, + Self::InsertBefore { + old_text, + mut new_text, + description, + } => { + let range = Self::resolve_location(&snapshot, &old_text); + new_text.push('\n'); + ResolvedEdit { + range: range.start..range.start, + new_text, + description, + } + } + Self::InsertAfter { + old_text, + mut new_text, + description, + } => { + let range = Self::resolve_location(&snapshot, &old_text); + new_text.insert(0, '\n'); + ResolvedEdit { + range: range.end..range.end, + new_text, + description, + } + } + Self::Delete { old_text } => { + let range = Self::resolve_location(&snapshot, &old_text); + ResolvedEdit { + range, + new_text: String::new(), + description: None, + } + } + } + } + + fn resolve_location(buffer: &text::BufferSnapshot, search_query: &str) -> Range { + const INSERTION_COST: u32 = 3; + const DELETION_COST: u32 = 10; + const WHITESPACE_INSERTION_COST: u32 = 1; + const WHITESPACE_DELETION_COST: u32 = 1; + + let buffer_len = buffer.len(); + let query_len = search_query.len(); + let mut matrix = SearchMatrix::new(query_len + 1, buffer_len + 1); + let mut leading_deletion_cost = 0_u32; + for (row, query_byte) in search_query.bytes().enumerate() { + let deletion_cost = if query_byte.is_ascii_whitespace() { + WHITESPACE_DELETION_COST + } else { + DELETION_COST + }; + + leading_deletion_cost = leading_deletion_cost.saturating_add(deletion_cost); + matrix.set( + row + 1, + 0, + SearchState::new(leading_deletion_cost, SearchDirection::Diagonal), + ); + + for (col, buffer_byte) in buffer.bytes_in_range(0..buffer.len()).flatten().enumerate() { + let insertion_cost = if buffer_byte.is_ascii_whitespace() { + WHITESPACE_INSERTION_COST + } else { + INSERTION_COST + }; + + let up = SearchState::new( + matrix.get(row, col + 1).cost.saturating_add(deletion_cost), + SearchDirection::Up, + ); + let left = SearchState::new( + matrix.get(row + 1, col).cost.saturating_add(insertion_cost), + SearchDirection::Left, + ); + let diagonal = SearchState::new( + if query_byte == *buffer_byte { + matrix.get(row, col).cost + } else { + matrix + .get(row, col) + .cost + .saturating_add(deletion_cost + insertion_cost) + }, + SearchDirection::Diagonal, + ); + matrix.set(row + 1, col + 1, up.min(left).min(diagonal)); + } + } + + // Traceback to find the best match + let mut best_buffer_end = buffer_len; + let mut best_cost = u32::MAX; + for col in 1..=buffer_len { + let cost = matrix.get(query_len, col).cost; + if cost < best_cost { + best_cost = cost; + best_buffer_end = col; + } + } + + let mut query_ix = query_len; + let mut buffer_ix = best_buffer_end; + while query_ix > 0 && buffer_ix > 0 { + let current = matrix.get(query_ix, buffer_ix); + match current.direction { + SearchDirection::Diagonal => { + query_ix -= 1; + buffer_ix -= 1; + } + SearchDirection::Up => { + query_ix -= 1; + } + SearchDirection::Left => { + buffer_ix -= 1; + } + } + } + + let mut start = buffer.offset_to_point(buffer.clip_offset(buffer_ix, Bias::Left)); + start.column = 0; + let mut end = buffer.offset_to_point(buffer.clip_offset(best_buffer_end, Bias::Right)); + if end.column > 0 { + end.column = buffer.line_len(end.row); + } + + buffer.anchor_after(start)..buffer.anchor_before(end) + } +} + +impl AssistantPatch { + pub fn from_tool_use(input: &str) -> Result { + let input: CodeEditsToolInput = serde_json::from_str(input)?; + + let edits = input + .edits + .into_iter() + .map(|edit| AssistantEdit { + path: edit.path, + kind: match edit.operation { + code_edits_tool::Operation::Update { + old_text, + new_text, + description, + } => AssistantEditKind::Update { + old_text, + new_text, + description, + }, + code_edits_tool::Operation::InsertBefore { + old_text, + new_text, + description, + } => AssistantEditKind::InsertBefore { + old_text, + new_text, + description, + }, + code_edits_tool::Operation::InsertAfter { + old_text, + new_text, + description, + } => AssistantEditKind::InsertAfter { + old_text, + new_text, + description, + }, + code_edits_tool::Operation::Create { + description, + new_text, + } => AssistantEditKind::Create { + new_text, + description, + }, + code_edits_tool::Operation::Delete { old_text } => { + AssistantEditKind::Delete { old_text } + } + }, + }) + .map(Ok) + .collect::>(); + + Ok(Self { + title: input.title.into(), + // TODO: In Assistant1 it seems the `range` corresponded to the + // source range in the context editor, so we might not need it + // anymore. + range: language::Anchor::MIN..language::Anchor::MAX, + edits: edits.into(), + status: AssistantPatchStatus::Pending, + }) + } + + pub(crate) async fn resolve( + &self, + project: Model, + cx: &mut AsyncAppContext, + ) -> ResolvedPatch { + let mut resolve_tasks = Vec::new(); + for (ix, edit) in self.edits.iter().enumerate() { + if let Ok(edit) = edit.as_ref() { + resolve_tasks.push( + edit.resolve(project.clone(), cx.clone()) + .map_err(move |error| (ix, error)), + ); + } + } + + let edits = future::join_all(resolve_tasks).await; + let mut errors = Vec::new(); + let mut edits_by_buffer = HashMap::default(); + for entry in edits { + match entry { + Ok((buffer, edit)) => { + edits_by_buffer + .entry(buffer) + .or_insert_with(Vec::new) + .push(edit); + } + Err((edit_ix, error)) => errors.push(AssistantPatchResolutionError { + edit_ix, + message: error.to_string(), + }), + } + } + + // Expand the context ranges of each edit and group edits with overlapping context ranges. + let mut edit_groups_by_buffer = HashMap::default(); + for (buffer, edits) in edits_by_buffer { + if let Ok(snapshot) = buffer.update(cx, |buffer, _| buffer.text_snapshot()) { + edit_groups_by_buffer.insert(buffer, Self::group_edits(edits, &snapshot)); + } + } + + ResolvedPatch { + edit_groups: edit_groups_by_buffer, + errors, + } + } + + fn group_edits( + mut edits: Vec, + snapshot: &text::BufferSnapshot, + ) -> Vec { + let mut edit_groups = Vec::::new(); + // Sort edits by their range so that earlier, larger ranges come first + edits.sort_by(|a, b| a.range.cmp(&b.range, &snapshot)); + + // Merge overlapping edits + edits.dedup_by(|a, b| b.try_merge(a, &snapshot)); + + // Create context ranges for each edit + for edit in edits { + let context_range = { + let edit_point_range = edit.range.to_point(&snapshot); + let start_row = edit_point_range.start.row.saturating_sub(5); + let end_row = cmp::min(edit_point_range.end.row + 5, snapshot.max_point().row); + let start = snapshot.anchor_before(Point::new(start_row, 0)); + let end = snapshot.anchor_after(Point::new(end_row, snapshot.line_len(end_row))); + start..end + }; + + if let Some(last_group) = edit_groups.last_mut() { + if last_group + .context_range + .end + .cmp(&context_range.start, &snapshot) + .is_ge() + { + // Merge with the previous group if context ranges overlap + last_group.context_range.end = context_range.end; + last_group.edits.push(edit); + } else { + // Create a new group + edit_groups.push(ResolvedEditGroup { + context_range, + edits: vec![edit], + }); + } + } else { + // Create the first group + edit_groups.push(ResolvedEditGroup { + context_range, + edits: vec![edit], + }); + } + } + + edit_groups + } + + pub fn path_count(&self) -> usize { + self.paths().count() + } + + pub fn paths(&self) -> impl '_ + Iterator { + let mut prev_path = None; + self.edits.iter().filter_map(move |edit| { + if let Ok(edit) = edit { + let path = Some(edit.path.as_str()); + if path != prev_path { + prev_path = path; + return path; + } + } + None + }) + } +} + +impl PartialEq for AssistantPatch { + fn eq(&self, other: &Self) -> bool { + self.range == other.range + && self.title == other.title + && Arc::ptr_eq(&self.edits, &other.edits) + } +} + +impl Eq for AssistantPatch {} + +#[cfg(test)] +mod tests { + use super::*; + use gpui::{AppContext, Context}; + use language::{ + language_settings::AllLanguageSettings, Language, LanguageConfig, LanguageMatcher, + }; + use settings::SettingsStore; + use ui::BorrowAppContext; + use unindent::Unindent as _; + use util::test::{generate_marked_text, marked_text_ranges}; + + #[gpui::test] + fn test_resolve_location(cx: &mut AppContext) { + assert_location_resolution( + concat!( + " Lorem\n", + "« ipsum\n", + " dolor sit amet»\n", + " consecteur", + ), + "ipsum\ndolor", + cx, + ); + + assert_location_resolution( + &" + «fn foo1(a: usize) -> usize { + 40 + }» + + fn foo2(b: usize) -> usize { + 42 + } + " + .unindent(), + "fn foo1(b: usize) {\n40\n}", + cx, + ); + + assert_location_resolution( + &" + fn main() { + « Foo + .bar() + .baz() + .qux()» + } + + fn foo2(b: usize) -> usize { + 42 + } + " + .unindent(), + "Foo.bar.baz.qux()", + cx, + ); + + assert_location_resolution( + &" + class Something { + one() { return 1; } + « two() { return 2222; } + three() { return 333; } + four() { return 4444; } + five() { return 5555; } + six() { return 6666; } + » seven() { return 7; } + eight() { return 8; } + } + " + .unindent(), + &" + two() { return 2222; } + four() { return 4444; } + five() { return 5555; } + six() { return 6666; } + " + .unindent(), + cx, + ); + } + + #[gpui::test] + fn test_resolve_edits(cx: &mut AppContext) { + init_test(cx); + + assert_edits( + " + /// A person + struct Person { + name: String, + age: usize, + } + + /// A dog + struct Dog { + weight: f32, + } + + impl Person { + fn name(&self) -> &str { + &self.name + } + } + " + .unindent(), + vec![ + AssistantEditKind::Update { + old_text: " + name: String, + " + .unindent(), + new_text: " + first_name: String, + last_name: String, + " + .unindent(), + description: None, + }, + AssistantEditKind::Update { + old_text: " + fn name(&self) -> &str { + &self.name + } + " + .unindent(), + new_text: " + fn name(&self) -> String { + format!(\"{} {}\", self.first_name, self.last_name) + } + " + .unindent(), + description: None, + }, + ], + " + /// A person + struct Person { + first_name: String, + last_name: String, + age: usize, + } + + /// A dog + struct Dog { + weight: f32, + } + + impl Person { + fn name(&self) -> String { + format!(\"{} {}\", self.first_name, self.last_name) + } + } + " + .unindent(), + cx, + ); + + // Ensure InsertBefore merges correctly with Update of the same text + assert_edits( + " + fn foo() { + + } + " + .unindent(), + vec![ + AssistantEditKind::InsertBefore { + old_text: " + fn foo() {" + .unindent(), + new_text: " + fn bar() { + qux(); + }" + .unindent(), + description: Some("implement bar".into()), + }, + AssistantEditKind::Update { + old_text: " + fn foo() { + + }" + .unindent(), + new_text: " + fn foo() { + bar(); + }" + .unindent(), + description: Some("call bar in foo".into()), + }, + AssistantEditKind::InsertAfter { + old_text: " + fn foo() { + + } + " + .unindent(), + new_text: " + fn qux() { + // todo + } + " + .unindent(), + description: Some("implement qux".into()), + }, + ], + " + fn bar() { + qux(); + } + + fn foo() { + bar(); + } + + fn qux() { + // todo + } + " + .unindent(), + cx, + ); + + // Correctly indent new text when replacing multiple adjacent indented blocks. + assert_edits( + " + impl Numbers { + fn one() { + 1 + } + + fn two() { + 2 + } + + fn three() { + 3 + } + } + " + .unindent(), + vec![ + AssistantEditKind::Update { + old_text: " + fn one() { + 1 + } + " + .unindent(), + new_text: " + fn one() { + 101 + } + " + .unindent(), + description: None, + }, + AssistantEditKind::Update { + old_text: " + fn two() { + 2 + } + " + .unindent(), + new_text: " + fn two() { + 102 + } + " + .unindent(), + description: None, + }, + AssistantEditKind::Update { + old_text: " + fn three() { + 3 + } + " + .unindent(), + new_text: " + fn three() { + 103 + } + " + .unindent(), + description: None, + }, + ], + " + impl Numbers { + fn one() { + 101 + } + + fn two() { + 102 + } + + fn three() { + 103 + } + } + " + .unindent(), + cx, + ); + + assert_edits( + " + impl Person { + fn set_name(&mut self, name: String) { + self.name = name; + } + + fn name(&self) -> String { + return self.name; + } + } + " + .unindent(), + vec![ + AssistantEditKind::Update { + old_text: "self.name = name;".unindent(), + new_text: "self._name = name;".unindent(), + description: None, + }, + AssistantEditKind::Update { + old_text: "return self.name;\n".unindent(), + new_text: "return self._name;\n".unindent(), + description: None, + }, + ], + " + impl Person { + fn set_name(&mut self, name: String) { + self._name = name; + } + + fn name(&self) -> String { + return self._name; + } + } + " + .unindent(), + cx, + ); + } + + fn init_test(cx: &mut AppContext) { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + language::init(cx); + cx.update_global::(|settings, cx| { + settings.update_user_settings::(cx, |_| {}); + }); + } + + #[track_caller] + fn assert_location_resolution( + text_with_expected_range: &str, + query: &str, + cx: &mut AppContext, + ) { + let (text, _) = marked_text_ranges(text_with_expected_range, false); + let buffer = cx.new_model(|cx| Buffer::local(text.clone(), cx)); + let snapshot = buffer.read(cx).snapshot(); + let range = AssistantEditKind::resolve_location(&snapshot, query).to_offset(&snapshot); + let text_with_actual_range = generate_marked_text(&text, &[range], false); + pretty_assertions::assert_eq!(text_with_actual_range, text_with_expected_range); + } + + #[track_caller] + fn assert_edits( + old_text: String, + edits: Vec, + new_text: String, + cx: &mut AppContext, + ) { + let buffer = + cx.new_model(|cx| Buffer::local(old_text, cx).with_language(Arc::new(rust_lang()), cx)); + let snapshot = buffer.read(cx).snapshot(); + let resolved_edits = edits + .into_iter() + .map(|kind| kind.resolve(&snapshot)) + .collect(); + let edit_groups = AssistantPatch::group_edits(resolved_edits, &snapshot); + ResolvedPatch::apply_edit_groups(&edit_groups, &buffer, cx); + let actual_new_text = buffer.read(cx).text(); + pretty_assertions::assert_eq!(actual_new_text, new_text); + } + + fn rust_lang() -> Language { + Language::new( + LanguageConfig { + name: "Rust".into(), + matcher: LanguageMatcher { + path_suffixes: vec!["rs".to_string()], + ..Default::default() + }, + ..Default::default() + }, + Some(language::tree_sitter_rust::LANGUAGE.into()), + ) + .with_indents_query( + r#" + (call_expression) @indent + (field_expression) @indent + (_ "(" ")" @end) @indent + (_ "{" "}" @end) @indent + "#, + ) + .unwrap() + } +} diff --git a/crates/assistant2/src/thread.rs b/crates/assistant2/src/thread.rs index a5ab415a4d..cea15ad47d 100644 --- a/crates/assistant2/src/thread.rs +++ b/crates/assistant2/src/thread.rs @@ -15,6 +15,9 @@ use language_models::provider::cloud::{MaxMonthlySpendReachedError, PaymentRequi use serde::{Deserialize, Serialize}; use util::post_inc; +use crate::edits::patch::AssistantPatch; +use crate::edits::CodeEditsTool; + #[derive(Debug, Clone, Copy)] pub enum RequestKind { Chat, @@ -167,6 +170,7 @@ impl Thread { } } LanguageModelCompletionEvent::ToolUse(tool_use) => { + dbg!(&tool_use); if let Some(last_assistant_message) = thread .messages .iter() @@ -258,6 +262,12 @@ impl Thread { let output = output.await; thread .update(&mut cx, |thread, cx| { + let Some(pending_tool_use) = + thread.pending_tool_uses_by_id.get(&tool_use_id) + else { + return; + }; + // The tool use was requested by an Assistant message, // so we want to attach the tool results to the next // user message. @@ -270,11 +280,17 @@ impl Thread { match output { Ok(output) => { - tool_results.push(LanguageModelToolResult { - tool_use_id: tool_use_id.to_string(), - content: output, - is_error: false, - }); + if pending_tool_use.name == CodeEditsTool::TOOL_NAME { + let patch = AssistantPatch::from_tool_use(&output); + + dbg!(&patch); + } else { + tool_results.push(LanguageModelToolResult { + tool_use_id: tool_use_id.to_string(), + content: output, + is_error: false, + }); + } cx.emit(ThreadEvent::ToolFinished { tool_use_id }); } diff --git a/crates/assistant_tools/src/assistant_tools.rs b/crates/assistant_tools/src/assistant_tools.rs index 7d33fdac2a..7d145c61b7 100644 --- a/crates/assistant_tools/src/assistant_tools.rs +++ b/crates/assistant_tools/src/assistant_tools.rs @@ -1,10 +1,8 @@ -mod code_edits_tool; mod now_tool; use assistant_tool::ToolRegistry; use gpui::AppContext; -use crate::code_edits_tool::CodeEditsTool; use crate::now_tool::NowTool; pub fn init(cx: &mut AppContext) { @@ -12,5 +10,4 @@ pub fn init(cx: &mut AppContext) { let registry = ToolRegistry::global(cx); registry.register_tool(NowTool); - registry.register_tool(CodeEditsTool); }