mirror of
https://github.com/zed-industries/zed.git
synced 2024-12-25 01:34:02 +00:00
leverage embeddings len returned in construction matrix multiplication
This commit is contained in:
parent
3682751455
commit
0e6fd645fd
1 changed files with 13 additions and 3 deletions
|
@ -438,6 +438,13 @@ impl VectorDatabase {
|
||||||
.filter_map(|row| row.ok())
|
.filter_map(|row| row.ok())
|
||||||
.collect::<Vec<(usize, Embedding)>>();
|
.collect::<Vec<(usize, Embedding)>>();
|
||||||
|
|
||||||
|
if deserialized_rows.len() == 0 {
|
||||||
|
return Ok(Vec::new());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get Length of Embeddings Returned
|
||||||
|
let embedding_len = deserialized_rows[0].1 .0.len();
|
||||||
|
|
||||||
let batch_n = 1000;
|
let batch_n = 1000;
|
||||||
let mut batches = Vec::new();
|
let mut batches = Vec::new();
|
||||||
let mut batch_ids = Vec::new();
|
let mut batch_ids = Vec::new();
|
||||||
|
@ -449,7 +456,8 @@ impl VectorDatabase {
|
||||||
if batch_ids.len() == batch_n {
|
if batch_ids.len() == batch_n {
|
||||||
let embeddings = std::mem::take(&mut batch_embeddings);
|
let embeddings = std::mem::take(&mut batch_embeddings);
|
||||||
let ids = std::mem::take(&mut batch_ids);
|
let ids = std::mem::take(&mut batch_ids);
|
||||||
let array = Array2::from_shape_vec((batch_ids.len(), 1536), embeddings);
|
let array =
|
||||||
|
Array2::from_shape_vec((ids.len(), embedding_len.clone()), embeddings);
|
||||||
match array {
|
match array {
|
||||||
Ok(array) => {
|
Ok(array) => {
|
||||||
batches.push((ids, array));
|
batches.push((ids, array));
|
||||||
|
@ -460,8 +468,10 @@ impl VectorDatabase {
|
||||||
});
|
});
|
||||||
|
|
||||||
if batch_ids.len() > 0 {
|
if batch_ids.len() > 0 {
|
||||||
let array =
|
let array = Array2::from_shape_vec(
|
||||||
Array2::from_shape_vec((batch_ids.len(), 1536), batch_embeddings.clone());
|
(batch_ids.len(), embedding_len),
|
||||||
|
batch_embeddings.clone(),
|
||||||
|
);
|
||||||
match array {
|
match array {
|
||||||
Ok(array) => {
|
Ok(array) => {
|
||||||
batches.push((batch_ids.clone(), array));
|
batches.push((batch_ids.clone(), array));
|
||||||
|
|
Loading…
Reference in a new issue