From aca3f025906a3236682534ebbbf0c50b9a26cefa Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 8 Dec 2022 12:14:12 +0100 Subject: [PATCH] Re-join room when client temporarily loses connection --- crates/call/src/participant.rs | 10 +- crates/call/src/room.rs | 114 ++++++++++++--- .../20221109000000_test_schema.sql | 1 + ...d_connection_lost_to_room_participants.sql | 2 + crates/collab/src/db.rs | 64 +++++++- crates/collab/src/db/room_participant.rs | 1 + crates/collab/src/integration_tests.rs | 30 +++- crates/collab/src/rpc.rs | 137 +++++++++--------- 8 files changed, 267 insertions(+), 92 deletions(-) create mode 100644 crates/collab/migrations/20221207165001_add_connection_lost_to_room_participants.sql diff --git a/crates/call/src/participant.rs b/crates/call/src/participant.rs index dfa456f734..d5c6d85154 100644 --- a/crates/call/src/participant.rs +++ b/crates/call/src/participant.rs @@ -4,7 +4,7 @@ use collections::HashMap; use gpui::WeakModelHandle; pub use live_kit_client::Frame; use project::Project; -use std::sync::Arc; +use std::{fmt, sync::Arc}; #[derive(Copy, Clone, Debug, Eq, PartialEq)] pub enum ParticipantLocation { @@ -36,7 +36,7 @@ pub struct LocalParticipant { pub active_project: Option>, } -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct RemoteParticipant { pub user: Arc, pub projects: Vec, @@ -49,6 +49,12 @@ pub struct RemoteVideoTrack { pub(crate) live_kit_track: Arc, } +impl fmt::Debug for RemoteVideoTrack { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("RemoteVideoTrack").finish() + } +} + impl RemoteVideoTrack { pub fn frames(&self) -> async_broadcast::Receiver { self.live_kit_track.frames() diff --git a/crates/call/src/room.rs b/crates/call/src/room.rs index f8a55a3a93..828885e9bd 100644 --- a/crates/call/src/room.rs +++ b/crates/call/src/room.rs @@ -5,14 +5,18 @@ use crate::{ use anyhow::{anyhow, Result}; use client::{proto, Client, PeerId, TypedEnvelope, User, UserStore}; use collections::{BTreeMap, HashSet}; -use futures::StreamExt; -use gpui::{AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task}; +use futures::{FutureExt, StreamExt}; +use gpui::{ + AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task, WeakModelHandle, +}; use live_kit_client::{LocalTrackPublication, LocalVideoTrack, RemoteVideoTrackUpdate}; use postage::stream::Stream; use project::Project; -use std::{mem, sync::Arc}; +use std::{mem, sync::Arc, time::Duration}; use util::{post_inc, ResultExt}; +pub const RECONNECTION_TIMEOUT: Duration = client::RECEIVE_TIMEOUT; + #[derive(Clone, Debug, PartialEq, Eq)] pub enum Event { ParticipantLocationChanged { @@ -46,6 +50,7 @@ pub struct Room { user_store: ModelHandle, subscriptions: Vec, pending_room_update: Option>, + _maintain_connection: Task>, } impl Entity for Room { @@ -66,21 +71,6 @@ impl Room { user_store: ModelHandle, cx: &mut ModelContext, ) -> Self { - let mut client_status = client.status(); - cx.spawn_weak(|this, mut cx| async move { - let is_connected = client_status - .next() - .await - .map_or(false, |s| s.is_connected()); - // Even if we're initially connected, any future change of the status means we momentarily disconnected. - if !is_connected || client_status.next().await.is_some() { - if let Some(this) = this.upgrade(&cx) { - let _ = this.update(&mut cx, |this, cx| this.leave(cx)); - } - } - }) - .detach(); - let live_kit_room = if let Some(connection_info) = live_kit_connection_info { let room = live_kit_client::Room::new(); let mut status = room.status(); @@ -131,6 +121,9 @@ impl Room { None }; + let _maintain_connection = + cx.spawn_weak(|this, cx| Self::maintain_connection(this, client.clone(), cx)); + Self { id, live_kit: live_kit_room, @@ -145,6 +138,7 @@ impl Room { pending_room_update: None, client, user_store, + _maintain_connection, } } @@ -245,6 +239,83 @@ impl Room { Ok(()) } + async fn maintain_connection( + this: WeakModelHandle, + client: Arc, + mut cx: AsyncAppContext, + ) -> Result<()> { + let mut client_status = client.status(); + loop { + let is_connected = client_status + .next() + .await + .map_or(false, |s| s.is_connected()); + // Even if we're initially connected, any future change of the status means we momentarily disconnected. + if !is_connected || client_status.next().await.is_some() { + let room_id = this + .upgrade(&cx) + .ok_or_else(|| anyhow!("room was dropped"))? + .update(&mut cx, |this, cx| { + this.status = RoomStatus::Rejoining; + cx.notify(); + this.id + }); + + // Wait for client to re-establish a connection to the server. + let mut reconnection_timeout = cx.background().timer(RECONNECTION_TIMEOUT).fuse(); + let client_reconnection = async { + loop { + if let Some(status) = client_status.next().await { + if status.is_connected() { + return true; + } + } else { + return false; + } + } + } + .fuse(); + futures::pin_mut!(client_reconnection); + + futures::select_biased! { + reconnected = client_reconnection => { + if reconnected { + // Client managed to reconnect to the server. Now attempt to join the room. + let rejoin_room = async { + let response = client.request(proto::JoinRoom { id: room_id }).await?; + let room_proto = response.room.ok_or_else(|| anyhow!("invalid room"))?; + this.upgrade(&cx) + .ok_or_else(|| anyhow!("room was dropped"))? + .update(&mut cx, |this, cx| { + this.status = RoomStatus::Online; + this.apply_room_update(room_proto, cx) + })?; + anyhow::Ok(()) + }; + + // If we successfully joined the room, go back around the loop + // waiting for future connection status changes. + if rejoin_room.await.log_err().is_some() { + continue; + } + } + } + _ = reconnection_timeout => {} + } + + // The client failed to re-establish a connection to the server + // or an error occurred while trying to re-join the room. Either way + // we leave the room and return an error. + if let Some(this) = this.upgrade(&cx) { + let _ = this.update(&mut cx, |this, cx| this.leave(cx)); + } + return Err(anyhow!( + "can't reconnect to room: client failed to re-establish connection" + )); + } + } + } + pub fn id(&self) -> u64 { self.id } @@ -325,9 +396,11 @@ impl Room { } if let Some(participants) = remote_participants.log_err() { + let mut participant_peer_ids = HashSet::default(); for (participant, user) in room.participants.into_iter().zip(participants) { let peer_id = PeerId(participant.peer_id); this.participant_user_ids.insert(participant.user_id); + participant_peer_ids.insert(peer_id); let old_projects = this .remote_participants @@ -394,8 +467,8 @@ impl Room { } } - this.remote_participants.retain(|_, participant| { - if this.participant_user_ids.contains(&participant.user.id) { + this.remote_participants.retain(|peer_id, participant| { + if participant_peer_ids.contains(peer_id) { true } else { for project in &participant.projects { @@ -751,6 +824,7 @@ impl Default for ScreenTrack { #[derive(Copy, Clone, PartialEq, Eq)] pub enum RoomStatus { Online, + Rejoining, Offline, } diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql index 68caf4fad7..4eba8d2302 100644 --- a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql +++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql @@ -118,6 +118,7 @@ CREATE TABLE "room_participants" ( "user_id" INTEGER NOT NULL REFERENCES users (id), "answering_connection_id" INTEGER, "answering_connection_epoch" TEXT, + "connection_lost" BOOLEAN NOT NULL, "location_kind" INTEGER, "location_project_id" INTEGER, "initial_project_id" INTEGER, diff --git a/crates/collab/migrations/20221207165001_add_connection_lost_to_room_participants.sql b/crates/collab/migrations/20221207165001_add_connection_lost_to_room_participants.sql new file mode 100644 index 0000000000..d49eda41b8 --- /dev/null +++ b/crates/collab/migrations/20221207165001_add_connection_lost_to_room_participants.sql @@ -0,0 +1,2 @@ +ALTER TABLE "room_participants" + ADD "connection_lost" BOOLEAN NOT NULL DEFAULT FALSE; diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index aae4d92964..063d82f932 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -1034,6 +1034,7 @@ impl Database { user_id: ActiveValue::set(user_id), answering_connection_id: ActiveValue::set(Some(connection_id.0 as i32)), answering_connection_epoch: ActiveValue::set(Some(self.epoch)), + connection_lost: ActiveValue::set(false), calling_user_id: ActiveValue::set(user_id), calling_connection_id: ActiveValue::set(connection_id.0 as i32), calling_connection_epoch: ActiveValue::set(self.epoch), @@ -1060,6 +1061,7 @@ impl Database { room_participant::ActiveModel { room_id: ActiveValue::set(room_id), user_id: ActiveValue::set(called_user_id), + connection_lost: ActiveValue::set(false), calling_user_id: ActiveValue::set(calling_user_id), calling_connection_id: ActiveValue::set(calling_connection_id.0 as i32), calling_connection_epoch: ActiveValue::set(self.epoch), @@ -1175,11 +1177,16 @@ impl Database { room_participant::Column::RoomId .eq(room_id) .and(room_participant::Column::UserId.eq(user_id)) - .and(room_participant::Column::AnsweringConnectionId.is_null()), + .and( + room_participant::Column::AnsweringConnectionId + .is_null() + .or(room_participant::Column::ConnectionLost.eq(true)), + ), ) .set(room_participant::ActiveModel { answering_connection_id: ActiveValue::set(Some(connection_id.0 as i32)), answering_connection_epoch: ActiveValue::set(Some(self.epoch)), + connection_lost: ActiveValue::set(false), ..Default::default() }) .exec(&*tx) @@ -1367,6 +1374,61 @@ impl Database { .await } + pub async fn connection_lost( + &self, + connection_id: ConnectionId, + ) -> Result>> { + self.room_transaction(|tx| async move { + let participant = room_participant::Entity::find() + .filter(room_participant::Column::AnsweringConnectionId.eq(connection_id.0 as i32)) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("not a participant in any room"))?; + let room_id = participant.room_id; + + room_participant::Entity::update(room_participant::ActiveModel { + connection_lost: ActiveValue::set(true), + ..participant.into_active_model() + }) + .exec(&*tx) + .await?; + + let collaborator_on_projects = project_collaborator::Entity::find() + .find_also_related(project::Entity) + .filter(project_collaborator::Column::ConnectionId.eq(connection_id.0 as i32)) + .all(&*tx) + .await?; + project_collaborator::Entity::delete_many() + .filter(project_collaborator::Column::ConnectionId.eq(connection_id.0)) + .exec(&*tx) + .await?; + + let mut left_projects = Vec::new(); + for (_, project) in collaborator_on_projects { + if let Some(project) = project { + let collaborators = project + .find_related(project_collaborator::Entity) + .all(&*tx) + .await?; + let connection_ids = collaborators + .into_iter() + .map(|collaborator| ConnectionId(collaborator.connection_id as u32)) + .collect(); + + left_projects.push(LeftProject { + id: project.id, + host_user_id: project.host_user_id, + host_connection_id: ConnectionId(project.host_connection_id as u32), + connection_ids, + }); + } + } + + Ok((room_id, left_projects)) + }) + .await + } + fn build_incoming_call( room: &proto::Room, called_user_id: UserId, diff --git a/crates/collab/src/db/room_participant.rs b/crates/collab/src/db/room_participant.rs index 783f45aa93..3ab3fbbdda 100644 --- a/crates/collab/src/db/room_participant.rs +++ b/crates/collab/src/db/room_participant.rs @@ -10,6 +10,7 @@ pub struct Model { pub user_id: UserId, pub answering_connection_id: Option, pub answering_connection_epoch: Option, + pub connection_lost: bool, pub location_kind: Option, pub location_project_id: Option, pub initial_project_id: Option, diff --git a/crates/collab/src/integration_tests.rs b/crates/collab/src/integration_tests.rs index 96fed5887b..f31022afc4 100644 --- a/crates/collab/src/integration_tests.rs +++ b/crates/collab/src/integration_tests.rs @@ -365,7 +365,7 @@ async fn test_room_uniqueness( } #[gpui::test(iterations = 10)] -async fn test_leaving_room_on_disconnection( +async fn test_disconnecting_from_room( deterministic: Arc, cx_a: &mut TestAppContext, cx_b: &mut TestAppContext, @@ -414,9 +414,30 @@ async fn test_leaving_room_on_disconnection( } ); - // When user A disconnects, both client A and B clear their room on the active call. + // User A automatically reconnects to the room upon disconnection. server.disconnect_client(client_a.peer_id().unwrap()); deterministic.advance_clock(rpc::RECEIVE_TIMEOUT); + deterministic.run_until_parked(); + assert_eq!( + room_participants(&room_a, cx_a), + RoomParticipants { + remote: vec!["user_b".to_string()], + pending: Default::default() + } + ); + assert_eq!( + room_participants(&room_b, cx_b), + RoomParticipants { + remote: vec!["user_a".to_string()], + pending: Default::default() + } + ); + + // When user A disconnects, both client A and B clear their room on the active call. + server.forbid_connections(); + server.disconnect_client(client_a.peer_id().unwrap()); + deterministic.advance_clock(rpc::RECEIVE_TIMEOUT + crate::rpc::RECONNECTION_TIMEOUT); + deterministic.run_until_parked(); active_call_a.read_with(cx_a, |call, _| assert!(call.room().is_none())); active_call_b.read_with(cx_b, |call, _| assert!(call.room().is_none())); assert_eq!( @@ -434,6 +455,11 @@ async fn test_leaving_room_on_disconnection( } ); + // Allow user A to reconnect to the server. + server.allow_connections(); + deterministic.advance_clock(rpc::RECEIVE_TIMEOUT); + deterministic.run_until_parked(); + // Call user B again from client A. active_call_a .update(cx_a, |call, cx| { diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index c1f9eb039b..3f70043bfb 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -51,11 +51,14 @@ use std::{ atomic::{AtomicBool, Ordering::SeqCst}, Arc, }, + time::Duration, }; use tokio::sync::{Mutex, MutexGuard}; use tower::ServiceBuilder; use tracing::{info_span, instrument, Instrument}; +pub const RECONNECTION_TIMEOUT: Duration = rpc::RECEIVE_TIMEOUT; + lazy_static! { static ref METRIC_CONNECTIONS: IntGauge = register_int_gauge!("connections", "number of connections").unwrap(); @@ -435,7 +438,7 @@ impl Server { drop(foreground_message_handlers); tracing::info!(%user_id, %login, %connection_id, %address, "signing out"); - if let Err(error) = sign_out(session).await { + if let Err(error) = sign_out(session, executor).await { tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out"); } @@ -636,29 +639,38 @@ pub async fn handle_metrics(Extension(server): Extension>) -> Result Ok(encoded_metrics) } -#[instrument(err)] -async fn sign_out(session: Session) -> Result<()> { +#[instrument(err, skip(executor))] +async fn sign_out(session: Session, executor: Executor) -> Result<()> { session.peer.disconnect(session.connection_id); - let decline_calls = { - let mut pool = session.connection_pool().await; - pool.remove_connection(session.connection_id)?; - let mut connections = pool.user_connection_ids(session.user_id); - connections.next().is_none() - }; + session + .connection_pool() + .await + .remove_connection(session.connection_id)?; - leave_room_for_session(&session).await.trace_err(); - if decline_calls { - if let Some(room) = session - .db() - .await - .decline_call(None, session.user_id) - .await - .trace_err() - { - room_updated(&room, &session); + if let Ok(mut left_projects) = session + .db() + .await + .connection_lost(session.connection_id) + .await + { + for left_project in mem::take(&mut *left_projects) { + project_left(&left_project, &session); } } + executor.sleep(RECONNECTION_TIMEOUT).await; + leave_room_for_session(&session).await.trace_err(); + + if !session + .connection_pool() + .await + .is_user_online(session.user_id) + { + let db = session.db().await; + if let Some(room) = db.decline_call(None, session.user_id).await.trace_err() { + room_updated(&room, &session); + } + } update_user_contacts(session.user_id, &session).await?; Ok(()) @@ -1089,20 +1101,7 @@ async fn leave_project(request: proto::LeaveProject, session: Session) -> Result host_connection_id = %project.host_connection_id, "leave project" ); - - broadcast( - sender_id, - project.connection_ids.iter().copied(), - |conn_id| { - session.peer.send( - conn_id, - proto::RemoveProjectCollaborator { - project_id: project_id.to_proto(), - peer_id: sender_id.0, - }, - ) - }, - ); + project_left(&project, &session); Ok(()) } @@ -1833,40 +1832,7 @@ async fn leave_room_for_session(session: &Session) -> Result<()> { contacts_to_update.insert(session.user_id); for project in left_room.left_projects.values() { - for connection_id in &project.connection_ids { - if project.host_user_id == session.user_id { - session - .peer - .send( - *connection_id, - proto::UnshareProject { - project_id: project.id.to_proto(), - }, - ) - .trace_err(); - } else { - session - .peer - .send( - *connection_id, - proto::RemoveProjectCollaborator { - project_id: project.id.to_proto(), - peer_id: session.connection_id.0, - }, - ) - .trace_err(); - } - } - - session - .peer - .send( - session.connection_id, - proto::UnshareProject { - project_id: project.id.to_proto(), - }, - ) - .trace_err(); + project_left(project, session); } room_updated(&left_room.room, &session); @@ -1906,6 +1872,43 @@ async fn leave_room_for_session(session: &Session) -> Result<()> { Ok(()) } +fn project_left(project: &db::LeftProject, session: &Session) { + for connection_id in &project.connection_ids { + if project.host_user_id == session.user_id { + session + .peer + .send( + *connection_id, + proto::UnshareProject { + project_id: project.id.to_proto(), + }, + ) + .trace_err(); + } else { + session + .peer + .send( + *connection_id, + proto::RemoveProjectCollaborator { + project_id: project.id.to_proto(), + peer_id: session.connection_id.0, + }, + ) + .trace_err(); + } + } + + session + .peer + .send( + session.connection_id, + proto::UnshareProject { + project_id: project.id.to_proto(), + }, + ) + .trace_err(); +} + pub trait ResultExt { type Ok;