zeta: Report Fireworks request data to Snowflake (#22973)

Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>
Co-authored-by: Conrad <conrad@zed.dev>
This commit is contained in:
Thorsten Ball 2025-01-10 23:40:54 +01:00 committed by Antonio Scandurra
parent 8fb1d135ad
commit 0d03674def
7 changed files with 236 additions and 3 deletions

12
Cargo.lock generated
View file

@ -2666,6 +2666,7 @@ dependencies = [
"envy",
"extension",
"file_finder",
"fireworks",
"fs",
"futures 0.3.31",
"git",
@ -4590,6 +4591,17 @@ dependencies = [
"windows-sys 0.59.0",
]
[[package]]
name = "fireworks"
version = "0.1.0"
dependencies = [
"anyhow",
"futures 0.3.31",
"http_client",
"serde",
"serde_json",
]
[[package]]
name = "fixedbitset"
version = "0.4.2"

View file

@ -40,6 +40,7 @@ members = [
"crates/feedback",
"crates/file_finder",
"crates/file_icons",
"crates/fireworks",
"crates/fs",
"crates/fsevent",
"crates/fuzzy",
@ -222,6 +223,7 @@ feature_flags = { path = "crates/feature_flags" }
feedback = { path = "crates/feedback" }
file_finder = { path = "crates/file_finder" }
file_icons = { path = "crates/file_icons" }
fireworks = { path = "crates/fireworks" }
fs = { path = "crates/fs" }
fsevent = { path = "crates/fsevent" }
fuzzy = { path = "crates/fuzzy" }

View file

@ -34,6 +34,7 @@ collections.workspace = true
dashmap.workspace = true
derive_more.workspace = true
envy = "0.4.2"
fireworks.workspace = true
futures.workspace = true
google_ai.workspace = true
hex.workspace = true

View file

@ -470,23 +470,48 @@ async fn predict_edits(
.replace("<outline>", &outline_prefix)
.replace("<events>", &params.input_events)
.replace("<excerpt>", &params.input_excerpt);
let mut response = open_ai::complete_text(
let mut response = fireworks::complete(
&state.http_client,
api_url,
api_key,
open_ai::CompletionRequest {
fireworks::CompletionRequest {
model: model.to_string(),
prompt: prompt.clone(),
max_tokens: 2048,
temperature: 0.,
prediction: Some(open_ai::Prediction::Content {
prediction: Some(fireworks::Prediction::Content {
content: params.input_excerpt,
}),
rewrite_speculation: Some(true),
},
)
.await?;
state.executor.spawn_detached({
let kinesis_client = state.kinesis_client.clone();
let kinesis_stream = state.config.kinesis_stream.clone();
let headers = response.headers.clone();
let model = model.clone();
async move {
SnowflakeRow::new(
"Fireworks Completion Requested",
claims.metrics_id,
claims.is_staff,
claims.system_id.clone(),
json!({
"model": model.to_string(),
"headers": headers,
}),
)
.write(&kinesis_client, &kinesis_stream)
.await
.log_err();
}
});
let choice = response
.completion
.choices
.pop()
.context("no output from completion response")?;

View file

@ -0,0 +1,19 @@
[package]
name = "fireworks"
version = "0.1.0"
edition = "2021"
publish = false
license = "GPL-3.0-or-later"
[lints]
workspace = true
[lib]
path = "src/fireworks.rs"
[dependencies]
anyhow.workspace = true
futures.workspace = true
http_client.workspace = true
serde.workspace = true
serde_json.workspace = true

View file

@ -0,0 +1 @@
../../LICENSE-GPL

View file

@ -0,0 +1,173 @@
use anyhow::{anyhow, Result};
use futures::AsyncReadExt;
use http_client::{http::HeaderMap, AsyncBody, HttpClient, Method, Request as HttpRequest};
use serde::{Deserialize, Serialize};
pub const FIREWORKS_API_URL: &str = "https://api.openai.com/v1";
#[derive(Debug, Serialize, Deserialize)]
pub struct CompletionRequest {
pub model: String,
pub prompt: String,
pub max_tokens: u32,
pub temperature: f32,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub prediction: Option<Prediction>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub rewrite_speculation: Option<bool>,
}
#[derive(Clone, Deserialize, Serialize, Debug)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum Prediction {
Content { content: String },
}
#[derive(Debug)]
pub struct Response {
pub completion: CompletionResponse,
pub headers: Headers,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct CompletionResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<CompletionChoice>,
pub usage: Usage,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct CompletionChoice {
pub text: String,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[derive(Debug, Clone, Default, Serialize)]
pub struct Headers {
pub server_processing_time: Option<f64>,
pub request_id: Option<String>,
pub prompt_tokens: Option<u32>,
pub speculation_generated_tokens: Option<u32>,
pub cached_prompt_tokens: Option<u32>,
pub backend_host: Option<String>,
pub num_concurrent_requests: Option<u32>,
pub deployment: Option<String>,
pub tokenizer_queue_duration: Option<f64>,
pub tokenizer_duration: Option<f64>,
pub prefill_queue_duration: Option<f64>,
pub prefill_duration: Option<f64>,
pub generation_queue_duration: Option<f64>,
}
impl Headers {
pub fn parse(headers: &HeaderMap) -> Self {
Headers {
request_id: headers
.get("x-request-id")
.and_then(|v| v.to_str().ok())
.map(String::from),
server_processing_time: headers
.get("fireworks-server-processing-time")
.and_then(|v| v.to_str().ok()?.parse().ok()),
prompt_tokens: headers
.get("fireworks-prompt-tokens")
.and_then(|v| v.to_str().ok()?.parse().ok()),
speculation_generated_tokens: headers
.get("fireworks-speculation-generated-tokens")
.and_then(|v| v.to_str().ok()?.parse().ok()),
cached_prompt_tokens: headers
.get("fireworks-cached-prompt-tokens")
.and_then(|v| v.to_str().ok()?.parse().ok()),
backend_host: headers
.get("fireworks-backend-host")
.and_then(|v| v.to_str().ok())
.map(String::from),
num_concurrent_requests: headers
.get("fireworks-num-concurrent-requests")
.and_then(|v| v.to_str().ok()?.parse().ok()),
deployment: headers
.get("fireworks-deployment")
.and_then(|v| v.to_str().ok())
.map(String::from),
tokenizer_queue_duration: headers
.get("fireworks-tokenizer-queue-duration")
.and_then(|v| v.to_str().ok()?.parse().ok()),
tokenizer_duration: headers
.get("fireworks-tokenizer-duration")
.and_then(|v| v.to_str().ok()?.parse().ok()),
prefill_queue_duration: headers
.get("fireworks-prefill-queue-duration")
.and_then(|v| v.to_str().ok()?.parse().ok()),
prefill_duration: headers
.get("fireworks-prefill-duration")
.and_then(|v| v.to_str().ok()?.parse().ok()),
generation_queue_duration: headers
.get("fireworks-generation-queue-duration")
.and_then(|v| v.to_str().ok()?.parse().ok()),
}
}
}
pub async fn complete(
client: &dyn HttpClient,
api_url: &str,
api_key: &str,
request: CompletionRequest,
) -> Result<Response> {
let uri = format!("{api_url}/completions");
let request_builder = HttpRequest::builder()
.method(Method::POST)
.uri(uri)
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_key));
let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
let mut response = client.send(request).await?;
if response.status().is_success() {
let headers = Headers::parse(response.headers());
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
Ok(Response {
completion: serde_json::from_str(&body)?,
headers,
})
} else {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
#[derive(Deserialize)]
struct FireworksResponse {
error: FireworksError,
}
#[derive(Deserialize)]
struct FireworksError {
message: String,
}
match serde_json::from_str::<FireworksResponse>(&body) {
Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
"Failed to connect to Fireworks API: {}",
response.error.message,
)),
_ => Err(anyhow!(
"Failed to connect to Fireworks API: {} {}",
response.status(),
body,
)),
}
}
}