New revision of the Assistant Panel (#10870)

This is a crate only addition of a new version of the AssistantPanel.
We'll be putting this behind a feature flag while we iron out the new
experience.

Release Notes:

- N/A

---------

Co-authored-by: Nathan Sobo <nathan@zed.dev>
Co-authored-by: Antonio Scandurra <me@as-cii.com>
Co-authored-by: Conrad Irwin <conrad@zed.dev>
Co-authored-by: Marshall Bowers <elliott.codes@gmail.com>
Co-authored-by: Antonio Scandurra <antonio@zed.dev>
Co-authored-by: Nate Butler <nate@zed.dev>
Co-authored-by: Nate Butler <iamnbutler@gmail.com>
Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
Co-authored-by: Max <max@zed.dev>
This commit is contained in:
Kyle Kelley 2024-04-23 16:23:26 -07:00 committed by GitHub
parent e0c83a1d32
commit 68a1ad89bb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
55 changed files with 2989 additions and 262 deletions

View file

@ -3,5 +3,10 @@
"label": "clippy",
"command": "cargo",
"args": ["xtask", "clippy"]
},
{
"label": "assistant2",
"command": "cargo",
"args": ["run", "-p", "assistant2", "--example", "assistant_example"]
}
]

133
Cargo.lock generated
View file

@ -371,6 +371,50 @@ dependencies = [
"workspace",
]
[[package]]
name = "assistant2"
version = "0.1.0"
dependencies = [
"anyhow",
"assets",
"assistant_tooling",
"client",
"editor",
"env_logger",
"feature_flags",
"futures 0.3.28",
"gpui",
"language",
"languages",
"log",
"nanoid",
"node_runtime",
"open_ai",
"project",
"release_channel",
"rich_text",
"schemars",
"semantic_index",
"serde",
"serde_json",
"settings",
"theme",
"ui",
"util",
"workspace",
]
[[package]]
name = "assistant_tooling"
version = "0.1.0"
dependencies = [
"anyhow",
"gpui",
"schemars",
"serde",
"serde_json",
]
[[package]]
name = "async-broadcast"
version = "0.7.0"
@ -643,7 +687,7 @@ checksum = "5fd55a5ba1179988837d24ab4c7cc8ed6efdeff578ede0416b4225a5fca35bd0"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.59",
]
[[package]]
@ -710,7 +754,7 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.59",
]
[[package]]
@ -741,7 +785,7 @@ checksum = "c980ee35e870bd1a4d2c8294d4c04d0499e67bca1e4b5cefcc693c2fa00caea9"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.59",
]
[[package]]
@ -1385,7 +1429,7 @@ dependencies = [
"regex",
"rustc-hash",
"shlex",
"syn 2.0.48",
"syn 2.0.59",
"which 4.4.2",
]
@ -1468,7 +1512,7 @@ source = "git+https://github.com/kvark/blade?rev=810ec594358aafea29a4a3d8ab601d2
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.59",
]
[[package]]
@ -1634,7 +1678,7 @@ checksum = "965ab7eb5f8f97d2a083c799f3a1b994fc397b2fe2da5d1da1626ce15a39f2b1"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.59",
]
[[package]]
@ -2019,7 +2063,7 @@ dependencies = [
"heck 0.4.1",
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.59",
]
[[package]]
@ -2959,7 +3003,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "30d2b3721e861707777e3195b0158f950ae6dc4a27e4d02ff9f67e3eb3de199e"
dependencies = [
"quote",
"syn 2.0.48",
"syn 2.0.59",
]
[[package]]
@ -3442,7 +3486,7 @@ checksum = "5c785274071b1b420972453b306eeca06acf4633829db4223b58a2a8c5953bc4"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.59",
]
[[package]]
@ -3954,7 +3998,7 @@ checksum = "1a5c6c585bc94aaf2c7b51dd4c2ba22680844aba4c687be581871a6f518c5742"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.59",
]
[[package]]
@ -4194,7 +4238,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.59",
]
[[package]]
@ -5061,7 +5105,7 @@ checksum = "ce243b1bfa62ffc028f1cc3b6034ec63d649f3031bc8a4fbbb004e1ac17d1f68"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.59",
]
[[package]]
@ -5682,7 +5726,7 @@ checksum = "ba125974b109d512fccbc6c0244e7580143e460895dfd6ea7f8bbb692fd94396"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.59",
]
[[package]]
@ -6643,7 +6687,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.59",
]
[[package]]
@ -6719,7 +6763,7 @@ dependencies = [
"proc-macro-error",
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.59",
]
[[package]]
@ -6799,7 +6843,7 @@ checksum = "e8890702dbec0bad9116041ae586f84805b13eecd1d8b1df27c29998a9969d6d"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.59",
]
[[package]]
@ -6977,7 +7021,7 @@ dependencies = [
"phf_shared",
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.59",
]
[[package]]
@ -7028,7 +7072,7 @@ checksum = "4359fd9c9171ec6e8c62926d6faaf553a8dc3f64e1507e76da7911b4f6a04405"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.59",
]
[[package]]
@ -7252,7 +7296,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ae005bd773ab59b4725093fd7df83fd7892f7d8eafb48dbd7de6e024e4215f9d"
dependencies = [
"proc-macro2",
"syn 2.0.48",
"syn 2.0.59",
]
[[package]]
@ -7309,9 +7353,9 @@ dependencies = [
[[package]]
name = "proc-macro2"
version = "1.0.78"
version = "1.0.81"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e2422ad645d89c99f8f3e6b88a9fdeca7fabeac836b1002371c4367c8f984aae"
checksum = "3d1597b0c024618f09a9c3b8655b7e430397a36d23fdafec26d6965e9eec3eba"
dependencies = [
"unicode-ident",
]
@ -7332,7 +7376,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8021cf59c8ec9c432cfc2526ac6b8aa508ecaf29cd415f271b8406c1b851c3fd"
dependencies = [
"quote",
"syn 2.0.48",
"syn 2.0.59",
]
[[package]]
@ -8175,7 +8219,7 @@ dependencies = [
"proc-macro2",
"quote",
"rust-embed-utils",
"syn 2.0.48",
"syn 2.0.59",
"walkdir",
]
@ -8449,7 +8493,7 @@ dependencies = [
"proc-macro-error",
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.59",
]
[[package]]
@ -8490,7 +8534,7 @@ dependencies = [
"proc-macro2",
"quote",
"sea-bae",
"syn 2.0.48",
"syn 2.0.59",
"unicode-ident",
]
@ -8674,7 +8718,7 @@ checksum = "33c85360c95e7d137454dc81d9a4ed2b8efd8fbe19cee57357b32b9771fccb67"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.59",
]
[[package]]
@ -8739,7 +8783,7 @@ checksum = "8725e1dfadb3a50f7e5ce0b1a540466f6ed3fe7a0fca2ac2b8b831d31316bd00"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.59",
]
[[package]]
@ -9505,7 +9549,7 @@ dependencies = [
"proc-macro2",
"quote",
"rustversion",
"syn 2.0.48",
"syn 2.0.59",
]
[[package]]
@ -9634,9 +9678,9 @@ dependencies = [
[[package]]
name = "syn"
version = "2.0.48"
version = "2.0.59"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0f3531638e407dfc0814761abb7c00a5b54992b849452a0646b7f65c9f770f3f"
checksum = "4a6531ffc7b071655e4ce2e04bd464c4830bb585a61cabb96cf808f05172615a"
dependencies = [
"proc-macro2",
"quote",
@ -10001,7 +10045,7 @@ checksum = "49922ecae66cc8a249b77e68d1d0623c1b2c514f0060c27cdc68bd62a1219d35"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.59",
]
[[package]]
@ -10180,7 +10224,7 @@ checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.59",
]
[[package]]
@ -10405,7 +10449,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.59",
]
[[package]]
@ -11172,7 +11216,7 @@ dependencies = [
"once_cell",
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.59",
"wasm-bindgen-shared",
]
@ -11206,7 +11250,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.59",
"wasm-bindgen-backend",
"wasm-bindgen-shared",
]
@ -11343,7 +11387,7 @@ dependencies = [
"anyhow",
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.59",
"wasmtime-component-util",
"wasmtime-wit-bindgen",
"wit-parser",
@ -11504,7 +11548,7 @@ checksum = "6d6d967f01032da7d4c6303da32f6a00d5efe1bac124b156e7342d8ace6ffdfc"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.59",
]
[[package]]
@ -11784,7 +11828,7 @@ dependencies = [
"proc-macro2",
"quote",
"shellexpand",
"syn 2.0.48",
"syn 2.0.59",
"witx",
]
@ -11796,7 +11840,7 @@ checksum = "512d816dbcd0113103b2eb2402ec9018e7f0755202a5b3e67db726f229d8dcae"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.59",
"wiggle-generate",
]
@ -11914,7 +11958,7 @@ checksum = "942ac266be9249c84ca862f0a164a39533dc2f6f33dc98ec89c8da99b82ea0bd"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.59",
]
[[package]]
@ -11925,7 +11969,7 @@ checksum = "da33557140a288fae4e1d5f8873aaf9eb6613a9cf82c3e070223ff177f598b60"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.59",
]
[[package]]
@ -12242,7 +12286,7 @@ dependencies = [
"anyhow",
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.59",
"wit-bindgen-core",
"wit-bindgen-rust",
]
@ -12567,6 +12611,7 @@ dependencies = [
"anyhow",
"assets",
"assistant",
"assistant2",
"audio",
"auto_update",
"backtrace",
@ -12860,7 +12905,7 @@ checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.59",
]
[[package]]
@ -12880,7 +12925,7 @@ checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.59",
]
[[package]]

View file

@ -4,6 +4,8 @@ members = [
"crates/anthropic",
"crates/assets",
"crates/assistant",
"crates/assistant_tooling",
"crates/assistant2",
"crates/audio",
"crates/auto_update",
"crates/breadcrumbs",
@ -137,6 +139,8 @@ ai = { path = "crates/ai" }
anthropic = { path = "crates/anthropic" }
assets = { path = "crates/assets" }
assistant = { path = "crates/assistant" }
assistant2 = { path = "crates/assistant2" }
assistant_tooling = { path = "crates/assistant_tooling" }
audio = { path = "crates/audio" }
auto_update = { path = "crates/auto_update" }
base64 = "0.13"
@ -208,6 +212,7 @@ rpc = { path = "crates/rpc" }
task = { path = "crates/task" }
tasks_ui = { path = "crates/tasks_ui" }
search = { path = "crates/search" }
semantic_index = { path = "crates/semantic_index" }
semantic_version = { path = "crates/semantic_version" }
settings = { path = "crates/settings" }
snippet = { path = "crates/snippet" }

View file

@ -209,7 +209,14 @@
}
},
{
"context": "AssistantPanel",
"context": "AssistantChat > Editor", // Used in the assistant2 crate
"bindings": {
"enter": ["assistant2::Submit", "Simple"],
"cmd-enter": ["assistant2::Submit", "Codebase"]
}
},
{
"context": "AssistantPanel", // Used in the assistant crate, which we're replacing
"bindings": {
"cmd-g": "search::SelectNextMatch",
"cmd-shift-g": "search::SelectPrevMatch"

View file

@ -5,6 +5,9 @@ edition = "2021"
publish = false
license = "GPL-3.0-or-later"
[lib]
path = "src/assets.rs"
[lints]
workspace = true

View file

@ -1,7 +1,7 @@
// This crate was essentially pulled out verbatim from main `zed` crate to avoid having to run RustEmbed macro whenever zed has to be rebuilt. It saves a second or two on an incremental build.
use anyhow::anyhow;
use gpui::{AssetSource, Result, SharedString};
use gpui::{AppContext, AssetSource, Result, SharedString};
use rust_embed::RustEmbed;
#[derive(RustEmbed)]
@ -34,3 +34,19 @@ impl AssetSource for Assets {
.collect())
}
}
impl Assets {
/// Populate the [`TextSystem`] of the given [`AppContext`] with all `.ttf` fonts in the `fonts` directory.
pub fn load_fonts(&self, cx: &AppContext) -> gpui::Result<()> {
let font_paths = self.list("fonts")?;
let mut embedded_fonts = Vec::new();
for font_path in font_paths {
if font_path.ends_with(".ttf") {
let font_bytes = cx.asset_source().load(&font_path)?;
embedded_fonts.push(font_bytes);
}
}
cx.text_system().add_fonts(embedded_fonts)
}
}

View file

@ -128,6 +128,8 @@ impl LanguageModelRequestMessage {
Role::System => proto::LanguageModelRole::LanguageModelSystem,
} as i32,
content: self.content.clone(),
tool_calls: Vec::new(),
tool_call_id: None,
}
}
}
@ -147,6 +149,8 @@ impl LanguageModelRequest {
messages: self.messages.iter().map(|m| m.to_proto()).collect(),
stop: self.stop.clone(),
temperature: self.temperature,
tool_choice: None,
tools: Vec::new(),
}
}
}

View file

@ -140,14 +140,24 @@ impl OpenAiCompletionProvider {
messages: request
.messages
.into_iter()
.map(|msg| RequestMessage {
role: msg.role.into(),
content: msg.content,
.map(|msg| match msg.role {
Role::User => RequestMessage::User {
content: msg.content,
},
Role::Assistant => RequestMessage::Assistant {
content: Some(msg.content),
tool_calls: Vec::new(),
},
Role::System => RequestMessage::System {
content: msg.content,
},
})
.collect(),
stream: true,
stop: request.stop,
temperature: request.temperature,
tools: Vec::new(),
tool_choice: None,
}
}
}

View file

@ -123,6 +123,8 @@ impl ZedDotDevCompletionProvider {
.collect(),
stop: request.stop,
temperature: request.temperature,
tools: Vec::new(),
tool_choice: None,
};
self.client

View file

