From d45b830412a9a3099c77a00bea1f9fc11de57580 Mon Sep 17 00:00:00 2001 From: Conrad Irwin Date: Thu, 24 Oct 2024 14:37:54 -0600 Subject: [PATCH] SSH connection pooling (#19692) Co-Authored-By: Max Closes #ISSUE Release Notes: - SSH Remoting: Reuse connections across hosts --------- Co-authored-by: Max --- .../remote_editing_collaboration_tests.rs | 4 +- crates/recent_projects/src/remote_servers.rs | 11 +- crates/remote/src/ssh_session.rs | 1208 +++++++++-------- .../remote_server/src/remote_editing_tests.rs | 4 +- 4 files changed, 686 insertions(+), 541 deletions(-) diff --git a/crates/collab/src/tests/remote_editing_collaboration_tests.rs b/crates/collab/src/tests/remote_editing_collaboration_tests.rs index 52086c856c..0e13c88d94 100644 --- a/crates/collab/src/tests/remote_editing_collaboration_tests.rs +++ b/crates/collab/src/tests/remote_editing_collaboration_tests.rs @@ -26,7 +26,7 @@ async fn test_sharing_an_ssh_remote_project( .await; // Set up project on remote FS - let (port, server_ssh) = SshRemoteClient::fake_server(cx_a, server_cx); + let (opts, server_ssh) = SshRemoteClient::fake_server(cx_a, server_cx); let remote_fs = FakeFs::new(server_cx.executor()); remote_fs .insert_tree( @@ -67,7 +67,7 @@ async fn test_sharing_an_ssh_remote_project( ) }); - let client_ssh = SshRemoteClient::fake_client(port, cx_a).await; + let client_ssh = SshRemoteClient::fake_client(opts, cx_a).await; let (project_a, worktree_id) = client_a .build_ssh_project("/code/project1", client_ssh, cx_a) .await; diff --git a/crates/recent_projects/src/remote_servers.rs b/crates/recent_projects/src/remote_servers.rs index 7081afc903..d7f3beccb2 100644 --- a/crates/recent_projects/src/remote_servers.rs +++ b/crates/recent_projects/src/remote_servers.rs @@ -17,6 +17,7 @@ use gpui::{ use picker::Picker; use project::Project; use remote::SshConnectionOptions; +use remote::SshRemoteClient; use settings::update_settings_file; use settings::Settings; use ui::{ @@ -46,6 +47,7 @@ pub struct RemoteServerProjects { scroll_handle: ScrollHandle, workspace: WeakView, selectable_items: SelectableItemList, + retained_connections: Vec>, } struct CreateRemoteServer { @@ -355,6 +357,7 @@ impl RemoteServerProjects { scroll_handle: ScrollHandle::new(), workspace, selectable_items: Default::default(), + retained_connections: Vec::new(), } } @@ -424,7 +427,7 @@ impl RemoteServerProjects { let address_editor = editor.clone(); let creating = cx.spawn(move |this, mut cx| async move { match connection.await { - Some(_) => this + Some(Some(client)) => this .update(&mut cx, |this, cx| { let _ = this.workspace.update(cx, |workspace, _| { workspace @@ -432,14 +435,14 @@ impl RemoteServerProjects { .telemetry() .report_app_event("create ssh server".to_string()) }); - + this.retained_connections.push(client); this.add_ssh_server(connection_options, cx); this.mode = Mode::default_mode(); this.selectable_items.reset_selection(); cx.notify() }) .log_err(), - None => this + _ => this .update(&mut cx, |this, cx| { address_editor.update(cx, |this, _| { this.set_read_only(false); @@ -1056,7 +1059,7 @@ impl RemoteServerProjects { ); cx.spawn(|mut cx| async move { - if confirmation.await.ok() == Some(1) { + if confirmation.await.ok() == Some(0) { remote_servers .update(&mut cx, |this, cx| { this.delete_ssh_server(index, cx); diff --git a/crates/remote/src/ssh_session.rs b/crates/remote/src/ssh_session.rs index f3baa5a286..d47e0375ea 100644 --- a/crates/remote/src/ssh_session.rs +++ b/crates/remote/src/ssh_session.rs @@ -13,17 +13,18 @@ use futures::{ mpsc::{self, Sender, UnboundedReceiver, UnboundedSender}, oneshot, }, - future::BoxFuture, + future::{BoxFuture, Shared}, select, select_biased, AsyncReadExt as _, Future, FutureExt as _, StreamExt as _, }; use gpui::{ - AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, SemanticVersion, Task, - WeakModel, + AppContext, AsyncAppContext, BorrowAppContext, Context, EventEmitter, Global, Model, + ModelContext, SemanticVersion, Task, WeakModel, }; use parking_lot::Mutex; use rpc::{ proto::{self, build_typed_envelope, Envelope, EnvelopedMessage, PeerId, RequestMessage}, - AnyProtoClient, EntityMessageSubscriber, ProtoClient, ProtoMessageHandlerSet, RpcError, + AnyProtoClient, EntityMessageSubscriber, ErrorExt, ProtoClient, ProtoMessageHandlerSet, + RpcError, }; use smol::{ fs, @@ -56,7 +57,7 @@ pub struct SshSocket { socket_path: PathBuf, } -#[derive(Debug, Default, Clone, PartialEq, Eq)] +#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)] pub struct SshConnectionOptions { pub host: String, pub username: Option, @@ -290,7 +291,7 @@ const MAX_RECONNECT_ATTEMPTS: usize = 3; enum State { Connecting, Connected { - ssh_connection: Box, + ssh_connection: Arc, delegate: Arc, multiplex_task: Task>, @@ -299,7 +300,7 @@ enum State { HeartbeatMissed { missed_heartbeats: usize, - ssh_connection: Box, + ssh_connection: Arc, delegate: Arc, multiplex_task: Task>, @@ -307,7 +308,7 @@ enum State { }, Reconnecting, ReconnectFailed { - ssh_connection: Box, + ssh_connection: Arc, delegate: Arc, error: anyhow::Error, @@ -332,7 +333,7 @@ impl fmt::Display for State { } impl State { - fn ssh_connection(&self) -> Option<&dyn SshRemoteProcess> { + fn ssh_connection(&self) -> Option<&dyn RemoteConnection> { match self { Self::Connected { ssh_connection, .. } => Some(ssh_connection.as_ref()), Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection.as_ref()), @@ -462,7 +463,7 @@ impl SshRemoteClient { connection_options: SshConnectionOptions, cancellation: oneshot::Receiver<()>, delegate: Arc, - cx: &AppContext, + cx: &mut AppContext, ) -> Task>>> { cx.spawn(|mut cx| async move { let success = Box::pin(async move { @@ -479,17 +480,28 @@ impl SshRemoteClient { state: Arc::new(Mutex::new(Some(State::Connecting))), })?; - let (ssh_connection, io_task) = Self::establish_connection( + let ssh_connection = cx + .update(|cx| { + cx.update_default_global(|pool: &mut ConnectionPool, cx| { + pool.connect(connection_options, &delegate, cx) + }) + })? + .await + .map_err(|e| e.cloned())?; + let remote_binary_path = ssh_connection + .get_remote_binary_path(&delegate, false, &mut cx) + .await?; + + let io_task = ssh_connection.start_proxy( + remote_binary_path, unique_identifier, false, - connection_options, incoming_tx, outgoing_rx, connection_activity_tx, delegate.clone(), &mut cx, - ) - .await?; + ); let multiplex_task = Self::monitor(this.downgrade(), io_task, &cx); @@ -578,7 +590,7 @@ impl SshRemoteClient { } let state = lock.take().unwrap(); - let (attempts, mut ssh_connection, delegate) = match state { + let (attempts, ssh_connection, delegate) = match state { State::Connected { ssh_connection, delegate, @@ -624,7 +636,7 @@ impl SshRemoteClient { log::info!("Trying to reconnect to ssh server... Attempt {}", attempts); - let identifier = self.unique_identifier.clone(); + let unique_identifier = self.unique_identifier.clone(); let client = self.client.clone(); let reconnect_task = cx.spawn(|this, mut cx| async move { macro_rules! failed { @@ -652,19 +664,33 @@ impl SshRemoteClient { let (incoming_tx, incoming_rx) = mpsc::unbounded::(); let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1); - let (ssh_connection, io_task) = match Self::establish_connection( - identifier, - true, - connection_options, - incoming_tx, - outgoing_rx, - connection_activity_tx, - delegate.clone(), - &mut cx, - ) + let (ssh_connection, io_task) = match async { + let ssh_connection = cx + .update_global(|pool: &mut ConnectionPool, cx| { + pool.connect(connection_options, &delegate, cx) + })? + .await + .map_err(|error| error.cloned())?; + + let remote_binary_path = ssh_connection + .get_remote_binary_path(&delegate, true, &mut cx) + .await?; + + let io_task = ssh_connection.start_proxy( + remote_binary_path, + unique_identifier, + true, + incoming_tx, + outgoing_rx, + connection_activity_tx, + delegate.clone(), + &mut cx, + ); + anyhow::Ok((ssh_connection, io_task)) + } .await { - Ok((ssh_connection, ssh_process)) => (ssh_connection, ssh_process), + Ok((ssh_connection, io_task)) => (ssh_connection, io_task), Err(error) => { failed!(error, attempts, ssh_connection, delegate); } @@ -834,6 +860,563 @@ impl SshRemoteClient { } } + fn monitor( + this: WeakModel, + io_task: Task>, + cx: &AsyncAppContext, + ) -> Task> { + cx.spawn(|mut cx| async move { + let result = io_task.await; + + match result { + Ok(exit_code) => { + if let Some(error) = ProxyLaunchError::from_exit_code(exit_code) { + match error { + ProxyLaunchError::ServerNotRunning => { + log::error!("failed to reconnect because server is not running"); + this.update(&mut cx, |this, cx| { + this.set_state(State::ServerNotRunning, cx); + })?; + } + } + } else if exit_code > 0 { + log::error!("proxy process terminated unexpectedly"); + this.update(&mut cx, |this, cx| { + this.reconnect(cx).ok(); + })?; + } + } + Err(error) => { + log::warn!("ssh io task died with error: {:?}. reconnecting...", error); + this.update(&mut cx, |this, cx| { + this.reconnect(cx).ok(); + })?; + } + } + + Ok(()) + }) + } + + fn state_is(&self, check: impl FnOnce(&State) -> bool) -> bool { + self.state.lock().as_ref().map_or(false, check) + } + + fn try_set_state( + &self, + cx: &mut ModelContext, + map: impl FnOnce(&State) -> Option, + ) { + let mut lock = self.state.lock(); + let new_state = lock.as_ref().and_then(map); + + if let Some(new_state) = new_state { + lock.replace(new_state); + cx.notify(); + } + } + + fn set_state(&self, state: State, cx: &mut ModelContext) { + log::info!("setting state to '{}'", &state); + + let is_reconnect_exhausted = state.is_reconnect_exhausted(); + let is_server_not_running = state.is_server_not_running(); + self.state.lock().replace(state); + + if is_reconnect_exhausted || is_server_not_running { + cx.emit(SshRemoteEvent::Disconnected); + } + cx.notify(); + } + + pub fn subscribe_to_entity(&self, remote_id: u64, entity: &Model) { + self.client.subscribe_to_entity(remote_id, entity); + } + + pub fn ssh_args(&self) -> Option> { + self.state + .lock() + .as_ref() + .and_then(|state| state.ssh_connection()) + .map(|ssh_connection| ssh_connection.ssh_args()) + } + + pub fn proto_client(&self) -> AnyProtoClient { + self.client.clone().into() + } + + pub fn connection_string(&self) -> String { + self.connection_options.connection_string() + } + + pub fn connection_options(&self) -> SshConnectionOptions { + self.connection_options.clone() + } + + pub fn connection_state(&self) -> ConnectionState { + self.state + .lock() + .as_ref() + .map(ConnectionState::from) + .unwrap_or(ConnectionState::Disconnected) + } + + pub fn is_disconnected(&self) -> bool { + self.connection_state() == ConnectionState::Disconnected + } + + #[cfg(any(test, feature = "test-support"))] + pub fn simulate_disconnect(&self, client_cx: &mut AppContext) -> Task<()> { + let opts = self.connection_options(); + client_cx.spawn(|cx| async move { + let connection = cx + .update_global(|c: &mut ConnectionPool, _| { + if let Some(ConnectionPoolEntry::Connecting(c)) = c.connections.get(&opts) { + c.clone() + } else { + panic!("missing test connection") + } + }) + .unwrap() + .await + .unwrap(); + + connection.simulate_disconnect(&cx); + }) + } + + #[cfg(any(test, feature = "test-support"))] + pub fn fake_server( + client_cx: &mut gpui::TestAppContext, + server_cx: &mut gpui::TestAppContext, + ) -> (SshConnectionOptions, Arc) { + let port = client_cx + .update(|cx| cx.default_global::().connections.len() as u16 + 1); + let opts = SshConnectionOptions { + host: "".to_string(), + port: Some(port), + ..Default::default() + }; + let (outgoing_tx, _) = mpsc::unbounded::(); + let (_, incoming_rx) = mpsc::unbounded::(); + let server_client = + server_cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "fake-server")); + let connection: Arc = Arc::new(fake::FakeRemoteConnection { + connection_options: opts.clone(), + server_cx: fake::SendableCx::new(server_cx.to_async()), + server_channel: server_client.clone(), + }); + + client_cx.update(|cx| { + cx.update_default_global(|c: &mut ConnectionPool, cx| { + c.connections.insert( + opts.clone(), + ConnectionPoolEntry::Connecting( + cx.foreground_executor() + .spawn({ + let connection = connection.clone(); + async move { Ok(connection.clone()) } + }) + .shared(), + ), + ); + }) + }); + + (opts, server_client) + } + + #[cfg(any(test, feature = "test-support"))] + pub async fn fake_client( + opts: SshConnectionOptions, + client_cx: &mut gpui::TestAppContext, + ) -> Model { + let (_tx, rx) = oneshot::channel(); + client_cx + .update(|cx| Self::new("fake".to_string(), opts, rx, Arc::new(fake::Delegate), cx)) + .await + .unwrap() + .unwrap() + } +} + +enum ConnectionPoolEntry { + Connecting(Shared, Arc>>>), + Connected(Weak), +} + +#[derive(Default)] +struct ConnectionPool { + connections: HashMap, +} + +impl Global for ConnectionPool {} + +impl ConnectionPool { + pub fn connect( + &mut self, + opts: SshConnectionOptions, + delegate: &Arc, + cx: &mut AppContext, + ) -> Shared, Arc>>> { + let connection = self.connections.get(&opts); + match connection { + Some(ConnectionPoolEntry::Connecting(task)) => { + let delegate = delegate.clone(); + cx.spawn(|mut cx| async move { + delegate.set_status(Some("Waiting for existing connection attempt"), &mut cx); + }) + .detach(); + return task.clone(); + } + Some(ConnectionPoolEntry::Connected(ssh)) => { + if let Some(ssh) = ssh.upgrade() { + if !ssh.has_been_killed() { + return Task::ready(Ok(ssh)).shared(); + } + } + self.connections.remove(&opts); + } + None => {} + } + + let task = cx + .spawn({ + let opts = opts.clone(); + let delegate = delegate.clone(); + |mut cx| async move { + let connection = SshRemoteConnection::new(opts.clone(), delegate, &mut cx) + .await + .map(|connection| Arc::new(connection) as Arc); + + cx.update_global(|pool: &mut Self, _| { + debug_assert!(matches!( + pool.connections.get(&opts), + Some(ConnectionPoolEntry::Connecting(_)) + )); + match connection { + Ok(connection) => { + pool.connections.insert( + opts.clone(), + ConnectionPoolEntry::Connected(Arc::downgrade(&connection)), + ); + Ok(connection) + } + Err(error) => { + pool.connections.remove(&opts); + Err(Arc::new(error)) + } + } + })? + } + }) + .shared(); + + self.connections + .insert(opts.clone(), ConnectionPoolEntry::Connecting(task.clone())); + task + } +} + +impl From for AnyProtoClient { + fn from(client: SshRemoteClient) -> Self { + AnyProtoClient::new(client.client.clone()) + } +} + +#[async_trait(?Send)] +trait RemoteConnection: Send + Sync { + #[allow(clippy::too_many_arguments)] + fn start_proxy( + &self, + remote_binary_path: PathBuf, + unique_identifier: String, + reconnect: bool, + incoming_tx: UnboundedSender, + outgoing_rx: UnboundedReceiver, + connection_activity_tx: Sender<()>, + delegate: Arc, + cx: &mut AsyncAppContext, + ) -> Task>; + async fn get_remote_binary_path( + &self, + delegate: &Arc, + reconnect: bool, + cx: &mut AsyncAppContext, + ) -> Result; + async fn kill(&self) -> Result<()>; + fn has_been_killed(&self) -> bool; + fn ssh_args(&self) -> Vec; + fn connection_options(&self) -> SshConnectionOptions; + + #[cfg(any(test, feature = "test-support"))] + fn simulate_disconnect(&self, _: &AsyncAppContext) {} +} + +struct SshRemoteConnection { + socket: SshSocket, + master_process: Mutex>, + platform: SshPlatform, + _temp_dir: TempDir, +} + +#[async_trait(?Send)] +impl RemoteConnection for SshRemoteConnection { + async fn kill(&self) -> Result<()> { + let Some(mut process) = self.master_process.lock().take() else { + return Ok(()); + }; + process.kill().ok(); + process.status().await?; + Ok(()) + } + + fn has_been_killed(&self) -> bool { + self.master_process.lock().is_none() + } + + fn ssh_args(&self) -> Vec { + self.socket.ssh_args() + } + + fn connection_options(&self) -> SshConnectionOptions { + self.socket.connection_options.clone() + } + + async fn get_remote_binary_path( + &self, + delegate: &Arc, + reconnect: bool, + cx: &mut AsyncAppContext, + ) -> Result { + let platform = self.platform; + let remote_binary_path = delegate.remote_server_binary_path(platform, cx)?; + if !reconnect { + self.ensure_server_binary(&delegate, &remote_binary_path, platform, cx) + .await?; + } + + let socket = self.socket.clone(); + run_cmd(socket.ssh_command(&remote_binary_path).arg("version")).await?; + Ok(remote_binary_path) + } + + fn start_proxy( + &self, + remote_binary_path: PathBuf, + unique_identifier: String, + reconnect: bool, + incoming_tx: UnboundedSender, + outgoing_rx: UnboundedReceiver, + connection_activity_tx: Sender<()>, + delegate: Arc, + cx: &mut AsyncAppContext, + ) -> Task> { + delegate.set_status(Some("Starting proxy"), cx); + + let mut start_proxy_command = format!( + "RUST_LOG={} RUST_BACKTRACE={} {:?} proxy --identifier {}", + std::env::var("RUST_LOG").unwrap_or_default(), + std::env::var("RUST_BACKTRACE").unwrap_or_default(), + remote_binary_path, + unique_identifier, + ); + if reconnect { + start_proxy_command.push_str(" --reconnect"); + } + + let ssh_proxy_process = match self + .socket + .ssh_command(start_proxy_command) + // IMPORTANT: we kill this process when we drop the task that uses it. + .kill_on_drop(true) + .spawn() + { + Ok(process) => process, + Err(error) => { + return Task::ready(Err(anyhow!("failed to spawn remote server: {}", error))) + } + }; + + Self::multiplex( + ssh_proxy_process, + incoming_tx, + outgoing_rx, + connection_activity_tx, + &cx, + ) + } +} + +impl SshRemoteConnection { + #[cfg(not(unix))] + async fn new( + _connection_options: SshConnectionOptions, + _delegate: Arc, + _cx: &mut AsyncAppContext, + ) -> Result { + Err(anyhow!("ssh is not supported on this platform")) + } + + #[cfg(unix)] + async fn new( + connection_options: SshConnectionOptions, + delegate: Arc, + cx: &mut AsyncAppContext, + ) -> Result { + use futures::AsyncWriteExt as _; + use futures::{io::BufReader, AsyncBufReadExt as _}; + use smol::{fs::unix::PermissionsExt as _, net::unix::UnixListener}; + use util::ResultExt as _; + + delegate.set_status(Some("Connecting"), cx); + + let url = connection_options.ssh_url(); + let temp_dir = tempfile::Builder::new() + .prefix("zed-ssh-session") + .tempdir()?; + + // Create a domain socket listener to handle requests from the askpass program. + let askpass_socket = temp_dir.path().join("askpass.sock"); + let (askpass_opened_tx, askpass_opened_rx) = oneshot::channel::<()>(); + let listener = + UnixListener::bind(&askpass_socket).context("failed to create askpass socket")?; + + let askpass_task = cx.spawn({ + let delegate = delegate.clone(); + |mut cx| async move { + let mut askpass_opened_tx = Some(askpass_opened_tx); + + while let Ok((mut stream, _)) = listener.accept().await { + if let Some(askpass_opened_tx) = askpass_opened_tx.take() { + askpass_opened_tx.send(()).ok(); + } + let mut buffer = Vec::new(); + let mut reader = BufReader::new(&mut stream); + if reader.read_until(b'\0', &mut buffer).await.is_err() { + buffer.clear(); + } + let password_prompt = String::from_utf8_lossy(&buffer); + if let Some(password) = delegate + .ask_password(password_prompt.to_string(), &mut cx) + .await + .context("failed to get ssh password") + .and_then(|p| p) + .log_err() + { + stream.write_all(password.as_bytes()).await.log_err(); + } + } + } + }); + + // Create an askpass script that communicates back to this process. + let askpass_script = format!( + "{shebang}\n{print_args} | nc -U {askpass_socket} 2> /dev/null \n", + askpass_socket = askpass_socket.display(), + print_args = "printf '%s\\0' \"$@\"", + shebang = "#!/bin/sh", + ); + let askpass_script_path = temp_dir.path().join("askpass.sh"); + fs::write(&askpass_script_path, askpass_script).await?; + fs::set_permissions(&askpass_script_path, std::fs::Permissions::from_mode(0o755)).await?; + + // Start the master SSH process, which does not do anything except for establish + // the connection and keep it open, allowing other ssh commands to reuse it + // via a control socket. + let socket_path = temp_dir.path().join("ssh.sock"); + let mut master_process = process::Command::new("ssh") + .stdin(Stdio::null()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .env("SSH_ASKPASS_REQUIRE", "force") + .env("SSH_ASKPASS", &askpass_script_path) + .args(connection_options.additional_args().unwrap_or(&Vec::new())) + .args([ + "-N", + "-o", + "ControlPersist=no", + "-o", + "ControlMaster=yes", + "-o", + ]) + .arg(format!("ControlPath={}", socket_path.display())) + .arg(&url) + .kill_on_drop(true) + .spawn()?; + + // Wait for this ssh process to close its stdout, indicating that authentication + // has completed. + let stdout = master_process.stdout.as_mut().unwrap(); + let mut output = Vec::new(); + let connection_timeout = Duration::from_secs(10); + + let result = select_biased! { + _ = askpass_opened_rx.fuse() => { + // If the askpass script has opened, that means the user is typing + // their password, in which case we don't want to timeout anymore, + // since we know a connection has been established. + stdout.read_to_end(&mut output).await?; + Ok(()) + } + result = stdout.read_to_end(&mut output).fuse() => { + result?; + Ok(()) + } + _ = futures::FutureExt::fuse(smol::Timer::after(connection_timeout)) => { + Err(anyhow!("Exceeded {:?} timeout trying to connect to host", connection_timeout)) + } + }; + + if let Err(e) = result { + return Err(e.context("Failed to connect to host")); + } + + drop(askpass_task); + + if master_process.try_status()?.is_some() { + output.clear(); + let mut stderr = master_process.stderr.take().unwrap(); + stderr.read_to_end(&mut output).await?; + + let error_message = format!( + "failed to connect: {}", + String::from_utf8_lossy(&output).trim() + ); + Err(anyhow!(error_message))?; + } + + let socket = SshSocket { + connection_options, + socket_path, + }; + + let os = run_cmd(socket.ssh_command("uname").arg("-s")).await?; + let arch = run_cmd(socket.ssh_command("uname").arg("-m")).await?; + + let os = match os.trim() { + "Darwin" => "macos", + "Linux" => "linux", + _ => Err(anyhow!("unknown uname os {os:?}"))?, + }; + let arch = if arch.starts_with("arm") || arch.starts_with("aarch64") { + "aarch64" + } else if arch.starts_with("x86") || arch.starts_with("i686") { + "x86_64" + } else { + Err(anyhow!("unknown uname architecture {arch:?}"))? + }; + + let platform = SshPlatform { os, arch }; + + Ok(Self { + socket, + master_process: Mutex::new(Some(master_process)), + platform, + _temp_dir: temp_dir, + }) + } + fn multiplex( mut ssh_proxy_process: Child, incoming_tx: UnboundedSender, @@ -936,428 +1519,6 @@ impl SshRemoteClient { }) } - fn monitor( - this: WeakModel, - io_task: Task>, - cx: &AsyncAppContext, - ) -> Task> { - cx.spawn(|mut cx| async move { - let result = io_task.await; - - match result { - Ok(exit_code) => { - if let Some(error) = ProxyLaunchError::from_exit_code(exit_code) { - match error { - ProxyLaunchError::ServerNotRunning => { - log::error!("failed to reconnect because server is not running"); - this.update(&mut cx, |this, cx| { - this.set_state(State::ServerNotRunning, cx); - })?; - } - } - } else if exit_code > 0 { - log::error!("proxy process terminated unexpectedly"); - this.update(&mut cx, |this, cx| { - this.reconnect(cx).ok(); - })?; - } - } - Err(error) => { - log::warn!("ssh io task died with error: {:?}. reconnecting...", error); - this.update(&mut cx, |this, cx| { - this.reconnect(cx).ok(); - })?; - } - } - - Ok(()) - }) - } - - fn state_is(&self, check: impl FnOnce(&State) -> bool) -> bool { - self.state.lock().as_ref().map_or(false, check) - } - - fn try_set_state( - &self, - cx: &mut ModelContext, - map: impl FnOnce(&State) -> Option, - ) { - let mut lock = self.state.lock(); - let new_state = lock.as_ref().and_then(map); - - if let Some(new_state) = new_state { - lock.replace(new_state); - cx.notify(); - } - } - - fn set_state(&self, state: State, cx: &mut ModelContext) { - log::info!("setting state to '{}'", &state); - - let is_reconnect_exhausted = state.is_reconnect_exhausted(); - let is_server_not_running = state.is_server_not_running(); - self.state.lock().replace(state); - - if is_reconnect_exhausted || is_server_not_running { - cx.emit(SshRemoteEvent::Disconnected); - } - cx.notify(); - } - - #[allow(clippy::too_many_arguments)] - async fn establish_connection( - unique_identifier: String, - reconnect: bool, - connection_options: SshConnectionOptions, - incoming_tx: UnboundedSender, - outgoing_rx: UnboundedReceiver, - connection_activity_tx: Sender<()>, - delegate: Arc, - cx: &mut AsyncAppContext, - ) -> Result<(Box, Task>)> { - #[cfg(any(test, feature = "test-support"))] - if let Some(fake) = fake::SshRemoteConnection::new(&connection_options) { - let io_task = fake::SshRemoteConnection::multiplex( - fake.connection_options(), - incoming_tx, - outgoing_rx, - connection_activity_tx, - cx, - ) - .await; - return Ok((fake, io_task)); - } - - let ssh_connection = - SshRemoteConnection::new(connection_options, delegate.clone(), cx).await?; - - let platform = ssh_connection.query_platform().await?; - let remote_binary_path = delegate.remote_server_binary_path(platform, cx)?; - if !reconnect { - ssh_connection - .ensure_server_binary(&delegate, &remote_binary_path, platform, cx) - .await?; - } - - let socket = ssh_connection.socket.clone(); - run_cmd(socket.ssh_command(&remote_binary_path).arg("version")).await?; - - delegate.set_status(Some("Starting proxy"), cx); - - let mut start_proxy_command = format!( - "RUST_LOG={} RUST_BACKTRACE={} {:?} proxy --identifier {}", - std::env::var("RUST_LOG").unwrap_or_default(), - std::env::var("RUST_BACKTRACE").unwrap_or_default(), - remote_binary_path, - unique_identifier, - ); - if reconnect { - start_proxy_command.push_str(" --reconnect"); - } - - let ssh_proxy_process = socket - .ssh_command(start_proxy_command) - // IMPORTANT: we kill this process when we drop the task that uses it. - .kill_on_drop(true) - .spawn() - .context("failed to spawn remote server")?; - - let io_task = Self::multiplex( - ssh_proxy_process, - incoming_tx, - outgoing_rx, - connection_activity_tx, - &cx, - ); - - Ok((Box::new(ssh_connection), io_task)) - } - - pub fn subscribe_to_entity(&self, remote_id: u64, entity: &Model) { - self.client.subscribe_to_entity(remote_id, entity); - } - - pub fn ssh_args(&self) -> Option> { - self.state - .lock() - .as_ref() - .and_then(|state| state.ssh_connection()) - .map(|ssh_connection| ssh_connection.ssh_args()) - } - - pub fn proto_client(&self) -> AnyProtoClient { - self.client.clone().into() - } - - pub fn connection_string(&self) -> String { - self.connection_options.connection_string() - } - - pub fn connection_options(&self) -> SshConnectionOptions { - self.connection_options.clone() - } - - pub fn connection_state(&self) -> ConnectionState { - self.state - .lock() - .as_ref() - .map(ConnectionState::from) - .unwrap_or(ConnectionState::Disconnected) - } - - pub fn is_disconnected(&self) -> bool { - self.connection_state() == ConnectionState::Disconnected - } - - #[cfg(any(test, feature = "test-support"))] - pub fn simulate_disconnect(&self, client_cx: &mut AppContext) -> Task<()> { - let port = self.connection_options().port.unwrap(); - client_cx.spawn(|cx| async move { - let (channel, server_cx) = cx - .update_global(|c: &mut fake::ServerConnections, _| c.get(port)) - .unwrap(); - - let (outgoing_tx, _) = mpsc::unbounded::(); - let (_, incoming_rx) = mpsc::unbounded::(); - channel.reconnect(incoming_rx, outgoing_tx, &server_cx); - }) - } - - #[cfg(any(test, feature = "test-support"))] - pub fn fake_server( - client_cx: &mut gpui::TestAppContext, - server_cx: &mut gpui::TestAppContext, - ) -> (u16, Arc) { - use gpui::BorrowAppContext; - let (outgoing_tx, _) = mpsc::unbounded::(); - let (_, incoming_rx) = mpsc::unbounded::(); - let server_client = - server_cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "fake-server")); - let port = client_cx.update(|cx| { - cx.update_default_global(|c: &mut fake::ServerConnections, _| { - c.push(server_client.clone(), server_cx.to_async()) - }) - }); - (port, server_client) - } - - #[cfg(any(test, feature = "test-support"))] - pub async fn fake_client(port: u16, client_cx: &mut gpui::TestAppContext) -> Model { - let (_tx, rx) = oneshot::channel(); - client_cx - .update(|cx| { - Self::new( - "fake".to_string(), - SshConnectionOptions { - host: "".to_string(), - port: Some(port), - ..Default::default() - }, - rx, - Arc::new(fake::Delegate), - cx, - ) - }) - .await - .unwrap() - .unwrap() - } -} - -impl From for AnyProtoClient { - fn from(client: SshRemoteClient) -> Self { - AnyProtoClient::new(client.client.clone()) - } -} - -#[async_trait] -trait SshRemoteProcess: Send + Sync { - async fn kill(&mut self) -> Result<()>; - fn ssh_args(&self) -> Vec; - fn connection_options(&self) -> SshConnectionOptions; -} - -struct SshRemoteConnection { - socket: SshSocket, - master_process: process::Child, - _temp_dir: TempDir, -} - -impl Drop for SshRemoteConnection { - fn drop(&mut self) { - if let Err(error) = self.master_process.kill() { - log::error!("failed to kill SSH master process: {}", error); - } - } -} - -#[async_trait] -impl SshRemoteProcess for SshRemoteConnection { - async fn kill(&mut self) -> Result<()> { - self.master_process.kill()?; - - self.master_process.status().await?; - - Ok(()) - } - - fn ssh_args(&self) -> Vec { - self.socket.ssh_args() - } - - fn connection_options(&self) -> SshConnectionOptions { - self.socket.connection_options.clone() - } -} - -impl SshRemoteConnection { - #[cfg(not(unix))] - async fn new( - _connection_options: SshConnectionOptions, - _delegate: Arc, - _cx: &mut AsyncAppContext, - ) -> Result { - Err(anyhow!("ssh is not supported on this platform")) - } - - #[cfg(unix)] - async fn new( - connection_options: SshConnectionOptions, - delegate: Arc, - cx: &mut AsyncAppContext, - ) -> Result { - use futures::AsyncWriteExt as _; - use futures::{io::BufReader, AsyncBufReadExt as _}; - use smol::{fs::unix::PermissionsExt as _, net::unix::UnixListener}; - use util::ResultExt as _; - - delegate.set_status(Some("Connecting"), cx); - - let url = connection_options.ssh_url(); - let temp_dir = tempfile::Builder::new() - .prefix("zed-ssh-session") - .tempdir()?; - - // Create a domain socket listener to handle requests from the askpass program. - let askpass_socket = temp_dir.path().join("askpass.sock"); - let (askpass_opened_tx, askpass_opened_rx) = oneshot::channel::<()>(); - let listener = - UnixListener::bind(&askpass_socket).context("failed to create askpass socket")?; - - let askpass_task = cx.spawn({ - let delegate = delegate.clone(); - |mut cx| async move { - let mut askpass_opened_tx = Some(askpass_opened_tx); - - while let Ok((mut stream, _)) = listener.accept().await { - if let Some(askpass_opened_tx) = askpass_opened_tx.take() { - askpass_opened_tx.send(()).ok(); - } - let mut buffer = Vec::new(); - let mut reader = BufReader::new(&mut stream); - if reader.read_until(b'\0', &mut buffer).await.is_err() { - buffer.clear(); - } - let password_prompt = String::from_utf8_lossy(&buffer); - if let Some(password) = delegate - .ask_password(password_prompt.to_string(), &mut cx) - .await - .context("failed to get ssh password") - .and_then(|p| p) - .log_err() - { - stream.write_all(password.as_bytes()).await.log_err(); - } - } - } - }); - - // Create an askpass script that communicates back to this process. - let askpass_script = format!( - "{shebang}\n{print_args} | nc -U {askpass_socket} 2> /dev/null \n", - askpass_socket = askpass_socket.display(), - print_args = "printf '%s\\0' \"$@\"", - shebang = "#!/bin/sh", - ); - let askpass_script_path = temp_dir.path().join("askpass.sh"); - fs::write(&askpass_script_path, askpass_script).await?; - fs::set_permissions(&askpass_script_path, std::fs::Permissions::from_mode(0o755)).await?; - - // Start the master SSH process, which does not do anything except for establish - // the connection and keep it open, allowing other ssh commands to reuse it - // via a control socket. - let socket_path = temp_dir.path().join("ssh.sock"); - let mut master_process = process::Command::new("ssh") - .stdin(Stdio::null()) - .stdout(Stdio::piped()) - .stderr(Stdio::piped()) - .env("SSH_ASKPASS_REQUIRE", "force") - .env("SSH_ASKPASS", &askpass_script_path) - .args(connection_options.additional_args().unwrap_or(&Vec::new())) - .args([ - "-N", - "-o", - "ControlPersist=no", - "-o", - "ControlMaster=yes", - "-o", - ]) - .arg(format!("ControlPath={}", socket_path.display())) - .arg(&url) - .spawn()?; - - // Wait for this ssh process to close its stdout, indicating that authentication - // has completed. - let stdout = master_process.stdout.as_mut().unwrap(); - let mut output = Vec::new(); - let connection_timeout = Duration::from_secs(10); - - let result = select_biased! { - _ = askpass_opened_rx.fuse() => { - // If the askpass script has opened, that means the user is typing - // their password, in which case we don't want to timeout anymore, - // since we know a connection has been established. - stdout.read_to_end(&mut output).await?; - Ok(()) - } - result = stdout.read_to_end(&mut output).fuse() => { - result?; - Ok(()) - } - _ = futures::FutureExt::fuse(smol::Timer::after(connection_timeout)) => { - Err(anyhow!("Exceeded {:?} timeout trying to connect to host", connection_timeout)) - } - }; - - if let Err(e) = result { - return Err(e.context("Failed to connect to host")); - } - - drop(askpass_task); - - if master_process.try_status()?.is_some() { - output.clear(); - let mut stderr = master_process.stderr.take().unwrap(); - stderr.read_to_end(&mut output).await?; - - let error_message = format!( - "failed to connect: {}", - String::from_utf8_lossy(&output).trim() - ); - Err(anyhow!(error_message))?; - } - - Ok(Self { - socket: SshSocket { - connection_options, - socket_path, - }, - master_process, - _temp_dir: temp_dir, - }) - } - async fn ensure_server_binary( &self, delegate: &Arc, @@ -1621,26 +1782,6 @@ impl SshRemoteConnection { Ok(()) } - async fn query_platform(&self) -> Result { - let os = run_cmd(self.socket.ssh_command("uname").arg("-s")).await?; - let arch = run_cmd(self.socket.ssh_command("uname").arg("-m")).await?; - - let os = match os.trim() { - "Darwin" => "macos", - "Linux" => "linux", - _ => Err(anyhow!("unknown uname os {os:?}"))?, - }; - let arch = if arch.starts_with("arm") || arch.starts_with("aarch64") { - "aarch64" - } else if arch.starts_with("x86") || arch.starts_with("i686") { - "x86_64" - } else { - Err(anyhow!("unknown uname architecture {arch:?}"))? - }; - - Ok(SshPlatform { os, arch }) - } - async fn upload_file(&self, src_path: &Path, dest_path: &Path) -> Result<()> { let mut command = process::Command::new("scp"); let output = self @@ -1974,50 +2115,86 @@ mod fake { }, select_biased, FutureExt, SinkExt, StreamExt, }; - use gpui::{AsyncAppContext, BorrowAppContext, Global, SemanticVersion, Task}; + use gpui::{AsyncAppContext, SemanticVersion, Task}; use rpc::proto::Envelope; use super::{ - ChannelClient, ServerBinary, SshClientDelegate, SshConnectionOptions, SshPlatform, - SshRemoteProcess, + ChannelClient, RemoteConnection, ServerBinary, SshClientDelegate, SshConnectionOptions, + SshPlatform, }; - pub(super) struct SshRemoteConnection { - connection_options: SshConnectionOptions, + pub(super) struct FakeRemoteConnection { + pub(super) connection_options: SshConnectionOptions, + pub(super) server_channel: Arc, + pub(super) server_cx: SendableCx, } - impl SshRemoteConnection { - pub(super) fn new( - connection_options: &SshConnectionOptions, - ) -> Option> { - if connection_options.host == "" { - return Some(Box::new(Self { - connection_options: connection_options.clone(), - })); - } - return None; + pub(super) struct SendableCx(AsyncAppContext); + // safety: you can only get the other cx on the main thread. + impl SendableCx { + pub(super) fn new(cx: AsyncAppContext) -> Self { + Self(cx) } - pub(super) async fn multiplex( - connection_options: SshConnectionOptions, + fn get(&self, _: &AsyncAppContext) -> AsyncAppContext { + self.0.clone() + } + } + unsafe impl Send for SendableCx {} + unsafe impl Sync for SendableCx {} + + #[async_trait(?Send)] + impl RemoteConnection for FakeRemoteConnection { + async fn kill(&self) -> Result<()> { + Ok(()) + } + + fn has_been_killed(&self) -> bool { + false + } + + fn ssh_args(&self) -> Vec { + Vec::new() + } + + fn connection_options(&self) -> SshConnectionOptions { + self.connection_options.clone() + } + + fn simulate_disconnect(&self, cx: &AsyncAppContext) { + let (outgoing_tx, _) = mpsc::unbounded::(); + let (_, incoming_rx) = mpsc::unbounded::(); + self.server_channel + .reconnect(incoming_rx, outgoing_tx, &self.server_cx.get(&cx)); + } + + async fn get_remote_binary_path( + &self, + _delegate: &Arc, + _reconnect: bool, + _cx: &mut AsyncAppContext, + ) -> Result { + Ok(PathBuf::new()) + } + + fn start_proxy( + &self, + _remote_binary_path: PathBuf, + _unique_identifier: String, + _reconnect: bool, mut client_incoming_tx: mpsc::UnboundedSender, mut client_outgoing_rx: mpsc::UnboundedReceiver, mut connection_activity_tx: Sender<()>, + _delegate: Arc, cx: &mut AsyncAppContext, ) -> Task> { let (mut server_incoming_tx, server_incoming_rx) = mpsc::unbounded::(); let (server_outgoing_tx, mut server_outgoing_rx) = mpsc::unbounded::(); - let (channel, server_cx) = cx - .update(|cx| { - cx.update_global(|conns: &mut ServerConnections, _| { - conns.get(connection_options.port.unwrap()) - }) - }) - .unwrap(); - channel.reconnect(server_incoming_rx, server_outgoing_tx, &server_cx); - - // send to proxy_tx to get to the server. - // receive from + self.server_channel.reconnect( + server_incoming_rx, + server_outgoing_tx, + &self.server_cx.get(cx), + ); cx.background_executor().spawn(async move { loop { @@ -2041,39 +2218,6 @@ mod fake { } } - #[async_trait] - impl SshRemoteProcess for SshRemoteConnection { - async fn kill(&mut self) -> Result<()> { - Ok(()) - } - - fn ssh_args(&self) -> Vec { - Vec::new() - } - - fn connection_options(&self) -> SshConnectionOptions { - self.connection_options.clone() - } - } - - #[derive(Default)] - pub(super) struct ServerConnections(Vec<(Arc, AsyncAppContext)>); - impl Global for ServerConnections {} - - impl ServerConnections { - pub(super) fn push(&mut self, server: Arc, cx: AsyncAppContext) -> u16 { - self.0.push((server.clone(), cx)); - self.0.len() as u16 - 1 - } - - pub(super) fn get(&mut self, port: u16) -> (Arc, AsyncAppContext) { - self.0 - .get(port as usize) - .expect("no fake server for port") - .clone() - } - } - pub(super) struct Delegate; impl SshClientDelegate for Delegate { @@ -2099,8 +2243,6 @@ mod fake { unreachable!() } - fn set_status(&self, _: Option<&str>, _: &mut AsyncAppContext) { - unreachable!() - } + fn set_status(&self, _: Option<&str>, _: &mut AsyncAppContext) {} } } diff --git a/crates/remote_server/src/remote_editing_tests.rs b/crates/remote_server/src/remote_editing_tests.rs index 32333def7f..f7420ef5b0 100644 --- a/crates/remote_server/src/remote_editing_tests.rs +++ b/crates/remote_server/src/remote_editing_tests.rs @@ -702,7 +702,7 @@ async fn init_test( ) -> (Model, Model, Arc) { init_logger(); - let (forwarder, ssh_server_client) = SshRemoteClient::fake_server(cx, server_cx); + let (opts, ssh_server_client) = SshRemoteClient::fake_server(cx, server_cx); let fs = FakeFs::new(server_cx.executor()); fs.insert_tree( "/code", @@ -744,7 +744,7 @@ async fn init_test( ) }); - let ssh = SshRemoteClient::fake_client(forwarder, cx).await; + let ssh = SshRemoteClient::fake_client(opts, cx).await; let project = build_project(ssh, cx); project .update(cx, {