From 98f691d16d01c5680ecd41021c7853b5a2dc22a7 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Fri, 20 Aug 2021 16:24:33 +0200 Subject: [PATCH] Make database interactions deterministic in test Co-Authored-By: Nathan Sobo --- gpui/src/app.rs | 20 +- gpui/src/executor.rs | 14 +- server/src/db.rs | 466 +++++++++++++++++++++++++------------------ server/src/main.rs | 11 +- server/src/rpc.rs | 112 +++++------ 5 files changed, 338 insertions(+), 285 deletions(-) diff --git a/gpui/src/app.rs b/gpui/src/app.rs index 221419341e..a2fa818824 100644 --- a/gpui/src/app.rs +++ b/gpui/src/app.rs @@ -14,7 +14,7 @@ use keymap::MatchResult; use parking_lot::{Mutex, RwLock}; use pathfinder_geometry::{rect::RectF, vector::vec2f}; use platform::Event; -use postage::{mpsc, oneshot, sink::Sink as _, stream::Stream as _}; +use postage::{mpsc, sink::Sink as _, stream::Stream as _}; use smol::prelude::*; use std::{ any::{type_name, Any, TypeId}, @@ -2310,24 +2310,6 @@ impl ModelHandle { cx.update_model(self, update) } - pub fn next_notification(&self, cx: &TestAppContext) -> impl Future { - let (tx, mut rx) = oneshot::channel(); - let mut tx = Some(tx); - - let mut cx = cx.cx.borrow_mut(); - self.update(&mut *cx, |_, cx| { - cx.observe(self, move |_, _, _| { - if let Some(mut tx) = tx.take() { - tx.blocking_send(()).ok(); - } - }); - }); - - async move { - rx.recv().await; - } - } - pub fn condition( &self, cx: &TestAppContext, diff --git a/gpui/src/executor.rs b/gpui/src/executor.rs index 78cb77c6b9..c848cff9c5 100644 --- a/gpui/src/executor.rs +++ b/gpui/src/executor.rs @@ -122,9 +122,14 @@ impl Deterministic { smol::pin!(future); let unparker = self.parker.lock().unparker(); - let waker = waker_fn(move || { - unparker.unpark(); - }); + let woken = Arc::new(AtomicBool::new(false)); + let waker = { + let woken = woken.clone(); + waker_fn(move || { + woken.store(true, SeqCst); + unparker.unpark(); + }) + }; let mut cx = Context::from_waker(&waker); let mut trace = Trace::default(); @@ -166,10 +171,11 @@ impl Deterministic { && state.scheduled_from_background.is_empty() && state.spawned_from_foreground.is_empty() { - if state.forbid_parking { + if state.forbid_parking && !woken.load(SeqCst) { panic!("deterministic executor parked after a call to forbid_parking"); } drop(state); + woken.store(false, SeqCst); self.parker.lock().park(); } diff --git a/server/src/db.rs b/server/src/db.rs index 8351b9d224..1e489aae36 100644 --- a/server/src/db.rs +++ b/server/src/db.rs @@ -1,3 +1,5 @@ +use anyhow::Context; +use async_std::task::{block_on, yield_now}; use serde::Serialize; use sqlx::{FromRow, Result}; use time::OffsetDateTime; @@ -5,7 +7,24 @@ use time::OffsetDateTime; pub use async_sqlx_session::PostgresSessionStore as SessionStore; pub use sqlx::postgres::PgPoolOptions as DbOptions; -pub struct Db(pub sqlx::PgPool); +macro_rules! test_support { + ($self:ident, { $($token:tt)* }) => {{ + let body = async { + $($token)* + }; + if $self.test_mode { + yield_now().await; + block_on(body) + } else { + body.await + } + }}; +} + +pub struct Db { + db: sqlx::PgPool, + test_mode: bool, +} #[derive(Debug, FromRow, Serialize)] pub struct User { @@ -37,6 +56,33 @@ pub struct ChannelMessage { } impl Db { + pub async fn new(url: &str, max_connections: u32) -> tide::Result { + let db = DbOptions::new() + .max_connections(max_connections) + .connect(url) + .await + .context("failed to connect to postgres database")?; + Ok(Self { + db, + test_mode: false, + }) + } + + #[cfg(test)] + pub fn test(url: &str, max_connections: u32) -> Self { + let mut db = block_on(Self::new(url, max_connections)).unwrap(); + db.test_mode = true; + db + } + + #[cfg(test)] + pub fn migrate(&self, path: &std::path::Path) { + block_on(async { + let migrator = sqlx::migrate::Migrator::new(path).await.unwrap(); + migrator.run(&self.db).await.unwrap(); + }); + } + // signups pub async fn create_signup( @@ -45,53 +91,63 @@ impl Db { email_address: &str, about: &str, ) -> Result { - let query = " - INSERT INTO signups (github_login, email_address, about) - VALUES ($1, $2, $3) - RETURNING id - "; - sqlx::query_scalar(query) - .bind(github_login) - .bind(email_address) - .bind(about) - .fetch_one(&self.0) - .await - .map(SignupId) + test_support!(self, { + let query = " + INSERT INTO signups (github_login, email_address, about) + VALUES ($1, $2, $3) + RETURNING id + "; + sqlx::query_scalar(query) + .bind(github_login) + .bind(email_address) + .bind(about) + .fetch_one(&self.db) + .await + .map(SignupId) + }) } pub async fn get_all_signups(&self) -> Result> { - let query = "SELECT * FROM users ORDER BY github_login ASC"; - sqlx::query_as(query).fetch_all(&self.0).await + test_support!(self, { + let query = "SELECT * FROM users ORDER BY github_login ASC"; + sqlx::query_as(query).fetch_all(&self.db).await + }) } pub async fn delete_signup(&self, id: SignupId) -> Result<()> { - let query = "DELETE FROM signups WHERE id = $1"; - sqlx::query(query) - .bind(id.0) - .execute(&self.0) - .await - .map(drop) + test_support!(self, { + let query = "DELETE FROM signups WHERE id = $1"; + sqlx::query(query) + .bind(id.0) + .execute(&self.db) + .await + .map(drop) + }) } // users pub async fn create_user(&self, github_login: &str, admin: bool) -> Result { - let query = " - INSERT INTO users (github_login, admin) - VALUES ($1, $2) - RETURNING id - "; - sqlx::query_scalar(query) - .bind(github_login) - .bind(admin) - .fetch_one(&self.0) - .await - .map(UserId) + test_support!(self, { + let query = " + INSERT INTO users (github_login, admin) + VALUES ($1, $2) + RETURNING id + "; + sqlx::query_scalar(query) + .bind(github_login) + .bind(admin) + .fetch_one(&self.db) + .await + .map(UserId) + }) } pub async fn get_all_users(&self) -> Result> { - let query = "SELECT * FROM users ORDER BY github_login ASC"; - sqlx::query_as(query).fetch_all(&self.0).await + test_support!(self, { + let query = "SELECT * FROM users ORDER BY github_login ASC"; + sqlx::query_as(query).fetch_all(&self.db).await + }) } pub async fn get_users_by_ids( @@ -99,53 +155,61 @@ impl Db { requester_id: UserId, ids: impl Iterator, ) -> Result> { - // Only return users that are in a common channel with the requesting user. - let query = " - SELECT users.* - FROM - users, channel_memberships - WHERE - users.id IN $1 AND - channel_memberships.user_id = users.id AND - channel_memberships.channel_id IN ( - SELECT channel_id - FROM channel_memberships - WHERE channel_memberships.user_id = $2 - ) - "; + test_support!(self, { + // Only return users that are in a common channel with the requesting user. + let query = " + SELECT users.* + FROM + users, channel_memberships + WHERE + users.id IN $1 AND + channel_memberships.user_id = users.id AND + channel_memberships.channel_id IN ( + SELECT channel_id + FROM channel_memberships + WHERE channel_memberships.user_id = $2 + ) + "; - sqlx::query_as(query) - .bind(&ids.map(|id| id.0).collect::>()) - .bind(requester_id) - .fetch_all(&self.0) - .await + sqlx::query_as(query) + .bind(&ids.map(|id| id.0).collect::>()) + .bind(requester_id) + .fetch_all(&self.db) + .await + }) } pub async fn get_user_by_github_login(&self, github_login: &str) -> Result> { - let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1"; - sqlx::query_as(query) - .bind(github_login) - .fetch_optional(&self.0) - .await + test_support!(self, { + let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1"; + sqlx::query_as(query) + .bind(github_login) + .fetch_optional(&self.db) + .await + }) } pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> { - let query = "UPDATE users SET admin = $1 WHERE id = $2"; - sqlx::query(query) - .bind(is_admin) - .bind(id.0) - .execute(&self.0) - .await - .map(drop) + test_support!(self, { + let query = "UPDATE users SET admin = $1 WHERE id = $2"; + sqlx::query(query) + .bind(is_admin) + .bind(id.0) + .execute(&self.db) + .await + .map(drop) + }) } pub async fn delete_user(&self, id: UserId) -> Result<()> { - let query = "DELETE FROM users WHERE id = $1;"; - sqlx::query(query) - .bind(id.0) - .execute(&self.0) - .await - .map(drop) + test_support!(self, { + let query = "DELETE FROM users WHERE id = $1;"; + sqlx::query(query) + .bind(id.0) + .execute(&self.db) + .await + .map(drop) + }) } // access tokens @@ -155,41 +219,47 @@ impl Db { user_id: UserId, access_token_hash: String, ) -> Result<()> { - let query = " + test_support!(self, { + let query = " INSERT INTO access_tokens (user_id, hash) VALUES ($1, $2) "; - sqlx::query(query) - .bind(user_id.0) - .bind(access_token_hash) - .execute(&self.0) - .await - .map(drop) + sqlx::query(query) + .bind(user_id.0) + .bind(access_token_hash) + .execute(&self.db) + .await + .map(drop) + }) } pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result> { - let query = "SELECT hash FROM access_tokens WHERE user_id = $1"; - sqlx::query_scalar(query) - .bind(user_id.0) - .fetch_all(&self.0) - .await + test_support!(self, { + let query = "SELECT hash FROM access_tokens WHERE user_id = $1"; + sqlx::query_scalar(query) + .bind(user_id.0) + .fetch_all(&self.db) + .await + }) } // orgs #[cfg(test)] pub async fn create_org(&self, name: &str, slug: &str) -> Result { - let query = " - INSERT INTO orgs (name, slug) - VALUES ($1, $2) - RETURNING id - "; - sqlx::query_scalar(query) - .bind(name) - .bind(slug) - .fetch_one(&self.0) - .await - .map(OrgId) + test_support!(self, { + let query = " + INSERT INTO orgs (name, slug) + VALUES ($1, $2) + RETURNING id + "; + sqlx::query_scalar(query) + .bind(name) + .bind(slug) + .fetch_one(&self.db) + .await + .map(OrgId) + }) } #[cfg(test)] @@ -199,50 +269,56 @@ impl Db { user_id: UserId, is_admin: bool, ) -> Result<()> { - let query = " - INSERT INTO org_memberships (org_id, user_id, admin) - VALUES ($1, $2, $3) - "; - sqlx::query(query) - .bind(org_id.0) - .bind(user_id.0) - .bind(is_admin) - .execute(&self.0) - .await - .map(drop) + test_support!(self, { + let query = " + INSERT INTO org_memberships (org_id, user_id, admin) + VALUES ($1, $2, $3) + "; + sqlx::query(query) + .bind(org_id.0) + .bind(user_id.0) + .bind(is_admin) + .execute(&self.db) + .await + .map(drop) + }) } // channels #[cfg(test)] pub async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result { - let query = " - INSERT INTO channels (owner_id, owner_is_user, name) - VALUES ($1, false, $2) - RETURNING id - "; - sqlx::query_scalar(query) - .bind(org_id.0) - .bind(name) - .fetch_one(&self.0) - .await - .map(ChannelId) + test_support!(self, { + let query = " + INSERT INTO channels (owner_id, owner_is_user, name) + VALUES ($1, false, $2) + RETURNING id + "; + sqlx::query_scalar(query) + .bind(org_id.0) + .bind(name) + .fetch_one(&self.db) + .await + .map(ChannelId) + }) } pub async fn get_channels_for_user(&self, user_id: UserId) -> Result> { - let query = " - SELECT - channels.id, channels.name - FROM - channel_memberships, channels - WHERE - channel_memberships.user_id = $1 AND - channel_memberships.channel_id = channels.id - "; - sqlx::query_as(query) - .bind(user_id.0) - .fetch_all(&self.0) - .await + test_support!(self, { + let query = " + SELECT + channels.id, channels.name + FROM + channel_memberships, channels + WHERE + channel_memberships.user_id = $1 AND + channel_memberships.channel_id = channels.id + "; + sqlx::query_as(query) + .bind(user_id.0) + .fetch_all(&self.db) + .await + }) } pub async fn can_user_access_channel( @@ -250,18 +326,20 @@ impl Db { user_id: UserId, channel_id: ChannelId, ) -> Result { - let query = " - SELECT id - FROM channel_memberships - WHERE user_id = $1 AND channel_id = $2 - LIMIT 1 - "; - sqlx::query_scalar::<_, i32>(query) - .bind(user_id.0) - .bind(channel_id.0) - .fetch_optional(&self.0) - .await - .map(|e| e.is_some()) + test_support!(self, { + let query = " + SELECT id + FROM channel_memberships + WHERE user_id = $1 AND channel_id = $2 + LIMIT 1 + "; + sqlx::query_scalar::<_, i32>(query) + .bind(user_id.0) + .bind(channel_id.0) + .fetch_optional(&self.db) + .await + .map(|e| e.is_some()) + }) } #[cfg(test)] @@ -271,17 +349,19 @@ impl Db { user_id: UserId, is_admin: bool, ) -> Result<()> { - let query = " - INSERT INTO channel_memberships (channel_id, user_id, admin) - VALUES ($1, $2, $3) - "; - sqlx::query(query) - .bind(channel_id.0) - .bind(user_id.0) - .bind(is_admin) - .execute(&self.0) - .await - .map(drop) + test_support!(self, { + let query = " + INSERT INTO channel_memberships (channel_id, user_id, admin) + VALUES ($1, $2, $3) + "; + sqlx::query(query) + .bind(channel_id.0) + .bind(user_id.0) + .bind(is_admin) + .execute(&self.db) + .await + .map(drop) + }) } // messages @@ -293,19 +373,21 @@ impl Db { body: &str, timestamp: OffsetDateTime, ) -> Result { - let query = " - INSERT INTO channel_messages (channel_id, sender_id, body, sent_at) - VALUES ($1, $2, $3, $4) - RETURNING id - "; - sqlx::query_scalar(query) - .bind(channel_id.0) - .bind(sender_id.0) - .bind(body) - .bind(timestamp) - .fetch_one(&self.0) - .await - .map(MessageId) + test_support!(self, { + let query = " + INSERT INTO channel_messages (channel_id, sender_id, body, sent_at) + VALUES ($1, $2, $3, $4) + RETURNING id + "; + sqlx::query_scalar(query) + .bind(channel_id.0) + .bind(sender_id.0) + .bind(body) + .bind(timestamp) + .fetch_one(&self.db) + .await + .map(MessageId) + }) } pub async fn get_recent_channel_messages( @@ -313,35 +395,39 @@ impl Db { channel_id: ChannelId, count: usize, ) -> Result> { - let query = r#" - SELECT - id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at - FROM - channel_messages - WHERE - channel_id = $1 - LIMIT $2 - "#; - sqlx::query_as(query) - .bind(channel_id.0) - .bind(count as i64) - .fetch_all(&self.0) - .await + test_support!(self, { + let query = r#" + SELECT + id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at + FROM + channel_messages + WHERE + channel_id = $1 + LIMIT $2 + "#; + sqlx::query_as(query) + .bind(channel_id.0) + .bind(count as i64) + .fetch_all(&self.db) + .await + }) } #[cfg(test)] pub async fn close(&self, db_name: &str) { - let query = " - SELECT pg_terminate_backend(pg_stat_activity.pid) - FROM pg_stat_activity - WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid(); - "; - sqlx::query(query) - .bind(db_name) - .execute(&self.0) - .await - .unwrap(); - self.0.close().await; + test_support!(self, { + let query = " + SELECT pg_terminate_backend(pg_stat_activity.pid) + FROM pg_stat_activity + WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid(); + "; + sqlx::query(query) + .bind(db_name) + .execute(&self.db) + .await + .unwrap(); + self.db.close().await; + }) } } diff --git a/server/src/main.rs b/server/src/main.rs index a49705e249..41f2638027 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -11,11 +11,11 @@ mod rpc; mod team; use self::errors::TideResultExt as _; -use anyhow::{Context, Result}; +use anyhow::Result; use async_std::net::TcpListener; use async_trait::async_trait; use auth::RequestExt as _; -use db::{Db, DbOptions}; +use db::Db; use handlebars::{Handlebars, TemplateRenderError}; use parking_lot::RwLock; use rust_embed::RustEmbed; @@ -54,12 +54,7 @@ pub struct AppState { impl AppState { async fn new(config: Config) -> tide::Result> { - let db = Db(DbOptions::new() - .max_connections(5) - .connect(&config.database_url) - .await - .context("failed to connect to postgres database")?); - + let db = Db::new(&config.database_url, 5).await?; let github_client = github::AppClient::new(config.github_app_id, config.github_private_key.clone()); let repo_client = github_client diff --git a/server/src/rpc.rs b/server/src/rpc.rs index e5a8d848e4..1dc372af01 100644 --- a/server/src/rpc.rs +++ b/server/src/rpc.rs @@ -922,16 +922,15 @@ mod tests { db::{self, UserId}, github, AppState, Config, }; - use async_std::{sync::RwLockReadGuard, task}; - use gpui::{ModelHandle, TestAppContext}; + use async_std::{ + sync::RwLockReadGuard, + task::{self, block_on}, + }; + use gpui::TestAppContext; use postage::mpsc; use rand::prelude::*; use serde_json::json; - use sqlx::{ - migrate::{MigrateDatabase, Migrator}, - types::time::OffsetDateTime, - Postgres, - }; + use sqlx::{migrate::MigrateDatabase, types::time::OffsetDateTime, Postgres}; use std::{path::Path, sync::Arc, time::Duration}; use zed::{ channel::{Channel, ChannelDetails, ChannelList}, @@ -1400,6 +1399,8 @@ mod tests { #[gpui::test] async fn test_basic_chat(mut cx_a: TestAppContext, mut cx_b: TestAppContext) { + cx_a.foreground().forbid_parking(); + // Connect to a server as 2 clients. let mut server = TestServer::start().await; let (user_id_a, client_a) = server.create_client(&mut cx_a, "user_a").await; @@ -1444,11 +1445,12 @@ mod tests { this.get_channel(channel_id.to_proto(), cx).unwrap() }); channel_a.read_with(&cx_a, |channel, _| assert!(channel.messages().is_empty())); - channel_a.next_notification(&cx_a).await; - assert_eq!( - channel_messages(&channel_a, &cx_a), - &[(user_id_b.to_proto(), "hello A, it's B.".to_string())] - ); + channel_a + .condition(&cx_a, |channel, _| { + channel_messages(channel) + == [(user_id_b.to_proto(), "hello A, it's B.".to_string())] + }) + .await; let channels_b = ChannelList::new(client_b, &mut cx_b.to_async()) .await @@ -1462,15 +1464,17 @@ mod tests { }] ) }); + let channel_b = channels_b.update(&mut cx_b, |this, cx| { this.get_channel(channel_id.to_proto(), cx).unwrap() }); channel_b.read_with(&cx_b, |channel, _| assert!(channel.messages().is_empty())); - channel_b.next_notification(&cx_b).await; - assert_eq!( - channel_messages(&channel_b, &cx_b), - &[(user_id_b.to_proto(), "hello A, it's B.".to_string())] - ); + channel_b + .condition(&cx_b, |channel, _| { + channel_messages(channel) + == [(user_id_b.to_proto(), "hello A, it's B.".to_string())] + }) + .await; channel_a.update(&mut cx_a, |channel, cx| { channel.send_message("oh, hi B.".to_string(), cx).unwrap(); @@ -1484,24 +1488,20 @@ mod tests { &["oh, hi B.", "sup"] ) }); - channel_a.next_notification(&cx_a).await; - channel_a.read_with(&cx_a, |channel, _| { - assert_eq!(channel.pending_messages().len(), 1); - }); - channel_a.next_notification(&cx_a).await; - channel_a.read_with(&cx_a, |channel, _| { - assert_eq!(channel.pending_messages().len(), 0); - }); - channel_b.next_notification(&cx_b).await; - assert_eq!( - channel_messages(&channel_b, &cx_b), - &[ - (user_id_b.to_proto(), "hello A, it's B.".to_string()), - (user_id_a.to_proto(), "oh, hi B.".to_string()), - (user_id_a.to_proto(), "sup".to_string()), - ] - ); + channel_a + .condition(&cx_a, |channel, _| channel.pending_messages().is_empty()) + .await; + channel_b + .condition(&cx_b, |channel, _| { + channel_messages(channel) + == [ + (user_id_b.to_proto(), "hello A, it's B.".to_string()), + (user_id_a.to_proto(), "oh, hi B.".to_string()), + (user_id_a.to_proto(), "sup".to_string()), + ] + }) + .await; assert_eq!( server.state().await.channels[&channel_id] @@ -1519,17 +1519,12 @@ mod tests { .condition(|state| !state.channels.contains_key(&channel_id)) .await; - fn channel_messages( - channel: &ModelHandle, - cx: &TestAppContext, - ) -> Vec<(u64, String)> { - channel.read_with(cx, |channel, _| { - channel - .messages() - .iter() - .map(|m| (m.sender_id, m.body.clone())) - .collect() - }) + fn channel_messages(channel: &Channel) -> Vec<(u64, String)> { + channel + .messages() + .iter() + .map(|m| (m.sender_id, m.body.clone())) + .collect() } } @@ -1584,21 +1579,12 @@ mod tests { config.session_secret = "a".repeat(32); config.database_url = format!("postgres://postgres@localhost/{}", db_name); - Self::create_db(&config.database_url).await; - let db = db::Db( - db::DbOptions::new() - .max_connections(5) - .connect(&config.database_url) - .await - .expect("failed to connect to postgres database"), - ); - let migrator = Migrator::new(Path::new(concat!( + Self::create_db(&config.database_url); + let db = db::Db::test(&config.database_url, 5); + db.migrate(Path::new(concat!( env!("CARGO_MANIFEST_DIR"), "/migrations" - ))) - .await - .unwrap(); - migrator.run(&db.0).await.unwrap(); + ))); let github_client = github::AppClient::test(); Arc::new(AppState { @@ -1611,16 +1597,14 @@ mod tests { }) } - async fn create_db(url: &str) { + fn create_db(url: &str) { // Enable tests to run in parallel by serializing the creation of each test database. lazy_static::lazy_static! { - static ref DB_CREATION: async_std::sync::Mutex<()> = async_std::sync::Mutex::new(()); + static ref DB_CREATION: std::sync::Mutex<()> = std::sync::Mutex::new(()); } - let _lock = DB_CREATION.lock().await; - Postgres::create_database(url) - .await - .expect("failed to create test database"); + let _lock = DB_CREATION.lock(); + block_on(Postgres::create_database(url)).expect("failed to create test database"); } async fn state<'a>(&'a self) -> RwLockReadGuard<'a, ServerState> {