From 97810471569618955c241e4137629b578c46285b Mon Sep 17 00:00:00 2001 From: KCaverly Date: Wed, 30 Aug 2023 12:13:26 -0400 Subject: [PATCH] move truncation to parsing step leveraging the EmbeddingProvider trait --- crates/semantic_index/src/embedding.rs | 78 +++++++++---------- crates/semantic_index/src/parsing.rs | 4 + .../src/semantic_index_tests.rs | 4 + 3 files changed, 45 insertions(+), 41 deletions(-) diff --git a/crates/semantic_index/src/embedding.rs b/crates/semantic_index/src/embedding.rs index 3dd979f01b..cba34439c8 100644 --- a/crates/semantic_index/src/embedding.rs +++ b/crates/semantic_index/src/embedding.rs @@ -56,7 +56,7 @@ pub trait EmbeddingProvider: Sync + Send { async fn embed_batch(&self, spans: Vec<&str>) -> Result>>; fn count_tokens(&self, span: &str) -> usize; fn should_truncate(&self, span: &str) -> bool; - // fn truncate(&self, span: &str) -> Result<&str>; + fn truncate(&self, span: &str) -> String; } pub struct DummyEmbeddings {} @@ -78,36 +78,27 @@ impl EmbeddingProvider for DummyEmbeddings { fn should_truncate(&self, span: &str) -> bool { self.count_tokens(span) > OPENAI_INPUT_LIMIT + } - // let tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); - // let Ok(output) = { - // if tokens.len() > OPENAI_INPUT_LIMIT { - // tokens.truncate(OPENAI_INPUT_LIMIT); - // OPENAI_BPE_TOKENIZER.decode(tokens) - // } else { - // Ok(span) - // } - // }; + fn truncate(&self, span: &str) -> String { + let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); + let output = if tokens.len() > OPENAI_INPUT_LIMIT { + tokens.truncate(OPENAI_INPUT_LIMIT); + OPENAI_BPE_TOKENIZER + .decode(tokens) + .ok() + .unwrap_or_else(|| span.to_string()) + } else { + span.to_string() + }; + + output } } const OPENAI_INPUT_LIMIT: usize = 8190; impl OpenAIEmbeddings { - fn truncate(span: String) -> String { - let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span.as_ref()); - if tokens.len() > OPENAI_INPUT_LIMIT { - tokens.truncate(OPENAI_INPUT_LIMIT); - let result = OPENAI_BPE_TOKENIZER.decode(tokens.clone()); - if result.is_ok() { - let transformed = result.unwrap(); - return transformed; - } - } - - span - } - async fn send_request( &self, api_key: &str, @@ -144,6 +135,21 @@ impl EmbeddingProvider for OpenAIEmbeddings { self.count_tokens(span) > OPENAI_INPUT_LIMIT } + fn truncate(&self, span: &str) -> String { + let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); + let output = if tokens.len() > OPENAI_INPUT_LIMIT { + tokens.truncate(OPENAI_INPUT_LIMIT); + OPENAI_BPE_TOKENIZER + .decode(tokens) + .ok() + .unwrap_or_else(|| span.to_string()) + } else { + span.to_string() + }; + + output + } + async fn embed_batch(&self, spans: Vec<&str>) -> Result>> { const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45]; const MAX_RETRIES: usize = 4; @@ -214,23 +220,13 @@ impl EmbeddingProvider for OpenAIEmbeddings { self.executor.timer(delay_duration).await; } _ => { - // TODO: Move this to parsing step - // Only truncate if it hasnt been truncated before - if !truncated { - for span in spans.iter_mut() { - *span = Self::truncate(span.clone()); - } - truncated = true; - } else { - // If failing once already truncated, log the error and break the loop - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - return Err(anyhow!( - "open ai bad request: {:?} {:?}", - &response.status(), - body - )); - } + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + return Err(anyhow!( + "open ai bad request: {:?} {:?}", + &response.status(), + body + )); } } } diff --git a/crates/semantic_index/src/parsing.rs b/crates/semantic_index/src/parsing.rs index b106e5055b..00849580bb 100644 --- a/crates/semantic_index/src/parsing.rs +++ b/crates/semantic_index/src/parsing.rs @@ -73,6 +73,7 @@ impl CodeContextRetriever { sha1.update(&document_span); let token_count = self.embedding_provider.count_tokens(&document_span); + let document_span = self.embedding_provider.truncate(&document_span); Ok(vec![Document { range: 0..content.len(), @@ -93,6 +94,7 @@ impl CodeContextRetriever { sha1.update(&document_span); let token_count = self.embedding_provider.count_tokens(&document_span); + let document_span = self.embedding_provider.truncate(&document_span); Ok(vec![Document { range: 0..content.len(), @@ -182,6 +184,8 @@ impl CodeContextRetriever { .replace("item", &document.content); let token_count = self.embedding_provider.count_tokens(&document_content); + let document_content = self.embedding_provider.truncate(&document_content); + document.content = document_content; document.token_count = token_count; } diff --git a/crates/semantic_index/src/semantic_index_tests.rs b/crates/semantic_index/src/semantic_index_tests.rs index 48cefd93b1..7093cf9fcf 100644 --- a/crates/semantic_index/src/semantic_index_tests.rs +++ b/crates/semantic_index/src/semantic_index_tests.rs @@ -1232,6 +1232,10 @@ impl EmbeddingProvider for FakeEmbeddingProvider { false } + fn truncate(&self, span: &str) -> String { + span.to_string() + } + async fn embed_batch(&self, spans: Vec<&str>) -> Result>> { self.embedding_count .fetch_add(spans.len(), atomic::Ordering::SeqCst);