Get back to a compiling state with Buffer backing the assistant

This commit is contained in:
Antonio Scandurra 2023-06-13 13:43:06 +02:00
parent 7db690b713
commit 2ae8b558b9

View file

@ -11,7 +11,7 @@ use editor::{
autoscroll::{Autoscroll, AutoscrollStrategy},
ScrollAnchor,
},
Anchor, DisplayPoint, Editor, ExcerptId,
Anchor, DisplayPoint, Editor, ToOffset as _,
};
use fs::Fs;
use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
@ -25,10 +25,13 @@ use gpui::{
Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle, WindowContext,
};
use isahc::{http::StatusCode, Request, RequestExt};
use language::{language_settings::SoftWrap, Buffer, LanguageRegistry};
use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _};
use serde::Deserialize;
use settings::SettingsStore;
use std::{borrow::Cow, cell::RefCell, cmp, fmt::Write, io, rc::Rc, sync::Arc, time::Duration};
use std::{
borrow::Cow, cell::RefCell, cmp, fmt::Write, io, iter, ops::Range, rc::Rc, sync::Arc,
time::Duration,
};
use util::{post_inc, truncate_and_trailoff, ResultExt, TryFutureExt};
use workspace::{
dock::{DockPosition, Panel},
@ -507,16 +510,16 @@ impl Assistant {
fn count_remaining_tokens(&mut self, cx: &mut ModelContext<Self>) {
let messages = self
.messages
.iter()
.open_ai_request_messages(cx)
.into_iter()
.filter_map(|message| {
Some(tiktoken_rs::ChatCompletionRequestMessage {
role: match self.messages_metadata.get(&message.excerpt_id)?.role {
role: match message.role {
Role::User => "user".into(),
Role::Assistant => "assistant".into(),
Role::System => "system".into(),
},
content: message.content.read(cx).text(),
content: message.content,
name: None,
})
})
@ -554,45 +557,47 @@ impl Assistant {
}
fn assist(&mut self, cx: &mut ModelContext<Self>) -> Option<(Message, Message)> {
let messages = self
.messages
.iter()
.filter_map(|message| {
Some(RequestMessage {
role: self.messages_metadata.get(&message.excerpt_id)?.role,
content: message.content.read(cx).text(),
})
})
.collect();
let request = OpenAIRequest {
model: self.model.clone(),
messages,
messages: self.open_ai_request_messages(cx),
stream: true,
};
let api_key = self.api_key.borrow().clone()?;
let stream = stream_completion(api_key, cx.background().clone(), request);
let assistant_message = self.insert_message_after(ExcerptId::max(), Role::Assistant, cx);
let user_message = self.insert_message_after(ExcerptId::max(), Role::User, cx);
let assistant_message =
self.insert_message_after(self.messages.last()?.id, Role::Assistant, cx)?;
let user_message = self.insert_message_after(assistant_message.id, Role::User, cx)?;
let task = cx.spawn_weak({
let assistant_message = assistant_message.clone();
|this, mut cx| async move {
let assistant_message = assistant_message;
let assistant_message_id = assistant_message.id;
let stream_completion = async {
let mut messages = stream.await?;
while let Some(message) = messages.next().await {
let mut message = message?;
if let Some(choice) = message.choices.pop() {
assistant_message.content.update(&mut cx, |content, cx| {
let text: Arc<str> = choice.delta.content?.into();
content.edit([(content.len()..content.len(), text)], None, cx);
Some(())
});
this.upgrade(&cx)
.ok_or_else(|| anyhow!("assistant was dropped"))?
.update(&mut cx, |_, cx| {
cx.emit(AssistantEvent::StreamedCompletion);
.update(&mut cx, |this, cx| {
let text: Arc<str> = choice.delta.content?.into();
let message_ix = this
.messages
.iter()
.position(|message| message.id == assistant_message_id)?;
this.buffer.update(cx, |buffer, cx| {
let offset = if message_ix + 1 == this.messages.len() {
buffer.len()
} else {
this.messages[message_ix + 1]
.start
.to_offset(buffer)
.saturating_sub(1)
};
buffer.edit([(offset..offset, text)], None, cx);
});
Some(())
});
}
}
@ -612,9 +617,8 @@ impl Assistant {
if let Some(this) = this.upgrade(&cx) {
this.update(&mut cx, |this, cx| {
if let Err(error) = result {
if let Some(metadata) = this
.messages_metadata
.get_mut(&assistant_message.excerpt_id)
if let Some(metadata) =
this.messages_metadata.get_mut(&assistant_message.id)
{
metadata.error = Some(error.to_string().trim().into());
cx.notify();
@ -642,33 +646,33 @@ impl Assistant {
protected_offsets: HashSet<usize>,
cx: &mut ModelContext<Self>,
) {
let mut offset = 0;
let mut excerpts_to_remove = Vec::new();
self.messages.retain(|message| {
let range = offset..offset + message.content.read(cx).len();
offset = range.end + 1;
if range.is_empty()
&& !protected_offsets.contains(&range.start)
&& messages.contains(&message.id)
{
excerpts_to_remove.push(message.excerpt_id);
self.messages_metadata.remove(&message.excerpt_id);
false
} else {
true
}
});
// let mut offset = 0;
// let mut excerpts_to_remove = Vec::new();
// self.messages.retain(|message| {
// let range = offset..offset + message.content.read(cx).len();
// offset = range.end + 1;
// if range.is_empty()
// && !protected_offsets.contains(&range.start)
// && messages.contains(&message.id)
// {
// excerpts_to_remove.push(message.excerpt_id);
// self.messages_metadata.remove(&message.excerpt_id);
// false
// } else {
// true
// }
// });
if !excerpts_to_remove.is_empty() {
self.buffer.update(cx, |buffer, cx| {
buffer.remove_excerpts(excerpts_to_remove, cx)
});
cx.notify();
}
// if !excerpts_to_remove.is_empty() {
// self.buffer.update(cx, |buffer, cx| {
// buffer.remove_excerpts(excerpts_to_remove, cx)
// });
// cx.notify();
// }
}
fn cycle_message_role(&mut self, excerpt_id: ExcerptId, cx: &mut ModelContext<Self>) {
if let Some(metadata) = self.messages_metadata.get_mut(&excerpt_id) {
fn cycle_message_role(&mut self, id: MessageId, cx: &mut ModelContext<Self>) {
if let Some(metadata) = self.messages_metadata.get_mut(&id) {
metadata.role.cycle();
cx.notify();
}
@ -686,15 +690,18 @@ impl Assistant {
.position(|message| message.id == message_id)
{
let start = self.buffer.update(cx, |buffer, cx| {
let len = buffer.len();
buffer.edit([(len..len, "\n")], None, cx);
buffer.anchor_before(len + 1)
let offset = self
.messages
.get(prev_message_ix + 1)
.map_or(buffer.len(), |message| message.start.to_offset(buffer) - 1);
buffer.edit([(offset..offset, "\n")], None, cx);
buffer.anchor_before(offset + 1)
});
let message = Message {
id: MessageId(post_inc(&mut self.next_message_id.0)),
start,
};
self.messages.insert(prev_message_ix, message.clone());
self.messages.insert(prev_message_ix + 1, message.clone());
self.messages_metadata.insert(
message.id,
MessageMetadata {
@ -713,24 +720,13 @@ impl Assistant {
if self.messages.len() >= 2 && self.summary.is_none() {
let api_key = self.api_key.borrow().clone();
if let Some(api_key) = api_key {
// let messages = self
// .messages
// .iter()
// .take(2)
// .filter_map(|message| {
// Some(RequestMessage {
// role: self.messages_metadata.get(&message.id)?.role,
// content: message.content.read(cx).text(),
// })
// })
// .chain(Some(RequestMessage {
// role: Role::User,
// content:
// "Summarize the conversation into a short title without punctuation"
// .into(),
// }))
// .collect();
let messages = todo!();
let mut messages = self.open_ai_request_messages(cx);
messages.truncate(2);
messages.push(RequestMessage {
role: Role::User,
content: "Summarize the conversation into a short title without punctuation"
.into(),
});
let request = OpenAIRequest {
model: self.model.clone(),
messages,
@ -760,6 +756,44 @@ impl Assistant {
}
}
}
fn open_ai_request_messages(&self, cx: &AppContext) -> Vec<RequestMessage> {
let buffer = self.buffer.read(cx);
self.messages(cx)
.map(|(message, metadata, range)| RequestMessage {
role: metadata.role,
content: buffer.text_for_range(range).collect(),
})
.collect()
}
fn message_id_for_offset(&self, offset: usize, cx: &AppContext) -> Option<MessageId> {
Some(
self.messages(cx)
.find(|(_, _, range)| range.contains(&offset))
.map(|(message, _, _)| message)
.or(self.messages.last())?
.id,
)
}
fn messages<'a>(
&'a self,
cx: &'a AppContext,
) -> impl 'a + Iterator<Item = (&Message, &MessageMetadata, Range<usize>)> {
let buffer = self.buffer.read(cx);
let mut messages = self.messages.iter().peekable();
iter::from_fn(move || {
let message = messages.next()?;
let metadata = self.messages_metadata.get(&message.id)?;
let message_start = message.start.to_offset(buffer);
let message_end = messages
.peek()
.map_or(language::Anchor::MAX, |message| message.start)
.to_offset(buffer);
Some((message, metadata, message_start..message_end))
})
}
}
struct PendingCompletion {
@ -812,16 +846,12 @@ impl AssistantEditor {
fn assist(&mut self, _: &Assist, cx: &mut ViewContext<Self>) {
let user_message = self.assistant.update(cx, |assistant, cx| {
let editor = self.editor.read(cx);
let newest_selection = editor.selections.newest_anchor();
let message_id = if newest_selection.head() == Anchor::min() {
assistant.messages.first().map(|message| message.id)?
} else if newest_selection.head() == Anchor::max() {
assistant.messages.last().map(|message| message.id)?
} else {
todo!()
// newest_selection.head().excerpt_id()
};
let newest_selection = editor
.selections
.newest_anchor()
.head()
.to_offset(&editor.buffer().read(cx).snapshot(cx));
let message_id = assistant.message_id_for_offset(newest_selection, cx)?;
let metadata = assistant.messages_metadata.get(&message_id)?;
let user_message = if metadata.role == Role::User {
let (_, user_message) = assistant.assist(cx)?;
@ -834,16 +864,14 @@ impl AssistantEditor {
});
if let Some(user_message) = user_message {
let cursor = user_message
.start
.to_offset(&self.assistant.read(cx).buffer.read(cx));
self.editor.update(cx, |editor, cx| {
let cursor = editor
.buffer()
.read(cx)
.snapshot(cx)
.anchor_in_excerpt(Default::default(), user_message.start);
editor.change_selections(
Some(Autoscroll::Strategy(AutoscrollStrategy::Fit)),
cx,
|selections| selections.select_anchor_ranges([cursor..cursor]),
|selections| selections.select_ranges([cursor..cursor]),
);
});
self.update_scroll_bottom(cx);
@ -1011,7 +1039,7 @@ impl AssistantEditor {
let mut copied_text = String::new();
let mut spanned_messages = 0;
for message in &assistant.messages {
// TODO
todo!();
// let message_range = offset..offset + message.content.read(cx).len() + 1;
let message_range = offset..offset + 1;
@ -1260,28 +1288,100 @@ mod tests {
#[gpui::test]
fn test_inserting_and_removing_messages(cx: &mut AppContext) {
let registry = Arc::new(LanguageRegistry::test());
let assistant = cx.add_model(|cx| Assistant::new(Default::default(), registry, cx));
let buffer = assistant.read(cx).buffer.clone();
cx.add_model(|cx| {
let mut assistant = Assistant::new(Default::default(), registry, cx);
let message_1 = assistant.messages[0].clone();
let message_2 = assistant
.insert_message_after(message_1.id, Role::Assistant, cx)
.unwrap();
let message_3 = assistant
.insert_message_after(message_2.id, Role::User, cx)
.unwrap();
let message_4 = assistant
.insert_message_after(message_2.id, Role::User, cx)
.unwrap();
assistant.remove_empty_messages(
HashSet::from_iter([message_3.id, message_4.id]),
Default::default(),
cx,
);
assert_eq!(assistant.messages.len(), 2);
assert_eq!(assistant.messages[0].id, message_1.id);
assert_eq!(assistant.messages[1].id, message_2.id);
let message_1 = assistant.read(cx).messages[0].clone();
assert_eq!(
messages(&assistant, cx),
vec![(message_1.id, Role::User, 0..0)]
);
let message_2 = assistant.update(cx, |assistant, cx| {
assistant
.insert_message_after(message_1.id, Role::Assistant, cx)
.unwrap()
});
assert_eq!(
messages(&assistant, cx),
vec![
(message_1.id, Role::User, 0..1),
(message_2.id, Role::Assistant, 1..1)
]
);
buffer.update(cx, |buffer, cx| {
buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
});
assert_eq!(
messages(&assistant, cx),
vec![
(message_1.id, Role::User, 0..2),
(message_2.id, Role::Assistant, 2..3)
]
);
let message_3 = assistant.update(cx, |assistant, cx| {
assistant
.insert_message_after(message_2.id, Role::User, cx)
.unwrap()
});
assert_eq!(
messages(&assistant, cx),
vec![
(message_1.id, Role::User, 0..2),
(message_2.id, Role::Assistant, 2..4),
(message_3.id, Role::User, 4..4)
]
);
let message_4 = assistant.update(cx, |assistant, cx| {
assistant
.insert_message_after(message_2.id, Role::User, cx)
.unwrap()
});
assert_eq!(
messages(&assistant, cx),
vec![
(message_1.id, Role::User, 0..2),
(message_2.id, Role::Assistant, 2..4),
(message_4.id, Role::User, 4..5),
(message_3.id, Role::User, 5..5),
]
);
buffer.update(cx, |buffer, cx| {
buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
});
assert_eq!(
messages(&assistant, cx),
vec![
(message_1.id, Role::User, 0..2),
(message_2.id, Role::Assistant, 2..4),
(message_4.id, Role::User, 4..6),
(message_3.id, Role::User, 6..7),
]
);
// Deleting across message boundaries merges the messages.
buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
assert_eq!(
messages(&assistant, cx),
vec![
(message_1.id, Role::User, 0..6),
(message_3.id, Role::User, 6..7),
]
);
}
fn messages(
assistant: &ModelHandle<Assistant>,
cx: &AppContext,
) -> Vec<(MessageId, Role, Range<usize>)> {
assistant
.read(cx)
.messages(cx)
.map(|(message, metadata, range)| (message.id, metadata.role, range))
.collect()
}
}