mirror of
https://github.com/zed-industries/zed.git
synced 2025-01-27 04:44:30 +00:00
Simplify logic & add UI affordances to show model cache status (#16395)
Some checks are pending
CI / Check formatting and spelling (push) Waiting to run
CI / (macOS) Run Clippy and tests (push) Waiting to run
CI / (Linux) Run Clippy and tests (push) Waiting to run
CI / (Windows) Run Clippy and tests (push) Waiting to run
CI / Create a macOS bundle (push) Blocked by required conditions
CI / Create a Linux bundle (push) Blocked by required conditions
CI / Create arm64 Linux bundle (push) Blocked by required conditions
Deploy Docs / Deploy Docs (push) Waiting to run
Docs / Check formatting (push) Waiting to run
Some checks are pending
CI / Check formatting and spelling (push) Waiting to run
CI / (macOS) Run Clippy and tests (push) Waiting to run
CI / (Linux) Run Clippy and tests (push) Waiting to run
CI / (Windows) Run Clippy and tests (push) Waiting to run
CI / Create a macOS bundle (push) Blocked by required conditions
CI / Create a Linux bundle (push) Blocked by required conditions
CI / Create arm64 Linux bundle (push) Blocked by required conditions
Deploy Docs / Deploy Docs (push) Waiting to run
Docs / Check formatting (push) Waiting to run
Release Notes: - Adds UI affordances to the assistant panel to show which messages have been cached - Migrate cache invalidation to be based on `has_edits_since_in_range` to be smarter and more selective about when to invalidate the cache and when to fetch. <img width="310" alt="Screenshot 2024-08-16 at 11 19 23 PM" src="https://github.com/user-attachments/assets/4ee2d111-2f55-4b0e-b944-50c4f78afc42"> <img width="580" alt="Screenshot 2024-08-18 at 10 05 16 PM" src="https://github.com/user-attachments/assets/17630a60-7b78-421c-ae39-425246638a12"> I had originally added the lightening bolt on every message and only added the tooltip warning about editing prior messages on the first anchor, but thought it looked too busy, so I settled on just annotating the last anchor.
This commit is contained in:
parent
971db5c6f6
commit
0042c24d3c
5 changed files with 365 additions and 77 deletions
1
assets/icons/database_zap.svg
Normal file
1
assets/icons/database_zap.svg
Normal file
|
@ -0,0 +1 @@
|
|||
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-database-zap"><ellipse cx="12" cy="5" rx="9" ry="3"/><path d="M3 5V19A9 3 0 0 0 15 21.84"/><path d="M21 5V8"/><path d="M21 12L18 17H22L19 22"/><path d="M3 12A9 3 0 0 0 14.59 14.87"/></svg>
|
After Width: | Height: | Size: 391 B |
|
@ -11,11 +11,11 @@ use crate::{
|
|||
},
|
||||
slash_command_picker,
|
||||
terminal_inline_assistant::TerminalInlineAssistant,
|
||||
Assist, ConfirmCommand, Context, ContextEvent, ContextId, ContextStore, CycleMessageRole,
|
||||
DeployHistory, DeployPromptLibrary, InlineAssist, InlineAssistId, InlineAssistant,
|
||||
InsertIntoEditor, MessageStatus, ModelSelector, PendingSlashCommand, PendingSlashCommandStatus,
|
||||
QuoteSelection, RemoteContextMetadata, SavedContextMetadata, Split, ToggleFocus,
|
||||
ToggleModelSelector, WorkflowStepResolution, WorkflowStepView,
|
||||
Assist, CacheStatus, ConfirmCommand, Context, ContextEvent, ContextId, ContextStore,
|
||||
CycleMessageRole, DeployHistory, DeployPromptLibrary, InlineAssist, InlineAssistId,
|
||||
InlineAssistant, InsertIntoEditor, MessageStatus, ModelSelector, PendingSlashCommand,
|
||||
PendingSlashCommandStatus, QuoteSelection, RemoteContextMetadata, SavedContextMetadata, Split,
|
||||
ToggleFocus, ToggleModelSelector, WorkflowStepResolution, WorkflowStepView,
|
||||
};
|
||||
use crate::{ContextStoreEvent, ModelPickerDelegate};
|
||||
use anyhow::{anyhow, Result};
|
||||
|
@ -3137,6 +3137,36 @@ impl ContextEditor {
|
|||
.relative()
|
||||
.gap_1()
|
||||
.child(sender)
|
||||
.children(match &message.cache {
|
||||
Some(cache) if cache.is_final_anchor => match cache.status {
|
||||
CacheStatus::Cached => Some(
|
||||
div()
|
||||
.id("cached")
|
||||
.child(
|
||||
Icon::new(IconName::DatabaseZap)
|
||||
.size(IconSize::XSmall)
|
||||
.color(Color::Hint),
|
||||
)
|
||||
.tooltip(|cx| {
|
||||
Tooltip::with_meta(
|
||||
"Context cached",
|
||||
None,
|
||||
"Large messages cached to optimize performance",
|
||||
cx,
|
||||
)
|
||||
}).into_any_element()
|
||||
),
|
||||
CacheStatus::Pending => Some(
|
||||
div()
|
||||
.child(
|
||||
Icon::new(IconName::Ellipsis)
|
||||
.size(IconSize::XSmall)
|
||||
.color(Color::Hint),
|
||||
).into_any_element()
|
||||
),
|
||||
},
|
||||
_ => None,
|
||||
})
|
||||
.children(match &message.status {
|
||||
MessageStatus::Error(error) => Some(
|
||||
Button::new("show-error", "Error")
|
||||
|
|
|
@ -40,6 +40,7 @@ use std::{
|
|||
time::{Duration, Instant},
|
||||
};
|
||||
use telemetry_events::AssistantKind;
|
||||
use text::BufferSnapshot;
|
||||
use util::{post_inc, ResultExt, TryFutureExt};
|
||||
use uuid::Uuid;
|
||||
|
||||
|
@ -107,8 +108,7 @@ impl ContextOperation {
|
|||
message.status.context("invalid status")?,
|
||||
),
|
||||
timestamp: id.0,
|
||||
should_cache: false,
|
||||
is_cache_anchor: false,
|
||||
cache: None,
|
||||
},
|
||||
version: language::proto::deserialize_version(&insert.version),
|
||||
})
|
||||
|
@ -123,8 +123,7 @@ impl ContextOperation {
|
|||
timestamp: language::proto::deserialize_timestamp(
|
||||
update.timestamp.context("invalid timestamp")?,
|
||||
),
|
||||
should_cache: false,
|
||||
is_cache_anchor: false,
|
||||
cache: None,
|
||||
},
|
||||
version: language::proto::deserialize_version(&update.version),
|
||||
}),
|
||||
|
@ -313,13 +312,43 @@ pub struct MessageAnchor {
|
|||
pub start: language::Anchor,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub enum CacheStatus {
|
||||
Pending,
|
||||
Cached,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub struct MessageCacheMetadata {
|
||||
pub is_anchor: bool,
|
||||
pub is_final_anchor: bool,
|
||||
pub status: CacheStatus,
|
||||
pub cached_at: clock::Global,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
|
||||
pub struct MessageMetadata {
|
||||
pub role: Role,
|
||||
pub status: MessageStatus,
|
||||
timestamp: clock::Lamport,
|
||||
should_cache: bool,
|
||||
is_cache_anchor: bool,
|
||||
#[serde(skip)]
|
||||
pub cache: Option<MessageCacheMetadata>,
|
||||
}
|
||||
|
||||
impl MessageMetadata {
|
||||
pub fn is_cache_valid(&self, buffer: &BufferSnapshot, range: &Range<usize>) -> bool {
|
||||
let result = match &self.cache {
|
||||
Some(MessageCacheMetadata { cached_at, .. }) => !buffer.has_edits_since_in_range(
|
||||
&cached_at,
|
||||
Range {
|
||||
start: buffer.anchor_at(range.start, Bias::Right),
|
||||
end: buffer.anchor_at(range.end, Bias::Left),
|
||||
},
|
||||
),
|
||||
_ => false,
|
||||
};
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
|
@ -345,7 +374,7 @@ pub struct Message {
|
|||
pub anchor: language::Anchor,
|
||||
pub role: Role,
|
||||
pub status: MessageStatus,
|
||||
pub cache: bool,
|
||||
pub cache: Option<MessageCacheMetadata>,
|
||||
}
|
||||
|
||||
impl Message {
|
||||
|
@ -381,7 +410,7 @@ impl Message {
|
|||
Some(LanguageModelRequestMessage {
|
||||
role: self.role,
|
||||
content,
|
||||
cache: self.cache,
|
||||
cache: self.cache.as_ref().map_or(false, |cache| cache.is_anchor),
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -544,8 +573,7 @@ impl Context {
|
|||
role: Role::User,
|
||||
status: MessageStatus::Done,
|
||||
timestamp: first_message_id.0,
|
||||
should_cache: false,
|
||||
is_cache_anchor: false,
|
||||
cache: None,
|
||||
},
|
||||
);
|
||||
this.message_anchors.push(message);
|
||||
|
@ -979,7 +1007,7 @@ impl Context {
|
|||
});
|
||||
}
|
||||
|
||||
pub fn mark_longest_messages_for_cache(
|
||||
pub fn mark_cache_anchors(
|
||||
&mut self,
|
||||
cache_configuration: &Option<LanguageModelCacheConfiguration>,
|
||||
speculative: bool,
|
||||
|
@ -994,66 +1022,104 @@ impl Context {
|
|||
min_total_token: 0,
|
||||
});
|
||||
|
||||
let messages: Vec<Message> = self
|
||||
.messages_from_anchors(
|
||||
self.message_anchors.iter().take(if speculative {
|
||||
self.message_anchors.len().saturating_sub(1)
|
||||
} else {
|
||||
self.message_anchors.len()
|
||||
}),
|
||||
cx,
|
||||
)
|
||||
.filter(|message| message.offset_range.len() >= 5_000)
|
||||
.collect();
|
||||
let messages: Vec<Message> = self.messages(cx).collect();
|
||||
|
||||
let mut sorted_messages = messages.clone();
|
||||
sorted_messages.sort_by(|a, b| b.offset_range.len().cmp(&a.offset_range.len()));
|
||||
if cache_configuration.max_cache_anchors == 0 && cache_configuration.should_speculate {
|
||||
// Some models support caching, but don't support anchors. In that case we want to
|
||||
// mark the largest message as needing to be cached, but we will not mark it as an
|
||||
// anchor.
|
||||
sorted_messages.truncate(1);
|
||||
} else {
|
||||
// Save 1 anchor for the inline assistant.
|
||||
sorted_messages.truncate(max(cache_configuration.max_cache_anchors, 1) - 1);
|
||||
if speculative {
|
||||
// Avoid caching the last message if this is a speculative cache fetch as
|
||||
// it's likely to change.
|
||||
sorted_messages.pop();
|
||||
}
|
||||
sorted_messages.retain(|m| m.role == Role::User);
|
||||
sorted_messages.sort_by(|a, b| b.offset_range.len().cmp(&a.offset_range.len()));
|
||||
|
||||
let longest_message_ids: HashSet<MessageId> = sorted_messages
|
||||
let cache_anchors = if self.token_count.unwrap_or(0) < cache_configuration.min_total_token {
|
||||
// If we have't hit the minimum threshold to enable caching, don't cache anything.
|
||||
0
|
||||
} else {
|
||||
// Save 1 anchor for the inline assistant to use.
|
||||
max(cache_configuration.max_cache_anchors, 1) - 1
|
||||
};
|
||||
sorted_messages.truncate(cache_anchors);
|
||||
|
||||
let anchors: HashSet<MessageId> = sorted_messages
|
||||
.into_iter()
|
||||
.map(|message| message.id)
|
||||
.collect();
|
||||
|
||||
let cache_deltas: HashSet<MessageId> = self
|
||||
.messages_metadata
|
||||
let buffer = self.buffer.read(cx).snapshot();
|
||||
let invalidated_caches: HashSet<MessageId> = messages
|
||||
.iter()
|
||||
.filter_map(|(id, metadata)| {
|
||||
let should_cache = longest_message_ids.contains(id);
|
||||
let should_be_anchor = should_cache && cache_configuration.max_cache_anchors > 0;
|
||||
if metadata.should_cache != should_cache
|
||||
|| metadata.is_cache_anchor != should_be_anchor
|
||||
{
|
||||
Some(*id)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
.scan(false, |encountered_invalid, message| {
|
||||
let message_id = message.id;
|
||||
let is_invalid = self
|
||||
.messages_metadata
|
||||
.get(&message_id)
|
||||
.map_or(true, |metadata| {
|
||||
!metadata.is_cache_valid(&buffer, &message.offset_range)
|
||||
|| *encountered_invalid
|
||||
});
|
||||
*encountered_invalid |= is_invalid;
|
||||
Some(if is_invalid { Some(message_id) } else { None })
|
||||
})
|
||||
.flatten()
|
||||
.collect();
|
||||
|
||||
let mut newly_cached_item = false;
|
||||
for id in cache_deltas {
|
||||
newly_cached_item = newly_cached_item || longest_message_ids.contains(&id);
|
||||
self.update_metadata(id, cx, |metadata| {
|
||||
metadata.should_cache = longest_message_ids.contains(&id);
|
||||
metadata.is_cache_anchor =
|
||||
metadata.should_cache && (cache_configuration.max_cache_anchors > 0);
|
||||
let last_anchor = messages.iter().rev().find_map(|message| {
|
||||
if anchors.contains(&message.id) {
|
||||
Some(message.id)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
});
|
||||
|
||||
let mut new_anchor_needs_caching = false;
|
||||
let current_version = &buffer.version;
|
||||
// If we have no anchors, mark all messages as not being cached.
|
||||
let mut hit_last_anchor = last_anchor.is_none();
|
||||
|
||||
for message in messages.iter() {
|
||||
if hit_last_anchor {
|
||||
self.update_metadata(message.id, cx, |metadata| metadata.cache = None);
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(last_anchor) = last_anchor {
|
||||
if message.id == last_anchor {
|
||||
hit_last_anchor = true;
|
||||
}
|
||||
}
|
||||
|
||||
new_anchor_needs_caching = new_anchor_needs_caching
|
||||
|| (invalidated_caches.contains(&message.id) && anchors.contains(&message.id));
|
||||
|
||||
self.update_metadata(message.id, cx, |metadata| {
|
||||
let cache_status = if invalidated_caches.contains(&message.id) {
|
||||
CacheStatus::Pending
|
||||
} else {
|
||||
metadata
|
||||
.cache
|
||||
.as_ref()
|
||||
.map_or(CacheStatus::Pending, |cm| cm.status.clone())
|
||||
};
|
||||
metadata.cache = Some(MessageCacheMetadata {
|
||||
is_anchor: anchors.contains(&message.id),
|
||||
is_final_anchor: hit_last_anchor,
|
||||
status: cache_status,
|
||||
cached_at: current_version.clone(),
|
||||
});
|
||||
});
|
||||
}
|
||||
newly_cached_item
|
||||
new_anchor_needs_caching
|
||||
}
|
||||
|
||||
fn start_cache_warming(&mut self, model: &Arc<dyn LanguageModel>, cx: &mut ModelContext<Self>) {
|
||||
let cache_configuration = model.cache_configuration();
|
||||
if !self.mark_longest_messages_for_cache(&cache_configuration, true, cx) {
|
||||
|
||||
if !self.mark_cache_anchors(&cache_configuration, true, cx) {
|
||||
return;
|
||||
}
|
||||
if !self.pending_completions.is_empty() {
|
||||
return;
|
||||
}
|
||||
if let Some(cache_configuration) = cache_configuration {
|
||||
|
@ -1076,7 +1142,7 @@ impl Context {
|
|||
};
|
||||
|
||||
let model = Arc::clone(model);
|
||||
self.pending_cache_warming_task = cx.spawn(|_, cx| {
|
||||
self.pending_cache_warming_task = cx.spawn(|this, mut cx| {
|
||||
async move {
|
||||
match model.stream_completion(request, &cx).await {
|
||||
Ok(mut stream) => {
|
||||
|
@ -1087,13 +1153,41 @@ impl Context {
|
|||
log::warn!("Cache warming failed: {}", e);
|
||||
}
|
||||
};
|
||||
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.update_cache_status_for_completion(cx);
|
||||
})
|
||||
.ok();
|
||||
anyhow::Ok(())
|
||||
}
|
||||
.log_err()
|
||||
});
|
||||
}
|
||||
|
||||
pub fn update_cache_status_for_completion(&mut self, cx: &mut ModelContext<Self>) {
|
||||
let cached_message_ids: Vec<MessageId> = self
|
||||
.messages_metadata
|
||||
.iter()
|
||||
.filter_map(|(message_id, metadata)| {
|
||||
metadata.cache.as_ref().and_then(|cache| {
|
||||
if cache.status == CacheStatus::Pending {
|
||||
Some(*message_id)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
for message_id in cached_message_ids {
|
||||
self.update_metadata(message_id, cx, |metadata| {
|
||||
if let Some(cache) = &mut metadata.cache {
|
||||
cache.status = CacheStatus::Cached;
|
||||
}
|
||||
});
|
||||
}
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
pub fn reparse_slash_commands(&mut self, cx: &mut ModelContext<Self>) {
|
||||
let buffer = self.buffer.read(cx);
|
||||
let mut row_ranges = self
|
||||
|
@ -1531,7 +1625,7 @@ impl Context {
|
|||
return None;
|
||||
}
|
||||
// Compute which messages to cache, including the last one.
|
||||
self.mark_longest_messages_for_cache(&model.cache_configuration(), false, cx);
|
||||
self.mark_cache_anchors(&model.cache_configuration(), false, cx);
|
||||
|
||||
let request = self.to_completion_request(cx);
|
||||
let assistant_message = self
|
||||
|
@ -1596,6 +1690,7 @@ impl Context {
|
|||
this.pending_completions
|
||||
.retain(|completion| completion.id != pending_completion_id);
|
||||
this.summarize(false, cx);
|
||||
this.update_cache_status_for_completion(cx);
|
||||
})?;
|
||||
|
||||
anyhow::Ok(())
|
||||
|
@ -1746,8 +1841,7 @@ impl Context {
|
|||
role,
|
||||
status,
|
||||
timestamp: anchor.id.0,
|
||||
should_cache: false,
|
||||
is_cache_anchor: false,
|
||||
cache: None,
|
||||
};
|
||||
self.insert_message(anchor.clone(), metadata.clone(), cx);
|
||||
self.push_op(
|
||||
|
@ -1864,8 +1958,7 @@ impl Context {
|
|||
role,
|
||||
status: MessageStatus::Done,
|
||||
timestamp: suffix.id.0,
|
||||
should_cache: false,
|
||||
is_cache_anchor: false,
|
||||
cache: None,
|
||||
};
|
||||
self.insert_message(suffix.clone(), suffix_metadata.clone(), cx);
|
||||
self.push_op(
|
||||
|
@ -1915,8 +2008,7 @@ impl Context {
|
|||
role,
|
||||
status: MessageStatus::Done,
|
||||
timestamp: selection.id.0,
|
||||
should_cache: false,
|
||||
is_cache_anchor: false,
|
||||
cache: None,
|
||||
};
|
||||
self.insert_message(selection.clone(), selection_metadata.clone(), cx);
|
||||
self.push_op(
|
||||
|
@ -2150,7 +2242,7 @@ impl Context {
|
|||
anchor: message_anchor.start,
|
||||
role: metadata.role,
|
||||
status: metadata.status.clone(),
|
||||
cache: metadata.is_cache_anchor,
|
||||
cache: metadata.cache.clone(),
|
||||
image_offsets,
|
||||
});
|
||||
}
|
||||
|
@ -2397,8 +2489,7 @@ impl SavedContext {
|
|||
role: message.metadata.role,
|
||||
status: message.metadata.status,
|
||||
timestamp: message.metadata.timestamp,
|
||||
should_cache: false,
|
||||
is_cache_anchor: false,
|
||||
cache: None,
|
||||
},
|
||||
version: version.clone(),
|
||||
});
|
||||
|
@ -2415,8 +2506,7 @@ impl SavedContext {
|
|||
role: metadata.role,
|
||||
status: metadata.status,
|
||||
timestamp,
|
||||
should_cache: false,
|
||||
is_cache_anchor: false,
|
||||
cache: None,
|
||||
},
|
||||
version: version.clone(),
|
||||
});
|
||||
|
@ -2511,8 +2601,7 @@ impl SavedContextV0_3_0 {
|
|||
role: metadata.role,
|
||||
status: metadata.status.clone(),
|
||||
timestamp,
|
||||
should_cache: false,
|
||||
is_cache_anchor: false,
|
||||
cache: None,
|
||||
},
|
||||
image_offsets: Vec::new(),
|
||||
})
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use crate::{
|
||||
assistant_panel, prompt_library, slash_command::file_command, workflow::tool, Context,
|
||||
ContextEvent, ContextId, ContextOperation, MessageId, MessageStatus, PromptBuilder,
|
||||
assistant_panel, prompt_library, slash_command::file_command, workflow::tool, CacheStatus,
|
||||
Context, ContextEvent, ContextId, ContextOperation, MessageId, MessageStatus, PromptBuilder,
|
||||
};
|
||||
use anyhow::Result;
|
||||
use assistant_slash_command::{
|
||||
|
@ -12,7 +12,7 @@ use fs::{FakeFs, Fs as _};
|
|||
use gpui::{AppContext, Model, SharedString, Task, TestAppContext, WeakView};
|
||||
use indoc::indoc;
|
||||
use language::{Buffer, LanguageRegistry, LspAdapterDelegate};
|
||||
use language_model::{LanguageModelRegistry, Role};
|
||||
use language_model::{LanguageModelCacheConfiguration, LanguageModelRegistry, Role};
|
||||
use parking_lot::Mutex;
|
||||
use project::Project;
|
||||
use rand::prelude::*;
|
||||
|
@ -33,6 +33,8 @@ use unindent::Unindent;
|
|||
use util::{test::marked_text_ranges, RandomCharIter};
|
||||
use workspace::Workspace;
|
||||
|
||||
use super::MessageCacheMetadata;
|
||||
|
||||
#[gpui::test]
|
||||
fn test_inserting_and_removing_messages(cx: &mut AppContext) {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
|
@ -1002,6 +1004,159 @@ async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: Std
|
|||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_mark_cache_anchors(cx: &mut AppContext) {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
LanguageModelRegistry::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
assistant_panel::init(cx);
|
||||
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
|
||||
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
|
||||
let context =
|
||||
cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx));
|
||||
let buffer = context.read(cx).buffer.clone();
|
||||
|
||||
// Create a test cache configuration
|
||||
let cache_configuration = &Some(LanguageModelCacheConfiguration {
|
||||
max_cache_anchors: 3,
|
||||
should_speculate: true,
|
||||
min_total_token: 10,
|
||||
});
|
||||
|
||||
let message_1 = context.read(cx).message_anchors[0].clone();
|
||||
|
||||
context.update(cx, |context, cx| {
|
||||
context.mark_cache_anchors(cache_configuration, false, cx)
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
messages_cache(&context, cx)
|
||||
.iter()
|
||||
.filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
|
||||
.count(),
|
||||
0,
|
||||
"Empty messages should not have any cache anchors."
|
||||
);
|
||||
|
||||
buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
|
||||
let message_2 = context
|
||||
.update(cx, |context, cx| {
|
||||
context.insert_message_after(message_1.id, Role::User, MessageStatus::Pending, cx)
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbbbbbb")], None, cx));
|
||||
let message_3 = context
|
||||
.update(cx, |context, cx| {
|
||||
context.insert_message_after(message_2.id, Role::User, MessageStatus::Pending, cx)
|
||||
})
|
||||
.unwrap();
|
||||
buffer.update(cx, |buffer, cx| buffer.edit([(12..12, "cccccc")], None, cx));
|
||||
|
||||
context.update(cx, |context, cx| {
|
||||
context.mark_cache_anchors(cache_configuration, false, cx)
|
||||
});
|
||||
assert_eq!(buffer.read(cx).text(), "aaa\nbbbbbbb\ncccccc");
|
||||
assert_eq!(
|
||||
messages_cache(&context, cx)
|
||||
.iter()
|
||||
.filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
|
||||
.count(),
|
||||
0,
|
||||
"Messages should not be marked for cache before going over the token minimum."
|
||||
);
|
||||
context.update(cx, |context, _| {
|
||||
context.token_count = Some(20);
|
||||
});
|
||||
|
||||
context.update(cx, |context, cx| {
|
||||
context.mark_cache_anchors(cache_configuration, true, cx)
|
||||
});
|
||||
assert_eq!(
|
||||
messages_cache(&context, cx)
|
||||
.iter()
|
||||
.map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
|
||||
.collect::<Vec<bool>>(),
|
||||
vec![true, true, false],
|
||||
"Last message should not be an anchor on speculative request."
|
||||
);
|
||||
|
||||
context
|
||||
.update(cx, |context, cx| {
|
||||
context.insert_message_after(message_3.id, Role::Assistant, MessageStatus::Pending, cx)
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
context.update(cx, |context, cx| {
|
||||
context.mark_cache_anchors(cache_configuration, false, cx)
|
||||
});
|
||||
assert_eq!(
|
||||
messages_cache(&context, cx)
|
||||
.iter()
|
||||
.map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
|
||||
.collect::<Vec<bool>>(),
|
||||
vec![false, true, true, false],
|
||||
"Most recent message should also be cached if not a speculative request."
|
||||
);
|
||||
context.update(cx, |context, cx| {
|
||||
context.update_cache_status_for_completion(cx)
|
||||
});
|
||||
assert_eq!(
|
||||
messages_cache(&context, cx)
|
||||
.iter()
|
||||
.map(|(_, cache)| cache
|
||||
.as_ref()
|
||||
.map_or(None, |cache| Some(cache.status.clone())))
|
||||
.collect::<Vec<Option<CacheStatus>>>(),
|
||||
vec![
|
||||
Some(CacheStatus::Cached),
|
||||
Some(CacheStatus::Cached),
|
||||
Some(CacheStatus::Cached),
|
||||
None
|
||||
],
|
||||
"All user messages prior to anchor should be marked as cached."
|
||||
);
|
||||
|
||||
buffer.update(cx, |buffer, cx| buffer.edit([(14..14, "d")], None, cx));
|
||||
context.update(cx, |context, cx| {
|
||||
context.mark_cache_anchors(cache_configuration, false, cx)
|
||||
});
|
||||
assert_eq!(
|
||||
messages_cache(&context, cx)
|
||||
.iter()
|
||||
.map(|(_, cache)| cache
|
||||
.as_ref()
|
||||
.map_or(None, |cache| Some(cache.status.clone())))
|
||||
.collect::<Vec<Option<CacheStatus>>>(),
|
||||
vec![
|
||||
Some(CacheStatus::Cached),
|
||||
Some(CacheStatus::Cached),
|
||||
Some(CacheStatus::Pending),
|
||||
None
|
||||
],
|
||||
"Modifying a message should invalidate it's cache but leave previous messages."
|
||||
);
|
||||
buffer.update(cx, |buffer, cx| buffer.edit([(2..2, "e")], None, cx));
|
||||
context.update(cx, |context, cx| {
|
||||
context.mark_cache_anchors(cache_configuration, false, cx)
|
||||
});
|
||||
assert_eq!(
|
||||
messages_cache(&context, cx)
|
||||
.iter()
|
||||
.map(|(_, cache)| cache
|
||||
.as_ref()
|
||||
.map_or(None, |cache| Some(cache.status.clone())))
|
||||
.collect::<Vec<Option<CacheStatus>>>(),
|
||||
vec![
|
||||
Some(CacheStatus::Pending),
|
||||
Some(CacheStatus::Pending),
|
||||
Some(CacheStatus::Pending),
|
||||
None
|
||||
],
|
||||
"Modifying a message should invalidate all future messages."
|
||||
);
|
||||
}
|
||||
|
||||
fn messages(context: &Model<Context>, cx: &AppContext) -> Vec<(MessageId, Role, Range<usize>)> {
|
||||
context
|
||||
.read(cx)
|
||||
|
@ -1010,6 +1165,17 @@ fn messages(context: &Model<Context>, cx: &AppContext) -> Vec<(MessageId, Role,
|
|||
.collect()
|
||||
}
|
||||
|
||||
fn messages_cache(
|
||||
context: &Model<Context>,
|
||||
cx: &AppContext,
|
||||
) -> Vec<(MessageId, Option<MessageCacheMetadata>)> {
|
||||
context
|
||||
.read(cx)
|
||||
.messages(cx)
|
||||
.map(|message| (message.id, message.cache.clone()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct FakeSlashCommand(String);
|
||||
|
||||
|
|
|
@ -154,6 +154,7 @@ pub enum IconName {
|
|||
Copy,
|
||||
CountdownTimer,
|
||||
Dash,
|
||||
DatabaseZap,
|
||||
Delete,
|
||||
Disconnected,
|
||||
Download,
|
||||
|
@ -322,6 +323,7 @@ impl IconName {
|
|||
IconName::Copy => "icons/copy.svg",
|
||||
IconName::CountdownTimer => "icons/countdown_timer.svg",
|
||||
IconName::Dash => "icons/dash.svg",
|
||||
IconName::DatabaseZap => "icons/database_zap.svg",
|
||||
IconName::Delete => "icons/delete.svg",
|
||||
IconName::Disconnected => "icons/disconnected.svg",
|
||||
IconName::Download => "icons/download.svg",
|
||||
|
|
Loading…
Reference in a new issue