mirror of
https://github.com/zed-industries/zed.git
synced 2024-12-28 11:29:25 +00:00
assistant: Add tool registry (#17331)
This PR adds a tool registry to hold tools that can be called by the Assistant. Currently we just have a `now` tool for retrieving the current datetime. This is all behind the `assistant-tool-use` feature flag which currently needs to be explicitly opted-in to in order for the LLM to see the tools. Release Notes: - N/A
This commit is contained in:
parent
c2448e1673
commit
e81b484bf2
11 changed files with 243 additions and 2 deletions
15
Cargo.lock
generated
15
Cargo.lock
generated
|
@ -373,6 +373,7 @@ dependencies = [
|
|||
"anyhow",
|
||||
"assets",
|
||||
"assistant_slash_command",
|
||||
"assistant_tool",
|
||||
"async-watch",
|
||||
"cargo_toml",
|
||||
"chrono",
|
||||
|
@ -454,6 +455,20 @@ dependencies = [
|
|||
"workspace",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "assistant_tool"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"collections",
|
||||
"derive_more",
|
||||
"gpui",
|
||||
"parking_lot",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"workspace",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-attributes"
|
||||
version = "1.1.2"
|
||||
|
|
|
@ -6,6 +6,7 @@ members = [
|
|||
"crates/assets",
|
||||
"crates/assistant",
|
||||
"crates/assistant_slash_command",
|
||||
"crates/assistant_tool",
|
||||
"crates/audio",
|
||||
"crates/auto_update",
|
||||
"crates/breadcrumbs",
|
||||
|
@ -181,6 +182,7 @@ anthropic = { path = "crates/anthropic" }
|
|||
assets = { path = "crates/assets" }
|
||||
assistant = { path = "crates/assistant" }
|
||||
assistant_slash_command = { path = "crates/assistant_slash_command" }
|
||||
assistant_tool = { path = "crates/assistant_tool" }
|
||||
audio = { path = "crates/audio" }
|
||||
auto_update = { path = "crates/auto_update" }
|
||||
breadcrumbs = { path = "crates/breadcrumbs" }
|
||||
|
|
|
@ -25,6 +25,7 @@ anthropic = { workspace = true, features = ["schemars"] }
|
|||
anyhow.workspace = true
|
||||
assets.workspace = true
|
||||
assistant_slash_command.workspace = true
|
||||
assistant_tool.workspace = true
|
||||
async-watch.workspace = true
|
||||
cargo_toml.workspace = true
|
||||
chrono.workspace = true
|
||||
|
|
|
@ -13,11 +13,13 @@ pub(crate) mod slash_command_picker;
|
|||
pub mod slash_command_settings;
|
||||
mod streaming_diff;
|
||||
mod terminal_inline_assistant;
|
||||
mod tools;
|
||||
mod workflow;
|
||||
|
||||
pub use assistant_panel::{AssistantPanel, AssistantPanelEvent};
|
||||
use assistant_settings::AssistantSettings;
|
||||
use assistant_slash_command::SlashCommandRegistry;
|
||||
use assistant_tool::ToolRegistry;
|
||||
use client::{proto, Client};
|
||||
use command_palette_hooks::CommandPaletteFilter;
|
||||
pub use context::*;
|
||||
|
@ -214,6 +216,7 @@ pub fn init(
|
|||
prompt_library::init(cx);
|
||||
init_language_model_settings(cx);
|
||||
assistant_slash_command::init(cx);
|
||||
assistant_tool::init(cx);
|
||||
assistant_panel::init(cx);
|
||||
context_servers::init(cx);
|
||||
|
||||
|
@ -228,6 +231,7 @@ pub fn init(
|
|||
.map(Arc::new)
|
||||
.unwrap_or_else(|| Arc::new(prompts::PromptBuilder::new(None).unwrap()));
|
||||
register_slash_commands(Some(prompt_builder.clone()), cx);
|
||||
register_tools(cx);
|
||||
inline_assistant::init(
|
||||
fs.clone(),
|
||||
prompt_builder.clone(),
|
||||
|
@ -401,6 +405,11 @@ fn update_slash_commands_from_settings(cx: &mut AppContext) {
|
|||
}
|
||||
}
|
||||
|
||||
fn register_tools(cx: &mut AppContext) {
|
||||
let tool_registry = ToolRegistry::global(cx);
|
||||
tool_registry.register_tool(tools::now_tool::NowTool);
|
||||
}
|
||||
|
||||
pub fn humanize_token_count(count: usize) -> String {
|
||||
match count {
|
||||
0..=999 => count.to_string(),
|
||||
|
|
|
@ -9,9 +9,11 @@ use anyhow::{anyhow, Context as _, Result};
|
|||
use assistant_slash_command::{
|
||||
SlashCommandOutput, SlashCommandOutputSection, SlashCommandRegistry,
|
||||
};
|
||||
use assistant_tool::ToolRegistry;
|
||||
use client::{self, proto, telemetry::Telemetry};
|
||||
use clock::ReplicaId;
|
||||
use collections::{HashMap, HashSet};
|
||||
use feature_flags::{FeatureFlag, FeatureFlagAppExt};
|
||||
use fs::{Fs, RemoveOptions};
|
||||
use futures::{
|
||||
future::{self, Shared},
|
||||
|
@ -27,7 +29,7 @@ use language::{AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, P
|
|||
use language_model::{
|
||||
LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionEvent,
|
||||
LanguageModelImage, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
|
||||
MessageContent, Role,
|
||||
LanguageModelRequestTool, MessageContent, Role,
|
||||
};
|
||||
use open_ai::Model as OpenAiModel;
|
||||
use paths::{context_images_dir, contexts_dir};
|
||||
|
@ -1942,7 +1944,21 @@ impl Context {
|
|||
// Compute which messages to cache, including the last one.
|
||||
self.mark_cache_anchors(&model.cache_configuration(), false, cx);
|
||||
|
||||
let request = self.to_completion_request(cx);
|
||||
let mut request = self.to_completion_request(cx);
|
||||
|
||||
if cx.has_flag::<ToolUseFeatureFlag>() {
|
||||
let tool_registry = ToolRegistry::global(cx);
|
||||
request.tools = tool_registry
|
||||
.tools()
|
||||
.into_iter()
|
||||
.map(|tool| LanguageModelRequestTool {
|
||||
name: tool.name(),
|
||||
description: tool.description(),
|
||||
input_schema: tool.input_schema(),
|
||||
})
|
||||
.collect();
|
||||
}
|
||||
|
||||
let assistant_message = self
|
||||
.insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx)
|
||||
.unwrap();
|
||||
|
@ -2788,6 +2804,16 @@ pub enum PendingSlashCommandStatus {
|
|||
Error(String),
|
||||
}
|
||||
|
||||
pub(crate) struct ToolUseFeatureFlag;
|
||||
|
||||
impl FeatureFlag for ToolUseFeatureFlag {
|
||||
const NAME: &'static str = "assistant-tool-use";
|
||||
|
||||
fn enabled_for_staff() -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PendingToolUse {
|
||||
pub id: String,
|
||||
|
|
1
crates/assistant/src/tools.rs
Normal file
1
crates/assistant/src/tools.rs
Normal file
|
@ -0,0 +1 @@
|
|||
pub mod now_tool;
|
60
crates/assistant/src/tools/now_tool.rs
Normal file
60
crates/assistant/src/tools/now_tool.rs
Normal file
|
@ -0,0 +1,60 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use assistant_tool::Tool;
|
||||
use chrono::{Local, Utc};
|
||||
use gpui::{Task, WeakView, WindowContext};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum Timezone {
|
||||
/// Use UTC for the datetime.
|
||||
Utc,
|
||||
/// Use local time for the datetime.
|
||||
Local,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct FileToolInput {
|
||||
/// The timezone to use for the datetime.
|
||||
timezone: Timezone,
|
||||
}
|
||||
|
||||
pub struct NowTool;
|
||||
|
||||
impl Tool for NowTool {
|
||||
fn name(&self) -> String {
|
||||
"now".into()
|
||||
}
|
||||
|
||||
fn description(&self) -> String {
|
||||
"Returns the current datetime in RFC 3339 format.".into()
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> serde_json::Value {
|
||||
let schema = schemars::schema_for!(FileToolInput);
|
||||
serde_json::to_value(&schema).unwrap()
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: serde_json::Value,
|
||||
_workspace: WeakView<workspace::Workspace>,
|
||||
_cx: &mut WindowContext,
|
||||
) -> Task<Result<String>> {
|
||||
let input: FileToolInput = match serde_json::from_value(input) {
|
||||
Ok(input) => input,
|
||||
Err(err) => return Task::ready(Err(anyhow!(err))),
|
||||
};
|
||||
|
||||
let now = match input.timezone {
|
||||
Timezone::Utc => Utc::now().to_rfc3339(),
|
||||
Timezone::Local => Local::now().to_rfc3339(),
|
||||
};
|
||||
let text = format!("The current datetime is {now}.");
|
||||
|
||||
Task::ready(Ok(text))
|
||||
}
|
||||
}
|
22
crates/assistant_tool/Cargo.toml
Normal file
22
crates/assistant_tool/Cargo.toml
Normal file
|
@ -0,0 +1,22 @@
|
|||
[package]
|
||||
name = "assistant_tool"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
publish = false
|
||||
license = "GPL-3.0-or-later"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/assistant_tool.rs"
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
collections.workspace = true
|
||||
derive_more.workspace = true
|
||||
gpui.workspace = true
|
||||
parking_lot.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
workspace.workspace = true
|
1
crates/assistant_tool/LICENSE-GPL
Symbolic link
1
crates/assistant_tool/LICENSE-GPL
Symbolic link
|
@ -0,0 +1 @@
|
|||
../../LICENSE-GPL
|
35
crates/assistant_tool/src/assistant_tool.rs
Normal file
35
crates/assistant_tool/src/assistant_tool.rs
Normal file
|
@ -0,0 +1,35 @@
|
|||
mod tool_registry;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use gpui::{AppContext, Task, WeakView, WindowContext};
|
||||
use workspace::Workspace;
|
||||
|
||||
pub use tool_registry::*;
|
||||
|
||||
pub fn init(cx: &mut AppContext) {
|
||||
ToolRegistry::default_global(cx);
|
||||
}
|
||||
|
||||
/// A tool that can be used by a language model.
|
||||
pub trait Tool: 'static + Send + Sync {
|
||||
/// Returns the name of the tool.
|
||||
fn name(&self) -> String;
|
||||
|
||||
/// Returns the description of the tool.
|
||||
fn description(&self) -> String;
|
||||
|
||||
/// Returns the JSON schema that describes the tool's input.
|
||||
fn input_schema(&self) -> serde_json::Value {
|
||||
serde_json::Value::Object(serde_json::Map::default())
|
||||
}
|
||||
|
||||
/// Runs the tool with the provided input.
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: serde_json::Value,
|
||||
workspace: WeakView<Workspace>,
|
||||
cx: &mut WindowContext,
|
||||
) -> Task<Result<String>>;
|
||||
}
|
69
crates/assistant_tool/src/tool_registry.rs
Normal file
69
crates/assistant_tool/src/tool_registry.rs
Normal file
|
@ -0,0 +1,69 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use collections::HashMap;
|
||||
use derive_more::{Deref, DerefMut};
|
||||
use gpui::Global;
|
||||
use gpui::{AppContext, ReadGlobal};
|
||||
use parking_lot::RwLock;
|
||||
|
||||
use crate::Tool;
|
||||
|
||||
#[derive(Default, Deref, DerefMut)]
|
||||
struct GlobalToolRegistry(Arc<ToolRegistry>);
|
||||
|
||||
impl Global for GlobalToolRegistry {}
|
||||
|
||||
#[derive(Default)]
|
||||
struct ToolRegistryState {
|
||||
tools: HashMap<Arc<str>, Arc<dyn Tool>>,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct ToolRegistry {
|
||||
state: RwLock<ToolRegistryState>,
|
||||
}
|
||||
|
||||
impl ToolRegistry {
|
||||
/// Returns the global [`ToolRegistry`].
|
||||
pub fn global(cx: &AppContext) -> Arc<Self> {
|
||||
GlobalToolRegistry::global(cx).0.clone()
|
||||
}
|
||||
|
||||
/// Returns the global [`ToolRegistry`].
|
||||
///
|
||||
/// Inserts a default [`ToolRegistry`] if one does not yet exist.
|
||||
pub fn default_global(cx: &mut AppContext) -> Arc<Self> {
|
||||
cx.default_global::<GlobalToolRegistry>().0.clone()
|
||||
}
|
||||
|
||||
pub fn new() -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
state: RwLock::new(ToolRegistryState {
|
||||
tools: HashMap::default(),
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
/// Registers the provided [`Tool`].
|
||||
pub fn register_tool(&self, tool: impl Tool) {
|
||||
let mut state = self.state.write();
|
||||
let tool_name: Arc<str> = tool.name().into();
|
||||
state.tools.insert(tool_name, Arc::new(tool));
|
||||
}
|
||||
|
||||
/// Unregisters the provided [`Tool`].
|
||||
pub fn unregister_tool(&self, tool: impl Tool) {
|
||||
self.unregister_tool_by_name(tool.name().as_str())
|
||||
}
|
||||
|
||||
/// Unregisters the tool with the given name.
|
||||
pub fn unregister_tool_by_name(&self, tool_name: &str) {
|
||||
let mut state = self.state.write();
|
||||
state.tools.remove(tool_name);
|
||||
}
|
||||
|
||||
/// Returns the list of tools in the registry.
|
||||
pub fn tools(&self) -> Vec<Arc<dyn Tool>> {
|
||||
self.state.read().tools.values().cloned().collect()
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue