diff --git a/Cargo.lock b/Cargo.lock index 8ea6f61da0..75f66163e3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6502,6 +6502,7 @@ dependencies = [ "tree-sitter", "tree-sitter-cpp", "tree-sitter-elixir 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)", + "tree-sitter-json 0.19.0", "tree-sitter-rust", "tree-sitter-toml 0.20.0", "tree-sitter-typescript 0.20.2 (registry+https://github.com/rust-lang/crates.io-index)", diff --git a/crates/language/src/language.rs b/crates/language/src/language.rs index ec233716d6..e34358c7c5 100644 --- a/crates/language/src/language.rs +++ b/crates/language/src/language.rs @@ -526,7 +526,7 @@ pub struct OutlineConfig { pub struct EmbeddingConfig { pub query: Query, pub item_capture_ix: u32, - pub name_capture_ix: u32, + pub name_capture_ix: Option, pub context_capture_ix: Option, pub collapse_capture_ix: Option, pub keep_capture_ix: Option, @@ -1263,7 +1263,7 @@ impl Language { ("collapse", &mut collapse_capture_ix), ], ); - if let Some((item_capture_ix, name_capture_ix)) = item_capture_ix.zip(name_capture_ix) { + if let Some(item_capture_ix) = item_capture_ix { grammar.embedding_config = Some(EmbeddingConfig { query, item_capture_ix, diff --git a/crates/semantic_index/Cargo.toml b/crates/semantic_index/Cargo.toml index 1b3169bfe4..35b9724512 100644 --- a/crates/semantic_index/Cargo.toml +++ b/crates/semantic_index/Cargo.toml @@ -54,6 +54,7 @@ ctor.workspace = true env_logger.workspace = true tree-sitter-typescript = "*" +tree-sitter-json = "*" tree-sitter-rust = "*" tree-sitter-toml = "*" tree-sitter-cpp = "*" diff --git a/crates/semantic_index/src/parsing.rs b/crates/semantic_index/src/parsing.rs index 0d2aeb60fb..c952ef3a4e 100644 --- a/crates/semantic_index/src/parsing.rs +++ b/crates/semantic_index/src/parsing.rs @@ -1,6 +1,12 @@ use anyhow::{anyhow, Ok, Result}; use language::{Grammar, Language}; -use std::{cmp, collections::HashSet, ops::Range, path::Path, sync::Arc}; +use std::{ + cmp::{self, Reverse}, + collections::HashSet, + ops::Range, + path::Path, + sync::Arc, +}; use tree_sitter::{Parser, QueryCursor}; #[derive(Debug, PartialEq, Clone)] @@ -15,7 +21,7 @@ const CODE_CONTEXT_TEMPLATE: &str = "The below code snippet is from file ''\n\n```\n\n```"; const ENTIRE_FILE_TEMPLATE: &str = "The below snippet is from file ''\n\n```\n\n```"; -pub const PARSEABLE_ENTIRE_FILE_TYPES: [&str; 4] = ["TOML", "YAML", "JSON", "CSS"]; +pub const PARSEABLE_ENTIRE_FILE_TYPES: &[&str] = &["TOML", "YAML", "CSS"]; pub struct CodeContextRetriever { pub parser: Parser, @@ -30,8 +36,8 @@ pub struct CodeContextRetriever { #[derive(Debug, Clone)] pub struct CodeContextMatch { pub start_col: usize, - pub item_range: Range, - pub name_range: Range, + pub item_range: Option>, + pub name_range: Option>, pub context_ranges: Vec>, pub collapse_ranges: Vec>, } @@ -44,7 +50,7 @@ impl CodeContextRetriever { } } - fn _parse_entire_file( + fn parse_entire_file( &self, relative_path: &Path, language_name: Arc, @@ -97,7 +103,7 @@ impl CodeContextRetriever { if capture.index == embedding_config.item_capture_ix { item_range = Some(capture.node.byte_range()); start_col = capture.node.start_position().column; - } else if capture.index == embedding_config.name_capture_ix { + } else if Some(capture.index) == embedding_config.name_capture_ix { name_range = Some(capture.node.byte_range()); } else if Some(capture.index) == embedding_config.context_capture_ix { context_ranges.push(capture.node.byte_range()); @@ -108,16 +114,13 @@ impl CodeContextRetriever { } } - if item_range.is_some() && name_range.is_some() { - let item_range = item_range.unwrap(); - captures.push(CodeContextMatch { - start_col, - item_range, - name_range: name_range.unwrap(), - context_ranges, - collapse_ranges: subtract_ranges(&collapse_ranges, &keep_ranges), - }); - } + captures.push(CodeContextMatch { + start_col, + item_range, + name_range, + context_ranges, + collapse_ranges: subtract_ranges(&collapse_ranges, &keep_ranges), + }); } Ok(captures) } @@ -129,7 +132,12 @@ impl CodeContextRetriever { language: Arc, ) -> Result> { let language_name = language.name(); - let mut documents = self.parse_file(relative_path, content, language)?; + + if PARSEABLE_ENTIRE_FILE_TYPES.contains(&language_name.as_ref()) { + return self.parse_entire_file(relative_path, language_name, &content); + } + + let mut documents = self.parse_file(content, language)?; for document in &mut documents { document.content = CODE_CONTEXT_TEMPLATE .replace("", relative_path.to_string_lossy().as_ref()) @@ -139,16 +147,7 @@ impl CodeContextRetriever { Ok(documents) } - pub fn parse_file( - &mut self, - relative_path: &Path, - content: &str, - language: Arc, - ) -> Result> { - if PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref()) { - return self._parse_entire_file(relative_path, language.name(), &content); - } - + pub fn parse_file(&mut self, content: &str, language: Arc) -> Result> { let grammar = language .grammar() .ok_or_else(|| anyhow!("no grammar for language"))?; @@ -163,32 +162,49 @@ impl CodeContextRetriever { let mut collapsed_ranges_within = Vec::new(); let mut parsed_name_ranges = HashSet::new(); for (i, context_match) in matches.iter().enumerate() { - if parsed_name_ranges.contains(&context_match.name_range) { + // Items which are collapsible but not embeddable have no item range + let item_range = if let Some(item_range) = context_match.item_range.clone() { + item_range + } else { continue; + }; + + // Checks for deduplication + let name; + if let Some(name_range) = context_match.name_range.clone() { + name = content + .get(name_range.clone()) + .map_or(String::new(), |s| s.to_string()); + if parsed_name_ranges.contains(&name_range) { + continue; + } + parsed_name_ranges.insert(name_range); + } else { + name = String::new(); } collapsed_ranges_within.clear(); - for remaining_match in &matches[(i + 1)..] { - if context_match - .item_range - .contains(&remaining_match.item_range.start) - && context_match - .item_range - .contains(&remaining_match.item_range.end) - { - collapsed_ranges_within.extend(remaining_match.collapse_ranges.iter().cloned()); - } else { - break; + 'outer: for remaining_match in &matches[(i + 1)..] { + for collapsed_range in &remaining_match.collapse_ranges { + if item_range.start <= collapsed_range.start + && item_range.end >= collapsed_range.end + { + collapsed_ranges_within.push(collapsed_range.clone()); + } else { + break 'outer; + } } } + collapsed_ranges_within.sort_by_key(|r| (r.start, Reverse(r.end))); + let mut document_content = String::new(); for context_range in &context_match.context_ranges { document_content.push_str(&content[context_range.clone()]); document_content.push_str("\n"); } - let mut offset = context_match.item_range.start; + let mut offset = item_range.start; for collapsed_range in &collapsed_ranges_within { if collapsed_range.start > offset { add_content_from_range( @@ -197,29 +213,30 @@ impl CodeContextRetriever { offset..collapsed_range.start, context_match.start_col, ); + offset = collapsed_range.start; + } + + if collapsed_range.end > offset { + document_content.push_str(placeholder); + offset = collapsed_range.end; } - document_content.push_str(placeholder); - offset = collapsed_range.end; } - if offset < context_match.item_range.end { + if offset < item_range.end { add_content_from_range( &mut document_content, content, - offset..context_match.item_range.end, + offset..item_range.end, context_match.start_col, ); } - if let Some(name) = content.get(context_match.name_range.clone()) { - parsed_name_ranges.insert(context_match.name_range.clone()); - documents.push(Document { - name: name.to_string(), - content: document_content, - range: context_match.item_range.clone(), - embedding: vec![], - }) - } + documents.push(Document { + name, + content: document_content, + range: item_range.clone(), + embedding: vec![], + }) } return Ok(documents); diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 271fd741a6..6e04774915 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -33,7 +33,7 @@ use util::{ ResultExt, }; -const SEMANTIC_INDEX_VERSION: usize = 4; +const SEMANTIC_INDEX_VERSION: usize = 5; const EMBEDDINGS_BATCH_SIZE: usize = 80; pub fn init( diff --git a/crates/semantic_index/src/semantic_index_tests.rs b/crates/semantic_index/src/semantic_index_tests.rs index c54d5079d3..31c96ca207 100644 --- a/crates/semantic_index/src/semantic_index_tests.rs +++ b/crates/semantic_index/src/semantic_index_tests.rs @@ -170,9 +170,7 @@ async fn test_code_context_retrieval_rust() { " .unindent(); - let documents = retriever - .parse_file(Path::new("foo.rs"), &text, language) - .unwrap(); + let documents = retriever.parse_file(&text, language).unwrap(); assert_documents_eq( &documents, @@ -229,6 +227,76 @@ async fn test_code_context_retrieval_rust() { ); } +#[gpui::test] +async fn test_code_context_retrieval_json() { + let language = json_lang(); + let mut retriever = CodeContextRetriever::new(); + + let text = r#" + { + "array": [1, 2, 3, 4], + "string": "abcdefg", + "nested_object": { + "array_2": [5, 6, 7, 8], + "string_2": "hijklmnop", + "boolean": true, + "none": null + } + } + "# + .unindent(); + + let documents = retriever.parse_file(&text, language.clone()).unwrap(); + + assert_documents_eq( + &documents, + &[( + r#" + { + "array": [], + "string": "", + "nested_object": { + "array_2": [], + "string_2": "", + "boolean": true, + "none": null + } + }"# + .unindent(), + text.find("{").unwrap(), + )], + ); + + let text = r#" + [ + { + "name": "somebody", + "age": 42 + }, + { + "name": "somebody else", + "age": 43 + } + ] + "# + .unindent(); + + let documents = retriever.parse_file(&text, language.clone()).unwrap(); + + assert_documents_eq( + &documents, + &[( + r#" + [{ + "name": "", + "age": 42 + }]"# + .unindent(), + text.find("[").unwrap(), + )], + ); +} + fn assert_documents_eq( documents: &[Document], expected_contents_and_start_offsets: &[(String, usize)], @@ -913,6 +981,35 @@ fn rust_lang() -> Arc { ) } +fn json_lang() -> Arc { + Arc::new( + Language::new( + LanguageConfig { + name: "JSON".into(), + path_suffixes: vec!["json".into()], + ..Default::default() + }, + Some(tree_sitter_json::language()), + ) + .with_embedding_query( + r#" + (document) @item + + (array + "[" @keep + . + (object)? @keep + "]" @keep) @collapse + + (pair value: (string + "\"" @keep + "\"" @keep) @collapse) + "#, + ) + .unwrap(), + ) +} + fn toml_lang() -> Arc { Arc::new(Language::new( LanguageConfig { diff --git a/crates/zed/src/languages/json/embedding.scm b/crates/zed/src/languages/json/embedding.scm new file mode 100644 index 0000000000..fa286e3880 --- /dev/null +++ b/crates/zed/src/languages/json/embedding.scm @@ -0,0 +1,14 @@ +; Only produce one embedding for the entire file. +(document) @item + +; Collapse arrays, except for the first object. +(array + "[" @keep + . + (object)? @keep + "]" @keep) @collapse + +; Collapse string values (but not keys). +(pair value: (string + "\"" @keep + "\"" @keep) @collapse)