From b5fe0d72ee4328d9d7a6d9cf9ec0e371a7d0a3f1 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Thu, 2 Nov 2023 09:34:18 -0400 Subject: [PATCH] authenticate with completion provider on new inline assists --- crates/assistant/src/assistant_panel.rs | 15 +++++++++------ crates/assistant/src/codegen.rs | 14 ++++++++------ 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 03eb3c238f..022c228790 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -259,7 +259,13 @@ impl AssistantPanel { cx: &mut ViewContext, ) { let this = if let Some(this) = workspace.panel::(cx) { - if this.update(cx, |assistant, _| assistant.has_credentials()) { + if this.update(cx, |assistant, cx| { + if !assistant.has_credentials() { + assistant.load_credentials(cx); + }; + + assistant.has_credentials() + }) { this } else { workspace.focus_panel::(cx); @@ -320,13 +326,10 @@ impl AssistantPanel { }; let inline_assist_id = post_inc(&mut self.next_inline_assist_id); - let provider = Arc::new(OpenAICompletionProvider::new( - "gpt-4", - cx.background().clone(), - )); + let provider = self.completion_provider.clone(); // Retrieve Credentials Authenticates the Provider - // provider.retrieve_credentials(cx); + provider.retrieve_credentials(cx); let codegen = cx.add_model(|cx| { Codegen::new(editor.read(cx).buffer().clone(), codegen_kind, provider, cx) diff --git a/crates/assistant/src/codegen.rs b/crates/assistant/src/codegen.rs index f62c91fcb7..da7beda2dc 100644 --- a/crates/assistant/src/codegen.rs +++ b/crates/assistant/src/codegen.rs @@ -6,7 +6,7 @@ use futures::{channel::mpsc, SinkExt, Stream, StreamExt}; use gpui::{Entity, ModelContext, ModelHandle, Task}; use language::{Rope, TransactionId}; use multi_buffer; -use std::{cmp, future, ops::Range, sync::Arc}; +use std::{cmp, future, ops::Range}; pub enum Event { Finished, @@ -20,7 +20,7 @@ pub enum CodegenKind { } pub struct Codegen { - provider: Arc, + provider: Box, buffer: ModelHandle, snapshot: MultiBufferSnapshot, kind: CodegenKind, @@ -40,7 +40,7 @@ impl Codegen { pub fn new( buffer: ModelHandle, kind: CodegenKind, - provider: Arc, + provider: Box, cx: &mut ModelContext, ) -> Self { let snapshot = buffer.read(cx).snapshot(cx); @@ -367,6 +367,8 @@ fn strip_invalid_spans_from_codeblock( #[cfg(test)] mod tests { + use std::sync::Arc; + use super::*; use ai::test::FakeCompletionProvider; use futures::stream::{self}; @@ -412,7 +414,7 @@ mod tests { let snapshot = buffer.snapshot(cx); snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5)) }); - let provider = Arc::new(FakeCompletionProvider::new()); + let provider = Box::new(FakeCompletionProvider::new()); let codegen = cx.add_model(|cx| { Codegen::new( buffer.clone(), @@ -478,7 +480,7 @@ mod tests { let snapshot = buffer.snapshot(cx); snapshot.anchor_before(Point::new(1, 6)) }); - let provider = Arc::new(FakeCompletionProvider::new()); + let provider = Box::new(FakeCompletionProvider::new()); let codegen = cx.add_model(|cx| { Codegen::new( buffer.clone(), @@ -544,7 +546,7 @@ mod tests { let snapshot = buffer.snapshot(cx); snapshot.anchor_before(Point::new(1, 2)) }); - let provider = Arc::new(FakeCompletionProvider::new()); + let provider = Box::new(FakeCompletionProvider::new()); let codegen = cx.add_model(|cx| { Codegen::new( buffer.clone(),