diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 35fad50e12..5b9467fc09 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -188,6 +188,20 @@ pub struct PostgresDb { pool: sqlx::PgPool, } +macro_rules! test_support { + ($self:ident, { $($token:tt)* }) => {{ + let body = async { + $($token)* + }; + + if cfg!(test) { + tokio::runtime::Builder::new_current_thread().enable_io().enable_time().build().unwrap().block_on(body) + } else { + body.await + } + }}; +} + impl PostgresDb { pub async fn new(url: &str, max_connections: u32) -> Result { let pool = DbOptions::new() @@ -262,51 +276,58 @@ impl Db for PostgresDb { admin: bool, params: NewUserParams, ) -> Result { - let query = " - INSERT INTO users (email_address, github_login, github_user_id, admin) - VALUES ($1, $2, $3, $4) - ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login - RETURNING id, metrics_id::text - "; - let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query) - .bind(email_address) - .bind(params.github_login) - .bind(params.github_user_id) - .bind(admin) - .fetch_one(&self.pool) - .await?; - Ok(NewUserResult { - user_id, - metrics_id, - signup_device_id: None, - inviting_user_id: None, + test_support!(self, { + let query = " + INSERT INTO users (email_address, github_login, github_user_id, admin) + VALUES ($1, $2, $3, $4) + ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login + RETURNING id, metrics_id::text + "; + + let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query) + .bind(email_address) + .bind(params.github_login) + .bind(params.github_user_id) + .bind(admin) + .fetch_one(&self.pool) + .await?; + Ok(NewUserResult { + user_id, + metrics_id, + signup_device_id: None, + inviting_user_id: None, + }) }) } async fn get_all_users(&self, page: u32, limit: u32) -> Result> { - let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2"; - Ok(sqlx::query_as(query) - .bind(limit as i32) - .bind((page * limit) as i32) - .fetch_all(&self.pool) - .await?) + test_support!(self, { + let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2"; + Ok(sqlx::query_as(query) + .bind(limit as i32) + .bind((page * limit) as i32) + .fetch_all(&self.pool) + .await?) + }) } async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result> { - let like_string = Self::fuzzy_like_string(name_query); - let query = " - SELECT users.* - FROM users - WHERE github_login ILIKE $1 - ORDER BY github_login <-> $2 - LIMIT $3 - "; - Ok(sqlx::query_as(query) - .bind(like_string) - .bind(name_query) - .bind(limit as i32) - .fetch_all(&self.pool) - .await?) + test_support!(self, { + let like_string = Self::fuzzy_like_string(name_query); + let query = " + SELECT users.* + FROM users + WHERE github_login ILIKE $1 + ORDER BY github_login <-> $2 + LIMIT $3 + "; + Ok(sqlx::query_as(query) + .bind(like_string) + .bind(name_query) + .bind(limit as i32) + .fetch_all(&self.pool) + .await?) + }) } async fn get_user_by_id(&self, id: UserId) -> Result> { @@ -315,42 +336,48 @@ impl Db for PostgresDb { } async fn get_user_metrics_id(&self, id: UserId) -> Result { - let query = " - SELECT metrics_id::text - FROM users - WHERE id = $1 - "; - Ok(sqlx::query_scalar(query) - .bind(id) - .fetch_one(&self.pool) - .await?) + test_support!(self, { + let query = " + SELECT metrics_id::text + FROM users + WHERE id = $1 + "; + Ok(sqlx::query_scalar(query) + .bind(id) + .fetch_one(&self.pool) + .await?) + }) } async fn get_users_by_ids(&self, ids: Vec) -> Result> { - let ids = ids.into_iter().map(|id| id.0).collect::>(); - let query = " - SELECT users.* - FROM users - WHERE users.id = ANY ($1) - "; - Ok(sqlx::query_as(query) - .bind(&ids) - .fetch_all(&self.pool) - .await?) + test_support!(self, { + let ids = ids.into_iter().map(|id| id.0).collect::>(); + let query = " + SELECT users.* + FROM users + WHERE users.id = ANY ($1) + "; + Ok(sqlx::query_as(query) + .bind(&ids) + .fetch_all(&self.pool) + .await?) + }) } async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result> { - let query = format!( - " - SELECT users.* - FROM users - WHERE invite_count = 0 - AND inviter_id IS{} NULL - ", - if invited_by_another_user { " NOT" } else { "" } - ); + test_support!(self, { + let query = format!( + " + SELECT users.* + FROM users + WHERE invite_count = 0 + AND inviter_id IS{} NULL + ", + if invited_by_another_user { " NOT" } else { "" } + ); - Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?) + Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?) + }) } async fn get_user_by_github_account( @@ -358,176 +385,193 @@ impl Db for PostgresDb { github_login: &str, github_user_id: Option, ) -> Result> { - if let Some(github_user_id) = github_user_id { - let mut user = sqlx::query_as::<_, User>( - " - UPDATE users - SET github_login = $1 - WHERE github_user_id = $2 - RETURNING * - ", - ) - .bind(github_login) - .bind(github_user_id) - .fetch_optional(&self.pool) - .await?; - - if user.is_none() { - user = sqlx::query_as::<_, User>( + test_support!(self, { + if let Some(github_user_id) = github_user_id { + let mut user = sqlx::query_as::<_, User>( " UPDATE users - SET github_user_id = $1 - WHERE github_login = $2 + SET github_login = $1 + WHERE github_user_id = $2 RETURNING * ", ) + .bind(github_login) .bind(github_user_id) + .fetch_optional(&self.pool) + .await?; + + if user.is_none() { + user = sqlx::query_as::<_, User>( + " + UPDATE users + SET github_user_id = $1 + WHERE github_login = $2 + RETURNING * + ", + ) + .bind(github_user_id) + .bind(github_login) + .fetch_optional(&self.pool) + .await?; + } + + Ok(user) + } else { + let user = sqlx::query_as( + " + SELECT * FROM users + WHERE github_login = $1 + LIMIT 1 + ", + ) .bind(github_login) .fetch_optional(&self.pool) .await?; + Ok(user) } - - Ok(user) - } else { - Ok(sqlx::query_as( - " - SELECT * FROM users - WHERE github_login = $1 - LIMIT 1 - ", - ) - .bind(github_login) - .fetch_optional(&self.pool) - .await?) - } + }) } async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> { - let query = "UPDATE users SET admin = $1 WHERE id = $2"; - Ok(sqlx::query(query) - .bind(is_admin) - .bind(id.0) - .execute(&self.pool) - .await - .map(drop)?) + test_support!(self, { + let query = "UPDATE users SET admin = $1 WHERE id = $2"; + Ok(sqlx::query(query) + .bind(is_admin) + .bind(id.0) + .execute(&self.pool) + .await + .map(drop)?) + }) } async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> { - let query = "UPDATE users SET connected_once = $1 WHERE id = $2"; - Ok(sqlx::query(query) - .bind(connected_once) - .bind(id.0) - .execute(&self.pool) - .await - .map(drop)?) + test_support!(self, { + let query = "UPDATE users SET connected_once = $1 WHERE id = $2"; + Ok(sqlx::query(query) + .bind(connected_once) + .bind(id.0) + .execute(&self.pool) + .await + .map(drop)?) + }) } async fn destroy_user(&self, id: UserId) -> Result<()> { - let query = "DELETE FROM access_tokens WHERE user_id = $1;"; - sqlx::query(query) - .bind(id.0) - .execute(&self.pool) - .await - .map(drop)?; - let query = "DELETE FROM users WHERE id = $1;"; - Ok(sqlx::query(query) - .bind(id.0) - .execute(&self.pool) - .await - .map(drop)?) + test_support!(self, { + let query = "DELETE FROM access_tokens WHERE user_id = $1;"; + sqlx::query(query) + .bind(id.0) + .execute(&self.pool) + .await + .map(drop)?; + let query = "DELETE FROM users WHERE id = $1;"; + Ok(sqlx::query(query) + .bind(id.0) + .execute(&self.pool) + .await + .map(drop)?) + }) } // signups async fn create_signup(&self, signup: Signup) -> Result<()> { - sqlx::query( - " - INSERT INTO signups - ( - email_address, - email_confirmation_code, - email_confirmation_sent, - platform_linux, - platform_mac, - platform_windows, - platform_unknown, - editor_features, - programming_languages, - device_id + test_support!(self, { + sqlx::query( + " + INSERT INTO signups + ( + email_address, + email_confirmation_code, + email_confirmation_sent, + platform_linux, + platform_mac, + platform_windows, + platform_unknown, + editor_features, + programming_languages, + device_id + ) + VALUES + ($1, $2, 'f', $3, $4, $5, 'f', $6, $7, $8) + RETURNING id + ", ) - VALUES - ($1, $2, 'f', $3, $4, $5, 'f', $6, $7, $8) - RETURNING id - ", - ) - .bind(&signup.email_address) - .bind(&random_email_confirmation_code()) - .bind(&signup.platform_linux) - .bind(&signup.platform_mac) - .bind(&signup.platform_windows) - .bind(&signup.editor_features) - .bind(&signup.programming_languages) - .bind(&signup.device_id) - .execute(&self.pool) - .await?; - Ok(()) + .bind(&signup.email_address) + .bind(&random_email_confirmation_code()) + .bind(&signup.platform_linux) + .bind(&signup.platform_mac) + .bind(&signup.platform_windows) + .bind(&signup.editor_features) + .bind(&signup.programming_languages) + .bind(&signup.device_id) + .execute(&self.pool) + .await?; + Ok(()) + }) } async fn get_waitlist_summary(&self) -> Result { - Ok(sqlx::query_as( - " - SELECT - COUNT(*) as count, - COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count, - COALESCE(SUM(CASE WHEN platform_mac THEN 1 ELSE 0 END), 0) as mac_count, - COALESCE(SUM(CASE WHEN platform_windows THEN 1 ELSE 0 END), 0) as windows_count, - COALESCE(SUM(CASE WHEN platform_unknown THEN 1 ELSE 0 END), 0) as unknown_count - FROM ( - SELECT * - FROM signups - WHERE - NOT email_confirmation_sent - ) AS unsent - ", - ) - .fetch_one(&self.pool) - .await?) + test_support!(self, { + Ok(sqlx::query_as( + " + SELECT + COUNT(*) as count, + COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count, + COALESCE(SUM(CASE WHEN platform_mac THEN 1 ELSE 0 END), 0) as mac_count, + COALESCE(SUM(CASE WHEN platform_windows THEN 1 ELSE 0 END), 0) as windows_count, + COALESCE(SUM(CASE WHEN platform_unknown THEN 1 ELSE 0 END), 0) as unknown_count + FROM ( + SELECT * + FROM signups + WHERE + NOT email_confirmation_sent + ) AS unsent + ", + ) + .fetch_one(&self.pool) + .await?) + }) } async fn get_unsent_invites(&self, count: usize) -> Result> { - Ok(sqlx::query_as( - " - SELECT - email_address, email_confirmation_code - FROM signups - WHERE - NOT email_confirmation_sent AND - (platform_mac OR platform_unknown) - LIMIT $1 - ", - ) - .bind(count as i32) - .fetch_all(&self.pool) - .await?) + test_support!(self, { + Ok(sqlx::query_as( + " + SELECT + email_address, email_confirmation_code + FROM signups + WHERE + NOT email_confirmation_sent AND + (platform_mac OR platform_unknown) + LIMIT $1 + ", + ) + .bind(count as i32) + .fetch_all(&self.pool) + .await?) + }) } async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> { - sqlx::query( - " - UPDATE signups - SET email_confirmation_sent = 't' - WHERE email_address = ANY ($1) - ", - ) - .bind( - &invites - .iter() - .map(|s| s.email_address.as_str()) - .collect::>(), - ) - .execute(&self.pool) - .await?; - Ok(()) + test_support!(self, { + sqlx::query( + " + UPDATE signups + SET email_confirmation_sent = 't' + WHERE email_address = ANY ($1) + ", + ) + .bind( + &invites + .iter() + .map(|s| s.email_address.as_str()) + .collect::>(), + ) + .execute(&self.pool) + .await?; + Ok(()) + }) } async fn create_user_from_invite( @@ -535,176 +579,184 @@ impl Db for PostgresDb { invite: &Invite, user: NewUserParams, ) -> Result> { - let mut tx = self.pool.begin().await?; + test_support!(self, { + let mut tx = self.pool.begin().await?; - let (signup_id, existing_user_id, inviting_user_id, signup_device_id): ( - i32, - Option, - Option, - Option, - ) = sqlx::query_as( - " - SELECT id, user_id, inviting_user_id, device_id - FROM signups - WHERE - email_address = $1 AND - email_confirmation_code = $2 - ", - ) - .bind(&invite.email_address) - .bind(&invite.email_confirmation_code) - .fetch_optional(&mut tx) - .await? - .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?; - - if existing_user_id.is_some() { - return Ok(None); - } - - let (user_id, metrics_id): (UserId, String) = sqlx::query_as( - " - INSERT INTO users - (email_address, github_login, github_user_id, admin, invite_count, invite_code) - VALUES - ($1, $2, $3, 'f', $4, $5) - ON CONFLICT (github_login) DO UPDATE SET - email_address = excluded.email_address, - github_user_id = excluded.github_user_id, - admin = excluded.admin - RETURNING id, metrics_id::text - ", - ) - .bind(&invite.email_address) - .bind(&user.github_login) - .bind(&user.github_user_id) - .bind(&user.invite_count) - .bind(random_invite_code()) - .fetch_one(&mut tx) - .await?; - - sqlx::query( - " - UPDATE signups - SET user_id = $1 - WHERE id = $2 - ", - ) - .bind(&user_id) - .bind(&signup_id) - .execute(&mut tx) - .await?; - - if let Some(inviting_user_id) = inviting_user_id { - let id: Option = sqlx::query_scalar( + let (signup_id, existing_user_id, inviting_user_id, signup_device_id): ( + i32, + Option, + Option, + Option, + ) = sqlx::query_as( " - UPDATE users - SET invite_count = invite_count - 1 - WHERE id = $1 AND invite_count > 0 - RETURNING id + SELECT id, user_id, inviting_user_id, device_id + FROM signups + WHERE + email_address = $1 AND + email_confirmation_code = $2 ", ) - .bind(&inviting_user_id) + .bind(&invite.email_address) + .bind(&invite.email_confirmation_code) .fetch_optional(&mut tx) - .await?; + .await? + .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?; - if id.is_none() { - Err(Error::Http( - StatusCode::UNAUTHORIZED, - "no invites remaining".to_string(), - ))?; + if existing_user_id.is_some() { + return Ok(None); } + let (user_id, metrics_id): (UserId, String) = sqlx::query_as( + " + INSERT INTO users + (email_address, github_login, github_user_id, admin, invite_count, invite_code) + VALUES + ($1, $2, $3, 'f', $4, $5) + ON CONFLICT (github_login) DO UPDATE SET + email_address = excluded.email_address, + github_user_id = excluded.github_user_id, + admin = excluded.admin + RETURNING id, metrics_id::text + ", + ) + .bind(&invite.email_address) + .bind(&user.github_login) + .bind(&user.github_user_id) + .bind(&user.invite_count) + .bind(random_invite_code()) + .fetch_one(&mut tx) + .await?; + sqlx::query( " - INSERT INTO contacts - (user_id_a, user_id_b, a_to_b, should_notify, accepted) - VALUES - ($1, $2, 't', 't', 't') - ON CONFLICT DO NOTHING + UPDATE signups + SET user_id = $1 + WHERE id = $2 ", ) - .bind(inviting_user_id) - .bind(user_id) + .bind(&user_id) + .bind(&signup_id) .execute(&mut tx) .await?; - } - tx.commit().await?; - Ok(Some(NewUserResult { - user_id, - metrics_id, - inviting_user_id, - signup_device_id, - })) + if let Some(inviting_user_id) = inviting_user_id { + let id: Option = sqlx::query_scalar( + " + UPDATE users + SET invite_count = invite_count - 1 + WHERE id = $1 AND invite_count > 0 + RETURNING id + ", + ) + .bind(&inviting_user_id) + .fetch_optional(&mut tx) + .await?; + + if id.is_none() { + Err(Error::Http( + StatusCode::UNAUTHORIZED, + "no invites remaining".to_string(), + ))?; + } + + sqlx::query( + " + INSERT INTO contacts + (user_id_a, user_id_b, a_to_b, should_notify, accepted) + VALUES + ($1, $2, 't', 't', 't') + ON CONFLICT DO NOTHING + ", + ) + .bind(inviting_user_id) + .bind(user_id) + .execute(&mut tx) + .await?; + } + + tx.commit().await?; + Ok(Some(NewUserResult { + user_id, + metrics_id, + inviting_user_id, + signup_device_id, + })) + }) } // invite codes async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> { - let mut tx = self.pool.begin().await?; - if count > 0 { + test_support!(self, { + let mut tx = self.pool.begin().await?; + if count > 0 { + sqlx::query( + " + UPDATE users + SET invite_code = $1 + WHERE id = $2 AND invite_code IS NULL + ", + ) + .bind(random_invite_code()) + .bind(id) + .execute(&mut tx) + .await?; + } + sqlx::query( " UPDATE users - SET invite_code = $1 - WHERE id = $2 AND invite_code IS NULL - ", + SET invite_count = $1 + WHERE id = $2 + ", ) - .bind(random_invite_code()) + .bind(count as i32) .bind(id) .execute(&mut tx) .await?; - } - - sqlx::query( - " - UPDATE users - SET invite_count = $1 - WHERE id = $2 - ", - ) - .bind(count as i32) - .bind(id) - .execute(&mut tx) - .await?; - tx.commit().await?; - Ok(()) + tx.commit().await?; + Ok(()) + }) } async fn get_invite_code_for_user(&self, id: UserId) -> Result> { - let result: Option<(String, i32)> = sqlx::query_as( - " - SELECT invite_code, invite_count - FROM users - WHERE id = $1 AND invite_code IS NOT NULL - ", - ) - .bind(id) - .fetch_optional(&self.pool) - .await?; - if let Some((code, count)) = result { - Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?))) - } else { - Ok(None) - } + test_support!(self, { + let result: Option<(String, i32)> = sqlx::query_as( + " + SELECT invite_code, invite_count + FROM users + WHERE id = $1 AND invite_code IS NOT NULL + ", + ) + .bind(id) + .fetch_optional(&self.pool) + .await?; + if let Some((code, count)) = result { + Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?))) + } else { + Ok(None) + } + }) } async fn get_user_for_invite_code(&self, code: &str) -> Result { - sqlx::query_as( - " - SELECT * - FROM users - WHERE invite_code = $1 - ", - ) - .bind(code) - .fetch_optional(&self.pool) - .await? - .ok_or_else(|| { - Error::Http( - StatusCode::NOT_FOUND, - "that invite code does not exist".to_string(), + test_support!(self, { + sqlx::query_as( + " + SELECT * + FROM users + WHERE invite_code = $1 + ", ) + .bind(code) + .fetch_optional(&self.pool) + .await? + .ok_or_else(|| { + Error::Http( + StatusCode::NOT_FOUND, + "that invite code does not exist".to_string(), + ) + }) }) } @@ -714,113 +766,119 @@ impl Db for PostgresDb { email_address: &str, device_id: Option<&str>, ) -> Result { - let mut tx = self.pool.begin().await?; + test_support!(self, { + let mut tx = self.pool.begin().await?; - let existing_user: Option = sqlx::query_scalar( - " - SELECT id - FROM users - WHERE email_address = $1 - ", - ) - .bind(email_address) - .fetch_optional(&mut tx) - .await?; - if existing_user.is_some() { - Err(anyhow!("email address is already in use"))?; - } - - let row: Option<(UserId, i32)> = sqlx::query_as( - " - SELECT id, invite_count - FROM users - WHERE invite_code = $1 - ", - ) - .bind(code) - .fetch_optional(&mut tx) - .await?; - - let (inviter_id, invite_count) = match row { - Some(row) => row, - None => Err(Error::Http( - StatusCode::NOT_FOUND, - "invite code not found".to_string(), - ))?, - }; - - if invite_count == 0 { - Err(Error::Http( - StatusCode::UNAUTHORIZED, - "no invites remaining".to_string(), - ))?; - } - - let email_confirmation_code: String = sqlx::query_scalar( - " - INSERT INTO signups - ( - email_address, - email_confirmation_code, - email_confirmation_sent, - inviting_user_id, - platform_linux, - platform_mac, - platform_windows, - platform_unknown, - device_id + let existing_user: Option = sqlx::query_scalar( + " + SELECT id + FROM users + WHERE email_address = $1 + ", ) - VALUES - ($1, $2, 'f', $3, 'f', 'f', 'f', 't', $4) - ON CONFLICT (email_address) - DO UPDATE SET - inviting_user_id = excluded.inviting_user_id - RETURNING email_confirmation_code - ", - ) - .bind(&email_address) - .bind(&random_email_confirmation_code()) - .bind(&inviter_id) - .bind(&device_id) - .fetch_one(&mut tx) - .await?; + .bind(email_address) + .fetch_optional(&mut tx) + .await?; + if existing_user.is_some() { + Err(anyhow!("email address is already in use"))?; + } - tx.commit().await?; + let row: Option<(UserId, i32)> = sqlx::query_as( + " + SELECT id, invite_count + FROM users + WHERE invite_code = $1 + ", + ) + .bind(code) + .fetch_optional(&mut tx) + .await?; - Ok(Invite { - email_address: email_address.into(), - email_confirmation_code, + let (inviter_id, invite_count) = match row { + Some(row) => row, + None => Err(Error::Http( + StatusCode::NOT_FOUND, + "invite code not found".to_string(), + ))?, + }; + + if invite_count == 0 { + Err(Error::Http( + StatusCode::UNAUTHORIZED, + "no invites remaining".to_string(), + ))?; + } + + let email_confirmation_code: String = sqlx::query_scalar( + " + INSERT INTO signups + ( + email_address, + email_confirmation_code, + email_confirmation_sent, + inviting_user_id, + platform_linux, + platform_mac, + platform_windows, + platform_unknown, + device_id + ) + VALUES + ($1, $2, 'f', $3, 'f', 'f', 'f', 't', $4) + ON CONFLICT (email_address) + DO UPDATE SET + inviting_user_id = excluded.inviting_user_id + RETURNING email_confirmation_code + ", + ) + .bind(&email_address) + .bind(&random_email_confirmation_code()) + .bind(&inviter_id) + .bind(&device_id) + .fetch_one(&mut tx) + .await?; + + tx.commit().await?; + + Ok(Invite { + email_address: email_address.into(), + email_confirmation_code, + }) }) } // projects async fn register_project(&self, host_user_id: UserId) -> Result { - Ok(sqlx::query_scalar( - " - INSERT INTO projects(host_user_id) - VALUES ($1) - RETURNING id - ", - ) - .bind(host_user_id) - .fetch_one(&self.pool) - .await - .map(ProjectId)?) + test_support!(self, { + Ok(sqlx::query_scalar( + " + INSERT INTO projects(host_user_id) + VALUES ($1) + RETURNING id + ", + ) + .bind(host_user_id) + .fetch_one(&self.pool) + .await + .map(ProjectId)?) + }) } async fn unregister_project(&self, project_id: ProjectId) -> Result<()> { - sqlx::query( - " - UPDATE projects - SET unregistered = 't' - WHERE id = $1 - ", - ) - .bind(project_id) - .execute(&self.pool) - .await?; - Ok(()) + test_support!(self, { + sqlx::query( + " + UPDATE projects + SET unregistered = 't' + WHERE id = $1 + ", + ) + .bind(project_id) + .execute(&self.pool) + .await?; + Ok(()) + }) } async fn update_worktree_extensions( @@ -829,60 +887,64 @@ impl Db for PostgresDb { worktree_id: u64, extensions: HashMap, ) -> Result<()> { - if extensions.is_empty() { - return Ok(()); - } + test_support!(self, { + if extensions.is_empty() { + return Ok(()); + } - let mut query = QueryBuilder::new( - "INSERT INTO worktree_extensions (project_id, worktree_id, extension, count)", - ); - query.push_values(extensions, |mut query, (extension, count)| { - query - .push_bind(project_id) - .push_bind(worktree_id as i32) - .push_bind(extension) - .push_bind(count as i32); - }); - query.push( - " - ON CONFLICT (project_id, worktree_id, extension) DO UPDATE SET - count = excluded.count - ", - ); - query.build().execute(&self.pool).await?; + let mut query = QueryBuilder::new( + "INSERT INTO worktree_extensions (project_id, worktree_id, extension, count)", + ); + query.push_values(extensions, |mut query, (extension, count)| { + query + .push_bind(project_id) + .push_bind(worktree_id as i32) + .push_bind(extension) + .push_bind(count as i32); + }); + query.push( + " + ON CONFLICT (project_id, worktree_id, extension) DO UPDATE SET + count = excluded.count + ", + ); + query.build().execute(&self.pool).await?; - Ok(()) + Ok(()) + }) } async fn get_project_extensions( &self, project_id: ProjectId, ) -> Result>> { - #[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)] - struct WorktreeExtension { - worktree_id: i32, - extension: String, - count: i32, - } + test_support!(self, { + #[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)] + struct WorktreeExtension { + worktree_id: i32, + extension: String, + count: i32, + } - let query = " - SELECT worktree_id, extension, count - FROM worktree_extensions - WHERE project_id = $1 - "; - let counts = sqlx::query_as::<_, WorktreeExtension>(query) - .bind(&project_id) - .fetch_all(&self.pool) - .await?; + let query = " + SELECT worktree_id, extension, count + FROM worktree_extensions + WHERE project_id = $1 + "; + let counts = sqlx::query_as::<_, WorktreeExtension>(query) + .bind(&project_id) + .fetch_all(&self.pool) + .await?; - let mut extension_counts = HashMap::default(); - for count in counts { - extension_counts - .entry(count.worktree_id as u64) - .or_insert_with(HashMap::default) - .insert(count.extension, count.count as usize); - } - Ok(extension_counts) + let mut extension_counts = HashMap::default(); + for count in counts { + extension_counts + .entry(count.worktree_id as u64) + .or_insert_with(HashMap::default) + .insert(count.extension, count.count as usize); + } + Ok(extension_counts) + }) } async fn record_user_activity( @@ -890,28 +952,30 @@ impl Db for PostgresDb { time_period: Range, projects: &[(UserId, ProjectId)], ) -> Result<()> { - let query = " - INSERT INTO project_activity_periods - (ended_at, duration_millis, user_id, project_id) - VALUES - ($1, $2, $3, $4); - "; + test_support!(self, { + let query = " + INSERT INTO project_activity_periods + (ended_at, duration_millis, user_id, project_id) + VALUES + ($1, $2, $3, $4); + "; - let mut tx = self.pool.begin().await?; - let duration_millis = - ((time_period.end - time_period.start).as_seconds_f64() * 1000.0) as i32; - for (user_id, project_id) in projects { - sqlx::query(query) - .bind(time_period.end) - .bind(duration_millis) - .bind(user_id) - .bind(project_id) - .execute(&mut tx) - .await?; - } - tx.commit().await?; + let mut tx = self.pool.begin().await?; + let duration_millis = + ((time_period.end - time_period.start).as_seconds_f64() * 1000.0) as i32; + for (user_id, project_id) in projects { + sqlx::query(query) + .bind(time_period.end) + .bind(duration_millis) + .bind(user_id) + .bind(project_id) + .execute(&mut tx) + .await?; + } + tx.commit().await?; - Ok(()) + Ok(()) + }) } async fn get_active_user_count( @@ -920,73 +984,75 @@ impl Db for PostgresDb { min_duration: Duration, only_collaborative: bool, ) -> Result { - let mut with_clause = String::new(); - with_clause.push_str("WITH\n"); - with_clause.push_str( - " - project_durations AS ( - SELECT user_id, project_id, SUM(duration_millis) AS project_duration - FROM project_activity_periods - WHERE $1 < ended_at AND ended_at <= $2 - GROUP BY user_id, project_id - ), - ", - ); - with_clause.push_str( - " - project_collaborators as ( - SELECT project_id, COUNT(DISTINCT user_id) as max_collaborators - FROM project_durations - GROUP BY project_id - ), - ", - ); - - if only_collaborative { + test_support!(self, { + let mut with_clause = String::new(); + with_clause.push_str("WITH\n"); with_clause.push_str( " - user_durations AS ( - SELECT user_id, SUM(project_duration) as total_duration - FROM project_durations, project_collaborators - WHERE - project_durations.project_id = project_collaborators.project_id AND - max_collaborators > 1 - GROUP BY user_id - ORDER BY total_duration DESC - LIMIT $3 - ) + project_durations AS ( + SELECT user_id, project_id, SUM(duration_millis) AS project_duration + FROM project_activity_periods + WHERE $1 < ended_at AND ended_at <= $2 + GROUP BY user_id, project_id + ), ", ); - } else { with_clause.push_str( " - user_durations AS ( - SELECT user_id, SUM(project_duration) as total_duration + project_collaborators as ( + SELECT project_id, COUNT(DISTINCT user_id) as max_collaborators FROM project_durations - GROUP BY user_id - ORDER BY total_duration DESC - LIMIT $3 - ) + GROUP BY project_id + ), ", ); - } - let query = format!( - " - {with_clause} - SELECT count(user_durations.user_id) - FROM user_durations - WHERE user_durations.total_duration >= $3 - " - ); + if only_collaborative { + with_clause.push_str( + " + user_durations AS ( + SELECT user_id, SUM(project_duration) as total_duration + FROM project_durations, project_collaborators + WHERE + project_durations.project_id = project_collaborators.project_id AND + max_collaborators > 1 + GROUP BY user_id + ORDER BY total_duration DESC + LIMIT $3 + ) + ", + ); + } else { + with_clause.push_str( + " + user_durations AS ( + SELECT user_id, SUM(project_duration) as total_duration + FROM project_durations + GROUP BY user_id + ORDER BY total_duration DESC + LIMIT $3 + ) + ", + ); + } - let count: i64 = sqlx::query_scalar(&query) - .bind(time_period.start) - .bind(time_period.end) - .bind(min_duration.as_millis() as i64) - .fetch_one(&self.pool) - .await?; - Ok(count as usize) + let query = format!( + " + {with_clause} + SELECT count(user_durations.user_id) + FROM user_durations + WHERE user_durations.total_duration >= $3 + " + ); + + let count: i64 = sqlx::query_scalar(&query) + .bind(time_period.start) + .bind(time_period.end) + .bind(min_duration.as_millis() as i64) + .fetch_one(&self.pool) + .await?; + Ok(count as usize) + }) } async fn get_top_users_activity_summary( @@ -994,65 +1060,68 @@ impl Db for PostgresDb { time_period: Range, max_user_count: usize, ) -> Result> { - let query = " - WITH - project_durations AS ( - SELECT user_id, project_id, SUM(duration_millis) AS project_duration - FROM project_activity_periods - WHERE $1 < ended_at AND ended_at <= $2 - GROUP BY user_id, project_id - ), - user_durations AS ( - SELECT user_id, SUM(project_duration) as total_duration - FROM project_durations - GROUP BY user_id - ORDER BY total_duration DESC - LIMIT $3 - ), - project_collaborators as ( - SELECT project_id, COUNT(DISTINCT user_id) as max_collaborators - FROM project_durations - GROUP BY project_id - ) - SELECT user_durations.user_id, users.github_login, project_durations.project_id, project_duration, max_collaborators - FROM user_durations, project_durations, project_collaborators, users - WHERE - user_durations.user_id = project_durations.user_id AND - user_durations.user_id = users.id AND - project_durations.project_id = project_collaborators.project_id - ORDER BY total_duration DESC, user_id ASC, project_id ASC - "; + test_support!(self, { + let query = " + WITH + project_durations AS ( + SELECT user_id, project_id, SUM(duration_millis) AS project_duration + FROM project_activity_periods + WHERE $1 < ended_at AND ended_at <= $2 + GROUP BY user_id, project_id + ), + user_durations AS ( + SELECT user_id, SUM(project_duration) as total_duration + FROM project_durations + GROUP BY user_id + ORDER BY total_duration DESC + LIMIT $3 + ), + project_collaborators as ( + SELECT project_id, COUNT(DISTINCT user_id) as max_collaborators + FROM project_durations + GROUP BY project_id + ) + SELECT user_durations.user_id, users.github_login, project_durations.project_id, project_duration, max_collaborators + FROM user_durations, project_durations, project_collaborators, users + WHERE + user_durations.user_id = project_durations.user_id AND + user_durations.user_id = users.id AND + project_durations.project_id = project_collaborators.project_id + ORDER BY total_duration DESC, user_id ASC, project_id ASC + "; - let mut rows = sqlx::query_as::<_, (UserId, String, ProjectId, i64, i64)>(query) - .bind(time_period.start) - .bind(time_period.end) - .bind(max_user_count as i32) - .fetch(&self.pool); + let mut rows = sqlx::query_as::<_, (UserId, String, ProjectId, i64, i64)>(query) + .bind(time_period.start) + .bind(time_period.end) + .bind(max_user_count as i32) + .fetch(&self.pool); - let mut result = Vec::::new(); - while let Some(row) = rows.next().await { - let (user_id, github_login, project_id, duration_millis, project_collaborators) = row?; - let project_id = project_id; - let duration = Duration::from_millis(duration_millis as u64); - let project_activity = ProjectActivitySummary { - id: project_id, - duration, - max_collaborators: project_collaborators as usize, - }; - if let Some(last_summary) = result.last_mut() { - if last_summary.id == user_id { - last_summary.project_activity.push(project_activity); - continue; + let mut result = Vec::::new(); + while let Some(row) = rows.next().await { + let (user_id, github_login, project_id, duration_millis, project_collaborators) = + row?; + let project_id = project_id; + let duration = Duration::from_millis(duration_millis as u64); + let project_activity = ProjectActivitySummary { + id: project_id, + duration, + max_collaborators: project_collaborators as usize, + }; + if let Some(last_summary) = result.last_mut() { + if last_summary.id == user_id { + last_summary.project_activity.push(project_activity); + continue; + } } + result.push(UserActivitySummary { + id: user_id, + project_activity: vec![project_activity], + github_login, + }); } - result.push(UserActivitySummary { - id: user_id, - project_activity: vec![project_activity], - github_login, - }); - } - Ok(result) + Ok(result) + }) } async fn get_user_activity_timeline( @@ -1060,55 +1129,64 @@ impl Db for PostgresDb { time_period: Range, user_id: UserId, ) -> Result> { - const COALESCE_THRESHOLD: Duration = Duration::from_secs(30); + test_support!(self, { + const COALESCE_THRESHOLD: Duration = Duration::from_secs(30); - let query = " - SELECT - project_activity_periods.ended_at, - project_activity_periods.duration_millis, - project_activity_periods.project_id, - worktree_extensions.extension, - worktree_extensions.count - FROM project_activity_periods - LEFT OUTER JOIN - worktree_extensions - ON - project_activity_periods.project_id = worktree_extensions.project_id - WHERE - project_activity_periods.user_id = $1 AND - $2 < project_activity_periods.ended_at AND - project_activity_periods.ended_at <= $3 - ORDER BY project_activity_periods.id ASC - "; + let query = " + SELECT + project_activity_periods.ended_at, + project_activity_periods.duration_millis, + project_activity_periods.project_id, + worktree_extensions.extension, + worktree_extensions.count + FROM project_activity_periods + LEFT OUTER JOIN + worktree_extensions + ON + project_activity_periods.project_id = worktree_extensions.project_id + WHERE + project_activity_periods.user_id = $1 AND + $2 < project_activity_periods.ended_at AND + project_activity_periods.ended_at <= $3 + ORDER BY project_activity_periods.id ASC + "; - let mut rows = sqlx::query_as::< - _, - ( - PrimitiveDateTime, - i32, - ProjectId, - Option, - Option, - ), - >(query) - .bind(user_id) - .bind(time_period.start) - .bind(time_period.end) - .fetch(&self.pool); + let mut rows = sqlx::query_as::< + _, + ( + PrimitiveDateTime, + i32, + ProjectId, + Option, + Option, + ), + >(query) + .bind(user_id) + .bind(time_period.start) + .bind(time_period.end) + .fetch(&self.pool); - let mut time_periods: HashMap> = Default::default(); - while let Some(row) = rows.next().await { - let (ended_at, duration_millis, project_id, extension, extension_count) = row?; - let ended_at = ended_at.assume_utc(); - let duration = Duration::from_millis(duration_millis as u64); - let started_at = ended_at - duration; - let project_time_periods = time_periods.entry(project_id).or_default(); + let mut time_periods: HashMap> = Default::default(); + while let Some(row) = rows.next().await { + let (ended_at, duration_millis, project_id, extension, extension_count) = row?; + let ended_at = ended_at.assume_utc(); + let duration = Duration::from_millis(duration_millis as u64); + let started_at = ended_at - duration; + let project_time_periods = time_periods.entry(project_id).or_default(); - if let Some(prev_duration) = project_time_periods.last_mut() { - if started_at <= prev_duration.end + COALESCE_THRESHOLD - && ended_at >= prev_duration.start - { - prev_duration.end = cmp::max(prev_duration.end, ended_at); + if let Some(prev_duration) = project_time_periods.last_mut() { + if started_at <= prev_duration.end + COALESCE_THRESHOLD + && ended_at >= prev_duration.start + { + prev_duration.end = cmp::max(prev_duration.end, ended_at); + } else { + project_time_periods.push(UserActivityPeriod { + project_id, + start: started_at, + end: ended_at, + extensions: Default::default(), + }); + } } else { project_time_periods.push(UserActivityPeriod { project_id, @@ -1117,153 +1195,154 @@ impl Db for PostgresDb { extensions: Default::default(), }); } - } else { - project_time_periods.push(UserActivityPeriod { - project_id, - start: started_at, - end: ended_at, - extensions: Default::default(), - }); + + if let Some((extension, extension_count)) = extension.zip(extension_count) { + project_time_periods + .last_mut() + .unwrap() + .extensions + .insert(extension, extension_count as usize); + } } - if let Some((extension, extension_count)) = extension.zip(extension_count) { - project_time_periods - .last_mut() - .unwrap() - .extensions - .insert(extension, extension_count as usize); - } - } - - let mut durations = time_periods.into_values().flatten().collect::>(); - durations.sort_unstable_by_key(|duration| duration.start); - Ok(durations) + let mut durations = time_periods.into_values().flatten().collect::>(); + durations.sort_unstable_by_key(|duration| duration.start); + Ok(durations) + }) } // contacts async fn get_contacts(&self, user_id: UserId) -> Result> { - let query = " - SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify - FROM contacts - WHERE user_id_a = $1 OR user_id_b = $1; - "; + test_support!(self, { + let query = " + SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify + FROM contacts + WHERE user_id_a = $1 OR user_id_b = $1; + "; - let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query) - .bind(user_id) - .fetch(&self.pool); + let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query) + .bind(user_id) + .fetch(&self.pool); - let mut contacts = Vec::new(); - while let Some(row) = rows.next().await { - let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?; + let mut contacts = Vec::new(); + while let Some(row) = rows.next().await { + let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?; - if user_id_a == user_id { - if accepted { + if user_id_a == user_id { + if accepted { + contacts.push(Contact::Accepted { + user_id: user_id_b, + should_notify: should_notify && a_to_b, + }); + } else if a_to_b { + contacts.push(Contact::Outgoing { user_id: user_id_b }) + } else { + contacts.push(Contact::Incoming { + user_id: user_id_b, + should_notify, + }); + } + } else if accepted { contacts.push(Contact::Accepted { - user_id: user_id_b, - should_notify: should_notify && a_to_b, + user_id: user_id_a, + should_notify: should_notify && !a_to_b, }); } else if a_to_b { - contacts.push(Contact::Outgoing { user_id: user_id_b }) - } else { contacts.push(Contact::Incoming { - user_id: user_id_b, + user_id: user_id_a, should_notify, }); + } else { + contacts.push(Contact::Outgoing { user_id: user_id_a }); } - } else if accepted { - contacts.push(Contact::Accepted { - user_id: user_id_a, - should_notify: should_notify && !a_to_b, - }); - } else if a_to_b { - contacts.push(Contact::Incoming { - user_id: user_id_a, - should_notify, - }); - } else { - contacts.push(Contact::Outgoing { user_id: user_id_a }); } - } - contacts.sort_unstable_by_key(|contact| contact.user_id()); + contacts.sort_unstable_by_key(|contact| contact.user_id()); - Ok(contacts) + Ok(contacts) + }) } async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result { - let (id_a, id_b) = if user_id_1 < user_id_2 { - (user_id_1, user_id_2) - } else { - (user_id_2, user_id_1) - }; + test_support!(self, { + let (id_a, id_b) = if user_id_1 < user_id_2 { + (user_id_1, user_id_2) + } else { + (user_id_2, user_id_1) + }; - let query = " - SELECT 1 FROM contacts - WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't' - LIMIT 1 - "; - Ok(sqlx::query_scalar::<_, i32>(query) - .bind(id_a.0) - .bind(id_b.0) - .fetch_optional(&self.pool) - .await? - .is_some()) + let query = " + SELECT 1 FROM contacts + WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't' + LIMIT 1 + "; + Ok(sqlx::query_scalar::<_, i32>(query) + .bind(id_a.0) + .bind(id_b.0) + .fetch_optional(&self.pool) + .await? + .is_some()) + }) } async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> { - let (id_a, id_b, a_to_b) = if sender_id < receiver_id { - (sender_id, receiver_id, true) - } else { - (receiver_id, sender_id, false) - }; - let query = " - INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify) - VALUES ($1, $2, $3, 'f', 't') - ON CONFLICT (user_id_a, user_id_b) DO UPDATE - SET - accepted = 't', - should_notify = 'f' - WHERE - NOT contacts.accepted AND - ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR - (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a)); - "; - let result = sqlx::query(query) - .bind(id_a.0) - .bind(id_b.0) - .bind(a_to_b) - .execute(&self.pool) - .await?; + test_support!(self, { + let (id_a, id_b, a_to_b) = if sender_id < receiver_id { + (sender_id, receiver_id, true) + } else { + (receiver_id, sender_id, false) + }; + let query = " + INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify) + VALUES ($1, $2, $3, 'f', 't') + ON CONFLICT (user_id_a, user_id_b) DO UPDATE + SET + accepted = 't', + should_notify = 'f' + WHERE + NOT contacts.accepted AND + ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR + (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a)); + "; + let result = sqlx::query(query) + .bind(id_a.0) + .bind(id_b.0) + .bind(a_to_b) + .execute(&self.pool) + .await?; - if result.rows_affected() == 1 { - Ok(()) - } else { - Err(anyhow!("contact already requested"))? - } + if result.rows_affected() == 1 { + Ok(()) + } else { + Err(anyhow!("contact already requested"))? + } + }) } async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> { - let (id_a, id_b) = if responder_id < requester_id { - (responder_id, requester_id) - } else { - (requester_id, responder_id) - }; - let query = " - DELETE FROM contacts - WHERE user_id_a = $1 AND user_id_b = $2; - "; - let result = sqlx::query(query) - .bind(id_a.0) - .bind(id_b.0) - .execute(&self.pool) - .await?; + test_support!(self, { + let (id_a, id_b) = if responder_id < requester_id { + (responder_id, requester_id) + } else { + (requester_id, responder_id) + }; + let query = " + DELETE FROM contacts + WHERE user_id_a = $1 AND user_id_b = $2; + "; + let result = sqlx::query(query) + .bind(id_a.0) + .bind(id_b.0) + .execute(&self.pool) + .await?; - if result.rows_affected() == 1 { - Ok(()) - } else { - Err(anyhow!("no such contact"))? - } + if result.rows_affected() == 1 { + Ok(()) + } else { + Err(anyhow!("no such contact"))? + } + }) } async fn dismiss_contact_notification( @@ -1271,35 +1350,37 @@ impl Db for PostgresDb { user_id: UserId, contact_user_id: UserId, ) -> Result<()> { - let (id_a, id_b, a_to_b) = if user_id < contact_user_id { - (user_id, contact_user_id, true) - } else { - (contact_user_id, user_id, false) - }; + test_support!(self, { + let (id_a, id_b, a_to_b) = if user_id < contact_user_id { + (user_id, contact_user_id, true) + } else { + (contact_user_id, user_id, false) + }; - let query = " - UPDATE contacts - SET should_notify = 'f' - WHERE - user_id_a = $1 AND user_id_b = $2 AND - ( - (a_to_b = $3 AND accepted) OR - (a_to_b != $3 AND NOT accepted) - ); - "; + let query = " + UPDATE contacts + SET should_notify = 'f' + WHERE + user_id_a = $1 AND user_id_b = $2 AND + ( + (a_to_b = $3 AND accepted) OR + (a_to_b != $3 AND NOT accepted) + ); + "; - let result = sqlx::query(query) - .bind(id_a.0) - .bind(id_b.0) - .bind(a_to_b) - .execute(&self.pool) - .await?; + let result = sqlx::query(query) + .bind(id_a.0) + .bind(id_b.0) + .bind(a_to_b) + .execute(&self.pool) + .await?; - if result.rows_affected() == 0 { - Err(anyhow!("no such contact request"))?; - } + if result.rows_affected() == 0 { + Err(anyhow!("no such contact request"))?; + } - Ok(()) + Ok(()) + }) } async fn respond_to_contact_request( @@ -1308,40 +1389,42 @@ impl Db for PostgresDb { requester_id: UserId, accept: bool, ) -> Result<()> { - let (id_a, id_b, a_to_b) = if responder_id < requester_id { - (responder_id, requester_id, false) - } else { - (requester_id, responder_id, true) - }; - let result = if accept { - let query = " - UPDATE contacts - SET accepted = 't', should_notify = 't' - WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3; - "; - sqlx::query(query) - .bind(id_a.0) - .bind(id_b.0) - .bind(a_to_b) - .execute(&self.pool) - .await? - } else { - let query = " - DELETE FROM contacts - WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted; - "; - sqlx::query(query) - .bind(id_a.0) - .bind(id_b.0) - .bind(a_to_b) - .execute(&self.pool) - .await? - }; - if result.rows_affected() == 1 { - Ok(()) - } else { - Err(anyhow!("no such contact request"))? - } + test_support!(self, { + let (id_a, id_b, a_to_b) = if responder_id < requester_id { + (responder_id, requester_id, false) + } else { + (requester_id, responder_id, true) + }; + let result = if accept { + let query = " + UPDATE contacts + SET accepted = 't', should_notify = 't' + WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3; + "; + sqlx::query(query) + .bind(id_a.0) + .bind(id_b.0) + .bind(a_to_b) + .execute(&self.pool) + .await? + } else { + let query = " + DELETE FROM contacts + WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted; + "; + sqlx::query(query) + .bind(id_a.0) + .bind(id_b.0) + .bind(a_to_b) + .execute(&self.pool) + .await? + }; + if result.rows_affected() == 1 { + Ok(()) + } else { + Err(anyhow!("no such contact request"))? + } + }) } // access tokens @@ -1352,46 +1435,50 @@ impl Db for PostgresDb { access_token_hash: &str, max_access_token_count: usize, ) -> Result<()> { - let insert_query = " - INSERT INTO access_tokens (user_id, hash) - VALUES ($1, $2); - "; - let cleanup_query = " - DELETE FROM access_tokens - WHERE id IN ( - SELECT id from access_tokens - WHERE user_id = $1 - ORDER BY id DESC - OFFSET $3 - ) - "; + test_support!(self, { + let insert_query = " + INSERT INTO access_tokens (user_id, hash) + VALUES ($1, $2); + "; + let cleanup_query = " + DELETE FROM access_tokens + WHERE id IN ( + SELECT id from access_tokens + WHERE user_id = $1 + ORDER BY id DESC + OFFSET $3 + ) + "; - let mut tx = self.pool.begin().await?; - sqlx::query(insert_query) - .bind(user_id.0) - .bind(access_token_hash) - .execute(&mut tx) - .await?; - sqlx::query(cleanup_query) - .bind(user_id.0) - .bind(access_token_hash) - .bind(max_access_token_count as i32) - .execute(&mut tx) - .await?; - Ok(tx.commit().await?) + let mut tx = self.pool.begin().await?; + sqlx::query(insert_query) + .bind(user_id.0) + .bind(access_token_hash) + .execute(&mut tx) + .await?; + sqlx::query(cleanup_query) + .bind(user_id.0) + .bind(access_token_hash) + .bind(max_access_token_count as i32) + .execute(&mut tx) + .await?; + Ok(tx.commit().await?) + }) } async fn get_access_token_hashes(&self, user_id: UserId) -> Result> { - let query = " - SELECT hash - FROM access_tokens - WHERE user_id = $1 - ORDER BY id DESC - "; - Ok(sqlx::query_scalar(query) - .bind(user_id.0) - .fetch_all(&self.pool) - .await?) + test_support!(self, { + let query = " + SELECT hash + FROM access_tokens + WHERE user_id = $1 + ORDER BY id DESC + "; + Ok(sqlx::query_scalar(query) + .bind(user_id.0) + .fetch_all(&self.pool) + .await?) + }) } // orgs @@ -1399,95 +1486,107 @@ impl Db for PostgresDb { #[allow(unused)] // Help rust-analyzer #[cfg(any(test, feature = "seed-support"))] async fn find_org_by_slug(&self, slug: &str) -> Result> { - let query = " - SELECT * - FROM orgs - WHERE slug = $1 - "; - Ok(sqlx::query_as(query) - .bind(slug) - .fetch_optional(&self.pool) - .await?) + test_support!(self, { + let query = " + SELECT * + FROM orgs + WHERE slug = $1 + "; + Ok(sqlx::query_as(query) + .bind(slug) + .fetch_optional(&self.pool) + .await?) + }) } #[cfg(any(test, feature = "seed-support"))] async fn create_org(&self, name: &str, slug: &str) -> Result { - let query = " - INSERT INTO orgs (name, slug) - VALUES ($1, $2) - RETURNING id - "; - Ok(sqlx::query_scalar(query) - .bind(name) - .bind(slug) - .fetch_one(&self.pool) - .await - .map(OrgId)?) + test_support!(self, { + let query = " + INSERT INTO orgs (name, slug) + VALUES ($1, $2) + RETURNING id + "; + Ok(sqlx::query_scalar(query) + .bind(name) + .bind(slug) + .fetch_one(&self.pool) + .await + .map(OrgId)?) + }) } #[cfg(any(test, feature = "seed-support"))] async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> { - let query = " - INSERT INTO org_memberships (org_id, user_id, admin) - VALUES ($1, $2, $3) - ON CONFLICT DO NOTHING - "; - Ok(sqlx::query(query) - .bind(org_id.0) - .bind(user_id.0) - .bind(is_admin) - .execute(&self.pool) - .await - .map(drop)?) + test_support!(self, { + let query = " + INSERT INTO org_memberships (org_id, user_id, admin) + VALUES ($1, $2, $3) + ON CONFLICT DO NOTHING + "; + Ok(sqlx::query(query) + .bind(org_id.0) + .bind(user_id.0) + .bind(is_admin) + .execute(&self.pool) + .await + .map(drop)?) + }) } // channels #[cfg(any(test, feature = "seed-support"))] async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result { - let query = " - INSERT INTO channels (owner_id, owner_is_user, name) - VALUES ($1, false, $2) - RETURNING id - "; - Ok(sqlx::query_scalar(query) - .bind(org_id.0) - .bind(name) - .fetch_one(&self.pool) - .await - .map(ChannelId)?) + test_support!(self, { + let query = " + INSERT INTO channels (owner_id, owner_is_user, name) + VALUES ($1, false, $2) + RETURNING id + "; + Ok(sqlx::query_scalar(query) + .bind(org_id.0) + .bind(name) + .fetch_one(&self.pool) + .await + .map(ChannelId)?) + }) } #[allow(unused)] // Help rust-analyzer #[cfg(any(test, feature = "seed-support"))] async fn get_org_channels(&self, org_id: OrgId) -> Result> { - let query = " - SELECT * - FROM channels - WHERE - channels.owner_is_user = false AND - channels.owner_id = $1 - "; - Ok(sqlx::query_as(query) - .bind(org_id.0) - .fetch_all(&self.pool) - .await?) + test_support!(self, { + let query = " + SELECT * + FROM channels + WHERE + channels.owner_is_user = false AND + channels.owner_id = $1 + "; + Ok(sqlx::query_as(query) + .bind(org_id.0) + .fetch_all(&self.pool) + .await?) + }) } async fn get_accessible_channels(&self, user_id: UserId) -> Result> { - let query = " - SELECT - channels.* - FROM - channel_memberships, channels - WHERE - channel_memberships.user_id = $1 AND - channel_memberships.channel_id = channels.id - "; - Ok(sqlx::query_as(query) - .bind(user_id.0) - .fetch_all(&self.pool) - .await?) + test_support!(self, { + let query = " + SELECT + channels.* + FROM + channel_memberships, channels + WHERE + channel_memberships.user_id = $1 AND + channel_memberships.channel_id = channels.id + "; + Ok(sqlx::query_as(query) + .bind(user_id.0) + .fetch_all(&self.pool) + .await?) + }) } async fn can_user_access_channel( @@ -1495,18 +1594,20 @@ impl Db for PostgresDb { user_id: UserId, channel_id: ChannelId, ) -> Result { - let query = " - SELECT id - FROM channel_memberships - WHERE user_id = $1 AND channel_id = $2 - LIMIT 1 - "; - Ok(sqlx::query_scalar::<_, i32>(query) - .bind(user_id.0) - .bind(channel_id.0) - .fetch_optional(&self.pool) - .await - .map(|e| e.is_some())?) + test_support!(self, { + let query = " + SELECT id + FROM channel_memberships + WHERE user_id = $1 AND channel_id = $2 + LIMIT 1 + "; + Ok(sqlx::query_scalar::<_, i32>(query) + .bind(user_id.0) + .bind(channel_id.0) + .fetch_optional(&self.pool) + .await + .map(|e| e.is_some())?) + }) } #[cfg(any(test, feature = "seed-support"))] @@ -1516,18 +1617,20 @@ impl Db for PostgresDb { user_id: UserId, is_admin: bool, ) -> Result<()> { - let query = " - INSERT INTO channel_memberships (channel_id, user_id, admin) - VALUES ($1, $2, $3) - ON CONFLICT DO NOTHING - "; - Ok(sqlx::query(query) - .bind(channel_id.0) - .bind(user_id.0) - .bind(is_admin) - .execute(&self.pool) - .await - .map(drop)?) + test_support!(self, { + let query = " + INSERT INTO channel_memberships (channel_id, user_id, admin) + VALUES ($1, $2, $3) + ON CONFLICT DO NOTHING + "; + Ok(sqlx::query(query) + .bind(channel_id.0) + .bind(user_id.0) + .bind(is_admin) + .execute(&self.pool) + .await + .map(drop)?) + }) } // messages @@ -1540,21 +1643,23 @@ impl Db for PostgresDb { timestamp: OffsetDateTime, nonce: u128, ) -> Result { - let query = " - INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce) - VALUES ($1, $2, $3, $4, $5) - ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce - RETURNING id - "; - Ok(sqlx::query_scalar(query) - .bind(channel_id.0) - .bind(sender_id.0) - .bind(body) - .bind(timestamp) - .bind(Uuid::from_u128(nonce)) - .fetch_one(&self.pool) - .await - .map(MessageId)?) + test_support!(self, { + let query = " + INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce) + VALUES ($1, $2, $3, $4, $5) + ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce + RETURNING id + "; + Ok(sqlx::query_scalar(query) + .bind(channel_id.0) + .bind(sender_id.0) + .bind(body) + .bind(timestamp) + .bind(Uuid::from_u128(nonce)) + .fetch_one(&self.pool) + .await + .map(MessageId)?) + }) } async fn get_channel_messages( @@ -1563,42 +1668,49 @@ impl Db for PostgresDb { count: usize, before_id: Option, ) -> Result> { - let query = r#" - SELECT * FROM ( - SELECT - id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce - FROM - channel_messages - WHERE - channel_id = $1 AND - id < $2 - ORDER BY id DESC - LIMIT $3 - ) as recent_messages - ORDER BY id ASC - "#; - Ok(sqlx::query_as(query) - .bind(channel_id.0) - .bind(before_id.unwrap_or(MessageId::MAX)) - .bind(count as i64) - .fetch_all(&self.pool) - .await?) + test_support!(self, { + let query = r#" + SELECT * FROM ( + SELECT + id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce + FROM + channel_messages + WHERE + channel_id = $1 AND + id < $2 + ORDER BY id DESC + LIMIT $3 + ) as recent_messages + ORDER BY id ASC + "#; + Ok(sqlx::query_as(query) + .bind(channel_id.0) + .bind(before_id.unwrap_or(MessageId::MAX)) + .bind(count as i64) + .fetch_all(&self.pool) + .await?) + }) } #[cfg(test)] async fn teardown(&self, url: &str) { - use util::ResultExt; + let start = std::time::Instant::now(); + eprintln!("tearing down database..."); + test_support!(self, { + use util::ResultExt; - let query = " - SELECT pg_terminate_backend(pg_stat_activity.pid) - FROM pg_stat_activity - WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid(); - "; - sqlx::query(query).execute(&self.pool).await.log_err(); - self.pool.close().await; - ::drop_database(url) - .await - .log_err(); + let query = " + SELECT pg_terminate_backend(pg_stat_activity.pid) + FROM pg_stat_activity + WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid(); + "; + sqlx::query(query).execute(&self.pool).await.log_err(); + self.pool.close().await; + ::drop_database(url) + .await + .log_err(); + eprintln!("tore down database: {:?}", start.elapsed()); + }) } #[cfg(test)] @@ -2480,10 +2592,12 @@ mod test { static ref LOCK: Mutex<()> = Mutex::new(()); } + eprintln!("creating database..."); + let start = std::time::Instant::now(); let _guard = LOCK.lock(); let mut rng = StdRng::from_entropy(); let name = format!("zed-test-{}", rng.gen::()); - let url = format!("postgres://postgres@localhost/{}", name); + let url = format!("postgres://postgres@localhost:5433/{}", name); Postgres::create_database(&url) .await .expect("failed to create test db"); @@ -2491,6 +2605,8 @@ mod test { db.migrate(Path::new(DEFAULT_MIGRATIONS_PATH.unwrap()), false) .await .unwrap(); + + eprintln!("created database: {:?}", start.elapsed()); Self { db: Some(Arc::new(db)), url, diff --git a/crates/collab/src/integration_tests.rs b/crates/collab/src/integration_tests.rs index 2bf2701f23..c04d84f5db 100644 --- a/crates/collab/src/integration_tests.rs +++ b/crates/collab/src/integration_tests.rs @@ -53,6 +53,7 @@ use std::{ time::Duration, }; use theme::ThemeRegistry; +use tokio::runtime::{EnterGuard, Runtime}; use unindent::Unindent as _; use util::post_inc; use workspace::{shared_screen::SharedScreen, Item, SplitDirection, ToggleFollow, Workspace}; @@ -72,8 +73,15 @@ async fn test_basic_calls( cx_b2: &mut TestAppContext, cx_c: &mut TestAppContext, ) { + // let runtime = tokio::runtime::Runtime::new().unwrap(); + // let _enter_guard = runtime.enter(); + deterministic.forbid_parking(); let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + + let start = std::time::Instant::now(); + eprintln!("test_basic_calls"); + let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; let client_c = server.create_client(cx_c, "user_c").await; @@ -259,6 +267,8 @@ async fn test_basic_calls( pending: Default::default() } ); + + eprintln!("finished test {:?}", start.elapsed()); } #[gpui::test(iterations = 10)] @@ -6091,7 +6101,12 @@ impl TestServer { ) -> Self { static NEXT_LIVE_KIT_SERVER_ID: AtomicUsize = AtomicUsize::new(0); - let test_db = TestDb::fake(background.clone()); + let test_db = tokio::runtime::Builder::new_current_thread() + .enable_io() + .enable_time() + .build() + .unwrap() + .block_on(TestDb::postgres()); let live_kit_server_id = NEXT_LIVE_KIT_SERVER_ID.fetch_add(1, SeqCst); let live_kit_server = live_kit_client::TestServer::create( format!("http://livekit.{}.test", live_kit_server_id),