diff --git a/crates/collab/src/api/billing.rs b/crates/collab/src/api/billing.rs index 0e4c85907b..c3e9b29ee6 100644 --- a/crates/collab/src/api/billing.rs +++ b/crates/collab/src/api/billing.rs @@ -169,9 +169,7 @@ struct ManageBillingSubscriptionBody { github_user_id: i32, intent: ManageSubscriptionIntent, /// The ID of the subscription to manage. - /// - /// If not provided, we will try to use the active subscription (if there is only one). - subscription_id: Option, + subscription_id: BillingSubscriptionId, } #[derive(Debug, Serialize)] @@ -206,23 +204,11 @@ async fn manage_billing_subscription( let customer_id = CustomerId::from_str(&customer.stripe_customer_id) .context("failed to parse customer ID")?; - let subscription = if let Some(subscription_id) = body.subscription_id { - app.db - .get_billing_subscription_by_id(subscription_id) - .await? - .ok_or_else(|| anyhow!("subscription not found"))? - } else { - // If no subscription ID was provided, try to find the only active subscription ID. - let subscriptions = app.db.get_active_billing_subscriptions(user.id).await?; - if subscriptions.len() > 1 { - Err(anyhow!("user has multiple active subscriptions"))?; - } - - subscriptions - .into_iter() - .next() - .ok_or_else(|| anyhow!("user has no active subscriptions"))? - }; + let subscription = app + .db + .get_billing_subscription_by_id(body.subscription_id) + .await? + .ok_or_else(|| anyhow!("subscription not found"))?; let flow = match body.intent { ManageSubscriptionIntent::Cancel => CreateBillingPortalSessionFlowData { diff --git a/crates/collab/src/db/queries/billing_subscriptions.rs b/crates/collab/src/db/queries/billing_subscriptions.rs index e2fc3f4f1b..72494d1b3a 100644 --- a/crates/collab/src/db/queries/billing_subscriptions.rs +++ b/crates/collab/src/db/queries/billing_subscriptions.rs @@ -110,13 +110,15 @@ impl Database { .await } - /// Returns all of the active billing subscriptions for the user with the specified ID. - pub async fn get_active_billing_subscriptions( - &self, - user_id: UserId, - ) -> Result> { + /// Returns whether the user has an active billing subscription. + pub async fn has_active_billing_subscription(&self, user_id: UserId) -> Result { + Ok(self.count_active_billing_subscriptions(user_id).await? > 0) + } + + /// Returns the count of the active billing subscriptions for the user with the specified ID. + pub async fn count_active_billing_subscriptions(&self, user_id: UserId) -> Result { self.transaction(|tx| async move { - let subscriptions = billing_subscription::Entity::find() + let count = billing_subscription::Entity::find() .inner_join(billing_customer::Entity) .filter( billing_customer::Column::UserId.eq(user_id).and( @@ -124,11 +126,10 @@ impl Database { .eq(StripeSubscriptionStatus::Active), ), ) - .order_by_asc(billing_subscription::Column::Id) - .all(&*tx) + .count(&*tx) .await?; - Ok(subscriptions) + Ok(count as usize) }) .await } diff --git a/crates/collab/src/db/tests/billing_subscription_tests.rs b/crates/collab/src/db/tests/billing_subscription_tests.rs index 19f5463ac2..a1973e3fbb 100644 --- a/crates/collab/src/db/tests/billing_subscription_tests.rs +++ b/crates/collab/src/db/tests/billing_subscription_tests.rs @@ -17,9 +17,12 @@ async fn test_get_active_billing_subscriptions(db: &Arc) { // A user with no subscription has no active billing subscriptions. { let user_id = new_test_user(db, "no-subscription-user@example.com").await; - let subscriptions = db.get_active_billing_subscriptions(user_id).await.unwrap(); + let subscription_count = db + .count_active_billing_subscriptions(user_id) + .await + .unwrap(); - assert_eq!(subscriptions.len(), 0); + assert_eq!(subscription_count, 0); } // A user with an active subscription has one active billing subscription. @@ -42,7 +45,7 @@ async fn test_get_active_billing_subscriptions(db: &Arc) { .await .unwrap(); - let subscriptions = db.get_active_billing_subscriptions(user_id).await.unwrap(); + let subscriptions = db.get_billing_subscriptions(user_id).await.unwrap(); assert_eq!(subscriptions.len(), 1); let subscription = &subscriptions[0]; @@ -76,7 +79,10 @@ async fn test_get_active_billing_subscriptions(db: &Arc) { .await .unwrap(); - let subscriptions = db.get_active_billing_subscriptions(user_id).await.unwrap(); - assert_eq!(subscriptions.len(), 0); + let subscription_count = db + .count_active_billing_subscriptions(user_id) + .await + .unwrap(); + assert_eq!(subscription_count, 0); } } diff --git a/crates/collab/src/rate_limiter.rs b/crates/collab/src/rate_limiter.rs index 844a3af949..5e619f7db2 100644 --- a/crates/collab/src/rate_limiter.rs +++ b/crates/collab/src/rate_limiter.rs @@ -6,10 +6,10 @@ use sea_orm::prelude::DateTimeUtc; use std::sync::Arc; use util::ResultExt; -pub trait RateLimit: 'static { - fn capacity() -> usize; - fn refill_duration() -> Duration; - fn db_name() -> &'static str; +pub trait RateLimit: Send + Sync { + fn capacity(&self) -> usize; + fn refill_duration(&self) -> Duration; + fn db_name(&self) -> &'static str; } /// Used to enforce per-user rate limits @@ -42,18 +42,23 @@ impl RateLimiter { /// Returns an error if the user has exceeded the specified `RateLimit`. /// Attempts to read the from the database if no cached RateBucket currently exists. - pub async fn check(&self, user_id: UserId) -> Result<()> { - self.check_internal::(user_id, Utc::now()).await + pub async fn check(&self, limit: &dyn RateLimit, user_id: UserId) -> Result<()> { + self.check_internal(limit, user_id, Utc::now()).await } - async fn check_internal(&self, user_id: UserId, now: DateTimeUtc) -> Result<()> { - let bucket_key = (user_id, T::db_name().to_string()); + async fn check_internal( + &self, + limit: &dyn RateLimit, + user_id: UserId, + now: DateTimeUtc, + ) -> Result<()> { + let bucket_key = (user_id, limit.db_name().to_string()); // Attempt to fetch the bucket from the database if it hasn't been cached. // For now, we keep buckets in memory for the lifetime of the process rather than expiring them, // but this enforces limits across restarts so long as the database is reachable. if !self.buckets.contains_key(&bucket_key) { - if let Some(bucket) = self.load_bucket::(user_id).await.log_err().flatten() { + if let Some(bucket) = self.load_bucket(limit, user_id).await.log_err().flatten() { self.buckets.insert(bucket_key.clone(), bucket); self.dirty_buckets.insert(bucket_key.clone()); } @@ -62,7 +67,7 @@ impl RateLimiter { let mut bucket = self .buckets .entry(bucket_key.clone()) - .or_insert_with(|| RateBucket::new::(now)); + .or_insert_with(|| RateBucket::new(limit, now)); if bucket.value_mut().allow(now) { self.dirty_buckets.insert(bucket_key); @@ -72,16 +77,18 @@ impl RateLimiter { } } - async fn load_bucket( + async fn load_bucket( &self, + limit: &dyn RateLimit, user_id: UserId, ) -> Result, Error> { Ok(self .db - .get_rate_bucket(user_id, T::db_name()) + .get_rate_bucket(user_id, limit.db_name()) .await? .map(|saved_bucket| { - RateBucket::from_db::( + RateBucket::from_db( + limit, saved_bucket.token_count as usize, DateTime::from_naive_utc_and_offset(saved_bucket.last_refill, Utc), ) @@ -124,20 +131,20 @@ struct RateBucket { } impl RateBucket { - fn new(now: DateTimeUtc) -> Self { + fn new(limit: &dyn RateLimit, now: DateTimeUtc) -> Self { Self { - capacity: T::capacity(), - token_count: T::capacity(), - refill_time_per_token: T::refill_duration() / T::capacity() as i32, + capacity: limit.capacity(), + token_count: limit.capacity(), + refill_time_per_token: limit.refill_duration() / limit.capacity() as i32, last_refill: now, } } - fn from_db(token_count: usize, last_refill: DateTimeUtc) -> Self { + fn from_db(limit: &dyn RateLimit, token_count: usize, last_refill: DateTimeUtc) -> Self { Self { - capacity: T::capacity(), + capacity: limit.capacity(), token_count, - refill_time_per_token: T::refill_duration() / T::capacity() as i32, + refill_time_per_token: limit.refill_duration() / limit.capacity() as i32, last_refill, } } @@ -205,50 +212,52 @@ mod tests { let mut now = Utc::now(); let rate_limiter = RateLimiter::new(db.clone()); + let rate_limit_a = Box::new(RateLimitA); + let rate_limit_b = Box::new(RateLimitB); // User 1 can access resource A two times before being rate-limited. rate_limiter - .check_internal::(user_1, now) + .check_internal(&*rate_limit_a, user_1, now) .await .unwrap(); rate_limiter - .check_internal::(user_1, now) + .check_internal(&*rate_limit_a, user_1, now) .await .unwrap(); rate_limiter - .check_internal::(user_1, now) + .check_internal(&*rate_limit_a, user_1, now) .await .unwrap_err(); // User 2 can access resource A and user 1 can access resource B. rate_limiter - .check_internal::(user_2, now) + .check_internal(&*rate_limit_b, user_2, now) .await .unwrap(); rate_limiter - .check_internal::(user_1, now) + .check_internal(&*rate_limit_b, user_1, now) .await .unwrap(); // After 1.5s, user 1 can make another request before being rate-limited again. now += Duration::milliseconds(1500); rate_limiter - .check_internal::(user_1, now) + .check_internal(&*rate_limit_a, user_1, now) .await .unwrap(); rate_limiter - .check_internal::(user_1, now) + .check_internal(&*rate_limit_a, user_1, now) .await .unwrap_err(); // After 500ms, user 1 can make another request before being rate-limited again. now += Duration::milliseconds(500); rate_limiter - .check_internal::(user_1, now) + .check_internal(&*rate_limit_a, user_1, now) .await .unwrap(); rate_limiter - .check_internal::(user_1, now) + .check_internal(&*rate_limit_a, user_1, now) .await .unwrap_err(); @@ -258,18 +267,18 @@ mod tests { // for resource A. let rate_limiter = RateLimiter::new(db.clone()); rate_limiter - .check_internal::(user_1, now) + .check_internal(&*rate_limit_a, user_1, now) .await .unwrap_err(); // After 1s, user 1 can make another request before being rate-limited again. now += Duration::seconds(1); rate_limiter - .check_internal::(user_1, now) + .check_internal(&*rate_limit_a, user_1, now) .await .unwrap(); rate_limiter - .check_internal::(user_1, now) + .check_internal(&*rate_limit_a, user_1, now) .await .unwrap_err(); } @@ -277,15 +286,15 @@ mod tests { struct RateLimitA; impl RateLimit for RateLimitA { - fn capacity() -> usize { + fn capacity(&self) -> usize { 2 } - fn refill_duration() -> Duration { + fn refill_duration(&self) -> Duration { Duration::seconds(2) } - fn db_name() -> &'static str { + fn db_name(&self) -> &'static str { "rate-limit-a" } } @@ -293,15 +302,15 @@ mod tests { struct RateLimitB; impl RateLimit for RateLimitB { - fn capacity() -> usize { + fn capacity(&self) -> usize { 10 } - fn refill_duration() -> Duration { + fn refill_duration(&self) -> Duration { Duration::seconds(3) } - fn db_name() -> &'static str { + fn db_name(&self) -> &'static str { "rate-limit-b" } } diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 81ce57561a..3683cdc5c8 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -199,6 +199,23 @@ impl Session { } } + pub async fn current_plan(&self) -> anyhow::Result { + if self.is_staff() { + return Ok(proto::Plan::ZedPro); + } + + let Some(user_id) = self.user_id() else { + return Ok(proto::Plan::Free); + }; + + let db = self.db().await; + if db.has_active_billing_subscription(user_id).await? { + Ok(proto::Plan::ZedPro) + } else { + Ok(proto::Plan::Free) + } + } + fn dev_server_id(&self) -> Option { match &self.principal { Principal::User(_) | Principal::Impersonated { .. } => None, @@ -3537,15 +3554,8 @@ fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool { version.0.minor() < 139 } -async fn update_user_plan(user_id: UserId, session: &Session) -> Result<()> { - let db = session.db().await; - let active_subscriptions = db.get_active_billing_subscriptions(user_id).await?; - - let plan = if session.is_staff() || !active_subscriptions.is_empty() { - proto::Plan::ZedPro - } else { - proto::Plan::Free - }; +async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> { + let plan = session.current_plan().await?; session .peer @@ -4532,22 +4542,41 @@ async fn acknowledge_buffer_version( Ok(()) } -struct CompleteWithLanguageModelRateLimit; +struct ZedProCompleteWithLanguageModelRateLimit; -impl RateLimit for CompleteWithLanguageModelRateLimit { - fn capacity() -> usize { +impl RateLimit for ZedProCompleteWithLanguageModelRateLimit { + fn capacity(&self) -> usize { std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR") .ok() .and_then(|v| v.parse().ok()) .unwrap_or(120) // Picked arbitrarily } - fn refill_duration() -> chrono::Duration { + fn refill_duration(&self) -> chrono::Duration { chrono::Duration::hours(1) } - fn db_name() -> &'static str { - "complete-with-language-model" + fn db_name(&self) -> &'static str { + "zed-pro:complete-with-language-model" + } +} + +struct FreeCompleteWithLanguageModelRateLimit; + +impl RateLimit for FreeCompleteWithLanguageModelRateLimit { + fn capacity(&self) -> usize { + std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR_FREE") + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(120 / 10) // Picked arbitrarily + } + + fn refill_duration(&self) -> chrono::Duration { + chrono::Duration::hours(1) + } + + fn db_name(&self) -> &'static str { + "free:complete-with-language-model" } } @@ -4562,9 +4591,14 @@ async fn complete_with_language_model( }; authorize_access_to_language_models(&session).await?; + let rate_limit: Box = match session.current_plan().await? { + proto::Plan::ZedPro => Box::new(ZedProCompleteWithLanguageModelRateLimit), + proto::Plan::Free => Box::new(FreeCompleteWithLanguageModelRateLimit), + }; + session .rate_limiter - .check::(session.user_id()) + .check(&*rate_limit, session.user_id()) .await?; let result = match proto::LanguageModelProvider::from_i32(request.provider) { @@ -4602,9 +4636,14 @@ async fn stream_complete_with_language_model( }; authorize_access_to_language_models(&session).await?; + let rate_limit: Box = match session.current_plan().await? { + proto::Plan::ZedPro => Box::new(ZedProCompleteWithLanguageModelRateLimit), + proto::Plan::Free => Box::new(FreeCompleteWithLanguageModelRateLimit), + }; + session .rate_limiter - .check::(session.user_id()) + .check(&*rate_limit, session.user_id()) .await?; match proto::LanguageModelProvider::from_i32(request.provider) { @@ -4684,9 +4723,14 @@ async fn count_language_model_tokens( }; authorize_access_to_language_models(&session).await?; + let rate_limit: Box = match session.current_plan().await? { + proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit), + proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit), + }; + session .rate_limiter - .check::(session.user_id()) + .check(&*rate_limit, session.user_id()) .await?; let result = match proto::LanguageModelProvider::from_i32(request.provider) { @@ -4713,41 +4757,79 @@ async fn count_language_model_tokens( Ok(()) } -struct CountLanguageModelTokensRateLimit; +struct ZedProCountLanguageModelTokensRateLimit; -impl RateLimit for CountLanguageModelTokensRateLimit { - fn capacity() -> usize { +impl RateLimit for ZedProCountLanguageModelTokensRateLimit { + fn capacity(&self) -> usize { std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR") .ok() .and_then(|v| v.parse().ok()) .unwrap_or(600) // Picked arbitrarily } - fn refill_duration() -> chrono::Duration { + fn refill_duration(&self) -> chrono::Duration { chrono::Duration::hours(1) } - fn db_name() -> &'static str { - "count-language-model-tokens" + fn db_name(&self) -> &'static str { + "zed-pro:count-language-model-tokens" } } -struct ComputeEmbeddingsRateLimit; +struct FreeCountLanguageModelTokensRateLimit; -impl RateLimit for ComputeEmbeddingsRateLimit { - fn capacity() -> usize { +impl RateLimit for FreeCountLanguageModelTokensRateLimit { + fn capacity(&self) -> usize { + std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE") + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(600 / 10) // Picked arbitrarily + } + + fn refill_duration(&self) -> chrono::Duration { + chrono::Duration::hours(1) + } + + fn db_name(&self) -> &'static str { + "free:count-language-model-tokens" + } +} + +struct ZedProComputeEmbeddingsRateLimit; + +impl RateLimit for ZedProComputeEmbeddingsRateLimit { + fn capacity(&self) -> usize { std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR") .ok() .and_then(|v| v.parse().ok()) .unwrap_or(5000) // Picked arbitrarily } - fn refill_duration() -> chrono::Duration { + fn refill_duration(&self) -> chrono::Duration { chrono::Duration::hours(1) } - fn db_name() -> &'static str { - "compute-embeddings" + fn db_name(&self) -> &'static str { + "zed-pro:compute-embeddings" + } +} + +struct FreeComputeEmbeddingsRateLimit; + +impl RateLimit for FreeComputeEmbeddingsRateLimit { + fn capacity(&self) -> usize { + std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE") + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(5000 / 10) // Picked arbitrarily + } + + fn refill_duration(&self) -> chrono::Duration { + chrono::Duration::hours(1) + } + + fn db_name(&self) -> &'static str { + "free:compute-embeddings" } } @@ -4760,9 +4842,14 @@ async fn compute_embeddings( let api_key = api_key.context("no OpenAI API key configured on the server")?; authorize_access_to_language_models(&session).await?; + let rate_limit: Box = match session.current_plan().await? { + proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit), + proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit), + }; + session .rate_limiter - .check::(session.user_id()) + .check(&*rate_limit, session.user_id()) .await?; let embeddings = match request.model.as_str() { @@ -4834,10 +4921,10 @@ async fn authorize_access_to_language_models(session: &UserSession) -> Result<() let db = session.db().await; let flags = db.get_user_flags(session.user_id()).await?; if flags.iter().any(|flag| flag == "language-models") { - Ok(()) - } else { - Err(anyhow!("permission denied"))? + return Ok(()); } + + Err(anyhow!("permission denied"))? } /// Get a Supermaven API key for the user