Test serialization roundtrip

This commit is contained in:
Antonio Scandurra 2023-06-23 10:42:15 +02:00
parent c38bf2de33
commit 6c7271c633

View file

@ -492,15 +492,14 @@ impl AssistantPanel {
}
let fs = self.fs.clone();
let conversation = Conversation::load(
path.clone(),
self.api_key.clone(),
self.languages.clone(),
self.fs.clone(),
cx,
);
let api_key = self.api_key.clone();
let languages = self.languages.clone();
cx.spawn(|this, mut cx| async move {
let conversation = conversation.await?;
let saved_conversation = fs.load(&path).await?;
let saved_conversation = serde_json::from_str(&saved_conversation)?;
let conversation = cx.add_model(|cx| {
Conversation::deserialize(saved_conversation, path.clone(), api_key, languages, cx)
});
this.update(&mut cx, |this, cx| {
// If, by the time we've loaded the conversation, the user has already opened
// the same conversation, we don't want to open it again.
@ -508,7 +507,7 @@ impl AssistantPanel {
this.set_active_editor_index(Some(ix), cx);
} else {
let editor = cx
.add_view(|cx| ConversationEditor::from_conversation(conversation, fs, cx));
.add_view(|cx| ConversationEditor::for_conversation(conversation, fs, cx));
this.add_conversation(editor, cx);
}
})?;
@ -861,72 +860,86 @@ impl Conversation {
this
}
fn load(
fn serialize(&self, cx: &AppContext) -> SavedConversation {
SavedConversation {
zed: "conversation".into(),
version: SavedConversation::VERSION.into(),
text: self.buffer.read(cx).text(),
message_metadata: self.messages_metadata.clone(),
messages: self
.messages(cx)
.map(|message| SavedMessage {
id: message.id,
start: message.range.start,
})
.collect(),
summary: self
.summary
.as_ref()
.map(|summary| summary.text.clone())
.unwrap_or_default(),
model: self.model.clone(),
}
}
fn deserialize(
saved_conversation: SavedConversation,
path: PathBuf,
api_key: Rc<RefCell<Option<String>>>,
language_registry: Arc<LanguageRegistry>,
fs: Arc<dyn Fs>,
cx: &mut AppContext,
) -> Task<Result<ModelHandle<Self>>> {
cx.spawn(|mut cx| async move {
let saved_conversation = fs.load(&path).await?;
let saved_conversation: SavedConversation = serde_json::from_str(&saved_conversation)?;
cx: &mut ModelContext<Self>,
) -> Self {
let model = saved_conversation.model;
let markdown = language_registry.language_for_name("Markdown");
let mut message_anchors = Vec::new();
let mut next_message_id = MessageId(0);
let buffer = cx.add_model(|cx| {
let mut buffer = Buffer::new(0, saved_conversation.text, cx);
for message in saved_conversation.messages {
message_anchors.push(MessageAnchor {
id: message.id,
start: buffer.anchor_before(message.start),
});
next_message_id = cmp::max(next_message_id, MessageId(message.id.0 + 1));
}
buffer.set_language_registry(language_registry);
cx.spawn_weak(|buffer, mut cx| async move {
let markdown = markdown.await?;
let buffer = buffer
.upgrade(&cx)
.ok_or_else(|| anyhow!("buffer was dropped"))?;
buffer.update(&mut cx, |buffer, cx| {
buffer.set_language(Some(markdown), cx)
});
anyhow::Ok(())
})
.detach_and_log_err(cx);
buffer
});
let model = saved_conversation.model;
let markdown = language_registry.language_for_name("Markdown");
let mut message_anchors = Vec::new();
let mut next_message_id = MessageId(0);
let buffer = cx.add_model(|cx| {
let mut buffer = Buffer::new(0, saved_conversation.text, cx);
for message in saved_conversation.messages {
message_anchors.push(MessageAnchor {
id: message.id,
start: buffer.anchor_before(message.start),
});
next_message_id = cmp::max(next_message_id, MessageId(message.id.0 + 1));
}
buffer.set_language_registry(language_registry);
cx.spawn_weak(|buffer, mut cx| async move {
let markdown = markdown.await?;
let buffer = buffer
.upgrade(&cx)
.ok_or_else(|| anyhow!("buffer was dropped"))?;
buffer.update(&mut cx, |buffer, cx| {
buffer.set_language(Some(markdown), cx)
});
anyhow::Ok(())
})
.detach_and_log_err(cx);
buffer
});
let conversation = cx.add_model(|cx| {
let mut this = Self {
message_anchors,
messages_metadata: saved_conversation.message_metadata,
next_message_id,
summary: Some(Summary {
text: saved_conversation.summary,
done: true,
}),
pending_summary: Task::ready(None),
completion_count: Default::default(),
pending_completions: Default::default(),
token_count: None,
max_token_count: tiktoken_rs::model::get_context_size(&model),
pending_token_count: Task::ready(None),
model,
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
pending_save: Task::ready(Ok(())),
path: Some(path),
api_key,
buffer,
};
this.count_remaining_tokens(cx);
this
});
Ok(conversation)
})
let mut this = Self {
message_anchors,
messages_metadata: saved_conversation.message_metadata,
next_message_id,
summary: Some(Summary {
text: saved_conversation.summary,
done: true,
}),
pending_summary: Task::ready(None),
completion_count: Default::default(),
pending_completions: Default::default(),
token_count: None,
max_token_count: tiktoken_rs::model::get_context_size(&model),
pending_token_count: Task::ready(None),
model,
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
pending_save: Task::ready(Ok(())),
path: Some(path),
api_key,
buffer,
};
this.count_remaining_tokens(cx);
this
}
fn handle_buffer_event(
@ -1453,23 +1466,7 @@ impl Conversation {
});
if let Some(summary) = summary {
let conversation = this.read_with(&cx, |this, cx| SavedConversation {
zed: "conversation".into(),
version: SavedConversation::VERSION.into(),
text: this.buffer.read(cx).text(),
message_metadata: this.messages_metadata.clone(),
messages: this
.message_anchors
.iter()
.map(|message| SavedMessage {
id: message.id,
start: message.start.to_offset(this.buffer.read(cx)),
})
.collect(),
summary: summary.clone(),
model: this.model.clone(),
});
let conversation = this.read_with(&cx, |this, cx| this.serialize(cx));
let path = if let Some(old_path) = old_path {
old_path
} else {
@ -1533,10 +1530,10 @@ impl ConversationEditor {
cx: &mut ViewContext<Self>,
) -> Self {
let conversation = cx.add_model(|cx| Conversation::new(api_key, language_registry, cx));
Self::from_conversation(conversation, fs, cx)
Self::for_conversation(conversation, fs, cx)
}
fn from_conversation(
fn for_conversation(
conversation: ModelHandle<Conversation>,
fs: Arc<dyn Fs>,
cx: &mut ViewContext<Self>,
@ -2116,9 +2113,8 @@ async fn stream_completion(
#[cfg(test)]
mod tests {
use crate::MessageId;
use super::*;
use crate::MessageId;
use fs::FakeFs;
use gpui::{AppContext, TestAppContext};
use project::Project;
@ -2464,6 +2460,64 @@ mod tests {
}
}
#[gpui::test]
fn test_serialization(cx: &mut AppContext) {
let registry = Arc::new(LanguageRegistry::test());
let conversation =
cx.add_model(|cx| Conversation::new(Default::default(), registry.clone(), cx));
let buffer = conversation.read(cx).buffer.clone();
let message_0 = conversation.read(cx).message_anchors[0].id;
let message_1 = conversation.update(cx, |conversation, cx| {
conversation
.insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
.unwrap()
});
let message_2 = conversation.update(cx, |conversation, cx| {
conversation
.insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
.unwrap()
});
buffer.update(cx, |buffer, cx| {
buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx);
buffer.finalize_last_transaction();
});
let _message_3 = conversation.update(cx, |conversation, cx| {
conversation
.insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx)
.unwrap()
});
buffer.update(cx, |buffer, cx| buffer.undo(cx));
assert_eq!(buffer.read(cx).text(), "a\nb\nc\n");
assert_eq!(
messages(&conversation, cx),
[
(message_0, Role::User, 0..2),
(message_1.id, Role::Assistant, 2..6),
(message_2.id, Role::System, 6..6),
]
);
let deserialized_conversation = cx.add_model(|cx| {
Conversation::deserialize(
conversation.read(cx).serialize(cx),
Default::default(),
Default::default(),
registry.clone(),
cx,
)
});
let deserialized_buffer = deserialized_conversation.read(cx).buffer.clone();
assert_eq!(deserialized_buffer.read(cx).text(), "a\nb\nc\n");
assert_eq!(
messages(&deserialized_conversation, cx),
[
(message_0, Role::User, 0..2),
(message_1.id, Role::Assistant, 2..6),
(message_2.id, Role::System, 6..6),
]
);
}
fn messages(
conversation: &ModelHandle<Conversation>,
cx: &AppContext,