mirror of
https://github.com/zed-industries/zed.git
synced 2024-10-26 00:19:46 +00:00
391ad489ff
Co-Authored-By: Nathan Sobo <nathan@zed.dev> Co-Authored-By: Max Brunsfeld <max@zed.dev>
653 lines
19 KiB
Rust
653 lines
19 KiB
Rust
use crate::auth::{self, UserId};
|
|
|
|
use super::{auth::PeerExt as _, AppState};
|
|
use anyhow::anyhow;
|
|
use async_std::task;
|
|
use async_tungstenite::{
|
|
tungstenite::{protocol::Role, Error as WebSocketError, Message as WebSocketMessage},
|
|
WebSocketStream,
|
|
};
|
|
use sha1::{Digest as _, Sha1};
|
|
use std::{
|
|
collections::{HashMap, HashSet},
|
|
future::Future,
|
|
mem,
|
|
sync::Arc,
|
|
time::Instant,
|
|
};
|
|
use surf::StatusCode;
|
|
use tide::log;
|
|
use tide::{
|
|
http::headers::{HeaderName, CONNECTION, UPGRADE},
|
|
Request, Response,
|
|
};
|
|
use zrpc::{
|
|
auth::random_token,
|
|
proto::{self, EnvelopedMessage},
|
|
ConnectionId, Peer, Router, TypedEnvelope,
|
|
};
|
|
|
|
type ReplicaId = u16;
|
|
|
|
#[derive(Default)]
|
|
pub struct State {
|
|
connections: HashMap<ConnectionId, ConnectionState>,
|
|
pub worktrees: HashMap<u64, WorktreeState>,
|
|
next_worktree_id: u64,
|
|
}
|
|
|
|
struct ConnectionState {
|
|
_user_id: i32,
|
|
worktrees: HashSet<u64>,
|
|
}
|
|
|
|
pub struct WorktreeState {
|
|
host_connection_id: Option<ConnectionId>,
|
|
guest_connection_ids: HashMap<ConnectionId, ReplicaId>,
|
|
active_replica_ids: HashSet<ReplicaId>,
|
|
access_token: String,
|
|
root_name: String,
|
|
entries: HashMap<u64, proto::Entry>,
|
|
}
|
|
|
|
impl WorktreeState {
|
|
pub fn connection_ids(&self) -> Vec<ConnectionId> {
|
|
self.guest_connection_ids
|
|
.keys()
|
|
.copied()
|
|
.chain(self.host_connection_id)
|
|
.collect()
|
|
}
|
|
|
|
fn host_connection_id(&self) -> tide::Result<ConnectionId> {
|
|
Ok(self
|
|
.host_connection_id
|
|
.ok_or_else(|| anyhow!("host disconnected from worktree"))?)
|
|
}
|
|
}
|
|
|
|
impl State {
|
|
// Add a new connection associated with a given user.
|
|
pub fn add_connection(&mut self, connection_id: ConnectionId, _user_id: i32) {
|
|
self.connections.insert(
|
|
connection_id,
|
|
ConnectionState {
|
|
_user_id,
|
|
worktrees: Default::default(),
|
|
},
|
|
);
|
|
}
|
|
|
|
// Remove the given connection and its association with any worktrees.
|
|
pub fn remove_connection(&mut self, connection_id: ConnectionId) -> Vec<u64> {
|
|
let mut worktree_ids = Vec::new();
|
|
if let Some(connection_state) = self.connections.remove(&connection_id) {
|
|
for worktree_id in connection_state.worktrees {
|
|
if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
|
|
if worktree.host_connection_id == Some(connection_id) {
|
|
worktree_ids.push(worktree_id);
|
|
} else if let Some(replica_id) =
|
|
worktree.guest_connection_ids.remove(&connection_id)
|
|
{
|
|
worktree.active_replica_ids.remove(&replica_id);
|
|
worktree_ids.push(worktree_id);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
worktree_ids
|
|
}
|
|
|
|
// Add the given connection as a guest of the given worktree
|
|
pub fn join_worktree(
|
|
&mut self,
|
|
connection_id: ConnectionId,
|
|
worktree_id: u64,
|
|
access_token: &str,
|
|
) -> Option<(ReplicaId, &WorktreeState)> {
|
|
if let Some(worktree_state) = self.worktrees.get_mut(&worktree_id) {
|
|
if access_token == worktree_state.access_token {
|
|
if let Some(connection_state) = self.connections.get_mut(&connection_id) {
|
|
connection_state.worktrees.insert(worktree_id);
|
|
}
|
|
|
|
let mut replica_id = 1;
|
|
while worktree_state.active_replica_ids.contains(&replica_id) {
|
|
replica_id += 1;
|
|
}
|
|
worktree_state.active_replica_ids.insert(replica_id);
|
|
worktree_state
|
|
.guest_connection_ids
|
|
.insert(connection_id, replica_id);
|
|
Some((replica_id, worktree_state))
|
|
} else {
|
|
None
|
|
}
|
|
} else {
|
|
None
|
|
}
|
|
}
|
|
|
|
fn read_worktree(
|
|
&self,
|
|
worktree_id: u64,
|
|
connection_id: ConnectionId,
|
|
) -> tide::Result<&WorktreeState> {
|
|
let worktree = self
|
|
.worktrees
|
|
.get(&worktree_id)
|
|
.ok_or_else(|| anyhow!("worktree not found"))?;
|
|
|
|
if worktree.host_connection_id == Some(connection_id)
|
|
|| worktree.guest_connection_ids.contains_key(&connection_id)
|
|
{
|
|
Ok(worktree)
|
|
} else {
|
|
Err(anyhow!(
|
|
"{} is not a member of worktree {}",
|
|
connection_id,
|
|
worktree_id
|
|
))?
|
|
}
|
|
}
|
|
|
|
fn write_worktree(
|
|
&mut self,
|
|
worktree_id: u64,
|
|
connection_id: ConnectionId,
|
|
) -> tide::Result<&mut WorktreeState> {
|
|
let worktree = self
|
|
.worktrees
|
|
.get_mut(&worktree_id)
|
|
.ok_or_else(|| anyhow!("worktree not found"))?;
|
|
|
|
if worktree.host_connection_id == Some(connection_id)
|
|
|| worktree.guest_connection_ids.contains_key(&connection_id)
|
|
{
|
|
Ok(worktree)
|
|
} else {
|
|
Err(anyhow!(
|
|
"{} is not a member of worktree {}",
|
|
connection_id,
|
|
worktree_id
|
|
))?
|
|
}
|
|
}
|
|
}
|
|
|
|
trait MessageHandler<'a, M: proto::EnvelopedMessage> {
|
|
type Output: 'a + Send + Future<Output = tide::Result<()>>;
|
|
|
|
fn handle(
|
|
&self,
|
|
message: TypedEnvelope<M>,
|
|
rpc: &'a Arc<Peer>,
|
|
app_state: &'a Arc<AppState>,
|
|
) -> Self::Output;
|
|
}
|
|
|
|
impl<'a, M, F, Fut> MessageHandler<'a, M> for F
|
|
where
|
|
M: proto::EnvelopedMessage,
|
|
F: Fn(TypedEnvelope<M>, &'a Arc<Peer>, &'a Arc<AppState>) -> Fut,
|
|
Fut: 'a + Send + Future<Output = tide::Result<()>>,
|
|
{
|
|
type Output = Fut;
|
|
|
|
fn handle(
|
|
&self,
|
|
message: TypedEnvelope<M>,
|
|
rpc: &'a Arc<Peer>,
|
|
app_state: &'a Arc<AppState>,
|
|
) -> Self::Output {
|
|
(self)(message, rpc, app_state)
|
|
}
|
|
}
|
|
|
|
fn on_message<M, H>(router: &mut Router, rpc: &Arc<Peer>, app_state: &Arc<AppState>, handler: H)
|
|
where
|
|
M: EnvelopedMessage,
|
|
H: 'static + Clone + Send + Sync + for<'a> MessageHandler<'a, M>,
|
|
{
|
|
let rpc = rpc.clone();
|
|
let handler = handler.clone();
|
|
let app_state = app_state.clone();
|
|
router.add_message_handler(move |message| {
|
|
let rpc = rpc.clone();
|
|
let handler = handler.clone();
|
|
let app_state = app_state.clone();
|
|
async move {
|
|
let sender_id = message.sender_id;
|
|
let message_id = message.message_id;
|
|
let start_time = Instant::now();
|
|
log::info!(
|
|
"RPC message received. id: {}.{}, type:{}",
|
|
sender_id,
|
|
message_id,
|
|
M::NAME
|
|
);
|
|
if let Err(err) = handler.handle(message, &rpc, &app_state).await {
|
|
log::error!("error handling message: {:?}", err);
|
|
} else {
|
|
log::info!(
|
|
"RPC message handled. id:{}.{}, duration:{:?}",
|
|
sender_id,
|
|
message_id,
|
|
start_time.elapsed()
|
|
);
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
});
|
|
}
|
|
|
|
pub fn add_rpc_routes(router: &mut Router, state: &Arc<AppState>, rpc: &Arc<Peer>) {
|
|
on_message(router, rpc, state, share_worktree);
|
|
on_message(router, rpc, state, join_worktree);
|
|
on_message(router, rpc, state, update_worktree);
|
|
on_message(router, rpc, state, close_worktree);
|
|
on_message(router, rpc, state, open_buffer);
|
|
on_message(router, rpc, state, close_buffer);
|
|
on_message(router, rpc, state, update_buffer);
|
|
on_message(router, rpc, state, buffer_saved);
|
|
on_message(router, rpc, state, save_buffer);
|
|
}
|
|
|
|
pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
|
|
let mut router = Router::new();
|
|
add_rpc_routes(&mut router, app.state(), rpc);
|
|
let router = Arc::new(router);
|
|
|
|
let rpc = rpc.clone();
|
|
app.at("/rpc").with(auth::VerifyToken).get(move |request: Request<Arc<AppState>>| {
|
|
let user_id = request.ext::<UserId>().copied();
|
|
let rpc = rpc.clone();
|
|
let router = router.clone();
|
|
async move {
|
|
const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
|
|
|
|
let connection_upgrade = header_contains_ignore_case(&request, CONNECTION, "upgrade");
|
|
let upgrade_to_websocket = header_contains_ignore_case(&request, UPGRADE, "websocket");
|
|
let upgrade_requested = connection_upgrade && upgrade_to_websocket;
|
|
|
|
if !upgrade_requested {
|
|
return Ok(Response::new(StatusCode::UpgradeRequired));
|
|
}
|
|
|
|
let header = match request.header("Sec-Websocket-Key") {
|
|
Some(h) => h.as_str(),
|
|
None => return Err(anyhow!("expected sec-websocket-key"))?,
|
|
};
|
|
|
|
let mut response = Response::new(StatusCode::SwitchingProtocols);
|
|
response.insert_header(UPGRADE, "websocket");
|
|
response.insert_header(CONNECTION, "Upgrade");
|
|
let hash = Sha1::new().chain(header).chain(WEBSOCKET_GUID).finalize();
|
|
response.insert_header("Sec-Websocket-Accept", base64::encode(&hash[..]));
|
|
response.insert_header("Sec-Websocket-Version", "13");
|
|
|
|
let http_res: &mut tide::http::Response = response.as_mut();
|
|
let upgrade_receiver = http_res.recv_upgrade().await;
|
|
let addr = request.remote().unwrap_or("unknown").to_string();
|
|
let state = request.state().clone();
|
|
let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?.0;
|
|
task::spawn(async move {
|
|
if let Some(stream) = upgrade_receiver.await {
|
|
let stream = WebSocketStream::from_raw_socket(stream, Role::Server, None).await;
|
|
handle_connection(rpc, router, state, addr, stream, user_id).await;
|
|
}
|
|
});
|
|
|
|
Ok(response)
|
|
}
|
|
});
|
|
}
|
|
|
|
pub async fn handle_connection<Conn>(
|
|
rpc: Arc<Peer>,
|
|
router: Arc<Router>,
|
|
state: Arc<AppState>,
|
|
addr: String,
|
|
stream: Conn,
|
|
user_id: i32,
|
|
) where
|
|
Conn: 'static
|
|
+ futures::Sink<WebSocketMessage, Error = WebSocketError>
|
|
+ futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
|
|
+ Send
|
|
+ Unpin,
|
|
{
|
|
log::info!("accepted rpc connection: {:?}", addr);
|
|
let (connection_id, handle_io, handle_messages) = rpc.add_connection(stream, router).await;
|
|
state
|
|
.rpc
|
|
.write()
|
|
.await
|
|
.add_connection(connection_id, user_id);
|
|
|
|
let handle_messages = async move {
|
|
handle_messages.await;
|
|
Ok(())
|
|
};
|
|
|
|
if let Err(e) = futures::try_join!(handle_messages, handle_io) {
|
|
log::error!("error handling rpc connection {:?} - {:?}", addr, e);
|
|
}
|
|
|
|
log::info!("closing connection to {:?}", addr);
|
|
if let Err(e) = rpc.sign_out(connection_id, &state).await {
|
|
log::error!("error signing out connection {:?} - {:?}", addr, e);
|
|
}
|
|
}
|
|
|
|
async fn share_worktree(
|
|
mut request: TypedEnvelope<proto::ShareWorktree>,
|
|
rpc: &Arc<Peer>,
|
|
state: &Arc<AppState>,
|
|
) -> tide::Result<()> {
|
|
let mut state = state.rpc.write().await;
|
|
let worktree_id = state.next_worktree_id;
|
|
state.next_worktree_id += 1;
|
|
let access_token = random_token();
|
|
let worktree = request
|
|
.payload
|
|
.worktree
|
|
.as_mut()
|
|
.ok_or_else(|| anyhow!("missing worktree"))?;
|
|
let entries = mem::take(&mut worktree.entries)
|
|
.into_iter()
|
|
.map(|entry| (entry.id, entry))
|
|
.collect();
|
|
state.worktrees.insert(
|
|
worktree_id,
|
|
WorktreeState {
|
|
host_connection_id: Some(request.sender_id),
|
|
guest_connection_ids: Default::default(),
|
|
active_replica_ids: Default::default(),
|
|
access_token: access_token.clone(),
|
|
root_name: mem::take(&mut worktree.root_name),
|
|
entries,
|
|
},
|
|
);
|
|
|
|
rpc.respond(
|
|
request.receipt(),
|
|
proto::ShareWorktreeResponse {
|
|
worktree_id,
|
|
access_token,
|
|
},
|
|
)
|
|
.await?;
|
|
Ok(())
|
|
}
|
|
|
|
async fn join_worktree(
|
|
request: TypedEnvelope<proto::OpenWorktree>,
|
|
rpc: &Arc<Peer>,
|
|
state: &Arc<AppState>,
|
|
) -> tide::Result<()> {
|
|
let worktree_id = request.payload.worktree_id;
|
|
let access_token = &request.payload.access_token;
|
|
|
|
let mut state = state.rpc.write().await;
|
|
if let Some((peer_replica_id, worktree)) =
|
|
state.join_worktree(request.sender_id, worktree_id, access_token)
|
|
{
|
|
let mut peers = Vec::new();
|
|
if let Some(host_connection_id) = worktree.host_connection_id {
|
|
peers.push(proto::Peer {
|
|
peer_id: host_connection_id.0,
|
|
replica_id: 0,
|
|
});
|
|
}
|
|
for (peer_conn_id, peer_replica_id) in &worktree.guest_connection_ids {
|
|
if *peer_conn_id != request.sender_id {
|
|
peers.push(proto::Peer {
|
|
peer_id: peer_conn_id.0,
|
|
replica_id: *peer_replica_id as u32,
|
|
});
|
|
}
|
|
}
|
|
|
|
broadcast(request.sender_id, worktree.connection_ids(), |conn_id| {
|
|
rpc.send(
|
|
conn_id,
|
|
proto::AddPeer {
|
|
worktree_id,
|
|
peer: Some(proto::Peer {
|
|
peer_id: request.sender_id.0,
|
|
replica_id: peer_replica_id as u32,
|
|
}),
|
|
},
|
|
)
|
|
})
|
|
.await?;
|
|
rpc.respond(
|
|
request.receipt(),
|
|
proto::OpenWorktreeResponse {
|
|
worktree_id,
|
|
worktree: Some(proto::Worktree {
|
|
root_name: worktree.root_name.clone(),
|
|
entries: worktree.entries.values().cloned().collect(),
|
|
}),
|
|
replica_id: peer_replica_id as u32,
|
|
peers,
|
|
},
|
|
)
|
|
.await?;
|
|
} else {
|
|
rpc.respond(
|
|
request.receipt(),
|
|
proto::OpenWorktreeResponse {
|
|
worktree_id,
|
|
worktree: None,
|
|
replica_id: 0,
|
|
peers: Vec::new(),
|
|
},
|
|
)
|
|
.await?;
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn update_worktree(
|
|
request: TypedEnvelope<proto::UpdateWorktree>,
|
|
rpc: &Arc<Peer>,
|
|
state: &Arc<AppState>,
|
|
) -> tide::Result<()> {
|
|
{
|
|
let mut state = state.rpc.write().await;
|
|
let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
|
|
for entry_id in &request.payload.removed_entries {
|
|
worktree.entries.remove(&entry_id);
|
|
}
|
|
|
|
for entry in &request.payload.updated_entries {
|
|
worktree.entries.insert(entry.id, entry.clone());
|
|
}
|
|
}
|
|
|
|
broadcast_in_worktree(request.payload.worktree_id, request, rpc, state).await?;
|
|
Ok(())
|
|
}
|
|
|
|
async fn close_worktree(
|
|
request: TypedEnvelope<proto::CloseWorktree>,
|
|
rpc: &Arc<Peer>,
|
|
state: &Arc<AppState>,
|
|
) -> tide::Result<()> {
|
|
let connection_ids;
|
|
{
|
|
let mut state = state.rpc.write().await;
|
|
let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
|
|
connection_ids = worktree.connection_ids();
|
|
if worktree.host_connection_id == Some(request.sender_id) {
|
|
worktree.host_connection_id = None;
|
|
} else if let Some(replica_id) = worktree.guest_connection_ids.remove(&request.sender_id) {
|
|
worktree.active_replica_ids.remove(&replica_id);
|
|
}
|
|
}
|
|
|
|
broadcast(request.sender_id, connection_ids, |conn_id| {
|
|
rpc.send(
|
|
conn_id,
|
|
proto::RemovePeer {
|
|
worktree_id: request.payload.worktree_id,
|
|
peer_id: request.sender_id.0,
|
|
},
|
|
)
|
|
})
|
|
.await?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn open_buffer(
|
|
request: TypedEnvelope<proto::OpenBuffer>,
|
|
rpc: &Arc<Peer>,
|
|
state: &Arc<AppState>,
|
|
) -> tide::Result<()> {
|
|
let receipt = request.receipt();
|
|
let worktree_id = request.payload.worktree_id;
|
|
let host_connection_id = state
|
|
.rpc
|
|
.read()
|
|
.await
|
|
.read_worktree(worktree_id, request.sender_id)?
|
|
.host_connection_id()?;
|
|
|
|
let response = rpc
|
|
.forward_request(request.sender_id, host_connection_id, request.payload)
|
|
.await?;
|
|
rpc.respond(receipt, response).await?;
|
|
Ok(())
|
|
}
|
|
|
|
async fn close_buffer(
|
|
request: TypedEnvelope<proto::CloseBuffer>,
|
|
rpc: &Arc<Peer>,
|
|
state: &Arc<AppState>,
|
|
) -> tide::Result<()> {
|
|
let host_connection_id = state
|
|
.rpc
|
|
.read()
|
|
.await
|
|
.read_worktree(request.payload.worktree_id, request.sender_id)?
|
|
.host_connection_id()?;
|
|
|
|
rpc.forward_send(request.sender_id, host_connection_id, request.payload)
|
|
.await?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn save_buffer(
|
|
request: TypedEnvelope<proto::SaveBuffer>,
|
|
rpc: &Arc<Peer>,
|
|
state: &Arc<AppState>,
|
|
) -> tide::Result<()> {
|
|
let host;
|
|
let guests;
|
|
{
|
|
let state = state.rpc.read().await;
|
|
let worktree = state.read_worktree(request.payload.worktree_id, request.sender_id)?;
|
|
host = worktree.host_connection_id()?;
|
|
guests = worktree
|
|
.guest_connection_ids
|
|
.keys()
|
|
.copied()
|
|
.collect::<Vec<_>>();
|
|
}
|
|
|
|
let sender = request.sender_id;
|
|
let receipt = request.receipt();
|
|
let response = rpc
|
|
.forward_request(sender, host, request.payload.clone())
|
|
.await?;
|
|
|
|
broadcast(host, guests, |conn_id| {
|
|
let response = response.clone();
|
|
async move {
|
|
if conn_id == sender {
|
|
rpc.respond(receipt, response).await
|
|
} else {
|
|
rpc.forward_send(host, conn_id, response).await
|
|
}
|
|
}
|
|
})
|
|
.await?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn update_buffer(
|
|
request: TypedEnvelope<proto::UpdateBuffer>,
|
|
rpc: &Arc<Peer>,
|
|
state: &Arc<AppState>,
|
|
) -> tide::Result<()> {
|
|
broadcast_in_worktree(request.payload.worktree_id, request, rpc, state).await
|
|
}
|
|
|
|
async fn buffer_saved(
|
|
request: TypedEnvelope<proto::BufferSaved>,
|
|
rpc: &Arc<Peer>,
|
|
state: &Arc<AppState>,
|
|
) -> tide::Result<()> {
|
|
broadcast_in_worktree(request.payload.worktree_id, request, rpc, state).await
|
|
}
|
|
|
|
async fn broadcast_in_worktree<T: proto::EnvelopedMessage>(
|
|
worktree_id: u64,
|
|
request: TypedEnvelope<T>,
|
|
rpc: &Arc<Peer>,
|
|
state: &Arc<AppState>,
|
|
) -> tide::Result<()> {
|
|
let connection_ids = state
|
|
.rpc
|
|
.read()
|
|
.await
|
|
.read_worktree(worktree_id, request.sender_id)?
|
|
.connection_ids();
|
|
|
|
broadcast(request.sender_id, connection_ids, |conn_id| {
|
|
rpc.forward_send(request.sender_id, conn_id, request.payload.clone())
|
|
})
|
|
.await?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn broadcast<F, T>(
|
|
sender_id: ConnectionId,
|
|
receiver_ids: Vec<ConnectionId>,
|
|
mut f: F,
|
|
) -> anyhow::Result<()>
|
|
where
|
|
F: FnMut(ConnectionId) -> T,
|
|
T: Future<Output = anyhow::Result<()>>,
|
|
{
|
|
let futures = receiver_ids
|
|
.into_iter()
|
|
.filter(|id| *id != sender_id)
|
|
.map(|id| f(id));
|
|
futures::future::try_join_all(futures).await?;
|
|
Ok(())
|
|
}
|
|
|
|
fn header_contains_ignore_case<T>(
|
|
request: &tide::Request<T>,
|
|
header_name: HeaderName,
|
|
value: &str,
|
|
) -> bool {
|
|
request
|
|
.header(header_name)
|
|
.map(|h| {
|
|
h.as_str()
|
|
.split(',')
|
|
.any(|s| s.trim().eq_ignore_ascii_case(value.trim()))
|
|
})
|
|
.unwrap_or(false)
|
|
}
|