Fix indentation issues when generating or transforming code with inline assistant (#2961)

This pull request extracts a separate `Codegen` struct that models the
interaction with OpenAI and takes care of diffing, auto-indenting and
reporting regions to highlight.

As part of this refactoring, we're now relying less on the AI model to
indent code. The new logic lets tree-sitter decide how the first line
should be indented. Then, for every subsequent line reported by ChatGPT,
it calculates an indent delta relative to the first reported line and
applies it to the indent level chosen by tree-sitter.

Release Notes:

- Improved auto-indentation when using the inline assistant.
This commit is contained in:
Antonio Scandurra 2023-09-13 12:45:44 +02:00 committed by GitHub
commit 5697a87f4a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 871 additions and 491 deletions

1
Cargo.lock generated
View file

@ -114,6 +114,7 @@ dependencies = [
"log",
"menu",
"ordered-float",
"parking_lot 0.11.2",
"project",
"rand 0.8.5",
"regex",

View file

@ -27,6 +27,7 @@ futures.workspace = true
indoc.workspace = true
isahc.workspace = true
ordered-float.workspace = true
parking_lot.workspace = true
regex.workspace = true
schemars.workspace = true
serde.workspace = true

View file

@ -1,5 +1,6 @@
pub mod assistant;
mod assistant_settings;
mod codegen;
mod streaming_diff;
use anyhow::{anyhow, Result};
@ -26,7 +27,7 @@ use util::paths::CONVERSATIONS_DIR;
const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
// Data types for chat completion requests
#[derive(Debug, Serialize)]
#[derive(Debug, Default, Serialize)]
pub struct OpenAIRequest {
model: String,
messages: Vec<RequestMessage>,

View file

@ -1,9 +1,8 @@
use crate::{
assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAIModel},
stream_completion,
streaming_diff::{Hunk, StreamingDiff},
MessageId, MessageMetadata, MessageStatus, OpenAIRequest, RequestMessage, Role,
SavedConversation, SavedConversationMetadata, SavedMessage, OPENAI_API_URL,
codegen::{self, Codegen, CodegenKind, OpenAICompletionProvider},
stream_completion, MessageId, MessageMetadata, MessageStatus, OpenAIRequest, RequestMessage,
Role, SavedConversation, SavedConversationMetadata, SavedMessage, OPENAI_API_URL,
};
use anyhow::{anyhow, Result};
use chrono::{DateTime, Local};
@ -13,10 +12,10 @@ use editor::{
BlockContext, BlockDisposition, BlockId, BlockProperties, BlockStyle, ToDisplayPoint,
},
scroll::autoscroll::{Autoscroll, AutoscrollStrategy},
Anchor, Editor, MoveDown, MoveUp, MultiBufferSnapshot, ToOffset, ToPoint,
Anchor, Editor, MoveDown, MoveUp, MultiBufferSnapshot, ToOffset,
};
use fs::Fs;
use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
use futures::StreamExt;
use gpui::{
actions,
elements::{
@ -30,17 +29,14 @@ use gpui::{
ModelHandle, SizeConstraint, Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle,
WindowContext,
};
use language::{
language_settings::SoftWrap, Buffer, LanguageRegistry, Point, Rope, ToOffset as _,
TransactionId,
};
use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _};
use search::BufferSearchBar;
use settings::SettingsStore;
use std::{
cell::{Cell, RefCell},
cmp, env,
fmt::Write,
future, iter,
iter,
ops::Range,
path::{Path, PathBuf},
rc::Rc,
@ -266,23 +262,40 @@ impl AssistantPanel {
}
fn new_inline_assist(&mut self, editor: &ViewHandle<Editor>, cx: &mut ViewContext<Self>) {
let api_key = if let Some(api_key) = self.api_key.borrow().clone() {
api_key
} else {
return;
};
let inline_assist_id = post_inc(&mut self.next_inline_assist_id);
let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
let provider = Arc::new(OpenAICompletionProvider::new(
api_key,
cx.background().clone(),
));
let selection = editor.read(cx).selections.newest_anchor().clone();
let range = selection.start.bias_left(&snapshot)..selection.end.bias_right(&snapshot);
let assist_kind = if editor.read(cx).selections.newest::<usize>(cx).is_empty() {
InlineAssistKind::Generate
let codegen_kind = if editor.read(cx).selections.newest::<usize>(cx).is_empty() {
CodegenKind::Generate {
position: selection.start,
}
} else {
InlineAssistKind::Transform
CodegenKind::Transform {
range: selection.start..selection.end,
}
};
let codegen = cx.add_model(|cx| {
Codegen::new(editor.read(cx).buffer().clone(), codegen_kind, provider, cx)
});
let measurements = Rc::new(Cell::new(BlockMeasurements::default()));
let inline_assistant = cx.add_view(|cx| {
let assistant = InlineAssistant::new(
inline_assist_id,
assist_kind,
measurements.clone(),
self.include_conversation_in_next_inline_assist,
self.inline_prompt_history.clone(),
codegen.clone(),
cx,
);
cx.focus_self();
@ -321,48 +334,64 @@ impl AssistantPanel {
self.pending_inline_assists.insert(
inline_assist_id,
PendingInlineAssist {
kind: assist_kind,
editor: editor.downgrade(),
range,
highlighted_ranges: Default::default(),
inline_assistant: Some((block_id, inline_assistant.clone())),
code_generation: Task::ready(None),
transaction_id: None,
codegen: codegen.clone(),
_subscriptions: vec![
cx.subscribe(&inline_assistant, Self::handle_inline_assistant_event),
cx.subscribe(editor, {
let inline_assistant = inline_assistant.downgrade();
move |this, editor, event, cx| {
move |_, editor, event, cx| {
if let Some(inline_assistant) = inline_assistant.upgrade(cx) {
match event {
editor::Event::SelectionsChanged { local } => {
if *local && inline_assistant.read(cx).has_focus {
cx.focus(&editor);
}
if let editor::Event::SelectionsChanged { local } = event {
if *local && inline_assistant.read(cx).has_focus {
cx.focus(&editor);
}
editor::Event::TransactionUndone {
transaction_id: tx_id,
} => {
if let Some(pending_assist) =
this.pending_inline_assists.get(&inline_assist_id)
{
if pending_assist.transaction_id == Some(*tx_id) {
// Notice we are supplying `undo: false` here. This
// is because there's no need to undo the transaction
// because the user just did so.
this.close_inline_assist(
inline_assist_id,
false,
cx,
);
}
}
}
_ => {}
}
}
}
}),
cx.observe(&codegen, {
let editor = editor.downgrade();
move |this, _, cx| {
if let Some(editor) = editor.upgrade(cx) {
this.update_highlights_for_editor(&editor, cx);
}
}
}),
cx.subscribe(&codegen, move |this, codegen, event, cx| match event {
codegen::Event::Undone => {
this.finish_inline_assist(inline_assist_id, false, cx)
}
codegen::Event::Finished => {
let pending_assist = if let Some(pending_assist) =
this.pending_inline_assists.get(&inline_assist_id)
{
pending_assist
} else {
return;
};
let error = codegen
.read(cx)
.error()
.map(|error| format!("Inline assistant error: {}", error));
if let Some(error) = error {
if pending_assist.inline_assistant.is_none() {
if let Some(workspace) = this.workspace.upgrade(cx) {
workspace.update(cx, |workspace, cx| {
workspace.show_toast(
Toast::new(inline_assist_id, error),
cx,
);
})
}
}
}
this.finish_inline_assist(inline_assist_id, false, cx);
}
}),
],
},
);
@ -388,7 +417,7 @@ impl AssistantPanel {
self.confirm_inline_assist(assist_id, prompt, *include_conversation, cx);
}
InlineAssistantEvent::Canceled => {
self.close_inline_assist(assist_id, true, cx);
self.finish_inline_assist(assist_id, true, cx);
}
InlineAssistantEvent::Dismissed => {
self.hide_inline_assist(assist_id, cx);
@ -417,7 +446,7 @@ impl AssistantPanel {
.get(&editor.downgrade())
.and_then(|assist_ids| assist_ids.last().copied())
{
panel.close_inline_assist(assist_id, true, cx);
panel.finish_inline_assist(assist_id, true, cx);
true
} else {
false
@ -432,7 +461,7 @@ impl AssistantPanel {
cx.propagate_action();
}
fn close_inline_assist(&mut self, assist_id: usize, undo: bool, cx: &mut ViewContext<Self>) {
fn finish_inline_assist(&mut self, assist_id: usize, undo: bool, cx: &mut ViewContext<Self>) {
self.hide_inline_assist(assist_id, cx);
if let Some(pending_assist) = self.pending_inline_assists.remove(&assist_id) {
@ -450,13 +479,9 @@ impl AssistantPanel {
self.update_highlights_for_editor(&editor, cx);
if undo {
if let Some(transaction_id) = pending_assist.transaction_id {
editor.update(cx, |editor, cx| {
editor.buffer().update(cx, |buffer, cx| {
buffer.undo_transaction(transaction_id, cx)
});
});
}
pending_assist
.codegen
.update(cx, |codegen, cx| codegen.undo(cx));
}
}
}
@ -481,12 +506,6 @@ impl AssistantPanel {
include_conversation: bool,
cx: &mut ViewContext<Self>,
) {
let api_key = if let Some(api_key) = self.api_key.borrow().clone() {
api_key
} else {
return;
};
let conversation = if include_conversation {
self.active_editor()
.map(|editor| editor.read(cx).conversation.clone())
@ -514,56 +533,9 @@ impl AssistantPanel {
self.inline_prompt_history.pop_front();
}
let range = pending_assist.range.clone();
let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
let selected_text = snapshot
.text_for_range(range.start..range.end)
.collect::<Rope>();
let selection_start = range.start.to_point(&snapshot);
let selection_end = range.end.to_point(&snapshot);
let mut base_indent: Option<language::IndentSize> = None;
let mut start_row = selection_start.row;
if snapshot.is_line_blank(start_row) {
if let Some(prev_non_blank_row) = snapshot.prev_non_blank_row(start_row) {
start_row = prev_non_blank_row;
}
}
for row in start_row..=selection_end.row {
if snapshot.is_line_blank(row) {
continue;
}
let line_indent = snapshot.indent_size_for_line(row);
if let Some(base_indent) = base_indent.as_mut() {
if line_indent.len < base_indent.len {
*base_indent = line_indent;
}
} else {
base_indent = Some(line_indent);
}
}
let mut normalized_selected_text = selected_text.clone();
if let Some(base_indent) = base_indent {
for row in selection_start.row..=selection_end.row {
let selection_row = row - selection_start.row;
let line_start =
normalized_selected_text.point_to_offset(Point::new(selection_row, 0));
let indent_len = if row == selection_start.row {
base_indent.len.saturating_sub(selection_start.column)
} else {
let line_len = normalized_selected_text.line_len(selection_row);
cmp::min(line_len, base_indent.len)
};
let indent_end = cmp::min(
line_start + indent_len as usize,
normalized_selected_text.len(),
);
normalized_selected_text.replace(line_start..indent_end, "");
}
}
let range = pending_assist.codegen.read(cx).range();
let selected_text = snapshot.text_for_range(range.clone()).collect::<String>();
let language = snapshot.language_at(range.start);
let language_name = if let Some(language) = language.as_ref() {
@ -581,8 +553,8 @@ impl AssistantPanel {
if let Some(language_name) = language_name {
writeln!(prompt, "You're an expert {language_name} engineer.").unwrap();
}
match pending_assist.kind {
InlineAssistKind::Transform => {
match pending_assist.codegen.read(cx).kind() {
CodegenKind::Transform { .. } => {
writeln!(
prompt,
"You're currently working inside an editor on this file:"
@ -608,7 +580,7 @@ impl AssistantPanel {
} else {
writeln!(prompt, "```").unwrap();
}
writeln!(prompt, "{normalized_selected_text}").unwrap();
writeln!(prompt, "{selected_text}").unwrap();
writeln!(prompt, "```").unwrap();
writeln!(prompt).unwrap();
writeln!(
@ -622,7 +594,7 @@ impl AssistantPanel {
)
.unwrap();
}
InlineAssistKind::Generate => {
CodegenKind::Generate { .. } => {
writeln!(
prompt,
"You're currently working inside an editor on this file:"
@ -689,209 +661,9 @@ impl AssistantPanel {
messages,
stream: true,
};
let response = stream_completion(api_key, cx.background().clone(), request);
let editor = editor.downgrade();
pending_assist.code_generation = cx.spawn(|this, mut cx| {
async move {
let mut edit_start = range.start.to_offset(&snapshot);
let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
let diff = cx.background().spawn(async move {
let chunks = strip_markdown_codeblock(response.await?.filter_map(
|message| async move {
match message {
Ok(mut message) => Some(Ok(message.choices.pop()?.delta.content?)),
Err(error) => Some(Err(error)),
}
},
));
futures::pin_mut!(chunks);
let mut diff = StreamingDiff::new(selected_text.to_string());
let mut indent_len;
let indent_text;
if let Some(base_indent) = base_indent {
indent_len = base_indent.len;
indent_text = match base_indent.kind {
language::IndentKind::Space => " ",
language::IndentKind::Tab => "\t",
};
} else {
indent_len = 0;
indent_text = "";
};
let mut first_line_len = 0;
let mut first_line_non_whitespace_char_ix = None;
let mut first_line = true;
let mut new_text = String::new();
while let Some(chunk) = chunks.next().await {
let chunk = chunk?;
let mut lines = chunk.split('\n');
if let Some(mut line) = lines.next() {
if first_line {
if first_line_non_whitespace_char_ix.is_none() {
if let Some(mut char_ix) =
line.find(|ch: char| !ch.is_whitespace())
{
line = &line[char_ix..];
char_ix += first_line_len;
first_line_non_whitespace_char_ix = Some(char_ix);
let first_line_indent = char_ix
.saturating_sub(selection_start.column as usize)
as usize;
new_text.push_str(&indent_text.repeat(first_line_indent));
indent_len = indent_len.saturating_sub(char_ix as u32);
}
}
first_line_len += line.len();
}
if first_line_non_whitespace_char_ix.is_some() {
new_text.push_str(line);
}
}
for line in lines {
first_line = false;
new_text.push('\n');
if !line.is_empty() {
new_text.push_str(&indent_text.repeat(indent_len as usize));
}
new_text.push_str(line);
}
let hunks = diff.push_new(&new_text);
hunks_tx.send(hunks).await?;
new_text.clear();
}
hunks_tx.send(diff.finish()).await?;
anyhow::Ok(())
});
while let Some(hunks) = hunks_rx.next().await {
let editor = if let Some(editor) = editor.upgrade(&cx) {
editor
} else {
break;
};
let this = if let Some(this) = this.upgrade(&cx) {
this
} else {
break;
};
this.update(&mut cx, |this, cx| {
let pending_assist = if let Some(pending_assist) =
this.pending_inline_assists.get_mut(&inline_assist_id)
{
pending_assist
} else {
return;
};
pending_assist.highlighted_ranges.clear();
editor.update(cx, |editor, cx| {
let transaction = editor.buffer().update(cx, |buffer, cx| {
// Avoid grouping assistant edits with user edits.
buffer.finalize_last_transaction(cx);
buffer.start_transaction(cx);
buffer.edit(
hunks.into_iter().filter_map(|hunk| match hunk {
Hunk::Insert { text } => {
let edit_start = snapshot.anchor_after(edit_start);
Some((edit_start..edit_start, text))
}
Hunk::Remove { len } => {
let edit_end = edit_start + len;
let edit_range = snapshot.anchor_after(edit_start)
..snapshot.anchor_before(edit_end);
edit_start = edit_end;
Some((edit_range, String::new()))
}
Hunk::Keep { len } => {
let edit_end = edit_start + len;
let edit_range = snapshot.anchor_after(edit_start)
..snapshot.anchor_before(edit_end);
edit_start += len;
pending_assist.highlighted_ranges.push(edit_range);
None
}
}),
None,
cx,
);
buffer.end_transaction(cx)
});
if let Some(transaction) = transaction {
if let Some(first_transaction) = pending_assist.transaction_id {
// Group all assistant edits into the first transaction.
editor.buffer().update(cx, |buffer, cx| {
buffer.merge_transactions(
transaction,
first_transaction,
cx,
)
});
} else {
pending_assist.transaction_id = Some(transaction);
editor.buffer().update(cx, |buffer, cx| {
buffer.finalize_last_transaction(cx)
});
}
}
});
this.update_highlights_for_editor(&editor, cx);
});
}
if let Err(error) = diff.await {
this.update(&mut cx, |this, cx| {
let pending_assist = if let Some(pending_assist) =
this.pending_inline_assists.get_mut(&inline_assist_id)
{
pending_assist
} else {
return;
};
if let Some((_, inline_assistant)) =
pending_assist.inline_assistant.as_ref()
{
inline_assistant.update(cx, |inline_assistant, cx| {
inline_assistant.set_error(error, cx);
});
} else if let Some(workspace) = this.workspace.upgrade(cx) {
workspace.update(cx, |workspace, cx| {
workspace.show_toast(
Toast::new(
inline_assist_id,
format!("Inline assistant error: {}", error),
),
cx,
);
})
}
})?;
} else {
let _ = this.update(&mut cx, |this, cx| {
this.close_inline_assist(inline_assist_id, false, cx)
});
}
anyhow::Ok(())
}
.log_err()
});
pending_assist
.codegen
.update(cx, |codegen, cx| codegen.start(request, cx));
}
fn update_highlights_for_editor(
@ -909,8 +681,9 @@ impl AssistantPanel {
for inline_assist_id in inline_assist_ids {
if let Some(pending_assist) = self.pending_inline_assists.get(inline_assist_id) {
background_ranges.push(pending_assist.range.clone());
foreground_ranges.extend(pending_assist.highlighted_ranges.iter().cloned());
let codegen = pending_assist.codegen.read(cx);
background_ranges.push(codegen.range());
foreground_ranges.extend(codegen.last_equal_ranges().iter().cloned());
}
}
@ -2887,12 +2660,6 @@ enum InlineAssistantEvent {
},
}
#[derive(Copy, Clone)]
enum InlineAssistKind {
Transform,
Generate,
}
struct InlineAssistant {
id: usize,
prompt_editor: ViewHandle<Editor>,
@ -2900,11 +2667,11 @@ struct InlineAssistant {
has_focus: bool,
include_conversation: bool,
measurements: Rc<Cell<BlockMeasurements>>,
error: Option<anyhow::Error>,
prompt_history: VecDeque<String>,
prompt_history_ix: Option<usize>,
pending_prompt: String,
_subscription: Subscription,
codegen: ModelHandle<Codegen>,
_subscriptions: Vec<Subscription>,
}
impl Entity for InlineAssistant {
@ -2933,7 +2700,7 @@ impl View for InlineAssistant {
.element()
.aligned(),
)
.with_children(if let Some(error) = self.error.as_ref() {
.with_children(if let Some(error) = self.codegen.read(cx).error() {
Some(
Svg::new("icons/circle_x_mark_12.svg")
.with_color(theme.assistant.error_icon.color)
@ -3007,10 +2774,10 @@ impl View for InlineAssistant {
impl InlineAssistant {
fn new(
id: usize,
kind: InlineAssistKind,
measurements: Rc<Cell<BlockMeasurements>>,
include_conversation: bool,
prompt_history: VecDeque<String>,
codegen: ModelHandle<Codegen>,
cx: &mut ViewContext<Self>,
) -> Self {
let prompt_editor = cx.add_view(|cx| {
@ -3018,14 +2785,17 @@ impl InlineAssistant {
Some(Arc::new(|theme| theme.assistant.inline.editor.clone())),
cx,
);
let placeholder = match kind {
InlineAssistKind::Transform => "Enter transformation prompt…",
InlineAssistKind::Generate => "Enter generation prompt…",
let placeholder = match codegen.read(cx).kind() {
CodegenKind::Transform { .. } => "Enter transformation prompt…",
CodegenKind::Generate { .. } => "Enter generation prompt…",
};
editor.set_placeholder_text(placeholder, cx);
editor
});
let subscription = cx.subscribe(&prompt_editor, Self::handle_prompt_editor_events);
let subscriptions = vec![
cx.observe(&codegen, Self::handle_codegen_changed),
cx.subscribe(&prompt_editor, Self::handle_prompt_editor_events),
];
Self {
id,
prompt_editor,
@ -3033,11 +2803,11 @@ impl InlineAssistant {
has_focus: false,
include_conversation,
measurements,
error: None,
prompt_history,
prompt_history_ix: None,
pending_prompt: String::new(),
_subscription: subscription,
codegen,
_subscriptions: subscriptions,
}
}
@ -3053,6 +2823,31 @@ impl InlineAssistant {
}
}
fn handle_codegen_changed(&mut self, _: ModelHandle<Codegen>, cx: &mut ViewContext<Self>) {
let is_read_only = !self.codegen.read(cx).idle();
self.prompt_editor.update(cx, |editor, cx| {
let was_read_only = editor.read_only();
if was_read_only != is_read_only {
if is_read_only {
editor.set_read_only(true);
editor.set_field_editor_style(
Some(Arc::new(|theme| {
theme.assistant.inline.disabled_editor.clone()
})),
cx,
);
} else {
editor.set_read_only(false);
editor.set_field_editor_style(
Some(Arc::new(|theme| theme.assistant.inline.editor.clone())),
cx,
);
}
}
});
cx.notify();
}
fn cancel(&mut self, _: &editor::Cancel, cx: &mut ViewContext<Self>) {
cx.emit(InlineAssistantEvent::Canceled);
}
@ -3076,7 +2871,6 @@ impl InlineAssistant {
include_conversation: self.include_conversation,
});
self.confirmed = true;
self.error = None;
cx.notify();
}
}
@ -3093,19 +2887,6 @@ impl InlineAssistant {
cx.notify();
}
fn set_error(&mut self, error: anyhow::Error, cx: &mut ViewContext<Self>) {
self.error = Some(error);
self.confirmed = false;
self.prompt_editor.update(cx, |editor, cx| {
editor.set_read_only(false);
editor.set_field_editor_style(
Some(Arc::new(|theme| theme.assistant.inline.editor.clone())),
cx,
);
});
cx.notify();
}
fn move_up(&mut self, _: &MoveUp, cx: &mut ViewContext<Self>) {
if let Some(ix) = self.prompt_history_ix {
if ix > 0 {
@ -3152,13 +2933,9 @@ struct BlockMeasurements {
}
struct PendingInlineAssist {
kind: InlineAssistKind,
editor: WeakViewHandle<Editor>,
range: Range<Anchor>,
highlighted_ranges: Vec<Range<Anchor>>,
inline_assistant: Option<(BlockId, ViewHandle<InlineAssistant>)>,
code_generation: Task<Option<()>>,
transaction_id: Option<TransactionId>,
codegen: ModelHandle<Codegen>,
_subscriptions: Vec<Subscription>,
}
@ -3184,65 +2961,10 @@ fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
}
}
fn strip_markdown_codeblock(
stream: impl Stream<Item = Result<String>>,
) -> impl Stream<Item = Result<String>> {
let mut first_line = true;
let mut buffer = String::new();
let mut starts_with_fenced_code_block = false;
stream.filter_map(move |chunk| {
let chunk = match chunk {
Ok(chunk) => chunk,
Err(err) => return future::ready(Some(Err(err))),
};
buffer.push_str(&chunk);
if first_line {
if buffer == "" || buffer == "`" || buffer == "``" {
return future::ready(None);
} else if buffer.starts_with("```") {
starts_with_fenced_code_block = true;
if let Some(newline_ix) = buffer.find('\n') {
buffer.replace_range(..newline_ix + 1, "");
first_line = false;
} else {
return future::ready(None);
}
}
}
let text = if starts_with_fenced_code_block {
buffer
.strip_suffix("\n```\n")
.or_else(|| buffer.strip_suffix("\n```"))
.or_else(|| buffer.strip_suffix("\n``"))
.or_else(|| buffer.strip_suffix("\n`"))
.or_else(|| buffer.strip_suffix('\n'))
.unwrap_or(&buffer)
} else {
&buffer
};
if text.contains('\n') {
first_line = false;
}
let remainder = buffer.split_off(text.len());
let result = if buffer.is_empty() {
None
} else {
Some(Ok(buffer.clone()))
};
buffer = remainder;
future::ready(result)
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::MessageId;
use futures::stream;
use gpui::AppContext;
#[gpui::test]
@ -3611,62 +3333,6 @@ mod tests {
);
}
#[gpui::test]
async fn test_strip_markdown_codeblock() {
assert_eq!(
strip_markdown_codeblock(chunks("Lorem ipsum dolor", 2))
.map(|chunk| chunk.unwrap())
.collect::<String>()
.await,
"Lorem ipsum dolor"
);
assert_eq!(
strip_markdown_codeblock(chunks("```\nLorem ipsum dolor", 2))
.map(|chunk| chunk.unwrap())
.collect::<String>()
.await,
"Lorem ipsum dolor"
);
assert_eq!(
strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
.map(|chunk| chunk.unwrap())
.collect::<String>()
.await,
"Lorem ipsum dolor"
);
assert_eq!(
strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
.map(|chunk| chunk.unwrap())
.collect::<String>()
.await,
"Lorem ipsum dolor"
);
assert_eq!(
strip_markdown_codeblock(chunks("```html\n```js\nLorem ipsum dolor\n```\n```", 2))
.map(|chunk| chunk.unwrap())
.collect::<String>()
.await,
"```js\nLorem ipsum dolor\n```"
);
assert_eq!(
strip_markdown_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
.map(|chunk| chunk.unwrap())
.collect::<String>()
.await,
"``\nLorem ipsum dolor\n```"
);
fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
stream::iter(
text.chars()
.collect::<Vec<_>>()
.chunks(size)
.map(|chunk| Ok(chunk.iter().collect::<String>()))
.collect::<Vec<_>>(),
)
}
}
fn messages(
conversation: &ModelHandle<Conversation>,
cx: &AppContext,

704
crates/ai/src/codegen.rs Normal file
View file

@ -0,0 +1,704 @@
use crate::{
stream_completion,
streaming_diff::{Hunk, StreamingDiff},
OpenAIRequest,
};
use anyhow::Result;
use editor::{
multi_buffer, Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint,
};
use futures::{
channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, SinkExt, Stream, StreamExt,
};
use gpui::{executor::Background, Entity, ModelContext, ModelHandle, Task};
use language::{Rope, TransactionId};
use std::{cmp, future, ops::Range, sync::Arc};
pub trait CompletionProvider {
fn complete(
&self,
prompt: OpenAIRequest,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
}
pub struct OpenAICompletionProvider {
api_key: String,
executor: Arc<Background>,
}
impl OpenAICompletionProvider {
pub fn new(api_key: String, executor: Arc<Background>) -> Self {
Self { api_key, executor }
}
}
impl CompletionProvider for OpenAICompletionProvider {
fn complete(
&self,
prompt: OpenAIRequest,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt);
async move {
let response = request.await?;
let stream = response
.filter_map(|response| async move {
match response {
Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
Err(error) => Some(Err(error)),
}
})
.boxed();
Ok(stream)
}
.boxed()
}
}
pub enum Event {
Finished,
Undone,
}
#[derive(Clone)]
pub enum CodegenKind {
Transform { range: Range<Anchor> },
Generate { position: Anchor },
}
pub struct Codegen {
provider: Arc<dyn CompletionProvider>,
buffer: ModelHandle<MultiBuffer>,
snapshot: MultiBufferSnapshot,
kind: CodegenKind,
last_equal_ranges: Vec<Range<Anchor>>,
transaction_id: Option<TransactionId>,
error: Option<anyhow::Error>,
generation: Task<()>,
idle: bool,
_subscription: gpui::Subscription,
}
impl Entity for Codegen {
type Event = Event;
}
impl Codegen {
pub fn new(
buffer: ModelHandle<MultiBuffer>,
mut kind: CodegenKind,
provider: Arc<dyn CompletionProvider>,
cx: &mut ModelContext<Self>,
) -> Self {
let snapshot = buffer.read(cx).snapshot(cx);
match &mut kind {
CodegenKind::Transform { range } => {
let mut point_range = range.to_point(&snapshot);
point_range.start.column = 0;
if point_range.end.column > 0 || point_range.start.row == point_range.end.row {
point_range.end.column = snapshot.line_len(point_range.end.row);
}
range.start = snapshot.anchor_before(point_range.start);
range.end = snapshot.anchor_after(point_range.end);
}
CodegenKind::Generate { position } => {
*position = position.bias_right(&snapshot);
}
}
Self {
provider,
buffer: buffer.clone(),
snapshot,
kind,
last_equal_ranges: Default::default(),
transaction_id: Default::default(),
error: Default::default(),
idle: true,
generation: Task::ready(()),
_subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
}
}
fn handle_buffer_event(
&mut self,
_buffer: ModelHandle<MultiBuffer>,
event: &multi_buffer::Event,
cx: &mut ModelContext<Self>,
) {
if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
if self.transaction_id == Some(*transaction_id) {
self.transaction_id = None;
self.generation = Task::ready(());
cx.emit(Event::Undone);
}
}
}
pub fn range(&self) -> Range<Anchor> {
match &self.kind {
CodegenKind::Transform { range } => range.clone(),
CodegenKind::Generate { position } => position.bias_left(&self.snapshot)..*position,
}
}
pub fn kind(&self) -> &CodegenKind {
&self.kind
}
pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
&self.last_equal_ranges
}
pub fn idle(&self) -> bool {
self.idle
}
pub fn error(&self) -> Option<&anyhow::Error> {
self.error.as_ref()
}
pub fn start(&mut self, prompt: OpenAIRequest, cx: &mut ModelContext<Self>) {
let range = self.range();
let snapshot = self.snapshot.clone();
let selected_text = snapshot
.text_for_range(range.start..range.end)
.collect::<Rope>();
let selection_start = range.start.to_point(&snapshot);
let suggested_line_indent = snapshot
.suggested_indents(selection_start.row..selection_start.row + 1, cx)
.into_values()
.next()
.unwrap_or_else(|| snapshot.indent_size_for_line(selection_start.row));
let response = self.provider.complete(prompt);
self.generation = cx.spawn_weak(|this, mut cx| {
async move {
let generate = async {
let mut edit_start = range.start.to_offset(&snapshot);
let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
let diff = cx.background().spawn(async move {
let chunks = strip_markdown_codeblock(response.await?);
futures::pin_mut!(chunks);
let mut diff = StreamingDiff::new(selected_text.to_string());
let mut new_text = String::new();
let mut base_indent = None;
let mut line_indent = None;
let mut first_line = true;
while let Some(chunk) = chunks.next().await {
let chunk = chunk?;
let mut lines = chunk.split('\n').peekable();
while let Some(line) = lines.next() {
new_text.push_str(line);
if line_indent.is_none() {
if let Some(non_whitespace_ch_ix) =
new_text.find(|ch: char| !ch.is_whitespace())
{
line_indent = Some(non_whitespace_ch_ix);
base_indent = base_indent.or(line_indent);
let line_indent = line_indent.unwrap();
let base_indent = base_indent.unwrap();
let indent_delta = line_indent as i32 - base_indent as i32;
let mut corrected_indent_len = cmp::max(
0,
suggested_line_indent.len as i32 + indent_delta,
)
as usize;
if first_line {
corrected_indent_len = corrected_indent_len
.saturating_sub(selection_start.column as usize);
}
let indent_char = suggested_line_indent.char();
let mut indent_buffer = [0; 4];
let indent_str =
indent_char.encode_utf8(&mut indent_buffer);
new_text.replace_range(
..line_indent,
&indent_str.repeat(corrected_indent_len),
);
}
}
if line_indent.is_some() {
hunks_tx.send(diff.push_new(&new_text)).await?;
new_text.clear();
}
if lines.peek().is_some() {
hunks_tx.send(diff.push_new("\n")).await?;
line_indent = None;
first_line = false;
}
}
}
hunks_tx.send(diff.push_new(&new_text)).await?;
hunks_tx.send(diff.finish()).await?;
anyhow::Ok(())
});
while let Some(hunks) = hunks_rx.next().await {
let this = if let Some(this) = this.upgrade(&cx) {
this
} else {
break;
};
this.update(&mut cx, |this, cx| {
this.last_equal_ranges.clear();
let transaction = this.buffer.update(cx, |buffer, cx| {
// Avoid grouping assistant edits with user edits.
buffer.finalize_last_transaction(cx);
buffer.start_transaction(cx);
buffer.edit(
hunks.into_iter().filter_map(|hunk| match hunk {
Hunk::Insert { text } => {
let edit_start = snapshot.anchor_after(edit_start);
Some((edit_start..edit_start, text))
}
Hunk::Remove { len } => {
let edit_end = edit_start + len;
let edit_range = snapshot.anchor_after(edit_start)
..snapshot.anchor_before(edit_end);
edit_start = edit_end;
Some((edit_range, String::new()))
}
Hunk::Keep { len } => {
let edit_end = edit_start + len;
let edit_range = snapshot.anchor_after(edit_start)
..snapshot.anchor_before(edit_end);
edit_start = edit_end;
this.last_equal_ranges.push(edit_range);
None
}
}),
None,
cx,
);
buffer.end_transaction(cx)
});
if let Some(transaction) = transaction {
if let Some(first_transaction) = this.transaction_id {
// Group all assistant edits into the first transaction.
this.buffer.update(cx, |buffer, cx| {
buffer.merge_transactions(
transaction,
first_transaction,
cx,
)
});
} else {
this.transaction_id = Some(transaction);
this.buffer.update(cx, |buffer, cx| {
buffer.finalize_last_transaction(cx)
});
}
}
cx.notify();
});
}
diff.await?;
anyhow::Ok(())
};
let result = generate.await;
if let Some(this) = this.upgrade(&cx) {
this.update(&mut cx, |this, cx| {
this.last_equal_ranges.clear();
this.idle = true;
if let Err(error) = result {
this.error = Some(error);
}
cx.emit(Event::Finished);
cx.notify();
});
}
}
});
self.error.take();
self.idle = false;
cx.notify();
}
pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
if let Some(transaction_id) = self.transaction_id {
self.buffer
.update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
}
}
}
fn strip_markdown_codeblock(
stream: impl Stream<Item = Result<String>>,
) -> impl Stream<Item = Result<String>> {
let mut first_line = true;
let mut buffer = String::new();
let mut starts_with_fenced_code_block = false;
stream.filter_map(move |chunk| {
let chunk = match chunk {
Ok(chunk) => chunk,
Err(err) => return future::ready(Some(Err(err))),
};
buffer.push_str(&chunk);
if first_line {
if buffer == "" || buffer == "`" || buffer == "``" {
return future::ready(None);
} else if buffer.starts_with("```") {
starts_with_fenced_code_block = true;
if let Some(newline_ix) = buffer.find('\n') {
buffer.replace_range(..newline_ix + 1, "");
first_line = false;
} else {
return future::ready(None);
}
}
}
let text = if starts_with_fenced_code_block {
buffer
.strip_suffix("\n```\n")
.or_else(|| buffer.strip_suffix("\n```"))
.or_else(|| buffer.strip_suffix("\n``"))
.or_else(|| buffer.strip_suffix("\n`"))
.or_else(|| buffer.strip_suffix('\n'))
.unwrap_or(&buffer)
} else {
&buffer
};
if text.contains('\n') {
first_line = false;
}
let remainder = buffer.split_off(text.len());
let result = if buffer.is_empty() {
None
} else {
Some(Ok(buffer.clone()))
};
buffer = remainder;
future::ready(result)
})
}
#[cfg(test)]
mod tests {
use super::*;
use futures::stream;
use gpui::{executor::Deterministic, TestAppContext};
use indoc::indoc;
use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
use parking_lot::Mutex;
use rand::prelude::*;
use settings::SettingsStore;
#[gpui::test(iterations = 10)]
async fn test_transform_autoindent(
cx: &mut TestAppContext,
mut rng: StdRng,
deterministic: Arc<Deterministic>,
) {
cx.set_global(cx.read(SettingsStore::test));
cx.update(language_settings::init);
let text = indoc! {"
fn main() {
let x = 0;
for _ in 0..10 {
x += 1;
}
}
"};
let buffer =
cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
let range = buffer.read_with(cx, |buffer, cx| {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 4))..snapshot.anchor_after(Point::new(4, 4))
});
let provider = Arc::new(TestCompletionProvider::new());
let codegen = cx.add_model(|cx| {
Codegen::new(
buffer.clone(),
CodegenKind::Transform { range },
provider.clone(),
cx,
)
});
codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
let mut new_text = concat!(
" let mut x = 0;\n",
" while x < 10 {\n",
" x += 1;\n",
" }",
);
while !new_text.is_empty() {
let max_len = cmp::min(new_text.len(), 10);
let len = rng.gen_range(1..=max_len);
let (chunk, suffix) = new_text.split_at(len);
provider.send_completion(chunk);
new_text = suffix;
deterministic.run_until_parked();
}
provider.finish_completion();
deterministic.run_until_parked();
assert_eq!(
buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
indoc! {"
fn main() {
let mut x = 0;
while x < 10 {
x += 1;
}
}
"}
);
}
#[gpui::test(iterations = 10)]
async fn test_autoindent_when_generating_past_indentation(
cx: &mut TestAppContext,
mut rng: StdRng,
deterministic: Arc<Deterministic>,
) {
cx.set_global(cx.read(SettingsStore::test));
cx.update(language_settings::init);
let text = indoc! {"
fn main() {
le
}
"};
let buffer =
cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
let position = buffer.read_with(cx, |buffer, cx| {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 6))
});
let provider = Arc::new(TestCompletionProvider::new());
let codegen = cx.add_model(|cx| {
Codegen::new(
buffer.clone(),
CodegenKind::Generate { position },
provider.clone(),
cx,
)
});
codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
let mut new_text = concat!(
"t mut x = 0;\n",
"while x < 10 {\n",
" x += 1;\n",
"}", //
);
while !new_text.is_empty() {
let max_len = cmp::min(new_text.len(), 10);
let len = rng.gen_range(1..=max_len);
let (chunk, suffix) = new_text.split_at(len);
provider.send_completion(chunk);
new_text = suffix;
deterministic.run_until_parked();
}
provider.finish_completion();
deterministic.run_until_parked();
assert_eq!(
buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
indoc! {"
fn main() {
let mut x = 0;
while x < 10 {
x += 1;
}
}
"}
);
}
#[gpui::test(iterations = 10)]
async fn test_autoindent_when_generating_before_indentation(
cx: &mut TestAppContext,
mut rng: StdRng,
deterministic: Arc<Deterministic>,
) {
cx.set_global(cx.read(SettingsStore::test));
cx.update(language_settings::init);
let text = concat!(
"fn main() {\n",
" \n",
"}\n" //
);
let buffer =
cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
let position = buffer.read_with(cx, |buffer, cx| {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 2))
});
let provider = Arc::new(TestCompletionProvider::new());
let codegen = cx.add_model(|cx| {
Codegen::new(
buffer.clone(),
CodegenKind::Generate { position },
provider.clone(),
cx,
)
});
codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
let mut new_text = concat!(
"let mut x = 0;\n",
"while x < 10 {\n",
" x += 1;\n",
"}", //
);
while !new_text.is_empty() {
let max_len = cmp::min(new_text.len(), 10);
let len = rng.gen_range(1..=max_len);
let (chunk, suffix) = new_text.split_at(len);
provider.send_completion(chunk);
new_text = suffix;
deterministic.run_until_parked();
}
provider.finish_completion();
deterministic.run_until_parked();
assert_eq!(
buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
indoc! {"
fn main() {
let mut x = 0;
while x < 10 {
x += 1;
}
}
"}
);
}
#[gpui::test]
async fn test_strip_markdown_codeblock() {
assert_eq!(
strip_markdown_codeblock(chunks("Lorem ipsum dolor", 2))
.map(|chunk| chunk.unwrap())
.collect::<String>()
.await,
"Lorem ipsum dolor"
);
assert_eq!(
strip_markdown_codeblock(chunks("```\nLorem ipsum dolor", 2))
.map(|chunk| chunk.unwrap())
.collect::<String>()
.await,
"Lorem ipsum dolor"
);
assert_eq!(
strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
.map(|chunk| chunk.unwrap())
.collect::<String>()
.await,
"Lorem ipsum dolor"
);
assert_eq!(
strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
.map(|chunk| chunk.unwrap())
.collect::<String>()
.await,
"Lorem ipsum dolor"
);
assert_eq!(
strip_markdown_codeblock(chunks("```html\n```js\nLorem ipsum dolor\n```\n```", 2))
.map(|chunk| chunk.unwrap())
.collect::<String>()
.await,
"```js\nLorem ipsum dolor\n```"
);
assert_eq!(
strip_markdown_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
.map(|chunk| chunk.unwrap())
.collect::<String>()
.await,
"``\nLorem ipsum dolor\n```"
);
fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
stream::iter(
text.chars()
.collect::<Vec<_>>()
.chunks(size)
.map(|chunk| Ok(chunk.iter().collect::<String>()))
.collect::<Vec<_>>(),
)
}
}
struct TestCompletionProvider {
last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
}
impl TestCompletionProvider {
fn new() -> Self {
Self {
last_completion_tx: Mutex::new(None),
}
}
fn send_completion(&self, completion: impl Into<String>) {
let mut tx = self.last_completion_tx.lock();
tx.as_mut().unwrap().try_send(completion.into()).unwrap();
}
fn finish_completion(&self) {
self.last_completion_tx.lock().take().unwrap();
}
}
impl CompletionProvider for TestCompletionProvider {
fn complete(
&self,
_prompt: OpenAIRequest,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let (tx, rx) = mpsc::channel(1);
*self.last_completion_tx.lock() = Some(tx);
async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
}
}
fn rust_lang() -> Language {
Language::new(
LanguageConfig {
name: "Rust".into(),
path_suffixes: vec!["rs".to_string()],
..Default::default()
},
Some(tree_sitter_rust::language()),
)
.with_indents_query(
r#"
(call_expression) @indent
(field_expression) @indent
(_ "(" ")" @end) @indent
(_ "{" "}" @end) @indent
"#,
)
.unwrap()
}
}

View file

@ -1734,6 +1734,10 @@ impl Editor {
}
}
pub fn read_only(&self) -> bool {
self.read_only
}
pub fn set_read_only(&mut self, read_only: bool) {
self.read_only = read_only;
}
@ -5121,9 +5125,6 @@ impl Editor {
self.unmark_text(cx);
self.refresh_copilot_suggestions(true, cx);
cx.emit(Event::Edited);
cx.emit(Event::TransactionUndone {
transaction_id: tx_id,
});
}
}
@ -8605,9 +8606,6 @@ pub enum Event {
local: bool,
autoscroll: bool,
},
TransactionUndone {
transaction_id: TransactionId,
},
Closed,
}

View file

@ -70,6 +70,9 @@ pub enum Event {
Edited {
sigleton_buffer_edited: bool,
},
TransactionUndone {
transaction_id: TransactionId,
},
Reloaded,
DiffBaseChanged,
LanguageChanged,
@ -771,30 +774,36 @@ impl MultiBuffer {
}
pub fn undo(&mut self, cx: &mut ModelContext<Self>) -> Option<TransactionId> {
let mut transaction_id = None;
if let Some(buffer) = self.as_singleton() {
return buffer.update(cx, |buffer, cx| buffer.undo(cx));
}
transaction_id = buffer.update(cx, |buffer, cx| buffer.undo(cx));
} else {
while let Some(transaction) = self.history.pop_undo() {
let mut undone = false;
for (buffer_id, buffer_transaction_id) in &mut transaction.buffer_transactions {
if let Some(BufferState { buffer, .. }) = self.buffers.borrow().get(buffer_id) {
undone |= buffer.update(cx, |buffer, cx| {
let undo_to = *buffer_transaction_id;
if let Some(entry) = buffer.peek_undo_stack() {
*buffer_transaction_id = entry.transaction_id();
}
buffer.undo_to_transaction(undo_to, cx)
});
}
}
while let Some(transaction) = self.history.pop_undo() {
let mut undone = false;
for (buffer_id, buffer_transaction_id) in &mut transaction.buffer_transactions {
if let Some(BufferState { buffer, .. }) = self.buffers.borrow().get(buffer_id) {
undone |= buffer.update(cx, |buffer, cx| {
let undo_to = *buffer_transaction_id;
if let Some(entry) = buffer.peek_undo_stack() {
*buffer_transaction_id = entry.transaction_id();
}
buffer.undo_to_transaction(undo_to, cx)
});
if undone {
transaction_id = Some(transaction.id);
break;
}
}
if undone {
return Some(transaction.id);
}
}
None
if let Some(transaction_id) = transaction_id {
cx.emit(Event::TransactionUndone { transaction_id });
}
transaction_id
}
pub fn redo(&mut self, cx: &mut ModelContext<Self>) -> Option<TransactionId> {