sys_util: Add scoped_signal_handler.

This adds a scoped signal handler struct that performs some basic checks
to avoid accidental replacement of an existing handler and cleanup when
the handler goes out of scope.

Also, adds scoped_signal_handler::wait_for_interrupt() which blocks
until SIGINT is received (e.g. via Ctrl-C).

BUG=None
TEST=cargo test --workspace -- --test-threads=1

Change-Id: I5e649eac3d3ee0b842b200fc553acac44b2dfe94
Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/platform/crosvm/+/2807306
Reviewed-by: Zach Reizner <zachr@chromium.org>
Reviewed-by: Chirantan Ekbote <chirantan@chromium.org>
Tested-by: Allen Webb <allenwebb@google.com>
Tested-by: kokoro <noreply+kokoro@google.com>
Commit-Queue: Allen Webb <allenwebb@google.com>
This commit is contained in:
Allen Webb 2021-04-06 10:20:57 -05:00 committed by Commit Bot
parent a09d09c6dc
commit 5e185834fc
3 changed files with 523 additions and 3 deletions

View file

@ -36,6 +36,7 @@ mod poll;
mod priority;
mod raw_fd;
pub mod sched;
pub mod scoped_signal_handler;
mod seek_hole;
mod shm;
pub mod signal;
@ -62,6 +63,7 @@ pub use crate::poll::*;
pub use crate::priority::*;
pub use crate::raw_fd::*;
pub use crate::sched::*;
pub use crate::scoped_signal_handler::*;
pub use crate::shm::*;
pub use crate::signal::*;
pub use crate::signalfd::*;

View file

