Re-send pending messages after reconnecting

This commit is contained in:
Antonio Scandurra 2021-09-16 16:23:20 +02:00
parent 4a96a5c9ff
commit 8973e250ca
8 changed files with 211 additions and 43 deletions

13
Cargo.lock generated
View file

@ -836,7 +836,7 @@ dependencies = [
"target_build_utils",
"term",
"toml 0.4.10",
"uuid",
"uuid 0.5.1",
"walkdir",
]
@ -884,7 +884,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e7fb075b9b54e939006aa12e1f6cd2d3194041ff4ebe7f2efcbedf18f25b667"
dependencies = [
"byteorder",
"uuid",
"uuid 0.5.1",
]
[[package]]
@ -2963,7 +2963,7 @@ dependencies = [
"byteorder",
"cfb",
"encoding",
"uuid",
"uuid 0.5.1",
]
[[package]]
@ -4784,6 +4784,7 @@ dependencies = [
"thiserror",
"time 0.2.25",
"url",
"uuid 0.8.2",
"webpki",
"webpki-roots",
"whoami",
@ -5606,6 +5607,12 @@ dependencies = [
"sha1 0.2.0",
]
[[package]]
name = "uuid"
version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bc5cf98d8186244414c848017f0e2676b3fcb46807f6668a97dfe67359a3c4b7"
[[package]]
name = "value-bag"
version = "1.0.0-alpha.7"

View file

@ -5,6 +5,9 @@ edition = "2018"
name = "zed-server"
version = "0.1.0"
[[bin]]
name = "zed-server"
[[bin]]
name = "seed"
required-features = ["seed-support"]
@ -47,7 +50,7 @@ default-features = false
[dependencies.sqlx]
version = "0.5.2"
features = ["runtime-async-std-rustls", "postgres", "time"]
features = ["runtime-async-std-rustls", "postgres", "time", "uuid"]
[dev-dependencies]
gpui = { path = "../gpui" }

View file

@ -73,7 +73,7 @@ async fn main() {
for timestamp in timestamps {
let sender_id = *zed_user_ids.choose(&mut rng).unwrap();
let body = lipsum::lipsum_words(rng.gen_range(1..=50));
db.create_channel_message(channel_id, sender_id, &body, timestamp)
db.create_channel_message(channel_id, sender_id, &body, timestamp, rng.gen())
.await
.expect("failed to insert message");
}

View file

@ -1,7 +1,7 @@
use anyhow::Context;
use async_std::task::{block_on, yield_now};
use serde::Serialize;
use sqlx::{FromRow, Result};
use sqlx::{types::Uuid, FromRow, Result};
use time::OffsetDateTime;
pub use async_sqlx_session::PostgresSessionStore as SessionStore;
@ -402,11 +402,13 @@ impl Db {
sender_id: UserId,
body: &str,
timestamp: OffsetDateTime,
nonce: u128,
) -> Result<MessageId> {
test_support!(self, {
let query = "
INSERT INTO channel_messages (channel_id, sender_id, body, sent_at)
VALUES ($1, $2, $3, $4)
INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
RETURNING id
";
sqlx::query_scalar(query)
@ -414,6 +416,7 @@ impl Db {
.bind(sender_id.0)
.bind(body)
.bind(timestamp)
.bind(Uuid::from_u128(nonce))
.fetch_one(&self.pool)
.await
.map(MessageId)
@ -430,7 +433,7 @@ impl Db {
let query = r#"
SELECT * FROM (
SELECT
id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at
id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
FROM
channel_messages
WHERE
@ -514,6 +517,7 @@ pub struct ChannelMessage {
pub sender_id: UserId,
pub body: String,
pub sent_at: OffsetDateTime,
pub nonce: Uuid,
}
#[cfg(test)]
@ -677,7 +681,7 @@ pub mod tests {
let org = db.create_org("org", "org").await.unwrap();
let channel = db.create_org_channel(org, "channel").await.unwrap();
for i in 0..10 {
db.create_channel_message(channel, user, &i.to_string(), OffsetDateTime::now_utc())
db.create_channel_message(channel, user, &i.to_string(), OffsetDateTime::now_utc(), i)
.await
.unwrap();
}
@ -697,4 +701,34 @@ pub mod tests {
["1", "2", "3", "4"]
);
}
#[gpui::test]
async fn test_channel_message_nonces() {
let test_db = TestDb::new();
let db = test_db.db();
let user = db.create_user("user", false).await.unwrap();
let org = db.create_org("org", "org").await.unwrap();
let channel = db.create_org_channel(org, "channel").await.unwrap();
let msg1_id = db
.create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
.await
.unwrap();
let msg2_id = db
.create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
.await
.unwrap();
let msg3_id = db
.create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
.await
.unwrap();
let msg4_id = db
.create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
.await
.unwrap();
assert_ne!(msg1_id, msg2_id);
assert_eq!(msg1_id, msg3_id);
assert_eq!(msg2_id, msg4_id);
}
}

View file

@ -602,6 +602,7 @@ impl Server {
body: msg.body,
timestamp: msg.sent_at.unix_timestamp() as u64,
sender_id: msg.sender_id.to_proto(),
nonce: Some(msg.nonce.as_u128().into()),
})
.collect::<Vec<_>>();
self.peer
@ -687,10 +688,24 @@ impl Server {
}
let timestamp = OffsetDateTime::now_utc();
let nonce = if let Some(nonce) = request.payload.nonce {
nonce
} else {
self.peer
.respond_with_error(
receipt,
proto::Error {
message: "nonce can't be blank".to_string(),
},
)
.await?;
return Ok(());
};
let message_id = self
.app_state
.db
.create_channel_message(channel_id, user_id, &body, timestamp)
.create_channel_message(channel_id, user_id, &body, timestamp, nonce.clone().into())
.await?
.to_proto();
let message = proto::ChannelMessage {
@ -698,6 +713,7 @@ impl Server {
id: message_id,
body,
timestamp: timestamp.unix_timestamp() as u64,
nonce: Some(nonce),
};
broadcast(request.sender_id, connection_ids, |conn_id| {
self.peer.send(
@ -754,6 +770,7 @@ impl Server {
body: msg.body,
timestamp: msg.sent_at.unix_timestamp() as u64,
sender_id: msg.sender_id.to_proto(),
nonce: Some(msg.nonce.as_u128().into()),
})
.collect::<Vec<_>>();
self.peer
@ -1513,6 +1530,7 @@ mod tests {
current_user_id(&user_store_b),
"hello A, it's B.",
OffsetDateTime::now_utc(),
1,
)
.await
.unwrap();
@ -1707,6 +1725,7 @@ mod tests {
current_user_id(&user_store_b),
"hello A, it's B.",
OffsetDateTime::now_utc(),
2,
)
.await
.unwrap();
@ -1787,6 +1806,24 @@ mod tests {
)
});
// Send a message from client B while it is disconnected.
channel_b
.update(&mut cx_b, |channel, cx| {
let task = channel
.send_message("can you see this?".to_string(), cx)
.unwrap();
assert_eq!(
channel_messages(channel),
&[
("user_b".to_string(), "hello A, it's B.".to_string(), false),
("user_b".to_string(), "can you see this?".to_string(), true)
]
);
task
})
.await
.unwrap_err();
// Send a message from client A while B is disconnected.
channel_a
.update(&mut cx_a, |channel, cx| {
@ -1812,7 +1849,8 @@ mod tests {
server.allow_connections();
cx_b.foreground().advance_clock(Duration::from_secs(10));
// Verify that B sees the new messages upon reconnection.
// Verify that B sees the new messages upon reconnection, as well as the message client B
// sent while offline.
channel_b
.condition(&cx_b, |channel, _| {
channel_messages(channel)
@ -1820,6 +1858,7 @@ mod tests {
("user_b".to_string(), "hello A, it's B.".to_string(), false),
("user_a".to_string(), "oh, hi B.".to_string(), false),
("user_a".to_string(), "sup".to_string(), false),
("user_b".to_string(), "can you see this?".to_string(), false),
]
})
.await;
@ -1838,6 +1877,7 @@ mod tests {
("user_b".to_string(), "hello A, it's B.".to_string(), false),
("user_a".to_string(), "oh, hi B.".to_string(), false),
("user_a".to_string(), "sup".to_string(), false),
("user_b".to_string(), "can you see this?".to_string(), false),
("user_a".to_string(), "you online?".to_string(), false),
]
})
@ -1856,6 +1896,7 @@ mod tests {
("user_b".to_string(), "hello A, it's B.".to_string(), false),
("user_a".to_string(), "oh, hi B.".to_string(), false),
("user_a".to_string(), "sup".to_string(), false),
("user_b".to_string(), "can you see this?".to_string(), false),
("user_a".to_string(), "you online?".to_string(), false),
("user_b".to_string(), "yep".to_string(), false),
]

View file

@ -9,6 +9,7 @@ use gpui::{
Entity, ModelContext, ModelHandle, MutableAppContext, Task, WeakModelHandle,
};
use postage::prelude::Stream;
use rand::prelude::*;
use std::{
collections::{HashMap, HashSet},
mem,
@ -42,6 +43,7 @@ pub struct Channel {
next_pending_message_id: usize,
user_store: Arc<UserStore>,
rpc: Arc<Client>,
rng: StdRng,
_subscription: rpc::Subscription,
}
@ -51,6 +53,7 @@ pub struct ChannelMessage {
pub body: String,
pub timestamp: OffsetDateTime,
pub sender: Arc<User>,
pub nonce: u128,
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
@ -218,6 +221,7 @@ impl Channel {
messages: Default::default(),
loaded_all_messages: false,
next_pending_message_id: 0,
rng: StdRng::from_entropy(),
_subscription,
}
}
@ -242,6 +246,7 @@ impl Channel {
let channel_id = self.details.id;
let pending_id = ChannelMessageId::Pending(post_inc(&mut self.next_pending_message_id));
let nonce = self.rng.gen();
self.insert_messages(
SumTree::from_item(
ChannelMessage {
@ -249,6 +254,7 @@ impl Channel {
body: body.clone(),
sender: current_user,
timestamp: OffsetDateTime::now_utc(),
nonce,
},
&(),
),
@ -257,7 +263,11 @@ impl Channel {
let user_store = self.user_store.clone();
let rpc = self.rpc.clone();
Ok(cx.spawn(|this, mut cx| async move {
let request = rpc.request(proto::SendChannelMessage { channel_id, body });
let request = rpc.request(proto::SendChannelMessage {
channel_id,
body,
nonce: Some(nonce.into()),
});
let response = request.await?;
let message = ChannelMessage::from_proto(
response.message.ok_or_else(|| anyhow!("invalid message"))?,
@ -265,7 +275,6 @@ impl Channel {
)
.await?;
this.update(&mut cx, |this, cx| {
this.remove_message(pending_id, cx);
this.insert_messages(SumTree::from_item(message, &()), cx);
Ok(())
})
@ -312,32 +321,51 @@ impl Channel {
let user_store = self.user_store.clone();
let rpc = self.rpc.clone();
let channel_id = self.details.id;
cx.spawn(|channel, mut cx| {
cx.spawn(|this, mut cx| {
async move {
let response = rpc.request(proto::JoinChannel { channel_id }).await?;
let messages = messages_from_proto(response.messages, &user_store).await?;
let loaded_all_messages = response.done;
channel.update(&mut cx, |channel, cx| {
let pending_messages = this.update(&mut cx, |this, cx| {
if let Some((first_new_message, last_old_message)) =
messages.first().zip(channel.messages.last())
messages.first().zip(this.messages.last())
{
if first_new_message.id > last_old_message.id {
let old_messages = mem::take(&mut channel.messages);
let old_messages = mem::take(&mut this.messages);
cx.emit(ChannelEvent::MessagesUpdated {
old_range: 0..old_messages.summary().count,
new_count: 0,
});
channel.loaded_all_messages = loaded_all_messages;
this.loaded_all_messages = loaded_all_messages;
}
}
channel.insert_messages(messages, cx);
this.insert_messages(messages, cx);
if loaded_all_messages {
channel.loaded_all_messages = loaded_all_messages;
this.loaded_all_messages = loaded_all_messages;
}
this.pending_messages().cloned().collect::<Vec<_>>()
});
for pending_message in pending_messages {
let request = rpc.request(proto::SendChannelMessage {
channel_id,
body: pending_message.body,
nonce: Some(pending_message.nonce.into()),
});
let response = request.await?;
let message = ChannelMessage::from_proto(
response.message.ok_or_else(|| anyhow!("invalid message"))?,
&user_store,
)
.await?;
this.update(&mut cx, |this, cx| {
this.insert_messages(SumTree::from_item(message, &()), cx);
});
}
Ok(())
}
.log_err()
@ -365,6 +393,12 @@ impl Channel {
cursor.take(range.len())
}
pub fn pending_messages(&self) -> impl Iterator<Item = &ChannelMessage> {
let mut cursor = self.messages.cursor::<ChannelMessageId, ()>();
cursor.seek(&ChannelMessageId::Pending(0), Bias::Left, &());
cursor
}
fn handle_message_sent(
&mut self,
message: TypedEnvelope<ChannelMessageSent>,
@ -391,29 +425,13 @@ impl Channel {
Ok(())
}
fn remove_message(&mut self, message_id: ChannelMessageId, cx: &mut ModelContext<Self>) {
let mut old_cursor = self.messages.cursor::<ChannelMessageId, Count>();
let mut new_messages = old_cursor.slice(&message_id, Bias::Left, &());
let start_ix = old_cursor.sum_start().0;
let removed_messages = old_cursor.slice(&message_id, Bias::Right, &());
let removed_count = removed_messages.summary().count;
new_messages.push_tree(old_cursor.suffix(&()), &());
drop(old_cursor);
self.messages = new_messages;
if removed_count > 0 {
let end_ix = start_ix + removed_count;
cx.emit(ChannelEvent::MessagesUpdated {
old_range: start_ix..end_ix,
new_count: 0,
});
cx.notify();
}
}
fn insert_messages(&mut self, messages: SumTree<ChannelMessage>, cx: &mut ModelContext<Self>) {
if let Some((first_message, last_message)) = messages.first().zip(messages.last()) {
let nonces = messages
.cursor::<(), ()>()
.map(|m| m.nonce)
.collect::<HashSet<_>>();
let mut old_cursor = self.messages.cursor::<ChannelMessageId, Count>();
let mut new_messages = old_cursor.slice(&first_message.id, Bias::Left, &());
let start_ix = old_cursor.sum_start().0;
@ -423,10 +441,40 @@ impl Channel {
let end_ix = start_ix + removed_count;
new_messages.push_tree(messages, &());
new_messages.push_tree(old_cursor.suffix(&()), &());
let mut ranges = Vec::<Range<usize>>::new();
if new_messages.last().unwrap().is_pending() {
new_messages.push_tree(old_cursor.suffix(&()), &());
} else {
new_messages.push_tree(
old_cursor.slice(&ChannelMessageId::Pending(0), Bias::Left, &()),
&(),
);
while let Some(message) = old_cursor.item() {
let message_ix = old_cursor.sum_start().0;
if nonces.contains(&message.nonce) {
if ranges.last().map_or(false, |r| r.end == message_ix) {
ranges.last_mut().unwrap().end += 1;
} else {
ranges.push(message_ix..message_ix + 1);
}
} else {
new_messages.push(message.clone(), &());
}
old_cursor.next(&());
}
}
drop(old_cursor);
self.messages = new_messages;
for range in ranges.into_iter().rev() {
cx.emit(ChannelEvent::MessagesUpdated {
old_range: range,
new_count: 0,
});
}
cx.emit(ChannelEvent::MessagesUpdated {
old_range: start_ix..end_ix,
new_count,
@ -477,6 +525,10 @@ impl ChannelMessage {
body: message.body,
timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)?,
sender,
nonce: message
.nonce
.ok_or_else(|| anyhow!("nonce is required"))?
.into(),
})
}
@ -606,12 +658,14 @@ mod tests {
body: "a".into(),
timestamp: 1000,
sender_id: 5,
nonce: Some(1.into()),
},
proto::ChannelMessage {
id: 11,
body: "b".into(),
timestamp: 1001,
sender_id: 6,
nonce: Some(2.into()),
},
],
done: false,
@ -665,6 +719,7 @@ mod tests {
body: "c".into(),
timestamp: 1002,
sender_id: 7,
nonce: Some(3.into()),
}),
})
.await;
@ -720,12 +775,14 @@ mod tests {
body: "y".into(),
timestamp: 998,
sender_id: 5,
nonce: Some(4.into()),
},
proto::ChannelMessage {
id: 9,
body: "z".into(),
timestamp: 999,
sender_id: 6,
nonce: Some(5.into()),
},
],
},

View file

@ -151,6 +151,7 @@ message GetUsersResponse {
message SendChannelMessage {
uint64 channel_id = 1;
string body = 2;
Nonce nonce = 3;
}
message SendChannelMessageResponse {
@ -296,6 +297,11 @@ message Range {
uint64 end = 2;
}
message Nonce {
uint64 upper_half = 1;
uint64 lower_half = 2;
}
message Channel {
uint64 id = 1;
string name = 2;
@ -306,4 +312,5 @@ message ChannelMessage {
string body = 2;
uint64 timestamp = 3;
uint64 sender_id = 4;
Nonce nonce = 5;
}

View file

@ -248,3 +248,22 @@ impl From<SystemTime> for Timestamp {
}
}
}
impl From<u128> for Nonce {
fn from(nonce: u128) -> Self {
let upper_half = (nonce >> 64) as u64;
let lower_half = nonce as u64;
Self {
upper_half,
lower_half,
}
}
}
impl From<Nonce> for u128 {
fn from(nonce: Nonce) -> Self {
let upper_half = (nonce.upper_half as u128) << 64;
let lower_half = nonce.lower_half as u128;
upper_half | lower_half
}
}