mirror of
https://github.com/zed-industries/zed.git
synced 2024-12-25 01:34:02 +00:00
Make database interactions deterministic in test
Co-Authored-By: Nathan Sobo <nathan@zed.dev>
This commit is contained in:
parent
3ba530bca1
commit
98f691d16d
5 changed files with 338 additions and 285 deletions
|
@ -14,7 +14,7 @@ use keymap::MatchResult;
|
|||
use parking_lot::{Mutex, RwLock};
|
||||
use pathfinder_geometry::{rect::RectF, vector::vec2f};
|
||||
use platform::Event;
|
||||
use postage::{mpsc, oneshot, sink::Sink as _, stream::Stream as _};
|
||||
use postage::{mpsc, sink::Sink as _, stream::Stream as _};
|
||||
use smol::prelude::*;
|
||||
use std::{
|
||||
any::{type_name, Any, TypeId},
|
||||
|
@ -2310,24 +2310,6 @@ impl<T: Entity> ModelHandle<T> {
|
|||
cx.update_model(self, update)
|
||||
}
|
||||
|
||||
pub fn next_notification(&self, cx: &TestAppContext) -> impl Future<Output = ()> {
|
||||
let (tx, mut rx) = oneshot::channel();
|
||||
let mut tx = Some(tx);
|
||||
|
||||
let mut cx = cx.cx.borrow_mut();
|
||||
self.update(&mut *cx, |_, cx| {
|
||||
cx.observe(self, move |_, _, _| {
|
||||
if let Some(mut tx) = tx.take() {
|
||||
tx.blocking_send(()).ok();
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
async move {
|
||||
rx.recv().await;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn condition(
|
||||
&self,
|
||||
cx: &TestAppContext,
|
||||
|
|
|
@ -122,9 +122,14 @@ impl Deterministic {
|
|||
smol::pin!(future);
|
||||
|
||||
let unparker = self.parker.lock().unparker();
|
||||
let waker = waker_fn(move || {
|
||||
unparker.unpark();
|
||||
});
|
||||
let woken = Arc::new(AtomicBool::new(false));
|
||||
let waker = {
|
||||
let woken = woken.clone();
|
||||
waker_fn(move || {
|
||||
woken.store(true, SeqCst);
|
||||
unparker.unpark();
|
||||
})
|
||||
};
|
||||
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
let mut trace = Trace::default();
|
||||
|
@ -166,10 +171,11 @@ impl Deterministic {
|
|||
&& state.scheduled_from_background.is_empty()
|
||||
&& state.spawned_from_foreground.is_empty()
|
||||
{
|
||||
if state.forbid_parking {
|
||||
if state.forbid_parking && !woken.load(SeqCst) {
|
||||
panic!("deterministic executor parked after a call to forbid_parking");
|
||||
}
|
||||
drop(state);
|
||||
woken.store(false, SeqCst);
|
||||
self.parker.lock().park();
|
||||
}
|
||||
|
||||
|
|
466
server/src/db.rs
466
server/src/db.rs
|
@ -1,3 +1,5 @@
|
|||
use anyhow::Context;
|
||||
use async_std::task::{block_on, yield_now};
|
||||
use serde::Serialize;
|
||||
use sqlx::{FromRow, Result};
|
||||
use time::OffsetDateTime;
|
||||
|
@ -5,7 +7,24 @@ use time::OffsetDateTime;
|
|||
pub use async_sqlx_session::PostgresSessionStore as SessionStore;
|
||||
pub use sqlx::postgres::PgPoolOptions as DbOptions;
|
||||
|
||||
pub struct Db(pub sqlx::PgPool);
|
||||
macro_rules! test_support {
|
||||
($self:ident, { $($token:tt)* }) => {{
|
||||
let body = async {
|
||||
$($token)*
|
||||
};
|
||||
if $self.test_mode {
|
||||
yield_now().await;
|
||||
block_on(body)
|
||||
} else {
|
||||
body.await
|
||||
}
|
||||
}};
|
||||
}
|
||||
|
||||
pub struct Db {
|
||||
db: sqlx::PgPool,
|
||||
test_mode: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, FromRow, Serialize)]
|
||||
pub struct User {
|
||||
|
@ -37,6 +56,33 @@ pub struct ChannelMessage {
|
|||
}
|
||||
|
||||
impl Db {
|
||||
pub async fn new(url: &str, max_connections: u32) -> tide::Result<Self> {
|
||||
let db = DbOptions::new()
|
||||
.max_connections(max_connections)
|
||||
.connect(url)
|
||||
.await
|
||||
.context("failed to connect to postgres database")?;
|
||||
Ok(Self {
|
||||
db,
|
||||
test_mode: false,
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn test(url: &str, max_connections: u32) -> Self {
|
||||
let mut db = block_on(Self::new(url, max_connections)).unwrap();
|
||||
db.test_mode = true;
|
||||
db
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn migrate(&self, path: &std::path::Path) {
|
||||
block_on(async {
|
||||
let migrator = sqlx::migrate::Migrator::new(path).await.unwrap();
|
||||
migrator.run(&self.db).await.unwrap();
|
||||
});
|
||||
}
|
||||
|
||||
// signups
|
||||
|
||||
pub async fn create_signup(
|
||||
|
@ -45,53 +91,63 @@ impl Db {
|
|||
email_address: &str,
|
||||
about: &str,
|
||||
) -> Result<SignupId> {
|
||||
let query = "
|
||||
INSERT INTO signups (github_login, email_address, about)
|
||||
VALUES ($1, $2, $3)
|
||||
RETURNING id
|
||||
";
|
||||
sqlx::query_scalar(query)
|
||||
.bind(github_login)
|
||||
.bind(email_address)
|
||||
.bind(about)
|
||||
.fetch_one(&self.0)
|
||||
.await
|
||||
.map(SignupId)
|
||||
test_support!(self, {
|
||||
let query = "
|
||||
INSERT INTO signups (github_login, email_address, about)
|
||||
VALUES ($1, $2, $3)
|
||||
RETURNING id
|
||||
";
|
||||
sqlx::query_scalar(query)
|
||||
.bind(github_login)
|
||||
.bind(email_address)
|
||||
.bind(about)
|
||||
.fetch_one(&self.db)
|
||||
.await
|
||||
.map(SignupId)
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn get_all_signups(&self) -> Result<Vec<Signup>> {
|
||||
let query = "SELECT * FROM users ORDER BY github_login ASC";
|
||||
sqlx::query_as(query).fetch_all(&self.0).await
|
||||
test_support!(self, {
|
||||
let query = "SELECT * FROM users ORDER BY github_login ASC";
|
||||
sqlx::query_as(query).fetch_all(&self.db).await
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn delete_signup(&self, id: SignupId) -> Result<()> {
|
||||
let query = "DELETE FROM signups WHERE id = $1";
|
||||
sqlx::query(query)
|
||||
.bind(id.0)
|
||||
.execute(&self.0)
|
||||
.await
|
||||
.map(drop)
|
||||
test_support!(self, {
|
||||
let query = "DELETE FROM signups WHERE id = $1";
|
||||
sqlx::query(query)
|
||||
.bind(id.0)
|
||||
.execute(&self.db)
|
||||
.await
|
||||
.map(drop)
|
||||
})
|
||||
}
|
||||
|
||||
// users
|
||||
|
||||
pub async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
|
||||
let query = "
|
||||
INSERT INTO users (github_login, admin)
|
||||
VALUES ($1, $2)
|
||||
RETURNING id
|
||||
";
|
||||
sqlx::query_scalar(query)
|
||||
.bind(github_login)
|
||||
.bind(admin)
|
||||
.fetch_one(&self.0)
|
||||
.await
|
||||
.map(UserId)
|
||||
test_support!(self, {
|
||||
let query = "
|
||||
INSERT INTO users (github_login, admin)
|
||||
VALUES ($1, $2)
|
||||
RETURNING id
|
||||
";
|
||||
sqlx::query_scalar(query)
|
||||
.bind(github_login)
|
||||
.bind(admin)
|
||||
.fetch_one(&self.db)
|
||||
.await
|
||||
.map(UserId)
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn get_all_users(&self) -> Result<Vec<User>> {
|
||||
let query = "SELECT * FROM users ORDER BY github_login ASC";
|
||||
sqlx::query_as(query).fetch_all(&self.0).await
|
||||
test_support!(self, {
|
||||
let query = "SELECT * FROM users ORDER BY github_login ASC";
|
||||
sqlx::query_as(query).fetch_all(&self.db).await
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn get_users_by_ids(
|
||||
|
@ -99,53 +155,61 @@ impl Db {
|
|||
requester_id: UserId,
|
||||
ids: impl Iterator<Item = UserId>,
|
||||
) -> Result<Vec<User>> {
|
||||
// Only return users that are in a common channel with the requesting user.
|
||||
let query = "
|
||||
SELECT users.*
|
||||
FROM
|
||||
users, channel_memberships
|
||||
WHERE
|
||||
users.id IN $1 AND
|
||||
channel_memberships.user_id = users.id AND
|
||||
channel_memberships.channel_id IN (
|
||||
SELECT channel_id
|
||||
FROM channel_memberships
|
||||
WHERE channel_memberships.user_id = $2
|
||||
)
|
||||
";
|
||||
test_support!(self, {
|
||||
// Only return users that are in a common channel with the requesting user.
|
||||
let query = "
|
||||
SELECT users.*
|
||||
FROM
|
||||
users, channel_memberships
|
||||
WHERE
|
||||
users.id IN $1 AND
|
||||
channel_memberships.user_id = users.id AND
|
||||
channel_memberships.channel_id IN (
|
||||
SELECT channel_id
|
||||
FROM channel_memberships
|
||||
WHERE channel_memberships.user_id = $2
|
||||
)
|
||||
";
|
||||
|
||||
sqlx::query_as(query)
|
||||
.bind(&ids.map(|id| id.0).collect::<Vec<_>>())
|
||||
.bind(requester_id)
|
||||
.fetch_all(&self.0)
|
||||
.await
|
||||
sqlx::query_as(query)
|
||||
.bind(&ids.map(|id| id.0).collect::<Vec<_>>())
|
||||
.bind(requester_id)
|
||||
.fetch_all(&self.db)
|
||||
.await
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
|
||||
let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
|
||||
sqlx::query_as(query)
|
||||
.bind(github_login)
|
||||
.fetch_optional(&self.0)
|
||||
.await
|
||||
test_support!(self, {
|
||||
let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
|
||||
sqlx::query_as(query)
|
||||
.bind(github_login)
|
||||
.fetch_optional(&self.db)
|
||||
.await
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
|
||||
let query = "UPDATE users SET admin = $1 WHERE id = $2";
|
||||
sqlx::query(query)
|
||||
.bind(is_admin)
|
||||
.bind(id.0)
|
||||
.execute(&self.0)
|
||||
.await
|
||||
.map(drop)
|
||||
test_support!(self, {
|
||||
let query = "UPDATE users SET admin = $1 WHERE id = $2";
|
||||
sqlx::query(query)
|
||||
.bind(is_admin)
|
||||
.bind(id.0)
|
||||
.execute(&self.db)
|
||||
.await
|
||||
.map(drop)
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn delete_user(&self, id: UserId) -> Result<()> {
|
||||
let query = "DELETE FROM users WHERE id = $1;";
|
||||
sqlx::query(query)
|
||||
.bind(id.0)
|
||||
.execute(&self.0)
|
||||
.await
|
||||
.map(drop)
|
||||
test_support!(self, {
|
||||
let query = "DELETE FROM users WHERE id = $1;";
|
||||
sqlx::query(query)
|
||||
.bind(id.0)
|
||||
.execute(&self.db)
|
||||
.await
|
||||
.map(drop)
|
||||
})
|
||||
}
|
||||
|
||||
// access tokens
|
||||
|
@ -155,41 +219,47 @@ impl Db {
|
|||
user_id: UserId,
|
||||
access_token_hash: String,
|
||||
) -> Result<()> {
|
||||
let query = "
|
||||
test_support!(self, {
|
||||
let query = "
|
||||
INSERT INTO access_tokens (user_id, hash)
|
||||
VALUES ($1, $2)
|
||||
";
|
||||
sqlx::query(query)
|
||||
.bind(user_id.0)
|
||||
.bind(access_token_hash)
|
||||
.execute(&self.0)
|
||||
.await
|
||||
.map(drop)
|
||||
sqlx::query(query)
|
||||
.bind(user_id.0)
|
||||
.bind(access_token_hash)
|
||||
.execute(&self.db)
|
||||
.await
|
||||
.map(drop)
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
|
||||
let query = "SELECT hash FROM access_tokens WHERE user_id = $1";
|
||||
sqlx::query_scalar(query)
|
||||
.bind(user_id.0)
|
||||
.fetch_all(&self.0)
|
||||
.await
|
||||
test_support!(self, {
|
||||
let query = "SELECT hash FROM access_tokens WHERE user_id = $1";
|
||||
sqlx::query_scalar(query)
|
||||
.bind(user_id.0)
|
||||
.fetch_all(&self.db)
|
||||
.await
|
||||
})
|
||||
}
|
||||
|
||||
// orgs
|
||||
|
||||
#[cfg(test)]
|
||||
pub async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
|
||||
let query = "
|
||||
INSERT INTO orgs (name, slug)
|
||||
VALUES ($1, $2)
|
||||
RETURNING id
|
||||
";
|
||||
sqlx::query_scalar(query)
|
||||
.bind(name)
|
||||
.bind(slug)
|
||||
.fetch_one(&self.0)
|
||||
.await
|
||||
.map(OrgId)
|
||||
test_support!(self, {
|
||||
let query = "
|
||||
INSERT INTO orgs (name, slug)
|
||||
VALUES ($1, $2)
|
||||
RETURNING id
|
||||
";
|
||||
sqlx::query_scalar(query)
|
||||
.bind(name)
|
||||
.bind(slug)
|
||||
.fetch_one(&self.db)
|
||||
.await
|
||||
.map(OrgId)
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
@ -199,50 +269,56 @@ impl Db {
|
|||
user_id: UserId,
|
||||
is_admin: bool,
|
||||
) -> Result<()> {
|
||||
let query = "
|
||||
INSERT INTO org_memberships (org_id, user_id, admin)
|
||||
VALUES ($1, $2, $3)
|
||||
";
|
||||
sqlx::query(query)
|
||||
.bind(org_id.0)
|
||||
.bind(user_id.0)
|
||||
.bind(is_admin)
|
||||
.execute(&self.0)
|
||||
.await
|
||||
.map(drop)
|
||||
test_support!(self, {
|
||||
let query = "
|
||||
INSERT INTO org_memberships (org_id, user_id, admin)
|
||||
VALUES ($1, $2, $3)
|
||||
";
|
||||
sqlx::query(query)
|
||||
.bind(org_id.0)
|
||||
.bind(user_id.0)
|
||||
.bind(is_admin)
|
||||
.execute(&self.db)
|
||||
.await
|
||||
.map(drop)
|
||||
})
|
||||
}
|
||||
|
||||
// channels
|
||||
|
||||
#[cfg(test)]
|
||||
pub async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
|
||||
let query = "
|
||||
INSERT INTO channels (owner_id, owner_is_user, name)
|
||||
VALUES ($1, false, $2)
|
||||
RETURNING id
|
||||
";
|
||||
sqlx::query_scalar(query)
|
||||
.bind(org_id.0)
|
||||
.bind(name)
|
||||
.fetch_one(&self.0)
|
||||
.await
|
||||
.map(ChannelId)
|
||||
test_support!(self, {
|
||||
let query = "
|
||||
INSERT INTO channels (owner_id, owner_is_user, name)
|
||||
VALUES ($1, false, $2)
|
||||
RETURNING id
|
||||
";
|
||||
sqlx::query_scalar(query)
|
||||
.bind(org_id.0)
|
||||
.bind(name)
|
||||
.fetch_one(&self.db)
|
||||
.await
|
||||
.map(ChannelId)
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn get_channels_for_user(&self, user_id: UserId) -> Result<Vec<Channel>> {
|
||||
let query = "
|
||||
SELECT
|
||||
channels.id, channels.name
|
||||
FROM
|
||||
channel_memberships, channels
|
||||
WHERE
|
||||
channel_memberships.user_id = $1 AND
|
||||
channel_memberships.channel_id = channels.id
|
||||
";
|
||||
sqlx::query_as(query)
|
||||
.bind(user_id.0)
|
||||
.fetch_all(&self.0)
|
||||
.await
|
||||
test_support!(self, {
|
||||
let query = "
|
||||
SELECT
|
||||
channels.id, channels.name
|
||||
FROM
|
||||
channel_memberships, channels
|
||||
WHERE
|
||||
channel_memberships.user_id = $1 AND
|
||||
channel_memberships.channel_id = channels.id
|
||||
";
|
||||
sqlx::query_as(query)
|
||||
.bind(user_id.0)
|
||||
.fetch_all(&self.db)
|
||||
.await
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn can_user_access_channel(
|
||||
|
@ -250,18 +326,20 @@ impl Db {
|
|||
user_id: UserId,
|
||||
channel_id: ChannelId,
|
||||
) -> Result<bool> {
|
||||
let query = "
|
||||
SELECT id
|
||||
FROM channel_memberships
|
||||
WHERE user_id = $1 AND channel_id = $2
|
||||
LIMIT 1
|
||||
";
|
||||
sqlx::query_scalar::<_, i32>(query)
|
||||
.bind(user_id.0)
|
||||
.bind(channel_id.0)
|
||||
.fetch_optional(&self.0)
|
||||
.await
|
||||
.map(|e| e.is_some())
|
||||
test_support!(self, {
|
||||
let query = "
|
||||
SELECT id
|
||||
FROM channel_memberships
|
||||
WHERE user_id = $1 AND channel_id = $2
|
||||
LIMIT 1
|
||||
";
|
||||
sqlx::query_scalar::<_, i32>(query)
|
||||
.bind(user_id.0)
|
||||
.bind(channel_id.0)
|
||||
.fetch_optional(&self.db)
|
||||
.await
|
||||
.map(|e| e.is_some())
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
@ -271,17 +349,19 @@ impl Db {
|
|||
user_id: UserId,
|
||||
is_admin: bool,
|
||||
) -> Result<()> {
|
||||
let query = "
|
||||
INSERT INTO channel_memberships (channel_id, user_id, admin)
|
||||
VALUES ($1, $2, $3)
|
||||
";
|
||||
sqlx::query(query)
|
||||
.bind(channel_id.0)
|
||||
.bind(user_id.0)
|
||||
.bind(is_admin)
|
||||
.execute(&self.0)
|
||||
.await
|
||||
.map(drop)
|
||||
test_support!(self, {
|
||||
let query = "
|
||||
INSERT INTO channel_memberships (channel_id, user_id, admin)
|
||||
VALUES ($1, $2, $3)
|
||||
";
|
||||
sqlx::query(query)
|
||||
.bind(channel_id.0)
|
||||
.bind(user_id.0)
|
||||
.bind(is_admin)
|
||||
.execute(&self.db)
|
||||
.await
|
||||
.map(drop)
|
||||
})
|
||||
}
|
||||
|
||||
// messages
|
||||
|
@ -293,19 +373,21 @@ impl Db {
|
|||
body: &str,
|
||||
timestamp: OffsetDateTime,
|
||||
) -> Result<MessageId> {
|
||||
let query = "
|
||||
INSERT INTO channel_messages (channel_id, sender_id, body, sent_at)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
RETURNING id
|
||||
";
|
||||
sqlx::query_scalar(query)
|
||||
.bind(channel_id.0)
|
||||
.bind(sender_id.0)
|
||||
.bind(body)
|
||||
.bind(timestamp)
|
||||
.fetch_one(&self.0)
|
||||
.await
|
||||
.map(MessageId)
|
||||
test_support!(self, {
|
||||
let query = "
|
||||
INSERT INTO channel_messages (channel_id, sender_id, body, sent_at)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
RETURNING id
|
||||
";
|
||||
sqlx::query_scalar(query)
|
||||
.bind(channel_id.0)
|
||||
.bind(sender_id.0)
|
||||
.bind(body)
|
||||
.bind(timestamp)
|
||||
.fetch_one(&self.db)
|
||||
.await
|
||||
.map(MessageId)
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn get_recent_channel_messages(
|
||||
|
@ -313,35 +395,39 @@ impl Db {
|
|||
channel_id: ChannelId,
|
||||
count: usize,
|
||||
) -> Result<Vec<ChannelMessage>> {
|
||||
let query = r#"
|
||||
SELECT
|
||||
id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at
|
||||
FROM
|
||||
channel_messages
|
||||
WHERE
|
||||
channel_id = $1
|
||||
LIMIT $2
|
||||
"#;
|
||||
sqlx::query_as(query)
|
||||
.bind(channel_id.0)
|
||||
.bind(count as i64)
|
||||
.fetch_all(&self.0)
|
||||
.await
|
||||
test_support!(self, {
|
||||
let query = r#"
|
||||
SELECT
|
||||
id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at
|
||||
FROM
|
||||
channel_messages
|
||||
WHERE
|
||||
channel_id = $1
|
||||
LIMIT $2
|
||||
"#;
|
||||
sqlx::query_as(query)
|
||||
.bind(channel_id.0)
|
||||
.bind(count as i64)
|
||||
.fetch_all(&self.db)
|
||||
.await
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub async fn close(&self, db_name: &str) {
|
||||
let query = "
|
||||
SELECT pg_terminate_backend(pg_stat_activity.pid)
|
||||
FROM pg_stat_activity
|
||||
WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid();
|
||||
";
|
||||
sqlx::query(query)
|
||||
.bind(db_name)
|
||||
.execute(&self.0)
|
||||
.await
|
||||
.unwrap();
|
||||
self.0.close().await;
|
||||
test_support!(self, {
|
||||
let query = "
|
||||
SELECT pg_terminate_backend(pg_stat_activity.pid)
|
||||
FROM pg_stat_activity
|
||||
WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid();
|
||||
";
|
||||
sqlx::query(query)
|
||||
.bind(db_name)
|
||||
.execute(&self.db)
|
||||
.await
|
||||
.unwrap();
|
||||
self.db.close().await;
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -11,11 +11,11 @@ mod rpc;
|
|||
mod team;
|
||||
|
||||
use self::errors::TideResultExt as _;
|
||||
use anyhow::{Context, Result};
|
||||
use anyhow::Result;
|
||||
use async_std::net::TcpListener;
|
||||
use async_trait::async_trait;
|
||||
use auth::RequestExt as _;
|
||||
use db::{Db, DbOptions};
|
||||
use db::Db;
|
||||
use handlebars::{Handlebars, TemplateRenderError};
|
||||
use parking_lot::RwLock;
|
||||
use rust_embed::RustEmbed;
|
||||
|
@ -54,12 +54,7 @@ pub struct AppState {
|
|||
|
||||
impl AppState {
|
||||
async fn new(config: Config) -> tide::Result<Arc<Self>> {
|
||||
let db = Db(DbOptions::new()
|
||||
.max_connections(5)
|
||||
.connect(&config.database_url)
|
||||
.await
|
||||
.context("failed to connect to postgres database")?);
|
||||
|
||||
let db = Db::new(&config.database_url, 5).await?;
|
||||
let github_client =
|
||||
github::AppClient::new(config.github_app_id, config.github_private_key.clone());
|
||||
let repo_client = github_client
|
||||
|
|
|
@ -922,16 +922,15 @@ mod tests {
|
|||
db::{self, UserId},
|
||||
github, AppState, Config,
|
||||
};
|
||||
use async_std::{sync::RwLockReadGuard, task};
|
||||
use gpui::{ModelHandle, TestAppContext};
|
||||
use async_std::{
|
||||
sync::RwLockReadGuard,
|
||||
task::{self, block_on},
|
||||
};
|
||||
use gpui::TestAppContext;
|
||||
use postage::mpsc;
|
||||
use rand::prelude::*;
|
||||
use serde_json::json;
|
||||
use sqlx::{
|
||||
migrate::{MigrateDatabase, Migrator},
|
||||
types::time::OffsetDateTime,
|
||||
Postgres,
|
||||
};
|
||||
use sqlx::{migrate::MigrateDatabase, types::time::OffsetDateTime, Postgres};
|
||||
use std::{path::Path, sync::Arc, time::Duration};
|
||||
use zed::{
|
||||
channel::{Channel, ChannelDetails, ChannelList},
|
||||
|
@ -1400,6 +1399,8 @@ mod tests {
|
|||
|
||||
#[gpui::test]
|
||||
async fn test_basic_chat(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
|
||||
cx_a.foreground().forbid_parking();
|
||||
|
||||
// Connect to a server as 2 clients.
|
||||
let mut server = TestServer::start().await;
|
||||
let (user_id_a, client_a) = server.create_client(&mut cx_a, "user_a").await;
|
||||
|
@ -1444,11 +1445,12 @@ mod tests {
|
|||
this.get_channel(channel_id.to_proto(), cx).unwrap()
|
||||
});
|
||||
channel_a.read_with(&cx_a, |channel, _| assert!(channel.messages().is_empty()));
|
||||
channel_a.next_notification(&cx_a).await;
|
||||
assert_eq!(
|
||||
channel_messages(&channel_a, &cx_a),
|
||||
&[(user_id_b.to_proto(), "hello A, it's B.".to_string())]
|
||||
);
|
||||
channel_a
|
||||
.condition(&cx_a, |channel, _| {
|
||||
channel_messages(channel)
|
||||
== [(user_id_b.to_proto(), "hello A, it's B.".to_string())]
|
||||
})
|
||||
.await;
|
||||
|
||||
let channels_b = ChannelList::new(client_b, &mut cx_b.to_async())
|
||||
.await
|
||||
|
@ -1462,15 +1464,17 @@ mod tests {
|
|||
}]
|
||||
)
|
||||
});
|
||||
|
||||
let channel_b = channels_b.update(&mut cx_b, |this, cx| {
|
||||
this.get_channel(channel_id.to_proto(), cx).unwrap()
|
||||
});
|
||||
channel_b.read_with(&cx_b, |channel, _| assert!(channel.messages().is_empty()));
|
||||
channel_b.next_notification(&cx_b).await;
|
||||
assert_eq!(
|
||||
channel_messages(&channel_b, &cx_b),
|
||||
&[(user_id_b.to_proto(), "hello A, it's B.".to_string())]
|
||||
);
|
||||
channel_b
|
||||
.condition(&cx_b, |channel, _| {
|
||||
channel_messages(channel)
|
||||
== [(user_id_b.to_proto(), "hello A, it's B.".to_string())]
|
||||
})
|
||||
.await;
|
||||
|
||||
channel_a.update(&mut cx_a, |channel, cx| {
|
||||
channel.send_message("oh, hi B.".to_string(), cx).unwrap();
|
||||
|
@ -1484,24 +1488,20 @@ mod tests {
|
|||
&["oh, hi B.", "sup"]
|
||||
)
|
||||
});
|
||||
channel_a.next_notification(&cx_a).await;
|
||||
channel_a.read_with(&cx_a, |channel, _| {
|
||||
assert_eq!(channel.pending_messages().len(), 1);
|
||||
});
|
||||
channel_a.next_notification(&cx_a).await;
|
||||
channel_a.read_with(&cx_a, |channel, _| {
|
||||
assert_eq!(channel.pending_messages().len(), 0);
|
||||
});
|
||||
|
||||
channel_b.next_notification(&cx_b).await;
|
||||
assert_eq!(
|
||||
channel_messages(&channel_b, &cx_b),
|
||||
&[
|
||||
(user_id_b.to_proto(), "hello A, it's B.".to_string()),
|
||||
(user_id_a.to_proto(), "oh, hi B.".to_string()),
|
||||
(user_id_a.to_proto(), "sup".to_string()),
|
||||
]
|
||||
);
|
||||
channel_a
|
||||
.condition(&cx_a, |channel, _| channel.pending_messages().is_empty())
|
||||
.await;
|
||||
channel_b
|
||||
.condition(&cx_b, |channel, _| {
|
||||
channel_messages(channel)
|
||||
== [
|
||||
(user_id_b.to_proto(), "hello A, it's B.".to_string()),
|
||||
(user_id_a.to_proto(), "oh, hi B.".to_string()),
|
||||
(user_id_a.to_proto(), "sup".to_string()),
|
||||
]
|
||||
})
|
||||
.await;
|
||||
|
||||
assert_eq!(
|
||||
server.state().await.channels[&channel_id]
|
||||
|
@ -1519,17 +1519,12 @@ mod tests {
|
|||
.condition(|state| !state.channels.contains_key(&channel_id))
|
||||
.await;
|
||||
|
||||
fn channel_messages(
|
||||
channel: &ModelHandle<Channel>,
|
||||
cx: &TestAppContext,
|
||||
) -> Vec<(u64, String)> {
|
||||
channel.read_with(cx, |channel, _| {
|
||||
channel
|
||||
.messages()
|
||||
.iter()
|
||||
.map(|m| (m.sender_id, m.body.clone()))
|
||||
.collect()
|
||||
})
|
||||
fn channel_messages(channel: &Channel) -> Vec<(u64, String)> {
|
||||
channel
|
||||
.messages()
|
||||
.iter()
|
||||
.map(|m| (m.sender_id, m.body.clone()))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1584,21 +1579,12 @@ mod tests {
|
|||
config.session_secret = "a".repeat(32);
|
||||
config.database_url = format!("postgres://postgres@localhost/{}", db_name);
|
||||
|
||||
Self::create_db(&config.database_url).await;
|
||||
let db = db::Db(
|
||||
db::DbOptions::new()
|
||||
.max_connections(5)
|
||||
.connect(&config.database_url)
|
||||
.await
|
||||
.expect("failed to connect to postgres database"),
|
||||
);
|
||||
let migrator = Migrator::new(Path::new(concat!(
|
||||
Self::create_db(&config.database_url);
|
||||
let db = db::Db::test(&config.database_url, 5);
|
||||
db.migrate(Path::new(concat!(
|
||||
env!("CARGO_MANIFEST_DIR"),
|
||||
"/migrations"
|
||||
)))
|
||||
.await
|
||||
.unwrap();
|
||||
migrator.run(&db.0).await.unwrap();
|
||||
)));
|
||||
|
||||
let github_client = github::AppClient::test();
|
||||
Arc::new(AppState {
|
||||
|
@ -1611,16 +1597,14 @@ mod tests {
|
|||
})
|
||||
}
|
||||
|
||||
async fn create_db(url: &str) {
|
||||
fn create_db(url: &str) {
|
||||
// Enable tests to run in parallel by serializing the creation of each test database.
|
||||
lazy_static::lazy_static! {
|
||||
static ref DB_CREATION: async_std::sync::Mutex<()> = async_std::sync::Mutex::new(());
|
||||
static ref DB_CREATION: std::sync::Mutex<()> = std::sync::Mutex::new(());
|
||||
}
|
||||
|
||||
let _lock = DB_CREATION.lock().await;
|
||||
Postgres::create_database(url)
|
||||
.await
|
||||
.expect("failed to create test database");
|
||||
let _lock = DB_CREATION.lock();
|
||||
block_on(Postgres::create_database(url)).expect("failed to create test database");
|
||||
}
|
||||
|
||||
async fn state<'a>(&'a self) -> RwLockReadGuard<'a, ServerState> {
|
||||
|
|
Loading…
Reference in a new issue