mirror of
https://github.com/zed-industries/zed.git
synced 2024-10-22 22:46:27 +00:00
collab: Setup database for LLM service (#15882)
This PR puts the initial infrastructure for the LLM service's database in place. The LLM service will be using a separate Postgres database, with its own set of migrations. Currently we only connect to the database in development, as we don't yet have the database setup for the staging/production environments. Release Notes: - N/A
This commit is contained in:
parent
a64906779b
commit
7f6d0919c9
25 changed files with 627 additions and 74 deletions
|
@ -27,5 +27,7 @@ RUN apt-get update; \
|
|||
WORKDIR app
|
||||
COPY --from=builder /app/collab /app/collab
|
||||
COPY --from=builder /app/crates/collab/migrations /app/migrations
|
||||
COPY --from=builder /app/crates/collab/migrations_llm /app/migrations_llm
|
||||
ENV MIGRATIONS_PATH=/app/migrations
|
||||
ENV LLM_DATABASE_MIGRATIONS_PATH=/app/migrations_llm
|
||||
ENTRYPOINT ["/app/collab"]
|
||||
|
|
|
@ -15,6 +15,8 @@ BLOB_STORE_URL = "http://127.0.0.1:9000"
|
|||
BLOB_STORE_REGION = "the-region"
|
||||
ZED_CLIENT_CHECKSUM_SEED = "development-checksum-seed"
|
||||
SEED_PATH = "crates/collab/seed.default.json"
|
||||
LLM_DATABASE_URL = "postgres://postgres@localhost/zed_llm"
|
||||
LLM_DATABASE_MAX_CONNECTIONS = 5
|
||||
LLM_API_SECRET = "llm-secret"
|
||||
|
||||
# CLICKHOUSE_URL = ""
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
create table providers (
|
||||
id integer primary key autoincrement,
|
||||
name text not null
|
||||
);
|
||||
|
||||
create unique index uix_providers_on_name on providers (name);
|
||||
|
||||
create table models (
|
||||
id integer primary key autoincrement,
|
||||
provider_id integer not null references providers (id) on delete cascade,
|
||||
name text not null
|
||||
);
|
||||
|
||||
create unique index uix_models_on_provider_id_name on models (provider_id, name);
|
||||
create index ix_models_on_provider_id on models (provider_id);
|
||||
create index ix_models_on_name on models (name);
|
|
@ -0,0 +1,16 @@
|
|||
create table if not exists providers (
|
||||
id serial primary key,
|
||||
name text not null
|
||||
);
|
||||
|
||||
create unique index uix_providers_on_name on providers (name);
|
||||
|
||||
create table if not exists models (
|
||||
id serial primary key,
|
||||
provider_id integer not null references providers (id) on delete cascade,
|
||||
name text not null
|
||||
);
|
||||
|
||||
create unique index uix_models_on_provider_id_name on models (provider_id, name);
|
||||
create index ix_models_on_provider_id on models (provider_id);
|
||||
create index ix_models_on_name on models (name);
|
|
@ -23,17 +23,12 @@ use sea_orm::{
|
|||
};
|
||||
use semantic_version::SemanticVersion;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::{
|
||||
migrate::{Migrate, Migration, MigrationSource},
|
||||
Connection,
|
||||
};
|
||||
use std::ops::RangeInclusive;
|
||||
use std::{
|
||||
fmt::Write as _,
|
||||
future::Future,
|
||||
marker::PhantomData,
|
||||
ops::{Deref, DerefMut},
|
||||
path::Path,
|
||||
rc::Rc,
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
|
@ -90,54 +85,16 @@ impl Database {
|
|||
})
|
||||
}
|
||||
|
||||
pub fn options(&self) -> &ConnectOptions {
|
||||
&self.options
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn reset(&self) {
|
||||
self.rooms.clear();
|
||||
self.projects.clear();
|
||||
}
|
||||
|
||||
/// Runs the database migrations.
|
||||
pub async fn migrate(
|
||||
&self,
|
||||
migrations_path: &Path,
|
||||
ignore_checksum_mismatch: bool,
|
||||
) -> anyhow::Result<Vec<(Migration, Duration)>> {
|
||||
let migrations = MigrationSource::resolve(migrations_path)
|
||||
.await
|
||||
.map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
|
||||
|
||||
let mut connection = sqlx::AnyConnection::connect(self.options.get_url()).await?;
|
||||
|
||||
connection.ensure_migrations_table().await?;
|
||||
let applied_migrations: HashMap<_, _> = connection
|
||||
.list_applied_migrations()
|
||||
.await?
|
||||
.into_iter()
|
||||
.map(|m| (m.version, m))
|
||||
.collect();
|
||||
|
||||
let mut new_migrations = Vec::new();
|
||||
for migration in migrations {
|
||||
match applied_migrations.get(&migration.version) {
|
||||
Some(applied_migration) => {
|
||||
if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
|
||||
{
|
||||
Err(anyhow!(
|
||||
"checksum mismatch for applied migration {}",
|
||||
migration.description
|
||||
))?;
|
||||
}
|
||||
}
|
||||
None => {
|
||||
let elapsed = connection.apply(&migration).await?;
|
||||
new_migrations.push((migration, elapsed));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(new_migrations)
|
||||
}
|
||||
|
||||
/// Transaction runs things in a transaction. If you want to call other methods
|
||||
/// and pass the transaction around you need to reborrow the transaction at each
|
||||
/// call site with: `&*tx`.
|
||||
|
@ -453,7 +410,7 @@ fn is_serialization_error(error: &Error) -> bool {
|
|||
}
|
||||
|
||||
/// A handle to a [`DatabaseTransaction`].
|
||||
pub struct TransactionHandle(Arc<Option<DatabaseTransaction>>);
|
||||
pub struct TransactionHandle(pub(crate) Arc<Option<DatabaseTransaction>>);
|
||||
|
||||
impl Deref for TransactionHandle {
|
||||
type Target = DatabaseTransaction;
|
||||
|
|
|
@ -3,6 +3,7 @@ use rpc::proto;
|
|||
use sea_orm::{entity::prelude::*, DbErr};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! id_type {
|
||||
($name:ident) => {
|
||||
#[derive(
|
||||
|
|
|
@ -11,6 +11,8 @@ mod feature_flag_tests;
|
|||
mod message_tests;
|
||||
mod processed_stripe_event_tests;
|
||||
|
||||
use crate::migrations::run_database_migrations;
|
||||
|
||||
use super::*;
|
||||
use gpui::BackgroundExecutor;
|
||||
use parking_lot::Mutex;
|
||||
|
@ -91,7 +93,9 @@ impl TestDb {
|
|||
.await
|
||||
.unwrap();
|
||||
let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
|
||||
db.migrate(Path::new(migrations_path), false).await.unwrap();
|
||||
run_database_migrations(db.options(), migrations_path, false)
|
||||
.await
|
||||
.unwrap();
|
||||
db.initialize_notification_kinds().await.unwrap();
|
||||
db
|
||||
});
|
||||
|
|
|
@ -4,6 +4,7 @@ pub mod db;
|
|||
pub mod env;
|
||||
pub mod executor;
|
||||
pub mod llm;
|
||||
pub mod migrations;
|
||||
mod rate_limiter;
|
||||
pub mod rpc;
|
||||
pub mod seed;
|
||||
|
@ -150,6 +151,9 @@ pub struct Config {
|
|||
pub live_kit_server: Option<String>,
|
||||
pub live_kit_key: Option<String>,
|
||||
pub live_kit_secret: Option<String>,
|
||||
pub llm_database_url: Option<String>,
|
||||
pub llm_database_max_connections: Option<u32>,
|
||||
pub llm_database_migrations_path: Option<PathBuf>,
|
||||
pub llm_api_secret: Option<String>,
|
||||
pub rust_log: Option<String>,
|
||||
pub log_json: Option<bool>,
|
||||
|
@ -197,6 +201,9 @@ impl Config {
|
|||
live_kit_server: None,
|
||||
live_kit_key: None,
|
||||
live_kit_secret: None,
|
||||
llm_database_url: None,
|
||||
llm_database_max_connections: None,
|
||||
llm_database_migrations_path: None,
|
||||
llm_api_secret: None,
|
||||
rust_log: None,
|
||||
log_json: None,
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
mod authorization;
|
||||
pub mod db;
|
||||
mod token;
|
||||
|
||||
use crate::api::CloudflareIpCountryHeader;
|
||||
use crate::llm::authorization::authorize_access_to_language_model;
|
||||
use crate::llm::db::LlmDatabase;
|
||||
use crate::{executor::Executor, Config, Error, Result};
|
||||
use anyhow::Context as _;
|
||||
use anyhow::{anyhow, Context as _};
|
||||
use axum::TypedHeader;
|
||||
use axum::{
|
||||
body::Body,
|
||||
|
@ -24,11 +26,31 @@ pub use token::*;
|
|||
pub struct LlmState {
|
||||
pub config: Config,
|
||||
pub executor: Executor,
|
||||
pub db: Option<Arc<LlmDatabase>>,
|
||||
pub http_client: IsahcHttpClient,
|
||||
}
|
||||
|
||||
impl LlmState {
|
||||
pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
|
||||
// TODO: This is temporary until we have the LLM database stood up.
|
||||
let db = if config.is_development() {
|
||||
let database_url = config
|
||||
.llm_database_url
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?;
|
||||
let max_connections = config
|
||||
.llm_database_max_connections
|
||||
.ok_or_else(|| anyhow!("missing LLM_DATABASE_MAX_CONNECTIONS"))?;
|
||||
|
||||
let mut db_options = db::ConnectOptions::new(database_url);
|
||||
db_options.max_connections(max_connections);
|
||||
let db = LlmDatabase::new(db_options, executor.clone()).await?;
|
||||
|
||||
Some(Arc::new(db))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
|
||||
let http_client = IsahcHttpClient::builder()
|
||||
.default_header("User-Agent", user_agent)
|
||||
|
@ -38,6 +60,7 @@ impl LlmState {
|
|||
let this = Self {
|
||||
config,
|
||||
executor,
|
||||
db,
|
||||
http_client,
|
||||
};
|
||||
|
||||
|
|
118
crates/collab/src/llm/db.rs
Normal file
118
crates/collab/src/llm/db.rs
Normal file
|
@ -0,0 +1,118 @@
|
|||
mod ids;
|
||||
mod queries;
|
||||
mod tables;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
pub use ids::*;
|
||||
pub use tables::*;
|
||||
|
||||
#[cfg(test)]
|
||||
pub use tests::TestLlmDb;
|
||||
|
||||
use std::future::Future;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::anyhow;
|
||||
use sea_orm::prelude::*;
|
||||
pub use sea_orm::ConnectOptions;
|
||||
use sea_orm::{
|
||||
ActiveValue, DatabaseConnection, DatabaseTransaction, IsolationLevel, TransactionTrait,
|
||||
};
|
||||
|
||||
use crate::db::TransactionHandle;
|
||||
use crate::executor::Executor;
|
||||
use crate::Result;
|
||||
|
||||
/// The database for the LLM service.
|
||||
pub struct LlmDatabase {
|
||||
options: ConnectOptions,
|
||||
pool: DatabaseConnection,
|
||||
#[allow(unused)]
|
||||
executor: Executor,
|
||||
#[cfg(test)]
|
||||
runtime: Option<tokio::runtime::Runtime>,
|
||||
}
|
||||
|
||||
impl LlmDatabase {
|
||||
/// Connects to the database with the given options
|
||||
pub async fn new(options: ConnectOptions, executor: Executor) -> Result<Self> {
|
||||
sqlx::any::install_default_drivers();
|
||||
Ok(Self {
|
||||
options: options.clone(),
|
||||
pool: sea_orm::Database::connect(options).await?,
|
||||
executor,
|
||||
#[cfg(test)]
|
||||
runtime: None,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn options(&self) -> &ConnectOptions {
|
||||
&self.options
|
||||
}
|
||||
|
||||
pub async fn transaction<F, Fut, T>(&self, f: F) -> Result<T>
|
||||
where
|
||||
F: Send + Fn(TransactionHandle) -> Fut,
|
||||
Fut: Send + Future<Output = Result<T>>,
|
||||
{
|
||||
let body = async {
|
||||
let (tx, result) = self.with_transaction(&f).await?;
|
||||
match result {
|
||||
Ok(result) => match tx.commit().await.map_err(Into::into) {
|
||||
Ok(()) => return Ok(result),
|
||||
Err(error) => {
|
||||
return Err(error);
|
||||
}
|
||||
},
|
||||
Err(error) => {
|
||||
tx.rollback().await?;
|
||||
return Err(error);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
self.run(body).await
|
||||
}
|
||||
|
||||
async fn with_transaction<F, Fut, T>(&self, f: &F) -> Result<(DatabaseTransaction, Result<T>)>
|
||||
where
|
||||
F: Send + Fn(TransactionHandle) -> Fut,
|
||||
Fut: Send + Future<Output = Result<T>>,
|
||||
{
|
||||
let tx = self
|
||||
.pool
|
||||
.begin_with_config(Some(IsolationLevel::ReadCommitted), None)
|
||||
.await?;
|
||||
|
||||
let mut tx = Arc::new(Some(tx));
|
||||
let result = f(TransactionHandle(tx.clone())).await;
|
||||
let Some(tx) = Arc::get_mut(&mut tx).and_then(|tx| tx.take()) else {
|
||||
return Err(anyhow!(
|
||||
"couldn't complete transaction because it's still in use"
|
||||
))?;
|
||||
};
|
||||
|
||||
Ok((tx, result))
|
||||
}
|
||||
|
||||
async fn run<F, T>(&self, future: F) -> Result<T>
|
||||
where
|
||||
F: Future<Output = Result<T>>,
|
||||
{
|
||||
#[cfg(test)]
|
||||
{
|
||||
if let Executor::Deterministic(executor) = &self.executor {
|
||||
executor.simulate_random_delay().await;
|
||||
}
|
||||
|
||||
self.runtime.as_ref().unwrap().block_on(future)
|
||||
}
|
||||
|
||||
#[cfg(not(test))]
|
||||
{
|
||||
future.await
|
||||
}
|
||||
}
|
||||
}
|
7
crates/collab/src/llm/db/ids.rs
Normal file
7
crates/collab/src/llm/db/ids.rs
Normal file
|
@ -0,0 +1,7 @@
|
|||
use sea_orm::{entity::prelude::*, DbErr};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::id_type;
|
||||
|
||||
id_type!(ProviderId);
|
||||
id_type!(ModelId);
|
3
crates/collab/src/llm/db/queries.rs
Normal file
3
crates/collab/src/llm/db/queries.rs
Normal file
|
@ -0,0 +1,3 @@
|
|||
use super::*;
|
||||
|
||||
pub mod providers;
|
67
crates/collab/src/llm/db/queries/providers.rs
Normal file
67
crates/collab/src/llm/db/queries/providers.rs
Normal file
|
@ -0,0 +1,67 @@
|
|||
use sea_orm::sea_query::OnConflict;
|
||||
use sea_orm::QueryOrder;
|
||||
|
||||
use super::*;
|
||||
|
||||
impl LlmDatabase {
|
||||
pub async fn initialize_providers(&self) -> Result<()> {
|
||||
self.transaction(|tx| async move {
|
||||
let providers_and_models = vec![
|
||||
("anthropic", "claude-3-5-sonnet"),
|
||||
("anthropic", "claude-3-opus"),
|
||||
("anthropic", "claude-3-sonnet"),
|
||||
("anthropic", "claude-3-haiku"),
|
||||
];
|
||||
|
||||
for (provider_name, model_name) in providers_and_models {
|
||||
let insert_provider = provider::Entity::insert(provider::ActiveModel {
|
||||
name: ActiveValue::set(provider_name.to_owned()),
|
||||
..Default::default()
|
||||
})
|
||||
.on_conflict(
|
||||
OnConflict::columns([provider::Column::Name])
|
||||
.update_column(provider::Column::Name)
|
||||
.to_owned(),
|
||||
);
|
||||
|
||||
let provider = if tx.support_returning() {
|
||||
insert_provider.exec_with_returning(&*tx).await?
|
||||
} else {
|
||||
insert_provider.exec_without_returning(&*tx).await?;
|
||||
provider::Entity::find()
|
||||
.filter(provider::Column::Name.eq(provider_name))
|
||||
.one(&*tx)
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("failed to insert provider"))?
|
||||
};
|
||||
|
||||
model::Entity::insert(model::ActiveModel {
|
||||
provider_id: ActiveValue::set(provider.id),
|
||||
name: ActiveValue::set(model_name.to_owned()),
|
||||
..Default::default()
|
||||
})
|
||||
.on_conflict(
|
||||
OnConflict::columns([model::Column::ProviderId, model::Column::Name])
|
||||
.update_column(model::Column::Name)
|
||||
.to_owned(),
|
||||
)
|
||||
.exec_without_returning(&*tx)
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// Returns the list of LLM providers.
|
||||
pub async fn list_providers(&self) -> Result<Vec<provider::Model>> {
|
||||
self.transaction(|tx| async move {
|
||||
Ok(provider::Entity::find()
|
||||
.order_by_asc(provider::Column::Name)
|
||||
.all(&*tx)
|
||||
.await?)
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
2
crates/collab/src/llm/db/tables.rs
Normal file
2
crates/collab/src/llm/db/tables.rs
Normal file
|
@ -0,0 +1,2 @@
|
|||
pub mod model;
|
||||
pub mod provider;
|
31
crates/collab/src/llm/db/tables/model.rs
Normal file
31
crates/collab/src/llm/db/tables/model.rs
Normal file
|
@ -0,0 +1,31 @@
|
|||
use sea_orm::entity::prelude::*;
|
||||
|
||||
use crate::llm::db::{ModelId, ProviderId};
|
||||
|
||||
/// An LLM model.
|
||||
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
|
||||
#[sea_orm(table_name = "models")]
|
||||
pub struct Model {
|
||||
#[sea_orm(primary_key)]
|
||||
pub id: ModelId,
|
||||
pub provider_id: ProviderId,
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
pub enum Relation {
|
||||
#[sea_orm(
|
||||
belongs_to = "super::provider::Entity",
|
||||
from = "Column::ProviderId",
|
||||
to = "super::provider::Column::Id"
|
||||
)]
|
||||
Provider,
|
||||
}
|
||||
|
||||
impl Related<super::provider::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
Relation::Provider.def()
|
||||
}
|
||||
}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {}
|
26
crates/collab/src/llm/db/tables/provider.rs
Normal file
26
crates/collab/src/llm/db/tables/provider.rs
Normal file
|
@ -0,0 +1,26 @@
|
|||
use sea_orm::entity::prelude::*;
|
||||
|
||||
use crate::llm::db::ProviderId;
|
||||
|
||||
/// An LLM provider.
|
||||
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
|
||||
#[sea_orm(table_name = "providers")]
|
||||
pub struct Model {
|
||||
#[sea_orm(primary_key)]
|
||||
pub id: ProviderId,
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
pub enum Relation {
|
||||
#[sea_orm(has_many = "super::model::Entity")]
|
||||
Models,
|
||||
}
|
||||
|
||||
impl Related<super::model::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
Relation::Models.def()
|
||||
}
|
||||
}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {}
|
147
crates/collab/src/llm/db/tests.rs
Normal file
147
crates/collab/src/llm/db/tests.rs
Normal file
|
@ -0,0 +1,147 @@
|
|||
mod provider_tests;
|
||||
|
||||
use gpui::BackgroundExecutor;
|
||||
use parking_lot::Mutex;
|
||||
use rand::prelude::*;
|
||||
use sea_orm::ConnectionTrait;
|
||||
use sqlx::migrate::MigrateDatabase;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::migrations::run_database_migrations;
|
||||
|
||||
use super::*;
|
||||
|
||||
pub struct TestLlmDb {
|
||||
pub db: Option<Arc<LlmDatabase>>,
|
||||
pub connection: Option<sqlx::AnyConnection>,
|
||||
}
|
||||
|
||||
impl TestLlmDb {
|
||||
pub fn sqlite(background: BackgroundExecutor) -> Self {
|
||||
let url = "sqlite::memory:";
|
||||
let runtime = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_io()
|
||||
.enable_time()
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let mut db = runtime.block_on(async {
|
||||
let mut options = ConnectOptions::new(url);
|
||||
options.max_connections(5);
|
||||
let db = LlmDatabase::new(options, Executor::Deterministic(background))
|
||||
.await
|
||||
.unwrap();
|
||||
let sql = include_str!(concat!(
|
||||
env!("CARGO_MANIFEST_DIR"),
|
||||
"/migrations_llm.sqlite/20240806182921_test_schema.sql"
|
||||
));
|
||||
db.pool
|
||||
.execute(sea_orm::Statement::from_string(
|
||||
db.pool.get_database_backend(),
|
||||
sql,
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
db
|
||||
});
|
||||
|
||||
db.runtime = Some(runtime);
|
||||
|
||||
Self {
|
||||
db: Some(Arc::new(db)),
|
||||
connection: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn postgres(background: BackgroundExecutor) -> Self {
|
||||
static LOCK: Mutex<()> = Mutex::new(());
|
||||
|
||||
let _guard = LOCK.lock();
|
||||
let mut rng = StdRng::from_entropy();
|
||||
let url = format!(
|
||||
"postgres://postgres@localhost/zed-llm-test-{}",
|
||||
rng.gen::<u128>()
|
||||
);
|
||||
let runtime = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_io()
|
||||
.enable_time()
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let mut db = runtime.block_on(async {
|
||||
sqlx::Postgres::create_database(&url)
|
||||
.await
|
||||
.expect("failed to create test db");
|
||||
let mut options = ConnectOptions::new(url);
|
||||
options
|
||||
.max_connections(5)
|
||||
.idle_timeout(Duration::from_secs(0));
|
||||
let db = LlmDatabase::new(options, Executor::Deterministic(background))
|
||||
.await
|
||||
.unwrap();
|
||||
let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations_llm");
|
||||
run_database_migrations(db.options(), migrations_path, false)
|
||||
.await
|
||||
.unwrap();
|
||||
db
|
||||
});
|
||||
|
||||
db.runtime = Some(runtime);
|
||||
|
||||
Self {
|
||||
db: Some(Arc::new(db)),
|
||||
connection: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn db(&self) -> &Arc<LlmDatabase> {
|
||||
self.db.as_ref().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! test_both_llm_dbs {
|
||||
($test_name:ident, $postgres_test_name:ident, $sqlite_test_name:ident) => {
|
||||
#[cfg(target_os = "macos")]
|
||||
#[gpui::test]
|
||||
async fn $postgres_test_name(cx: &mut gpui::TestAppContext) {
|
||||
let test_db = $crate::llm::db::TestLlmDb::postgres(cx.executor().clone());
|
||||
$test_name(test_db.db()).await;
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn $sqlite_test_name(cx: &mut gpui::TestAppContext) {
|
||||
let test_db = $crate::llm::db::TestLlmDb::sqlite(cx.executor().clone());
|
||||
$test_name(test_db.db()).await;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl Drop for TestLlmDb {
|
||||
fn drop(&mut self) {
|
||||
let db = self.db.take().unwrap();
|
||||
if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
|
||||
db.runtime.as_ref().unwrap().block_on(async {
|
||||
use util::ResultExt;
|
||||
let query = "
|
||||
SELECT pg_terminate_backend(pg_stat_activity.pid)
|
||||
FROM pg_stat_activity
|
||||
WHERE
|
||||
pg_stat_activity.datname = current_database() AND
|
||||
pid <> pg_backend_pid();
|
||||
";
|
||||
db.pool
|
||||
.execute(sea_orm::Statement::from_string(
|
||||
db.pool.get_database_backend(),
|
||||
query,
|
||||
))
|
||||
.await
|
||||
.log_err();
|
||||
sqlx::Postgres::drop_database(db.options.get_url())
|
||||
.await
|
||||
.log_err();
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
30
crates/collab/src/llm/db/tests/provider_tests.rs
Normal file
30
crates/collab/src/llm/db/tests/provider_tests.rs
Normal file
|
@ -0,0 +1,30 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
use crate::llm::db::LlmDatabase;
|
||||
use crate::test_both_llm_dbs;
|
||||
|
||||
test_both_llm_dbs!(
|
||||
test_initialize_providers,
|
||||
test_initialize_providers_postgres,
|
||||
test_initialize_providers_sqlite
|
||||
);
|
||||
|
||||
async fn test_initialize_providers(db: &Arc<LlmDatabase>) {
|
||||
let initial_providers = db.list_providers().await.unwrap();
|
||||
assert_eq!(initial_providers, vec![]);
|
||||
|
||||
db.initialize_providers().await.unwrap();
|
||||
|
||||
// Do it twice, to make sure the operation is idempotent.
|
||||
db.initialize_providers().await.unwrap();
|
||||
|
||||
let providers = db.list_providers().await.unwrap();
|
||||
|
||||
let provider_names = providers
|
||||
.into_iter()
|
||||
.map(|provider| provider.name)
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(provider_names, vec!["anthropic".to_string()]);
|
||||
}
|
|
@ -5,6 +5,8 @@ use axum::{
|
|||
routing::get,
|
||||
Extension, Router,
|
||||
};
|
||||
use collab::llm::db::LlmDatabase;
|
||||
use collab::migrations::run_database_migrations;
|
||||
use collab::{api::billing::poll_stripe_events_periodically, llm::LlmState, ServiceMode};
|
||||
use collab::{
|
||||
api::fetch_extensions_from_blob_store_periodically, db, env, executor::Executor,
|
||||
|
@ -45,7 +47,7 @@ async fn main() -> Result<()> {
|
|||
}
|
||||
Some("migrate") => {
|
||||
let config = envy::from_env::<Config>().expect("error loading config");
|
||||
run_migrations(&config).await?;
|
||||
setup_app_database(&config).await?;
|
||||
}
|
||||
Some("seed") => {
|
||||
let config = envy::from_env::<Config>().expect("error loading config");
|
||||
|
@ -81,6 +83,8 @@ async fn main() -> Result<()> {
|
|||
let mut on_shutdown = None;
|
||||
|
||||
if mode.is_llm() {
|
||||
setup_llm_database(&config).await?;
|
||||
|
||||
let state = LlmState::new(config.clone(), Executor::Production).await?;
|
||||
|
||||
app = app
|
||||
|
@ -89,7 +93,7 @@ async fn main() -> Result<()> {
|
|||
}
|
||||
|
||||
if mode.is_collab() || mode.is_api() {
|
||||
run_migrations(&config).await?;
|
||||
setup_app_database(&config).await?;
|
||||
|
||||
let state = AppState::new(config, Executor::Production).await?;
|
||||
|
||||
|
@ -203,7 +207,7 @@ async fn main() -> Result<()> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_migrations(config: &Config) -> Result<()> {
|
||||
async fn setup_app_database(config: &Config) -> Result<()> {
|
||||
let db_options = db::ConnectOptions::new(config.database_url.clone());
|
||||
let mut db = Database::new(db_options, Executor::Production).await?;
|
||||
|
||||
|
@ -216,7 +220,7 @@ async fn run_migrations(config: &Config) -> Result<()> {
|
|||
Path::new(default_migrations)
|
||||
});
|
||||
|
||||
let migrations = db.migrate(&migrations_path, false).await?;
|
||||
let migrations = run_database_migrations(db.options(), migrations_path, false).await?;
|
||||
for (migration, duration) in migrations {
|
||||
log::info!(
|
||||
"Migrated {} {} {:?}",
|
||||
|
@ -232,7 +236,46 @@ async fn run_migrations(config: &Config) -> Result<()> {
|
|||
collab::seed::seed(&config, &db, false).await?;
|
||||
}
|
||||
|
||||
return Ok(());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn setup_llm_database(config: &Config) -> Result<()> {
|
||||
// TODO: This is temporary until we have the LLM database stood up.
|
||||
if !config.is_development() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let database_url = config
|
||||
.llm_database_url
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?;
|
||||
|
||||
let db_options = db::ConnectOptions::new(database_url.clone());
|
||||
let db = LlmDatabase::new(db_options, Executor::Production).await?;
|
||||
|
||||
let migrations_path = config
|
||||
.llm_database_migrations_path
|
||||
.as_deref()
|
||||
.unwrap_or_else(|| {
|
||||
#[cfg(feature = "sqlite")]
|
||||
let default_migrations = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations_llm.sqlite");
|
||||
#[cfg(not(feature = "sqlite"))]
|
||||
let default_migrations = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations_llm");
|
||||
|
||||
Path::new(default_migrations)
|
||||
});
|
||||
|
||||
let migrations = run_database_migrations(db.options(), migrations_path, false).await?;
|
||||
for (migration, duration) in migrations {
|
||||
log::info!(
|
||||
"Migrated {} {} {:?}",
|
||||
migration.version,
|
||||
migration.description,
|
||||
duration
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_root(Extension(mode): Extension<ServiceMode>) -> String {
|
||||
|
|
49
crates/collab/src/migrations.rs
Normal file
49
crates/collab/src/migrations.rs
Normal file
|
@ -0,0 +1,49 @@
|
|||
use std::path::Path;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use collections::HashMap;
|
||||
use sea_orm::ConnectOptions;
|
||||
use sqlx::migrate::{Migrate, Migration, MigrationSource};
|
||||
use sqlx::Connection;
|
||||
|
||||
/// Runs the database migrations for the specified database.
|
||||
pub async fn run_database_migrations(
|
||||
database_options: &ConnectOptions,
|
||||
migrations_path: impl AsRef<Path>,
|
||||
ignore_checksum_mismatch: bool,
|
||||
) -> Result<Vec<(Migration, Duration)>> {
|
||||
let migrations = MigrationSource::resolve(migrations_path.as_ref())
|
||||
.await
|
||||
.map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
|
||||
|
||||
let mut connection = sqlx::AnyConnection::connect(database_options.get_url()).await?;
|
||||
|
||||
connection.ensure_migrations_table().await?;
|
||||
let applied_migrations: HashMap<_, _> = connection
|
||||
.list_applied_migrations()
|
||||
.await?
|
||||
.into_iter()
|
||||
.map(|migration| (migration.version, migration))
|
||||
.collect();
|
||||
|
||||
let mut new_migrations = Vec::new();
|
||||
for migration in migrations {
|
||||
match applied_migrations.get(&migration.version) {
|
||||
Some(applied_migration) => {
|
||||
if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch {
|
||||
Err(anyhow!(
|
||||
"checksum mismatch for applied migration {}",
|
||||
migration.description
|
||||
))?;
|
||||
}
|
||||
}
|
||||
None => {
|
||||
let elapsed = connection.apply(&migration).await?;
|
||||
new_migrations.push((migration, elapsed));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(new_migrations)
|
||||
}
|
|
@ -651,6 +651,9 @@ impl TestServer {
|
|||
live_kit_server: None,
|
||||
live_kit_key: None,
|
||||
live_kit_secret: None,
|
||||
llm_database_url: None,
|
||||
llm_database_max_connections: None,
|
||||
llm_database_migrations_path: None,
|
||||
llm_api_secret: None,
|
||||
rust_log: None,
|
||||
log_json: None,
|
||||
|
|
|
@ -1 +1,2 @@
|
|||
create database zed;
|
||||
create database zed_llm;
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
set -e
|
||||
|
||||
if [[ "$OSTYPE" == "linux-gnu"* ]]; then
|
||||
echo "Linux dependencies..."
|
||||
script/linux
|
||||
|
@ -8,5 +10,17 @@ else
|
|||
which foreman > /dev/null || brew install foreman
|
||||
fi
|
||||
|
||||
echo "creating database..."
|
||||
script/sqlx database create
|
||||
# Install sqlx-cli if needed
|
||||
if [[ "$(sqlx --version)" != "sqlx-cli 0.5.7" ]]; then
|
||||
echo "sqlx-cli not found or not the required version, installing version 0.5.7..."
|
||||
cargo install sqlx-cli --version 0.5.7
|
||||
fi
|
||||
|
||||
cd crates/collab
|
||||
|
||||
# Export contents of .env.toml
|
||||
eval "$(cargo run --bin dotenv)"
|
||||
|
||||
echo "creating databases..."
|
||||
sqlx database create --database-url "$DATABASE_URL"
|
||||
sqlx database create --database-url "$LLM_DATABASE_URL"
|
||||
|
|
|
@ -1,2 +1,3 @@
|
|||
psql -c "DROP DATABASE zed (FORCE);"
|
||||
psql -c "DROP DATABASE zed_llm (FORCE);"
|
||||
script/bootstrap
|
||||
|
|
17
script/sqlx
17
script/sqlx
|
@ -1,17 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
# Install sqlx-cli if needed
|
||||
if [[ "$(sqlx --version)" != "sqlx-cli 0.5.7" ]]; then
|
||||
echo "sqlx-cli not found or not the required version, installing version 0.5.7..."
|
||||
cargo install sqlx-cli --version 0.5.7
|
||||
fi
|
||||
|
||||
cd crates/collab
|
||||
|
||||
# Export contents of .env.toml
|
||||
eval "$(cargo run --bin dotenv)"
|
||||
|
||||
# Run sqlx command
|
||||
sqlx $@
|
Loading…
Reference in a new issue