Store the impersonator id on access tokens created via ZED_IMPERSONATE

* Use the impersonator id to prevent these tokens from counting
  against the impersonated user when limiting the users' total
  of access tokens.
* When connecting using an access token with an impersonator
  add the impersonator as a field to the tracing span that wraps
  the task for that connection.
* Disallow impersonating users via the admin API token in production,
  because when using the admin API token, we aren't able to identify
  the impersonator.

Co-authored-by: Marshall <marshall@zed.dev>
This commit is contained in:
Max Brunsfeld 2024-01-17 15:46:36 -08:00
parent 9521f49160
commit ab1bea515c
9 changed files with 198 additions and 39 deletions

View file

@ -19,9 +19,11 @@ CREATE INDEX "index_users_on_github_user_id" ON "users" ("github_user_id");
CREATE TABLE "access_tokens" ( CREATE TABLE "access_tokens" (
"id" INTEGER PRIMARY KEY AUTOINCREMENT, "id" INTEGER PRIMARY KEY AUTOINCREMENT,
"user_id" INTEGER REFERENCES users (id), "user_id" INTEGER REFERENCES users (id),
"impersonator_id" INTEGER REFERENCES users (id),
"hash" VARCHAR(128) "hash" VARCHAR(128)
); );
CREATE INDEX "index_access_tokens_user_id" ON "access_tokens" ("user_id"); CREATE INDEX "index_access_tokens_user_id" ON "access_tokens" ("user_id");
CREATE INDEX "index_access_tokens_impersonator_id" ON "access_tokens" ("impersonator_id");
CREATE TABLE "contacts" ( CREATE TABLE "contacts" (
"id" INTEGER PRIMARY KEY AUTOINCREMENT, "id" INTEGER PRIMARY KEY AUTOINCREMENT,

View file

@ -0,0 +1,3 @@
ALTER TABLE access_tokens ADD COLUMN impersonator_id integer;
CREATE INDEX "index_access_tokens_impersonator_id" ON "access_tokens" ("impersonator_id");

View file

@ -157,9 +157,11 @@ async fn create_access_token(
.ok_or_else(|| anyhow!("user not found"))?; .ok_or_else(|| anyhow!("user not found"))?;
let mut user_id = user.id; let mut user_id = user.id;
let mut impersonator_id = None;
if let Some(impersonate) = params.impersonate { if let Some(impersonate) = params.impersonate {
if user.admin { if user.admin {
if let Some(impersonated_user) = app.db.get_user_by_github_login(&impersonate).await? { if let Some(impersonated_user) = app.db.get_user_by_github_login(&impersonate).await? {
impersonator_id = Some(user_id);
user_id = impersonated_user.id; user_id = impersonated_user.id;
} else { } else {
return Err(Error::Http( return Err(Error::Http(
@ -175,7 +177,7 @@ async fn create_access_token(
} }
} }
let access_token = auth::create_access_token(app.db.as_ref(), user_id).await?; let access_token = auth::create_access_token(app.db.as_ref(), user_id, impersonator_id).await?;
let encrypted_access_token = let encrypted_access_token =
auth::encrypt_access_token(&access_token, params.public_key.clone())?; auth::encrypt_access_token(&access_token, params.public_key.clone())?;

View file

@ -27,6 +27,9 @@ lazy_static! {
.unwrap(); .unwrap();
} }
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct Impersonator(pub Option<db::User>);
/// Validates the authorization header. This has two mechanisms, one for the ADMIN_TOKEN /// Validates the authorization header. This has two mechanisms, one for the ADMIN_TOKEN
/// and one for the access tokens that we issue. /// and one for the access tokens that we issue.
pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse { pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
@ -57,28 +60,50 @@ pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl Into
})?; })?;
let state = req.extensions().get::<Arc<AppState>>().unwrap(); let state = req.extensions().get::<Arc<AppState>>().unwrap();
let credentials_valid = if let Some(admin_token) = access_token.strip_prefix("ADMIN_TOKEN:") {
state.config.api_token == admin_token // In development, allow impersonation using the admin API token.
// Don't allow this in production because we can't tell who is doing
// the impersonating.
let validate_result = if let (Some(admin_token), true) = (
access_token.strip_prefix("ADMIN_TOKEN:"),
state.config.is_development(),
) {
Ok(VerifyAccessTokenResult {
is_valid: state.config.api_token == admin_token,
impersonator_id: None,
})
} else { } else {
verify_access_token(&access_token, user_id, &state.db) verify_access_token(&access_token, user_id, &state.db).await
.await
.unwrap_or(false)
}; };
if credentials_valid { if let Ok(validate_result) = validate_result {
let user = state if validate_result.is_valid {
.db let user = state
.get_user_by_id(user_id) .db
.await? .get_user_by_id(user_id)
.ok_or_else(|| anyhow!("user {} not found", user_id))?; .await?
req.extensions_mut().insert(user); .ok_or_else(|| anyhow!("user {} not found", user_id))?;
Ok::<_, Error>(next.run(req).await)
} else { let impersonator = if let Some(impersonator_id) = validate_result.impersonator_id {
Err(Error::Http( let impersonator = state
StatusCode::UNAUTHORIZED, .db
"invalid credentials".to_string(), .get_user_by_id(impersonator_id)
)) .await?
.ok_or_else(|| anyhow!("user {} not found", impersonator_id))?;
Some(impersonator)
} else {
None
};
req.extensions_mut().insert(user);
req.extensions_mut().insert(Impersonator(impersonator));
return Ok::<_, Error>(next.run(req).await);
}
} }
Err(Error::Http(
StatusCode::UNAUTHORIZED,
"invalid credentials".to_string(),
))
} }
const MAX_ACCESS_TOKENS_TO_STORE: usize = 8; const MAX_ACCESS_TOKENS_TO_STORE: usize = 8;
@ -92,13 +117,22 @@ struct AccessTokenJson {
/// Creates a new access token to identify the given user. before returning it, you should /// Creates a new access token to identify the given user. before returning it, you should
/// encrypt it with the user's public key. /// encrypt it with the user's public key.
pub async fn create_access_token(db: &db::Database, user_id: UserId) -> Result<String> { pub async fn create_access_token(
db: &db::Database,
user_id: UserId,
impersonator_id: Option<UserId>,
) -> Result<String> {
const VERSION: usize = 1; const VERSION: usize = 1;
let access_token = rpc::auth::random_token(); let access_token = rpc::auth::random_token();
let access_token_hash = let access_token_hash =
hash_access_token(&access_token).context("failed to hash access token")?; hash_access_token(&access_token).context("failed to hash access token")?;
let id = db let id = db
.create_access_token(user_id, &access_token_hash, MAX_ACCESS_TOKENS_TO_STORE) .create_access_token(
user_id,
impersonator_id,
&access_token_hash,
MAX_ACCESS_TOKENS_TO_STORE,
)
.await?; .await?;
Ok(serde_json::to_string(&AccessTokenJson { Ok(serde_json::to_string(&AccessTokenJson {
version: VERSION, version: VERSION,
@ -137,8 +171,17 @@ pub fn encrypt_access_token(access_token: &str, public_key: String) -> Result<St
Ok(encrypted_access_token) Ok(encrypted_access_token)
} }
pub struct VerifyAccessTokenResult {
pub is_valid: bool,
pub impersonator_id: Option<UserId>,
}
/// verify access token returns true if the given token is valid for the given user. /// verify access token returns true if the given token is valid for the given user.
pub async fn verify_access_token(token: &str, user_id: UserId, db: &Arc<Database>) -> Result<bool> { pub async fn verify_access_token(
token: &str,
user_id: UserId,
db: &Arc<Database>,
) -> Result<VerifyAccessTokenResult> {
let token: AccessTokenJson = serde_json::from_str(&token)?; let token: AccessTokenJson = serde_json::from_str(&token)?;
let db_token = db.get_access_token(token.id).await?; let db_token = db.get_access_token(token.id).await?;
@ -154,5 +197,8 @@ pub async fn verify_access_token(token: &str, user_id: UserId, db: &Arc<Database
let duration = t0.elapsed(); let duration = t0.elapsed();
log::info!("hashed access token in {:?}", duration); log::info!("hashed access token in {:?}", duration);
METRIC_ACCESS_TOKEN_HASHING_TIME.observe(duration.as_millis() as f64); METRIC_ACCESS_TOKEN_HASHING_TIME.observe(duration.as_millis() as f64);
Ok(is_valid) Ok(VerifyAccessTokenResult {
is_valid,
impersonator_id: db_token.impersonator_id,
})
} }

View file

@ -6,6 +6,7 @@ impl Database {
pub async fn create_access_token( pub async fn create_access_token(
&self, &self,
user_id: UserId, user_id: UserId,
impersonator_id: Option<UserId>,
access_token_hash: &str, access_token_hash: &str,
max_access_token_count: usize, max_access_token_count: usize,
) -> Result<AccessTokenId> { ) -> Result<AccessTokenId> {
@ -14,19 +15,28 @@ impl Database {
let token = access_token::ActiveModel { let token = access_token::ActiveModel {
user_id: ActiveValue::set(user_id), user_id: ActiveValue::set(user_id),
impersonator_id: ActiveValue::set(impersonator_id),
hash: ActiveValue::set(access_token_hash.into()), hash: ActiveValue::set(access_token_hash.into()),
..Default::default() ..Default::default()
} }
.insert(&*tx) .insert(&*tx)
.await?; .await?;
let existing_token_filter = if let Some(impersonator_id) = impersonator_id {
access_token::Column::ImpersonatorId.eq(impersonator_id)
} else {
access_token::Column::UserId
.eq(user_id)
.and(access_token::Column::ImpersonatorId.is_null())
};
access_token::Entity::delete_many() access_token::Entity::delete_many()
.filter( .filter(
access_token::Column::Id.in_subquery( access_token::Column::Id.in_subquery(
Query::select() Query::select()
.column(access_token::Column::Id) .column(access_token::Column::Id)
.from(access_token::Entity) .from(access_token::Entity)
.and_where(access_token::Column::UserId.eq(user_id)) .cond_where(existing_token_filter)
.order_by(access_token::Column::Id, sea_orm::Order::Desc) .order_by(access_token::Column::Id, sea_orm::Order::Desc)
.limit(10000) .limit(10000)
.offset(max_access_token_count as u64) .offset(max_access_token_count as u64)

View file

@ -7,6 +7,7 @@ pub struct Model {
#[sea_orm(primary_key)] #[sea_orm(primary_key)]
pub id: AccessTokenId, pub id: AccessTokenId,
pub user_id: UserId, pub user_id: UserId,
pub impersonator_id: Option<UserId>,
pub hash: String, pub hash: String,
} }

View file

@ -146,7 +146,7 @@ test_both_dbs!(
); );
async fn test_create_access_tokens(db: &Arc<Database>) { async fn test_create_access_tokens(db: &Arc<Database>) {
let user = db let user_1 = db
.create_user( .create_user(
"u1@example.com", "u1@example.com",
false, false,
@ -158,14 +158,27 @@ async fn test_create_access_tokens(db: &Arc<Database>) {
.await .await
.unwrap() .unwrap()
.user_id; .user_id;
let user_2 = db
.create_user(
"u2@example.com",
false,
NewUserParams {
github_login: "u2".into(),
github_user_id: 2,
},
)
.await
.unwrap()
.user_id;
let token_1 = db.create_access_token(user, "h1", 2).await.unwrap(); let token_1 = db.create_access_token(user_1, None, "h1", 2).await.unwrap();
let token_2 = db.create_access_token(user, "h2", 2).await.unwrap(); let token_2 = db.create_access_token(user_1, None, "h2", 2).await.unwrap();
assert_eq!( assert_eq!(
db.get_access_token(token_1).await.unwrap(), db.get_access_token(token_1).await.unwrap(),
access_token::Model { access_token::Model {
id: token_1, id: token_1,
user_id: user, user_id: user_1,
impersonator_id: None,
hash: "h1".into(), hash: "h1".into(),
} }
); );
@ -173,17 +186,19 @@ async fn test_create_access_tokens(db: &Arc<Database>) {
db.get_access_token(token_2).await.unwrap(), db.get_access_token(token_2).await.unwrap(),
access_token::Model { access_token::Model {
id: token_2, id: token_2,
user_id: user, user_id: user_1,
impersonator_id: None,
hash: "h2".into() hash: "h2".into()
} }
); );
let token_3 = db.create_access_token(user, "h3", 2).await.unwrap(); let token_3 = db.create_access_token(user_1, None, "h3", 2).await.unwrap();
assert_eq!( assert_eq!(
db.get_access_token(token_3).await.unwrap(), db.get_access_token(token_3).await.unwrap(),
access_token::Model { access_token::Model {
id: token_3, id: token_3,
user_id: user, user_id: user_1,
impersonator_id: None,
hash: "h3".into() hash: "h3".into()
} }
); );
@ -191,18 +206,20 @@ async fn test_create_access_tokens(db: &Arc<Database>) {
db.get_access_token(token_2).await.unwrap(), db.get_access_token(token_2).await.unwrap(),
access_token::Model { access_token::Model {
id: token_2, id: token_2,
user_id: user, user_id: user_1,
impersonator_id: None,
hash: "h2".into() hash: "h2".into()
} }
); );
assert!(db.get_access_token(token_1).await.is_err()); assert!(db.get_access_token(token_1).await.is_err());
let token_4 = db.create_access_token(user, "h4", 2).await.unwrap(); let token_4 = db.create_access_token(user_1, None, "h4", 2).await.unwrap();
assert_eq!( assert_eq!(
db.get_access_token(token_4).await.unwrap(), db.get_access_token(token_4).await.unwrap(),
access_token::Model { access_token::Model {
id: token_4, id: token_4,
user_id: user, user_id: user_1,
impersonator_id: None,
hash: "h4".into() hash: "h4".into()
} }
); );
@ -210,12 +227,77 @@ async fn test_create_access_tokens(db: &Arc<Database>) {
db.get_access_token(token_3).await.unwrap(), db.get_access_token(token_3).await.unwrap(),
access_token::Model { access_token::Model {
id: token_3, id: token_3,
user_id: user, user_id: user_1,
impersonator_id: None,
hash: "h3".into() hash: "h3".into()
} }
); );
assert!(db.get_access_token(token_2).await.is_err()); assert!(db.get_access_token(token_2).await.is_err());
assert!(db.get_access_token(token_1).await.is_err()); assert!(db.get_access_token(token_1).await.is_err());
// An access token for user 2 impersonating user 1 does not
// count against user 1's access token limit (of 2).
let token_5 = db
.create_access_token(user_1, Some(user_2), "h5", 2)
.await
.unwrap();
assert_eq!(
db.get_access_token(token_5).await.unwrap(),
access_token::Model {
id: token_5,
user_id: user_1,
impersonator_id: Some(user_2),
hash: "h5".into()
}
);
assert_eq!(
db.get_access_token(token_3).await.unwrap(),
access_token::Model {
id: token_3,
user_id: user_1,
impersonator_id: None,
hash: "h3".into()
}
);
// Only a limited number (2) of access tokens are stored for user 2
// impersonating other users.
let token_6 = db
.create_access_token(user_1, Some(user_2), "h6", 2)
.await
.unwrap();
let token_7 = db
.create_access_token(user_1, Some(user_2), "h7", 2)
.await
.unwrap();
assert_eq!(
db.get_access_token(token_6).await.unwrap(),
access_token::Model {
id: token_6,
user_id: user_1,
impersonator_id: Some(user_2),
hash: "h6".into()
}
);
assert_eq!(
db.get_access_token(token_7).await.unwrap(),
access_token::Model {
id: token_7,
user_id: user_1,
impersonator_id: Some(user_2),
hash: "h7".into()
}
);
assert!(db.get_access_token(token_5).await.is_err());
assert_eq!(
db.get_access_token(token_3).await.unwrap(),
access_token::Model {
id: token_3,
user_id: user_1,
impersonator_id: None,
hash: "h3".into()
}
);
} }
test_both_dbs!( test_both_dbs!(

View file

@ -1,7 +1,7 @@
mod connection_pool; mod connection_pool;
use crate::{ use crate::{
auth, auth::{self, Impersonator},
db::{ db::{
self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreateChannelResult, self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreateChannelResult,
CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId, CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
@ -65,7 +65,7 @@ use std::{
use time::OffsetDateTime; use time::OffsetDateTime;
use tokio::sync::{watch, Semaphore}; use tokio::sync::{watch, Semaphore};
use tower::ServiceBuilder; use tower::ServiceBuilder;
use tracing::{info_span, instrument, Instrument}; use tracing::{field, info_span, instrument, Instrument};
pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30); pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10); pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
@ -561,13 +561,17 @@ impl Server {
connection: Connection, connection: Connection,
address: String, address: String,
user: User, user: User,
impersonator: Option<User>,
mut send_connection_id: Option<oneshot::Sender<ConnectionId>>, mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
executor: Executor, executor: Executor,
) -> impl Future<Output = Result<()>> { ) -> impl Future<Output = Result<()>> {
let this = self.clone(); let this = self.clone();
let user_id = user.id; let user_id = user.id;
let login = user.github_login; let login = user.github_login;
let span = info_span!("handle connection", %user_id, %login, %address); let span = info_span!("handle connection", %user_id, %login, %address, impersonator = field::Empty);
if let Some(impersonator) = impersonator {
span.record("impersonator", &impersonator.github_login);
}
let mut teardown = self.teardown.subscribe(); let mut teardown = self.teardown.subscribe();
async move { async move {
let (connection_id, handle_io, mut incoming_rx) = this let (connection_id, handle_io, mut incoming_rx) = this
@ -839,6 +843,7 @@ pub async fn handle_websocket_request(
ConnectInfo(socket_address): ConnectInfo<SocketAddr>, ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
Extension(server): Extension<Arc<Server>>, Extension(server): Extension<Arc<Server>>,
Extension(user): Extension<User>, Extension(user): Extension<User>,
Extension(impersonator): Extension<Impersonator>,
ws: WebSocketUpgrade, ws: WebSocketUpgrade,
) -> axum::response::Response { ) -> axum::response::Response {
if protocol_version != rpc::PROTOCOL_VERSION { if protocol_version != rpc::PROTOCOL_VERSION {
@ -858,7 +863,14 @@ pub async fn handle_websocket_request(
let connection = Connection::new(Box::pin(socket)); let connection = Connection::new(Box::pin(socket));
async move { async move {
server server
.handle_connection(connection, socket_address, user, None, Executor::Production) .handle_connection(
connection,
socket_address,
user,
impersonator.0,
None,
Executor::Production,
)
.await .await
.log_err(); .log_err();
} }

View file

@ -213,6 +213,7 @@ impl TestServer {
server_conn, server_conn,
client_name, client_name,
user, user,
None,
Some(connection_id_tx), Some(connection_id_tx),
Executor::Deterministic(cx.background_executor().clone()), Executor::Deterministic(cx.background_executor().clone()),
)) ))