Re-join room when client temporarily loses connection

This commit is contained in:
Antonio Scandurra 2022-12-08 12:14:12 +01:00
parent d74fb97158
commit aca3f02590
8 changed files with 267 additions and 92 deletions

View file

@ -4,7 +4,7 @@ use collections::HashMap;
use gpui::WeakModelHandle; use gpui::WeakModelHandle;
pub use live_kit_client::Frame; pub use live_kit_client::Frame;
use project::Project; use project::Project;
use std::sync::Arc; use std::{fmt, sync::Arc};
#[derive(Copy, Clone, Debug, Eq, PartialEq)] #[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum ParticipantLocation { pub enum ParticipantLocation {
@ -36,7 +36,7 @@ pub struct LocalParticipant {
pub active_project: Option<WeakModelHandle<Project>>, pub active_project: Option<WeakModelHandle<Project>>,
} }
#[derive(Clone)] #[derive(Clone, Debug)]
pub struct RemoteParticipant { pub struct RemoteParticipant {
pub user: Arc<User>, pub user: Arc<User>,
pub projects: Vec<proto::ParticipantProject>, pub projects: Vec<proto::ParticipantProject>,
@ -49,6 +49,12 @@ pub struct RemoteVideoTrack {
pub(crate) live_kit_track: Arc<live_kit_client::RemoteVideoTrack>, pub(crate) live_kit_track: Arc<live_kit_client::RemoteVideoTrack>,
} }
impl fmt::Debug for RemoteVideoTrack {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RemoteVideoTrack").finish()
}
}
impl RemoteVideoTrack { impl RemoteVideoTrack {
pub fn frames(&self) -> async_broadcast::Receiver<Frame> { pub fn frames(&self) -> async_broadcast::Receiver<Frame> {
self.live_kit_track.frames() self.live_kit_track.frames()

View file

@ -5,14 +5,18 @@ use crate::{
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use client::{proto, Client, PeerId, TypedEnvelope, User, UserStore}; use client::{proto, Client, PeerId, TypedEnvelope, User, UserStore};
use collections::{BTreeMap, HashSet}; use collections::{BTreeMap, HashSet};
use futures::StreamExt; use futures::{FutureExt, StreamExt};
use gpui::{AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task}; use gpui::{
AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task, WeakModelHandle,
};
use live_kit_client::{LocalTrackPublication, LocalVideoTrack, RemoteVideoTrackUpdate}; use live_kit_client::{LocalTrackPublication, LocalVideoTrack, RemoteVideoTrackUpdate};
use postage::stream::Stream; use postage::stream::Stream;
use project::Project; use project::Project;
use std::{mem, sync::Arc}; use std::{mem, sync::Arc, time::Duration};
use util::{post_inc, ResultExt}; use util::{post_inc, ResultExt};
pub const RECONNECTION_TIMEOUT: Duration = client::RECEIVE_TIMEOUT;
#[derive(Clone, Debug, PartialEq, Eq)] #[derive(Clone, Debug, PartialEq, Eq)]
pub enum Event { pub enum Event {
ParticipantLocationChanged { ParticipantLocationChanged {
@ -46,6 +50,7 @@ pub struct Room {
user_store: ModelHandle<UserStore>, user_store: ModelHandle<UserStore>,
subscriptions: Vec<client::Subscription>, subscriptions: Vec<client::Subscription>,
pending_room_update: Option<Task<()>>, pending_room_update: Option<Task<()>>,
_maintain_connection: Task<Result<()>>,
} }
impl Entity for Room { impl Entity for Room {
@ -66,21 +71,6 @@ impl Room {
user_store: ModelHandle<UserStore>, user_store: ModelHandle<UserStore>,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) -> Self { ) -> Self {
let mut client_status = client.status();
cx.spawn_weak(|this, mut cx| async move {
let is_connected = client_status
.next()
.await
.map_or(false, |s| s.is_connected());
// Even if we're initially connected, any future change of the status means we momentarily disconnected.
if !is_connected || client_status.next().await.is_some() {
if let Some(this) = this.upgrade(&cx) {
let _ = this.update(&mut cx, |this, cx| this.leave(cx));
}
}
})
.detach();
let live_kit_room = if let Some(connection_info) = live_kit_connection_info { let live_kit_room = if let Some(connection_info) = live_kit_connection_info {
let room = live_kit_client::Room::new(); let room = live_kit_client::Room::new();
let mut status = room.status(); let mut status = room.status();
@ -131,6 +121,9 @@ impl Room {
None None
}; };
let _maintain_connection =
cx.spawn_weak(|this, cx| Self::maintain_connection(this, client.clone(), cx));
Self { Self {
id, id,
live_kit: live_kit_room, live_kit: live_kit_room,
@ -145,6 +138,7 @@ impl Room {
pending_room_update: None, pending_room_update: None,
client, client,
user_store, user_store,
_maintain_connection,
} }
} }
@ -245,6 +239,83 @@ impl Room {
Ok(()) Ok(())
} }
async fn maintain_connection(
this: WeakModelHandle<Self>,
client: Arc<Client>,
mut cx: AsyncAppContext,
) -> Result<()> {
let mut client_status = client.status();
loop {
let is_connected = client_status
.next()
.await
.map_or(false, |s| s.is_connected());
// Even if we're initially connected, any future change of the status means we momentarily disconnected.
if !is_connected || client_status.next().await.is_some() {
let room_id = this
.upgrade(&cx)
.ok_or_else(|| anyhow!("room was dropped"))?
.update(&mut cx, |this, cx| {
this.status = RoomStatus::Rejoining;
cx.notify();
this.id
});
// Wait for client to re-establish a connection to the server.
let mut reconnection_timeout = cx.background().timer(RECONNECTION_TIMEOUT).fuse();
let client_reconnection = async {
loop {
if let Some(status) = client_status.next().await {
if status.is_connected() {
return true;
}
} else {
return false;
}
}
}
.fuse();
futures::pin_mut!(client_reconnection);
futures::select_biased! {
reconnected = client_reconnection => {
if reconnected {
// Client managed to reconnect to the server. Now attempt to join the room.
let rejoin_room = async {
let response = client.request(proto::JoinRoom { id: room_id }).await?;
let room_proto = response.room.ok_or_else(|| anyhow!("invalid room"))?;
this.upgrade(&cx)
.ok_or_else(|| anyhow!("room was dropped"))?
.update(&mut cx, |this, cx| {
this.status = RoomStatus::Online;
this.apply_room_update(room_proto, cx)
})?;
anyhow::Ok(())
};
// If we successfully joined the room, go back around the loop
// waiting for future connection status changes.
if rejoin_room.await.log_err().is_some() {
continue;
}
}
}
_ = reconnection_timeout => {}
}
// The client failed to re-establish a connection to the server
// or an error occurred while trying to re-join the room. Either way
// we leave the room and return an error.
if let Some(this) = this.upgrade(&cx) {
let _ = this.update(&mut cx, |this, cx| this.leave(cx));
}
return Err(anyhow!(
"can't reconnect to room: client failed to re-establish connection"
));
}
}
}
pub fn id(&self) -> u64 { pub fn id(&self) -> u64 {
self.id self.id
} }
@ -325,9 +396,11 @@ impl Room {
} }
if let Some(participants) = remote_participants.log_err() { if let Some(participants) = remote_participants.log_err() {
let mut participant_peer_ids = HashSet::default();
for (participant, user) in room.participants.into_iter().zip(participants) { for (participant, user) in room.participants.into_iter().zip(participants) {
let peer_id = PeerId(participant.peer_id); let peer_id = PeerId(participant.peer_id);
this.participant_user_ids.insert(participant.user_id); this.participant_user_ids.insert(participant.user_id);
participant_peer_ids.insert(peer_id);
let old_projects = this let old_projects = this
.remote_participants .remote_participants
@ -394,8 +467,8 @@ impl Room {
} }
} }
this.remote_participants.retain(|_, participant| { this.remote_participants.retain(|peer_id, participant| {
if this.participant_user_ids.contains(&participant.user.id) { if participant_peer_ids.contains(peer_id) {
true true
} else { } else {
for project in &participant.projects { for project in &participant.projects {
@ -751,6 +824,7 @@ impl Default for ScreenTrack {
#[derive(Copy, Clone, PartialEq, Eq)] #[derive(Copy, Clone, PartialEq, Eq)]
pub enum RoomStatus { pub enum RoomStatus {
Online, Online,
Rejoining,
Offline, Offline,
} }

View file

@ -118,6 +118,7 @@ CREATE TABLE "room_participants" (
"user_id" INTEGER NOT NULL REFERENCES users (id), "user_id" INTEGER NOT NULL REFERENCES users (id),
"answering_connection_id" INTEGER, "answering_connection_id" INTEGER,
"answering_connection_epoch" TEXT, "answering_connection_epoch" TEXT,
"connection_lost" BOOLEAN NOT NULL,
"location_kind" INTEGER, "location_kind" INTEGER,
"location_project_id" INTEGER, "location_project_id" INTEGER,
"initial_project_id" INTEGER, "initial_project_id" INTEGER,

View file

@ -0,0 +1,2 @@
ALTER TABLE "room_participants"
ADD "connection_lost" BOOLEAN NOT NULL DEFAULT FALSE;

View file

@ -1034,6 +1034,7 @@ impl Database {
user_id: ActiveValue::set(user_id), user_id: ActiveValue::set(user_id),
answering_connection_id: ActiveValue::set(Some(connection_id.0 as i32)), answering_connection_id: ActiveValue::set(Some(connection_id.0 as i32)),
answering_connection_epoch: ActiveValue::set(Some(self.epoch)), answering_connection_epoch: ActiveValue::set(Some(self.epoch)),
connection_lost: ActiveValue::set(false),
calling_user_id: ActiveValue::set(user_id), calling_user_id: ActiveValue::set(user_id),
calling_connection_id: ActiveValue::set(connection_id.0 as i32), calling_connection_id: ActiveValue::set(connection_id.0 as i32),
calling_connection_epoch: ActiveValue::set(self.epoch), calling_connection_epoch: ActiveValue::set(self.epoch),
@ -1060,6 +1061,7 @@ impl Database {
room_participant::ActiveModel { room_participant::ActiveModel {
room_id: ActiveValue::set(room_id), room_id: ActiveValue::set(room_id),
user_id: ActiveValue::set(called_user_id), user_id: ActiveValue::set(called_user_id),
connection_lost: ActiveValue::set(false),
calling_user_id: ActiveValue::set(calling_user_id), calling_user_id: ActiveValue::set(calling_user_id),
calling_connection_id: ActiveValue::set(calling_connection_id.0 as i32), calling_connection_id: ActiveValue::set(calling_connection_id.0 as i32),
calling_connection_epoch: ActiveValue::set(self.epoch), calling_connection_epoch: ActiveValue::set(self.epoch),
@ -1175,11 +1177,16 @@ impl Database {
room_participant::Column::RoomId room_participant::Column::RoomId
.eq(room_id) .eq(room_id)
.and(room_participant::Column::UserId.eq(user_id)) .and(room_participant::Column::UserId.eq(user_id))
.and(room_participant::Column::AnsweringConnectionId.is_null()), .and(
room_participant::Column::AnsweringConnectionId
.is_null()
.or(room_participant::Column::ConnectionLost.eq(true)),
),
) )
.set(room_participant::ActiveModel { .set(room_participant::ActiveModel {
answering_connection_id: ActiveValue::set(Some(connection_id.0 as i32)), answering_connection_id: ActiveValue::set(Some(connection_id.0 as i32)),
answering_connection_epoch: ActiveValue::set(Some(self.epoch)), answering_connection_epoch: ActiveValue::set(Some(self.epoch)),
connection_lost: ActiveValue::set(false),
..Default::default() ..Default::default()
}) })
.exec(&*tx) .exec(&*tx)
@ -1367,6 +1374,61 @@ impl Database {
.await .await
} }
pub async fn connection_lost(
&self,
connection_id: ConnectionId,
) -> Result<RoomGuard<Vec<LeftProject>>> {
self.room_transaction(|tx| async move {
let participant = room_participant::Entity::find()
.filter(room_participant::Column::AnsweringConnectionId.eq(connection_id.0 as i32))
.one(&*tx)
.await?
.ok_or_else(|| anyhow!("not a participant in any room"))?;
let room_id = participant.room_id;
room_participant::Entity::update(room_participant::ActiveModel {
connection_lost: ActiveValue::set(true),
..participant.into_active_model()
})
.exec(&*tx)
.await?;
let collaborator_on_projects = project_collaborator::Entity::find()
.find_also_related(project::Entity)
.filter(project_collaborator::Column::ConnectionId.eq(connection_id.0 as i32))
.all(&*tx)
.await?;
project_collaborator::Entity::delete_many()
.filter(project_collaborator::Column::ConnectionId.eq(connection_id.0))
.exec(&*tx)
.await?;
let mut left_projects = Vec::new();
for (_, project) in collaborator_on_projects {
if let Some(project) = project {
let collaborators = project
.find_related(project_collaborator::Entity)
.all(&*tx)
.await?;
let connection_ids = collaborators
.into_iter()
.map(|collaborator| ConnectionId(collaborator.connection_id as u32))
.collect();
left_projects.push(LeftProject {
id: project.id,
host_user_id: project.host_user_id,
host_connection_id: ConnectionId(project.host_connection_id as u32),
connection_ids,
});
}
}
Ok((room_id, left_projects))
})
.await
}
fn build_incoming_call( fn build_incoming_call(
room: &proto::Room, room: &proto::Room,
called_user_id: UserId, called_user_id: UserId,

View file

@ -10,6 +10,7 @@ pub struct Model {
pub user_id: UserId, pub user_id: UserId,
pub answering_connection_id: Option<i32>, pub answering_connection_id: Option<i32>,
pub answering_connection_epoch: Option<Uuid>, pub answering_connection_epoch: Option<Uuid>,
pub connection_lost: bool,
pub location_kind: Option<i32>, pub location_kind: Option<i32>,
pub location_project_id: Option<ProjectId>, pub location_project_id: Option<ProjectId>,
pub initial_project_id: Option<ProjectId>, pub initial_project_id: Option<ProjectId>,

View file

@ -365,7 +365,7 @@ async fn test_room_uniqueness(
} }
#[gpui::test(iterations = 10)] #[gpui::test(iterations = 10)]
async fn test_leaving_room_on_disconnection( async fn test_disconnecting_from_room(
deterministic: Arc<Deterministic>, deterministic: Arc<Deterministic>,
cx_a: &mut TestAppContext, cx_a: &mut TestAppContext,
cx_b: &mut TestAppContext, cx_b: &mut TestAppContext,
@ -414,9 +414,30 @@ async fn test_leaving_room_on_disconnection(
} }
); );
// When user A disconnects, both client A and B clear their room on the active call. // User A automatically reconnects to the room upon disconnection.
server.disconnect_client(client_a.peer_id().unwrap()); server.disconnect_client(client_a.peer_id().unwrap());
deterministic.advance_clock(rpc::RECEIVE_TIMEOUT); deterministic.advance_clock(rpc::RECEIVE_TIMEOUT);
deterministic.run_until_parked();
assert_eq!(
room_participants(&room_a, cx_a),
RoomParticipants {
remote: vec!["user_b".to_string()],
pending: Default::default()
}
);
assert_eq!(
room_participants(&room_b, cx_b),
RoomParticipants {
remote: vec!["user_a".to_string()],
pending: Default::default()
}
);
// When user A disconnects, both client A and B clear their room on the active call.
server.forbid_connections();
server.disconnect_client(client_a.peer_id().unwrap());
deterministic.advance_clock(rpc::RECEIVE_TIMEOUT + crate::rpc::RECONNECTION_TIMEOUT);
deterministic.run_until_parked();
active_call_a.read_with(cx_a, |call, _| assert!(call.room().is_none())); active_call_a.read_with(cx_a, |call, _| assert!(call.room().is_none()));
active_call_b.read_with(cx_b, |call, _| assert!(call.room().is_none())); active_call_b.read_with(cx_b, |call, _| assert!(call.room().is_none()));
assert_eq!( assert_eq!(
@ -434,6 +455,11 @@ async fn test_leaving_room_on_disconnection(
} }
); );
// Allow user A to reconnect to the server.
server.allow_connections();
deterministic.advance_clock(rpc::RECEIVE_TIMEOUT);
deterministic.run_until_parked();
// Call user B again from client A. // Call user B again from client A.
active_call_a active_call_a
.update(cx_a, |call, cx| { .update(cx_a, |call, cx| {

View file

@ -51,11 +51,14 @@ use std::{
atomic::{AtomicBool, Ordering::SeqCst}, atomic::{AtomicBool, Ordering::SeqCst},
Arc, Arc,
}, },
time::Duration,
}; };
use tokio::sync::{Mutex, MutexGuard}; use tokio::sync::{Mutex, MutexGuard};
use tower::ServiceBuilder; use tower::ServiceBuilder;
use tracing::{info_span, instrument, Instrument}; use tracing::{info_span, instrument, Instrument};
pub const RECONNECTION_TIMEOUT: Duration = rpc::RECEIVE_TIMEOUT;
lazy_static! { lazy_static! {
static ref METRIC_CONNECTIONS: IntGauge = static ref METRIC_CONNECTIONS: IntGauge =
register_int_gauge!("connections", "number of connections").unwrap(); register_int_gauge!("connections", "number of connections").unwrap();
@ -435,7 +438,7 @@ impl Server {
drop(foreground_message_handlers); drop(foreground_message_handlers);
tracing::info!(%user_id, %login, %connection_id, %address, "signing out"); tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
if let Err(error) = sign_out(session).await { if let Err(error) = sign_out(session, executor).await {
tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out"); tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
} }
@ -636,29 +639,38 @@ pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result
Ok(encoded_metrics) Ok(encoded_metrics)
} }
#[instrument(err)] #[instrument(err, skip(executor))]
async fn sign_out(session: Session) -> Result<()> { async fn sign_out(session: Session, executor: Executor) -> Result<()> {
session.peer.disconnect(session.connection_id); session.peer.disconnect(session.connection_id);
let decline_calls = { session
let mut pool = session.connection_pool().await; .connection_pool()
pool.remove_connection(session.connection_id)?; .await
let mut connections = pool.user_connection_ids(session.user_id); .remove_connection(session.connection_id)?;
connections.next().is_none()
};
leave_room_for_session(&session).await.trace_err(); if let Ok(mut left_projects) = session
if decline_calls { .db()
if let Some(room) = session .await
.db() .connection_lost(session.connection_id)
.await .await
.decline_call(None, session.user_id) {
.await for left_project in mem::take(&mut *left_projects) {
.trace_err() project_left(&left_project, &session);
{
room_updated(&room, &session);
} }
} }
executor.sleep(RECONNECTION_TIMEOUT).await;
leave_room_for_session(&session).await.trace_err();
if !session
.connection_pool()
.await
.is_user_online(session.user_id)
{
let db = session.db().await;
if let Some(room) = db.decline_call(None, session.user_id).await.trace_err() {
room_updated(&room, &session);
}
}
update_user_contacts(session.user_id, &session).await?; update_user_contacts(session.user_id, &session).await?;
Ok(()) Ok(())
@ -1089,20 +1101,7 @@ async fn leave_project(request: proto::LeaveProject, session: Session) -> Result
host_connection_id = %project.host_connection_id, host_connection_id = %project.host_connection_id,
"leave project" "leave project"
); );
project_left(&project, &session);
broadcast(
sender_id,
project.connection_ids.iter().copied(),
|conn_id| {
session.peer.send(
conn_id,
proto::RemoveProjectCollaborator {
project_id: project_id.to_proto(),
peer_id: sender_id.0,
},
)
},
);
Ok(()) Ok(())
} }
@ -1833,40 +1832,7 @@ async fn leave_room_for_session(session: &Session) -> Result<()> {
contacts_to_update.insert(session.user_id); contacts_to_update.insert(session.user_id);
for project in left_room.left_projects.values() { for project in left_room.left_projects.values() {
for connection_id in &project.connection_ids { project_left(project, session);
if project.host_user_id == session.user_id {
session
.peer
.send(
*connection_id,
proto::UnshareProject {
project_id: project.id.to_proto(),
},
)
.trace_err();
} else {
session
.peer
.send(
*connection_id,
proto::RemoveProjectCollaborator {
project_id: project.id.to_proto(),
peer_id: session.connection_id.0,
},
)
.trace_err();
}
}
session
.peer
.send(
session.connection_id,
proto::UnshareProject {
project_id: project.id.to_proto(),
},
)
.trace_err();
} }
room_updated(&left_room.room, &session); room_updated(&left_room.room, &session);
@ -1906,6 +1872,43 @@ async fn leave_room_for_session(session: &Session) -> Result<()> {
Ok(()) Ok(())
} }
fn project_left(project: &db::LeftProject, session: &Session) {
for connection_id in &project.connection_ids {
if project.host_user_id == session.user_id {
session
.peer
.send(
*connection_id,
proto::UnshareProject {
project_id: project.id.to_proto(),
},
)
.trace_err();
} else {
session
.peer
.send(
*connection_id,
proto::RemoveProjectCollaborator {
project_id: project.id.to_proto(),
peer_id: session.connection_id.0,
},
)
.trace_err();
}
}
session
.peer
.send(
session.connection_id,
proto::UnshareProject {
project_id: project.id.to_proto(),
},
)
.trace_err();
}
pub trait ResultExt { pub trait ResultExt {
type Ok; type Ok;