From ab59f023164ed5b914762852247fc38c8b58a700 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Wed, 16 Feb 2022 13:54:00 -0800 Subject: [PATCH] Fix chat channel unit test Also, improve error in tests when FakeServer never receives a request, using the new `start_waiting` method on the DeterministicExecutor. --- crates/client/src/channel.rs | 4 ++ crates/client/src/test.rs | 126 +++++++++++++++++------------------ 2 files changed, 67 insertions(+), 63 deletions(-) diff --git a/crates/client/src/channel.rs b/crates/client/src/channel.rs index 24de1ff835..1b00d4daf6 100644 --- a/crates/client/src/channel.rs +++ b/crates/client/src/channel.rs @@ -598,10 +598,14 @@ mod tests { #[gpui::test] async fn test_channel_messages(mut cx: TestAppContext) { + cx.foreground().forbid_parking(); + let user_id = 5; let http_client = FakeHttpClient::new(|_| async move { Ok(Response::new(404)) }); let mut client = Client::new(http_client.clone()); let server = FakeServer::for_client(user_id, &mut client, &cx).await; + + Channel::init(&client); let user_store = cx.add_model(|cx| UserStore::new(client.clone(), http_client, cx)); let channel_list = cx.add_model(|cx| ChannelList::new(user_store, client.clone(), cx)); diff --git a/crates/client/src/test.rs b/crates/client/src/test.rs index c8aca79192..697bf3860c 100644 --- a/crates/client/src/test.rs +++ b/crates/client/src/test.rs @@ -1,25 +1,28 @@ -use super::Client; -use super::*; -use crate::http::{HttpClient, Request, Response, ServerResponse}; +use crate::{ + http::{HttpClient, Request, Response, ServerResponse}, + Client, Connection, Credentials, EstablishConnectionError, UserStore, +}; +use anyhow::{anyhow, Result}; use futures::{future::BoxFuture, stream::BoxStream, Future, StreamExt}; -use gpui::{ModelHandle, TestAppContext}; +use gpui::{executor, ModelHandle, TestAppContext}; use parking_lot::Mutex; use rpc::{proto, ConnectionId, Peer, Receipt, TypedEnvelope}; -use std::fmt; -use std::sync::atomic::Ordering::SeqCst; -use std::sync::{ - atomic::{AtomicBool, AtomicUsize}, - Arc, -}; +use std::{fmt, rc::Rc, sync::Arc}; pub struct FakeServer { peer: Arc, - incoming: Mutex>>>, - connection_id: Mutex>, - forbid_connections: AtomicBool, - auth_count: AtomicUsize, - access_token: AtomicUsize, + state: Arc>, user_id: u64, + executor: Rc, +} + +#[derive(Default)] +struct FakeServerState { + incoming: Option>>, + connection_id: Option, + forbid_connections: bool, + auth_count: usize, + access_token: usize, } impl FakeServer { @@ -27,24 +30,22 @@ impl FakeServer { client_user_id: u64, client: &mut Arc, cx: &TestAppContext, - ) -> Arc { - let server = Arc::new(Self { + ) -> Self { + let server = Self { peer: Peer::new(), - incoming: Default::default(), - connection_id: Default::default(), - forbid_connections: Default::default(), - auth_count: Default::default(), - access_token: Default::default(), + state: Default::default(), user_id: client_user_id, - }); + executor: cx.foreground(), + }; Arc::get_mut(client) .unwrap() .override_authenticate({ - let server = server.clone(); + let state = server.state.clone(); move |cx| { - server.auth_count.fetch_add(1, SeqCst); - let access_token = server.access_token.load(SeqCst).to_string(); + let mut state = state.lock(); + state.auth_count += 1; + let access_token = state.access_token.to_string(); cx.spawn(move |_| async move { Ok(Credentials { user_id: client_user_id, @@ -54,12 +55,32 @@ impl FakeServer { } }) .override_establish_connection({ - let server = server.clone(); + let peer = server.peer.clone(); + let state = server.state.clone(); move |credentials, cx| { + let peer = peer.clone(); + let state = state.clone(); let credentials = credentials.clone(); - cx.spawn({ - let server = server.clone(); - move |cx| async move { server.establish_connection(&credentials, &cx).await } + cx.spawn(move |cx| async move { + assert_eq!(credentials.user_id, client_user_id); + + if state.lock().forbid_connections { + Err(EstablishConnectionError::Other(anyhow!( + "server is forbidding connections" + )))? + } + + if credentials.access_token != state.lock().access_token.to_string() { + Err(EstablishConnectionError::Unauthorized)? + } + + let (client_conn, server_conn, _) = Connection::in_memory(cx.background()); + let (connection_id, io, incoming) = peer.add_connection(server_conn).await; + cx.background().spawn(io).detach(); + let mut state = state.lock(); + state.connection_id = Some(connection_id); + state.incoming = Some(incoming); + Ok(client_conn) }) } }); @@ -73,49 +94,25 @@ impl FakeServer { pub fn disconnect(&self) { self.peer.disconnect(self.connection_id()); - self.connection_id.lock().take(); - self.incoming.lock().take(); - } - - async fn establish_connection( - &self, - credentials: &Credentials, - cx: &AsyncAppContext, - ) -> Result { - assert_eq!(credentials.user_id, self.user_id); - - if self.forbid_connections.load(SeqCst) { - Err(EstablishConnectionError::Other(anyhow!( - "server is forbidding connections" - )))? - } - - if credentials.access_token != self.access_token.load(SeqCst).to_string() { - Err(EstablishConnectionError::Unauthorized)? - } - - let (client_conn, server_conn, _) = Connection::in_memory(cx.background()); - let (connection_id, io, incoming) = self.peer.add_connection(server_conn).await; - cx.background().spawn(io).detach(); - *self.incoming.lock() = Some(incoming); - *self.connection_id.lock() = Some(connection_id); - Ok(client_conn) + let mut state = self.state.lock(); + state.connection_id.take(); + state.incoming.take(); } pub fn auth_count(&self) -> usize { - self.auth_count.load(SeqCst) + self.state.lock().auth_count } pub fn roll_access_token(&self) { - self.access_token.fetch_add(1, SeqCst); + self.state.lock().access_token += 1; } pub fn forbid_connections(&self) { - self.forbid_connections.store(true, SeqCst); + self.state.lock().forbid_connections = true; } pub fn allow_connections(&self) { - self.forbid_connections.store(false, SeqCst); + self.state.lock().forbid_connections = false; } pub fn send(&self, message: T) { @@ -123,14 +120,17 @@ impl FakeServer { } pub async fn receive(&self) -> Result> { + self.executor.start_waiting(); let message = self - .incoming + .state .lock() + .incoming .as_mut() .expect("not connected") .next() .await .ok_or_else(|| anyhow!("other half hung up"))?; + self.executor.finish_waiting(); let type_name = message.payload_type_name(); Ok(*message .into_any() @@ -152,7 +152,7 @@ impl FakeServer { } fn connection_id(&self) -> ConnectionId { - self.connection_id.lock().expect("not connected") + self.state.lock().connection_id.expect("not connected") } pub async fn build_user_store(