From 5bae6eb4933476736d94cce5d1d53e719c1f2f52 Mon Sep 17 00:00:00 2001 From: David Soria Parra <167242713+dsp-ant@users.noreply.github.com> Date: Thu, 29 Aug 2024 21:56:58 +0100 Subject: [PATCH] context_servers: Completion support for context server slash commands (#17085) This PR adds support for completions via MCP. The protocol now supports a new request type "completion/complete" that can either complete a resource URI template (which we currently don't use in Zed), or a prompt argument. We use this to add autocompletion to our context server slash commands! https://github.com/user-attachments/assets/08c9cf04-cbeb-49a7-903f-5049fb3b3d9f Release Notes: - context_servers: Added support for argument completions for context server prompts. These show up as regular completions to slash commands. --- .../slash_command/context_server_command.rs | 72 ++++++++++++++-- crates/context_servers/src/protocol.rs | 29 +++++++ crates/context_servers/src/types.rs | 83 +++++++++++++++++++ 3 files changed, 179 insertions(+), 5 deletions(-) diff --git a/crates/assistant/src/slash_command/context_server_command.rs b/crates/assistant/src/slash_command/context_server_command.rs index 66c8ab21d3..7dc9b34ceb 100644 --- a/crates/assistant/src/slash_command/context_server_command.rs +++ b/crates/assistant/src/slash_command/context_server_command.rs @@ -1,6 +1,7 @@ use anyhow::{anyhow, Result}; use assistant_slash_command::{ - ArgumentCompletion, SlashCommand, SlashCommandOutput, SlashCommandOutputSection, + AfterCompletion, ArgumentCompletion, SlashCommand, SlashCommandOutput, + SlashCommandOutputSection, }; use collections::HashMap; use context_servers::{ @@ -8,7 +9,7 @@ use context_servers::{ protocol::PromptInfo, }; use gpui::{Task, WeakView, WindowContext}; -use language::LspAdapterDelegate; +use language::{CodeLabel, LspAdapterDelegate}; use std::sync::atomic::AtomicBool; use std::sync::Arc; use ui::{IconName, SharedString}; @@ -50,12 +51,57 @@ impl SlashCommand for ContextServerSlashCommand { fn complete_argument( self: Arc, - _arguments: &[String], + arguments: &[String], _cancel: Arc, _workspace: Option>, - _cx: &mut WindowContext, + cx: &mut WindowContext, ) -> Task>> { - Task::ready(Ok(Vec::new())) + let server_id = self.server_id.clone(); + let prompt_name = self.prompt.name.clone(); + let manager = ContextServerManager::global(cx); + let manager = manager.read(cx); + + let (arg_name, arg_val) = match completion_argument(&self.prompt, arguments) { + Ok(tp) => tp, + Err(e) => { + return Task::ready(Err(e)); + } + }; + if let Some(server) = manager.get_server(&server_id) { + cx.foreground_executor().spawn(async move { + let Some(protocol) = server.client.read().clone() else { + return Err(anyhow!("Context server not initialized")); + }; + + let completion_result = protocol + .completion( + context_servers::types::CompletionReference::Prompt( + context_servers::types::PromptReference { + r#type: context_servers::types::PromptReferenceType::Prompt, + name: prompt_name, + }, + ), + arg_name, + arg_val, + ) + .await?; + + let completions = completion_result + .values + .into_iter() + .map(|value| ArgumentCompletion { + label: CodeLabel::plain(value.clone(), None), + new_text: value, + after_completion: AfterCompletion::Continue, + replace_previous_arguments: false, + }) + .collect(); + + Ok(completions) + }) + } else { + Task::ready(Err(anyhow!("Context server not found"))) + } } fn run( @@ -102,6 +148,22 @@ impl SlashCommand for ContextServerSlashCommand { } } +fn completion_argument(prompt: &PromptInfo, arguments: &[String]) -> Result<(String, String)> { + if arguments.is_empty() { + return Err(anyhow!("No arguments given")); + } + + match &prompt.arguments { + Some(args) if args.len() == 1 => { + let arg_name = args[0].name.clone(); + let arg_value = arguments.join(" "); + Ok((arg_name, arg_value)) + } + Some(_) => Err(anyhow!("Prompt must have exactly one argument")), + None => Err(anyhow!("Prompt has no arguments")), + } +} + fn prompt_arguments(prompt: &PromptInfo, arguments: &[String]) -> Result> { match &prompt.arguments { Some(args) if args.len() > 1 => Err(anyhow!( diff --git a/crates/context_servers/src/protocol.rs b/crates/context_servers/src/protocol.rs index 1440f8248a..87da217f7d 100644 --- a/crates/context_servers/src/protocol.rs +++ b/crates/context_servers/src/protocol.rs @@ -127,6 +127,35 @@ impl InitializedContextServerProtocol { Ok(response) } + + pub async fn completion>( + &self, + reference: types::CompletionReference, + argument: P, + value: P, + ) -> Result { + let params = types::CompletionCompleteParams { + r#ref: reference, + argument: types::CompletionArgument { + name: argument.into(), + value: value.into(), + }, + }; + let result: types::CompletionCompleteResponse = self + .inner + .request(types::RequestType::CompletionComplete.as_str(), params) + .await?; + + let completion = types::Completion { + values: result.completion.values, + total: types::CompletionTotal::from_options( + result.completion.has_more, + result.completion.total, + ), + }; + + Ok(completion) + } } impl InitializedContextServerProtocol { diff --git a/crates/context_servers/src/types.rs b/crates/context_servers/src/types.rs index d880a45bd4..c0e9a79f15 100644 --- a/crates/context_servers/src/types.rs +++ b/crates/context_servers/src/types.rs @@ -14,6 +14,7 @@ pub enum RequestType { LoggingSetLevel, PromptsGet, PromptsList, + CompletionComplete, } impl RequestType { @@ -28,6 +29,7 @@ impl RequestType { RequestType::LoggingSetLevel => "logging/setLevel", RequestType::PromptsGet => "prompts/get", RequestType::PromptsList => "prompts/list", + RequestType::CompletionComplete => "completion/complete", } } } @@ -78,6 +80,50 @@ pub struct PromptsGetParams { pub arguments: Option>, } +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct CompletionCompleteParams { + pub r#ref: CompletionReference, + pub argument: CompletionArgument, +} + +#[derive(Debug, Serialize)] +#[serde(untagged)] +pub enum CompletionReference { + Prompt(PromptReference), + Resource(ResourceReference), +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct PromptReference { + pub r#type: PromptReferenceType, + pub name: String, +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "snake_case")] +pub enum PromptReferenceType { + #[serde(rename = "ref/prompt")] + Prompt, + #[serde(rename = "ref/resource")] + Resource, +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct ResourceReference { + pub r#type: String, + pub uri: String, +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct CompletionArgument { + pub name: String, + pub value: String, +} + #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] pub struct InitializeResponse { @@ -112,6 +158,20 @@ pub struct PromptsListResponse { pub prompts: Vec, } +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CompletionCompleteResponse { + pub completion: CompletionResult, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CompletionResult { + pub values: Vec, + pub total: Option, + pub has_more: Option, +} + #[derive(Debug, Deserialize, Clone)] #[serde(rename_all = "camelCase")] pub struct PromptInfo { @@ -233,3 +293,26 @@ pub struct ProgressParams { pub progress: f64, pub total: Option, } + +// Helper Types that don't map directly to the protocol + +pub enum CompletionTotal { + Exact(u32), + HasMore, + Unknown, +} + +impl CompletionTotal { + pub fn from_options(has_more: Option, total: Option) -> Self { + match (has_more, total) { + (_, Some(count)) => CompletionTotal::Exact(count), + (Some(true), _) => CompletionTotal::HasMore, + _ => CompletionTotal::Unknown, + } + } +} + +pub struct Completion { + pub values: Vec, + pub total: CompletionTotal, +}