add cross encoder model to semantic_index

This commit is contained in:
KCaverly 2023-09-11 14:54:15 -04:00
parent 65add70a37
commit 82760d6d1a
4 changed files with 322 additions and 26 deletions

246
Cargo.lock generated
View file

@ -974,7 +974,7 @@ dependencies = [
"collections",
"editor",
"gpui",
"itertools",
"itertools 0.10.5",
"language",
"outline",
"project",
@ -1924,7 +1924,7 @@ dependencies = [
"cranelift-codegen",
"cranelift-entity",
"cranelift-frontend",
"itertools",
"itertools 0.10.5",
"log",
"smallvec",
"wasmparser",
@ -2074,6 +2074,41 @@ dependencies = [
"winapi 0.3.9",
]
[[package]]
name = "darling"
version = "0.14.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b750cb3417fd1b327431a470f388520309479ab0bf5e323505daf0290cd3850"
dependencies = [
"darling_core",
"darling_macro",
]
[[package]]
name = "darling_core"
version = "0.14.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "109c1ca6e6b7f82cc233a97004ea8ed7ca123a9af07a8230878fcfda9b158bf0"
dependencies = [
"fnv",
"ident_case",
"proc-macro2",
"quote",
"strsim",
"syn 1.0.109",
]
[[package]]
name = "darling_macro"
version = "0.14.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a4aab4dbc9f7611d8b55048a3a16d2d010c2c8334e46304b40ac1cc14bf3b48e"
dependencies = [
"darling_core",
"quote",
"syn 1.0.109",
]
[[package]]
name = "dashmap"
version = "5.5.1"
@ -2143,6 +2178,37 @@ dependencies = [
"serde",
]
[[package]]
name = "derive_builder"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8d67778784b508018359cbc8696edb3db78160bab2c2a28ba7f56ef6932997f8"
dependencies = [
"derive_builder_macro",
]
[[package]]
name = "derive_builder_core"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c11bdc11a0c47bc7d37d582b5285da6849c96681023680b906673c5707af7b0f"
dependencies = [
"darling",
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]]
name = "derive_builder_macro"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ebcda35c7a396850a55ffeac740804b40ffec779b98fffbb1738f4033f0ee79e"
dependencies = [
"derive_builder_core",
"syn 1.0.109",
]
[[package]]
name = "derive_more"
version = "0.99.17"
@ -2351,7 +2417,7 @@ dependencies = [
"git",
"gpui",
"indoc",
"itertools",
"itertools 0.10.5",
"language",
"lazy_static",
"log",
@ -2480,6 +2546,12 @@ dependencies = [
"libc",
]
[[package]]
name = "esaxx-rs"
version = "0.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1f748b253ceca9fed5f42f8b5ceb3851e93102199bc25b64b65369f76e5c0a35"
[[package]]
name = "etagere"
version = "0.2.8"
@ -3141,7 +3213,7 @@ dependencies = [
"futures 0.3.28",
"gpui_macros",
"image",
"itertools",
"itertools 0.10.5",
"lazy_static",
"log",
"media",
@ -3517,6 +3589,12 @@ dependencies = [
"cc",
]
[[package]]
name = "ident_case"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39"
[[package]]
name = "idna"
version = "0.4.0"
@ -3725,6 +3803,24 @@ dependencies = [
"waker-fn",
]
[[package]]
name = "itertools"
version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f56a2d0bc861f9165be4eb3442afd3c236d8a98afd426f65d92324ae1091a484"
dependencies = [
"either",
]
[[package]]
name = "itertools"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "284f18f85651fe11e8a991b2adb42cb078325c996ed026d994719efcfca1d54b"
dependencies = [
"either",
]
[[package]]
name = "itertools"
version = "0.10.5"
@ -4221,6 +4317,22 @@ dependencies = [
"libc",
]
[[package]]
name = "macro_rules_attribute"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cf0c9b980bf4f3a37fd7b1c066941dd1b1d0152ce6ee6e8fe8c49b9f6810d862"
dependencies = [
"macro_rules_attribute-proc_macro",
"paste",
]
[[package]]
name = "macro_rules_attribute-proc_macro"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "58093314a45e00c77d5c508f76e77c3396afbbc0d01506e7fae47b018bac2b1d"
[[package]]
name = "malloc_buf"
version = "0.0.6"
@ -4490,6 +4602,27 @@ dependencies = [
"winapi 0.3.9",
]
[[package]]
name = "monostate"
version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "15f370ae88093ec6b11a710dec51321a61d420fafd1bad6e30d01bd9c920e8ee"
dependencies = [
"monostate-impl",
"serde",
]
[[package]]
name = "monostate-impl"
version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "371717c0a5543d6a800cac822eac735aa7d2d2fbb41002e9856a4089532dbdce"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.29",
]
[[package]]
name = "more-asserts"
version = "0.2.2"
@ -4914,6 +5047,28 @@ version = "1.18.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d"
[[package]]
name = "onig"
version = "6.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8c4b31c8722ad9171c6d77d3557db078cab2bd50afcc9d09c8b315c59df8ca4f"
dependencies = [
"bitflags 1.3.2",
"libc",
"once_cell",
"onig_sys",
]
[[package]]
name = "onig_sys"
version = "69.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b829e3d7e9cc74c7e315ee8edb185bf4190da5acde74afd7fc59c35b1f086e7"
dependencies = [
"cc",
"pkg-config",
]
[[package]]
name = "opaque-debug"
version = "0.3.0"
@ -5532,7 +5687,7 @@ dependencies = [
"globset",
"gpui",
"ignore",
"itertools",
"itertools 0.10.5",
"language",
"lazy_static",
"log",
@ -5656,7 +5811,7 @@ checksum = "62941722fb675d463659e49c4f3fe1fe792ff24fe5bbaa9c08cd3b98a1c354f5"
dependencies = [
"bytes 1.4.0",
"heck 0.3.3",
"itertools",
"itertools 0.10.5",
"lazy_static",
"log",
"multimap",
@ -5675,7 +5830,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "600d2f334aa05acb02a755e217ef1ab6dea4d51b58b7846588b747edec04efba"
dependencies = [
"anyhow",
"itertools",
"itertools 0.10.5",
"proc-macro2",
"quote",
"syn 1.0.109",
@ -5688,7 +5843,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f9cc1a3263e07e0bf68e96268f37665207b49560d98739662cdfaae215c720fe"
dependencies = [
"anyhow",
"itertools",
"itertools 0.10.5",
"proc-macro2",
"quote",
"syn 1.0.109",
@ -5917,6 +6072,17 @@ dependencies = [
"rayon-core",
]
[[package]]
name = "rayon-cond"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fd1259362c9065e5ea39a789ef40b1e3fd934c94beb7b5ab3ac6629d3b5e7cb7"
dependencies = [
"either",
"itertools 0.8.2",
"rayon",
]
[[package]]
name = "rayon-core"
version = "1.11.0"
@ -6824,6 +6990,7 @@ dependencies = [
"lazy_static",
"log",
"matrixmultiply",
"ndarray",
"ort",
"parking_lot 0.11.2",
"parse_duration",
@ -6843,6 +7010,7 @@ dependencies = [
"tempdir",
"theme",
"tiktoken-rs 0.5.1",
"tokenizers",
"tree-sitter",
"tree-sitter-cpp",
"tree-sitter-elixir",
@ -7300,6 +7468,18 @@ dependencies = [
"lock_api",
]
[[package]]
name = "spm_precompiled"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5851699c4033c63636f7ea4cf7b7c1f1bf06d0cc03cfb42e711de5a5c46cf326"
dependencies = [
"base64 0.13.1",
"nom",
"serde",
"unicode-segmentation",
]
[[package]]
name = "spsc-buffer"
version = "0.1.1"
@ -7339,7 +7519,7 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c12bc9199d1db8234678b7051747c07f517cdcf019262d1847b94ec8b1aee3e"
dependencies = [
"itertools",
"itertools 0.10.5",
"nom",
"unicode_categories",
]
@ -7736,7 +7916,7 @@ dependencies = [
"dirs 4.0.0",
"futures 0.3.28",
"gpui",
"itertools",
"itertools 0.10.5",
"lazy_static",
"libc",
"mio-extras",
@ -7767,7 +7947,7 @@ dependencies = [
"editor",
"futures 0.3.28",
"gpui",
"itertools",
"itertools 0.10.5",
"language",
"lazy_static",
"libc",
@ -8015,6 +8195,37 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
[[package]]
name = "tokenizers"
version = "0.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "12b515a66453a4d68f03398054f7204fd0dde6b93d3f20ea90b08025ab49b499"
dependencies = [
"aho-corasick 0.7.20",
"derive_builder",
"esaxx-rs",
"getrandom 0.2.10",
"itertools 0.9.0",
"lazy_static",
"log",
"macro_rules_attribute",
"monostate",
"onig",
"paste",
"rand 0.8.5",
"rayon",
"rayon-cond",
"regex",
"regex-syntax 0.7.4",
"serde",
"serde_json",
"spm_precompiled",
"thiserror",
"unicode-normalization-alignments",
"unicode-segmentation",
"unicode_categories",
]
[[package]]
name = "tokio"
version = "1.32.0"
@ -8717,6 +8928,15 @@ dependencies = [
"tinyvec",
]
[[package]]
name = "unicode-normalization-alignments"
version = "0.1.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "43f613e4fa046e69818dd287fdc4bc78175ff20331479dab6e1b0f98d57062de"
dependencies = [
"smallvec",
]
[[package]]
name = "unicode-script"
version = "0.5.5"
@ -8948,7 +9168,7 @@ dependencies = [
"editor",
"gpui",
"indoc",
"itertools",
"itertools 0.10.5",
"language",
"language_selector",
"log",
@ -9809,7 +10029,7 @@ dependencies = [
"gpui",
"indoc",
"install_cli",
"itertools",
"itertools 0.10.5",
"language",
"lazy_static",
"log",

View file

@ -42,6 +42,8 @@ globset.workspace = true
sha1 = "0.10.5"
parse_duration = "2.1.1"
ort = { version = "1.15.2", features = ["coreml"]}
tokenizers = { version = ">=0.13.4", default-features = false, features = [ "onig" ] }
ndarray = "0.15"
[dev-dependencies]
collections = { path = "../collections", features = ["test-support"] }

View file

@ -1,26 +1,99 @@
use ort::{Environment, ExecutionProvider, GraphOptimizationLevel};
use ndarray::{Array1, Array2, Axis, CowArray};
use ort::{Environment, ExecutionProvider, GraphOptimizationLevel, Session, SessionBuilder, Value};
use tokenizers::Tokenizer;
use util::paths::MODELS_DIR;
struct CrossEncoder {}
struct CrossEncoder {
session: Session,
tokenizer: Tokenizer,
}
fn sigmoid(val: f32) -> f32 {
1.0 / (1.0 + (-val).exp())
}
impl CrossEncoder {
pub fn load() -> anyhow::Result<Self> {
let model_path = MODELS_DIR.join("cross-encoder").join("model.onnx");
let tokenizer_path = MODELS_DIR.join("cross-encoder").join("tokenizer.json");
let environment = Environment::builder()
.with_name("cross-encoder")
.with_execution_providers([ExecutionProvider::CoreML(Default::default())])
.build()?
.into_arc();
let model = "../models/cross-encoder.onnx";
let mut session = environment
.new_session_builder()
.unwrap()
.with_optimization_level(GraphOptimizationLevel::Basic)
.unwrap()
.with_number_threads(1)
.unwrap()
.with_model_from_file(model)
.unwrap();
let session = SessionBuilder::new(&environment)?
.with_optimization_level(GraphOptimizationLevel::Level1)?
.with_intra_threads(1)?
.with_model_from_file(model_path)?;
Ok(Self {})
let tokenizer = Tokenizer::from_file(tokenizer_path).unwrap();
Ok(Self { session, tokenizer })
}
pub fn score(&self, query: &str, candidates: Vec<&str>) -> anyhow::Result<Vec<f32>> {
let spans = candidates
.into_iter()
.map(|candidate| format!("{}. {}", query, candidate))
.collect::<Vec<_>>();
let encodings = self.tokenizer.encode_batch(spans, true).unwrap();
let mut results = Vec::new();
for encoding in encodings {
// Get Input Variables Individually
let input_ids = encoding.get_ids();
let attention_mask = encoding.get_attention_mask();
let token_type_ids = encoding.get_type_ids();
let length = input_ids.len();
// Convert to Arrays
let inputs_ids_array = CowArray::from(ndarray::Array::from_shape_vec(
(1, length),
input_ids.iter().map(|&x| x as i64).collect(),
)?);
let attention_mask_array = CowArray::from(ndarray::Array::from_shape_vec(
(1, length),
attention_mask.iter().map(|&x| x as i64).collect(),
)?)
.into_dyn();
let token_type_ids_array = CowArray::from(ndarray::Array::from_shape_vec(
(1, length),
token_type_ids.iter().map(|&x| x as i64).collect(),
)?)
.into_dyn();
let outputs = self.session.run(vec![
Value::from_array(self.session.allocator(), &inputs_ids_array.into_dyn())?,
Value::from_array(self.session.allocator(), &attention_mask_array)?,
Value::from_array(self.session.allocator(), &token_type_ids_array)?,
]);
let output = outputs.unwrap()[0].try_extract::<f32>().unwrap();
let value = output.view().to_owned();
let val = value.as_slice().unwrap()[0];
results.push(sigmoid(val))
}
Ok(results)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cross_encoder() {
let cross_encoder = CrossEncoder::load().unwrap();
let sample_candidates = vec!["I love you.", "I hate you."];
let results = cross_encoder.score("I like you", sample_candidates.clone());
assert_eq!(results.unwrap().len(), sample_candidates.len());
}
}

View file

@ -18,6 +18,7 @@ lazy_static::lazy_static! {
pub static ref LOG: PathBuf = LOGS_DIR.join("Zed.log");
pub static ref OLD_LOG: PathBuf = LOGS_DIR.join("Zed.log.old");
pub static ref LOCAL_SETTINGS_RELATIVE_PATH: &'static Path = Path::new(".zed/settings.json");
pub static ref MODELS_DIR: PathBuf = HOME.join("Library/Application Support/Zed/models");
}
pub mod legacy {