From a884bd77e11970b6094de04aa1774e9d672a55c0 Mon Sep 17 00:00:00 2001 From: Kirill Bulatov Date: Tue, 18 Jul 2023 14:06:57 +0300 Subject: [PATCH] Slightly tidy up vector_db code Avoid panicking when truncating code with special chars --- crates/vector_store/src/embedding.rs | 12 +++++++----- crates/vector_store/src/parsing.rs | 7 ++----- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/crates/vector_store/src/embedding.rs b/crates/vector_store/src/embedding.rs index ea349c8afa..2ddade6bb2 100644 --- a/crates/vector_store/src/embedding.rs +++ b/crates/vector_store/src/embedding.rs @@ -67,11 +67,13 @@ impl EmbeddingProvider for DummyEmbeddings { } } +const INPUT_LIMIT: usize = 8190; + impl OpenAIEmbeddings { - async fn truncate(span: String) -> String { + fn truncate(span: String) -> String { let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span.as_ref()); - if tokens.len() > 8190 { - tokens.truncate(8190); + if tokens.len() > INPUT_LIMIT { + tokens.truncate(INPUT_LIMIT); let result = OPENAI_BPE_TOKENIZER.decode(tokens.clone()); if result.is_ok() { let transformed = result.unwrap(); @@ -80,7 +82,7 @@ impl OpenAIEmbeddings { } } - return span.to_string(); + span } async fn send_request(&self, api_key: &str, spans: Vec<&str>) -> Result> { @@ -137,7 +139,7 @@ impl EmbeddingProvider for OpenAIEmbeddings { // Don't worry about delaying bad request, as we can assume // we haven't been rate limited yet. for span in spans.iter_mut() { - *span = Self::truncate(span.to_string()).await; + *span = Self::truncate(span.to_string()); } } StatusCode::OK => { diff --git a/crates/vector_store/src/parsing.rs b/crates/vector_store/src/parsing.rs index 91dcf699f8..12e590b35f 100644 --- a/crates/vector_store/src/parsing.rs +++ b/crates/vector_store/src/parsing.rs @@ -63,7 +63,7 @@ impl CodeContextRetriever { ) { // log::info!("-----MATCH-----"); - let mut name: Vec<&str> = vec![]; + let mut name = Vec::new(); let mut item: Option<&str> = None; let mut offset: Option = None; for capture in mat.captures { @@ -91,11 +91,8 @@ impl CodeContextRetriever { .replace("", &pending_file.language.name().to_lowercase()) .replace("", item.unwrap()); - let mut truncated_span = context_span.clone(); - truncated_span.truncate(100); - // log::info!("Name: {:?}", name); - // log::info!("Span: {:?}", truncated_span); + // log::info!("Span: {:?}", util::truncate(&context_span, 100)); context_spans.push(context_span); documents.push(Document {