@ -0,0 +1,56 @@
[package]
name = "assistant2"
version = "0.1.0"
edition = "2021"
publish = false
license = "GPL-3.0-or-later"
[lib]
path = "src/assistant2.rs"
[[example]]
name = "assistant_example"
path = "examples/assistant_example.rs"
crate-type = ["bin"]
[dependencies]
anyhow.workspace = true
assistant_tooling.workspace = true
client.workspace = true
editor.workspace = true
feature_flags.workspace = true
futures.workspace = true
gpui.workspace = true
language.workspace = true
log.workspace = true
open_ai.workspace = true
project.workspace = true
rich_text.workspace = true
semantic_index.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
theme.workspace = true
ui.workspace = true
util.workspace = true
workspace.workspace = true
nanoid = "0.4"
[dev-dependencies]
assets.workspace = true
editor = { workspace = true, features = ["test-support"] }
env_logger.workspace = true
gpui = { workspace = true, features = ["test-support"] }
language = { workspace = true, features = ["test-support"] }
languages.workspace = true
node_runtime.workspace = true
project = { workspace = true, features = ["test-support"] }
release_channel.workspace = true
settings = { workspace = true, features = ["test-support"] }
theme = { workspace = true, features = ["test-support"] }
util = { workspace = true, features = ["test-support"] }
workspace = { workspace = true, features = ["test-support"] }
[lints]
workspace = true

View file

@ -0,0 +1 @@
../../LICENSE-GPL

View file

@ -0,0 +1,129 @@
use anyhow::Context as _;
use assets::Assets;
use assistant2::{tools::ProjectIndexTool, AssistantPanel};
use assistant_tooling::ToolRegistry;
use client::Client;
use gpui::{actions, App, AppContext, KeyBinding, Task, View, WindowOptions};
use language::LanguageRegistry;
use project::Project;
use semantic_index::{OpenAiEmbeddingModel, OpenAiEmbeddingProvider, SemanticIndex};
use settings::{KeymapFile, DEFAULT_KEYMAP_PATH};
use std::{
path::{Path, PathBuf},
sync::Arc,
};
use theme::LoadThemes;
use ui::{div, prelude::*, Render};
use util::{http::HttpClientWithUrl, ResultExt as _};
actions!(example, [Quit]);
fn main() {
let args: Vec<String> = std::env::args().collect();
env_logger::init();
App::new().with_assets(Assets).run(|cx| {
cx.bind_keys(Some(KeyBinding::new("cmd-q", Quit, None)));
cx.on_action(|_: &Quit, cx: &mut AppContext| {
cx.quit();
});
if args.len() < 2 {
eprintln!(
"Usage: cargo run --example assistant_example -p assistant2 -- <project_path>"
);
cx.quit();
return;
}
settings::init(cx);
language::init(cx);
Project::init_settings(cx);
editor::init(cx);
theme::init(LoadThemes::JustBase, cx);
Assets.load_fonts(cx).unwrap();
KeymapFile::load_asset(DEFAULT_KEYMAP_PATH, cx).unwrap();
client::init_settings(cx);
release_channel::init("0.130.0", cx);
let client = Client::production(cx);
{
let client = client.clone();
cx.spawn(|cx| async move { client.authenticate_and_connect(false, &cx).await })
.detach_and_log_err(cx);
}
assistant2::init(client.clone(), cx);
let language_registry = Arc::new(LanguageRegistry::new(
Task::ready(()),
cx.background_executor().clone(),
));
let node_runtime = node_runtime::RealNodeRuntime::new(client.http_client());
languages::init(language_registry.clone(), node_runtime, cx);
let http = Arc::new(HttpClientWithUrl::new("http://localhost:11434"));
let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
let embedding_provider = OpenAiEmbeddingProvider::new(
http.clone(),
OpenAiEmbeddingModel::TextEmbedding3Small,
open_ai::OPEN_AI_API_URL.to_string(),
api_key,
);
cx.spawn(|mut cx| async move {
let mut semantic_index = SemanticIndex::new(
PathBuf::from("/tmp/semantic-index-db.mdb"),
Arc::new(embedding_provider),
&mut cx,
)
.await?;
let project_path = Path::new(&args[1]);
let project = Project::example([project_path], &mut cx).await;
cx.update(|cx| {
let fs = project.read(cx).fs().clone();
let project_index = semantic_index.project_index(project.clone(), cx);
let mut tool_registry = ToolRegistry::new();
tool_registry
.register(ProjectIndexTool::new(project_index.clone(), fs.clone()))
.context("failed to register ProjectIndexTool")
.log_err();
let tool_registry = Arc::new(tool_registry);
cx.open_window(WindowOptions::default(), |cx| {
cx.new_view(|cx| Example::new(language_registry, tool_registry, cx))
});
cx.activate(true);
})
})
.detach_and_log_err(cx);
})
}
struct Example {
assistant_panel: View<AssistantPanel>,
}
impl Example {
fn new(
language_registry: Arc<LanguageRegistry>,
tool_registry: Arc<ToolRegistry>,
cx: &mut ViewContext<Self>,
) -> Self {
Self {
assistant_panel: cx
.new_view(|cx| AssistantPanel::new(language_registry, tool_registry, cx)),
}
}
}
impl Render for Example {
fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl ui::prelude::IntoElement {
div().size_full().child(self.assistant_panel.clone())
}
}

View file

