mirror of
https://github.com/zed-industries/zed.git
synced 2024-12-24 17:28:40 +00:00
add retrieve context button to inline assistant
This commit is contained in:
parent
e9637267ef
commit
bfe76467b0
4 changed files with 131 additions and 93 deletions
21
Cargo.lock
generated
21
Cargo.lock
generated
|
@ -108,7 +108,7 @@ dependencies = [
|
|||
"rusqlite",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tiktoken-rs 0.5.4",
|
||||
"tiktoken-rs",
|
||||
"util",
|
||||
]
|
||||
|
||||
|
@ -327,7 +327,7 @@ dependencies = [
|
|||
"settings",
|
||||
"smol",
|
||||
"theme",
|
||||
"tiktoken-rs 0.4.5",
|
||||
"tiktoken-rs",
|
||||
"util",
|
||||
"uuid 1.4.1",
|
||||
"workspace",
|
||||
|
@ -6798,7 +6798,7 @@ dependencies = [
|
|||
"smol",
|
||||
"tempdir",
|
||||
"theme",
|
||||
"tiktoken-rs 0.5.4",
|
||||
"tiktoken-rs",
|
||||
"tree-sitter",
|
||||
"tree-sitter-cpp",
|
||||
"tree-sitter-elixir",
|
||||
|
@ -7875,21 +7875,6 @@ dependencies = [
|
|||
"weezl",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tiktoken-rs"
|
||||
version = "0.4.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "52aacc1cff93ba9d5f198c62c49c77fa0355025c729eed3326beaf7f33bc8614"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"base64 0.21.4",
|
||||
"bstr",
|
||||
"fancy-regex",
|
||||
"lazy_static",
|
||||
"parking_lot 0.12.1",
|
||||
"rustc-hash",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tiktoken-rs"
|
||||
version = "0.5.4"
|
||||
|
|
|
@ -38,7 +38,7 @@ schemars.workspace = true
|
|||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
smol.workspace = true
|
||||
tiktoken-rs = "0.4"
|
||||
tiktoken-rs = "0.5"
|
||||
|
||||
[dev-dependencies]
|
||||
editor = { path = "../editor", features = ["test-support"] }
|
||||
|
|
|
@ -437,8 +437,15 @@ impl AssistantPanel {
|
|||
InlineAssistantEvent::Confirmed {
|
||||
prompt,
|
||||
include_conversation,
|
||||
retrieve_context,
|
||||
} => {
|
||||
self.confirm_inline_assist(assist_id, prompt, *include_conversation, cx);
|
||||
self.confirm_inline_assist(
|
||||
assist_id,
|
||||
prompt,
|
||||
*include_conversation,
|
||||
cx,
|
||||
*retrieve_context,
|
||||
);
|
||||
}
|
||||
InlineAssistantEvent::Canceled => {
|
||||
self.finish_inline_assist(assist_id, true, cx);
|
||||
|
@ -532,6 +539,7 @@ impl AssistantPanel {
|
|||
user_prompt: &str,
|
||||
include_conversation: bool,
|
||||
cx: &mut ViewContext<Self>,
|
||||
retrieve_context: bool,
|
||||
) {
|
||||
let conversation = if include_conversation {
|
||||
self.active_editor()
|
||||
|
@ -593,42 +601,49 @@ impl AssistantPanel {
|
|||
let codegen_kind = codegen.read(cx).kind().clone();
|
||||
let user_prompt = user_prompt.to_string();
|
||||
|
||||
let project = if let Some(workspace) = self.workspace.upgrade(cx) {
|
||||
workspace.read(cx).project()
|
||||
} else {
|
||||
return;
|
||||
};
|
||||
let snippets = if retrieve_context {
|
||||
let project = if let Some(workspace) = self.workspace.upgrade(cx) {
|
||||
workspace.read(cx).project()
|
||||
} else {
|
||||
return;
|
||||
};
|
||||
|
||||
let project = project.to_owned();
|
||||
let search_results = if let Some(semantic_index) = self.semantic_index.clone() {
|
||||
let search_results = semantic_index.update(cx, |this, cx| {
|
||||
this.search_project(project, user_prompt.to_string(), 10, vec![], vec![], cx)
|
||||
let project = project.to_owned();
|
||||
let search_results = if let Some(semantic_index) = self.semantic_index.clone() {
|
||||
let search_results = semantic_index.update(cx, |this, cx| {
|
||||
this.search_project(project, user_prompt.to_string(), 10, vec![], vec![], cx)
|
||||
});
|
||||
|
||||
cx.background()
|
||||
.spawn(async move { search_results.await.unwrap_or_default() })
|
||||
} else {
|
||||
Task::ready(Vec::new())
|
||||
};
|
||||
|
||||
let snippets = cx.spawn(|_, cx| async move {
|
||||
let mut snippets = Vec::new();
|
||||
for result in search_results.await {
|
||||
snippets.push(result.buffer.read_with(&cx, |buffer, _| {
|
||||
buffer
|
||||
.snapshot()
|
||||
.text_for_range(result.range)
|
||||
.collect::<String>()
|
||||
}));
|
||||
}
|
||||
snippets
|
||||
});
|
||||
|
||||
cx.background()
|
||||
.spawn(async move { search_results.await.unwrap_or_default() })
|
||||
snippets
|
||||
} else {
|
||||
Task::ready(Vec::new())
|
||||
};
|
||||
|
||||
let snippets = cx.spawn(|_, cx| async move {
|
||||
let mut snippets = Vec::new();
|
||||
for result in search_results.await {
|
||||
snippets.push(result.buffer.read_with(&cx, |buffer, _| {
|
||||
buffer
|
||||
.snapshot()
|
||||
.text_for_range(result.range)
|
||||
.collect::<String>()
|
||||
}));
|
||||
}
|
||||
snippets
|
||||
});
|
||||
let mut model = settings::get::<AssistantSettings>(cx)
|
||||
.default_open_ai_model
|
||||
.clone();
|
||||
let model_name = model.full_name();
|
||||
|
||||
let prompt = cx.background().spawn(async move {
|
||||
let snippets = snippets.await;
|
||||
for snippet in &snippets {
|
||||
println!("SNIPPET: \n{:?}", snippet);
|
||||
}
|
||||
|
||||
let language_name = language_name.as_deref();
|
||||
generate_content_prompt(
|
||||
|
@ -638,13 +653,11 @@ impl AssistantPanel {
|
|||
range,
|
||||
codegen_kind,
|
||||
snippets,
|
||||
model_name,
|
||||
)
|
||||
});
|
||||
|
||||
let mut messages = Vec::new();
|
||||
let mut model = settings::get::<AssistantSettings>(cx)
|
||||
.default_open_ai_model
|
||||
.clone();
|
||||
if let Some(conversation) = conversation {
|
||||
let conversation = conversation.read(cx);
|
||||
let buffer = conversation.buffer.read(cx);
|
||||
|
@ -1557,12 +1570,14 @@ impl Conversation {
|
|||
Role::Assistant => "assistant".into(),
|
||||
Role::System => "system".into(),
|
||||
},
|
||||
content: self
|
||||
.buffer
|
||||
.read(cx)
|
||||
.text_for_range(message.offset_range)
|
||||
.collect(),
|
||||
content: Some(
|
||||
self.buffer
|
||||
.read(cx)
|
||||
.text_for_range(message.offset_range)
|
||||
.collect(),
|
||||
),
|
||||
name: None,
|
||||
function_call: None,
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
@ -2681,6 +2696,7 @@ enum InlineAssistantEvent {
|
|||
Confirmed {
|
||||
prompt: String,
|
||||
include_conversation: bool,
|
||||
retrieve_context: bool,
|
||||
},
|
||||
Canceled,
|
||||
Dismissed,
|
||||
|
@ -2922,6 +2938,7 @@ impl InlineAssistant {
|
|||
cx.emit(InlineAssistantEvent::Confirmed {
|
||||
prompt,
|
||||
include_conversation: self.include_conversation,
|
||||
retrieve_context: self.retrieve_context,
|
||||
});
|
||||
self.confirmed = true;
|
||||
cx.notify();
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
use crate::codegen::CodegenKind;
|
||||
use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
|
||||
use std::cmp;
|
||||
use std::fmt::Write;
|
||||
use std::iter;
|
||||
use std::ops::Range;
|
||||
use std::{fmt::Write, iter};
|
||||
use tiktoken_rs::ChatCompletionRequestMessage;
|
||||
|
||||
fn summarize(buffer: &BufferSnapshot, selected_range: Range<impl ToOffset>) -> String {
|
||||
#[derive(Debug)]
|
||||
|
@ -122,69 +124,103 @@ pub fn generate_content_prompt(
|
|||
range: Range<impl ToOffset>,
|
||||
kind: CodegenKind,
|
||||
search_results: Vec<String>,
|
||||
model: &str,
|
||||
) -> String {
|
||||
let mut prompt = String::new();
|
||||
const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500;
|
||||
|
||||
let mut prompts = Vec::new();
|
||||
|
||||
// General Preamble
|
||||
if let Some(language_name) = language_name {
|
||||
writeln!(prompt, "You're an expert {language_name} engineer.\n").unwrap();
|
||||
prompts.push(format!("You're an expert {language_name} engineer.\n"));
|
||||
} else {
|
||||
writeln!(prompt, "You're an expert engineer.\n").unwrap();
|
||||
prompts.push("You're an expert engineer.\n".to_string());
|
||||
}
|
||||
|
||||
// Snippets
|
||||
let mut snippet_position = prompts.len() - 1;
|
||||
|
||||
let outline = summarize(buffer, range);
|
||||
writeln!(
|
||||
prompt,
|
||||
"The file you are currently working on has the following outline:"
|
||||
)
|
||||
.unwrap();
|
||||
prompts.push("The file you are currently working on has the following outline:".to_string());
|
||||
if let Some(language_name) = language_name {
|
||||
let language_name = language_name.to_lowercase();
|
||||
writeln!(prompt, "```{language_name}\n{outline}\n```").unwrap();
|
||||
prompts.push(format!("```{language_name}\n{outline}\n```"));
|
||||
} else {
|
||||
writeln!(prompt, "```\n{outline}\n```").unwrap();
|
||||
prompts.push(format!("```\n{outline}\n```"));
|
||||
}
|
||||
|
||||
match kind {
|
||||
CodegenKind::Generate { position: _ } => {
|
||||
writeln!(prompt, "In particular, the user's cursor is current on the '<|START|>' span in the above outline, with no text selected.").unwrap();
|
||||
writeln!(
|
||||
prompt,
|
||||
"Assume the cursor is located where the `<|START|` marker is."
|
||||
)
|
||||
.unwrap();
|
||||
writeln!(
|
||||
prompt,
|
||||
prompts.push("In particular, the user's cursor is currently on the '<|START|>' span in the above outline, with no text selected.".to_string());
|
||||
prompts
|
||||
.push("Assume the cursor is located where the `<|START|` marker is.".to_string());
|
||||
prompts.push(
|
||||
"Text can't be replaced, so assume your answer will be inserted at the cursor."
|
||||
)
|
||||
.unwrap();
|
||||
writeln!(
|
||||
prompt,
|
||||
.to_string(),
|
||||
);
|
||||
prompts.push(format!(
|
||||
"Generate text based on the users prompt: {user_prompt}"
|
||||
)
|
||||
.unwrap();
|
||||
));
|
||||
}
|
||||
CodegenKind::Transform { range: _ } => {
|
||||
writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap();
|
||||
writeln!(
|
||||
prompt,
|
||||
prompts.push("In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.".to_string());
|
||||
prompts.push(format!(
|
||||
"Modify the users code selected text based upon the users prompt: {user_prompt}"
|
||||
)
|
||||
.unwrap();
|
||||
writeln!(
|
||||
prompt,
|
||||
"You MUST reply with only the adjusted code (within the '<|START|' and '|END|>' spans), not the entire file."
|
||||
)
|
||||
.unwrap();
|
||||
));
|
||||
prompts.push("You MUST reply with only the adjusted code (within the '<|START|' and '|END|>' spans), not the entire file.".to_string());
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(language_name) = language_name {
|
||||
writeln!(prompt, "Your answer MUST always be valid {language_name}").unwrap();
|
||||
prompts.push(format!("Your answer MUST always be valid {language_name}"));
|
||||
}
|
||||
writeln!(prompt, "Always wrap your response in a Markdown codeblock").unwrap();
|
||||
writeln!(prompt, "Never make remarks about the output.").unwrap();
|
||||
prompts.push("Always wrap your response in a Markdown codeblock".to_string());
|
||||
prompts.push("Never make remarks about the output.".to_string());
|
||||
|
||||
let current_messages = [ChatCompletionRequestMessage {
|
||||
role: "user".to_string(),
|
||||
content: Some(prompts.join("\n")),
|
||||
function_call: None,
|
||||
name: None,
|
||||
}];
|
||||
|
||||
let remaining_token_count = if let Ok(current_token_count) =
|
||||
tiktoken_rs::num_tokens_from_messages(model, ¤t_messages)
|
||||
{
|
||||
let max_token_count = tiktoken_rs::model::get_context_size(model);
|
||||
max_token_count - current_token_count
|
||||
} else {
|
||||
// If tiktoken fails to count token count, assume we have no space remaining.
|
||||
0
|
||||
};
|
||||
|
||||
// TODO:
|
||||
// - add repository name to snippet
|
||||
// - add file path
|
||||
// - add language
|
||||
if let Ok(encoding) = tiktoken_rs::get_bpe_from_model(model) {
|
||||
let template = "You are working inside a large repository, here are a few code snippets that may be useful";
|
||||
|
||||
for search_result in search_results {
|
||||
let mut snippet_prompt = template.to_string();
|
||||
writeln!(snippet_prompt, "```\n{search_result}\n```").unwrap();
|
||||
|
||||
let token_count = encoding
|
||||
.encode_with_special_tokens(snippet_prompt.as_str())
|
||||
.len();
|
||||
if token_count <= remaining_token_count {
|
||||
if token_count < MAXIMUM_SNIPPET_TOKEN_COUNT {
|
||||
prompts.insert(snippet_position, snippet_prompt);
|
||||
snippet_position += 1;
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let prompt = prompts.join("\n");
|
||||
println!("PROMPT: {:?}", prompt);
|
||||
prompt
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue