diff --git a/Cargo.lock b/Cargo.lock index 3e77ccc35c..957466997a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2815,6 +2815,7 @@ name = "context_servers" version = "0.1.0" dependencies = [ "anyhow", + "async-trait", "collections", "command_palette_hooks", "futures 0.3.30", diff --git a/crates/assistant/src/context_store.rs b/crates/assistant/src/context_store.rs index ad8ae9808f..868495d890 100644 --- a/crates/assistant/src/context_store.rs +++ b/crates/assistant/src/context_store.rs @@ -819,7 +819,7 @@ impl ContextStore { |context_server_manager, cx| { for server in context_server_manager.servers() { context_server_manager - .restart_server(&server.id, cx) + .restart_server(&server.id(), cx) .detach_and_log_err(cx); } }, @@ -850,7 +850,7 @@ impl ContextStore { let server = server.clone(); let server_id = server_id.clone(); |this, mut cx| async move { - let Some(protocol) = server.client.read().clone() else { + let Some(protocol) = server.client() else { return; }; @@ -889,7 +889,7 @@ impl ContextStore { tool_working_set.insert( Arc::new(tools::context_server_tool::ContextServerTool::new( context_server_manager.clone(), - server.id.clone(), + server.id(), tool, )), ) diff --git a/crates/assistant/src/slash_command/context_server_command.rs b/crates/assistant/src/slash_command/context_server_command.rs index 997f289c9b..843c2081a7 100644 --- a/crates/assistant/src/slash_command/context_server_command.rs +++ b/crates/assistant/src/slash_command/context_server_command.rs @@ -20,18 +20,18 @@ use crate::slash_command::create_label_for_command; pub struct ContextServerSlashCommand { server_manager: Model, - server_id: String, + server_id: Arc, prompt: Prompt, } impl ContextServerSlashCommand { pub fn new( server_manager: Model, - server: &Arc, + server: &Arc, prompt: Prompt, ) -> Self { Self { - server_id: server.id.clone(), + server_id: server.id(), prompt, server_manager, } @@ -89,7 +89,7 @@ impl SlashCommand for ContextServerSlashCommand { if let Some(server) = self.server_manager.read(cx).get_server(&server_id) { cx.foreground_executor().spawn(async move { - let Some(protocol) = server.client.read().clone() else { + let Some(protocol) = server.client() else { return Err(anyhow!("Context server not initialized")); }; @@ -143,7 +143,7 @@ impl SlashCommand for ContextServerSlashCommand { let manager = self.server_manager.read(cx); if let Some(server) = manager.get_server(&server_id) { cx.foreground_executor().spawn(async move { - let Some(protocol) = server.client.read().clone() else { + let Some(protocol) = server.client() else { return Err(anyhow!("Context server not initialized")); }; let result = protocol.run_prompt(&prompt_name, prompt_args).await?; diff --git a/crates/assistant/src/tools/context_server_tool.rs b/crates/assistant/src/tools/context_server_tool.rs index 72bb87191f..aa742bd9eb 100644 --- a/crates/assistant/src/tools/context_server_tool.rs +++ b/crates/assistant/src/tools/context_server_tool.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use anyhow::{anyhow, bail}; use assistant_tool::Tool; use context_servers::manager::ContextServerManager; @@ -6,14 +8,14 @@ use gpui::{Model, Task}; pub struct ContextServerTool { server_manager: Model, - server_id: String, + server_id: Arc, tool: types::Tool, } impl ContextServerTool { pub fn new( server_manager: Model, - server_id: impl Into, + server_id: impl Into>, tool: types::Tool, ) -> Self { Self { @@ -55,7 +57,7 @@ impl Tool for ContextServerTool { cx.foreground_executor().spawn({ let tool_name = self.tool.name.clone(); async move { - let Some(protocol) = server.client.read().clone() else { + let Some(protocol) = server.client() else { bail!("Context server not initialized"); }; diff --git a/crates/context_servers/Cargo.toml b/crates/context_servers/Cargo.toml index 73a393375c..b3d1ccf964 100644 --- a/crates/context_servers/Cargo.toml +++ b/crates/context_servers/Cargo.toml @@ -13,6 +13,7 @@ path = "src/context_servers.rs" [dependencies] anyhow.workspace = true +async-trait.workspace = true collections.workspace = true command_palette_hooks.workspace = true futures.workspace = true diff --git a/crates/context_servers/src/manager.rs b/crates/context_servers/src/manager.rs index 34583e559a..c94d01d7b7 100644 --- a/crates/context_servers/src/manager.rs +++ b/crates/context_servers/src/manager.rs @@ -15,9 +15,13 @@ //! and react to changes in settings. use std::path::Path; +use std::pin::Pin; use std::sync::Arc; +use anyhow::Result; +use async_trait::async_trait; use collections::{HashMap, HashSet}; +use futures::{Future, FutureExt}; use gpui::{AsyncAppContext, EventEmitter, ModelContext, Task}; use log; use parking_lot::RwLock; @@ -56,51 +60,84 @@ impl Settings for ContextServerSettings { } } -pub struct ContextServer { - pub id: String, - pub config: ServerConfig, +#[async_trait(?Send)] +pub trait ContextServer: Send + Sync + 'static { + fn id(&self) -> Arc; + fn config(&self) -> Arc; + fn client(&self) -> Option>; + fn start<'a>( + self: Arc, + cx: &'a AsyncAppContext, + ) -> Pin>>>; + fn stop(&self) -> Result<()>; +} + +pub struct NativeContextServer { + pub id: Arc, + pub config: Arc, pub client: RwLock>>, } -impl ContextServer { - fn new(config: ServerConfig) -> Self { +impl NativeContextServer { + fn new(config: Arc) -> Self { Self { - id: config.id.clone(), + id: config.id.clone().into(), config, client: RwLock::new(None), } } +} - async fn start(&self, cx: &AsyncAppContext) -> anyhow::Result<()> { - log::info!("starting context server {}", self.config.id,); - let client = Client::new( - client::ContextServerId(self.config.id.clone()), - client::ModelContextServerBinary { - executable: Path::new(&self.config.executable).to_path_buf(), - args: self.config.args.clone(), - env: self.config.env.clone(), - }, - cx.clone(), - )?; - - let protocol = crate::protocol::ModelContextProtocol::new(client); - let client_info = types::Implementation { - name: "Zed".to_string(), - version: env!("CARGO_PKG_VERSION").to_string(), - }; - let initialized_protocol = protocol.initialize(client_info).await?; - - log::debug!( - "context server {} initialized: {:?}", - self.config.id, - initialized_protocol.initialize, - ); - - *self.client.write() = Some(Arc::new(initialized_protocol)); - Ok(()) +#[async_trait(?Send)] +impl ContextServer for NativeContextServer { + fn id(&self) -> Arc { + self.id.clone() } - async fn stop(&self) -> anyhow::Result<()> { + fn config(&self) -> Arc { + self.config.clone() + } + + fn client(&self) -> Option> { + self.client.read().clone() + } + + fn start<'a>( + self: Arc, + cx: &'a AsyncAppContext, + ) -> Pin>>> { + async move { + log::info!("starting context server {}", self.config.id,); + let client = Client::new( + client::ContextServerId(self.config.id.clone()), + client::ModelContextServerBinary { + executable: Path::new(&self.config.executable).to_path_buf(), + args: self.config.args.clone(), + env: self.config.env.clone(), + }, + cx.clone(), + )?; + + let protocol = crate::protocol::ModelContextProtocol::new(client); + let client_info = types::Implementation { + name: "Zed".to_string(), + version: env!("CARGO_PKG_VERSION").to_string(), + }; + let initialized_protocol = protocol.initialize(client_info).await?; + + log::debug!( + "context server {} initialized: {:?}", + self.config.id, + initialized_protocol.initialize, + ); + + *self.client.write() = Some(Arc::new(initialized_protocol)); + Ok(()) + } + .boxed_local() + } + + fn stop(&self) -> Result<()> { let mut client = self.client.write(); if let Some(protocol) = client.take() { drop(protocol); @@ -114,7 +151,7 @@ impl ContextServer { /// must go through the `GlobalContextServerManager` which holds /// a model to the ContextServerManager. pub struct ContextServerManager { - servers: HashMap>, + servers: HashMap>, pending_servers: HashSet, } @@ -141,7 +178,7 @@ impl ContextServerManager { pub fn add_server( &mut self, - config: ServerConfig, + config: Arc, cx: &ModelContext, ) -> Task> { let server_id = config.id.clone(); @@ -153,8 +190,8 @@ impl ContextServerManager { let task = { let server_id = server_id.clone(); cx.spawn(|this, mut cx| async move { - let server = Arc::new(ContextServer::new(config)); - server.start(&cx).await?; + let server = Arc::new(NativeContextServer::new(config)); + server.clone().start(&cx).await?; this.update(&mut cx, |this, cx| { this.servers.insert(server_id.clone(), server); this.pending_servers.remove(&server_id); @@ -170,7 +207,7 @@ impl ContextServerManager { task } - pub fn get_server(&self, id: &str) -> Option> { + pub fn get_server(&self, id: &str) -> Option> { self.servers.get(id).cloned() } @@ -178,7 +215,7 @@ impl ContextServerManager { let id = id.to_string(); cx.spawn(|this, mut cx| async move { if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? { - server.stop().await?; + server.stop()?; } this.update(&mut cx, |this, cx| { this.pending_servers.remove(&id); @@ -192,16 +229,16 @@ impl ContextServerManager { pub fn restart_server( &mut self, - id: &str, + id: &Arc, cx: &mut ModelContext, ) -> Task> { let id = id.to_string(); cx.spawn(|this, mut cx| async move { if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? { - server.stop().await?; - let config = server.config.clone(); - let new_server = Arc::new(ContextServer::new(config)); - new_server.start(&cx).await?; + server.stop()?; + let config = server.config(); + let new_server = Arc::new(NativeContextServer::new(config)); + new_server.clone().start(&cx).await?; this.update(&mut cx, |this, cx| { this.servers.insert(id.clone(), new_server); cx.emit(Event::ServerStopped { @@ -216,7 +253,7 @@ impl ContextServerManager { }) } - pub fn servers(&self) -> Vec> { + pub fn servers(&self) -> Vec> { self.servers.values().cloned().collect() } @@ -224,7 +261,7 @@ impl ContextServerManager { let current_servers = self .servers() .into_iter() - .map(|server| (server.id.clone(), server.config.clone())) + .map(|server| (server.id(), server.config())) .collect::>(); let new_servers = settings @@ -235,19 +272,19 @@ impl ContextServerManager { let servers_to_add = new_servers .values() - .filter(|config| !current_servers.contains_key(&config.id)) + .filter(|config| !current_servers.contains_key(config.id.as_str())) .cloned() .collect::>(); let servers_to_remove = current_servers .keys() - .filter(|id| !new_servers.contains_key(*id)) + .filter(|id| !new_servers.contains_key(id.as_ref())) .cloned() .collect::>(); log::trace!("servers_to_add={:?}", servers_to_add); for config in servers_to_add { - self.add_server(config, cx).detach_and_log_err(cx); + self.add_server(Arc::new(config), cx).detach_and_log_err(cx); } for id in servers_to_remove {