@ -0,0 +1,952 @@
mod assistant_settings;
mod completion_provider;
pub mod tools;
use anyhow::{Context, Result};
use assistant_tooling::{ToolFunctionCall, ToolRegistry};
use client::{proto, Client};
use completion_provider::*;
use editor::{Editor, EditorEvent};
use feature_flags::FeatureFlagAppExt as _;
use futures::{channel::oneshot, future::join_all, Future, FutureExt, StreamExt};
use gpui::{
list, prelude::*, AnyElement, AppContext, AsyncWindowContext, EventEmitter, FocusHandle,
FocusableView, Global, ListAlignment, ListState, Model, Render, Task, View, WeakView,
};
use language::{language_settings::SoftWrap, LanguageRegistry};
use open_ai::{FunctionContent, ToolCall, ToolCallContent};
use project::Fs;
use rich_text::RichText;
use semantic_index::{CloudEmbeddingProvider, ProjectIndex, SemanticIndex};
use serde::Deserialize;
use settings::Settings;
use std::{cmp, sync::Arc};
use theme::ThemeSettings;
use tools::ProjectIndexTool;
use ui::{popover_menu, prelude::*, ButtonLike, CollapsibleContainer, Color, ContextMenu, Tooltip};
use util::{paths::EMBEDDINGS_DIR, ResultExt};
use workspace::{
dock::{DockPosition, Panel, PanelEvent},
Workspace,
};
pub use assistant_settings::AssistantSettings;
const MAX_COMPLETION_CALLS_PER_SUBMISSION: usize = 5;
// gpui::actions!(assistant, [Submit]);
#[derive(Eq, PartialEq, Copy, Clone, Deserialize)]
pub struct Submit(SubmitMode);
/// There are multiple different ways to submit a model request, represented by this enum.
#[derive(Eq, PartialEq, Copy, Clone, Deserialize)]
pub enum SubmitMode {
/// Only include the conversation.
Simple,
/// Send the current file as context.
CurrentFile,
/// Search the codebase and send relevant excerpts.
Codebase,
}
gpui::actions!(assistant2, [ToggleFocus]);
gpui::impl_actions!(assistant2, [Submit]);
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
AssistantSettings::register(cx);
cx.spawn(|mut cx| {
let client = client.clone();
async move {
let embedding_provider = CloudEmbeddingProvider::new(client.clone());
let semantic_index = SemanticIndex::new(
EMBEDDINGS_DIR.join("semantic-index-db.0.mdb"),
Arc::new(embedding_provider),
&mut cx,
)
.await?;
cx.update(|cx| cx.set_global(semantic_index))
}
})
.detach();
cx.set_global(CompletionProvider::new(CloudCompletionProvider::new(
client,
)));
cx.observe_new_views(
|workspace: &mut Workspace, _cx: &mut ViewContext<Workspace>| {
workspace.register_action(|workspace, _: &ToggleFocus, cx| {
workspace.toggle_panel_focus::<AssistantPanel>(cx);
});
},
)
.detach();
}
pub fn enabled(cx: &AppContext) -> bool {
cx.is_staff()
}
pub struct AssistantPanel {
chat: View<AssistantChat>,
width: Option<Pixels>,
}
impl AssistantPanel {
pub fn load(
workspace: WeakView<Workspace>,
cx: AsyncWindowContext,
) -> Task<Result<View<Self>>> {
cx.spawn(|mut cx| async move {
let (app_state, project) = workspace.update(&mut cx, |workspace, _| {
(workspace.app_state().clone(), workspace.project().clone())
})?;
cx.new_view(|cx| {
// todo!("this will panic if the semantic index failed to load or has not loaded yet")
let project_index = cx.update_global(|semantic_index: &mut SemanticIndex, cx| {
semantic_index.project_index(project.clone(), cx)
});
let mut tool_registry = ToolRegistry::new();
tool_registry
.register(ProjectIndexTool::new(
project_index.clone(),
app_state.fs.clone(),
))
.context("failed to register ProjectIndexTool")
.log_err();
let tool_registry = Arc::new(tool_registry);
Self::new(app_state.languages.clone(), tool_registry, cx)
})
})
}
pub fn new(
language_registry: Arc<LanguageRegistry>,
tool_registry: Arc<ToolRegistry>,
cx: &mut ViewContext<Self>,
) -> Self {
let chat = cx.new_view(|cx| {
AssistantChat::new(language_registry.clone(), tool_registry.clone(), cx)
});
Self { width: None, chat }
}
}
impl Render for AssistantPanel {
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
div()
.size_full()
.v_flex()
.p_2()
.bg(cx.theme().colors().background)
.child(self.chat.clone())
}
}
impl Panel for AssistantPanel {
fn persistent_name() -> &'static str {
"AssistantPanelv2"
}
fn position(&self, _cx: &WindowContext) -> workspace::dock::DockPosition {
// todo!("Add a setting / use assistant settings")
DockPosition::Right
}
fn position_is_valid(&self, position: workspace::dock::DockPosition) -> bool {
matches!(position, DockPosition::Right)
}
fn set_position(&mut self, _: workspace::dock::DockPosition, _: &mut ViewContext<Self>) {
// Do nothing until we have a setting for this
}
fn size(&self, _cx: &WindowContext) -> Pixels {
self.width.unwrap_or(px(400.))
}
fn set_size(&mut self, size: Option<Pixels>, cx: &mut ViewContext<Self>) {
self.width = size;
cx.notify();
}
fn icon(&self, _cx: &WindowContext) -> Option<ui::IconName> {
Some(IconName::Ai)
}
fn icon_tooltip(&self, _: &WindowContext) -> Option<&'static str> {
Some("Assistant Panel ✨")
}
fn toggle_action(&self) -> Box<dyn gpui::Action> {
Box::new(ToggleFocus)
}
}
impl EventEmitter<PanelEvent> for AssistantPanel {}
impl FocusableView for AssistantPanel {
fn focus_handle(&self, cx: &AppContext) -> FocusHandle {
self.chat
.read(cx)
.messages
.iter()
.rev()
.find_map(|msg| msg.focus_handle(cx))
.expect("no user message in chat")
}
}
struct AssistantChat {
model: String,
messages: Vec<ChatMessage>,
list_state: ListState,
language_registry: Arc<LanguageRegistry>,
next_message_id: MessageId,
pending_completion: Option<Task<()>>,
tool_registry: Arc<ToolRegistry>,
}
impl AssistantChat {
fn new(
language_registry: Arc<LanguageRegistry>,
tool_registry: Arc<ToolRegistry>,
cx: &mut ViewContext<Self>,
) -> Self {
let model = CompletionProvider::get(cx).default_model();
let view = cx.view().downgrade();
let list_state = ListState::new(
0,
ListAlignment::Bottom,
px(1024.),
move |ix, cx: &mut WindowContext| {
view.update(cx, |this, cx| this.render_message(ix, cx))
.unwrap()
},
);
let mut this = Self {
model,
messages: Vec::new(),
list_state,
language_registry,
next_message_id: MessageId(0),
pending_completion: None,
tool_registry,
};
this.push_new_user_message(true, cx);
this
}
fn focused_message_id(&self, cx: &WindowContext) -> Option<MessageId> {
self.messages.iter().find_map(|message| match message {
ChatMessage::User(message) => message
.body
.focus_handle(cx)
.contains_focused(cx)
.then_some(message.id),
ChatMessage::Assistant(_) => None,
})
}
fn submit(&mut self, Submit(mode): &Submit, cx: &mut ViewContext<Self>) {
let Some(focused_message_id) = self.focused_message_id(cx) else {
log::error!("unexpected state: no user message editor is focused.");
return;
};
self.truncate_messages(focused_message_id, cx);
let mode = *mode;
self.pending_completion = Some(cx.spawn(move |this, mut cx| async move {
Self::request_completion(
this.clone(),
mode,
MAX_COMPLETION_CALLS_PER_SUBMISSION,
&mut cx,
)
.await
.log_err();
this.update(&mut cx, |this, cx| {
let focus = this
.user_message(focused_message_id)
.body
.focus_handle(cx)
.contains_focused(cx);
this.push_new_user_message(focus, cx);
})
.context("Failed to push new user message")
.log_err();
}));
}
async fn request_completion(
this: WeakView<Self>,
mode: SubmitMode,
limit: usize,
cx: &mut AsyncWindowContext,
) -> Result<()> {
let mut call_count = 0;
loop {
let complete = async {
let completion = this.update(cx, |this, cx| {
this.push_new_assistant_message(cx);
let definitions = if call_count < limit && matches!(mode, SubmitMode::Codebase)
{
this.tool_registry.definitions()
} else {
&[]
};
call_count += 1;
CompletionProvider::get(cx).complete(
this.model.clone(),
this.completion_messages(cx),
Vec::new(),
1.0,
definitions,
)
});
let mut stream = completion?.await?;
let mut body = String::new();
while let Some(delta) = stream.next().await {
let delta = delta?;
this.update(cx, |this, cx| {
if let Some(ChatMessage::Assistant(AssistantMessage {
body: message_body,
tool_calls: message_tool_calls,
..
})) = this.messages.last_mut()
{
if let Some(content) = &delta.content {
body.push_str(content);
}
for tool_call in delta.tool_calls {
let index = tool_call.index as usize;
if index >= message_tool_calls.len() {
message_tool_calls.resize_with(index + 1, Default::default);
}
let call = &mut message_tool_calls[index];
if let Some(id) = &tool_call.id {
call.id.push_str(id);
}
match tool_call.variant {
Some(proto::tool_call_delta::Variant::Function(tool_call)) => {
if let Some(name) = &tool_call.name {
call.name.push_str(name);
}
if let Some(arguments) = &tool_call.arguments {
call.arguments.push_str(arguments);
}
}
None => {}
}
}
*message_body =
RichText::new(body.clone(), &[], &this.language_registry);
cx.notify();
} else {
unreachable!()
}
})?;
}
anyhow::Ok(())
}
.await;
let mut tool_tasks = Vec::new();
this.update(cx, |this, cx| {
if let Some(ChatMessage::Assistant(AssistantMessage {
error: message_error,
tool_calls,
..
})) = this.messages.last_mut()
{
if let Err(error) = complete {
message_error.replace(SharedString::from(error.to_string()));
cx.notify();
} else {
for tool_call in tool_calls.iter() {
tool_tasks.push(this.tool_registry.call(tool_call, cx));
}
}
}
})?;
if tool_tasks.is_empty() {
return Ok(());
}
let tools = join_all(tool_tasks.into_iter()).await;
this.update(cx, |this, cx| {
if let Some(ChatMessage::Assistant(AssistantMessage { tool_calls, .. })) =
this.messages.last_mut()
{
*tool_calls = tools;
cx.notify();
}
})?;
}
}
fn user_message(&mut self, message_id: MessageId) -> &mut UserMessage {
self.messages
.iter_mut()
.find_map(|message| match message {
ChatMessage::User(user_message) if user_message.id == message_id => {
Some(user_message)
}
_ => None,
})
.expect("User message not found")
}
fn push_new_user_message(&mut self, focus: bool, cx: &mut ViewContext<Self>) {
let id = self.next_message_id.post_inc();
let body = cx.new_view(|cx| {
let mut editor = Editor::auto_height(80, cx);
editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx);
if focus {
cx.focus_self();
}
editor
});
let _subscription = cx.subscribe(&body, move |this, editor, event, cx| match event {
EditorEvent::SelectionsChanged { .. } => {
if editor.read(cx).is_focused(cx) {
let (message_ix, _message) = this
.messages
.iter()
.enumerate()
.find_map(|(ix, message)| match message {
ChatMessage::User(user_message) if user_message.id == id => {
Some((ix, user_message))
}
_ => None,
})
.expect("user message not found");
this.list_state.scroll_to_reveal_item(message_ix);
}
}
_ => {}
});
let message = ChatMessage::User(UserMessage {
id,
body,
contexts: Vec::new(),
_subscription,
});
self.push_message(message, cx);
}
fn push_new_assistant_message(&mut self, cx: &mut ViewContext<Self>) {
let message = ChatMessage::Assistant(AssistantMessage {
id: self.next_message_id.post_inc(),
body: RichText::default(),
tool_calls: Vec::new(),
error: None,
});
self.push_message(message, cx);
}
fn push_message(&mut self, message: ChatMessage, cx: &mut ViewContext<Self>) {
let old_len = self.messages.len();
let focus_handle = Some(message.focus_handle(cx));
self.messages.push(message);
self.list_state
.splice_focusable(old_len..old_len, focus_handle);
cx.notify();
}
fn truncate_messages(&mut self, last_message_id: MessageId, cx: &mut ViewContext<Self>) {
if let Some(index) = self.messages.iter().position(|message| match message {
ChatMessage::User(message) => message.id == last_message_id,
ChatMessage::Assistant(message) => message.id == last_message_id,
}) {
self.list_state.splice(index + 1..self.messages.len(), 0);
self.messages.truncate(index + 1);
cx.notify();
}
}
fn render_error(
&self,
error: Option<SharedString>,
_ix: usize,
cx: &mut ViewContext<Self>,
) -> AnyElement {
let theme = cx.theme();
if let Some(error) = error {
div()
.py_1()
.px_2()
.neg_mx_1()
.rounded_md()
.border()
.border_color(theme.status().error_border)
// .bg(theme.status().error_background)
.text_color(theme.status().error)
.child(error.clone())
.into_any_element()
} else {
div().into_any_element()
}
}
fn render_message(&self, ix: usize, cx: &mut ViewContext<Self>) -> AnyElement {
let is_last = ix == self.messages.len() - 1;
match &self.messages[ix] {
ChatMessage::User(UserMessage {
body,
contexts: _contexts,
..
}) => div()
.when(!is_last, |element| element.mb_2())
.child(div().p_2().child(Label::new("You").color(Color::Default)))
.child(
div()
.on_action(cx.listener(Self::submit))
.p_2()
.text_color(cx.theme().colors().editor_foreground)
.font(ThemeSettings::get_global(cx).buffer_font.clone())
.bg(cx.theme().colors().editor_background)
.child(body.clone()), // .children(contexts.iter().map(|context| context.render(cx))),
)
.into_any(),
ChatMessage::Assistant(AssistantMessage {
id,
body,
error,
tool_calls,
..
}) => {
let assistant_body = if body.text.is_empty() && !tool_calls.is_empty() {
div()
} else {
div().p_2().child(body.element(ElementId::from(id.0), cx))
};
div()
.when(!is_last, |element| element.mb_2())
.child(
div()
.p_2()
.child(Label::new("Assistant").color(Color::Modified)),
)
.child(assistant_body)
.child(self.render_error(error.clone(), ix, cx))
.children(tool_calls.iter().map(|tool_call| {
let result = &tool_call.result;
let name = tool_call.name.clone();
match result {
Some(result) => div()
.p_2()
.child(result.render(&name, &tool_call.id, cx))
.into_any(),
None => div()
.p_2()
.child(Label::new(name).color(Color::Modified))
.child("Running...")
.into_any(),
}
}))
.into_any()
}
}
}
fn completion_messages(&self, cx: &WindowContext) -> Vec<CompletionMessage> {
let mut completion_messages = Vec::new();
for message in &self.messages {
match message {
ChatMessage::User(UserMessage { body, contexts, .. }) => {
// setup context for model
contexts.iter().for_each(|context| {
completion_messages.extend(context.completion_messages(cx))
});
// Show user's message last so that the assistant is grounded in the user's request
completion_messages.push(CompletionMessage::User {
content: body.read(cx).text(cx),
});
}
ChatMessage::Assistant(AssistantMessage {
body, tool_calls, ..
}) => {
// In no case do we want to send an empty message. This shouldn't happen, but we might as well
// not break the Chat API if it does.
if body.text.is_empty() && tool_calls.is_empty() {
continue;
}
let tool_calls_from_assistant = tool_calls
.iter()
.map(|tool_call| ToolCall {
content: ToolCallContent::Function {
function: FunctionContent {
name: tool_call.name.clone(),
arguments: tool_call.arguments.clone(),
},
},
id: tool_call.id.clone(),
})
.collect();
completion_messages.push(CompletionMessage::Assistant {
content: Some(body.text.to_string()),
tool_calls: tool_calls_from_assistant,
});
for tool_call in tool_calls {
// todo!(): we should not be sending when the tool is still running / has no result
// For now I'm going to have to assume we send an empty string because otherwise
// the Chat API will break -- there is a required message for every tool call by ID
let content = match &tool_call.result {
Some(result) => result.format(&tool_call.name),
None => "".to_string(),
};
completion_messages.push(CompletionMessage::Tool {
content,
tool_call_id: tool_call.id.clone(),
});
}
}
}
}
completion_messages
}
fn render_model_dropdown(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
let this = cx.view().downgrade();
div().h_flex().justify_end().child(
div().w_32().child(
popover_menu("user-menu")
.menu(move |cx| {
ContextMenu::build(cx, |mut menu, cx| {
for model in CompletionProvider::get(cx).available_models() {
menu = menu.custom_entry(
{
let model = model.clone();
move |_| Label::new(model.clone()).into_any_element()
},
{
let this = this.clone();
move |cx| {
_ = this.update(cx, |this, cx| {
this.model = model.clone();
cx.notify();
});
}
},
);
}
menu
})
.into()
})
.trigger(
ButtonLike::new("active-model")
.child(
h_flex()
.w_full()
.gap_0p5()
.child(
div()
.overflow_x_hidden()
.flex_grow()
.whitespace_nowrap()
.child(Label::new(self.model.clone())),
)
.child(div().child(
Icon::new(IconName::ChevronDown).color(Color::Muted),
)),
)
.style(ButtonStyle::Subtle)
.tooltip(move |cx| Tooltip::text("Change Model", cx)),
)
.anchor(gpui::AnchorCorner::TopRight),
),
)
}
}
impl Render for AssistantChat {
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
div()
.relative()
.flex_1()
.v_flex()
.key_context("AssistantChat")
.text_color(Color::Default.color(cx))
.child(self.render_model_dropdown(cx))
.child(list(self.list_state.clone()).flex_1())
}
}
#[derive(Copy, Clone, Eq, PartialEq)]
struct MessageId(usize);
impl MessageId {
fn post_inc(&mut self) -> Self {
let id = *self;
self.0 += 1;
id
}
}
enum ChatMessage {
User(UserMessage),
Assistant(AssistantMessage),
}
impl ChatMessage {
fn focus_handle(&self, cx: &AppContext) -> Option<FocusHandle> {
match self {
ChatMessage::User(UserMessage { body, .. }) => Some(body.focus_handle(cx)),
ChatMessage::Assistant(_) => None,
}
}
}
struct UserMessage {
id: MessageId,
body: View<Editor>,
contexts: Vec<AssistantContext>,
_subscription: gpui::Subscription,
}
struct AssistantMessage {
id: MessageId,
body: RichText,
tool_calls: Vec<ToolFunctionCall>,
error: Option<SharedString>,
}
// Since we're swapping out for direct query usage, we might not need to use this injected context
// It will be useful though for when the user _definitely_ wants the model to see a specific file,
// query, error, etc.
#[allow(dead_code)]
enum AssistantContext {
Codebase(View<CodebaseContext>),
}
#[allow(dead_code)]
struct CodebaseExcerpt {
element_id: ElementId,
path: SharedString,
text: SharedString,
score: f32,
expanded: bool,
}
impl AssistantContext {
#[allow(dead_code)]
fn render(&self, _cx: &mut ViewContext<AssistantChat>) -> AnyElement {
match self {
AssistantContext::Codebase(context) => context.clone().into_any_element(),
}
}
fn completion_messages(&self, cx: &WindowContext) -> Vec<CompletionMessage> {
match self {
AssistantContext::Codebase(context) => context.read(cx).completion_messages(),
}
}
}
enum CodebaseContext {
Pending { _task: Task<()> },
Done(Result<Vec<CodebaseExcerpt>>),
}
impl CodebaseContext {
fn toggle_expanded(&mut self, element_id: ElementId, cx: &mut ViewContext<Self>) {
if let CodebaseContext::Done(Ok(excerpts)) = self {
if let Some(excerpt) = excerpts
.iter_mut()
.find(|excerpt| excerpt.element_id == element_id)
{
excerpt.expanded = !excerpt.expanded;
cx.notify();
}
}
}
}
impl Render for CodebaseContext {
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
match self {
CodebaseContext::Pending { .. } => div()
.h_flex()
.items_center()
.gap_1()
.child(Icon::new(IconName::Ai).color(Color::Muted).into_element())
.child("Searching codebase..."),
CodebaseContext::Done(Ok(excerpts)) => {
div()
.v_flex()
.gap_2()
.children(excerpts.iter().map(|excerpt| {
let expanded = excerpt.expanded;
let element_id = excerpt.element_id.clone();
CollapsibleContainer::new(element_id.clone(), expanded)
.start_slot(
h_flex()
.gap_1()
.child(Icon::new(IconName::File).color(Color::Muted))
.child(Label::new(excerpt.path.clone()).color(Color::Muted)),
)
.on_click(cx.listener(move |this, _, cx| {
this.toggle_expanded(element_id.clone(), cx);
}))
.child(
div()
.p_2()
.rounded_md()
.bg(cx.theme().colors().editor_background)
.child(
excerpt.text.clone(), // todo!(): Show as an editor block
),
)
}))
}
CodebaseContext::Done(Err(error)) => div().child(error.to_string()),
}
}
}
impl CodebaseContext {
#[allow(dead_code)]
fn new(
query: impl 'static + Future<Output = Result<String>>,
populated: oneshot::Sender<bool>,
project_index: Model<ProjectIndex>,
fs: Arc<dyn Fs>,
cx: &mut ViewContext<Self>,
) -> Self {
let query = query.boxed_local();
let _task = cx.spawn(|this, mut cx| async move {
let result = async {
let query = query.await?;
let results = this
.update(&mut cx, |_this, cx| {
project_index.read(cx).search(&query, 16, cx)
})?
.await;
let excerpts = results.into_iter().map(|result| {
let abs_path = result
.worktree
.read_with(&cx, |worktree, _| worktree.abs_path().join(&result.path));
let fs = fs.clone();
async move {
let path = result.path.clone();
let text = fs.load(&abs_path?).await?;
// todo!("what should we do with stale ranges?");
let range = cmp::min(result.range.start, text.len())
..cmp::min(result.range.end, text.len());
let text = SharedString::from(text[range].to_string());
anyhow::Ok(CodebaseExcerpt {
element_id: ElementId::Name(nanoid::nanoid!().into()),
path: path.to_string_lossy().to_string().into(),
text,
score: result.score,
expanded: false,
})
}
});
anyhow::Ok(
futures::future::join_all(excerpts)
.await
.into_iter()
.filter_map(|result| result.log_err())
.collect(),
)
}
.await;
this.update(&mut cx, |this, cx| {
this.populate(result, populated, cx);
})
.ok();
});
Self::Pending { _task }
}
#[allow(dead_code)]
fn populate(
&mut self,
result: Result<Vec<CodebaseExcerpt>>,
populated: oneshot::Sender<bool>,
cx: &mut ViewContext<Self>,
) {
let success = result.is_ok();
*self = Self::Done(result);
populated.send(success).ok();
cx.notify();
}
fn completion_messages(&self) -> Vec<CompletionMessage> {
// One system message for the whole batch of excerpts:
// Semantic search results for user query:
//
// Excerpt from $path:
// ~~~
// `text`
// ~~~
//
// Excerpt from $path:
match self {
CodebaseContext::Done(Ok(excerpts)) => {
if excerpts.is_empty() {
return Vec::new();
}
let mut body = "Semantic search results for user query:\n".to_string();
for excerpt in excerpts {
body.push_str("Excerpt from ");
body.push_str(excerpt.path.as_ref());
body.push_str(", score ");
body.push_str(&excerpt.score.to_string());
body.push_str(":\n");
body.push_str("~~~\n");
body.push_str(excerpt.text.as_ref());
body.push_str("~~~\n");
}
vec![CompletionMessage::System { content: body }]
}
_ => vec![],
}
}
}

