mirror of
https://github.com/zed-industries/zed.git
synced 2025-01-30 14:17:02 +00:00
move truncation to parsing step leveraging the EmbeddingProvider trait
This commit is contained in:
parent
76caea80f7
commit
9781047156
3 changed files with 45 additions and 41 deletions
|
@ -56,7 +56,7 @@ pub trait EmbeddingProvider: Sync + Send {
|
||||||
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>>;
|
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>>;
|
||||||
fn count_tokens(&self, span: &str) -> usize;
|
fn count_tokens(&self, span: &str) -> usize;
|
||||||
fn should_truncate(&self, span: &str) -> bool;
|
fn should_truncate(&self, span: &str) -> bool;
|
||||||
// fn truncate(&self, span: &str) -> Result<&str>;
|
fn truncate(&self, span: &str) -> String;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct DummyEmbeddings {}
|
pub struct DummyEmbeddings {}
|
||||||
|
@ -78,36 +78,27 @@ impl EmbeddingProvider for DummyEmbeddings {
|
||||||
|
|
||||||
fn should_truncate(&self, span: &str) -> bool {
|
fn should_truncate(&self, span: &str) -> bool {
|
||||||
self.count_tokens(span) > OPENAI_INPUT_LIMIT
|
self.count_tokens(span) > OPENAI_INPUT_LIMIT
|
||||||
|
}
|
||||||
|
|
||||||
// let tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
|
fn truncate(&self, span: &str) -> String {
|
||||||
// let Ok(output) = {
|
let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
|
||||||
// if tokens.len() > OPENAI_INPUT_LIMIT {
|
let output = if tokens.len() > OPENAI_INPUT_LIMIT {
|
||||||
// tokens.truncate(OPENAI_INPUT_LIMIT);
|
tokens.truncate(OPENAI_INPUT_LIMIT);
|
||||||
// OPENAI_BPE_TOKENIZER.decode(tokens)
|
OPENAI_BPE_TOKENIZER
|
||||||
// } else {
|
.decode(tokens)
|
||||||
// Ok(span)
|
.ok()
|
||||||
// }
|
.unwrap_or_else(|| span.to_string())
|
||||||
// };
|
} else {
|
||||||
|
span.to_string()
|
||||||
|
};
|
||||||
|
|
||||||
|
output
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const OPENAI_INPUT_LIMIT: usize = 8190;
|
const OPENAI_INPUT_LIMIT: usize = 8190;
|
||||||
|
|
||||||
impl OpenAIEmbeddings {
|
impl OpenAIEmbeddings {
|
||||||
fn truncate(span: String) -> String {
|
|
||||||
let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span.as_ref());
|
|
||||||
if tokens.len() > OPENAI_INPUT_LIMIT {
|
|
||||||
tokens.truncate(OPENAI_INPUT_LIMIT);
|
|
||||||
let result = OPENAI_BPE_TOKENIZER.decode(tokens.clone());
|
|
||||||
if result.is_ok() {
|
|
||||||
let transformed = result.unwrap();
|
|
||||||
return transformed;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
span
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn send_request(
|
async fn send_request(
|
||||||
&self,
|
&self,
|
||||||
api_key: &str,
|
api_key: &str,
|
||||||
|
@ -144,6 +135,21 @@ impl EmbeddingProvider for OpenAIEmbeddings {
|
||||||
self.count_tokens(span) > OPENAI_INPUT_LIMIT
|
self.count_tokens(span) > OPENAI_INPUT_LIMIT
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn truncate(&self, span: &str) -> String {
|
||||||
|
let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
|
||||||
|
let output = if tokens.len() > OPENAI_INPUT_LIMIT {
|
||||||
|
tokens.truncate(OPENAI_INPUT_LIMIT);
|
||||||
|
OPENAI_BPE_TOKENIZER
|
||||||
|
.decode(tokens)
|
||||||
|
.ok()
|
||||||
|
.unwrap_or_else(|| span.to_string())
|
||||||
|
} else {
|
||||||
|
span.to_string()
|
||||||
|
};
|
||||||
|
|
||||||
|
output
|
||||||
|
}
|
||||||
|
|
||||||
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
|
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
|
||||||
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
|
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
|
||||||
const MAX_RETRIES: usize = 4;
|
const MAX_RETRIES: usize = 4;
|
||||||
|
@ -214,23 +220,13 @@ impl EmbeddingProvider for OpenAIEmbeddings {
|
||||||
self.executor.timer(delay_duration).await;
|
self.executor.timer(delay_duration).await;
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
// TODO: Move this to parsing step
|
let mut body = String::new();
|
||||||
// Only truncate if it hasnt been truncated before
|
response.body_mut().read_to_string(&mut body).await?;
|
||||||
if !truncated {
|
return Err(anyhow!(
|
||||||
for span in spans.iter_mut() {
|
"open ai bad request: {:?} {:?}",
|
||||||
*span = Self::truncate(span.clone());
|
&response.status(),
|
||||||
}
|
body
|
||||||
truncated = true;
|
));
|
||||||
} else {
|
|
||||||
// If failing once already truncated, log the error and break the loop
|
|
||||||
let mut body = String::new();
|
|
||||||
response.body_mut().read_to_string(&mut body).await?;
|
|
||||||
return Err(anyhow!(
|
|
||||||
"open ai bad request: {:?} {:?}",
|
|
||||||
&response.status(),
|
|
||||||
body
|
|
||||||
));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -73,6 +73,7 @@ impl CodeContextRetriever {
|
||||||
sha1.update(&document_span);
|
sha1.update(&document_span);
|
||||||
|
|
||||||
let token_count = self.embedding_provider.count_tokens(&document_span);
|
let token_count = self.embedding_provider.count_tokens(&document_span);
|
||||||
|
let document_span = self.embedding_provider.truncate(&document_span);
|
||||||
|
|
||||||
Ok(vec![Document {
|
Ok(vec![Document {
|
||||||
range: 0..content.len(),
|
range: 0..content.len(),
|
||||||
|
@ -93,6 +94,7 @@ impl CodeContextRetriever {
|
||||||
sha1.update(&document_span);
|
sha1.update(&document_span);
|
||||||
|
|
||||||
let token_count = self.embedding_provider.count_tokens(&document_span);
|
let token_count = self.embedding_provider.count_tokens(&document_span);
|
||||||
|
let document_span = self.embedding_provider.truncate(&document_span);
|
||||||
|
|
||||||
Ok(vec![Document {
|
Ok(vec![Document {
|
||||||
range: 0..content.len(),
|
range: 0..content.len(),
|
||||||
|
@ -182,6 +184,8 @@ impl CodeContextRetriever {
|
||||||
.replace("item", &document.content);
|
.replace("item", &document.content);
|
||||||
|
|
||||||
let token_count = self.embedding_provider.count_tokens(&document_content);
|
let token_count = self.embedding_provider.count_tokens(&document_content);
|
||||||
|
let document_content = self.embedding_provider.truncate(&document_content);
|
||||||
|
|
||||||
document.content = document_content;
|
document.content = document_content;
|
||||||
document.token_count = token_count;
|
document.token_count = token_count;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1232,6 +1232,10 @@ impl EmbeddingProvider for FakeEmbeddingProvider {
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn truncate(&self, span: &str) -> String {
|
||||||
|
span.to_string()
|
||||||
|
}
|
||||||
|
|
||||||
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
|
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
|
||||||
self.embedding_count
|
self.embedding_count
|
||||||
.fetch_add(spans.len(), atomic::Ordering::SeqCst);
|
.fetch_add(spans.len(), atomic::Ordering::SeqCst);
|
||||||
|
|
Loading…
Reference in a new issue