From beea9b68ff53b347f32e3085590ae612b8344b93 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 12 Dec 2022 16:03:21 +0100 Subject: [PATCH] Allow re-joining room after server restarts --- crates/call/src/room.rs | 4 +- crates/collab/src/db.rs | 76 +++++++++++++++-------- crates/collab/src/integration_tests.rs | 85 ++++++++++++++++++++++++-- crates/collab/src/main.rs | 2 +- crates/collab/src/rpc.rs | 11 +++- 5 files changed, 141 insertions(+), 37 deletions(-) diff --git a/crates/call/src/room.rs b/crates/call/src/room.rs index 824ec49054..44a3653c5d 100644 --- a/crates/call/src/room.rs +++ b/crates/call/src/room.rs @@ -15,7 +15,7 @@ use project::Project; use std::{mem, sync::Arc, time::Duration}; use util::{post_inc, ResultExt}; -pub const RECONNECTION_TIMEOUT: Duration = client::RECEIVE_TIMEOUT; +pub const RECONNECT_TIMEOUT: Duration = client::RECEIVE_TIMEOUT; #[derive(Clone, Debug, PartialEq, Eq)] pub enum Event { @@ -262,7 +262,7 @@ impl Room { }); // Wait for client to re-establish a connection to the server. - let mut reconnection_timeout = cx.background().timer(RECONNECTION_TIMEOUT).fuse(); + let mut reconnection_timeout = cx.background().timer(RECONNECT_TIMEOUT).fuse(); let client_reconnection = async { loop { if let Some(status) = client_status.next().await { diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 4a920841e8..64a95b2300 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -21,6 +21,7 @@ use dashmap::DashMap; use futures::StreamExt; use hyper::StatusCode; use rpc::{proto, ConnectionId}; +use sea_orm::Condition; pub use sea_orm::ConnectOptions; use sea_orm::{ entity::prelude::*, ActiveValue, ConnectionTrait, DatabaseConnection, DatabaseTransaction, @@ -47,7 +48,7 @@ pub struct Database { background: Option>, #[cfg(test)] runtime: Option, - epoch: Uuid, + epoch: parking_lot::RwLock, } impl Database { @@ -60,10 +61,20 @@ impl Database { background: None, #[cfg(test)] runtime: None, - epoch: Uuid::new_v4(), + epoch: parking_lot::RwLock::new(Uuid::new_v4()), }) } + #[cfg(test)] + pub fn reset(&self) { + self.rooms.clear(); + *self.epoch.write() = Uuid::new_v4(); + } + + fn epoch(&self) -> Uuid { + *self.epoch.read() + } + pub async fn migrate( &self, migrations_path: &Path, @@ -105,22 +116,29 @@ impl Database { Ok(new_migrations) } - pub async fn clear_stale_data(&self) -> Result<()> { + pub async fn delete_stale_projects(&self) -> Result<()> { self.transaction(|tx| async move { project_collaborator::Entity::delete_many() - .filter(project_collaborator::Column::ConnectionEpoch.ne(self.epoch)) - .exec(&*tx) - .await?; - room_participant::Entity::delete_many() - .filter( - room_participant::Column::AnsweringConnectionEpoch - .ne(self.epoch) - .or(room_participant::Column::CallingConnectionEpoch.ne(self.epoch)), - ) + .filter(project_collaborator::Column::ConnectionEpoch.ne(self.epoch())) .exec(&*tx) .await?; project::Entity::delete_many() - .filter(project::Column::HostConnectionEpoch.ne(self.epoch)) + .filter(project::Column::HostConnectionEpoch.ne(self.epoch())) + .exec(&*tx) + .await?; + Ok(()) + }) + .await + } + + pub async fn delete_stale_rooms(&self) -> Result<()> { + self.transaction(|tx| async move { + room_participant::Entity::delete_many() + .filter( + room_participant::Column::AnsweringConnectionEpoch + .ne(self.epoch()) + .or(room_participant::Column::CallingConnectionEpoch.ne(self.epoch())), + ) .exec(&*tx) .await?; room::Entity::delete_many() @@ -1033,11 +1051,11 @@ impl Database { room_id: ActiveValue::set(room_id), user_id: ActiveValue::set(user_id), answering_connection_id: ActiveValue::set(Some(connection_id.0 as i32)), - answering_connection_epoch: ActiveValue::set(Some(self.epoch)), + answering_connection_epoch: ActiveValue::set(Some(self.epoch())), answering_connection_lost: ActiveValue::set(false), calling_user_id: ActiveValue::set(user_id), calling_connection_id: ActiveValue::set(connection_id.0 as i32), - calling_connection_epoch: ActiveValue::set(self.epoch), + calling_connection_epoch: ActiveValue::set(self.epoch()), ..Default::default() } .insert(&*tx) @@ -1064,7 +1082,7 @@ impl Database { answering_connection_lost: ActiveValue::set(false), calling_user_id: ActiveValue::set(calling_user_id), calling_connection_id: ActiveValue::set(calling_connection_id.0 as i32), - calling_connection_epoch: ActiveValue::set(self.epoch), + calling_connection_epoch: ActiveValue::set(self.epoch()), initial_project_id: ActiveValue::set(initial_project_id), ..Default::default() } @@ -1174,18 +1192,22 @@ impl Database { self.room_transaction(|tx| async move { let result = room_participant::Entity::update_many() .filter( - room_participant::Column::RoomId - .eq(room_id) - .and(room_participant::Column::UserId.eq(user_id)) - .and( - room_participant::Column::AnsweringConnectionId - .is_null() - .or(room_participant::Column::AnsweringConnectionLost.eq(true)), + Condition::all() + .add(room_participant::Column::RoomId.eq(room_id)) + .add(room_participant::Column::UserId.eq(user_id)) + .add( + Condition::any() + .add(room_participant::Column::AnsweringConnectionId.is_null()) + .add(room_participant::Column::AnsweringConnectionLost.eq(true)) + .add( + room_participant::Column::AnsweringConnectionEpoch + .ne(self.epoch()), + ), ), ) .set(room_participant::ActiveModel { answering_connection_id: ActiveValue::set(Some(connection_id.0 as i32)), - answering_connection_epoch: ActiveValue::set(Some(self.epoch)), + answering_connection_epoch: ActiveValue::set(Some(self.epoch())), answering_connection_lost: ActiveValue::set(false), ..Default::default() }) @@ -1591,7 +1613,7 @@ impl Database { room_id: ActiveValue::set(participant.room_id), host_user_id: ActiveValue::set(participant.user_id), host_connection_id: ActiveValue::set(connection_id.0 as i32), - host_connection_epoch: ActiveValue::set(self.epoch), + host_connection_epoch: ActiveValue::set(self.epoch()), ..Default::default() } .insert(&*tx) @@ -1616,7 +1638,7 @@ impl Database { project_collaborator::ActiveModel { project_id: ActiveValue::set(project.id), connection_id: ActiveValue::set(connection_id.0 as i32), - connection_epoch: ActiveValue::set(self.epoch), + connection_epoch: ActiveValue::set(self.epoch()), user_id: ActiveValue::set(participant.user_id), replica_id: ActiveValue::set(ReplicaId(0)), is_host: ActiveValue::set(true), @@ -1930,7 +1952,7 @@ impl Database { let new_collaborator = project_collaborator::ActiveModel { project_id: ActiveValue::set(project_id), connection_id: ActiveValue::set(connection_id.0 as i32), - connection_epoch: ActiveValue::set(self.epoch), + connection_epoch: ActiveValue::set(self.epoch()), user_id: ActiveValue::set(participant.user_id), replica_id: ActiveValue::set(replica_id), is_host: ActiveValue::set(false), diff --git a/crates/collab/src/integration_tests.rs b/crates/collab/src/integration_tests.rs index aca5f77fe9..84e7954c33 100644 --- a/crates/collab/src/integration_tests.rs +++ b/crates/collab/src/integration_tests.rs @@ -4,7 +4,6 @@ use crate::{ rpc::{Server, RECONNECT_TIMEOUT}, AppState, }; -use ::rpc::Peer; use anyhow::anyhow; use call::{room, ActiveCall, ParticipantLocation, Room}; use client::{ @@ -365,7 +364,7 @@ async fn test_room_uniqueness( } #[gpui::test(iterations = 10)] -async fn test_disconnecting_from_room( +async fn test_client_disconnecting_from_room( deterministic: Arc, cx_a: &mut TestAppContext, cx_b: &mut TestAppContext, @@ -516,6 +515,75 @@ async fn test_disconnecting_from_room( ); } +#[gpui::test(iterations = 10)] +async fn test_server_restarts( + deterministic: Arc, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + deterministic.forbid_parking(); + let mut server = TestServer::start(cx_a.background()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + server + .make_contacts(&mut [(&client_a, cx_a), (&client_b, cx_b)]) + .await; + + let active_call_a = cx_a.read(ActiveCall::global); + let active_call_b = cx_b.read(ActiveCall::global); + + // Call user B from client A. + active_call_a + .update(cx_a, |call, cx| { + call.invite(client_b.user_id().unwrap(), None, cx) + }) + .await + .unwrap(); + let room_a = active_call_a.read_with(cx_a, |call, _| call.room().unwrap().clone()); + + // User B receives the call and joins the room. + let mut incoming_call_b = active_call_b.read_with(cx_b, |call, _| call.incoming()); + incoming_call_b.next().await.unwrap().unwrap(); + active_call_b + .update(cx_b, |call, cx| call.accept_incoming(cx)) + .await + .unwrap(); + let room_b = active_call_b.read_with(cx_b, |call, _| call.room().unwrap().clone()); + deterministic.run_until_parked(); + assert_eq!( + room_participants(&room_a, cx_a), + RoomParticipants { + remote: vec!["user_b".to_string()], + pending: Default::default() + } + ); + assert_eq!( + room_participants(&room_b, cx_b), + RoomParticipants { + remote: vec!["user_a".to_string()], + pending: Default::default() + } + ); + + // User A automatically reconnects to the room when the server restarts. + server.restart().await; + deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT); + assert_eq!( + room_participants(&room_a, cx_a), + RoomParticipants { + remote: vec!["user_b".to_string()], + pending: Default::default() + } + ); + assert_eq!( + room_participants(&room_b, cx_b), + RoomParticipants { + remote: vec!["user_a".to_string()], + pending: Default::default() + } + ); +} + #[gpui::test(iterations = 10)] async fn test_calls_on_multiple_connections( deterministic: Arc, @@ -5933,7 +6001,6 @@ async fn test_random_collaboration( } struct TestServer { - peer: Arc, app_state: Arc, server: Arc, connection_killers: Arc>>>, @@ -5962,10 +6029,9 @@ impl TestServer { ) .unwrap(); let app_state = Self::build_app_state(&test_db, &live_kit_server).await; - let peer = Peer::new(); let server = Server::new(app_state.clone()); + server.start().await.unwrap(); Self { - peer, app_state, server, connection_killers: Default::default(), @@ -5975,6 +6041,14 @@ impl TestServer { } } + async fn restart(&self) { + self.forbid_connections(); + self.server.teardown(); + self.app_state.db.reset(); + self.server.start().await.unwrap(); + self.allow_connections(); + } + async fn create_client(&mut self, cx: &mut TestAppContext, name: &str) -> TestClient { cx.update(|cx| { cx.set_global(HomeDir(Path::new("/tmp/").to_path_buf())); @@ -6192,7 +6266,6 @@ impl Deref for TestServer { impl Drop for TestServer { fn drop(&mut self) { - self.peer.reset(); self.server.teardown(); self.test_live_kit_server.teardown().unwrap(); } diff --git a/crates/collab/src/main.rs b/crates/collab/src/main.rs index a288e0f3ce..384789b7c2 100644 --- a/crates/collab/src/main.rs +++ b/crates/collab/src/main.rs @@ -52,12 +52,12 @@ async fn main() -> Result<()> { init_tracing(&config); let state = AppState::new(config).await?; - state.db.clear_stale_data().await?; let listener = TcpListener::bind(&format!("0.0.0.0:{}", state.config.http_port)) .expect("failed to bind TCP listener"); let rpc_server = collab::rpc::Server::new(state.clone()); + rpc_server.start().await?; let app = collab::api::routes(rpc_server.clone(), state.clone()) .merge(collab::rpc::routes(rpc_server.clone())) diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index a799837ad4..18bd96c536 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -237,7 +237,15 @@ impl Server { Arc::new(server) } + pub async fn start(&self) -> Result<()> { + self.app_state.db.delete_stale_projects().await?; + // TODO: delete stale rooms after timeout. + // self.app_state.db.delete_stale_rooms().await?; + Ok(()) + } + pub fn teardown(&self) { + self.peer.reset(); let _ = self.teardown.send(()); } @@ -339,7 +347,7 @@ impl Server { let user_id = user.id; let login = user.github_login; let span = info_span!("handle connection", %user_id, %login, %address); - let teardown = self.teardown.subscribe(); + let mut teardown = self.teardown.subscribe(); async move { let (connection_id, handle_io, mut incoming_rx) = this .peer @@ -409,6 +417,7 @@ impl Server { let next_message = incoming_rx.next().fuse(); futures::pin_mut!(next_message); futures::select_biased! { + _ = teardown.changed().fuse() => return Ok(()), result = handle_io => { if let Err(error) = result { tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");