View file

@ -0,0 +1,26 @@
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsSources};
#[derive(Default, Debug, Deserialize, Serialize, Clone)]
pub struct AssistantSettings {
pub enabled: bool,
}
#[derive(Default, Debug, Deserialize, Serialize, Clone, JsonSchema)]
pub struct AssistantSettingsContent {
pub enabled: Option<bool>,
}
impl Settings for AssistantSettings {
const KEY: Option<&'static str> = Some("assistant_v2");
type FileContent = AssistantSettingsContent;
fn load(
sources: SettingsSources<Self::FileContent>,
_: &mut gpui::AppContext,
) -> anyhow::Result<Self> {
Ok(sources.json_merge().unwrap_or_else(|_| Default::default()))
}
}

View file

@ -0,0 +1,179 @@
use anyhow::Result;
use assistant_tooling::ToolFunctionDefinition;
use client::{proto, Client};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::Global;
use std::sync::Arc;
pub use open_ai::RequestMessage as CompletionMessage;
#[derive(Clone)]
pub struct CompletionProvider(Arc<dyn CompletionProviderBackend>);
impl CompletionProvider {
pub fn new(backend: impl CompletionProviderBackend) -> Self {
Self(Arc::new(backend))
}
pub fn default_model(&self) -> String {
self.0.default_model()
}
pub fn available_models(&self) -> Vec<String> {
self.0.available_models()
}
pub fn complete(
&self,
model: String,
messages: Vec<CompletionMessage>,
stop: Vec<String>,
temperature: f32,
tools: &[ToolFunctionDefinition],
) -> BoxFuture<'static, Result<BoxStream<'static, Result<proto::LanguageModelResponseMessage>>>>
{
self.0.complete(model, messages, stop, temperature, tools)
}
}
impl Global for CompletionProvider {}
pub trait CompletionProviderBackend: 'static {
fn default_model(&self) -> String;
fn available_models(&self) -> Vec<String>;
fn complete(
&self,
model: String,
messages: Vec<CompletionMessage>,
stop: Vec<String>,
temperature: f32,
tools: &[ToolFunctionDefinition],
) -> BoxFuture<'static, Result<BoxStream<'static, Result<proto::LanguageModelResponseMessage>>>>;
}
pub struct CloudCompletionProvider {
client: Arc<Client>,
}
impl CloudCompletionProvider {
pub fn new(client: Arc<Client>) -> Self {
Self { client }
}
}
impl CompletionProviderBackend for CloudCompletionProvider {
fn default_model(&self) -> String {
"gpt-4-turbo".into()
}
fn available_models(&self) -> Vec<String> {
vec!["gpt-4-turbo".into(), "gpt-4".into(), "gpt-3.5-turbo".into()]
}
fn complete(
&self,
model: String,
messages: Vec<CompletionMessage>,
stop: Vec<String>,
temperature: f32,
tools: &[ToolFunctionDefinition],
) -> BoxFuture<'static, Result<BoxStream<'static, Result<proto::LanguageModelResponseMessage>>>>
{
let client = self.client.clone();
let tools: Vec<proto::ChatCompletionTool> = tools
.iter()
.filter_map(|tool| {
Some(proto::ChatCompletionTool {
variant: Some(proto::chat_completion_tool::Variant::Function(
proto::chat_completion_tool::FunctionObject {
name: tool.name.clone(),
description: Some(tool.description.clone()),
parameters: Some(serde_json::to_string(&tool.parameters).ok()?),
},
)),
})
})
.collect();
let tool_choice = match tools.is_empty() {
true => None,
false => Some("auto".into()),
};
async move {
let stream = client
.request_stream(proto::CompleteWithLanguageModel {
model,
messages: messages
.into_iter()
.map(|message| match message {
CompletionMessage::Assistant {
content,
tool_calls,
} => proto::LanguageModelRequestMessage {
role: proto::LanguageModelRole::LanguageModelAssistant as i32,
content: content.unwrap_or_default(),
tool_call_id: None,
tool_calls: tool_calls
.into_iter()
.map(|tool_call| match tool_call.content {
open_ai::ToolCallContent::Function { function } => {
proto::ToolCall {
id: tool_call.id,
variant: Some(proto::tool_call::Variant::Function(
proto::tool_call::FunctionCall {
name: function.name,
arguments: function.arguments,
},
)),
}
}
})
.collect(),
},
CompletionMessage::User { content } => {
proto::LanguageModelRequestMessage {
role: proto::LanguageModelRole::LanguageModelUser as i32,
content,
tool_call_id: None,
tool_calls: Vec::new(),
}
}
CompletionMessage::System { content } => {
proto::LanguageModelRequestMessage {
role: proto::LanguageModelRole::LanguageModelSystem as i32,
content,
tool_calls: Vec::new(),
tool_call_id: None,
}
}
CompletionMessage::Tool {
content,
tool_call_id,
} => proto::LanguageModelRequestMessage {
role: proto::LanguageModelRole::LanguageModelTool as i32,
content,
tool_call_id: Some(tool_call_id),
tool_calls: Vec::new(),
},
})
.collect(),
stop,
temperature,
tool_choice,
tools,
})
.await?;
Ok(stream
.filter_map(|response| async move {
match response {
Ok(mut response) => Some(Ok(response.choices.pop()?.delta?)),
Err(error) => Some(Err(error)),
}
})
.boxed())
}
.boxed()
}
}

View file

@ -0,0 +1,176 @@
use anyhow::Result;
use assistant_tooling::LanguageModelTool;
use gpui::{prelude::*, AnyElement, AppContext, Model, Task};
use project::Fs;
use schemars::JsonSchema;
use semantic_index::ProjectIndex;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use ui::{
div, prelude::*, CollapsibleContainer, Color, Icon, IconName, Label, SharedString,
WindowContext,
};
use util::ResultExt as _;
const DEFAULT_SEARCH_LIMIT: usize = 20;
#[derive(Serialize, Clone)]
pub struct CodebaseExcerpt {
path: SharedString,
text: SharedString,
score: f32,
}
// Note: Comments on a `LanguageModelTool::Input` become descriptions on the generated JSON schema as shown to the language model.
// Any changes or deletions to the `CodebaseQuery` comments will change model behavior.
#[derive(Deserialize, JsonSchema)]
pub struct CodebaseQuery {
/// Semantic search query
query: String,
/// Maximum number of results to return, defaults to 20
limit: Option<usize>,
}
pub struct ProjectIndexTool {
project_index: Model<ProjectIndex>,
fs: Arc<dyn Fs>,
}
impl ProjectIndexTool {
pub fn new(project_index: Model<ProjectIndex>, fs: Arc<dyn Fs>) -> Self {
// TODO: setup a better description based on the user's current codebase.
Self { project_index, fs }
}
}
impl LanguageModelTool for ProjectIndexTool {
type Input = CodebaseQuery;
type Output = Vec<CodebaseExcerpt>;
fn name(&self) -> String {
"query_codebase".to_string()
}
fn description(&self) -> String {
"Semantic search against the user's current codebase, returning excerpts related to the query by computing a dot product against embeddings of chunks and an embedding of the query".to_string()
}
fn execute(&self, query: &Self::Input, cx: &AppContext) -> Task<Result<Self::Output>> {
let project_index = self.project_index.read(cx);
let results = project_index.search(
query.query.as_str(),
query.limit.unwrap_or(DEFAULT_SEARCH_LIMIT),
cx,
);
let fs = self.fs.clone();
cx.spawn(|cx| async move {
let results = results.await;
let excerpts = results.into_iter().map(|result| {
let abs_path = result
.worktree
.read_with(&cx, |worktree, _| worktree.abs_path().join(&result.path));
let fs = fs.clone();
async move {
let path = result.path.clone();
let text = fs.load(&abs_path?).await?;
let mut start = result.range.start;
let mut end = result.range.end.min(text.len());
while !text.is_char_boundary(start) {
start += 1;
}
while !text.is_char_boundary(end) {
end -= 1;
}
anyhow::Ok(CodebaseExcerpt {
path: path.to_string_lossy().to_string().into(),
text: SharedString::from(text[start..end].to_string()),
score: result.score,
})
}
});
let excerpts = futures::future::join_all(excerpts)
.await
.into_iter()
.filter_map(|result| result.log_err())
.collect();
anyhow::Ok(excerpts)
})
}
fn render(
_tool_call_id: &str,
input: &Self::Input,
excerpts: &Self::Output,
cx: &mut WindowContext,
) -> AnyElement {
let query = input.query.clone();
div()
.v_flex()
.gap_2()
.child(
div()
.p_2()
.rounded_md()
.bg(cx.theme().colors().editor_background)
.child(
h_flex()
.child(Label::new("Query: ").color(Color::Modified))
.child(Label::new(query).color(Color::Muted)),
),
)
.children(excerpts.iter().map(|excerpt| {
// This render doesn't have state/model, so we can't use the listener
// let expanded = excerpt.expanded;
// let element_id = excerpt.element_id.clone();
let element_id = ElementId::Name(nanoid::nanoid!().into());
let expanded = false;
CollapsibleContainer::new(element_id.clone(), expanded)
.start_slot(
h_flex()
.gap_1()
.child(Icon::new(IconName::File).color(Color::Muted))
.child(Label::new(excerpt.path.clone()).color(Color::Muted)),
)
// .on_click(cx.listener(move |this, _, cx| {
// this.toggle_expanded(element_id.clone(), cx);
// }))
.child(
div()
.p_2()
.rounded_md()
.bg(cx.theme().colors().editor_background)
.child(
excerpt.text.clone(), // todo!(): Show as an editor block
),
)
}))
.into_any_element()
}
fn format(_input: &Self::Input, excerpts: &Self::Output) -> String {
let mut body = "Semantic search results:\n".to_string();
for excerpt in excerpts {
body.push_str("Excerpt from ");
body.push_str(excerpt.path.as_ref());
body.push_str(", score ");
body.push_str(&excerpt.score.to_string());
body.push_str(":\n");
body.push_str("~~~\n");
body.push_str(excerpt.text.as_ref());
body.push_str("~~~\n");
}
body
}
}

View file

@ -0,0 +1,22 @@
[package]
name = "assistant_tooling"
version = "0.1.0"
edition = "2021"
publish = false
license = "GPL-3.0-or-later"
[lints]
workspace = true
[lib]
path = "src/assistant_tooling.rs"
[dependencies]
anyhow.workspace = true
gpui.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
[dev-dependencies]
gpui = { workspace = true, features = ["test-support"] }

View file

@ -0,0 +1 @@
../../LICENSE-GPL

View file

@ -0,0 +1,208 @@
# Assistant Tooling
Bringing OpenAI compatible tool calling to GPUI.
This unlocks:
- **Structured Extraction** of model responses
- **Validation** of model inputs
- **Execution** of chosen toolsn
## Overview
Language Models can produce structured outputs that are perfect for calling functions. The most famous of these is OpenAI's tool calling. When make a chat completion you can pass a list of tools available to the model. The model will choose `0..n` tools to help them complete a user's task. It's up to _you_ to create the tools that the model can call.
> **User**: "Hey I need help with implementing a collapsible panel in GPUI"
>
> **Assistant**: "Sure, I can help with that. Let me see what I can find."
>
> `tool_calls: ["name": "query_codebase", arguments: "{ 'query': 'GPUI collapsible panel' }"]`
>
> `result: "['crates/gpui/src/panel.rs:12: impl Panel { ... }', 'crates/gpui/src/panel.rs:20: impl Panel { ... }']"`
>
> **Assistant**: "Here are some excerpts from the GPUI codebase that might help you."
This library is designed to facilitate this interaction mode by allowing you to go from `struct` to `tool` with a simple trait, `LanguageModelTool`.
## Example
Let's expose querying a semantic index directly by the model. First, we'll set up some _necessary_ imports
```rust
use anyhow::Result;
use assistant_tooling::{LanguageModelTool, ToolRegistry};
use gpui::{App, AppContext, Task};
use schemars::JsonSchema;
use serde::Deserialize;
use serde_json::json;
```
Then we'll define the query structure the model must fill in. This _must_ derive `Deserialize` from `serde` and `JsonSchema` from the `schemars` crate.
```rust
#[derive(Deserialize, JsonSchema)]
struct CodebaseQuery {
query: String,
}
```
After that we can define our tool, with the expectation that it will need a `ProjectIndex` to search against. For this example, the index uses the same interface as `semantic_index::ProjectIndex`.
```rust
struct ProjectIndex {}
impl ProjectIndex {
fn new() -> Self {
ProjectIndex {}
}
fn search(&self, _query: &str, _limit: usize, _cx: &AppContext) -> Task<Result<Vec<String>>> {
// Instead of hooking up a real index, we're going to fake it
if _query.contains("gpui") {
return Task::ready(Ok(vec![r#"// crates/gpui/src/gpui.rs
//! # Welcome to GPUI!
//!
//! GPUI is a hybrid immediate and retained mode, GPU accelerated, UI framework
//! for Rust, designed to support a wide variety of applications
"#
.to_string()]));
}
return Task::ready(Ok(vec![]));
}
}
struct ProjectIndexTool {
project_index: ProjectIndex,
}
```
Now we can implement the `LanguageModelTool` trait for our tool by:
- Defining the `Input` from the model, which is `CodebaseQuery`
- Defining the `Output`
- Implementing the `name` and `description` functions to provide the model information when it's choosing a tool
- Implementing the `execute` function to run the tool
```rust
impl LanguageModelTool for ProjectIndexTool {
type Input = CodebaseQuery;
type Output = String;
fn name(&self) -> String {
"query_codebase".to_string()
}
fn description(&self) -> String {
"Executes a query against the codebase, returning excerpts related to the query".to_string()
}
fn execute(&self, query: Self::Input, cx: &AppContext) -> Task<Result<Self::Output>> {
let results = self.project_index.search(query.query.as_str(), 10, cx);
cx.spawn(|_cx| async move {
let results = results.await?;
if !results.is_empty() {
Ok(results.join("\n"))
} else {
Ok("No results".to_string())
}
})
}
}
```
For the sake of this example, let's look at the types that OpenAI will be passing to us
```rust
// OpenAI definitions, shown here for demonstration
#[derive(Deserialize)]
struct FunctionCall {
name: String,
args: String,
}
#[derive(Deserialize, Eq, PartialEq)]
enum ToolCallType {
#[serde(rename = "function")]
Function,
Other,
}
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash, Ord, PartialOrd)]
struct ToolCallId(String);
#[derive(Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
enum ToolCall {
Function {
#[allow(dead_code)]
id: ToolCallId,
function: FunctionCall,
},
Other {
#[allow(dead_code)]
id: ToolCallId,
},
}
#[derive(Deserialize)]
struct AssistantMessage {
role: String,
content: Option<String>,
tool_calls: Option<Vec<ToolCall>>,
}
```
When the model wants to call tools, it will pass a list of `ToolCall`s. When those are `function`s that we can handle, we'll pass them to our `ToolRegistry` to get a future that we can await.
```rust
// Inside `fn main()`
App::new().run(|cx: &mut AppContext| {
let tool = ProjectIndexTool {
project_index: ProjectIndex::new(),
};
let mut registry = ToolRegistry::new();
let registered = registry.register(tool);
assert!(registered.is_ok());
```
Let's pretend the model sent us back a message requesting
```rust
let model_response = json!({
"role": "assistant",
"tool_calls": [
{
"id": "call_1",
"function": {
"name": "query_codebase",
"args": r#"{"query":"GPUI Task background_executor"}"#
},
"type": "function"
}
]
});
let message: AssistantMessage = serde_json::from_value(model_response).unwrap();
// We know there's a tool call, so let's skip straight to it for this example
let tool_calls = message.tool_calls.as_ref().unwrap();
let tool_call = tool_calls.get(0).unwrap();
```
We can now use our registry to call the tool.
```rust
let task = registry.call(
tool_call.name,
tool_call.args,
);
cx.spawn(|_cx| async move {
let result = task.await?;
println!("{}", result.unwrap());
Ok(())
})
```

