Get back to a compiling state with Buffer backing the assistant

This commit is contained in:
Antonio Scandurra 2023-06-13 13:43:06 +02:00
parent 7db690b713
commit 2ae8b558b9

View file

@ -11,7 +11,7 @@ use editor::{
autoscroll::{Autoscroll, AutoscrollStrategy}, autoscroll::{Autoscroll, AutoscrollStrategy},
ScrollAnchor, ScrollAnchor,
}, },
Anchor, DisplayPoint, Editor, ExcerptId, Anchor, DisplayPoint, Editor, ToOffset as _,
}; };
use fs::Fs; use fs::Fs;
use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt}; use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
@ -25,10 +25,13 @@ use gpui::{
Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle, WindowContext, Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle, WindowContext,
}; };
use isahc::{http::StatusCode, Request, RequestExt}; use isahc::{http::StatusCode, Request, RequestExt};
use language::{language_settings::SoftWrap, Buffer, LanguageRegistry}; use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _};
use serde::Deserialize; use serde::Deserialize;
use settings::SettingsStore; use settings::SettingsStore;
use std::{borrow::Cow, cell::RefCell, cmp, fmt::Write, io, rc::Rc, sync::Arc, time::Duration}; use std::{
borrow::Cow, cell::RefCell, cmp, fmt::Write, io, iter, ops::Range, rc::Rc, sync::Arc,
time::Duration,
};
use util::{post_inc, truncate_and_trailoff, ResultExt, TryFutureExt}; use util::{post_inc, truncate_and_trailoff, ResultExt, TryFutureExt};
use workspace::{ use workspace::{
dock::{DockPosition, Panel}, dock::{DockPosition, Panel},
@ -507,16 +510,16 @@ impl Assistant {
fn count_remaining_tokens(&mut self, cx: &mut ModelContext<Self>) { fn count_remaining_tokens(&mut self, cx: &mut ModelContext<Self>) {
let messages = self let messages = self
.messages .open_ai_request_messages(cx)
.iter() .into_iter()
.filter_map(|message| { .filter_map(|message| {
Some(tiktoken_rs::ChatCompletionRequestMessage { Some(tiktoken_rs::ChatCompletionRequestMessage {
role: match self.messages_metadata.get(&message.excerpt_id)?.role { role: match message.role {
Role::User => "user".into(), Role::User => "user".into(),
Role::Assistant => "assistant".into(), Role::Assistant => "assistant".into(),
Role::System => "system".into(), Role::System => "system".into(),
}, },
content: message.content.read(cx).text(), content: message.content,
name: None, name: None,
}) })
}) })
@ -554,45 +557,47 @@ impl Assistant {
} }
fn assist(&mut self, cx: &mut ModelContext<Self>) -> Option<(Message, Message)> { fn assist(&mut self, cx: &mut ModelContext<Self>) -> Option<(Message, Message)> {
let messages = self
.messages
.iter()
.filter_map(|message| {
Some(RequestMessage {
role: self.messages_metadata.get(&message.excerpt_id)?.role,
content: message.content.read(cx).text(),
})
})
.collect();
let request = OpenAIRequest { let request = OpenAIRequest {
model: self.model.clone(), model: self.model.clone(),
messages, messages: self.open_ai_request_messages(cx),
stream: true, stream: true,
}; };
let api_key = self.api_key.borrow().clone()?; let api_key = self.api_key.borrow().clone()?;
let stream = stream_completion(api_key, cx.background().clone(), request); let stream = stream_completion(api_key, cx.background().clone(), request);
let assistant_message = self.insert_message_after(ExcerptId::max(), Role::Assistant, cx); let assistant_message =
let user_message = self.insert_message_after(ExcerptId::max(), Role::User, cx); self.insert_message_after(self.messages.last()?.id, Role::Assistant, cx)?;
let user_message = self.insert_message_after(assistant_message.id, Role::User, cx)?;
let task = cx.spawn_weak({ let task = cx.spawn_weak({
let assistant_message = assistant_message.clone();
|this, mut cx| async move { |this, mut cx| async move {
let assistant_message = assistant_message; let assistant_message_id = assistant_message.id;
let stream_completion = async { let stream_completion = async {
let mut messages = stream.await?; let mut messages = stream.await?;
while let Some(message) = messages.next().await { while let Some(message) = messages.next().await {
let mut message = message?; let mut message = message?;
if let Some(choice) = message.choices.pop() { if let Some(choice) = message.choices.pop() {
assistant_message.content.update(&mut cx, |content, cx| {
let text: Arc<str> = choice.delta.content?.into();
content.edit([(content.len()..content.len(), text)], None, cx);
Some(())
});
this.upgrade(&cx) this.upgrade(&cx)
.ok_or_else(|| anyhow!("assistant was dropped"))? .ok_or_else(|| anyhow!("assistant was dropped"))?
.update(&mut cx, |_, cx| { .update(&mut cx, |this, cx| {
cx.emit(AssistantEvent::StreamedCompletion); let text: Arc<str> = choice.delta.content?.into();
let message_ix = this
.messages
.iter()
.position(|message| message.id == assistant_message_id)?;
this.buffer.update(cx, |buffer, cx| {
let offset = if message_ix + 1 == this.messages.len() {
buffer.len()
} else {
this.messages[message_ix + 1]
.start
.to_offset(buffer)
.saturating_sub(1)
};
buffer.edit([(offset..offset, text)], None, cx);
});
Some(())
}); });
} }
} }
@ -612,9 +617,8 @@ impl Assistant {
if let Some(this) = this.upgrade(&cx) { if let Some(this) = this.upgrade(&cx) {
this.update(&mut cx, |this, cx| { this.update(&mut cx, |this, cx| {
if let Err(error) = result { if let Err(error) = result {
if let Some(metadata) = this if let Some(metadata) =
.messages_metadata this.messages_metadata.get_mut(&assistant_message.id)
.get_mut(&assistant_message.excerpt_id)
{ {
metadata.error = Some(error.to_string().trim().into()); metadata.error = Some(error.to_string().trim().into());
cx.notify(); cx.notify();
@ -642,33 +646,33 @@ impl Assistant {
protected_offsets: HashSet<usize>, protected_offsets: HashSet<usize>,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) { ) {
let mut offset = 0; // let mut offset = 0;
let mut excerpts_to_remove = Vec::new(); // let mut excerpts_to_remove = Vec::new();
self.messages.retain(|message| { // self.messages.retain(|message| {
let range = offset..offset + message.content.read(cx).len(); // let range = offset..offset + message.content.read(cx).len();
offset = range.end + 1; // offset = range.end + 1;
if range.is_empty() // if range.is_empty()
&& !protected_offsets.contains(&range.start) // && !protected_offsets.contains(&range.start)
&& messages.contains(&message.id) // && messages.contains(&message.id)
{ // {
excerpts_to_remove.push(message.excerpt_id); // excerpts_to_remove.push(message.excerpt_id);
self.messages_metadata.remove(&message.excerpt_id); // self.messages_metadata.remove(&message.excerpt_id);
false // false
} else { // } else {
true // true
} // }
}); // });
if !excerpts_to_remove.is_empty() { // if !excerpts_to_remove.is_empty() {
self.buffer.update(cx, |buffer, cx| { // self.buffer.update(cx, |buffer, cx| {
buffer.remove_excerpts(excerpts_to_remove, cx) // buffer.remove_excerpts(excerpts_to_remove, cx)
}); // });
cx.notify(); // cx.notify();
} // }
} }
fn cycle_message_role(&mut self, excerpt_id: ExcerptId, cx: &mut ModelContext<Self>) { fn cycle_message_role(&mut self, id: MessageId, cx: &mut ModelContext<Self>) {
if let Some(metadata) = self.messages_metadata.get_mut(&excerpt_id) { if let Some(metadata) = self.messages_metadata.get_mut(&id) {
metadata.role.cycle(); metadata.role.cycle();
cx.notify(); cx.notify();
} }
@ -686,15 +690,18 @@ impl Assistant {
.position(|message| message.id == message_id) .position(|message| message.id == message_id)
{ {
let start = self.buffer.update(cx, |buffer, cx| { let start = self.buffer.update(cx, |buffer, cx| {
let len = buffer.len(); let offset = self
buffer.edit([(len..len, "\n")], None, cx); .messages
buffer.anchor_before(len + 1) .get(prev_message_ix + 1)
.map_or(buffer.len(), |message| message.start.to_offset(buffer) - 1);
buffer.edit([(offset..offset, "\n")], None, cx);
buffer.anchor_before(offset + 1)
}); });
let message = Message { let message = Message {
id: MessageId(post_inc(&mut self.next_message_id.0)), id: MessageId(post_inc(&mut self.next_message_id.0)),
start, start,
}; };
self.messages.insert(prev_message_ix, message.clone()); self.messages.insert(prev_message_ix + 1, message.clone());
self.messages_metadata.insert( self.messages_metadata.insert(
message.id, message.id,
MessageMetadata { MessageMetadata {
@ -713,24 +720,13 @@ impl Assistant {
if self.messages.len() >= 2 && self.summary.is_none() { if self.messages.len() >= 2 && self.summary.is_none() {
let api_key = self.api_key.borrow().clone(); let api_key = self.api_key.borrow().clone();
if let Some(api_key) = api_key { if let Some(api_key) = api_key {
// let messages = self let mut messages = self.open_ai_request_messages(cx);
// .messages messages.truncate(2);
// .iter() messages.push(RequestMessage {
// .take(2) role: Role::User,
// .filter_map(|message| { content: "Summarize the conversation into a short title without punctuation"
// Some(RequestMessage { .into(),
// role: self.messages_metadata.get(&message.id)?.role, });
// content: message.content.read(cx).text(),
// })
// })
// .chain(Some(RequestMessage {
// role: Role::User,
// content:
// "Summarize the conversation into a short title without punctuation"
// .into(),
// }))
// .collect();
let messages = todo!();
let request = OpenAIRequest { let request = OpenAIRequest {
model: self.model.clone(), model: self.model.clone(),
messages, messages,
@ -760,6 +756,44 @@ impl Assistant {
} }
} }
} }
fn open_ai_request_messages(&self, cx: &AppContext) -> Vec<RequestMessage> {
let buffer = self.buffer.read(cx);
self.messages(cx)
.map(|(message, metadata, range)| RequestMessage {
role: metadata.role,
content: buffer.text_for_range(range).collect(),
})
.collect()
}
fn message_id_for_offset(&self, offset: usize, cx: &AppContext) -> Option<MessageId> {
Some(
self.messages(cx)
.find(|(_, _, range)| range.contains(&offset))
.map(|(message, _, _)| message)
.or(self.messages.last())?
.id,
)
}
fn messages<'a>(
&'a self,
cx: &'a AppContext,
) -> impl 'a + Iterator<Item = (&Message, &MessageMetadata, Range<usize>)> {
let buffer = self.buffer.read(cx);
let mut messages = self.messages.iter().peekable();
iter::from_fn(move || {
let message = messages.next()?;
let metadata = self.messages_metadata.get(&message.id)?;
let message_start = message.start.to_offset(buffer);
let message_end = messages
.peek()
.map_or(language::Anchor::MAX, |message| message.start)
.to_offset(buffer);
Some((message, metadata, message_start..message_end))
})
}
} }
struct PendingCompletion { struct PendingCompletion {
@ -812,16 +846,12 @@ impl AssistantEditor {
fn assist(&mut self, _: &Assist, cx: &mut ViewContext<Self>) { fn assist(&mut self, _: &Assist, cx: &mut ViewContext<Self>) {
let user_message = self.assistant.update(cx, |assistant, cx| { let user_message = self.assistant.update(cx, |assistant, cx| {
let editor = self.editor.read(cx); let editor = self.editor.read(cx);
let newest_selection = editor.selections.newest_anchor(); let newest_selection = editor
let message_id = if newest_selection.head() == Anchor::min() { .selections
assistant.messages.first().map(|message| message.id)? .newest_anchor()
} else if newest_selection.head() == Anchor::max() { .head()
assistant.messages.last().map(|message| message.id)? .to_offset(&editor.buffer().read(cx).snapshot(cx));
} else { let message_id = assistant.message_id_for_offset(newest_selection, cx)?;
todo!()
// newest_selection.head().excerpt_id()
};
let metadata = assistant.messages_metadata.get(&message_id)?; let metadata = assistant.messages_metadata.get(&message_id)?;
let user_message = if metadata.role == Role::User { let user_message = if metadata.role == Role::User {
let (_, user_message) = assistant.assist(cx)?; let (_, user_message) = assistant.assist(cx)?;
@ -834,16 +864,14 @@ impl AssistantEditor {
}); });
if let Some(user_message) = user_message { if let Some(user_message) = user_message {
let cursor = user_message
.start
.to_offset(&self.assistant.read(cx).buffer.read(cx));
self.editor.update(cx, |editor, cx| { self.editor.update(cx, |editor, cx| {
let cursor = editor
.buffer()
.read(cx)
.snapshot(cx)
.anchor_in_excerpt(Default::default(), user_message.start);
editor.change_selections( editor.change_selections(
Some(Autoscroll::Strategy(AutoscrollStrategy::Fit)), Some(Autoscroll::Strategy(AutoscrollStrategy::Fit)),
cx, cx,
|selections| selections.select_anchor_ranges([cursor..cursor]), |selections| selections.select_ranges([cursor..cursor]),
); );
}); });
self.update_scroll_bottom(cx); self.update_scroll_bottom(cx);
@ -1011,7 +1039,7 @@ impl AssistantEditor {
let mut copied_text = String::new(); let mut copied_text = String::new();
let mut spanned_messages = 0; let mut spanned_messages = 0;
for message in &assistant.messages { for message in &assistant.messages {
// TODO todo!();
// let message_range = offset..offset + message.content.read(cx).len() + 1; // let message_range = offset..offset + message.content.read(cx).len() + 1;
let message_range = offset..offset + 1; let message_range = offset..offset + 1;
@ -1260,28 +1288,100 @@ mod tests {
#[gpui::test] #[gpui::test]
fn test_inserting_and_removing_messages(cx: &mut AppContext) { fn test_inserting_and_removing_messages(cx: &mut AppContext) {
let registry = Arc::new(LanguageRegistry::test()); let registry = Arc::new(LanguageRegistry::test());
let assistant = cx.add_model(|cx| Assistant::new(Default::default(), registry, cx));
let buffer = assistant.read(cx).buffer.clone();
cx.add_model(|cx| { let message_1 = assistant.read(cx).messages[0].clone();
let mut assistant = Assistant::new(Default::default(), registry, cx); assert_eq!(
let message_1 = assistant.messages[0].clone(); messages(&assistant, cx),
let message_2 = assistant vec![(message_1.id, Role::User, 0..0)]
.insert_message_after(message_1.id, Role::Assistant, cx) );
.unwrap();
let message_3 = assistant let message_2 = assistant.update(cx, |assistant, cx| {
.insert_message_after(message_2.id, Role::User, cx)
.unwrap();
let message_4 = assistant
.insert_message_after(message_2.id, Role::User, cx)
.unwrap();
assistant.remove_empty_messages(
HashSet::from_iter([message_3.id, message_4.id]),
Default::default(),
cx,
);
assert_eq!(assistant.messages.len(), 2);
assert_eq!(assistant.messages[0].id, message_1.id);
assert_eq!(assistant.messages[1].id, message_2.id);
assistant assistant
.insert_message_after(message_1.id, Role::Assistant, cx)
.unwrap()
}); });
assert_eq!(
messages(&assistant, cx),
vec![
(message_1.id, Role::User, 0..1),
(message_2.id, Role::Assistant, 1..1)
]
);
buffer.update(cx, |buffer, cx| {
buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
});
assert_eq!(
messages(&assistant, cx),
vec![
(message_1.id, Role::User, 0..2),
(message_2.id, Role::Assistant, 2..3)
]
);
let message_3 = assistant.update(cx, |assistant, cx| {
assistant
.insert_message_after(message_2.id, Role::User, cx)
.unwrap()
});
assert_eq!(
messages(&assistant, cx),
vec![
(message_1.id, Role::User, 0..2),
(message_2.id, Role::Assistant, 2..4),
(message_3.id, Role::User, 4..4)
]
);
let message_4 = assistant.update(cx, |assistant, cx| {
assistant
.insert_message_after(message_2.id, Role::User, cx)
.unwrap()
});
assert_eq!(
messages(&assistant, cx),
vec![
(message_1.id, Role::User, 0..2),
(message_2.id, Role::Assistant, 2..4),
(message_4.id, Role::User, 4..5),
(message_3.id, Role::User, 5..5),
]
);
buffer.update(cx, |buffer, cx| {
buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
});
assert_eq!(
messages(&assistant, cx),
vec![
(message_1.id, Role::User, 0..2),
(message_2.id, Role::Assistant, 2..4),
(message_4.id, Role::User, 4..6),
(message_3.id, Role::User, 6..7),
]
);
// Deleting across message boundaries merges the messages.
buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
assert_eq!(
messages(&assistant, cx),
vec![
(message_1.id, Role::User, 0..6),
(message_3.id, Role::User, 6..7),
]
);
}
fn messages(
assistant: &ModelHandle<Assistant>,
cx: &AppContext,
) -> Vec<(MessageId, Role, Range<usize>)> {
assistant
.read(cx)
.messages(cx)
.map(|(message, metadata, range)| (message.id, metadata.role, range))
.collect()
} }
} }