cros_async: Make raw waker Send + Sync

std::task::Waker unconditionally implements Send + Sync so the raw waker
that we provide also must implement those traits.  Switch to using an
Arc<AtomicBool>.

This also fixes an inconsistency where the waker was defined to be an
Rc<Cell<bool>> but all the vtable functions were treating as an
Rc<AtomicBool>.

To reduce the vtable boilerplate use the ArcWake trait from the futures
crate.

BUG=none
TEST=unit tests

Change-Id: I3870e4d7f6ce0de9f6ac3313a2f4474ae29018b2
Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/platform/crosvm/+/2287079
Reviewed-by: Daniel Verkamp <dverkamp@chromium.org>
Reviewed-by: Dylan Reid <dgreid@chromium.org>
Tested-by: Chirantan Ekbote <chirantan@chromium.org>
Commit-Queue: Chirantan Ekbote <chirantan@chromium.org>
This commit is contained in:
Chirantan Ekbote 2020-07-08 19:24:04 +09:00 committed by Commit Bot
parent f62b22c7bf
commit 906b43c5c9
5 changed files with 48 additions and 58 deletions

View file

@ -15,3 +15,4 @@ syscall_defines = { path = "../syscall_defines" }
[dependencies.futures]
version = "*"
default-features = false
features = ["alloc"]

View file

