mirror of
https://github.com/zed-industries/zed.git
synced 2025-01-23 18:32:17 +00:00
zeta: Report Fireworks request data to Snowflake (#22973)
Release Notes: - N/A --------- Co-authored-by: Antonio Scandurra <me@as-cii.com> Co-authored-by: Conrad <conrad@zed.dev>
This commit is contained in:
parent
8fb1d135ad
commit
0d03674def
7 changed files with 236 additions and 3 deletions
12
Cargo.lock
generated
12
Cargo.lock
generated
|
@ -2666,6 +2666,7 @@ dependencies = [
|
||||||
"envy",
|
"envy",
|
||||||
"extension",
|
"extension",
|
||||||
"file_finder",
|
"file_finder",
|
||||||
|
"fireworks",
|
||||||
"fs",
|
"fs",
|
||||||
"futures 0.3.31",
|
"futures 0.3.31",
|
||||||
"git",
|
"git",
|
||||||
|
@ -4590,6 +4591,17 @@ dependencies = [
|
||||||
"windows-sys 0.59.0",
|
"windows-sys 0.59.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "fireworks"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"anyhow",
|
||||||
|
"futures 0.3.31",
|
||||||
|
"http_client",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "fixedbitset"
|
name = "fixedbitset"
|
||||||
version = "0.4.2"
|
version = "0.4.2"
|
||||||
|
|
|
@ -40,6 +40,7 @@ members = [
|
||||||
"crates/feedback",
|
"crates/feedback",
|
||||||
"crates/file_finder",
|
"crates/file_finder",
|
||||||
"crates/file_icons",
|
"crates/file_icons",
|
||||||
|
"crates/fireworks",
|
||||||
"crates/fs",
|
"crates/fs",
|
||||||
"crates/fsevent",
|
"crates/fsevent",
|
||||||
"crates/fuzzy",
|
"crates/fuzzy",
|
||||||
|
@ -222,6 +223,7 @@ feature_flags = { path = "crates/feature_flags" }
|
||||||
feedback = { path = "crates/feedback" }
|
feedback = { path = "crates/feedback" }
|
||||||
file_finder = { path = "crates/file_finder" }
|
file_finder = { path = "crates/file_finder" }
|
||||||
file_icons = { path = "crates/file_icons" }
|
file_icons = { path = "crates/file_icons" }
|
||||||
|
fireworks = { path = "crates/fireworks" }
|
||||||
fs = { path = "crates/fs" }
|
fs = { path = "crates/fs" }
|
||||||
fsevent = { path = "crates/fsevent" }
|
fsevent = { path = "crates/fsevent" }
|
||||||
fuzzy = { path = "crates/fuzzy" }
|
fuzzy = { path = "crates/fuzzy" }
|
||||||
|
|
|
@ -34,6 +34,7 @@ collections.workspace = true
|
||||||
dashmap.workspace = true
|
dashmap.workspace = true
|
||||||
derive_more.workspace = true
|
derive_more.workspace = true
|
||||||
envy = "0.4.2"
|
envy = "0.4.2"
|
||||||
|
fireworks.workspace = true
|
||||||
futures.workspace = true
|
futures.workspace = true
|
||||||
google_ai.workspace = true
|
google_ai.workspace = true
|
||||||
hex.workspace = true
|
hex.workspace = true
|
||||||
|
|
|
@ -470,23 +470,48 @@ async fn predict_edits(
|
||||||
.replace("<outline>", &outline_prefix)
|
.replace("<outline>", &outline_prefix)
|
||||||
.replace("<events>", ¶ms.input_events)
|
.replace("<events>", ¶ms.input_events)
|
||||||
.replace("<excerpt>", ¶ms.input_excerpt);
|
.replace("<excerpt>", ¶ms.input_excerpt);
|
||||||
let mut response = open_ai::complete_text(
|
let mut response = fireworks::complete(
|
||||||
&state.http_client,
|
&state.http_client,
|
||||||
api_url,
|
api_url,
|
||||||
api_key,
|
api_key,
|
||||||
open_ai::CompletionRequest {
|
fireworks::CompletionRequest {
|
||||||
model: model.to_string(),
|
model: model.to_string(),
|
||||||
prompt: prompt.clone(),
|
prompt: prompt.clone(),
|
||||||
max_tokens: 2048,
|
max_tokens: 2048,
|
||||||
temperature: 0.,
|
temperature: 0.,
|
||||||
prediction: Some(open_ai::Prediction::Content {
|
prediction: Some(fireworks::Prediction::Content {
|
||||||
content: params.input_excerpt,
|
content: params.input_excerpt,
|
||||||
}),
|
}),
|
||||||
rewrite_speculation: Some(true),
|
rewrite_speculation: Some(true),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
state.executor.spawn_detached({
|
||||||
|
let kinesis_client = state.kinesis_client.clone();
|
||||||
|
let kinesis_stream = state.config.kinesis_stream.clone();
|
||||||
|
let headers = response.headers.clone();
|
||||||
|
let model = model.clone();
|
||||||
|
|
||||||
|
async move {
|
||||||
|
SnowflakeRow::new(
|
||||||
|
"Fireworks Completion Requested",
|
||||||
|
claims.metrics_id,
|
||||||
|
claims.is_staff,
|
||||||
|
claims.system_id.clone(),
|
||||||
|
json!({
|
||||||
|
"model": model.to_string(),
|
||||||
|
"headers": headers,
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.write(&kinesis_client, &kinesis_stream)
|
||||||
|
.await
|
||||||
|
.log_err();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
let choice = response
|
let choice = response
|
||||||
|
.completion
|
||||||
.choices
|
.choices
|
||||||
.pop()
|
.pop()
|
||||||
.context("no output from completion response")?;
|
.context("no output from completion response")?;
|
||||||
|
|
19
crates/fireworks/Cargo.toml
Normal file
19
crates/fireworks/Cargo.toml
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
[package]
|
||||||
|
name = "fireworks"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
publish = false
|
||||||
|
license = "GPL-3.0-or-later"
|
||||||
|
|
||||||
|
[lints]
|
||||||
|
workspace = true
|
||||||
|
|
||||||
|
[lib]
|
||||||
|
path = "src/fireworks.rs"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
anyhow.workspace = true
|
||||||
|
futures.workspace = true
|
||||||
|
http_client.workspace = true
|
||||||
|
serde.workspace = true
|
||||||
|
serde_json.workspace = true
|
1
crates/fireworks/LICENSE-GPL
Symbolic link
1
crates/fireworks/LICENSE-GPL
Symbolic link
|
@ -0,0 +1 @@
|
||||||
|
../../LICENSE-GPL
|
173
crates/fireworks/src/fireworks.rs
Normal file
173
crates/fireworks/src/fireworks.rs
Normal file
|
@ -0,0 +1,173 @@
|
||||||
|
use anyhow::{anyhow, Result};
|
||||||
|
use futures::AsyncReadExt;
|
||||||
|
use http_client::{http::HeaderMap, AsyncBody, HttpClient, Method, Request as HttpRequest};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
pub const FIREWORKS_API_URL: &str = "https://api.openai.com/v1";
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct CompletionRequest {
|
||||||
|
pub model: String,
|
||||||
|
pub prompt: String,
|
||||||
|
pub max_tokens: u32,
|
||||||
|
pub temperature: f32,
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub prediction: Option<Prediction>,
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub rewrite_speculation: Option<bool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, Serialize, Debug)]
|
||||||
|
#[serde(tag = "type", rename_all = "snake_case")]
|
||||||
|
pub enum Prediction {
|
||||||
|
Content { content: String },
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Response {
|
||||||
|
pub completion: CompletionResponse,
|
||||||
|
pub headers: Headers,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
|
pub struct CompletionResponse {
|
||||||
|
pub id: String,
|
||||||
|
pub object: String,
|
||||||
|
pub created: u64,
|
||||||
|
pub model: String,
|
||||||
|
pub choices: Vec<CompletionChoice>,
|
||||||
|
pub usage: Usage,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
|
pub struct CompletionChoice {
|
||||||
|
pub text: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
|
pub struct Usage {
|
||||||
|
pub prompt_tokens: u32,
|
||||||
|
pub completion_tokens: u32,
|
||||||
|
pub total_tokens: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Default, Serialize)]
|
||||||
|
pub struct Headers {
|
||||||
|
pub server_processing_time: Option<f64>,
|
||||||
|
pub request_id: Option<String>,
|
||||||
|
pub prompt_tokens: Option<u32>,
|
||||||
|
pub speculation_generated_tokens: Option<u32>,
|
||||||
|
pub cached_prompt_tokens: Option<u32>,
|
||||||
|
pub backend_host: Option<String>,
|
||||||
|
pub num_concurrent_requests: Option<u32>,
|
||||||
|
pub deployment: Option<String>,
|
||||||
|
pub tokenizer_queue_duration: Option<f64>,
|
||||||
|
pub tokenizer_duration: Option<f64>,
|
||||||
|
pub prefill_queue_duration: Option<f64>,
|
||||||
|
pub prefill_duration: Option<f64>,
|
||||||
|
pub generation_queue_duration: Option<f64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Headers {
|
||||||
|
pub fn parse(headers: &HeaderMap) -> Self {
|
||||||
|
Headers {
|
||||||
|
request_id: headers
|
||||||
|
.get("x-request-id")
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
.map(String::from),
|
||||||
|
server_processing_time: headers
|
||||||
|
.get("fireworks-server-processing-time")
|
||||||
|
.and_then(|v| v.to_str().ok()?.parse().ok()),
|
||||||
|
prompt_tokens: headers
|
||||||
|
.get("fireworks-prompt-tokens")
|
||||||
|
.and_then(|v| v.to_str().ok()?.parse().ok()),
|
||||||
|
speculation_generated_tokens: headers
|
||||||
|
.get("fireworks-speculation-generated-tokens")
|
||||||
|
.and_then(|v| v.to_str().ok()?.parse().ok()),
|
||||||
|
cached_prompt_tokens: headers
|
||||||
|
.get("fireworks-cached-prompt-tokens")
|
||||||
|
.and_then(|v| v.to_str().ok()?.parse().ok()),
|
||||||
|
backend_host: headers
|
||||||
|
.get("fireworks-backend-host")
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
.map(String::from),
|
||||||
|
num_concurrent_requests: headers
|
||||||
|
.get("fireworks-num-concurrent-requests")
|
||||||
|
.and_then(|v| v.to_str().ok()?.parse().ok()),
|
||||||
|
deployment: headers
|
||||||
|
.get("fireworks-deployment")
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
.map(String::from),
|
||||||
|
tokenizer_queue_duration: headers
|
||||||
|
.get("fireworks-tokenizer-queue-duration")
|
||||||
|
.and_then(|v| v.to_str().ok()?.parse().ok()),
|
||||||
|
tokenizer_duration: headers
|
||||||
|
.get("fireworks-tokenizer-duration")
|
||||||
|
.and_then(|v| v.to_str().ok()?.parse().ok()),
|
||||||
|
prefill_queue_duration: headers
|
||||||
|
.get("fireworks-prefill-queue-duration")
|
||||||
|
.and_then(|v| v.to_str().ok()?.parse().ok()),
|
||||||
|
prefill_duration: headers
|
||||||
|
.get("fireworks-prefill-duration")
|
||||||
|
.and_then(|v| v.to_str().ok()?.parse().ok()),
|
||||||
|
generation_queue_duration: headers
|
||||||
|
.get("fireworks-generation-queue-duration")
|
||||||
|
.and_then(|v| v.to_str().ok()?.parse().ok()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn complete(
|
||||||
|
client: &dyn HttpClient,
|
||||||
|
api_url: &str,
|
||||||
|
api_key: &str,
|
||||||
|
request: CompletionRequest,
|
||||||
|
) -> Result<Response> {
|
||||||
|
let uri = format!("{api_url}/completions");
|
||||||
|
let request_builder = HttpRequest::builder()
|
||||||
|
.method(Method::POST)
|
||||||
|
.uri(uri)
|
||||||
|
.header("Content-Type", "application/json")
|
||||||
|
.header("Authorization", format!("Bearer {}", api_key));
|
||||||
|
|
||||||
|
let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
|
||||||
|
let mut response = client.send(request).await?;
|
||||||
|
|
||||||
|
if response.status().is_success() {
|
||||||
|
let headers = Headers::parse(response.headers());
|
||||||
|
|
||||||
|
let mut body = String::new();
|
||||||
|
response.body_mut().read_to_string(&mut body).await?;
|
||||||
|
|
||||||
|
Ok(Response {
|
||||||
|
completion: serde_json::from_str(&body)?,
|
||||||
|
headers,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
let mut body = String::new();
|
||||||
|
response.body_mut().read_to_string(&mut body).await?;
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct FireworksResponse {
|
||||||
|
error: FireworksError,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct FireworksError {
|
||||||
|
message: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
match serde_json::from_str::<FireworksResponse>(&body) {
|
||||||
|
Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
|
||||||
|
"Failed to connect to Fireworks API: {}",
|
||||||
|
response.error.message,
|
||||||
|
)),
|
||||||
|
|
||||||
|
_ => Err(anyhow!(
|
||||||
|
"Failed to connect to Fireworks API: {} {}",
|
||||||
|
response.status(),
|
||||||
|
body,
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in a new issue