View file

@ -0,0 +1,5 @@
pub mod registry;
pub mod tool;
pub use crate::registry::ToolRegistry;
pub use crate::tool::{LanguageModelTool, ToolFunctionCall, ToolFunctionDefinition};

View file

@ -0,0 +1,298 @@
use anyhow::{anyhow, Result};
use gpui::{AnyElement, AppContext, Task, WindowContext};
use std::{any::Any, collections::HashMap};
use crate::tool::{
LanguageModelTool, ToolFunctionCall, ToolFunctionCallResult, ToolFunctionDefinition,
};
pub struct ToolRegistry {
tools: HashMap<String, Box<dyn Fn(&ToolFunctionCall, &AppContext) -> Task<ToolFunctionCall>>>,
definitions: Vec<ToolFunctionDefinition>,
}
impl ToolRegistry {
pub fn new() -> Self {
Self {
tools: HashMap::new(),
definitions: Vec::new(),
}
}
pub fn definitions(&self) -> &[ToolFunctionDefinition] {
&self.definitions
}
pub fn register<T: 'static + LanguageModelTool>(&mut self, tool: T) -> Result<()> {
fn render<T: 'static + LanguageModelTool>(
tool_call_id: &str,
input: &Box<dyn Any>,
output: &Box<dyn Any>,
cx: &mut WindowContext,
) -> AnyElement {
T::render(
tool_call_id,
input.as_ref().downcast_ref::<T::Input>().unwrap(),
output.as_ref().downcast_ref::<T::Output>().unwrap(),
cx,
)
}
fn format<T: 'static + LanguageModelTool>(
input: &Box<dyn Any>,
output: &Box<dyn Any>,
) -> String {
T::format(
input.as_ref().downcast_ref::<T::Input>().unwrap(),
output.as_ref().downcast_ref::<T::Output>().unwrap(),
)
}
self.definitions.push(tool.definition());
let name = tool.name();
let previous = self.tools.insert(
name.clone(),
Box::new(move |tool_call: &ToolFunctionCall, cx: &AppContext| {
let name = tool_call.name.clone();
let arguments = tool_call.arguments.clone();
let id = tool_call.id.clone();
let Ok(input) = serde_json::from_str::<T::Input>(arguments.as_str()) else {
return Task::ready(ToolFunctionCall {
id,
name: name.clone(),
arguments,
result: Some(ToolFunctionCallResult::ParsingFailed),
});
};
let result = tool.execute(&input, cx);
cx.spawn(move |_cx| async move {
match result.await {
Ok(result) => {
let result: T::Output = result;
ToolFunctionCall {
id,
name: name.clone(),
arguments,
result: Some(ToolFunctionCallResult::Finished {
input: Box::new(input),
output: Box::new(result),
render_fn: render::<T>,
format_fn: format::<T>,
}),
}
}
Err(_error) => ToolFunctionCall {
id,
name: name.clone(),
arguments,
result: Some(ToolFunctionCallResult::ExecutionFailed {
input: Box::new(input),
}),
},
}
})
}),
);
if previous.is_some() {
return Err(anyhow!("already registered a tool with name {}", name));
}
Ok(())
}
pub fn call(&self, tool_call: &ToolFunctionCall, cx: &AppContext) -> Task<ToolFunctionCall> {
let name = tool_call.name.clone();
let arguments = tool_call.arguments.clone();
let id = tool_call.id.clone();
let tool = match self.tools.get(&name) {
Some(tool) => tool,
None => {
let name = name.clone();
return Task::ready(ToolFunctionCall {
id,
name: name.clone(),
arguments,
result: Some(ToolFunctionCallResult::NoSuchTool),
});
}
};
tool(tool_call, cx)
}
}
#[cfg(test)]
mod test {
use super::*;
use schemars::schema_for;
use gpui::{div, AnyElement, Element, ParentElement, TestAppContext, WindowContext};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::json;
#[derive(Deserialize, Serialize, JsonSchema)]
struct WeatherQuery {
location: String,
unit: String,
}
struct WeatherTool {
current_weather: WeatherResult,
}
#[derive(Clone, Serialize, Deserialize, PartialEq, Debug)]
struct WeatherResult {
location: String,
temperature: f64,
unit: String,
}
impl LanguageModelTool for WeatherTool {
type Input = WeatherQuery;
type Output = WeatherResult;
fn name(&self) -> String {
"get_current_weather".to_string()
}
fn description(&self) -> String {
"Fetches the current weather for a given location.".to_string()
}
fn execute(&self, input: &WeatherQuery, _cx: &AppContext) -> Task<Result<Self::Output>> {
let _location = input.location.clone();
let _unit = input.unit.clone();
let weather = self.current_weather.clone();
Task::ready(Ok(weather))
}
fn render(
_tool_call_id: &str,
_input: &Self::Input,
output: &Self::Output,
_cx: &mut WindowContext,
) -> AnyElement {
div()
.child(format!(
"The current temperature in {} is {} {}",
output.location, output.temperature, output.unit
))
.into_any()
}
fn format(_input: &Self::Input, output: &Self::Output) -> String {
format!(
"The current temperature in {} is {} {}",
output.location, output.temperature, output.unit
)
}
}
#[gpui::test]
async fn test_function_registry(cx: &mut TestAppContext) {
cx.background_executor.run_until_parked();
let mut registry = ToolRegistry::new();
let tool = WeatherTool {
current_weather: WeatherResult {
location: "San Francisco".to_string(),
temperature: 21.0,
unit: "Celsius".to_string(),
},
};
registry.register(tool).unwrap();
let _result = cx
.update(|cx| {
registry.call(
&ToolFunctionCall {
name: "get_current_weather".to_string(),
arguments: r#"{ "location": "San Francisco", "unit": "Celsius" }"#
.to_string(),
id: "test-123".to_string(),
result: None,
},
cx,
)
})
.await;
// assert!(result.is_ok());
// let result = result.unwrap();
// let expected = r#"{"location":"San Francisco","temperature":21.0,"unit":"Celsius"}"#;
// todo!(): Put this back in after the interface is stabilized
// assert_eq!(result, expected);
}
#[gpui::test]
async fn test_openai_weather_example(cx: &mut TestAppContext) {
cx.background_executor.run_until_parked();
let tool = WeatherTool {
current_weather: WeatherResult {
location: "San Francisco".to_string(),
temperature: 21.0,
unit: "Celsius".to_string(),
},
};
let tools = vec![tool.definition()];
assert_eq!(tools.len(), 1);
let expected = ToolFunctionDefinition {
name: "get_current_weather".to_string(),
description: "Fetches the current weather for a given location.".to_string(),
parameters: schema_for!(WeatherQuery).schema,
};
assert_eq!(tools[0].name, expected.name);
assert_eq!(tools[0].description, expected.description);
let expected_schema = serde_json::to_value(&tools[0].parameters).unwrap();
assert_eq!(
expected_schema,
json!({
"title": "WeatherQuery",
"type": "object",
"properties": {
"location": {
"type": "string"
},
"unit": {
"type": "string"
}
},
"required": ["location", "unit"]
})
);
let args = json!({
"location": "San Francisco",
"unit": "Celsius"
});
let query: WeatherQuery = serde_json::from_value(args).unwrap();
let result = cx.update(|cx| tool.execute(&query, cx)).await;
assert!(result.is_ok());
let result = result.unwrap();
assert_eq!(result, tool.current_weather);
}
}

View file

@ -0,0 +1,145 @@
use anyhow::Result;
use gpui::{div, AnyElement, AppContext, Element, ParentElement as _, Task, WindowContext};
use schemars::{schema::SchemaObject, schema_for, JsonSchema};
use serde::Deserialize;
use std::{any::Any, fmt::Debug};
#[derive(Default, Deserialize)]
pub struct ToolFunctionCall {
pub id: String,
pub name: String,
pub arguments: String,
#[serde(skip)]
pub result: Option<ToolFunctionCallResult>,
}
pub enum ToolFunctionCallResult {
NoSuchTool,
ParsingFailed,
ExecutionFailed {
input: Box<dyn Any>,
},
Finished {
input: Box<dyn Any>,
output: Box<dyn Any>,
render_fn: fn(
// tool_call_id
&str,
// LanguageModelTool::Input
&Box<dyn Any>,
// LanguageModelTool::Output
&Box<dyn Any>,
&mut WindowContext,
) -> AnyElement,
format_fn: fn(
// LanguageModelTool::Input
&Box<dyn Any>,
// LanguageModelTool::Output
&Box<dyn Any>,
) -> String,
},
}
impl ToolFunctionCallResult {
pub fn render(
&self,
tool_name: &str,
tool_call_id: &str,
cx: &mut WindowContext,
) -> AnyElement {
match self {
ToolFunctionCallResult::NoSuchTool => {
div().child(format!("no such tool {tool_name}")).into_any()
}
ToolFunctionCallResult::ParsingFailed => div()
.child(format!("failed to parse input for tool {tool_name}"))
.into_any(),
ToolFunctionCallResult::ExecutionFailed { .. } => div()
.child(format!("failed to execute tool {tool_name}"))
.into_any(),
ToolFunctionCallResult::Finished {
input,
output,
render_fn,
..
} => render_fn(tool_call_id, input, output, cx),
}
}
pub fn format(&self, tool: &str) -> String {
match self {
ToolFunctionCallResult::NoSuchTool => format!("no such tool {tool}"),
ToolFunctionCallResult::ParsingFailed => {
format!("failed to parse input for tool {tool}")
}
ToolFunctionCallResult::ExecutionFailed { input: _input } => {
format!("failed to execute tool {tool}")
}
ToolFunctionCallResult::Finished {
input,
output,
format_fn,
..
} => format_fn(input, output),
}
}
}
#[derive(Clone)]
pub struct ToolFunctionDefinition {
pub name: String,
pub description: String,
pub parameters: SchemaObject,
}
impl Debug for ToolFunctionDefinition {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let schema = serde_json::to_string(&self.parameters).ok();
let schema = schema.unwrap_or("None".to_string());
f.debug_struct("ToolFunctionDefinition")
.field("name", &self.name)
.field("description", &self.description)
.field("parameters", &schema)
.finish()
}
}
pub trait LanguageModelTool {
/// The input type that will be passed in to `execute` when the tool is called
/// by the language model.
type Input: for<'de> Deserialize<'de> + JsonSchema;
/// The output returned by executing the tool.
type Output: 'static;
/// The name of the tool is exposed to the language model to allow
/// the model to pick which tools to use. As this name is used to
/// identify the tool within a tool registry, it should be unique.
fn name(&self) -> String;
/// A description of the tool that can be used to _prompt_ the model
/// as to what the tool does.
fn description(&self) -> String;
/// The OpenAI Function definition for the tool, for direct use with OpenAI's API.
fn definition(&self) -> ToolFunctionDefinition {
ToolFunctionDefinition {
name: self.name(),
description: self.description(),
parameters: schema_for!(Self::Input).schema,
}
}
/// Execute the tool
fn execute(&self, input: &Self::Input, cx: &AppContext) -> Task<Result<Self::Output>>;
fn render(
tool_call_id: &str,
input: &Self::Input,
output: &Self::Output,
cx: &mut WindowContext,
) -> AnyElement;
fn format(input: &Self::Input, output: &Self::Output) -> String;
}

View file

