moved TestCompletionProvider into ai

This commit is contained in:
KCaverly 2023-10-27 12:26:01 +02:00
parent ec9d79b6fe
commit 7af77b1cf9
3 changed files with 41 additions and 37 deletions

View file

@ -4,9 +4,12 @@ use std::{
};
use async_trait::async_trait;
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use parking_lot::Mutex;
use crate::{
auth::{CredentialProvider, NullCredentialProvider, ProviderCredential},
completion::{CompletionProvider, CompletionRequest},
embedding::{Embedding, EmbeddingProvider},
models::{LanguageModel, TruncationDirection},
};
@ -125,3 +128,39 @@ impl EmbeddingProvider for FakeEmbeddingProvider {
anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
}
}
pub struct TestCompletionProvider {
last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
}
impl TestCompletionProvider {
pub fn new() -> Self {
Self {
last_completion_tx: Mutex::new(None),
}
}
pub fn send_completion(&self, completion: impl Into<String>) {
let mut tx = self.last_completion_tx.lock();
tx.as_mut().unwrap().try_send(completion.into()).unwrap();
}
pub fn finish_completion(&self) {
self.last_completion_tx.lock().take().unwrap();
}
}
impl CompletionProvider for TestCompletionProvider {
fn base_model(&self) -> Box<dyn LanguageModel> {
let model: Box<dyn LanguageModel> = Box::new(FakeLanguageModel { capacity: 8190 });
model
}
fn complete(
&self,
_prompt: Box<dyn CompletionRequest>,
) -> BoxFuture<'static, anyhow::Result<BoxStream<'static, anyhow::Result<String>>>> {
let (tx, rx) = mpsc::channel(1);
*self.last_completion_tx.lock() = Some(tx);
async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
}
}

View file

@ -44,6 +44,7 @@ tiktoken-rs = "0.5"
[dev-dependencies]
editor = { path = "../editor", features = ["test-support"] }
project = { path = "../project", features = ["test-support"] }
ai = { path = "../ai", features = ["test-support"]}
ctor.workspace = true
env_logger.workspace = true

View file

@ -335,7 +335,7 @@ fn strip_markdown_codeblock(
#[cfg(test)]
mod tests {
use super::*;
use ai::{models::LanguageModel, test::FakeLanguageModel};
use ai::test::TestCompletionProvider;
use futures::{
future::BoxFuture,
stream::{self, BoxStream},
@ -617,42 +617,6 @@ mod tests {
}
}
struct TestCompletionProvider {
last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
}
impl TestCompletionProvider {
fn new() -> Self {
Self {
last_completion_tx: Mutex::new(None),
}
}
fn send_completion(&self, completion: impl Into<String>) {
let mut tx = self.last_completion_tx.lock();
tx.as_mut().unwrap().try_send(completion.into()).unwrap();
}
fn finish_completion(&self) {
self.last_completion_tx.lock().take().unwrap();
}
}
impl CompletionProvider for TestCompletionProvider {
fn base_model(&self) -> Box<dyn LanguageModel> {
let model: Box<dyn LanguageModel> = Box::new(FakeLanguageModel { capacity: 8190 });
model
}
fn complete(
&self,
_prompt: Box<dyn CompletionRequest>,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let (tx, rx) = mpsc::channel(1);
*self.last_completion_tx.lock() = Some(tx);
async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
}
}
fn rust_lang() -> Language {
Language::new(
LanguageConfig {