use crate::{ http::{HttpClient, Request, Response, ServerResponse}, Client, Connection, Credentials, EstablishConnectionError, UserStore, }; use anyhow::{anyhow, Result}; use futures::{future::BoxFuture, stream::BoxStream, Future, StreamExt}; use gpui::{executor, ModelHandle, TestAppContext}; use parking_lot::Mutex; use rpc::{proto, ConnectionId, Peer, Receipt, TypedEnvelope}; use std::{fmt, rc::Rc, sync::Arc}; pub struct FakeServer { peer: Arc, state: Arc>, user_id: u64, executor: Rc, } #[derive(Default)] struct FakeServerState { incoming: Option>>, connection_id: Option, forbid_connections: bool, auth_count: usize, access_token: usize, } impl FakeServer { pub async fn for_client( client_user_id: u64, client: &mut Arc, cx: &TestAppContext, ) -> Self { let server = Self { peer: Peer::new(), state: Default::default(), user_id: client_user_id, executor: cx.foreground(), }; Arc::get_mut(client) .unwrap() .override_authenticate({ let state = server.state.clone(); move |cx| { let mut state = state.lock(); state.auth_count += 1; let access_token = state.access_token.to_string(); cx.spawn(move |_| async move { Ok(Credentials { user_id: client_user_id, access_token, }) }) } }) .override_establish_connection({ let peer = server.peer.clone(); let state = server.state.clone(); move |credentials, cx| { let peer = peer.clone(); let state = state.clone(); let credentials = credentials.clone(); cx.spawn(move |cx| async move { assert_eq!(credentials.user_id, client_user_id); if state.lock().forbid_connections { Err(EstablishConnectionError::Other(anyhow!( "server is forbidding connections" )))? } if credentials.access_token != state.lock().access_token.to_string() { Err(EstablishConnectionError::Unauthorized)? } let (client_conn, server_conn, _) = Connection::in_memory(cx.background()); let (connection_id, io, incoming) = peer.add_test_connection(server_conn, cx.background()).await; cx.background().spawn(io).detach(); let mut state = state.lock(); state.connection_id = Some(connection_id); state.incoming = Some(incoming); Ok(client_conn) }) } }); client .authenticate_and_connect(false, &cx.to_async()) .await .unwrap(); server } pub fn disconnect(&self) { self.peer.disconnect(self.connection_id()); let mut state = self.state.lock(); state.connection_id.take(); state.incoming.take(); } pub fn auth_count(&self) -> usize { self.state.lock().auth_count } pub fn roll_access_token(&self) { self.state.lock().access_token += 1; } pub fn forbid_connections(&self) { self.state.lock().forbid_connections = true; } pub fn allow_connections(&self) { self.state.lock().forbid_connections = false; } pub fn send(&self, message: T) { self.peer.send(self.connection_id(), message).unwrap(); } pub async fn receive(&self) -> Result> { self.executor.start_waiting(); let message = self .state .lock() .incoming .as_mut() .expect("not connected") .next() .await .ok_or_else(|| anyhow!("other half hung up"))?; self.executor.finish_waiting(); let type_name = message.payload_type_name(); Ok(*message .into_any() .downcast::>() .unwrap_or_else(|_| { panic!( "fake server received unexpected message type: {:?}", type_name ); })) } pub async fn respond( &self, receipt: Receipt, response: T::Response, ) { self.peer.respond(receipt, response).unwrap() } fn connection_id(&self) -> ConnectionId { self.state.lock().connection_id.expect("not connected") } pub async fn build_user_store( &self, client: Arc, cx: &mut TestAppContext, ) -> ModelHandle { let http_client = FakeHttpClient::with_404_response(); let user_store = cx.add_model(|cx| UserStore::new(client, http_client, cx)); assert_eq!( self.receive::() .await .unwrap() .payload .user_ids, &[self.user_id] ); user_store } } pub struct FakeHttpClient { handler: Box BoxFuture<'static, Result>>, } impl FakeHttpClient { pub fn new(handler: F) -> Arc where Fut: 'static + Send + Future>, F: 'static + Send + Sync + Fn(Request) -> Fut, { Arc::new(Self { handler: Box::new(move |req| Box::pin(handler(req))), }) } pub fn with_404_response() -> Arc { Self::new(|_| async move { Ok(ServerResponse::new(404)) }) } } impl fmt::Debug for FakeHttpClient { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("FakeHttpClient").finish() } } impl HttpClient for FakeHttpClient { fn send<'a>(&'a self, req: Request) -> BoxFuture<'a, Result> { let future = (self.handler)(req); Box::pin(async move { future.await.map(Into::into) }) } }