Fix possibility of extra mention insertion on nonce collision

This commit is contained in:
Max Brunsfeld 2023-10-18 18:04:54 -07:00
parent b07f9fe3b5
commit ac54d2b927
6 changed files with 204 additions and 171 deletions

View file

@ -214,7 +214,7 @@ CREATE TABLE IF NOT EXISTS "channel_messages" (
"nonce" BLOB NOT NULL "nonce" BLOB NOT NULL
); );
CREATE INDEX "index_channel_messages_on_channel_id" ON "channel_messages" ("channel_id"); CREATE INDEX "index_channel_messages_on_channel_id" ON "channel_messages" ("channel_id");
CREATE UNIQUE INDEX "index_channel_messages_on_nonce" ON "channel_messages" ("nonce"); CREATE UNIQUE INDEX "index_channel_messages_on_sender_id_nonce" ON "channel_messages" ("sender_id", "nonce");
CREATE TABLE "channel_message_mentions" ( CREATE TABLE "channel_message_mentions" (
"message_id" INTEGER NOT NULL REFERENCES channel_messages (id) ON DELETE CASCADE, "message_id" INTEGER NOT NULL REFERENCES channel_messages (id) ON DELETE CASCADE,

View file

@ -5,3 +5,7 @@ CREATE TABLE "channel_message_mentions" (
"user_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE, "user_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE,
PRIMARY KEY(message_id, start_offset) PRIMARY KEY(message_id, start_offset)
); );
-- We use 'on conflict update' with this index, so it should be per-user.
CREATE UNIQUE INDEX "index_channel_messages_on_sender_id_nonce" ON "channel_messages" ("sender_id", "nonce");
DROP INDEX "index_channel_messages_on_nonce";

View file

@ -1,4 +1,5 @@
use super::*; use super::*;
use sea_orm::TryInsertResult;
use time::OffsetDateTime; use time::OffsetDateTime;
impl Database { impl Database {
@ -184,7 +185,7 @@ impl Database {
let timestamp = timestamp.to_offset(time::UtcOffset::UTC); let timestamp = timestamp.to_offset(time::UtcOffset::UTC);
let timestamp = time::PrimitiveDateTime::new(timestamp.date(), timestamp.time()); let timestamp = time::PrimitiveDateTime::new(timestamp.date(), timestamp.time());
let message_id = channel_message::Entity::insert(channel_message::ActiveModel { let result = channel_message::Entity::insert(channel_message::ActiveModel {
channel_id: ActiveValue::Set(channel_id), channel_id: ActiveValue::Set(channel_id),
sender_id: ActiveValue::Set(user_id), sender_id: ActiveValue::Set(user_id),
body: ActiveValue::Set(body.to_string()), body: ActiveValue::Set(body.to_string()),
@ -193,46 +194,57 @@ impl Database {
id: ActiveValue::NotSet, id: ActiveValue::NotSet,
}) })
.on_conflict( .on_conflict(
OnConflict::column(channel_message::Column::Nonce) OnConflict::columns([
.update_column(channel_message::Column::Nonce) channel_message::Column::SenderId,
.to_owned(), channel_message::Column::Nonce,
])
.do_nothing()
.to_owned(),
) )
.do_nothing()
.exec(&*tx) .exec(&*tx)
.await? .await?;
.last_insert_id;
let models = mentions let message_id;
.iter() match result {
.filter_map(|mention| { TryInsertResult::Inserted(result) => {
let range = mention.range.as_ref()?; message_id = result.last_insert_id;
if !body.is_char_boundary(range.start as usize) let models = mentions
|| !body.is_char_boundary(range.end as usize) .iter()
{ .filter_map(|mention| {
return None; let range = mention.range.as_ref()?;
if !body.is_char_boundary(range.start as usize)
|| !body.is_char_boundary(range.end as usize)
{
return None;
}
Some(channel_message_mention::ActiveModel {
message_id: ActiveValue::Set(message_id),
start_offset: ActiveValue::Set(range.start as i32),
end_offset: ActiveValue::Set(range.end as i32),
user_id: ActiveValue::Set(UserId::from_proto(mention.user_id)),
})
})
.collect::<Vec<_>>();
if !models.is_empty() {
channel_message_mention::Entity::insert_many(models)
.exec(&*tx)
.await?;
} }
Some(channel_message_mention::ActiveModel {
message_id: ActiveValue::Set(message_id),
start_offset: ActiveValue::Set(range.start as i32),
end_offset: ActiveValue::Set(range.end as i32),
user_id: ActiveValue::Set(UserId::from_proto(mention.user_id)),
})
})
.collect::<Vec<_>>();
if !models.is_empty() {
channel_message_mention::Entity::insert_many(models)
.exec(&*tx)
.await?;
}
#[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)] self.observe_channel_message_internal(channel_id, user_id, message_id, &*tx)
enum QueryConnectionId { .await?;
ConnectionId, }
_ => {
message_id = channel_message::Entity::find()
.filter(channel_message::Column::Nonce.eq(Uuid::from_u128(nonce)))
.one(&*tx)
.await?
.ok_or_else(|| anyhow!("failed to insert message"))?
.id;
}
} }
// Observe this message for the sender
self.observe_channel_message_internal(channel_id, user_id, message_id, &*tx)
.await?;
let mut channel_members = self let mut channel_members = self
.get_channel_participants_internal(channel_id, &*tx) .get_channel_participants_internal(channel_id, &*tx)
.await?; .await?;

View file

@ -10,7 +10,10 @@ use parking_lot::Mutex;
use rpc::proto::ChannelEdge; use rpc::proto::ChannelEdge;
use sea_orm::ConnectionTrait; use sea_orm::ConnectionTrait;
use sqlx::migrate::MigrateDatabase; use sqlx::migrate::MigrateDatabase;
use std::sync::Arc; use std::sync::{
atomic::{AtomicI32, Ordering::SeqCst},
Arc,
};
const TEST_RELEASE_CHANNEL: &'static str = "test"; const TEST_RELEASE_CHANNEL: &'static str = "test";
@ -174,3 +177,19 @@ fn graph(channels: &[(ChannelId, &'static str)], edges: &[(ChannelId, ChannelId)
graph graph
} }
static GITHUB_USER_ID: AtomicI32 = AtomicI32::new(5);
async fn new_test_user(db: &Arc<Database>, email: &str) -> UserId {
db.create_user(
email,
false,
NewUserParams {
github_login: email[0..email.find("@").unwrap()].to_string(),
github_user_id: GITHUB_USER_ID.fetch_add(1, SeqCst),
},
)
.await
.unwrap()
.user_id
}

View file

@ -1,21 +1,17 @@
use crate::{
db::{
queries::channels::ChannelGraph,
tests::{graph, new_test_user, TEST_RELEASE_CHANNEL},
ChannelId, ChannelRole, Database, NewUserParams, RoomId,
},
test_both_dbs,
};
use collections::{HashMap, HashSet}; use collections::{HashMap, HashSet};
use rpc::{ use rpc::{
proto::{self}, proto::{self},
ConnectionId, ConnectionId,
}; };
use std::sync::Arc;
use crate::{
db::{
queries::channels::ChannelGraph,
tests::{graph, TEST_RELEASE_CHANNEL},
ChannelId, ChannelRole, Database, NewUserParams, RoomId, UserId,
},
test_both_dbs,
};
use std::sync::{
atomic::{AtomicI32, Ordering},
Arc,
};
test_both_dbs!(test_channels, test_channels_postgres, test_channels_sqlite); test_both_dbs!(test_channels, test_channels_postgres, test_channels_sqlite);
@ -1105,19 +1101,3 @@ fn assert_dag(actual: ChannelGraph, expected: &[(ChannelId, Option<ChannelId>)])
pretty_assertions::assert_eq!(actual_map, expected_map) pretty_assertions::assert_eq!(actual_map, expected_map)
} }
static GITHUB_USER_ID: AtomicI32 = AtomicI32::new(5);
async fn new_test_user(db: &Arc<Database>, email: &str) -> UserId {
db.create_user(
email,
false,
NewUserParams {
github_login: email[0..email.find("@").unwrap()].to_string(),
github_user_id: GITHUB_USER_ID.fetch_add(1, Ordering::SeqCst),
},
)
.await
.unwrap()
.user_id
}

View file

@ -1,5 +1,6 @@
use super::new_test_user;
use crate::{ use crate::{
db::{ChannelRole, Database, MessageId, NewUserParams}, db::{ChannelRole, Database, MessageId},
test_both_dbs, test_both_dbs,
}; };
use channel::mentions_to_proto; use channel::mentions_to_proto;
@ -13,18 +14,7 @@ test_both_dbs!(
); );
async fn test_channel_message_retrieval(db: &Arc<Database>) { async fn test_channel_message_retrieval(db: &Arc<Database>) {
let user = db let user = new_test_user(db, "user@example.com").await;
.create_user(
"user@example.com",
false,
NewUserParams {
github_login: "user".into(),
github_user_id: 1,
},
)
.await
.unwrap()
.user_id;
let channel = db.create_channel("channel", None, user).await.unwrap(); let channel = db.create_channel("channel", None, user).await.unwrap();
let owner_id = db.create_server("test").await.unwrap().0 as u32; let owner_id = db.create_server("test").await.unwrap().0 as u32;
@ -81,46 +71,129 @@ test_both_dbs!(
); );
async fn test_channel_message_nonces(db: &Arc<Database>) { async fn test_channel_message_nonces(db: &Arc<Database>) {
let user = db let user_a = new_test_user(db, "user_a@example.com").await;
.create_user( let user_b = new_test_user(db, "user_b@example.com").await;
"user@example.com", let user_c = new_test_user(db, "user_c@example.com").await;
false, let channel = db.create_channel("channel", None, user_a).await.unwrap();
NewUserParams { db.invite_channel_member(channel, user_b, user_a, ChannelRole::Member)
github_login: "user".into(), .await
github_user_id: 1, .unwrap();
}, db.invite_channel_member(channel, user_c, user_a, ChannelRole::Member)
.await
.unwrap();
db.respond_to_channel_invite(channel, user_b, true)
.await
.unwrap();
db.respond_to_channel_invite(channel, user_c, true)
.await
.unwrap();
let owner_id = db.create_server("test").await.unwrap().0 as u32;
db.join_channel_chat(channel, rpc::ConnectionId { owner_id, id: 0 }, user_a)
.await
.unwrap();
db.join_channel_chat(channel, rpc::ConnectionId { owner_id, id: 1 }, user_b)
.await
.unwrap();
// As user A, create messages that re-use the same nonces. The requests
// succeed, but return the same ids.
let id1 = db
.create_channel_message(
channel,
user_a,
"hi @user_b",
&mentions_to_proto(&[(3..10, user_b.to_proto())]),
OffsetDateTime::now_utc(),
100,
) )
.await .await
.unwrap() .unwrap()
.user_id; .0;
let channel = db.create_channel("channel", None, user).await.unwrap(); let id2 = db
.create_channel_message(
channel,
user_a,
"hello, fellow users",
&mentions_to_proto(&[]),
OffsetDateTime::now_utc(),
200,
)
.await
.unwrap()
.0;
let id3 = db
.create_channel_message(
channel,
user_a,
"bye @user_c (same nonce as first message)",
&mentions_to_proto(&[(4..11, user_c.to_proto())]),
OffsetDateTime::now_utc(),
100,
)
.await
.unwrap()
.0;
let id4 = db
.create_channel_message(
channel,
user_a,
"omg (same nonce as second message)",
&mentions_to_proto(&[]),
OffsetDateTime::now_utc(),
200,
)
.await
.unwrap()
.0;
let owner_id = db.create_server("test").await.unwrap().0 as u32; // As a different user, reuse one of the same nonces. This request succeeds
// and returns a different id.
let id5 = db
.create_channel_message(
channel,
user_b,
"omg @user_a (same nonce as user_a's first message)",
&mentions_to_proto(&[(4..11, user_a.to_proto())]),
OffsetDateTime::now_utc(),
100,
)
.await
.unwrap()
.0;
db.join_channel_chat(channel, rpc::ConnectionId { owner_id, id: 0 }, user) assert_ne!(id1, id2);
.await assert_eq!(id1, id3);
.unwrap(); assert_eq!(id2, id4);
assert_ne!(id5, id1);
let msg1_id = db let messages = db
.create_channel_message(channel, user, "1", &[], OffsetDateTime::now_utc(), 1) .get_channel_messages(channel, user_a, 5, None)
.await .await
.unwrap(); .unwrap()
let msg2_id = db .into_iter()
.create_channel_message(channel, user, "2", &[], OffsetDateTime::now_utc(), 2) .map(|m| (m.id, m.body, m.mentions))
.await .collect::<Vec<_>>();
.unwrap(); assert_eq!(
let msg3_id = db messages,
.create_channel_message(channel, user, "3", &[], OffsetDateTime::now_utc(), 1) &[
.await (
.unwrap(); id1.to_proto(),
let msg4_id = db "hi @user_b".into(),
.create_channel_message(channel, user, "4", &[], OffsetDateTime::now_utc(), 2) mentions_to_proto(&[(3..10, user_b.to_proto())]),
.await ),
.unwrap(); (
id2.to_proto(),
assert_ne!(msg1_id, msg2_id); "hello, fellow users".into(),
assert_eq!(msg1_id, msg3_id); mentions_to_proto(&[])
assert_eq!(msg2_id, msg4_id); ),
(
id5.to_proto(),
"omg @user_a (same nonce as user_a's first message)".into(),
mentions_to_proto(&[(4..11, user_a.to_proto())]),
),
]
);
} }
test_both_dbs!( test_both_dbs!(
@ -130,30 +203,8 @@ test_both_dbs!(
); );
async fn test_unseen_channel_messages(db: &Arc<Database>) { async fn test_unseen_channel_messages(db: &Arc<Database>) {
let user = db let user = new_test_user(db, "user_a@example.com").await;
.create_user( let observer = new_test_user(db, "user_b@example.com").await;
"user_a@example.com",
false,
NewUserParams {
github_login: "user_a".into(),
github_user_id: 1,
},
)
.await
.unwrap()
.user_id;
let observer = db
.create_user(
"user_b@example.com",
false,
NewUserParams {
github_login: "user_b".into(),
github_user_id: 2,
},
)
.await
.unwrap()
.user_id;
let channel_1 = db.create_channel("channel", None, user).await.unwrap(); let channel_1 = db.create_channel("channel", None, user).await.unwrap();
let channel_2 = db.create_channel("channel-2", None, user).await.unwrap(); let channel_2 = db.create_channel("channel-2", None, user).await.unwrap();
@ -304,42 +355,9 @@ test_both_dbs!(
); );
async fn test_channel_message_mentions(db: &Arc<Database>) { async fn test_channel_message_mentions(db: &Arc<Database>) {
let user_a = db let user_a = new_test_user(db, "user_a@example.com").await;
.create_user( let user_b = new_test_user(db, "user_b@example.com").await;
"user_a@example.com", let user_c = new_test_user(db, "user_c@example.com").await;
false,
NewUserParams {
github_login: "user_a".into(),
github_user_id: 1,
},
)
.await
.unwrap()
.user_id;
let user_b = db
.create_user(
"user_b@example.com",
false,
NewUserParams {
github_login: "user_b".into(),
github_user_id: 2,
},
)
.await
.unwrap()
.user_id;
let user_c = db
.create_user(
"user_b@example.com",
false,
NewUserParams {
github_login: "user_c".into(),
github_user_id: 3,
},
)
.await
.unwrap()
.user_id;
let channel = db.create_channel("channel", None, user_a).await.unwrap(); let channel = db.create_channel("channel", None, user_a).await.unwrap();
db.invite_channel_member(channel, user_b, user_a, ChannelRole::Member) db.invite_channel_member(channel, user_b, user_a, ChannelRole::Member)