diff --git a/assets/icons/database_zap.svg b/assets/icons/database_zap.svg
new file mode 100644
index 0000000000..06241b35f4
--- /dev/null
+++ b/assets/icons/database_zap.svg
@@ -0,0 +1 @@
+
diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs
index d74b4870a0..e227fa03a3 100644
--- a/crates/assistant/src/assistant_panel.rs
+++ b/crates/assistant/src/assistant_panel.rs
@@ -11,11 +11,11 @@ use crate::{
},
slash_command_picker,
terminal_inline_assistant::TerminalInlineAssistant,
- Assist, ConfirmCommand, Context, ContextEvent, ContextId, ContextStore, CycleMessageRole,
- DeployHistory, DeployPromptLibrary, InlineAssist, InlineAssistId, InlineAssistant,
- InsertIntoEditor, MessageStatus, ModelSelector, PendingSlashCommand, PendingSlashCommandStatus,
- QuoteSelection, RemoteContextMetadata, SavedContextMetadata, Split, ToggleFocus,
- ToggleModelSelector, WorkflowStepResolution, WorkflowStepView,
+ Assist, CacheStatus, ConfirmCommand, Context, ContextEvent, ContextId, ContextStore,
+ CycleMessageRole, DeployHistory, DeployPromptLibrary, InlineAssist, InlineAssistId,
+ InlineAssistant, InsertIntoEditor, MessageStatus, ModelSelector, PendingSlashCommand,
+ PendingSlashCommandStatus, QuoteSelection, RemoteContextMetadata, SavedContextMetadata, Split,
+ ToggleFocus, ToggleModelSelector, WorkflowStepResolution, WorkflowStepView,
};
use crate::{ContextStoreEvent, ModelPickerDelegate};
use anyhow::{anyhow, Result};
@@ -3137,6 +3137,36 @@ impl ContextEditor {
.relative()
.gap_1()
.child(sender)
+ .children(match &message.cache {
+ Some(cache) if cache.is_final_anchor => match cache.status {
+ CacheStatus::Cached => Some(
+ div()
+ .id("cached")
+ .child(
+ Icon::new(IconName::DatabaseZap)
+ .size(IconSize::XSmall)
+ .color(Color::Hint),
+ )
+ .tooltip(|cx| {
+ Tooltip::with_meta(
+ "Context cached",
+ None,
+ "Large messages cached to optimize performance",
+ cx,
+ )
+ }).into_any_element()
+ ),
+ CacheStatus::Pending => Some(
+ div()
+ .child(
+ Icon::new(IconName::Ellipsis)
+ .size(IconSize::XSmall)
+ .color(Color::Hint),
+ ).into_any_element()
+ ),
+ },
+ _ => None,
+ })
.children(match &message.status {
MessageStatus::Error(error) => Some(
Button::new("show-error", "Error")
diff --git a/crates/assistant/src/context.rs b/crates/assistant/src/context.rs
index f68cdb53eb..41f894aeb2 100644
--- a/crates/assistant/src/context.rs
+++ b/crates/assistant/src/context.rs
@@ -40,6 +40,7 @@ use std::{
time::{Duration, Instant},
};
use telemetry_events::AssistantKind;
+use text::BufferSnapshot;
use util::{post_inc, ResultExt, TryFutureExt};
use uuid::Uuid;
@@ -107,8 +108,7 @@ impl ContextOperation {
message.status.context("invalid status")?,
),
timestamp: id.0,
- should_cache: false,
- is_cache_anchor: false,
+ cache: None,
},
version: language::proto::deserialize_version(&insert.version),
})
@@ -123,8 +123,7 @@ impl ContextOperation {
timestamp: language::proto::deserialize_timestamp(
update.timestamp.context("invalid timestamp")?,
),
- should_cache: false,
- is_cache_anchor: false,
+ cache: None,
},
version: language::proto::deserialize_version(&update.version),
}),
@@ -313,13 +312,43 @@ pub struct MessageAnchor {
pub start: language::Anchor,
}
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub enum CacheStatus {
+ Pending,
+ Cached,
+}
+
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct MessageCacheMetadata {
+ pub is_anchor: bool,
+ pub is_final_anchor: bool,
+ pub status: CacheStatus,
+ pub cached_at: clock::Global,
+}
+
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub struct MessageMetadata {
pub role: Role,
pub status: MessageStatus,
timestamp: clock::Lamport,
- should_cache: bool,
- is_cache_anchor: bool,
+ #[serde(skip)]
+ pub cache: Option,
+}
+
+impl MessageMetadata {
+ pub fn is_cache_valid(&self, buffer: &BufferSnapshot, range: &Range) -> bool {
+ let result = match &self.cache {
+ Some(MessageCacheMetadata { cached_at, .. }) => !buffer.has_edits_since_in_range(
+ &cached_at,
+ Range {
+ start: buffer.anchor_at(range.start, Bias::Right),
+ end: buffer.anchor_at(range.end, Bias::Left),
+ },
+ ),
+ _ => false,
+ };
+ result
+ }
}
#[derive(Clone, Debug)]
@@ -345,7 +374,7 @@ pub struct Message {
pub anchor: language::Anchor,
pub role: Role,
pub status: MessageStatus,
- pub cache: bool,
+ pub cache: Option,
}
impl Message {
@@ -381,7 +410,7 @@ impl Message {
Some(LanguageModelRequestMessage {
role: self.role,
content,
- cache: self.cache,
+ cache: self.cache.as_ref().map_or(false, |cache| cache.is_anchor),
})
}
@@ -544,8 +573,7 @@ impl Context {
role: Role::User,
status: MessageStatus::Done,
timestamp: first_message_id.0,
- should_cache: false,
- is_cache_anchor: false,
+ cache: None,
},
);
this.message_anchors.push(message);
@@ -979,7 +1007,7 @@ impl Context {
});
}
- pub fn mark_longest_messages_for_cache(
+ pub fn mark_cache_anchors(
&mut self,
cache_configuration: &Option,
speculative: bool,
@@ -994,66 +1022,104 @@ impl Context {
min_total_token: 0,
});
- let messages: Vec = self
- .messages_from_anchors(
- self.message_anchors.iter().take(if speculative {
- self.message_anchors.len().saturating_sub(1)
- } else {
- self.message_anchors.len()
- }),
- cx,
- )
- .filter(|message| message.offset_range.len() >= 5_000)
- .collect();
+ let messages: Vec = self.messages(cx).collect();
let mut sorted_messages = messages.clone();
- sorted_messages.sort_by(|a, b| b.offset_range.len().cmp(&a.offset_range.len()));
- if cache_configuration.max_cache_anchors == 0 && cache_configuration.should_speculate {
- // Some models support caching, but don't support anchors. In that case we want to
- // mark the largest message as needing to be cached, but we will not mark it as an
- // anchor.
- sorted_messages.truncate(1);
- } else {
- // Save 1 anchor for the inline assistant.
- sorted_messages.truncate(max(cache_configuration.max_cache_anchors, 1) - 1);
+ if speculative {
+ // Avoid caching the last message if this is a speculative cache fetch as
+ // it's likely to change.
+ sorted_messages.pop();
}
+ sorted_messages.retain(|m| m.role == Role::User);
+ sorted_messages.sort_by(|a, b| b.offset_range.len().cmp(&a.offset_range.len()));
- let longest_message_ids: HashSet = sorted_messages
+ let cache_anchors = if self.token_count.unwrap_or(0) < cache_configuration.min_total_token {
+ // If we have't hit the minimum threshold to enable caching, don't cache anything.
+ 0
+ } else {
+ // Save 1 anchor for the inline assistant to use.
+ max(cache_configuration.max_cache_anchors, 1) - 1
+ };
+ sorted_messages.truncate(cache_anchors);
+
+ let anchors: HashSet = sorted_messages
.into_iter()
.map(|message| message.id)
.collect();
- let cache_deltas: HashSet = self
- .messages_metadata
+ let buffer = self.buffer.read(cx).snapshot();
+ let invalidated_caches: HashSet = messages
.iter()
- .filter_map(|(id, metadata)| {
- let should_cache = longest_message_ids.contains(id);
- let should_be_anchor = should_cache && cache_configuration.max_cache_anchors > 0;
- if metadata.should_cache != should_cache
- || metadata.is_cache_anchor != should_be_anchor
- {
- Some(*id)
- } else {
- None
- }
+ .scan(false, |encountered_invalid, message| {
+ let message_id = message.id;
+ let is_invalid = self
+ .messages_metadata
+ .get(&message_id)
+ .map_or(true, |metadata| {
+ !metadata.is_cache_valid(&buffer, &message.offset_range)
+ || *encountered_invalid
+ });
+ *encountered_invalid |= is_invalid;
+ Some(if is_invalid { Some(message_id) } else { None })
})
+ .flatten()
.collect();
- let mut newly_cached_item = false;
- for id in cache_deltas {
- newly_cached_item = newly_cached_item || longest_message_ids.contains(&id);
- self.update_metadata(id, cx, |metadata| {
- metadata.should_cache = longest_message_ids.contains(&id);
- metadata.is_cache_anchor =
- metadata.should_cache && (cache_configuration.max_cache_anchors > 0);
+ let last_anchor = messages.iter().rev().find_map(|message| {
+ if anchors.contains(&message.id) {
+ Some(message.id)
+ } else {
+ None
+ }
+ });
+
+ let mut new_anchor_needs_caching = false;
+ let current_version = &buffer.version;
+ // If we have no anchors, mark all messages as not being cached.
+ let mut hit_last_anchor = last_anchor.is_none();
+
+ for message in messages.iter() {
+ if hit_last_anchor {
+ self.update_metadata(message.id, cx, |metadata| metadata.cache = None);
+ continue;
+ }
+
+ if let Some(last_anchor) = last_anchor {
+ if message.id == last_anchor {
+ hit_last_anchor = true;
+ }
+ }
+
+ new_anchor_needs_caching = new_anchor_needs_caching
+ || (invalidated_caches.contains(&message.id) && anchors.contains(&message.id));
+
+ self.update_metadata(message.id, cx, |metadata| {
+ let cache_status = if invalidated_caches.contains(&message.id) {
+ CacheStatus::Pending
+ } else {
+ metadata
+ .cache
+ .as_ref()
+ .map_or(CacheStatus::Pending, |cm| cm.status.clone())
+ };
+ metadata.cache = Some(MessageCacheMetadata {
+ is_anchor: anchors.contains(&message.id),
+ is_final_anchor: hit_last_anchor,
+ status: cache_status,
+ cached_at: current_version.clone(),
+ });
});
}
- newly_cached_item
+ new_anchor_needs_caching
}
fn start_cache_warming(&mut self, model: &Arc, cx: &mut ModelContext) {
let cache_configuration = model.cache_configuration();
- if !self.mark_longest_messages_for_cache(&cache_configuration, true, cx) {
+
+ if !self.mark_cache_anchors(&cache_configuration, true, cx) {
+ return;
+ }
+ if !self.pending_completions.is_empty() {
return;
}
if let Some(cache_configuration) = cache_configuration {
@@ -1076,7 +1142,7 @@ impl Context {
};
let model = Arc::clone(model);
- self.pending_cache_warming_task = cx.spawn(|_, cx| {
+ self.pending_cache_warming_task = cx.spawn(|this, mut cx| {
async move {
match model.stream_completion(request, &cx).await {
Ok(mut stream) => {
@@ -1087,13 +1153,41 @@ impl Context {
log::warn!("Cache warming failed: {}", e);
}
};
-
+ this.update(&mut cx, |this, cx| {
+ this.update_cache_status_for_completion(cx);
+ })
+ .ok();
anyhow::Ok(())
}
.log_err()
});
}
+ pub fn update_cache_status_for_completion(&mut self, cx: &mut ModelContext) {
+ let cached_message_ids: Vec = self
+ .messages_metadata
+ .iter()
+ .filter_map(|(message_id, metadata)| {
+ metadata.cache.as_ref().and_then(|cache| {
+ if cache.status == CacheStatus::Pending {
+ Some(*message_id)
+ } else {
+ None
+ }
+ })
+ })
+ .collect();
+
+ for message_id in cached_message_ids {
+ self.update_metadata(message_id, cx, |metadata| {
+ if let Some(cache) = &mut metadata.cache {
+ cache.status = CacheStatus::Cached;
+ }
+ });
+ }
+ cx.notify();
+ }
+
pub fn reparse_slash_commands(&mut self, cx: &mut ModelContext) {
let buffer = self.buffer.read(cx);
let mut row_ranges = self
@@ -1531,7 +1625,7 @@ impl Context {
return None;
}
// Compute which messages to cache, including the last one.
- self.mark_longest_messages_for_cache(&model.cache_configuration(), false, cx);
+ self.mark_cache_anchors(&model.cache_configuration(), false, cx);
let request = self.to_completion_request(cx);
let assistant_message = self
@@ -1596,6 +1690,7 @@ impl Context {
this.pending_completions
.retain(|completion| completion.id != pending_completion_id);
this.summarize(false, cx);
+ this.update_cache_status_for_completion(cx);
})?;
anyhow::Ok(())
@@ -1746,8 +1841,7 @@ impl Context {
role,
status,
timestamp: anchor.id.0,
- should_cache: false,
- is_cache_anchor: false,
+ cache: None,
};
self.insert_message(anchor.clone(), metadata.clone(), cx);
self.push_op(
@@ -1864,8 +1958,7 @@ impl Context {
role,
status: MessageStatus::Done,
timestamp: suffix.id.0,
- should_cache: false,
- is_cache_anchor: false,
+ cache: None,
};
self.insert_message(suffix.clone(), suffix_metadata.clone(), cx);
self.push_op(
@@ -1915,8 +2008,7 @@ impl Context {
role,
status: MessageStatus::Done,
timestamp: selection.id.0,
- should_cache: false,
- is_cache_anchor: false,
+ cache: None,
};
self.insert_message(selection.clone(), selection_metadata.clone(), cx);
self.push_op(
@@ -2150,7 +2242,7 @@ impl Context {
anchor: message_anchor.start,
role: metadata.role,
status: metadata.status.clone(),
- cache: metadata.is_cache_anchor,
+ cache: metadata.cache.clone(),
image_offsets,
});
}
@@ -2397,8 +2489,7 @@ impl SavedContext {
role: message.metadata.role,
status: message.metadata.status,
timestamp: message.metadata.timestamp,
- should_cache: false,
- is_cache_anchor: false,
+ cache: None,
},
version: version.clone(),
});
@@ -2415,8 +2506,7 @@ impl SavedContext {
role: metadata.role,
status: metadata.status,
timestamp,
- should_cache: false,
- is_cache_anchor: false,
+ cache: None,
},
version: version.clone(),
});
@@ -2511,8 +2601,7 @@ impl SavedContextV0_3_0 {
role: metadata.role,
status: metadata.status.clone(),
timestamp,
- should_cache: false,
- is_cache_anchor: false,
+ cache: None,
},
image_offsets: Vec::new(),
})
diff --git a/crates/assistant/src/context/context_tests.rs b/crates/assistant/src/context/context_tests.rs
index 4eb7b75a64..35764822d7 100644
--- a/crates/assistant/src/context/context_tests.rs
+++ b/crates/assistant/src/context/context_tests.rs
@@ -1,6 +1,6 @@
use crate::{
- assistant_panel, prompt_library, slash_command::file_command, workflow::tool, Context,
- ContextEvent, ContextId, ContextOperation, MessageId, MessageStatus, PromptBuilder,
+ assistant_panel, prompt_library, slash_command::file_command, workflow::tool, CacheStatus,
+ Context, ContextEvent, ContextId, ContextOperation, MessageId, MessageStatus, PromptBuilder,
};
use anyhow::Result;
use assistant_slash_command::{
@@ -12,7 +12,7 @@ use fs::{FakeFs, Fs as _};
use gpui::{AppContext, Model, SharedString, Task, TestAppContext, WeakView};
use indoc::indoc;
use language::{Buffer, LanguageRegistry, LspAdapterDelegate};
-use language_model::{LanguageModelRegistry, Role};
+use language_model::{LanguageModelCacheConfiguration, LanguageModelRegistry, Role};
use parking_lot::Mutex;
use project::Project;
use rand::prelude::*;
@@ -33,6 +33,8 @@ use unindent::Unindent;
use util::{test::marked_text_ranges, RandomCharIter};
use workspace::Workspace;
+use super::MessageCacheMetadata;
+
#[gpui::test]
fn test_inserting_and_removing_messages(cx: &mut AppContext) {
let settings_store = SettingsStore::test(cx);
@@ -1002,6 +1004,159 @@ async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: Std
});
}
+#[gpui::test]
+fn test_mark_cache_anchors(cx: &mut AppContext) {
+ let settings_store = SettingsStore::test(cx);
+ LanguageModelRegistry::test(cx);
+ cx.set_global(settings_store);
+ assistant_panel::init(cx);
+ let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
+ let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
+ let context =
+ cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx));
+ let buffer = context.read(cx).buffer.clone();
+
+ // Create a test cache configuration
+ let cache_configuration = &Some(LanguageModelCacheConfiguration {
+ max_cache_anchors: 3,
+ should_speculate: true,
+ min_total_token: 10,
+ });
+
+ let message_1 = context.read(cx).message_anchors[0].clone();
+
+ context.update(cx, |context, cx| {
+ context.mark_cache_anchors(cache_configuration, false, cx)
+ });
+
+ assert_eq!(
+ messages_cache(&context, cx)
+ .iter()
+ .filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
+ .count(),
+ 0,
+ "Empty messages should not have any cache anchors."
+ );
+
+ buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
+ let message_2 = context
+ .update(cx, |context, cx| {
+ context.insert_message_after(message_1.id, Role::User, MessageStatus::Pending, cx)
+ })
+ .unwrap();
+
+ buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbbbbbb")], None, cx));
+ let message_3 = context
+ .update(cx, |context, cx| {
+ context.insert_message_after(message_2.id, Role::User, MessageStatus::Pending, cx)
+ })
+ .unwrap();
+ buffer.update(cx, |buffer, cx| buffer.edit([(12..12, "cccccc")], None, cx));
+
+ context.update(cx, |context, cx| {
+ context.mark_cache_anchors(cache_configuration, false, cx)
+ });
+ assert_eq!(buffer.read(cx).text(), "aaa\nbbbbbbb\ncccccc");
+ assert_eq!(
+ messages_cache(&context, cx)
+ .iter()
+ .filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
+ .count(),
+ 0,
+ "Messages should not be marked for cache before going over the token minimum."
+ );
+ context.update(cx, |context, _| {
+ context.token_count = Some(20);
+ });
+
+ context.update(cx, |context, cx| {
+ context.mark_cache_anchors(cache_configuration, true, cx)
+ });
+ assert_eq!(
+ messages_cache(&context, cx)
+ .iter()
+ .map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
+ .collect::>(),
+ vec![true, true, false],
+ "Last message should not be an anchor on speculative request."
+ );
+
+ context
+ .update(cx, |context, cx| {
+ context.insert_message_after(message_3.id, Role::Assistant, MessageStatus::Pending, cx)
+ })
+ .unwrap();
+
+ context.update(cx, |context, cx| {
+ context.mark_cache_anchors(cache_configuration, false, cx)
+ });
+ assert_eq!(
+ messages_cache(&context, cx)
+ .iter()
+ .map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
+ .collect::>(),
+ vec![false, true, true, false],
+ "Most recent message should also be cached if not a speculative request."
+ );
+ context.update(cx, |context, cx| {
+ context.update_cache_status_for_completion(cx)
+ });
+ assert_eq!(
+ messages_cache(&context, cx)
+ .iter()
+ .map(|(_, cache)| cache
+ .as_ref()
+ .map_or(None, |cache| Some(cache.status.clone())))
+ .collect::>>(),
+ vec![
+ Some(CacheStatus::Cached),
+ Some(CacheStatus::Cached),
+ Some(CacheStatus::Cached),
+ None
+ ],
+ "All user messages prior to anchor should be marked as cached."
+ );
+
+ buffer.update(cx, |buffer, cx| buffer.edit([(14..14, "d")], None, cx));
+ context.update(cx, |context, cx| {
+ context.mark_cache_anchors(cache_configuration, false, cx)
+ });
+ assert_eq!(
+ messages_cache(&context, cx)
+ .iter()
+ .map(|(_, cache)| cache
+ .as_ref()
+ .map_or(None, |cache| Some(cache.status.clone())))
+ .collect::>>(),
+ vec![
+ Some(CacheStatus::Cached),
+ Some(CacheStatus::Cached),
+ Some(CacheStatus::Pending),
+ None
+ ],
+ "Modifying a message should invalidate it's cache but leave previous messages."
+ );
+ buffer.update(cx, |buffer, cx| buffer.edit([(2..2, "e")], None, cx));
+ context.update(cx, |context, cx| {
+ context.mark_cache_anchors(cache_configuration, false, cx)
+ });
+ assert_eq!(
+ messages_cache(&context, cx)
+ .iter()
+ .map(|(_, cache)| cache
+ .as_ref()
+ .map_or(None, |cache| Some(cache.status.clone())))
+ .collect::>>(),
+ vec![
+ Some(CacheStatus::Pending),
+ Some(CacheStatus::Pending),
+ Some(CacheStatus::Pending),
+ None
+ ],
+ "Modifying a message should invalidate all future messages."
+ );
+}
+
fn messages(context: &Model, cx: &AppContext) -> Vec<(MessageId, Role, Range)> {
context
.read(cx)
@@ -1010,6 +1165,17 @@ fn messages(context: &Model, cx: &AppContext) -> Vec<(MessageId, Role,
.collect()
}
+fn messages_cache(
+ context: &Model,
+ cx: &AppContext,
+) -> Vec<(MessageId, Option)> {
+ context
+ .read(cx)
+ .messages(cx)
+ .map(|message| (message.id, message.cache.clone()))
+ .collect()
+}
+
#[derive(Clone)]
struct FakeSlashCommand(String);
diff --git a/crates/ui/src/components/icon.rs b/crates/ui/src/components/icon.rs
index f399820016..74c565988d 100644
--- a/crates/ui/src/components/icon.rs
+++ b/crates/ui/src/components/icon.rs
@@ -154,6 +154,7 @@ pub enum IconName {
Copy,
CountdownTimer,
Dash,
+ DatabaseZap,
Delete,
Disconnected,
Download,
@@ -322,6 +323,7 @@ impl IconName {
IconName::Copy => "icons/copy.svg",
IconName::CountdownTimer => "icons/countdown_timer.svg",
IconName::Dash => "icons/dash.svg",
+ IconName::DatabaseZap => "icons/database_zap.svg",
IconName::Delete => "icons/delete.svg",
IconName::Disconnected => "icons/disconnected.svg",
IconName::Download => "icons/download.svg",