diff --git a/crates/assistant/src/slash_command/context_server_command.rs b/crates/assistant/src/slash_command/context_server_command.rs index 3db057d074..9e6c4b7718 100644 --- a/crates/assistant/src/slash_command/context_server_command.rs +++ b/crates/assistant/src/slash_command/context_server_command.rs @@ -145,7 +145,28 @@ impl SlashCommand for ContextServerSlashCommand { return Err(anyhow!("Context server not initialized")); }; let result = protocol.run_prompt(&prompt_name, prompt_args).await?; - let mut prompt = result.prompt; + + // Check that there are only user roles + if result + .messages + .iter() + .any(|msg| !matches!(msg.role, context_servers::types::SamplingRole::User)) + { + return Err(anyhow!( + "Prompt contains non-user roles, which is not supported" + )); + } + + // Extract text from user messages into a single prompt string + let mut prompt = result + .messages + .into_iter() + .filter_map(|msg| match msg.content { + context_servers::types::SamplingContent::Text { text } => Some(text), + _ => None, + }) + .collect::>() + .join("\n\n"); // We must normalize the line endings here, since servers might return CR characters. LineEnding::normalize(&mut prompt); diff --git a/crates/context_servers/src/protocol.rs b/crates/context_servers/src/protocol.rs index 451db56ef3..80a7a7f991 100644 --- a/crates/context_servers/src/protocol.rs +++ b/crates/context_servers/src/protocol.rs @@ -11,7 +11,7 @@ use collections::HashMap; use crate::client::Client; use crate::types; -const PROTOCOL_VERSION: u32 = 1; +const PROTOCOL_VERSION: &str = "2024-10-07"; pub struct ModelContextProtocol { inner: Client, @@ -22,12 +22,19 @@ impl ModelContextProtocol { Self { inner } } + fn supported_protocols() -> Vec { + vec![ + types::ProtocolVersion::VersionString(PROTOCOL_VERSION.to_string()), + types::ProtocolVersion::VersionNumber(1), + ] + } + pub async fn initialize( self, client_info: types::Implementation, ) -> Result { let params = types::InitializeParams { - protocol_version: PROTOCOL_VERSION, + protocol_version: types::ProtocolVersion::VersionString(PROTOCOL_VERSION.to_string()), capabilities: types::ClientCapabilities { experimental: None, sampling: None, @@ -40,6 +47,13 @@ impl ModelContextProtocol { .request(types::RequestType::Initialize.as_str(), params) .await?; + if !Self::supported_protocols().contains(&response.protocol_version) { + return Err(anyhow::anyhow!( + "Unsupported protocol version: {:?}", + response.protocol_version + )); + } + log::trace!("mcp server info {:?}", response.server_info); self.inner.notify( diff --git a/crates/context_servers/src/types.rs b/crates/context_servers/src/types.rs index 04ac87c704..2bca0a021a 100644 --- a/crates/context_servers/src/types.rs +++ b/crates/context_servers/src/types.rs @@ -36,10 +36,17 @@ impl RequestType { } } +#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)] +#[serde(untagged)] +pub enum ProtocolVersion { + VersionString(String), + VersionNumber(u32), +} + #[derive(Debug, Serialize)] #[serde(rename_all = "camelCase")] pub struct InitializeParams { - pub protocol_version: u32, + pub protocol_version: ProtocolVersion, pub capabilities: ClientCapabilities, pub client_info: Implementation, } @@ -131,7 +138,7 @@ pub struct CompletionArgument { #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] pub struct InitializeResponse { - pub protocol_version: u32, + pub protocol_version: ProtocolVersion, pub capabilities: ServerCapabilities, pub server_info: Implementation, } @@ -145,10 +152,9 @@ pub struct ResourcesReadResponse { #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ResourcesListResponse { + pub resources: Vec, #[serde(skip_serializing_if = "Option::is_none")] - pub resource_templates: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub resources: Option>, + pub next_cursor: Option, } #[derive(Debug, Serialize, Deserialize)] @@ -179,13 +185,15 @@ pub enum SamplingContent { pub struct PromptsGetResponse { #[serde(skip_serializing_if = "Option::is_none")] pub description: Option, - pub prompt: String, + pub messages: Vec, } #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] pub struct PromptsListResponse { pub prompts: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub next_cursor: Option, } #[derive(Debug, Deserialize)]