diff --git a/crates/lsp/src/lsp.rs b/crates/lsp/src/lsp.rs index b6a4d8513e..86352d225d 100644 --- a/crates/lsp/src/lsp.rs +++ b/crates/lsp/src/lsp.rs @@ -20,10 +20,10 @@ use std::{ future::Future, io::Write, path::PathBuf, - str::FromStr, + str::{self, FromStr as _}, sync::{ atomic::{AtomicUsize, Ordering::SeqCst}, - Arc, + Arc, Weak, }, }; use std::{path::Path, process::Stdio}; @@ -34,16 +34,18 @@ const CONTENT_LEN_HEADER: &str = "Content-Length: "; type NotificationHandler = Box, &str, AsyncAppContext)>; type ResponseHandler = Box)>; +type IoHandler = Box; pub struct LanguageServer { server_id: LanguageServerId, next_id: AtomicUsize, - outbound_tx: channel::Sender>, + outbound_tx: channel::Sender, name: String, capabilities: ServerCapabilities, code_action_kinds: Option>, notification_handlers: Arc>>, response_handlers: Arc>>>, + io_handlers: Arc>>, executor: Arc, #[allow(clippy::type_complexity)] io_tasks: Mutex>, Task>)>>, @@ -56,9 +58,16 @@ pub struct LanguageServer { #[repr(transparent)] pub struct LanguageServerId(pub usize); -pub struct Subscription { - method: &'static str, - notification_handlers: Arc>>, +pub enum Subscription { + Detached, + Notification { + method: &'static str, + notification_handlers: Arc>>, + }, + Io { + id: usize, + io_handlers: Weak>>, + }, } #[derive(Serialize, Deserialize)] @@ -177,33 +186,40 @@ impl LanguageServer { Stdout: AsyncRead + Unpin + Send + 'static, F: FnMut(AnyNotification) + 'static + Send, { - let (outbound_tx, outbound_rx) = channel::unbounded::>(); + let (outbound_tx, outbound_rx) = channel::unbounded::(); + let (output_done_tx, output_done_rx) = barrier::channel(); let notification_handlers = Arc::new(Mutex::new(HashMap::<_, NotificationHandler>::default())); let response_handlers = Arc::new(Mutex::new(Some(HashMap::<_, ResponseHandler>::default()))); + let io_handlers = Arc::new(Mutex::new(HashMap::default())); let input_task = cx.spawn(|cx| { - let notification_handlers = notification_handlers.clone(); - let response_handlers = response_handlers.clone(); Self::handle_input( stdout, on_unhandled_notification, - notification_handlers, - response_handlers, + notification_handlers.clone(), + response_handlers.clone(), + io_handlers.clone(), cx, ) .log_err() }); - let (output_done_tx, output_done_rx) = barrier::channel(); let output_task = cx.background().spawn({ - let response_handlers = response_handlers.clone(); - Self::handle_output(stdin, outbound_rx, output_done_tx, response_handlers).log_err() + Self::handle_output( + stdin, + outbound_rx, + output_done_tx, + response_handlers.clone(), + io_handlers.clone(), + ) + .log_err() }); Self { server_id, notification_handlers, response_handlers, + io_handlers, name: Default::default(), capabilities: Default::default(), code_action_kinds, @@ -226,6 +242,7 @@ impl LanguageServer { mut on_unhandled_notification: F, notification_handlers: Arc>>, response_handlers: Arc>>>, + io_handlers: Arc>>, cx: AsyncAppContext, ) -> anyhow::Result<()> where @@ -252,7 +269,13 @@ impl LanguageServer { buffer.resize(message_len, 0); stdout.read_exact(&mut buffer).await?; - log::trace!("incoming message:{}", String::from_utf8_lossy(&buffer)); + + if let Ok(message) = str::from_utf8(&buffer) { + log::trace!("incoming message:{}", message); + for handler in io_handlers.lock().values_mut() { + handler(true, message); + } + } if let Ok(msg) = serde_json::from_slice::(&buffer) { if let Some(handler) = notification_handlers.lock().get_mut(msg.method) { @@ -291,9 +314,10 @@ impl LanguageServer { async fn handle_output( stdin: Stdin, - outbound_rx: channel::Receiver>, + outbound_rx: channel::Receiver, output_done_tx: barrier::Sender, response_handlers: Arc>>>, + io_handlers: Arc>>, ) -> anyhow::Result<()> where Stdin: AsyncWrite + Unpin + Send + 'static, @@ -307,13 +331,17 @@ impl LanguageServer { }); let mut content_len_buffer = Vec::new(); while let Ok(message) = outbound_rx.recv().await { - log::trace!("outgoing message:{}", String::from_utf8_lossy(&message)); + log::trace!("outgoing message:{}", message); + for handler in io_handlers.lock().values_mut() { + handler(false, &message); + } + content_len_buffer.clear(); write!(content_len_buffer, "{}", message.len()).unwrap(); stdin.write_all(CONTENT_LEN_HEADER.as_bytes()).await?; stdin.write_all(&content_len_buffer).await?; stdin.write_all("\r\n\r\n".as_bytes()).await?; - stdin.write_all(&message).await?; + stdin.write_all(message.as_bytes()).await?; stdin.flush().await?; } drop(output_done_tx); @@ -464,6 +492,19 @@ impl LanguageServer { self.on_custom_request(T::METHOD, f) } + #[must_use] + pub fn on_io(&self, f: F) -> Subscription + where + F: 'static + Send + FnMut(bool, &str), + { + let id = self.next_id.fetch_add(1, SeqCst); + self.io_handlers.lock().insert(id, Box::new(f)); + Subscription::Io { + id, + io_handlers: Arc::downgrade(&self.io_handlers), + } + } + pub fn remove_request_handler(&self) { self.notification_handlers.lock().remove(T::METHOD); } @@ -490,7 +531,7 @@ impl LanguageServer { prev_handler.is_none(), "registered multiple handlers for the same LSP method" ); - Subscription { + Subscription::Notification { method, notification_handlers: self.notification_handlers.clone(), } @@ -537,7 +578,7 @@ impl LanguageServer { }, }; if let Some(response) = - serde_json::to_vec(&response).log_err() + serde_json::to_string(&response).log_err() { outbound_tx.try_send(response).ok(); } @@ -560,7 +601,7 @@ impl LanguageServer { message: error.to_string(), }), }; - if let Some(response) = serde_json::to_vec(&response).log_err() { + if let Some(response) = serde_json::to_string(&response).log_err() { outbound_tx.try_send(response).ok(); } } @@ -572,7 +613,7 @@ impl LanguageServer { prev_handler.is_none(), "registered multiple handlers for the same LSP method" ); - Subscription { + Subscription::Notification { method, notification_handlers: self.notification_handlers.clone(), } @@ -612,14 +653,14 @@ impl LanguageServer { fn request_internal( next_id: &AtomicUsize, response_handlers: &Mutex>>, - outbound_tx: &channel::Sender>, + outbound_tx: &channel::Sender, params: T::Params, ) -> impl 'static + Future> where T::Result: 'static + Send, { let id = next_id.fetch_add(1, SeqCst); - let message = serde_json::to_vec(&Request { + let message = serde_json::to_string(&Request { jsonrpc: JSON_RPC_VERSION, id, method: T::METHOD, @@ -662,10 +703,10 @@ impl LanguageServer { } fn notify_internal( - outbound_tx: &channel::Sender>, + outbound_tx: &channel::Sender, params: T::Params, ) -> Result<()> { - let message = serde_json::to_vec(&Notification { + let message = serde_json::to_string(&Notification { jsonrpc: JSON_RPC_VERSION, method: T::METHOD, params, @@ -686,7 +727,7 @@ impl Drop for LanguageServer { impl Subscription { pub fn detach(mut self) { - self.method = ""; + *(&mut self) = Self::Detached; } } @@ -698,7 +739,20 @@ impl fmt::Display for LanguageServerId { impl Drop for Subscription { fn drop(&mut self) { - self.notification_handlers.lock().remove(self.method); + match self { + Subscription::Detached => {} + Subscription::Notification { + method, + notification_handlers, + } => { + notification_handlers.lock().remove(method); + } + Subscription::Io { id, io_handlers } => { + if let Some(io_handlers) = io_handlers.upgrade() { + io_handlers.lock().remove(id); + } + } + } } }