diff --git a/assets/icons/ai_anthropic.svg b/assets/icons/ai_anthropic.svg
new file mode 100644
index 0000000000..b0705c2774
--- /dev/null
+++ b/assets/icons/ai_anthropic.svg
@@ -0,0 +1,4 @@
+
diff --git a/assets/icons/ai_google.svg b/assets/icons/ai_google.svg
new file mode 100644
index 0000000000..953e585cb1
--- /dev/null
+++ b/assets/icons/ai_google.svg
@@ -0,0 +1,3 @@
+
diff --git a/assets/icons/ai_ollama.svg b/assets/icons/ai_ollama.svg
new file mode 100644
index 0000000000..b28f3788e8
--- /dev/null
+++ b/assets/icons/ai_ollama.svg
@@ -0,0 +1,5 @@
+
diff --git a/assets/icons/ai_open_ai.svg b/assets/icons/ai_open_ai.svg
new file mode 100644
index 0000000000..e659a472d8
--- /dev/null
+++ b/assets/icons/ai_open_ai.svg
@@ -0,0 +1,3 @@
+
diff --git a/assets/icons/ai_zed.svg b/assets/icons/ai_zed.svg
new file mode 100644
index 0000000000..1c6bb8ad63
--- /dev/null
+++ b/assets/icons/ai_zed.svg
@@ -0,0 +1,10 @@
+
diff --git a/crates/assistant/src/model_selector.rs b/crates/assistant/src/model_selector.rs
index d3d23b148d..b3d0344e71 100644
--- a/crates/assistant/src/model_selector.rs
+++ b/crates/assistant/src/model_selector.rs
@@ -3,7 +3,8 @@ use std::sync::Arc;
use crate::assistant_settings::AssistantSettings;
use fs::Fs;
use gpui::SharedString;
-use language_model::LanguageModelRegistry;
+use language_model::{LanguageModelAvailability, LanguageModelRegistry};
+use proto::Plan;
use settings::update_settings_file;
use ui::{prelude::*, ContextMenu, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
@@ -37,7 +38,7 @@ impl ModelSelector {
}
impl RenderOnce for ModelSelector {
- fn render(self, _: &mut WindowContext) -> impl IntoElement {
+ fn render(self, _cx: &mut WindowContext) -> impl IntoElement {
let mut menu = PopoverMenu::new("model-switcher");
if let Some(handle) = self.handle {
menu = menu.with_handle(handle);
@@ -63,10 +64,25 @@ impl RenderOnce for ModelSelector {
.into_iter()
.enumerate()
{
+ let provider_icon = provider.icon();
+ let provider_name = provider.name().0.clone();
+
if index > 0 {
menu = menu.separator();
}
- menu = menu.header(provider.name().0);
+ menu = menu.custom_row(move |_| {
+ h_flex()
+ .pb_1()
+ .gap_1p5()
+ .w_full()
+ .child(
+ Icon::new(provider_icon)
+ .color(Color::Muted)
+ .size(IconSize::Small),
+ )
+ .child(Label::new(provider_name.clone()))
+ .into_any_element()
+ });
let available_models = provider.provided_models(cx);
if available_models.is_empty() {
@@ -109,19 +125,44 @@ impl RenderOnce for ModelSelector {
let id = available_model.id();
let provider_id = available_model.provider_id();
let model_name = available_model.name().0.clone();
- let _availability = available_model.availability();
+ let availability = available_model.availability();
let selected_model = selected_model.clone();
let selected_provider = selected_provider.clone();
- move |_| {
+ move |cx| {
h_flex()
.w_full()
.justify_between()
- .child(Label::new(model_name.clone()))
- .when(
+ .font_buffer(cx)
+ .min_w(px(260.))
+ .child(
+ h_flex()
+ .gap_2()
+ .child(Label::new(model_name.clone()))
+ .children(match availability {
+ LanguageModelAvailability::Public => None,
+ LanguageModelAvailability::RequiresPlan(
+ Plan::Free,
+ ) => None,
+ LanguageModelAvailability::RequiresPlan(
+ Plan::ZedPro,
+ ) => Some(
+ Label::new("Pro")
+ .size(LabelSize::XSmall)
+ .color(Color::Muted),
+ ),
+ }),
+ )
+ .child(div().when(
selected_model.as_ref() == Some(&id)
&& selected_provider.as_ref() == Some(&provider_id),
- |this| this.child(Icon::new(IconName::Check)),
- )
+ |this| {
+ this.child(
+ Icon::new(IconName::Check)
+ .color(Color::Accent)
+ .size(IconSize::Small),
+ )
+ },
+ ))
.into_any()
}
},
diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs
index 2fa82197ab..0048d4c50d 100644
--- a/crates/language_model/src/language_model.rs
+++ b/crates/language_model/src/language_model.rs
@@ -22,6 +22,7 @@ pub use role::*;
use schemars::JsonSchema;
use serde::de::DeserializeOwned;
use std::{future::Future, sync::Arc};
+use ui::IconName;
pub fn init(
user_store: Model,
@@ -102,6 +103,9 @@ pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
pub trait LanguageModelProvider: 'static {
fn id(&self) -> LanguageModelProviderId;
fn name(&self) -> LanguageModelProviderName;
+ fn icon(&self) -> IconName {
+ IconName::ZedAssistant
+ }
fn provided_models(&self, cx: &AppContext) -> Vec>;
fn load_model(&self, _model: Arc, _cx: &AppContext) {}
fn is_authenticated(&self, cx: &AppContext) -> bool;
diff --git a/crates/language_model/src/provider/anthropic.rs b/crates/language_model/src/provider/anthropic.rs
index 3b96035ffb..a9b9ec686e 100644
--- a/crates/language_model/src/provider/anthropic.rs
+++ b/crates/language_model/src/provider/anthropic.rs
@@ -115,6 +115,10 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
LanguageModelProviderName(PROVIDER_NAME.into())
}
+ fn icon(&self) -> IconName {
+ IconName::AiAnthropic
+ }
+
fn provided_models(&self, cx: &AppContext) -> Vec> {
let mut models = BTreeMap::default();
diff --git a/crates/language_model/src/provider/cloud.rs b/crates/language_model/src/provider/cloud.rs
index 3dda8b24e1..f0056cb0d8 100644
--- a/crates/language_model/src/provider/cloud.rs
+++ b/crates/language_model/src/provider/cloud.rs
@@ -23,7 +23,7 @@ use crate::{LanguageModelAvailability, LanguageModelProvider};
use super::anthropic::count_anthropic_tokens;
pub const PROVIDER_ID: &str = "zed.dev";
-pub const PROVIDER_NAME: &str = "Zed AI";
+pub const PROVIDER_NAME: &str = "Zed";
#[derive(Default, Clone, Debug, PartialEq)]
pub struct ZedDotDevSettings {
@@ -128,6 +128,10 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
LanguageModelProviderName(PROVIDER_NAME.into())
}
+ fn icon(&self) -> IconName {
+ IconName::AiZed
+ }
+
fn provided_models(&self, cx: &AppContext) -> Vec> {
let mut models = BTreeMap::default();
diff --git a/crates/language_model/src/provider/copilot_chat.rs b/crates/language_model/src/provider/copilot_chat.rs
index 7de8510389..fc0831a603 100644
--- a/crates/language_model/src/provider/copilot_chat.rs
+++ b/crates/language_model/src/provider/copilot_chat.rs
@@ -91,6 +91,10 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider {
LanguageModelProviderName(PROVIDER_NAME.into())
}
+ fn icon(&self) -> IconName {
+ IconName::Copilot
+ }
+
fn provided_models(&self, _cx: &AppContext) -> Vec> {
CopilotChatModel::iter()
.map(|model| {
diff --git a/crates/language_model/src/provider/google.rs b/crates/language_model/src/provider/google.rs
index 0547d7e98c..b84bf87fdb 100644
--- a/crates/language_model/src/provider/google.rs
+++ b/crates/language_model/src/provider/google.rs
@@ -97,6 +97,10 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
LanguageModelProviderName(PROVIDER_NAME.into())
}
+ fn icon(&self) -> IconName {
+ IconName::AiGoogle
+ }
+
fn provided_models(&self, cx: &AppContext) -> Vec> {
let mut models = BTreeMap::default();
diff --git a/crates/language_model/src/provider/ollama.rs b/crates/language_model/src/provider/ollama.rs
index c2aace0ba1..717feb5fad 100644
--- a/crates/language_model/src/provider/ollama.rs
+++ b/crates/language_model/src/provider/ollama.rs
@@ -108,6 +108,10 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
LanguageModelProviderName(PROVIDER_NAME.into())
}
+ fn icon(&self) -> IconName {
+ IconName::AiOllama
+ }
+
fn provided_models(&self, cx: &AppContext) -> Vec> {
self.state
.read(cx)
diff --git a/crates/language_model/src/provider/open_ai.rs b/crates/language_model/src/provider/open_ai.rs
index d8a683c7db..b7842dd72b 100644
--- a/crates/language_model/src/provider/open_ai.rs
+++ b/crates/language_model/src/provider/open_ai.rs
@@ -98,6 +98,10 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
LanguageModelProviderName(PROVIDER_NAME.into())
}
+ fn icon(&self) -> IconName {
+ IconName::AiOpenAi
+ }
+
fn provided_models(&self, cx: &AppContext) -> Vec> {
let mut models = BTreeMap::default();
diff --git a/crates/ui/src/components/icon.rs b/crates/ui/src/components/icon.rs
index c56c041557..fdcdeafa01 100644
--- a/crates/ui/src/components/icon.rs
+++ b/crates/ui/src/components/icon.rs
@@ -106,6 +106,11 @@ impl IconSize {
)]
pub enum IconName {
Ai,
+ AiAnthropic,
+ AiOpenAi,
+ AiGoogle,
+ AiOllama,
+ AiZed,
ArrowCircle,
ArrowDown,
ArrowDownFromLine,
@@ -262,6 +267,11 @@ impl IconName {
pub fn path(self) -> &'static str {
match self {
IconName::Ai => "icons/ai.svg",
+ IconName::AiAnthropic => "icons/ai_anthropic.svg",
+ IconName::AiOpenAi => "icons/ai_open_ai.svg",
+ IconName::AiGoogle => "icons/ai_google.svg",
+ IconName::AiOllama => "icons/ai_ollama.svg",
+ IconName::AiZed => "icons/ai_zed.svg",
IconName::ArrowCircle => "icons/arrow_circle.svg",
IconName::ArrowDown => "icons/arrow_down.svg",
IconName::ArrowDownFromLine => "icons/arrow_down_from_line.svg",
diff --git a/crates/ui/src/styles/typography.rs b/crates/ui/src/styles/typography.rs
index 56f981f393..4afd3b9303 100644
--- a/crates/ui/src/styles/typography.rs
+++ b/crates/ui/src/styles/typography.rs
@@ -8,6 +8,22 @@ use crate::{rems_from_px, Color};
/// Extends [`gpui::Styled`] with typography-related styling methods.
pub trait StyledTypography: Styled + Sized {
+ /// Sets the font family to the buffer font.
+ fn font_buffer(self, cx: &WindowContext) -> Self {
+ let settings = ThemeSettings::get_global(cx);
+ let buffer_font_family = settings.buffer_font.family.clone();
+
+ self.font_family(buffer_font_family)
+ }
+
+ /// Sets the font family to the UI font.
+ fn font_ui(self, cx: &WindowContext) -> Self {
+ let settings = ThemeSettings::get_global(cx);
+ let ui_font_family = settings.ui_font.family.clone();
+
+ self.font_family(ui_font_family)
+ }
+
/// Sets the text size using a [`UiTextSize`].
fn text_ui_size(self, size: TextSize, cx: &WindowContext) -> Self {
self.text_size(size.rems(cx))