diff --git a/Cargo.lock b/Cargo.lock index 0b0ca33a5..8ad4c51c0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -724,7 +724,9 @@ dependencies = [ "insta", "itertools", "jujutsu-lib", + "libc", "maplit", + "once_cell", "pest", "pest_derive", "predicates", @@ -732,6 +734,7 @@ dependencies = [ "regex", "rpassword", "serde", + "slab", "tempfile", "test-case", "textwrap 0.16.0", @@ -780,9 +783,9 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" [[package]] name = "libc" -version = "0.2.132" +version = "0.2.137" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8371e4e5341c3a96db127eb2465ac681ced4c433e01dd0e938adbef26ba93ba5" +checksum = "fc7fcc620a3bff7cdd7a365be3376c97191aeaccc2a603e600951e452615bf89" [[package]] name = "libgit2-sys" @@ -1450,6 +1453,15 @@ version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "62ac7f900db32bf3fd12e0117dd3dc4da74bc52ebaac97f39668446d89694803" +[[package]] +name = "slab" +version = "0.4.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4614a76b2a8be0058caa9dbbaf66d988527d86d003c11a94fbd335d7661edcef" +dependencies = [ + "autocfg", +] + [[package]] name = "smallvec" version = "1.10.0" diff --git a/Cargo.toml b/Cargo.toml index 179792213..c8ecd6a4c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,6 +46,7 @@ git2 = "0.15.0" hex = "0.4.3" itertools = "0.10.5" jujutsu-lib = { version = "=0.5.1", path = "lib"} +once_cell = "1.15.0" maplit = "1.0.2" pest = "2.4.0" pest_derive = "2.4" @@ -53,10 +54,14 @@ rand = "0.8.5" regex = "1.6.0" rpassword = "7.1.0" serde = { version = "1.0", features = ["derive"] } +slab = "0.4.7" tempfile = "3.3.0" textwrap = "0.16.0" thiserror = "1.0.37" +[target.'cfg(unix)'.dependencies] +libc = { version = "0.2.137" } + [dev-dependencies] assert_cmd = "2.0.5" criterion = "0.4.0" diff --git a/examples/custom-backend/main.rs b/examples/custom-backend/main.rs index dac126c3a..9c911b582 100644 --- a/examples/custom-backend/main.rs +++ b/examples/custom-backend/main.rs @@ -61,6 +61,7 @@ fn run(ui: &mut Ui) -> Result<(), CommandError> { } fn main() { + jujutsu::cleanup_guard::init(); let (mut ui, result) = create_ui(); let result = result.and_then(|()| run(&mut ui)); let exit_code = handle_command_result(&mut ui, result); diff --git a/examples/custom-command/main.rs b/examples/custom-command/main.rs index 3f898cb9d..ccd136b19 100644 --- a/examples/custom-command/main.rs +++ b/examples/custom-command/main.rs @@ -59,6 +59,7 @@ fn run(ui: &mut Ui) -> Result<(), CommandError> { } fn main() { + jujutsu::cleanup_guard::init(); let (mut ui, result) = create_ui(); let result = result.and_then(|()| run(&mut ui)); let exit_code = handle_command_result(&mut ui, result); diff --git a/src/cleanup_guard.rs b/src/cleanup_guard.rs new file mode 100644 index 000000000..841d11b36 --- /dev/null +++ b/src/cleanup_guard.rs @@ -0,0 +1,117 @@ +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Mutex, Once}; +use std::{io, thread}; + +use once_cell::sync::Lazy; +use slab::Slab; + +/// Contains the callbacks passed to currently-live [`CleanupGuard`]s +static LIVE_GUARDS: Lazy> = Lazy::new(|| Mutex::new(Slab::new())); + +type GuardTable = Slab>; + +/// Prepare to run [`CleanupGuard`]s on `SIGINT`/`SIGTERM` +pub fn init() { + // Safety: `` ensures at most one call + static CALLED: Once = Once::new(); + CALLED.call_once(|| { + if let Err(ref e) = unsafe { platform::init() } { + eprintln!("couldn't register signal handler: {}", e); + } + }); +} + +/// A drop guard that also runs on `SIGINT`/`SIGTERM` +pub struct CleanupGuard { + slot: usize, +} + +impl CleanupGuard { + /// Invoke `f` when dropped or killed by `SIGINT`/`SIGTERM` + pub fn new(f: F) -> Self { + let guards = &mut *LIVE_GUARDS.lock().unwrap(); + Self { + slot: guards.insert(Box::new(f)), + } + } +} + +impl Drop for CleanupGuard { + fn drop(&mut self) { + let guards = &mut *LIVE_GUARDS.lock().unwrap(); + let f = guards.remove(self.slot); + f(); + } +} + +// Invoked on a background thread. Process exits after this returns. +fn on_signal(guards: &mut GuardTable) { + for guard in guards.drain() { + guard(); + } +} + +#[cfg(unix)] +mod platform { + use std::os::unix::io::{IntoRawFd as _, RawFd}; + use std::os::unix::net::UnixDatagram; + use std::panic::AssertUnwindSafe; + + use libc::{c_int, SIGINT, SIGTERM}; + + use super::*; + + /// Safety: Must be called at most once + pub unsafe fn init() -> io::Result<()> { + let (send, recv) = UnixDatagram::pair()?; + + // Spawn a background thread that waits for the signal handler to write a signal + // into it + thread::spawn(move || { + let mut buf = [0]; + let signal = match recv.recv(&mut buf) { + Ok(1) => c_int::from(buf[0]), + _ => unreachable!(), + }; + // We must hold the lock for the remainder of the process's lifetime to avoid a + // race where a guard is created between `on_signal` and `raise`. + let guards = &mut *LIVE_GUARDS.lock().unwrap(); + if let Err(e) = std::panic::catch_unwind(AssertUnwindSafe(|| on_signal(guards))) { + match e.downcast::() { + Ok(s) => eprintln!("signal handler panicked: {}", s), + Err(_) => eprintln!("signal handler panicked"), + } + } + libc::signal(signal, libc::SIG_DFL); + libc::raise(signal); + }); + + SIGNAL_SEND = send.into_raw_fd(); + libc::signal(SIGINT, handler as libc::sighandler_t); + libc::signal(SIGTERM, handler as libc::sighandler_t); + Ok(()) + } + + unsafe extern "C" fn handler(signal: c_int) { + // Treat the second signal as instantly fatal. + static SIGNALED: AtomicBool = AtomicBool::new(false); + if SIGNALED.swap(true, Ordering::Relaxed) { + libc::signal(signal, libc::SIG_DFL); + libc::raise(signal); + } + + let buf = [signal as u8]; + libc::write(SIGNAL_SEND, buf.as_ptr().cast(), buf.len()); + } + + static mut SIGNAL_SEND: RawFd = 0; +} + +#[cfg(not(unix))] +mod platform { + use super::*; + + pub fn init() -> io::Result<()> { + Ok(()) + } +} diff --git a/src/lib.rs b/src/lib.rs index 2a9c94b4a..b8d0872d5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,6 +14,7 @@ #![deny(unused_must_use)] +pub mod cleanup_guard; pub mod cli_util; pub mod commands; pub mod config; diff --git a/src/main.rs b/src/main.rs index 19b84d04b..242b7e170 100644 --- a/src/main.rs +++ b/src/main.rs @@ -23,6 +23,7 @@ fn run(ui: &mut Ui) -> Result<(), CommandError> { } fn main() { + jujutsu::cleanup_guard::init(); let (mut ui, result) = create_ui(); let result = result.and_then(|()| run(&mut ui)); let exit_code = handle_command_result(&mut ui, result);