mirror of
https://github.com/zed-industries/zed.git
synced 2025-02-10 04:09:37 +00:00
assistant: Limit model access for Zed AI users to Claude-3.5-sonnet (#15904)
This prevents users from accessing other models, such as OpenAI's GPT-4 or Google's Gemini-Pro. Staff members can still access all models. Co-authored-by: Thorsten <thorsten@zed.dev> Release Notes: - N/A --------- Co-authored-by: Thorsten <thorsten@zed.dev>
This commit is contained in:
parent
efbf7ada28
commit
3a52d6cc52
3 changed files with 130 additions and 5 deletions
|
@ -6,21 +6,40 @@ use crate::{Config, Error, Result};
|
|||
|
||||
pub fn authorize_access_to_language_model(
|
||||
config: &Config,
|
||||
_claims: &LlmTokenClaims,
|
||||
claims: &LlmTokenClaims,
|
||||
country_code: Option<String>,
|
||||
provider: LanguageModelProvider,
|
||||
model: &str,
|
||||
) -> Result<()> {
|
||||
authorize_access_for_country(config, country_code, provider, model)?;
|
||||
|
||||
authorize_access_for_country(config, country_code, provider)?;
|
||||
authorize_access_to_model(claims, provider, model)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn authorize_access_to_model(
|
||||
claims: &LlmTokenClaims,
|
||||
provider: LanguageModelProvider,
|
||||
model: &str,
|
||||
) -> Result<()> {
|
||||
if claims.is_staff {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
match (provider, model) {
|
||||
(LanguageModelProvider::Anthropic, model) if model.starts_with("claude-3.5-sonnet") => {
|
||||
Ok(())
|
||||
}
|
||||
_ => Err(Error::http(
|
||||
StatusCode::FORBIDDEN,
|
||||
format!("access to model {model:?} is not included in your plan"),
|
||||
))?,
|
||||
}
|
||||
}
|
||||
|
||||
fn authorize_access_for_country(
|
||||
config: &Config,
|
||||
country_code: Option<String>,
|
||||
provider: LanguageModelProvider,
|
||||
_model: &str,
|
||||
) -> Result<()> {
|
||||
// In development we won't have the `CF-IPCountry` header, so we can't check
|
||||
// the country code.
|
||||
|
@ -79,6 +98,7 @@ mod tests {
|
|||
let claims = LlmTokenClaims {
|
||||
user_id: 99,
|
||||
plan: Plan::ZedPro,
|
||||
is_staff: true,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
|
@ -210,4 +230,101 @@ mod tests {
|
|||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_authorize_access_to_language_model_based_on_plan() {
|
||||
let config = Config::test();
|
||||
|
||||
let test_cases = vec![
|
||||
// Pro plan should have access to claude-3.5-sonnet
|
||||
(
|
||||
Plan::ZedPro,
|
||||
LanguageModelProvider::Anthropic,
|
||||
"claude-3.5-sonnet",
|
||||
true,
|
||||
),
|
||||
// Free plan should have access to claude-3.5-sonnet
|
||||
(
|
||||
Plan::Free,
|
||||
LanguageModelProvider::Anthropic,
|
||||
"claude-3.5-sonnet",
|
||||
true,
|
||||
),
|
||||
// Pro plan should NOT have access to other Anthropic models
|
||||
(
|
||||
Plan::ZedPro,
|
||||
LanguageModelProvider::Anthropic,
|
||||
"claude-3-opus",
|
||||
false,
|
||||
),
|
||||
];
|
||||
|
||||
for (plan, provider, model, expected_access) in test_cases {
|
||||
let claims = LlmTokenClaims {
|
||||
plan,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = authorize_access_to_language_model(
|
||||
&config,
|
||||
&claims,
|
||||
Some("US".into()),
|
||||
provider,
|
||||
model,
|
||||
);
|
||||
|
||||
if expected_access {
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"Expected access to be granted for plan {:?}, provider {:?}, model {}",
|
||||
plan,
|
||||
provider,
|
||||
model
|
||||
);
|
||||
} else {
|
||||
let error = result.expect_err(&format!(
|
||||
"Expected access to be denied for plan {:?}, provider {:?}, model {}",
|
||||
plan, provider, model
|
||||
));
|
||||
let response = error.into_response();
|
||||
assert_eq!(response.status(), StatusCode::FORBIDDEN);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_authorize_access_to_language_model_for_staff() {
|
||||
let config = Config::test();
|
||||
|
||||
let claims = LlmTokenClaims {
|
||||
is_staff: true,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Staff should have access to all models
|
||||
let test_cases = vec![
|
||||
(LanguageModelProvider::Anthropic, "claude-3.5-sonnet"),
|
||||
(LanguageModelProvider::Anthropic, "claude-2"),
|
||||
(LanguageModelProvider::Anthropic, "claude-123-agi"),
|
||||
(LanguageModelProvider::OpenAi, "gpt-4"),
|
||||
(LanguageModelProvider::Google, "gemini-pro"),
|
||||
];
|
||||
|
||||
for (provider, model) in test_cases {
|
||||
let result = authorize_access_to_language_model(
|
||||
&config,
|
||||
&claims,
|
||||
Some("US".into()),
|
||||
provider,
|
||||
model,
|
||||
);
|
||||
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"Expected staff to have access to provider {:?}, model {}",
|
||||
provider,
|
||||
model
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -13,13 +13,19 @@ pub struct LlmTokenClaims {
|
|||
pub exp: u64,
|
||||
pub jti: String,
|
||||
pub user_id: u64,
|
||||
pub is_staff: bool,
|
||||
pub plan: rpc::proto::Plan,
|
||||
}
|
||||
|
||||
const LLM_TOKEN_LIFETIME: Duration = Duration::from_secs(60 * 60);
|
||||
|
||||
impl LlmTokenClaims {
|
||||
pub fn create(user_id: UserId, plan: rpc::proto::Plan, config: &Config) -> Result<String> {
|
||||
pub fn create(
|
||||
user_id: UserId,
|
||||
is_staff: bool,
|
||||
plan: rpc::proto::Plan,
|
||||
config: &Config,
|
||||
) -> Result<String> {
|
||||
let secret = config
|
||||
.llm_api_secret
|
||||
.as_ref()
|
||||
|
@ -31,6 +37,7 @@ impl LlmTokenClaims {
|
|||
exp: (now + LLM_TOKEN_LIFETIME).timestamp() as u64,
|
||||
jti: uuid::Uuid::new_v4().to_string(),
|
||||
user_id: user_id.to_proto(),
|
||||
is_staff,
|
||||
plan,
|
||||
};
|
||||
|
||||
|
|
|
@ -5164,6 +5164,7 @@ async fn get_llm_api_token(
|
|||
|
||||
let token = LlmTokenClaims::create(
|
||||
session.user_id(),
|
||||
session.is_staff(),
|
||||
session.current_plan().await?,
|
||||
&session.app_state.config,
|
||||
)?;
|
||||
|
|
Loading…
Reference in a new issue