ssh: Do not cancel connection process if user is typing password (#18812)

Previously, the connection process would be cancelled after 10 seconds,
even if the connection was established successfully but the user was
still typing in a password.
We know recognize when the user is prompted for a password, and cancel
the timeout task.

Co-Authored-by: Thorsten <thorsten@zed.dev>

Release Notes:

- N/A

---------

Co-authored-by: Thorsten <thorsten@zed.dev>
This commit is contained in:
Bennet Bo Fenner 2024-10-07 15:53:32 +02:00 committed by GitHub
parent 65c9b15796
commit a3b63448df
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -171,28 +171,6 @@ async fn run_cmd(command: &mut process::Command) -> Result<String> {
)) ))
} }
} }
#[cfg(unix)]
async fn read_with_timeout(
stdout: &mut process::ChildStdout,
timeout: Duration,
output: &mut Vec<u8>,
) -> Result<(), std::io::Error> {
smol::future::or(
async {
stdout.read_to_end(output).await?;
Ok::<_, std::io::Error>(())
},
async {
smol::Timer::after(timeout).await;
Err(std::io::Error::new(
std::io::ErrorKind::TimedOut,
"Read operation timed out",
))
},
)
.await
}
struct ChannelForwarder { struct ChannelForwarder {
quit_tx: UnboundedSender<()>, quit_tx: UnboundedSender<()>,
@ -725,13 +703,19 @@ impl SshRemoteConnection {
// Create a domain socket listener to handle requests from the askpass program. // Create a domain socket listener to handle requests from the askpass program.
let askpass_socket = temp_dir.path().join("askpass.sock"); let askpass_socket = temp_dir.path().join("askpass.sock");
let (askpass_opened_tx, askpass_opened_rx) = oneshot::channel::<()>();
let listener = let listener =
UnixListener::bind(&askpass_socket).context("failed to create askpass socket")?; UnixListener::bind(&askpass_socket).context("failed to create askpass socket")?;
let askpass_task = cx.spawn({ let askpass_task = cx.spawn({
let delegate = delegate.clone(); let delegate = delegate.clone();
|mut cx| async move { |mut cx| async move {
let mut askpass_opened_tx = Some(askpass_opened_tx);
while let Ok((mut stream, _)) = listener.accept().await { while let Ok((mut stream, _)) = listener.accept().await {
if let Some(askpass_opened_tx) = askpass_opened_tx.take() {
askpass_opened_tx.send(()).ok();
}
let mut buffer = Vec::new(); let mut buffer = Vec::new();
let mut reader = BufReader::new(&mut stream); let mut reader = BufReader::new(&mut stream);
if reader.read_until(b'\0', &mut buffer).await.is_err() { if reader.read_until(b'\0', &mut buffer).await.is_err() {
@ -782,19 +766,28 @@ impl SshRemoteConnection {
let stdout = master_process.stdout.as_mut().unwrap(); let stdout = master_process.stdout.as_mut().unwrap();
let mut output = Vec::new(); let mut output = Vec::new();
let connection_timeout = Duration::from_secs(10); let connection_timeout = Duration::from_secs(10);
let result = read_with_timeout(stdout, connection_timeout, &mut output).await;
if let Err(e) = result { let result = select_biased! {
let error_message = if e.kind() == std::io::ErrorKind::TimedOut { _ = askpass_opened_rx.fuse() => {
format!( // If the askpass script has opened, that means the user is typing
"Failed to connect to host. Timed out after {:?}.", // their password, in which case we don't want to timeout anymore,
connection_timeout // since we know a connection has been established.
) stdout.read_to_end(&mut output).await?;
} else { Ok(())
format!("Failed to connect to host: {}.", e) }
result = stdout.read_to_end(&mut output).fuse() => {
result?;
Ok(())
}
_ = futures::FutureExt::fuse(smol::Timer::after(connection_timeout)) => {
Err(anyhow!("Exceeded {:?} timeout trying to connect to host", connection_timeout))
}
}; };
if let Err(e) = result {
let error_message = format!("Failed to connect to host: {}.", e);
delegate.set_error(error_message, cx); delegate.set_error(error_message, cx);
return Err(e.into()); return Err(e);
} }
drop(askpass_task); drop(askpass_task);
@ -803,10 +796,10 @@ impl SshRemoteConnection {
output.clear(); output.clear();
let mut stderr = master_process.stderr.take().unwrap(); let mut stderr = master_process.stderr.take().unwrap();
stderr.read_to_end(&mut output).await?; stderr.read_to_end(&mut output).await?;
Err(anyhow!(
"failed to connect: {}", let error_message = format!("failed to connect: {}", String::from_utf8_lossy(&output));
String::from_utf8_lossy(&output) delegate.set_error(error_message.clone(), cx);
))?; Err(anyhow!(error_message))?;
} }
Ok(Self { Ok(Self {