diff --git a/crates/collab/src/auth.rs b/crates/collab/src/auth.rs index 9081fe1f1e..63f032f7e6 100644 --- a/crates/collab/src/auth.rs +++ b/crates/collab/src/auth.rs @@ -75,7 +75,7 @@ pub async fn validate_header(mut req: Request, next: Next) -> impl Into const MAX_ACCESS_TOKENS_TO_STORE: usize = 8; -pub async fn create_access_token(db: &dyn db::Db, user_id: UserId) -> Result { +pub async fn create_access_token(db: &db::DefaultDb, user_id: UserId) -> Result { let access_token = rpc::auth::random_token(); let access_token_hash = hash_access_token(&access_token).context("failed to hash access token")?; diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 58cec893e3..c3c74cc023 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -1,6 +1,5 @@ use crate::{Error, Result}; -use anyhow::{anyhow, Context}; -use async_trait::async_trait; +use anyhow::anyhow; use axum::http::StatusCode; use collections::HashMap; use futures::StreamExt; @@ -8,186 +7,20 @@ use serde::{Deserialize, Serialize}; use sqlx::{ migrate::{Migrate as _, Migration, MigrationSource}, types::Uuid, - FromRow, QueryBuilder, + Encode, FromRow, QueryBuilder, }; use std::{cmp, ops::Range, path::Path, time::Duration}; use time::{OffsetDateTime, PrimitiveDateTime}; -#[async_trait] -pub trait Db: Send + Sync { - async fn create_user( - &self, - email_address: &str, - admin: bool, - params: NewUserParams, - ) -> Result; - async fn get_all_users(&self, page: u32, limit: u32) -> Result>; - async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result>; - async fn get_user_by_id(&self, id: UserId) -> Result>; - async fn get_user_metrics_id(&self, id: UserId) -> Result; - async fn get_users_by_ids(&self, ids: Vec) -> Result>; - async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result>; - async fn get_user_by_github_account( - &self, - github_login: &str, - github_user_id: Option, - ) -> Result>; - async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>; - async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()>; - async fn destroy_user(&self, id: UserId) -> Result<()>; - - async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()>; - async fn get_invite_code_for_user(&self, id: UserId) -> Result>; - async fn get_user_for_invite_code(&self, code: &str) -> Result; - async fn create_invite_from_code( - &self, - code: &str, - email_address: &str, - device_id: Option<&str>, - ) -> Result; - - async fn create_signup(&self, signup: Signup) -> Result<()>; - async fn get_waitlist_summary(&self) -> Result; - async fn get_unsent_invites(&self, count: usize) -> Result>; - async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()>; - async fn create_user_from_invite( - &self, - invite: &Invite, - user: NewUserParams, - ) -> Result>; - - /// Registers a new project for the given user. - async fn register_project(&self, host_user_id: UserId) -> Result; - - /// Unregisters a project for the given project id. - async fn unregister_project(&self, project_id: ProjectId) -> Result<()>; - - /// Update file counts by extension for the given project and worktree. - async fn update_worktree_extensions( - &self, - project_id: ProjectId, - worktree_id: u64, - extensions: HashMap, - ) -> Result<()>; - - /// Get the file counts on the given project keyed by their worktree and extension. - async fn get_project_extensions( - &self, - project_id: ProjectId, - ) -> Result>>; - - /// Record which users have been active in which projects during - /// a given period of time. - async fn record_user_activity( - &self, - time_period: Range, - active_projects: &[(UserId, ProjectId)], - ) -> Result<()>; - - /// Get the number of users who have been active in the given - /// time period for at least the given time duration. - async fn get_active_user_count( - &self, - time_period: Range, - min_duration: Duration, - only_collaborative: bool, - ) -> Result; - - /// Get the users that have been most active during the given time period, - /// along with the amount of time they have been active in each project. - async fn get_top_users_activity_summary( - &self, - time_period: Range, - max_user_count: usize, - ) -> Result>; - - /// Get the project activity for the given user and time period. - async fn get_user_activity_timeline( - &self, - time_period: Range, - user_id: UserId, - ) -> Result>; - - async fn get_contacts(&self, id: UserId) -> Result>; - async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result; - async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>; - async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>; - async fn dismiss_contact_notification( - &self, - responder_id: UserId, - requester_id: UserId, - ) -> Result<()>; - async fn respond_to_contact_request( - &self, - responder_id: UserId, - requester_id: UserId, - accept: bool, - ) -> Result<()>; - - async fn create_access_token_hash( - &self, - user_id: UserId, - access_token_hash: &str, - max_access_token_count: usize, - ) -> Result<()>; - async fn get_access_token_hashes(&self, user_id: UserId) -> Result>; - - #[cfg(any(test, feature = "seed-support"))] - async fn find_org_by_slug(&self, slug: &str) -> Result>; - #[cfg(any(test, feature = "seed-support"))] - async fn create_org(&self, name: &str, slug: &str) -> Result; - #[cfg(any(test, feature = "seed-support"))] - async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>; - #[cfg(any(test, feature = "seed-support"))] - async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result; - #[cfg(any(test, feature = "seed-support"))] - - async fn get_org_channels(&self, org_id: OrgId) -> Result>; - async fn get_accessible_channels(&self, user_id: UserId) -> Result>; - async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId) - -> Result; - - #[cfg(any(test, feature = "seed-support"))] - async fn add_channel_member( - &self, - channel_id: ChannelId, - user_id: UserId, - is_admin: bool, - ) -> Result<()>; - async fn create_channel_message( - &self, - channel_id: ChannelId, - sender_id: UserId, - body: &str, - timestamp: OffsetDateTime, - nonce: u128, - ) -> Result; - async fn get_channel_messages( - &self, - channel_id: ChannelId, - count: usize, - before_id: Option, - ) -> Result>; - - #[cfg(test)] - async fn teardown(&self, url: &str); - - #[cfg(test)] - fn as_fake(&self) -> Option<&FakeDb>; -} - #[cfg(any(test, debug_assertions))] pub const DEFAULT_MIGRATIONS_PATH: Option<&'static str> = Some(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations")); -pub const TEST_MIGRATIONS_PATH: Option<&'static str> = - Some(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite")); - #[cfg(not(any(test, debug_assertions)))] pub const DEFAULT_MIGRATIONS_PATH: Option<&'static str> = None; -pub struct RealDb { - pool: sqlx::SqlitePool, +pub struct Db { + pool: sqlx::Pool, } macro_rules! test_support { @@ -204,16 +37,45 @@ macro_rules! test_support { }}; } -impl RealDb { - pub async fn new(url: &str, max_connections: u32) -> Result { - eprintln!("{url}"); +impl Db { + #[cfg(test)] + pub async fn sqlite(url: &str) -> Result { let pool = sqlx::sqlite::SqlitePoolOptions::new() .max_connections(1) .connect(url) .await?; Ok(Self { pool }) } +} +impl Db { + pub async fn postgres(url: &str, max_connection: u32) -> Result { + let pool = sqlx::postgres::PgPoolOptions::new() + .max_connections(1) + .connect(url) + .await?; + Ok(Self { pool }) + } +} + +impl Db +where + D: sqlx::Database + sqlx::migrate::MigrateDatabase, + for<'a> >::Arguments: sqlx::IntoArguments<'a, D>, + D: for<'r> sqlx::database::HasValueRef<'r>, + D: for<'r> sqlx::database::HasArguments<'r>, + for<'a> &'a mut D::Connection: sqlx::Executor<'a>, + String: sqlx::Type, + i32: sqlx::Type, + bool: sqlx::Type, + str: sqlx::Type, + for<'a> str: sqlx::Encode<'a, D>, + for<'a> &'a str: sqlx::Encode<'a, D>, + for<'a> String: sqlx::Encode<'a, D>, + for<'a> i32: sqlx::Encode<'a, D>, + for<'a> bool: sqlx::Encode<'a, D>, + for<'a> Option: sqlx::Encode<'a, D>, +{ pub async fn migrate( &self, migrations_path: &Path, @@ -266,13 +128,10 @@ impl RealDb { result.push('%'); result } -} -#[async_trait] -impl Db for RealDb { // users - async fn create_user( + pub async fn create_user( &self, email_address: &str, admin: bool, @@ -302,7 +161,7 @@ impl Db for RealDb { }) } - async fn get_all_users(&self, page: u32, limit: u32) -> Result> { + pub async fn get_all_users(&self, page: u32, limit: u32) -> Result> { test_support!(self, { let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2"; Ok(sqlx::query_as(query) @@ -313,7 +172,7 @@ impl Db for RealDb { }) } - async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result> { + pub async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result> { test_support!(self, { let like_string = Self::fuzzy_like_string(name_query); let query = " @@ -332,7 +191,7 @@ impl Db for RealDb { }) } - async fn get_user_by_id(&self, id: UserId) -> Result> { + pub async fn get_user_by_id(&self, id: UserId) -> Result> { test_support!(self, { let query = " SELECT users.* @@ -347,7 +206,7 @@ impl Db for RealDb { }) } - async fn get_user_metrics_id(&self, id: UserId) -> Result { + pub async fn get_user_metrics_id(&self, id: UserId) -> Result { test_support!(self, { let query = " SELECT metrics_id::text @@ -361,7 +220,7 @@ impl Db for RealDb { }) } - async fn get_users_by_ids(&self, ids: Vec) -> Result> { + pub async fn get_users_by_ids(&self, ids: Vec) -> Result> { test_support!(self, { let query = " SELECT users.* @@ -375,7 +234,10 @@ impl Db for RealDb { }) } - async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result> { + pub async fn get_users_with_no_invites( + &self, + invited_by_another_user: bool, + ) -> Result> { test_support!(self, { let query = format!( " @@ -391,7 +253,7 @@ impl Db for RealDb { }) } - async fn get_user_by_github_account( + pub async fn get_user_by_github_account( &self, github_login: &str, github_user_id: Option, @@ -443,7 +305,7 @@ impl Db for RealDb { }) } - async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> { + pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> { test_support!(self, { let query = "UPDATE users SET admin = $1 WHERE id = $2"; Ok(sqlx::query(query) @@ -455,7 +317,7 @@ impl Db for RealDb { }) } - async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> { + pub async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> { test_support!(self, { let query = "UPDATE users SET connected_once = $1 WHERE id = $2"; Ok(sqlx::query(query) @@ -467,7 +329,7 @@ impl Db for RealDb { }) } - async fn destroy_user(&self, id: UserId) -> Result<()> { + pub async fn destroy_user(&self, id: UserId) -> Result<()> { test_support!(self, { let query = "DELETE FROM access_tokens WHERE user_id = $1;"; sqlx::query(query) @@ -486,7 +348,7 @@ impl Db for RealDb { // signups - async fn create_signup(&self, signup: Signup) -> Result<()> { + pub async fn create_signup(&self, signup: Signup) -> Result<()> { test_support!(self, { sqlx::query( " @@ -522,7 +384,7 @@ impl Db for RealDb { }) } - async fn get_waitlist_summary(&self) -> Result { + pub async fn get_waitlist_summary(&self) -> Result { test_support!(self, { Ok(sqlx::query_as( " @@ -545,7 +407,7 @@ impl Db for RealDb { }) } - async fn get_unsent_invites(&self, count: usize) -> Result> { + pub async fn get_unsent_invites(&self, count: usize) -> Result> { test_support!(self, { Ok(sqlx::query_as( " @@ -564,28 +426,28 @@ impl Db for RealDb { }) } - async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> { + pub async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> { test_support!(self, { - // sqlx::query( - // " - // UPDATE signups - // SET email_confirmation_sent = TRUE - // WHERE email_address = ANY ($1) - // ", - // ) + sqlx::query( + " + UPDATE signups + SET email_confirmation_sent = TRUE + WHERE email_address = ANY ($1) + ", + ) // .bind( // &invites // .iter() // .map(|s| s.email_address.as_str()) // .collect::>(), // ) - // .execute(&self.pool) - // .await?; + .execute(&self.pool) + .await?; Ok(()) }) } - async fn create_user_from_invite( + pub async fn create_user_from_invite( &self, invite: &Invite, user: NewUserParams, @@ -697,7 +559,7 @@ impl Db for RealDb { // invite codes - async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> { + pub async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> { test_support!(self, { let mut tx = self.pool.begin().await?; if count > 0 { @@ -730,7 +592,7 @@ impl Db for RealDb { }) } - async fn get_invite_code_for_user(&self, id: UserId) -> Result> { + pub async fn get_invite_code_for_user(&self, id: UserId) -> Result> { test_support!(self, { let result: Option<(String, i32)> = sqlx::query_as( " @@ -750,7 +612,7 @@ impl Db for RealDb { }) } - async fn get_user_for_invite_code(&self, code: &str) -> Result { + pub async fn get_user_for_invite_code(&self, code: &str) -> Result { test_support!(self, { sqlx::query_as( " @@ -771,7 +633,7 @@ impl Db for RealDb { }) } - async fn create_invite_from_code( + pub async fn create_invite_from_code( &self, code: &str, email_address: &str, @@ -860,7 +722,8 @@ impl Db for RealDb { // projects - async fn register_project(&self, host_user_id: UserId) -> Result { + /// Registers a new project for the given user. + pub async fn register_project(&self, host_user_id: UserId) -> Result { test_support!(self, { Ok(sqlx::query_scalar( " @@ -876,7 +739,8 @@ impl Db for RealDb { }) } - async fn unregister_project(&self, project_id: ProjectId) -> Result<()> { + /// Unregisters a project for the given project id. + pub async fn unregister_project(&self, project_id: ProjectId) -> Result<()> { test_support!(self, { sqlx::query( " @@ -892,7 +756,8 @@ impl Db for RealDb { }) } - async fn update_worktree_extensions( + /// Update file counts by extension for the given project and worktree. + pub async fn update_worktree_extensions( &self, project_id: ProjectId, worktree_id: u64, @@ -925,7 +790,8 @@ impl Db for RealDb { }) } - async fn get_project_extensions( + /// Get the file counts on the given project keyed by their worktree and extension. + pub async fn get_project_extensions( &self, project_id: ProjectId, ) -> Result>> { @@ -958,7 +824,9 @@ impl Db for RealDb { }) } - async fn record_user_activity( + /// Record which users have been active in which projects during + /// a given period of time. + pub async fn record_user_activity( &self, time_period: Range, projects: &[(UserId, ProjectId)], @@ -989,7 +857,9 @@ impl Db for RealDb { }) } - async fn get_active_user_count( + /// Get the number of users who have been active in the given + /// time period for at least the given time duration. + pub async fn get_active_user_count( &self, time_period: Range, min_duration: Duration, @@ -1066,7 +936,9 @@ impl Db for RealDb { }) } - async fn get_top_users_activity_summary( + /// Get the users that have been most active during the given time period, + /// along with the amount of time they have been active in each project. + pub async fn get_top_users_activity_summary( &self, time_period: Range, max_user_count: usize, @@ -1135,7 +1007,8 @@ impl Db for RealDb { }) } - async fn get_user_activity_timeline( + /// Get the project activity for the given user and time period. + pub async fn get_user_activity_timeline( &self, time_period: Range, user_id: UserId, @@ -1224,7 +1097,7 @@ impl Db for RealDb { // contacts - async fn get_contacts(&self, user_id: UserId) -> Result> { + pub async fn get_contacts(&self, user_id: UserId) -> Result> { test_support!(self, { let query = " SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify @@ -1275,7 +1148,7 @@ impl Db for RealDb { }) } - async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result { + pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result { test_support!(self, { let (id_a, id_b) = if user_id_1 < user_id_2 { (user_id_1, user_id_2) @@ -1297,7 +1170,7 @@ impl Db for RealDb { }) } - async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> { + pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> { test_support!(self, { let (id_a, id_b, a_to_b) = if sender_id < receiver_id { (sender_id, receiver_id, true) @@ -1331,7 +1204,7 @@ impl Db for RealDb { }) } - async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> { + pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> { test_support!(self, { let (id_a, id_b) = if responder_id < requester_id { (responder_id, requester_id) @@ -1356,7 +1229,7 @@ impl Db for RealDb { }) } - async fn dismiss_contact_notification( + pub async fn dismiss_contact_notification( &self, user_id: UserId, contact_user_id: UserId, @@ -1394,7 +1267,7 @@ impl Db for RealDb { }) } - async fn respond_to_contact_request( + pub async fn respond_to_contact_request( &self, responder_id: UserId, requester_id: UserId, @@ -1440,7 +1313,7 @@ impl Db for RealDb { // access tokens - async fn create_access_token_hash( + pub async fn create_access_token_hash( &self, user_id: UserId, access_token_hash: &str, @@ -1477,7 +1350,7 @@ impl Db for RealDb { }) } - async fn get_access_token_hashes(&self, user_id: UserId) -> Result> { + pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result> { test_support!(self, { let query = " SELECT hash @@ -1496,7 +1369,7 @@ impl Db for RealDb { #[allow(unused)] // Help rust-analyzer #[cfg(any(test, feature = "seed-support"))] - async fn find_org_by_slug(&self, slug: &str) -> Result> { + pub async fn find_org_by_slug(&self, slug: &str) -> Result> { test_support!(self, { let query = " SELECT * @@ -1511,7 +1384,7 @@ impl Db for RealDb { } #[cfg(any(test, feature = "seed-support"))] - async fn create_org(&self, name: &str, slug: &str) -> Result { + pub async fn create_org(&self, name: &str, slug: &str) -> Result { test_support!(self, { let query = " INSERT INTO orgs (name, slug) @@ -1528,7 +1401,12 @@ impl Db for RealDb { } #[cfg(any(test, feature = "seed-support"))] - async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> { + pub async fn add_org_member( + &self, + org_id: OrgId, + user_id: UserId, + is_admin: bool, + ) -> Result<()> { test_support!(self, { let query = " INSERT INTO org_memberships (org_id, user_id, admin) @@ -1548,7 +1426,7 @@ impl Db for RealDb { // channels #[cfg(any(test, feature = "seed-support"))] - async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result { + pub async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result { test_support!(self, { let query = " INSERT INTO channels (owner_id, owner_is_user, name) @@ -1566,7 +1444,7 @@ impl Db for RealDb { #[allow(unused)] // Help rust-analyzer #[cfg(any(test, feature = "seed-support"))] - async fn get_org_channels(&self, org_id: OrgId) -> Result> { + pub async fn get_org_channels(&self, org_id: OrgId) -> Result> { test_support!(self, { let query = " SELECT * @@ -1582,7 +1460,7 @@ impl Db for RealDb { }) } - async fn get_accessible_channels(&self, user_id: UserId) -> Result> { + pub async fn get_accessible_channels(&self, user_id: UserId) -> Result> { test_support!(self, { let query = " SELECT @@ -1600,7 +1478,7 @@ impl Db for RealDb { }) } - async fn can_user_access_channel( + pub async fn can_user_access_channel( &self, user_id: UserId, channel_id: ChannelId, @@ -1622,7 +1500,7 @@ impl Db for RealDb { } #[cfg(any(test, feature = "seed-support"))] - async fn add_channel_member( + pub async fn add_channel_member( &self, channel_id: ChannelId, user_id: UserId, @@ -1646,7 +1524,7 @@ impl Db for RealDb { // messages - async fn create_channel_message( + pub async fn create_channel_message( &self, channel_id: ChannelId, sender_id: UserId, @@ -1673,7 +1551,7 @@ impl Db for RealDb { }) } - async fn get_channel_messages( + pub async fn get_channel_messages( &self, channel_id: ChannelId, count: usize, @@ -1704,9 +1582,7 @@ impl Db for RealDb { } #[cfg(test)] - async fn teardown(&self, url: &str) { - let start = std::time::Instant::now(); - eprintln!("tearing down database..."); + pub async fn teardown(&self, url: &str) { test_support!(self, { use util::ResultExt; @@ -1720,14 +1596,8 @@ impl Db for RealDb { ::drop_database(url) .await .log_err(); - eprintln!("tore down database: {:?}", start.elapsed()); }) } - - #[cfg(test)] - fn as_fake(&self) -> Option<&FakeDb> { - None - } } macro_rules! id_type { @@ -1937,661 +1807,13 @@ pub use test::*; #[cfg(test)] mod test { use super::*; - use anyhow::anyhow; - use collections::BTreeMap; use gpui::executor::Background; - use parking_lot::Mutex; use rand::prelude::*; use sqlx::{migrate::MigrateDatabase, Sqlite}; use std::sync::Arc; - use util::post_inc; - - pub struct FakeDb { - background: Arc, - pub users: Mutex>, - pub projects: Mutex>, - pub worktree_extensions: Mutex>, - pub orgs: Mutex>, - pub org_memberships: Mutex>, - pub channels: Mutex>, - pub channel_memberships: Mutex>, - pub channel_messages: Mutex>, - pub contacts: Mutex>, - next_channel_message_id: Mutex, - next_user_id: Mutex, - next_org_id: Mutex, - next_channel_id: Mutex, - next_project_id: Mutex, - } - - #[derive(Debug)] - pub struct FakeContact { - pub requester_id: UserId, - pub responder_id: UserId, - pub accepted: bool, - pub should_notify: bool, - } - - impl FakeDb { - pub fn new(background: Arc) -> Self { - Self { - background, - users: Default::default(), - next_user_id: Mutex::new(0), - projects: Default::default(), - worktree_extensions: Default::default(), - next_project_id: Mutex::new(1), - orgs: Default::default(), - next_org_id: Mutex::new(1), - org_memberships: Default::default(), - channels: Default::default(), - next_channel_id: Mutex::new(1), - channel_memberships: Default::default(), - channel_messages: Default::default(), - next_channel_message_id: Mutex::new(1), - contacts: Default::default(), - } - } - } - - #[async_trait] - impl Db for FakeDb { - async fn create_user( - &self, - email_address: &str, - admin: bool, - params: NewUserParams, - ) -> Result { - self.background.simulate_random_delay().await; - - let mut users = self.users.lock(); - let user_id = if let Some(user) = users - .values() - .find(|user| user.github_login == params.github_login) - { - user.id - } else { - let id = post_inc(&mut *self.next_user_id.lock()); - let user_id = UserId(id); - users.insert( - user_id, - User { - id: user_id, - github_login: params.github_login, - github_user_id: Some(params.github_user_id), - email_address: Some(email_address.to_string()), - admin, - invite_code: None, - invite_count: 0, - connected_once: false, - }, - ); - user_id - }; - Ok(NewUserResult { - user_id, - metrics_id: "the-metrics-id".to_string(), - inviting_user_id: None, - signup_device_id: None, - }) - } - - async fn get_all_users(&self, _page: u32, _limit: u32) -> Result> { - unimplemented!() - } - - async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result> { - unimplemented!() - } - - async fn get_user_by_id(&self, id: UserId) -> Result> { - self.background.simulate_random_delay().await; - Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next()) - } - - async fn get_user_metrics_id(&self, _id: UserId) -> Result { - Ok("the-metrics-id".to_string()) - } - - async fn get_users_by_ids(&self, ids: Vec) -> Result> { - self.background.simulate_random_delay().await; - let users = self.users.lock(); - Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect()) - } - - async fn get_users_with_no_invites(&self, _: bool) -> Result> { - unimplemented!() - } - - async fn get_user_by_github_account( - &self, - github_login: &str, - github_user_id: Option, - ) -> Result> { - self.background.simulate_random_delay().await; - if let Some(github_user_id) = github_user_id { - for user in self.users.lock().values_mut() { - if user.github_user_id == Some(github_user_id) { - user.github_login = github_login.into(); - return Ok(Some(user.clone())); - } - if user.github_login == github_login { - user.github_user_id = Some(github_user_id); - return Ok(Some(user.clone())); - } - } - Ok(None) - } else { - Ok(self - .users - .lock() - .values() - .find(|user| user.github_login == github_login) - .cloned()) - } - } - - async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> { - unimplemented!() - } - - async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> { - self.background.simulate_random_delay().await; - let mut users = self.users.lock(); - let mut user = users - .get_mut(&id) - .ok_or_else(|| anyhow!("user not found"))?; - user.connected_once = connected_once; - Ok(()) - } - - async fn destroy_user(&self, _id: UserId) -> Result<()> { - unimplemented!() - } - - // signups - - async fn create_signup(&self, _signup: Signup) -> Result<()> { - unimplemented!() - } - - async fn get_waitlist_summary(&self) -> Result { - unimplemented!() - } - - async fn get_unsent_invites(&self, _count: usize) -> Result> { - unimplemented!() - } - - async fn record_sent_invites(&self, _invites: &[Invite]) -> Result<()> { - unimplemented!() - } - - async fn create_user_from_invite( - &self, - _invite: &Invite, - _user: NewUserParams, - ) -> Result> { - unimplemented!() - } - - // invite codes - - async fn set_invite_count_for_user(&self, _id: UserId, _count: u32) -> Result<()> { - unimplemented!() - } - - async fn get_invite_code_for_user(&self, _id: UserId) -> Result> { - self.background.simulate_random_delay().await; - Ok(None) - } - - async fn get_user_for_invite_code(&self, _code: &str) -> Result { - unimplemented!() - } - - async fn create_invite_from_code( - &self, - _code: &str, - _email_address: &str, - _device_id: Option<&str>, - ) -> Result { - unimplemented!() - } - - // projects - - async fn register_project(&self, host_user_id: UserId) -> Result { - self.background.simulate_random_delay().await; - if !self.users.lock().contains_key(&host_user_id) { - Err(anyhow!("no such user"))?; - } - - let project_id = ProjectId(post_inc(&mut *self.next_project_id.lock())); - self.projects.lock().insert( - project_id, - Project { - id: project_id, - host_user_id, - unregistered: false, - }, - ); - Ok(project_id) - } - - async fn unregister_project(&self, project_id: ProjectId) -> Result<()> { - self.background.simulate_random_delay().await; - self.projects - .lock() - .get_mut(&project_id) - .ok_or_else(|| anyhow!("no such project"))? - .unregistered = true; - Ok(()) - } - - async fn update_worktree_extensions( - &self, - project_id: ProjectId, - worktree_id: u64, - extensions: HashMap, - ) -> Result<()> { - self.background.simulate_random_delay().await; - if !self.projects.lock().contains_key(&project_id) { - Err(anyhow!("no such project"))?; - } - - for (extension, count) in extensions { - self.worktree_extensions - .lock() - .insert((project_id, worktree_id, extension), count); - } - - Ok(()) - } - - async fn get_project_extensions( - &self, - _project_id: ProjectId, - ) -> Result>> { - unimplemented!() - } - - async fn record_user_activity( - &self, - _time_period: Range, - _active_projects: &[(UserId, ProjectId)], - ) -> Result<()> { - unimplemented!() - } - - async fn get_active_user_count( - &self, - _time_period: Range, - _min_duration: Duration, - _only_collaborative: bool, - ) -> Result { - unimplemented!() - } - - async fn get_top_users_activity_summary( - &self, - _time_period: Range, - _limit: usize, - ) -> Result> { - unimplemented!() - } - - async fn get_user_activity_timeline( - &self, - _time_period: Range, - _user_id: UserId, - ) -> Result> { - unimplemented!() - } - - // contacts - - async fn get_contacts(&self, id: UserId) -> Result> { - self.background.simulate_random_delay().await; - let mut contacts = Vec::new(); - - for contact in self.contacts.lock().iter() { - if contact.requester_id == id { - if contact.accepted { - contacts.push(Contact::Accepted { - user_id: contact.responder_id, - should_notify: contact.should_notify, - }); - } else { - contacts.push(Contact::Outgoing { - user_id: contact.responder_id, - }); - } - } else if contact.responder_id == id { - if contact.accepted { - contacts.push(Contact::Accepted { - user_id: contact.requester_id, - should_notify: false, - }); - } else { - contacts.push(Contact::Incoming { - user_id: contact.requester_id, - should_notify: contact.should_notify, - }); - } - } - } - - contacts.sort_unstable_by_key(|contact| contact.user_id()); - Ok(contacts) - } - - async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result { - self.background.simulate_random_delay().await; - Ok(self.contacts.lock().iter().any(|contact| { - contact.accepted - && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b) - || (contact.requester_id == user_id_b && contact.responder_id == user_id_a)) - })) - } - - async fn send_contact_request( - &self, - requester_id: UserId, - responder_id: UserId, - ) -> Result<()> { - self.background.simulate_random_delay().await; - let mut contacts = self.contacts.lock(); - for contact in contacts.iter_mut() { - if contact.requester_id == requester_id && contact.responder_id == responder_id { - if contact.accepted { - Err(anyhow!("contact already exists"))?; - } else { - Err(anyhow!("contact already requested"))?; - } - } - if contact.responder_id == requester_id && contact.requester_id == responder_id { - if contact.accepted { - Err(anyhow!("contact already exists"))?; - } else { - contact.accepted = true; - contact.should_notify = false; - return Ok(()); - } - } - } - contacts.push(FakeContact { - requester_id, - responder_id, - accepted: false, - should_notify: true, - }); - Ok(()) - } - - async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> { - self.background.simulate_random_delay().await; - self.contacts.lock().retain(|contact| { - !(contact.requester_id == requester_id && contact.responder_id == responder_id) - }); - Ok(()) - } - - async fn dismiss_contact_notification( - &self, - user_id: UserId, - contact_user_id: UserId, - ) -> Result<()> { - self.background.simulate_random_delay().await; - let mut contacts = self.contacts.lock(); - for contact in contacts.iter_mut() { - if contact.requester_id == contact_user_id - && contact.responder_id == user_id - && !contact.accepted - { - contact.should_notify = false; - return Ok(()); - } - if contact.requester_id == user_id - && contact.responder_id == contact_user_id - && contact.accepted - { - contact.should_notify = false; - return Ok(()); - } - } - Err(anyhow!("no such notification"))? - } - - async fn respond_to_contact_request( - &self, - responder_id: UserId, - requester_id: UserId, - accept: bool, - ) -> Result<()> { - self.background.simulate_random_delay().await; - let mut contacts = self.contacts.lock(); - for (ix, contact) in contacts.iter_mut().enumerate() { - if contact.requester_id == requester_id && contact.responder_id == responder_id { - if contact.accepted { - Err(anyhow!("contact already confirmed"))?; - } - if accept { - contact.accepted = true; - contact.should_notify = true; - } else { - contacts.remove(ix); - } - return Ok(()); - } - } - Err(anyhow!("no such contact request"))? - } - - async fn create_access_token_hash( - &self, - _user_id: UserId, - _access_token_hash: &str, - _max_access_token_count: usize, - ) -> Result<()> { - unimplemented!() - } - - async fn get_access_token_hashes(&self, _user_id: UserId) -> Result> { - unimplemented!() - } - - async fn find_org_by_slug(&self, _slug: &str) -> Result> { - unimplemented!() - } - - async fn create_org(&self, name: &str, slug: &str) -> Result { - self.background.simulate_random_delay().await; - let mut orgs = self.orgs.lock(); - if orgs.values().any(|org| org.slug == slug) { - Err(anyhow!("org already exists"))? - } else { - let org_id = OrgId(post_inc(&mut *self.next_org_id.lock())); - orgs.insert( - org_id, - Org { - id: org_id, - name: name.to_string(), - slug: slug.to_string(), - }, - ); - Ok(org_id) - } - } - - async fn add_org_member( - &self, - org_id: OrgId, - user_id: UserId, - is_admin: bool, - ) -> Result<()> { - self.background.simulate_random_delay().await; - if !self.orgs.lock().contains_key(&org_id) { - Err(anyhow!("org does not exist"))?; - } - if !self.users.lock().contains_key(&user_id) { - Err(anyhow!("user does not exist"))?; - } - - self.org_memberships - .lock() - .entry((org_id, user_id)) - .or_insert(is_admin); - Ok(()) - } - - async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result { - self.background.simulate_random_delay().await; - if !self.orgs.lock().contains_key(&org_id) { - Err(anyhow!("org does not exist"))?; - } - - let mut channels = self.channels.lock(); - let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock())); - channels.insert( - channel_id, - Channel { - id: channel_id, - name: name.to_string(), - owner_id: org_id.0, - owner_is_user: false, - }, - ); - Ok(channel_id) - } - - async fn get_org_channels(&self, org_id: OrgId) -> Result> { - self.background.simulate_random_delay().await; - Ok(self - .channels - .lock() - .values() - .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0) - .cloned() - .collect()) - } - - async fn get_accessible_channels(&self, user_id: UserId) -> Result> { - self.background.simulate_random_delay().await; - let channels = self.channels.lock(); - let memberships = self.channel_memberships.lock(); - Ok(channels - .values() - .filter(|channel| memberships.contains_key(&(channel.id, user_id))) - .cloned() - .collect()) - } - - async fn can_user_access_channel( - &self, - user_id: UserId, - channel_id: ChannelId, - ) -> Result { - self.background.simulate_random_delay().await; - Ok(self - .channel_memberships - .lock() - .contains_key(&(channel_id, user_id))) - } - - async fn add_channel_member( - &self, - channel_id: ChannelId, - user_id: UserId, - is_admin: bool, - ) -> Result<()> { - self.background.simulate_random_delay().await; - if !self.channels.lock().contains_key(&channel_id) { - Err(anyhow!("channel does not exist"))?; - } - if !self.users.lock().contains_key(&user_id) { - Err(anyhow!("user does not exist"))?; - } - - self.channel_memberships - .lock() - .entry((channel_id, user_id)) - .or_insert(is_admin); - Ok(()) - } - - async fn create_channel_message( - &self, - channel_id: ChannelId, - sender_id: UserId, - body: &str, - timestamp: OffsetDateTime, - nonce: u128, - ) -> Result { - self.background.simulate_random_delay().await; - if !self.channels.lock().contains_key(&channel_id) { - Err(anyhow!("channel does not exist"))?; - } - if !self.users.lock().contains_key(&sender_id) { - Err(anyhow!("user does not exist"))?; - } - - let mut messages = self.channel_messages.lock(); - if let Some(message) = messages - .values() - .find(|message| message.nonce.as_u128() == nonce) - { - Ok(message.id) - } else { - let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock())); - messages.insert( - message_id, - ChannelMessage { - id: message_id, - channel_id, - sender_id, - body: body.to_string(), - sent_at: timestamp, - nonce: Uuid::from_u128(nonce), - }, - ); - Ok(message_id) - } - } - - async fn get_channel_messages( - &self, - channel_id: ChannelId, - count: usize, - before_id: Option, - ) -> Result> { - self.background.simulate_random_delay().await; - let mut messages = self - .channel_messages - .lock() - .values() - .rev() - .filter(|message| { - message.channel_id == channel_id - && message.id < before_id.unwrap_or(MessageId::MAX) - }) - .take(count) - .cloned() - .collect::>(); - messages.sort_unstable_by_key(|message| message.id); - Ok(messages) - } - - async fn teardown(&self, _: &str) {} - - #[cfg(test)] - fn as_fake(&self) -> Option<&FakeDb> { - Some(self) - } - } pub struct TestDb { - pub db: Option>, + pub db: Option>, pub url: String, } @@ -2603,8 +1825,8 @@ mod test { let mut rng = StdRng::from_entropy(); let url = format!("/tmp/zed-test-{}", rng.gen::()); Sqlite::create_database(&url).await.unwrap(); - let db = RealDb::new(&url, 5).await.unwrap(); - db.migrate(Path::new(TEST_MIGRATIONS_PATH.unwrap()), false) + let db = Db::new(&url, 5).await.unwrap(); + db.migrate(Path::new(DEFAULT_MIGRATIONS_PATH.unwrap()), false) .await .unwrap(); @@ -2615,14 +1837,23 @@ mod test { } } - pub fn fake(background: Arc) -> Self { + pub async fn fake(background: Arc) -> Self { + let start = std::time::Instant::now(); + let mut rng = StdRng::from_entropy(); + let url = format!("/tmp/zed-test-{}", rng.gen::()); + Sqlite::create_database(&url).await.unwrap(); + let db = Db::new(&url, 5).await.unwrap(); + db.migrate(Path::new(DEFAULT_MIGRATIONS_PATH.unwrap()), false) + .await + .unwrap(); + Self { - db: Some(Arc::new(FakeDb::new(background))), - url: Default::default(), + db: Some(Arc::new(db)), + url, } } - pub fn db(&self) -> &Arc { + pub fn db(&self) -> &Arc { self.db.as_ref().unwrap() } } diff --git a/crates/collab/src/db_tests.rs b/crates/collab/src/db_tests.rs index 4a27f2752a..98aac44400 100644 --- a/crates/collab/src/db_tests.rs +++ b/crates/collab/src/db_tests.rs @@ -625,7 +625,7 @@ async fn test_fuzzy_search_users() { &["rhode-island", "colorado", "oregon"], ); - async fn fuzzy_search_user_names(db: &Arc, query: &str) -> Vec { + async fn fuzzy_search_user_names(db: &Arc, query: &str) -> Vec { db.fuzzy_search_users(query, 10) .await .unwrap() diff --git a/crates/collab/src/main.rs b/crates/collab/src/main.rs index 350596cf31..fd37990674 100644 --- a/crates/collab/src/main.rs +++ b/crates/collab/src/main.rs @@ -13,7 +13,7 @@ use crate::rpc::ResultExt as _; use anyhow::anyhow; use axum::{routing::get, Router}; use collab::{Error, Result}; -use db::{Db, RealDb}; +use db::DefaultDb as Db; use serde::Deserialize; use std::{ env::args, @@ -49,14 +49,14 @@ pub struct MigrateConfig { } pub struct AppState { - db: Arc, + db: Arc, live_kit_client: Option>, config: Config, } impl AppState { async fn new(config: Config) -> Result> { - let db = RealDb::new(&config.database_url, 5).await?; + let db = Db::new(&config.database_url, 5).await?; let live_kit_client = if let Some(((server, key), secret)) = config .live_kit_server .as_ref() @@ -96,7 +96,7 @@ async fn main() -> Result<()> { } Some("migrate") => { let config = envy::from_env::().expect("error loading config"); - let db = RealDb::new(&config.database_url, 5).await?; + let db = Db::new(&config.database_url, 5).await?; let migrations_path = config .migrations_path