From ec9d79b6fec4e90f34367bd3a855ef11c58f75fd Mon Sep 17 00:00:00 2001 From: KCaverly Date: Fri, 27 Oct 2023 08:51:30 +0200 Subject: [PATCH] add concept of LanguageModel to CompletionProvider --- crates/ai/src/completion.rs | 3 +++ crates/ai/src/providers/open_ai/completion.rs | 21 ++++++++++++++++--- crates/ai/src/providers/open_ai/embedding.rs | 1 - crates/assistant/src/assistant_panel.rs | 1 + crates/assistant/src/codegen.rs | 5 +++++ 5 files changed, 27 insertions(+), 4 deletions(-) diff --git a/crates/ai/src/completion.rs b/crates/ai/src/completion.rs index ba89c869d2..da9ebd5a1d 100644 --- a/crates/ai/src/completion.rs +++ b/crates/ai/src/completion.rs @@ -1,11 +1,14 @@ use anyhow::Result; use futures::{future::BoxFuture, stream::BoxStream}; +use crate::models::LanguageModel; + pub trait CompletionRequest: Send + Sync { fn data(&self) -> serde_json::Result; } pub trait CompletionProvider { + fn base_model(&self) -> Box; fn complete( &self, prompt: Box, diff --git a/crates/ai/src/providers/open_ai/completion.rs b/crates/ai/src/providers/open_ai/completion.rs index 95ed13c0dd..20f72c0ff7 100644 --- a/crates/ai/src/providers/open_ai/completion.rs +++ b/crates/ai/src/providers/open_ai/completion.rs @@ -12,7 +12,12 @@ use std::{ sync::Arc, }; -use crate::completion::{CompletionProvider, CompletionRequest}; +use crate::{ + completion::{CompletionProvider, CompletionRequest}, + models::LanguageModel, +}; + +use super::OpenAILanguageModel; pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1"; @@ -180,17 +185,27 @@ pub async fn stream_completion( } pub struct OpenAICompletionProvider { + model: OpenAILanguageModel, api_key: String, executor: Arc, } impl OpenAICompletionProvider { - pub fn new(api_key: String, executor: Arc) -> Self { - Self { api_key, executor } + pub fn new(model_name: &str, api_key: String, executor: Arc) -> Self { + let model = OpenAILanguageModel::load(model_name); + Self { + model, + api_key, + executor, + } } } impl CompletionProvider for OpenAICompletionProvider { + fn base_model(&self) -> Box { + let model: Box = Box::new(self.model.clone()); + model + } fn complete( &self, prompt: Box, diff --git a/crates/ai/src/providers/open_ai/embedding.rs b/crates/ai/src/providers/open_ai/embedding.rs index 9806877660..64f568da1a 100644 --- a/crates/ai/src/providers/open_ai/embedding.rs +++ b/crates/ai/src/providers/open_ai/embedding.rs @@ -26,7 +26,6 @@ use crate::providers::open_ai::OpenAILanguageModel; use crate::providers::open_ai::auth::OpenAICredentialProvider; lazy_static! { - static ref OPENAI_API_KEY: Option = env::var("OPENAI_API_KEY").ok(); static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap(); } diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index ec16c8fd04..c899465ed2 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -328,6 +328,7 @@ impl AssistantPanel { let inline_assist_id = post_inc(&mut self.next_inline_assist_id); let provider = Arc::new(OpenAICompletionProvider::new( + "gpt-4", api_key, cx.background().clone(), )); diff --git a/crates/assistant/src/codegen.rs b/crates/assistant/src/codegen.rs index e71b1ae2cb..33adb2e570 100644 --- a/crates/assistant/src/codegen.rs +++ b/crates/assistant/src/codegen.rs @@ -335,6 +335,7 @@ fn strip_markdown_codeblock( #[cfg(test)] mod tests { use super::*; + use ai::{models::LanguageModel, test::FakeLanguageModel}; use futures::{ future::BoxFuture, stream::{self, BoxStream}, @@ -638,6 +639,10 @@ mod tests { } impl CompletionProvider for TestCompletionProvider { + fn base_model(&self) -> Box { + let model: Box = Box::new(FakeLanguageModel { capacity: 8190 }); + model + } fn complete( &self, _prompt: Box,