add concept of LanguageModel to CompletionProvider

This commit is contained in:
KCaverly 2023-10-27 08:51:30 +02:00
parent 6c8bb4b05e
commit ec9d79b6fe
5 changed files with 27 additions and 4 deletions

View file

@ -1,11 +1,14 @@
use anyhow::Result;
use futures::{future::BoxFuture, stream::BoxStream};
use crate::models::LanguageModel;
pub trait CompletionRequest: Send + Sync {
fn data(&self) -> serde_json::Result<String>;
}
pub trait CompletionProvider {
fn base_model(&self) -> Box<dyn LanguageModel>;
fn complete(
&self,
prompt: Box<dyn CompletionRequest>,

View file

@ -12,7 +12,12 @@ use std::{
sync::Arc,
};
use crate::completion::{CompletionProvider, CompletionRequest};
use crate::{
completion::{CompletionProvider, CompletionRequest},
models::LanguageModel,
};
use super::OpenAILanguageModel;
pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
@ -180,17 +185,27 @@ pub async fn stream_completion(
}
pub struct OpenAICompletionProvider {
model: OpenAILanguageModel,
api_key: String,
executor: Arc<Background>,
}
impl OpenAICompletionProvider {
pub fn new(api_key: String, executor: Arc<Background>) -> Self {
Self { api_key, executor }
pub fn new(model_name: &str, api_key: String, executor: Arc<Background>) -> Self {
let model = OpenAILanguageModel::load(model_name);
Self {
model,
api_key,
executor,
}
}
}
impl CompletionProvider for OpenAICompletionProvider {
fn base_model(&self) -> Box<dyn LanguageModel> {
let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
model
}
fn complete(
&self,
prompt: Box<dyn CompletionRequest>,

View file

@ -26,7 +26,6 @@ use crate::providers::open_ai::OpenAILanguageModel;
use crate::providers::open_ai::auth::OpenAICredentialProvider;
lazy_static! {
static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
}

View file

@ -328,6 +328,7 @@ impl AssistantPanel {
let inline_assist_id = post_inc(&mut self.next_inline_assist_id);
let provider = Arc::new(OpenAICompletionProvider::new(
"gpt-4",
api_key,
cx.background().clone(),
));

View file

@ -335,6 +335,7 @@ fn strip_markdown_codeblock(
#[cfg(test)]
mod tests {
use super::*;
use ai::{models::LanguageModel, test::FakeLanguageModel};
use futures::{
future::BoxFuture,
stream::{self, BoxStream},
@ -638,6 +639,10 @@ mod tests {
}
impl CompletionProvider for TestCompletionProvider {
fn base_model(&self) -> Box<dyn LanguageModel> {
let model: Box<dyn LanguageModel> = Box::new(FakeLanguageModel { capacity: 8190 });
model
}
fn complete(
&self,
_prompt: Box<dyn CompletionRequest>,