diff --git a/Cargo.lock b/Cargo.lock index 1ff9de0d5a..ec21ac18e7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -974,7 +974,7 @@ dependencies = [ "collections", "editor", "gpui", - "itertools", + "itertools 0.10.5", "language", "outline", "project", @@ -1924,7 +1924,7 @@ dependencies = [ "cranelift-codegen", "cranelift-entity", "cranelift-frontend", - "itertools", + "itertools 0.10.5", "log", "smallvec", "wasmparser", @@ -2074,6 +2074,41 @@ dependencies = [ "winapi 0.3.9", ] +[[package]] +name = "darling" +version = "0.14.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b750cb3417fd1b327431a470f388520309479ab0bf5e323505daf0290cd3850" +dependencies = [ + "darling_core", + "darling_macro", +] + +[[package]] +name = "darling_core" +version = "0.14.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "109c1ca6e6b7f82cc233a97004ea8ed7ca123a9af07a8230878fcfda9b158bf0" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn 1.0.109", +] + +[[package]] +name = "darling_macro" +version = "0.14.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4aab4dbc9f7611d8b55048a3a16d2d010c2c8334e46304b40ac1cc14bf3b48e" +dependencies = [ + "darling_core", + "quote", + "syn 1.0.109", +] + [[package]] name = "dashmap" version = "5.5.1" @@ -2143,6 +2178,37 @@ dependencies = [ "serde", ] +[[package]] +name = "derive_builder" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d67778784b508018359cbc8696edb3db78160bab2c2a28ba7f56ef6932997f8" +dependencies = [ + "derive_builder_macro", +] + +[[package]] +name = "derive_builder_core" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c11bdc11a0c47bc7d37d582b5285da6849c96681023680b906673c5707af7b0f" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "derive_builder_macro" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebcda35c7a396850a55ffeac740804b40ffec779b98fffbb1738f4033f0ee79e" +dependencies = [ + "derive_builder_core", + "syn 1.0.109", +] + [[package]] name = "derive_more" version = "0.99.17" @@ -2351,7 +2417,7 @@ dependencies = [ "git", "gpui", "indoc", - "itertools", + "itertools 0.10.5", "language", "lazy_static", "log", @@ -2480,6 +2546,12 @@ dependencies = [ "libc", ] +[[package]] +name = "esaxx-rs" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f748b253ceca9fed5f42f8b5ceb3851e93102199bc25b64b65369f76e5c0a35" + [[package]] name = "etagere" version = "0.2.8" @@ -3141,7 +3213,7 @@ dependencies = [ "futures 0.3.28", "gpui_macros", "image", - "itertools", + "itertools 0.10.5", "lazy_static", "log", "media", @@ -3517,6 +3589,12 @@ dependencies = [ "cc", ] +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + [[package]] name = "idna" version = "0.4.0" @@ -3725,6 +3803,24 @@ dependencies = [ "waker-fn", ] +[[package]] +name = "itertools" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f56a2d0bc861f9165be4eb3442afd3c236d8a98afd426f65d92324ae1091a484" +dependencies = [ + "either", +] + +[[package]] +name = "itertools" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "284f18f85651fe11e8a991b2adb42cb078325c996ed026d994719efcfca1d54b" +dependencies = [ + "either", +] + [[package]] name = "itertools" version = "0.10.5" @@ -4221,6 +4317,22 @@ dependencies = [ "libc", ] +[[package]] +name = "macro_rules_attribute" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf0c9b980bf4f3a37fd7b1c066941dd1b1d0152ce6ee6e8fe8c49b9f6810d862" +dependencies = [ + "macro_rules_attribute-proc_macro", + "paste", +] + +[[package]] +name = "macro_rules_attribute-proc_macro" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "58093314a45e00c77d5c508f76e77c3396afbbc0d01506e7fae47b018bac2b1d" + [[package]] name = "malloc_buf" version = "0.0.6" @@ -4490,6 +4602,27 @@ dependencies = [ "winapi 0.3.9", ] +[[package]] +name = "monostate" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15f370ae88093ec6b11a710dec51321a61d420fafd1bad6e30d01bd9c920e8ee" +dependencies = [ + "monostate-impl", + "serde", +] + +[[package]] +name = "monostate-impl" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "371717c0a5543d6a800cac822eac735aa7d2d2fbb41002e9856a4089532dbdce" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.29", +] + [[package]] name = "more-asserts" version = "0.2.2" @@ -4914,6 +5047,28 @@ version = "1.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" +[[package]] +name = "onig" +version = "6.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c4b31c8722ad9171c6d77d3557db078cab2bd50afcc9d09c8b315c59df8ca4f" +dependencies = [ + "bitflags 1.3.2", + "libc", + "once_cell", + "onig_sys", +] + +[[package]] +name = "onig_sys" +version = "69.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b829e3d7e9cc74c7e315ee8edb185bf4190da5acde74afd7fc59c35b1f086e7" +dependencies = [ + "cc", + "pkg-config", +] + [[package]] name = "opaque-debug" version = "0.3.0" @@ -5532,7 +5687,7 @@ dependencies = [ "globset", "gpui", "ignore", - "itertools", + "itertools 0.10.5", "language", "lazy_static", "log", @@ -5656,7 +5811,7 @@ checksum = "62941722fb675d463659e49c4f3fe1fe792ff24fe5bbaa9c08cd3b98a1c354f5" dependencies = [ "bytes 1.4.0", "heck 0.3.3", - "itertools", + "itertools 0.10.5", "lazy_static", "log", "multimap", @@ -5675,7 +5830,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "600d2f334aa05acb02a755e217ef1ab6dea4d51b58b7846588b747edec04efba" dependencies = [ "anyhow", - "itertools", + "itertools 0.10.5", "proc-macro2", "quote", "syn 1.0.109", @@ -5688,7 +5843,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f9cc1a3263e07e0bf68e96268f37665207b49560d98739662cdfaae215c720fe" dependencies = [ "anyhow", - "itertools", + "itertools 0.10.5", "proc-macro2", "quote", "syn 1.0.109", @@ -5917,6 +6072,17 @@ dependencies = [ "rayon-core", ] +[[package]] +name = "rayon-cond" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd1259362c9065e5ea39a789ef40b1e3fd934c94beb7b5ab3ac6629d3b5e7cb7" +dependencies = [ + "either", + "itertools 0.8.2", + "rayon", +] + [[package]] name = "rayon-core" version = "1.11.0" @@ -6824,6 +6990,7 @@ dependencies = [ "lazy_static", "log", "matrixmultiply", + "ndarray", "ort", "parking_lot 0.11.2", "parse_duration", @@ -6843,6 +7010,7 @@ dependencies = [ "tempdir", "theme", "tiktoken-rs 0.5.1", + "tokenizers", "tree-sitter", "tree-sitter-cpp", "tree-sitter-elixir", @@ -7300,6 +7468,18 @@ dependencies = [ "lock_api", ] +[[package]] +name = "spm_precompiled" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5851699c4033c63636f7ea4cf7b7c1f1bf06d0cc03cfb42e711de5a5c46cf326" +dependencies = [ + "base64 0.13.1", + "nom", + "serde", + "unicode-segmentation", +] + [[package]] name = "spsc-buffer" version = "0.1.1" @@ -7339,7 +7519,7 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c12bc9199d1db8234678b7051747c07f517cdcf019262d1847b94ec8b1aee3e" dependencies = [ - "itertools", + "itertools 0.10.5", "nom", "unicode_categories", ] @@ -7736,7 +7916,7 @@ dependencies = [ "dirs 4.0.0", "futures 0.3.28", "gpui", - "itertools", + "itertools 0.10.5", "lazy_static", "libc", "mio-extras", @@ -7767,7 +7947,7 @@ dependencies = [ "editor", "futures 0.3.28", "gpui", - "itertools", + "itertools 0.10.5", "language", "lazy_static", "libc", @@ -8015,6 +8195,37 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" +[[package]] +name = "tokenizers" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "12b515a66453a4d68f03398054f7204fd0dde6b93d3f20ea90b08025ab49b499" +dependencies = [ + "aho-corasick 0.7.20", + "derive_builder", + "esaxx-rs", + "getrandom 0.2.10", + "itertools 0.9.0", + "lazy_static", + "log", + "macro_rules_attribute", + "monostate", + "onig", + "paste", + "rand 0.8.5", + "rayon", + "rayon-cond", + "regex", + "regex-syntax 0.7.4", + "serde", + "serde_json", + "spm_precompiled", + "thiserror", + "unicode-normalization-alignments", + "unicode-segmentation", + "unicode_categories", +] + [[package]] name = "tokio" version = "1.32.0" @@ -8717,6 +8928,15 @@ dependencies = [ "tinyvec", ] +[[package]] +name = "unicode-normalization-alignments" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43f613e4fa046e69818dd287fdc4bc78175ff20331479dab6e1b0f98d57062de" +dependencies = [ + "smallvec", +] + [[package]] name = "unicode-script" version = "0.5.5" @@ -8948,7 +9168,7 @@ dependencies = [ "editor", "gpui", "indoc", - "itertools", + "itertools 0.10.5", "language", "language_selector", "log", @@ -9809,7 +10029,7 @@ dependencies = [ "gpui", "indoc", "install_cli", - "itertools", + "itertools 0.10.5", "language", "lazy_static", "log", diff --git a/crates/semantic_index/Cargo.toml b/crates/semantic_index/Cargo.toml index 93658b1c3f..ea393a6927 100644 --- a/crates/semantic_index/Cargo.toml +++ b/crates/semantic_index/Cargo.toml @@ -42,6 +42,8 @@ globset.workspace = true sha1 = "0.10.5" parse_duration = "2.1.1" ort = { version = "1.15.2", features = ["coreml"]} +tokenizers = { version = ">=0.13.4", default-features = false, features = [ "onig" ] } +ndarray = "0.15" [dev-dependencies] collections = { path = "../collections", features = ["test-support"] } diff --git a/crates/semantic_index/src/cross_encoder.rs b/crates/semantic_index/src/cross_encoder.rs index 9fb928e329..7aaf8a3bbe 100644 --- a/crates/semantic_index/src/cross_encoder.rs +++ b/crates/semantic_index/src/cross_encoder.rs @@ -1,26 +1,99 @@ -use ort::{Environment, ExecutionProvider, GraphOptimizationLevel}; +use ndarray::{Array1, Array2, Axis, CowArray}; +use ort::{Environment, ExecutionProvider, GraphOptimizationLevel, Session, SessionBuilder, Value}; +use tokenizers::Tokenizer; +use util::paths::MODELS_DIR; -struct CrossEncoder {} +struct CrossEncoder { + session: Session, + tokenizer: Tokenizer, +} + +fn sigmoid(val: f32) -> f32 { + 1.0 / (1.0 + (-val).exp()) +} impl CrossEncoder { pub fn load() -> anyhow::Result { + let model_path = MODELS_DIR.join("cross-encoder").join("model.onnx"); + let tokenizer_path = MODELS_DIR.join("cross-encoder").join("tokenizer.json"); + let environment = Environment::builder() .with_name("cross-encoder") .with_execution_providers([ExecutionProvider::CoreML(Default::default())]) .build()? .into_arc(); - let model = "../models/cross-encoder.onnx"; - let mut session = environment - .new_session_builder() - .unwrap() - .with_optimization_level(GraphOptimizationLevel::Basic) - .unwrap() - .with_number_threads(1) - .unwrap() - .with_model_from_file(model) - .unwrap(); + let session = SessionBuilder::new(&environment)? + .with_optimization_level(GraphOptimizationLevel::Level1)? + .with_intra_threads(1)? + .with_model_from_file(model_path)?; - Ok(Self {}) + let tokenizer = Tokenizer::from_file(tokenizer_path).unwrap(); + + Ok(Self { session, tokenizer }) + } + + pub fn score(&self, query: &str, candidates: Vec<&str>) -> anyhow::Result> { + let spans = candidates + .into_iter() + .map(|candidate| format!("{}. {}", query, candidate)) + .collect::>(); + + let encodings = self.tokenizer.encode_batch(spans, true).unwrap(); + + let mut results = Vec::new(); + for encoding in encodings { + // Get Input Variables Individually + let input_ids = encoding.get_ids(); + let attention_mask = encoding.get_attention_mask(); + let token_type_ids = encoding.get_type_ids(); + let length = input_ids.len(); + + // Convert to Arrays + let inputs_ids_array = CowArray::from(ndarray::Array::from_shape_vec( + (1, length), + input_ids.iter().map(|&x| x as i64).collect(), + )?); + + let attention_mask_array = CowArray::from(ndarray::Array::from_shape_vec( + (1, length), + attention_mask.iter().map(|&x| x as i64).collect(), + )?) + .into_dyn(); + + let token_type_ids_array = CowArray::from(ndarray::Array::from_shape_vec( + (1, length), + token_type_ids.iter().map(|&x| x as i64).collect(), + )?) + .into_dyn(); + + let outputs = self.session.run(vec![ + Value::from_array(self.session.allocator(), &inputs_ids_array.into_dyn())?, + Value::from_array(self.session.allocator(), &attention_mask_array)?, + Value::from_array(self.session.allocator(), &token_type_ids_array)?, + ]); + + let output = outputs.unwrap()[0].try_extract::().unwrap(); + let value = output.view().to_owned(); + + let val = value.as_slice().unwrap()[0]; + results.push(sigmoid(val)) + } + + Ok(results) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_cross_encoder() { + let cross_encoder = CrossEncoder::load().unwrap(); + + let sample_candidates = vec!["I love you.", "I hate you."]; + let results = cross_encoder.score("I like you", sample_candidates.clone()); + assert_eq!(results.unwrap().len(), sample_candidates.len()); } } diff --git a/crates/util/src/paths.rs b/crates/util/src/paths.rs index e7e6e0ac72..927f484e1d 100644 --- a/crates/util/src/paths.rs +++ b/crates/util/src/paths.rs @@ -18,6 +18,7 @@ lazy_static::lazy_static! { pub static ref LOG: PathBuf = LOGS_DIR.join("Zed.log"); pub static ref OLD_LOG: PathBuf = LOGS_DIR.join("Zed.log.old"); pub static ref LOCAL_SETTINGS_RELATIVE_PATH: &'static Path = Path::new(".zed/settings.json"); + pub static ref MODELS_DIR: PathBuf = HOME.join("Library/Application Support/Zed/models"); } pub mod legacy {