This commit is contained in:
Max Brunsfeld 2024-11-12 22:09:15 -08:00
parent 55cd99cdc4
commit 1adb9f4c45
5 changed files with 133 additions and 65 deletions

View file

@ -157,8 +157,7 @@ impl ContextStore {
// I tried doing this in a subscription on the `ExtensionStore`, but it never seemed to fire.
//
// We should find a more elegant way to do this.
let context_server_factory_registry =
ContextServerFactoryRegistry::default_global(cx);
let context_server_factory_registry = ContextServerFactoryRegistry::global(cx);
cx.spawn(|context_store, mut cx| async move {
loop {
let mut servers_to_register = Vec::new();

View file

@ -19,7 +19,7 @@ pub const CONTEXT_SERVERS_NAMESPACE: &'static str = "context_servers";
pub fn init(cx: &mut AppContext) {
ContextServerSettings::register(cx);
ContextServerFactoryRegistry::default_global(cx);
ContextServerFactoryRegistry::global(cx);
CommandPaletteFilter::update_global(cx, |filter, _cx| {
filter.hide_namespace(CONTEXT_SERVERS_NAMESPACE);

View file

@ -21,17 +21,19 @@ use std::sync::Arc;
use anyhow::{bail, Result};
use async_trait::async_trait;
use collections::{HashMap, HashSet};
use futures::{Future, FutureExt};
use gpui::{AsyncAppContext, EventEmitter, ModelContext, Task};
use futures::{channel::mpsc, Future, FutureExt};
use gpui::{AsyncAppContext, EventEmitter, Model, ModelContext, Task};
use log;
use parking_lot::RwLock;
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsSources};
use settings::{Settings, SettingsSources, SettingsStore};
use smol::stream::StreamExt;
use crate::{
client::{self, Client},
types,
types, ContextServerFactoryRegistry,
};
#[derive(Deserialize, Serialize, Default, Clone, PartialEq, Eq, JsonSchema, Debug)]
@ -160,8 +162,12 @@ impl ContextServer for NativeContextServer {
/// must go through the `GlobalContextServerManager` which holds
/// a model to the ContextServerManager.
pub struct ContextServerManager {
servers: HashMap<Arc<str>, Arc<dyn ContextServer>>,
project: Model<Project>,
servers: HashMap<Arc<str>, (ServerCommand, Arc<dyn ContextServer>)>,
pending_servers: HashSet<Arc<str>>,
registry: Model<ContextServerFactoryRegistry>,
_maintain_context_servers: Task<Result<()>>,
_subscriptions: [gpui::Subscription; 2],
}
pub enum Event {
@ -171,17 +177,40 @@ pub enum Event {
impl EventEmitter<Event> for ContextServerManager {}
impl Default for ContextServerManager {
fn default() -> Self {
Self::new()
}
}
impl ContextServerManager {
pub fn new() -> Self {
pub fn new(
project: Model<Project>,
registry: Model<ContextServerFactoryRegistry>,
cx: &mut ModelContext<Self>,
) -> Self {
let (tx, mut rx) = mpsc::unbounded::<()>();
Self {
servers: HashMap::default(),
project,
servers: Default::default(),
pending_servers: HashSet::default(),
_subscriptions: [
cx.observe(&registry, {
let tx = tx.clone();
move |_this, _, _cx| {
tx.unbounded_send(()).ok();
}
}),
cx.observe_global::<SettingsStore>({
let tx = tx.clone();
move |_this, _cx| {
tx.unbounded_send(()).ok();
}
}),
],
registry,
_maintain_context_servers: cx.spawn(|this, mut cx| async move {
while let Some(_) = rx.next().await {
this.update(&mut cx, |this, cx| {
this.registered_servers_changed(cx);
})?;
}
Ok(())
}),
}
}
@ -271,6 +300,45 @@ impl ContextServerManager {
self.servers.values().cloned().collect()
}
fn registered_servers_changed(
&mut self,
cx: &mut ModelContext<ContextServerManager>,
) -> Task<()> {
let worktree_id = self
.project
.read(cx)
.visible_worktrees(cx)
.next()
.map(|worktree| worktree.read(cx).id());
let settings = ContextServerSettings::get(
worktree_id.map(|worktree_id| settings::SettingsLocation {
worktree_id,
path: Path::new(""),
}),
cx,
);
let registry = self.registry.read(cx);
let mut settings_iter = settings.context_servers.iter().peekable();
let mut registry_iter = registry.context_servers.iter().peekable();
// loop {
// let mut setting_command = None;
// let mut registered_command = None;
// let mut server_id;
// match (settings_iter.peek(), registry_iter.peek()) {
// (None, None) => break,
// (None, Some((id, value))) => {
// server_id = id.clone();
// registered_command = value;
// }
// (Some(_), None) => continue,
// (Some(_), Some(_)) => continue,
// }
// }
}
pub fn maintain_servers(&mut self, settings: &ContextServerSettings, cx: &ModelContext<Self>) {
let current_servers = self
.servers()

View file

@ -1,76 +1,59 @@
use std::sync::Arc;
use anyhow::Result;
use collections::HashMap;
use gpui::{AppContext, AsyncAppContext, Global, Model, ReadGlobal, Task};
use parking_lot::RwLock;
use collections::BTreeMap;
use futures::future::BoxFuture;
use gpui::{AppContext, AsyncAppContext, Context, Global, Model, ReadGlobal, Task};
use project::Project;
use crate::ContextServer;
use crate::manager::ServerCommand;
pub type ContextServerFactory = Arc<
dyn Fn(Model<Project>, &AsyncAppContext) -> Task<Result<Arc<dyn ContextServer>>>
dyn Fn(Model<Project>, &AsyncAppContext) -> BoxFuture<Result<ServerCommand>>
+ Send
+ Sync
+ 'static,
>;
#[derive(Default)]
struct GlobalContextServerFactoryRegistry(Arc<ContextServerFactoryRegistry>);
struct GlobalContextServerFactoryRegistry(Model<ContextServerFactoryRegistry>);
impl Global for GlobalContextServerFactoryRegistry {}
#[derive(Default)]
struct ContextServerFactoryRegistryState {
context_servers: HashMap<Arc<str>, ContextServerFactory>,
}
#[derive(Default)]
pub struct ContextServerFactoryRegistry {
state: RwLock<ContextServerFactoryRegistryState>,
pub context_servers: BTreeMap<Arc<str>, ContextServerFactory>,
}
impl ContextServerFactoryRegistry {
/// Returns the global [`ContextServerFactoryRegistry`].
pub fn global(cx: &AppContext) -> Arc<Self> {
pub fn global(cx: &mut AppContext) -> Model<Self> {
if !cx.has_global::<GlobalContextServerFactoryRegistry>() {
let registry = cx.new_model(|_| ContextServerFactoryRegistry::new());
cx.set_global(GlobalContextServerFactoryRegistry(registry));
}
GlobalContextServerFactoryRegistry::global(cx).0.clone()
}
/// Returns the global [`ContextServerFactoryRegistry`].
///
/// Inserts a default [`ContextServerFactoryRegistry`] if one does not yet exist.
pub fn default_global(cx: &mut AppContext) -> Arc<Self> {
cx.default_global::<GlobalContextServerFactoryRegistry>()
.0
.clone()
}
pub fn new() -> Arc<Self> {
Arc::new(Self {
state: RwLock::new(ContextServerFactoryRegistryState {
context_servers: HashMap::default(),
}),
})
pub fn new() -> Self {
Self {
context_servers: Default::default(),
}
}
pub fn context_server_factories(&self) -> Vec<(Arc<str>, ContextServerFactory)> {
self.state
.read()
.context_servers
self.context_servers
.iter()
.map(|(id, factory)| (id.clone(), factory.clone()))
.collect()
}
/// Registers the provided [`ContextServerFactory`].
pub fn register_server_factory(&self, id: Arc<str>, factory: ContextServerFactory) {
let mut state = self.state.write();
state.context_servers.insert(id, factory);
pub fn register_server_factory(&mut self, id: Arc<str>, factory: ContextServerFactory) {
self.context_servers.insert(id, factory);
}
/// Unregisters the [`ContextServerFactory`] for the server with the given ID.
pub fn unregister_server_factory_by_id(&self, server_id: &str) {
let mut state = self.state.write();
state.context_servers.remove(server_id);
pub fn unregister_server_factory_by_id(&mut self, server_id: &str) {
self.context_servers.remove(server_id);
}
}

View file

@ -5,6 +5,7 @@ use assistant_slash_command::SlashCommandRegistry;
use context_servers::ContextServerFactoryRegistry;
use extension_host::{extension_lsp_adapter::ExtensionLspAdapter, wasm_host};
use fs::Fs;
use futures::FutureExt;
use gpui::{AppContext, BackgroundExecutor, Task};
use indexed_docs::{IndexedDocsRegistry, ProviderId};
use language::{LanguageRegistry, LanguageServerBinaryStatus, LoadedLanguage};
@ -83,18 +84,35 @@ impl extension_host::ExtensionRegistrationHooks for ConcreteExtensionRegistratio
self.context_server_factory_registry
.register_server_factory(
id.clone(),
Arc::new({
move |project, cx| {
let id = id.clone();
let extension = extension.clone();
let host = host.clone();
cx.spawn(|cx| async move {
let context_server =
ExtensionContextServer::new(extension, host, id, project, cx)
.await?;
anyhow::Ok(Arc::new(context_server) as _)
})
Arc::new(move |project, cx| {
async move {
let extension_project =
project.update(&mut cx, |project, cx| ExtensionProject {
worktree_ids: project
.visible_worktrees(cx)
.map(|worktree| worktree.read(cx).id().to_proto())
.collect(),
})?;
let command = extension
.call({
let id = id.clone();
|extension, store| {
async move {
let project =
store.data_mut().table().push(extension_project)?;
let command = extension
.call_context_server_command(store, id.clone(), project)
.await?
.map_err(|e| anyhow!("{}", e))?;
anyhow::Ok(command)
}
.boxed()
}
})
.await?;
Ok(command)
}
.boxed()
}),
);
}