Make some db tests pass against the new sea-orm implementation

This commit is contained in:
Antonio Scandurra 2022-11-30 12:06:25 +01:00
parent b7294887c7
commit d9a892a423
9 changed files with 849 additions and 711 deletions

View file

@ -8,7 +8,7 @@ CREATE TABLE "users" (
"inviter_id" INTEGER REFERENCES users (id),
"connected_once" BOOLEAN NOT NULL DEFAULT false,
"created_at" TIMESTAMP NOT NULL DEFAULT now,
"metrics_id" VARCHAR(255),
"metrics_id" TEXT,
"github_user_id" INTEGER
);
CREATE UNIQUE INDEX "index_users_github_login" ON "users" ("github_login");

View file

@ -18,6 +18,7 @@ use sea_orm::{
entity::prelude::*, ConnectOptions, DatabaseConnection, DatabaseTransaction, DbErr,
TransactionTrait,
};
use sea_query::OnConflict;
use serde::{Deserialize, Serialize};
use sqlx::migrate::{Migrate, Migration, MigrationSource};
use sqlx::Connection;
@ -42,7 +43,7 @@ pub struct Database {
impl Database {
pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
let mut options = ConnectOptions::new(url.into());
options.min_connections(1).max_connections(max_connections);
options.max_connections(max_connections);
Ok(Self {
url: url.into(),
pool: sea_orm::Database::connect(options).await?,
@ -58,7 +59,7 @@ impl Database {
&self,
migrations_path: &Path,
ignore_checksum_mismatch: bool,
) -> anyhow::Result<Vec<(Migration, Duration)>> {
) -> anyhow::Result<(sqlx::AnyConnection, Vec<(Migration, Duration)>)> {
let migrations = MigrationSource::resolve(migrations_path)
.await
.map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
@ -92,11 +93,45 @@ impl Database {
}
}
Ok(new_migrations)
Ok((connection, new_migrations))
}
pub async fn create_user(
&self,
email_address: &str,
admin: bool,
params: NewUserParams,
) -> Result<NewUserResult> {
self.transact(|tx| async {
let user = user::Entity::insert(user::ActiveModel {
email_address: ActiveValue::set(Some(email_address.into())),
github_login: ActiveValue::set(params.github_login.clone()),
github_user_id: ActiveValue::set(Some(params.github_user_id)),
admin: ActiveValue::set(admin),
metrics_id: ActiveValue::set(Uuid::new_v4()),
..Default::default()
})
.on_conflict(
OnConflict::column(user::Column::GithubLogin)
.update_column(user::Column::GithubLogin)
.to_owned(),
)
.exec_with_returning(&tx)
.await?;
tx.commit().await?;
Ok(NewUserResult {
user_id: user.id,
metrics_id: user.metrics_id.to_string(),
signup_device_id: None,
inviting_user_id: None,
})
})
.await
}
pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<user::Model>> {
let ids = ids.iter().map(|id| id.0).collect::<Vec<_>>();
self.transact(|tx| async {
let tx = tx;
Ok(user::Entity::find()
@ -119,7 +154,7 @@ impl Database {
.one(&tx)
.await?
.ok_or_else(|| anyhow!("could not find participant"))?;
if participant.room_id != room_id.0 {
if participant.room_id != room_id {
return Err(anyhow!("shared project on unexpected room"))?;
}
@ -156,14 +191,14 @@ impl Database {
.await?;
let room = self.get_room(room_id, &tx).await?;
self.commit_room_transaction(room_id, tx, (ProjectId(project.id), room))
self.commit_room_transaction(room_id, tx, (project.id, room))
.await
})
.await
}
async fn get_room(&self, room_id: RoomId, tx: &DatabaseTransaction) -> Result<proto::Room> {
let db_room = room::Entity::find_by_id(room_id.0)
let db_room = room::Entity::find_by_id(room_id)
.one(tx)
.await?
.ok_or_else(|| anyhow!("could not find room"))?;
@ -184,7 +219,7 @@ impl Database {
(Some(0), Some(project_id)) => {
Some(proto::participant_location::Variant::SharedProject(
proto::participant_location::SharedProject {
id: project_id as u64,
id: project_id.to_proto(),
},
))
}
@ -198,7 +233,7 @@ impl Database {
participants.insert(
answering_connection_id,
proto::Participant {
user_id: db_participant.user_id as u64,
user_id: db_participant.user_id.to_proto(),
peer_id: answering_connection_id as u32,
projects: Default::default(),
location: Some(proto::ParticipantLocation { variant: location }),
@ -206,9 +241,9 @@ impl Database {
);
} else {
pending_participants.push(proto::PendingParticipant {
user_id: db_participant.user_id as u64,
calling_user_id: db_participant.calling_user_id as u64,
initial_project_id: db_participant.initial_project_id.map(|id| id as u64),
user_id: db_participant.user_id.to_proto(),
calling_user_id: db_participant.calling_user_id.to_proto(),
initial_project_id: db_participant.initial_project_id.map(|id| id.to_proto()),
});
}
}
@ -225,12 +260,12 @@ impl Database {
let project = if let Some(project) = participant
.projects
.iter_mut()
.find(|project| project.id as i32 == db_project.id)
.find(|project| project.id == db_project.id.to_proto())
{
project
} else {
participant.projects.push(proto::ParticipantProject {
id: db_project.id as u64,
id: db_project.id.to_proto(),
worktree_root_names: Default::default(),
});
participant.projects.last_mut().unwrap()
@ -243,7 +278,7 @@ impl Database {
}
Ok(proto::Room {
id: db_room.id as u64,
id: db_room.id.to_proto(),
live_kit_room: db_room.live_kit_room,
participants: participants.into_values().collect(),
pending_participants,
@ -393,6 +428,84 @@ macro_rules! id_type {
self.0.fmt(f)
}
}
impl From<$name> for sea_query::Value {
fn from(value: $name) -> Self {
sea_query::Value::Int(Some(value.0))
}
}
impl sea_orm::TryGetable for $name {
fn try_get(
res: &sea_orm::QueryResult,
pre: &str,
col: &str,
) -> Result<Self, sea_orm::TryGetError> {
Ok(Self(i32::try_get(res, pre, col)?))
}
}
impl sea_query::ValueType for $name {
fn try_from(v: Value) -> Result<Self, sea_query::ValueTypeErr> {
match v {
Value::TinyInt(Some(int)) => {
Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
}
Value::SmallInt(Some(int)) => {
Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
}
Value::Int(Some(int)) => {
Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
}
Value::BigInt(Some(int)) => {
Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
}
Value::TinyUnsigned(Some(int)) => {
Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
}
Value::SmallUnsigned(Some(int)) => {
Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
}
Value::Unsigned(Some(int)) => {
Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
}
Value::BigUnsigned(Some(int)) => {
Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
}
_ => Err(sea_query::ValueTypeErr),
}
}
fn type_name() -> String {
stringify!($name).into()
}
fn array_type() -> sea_query::ArrayType {
sea_query::ArrayType::Int
}
fn column_type() -> sea_query::ColumnType {
sea_query::ColumnType::Integer(None)
}
}
impl sea_orm::TryFromU64 for $name {
fn try_from_u64(n: u64) -> Result<Self, DbErr> {
Ok(Self(n.try_into().map_err(|_| {
DbErr::ConvertFromU64(concat!(
"error converting ",
stringify!($name),
" to u64"
))
})?))
}
}
impl sea_query::Nullable for $name {
fn null() -> Value {
Value::Int(None)
}
}
};
}
@ -400,6 +513,7 @@ id_type!(UserId);
id_type!(RoomId);
id_type!(RoomParticipantId);
id_type!(ProjectId);
id_type!(ProjectCollaboratorId);
id_type!(WorktreeId);
#[cfg(test)]
@ -412,17 +526,18 @@ mod test {
use lazy_static::lazy_static;
use parking_lot::Mutex;
use rand::prelude::*;
use sea_orm::ConnectionTrait;
use sqlx::migrate::MigrateDatabase;
use std::sync::Arc;
pub struct TestDb {
pub db: Option<Arc<Database>>,
pub connection: Option<sqlx::AnyConnection>,
}
impl TestDb {
pub fn sqlite(background: Arc<Background>) -> Self {
let mut rng = StdRng::from_entropy();
let url = format!("sqlite://file:zed-test-{}?mode=memory", rng.gen::<u128>());
let url = format!("sqlite::memory:");
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_io()
.enable_time()
@ -431,8 +546,17 @@ mod test {
let mut db = runtime.block_on(async {
let db = Database::new(&url, 5).await.unwrap();
let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite");
db.migrate(migrations_path.as_ref(), false).await.unwrap();
let sql = include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/migrations.sqlite/20221109000000_test_schema.sql"
));
db.pool
.execute(sea_orm::Statement::from_string(
db.pool.get_database_backend(),
sql.into(),
))
.await
.unwrap();
db
});
@ -441,6 +565,7 @@ mod test {
Self {
db: Some(Arc::new(db)),
connection: None,
}
}
@ -476,6 +601,7 @@ mod test {
Self {
db: Some(Arc::new(db)),
connection: None,
}
}

View file

@ -1,12 +1,13 @@
use super::{ProjectId, RoomId, UserId};
use sea_orm::entity::prelude::*;
#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
#[sea_orm(table_name = "projects")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: i32,
pub room_id: i32,
pub host_user_id: i32,
pub id: ProjectId,
pub room_id: RoomId,
pub host_user_id: UserId,
pub host_connection_id: i32,
}

View file

@ -1,13 +1,14 @@
use super::{ProjectCollaboratorId, ProjectId, UserId};
use sea_orm::entity::prelude::*;
#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
#[sea_orm(table_name = "project_collaborators")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: i32,
pub project_id: i32,
pub id: ProjectCollaboratorId,
pub project_id: ProjectId,
pub connection_id: i32,
pub user_id: i32,
pub user_id: UserId,
pub replica_id: i32,
pub is_host: bool,
}

View file

@ -1,10 +1,11 @@
use super::RoomId;
use sea_orm::entity::prelude::*;
#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
#[sea_orm(table_name = "room_participants")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: i32,
pub id: RoomId,
pub live_kit_room: String,
}

View file

@ -1,17 +1,18 @@
use super::{ProjectId, RoomId, RoomParticipantId, UserId};
use sea_orm::entity::prelude::*;
#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
#[sea_orm(table_name = "room_participants")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: i32,
pub room_id: i32,
pub user_id: i32,
pub id: RoomParticipantId,
pub room_id: RoomId,
pub user_id: UserId,
pub answering_connection_id: Option<i32>,
pub location_kind: Option<i32>,
pub location_project_id: Option<i32>,
pub initial_project_id: Option<i32>,
pub calling_user_id: i32,
pub location_project_id: Option<ProjectId>,
pub initial_project_id: Option<ProjectId>,
pub calling_user_id: UserId,
pub calling_connection_id: i32,
}

File diff suppressed because it is too large Load diff

View file

@ -1,7 +1,7 @@
use super::UserId;
use sea_orm::entity::prelude::*;
#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
#[derive(Clone, Debug, Default, PartialEq, Eq, DeriveEntityModel)]
#[sea_orm(table_name = "users")]
pub struct Model {
#[sea_orm(primary_key)]
@ -13,6 +13,7 @@ pub struct Model {
pub invite_code: Option<String>,
pub invite_count: i32,
pub connected_once: bool,
pub metrics_id: Uuid,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]

View file

@ -1,12 +1,14 @@
use sea_orm::entity::prelude::*;
use super::ProjectId;
#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
#[sea_orm(table_name = "worktrees")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: i32,
#[sea_orm(primary_key)]
pub project_id: i32,
pub project_id: ProjectId,
pub abs_path: String,
pub root_name: String,
pub visible: bool,