@ -0,0 +1,421 @@
// Copyright 2021 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
//! Provides a struct for registering signal handlers that get cleared on drop.
use std::convert::TryFrom;
use std::fmt::{self, Display};
use std::io::{Cursor, Write};
use std::panic::catch_unwind;
use std::result;
use libc::{c_int, c_void, STDERR_FILENO};
use crate::errno;
use crate::signal::{
clear_signal_handler, has_default_signal_handler, register_signal_handler, wait_for_signal,
Signal,
};
#[derive(Debug)]
pub enum Error {
/// Sigaction failed.
Sigaction(Signal, errno::Error),
/// Failed to check if signal has the default signal handler.
HasDefaultSignalHandler(Signal, errno::Error),
/// Failed to register a signal handler.
RegisterSignalHandler(Signal, errno::Error),
/// Signal already has a handler.
HandlerAlreadySet(Signal),
/// Already waiting for interrupt.
AlreadyWaiting,
/// Failed to wait for signal.
WaitForSignal(errno::Error),
}
impl Display for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
use self::Error::*;
match self {
Sigaction(s, e) => write!(f, "sigaction failed for {0:?}: {1}", s, e),
HasDefaultSignalHandler(s, e) => {
write!(f, "failed to check the signal handler for {0:?}: {1}", s, e)
}
RegisterSignalHandler(s, e) => write!(
f,
"failed to register a signal handler for {0:?}: {1}",
s, e
),
HandlerAlreadySet(s) => write!(f, "signal handler already set for {0:?}", s),
AlreadyWaiting => write!(f, "already waiting for interrupt."),
WaitForSignal(e) => write!(f, "wait_for_signal failed: {0}", e),
}
}
}
pub type Result<T> = result::Result<T, Error>;
/// The interface used by Scoped Signal handler.
///
/// # Safety
/// The implementation of handle_signal needs to be async signal-safe.
///
/// NOTE: panics are caught when possible because a panic inside ffi is undefined behavior.
pub unsafe trait SignalHandler {
/// A function that is called to handle the passed signal.
fn handle_signal(signal: Signal);
}
/// Wrap the handler with an extern "C" function.
extern "C" fn call_handler<H: SignalHandler>(signum: c_int) {
// Make an effort to surface an error.
if catch_unwind(|| H::handle_signal(Signal::try_from(signum).unwrap())).is_err() {
// Note the following cannot be used:
// eprintln! - uses std::io which has locks that may be held.
// format! - uses the allocator which enforces mutual exclusion.
// Get the debug representation of signum.
let signal: Signal;
let signal_debug: &dyn fmt::Debug = match Signal::try_from(signum) {
Ok(s) => {
signal = s;
&signal as &dyn fmt::Debug
}
Err(_) => &signum as &dyn fmt::Debug,
};
// Buffer the output, so a single call to write can be used.
// The message accounts for 29 chars, that leaves 35 for the string representation of the
// signal which is more than enough.
let mut buffer = [0u8; 64];
let mut cursor = Cursor::new(buffer.as_mut());
if writeln!(cursor, "signal handler got error for: {:?}", signal_debug).is_ok() {
let len = cursor.position() as usize;
// Safe in the sense that buffer is owned and the length is checked. This may print in
// the middle of an existing write, but that is considered better than dropping the
// error.
unsafe {
libc::write(
STDERR_FILENO,
cursor.get_ref().as_ptr() as *const c_void,
len,
)
};
} else {
// This should never happen, but write an error message just in case.
const ERROR_DROPPED: &str = "Error dropped by signal handler.";
let bytes = ERROR_DROPPED.as_bytes();
unsafe { libc::write(STDERR_FILENO, bytes.as_ptr() as *const c_void, bytes.len()) };
}
}
}
/// Represents a signal handler that is registered with a set of signals that unregistered when the
/// struct goes out of scope. Prefer a signalfd based solution before using this.
pub struct ScopedSignalHandler {
signals: Vec<Signal>,
}
impl ScopedSignalHandler {
/// Attempts to register `handler` with the provided `signals`. It will fail if there is already
/// an existing handler on any of `signals`.
///
/// # Safety
/// This is safe if H::handle_signal is async-signal safe.
pub fn new<H: SignalHandler>(signals: &[Signal]) -> Result<Self> {
let mut scoped_handler = ScopedSignalHandler {
signals: Vec::with_capacity(signals.len()),
};
for &signal in signals {
if !has_default_signal_handler((signal).into())
.map_err(|err| Error::HasDefaultSignalHandler(signal, err))?
{
return Err(Error::HandlerAlreadySet(signal));
}
// Requires an async-safe callback.
unsafe {
register_signal_handler((signal).into(), call_handler::<H>)
.map_err(|err| Error::RegisterSignalHandler(signal, err))?
};
scoped_handler.signals.push(signal);
}
Ok(scoped_handler)
}
}
/// Clears the signal handler for any of the associated signals.
impl Drop for ScopedSignalHandler {
fn drop(&mut self) {
for signal in &self.signals {
if let Err(err) = clear_signal_handler((*signal).into()) {
eprintln!("Error: failed to clear signal handler: {:?}", err);
}
}
}
}
/// A signal handler that does nothing.
///
/// This is useful in cases where wait_for_signal is used since it will never trigger if the signal
/// is blocked and the default handler may have undesired effects like terminating the process.
pub struct EmptySignalHandler;
/// # Safety
/// Safe because handle_signal is async-signal safe.
unsafe impl SignalHandler for EmptySignalHandler {
fn handle_signal(_: Signal) {}
}
/// Blocks until SIGINT is received, which often happens because Ctrl-C was pressed in an
/// interactive terminal.
///
/// Note: if you are using a multi-threaded application you need to block SIGINT on all other
/// threads or they may receive the signal instead of the desired thread.
pub fn wait_for_interrupt() -> Result<()> {
// Register a signal handler if there is not one already so the thread is not killed.
let ret = ScopedSignalHandler::new::<EmptySignalHandler>(&[Signal::Interrupt]);
if !matches!(&ret, Ok(_) | Err(Error::HandlerAlreadySet(_))) {
ret?;
}
match wait_for_signal(&[Signal::Interrupt.into()], None) {
Ok(_) => Ok(()),
Err(err) => Err(Error::WaitForSignal(err)),
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs::File;
use std::io::{BufRead, BufReader};
use std::mem::zeroed;
use std::ptr::{null, null_mut};
use std::sync::atomic::{AtomicI32, AtomicUsize, Ordering};
use std::sync::{Arc, Mutex, MutexGuard, Once};
use std::thread::{sleep, spawn};
use std::time::{Duration, Instant};
use libc::sigaction;
use crate::{gettid, kill, Pid};
const TEST_SIGNAL: Signal = Signal::User1;
const TEST_SIGNALS: &[Signal] = &[Signal::User1, Signal::User2];
static TEST_SIGNAL_COUNTER: AtomicUsize = AtomicUsize::new(0);
/// Only allows one test case to execute at a time.
fn get_mutex() -> MutexGuard<'static, ()> {
static INIT: Once = Once::new();
static mut VAL: Option<Arc<Mutex<()>>> = None;
INIT.call_once(|| {
let val = Some(Arc::new(Mutex::new(())));
// Safe because the mutation is protected by the Once.
unsafe { VAL = val }
});
// Safe mutation only happens in the Once.
unsafe { VAL.as_ref() }.unwrap().lock().unwrap()
}
fn reset_counter() {
TEST_SIGNAL_COUNTER.swap(0, Ordering::SeqCst);
}
fn get_sigaction(signal: Signal) -> Result<sigaction> {
// Safe because sigaction is owned and expected to be initialized ot zeros.
let mut sigact: sigaction = unsafe { zeroed() };
if unsafe { sigaction(signal.into(), null(), &mut sigact) } < 0 {
Err(Error::Sigaction(signal, errno::Error::last()))
} else {
Ok(sigact)
}
}
/// Safety:
/// This is only safe if the signal handler set in sigaction is safe.
unsafe fn restore_sigaction(signal: Signal, sigact: sigaction) -> Result<sigaction> {
if sigaction(signal.into(), &sigact, null_mut()) < 0 {
Err(Error::Sigaction(signal, errno::Error::last()))
} else {
Ok(sigact)
}
}
/// Safety:
/// Safe if the signal handler for Signal::User1 is safe.
unsafe fn send_test_signal() {
kill(gettid(), Signal::User1.into()).unwrap()
}
macro_rules! assert_counter_eq {
($compare_to:expr) => {{
let expected: usize = $compare_to;
let got: usize = TEST_SIGNAL_COUNTER.load(Ordering::SeqCst);
if got != expected {
panic!(
"wrong signal counter value: got {}; expected {}",
got, expected
);
}
}};
}
struct TestHandler;
/// # Safety
/// Safe because handle_signal is async-signal safe.
unsafe impl SignalHandler for TestHandler {
fn handle_signal(signal: Signal) {
if TEST_SIGNAL == signal {
TEST_SIGNAL_COUNTER.fetch_add(1, Ordering::SeqCst);
}
}
}
#[test]
fn scopedsignalhandler_success() {
// Prevent other test cases from running concurrently since the signal
// handlers are shared for the process.
let _guard = get_mutex();
reset_counter();
assert_counter_eq!(0);
assert!(has_default_signal_handler(TEST_SIGNAL.into()).unwrap());
let handler = ScopedSignalHandler::new::<TestHandler>(&[TEST_SIGNAL]).unwrap();
assert!(!has_default_signal_handler(TEST_SIGNAL.into()).unwrap());
// Safe because test_handler is safe.
unsafe { send_test_signal() };
// Give the handler time to run in case it is on a different thread.
for _ in 1..40 {
if TEST_SIGNAL_COUNTER.load(Ordering::SeqCst) > 0 {
break;
}
sleep(Duration::from_millis(250));
}
assert_counter_eq!(1);
drop(handler);
assert!(has_default_signal_handler(TEST_SIGNAL.into()).unwrap());
}
#[test]
fn scopedsignalhandler_handleralreadyset() {
// Prevent other test cases from running concurrently since the signal
// handlers are shared for the process.
let _guard = get_mutex();
reset_counter();
assert_counter_eq!(0);
assert!(has_default_signal_handler(TEST_SIGNAL.into()).unwrap());
// Safe because TestHandler is async-signal safe.
let handler = ScopedSignalHandler::new::<TestHandler>(&[TEST_SIGNAL]).unwrap();
assert!(!has_default_signal_handler(TEST_SIGNAL.into()).unwrap());
// Safe because TestHandler is async-signal safe.
assert!(matches!(
ScopedSignalHandler::new::<TestHandler>(&TEST_SIGNALS),
Err(Error::HandlerAlreadySet(Signal::User1))
));
assert_counter_eq!(0);
drop(handler);
assert!(has_default_signal_handler(TEST_SIGNAL.into()).unwrap());
}
/// Stores the thread used by WaitForInterruptHandler.
static WAIT_FOR_INTERRUPT_THREAD_ID: AtomicI32 = AtomicI32::new(0);
/// Forwards SIGINT to the appropriate thread.
struct WaitForInterruptHandler;
/// # Safety
/// Safe because handle_signal is async-signal safe.
unsafe impl SignalHandler for WaitForInterruptHandler {
fn handle_signal(_: Signal) {
let tid = WAIT_FOR_INTERRUPT_THREAD_ID.load(Ordering::SeqCst);
// If the thread ID is set and executed on the wrong thread, forward the signal.
if tid != 0 && gettid() != tid {
// Safe because the handler is safe and the target thread id is expecting the signal.
unsafe { kill(tid, Signal::Interrupt.into()) }.unwrap();
}
}
}
/// Query /proc/${tid}/status for its State and check if it is either S (sleeping) or in
/// D (disk sleep).
fn thread_is_sleeping(tid: Pid) -> result::Result<bool, errno::Error> {
const PREFIX: &str = "State:";
let mut status_reader = BufReader::new(File::open(format!("/proc/{}/status", tid))?);
let mut line = String::new();
loop {
let count = status_reader.read_line(&mut line)?;
if count == 0 {
return Err(errno::Error::new(libc::EIO));
}
if line.starts_with(PREFIX) {
return Ok(matches!(
line[PREFIX.len()..].trim_start().chars().next(),
Some('S') | Some('D')
));
}
line.clear();
}
}
/// Wait for a process to block either in a sleeping or disk sleep state.
fn wait_for_thread_to_sleep(tid: Pid, timeout: Duration) -> result::Result<(), errno::Error> {
let start = Instant::now();
loop {
if thread_is_sleeping(tid)? {
return Ok(());
}
if start.elapsed() > timeout {
return Err(errno::Error::new(libc::EAGAIN));
}
sleep(Duration::from_millis(50));
}
}
#[test]
fn waitforinterrupt_success() {
// Prevent other test cases from running concurrently since the signal
// handlers are shared for the process.
let _guard = get_mutex();
let to_restore = get_sigaction(Signal::Interrupt).unwrap();
clear_signal_handler(Signal::Interrupt.into()).unwrap();
// Safe because TestHandler is async-signal safe.
let handler =
ScopedSignalHandler::new::<WaitForInterruptHandler>(&[Signal::Interrupt]).unwrap();
let tid = gettid();
WAIT_FOR_INTERRUPT_THREAD_ID.store(tid, Ordering::SeqCst);
let join_handle = spawn(move || -> result::Result<(), errno::Error> {
// Wait unitl the thread is ready to receive the signal.
wait_for_thread_to_sleep(tid, Duration::from_secs(10)).unwrap();
// Safe because the SIGINT handler is safe.
unsafe { kill(tid, Signal::Interrupt.into()) }
});
let wait_ret = wait_for_interrupt();
let join_ret = join_handle.join();
drop(handler);
// Safe because we are restoring the previous SIGINT handler.
unsafe { restore_sigaction(Signal::Interrupt, to_restore) }.unwrap();
wait_ret.unwrap();
join_ret.unwrap().unwrap();
}
}

