diff --git a/sys_util/src/errno.rs b/sys_util/src/errno.rs index ff3226c9a9..b2c29779a8 100644 --- a/sys_util/src/errno.rs +++ b/sys_util/src/errno.rs @@ -57,3 +57,13 @@ impl Display for Error { pub fn errno_result() -> Result { Err(Error::last()) } + +/// Sets errno to given error code. +/// Only defined when we compile tests as normal code does not +/// normally need set errno. +#[cfg(test)] +pub fn set_errno(e: i32) { + unsafe { + *__errno_location() = e; + } +} diff --git a/sys_util/src/handle_eintr.rs b/sys_util/src/handle_eintr.rs index af177b7d72..4bf16dae34 100644 --- a/sys_util/src/handle_eintr.rs +++ b/sys_util/src/handle_eintr.rs @@ -15,12 +15,6 @@ pub trait InterruptibleResult { fn is_interrupted(&self) -> bool; } -impl InterruptibleResult for i32 { - fn is_interrupted(&self) -> bool { - *self == EINTR - } -} - impl InterruptibleResult for ::Result { fn is_interrupted(&self) -> bool { match self { @@ -40,15 +34,17 @@ impl InterruptibleResult for io::Result { } /// Macro that retries the given expression every time its result indicates it was interrupted (i.e. -/// returned `-EINTR`). This is useful for operations that are prone to being interrupted by +/// returned `EINTR`). This is useful for operations that are prone to being interrupted by /// signals, such as blocking syscalls. /// /// The given expression `$x` can return /// -/// * `i32` in which case the expression is retried if equal to `EINTR`. /// * `sys_util::Result` in which case the expression is retried if the `Error::errno()` is `EINTR`. /// * `std::io::Result` in which case the expression is retried if the `ErrorKind` is `ErrorKind::Interrupted`. /// +/// Note that if expression returns i32 (i.e. either -1 or error code), then handle_eintr_errno() +/// or handle_eintr_rc() should be used instead. +/// /// In all cases where the result does not indicate that the expression was interrupted, the result /// is returned verbatim to the caller of this macro. /// @@ -142,27 +138,84 @@ macro_rules! handle_eintr { ) } +/// Macro that retries the given expression every time its result indicates it was interrupted. +/// It is intended to use with system functions that return `EINTR` and other error codes +/// directly as their result. +/// Most of reentrant functions use this way of signalling errors. +#[macro_export] +macro_rules! handle_eintr_rc { + ($x:expr) => ( + { + use libc::EINTR; + let mut res; + loop { + res = $x; + if res != EINTR { + break; + } + } + res + } + ) +} + +/// Macro that retries the given expression every time its result indicates it was interrupted. +/// It is intended to use with system functions that signal error by returning `-1` and setting +/// `errno` to appropriate error code (`EINTR`, `EINVAL`, etc.) +/// Most of standard non-reentrant libc functions use this way of signalling errors. +#[macro_export] +macro_rules! handle_eintr_errno { + ($x:expr) => ( + { + use $crate::Error; + use libc::EINTR; + let mut res; + loop { + res = $x; + if res != -1 || Error::last() != Error::new(EINTR) { + break; + } + } + res + } + ) +} #[cfg(test)] mod tests { use super::*; + use errno::set_errno; use Error as SysError; #[test] - fn i32_eintr() { + fn i32_eintr_rc() { let mut count = 3; { let mut dummy = || { count -= 1; if count > 0 { EINTR } else { 0 } }; - let res = handle_eintr!(dummy()); + let res = handle_eintr_rc!(dummy()); assert_eq!(res, 0); } assert_eq!(count, 0); } + #[test] + fn i32_eintr_errno() { + let mut count = 3; + { + let mut dummy = || { + count -= 1; + if count > 0 { set_errno(EINTR); -1 } else { 56 } + }; + let res = handle_eintr_errno!(dummy()); + assert_eq!(res, 56); + } + assert_eq!(count, 0); + } + #[test] fn sys_eintr() { let mut count = 7; diff --git a/sys_util/src/passwd.rs b/sys_util/src/passwd.rs index 2f86946bd6..215cdd36af 100644 --- a/sys_util/src/passwd.rs +++ b/sys_util/src/passwd.rs @@ -29,11 +29,11 @@ pub fn get_user_id(user_name: &CStr) -> Result { // This call is safe as long as it behaves as described in the man page. We pass in valid // pointers to stack-allocated buffers, and the length check for the scratch buffer is correct. unsafe { - handle_eintr!(getpwnam_r(user_name.as_ptr(), - &mut passwd, - buf.as_mut_ptr(), - buf.len(), - &mut passwd_result)) + handle_eintr_rc!(getpwnam_r(user_name.as_ptr(), + &mut passwd, + buf.as_mut_ptr(), + buf.len(), + &mut passwd_result)) }; if passwd_result.is_null() { @@ -59,11 +59,11 @@ pub fn get_group_id(group_name: &CStr) -> Result { // This call is safe as long as it behaves as described in the man page. We pass in valid // pointers to stack-allocated buffers, and the length check for the scratch buffer is correct. unsafe { - handle_eintr!(getgrnam_r(group_name.as_ptr(), - &mut group, - buf.as_mut_ptr(), - buf.len(), - &mut group_result)) + handle_eintr_rc!(getgrnam_r(group_name.as_ptr(), + &mut group, + buf.as_mut_ptr(), + buf.len(), + &mut group_result)) }; if group_result.is_null() { diff --git a/sys_util/src/poll.rs b/sys_util/src/poll.rs index 2e285144ff..d6c3d4a732 100644 --- a/sys_util/src/poll.rs +++ b/sys_util/src/poll.rs @@ -115,11 +115,11 @@ impl Poller { // Safe because poll is given the correct length of properly initialized pollfds, and we // check the return result. let ret = unsafe { - handle_eintr!(ppoll(self.pollfds.as_mut_ptr(), - self.pollfds.len() as nfds_t, - &mut timeout_spec, - null(), - 0)) + handle_eintr_errno!(ppoll(self.pollfds.as_mut_ptr(), + self.pollfds.len() as nfds_t, + &mut timeout_spec, + null(), + 0)) }; *timeout = Duration::new(timeout_spec.tv_sec as u64, timeout_spec.tv_nsec as u32);