mirror of
https://github.com/zed-industries/zed.git
synced 2024-12-24 17:28:40 +00:00
move embedding provider to ai crate
This commit is contained in:
parent
48e151495f
commit
68c37ca2a4
11 changed files with 78 additions and 35 deletions
19
Cargo.lock
generated
19
Cargo.lock
generated
|
@ -91,13 +91,25 @@ name = "ai"
|
|||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"ctor",
|
||||
"async-trait",
|
||||
"bincode",
|
||||
"futures 0.3.28",
|
||||
"gpui",
|
||||
"isahc",
|
||||
"lazy_static",
|
||||
"log",
|
||||
"matrixmultiply",
|
||||
"ordered-float",
|
||||
"parking_lot 0.11.2",
|
||||
"parse_duration",
|
||||
"postage",
|
||||
"rand 0.8.5",
|
||||
"regex",
|
||||
"rusqlite",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tiktoken-rs 0.5.4",
|
||||
"util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -6725,9 +6737,9 @@ dependencies = [
|
|||
name = "semantic_index"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"ai",
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
"bincode",
|
||||
"client",
|
||||
"collections",
|
||||
"ctor",
|
||||
|
@ -6736,15 +6748,12 @@ dependencies = [
|
|||
"futures 0.3.28",
|
||||
"globset",
|
||||
"gpui",
|
||||
"isahc",
|
||||
"language",
|
||||
"lazy_static",
|
||||
"log",
|
||||
"matrixmultiply",
|
||||
"node_runtime",
|
||||
"ordered-float",
|
||||
"parking_lot 0.11.2",
|
||||
"parse_duration",
|
||||
"picker",
|
||||
"postage",
|
||||
"pretty_assertions",
|
||||
|
|
|
@ -10,12 +10,25 @@ doctest = false
|
|||
|
||||
[dependencies]
|
||||
gpui = { path = "../gpui" }
|
||||
util = { path = "../util" }
|
||||
async-trait.workspace = true
|
||||
anyhow.workspace = true
|
||||
futures.workspace = true
|
||||
lazy_static.workspace = true
|
||||
ordered-float.workspace = true
|
||||
parking_lot.workspace = true
|
||||
isahc.workspace = true
|
||||
regex.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
postage.workspace = true
|
||||
rand.workspace = true
|
||||
log.workspace = true
|
||||
parse_duration = "2.1.1"
|
||||
tiktoken-rs = "0.5.0"
|
||||
matrixmultiply = "0.3.7"
|
||||
rusqlite = { version = "0.27.0", features = ["blob", "array", "modern_sqlite"] }
|
||||
bincode = "1.3.3"
|
||||
|
||||
[dev-dependencies]
|
||||
ctor.workspace = true
|
||||
gpui = { path = "../gpui", features = ["test-support"] }
|
||||
|
|
|
@ -1 +1,2 @@
|
|||
pub mod completion;
|
||||
pub mod embedding;
|
||||
|
|
|
@ -27,8 +27,30 @@ lazy_static! {
|
|||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
pub struct Embedding(Vec<f32>);
|
||||
pub struct Embedding(pub Vec<f32>);
|
||||
|
||||
// This is needed for semantic index functionality
|
||||
// Unfortunately it has to live wherever the "Embedding" struct is created.
|
||||
// Keeping this in here though, introduces a 'rusqlite' dependency into AI
|
||||
// which is less than ideal
|
||||
impl FromSql for Embedding {
|
||||
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
|
||||
let bytes = value.as_blob()?;
|
||||
let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
|
||||
if embedding.is_err() {
|
||||
return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
|
||||
}
|
||||
Ok(Embedding(embedding.unwrap()))
|
||||
}
|
||||
}
|
||||
|
||||
impl ToSql for Embedding {
|
||||
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
|
||||
let bytes = bincode::serialize(&self.0)
|
||||
.map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
|
||||
Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
|
||||
}
|
||||
}
|
||||
impl From<Vec<f32>> for Embedding {
|
||||
fn from(value: Vec<f32>) -> Self {
|
||||
Embedding(value)
|
||||
|
@ -63,24 +85,24 @@ impl Embedding {
|
|||
}
|
||||
}
|
||||
|
||||
impl FromSql for Embedding {
|
||||
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
|
||||
let bytes = value.as_blob()?;
|
||||
let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
|
||||
if embedding.is_err() {
|
||||
return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
|
||||
}
|
||||
Ok(Embedding(embedding.unwrap()))
|
||||
}
|
||||
}
|
||||
// impl FromSql for Embedding {
|
||||
// fn column_result(value: ValueRef) -> FromSqlResult<Self> {
|
||||
// let bytes = value.as_blob()?;
|
||||
// let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
|
||||
// if embedding.is_err() {
|
||||
// return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
|
||||
// }
|
||||
// Ok(Embedding(embedding.unwrap()))
|
||||
// }
|
||||
// }
|
||||
|
||||
impl ToSql for Embedding {
|
||||
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
|
||||
let bytes = bincode::serialize(&self.0)
|
||||
.map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
|
||||
Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
|
||||
}
|
||||
}
|
||||
// impl ToSql for Embedding {
|
||||
// fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
|
||||
// let bytes = bincode::serialize(&self.0)
|
||||
// .map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
|
||||
// Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
|
||||
// }
|
||||
// }
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct OpenAIEmbeddings {
|
|
@ -9,6 +9,7 @@ path = "src/semantic_index.rs"
|
|||
doctest = false
|
||||
|
||||
[dependencies]
|
||||
ai = { path = "../ai" }
|
||||
collections = { path = "../collections" }
|
||||
gpui = { path = "../gpui" }
|
||||
language = { path = "../language" }
|
||||
|
@ -26,22 +27,18 @@ futures.workspace = true
|
|||
ordered-float.workspace = true
|
||||
smol.workspace = true
|
||||
rusqlite = { version = "0.27.0", features = ["blob", "array", "modern_sqlite"] }
|
||||
isahc.workspace = true
|
||||
log.workspace = true
|
||||
tree-sitter.workspace = true
|
||||
lazy_static.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
async-trait.workspace = true
|
||||
bincode = "1.3.3"
|
||||
matrixmultiply = "0.3.7"
|
||||
tiktoken-rs = "0.5.0"
|
||||
parking_lot.workspace = true
|
||||
rand.workspace = true
|
||||
schemars.workspace = true
|
||||
globset.workspace = true
|
||||
sha1 = "0.10.5"
|
||||
parse_duration = "2.1.1"
|
||||
|
||||
[dev-dependencies]
|
||||
collections = { path = "../collections", features = ["test-support"] }
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
use ai::embedding::OpenAIEmbeddings;
|
||||
use anyhow::{anyhow, Result};
|
||||
use client::{self, UserStore};
|
||||
use gpui::{AsyncAppContext, ModelHandle, Task};
|
||||
use language::LanguageRegistry;
|
||||
use node_runtime::RealNodeRuntime;
|
||||
use project::{Project, RealFs};
|
||||
use semantic_index::embedding::OpenAIEmbeddings;
|
||||
use semantic_index::semantic_index_settings::SemanticIndexSettings;
|
||||
use semantic_index::{SearchResult, SemanticIndex};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
use crate::{
|
||||
embedding::Embedding,
|
||||
parsing::{Span, SpanDigest},
|
||||
SEMANTIC_INDEX_VERSION,
|
||||
};
|
||||
use ai::embedding::Embedding;
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use collections::HashMap;
|
||||
use futures::channel::oneshot;
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
use crate::{embedding::EmbeddingProvider, parsing::Span, JobHandle};
|
||||
use crate::{parsing::Span, JobHandle};
|
||||
use ai::embedding::EmbeddingProvider;
|
||||
use gpui::executor::Background;
|
||||
use parking_lot::Mutex;
|
||||
use smol::channel;
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use crate::embedding::{Embedding, EmbeddingProvider};
|
||||
use ai::embedding::{Embedding, EmbeddingProvider};
|
||||
use anyhow::{anyhow, Result};
|
||||
use language::{Grammar, Language};
|
||||
use rusqlite::{
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
mod db;
|
||||
pub mod embedding;
|
||||
// pub mod embedding;
|
||||
mod embedding_queue;
|
||||
mod parsing;
|
||||
pub mod semantic_index_settings;
|
||||
|
@ -11,7 +11,7 @@ use crate::semantic_index_settings::SemanticIndexSettings;
|
|||
use anyhow::{anyhow, Result};
|
||||
use collections::{BTreeMap, HashMap, HashSet};
|
||||
use db::VectorDatabase;
|
||||
use embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings};
|
||||
use ai::embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings};
|
||||
use embedding_queue::{EmbeddingQueue, FileToEmbed};
|
||||
use futures::{future, FutureExt, StreamExt};
|
||||
use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
use crate::{
|
||||
embedding::{DummyEmbeddings, Embedding, EmbeddingProvider},
|
||||
embedding_queue::EmbeddingQueue,
|
||||
parsing::{subtract_ranges, CodeContextRetriever, Span, SpanDigest},
|
||||
semantic_index_settings::SemanticIndexSettings,
|
||||
FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT,
|
||||
};
|
||||
use ai::embedding::{DummyEmbeddings, Embedding, EmbeddingProvider};
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use gpui::{executor::Deterministic, Task, TestAppContext};
|
||||
|
|
Loading…
Reference in a new issue