From 28c14cdee4fbae2054b41f0c5d6216d08164caa6 Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Mon, 29 Jul 2024 22:48:21 -0400 Subject: [PATCH] collab: Add separate `billing_customers` table (#15457) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR adds a new `billing_customers` table to hold the billing customers. Previously we were storing both the `stripe_customer_id` and `stripe_subscription_id` in the `billable_subscriptions` table. However, this creates problems when we need to correlate subscription events back to the subscription record, as we don't know the user that the Stripe event corresponds to. By moving the `stripe_customer_id` to a separate table we can create the Stripe customer earlier in the flow—before we create the Stripe Checkout session—and associate that customer with a user. This way when we receive events down the line we can use the Stripe customer ID to correlate it back to the user. We're doing some destructive actions to the `billing_subscriptions` table, but this is fine, as we haven't started using them yet. Release Notes: - N/A --- .../20221109000000_test_schema.sql | 16 +++++-- .../20240730014107_add_billing_customer.sql | 18 +++++++ crates/collab/src/api/billing.rs | 47 ++++++++++--------- crates/collab/src/db.rs | 1 + crates/collab/src/db/ids.rs | 1 + crates/collab/src/db/queries.rs | 1 + .../src/db/queries/billing_customers.rs | 42 +++++++++++++++++ .../src/db/queries/billing_subscriptions.rs | 12 ++--- crates/collab/src/db/tables.rs | 1 + .../collab/src/db/tables/billing_customer.rs | 39 +++++++++++++++ .../src/db/tables/billing_subscription.rs | 17 ++++--- crates/collab/src/db/tables/user.rs | 10 +++- .../db/tests/billing_subscription_tests.rs | 30 ++++++++---- 13 files changed, 183 insertions(+), 52 deletions(-) create mode 100644 crates/collab/migrations/20240730014107_add_billing_customer.sql create mode 100644 crates/collab/src/db/queries/billing_customers.rs create mode 100644 crates/collab/src/db/tables/billing_customer.rs diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql index 161ab2c03b..b19eaecd63 100644 --- a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql +++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql @@ -420,12 +420,20 @@ CREATE TABLE dev_server_projects ( CREATE TABLE IF NOT EXISTS billing_subscriptions ( id INTEGER PRIMARY KEY AUTOINCREMENT, created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - user_id INTEGER NOT NULL REFERENCES users(id), - stripe_customer_id TEXT NOT NULL, + billing_customer_id INTEGER NOT NULL REFERENCES billing_customers(id), stripe_subscription_id TEXT NOT NULL, stripe_subscription_status TEXT NOT NULL ); -CREATE INDEX "ix_billing_subscriptions_on_user_id" ON billing_subscriptions (user_id); -CREATE INDEX "ix_billing_subscriptions_on_stripe_customer_id" ON billing_subscriptions (stripe_customer_id); +CREATE INDEX "ix_billing_subscriptions_on_billing_customer_id" ON billing_subscriptions (billing_customer_id); CREATE UNIQUE INDEX "uix_billing_subscriptions_on_stripe_subscription_id" ON billing_subscriptions (stripe_subscription_id); + +CREATE TABLE IF NOT EXISTS billing_customers ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + user_id INTEGER NOT NULL REFERENCES users(id), + stripe_customer_id TEXT NOT NULL +); + +CREATE UNIQUE INDEX "uix_billing_customers_on_user_id" ON billing_customers (user_id); +CREATE UNIQUE INDEX "uix_billing_customers_on_stripe_customer_id" ON billing_customers (stripe_customer_id); diff --git a/crates/collab/migrations/20240730014107_add_billing_customer.sql b/crates/collab/migrations/20240730014107_add_billing_customer.sql new file mode 100644 index 0000000000..7f7d4a0f85 --- /dev/null +++ b/crates/collab/migrations/20240730014107_add_billing_customer.sql @@ -0,0 +1,18 @@ +CREATE TABLE IF NOT EXISTS billing_customers ( + id SERIAL PRIMARY KEY, + created_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT now(), + user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, + stripe_customer_id TEXT NOT NULL +); + +CREATE UNIQUE INDEX "uix_billing_customers_on_user_id" ON billing_customers (user_id); +CREATE UNIQUE INDEX "uix_billing_customers_on_stripe_customer_id" ON billing_customers (stripe_customer_id); + +-- Make `billing_subscriptions` reference `billing_customers` instead of having its +-- own `user_id` and `stripe_customer_id`. +DROP INDEX IF EXISTS "ix_billing_subscriptions_on_user_id"; +DROP INDEX IF EXISTS "ix_billing_subscriptions_on_stripe_customer_id"; +ALTER TABLE billing_subscriptions DROP COLUMN user_id; +ALTER TABLE billing_subscriptions DROP COLUMN stripe_customer_id; +ALTER TABLE billing_subscriptions ADD COLUMN billing_customer_id INTEGER NOT NULL REFERENCES billing_customers (id) ON DELETE CASCADE; +CREATE INDEX "ix_billing_subscriptions_on_billing_customer_id" ON billing_subscriptions (billing_customer_id); diff --git a/crates/collab/src/api/billing.rs b/crates/collab/src/api/billing.rs index bc7052252a..0db4cf062b 100644 --- a/crates/collab/src/api/billing.rs +++ b/crates/collab/src/api/billing.rs @@ -3,7 +3,6 @@ use std::sync::Arc; use anyhow::{anyhow, Context}; use axum::{extract, routing::post, Extension, Json, Router}; -use collections::HashSet; use reqwest::StatusCode; use serde::{Deserialize, Serialize}; use stripe::{ @@ -11,7 +10,7 @@ use stripe::{ CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion, CreateBillingPortalSessionFlowDataAfterCompletionRedirect, CreateBillingPortalSessionFlowDataType, CreateCheckoutSession, CreateCheckoutSessionLineItems, - CustomerId, + CreateCustomer, Customer, CustomerId, }; use crate::db::BillingSubscriptionId; @@ -59,28 +58,27 @@ async fn create_billing_subscription( ))? }; - let existing_customer_id = { - let existing_subscriptions = app.db.get_billing_subscriptions(user.id).await?; - let distinct_customer_ids = existing_subscriptions - .iter() - .map(|subscription| subscription.stripe_customer_id.as_str()) - .collect::>(); - // Sanity: Make sure we can determine a single Stripe customer ID for the user. - if distinct_customer_ids.len() > 1 { - Err(anyhow!("user has multiple existing customer IDs"))?; - } + let customer_id = + if let Some(existing_customer) = app.db.get_billing_customer_by_user_id(user.id).await? { + CustomerId::from_str(&existing_customer.stripe_customer_id) + .context("failed to parse customer ID")? + } else { + let customer = Customer::create( + &stripe_client, + CreateCustomer { + email: user.email_address.as_deref(), + ..Default::default() + }, + ) + .await?; - distinct_customer_ids - .into_iter() - .next() - .map(|id| CustomerId::from_str(id).context("failed to parse customer ID")) - .transpose() - }?; + customer.id + }; let checkout_session = { let mut params = CreateCheckoutSession::new(); params.mode = Some(stripe::CheckoutSessionMode::Subscription); - params.customer = existing_customer_id; + params.customer = Some(customer_id); params.client_reference_id = Some(user.github_login.as_str()); params.line_items = Some(vec![CreateCheckoutSessionLineItems { price: Some(stripe_price_id.to_string()), @@ -140,6 +138,14 @@ async fn manage_billing_subscription( ))? }; + let customer = app + .db + .get_billing_customer_by_user_id(user.id) + .await? + .ok_or_else(|| anyhow!("billing customer not found"))?; + let customer_id = CustomerId::from_str(&customer.stripe_customer_id) + .context("failed to parse customer ID")?; + let subscription = if let Some(subscription_id) = body.subscription_id { app.db .get_billing_subscription_by_id(subscription_id) @@ -158,9 +164,6 @@ async fn manage_billing_subscription( .ok_or_else(|| anyhow!("user has no active subscriptions"))? }; - let customer_id = CustomerId::from_str(&subscription.stripe_customer_id) - .context("failed to parse customer ID")?; - let flow = match body.intent { ManageSubscriptionIntent::Cancel => CreateBillingPortalSessionFlowData { type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel, diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 58f08827ec..b34de6b326 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -45,6 +45,7 @@ use tokio::sync::{Mutex, OwnedMutexGuard}; pub use tests::TestDb; pub use ids::*; +pub use queries::billing_customers::CreateBillingCustomerParams; pub use queries::billing_subscriptions::CreateBillingSubscriptionParams; pub use queries::contributors::ContributorSelector; pub use sea_orm::ConnectOptions; diff --git a/crates/collab/src/db/ids.rs b/crates/collab/src/db/ids.rs index 67c833ab26..21206e377e 100644 --- a/crates/collab/src/db/ids.rs +++ b/crates/collab/src/db/ids.rs @@ -68,6 +68,7 @@ macro_rules! id_type { } id_type!(AccessTokenId); +id_type!(BillingCustomerId); id_type!(BillingSubscriptionId); id_type!(BufferId); id_type!(ChannelBufferCollaboratorId); diff --git a/crates/collab/src/db/queries.rs b/crates/collab/src/db/queries.rs index 45e259ceb2..fb1d6b24f2 100644 --- a/crates/collab/src/db/queries.rs +++ b/crates/collab/src/db/queries.rs @@ -1,6 +1,7 @@ use super::*; pub mod access_tokens; +pub mod billing_customers; pub mod billing_subscriptions; pub mod buffers; pub mod channels; diff --git a/crates/collab/src/db/queries/billing_customers.rs b/crates/collab/src/db/queries/billing_customers.rs new file mode 100644 index 0000000000..9a0c00a0d3 --- /dev/null +++ b/crates/collab/src/db/queries/billing_customers.rs @@ -0,0 +1,42 @@ +use super::*; + +#[derive(Debug)] +pub struct CreateBillingCustomerParams { + pub user_id: UserId, + pub stripe_customer_id: String, +} + +impl Database { + /// Creates a new billing customer. + pub async fn create_billing_customer( + &self, + params: &CreateBillingCustomerParams, + ) -> Result { + self.transaction(|tx| async move { + let customer = billing_customer::Entity::insert(billing_customer::ActiveModel { + user_id: ActiveValue::set(params.user_id), + stripe_customer_id: ActiveValue::set(params.stripe_customer_id.clone()), + ..Default::default() + }) + .exec_with_returning(&*tx) + .await?; + + Ok(customer) + }) + .await + } + + /// Returns the billing customer for the user with the specified ID. + pub async fn get_billing_customer_by_user_id( + &self, + user_id: UserId, + ) -> Result> { + self.transaction(|tx| async move { + Ok(billing_customer::Entity::find() + .filter(billing_customer::Column::UserId.eq(user_id)) + .one(&*tx) + .await?) + }) + .await + } +} diff --git a/crates/collab/src/db/queries/billing_subscriptions.rs b/crates/collab/src/db/queries/billing_subscriptions.rs index 42d1a4f180..0b11f25aa3 100644 --- a/crates/collab/src/db/queries/billing_subscriptions.rs +++ b/crates/collab/src/db/queries/billing_subscriptions.rs @@ -4,8 +4,7 @@ use super::*; #[derive(Debug)] pub struct CreateBillingSubscriptionParams { - pub user_id: UserId, - pub stripe_customer_id: String, + pub billing_customer_id: BillingCustomerId, pub stripe_subscription_id: String, pub stripe_subscription_status: StripeSubscriptionStatus, } @@ -18,8 +17,7 @@ impl Database { ) -> Result<()> { self.transaction(|tx| async move { billing_subscription::Entity::insert(billing_subscription::ActiveModel { - user_id: ActiveValue::set(params.user_id), - stripe_customer_id: ActiveValue::set(params.stripe_customer_id.clone()), + billing_customer_id: ActiveValue::set(params.billing_customer_id), stripe_subscription_id: ActiveValue::set(params.stripe_subscription_id.clone()), stripe_subscription_status: ActiveValue::set(params.stripe_subscription_status), ..Default::default() @@ -56,7 +54,8 @@ impl Database { ) -> Result> { self.transaction(|tx| async move { let subscriptions = billing_subscription::Entity::find() - .filter(billing_subscription::Column::UserId.eq(user_id)) + .inner_join(billing_customer::Entity) + .filter(billing_customer::Column::UserId.eq(user_id)) .order_by_asc(billing_subscription::Column::Id) .all(&*tx) .await?; @@ -73,8 +72,9 @@ impl Database { ) -> Result> { self.transaction(|tx| async move { let subscriptions = billing_subscription::Entity::find() + .inner_join(billing_customer::Entity) .filter( - billing_subscription::Column::UserId.eq(user_id).and( + billing_customer::Column::UserId.eq(user_id).and( billing_subscription::Column::StripeSubscriptionStatus .eq(StripeSubscriptionStatus::Active), ), diff --git a/crates/collab/src/db/tables.rs b/crates/collab/src/db/tables.rs index 2e5a9bebc1..d3105ede76 100644 --- a/crates/collab/src/db/tables.rs +++ b/crates/collab/src/db/tables.rs @@ -1,4 +1,5 @@ pub mod access_token; +pub mod billing_customer; pub mod billing_subscription; pub mod buffer; pub mod buffer_operation; diff --git a/crates/collab/src/db/tables/billing_customer.rs b/crates/collab/src/db/tables/billing_customer.rs new file mode 100644 index 0000000000..258a7e0c0c --- /dev/null +++ b/crates/collab/src/db/tables/billing_customer.rs @@ -0,0 +1,39 @@ +use crate::db::{BillingCustomerId, UserId}; +use sea_orm::entity::prelude::*; + +/// A billing customer. +#[derive(Clone, Debug, Default, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "billing_customers")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: BillingCustomerId, + pub user_id: UserId, + pub stripe_customer_id: String, + pub created_at: DateTime, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::user::Entity", + from = "Column::UserId", + to = "super::user::Column::Id" + )] + User, + #[sea_orm(has_many = "super::billing_subscription::Entity")] + BillingSubscription, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::User.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::BillingSubscription.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/db/tables/billing_subscription.rs b/crates/collab/src/db/tables/billing_subscription.rs index 936352ff1a..4cbde6bec0 100644 --- a/crates/collab/src/db/tables/billing_subscription.rs +++ b/crates/collab/src/db/tables/billing_subscription.rs @@ -1,4 +1,4 @@ -use crate::db::{BillingSubscriptionId, UserId}; +use crate::db::{BillingCustomerId, BillingSubscriptionId}; use sea_orm::entity::prelude::*; /// A billing subscription. @@ -7,8 +7,7 @@ use sea_orm::entity::prelude::*; pub struct Model { #[sea_orm(primary_key)] pub id: BillingSubscriptionId, - pub user_id: UserId, - pub stripe_customer_id: String, + pub billing_customer_id: BillingCustomerId, pub stripe_subscription_id: String, pub stripe_subscription_status: StripeSubscriptionStatus, pub created_at: DateTime, @@ -17,16 +16,16 @@ pub struct Model { #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] pub enum Relation { #[sea_orm( - belongs_to = "super::user::Entity", - from = "Column::UserId", - to = "super::user::Column::Id" + belongs_to = "super::billing_customer::Entity", + from = "Column::BillingCustomerId", + to = "super::billing_customer::Column::Id" )] - User, + BillingCustomer, } -impl Related for Entity { +impl Related for Entity { fn to() -> RelationDef { - Relation::User.def() + Relation::BillingCustomer.def() } } diff --git a/crates/collab/src/db/tables/user.rs b/crates/collab/src/db/tables/user.rs index 979abe9299..a801e6383e 100644 --- a/crates/collab/src/db/tables/user.rs +++ b/crates/collab/src/db/tables/user.rs @@ -24,8 +24,8 @@ pub struct Model { pub enum Relation { #[sea_orm(has_many = "super::access_token::Entity")] AccessToken, - #[sea_orm(has_many = "super::billing_subscription::Entity")] - BillingSubscription, + #[sea_orm(has_one = "super::billing_customer::Entity")] + BillingCustomer, #[sea_orm(has_one = "super::room_participant::Entity")] RoomParticipant, #[sea_orm(has_many = "super::project::Entity")] @@ -44,6 +44,12 @@ impl Related for Entity { } } +impl Related for Entity { + fn to() -> RelationDef { + Relation::BillingCustomer.def() + } +} + impl Related for Entity { fn to() -> RelationDef { Relation::RoomParticipant.def() diff --git a/crates/collab/src/db/tests/billing_subscription_tests.rs b/crates/collab/src/db/tests/billing_subscription_tests.rs index 26a86fe449..19f5463ac2 100644 --- a/crates/collab/src/db/tests/billing_subscription_tests.rs +++ b/crates/collab/src/db/tests/billing_subscription_tests.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use crate::db::billing_subscription::StripeSubscriptionStatus; use crate::db::tests::new_test_user; -use crate::db::CreateBillingSubscriptionParams; +use crate::db::{CreateBillingCustomerParams, CreateBillingSubscriptionParams}; use crate::test_both_dbs; use super::Database; @@ -25,9 +25,17 @@ async fn test_get_active_billing_subscriptions(db: &Arc) { // A user with an active subscription has one active billing subscription. { let user_id = new_test_user(db, "active-user@example.com").await; + let customer = db + .create_billing_customer(&CreateBillingCustomerParams { + user_id, + stripe_customer_id: "cus_active_user".into(), + }) + .await + .unwrap(); + assert_eq!(customer.stripe_customer_id, "cus_active_user".to_string()); + db.create_billing_subscription(&CreateBillingSubscriptionParams { - user_id, - stripe_customer_id: "cus_active_user".into(), + billing_customer_id: customer.id, stripe_subscription_id: "sub_active_user".into(), stripe_subscription_status: StripeSubscriptionStatus::Active, }) @@ -38,10 +46,6 @@ async fn test_get_active_billing_subscriptions(db: &Arc) { assert_eq!(subscriptions.len(), 1); let subscription = &subscriptions[0]; - assert_eq!( - subscription.stripe_customer_id, - "cus_active_user".to_string() - ); assert_eq!( subscription.stripe_subscription_id, "sub_active_user".to_string() @@ -55,9 +59,17 @@ async fn test_get_active_billing_subscriptions(db: &Arc) { // A user with a past-due subscription has no active billing subscriptions. { let user_id = new_test_user(db, "past-due-user@example.com").await; + let customer = db + .create_billing_customer(&CreateBillingCustomerParams { + user_id, + stripe_customer_id: "cus_past_due_user".into(), + }) + .await + .unwrap(); + assert_eq!(customer.stripe_customer_id, "cus_past_due_user".to_string()); + db.create_billing_subscription(&CreateBillingSubscriptionParams { - user_id, - stripe_customer_id: "cus_past_due_user".into(), + billing_customer_id: customer.id, stripe_subscription_id: "sub_past_due_user".into(), stripe_subscription_status: StripeSubscriptionStatus::PastDue, })