mirror of
https://github.com/zed-industries/zed.git
synced 2025-01-12 21:32:40 +00:00
google_ai: Add Gemini 2.0 Flash support (#22665)
Release Notes: - Added support for Google's Gemini 2.0 Flash experimental model. Note: Weirdly enough the model is slow on small talk responses like 'hi' (in my tests) but very fast on things that need more tokens like 'write me a snake game in python'. Likely an API problem. TESTED ONLY ON WINDOWS! Would test further but don't have Linux installed and don't have an Mac. Will likely work everywhere. Why?: I think Gemini 2.0 Flash is incredibly good model at coding and following instructions. I think it would be nice to have it in the editor. I did as minimal changes as possible while adding the model and streaming validation. I think it's worth merging the commits as they bring good improvements. --------- Co-authored-by: Marshall Bowers <elliott.codes@gmail.com>
This commit is contained in:
parent
0d30bda740
commit
799e81ffe5
2 changed files with 22 additions and 2 deletions
|
@ -1,6 +1,6 @@
|
||||||
mod supported_countries;
|
mod supported_countries;
|
||||||
|
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, bail, Result};
|
||||||
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
|
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
|
||||||
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
|
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
@ -15,6 +15,20 @@ pub async fn stream_generate_content(
|
||||||
api_key: &str,
|
api_key: &str,
|
||||||
mut request: GenerateContentRequest,
|
mut request: GenerateContentRequest,
|
||||||
) -> Result<BoxStream<'static, Result<GenerateContentResponse>>> {
|
) -> Result<BoxStream<'static, Result<GenerateContentResponse>>> {
|
||||||
|
if request.contents.is_empty() {
|
||||||
|
bail!("Request must contain at least one content item");
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(user_content) = request
|
||||||
|
.contents
|
||||||
|
.iter()
|
||||||
|
.find(|content| content.role == Role::User)
|
||||||
|
{
|
||||||
|
if user_content.parts.is_empty() {
|
||||||
|
bail!("User content must contain at least one part");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let uri = format!(
|
let uri = format!(
|
||||||
"{api_url}/v1beta/models/{model}:streamGenerateContent?alt=sse&key={api_key}",
|
"{api_url}/v1beta/models/{model}:streamGenerateContent?alt=sse&key={api_key}",
|
||||||
model = request.model
|
model = request.model
|
||||||
|
@ -140,7 +154,7 @@ pub struct Content {
|
||||||
pub role: Role,
|
pub role: Role,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize, Serialize)]
|
#[derive(Debug, PartialEq, Deserialize, Serialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub enum Role {
|
pub enum Role {
|
||||||
User,
|
User,
|
||||||
|
@ -291,6 +305,8 @@ pub enum Model {
|
||||||
Gemini15Pro,
|
Gemini15Pro,
|
||||||
#[serde(rename = "gemini-1.5-flash")]
|
#[serde(rename = "gemini-1.5-flash")]
|
||||||
Gemini15Flash,
|
Gemini15Flash,
|
||||||
|
#[serde(rename = "gemini-2.0-flash-exp")]
|
||||||
|
Gemini20Flash,
|
||||||
#[serde(rename = "custom")]
|
#[serde(rename = "custom")]
|
||||||
Custom {
|
Custom {
|
||||||
name: String,
|
name: String,
|
||||||
|
@ -305,6 +321,7 @@ impl Model {
|
||||||
match self {
|
match self {
|
||||||
Model::Gemini15Pro => "gemini-1.5-pro",
|
Model::Gemini15Pro => "gemini-1.5-pro",
|
||||||
Model::Gemini15Flash => "gemini-1.5-flash",
|
Model::Gemini15Flash => "gemini-1.5-flash",
|
||||||
|
Model::Gemini20Flash => "gemini-2.0-flash-exp",
|
||||||
Model::Custom { name, .. } => name,
|
Model::Custom { name, .. } => name,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -313,6 +330,7 @@ impl Model {
|
||||||
match self {
|
match self {
|
||||||
Model::Gemini15Pro => "Gemini 1.5 Pro",
|
Model::Gemini15Pro => "Gemini 1.5 Pro",
|
||||||
Model::Gemini15Flash => "Gemini 1.5 Flash",
|
Model::Gemini15Flash => "Gemini 1.5 Flash",
|
||||||
|
Model::Gemini20Flash => "Gemini 2.0 Flash",
|
||||||
Self::Custom {
|
Self::Custom {
|
||||||
name, display_name, ..
|
name, display_name, ..
|
||||||
} => display_name.as_ref().unwrap_or(name),
|
} => display_name.as_ref().unwrap_or(name),
|
||||||
|
@ -323,6 +341,7 @@ impl Model {
|
||||||
match self {
|
match self {
|
||||||
Model::Gemini15Pro => 2_000_000,
|
Model::Gemini15Pro => 2_000_000,
|
||||||
Model::Gemini15Flash => 1_000_000,
|
Model::Gemini15Flash => 1_000_000,
|
||||||
|
Model::Gemini20Flash => 1_000_000,
|
||||||
Model::Custom { max_tokens, .. } => *max_tokens,
|
Model::Custom { max_tokens, .. } => *max_tokens,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -88,6 +88,7 @@ impl CloudModel {
|
||||||
Self::Google(model) => match model {
|
Self::Google(model) => match model {
|
||||||
google_ai::Model::Gemini15Pro
|
google_ai::Model::Gemini15Pro
|
||||||
| google_ai::Model::Gemini15Flash
|
| google_ai::Model::Gemini15Flash
|
||||||
|
| google_ai::Model::Gemini20Flash
|
||||||
| google_ai::Model::Custom { .. } => {
|
| google_ai::Model::Custom { .. } => {
|
||||||
LanguageModelAvailability::RequiresPlan(Plan::ZedPro)
|
LanguageModelAvailability::RequiresPlan(Plan::ZedPro)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue