diff --git a/assets/icons/ai_anthropic.svg b/assets/icons/ai_anthropic.svg new file mode 100644 index 0000000000..b0705c2774 --- /dev/null +++ b/assets/icons/ai_anthropic.svg @@ -0,0 +1,4 @@ + + + + diff --git a/assets/icons/ai_google.svg b/assets/icons/ai_google.svg new file mode 100644 index 0000000000..953e585cb1 --- /dev/null +++ b/assets/icons/ai_google.svg @@ -0,0 +1,3 @@ + + + diff --git a/assets/icons/ai_ollama.svg b/assets/icons/ai_ollama.svg new file mode 100644 index 0000000000..b28f3788e8 --- /dev/null +++ b/assets/icons/ai_ollama.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/assets/icons/ai_open_ai.svg b/assets/icons/ai_open_ai.svg new file mode 100644 index 0000000000..e659a472d8 --- /dev/null +++ b/assets/icons/ai_open_ai.svg @@ -0,0 +1,3 @@ + + + diff --git a/assets/icons/ai_zed.svg b/assets/icons/ai_zed.svg new file mode 100644 index 0000000000..1c6bb8ad63 --- /dev/null +++ b/assets/icons/ai_zed.svg @@ -0,0 +1,10 @@ + + + + + + + + + + diff --git a/crates/assistant/src/model_selector.rs b/crates/assistant/src/model_selector.rs index d3d23b148d..b3d0344e71 100644 --- a/crates/assistant/src/model_selector.rs +++ b/crates/assistant/src/model_selector.rs @@ -3,7 +3,8 @@ use std::sync::Arc; use crate::assistant_settings::AssistantSettings; use fs::Fs; use gpui::SharedString; -use language_model::LanguageModelRegistry; +use language_model::{LanguageModelAvailability, LanguageModelRegistry}; +use proto::Plan; use settings::update_settings_file; use ui::{prelude::*, ContextMenu, PopoverMenu, PopoverMenuHandle, PopoverTrigger}; @@ -37,7 +38,7 @@ impl ModelSelector { } impl RenderOnce for ModelSelector { - fn render(self, _: &mut WindowContext) -> impl IntoElement { + fn render(self, _cx: &mut WindowContext) -> impl IntoElement { let mut menu = PopoverMenu::new("model-switcher"); if let Some(handle) = self.handle { menu = menu.with_handle(handle); @@ -63,10 +64,25 @@ impl RenderOnce for ModelSelector { .into_iter() .enumerate() { + let provider_icon = provider.icon(); + let provider_name = provider.name().0.clone(); + if index > 0 { menu = menu.separator(); } - menu = menu.header(provider.name().0); + menu = menu.custom_row(move |_| { + h_flex() + .pb_1() + .gap_1p5() + .w_full() + .child( + Icon::new(provider_icon) + .color(Color::Muted) + .size(IconSize::Small), + ) + .child(Label::new(provider_name.clone())) + .into_any_element() + }); let available_models = provider.provided_models(cx); if available_models.is_empty() { @@ -109,19 +125,44 @@ impl RenderOnce for ModelSelector { let id = available_model.id(); let provider_id = available_model.provider_id(); let model_name = available_model.name().0.clone(); - let _availability = available_model.availability(); + let availability = available_model.availability(); let selected_model = selected_model.clone(); let selected_provider = selected_provider.clone(); - move |_| { + move |cx| { h_flex() .w_full() .justify_between() - .child(Label::new(model_name.clone())) - .when( + .font_buffer(cx) + .min_w(px(260.)) + .child( + h_flex() + .gap_2() + .child(Label::new(model_name.clone())) + .children(match availability { + LanguageModelAvailability::Public => None, + LanguageModelAvailability::RequiresPlan( + Plan::Free, + ) => None, + LanguageModelAvailability::RequiresPlan( + Plan::ZedPro, + ) => Some( + Label::new("Pro") + .size(LabelSize::XSmall) + .color(Color::Muted), + ), + }), + ) + .child(div().when( selected_model.as_ref() == Some(&id) && selected_provider.as_ref() == Some(&provider_id), - |this| this.child(Icon::new(IconName::Check)), - ) + |this| { + this.child( + Icon::new(IconName::Check) + .color(Color::Accent) + .size(IconSize::Small), + ) + }, + )) .into_any() } }, diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 2fa82197ab..0048d4c50d 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -22,6 +22,7 @@ pub use role::*; use schemars::JsonSchema; use serde::de::DeserializeOwned; use std::{future::Future, sync::Arc}; +use ui::IconName; pub fn init( user_store: Model, @@ -102,6 +103,9 @@ pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema { pub trait LanguageModelProvider: 'static { fn id(&self) -> LanguageModelProviderId; fn name(&self) -> LanguageModelProviderName; + fn icon(&self) -> IconName { + IconName::ZedAssistant + } fn provided_models(&self, cx: &AppContext) -> Vec>; fn load_model(&self, _model: Arc, _cx: &AppContext) {} fn is_authenticated(&self, cx: &AppContext) -> bool; diff --git a/crates/language_model/src/provider/anthropic.rs b/crates/language_model/src/provider/anthropic.rs index 3b96035ffb..a9b9ec686e 100644 --- a/crates/language_model/src/provider/anthropic.rs +++ b/crates/language_model/src/provider/anthropic.rs @@ -115,6 +115,10 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider { LanguageModelProviderName(PROVIDER_NAME.into()) } + fn icon(&self) -> IconName { + IconName::AiAnthropic + } + fn provided_models(&self, cx: &AppContext) -> Vec> { let mut models = BTreeMap::default(); diff --git a/crates/language_model/src/provider/cloud.rs b/crates/language_model/src/provider/cloud.rs index 3dda8b24e1..f0056cb0d8 100644 --- a/crates/language_model/src/provider/cloud.rs +++ b/crates/language_model/src/provider/cloud.rs @@ -23,7 +23,7 @@ use crate::{LanguageModelAvailability, LanguageModelProvider}; use super::anthropic::count_anthropic_tokens; pub const PROVIDER_ID: &str = "zed.dev"; -pub const PROVIDER_NAME: &str = "Zed AI"; +pub const PROVIDER_NAME: &str = "Zed"; #[derive(Default, Clone, Debug, PartialEq)] pub struct ZedDotDevSettings { @@ -128,6 +128,10 @@ impl LanguageModelProvider for CloudLanguageModelProvider { LanguageModelProviderName(PROVIDER_NAME.into()) } + fn icon(&self) -> IconName { + IconName::AiZed + } + fn provided_models(&self, cx: &AppContext) -> Vec> { let mut models = BTreeMap::default(); diff --git a/crates/language_model/src/provider/copilot_chat.rs b/crates/language_model/src/provider/copilot_chat.rs index 7de8510389..fc0831a603 100644 --- a/crates/language_model/src/provider/copilot_chat.rs +++ b/crates/language_model/src/provider/copilot_chat.rs @@ -91,6 +91,10 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider { LanguageModelProviderName(PROVIDER_NAME.into()) } + fn icon(&self) -> IconName { + IconName::Copilot + } + fn provided_models(&self, _cx: &AppContext) -> Vec> { CopilotChatModel::iter() .map(|model| { diff --git a/crates/language_model/src/provider/google.rs b/crates/language_model/src/provider/google.rs index 0547d7e98c..b84bf87fdb 100644 --- a/crates/language_model/src/provider/google.rs +++ b/crates/language_model/src/provider/google.rs @@ -97,6 +97,10 @@ impl LanguageModelProvider for GoogleLanguageModelProvider { LanguageModelProviderName(PROVIDER_NAME.into()) } + fn icon(&self) -> IconName { + IconName::AiGoogle + } + fn provided_models(&self, cx: &AppContext) -> Vec> { let mut models = BTreeMap::default(); diff --git a/crates/language_model/src/provider/ollama.rs b/crates/language_model/src/provider/ollama.rs index c2aace0ba1..717feb5fad 100644 --- a/crates/language_model/src/provider/ollama.rs +++ b/crates/language_model/src/provider/ollama.rs @@ -108,6 +108,10 @@ impl LanguageModelProvider for OllamaLanguageModelProvider { LanguageModelProviderName(PROVIDER_NAME.into()) } + fn icon(&self) -> IconName { + IconName::AiOllama + } + fn provided_models(&self, cx: &AppContext) -> Vec> { self.state .read(cx) diff --git a/crates/language_model/src/provider/open_ai.rs b/crates/language_model/src/provider/open_ai.rs index d8a683c7db..b7842dd72b 100644 --- a/crates/language_model/src/provider/open_ai.rs +++ b/crates/language_model/src/provider/open_ai.rs @@ -98,6 +98,10 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider { LanguageModelProviderName(PROVIDER_NAME.into()) } + fn icon(&self) -> IconName { + IconName::AiOpenAi + } + fn provided_models(&self, cx: &AppContext) -> Vec> { let mut models = BTreeMap::default(); diff --git a/crates/ui/src/components/icon.rs b/crates/ui/src/components/icon.rs index c56c041557..fdcdeafa01 100644 --- a/crates/ui/src/components/icon.rs +++ b/crates/ui/src/components/icon.rs @@ -106,6 +106,11 @@ impl IconSize { )] pub enum IconName { Ai, + AiAnthropic, + AiOpenAi, + AiGoogle, + AiOllama, + AiZed, ArrowCircle, ArrowDown, ArrowDownFromLine, @@ -262,6 +267,11 @@ impl IconName { pub fn path(self) -> &'static str { match self { IconName::Ai => "icons/ai.svg", + IconName::AiAnthropic => "icons/ai_anthropic.svg", + IconName::AiOpenAi => "icons/ai_open_ai.svg", + IconName::AiGoogle => "icons/ai_google.svg", + IconName::AiOllama => "icons/ai_ollama.svg", + IconName::AiZed => "icons/ai_zed.svg", IconName::ArrowCircle => "icons/arrow_circle.svg", IconName::ArrowDown => "icons/arrow_down.svg", IconName::ArrowDownFromLine => "icons/arrow_down_from_line.svg", diff --git a/crates/ui/src/styles/typography.rs b/crates/ui/src/styles/typography.rs index 56f981f393..4afd3b9303 100644 --- a/crates/ui/src/styles/typography.rs +++ b/crates/ui/src/styles/typography.rs @@ -8,6 +8,22 @@ use crate::{rems_from_px, Color}; /// Extends [`gpui::Styled`] with typography-related styling methods. pub trait StyledTypography: Styled + Sized { + /// Sets the font family to the buffer font. + fn font_buffer(self, cx: &WindowContext) -> Self { + let settings = ThemeSettings::get_global(cx); + let buffer_font_family = settings.buffer_font.family.clone(); + + self.font_family(buffer_font_family) + } + + /// Sets the font family to the UI font. + fn font_ui(self, cx: &WindowContext) -> Self { + let settings = ThemeSettings::get_global(cx); + let ui_font_family = settings.ui_font.family.clone(); + + self.font_family(ui_font_family) + } + /// Sets the text size using a [`UiTextSize`]. fn text_ui_size(self, size: TextSize, cx: &WindowContext) -> Self { self.text_size(size.rems(cx))