mirror of
https://github.com/zed-industries/zed.git
synced 2025-01-12 13:24:19 +00:00
context_servers: Completion support for context server slash commands (#17085)
This PR adds support for completions via MCP. The protocol now supports a new request type "completion/complete" that can either complete a resource URI template (which we currently don't use in Zed), or a prompt argument. We use this to add autocompletion to our context server slash commands! https://github.com/user-attachments/assets/08c9cf04-cbeb-49a7-903f-5049fb3b3d9f Release Notes: - context_servers: Added support for argument completions for context server prompts. These show up as regular completions to slash commands.
This commit is contained in:
parent
01f8d27f22
commit
5bae6eb493
3 changed files with 179 additions and 5 deletions
|
@ -1,6 +1,7 @@
|
|||
use anyhow::{anyhow, Result};
|
||||
use assistant_slash_command::{
|
||||
ArgumentCompletion, SlashCommand, SlashCommandOutput, SlashCommandOutputSection,
|
||||
AfterCompletion, ArgumentCompletion, SlashCommand, SlashCommandOutput,
|
||||
SlashCommandOutputSection,
|
||||
};
|
||||
use collections::HashMap;
|
||||
use context_servers::{
|
||||
|
@ -8,7 +9,7 @@ use context_servers::{
|
|||
protocol::PromptInfo,
|
||||
};
|
||||
use gpui::{Task, WeakView, WindowContext};
|
||||
use language::LspAdapterDelegate;
|
||||
use language::{CodeLabel, LspAdapterDelegate};
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::Arc;
|
||||
use ui::{IconName, SharedString};
|
||||
|
@ -50,12 +51,57 @@ impl SlashCommand for ContextServerSlashCommand {
|
|||
|
||||
fn complete_argument(
|
||||
self: Arc<Self>,
|
||||
_arguments: &[String],
|
||||
arguments: &[String],
|
||||
_cancel: Arc<AtomicBool>,
|
||||
_workspace: Option<WeakView<Workspace>>,
|
||||
_cx: &mut WindowContext,
|
||||
cx: &mut WindowContext,
|
||||
) -> Task<Result<Vec<ArgumentCompletion>>> {
|
||||
Task::ready(Ok(Vec::new()))
|
||||
let server_id = self.server_id.clone();
|
||||
let prompt_name = self.prompt.name.clone();
|
||||
let manager = ContextServerManager::global(cx);
|
||||
let manager = manager.read(cx);
|
||||
|
||||
let (arg_name, arg_val) = match completion_argument(&self.prompt, arguments) {
|
||||
Ok(tp) => tp,
|
||||
Err(e) => {
|
||||
return Task::ready(Err(e));
|
||||
}
|
||||
};
|
||||
if let Some(server) = manager.get_server(&server_id) {
|
||||
cx.foreground_executor().spawn(async move {
|
||||
let Some(protocol) = server.client.read().clone() else {
|
||||
return Err(anyhow!("Context server not initialized"));
|
||||
};
|
||||
|
||||
let completion_result = protocol
|
||||
.completion(
|
||||
context_servers::types::CompletionReference::Prompt(
|
||||
context_servers::types::PromptReference {
|
||||
r#type: context_servers::types::PromptReferenceType::Prompt,
|
||||
name: prompt_name,
|
||||
},
|
||||
),
|
||||
arg_name,
|
||||
arg_val,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let completions = completion_result
|
||||
.values
|
||||
.into_iter()
|
||||
.map(|value| ArgumentCompletion {
|
||||
label: CodeLabel::plain(value.clone(), None),
|
||||
new_text: value,
|
||||
after_completion: AfterCompletion::Continue,
|
||||
replace_previous_arguments: false,
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(completions)
|
||||
})
|
||||
} else {
|
||||
Task::ready(Err(anyhow!("Context server not found")))
|
||||
}
|
||||
}
|
||||
|
||||
fn run(
|
||||
|
@ -102,6 +148,22 @@ impl SlashCommand for ContextServerSlashCommand {
|
|||
}
|
||||
}
|
||||
|
||||
fn completion_argument(prompt: &PromptInfo, arguments: &[String]) -> Result<(String, String)> {
|
||||
if arguments.is_empty() {
|
||||
return Err(anyhow!("No arguments given"));
|
||||
}
|
||||
|
||||
match &prompt.arguments {
|
||||
Some(args) if args.len() == 1 => {
|
||||
let arg_name = args[0].name.clone();
|
||||
let arg_value = arguments.join(" ");
|
||||
Ok((arg_name, arg_value))
|
||||
}
|
||||
Some(_) => Err(anyhow!("Prompt must have exactly one argument")),
|
||||
None => Err(anyhow!("Prompt has no arguments")),
|
||||
}
|
||||
}
|
||||
|
||||
fn prompt_arguments(prompt: &PromptInfo, arguments: &[String]) -> Result<HashMap<String, String>> {
|
||||
match &prompt.arguments {
|
||||
Some(args) if args.len() > 1 => Err(anyhow!(
|
||||
|
|
|
@ -127,6 +127,35 @@ impl InitializedContextServerProtocol {
|
|||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
pub async fn completion<P: Into<String>>(
|
||||
&self,
|
||||
reference: types::CompletionReference,
|
||||
argument: P,
|
||||
value: P,
|
||||
) -> Result<types::Completion> {
|
||||
let params = types::CompletionCompleteParams {
|
||||
r#ref: reference,
|
||||
argument: types::CompletionArgument {
|
||||
name: argument.into(),
|
||||
value: value.into(),
|
||||
},
|
||||
};
|
||||
let result: types::CompletionCompleteResponse = self
|
||||
.inner
|
||||
.request(types::RequestType::CompletionComplete.as_str(), params)
|
||||
.await?;
|
||||
|
||||
let completion = types::Completion {
|
||||
values: result.completion.values,
|
||||
total: types::CompletionTotal::from_options(
|
||||
result.completion.has_more,
|
||||
result.completion.total,
|
||||
),
|
||||
};
|
||||
|
||||
Ok(completion)
|
||||
}
|
||||
}
|
||||
|
||||
impl InitializedContextServerProtocol {
|
||||
|
|
|
@ -14,6 +14,7 @@ pub enum RequestType {
|
|||
LoggingSetLevel,
|
||||
PromptsGet,
|
||||
PromptsList,
|
||||
CompletionComplete,
|
||||
}
|
||||
|
||||
impl RequestType {
|
||||
|
@ -28,6 +29,7 @@ impl RequestType {
|
|||
RequestType::LoggingSetLevel => "logging/setLevel",
|
||||
RequestType::PromptsGet => "prompts/get",
|
||||
RequestType::PromptsList => "prompts/list",
|
||||
RequestType::CompletionComplete => "completion/complete",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -78,6 +80,50 @@ pub struct PromptsGetParams {
|
|||
pub arguments: Option<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CompletionCompleteParams {
|
||||
pub r#ref: CompletionReference,
|
||||
pub argument: CompletionArgument,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum CompletionReference {
|
||||
Prompt(PromptReference),
|
||||
Resource(ResourceReference),
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PromptReference {
|
||||
pub r#type: PromptReferenceType,
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum PromptReferenceType {
|
||||
#[serde(rename = "ref/prompt")]
|
||||
Prompt,
|
||||
#[serde(rename = "ref/resource")]
|
||||
Resource,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ResourceReference {
|
||||
pub r#type: String,
|
||||
pub uri: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CompletionArgument {
|
||||
pub name: String,
|
||||
pub value: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct InitializeResponse {
|
||||
|
@ -112,6 +158,20 @@ pub struct PromptsListResponse {
|
|||
pub prompts: Vec<PromptInfo>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CompletionCompleteResponse {
|
||||
pub completion: CompletionResult,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CompletionResult {
|
||||
pub values: Vec<String>,
|
||||
pub total: Option<u32>,
|
||||
pub has_more: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PromptInfo {
|
||||
|
@ -233,3 +293,26 @@ pub struct ProgressParams {
|
|||
pub progress: f64,
|
||||
pub total: Option<f64>,
|
||||
}
|
||||
|
||||
// Helper Types that don't map directly to the protocol
|
||||
|
||||
pub enum CompletionTotal {
|
||||
Exact(u32),
|
||||
HasMore,
|
||||
Unknown,
|
||||
}
|
||||
|
||||
impl CompletionTotal {
|
||||
pub fn from_options(has_more: Option<bool>, total: Option<u32>) -> Self {
|
||||
match (has_more, total) {
|
||||
(_, Some(count)) => CompletionTotal::Exact(count),
|
||||
(Some(true), _) => CompletionTotal::HasMore,
|
||||
_ => CompletionTotal::Unknown,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Completion {
|
||||
pub values: Vec<String>,
|
||||
pub total: CompletionTotal,
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue