assistant: Limit model access for Zed AI users to Claude-3.5-sonnet (#15904)

This prevents users from accessing other models, such as OpenAI's GPT-4
or Google's Gemini-Pro.
Staff members can still access all models.

Co-authored-by: Thorsten <thorsten@zed.dev>

Release Notes:

- N/A

---------

Co-authored-by: Thorsten <thorsten@zed.dev>
This commit is contained in:
Bennet Bo Fenner 2024-08-07 16:26:56 +02:00 committed by GitHub
parent efbf7ada28
commit 3a52d6cc52
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 130 additions and 5 deletions

View file

@ -6,21 +6,40 @@ use crate::{Config, Error, Result};
pub fn authorize_access_to_language_model(
config: &Config,
_claims: &LlmTokenClaims,
claims: &LlmTokenClaims,
country_code: Option<String>,
provider: LanguageModelProvider,
model: &str,
) -> Result<()> {
authorize_access_for_country(config, country_code, provider, model)?;
authorize_access_for_country(config, country_code, provider)?;
authorize_access_to_model(claims, provider, model)?;
Ok(())
}
fn authorize_access_to_model(
claims: &LlmTokenClaims,
provider: LanguageModelProvider,
model: &str,
) -> Result<()> {
if claims.is_staff {
return Ok(());
}
match (provider, model) {
(LanguageModelProvider::Anthropic, model) if model.starts_with("claude-3.5-sonnet") => {
Ok(())
}
_ => Err(Error::http(
StatusCode::FORBIDDEN,
format!("access to model {model:?} is not included in your plan"),
))?,
}
}
fn authorize_access_for_country(
config: &Config,
country_code: Option<String>,
provider: LanguageModelProvider,
_model: &str,
) -> Result<()> {
// In development we won't have the `CF-IPCountry` header, so we can't check
// the country code.
@ -79,6 +98,7 @@ mod tests {
let claims = LlmTokenClaims {
user_id: 99,
plan: Plan::ZedPro,
is_staff: true,
..Default::default()
};
@ -210,4 +230,101 @@ mod tests {
);
}
}
#[gpui::test]
async fn test_authorize_access_to_language_model_based_on_plan() {
let config = Config::test();
let test_cases = vec![
// Pro plan should have access to claude-3.5-sonnet
(
Plan::ZedPro,
LanguageModelProvider::Anthropic,
"claude-3.5-sonnet",
true,
),
// Free plan should have access to claude-3.5-sonnet
(
Plan::Free,
LanguageModelProvider::Anthropic,
"claude-3.5-sonnet",
true,
),
// Pro plan should NOT have access to other Anthropic models
(
Plan::ZedPro,
LanguageModelProvider::Anthropic,
"claude-3-opus",
false,
),
];
for (plan, provider, model, expected_access) in test_cases {
let claims = LlmTokenClaims {
plan,
..Default::default()
};
let result = authorize_access_to_language_model(
&config,
&claims,
Some("US".into()),
provider,
model,
);
if expected_access {
assert!(
result.is_ok(),
"Expected access to be granted for plan {:?}, provider {:?}, model {}",
plan,
provider,
model
);
} else {
let error = result.expect_err(&format!(
"Expected access to be denied for plan {:?}, provider {:?}, model {}",
plan, provider, model
));
let response = error.into_response();
assert_eq!(response.status(), StatusCode::FORBIDDEN);
}
}
}
#[gpui::test]
async fn test_authorize_access_to_language_model_for_staff() {
let config = Config::test();
let claims = LlmTokenClaims {
is_staff: true,
..Default::default()
};
// Staff should have access to all models
let test_cases = vec![
(LanguageModelProvider::Anthropic, "claude-3.5-sonnet"),
(LanguageModelProvider::Anthropic, "claude-2"),
(LanguageModelProvider::Anthropic, "claude-123-agi"),
(LanguageModelProvider::OpenAi, "gpt-4"),
(LanguageModelProvider::Google, "gemini-pro"),
];
for (provider, model) in test_cases {
let result = authorize_access_to_language_model(
&config,
&claims,
Some("US".into()),
provider,
model,
);
assert!(
result.is_ok(),
"Expected staff to have access to provider {:?}, model {}",
provider,
model
);
}
}
}

View file

@ -13,13 +13,19 @@ pub struct LlmTokenClaims {
pub exp: u64,
pub jti: String,
pub user_id: u64,
pub is_staff: bool,
pub plan: rpc::proto::Plan,
}
const LLM_TOKEN_LIFETIME: Duration = Duration::from_secs(60 * 60);
impl LlmTokenClaims {
pub fn create(user_id: UserId, plan: rpc::proto::Plan, config: &Config) -> Result<String> {
pub fn create(
user_id: UserId,
is_staff: bool,
plan: rpc::proto::Plan,
config: &Config,
) -> Result<String> {
let secret = config
.llm_api_secret
.as_ref()
@ -31,6 +37,7 @@ impl LlmTokenClaims {
exp: (now + LLM_TOKEN_LIFETIME).timestamp() as u64,
jti: uuid::Uuid::new_v4().to_string(),
user_id: user_id.to_proto(),
is_staff,
plan,
};

View file

@ -5164,6 +5164,7 @@ async fn get_llm_api_token(
let token = LlmTokenClaims::create(
session.user_id(),
session.is_staff(),
session.current_plan().await?,
&session.app_state.config,
)?;