Summarize the contents of a file using the embedding query

This commit is contained in:
Antonio Scandurra 2023-10-02 14:32:13 +02:00
parent 53c25690f9
commit 64a55681e6
4 changed files with 264 additions and 235 deletions

View file

@ -578,7 +578,6 @@ impl AssistantPanel {
language_name,
&snapshot,
language_range,
cx,
codegen_kind,
);

View file

@ -1,86 +1,118 @@
use gpui::AppContext;
use crate::codegen::CodegenKind;
use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
use std::cmp;
use std::ops::Range;
use std::{fmt::Write, iter};
use crate::codegen::CodegenKind;
fn outline_for_prompt(
buffer: &BufferSnapshot,
range: Range<language::Anchor>,
cx: &AppContext,
) -> Option<String> {
let indent = buffer
.language_indent_size_at(0, cx)
.chars()
.collect::<String>();
let outline = buffer.outline(None)?;
let range = range.to_offset(buffer);
let mut text = String::new();
let mut items = outline.items.into_iter().peekable();
let mut intersected = false;
let mut intersection_indent = 0;
let mut extended_range = range.clone();
while let Some(item) = items.next() {
let item_range = item.range.to_offset(buffer);
if item_range.end < range.start || item_range.start > range.end {
text.extend(iter::repeat(indent.as_str()).take(item.depth));
text.push_str(&item.text);
text.push('\n');
} else {
intersected = true;
let is_terminal = items
.peek()
.map_or(true, |next_item| next_item.depth <= item.depth);
if is_terminal {
if item_range.start <= extended_range.start {
extended_range.start = item_range.start;
intersection_indent = item.depth;
}
extended_range.end = cmp::max(extended_range.end, item_range.end);
} else {
let name_start = item_range.start + item.name_ranges.first().unwrap().start;
let name_end = item_range.start + item.name_ranges.last().unwrap().end;
if range.start > name_end {
text.extend(iter::repeat(indent.as_str()).take(item.depth));
text.push_str(&item.text);
text.push('\n');
} else {
if name_start <= extended_range.start {
extended_range.start = item_range.start;
intersection_indent = item.depth;
}
extended_range.end = cmp::max(extended_range.end, name_end);
}
}
}
if intersected
&& items.peek().map_or(true, |next_item| {
next_item.range.start.to_offset(buffer) > range.end
})
{
intersected = false;
text.extend(iter::repeat(indent.as_str()).take(intersection_indent));
text.extend(buffer.text_for_range(extended_range.start..range.start));
text.push_str("<|START|");
text.extend(buffer.text_for_range(range.clone()));
if range.start != range.end {
text.push_str("|END|>");
} else {
text.push_str(">");
}
text.extend(buffer.text_for_range(range.end..extended_range.end));
text.push('\n');
}
fn summarize(buffer: &BufferSnapshot, selected_range: Range<impl ToOffset>) -> String {
#[derive(Debug)]
struct Match {
collapse: Range<usize>,
keep: Vec<Range<usize>>,
}
Some(text)
let selected_range = selected_range.to_offset(buffer);
let mut matches = buffer.matches(0..buffer.len(), |grammar| {
Some(&grammar.embedding_config.as_ref()?.query)
});
let configs = matches
.grammars()
.iter()
.map(|g| g.embedding_config.as_ref().unwrap())
.collect::<Vec<_>>();
let mut matches = iter::from_fn(move || {
while let Some(mat) = matches.peek() {
let config = &configs[mat.grammar_index];
if let Some(collapse) = mat.captures.iter().find_map(|cap| {
if Some(cap.index) == config.collapse_capture_ix {
Some(cap.node.byte_range())
} else {
None
}
}) {
let mut keep = Vec::new();
for capture in mat.captures.iter() {
if Some(capture.index) == config.keep_capture_ix {
keep.push(capture.node.byte_range());
} else {
continue;
}
}
matches.advance();
return Some(Match { collapse, keep });
} else {
matches.advance();
}
}
None
})
.peekable();
let mut summary = String::new();
let mut offset = 0;
let mut flushed_selection = false;
while let Some(mut mat) = matches.next() {
// Keep extending the collapsed range if the next match surrounds
// the current one.
while let Some(next_mat) = matches.peek() {
if next_mat.collapse.start <= mat.collapse.start
&& next_mat.collapse.end >= mat.collapse.end
{
mat = matches.next().unwrap();
} else {
break;
}
}
if offset >= mat.collapse.start {
// Skip collapsed nodes that have already been summarized.
offset = cmp::max(offset, mat.collapse.end);
continue;
}
if offset <= selected_range.start && selected_range.start <= mat.collapse.end {
if !flushed_selection {
// The collapsed node ends after the selection starts, so we'll flush the selection first.
summary.extend(buffer.text_for_range(offset..selected_range.start));
summary.push_str("<|START|");
if selected_range.end == selected_range.start {
summary.push_str(">");
} else {
summary.extend(buffer.text_for_range(selected_range.clone()));
summary.push_str("|END|>");
}
offset = selected_range.end;
flushed_selection = true;
}
// If the selection intersects the collapsed node, we won't collapse it.
if selected_range.end >= mat.collapse.start {
continue;
}
}
summary.extend(buffer.text_for_range(offset..mat.collapse.start));
for keep in mat.keep {
summary.extend(buffer.text_for_range(keep));
}
offset = mat.collapse.end;
}
// Flush selection if we haven't already done so.
if !flushed_selection && offset <= selected_range.start {
summary.extend(buffer.text_for_range(offset..selected_range.start));
summary.push_str("<|START|");
if selected_range.end == selected_range.start {
summary.push_str(">");
} else {
summary.extend(buffer.text_for_range(selected_range.clone()));
summary.push_str("|END|>");
}
offset = selected_range.end;
}
summary.extend(buffer.text_for_range(offset..buffer.len()));
summary
}
pub fn generate_content_prompt(
@ -88,7 +120,6 @@ pub fn generate_content_prompt(
language_name: Option<&str>,
buffer: &BufferSnapshot,
range: Range<language::Anchor>,
cx: &AppContext,
kind: CodegenKind,
) -> String {
let mut prompt = String::new();
@ -100,19 +131,17 @@ pub fn generate_content_prompt(
writeln!(prompt, "You're an expert engineer.\n").unwrap();
}
let outline = outline_for_prompt(buffer, range.clone(), cx);
if let Some(outline) = outline {
writeln!(
prompt,
"The file you are currently working on has the following outline:"
)
.unwrap();
if let Some(language_name) = language_name {
let language_name = language_name.to_lowercase();
writeln!(prompt, "```{language_name}\n{outline}\n```").unwrap();
} else {
writeln!(prompt, "```\n{outline}\n```").unwrap();
}
let outline = summarize(buffer, range.clone());
writeln!(
prompt,
"The file you are currently working on has the following outline:"
)
.unwrap();
if let Some(language_name) = language_name {
let language_name = language_name.to_lowercase();
writeln!(prompt, "```{language_name}\n{outline}\n```").unwrap();
} else {
writeln!(prompt, "```\n{outline}\n```").unwrap();
}
// Assume for now that we are just generating
@ -183,39 +212,37 @@ pub(crate) mod tests {
},
Some(tree_sitter_rust::language()),
)
.with_indents_query(
.with_embedding_query(
r#"
(call_expression) @indent
(field_expression) @indent
(_ "(" ")" @end) @indent
(_ "{" "}" @end) @indent
"#,
)
.unwrap()
.with_outline_query(
r#"
(struct_item
"struct" @context
name: (_) @name) @item
(enum_item
"enum" @context
name: (_) @name) @item
(enum_variant
name: (_) @name) @item
(field_declaration
name: (_) @name) @item
(impl_item
"impl" @context
trait: (_)? @name
"for"? @context
type: (_) @name) @item
(function_item
"fn" @context
name: (_) @name) @item
(mod_item
"mod" @context
name: (_) @name) @item
"#,
(
[(line_comment) (attribute_item)]* @context
.
[
(struct_item
name: (_) @name)
(enum_item
name: (_) @name)
(impl_item
trait: (_)? @name
"for"? @name
type: (_) @name)
(trait_item
name: (_) @name)
(function_item
name: (_) @name
body: (block
"{" @keep
"}" @keep) @collapse)
(macro_definition
name: (_) @name)
] @item
)
"#,
)
.unwrap()
}
@ -251,132 +278,133 @@ pub(crate) mod tests {
cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
let snapshot = buffer.read(cx).snapshot();
let outline = outline_for_prompt(
&snapshot,
snapshot.anchor_before(Point::new(1, 4))..snapshot.anchor_before(Point::new(1, 4)),
cx,
);
assert_eq!(
outline.as_deref(),
Some(indoc! {"
struct X
<|START|>a: usize
b
impl X
fn new
fn a
fn b
"})
summarize(&snapshot, Point::new(1, 4)..Point::new(1, 4)),
indoc! {"
struct X {
<|START|>a: usize,
b: usize,
}
impl X {
fn new() -> Self {}
pub fn a(&self, param: bool) -> usize {}
pub fn b(&self) -> usize {}
}
"}
);
let outline = outline_for_prompt(
&snapshot,
snapshot.anchor_before(Point::new(8, 12))..snapshot.anchor_before(Point::new(8, 14)),
cx,
);
assert_eq!(
outline.as_deref(),
Some(indoc! {"
struct X
a
b
impl X
summarize(&snapshot, Point::new(8, 12)..Point::new(8, 14)),
indoc! {"
struct X {
a: usize,
b: usize,
}
impl X {
fn new() -> Self {
let <|START|a |END|>= 1;
let b = 2;
Self { a, b }
}
fn a
fn b
"})
pub fn a(&self, param: bool) -> usize {}
pub fn b(&self) -> usize {}
}
"}
);
let outline = outline_for_prompt(
&snapshot,
snapshot.anchor_before(Point::new(6, 0))..snapshot.anchor_before(Point::new(6, 0)),
cx,
);
assert_eq!(
outline.as_deref(),
Some(indoc! {"
struct X
a
b
impl X
summarize(&snapshot, Point::new(6, 0)..Point::new(6, 0)),
indoc! {"
struct X {
a: usize,
b: usize,
}
impl X {
<|START|>
fn new
fn a
fn b
"})
fn new() -> Self {}
pub fn a(&self, param: bool) -> usize {}
pub fn b(&self) -> usize {}
}
"}
);
let outline = outline_for_prompt(
&snapshot,
snapshot.anchor_before(Point::new(8, 12))..snapshot.anchor_before(Point::new(13, 9)),
cx,
);
assert_eq!(
outline.as_deref(),
Some(indoc! {"
struct X
a
b
impl X
fn new() -> Self {
let <|START|a = 1;
let b = 2;
Self { a, b }
}
summarize(&snapshot, Point::new(21, 0)..Point::new(21, 0)),
indoc! {"
struct X {
a: usize,
b: usize,
}
pub f|END|>n a(&self, param: bool) -> usize {
self.a
}
fn b
"})
impl X {
fn new() -> Self {}
pub fn a(&self, param: bool) -> usize {}
pub fn b(&self) -> usize {}
}
<|START|>"}
);
let outline = outline_for_prompt(
&snapshot,
snapshot.anchor_before(Point::new(5, 6))..snapshot.anchor_before(Point::new(12, 0)),
cx,
);
// Ensure nested functions get collapsed properly.
let text = indoc! {"
struct X {
a: usize,
b: usize,
}
impl X {
fn new() -> Self {
let a = 1;
let b = 2;
Self { a, b }
}
pub fn a(&self, param: bool) -> usize {
let a = 30;
fn nested() -> usize {
3
}
self.a + nested()
}
pub fn b(&self) -> usize {
self.b
}
}
"};
buffer.update(cx, |buffer, cx| buffer.set_text(text, cx));
let snapshot = buffer.read(cx).snapshot();
assert_eq!(
outline.as_deref(),
Some(indoc! {"
struct X
a
b
impl X<|START| {
summarize(&snapshot, Point::new(0, 0)..Point::new(0, 0)),
indoc! {"
<|START|>struct X {
a: usize,
b: usize,
}
fn new() -> Self {
let a = 1;
let b = 2;
Self { a, b }
}
|END|>
fn a
fn b
"})
);
impl X {
let outline = outline_for_prompt(
&snapshot,
snapshot.anchor_before(Point::new(18, 8))..snapshot.anchor_before(Point::new(18, 8)),
cx,
);
assert_eq!(
outline.as_deref(),
Some(indoc! {"
struct X
a
b
impl X
fn new
fn a
pub fn b(&self) -> usize {
<|START|>self.b
}
"})
fn new() -> Self {}
pub fn a(&self, param: bool) -> usize {}
pub fn b(&self) -> usize {}
}
"}
);
}
}

View file

@ -8,8 +8,8 @@ use crate::{
language_settings::{language_settings, LanguageSettings},
outline::OutlineItem,
syntax_map::{
SyntaxLayerInfo, SyntaxMap, SyntaxMapCapture, SyntaxMapCaptures, SyntaxSnapshot,
ToTreeSitterPoint,
SyntaxLayerInfo, SyntaxMap, SyntaxMapCapture, SyntaxMapCaptures, SyntaxMapMatches,
SyntaxSnapshot, ToTreeSitterPoint,
},
CodeLabel, LanguageScope, Outline,
};
@ -2467,6 +2467,14 @@ impl BufferSnapshot {
Some(items)
}
pub fn matches(
&self,
range: Range<usize>,
query: fn(&Grammar) -> Option<&tree_sitter::Query>,
) -> SyntaxMapMatches {
self.syntax.matches(range, self, query)
}
/// Returns bracket range pairs overlapping or adjacent to `range`
pub fn bracket_ranges<'a, T: ToOffset>(
&'a self,

View file

@ -1,6 +0,0 @@
(function_item
body: (block
"{" @keep
"}" @keep) @collapse)
(use_declaration) @collapse