View file

@ -9,6 +9,7 @@ use libc::{
};
use std::cmp::Ordering;
use std::convert::TryFrom;
use std::fmt::{self, Display};
use std::io;
use std::mem;
@ -55,6 +56,10 @@ pub enum Error {
WaitPid(errno::Error),
/// Timeout reached.
TimedOut,
/// Failed to convert signum to Signal.
UnrecognizedSignum(i32),
/// Converted signum greater than SIGRTMAX.
RTSignumGreaterThanMax(Signal),
}
impl Display for Error {
@ -88,6 +93,10 @@ impl Display for Error {
WaitForSignal(e) => write!(f, "failed to wait for signal: {}", e),
WaitPid(e) => write!(f, "failed to wait for process: {}", e),
TimedOut => write!(f, "timeout reached."),
UnrecognizedSignum(signum) => write!(f, "unrecoginized signal number: {}", signum),
RTSignumGreaterThanMax(signal) => {
write!(f, "got RT signal greater than max: {:?}", signal)
}
}
}
}
@ -162,9 +171,9 @@ pub enum Signal {
Rt31,
}
impl Into<i32> for Signal {
fn into(self) -> i32 {
let num = self as libc::c_int;
impl From<Signal> for c_int {
fn from(signal: Signal) -> c_int {
let num = signal as libc::c_int;
if num >= Signal::Rt0 as libc::c_int {
return num - (Signal::Rt0 as libc::c_int) + SIGRTMIN();
}
@ -172,6 +181,94 @@ impl Into<i32> for Signal {
}
}
impl TryFrom<c_int> for Signal {
type Error = Error;
fn try_from(value: c_int) -> result::Result<Self, Self::Error> {
use Signal::*;
Ok(match value {
libc::SIGABRT => Abort,
libc::SIGALRM => Alarm,
libc::SIGBUS => Bus,
libc::SIGCHLD => Child,
libc::SIGCONT => Continue,
libc::SIGXFSZ => ExceededFileSize,
libc::SIGFPE => FloatingPointException,
libc::SIGHUP => HangUp,
libc::SIGILL => IllegalInstruction,
libc::SIGINT => Interrupt,
libc::SIGIO => IO,
libc::SIGKILL => Kill,
libc::SIGPIPE => Pipe,
libc::SIGPWR => Power,
libc::SIGPROF => Profile,
libc::SIGQUIT => Quit,
libc::SIGSEGV => SegmentationViolation,
libc::SIGSTKFLT => StackFault,
libc::SIGSTOP => Stop,
libc::SIGSYS => Sys,
libc::SIGTRAP => Trap,
libc::SIGTERM => Terminate,
libc::SIGTTIN => TTYIn,
libc::SIGTTOU => TTYOut,
libc::SIGTSTP => TTYStop,
libc::SIGURG => Urgent,
libc::SIGUSR1 => User1,
libc::SIGUSR2 => User2,
libc::SIGVTALRM => VTAlarm,
libc::SIGWINCH => Winch,
libc::SIGXCPU => XCPU,
_ => {
if value < SIGRTMIN() {
return Err(Error::UnrecognizedSignum(value));
}
let signal = match value - SIGRTMIN() {
0 => Rt0,
1 => Rt1,
2 => Rt2,
3 => Rt3,
4 => Rt4,
5 => Rt5,
6 => Rt6,
7 => Rt7,
8 => Rt8,
9 => Rt9,
10 => Rt10,
11 => Rt11,
12 => Rt12,
13 => Rt13,
14 => Rt14,
15 => Rt15,
16 => Rt16,
17 => Rt17,
18 => Rt18,
19 => Rt19,
20 => Rt20,
21 => Rt21,
22 => Rt22,
23 => Rt23,
24 => Rt24,
25 => Rt25,
26 => Rt26,
27 => Rt27,
28 => Rt28,
29 => Rt29,
30 => Rt30,
31 => Rt31,
_ => {
return Err(Error::UnrecognizedSignum(value));
}
};
if value > SIGRTMAX() {
return Err(Error::RTSignumGreaterThanMax(signal));
}
signal
}
})
}
}
pub type SignalResult<T> = result::Result<T, Error>;
#[link(name = "c")]