mirror of
https://github.com/zed-industries/zed.git
synced 2024-12-29 04:20:46 +00:00
moved TestCompletionProvider into ai
This commit is contained in:
parent
ec9d79b6fe
commit
7af77b1cf9
3 changed files with 41 additions and 37 deletions
|
@ -4,9 +4,12 @@ use std::{
|
|||
};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||
use parking_lot::Mutex;
|
||||
|
||||
use crate::{
|
||||
auth::{CredentialProvider, NullCredentialProvider, ProviderCredential},
|
||||
completion::{CompletionProvider, CompletionRequest},
|
||||
embedding::{Embedding, EmbeddingProvider},
|
||||
models::{LanguageModel, TruncationDirection},
|
||||
};
|
||||
|
@ -125,3 +128,39 @@ impl EmbeddingProvider for FakeEmbeddingProvider {
|
|||
anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TestCompletionProvider {
|
||||
last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
|
||||
}
|
||||
|
||||
impl TestCompletionProvider {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
last_completion_tx: Mutex::new(None),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn send_completion(&self, completion: impl Into<String>) {
|
||||
let mut tx = self.last_completion_tx.lock();
|
||||
tx.as_mut().unwrap().try_send(completion.into()).unwrap();
|
||||
}
|
||||
|
||||
pub fn finish_completion(&self) {
|
||||
self.last_completion_tx.lock().take().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
impl CompletionProvider for TestCompletionProvider {
|
||||
fn base_model(&self) -> Box<dyn LanguageModel> {
|
||||
let model: Box<dyn LanguageModel> = Box::new(FakeLanguageModel { capacity: 8190 });
|
||||
model
|
||||
}
|
||||
fn complete(
|
||||
&self,
|
||||
_prompt: Box<dyn CompletionRequest>,
|
||||
) -> BoxFuture<'static, anyhow::Result<BoxStream<'static, anyhow::Result<String>>>> {
|
||||
let (tx, rx) = mpsc::channel(1);
|
||||
*self.last_completion_tx.lock() = Some(tx);
|
||||
async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -44,6 +44,7 @@ tiktoken-rs = "0.5"
|
|||
[dev-dependencies]
|
||||
editor = { path = "../editor", features = ["test-support"] }
|
||||
project = { path = "../project", features = ["test-support"] }
|
||||
ai = { path = "../ai", features = ["test-support"]}
|
||||
|
||||
ctor.workspace = true
|
||||
env_logger.workspace = true
|
||||
|
|
|
@ -335,7 +335,7 @@ fn strip_markdown_codeblock(
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use ai::{models::LanguageModel, test::FakeLanguageModel};
|
||||
use ai::test::TestCompletionProvider;
|
||||
use futures::{
|
||||
future::BoxFuture,
|
||||
stream::{self, BoxStream},
|
||||
|
@ -617,42 +617,6 @@ mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
struct TestCompletionProvider {
|
||||
last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
|
||||
}
|
||||
|
||||
impl TestCompletionProvider {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
last_completion_tx: Mutex::new(None),
|
||||
}
|
||||
}
|
||||
|
||||
fn send_completion(&self, completion: impl Into<String>) {
|
||||
let mut tx = self.last_completion_tx.lock();
|
||||
tx.as_mut().unwrap().try_send(completion.into()).unwrap();
|
||||
}
|
||||
|
||||
fn finish_completion(&self) {
|
||||
self.last_completion_tx.lock().take().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
impl CompletionProvider for TestCompletionProvider {
|
||||
fn base_model(&self) -> Box<dyn LanguageModel> {
|
||||
let model: Box<dyn LanguageModel> = Box::new(FakeLanguageModel { capacity: 8190 });
|
||||
model
|
||||
}
|
||||
fn complete(
|
||||
&self,
|
||||
_prompt: Box<dyn CompletionRequest>,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
let (tx, rx) = mpsc::channel(1);
|
||||
*self.last_completion_tx.lock() = Some(tx);
|
||||
async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
fn rust_lang() -> Language {
|
||||
Language::new(
|
||||
LanguageConfig {
|
||||
|
|
Loading…
Reference in a new issue