From 83e4e269896e84c43f64e523f02973a4c41f9673 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Wed, 19 Oct 2022 13:27:14 -0700 Subject: [PATCH] Allow setting ZED_SERVER_URL to URL of a collab server --- crates/client/src/client.rs | 82 ++++++++++++++++++++----------------- crates/collab/src/auth.rs | 11 ++--- 2 files changed, 50 insertions(+), 43 deletions(-) diff --git a/crates/client/src/client.rs b/crates/client/src/client.rs index 9cfccba37f..64075472cd 100644 --- a/crates/client/src/client.rs +++ b/crates/client/src/client.rs @@ -926,29 +926,34 @@ impl Client { } async fn get_rpc_url(http: Arc) -> Result { - let rpc_response = http - .get( - &(format!("{}/rpc", *ZED_SERVER_URL)), - Default::default(), - false, - ) - .await?; - if !rpc_response.status().is_redirection() { + let url = format!("{}/rpc", *ZED_SERVER_URL); + let response = http.get(&url, Default::default(), false).await?; + + // Normally, ZED_SERVER_URL is set to the URL of zed.dev website. + // The website's /rpc endpoint redirects to a collab server's /rpc endpoint, + // which requires authorization via an HTTP header. + // + // For testing purposes, ZED_SERVER_URL can also set to the direct URL of + // of a collab server. In that case, a request to the /rpc endpoint will + // return an 'unauthorized' response. + let collab_url = if response.status().is_redirection() { + response + .headers() + .get("Location") + .ok_or_else(|| anyhow!("missing location header in /rpc response"))? + .to_str() + .map_err(EstablishConnectionError::other)? + .to_string() + } else if response.status() == StatusCode::UNAUTHORIZED { + url + } else { Err(anyhow!( "unexpected /rpc response status {}", - rpc_response.status() + response.status() ))? - } + }; - let rpc_url = rpc_response - .headers() - .get("Location") - .ok_or_else(|| anyhow!("missing location header in /rpc response"))? - .to_str() - .map_err(EstablishConnectionError::other)? - .to_string(); - - Url::parse(&rpc_url).context("invalid rpc url") + Url::parse(&collab_url).context("invalid rpc url") } fn establish_websocket_connection( @@ -1105,25 +1110,6 @@ impl Client { login: String, mut api_token: String, ) -> Result { - let mut url = Self::get_rpc_url(http.clone()).await?; - url.set_path("/user"); - url.set_query(Some(&format!("github_login={login}"))); - let request = Request::get(url.as_str()) - .header("Authorization", format!("token {api_token}")) - .body("".into())?; - - let mut response = http.send(request).await?; - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - - if !response.status().is_success() { - Err(anyhow!( - "admin user request failed {} - {}", - response.status().as_u16(), - body, - ))?; - } - #[derive(Deserialize)] struct AuthenticatedUserResponse { user: User, @@ -1134,8 +1120,28 @@ impl Client { id: u64, } + // Use the collab server's admin API to retrieve the id + // of the impersonated user. + let mut url = Self::get_rpc_url(http.clone()).await?; + url.set_path("/user"); + url.set_query(Some(&format!("github_login={login}"))); + let request = Request::get(url.as_str()) + .header("Authorization", format!("token {api_token}")) + .body("".into())?; + + let mut response = http.send(request).await?; + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + if !response.status().is_success() { + Err(anyhow!( + "admin user request failed {} - {}", + response.status().as_u16(), + body, + ))?; + } let response: AuthenticatedUserResponse = serde_json::from_str(&body)?; + // Use the admin API token to authenticate as the impersonated user. api_token.insert_str(0, "ADMIN_TOKEN:"); Ok(Credentials { user_id: response.user.id, diff --git a/crates/collab/src/auth.rs b/crates/collab/src/auth.rs index e9e2855f1c..9081fe1f1e 100644 --- a/crates/collab/src/auth.rs +++ b/crates/collab/src/auth.rs @@ -1,7 +1,7 @@ -use std::sync::Arc; - -use super::db::{self, UserId}; -use crate::{AppState, Error, Result}; +use crate::{ + db::{self, UserId}, + AppState, Error, Result, +}; use anyhow::{anyhow, Context}; use axum::{ http::{self, Request, StatusCode}, @@ -13,6 +13,7 @@ use scrypt::{ password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString}, Scrypt, }; +use std::sync::Arc; pub async fn validate_header(mut req: Request, next: Next) -> impl IntoResponse { let mut auth_header = req @@ -21,7 +22,7 @@ pub async fn validate_header(mut req: Request, next: Next) -> impl Into .and_then(|header| header.to_str().ok()) .ok_or_else(|| { Error::Http( - StatusCode::BAD_REQUEST, + StatusCode::UNAUTHORIZED, "missing authorization header".to_string(), ) })?