Implement initial RPC endpoints for chat

Co-Authored-By: Nathan Sobo <nathan@zed.dev>
This commit is contained in:
Max Brunsfeld 2021-08-06 13:43:06 -07:00
parent 13ee9c2286
commit 4a32bd6bb0
7 changed files with 258 additions and 81 deletions

2
Cargo.lock generated
View file

@ -4699,6 +4699,7 @@ dependencies = [
"sqlx-rt 0.5.5",
"stringprep",
"thiserror",
"time 0.2.25",
"url",
"webpki",
"webpki-roots",
@ -5866,6 +5867,7 @@ dependencies = [
"surf",
"tide",
"tide-compress",
"time 0.2.25",
"toml 0.5.8",
"zed",
"zrpc",

View file

@ -31,6 +31,7 @@ sha-1 = "0.9"
surf = "2.2.0"
tide = "0.16.0"
tide-compress = "0.9.0"
time = "0.2"
toml = "0.5.8"
zrpc = { path = "../zrpc" }
@ -41,7 +42,7 @@ default-features = false
[dependencies.sqlx]
version = "0.5.2"
features = ["runtime-async-std-rustls", "postgres"]
features = ["runtime-async-std-rustls", "postgres", "time"]
[dev-dependencies]
gpui = { path = "../gpui" }

View file

@ -257,7 +257,7 @@ async fn get_auth_callback(mut request: Request) -> tide::Result {
// When signing in from the native app, generate a new access token for the current user. Return
// a redirect so that the user's browser sends this access token to the locally-running app.
if let Some((user, app_sign_in_params)) = user.zip(query.native_app_sign_in_params) {
let access_token = create_access_token(request.db(), user.id()).await?;
let access_token = create_access_token(request.db(), user.id).await?;
let native_app_public_key =
zed_auth::PublicKey::try_from(app_sign_in_params.native_app_public_key.clone())
.context("failed to parse app public key")?;
@ -267,9 +267,7 @@ async fn get_auth_callback(mut request: Request) -> tide::Result {
return Ok(tide::Redirect::new(&format!(
"http://127.0.0.1:{}?user_id={}&access_token={}",
app_sign_in_params.native_app_port,
user.id().0,
encrypted_access_token,
app_sign_in_params.native_app_port, user.id.0, encrypted_access_token,
))
.into());
}

View file

@ -1,5 +1,6 @@
use serde::Serialize;
use sqlx::{FromRow, Result};
use time::OffsetDateTime;
pub use async_sqlx_session::PostgresSessionStore as SessionStore;
pub use sqlx::postgres::PgPoolOptions as DbOptions;
@ -8,14 +9,14 @@ pub struct Db(pub sqlx::PgPool);
#[derive(Debug, FromRow, Serialize)]
pub struct User {
id: i32,
pub id: UserId,
pub github_login: String,
pub admin: bool,
}
#[derive(Debug, FromRow, Serialize)]
pub struct Signup {
id: i32,
pub id: SignupId,
pub github_login: String,
pub email_address: String,
pub about: String,
@ -23,33 +24,18 @@ pub struct Signup {
#[derive(Debug, FromRow, Serialize)]
pub struct Channel {
id: i32,
pub id: ChannelId,
pub name: String,
}
#[derive(Debug, FromRow)]
pub struct ChannelMessage {
id: i32,
sender_id: i32,
body: String,
sent_at: i64,
pub id: MessageId,
pub sender_id: UserId,
pub body: String,
pub sent_at: OffsetDateTime,
}
#[derive(Clone, Copy)]
pub struct UserId(pub i32);
#[derive(Clone, Copy)]
pub struct OrgId(pub i32);
#[derive(Clone, Copy)]
pub struct ChannelId(pub i32);
#[derive(Clone, Copy)]
pub struct SignupId(pub i32);
#[derive(Clone, Copy)]
pub struct MessageId(pub i32);
impl Db {
// signups
@ -108,6 +94,33 @@ impl Db {
sqlx::query_as(query).fetch_all(&self.0).await
}
pub async fn get_users_by_ids(
&self,
requester_id: UserId,
ids: impl Iterator<Item = UserId>,
) -> Result<Vec<User>> {
// Only return users that are in a common channel with the requesting user.
let query = "
SELECT users.*
FROM
users, channel_memberships
WHERE
users.id IN $1 AND
channel_memberships.user_id = users.id AND
channel_memberships.channel_id IN (
SELECT channel_id
FROM channel_memberships
WHERE channel_memberships.user_id = $2
)
";
sqlx::query_as(query)
.bind(&ids.map(|id| id.0).collect::<Vec<_>>())
.bind(requester_id)
.fetch_all(&self.0)
.await
}
pub async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
sqlx::query_as(query)
@ -147,7 +160,7 @@ impl Db {
VALUES ($1, $2)
";
sqlx::query(query)
.bind(user_id.0 as i32)
.bind(user_id.0)
.bind(access_token_hash)
.execute(&self.0)
.await
@ -156,8 +169,8 @@ impl Db {
pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
let query = "SELECT hash FROM access_tokens WHERE user_id = $1";
sqlx::query_scalar::<_, String>(query)
.bind(user_id.0 as i32)
sqlx::query_scalar(query)
.bind(user_id.0)
.fetch_all(&self.0)
.await
}
@ -180,14 +193,20 @@ impl Db {
}
#[cfg(test)]
pub async fn add_org_member(&self, org_id: OrgId, user_id: UserId) -> Result<()> {
pub async fn add_org_member(
&self,
org_id: OrgId,
user_id: UserId,
is_admin: bool,
) -> Result<()> {
let query = "
INSERT INTO org_memberships (org_id, user_id)
VALUES ($1, $2)
INSERT INTO org_memberships (org_id, user_id, admin)
VALUES ($1, $2, $3)
";
sqlx::query(query)
.bind(org_id.0)
.bind(user_id.0)
.bind(is_admin)
.execute(&self.0)
.await
.map(drop)
@ -272,16 +291,18 @@ impl Db {
channel_id: ChannelId,
sender_id: UserId,
body: &str,
timestamp: OffsetDateTime,
) -> Result<MessageId> {
let query = "
INSERT INTO channel_messages (channel_id, sender_id, body, sent_at)
VALUES ($1, $2, $3, NOW()::timestamp)
VALUES ($1, $2, $3, $4)
RETURNING id
";
sqlx::query_scalar(query)
.bind(channel_id.0)
.bind(sender_id.0)
.bind(body)
.bind(timestamp)
.fetch_one(&self.0)
.await
.map(MessageId)
@ -292,12 +313,15 @@ impl Db {
channel_id: ChannelId,
count: usize,
) -> Result<Vec<ChannelMessage>> {
let query = "
SELECT id, sender_id, body, sent_at
FROM channel_messages
WHERE channel_id = $1
let query = r#"
SELECT
id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at
FROM
channel_messages
WHERE
channel_id = $1
LIMIT $2
";
"#;
sqlx::query_as(query)
.bind(channel_id.0)
.bind(count as i64)
@ -314,14 +338,29 @@ impl std::ops::Deref for Db {
}
}
impl Channel {
pub fn id(&self) -> ChannelId {
ChannelId(self.id)
}
macro_rules! id_type {
($name:ident) => {
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, sqlx::Type, Serialize)]
#[sqlx(transparent)]
#[serde(transparent)]
pub struct $name(pub i32);
impl $name {
#[allow(unused)]
pub fn from_proto(value: u64) -> Self {
Self(value as i32)
}
#[allow(unused)]
pub fn to_proto(&self) -> u64 {
self.0 as u64
}
}
};
}
impl User {
pub fn id(&self) -> UserId {
UserId(self.id)
}
}
id_type!(UserId);
id_type!(OrgId);
id_type!(ChannelId);
id_type!(SignupId);
id_type!(MessageId);

View file

@ -23,6 +23,7 @@ use tide::{
http::headers::{HeaderName, CONNECTION, UPGRADE},
Request, Response,
};
use time::OffsetDateTime;
use zrpc::{
auth::random_token,
proto::{self, EnvelopedMessage},
@ -33,17 +34,19 @@ type ReplicaId = u16;
#[derive(Default)]
pub struct State {
connections: HashMap<ConnectionId, ConnectionState>,
pub worktrees: HashMap<u64, WorktreeState>,
connections: HashMap<ConnectionId, Connection>,
pub worktrees: HashMap<u64, Worktree>,
channels: HashMap<ChannelId, Channel>,
next_worktree_id: u64,
}
struct ConnectionState {
struct Connection {
user_id: UserId,
worktrees: HashSet<u64>,
channels: HashSet<ChannelId>,
}
pub struct WorktreeState {
pub struct Worktree {
host_connection_id: Option<ConnectionId>,
guest_connection_ids: HashMap<ConnectionId, ReplicaId>,
active_replica_ids: HashSet<ReplicaId>,
@ -52,7 +55,12 @@ pub struct WorktreeState {
entries: HashMap<u64, proto::Entry>,
}
impl WorktreeState {
#[derive(Default)]
struct Channel {
connection_ids: HashSet<ConnectionId>,
}
impl Worktree {
pub fn connection_ids(&self) -> Vec<ConnectionId> {
self.guest_connection_ids
.keys()
@ -68,14 +76,21 @@ impl WorktreeState {
}
}
impl Channel {
fn connection_ids(&self) -> Vec<ConnectionId> {
self.connection_ids.iter().copied().collect()
}
}
impl State {
// Add a new connection associated with a given user.
pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
self.connections.insert(
connection_id,
ConnectionState {
Connection {
user_id,
worktrees: Default::default(),
channels: Default::default(),
},
);
}
@ -83,8 +98,13 @@ impl State {
// Remove the given connection and its association with any worktrees.
pub fn remove_connection(&mut self, connection_id: ConnectionId) -> Vec<u64> {
let mut worktree_ids = Vec::new();
if let Some(connection_state) = self.connections.remove(&connection_id) {
for worktree_id in connection_state.worktrees {
if let Some(connection) = self.connections.remove(&connection_id) {
for channel_id in connection.channels {
if let Some(channel) = self.channels.get_mut(&channel_id) {
channel.connection_ids.remove(&connection_id);
}
}
for worktree_id in connection.worktrees {
if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
if worktree.host_connection_id == Some(connection_id) {
worktree_ids.push(worktree_id);
@ -100,28 +120,39 @@ impl State {
worktree_ids
}
fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
if let Some(connection) = self.connections.get_mut(&connection_id) {
connection.channels.insert(channel_id);
self.channels
.entry(channel_id)
.or_default()
.connection_ids
.insert(connection_id);
}
}
// Add the given connection as a guest of the given worktree
pub fn join_worktree(
&mut self,
connection_id: ConnectionId,
worktree_id: u64,
access_token: &str,
) -> Option<(ReplicaId, &WorktreeState)> {
if let Some(worktree_state) = self.worktrees.get_mut(&worktree_id) {
if access_token == worktree_state.access_token {
if let Some(connection_state) = self.connections.get_mut(&connection_id) {
connection_state.worktrees.insert(worktree_id);
) -> Option<(ReplicaId, &Worktree)> {
if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
if access_token == worktree.access_token {
if let Some(connection) = self.connections.get_mut(&connection_id) {
connection.worktrees.insert(worktree_id);
}
let mut replica_id = 1;
while worktree_state.active_replica_ids.contains(&replica_id) {
while worktree.active_replica_ids.contains(&replica_id) {
replica_id += 1;
}
worktree_state.active_replica_ids.insert(replica_id);
worktree_state
worktree.active_replica_ids.insert(replica_id);
worktree
.guest_connection_ids
.insert(connection_id, replica_id);
Some((replica_id, worktree_state))
Some((replica_id, worktree))
} else {
None
}
@ -142,7 +173,7 @@ impl State {
&self,
worktree_id: u64,
connection_id: ConnectionId,
) -> tide::Result<&WorktreeState> {
) -> tide::Result<&Worktree> {
let worktree = self
.worktrees
.get(&worktree_id)
@ -165,7 +196,7 @@ impl State {
&mut self,
worktree_id: u64,
connection_id: ConnectionId,
) -> tide::Result<&mut WorktreeState> {
) -> tide::Result<&mut Worktree> {
let worktree = self
.worktrees
.get_mut(&worktree_id)
@ -263,7 +294,9 @@ pub fn add_rpc_routes(router: &mut Router, state: &Arc<AppState>, rpc: &Arc<Peer
on_message(router, rpc, state, buffer_saved);
on_message(router, rpc, state, save_buffer);
on_message(router, rpc, state, get_channels);
on_message(router, rpc, state, get_users);
on_message(router, rpc, state, join_channel);
on_message(router, rpc, state, send_channel_message);
}
pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
@ -373,7 +406,7 @@ async fn share_worktree(
.collect();
state.worktrees.insert(
worktree_id,
WorktreeState {
Worktree {
host_connection_id: Some(request.sender_id),
guest_connection_ids: Default::default(),
active_replica_ids: Default::default(),
@ -627,7 +660,7 @@ async fn get_channels(
channels: channels
.into_iter()
.map(|chan| proto::Channel {
id: chan.id().0 as u64,
id: chan.id.to_proto(),
name: chan.name,
})
.collect(),
@ -637,6 +670,34 @@ async fn get_channels(
Ok(())
}
async fn get_users(
request: TypedEnvelope<proto::GetUsers>,
rpc: &Arc<Peer>,
state: &Arc<AppState>,
) -> tide::Result<()> {
let user_id = state
.rpc
.read()
.await
.user_id_for_connection(request.sender_id)?;
let receipt = request.receipt();
let user_ids = request.payload.user_ids.into_iter().map(UserId::from_proto);
let users = state
.db
.get_users_by_ids(user_id, user_ids)
.await?
.into_iter()
.map(|user| proto::User {
id: user.id.to_proto(),
github_login: user.github_login,
avatar_url: String::new(),
})
.collect();
rpc.respond(receipt, proto::GetUsersResponse { users })
.await?;
Ok(())
}
async fn join_channel(
request: TypedEnvelope<proto::JoinChannel>,
rpc: &Arc<Peer>,
@ -647,14 +708,74 @@ async fn join_channel(
.read()
.await
.user_id_for_connection(request.sender_id)?;
let channel_id = ChannelId::from_proto(request.payload.channel_id);
if !state
.db
.can_user_access_channel(user_id, ChannelId(request.payload.channel_id as i32))
.can_user_access_channel(user_id, channel_id)
.await?
{
Err(anyhow!("access denied"))?;
}
state
.rpc
.write()
.await
.join_channel(request.sender_id, channel_id);
let messages = state
.db
.get_recent_channel_messages(channel_id, 50)
.await?
.into_iter()
.map(|msg| proto::ChannelMessage {
id: msg.id.to_proto(),
body: msg.body,
timestamp: msg.sent_at.unix_timestamp() as u64,
sender_id: msg.sender_id.to_proto(),
})
.collect();
rpc.respond(request.receipt(), proto::JoinChannelResponse { messages })
.await?;
Ok(())
}
async fn send_channel_message(
request: TypedEnvelope<proto::SendChannelMessage>,
peer: &Arc<Peer>,
app: &Arc<AppState>,
) -> tide::Result<()> {
let channel_id = ChannelId::from_proto(request.payload.channel_id);
let user_id;
let connection_ids;
{
let state = app.rpc.read().await;
user_id = state.user_id_for_connection(request.sender_id)?;
if let Some(channel) = state.channels.get(&channel_id) {
connection_ids = channel.connection_ids();
} else {
return Ok(());
}
}
let timestamp = OffsetDateTime::now_utc();
let message_id = app
.db
.create_channel_message(channel_id, user_id, &request.payload.body, timestamp)
.await?;
let message = proto::ChannelMessageSent {
channel_id: channel_id.to_proto(),
message: Some(proto::ChannelMessage {
sender_id: user_id.to_proto(),
id: message_id.to_proto(),
body: request.payload.body,
timestamp: timestamp.unix_timestamp() as u64,
}),
};
broadcast(request.sender_id, connection_ids, |conn_id| {
peer.send(conn_id, message.clone())
})
.await?;
Ok(())
}

View file

@ -11,9 +11,10 @@ use rand::prelude::*;
use serde_json::json;
use sqlx::{
migrate::{MigrateDatabase, Migrator},
types::time::OffsetDateTime,
Executor as _, Postgres,
};
use std::{path::Path, sync::Arc};
use std::{path::Path, sync::Arc, time::SystemTime};
use zed::{
editor::Editor,
fs::{FakeFs, Fs as _},
@ -485,10 +486,15 @@ async fn test_basic_chat(mut cx_a: TestAppContext, cx_b: TestAppContext) {
let (user_id_a, client_a) = server.create_client(&mut cx_a, "user_a").await;
let (user_id_b, client_b) = server.create_client(&mut cx_a, "user_b").await;
// Create a channel that includes these 2 users and 1 other user.
// Create an org that includes these 2 users and 1 other user.
let db = &server.app_state.db;
let user_id_c = db.create_user("user_c", false).await.unwrap();
let org_id = db.create_org("Test Org", "test-org").await.unwrap();
db.add_org_member(org_id, user_id_a, false).await.unwrap();
db.add_org_member(org_id, user_id_b, false).await.unwrap();
db.add_org_member(org_id, user_id_c, false).await.unwrap();
// Create a channel that includes all the users.
let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
db.add_channel_member(channel_id, user_id_a, false)
.await
@ -499,11 +505,21 @@ async fn test_basic_chat(mut cx_a: TestAppContext, cx_b: TestAppContext) {
db.add_channel_member(channel_id, user_id_c, false)
.await
.unwrap();
db.create_channel_message(channel_id, user_id_c, "first message!")
.await
.unwrap();
// let chatroom_a = ChatRoom::
db.create_channel_message(
channel_id,
user_id_c,
"first message!",
OffsetDateTime::now_utc(),
)
.await
.unwrap();
assert_eq!(
db.get_recent_channel_messages(channel_id, 50)
.await
.unwrap()[0]
.body,
"first message!"
);
}
struct TestServer {

View file

@ -24,10 +24,10 @@ message Envelope {
RemovePeer remove_peer = 19;
GetChannels get_channels = 20;
GetChannelsResponse get_channels_response = 21;
JoinChannel join_channel = 22;
JoinChannelResponse join_channel_response = 23;
GetUsers get_users = 24;
GetUsersResponse get_users_response = 25;
GetUsers get_users = 22;
GetUsersResponse get_users_response = 23;
JoinChannel join_channel = 24;
JoinChannelResponse join_channel_response = 25;
SendChannelMessage send_channel_message = 26;
ChannelMessageSent channel_message_sent = 27;
}