diff --git a/Cargo.lock b/Cargo.lock index 7b09775f2a..590835a49b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1159,7 +1159,6 @@ dependencies = [ "scrypt", "sea-orm", "sea-query", - "sea-query-binder", "serde", "serde_json", "settings", diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index 4cb91ad12d..66f426839c 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -37,8 +37,7 @@ rand = "0.8" reqwest = { version = "0.11", features = ["json"], optional = true } scrypt = "0.7" sea-orm = { version = "0.10", features = ["sqlx-postgres", "postgres-array", "runtime-tokio-rustls"] } -sea-query = { version = "0.27", features = ["derive"] } -sea-query-binder = { version = "0.2", features = ["sqlx-postgres"] } +sea-query = "0.27" serde = { version = "1.0", features = ["derive", "rc"] } serde_json = "1.0" sha-1 = "0.9" @@ -76,7 +75,6 @@ log = { version = "0.4.16", features = ["kv_unstable_serde"] } util = { path = "../util" } lazy_static = "1.4" sea-orm = { version = "0.10", features = ["sqlx-sqlite"] } -sea-query-binder = { version = "0.2", features = ["sqlx-sqlite"] } serde_json = { version = "1.0", features = ["preserve_order"] } sqlx = { version = "0.6", features = ["sqlite"] } unindent = "0.1" diff --git a/crates/collab/src/api.rs b/crates/collab/src/api.rs index 5fcdc5fcfd..bf183edf54 100644 --- a/crates/collab/src/api.rs +++ b/crates/collab/src/api.rs @@ -1,6 +1,6 @@ use crate::{ auth, - db::{Invite, NewUserParams, Signup, User, UserId, WaitlistSummary}, + db::{Invite, NewSignup, NewUserParams, User, UserId, WaitlistSummary}, rpc::{self, ResultExt}, AppState, Error, Result, }; @@ -335,7 +335,7 @@ async fn get_user_for_invite_code( } async fn create_signup( - Json(params): Json, + Json(params): Json, Extension(app): Extension>, ) -> Result<()> { app.db.create_signup(params).await?; diff --git a/crates/collab/src/auth.rs b/crates/collab/src/auth.rs index 63f032f7e6..0c9cf33a6b 100644 --- a/crates/collab/src/auth.rs +++ b/crates/collab/src/auth.rs @@ -75,7 +75,7 @@ pub async fn validate_header(mut req: Request, next: Next) -> impl Into const MAX_ACCESS_TOKENS_TO_STORE: usize = 8; -pub async fn create_access_token(db: &db::DefaultDb, user_id: UserId) -> Result { +pub async fn create_access_token(db: &db::Database, user_id: UserId) -> Result { let access_token = rpc::auth::random_token(); let access_token_hash = hash_access_token(&access_token).context("failed to hash access token")?; diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 044d4ef8d7..d89d041f2a 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -1,42 +1,44 @@ -mod schema; +mod access_token; +mod contact; +mod project; +mod project_collaborator; +mod room; +mod room_participant; +mod signup; #[cfg(test)] mod tests; +mod user; +mod worktree; use crate::{Error, Result}; use anyhow::anyhow; -use axum::http::StatusCode; -use collections::{BTreeMap, HashMap, HashSet}; +use collections::HashMap; +pub use contact::Contact; use dashmap::DashMap; -use futures::{future::BoxFuture, FutureExt, StreamExt}; +use futures::StreamExt; +use hyper::StatusCode; use rpc::{proto, ConnectionId}; -use sea_query::{Expr, Query}; -use sea_query_binder::SqlxBinder; +pub use sea_orm::ConnectOptions; +use sea_orm::{ + entity::prelude::*, ActiveValue, ConnectionTrait, DatabaseBackend, DatabaseConnection, + DatabaseTransaction, DbErr, FromQueryResult, IntoActiveModel, JoinType, QueryOrder, + QuerySelect, Statement, TransactionTrait, +}; +use sea_query::{Alias, Expr, OnConflict, Query}; use serde::{Deserialize, Serialize}; -use sqlx::{ - migrate::{Migrate as _, Migration, MigrationSource}, - types::Uuid, - FromRow, -}; -use std::{ - future::Future, - marker::PhantomData, - ops::{Deref, DerefMut}, - path::Path, - rc::Rc, - sync::Arc, - time::Duration, -}; -use time::{OffsetDateTime, PrimitiveDateTime}; +pub use signup::{Invite, NewSignup, WaitlistSummary}; +use sqlx::migrate::{Migrate, Migration, MigrationSource}; +use sqlx::Connection; +use std::ops::{Deref, DerefMut}; +use std::path::Path; +use std::time::Duration; +use std::{future::Future, marker::PhantomData, rc::Rc, sync::Arc}; use tokio::sync::{Mutex, OwnedMutexGuard}; +pub use user::Model as User; -#[cfg(test)] -pub type DefaultDb = Db; - -#[cfg(not(test))] -pub type DefaultDb = Db; - -pub struct Db { - pool: sqlx::Pool, +pub struct Database { + options: ConnectOptions, + pool: DatabaseConnection, rooms: DashMap>>, #[cfg(test)] background: Option>, @@ -44,214 +46,11 @@ pub struct Db { runtime: Option, } -pub struct RoomGuard { - data: T, - _guard: OwnedMutexGuard<()>, - _not_send: PhantomData>, -} - -impl Deref for RoomGuard { - type Target = T; - - fn deref(&self) -> &T { - &self.data - } -} - -impl DerefMut for RoomGuard { - fn deref_mut(&mut self) -> &mut T { - &mut self.data - } -} - -pub trait BeginTransaction: Send + Sync { - type Database: sqlx::Database; - - fn begin_transaction(&self) -> BoxFuture>>; -} - -// In Postgres, serializable transactions are opt-in -impl BeginTransaction for Db { - type Database = sqlx::Postgres; - - fn begin_transaction(&self) -> BoxFuture>> { - async move { - let mut tx = self.pool.begin().await?; - sqlx::Executor::execute(&mut tx, "SET TRANSACTION ISOLATION LEVEL SERIALIZABLE;") - .await?; - Ok(tx) - } - .boxed() - } -} - -// In Sqlite, transactions are inherently serializable. -#[cfg(test)] -impl BeginTransaction for Db { - type Database = sqlx::Sqlite; - - fn begin_transaction(&self) -> BoxFuture>> { - async move { Ok(self.pool.begin().await?) }.boxed() - } -} - -pub trait BuildQuery { - fn build_query(&self, query: &T) -> (String, sea_query_binder::SqlxValues); -} - -impl BuildQuery for Db { - fn build_query(&self, query: &T) -> (String, sea_query_binder::SqlxValues) { - query.build_sqlx(sea_query::PostgresQueryBuilder) - } -} - -#[cfg(test)] -impl BuildQuery for Db { - fn build_query(&self, query: &T) -> (String, sea_query_binder::SqlxValues) { - query.build_sqlx(sea_query::SqliteQueryBuilder) - } -} - -pub trait RowsAffected { - fn rows_affected(&self) -> u64; -} - -#[cfg(test)] -impl RowsAffected for sqlx::sqlite::SqliteQueryResult { - fn rows_affected(&self) -> u64 { - self.rows_affected() - } -} - -impl RowsAffected for sqlx::postgres::PgQueryResult { - fn rows_affected(&self) -> u64 { - self.rows_affected() - } -} - -#[cfg(test)] -impl Db { - pub async fn new(url: &str, max_connections: u32) -> Result { - use std::str::FromStr as _; - let options = sqlx::sqlite::SqliteConnectOptions::from_str(url) - .unwrap() - .create_if_missing(true) - .shared_cache(true); - let pool = sqlx::sqlite::SqlitePoolOptions::new() - .min_connections(2) - .max_connections(max_connections) - .connect_with(options) - .await?; +impl Database { + pub async fn new(options: ConnectOptions) -> Result { Ok(Self { - pool, - rooms: Default::default(), - background: None, - runtime: None, - }) - } - - pub async fn get_users_by_ids(&self, ids: Vec) -> Result> { - self.transact(|tx| async { - let mut tx = tx; - let query = " - SELECT users.* - FROM users - WHERE users.id IN (SELECT value from json_each($1)) - "; - Ok(sqlx::query_as(query) - .bind(&serde_json::json!(ids)) - .fetch_all(&mut tx) - .await?) - }) - .await - } - - pub async fn get_user_metrics_id(&self, id: UserId) -> Result { - self.transact(|mut tx| async move { - let query = " - SELECT metrics_id - FROM users - WHERE id = $1 - "; - Ok(sqlx::query_scalar(query) - .bind(id) - .fetch_one(&mut tx) - .await?) - }) - .await - } - - pub async fn create_user( - &self, - email_address: &str, - admin: bool, - params: NewUserParams, - ) -> Result { - self.transact(|mut tx| async { - let query = " - INSERT INTO users (email_address, github_login, github_user_id, admin, metrics_id) - VALUES ($1, $2, $3, $4, $5) - ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login - RETURNING id, metrics_id - "; - - let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query) - .bind(email_address) - .bind(¶ms.github_login) - .bind(¶ms.github_user_id) - .bind(admin) - .bind(Uuid::new_v4().to_string()) - .fetch_one(&mut tx) - .await?; - tx.commit().await?; - Ok(NewUserResult { - user_id, - metrics_id, - signup_device_id: None, - inviting_user_id: None, - }) - }) - .await - } - - pub async fn fuzzy_search_users(&self, _name_query: &str, _limit: u32) -> Result> { - unimplemented!() - } - - pub async fn create_user_from_invite( - &self, - _invite: &Invite, - _user: NewUserParams, - ) -> Result> { - unimplemented!() - } - - pub async fn create_signup(&self, _signup: Signup) -> Result<()> { - unimplemented!() - } - - pub async fn create_invite_from_code( - &self, - _code: &str, - _email_address: &str, - _device_id: Option<&str>, - ) -> Result { - unimplemented!() - } - - pub async fn record_sent_invites(&self, _invites: &[Invite]) -> Result<()> { - unimplemented!() - } -} - -impl Db { - pub async fn new(url: &str, max_connections: u32) -> Result { - let pool = sqlx::postgres::PgPoolOptions::new() - .max_connections(max_connections) - .connect(url) - .await?; - Ok(Self { - pool, + options: options.clone(), + pool: sea_orm::Database::connect(options).await?, rooms: DashMap::with_capacity(16384), #[cfg(test)] background: None, @@ -260,396 +59,6 @@ impl Db { }) } - #[cfg(test)] - pub fn teardown(&self, url: &str) { - self.runtime.as_ref().unwrap().block_on(async { - use util::ResultExt; - let query = " - SELECT pg_terminate_backend(pg_stat_activity.pid) - FROM pg_stat_activity - WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid(); - "; - sqlx::query(query).execute(&self.pool).await.log_err(); - self.pool.close().await; - ::drop_database(url) - .await - .log_err(); - }) - } - - pub async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result> { - self.transact(|tx| async { - let mut tx = tx; - let like_string = Self::fuzzy_like_string(name_query); - let query = " - SELECT users.* - FROM users - WHERE github_login ILIKE $1 - ORDER BY github_login <-> $2 - LIMIT $3 - "; - Ok(sqlx::query_as(query) - .bind(like_string) - .bind(name_query) - .bind(limit as i32) - .fetch_all(&mut tx) - .await?) - }) - .await - } - - pub async fn get_users_by_ids(&self, ids: Vec) -> Result> { - let ids = ids.iter().map(|id| id.0).collect::>(); - self.transact(|tx| async { - let mut tx = tx; - let query = " - SELECT users.* - FROM users - WHERE users.id = ANY ($1) - "; - Ok(sqlx::query_as(query).bind(&ids).fetch_all(&mut tx).await?) - }) - .await - } - - pub async fn get_user_metrics_id(&self, id: UserId) -> Result { - self.transact(|mut tx| async move { - let query = " - SELECT metrics_id::text - FROM users - WHERE id = $1 - "; - Ok(sqlx::query_scalar(query) - .bind(id) - .fetch_one(&mut tx) - .await?) - }) - .await - } - - pub async fn create_user( - &self, - email_address: &str, - admin: bool, - params: NewUserParams, - ) -> Result { - self.transact(|mut tx| async { - let query = " - INSERT INTO users (email_address, github_login, github_user_id, admin) - VALUES ($1, $2, $3, $4) - ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login - RETURNING id, metrics_id::text - "; - - let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query) - .bind(email_address) - .bind(¶ms.github_login) - .bind(params.github_user_id) - .bind(admin) - .fetch_one(&mut tx) - .await?; - tx.commit().await?; - - Ok(NewUserResult { - user_id, - metrics_id, - signup_device_id: None, - inviting_user_id: None, - }) - }) - .await - } - - pub async fn create_user_from_invite( - &self, - invite: &Invite, - user: NewUserParams, - ) -> Result> { - self.transact(|mut tx| async { - let (signup_id, existing_user_id, inviting_user_id, signup_device_id): ( - i32, - Option, - Option, - Option, - ) = sqlx::query_as( - " - SELECT id, user_id, inviting_user_id, device_id - FROM signups - WHERE - email_address = $1 AND - email_confirmation_code = $2 - ", - ) - .bind(&invite.email_address) - .bind(&invite.email_confirmation_code) - .fetch_optional(&mut tx) - .await? - .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?; - - if existing_user_id.is_some() { - return Ok(None); - } - - let (user_id, metrics_id): (UserId, String) = sqlx::query_as( - " - INSERT INTO users - (email_address, github_login, github_user_id, admin, invite_count, invite_code) - VALUES - ($1, $2, $3, FALSE, $4, $5) - ON CONFLICT (github_login) DO UPDATE SET - email_address = excluded.email_address, - github_user_id = excluded.github_user_id, - admin = excluded.admin - RETURNING id, metrics_id::text - ", - ) - .bind(&invite.email_address) - .bind(&user.github_login) - .bind(&user.github_user_id) - .bind(&user.invite_count) - .bind(random_invite_code()) - .fetch_one(&mut tx) - .await?; - - sqlx::query( - " - UPDATE signups - SET user_id = $1 - WHERE id = $2 - ", - ) - .bind(&user_id) - .bind(&signup_id) - .execute(&mut tx) - .await?; - - if let Some(inviting_user_id) = inviting_user_id { - let id: Option = sqlx::query_scalar( - " - UPDATE users - SET invite_count = invite_count - 1 - WHERE id = $1 AND invite_count > 0 - RETURNING id - ", - ) - .bind(&inviting_user_id) - .fetch_optional(&mut tx) - .await?; - - if id.is_none() { - Err(Error::Http( - StatusCode::UNAUTHORIZED, - "no invites remaining".to_string(), - ))?; - } - - sqlx::query( - " - INSERT INTO contacts - (user_id_a, user_id_b, a_to_b, should_notify, accepted) - VALUES - ($1, $2, TRUE, TRUE, TRUE) - ON CONFLICT DO NOTHING - ", - ) - .bind(inviting_user_id) - .bind(user_id) - .execute(&mut tx) - .await?; - } - - tx.commit().await?; - Ok(Some(NewUserResult { - user_id, - metrics_id, - inviting_user_id, - signup_device_id, - })) - }) - .await - } - - pub async fn create_signup(&self, signup: Signup) -> Result<()> { - self.transact(|mut tx| async { - sqlx::query( - " - INSERT INTO signups - ( - email_address, - email_confirmation_code, - email_confirmation_sent, - platform_linux, - platform_mac, - platform_windows, - platform_unknown, - editor_features, - programming_languages, - device_id - ) - VALUES - ($1, $2, FALSE, $3, $4, $5, FALSE, $6, $7, $8) - RETURNING id - ", - ) - .bind(&signup.email_address) - .bind(&random_email_confirmation_code()) - .bind(&signup.platform_linux) - .bind(&signup.platform_mac) - .bind(&signup.platform_windows) - .bind(&signup.editor_features) - .bind(&signup.programming_languages) - .bind(&signup.device_id) - .execute(&mut tx) - .await?; - tx.commit().await?; - Ok(()) - }) - .await - } - - pub async fn create_invite_from_code( - &self, - code: &str, - email_address: &str, - device_id: Option<&str>, - ) -> Result { - self.transact(|mut tx| async { - let existing_user: Option = sqlx::query_scalar( - " - SELECT id - FROM users - WHERE email_address = $1 - ", - ) - .bind(email_address) - .fetch_optional(&mut tx) - .await?; - if existing_user.is_some() { - Err(anyhow!("email address is already in use"))?; - } - - let row: Option<(UserId, i32)> = sqlx::query_as( - " - SELECT id, invite_count - FROM users - WHERE invite_code = $1 - ", - ) - .bind(code) - .fetch_optional(&mut tx) - .await?; - - let (inviter_id, invite_count) = match row { - Some(row) => row, - None => Err(Error::Http( - StatusCode::NOT_FOUND, - "invite code not found".to_string(), - ))?, - }; - - if invite_count == 0 { - Err(Error::Http( - StatusCode::UNAUTHORIZED, - "no invites remaining".to_string(), - ))?; - } - - let email_confirmation_code: String = sqlx::query_scalar( - " - INSERT INTO signups - ( - email_address, - email_confirmation_code, - email_confirmation_sent, - inviting_user_id, - platform_linux, - platform_mac, - platform_windows, - platform_unknown, - device_id - ) - VALUES - ($1, $2, FALSE, $3, FALSE, FALSE, FALSE, TRUE, $4) - ON CONFLICT (email_address) - DO UPDATE SET - inviting_user_id = excluded.inviting_user_id - RETURNING email_confirmation_code - ", - ) - .bind(&email_address) - .bind(&random_email_confirmation_code()) - .bind(&inviter_id) - .bind(&device_id) - .fetch_one(&mut tx) - .await?; - - tx.commit().await?; - - Ok(Invite { - email_address: email_address.into(), - email_confirmation_code, - }) - }) - .await - } - - pub async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> { - self.transact(|mut tx| async { - let emails = invites - .iter() - .map(|s| s.email_address.as_str()) - .collect::>(); - sqlx::query( - " - UPDATE signups - SET email_confirmation_sent = TRUE - WHERE email_address = ANY ($1) - ", - ) - .bind(&emails) - .execute(&mut tx) - .await?; - tx.commit().await?; - Ok(()) - }) - .await - } -} - -impl Db -where - Self: BeginTransaction + BuildQuery, - D: sqlx::Database + sqlx::migrate::MigrateDatabase, - D::Connection: sqlx::migrate::Migrate, - for<'a> >::Arguments: sqlx::IntoArguments<'a, D>, - for<'a> sea_query_binder::SqlxValues: sqlx::IntoArguments<'a, D>, - for<'a> &'a mut D::Connection: sqlx::Executor<'a, Database = D>, - for<'a, 'b> &'b mut sqlx::Transaction<'a, D>: sqlx::Executor<'b, Database = D>, - D::QueryResult: RowsAffected, - String: sqlx::Type, - i32: sqlx::Type, - i64: sqlx::Type, - bool: sqlx::Type, - str: sqlx::Type, - Uuid: sqlx::Type, - sqlx::types::Json: sqlx::Type, - OffsetDateTime: sqlx::Type, - PrimitiveDateTime: sqlx::Type, - usize: sqlx::ColumnIndex, - for<'a> &'a str: sqlx::ColumnIndex, - for<'a> &'a str: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>, - for<'a> String: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>, - for<'a> Option: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>, - for<'a> Option<&'a str>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>, - for<'a> i32: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>, - for<'a> i64: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>, - for<'a> bool: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>, - for<'a> Uuid: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>, - for<'a> Option: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>, - for<'a> sqlx::types::JsonValue: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>, - for<'a> OffsetDateTime: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>, - for<'a> PrimitiveDateTime: sqlx::Decode<'a, D> + sqlx::Decode<'a, D>, -{ pub async fn migrate( &self, migrations_path: &Path, @@ -659,10 +68,10 @@ where .await .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?; - let mut conn = self.pool.acquire().await?; + let mut connection = sqlx::AnyConnection::connect(self.options.get_url()).await?; - conn.ensure_migrations_table().await?; - let applied_migrations: HashMap<_, _> = conn + connection.ensure_migrations_table().await?; + let applied_migrations: HashMap<_, _> = connection .list_applied_migrations() .await? .into_iter() @@ -682,7 +91,7 @@ where } } None => { - let elapsed = conn.apply(&migration).await?; + let elapsed = connection.apply(&migration).await?; new_migrations.push((migration, elapsed)); } } @@ -691,6 +100,457 @@ where Ok(new_migrations) } + // users + + pub async fn create_user( + &self, + email_address: &str, + admin: bool, + params: NewUserParams, + ) -> Result { + self.transact(|tx| async { + let user = user::Entity::insert(user::ActiveModel { + email_address: ActiveValue::set(Some(email_address.into())), + github_login: ActiveValue::set(params.github_login.clone()), + github_user_id: ActiveValue::set(Some(params.github_user_id)), + admin: ActiveValue::set(admin), + metrics_id: ActiveValue::set(Uuid::new_v4()), + ..Default::default() + }) + .on_conflict( + OnConflict::column(user::Column::GithubLogin) + .update_column(user::Column::GithubLogin) + .to_owned(), + ) + .exec_with_returning(&tx) + .await?; + + tx.commit().await?; + + Ok(NewUserResult { + user_id: user.id, + metrics_id: user.metrics_id.to_string(), + signup_device_id: None, + inviting_user_id: None, + }) + }) + .await + } + + pub async fn get_user_by_id(&self, id: UserId) -> Result> { + self.transact(|tx| async move { Ok(user::Entity::find_by_id(id).one(&tx).await?) }) + .await + } + + pub async fn get_users_by_ids(&self, ids: Vec) -> Result> { + self.transact(|tx| async { + let tx = tx; + Ok(user::Entity::find() + .filter(user::Column::Id.is_in(ids.iter().copied())) + .all(&tx) + .await?) + }) + .await + } + + pub async fn get_user_by_github_account( + &self, + github_login: &str, + github_user_id: Option, + ) -> Result> { + self.transact(|tx| async { + let tx = tx; + if let Some(github_user_id) = github_user_id { + if let Some(user_by_github_user_id) = user::Entity::find() + .filter(user::Column::GithubUserId.eq(github_user_id)) + .one(&tx) + .await? + { + let mut user_by_github_user_id = user_by_github_user_id.into_active_model(); + user_by_github_user_id.github_login = ActiveValue::set(github_login.into()); + Ok(Some(user_by_github_user_id.update(&tx).await?)) + } else if let Some(user_by_github_login) = user::Entity::find() + .filter(user::Column::GithubLogin.eq(github_login)) + .one(&tx) + .await? + { + let mut user_by_github_login = user_by_github_login.into_active_model(); + user_by_github_login.github_user_id = ActiveValue::set(Some(github_user_id)); + Ok(Some(user_by_github_login.update(&tx).await?)) + } else { + Ok(None) + } + } else { + Ok(user::Entity::find() + .filter(user::Column::GithubLogin.eq(github_login)) + .one(&tx) + .await?) + } + }) + .await + } + + pub async fn get_all_users(&self, page: u32, limit: u32) -> Result> { + self.transact(|tx| async move { + Ok(user::Entity::find() + .order_by_asc(user::Column::GithubLogin) + .limit(limit as u64) + .offset(page as u64 * limit as u64) + .all(&tx) + .await?) + }) + .await + } + + pub async fn get_users_with_no_invites( + &self, + invited_by_another_user: bool, + ) -> Result> { + self.transact(|tx| async move { + Ok(user::Entity::find() + .filter( + user::Column::InviteCount + .eq(0) + .and(if invited_by_another_user { + user::Column::InviterId.is_not_null() + } else { + user::Column::InviterId.is_null() + }), + ) + .all(&tx) + .await?) + }) + .await + } + + pub async fn get_user_metrics_id(&self, id: UserId) -> Result { + #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] + enum QueryAs { + MetricsId, + } + + self.transact(|tx| async move { + let metrics_id: Uuid = user::Entity::find_by_id(id) + .select_only() + .column(user::Column::MetricsId) + .into_values::<_, QueryAs>() + .one(&tx) + .await? + .ok_or_else(|| anyhow!("could not find user"))?; + Ok(metrics_id.to_string()) + }) + .await + } + + pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> { + self.transact(|tx| async move { + user::Entity::update_many() + .filter(user::Column::Id.eq(id)) + .col_expr(user::Column::Admin, is_admin.into()) + .exec(&tx) + .await?; + tx.commit().await?; + Ok(()) + }) + .await + } + + pub async fn destroy_user(&self, id: UserId) -> Result<()> { + self.transact(|tx| async move { + access_token::Entity::delete_many() + .filter(access_token::Column::UserId.eq(id)) + .exec(&tx) + .await?; + user::Entity::delete_by_id(id).exec(&tx).await?; + tx.commit().await?; + Ok(()) + }) + .await + } + + // contacts + + pub async fn get_contacts(&self, user_id: UserId) -> Result> { + #[derive(Debug, FromQueryResult)] + struct ContactWithUserBusyStatuses { + user_id_a: UserId, + user_id_b: UserId, + a_to_b: bool, + accepted: bool, + should_notify: bool, + user_a_busy: bool, + user_b_busy: bool, + } + + self.transact(|tx| async move { + let user_a_participant = Alias::new("user_a_participant"); + let user_b_participant = Alias::new("user_b_participant"); + let mut db_contacts = contact::Entity::find() + .column_as( + Expr::tbl(user_a_participant.clone(), room_participant::Column::Id) + .is_not_null(), + "user_a_busy", + ) + .column_as( + Expr::tbl(user_b_participant.clone(), room_participant::Column::Id) + .is_not_null(), + "user_b_busy", + ) + .filter( + contact::Column::UserIdA + .eq(user_id) + .or(contact::Column::UserIdB.eq(user_id)), + ) + .join_as( + JoinType::LeftJoin, + contact::Relation::UserARoomParticipant.def(), + user_a_participant, + ) + .join_as( + JoinType::LeftJoin, + contact::Relation::UserBRoomParticipant.def(), + user_b_participant, + ) + .into_model::() + .stream(&tx) + .await?; + + let mut contacts = Vec::new(); + while let Some(db_contact) = db_contacts.next().await { + let db_contact = db_contact?; + if db_contact.user_id_a == user_id { + if db_contact.accepted { + contacts.push(Contact::Accepted { + user_id: db_contact.user_id_b, + should_notify: db_contact.should_notify && db_contact.a_to_b, + busy: db_contact.user_b_busy, + }); + } else if db_contact.a_to_b { + contacts.push(Contact::Outgoing { + user_id: db_contact.user_id_b, + }) + } else { + contacts.push(Contact::Incoming { + user_id: db_contact.user_id_b, + should_notify: db_contact.should_notify, + }); + } + } else if db_contact.accepted { + contacts.push(Contact::Accepted { + user_id: db_contact.user_id_a, + should_notify: db_contact.should_notify && !db_contact.a_to_b, + busy: db_contact.user_a_busy, + }); + } else if db_contact.a_to_b { + contacts.push(Contact::Incoming { + user_id: db_contact.user_id_a, + should_notify: db_contact.should_notify, + }); + } else { + contacts.push(Contact::Outgoing { + user_id: db_contact.user_id_a, + }); + } + } + + contacts.sort_unstable_by_key(|contact| contact.user_id()); + + Ok(contacts) + }) + .await + } + + pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result { + self.transact(|tx| async move { + let (id_a, id_b) = if user_id_1 < user_id_2 { + (user_id_1, user_id_2) + } else { + (user_id_2, user_id_1) + }; + + Ok(contact::Entity::find() + .filter( + contact::Column::UserIdA + .eq(id_a) + .and(contact::Column::UserIdB.eq(id_b)) + .and(contact::Column::Accepted.eq(true)), + ) + .one(&tx) + .await? + .is_some()) + }) + .await + } + + pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> { + self.transact(|tx| async move { + let (id_a, id_b, a_to_b) = if sender_id < receiver_id { + (sender_id, receiver_id, true) + } else { + (receiver_id, sender_id, false) + }; + + let rows_affected = contact::Entity::insert(contact::ActiveModel { + user_id_a: ActiveValue::set(id_a), + user_id_b: ActiveValue::set(id_b), + a_to_b: ActiveValue::set(a_to_b), + accepted: ActiveValue::set(false), + should_notify: ActiveValue::set(true), + ..Default::default() + }) + .on_conflict( + OnConflict::columns([contact::Column::UserIdA, contact::Column::UserIdB]) + .values([ + (contact::Column::Accepted, true.into()), + (contact::Column::ShouldNotify, false.into()), + ]) + .action_and_where( + contact::Column::Accepted.eq(false).and( + contact::Column::AToB + .eq(a_to_b) + .and(contact::Column::UserIdA.eq(id_b)) + .or(contact::Column::AToB + .ne(a_to_b) + .and(contact::Column::UserIdA.eq(id_a))), + ), + ) + .to_owned(), + ) + .exec_without_returning(&tx) + .await?; + + if rows_affected == 1 { + tx.commit().await?; + Ok(()) + } else { + Err(anyhow!("contact already requested"))? + } + }) + .await + } + + pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> { + self.transact(|tx| async move { + let (id_a, id_b) = if responder_id < requester_id { + (responder_id, requester_id) + } else { + (requester_id, responder_id) + }; + + let result = contact::Entity::delete_many() + .filter( + contact::Column::UserIdA + .eq(id_a) + .and(contact::Column::UserIdB.eq(id_b)), + ) + .exec(&tx) + .await?; + + if result.rows_affected == 1 { + tx.commit().await?; + Ok(()) + } else { + Err(anyhow!("no such contact"))? + } + }) + .await + } + + pub async fn dismiss_contact_notification( + &self, + user_id: UserId, + contact_user_id: UserId, + ) -> Result<()> { + self.transact(|tx| async move { + let (id_a, id_b, a_to_b) = if user_id < contact_user_id { + (user_id, contact_user_id, true) + } else { + (contact_user_id, user_id, false) + }; + + let result = contact::Entity::update_many() + .set(contact::ActiveModel { + should_notify: ActiveValue::set(false), + ..Default::default() + }) + .filter( + contact::Column::UserIdA + .eq(id_a) + .and(contact::Column::UserIdB.eq(id_b)) + .and( + contact::Column::AToB + .eq(a_to_b) + .and(contact::Column::Accepted.eq(true)) + .or(contact::Column::AToB + .ne(a_to_b) + .and(contact::Column::Accepted.eq(false))), + ), + ) + .exec(&tx) + .await?; + if result.rows_affected == 0 { + Err(anyhow!("no such contact request"))? + } else { + tx.commit().await?; + Ok(()) + } + }) + .await + } + + pub async fn respond_to_contact_request( + &self, + responder_id: UserId, + requester_id: UserId, + accept: bool, + ) -> Result<()> { + self.transact(|tx| async move { + let (id_a, id_b, a_to_b) = if responder_id < requester_id { + (responder_id, requester_id, false) + } else { + (requester_id, responder_id, true) + }; + let rows_affected = if accept { + let result = contact::Entity::update_many() + .set(contact::ActiveModel { + accepted: ActiveValue::set(true), + should_notify: ActiveValue::set(true), + ..Default::default() + }) + .filter( + contact::Column::UserIdA + .eq(id_a) + .and(contact::Column::UserIdB.eq(id_b)) + .and(contact::Column::AToB.eq(a_to_b)), + ) + .exec(&tx) + .await?; + result.rows_affected + } else { + let result = contact::Entity::delete_many() + .filter( + contact::Column::UserIdA + .eq(id_a) + .and(contact::Column::UserIdB.eq(id_b)) + .and(contact::Column::AToB.eq(a_to_b)) + .and(contact::Column::Accepted.eq(false)), + ) + .exec(&tx) + .await?; + + result.rows_affected + }; + + if rows_affected == 1 { + tx.commit().await?; + Ok(()) + } else { + Err(anyhow!("no such contact request"))? + } + }) + .await + } + pub fn fuzzy_like_string(string: &str) -> String { let mut result = String::with_capacity(string.len() * 2 + 1); for c in string.chars() { @@ -703,163 +563,58 @@ where result } - // users - - pub async fn get_all_users(&self, page: u32, limit: u32) -> Result> { + pub async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result> { self.transact(|tx| async { - let mut tx = tx; - let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2"; - Ok(sqlx::query_as(query) - .bind(limit as i32) - .bind((page * limit) as i32) - .fetch_all(&mut tx) - .await?) - }) - .await - } - - pub async fn get_user_by_id(&self, id: UserId) -> Result> { - self.transact(|tx| async { - let mut tx = tx; + let tx = tx; + let like_string = Self::fuzzy_like_string(name_query); let query = " SELECT users.* FROM users - WHERE id = $1 - LIMIT 1 + WHERE github_login ILIKE $1 + ORDER BY github_login <-> $2 + LIMIT $3 "; - Ok(sqlx::query_as(query) - .bind(&id) - .fetch_optional(&mut tx) + + Ok(user::Entity::find() + .from_raw_sql(Statement::from_sql_and_values( + self.pool.get_database_backend(), + query.into(), + vec![like_string.into(), name_query.into(), limit.into()], + )) + .all(&tx) .await?) }) .await } - pub async fn get_users_with_no_invites( - &self, - invited_by_another_user: bool, - ) -> Result> { - self.transact(|tx| async { - let mut tx = tx; - let query = format!( - " - SELECT users.* - FROM users - WHERE invite_count = 0 - AND inviter_id IS{} NULL - ", - if invited_by_another_user { " NOT" } else { "" } - ); - - Ok(sqlx::query_as(&query).fetch_all(&mut tx).await?) - }) - .await - } - - pub async fn get_user_by_github_account( - &self, - github_login: &str, - github_user_id: Option, - ) -> Result> { - self.transact(|tx| async { - let mut tx = tx; - if let Some(github_user_id) = github_user_id { - let mut user = sqlx::query_as::<_, User>( - " - UPDATE users - SET github_login = $1 - WHERE github_user_id = $2 - RETURNING * - ", - ) - .bind(github_login) - .bind(github_user_id) - .fetch_optional(&mut tx) - .await?; - - if user.is_none() { - user = sqlx::query_as::<_, User>( - " - UPDATE users - SET github_user_id = $1 - WHERE github_login = $2 - RETURNING * - ", - ) - .bind(github_user_id) - .bind(github_login) - .fetch_optional(&mut tx) - .await?; - } - - Ok(user) - } else { - let user = sqlx::query_as( - " - SELECT * FROM users - WHERE github_login = $1 - LIMIT 1 - ", - ) - .bind(github_login) - .fetch_optional(&mut tx) - .await?; - Ok(user) - } - }) - .await - } - - pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> { - self.transact(|mut tx| async { - let query = "UPDATE users SET admin = $1 WHERE id = $2"; - sqlx::query(query) - .bind(is_admin) - .bind(id.0) - .execute(&mut tx) - .await?; - tx.commit().await?; - Ok(()) - }) - .await - } - - pub async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> { - self.transact(|mut tx| async move { - let query = "UPDATE users SET connected_once = $1 WHERE id = $2"; - sqlx::query(query) - .bind(connected_once) - .bind(id.0) - .execute(&mut tx) - .await?; - tx.commit().await?; - Ok(()) - }) - .await - } - - pub async fn destroy_user(&self, id: UserId) -> Result<()> { - self.transact(|mut tx| async move { - let query = "DELETE FROM access_tokens WHERE user_id = $1;"; - sqlx::query(query) - .bind(id.0) - .execute(&mut tx) - .await - .map(drop)?; - let query = "DELETE FROM users WHERE id = $1;"; - sqlx::query(query).bind(id.0).execute(&mut tx).await?; - tx.commit().await?; - Ok(()) - }) - .await - } - // signups + pub async fn create_signup(&self, signup: NewSignup) -> Result<()> { + self.transact(|tx| async { + signup::ActiveModel { + email_address: ActiveValue::set(signup.email_address.clone()), + email_confirmation_code: ActiveValue::set(random_email_confirmation_code()), + email_confirmation_sent: ActiveValue::set(false), + platform_mac: ActiveValue::set(signup.platform_mac), + platform_windows: ActiveValue::set(signup.platform_windows), + platform_linux: ActiveValue::set(signup.platform_linux), + platform_unknown: ActiveValue::set(false), + editor_features: ActiveValue::set(Some(signup.editor_features.clone())), + programming_languages: ActiveValue::set(Some(signup.programming_languages.clone())), + device_id: ActiveValue::set(signup.device_id.clone()), + ..Default::default() + } + .insert(&tx) + .await?; + tx.commit().await?; + Ok(()) + }) + .await + } + pub async fn get_waitlist_summary(&self) -> Result { - self.transact(|mut tx| async move { - Ok(sqlx::query_as( - " + self.transact(|tx| async move { + let query = " SELECT COUNT(*) as count, COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count, @@ -872,63 +627,241 @@ where WHERE NOT email_confirmation_sent ) AS unsent - ", + "; + Ok( + WaitlistSummary::find_by_statement(Statement::from_sql_and_values( + self.pool.get_database_backend(), + query.into(), + vec![], + )) + .one(&tx) + .await? + .ok_or_else(|| anyhow!("invalid result"))?, ) - .fetch_one(&mut tx) - .await?) + }) + .await + } + + pub async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> { + let emails = invites + .iter() + .map(|s| s.email_address.as_str()) + .collect::>(); + self.transact(|tx| async { + signup::Entity::update_many() + .filter(signup::Column::EmailAddress.is_in(emails.iter().copied())) + .col_expr(signup::Column::EmailConfirmationSent, true.into()) + .exec(&tx) + .await?; + tx.commit().await?; + Ok(()) }) .await } pub async fn get_unsent_invites(&self, count: usize) -> Result> { - self.transact(|mut tx| async move { - Ok(sqlx::query_as( - " - SELECT - email_address, email_confirmation_code - FROM signups - WHERE - NOT email_confirmation_sent AND - (platform_mac OR platform_unknown) - LIMIT $1 - ", - ) - .bind(count as i32) - .fetch_all(&mut tx) - .await?) + self.transact(|tx| async move { + Ok(signup::Entity::find() + .select_only() + .column(signup::Column::EmailAddress) + .column(signup::Column::EmailConfirmationCode) + .filter( + signup::Column::EmailConfirmationSent.eq(false).and( + signup::Column::PlatformMac + .eq(true) + .or(signup::Column::PlatformUnknown.eq(true)), + ), + ) + .limit(count as u64) + .into_model() + .all(&tx) + .await?) }) .await } // invite codes - pub async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> { - self.transact(|mut tx| async move { - if count > 0 { - sqlx::query( - " - UPDATE users - SET invite_code = $1 - WHERE id = $2 AND invite_code IS NULL - ", + pub async fn create_invite_from_code( + &self, + code: &str, + email_address: &str, + device_id: Option<&str>, + ) -> Result { + self.transact(|tx| async move { + let existing_user = user::Entity::find() + .filter(user::Column::EmailAddress.eq(email_address)) + .one(&tx) + .await?; + + if existing_user.is_some() { + Err(anyhow!("email address is already in use"))?; + } + + let inviter = match user::Entity::find() + .filter(user::Column::InviteCode.eq(code)) + .one(&tx) + .await? + { + Some(inviter) => inviter, + None => { + return Err(Error::Http( + StatusCode::NOT_FOUND, + "invite code not found".to_string(), + ))? + } + }; + + if inviter.invite_count == 0 { + Err(Error::Http( + StatusCode::UNAUTHORIZED, + "no invites remaining".to_string(), + ))?; + } + + let signup = signup::Entity::insert(signup::ActiveModel { + email_address: ActiveValue::set(email_address.into()), + email_confirmation_code: ActiveValue::set(random_email_confirmation_code()), + email_confirmation_sent: ActiveValue::set(false), + inviting_user_id: ActiveValue::set(Some(inviter.id)), + platform_linux: ActiveValue::set(false), + platform_mac: ActiveValue::set(false), + platform_windows: ActiveValue::set(false), + platform_unknown: ActiveValue::set(true), + device_id: ActiveValue::set(device_id.map(|device_id| device_id.into())), + ..Default::default() + }) + .on_conflict( + OnConflict::column(signup::Column::EmailAddress) + .update_column(signup::Column::InvitingUserId) + .to_owned(), + ) + .exec_with_returning(&tx) + .await?; + tx.commit().await?; + + Ok(Invite { + email_address: signup.email_address, + email_confirmation_code: signup.email_confirmation_code, + }) + }) + .await + } + + pub async fn create_user_from_invite( + &self, + invite: &Invite, + user: NewUserParams, + ) -> Result> { + self.transact(|tx| async { + let tx = tx; + let signup = signup::Entity::find() + .filter( + signup::Column::EmailAddress + .eq(invite.email_address.as_str()) + .and( + signup::Column::EmailConfirmationCode + .eq(invite.email_confirmation_code.as_str()), + ), ) - .bind(random_invite_code()) - .bind(id) - .execute(&mut tx) + .one(&tx) + .await? + .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?; + + if signup.user_id.is_some() { + return Ok(None); + } + + let user = user::Entity::insert(user::ActiveModel { + email_address: ActiveValue::set(Some(invite.email_address.clone())), + github_login: ActiveValue::set(user.github_login.clone()), + github_user_id: ActiveValue::set(Some(user.github_user_id)), + admin: ActiveValue::set(false), + invite_count: ActiveValue::set(user.invite_count), + invite_code: ActiveValue::set(Some(random_invite_code())), + metrics_id: ActiveValue::set(Uuid::new_v4()), + ..Default::default() + }) + .on_conflict( + OnConflict::column(user::Column::GithubLogin) + .update_columns([ + user::Column::EmailAddress, + user::Column::GithubUserId, + user::Column::Admin, + ]) + .to_owned(), + ) + .exec_with_returning(&tx) + .await?; + + let mut signup = signup.into_active_model(); + signup.user_id = ActiveValue::set(Some(user.id)); + let signup = signup.update(&tx).await?; + + if let Some(inviting_user_id) = signup.inviting_user_id { + let result = user::Entity::update_many() + .filter( + user::Column::Id + .eq(inviting_user_id) + .and(user::Column::InviteCount.gt(0)), + ) + .col_expr( + user::Column::InviteCount, + Expr::col(user::Column::InviteCount).sub(1), + ) + .exec(&tx) + .await?; + + if result.rows_affected == 0 { + Err(Error::Http( + StatusCode::UNAUTHORIZED, + "no invites remaining".to_string(), + ))?; + } + + contact::Entity::insert(contact::ActiveModel { + user_id_a: ActiveValue::set(inviting_user_id), + user_id_b: ActiveValue::set(user.id), + a_to_b: ActiveValue::set(true), + should_notify: ActiveValue::set(true), + accepted: ActiveValue::set(true), + ..Default::default() + }) + .on_conflict(OnConflict::new().do_nothing().to_owned()) + .exec_without_returning(&tx) .await?; } - sqlx::query( - " - UPDATE users - SET invite_count = $1 - WHERE id = $2 - ", - ) - .bind(count as i32) - .bind(id) - .execute(&mut tx) - .await?; + tx.commit().await?; + Ok(Some(NewUserResult { + user_id: user.id, + metrics_id: user.metrics_id.to_string(), + inviting_user_id: signup.inviting_user_id, + signup_device_id: signup.device_id, + })) + }) + .await + } + + pub async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> { + self.transact(|tx| async move { + if count > 0 { + user::Entity::update_many() + .filter( + user::Column::Id + .eq(id) + .and(user::Column::InviteCode.is_null()), + ) + .col_expr(user::Column::InviteCode, random_invite_code().into()) + .exec(&tx) + .await?; + } + + user::Entity::update_many() + .filter(user::Column::Id.eq(id)) + .col_expr(user::Column::InviteCount, count.into()) + .exec(&tx) + .await?; tx.commit().await?; Ok(()) }) @@ -936,535 +869,109 @@ where } pub async fn get_invite_code_for_user(&self, id: UserId) -> Result> { - self.transact(|mut tx| async move { - let result: Option<(String, i32)> = sqlx::query_as( - " - SELECT invite_code, invite_count - FROM users - WHERE id = $1 AND invite_code IS NOT NULL - ", - ) - .bind(id) - .fetch_optional(&mut tx) - .await?; - if let Some((code, count)) = result { - Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?))) - } else { - Ok(None) + self.transact(|tx| async move { + match user::Entity::find_by_id(id).one(&tx).await? { + Some(user) if user.invite_code.is_some() => { + Ok(Some((user.invite_code.unwrap(), user.invite_count as u32))) + } + _ => Ok(None), } }) .await } pub async fn get_user_for_invite_code(&self, code: &str) -> Result { - self.transact(|tx| async { - let mut tx = tx; - sqlx::query_as( - " - SELECT * - FROM users - WHERE invite_code = $1 - ", - ) - .bind(code) - .fetch_optional(&mut tx) - .await? - .ok_or_else(|| { - Error::Http( - StatusCode::NOT_FOUND, - "that invite code does not exist".to_string(), - ) - }) + self.transact(|tx| async move { + user::Entity::find() + .filter(user::Column::InviteCode.eq(code)) + .one(&tx) + .await? + .ok_or_else(|| { + Error::Http( + StatusCode::NOT_FOUND, + "that invite code does not exist".to_string(), + ) + }) }) .await } - async fn commit_room_transaction<'a, T>( - &'a self, - room_id: RoomId, - tx: sqlx::Transaction<'static, D>, - data: T, - ) -> Result> { - let lock = self.rooms.entry(room_id).or_default().clone(); - let _guard = lock.lock_owned().await; - tx.commit().await?; - Ok(RoomGuard { - data, - _guard, - _not_send: PhantomData, - }) - } + // projects - pub async fn create_room( + pub async fn share_project( &self, - user_id: UserId, + room_id: RoomId, connection_id: ConnectionId, - live_kit_room: &str, - ) -> Result> { - self.transact(|mut tx| async move { - let room_id = sqlx::query_scalar( - " - INSERT INTO rooms (live_kit_room) - VALUES ($1) - RETURNING id - ", - ) - .bind(&live_kit_room) - .fetch_one(&mut tx) - .await - .map(RoomId)?; + worktrees: &[proto::WorktreeMetadata], + ) -> Result> { + self.transact(|tx| async move { + let participant = room_participant::Entity::find() + .filter(room_participant::Column::AnsweringConnectionId.eq(connection_id.0)) + .one(&tx) + .await? + .ok_or_else(|| anyhow!("could not find participant"))?; + if participant.room_id != room_id { + return Err(anyhow!("shared project on unexpected room"))?; + } - sqlx::query( - " - INSERT INTO room_participants (room_id, user_id, answering_connection_id, calling_user_id, calling_connection_id) - VALUES ($1, $2, $3, $4, $5) - ", - ) - .bind(room_id) - .bind(user_id) - .bind(connection_id.0 as i32) - .bind(user_id) - .bind(connection_id.0 as i32) - .execute(&mut tx) + let project = project::ActiveModel { + room_id: ActiveValue::set(participant.room_id), + host_user_id: ActiveValue::set(participant.user_id), + host_connection_id: ActiveValue::set(connection_id.0 as i32), + ..Default::default() + } + .insert(&tx) .await?; - let room = self.get_room(room_id, &mut tx).await?; - self.commit_room_transaction(room_id, tx, room).await - }).await - } - - pub async fn call( - &self, - room_id: RoomId, - calling_user_id: UserId, - calling_connection_id: ConnectionId, - called_user_id: UserId, - initial_project_id: Option, - ) -> Result> { - self.transact(|mut tx| async move { - sqlx::query( - " - INSERT INTO room_participants ( - room_id, - user_id, - calling_user_id, - calling_connection_id, - initial_project_id - ) - VALUES ($1, $2, $3, $4, $5) - ", - ) - .bind(room_id) - .bind(called_user_id) - .bind(calling_user_id) - .bind(calling_connection_id.0 as i32) - .bind(initial_project_id) - .execute(&mut tx) + worktree::Entity::insert_many(worktrees.iter().map(|worktree| worktree::ActiveModel { + id: ActiveValue::set(worktree.id as i32), + project_id: ActiveValue::set(project.id), + abs_path: ActiveValue::set(worktree.abs_path.clone()), + root_name: ActiveValue::set(worktree.root_name.clone()), + visible: ActiveValue::set(worktree.visible), + scan_id: ActiveValue::set(0), + is_complete: ActiveValue::set(false), + })) + .exec(&tx) .await?; - let room = self.get_room(room_id, &mut tx).await?; - let incoming_call = Self::build_incoming_call(&room, called_user_id) - .ok_or_else(|| anyhow!("failed to build incoming call"))?; - self.commit_room_transaction(room_id, tx, (room, incoming_call)) + project_collaborator::ActiveModel { + project_id: ActiveValue::set(project.id), + connection_id: ActiveValue::set(connection_id.0 as i32), + user_id: ActiveValue::set(participant.user_id), + replica_id: ActiveValue::set(0), + is_host: ActiveValue::set(true), + ..Default::default() + } + .insert(&tx) + .await?; + + let room = self.get_room(room_id, &tx).await?; + self.commit_room_transaction(room_id, tx, (project.id, room)) .await }) .await } - pub async fn incoming_call_for_user( - &self, - user_id: UserId, - ) -> Result> { - self.transact(|mut tx| async move { - let room_id = sqlx::query_scalar::<_, RoomId>( - " - SELECT room_id - FROM room_participants - WHERE user_id = $1 AND answering_connection_id IS NULL - ", - ) - .bind(user_id) - .fetch_optional(&mut tx) + async fn get_room(&self, room_id: RoomId, tx: &DatabaseTransaction) -> Result { + let db_room = room::Entity::find_by_id(room_id) + .one(tx) + .await? + .ok_or_else(|| anyhow!("could not find room"))?; + + let mut db_participants = db_room + .find_related(room_participant::Entity) + .stream(tx) .await?; - - if let Some(room_id) = room_id { - let room = self.get_room(room_id, &mut tx).await?; - Ok(Self::build_incoming_call(&room, user_id)) - } else { - Ok(None) - } - }) - .await - } - - fn build_incoming_call( - room: &proto::Room, - called_user_id: UserId, - ) -> Option { - let pending_participant = room - .pending_participants - .iter() - .find(|participant| participant.user_id == called_user_id.to_proto())?; - - Some(proto::IncomingCall { - room_id: room.id, - calling_user_id: pending_participant.calling_user_id, - participant_user_ids: room - .participants - .iter() - .map(|participant| participant.user_id) - .collect(), - initial_project: room.participants.iter().find_map(|participant| { - let initial_project_id = pending_participant.initial_project_id?; - participant - .projects - .iter() - .find(|project| project.id == initial_project_id) - .cloned() - }), - }) - } - - pub async fn call_failed( - &self, - room_id: RoomId, - called_user_id: UserId, - ) -> Result> { - self.transact(|mut tx| async move { - sqlx::query( - " - DELETE FROM room_participants - WHERE room_id = $1 AND user_id = $2 - ", - ) - .bind(room_id) - .bind(called_user_id) - .execute(&mut tx) - .await?; - - let room = self.get_room(room_id, &mut tx).await?; - self.commit_room_transaction(room_id, tx, room).await - }) - .await - } - - pub async fn decline_call( - &self, - expected_room_id: Option, - user_id: UserId, - ) -> Result> { - self.transact(|mut tx| async move { - let room_id = sqlx::query_scalar( - " - DELETE FROM room_participants - WHERE user_id = $1 AND answering_connection_id IS NULL - RETURNING room_id - ", - ) - .bind(user_id) - .fetch_one(&mut tx) - .await?; - if expected_room_id.map_or(false, |expected_room_id| expected_room_id != room_id) { - return Err(anyhow!("declining call on unexpected room"))?; - } - - let room = self.get_room(room_id, &mut tx).await?; - self.commit_room_transaction(room_id, tx, room).await - }) - .await - } - - pub async fn cancel_call( - &self, - expected_room_id: Option, - calling_connection_id: ConnectionId, - called_user_id: UserId, - ) -> Result> { - self.transact(|mut tx| async move { - let room_id = sqlx::query_scalar( - " - DELETE FROM room_participants - WHERE user_id = $1 AND calling_connection_id = $2 AND answering_connection_id IS NULL - RETURNING room_id - ", - ) - .bind(called_user_id) - .bind(calling_connection_id.0 as i32) - .fetch_one(&mut tx) - .await?; - if expected_room_id.map_or(false, |expected_room_id| expected_room_id != room_id) { - return Err(anyhow!("canceling call on unexpected room"))?; - } - - let room = self.get_room(room_id, &mut tx).await?; - self.commit_room_transaction(room_id, tx, room).await - }).await - } - - pub async fn join_room( - &self, - room_id: RoomId, - user_id: UserId, - connection_id: ConnectionId, - ) -> Result> { - self.transact(|mut tx| async move { - sqlx::query( - " - UPDATE room_participants - SET answering_connection_id = $1 - WHERE room_id = $2 AND user_id = $3 - RETURNING 1 - ", - ) - .bind(connection_id.0 as i32) - .bind(room_id) - .bind(user_id) - .fetch_one(&mut tx) - .await?; - - let room = self.get_room(room_id, &mut tx).await?; - self.commit_room_transaction(room_id, tx, room).await - }) - .await - } - - pub async fn leave_room( - &self, - connection_id: ConnectionId, - ) -> Result>> { - self.transact(|mut tx| async move { - // Leave room. - let room_id = sqlx::query_scalar::<_, RoomId>( - " - DELETE FROM room_participants - WHERE answering_connection_id = $1 - RETURNING room_id - ", - ) - .bind(connection_id.0 as i32) - .fetch_optional(&mut tx) - .await?; - - if let Some(room_id) = room_id { - // Cancel pending calls initiated by the leaving user. - let canceled_calls_to_user_ids: Vec = sqlx::query_scalar( - " - DELETE FROM room_participants - WHERE calling_connection_id = $1 AND answering_connection_id IS NULL - RETURNING user_id - ", - ) - .bind(connection_id.0 as i32) - .fetch_all(&mut tx) - .await?; - - let project_ids = sqlx::query_scalar::<_, ProjectId>( - " - SELECT project_id - FROM project_collaborators - WHERE connection_id = $1 - ", - ) - .bind(connection_id.0 as i32) - .fetch_all(&mut tx) - .await?; - - // Leave projects. - let mut left_projects = HashMap::default(); - if !project_ids.is_empty() { - let mut params = "?,".repeat(project_ids.len()); - params.pop(); - let query = format!( - " - SELECT * - FROM project_collaborators - WHERE project_id IN ({params}) - " - ); - let mut query = sqlx::query_as::<_, ProjectCollaborator>(&query); - for project_id in project_ids { - query = query.bind(project_id); - } - - let mut project_collaborators = query.fetch(&mut tx); - while let Some(collaborator) = project_collaborators.next().await { - let collaborator = collaborator?; - let left_project = - left_projects - .entry(collaborator.project_id) - .or_insert(LeftProject { - id: collaborator.project_id, - host_user_id: Default::default(), - connection_ids: Default::default(), - host_connection_id: Default::default(), - }); - - let collaborator_connection_id = - ConnectionId(collaborator.connection_id as u32); - if collaborator_connection_id != connection_id { - left_project.connection_ids.push(collaborator_connection_id); - } - - if collaborator.is_host { - left_project.host_user_id = collaborator.user_id; - left_project.host_connection_id = - ConnectionId(collaborator.connection_id as u32); - } - } - } - sqlx::query( - " - DELETE FROM project_collaborators - WHERE connection_id = $1 - ", - ) - .bind(connection_id.0 as i32) - .execute(&mut tx) - .await?; - - // Unshare projects. - sqlx::query( - " - DELETE FROM projects - WHERE room_id = $1 AND host_connection_id = $2 - ", - ) - .bind(room_id) - .bind(connection_id.0 as i32) - .execute(&mut tx) - .await?; - - let room = self.get_room(room_id, &mut tx).await?; - Ok(Some( - self.commit_room_transaction( - room_id, - tx, - LeftRoom { - room, - left_projects, - canceled_calls_to_user_ids, - }, - ) - .await?, - )) - } else { - Ok(None) - } - }) - .await - } - - pub async fn update_room_participant_location( - &self, - room_id: RoomId, - connection_id: ConnectionId, - location: proto::ParticipantLocation, - ) -> Result> { - self.transact(|tx| async { - let mut tx = tx; - let location_kind; - let location_project_id; - match location - .variant - .as_ref() - .ok_or_else(|| anyhow!("invalid location"))? - { - proto::participant_location::Variant::SharedProject(project) => { - location_kind = 0; - location_project_id = Some(ProjectId::from_proto(project.id)); - } - proto::participant_location::Variant::UnsharedProject(_) => { - location_kind = 1; - location_project_id = None; - } - proto::participant_location::Variant::External(_) => { - location_kind = 2; - location_project_id = None; - } - } - - sqlx::query( - " - UPDATE room_participants - SET location_kind = $1, location_project_id = $2 - WHERE room_id = $3 AND answering_connection_id = $4 - RETURNING 1 - ", - ) - .bind(location_kind) - .bind(location_project_id) - .bind(room_id) - .bind(connection_id.0 as i32) - .fetch_one(&mut tx) - .await?; - - let room = self.get_room(room_id, &mut tx).await?; - self.commit_room_transaction(room_id, tx, room).await - }) - .await - } - - async fn get_guest_connection_ids( - &self, - project_id: ProjectId, - tx: &mut sqlx::Transaction<'_, D>, - ) -> Result> { - let mut guest_connection_ids = Vec::new(); - let mut db_guest_connection_ids = sqlx::query_scalar::<_, i32>( - " - SELECT connection_id - FROM project_collaborators - WHERE project_id = $1 AND is_host = FALSE - ", - ) - .bind(project_id) - .fetch(tx); - while let Some(connection_id) = db_guest_connection_ids.next().await { - guest_connection_ids.push(ConnectionId(connection_id? as u32)); - } - Ok(guest_connection_ids) - } - - async fn get_room( - &self, - room_id: RoomId, - tx: &mut sqlx::Transaction<'_, D>, - ) -> Result { - let room: Room = sqlx::query_as( - " - SELECT * - FROM rooms - WHERE id = $1 - ", - ) - .bind(room_id) - .fetch_one(&mut *tx) - .await?; - - let mut db_participants = - sqlx::query_as::<_, (UserId, Option, Option, Option, UserId, Option)>( - " - SELECT user_id, answering_connection_id, location_kind, location_project_id, calling_user_id, initial_project_id - FROM room_participants - WHERE room_id = $1 - ", - ) - .bind(room_id) - .fetch(&mut *tx); - let mut participants = HashMap::default(); let mut pending_participants = Vec::new(); - while let Some(participant) = db_participants.next().await { - let ( - user_id, - answering_connection_id, - location_kind, - location_project_id, - calling_user_id, - initial_project_id, - ) = participant?; - if let Some(answering_connection_id) = answering_connection_id { - let location = match (location_kind, location_project_id) { + while let Some(db_participant) = db_participants.next().await { + let db_participant = db_participant?; + if let Some(answering_connection_id) = db_participant.answering_connection_id { + let location = match ( + db_participant.location_kind, + db_participant.location_project_id, + ) { (Some(0), Some(project_id)) => { Some(proto::participant_location::Variant::SharedProject( proto::participant_location::SharedProject { @@ -1482,7 +989,7 @@ where participants.insert( answering_connection_id, proto::Participant { - user_id: user_id.to_proto(), + user_id: db_participant.user_id.to_proto(), peer_id: answering_connection_id as u32, projects: Default::default(), location: Some(proto::ParticipantLocation { variant: location }), @@ -1490,1054 +997,66 @@ where ); } else { pending_participants.push(proto::PendingParticipant { - user_id: user_id.to_proto(), - calling_user_id: calling_user_id.to_proto(), - initial_project_id: initial_project_id.map(|id| id.to_proto()), + user_id: db_participant.user_id.to_proto(), + calling_user_id: db_participant.calling_user_id.to_proto(), + initial_project_id: db_participant.initial_project_id.map(|id| id.to_proto()), }); } } - drop(db_participants); - let mut rows = sqlx::query_as::<_, (i32, ProjectId, Option)>( - " - SELECT host_connection_id, projects.id, worktrees.root_name - FROM projects - LEFT JOIN worktrees ON projects.id = worktrees.project_id - WHERE room_id = $1 - ", - ) - .bind(room_id) - .fetch(&mut *tx); + let mut db_projects = db_room + .find_related(project::Entity) + .find_with_related(worktree::Entity) + .stream(tx) + .await?; - while let Some(row) = rows.next().await { - let (connection_id, project_id, worktree_root_name) = row?; - if let Some(participant) = participants.get_mut(&connection_id) { + while let Some(row) = db_projects.next().await { + let (db_project, db_worktree) = row?; + if let Some(participant) = participants.get_mut(&db_project.host_connection_id) { let project = if let Some(project) = participant .projects .iter_mut() - .find(|project| project.id == project_id.to_proto()) + .find(|project| project.id == db_project.id.to_proto()) { project } else { participant.projects.push(proto::ParticipantProject { - id: project_id.to_proto(), + id: db_project.id.to_proto(), worktree_root_names: Default::default(), }); participant.projects.last_mut().unwrap() }; - project.worktree_root_names.extend(worktree_root_name); + + if let Some(db_worktree) = db_worktree { + project.worktree_root_names.push(db_worktree.root_name); + } } } Ok(proto::Room { - id: room.id.to_proto(), - live_kit_room: room.live_kit_room, + id: db_room.id.to_proto(), + live_kit_room: db_room.live_kit_room, participants: participants.into_values().collect(), pending_participants, }) } - // projects - - pub async fn project_count_excluding_admins(&self) -> Result { - self.transact(|mut tx| async move { - Ok(sqlx::query_scalar::<_, i32>( - " - SELECT COUNT(*) - FROM projects, users - WHERE projects.host_user_id = users.id AND users.admin IS FALSE - ", - ) - .fetch_one(&mut tx) - .await? as usize) - }) - .await - } - - pub async fn share_project( + async fn commit_room_transaction( &self, - expected_room_id: RoomId, - connection_id: ConnectionId, - worktrees: &[proto::WorktreeMetadata], - ) -> Result> { - self.transact(|mut tx| async move { - let (sql, values) = self.build_query( - Query::select() - .columns([ - schema::room_participant::Definition::RoomId, - schema::room_participant::Definition::UserId, - ]) - .from(schema::room_participant::Definition::Table) - .and_where( - Expr::col(schema::room_participant::Definition::AnsweringConnectionId) - .eq(connection_id.0), - ), - ); - let (room_id, user_id) = sqlx::query_as_with::<_, (RoomId, UserId), _>(&sql, values) - .fetch_one(&mut tx) - .await?; - if room_id != expected_room_id { - return Err(anyhow!("shared project on unexpected room"))?; - } - - let (sql, values) = self.build_query( - Query::insert() - .into_table(schema::project::Definition::Table) - .columns([ - schema::project::Definition::RoomId, - schema::project::Definition::HostUserId, - schema::project::Definition::HostConnectionId, - ]) - .values_panic([room_id.into(), user_id.into(), connection_id.0.into()]) - .returning_col(schema::project::Definition::Id), - ); - let project_id: ProjectId = sqlx::query_scalar_with(&sql, values) - .fetch_one(&mut tx) - .await?; - - if !worktrees.is_empty() { - let mut query = Query::insert() - .into_table(schema::worktree::Definition::Table) - .columns([ - schema::worktree::Definition::ProjectId, - schema::worktree::Definition::Id, - schema::worktree::Definition::RootName, - schema::worktree::Definition::AbsPath, - schema::worktree::Definition::Visible, - schema::worktree::Definition::ScanId, - schema::worktree::Definition::IsComplete, - ]) - .to_owned(); - for worktree in worktrees { - query.values_panic([ - project_id.into(), - worktree.id.into(), - worktree.root_name.clone().into(), - worktree.abs_path.clone().into(), - worktree.visible.into(), - 0.into(), - false.into(), - ]); - } - let (sql, values) = self.build_query(&query); - sqlx::query_with(&sql, values).execute(&mut tx).await?; - } - - sqlx::query( - " - INSERT INTO project_collaborators ( - project_id, - connection_id, - user_id, - replica_id, - is_host - ) - VALUES ($1, $2, $3, $4, $5) - ", - ) - .bind(project_id) - .bind(connection_id.0 as i32) - .bind(user_id) - .bind(0) - .bind(true) - .execute(&mut tx) - .await?; - - let room = self.get_room(room_id, &mut tx).await?; - self.commit_room_transaction(room_id, tx, (project_id, room)) - .await + room_id: RoomId, + tx: DatabaseTransaction, + data: T, + ) -> Result> { + let lock = self.rooms.entry(room_id).or_default().clone(); + let _guard = lock.lock_owned().await; + tx.commit().await?; + Ok(RoomGuard { + data, + _guard, + _not_send: PhantomData, }) - .await } - pub async fn unshare_project( - &self, - project_id: ProjectId, - connection_id: ConnectionId, - ) -> Result)>> { - self.transact(|mut tx| async move { - let guest_connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?; - let room_id: RoomId = sqlx::query_scalar( - " - DELETE FROM projects - WHERE id = $1 AND host_connection_id = $2 - RETURNING room_id - ", - ) - .bind(project_id) - .bind(connection_id.0 as i32) - .fetch_one(&mut tx) - .await?; - let room = self.get_room(room_id, &mut tx).await?; - self.commit_room_transaction(room_id, tx, (room, guest_connection_ids)) - .await - }) - .await - } - - pub async fn update_project( - &self, - project_id: ProjectId, - connection_id: ConnectionId, - worktrees: &[proto::WorktreeMetadata], - ) -> Result)>> { - self.transact(|mut tx| async move { - let room_id: RoomId = sqlx::query_scalar( - " - SELECT room_id - FROM projects - WHERE id = $1 AND host_connection_id = $2 - ", - ) - .bind(project_id) - .bind(connection_id.0 as i32) - .fetch_one(&mut tx) - .await?; - - if !worktrees.is_empty() { - let mut params = "(?, ?, ?, ?, ?, ?, ?),".repeat(worktrees.len()); - params.pop(); - let query = format!( - " - INSERT INTO worktrees ( - project_id, - id, - root_name, - abs_path, - visible, - scan_id, - is_complete - ) - VALUES {params} - ON CONFLICT (project_id, id) DO UPDATE SET root_name = excluded.root_name - " - ); - - let mut query = sqlx::query(&query); - for worktree in worktrees { - query = query - .bind(project_id) - .bind(worktree.id as i32) - .bind(&worktree.root_name) - .bind(&worktree.abs_path) - .bind(worktree.visible) - .bind(0) - .bind(false) - } - query.execute(&mut tx).await?; - } - - let mut params = "?,".repeat(worktrees.len()); - if !worktrees.is_empty() { - params.pop(); - } - let query = format!( - " - DELETE FROM worktrees - WHERE project_id = ? AND id NOT IN ({params}) - ", - ); - - let mut query = sqlx::query(&query).bind(project_id); - for worktree in worktrees { - query = query.bind(WorktreeId(worktree.id as i32)); - } - query.execute(&mut tx).await?; - - let guest_connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?; - let room = self.get_room(room_id, &mut tx).await?; - self.commit_room_transaction(room_id, tx, (room, guest_connection_ids)) - .await - }) - .await - } - - pub async fn update_worktree( - &self, - update: &proto::UpdateWorktree, - connection_id: ConnectionId, - ) -> Result>> { - self.transact(|mut tx| async move { - let project_id = ProjectId::from_proto(update.project_id); - let worktree_id = WorktreeId::from_proto(update.worktree_id); - - // Ensure the update comes from the host. - let room_id: RoomId = sqlx::query_scalar( - " - SELECT room_id - FROM projects - WHERE id = $1 AND host_connection_id = $2 - ", - ) - .bind(project_id) - .bind(connection_id.0 as i32) - .fetch_one(&mut tx) - .await?; - - // Update metadata. - sqlx::query( - " - UPDATE worktrees - SET - root_name = $1, - scan_id = $2, - is_complete = $3, - abs_path = $4 - WHERE project_id = $5 AND id = $6 - RETURNING 1 - ", - ) - .bind(&update.root_name) - .bind(update.scan_id as i64) - .bind(update.is_last_update) - .bind(&update.abs_path) - .bind(project_id) - .bind(worktree_id) - .fetch_one(&mut tx) - .await?; - - if !update.updated_entries.is_empty() { - let mut params = - "(?, ?, ?, ?, ?, ?, ?, ?, ?, ?),".repeat(update.updated_entries.len()); - params.pop(); - - let query = format!( - " - INSERT INTO worktree_entries ( - project_id, - worktree_id, - id, - is_dir, - path, - inode, - mtime_seconds, - mtime_nanos, - is_symlink, - is_ignored - ) - VALUES {params} - ON CONFLICT (project_id, worktree_id, id) DO UPDATE SET - is_dir = excluded.is_dir, - path = excluded.path, - inode = excluded.inode, - mtime_seconds = excluded.mtime_seconds, - mtime_nanos = excluded.mtime_nanos, - is_symlink = excluded.is_symlink, - is_ignored = excluded.is_ignored - " - ); - let mut query = sqlx::query(&query); - for entry in &update.updated_entries { - let mtime = entry.mtime.clone().unwrap_or_default(); - query = query - .bind(project_id) - .bind(worktree_id) - .bind(entry.id as i64) - .bind(entry.is_dir) - .bind(&entry.path) - .bind(entry.inode as i64) - .bind(mtime.seconds as i64) - .bind(mtime.nanos as i32) - .bind(entry.is_symlink) - .bind(entry.is_ignored); - } - query.execute(&mut tx).await?; - } - - if !update.removed_entries.is_empty() { - let mut params = "?,".repeat(update.removed_entries.len()); - params.pop(); - let query = format!( - " - DELETE FROM worktree_entries - WHERE project_id = ? AND worktree_id = ? AND id IN ({params}) - " - ); - - let mut query = sqlx::query(&query).bind(project_id).bind(worktree_id); - for entry_id in &update.removed_entries { - query = query.bind(*entry_id as i64); - } - query.execute(&mut tx).await?; - } - - let connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?; - self.commit_room_transaction(room_id, tx, connection_ids) - .await - }) - .await - } - - pub async fn update_diagnostic_summary( - &self, - update: &proto::UpdateDiagnosticSummary, - connection_id: ConnectionId, - ) -> Result>> { - self.transact(|mut tx| async { - let project_id = ProjectId::from_proto(update.project_id); - let worktree_id = WorktreeId::from_proto(update.worktree_id); - let summary = update - .summary - .as_ref() - .ok_or_else(|| anyhow!("invalid summary"))?; - - // Ensure the update comes from the host. - let room_id: RoomId = sqlx::query_scalar( - " - SELECT room_id - FROM projects - WHERE id = $1 AND host_connection_id = $2 - ", - ) - .bind(project_id) - .bind(connection_id.0 as i32) - .fetch_one(&mut tx) - .await?; - - // Update summary. - sqlx::query( - " - INSERT INTO worktree_diagnostic_summaries ( - project_id, - worktree_id, - path, - language_server_id, - error_count, - warning_count - ) - VALUES ($1, $2, $3, $4, $5, $6) - ON CONFLICT (project_id, worktree_id, path) DO UPDATE SET - language_server_id = excluded.language_server_id, - error_count = excluded.error_count, - warning_count = excluded.warning_count - ", - ) - .bind(project_id) - .bind(worktree_id) - .bind(&summary.path) - .bind(summary.language_server_id as i64) - .bind(summary.error_count as i32) - .bind(summary.warning_count as i32) - .execute(&mut tx) - .await?; - - let connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?; - self.commit_room_transaction(room_id, tx, connection_ids) - .await - }) - .await - } - - pub async fn start_language_server( - &self, - update: &proto::StartLanguageServer, - connection_id: ConnectionId, - ) -> Result>> { - self.transact(|mut tx| async { - let project_id = ProjectId::from_proto(update.project_id); - let server = update - .server - .as_ref() - .ok_or_else(|| anyhow!("invalid language server"))?; - - // Ensure the update comes from the host. - let room_id: RoomId = sqlx::query_scalar( - " - SELECT room_id - FROM projects - WHERE id = $1 AND host_connection_id = $2 - ", - ) - .bind(project_id) - .bind(connection_id.0 as i32) - .fetch_one(&mut tx) - .await?; - - // Add the newly-started language server. - sqlx::query( - " - INSERT INTO language_servers (project_id, id, name) - VALUES ($1, $2, $3) - ON CONFLICT (project_id, id) DO UPDATE SET - name = excluded.name - ", - ) - .bind(project_id) - .bind(server.id as i64) - .bind(&server.name) - .execute(&mut tx) - .await?; - - let connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?; - self.commit_room_transaction(room_id, tx, connection_ids) - .await - }) - .await - } - - pub async fn join_project( - &self, - project_id: ProjectId, - connection_id: ConnectionId, - ) -> Result> { - self.transact(|mut tx| async move { - let (room_id, user_id) = sqlx::query_as::<_, (RoomId, UserId)>( - " - SELECT room_id, user_id - FROM room_participants - WHERE answering_connection_id = $1 - ", - ) - .bind(connection_id.0 as i32) - .fetch_one(&mut tx) - .await?; - - // Ensure project id was shared on this room. - sqlx::query( - " - SELECT 1 - FROM projects - WHERE id = $1 AND room_id = $2 - ", - ) - .bind(project_id) - .bind(room_id) - .fetch_one(&mut tx) - .await?; - - let mut collaborators = sqlx::query_as::<_, ProjectCollaborator>( - " - SELECT * - FROM project_collaborators - WHERE project_id = $1 - ", - ) - .bind(project_id) - .fetch_all(&mut tx) - .await?; - let replica_ids = collaborators - .iter() - .map(|c| c.replica_id) - .collect::>(); - let mut replica_id = ReplicaId(1); - while replica_ids.contains(&replica_id) { - replica_id.0 += 1; - } - let new_collaborator = ProjectCollaborator { - project_id, - connection_id: connection_id.0 as i32, - user_id, - replica_id, - is_host: false, - }; - - sqlx::query( - " - INSERT INTO project_collaborators ( - project_id, - connection_id, - user_id, - replica_id, - is_host - ) - VALUES ($1, $2, $3, $4, $5) - ", - ) - .bind(new_collaborator.project_id) - .bind(new_collaborator.connection_id) - .bind(new_collaborator.user_id) - .bind(new_collaborator.replica_id) - .bind(new_collaborator.is_host) - .execute(&mut tx) - .await?; - collaborators.push(new_collaborator); - - let worktree_rows = sqlx::query_as::<_, WorktreeRow>( - " - SELECT * - FROM worktrees - WHERE project_id = $1 - ", - ) - .bind(project_id) - .fetch_all(&mut tx) - .await?; - let mut worktrees = worktree_rows - .into_iter() - .map(|worktree_row| { - ( - worktree_row.id, - Worktree { - id: worktree_row.id, - abs_path: worktree_row.abs_path, - root_name: worktree_row.root_name, - visible: worktree_row.visible, - entries: Default::default(), - diagnostic_summaries: Default::default(), - scan_id: worktree_row.scan_id as u64, - is_complete: worktree_row.is_complete, - }, - ) - }) - .collect::>(); - - // Populate worktree entries. - { - let mut entries = sqlx::query_as::<_, WorktreeEntry>( - " - SELECT * - FROM worktree_entries - WHERE project_id = $1 - ", - ) - .bind(project_id) - .fetch(&mut tx); - while let Some(entry) = entries.next().await { - let entry = entry?; - if let Some(worktree) = worktrees.get_mut(&entry.worktree_id) { - worktree.entries.push(proto::Entry { - id: entry.id as u64, - is_dir: entry.is_dir, - path: entry.path, - inode: entry.inode as u64, - mtime: Some(proto::Timestamp { - seconds: entry.mtime_seconds as u64, - nanos: entry.mtime_nanos as u32, - }), - is_symlink: entry.is_symlink, - is_ignored: entry.is_ignored, - }); - } - } - } - - // Populate worktree diagnostic summaries. - { - let mut summaries = sqlx::query_as::<_, WorktreeDiagnosticSummary>( - " - SELECT * - FROM worktree_diagnostic_summaries - WHERE project_id = $1 - ", - ) - .bind(project_id) - .fetch(&mut tx); - while let Some(summary) = summaries.next().await { - let summary = summary?; - if let Some(worktree) = worktrees.get_mut(&summary.worktree_id) { - worktree - .diagnostic_summaries - .push(proto::DiagnosticSummary { - path: summary.path, - language_server_id: summary.language_server_id as u64, - error_count: summary.error_count as u32, - warning_count: summary.warning_count as u32, - }); - } - } - } - - // Populate language servers. - let language_servers = sqlx::query_as::<_, LanguageServer>( - " - SELECT * - FROM language_servers - WHERE project_id = $1 - ", - ) - .bind(project_id) - .fetch_all(&mut tx) - .await?; - - self.commit_room_transaction( - room_id, - tx, - ( - Project { - collaborators, - worktrees, - language_servers: language_servers - .into_iter() - .map(|language_server| proto::LanguageServer { - id: language_server.id.to_proto(), - name: language_server.name, - }) - .collect(), - }, - replica_id as ReplicaId, - ), - ) - .await - }) - .await - } - - pub async fn leave_project( - &self, - project_id: ProjectId, - connection_id: ConnectionId, - ) -> Result> { - self.transact(|mut tx| async move { - let result = sqlx::query( - " - DELETE FROM project_collaborators - WHERE project_id = $1 AND connection_id = $2 - ", - ) - .bind(project_id) - .bind(connection_id.0 as i32) - .execute(&mut tx) - .await?; - - if result.rows_affected() == 0 { - Err(anyhow!("not a collaborator on this project"))?; - } - - let connection_ids = sqlx::query_scalar::<_, i32>( - " - SELECT connection_id - FROM project_collaborators - WHERE project_id = $1 - ", - ) - .bind(project_id) - .fetch_all(&mut tx) - .await? - .into_iter() - .map(|id| ConnectionId(id as u32)) - .collect(); - - let (room_id, host_user_id, host_connection_id) = - sqlx::query_as::<_, (RoomId, i32, i32)>( - " - SELECT room_id, host_user_id, host_connection_id - FROM projects - WHERE id = $1 - ", - ) - .bind(project_id) - .fetch_one(&mut tx) - .await?; - - self.commit_room_transaction( - room_id, - tx, - LeftProject { - id: project_id, - host_user_id: UserId(host_user_id), - host_connection_id: ConnectionId(host_connection_id as u32), - connection_ids, - }, - ) - .await - }) - .await - } - - pub async fn project_collaborators( - &self, - project_id: ProjectId, - connection_id: ConnectionId, - ) -> Result> { - self.transact(|mut tx| async move { - let collaborators = sqlx::query_as::<_, ProjectCollaborator>( - " - SELECT * - FROM project_collaborators - WHERE project_id = $1 - ", - ) - .bind(project_id) - .fetch_all(&mut tx) - .await?; - - if collaborators - .iter() - .any(|collaborator| collaborator.connection_id == connection_id.0 as i32) - { - Ok(collaborators) - } else { - Err(anyhow!("no such project"))? - } - }) - .await - } - - pub async fn project_connection_ids( - &self, - project_id: ProjectId, - connection_id: ConnectionId, - ) -> Result> { - self.transact(|mut tx| async move { - let connection_ids = sqlx::query_scalar::<_, i32>( - " - SELECT connection_id - FROM project_collaborators - WHERE project_id = $1 - ", - ) - .bind(project_id) - .fetch_all(&mut tx) - .await?; - - if connection_ids.contains(&(connection_id.0 as i32)) { - Ok(connection_ids - .into_iter() - .map(|connection_id| ConnectionId(connection_id as u32)) - .collect()) - } else { - Err(anyhow!("no such project"))? - } - }) - .await - } - - // contacts - - pub async fn get_contacts(&self, user_id: UserId) -> Result> { - self.transact(|mut tx| async move { - let query = " - SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify, (room_participants.id IS NOT NULL) as busy - FROM contacts - LEFT JOIN room_participants ON room_participants.user_id = $1 - WHERE user_id_a = $1 OR user_id_b = $1; - "; - - let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool, bool)>(query) - .bind(user_id) - .fetch(&mut tx); - - let mut contacts = Vec::new(); - while let Some(row) = rows.next().await { - let (user_id_a, user_id_b, a_to_b, accepted, should_notify, busy) = row?; - if user_id_a == user_id { - if accepted { - contacts.push(Contact::Accepted { - user_id: user_id_b, - should_notify: should_notify && a_to_b, - busy - }); - } else if a_to_b { - contacts.push(Contact::Outgoing { user_id: user_id_b }) - } else { - contacts.push(Contact::Incoming { - user_id: user_id_b, - should_notify, - }); - } - } else if accepted { - contacts.push(Contact::Accepted { - user_id: user_id_a, - should_notify: should_notify && !a_to_b, - busy - }); - } else if a_to_b { - contacts.push(Contact::Incoming { - user_id: user_id_a, - should_notify, - }); - } else { - contacts.push(Contact::Outgoing { user_id: user_id_a }); - } - } - - contacts.sort_unstable_by_key(|contact| contact.user_id()); - - Ok(contacts) - }) - .await - } - - pub async fn is_user_busy(&self, user_id: UserId) -> Result { - self.transact(|mut tx| async move { - Ok(sqlx::query_scalar::<_, i32>( - " - SELECT 1 - FROM room_participants - WHERE room_participants.user_id = $1 - ", - ) - .bind(user_id) - .fetch_optional(&mut tx) - .await? - .is_some()) - }) - .await - } - - pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result { - self.transact(|mut tx| async move { - let (id_a, id_b) = if user_id_1 < user_id_2 { - (user_id_1, user_id_2) - } else { - (user_id_2, user_id_1) - }; - - let query = " - SELECT 1 FROM contacts - WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = TRUE - LIMIT 1 - "; - Ok(sqlx::query_scalar::<_, i32>(query) - .bind(id_a.0) - .bind(id_b.0) - .fetch_optional(&mut tx) - .await? - .is_some()) - }) - .await - } - - pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> { - self.transact(|mut tx| async move { - let (id_a, id_b, a_to_b) = if sender_id < receiver_id { - (sender_id, receiver_id, true) - } else { - (receiver_id, sender_id, false) - }; - let query = " - INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify) - VALUES ($1, $2, $3, FALSE, TRUE) - ON CONFLICT (user_id_a, user_id_b) DO UPDATE - SET - accepted = TRUE, - should_notify = FALSE - WHERE - NOT contacts.accepted AND - ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR - (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a)); - "; - let result = sqlx::query(query) - .bind(id_a.0) - .bind(id_b.0) - .bind(a_to_b) - .execute(&mut tx) - .await?; - - if result.rows_affected() == 1 { - tx.commit().await?; - Ok(()) - } else { - Err(anyhow!("contact already requested"))? - } - }).await - } - - pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> { - self.transact(|mut tx| async move { - let (id_a, id_b) = if responder_id < requester_id { - (responder_id, requester_id) - } else { - (requester_id, responder_id) - }; - let query = " - DELETE FROM contacts - WHERE user_id_a = $1 AND user_id_b = $2; - "; - let result = sqlx::query(query) - .bind(id_a.0) - .bind(id_b.0) - .execute(&mut tx) - .await?; - - if result.rows_affected() == 1 { - tx.commit().await?; - Ok(()) - } else { - Err(anyhow!("no such contact"))? - } - }) - .await - } - - pub async fn dismiss_contact_notification( - &self, - user_id: UserId, - contact_user_id: UserId, - ) -> Result<()> { - self.transact(|mut tx| async move { - let (id_a, id_b, a_to_b) = if user_id < contact_user_id { - (user_id, contact_user_id, true) - } else { - (contact_user_id, user_id, false) - }; - - let query = " - UPDATE contacts - SET should_notify = FALSE - WHERE - user_id_a = $1 AND user_id_b = $2 AND - ( - (a_to_b = $3 AND accepted) OR - (a_to_b != $3 AND NOT accepted) - ); - "; - - let result = sqlx::query(query) - .bind(id_a.0) - .bind(id_b.0) - .bind(a_to_b) - .execute(&mut tx) - .await?; - - if result.rows_affected() == 0 { - Err(anyhow!("no such contact request"))? - } else { - tx.commit().await?; - Ok(()) - } - }) - .await - } - - pub async fn respond_to_contact_request( - &self, - responder_id: UserId, - requester_id: UserId, - accept: bool, - ) -> Result<()> { - self.transact(|mut tx| async move { - let (id_a, id_b, a_to_b) = if responder_id < requester_id { - (responder_id, requester_id, false) - } else { - (requester_id, responder_id, true) - }; - let result = if accept { - let query = " - UPDATE contacts - SET accepted = TRUE, should_notify = TRUE - WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3; - "; - sqlx::query(query) - .bind(id_a.0) - .bind(id_b.0) - .bind(a_to_b) - .execute(&mut tx) - .await? - } else { - let query = " - DELETE FROM contacts - WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted; - "; - sqlx::query(query) - .bind(id_a.0) - .bind(id_b.0) - .bind(a_to_b) - .execute(&mut tx) - .await? - }; - if result.rows_affected() == 1 { - tx.commit().await?; - Ok(()) - } else { - Err(anyhow!("no such contact request"))? - } - }) - .await - } - - // access tokens - pub async fn create_access_token_hash( &self, user_id: UserId, @@ -2545,49 +1064,51 @@ where max_access_token_count: usize, ) -> Result<()> { self.transact(|tx| async { - let mut tx = tx; - let insert_query = " - INSERT INTO access_tokens (user_id, hash) - VALUES ($1, $2); - "; - let cleanup_query = " - DELETE FROM access_tokens - WHERE id IN ( - SELECT id from access_tokens - WHERE user_id = $1 - ORDER BY id DESC - LIMIT 10000 - OFFSET $3 - ) - "; + let tx = tx; - sqlx::query(insert_query) - .bind(user_id.0) - .bind(access_token_hash) - .execute(&mut tx) + access_token::ActiveModel { + user_id: ActiveValue::set(user_id), + hash: ActiveValue::set(access_token_hash.into()), + ..Default::default() + } + .insert(&tx) + .await?; + + access_token::Entity::delete_many() + .filter( + access_token::Column::Id.in_subquery( + Query::select() + .column(access_token::Column::Id) + .from(access_token::Entity) + .and_where(access_token::Column::UserId.eq(user_id)) + .order_by(access_token::Column::Id, sea_orm::Order::Desc) + .limit(10000) + .offset(max_access_token_count as u64) + .to_owned(), + ), + ) + .exec(&tx) .await?; - sqlx::query(cleanup_query) - .bind(user_id.0) - .bind(access_token_hash) - .bind(max_access_token_count as i32) - .execute(&mut tx) - .await?; - Ok(tx.commit().await?) + tx.commit().await?; + Ok(()) }) .await } pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result> { - self.transact(|mut tx| async move { - let query = " - SELECT hash - FROM access_tokens - WHERE user_id = $1 - ORDER BY id DESC - "; - Ok(sqlx::query_scalar(query) - .bind(user_id.0) - .fetch_all(&mut tx) + #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] + enum QueryAs { + Hash, + } + + self.transact(|tx| async move { + Ok(access_token::Entity::find() + .select_only() + .column(access_token::Column::Hash) + .filter(access_token::Column::UserId.eq(user_id)) + .order_by_desc(access_token::Column::Id) + .into_values::<_, QueryAs>() + .all(&tx) .await?) }) .await @@ -2595,21 +1116,33 @@ where async fn transact(&self, f: F) -> Result where - F: Send + Fn(sqlx::Transaction<'static, D>) -> Fut, + F: Send + Fn(DatabaseTransaction) -> Fut, Fut: Send + Future>, { let body = async { loop { - let tx = self.begin_transaction().await?; + let tx = self.pool.begin().await?; + + // In Postgres, serializable transactions are opt-in + if let DatabaseBackend::Postgres = self.pool.get_database_backend() { + tx.execute(Statement::from_string( + DatabaseBackend::Postgres, + "SET TRANSACTION ISOLATION LEVEL SERIALIZABLE;".into(), + )) + .await?; + } + match f(tx).await { Ok(result) => return Ok(result), Err(error) => match error { - Error::Database(error) - if error - .as_database_error() - .and_then(|error| error.code()) - .as_deref() - == Some("40001") => + Error::Database2( + DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error)) + | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)), + ) if error + .as_database_error() + .and_then(|error| error.code()) + .as_deref() + == Some("40001") => { // Retry (don't break the loop) } @@ -2635,6 +1168,49 @@ where } } +pub struct RoomGuard { + data: T, + _guard: OwnedMutexGuard<()>, + _not_send: PhantomData>, +} + +impl Deref for RoomGuard { + type Target = T; + + fn deref(&self) -> &T { + &self.data + } +} + +impl DerefMut for RoomGuard { + fn deref_mut(&mut self) -> &mut T { + &mut self.data + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct NewUserParams { + pub github_login: String, + pub github_user_id: i32, + pub invite_count: i32, +} + +#[derive(Debug)] +pub struct NewUserResult { + pub user_id: UserId, + pub metrics_id: String, + pub inviting_user_id: Option, + pub signup_device_id: Option, +} + +fn random_invite_code() -> String { + nanoid::nanoid!(16) +} + +fn random_email_confirmation_code() -> String { + nanoid::nanoid!(64) +} + macro_rules! id_type { ($name:ident) => { #[derive( @@ -2681,196 +1257,90 @@ macro_rules! id_type { sea_query::Value::Int(Some(value.0)) } } + + impl sea_orm::TryGetable for $name { + fn try_get( + res: &sea_orm::QueryResult, + pre: &str, + col: &str, + ) -> Result { + Ok(Self(i32::try_get(res, pre, col)?)) + } + } + + impl sea_query::ValueType for $name { + fn try_from(v: Value) -> Result { + match v { + Value::TinyInt(Some(int)) => { + Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?)) + } + Value::SmallInt(Some(int)) => { + Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?)) + } + Value::Int(Some(int)) => { + Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?)) + } + Value::BigInt(Some(int)) => { + Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?)) + } + Value::TinyUnsigned(Some(int)) => { + Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?)) + } + Value::SmallUnsigned(Some(int)) => { + Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?)) + } + Value::Unsigned(Some(int)) => { + Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?)) + } + Value::BigUnsigned(Some(int)) => { + Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?)) + } + _ => Err(sea_query::ValueTypeErr), + } + } + + fn type_name() -> String { + stringify!($name).into() + } + + fn array_type() -> sea_query::ArrayType { + sea_query::ArrayType::Int + } + + fn column_type() -> sea_query::ColumnType { + sea_query::ColumnType::Integer(None) + } + } + + impl sea_orm::TryFromU64 for $name { + fn try_from_u64(n: u64) -> Result { + Ok(Self(n.try_into().map_err(|_| { + DbErr::ConvertFromU64(concat!( + "error converting ", + stringify!($name), + " to u64" + )) + })?)) + } + } + + impl sea_query::Nullable for $name { + fn null() -> Value { + Value::Int(None) + } + } }; } +id_type!(AccessTokenId); +id_type!(ContactId); id_type!(UserId); -#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)] -pub struct User { - pub id: UserId, - pub github_login: String, - pub github_user_id: Option, - pub email_address: Option, - pub admin: bool, - pub invite_code: Option, - pub invite_count: i32, - pub connected_once: bool, -} - id_type!(RoomId); -#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)] -pub struct Room { - pub id: RoomId, - pub live_kit_room: String, -} - +id_type!(RoomParticipantId); id_type!(ProjectId); -pub struct Project { - pub collaborators: Vec, - pub worktrees: BTreeMap, - pub language_servers: Vec, -} - -id_type!(ReplicaId); -#[derive(Clone, Debug, Default, FromRow, PartialEq)] -pub struct ProjectCollaborator { - pub project_id: ProjectId, - pub connection_id: i32, - pub user_id: UserId, - pub replica_id: ReplicaId, - pub is_host: bool, -} - +id_type!(ProjectCollaboratorId); +id_type!(SignupId); id_type!(WorktreeId); -#[derive(Clone, Debug, Default, FromRow, PartialEq)] -struct WorktreeRow { - pub id: WorktreeId, - pub project_id: ProjectId, - pub abs_path: String, - pub root_name: String, - pub visible: bool, - pub scan_id: i64, - pub is_complete: bool, -} - -pub struct Worktree { - pub id: WorktreeId, - pub abs_path: String, - pub root_name: String, - pub visible: bool, - pub entries: Vec, - pub diagnostic_summaries: Vec, - pub scan_id: u64, - pub is_complete: bool, -} - -#[derive(Clone, Debug, Default, FromRow, PartialEq)] -struct WorktreeEntry { - id: i64, - worktree_id: WorktreeId, - is_dir: bool, - path: String, - inode: i64, - mtime_seconds: i64, - mtime_nanos: i32, - is_symlink: bool, - is_ignored: bool, -} - -#[derive(Clone, Debug, Default, FromRow, PartialEq)] -struct WorktreeDiagnosticSummary { - worktree_id: WorktreeId, - path: String, - language_server_id: i64, - error_count: i32, - warning_count: i32, -} - -id_type!(LanguageServerId); -#[derive(Clone, Debug, Default, FromRow, PartialEq)] -struct LanguageServer { - id: LanguageServerId, - name: String, -} - -pub struct LeftProject { - pub id: ProjectId, - pub host_user_id: UserId, - pub host_connection_id: ConnectionId, - pub connection_ids: Vec, -} - -pub struct LeftRoom { - pub room: proto::Room, - pub left_projects: HashMap, - pub canceled_calls_to_user_ids: Vec, -} - -#[derive(Clone, Debug, PartialEq, Eq)] -pub enum Contact { - Accepted { - user_id: UserId, - should_notify: bool, - busy: bool, - }, - Outgoing { - user_id: UserId, - }, - Incoming { - user_id: UserId, - should_notify: bool, - }, -} - -impl Contact { - pub fn user_id(&self) -> UserId { - match self { - Contact::Accepted { user_id, .. } => *user_id, - Contact::Outgoing { user_id } => *user_id, - Contact::Incoming { user_id, .. } => *user_id, - } - } -} - -#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] -pub struct IncomingContactRequest { - pub requester_id: UserId, - pub should_notify: bool, -} - -#[derive(Clone, Deserialize)] -pub struct Signup { - pub email_address: String, - pub platform_mac: bool, - pub platform_windows: bool, - pub platform_linux: bool, - pub editor_features: Vec, - pub programming_languages: Vec, - pub device_id: Option, -} - -#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromRow)] -pub struct WaitlistSummary { - #[sqlx(default)] - pub count: i64, - #[sqlx(default)] - pub linux_count: i64, - #[sqlx(default)] - pub mac_count: i64, - #[sqlx(default)] - pub windows_count: i64, - #[sqlx(default)] - pub unknown_count: i64, -} - -#[derive(FromRow, PartialEq, Debug, Serialize, Deserialize)] -pub struct Invite { - pub email_address: String, - pub email_confirmation_code: String, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct NewUserParams { - pub github_login: String, - pub github_user_id: i32, - pub invite_count: i32, -} - -#[derive(Debug)] -pub struct NewUserResult { - pub user_id: UserId, - pub metrics_id: String, - pub inviting_user_id: Option, - pub signup_device_id: Option, -} - -fn random_invite_code() -> String { - nanoid::nanoid!(16) -} - -fn random_email_confirmation_code() -> String { - nanoid::nanoid!(64) -} #[cfg(test)] pub use test::*; @@ -2882,35 +1352,40 @@ mod test { use lazy_static::lazy_static; use parking_lot::Mutex; use rand::prelude::*; + use sea_orm::ConnectionTrait; use sqlx::migrate::MigrateDatabase; use std::sync::Arc; - pub struct SqliteTestDb { - pub db: Option>>, - pub conn: sqlx::sqlite::SqliteConnection, + pub struct TestDb { + pub db: Option>, + pub connection: Option, } - pub struct PostgresTestDb { - pub db: Option>>, - pub url: String, - } - - impl SqliteTestDb { - pub fn new(background: Arc) -> Self { - let mut rng = StdRng::from_entropy(); - let url = format!("file:zed-test-{}?mode=memory", rng.gen::()); + impl TestDb { + pub fn sqlite(background: Arc) -> Self { + let url = format!("sqlite::memory:"); let runtime = tokio::runtime::Builder::new_current_thread() .enable_io() .enable_time() .build() .unwrap(); - let (mut db, conn) = runtime.block_on(async { - let db = Db::::new(&url, 5).await.unwrap(); - let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite"); - db.migrate(migrations_path.as_ref(), false).await.unwrap(); - let conn = db.pool.acquire().await.unwrap().detach(); - (db, conn) + let mut db = runtime.block_on(async { + let mut options = ConnectOptions::new(url); + options.max_connections(5); + let db = Database::new(options).await.unwrap(); + let sql = include_str!(concat!( + env!("CARGO_MANIFEST_DIR"), + "/migrations.sqlite/20221109000000_test_schema.sql" + )); + db.pool + .execute(sea_orm::Statement::from_string( + db.pool.get_database_backend(), + sql.into(), + )) + .await + .unwrap(); + db }); db.background = Some(background); @@ -2918,17 +1393,11 @@ mod test { Self { db: Some(Arc::new(db)), - conn, + connection: None, } } - pub fn db(&self) -> &Arc> { - self.db.as_ref().unwrap() - } - } - - impl PostgresTestDb { - pub fn new(background: Arc) -> Self { + pub fn postgres(background: Arc) -> Self { lazy_static! { static ref LOCK: Mutex<()> = Mutex::new(()); } @@ -2949,7 +1418,11 @@ mod test { sqlx::Postgres::create_database(&url) .await .expect("failed to create test db"); - let db = Db::::new(&url, 5).await.unwrap(); + let mut options = ConnectOptions::new(url); + options + .max_connections(5) + .idle_timeout(Duration::from_secs(0)); + let db = Database::new(options).await.unwrap(); let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"); db.migrate(Path::new(migrations_path), false).await.unwrap(); db @@ -2960,19 +1433,40 @@ mod test { Self { db: Some(Arc::new(db)), - url, + connection: None, } } - pub fn db(&self) -> &Arc> { + pub fn db(&self) -> &Arc { self.db.as_ref().unwrap() } } - impl Drop for PostgresTestDb { + impl Drop for TestDb { fn drop(&mut self) { let db = self.db.take().unwrap(); - db.teardown(&self.url); + if let DatabaseBackend::Postgres = db.pool.get_database_backend() { + db.runtime.as_ref().unwrap().block_on(async { + use util::ResultExt; + let query = " + SELECT pg_terminate_backend(pg_stat_activity.pid) + FROM pg_stat_activity + WHERE + pg_stat_activity.datname = current_database() AND + pid <> pg_backend_pid(); + "; + db.pool + .execute(sea_orm::Statement::from_string( + db.pool.get_database_backend(), + query.into(), + )) + .await + .log_err(); + sqlx::Postgres::drop_database(db.options.get_url()) + .await + .log_err(); + }) + } } } } diff --git a/crates/collab/src/db2/access_token.rs b/crates/collab/src/db/access_token.rs similarity index 100% rename from crates/collab/src/db2/access_token.rs rename to crates/collab/src/db/access_token.rs diff --git a/crates/collab/src/db2/contact.rs b/crates/collab/src/db/contact.rs similarity index 100% rename from crates/collab/src/db2/contact.rs rename to crates/collab/src/db/contact.rs diff --git a/crates/collab/src/db2/project.rs b/crates/collab/src/db/project.rs similarity index 100% rename from crates/collab/src/db2/project.rs rename to crates/collab/src/db/project.rs diff --git a/crates/collab/src/db2/project_collaborator.rs b/crates/collab/src/db/project_collaborator.rs similarity index 100% rename from crates/collab/src/db2/project_collaborator.rs rename to crates/collab/src/db/project_collaborator.rs diff --git a/crates/collab/src/db2/room.rs b/crates/collab/src/db/room.rs similarity index 100% rename from crates/collab/src/db2/room.rs rename to crates/collab/src/db/room.rs diff --git a/crates/collab/src/db2/room_participant.rs b/crates/collab/src/db/room_participant.rs similarity index 100% rename from crates/collab/src/db2/room_participant.rs rename to crates/collab/src/db/room_participant.rs diff --git a/crates/collab/src/db/schema.rs b/crates/collab/src/db/schema.rs deleted file mode 100644 index 40a3e334d1..0000000000 --- a/crates/collab/src/db/schema.rs +++ /dev/null @@ -1,43 +0,0 @@ -pub mod project { - use sea_query::Iden; - - #[derive(Iden)] - pub enum Definition { - #[iden = "projects"] - Table, - Id, - RoomId, - HostUserId, - HostConnectionId, - } -} - -pub mod worktree { - use sea_query::Iden; - - #[derive(Iden)] - pub enum Definition { - #[iden = "worktrees"] - Table, - Id, - ProjectId, - AbsPath, - RootName, - Visible, - ScanId, - IsComplete, - } -} - -pub mod room_participant { - use sea_query::Iden; - - #[derive(Iden)] - pub enum Definition { - #[iden = "room_participants"] - Table, - RoomId, - UserId, - AnsweringConnectionId, - } -} diff --git a/crates/collab/src/db2/signup.rs b/crates/collab/src/db/signup.rs similarity index 95% rename from crates/collab/src/db2/signup.rs rename to crates/collab/src/db/signup.rs index 8fab8daa36..9857018a0c 100644 --- a/crates/collab/src/db2/signup.rs +++ b/crates/collab/src/db/signup.rs @@ -27,7 +27,7 @@ pub enum Relation {} impl ActiveModelBehavior for ActiveModel {} -#[derive(Debug, PartialEq, Eq, FromQueryResult)] +#[derive(Debug, PartialEq, Eq, FromQueryResult, Serialize, Deserialize)] pub struct Invite { pub email_address: String, pub email_confirmation_code: String, diff --git a/crates/collab/src/db/tests.rs b/crates/collab/src/db/tests.rs index 88488b10d2..b276bd5057 100644 --- a/crates/collab/src/db/tests.rs +++ b/crates/collab/src/db/tests.rs @@ -6,14 +6,14 @@ macro_rules! test_both_dbs { ($postgres_test_name:ident, $sqlite_test_name:ident, $db:ident, $body:block) => { #[gpui::test] async fn $postgres_test_name() { - let test_db = PostgresTestDb::new(Deterministic::new(0).build_background()); + let test_db = TestDb::postgres(Deterministic::new(0).build_background()); let $db = test_db.db(); $body } #[gpui::test] async fn $sqlite_test_name() { - let test_db = SqliteTestDb::new(Deterministic::new(0).build_background()); + let test_db = TestDb::sqlite(Deterministic::new(0).build_background()); let $db = test_db.db(); $body } @@ -26,9 +26,10 @@ test_both_dbs!( db, { let mut user_ids = Vec::new(); + let mut user_metric_ids = Vec::new(); for i in 1..=4 { - user_ids.push( - db.create_user( + let user = db + .create_user( &format!("user{i}@example.com"), false, NewUserParams { @@ -38,9 +39,9 @@ test_both_dbs!( }, ) .await - .unwrap() - .user_id, - ); + .unwrap(); + user_ids.push(user.user_id); + user_metric_ids.push(user.metrics_id); } assert_eq!( @@ -52,6 +53,7 @@ test_both_dbs!( github_user_id: Some(1), email_address: Some("user1@example.com".to_string()), admin: false, + metrics_id: user_metric_ids[0].parse().unwrap(), ..Default::default() }, User { @@ -60,6 +62,7 @@ test_both_dbs!( github_user_id: Some(2), email_address: Some("user2@example.com".to_string()), admin: false, + metrics_id: user_metric_ids[1].parse().unwrap(), ..Default::default() }, User { @@ -68,6 +71,7 @@ test_both_dbs!( github_user_id: Some(3), email_address: Some("user3@example.com".to_string()), admin: false, + metrics_id: user_metric_ids[2].parse().unwrap(), ..Default::default() }, User { @@ -76,6 +80,7 @@ test_both_dbs!( github_user_id: Some(4), email_address: Some("user4@example.com".to_string()), admin: false, + metrics_id: user_metric_ids[3].parse().unwrap(), ..Default::default() } ] @@ -399,14 +404,14 @@ test_both_dbs!(test_metrics_id_postgres, test_metrics_id_sqlite, db, { #[test] fn test_fuzzy_like_string() { - assert_eq!(DefaultDb::fuzzy_like_string("abcd"), "%a%b%c%d%"); - assert_eq!(DefaultDb::fuzzy_like_string("x y"), "%x%y%"); - assert_eq!(DefaultDb::fuzzy_like_string(" z "), "%z%"); + assert_eq!(Database::fuzzy_like_string("abcd"), "%a%b%c%d%"); + assert_eq!(Database::fuzzy_like_string("x y"), "%x%y%"); + assert_eq!(Database::fuzzy_like_string(" z "), "%z%"); } #[gpui::test] async fn test_fuzzy_search_users() { - let test_db = PostgresTestDb::new(build_background_executor()); + let test_db = TestDb::postgres(build_background_executor()); let db = test_db.db(); for (i, github_login) in [ "California", @@ -442,7 +447,7 @@ async fn test_fuzzy_search_users() { &["rhode-island", "colorado", "oregon"], ); - async fn fuzzy_search_user_names(db: &Db, query: &str) -> Vec { + async fn fuzzy_search_user_names(db: &Database, query: &str) -> Vec { db.fuzzy_search_users(query, 10) .await .unwrap() @@ -454,7 +459,7 @@ async fn test_fuzzy_search_users() { #[gpui::test] async fn test_invite_codes() { - let test_db = PostgresTestDb::new(build_background_executor()); + let test_db = TestDb::postgres(build_background_executor()); let db = test_db.db(); let NewUserResult { user_id: user1, .. } = db @@ -659,12 +664,12 @@ async fn test_invite_codes() { #[gpui::test] async fn test_signups() { - let test_db = PostgresTestDb::new(build_background_executor()); + let test_db = TestDb::postgres(build_background_executor()); let db = test_db.db(); // people sign up on the waitlist for i in 0..8 { - db.create_signup(Signup { + db.create_signup(NewSignup { email_address: format!("person-{i}@example.com"), platform_mac: true, platform_linux: i % 2 == 0, diff --git a/crates/collab/src/db2/user.rs b/crates/collab/src/db/user.rs similarity index 93% rename from crates/collab/src/db2/user.rs rename to crates/collab/src/db/user.rs index f6bac9dc77..b6e096f667 100644 --- a/crates/collab/src/db2/user.rs +++ b/crates/collab/src/db/user.rs @@ -1,7 +1,8 @@ use super::UserId; use sea_orm::entity::prelude::*; +use serde::Serialize; -#[derive(Clone, Debug, Default, PartialEq, Eq, DeriveEntityModel)] +#[derive(Clone, Debug, Default, PartialEq, Eq, DeriveEntityModel, Serialize)] #[sea_orm(table_name = "users")] pub struct Model { #[sea_orm(primary_key)] @@ -12,6 +13,7 @@ pub struct Model { pub admin: bool, pub invite_code: Option, pub invite_count: i32, + pub inviter_id: Option, pub connected_once: bool, pub metrics_id: Uuid, } diff --git a/crates/collab/src/db2/worktree.rs b/crates/collab/src/db/worktree.rs similarity index 100% rename from crates/collab/src/db2/worktree.rs rename to crates/collab/src/db/worktree.rs diff --git a/crates/collab/src/db2.rs b/crates/collab/src/db2.rs deleted file mode 100644 index 3aa21c6059..0000000000 --- a/crates/collab/src/db2.rs +++ /dev/null @@ -1,1416 +0,0 @@ -mod access_token; -mod contact; -mod project; -mod project_collaborator; -mod room; -mod room_participant; -mod signup; -#[cfg(test)] -mod tests; -mod user; -mod worktree; - -use crate::{Error, Result}; -use anyhow::anyhow; -use collections::HashMap; -use dashmap::DashMap; -use futures::StreamExt; -use hyper::StatusCode; -use rpc::{proto, ConnectionId}; -use sea_orm::{ - entity::prelude::*, ConnectOptions, DatabaseConnection, DatabaseTransaction, DbErr, - TransactionTrait, -}; -use sea_orm::{ - ActiveValue, ConnectionTrait, DatabaseBackend, FromQueryResult, IntoActiveModel, JoinType, - QueryOrder, QuerySelect, Statement, -}; -use sea_query::{Alias, Expr, OnConflict, Query}; -use serde::{Deserialize, Serialize}; -use sqlx::migrate::{Migrate, Migration, MigrationSource}; -use sqlx::Connection; -use std::ops::{Deref, DerefMut}; -use std::path::Path; -use std::time::Duration; -use std::{future::Future, marker::PhantomData, rc::Rc, sync::Arc}; -use tokio::sync::{Mutex, OwnedMutexGuard}; - -pub use contact::Contact; -pub use signup::{Invite, NewSignup, WaitlistSummary}; -pub use user::Model as User; - -pub struct Database { - options: ConnectOptions, - pool: DatabaseConnection, - rooms: DashMap>>, - #[cfg(test)] - background: Option>, - #[cfg(test)] - runtime: Option, -} - -impl Database { - pub async fn new(options: ConnectOptions) -> Result { - Ok(Self { - options: options.clone(), - pool: sea_orm::Database::connect(options).await?, - rooms: DashMap::with_capacity(16384), - #[cfg(test)] - background: None, - #[cfg(test)] - runtime: None, - }) - } - - pub async fn migrate( - &self, - migrations_path: &Path, - ignore_checksum_mismatch: bool, - ) -> anyhow::Result> { - let migrations = MigrationSource::resolve(migrations_path) - .await - .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?; - - let mut connection = sqlx::AnyConnection::connect(self.options.get_url()).await?; - - connection.ensure_migrations_table().await?; - let applied_migrations: HashMap<_, _> = connection - .list_applied_migrations() - .await? - .into_iter() - .map(|m| (m.version, m)) - .collect(); - - let mut new_migrations = Vec::new(); - for migration in migrations { - match applied_migrations.get(&migration.version) { - Some(applied_migration) => { - if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch - { - Err(anyhow!( - "checksum mismatch for applied migration {}", - migration.description - ))?; - } - } - None => { - let elapsed = connection.apply(&migration).await?; - new_migrations.push((migration, elapsed)); - } - } - } - - Ok(new_migrations) - } - - // users - - pub async fn create_user( - &self, - email_address: &str, - admin: bool, - params: NewUserParams, - ) -> Result { - self.transact(|tx| async { - let user = user::Entity::insert(user::ActiveModel { - email_address: ActiveValue::set(Some(email_address.into())), - github_login: ActiveValue::set(params.github_login.clone()), - github_user_id: ActiveValue::set(Some(params.github_user_id)), - admin: ActiveValue::set(admin), - metrics_id: ActiveValue::set(Uuid::new_v4()), - ..Default::default() - }) - .on_conflict( - OnConflict::column(user::Column::GithubLogin) - .update_column(user::Column::GithubLogin) - .to_owned(), - ) - .exec_with_returning(&tx) - .await?; - - tx.commit().await?; - - Ok(NewUserResult { - user_id: user.id, - metrics_id: user.metrics_id.to_string(), - signup_device_id: None, - inviting_user_id: None, - }) - }) - .await - } - - pub async fn get_user_by_id(&self, id: UserId) -> Result> { - self.transact(|tx| async move { Ok(user::Entity::find_by_id(id).one(&tx).await?) }) - .await - } - - pub async fn get_users_by_ids(&self, ids: Vec) -> Result> { - self.transact(|tx| async { - let tx = tx; - Ok(user::Entity::find() - .filter(user::Column::Id.is_in(ids.iter().copied())) - .all(&tx) - .await?) - }) - .await - } - - pub async fn get_user_by_github_account( - &self, - github_login: &str, - github_user_id: Option, - ) -> Result> { - self.transact(|tx| async { - let tx = tx; - if let Some(github_user_id) = github_user_id { - if let Some(user_by_github_user_id) = user::Entity::find() - .filter(user::Column::GithubUserId.eq(github_user_id)) - .one(&tx) - .await? - { - let mut user_by_github_user_id = user_by_github_user_id.into_active_model(); - user_by_github_user_id.github_login = ActiveValue::set(github_login.into()); - Ok(Some(user_by_github_user_id.update(&tx).await?)) - } else if let Some(user_by_github_login) = user::Entity::find() - .filter(user::Column::GithubLogin.eq(github_login)) - .one(&tx) - .await? - { - let mut user_by_github_login = user_by_github_login.into_active_model(); - user_by_github_login.github_user_id = ActiveValue::set(Some(github_user_id)); - Ok(Some(user_by_github_login.update(&tx).await?)) - } else { - Ok(None) - } - } else { - Ok(user::Entity::find() - .filter(user::Column::GithubLogin.eq(github_login)) - .one(&tx) - .await?) - } - }) - .await - } - - pub async fn get_user_metrics_id(&self, id: UserId) -> Result { - #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] - enum QueryAs { - MetricsId, - } - - self.transact(|tx| async move { - let metrics_id: Uuid = user::Entity::find_by_id(id) - .select_only() - .column(user::Column::MetricsId) - .into_values::<_, QueryAs>() - .one(&tx) - .await? - .ok_or_else(|| anyhow!("could not find user"))?; - Ok(metrics_id.to_string()) - }) - .await - } - - // contacts - - pub async fn get_contacts(&self, user_id: UserId) -> Result> { - #[derive(Debug, FromQueryResult)] - struct ContactWithUserBusyStatuses { - user_id_a: UserId, - user_id_b: UserId, - a_to_b: bool, - accepted: bool, - should_notify: bool, - user_a_busy: bool, - user_b_busy: bool, - } - - self.transact(|tx| async move { - let user_a_participant = Alias::new("user_a_participant"); - let user_b_participant = Alias::new("user_b_participant"); - let mut db_contacts = contact::Entity::find() - .column_as( - Expr::tbl(user_a_participant.clone(), room_participant::Column::Id) - .is_not_null(), - "user_a_busy", - ) - .column_as( - Expr::tbl(user_b_participant.clone(), room_participant::Column::Id) - .is_not_null(), - "user_b_busy", - ) - .filter( - contact::Column::UserIdA - .eq(user_id) - .or(contact::Column::UserIdB.eq(user_id)), - ) - .join_as( - JoinType::LeftJoin, - contact::Relation::UserARoomParticipant.def(), - user_a_participant, - ) - .join_as( - JoinType::LeftJoin, - contact::Relation::UserBRoomParticipant.def(), - user_b_participant, - ) - .into_model::() - .stream(&tx) - .await?; - - let mut contacts = Vec::new(); - while let Some(db_contact) = db_contacts.next().await { - let db_contact = db_contact?; - if db_contact.user_id_a == user_id { - if db_contact.accepted { - contacts.push(Contact::Accepted { - user_id: db_contact.user_id_b, - should_notify: db_contact.should_notify && db_contact.a_to_b, - busy: db_contact.user_b_busy, - }); - } else if db_contact.a_to_b { - contacts.push(Contact::Outgoing { - user_id: db_contact.user_id_b, - }) - } else { - contacts.push(Contact::Incoming { - user_id: db_contact.user_id_b, - should_notify: db_contact.should_notify, - }); - } - } else if db_contact.accepted { - contacts.push(Contact::Accepted { - user_id: db_contact.user_id_a, - should_notify: db_contact.should_notify && !db_contact.a_to_b, - busy: db_contact.user_a_busy, - }); - } else if db_contact.a_to_b { - contacts.push(Contact::Incoming { - user_id: db_contact.user_id_a, - should_notify: db_contact.should_notify, - }); - } else { - contacts.push(Contact::Outgoing { - user_id: db_contact.user_id_a, - }); - } - } - - contacts.sort_unstable_by_key(|contact| contact.user_id()); - - Ok(contacts) - }) - .await - } - - pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result { - self.transact(|tx| async move { - let (id_a, id_b) = if user_id_1 < user_id_2 { - (user_id_1, user_id_2) - } else { - (user_id_2, user_id_1) - }; - - Ok(contact::Entity::find() - .filter( - contact::Column::UserIdA - .eq(id_a) - .and(contact::Column::UserIdB.eq(id_b)) - .and(contact::Column::Accepted.eq(true)), - ) - .one(&tx) - .await? - .is_some()) - }) - .await - } - - pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> { - self.transact(|tx| async move { - let (id_a, id_b, a_to_b) = if sender_id < receiver_id { - (sender_id, receiver_id, true) - } else { - (receiver_id, sender_id, false) - }; - - let rows_affected = contact::Entity::insert(contact::ActiveModel { - user_id_a: ActiveValue::set(id_a), - user_id_b: ActiveValue::set(id_b), - a_to_b: ActiveValue::set(a_to_b), - accepted: ActiveValue::set(false), - should_notify: ActiveValue::set(true), - ..Default::default() - }) - .on_conflict( - OnConflict::columns([contact::Column::UserIdA, contact::Column::UserIdB]) - .values([ - (contact::Column::Accepted, true.into()), - (contact::Column::ShouldNotify, false.into()), - ]) - .action_and_where( - contact::Column::Accepted.eq(false).and( - contact::Column::AToB - .eq(a_to_b) - .and(contact::Column::UserIdA.eq(id_b)) - .or(contact::Column::AToB - .ne(a_to_b) - .and(contact::Column::UserIdA.eq(id_a))), - ), - ) - .to_owned(), - ) - .exec_without_returning(&tx) - .await?; - - if rows_affected == 1 { - tx.commit().await?; - Ok(()) - } else { - Err(anyhow!("contact already requested"))? - } - }) - .await - } - - pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> { - self.transact(|tx| async move { - let (id_a, id_b) = if responder_id < requester_id { - (responder_id, requester_id) - } else { - (requester_id, responder_id) - }; - - let result = contact::Entity::delete_many() - .filter( - contact::Column::UserIdA - .eq(id_a) - .and(contact::Column::UserIdB.eq(id_b)), - ) - .exec(&tx) - .await?; - - if result.rows_affected == 1 { - tx.commit().await?; - Ok(()) - } else { - Err(anyhow!("no such contact"))? - } - }) - .await - } - - pub async fn dismiss_contact_notification( - &self, - user_id: UserId, - contact_user_id: UserId, - ) -> Result<()> { - self.transact(|tx| async move { - let (id_a, id_b, a_to_b) = if user_id < contact_user_id { - (user_id, contact_user_id, true) - } else { - (contact_user_id, user_id, false) - }; - - let result = contact::Entity::update_many() - .set(contact::ActiveModel { - should_notify: ActiveValue::set(false), - ..Default::default() - }) - .filter( - contact::Column::UserIdA - .eq(id_a) - .and(contact::Column::UserIdB.eq(id_b)) - .and( - contact::Column::AToB - .eq(a_to_b) - .and(contact::Column::Accepted.eq(true)) - .or(contact::Column::AToB - .ne(a_to_b) - .and(contact::Column::Accepted.eq(false))), - ), - ) - .exec(&tx) - .await?; - if result.rows_affected == 0 { - Err(anyhow!("no such contact request"))? - } else { - tx.commit().await?; - Ok(()) - } - }) - .await - } - - pub async fn respond_to_contact_request( - &self, - responder_id: UserId, - requester_id: UserId, - accept: bool, - ) -> Result<()> { - self.transact(|tx| async move { - let (id_a, id_b, a_to_b) = if responder_id < requester_id { - (responder_id, requester_id, false) - } else { - (requester_id, responder_id, true) - }; - let rows_affected = if accept { - let result = contact::Entity::update_many() - .set(contact::ActiveModel { - accepted: ActiveValue::set(true), - should_notify: ActiveValue::set(true), - ..Default::default() - }) - .filter( - contact::Column::UserIdA - .eq(id_a) - .and(contact::Column::UserIdB.eq(id_b)) - .and(contact::Column::AToB.eq(a_to_b)), - ) - .exec(&tx) - .await?; - result.rows_affected - } else { - let result = contact::Entity::delete_many() - .filter( - contact::Column::UserIdA - .eq(id_a) - .and(contact::Column::UserIdB.eq(id_b)) - .and(contact::Column::AToB.eq(a_to_b)) - .and(contact::Column::Accepted.eq(false)), - ) - .exec(&tx) - .await?; - - result.rows_affected - }; - - if rows_affected == 1 { - tx.commit().await?; - Ok(()) - } else { - Err(anyhow!("no such contact request"))? - } - }) - .await - } - - pub fn fuzzy_like_string(string: &str) -> String { - let mut result = String::with_capacity(string.len() * 2 + 1); - for c in string.chars() { - if c.is_alphanumeric() { - result.push('%'); - result.push(c); - } - } - result.push('%'); - result - } - - pub async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result> { - self.transact(|tx| async { - let tx = tx; - let like_string = Self::fuzzy_like_string(name_query); - let query = " - SELECT users.* - FROM users - WHERE github_login ILIKE $1 - ORDER BY github_login <-> $2 - LIMIT $3 - "; - - Ok(user::Entity::find() - .from_raw_sql(Statement::from_sql_and_values( - self.pool.get_database_backend(), - query.into(), - vec![like_string.into(), name_query.into(), limit.into()], - )) - .all(&tx) - .await?) - }) - .await - } - - // signups - - pub async fn create_signup(&self, signup: NewSignup) -> Result<()> { - self.transact(|tx| async { - signup::ActiveModel { - email_address: ActiveValue::set(signup.email_address.clone()), - email_confirmation_code: ActiveValue::set(random_email_confirmation_code()), - email_confirmation_sent: ActiveValue::set(false), - platform_mac: ActiveValue::set(signup.platform_mac), - platform_windows: ActiveValue::set(signup.platform_windows), - platform_linux: ActiveValue::set(signup.platform_linux), - platform_unknown: ActiveValue::set(false), - editor_features: ActiveValue::set(Some(signup.editor_features.clone())), - programming_languages: ActiveValue::set(Some(signup.programming_languages.clone())), - device_id: ActiveValue::set(signup.device_id.clone()), - ..Default::default() - } - .insert(&tx) - .await?; - tx.commit().await?; - Ok(()) - }) - .await - } - - pub async fn get_waitlist_summary(&self) -> Result { - self.transact(|tx| async move { - let query = " - SELECT - COUNT(*) as count, - COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count, - COALESCE(SUM(CASE WHEN platform_mac THEN 1 ELSE 0 END), 0) as mac_count, - COALESCE(SUM(CASE WHEN platform_windows THEN 1 ELSE 0 END), 0) as windows_count, - COALESCE(SUM(CASE WHEN platform_unknown THEN 1 ELSE 0 END), 0) as unknown_count - FROM ( - SELECT * - FROM signups - WHERE - NOT email_confirmation_sent - ) AS unsent - "; - Ok( - WaitlistSummary::find_by_statement(Statement::from_sql_and_values( - self.pool.get_database_backend(), - query.into(), - vec![], - )) - .one(&tx) - .await? - .ok_or_else(|| anyhow!("invalid result"))?, - ) - }) - .await - } - - pub async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> { - let emails = invites - .iter() - .map(|s| s.email_address.as_str()) - .collect::>(); - self.transact(|tx| async { - signup::Entity::update_many() - .filter(signup::Column::EmailAddress.is_in(emails.iter().copied())) - .col_expr(signup::Column::EmailConfirmationSent, true.into()) - .exec(&tx) - .await?; - tx.commit().await?; - Ok(()) - }) - .await - } - - pub async fn get_unsent_invites(&self, count: usize) -> Result> { - self.transact(|tx| async move { - Ok(signup::Entity::find() - .select_only() - .column(signup::Column::EmailAddress) - .column(signup::Column::EmailConfirmationCode) - .filter( - signup::Column::EmailConfirmationSent.eq(false).and( - signup::Column::PlatformMac - .eq(true) - .or(signup::Column::PlatformUnknown.eq(true)), - ), - ) - .limit(count as u64) - .into_model() - .all(&tx) - .await?) - }) - .await - } - - // invite codes - - pub async fn create_invite_from_code( - &self, - code: &str, - email_address: &str, - device_id: Option<&str>, - ) -> Result { - self.transact(|tx| async move { - let existing_user = user::Entity::find() - .filter(user::Column::EmailAddress.eq(email_address)) - .one(&tx) - .await?; - - if existing_user.is_some() { - Err(anyhow!("email address is already in use"))?; - } - - let inviter = match user::Entity::find() - .filter(user::Column::InviteCode.eq(code)) - .one(&tx) - .await? - { - Some(inviter) => inviter, - None => { - return Err(Error::Http( - StatusCode::NOT_FOUND, - "invite code not found".to_string(), - ))? - } - }; - - if inviter.invite_count == 0 { - Err(Error::Http( - StatusCode::UNAUTHORIZED, - "no invites remaining".to_string(), - ))?; - } - - let signup = signup::Entity::insert(signup::ActiveModel { - email_address: ActiveValue::set(email_address.into()), - email_confirmation_code: ActiveValue::set(random_email_confirmation_code()), - email_confirmation_sent: ActiveValue::set(false), - inviting_user_id: ActiveValue::set(Some(inviter.id)), - platform_linux: ActiveValue::set(false), - platform_mac: ActiveValue::set(false), - platform_windows: ActiveValue::set(false), - platform_unknown: ActiveValue::set(true), - device_id: ActiveValue::set(device_id.map(|device_id| device_id.into())), - ..Default::default() - }) - .on_conflict( - OnConflict::column(signup::Column::EmailAddress) - .update_column(signup::Column::InvitingUserId) - .to_owned(), - ) - .exec_with_returning(&tx) - .await?; - tx.commit().await?; - - Ok(Invite { - email_address: signup.email_address, - email_confirmation_code: signup.email_confirmation_code, - }) - }) - .await - } - - pub async fn create_user_from_invite( - &self, - invite: &Invite, - user: NewUserParams, - ) -> Result> { - self.transact(|tx| async { - let tx = tx; - let signup = signup::Entity::find() - .filter( - signup::Column::EmailAddress - .eq(invite.email_address.as_str()) - .and( - signup::Column::EmailConfirmationCode - .eq(invite.email_confirmation_code.as_str()), - ), - ) - .one(&tx) - .await? - .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?; - - if signup.user_id.is_some() { - return Ok(None); - } - - let user = user::Entity::insert(user::ActiveModel { - email_address: ActiveValue::set(Some(invite.email_address.clone())), - github_login: ActiveValue::set(user.github_login.clone()), - github_user_id: ActiveValue::set(Some(user.github_user_id)), - admin: ActiveValue::set(false), - invite_count: ActiveValue::set(user.invite_count), - invite_code: ActiveValue::set(Some(random_invite_code())), - metrics_id: ActiveValue::set(Uuid::new_v4()), - ..Default::default() - }) - .on_conflict( - OnConflict::column(user::Column::GithubLogin) - .update_columns([ - user::Column::EmailAddress, - user::Column::GithubUserId, - user::Column::Admin, - ]) - .to_owned(), - ) - .exec_with_returning(&tx) - .await?; - - let mut signup = signup.into_active_model(); - signup.user_id = ActiveValue::set(Some(user.id)); - let signup = signup.update(&tx).await?; - - if let Some(inviting_user_id) = signup.inviting_user_id { - let result = user::Entity::update_many() - .filter( - user::Column::Id - .eq(inviting_user_id) - .and(user::Column::InviteCount.gt(0)), - ) - .col_expr( - user::Column::InviteCount, - Expr::col(user::Column::InviteCount).sub(1), - ) - .exec(&tx) - .await?; - - if result.rows_affected == 0 { - Err(Error::Http( - StatusCode::UNAUTHORIZED, - "no invites remaining".to_string(), - ))?; - } - - contact::Entity::insert(contact::ActiveModel { - user_id_a: ActiveValue::set(inviting_user_id), - user_id_b: ActiveValue::set(user.id), - a_to_b: ActiveValue::set(true), - should_notify: ActiveValue::set(true), - accepted: ActiveValue::set(true), - ..Default::default() - }) - .on_conflict(OnConflict::new().do_nothing().to_owned()) - .exec_without_returning(&tx) - .await?; - } - - tx.commit().await?; - Ok(Some(NewUserResult { - user_id: user.id, - metrics_id: user.metrics_id.to_string(), - inviting_user_id: signup.inviting_user_id, - signup_device_id: signup.device_id, - })) - }) - .await - } - - pub async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> { - self.transact(|tx| async move { - if count > 0 { - user::Entity::update_many() - .filter( - user::Column::Id - .eq(id) - .and(user::Column::InviteCode.is_null()), - ) - .col_expr(user::Column::InviteCode, random_invite_code().into()) - .exec(&tx) - .await?; - } - - user::Entity::update_many() - .filter(user::Column::Id.eq(id)) - .col_expr(user::Column::InviteCount, count.into()) - .exec(&tx) - .await?; - tx.commit().await?; - Ok(()) - }) - .await - } - - pub async fn get_invite_code_for_user(&self, id: UserId) -> Result> { - self.transact(|tx| async move { - match user::Entity::find_by_id(id).one(&tx).await? { - Some(user) if user.invite_code.is_some() => { - Ok(Some((user.invite_code.unwrap(), user.invite_count as u32))) - } - _ => Ok(None), - } - }) - .await - } - - pub async fn get_user_for_invite_code(&self, code: &str) -> Result { - self.transact(|tx| async move { - user::Entity::find() - .filter(user::Column::InviteCode.eq(code)) - .one(&tx) - .await? - .ok_or_else(|| { - Error::Http( - StatusCode::NOT_FOUND, - "that invite code does not exist".to_string(), - ) - }) - }) - .await - } - - // projects - - pub async fn share_project( - &self, - room_id: RoomId, - connection_id: ConnectionId, - worktrees: &[proto::WorktreeMetadata], - ) -> Result> { - self.transact(|tx| async move { - let participant = room_participant::Entity::find() - .filter(room_participant::Column::AnsweringConnectionId.eq(connection_id.0)) - .one(&tx) - .await? - .ok_or_else(|| anyhow!("could not find participant"))?; - if participant.room_id != room_id { - return Err(anyhow!("shared project on unexpected room"))?; - } - - let project = project::ActiveModel { - room_id: ActiveValue::set(participant.room_id), - host_user_id: ActiveValue::set(participant.user_id), - host_connection_id: ActiveValue::set(connection_id.0 as i32), - ..Default::default() - } - .insert(&tx) - .await?; - - worktree::Entity::insert_many(worktrees.iter().map(|worktree| worktree::ActiveModel { - id: ActiveValue::set(worktree.id as i32), - project_id: ActiveValue::set(project.id), - abs_path: ActiveValue::set(worktree.abs_path.clone()), - root_name: ActiveValue::set(worktree.root_name.clone()), - visible: ActiveValue::set(worktree.visible), - scan_id: ActiveValue::set(0), - is_complete: ActiveValue::set(false), - })) - .exec(&tx) - .await?; - - project_collaborator::ActiveModel { - project_id: ActiveValue::set(project.id), - connection_id: ActiveValue::set(connection_id.0 as i32), - user_id: ActiveValue::set(participant.user_id), - replica_id: ActiveValue::set(0), - is_host: ActiveValue::set(true), - ..Default::default() - } - .insert(&tx) - .await?; - - let room = self.get_room(room_id, &tx).await?; - self.commit_room_transaction(room_id, tx, (project.id, room)) - .await - }) - .await - } - - async fn get_room(&self, room_id: RoomId, tx: &DatabaseTransaction) -> Result { - let db_room = room::Entity::find_by_id(room_id) - .one(tx) - .await? - .ok_or_else(|| anyhow!("could not find room"))?; - - let mut db_participants = db_room - .find_related(room_participant::Entity) - .stream(tx) - .await?; - let mut participants = HashMap::default(); - let mut pending_participants = Vec::new(); - while let Some(db_participant) = db_participants.next().await { - let db_participant = db_participant?; - if let Some(answering_connection_id) = db_participant.answering_connection_id { - let location = match ( - db_participant.location_kind, - db_participant.location_project_id, - ) { - (Some(0), Some(project_id)) => { - Some(proto::participant_location::Variant::SharedProject( - proto::participant_location::SharedProject { - id: project_id.to_proto(), - }, - )) - } - (Some(1), _) => Some(proto::participant_location::Variant::UnsharedProject( - Default::default(), - )), - _ => Some(proto::participant_location::Variant::External( - Default::default(), - )), - }; - participants.insert( - answering_connection_id, - proto::Participant { - user_id: db_participant.user_id.to_proto(), - peer_id: answering_connection_id as u32, - projects: Default::default(), - location: Some(proto::ParticipantLocation { variant: location }), - }, - ); - } else { - pending_participants.push(proto::PendingParticipant { - user_id: db_participant.user_id.to_proto(), - calling_user_id: db_participant.calling_user_id.to_proto(), - initial_project_id: db_participant.initial_project_id.map(|id| id.to_proto()), - }); - } - } - - let mut db_projects = db_room - .find_related(project::Entity) - .find_with_related(worktree::Entity) - .stream(tx) - .await?; - - while let Some(row) = db_projects.next().await { - let (db_project, db_worktree) = row?; - if let Some(participant) = participants.get_mut(&db_project.host_connection_id) { - let project = if let Some(project) = participant - .projects - .iter_mut() - .find(|project| project.id == db_project.id.to_proto()) - { - project - } else { - participant.projects.push(proto::ParticipantProject { - id: db_project.id.to_proto(), - worktree_root_names: Default::default(), - }); - participant.projects.last_mut().unwrap() - }; - - if let Some(db_worktree) = db_worktree { - project.worktree_root_names.push(db_worktree.root_name); - } - } - } - - Ok(proto::Room { - id: db_room.id.to_proto(), - live_kit_room: db_room.live_kit_room, - participants: participants.into_values().collect(), - pending_participants, - }) - } - - async fn commit_room_transaction( - &self, - room_id: RoomId, - tx: DatabaseTransaction, - data: T, - ) -> Result> { - let lock = self.rooms.entry(room_id).or_default().clone(); - let _guard = lock.lock_owned().await; - tx.commit().await?; - Ok(RoomGuard { - data, - _guard, - _not_send: PhantomData, - }) - } - - pub async fn create_access_token_hash( - &self, - user_id: UserId, - access_token_hash: &str, - max_access_token_count: usize, - ) -> Result<()> { - self.transact(|tx| async { - let tx = tx; - - access_token::ActiveModel { - user_id: ActiveValue::set(user_id), - hash: ActiveValue::set(access_token_hash.into()), - ..Default::default() - } - .insert(&tx) - .await?; - - access_token::Entity::delete_many() - .filter( - access_token::Column::Id.in_subquery( - Query::select() - .column(access_token::Column::Id) - .from(access_token::Entity) - .and_where(access_token::Column::UserId.eq(user_id)) - .order_by(access_token::Column::Id, sea_orm::Order::Desc) - .limit(10000) - .offset(max_access_token_count as u64) - .to_owned(), - ), - ) - .exec(&tx) - .await?; - tx.commit().await?; - Ok(()) - }) - .await - } - - pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result> { - #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] - enum QueryAs { - Hash, - } - - self.transact(|tx| async move { - Ok(access_token::Entity::find() - .select_only() - .column(access_token::Column::Hash) - .filter(access_token::Column::UserId.eq(user_id)) - .order_by_desc(access_token::Column::Id) - .into_values::<_, QueryAs>() - .all(&tx) - .await?) - }) - .await - } - - async fn transact(&self, f: F) -> Result - where - F: Send + Fn(DatabaseTransaction) -> Fut, - Fut: Send + Future>, - { - let body = async { - loop { - let tx = self.pool.begin().await?; - - // In Postgres, serializable transactions are opt-in - if let DatabaseBackend::Postgres = self.pool.get_database_backend() { - tx.execute(Statement::from_string( - DatabaseBackend::Postgres, - "SET TRANSACTION ISOLATION LEVEL SERIALIZABLE;".into(), - )) - .await?; - } - - match f(tx).await { - Ok(result) => return Ok(result), - Err(error) => match error { - Error::Database2( - DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error)) - | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)), - ) if error - .as_database_error() - .and_then(|error| error.code()) - .as_deref() - == Some("40001") => - { - // Retry (don't break the loop) - } - error @ _ => return Err(error), - }, - } - } - }; - - #[cfg(test)] - { - if let Some(background) = self.background.as_ref() { - background.simulate_random_delay().await; - } - - self.runtime.as_ref().unwrap().block_on(body) - } - - #[cfg(not(test))] - { - body.await - } - } -} - -pub struct RoomGuard { - data: T, - _guard: OwnedMutexGuard<()>, - _not_send: PhantomData>, -} - -impl Deref for RoomGuard { - type Target = T; - - fn deref(&self) -> &T { - &self.data - } -} - -impl DerefMut for RoomGuard { - fn deref_mut(&mut self) -> &mut T { - &mut self.data - } -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct NewUserParams { - pub github_login: String, - pub github_user_id: i32, - pub invite_count: i32, -} - -#[derive(Debug)] -pub struct NewUserResult { - pub user_id: UserId, - pub metrics_id: String, - pub inviting_user_id: Option, - pub signup_device_id: Option, -} - -fn random_invite_code() -> String { - nanoid::nanoid!(16) -} - -fn random_email_confirmation_code() -> String { - nanoid::nanoid!(64) -} - -macro_rules! id_type { - ($name:ident) => { - #[derive( - Clone, - Copy, - Debug, - Default, - PartialEq, - Eq, - PartialOrd, - Ord, - Hash, - sqlx::Type, - Serialize, - Deserialize, - )] - #[sqlx(transparent)] - #[serde(transparent)] - pub struct $name(pub i32); - - impl $name { - #[allow(unused)] - pub const MAX: Self = Self(i32::MAX); - - #[allow(unused)] - pub fn from_proto(value: u64) -> Self { - Self(value as i32) - } - - #[allow(unused)] - pub fn to_proto(self) -> u64 { - self.0 as u64 - } - } - - impl std::fmt::Display for $name { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - self.0.fmt(f) - } - } - - impl From<$name> for sea_query::Value { - fn from(value: $name) -> Self { - sea_query::Value::Int(Some(value.0)) - } - } - - impl sea_orm::TryGetable for $name { - fn try_get( - res: &sea_orm::QueryResult, - pre: &str, - col: &str, - ) -> Result { - Ok(Self(i32::try_get(res, pre, col)?)) - } - } - - impl sea_query::ValueType for $name { - fn try_from(v: Value) -> Result { - match v { - Value::TinyInt(Some(int)) => { - Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?)) - } - Value::SmallInt(Some(int)) => { - Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?)) - } - Value::Int(Some(int)) => { - Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?)) - } - Value::BigInt(Some(int)) => { - Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?)) - } - Value::TinyUnsigned(Some(int)) => { - Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?)) - } - Value::SmallUnsigned(Some(int)) => { - Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?)) - } - Value::Unsigned(Some(int)) => { - Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?)) - } - Value::BigUnsigned(Some(int)) => { - Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?)) - } - _ => Err(sea_query::ValueTypeErr), - } - } - - fn type_name() -> String { - stringify!($name).into() - } - - fn array_type() -> sea_query::ArrayType { - sea_query::ArrayType::Int - } - - fn column_type() -> sea_query::ColumnType { - sea_query::ColumnType::Integer(None) - } - } - - impl sea_orm::TryFromU64 for $name { - fn try_from_u64(n: u64) -> Result { - Ok(Self(n.try_into().map_err(|_| { - DbErr::ConvertFromU64(concat!( - "error converting ", - stringify!($name), - " to u64" - )) - })?)) - } - } - - impl sea_query::Nullable for $name { - fn null() -> Value { - Value::Int(None) - } - } - }; -} - -id_type!(AccessTokenId); -id_type!(ContactId); -id_type!(UserId); -id_type!(RoomId); -id_type!(RoomParticipantId); -id_type!(ProjectId); -id_type!(ProjectCollaboratorId); -id_type!(SignupId); -id_type!(WorktreeId); - -#[cfg(test)] -pub use test::*; - -#[cfg(test)] -mod test { - use super::*; - use gpui::executor::Background; - use lazy_static::lazy_static; - use parking_lot::Mutex; - use rand::prelude::*; - use sea_orm::ConnectionTrait; - use sqlx::migrate::MigrateDatabase; - use std::sync::Arc; - - pub struct TestDb { - pub db: Option>, - pub connection: Option, - } - - impl TestDb { - pub fn sqlite(background: Arc) -> Self { - let url = format!("sqlite::memory:"); - let runtime = tokio::runtime::Builder::new_current_thread() - .enable_io() - .enable_time() - .build() - .unwrap(); - - let mut db = runtime.block_on(async { - let mut options = ConnectOptions::new(url); - options.max_connections(5); - let db = Database::new(options).await.unwrap(); - let sql = include_str!(concat!( - env!("CARGO_MANIFEST_DIR"), - "/migrations.sqlite/20221109000000_test_schema.sql" - )); - db.pool - .execute(sea_orm::Statement::from_string( - db.pool.get_database_backend(), - sql.into(), - )) - .await - .unwrap(); - db - }); - - db.background = Some(background); - db.runtime = Some(runtime); - - Self { - db: Some(Arc::new(db)), - connection: None, - } - } - - pub fn postgres(background: Arc) -> Self { - lazy_static! { - static ref LOCK: Mutex<()> = Mutex::new(()); - } - - let _guard = LOCK.lock(); - let mut rng = StdRng::from_entropy(); - let url = format!( - "postgres://postgres@localhost/zed-test-{}", - rng.gen::() - ); - let runtime = tokio::runtime::Builder::new_current_thread() - .enable_io() - .enable_time() - .build() - .unwrap(); - - let mut db = runtime.block_on(async { - sqlx::Postgres::create_database(&url) - .await - .expect("failed to create test db"); - let mut options = ConnectOptions::new(url); - options - .max_connections(5) - .idle_timeout(Duration::from_secs(0)); - let db = Database::new(options).await.unwrap(); - let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"); - db.migrate(Path::new(migrations_path), false).await.unwrap(); - db - }); - - db.background = Some(background); - db.runtime = Some(runtime); - - Self { - db: Some(Arc::new(db)), - connection: None, - } - } - - pub fn db(&self) -> &Arc { - self.db.as_ref().unwrap() - } - } - - impl Drop for TestDb { - fn drop(&mut self) { - let db = self.db.take().unwrap(); - if let DatabaseBackend::Postgres = db.pool.get_database_backend() { - db.runtime.as_ref().unwrap().block_on(async { - use util::ResultExt; - let query = " - SELECT pg_terminate_backend(pg_stat_activity.pid) - FROM pg_stat_activity - WHERE - pg_stat_activity.datname = current_database() AND - pid <> pg_backend_pid(); - "; - db.pool - .execute(sea_orm::Statement::from_string( - db.pool.get_database_backend(), - query.into(), - )) - .await - .log_err(); - sqlx::Postgres::drop_database(db.options.get_url()) - .await - .log_err(); - }) - } - } - } -} diff --git a/crates/collab/src/db2/tests.rs b/crates/collab/src/db2/tests.rs deleted file mode 100644 index b276bd5057..0000000000 --- a/crates/collab/src/db2/tests.rs +++ /dev/null @@ -1,813 +0,0 @@ -use super::*; -use gpui::executor::{Background, Deterministic}; -use std::sync::Arc; - -macro_rules! test_both_dbs { - ($postgres_test_name:ident, $sqlite_test_name:ident, $db:ident, $body:block) => { - #[gpui::test] - async fn $postgres_test_name() { - let test_db = TestDb::postgres(Deterministic::new(0).build_background()); - let $db = test_db.db(); - $body - } - - #[gpui::test] - async fn $sqlite_test_name() { - let test_db = TestDb::sqlite(Deterministic::new(0).build_background()); - let $db = test_db.db(); - $body - } - }; -} - -test_both_dbs!( - test_get_users_by_ids_postgres, - test_get_users_by_ids_sqlite, - db, - { - let mut user_ids = Vec::new(); - let mut user_metric_ids = Vec::new(); - for i in 1..=4 { - let user = db - .create_user( - &format!("user{i}@example.com"), - false, - NewUserParams { - github_login: format!("user{i}"), - github_user_id: i, - invite_count: 0, - }, - ) - .await - .unwrap(); - user_ids.push(user.user_id); - user_metric_ids.push(user.metrics_id); - } - - assert_eq!( - db.get_users_by_ids(user_ids.clone()).await.unwrap(), - vec![ - User { - id: user_ids[0], - github_login: "user1".to_string(), - github_user_id: Some(1), - email_address: Some("user1@example.com".to_string()), - admin: false, - metrics_id: user_metric_ids[0].parse().unwrap(), - ..Default::default() - }, - User { - id: user_ids[1], - github_login: "user2".to_string(), - github_user_id: Some(2), - email_address: Some("user2@example.com".to_string()), - admin: false, - metrics_id: user_metric_ids[1].parse().unwrap(), - ..Default::default() - }, - User { - id: user_ids[2], - github_login: "user3".to_string(), - github_user_id: Some(3), - email_address: Some("user3@example.com".to_string()), - admin: false, - metrics_id: user_metric_ids[2].parse().unwrap(), - ..Default::default() - }, - User { - id: user_ids[3], - github_login: "user4".to_string(), - github_user_id: Some(4), - email_address: Some("user4@example.com".to_string()), - admin: false, - metrics_id: user_metric_ids[3].parse().unwrap(), - ..Default::default() - } - ] - ); - } -); - -test_both_dbs!( - test_get_user_by_github_account_postgres, - test_get_user_by_github_account_sqlite, - db, - { - let user_id1 = db - .create_user( - "user1@example.com", - false, - NewUserParams { - github_login: "login1".into(), - github_user_id: 101, - invite_count: 0, - }, - ) - .await - .unwrap() - .user_id; - let user_id2 = db - .create_user( - "user2@example.com", - false, - NewUserParams { - github_login: "login2".into(), - github_user_id: 102, - invite_count: 0, - }, - ) - .await - .unwrap() - .user_id; - - let user = db - .get_user_by_github_account("login1", None) - .await - .unwrap() - .unwrap(); - assert_eq!(user.id, user_id1); - assert_eq!(&user.github_login, "login1"); - assert_eq!(user.github_user_id, Some(101)); - - assert!(db - .get_user_by_github_account("non-existent-login", None) - .await - .unwrap() - .is_none()); - - let user = db - .get_user_by_github_account("the-new-login2", Some(102)) - .await - .unwrap() - .unwrap(); - assert_eq!(user.id, user_id2); - assert_eq!(&user.github_login, "the-new-login2"); - assert_eq!(user.github_user_id, Some(102)); - } -); - -test_both_dbs!( - test_create_access_tokens_postgres, - test_create_access_tokens_sqlite, - db, - { - let user = db - .create_user( - "u1@example.com", - false, - NewUserParams { - github_login: "u1".into(), - github_user_id: 1, - invite_count: 0, - }, - ) - .await - .unwrap() - .user_id; - - db.create_access_token_hash(user, "h1", 3).await.unwrap(); - db.create_access_token_hash(user, "h2", 3).await.unwrap(); - assert_eq!( - db.get_access_token_hashes(user).await.unwrap(), - &["h2".to_string(), "h1".to_string()] - ); - - db.create_access_token_hash(user, "h3", 3).await.unwrap(); - assert_eq!( - db.get_access_token_hashes(user).await.unwrap(), - &["h3".to_string(), "h2".to_string(), "h1".to_string(),] - ); - - db.create_access_token_hash(user, "h4", 3).await.unwrap(); - assert_eq!( - db.get_access_token_hashes(user).await.unwrap(), - &["h4".to_string(), "h3".to_string(), "h2".to_string(),] - ); - - db.create_access_token_hash(user, "h5", 3).await.unwrap(); - assert_eq!( - db.get_access_token_hashes(user).await.unwrap(), - &["h5".to_string(), "h4".to_string(), "h3".to_string()] - ); - } -); - -test_both_dbs!(test_add_contacts_postgres, test_add_contacts_sqlite, db, { - let mut user_ids = Vec::new(); - for i in 0..3 { - user_ids.push( - db.create_user( - &format!("user{i}@example.com"), - false, - NewUserParams { - github_login: format!("user{i}"), - github_user_id: i, - invite_count: 0, - }, - ) - .await - .unwrap() - .user_id, - ); - } - - let user_1 = user_ids[0]; - let user_2 = user_ids[1]; - let user_3 = user_ids[2]; - - // User starts with no contacts - assert_eq!(db.get_contacts(user_1).await.unwrap(), &[]); - - // User requests a contact. Both users see the pending request. - db.send_contact_request(user_1, user_2).await.unwrap(); - assert!(!db.has_contact(user_1, user_2).await.unwrap()); - assert!(!db.has_contact(user_2, user_1).await.unwrap()); - assert_eq!( - db.get_contacts(user_1).await.unwrap(), - &[Contact::Outgoing { user_id: user_2 }], - ); - assert_eq!( - db.get_contacts(user_2).await.unwrap(), - &[Contact::Incoming { - user_id: user_1, - should_notify: true - }] - ); - - // User 2 dismisses the contact request notification without accepting or rejecting. - // We shouldn't notify them again. - db.dismiss_contact_notification(user_1, user_2) - .await - .unwrap_err(); - db.dismiss_contact_notification(user_2, user_1) - .await - .unwrap(); - assert_eq!( - db.get_contacts(user_2).await.unwrap(), - &[Contact::Incoming { - user_id: user_1, - should_notify: false - }] - ); - - // User can't accept their own contact request - db.respond_to_contact_request(user_1, user_2, true) - .await - .unwrap_err(); - - // User accepts a contact request. Both users see the contact. - db.respond_to_contact_request(user_2, user_1, true) - .await - .unwrap(); - assert_eq!( - db.get_contacts(user_1).await.unwrap(), - &[Contact::Accepted { - user_id: user_2, - should_notify: true, - busy: false, - }], - ); - assert!(db.has_contact(user_1, user_2).await.unwrap()); - assert!(db.has_contact(user_2, user_1).await.unwrap()); - assert_eq!( - db.get_contacts(user_2).await.unwrap(), - &[Contact::Accepted { - user_id: user_1, - should_notify: false, - busy: false, - }] - ); - - // Users cannot re-request existing contacts. - db.send_contact_request(user_1, user_2).await.unwrap_err(); - db.send_contact_request(user_2, user_1).await.unwrap_err(); - - // Users can't dismiss notifications of them accepting other users' requests. - db.dismiss_contact_notification(user_2, user_1) - .await - .unwrap_err(); - assert_eq!( - db.get_contacts(user_1).await.unwrap(), - &[Contact::Accepted { - user_id: user_2, - should_notify: true, - busy: false, - }] - ); - - // Users can dismiss notifications of other users accepting their requests. - db.dismiss_contact_notification(user_1, user_2) - .await - .unwrap(); - assert_eq!( - db.get_contacts(user_1).await.unwrap(), - &[Contact::Accepted { - user_id: user_2, - should_notify: false, - busy: false, - }] - ); - - // Users send each other concurrent contact requests and - // see that they are immediately accepted. - db.send_contact_request(user_1, user_3).await.unwrap(); - db.send_contact_request(user_3, user_1).await.unwrap(); - assert_eq!( - db.get_contacts(user_1).await.unwrap(), - &[ - Contact::Accepted { - user_id: user_2, - should_notify: false, - busy: false, - }, - Contact::Accepted { - user_id: user_3, - should_notify: false, - busy: false, - } - ] - ); - assert_eq!( - db.get_contacts(user_3).await.unwrap(), - &[Contact::Accepted { - user_id: user_1, - should_notify: false, - busy: false, - }], - ); - - // User declines a contact request. Both users see that it is gone. - db.send_contact_request(user_2, user_3).await.unwrap(); - db.respond_to_contact_request(user_3, user_2, false) - .await - .unwrap(); - assert!(!db.has_contact(user_2, user_3).await.unwrap()); - assert!(!db.has_contact(user_3, user_2).await.unwrap()); - assert_eq!( - db.get_contacts(user_2).await.unwrap(), - &[Contact::Accepted { - user_id: user_1, - should_notify: false, - busy: false, - }] - ); - assert_eq!( - db.get_contacts(user_3).await.unwrap(), - &[Contact::Accepted { - user_id: user_1, - should_notify: false, - busy: false, - }], - ); -}); - -test_both_dbs!(test_metrics_id_postgres, test_metrics_id_sqlite, db, { - let NewUserResult { - user_id: user1, - metrics_id: metrics_id1, - .. - } = db - .create_user( - "person1@example.com", - false, - NewUserParams { - github_login: "person1".into(), - github_user_id: 101, - invite_count: 5, - }, - ) - .await - .unwrap(); - let NewUserResult { - user_id: user2, - metrics_id: metrics_id2, - .. - } = db - .create_user( - "person2@example.com", - false, - NewUserParams { - github_login: "person2".into(), - github_user_id: 102, - invite_count: 5, - }, - ) - .await - .unwrap(); - - assert_eq!(db.get_user_metrics_id(user1).await.unwrap(), metrics_id1); - assert_eq!(db.get_user_metrics_id(user2).await.unwrap(), metrics_id2); - assert_eq!(metrics_id1.len(), 36); - assert_eq!(metrics_id2.len(), 36); - assert_ne!(metrics_id1, metrics_id2); -}); - -#[test] -fn test_fuzzy_like_string() { - assert_eq!(Database::fuzzy_like_string("abcd"), "%a%b%c%d%"); - assert_eq!(Database::fuzzy_like_string("x y"), "%x%y%"); - assert_eq!(Database::fuzzy_like_string(" z "), "%z%"); -} - -#[gpui::test] -async fn test_fuzzy_search_users() { - let test_db = TestDb::postgres(build_background_executor()); - let db = test_db.db(); - for (i, github_login) in [ - "California", - "colorado", - "oregon", - "washington", - "florida", - "delaware", - "rhode-island", - ] - .into_iter() - .enumerate() - { - db.create_user( - &format!("{github_login}@example.com"), - false, - NewUserParams { - github_login: github_login.into(), - github_user_id: i as i32, - invite_count: 0, - }, - ) - .await - .unwrap(); - } - - assert_eq!( - fuzzy_search_user_names(db, "clr").await, - &["colorado", "California"] - ); - assert_eq!( - fuzzy_search_user_names(db, "ro").await, - &["rhode-island", "colorado", "oregon"], - ); - - async fn fuzzy_search_user_names(db: &Database, query: &str) -> Vec { - db.fuzzy_search_users(query, 10) - .await - .unwrap() - .into_iter() - .map(|user| user.github_login) - .collect::>() - } -} - -#[gpui::test] -async fn test_invite_codes() { - let test_db = TestDb::postgres(build_background_executor()); - let db = test_db.db(); - - let NewUserResult { user_id: user1, .. } = db - .create_user( - "user1@example.com", - false, - NewUserParams { - github_login: "user1".into(), - github_user_id: 0, - invite_count: 0, - }, - ) - .await - .unwrap(); - - // Initially, user 1 has no invite code - assert_eq!(db.get_invite_code_for_user(user1).await.unwrap(), None); - - // Setting invite count to 0 when no code is assigned does not assign a new code - db.set_invite_count_for_user(user1, 0).await.unwrap(); - assert!(db.get_invite_code_for_user(user1).await.unwrap().is_none()); - - // User 1 creates an invite code that can be used twice. - db.set_invite_count_for_user(user1, 2).await.unwrap(); - let (invite_code, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap(); - assert_eq!(invite_count, 2); - - // User 2 redeems the invite code and becomes a contact of user 1. - let user2_invite = db - .create_invite_from_code(&invite_code, "user2@example.com", Some("user-2-device-id")) - .await - .unwrap(); - let NewUserResult { - user_id: user2, - inviting_user_id, - signup_device_id, - metrics_id, - } = db - .create_user_from_invite( - &user2_invite, - NewUserParams { - github_login: "user2".into(), - github_user_id: 2, - invite_count: 7, - }, - ) - .await - .unwrap() - .unwrap(); - let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap(); - assert_eq!(invite_count, 1); - assert_eq!(inviting_user_id, Some(user1)); - assert_eq!(signup_device_id.unwrap(), "user-2-device-id"); - assert_eq!(db.get_user_metrics_id(user2).await.unwrap(), metrics_id); - assert_eq!( - db.get_contacts(user1).await.unwrap(), - [Contact::Accepted { - user_id: user2, - should_notify: true, - busy: false, - }] - ); - assert_eq!( - db.get_contacts(user2).await.unwrap(), - [Contact::Accepted { - user_id: user1, - should_notify: false, - busy: false, - }] - ); - assert_eq!( - db.get_invite_code_for_user(user2).await.unwrap().unwrap().1, - 7 - ); - - // User 3 redeems the invite code and becomes a contact of user 1. - let user3_invite = db - .create_invite_from_code(&invite_code, "user3@example.com", None) - .await - .unwrap(); - let NewUserResult { - user_id: user3, - inviting_user_id, - signup_device_id, - .. - } = db - .create_user_from_invite( - &user3_invite, - NewUserParams { - github_login: "user-3".into(), - github_user_id: 3, - invite_count: 3, - }, - ) - .await - .unwrap() - .unwrap(); - let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap(); - assert_eq!(invite_count, 0); - assert_eq!(inviting_user_id, Some(user1)); - assert!(signup_device_id.is_none()); - assert_eq!( - db.get_contacts(user1).await.unwrap(), - [ - Contact::Accepted { - user_id: user2, - should_notify: true, - busy: false, - }, - Contact::Accepted { - user_id: user3, - should_notify: true, - busy: false, - } - ] - ); - assert_eq!( - db.get_contacts(user3).await.unwrap(), - [Contact::Accepted { - user_id: user1, - should_notify: false, - busy: false, - }] - ); - assert_eq!( - db.get_invite_code_for_user(user3).await.unwrap().unwrap().1, - 3 - ); - - // Trying to reedem the code for the third time results in an error. - db.create_invite_from_code(&invite_code, "user4@example.com", Some("user-4-device-id")) - .await - .unwrap_err(); - - // Invite count can be updated after the code has been created. - db.set_invite_count_for_user(user1, 2).await.unwrap(); - let (latest_code, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap(); - assert_eq!(latest_code, invite_code); // Invite code doesn't change when we increment above 0 - assert_eq!(invite_count, 2); - - // User 4 can now redeem the invite code and becomes a contact of user 1. - let user4_invite = db - .create_invite_from_code(&invite_code, "user4@example.com", Some("user-4-device-id")) - .await - .unwrap(); - let user4 = db - .create_user_from_invite( - &user4_invite, - NewUserParams { - github_login: "user-4".into(), - github_user_id: 4, - invite_count: 5, - }, - ) - .await - .unwrap() - .unwrap() - .user_id; - - let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap(); - assert_eq!(invite_count, 1); - assert_eq!( - db.get_contacts(user1).await.unwrap(), - [ - Contact::Accepted { - user_id: user2, - should_notify: true, - busy: false, - }, - Contact::Accepted { - user_id: user3, - should_notify: true, - busy: false, - }, - Contact::Accepted { - user_id: user4, - should_notify: true, - busy: false, - } - ] - ); - assert_eq!( - db.get_contacts(user4).await.unwrap(), - [Contact::Accepted { - user_id: user1, - should_notify: false, - busy: false, - }] - ); - assert_eq!( - db.get_invite_code_for_user(user4).await.unwrap().unwrap().1, - 5 - ); - - // An existing user cannot redeem invite codes. - db.create_invite_from_code(&invite_code, "user2@example.com", Some("user-2-device-id")) - .await - .unwrap_err(); - let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap(); - assert_eq!(invite_count, 1); -} - -#[gpui::test] -async fn test_signups() { - let test_db = TestDb::postgres(build_background_executor()); - let db = test_db.db(); - - // people sign up on the waitlist - for i in 0..8 { - db.create_signup(NewSignup { - email_address: format!("person-{i}@example.com"), - platform_mac: true, - platform_linux: i % 2 == 0, - platform_windows: i % 4 == 0, - editor_features: vec!["speed".into()], - programming_languages: vec!["rust".into(), "c".into()], - device_id: Some(format!("device_id_{i}")), - }) - .await - .unwrap(); - } - - assert_eq!( - db.get_waitlist_summary().await.unwrap(), - WaitlistSummary { - count: 8, - mac_count: 8, - linux_count: 4, - windows_count: 2, - unknown_count: 0, - } - ); - - // retrieve the next batch of signup emails to send - let signups_batch1 = db.get_unsent_invites(3).await.unwrap(); - let addresses = signups_batch1 - .iter() - .map(|s| &s.email_address) - .collect::>(); - assert_eq!( - addresses, - &[ - "person-0@example.com", - "person-1@example.com", - "person-2@example.com" - ] - ); - assert_ne!( - signups_batch1[0].email_confirmation_code, - signups_batch1[1].email_confirmation_code - ); - - // the waitlist isn't updated until we record that the emails - // were successfully sent. - let signups_batch = db.get_unsent_invites(3).await.unwrap(); - assert_eq!(signups_batch, signups_batch1); - - // once the emails go out, we can retrieve the next batch - // of signups. - db.record_sent_invites(&signups_batch1).await.unwrap(); - let signups_batch2 = db.get_unsent_invites(3).await.unwrap(); - let addresses = signups_batch2 - .iter() - .map(|s| &s.email_address) - .collect::>(); - assert_eq!( - addresses, - &[ - "person-3@example.com", - "person-4@example.com", - "person-5@example.com" - ] - ); - - // the sent invites are excluded from the summary. - assert_eq!( - db.get_waitlist_summary().await.unwrap(), - WaitlistSummary { - count: 5, - mac_count: 5, - linux_count: 2, - windows_count: 1, - unknown_count: 0, - } - ); - - // user completes the signup process by providing their - // github account. - let NewUserResult { - user_id, - inviting_user_id, - signup_device_id, - .. - } = db - .create_user_from_invite( - &Invite { - email_address: signups_batch1[0].email_address.clone(), - email_confirmation_code: signups_batch1[0].email_confirmation_code.clone(), - }, - NewUserParams { - github_login: "person-0".into(), - github_user_id: 0, - invite_count: 5, - }, - ) - .await - .unwrap() - .unwrap(); - let user = db.get_user_by_id(user_id).await.unwrap().unwrap(); - assert!(inviting_user_id.is_none()); - assert_eq!(user.github_login, "person-0"); - assert_eq!(user.email_address.as_deref(), Some("person-0@example.com")); - assert_eq!(user.invite_count, 5); - assert_eq!(signup_device_id.unwrap(), "device_id_0"); - - // cannot redeem the same signup again. - assert!(db - .create_user_from_invite( - &Invite { - email_address: signups_batch1[0].email_address.clone(), - email_confirmation_code: signups_batch1[0].email_confirmation_code.clone(), - }, - NewUserParams { - github_login: "some-other-github_account".into(), - github_user_id: 1, - invite_count: 5, - }, - ) - .await - .unwrap() - .is_none()); - - // cannot redeem a signup with the wrong confirmation code. - db.create_user_from_invite( - &Invite { - email_address: signups_batch1[1].email_address.clone(), - email_confirmation_code: "the-wrong-code".to_string(), - }, - NewUserParams { - github_login: "person-1".into(), - github_user_id: 2, - invite_count: 5, - }, - ) - .await - .unwrap_err(); -} - -fn build_background_executor() -> Arc { - Deterministic::new(0).build_background() -} diff --git a/crates/collab/src/integration_tests.rs b/crates/collab/src/integration_tests.rs index 93ff73fc83..225501c71d 100644 --- a/crates/collab/src/integration_tests.rs +++ b/crates/collab/src/integration_tests.rs @@ -1,5 +1,5 @@ use crate::{ - db::{self, NewUserParams, SqliteTestDb as TestDb, UserId}, + db::{self, NewUserParams, TestDb, UserId}, rpc::{Executor, Server}, AppState, }; @@ -5665,7 +5665,7 @@ impl TestServer { async fn start(background: Arc) -> Self { static NEXT_LIVE_KIT_SERVER_ID: AtomicUsize = AtomicUsize::new(0); - let test_db = TestDb::new(background.clone()); + let test_db = TestDb::sqlite(background.clone()); let live_kit_server_id = NEXT_LIVE_KIT_SERVER_ID.fetch_add(1, SeqCst); let live_kit_server = live_kit_client::TestServer::create( format!("http://livekit.{}.test", live_kit_server_id), diff --git a/crates/collab/src/main.rs b/crates/collab/src/main.rs index 8a2cdc980f..4802fd82b4 100644 --- a/crates/collab/src/main.rs +++ b/crates/collab/src/main.rs @@ -1,7 +1,6 @@ mod api; mod auth; mod db; -mod db2; mod env; mod rpc; @@ -11,7 +10,7 @@ mod integration_tests; use anyhow::anyhow; use axum::{routing::get, Router}; use collab::{Error, Result}; -use db::DefaultDb as Db; +use db::Database; use serde::Deserialize; use std::{ env::args, @@ -45,14 +44,16 @@ pub struct MigrateConfig { } pub struct AppState { - db: Arc, + db: Arc, live_kit_client: Option>, config: Config, } impl AppState { async fn new(config: Config) -> Result> { - let db = Db::new(&config.database_url, 5).await?; + let mut db_options = db::ConnectOptions::new(config.database_url.clone()); + db_options.max_connections(5); + let db = Database::new(db_options).await?; let live_kit_client = if let Some(((server, key), secret)) = config .live_kit_server .as_ref() @@ -92,7 +93,9 @@ async fn main() -> Result<()> { } Some("migrate") => { let config = envy::from_env::().expect("error loading config"); - let db = Db::new(&config.database_url, 5).await?; + let mut db_options = db::ConnectOptions::new(config.database_url.clone()); + db_options.max_connections(5); + let db = Database::new(db_options).await?; let migrations_path = config .migrations_path diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 07b9891480..beefe54a9d 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -2,7 +2,7 @@ mod connection_pool; use crate::{ auth, - db::{self, DefaultDb, ProjectId, RoomId, User, UserId}, + db::{self, Database, ProjectId, RoomId, User, UserId}, AppState, Result, }; use anyhow::anyhow; @@ -128,10 +128,10 @@ impl fmt::Debug for Session { } } -struct DbHandle(Arc); +struct DbHandle(Arc); impl Deref for DbHandle { - type Target = DefaultDb; + type Target = Database; fn deref(&self) -> &Self::Target { self.0.as_ref()