mirror of
https://github.com/zed-industries/zed.git
synced 2024-10-26 08:31:04 +00:00
Test serialization roundtrip
This commit is contained in:
parent
c38bf2de33
commit
6c7271c633
1 changed files with 146 additions and 92 deletions
|
@ -492,15 +492,14 @@ impl AssistantPanel {
|
|||
}
|
||||
|
||||
let fs = self.fs.clone();
|
||||
let conversation = Conversation::load(
|
||||
path.clone(),
|
||||
self.api_key.clone(),
|
||||
self.languages.clone(),
|
||||
self.fs.clone(),
|
||||
cx,
|
||||
);
|
||||
let api_key = self.api_key.clone();
|
||||
let languages = self.languages.clone();
|
||||
cx.spawn(|this, mut cx| async move {
|
||||
let conversation = conversation.await?;
|
||||
let saved_conversation = fs.load(&path).await?;
|
||||
let saved_conversation = serde_json::from_str(&saved_conversation)?;
|
||||
let conversation = cx.add_model(|cx| {
|
||||
Conversation::deserialize(saved_conversation, path.clone(), api_key, languages, cx)
|
||||
});
|
||||
this.update(&mut cx, |this, cx| {
|
||||
// If, by the time we've loaded the conversation, the user has already opened
|
||||
// the same conversation, we don't want to open it again.
|
||||
|
@ -508,7 +507,7 @@ impl AssistantPanel {
|
|||
this.set_active_editor_index(Some(ix), cx);
|
||||
} else {
|
||||
let editor = cx
|
||||
.add_view(|cx| ConversationEditor::from_conversation(conversation, fs, cx));
|
||||
.add_view(|cx| ConversationEditor::for_conversation(conversation, fs, cx));
|
||||
this.add_conversation(editor, cx);
|
||||
}
|
||||
})?;
|
||||
|
@ -861,17 +860,35 @@ impl Conversation {
|
|||
this
|
||||
}
|
||||
|
||||
fn load(
|
||||
fn serialize(&self, cx: &AppContext) -> SavedConversation {
|
||||
SavedConversation {
|
||||
zed: "conversation".into(),
|
||||
version: SavedConversation::VERSION.into(),
|
||||
text: self.buffer.read(cx).text(),
|
||||
message_metadata: self.messages_metadata.clone(),
|
||||
messages: self
|
||||
.messages(cx)
|
||||
.map(|message| SavedMessage {
|
||||
id: message.id,
|
||||
start: message.range.start,
|
||||
})
|
||||
.collect(),
|
||||
summary: self
|
||||
.summary
|
||||
.as_ref()
|
||||
.map(|summary| summary.text.clone())
|
||||
.unwrap_or_default(),
|
||||
model: self.model.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn deserialize(
|
||||
saved_conversation: SavedConversation,
|
||||
path: PathBuf,
|
||||
api_key: Rc<RefCell<Option<String>>>,
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
fs: Arc<dyn Fs>,
|
||||
cx: &mut AppContext,
|
||||
) -> Task<Result<ModelHandle<Self>>> {
|
||||
cx.spawn(|mut cx| async move {
|
||||
let saved_conversation = fs.load(&path).await?;
|
||||
let saved_conversation: SavedConversation = serde_json::from_str(&saved_conversation)?;
|
||||
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Self {
|
||||
let model = saved_conversation.model;
|
||||
let markdown = language_registry.language_for_name("Markdown");
|
||||
let mut message_anchors = Vec::new();
|
||||
|
@ -899,7 +916,7 @@ impl Conversation {
|
|||
.detach_and_log_err(cx);
|
||||
buffer
|
||||
});
|
||||
let conversation = cx.add_model(|cx| {
|
||||
|
||||
let mut this = Self {
|
||||
message_anchors,
|
||||
messages_metadata: saved_conversation.message_metadata,
|
||||
|
@ -921,12 +938,8 @@ impl Conversation {
|
|||
api_key,
|
||||
buffer,
|
||||
};
|
||||
|
||||
this.count_remaining_tokens(cx);
|
||||
this
|
||||
});
|
||||
Ok(conversation)
|
||||
})
|
||||
}
|
||||
|
||||
fn handle_buffer_event(
|
||||
|
@ -1453,23 +1466,7 @@ impl Conversation {
|
|||
});
|
||||
|
||||
if let Some(summary) = summary {
|
||||
let conversation = this.read_with(&cx, |this, cx| SavedConversation {
|
||||
zed: "conversation".into(),
|
||||
version: SavedConversation::VERSION.into(),
|
||||
text: this.buffer.read(cx).text(),
|
||||
message_metadata: this.messages_metadata.clone(),
|
||||
messages: this
|
||||
.message_anchors
|
||||
.iter()
|
||||
.map(|message| SavedMessage {
|
||||
id: message.id,
|
||||
start: message.start.to_offset(this.buffer.read(cx)),
|
||||
})
|
||||
.collect(),
|
||||
summary: summary.clone(),
|
||||
model: this.model.clone(),
|
||||
});
|
||||
|
||||
let conversation = this.read_with(&cx, |this, cx| this.serialize(cx));
|
||||
let path = if let Some(old_path) = old_path {
|
||||
old_path
|
||||
} else {
|
||||
|
@ -1533,10 +1530,10 @@ impl ConversationEditor {
|
|||
cx: &mut ViewContext<Self>,
|
||||
) -> Self {
|
||||
let conversation = cx.add_model(|cx| Conversation::new(api_key, language_registry, cx));
|
||||
Self::from_conversation(conversation, fs, cx)
|
||||
Self::for_conversation(conversation, fs, cx)
|
||||
}
|
||||
|
||||
fn from_conversation(
|
||||
fn for_conversation(
|
||||
conversation: ModelHandle<Conversation>,
|
||||
fs: Arc<dyn Fs>,
|
||||
cx: &mut ViewContext<Self>,
|
||||
|
@ -2116,9 +2113,8 @@ async fn stream_completion(
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::MessageId;
|
||||
|
||||
use super::*;
|
||||
use crate::MessageId;
|
||||
use fs::FakeFs;
|
||||
use gpui::{AppContext, TestAppContext};
|
||||
use project::Project;
|
||||
|
@ -2464,6 +2460,64 @@ mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_serialization(cx: &mut AppContext) {
|
||||
let registry = Arc::new(LanguageRegistry::test());
|
||||
let conversation =
|
||||
cx.add_model(|cx| Conversation::new(Default::default(), registry.clone(), cx));
|
||||
let buffer = conversation.read(cx).buffer.clone();
|
||||
let message_0 = conversation.read(cx).message_anchors[0].id;
|
||||
let message_1 = conversation.update(cx, |conversation, cx| {
|
||||
conversation
|
||||
.insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
|
||||
.unwrap()
|
||||
});
|
||||
let message_2 = conversation.update(cx, |conversation, cx| {
|
||||
conversation
|
||||
.insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
|
||||
.unwrap()
|
||||
});
|
||||
buffer.update(cx, |buffer, cx| {
|
||||
buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx);
|
||||
buffer.finalize_last_transaction();
|
||||
});
|
||||
let _message_3 = conversation.update(cx, |conversation, cx| {
|
||||
conversation
|
||||
.insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx)
|
||||
.unwrap()
|
||||
});
|
||||
buffer.update(cx, |buffer, cx| buffer.undo(cx));
|
||||
assert_eq!(buffer.read(cx).text(), "a\nb\nc\n");
|
||||
assert_eq!(
|
||||
messages(&conversation, cx),
|
||||
[
|
||||
(message_0, Role::User, 0..2),
|
||||
(message_1.id, Role::Assistant, 2..6),
|
||||
(message_2.id, Role::System, 6..6),
|
||||
]
|
||||
);
|
||||
|
||||
let deserialized_conversation = cx.add_model(|cx| {
|
||||
Conversation::deserialize(
|
||||
conversation.read(cx).serialize(cx),
|
||||
Default::default(),
|
||||
Default::default(),
|
||||
registry.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let deserialized_buffer = deserialized_conversation.read(cx).buffer.clone();
|
||||
assert_eq!(deserialized_buffer.read(cx).text(), "a\nb\nc\n");
|
||||
assert_eq!(
|
||||
messages(&deserialized_conversation, cx),
|
||||
[
|
||||
(message_0, Role::User, 0..2),
|
||||
(message_1.id, Role::Assistant, 2..6),
|
||||
(message_2.id, Role::System, 6..6),
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
fn messages(
|
||||
conversation: &ModelHandle<Conversation>,
|
||||
cx: &AppContext,
|
||||
|
|
Loading…
Reference in a new issue