diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 09e308f602..37360589ca 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -1,3 +1,4 @@ +use crate::ContextStoreEvent; use crate::{ assistant_settings::{AssistantDockPosition, AssistantSettings}, humanize_token_count, @@ -389,6 +390,7 @@ impl AssistantPanel { cx.subscribe(&pane, Self::handle_pane_event), cx.subscribe(&context_editor_toolbar, Self::handle_toolbar_event), cx.subscribe(&model_summary_editor, Self::handle_summary_editor_event), + cx.subscribe(&context_store, Self::handle_context_store_event), cx.observe( &LanguageModelCompletionProvider::global(cx), |this, _, cx| { @@ -507,6 +509,46 @@ impl AssistantPanel { } } + fn handle_context_store_event( + &mut self, + _context_store: Model, + event: &ContextStoreEvent, + cx: &mut ViewContext, + ) { + let ContextStoreEvent::ContextCreated(context_id) = event; + let Some(context) = self + .context_store + .read(cx) + .loaded_context_for_id(&context_id, cx) + else { + log::error!("no context found with ID: {}", context_id.to_proto()); + return; + }; + let Some(workspace) = self.workspace.upgrade() else { + return; + }; + let lsp_adapter_delegate = workspace.update(cx, |workspace, cx| { + make_lsp_adapter_delegate(workspace.project(), cx).log_err() + }); + + let assistant_panel = cx.view().downgrade(); + let editor = cx.new_view(|cx| { + let mut editor = ContextEditor::for_context( + context, + self.fs.clone(), + workspace.clone(), + self.project.clone(), + lsp_adapter_delegate, + assistant_panel, + cx, + ); + editor.insert_default_prompt(cx); + editor + }); + + self.show_context(editor.clone(), cx); + } + fn completion_provider_changed(&mut self, cx: &mut ViewContext) { if let Some(editor) = self.active_context_editor(cx) { editor.update(cx, |active_context, cx| { @@ -681,29 +723,75 @@ impl AssistantPanel { } fn new_context(&mut self, cx: &mut ViewContext) -> Option> { - let context = self.context_store.update(cx, |store, cx| store.create(cx)); - let workspace = self.workspace.upgrade()?; - let lsp_adapter_delegate = workspace.update(cx, |workspace, cx| { - make_lsp_adapter_delegate(workspace.project(), cx).log_err() - }); + if self.project.read(cx).is_remote() { + let task = self + .context_store + .update(cx, |store, cx| store.create_remote_context(cx)); - let assistant_panel = cx.view().downgrade(); - let editor = cx.new_view(|cx| { - let mut editor = ContextEditor::for_context( - context, - self.fs.clone(), - workspace.clone(), - self.project.clone(), - lsp_adapter_delegate, - assistant_panel, - cx, - ); - editor.insert_default_prompt(cx); - editor - }); + cx.spawn(|this, mut cx| async move { + let context = task.await?; - self.show_context(editor.clone(), cx); - Some(editor) + this.update(&mut cx, |this, cx| { + let Some(workspace) = this.workspace.upgrade() else { + return Ok(()); + }; + let lsp_adapter_delegate = workspace.update(cx, |workspace, cx| { + make_lsp_adapter_delegate(workspace.project(), cx).log_err() + }); + + let fs = this.fs.clone(); + let project = this.project.clone(); + let weak_assistant_panel = cx.view().downgrade(); + + let editor = cx.new_view(|cx| { + let mut editor = ContextEditor::for_context( + context, + fs, + workspace.clone(), + project, + lsp_adapter_delegate, + weak_assistant_panel, + cx, + ); + editor.insert_default_prompt(cx); + editor + }); + + this.show_context(editor, cx); + + anyhow::Ok(()) + })??; + + anyhow::Ok(()) + }) + .detach_and_log_err(cx); + + None + } else { + let context = self.context_store.update(cx, |store, cx| store.create(cx)); + let workspace = self.workspace.upgrade()?; + let lsp_adapter_delegate = workspace.update(cx, |workspace, cx| { + make_lsp_adapter_delegate(workspace.project(), cx).log_err() + }); + + let assistant_panel = cx.view().downgrade(); + let editor = cx.new_view(|cx| { + let mut editor = ContextEditor::for_context( + context, + self.fs.clone(), + workspace.clone(), + self.project.clone(), + lsp_adapter_delegate, + assistant_panel, + cx, + ); + editor.insert_default_prompt(cx); + editor + }); + + self.show_context(editor.clone(), cx); + Some(editor) + } } fn show_context(&mut self, context_editor: View, cx: &mut ViewContext) { diff --git a/crates/assistant/src/context_store.rs b/crates/assistant/src/context_store.rs index 9cf60fc014..a17646a408 100644 --- a/crates/assistant/src/context_store.rs +++ b/crates/assistant/src/context_store.rs @@ -8,7 +8,9 @@ use clock::ReplicaId; use fs::Fs; use futures::StreamExt; use fuzzy::StringMatchCandidate; -use gpui::{AppContext, AsyncAppContext, Context as _, Model, ModelContext, Task, WeakModel}; +use gpui::{ + AppContext, AsyncAppContext, Context as _, EventEmitter, Model, ModelContext, Task, WeakModel, +}; use language::LanguageRegistry; use paths::contexts_dir; use project::Project; @@ -26,6 +28,7 @@ use util::{ResultExt, TryFutureExt}; pub fn init(client: &Arc) { client.add_model_message_handler(ContextStore::handle_advertise_contexts); client.add_model_request_handler(ContextStore::handle_open_context); + client.add_model_request_handler(ContextStore::handle_create_context); client.add_model_message_handler(ContextStore::handle_update_context); client.add_model_request_handler(ContextStore::handle_synchronize_contexts); } @@ -51,6 +54,12 @@ pub struct ContextStore { _project_subscriptions: Vec, } +pub enum ContextStoreEvent { + ContextCreated(ContextId), +} + +impl EventEmitter for ContextStore {} + enum ContextHandle { Weak(WeakModel), Strong(Model), @@ -169,6 +178,34 @@ impl ContextStore { }) } + async fn handle_create_context( + this: Model, + _: TypedEnvelope, + mut cx: AsyncAppContext, + ) -> Result { + let (context_id, operations) = this.update(&mut cx, |this, cx| { + if this.project.read(cx).is_remote() { + return Err(anyhow!("can only create contexts as the host")); + } + + let context = this.create(cx); + let context_id = context.read(cx).id().clone(); + cx.emit(ContextStoreEvent::ContextCreated(context_id.clone())); + + anyhow::Ok(( + context_id, + context + .read(cx) + .serialize_ops(&ContextVersion::default(), cx), + )) + })??; + let operations = operations.await; + Ok(proto::CreateContextResponse { + context_id: context_id.to_proto(), + context: Some(proto::Context { operations }), + }) + } + async fn handle_update_context( this: Model, envelope: TypedEnvelope, @@ -299,6 +336,60 @@ impl ContextStore { context } + pub fn create_remote_context( + &mut self, + cx: &mut ModelContext, + ) -> Task>> { + let project = self.project.read(cx); + let Some(project_id) = project.remote_id() else { + return Task::ready(Err(anyhow!("project was not remote"))); + }; + if project.is_local() { + return Task::ready(Err(anyhow!("cannot create remote contexts as the host"))); + } + + let replica_id = project.replica_id(); + let capability = project.capability(); + let language_registry = self.languages.clone(); + let telemetry = self.telemetry.clone(); + let request = self.client.request(proto::CreateContext { project_id }); + cx.spawn(|this, mut cx| async move { + let response = request.await?; + let context_id = ContextId::from_proto(response.context_id); + let context_proto = response.context.context("invalid context")?; + let context = cx.new_model(|cx| { + Context::new( + context_id.clone(), + replica_id, + capability, + language_registry, + Some(telemetry), + cx, + ) + })?; + let operations = cx + .background_executor() + .spawn(async move { + context_proto + .operations + .into_iter() + .map(|op| ContextOperation::from_proto(op)) + .collect::>>() + }) + .await?; + context.update(&mut cx, |context, cx| context.apply_ops(operations, cx))??; + this.update(&mut cx, |this, cx| { + if let Some(existing_context) = this.loaded_context_for_id(&context_id, cx) { + existing_context + } else { + this.register_context(&context, cx); + this.synchronize_contexts(cx); + context + } + }) + }) + } + pub fn open_local_context( &mut self, path: PathBuf, @@ -346,7 +437,11 @@ impl ContextStore { }) } - fn loaded_context_for_id(&self, id: &ContextId, cx: &AppContext) -> Option> { + pub(super) fn loaded_context_for_id( + &self, + id: &ContextId, + cx: &AppContext, + ) -> Option> { self.contexts.iter().find_map(|context| { let context = context.upgrade()?; if context.read(cx).id() == id { diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index f536f41aca..bced91eb60 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -600,6 +600,9 @@ impl Server { .add_request_handler(user_handler( forward_mutating_project_request::, )) + .add_request_handler(user_handler( + forward_mutating_project_request::, + )) .add_request_handler(user_handler( forward_mutating_project_request::, )) diff --git a/crates/proto/proto/zed.proto b/crates/proto/proto/zed.proto index 404acb42e8..2c70bcc613 100644 --- a/crates/proto/proto/zed.proto +++ b/crates/proto/proto/zed.proto @@ -199,7 +199,7 @@ message Envelope { StreamCompleteWithLanguageModel stream_complete_with_language_model = 228; StreamCompleteWithLanguageModelResponse stream_complete_with_language_model_response = 229; CountLanguageModelTokens count_language_model_tokens = 230; - CountLanguageModelTokensResponse count_language_model_tokens_response = 231; // current max + CountLanguageModelTokensResponse count_language_model_tokens_response = 231; GetCachedEmbeddings get_cached_embeddings = 189; GetCachedEmbeddingsResponse get_cached_embeddings_response = 190; ComputeEmbeddings compute_embeddings = 191; @@ -255,6 +255,8 @@ message Envelope { AdvertiseContexts advertise_contexts = 211; OpenContext open_context = 212; OpenContextResponse open_context_response = 213; + CreateContext create_context = 232; + CreateContextResponse create_context_response = 233; // current max UpdateContext update_context = 214; SynchronizeContexts synchronize_contexts = 215; SynchronizeContextsResponse synchronize_contexts_response = 216; @@ -2381,6 +2383,15 @@ message OpenContextResponse { Context context = 1; } +message CreateContext { + uint64 project_id = 1; +} + +message CreateContextResponse { + string context_id = 1; + Context context = 2; +} + message UpdateContext { uint64 project_id = 1; string context_id = 2; diff --git a/crates/proto/src/proto.rs b/crates/proto/src/proto.rs index 632a6f6951..451292308d 100644 --- a/crates/proto/src/proto.rs +++ b/crates/proto/src/proto.rs @@ -398,6 +398,8 @@ messages!( (AdvertiseContexts, Foreground), (OpenContext, Foreground), (OpenContextResponse, Foreground), + (CreateContext, Foreground), + (CreateContextResponse, Foreground), (UpdateContext, Foreground), (SynchronizeContexts, Foreground), (SynchronizeContextsResponse, Foreground), @@ -523,6 +525,7 @@ request_messages!( (RenameDevServer, Ack), (RestartLanguageServers, Ack), (OpenContext, OpenContextResponse), + (CreateContext, CreateContextResponse), (SynchronizeContexts, SynchronizeContextsResponse), (AddWorktree, AddWorktreeResponse), ); @@ -589,6 +592,7 @@ entity_messages!( LspExtExpandMacro, AdvertiseContexts, OpenContext, + CreateContext, UpdateContext, SynchronizeContexts, );