@ -457,6 +457,14 @@ impl Client {
})
}
pub fn production(cx: &mut AppContext) -> Arc<Self> {
let clock = Arc::new(clock::RealSystemClock);
let http = Arc::new(HttpClientWithUrl::new(
&ClientSettings::get_global(cx).server_url,
));
Self::new(clock, http.clone(), cx)
}
pub fn id(&self) -> u64 {
self.id.load(Ordering::SeqCst)
}
@ -1119,6 +1127,8 @@ impl Client {
if let Some((login, token)) =
IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref())
{
eprintln!("authenticate as admin {login}, {token}");
return Self::authenticate_as_admin(http, login.clone(), token.clone())
.await;
}

View file

@ -5,7 +5,8 @@
"maxbrunsfeld",
"iamnbutler",
"mikayla-maki",
"JosephTLyons"
"JosephTLyons",
"rgbkrk"
],
"channels": ["zed"],
"number_of_users": 100

View file

@ -1,5 +1,6 @@
use anyhow::{anyhow, Result};
use anyhow::{anyhow, Context as _, Result};
use rpc::proto;
use util::ResultExt as _;
pub fn language_model_request_to_open_ai(
request: proto::CompleteWithLanguageModel,
@ -9,24 +10,83 @@ pub fn language_model_request_to_open_ai(
messages: request
.messages
.into_iter()
.map(|message| {
.map(|message: proto::LanguageModelRequestMessage| {
let role = proto::LanguageModelRole::from_i32(message.role)
.ok_or_else(|| anyhow!("invalid role {}", message.role))?;
Ok(open_ai::RequestMessage {
role: match role {
proto::LanguageModelRole::LanguageModelUser => open_ai::Role::User,
proto::LanguageModelRole::LanguageModelAssistant => {
open_ai::Role::Assistant
}
proto::LanguageModelRole::LanguageModelSystem => open_ai::Role::System,
let openai_message = match role {
proto::LanguageModelRole::LanguageModelUser => open_ai::RequestMessage::User {
content: message.content,
},
content: message.content,
})
proto::LanguageModelRole::LanguageModelAssistant => {
open_ai::RequestMessage::Assistant {
content: Some(message.content),
tool_calls: message
.tool_calls
.into_iter()
.filter_map(|call| {
Some(open_ai::ToolCall {
id: call.id,
content: match call.variant? {
proto::tool_call::Variant::Function(f) => {
open_ai::ToolCallContent::Function {
function: open_ai::FunctionContent {
name: f.name,
arguments: f.arguments,
},
}
}
},
})
})
.collect(),
}
}
proto::LanguageModelRole::LanguageModelSystem => {
open_ai::RequestMessage::System {
content: message.content,
}
}
proto::LanguageModelRole::LanguageModelTool => open_ai::RequestMessage::Tool {
tool_call_id: message
.tool_call_id
.ok_or_else(|| anyhow!("tool message is missing tool call id"))?,
content: message.content,
},
};
Ok(openai_message)
})
.collect::<Result<Vec<open_ai::RequestMessage>>>()?,
stream: true,
stop: request.stop,
temperature: request.temperature,
tools: request
.tools
.into_iter()
.filter_map(|tool| {
Some(match tool.variant? {
proto::chat_completion_tool::Variant::Function(f) => {
open_ai::ToolDefinition::Function {
function: open_ai::FunctionDefinition {
name: f.name,
description: f.description,
parameters: if let Some(params) = &f.parameters {
Some(
serde_json::from_str(params)
.context("failed to deserialize tool parameters")
.log_err()?,
)
} else {
None
},
},
}
}
})
})
.collect(),
tool_choice: request.tool_choice,
})
}
@ -58,6 +118,9 @@ pub fn language_model_request_message_to_google_ai(
proto::LanguageModelRole::LanguageModelUser => google_ai::Role::User,
proto::LanguageModelRole::LanguageModelAssistant => google_ai::Role::Model,
proto::LanguageModelRole::LanguageModelSystem => google_ai::Role::User,
proto::LanguageModelRole::LanguageModelTool => {
Err(anyhow!("we don't handle tool calls with google ai yet"))?
}
},
})
}

View file

@ -775,9 +775,7 @@ impl Server {
Box::new(move |envelope, session| {
let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
let received_at = envelope.received_at;
tracing::info!(
"message received"
);
tracing::info!("message received");
let start_time = Instant::now();
let future = (handler)(*envelope, session);
async move {
@ -786,12 +784,24 @@ impl Server {
let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
let queue_duration_ms = total_duration_ms - processing_duration_ms;
let payload_type = M::NAME;
match result {
Err(error) => {
// todo!(), why isn't this logged inside the span?
tracing::error!(%error, total_duration_ms, processing_duration_ms, queue_duration_ms, payload_type, "error handling message")
tracing::error!(
?error,
total_duration_ms,
processing_duration_ms,
queue_duration_ms,
payload_type,
"error handling message"
)
}
Ok(()) => tracing::info!(total_duration_ms, processing_duration_ms, queue_duration_ms, "finished handling message"),
Ok(()) => tracing::info!(
total_duration_ms,
processing_duration_ms,
queue_duration_ms,
"finished handling message"
),
}
}
.boxed()
@ -4098,7 +4108,7 @@ async fn complete_with_open_ai(
crate::ai::language_model_request_to_open_ai(request)?,
)
.await
.context("open_ai::stream_completion request failed")?;
.context("open_ai::stream_completion request failed within collab")?;
while let Some(event) = completion_stream.next().await {
let event = event?;
@ -4113,8 +4123,32 @@ async fn complete_with_open_ai(
open_ai::Role::User => LanguageModelRole::LanguageModelUser,
open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
open_ai::Role::Tool => LanguageModelRole::LanguageModelTool,
} as i32),
content: choice.delta.content,
tool_calls: choice
.delta
.tool_calls
.into_iter()
.map(|delta| proto::ToolCallDelta {
index: delta.index as u32,
id: delta.id,
variant: match delta.function {
Some(function) => {
let name = function.name;
let arguments = function.arguments;
Some(proto::tool_call_delta::Variant::Function(
proto::tool_call_delta::FunctionCallDelta {
name,
arguments,
},
))
}
None => None,
},
})
.collect(),
}),
finish_reason: choice.finish_reason,
})
@ -4165,6 +4199,8 @@ async fn complete_with_google_ai(
})
.collect(),
),
// Tool calls are not supported for Google
tool_calls: Vec::new(),
}),
finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
})
@ -4187,24 +4223,28 @@ async fn complete_with_anthropic(
let messages = request
.messages
.into_iter()
.filter_map(|message| match message.role() {
LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage {
role: anthropic::Role::User,
content: message.content,
}),
LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage {
role: anthropic::Role::Assistant,
content: message.content,
}),
// Anthropic's API breaks system instructions out as a separate field rather
// than having a system message role.
LanguageModelRole::LanguageModelSystem => {
if !system_message.is_empty() {
system_message.push_str("\n\n");
}
system_message.push_str(&message.content);
.filter_map(|message| {
match message.role() {
LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage {
role: anthropic::Role::User,
content: message.content,
}),
LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage {
role: anthropic::Role::Assistant,
content: message.content,
}),
// Anthropic's API breaks system instructions out as a separate field rather
// than having a system message role.
LanguageModelRole::LanguageModelSystem => {
if !system_message.is_empty() {
system_message.push_str("\n\n");
}
system_message.push_str(&message.content);
None
None
}
// We don't yet support tool calls for Anthropic
LanguageModelRole::LanguageModelTool => None,
}
})
.collect();
@ -4248,6 +4288,7 @@ async fn complete_with_anthropic(
delta: Some(proto::LanguageModelResponseMessage {
role: Some(current_role as i32),
content: Some(text),
tool_calls: Vec::new(),
}),
finish_reason: None,
}],
@ -4264,6 +4305,7 @@ async fn complete_with_anthropic(
delta: Some(proto::LanguageModelResponseMessage {
role: Some(current_role as i32),
content: Some(text),
tool_calls: Vec::new(),
}),
finish_reason: None,
}],

View file

