mirror of
https://github.com/zed-industries/zed.git
synced 2025-01-06 08:07:42 +00:00
authenticate with completion provider on new inline assists (#3209)
authenticate with completion provider on new inline assists Release Notes: - Fixed bug which lead the inline assist functionality to never authenticate
This commit is contained in:
commit
5ee2b01102
3 changed files with 33 additions and 17 deletions
|
@ -153,10 +153,17 @@ impl FakeCompletionProvider {
|
|||
|
||||
pub fn send_completion(&self, completion: impl Into<String>) {
|
||||
let mut tx = self.last_completion_tx.lock();
|
||||
tx.as_mut().unwrap().try_send(completion.into()).unwrap();
|
||||
|
||||
println!("COMPLETION TX: {:?}", &tx);
|
||||
|
||||
let a = tx.as_mut().unwrap();
|
||||
a.try_send(completion.into()).unwrap();
|
||||
|
||||
// tx.as_mut().unwrap().try_send(completion.into()).unwrap();
|
||||
}
|
||||
|
||||
pub fn finish_completion(&self) {
|
||||
println!("FINISHING COMPLETION");
|
||||
self.last_completion_tx.lock().take().unwrap();
|
||||
}
|
||||
}
|
||||
|
@ -181,8 +188,10 @@ impl CompletionProvider for FakeCompletionProvider {
|
|||
&self,
|
||||
_prompt: Box<dyn CompletionRequest>,
|
||||
) -> BoxFuture<'static, anyhow::Result<BoxStream<'static, anyhow::Result<String>>>> {
|
||||
println!("COMPLETING");
|
||||
let (tx, rx) = mpsc::channel(1);
|
||||
*self.last_completion_tx.lock() = Some(tx);
|
||||
println!("TX: {:?}", *self.last_completion_tx.lock());
|
||||
async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
|
||||
}
|
||||
fn box_clone(&self) -> Box<dyn CompletionProvider> {
|
||||
|
|
|
@ -142,7 +142,7 @@ pub struct AssistantPanel {
|
|||
zoomed: bool,
|
||||
has_focus: bool,
|
||||
toolbar: ViewHandle<Toolbar>,
|
||||
completion_provider: Box<dyn CompletionProvider>,
|
||||
completion_provider: Arc<dyn CompletionProvider>,
|
||||
api_key_editor: Option<ViewHandle<Editor>>,
|
||||
languages: Arc<LanguageRegistry>,
|
||||
fs: Arc<dyn Fs>,
|
||||
|
@ -204,7 +204,7 @@ impl AssistantPanel {
|
|||
|
||||
let semantic_index = SemanticIndex::global(cx);
|
||||
// Defaulting currently to GPT4, allow for this to be set via config.
|
||||
let completion_provider = Box::new(OpenAICompletionProvider::new(
|
||||
let completion_provider = Arc::new(OpenAICompletionProvider::new(
|
||||
"gpt-4",
|
||||
cx.background().clone(),
|
||||
));
|
||||
|
@ -259,7 +259,13 @@ impl AssistantPanel {
|
|||
cx: &mut ViewContext<Workspace>,
|
||||
) {
|
||||
let this = if let Some(this) = workspace.panel::<AssistantPanel>(cx) {
|
||||
if this.update(cx, |assistant, _| assistant.has_credentials()) {
|
||||
if this.update(cx, |assistant, cx| {
|
||||
if !assistant.has_credentials() {
|
||||
assistant.load_credentials(cx);
|
||||
};
|
||||
|
||||
assistant.has_credentials()
|
||||
}) {
|
||||
this
|
||||
} else {
|
||||
workspace.focus_panel::<AssistantPanel>(cx);
|
||||
|
@ -320,13 +326,10 @@ impl AssistantPanel {
|
|||
};
|
||||
|
||||
let inline_assist_id = post_inc(&mut self.next_inline_assist_id);
|
||||
let provider = Arc::new(OpenAICompletionProvider::new(
|
||||
"gpt-4",
|
||||
cx.background().clone(),
|
||||
));
|
||||
let provider = self.completion_provider.clone();
|
||||
|
||||
// Retrieve Credentials Authenticates the Provider
|
||||
// provider.retrieve_credentials(cx);
|
||||
provider.retrieve_credentials(cx);
|
||||
|
||||
let codegen = cx.add_model(|cx| {
|
||||
Codegen::new(editor.read(cx).buffer().clone(), codegen_kind, provider, cx)
|
||||
|
@ -1439,7 +1442,7 @@ struct Conversation {
|
|||
pending_save: Task<Result<()>>,
|
||||
path: Option<PathBuf>,
|
||||
_subscriptions: Vec<Subscription>,
|
||||
completion_provider: Box<dyn CompletionProvider>,
|
||||
completion_provider: Arc<dyn CompletionProvider>,
|
||||
}
|
||||
|
||||
impl Entity for Conversation {
|
||||
|
@ -1450,7 +1453,7 @@ impl Conversation {
|
|||
fn new(
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
completion_provider: Box<dyn CompletionProvider>,
|
||||
completion_provider: Arc<dyn CompletionProvider>,
|
||||
) -> Self {
|
||||
let markdown = language_registry.language_for_name("Markdown");
|
||||
let buffer = cx.add_model(|cx| {
|
||||
|
@ -1544,7 +1547,7 @@ impl Conversation {
|
|||
None => Some(Uuid::new_v4().to_string()),
|
||||
};
|
||||
let model = saved_conversation.model;
|
||||
let completion_provider: Box<dyn CompletionProvider> = Box::new(
|
||||
let completion_provider: Arc<dyn CompletionProvider> = Arc::new(
|
||||
OpenAICompletionProvider::new(model.full_name(), cx.background().clone()),
|
||||
);
|
||||
completion_provider.retrieve_credentials(cx);
|
||||
|
@ -2201,7 +2204,7 @@ struct ConversationEditor {
|
|||
|
||||
impl ConversationEditor {
|
||||
fn new(
|
||||
completion_provider: Box<dyn CompletionProvider>,
|
||||
completion_provider: Arc<dyn CompletionProvider>,
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
fs: Arc<dyn Fs>,
|
||||
workspace: WeakViewHandle<Workspace>,
|
||||
|
@ -3406,7 +3409,7 @@ mod tests {
|
|||
init(cx);
|
||||
let registry = Arc::new(LanguageRegistry::test());
|
||||
|
||||
let completion_provider = Box::new(FakeCompletionProvider::new());
|
||||
let completion_provider = Arc::new(FakeCompletionProvider::new());
|
||||
let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider));
|
||||
let buffer = conversation.read(cx).buffer.clone();
|
||||
|
||||
|
@ -3535,7 +3538,7 @@ mod tests {
|
|||
cx.set_global(SettingsStore::test(cx));
|
||||
init(cx);
|
||||
let registry = Arc::new(LanguageRegistry::test());
|
||||
let completion_provider = Box::new(FakeCompletionProvider::new());
|
||||
let completion_provider = Arc::new(FakeCompletionProvider::new());
|
||||
|
||||
let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider));
|
||||
let buffer = conversation.read(cx).buffer.clone();
|
||||
|
@ -3633,7 +3636,7 @@ mod tests {
|
|||
cx.set_global(SettingsStore::test(cx));
|
||||
init(cx);
|
||||
let registry = Arc::new(LanguageRegistry::test());
|
||||
let completion_provider = Box::new(FakeCompletionProvider::new());
|
||||
let completion_provider = Arc::new(FakeCompletionProvider::new());
|
||||
let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider));
|
||||
let buffer = conversation.read(cx).buffer.clone();
|
||||
|
||||
|
@ -3716,7 +3719,7 @@ mod tests {
|
|||
cx.set_global(SettingsStore::test(cx));
|
||||
init(cx);
|
||||
let registry = Arc::new(LanguageRegistry::test());
|
||||
let completion_provider = Box::new(FakeCompletionProvider::new());
|
||||
let completion_provider = Arc::new(FakeCompletionProvider::new());
|
||||
let conversation =
|
||||
cx.add_model(|cx| Conversation::new(registry.clone(), cx, completion_provider));
|
||||
let buffer = conversation.read(cx).buffer.clone();
|
||||
|
|
|
@ -367,6 +367,8 @@ fn strip_invalid_spans_from_codeblock(
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::*;
|
||||
use ai::test::FakeCompletionProvider;
|
||||
use futures::stream::{self};
|
||||
|
@ -437,6 +439,7 @@ mod tests {
|
|||
let max_len = cmp::min(new_text.len(), 10);
|
||||
let len = rng.gen_range(1..=max_len);
|
||||
let (chunk, suffix) = new_text.split_at(len);
|
||||
println!("CHUNK: {:?}", &chunk);
|
||||
provider.send_completion(chunk);
|
||||
new_text = suffix;
|
||||
deterministic.run_until_parked();
|
||||
|
@ -569,6 +572,7 @@ mod tests {
|
|||
let max_len = cmp::min(new_text.len(), 10);
|
||||
let len = rng.gen_range(1..=max_len);
|
||||
let (chunk, suffix) = new_text.split_at(len);
|
||||
println!("{:?}", &chunk);
|
||||
provider.send_completion(chunk);
|
||||
new_text = suffix;
|
||||
deterministic.run_until_parked();
|
||||
|
|
Loading…
Reference in a new issue