From b366592878d01d3174fdbd2afb901bb354bea03b Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Tue, 10 Oct 2023 19:00:05 +0200 Subject: [PATCH] Don't include start of a line when selection ends at start of line --- crates/assistant/src/assistant_panel.rs | 38 +++++++++++++++++-------- crates/assistant/src/codegen.rs | 23 ++------------- crates/assistant/src/prompts.rs | 1 + 3 files changed, 30 insertions(+), 32 deletions(-) diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 62e2f61111..b1c6038602 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -17,7 +17,7 @@ use editor::{ BlockContext, BlockDisposition, BlockId, BlockProperties, BlockStyle, ToDisplayPoint, }, scroll::autoscroll::{Autoscroll, AutoscrollStrategy}, - Anchor, Editor, MoveDown, MoveUp, MultiBufferSnapshot, ToOffset, + Anchor, Editor, MoveDown, MoveUp, MultiBufferSnapshot, ToOffset, ToPoint, }; use fs::Fs; use futures::StreamExt; @@ -278,22 +278,36 @@ impl AssistantPanel { if selection.start.excerpt_id() != selection.end.excerpt_id() { return; } + let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx); + + // Extend the selection to the start and the end of the line. + let mut point_selection = selection.map(|selection| selection.to_point(&snapshot)); + if point_selection.end > point_selection.start { + point_selection.start.column = 0; + // If the selection ends at the start of the line, we don't want to include it. + if point_selection.end.column == 0 { + point_selection.end.row -= 1; + } + point_selection.end.column = snapshot.line_len(point_selection.end.row); + } + + let codegen_kind = if point_selection.start == point_selection.end { + CodegenKind::Generate { + position: snapshot.anchor_after(point_selection.start), + } + } else { + CodegenKind::Transform { + range: snapshot.anchor_before(point_selection.start) + ..snapshot.anchor_after(point_selection.end), + } + }; let inline_assist_id = post_inc(&mut self.next_inline_assist_id); - let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx); let provider = Arc::new(OpenAICompletionProvider::new( api_key, cx.background().clone(), )); - let codegen_kind = if editor.read(cx).selections.newest::(cx).is_empty() { - CodegenKind::Generate { - position: selection.start, - } - } else { - CodegenKind::Transform { - range: selection.start..selection.end, - } - }; + let codegen = cx.add_model(|cx| { Codegen::new(editor.read(cx).buffer().clone(), codegen_kind, provider, cx) }); @@ -319,7 +333,7 @@ impl AssistantPanel { editor.insert_blocks( [BlockProperties { style: BlockStyle::Flex, - position: selection.head().bias_left(&snapshot), + position: snapshot.anchor_before(point_selection.head()), height: 2, render: Arc::new({ let inline_assistant = inline_assistant.clone(); diff --git a/crates/assistant/src/codegen.rs b/crates/assistant/src/codegen.rs index e956d72260..b6ef6b5cfa 100644 --- a/crates/assistant/src/codegen.rs +++ b/crates/assistant/src/codegen.rs @@ -1,9 +1,7 @@ use crate::streaming_diff::{Hunk, StreamingDiff}; use ai::completion::{CompletionProvider, OpenAIRequest}; use anyhow::Result; -use editor::{ - multi_buffer, Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint, -}; +use editor::{multi_buffer, Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint}; use futures::{channel::mpsc, SinkExt, Stream, StreamExt}; use gpui::{Entity, ModelContext, ModelHandle, Task}; use language::{Rope, TransactionId}; @@ -40,26 +38,11 @@ impl Entity for Codegen { impl Codegen { pub fn new( buffer: ModelHandle, - mut kind: CodegenKind, + kind: CodegenKind, provider: Arc, cx: &mut ModelContext, ) -> Self { let snapshot = buffer.read(cx).snapshot(cx); - match &mut kind { - CodegenKind::Transform { range } => { - let mut point_range = range.to_point(&snapshot); - point_range.start.column = 0; - if point_range.end.column > 0 || point_range.start.row == point_range.end.row { - point_range.end.column = snapshot.line_len(point_range.end.row); - } - range.start = snapshot.anchor_before(point_range.start); - range.end = snapshot.anchor_after(point_range.end); - } - CodegenKind::Generate { position } => { - *position = position.bias_right(&snapshot); - } - } - Self { provider, buffer: buffer.clone(), @@ -386,7 +369,7 @@ mod tests { let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx)); let range = buffer.read_with(cx, |buffer, cx| { let snapshot = buffer.snapshot(cx); - snapshot.anchor_before(Point::new(1, 4))..snapshot.anchor_after(Point::new(4, 4)) + snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5)) }); let provider = Arc::new(TestCompletionProvider::new()); let codegen = cx.add_model(|cx| { diff --git a/crates/assistant/src/prompts.rs b/crates/assistant/src/prompts.rs index 7aca365776..d326a7f445 100644 --- a/crates/assistant/src/prompts.rs +++ b/crates/assistant/src/prompts.rs @@ -4,6 +4,7 @@ use std::cmp::{self, Reverse}; use std::fmt::Write; use std::ops::Range; +#[allow(dead_code)] fn summarize(buffer: &BufferSnapshot, selected_range: Range) -> String { #[derive(Debug)] struct Match {