diff --git a/.zed/tasks.json b/.zed/tasks.json index 80465969e2..c95cf5ffb1 100644 --- a/.zed/tasks.json +++ b/.zed/tasks.json @@ -3,5 +3,10 @@ "label": "clippy", "command": "cargo", "args": ["xtask", "clippy"] + }, + { + "label": "assistant2", + "command": "cargo", + "args": ["run", "-p", "assistant2", "--example", "assistant_example"] } ] diff --git a/Cargo.lock b/Cargo.lock index 3d22a64359..ca625dd461 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -371,6 +371,50 @@ dependencies = [ "workspace", ] +[[package]] +name = "assistant2" +version = "0.1.0" +dependencies = [ + "anyhow", + "assets", + "assistant_tooling", + "client", + "editor", + "env_logger", + "feature_flags", + "futures 0.3.28", + "gpui", + "language", + "languages", + "log", + "nanoid", + "node_runtime", + "open_ai", + "project", + "release_channel", + "rich_text", + "schemars", + "semantic_index", + "serde", + "serde_json", + "settings", + "theme", + "ui", + "util", + "workspace", +] + +[[package]] +name = "assistant_tooling" +version = "0.1.0" +dependencies = [ + "anyhow", + "gpui", + "schemars", + "serde", + "serde_json", +] + [[package]] name = "async-broadcast" version = "0.7.0" @@ -643,7 +687,7 @@ checksum = "5fd55a5ba1179988837d24ab4c7cc8ed6efdeff578ede0416b4225a5fca35bd0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.59", ] [[package]] @@ -710,7 +754,7 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.59", ] [[package]] @@ -741,7 +785,7 @@ checksum = "c980ee35e870bd1a4d2c8294d4c04d0499e67bca1e4b5cefcc693c2fa00caea9" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.59", ] [[package]] @@ -1385,7 +1429,7 @@ dependencies = [ "regex", "rustc-hash", "shlex", - "syn 2.0.48", + "syn 2.0.59", "which 4.4.2", ] @@ -1468,7 +1512,7 @@ source = "git+https://github.com/kvark/blade?rev=810ec594358aafea29a4a3d8ab601d2 dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.59", ] [[package]] @@ -1634,7 +1678,7 @@ checksum = "965ab7eb5f8f97d2a083c799f3a1b994fc397b2fe2da5d1da1626ce15a39f2b1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.59", ] [[package]] @@ -2019,7 +2063,7 @@ dependencies = [ "heck 0.4.1", "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.59", ] [[package]] @@ -2959,7 +3003,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "30d2b3721e861707777e3195b0158f950ae6dc4a27e4d02ff9f67e3eb3de199e" dependencies = [ "quote", - "syn 2.0.48", + "syn 2.0.59", ] [[package]] @@ -3442,7 +3486,7 @@ checksum = "5c785274071b1b420972453b306eeca06acf4633829db4223b58a2a8c5953bc4" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.59", ] [[package]] @@ -3954,7 +3998,7 @@ checksum = "1a5c6c585bc94aaf2c7b51dd4c2ba22680844aba4c687be581871a6f518c5742" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.59", ] [[package]] @@ -4194,7 +4238,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.59", ] [[package]] @@ -5061,7 +5105,7 @@ checksum = "ce243b1bfa62ffc028f1cc3b6034ec63d649f3031bc8a4fbbb004e1ac17d1f68" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.59", ] [[package]] @@ -5682,7 +5726,7 @@ checksum = "ba125974b109d512fccbc6c0244e7580143e460895dfd6ea7f8bbb692fd94396" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.59", ] [[package]] @@ -6643,7 +6687,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.59", ] [[package]] @@ -6719,7 +6763,7 @@ dependencies = [ "proc-macro-error", "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.59", ] [[package]] @@ -6799,7 +6843,7 @@ checksum = "e8890702dbec0bad9116041ae586f84805b13eecd1d8b1df27c29998a9969d6d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.59", ] [[package]] @@ -6977,7 +7021,7 @@ dependencies = [ "phf_shared", "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.59", ] [[package]] @@ -7028,7 +7072,7 @@ checksum = "4359fd9c9171ec6e8c62926d6faaf553a8dc3f64e1507e76da7911b4f6a04405" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.59", ] [[package]] @@ -7252,7 +7296,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ae005bd773ab59b4725093fd7df83fd7892f7d8eafb48dbd7de6e024e4215f9d" dependencies = [ "proc-macro2", - "syn 2.0.48", + "syn 2.0.59", ] [[package]] @@ -7309,9 +7353,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.78" +version = "1.0.81" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2422ad645d89c99f8f3e6b88a9fdeca7fabeac836b1002371c4367c8f984aae" +checksum = "3d1597b0c024618f09a9c3b8655b7e430397a36d23fdafec26d6965e9eec3eba" dependencies = [ "unicode-ident", ] @@ -7332,7 +7376,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8021cf59c8ec9c432cfc2526ac6b8aa508ecaf29cd415f271b8406c1b851c3fd" dependencies = [ "quote", - "syn 2.0.48", + "syn 2.0.59", ] [[package]] @@ -8175,7 +8219,7 @@ dependencies = [ "proc-macro2", "quote", "rust-embed-utils", - "syn 2.0.48", + "syn 2.0.59", "walkdir", ] @@ -8449,7 +8493,7 @@ dependencies = [ "proc-macro-error", "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.59", ] [[package]] @@ -8490,7 +8534,7 @@ dependencies = [ "proc-macro2", "quote", "sea-bae", - "syn 2.0.48", + "syn 2.0.59", "unicode-ident", ] @@ -8674,7 +8718,7 @@ checksum = "33c85360c95e7d137454dc81d9a4ed2b8efd8fbe19cee57357b32b9771fccb67" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.59", ] [[package]] @@ -8739,7 +8783,7 @@ checksum = "8725e1dfadb3a50f7e5ce0b1a540466f6ed3fe7a0fca2ac2b8b831d31316bd00" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.59", ] [[package]] @@ -9505,7 +9549,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.48", + "syn 2.0.59", ] [[package]] @@ -9634,9 +9678,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.48" +version = "2.0.59" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f3531638e407dfc0814761abb7c00a5b54992b849452a0646b7f65c9f770f3f" +checksum = "4a6531ffc7b071655e4ce2e04bd464c4830bb585a61cabb96cf808f05172615a" dependencies = [ "proc-macro2", "quote", @@ -10001,7 +10045,7 @@ checksum = "49922ecae66cc8a249b77e68d1d0623c1b2c514f0060c27cdc68bd62a1219d35" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.59", ] [[package]] @@ -10180,7 +10224,7 @@ checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.59", ] [[package]] @@ -10405,7 +10449,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.59", ] [[package]] @@ -11172,7 +11216,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.59", "wasm-bindgen-shared", ] @@ -11206,7 +11250,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.59", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -11343,7 +11387,7 @@ dependencies = [ "anyhow", "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.59", "wasmtime-component-util", "wasmtime-wit-bindgen", "wit-parser", @@ -11504,7 +11548,7 @@ checksum = "6d6d967f01032da7d4c6303da32f6a00d5efe1bac124b156e7342d8ace6ffdfc" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.59", ] [[package]] @@ -11784,7 +11828,7 @@ dependencies = [ "proc-macro2", "quote", "shellexpand", - "syn 2.0.48", + "syn 2.0.59", "witx", ] @@ -11796,7 +11840,7 @@ checksum = "512d816dbcd0113103b2eb2402ec9018e7f0755202a5b3e67db726f229d8dcae" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.59", "wiggle-generate", ] @@ -11914,7 +11958,7 @@ checksum = "942ac266be9249c84ca862f0a164a39533dc2f6f33dc98ec89c8da99b82ea0bd" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.59", ] [[package]] @@ -11925,7 +11969,7 @@ checksum = "da33557140a288fae4e1d5f8873aaf9eb6613a9cf82c3e070223ff177f598b60" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.59", ] [[package]] @@ -12242,7 +12286,7 @@ dependencies = [ "anyhow", "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.59", "wit-bindgen-core", "wit-bindgen-rust", ] @@ -12567,6 +12611,7 @@ dependencies = [ "anyhow", "assets", "assistant", + "assistant2", "audio", "auto_update", "backtrace", @@ -12860,7 +12905,7 @@ checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.59", ] [[package]] @@ -12880,7 +12925,7 @@ checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.59", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index d2ff0c5066..adb9f461a6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,8 @@ members = [ "crates/anthropic", "crates/assets", "crates/assistant", + "crates/assistant_tooling", + "crates/assistant2", "crates/audio", "crates/auto_update", "crates/breadcrumbs", @@ -137,6 +139,8 @@ ai = { path = "crates/ai" } anthropic = { path = "crates/anthropic" } assets = { path = "crates/assets" } assistant = { path = "crates/assistant" } +assistant2 = { path = "crates/assistant2" } +assistant_tooling = { path = "crates/assistant_tooling" } audio = { path = "crates/audio" } auto_update = { path = "crates/auto_update" } base64 = "0.13" @@ -208,6 +212,7 @@ rpc = { path = "crates/rpc" } task = { path = "crates/task" } tasks_ui = { path = "crates/tasks_ui" } search = { path = "crates/search" } +semantic_index = { path = "crates/semantic_index" } semantic_version = { path = "crates/semantic_version" } settings = { path = "crates/settings" } snippet = { path = "crates/snippet" } diff --git a/assets/keymaps/default-macos.json b/assets/keymaps/default-macos.json index f909bd48c5..f4da3078ad 100644 --- a/assets/keymaps/default-macos.json +++ b/assets/keymaps/default-macos.json @@ -209,7 +209,14 @@ } }, { - "context": "AssistantPanel", + "context": "AssistantChat > Editor", // Used in the assistant2 crate + "bindings": { + "enter": ["assistant2::Submit", "Simple"], + "cmd-enter": ["assistant2::Submit", "Codebase"] + } + }, + { + "context": "AssistantPanel", // Used in the assistant crate, which we're replacing "bindings": { "cmd-g": "search::SelectNextMatch", "cmd-shift-g": "search::SelectPrevMatch" diff --git a/crates/assets/Cargo.toml b/crates/assets/Cargo.toml index 8fcb1f9cfe..06f91da59f 100644 --- a/crates/assets/Cargo.toml +++ b/crates/assets/Cargo.toml @@ -5,6 +5,9 @@ edition = "2021" publish = false license = "GPL-3.0-or-later" +[lib] +path = "src/assets.rs" + [lints] workspace = true diff --git a/crates/assets/src/lib.rs b/crates/assets/src/assets.rs similarity index 62% rename from crates/assets/src/lib.rs rename to crates/assets/src/assets.rs index 4f013dd5af..b0a32a9d9c 100644 --- a/crates/assets/src/lib.rs +++ b/crates/assets/src/assets.rs @@ -1,7 +1,7 @@ // This crate was essentially pulled out verbatim from main `zed` crate to avoid having to run RustEmbed macro whenever zed has to be rebuilt. It saves a second or two on an incremental build. use anyhow::anyhow; -use gpui::{AssetSource, Result, SharedString}; +use gpui::{AppContext, AssetSource, Result, SharedString}; use rust_embed::RustEmbed; #[derive(RustEmbed)] @@ -34,3 +34,19 @@ impl AssetSource for Assets { .collect()) } } + +impl Assets { + /// Populate the [`TextSystem`] of the given [`AppContext`] with all `.ttf` fonts in the `fonts` directory. + pub fn load_fonts(&self, cx: &AppContext) -> gpui::Result<()> { + let font_paths = self.list("fonts")?; + let mut embedded_fonts = Vec::new(); + for font_path in font_paths { + if font_path.ends_with(".ttf") { + let font_bytes = cx.asset_source().load(&font_path)?; + embedded_fonts.push(font_bytes); + } + } + + cx.text_system().add_fonts(embedded_fonts) + } +} diff --git a/crates/assistant/src/assistant.rs b/crates/assistant/src/assistant.rs index 9d72b512a1..46eeb4c095 100644 --- a/crates/assistant/src/assistant.rs +++ b/crates/assistant/src/assistant.rs @@ -128,6 +128,8 @@ impl LanguageModelRequestMessage { Role::System => proto::LanguageModelRole::LanguageModelSystem, } as i32, content: self.content.clone(), + tool_calls: Vec::new(), + tool_call_id: None, } } } @@ -147,6 +149,8 @@ impl LanguageModelRequest { messages: self.messages.iter().map(|m| m.to_proto()).collect(), stop: self.stop.clone(), temperature: self.temperature, + tool_choice: None, + tools: Vec::new(), } } } diff --git a/crates/assistant/src/completion_provider/open_ai.rs b/crates/assistant/src/completion_provider/open_ai.rs index f4c29a47e8..458a3a9d25 100644 --- a/crates/assistant/src/completion_provider/open_ai.rs +++ b/crates/assistant/src/completion_provider/open_ai.rs @@ -140,14 +140,24 @@ impl OpenAiCompletionProvider { messages: request .messages .into_iter() - .map(|msg| RequestMessage { - role: msg.role.into(), - content: msg.content, + .map(|msg| match msg.role { + Role::User => RequestMessage::User { + content: msg.content, + }, + Role::Assistant => RequestMessage::Assistant { + content: Some(msg.content), + tool_calls: Vec::new(), + }, + Role::System => RequestMessage::System { + content: msg.content, + }, }) .collect(), stream: true, stop: request.stop, temperature: request.temperature, + tools: Vec::new(), + tool_choice: None, } } } diff --git a/crates/assistant/src/completion_provider/zed.rs b/crates/assistant/src/completion_provider/zed.rs index 1ec852da19..ed84f1f7c6 100644 --- a/crates/assistant/src/completion_provider/zed.rs +++ b/crates/assistant/src/completion_provider/zed.rs @@ -123,6 +123,8 @@ impl ZedDotDevCompletionProvider { .collect(), stop: request.stop, temperature: request.temperature, + tools: Vec::new(), + tool_choice: None, }; self.client diff --git a/crates/assistant2/Cargo.toml b/crates/assistant2/Cargo.toml new file mode 100644 index 0000000000..060dbaa98b --- /dev/null +++ b/crates/assistant2/Cargo.toml @@ -0,0 +1,56 @@ +[package] +name = "assistant2" +version = "0.1.0" +edition = "2021" +publish = false +license = "GPL-3.0-or-later" + +[lib] +path = "src/assistant2.rs" + +[[example]] +name = "assistant_example" +path = "examples/assistant_example.rs" +crate-type = ["bin"] + +[dependencies] +anyhow.workspace = true +assistant_tooling.workspace = true +client.workspace = true +editor.workspace = true +feature_flags.workspace = true +futures.workspace = true +gpui.workspace = true +language.workspace = true +log.workspace = true +open_ai.workspace = true +project.workspace = true +rich_text.workspace = true +semantic_index.workspace = true +schemars.workspace = true +serde.workspace = true +serde_json.workspace = true +settings.workspace = true +theme.workspace = true +ui.workspace = true +util.workspace = true +workspace.workspace = true +nanoid = "0.4" + +[dev-dependencies] +assets.workspace = true +editor = { workspace = true, features = ["test-support"] } +env_logger.workspace = true +gpui = { workspace = true, features = ["test-support"] } +language = { workspace = true, features = ["test-support"] } +languages.workspace = true +node_runtime.workspace = true +project = { workspace = true, features = ["test-support"] } +release_channel.workspace = true +settings = { workspace = true, features = ["test-support"] } +theme = { workspace = true, features = ["test-support"] } +util = { workspace = true, features = ["test-support"] } +workspace = { workspace = true, features = ["test-support"] } + +[lints] +workspace = true diff --git a/crates/assistant2/LICENSE-GPL b/crates/assistant2/LICENSE-GPL new file mode 120000 index 0000000000..89e542f750 --- /dev/null +++ b/crates/assistant2/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/assistant2/examples/assistant_example.rs b/crates/assistant2/examples/assistant_example.rs new file mode 100644 index 0000000000..260c3bc8f9 --- /dev/null +++ b/crates/assistant2/examples/assistant_example.rs @@ -0,0 +1,129 @@ +use anyhow::Context as _; +use assets::Assets; +use assistant2::{tools::ProjectIndexTool, AssistantPanel}; +use assistant_tooling::ToolRegistry; +use client::Client; +use gpui::{actions, App, AppContext, KeyBinding, Task, View, WindowOptions}; +use language::LanguageRegistry; +use project::Project; +use semantic_index::{OpenAiEmbeddingModel, OpenAiEmbeddingProvider, SemanticIndex}; +use settings::{KeymapFile, DEFAULT_KEYMAP_PATH}; +use std::{ + path::{Path, PathBuf}, + sync::Arc, +}; +use theme::LoadThemes; +use ui::{div, prelude::*, Render}; +use util::{http::HttpClientWithUrl, ResultExt as _}; + +actions!(example, [Quit]); + +fn main() { + let args: Vec = std::env::args().collect(); + + env_logger::init(); + App::new().with_assets(Assets).run(|cx| { + cx.bind_keys(Some(KeyBinding::new("cmd-q", Quit, None))); + cx.on_action(|_: &Quit, cx: &mut AppContext| { + cx.quit(); + }); + + if args.len() < 2 { + eprintln!( + "Usage: cargo run --example assistant_example -p assistant2 -- " + ); + cx.quit(); + return; + } + + settings::init(cx); + language::init(cx); + Project::init_settings(cx); + editor::init(cx); + theme::init(LoadThemes::JustBase, cx); + Assets.load_fonts(cx).unwrap(); + KeymapFile::load_asset(DEFAULT_KEYMAP_PATH, cx).unwrap(); + client::init_settings(cx); + release_channel::init("0.130.0", cx); + + let client = Client::production(cx); + { + let client = client.clone(); + cx.spawn(|cx| async move { client.authenticate_and_connect(false, &cx).await }) + .detach_and_log_err(cx); + } + assistant2::init(client.clone(), cx); + + let language_registry = Arc::new(LanguageRegistry::new( + Task::ready(()), + cx.background_executor().clone(), + )); + let node_runtime = node_runtime::RealNodeRuntime::new(client.http_client()); + languages::init(language_registry.clone(), node_runtime, cx); + + let http = Arc::new(HttpClientWithUrl::new("http://localhost:11434")); + + let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); + let embedding_provider = OpenAiEmbeddingProvider::new( + http.clone(), + OpenAiEmbeddingModel::TextEmbedding3Small, + open_ai::OPEN_AI_API_URL.to_string(), + api_key, + ); + + cx.spawn(|mut cx| async move { + let mut semantic_index = SemanticIndex::new( + PathBuf::from("/tmp/semantic-index-db.mdb"), + Arc::new(embedding_provider), + &mut cx, + ) + .await?; + + let project_path = Path::new(&args[1]); + let project = Project::example([project_path], &mut cx).await; + + cx.update(|cx| { + let fs = project.read(cx).fs().clone(); + + let project_index = semantic_index.project_index(project.clone(), cx); + + let mut tool_registry = ToolRegistry::new(); + tool_registry + .register(ProjectIndexTool::new(project_index.clone(), fs.clone())) + .context("failed to register ProjectIndexTool") + .log_err(); + + let tool_registry = Arc::new(tool_registry); + + cx.open_window(WindowOptions::default(), |cx| { + cx.new_view(|cx| Example::new(language_registry, tool_registry, cx)) + }); + cx.activate(true); + }) + }) + .detach_and_log_err(cx); + }) +} + +struct Example { + assistant_panel: View, +} + +impl Example { + fn new( + language_registry: Arc, + tool_registry: Arc, + cx: &mut ViewContext, + ) -> Self { + Self { + assistant_panel: cx + .new_view(|cx| AssistantPanel::new(language_registry, tool_registry, cx)), + } + } +} + +impl Render for Example { + fn render(&mut self, _cx: &mut ViewContext) -> impl ui::prelude::IntoElement { + div().size_full().child(self.assistant_panel.clone()) + } +} diff --git a/crates/assistant2/src/assistant2.rs b/crates/assistant2/src/assistant2.rs new file mode 100644 index 0000000000..5a9d6c8df6 --- /dev/null +++ b/crates/assistant2/src/assistant2.rs @@ -0,0 +1,952 @@ +mod assistant_settings; +mod completion_provider; +pub mod tools; + +use anyhow::{Context, Result}; +use assistant_tooling::{ToolFunctionCall, ToolRegistry}; +use client::{proto, Client}; +use completion_provider::*; +use editor::{Editor, EditorEvent}; +use feature_flags::FeatureFlagAppExt as _; +use futures::{channel::oneshot, future::join_all, Future, FutureExt, StreamExt}; +use gpui::{ + list, prelude::*, AnyElement, AppContext, AsyncWindowContext, EventEmitter, FocusHandle, + FocusableView, Global, ListAlignment, ListState, Model, Render, Task, View, WeakView, +}; +use language::{language_settings::SoftWrap, LanguageRegistry}; +use open_ai::{FunctionContent, ToolCall, ToolCallContent}; +use project::Fs; +use rich_text::RichText; +use semantic_index::{CloudEmbeddingProvider, ProjectIndex, SemanticIndex}; +use serde::Deserialize; +use settings::Settings; +use std::{cmp, sync::Arc}; +use theme::ThemeSettings; +use tools::ProjectIndexTool; +use ui::{popover_menu, prelude::*, ButtonLike, CollapsibleContainer, Color, ContextMenu, Tooltip}; +use util::{paths::EMBEDDINGS_DIR, ResultExt}; +use workspace::{ + dock::{DockPosition, Panel, PanelEvent}, + Workspace, +}; + +pub use assistant_settings::AssistantSettings; + +const MAX_COMPLETION_CALLS_PER_SUBMISSION: usize = 5; + +// gpui::actions!(assistant, [Submit]); + +#[derive(Eq, PartialEq, Copy, Clone, Deserialize)] +pub struct Submit(SubmitMode); + +/// There are multiple different ways to submit a model request, represented by this enum. +#[derive(Eq, PartialEq, Copy, Clone, Deserialize)] +pub enum SubmitMode { + /// Only include the conversation. + Simple, + /// Send the current file as context. + CurrentFile, + /// Search the codebase and send relevant excerpts. + Codebase, +} + +gpui::actions!(assistant2, [ToggleFocus]); +gpui::impl_actions!(assistant2, [Submit]); + +pub fn init(client: Arc, cx: &mut AppContext) { + AssistantSettings::register(cx); + + cx.spawn(|mut cx| { + let client = client.clone(); + async move { + let embedding_provider = CloudEmbeddingProvider::new(client.clone()); + let semantic_index = SemanticIndex::new( + EMBEDDINGS_DIR.join("semantic-index-db.0.mdb"), + Arc::new(embedding_provider), + &mut cx, + ) + .await?; + cx.update(|cx| cx.set_global(semantic_index)) + } + }) + .detach(); + + cx.set_global(CompletionProvider::new(CloudCompletionProvider::new( + client, + ))); + + cx.observe_new_views( + |workspace: &mut Workspace, _cx: &mut ViewContext| { + workspace.register_action(|workspace, _: &ToggleFocus, cx| { + workspace.toggle_panel_focus::(cx); + }); + }, + ) + .detach(); +} + +pub fn enabled(cx: &AppContext) -> bool { + cx.is_staff() +} + +pub struct AssistantPanel { + chat: View, + width: Option, +} + +impl AssistantPanel { + pub fn load( + workspace: WeakView, + cx: AsyncWindowContext, + ) -> Task>> { + cx.spawn(|mut cx| async move { + let (app_state, project) = workspace.update(&mut cx, |workspace, _| { + (workspace.app_state().clone(), workspace.project().clone()) + })?; + + cx.new_view(|cx| { + // todo!("this will panic if the semantic index failed to load or has not loaded yet") + let project_index = cx.update_global(|semantic_index: &mut SemanticIndex, cx| { + semantic_index.project_index(project.clone(), cx) + }); + + let mut tool_registry = ToolRegistry::new(); + tool_registry + .register(ProjectIndexTool::new( + project_index.clone(), + app_state.fs.clone(), + )) + .context("failed to register ProjectIndexTool") + .log_err(); + + let tool_registry = Arc::new(tool_registry); + + Self::new(app_state.languages.clone(), tool_registry, cx) + }) + }) + } + + pub fn new( + language_registry: Arc, + tool_registry: Arc, + cx: &mut ViewContext, + ) -> Self { + let chat = cx.new_view(|cx| { + AssistantChat::new(language_registry.clone(), tool_registry.clone(), cx) + }); + + Self { width: None, chat } + } +} + +impl Render for AssistantPanel { + fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { + div() + .size_full() + .v_flex() + .p_2() + .bg(cx.theme().colors().background) + .child(self.chat.clone()) + } +} + +impl Panel for AssistantPanel { + fn persistent_name() -> &'static str { + "AssistantPanelv2" + } + + fn position(&self, _cx: &WindowContext) -> workspace::dock::DockPosition { + // todo!("Add a setting / use assistant settings") + DockPosition::Right + } + + fn position_is_valid(&self, position: workspace::dock::DockPosition) -> bool { + matches!(position, DockPosition::Right) + } + + fn set_position(&mut self, _: workspace::dock::DockPosition, _: &mut ViewContext) { + // Do nothing until we have a setting for this + } + + fn size(&self, _cx: &WindowContext) -> Pixels { + self.width.unwrap_or(px(400.)) + } + + fn set_size(&mut self, size: Option, cx: &mut ViewContext) { + self.width = size; + cx.notify(); + } + + fn icon(&self, _cx: &WindowContext) -> Option { + Some(IconName::Ai) + } + + fn icon_tooltip(&self, _: &WindowContext) -> Option<&'static str> { + Some("Assistant Panel ✨") + } + + fn toggle_action(&self) -> Box { + Box::new(ToggleFocus) + } +} + +impl EventEmitter for AssistantPanel {} + +impl FocusableView for AssistantPanel { + fn focus_handle(&self, cx: &AppContext) -> FocusHandle { + self.chat + .read(cx) + .messages + .iter() + .rev() + .find_map(|msg| msg.focus_handle(cx)) + .expect("no user message in chat") + } +} + +struct AssistantChat { + model: String, + messages: Vec, + list_state: ListState, + language_registry: Arc, + next_message_id: MessageId, + pending_completion: Option>, + tool_registry: Arc, +} + +impl AssistantChat { + fn new( + language_registry: Arc, + tool_registry: Arc, + cx: &mut ViewContext, + ) -> Self { + let model = CompletionProvider::get(cx).default_model(); + let view = cx.view().downgrade(); + let list_state = ListState::new( + 0, + ListAlignment::Bottom, + px(1024.), + move |ix, cx: &mut WindowContext| { + view.update(cx, |this, cx| this.render_message(ix, cx)) + .unwrap() + }, + ); + + let mut this = Self { + model, + messages: Vec::new(), + list_state, + language_registry, + next_message_id: MessageId(0), + pending_completion: None, + tool_registry, + }; + this.push_new_user_message(true, cx); + this + } + + fn focused_message_id(&self, cx: &WindowContext) -> Option { + self.messages.iter().find_map(|message| match message { + ChatMessage::User(message) => message + .body + .focus_handle(cx) + .contains_focused(cx) + .then_some(message.id), + ChatMessage::Assistant(_) => None, + }) + } + + fn submit(&mut self, Submit(mode): &Submit, cx: &mut ViewContext) { + let Some(focused_message_id) = self.focused_message_id(cx) else { + log::error!("unexpected state: no user message editor is focused."); + return; + }; + + self.truncate_messages(focused_message_id, cx); + + let mode = *mode; + self.pending_completion = Some(cx.spawn(move |this, mut cx| async move { + Self::request_completion( + this.clone(), + mode, + MAX_COMPLETION_CALLS_PER_SUBMISSION, + &mut cx, + ) + .await + .log_err(); + + this.update(&mut cx, |this, cx| { + let focus = this + .user_message(focused_message_id) + .body + .focus_handle(cx) + .contains_focused(cx); + this.push_new_user_message(focus, cx); + }) + .context("Failed to push new user message") + .log_err(); + })); + } + + async fn request_completion( + this: WeakView, + mode: SubmitMode, + limit: usize, + cx: &mut AsyncWindowContext, + ) -> Result<()> { + let mut call_count = 0; + loop { + let complete = async { + let completion = this.update(cx, |this, cx| { + this.push_new_assistant_message(cx); + + let definitions = if call_count < limit && matches!(mode, SubmitMode::Codebase) + { + this.tool_registry.definitions() + } else { + &[] + }; + call_count += 1; + + CompletionProvider::get(cx).complete( + this.model.clone(), + this.completion_messages(cx), + Vec::new(), + 1.0, + definitions, + ) + }); + + let mut stream = completion?.await?; + let mut body = String::new(); + while let Some(delta) = stream.next().await { + let delta = delta?; + this.update(cx, |this, cx| { + if let Some(ChatMessage::Assistant(AssistantMessage { + body: message_body, + tool_calls: message_tool_calls, + .. + })) = this.messages.last_mut() + { + if let Some(content) = &delta.content { + body.push_str(content); + } + + for tool_call in delta.tool_calls { + let index = tool_call.index as usize; + if index >= message_tool_calls.len() { + message_tool_calls.resize_with(index + 1, Default::default); + } + let call = &mut message_tool_calls[index]; + + if let Some(id) = &tool_call.id { + call.id.push_str(id); + } + + match tool_call.variant { + Some(proto::tool_call_delta::Variant::Function(tool_call)) => { + if let Some(name) = &tool_call.name { + call.name.push_str(name); + } + if let Some(arguments) = &tool_call.arguments { + call.arguments.push_str(arguments); + } + } + None => {} + } + } + + *message_body = + RichText::new(body.clone(), &[], &this.language_registry); + cx.notify(); + } else { + unreachable!() + } + })?; + } + + anyhow::Ok(()) + } + .await; + + let mut tool_tasks = Vec::new(); + this.update(cx, |this, cx| { + if let Some(ChatMessage::Assistant(AssistantMessage { + error: message_error, + tool_calls, + .. + })) = this.messages.last_mut() + { + if let Err(error) = complete { + message_error.replace(SharedString::from(error.to_string())); + cx.notify(); + } else { + for tool_call in tool_calls.iter() { + tool_tasks.push(this.tool_registry.call(tool_call, cx)); + } + } + } + })?; + + if tool_tasks.is_empty() { + return Ok(()); + } + + let tools = join_all(tool_tasks.into_iter()).await; + this.update(cx, |this, cx| { + if let Some(ChatMessage::Assistant(AssistantMessage { tool_calls, .. })) = + this.messages.last_mut() + { + *tool_calls = tools; + cx.notify(); + } + })?; + } + } + + fn user_message(&mut self, message_id: MessageId) -> &mut UserMessage { + self.messages + .iter_mut() + .find_map(|message| match message { + ChatMessage::User(user_message) if user_message.id == message_id => { + Some(user_message) + } + _ => None, + }) + .expect("User message not found") + } + + fn push_new_user_message(&mut self, focus: bool, cx: &mut ViewContext) { + let id = self.next_message_id.post_inc(); + let body = cx.new_view(|cx| { + let mut editor = Editor::auto_height(80, cx); + editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx); + if focus { + cx.focus_self(); + } + editor + }); + let _subscription = cx.subscribe(&body, move |this, editor, event, cx| match event { + EditorEvent::SelectionsChanged { .. } => { + if editor.read(cx).is_focused(cx) { + let (message_ix, _message) = this + .messages + .iter() + .enumerate() + .find_map(|(ix, message)| match message { + ChatMessage::User(user_message) if user_message.id == id => { + Some((ix, user_message)) + } + _ => None, + }) + .expect("user message not found"); + + this.list_state.scroll_to_reveal_item(message_ix); + } + } + _ => {} + }); + let message = ChatMessage::User(UserMessage { + id, + body, + contexts: Vec::new(), + _subscription, + }); + self.push_message(message, cx); + } + + fn push_new_assistant_message(&mut self, cx: &mut ViewContext) { + let message = ChatMessage::Assistant(AssistantMessage { + id: self.next_message_id.post_inc(), + body: RichText::default(), + tool_calls: Vec::new(), + error: None, + }); + self.push_message(message, cx); + } + + fn push_message(&mut self, message: ChatMessage, cx: &mut ViewContext) { + let old_len = self.messages.len(); + let focus_handle = Some(message.focus_handle(cx)); + self.messages.push(message); + self.list_state + .splice_focusable(old_len..old_len, focus_handle); + cx.notify(); + } + + fn truncate_messages(&mut self, last_message_id: MessageId, cx: &mut ViewContext) { + if let Some(index) = self.messages.iter().position(|message| match message { + ChatMessage::User(message) => message.id == last_message_id, + ChatMessage::Assistant(message) => message.id == last_message_id, + }) { + self.list_state.splice(index + 1..self.messages.len(), 0); + self.messages.truncate(index + 1); + cx.notify(); + } + } + + fn render_error( + &self, + error: Option, + _ix: usize, + cx: &mut ViewContext, + ) -> AnyElement { + let theme = cx.theme(); + + if let Some(error) = error { + div() + .py_1() + .px_2() + .neg_mx_1() + .rounded_md() + .border() + .border_color(theme.status().error_border) + // .bg(theme.status().error_background) + .text_color(theme.status().error) + .child(error.clone()) + .into_any_element() + } else { + div().into_any_element() + } + } + + fn render_message(&self, ix: usize, cx: &mut ViewContext) -> AnyElement { + let is_last = ix == self.messages.len() - 1; + + match &self.messages[ix] { + ChatMessage::User(UserMessage { + body, + contexts: _contexts, + .. + }) => div() + .when(!is_last, |element| element.mb_2()) + .child(div().p_2().child(Label::new("You").color(Color::Default))) + .child( + div() + .on_action(cx.listener(Self::submit)) + .p_2() + .text_color(cx.theme().colors().editor_foreground) + .font(ThemeSettings::get_global(cx).buffer_font.clone()) + .bg(cx.theme().colors().editor_background) + .child(body.clone()), // .children(contexts.iter().map(|context| context.render(cx))), + ) + .into_any(), + ChatMessage::Assistant(AssistantMessage { + id, + body, + error, + tool_calls, + .. + }) => { + let assistant_body = if body.text.is_empty() && !tool_calls.is_empty() { + div() + } else { + div().p_2().child(body.element(ElementId::from(id.0), cx)) + }; + + div() + .when(!is_last, |element| element.mb_2()) + .child( + div() + .p_2() + .child(Label::new("Assistant").color(Color::Modified)), + ) + .child(assistant_body) + .child(self.render_error(error.clone(), ix, cx)) + .children(tool_calls.iter().map(|tool_call| { + let result = &tool_call.result; + let name = tool_call.name.clone(); + match result { + Some(result) => div() + .p_2() + .child(result.render(&name, &tool_call.id, cx)) + .into_any(), + None => div() + .p_2() + .child(Label::new(name).color(Color::Modified)) + .child("Running...") + .into_any(), + } + })) + .into_any() + } + } + } + + fn completion_messages(&self, cx: &WindowContext) -> Vec { + let mut completion_messages = Vec::new(); + + for message in &self.messages { + match message { + ChatMessage::User(UserMessage { body, contexts, .. }) => { + // setup context for model + contexts.iter().for_each(|context| { + completion_messages.extend(context.completion_messages(cx)) + }); + + // Show user's message last so that the assistant is grounded in the user's request + completion_messages.push(CompletionMessage::User { + content: body.read(cx).text(cx), + }); + } + ChatMessage::Assistant(AssistantMessage { + body, tool_calls, .. + }) => { + // In no case do we want to send an empty message. This shouldn't happen, but we might as well + // not break the Chat API if it does. + if body.text.is_empty() && tool_calls.is_empty() { + continue; + } + + let tool_calls_from_assistant = tool_calls + .iter() + .map(|tool_call| ToolCall { + content: ToolCallContent::Function { + function: FunctionContent { + name: tool_call.name.clone(), + arguments: tool_call.arguments.clone(), + }, + }, + id: tool_call.id.clone(), + }) + .collect(); + + completion_messages.push(CompletionMessage::Assistant { + content: Some(body.text.to_string()), + tool_calls: tool_calls_from_assistant, + }); + + for tool_call in tool_calls { + // todo!(): we should not be sending when the tool is still running / has no result + // For now I'm going to have to assume we send an empty string because otherwise + // the Chat API will break -- there is a required message for every tool call by ID + let content = match &tool_call.result { + Some(result) => result.format(&tool_call.name), + None => "".to_string(), + }; + + completion_messages.push(CompletionMessage::Tool { + content, + tool_call_id: tool_call.id.clone(), + }); + } + } + } + } + + completion_messages + } + + fn render_model_dropdown(&self, cx: &mut ViewContext) -> impl IntoElement { + let this = cx.view().downgrade(); + div().h_flex().justify_end().child( + div().w_32().child( + popover_menu("user-menu") + .menu(move |cx| { + ContextMenu::build(cx, |mut menu, cx| { + for model in CompletionProvider::get(cx).available_models() { + menu = menu.custom_entry( + { + let model = model.clone(); + move |_| Label::new(model.clone()).into_any_element() + }, + { + let this = this.clone(); + move |cx| { + _ = this.update(cx, |this, cx| { + this.model = model.clone(); + cx.notify(); + }); + } + }, + ); + } + menu + }) + .into() + }) + .trigger( + ButtonLike::new("active-model") + .child( + h_flex() + .w_full() + .gap_0p5() + .child( + div() + .overflow_x_hidden() + .flex_grow() + .whitespace_nowrap() + .child(Label::new(self.model.clone())), + ) + .child(div().child( + Icon::new(IconName::ChevronDown).color(Color::Muted), + )), + ) + .style(ButtonStyle::Subtle) + .tooltip(move |cx| Tooltip::text("Change Model", cx)), + ) + .anchor(gpui::AnchorCorner::TopRight), + ), + ) + } +} + +impl Render for AssistantChat { + fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { + div() + .relative() + .flex_1() + .v_flex() + .key_context("AssistantChat") + .text_color(Color::Default.color(cx)) + .child(self.render_model_dropdown(cx)) + .child(list(self.list_state.clone()).flex_1()) + } +} + +#[derive(Copy, Clone, Eq, PartialEq)] +struct MessageId(usize); + +impl MessageId { + fn post_inc(&mut self) -> Self { + let id = *self; + self.0 += 1; + id + } +} + +enum ChatMessage { + User(UserMessage), + Assistant(AssistantMessage), +} + +impl ChatMessage { + fn focus_handle(&self, cx: &AppContext) -> Option { + match self { + ChatMessage::User(UserMessage { body, .. }) => Some(body.focus_handle(cx)), + ChatMessage::Assistant(_) => None, + } + } +} + +struct UserMessage { + id: MessageId, + body: View, + contexts: Vec, + _subscription: gpui::Subscription, +} + +struct AssistantMessage { + id: MessageId, + body: RichText, + tool_calls: Vec, + error: Option, +} + +// Since we're swapping out for direct query usage, we might not need to use this injected context +// It will be useful though for when the user _definitely_ wants the model to see a specific file, +// query, error, etc. +#[allow(dead_code)] +enum AssistantContext { + Codebase(View), +} + +#[allow(dead_code)] +struct CodebaseExcerpt { + element_id: ElementId, + path: SharedString, + text: SharedString, + score: f32, + expanded: bool, +} + +impl AssistantContext { + #[allow(dead_code)] + fn render(&self, _cx: &mut ViewContext) -> AnyElement { + match self { + AssistantContext::Codebase(context) => context.clone().into_any_element(), + } + } + + fn completion_messages(&self, cx: &WindowContext) -> Vec { + match self { + AssistantContext::Codebase(context) => context.read(cx).completion_messages(), + } + } +} + +enum CodebaseContext { + Pending { _task: Task<()> }, + Done(Result>), +} + +impl CodebaseContext { + fn toggle_expanded(&mut self, element_id: ElementId, cx: &mut ViewContext) { + if let CodebaseContext::Done(Ok(excerpts)) = self { + if let Some(excerpt) = excerpts + .iter_mut() + .find(|excerpt| excerpt.element_id == element_id) + { + excerpt.expanded = !excerpt.expanded; + cx.notify(); + } + } + } +} + +impl Render for CodebaseContext { + fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { + match self { + CodebaseContext::Pending { .. } => div() + .h_flex() + .items_center() + .gap_1() + .child(Icon::new(IconName::Ai).color(Color::Muted).into_element()) + .child("Searching codebase..."), + CodebaseContext::Done(Ok(excerpts)) => { + div() + .v_flex() + .gap_2() + .children(excerpts.iter().map(|excerpt| { + let expanded = excerpt.expanded; + let element_id = excerpt.element_id.clone(); + + CollapsibleContainer::new(element_id.clone(), expanded) + .start_slot( + h_flex() + .gap_1() + .child(Icon::new(IconName::File).color(Color::Muted)) + .child(Label::new(excerpt.path.clone()).color(Color::Muted)), + ) + .on_click(cx.listener(move |this, _, cx| { + this.toggle_expanded(element_id.clone(), cx); + })) + .child( + div() + .p_2() + .rounded_md() + .bg(cx.theme().colors().editor_background) + .child( + excerpt.text.clone(), // todo!(): Show as an editor block + ), + ) + })) + } + CodebaseContext::Done(Err(error)) => div().child(error.to_string()), + } + } +} + +impl CodebaseContext { + #[allow(dead_code)] + fn new( + query: impl 'static + Future>, + populated: oneshot::Sender, + project_index: Model, + fs: Arc, + cx: &mut ViewContext, + ) -> Self { + let query = query.boxed_local(); + let _task = cx.spawn(|this, mut cx| async move { + let result = async { + let query = query.await?; + let results = this + .update(&mut cx, |_this, cx| { + project_index.read(cx).search(&query, 16, cx) + })? + .await; + + let excerpts = results.into_iter().map(|result| { + let abs_path = result + .worktree + .read_with(&cx, |worktree, _| worktree.abs_path().join(&result.path)); + let fs = fs.clone(); + + async move { + let path = result.path.clone(); + let text = fs.load(&abs_path?).await?; + // todo!("what should we do with stale ranges?"); + let range = cmp::min(result.range.start, text.len()) + ..cmp::min(result.range.end, text.len()); + + let text = SharedString::from(text[range].to_string()); + + anyhow::Ok(CodebaseExcerpt { + element_id: ElementId::Name(nanoid::nanoid!().into()), + path: path.to_string_lossy().to_string().into(), + text, + score: result.score, + expanded: false, + }) + } + }); + + anyhow::Ok( + futures::future::join_all(excerpts) + .await + .into_iter() + .filter_map(|result| result.log_err()) + .collect(), + ) + } + .await; + + this.update(&mut cx, |this, cx| { + this.populate(result, populated, cx); + }) + .ok(); + }); + + Self::Pending { _task } + } + + #[allow(dead_code)] + fn populate( + &mut self, + result: Result>, + populated: oneshot::Sender, + cx: &mut ViewContext, + ) { + let success = result.is_ok(); + *self = Self::Done(result); + populated.send(success).ok(); + cx.notify(); + } + + fn completion_messages(&self) -> Vec { + // One system message for the whole batch of excerpts: + + // Semantic search results for user query: + // + // Excerpt from $path: + // ~~~ + // `text` + // ~~~ + // + // Excerpt from $path: + + match self { + CodebaseContext::Done(Ok(excerpts)) => { + if excerpts.is_empty() { + return Vec::new(); + } + + let mut body = "Semantic search results for user query:\n".to_string(); + + for excerpt in excerpts { + body.push_str("Excerpt from "); + body.push_str(excerpt.path.as_ref()); + body.push_str(", score "); + body.push_str(&excerpt.score.to_string()); + body.push_str(":\n"); + body.push_str("~~~\n"); + body.push_str(excerpt.text.as_ref()); + body.push_str("~~~\n"); + } + + vec![CompletionMessage::System { content: body }] + } + _ => vec![], + } + } +} diff --git a/crates/assistant2/src/assistant_settings.rs b/crates/assistant2/src/assistant_settings.rs new file mode 100644 index 0000000000..7d532faaeb --- /dev/null +++ b/crates/assistant2/src/assistant_settings.rs @@ -0,0 +1,26 @@ +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use settings::{Settings, SettingsSources}; + +#[derive(Default, Debug, Deserialize, Serialize, Clone)] +pub struct AssistantSettings { + pub enabled: bool, +} + +#[derive(Default, Debug, Deserialize, Serialize, Clone, JsonSchema)] +pub struct AssistantSettingsContent { + pub enabled: Option, +} + +impl Settings for AssistantSettings { + const KEY: Option<&'static str> = Some("assistant_v2"); + + type FileContent = AssistantSettingsContent; + + fn load( + sources: SettingsSources, + _: &mut gpui::AppContext, + ) -> anyhow::Result { + Ok(sources.json_merge().unwrap_or_else(|_| Default::default())) + } +} diff --git a/crates/assistant2/src/completion_provider.rs b/crates/assistant2/src/completion_provider.rs new file mode 100644 index 0000000000..01970c053e --- /dev/null +++ b/crates/assistant2/src/completion_provider.rs @@ -0,0 +1,179 @@ +use anyhow::Result; +use assistant_tooling::ToolFunctionDefinition; +use client::{proto, Client}; +use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; +use gpui::Global; +use std::sync::Arc; + +pub use open_ai::RequestMessage as CompletionMessage; + +#[derive(Clone)] +pub struct CompletionProvider(Arc); + +impl CompletionProvider { + pub fn new(backend: impl CompletionProviderBackend) -> Self { + Self(Arc::new(backend)) + } + + pub fn default_model(&self) -> String { + self.0.default_model() + } + + pub fn available_models(&self) -> Vec { + self.0.available_models() + } + + pub fn complete( + &self, + model: String, + messages: Vec, + stop: Vec, + temperature: f32, + tools: &[ToolFunctionDefinition], + ) -> BoxFuture<'static, Result>>> + { + self.0.complete(model, messages, stop, temperature, tools) + } +} + +impl Global for CompletionProvider {} + +pub trait CompletionProviderBackend: 'static { + fn default_model(&self) -> String; + fn available_models(&self) -> Vec; + fn complete( + &self, + model: String, + messages: Vec, + stop: Vec, + temperature: f32, + tools: &[ToolFunctionDefinition], + ) -> BoxFuture<'static, Result>>>; +} + +pub struct CloudCompletionProvider { + client: Arc, +} + +impl CloudCompletionProvider { + pub fn new(client: Arc) -> Self { + Self { client } + } +} + +impl CompletionProviderBackend for CloudCompletionProvider { + fn default_model(&self) -> String { + "gpt-4-turbo".into() + } + + fn available_models(&self) -> Vec { + vec!["gpt-4-turbo".into(), "gpt-4".into(), "gpt-3.5-turbo".into()] + } + + fn complete( + &self, + model: String, + messages: Vec, + stop: Vec, + temperature: f32, + tools: &[ToolFunctionDefinition], + ) -> BoxFuture<'static, Result>>> + { + let client = self.client.clone(); + let tools: Vec = tools + .iter() + .filter_map(|tool| { + Some(proto::ChatCompletionTool { + variant: Some(proto::chat_completion_tool::Variant::Function( + proto::chat_completion_tool::FunctionObject { + name: tool.name.clone(), + description: Some(tool.description.clone()), + parameters: Some(serde_json::to_string(&tool.parameters).ok()?), + }, + )), + }) + }) + .collect(); + + let tool_choice = match tools.is_empty() { + true => None, + false => Some("auto".into()), + }; + + async move { + let stream = client + .request_stream(proto::CompleteWithLanguageModel { + model, + messages: messages + .into_iter() + .map(|message| match message { + CompletionMessage::Assistant { + content, + tool_calls, + } => proto::LanguageModelRequestMessage { + role: proto::LanguageModelRole::LanguageModelAssistant as i32, + content: content.unwrap_or_default(), + tool_call_id: None, + tool_calls: tool_calls + .into_iter() + .map(|tool_call| match tool_call.content { + open_ai::ToolCallContent::Function { function } => { + proto::ToolCall { + id: tool_call.id, + variant: Some(proto::tool_call::Variant::Function( + proto::tool_call::FunctionCall { + name: function.name, + arguments: function.arguments, + }, + )), + } + } + }) + .collect(), + }, + CompletionMessage::User { content } => { + proto::LanguageModelRequestMessage { + role: proto::LanguageModelRole::LanguageModelUser as i32, + content, + tool_call_id: None, + tool_calls: Vec::new(), + } + } + CompletionMessage::System { content } => { + proto::LanguageModelRequestMessage { + role: proto::LanguageModelRole::LanguageModelSystem as i32, + content, + tool_calls: Vec::new(), + tool_call_id: None, + } + } + CompletionMessage::Tool { + content, + tool_call_id, + } => proto::LanguageModelRequestMessage { + role: proto::LanguageModelRole::LanguageModelTool as i32, + content, + tool_call_id: Some(tool_call_id), + tool_calls: Vec::new(), + }, + }) + .collect(), + stop, + temperature, + tool_choice, + tools, + }) + .await?; + + Ok(stream + .filter_map(|response| async move { + match response { + Ok(mut response) => Some(Ok(response.choices.pop()?.delta?)), + Err(error) => Some(Err(error)), + } + }) + .boxed()) + } + .boxed() + } +} diff --git a/crates/assistant2/src/tools.rs b/crates/assistant2/src/tools.rs new file mode 100644 index 0000000000..ffd5e42bfa --- /dev/null +++ b/crates/assistant2/src/tools.rs @@ -0,0 +1,176 @@ +use anyhow::Result; +use assistant_tooling::LanguageModelTool; +use gpui::{prelude::*, AnyElement, AppContext, Model, Task}; +use project::Fs; +use schemars::JsonSchema; +use semantic_index::ProjectIndex; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; +use ui::{ + div, prelude::*, CollapsibleContainer, Color, Icon, IconName, Label, SharedString, + WindowContext, +}; +use util::ResultExt as _; + +const DEFAULT_SEARCH_LIMIT: usize = 20; + +#[derive(Serialize, Clone)] +pub struct CodebaseExcerpt { + path: SharedString, + text: SharedString, + score: f32, +} + +// Note: Comments on a `LanguageModelTool::Input` become descriptions on the generated JSON schema as shown to the language model. +// Any changes or deletions to the `CodebaseQuery` comments will change model behavior. + +#[derive(Deserialize, JsonSchema)] +pub struct CodebaseQuery { + /// Semantic search query + query: String, + /// Maximum number of results to return, defaults to 20 + limit: Option, +} + +pub struct ProjectIndexTool { + project_index: Model, + fs: Arc, +} + +impl ProjectIndexTool { + pub fn new(project_index: Model, fs: Arc) -> Self { + // TODO: setup a better description based on the user's current codebase. + Self { project_index, fs } + } +} + +impl LanguageModelTool for ProjectIndexTool { + type Input = CodebaseQuery; + type Output = Vec; + + fn name(&self) -> String { + "query_codebase".to_string() + } + + fn description(&self) -> String { + "Semantic search against the user's current codebase, returning excerpts related to the query by computing a dot product against embeddings of chunks and an embedding of the query".to_string() + } + + fn execute(&self, query: &Self::Input, cx: &AppContext) -> Task> { + let project_index = self.project_index.read(cx); + + let results = project_index.search( + query.query.as_str(), + query.limit.unwrap_or(DEFAULT_SEARCH_LIMIT), + cx, + ); + + let fs = self.fs.clone(); + + cx.spawn(|cx| async move { + let results = results.await; + + let excerpts = results.into_iter().map(|result| { + let abs_path = result + .worktree + .read_with(&cx, |worktree, _| worktree.abs_path().join(&result.path)); + let fs = fs.clone(); + + async move { + let path = result.path.clone(); + let text = fs.load(&abs_path?).await?; + + let mut start = result.range.start; + let mut end = result.range.end.min(text.len()); + while !text.is_char_boundary(start) { + start += 1; + } + while !text.is_char_boundary(end) { + end -= 1; + } + + anyhow::Ok(CodebaseExcerpt { + path: path.to_string_lossy().to_string().into(), + text: SharedString::from(text[start..end].to_string()), + score: result.score, + }) + } + }); + + let excerpts = futures::future::join_all(excerpts) + .await + .into_iter() + .filter_map(|result| result.log_err()) + .collect(); + anyhow::Ok(excerpts) + }) + } + + fn render( + _tool_call_id: &str, + input: &Self::Input, + excerpts: &Self::Output, + cx: &mut WindowContext, + ) -> AnyElement { + let query = input.query.clone(); + + div() + .v_flex() + .gap_2() + .child( + div() + .p_2() + .rounded_md() + .bg(cx.theme().colors().editor_background) + .child( + h_flex() + .child(Label::new("Query: ").color(Color::Modified)) + .child(Label::new(query).color(Color::Muted)), + ), + ) + .children(excerpts.iter().map(|excerpt| { + // This render doesn't have state/model, so we can't use the listener + // let expanded = excerpt.expanded; + // let element_id = excerpt.element_id.clone(); + let element_id = ElementId::Name(nanoid::nanoid!().into()); + let expanded = false; + + CollapsibleContainer::new(element_id.clone(), expanded) + .start_slot( + h_flex() + .gap_1() + .child(Icon::new(IconName::File).color(Color::Muted)) + .child(Label::new(excerpt.path.clone()).color(Color::Muted)), + ) + // .on_click(cx.listener(move |this, _, cx| { + // this.toggle_expanded(element_id.clone(), cx); + // })) + .child( + div() + .p_2() + .rounded_md() + .bg(cx.theme().colors().editor_background) + .child( + excerpt.text.clone(), // todo!(): Show as an editor block + ), + ) + })) + .into_any_element() + } + + fn format(_input: &Self::Input, excerpts: &Self::Output) -> String { + let mut body = "Semantic search results:\n".to_string(); + + for excerpt in excerpts { + body.push_str("Excerpt from "); + body.push_str(excerpt.path.as_ref()); + body.push_str(", score "); + body.push_str(&excerpt.score.to_string()); + body.push_str(":\n"); + body.push_str("~~~\n"); + body.push_str(excerpt.text.as_ref()); + body.push_str("~~~\n"); + } + body + } +} diff --git a/crates/assistant_tooling/Cargo.toml b/crates/assistant_tooling/Cargo.toml new file mode 100644 index 0000000000..8a7e7ab185 --- /dev/null +++ b/crates/assistant_tooling/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "assistant_tooling" +version = "0.1.0" +edition = "2021" +publish = false +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/assistant_tooling.rs" + +[dependencies] +anyhow.workspace = true +gpui.workspace = true +schemars.workspace = true +serde.workspace = true +serde_json.workspace = true + +[dev-dependencies] +gpui = { workspace = true, features = ["test-support"] } diff --git a/crates/assistant_tooling/LICENSE-GPL b/crates/assistant_tooling/LICENSE-GPL new file mode 120000 index 0000000000..89e542f750 --- /dev/null +++ b/crates/assistant_tooling/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/assistant_tooling/README.md b/crates/assistant_tooling/README.md new file mode 100644 index 0000000000..79064142ed --- /dev/null +++ b/crates/assistant_tooling/README.md @@ -0,0 +1,208 @@ +# Assistant Tooling + +Bringing OpenAI compatible tool calling to GPUI. + +This unlocks: + +- **Structured Extraction** of model responses +- **Validation** of model inputs +- **Execution** of chosen toolsn + +## Overview + +Language Models can produce structured outputs that are perfect for calling functions. The most famous of these is OpenAI's tool calling. When make a chat completion you can pass a list of tools available to the model. The model will choose `0..n` tools to help them complete a user's task. It's up to _you_ to create the tools that the model can call. + +> **User**: "Hey I need help with implementing a collapsible panel in GPUI" +> +> **Assistant**: "Sure, I can help with that. Let me see what I can find." +> +> `tool_calls: ["name": "query_codebase", arguments: "{ 'query': 'GPUI collapsible panel' }"]` +> +> `result: "['crates/gpui/src/panel.rs:12: impl Panel { ... }', 'crates/gpui/src/panel.rs:20: impl Panel { ... }']"` +> +> **Assistant**: "Here are some excerpts from the GPUI codebase that might help you." + +This library is designed to facilitate this interaction mode by allowing you to go from `struct` to `tool` with a simple trait, `LanguageModelTool`. + +## Example + +Let's expose querying a semantic index directly by the model. First, we'll set up some _necessary_ imports + +```rust +use anyhow::Result; +use assistant_tooling::{LanguageModelTool, ToolRegistry}; +use gpui::{App, AppContext, Task}; +use schemars::JsonSchema; +use serde::Deserialize; +use serde_json::json; +``` + +Then we'll define the query structure the model must fill in. This _must_ derive `Deserialize` from `serde` and `JsonSchema` from the `schemars` crate. + +```rust +#[derive(Deserialize, JsonSchema)] +struct CodebaseQuery { + query: String, +} +``` + +After that we can define our tool, with the expectation that it will need a `ProjectIndex` to search against. For this example, the index uses the same interface as `semantic_index::ProjectIndex`. + +```rust +struct ProjectIndex {} + +impl ProjectIndex { + fn new() -> Self { + ProjectIndex {} + } + + fn search(&self, _query: &str, _limit: usize, _cx: &AppContext) -> Task>> { + // Instead of hooking up a real index, we're going to fake it + if _query.contains("gpui") { + return Task::ready(Ok(vec![r#"// crates/gpui/src/gpui.rs + //! # Welcome to GPUI! + //! + //! GPUI is a hybrid immediate and retained mode, GPU accelerated, UI framework + //! for Rust, designed to support a wide variety of applications + "# + .to_string()])); + } + return Task::ready(Ok(vec![])); + } +} + +struct ProjectIndexTool { + project_index: ProjectIndex, +} +``` + +Now we can implement the `LanguageModelTool` trait for our tool by: + +- Defining the `Input` from the model, which is `CodebaseQuery` +- Defining the `Output` +- Implementing the `name` and `description` functions to provide the model information when it's choosing a tool +- Implementing the `execute` function to run the tool + +```rust +impl LanguageModelTool for ProjectIndexTool { + type Input = CodebaseQuery; + type Output = String; + + fn name(&self) -> String { + "query_codebase".to_string() + } + + fn description(&self) -> String { + "Executes a query against the codebase, returning excerpts related to the query".to_string() + } + + fn execute(&self, query: Self::Input, cx: &AppContext) -> Task> { + let results = self.project_index.search(query.query.as_str(), 10, cx); + + cx.spawn(|_cx| async move { + let results = results.await?; + + if !results.is_empty() { + Ok(results.join("\n")) + } else { + Ok("No results".to_string()) + } + }) + } +} +``` + +For the sake of this example, let's look at the types that OpenAI will be passing to us + +```rust +// OpenAI definitions, shown here for demonstration +#[derive(Deserialize)] +struct FunctionCall { + name: String, + args: String, +} + +#[derive(Deserialize, Eq, PartialEq)] +enum ToolCallType { + #[serde(rename = "function")] + Function, + Other, +} + +#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash, Ord, PartialOrd)] +struct ToolCallId(String); + +#[derive(Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +enum ToolCall { + Function { + #[allow(dead_code)] + id: ToolCallId, + function: FunctionCall, + }, + Other { + #[allow(dead_code)] + id: ToolCallId, + }, +} + +#[derive(Deserialize)] +struct AssistantMessage { + role: String, + content: Option, + tool_calls: Option>, +} +``` + +When the model wants to call tools, it will pass a list of `ToolCall`s. When those are `function`s that we can handle, we'll pass them to our `ToolRegistry` to get a future that we can await. + +```rust +// Inside `fn main()` +App::new().run(|cx: &mut AppContext| { + let tool = ProjectIndexTool { + project_index: ProjectIndex::new(), + }; + + let mut registry = ToolRegistry::new(); + let registered = registry.register(tool); + assert!(registered.is_ok()); +``` + +Let's pretend the model sent us back a message requesting + +```rust +let model_response = json!({ + "role": "assistant", + "tool_calls": [ + { + "id": "call_1", + "function": { + "name": "query_codebase", + "args": r#"{"query":"GPUI Task background_executor"}"# + }, + "type": "function" + } + ] +}); + +let message: AssistantMessage = serde_json::from_value(model_response).unwrap(); + +// We know there's a tool call, so let's skip straight to it for this example +let tool_calls = message.tool_calls.as_ref().unwrap(); +let tool_call = tool_calls.get(0).unwrap(); +``` + +We can now use our registry to call the tool. + +```rust +let task = registry.call( + tool_call.name, + tool_call.args, +); + +cx.spawn(|_cx| async move { + let result = task.await?; + println!("{}", result.unwrap()); + Ok(()) +}) +``` diff --git a/crates/assistant_tooling/src/assistant_tooling.rs b/crates/assistant_tooling/src/assistant_tooling.rs new file mode 100644 index 0000000000..93d81cbb9d --- /dev/null +++ b/crates/assistant_tooling/src/assistant_tooling.rs @@ -0,0 +1,5 @@ +pub mod registry; +pub mod tool; + +pub use crate::registry::ToolRegistry; +pub use crate::tool::{LanguageModelTool, ToolFunctionCall, ToolFunctionDefinition}; diff --git a/crates/assistant_tooling/src/registry.rs b/crates/assistant_tooling/src/registry.rs new file mode 100644 index 0000000000..8c969c0d80 --- /dev/null +++ b/crates/assistant_tooling/src/registry.rs @@ -0,0 +1,298 @@ +use anyhow::{anyhow, Result}; +use gpui::{AnyElement, AppContext, Task, WindowContext}; +use std::{any::Any, collections::HashMap}; + +use crate::tool::{ + LanguageModelTool, ToolFunctionCall, ToolFunctionCallResult, ToolFunctionDefinition, +}; + +pub struct ToolRegistry { + tools: HashMap Task>>, + definitions: Vec, +} + +impl ToolRegistry { + pub fn new() -> Self { + Self { + tools: HashMap::new(), + definitions: Vec::new(), + } + } + + pub fn definitions(&self) -> &[ToolFunctionDefinition] { + &self.definitions + } + + pub fn register(&mut self, tool: T) -> Result<()> { + fn render( + tool_call_id: &str, + input: &Box, + output: &Box, + cx: &mut WindowContext, + ) -> AnyElement { + T::render( + tool_call_id, + input.as_ref().downcast_ref::().unwrap(), + output.as_ref().downcast_ref::().unwrap(), + cx, + ) + } + + fn format( + input: &Box, + output: &Box, + ) -> String { + T::format( + input.as_ref().downcast_ref::().unwrap(), + output.as_ref().downcast_ref::().unwrap(), + ) + } + + self.definitions.push(tool.definition()); + let name = tool.name(); + let previous = self.tools.insert( + name.clone(), + Box::new(move |tool_call: &ToolFunctionCall, cx: &AppContext| { + let name = tool_call.name.clone(); + let arguments = tool_call.arguments.clone(); + let id = tool_call.id.clone(); + + let Ok(input) = serde_json::from_str::(arguments.as_str()) else { + return Task::ready(ToolFunctionCall { + id, + name: name.clone(), + arguments, + result: Some(ToolFunctionCallResult::ParsingFailed), + }); + }; + + let result = tool.execute(&input, cx); + + cx.spawn(move |_cx| async move { + match result.await { + Ok(result) => { + let result: T::Output = result; + ToolFunctionCall { + id, + name: name.clone(), + arguments, + result: Some(ToolFunctionCallResult::Finished { + input: Box::new(input), + output: Box::new(result), + render_fn: render::, + format_fn: format::, + }), + } + } + Err(_error) => ToolFunctionCall { + id, + name: name.clone(), + arguments, + result: Some(ToolFunctionCallResult::ExecutionFailed { + input: Box::new(input), + }), + }, + } + }) + }), + ); + + if previous.is_some() { + return Err(anyhow!("already registered a tool with name {}", name)); + } + + Ok(()) + } + + pub fn call(&self, tool_call: &ToolFunctionCall, cx: &AppContext) -> Task { + let name = tool_call.name.clone(); + let arguments = tool_call.arguments.clone(); + let id = tool_call.id.clone(); + + let tool = match self.tools.get(&name) { + Some(tool) => tool, + None => { + let name = name.clone(); + return Task::ready(ToolFunctionCall { + id, + name: name.clone(), + arguments, + result: Some(ToolFunctionCallResult::NoSuchTool), + }); + } + }; + + tool(tool_call, cx) + } +} + +#[cfg(test)] +mod test { + + use super::*; + + use schemars::schema_for; + + use gpui::{div, AnyElement, Element, ParentElement, TestAppContext, WindowContext}; + use schemars::JsonSchema; + use serde::{Deserialize, Serialize}; + use serde_json::json; + + #[derive(Deserialize, Serialize, JsonSchema)] + struct WeatherQuery { + location: String, + unit: String, + } + + struct WeatherTool { + current_weather: WeatherResult, + } + + #[derive(Clone, Serialize, Deserialize, PartialEq, Debug)] + struct WeatherResult { + location: String, + temperature: f64, + unit: String, + } + + impl LanguageModelTool for WeatherTool { + type Input = WeatherQuery; + type Output = WeatherResult; + + fn name(&self) -> String { + "get_current_weather".to_string() + } + + fn description(&self) -> String { + "Fetches the current weather for a given location.".to_string() + } + + fn execute(&self, input: &WeatherQuery, _cx: &AppContext) -> Task> { + let _location = input.location.clone(); + let _unit = input.unit.clone(); + + let weather = self.current_weather.clone(); + + Task::ready(Ok(weather)) + } + + fn render( + _tool_call_id: &str, + _input: &Self::Input, + output: &Self::Output, + _cx: &mut WindowContext, + ) -> AnyElement { + div() + .child(format!( + "The current temperature in {} is {} {}", + output.location, output.temperature, output.unit + )) + .into_any() + } + + fn format(_input: &Self::Input, output: &Self::Output) -> String { + format!( + "The current temperature in {} is {} {}", + output.location, output.temperature, output.unit + ) + } + } + + #[gpui::test] + async fn test_function_registry(cx: &mut TestAppContext) { + cx.background_executor.run_until_parked(); + + let mut registry = ToolRegistry::new(); + + let tool = WeatherTool { + current_weather: WeatherResult { + location: "San Francisco".to_string(), + temperature: 21.0, + unit: "Celsius".to_string(), + }, + }; + + registry.register(tool).unwrap(); + + let _result = cx + .update(|cx| { + registry.call( + &ToolFunctionCall { + name: "get_current_weather".to_string(), + arguments: r#"{ "location": "San Francisco", "unit": "Celsius" }"# + .to_string(), + id: "test-123".to_string(), + result: None, + }, + cx, + ) + }) + .await; + + // assert!(result.is_ok()); + // let result = result.unwrap(); + + // let expected = r#"{"location":"San Francisco","temperature":21.0,"unit":"Celsius"}"#; + + // todo!(): Put this back in after the interface is stabilized + // assert_eq!(result, expected); + } + + #[gpui::test] + async fn test_openai_weather_example(cx: &mut TestAppContext) { + cx.background_executor.run_until_parked(); + + let tool = WeatherTool { + current_weather: WeatherResult { + location: "San Francisco".to_string(), + temperature: 21.0, + unit: "Celsius".to_string(), + }, + }; + + let tools = vec![tool.definition()]; + assert_eq!(tools.len(), 1); + + let expected = ToolFunctionDefinition { + name: "get_current_weather".to_string(), + description: "Fetches the current weather for a given location.".to_string(), + parameters: schema_for!(WeatherQuery).schema, + }; + + assert_eq!(tools[0].name, expected.name); + assert_eq!(tools[0].description, expected.description); + + let expected_schema = serde_json::to_value(&tools[0].parameters).unwrap(); + + assert_eq!( + expected_schema, + json!({ + "title": "WeatherQuery", + "type": "object", + "properties": { + "location": { + "type": "string" + }, + "unit": { + "type": "string" + } + }, + "required": ["location", "unit"] + }) + ); + + let args = json!({ + "location": "San Francisco", + "unit": "Celsius" + }); + + let query: WeatherQuery = serde_json::from_value(args).unwrap(); + + let result = cx.update(|cx| tool.execute(&query, cx)).await; + + assert!(result.is_ok()); + let result = result.unwrap(); + + assert_eq!(result, tool.current_weather); + } +} diff --git a/crates/assistant_tooling/src/tool.rs b/crates/assistant_tooling/src/tool.rs new file mode 100644 index 0000000000..b63e2901c6 --- /dev/null +++ b/crates/assistant_tooling/src/tool.rs @@ -0,0 +1,145 @@ +use anyhow::Result; +use gpui::{div, AnyElement, AppContext, Element, ParentElement as _, Task, WindowContext}; +use schemars::{schema::SchemaObject, schema_for, JsonSchema}; +use serde::Deserialize; +use std::{any::Any, fmt::Debug}; + +#[derive(Default, Deserialize)] +pub struct ToolFunctionCall { + pub id: String, + pub name: String, + pub arguments: String, + #[serde(skip)] + pub result: Option, +} + +pub enum ToolFunctionCallResult { + NoSuchTool, + ParsingFailed, + ExecutionFailed { + input: Box, + }, + Finished { + input: Box, + output: Box, + render_fn: fn( + // tool_call_id + &str, + // LanguageModelTool::Input + &Box, + // LanguageModelTool::Output + &Box, + &mut WindowContext, + ) -> AnyElement, + format_fn: fn( + // LanguageModelTool::Input + &Box, + // LanguageModelTool::Output + &Box, + ) -> String, + }, +} + +impl ToolFunctionCallResult { + pub fn render( + &self, + tool_name: &str, + tool_call_id: &str, + cx: &mut WindowContext, + ) -> AnyElement { + match self { + ToolFunctionCallResult::NoSuchTool => { + div().child(format!("no such tool {tool_name}")).into_any() + } + ToolFunctionCallResult::ParsingFailed => div() + .child(format!("failed to parse input for tool {tool_name}")) + .into_any(), + ToolFunctionCallResult::ExecutionFailed { .. } => div() + .child(format!("failed to execute tool {tool_name}")) + .into_any(), + ToolFunctionCallResult::Finished { + input, + output, + render_fn, + .. + } => render_fn(tool_call_id, input, output, cx), + } + } + + pub fn format(&self, tool: &str) -> String { + match self { + ToolFunctionCallResult::NoSuchTool => format!("no such tool {tool}"), + ToolFunctionCallResult::ParsingFailed => { + format!("failed to parse input for tool {tool}") + } + ToolFunctionCallResult::ExecutionFailed { input: _input } => { + format!("failed to execute tool {tool}") + } + ToolFunctionCallResult::Finished { + input, + output, + format_fn, + .. + } => format_fn(input, output), + } + } +} + +#[derive(Clone)] +pub struct ToolFunctionDefinition { + pub name: String, + pub description: String, + pub parameters: SchemaObject, +} + +impl Debug for ToolFunctionDefinition { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let schema = serde_json::to_string(&self.parameters).ok(); + let schema = schema.unwrap_or("None".to_string()); + + f.debug_struct("ToolFunctionDefinition") + .field("name", &self.name) + .field("description", &self.description) + .field("parameters", &schema) + .finish() + } +} + +pub trait LanguageModelTool { + /// The input type that will be passed in to `execute` when the tool is called + /// by the language model. + type Input: for<'de> Deserialize<'de> + JsonSchema; + + /// The output returned by executing the tool. + type Output: 'static; + + /// The name of the tool is exposed to the language model to allow + /// the model to pick which tools to use. As this name is used to + /// identify the tool within a tool registry, it should be unique. + fn name(&self) -> String; + + /// A description of the tool that can be used to _prompt_ the model + /// as to what the tool does. + fn description(&self) -> String; + + /// The OpenAI Function definition for the tool, for direct use with OpenAI's API. + fn definition(&self) -> ToolFunctionDefinition { + ToolFunctionDefinition { + name: self.name(), + description: self.description(), + parameters: schema_for!(Self::Input).schema, + } + } + + /// Execute the tool + fn execute(&self, input: &Self::Input, cx: &AppContext) -> Task>; + + fn render( + tool_call_id: &str, + input: &Self::Input, + output: &Self::Output, + cx: &mut WindowContext, + ) -> AnyElement; + + fn format(input: &Self::Input, output: &Self::Output) -> String; +} diff --git a/crates/client/src/client.rs b/crates/client/src/client.rs index 7d18e5d2db..7787089568 100644 --- a/crates/client/src/client.rs +++ b/crates/client/src/client.rs @@ -457,6 +457,14 @@ impl Client { }) } + pub fn production(cx: &mut AppContext) -> Arc { + let clock = Arc::new(clock::RealSystemClock); + let http = Arc::new(HttpClientWithUrl::new( + &ClientSettings::get_global(cx).server_url, + )); + Self::new(clock, http.clone(), cx) + } + pub fn id(&self) -> u64 { self.id.load(Ordering::SeqCst) } @@ -1119,6 +1127,8 @@ impl Client { if let Some((login, token)) = IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref()) { + eprintln!("authenticate as admin {login}, {token}"); + return Self::authenticate_as_admin(http, login.clone(), token.clone()) .await; } diff --git a/crates/collab/seed.default.json b/crates/collab/seed.default.json index ded1dc862b..1abec644be 100644 --- a/crates/collab/seed.default.json +++ b/crates/collab/seed.default.json @@ -5,7 +5,8 @@ "maxbrunsfeld", "iamnbutler", "mikayla-maki", - "JosephTLyons" + "JosephTLyons", + "rgbkrk" ], "channels": ["zed"], "number_of_users": 100 diff --git a/crates/collab/src/ai.rs b/crates/collab/src/ai.rs index 4634166799..06c6e77dfd 100644 --- a/crates/collab/src/ai.rs +++ b/crates/collab/src/ai.rs @@ -1,5 +1,6 @@ -use anyhow::{anyhow, Result}; +use anyhow::{anyhow, Context as _, Result}; use rpc::proto; +use util::ResultExt as _; pub fn language_model_request_to_open_ai( request: proto::CompleteWithLanguageModel, @@ -9,24 +10,83 @@ pub fn language_model_request_to_open_ai( messages: request .messages .into_iter() - .map(|message| { + .map(|message: proto::LanguageModelRequestMessage| { let role = proto::LanguageModelRole::from_i32(message.role) .ok_or_else(|| anyhow!("invalid role {}", message.role))?; - Ok(open_ai::RequestMessage { - role: match role { - proto::LanguageModelRole::LanguageModelUser => open_ai::Role::User, - proto::LanguageModelRole::LanguageModelAssistant => { - open_ai::Role::Assistant - } - proto::LanguageModelRole::LanguageModelSystem => open_ai::Role::System, + + let openai_message = match role { + proto::LanguageModelRole::LanguageModelUser => open_ai::RequestMessage::User { + content: message.content, }, - content: message.content, - }) + proto::LanguageModelRole::LanguageModelAssistant => { + open_ai::RequestMessage::Assistant { + content: Some(message.content), + tool_calls: message + .tool_calls + .into_iter() + .filter_map(|call| { + Some(open_ai::ToolCall { + id: call.id, + content: match call.variant? { + proto::tool_call::Variant::Function(f) => { + open_ai::ToolCallContent::Function { + function: open_ai::FunctionContent { + name: f.name, + arguments: f.arguments, + }, + } + } + }, + }) + }) + .collect(), + } + } + proto::LanguageModelRole::LanguageModelSystem => { + open_ai::RequestMessage::System { + content: message.content, + } + } + proto::LanguageModelRole::LanguageModelTool => open_ai::RequestMessage::Tool { + tool_call_id: message + .tool_call_id + .ok_or_else(|| anyhow!("tool message is missing tool call id"))?, + content: message.content, + }, + }; + + Ok(openai_message) }) .collect::>>()?, stream: true, stop: request.stop, temperature: request.temperature, + tools: request + .tools + .into_iter() + .filter_map(|tool| { + Some(match tool.variant? { + proto::chat_completion_tool::Variant::Function(f) => { + open_ai::ToolDefinition::Function { + function: open_ai::FunctionDefinition { + name: f.name, + description: f.description, + parameters: if let Some(params) = &f.parameters { + Some( + serde_json::from_str(params) + .context("failed to deserialize tool parameters") + .log_err()?, + ) + } else { + None + }, + }, + } + } + }) + }) + .collect(), + tool_choice: request.tool_choice, }) } @@ -58,6 +118,9 @@ pub fn language_model_request_message_to_google_ai( proto::LanguageModelRole::LanguageModelUser => google_ai::Role::User, proto::LanguageModelRole::LanguageModelAssistant => google_ai::Role::Model, proto::LanguageModelRole::LanguageModelSystem => google_ai::Role::User, + proto::LanguageModelRole::LanguageModelTool => { + Err(anyhow!("we don't handle tool calls with google ai yet"))? + } }, }) } diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 3cba88b543..b2588e6fb3 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -775,9 +775,7 @@ impl Server { Box::new(move |envelope, session| { let envelope = envelope.into_any().downcast::>().unwrap(); let received_at = envelope.received_at; - tracing::info!( - "message received" - ); + tracing::info!("message received"); let start_time = Instant::now(); let future = (handler)(*envelope, session); async move { @@ -786,12 +784,24 @@ impl Server { let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0; let queue_duration_ms = total_duration_ms - processing_duration_ms; let payload_type = M::NAME; + match result { Err(error) => { - // todo!(), why isn't this logged inside the span? - tracing::error!(%error, total_duration_ms, processing_duration_ms, queue_duration_ms, payload_type, "error handling message") + tracing::error!( + ?error, + total_duration_ms, + processing_duration_ms, + queue_duration_ms, + payload_type, + "error handling message" + ) } - Ok(()) => tracing::info!(total_duration_ms, processing_duration_ms, queue_duration_ms, "finished handling message"), + Ok(()) => tracing::info!( + total_duration_ms, + processing_duration_ms, + queue_duration_ms, + "finished handling message" + ), } } .boxed() @@ -4098,7 +4108,7 @@ async fn complete_with_open_ai( crate::ai::language_model_request_to_open_ai(request)?, ) .await - .context("open_ai::stream_completion request failed")?; + .context("open_ai::stream_completion request failed within collab")?; while let Some(event) = completion_stream.next().await { let event = event?; @@ -4113,8 +4123,32 @@ async fn complete_with_open_ai( open_ai::Role::User => LanguageModelRole::LanguageModelUser, open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant, open_ai::Role::System => LanguageModelRole::LanguageModelSystem, + open_ai::Role::Tool => LanguageModelRole::LanguageModelTool, } as i32), content: choice.delta.content, + tool_calls: choice + .delta + .tool_calls + .into_iter() + .map(|delta| proto::ToolCallDelta { + index: delta.index as u32, + id: delta.id, + variant: match delta.function { + Some(function) => { + let name = function.name; + let arguments = function.arguments; + + Some(proto::tool_call_delta::Variant::Function( + proto::tool_call_delta::FunctionCallDelta { + name, + arguments, + }, + )) + } + None => None, + }, + }) + .collect(), }), finish_reason: choice.finish_reason, }) @@ -4165,6 +4199,8 @@ async fn complete_with_google_ai( }) .collect(), ), + // Tool calls are not supported for Google + tool_calls: Vec::new(), }), finish_reason: candidate.finish_reason.map(|reason| reason.to_string()), }) @@ -4187,24 +4223,28 @@ async fn complete_with_anthropic( let messages = request .messages .into_iter() - .filter_map(|message| match message.role() { - LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage { - role: anthropic::Role::User, - content: message.content, - }), - LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage { - role: anthropic::Role::Assistant, - content: message.content, - }), - // Anthropic's API breaks system instructions out as a separate field rather - // than having a system message role. - LanguageModelRole::LanguageModelSystem => { - if !system_message.is_empty() { - system_message.push_str("\n\n"); - } - system_message.push_str(&message.content); + .filter_map(|message| { + match message.role() { + LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage { + role: anthropic::Role::User, + content: message.content, + }), + LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage { + role: anthropic::Role::Assistant, + content: message.content, + }), + // Anthropic's API breaks system instructions out as a separate field rather + // than having a system message role. + LanguageModelRole::LanguageModelSystem => { + if !system_message.is_empty() { + system_message.push_str("\n\n"); + } + system_message.push_str(&message.content); - None + None + } + // We don't yet support tool calls for Anthropic + LanguageModelRole::LanguageModelTool => None, } }) .collect(); @@ -4248,6 +4288,7 @@ async fn complete_with_anthropic( delta: Some(proto::LanguageModelResponseMessage { role: Some(current_role as i32), content: Some(text), + tool_calls: Vec::new(), }), finish_reason: None, }], @@ -4264,6 +4305,7 @@ async fn complete_with_anthropic( delta: Some(proto::LanguageModelResponseMessage { role: Some(current_role as i32), content: Some(text), + tool_calls: Vec::new(), }), finish_reason: None, }], diff --git a/crates/collab_ui/src/chat_panel.rs b/crates/collab_ui/src/chat_panel.rs index 58384f5ee5..ef37ce653b 100644 --- a/crates/collab_ui/src/chat_panel.rs +++ b/crates/collab_ui/src/chat_panel.rs @@ -234,10 +234,11 @@ impl ChatPanel { let channel_id = chat.read(cx).channel_id; { self.markdown_data.clear(); - let chat = chat.read(cx); - self.message_list.reset(chat.message_count()); + let chat = chat.read(cx); let channel_name = chat.channel(cx).map(|channel| channel.name.clone()); + let message_count = chat.message_count(); + self.message_list.reset(message_count); self.message_editor.update(cx, |editor, cx| { editor.set_channel(channel_id, channel_name, cx); editor.clear_reply_to_message_id(); @@ -766,7 +767,7 @@ impl ChatPanel { body.push_str(MESSAGE_EDITED); } - let mut rich_text = rich_text::render_rich_text(body, &mentions, language_registry, None); + let mut rich_text = RichText::new(body, &mentions, language_registry); if message.edited_at.is_some() { let range = (rich_text.text.len() - MESSAGE_EDITED.len())..rich_text.text.len(); diff --git a/crates/collab_ui/src/collab_panel.rs b/crates/collab_ui/src/collab_panel.rs index 8b5eed08d9..d9b3f1abbf 100644 --- a/crates/collab_ui/src/collab_panel.rs +++ b/crates/collab_ui/src/collab_panel.rs @@ -2947,7 +2947,7 @@ impl Render for DraggedChannelView { fn render(&mut self, cx: &mut ViewContext) -> impl Element { let ui_font = ThemeSettings::get_global(cx).ui_font.family.clone(); h_flex() - .font(ui_font) + .font_family(ui_font) .bg(cx.theme().colors().background) .w(self.width) .p_1() diff --git a/crates/collab_ui/src/notifications/incoming_call_notification.rs b/crates/collab_ui/src/notifications/incoming_call_notification.rs index a8ba20c1e5..385e903bf7 100644 --- a/crates/collab_ui/src/notifications/incoming_call_notification.rs +++ b/crates/collab_ui/src/notifications/incoming_call_notification.rs @@ -125,7 +125,7 @@ impl Render for IncomingCallNotification { cx.set_rem_size(ui_font_size); - div().size_full().font(ui_font).child( + div().size_full().font_family(ui_font).child( CollabNotification::new( self.state.call.calling_user.avatar_uri.clone(), Button::new("accept", "Accept").on_click({ diff --git a/crates/collab_ui/src/notifications/project_shared_notification.rs b/crates/collab_ui/src/notifications/project_shared_notification.rs index 407ff66d19..03001bc3ad 100644 --- a/crates/collab_ui/src/notifications/project_shared_notification.rs +++ b/crates/collab_ui/src/notifications/project_shared_notification.rs @@ -129,7 +129,7 @@ impl Render for ProjectSharedNotification { cx.set_rem_size(ui_font_size); - div().size_full().font(ui_font).child( + div().size_full().font_family(ui_font).child( CollabNotification::new( self.owner.avatar_uri.clone(), Button::new("open", "Open").on_click(cx.listener(move |this, _event, cx| { diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index d7dc3caed7..cfca895eff 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -61,13 +61,13 @@ use fuzzy::{StringMatch, StringMatchCandidate}; use git::blame::GitBlame; use git::diff_hunk_to_display; use gpui::{ - div, impl_actions, point, prelude::*, px, relative, rems, size, uniform_list, Action, - AnyElement, AppContext, AsyncWindowContext, AvailableSpace, BackgroundExecutor, Bounds, - ClipboardItem, Context, DispatchPhase, ElementId, EventEmitter, FocusHandle, FocusableView, - FontId, FontStyle, FontWeight, HighlightStyle, Hsla, InteractiveText, KeyContext, Model, - MouseButton, PaintQuad, ParentElement, Pixels, Render, SharedString, Size, StrikethroughStyle, - Styled, StyledText, Subscription, Task, TextStyle, UnderlineStyle, UniformListScrollHandle, - View, ViewContext, ViewInputHandler, VisualContext, WeakView, WhiteSpace, WindowContext, + div, impl_actions, point, prelude::*, px, relative, size, uniform_list, Action, AnyElement, + AppContext, AsyncWindowContext, AvailableSpace, BackgroundExecutor, Bounds, ClipboardItem, + Context, DispatchPhase, ElementId, EventEmitter, FocusHandle, FocusableView, FontId, FontStyle, + FontWeight, HighlightStyle, Hsla, InteractiveText, KeyContext, Model, MouseButton, PaintQuad, + ParentElement, Pixels, Render, SharedString, Size, StrikethroughStyle, Styled, StyledText, + Subscription, Task, TextStyle, UnderlineStyle, UniformListScrollHandle, View, ViewContext, + ViewInputHandler, VisualContext, WeakView, WhiteSpace, WindowContext, }; use highlight_matching_bracket::refresh_matching_bracket_highlights; use hover_popover::{hide_hover, HoverState}; @@ -8885,7 +8885,6 @@ impl Editor { self.style = Some(style); } - #[cfg(any(test, feature = "test-support"))] pub fn style(&self) -> Option<&EditorStyle> { self.style.as_ref() } @@ -10322,21 +10321,9 @@ impl FocusableView for Editor { impl Render for Editor { fn render<'a>(&mut self, cx: &mut ViewContext<'a, Self>) -> impl IntoElement { let settings = ThemeSettings::get_global(cx); - let text_style = match self.mode { - EditorMode::SingleLine | EditorMode::AutoHeight { .. } => TextStyle { - color: cx.theme().colors().editor_foreground, - font_family: settings.ui_font.family.clone(), - font_features: settings.ui_font.features, - font_size: rems(0.875).into(), - font_weight: FontWeight::NORMAL, - font_style: FontStyle::Normal, - line_height: relative(settings.buffer_line_height.value()), - background_color: None, - underline: None, - strikethrough: None, - white_space: WhiteSpace::Normal, - }, + let text_style = match self.mode { + EditorMode::SingleLine | EditorMode::AutoHeight { .. } => cx.text_style(), EditorMode::Full => TextStyle { color: cx.theme().colors().editor_foreground, font_family: settings.buffer_font.family.clone(), diff --git a/crates/editor/src/element.rs b/crates/editor/src/element.rs index b9fed082fc..49917d7ade 100644 --- a/crates/editor/src/element.rs +++ b/crates/editor/src/element.rs @@ -3056,7 +3056,7 @@ fn render_inline_blame_entry( h_flex() .id("inline-blame") .w_full() - .font(style.text.font().family) + .font_family(style.text.font().family) .text_color(cx.theme().status().hint) .line_height(style.text.line_height) .child(Icon::new(IconName::FileGit).color(Color::Hint)) @@ -3108,7 +3108,7 @@ fn render_blame_entry( h_flex() .w_full() - .font(style.text.font().family) + .font_family(style.text.font().family) .line_height(style.text.line_height) .id(("blame", ix)) .children([ diff --git a/crates/gpui/src/styled.rs b/crates/gpui/src/styled.rs index 54adbb3891..9705f4dd13 100644 --- a/crates/gpui/src/styled.rs +++ b/crates/gpui/src/styled.rs @@ -1,7 +1,7 @@ use crate::{ self as gpui, hsla, point, px, relative, rems, AbsoluteLength, AlignItems, CursorStyle, - DefiniteLength, Fill, FlexDirection, FlexWrap, FontStyle, FontWeight, Hsla, JustifyContent, - Length, Position, SharedString, StyleRefinement, Visibility, WhiteSpace, + DefiniteLength, Fill, FlexDirection, FlexWrap, Font, FontStyle, FontWeight, Hsla, + JustifyContent, Length, Position, SharedString, StyleRefinement, Visibility, WhiteSpace, }; use crate::{BoxShadow, TextStyleRefinement}; use smallvec::{smallvec, SmallVec}; @@ -771,14 +771,32 @@ pub trait Styled: Sized { self } - /// Change the font on this element and its children. - fn font(mut self, family_name: impl Into) -> Self { + /// Change the font family on this element and its children. + fn font_family(mut self, family_name: impl Into) -> Self { self.text_style() .get_or_insert_with(Default::default) .font_family = Some(family_name.into()); self } + /// Change the font of this element and its children. + fn font(mut self, font: Font) -> Self { + let Font { + family, + features, + weight, + style, + } = font; + + let text_style = self.text_style().get_or_insert_with(Default::default); + text_style.font_family = Some(family); + text_style.font_features = Some(features); + text_style.font_weight = Some(weight); + text_style.font_style = Some(style); + + self + } + /// Set the line height on this element and its children. fn line_height(mut self, line_height: impl Into) -> Self { self.text_style() diff --git a/crates/open_ai/src/open_ai.rs b/crates/open_ai/src/open_ai.rs index 97abb45dfc..bdc6d3cb9b 100644 --- a/crates/open_ai/src/open_ai.rs +++ b/crates/open_ai/src/open_ai.rs @@ -1,6 +1,7 @@ use anyhow::{anyhow, Context, Result}; use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt}; use serde::{Deserialize, Serialize}; +use serde_json::{Map, Value}; use std::{convert::TryFrom, future::Future}; use util::http::{AsyncBody, HttpClient, Method, Request as HttpRequest}; @@ -12,6 +13,7 @@ pub enum Role { User, Assistant, System, + Tool, } impl TryFrom for Role { @@ -22,6 +24,7 @@ impl TryFrom for Role { "user" => Ok(Self::User), "assistant" => Ok(Self::Assistant), "system" => Ok(Self::System), + "tool" => Ok(Self::Tool), _ => Err(anyhow!("invalid role '{value}'")), } } @@ -33,6 +36,7 @@ impl From for String { Role::User => "user".to_owned(), Role::Assistant => "assistant".to_owned(), Role::System => "system".to_owned(), + Role::Tool => "tool".to_owned(), } } } @@ -91,18 +95,88 @@ pub struct Request { pub stream: bool, pub stop: Vec, pub temperature: f32, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_choice: Option, + #[serde(skip_serializing_if = "Vec::is_empty")] + pub tools: Vec, +} + +#[derive(Debug, Serialize)] +pub struct FunctionDefinition { + pub name: String, + pub description: Option, + pub parameters: Option>, +} + +#[derive(Serialize, Debug)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ToolDefinition { + #[allow(dead_code)] + Function { function: FunctionDefinition }, } #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] -pub struct RequestMessage { - pub role: Role, - pub content: String, +#[serde(tag = "role", rename_all = "lowercase")] +pub enum RequestMessage { + Assistant { + content: Option, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + tool_calls: Vec, + }, + User { + content: String, + }, + System { + content: String, + }, + Tool { + content: String, + tool_call_id: String, + }, } #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] -pub struct ResponseMessage { +pub struct ToolCall { + pub id: String, + #[serde(flatten)] + pub content: ToolCallContent, +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +#[serde(tag = "type", rename_all = "lowercase")] +pub enum ToolCallContent { + Function { function: FunctionContent }, +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct FunctionContent { + pub name: String, + pub arguments: String, +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct ResponseMessageDelta { pub role: Option, pub content: Option, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub tool_calls: Vec, +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct ToolCallChunk { + pub index: usize, + pub id: Option, + + // There is also an optional `type` field that would determine if a + // function is there. Sometimes this streams in with the `function` before + // it streams in the `type` + pub function: Option, +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct FunctionChunk { + pub name: Option, + pub arguments: Option, } #[derive(Deserialize, Debug)] @@ -115,7 +189,7 @@ pub struct Usage { #[derive(Deserialize, Debug)] pub struct ChoiceDelta { pub index: u32, - pub delta: ResponseMessage, + pub delta: ResponseMessageDelta, pub finish_reason: Option, } diff --git a/crates/project_panel/src/project_panel.rs b/crates/project_panel/src/project_panel.rs index f592103b20..ddbd8429ff 100644 --- a/crates/project_panel/src/project_panel.rs +++ b/crates/project_panel/src/project_panel.rs @@ -1843,7 +1843,7 @@ impl Render for DraggedProjectEntryView { let settings = ProjectPanelSettings::get_global(cx); let ui_font = ThemeSettings::get_global(cx).ui_font.family.clone(); h_flex() - .font(ui_font) + .font_family(ui_font) .bg(cx.theme().colors().background) .w(self.width) .child( diff --git a/crates/recent_projects/src/remote_projects.rs b/crates/recent_projects/src/remote_projects.rs index 2a2f13d945..8a548c747e 100644 --- a/crates/recent_projects/src/remote_projects.rs +++ b/crates/recent_projects/src/remote_projects.rs @@ -507,7 +507,7 @@ impl RemoteProjects { .my_1() .py_0p5() .px_3() - .font(ThemeSettings::get_global(cx).buffer_font.family.clone()) + .font_family(ThemeSettings::get_global(cx).buffer_font.family.clone()) .child(Label::new(instructions)) ) .when(status == DevServerStatus::Offline, |this| { diff --git a/crates/rich_text/src/rich_text.rs b/crates/rich_text/src/rich_text.rs index 78dabe0ca3..16c4473e07 100644 --- a/crates/rich_text/src/rich_text.rs +++ b/crates/rich_text/src/rich_text.rs @@ -43,6 +43,19 @@ pub struct RichText { Option, &mut WindowContext) -> Option>>, } +impl Default for RichText { + fn default() -> Self { + Self { + text: SharedString::default(), + highlights: Vec::new(), + link_ranges: Vec::new(), + link_urls: Arc::from([]), + custom_ranges: Vec::new(), + custom_ranges_tooltip_fn: None, + } + } +} + /// Allows one to specify extra links to the rendered markdown, which can be used /// for e.g. mentions. #[derive(Debug)] @@ -52,6 +65,37 @@ pub struct Mention { } impl RichText { + pub fn new( + block: String, + mentions: &[Mention], + language_registry: &Arc, + ) -> Self { + let mut text = String::new(); + let mut highlights = Vec::new(); + let mut link_ranges = Vec::new(); + let mut link_urls = Vec::new(); + render_markdown_mut( + &block, + mentions, + language_registry, + None, + &mut text, + &mut highlights, + &mut link_ranges, + &mut link_urls, + ); + text.truncate(text.trim_end().len()); + + RichText { + text: SharedString::from(text), + link_urls: link_urls.into(), + link_ranges, + highlights, + custom_ranges: Vec::new(), + custom_ranges_tooltip_fn: None, + } + } + pub fn set_tooltip_builder_for_custom_ranges( &mut self, f: impl Fn(usize, Range, &mut WindowContext) -> Option + 'static, @@ -347,38 +391,6 @@ pub fn render_markdown_mut( } } -pub fn render_rich_text( - block: String, - mentions: &[Mention], - language_registry: &Arc, - language: Option<&Arc>, -) -> RichText { - let mut text = String::new(); - let mut highlights = Vec::new(); - let mut link_ranges = Vec::new(); - let mut link_urls = Vec::new(); - render_markdown_mut( - &block, - mentions, - language_registry, - language, - &mut text, - &mut highlights, - &mut link_ranges, - &mut link_urls, - ); - text.truncate(text.trim_end().len()); - - RichText { - text: SharedString::from(text), - link_urls: link_urls.into(), - link_ranges, - highlights, - custom_ranges: Vec::new(), - custom_ranges_tooltip_fn: None, - } -} - pub fn render_code( text: &mut String, highlights: &mut Vec<(Range, Highlight)>, diff --git a/crates/rpc/proto/zed.proto b/crates/rpc/proto/zed.proto index 7832af4b04..b3014d1748 100644 --- a/crates/rpc/proto/zed.proto +++ b/crates/rpc/proto/zed.proto @@ -1880,22 +1880,70 @@ message CompleteWithLanguageModel { repeated LanguageModelRequestMessage messages = 2; repeated string stop = 3; float temperature = 4; + repeated ChatCompletionTool tools = 5; + optional string tool_choice = 6; } +// A tool presented to the language model for its use +message ChatCompletionTool { + oneof variant { + FunctionObject function = 1; + } + + message FunctionObject { + string name = 1; + optional string description = 2; + optional string parameters = 3; + } +} + +// A message to the language model message LanguageModelRequestMessage { LanguageModelRole role = 1; string content = 2; + optional string tool_call_id = 3; + repeated ToolCall tool_calls = 4; } enum LanguageModelRole { LanguageModelUser = 0; LanguageModelAssistant = 1; LanguageModelSystem = 2; + LanguageModelTool = 3; } message LanguageModelResponseMessage { optional LanguageModelRole role = 1; optional string content = 2; + repeated ToolCallDelta tool_calls = 3; +} + +// A request to call a tool, by the language model +message ToolCall { + string id = 1; + + oneof variant { + FunctionCall function = 2; + } + + message FunctionCall { + string name = 1; + string arguments = 2; + } +} + +message ToolCallDelta { + uint32 index = 1; + optional string id = 2; + + oneof variant { + FunctionCallDelta function = 3; + } + + message FunctionCallDelta { + optional string name = 1; + optional string arguments = 2; + } } message LanguageModelResponse { diff --git a/crates/semantic_index/Cargo.toml b/crates/semantic_index/Cargo.toml index f50f17934d..5f06d4193f 100644 --- a/crates/semantic_index/Cargo.toml +++ b/crates/semantic_index/Cargo.toml @@ -12,6 +12,11 @@ workspace = true [lib] path = "src/semantic_index.rs" +[[example]] +name = "index" +path = "examples/index.rs" +crate-type = ["bin"] + [dependencies] anyhow.workspace = true client.workspace = true diff --git a/crates/semantic_index/examples/index.rs b/crates/semantic_index/examples/index.rs index 494d8a0f81..6783e07048 100644 --- a/crates/semantic_index/examples/index.rs +++ b/crates/semantic_index/examples/index.rs @@ -1,25 +1,16 @@ use client::Client; use futures::channel::oneshot; -use gpui::{App, Global, TestAppContext}; +use gpui::{App, Global}; use language::language_settings::AllLanguageSettings; use project::Project; use semantic_index::{OpenAiEmbeddingModel, OpenAiEmbeddingProvider, SemanticIndex}; use settings::SettingsStore; -use std::{path::Path, sync::Arc}; +use std::{ + path::{Path, PathBuf}, + sync::Arc, +}; use util::http::HttpClientWithUrl; -pub fn init_test(cx: &mut TestAppContext) { - _ = cx.update(|cx| { - let store = SettingsStore::test(cx); - cx.set_global(store); - language::init(cx); - Project::init_settings(cx); - SettingsStore::update(cx, |store, cx| { - store.update_user_settings::(cx, |_| {}); - }); - }); -} - fn main() { env_logger::init(); @@ -50,20 +41,21 @@ fn main() { // let embedding_provider = semantic_index::FakeEmbeddingProvider; let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); - let embedding_provider = OpenAiEmbeddingProvider::new( + + let embedding_provider = Arc::new(OpenAiEmbeddingProvider::new( http.clone(), OpenAiEmbeddingModel::TextEmbedding3Small, open_ai::OPEN_AI_API_URL.to_string(), api_key, - ); - - let semantic_index = SemanticIndex::new( - Path::new("/tmp/semantic-index-db.mdb"), - Arc::new(embedding_provider), - cx, - ); + )); cx.spawn(|mut cx| async move { + let semantic_index = SemanticIndex::new( + PathBuf::from("/tmp/semantic-index-db.mdb"), + embedding_provider, + &mut cx, + ); + let mut semantic_index = semantic_index.await.unwrap(); let project_path = Path::new(&args[1]); diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index a43d9e177c..c3eccd95f6 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -21,7 +21,7 @@ use std::{ cmp::Ordering, future::Future, ops::Range, - path::Path, + path::{Path, PathBuf}, sync::Arc, time::{Duration, SystemTime}, }; @@ -37,30 +37,29 @@ pub struct SemanticIndex { impl Global for SemanticIndex {} impl SemanticIndex { - pub fn new( - db_path: &Path, + pub async fn new( + db_path: PathBuf, embedding_provider: Arc, - cx: &mut AppContext, - ) -> Task> { - let db_path = db_path.to_path_buf(); - cx.spawn(|cx| async move { - let db_connection = cx - .background_executor() - .spawn(async move { - unsafe { - heed::EnvOpenOptions::new() - .map_size(1024 * 1024 * 1024) - .max_dbs(3000) - .open(db_path) - } - }) - .await?; - - Ok(SemanticIndex { - db_connection, - embedding_provider, - project_indices: HashMap::default(), + cx: &mut AsyncAppContext, + ) -> Result { + let db_connection = cx + .background_executor() + .spawn(async move { + std::fs::create_dir_all(&db_path)?; + unsafe { + heed::EnvOpenOptions::new() + .map_size(1024 * 1024 * 1024) + .max_dbs(3000) + .open(db_path) + } }) + .await + .context("opening database connection")?; + + Ok(SemanticIndex { + db_connection, + embedding_provider, + project_indices: HashMap::default(), }) } @@ -91,7 +90,7 @@ pub struct ProjectIndex { worktree_indices: HashMap, language_registry: Arc, fs: Arc, - last_status: Status, + pub last_status: Status, embedding_provider: Arc, _subscription: Subscription, } @@ -397,7 +396,7 @@ impl WorktreeIndex { ) -> impl Future> { let worktree = self.worktree.read(cx).as_local().unwrap().snapshot(); let worktree_abs_path = worktree.abs_path().clone(); - let scan = self.scan_updated_entries(worktree, updated_entries, cx); + let scan = self.scan_updated_entries(worktree, updated_entries.clone(), cx); let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx); let embed = self.embed_files(chunk.files, cx); let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx); @@ -498,7 +497,9 @@ impl WorktreeIndex { | project::PathChange::Updated | project::PathChange::AddedOrUpdated => { if let Some(entry) = worktree.entry_for_id(*entry_id) { - updated_entries_tx.send(entry.clone()).await?; + if entry.is_file() { + updated_entries_tx.send(entry.clone()).await?; + } } } project::PathChange::Removed => { @@ -539,7 +540,14 @@ impl WorktreeIndex { cx.spawn(async { while let Ok(entry) = entries.recv().await { let entry_abs_path = worktree_abs_path.join(&entry.path); - let Some(text) = fs.load(&entry_abs_path).await.log_err() else { + let Some(text) = fs + .load(&entry_abs_path) + .await + .with_context(|| { + format!("failed to read path {entry_abs_path:?}") + }) + .log_err() + else { continue; }; let language = language_registry @@ -683,7 +691,7 @@ impl WorktreeIndex { .context("failed to create read transaction")?; let db_entries = db.iter(&txn).context("failed to iterate database")?; for db_entry in db_entries { - let (_, db_embedded_file) = db_entry?; + let (_key, db_embedded_file) = db_entry?; for chunk in db_embedded_file.chunks { chunks_tx .send((db_embedded_file.path.clone(), chunk)) @@ -700,6 +708,7 @@ impl WorktreeIndex { cx.spawn(|cx| async move { #[cfg(debug_assertions)] let embedding_query_start = std::time::Instant::now(); + log::info!("Searching for {query}"); let mut query_embeddings = embedding_provider .embed(&[TextToEmbed::new(&query)]) @@ -876,17 +885,13 @@ mod tests { let temp_dir = tempfile::tempdir().unwrap(); - let mut semantic_index = cx - .update(|cx| { - let semantic_index = SemanticIndex::new( - Path::new(temp_dir.path()), - Arc::new(TestEmbeddingProvider), - cx, - ); - semantic_index - }) - .await - .unwrap(); + let mut semantic_index = SemanticIndex::new( + temp_dir.path().into(), + Arc::new(TestEmbeddingProvider), + &mut cx.to_async(), + ) + .await + .unwrap(); let project_path = Path::new("./fixture"); diff --git a/crates/settings/src/settings.rs b/crates/settings/src/settings.rs index e646e37f2c..e716ef5b07 100644 --- a/crates/settings/src/settings.rs +++ b/crates/settings/src/settings.rs @@ -2,6 +2,7 @@ mod keymap_file; mod settings_file; mod settings_store; +use gpui::AppContext; use rust_embed::RustEmbed; use std::{borrow::Cow, str}; use util::asset_str; @@ -19,6 +20,14 @@ pub use settings_store::{ #[exclude = "*.DS_Store"] pub struct SettingsAssets; +pub fn init(cx: &mut AppContext) { + let mut settings = SettingsStore::default(); + settings + .set_default_settings(&default_settings(), cx) + .unwrap(); + cx.set_global(settings); +} + pub fn default_settings() -> Cow<'static, str> { asset_str::("settings/default.json") } diff --git a/crates/storybook/src/story_selector.rs b/crates/storybook/src/story_selector.rs index c238542478..dcc42546fe 100644 --- a/crates/storybook/src/story_selector.rs +++ b/crates/storybook/src/story_selector.rs @@ -29,14 +29,14 @@ pub enum ComponentStory { ListHeader, ListItem, OverflowScroll, + Picker, Scroll, Tab, TabBar, + Text, TitleBar, ToggleButton, - Text, ViewportUnits, - Picker, } impl ComponentStory { diff --git a/crates/storybook/src/storybook.rs b/crates/storybook/src/storybook.rs index 015b4765fb..70853267ca 100644 --- a/crates/storybook/src/storybook.rs +++ b/crates/storybook/src/storybook.rs @@ -11,7 +11,7 @@ use gpui::{ }; use log::LevelFilter; use project::Project; -use settings::{default_settings, KeymapFile, Settings, SettingsStore}; +use settings::{KeymapFile, Settings}; use simplelog::SimpleLogger; use strum::IntoEnumIterator; use theme::{ThemeRegistry, ThemeSettings}; @@ -64,12 +64,7 @@ fn main() { gpui::App::new().with_assets(Assets).run(move |cx| { load_embedded_fonts(cx).unwrap(); - let mut store = SettingsStore::default(); - store - .set_default_settings(default_settings().as_ref(), cx) - .unwrap(); - cx.set_global(store); - + settings::init(cx); theme::init(theme::LoadThemes::All(Box::new(Assets)), cx); let selector = story_selector; @@ -122,7 +117,7 @@ impl Render for StoryWrapper { .flex() .flex_col() .size_full() - .font("Zed Mono") + .font_family("Zed Mono") .child(self.story.clone()) } } diff --git a/crates/ui/src/components.rs b/crates/ui/src/components.rs index 2a38130720..b93a997fe7 100644 --- a/crates/ui/src/components.rs +++ b/crates/ui/src/components.rs @@ -1,6 +1,7 @@ mod avatar; mod button; mod checkbox; +mod collapsible_container; mod context_menu; mod disclosure; mod divider; @@ -25,6 +26,7 @@ mod stories; pub use avatar::*; pub use button::*; pub use checkbox::*; +pub use collapsible_container::*; pub use context_menu::*; pub use disclosure::*; pub use divider::*; diff --git a/crates/ui/src/components/collapsible_container.rs b/crates/ui/src/components/collapsible_container.rs new file mode 100644 index 0000000000..5136dbd13d --- /dev/null +++ b/crates/ui/src/components/collapsible_container.rs @@ -0,0 +1,152 @@ +use crate::{prelude::*, ButtonLike}; +use smallvec::SmallVec; + +use gpui::*; + +#[derive(Default, Clone, Copy, Debug, PartialEq)] +pub enum ContainerStyle { + #[default] + None, + Card, +} + +struct ContainerStyles { + pub background_color: Hsla, + pub border_color: Hsla, + pub text_color: Hsla, +} + +#[derive(IntoElement)] +pub struct CollapsibleContainer { + id: ElementId, + base: ButtonLike, + toggle: bool, + /// A slot for content that appears before the label, like an icon or avatar. + start_slot: Option, + /// A slot for content that appears after the label, usually on the other side of the header. + /// This might be a button, a disclosure arrow, a face pile, etc. + end_slot: Option, + style: ContainerStyle, + children: SmallVec<[AnyElement; 1]>, +} + +impl CollapsibleContainer { + pub fn new(id: impl Into, toggle: bool) -> Self { + Self { + id: id.into(), + base: ButtonLike::new("button_base"), + toggle, + start_slot: None, + end_slot: None, + style: ContainerStyle::Card, + children: SmallVec::new(), + } + } + + pub fn start_slot(mut self, start_slot: impl Into>) -> Self { + self.start_slot = start_slot.into().map(IntoElement::into_any_element); + self + } + + pub fn end_slot(mut self, end_slot: impl Into>) -> Self { + self.end_slot = end_slot.into().map(IntoElement::into_any_element); + self + } + + pub fn child(mut self, child: E) -> Self { + self.children.push(child.into_any_element()); + self + } +} + +impl Clickable for CollapsibleContainer { + fn on_click(mut self, handler: impl Fn(&ClickEvent, &mut WindowContext) + 'static) -> Self { + self.base = self.base.on_click(handler); + self + } +} + +impl RenderOnce for CollapsibleContainer { + fn render(self, cx: &mut WindowContext) -> impl IntoElement { + let color = cx.theme().colors(); + + let styles = match self.style { + ContainerStyle::None => ContainerStyles { + background_color: color.ghost_element_background, + border_color: color.border_transparent, + text_color: color.text, + }, + ContainerStyle::Card => ContainerStyles { + background_color: color.elevated_surface_background, + border_color: color.border, + text_color: color.text, + }, + }; + + v_flex() + .id(self.id) + .relative() + .rounded_md() + .bg(styles.background_color) + .border() + .border_color(styles.border_color) + .text_color(styles.text_color) + .overflow_hidden() + .child( + h_flex() + .overflow_hidden() + .w_full() + .group("toggleable_container_header") + .border_b() + .border_color(if self.toggle { + styles.border_color + } else { + color.border_transparent + }) + .child( + self.base.full_width().style(ButtonStyle::Subtle).child( + div() + .h_7() + .p_1() + .flex() + .flex_1() + .items_center() + .justify_between() + .w_full() + .gap_1() + .cursor_pointer() + .group_hover("toggleable_container_header", |this| { + this.bg(color.element_hover) + }) + .child( + h_flex() + .gap_1() + .child( + IconButton::new( + "toggle_icon", + match self.toggle { + true => IconName::ChevronDown, + false => IconName::ChevronRight, + }, + ) + .icon_color(Color::Muted) + .icon_size(IconSize::XSmall), + ) + .child( + div() + .id("label_container") + .flex() + .gap_1() + .items_center() + .children(self.start_slot), + ), + ) + .child(h_flex().children(self.end_slot)), + ), + ), + ) + .when(self.toggle, |this| { + this.child(h_flex().flex_1().w_full().p_1().children(self.children)) + }) + } +} diff --git a/crates/ui/src/components/title_bar/windows_window_controls.rs b/crates/ui/src/components/title_bar/windows_window_controls.rs index 8352bed678..7c12395168 100644 --- a/crates/ui/src/components/title_bar/windows_window_controls.rs +++ b/crates/ui/src/components/title_bar/windows_window_controls.rs @@ -110,7 +110,7 @@ impl RenderOnce for WindowsCaptionButton { .content_center() .w(width) .h_full() - .font("Segoe Fluent Icons") + .font_family("Segoe Fluent Icons") .text_size(px(10.0)) .hover(|style| style.bg(self.hover_background_color)) .active(|style| { diff --git a/crates/ui/src/components/tooltip.rs b/crates/ui/src/components/tooltip.rs index 1ce25129ff..5d07f6b341 100644 --- a/crates/ui/src/components/tooltip.rs +++ b/crates/ui/src/components/tooltip.rs @@ -95,7 +95,7 @@ pub fn tooltip_container( div().pl_2().pt_2p5().child( v_flex() .elevation_2(cx) - .font(ui_font) + .font_family(ui_font) .text_ui() .text_color(cx.theme().colors().text) .py_1() diff --git a/crates/ui/src/styles/typography.rs b/crates/ui/src/styles/typography.rs index b4b598a7c2..cd40cb1e99 100644 --- a/crates/ui/src/styles/typography.rs +++ b/crates/ui/src/styles/typography.rs @@ -93,7 +93,7 @@ impl RenderOnce for Headline { let ui_font = ThemeSettings::get_global(cx).ui_font.family.clone(); div() - .font(ui_font) + .font_family(ui_font) .line_height(self.size.line_height()) .text_size(self.size.size()) .text_color(cx.theme().colors().text) diff --git a/crates/workspace/src/pane.rs b/crates/workspace/src/pane.rs index c4f62715b3..ddc81a3e12 100644 --- a/crates/workspace/src/pane.rs +++ b/crates/workspace/src/pane.rs @@ -2928,6 +2928,6 @@ impl Render for DraggedTab { .selected(self.is_active) .child(label) .render(cx) - .font(ui_font) + .font_family(ui_font) } } diff --git a/crates/workspace/src/workspace.rs b/crates/workspace/src/workspace.rs index 2a6ae60701..94890bc15c 100644 --- a/crates/workspace/src/workspace.rs +++ b/crates/workspace/src/workspace.rs @@ -4004,7 +4004,7 @@ impl Render for Workspace { .size_full() .flex() .flex_col() - .font(ui_font) + .font_family(ui_font) .gap_0() .justify_start() .items_start() diff --git a/crates/zed/Cargo.toml b/crates/zed/Cargo.toml index d005138dba..2eb188f768 100644 --- a/crates/zed/Cargo.toml +++ b/crates/zed/Cargo.toml @@ -19,6 +19,7 @@ activity_indicator.workspace = true anyhow.workspace = true assets.workspace = true assistant.workspace = true +assistant2.workspace = true audio.workspace = true auto_update.workspace = true backtrace = "0.3" diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index 3bc06f9ac6..ea5aafcb66 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -231,27 +231,18 @@ fn init_ui(args: Args) { load_embedded_fonts(cx); - let mut store = SettingsStore::default(); - store - .set_default_settings(default_settings().as_ref(), cx) - .unwrap(); - cx.set_global(store); + settings::init(cx); handle_settings_file_changes(user_settings_file_rx, cx); handle_keymap_file_changes(user_keymap_file_rx, cx); + client::init_settings(cx); - - let clock = Arc::new(clock::RealSystemClock); - let http = Arc::new(HttpClientWithUrl::new( - &client::ClientSettings::get_global(cx).server_url, - )); - - let client = client::Client::new(clock, http.clone(), cx); + let client = Client::production(cx); let mut languages = LanguageRegistry::new(login_shell_env_loaded, cx.background_executor().clone()); let copilot_language_server_id = languages.next_language_server_id(); languages.set_language_server_download_dir(paths::LANGUAGES_DIR.clone()); let languages = Arc::new(languages); - let node_runtime = RealNodeRuntime::new(http.clone()); + let node_runtime = RealNodeRuntime::new(client.http_client()); language::init(cx); languages::init(languages.clone(), node_runtime.clone(), cx); @@ -271,11 +262,14 @@ fn init_ui(args: Args) { diagnostics::init(cx); copilot::init( copilot_language_server_id, - http.clone(), + client.http_client(), node_runtime.clone(), cx, ); + assistant::init(client.clone(), cx); + assistant2::init(client.clone(), cx); + init_inline_completion_provider(client.telemetry().clone(), cx); extension::init( @@ -297,7 +291,7 @@ fn init_ui(args: Args) { cx.observe_global::({ let languages = languages.clone(); - let http = http.clone(); + let http = client.http_client(); let client = client.clone(); move |cx| { @@ -345,7 +339,7 @@ fn init_ui(args: Args) { AppState::set_global(Arc::downgrade(&app_state), cx); audio::init(Assets, cx); - auto_update::init(http.clone(), cx); + auto_update::init(client.http_client(), cx); workspace::init(app_state.clone(), cx); recent_projects::init(cx); @@ -378,7 +372,7 @@ fn init_ui(args: Args) { initialize_workspace(app_state.clone(), cx); // todo(linux): unblock this - upload_panics_and_crashes(http.clone(), cx); + upload_panics_and_crashes(client.http_client(), cx); cx.activate(true); diff --git a/crates/zed/src/zed.rs b/crates/zed/src/zed.rs index 963b7c3237..fbbec18601 100644 --- a/crates/zed/src/zed.rs +++ b/crates/zed/src/zed.rs @@ -3,7 +3,6 @@ mod only_instance; mod open_listener; pub use app_menus::*; -use assistant::AssistantPanel; use breadcrumbs::Breadcrumbs; use client::ZED_URL_SCHEME; use collections::VecDeque; @@ -181,10 +180,12 @@ pub fn initialize_workspace(app_state: Arc, cx: &mut AppContext) { }) }); } + cx.spawn(|workspace_handle, mut cx| async move { + let assistant_panel = + assistant::AssistantPanel::load(workspace_handle.clone(), cx.clone()); let project_panel = ProjectPanel::load(workspace_handle.clone(), cx.clone()); let terminal_panel = TerminalPanel::load(workspace_handle.clone(), cx.clone()); - let assistant_panel = AssistantPanel::load(workspace_handle.clone(), cx.clone()); let channels_panel = collab_ui::collab_panel::CollabPanel::load(workspace_handle.clone(), cx.clone()); let chat_panel = @@ -193,6 +194,7 @@ pub fn initialize_workspace(app_state: Arc, cx: &mut AppContext) { workspace_handle.clone(), cx.clone(), ); + let ( project_panel, terminal_panel, @@ -210,9 +212,9 @@ pub fn initialize_workspace(app_state: Arc, cx: &mut AppContext) { )?; workspace_handle.update(&mut cx, |workspace, cx| { + workspace.add_panel(assistant_panel, cx); workspace.add_panel(project_panel, cx); workspace.add_panel(terminal_panel, cx); - workspace.add_panel(assistant_panel, cx); workspace.add_panel(channels_panel, cx); workspace.add_panel(chat_panel, cx); workspace.add_panel(notification_panel, cx); @@ -221,6 +223,30 @@ pub fn initialize_workspace(app_state: Arc, cx: &mut AppContext) { }) .detach(); + let mut current_user = app_state.user_store.read(cx).watch_current_user(); + + cx.spawn(|workspace_handle, mut cx| async move { + while let Some(user) = current_user.next().await { + if user.is_some() { + // User known now, can check feature flags / staff + // At this point, should have the user with staff status available + let use_assistant2 = cx.update(|cx| assistant2::enabled(cx))?; + if use_assistant2 { + let panel = + assistant2::AssistantPanel::load(workspace_handle.clone(), cx.clone()) + .await?; + workspace_handle.update(&mut cx, |workspace, cx| { + workspace.add_panel(panel, cx); + })?; + } + + break; + } + } + anyhow::Ok(()) + }) + .detach(); + workspace .register_action(about) .register_action(|_, _: &Minimize, cx| { @@ -3028,11 +3054,7 @@ mod tests { ]) .unwrap(); let themes = ThemeRegistry::default(); - let mut settings = SettingsStore::default(); - settings - .set_default_settings(&settings::default_settings(), cx) - .unwrap(); - cx.set_global(settings); + settings::init(cx); theme::init(theme::LoadThemes::JustBase, cx); let mut has_default_theme = false; diff --git a/script/zed-local b/script/zed-local index 0ab6f0d0d1..69a44fe94a 100755 --- a/script/zed-local +++ b/script/zed-local @@ -147,7 +147,7 @@ setTimeout(() => { } spawn(binaryPath, i == 0 ? args : [], { stdio: "inherit", - env: { + env: Object.assign({}, process.env, { ZED_IMPERSONATE: users[i], ZED_WINDOW_POSITION: position, ZED_STATELESS: isStateful && i == 0 ? "1" : "", @@ -157,9 +157,8 @@ setTimeout(() => { ZED_ADMIN_API_TOKEN: "secret", ZED_WINDOW_SIZE: size, ZED_CLIENT_CHECKSUM_SEED: "development-checksum-seed", - PATH: process.env.PATH, RUST_LOG: process.env.RUST_LOG || "info", - }, + }), }); } }, 0.1);