mirror of
https://github.com/zed-industries/zed.git
synced 2025-01-08 10:02:56 +00:00
assistant panel: Fix panel not reloading after entering credentials (#15531)
This is the revised version of #15527. We also added new events to notify subscribers when new providers are added or removed. Co-Authored-by: Thorsten <thorsten@zed.dev> Release Notes: - N/A --------- Co-authored-by: Thorsten <thorsten@zed.dev> Co-authored-by: Thorsten Ball <mrnugget@gmail.com>
This commit is contained in:
parent
a31dba9fc1
commit
821ce2fc7c
11 changed files with 119 additions and 68 deletions
|
@ -223,9 +223,17 @@ fn init_language_model_settings(cx: &mut AppContext) {
|
|||
|
||||
cx.observe_global::<SettingsStore>(update_active_language_model_from_settings)
|
||||
.detach();
|
||||
cx.observe(&LanguageModelRegistry::global(cx), |_, cx| {
|
||||
update_active_language_model_from_settings(cx)
|
||||
})
|
||||
cx.subscribe(
|
||||
&LanguageModelRegistry::global(cx),
|
||||
|_, event: &language_model::Event, cx| match event {
|
||||
language_model::Event::ProviderStateChanged
|
||||
| language_model::Event::AddedProvider(_)
|
||||
| language_model::Event::RemovedProvider(_) => {
|
||||
update_active_language_model_from_settings(cx);
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
)
|
||||
.detach();
|
||||
}
|
||||
|
||||
|
|
|
@ -394,8 +394,15 @@ impl AssistantPanel {
|
|||
cx.subscribe(&context_store, Self::handle_context_store_event),
|
||||
cx.subscribe(
|
||||
&LanguageModelRegistry::global(cx),
|
||||
|this, _, _: &language_model::ActiveModelChanged, cx| {
|
||||
this.completion_provider_changed(cx);
|
||||
|this, _, event: &language_model::Event, cx| match event {
|
||||
language_model::Event::ActiveModelChanged => {
|
||||
this.completion_provider_changed(cx);
|
||||
}
|
||||
language_model::Event::ProviderStateChanged
|
||||
| language_model::Event::AddedProvider(_)
|
||||
| language_model::Event::RemovedProvider(_) => {
|
||||
this.ensure_authenticated(cx);
|
||||
}
|
||||
},
|
||||
),
|
||||
];
|
||||
|
@ -588,6 +595,11 @@ impl AssistantPanel {
|
|||
}
|
||||
|
||||
fn ensure_authenticated(&mut self, cx: &mut ViewContext<Self>) {
|
||||
if self.is_authenticated(cx) {
|
||||
self.set_authentication_prompt(None, cx);
|
||||
return;
|
||||
}
|
||||
|
||||
let Some(provider_id) = LanguageModelRegistry::read_global(cx)
|
||||
.active_provider()
|
||||
.map(|p| p.id())
|
||||
|
@ -596,29 +608,35 @@ impl AssistantPanel {
|
|||
};
|
||||
|
||||
let load_credentials = self.authenticate(cx);
|
||||
let task = cx.spawn(|this, mut cx| async move {
|
||||
let _ = load_credentials.await;
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.show_authentication_prompt(cx);
|
||||
})
|
||||
.log_err();
|
||||
});
|
||||
|
||||
self.authenticate_provider_task = Some((provider_id, task));
|
||||
self.authenticate_provider_task = Some((
|
||||
provider_id,
|
||||
cx.spawn(|this, mut cx| async move {
|
||||
let _ = load_credentials.await;
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.show_authentication_prompt(cx);
|
||||
this.authenticate_provider_task = None;
|
||||
})
|
||||
.log_err();
|
||||
}),
|
||||
));
|
||||
}
|
||||
|
||||
fn show_authentication_prompt(&mut self, cx: &mut ViewContext<Self>) {
|
||||
let prompt = Self::authentication_prompt(cx);
|
||||
self.set_authentication_prompt(prompt, cx);
|
||||
}
|
||||
|
||||
fn set_authentication_prompt(&mut self, prompt: Option<AnyView>, cx: &mut ViewContext<Self>) {
|
||||
if self.active_context_editor(cx).is_none() {
|
||||
self.new_context(cx);
|
||||
}
|
||||
|
||||
let authentication_prompt = Self::authentication_prompt(cx);
|
||||
for context_editor in self.context_editors(cx) {
|
||||
context_editor.update(cx, |editor, cx| {
|
||||
editor.set_authentication_prompt(authentication_prompt.clone(), cx);
|
||||
editor.set_authentication_prompt(prompt.clone(), cx);
|
||||
});
|
||||
}
|
||||
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
|
|
|
@ -89,7 +89,20 @@ pub trait LanguageModelProvider: 'static {
|
|||
}
|
||||
|
||||
pub trait LanguageModelProviderState: 'static {
|
||||
fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription>;
|
||||
type ObservableEntity;
|
||||
|
||||
fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>>;
|
||||
|
||||
fn subscribe<T: 'static>(
|
||||
&self,
|
||||
cx: &mut gpui::ModelContext<T>,
|
||||
callback: impl Fn(&mut T, &mut gpui::ModelContext<T>) + 'static,
|
||||
) -> Option<gpui::Subscription> {
|
||||
let entity = self.observable_entity()?;
|
||||
Some(cx.observe(&entity, move |this, _, cx| {
|
||||
callback(this, cx);
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
|
||||
|
|
|
@ -44,7 +44,7 @@ pub struct AnthropicLanguageModelProvider {
|
|||
state: gpui::Model<State>,
|
||||
}
|
||||
|
||||
struct State {
|
||||
pub struct State {
|
||||
api_key: Option<String>,
|
||||
_subscription: Subscription,
|
||||
}
|
||||
|
@ -61,11 +61,12 @@ impl AnthropicLanguageModelProvider {
|
|||
Self { http_client, state }
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelProviderState for AnthropicLanguageModelProvider {
|
||||
fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
|
||||
Some(cx.observe(&self.state, |_, _, cx| {
|
||||
cx.notify();
|
||||
}))
|
||||
type ObservableEntity = State;
|
||||
|
||||
fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
|
||||
Some(self.state.clone())
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ use anyhow::{anyhow, Context as _, Result};
|
|||
use client::Client;
|
||||
use collections::BTreeMap;
|
||||
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||
use gpui::{AnyView, AppContext, AsyncAppContext, Subscription, Task};
|
||||
use gpui::{AnyView, AppContext, AsyncAppContext, ModelContext, Subscription, Task};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsStore};
|
||||
|
@ -50,16 +50,19 @@ pub struct CloudLanguageModelProvider {
|
|||
_maintain_client_status: Task<()>,
|
||||
}
|
||||
|
||||
struct State {
|
||||
pub struct State {
|
||||
client: Arc<Client>,
|
||||
status: client::Status,
|
||||
_subscription: Subscription,
|
||||
}
|
||||
|
||||
impl State {
|
||||
fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
|
||||
fn authenticate(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
|
||||
let client = self.client.clone();
|
||||
cx.spawn(move |cx| async move { client.authenticate_and_connect(true, &cx).await })
|
||||
cx.spawn(move |this, mut cx| async move {
|
||||
client.authenticate_and_connect(true, &cx).await?;
|
||||
this.update(&mut cx, |_, cx| cx.notify())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -99,10 +102,10 @@ impl CloudLanguageModelProvider {
|
|||
}
|
||||
|
||||
impl LanguageModelProviderState for CloudLanguageModelProvider {
|
||||
fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
|
||||
Some(cx.observe(&self.state, |_, _, cx| {
|
||||
cx.notify();
|
||||
}))
|
||||
type ObservableEntity = State;
|
||||
|
||||
fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
|
||||
Some(self.state.clone())
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -11,8 +11,8 @@ use futures::future::BoxFuture;
|
|||
use futures::stream::BoxStream;
|
||||
use futures::{FutureExt, StreamExt};
|
||||
use gpui::{
|
||||
percentage, svg, Animation, AnimationExt, AnyView, AppContext, AsyncAppContext, Model,
|
||||
ModelContext, Render, Subscription, Task, Transformation,
|
||||
percentage, svg, Animation, AnimationExt, AnyView, AppContext, AsyncAppContext, Model, Render,
|
||||
Subscription, Task, Transformation,
|
||||
};
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::time::Duration;
|
||||
|
@ -67,10 +67,10 @@ impl CopilotChatLanguageModelProvider {
|
|||
}
|
||||
|
||||
impl LanguageModelProviderState for CopilotChatLanguageModelProvider {
|
||||
fn subscribe<T: 'static>(&self, cx: &mut ModelContext<T>) -> Option<Subscription> {
|
||||
Some(cx.observe(&self.state, |_, _, cx| {
|
||||
cx.notify();
|
||||
}))
|
||||
type ObservableEntity = State;
|
||||
|
||||
fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
|
||||
Some(self.state.clone())
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -36,7 +36,9 @@ pub struct FakeLanguageModelProvider {
|
|||
}
|
||||
|
||||
impl LanguageModelProviderState for FakeLanguageModelProvider {
|
||||
fn subscribe<T: 'static>(&self, _: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
|
||||
type ObservableEntity = ();
|
||||
|
||||
fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
|
|
@ -44,7 +44,7 @@ pub struct GoogleLanguageModelProvider {
|
|||
state: gpui::Model<State>,
|
||||
}
|
||||
|
||||
struct State {
|
||||
pub struct State {
|
||||
api_key: Option<String>,
|
||||
_subscription: Subscription,
|
||||
}
|
||||
|
@ -63,10 +63,10 @@ impl GoogleLanguageModelProvider {
|
|||
}
|
||||
|
||||
impl LanguageModelProviderState for GoogleLanguageModelProvider {
|
||||
fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
|
||||
Some(cx.observe(&self.state, |_, _, cx| {
|
||||
cx.notify();
|
||||
}))
|
||||
type ObservableEntity = State;
|
||||
|
||||
fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
|
||||
Some(self.state.clone())
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@ pub struct OllamaLanguageModelProvider {
|
|||
state: gpui::Model<State>,
|
||||
}
|
||||
|
||||
struct State {
|
||||
pub struct State {
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
available_models: Vec<ollama::Model>,
|
||||
_subscription: Subscription,
|
||||
|
@ -87,10 +87,10 @@ impl OllamaLanguageModelProvider {
|
|||
}
|
||||
|
||||
impl LanguageModelProviderState for OllamaLanguageModelProvider {
|
||||
fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
|
||||
Some(cx.observe(&self.state, |_, _, cx| {
|
||||
cx.notify();
|
||||
}))
|
||||
type ObservableEntity = State;
|
||||
|
||||
fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
|
||||
Some(self.state.clone())
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -45,7 +45,7 @@ pub struct OpenAiLanguageModelProvider {
|
|||
state: gpui::Model<State>,
|
||||
}
|
||||
|
||||
struct State {
|
||||
pub struct State {
|
||||
api_key: Option<String>,
|
||||
_subscription: Subscription,
|
||||
}
|
||||
|
@ -64,10 +64,10 @@ impl OpenAiLanguageModelProvider {
|
|||
}
|
||||
|
||||
impl LanguageModelProviderState for OpenAiLanguageModelProvider {
|
||||
fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
|
||||
Some(cx.observe(&self.state, |_, _, cx| {
|
||||
cx.notify();
|
||||
}))
|
||||
type ObservableEntity = State;
|
||||
|
||||
fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
|
||||
Some(self.state.clone())
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -54,9 +54,7 @@ fn register_language_model_providers(
|
|||
registry.register_provider(CloudLanguageModelProvider::new(client.clone(), cx), cx);
|
||||
} else {
|
||||
registry.unregister_provider(
|
||||
&LanguageModelProviderId::from(
|
||||
crate::provider::cloud::PROVIDER_NAME.to_string(),
|
||||
),
|
||||
LanguageModelProviderId::from(crate::provider::cloud::PROVIDER_ID.to_string()),
|
||||
cx,
|
||||
);
|
||||
}
|
||||
|
@ -80,9 +78,14 @@ pub struct ActiveModel {
|
|||
model: Option<Arc<dyn LanguageModel>>,
|
||||
}
|
||||
|
||||
pub struct ActiveModelChanged;
|
||||
pub enum Event {
|
||||
ActiveModelChanged,
|
||||
ProviderStateChanged,
|
||||
AddedProvider(LanguageModelProviderId),
|
||||
RemovedProvider(LanguageModelProviderId),
|
||||
}
|
||||
|
||||
impl EventEmitter<ActiveModelChanged> for LanguageModelRegistry {}
|
||||
impl EventEmitter<Event> for LanguageModelRegistry {}
|
||||
|
||||
impl LanguageModelRegistry {
|
||||
pub fn global(cx: &AppContext) -> Model<Self> {
|
||||
|
@ -112,23 +115,26 @@ impl LanguageModelRegistry {
|
|||
provider: T,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) {
|
||||
let name = provider.id();
|
||||
let id = provider.id();
|
||||
|
||||
if let Some(subscription) = provider.subscribe(cx) {
|
||||
let subscription = provider.subscribe(cx, |_, cx| {
|
||||
cx.emit(Event::ProviderStateChanged);
|
||||
});
|
||||
if let Some(subscription) = subscription {
|
||||
subscription.detach();
|
||||
}
|
||||
|
||||
self.providers.insert(name, Arc::new(provider));
|
||||
cx.notify();
|
||||
self.providers.insert(id.clone(), Arc::new(provider));
|
||||
cx.emit(Event::AddedProvider(id));
|
||||
}
|
||||
|
||||
pub fn unregister_provider(
|
||||
&mut self,
|
||||
name: &LanguageModelProviderId,
|
||||
id: LanguageModelProviderId,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) {
|
||||
if self.providers.remove(name).is_some() {
|
||||
cx.notify();
|
||||
if self.providers.remove(&id).is_some() {
|
||||
cx.emit(Event::RemovedProvider(id));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -187,7 +193,7 @@ impl LanguageModelRegistry {
|
|||
provider,
|
||||
model: None,
|
||||
});
|
||||
cx.emit(ActiveModelChanged);
|
||||
cx.emit(Event::ActiveModelChanged);
|
||||
}
|
||||
|
||||
pub fn set_active_model(
|
||||
|
@ -202,13 +208,13 @@ impl LanguageModelRegistry {
|
|||
provider,
|
||||
model: Some(model),
|
||||
});
|
||||
cx.emit(ActiveModelChanged);
|
||||
cx.emit(Event::ActiveModelChanged);
|
||||
} else {
|
||||
log::warn!("Active model's provider not found in registry");
|
||||
}
|
||||
} else {
|
||||
self.active_model = None;
|
||||
cx.emit(ActiveModelChanged);
|
||||
cx.emit(Event::ActiveModelChanged);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -239,7 +245,7 @@ mod tests {
|
|||
assert_eq!(providers[0].id(), crate::provider::fake::provider_id());
|
||||
|
||||
registry.update(cx, |registry, cx| {
|
||||
registry.unregister_provider(&crate::provider::fake::provider_id(), cx);
|
||||
registry.unregister_provider(crate::provider::fake::provider_id(), cx);
|
||||
});
|
||||
|
||||
let providers = registry.read(cx).providers();
|
||||
|
|
Loading…
Reference in a new issue