context_servers: Update protocol (#19547)

We sadly have to change the underlying protocol once again. This will
likely be the last change to the core protocol without correctly
handling older versions. From here on out, we want to get better with
version handling. To do so, we introduce the notion of a string protocol
version to be explicit of when the underlying protocol last changed.

The change also changes the return values of prompts. For now we only
allow User messages from servers to match the current behaviour. We will
change this once #19222 lands which will allow slash commands to insert
user and assistant messages.

Release Notes:

- N/A
This commit is contained in:
David Soria Parra 2024-10-22 16:19:32 +01:00 committed by GitHub
parent 680b3dd80b
commit d8d8c908ed
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 52 additions and 9 deletions

View file

@ -145,7 +145,28 @@ impl SlashCommand for ContextServerSlashCommand {
return Err(anyhow!("Context server not initialized"));
};
let result = protocol.run_prompt(&prompt_name, prompt_args).await?;
let mut prompt = result.prompt;
// Check that there are only user roles
if result
.messages
.iter()
.any(|msg| !matches!(msg.role, context_servers::types::SamplingRole::User))
{
return Err(anyhow!(
"Prompt contains non-user roles, which is not supported"
));
}
// Extract text from user messages into a single prompt string
let mut prompt = result
.messages
.into_iter()
.filter_map(|msg| match msg.content {
context_servers::types::SamplingContent::Text { text } => Some(text),
_ => None,
})
.collect::<Vec<String>>()
.join("\n\n");
// We must normalize the line endings here, since servers might return CR characters.
LineEnding::normalize(&mut prompt);

View file

@ -11,7 +11,7 @@ use collections::HashMap;
use crate::client::Client;
use crate::types;
const PROTOCOL_VERSION: u32 = 1;
const PROTOCOL_VERSION: &str = "2024-10-07";
pub struct ModelContextProtocol {
inner: Client,
@ -22,12 +22,19 @@ impl ModelContextProtocol {
Self { inner }
}
fn supported_protocols() -> Vec<types::ProtocolVersion> {
vec![
types::ProtocolVersion::VersionString(PROTOCOL_VERSION.to_string()),
types::ProtocolVersion::VersionNumber(1),
]
}
pub async fn initialize(
self,
client_info: types::Implementation,
) -> Result<InitializedContextServerProtocol> {
let params = types::InitializeParams {
protocol_version: PROTOCOL_VERSION,
protocol_version: types::ProtocolVersion::VersionString(PROTOCOL_VERSION.to_string()),
capabilities: types::ClientCapabilities {
experimental: None,
sampling: None,
@ -40,6 +47,13 @@ impl ModelContextProtocol {
.request(types::RequestType::Initialize.as_str(), params)
.await?;
if !Self::supported_protocols().contains(&response.protocol_version) {
return Err(anyhow::anyhow!(
"Unsupported protocol version: {:?}",
response.protocol_version
));
}
log::trace!("mcp server info {:?}", response.server_info);
self.inner.notify(

View file

@ -36,10 +36,17 @@ impl RequestType {
}
}
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ProtocolVersion {
VersionString(String),
VersionNumber(u32),
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct InitializeParams {
pub protocol_version: u32,
pub protocol_version: ProtocolVersion,
pub capabilities: ClientCapabilities,
pub client_info: Implementation,
}
@ -131,7 +138,7 @@ pub struct CompletionArgument {
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct InitializeResponse {
pub protocol_version: u32,
pub protocol_version: ProtocolVersion,
pub capabilities: ServerCapabilities,
pub server_info: Implementation,
}
@ -145,10 +152,9 @@ pub struct ResourcesReadResponse {
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ResourcesListResponse {
pub resources: Vec<Resource>,
#[serde(skip_serializing_if = "Option::is_none")]
pub resource_templates: Option<Vec<ResourceTemplate>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub resources: Option<Vec<Resource>>,
pub next_cursor: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
@ -179,13 +185,15 @@ pub enum SamplingContent {
pub struct PromptsGetResponse {
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
pub prompt: String,
pub messages: Vec<SamplingMessage>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PromptsListResponse {
pub prompts: Vec<Prompt>,
#[serde(skip_serializing_if = "Option::is_none")]
pub next_cursor: Option<String>,
}
#[derive(Debug, Deserialize)]