diff --git a/Cargo.lock b/Cargo.lock index 2b7d74578d..92d17ec0db 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -108,7 +108,7 @@ dependencies = [ "rusqlite", "serde", "serde_json", - "tiktoken-rs 0.5.4", + "tiktoken-rs", "util", ] @@ -327,7 +327,7 @@ dependencies = [ "settings", "smol", "theme", - "tiktoken-rs 0.4.5", + "tiktoken-rs", "util", "uuid 1.4.1", "workspace", @@ -6798,7 +6798,7 @@ dependencies = [ "smol", "tempdir", "theme", - "tiktoken-rs 0.5.4", + "tiktoken-rs", "tree-sitter", "tree-sitter-cpp", "tree-sitter-elixir", @@ -7875,21 +7875,6 @@ dependencies = [ "weezl", ] -[[package]] -name = "tiktoken-rs" -version = "0.4.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52aacc1cff93ba9d5f198c62c49c77fa0355025c729eed3326beaf7f33bc8614" -dependencies = [ - "anyhow", - "base64 0.21.4", - "bstr", - "fancy-regex", - "lazy_static", - "parking_lot 0.12.1", - "rustc-hash", -] - [[package]] name = "tiktoken-rs" version = "0.5.4" diff --git a/crates/assistant/Cargo.toml b/crates/assistant/Cargo.toml index 8b69e82109..12f52eee02 100644 --- a/crates/assistant/Cargo.toml +++ b/crates/assistant/Cargo.toml @@ -38,7 +38,7 @@ schemars.workspace = true serde.workspace = true serde_json.workspace = true smol.workspace = true -tiktoken-rs = "0.4" +tiktoken-rs = "0.5" [dev-dependencies] editor = { path = "../editor", features = ["test-support"] } diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 8cba4c4d9f..16d7ee6b81 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -437,8 +437,15 @@ impl AssistantPanel { InlineAssistantEvent::Confirmed { prompt, include_conversation, + retrieve_context, } => { - self.confirm_inline_assist(assist_id, prompt, *include_conversation, cx); + self.confirm_inline_assist( + assist_id, + prompt, + *include_conversation, + cx, + *retrieve_context, + ); } InlineAssistantEvent::Canceled => { self.finish_inline_assist(assist_id, true, cx); @@ -532,6 +539,7 @@ impl AssistantPanel { user_prompt: &str, include_conversation: bool, cx: &mut ViewContext, + retrieve_context: bool, ) { let conversation = if include_conversation { self.active_editor() @@ -593,42 +601,49 @@ impl AssistantPanel { let codegen_kind = codegen.read(cx).kind().clone(); let user_prompt = user_prompt.to_string(); - let project = if let Some(workspace) = self.workspace.upgrade(cx) { - workspace.read(cx).project() - } else { - return; - }; + let snippets = if retrieve_context { + let project = if let Some(workspace) = self.workspace.upgrade(cx) { + workspace.read(cx).project() + } else { + return; + }; - let project = project.to_owned(); - let search_results = if let Some(semantic_index) = self.semantic_index.clone() { - let search_results = semantic_index.update(cx, |this, cx| { - this.search_project(project, user_prompt.to_string(), 10, vec![], vec![], cx) + let project = project.to_owned(); + let search_results = if let Some(semantic_index) = self.semantic_index.clone() { + let search_results = semantic_index.update(cx, |this, cx| { + this.search_project(project, user_prompt.to_string(), 10, vec![], vec![], cx) + }); + + cx.background() + .spawn(async move { search_results.await.unwrap_or_default() }) + } else { + Task::ready(Vec::new()) + }; + + let snippets = cx.spawn(|_, cx| async move { + let mut snippets = Vec::new(); + for result in search_results.await { + snippets.push(result.buffer.read_with(&cx, |buffer, _| { + buffer + .snapshot() + .text_for_range(result.range) + .collect::() + })); + } + snippets }); - - cx.background() - .spawn(async move { search_results.await.unwrap_or_default() }) + snippets } else { Task::ready(Vec::new()) }; - let snippets = cx.spawn(|_, cx| async move { - let mut snippets = Vec::new(); - for result in search_results.await { - snippets.push(result.buffer.read_with(&cx, |buffer, _| { - buffer - .snapshot() - .text_for_range(result.range) - .collect::() - })); - } - snippets - }); + let mut model = settings::get::(cx) + .default_open_ai_model + .clone(); + let model_name = model.full_name(); let prompt = cx.background().spawn(async move { let snippets = snippets.await; - for snippet in &snippets { - println!("SNIPPET: \n{:?}", snippet); - } let language_name = language_name.as_deref(); generate_content_prompt( @@ -638,13 +653,11 @@ impl AssistantPanel { range, codegen_kind, snippets, + model_name, ) }); let mut messages = Vec::new(); - let mut model = settings::get::(cx) - .default_open_ai_model - .clone(); if let Some(conversation) = conversation { let conversation = conversation.read(cx); let buffer = conversation.buffer.read(cx); @@ -1557,12 +1570,14 @@ impl Conversation { Role::Assistant => "assistant".into(), Role::System => "system".into(), }, - content: self - .buffer - .read(cx) - .text_for_range(message.offset_range) - .collect(), + content: Some( + self.buffer + .read(cx) + .text_for_range(message.offset_range) + .collect(), + ), name: None, + function_call: None, }) }) .collect::>(); @@ -2681,6 +2696,7 @@ enum InlineAssistantEvent { Confirmed { prompt: String, include_conversation: bool, + retrieve_context: bool, }, Canceled, Dismissed, @@ -2922,6 +2938,7 @@ impl InlineAssistant { cx.emit(InlineAssistantEvent::Confirmed { prompt, include_conversation: self.include_conversation, + retrieve_context: self.retrieve_context, }); self.confirmed = true; cx.notify(); diff --git a/crates/assistant/src/prompts.rs b/crates/assistant/src/prompts.rs index 2301cd88ff..1e43833fea 100644 --- a/crates/assistant/src/prompts.rs +++ b/crates/assistant/src/prompts.rs @@ -1,8 +1,10 @@ use crate::codegen::CodegenKind; use language::{BufferSnapshot, OffsetRangeExt, ToOffset}; use std::cmp; +use std::fmt::Write; +use std::iter; use std::ops::Range; -use std::{fmt::Write, iter}; +use tiktoken_rs::ChatCompletionRequestMessage; fn summarize(buffer: &BufferSnapshot, selected_range: Range) -> String { #[derive(Debug)] @@ -122,69 +124,103 @@ pub fn generate_content_prompt( range: Range, kind: CodegenKind, search_results: Vec, + model: &str, ) -> String { - let mut prompt = String::new(); + const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500; + + let mut prompts = Vec::new(); // General Preamble if let Some(language_name) = language_name { - writeln!(prompt, "You're an expert {language_name} engineer.\n").unwrap(); + prompts.push(format!("You're an expert {language_name} engineer.\n")); } else { - writeln!(prompt, "You're an expert engineer.\n").unwrap(); + prompts.push("You're an expert engineer.\n".to_string()); } + // Snippets + let mut snippet_position = prompts.len() - 1; + let outline = summarize(buffer, range); - writeln!( - prompt, - "The file you are currently working on has the following outline:" - ) - .unwrap(); + prompts.push("The file you are currently working on has the following outline:".to_string()); if let Some(language_name) = language_name { let language_name = language_name.to_lowercase(); - writeln!(prompt, "```{language_name}\n{outline}\n```").unwrap(); + prompts.push(format!("```{language_name}\n{outline}\n```")); } else { - writeln!(prompt, "```\n{outline}\n```").unwrap(); + prompts.push(format!("```\n{outline}\n```")); } match kind { CodegenKind::Generate { position: _ } => { - writeln!(prompt, "In particular, the user's cursor is current on the '<|START|>' span in the above outline, with no text selected.").unwrap(); - writeln!( - prompt, - "Assume the cursor is located where the `<|START|` marker is." - ) - .unwrap(); - writeln!( - prompt, + prompts.push("In particular, the user's cursor is currently on the '<|START|>' span in the above outline, with no text selected.".to_string()); + prompts + .push("Assume the cursor is located where the `<|START|` marker is.".to_string()); + prompts.push( "Text can't be replaced, so assume your answer will be inserted at the cursor." - ) - .unwrap(); - writeln!( - prompt, + .to_string(), + ); + prompts.push(format!( "Generate text based on the users prompt: {user_prompt}" - ) - .unwrap(); + )); } CodegenKind::Transform { range: _ } => { - writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap(); - writeln!( - prompt, + prompts.push("In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.".to_string()); + prompts.push(format!( "Modify the users code selected text based upon the users prompt: {user_prompt}" - ) - .unwrap(); - writeln!( - prompt, - "You MUST reply with only the adjusted code (within the '<|START|' and '|END|>' spans), not the entire file." - ) - .unwrap(); + )); + prompts.push("You MUST reply with only the adjusted code (within the '<|START|' and '|END|>' spans), not the entire file.".to_string()); } } if let Some(language_name) = language_name { - writeln!(prompt, "Your answer MUST always be valid {language_name}").unwrap(); + prompts.push(format!("Your answer MUST always be valid {language_name}")); } - writeln!(prompt, "Always wrap your response in a Markdown codeblock").unwrap(); - writeln!(prompt, "Never make remarks about the output.").unwrap(); + prompts.push("Always wrap your response in a Markdown codeblock".to_string()); + prompts.push("Never make remarks about the output.".to_string()); + let current_messages = [ChatCompletionRequestMessage { + role: "user".to_string(), + content: Some(prompts.join("\n")), + function_call: None, + name: None, + }]; + + let remaining_token_count = if let Ok(current_token_count) = + tiktoken_rs::num_tokens_from_messages(model, ¤t_messages) + { + let max_token_count = tiktoken_rs::model::get_context_size(model); + max_token_count - current_token_count + } else { + // If tiktoken fails to count token count, assume we have no space remaining. + 0 + }; + + // TODO: + // - add repository name to snippet + // - add file path + // - add language + if let Ok(encoding) = tiktoken_rs::get_bpe_from_model(model) { + let template = "You are working inside a large repository, here are a few code snippets that may be useful"; + + for search_result in search_results { + let mut snippet_prompt = template.to_string(); + writeln!(snippet_prompt, "```\n{search_result}\n```").unwrap(); + + let token_count = encoding + .encode_with_special_tokens(snippet_prompt.as_str()) + .len(); + if token_count <= remaining_token_count { + if token_count < MAXIMUM_SNIPPET_TOKEN_COUNT { + prompts.insert(snippet_position, snippet_prompt); + snippet_position += 1; + } + } else { + break; + } + } + } + + let prompt = prompts.join("\n"); + println!("PROMPT: {:?}", prompt); prompt }