diff --git a/Cargo.lock b/Cargo.lock index c6bd889ea7..033c99d1cd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -12562,6 +12562,7 @@ dependencies = [ "call", "channel", "chrono", + "clap 4.4.4", "cli", "client", "clock", diff --git a/crates/client/src/client.rs b/crates/client/src/client.rs index 5abd530579..4c84935584 100644 --- a/crates/client/src/client.rs +++ b/crates/client/src/client.rs @@ -27,8 +27,8 @@ use release_channel::{AppVersion, ReleaseChannel}; use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, PeerId, RequestMessage}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; - use settings::{Settings, SettingsStore}; +use std::fmt; use std::{ any::TypeId, convert::TryFrom, @@ -52,6 +52,15 @@ pub use rpc::*; pub use telemetry_events::Event; pub use user::*; +#[derive(Debug, Clone, Eq, PartialEq)] +pub struct DevServerToken(pub String); + +impl fmt::Display for DevServerToken { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + lazy_static! { static ref ZED_SERVER_URL: Option = std::env::var("ZED_SERVER_URL").ok(); static ref ZED_RPC_URL: Option = std::env::var("ZED_RPC_URL").ok(); @@ -277,10 +286,22 @@ enum WeakSubscriber { Pending(Vec>), } -#[derive(Clone, Debug)] -pub struct Credentials { - pub user_id: u64, - pub access_token: String, +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum Credentials { + DevServer { token: DevServerToken }, + User { user_id: u64, access_token: String }, +} + +impl Credentials { + pub fn authorization_header(&self) -> String { + match self { + Credentials::DevServer { token } => format!("dev-server-token {}", token), + Credentials::User { + user_id, + access_token, + } => format!("{} {}", user_id, access_token), + } + } } impl Default for ClientState { @@ -497,11 +518,11 @@ impl Client { } pub fn user_id(&self) -> Option { - self.state - .read() - .credentials - .as_ref() - .map(|credentials| credentials.user_id) + if let Some(Credentials::User { user_id, .. }) = self.state.read().credentials.as_ref() { + Some(*user_id) + } else { + None + } } pub fn peer_id(&self) -> Option { @@ -746,6 +767,10 @@ impl Client { read_credentials_from_keychain(cx).await.is_some() } + pub fn set_dev_server_token(&self, token: DevServerToken) { + self.state.write().credentials = Some(Credentials::DevServer { token }); + } + #[async_recursion(?Send)] pub async fn authenticate_and_connect( self: &Arc, @@ -796,7 +821,9 @@ impl Client { } } let credentials = credentials.unwrap(); - self.set_id(credentials.user_id); + if let Credentials::User { user_id, .. } = &credentials { + self.set_id(*user_id); + } if was_disconnected { self.set_status(Status::Connecting, cx); @@ -812,7 +839,9 @@ impl Client { Ok(conn) => { self.state.write().credentials = Some(credentials.clone()); if !read_from_keychain && IMPERSONATE_LOGIN.is_none() { - write_credentials_to_keychain(credentials, cx).await.log_err(); + if let Credentials::User{user_id, access_token} = credentials { + write_credentials_to_keychain(user_id, access_token, cx).await.log_err(); + } } futures::select_biased! { @@ -1020,10 +1049,7 @@ impl Client { .unwrap_or_default(); let request = Request::builder() - .header( - "Authorization", - format!("{} {}", credentials.user_id, credentials.access_token), - ) + .header("Authorization", credentials.authorization_header()) .header("x-zed-protocol-version", rpc::PROTOCOL_VERSION) .header("x-zed-app-version", app_version) .header( @@ -1176,7 +1202,7 @@ impl Client { .decrypt_string(&access_token) .context("failed to decrypt access token")?; - Ok(Credentials { + Ok(Credentials::User { user_id: user_id.parse()?, access_token, }) @@ -1226,7 +1252,7 @@ impl Client { // Use the admin API token to authenticate as the impersonated user. api_token.insert_str(0, "ADMIN_TOKEN:"); - Ok(Credentials { + Ok(Credentials::User { user_id: response.user.id, access_token: api_token, }) @@ -1439,21 +1465,22 @@ async fn read_credentials_from_keychain(cx: &AsyncAppContext) -> Option Result<()> { cx.update(move |cx| { cx.write_credentials( &ClientSettings::get_global(cx).server_url, - &credentials.user_id.to_string(), - credentials.access_token.as_bytes(), + &user_id.to_string(), + access_token.as_bytes(), ) })? .await @@ -1558,7 +1585,7 @@ mod tests { // Time out when client tries to connect. client.override_authenticate(move |cx| { cx.background_executor().spawn(async move { - Ok(Credentials { + Ok(Credentials::User { user_id, access_token: "token".into(), }) diff --git a/crates/client/src/test.rs b/crates/client/src/test.rs index 9338e8cb91..5e8ad2181c 100644 --- a/crates/client/src/test.rs +++ b/crates/client/src/test.rs @@ -48,7 +48,7 @@ impl FakeServer { let mut state = state.lock(); state.auth_count += 1; let access_token = state.access_token.to_string(); - Ok(Credentials { + Ok(Credentials::User { user_id: client_user_id, access_token, }) @@ -71,9 +71,12 @@ impl FakeServer { )))? } - assert_eq!(credentials.user_id, client_user_id); - - if credentials.access_token != state.lock().access_token.to_string() { + if credentials + != (Credentials::User { + user_id: client_user_id, + access_token: state.lock().access_token.to_string(), + }) + { Err(EstablishConnectionError::Unauthorized)? } diff --git a/crates/collab/README.md b/crates/collab/README.md index bb3c76b15b..1af0b55d47 100644 --- a/crates/collab/README.md +++ b/crates/collab/README.md @@ -29,7 +29,7 @@ You can tell what is currently deployed with `./script/what-is-deployed`. To create a new migration: ``` -./script/sqlx migrate add +./script/create-migration ``` Migrations are run automatically on service start, so run `foreman start` again. The service will crash if the migrations fail. diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql index 29758d7eb1..d82ef75813 100644 --- a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql +++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql @@ -400,3 +400,11 @@ CREATE TABLE hosted_projects ( ); CREATE INDEX idx_hosted_projects_on_channel_id ON hosted_projects (channel_id); CREATE UNIQUE INDEX uix_hosted_projects_on_channel_id_and_name ON hosted_projects (channel_id, name) WHERE (deleted_at IS NULL); + +CREATE TABLE dev_servers ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + channel_id INTEGER NOT NULL REFERENCES channels(id), + name TEXT NOT NULL, + hashed_token TEXT NOT NULL +); +CREATE INDEX idx_dev_servers_on_channel_id ON dev_servers (channel_id); diff --git a/crates/collab/migrations/20240321162658_add_devservers.sql b/crates/collab/migrations/20240321162658_add_devservers.sql new file mode 100644 index 0000000000..cb1ff4df40 --- /dev/null +++ b/crates/collab/migrations/20240321162658_add_devservers.sql @@ -0,0 +1,7 @@ +CREATE TABLE dev_servers ( + id INT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + channel_id INT NOT NULL REFERENCES channels(id), + name TEXT NOT NULL, + hashed_token TEXT NOT NULL +); +CREATE INDEX idx_dev_servers_on_channel_id ON dev_servers (channel_id); diff --git a/crates/collab/src/auth.rs b/crates/collab/src/auth.rs index 26f6ede3d3..5daf6e6186 100644 --- a/crates/collab/src/auth.rs +++ b/crates/collab/src/auth.rs @@ -1,5 +1,6 @@ use crate::{ - db::{self, AccessTokenId, Database, UserId}, + db::{self, dev_server, AccessTokenId, Database, DevServerId, UserId}, + rpc::Principal, AppState, Error, Result, }; use anyhow::{anyhow, Context}; @@ -19,11 +20,11 @@ use std::sync::OnceLock; use std::{sync::Arc, time::Instant}; use subtle::ConstantTimeEq; -#[derive(Clone, Debug, Default, PartialEq, Eq)] -pub struct Impersonator(pub Option); - -/// Validates the authorization header. This has two mechanisms, one for the ADMIN_TOKEN -/// and one for the access tokens that we issue. +/// Validates the authorization header and adds an Extension to the request. +/// Authorization: +/// can be an access_token attached to that user, or an access token of an admin +/// or (in development) the string ADMIN:. +/// Authorization: "dev-server-token" pub async fn validate_header(mut req: Request, next: Next) -> impl IntoResponse { let mut auth_header = req .headers() @@ -37,7 +38,26 @@ pub async fn validate_header(mut req: Request, next: Next) -> impl Into })? .split_whitespace(); - let user_id = UserId(auth_header.next().unwrap_or("").parse().map_err(|_| { + let state = req.extensions().get::>().unwrap(); + + let first = auth_header.next().unwrap_or(""); + if first == "dev-server-token" { + let dev_server_token = auth_header.next().ok_or_else(|| { + Error::Http( + StatusCode::BAD_REQUEST, + "missing dev-server-token token in authorization header".to_string(), + ) + })?; + let dev_server = verify_dev_server_token(dev_server_token, &state.db) + .await + .map_err(|e| Error::Http(StatusCode::UNAUTHORIZED, format!("{}", e)))?; + + req.extensions_mut() + .insert(Principal::DevServer(dev_server)); + return Ok::<_, Error>(next.run(req).await); + } + + let user_id = UserId(first.parse().map_err(|_| { Error::Http( StatusCode::BAD_REQUEST, "missing user id in authorization header".to_string(), @@ -51,8 +71,6 @@ pub async fn validate_header(mut req: Request, next: Next) -> impl Into ) })?; - let state = req.extensions().get::>().unwrap(); - // In development, allow impersonation using the admin API token. // Don't allow this in production because we can't tell who is doing // the impersonating. @@ -76,18 +94,17 @@ pub async fn validate_header(mut req: Request, next: Next) -> impl Into .await? .ok_or_else(|| anyhow!("user {} not found", user_id))?; - let impersonator = if let Some(impersonator_id) = validate_result.impersonator_id { - let impersonator = state + if let Some(impersonator_id) = validate_result.impersonator_id { + let admin = state .db .get_user_by_id(impersonator_id) .await? .ok_or_else(|| anyhow!("user {} not found", impersonator_id))?; - Some(impersonator) + req.extensions_mut() + .insert(Principal::Impersonated { user, admin }); } else { - None + req.extensions_mut().insert(Principal::User(user)); }; - req.extensions_mut().insert(user); - req.extensions_mut().insert(Impersonator(impersonator)); return Ok::<_, Error>(next.run(req).await); } } @@ -213,6 +230,33 @@ pub async fn verify_access_token( }) } +// a dev_server_token has the format .. This is to make them +// relatively easy to copy/paste around. +pub async fn verify_dev_server_token( + dev_server_token: &str, + db: &Arc, +) -> anyhow::Result { + let mut parts = dev_server_token.splitn(2, '.'); + let id = DevServerId(parts.next().unwrap_or_default().parse()?); + let token = parts + .next() + .ok_or_else(|| anyhow!("invalid dev server token format"))?; + + let token_hash = hash_access_token(&token); + let server = db.get_dev_server(id).await?; + + if server + .hashed_token + .as_bytes() + .ct_eq(token_hash.as_ref()) + .into() + { + Ok(server) + } else { + Err(anyhow!("wrong token for dev server")) + } +} + #[cfg(test)] mod test { use rand::thread_rng; diff --git a/crates/collab/src/db/ids.rs b/crates/collab/src/db/ids.rs index f465d3812a..91c0c440a5 100644 --- a/crates/collab/src/db/ids.rs +++ b/crates/collab/src/db/ids.rs @@ -67,28 +67,29 @@ macro_rules! id_type { }; } -id_type!(BufferId); id_type!(AccessTokenId); +id_type!(BufferId); +id_type!(ChannelBufferCollaboratorId); id_type!(ChannelChatParticipantId); id_type!(ChannelId); id_type!(ChannelMemberId); -id_type!(MessageId); id_type!(ContactId); +id_type!(DevServerId); +id_type!(ExtensionId); +id_type!(FlagId); id_type!(FollowerId); +id_type!(HostedProjectId); +id_type!(MessageId); +id_type!(NotificationId); +id_type!(NotificationKindId); +id_type!(ProjectCollaboratorId); +id_type!(ProjectId); +id_type!(ReplicaId); id_type!(RoomId); id_type!(RoomParticipantId); -id_type!(ProjectId); -id_type!(ProjectCollaboratorId); -id_type!(ReplicaId); id_type!(ServerId); id_type!(SignupId); id_type!(UserId); -id_type!(ChannelBufferCollaboratorId); -id_type!(FlagId); -id_type!(ExtensionId); -id_type!(NotificationId); -id_type!(NotificationKindId); -id_type!(HostedProjectId); /// ChannelRole gives you permissions for both channels and calls. #[derive( diff --git a/crates/collab/src/db/queries.rs b/crates/collab/src/db/queries.rs index 7f2e345a59..0582b8f256 100644 --- a/crates/collab/src/db/queries.rs +++ b/crates/collab/src/db/queries.rs @@ -5,6 +5,7 @@ pub mod buffers; pub mod channels; pub mod contacts; pub mod contributors; +pub mod dev_servers; pub mod extensions; pub mod hosted_projects; pub mod messages; diff --git a/crates/collab/src/db/queries/dev_servers.rs b/crates/collab/src/db/queries/dev_servers.rs new file mode 100644 index 0000000000..d95897b51e --- /dev/null +++ b/crates/collab/src/db/queries/dev_servers.rs @@ -0,0 +1,18 @@ +use sea_orm::EntityTrait; + +use super::{dev_server, Database, DevServerId}; + +impl Database { + pub async fn get_dev_server( + &self, + dev_server_id: DevServerId, + ) -> crate::Result { + self.transaction(|tx| async move { + Ok(dev_server::Entity::find_by_id(dev_server_id) + .one(&*tx) + .await? + .ok_or_else(|| anyhow::anyhow!("no dev server with id {}", dev_server_id))?) + }) + .await + } +} diff --git a/crates/collab/src/db/tables.rs b/crates/collab/src/db/tables.rs index 6864cc3782..b679337943 100644 --- a/crates/collab/src/db/tables.rs +++ b/crates/collab/src/db/tables.rs @@ -10,6 +10,7 @@ pub mod channel_message; pub mod channel_message_mention; pub mod contact; pub mod contributor; +pub mod dev_server; pub mod extension; pub mod extension_version; pub mod feature_flag; diff --git a/crates/collab/src/db/tables/dev_server.rs b/crates/collab/src/db/tables/dev_server.rs new file mode 100644 index 0000000000..94b1d4dc00 --- /dev/null +++ b/crates/collab/src/db/tables/dev_server.rs @@ -0,0 +1,17 @@ +use crate::db::{ChannelId, DevServerId}; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "dev_servers")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: DevServerId, + pub name: String, + pub channel_id: ChannelId, + pub hashed_token: String, +} + +impl ActiveModelBehavior for ActiveModel {} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation {} diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 735e1d3c50..9545b0c2e4 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -1,12 +1,12 @@ mod connection_pool; use crate::{ - auth::{self, Impersonator}, + auth::{self}, db::{ - self, BufferId, Channel, ChannelId, ChannelRole, ChannelsForUser, CreatedChannelMessage, - Database, InviteMemberResult, MembershipUpdated, MessageId, NotificationId, Project, - ProjectId, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId, ServerId, - UpdatedChannelMessage, User, UserId, + self, dev_server, BufferId, Channel, ChannelId, ChannelRole, ChannelsForUser, + CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId, + NotificationId, Project, ProjectId, RemoveChannelMemberResult, ReplicaId, + RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId, }, executor::Executor, AppState, Error, RateLimit, RateLimiter, Result, @@ -64,7 +64,10 @@ use std::{ use time::OffsetDateTime; use tokio::sync::{watch, Semaphore}; use tower::ServiceBuilder; -use tracing::{field, info_span, instrument, Instrument}; +use tracing::{ + field::{self}, + info_span, instrument, Instrument, +}; use util::{http::IsahcHttpClient, SemanticVersion}; pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30); @@ -105,9 +108,35 @@ impl StreamingResponse { } } +#[derive(Clone, Debug)] +pub enum Principal { + User(User), + Impersonated { user: User, admin: User }, + DevServer(dev_server::Model), +} + +impl Principal { + fn update_span(&self, span: &tracing::Span) { + match &self { + Principal::User(user) => { + span.record("user_id", &user.id.0); + span.record("login", &user.github_login); + } + Principal::Impersonated { user, admin } => { + span.record("user_id", &user.id.0); + span.record("login", &user.github_login); + span.record("impersonator", &admin.github_login); + } + Principal::DevServer(dev_server) => { + span.record("dev_server_id", &dev_server.id.0); + } + } + } +} + #[derive(Clone)] struct Session { - user_id: UserId, + principal: Principal, connection_id: ConnectionId, db: Arc>, peer: Arc, @@ -137,14 +166,98 @@ impl Session { _not_send: PhantomData, } } + + fn for_user(self) -> Option { + UserSession::new(self) + } + + fn user_id(&self) -> Option { + match &self.principal { + Principal::User(user) => Some(user.id), + Principal::Impersonated { user, .. } => Some(user.id), + Principal::DevServer(_) => None, + } + } } impl Debug for Session { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - f.debug_struct("Session") - .field("user_id", &self.user_id) - .field("connection_id", &self.connection_id) - .finish() + let mut result = f.debug_struct("Session"); + match &self.principal { + Principal::User(user) => { + result.field("user", &user.github_login); + } + Principal::Impersonated { user, admin } => { + result.field("user", &user.github_login); + result.field("impersonator", &admin.github_login); + } + Principal::DevServer(dev_server) => { + result.field("dev_server", &dev_server.id); + } + } + result.field("connection_id", &self.connection_id).finish() + } +} + +struct UserSession(Session); + +impl UserSession { + pub fn new(s: Session) -> Option { + s.user_id().map(|_| UserSession(s)) + } + pub fn user_id(&self) -> UserId { + self.0.user_id().unwrap() + } +} + +impl Deref for UserSession { + type Target = Session; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} +impl DerefMut for UserSession { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +fn user_handler( + handler: impl 'static + Send + Sync + Fn(M, Response, UserSession) -> Fut, +) -> impl 'static + Send + Sync + Fn(M, Response, Session) -> BoxFuture<'static, Result<()>> +where + Fut: Send + Future>, +{ + let handler = Arc::new(handler); + move |message, response, session| { + let handler = handler.clone(); + Box::pin(async move { + if let Some(user_session) = session.for_user() { + Ok(handler(message, response, user_session).await?) + } else { + Err(Error::Internal(anyhow!("must be a user"))) + } + }) + } +} + +fn user_message_handler( + handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut, +) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>> +where + InnertRetFut: Send + Future>, +{ + let handler = Arc::new(handler); + move |message, session| { + let handler = handler.clone(); + Box::pin(async move { + if let Some(user_session) = session.for_user() { + Ok(handler(message, user_session).await?) + } else { + Err(Error::Internal(anyhow!("must be a user"))) + } + }) } } @@ -201,20 +314,20 @@ impl Server { server .add_request_handler(ping) - .add_request_handler(create_room) - .add_request_handler(join_room) - .add_request_handler(rejoin_room) - .add_request_handler(leave_room) - .add_request_handler(set_room_participant_role) - .add_request_handler(call) - .add_request_handler(cancel_call) - .add_message_handler(decline_call) - .add_request_handler(update_participant_location) + .add_request_handler(user_handler(create_room)) + .add_request_handler(user_handler(join_room)) + .add_request_handler(user_handler(rejoin_room)) + .add_request_handler(user_handler(leave_room)) + .add_request_handler(user_handler(set_room_participant_role)) + .add_request_handler(user_handler(call)) + .add_request_handler(user_handler(cancel_call)) + .add_message_handler(user_message_handler(decline_call)) + .add_request_handler(user_handler(update_participant_location)) .add_request_handler(share_project) .add_message_handler(unshare_project) - .add_request_handler(join_project) - .add_request_handler(join_hosted_project) - .add_message_handler(leave_project) + .add_request_handler(user_handler(join_project)) + .add_request_handler(user_handler(join_hosted_project)) + .add_message_handler(user_message_handler(leave_project)) .add_request_handler(update_project) .add_request_handler(update_worktree) .add_message_handler(start_language_server) @@ -261,40 +374,40 @@ impl Server { .add_message_handler(broadcast_project_message_from_host::) .add_message_handler(broadcast_project_message_from_host::) .add_request_handler(get_users) - .add_request_handler(fuzzy_search_users) - .add_request_handler(request_contact) - .add_request_handler(remove_contact) - .add_request_handler(respond_to_contact_request) - .add_request_handler(create_channel) - .add_request_handler(delete_channel) - .add_request_handler(invite_channel_member) - .add_request_handler(remove_channel_member) - .add_request_handler(set_channel_member_role) - .add_request_handler(set_channel_visibility) - .add_request_handler(rename_channel) - .add_request_handler(join_channel_buffer) - .add_request_handler(leave_channel_buffer) - .add_message_handler(update_channel_buffer) - .add_request_handler(rejoin_channel_buffers) - .add_request_handler(get_channel_members) - .add_request_handler(respond_to_channel_invite) - .add_request_handler(join_channel) - .add_request_handler(join_channel_chat) - .add_message_handler(leave_channel_chat) - .add_request_handler(send_channel_message) - .add_request_handler(remove_channel_message) - .add_request_handler(update_channel_message) - .add_request_handler(get_channel_messages) - .add_request_handler(get_channel_messages_by_id) - .add_request_handler(get_notifications) - .add_request_handler(mark_notification_as_read) - .add_request_handler(move_channel) - .add_request_handler(follow) - .add_message_handler(unfollow) - .add_message_handler(update_followers) - .add_request_handler(get_private_user_info) - .add_message_handler(acknowledge_channel_message) - .add_message_handler(acknowledge_buffer_version) + .add_request_handler(user_handler(fuzzy_search_users)) + .add_request_handler(user_handler(request_contact)) + .add_request_handler(user_handler(remove_contact)) + .add_request_handler(user_handler(respond_to_contact_request)) + .add_request_handler(user_handler(create_channel)) + .add_request_handler(user_handler(delete_channel)) + .add_request_handler(user_handler(invite_channel_member)) + .add_request_handler(user_handler(remove_channel_member)) + .add_request_handler(user_handler(set_channel_member_role)) + .add_request_handler(user_handler(set_channel_visibility)) + .add_request_handler(user_handler(rename_channel)) + .add_request_handler(user_handler(join_channel_buffer)) + .add_request_handler(user_handler(leave_channel_buffer)) + .add_message_handler(user_message_handler(update_channel_buffer)) + .add_request_handler(user_handler(rejoin_channel_buffers)) + .add_request_handler(user_handler(get_channel_members)) + .add_request_handler(user_handler(respond_to_channel_invite)) + .add_request_handler(user_handler(join_channel)) + .add_request_handler(user_handler(join_channel_chat)) + .add_message_handler(user_message_handler(leave_channel_chat)) + .add_request_handler(user_handler(send_channel_message)) + .add_request_handler(user_handler(remove_channel_message)) + .add_request_handler(user_handler(update_channel_message)) + .add_request_handler(user_handler(get_channel_messages)) + .add_request_handler(user_handler(get_channel_messages_by_id)) + .add_request_handler(user_handler(get_notifications)) + .add_request_handler(user_handler(mark_notification_as_read)) + .add_request_handler(user_handler(move_channel)) + .add_request_handler(user_handler(follow)) + .add_message_handler(user_message_handler(unfollow)) + .add_message_handler(user_message_handler(update_followers)) + .add_request_handler(user_handler(get_private_user_info)) + .add_message_handler(user_message_handler(acknowledge_channel_message)) + .add_message_handler(user_message_handler(acknowledge_buffer_version)) .add_streaming_request_handler({ let app_state = app_state.clone(); move |request, response, session| { @@ -309,14 +422,14 @@ impl Server { }) .add_request_handler({ let app_state = app_state.clone(); - move |request, response, session| { + user_handler(move |request, response, session| { count_tokens_with_language_model( request, response, session, app_state.config.google_ai_api_key.clone(), ) - } + }) }); Arc::new(server) @@ -612,19 +725,15 @@ impl Server { self: &Arc, connection: Connection, address: String, - user: User, + principal: Principal, zed_version: ZedVersion, - impersonator: Option, send_connection_id: Option>, executor: Executor, ) -> impl Future { let this = self.clone(); - let user_id = user.id; - let login = user.github_login.clone(); - let span = info_span!("handle connection", %user_id, %login, %address, impersonator = field::Empty, connection_id = field::Empty); - if let Some(impersonator) = impersonator { - span.record("impersonator", &impersonator.github_login); - } + let span = info_span!("handle connection", %address, impersonator = field::Empty, connection_id = field::Empty); + principal.update_span(&span); + let mut teardown = self.teardown.subscribe(); async move { if *teardown.borrow() { @@ -649,7 +758,7 @@ impl Server { }; let session = Session { - user_id, + principal: principal.clone(), connection_id, db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))), peer: this.peer.clone(), @@ -660,7 +769,7 @@ impl Server { _executor: executor.clone(), }; - if let Err(error) = this.send_initial_client_update(connection_id, user, zed_version, send_connection_id, &session).await { + if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await { tracing::error!(?error, "failed to send initial client update"); return; } @@ -700,7 +809,8 @@ impl Server { let type_name = message.payload_type_name(); // note: we copy all the fields from the parent span so we can query them in the logs. // (https://github.com/tokio-rs/tracing/issues/2670). - let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name); + let span = tracing::info_span!("receive message", %connection_id, %address, type_name); + principal.update_span(&span); let span_enter = span.enter(); if let Some(handler) = this.handlers.get(&message.payload_type_id()) { let is_background = message.is_background(); @@ -739,7 +849,7 @@ impl Server { async fn send_initial_client_update( &self, connection_id: ConnectionId, - user: User, + principal: &Principal, zed_version: ZedVersion, mut send_connection_id: Option>, session: &Session, @@ -752,6 +862,10 @@ impl Server { )?; tracing::info!("sent hello message"); + let Principal::User(user) = principal else { + return Ok(()); + }; + if let Some(send_connection_id) = send_connection_id.take() { let _ = send_connection_id.send(connection_id); } @@ -970,8 +1084,7 @@ pub async fn handle_websocket_request( app_version_header: Option>, ConnectInfo(socket_address): ConnectInfo, Extension(server): Extension>, - Extension(user): Extension, - Extension(impersonator): Extension, + Extension(principal): Extension, ws: WebSocketUpgrade, ) -> axum::response::Response { if protocol_version != rpc::PROTOCOL_VERSION { @@ -1010,9 +1123,8 @@ pub async fn handle_websocket_request( .handle_connection( connection, socket_address, - user, + principal, version, - impersonator.0, None, Executor::Production, ) @@ -1075,24 +1187,26 @@ async fn connection_lost( futures::select_biased! { _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => { - log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id); - leave_room_for_session(&session).await.trace_err(); - leave_channel_buffers_for_session(&session) - .await - .trace_err(); + if let Some(session) = session.for_user() { + log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id); + leave_room_for_session(&session).await.trace_err(); + leave_channel_buffers_for_session(&session) + .await + .trace_err(); - if !session - .connection_pool() - .await - .is_user_online(session.user_id) - { - let db = session.db().await; - if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() { - room_updated(&room, &session.peer); + if !session + .connection_pool() + .await + .is_user_online(session.user_id()) + { + let db = session.db().await; + if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() { + room_updated(&room, &session.peer); + } } - } - update_user_contacts(session.user_id, &session).await?; + update_user_contacts(session.user_id(), &session).await?; + } } _ = teardown.changed().fuse() => {} } @@ -1110,19 +1224,20 @@ async fn ping(_: proto::Ping, response: Response, _session: Session async fn create_room( _request: proto::CreateRoom, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let live_kit_room = nanoid::nanoid!(30); let live_kit_connection_info = { let live_kit_room = live_kit_room.clone(); let live_kit = session.live_kit_client.as_ref(); + let user_id = session.user_id().to_string(); util::async_maybe!({ let live_kit = live_kit?; let token = live_kit - .room_token(&live_kit_room, &session.user_id.to_string()) + .room_token(&live_kit_room, &user_id.to_string()) .trace_err()?; Some(proto::LiveKitConnectionInfo { @@ -1137,7 +1252,7 @@ async fn create_room( let room = session .db() .await - .create_room(session.user_id, session.connection_id, &live_kit_room) + .create_room(session.user_id(), session.connection_id, &live_kit_room) .await?; response.send(proto::CreateRoomResponse { @@ -1145,7 +1260,7 @@ async fn create_room( live_kit_connection_info, })?; - update_user_contacts(session.user_id, &session).await?; + update_user_contacts(session.user_id(), &session).await?; Ok(()) } @@ -1153,7 +1268,7 @@ async fn create_room( async fn join_room( request: proto::JoinRoom, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let room_id = RoomId::from_proto(request.id); @@ -1167,7 +1282,7 @@ async fn join_room( let room = session .db() .await - .join_room(room_id, session.user_id, session.connection_id) + .join_room(room_id, session.user_id(), session.connection_id) .await?; room_updated(&room.room, &session.peer); room.into_inner() @@ -1176,7 +1291,7 @@ async fn join_room( for connection_id in session .connection_pool() .await - .user_connection_ids(session.user_id) + .user_connection_ids(session.user_id()) { session .peer @@ -1193,7 +1308,7 @@ async fn join_room( if let Some(token) = live_kit .room_token( &joined_room.room.live_kit_room, - &session.user_id.to_string(), + &session.user_id().to_string(), ) .trace_err() { @@ -1215,7 +1330,7 @@ async fn join_room( live_kit_connection_info, })?; - update_user_contacts(session.user_id, &session).await?; + update_user_contacts(session.user_id(), &session).await?; Ok(()) } @@ -1223,7 +1338,7 @@ async fn join_room( async fn rejoin_room( request: proto::RejoinRoom, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let room; let channel; @@ -1231,7 +1346,7 @@ async fn rejoin_room( let mut rejoined_room = session .db() .await - .rejoin_room(request, session.user_id, session.connection_id) + .rejoin_room(request, session.user_id(), session.connection_id) .await?; response.send(proto::RejoinRoomResponse { @@ -1404,7 +1519,7 @@ async fn rejoin_room( ); } - update_user_contacts(session.user_id, &session).await?; + update_user_contacts(session.user_id(), &session).await?; Ok(()) } @@ -1412,7 +1527,7 @@ async fn rejoin_room( async fn leave_room( _: proto::LeaveRoom, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { leave_room_for_session(&session).await?; response.send(proto::Ack {})?; @@ -1423,7 +1538,7 @@ async fn leave_room( async fn set_room_participant_role( request: proto::SetRoomParticipantRole, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let user_id = UserId::from_proto(request.user_id); let role = ChannelRole::from(request.role()); @@ -1433,7 +1548,7 @@ async fn set_room_participant_role( .db() .await .set_room_participant_role( - session.user_id, + session.user_id(), RoomId::from_proto(request.room_id), user_id, role, @@ -1471,10 +1586,10 @@ async fn set_room_participant_role( async fn call( request: proto::Call, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let room_id = RoomId::from_proto(request.room_id); - let calling_user_id = session.user_id; + let calling_user_id = session.user_id(); let calling_connection_id = session.connection_id; let called_user_id = UserId::from_proto(request.called_user_id); let initial_project_id = request.initial_project_id.map(ProjectId::from_proto); @@ -1540,7 +1655,7 @@ async fn call( async fn cancel_call( request: proto::CancelCall, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let called_user_id = UserId::from_proto(request.called_user_id); let room_id = RoomId::from_proto(request.room_id); @@ -1575,13 +1690,13 @@ async fn cancel_call( } /// Decline an incoming call. -async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> { +async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> { let room_id = RoomId::from_proto(message.room_id); { let room = session .db() .await - .decline_call(Some(room_id), session.user_id) + .decline_call(Some(room_id), session.user_id()) .await? .ok_or_else(|| anyhow!("failed to decline call"))?; room_updated(&room, &session.peer); @@ -1590,7 +1705,7 @@ async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<( for connection_id in session .connection_pool() .await - .user_connection_ids(session.user_id) + .user_connection_ids(session.user_id()) { session .peer @@ -1602,7 +1717,7 @@ async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<( ) .trace_err(); } - update_user_contacts(session.user_id, &session).await?; + update_user_contacts(session.user_id(), &session).await?; Ok(()) } @@ -1610,7 +1725,7 @@ async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<( async fn update_participant_location( request: proto::UpdateParticipantLocation, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let room_id = RoomId::from_proto(request.room_id); let location = request @@ -1674,7 +1789,7 @@ async fn unshare_project(message: proto::UnshareProject, session: Session) -> Re async fn join_project( request: proto::JoinProject, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let project_id = ProjectId::from_proto(request.project_id); @@ -1705,7 +1820,7 @@ impl JoinProjectInternalResponse for Response { fn join_project_internal( response: impl JoinProjectInternalResponse, - session: Session, + session: UserSession, project: &mut Project, replica_id: &ReplicaId, ) -> Result<()> { @@ -1716,7 +1831,7 @@ fn join_project_internal( .map(|collaborator| collaborator.to_proto()) .collect::>(); let project_id = project.id; - let guest_user_id = session.user_id; + let guest_user_id = session.user_id(); let worktrees = project .worktrees @@ -1823,7 +1938,7 @@ fn join_project_internal( } /// Leave someone elses shared project. -async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> { +async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> { let sender_id = session.connection_id; let project_id = ProjectId::from_proto(request.project_id); let db = session.db().await; @@ -1850,14 +1965,14 @@ async fn leave_project(request: proto::LeaveProject, session: Session) -> Result async fn join_hosted_project( request: proto::JoinHostedProject, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let (mut project, replica_id) = session .db() .await .join_hosted_project( ProjectId(request.project_id as i32), - session.user_id, + session.user_id(), session.connection_id, ) .await?; @@ -2168,7 +2283,7 @@ async fn broadcast_project_message_from_host, - session: Session, + session: UserSession, ) -> Result<()> { let room_id = RoomId::from_proto(request.room_id); let project_id = request.project_id.map(ProjectId::from_proto); @@ -2203,7 +2318,7 @@ async fn follow( } /// Stop following another user in a call. -async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> { +async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> { let room_id = RoomId::from_proto(request.room_id); let project_id = request.project_id.map(ProjectId::from_proto); let leader_id = request @@ -2235,7 +2350,7 @@ async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> { } /// Notify everyone following you of your current location. -async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> { +async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> { let room_id = RoomId::from_proto(request.room_id); let database = session.db.lock().await; @@ -2297,7 +2412,7 @@ async fn get_users( async fn fuzzy_search_users( request: proto::FuzzySearchUsers, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let query = request.query; let users = match query.len() { @@ -2313,7 +2428,7 @@ async fn fuzzy_search_users( }; let users = users .into_iter() - .filter(|user| user.id != session.user_id) + .filter(|user| user.id != session.user_id()) .map(|user| proto::User { id: user.id.to_proto(), avatar_url: format!("https://github.com/{}.png?size=128", user.github_login), @@ -2328,9 +2443,9 @@ async fn fuzzy_search_users( async fn request_contact( request: proto::RequestContact, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { - let requester_id = session.user_id; + let requester_id = session.user_id(); let responder_id = UserId::from_proto(request.responder_id); if requester_id == responder_id { return Err(anyhow!("cannot add yourself as a contact"))?; @@ -2375,9 +2490,9 @@ async fn request_contact( async fn respond_to_contact_request( request: proto::RespondToContactRequest, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { - let responder_id = session.user_id; + let responder_id = session.user_id(); let requester_id = UserId::from_proto(request.requester_id); let db = session.db().await; if request.response == proto::ContactRequestResponse::Dismiss as i32 { @@ -2433,9 +2548,9 @@ async fn respond_to_contact_request( async fn remove_contact( request: proto::RemoveContact, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { - let requester_id = session.user_id; + let requester_id = session.user_id(); let responder_id = UserId::from_proto(request.user_id); let db = session.db().await; let (contact_accepted, deleted_notification_id) = @@ -2484,13 +2599,13 @@ async fn remove_contact( async fn create_channel( request: proto::CreateChannel, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let db = session.db().await; let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id)); let (channel, membership) = db - .create_channel(&request.name, parent_id, session.user_id) + .create_channel(&request.name, parent_id, session.user_id()) .await?; let root_id = channel.root_id(); @@ -2539,13 +2654,13 @@ async fn create_channel( async fn delete_channel( request: proto::DeleteChannel, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let db = session.db().await; let channel_id = request.channel_id; let (root_channel, removed_channels) = db - .delete_channel(ChannelId::from_proto(channel_id), session.user_id) + .delete_channel(ChannelId::from_proto(channel_id), session.user_id()) .await?; response.send(proto::Ack {})?; @@ -2567,7 +2682,7 @@ async fn delete_channel( async fn invite_channel_member( request: proto::InviteChannelMember, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); @@ -2579,7 +2694,7 @@ async fn invite_channel_member( .invite_channel_member( channel_id, invitee_id, - session.user_id, + session.user_id(), request.role().into(), ) .await?; @@ -2604,7 +2719,7 @@ async fn invite_channel_member( async fn remove_channel_member( request: proto::RemoveChannelMember, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); @@ -2614,7 +2729,7 @@ async fn remove_channel_member( membership_update, notification_id, } = db - .remove_channel_member(channel_id, member_id, session.user_id) + .remove_channel_member(channel_id, member_id, session.user_id()) .await?; let mut connection_pool = session.connection_pool().await; @@ -2648,14 +2763,14 @@ async fn remove_channel_member( async fn set_channel_visibility( request: proto::SetChannelVisibility, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); let visibility = request.visibility().into(); let channel_model = db - .set_channel_visibility(channel_id, visibility, session.user_id) + .set_channel_visibility(channel_id, visibility, session.user_id()) .await?; let root_id = channel_model.root_id(); let channel = Channel::from_model(channel_model); @@ -2693,7 +2808,7 @@ async fn set_channel_visibility( async fn set_channel_member_role( request: proto::SetChannelMemberRole, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); @@ -2701,7 +2816,7 @@ async fn set_channel_member_role( let result = db .set_channel_member_role( channel_id, - session.user_id, + session.user_id(), member_id, request.role().into(), ) @@ -2741,12 +2856,12 @@ async fn set_channel_member_role( async fn rename_channel( request: proto::RenameChannel, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); let channel_model = db - .rename_channel(channel_id, session.user_id, &request.name) + .rename_channel(channel_id, session.user_id(), &request.name) .await?; let root_id = channel_model.root_id(); let channel = Channel::from_model(channel_model); @@ -2773,7 +2888,7 @@ async fn rename_channel( async fn move_channel( request: proto::MoveChannel, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let channel_id = ChannelId::from_proto(request.channel_id); let to = ChannelId::from_proto(request.to); @@ -2781,7 +2896,7 @@ async fn move_channel( let (root_id, channels) = session .db() .await - .move_channel(channel_id, to, session.user_id) + .move_channel(channel_id, to, session.user_id()) .await?; let connection_pool = session.connection_pool().await; @@ -2816,12 +2931,12 @@ async fn move_channel( async fn get_channel_members( request: proto::GetChannelMembers, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); let members = db - .get_channel_participant_details(channel_id, session.user_id) + .get_channel_participant_details(channel_id, session.user_id()) .await?; response.send(proto::GetChannelMembersResponse { members })?; Ok(()) @@ -2831,7 +2946,7 @@ async fn get_channel_members( async fn respond_to_channel_invite( request: proto::RespondToChannelInvite, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); @@ -2839,7 +2954,7 @@ async fn respond_to_channel_invite( membership_update, notifications, } = db - .respond_to_channel_invite(channel_id, session.user_id, request.accept) + .respond_to_channel_invite(channel_id, session.user_id(), request.accept) .await?; let mut connection_pool = session.connection_pool().await; @@ -2847,7 +2962,7 @@ async fn respond_to_channel_invite( notify_membership_updated( &mut connection_pool, membership_update, - session.user_id, + session.user_id(), &session.peer, ); } else { @@ -2856,7 +2971,7 @@ async fn respond_to_channel_invite( ..Default::default() }; - for connection_id in connection_pool.user_connection_ids(session.user_id) { + for connection_id in connection_pool.user_connection_ids(session.user_id()) { session.peer.send(connection_id, update.clone())?; } }; @@ -2872,7 +2987,7 @@ async fn respond_to_channel_invite( async fn join_channel( request: proto::JoinChannel, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let channel_id = ChannelId::from_proto(request.channel_id); join_channel_internal(channel_id, Box::new(response), session).await @@ -2895,14 +3010,14 @@ impl JoinChannelInternalResponse for Response { async fn join_channel_internal( channel_id: ChannelId, response: Box, - session: Session, + session: UserSession, ) -> Result<()> { let joined_room = { leave_room_for_session(&session).await?; let db = session.db().await; let (joined_room, membership_updated, role) = db - .join_channel(channel_id, session.user_id, session.connection_id) + .join_channel(channel_id, session.user_id(), session.connection_id) .await?; let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| { @@ -2912,7 +3027,7 @@ async fn join_channel_internal( live_kit .guest_token( &joined_room.room.live_kit_room, - &session.user_id.to_string(), + &session.user_id().to_string(), ) .trace_err()?, ) @@ -2922,7 +3037,7 @@ async fn join_channel_internal( live_kit .room_token( &joined_room.room.live_kit_room, - &session.user_id.to_string(), + &session.user_id().to_string(), ) .trace_err()?, ) @@ -2949,7 +3064,7 @@ async fn join_channel_internal( notify_membership_updated( &mut connection_pool, membership_updated, - session.user_id, + session.user_id(), &session.peer, ); } @@ -2968,7 +3083,7 @@ async fn join_channel_internal( &*session.connection_pool().await, ); - update_user_contacts(session.user_id, &session).await?; + update_user_contacts(session.user_id(), &session).await?; Ok(()) } @@ -2976,13 +3091,13 @@ async fn join_channel_internal( async fn join_channel_buffer( request: proto::JoinChannelBuffer, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); let open_response = db - .join_channel_buffer(channel_id, session.user_id, session.connection_id) + .join_channel_buffer(channel_id, session.user_id(), session.connection_id) .await?; let collaborators = open_response.collaborators.clone(); @@ -3007,13 +3122,13 @@ async fn join_channel_buffer( /// Edit the channel notes async fn update_channel_buffer( request: proto::UpdateChannelBuffer, - session: Session, + session: UserSession, ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); let (collaborators, non_collaborators, epoch, version) = db - .update_channel_buffer(channel_id, session.user_id, &request.operations) + .update_channel_buffer(channel_id, session.user_id(), &request.operations) .await?; channel_buffer_updated( @@ -3055,11 +3170,11 @@ async fn update_channel_buffer( async fn rejoin_channel_buffers( request: proto::RejoinChannelBuffers, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let db = session.db().await; let buffers = db - .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id) + .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id) .await?; for rejoined_buffer in &buffers { @@ -3090,7 +3205,7 @@ async fn rejoin_channel_buffers( async fn leave_channel_buffer( request: proto::LeaveChannelBuffer, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); @@ -3152,7 +3267,7 @@ fn send_notifications( async fn send_channel_message( request: proto::SendChannelMessage, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { // Validate the message body. let body = request.body.trim().to_string(); @@ -3181,7 +3296,7 @@ async fn send_channel_message( .await .create_channel_message( channel_id, - session.user_id, + session.user_id(), &body, &request.mentions, timestamp, @@ -3194,7 +3309,7 @@ async fn send_channel_message( .await?; let message = proto::ChannelMessage { - sender_id: session.user_id.to_proto(), + sender_id: session.user_id().to_proto(), id: message_id.to_proto(), body, mentions: request.mentions, @@ -3248,14 +3363,14 @@ async fn send_channel_message( async fn remove_channel_message( request: proto::RemoveChannelMessage, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let channel_id = ChannelId::from_proto(request.channel_id); let message_id = MessageId::from_proto(request.message_id); let connection_ids = session .db() .await - .remove_channel_message(channel_id, message_id, session.user_id) + .remove_channel_message(channel_id, message_id, session.user_id()) .await?; broadcast(Some(session.connection_id), connection_ids, |connection| { session.peer.send(connection, request.clone()) @@ -3267,7 +3382,7 @@ async fn remove_channel_message( async fn update_channel_message( request: proto::UpdateChannelMessage, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let channel_id = ChannelId::from_proto(request.channel_id); let message_id = MessageId::from_proto(request.message_id); @@ -3284,7 +3399,7 @@ async fn update_channel_message( .update_channel_message( channel_id, message_id, - session.user_id, + session.user_id(), request.body.as_str(), &request.mentions, updated_at, @@ -3297,7 +3412,7 @@ async fn update_channel_message( .ok_or_else(|| anyhow!("nonce can't be blank"))?; let message = proto::ChannelMessage { - sender_id: session.user_id.to_proto(), + sender_id: session.user_id().to_proto(), id: message_id.to_proto(), body: request.body.clone(), mentions: request.mentions.clone(), @@ -3332,14 +3447,14 @@ async fn update_channel_message( /// Mark a channel message as read async fn acknowledge_channel_message( request: proto::AckChannelMessage, - session: Session, + session: UserSession, ) -> Result<()> { let channel_id = ChannelId::from_proto(request.channel_id); let message_id = MessageId::from_proto(request.message_id); let notifications = session .db() .await - .observe_channel_message(channel_id, session.user_id, message_id) + .observe_channel_message(channel_id, session.user_id(), message_id) .await?; send_notifications( &*session.connection_pool().await, @@ -3352,7 +3467,7 @@ async fn acknowledge_channel_message( /// Mark a buffer version as synced async fn acknowledge_buffer_version( request: proto::AckBufferOperation, - session: Session, + session: UserSession, ) -> Result<()> { let buffer_id = BufferId::from_proto(request.buffer_id); session @@ -3360,7 +3475,7 @@ async fn acknowledge_buffer_version( .await .observe_buffer_version( buffer_id, - session.user_id, + session.user_id(), request.epoch as i32, &request.version, ) @@ -3394,10 +3509,13 @@ async fn complete_with_language_model( open_ai_api_key: Option>, google_ai_api_key: Option>, ) -> Result<()> { + let Some(session) = session.for_user() else { + return Err(anyhow!("user not found"))?; + }; authorize_access_to_language_models(&session).await?; session .rate_limiter - .check::(session.user_id) + .check::(session.user_id()) .await?; if request.model.starts_with("gpt") { @@ -3416,7 +3534,7 @@ async fn complete_with_language_model( async fn complete_with_open_ai( request: proto::CompleteWithLanguageModel, response: StreamingResponse, - session: Session, + session: UserSession, api_key: Arc, ) -> Result<()> { const OPEN_AI_API_URL: &str = "https://api.openai.com/v1"; @@ -3458,7 +3576,7 @@ async fn complete_with_open_ai( async fn complete_with_google_ai( request: proto::CompleteWithLanguageModel, response: StreamingResponse, - session: Session, + session: UserSession, api_key: Arc, ) -> Result<()> { let mut stream = google_ai::stream_generate_content( @@ -3527,7 +3645,7 @@ impl RateLimit for CountTokensWithLanguageModelRateLimit { async fn count_tokens_with_language_model( request: proto::CountTokensWithLanguageModel, response: Response, - session: Session, + session: UserSession, google_ai_api_key: Option>, ) -> Result<()> { authorize_access_to_language_models(&session).await?; @@ -3541,7 +3659,7 @@ async fn count_tokens_with_language_model( session .rate_limiter - .check::(session.user_id) + .check::(session.user_id()) .await?; let api_key = google_ai_api_key @@ -3559,9 +3677,9 @@ async fn count_tokens_with_language_model( Ok(()) } -async fn authorize_access_to_language_models(session: &Session) -> Result<(), Error> { +async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> { let db = session.db().await; - let flags = db.get_user_flags(session.user_id).await?; + let flags = db.get_user_flags(session.user_id()).await?; if flags.iter().any(|flag| flag == "language-models") { Ok(()) } else { @@ -3573,15 +3691,15 @@ async fn authorize_access_to_language_models(session: &Session) -> Result<(), Er async fn join_channel_chat( request: proto::JoinChannelChat, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let channel_id = ChannelId::from_proto(request.channel_id); let db = session.db().await; - db.join_channel_chat(channel_id, session.connection_id, session.user_id) + db.join_channel_chat(channel_id, session.connection_id, session.user_id()) .await?; let messages = db - .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None) + .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None) .await?; response.send(proto::JoinChannelChatResponse { done: messages.len() < MESSAGE_COUNT_PER_PAGE, @@ -3591,12 +3709,12 @@ async fn join_channel_chat( } /// Stop receiving chat updates for a channel -async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> { +async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> { let channel_id = ChannelId::from_proto(request.channel_id); session .db() .await - .leave_channel_chat(channel_id, session.connection_id, session.user_id) + .leave_channel_chat(channel_id, session.connection_id, session.user_id()) .await?; Ok(()) } @@ -3605,7 +3723,7 @@ async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) async fn get_channel_messages( request: proto::GetChannelMessages, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let channel_id = ChannelId::from_proto(request.channel_id); let messages = session @@ -3613,7 +3731,7 @@ async fn get_channel_messages( .await .get_channel_messages( channel_id, - session.user_id, + session.user_id(), MESSAGE_COUNT_PER_PAGE, Some(MessageId::from_proto(request.before_message_id)), ) @@ -3629,7 +3747,7 @@ async fn get_channel_messages( async fn get_channel_messages_by_id( request: proto::GetChannelMessagesById, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let message_ids = request .message_ids @@ -3639,7 +3757,7 @@ async fn get_channel_messages_by_id( let messages = session .db() .await - .get_channel_messages_by_id(session.user_id, &message_ids) + .get_channel_messages_by_id(session.user_id(), &message_ids) .await?; response.send(proto::GetChannelMessagesResponse { done: messages.len() < MESSAGE_COUNT_PER_PAGE, @@ -3652,13 +3770,13 @@ async fn get_channel_messages_by_id( async fn get_notifications( request: proto::GetNotifications, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let notifications = session .db() .await .get_notifications( - session.user_id, + session.user_id(), NOTIFICATION_COUNT_PER_PAGE, request .before_id @@ -3676,12 +3794,12 @@ async fn get_notifications( async fn mark_notification_as_read( request: proto::MarkNotificationRead, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let database = &session.db().await; let notifications = database .mark_notification_as_read_by_id( - session.user_id, + session.user_id(), NotificationId::from_proto(request.notification_id), ) .await?; @@ -3698,16 +3816,16 @@ async fn mark_notification_as_read( async fn get_private_user_info( _request: proto::GetPrivateUserInfo, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let db = session.db().await; - let metrics_id = db.get_user_metrics_id(session.user_id).await?; + let metrics_id = db.get_user_metrics_id(session.user_id()).await?; let user = db - .get_user_by_id(session.user_id) + .get_user_by_id(session.user_id()) .await? .ok_or_else(|| anyhow!("user not found"))?; - let flags = db.get_user_flags(session.user_id).await?; + let flags = db.get_user_flags(session.user_id()).await?; response.send(proto::GetPrivateUserInfoResponse { metrics_id, @@ -3951,7 +4069,7 @@ async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> Ok(()) } -async fn leave_room_for_session(session: &Session) -> Result<()> { +async fn leave_room_for_session(session: &UserSession) -> Result<()> { let mut contacts_to_update = HashSet::default(); let room_id; @@ -3962,7 +4080,7 @@ async fn leave_room_for_session(session: &Session) -> Result<()> { let channel; if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? { - contacts_to_update.insert(session.user_id); + contacts_to_update.insert(session.user_id()); for project in left_room.left_projects.values() { project_left(project, session); @@ -4013,7 +4131,7 @@ async fn leave_room_for_session(session: &Session) -> Result<()> { if let Some(live_kit) = session.live_kit_client.as_ref() { live_kit - .remove_participant(live_kit_room.clone(), session.user_id.to_string()) + .remove_participant(live_kit_room.clone(), session.user_id().to_string()) .await .trace_err(); @@ -4047,9 +4165,9 @@ async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> { Ok(()) } -fn project_left(project: &db::LeftProject, session: &Session) { +fn project_left(project: &db::LeftProject, session: &UserSession) { for connection_id in &project.connection_ids { - if project.host_user_id == Some(session.user_id) { + if project.host_user_id == Some(session.user_id()) { session .peer .send( diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index e5ca052a2f..3027848b2b 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -1,7 +1,7 @@ use crate::{ db::{tests::TestDb, NewUserParams, UserId}, executor::Executor, - rpc::{Server, ZedVersion, CLEANUP_TIMEOUT, RECONNECT_TIMEOUT}, + rpc::{Principal, Server, ZedVersion, CLEANUP_TIMEOUT, RECONNECT_TIMEOUT}, AppState, Config, RateLimiter, }; use anyhow::anyhow; @@ -197,15 +197,20 @@ impl TestServer { .override_authenticate(move |cx| { cx.spawn(|_| async move { let access_token = "the-token".to_string(); - Ok(Credentials { + Ok(Credentials::User { user_id: user_id.to_proto(), access_token, }) }) }) .override_establish_connection(move |credentials, cx| { - assert_eq!(credentials.user_id, user_id.0 as u64); - assert_eq!(credentials.access_token, "the-token"); + assert_eq!( + credentials, + &Credentials::User { + user_id: user_id.0 as u64, + access_token: "the-token".into() + } + ); let server = server.clone(); let db = db.clone(); @@ -230,9 +235,8 @@ impl TestServer { .spawn(server.handle_connection( server_conn, client_name, - user, + Principal::User(user), ZedVersion(SemanticVersion::new(1, 0, 0)), - None, Some(connection_id_tx), Executor::Deterministic(cx.background_executor().clone()), )) diff --git a/crates/theme/theme.md b/crates/theme/theme.md index f9a7a58178..d19d147597 100644 --- a/crates/theme/theme.md +++ b/crates/theme/theme.md @@ -1,15 +1,15 @@ - # Theme +# Theme - This crate provides the theme system for Zed. +This crate provides the theme system for Zed. - ## Overview +## Overview - A theme is a collection of colors used to build a consistent appearance for UI components across the application. - To produce a theme in Zed, +A theme is a collection of colors used to build a consistent appearance for UI components across the application. +To produce a theme in Zed, - A theme is made of of two parts: A [ThemeFamily] and one or more [Theme]s. +A theme is made of of two parts: A [ThemeFamily] and one or more [Theme]s. // - A [ThemeFamily] contains metadata like theme name, author, and theme-specific [ColorScales] as well as a series of themes. +A [ThemeFamily] contains metadata like theme name, author, and theme-specific [ColorScales] as well as a series of themes. - - [ThemeColors] - A set of colors that are used to style the UI. Refer to the [ThemeColors] documentation for more information. +- [ThemeColors] - A set of colors that are used to style the UI. Refer to the [ThemeColors] documentation for more information. diff --git a/crates/zed/Cargo.toml b/crates/zed/Cargo.toml index b5cf2d9e54..8195a2bf93 100644 --- a/crates/zed/Cargo.toml +++ b/crates/zed/Cargo.toml @@ -26,6 +26,7 @@ breadcrumbs.workspace = true call.workspace = true channel.workspace = true chrono.workspace = true +clap.workspace = true cli.workspace = true client.workspace = true clock.workspace = true diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index bb6a2ca6e8..a5e959d757 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -6,8 +6,9 @@ mod zed; use anyhow::{anyhow, Context as _, Result}; use backtrace::Backtrace; use chrono::Utc; +use clap::{command, Parser}; use cli::FORCE_CLI_MODE_ENV_VAR_NAME; -use client::{parse_zed_link, Client, UserStore}; +use client::{parse_zed_link, Client, ClientSettings, DevServerToken, UserStore}; use collab_ui::channel_view::ChannelView; use db::kvp::KEY_VALUE_STORE; use editor::Editor; @@ -270,9 +271,28 @@ fn main() { cx.activate(true); - let urls = collect_url_args(cx); - if !urls.is_empty() { - listener.open_urls(urls) + let mut args = Args::parse(); + if let Some(dev_server_token) = args.dev_server_token.take() { + let dev_server_token = DevServerToken(dev_server_token); + let server_url = ClientSettings::get_global(&cx).server_url.clone(); + let client = client.clone(); + client.set_dev_server_token(dev_server_token); + cx.spawn(|cx| async move { + client.authenticate_and_connect(false, &cx).await?; + log::info!("Connected to {}", server_url); + anyhow::Ok(()) + }) + .detach_and_log_err(cx); + } else { + let urls: Vec<_> = args + .paths_or_urls + .iter() + .filter_map(|arg| parse_url_arg(arg, cx).log_err()) + .collect(); + + if !urls.is_empty() { + listener.open_urls(urls) + } } let mut triggered_authentication = false; @@ -898,23 +918,35 @@ fn stdout_is_a_pty() -> bool { std::env::var(FORCE_CLI_MODE_ENV_VAR_NAME).ok().is_none() && std::io::stdout().is_terminal() } -fn collect_url_args(cx: &AppContext) -> Vec { - env::args() - .skip(1) - .filter_map(|arg| match std::fs::canonicalize(Path::new(&arg)) { - Ok(path) => Some(format!("file://{}", path.to_string_lossy())), - Err(error) => { - if arg.starts_with("file://") || arg.starts_with("zed-cli://") { - Some(arg) - } else if let Some(_) = parse_zed_link(&arg, cx) { - Some(arg) - } else { - log::error!("error parsing path argument: {}", error); - None - } +#[derive(Parser, Debug)] +#[command(name = "zed", disable_version_flag = true)] +struct Args { + /// A sequence of space-separated paths or urls that you want to open. + /// + /// Use `path:line:row` syntax to open a file at a specific location. + /// Non-existing paths and directories will ignore `:line:row` suffix. + /// + /// URLs can either be file:// or zed:// scheme, or relative to https://zed.dev. + paths_or_urls: Vec, + + /// Instructs zed to run as a dev server on this machine. (not implemented) + #[arg(long)] + dev_server_token: Option, +} + +fn parse_url_arg(arg: &str, cx: &AppContext) -> Result { + match std::fs::canonicalize(Path::new(&arg)) { + Ok(path) => Ok(format!("file://{}", path.to_string_lossy())), + Err(error) => { + if arg.starts_with("file://") || arg.starts_with("zed-cli://") { + Ok(arg.into()) + } else if let Some(_) = parse_zed_link(&arg, cx) { + Ok(arg.into()) + } else { + Err(anyhow!("error parsing path argument: {}", error)) } - }) - .collect() + } + } } fn load_embedded_fonts(cx: &AppContext) { diff --git a/script/create-migration b/script/create-migration new file mode 100755 index 0000000000..187336be19 --- /dev/null +++ b/script/create-migration @@ -0,0 +1,3 @@ +zed . \ + "crates/collab/migrations.sqlite/20221109000000_test_schema.sql" \ + "crates/collab/migrations/$(date -u +%Y%m%d%H%M%S)_$(echo $1 | sed 's/[^a-z0-9]/_/g').sql" diff --git a/script/eula/eula.rtf b/script/eula/eula.rtf index 3bdbb463bb..6feaff789c 100644 --- a/script/eula/eula.rtf +++ b/script/eula/eula.rtf @@ -182,4 +182,4 @@ \f0\b \cf2 DATE: April 5, 2023 \f1\b0 \ -} \ No newline at end of file +}