diff --git a/crates/collab/src/api.rs b/crates/collab/src/api.rs index 235ed66424..7191400f44 100644 --- a/crates/collab/src/api.rs +++ b/crates/collab/src/api.rs @@ -78,6 +78,7 @@ pub async fn validate_api_token(req: Request, next: Next) -> impl IntoR struct AuthenticatedUserParams { github_user_id: Option, github_login: String, + github_email: Option, } #[derive(Debug, Serialize)] @@ -92,7 +93,11 @@ async fn get_authenticated_user( ) -> Result> { let user = app .db - .get_user_by_github_account(¶ms.github_login, params.github_user_id) + .get_or_create_user_by_github_account( + ¶ms.github_login, + params.github_user_id, + params.github_email.as_deref(), + ) .await? .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "user not found".into()))?; let metrics_id = app.db.get_user_metrics_id(user.id).await?; @@ -297,11 +302,7 @@ async fn create_access_token( let mut user_id = user.id; if let Some(impersonate) = params.impersonate { if user.admin { - if let Some(impersonated_user) = app - .db - .get_user_by_github_account(&impersonate, None) - .await? - { + if let Some(impersonated_user) = app.db.get_user_by_github_login(&impersonate).await? { user_id = impersonated_user.id; } else { return Err(Error::Http( diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 1d0fc377ab..3a711cbe29 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -295,10 +295,21 @@ impl Database { .await } - pub async fn get_user_by_github_account( + pub async fn get_user_by_github_login(&self, github_login: &str) -> Result> { + self.transaction(|tx| async move { + Ok(user::Entity::find() + .filter(user::Column::GithubLogin.eq(github_login)) + .one(&*tx) + .await?) + }) + .await + } + + pub async fn get_or_create_user_by_github_account( &self, github_login: &str, github_user_id: Option, + github_email: Option<&str>, ) -> Result> { self.transaction(|tx| async move { let tx = &*tx; @@ -320,7 +331,19 @@ impl Database { user_by_github_login.github_user_id = ActiveValue::set(Some(github_user_id)); Ok(Some(user_by_github_login.update(tx).await?)) } else { - Ok(None) + let user = user::Entity::insert(user::ActiveModel { + email_address: ActiveValue::set(github_email.map(|email| email.into())), + github_login: ActiveValue::set(github_login.into()), + github_user_id: ActiveValue::set(Some(github_user_id)), + admin: ActiveValue::set(false), + invite_count: ActiveValue::set(0), + invite_code: ActiveValue::set(None), + metrics_id: ActiveValue::set(Uuid::new_v4()), + ..Default::default() + }) + .exec_with_returning(&*tx) + .await?; + Ok(Some(user)) } } else { Ok(user::Entity::find() diff --git a/crates/collab/src/db/tests.rs b/crates/collab/src/db/tests.rs index 1e27167545..9cd79ea6d1 100644 --- a/crates/collab/src/db/tests.rs +++ b/crates/collab/src/db/tests.rs @@ -92,8 +92,8 @@ test_both_dbs!( ); test_both_dbs!( - test_get_user_by_github_account_postgres, - test_get_user_by_github_account_sqlite, + test_get_or_create_user_by_github_account_postgres, + test_get_or_create_user_by_github_account_sqlite, db, { let user_id1 = db @@ -124,7 +124,7 @@ test_both_dbs!( .user_id; let user = db - .get_user_by_github_account("login1", None) + .get_or_create_user_by_github_account("login1", None, None) .await .unwrap() .unwrap(); @@ -133,19 +133,28 @@ test_both_dbs!( assert_eq!(user.github_user_id, Some(101)); assert!(db - .get_user_by_github_account("non-existent-login", None) + .get_or_create_user_by_github_account("non-existent-login", None, None) .await .unwrap() .is_none()); let user = db - .get_user_by_github_account("the-new-login2", Some(102)) + .get_or_create_user_by_github_account("the-new-login2", Some(102), None) .await .unwrap() .unwrap(); assert_eq!(user.id, user_id2); assert_eq!(&user.github_login, "the-new-login2"); assert_eq!(user.github_user_id, Some(102)); + + let user = db + .get_or_create_user_by_github_account("login3", Some(103), Some("user3@example.com")) + .await + .unwrap() + .unwrap(); + assert_eq!(&user.github_login, "login3"); + assert_eq!(user.github_user_id, Some(103)); + assert_eq!(user.email_address, Some("user3@example.com".into())); } ); diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 52b5e80413..42a88d7d4c 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -1875,7 +1875,7 @@ async fn fuzzy_search_users( 1 | 2 => session .db() .await - .get_user_by_github_account(&query, None) + .get_user_by_github_login(&query) .await? .into_iter() .collect(), diff --git a/crates/collab/src/tests.rs b/crates/collab/src/tests.rs index 80df0ed6df..6ebfdc90b7 100644 --- a/crates/collab/src/tests.rs +++ b/crates/collab/src/tests.rs @@ -104,11 +104,7 @@ impl TestServer { }); let http = FakeHttpClient::with_404_response(); - let user_id = if let Ok(Some(user)) = self - .app_state - .db - .get_user_by_github_account(name, None) - .await + let user_id = if let Ok(Some(user)) = self.app_state.db.get_user_by_github_login(name).await { user.id } else {