mirror of
https://github.com/zed-industries/zed.git
synced 2024-12-28 03:25:59 +00:00
context_servers: Add initial implementation (#16103)
This commit proposes the addition of "context serveres" and the underlying protocol (model context protocol). Context servers allow simple definition of slash commands in another language and running local on the user machines. This aims to quickly prototype new commands, and provide a way to add personal (or company wide) customizations to the assistant panel, without having to maintain an extension. We can use this to reuse our existing codebase, with authenticators, etc and easily have it provide context into the assistant panel. As such it occupies a different design space as extensions, which I think are more aimed towards long-term, well maintained pieces of code that can be easily distributed. It's implemented as a central crate for easy reusability across the codebase and to easily hook into the assistant panel at all points. Design wise there are a few pieces: 1. client.rs: A simple JSON-RPC client talking over stdio to a spawned server. This is very close to how LSP work and likely there could be a combined client down the line. 2. types.rs: Serialization and deserialization client for the underlying model context protocol. 3. protocol.rs: Handling the session between client and server. 4. manager.rs: Manages settings and adding and deleting servers from a central pool. A server can be defined in the settings.json as: ``` "context_servers": [ {"id": "test", "executable": "python", "args": ["-m", "context_server"] ] ``` ## Quick Example A quick example of how a theoretical backend site can look like. With roughly 100 lines of code (nicely generated by Claude) and a bit of decorator magic (200 lines in total), one can come up with a framework that makes it as easy as: ```python @context_server.slash_command(name="rot13", description="Perform a rot13 transformation") @context_server.argument(name="input", type=str, help="String to rot13") async def rot13(input: str) -> str: return ''.join(chr((ord(c) - 97 + 13) % 26 + 97) if c.isalpha() else c for c in echo.lower()) ``` to define a new slash_command. ## Todo: - Allow context servers to be defined in workspace settings. - Allow passing env variables to context_servers Release Notes: - N/A --------- Co-authored-by: Marshall Bowers <elliott.codes@gmail.com>
This commit is contained in:
parent
d54818fd9e
commit
02ea6ac845
16 changed files with 1433 additions and 7 deletions
22
Cargo.lock
generated
22
Cargo.lock
generated
|
@ -358,6 +358,7 @@ dependencies = [
|
|||
"clock",
|
||||
"collections",
|
||||
"command_palette_hooks",
|
||||
"context_servers",
|
||||
"ctor",
|
||||
"db",
|
||||
"editor",
|
||||
|
@ -2668,6 +2669,27 @@ dependencies = [
|
|||
"tiny-keccak",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "context_servers"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"collections",
|
||||
"futures 0.3.30",
|
||||
"gpui",
|
||||
"log",
|
||||
"parking_lot",
|
||||
"postage",
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"settings",
|
||||
"smol",
|
||||
"url",
|
||||
"util",
|
||||
"workspace",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "convert_case"
|
||||
version = "0.4.0"
|
||||
|
|
|
@ -19,6 +19,7 @@ members = [
|
|||
"crates/collections",
|
||||
"crates/command_palette",
|
||||
"crates/command_palette_hooks",
|
||||
"crates/context_servers",
|
||||
"crates/copilot",
|
||||
"crates/db",
|
||||
"crates/dev_server_projects",
|
||||
|
@ -189,6 +190,7 @@ collab_ui = { path = "crates/collab_ui" }
|
|||
collections = { path = "crates/collections" }
|
||||
command_palette = { path = "crates/command_palette" }
|
||||
command_palette_hooks = { path = "crates/command_palette_hooks" }
|
||||
context_servers = { path = "crates/context_servers" }
|
||||
copilot = { path = "crates/copilot" }
|
||||
db = { path = "crates/db" }
|
||||
dev_server_projects = { path = "crates/dev_server_projects" }
|
||||
|
|
|
@ -1010,5 +1010,16 @@
|
|||
// ]
|
||||
// }
|
||||
// ]
|
||||
"ssh_connections": null
|
||||
"ssh_connections": null,
|
||||
// Configures the Context Server Protocol binaries
|
||||
//
|
||||
// Examples:
|
||||
// {
|
||||
// "id": "server-1",
|
||||
// "executable": "/path",
|
||||
// "args": ['arg1", "args2"]
|
||||
// }
|
||||
"experimental.context_servers": {
|
||||
"servers": []
|
||||
}
|
||||
}
|
||||
|
|
|
@ -33,6 +33,7 @@ clock.workspace = true
|
|||
collections.workspace = true
|
||||
command_palette_hooks.workspace = true
|
||||
db.workspace = true
|
||||
context_servers.workspace = true
|
||||
editor.workspace = true
|
||||
feature_flags.workspace = true
|
||||
fs.workspace = true
|
||||
|
|
|
@ -21,9 +21,11 @@ use assistant_slash_command::SlashCommandRegistry;
|
|||
use client::{proto, Client};
|
||||
use command_palette_hooks::CommandPaletteFilter;
|
||||
pub use context::*;
|
||||
use context_servers::ContextServerRegistry;
|
||||
pub use context_store::*;
|
||||
use feature_flags::FeatureFlagAppExt;
|
||||
use fs::Fs;
|
||||
use gpui::Context as _;
|
||||
use gpui::{actions, impl_actions, AppContext, Global, SharedString, UpdateGlobal};
|
||||
use indexed_docs::IndexedDocsRegistry;
|
||||
pub(crate) use inline_assistant::*;
|
||||
|
@ -37,9 +39,9 @@ use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
|
|||
use serde::{Deserialize, Serialize};
|
||||
use settings::{update_settings_file, Settings, SettingsStore};
|
||||
use slash_command::{
|
||||
default_command, diagnostics_command, docs_command, fetch_command, file_command, now_command,
|
||||
project_command, prompt_command, search_command, symbols_command, tab_command,
|
||||
terminal_command, workflow_command,
|
||||
context_server_command, default_command, diagnostics_command, docs_command, fetch_command,
|
||||
file_command, now_command, project_command, prompt_command, search_command, symbols_command,
|
||||
tab_command, terminal_command, workflow_command,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
pub(crate) use streaming_diff::*;
|
||||
|
@ -221,6 +223,7 @@ pub fn init(
|
|||
init_language_model_settings(cx);
|
||||
assistant_slash_command::init(cx);
|
||||
assistant_panel::init(cx);
|
||||
context_servers::init(cx);
|
||||
|
||||
let prompt_builder = prompts::PromptBuilder::new(Some(PromptOverrideContext {
|
||||
dev_mode,
|
||||
|
@ -261,9 +264,69 @@ pub fn init(
|
|||
})
|
||||
.detach();
|
||||
|
||||
register_context_server_handlers(cx);
|
||||
|
||||
prompt_builder
|
||||
}
|
||||
|
||||
fn register_context_server_handlers(cx: &mut AppContext) {
|
||||
cx.subscribe(
|
||||
&context_servers::manager::ContextServerManager::global(cx),
|
||||
|manager, event, cx| match event {
|
||||
context_servers::manager::Event::ServerStarted { server_id } => {
|
||||
cx.update_model(
|
||||
&manager,
|
||||
|manager: &mut context_servers::manager::ContextServerManager, cx| {
|
||||
let slash_command_registry = SlashCommandRegistry::global(cx);
|
||||
let context_server_registry = ContextServerRegistry::global(cx);
|
||||
if let Some(server) = manager.get_server(server_id) {
|
||||
cx.spawn(|_, _| async move {
|
||||
let Some(protocol) = server.client.read().clone() else {
|
||||
return;
|
||||
};
|
||||
|
||||
if let Some(prompts) = protocol.list_prompts().await.log_err() {
|
||||
for prompt in prompts
|
||||
.into_iter()
|
||||
.filter(context_server_command::acceptable_prompt)
|
||||
{
|
||||
log::info!(
|
||||
"registering context server command: {:?}",
|
||||
prompt.name
|
||||
);
|
||||
context_server_registry.register_command(
|
||||
server.id.clone(),
|
||||
prompt.name.as_str(),
|
||||
);
|
||||
slash_command_registry.register_command(
|
||||
context_server_command::ContextServerSlashCommand::new(
|
||||
&server, prompt,
|
||||
),
|
||||
true,
|
||||
);
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
},
|
||||
);
|
||||
}
|
||||
context_servers::manager::Event::ServerStopped { server_id } => {
|
||||
let slash_command_registry = SlashCommandRegistry::global(cx);
|
||||
let context_server_registry = ContextServerRegistry::global(cx);
|
||||
if let Some(commands) = context_server_registry.get_commands(server_id) {
|
||||
for command_name in commands {
|
||||
slash_command_registry.unregister_command_by_name(&command_name);
|
||||
context_server_registry.unregister_command(&server_id, &command_name);
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
.detach();
|
||||
}
|
||||
|
||||
fn init_language_model_settings(cx: &mut AppContext) {
|
||||
update_active_language_model_from_settings(cx);
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@ use std::{
|
|||
use ui::ActiveTheme;
|
||||
use workspace::Workspace;
|
||||
|
||||
pub mod context_server_command;
|
||||
pub mod default_command;
|
||||
pub mod diagnostics_command;
|
||||
pub mod docs_command;
|
||||
|
|
125
crates/assistant/src/slash_command/context_server_command.rs
Normal file
125
crates/assistant/src/slash_command/context_server_command.rs
Normal file
|
@ -0,0 +1,125 @@
|
|||
use anyhow::{anyhow, Result};
|
||||
use assistant_slash_command::{
|
||||
ArgumentCompletion, SlashCommand, SlashCommandOutput, SlashCommandOutputSection,
|
||||
};
|
||||
use collections::HashMap;
|
||||
use context_servers::{
|
||||
manager::{ContextServer, ContextServerManager},
|
||||
protocol::PromptInfo,
|
||||
};
|
||||
use gpui::{Task, WeakView, WindowContext};
|
||||
use language::LspAdapterDelegate;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::Arc;
|
||||
use ui::{IconName, SharedString};
|
||||
use workspace::Workspace;
|
||||
|
||||
pub struct ContextServerSlashCommand {
|
||||
server_id: String,
|
||||
prompt: PromptInfo,
|
||||
}
|
||||
|
||||
impl ContextServerSlashCommand {
|
||||
pub fn new(server: &Arc<ContextServer>, prompt: PromptInfo) -> Self {
|
||||
Self {
|
||||
server_id: server.id.clone(),
|
||||
prompt,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SlashCommand for ContextServerSlashCommand {
|
||||
fn name(&self) -> String {
|
||||
self.prompt.name.clone()
|
||||
}
|
||||
|
||||
fn description(&self) -> String {
|
||||
format!("Run context server command: {}", self.prompt.name)
|
||||
}
|
||||
|
||||
fn menu_text(&self) -> String {
|
||||
format!("Run '{}' from {}", self.prompt.name, self.server_id)
|
||||
}
|
||||
|
||||
fn requires_argument(&self) -> bool {
|
||||
self.prompt
|
||||
.arguments
|
||||
.as_ref()
|
||||
.map_or(false, |args| !args.is_empty())
|
||||
}
|
||||
|
||||
fn complete_argument(
|
||||
self: Arc<Self>,
|
||||
_arguments: &[String],
|
||||
_cancel: Arc<AtomicBool>,
|
||||
_workspace: Option<WeakView<Workspace>>,
|
||||
_cx: &mut WindowContext,
|
||||
) -> Task<Result<Vec<ArgumentCompletion>>> {
|
||||
Task::ready(Ok(Vec::new()))
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
arguments: &[String],
|
||||
_workspace: WeakView<Workspace>,
|
||||
_delegate: Option<Arc<dyn LspAdapterDelegate>>,
|
||||
cx: &mut WindowContext,
|
||||
) -> Task<Result<SlashCommandOutput>> {
|
||||
let server_id = self.server_id.clone();
|
||||
let prompt_name = self.prompt.name.clone();
|
||||
let argument = arguments.first().cloned();
|
||||
|
||||
let manager = ContextServerManager::global(cx);
|
||||
let manager = manager.read(cx);
|
||||
if let Some(server) = manager.get_server(&server_id) {
|
||||
cx.foreground_executor().spawn(async move {
|
||||
let Some(protocol) = server.client.read().clone() else {
|
||||
return Err(anyhow!("Context server not initialized"));
|
||||
};
|
||||
|
||||
let result = protocol
|
||||
.run_prompt(&prompt_name, prompt_arguments(&self.prompt, argument)?)
|
||||
.await?;
|
||||
|
||||
Ok(SlashCommandOutput {
|
||||
sections: vec![SlashCommandOutputSection {
|
||||
range: 0..result.len(),
|
||||
icon: IconName::ZedAssistant,
|
||||
label: SharedString::from(format!("Result from {}", prompt_name)),
|
||||
}],
|
||||
text: result,
|
||||
run_commands_in_text: false,
|
||||
})
|
||||
})
|
||||
} else {
|
||||
Task::ready(Err(anyhow!("Context server not found")))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn prompt_arguments(
|
||||
prompt: &PromptInfo,
|
||||
argument: Option<String>,
|
||||
) -> Result<HashMap<String, String>> {
|
||||
match &prompt.arguments {
|
||||
Some(args) if args.len() >= 2 => Err(anyhow!(
|
||||
"Prompt has more than one argument, which is not supported"
|
||||
)),
|
||||
Some(args) if args.len() == 1 => match argument {
|
||||
Some(value) => Ok(HashMap::from_iter([(args[0].name.clone(), value)])),
|
||||
None => Err(anyhow!("Prompt expects argument but none given")),
|
||||
},
|
||||
Some(_) | None => Ok(HashMap::default()),
|
||||
}
|
||||
}
|
||||
|
||||
/// MCP servers can return prompts with multiple arguments. Since we only
|
||||
/// support one argument, we ignore all others. This is the necessary predicate
|
||||
/// for this.
|
||||
pub fn acceptable_prompt(prompt: &PromptInfo) -> bool {
|
||||
match &prompt.arguments {
|
||||
None => true,
|
||||
Some(args) if args.len() == 1 => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
|
@ -58,10 +58,14 @@ impl SlashCommandRegistry {
|
|||
|
||||
/// Unregisters the provided [`SlashCommand`].
|
||||
pub fn unregister_command(&self, command: impl SlashCommand) {
|
||||
self.unregister_command_by_name(command.name().as_str())
|
||||
}
|
||||
|
||||
/// Unregisters the command with the given name.
|
||||
pub fn unregister_command_by_name(&self, command_name: &str) {
|
||||
let mut state = self.state.write();
|
||||
let command_name: Arc<str> = command.name().into();
|
||||
state.featured_commands.remove(&command_name);
|
||||
state.commands.remove(&command_name);
|
||||
state.featured_commands.remove(command_name);
|
||||
state.commands.remove(command_name);
|
||||
}
|
||||
|
||||
/// Returns the names of registered [`SlashCommand`]s.
|
||||
|
|
29
crates/context_servers/Cargo.toml
Normal file
29
crates/context_servers/Cargo.toml
Normal file
|
@ -0,0 +1,29 @@
|
|||
[package]
|
||||
name = "context_servers"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
publish = false
|
||||
license = "GPL-3.0-or-later"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/context_servers.rs"
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
collections.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
log.workspace = true
|
||||
parking_lot.workspace = true
|
||||
postage.workspace = true
|
||||
schemars.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
settings.workspace = true
|
||||
smol.workspace = true
|
||||
url = { workspace = true, features = ["serde"] }
|
||||
util.workspace = true
|
||||
workspace.workspace = true
|
1
crates/context_servers/LICENSE-GPL
Symbolic link
1
crates/context_servers/LICENSE-GPL
Symbolic link
|
@ -0,0 +1 @@
|
|||
../../LICENSE-GPL
|
432
crates/context_servers/src/client.rs
Normal file
432
crates/context_servers/src/client.rs
Normal file
|
@ -0,0 +1,432 @@
|
|||
use anyhow::{anyhow, Context, Result};
|
||||
use collections::HashMap;
|
||||
use futures::{channel::oneshot, io::BufWriter, select, AsyncRead, AsyncWrite, FutureExt};
|
||||
use gpui::{AsyncAppContext, BackgroundExecutor, Task};
|
||||
use parking_lot::Mutex;
|
||||
use postage::barrier;
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use serde_json::{value::RawValue, Value};
|
||||
use smol::{
|
||||
channel,
|
||||
io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
|
||||
process::{self, Child},
|
||||
};
|
||||
use std::{
|
||||
fmt,
|
||||
path::PathBuf,
|
||||
sync::{
|
||||
atomic::{AtomicI32, Ordering::SeqCst},
|
||||
Arc,
|
||||
},
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use util::TryFutureExt;
|
||||
|
||||
const JSON_RPC_VERSION: &str = "2.0";
|
||||
const REQUEST_TIMEOUT: Duration = Duration::from_secs(60);
|
||||
|
||||
type ResponseHandler = Box<dyn Send + FnOnce(Result<String, Error>)>;
|
||||
type NotificationHandler = Box<dyn Send + FnMut(RequestId, Value, AsyncAppContext)>;
|
||||
|
||||
#[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum RequestId {
|
||||
Int(i32),
|
||||
Str(String),
|
||||
}
|
||||
|
||||
pub struct Client {
|
||||
server_id: ContextServerId,
|
||||
next_id: AtomicI32,
|
||||
outbound_tx: channel::Sender<String>,
|
||||
name: Arc<str>,
|
||||
notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
|
||||
response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
|
||||
#[allow(clippy::type_complexity)]
|
||||
#[allow(dead_code)]
|
||||
io_tasks: Mutex<Option<(Task<Option<()>>, Task<Option<()>>)>>,
|
||||
#[allow(dead_code)]
|
||||
output_done_rx: Mutex<Option<barrier::Receiver>>,
|
||||
executor: BackgroundExecutor,
|
||||
server: Arc<Mutex<Option<Child>>>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
#[repr(transparent)]
|
||||
pub struct ContextServerId(pub String);
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct Request<'a, T> {
|
||||
jsonrpc: &'static str,
|
||||
id: RequestId,
|
||||
method: &'a str,
|
||||
params: T,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct AnyResponse<'a> {
|
||||
jsonrpc: &'a str,
|
||||
id: RequestId,
|
||||
#[serde(default)]
|
||||
error: Option<Error>,
|
||||
#[serde(borrow)]
|
||||
result: Option<&'a RawValue>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[allow(dead_code)]
|
||||
struct Response<T> {
|
||||
jsonrpc: &'static str,
|
||||
id: RequestId,
|
||||
#[serde(flatten)]
|
||||
value: CspResult<T>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
enum CspResult<T> {
|
||||
#[serde(rename = "result")]
|
||||
Ok(Option<T>),
|
||||
#[allow(dead_code)]
|
||||
Error(Option<Error>),
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct Notification<'a, T> {
|
||||
jsonrpc: &'static str,
|
||||
id: RequestId,
|
||||
#[serde(borrow)]
|
||||
method: &'a str,
|
||||
params: T,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
struct AnyNotification<'a> {
|
||||
jsonrpc: &'a str,
|
||||
id: RequestId,
|
||||
method: String,
|
||||
#[serde(default)]
|
||||
params: Option<Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct Error {
|
||||
message: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ModelContextServerBinary {
|
||||
pub executable: PathBuf,
|
||||
pub args: Vec<String>,
|
||||
pub env: Option<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
impl Client {
|
||||
/// Creates a new Client instance for a context server.
|
||||
///
|
||||
/// This function initializes a new Client by spawning a child process for the context server,
|
||||
/// setting up communication channels, and initializing handlers for input/output operations.
|
||||
/// It takes a server ID, binary information, and an async app context as input.
|
||||
pub fn new(
|
||||
server_id: ContextServerId,
|
||||
binary: ModelContextServerBinary,
|
||||
cx: AsyncAppContext,
|
||||
) -> Result<Self> {
|
||||
log::info!(
|
||||
"starting context server (executable={:?}, args={:?})",
|
||||
binary.executable,
|
||||
&binary.args
|
||||
);
|
||||
|
||||
let mut command = process::Command::new(&binary.executable);
|
||||
command
|
||||
.args(&binary.args)
|
||||
.envs(binary.env.unwrap_or_default())
|
||||
.stdin(std::process::Stdio::piped())
|
||||
.stdout(std::process::Stdio::piped())
|
||||
.stderr(std::process::Stdio::piped())
|
||||
.kill_on_drop(true);
|
||||
|
||||
let mut server = command.spawn().with_context(|| {
|
||||
format!(
|
||||
"failed to spawn command. (path={:?}, args={:?})",
|
||||
binary.executable, &binary.args
|
||||
)
|
||||
})?;
|
||||
|
||||
let stdin = server.stdin.take().unwrap();
|
||||
let stdout = server.stdout.take().unwrap();
|
||||
let stderr = server.stderr.take().unwrap();
|
||||
|
||||
let (outbound_tx, outbound_rx) = channel::unbounded::<String>();
|
||||
let (output_done_tx, output_done_rx) = barrier::channel();
|
||||
|
||||
let notification_handlers =
|
||||
Arc::new(Mutex::new(HashMap::<_, NotificationHandler>::default()));
|
||||
let response_handlers =
|
||||
Arc::new(Mutex::new(Some(HashMap::<_, ResponseHandler>::default())));
|
||||
|
||||
let stdout_input_task = cx.spawn({
|
||||
let notification_handlers = notification_handlers.clone();
|
||||
let response_handlers = response_handlers.clone();
|
||||
move |cx| {
|
||||
Self::handle_input(stdout, notification_handlers, response_handlers, cx).log_err()
|
||||
}
|
||||
});
|
||||
let stderr_input_task = cx.spawn(|_| Self::handle_stderr(stderr).log_err());
|
||||
let input_task = cx.spawn(|_| async move {
|
||||
let (stdout, stderr) = futures::join!(stdout_input_task, stderr_input_task);
|
||||
stdout.or(stderr)
|
||||
});
|
||||
let output_task = cx.background_executor().spawn({
|
||||
Self::handle_output(
|
||||
stdin,
|
||||
outbound_rx,
|
||||
output_done_tx,
|
||||
response_handlers.clone(),
|
||||
)
|
||||
.log_err()
|
||||
});
|
||||
|
||||
let mut context_server = Self {
|
||||
server_id,
|
||||
notification_handlers,
|
||||
response_handlers,
|
||||
name: "".into(),
|
||||
next_id: Default::default(),
|
||||
outbound_tx,
|
||||
executor: cx.background_executor().clone(),
|
||||
io_tasks: Mutex::new(Some((input_task, output_task))),
|
||||
output_done_rx: Mutex::new(Some(output_done_rx)),
|
||||
server: Arc::new(Mutex::new(Some(server))),
|
||||
};
|
||||
|
||||
if let Some(name) = binary.executable.file_name() {
|
||||
context_server.name = name.to_string_lossy().into();
|
||||
}
|
||||
|
||||
Ok(context_server)
|
||||
}
|
||||
|
||||
/// Handles input from the server's stdout.
|
||||
///
|
||||
/// This function continuously reads lines from the provided stdout stream,
|
||||
/// parses them as JSON-RPC responses or notifications, and dispatches them
|
||||
/// to the appropriate handlers. It processes both responses (which are matched
|
||||
/// to pending requests) and notifications (which trigger registered handlers).
|
||||
async fn handle_input<Stdout>(
|
||||
stdout: Stdout,
|
||||
notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
|
||||
response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
|
||||
cx: AsyncAppContext,
|
||||
) -> anyhow::Result<()>
|
||||
where
|
||||
Stdout: AsyncRead + Unpin + Send + 'static,
|
||||
{
|
||||
let mut stdout = BufReader::new(stdout);
|
||||
let mut buffer = String::new();
|
||||
|
||||
loop {
|
||||
buffer.clear();
|
||||
if stdout.read_line(&mut buffer).await? == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let content = buffer.trim();
|
||||
|
||||
if !content.is_empty() {
|
||||
if let Ok(response) = serde_json::from_str::<AnyResponse>(&content) {
|
||||
if let Some(handlers) = response_handlers.lock().as_mut() {
|
||||
if let Some(handler) = handlers.remove(&response.id) {
|
||||
handler(Ok(content.to_string()));
|
||||
}
|
||||
}
|
||||
} else if let Ok(notification) = serde_json::from_str::<AnyNotification>(&content) {
|
||||
let mut notification_handlers = notification_handlers.lock();
|
||||
if let Some(handler) =
|
||||
notification_handlers.get_mut(notification.method.as_str())
|
||||
{
|
||||
handler(
|
||||
notification.id,
|
||||
notification.params.unwrap_or(Value::Null),
|
||||
cx.clone(),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
smol::future::yield_now().await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Handles the stderr output from the context server.
|
||||
/// Continuously reads and logs any error messages from the server.
|
||||
async fn handle_stderr<Stderr>(stderr: Stderr) -> anyhow::Result<()>
|
||||
where
|
||||
Stderr: AsyncRead + Unpin + Send + 'static,
|
||||
{
|
||||
let mut stderr = BufReader::new(stderr);
|
||||
let mut buffer = String::new();
|
||||
|
||||
loop {
|
||||
buffer.clear();
|
||||
if stderr.read_line(&mut buffer).await? == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
log::warn!("context server stderr: {}", buffer.trim());
|
||||
smol::future::yield_now().await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Handles the output to the context server's stdin.
|
||||
/// This function continuously receives messages from the outbound channel,
|
||||
/// writes them to the server's stdin, and manages the lifecycle of response handlers.
|
||||
async fn handle_output<Stdin>(
|
||||
stdin: Stdin,
|
||||
outbound_rx: channel::Receiver<String>,
|
||||
output_done_tx: barrier::Sender,
|
||||
response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
|
||||
) -> anyhow::Result<()>
|
||||
where
|
||||
Stdin: AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
let mut stdin = BufWriter::new(stdin);
|
||||
let _clear_response_handlers = util::defer({
|
||||
let response_handlers = response_handlers.clone();
|
||||
move || {
|
||||
response_handlers.lock().take();
|
||||
}
|
||||
});
|
||||
while let Ok(message) = outbound_rx.recv().await {
|
||||
log::trace!("outgoing message: {}", message);
|
||||
|
||||
stdin.write_all(message.as_bytes()).await?;
|
||||
stdin.write_all(b"\n").await?;
|
||||
stdin.flush().await?;
|
||||
}
|
||||
drop(output_done_tx);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Sends a JSON-RPC request to the context server and waits for a response.
|
||||
/// This function handles serialization, deserialization, timeout, and error handling.
|
||||
pub async fn request<T: DeserializeOwned>(
|
||||
&self,
|
||||
method: &str,
|
||||
params: impl Serialize,
|
||||
) -> Result<T> {
|
||||
let id = self.next_id.fetch_add(1, SeqCst);
|
||||
let request = serde_json::to_string(&Request {
|
||||
jsonrpc: JSON_RPC_VERSION,
|
||||
id: RequestId::Int(id),
|
||||
method,
|
||||
params,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let (tx, rx) = oneshot::channel();
|
||||
let handle_response = self
|
||||
.response_handlers
|
||||
.lock()
|
||||
.as_mut()
|
||||
.ok_or_else(|| anyhow!("server shut down"))
|
||||
.map(|handlers| {
|
||||
handlers.insert(
|
||||
RequestId::Int(id),
|
||||
Box::new(move |result| {
|
||||
let _ = tx.send(result);
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
let send = self
|
||||
.outbound_tx
|
||||
.try_send(request)
|
||||
.context("failed to write to context server's stdin");
|
||||
|
||||
let executor = self.executor.clone();
|
||||
let started = Instant::now();
|
||||
handle_response?;
|
||||
send?;
|
||||
|
||||
let mut timeout = executor.timer(REQUEST_TIMEOUT).fuse();
|
||||
select! {
|
||||
response = rx.fuse() => {
|
||||
let elapsed = started.elapsed();
|
||||
log::trace!("took {elapsed:?} to receive response to {method:?} id {id}");
|
||||
match response? {
|
||||
Ok(response) => {
|
||||
let parsed: AnyResponse = serde_json::from_str(&response)?;
|
||||
if let Some(error) = parsed.error {
|
||||
Err(anyhow!(error.message))
|
||||
} else if let Some(result) = parsed.result {
|
||||
Ok(serde_json::from_str(result.get())?)
|
||||
} else {
|
||||
Err(anyhow!("Invalid response: no result or error"))
|
||||
}
|
||||
}
|
||||
Err(_) => anyhow::bail!("cancelled")
|
||||
}
|
||||
}
|
||||
_ = timeout => {
|
||||
log::error!("cancelled csp request task for {method:?} id {id} which took over {:?}", REQUEST_TIMEOUT);
|
||||
anyhow::bail!("Context server request timeout");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Sends a notification to the context server without expecting a response.
|
||||
/// This function serializes the notification and sends it through the outbound channel.
|
||||
pub fn notify(&self, method: &str, params: impl Serialize) -> Result<()> {
|
||||
let id = self.next_id.fetch_add(1, SeqCst);
|
||||
let notification = serde_json::to_string(&Notification {
|
||||
jsonrpc: JSON_RPC_VERSION,
|
||||
id: RequestId::Int(id),
|
||||
method,
|
||||
params,
|
||||
})
|
||||
.unwrap();
|
||||
self.outbound_tx.try_send(notification)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn on_notification<F>(&self, method: &'static str, mut f: F)
|
||||
where
|
||||
F: 'static + Send + FnMut(Value, AsyncAppContext),
|
||||
{
|
||||
self.notification_handlers
|
||||
.lock()
|
||||
.insert(method, Box::new(move |_, params, cx| f(params, cx)));
|
||||
}
|
||||
|
||||
pub fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
pub fn server_id(&self) -> ContextServerId {
|
||||
self.server_id.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Client {
|
||||
fn drop(&mut self) {
|
||||
if let Some(mut server) = self.server.lock().take() {
|
||||
let _ = server.kill();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for ContextServerId {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
self.0.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for Client {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("Context Server Client")
|
||||
.field("id", &self.server_id.0)
|
||||
.field("name", &self.name)
|
||||
.finish_non_exhaustive()
|
||||
}
|
||||
}
|
36
crates/context_servers/src/context_servers.rs
Normal file
36
crates/context_servers/src/context_servers.rs
Normal file
|
@ -0,0 +1,36 @@
|
|||
use gpui::{actions, AppContext, Context, ViewContext};
|
||||
use log;
|
||||
use manager::ContextServerManager;
|
||||
use workspace::Workspace;
|
||||
|
||||
pub mod client;
|
||||
pub mod manager;
|
||||
pub mod protocol;
|
||||
mod registry;
|
||||
pub mod types;
|
||||
|
||||
pub use registry::*;
|
||||
|
||||
actions!(context_servers, [Restart]);
|
||||
|
||||
pub fn init(cx: &mut AppContext) {
|
||||
log::info!("initializing context server client");
|
||||
manager::init(cx);
|
||||
ContextServerRegistry::register(cx);
|
||||
|
||||
cx.observe_new_views(
|
||||
|workspace: &mut Workspace, _cx: &mut ViewContext<Workspace>| {
|
||||
workspace.register_action(restart_servers);
|
||||
},
|
||||
)
|
||||
.detach();
|
||||
}
|
||||
|
||||
fn restart_servers(_workspace: &mut Workspace, _action: &Restart, cx: &mut ViewContext<Workspace>) {
|
||||
let model = ContextServerManager::global(&cx);
|
||||
cx.update_model(&model, |manager, cx| {
|
||||
for server in manager.servers() {
|
||||
manager.restart_server(&server.id, cx).detach();
|
||||
}
|
||||
});
|
||||
}
|
278
crates/context_servers/src/manager.rs
Normal file
278
crates/context_servers/src/manager.rs
Normal file
|
@ -0,0 +1,278 @@
|
|||
//! This module implements a context server management system for Zed.
|
||||
//!
|
||||
//! It provides functionality to:
|
||||
//! - Define and load context server settings
|
||||
//! - Manage individual context servers (start, stop, restart)
|
||||
//! - Maintain a global manager for all context servers
|
||||
//!
|
||||
//! Key components:
|
||||
//! - `ContextServerSettings`: Defines the structure for server configurations
|
||||
//! - `ContextServer`: Represents an individual context server
|
||||
//! - `ContextServerManager`: Manages multiple context servers
|
||||
//! - `GlobalContextServerManager`: Provides global access to the ContextServerManager
|
||||
//!
|
||||
//! The module also includes initialization logic to set up the context server system
|
||||
//! and react to changes in settings.
|
||||
|
||||
use collections::{HashMap, HashSet};
|
||||
use gpui::{AppContext, AsyncAppContext, Context, EventEmitter, Global, Model, ModelContext, Task};
|
||||
use log;
|
||||
use parking_lot::RwLock;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsSources, SettingsStore};
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
client::{self, Client},
|
||||
types,
|
||||
};
|
||||
|
||||
#[derive(Deserialize, Serialize, Default, Clone, PartialEq, Eq, JsonSchema, Debug)]
|
||||
pub struct ContextServerSettings {
|
||||
pub servers: Vec<ServerConfig>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug)]
|
||||
pub struct ServerConfig {
|
||||
pub id: String,
|
||||
pub executable: String,
|
||||
pub args: Vec<String>,
|
||||
}
|
||||
|
||||
impl Settings for ContextServerSettings {
|
||||
const KEY: Option<&'static str> = Some("experimental.context_servers");
|
||||
|
||||
type FileContent = Self;
|
||||
|
||||
fn load(
|
||||
sources: SettingsSources<Self::FileContent>,
|
||||
_: &mut gpui::AppContext,
|
||||
) -> anyhow::Result<Self> {
|
||||
sources.json_merge()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ContextServer {
|
||||
pub id: String,
|
||||
pub config: ServerConfig,
|
||||
pub client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
|
||||
}
|
||||
|
||||
impl ContextServer {
|
||||
fn new(config: ServerConfig) -> Self {
|
||||
Self {
|
||||
id: config.id.clone(),
|
||||
config,
|
||||
client: RwLock::new(None),
|
||||
}
|
||||
}
|
||||
|
||||
async fn start(&self, cx: &AsyncAppContext) -> anyhow::Result<()> {
|
||||
log::info!("starting context server {}", self.config.id);
|
||||
let client = Client::new(
|
||||
client::ContextServerId(self.config.id.clone()),
|
||||
client::ModelContextServerBinary {
|
||||
executable: Path::new(&self.config.executable).to_path_buf(),
|
||||
args: self.config.args.clone(),
|
||||
env: None,
|
||||
},
|
||||
cx.clone(),
|
||||
)?;
|
||||
|
||||
let protocol = crate::protocol::ModelContextProtocol::new(client);
|
||||
let client_info = types::EntityInfo {
|
||||
name: "Zed".to_string(),
|
||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
};
|
||||
let initialized_protocol = protocol.initialize(client_info).await?;
|
||||
|
||||
log::debug!(
|
||||
"context server {} initialized: {:?}",
|
||||
self.config.id,
|
||||
initialized_protocol.initialize,
|
||||
);
|
||||
|
||||
*self.client.write() = Some(Arc::new(initialized_protocol));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn stop(&self) -> anyhow::Result<()> {
|
||||
let mut client = self.client.write();
|
||||
if let Some(protocol) = client.take() {
|
||||
drop(protocol);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// A Context server manager manages the starting and stopping
|
||||
/// of all servers. To obtain a server to interact with, a crate
|
||||
/// must go through the `GlobalContextServerManager` which holds
|
||||
/// a model to the ContextServerManager.
|
||||
pub struct ContextServerManager {
|
||||
servers: HashMap<String, Arc<ContextServer>>,
|
||||
pending_servers: HashSet<String>,
|
||||
}
|
||||
|
||||
pub enum Event {
|
||||
ServerStarted { server_id: String },
|
||||
ServerStopped { server_id: String },
|
||||
}
|
||||
|
||||
impl Global for ContextServerManager {}
|
||||
impl EventEmitter<Event> for ContextServerManager {}
|
||||
|
||||
impl ContextServerManager {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
servers: HashMap::default(),
|
||||
pending_servers: HashSet::default(),
|
||||
}
|
||||
}
|
||||
pub fn global(cx: &AppContext) -> Model<Self> {
|
||||
cx.global::<GlobalContextServerManager>().0.clone()
|
||||
}
|
||||
|
||||
pub fn add_server(
|
||||
&mut self,
|
||||
config: ServerConfig,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Task<anyhow::Result<()>> {
|
||||
let server_id = config.id.clone();
|
||||
let server_id2 = config.id.clone();
|
||||
|
||||
if self.servers.contains_key(&server_id) || self.pending_servers.contains(&server_id) {
|
||||
return Task::ready(Ok(()));
|
||||
}
|
||||
|
||||
let task = cx.spawn(|this, mut cx| async move {
|
||||
let server = Arc::new(ContextServer::new(config));
|
||||
server.start(&cx).await?;
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.servers.insert(server_id.clone(), server);
|
||||
this.pending_servers.remove(&server_id);
|
||||
cx.emit(Event::ServerStarted {
|
||||
server_id: server_id.clone(),
|
||||
});
|
||||
})?;
|
||||
Ok(())
|
||||
});
|
||||
|
||||
self.pending_servers.insert(server_id2);
|
||||
task
|
||||
}
|
||||
|
||||
pub fn get_server(&self, id: &str) -> Option<Arc<ContextServer>> {
|
||||
self.servers.get(id).cloned()
|
||||
}
|
||||
|
||||
pub fn remove_server(
|
||||
&mut self,
|
||||
id: &str,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Task<anyhow::Result<()>> {
|
||||
let id = id.to_string();
|
||||
cx.spawn(|this, mut cx| async move {
|
||||
if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? {
|
||||
server.stop().await?;
|
||||
}
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.pending_servers.remove(&id);
|
||||
cx.emit(Event::ServerStopped {
|
||||
server_id: id.clone(),
|
||||
})
|
||||
})?;
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn restart_server(
|
||||
&mut self,
|
||||
id: &str,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Task<anyhow::Result<()>> {
|
||||
let id = id.to_string();
|
||||
cx.spawn(|this, mut cx| async move {
|
||||
if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? {
|
||||
server.stop().await?;
|
||||
let config = server.config.clone();
|
||||
let new_server = Arc::new(ContextServer::new(config));
|
||||
new_server.start(&cx).await?;
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.servers.insert(id.clone(), new_server);
|
||||
cx.emit(Event::ServerStopped {
|
||||
server_id: id.clone(),
|
||||
});
|
||||
cx.emit(Event::ServerStarted {
|
||||
server_id: id.clone(),
|
||||
});
|
||||
})?;
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn servers(&self) -> Vec<Arc<ContextServer>> {
|
||||
self.servers.values().cloned().collect()
|
||||
}
|
||||
|
||||
pub fn model(cx: &mut AppContext) -> Model<Self> {
|
||||
cx.new_model(|_cx| ContextServerManager::new())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct GlobalContextServerManager(Model<ContextServerManager>);
|
||||
impl Global for GlobalContextServerManager {}
|
||||
|
||||
impl GlobalContextServerManager {
|
||||
fn register(cx: &mut AppContext) {
|
||||
let model = ContextServerManager::model(cx);
|
||||
cx.set_global(Self(model));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn init(cx: &mut AppContext) {
|
||||
ContextServerSettings::register(cx);
|
||||
GlobalContextServerManager::register(cx);
|
||||
cx.observe_global::<SettingsStore>(|cx| {
|
||||
let manager = ContextServerManager::global(cx);
|
||||
cx.update_model(&manager, |manager, cx| {
|
||||
let settings = ContextServerSettings::get_global(cx);
|
||||
let current_servers: HashMap<String, ServerConfig> = manager
|
||||
.servers()
|
||||
.into_iter()
|
||||
.map(|server| (server.id.clone(), server.config.clone()))
|
||||
.collect();
|
||||
|
||||
let new_servers = settings
|
||||
.servers
|
||||
.iter()
|
||||
.map(|config| (config.id.clone(), config.clone()))
|
||||
.collect::<HashMap<_, _>>();
|
||||
|
||||
let servers_to_add = new_servers
|
||||
.values()
|
||||
.filter(|config| !current_servers.contains_key(&config.id))
|
||||
.cloned()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let servers_to_remove = current_servers
|
||||
.keys()
|
||||
.filter(|id| !new_servers.contains_key(*id))
|
||||
.cloned()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
log::trace!("servers_to_add={:?}", servers_to_add);
|
||||
for config in servers_to_add {
|
||||
manager.add_server(config, cx).detach();
|
||||
}
|
||||
|
||||
for id in servers_to_remove {
|
||||
manager.remove_server(&id, cx).detach();
|
||||
}
|
||||
})
|
||||
})
|
||||
.detach();
|
||||
}
|
140
crates/context_servers/src/protocol.rs
Normal file
140
crates/context_servers/src/protocol.rs
Normal file
|
@ -0,0 +1,140 @@
|
|||
//! This module implements parts of the Model Context Protocol.
|
||||
//!
|
||||
//! It handles the lifecycle messages, and provides a general interface to
|
||||
//! interacting with an MCP server. It uses the generic JSON-RPC client to
|
||||
//! read/write messages and the types from types.rs for serialization/deserialization
|
||||
//! of messages.
|
||||
|
||||
use anyhow::Result;
|
||||
use collections::HashMap;
|
||||
|
||||
use crate::client::Client;
|
||||
use crate::types;
|
||||
|
||||
pub use types::PromptInfo;
|
||||
|
||||
const PROTOCOL_VERSION: u32 = 1;
|
||||
|
||||
pub struct ModelContextProtocol {
|
||||
inner: Client,
|
||||
}
|
||||
|
||||
impl ModelContextProtocol {
|
||||
pub fn new(inner: Client) -> Self {
|
||||
Self { inner }
|
||||
}
|
||||
|
||||
pub async fn initialize(
|
||||
self,
|
||||
client_info: types::EntityInfo,
|
||||
) -> Result<InitializedContextServerProtocol> {
|
||||
let params = types::InitializeParams {
|
||||
protocol_version: PROTOCOL_VERSION,
|
||||
capabilities: types::ClientCapabilities {
|
||||
experimental: None,
|
||||
sampling: None,
|
||||
},
|
||||
client_info,
|
||||
};
|
||||
|
||||
let response: types::InitializeResponse = self
|
||||
.inner
|
||||
.request(types::RequestType::Initialize.as_str(), params)
|
||||
.await?;
|
||||
|
||||
log::trace!("mcp server info {:?}", response.server_info);
|
||||
|
||||
self.inner.notify(
|
||||
types::NotificationType::Initialized.as_str(),
|
||||
serde_json::json!({}),
|
||||
)?;
|
||||
|
||||
let initialized_protocol = InitializedContextServerProtocol {
|
||||
inner: self.inner,
|
||||
initialize: response,
|
||||
};
|
||||
|
||||
Ok(initialized_protocol)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct InitializedContextServerProtocol {
|
||||
inner: Client,
|
||||
pub initialize: types::InitializeResponse,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Copy)]
|
||||
pub enum ServerCapability {
|
||||
Experimental,
|
||||
Logging,
|
||||
Prompts,
|
||||
Resources,
|
||||
Tools,
|
||||
}
|
||||
|
||||
impl InitializedContextServerProtocol {
|
||||
/// Check if the server supports a specific capability
|
||||
pub fn capable(&self, capability: ServerCapability) -> bool {
|
||||
match capability {
|
||||
ServerCapability::Experimental => self.initialize.capabilities.experimental.is_some(),
|
||||
ServerCapability::Logging => self.initialize.capabilities.logging.is_some(),
|
||||
ServerCapability::Prompts => self.initialize.capabilities.prompts.is_some(),
|
||||
ServerCapability::Resources => self.initialize.capabilities.resources.is_some(),
|
||||
ServerCapability::Tools => self.initialize.capabilities.tools.is_some(),
|
||||
}
|
||||
}
|
||||
|
||||
fn check_capability(&self, capability: ServerCapability) -> Result<()> {
|
||||
if self.capable(capability) {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(anyhow::anyhow!(
|
||||
"Server does not support {:?} capability",
|
||||
capability
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
/// List the MCP prompts.
|
||||
pub async fn list_prompts(&self) -> Result<Vec<types::PromptInfo>> {
|
||||
self.check_capability(ServerCapability::Prompts)?;
|
||||
|
||||
let response: types::PromptsListResponse = self
|
||||
.inner
|
||||
.request(types::RequestType::PromptsList.as_str(), ())
|
||||
.await?;
|
||||
|
||||
Ok(response.prompts)
|
||||
}
|
||||
|
||||
/// Executes a prompt with the given arguments and returns the result.
|
||||
pub async fn run_prompt<P: AsRef<str>>(
|
||||
&self,
|
||||
prompt: P,
|
||||
arguments: HashMap<String, String>,
|
||||
) -> Result<String> {
|
||||
self.check_capability(ServerCapability::Prompts)?;
|
||||
|
||||
let params = types::PromptsGetParams {
|
||||
name: prompt.as_ref().to_string(),
|
||||
arguments: Some(arguments),
|
||||
};
|
||||
|
||||
let response: types::PromptsGetResponse = self
|
||||
.inner
|
||||
.request(types::RequestType::PromptsGet.as_str(), params)
|
||||
.await?;
|
||||
|
||||
Ok(response.prompt)
|
||||
}
|
||||
}
|
||||
|
||||
impl InitializedContextServerProtocol {
|
||||
pub async fn request<R: serde::de::DeserializeOwned>(
|
||||
&self,
|
||||
method: &str,
|
||||
params: impl serde::Serialize,
|
||||
) -> Result<R> {
|
||||
self.inner.request(method, params).await
|
||||
}
|
||||
}
|
47
crates/context_servers/src/registry.rs
Normal file
47
crates/context_servers/src/registry.rs
Normal file
|
@ -0,0 +1,47 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use collections::HashMap;
|
||||
use gpui::{AppContext, Global, ReadGlobal};
|
||||
use parking_lot::RwLock;
|
||||
|
||||
struct GlobalContextServerRegistry(Arc<ContextServerRegistry>);
|
||||
|
||||
impl Global for GlobalContextServerRegistry {}
|
||||
|
||||
pub struct ContextServerRegistry {
|
||||
registry: RwLock<HashMap<String, Vec<Arc<str>>>>,
|
||||
}
|
||||
|
||||
impl ContextServerRegistry {
|
||||
pub fn global(cx: &AppContext) -> Arc<Self> {
|
||||
GlobalContextServerRegistry::global(cx).0.clone()
|
||||
}
|
||||
|
||||
pub fn register(cx: &mut AppContext) {
|
||||
cx.set_global(GlobalContextServerRegistry(Arc::new(
|
||||
ContextServerRegistry {
|
||||
registry: RwLock::new(HashMap::default()),
|
||||
},
|
||||
)))
|
||||
}
|
||||
|
||||
pub fn register_command(&self, server_id: String, command_name: &str) {
|
||||
let mut registry = self.registry.write();
|
||||
registry
|
||||
.entry(server_id)
|
||||
.or_default()
|
||||
.push(command_name.into());
|
||||
}
|
||||
|
||||
pub fn unregister_command(&self, server_id: &str, command_name: &str) {
|
||||
let mut registry = self.registry.write();
|
||||
if let Some(commands) = registry.get_mut(server_id) {
|
||||
commands.retain(|name| name.as_ref() != command_name);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_commands(&self, server_id: &str) -> Option<Vec<Arc<str>>> {
|
||||
let registry = self.registry.read();
|
||||
registry.get(server_id).cloned()
|
||||
}
|
||||
}
|
234
crates/context_servers/src/types.rs
Normal file
234
crates/context_servers/src/types.rs
Normal file
|
@ -0,0 +1,234 @@
|
|||
use collections::HashMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use url::Url;
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub enum RequestType {
|
||||
Initialize,
|
||||
CallTool,
|
||||
ResourcesUnsubscribe,
|
||||
ResourcesSubscribe,
|
||||
ResourcesRead,
|
||||
ResourcesList,
|
||||
LoggingSetLevel,
|
||||
PromptsGet,
|
||||
PromptsList,
|
||||
}
|
||||
|
||||
impl RequestType {
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
RequestType::Initialize => "initialize",
|
||||
RequestType::CallTool => "tools/call",
|
||||
RequestType::ResourcesUnsubscribe => "resources/unsubscribe",
|
||||
RequestType::ResourcesSubscribe => "resources/subscribe",
|
||||
RequestType::ResourcesRead => "resources/read",
|
||||
RequestType::ResourcesList => "resources/list",
|
||||
RequestType::LoggingSetLevel => "logging/setLevel",
|
||||
RequestType::PromptsGet => "prompts/get",
|
||||
RequestType::PromptsList => "prompts/list",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct InitializeParams {
|
||||
pub protocol_version: u32,
|
||||
pub capabilities: ClientCapabilities,
|
||||
pub client_info: EntityInfo,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CallToolParams {
|
||||
pub name: String,
|
||||
pub arguments: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ResourcesUnsubscribeParams {
|
||||
pub uri: Url,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ResourcesSubscribeParams {
|
||||
pub uri: Url,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ResourcesReadParams {
|
||||
pub uri: Url,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct LoggingSetLevelParams {
|
||||
pub level: LoggingLevel,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PromptsGetParams {
|
||||
pub name: String,
|
||||
pub arguments: Option<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct InitializeResponse {
|
||||
pub protocol_version: u32,
|
||||
pub capabilities: ServerCapabilities,
|
||||
pub server_info: EntityInfo,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ResourcesReadResponse {
|
||||
pub contents: Vec<ResourceContent>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ResourcesListResponse {
|
||||
pub resource_templates: Option<Vec<ResourceTemplate>>,
|
||||
pub resources: Vec<Resource>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PromptsGetResponse {
|
||||
pub prompt: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PromptsListResponse {
|
||||
pub prompts: Vec<PromptInfo>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PromptInfo {
|
||||
pub name: String,
|
||||
pub arguments: Option<Vec<PromptArgument>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PromptArgument {
|
||||
pub name: String,
|
||||
pub description: Option<String>,
|
||||
pub required: Option<bool>,
|
||||
}
|
||||
|
||||
// Shared Types
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ClientCapabilities {
|
||||
pub experimental: Option<HashMap<String, serde_json::Value>>,
|
||||
pub sampling: Option<HashMap<String, serde_json::Value>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ServerCapabilities {
|
||||
pub experimental: Option<HashMap<String, serde_json::Value>>,
|
||||
pub logging: Option<HashMap<String, serde_json::Value>>,
|
||||
pub prompts: Option<HashMap<String, serde_json::Value>>,
|
||||
pub resources: Option<ResourcesCapabilities>,
|
||||
pub tools: Option<HashMap<String, serde_json::Value>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ResourcesCapabilities {
|
||||
pub subscribe: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Tool {
|
||||
pub name: String,
|
||||
pub description: Option<String>,
|
||||
pub input_schema: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct EntityInfo {
|
||||
pub name: String,
|
||||
pub version: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Resource {
|
||||
pub uri: Url,
|
||||
pub mime_type: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ResourceContent {
|
||||
pub uri: Url,
|
||||
pub mime_type: Option<String>,
|
||||
pub content_type: String,
|
||||
pub text: Option<String>,
|
||||
pub data: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ResourceTemplate {
|
||||
pub uri_template: String,
|
||||
pub name: Option<String>,
|
||||
pub description: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum LoggingLevel {
|
||||
Debug,
|
||||
Info,
|
||||
Warning,
|
||||
Error,
|
||||
}
|
||||
|
||||
// Client Notifications
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub enum NotificationType {
|
||||
Initialized,
|
||||
Progress,
|
||||
}
|
||||
|
||||
impl NotificationType {
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
NotificationType::Initialized => "notifications/initialized",
|
||||
NotificationType::Progress => "notifications/progress",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ClientNotification {
|
||||
Initialized,
|
||||
Progress(ProgressParams),
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ProgressParams {
|
||||
pub progress_token: String,
|
||||
pub progress: f64,
|
||||
pub total: Option<f64>,
|
||||
}
|
Loading…
Reference in a new issue