From 9824e40878e0ba7ccaa899e8e87c3dd49eb01e26 Mon Sep 17 00:00:00 2001 From: Bennet Bo Fenner Date: Wed, 5 Jun 2024 23:06:44 +0200 Subject: [PATCH] lsp: Handle responses in background thread (#12640) Release Notes: - Improved performance when handling large responses from language servers --------- Co-authored-by: Piotr --- crates/lsp/src/input_handler.rs | 159 +++++++++++++++++++++++++++++++ crates/lsp/src/lsp.rs | 161 ++++++-------------------------- 2 files changed, 189 insertions(+), 131 deletions(-) create mode 100644 crates/lsp/src/input_handler.rs diff --git a/crates/lsp/src/input_handler.rs b/crates/lsp/src/input_handler.rs new file mode 100644 index 0000000000..9b599c6afd --- /dev/null +++ b/crates/lsp/src/input_handler.rs @@ -0,0 +1,159 @@ +use std::str; +use std::sync::Arc; + +use anyhow::{anyhow, Result}; +use collections::HashMap; +use futures::{ + channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender}, + AsyncBufReadExt, AsyncRead, AsyncReadExt as _, +}; +use gpui::{BackgroundExecutor, Task}; +use log::warn; +use parking_lot::Mutex; +use smol::io::BufReader; + +use crate::{ + AnyNotification, AnyResponse, IoHandler, IoKind, RequestId, ResponseHandler, CONTENT_LEN_HEADER, +}; + +const HEADER_DELIMITER: &'static [u8; 4] = b"\r\n\r\n"; +/// Handler for stdout of language server. +pub struct LspStdoutHandler { + pub(super) loop_handle: Task>, + pub(super) notifications_channel: UnboundedReceiver, +} + +pub(self) async fn read_headers( + reader: &mut BufReader, + buffer: &mut Vec, +) -> Result<()> +where + Stdout: AsyncRead + Unpin + Send + 'static, +{ + loop { + if buffer.len() >= HEADER_DELIMITER.len() + && buffer[(buffer.len() - HEADER_DELIMITER.len())..] == HEADER_DELIMITER[..] + { + return Ok(()); + } + + if reader.read_until(b'\n', buffer).await? == 0 { + return Err(anyhow!("cannot read LSP message headers")); + } + } +} + +impl LspStdoutHandler { + pub fn new( + stdout: Input, + response_handlers: Arc>>>, + io_handlers: Arc>>, + cx: BackgroundExecutor, + ) -> Self + where + Input: AsyncRead + Unpin + Send + 'static, + { + let (tx, notifications_channel) = unbounded(); + let loop_handle = cx.spawn(Self::handler(stdout, tx, response_handlers, io_handlers)); + Self { + loop_handle, + notifications_channel, + } + } + + async fn handler( + stdout: Input, + notifications_sender: UnboundedSender, + response_handlers: Arc>>>, + io_handlers: Arc>>, + ) -> anyhow::Result<()> + where + Input: AsyncRead + Unpin + Send + 'static, + { + let mut stdout = BufReader::new(stdout); + + let mut buffer = Vec::new(); + + loop { + buffer.clear(); + + read_headers(&mut stdout, &mut buffer).await?; + + let headers = std::str::from_utf8(&buffer)?; + + let message_len = headers + .split('\n') + .find(|line| line.starts_with(CONTENT_LEN_HEADER)) + .and_then(|line| line.strip_prefix(CONTENT_LEN_HEADER)) + .ok_or_else(|| anyhow!("invalid LSP message header {headers:?}"))? + .trim_end() + .parse()?; + + buffer.resize(message_len, 0); + stdout.read_exact(&mut buffer).await?; + + if let Ok(message) = str::from_utf8(&buffer) { + log::trace!("incoming message: {message}"); + for handler in io_handlers.lock().values_mut() { + handler(IoKind::StdOut, message); + } + } + + if let Ok(msg) = serde_json::from_slice::(&buffer) { + notifications_sender.unbounded_send(msg)?; + } else if let Ok(AnyResponse { + id, error, result, .. + }) = serde_json::from_slice(&buffer) + { + let mut response_handlers = response_handlers.lock(); + if let Some(handler) = response_handlers + .as_mut() + .and_then(|handlers| handlers.remove(&id)) + { + drop(response_handlers); + if let Some(error) = error { + handler(Err(error)); + } else if let Some(result) = result { + handler(Ok(result.get().into())); + } else { + handler(Ok("null".into())); + } + } + } else { + warn!( + "failed to deserialize LSP message:\n{}", + std::str::from_utf8(&buffer)? + ); + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[gpui::test] + async fn test_read_headers() { + let mut buf = Vec::new(); + let mut reader = smol::io::BufReader::new(b"Content-Length: 123\r\n\r\n" as &[u8]); + read_headers(&mut reader, &mut buf).await.unwrap(); + assert_eq!(buf, b"Content-Length: 123\r\n\r\n"); + + let mut buf = Vec::new(); + let mut reader = smol::io::BufReader::new(b"Content-Type: application/vscode-jsonrpc\r\nContent-Length: 1235\r\n\r\n{\"somecontent\":123}" as &[u8]); + read_headers(&mut reader, &mut buf).await.unwrap(); + assert_eq!( + buf, + b"Content-Type: application/vscode-jsonrpc\r\nContent-Length: 1235\r\n\r\n" + ); + + let mut buf = Vec::new(); + let mut reader = smol::io::BufReader::new(b"Content-Length: 1235\r\nContent-Type: application/vscode-jsonrpc\r\n\r\n{\"somecontent\":true}" as &[u8]); + read_headers(&mut reader, &mut buf).await.unwrap(); + assert_eq!( + buf, + b"Content-Length: 1235\r\nContent-Type: application/vscode-jsonrpc\r\n\r\n" + ); + } +} diff --git a/crates/lsp/src/lsp.rs b/crates/lsp/src/lsp.rs index 73a7129ba9..d5051e4766 100644 --- a/crates/lsp/src/lsp.rs +++ b/crates/lsp/src/lsp.rs @@ -1,4 +1,5 @@ -use log::warn; +mod input_handler; + pub use lsp_types::request::*; pub use lsp_types::*; @@ -12,7 +13,7 @@ use serde::{de::DeserializeOwned, Deserialize, Serialize}; use serde_json::{json, value::RawValue, Value}; use smol::{ channel, - io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader}, + io::{AsyncBufReadExt, AsyncWriteExt, BufReader}, process::{self, Child}, }; @@ -25,7 +26,6 @@ use std::{ io::Write, path::PathBuf, pin::Pin, - str::{self, FromStr as _}, sync::{ atomic::{AtomicI32, Ordering::SeqCst}, Arc, Weak, @@ -36,13 +36,13 @@ use std::{ use std::{path::Path, process::Stdio}; use util::{ResultExt, TryFutureExt}; -const HEADER_DELIMITER: &'static [u8; 4] = b"\r\n\r\n"; const JSON_RPC_VERSION: &str = "2.0"; const CONTENT_LEN_HEADER: &str = "Content-Length: "; + const LSP_REQUEST_TIMEOUT: Duration = Duration::from_secs(60 * 2); const SERVER_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5); -type NotificationHandler = Box, &str, AsyncAppContext)>; +type NotificationHandler = Box, Value, AsyncAppContext)>; type ResponseHandler = Box)>; type IoHandler = Box; @@ -164,13 +164,12 @@ struct Notification<'a, T> { /// Language server RPC notification message before it is deserialized into a concrete type. #[derive(Debug, Clone, Deserialize)] -struct AnyNotification<'a> { +struct AnyNotification { #[serde(default)] id: Option, - #[serde(borrow)] - method: &'a str, - #[serde(borrow, default)] - params: Option<&'a RawValue>, + method: String, + #[serde(default)] + params: Option, } #[derive(Debug, Serialize, Deserialize)] @@ -297,13 +296,7 @@ impl LanguageServer { "Language server with id {} sent unhandled notification {}:\n{}", server_id, notification.method, - serde_json::to_string_pretty( - ¬ification - .params - .and_then(|params| Value::from_str(params.get()).ok()) - .unwrap_or(Value::Null) - ) - .unwrap(), + serde_json::to_string_pretty(¬ification.params).unwrap(), ); }, ); @@ -418,79 +411,36 @@ impl LanguageServer { Stdout: AsyncRead + Unpin + Send + 'static, F: FnMut(AnyNotification) + 'static + Send, { - let mut stdout = BufReader::new(stdout); + use smol::stream::StreamExt; + let stdout = BufReader::new(stdout); let _clear_response_handlers = util::defer({ let response_handlers = response_handlers.clone(); move || { response_handlers.lock().take(); } }); - let mut buffer = Vec::new(); - loop { - buffer.clear(); + let mut input_handler = input_handler::LspStdoutHandler::new( + stdout, + response_handlers, + io_handlers, + cx.background_executor().clone(), + ); - read_headers(&mut stdout, &mut buffer).await?; - - let headers = std::str::from_utf8(&buffer)?; - - let message_len = headers - .split('\n') - .find(|line| line.starts_with(CONTENT_LEN_HEADER)) - .and_then(|line| line.strip_prefix(CONTENT_LEN_HEADER)) - .ok_or_else(|| anyhow!("invalid LSP message header {headers:?}"))? - .trim_end() - .parse()?; - - buffer.resize(message_len, 0); - stdout.read_exact(&mut buffer).await?; - - if let Ok(message) = str::from_utf8(&buffer) { - log::trace!("incoming message: {message}"); - for handler in io_handlers.lock().values_mut() { - handler(IoKind::StdOut, message); - } - } - - if let Ok(msg) = serde_json::from_slice::(&buffer) { + while let Some(msg) = input_handler.notifications_channel.next().await { + { let mut notification_handlers = notification_handlers.lock(); - if let Some(handler) = notification_handlers.get_mut(msg.method) { - handler( - msg.id, - msg.params.map(|params| params.get()).unwrap_or("null"), - cx.clone(), - ); + if let Some(handler) = notification_handlers.get_mut(msg.method.as_str()) { + handler(msg.id, msg.params.unwrap_or(Value::Null), cx.clone()); } else { drop(notification_handlers); on_unhandled_notification(msg); } - } else if let Ok(AnyResponse { - id, error, result, .. - }) = serde_json::from_slice(&buffer) - { - let mut response_handlers = response_handlers.lock(); - if let Some(handler) = response_handlers - .as_mut() - .and_then(|handlers| handlers.remove(&id)) - { - drop(response_handlers); - if let Some(error) = error { - handler(Err(error)); - } else if let Some(result) = result { - handler(Ok(result.get().into())); - } else { - handler(Ok("null".into())); - } - } - } else { - warn!( - "failed to deserialize LSP message:\n{}", - std::str::from_utf8(&buffer)? - ); } - // Don't starve the main thread when receiving lots of messages at once. + // Don't starve the main thread when receiving lots of notifications at once. smol::future::yield_now().await; } + input_handler.loop_handle.await } async fn handle_stderr( @@ -512,7 +462,7 @@ impl LanguageServer { return Ok(()); } - if let Ok(message) = str::from_utf8(&buffer) { + if let Ok(message) = std::str::from_utf8(&buffer) { log::trace!("incoming stderr message:{message}"); for handler in io_handlers.lock().values_mut() { handler(IoKind::StdErr, message); @@ -850,7 +800,7 @@ impl LanguageServer { let prev_handler = self.notification_handlers.lock().insert( method, Box::new(move |_, params, cx| { - if let Some(params) = serde_json::from_str(params).log_err() { + if let Some(params) = serde_json::from_value(params).log_err() { f(params, cx); } }), @@ -878,7 +828,7 @@ impl LanguageServer { method, Box::new(move |id, params, cx| { if let Some(id) = id { - match serde_json::from_str(params) { + match serde_json::from_value(params) { Ok(params) => { let response = f(params, cx.clone()); cx.foreground_executor() @@ -910,12 +860,7 @@ impl LanguageServer { } Err(error) => { - log::error!( - "error deserializing {} request: {:?}, message: {:?}", - method, - error, - params - ); + log::error!("error deserializing {} request: {:?}", method, error); let response = AnyResponse { jsonrpc: JSON_RPC_VERSION, id, @@ -1202,10 +1147,7 @@ impl FakeLanguageServer { notifications_tx .try_send(( msg.method.to_string(), - msg.params - .map(|raw_value| raw_value.get()) - .unwrap_or("null") - .to_string(), + msg.params.unwrap_or(Value::Null).to_string(), )) .ok(); }, @@ -1372,30 +1314,11 @@ impl FakeLanguageServer { } } -pub(self) async fn read_headers( - reader: &mut BufReader, - buffer: &mut Vec, -) -> Result<()> -where - Stdout: AsyncRead + Unpin + Send + 'static, -{ - loop { - if buffer.len() >= HEADER_DELIMITER.len() - && buffer[(buffer.len() - HEADER_DELIMITER.len())..] == HEADER_DELIMITER[..] - { - return Ok(()); - } - - if reader.read_until(b'\n', buffer).await? == 0 { - return Err(anyhow!("cannot read LSP message headers")); - } - } -} - #[cfg(test)] mod tests { use super::*; use gpui::TestAppContext; + use std::str::FromStr; #[ctor::ctor] fn init_logger() { @@ -1475,30 +1398,6 @@ mod tests { fake.receive_notification::().await; } - #[gpui::test] - async fn test_read_headers() { - let mut buf = Vec::new(); - let mut reader = smol::io::BufReader::new(b"Content-Length: 123\r\n\r\n" as &[u8]); - read_headers(&mut reader, &mut buf).await.unwrap(); - assert_eq!(buf, b"Content-Length: 123\r\n\r\n"); - - let mut buf = Vec::new(); - let mut reader = smol::io::BufReader::new(b"Content-Type: application/vscode-jsonrpc\r\nContent-Length: 1235\r\n\r\n{\"somecontent\":123}" as &[u8]); - read_headers(&mut reader, &mut buf).await.unwrap(); - assert_eq!( - buf, - b"Content-Type: application/vscode-jsonrpc\r\nContent-Length: 1235\r\n\r\n" - ); - - let mut buf = Vec::new(); - let mut reader = smol::io::BufReader::new(b"Content-Length: 1235\r\nContent-Type: application/vscode-jsonrpc\r\n\r\n{\"somecontent\":true}" as &[u8]); - read_headers(&mut reader, &mut buf).await.unwrap(); - assert_eq!( - buf, - b"Content-Length: 1235\r\nContent-Type: application/vscode-jsonrpc\r\n\r\n" - ); - } - #[gpui::test] fn test_deserialize_string_digit_id() { let json = r#"{"jsonrpc":"2.0","id":"2","method":"workspace/configuration","params":{"items":[{"scopeUri":"file:///Users/mph/Devel/personal/hello-scala/","section":"metals"}]}}"#;