diff --git a/crates/remote/src/ssh_session.rs b/crates/remote/src/ssh_session.rs index 0a1cd00992..26ef8626ec 100644 --- a/crates/remote/src/ssh_session.rs +++ b/crates/remote/src/ssh_session.rs @@ -171,28 +171,6 @@ async fn run_cmd(command: &mut process::Command) -> Result { )) } } -#[cfg(unix)] -async fn read_with_timeout( - stdout: &mut process::ChildStdout, - timeout: Duration, - output: &mut Vec, -) -> Result<(), std::io::Error> { - smol::future::or( - async { - stdout.read_to_end(output).await?; - Ok::<_, std::io::Error>(()) - }, - async { - smol::Timer::after(timeout).await; - - Err(std::io::Error::new( - std::io::ErrorKind::TimedOut, - "Read operation timed out", - )) - }, - ) - .await -} struct ChannelForwarder { quit_tx: UnboundedSender<()>, @@ -725,13 +703,19 @@ impl SshRemoteConnection { // Create a domain socket listener to handle requests from the askpass program. let askpass_socket = temp_dir.path().join("askpass.sock"); + let (askpass_opened_tx, askpass_opened_rx) = oneshot::channel::<()>(); let listener = UnixListener::bind(&askpass_socket).context("failed to create askpass socket")?; let askpass_task = cx.spawn({ let delegate = delegate.clone(); |mut cx| async move { + let mut askpass_opened_tx = Some(askpass_opened_tx); + while let Ok((mut stream, _)) = listener.accept().await { + if let Some(askpass_opened_tx) = askpass_opened_tx.take() { + askpass_opened_tx.send(()).ok(); + } let mut buffer = Vec::new(); let mut reader = BufReader::new(&mut stream); if reader.read_until(b'\0', &mut buffer).await.is_err() { @@ -782,19 +766,28 @@ impl SshRemoteConnection { let stdout = master_process.stdout.as_mut().unwrap(); let mut output = Vec::new(); let connection_timeout = Duration::from_secs(10); - let result = read_with_timeout(stdout, connection_timeout, &mut output).await; - if let Err(e) = result { - let error_message = if e.kind() == std::io::ErrorKind::TimedOut { - format!( - "Failed to connect to host. Timed out after {:?}.", - connection_timeout - ) - } else { - format!("Failed to connect to host: {}.", e) - }; + let result = select_biased! { + _ = askpass_opened_rx.fuse() => { + // If the askpass script has opened, that means the user is typing + // their password, in which case we don't want to timeout anymore, + // since we know a connection has been established. + stdout.read_to_end(&mut output).await?; + Ok(()) + } + result = stdout.read_to_end(&mut output).fuse() => { + result?; + Ok(()) + } + _ = futures::FutureExt::fuse(smol::Timer::after(connection_timeout)) => { + Err(anyhow!("Exceeded {:?} timeout trying to connect to host", connection_timeout)) + } + }; + + if let Err(e) = result { + let error_message = format!("Failed to connect to host: {}.", e); delegate.set_error(error_message, cx); - return Err(e.into()); + return Err(e); } drop(askpass_task); @@ -803,10 +796,10 @@ impl SshRemoteConnection { output.clear(); let mut stderr = master_process.stderr.take().unwrap(); stderr.read_to_end(&mut output).await?; - Err(anyhow!( - "failed to connect: {}", - String::from_utf8_lossy(&output) - ))?; + + let error_message = format!("failed to connect: {}", String::from_utf8_lossy(&output)); + delegate.set_error(error_message.clone(), cx); + Err(anyhow!(error_message))?; } Ok(Self {