diff --git a/assets/icons/database_zap.svg b/assets/icons/database_zap.svg new file mode 100644 index 0000000000..06241b35f4 --- /dev/null +++ b/assets/icons/database_zap.svg @@ -0,0 +1 @@ + diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index d74b4870a0..e227fa03a3 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -11,11 +11,11 @@ use crate::{ }, slash_command_picker, terminal_inline_assistant::TerminalInlineAssistant, - Assist, ConfirmCommand, Context, ContextEvent, ContextId, ContextStore, CycleMessageRole, - DeployHistory, DeployPromptLibrary, InlineAssist, InlineAssistId, InlineAssistant, - InsertIntoEditor, MessageStatus, ModelSelector, PendingSlashCommand, PendingSlashCommandStatus, - QuoteSelection, RemoteContextMetadata, SavedContextMetadata, Split, ToggleFocus, - ToggleModelSelector, WorkflowStepResolution, WorkflowStepView, + Assist, CacheStatus, ConfirmCommand, Context, ContextEvent, ContextId, ContextStore, + CycleMessageRole, DeployHistory, DeployPromptLibrary, InlineAssist, InlineAssistId, + InlineAssistant, InsertIntoEditor, MessageStatus, ModelSelector, PendingSlashCommand, + PendingSlashCommandStatus, QuoteSelection, RemoteContextMetadata, SavedContextMetadata, Split, + ToggleFocus, ToggleModelSelector, WorkflowStepResolution, WorkflowStepView, }; use crate::{ContextStoreEvent, ModelPickerDelegate}; use anyhow::{anyhow, Result}; @@ -3137,6 +3137,36 @@ impl ContextEditor { .relative() .gap_1() .child(sender) + .children(match &message.cache { + Some(cache) if cache.is_final_anchor => match cache.status { + CacheStatus::Cached => Some( + div() + .id("cached") + .child( + Icon::new(IconName::DatabaseZap) + .size(IconSize::XSmall) + .color(Color::Hint), + ) + .tooltip(|cx| { + Tooltip::with_meta( + "Context cached", + None, + "Large messages cached to optimize performance", + cx, + ) + }).into_any_element() + ), + CacheStatus::Pending => Some( + div() + .child( + Icon::new(IconName::Ellipsis) + .size(IconSize::XSmall) + .color(Color::Hint), + ).into_any_element() + ), + }, + _ => None, + }) .children(match &message.status { MessageStatus::Error(error) => Some( Button::new("show-error", "Error") diff --git a/crates/assistant/src/context.rs b/crates/assistant/src/context.rs index f68cdb53eb..41f894aeb2 100644 --- a/crates/assistant/src/context.rs +++ b/crates/assistant/src/context.rs @@ -40,6 +40,7 @@ use std::{ time::{Duration, Instant}, }; use telemetry_events::AssistantKind; +use text::BufferSnapshot; use util::{post_inc, ResultExt, TryFutureExt}; use uuid::Uuid; @@ -107,8 +108,7 @@ impl ContextOperation { message.status.context("invalid status")?, ), timestamp: id.0, - should_cache: false, - is_cache_anchor: false, + cache: None, }, version: language::proto::deserialize_version(&insert.version), }) @@ -123,8 +123,7 @@ impl ContextOperation { timestamp: language::proto::deserialize_timestamp( update.timestamp.context("invalid timestamp")?, ), - should_cache: false, - is_cache_anchor: false, + cache: None, }, version: language::proto::deserialize_version(&update.version), }), @@ -313,13 +312,43 @@ pub struct MessageAnchor { pub start: language::Anchor, } +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum CacheStatus { + Pending, + Cached, +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct MessageCacheMetadata { + pub is_anchor: bool, + pub is_final_anchor: bool, + pub status: CacheStatus, + pub cached_at: clock::Global, +} + #[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] pub struct MessageMetadata { pub role: Role, pub status: MessageStatus, timestamp: clock::Lamport, - should_cache: bool, - is_cache_anchor: bool, + #[serde(skip)] + pub cache: Option, +} + +impl MessageMetadata { + pub fn is_cache_valid(&self, buffer: &BufferSnapshot, range: &Range) -> bool { + let result = match &self.cache { + Some(MessageCacheMetadata { cached_at, .. }) => !buffer.has_edits_since_in_range( + &cached_at, + Range { + start: buffer.anchor_at(range.start, Bias::Right), + end: buffer.anchor_at(range.end, Bias::Left), + }, + ), + _ => false, + }; + result + } } #[derive(Clone, Debug)] @@ -345,7 +374,7 @@ pub struct Message { pub anchor: language::Anchor, pub role: Role, pub status: MessageStatus, - pub cache: bool, + pub cache: Option, } impl Message { @@ -381,7 +410,7 @@ impl Message { Some(LanguageModelRequestMessage { role: self.role, content, - cache: self.cache, + cache: self.cache.as_ref().map_or(false, |cache| cache.is_anchor), }) } @@ -544,8 +573,7 @@ impl Context { role: Role::User, status: MessageStatus::Done, timestamp: first_message_id.0, - should_cache: false, - is_cache_anchor: false, + cache: None, }, ); this.message_anchors.push(message); @@ -979,7 +1007,7 @@ impl Context { }); } - pub fn mark_longest_messages_for_cache( + pub fn mark_cache_anchors( &mut self, cache_configuration: &Option, speculative: bool, @@ -994,66 +1022,104 @@ impl Context { min_total_token: 0, }); - let messages: Vec = self - .messages_from_anchors( - self.message_anchors.iter().take(if speculative { - self.message_anchors.len().saturating_sub(1) - } else { - self.message_anchors.len() - }), - cx, - ) - .filter(|message| message.offset_range.len() >= 5_000) - .collect(); + let messages: Vec = self.messages(cx).collect(); let mut sorted_messages = messages.clone(); - sorted_messages.sort_by(|a, b| b.offset_range.len().cmp(&a.offset_range.len())); - if cache_configuration.max_cache_anchors == 0 && cache_configuration.should_speculate { - // Some models support caching, but don't support anchors. In that case we want to - // mark the largest message as needing to be cached, but we will not mark it as an - // anchor. - sorted_messages.truncate(1); - } else { - // Save 1 anchor for the inline assistant. - sorted_messages.truncate(max(cache_configuration.max_cache_anchors, 1) - 1); + if speculative { + // Avoid caching the last message if this is a speculative cache fetch as + // it's likely to change. + sorted_messages.pop(); } + sorted_messages.retain(|m| m.role == Role::User); + sorted_messages.sort_by(|a, b| b.offset_range.len().cmp(&a.offset_range.len())); - let longest_message_ids: HashSet = sorted_messages + let cache_anchors = if self.token_count.unwrap_or(0) < cache_configuration.min_total_token { + // If we have't hit the minimum threshold to enable caching, don't cache anything. + 0 + } else { + // Save 1 anchor for the inline assistant to use. + max(cache_configuration.max_cache_anchors, 1) - 1 + }; + sorted_messages.truncate(cache_anchors); + + let anchors: HashSet = sorted_messages .into_iter() .map(|message| message.id) .collect(); - let cache_deltas: HashSet = self - .messages_metadata + let buffer = self.buffer.read(cx).snapshot(); + let invalidated_caches: HashSet = messages .iter() - .filter_map(|(id, metadata)| { - let should_cache = longest_message_ids.contains(id); - let should_be_anchor = should_cache && cache_configuration.max_cache_anchors > 0; - if metadata.should_cache != should_cache - || metadata.is_cache_anchor != should_be_anchor - { - Some(*id) - } else { - None - } + .scan(false, |encountered_invalid, message| { + let message_id = message.id; + let is_invalid = self + .messages_metadata + .get(&message_id) + .map_or(true, |metadata| { + !metadata.is_cache_valid(&buffer, &message.offset_range) + || *encountered_invalid + }); + *encountered_invalid |= is_invalid; + Some(if is_invalid { Some(message_id) } else { None }) }) + .flatten() .collect(); - let mut newly_cached_item = false; - for id in cache_deltas { - newly_cached_item = newly_cached_item || longest_message_ids.contains(&id); - self.update_metadata(id, cx, |metadata| { - metadata.should_cache = longest_message_ids.contains(&id); - metadata.is_cache_anchor = - metadata.should_cache && (cache_configuration.max_cache_anchors > 0); + let last_anchor = messages.iter().rev().find_map(|message| { + if anchors.contains(&message.id) { + Some(message.id) + } else { + None + } + }); + + let mut new_anchor_needs_caching = false; + let current_version = &buffer.version; + // If we have no anchors, mark all messages as not being cached. + let mut hit_last_anchor = last_anchor.is_none(); + + for message in messages.iter() { + if hit_last_anchor { + self.update_metadata(message.id, cx, |metadata| metadata.cache = None); + continue; + } + + if let Some(last_anchor) = last_anchor { + if message.id == last_anchor { + hit_last_anchor = true; + } + } + + new_anchor_needs_caching = new_anchor_needs_caching + || (invalidated_caches.contains(&message.id) && anchors.contains(&message.id)); + + self.update_metadata(message.id, cx, |metadata| { + let cache_status = if invalidated_caches.contains(&message.id) { + CacheStatus::Pending + } else { + metadata + .cache + .as_ref() + .map_or(CacheStatus::Pending, |cm| cm.status.clone()) + }; + metadata.cache = Some(MessageCacheMetadata { + is_anchor: anchors.contains(&message.id), + is_final_anchor: hit_last_anchor, + status: cache_status, + cached_at: current_version.clone(), + }); }); } - newly_cached_item + new_anchor_needs_caching } fn start_cache_warming(&mut self, model: &Arc, cx: &mut ModelContext) { let cache_configuration = model.cache_configuration(); - if !self.mark_longest_messages_for_cache(&cache_configuration, true, cx) { + + if !self.mark_cache_anchors(&cache_configuration, true, cx) { + return; + } + if !self.pending_completions.is_empty() { return; } if let Some(cache_configuration) = cache_configuration { @@ -1076,7 +1142,7 @@ impl Context { }; let model = Arc::clone(model); - self.pending_cache_warming_task = cx.spawn(|_, cx| { + self.pending_cache_warming_task = cx.spawn(|this, mut cx| { async move { match model.stream_completion(request, &cx).await { Ok(mut stream) => { @@ -1087,13 +1153,41 @@ impl Context { log::warn!("Cache warming failed: {}", e); } }; - + this.update(&mut cx, |this, cx| { + this.update_cache_status_for_completion(cx); + }) + .ok(); anyhow::Ok(()) } .log_err() }); } + pub fn update_cache_status_for_completion(&mut self, cx: &mut ModelContext) { + let cached_message_ids: Vec = self + .messages_metadata + .iter() + .filter_map(|(message_id, metadata)| { + metadata.cache.as_ref().and_then(|cache| { + if cache.status == CacheStatus::Pending { + Some(*message_id) + } else { + None + } + }) + }) + .collect(); + + for message_id in cached_message_ids { + self.update_metadata(message_id, cx, |metadata| { + if let Some(cache) = &mut metadata.cache { + cache.status = CacheStatus::Cached; + } + }); + } + cx.notify(); + } + pub fn reparse_slash_commands(&mut self, cx: &mut ModelContext) { let buffer = self.buffer.read(cx); let mut row_ranges = self @@ -1531,7 +1625,7 @@ impl Context { return None; } // Compute which messages to cache, including the last one. - self.mark_longest_messages_for_cache(&model.cache_configuration(), false, cx); + self.mark_cache_anchors(&model.cache_configuration(), false, cx); let request = self.to_completion_request(cx); let assistant_message = self @@ -1596,6 +1690,7 @@ impl Context { this.pending_completions .retain(|completion| completion.id != pending_completion_id); this.summarize(false, cx); + this.update_cache_status_for_completion(cx); })?; anyhow::Ok(()) @@ -1746,8 +1841,7 @@ impl Context { role, status, timestamp: anchor.id.0, - should_cache: false, - is_cache_anchor: false, + cache: None, }; self.insert_message(anchor.clone(), metadata.clone(), cx); self.push_op( @@ -1864,8 +1958,7 @@ impl Context { role, status: MessageStatus::Done, timestamp: suffix.id.0, - should_cache: false, - is_cache_anchor: false, + cache: None, }; self.insert_message(suffix.clone(), suffix_metadata.clone(), cx); self.push_op( @@ -1915,8 +2008,7 @@ impl Context { role, status: MessageStatus::Done, timestamp: selection.id.0, - should_cache: false, - is_cache_anchor: false, + cache: None, }; self.insert_message(selection.clone(), selection_metadata.clone(), cx); self.push_op( @@ -2150,7 +2242,7 @@ impl Context { anchor: message_anchor.start, role: metadata.role, status: metadata.status.clone(), - cache: metadata.is_cache_anchor, + cache: metadata.cache.clone(), image_offsets, }); } @@ -2397,8 +2489,7 @@ impl SavedContext { role: message.metadata.role, status: message.metadata.status, timestamp: message.metadata.timestamp, - should_cache: false, - is_cache_anchor: false, + cache: None, }, version: version.clone(), }); @@ -2415,8 +2506,7 @@ impl SavedContext { role: metadata.role, status: metadata.status, timestamp, - should_cache: false, - is_cache_anchor: false, + cache: None, }, version: version.clone(), }); @@ -2511,8 +2601,7 @@ impl SavedContextV0_3_0 { role: metadata.role, status: metadata.status.clone(), timestamp, - should_cache: false, - is_cache_anchor: false, + cache: None, }, image_offsets: Vec::new(), }) diff --git a/crates/assistant/src/context/context_tests.rs b/crates/assistant/src/context/context_tests.rs index 4eb7b75a64..35764822d7 100644 --- a/crates/assistant/src/context/context_tests.rs +++ b/crates/assistant/src/context/context_tests.rs @@ -1,6 +1,6 @@ use crate::{ - assistant_panel, prompt_library, slash_command::file_command, workflow::tool, Context, - ContextEvent, ContextId, ContextOperation, MessageId, MessageStatus, PromptBuilder, + assistant_panel, prompt_library, slash_command::file_command, workflow::tool, CacheStatus, + Context, ContextEvent, ContextId, ContextOperation, MessageId, MessageStatus, PromptBuilder, }; use anyhow::Result; use assistant_slash_command::{ @@ -12,7 +12,7 @@ use fs::{FakeFs, Fs as _}; use gpui::{AppContext, Model, SharedString, Task, TestAppContext, WeakView}; use indoc::indoc; use language::{Buffer, LanguageRegistry, LspAdapterDelegate}; -use language_model::{LanguageModelRegistry, Role}; +use language_model::{LanguageModelCacheConfiguration, LanguageModelRegistry, Role}; use parking_lot::Mutex; use project::Project; use rand::prelude::*; @@ -33,6 +33,8 @@ use unindent::Unindent; use util::{test::marked_text_ranges, RandomCharIter}; use workspace::Workspace; +use super::MessageCacheMetadata; + #[gpui::test] fn test_inserting_and_removing_messages(cx: &mut AppContext) { let settings_store = SettingsStore::test(cx); @@ -1002,6 +1004,159 @@ async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: Std }); } +#[gpui::test] +fn test_mark_cache_anchors(cx: &mut AppContext) { + let settings_store = SettingsStore::test(cx); + LanguageModelRegistry::test(cx); + cx.set_global(settings_store); + assistant_panel::init(cx); + let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); + let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); + let context = + cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx)); + let buffer = context.read(cx).buffer.clone(); + + // Create a test cache configuration + let cache_configuration = &Some(LanguageModelCacheConfiguration { + max_cache_anchors: 3, + should_speculate: true, + min_total_token: 10, + }); + + let message_1 = context.read(cx).message_anchors[0].clone(); + + context.update(cx, |context, cx| { + context.mark_cache_anchors(cache_configuration, false, cx) + }); + + assert_eq!( + messages_cache(&context, cx) + .iter() + .filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor)) + .count(), + 0, + "Empty messages should not have any cache anchors." + ); + + buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx)); + let message_2 = context + .update(cx, |context, cx| { + context.insert_message_after(message_1.id, Role::User, MessageStatus::Pending, cx) + }) + .unwrap(); + + buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbbbbbb")], None, cx)); + let message_3 = context + .update(cx, |context, cx| { + context.insert_message_after(message_2.id, Role::User, MessageStatus::Pending, cx) + }) + .unwrap(); + buffer.update(cx, |buffer, cx| buffer.edit([(12..12, "cccccc")], None, cx)); + + context.update(cx, |context, cx| { + context.mark_cache_anchors(cache_configuration, false, cx) + }); + assert_eq!(buffer.read(cx).text(), "aaa\nbbbbbbb\ncccccc"); + assert_eq!( + messages_cache(&context, cx) + .iter() + .filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor)) + .count(), + 0, + "Messages should not be marked for cache before going over the token minimum." + ); + context.update(cx, |context, _| { + context.token_count = Some(20); + }); + + context.update(cx, |context, cx| { + context.mark_cache_anchors(cache_configuration, true, cx) + }); + assert_eq!( + messages_cache(&context, cx) + .iter() + .map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor)) + .collect::>(), + vec![true, true, false], + "Last message should not be an anchor on speculative request." + ); + + context + .update(cx, |context, cx| { + context.insert_message_after(message_3.id, Role::Assistant, MessageStatus::Pending, cx) + }) + .unwrap(); + + context.update(cx, |context, cx| { + context.mark_cache_anchors(cache_configuration, false, cx) + }); + assert_eq!( + messages_cache(&context, cx) + .iter() + .map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor)) + .collect::>(), + vec![false, true, true, false], + "Most recent message should also be cached if not a speculative request." + ); + context.update(cx, |context, cx| { + context.update_cache_status_for_completion(cx) + }); + assert_eq!( + messages_cache(&context, cx) + .iter() + .map(|(_, cache)| cache + .as_ref() + .map_or(None, |cache| Some(cache.status.clone()))) + .collect::>>(), + vec![ + Some(CacheStatus::Cached), + Some(CacheStatus::Cached), + Some(CacheStatus::Cached), + None + ], + "All user messages prior to anchor should be marked as cached." + ); + + buffer.update(cx, |buffer, cx| buffer.edit([(14..14, "d")], None, cx)); + context.update(cx, |context, cx| { + context.mark_cache_anchors(cache_configuration, false, cx) + }); + assert_eq!( + messages_cache(&context, cx) + .iter() + .map(|(_, cache)| cache + .as_ref() + .map_or(None, |cache| Some(cache.status.clone()))) + .collect::>>(), + vec![ + Some(CacheStatus::Cached), + Some(CacheStatus::Cached), + Some(CacheStatus::Pending), + None + ], + "Modifying a message should invalidate it's cache but leave previous messages." + ); + buffer.update(cx, |buffer, cx| buffer.edit([(2..2, "e")], None, cx)); + context.update(cx, |context, cx| { + context.mark_cache_anchors(cache_configuration, false, cx) + }); + assert_eq!( + messages_cache(&context, cx) + .iter() + .map(|(_, cache)| cache + .as_ref() + .map_or(None, |cache| Some(cache.status.clone()))) + .collect::>>(), + vec![ + Some(CacheStatus::Pending), + Some(CacheStatus::Pending), + Some(CacheStatus::Pending), + None + ], + "Modifying a message should invalidate all future messages." + ); +} + fn messages(context: &Model, cx: &AppContext) -> Vec<(MessageId, Role, Range)> { context .read(cx) @@ -1010,6 +1165,17 @@ fn messages(context: &Model, cx: &AppContext) -> Vec<(MessageId, Role, .collect() } +fn messages_cache( + context: &Model, + cx: &AppContext, +) -> Vec<(MessageId, Option)> { + context + .read(cx) + .messages(cx) + .map(|message| (message.id, message.cache.clone())) + .collect() +} + #[derive(Clone)] struct FakeSlashCommand(String); diff --git a/crates/ui/src/components/icon.rs b/crates/ui/src/components/icon.rs index f399820016..74c565988d 100644 --- a/crates/ui/src/components/icon.rs +++ b/crates/ui/src/components/icon.rs @@ -154,6 +154,7 @@ pub enum IconName { Copy, CountdownTimer, Dash, + DatabaseZap, Delete, Disconnected, Download, @@ -322,6 +323,7 @@ impl IconName { IconName::Copy => "icons/copy.svg", IconName::CountdownTimer => "icons/countdown_timer.svg", IconName::Dash => "icons/dash.svg", + IconName::DatabaseZap => "icons/database_zap.svg", IconName::Delete => "icons/delete.svg", IconName::Disconnected => "icons/disconnected.svg", IconName::Download => "icons/download.svg",