collab: Add separate billing_customers table (#15457)

This PR adds a new `billing_customers` table to hold the billing
customers.

Previously we were storing both the `stripe_customer_id` and
`stripe_subscription_id` in the `billable_subscriptions` table. However,
this creates problems when we need to correlate subscription events back
to the subscription record, as we don't know the user that the Stripe
event corresponds to.

By moving the `stripe_customer_id` to a separate table we can create the
Stripe customer earlier in the flow—before we create the Stripe Checkout
session—and associate that customer with a user. This way when we
receive events down the line we can use the Stripe customer ID to
correlate it back to the user.

We're doing some destructive actions to the `billing_subscriptions`
table, but this is fine, as we haven't started using them yet.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2024-07-29 22:48:21 -04:00 committed by GitHub
parent 66121fa0e8
commit 28c14cdee4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 183 additions and 52 deletions

View file

@ -420,12 +420,20 @@ CREATE TABLE dev_server_projects (
CREATE TABLE IF NOT EXISTS billing_subscriptions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
user_id INTEGER NOT NULL REFERENCES users(id),
stripe_customer_id TEXT NOT NULL,
billing_customer_id INTEGER NOT NULL REFERENCES billing_customers(id),
stripe_subscription_id TEXT NOT NULL,
stripe_subscription_status TEXT NOT NULL
);
CREATE INDEX "ix_billing_subscriptions_on_user_id" ON billing_subscriptions (user_id);
CREATE INDEX "ix_billing_subscriptions_on_stripe_customer_id" ON billing_subscriptions (stripe_customer_id);
CREATE INDEX "ix_billing_subscriptions_on_billing_customer_id" ON billing_subscriptions (billing_customer_id);
CREATE UNIQUE INDEX "uix_billing_subscriptions_on_stripe_subscription_id" ON billing_subscriptions (stripe_subscription_id);
CREATE TABLE IF NOT EXISTS billing_customers (
id INTEGER PRIMARY KEY AUTOINCREMENT,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
user_id INTEGER NOT NULL REFERENCES users(id),
stripe_customer_id TEXT NOT NULL
);
CREATE UNIQUE INDEX "uix_billing_customers_on_user_id" ON billing_customers (user_id);
CREATE UNIQUE INDEX "uix_billing_customers_on_stripe_customer_id" ON billing_customers (stripe_customer_id);

View file

@ -0,0 +1,18 @@
CREATE TABLE IF NOT EXISTS billing_customers (
id SERIAL PRIMARY KEY,
created_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT now(),
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
stripe_customer_id TEXT NOT NULL
);
CREATE UNIQUE INDEX "uix_billing_customers_on_user_id" ON billing_customers (user_id);
CREATE UNIQUE INDEX "uix_billing_customers_on_stripe_customer_id" ON billing_customers (stripe_customer_id);
-- Make `billing_subscriptions` reference `billing_customers` instead of having its
-- own `user_id` and `stripe_customer_id`.
DROP INDEX IF EXISTS "ix_billing_subscriptions_on_user_id";
DROP INDEX IF EXISTS "ix_billing_subscriptions_on_stripe_customer_id";
ALTER TABLE billing_subscriptions DROP COLUMN user_id;
ALTER TABLE billing_subscriptions DROP COLUMN stripe_customer_id;
ALTER TABLE billing_subscriptions ADD COLUMN billing_customer_id INTEGER NOT NULL REFERENCES billing_customers (id) ON DELETE CASCADE;
CREATE INDEX "ix_billing_subscriptions_on_billing_customer_id" ON billing_subscriptions (billing_customer_id);

View file

@ -3,7 +3,6 @@ use std::sync::Arc;
use anyhow::{anyhow, Context};
use axum::{extract, routing::post, Extension, Json, Router};
use collections::HashSet;
use reqwest::StatusCode;
use serde::{Deserialize, Serialize};
use stripe::{
@ -11,7 +10,7 @@ use stripe::{
CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
CreateBillingPortalSessionFlowDataType, CreateCheckoutSession, CreateCheckoutSessionLineItems,
CustomerId,
CreateCustomer, Customer, CustomerId,
};
use crate::db::BillingSubscriptionId;
@ -59,28 +58,27 @@ async fn create_billing_subscription(
))?
};
let existing_customer_id = {
let existing_subscriptions = app.db.get_billing_subscriptions(user.id).await?;
let distinct_customer_ids = existing_subscriptions
.iter()
.map(|subscription| subscription.stripe_customer_id.as_str())
.collect::<HashSet<_>>();
// Sanity: Make sure we can determine a single Stripe customer ID for the user.
if distinct_customer_ids.len() > 1 {
Err(anyhow!("user has multiple existing customer IDs"))?;
}
let customer_id =
if let Some(existing_customer) = app.db.get_billing_customer_by_user_id(user.id).await? {
CustomerId::from_str(&existing_customer.stripe_customer_id)
.context("failed to parse customer ID")?
} else {
let customer = Customer::create(
&stripe_client,
CreateCustomer {
email: user.email_address.as_deref(),
..Default::default()
},
)
.await?;
distinct_customer_ids
.into_iter()
.next()
.map(|id| CustomerId::from_str(id).context("failed to parse customer ID"))
.transpose()
}?;
customer.id
};
let checkout_session = {
let mut params = CreateCheckoutSession::new();
params.mode = Some(stripe::CheckoutSessionMode::Subscription);
params.customer = existing_customer_id;
params.customer = Some(customer_id);
params.client_reference_id = Some(user.github_login.as_str());
params.line_items = Some(vec![CreateCheckoutSessionLineItems {
price: Some(stripe_price_id.to_string()),
@ -140,6 +138,14 @@ async fn manage_billing_subscription(
))?
};
let customer = app
.db
.get_billing_customer_by_user_id(user.id)
.await?
.ok_or_else(|| anyhow!("billing customer not found"))?;
let customer_id = CustomerId::from_str(&customer.stripe_customer_id)
.context("failed to parse customer ID")?;
let subscription = if let Some(subscription_id) = body.subscription_id {
app.db
.get_billing_subscription_by_id(subscription_id)
@ -158,9 +164,6 @@ async fn manage_billing_subscription(
.ok_or_else(|| anyhow!("user has no active subscriptions"))?
};
let customer_id = CustomerId::from_str(&subscription.stripe_customer_id)
.context("failed to parse customer ID")?;
let flow = match body.intent {
ManageSubscriptionIntent::Cancel => CreateBillingPortalSessionFlowData {
type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,

View file

@ -45,6 +45,7 @@ use tokio::sync::{Mutex, OwnedMutexGuard};
pub use tests::TestDb;
pub use ids::*;
pub use queries::billing_customers::CreateBillingCustomerParams;
pub use queries::billing_subscriptions::CreateBillingSubscriptionParams;
pub use queries::contributors::ContributorSelector;
pub use sea_orm::ConnectOptions;

View file

@ -68,6 +68,7 @@ macro_rules! id_type {
}
id_type!(AccessTokenId);
id_type!(BillingCustomerId);
id_type!(BillingSubscriptionId);
id_type!(BufferId);
id_type!(ChannelBufferCollaboratorId);

View file

@ -1,6 +1,7 @@
use super::*;
pub mod access_tokens;
pub mod billing_customers;
pub mod billing_subscriptions;
pub mod buffers;
pub mod channels;

View file

@ -0,0 +1,42 @@
use super::*;
#[derive(Debug)]
pub struct CreateBillingCustomerParams {
pub user_id: UserId,
pub stripe_customer_id: String,
}
impl Database {
/// Creates a new billing customer.
pub async fn create_billing_customer(
&self,
params: &CreateBillingCustomerParams,
) -> Result<billing_customer::Model> {
self.transaction(|tx| async move {
let customer = billing_customer::Entity::insert(billing_customer::ActiveModel {
user_id: ActiveValue::set(params.user_id),
stripe_customer_id: ActiveValue::set(params.stripe_customer_id.clone()),
..Default::default()
})
.exec_with_returning(&*tx)
.await?;
Ok(customer)
})
.await
}
/// Returns the billing customer for the user with the specified ID.
pub async fn get_billing_customer_by_user_id(
&self,
user_id: UserId,
) -> Result<Option<billing_customer::Model>> {
self.transaction(|tx| async move {
Ok(billing_customer::Entity::find()
.filter(billing_customer::Column::UserId.eq(user_id))
.one(&*tx)
.await?)
})
.await
}
}

View file

@ -4,8 +4,7 @@ use super::*;
#[derive(Debug)]
pub struct CreateBillingSubscriptionParams {
pub user_id: UserId,
pub stripe_customer_id: String,
pub billing_customer_id: BillingCustomerId,
pub stripe_subscription_id: String,
pub stripe_subscription_status: StripeSubscriptionStatus,
}
@ -18,8 +17,7 @@ impl Database {
) -> Result<()> {
self.transaction(|tx| async move {
billing_subscription::Entity::insert(billing_subscription::ActiveModel {
user_id: ActiveValue::set(params.user_id),
stripe_customer_id: ActiveValue::set(params.stripe_customer_id.clone()),
billing_customer_id: ActiveValue::set(params.billing_customer_id),
stripe_subscription_id: ActiveValue::set(params.stripe_subscription_id.clone()),
stripe_subscription_status: ActiveValue::set(params.stripe_subscription_status),
..Default::default()
@ -56,7 +54,8 @@ impl Database {
) -> Result<Vec<billing_subscription::Model>> {
self.transaction(|tx| async move {
let subscriptions = billing_subscription::Entity::find()
.filter(billing_subscription::Column::UserId.eq(user_id))
.inner_join(billing_customer::Entity)
.filter(billing_customer::Column::UserId.eq(user_id))
.order_by_asc(billing_subscription::Column::Id)
.all(&*tx)
.await?;
@ -73,8 +72,9 @@ impl Database {
) -> Result<Vec<billing_subscription::Model>> {
self.transaction(|tx| async move {
let subscriptions = billing_subscription::Entity::find()
.inner_join(billing_customer::Entity)
.filter(
billing_subscription::Column::UserId.eq(user_id).and(
billing_customer::Column::UserId.eq(user_id).and(
billing_subscription::Column::StripeSubscriptionStatus
.eq(StripeSubscriptionStatus::Active),
),

View file

@ -1,4 +1,5 @@
pub mod access_token;
pub mod billing_customer;
pub mod billing_subscription;
pub mod buffer;
pub mod buffer_operation;

View file

@ -0,0 +1,39 @@
use crate::db::{BillingCustomerId, UserId};
use sea_orm::entity::prelude::*;
/// A billing customer.
#[derive(Clone, Debug, Default, PartialEq, Eq, DeriveEntityModel)]
#[sea_orm(table_name = "billing_customers")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: BillingCustomerId,
pub user_id: UserId,
pub stripe_customer_id: String,
pub created_at: DateTime,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(
belongs_to = "super::user::Entity",
from = "Column::UserId",
to = "super::user::Column::Id"
)]
User,
#[sea_orm(has_many = "super::billing_subscription::Entity")]
BillingSubscription,
}
impl Related<super::user::Entity> for Entity {
fn to() -> RelationDef {
Relation::User.def()
}
}
impl Related<super::billing_subscription::Entity> for Entity {
fn to() -> RelationDef {
Relation::BillingSubscription.def()
}
}
impl ActiveModelBehavior for ActiveModel {}

View file

@ -1,4 +1,4 @@
use crate::db::{BillingSubscriptionId, UserId};
use crate::db::{BillingCustomerId, BillingSubscriptionId};
use sea_orm::entity::prelude::*;
/// A billing subscription.
@ -7,8 +7,7 @@ use sea_orm::entity::prelude::*;
pub struct Model {
#[sea_orm(primary_key)]
pub id: BillingSubscriptionId,
pub user_id: UserId,
pub stripe_customer_id: String,
pub billing_customer_id: BillingCustomerId,
pub stripe_subscription_id: String,
pub stripe_subscription_status: StripeSubscriptionStatus,
pub created_at: DateTime,
@ -17,16 +16,16 @@ pub struct Model {
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(
belongs_to = "super::user::Entity",
from = "Column::UserId",
to = "super::user::Column::Id"
belongs_to = "super::billing_customer::Entity",
from = "Column::BillingCustomerId",
to = "super::billing_customer::Column::Id"
)]
User,
BillingCustomer,
}
impl Related<super::user::Entity> for Entity {
impl Related<super::billing_customer::Entity> for Entity {
fn to() -> RelationDef {
Relation::User.def()
Relation::BillingCustomer.def()
}
}

View file

@ -24,8 +24,8 @@ pub struct Model {
pub enum Relation {
#[sea_orm(has_many = "super::access_token::Entity")]
AccessToken,
#[sea_orm(has_many = "super::billing_subscription::Entity")]
BillingSubscription,
#[sea_orm(has_one = "super::billing_customer::Entity")]
BillingCustomer,
#[sea_orm(has_one = "super::room_participant::Entity")]
RoomParticipant,
#[sea_orm(has_many = "super::project::Entity")]
@ -44,6 +44,12 @@ impl Related<super::access_token::Entity> for Entity {
}
}
impl Related<super::billing_customer::Entity> for Entity {
fn to() -> RelationDef {
Relation::BillingCustomer.def()
}
}
impl Related<super::room_participant::Entity> for Entity {
fn to() -> RelationDef {
Relation::RoomParticipant.def()

View file

@ -2,7 +2,7 @@ use std::sync::Arc;
use crate::db::billing_subscription::StripeSubscriptionStatus;
use crate::db::tests::new_test_user;
use crate::db::CreateBillingSubscriptionParams;
use crate::db::{CreateBillingCustomerParams, CreateBillingSubscriptionParams};
use crate::test_both_dbs;
use super::Database;
@ -25,9 +25,17 @@ async fn test_get_active_billing_subscriptions(db: &Arc<Database>) {
// A user with an active subscription has one active billing subscription.
{
let user_id = new_test_user(db, "active-user@example.com").await;
let customer = db
.create_billing_customer(&CreateBillingCustomerParams {
user_id,
stripe_customer_id: "cus_active_user".into(),
})
.await
.unwrap();
assert_eq!(customer.stripe_customer_id, "cus_active_user".to_string());
db.create_billing_subscription(&CreateBillingSubscriptionParams {
user_id,
stripe_customer_id: "cus_active_user".into(),
billing_customer_id: customer.id,
stripe_subscription_id: "sub_active_user".into(),
stripe_subscription_status: StripeSubscriptionStatus::Active,
})
@ -38,10 +46,6 @@ async fn test_get_active_billing_subscriptions(db: &Arc<Database>) {
assert_eq!(subscriptions.len(), 1);
let subscription = &subscriptions[0];
assert_eq!(
subscription.stripe_customer_id,
"cus_active_user".to_string()
);
assert_eq!(
subscription.stripe_subscription_id,
"sub_active_user".to_string()
@ -55,9 +59,17 @@ async fn test_get_active_billing_subscriptions(db: &Arc<Database>) {
// A user with a past-due subscription has no active billing subscriptions.
{
let user_id = new_test_user(db, "past-due-user@example.com").await;
let customer = db
.create_billing_customer(&CreateBillingCustomerParams {
user_id,
stripe_customer_id: "cus_past_due_user".into(),
})
.await
.unwrap();
assert_eq!(customer.stripe_customer_id, "cus_past_due_user".to_string());
db.create_billing_subscription(&CreateBillingSubscriptionParams {
user_id,
stripe_customer_id: "cus_past_due_user".into(),
billing_customer_id: customer.id,
stripe_subscription_id: "sub_past_due_user".into(),
stripe_subscription_status: StripeSubscriptionStatus::PastDue,
})