Fix crash when dropping a task that is awaiting a call to Background::scoped

Co-authored-by: Keith Simmons <keith@zed.dev>
This commit is contained in:
Max Brunsfeld 2022-04-15 10:48:01 -07:00
parent 20657566b3
commit d0413ac0e1

View file

@ -8,7 +8,7 @@ use std::{
mem, mem,
pin::Pin, pin::Pin,
rc::Rc, rc::Rc,
sync::Arc, sync::{mpsc, Arc},
task::{Context, Poll}, task::{Context, Poll},
thread, thread,
time::Duration, time::Duration,
@ -625,13 +625,9 @@ impl Background {
where where
F: FnOnce(&mut Scope<'scope>), F: FnOnce(&mut Scope<'scope>),
{ {
let mut scope = Scope { let mut scope = Scope::new();
futures: Default::default(),
_phantom: PhantomData,
};
(scheduler)(&mut scope); (scheduler)(&mut scope);
let spawned = scope let spawned = mem::take(&mut scope.futures)
.futures
.into_iter() .into_iter()
.map(|f| self.spawn(f)) .map(|f| self.spawn(f))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
@ -669,24 +665,53 @@ impl Background {
pub struct Scope<'a> { pub struct Scope<'a> {
futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>, futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
tx: Option<mpsc::Sender<()>>,
rx: mpsc::Receiver<()>,
_phantom: PhantomData<&'a ()>, _phantom: PhantomData<&'a ()>,
} }
impl<'a> Scope<'a> { impl<'a> Scope<'a> {
fn new() -> Self {
let (tx, rx) = mpsc::channel();
Self {
tx: Some(tx),
rx,
futures: Default::default(),
_phantom: PhantomData,
}
}
pub fn spawn<F>(&mut self, f: F) pub fn spawn<F>(&mut self, f: F)
where where
F: Future<Output = ()> + Send + 'a, F: Future<Output = ()> + Send + 'a,
{ {
let tx = self.tx.clone().unwrap();
// Safety: The 'a lifetime is guaranteed to outlive any of these futures because
// dropping this `Scope` blocks until all of the futures have resolved.
let f = unsafe { let f = unsafe {
mem::transmute::< mem::transmute::<
Pin<Box<dyn Future<Output = ()> + Send + 'a>>, Pin<Box<dyn Future<Output = ()> + Send + 'a>>,
Pin<Box<dyn Future<Output = ()> + Send + 'static>>, Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
>(Box::pin(f)) >(Box::pin(async move {
f.await;
drop(tx);
}))
}; };
self.futures.push(f); self.futures.push(f);
} }
} }
impl<'a> Drop for Scope<'a> {
fn drop(&mut self) {
self.tx.take().unwrap();
// Wait until the channel is closed, which means that all of the spawned
// futures have resolved.
self.rx.recv().ok();
}
}
impl<T> Task<T> { impl<T> Task<T> {
pub fn ready(value: T) -> Self { pub fn ready(value: T) -> Self {
Self::Ready(Some(value)) Self::Ready(Some(value))