mirror of
https://github.com/zed-industries/zed.git
synced 2024-10-26 08:31:04 +00:00
af5a9fabc6
Allows `LanguageModelTool`s to include nested structures, by exposing the definitions section of their JSON Schema. Release Notes: - N/A
219 lines
5.9 KiB
Rust
219 lines
5.9 KiB
Rust
use anyhow::Context as _;
|
|
use assets::Assets;
|
|
use assistant2::AssistantPanel;
|
|
use assistant_tooling::{LanguageModelTool, ToolRegistry};
|
|
use client::Client;
|
|
use gpui::{actions, AnyElement, App, AppContext, KeyBinding, Task, View, WindowOptions};
|
|
use language::LanguageRegistry;
|
|
use project::Project;
|
|
use rand::Rng;
|
|
use schemars::JsonSchema;
|
|
use serde::{Deserialize, Serialize};
|
|
use settings::{KeymapFile, DEFAULT_KEYMAP_PATH};
|
|
use std::sync::Arc;
|
|
use theme::LoadThemes;
|
|
use ui::{div, prelude::*, Render};
|
|
use util::ResultExt as _;
|
|
|
|
actions!(example, [Quit]);
|
|
|
|
struct RollDiceTool {}
|
|
|
|
impl RollDiceTool {
|
|
fn new() -> Self {
|
|
Self {}
|
|
}
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, JsonSchema, Clone)]
|
|
#[serde(rename_all = "snake_case")]
|
|
enum Die {
|
|
D6 = 6,
|
|
D20 = 20,
|
|
}
|
|
|
|
impl Die {
|
|
fn into_str(&self) -> &'static str {
|
|
match self {
|
|
Die::D6 => "d6",
|
|
Die::D20 => "d20",
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, JsonSchema, Clone)]
|
|
struct DiceParams {
|
|
/// The number of dice to roll.
|
|
num_dice: u8,
|
|
/// Which die to roll. Defaults to a d6 if not provided.
|
|
die_type: Option<Die>,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize)]
|
|
struct DieRoll {
|
|
die: Die,
|
|
roll: u8,
|
|
}
|
|
|
|
impl DieRoll {
|
|
fn render(&self) -> AnyElement {
|
|
match self.die {
|
|
Die::D6 => {
|
|
let face = match self.roll {
|
|
6 => div().child("⚅"),
|
|
5 => div().child("⚄"),
|
|
4 => div().child("⚃"),
|
|
3 => div().child("⚂"),
|
|
2 => div().child("⚁"),
|
|
1 => div().child("⚀"),
|
|
_ => div().child("😅"),
|
|
};
|
|
face.text_3xl().into_any_element()
|
|
}
|
|
_ => div()
|
|
.child(format!("{}", self.roll))
|
|
.text_3xl()
|
|
.into_any_element(),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize)]
|
|
struct DiceRoll {
|
|
rolls: Vec<DieRoll>,
|
|
}
|
|
|
|
impl LanguageModelTool for RollDiceTool {
|
|
type Input = DiceParams;
|
|
type Output = DiceRoll;
|
|
|
|
fn name(&self) -> String {
|
|
"roll_dice".to_string()
|
|
}
|
|
|
|
fn description(&self) -> String {
|
|
"Rolls N many dice and returns the results.".to_string()
|
|
}
|
|
|
|
fn execute(&self, input: &Self::Input, _cx: &AppContext) -> Task<gpui::Result<Self::Output>> {
|
|
let rolls = (0..input.num_dice)
|
|
.map(|_| {
|
|
let die_type = input.die_type.as_ref().unwrap_or(&Die::D6).clone();
|
|
|
|
DieRoll {
|
|
die: die_type.clone(),
|
|
roll: rand::thread_rng().gen_range(1..=die_type as u8),
|
|
}
|
|
})
|
|
.collect();
|
|
|
|
return Task::ready(Ok(DiceRoll { rolls }));
|
|
}
|
|
|
|
fn render(
|
|
_tool_call_id: &str,
|
|
_input: &Self::Input,
|
|
output: &Self::Output,
|
|
_cx: &mut WindowContext,
|
|
) -> gpui::AnyElement {
|
|
h_flex()
|
|
.children(
|
|
output
|
|
.rolls
|
|
.iter()
|
|
.map(|roll| div().p_2().child(roll.render())),
|
|
)
|
|
.into_any_element()
|
|
}
|
|
|
|
fn format(_input: &Self::Input, output: &Self::Output) -> String {
|
|
let mut result = String::new();
|
|
for roll in &output.rolls {
|
|
let die = &roll.die;
|
|
result.push_str(&format!("{}: {}\n", die.into_str(), roll.roll));
|
|
}
|
|
result
|
|
}
|
|
}
|
|
|
|
fn main() {
|
|
env_logger::init();
|
|
App::new().with_assets(Assets).run(|cx| {
|
|
cx.bind_keys(Some(KeyBinding::new("cmd-q", Quit, None)));
|
|
cx.on_action(|_: &Quit, cx: &mut AppContext| {
|
|
cx.quit();
|
|
});
|
|
|
|
settings::init(cx);
|
|
language::init(cx);
|
|
Project::init_settings(cx);
|
|
editor::init(cx);
|
|
theme::init(LoadThemes::JustBase, cx);
|
|
Assets.load_fonts(cx).unwrap();
|
|
KeymapFile::load_asset(DEFAULT_KEYMAP_PATH, cx).unwrap();
|
|
client::init_settings(cx);
|
|
release_channel::init("0.130.0", cx);
|
|
|
|
let client = Client::production(cx);
|
|
{
|
|
let client = client.clone();
|
|
cx.spawn(|cx| async move { client.authenticate_and_connect(false, &cx).await })
|
|
.detach_and_log_err(cx);
|
|
}
|
|
assistant2::init(client.clone(), cx);
|
|
|
|
let language_registry = Arc::new(LanguageRegistry::new(
|
|
Task::ready(()),
|
|
cx.background_executor().clone(),
|
|
));
|
|
let node_runtime = node_runtime::RealNodeRuntime::new(client.http_client());
|
|
languages::init(language_registry.clone(), node_runtime, cx);
|
|
|
|
cx.spawn(|cx| async move {
|
|
cx.update(|cx| {
|
|
let mut tool_registry = ToolRegistry::new();
|
|
tool_registry
|
|
.register(RollDiceTool::new())
|
|
.context("failed to register DummyTool")
|
|
.log_err();
|
|
|
|
let tool_registry = Arc::new(tool_registry);
|
|
|
|
println!("Tools registered");
|
|
for definition in tool_registry.definitions() {
|
|
println!("{}", definition);
|
|
}
|
|
|
|
cx.open_window(WindowOptions::default(), |cx| {
|
|
cx.new_view(|cx| Example::new(language_registry, tool_registry, cx))
|
|
});
|
|
cx.activate(true);
|
|
})
|
|
})
|
|
.detach_and_log_err(cx);
|
|
})
|
|
}
|
|
|
|
struct Example {
|
|
assistant_panel: View<AssistantPanel>,
|
|
}
|
|
|
|
impl Example {
|
|
fn new(
|
|
language_registry: Arc<LanguageRegistry>,
|
|
tool_registry: Arc<ToolRegistry>,
|
|
cx: &mut ViewContext<Self>,
|
|
) -> Self {
|
|
Self {
|
|
assistant_panel: cx
|
|
.new_view(|cx| AssistantPanel::new(language_registry, tool_registry, cx)),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Render for Example {
|
|
fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl ui::prelude::IntoElement {
|
|
div().size_full().child(self.assistant_panel.clone())
|
|
}
|
|
}
|