@ -234,10 +234,11 @@ impl ChatPanel {
let channel_id = chat.read(cx).channel_id;
{
self.markdown_data.clear();
let chat = chat.read(cx);
self.message_list.reset(chat.message_count());
let chat = chat.read(cx);
let channel_name = chat.channel(cx).map(|channel| channel.name.clone());
let message_count = chat.message_count();
self.message_list.reset(message_count);
self.message_editor.update(cx, |editor, cx| {
editor.set_channel(channel_id, channel_name, cx);
editor.clear_reply_to_message_id();
@ -766,7 +767,7 @@ impl ChatPanel {
body.push_str(MESSAGE_EDITED);
}
let mut rich_text = rich_text::render_rich_text(body, &mentions, language_registry, None);
let mut rich_text = RichText::new(body, &mentions, language_registry);
if message.edited_at.is_some() {
let range = (rich_text.text.len() - MESSAGE_EDITED.len())..rich_text.text.len();

View file

@ -2947,7 +2947,7 @@ impl Render for DraggedChannelView {
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl Element {
let ui_font = ThemeSettings::get_global(cx).ui_font.family.clone();
h_flex()
.font(ui_font)
.font_family(ui_font)
.bg(cx.theme().colors().background)
.w(self.width)
.p_1()

View file

@ -125,7 +125,7 @@ impl Render for IncomingCallNotification {
cx.set_rem_size(ui_font_size);
div().size_full().font(ui_font).child(
div().size_full().font_family(ui_font).child(
CollabNotification::new(
self.state.call.calling_user.avatar_uri.clone(),
Button::new("accept", "Accept").on_click({

View file

@ -129,7 +129,7 @@ impl Render for ProjectSharedNotification {
cx.set_rem_size(ui_font_size);
div().size_full().font(ui_font).child(
div().size_full().font_family(ui_font).child(
CollabNotification::new(
self.owner.avatar_uri.clone(),
Button::new("open", "Open").on_click(cx.listener(move |this, _event, cx| {

View file

@ -61,13 +61,13 @@ use fuzzy::{StringMatch, StringMatchCandidate};
use git::blame::GitBlame;
use git::diff_hunk_to_display;
use gpui::{
div, impl_actions, point, prelude::*, px, relative, rems, size, uniform_list, Action,
AnyElement, AppContext, AsyncWindowContext, AvailableSpace, BackgroundExecutor, Bounds,
ClipboardItem, Context, DispatchPhase, ElementId, EventEmitter, FocusHandle, FocusableView,
FontId, FontStyle, FontWeight, HighlightStyle, Hsla, InteractiveText, KeyContext, Model,
MouseButton, PaintQuad, ParentElement, Pixels, Render, SharedString, Size, StrikethroughStyle,
Styled, StyledText, Subscription, Task, TextStyle, UnderlineStyle, UniformListScrollHandle,
View, ViewContext, ViewInputHandler, VisualContext, WeakView, WhiteSpace, WindowContext,
div, impl_actions, point, prelude::*, px, relative, size, uniform_list, Action, AnyElement,
AppContext, AsyncWindowContext, AvailableSpace, BackgroundExecutor, Bounds, ClipboardItem,
Context, DispatchPhase, ElementId, EventEmitter, FocusHandle, FocusableView, FontId, FontStyle,
FontWeight, HighlightStyle, Hsla, InteractiveText, KeyContext, Model, MouseButton, PaintQuad,
ParentElement, Pixels, Render, SharedString, Size, StrikethroughStyle, Styled, StyledText,
Subscription, Task, TextStyle, UnderlineStyle, UniformListScrollHandle, View, ViewContext,
ViewInputHandler, VisualContext, WeakView, WhiteSpace, WindowContext,
};
use highlight_matching_bracket::refresh_matching_bracket_highlights;
use hover_popover::{hide_hover, HoverState};
@ -8885,7 +8885,6 @@ impl Editor {
self.style = Some(style);
}
#[cfg(any(test, feature = "test-support"))]
pub fn style(&self) -> Option<&EditorStyle> {
self.style.as_ref()
}
@ -10322,21 +10321,9 @@ impl FocusableView for Editor {
impl Render for Editor {
fn render<'a>(&mut self, cx: &mut ViewContext<'a, Self>) -> impl IntoElement {
let settings = ThemeSettings::get_global(cx);
let text_style = match self.mode {
EditorMode::SingleLine | EditorMode::AutoHeight { .. } => TextStyle {
color: cx.theme().colors().editor_foreground,
font_family: settings.ui_font.family.clone(),
font_features: settings.ui_font.features,
font_size: rems(0.875).into(),
font_weight: FontWeight::NORMAL,
font_style: FontStyle::Normal,
line_height: relative(settings.buffer_line_height.value()),
background_color: None,
underline: None,
strikethrough: None,
white_space: WhiteSpace::Normal,
},
let text_style = match self.mode {
EditorMode::SingleLine | EditorMode::AutoHeight { .. } => cx.text_style(),
EditorMode::Full => TextStyle {
color: cx.theme().colors().editor_foreground,
font_family: settings.buffer_font.family.clone(),

View file

@ -3056,7 +3056,7 @@ fn render_inline_blame_entry(
h_flex()
.id("inline-blame")
.w_full()
.font(style.text.font().family)
.font_family(style.text.font().family)
.text_color(cx.theme().status().hint)
.line_height(style.text.line_height)
.child(Icon::new(IconName::FileGit).color(Color::Hint))
@ -3108,7 +3108,7 @@ fn render_blame_entry(
h_flex()
.w_full()
.font(style.text.font().family)
.font_family(style.text.font().family)
.line_height(style.text.line_height)
.id(("blame", ix))
.children([

View file

@ -1,7 +1,7 @@
use crate::{
self as gpui, hsla, point, px, relative, rems, AbsoluteLength, AlignItems, CursorStyle,
DefiniteLength, Fill, FlexDirection, FlexWrap, FontStyle, FontWeight, Hsla, JustifyContent,
Length, Position, SharedString, StyleRefinement, Visibility, WhiteSpace,
DefiniteLength, Fill, FlexDirection, FlexWrap, Font, FontStyle, FontWeight, Hsla,
JustifyContent, Length, Position, SharedString, StyleRefinement, Visibility, WhiteSpace,
};
use crate::{BoxShadow, TextStyleRefinement};
use smallvec::{smallvec, SmallVec};
@ -771,14 +771,32 @@ pub trait Styled: Sized {
self
}
/// Change the font on this element and its children.
fn font(mut self, family_name: impl Into<SharedString>) -> Self {
/// Change the font family on this element and its children.
fn font_family(mut self, family_name: impl Into<SharedString>) -> Self {
self.text_style()
.get_or_insert_with(Default::default)
.font_family = Some(family_name.into());
self
}
/// Change the font of this element and its children.
fn font(mut self, font: Font) -> Self {
let Font {
family,
features,
weight,
style,
} = font;
let text_style = self.text_style().get_or_insert_with(Default::default);
text_style.font_family = Some(family);
text_style.font_features = Some(features);
text_style.font_weight = Some(weight);
text_style.font_style = Some(style);
self
}
/// Set the line height on this element and its children.
fn line_height(mut self, line_height: impl Into<DefiniteLength>) -> Self {
self.text_style()

View file

@ -1,6 +1,7 @@
use anyhow::{anyhow, Context, Result};
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
use serde::{Deserialize, Serialize};
use serde_json::{Map, Value};
use std::{convert::TryFrom, future::Future};
use util::http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
@ -12,6 +13,7 @@ pub enum Role {
User,
Assistant,
System,
Tool,
}
impl TryFrom<String> for Role {
@ -22,6 +24,7 @@ impl TryFrom<String> for Role {
"user" => Ok(Self::User),
"assistant" => Ok(Self::Assistant),
"system" => Ok(Self::System),
"tool" => Ok(Self::Tool),
_ => Err(anyhow!("invalid role '{value}'")),
}
}
@ -33,6 +36,7 @@ impl From<Role> for String {
Role::User => "user".to_owned(),
Role::Assistant => "assistant".to_owned(),
Role::System => "system".to_owned(),
Role::Tool => "tool".to_owned(),
}
}
}
@ -91,18 +95,88 @@ pub struct Request {
pub stream: bool,
pub stop: Vec<String>,
pub temperature: f32,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<String>,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub tools: Vec<ToolDefinition>,
}
#[derive(Debug, Serialize)]
pub struct FunctionDefinition {
pub name: String,
pub description: Option<String>,
pub parameters: Option<Map<String, Value>>,
}
#[derive(Serialize, Debug)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ToolDefinition {
#[allow(dead_code)]
Function { function: FunctionDefinition },
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct RequestMessage {
pub role: Role,
pub content: String,
#[serde(tag = "role", rename_all = "lowercase")]
pub enum RequestMessage {
Assistant {
content: Option<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
tool_calls: Vec<ToolCall>,
},
User {
content: String,
},
System {
content: String,
},
Tool {
content: String,
tool_call_id: String,
},
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct ResponseMessage {
pub struct ToolCall {
pub id: String,
#[serde(flatten)]
pub content: ToolCallContent,
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum ToolCallContent {
Function { function: FunctionContent },
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct FunctionContent {
pub name: String,
pub arguments: String,
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct ResponseMessageDelta {
pub role: Option<Role>,
pub content: Option<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub tool_calls: Vec<ToolCallChunk>,
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct ToolCallChunk {
pub index: usize,
pub id: Option<String>,
// There is also an optional `type` field that would determine if a
// function is there. Sometimes this streams in with the `function` before
// it streams in the `type`
pub function: Option<FunctionChunk>,
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct FunctionChunk {
pub name: Option<String>,
pub arguments: Option<String>,
}
#[derive(Deserialize, Debug)]
@ -115,7 +189,7 @@ pub struct Usage {
#[derive(Deserialize, Debug)]
pub struct ChoiceDelta {
pub index: u32,
pub delta: ResponseMessage,
pub delta: ResponseMessageDelta,
pub finish_reason: Option<String>,
}

View file

@ -1843,7 +1843,7 @@ impl Render for DraggedProjectEntryView {
let settings = ProjectPanelSettings::get_global(cx);
let ui_font = ThemeSettings::get_global(cx).ui_font.family.clone();
h_flex()
.font(ui_font)
.font_family(ui_font)
.bg(cx.theme().colors().background)
.w(self.width)
.child(

View file

@ -507,7 +507,7 @@ impl RemoteProjects {
.my_1()
.py_0p5()
.px_3()
.font(ThemeSettings::get_global(cx).buffer_font.family.clone())
.font_family(ThemeSettings::get_global(cx).buffer_font.family.clone())
.child(Label::new(instructions))
)
.when(status == DevServerStatus::Offline, |this| {

View file

@ -43,6 +43,19 @@ pub struct RichText {
Option<Arc<dyn Fn(usize, Range<usize>, &mut WindowContext) -> Option<AnyView>>>,
}
impl Default for RichText {
fn default() -> Self {
Self {
text: SharedString::default(),
highlights: Vec::new(),
link_ranges: Vec::new(),
link_urls: Arc::from([]),
custom_ranges: Vec::new(),
custom_ranges_tooltip_fn: None,
}
}
}
/// Allows one to specify extra links to the rendered markdown, which can be used
/// for e.g. mentions.
#[derive(Debug)]
@ -52,6 +65,37 @@ pub struct Mention {
}
impl RichText {
pub fn new(
block: String,
mentions: &[Mention],
language_registry: &Arc<LanguageRegistry>,
) -> Self {
let mut text = String::new();
let mut highlights = Vec::new();
let mut link_ranges = Vec::new();
let mut link_urls = Vec::new();
render_markdown_mut(
&block,
mentions,
language_registry,
None,
&mut text,
&mut highlights,
&mut link_ranges,
&mut link_urls,
);
text.truncate(text.trim_end().len());
RichText {
text: SharedString::from(text),
link_urls: link_urls.into(),
link_ranges,
highlights,
custom_ranges: Vec::new(),
custom_ranges_tooltip_fn: None,
}
}
pub fn set_tooltip_builder_for_custom_ranges(
&mut self,
f: impl Fn(usize, Range<usize>, &mut WindowContext) -> Option<AnyView> + 'static,
@ -347,38 +391,6 @@ pub fn render_markdown_mut(
}
}
pub fn render_rich_text(
block: String,
mentions: &[Mention],
language_registry: &Arc<LanguageRegistry>,
language: Option<&Arc<Language>>,
) -> RichText {
let mut text = String::new();
let mut highlights = Vec::new();
let mut link_ranges = Vec::new();
let mut link_urls = Vec::new();
render_markdown_mut(
&block,
mentions,
language_registry,
language,
&mut text,
&mut highlights,
&mut link_ranges,
&mut link_urls,
);
text.truncate(text.trim_end().len());
RichText {
text: SharedString::from(text),
link_urls: link_urls.into(),
link_ranges,
highlights,
custom_ranges: Vec::new(),
custom_ranges_tooltip_fn: None,
}
}
pub fn render_code(
text: &mut String,
highlights: &mut Vec<(Range<usize>, Highlight)>,

View file

@ -1880,22 +1880,70 @@ message CompleteWithLanguageModel {
repeated LanguageModelRequestMessage messages = 2;
repeated string stop = 3;
float temperature = 4;
repeated ChatCompletionTool tools = 5;
optional string tool_choice = 6;
}
// A tool presented to the language model for its use
message ChatCompletionTool {
oneof variant {
FunctionObject function = 1;
}
message FunctionObject {
string name = 1;
optional string description = 2;
optional string parameters = 3;
}
}
// A message to the language model
message LanguageModelRequestMessage {
LanguageModelRole role = 1;
string content = 2;
optional string tool_call_id = 3;
repeated ToolCall tool_calls = 4;
}
enum LanguageModelRole {
LanguageModelUser = 0;
LanguageModelAssistant = 1;
LanguageModelSystem = 2;
LanguageModelTool = 3;
}
message LanguageModelResponseMessage {
optional LanguageModelRole role = 1;
optional string content = 2;
repeated ToolCallDelta tool_calls = 3;
}
// A request to call a tool, by the language model
message ToolCall {
string id = 1;
oneof variant {
FunctionCall function = 2;
}
message FunctionCall {
string name = 1;
string arguments = 2;
}
}
message ToolCallDelta {
uint32 index = 1;
optional string id = 2;
oneof variant {
FunctionCallDelta function = 3;
}
message FunctionCallDelta {
optional string name = 1;
optional string arguments = 2;
}
}
message LanguageModelResponse {

View file

@ -12,6 +12,11 @@ workspace = true
[lib]
path = "src/semantic_index.rs"
[[example]]
name = "index"
path = "examples/index.rs"
crate-type = ["bin"]
[dependencies]
anyhow.workspace = true
client.workspace = true

View file

@ -1,25 +1,16 @@
use client::Client;
use futures::channel::oneshot;
use gpui::{App, Global, TestAppContext};
use gpui::{App, Global};
use language::language_settings::AllLanguageSettings;
use project::Project;
use semantic_index::{OpenAiEmbeddingModel, OpenAiEmbeddingProvider, SemanticIndex};
use settings::SettingsStore;
use std::{path::Path, sync::Arc};
use std::{
path::{Path, PathBuf},
sync::Arc,
};
use util::http::HttpClientWithUrl;
pub fn init_test(cx: &mut TestAppContext) {
_ = cx.update(|cx| {
let store = SettingsStore::test(cx);
cx.set_global(store);
language::init(cx);
Project::init_settings(cx);
SettingsStore::update(cx, |store, cx| {
store.update_user_settings::<AllLanguageSettings>(cx, |_| {});
});
});
}
fn main() {
env_logger::init();
@ -50,20 +41,21 @@ fn main() {
// let embedding_provider = semantic_index::FakeEmbeddingProvider;
let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
let embedding_provider = OpenAiEmbeddingProvider::new(
let embedding_provider = Arc::new(OpenAiEmbeddingProvider::new(
http.clone(),
OpenAiEmbeddingModel::TextEmbedding3Small,
open_ai::OPEN_AI_API_URL.to_string(),
api_key,
);
let semantic_index = SemanticIndex::new(
Path::new("/tmp/semantic-index-db.mdb"),
Arc::new(embedding_provider),
cx,
);
));
cx.spawn(|mut cx| async move {
let semantic_index = SemanticIndex::new(
PathBuf::from("/tmp/semantic-index-db.mdb"),
embedding_provider,
&mut cx,
);
let mut semantic_index = semantic_index.await.unwrap();
let project_path = Path::new(&args[1]);

View file

@ -21,7 +21,7 @@ use std::{
cmp::Ordering,
future::Future,
ops::Range,
path::Path,
path::{Path, PathBuf},
sync::Arc,
time::{Duration, SystemTime},
};
@ -37,30 +37,29 @@ pub struct SemanticIndex {
impl Global for SemanticIndex {}
impl SemanticIndex {
pub fn new(
db_path: &Path,
pub async fn new(
db_path: PathBuf,
embedding_provider: Arc<dyn EmbeddingProvider>,
cx: &mut AppContext,
) -> Task<Result<Self>> {
let db_path = db_path.to_path_buf();
cx.spawn(|cx| async move {
let db_connection = cx
.background_executor()
.spawn(async move {
unsafe {
heed::EnvOpenOptions::new()
.map_size(1024 * 1024 * 1024)
.max_dbs(3000)
.open(db_path)
}
})
.await?;
Ok(SemanticIndex {
db_connection,
embedding_provider,
project_indices: HashMap::default(),
cx: &mut AsyncAppContext,
) -> Result<Self> {
let db_connection = cx
.background_executor()
.spawn(async move {
std::fs::create_dir_all(&db_path)?;
unsafe {
heed::EnvOpenOptions::new()
.map_size(1024 * 1024 * 1024)
.max_dbs(3000)
.open(db_path)
}
})
.await
.context("opening database connection")?;
Ok(SemanticIndex {
db_connection,
embedding_provider,
project_indices: HashMap::default(),
})
}
@ -91,7 +90,7 @@ pub struct ProjectIndex {
worktree_indices: HashMap<EntityId, WorktreeIndexHandle>,
language_registry: Arc<LanguageRegistry>,
fs: Arc<dyn Fs>,
last_status: Status,
pub last_status: Status,
embedding_provider: Arc<dyn EmbeddingProvider>,
_subscription: Subscription,
}
@ -397,7 +396,7 @@ impl WorktreeIndex {
) -> impl Future<Output = Result<()>> {
let worktree = self.worktree.read(cx).as_local().unwrap().snapshot();
let worktree_abs_path = worktree.abs_path().clone();
let scan = self.scan_updated_entries(worktree, updated_entries, cx);
let scan = self.scan_updated_entries(worktree, updated_entries.clone(), cx);
let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
let embed = self.embed_files(chunk.files, cx);
let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
@ -498,7 +497,9 @@ impl WorktreeIndex {
| project::PathChange::Updated
| project::PathChange::AddedOrUpdated => {
if let Some(entry) = worktree.entry_for_id(*entry_id) {
updated_entries_tx.send(entry.clone()).await?;
if entry.is_file() {
updated_entries_tx.send(entry.clone()).await?;
}
}
}
project::PathChange::Removed => {
@ -539,7 +540,14 @@ impl WorktreeIndex {
cx.spawn(async {
while let Ok(entry) = entries.recv().await {
let entry_abs_path = worktree_abs_path.join(&entry.path);
let Some(text) = fs.load(&entry_abs_path).await.log_err() else {
let Some(text) = fs
.load(&entry_abs_path)
.await
.with_context(|| {
format!("failed to read path {entry_abs_path:?}")
})
.log_err()
else {
continue;
};
let language = language_registry
@ -683,7 +691,7 @@ impl WorktreeIndex {
.context("failed to create read transaction")?;
let db_entries = db.iter(&txn).context("failed to iterate database")?;
for db_entry in db_entries {
let (_, db_embedded_file) = db_entry?;
let (_key, db_embedded_file) = db_entry?;
for chunk in db_embedded_file.chunks {
chunks_tx
.send((db_embedded_file.path.clone(), chunk))
@ -700,6 +708,7 @@ impl WorktreeIndex {
cx.spawn(|cx| async move {
#[cfg(debug_assertions)]
let embedding_query_start = std::time::Instant::now();
log::info!("Searching for {query}");
let mut query_embeddings = embedding_provider
.embed(&[TextToEmbed::new(&query)])
@ -876,17 +885,13 @@ mod tests {
let temp_dir = tempfile::tempdir().unwrap();
let mut semantic_index = cx
.update(|cx| {
let semantic_index = SemanticIndex::new(
Path::new(temp_dir.path()),
Arc::new(TestEmbeddingProvider),
cx,
);
semantic_index
})
.await
.unwrap();
let mut semantic_index = SemanticIndex::new(
temp_dir.path().into(),
Arc::new(TestEmbeddingProvider),
&mut cx.to_async(),
)
.await
.unwrap();
let project_path = Path::new("./fixture");

View file

@ -2,6 +2,7 @@ mod keymap_file;
mod settings_file;
mod settings_store;
use gpui::AppContext;
use rust_embed::RustEmbed;
use std::{borrow::Cow, str};
use util::asset_str;
@ -19,6 +20,14 @@ pub use settings_store::{
#[exclude = "*.DS_Store"]
pub struct SettingsAssets;
pub fn init(cx: &mut AppContext) {
let mut settings = SettingsStore::default();
settings
.set_default_settings(&default_settings(), cx)
.unwrap();
cx.set_global(settings);
}
pub fn default_settings() -> Cow<'static, str> {
asset_str::<SettingsAssets>("settings/default.json")
}

View file

@ -29,14 +29,14 @@ pub enum ComponentStory {
ListHeader,
ListItem,
OverflowScroll,
Picker,
Scroll,
Tab,
TabBar,
Text,
TitleBar,
ToggleButton,
Text,
ViewportUnits,
Picker,
}
impl ComponentStory {

View file

@ -11,7 +11,7 @@ use gpui::{
};
use log::LevelFilter;
use project::Project;
use settings::{default_settings, KeymapFile, Settings, SettingsStore};
use settings::{KeymapFile, Settings};
use simplelog::SimpleLogger;
use strum::IntoEnumIterator;
use theme::{ThemeRegistry, ThemeSettings};
@ -64,12 +64,7 @@ fn main() {
gpui::App::new().with_assets(Assets).run(move |cx| {
load_embedded_fonts(cx).unwrap();
let mut store = SettingsStore::default();
store
.set_default_settings(default_settings().as_ref(), cx)
.unwrap();
cx.set_global(store);
settings::init(cx);
theme::init(theme::LoadThemes::All(Box::new(Assets)), cx);
let selector = story_selector;
@ -122,7 +117,7 @@ impl Render for StoryWrapper {
.flex()
.flex_col()
.size_full()
.font("Zed Mono")
.font_family("Zed Mono")
.child(self.story.clone())
}
}

View file

@ -1,6 +1,7 @@
mod avatar;
mod button;
mod checkbox;
mod collapsible_container;
mod context_menu;
mod disclosure;
mod divider;
@ -25,6 +26,7 @@ mod stories;
pub use avatar::*;
pub use button::*;
pub use checkbox::*;
pub use collapsible_container::*;
pub use context_menu::*;
pub use disclosure::*;
pub use divider::*;

View file

@ -0,0 +1,152 @@
use crate::{prelude::*, ButtonLike};
use smallvec::SmallVec;
use gpui::*;
#[derive(Default, Clone, Copy, Debug, PartialEq)]
pub enum ContainerStyle {
#[default]
None,
Card,
}
struct ContainerStyles {
pub background_color: Hsla,
pub border_color: Hsla,
pub text_color: Hsla,
}
#[derive(IntoElement)]
pub struct CollapsibleContainer {
id: ElementId,
base: ButtonLike,
toggle: bool,
/// A slot for content that appears before the label, like an icon or avatar.
start_slot: Option<AnyElement>,
/// A slot for content that appears after the label, usually on the other side of the header.
/// This might be a button, a disclosure arrow, a face pile, etc.
end_slot: Option<AnyElement>,
style: ContainerStyle,
children: SmallVec<[AnyElement; 1]>,
}
impl CollapsibleContainer {
pub fn new(id: impl Into<ElementId>, toggle: bool) -> Self {
Self {
id: id.into(),
base: ButtonLike::new("button_base"),
toggle,
start_slot: None,
end_slot: None,
style: ContainerStyle::Card,
children: SmallVec::new(),
}
}
pub fn start_slot<E: IntoElement>(mut self, start_slot: impl Into<Option<E>>) -> Self {
self.start_slot = start_slot.into().map(IntoElement::into_any_element);
self
}
pub fn end_slot<E: IntoElement>(mut self, end_slot: impl Into<Option<E>>) -> Self {
self.end_slot = end_slot.into().map(IntoElement::into_any_element);
self
}
pub fn child<E: IntoElement>(mut self, child: E) -> Self {
self.children.push(child.into_any_element());
self
}
}
impl Clickable for CollapsibleContainer {
fn on_click(mut self, handler: impl Fn(&ClickEvent, &mut WindowContext) + 'static) -> Self {
self.base = self.base.on_click(handler);
self
}
}
impl RenderOnce for CollapsibleContainer {
fn render(self, cx: &mut WindowContext) -> impl IntoElement {
let color = cx.theme().colors();
let styles = match self.style {
ContainerStyle::None => ContainerStyles {
background_color: color.ghost_element_background,
border_color: color.border_transparent,
text_color: color.text,
},
ContainerStyle::Card => ContainerStyles {
background_color: color.elevated_surface_background,
border_color: color.border,
text_color: color.text,
},
};
v_flex()
.id(self.id)
.relative()
.rounded_md()
.bg(styles.background_color)
.border()
.border_color(styles.border_color)
.text_color(styles.text_color)
.overflow_hidden()
.child(
h_flex()
.overflow_hidden()
.w_full()
.group("toggleable_container_header")
.border_b()
.border_color(if self.toggle {
styles.border_color
} else {
color.border_transparent
})
.child(
self.base.full_width().style(ButtonStyle::Subtle).child(
div()
.h_7()
.p_1()
.flex()
.flex_1()
.items_center()
.justify_between()
.w_full()
.gap_1()
.cursor_pointer()
.group_hover("toggleable_container_header", |this| {
this.bg(color.element_hover)
})
.child(
h_flex()
.gap_1()
.child(
IconButton::new(
"toggle_icon",
match self.toggle {
true => IconName::ChevronDown,
false => IconName::ChevronRight,
},
)
.icon_color(Color::Muted)
.icon_size(IconSize::XSmall),
)
.child(
div()
.id("label_container")
.flex()
.gap_1()
.items_center()
.children(self.start_slot),
),
)
.child(h_flex().children(self.end_slot)),
),
),
)
.when(self.toggle, |this| {
this.child(h_flex().flex_1().w_full().p_1().children(self.children))
})
}
}

View file

@ -110,7 +110,7 @@ impl RenderOnce for WindowsCaptionButton {
.content_center()
.w(width)
.h_full()
.font("Segoe Fluent Icons")
.font_family("Segoe Fluent Icons")
.text_size(px(10.0))
.hover(|style| style.bg(self.hover_background_color))
.active(|style| {

View file

@ -95,7 +95,7 @@ pub fn tooltip_container<V>(
div().pl_2().pt_2p5().child(
v_flex()
.elevation_2(cx)
.font(ui_font)
.font_family(ui_font)
.text_ui()
.text_color(cx.theme().colors().text)
.py_1()

View file

@ -93,7 +93,7 @@ impl RenderOnce for Headline {
let ui_font = ThemeSettings::get_global(cx).ui_font.family.clone();
div()
.font(ui_font)
.font_family(ui_font)
.line_height(self.size.line_height())
.text_size(self.size.size())
.text_color(cx.theme().colors().text)

View file

@ -2928,6 +2928,6 @@ impl Render for DraggedTab {
.selected(self.is_active)
.child(label)
.render(cx)
.font(ui_font)
.font_family(ui_font)
}
}

View file

@ -4004,7 +4004,7 @@ impl Render for Workspace {
.size_full()
.flex()
.flex_col()
.font(ui_font)
.font_family(ui_font)
.gap_0()
.justify_start()
.items_start()

View file

@ -19,6 +19,7 @@ activity_indicator.workspace = true
anyhow.workspace = true
assets.workspace = true
assistant.workspace = true
assistant2.workspace = true
audio.workspace = true
auto_update.workspace = true
backtrace = "0.3"

View file

@ -231,27 +231,18 @@ fn init_ui(args: Args) {
load_embedded_fonts(cx);
let mut store = SettingsStore::default();
store
.set_default_settings(default_settings().as_ref(), cx)
.unwrap();
cx.set_global(store);
settings::init(cx);
handle_settings_file_changes(user_settings_file_rx, cx);
handle_keymap_file_changes(user_keymap_file_rx, cx);
client::init_settings(cx);
let clock = Arc::new(clock::RealSystemClock);
let http = Arc::new(HttpClientWithUrl::new(
&client::ClientSettings::get_global(cx).server_url,
));
let client = client::Client::new(clock, http.clone(), cx);
let client = Client::production(cx);
let mut languages =
LanguageRegistry::new(login_shell_env_loaded, cx.background_executor().clone());
let copilot_language_server_id = languages.next_language_server_id();
languages.set_language_server_download_dir(paths::LANGUAGES_DIR.clone());
let languages = Arc::new(languages);
let node_runtime = RealNodeRuntime::new(http.clone());
let node_runtime = RealNodeRuntime::new(client.http_client());
language::init(cx);
languages::init(languages.clone(), node_runtime.clone(), cx);
@ -271,11 +262,14 @@ fn init_ui(args: Args) {
diagnostics::init(cx);
copilot::init(
copilot_language_server_id,
http.clone(),
client.http_client(),
node_runtime.clone(),
cx,
);
assistant::init(client.clone(), cx);
assistant2::init(client.clone(), cx);
init_inline_completion_provider(client.telemetry().clone(), cx);
extension::init(
@ -297,7 +291,7 @@ fn init_ui(args: Args) {
cx.observe_global::<SettingsStore>({
let languages = languages.clone();
let http = http.clone();
let http = client.http_client();
let client = client.clone();
move |cx| {
@ -345,7 +339,7 @@ fn init_ui(args: Args) {
AppState::set_global(Arc::downgrade(&app_state), cx);
audio::init(Assets, cx);
auto_update::init(http.clone(), cx);
auto_update::init(client.http_client(), cx);
workspace::init(app_state.clone(), cx);
recent_projects::init(cx);
@ -378,7 +372,7 @@ fn init_ui(args: Args) {
initialize_workspace(app_state.clone(), cx);
// todo(linux): unblock this
upload_panics_and_crashes(http.clone(), cx);
upload_panics_and_crashes(client.http_client(), cx);
cx.activate(true);

View file

@ -3,7 +3,6 @@ mod only_instance;
mod open_listener;
pub use app_menus::*;
use assistant::AssistantPanel;
use breadcrumbs::Breadcrumbs;
use client::ZED_URL_SCHEME;
use collections::VecDeque;
@ -181,10 +180,12 @@ pub fn initialize_workspace(app_state: Arc<AppState>, cx: &mut AppContext) {
})
});
}
cx.spawn(|workspace_handle, mut cx| async move {
let assistant_panel =
assistant::AssistantPanel::load(workspace_handle.clone(), cx.clone());
let project_panel = ProjectPanel::load(workspace_handle.clone(), cx.clone());
let terminal_panel = TerminalPanel::load(workspace_handle.clone(), cx.clone());
let assistant_panel = AssistantPanel::load(workspace_handle.clone(), cx.clone());
let channels_panel =
collab_ui::collab_panel::CollabPanel::load(workspace_handle.clone(), cx.clone());
let chat_panel =
@ -193,6 +194,7 @@ pub fn initialize_workspace(app_state: Arc<AppState>, cx: &mut AppContext) {
workspace_handle.clone(),
cx.clone(),
);
let (
project_panel,
terminal_panel,
@ -210,9 +212,9 @@ pub fn initialize_workspace(app_state: Arc<AppState>, cx: &mut AppContext) {
)?;
workspace_handle.update(&mut cx, |workspace, cx| {
workspace.add_panel(assistant_panel, cx);
workspace.add_panel(project_panel, cx);
workspace.add_panel(terminal_panel, cx);
workspace.add_panel(assistant_panel, cx);
workspace.add_panel(channels_panel, cx);
workspace.add_panel(chat_panel, cx);
workspace.add_panel(notification_panel, cx);
@ -221,6 +223,30 @@ pub fn initialize_workspace(app_state: Arc<AppState>, cx: &mut AppContext) {
})
.detach();
let mut current_user = app_state.user_store.read(cx).watch_current_user();
cx.spawn(|workspace_handle, mut cx| async move {
while let Some(user) = current_user.next().await {
if user.is_some() {
// User known now, can check feature flags / staff
// At this point, should have the user with staff status available
let use_assistant2 = cx.update(|cx| assistant2::enabled(cx))?;
if use_assistant2 {
let panel =
assistant2::AssistantPanel::load(workspace_handle.clone(), cx.clone())
.await?;
workspace_handle.update(&mut cx, |workspace, cx| {
workspace.add_panel(panel, cx);
})?;
}
break;
}
}
anyhow::Ok(())
})
.detach();
workspace
.register_action(about)
.register_action(|_, _: &Minimize, cx| {
@ -3028,11 +3054,7 @@ mod tests {
])
.unwrap();
let themes = ThemeRegistry::default();
let mut settings = SettingsStore::default();
settings
.set_default_settings(&settings::default_settings(), cx)
.unwrap();
cx.set_global(settings);
settings::init(cx);
theme::init(theme::LoadThemes::JustBase, cx);
let mut has_default_theme = false;

View file

@ -147,7 +147,7 @@ setTimeout(() => {
}
spawn(binaryPath, i == 0 ? args : [], {
stdio: "inherit",
env: {
env: Object.assign({}, process.env, {
ZED_IMPERSONATE: users[i],
ZED_WINDOW_POSITION: position,
ZED_STATELESS: isStateful && i == 0 ? "1" : "",
@ -157,9 +157,8 @@ setTimeout(() => {
ZED_ADMIN_API_TOKEN: "secret",
ZED_WINDOW_SIZE: size,
ZED_CLIENT_CHECKSUM_SEED: "development-checksum-seed",
PATH: process.env.PATH,
RUST_LOG: process.env.RUST_LOG || "info",
},
}),
});
}
}, 0.1);