From 4f8b95cf0d99955555b6b086bed7c3153cd5bc92 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Tue, 29 Aug 2023 15:44:51 -0400 Subject: [PATCH] add proper handling for open ai rate limit delays --- Cargo.lock | 65 ++++++++++++++++- crates/semantic_index/Cargo.toml | 1 + crates/semantic_index/src/embedding.rs | 96 ++++++++++++++++---------- 3 files changed, 124 insertions(+), 38 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 347976691d..e0eb1947e2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3532,7 +3532,7 @@ dependencies = [ "gif", "jpeg-decoder", "num-iter", - "num-rational", + "num-rational 0.3.2", "num-traits", "png", "scoped_threadpool", @@ -4625,6 +4625,31 @@ dependencies = [ "winapi 0.3.9", ] +[[package]] +name = "num" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8536030f9fea7127f841b45bb6243b27255787fb4eb83958aa1ef9d2fdc0c36" +dependencies = [ + "num-bigint 0.2.6", + "num-complex", + "num-integer", + "num-iter", + "num-rational 0.2.4", + "num-traits", +] + +[[package]] +name = "num-bigint" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "090c7f9998ee0ff65aa5b723e4009f7b217707f1fb5ea551329cc4d6231fb304" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + [[package]] name = "num-bigint" version = "0.4.4" @@ -4653,6 +4678,16 @@ dependencies = [ "zeroize", ] +[[package]] +name = "num-complex" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6b19411a9719e753aff12e5187b74d60d3dc449ec3f4dc21e3989c3f554bc95" +dependencies = [ + "autocfg", + "num-traits", +] + [[package]] name = "num-derive" version = "0.3.3" @@ -4685,6 +4720,18 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-rational" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c000134b5dbf44adc5cb772486d335293351644b801551abe8f75c84cfa4aef" +dependencies = [ + "autocfg", + "num-bigint 0.2.6", + "num-integer", + "num-traits", +] + [[package]] name = "num-rational" version = "0.3.2" @@ -5001,6 +5048,17 @@ dependencies = [ "windows-targets 0.48.5", ] +[[package]] +name = "parse_duration" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7037e5e93e0172a5a96874380bf73bc6ecef022e26fa25f2be26864d6b3ba95d" +dependencies = [ + "lazy_static", + "num", + "regex", +] + [[package]] name = "password-hash" version = "0.2.3" @@ -6667,6 +6725,7 @@ dependencies = [ "log", "matrixmultiply", "parking_lot 0.11.2", + "parse_duration", "picker", "postage", "pretty_assertions", @@ -6998,7 +7057,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8eb4ea60fb301dc81dfc113df680571045d375ab7345d171c5dc7d7e13107a80" dependencies = [ "chrono", - "num-bigint", + "num-bigint 0.4.4", "num-traits", "thiserror", ] @@ -7230,7 +7289,7 @@ dependencies = [ "log", "md-5", "memchr", - "num-bigint", + "num-bigint 0.4.4", "once_cell", "paste", "percent-encoding", diff --git a/crates/semantic_index/Cargo.toml b/crates/semantic_index/Cargo.toml index 4e817fcbe2..d46346e0ab 100644 --- a/crates/semantic_index/Cargo.toml +++ b/crates/semantic_index/Cargo.toml @@ -39,6 +39,7 @@ rand.workspace = true schemars.workspace = true globset.workspace = true sha1 = "0.10.5" +parse_duration = "2.1.1" [dev-dependencies] gpui = { path = "../gpui", features = ["test-support"] } diff --git a/crates/semantic_index/src/embedding.rs b/crates/semantic_index/src/embedding.rs index f2269a786a..a9cb0245c4 100644 --- a/crates/semantic_index/src/embedding.rs +++ b/crates/semantic_index/src/embedding.rs @@ -7,6 +7,7 @@ use isahc::http::StatusCode; use isahc::prelude::Configurable; use isahc::{AsyncBody, Response}; use lazy_static::lazy_static; +use parse_duration::parse; use serde::{Deserialize, Serialize}; use std::env; use std::sync::Arc; @@ -84,10 +85,15 @@ impl OpenAIEmbeddings { span } - async fn send_request(&self, api_key: &str, spans: Vec<&str>) -> Result> { + async fn send_request( + &self, + api_key: &str, + spans: Vec<&str>, + request_timeout: u64, + ) -> Result> { let request = Request::post("https://api.openai.com/v1/embeddings") .redirect_policy(isahc::config::RedirectPolicy::Follow) - .timeout(Duration::from_secs(4)) + .timeout(Duration::from_secs(request_timeout)) .header("Content-Type", "application/json") .header("Authorization", format!("Bearer {}", api_key)) .body( @@ -114,45 +120,23 @@ impl EmbeddingProvider for OpenAIEmbeddings { .ok_or_else(|| anyhow!("no api key"))?; let mut request_number = 0; + let mut request_timeout: u64 = 10; let mut truncated = false; let mut response: Response; let mut spans: Vec = spans.iter().map(|x| x.to_string()).collect(); while request_number < MAX_RETRIES { response = self - .send_request(api_key, spans.iter().map(|x| &**x).collect()) + .send_request( + api_key, + spans.iter().map(|x| &**x).collect(), + request_timeout, + ) .await?; request_number += 1; - if request_number + 1 == MAX_RETRIES && response.status() != StatusCode::OK { - return Err(anyhow!( - "openai max retries, error: {:?}", - &response.status() - )); - } - match response.status() { - StatusCode::TOO_MANY_REQUESTS => { - let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64); - log::trace!( - "open ai rate limiting, delaying request by {:?} seconds", - delay.as_secs() - ); - self.executor.timer(delay).await; - } - StatusCode::BAD_REQUEST => { - // Only truncate if it hasnt been truncated before - if !truncated { - for span in spans.iter_mut() { - *span = Self::truncate(span.clone()); - } - truncated = true; - } else { - // If failing once already truncated, log the error and break the loop - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - log::trace!("open ai bad request: {:?} {:?}", &response.status(), body); - break; - } + StatusCode::REQUEST_TIMEOUT => { + request_timeout += 5; } StatusCode::OK => { let mut body = String::new(); @@ -163,18 +147,60 @@ impl EmbeddingProvider for OpenAIEmbeddings { "openai embedding completed. tokens: {:?}", response.usage.total_tokens ); + return Ok(response .data .into_iter() .map(|embedding| embedding.embedding) .collect()); } + StatusCode::TOO_MANY_REQUESTS => { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + let delay_duration = { + let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64); + if let Some(time_to_reset) = + response.headers().get("x-ratelimit-reset-tokens") + { + if let Ok(time_str) = time_to_reset.to_str() { + parse(time_str).unwrap_or(delay) + } else { + delay + } + } else { + delay + } + }; + + log::trace!( + "openai rate limiting: waiting {:?} until lifted", + &delay_duration + ); + + self.executor.timer(delay_duration).await; + } _ => { - return Err(anyhow!("openai embedding failed {}", response.status())); + // TODO: Move this to parsing step + // Only truncate if it hasnt been truncated before + if !truncated { + for span in spans.iter_mut() { + *span = Self::truncate(span.clone()); + } + truncated = true; + } else { + // If failing once already truncated, log the error and break the loop + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + return Err(anyhow!( + "open ai bad request: {:?} {:?}", + &response.status(), + body + )); + } } } } - - Err(anyhow!("openai embedding failed")) + Err(anyhow!("openai max retries")) } }