From 7f6d0919c9828a76a98a60cd348af86adfaaaa83 Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Tue, 6 Aug 2024 17:18:08 -0400 Subject: [PATCH] collab: Setup database for LLM service (#15882) This PR puts the initial infrastructure for the LLM service's database in place. The LLM service will be using a separate Postgres database, with its own set of migrations. Currently we only connect to the database in development, as we don't yet have the database setup for the staging/production environments. Release Notes: - N/A --- Dockerfile | 2 + crates/collab/.env.toml | 2 + .../20240806182921_test_schema.sql | 16 ++ ...0806182921_create_providers_and_models.sql | 16 ++ crates/collab/src/db.rs | 53 +------ crates/collab/src/db/ids.rs | 1 + crates/collab/src/db/tests.rs | 6 +- crates/collab/src/lib.rs | 7 + crates/collab/src/llm.rs | 25 ++- crates/collab/src/llm/db.rs | 118 ++++++++++++++ crates/collab/src/llm/db/ids.rs | 7 + crates/collab/src/llm/db/queries.rs | 3 + crates/collab/src/llm/db/queries/providers.rs | 67 ++++++++ crates/collab/src/llm/db/tables.rs | 2 + crates/collab/src/llm/db/tables/model.rs | 31 ++++ crates/collab/src/llm/db/tables/provider.rs | 26 ++++ crates/collab/src/llm/db/tests.rs | 147 ++++++++++++++++++ .../collab/src/llm/db/tests/provider_tests.rs | 30 ++++ crates/collab/src/main.rs | 53 ++++++- crates/collab/src/migrations.rs | 49 ++++++ crates/collab/src/tests/test_server.rs | 3 + docker-compose.sql | 1 + script/bootstrap | 18 ++- script/reset_db | 1 + script/sqlx | 17 -- 25 files changed, 627 insertions(+), 74 deletions(-) create mode 100644 crates/collab/migrations_llm.sqlite/20240806182921_test_schema.sql create mode 100644 crates/collab/migrations_llm/20240806182921_create_providers_and_models.sql create mode 100644 crates/collab/src/llm/db.rs create mode 100644 crates/collab/src/llm/db/ids.rs create mode 100644 crates/collab/src/llm/db/queries.rs create mode 100644 crates/collab/src/llm/db/queries/providers.rs create mode 100644 crates/collab/src/llm/db/tables.rs create mode 100644 crates/collab/src/llm/db/tables/model.rs create mode 100644 crates/collab/src/llm/db/tables/provider.rs create mode 100644 crates/collab/src/llm/db/tests.rs create mode 100644 crates/collab/src/llm/db/tests/provider_tests.rs create mode 100644 crates/collab/src/migrations.rs delete mode 100755 script/sqlx diff --git a/Dockerfile b/Dockerfile index 7d4b6b77c9..03f63aa39c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -27,5 +27,7 @@ RUN apt-get update; \ WORKDIR app COPY --from=builder /app/collab /app/collab COPY --from=builder /app/crates/collab/migrations /app/migrations +COPY --from=builder /app/crates/collab/migrations_llm /app/migrations_llm ENV MIGRATIONS_PATH=/app/migrations +ENV LLM_DATABASE_MIGRATIONS_PATH=/app/migrations_llm ENTRYPOINT ["/app/collab"] diff --git a/crates/collab/.env.toml b/crates/collab/.env.toml index 9646d0c921..f542422e95 100644 --- a/crates/collab/.env.toml +++ b/crates/collab/.env.toml @@ -15,6 +15,8 @@ BLOB_STORE_URL = "http://127.0.0.1:9000" BLOB_STORE_REGION = "the-region" ZED_CLIENT_CHECKSUM_SEED = "development-checksum-seed" SEED_PATH = "crates/collab/seed.default.json" +LLM_DATABASE_URL = "postgres://postgres@localhost/zed_llm" +LLM_DATABASE_MAX_CONNECTIONS = 5 LLM_API_SECRET = "llm-secret" # CLICKHOUSE_URL = "" diff --git a/crates/collab/migrations_llm.sqlite/20240806182921_test_schema.sql b/crates/collab/migrations_llm.sqlite/20240806182921_test_schema.sql new file mode 100644 index 0000000000..7b6feb0302 --- /dev/null +++ b/crates/collab/migrations_llm.sqlite/20240806182921_test_schema.sql @@ -0,0 +1,16 @@ +create table providers ( + id integer primary key autoincrement, + name text not null +); + +create unique index uix_providers_on_name on providers (name); + +create table models ( + id integer primary key autoincrement, + provider_id integer not null references providers (id) on delete cascade, + name text not null +); + +create unique index uix_models_on_provider_id_name on models (provider_id, name); +create index ix_models_on_provider_id on models (provider_id); +create index ix_models_on_name on models (name); diff --git a/crates/collab/migrations_llm/20240806182921_create_providers_and_models.sql b/crates/collab/migrations_llm/20240806182921_create_providers_and_models.sql new file mode 100644 index 0000000000..059e6059dc --- /dev/null +++ b/crates/collab/migrations_llm/20240806182921_create_providers_and_models.sql @@ -0,0 +1,16 @@ +create table if not exists providers ( + id serial primary key, + name text not null +); + +create unique index uix_providers_on_name on providers (name); + +create table if not exists models ( + id serial primary key, + provider_id integer not null references providers (id) on delete cascade, + name text not null +); + +create unique index uix_models_on_provider_id_name on models (provider_id, name); +create index ix_models_on_provider_id on models (provider_id); +create index ix_models_on_name on models (name); diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 5250bef6df..db45b48e2b 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -23,17 +23,12 @@ use sea_orm::{ }; use semantic_version::SemanticVersion; use serde::{Deserialize, Serialize}; -use sqlx::{ - migrate::{Migrate, Migration, MigrationSource}, - Connection, -}; use std::ops::RangeInclusive; use std::{ fmt::Write as _, future::Future, marker::PhantomData, ops::{Deref, DerefMut}, - path::Path, rc::Rc, sync::Arc, time::Duration, @@ -90,54 +85,16 @@ impl Database { }) } + pub fn options(&self) -> &ConnectOptions { + &self.options + } + #[cfg(test)] pub fn reset(&self) { self.rooms.clear(); self.projects.clear(); } - /// Runs the database migrations. - pub async fn migrate( - &self, - migrations_path: &Path, - ignore_checksum_mismatch: bool, - ) -> anyhow::Result> { - let migrations = MigrationSource::resolve(migrations_path) - .await - .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?; - - let mut connection = sqlx::AnyConnection::connect(self.options.get_url()).await?; - - connection.ensure_migrations_table().await?; - let applied_migrations: HashMap<_, _> = connection - .list_applied_migrations() - .await? - .into_iter() - .map(|m| (m.version, m)) - .collect(); - - let mut new_migrations = Vec::new(); - for migration in migrations { - match applied_migrations.get(&migration.version) { - Some(applied_migration) => { - if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch - { - Err(anyhow!( - "checksum mismatch for applied migration {}", - migration.description - ))?; - } - } - None => { - let elapsed = connection.apply(&migration).await?; - new_migrations.push((migration, elapsed)); - } - } - } - - Ok(new_migrations) - } - /// Transaction runs things in a transaction. If you want to call other methods /// and pass the transaction around you need to reborrow the transaction at each /// call site with: `&*tx`. @@ -453,7 +410,7 @@ fn is_serialization_error(error: &Error) -> bool { } /// A handle to a [`DatabaseTransaction`]. -pub struct TransactionHandle(Arc>); +pub struct TransactionHandle(pub(crate) Arc>); impl Deref for TransactionHandle { type Target = DatabaseTransaction; diff --git a/crates/collab/src/db/ids.rs b/crates/collab/src/db/ids.rs index 21206e377e..4443938844 100644 --- a/crates/collab/src/db/ids.rs +++ b/crates/collab/src/db/ids.rs @@ -3,6 +3,7 @@ use rpc::proto; use sea_orm::{entity::prelude::*, DbErr}; use serde::{Deserialize, Serialize}; +#[macro_export] macro_rules! id_type { ($name:ident) => { #[derive( diff --git a/crates/collab/src/db/tests.rs b/crates/collab/src/db/tests.rs index 651bdaf624..c570e87aa6 100644 --- a/crates/collab/src/db/tests.rs +++ b/crates/collab/src/db/tests.rs @@ -11,6 +11,8 @@ mod feature_flag_tests; mod message_tests; mod processed_stripe_event_tests; +use crate::migrations::run_database_migrations; + use super::*; use gpui::BackgroundExecutor; use parking_lot::Mutex; @@ -91,7 +93,9 @@ impl TestDb { .await .unwrap(); let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"); - db.migrate(Path::new(migrations_path), false).await.unwrap(); + run_database_migrations(db.options(), migrations_path, false) + .await + .unwrap(); db.initialize_notification_kinds().await.unwrap(); db }); diff --git a/crates/collab/src/lib.rs b/crates/collab/src/lib.rs index a265e11dda..a795b0e6ba 100644 --- a/crates/collab/src/lib.rs +++ b/crates/collab/src/lib.rs @@ -4,6 +4,7 @@ pub mod db; pub mod env; pub mod executor; pub mod llm; +pub mod migrations; mod rate_limiter; pub mod rpc; pub mod seed; @@ -150,6 +151,9 @@ pub struct Config { pub live_kit_server: Option, pub live_kit_key: Option, pub live_kit_secret: Option, + pub llm_database_url: Option, + pub llm_database_max_connections: Option, + pub llm_database_migrations_path: Option, pub llm_api_secret: Option, pub rust_log: Option, pub log_json: Option, @@ -197,6 +201,9 @@ impl Config { live_kit_server: None, live_kit_key: None, live_kit_secret: None, + llm_database_url: None, + llm_database_max_connections: None, + llm_database_migrations_path: None, llm_api_secret: None, rust_log: None, log_json: None, diff --git a/crates/collab/src/llm.rs b/crates/collab/src/llm.rs index bde9f87e12..4f11351695 100644 --- a/crates/collab/src/llm.rs +++ b/crates/collab/src/llm.rs @@ -1,10 +1,12 @@ mod authorization; +pub mod db; mod token; use crate::api::CloudflareIpCountryHeader; use crate::llm::authorization::authorize_access_to_language_model; +use crate::llm::db::LlmDatabase; use crate::{executor::Executor, Config, Error, Result}; -use anyhow::Context as _; +use anyhow::{anyhow, Context as _}; use axum::TypedHeader; use axum::{ body::Body, @@ -24,11 +26,31 @@ pub use token::*; pub struct LlmState { pub config: Config, pub executor: Executor, + pub db: Option>, pub http_client: IsahcHttpClient, } impl LlmState { pub async fn new(config: Config, executor: Executor) -> Result> { + // TODO: This is temporary until we have the LLM database stood up. + let db = if config.is_development() { + let database_url = config + .llm_database_url + .as_ref() + .ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?; + let max_connections = config + .llm_database_max_connections + .ok_or_else(|| anyhow!("missing LLM_DATABASE_MAX_CONNECTIONS"))?; + + let mut db_options = db::ConnectOptions::new(database_url); + db_options.max_connections(max_connections); + let db = LlmDatabase::new(db_options, executor.clone()).await?; + + Some(Arc::new(db)) + } else { + None + }; + let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION")); let http_client = IsahcHttpClient::builder() .default_header("User-Agent", user_agent) @@ -38,6 +60,7 @@ impl LlmState { let this = Self { config, executor, + db, http_client, }; diff --git a/crates/collab/src/llm/db.rs b/crates/collab/src/llm/db.rs new file mode 100644 index 0000000000..c0eff088bd --- /dev/null +++ b/crates/collab/src/llm/db.rs @@ -0,0 +1,118 @@ +mod ids; +mod queries; +mod tables; + +#[cfg(test)] +mod tests; + +pub use ids::*; +pub use tables::*; + +#[cfg(test)] +pub use tests::TestLlmDb; + +use std::future::Future; +use std::sync::Arc; + +use anyhow::anyhow; +use sea_orm::prelude::*; +pub use sea_orm::ConnectOptions; +use sea_orm::{ + ActiveValue, DatabaseConnection, DatabaseTransaction, IsolationLevel, TransactionTrait, +}; + +use crate::db::TransactionHandle; +use crate::executor::Executor; +use crate::Result; + +/// The database for the LLM service. +pub struct LlmDatabase { + options: ConnectOptions, + pool: DatabaseConnection, + #[allow(unused)] + executor: Executor, + #[cfg(test)] + runtime: Option, +} + +impl LlmDatabase { + /// Connects to the database with the given options + pub async fn new(options: ConnectOptions, executor: Executor) -> Result { + sqlx::any::install_default_drivers(); + Ok(Self { + options: options.clone(), + pool: sea_orm::Database::connect(options).await?, + executor, + #[cfg(test)] + runtime: None, + }) + } + + pub fn options(&self) -> &ConnectOptions { + &self.options + } + + pub async fn transaction(&self, f: F) -> Result + where + F: Send + Fn(TransactionHandle) -> Fut, + Fut: Send + Future>, + { + let body = async { + let (tx, result) = self.with_transaction(&f).await?; + match result { + Ok(result) => match tx.commit().await.map_err(Into::into) { + Ok(()) => return Ok(result), + Err(error) => { + return Err(error); + } + }, + Err(error) => { + tx.rollback().await?; + return Err(error); + } + } + }; + + self.run(body).await + } + + async fn with_transaction(&self, f: &F) -> Result<(DatabaseTransaction, Result)> + where + F: Send + Fn(TransactionHandle) -> Fut, + Fut: Send + Future>, + { + let tx = self + .pool + .begin_with_config(Some(IsolationLevel::ReadCommitted), None) + .await?; + + let mut tx = Arc::new(Some(tx)); + let result = f(TransactionHandle(tx.clone())).await; + let Some(tx) = Arc::get_mut(&mut tx).and_then(|tx| tx.take()) else { + return Err(anyhow!( + "couldn't complete transaction because it's still in use" + ))?; + }; + + Ok((tx, result)) + } + + async fn run(&self, future: F) -> Result + where + F: Future>, + { + #[cfg(test)] + { + if let Executor::Deterministic(executor) = &self.executor { + executor.simulate_random_delay().await; + } + + self.runtime.as_ref().unwrap().block_on(future) + } + + #[cfg(not(test))] + { + future.await + } + } +} diff --git a/crates/collab/src/llm/db/ids.rs b/crates/collab/src/llm/db/ids.rs new file mode 100644 index 0000000000..d4613e9c7f --- /dev/null +++ b/crates/collab/src/llm/db/ids.rs @@ -0,0 +1,7 @@ +use sea_orm::{entity::prelude::*, DbErr}; +use serde::{Deserialize, Serialize}; + +use crate::id_type; + +id_type!(ProviderId); +id_type!(ModelId); diff --git a/crates/collab/src/llm/db/queries.rs b/crates/collab/src/llm/db/queries.rs new file mode 100644 index 0000000000..3e02c17a6a --- /dev/null +++ b/crates/collab/src/llm/db/queries.rs @@ -0,0 +1,3 @@ +use super::*; + +pub mod providers; diff --git a/crates/collab/src/llm/db/queries/providers.rs b/crates/collab/src/llm/db/queries/providers.rs new file mode 100644 index 0000000000..d96f8453e2 --- /dev/null +++ b/crates/collab/src/llm/db/queries/providers.rs @@ -0,0 +1,67 @@ +use sea_orm::sea_query::OnConflict; +use sea_orm::QueryOrder; + +use super::*; + +impl LlmDatabase { + pub async fn initialize_providers(&self) -> Result<()> { + self.transaction(|tx| async move { + let providers_and_models = vec![ + ("anthropic", "claude-3-5-sonnet"), + ("anthropic", "claude-3-opus"), + ("anthropic", "claude-3-sonnet"), + ("anthropic", "claude-3-haiku"), + ]; + + for (provider_name, model_name) in providers_and_models { + let insert_provider = provider::Entity::insert(provider::ActiveModel { + name: ActiveValue::set(provider_name.to_owned()), + ..Default::default() + }) + .on_conflict( + OnConflict::columns([provider::Column::Name]) + .update_column(provider::Column::Name) + .to_owned(), + ); + + let provider = if tx.support_returning() { + insert_provider.exec_with_returning(&*tx).await? + } else { + insert_provider.exec_without_returning(&*tx).await?; + provider::Entity::find() + .filter(provider::Column::Name.eq(provider_name)) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("failed to insert provider"))? + }; + + model::Entity::insert(model::ActiveModel { + provider_id: ActiveValue::set(provider.id), + name: ActiveValue::set(model_name.to_owned()), + ..Default::default() + }) + .on_conflict( + OnConflict::columns([model::Column::ProviderId, model::Column::Name]) + .update_column(model::Column::Name) + .to_owned(), + ) + .exec_without_returning(&*tx) + .await?; + } + + Ok(()) + }) + .await + } + + /// Returns the list of LLM providers. + pub async fn list_providers(&self) -> Result> { + self.transaction(|tx| async move { + Ok(provider::Entity::find() + .order_by_asc(provider::Column::Name) + .all(&*tx) + .await?) + }) + .await + } +} diff --git a/crates/collab/src/llm/db/tables.rs b/crates/collab/src/llm/db/tables.rs new file mode 100644 index 0000000000..89b42283de --- /dev/null +++ b/crates/collab/src/llm/db/tables.rs @@ -0,0 +1,2 @@ +pub mod model; +pub mod provider; diff --git a/crates/collab/src/llm/db/tables/model.rs b/crates/collab/src/llm/db/tables/model.rs new file mode 100644 index 0000000000..7242365acf --- /dev/null +++ b/crates/collab/src/llm/db/tables/model.rs @@ -0,0 +1,31 @@ +use sea_orm::entity::prelude::*; + +use crate::llm::db::{ModelId, ProviderId}; + +/// An LLM model. +#[derive(Clone, Debug, PartialEq, DeriveEntityModel)] +#[sea_orm(table_name = "models")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: ModelId, + pub provider_id: ProviderId, + pub name: String, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::provider::Entity", + from = "Column::ProviderId", + to = "super::provider::Column::Id" + )] + Provider, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Provider.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/llm/db/tables/provider.rs b/crates/collab/src/llm/db/tables/provider.rs new file mode 100644 index 0000000000..7f9aa8ee0d --- /dev/null +++ b/crates/collab/src/llm/db/tables/provider.rs @@ -0,0 +1,26 @@ +use sea_orm::entity::prelude::*; + +use crate::llm::db::ProviderId; + +/// An LLM provider. +#[derive(Clone, Debug, PartialEq, DeriveEntityModel)] +#[sea_orm(table_name = "providers")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: ProviderId, + pub name: String, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm(has_many = "super::model::Entity")] + Models, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Models.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/llm/db/tests.rs b/crates/collab/src/llm/db/tests.rs new file mode 100644 index 0000000000..4fadc02cad --- /dev/null +++ b/crates/collab/src/llm/db/tests.rs @@ -0,0 +1,147 @@ +mod provider_tests; + +use gpui::BackgroundExecutor; +use parking_lot::Mutex; +use rand::prelude::*; +use sea_orm::ConnectionTrait; +use sqlx::migrate::MigrateDatabase; +use std::sync::Arc; +use std::time::Duration; + +use crate::migrations::run_database_migrations; + +use super::*; + +pub struct TestLlmDb { + pub db: Option>, + pub connection: Option, +} + +impl TestLlmDb { + pub fn sqlite(background: BackgroundExecutor) -> Self { + let url = "sqlite::memory:"; + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_io() + .enable_time() + .build() + .unwrap(); + + let mut db = runtime.block_on(async { + let mut options = ConnectOptions::new(url); + options.max_connections(5); + let db = LlmDatabase::new(options, Executor::Deterministic(background)) + .await + .unwrap(); + let sql = include_str!(concat!( + env!("CARGO_MANIFEST_DIR"), + "/migrations_llm.sqlite/20240806182921_test_schema.sql" + )); + db.pool + .execute(sea_orm::Statement::from_string( + db.pool.get_database_backend(), + sql, + )) + .await + .unwrap(); + db + }); + + db.runtime = Some(runtime); + + Self { + db: Some(Arc::new(db)), + connection: None, + } + } + + pub fn postgres(background: BackgroundExecutor) -> Self { + static LOCK: Mutex<()> = Mutex::new(()); + + let _guard = LOCK.lock(); + let mut rng = StdRng::from_entropy(); + let url = format!( + "postgres://postgres@localhost/zed-llm-test-{}", + rng.gen::() + ); + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_io() + .enable_time() + .build() + .unwrap(); + + let mut db = runtime.block_on(async { + sqlx::Postgres::create_database(&url) + .await + .expect("failed to create test db"); + let mut options = ConnectOptions::new(url); + options + .max_connections(5) + .idle_timeout(Duration::from_secs(0)); + let db = LlmDatabase::new(options, Executor::Deterministic(background)) + .await + .unwrap(); + let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations_llm"); + run_database_migrations(db.options(), migrations_path, false) + .await + .unwrap(); + db + }); + + db.runtime = Some(runtime); + + Self { + db: Some(Arc::new(db)), + connection: None, + } + } + + pub fn db(&self) -> &Arc { + self.db.as_ref().unwrap() + } +} + +#[macro_export] +macro_rules! test_both_llm_dbs { + ($test_name:ident, $postgres_test_name:ident, $sqlite_test_name:ident) => { + #[cfg(target_os = "macos")] + #[gpui::test] + async fn $postgres_test_name(cx: &mut gpui::TestAppContext) { + let test_db = $crate::llm::db::TestLlmDb::postgres(cx.executor().clone()); + $test_name(test_db.db()).await; + } + + #[gpui::test] + async fn $sqlite_test_name(cx: &mut gpui::TestAppContext) { + let test_db = $crate::llm::db::TestLlmDb::sqlite(cx.executor().clone()); + $test_name(test_db.db()).await; + } + }; +} + +impl Drop for TestLlmDb { + fn drop(&mut self) { + let db = self.db.take().unwrap(); + if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() { + db.runtime.as_ref().unwrap().block_on(async { + use util::ResultExt; + let query = " + SELECT pg_terminate_backend(pg_stat_activity.pid) + FROM pg_stat_activity + WHERE + pg_stat_activity.datname = current_database() AND + pid <> pg_backend_pid(); + "; + db.pool + .execute(sea_orm::Statement::from_string( + db.pool.get_database_backend(), + query, + )) + .await + .log_err(); + sqlx::Postgres::drop_database(db.options.get_url()) + .await + .log_err(); + }) + } + } +} diff --git a/crates/collab/src/llm/db/tests/provider_tests.rs b/crates/collab/src/llm/db/tests/provider_tests.rs new file mode 100644 index 0000000000..2b0d692bb8 --- /dev/null +++ b/crates/collab/src/llm/db/tests/provider_tests.rs @@ -0,0 +1,30 @@ +use std::sync::Arc; + +use pretty_assertions::assert_eq; + +use crate::llm::db::LlmDatabase; +use crate::test_both_llm_dbs; + +test_both_llm_dbs!( + test_initialize_providers, + test_initialize_providers_postgres, + test_initialize_providers_sqlite +); + +async fn test_initialize_providers(db: &Arc) { + let initial_providers = db.list_providers().await.unwrap(); + assert_eq!(initial_providers, vec![]); + + db.initialize_providers().await.unwrap(); + + // Do it twice, to make sure the operation is idempotent. + db.initialize_providers().await.unwrap(); + + let providers = db.list_providers().await.unwrap(); + + let provider_names = providers + .into_iter() + .map(|provider| provider.name) + .collect::>(); + assert_eq!(provider_names, vec!["anthropic".to_string()]); +} diff --git a/crates/collab/src/main.rs b/crates/collab/src/main.rs index 60a6967ca2..ebf6f9bb2d 100644 --- a/crates/collab/src/main.rs +++ b/crates/collab/src/main.rs @@ -5,6 +5,8 @@ use axum::{ routing::get, Extension, Router, }; +use collab::llm::db::LlmDatabase; +use collab::migrations::run_database_migrations; use collab::{api::billing::poll_stripe_events_periodically, llm::LlmState, ServiceMode}; use collab::{ api::fetch_extensions_from_blob_store_periodically, db, env, executor::Executor, @@ -45,7 +47,7 @@ async fn main() -> Result<()> { } Some("migrate") => { let config = envy::from_env::().expect("error loading config"); - run_migrations(&config).await?; + setup_app_database(&config).await?; } Some("seed") => { let config = envy::from_env::().expect("error loading config"); @@ -81,6 +83,8 @@ async fn main() -> Result<()> { let mut on_shutdown = None; if mode.is_llm() { + setup_llm_database(&config).await?; + let state = LlmState::new(config.clone(), Executor::Production).await?; app = app @@ -89,7 +93,7 @@ async fn main() -> Result<()> { } if mode.is_collab() || mode.is_api() { - run_migrations(&config).await?; + setup_app_database(&config).await?; let state = AppState::new(config, Executor::Production).await?; @@ -203,7 +207,7 @@ async fn main() -> Result<()> { Ok(()) } -async fn run_migrations(config: &Config) -> Result<()> { +async fn setup_app_database(config: &Config) -> Result<()> { let db_options = db::ConnectOptions::new(config.database_url.clone()); let mut db = Database::new(db_options, Executor::Production).await?; @@ -216,7 +220,7 @@ async fn run_migrations(config: &Config) -> Result<()> { Path::new(default_migrations) }); - let migrations = db.migrate(&migrations_path, false).await?; + let migrations = run_database_migrations(db.options(), migrations_path, false).await?; for (migration, duration) in migrations { log::info!( "Migrated {} {} {:?}", @@ -232,7 +236,46 @@ async fn run_migrations(config: &Config) -> Result<()> { collab::seed::seed(&config, &db, false).await?; } - return Ok(()); + Ok(()) +} + +async fn setup_llm_database(config: &Config) -> Result<()> { + // TODO: This is temporary until we have the LLM database stood up. + if !config.is_development() { + return Ok(()); + } + + let database_url = config + .llm_database_url + .as_ref() + .ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?; + + let db_options = db::ConnectOptions::new(database_url.clone()); + let db = LlmDatabase::new(db_options, Executor::Production).await?; + + let migrations_path = config + .llm_database_migrations_path + .as_deref() + .unwrap_or_else(|| { + #[cfg(feature = "sqlite")] + let default_migrations = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations_llm.sqlite"); + #[cfg(not(feature = "sqlite"))] + let default_migrations = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations_llm"); + + Path::new(default_migrations) + }); + + let migrations = run_database_migrations(db.options(), migrations_path, false).await?; + for (migration, duration) in migrations { + log::info!( + "Migrated {} {} {:?}", + migration.version, + migration.description, + duration + ); + } + + Ok(()) } async fn handle_root(Extension(mode): Extension) -> String { diff --git a/crates/collab/src/migrations.rs b/crates/collab/src/migrations.rs new file mode 100644 index 0000000000..8887a4fb3e --- /dev/null +++ b/crates/collab/src/migrations.rs @@ -0,0 +1,49 @@ +use std::path::Path; +use std::time::Duration; + +use anyhow::{anyhow, Result}; +use collections::HashMap; +use sea_orm::ConnectOptions; +use sqlx::migrate::{Migrate, Migration, MigrationSource}; +use sqlx::Connection; + +/// Runs the database migrations for the specified database. +pub async fn run_database_migrations( + database_options: &ConnectOptions, + migrations_path: impl AsRef, + ignore_checksum_mismatch: bool, +) -> Result> { + let migrations = MigrationSource::resolve(migrations_path.as_ref()) + .await + .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?; + + let mut connection = sqlx::AnyConnection::connect(database_options.get_url()).await?; + + connection.ensure_migrations_table().await?; + let applied_migrations: HashMap<_, _> = connection + .list_applied_migrations() + .await? + .into_iter() + .map(|migration| (migration.version, migration)) + .collect(); + + let mut new_migrations = Vec::new(); + for migration in migrations { + match applied_migrations.get(&migration.version) { + Some(applied_migration) => { + if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch { + Err(anyhow!( + "checksum mismatch for applied migration {}", + migration.description + ))?; + } + } + None => { + let elapsed = connection.apply(&migration).await?; + new_migrations.push((migration, elapsed)); + } + } + } + + Ok(new_migrations) +} diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index bf8b031e5e..b420960122 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -651,6 +651,9 @@ impl TestServer { live_kit_server: None, live_kit_key: None, live_kit_secret: None, + llm_database_url: None, + llm_database_max_connections: None, + llm_database_migrations_path: None, llm_api_secret: None, rust_log: None, log_json: None, diff --git a/docker-compose.sql b/docker-compose.sql index 9cbd0bf0d1..7de55a1f98 100644 --- a/docker-compose.sql +++ b/docker-compose.sql @@ -1 +1,2 @@ create database zed; +create database zed_llm; diff --git a/script/bootstrap b/script/bootstrap index 0c91bfbc20..396e3f6c00 100755 --- a/script/bootstrap +++ b/script/bootstrap @@ -1,5 +1,7 @@ #!/usr/bin/env bash +set -e + if [[ "$OSTYPE" == "linux-gnu"* ]]; then echo "Linux dependencies..." script/linux @@ -8,5 +10,17 @@ else which foreman > /dev/null || brew install foreman fi -echo "creating database..." -script/sqlx database create +# Install sqlx-cli if needed +if [[ "$(sqlx --version)" != "sqlx-cli 0.5.7" ]]; then + echo "sqlx-cli not found or not the required version, installing version 0.5.7..." + cargo install sqlx-cli --version 0.5.7 +fi + +cd crates/collab + +# Export contents of .env.toml +eval "$(cargo run --bin dotenv)" + +echo "creating databases..." +sqlx database create --database-url "$DATABASE_URL" +sqlx database create --database-url "$LLM_DATABASE_URL" diff --git a/script/reset_db b/script/reset_db index 87ce786aa7..27697bc2a0 100755 --- a/script/reset_db +++ b/script/reset_db @@ -1,2 +1,3 @@ psql -c "DROP DATABASE zed (FORCE);" +psql -c "DROP DATABASE zed_llm (FORCE);" script/bootstrap diff --git a/script/sqlx b/script/sqlx deleted file mode 100755 index 038218c465..0000000000 --- a/script/sqlx +++ /dev/null @@ -1,17 +0,0 @@ -#!/bin/bash - -set -e - -# Install sqlx-cli if needed -if [[ "$(sqlx --version)" != "sqlx-cli 0.5.7" ]]; then - echo "sqlx-cli not found or not the required version, installing version 0.5.7..." - cargo install sqlx-cli --version 0.5.7 -fi - -cd crates/collab - -# Export contents of .env.toml -eval "$(cargo run --bin dotenv)" - -# Run sqlx command -sqlx $@