Make database interactions deterministic in test

Co-Authored-By: Nathan Sobo <nathan@zed.dev>
This commit is contained in:
Antonio Scandurra 2021-08-20 16:24:33 +02:00
parent 3ba530bca1
commit 98f691d16d
5 changed files with 338 additions and 285 deletions

View file

@ -14,7 +14,7 @@ use keymap::MatchResult;
use parking_lot::{Mutex, RwLock};
use pathfinder_geometry::{rect::RectF, vector::vec2f};
use platform::Event;
use postage::{mpsc, oneshot, sink::Sink as _, stream::Stream as _};
use postage::{mpsc, sink::Sink as _, stream::Stream as _};
use smol::prelude::*;
use std::{
any::{type_name, Any, TypeId},
@ -2310,24 +2310,6 @@ impl<T: Entity> ModelHandle<T> {
cx.update_model(self, update)
}
pub fn next_notification(&self, cx: &TestAppContext) -> impl Future<Output = ()> {
let (tx, mut rx) = oneshot::channel();
let mut tx = Some(tx);
let mut cx = cx.cx.borrow_mut();
self.update(&mut *cx, |_, cx| {
cx.observe(self, move |_, _, _| {
if let Some(mut tx) = tx.take() {
tx.blocking_send(()).ok();
}
});
});
async move {
rx.recv().await;
}
}
pub fn condition(
&self,
cx: &TestAppContext,

View file

@ -122,9 +122,14 @@ impl Deterministic {
smol::pin!(future);
let unparker = self.parker.lock().unparker();
let waker = waker_fn(move || {
unparker.unpark();
});
let woken = Arc::new(AtomicBool::new(false));
let waker = {
let woken = woken.clone();
waker_fn(move || {
woken.store(true, SeqCst);
unparker.unpark();
})
};
let mut cx = Context::from_waker(&waker);
let mut trace = Trace::default();
@ -166,10 +171,11 @@ impl Deterministic {
&& state.scheduled_from_background.is_empty()
&& state.spawned_from_foreground.is_empty()
{
if state.forbid_parking {
if state.forbid_parking && !woken.load(SeqCst) {
panic!("deterministic executor parked after a call to forbid_parking");
}
drop(state);
woken.store(false, SeqCst);
self.parker.lock().park();
}

View file

@ -1,3 +1,5 @@
use anyhow::Context;
use async_std::task::{block_on, yield_now};
use serde::Serialize;
use sqlx::{FromRow, Result};
use time::OffsetDateTime;
@ -5,7 +7,24 @@ use time::OffsetDateTime;
pub use async_sqlx_session::PostgresSessionStore as SessionStore;
pub use sqlx::postgres::PgPoolOptions as DbOptions;
pub struct Db(pub sqlx::PgPool);
macro_rules! test_support {
($self:ident, { $($token:tt)* }) => {{
let body = async {
$($token)*
};
if $self.test_mode {
yield_now().await;
block_on(body)
} else {
body.await
}
}};
}
pub struct Db {
db: sqlx::PgPool,
test_mode: bool,
}
#[derive(Debug, FromRow, Serialize)]
pub struct User {
@ -37,6 +56,33 @@ pub struct ChannelMessage {
}
impl Db {
pub async fn new(url: &str, max_connections: u32) -> tide::Result<Self> {
let db = DbOptions::new()
.max_connections(max_connections)
.connect(url)
.await
.context("failed to connect to postgres database")?;
Ok(Self {
db,
test_mode: false,
})
}
#[cfg(test)]
pub fn test(url: &str, max_connections: u32) -> Self {
let mut db = block_on(Self::new(url, max_connections)).unwrap();
db.test_mode = true;
db
}
#[cfg(test)]
pub fn migrate(&self, path: &std::path::Path) {
block_on(async {
let migrator = sqlx::migrate::Migrator::new(path).await.unwrap();
migrator.run(&self.db).await.unwrap();
});
}
// signups
pub async fn create_signup(
@ -45,53 +91,63 @@ impl Db {
email_address: &str,
about: &str,
) -> Result<SignupId> {
let query = "
INSERT INTO signups (github_login, email_address, about)
VALUES ($1, $2, $3)
RETURNING id
";
sqlx::query_scalar(query)
.bind(github_login)
.bind(email_address)
.bind(about)
.fetch_one(&self.0)
.await
.map(SignupId)
test_support!(self, {
let query = "
INSERT INTO signups (github_login, email_address, about)
VALUES ($1, $2, $3)
RETURNING id
";
sqlx::query_scalar(query)
.bind(github_login)
.bind(email_address)
.bind(about)
.fetch_one(&self.db)
.await
.map(SignupId)
})
}
pub async fn get_all_signups(&self) -> Result<Vec<Signup>> {
let query = "SELECT * FROM users ORDER BY github_login ASC";
sqlx::query_as(query).fetch_all(&self.0).await
test_support!(self, {
let query = "SELECT * FROM users ORDER BY github_login ASC";
sqlx::query_as(query).fetch_all(&self.db).await
})
}
pub async fn delete_signup(&self, id: SignupId) -> Result<()> {
let query = "DELETE FROM signups WHERE id = $1";
sqlx::query(query)
.bind(id.0)
.execute(&self.0)
.await
.map(drop)
test_support!(self, {
let query = "DELETE FROM signups WHERE id = $1";
sqlx::query(query)
.bind(id.0)
.execute(&self.db)
.await
.map(drop)
})
}
// users
pub async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
let query = "
INSERT INTO users (github_login, admin)
VALUES ($1, $2)
RETURNING id
";
sqlx::query_scalar(query)
.bind(github_login)
.bind(admin)
.fetch_one(&self.0)
.await
.map(UserId)
test_support!(self, {
let query = "
INSERT INTO users (github_login, admin)
VALUES ($1, $2)
RETURNING id
";
sqlx::query_scalar(query)
.bind(github_login)
.bind(admin)
.fetch_one(&self.db)
.await
.map(UserId)
})
}
pub async fn get_all_users(&self) -> Result<Vec<User>> {
let query = "SELECT * FROM users ORDER BY github_login ASC";
sqlx::query_as(query).fetch_all(&self.0).await
test_support!(self, {
let query = "SELECT * FROM users ORDER BY github_login ASC";
sqlx::query_as(query).fetch_all(&self.db).await
})
}
pub async fn get_users_by_ids(
@ -99,53 +155,61 @@ impl Db {
requester_id: UserId,
ids: impl Iterator<Item = UserId>,
) -> Result<Vec<User>> {
// Only return users that are in a common channel with the requesting user.
let query = "
SELECT users.*
FROM
users, channel_memberships
WHERE
users.id IN $1 AND
channel_memberships.user_id = users.id AND
channel_memberships.channel_id IN (
SELECT channel_id
FROM channel_memberships
WHERE channel_memberships.user_id = $2
)
";
test_support!(self, {
// Only return users that are in a common channel with the requesting user.
let query = "
SELECT users.*
FROM
users, channel_memberships
WHERE
users.id IN $1 AND
channel_memberships.user_id = users.id AND
channel_memberships.channel_id IN (
SELECT channel_id
FROM channel_memberships
WHERE channel_memberships.user_id = $2
)
";
sqlx::query_as(query)
.bind(&ids.map(|id| id.0).collect::<Vec<_>>())
.bind(requester_id)
.fetch_all(&self.0)
.await
sqlx::query_as(query)
.bind(&ids.map(|id| id.0).collect::<Vec<_>>())
.bind(requester_id)
.fetch_all(&self.db)
.await
})
}
pub async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
sqlx::query_as(query)
.bind(github_login)
.fetch_optional(&self.0)
.await
test_support!(self, {
let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
sqlx::query_as(query)
.bind(github_login)
.fetch_optional(&self.db)
.await
})
}
pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
let query = "UPDATE users SET admin = $1 WHERE id = $2";
sqlx::query(query)
.bind(is_admin)
.bind(id.0)
.execute(&self.0)
.await
.map(drop)
test_support!(self, {
let query = "UPDATE users SET admin = $1 WHERE id = $2";
sqlx::query(query)
.bind(is_admin)
.bind(id.0)
.execute(&self.db)
.await
.map(drop)
})
}
pub async fn delete_user(&self, id: UserId) -> Result<()> {
let query = "DELETE FROM users WHERE id = $1;";
sqlx::query(query)
.bind(id.0)
.execute(&self.0)
.await
.map(drop)
test_support!(self, {
let query = "DELETE FROM users WHERE id = $1;";
sqlx::query(query)
.bind(id.0)
.execute(&self.db)
.await
.map(drop)
})
}
// access tokens
@ -155,41 +219,47 @@ impl Db {
user_id: UserId,
access_token_hash: String,
) -> Result<()> {
let query = "
test_support!(self, {
let query = "
INSERT INTO access_tokens (user_id, hash)
VALUES ($1, $2)
";
sqlx::query(query)
.bind(user_id.0)
.bind(access_token_hash)
.execute(&self.0)
.await
.map(drop)
sqlx::query(query)
.bind(user_id.0)
.bind(access_token_hash)
.execute(&self.db)
.await
.map(drop)
})
}
pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
let query = "SELECT hash FROM access_tokens WHERE user_id = $1";
sqlx::query_scalar(query)
.bind(user_id.0)
.fetch_all(&self.0)
.await
test_support!(self, {
let query = "SELECT hash FROM access_tokens WHERE user_id = $1";
sqlx::query_scalar(query)
.bind(user_id.0)
.fetch_all(&self.db)
.await
})
}
// orgs
#[cfg(test)]
pub async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
let query = "
INSERT INTO orgs (name, slug)
VALUES ($1, $2)
RETURNING id
";
sqlx::query_scalar(query)
.bind(name)
.bind(slug)
.fetch_one(&self.0)
.await
.map(OrgId)
test_support!(self, {
let query = "
INSERT INTO orgs (name, slug)
VALUES ($1, $2)
RETURNING id
";
sqlx::query_scalar(query)
.bind(name)
.bind(slug)
.fetch_one(&self.db)
.await
.map(OrgId)
})
}
#[cfg(test)]
@ -199,50 +269,56 @@ impl Db {
user_id: UserId,
is_admin: bool,
) -> Result<()> {
let query = "
INSERT INTO org_memberships (org_id, user_id, admin)
VALUES ($1, $2, $3)
";
sqlx::query(query)
.bind(org_id.0)
.bind(user_id.0)
.bind(is_admin)
.execute(&self.0)
.await
.map(drop)
test_support!(self, {
let query = "
INSERT INTO org_memberships (org_id, user_id, admin)
VALUES ($1, $2, $3)
";
sqlx::query(query)
.bind(org_id.0)
.bind(user_id.0)
.bind(is_admin)
.execute(&self.db)
.await
.map(drop)
})
}
// channels
#[cfg(test)]
pub async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
let query = "
INSERT INTO channels (owner_id, owner_is_user, name)
VALUES ($1, false, $2)
RETURNING id
";
sqlx::query_scalar(query)
.bind(org_id.0)
.bind(name)
.fetch_one(&self.0)
.await
.map(ChannelId)
test_support!(self, {
let query = "
INSERT INTO channels (owner_id, owner_is_user, name)
VALUES ($1, false, $2)
RETURNING id
";
sqlx::query_scalar(query)
.bind(org_id.0)
.bind(name)
.fetch_one(&self.db)
.await
.map(ChannelId)
})
}
pub async fn get_channels_for_user(&self, user_id: UserId) -> Result<Vec<Channel>> {
let query = "
SELECT
channels.id, channels.name
FROM
channel_memberships, channels
WHERE
channel_memberships.user_id = $1 AND
channel_memberships.channel_id = channels.id
";
sqlx::query_as(query)
.bind(user_id.0)
.fetch_all(&self.0)
.await
test_support!(self, {
let query = "
SELECT
channels.id, channels.name
FROM
channel_memberships, channels
WHERE
channel_memberships.user_id = $1 AND
channel_memberships.channel_id = channels.id
";
sqlx::query_as(query)
.bind(user_id.0)
.fetch_all(&self.db)
.await
})
}
pub async fn can_user_access_channel(
@ -250,18 +326,20 @@ impl Db {
user_id: UserId,
channel_id: ChannelId,
) -> Result<bool> {
let query = "
SELECT id
FROM channel_memberships
WHERE user_id = $1 AND channel_id = $2
LIMIT 1
";
sqlx::query_scalar::<_, i32>(query)
.bind(user_id.0)
.bind(channel_id.0)
.fetch_optional(&self.0)
.await
.map(|e| e.is_some())
test_support!(self, {
let query = "
SELECT id
FROM channel_memberships
WHERE user_id = $1 AND channel_id = $2
LIMIT 1
";
sqlx::query_scalar::<_, i32>(query)
.bind(user_id.0)
.bind(channel_id.0)
.fetch_optional(&self.db)
.await
.map(|e| e.is_some())
})
}
#[cfg(test)]
@ -271,17 +349,19 @@ impl Db {
user_id: UserId,
is_admin: bool,
) -> Result<()> {
let query = "
INSERT INTO channel_memberships (channel_id, user_id, admin)
VALUES ($1, $2, $3)
";
sqlx::query(query)
.bind(channel_id.0)
.bind(user_id.0)
.bind(is_admin)
.execute(&self.0)
.await
.map(drop)
test_support!(self, {
let query = "
INSERT INTO channel_memberships (channel_id, user_id, admin)
VALUES ($1, $2, $3)
";
sqlx::query(query)
.bind(channel_id.0)
.bind(user_id.0)
.bind(is_admin)
.execute(&self.db)
.await
.map(drop)
})
}
// messages
@ -293,19 +373,21 @@ impl Db {
body: &str,
timestamp: OffsetDateTime,
) -> Result<MessageId> {
let query = "
INSERT INTO channel_messages (channel_id, sender_id, body, sent_at)
VALUES ($1, $2, $3, $4)
RETURNING id
";
sqlx::query_scalar(query)
.bind(channel_id.0)
.bind(sender_id.0)
.bind(body)
.bind(timestamp)
.fetch_one(&self.0)
.await
.map(MessageId)
test_support!(self, {
let query = "
INSERT INTO channel_messages (channel_id, sender_id, body, sent_at)
VALUES ($1, $2, $3, $4)
RETURNING id
";
sqlx::query_scalar(query)
.bind(channel_id.0)
.bind(sender_id.0)
.bind(body)
.bind(timestamp)
.fetch_one(&self.db)
.await
.map(MessageId)
})
}
pub async fn get_recent_channel_messages(
@ -313,35 +395,39 @@ impl Db {
channel_id: ChannelId,
count: usize,
) -> Result<Vec<ChannelMessage>> {
let query = r#"
SELECT
id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at
FROM
channel_messages
WHERE
channel_id = $1
LIMIT $2
"#;
sqlx::query_as(query)
.bind(channel_id.0)
.bind(count as i64)
.fetch_all(&self.0)
.await
test_support!(self, {
let query = r#"
SELECT
id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at
FROM
channel_messages
WHERE
channel_id = $1
LIMIT $2
"#;
sqlx::query_as(query)
.bind(channel_id.0)
.bind(count as i64)
.fetch_all(&self.db)
.await
})
}
#[cfg(test)]
pub async fn close(&self, db_name: &str) {
let query = "
SELECT pg_terminate_backend(pg_stat_activity.pid)
FROM pg_stat_activity
WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid();
";
sqlx::query(query)
.bind(db_name)
.execute(&self.0)
.await
.unwrap();
self.0.close().await;
test_support!(self, {
let query = "
SELECT pg_terminate_backend(pg_stat_activity.pid)
FROM pg_stat_activity
WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid();
";
sqlx::query(query)
.bind(db_name)
.execute(&self.db)
.await
.unwrap();
self.db.close().await;
})
}
}

View file

@ -11,11 +11,11 @@ mod rpc;
mod team;
use self::errors::TideResultExt as _;
use anyhow::{Context, Result};
use anyhow::Result;
use async_std::net::TcpListener;
use async_trait::async_trait;
use auth::RequestExt as _;
use db::{Db, DbOptions};
use db::Db;
use handlebars::{Handlebars, TemplateRenderError};
use parking_lot::RwLock;
use rust_embed::RustEmbed;
@ -54,12 +54,7 @@ pub struct AppState {
impl AppState {
async fn new(config: Config) -> tide::Result<Arc<Self>> {
let db = Db(DbOptions::new()
.max_connections(5)
.connect(&config.database_url)
.await
.context("failed to connect to postgres database")?);
let db = Db::new(&config.database_url, 5).await?;
let github_client =
github::AppClient::new(config.github_app_id, config.github_private_key.clone());
let repo_client = github_client

View file

@ -922,16 +922,15 @@ mod tests {
db::{self, UserId},
github, AppState, Config,
};
use async_std::{sync::RwLockReadGuard, task};
use gpui::{ModelHandle, TestAppContext};
use async_std::{
sync::RwLockReadGuard,
task::{self, block_on},
};
use gpui::TestAppContext;
use postage::mpsc;
use rand::prelude::*;
use serde_json::json;
use sqlx::{
migrate::{MigrateDatabase, Migrator},
types::time::OffsetDateTime,
Postgres,
};
use sqlx::{migrate::MigrateDatabase, types::time::OffsetDateTime, Postgres};
use std::{path::Path, sync::Arc, time::Duration};
use zed::{
channel::{Channel, ChannelDetails, ChannelList},
@ -1400,6 +1399,8 @@ mod tests {
#[gpui::test]
async fn test_basic_chat(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
cx_a.foreground().forbid_parking();
// Connect to a server as 2 clients.
let mut server = TestServer::start().await;
let (user_id_a, client_a) = server.create_client(&mut cx_a, "user_a").await;
@ -1444,11 +1445,12 @@ mod tests {
this.get_channel(channel_id.to_proto(), cx).unwrap()
});
channel_a.read_with(&cx_a, |channel, _| assert!(channel.messages().is_empty()));
channel_a.next_notification(&cx_a).await;
assert_eq!(
channel_messages(&channel_a, &cx_a),
&[(user_id_b.to_proto(), "hello A, it's B.".to_string())]
);
channel_a
.condition(&cx_a, |channel, _| {
channel_messages(channel)
== [(user_id_b.to_proto(), "hello A, it's B.".to_string())]
})
.await;
let channels_b = ChannelList::new(client_b, &mut cx_b.to_async())
.await
@ -1462,15 +1464,17 @@ mod tests {
}]
)
});
let channel_b = channels_b.update(&mut cx_b, |this, cx| {
this.get_channel(channel_id.to_proto(), cx).unwrap()
});
channel_b.read_with(&cx_b, |channel, _| assert!(channel.messages().is_empty()));
channel_b.next_notification(&cx_b).await;
assert_eq!(
channel_messages(&channel_b, &cx_b),
&[(user_id_b.to_proto(), "hello A, it's B.".to_string())]
);
channel_b
.condition(&cx_b, |channel, _| {
channel_messages(channel)
== [(user_id_b.to_proto(), "hello A, it's B.".to_string())]
})
.await;
channel_a.update(&mut cx_a, |channel, cx| {
channel.send_message("oh, hi B.".to_string(), cx).unwrap();
@ -1484,24 +1488,20 @@ mod tests {
&["oh, hi B.", "sup"]
)
});
channel_a.next_notification(&cx_a).await;
channel_a.read_with(&cx_a, |channel, _| {
assert_eq!(channel.pending_messages().len(), 1);
});
channel_a.next_notification(&cx_a).await;
channel_a.read_with(&cx_a, |channel, _| {
assert_eq!(channel.pending_messages().len(), 0);
});
channel_b.next_notification(&cx_b).await;
assert_eq!(
channel_messages(&channel_b, &cx_b),
&[
(user_id_b.to_proto(), "hello A, it's B.".to_string()),
(user_id_a.to_proto(), "oh, hi B.".to_string()),
(user_id_a.to_proto(), "sup".to_string()),
]
);
channel_a
.condition(&cx_a, |channel, _| channel.pending_messages().is_empty())
.await;
channel_b
.condition(&cx_b, |channel, _| {
channel_messages(channel)
== [
(user_id_b.to_proto(), "hello A, it's B.".to_string()),
(user_id_a.to_proto(), "oh, hi B.".to_string()),
(user_id_a.to_proto(), "sup".to_string()),
]
})
.await;
assert_eq!(
server.state().await.channels[&channel_id]
@ -1519,17 +1519,12 @@ mod tests {
.condition(|state| !state.channels.contains_key(&channel_id))
.await;
fn channel_messages(
channel: &ModelHandle<Channel>,
cx: &TestAppContext,
) -> Vec<(u64, String)> {
channel.read_with(cx, |channel, _| {
channel
.messages()
.iter()
.map(|m| (m.sender_id, m.body.clone()))
.collect()
})
fn channel_messages(channel: &Channel) -> Vec<(u64, String)> {
channel
.messages()
.iter()
.map(|m| (m.sender_id, m.body.clone()))
.collect()
}
}
@ -1584,21 +1579,12 @@ mod tests {
config.session_secret = "a".repeat(32);
config.database_url = format!("postgres://postgres@localhost/{}", db_name);
Self::create_db(&config.database_url).await;
let db = db::Db(
db::DbOptions::new()
.max_connections(5)
.connect(&config.database_url)
.await
.expect("failed to connect to postgres database"),
);
let migrator = Migrator::new(Path::new(concat!(
Self::create_db(&config.database_url);
let db = db::Db::test(&config.database_url, 5);
db.migrate(Path::new(concat!(
env!("CARGO_MANIFEST_DIR"),
"/migrations"
)))
.await
.unwrap();
migrator.run(&db.0).await.unwrap();
)));
let github_client = github::AppClient::test();
Arc::new(AppState {
@ -1611,16 +1597,14 @@ mod tests {
})
}
async fn create_db(url: &str) {
fn create_db(url: &str) {
// Enable tests to run in parallel by serializing the creation of each test database.
lazy_static::lazy_static! {
static ref DB_CREATION: async_std::sync::Mutex<()> = async_std::sync::Mutex::new(());
static ref DB_CREATION: std::sync::Mutex<()> = std::sync::Mutex::new(());
}
let _lock = DB_CREATION.lock().await;
Postgres::create_database(url)
.await
.expect("failed to create test database");
let _lock = DB_CREATION.lock();
block_on(Postgres::create_database(url)).expect("failed to create test database");
}
async fn state<'a>(&'a self) -> RwLockReadGuard<'a, ServerState> {