diff --git a/crates/collab/src/llm/authorization.rs b/crates/collab/src/llm/authorization.rs index 8a9945d739..5c0295d9fb 100644 --- a/crates/collab/src/llm/authorization.rs +++ b/crates/collab/src/llm/authorization.rs @@ -6,21 +6,40 @@ use crate::{Config, Error, Result}; pub fn authorize_access_to_language_model( config: &Config, - _claims: &LlmTokenClaims, + claims: &LlmTokenClaims, country_code: Option, provider: LanguageModelProvider, model: &str, ) -> Result<()> { - authorize_access_for_country(config, country_code, provider, model)?; - + authorize_access_for_country(config, country_code, provider)?; + authorize_access_to_model(claims, provider, model)?; Ok(()) } +fn authorize_access_to_model( + claims: &LlmTokenClaims, + provider: LanguageModelProvider, + model: &str, +) -> Result<()> { + if claims.is_staff { + return Ok(()); + } + + match (provider, model) { + (LanguageModelProvider::Anthropic, model) if model.starts_with("claude-3.5-sonnet") => { + Ok(()) + } + _ => Err(Error::http( + StatusCode::FORBIDDEN, + format!("access to model {model:?} is not included in your plan"), + ))?, + } +} + fn authorize_access_for_country( config: &Config, country_code: Option, provider: LanguageModelProvider, - _model: &str, ) -> Result<()> { // In development we won't have the `CF-IPCountry` header, so we can't check // the country code. @@ -79,6 +98,7 @@ mod tests { let claims = LlmTokenClaims { user_id: 99, plan: Plan::ZedPro, + is_staff: true, ..Default::default() }; @@ -210,4 +230,101 @@ mod tests { ); } } + + #[gpui::test] + async fn test_authorize_access_to_language_model_based_on_plan() { + let config = Config::test(); + + let test_cases = vec![ + // Pro plan should have access to claude-3.5-sonnet + ( + Plan::ZedPro, + LanguageModelProvider::Anthropic, + "claude-3.5-sonnet", + true, + ), + // Free plan should have access to claude-3.5-sonnet + ( + Plan::Free, + LanguageModelProvider::Anthropic, + "claude-3.5-sonnet", + true, + ), + // Pro plan should NOT have access to other Anthropic models + ( + Plan::ZedPro, + LanguageModelProvider::Anthropic, + "claude-3-opus", + false, + ), + ]; + + for (plan, provider, model, expected_access) in test_cases { + let claims = LlmTokenClaims { + plan, + ..Default::default() + }; + + let result = authorize_access_to_language_model( + &config, + &claims, + Some("US".into()), + provider, + model, + ); + + if expected_access { + assert!( + result.is_ok(), + "Expected access to be granted for plan {:?}, provider {:?}, model {}", + plan, + provider, + model + ); + } else { + let error = result.expect_err(&format!( + "Expected access to be denied for plan {:?}, provider {:?}, model {}", + plan, provider, model + )); + let response = error.into_response(); + assert_eq!(response.status(), StatusCode::FORBIDDEN); + } + } + } + + #[gpui::test] + async fn test_authorize_access_to_language_model_for_staff() { + let config = Config::test(); + + let claims = LlmTokenClaims { + is_staff: true, + ..Default::default() + }; + + // Staff should have access to all models + let test_cases = vec![ + (LanguageModelProvider::Anthropic, "claude-3.5-sonnet"), + (LanguageModelProvider::Anthropic, "claude-2"), + (LanguageModelProvider::Anthropic, "claude-123-agi"), + (LanguageModelProvider::OpenAi, "gpt-4"), + (LanguageModelProvider::Google, "gemini-pro"), + ]; + + for (provider, model) in test_cases { + let result = authorize_access_to_language_model( + &config, + &claims, + Some("US".into()), + provider, + model, + ); + + assert!( + result.is_ok(), + "Expected staff to have access to provider {:?}, model {}", + provider, + model + ); + } + } } diff --git a/crates/collab/src/llm/token.rs b/crates/collab/src/llm/token.rs index 99386443eb..e2350a853a 100644 --- a/crates/collab/src/llm/token.rs +++ b/crates/collab/src/llm/token.rs @@ -13,13 +13,19 @@ pub struct LlmTokenClaims { pub exp: u64, pub jti: String, pub user_id: u64, + pub is_staff: bool, pub plan: rpc::proto::Plan, } const LLM_TOKEN_LIFETIME: Duration = Duration::from_secs(60 * 60); impl LlmTokenClaims { - pub fn create(user_id: UserId, plan: rpc::proto::Plan, config: &Config) -> Result { + pub fn create( + user_id: UserId, + is_staff: bool, + plan: rpc::proto::Plan, + config: &Config, + ) -> Result { let secret = config .llm_api_secret .as_ref() @@ -31,6 +37,7 @@ impl LlmTokenClaims { exp: (now + LLM_TOKEN_LIFETIME).timestamp() as u64, jti: uuid::Uuid::new_v4().to_string(), user_id: user_id.to_proto(), + is_staff, plan, }; diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 361e4fe237..2b34546ba2 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -5164,6 +5164,7 @@ async fn get_llm_api_token( let token = LlmTokenClaims::create( session.user_id(), + session.is_staff(), session.current_plan().await?, &session.app_state.config, )?;