move truncation to parsing step leveraging the EmbeddingProvider trait

This commit is contained in:
KCaverly 2023-08-30 12:13:26 -04:00
parent 76caea80f7
commit 9781047156
3 changed files with 45 additions and 41 deletions

View file

@ -56,7 +56,7 @@ pub trait EmbeddingProvider: Sync + Send {
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>>;
fn count_tokens(&self, span: &str) -> usize;
fn should_truncate(&self, span: &str) -> bool;
// fn truncate(&self, span: &str) -> Result<&str>;
fn truncate(&self, span: &str) -> String;
}
pub struct DummyEmbeddings {}
@ -78,36 +78,27 @@ impl EmbeddingProvider for DummyEmbeddings {
fn should_truncate(&self, span: &str) -> bool {
self.count_tokens(span) > OPENAI_INPUT_LIMIT
}
// let tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
// let Ok(output) = {
// if tokens.len() > OPENAI_INPUT_LIMIT {
// tokens.truncate(OPENAI_INPUT_LIMIT);
// OPENAI_BPE_TOKENIZER.decode(tokens)
// } else {
// Ok(span)
// }
// };
fn truncate(&self, span: &str) -> String {
let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
let output = if tokens.len() > OPENAI_INPUT_LIMIT {
tokens.truncate(OPENAI_INPUT_LIMIT);
OPENAI_BPE_TOKENIZER
.decode(tokens)
.ok()
.unwrap_or_else(|| span.to_string())
} else {
span.to_string()
};
output
}
}
const OPENAI_INPUT_LIMIT: usize = 8190;
impl OpenAIEmbeddings {
fn truncate(span: String) -> String {
let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span.as_ref());
if tokens.len() > OPENAI_INPUT_LIMIT {
tokens.truncate(OPENAI_INPUT_LIMIT);
let result = OPENAI_BPE_TOKENIZER.decode(tokens.clone());
if result.is_ok() {
let transformed = result.unwrap();
return transformed;
}
}
span
}
async fn send_request(
&self,
api_key: &str,
@ -144,6 +135,21 @@ impl EmbeddingProvider for OpenAIEmbeddings {
self.count_tokens(span) > OPENAI_INPUT_LIMIT
}
fn truncate(&self, span: &str) -> String {
let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
let output = if tokens.len() > OPENAI_INPUT_LIMIT {
tokens.truncate(OPENAI_INPUT_LIMIT);
OPENAI_BPE_TOKENIZER
.decode(tokens)
.ok()
.unwrap_or_else(|| span.to_string())
} else {
span.to_string()
};
output
}
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
const MAX_RETRIES: usize = 4;
@ -214,23 +220,13 @@ impl EmbeddingProvider for OpenAIEmbeddings {
self.executor.timer(delay_duration).await;
}
_ => {
// TODO: Move this to parsing step
// Only truncate if it hasnt been truncated before
if !truncated {
for span in spans.iter_mut() {
*span = Self::truncate(span.clone());
}
truncated = true;
} else {
// If failing once already truncated, log the error and break the loop
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
return Err(anyhow!(
"open ai bad request: {:?} {:?}",
&response.status(),
body
));
}
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
return Err(anyhow!(
"open ai bad request: {:?} {:?}",
&response.status(),
body
));
}
}
}

View file

@ -73,6 +73,7 @@ impl CodeContextRetriever {
sha1.update(&document_span);
let token_count = self.embedding_provider.count_tokens(&document_span);
let document_span = self.embedding_provider.truncate(&document_span);
Ok(vec![Document {
range: 0..content.len(),
@ -93,6 +94,7 @@ impl CodeContextRetriever {
sha1.update(&document_span);
let token_count = self.embedding_provider.count_tokens(&document_span);
let document_span = self.embedding_provider.truncate(&document_span);
Ok(vec![Document {
range: 0..content.len(),
@ -182,6 +184,8 @@ impl CodeContextRetriever {
.replace("item", &document.content);
let token_count = self.embedding_provider.count_tokens(&document_content);
let document_content = self.embedding_provider.truncate(&document_content);
document.content = document_content;
document.token_count = token_count;
}

View file

@ -1232,6 +1232,10 @@ impl EmbeddingProvider for FakeEmbeddingProvider {
false
}
fn truncate(&self, span: &str) -> String {
span.to_string()
}
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
self.embedding_count
.fetch_add(spans.len(), atomic::Ordering::SeqCst);