@ -10,6 +10,7 @@ use std::pin::Pin;
use std::task::Context;
use futures::future::{maybe_done, FutureExt, MaybeDone};
use futures::task::waker_ref;
use crate::executor::{FutureList, FutureState, UnitFutures};
@ -55,8 +56,9 @@ macro_rules! generate {
let mut complete = true;
$(
if self.[<$Fut _state>].needs_poll.replace(false) {
let mut ctx = Context::from_waker(&self.[<$Fut _state>].waker);
if self.[<$Fut _state>].needs_poll.swap(false) {
let waker = waker_ref(&self.[<$Fut _state>].needs_poll);
let mut ctx = Context::from_waker(&waker);
// The future impls `Unpin`, use `poll_unpin` to avoid wrapping it in
// `Pin` to call `poll`.
complete &= self.$Fut.poll_unpin(&mut ctx).is_ready();

View file

@ -2,17 +2,16 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
use std::cell::Cell;
use std::collections::VecDeque;
use std::future::Future;
use std::pin::Pin;
use std::rc::Rc;
use std::task::Waker;
use std::sync::Arc;
use std::task::{Context, Poll};
use futures::future::FutureExt;
use futures::task::waker_ref;
use crate::waker::create_waker;
use crate::waker::NeedsPoll;
/// Represents a future executor that can be run. Implementers of the trait will take a list of
/// futures and poll them until completed.
@ -28,22 +27,14 @@ pub trait Executor {
// Tracks if a future needs to be polled and the waker to use.
pub(crate) struct FutureState {
pub needs_poll: Rc<Cell<bool>>,
pub waker: Waker,
pub needs_poll: Arc<NeedsPoll>,
}
impl FutureState {
pub fn new() -> FutureState {
let needs_poll = Rc::new(Cell::new(true));
// Safe because a valid pointer is passed to `create_waker` and the valid result is
// passed to `Waker::from_raw`. And because the reference count to needs_poll is
// incremented by cloning it so it can't be dropped before the waker.
let waker = unsafe {
let clone = needs_poll.clone();
let raw_waker = create_waker(Rc::into_raw(clone) as *const _);
Waker::from_raw(raw_waker)
};
FutureState { needs_poll, waker }
FutureState {
needs_poll: NeedsPoll::new(),
}
}
}
@ -69,7 +60,8 @@ impl<T> ExecutableFuture<T> {
// Polls the future if needed and returns the result.
// Covers setting up the waker and context before calling the future.
fn poll(&mut self) -> Poll<T> {
let mut ctx = Context::from_waker(&self.state.waker);
let waker = waker_ref(&self.state.needs_poll);
let mut ctx = Context::from_waker(&waker);
let f = self.future.as_mut();
f.poll(&mut ctx)
}
@ -116,7 +108,7 @@ impl UnitFutures {
let mut i = 0;
while i < self.futures.len() {
let fut = &mut self.futures[i];
let remove = if fut.state.needs_poll.replace(false) {
let remove = if fut.state.needs_poll.swap(false) {
fut.poll().is_ready()
} else {
false
@ -178,8 +170,9 @@ impl<F: Future + Unpin> FutureList for RunOne<F> {
fn poll_results(&mut self) -> Option<Self::Output> {
let _ = self.added_futures.poll_results();
if self.fut_state.needs_poll.replace(false) {
let mut ctx = Context::from_waker(&self.fut_state.waker);
if self.fut_state.needs_poll.swap(false) {
let waker = waker_ref(&self.fut_state.needs_poll);
let mut ctx = Context::from_waker(&waker);
// The future impls `Unpin`, use `poll_unpin` to avoid wrapping it in
// `Pin` to call `poll`.
if let Poll::Ready(o) = self.fut.poll_unpin(&mut ctx) {
@ -197,6 +190,7 @@ impl<F: Future + Unpin> FutureList for RunOne<F> {
#[cfg(test)]
mod tests {
use super::*;
use std::rc::Rc;
use std::sync::atomic::{AtomicUsize, Ordering};
#[test]

View file

@ -10,6 +10,7 @@ use std::pin::Pin;
use std::task::Context;
use futures::future::{maybe_done, FutureExt, MaybeDone};
use futures::task::waker_ref;
use crate::executor::{FutureList, FutureState, UnitFutures};
@ -61,8 +62,9 @@ macro_rules! generate {
let mut complete = false;
$(
let $Fut = Pin::new(&mut self.$Fut);
if self.[<$Fut _state>].needs_poll.replace(false) {
let mut ctx = Context::from_waker(&self.[<$Fut _state>].waker);
if self.[<$Fut _state>].needs_poll.swap(false) {
let waker = waker_ref(&self.[<$Fut _state>].needs_poll);
let mut ctx = Context::from_waker(&waker);
// The future impls `Unpin`, use `poll_unpin` to avoid wrapping it in
// `Pin` to call `poll`.
complete |= self.$Fut.poll_unpin(&mut ctx).is_ready();

View file

@ -2,47 +2,38 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
use std::rc::Rc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::task::{RawWaker, RawWakerVTable};
use std::sync::Arc;
use futures::task::ArcWake;
/// Wrapper around a u64 used as a token to uniquely identify a pending waker.
#[derive(Clone, Debug, PartialEq, PartialOrd, Eq, Ord)]
pub(crate) struct WakerToken(pub(crate) u64);
// Boiler-plate for creating a waker with function pointers.
// This waker sets the atomic bool it is passed to true.
// The bool will be used by the executor to know which futures to poll
/// Raw waker used by executors. Associated with a single future and used to indicate whether that
/// future needs to be polled.
pub(crate) struct NeedsPoll(AtomicBool);
// Convert the pointer back to the Rc it was created from and drop it.
unsafe fn waker_drop(data_ptr: *const ()) {
// from_raw, then drop
let _rc_bool = Rc::<AtomicBool>::from_raw(data_ptr as *const _);
impl NeedsPoll {
/// Creates a new `NeedsPoll` initialized to `true`.
pub fn new() -> Arc<NeedsPoll> {
Arc::new(NeedsPoll(AtomicBool::new(true)))
}
/// Returns the current value of this `NeedsPoll`.
pub fn get(&self) -> bool {
self.0.load(Ordering::Acquire)
}
/// Changes the internal value to `val` and returns the old value.
pub fn swap(&self, val: bool) -> bool {
self.0.swap(val, Ordering::AcqRel)
}
}
unsafe fn waker_wake(data_ptr: *const ()) {
waker_wake_by_ref(data_ptr)
}
// Called when the bool should be set to true to wake the waker.
unsafe fn waker_wake_by_ref(data_ptr: *const ()) {
let bool_atomic_ptr = data_ptr as *const AtomicBool;
let bool_atomic_ref = bool_atomic_ptr.as_ref().unwrap();
bool_atomic_ref.store(true, Ordering::Relaxed);
}
// The data_ptr will be a pointer to an Rc<AtomicBool>.
unsafe fn waker_clone(data_ptr: *const ()) -> RawWaker {
let rc_bool = Rc::<AtomicBool>::from_raw(data_ptr as *const _);
let new_ptr = rc_bool.clone();
Rc::into_raw(rc_bool); // Don't decrement the ref count of the original, so back to raw.
create_waker(Rc::into_raw(new_ptr) as *const _)
}
static WAKER_VTABLE: RawWakerVTable =
RawWakerVTable::new(waker_clone, waker_wake, waker_wake_by_ref, waker_drop);
/// To use safely, data_ptr must be from Rc<AtomicBool>::from_raw().
pub unsafe fn create_waker(data_ptr: *const ()) -> RawWaker {
RawWaker::new(data_ptr, &WAKER_VTABLE)
impl ArcWake for NeedsPoll {
fn wake_by_ref(arc_self: &Arc<Self>) {
arc_self.0.store(true, Ordering::Release);
}
}