moved from Boxes to Arcs for shared access of completion providers across the assistant panel and inline assistant

This commit is contained in:
KCaverly 2023-11-02 10:08:47 -04:00
parent b5fe0d72ee
commit d5b6300fd7
3 changed files with 28 additions and 17 deletions

View file

@ -153,10 +153,17 @@ impl FakeCompletionProvider {
pub fn send_completion(&self, completion: impl Into<String>) { pub fn send_completion(&self, completion: impl Into<String>) {
let mut tx = self.last_completion_tx.lock(); let mut tx = self.last_completion_tx.lock();
tx.as_mut().unwrap().try_send(completion.into()).unwrap();
println!("COMPLETION TX: {:?}", &tx);
let a = tx.as_mut().unwrap();
a.try_send(completion.into()).unwrap();
// tx.as_mut().unwrap().try_send(completion.into()).unwrap();
} }
pub fn finish_completion(&self) { pub fn finish_completion(&self) {
println!("FINISHING COMPLETION");
self.last_completion_tx.lock().take().unwrap(); self.last_completion_tx.lock().take().unwrap();
} }
} }
@ -181,8 +188,10 @@ impl CompletionProvider for FakeCompletionProvider {
&self, &self,
_prompt: Box<dyn CompletionRequest>, _prompt: Box<dyn CompletionRequest>,
) -> BoxFuture<'static, anyhow::Result<BoxStream<'static, anyhow::Result<String>>>> { ) -> BoxFuture<'static, anyhow::Result<BoxStream<'static, anyhow::Result<String>>>> {
println!("COMPLETING");
let (tx, rx) = mpsc::channel(1); let (tx, rx) = mpsc::channel(1);
*self.last_completion_tx.lock() = Some(tx); *self.last_completion_tx.lock() = Some(tx);
println!("TX: {:?}", *self.last_completion_tx.lock());
async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed() async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
} }
fn box_clone(&self) -> Box<dyn CompletionProvider> { fn box_clone(&self) -> Box<dyn CompletionProvider> {

View file

@ -142,7 +142,7 @@ pub struct AssistantPanel {
zoomed: bool, zoomed: bool,
has_focus: bool, has_focus: bool,
toolbar: ViewHandle<Toolbar>, toolbar: ViewHandle<Toolbar>,
completion_provider: Box<dyn CompletionProvider>, completion_provider: Arc<dyn CompletionProvider>,
api_key_editor: Option<ViewHandle<Editor>>, api_key_editor: Option<ViewHandle<Editor>>,
languages: Arc<LanguageRegistry>, languages: Arc<LanguageRegistry>,
fs: Arc<dyn Fs>, fs: Arc<dyn Fs>,
@ -204,7 +204,7 @@ impl AssistantPanel {
let semantic_index = SemanticIndex::global(cx); let semantic_index = SemanticIndex::global(cx);
// Defaulting currently to GPT4, allow for this to be set via config. // Defaulting currently to GPT4, allow for this to be set via config.
let completion_provider = Box::new(OpenAICompletionProvider::new( let completion_provider = Arc::new(OpenAICompletionProvider::new(
"gpt-4", "gpt-4",
cx.background().clone(), cx.background().clone(),
)); ));
@ -1442,7 +1442,7 @@ struct Conversation {
pending_save: Task<Result<()>>, pending_save: Task<Result<()>>,
path: Option<PathBuf>, path: Option<PathBuf>,
_subscriptions: Vec<Subscription>, _subscriptions: Vec<Subscription>,
completion_provider: Box<dyn CompletionProvider>, completion_provider: Arc<dyn CompletionProvider>,
} }
impl Entity for Conversation { impl Entity for Conversation {
@ -1453,7 +1453,7 @@ impl Conversation {
fn new( fn new(
language_registry: Arc<LanguageRegistry>, language_registry: Arc<LanguageRegistry>,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
completion_provider: Box<dyn CompletionProvider>, completion_provider: Arc<dyn CompletionProvider>,
) -> Self { ) -> Self {
let markdown = language_registry.language_for_name("Markdown"); let markdown = language_registry.language_for_name("Markdown");
let buffer = cx.add_model(|cx| { let buffer = cx.add_model(|cx| {
@ -1547,7 +1547,7 @@ impl Conversation {
None => Some(Uuid::new_v4().to_string()), None => Some(Uuid::new_v4().to_string()),
}; };
let model = saved_conversation.model; let model = saved_conversation.model;
let completion_provider: Box<dyn CompletionProvider> = Box::new( let completion_provider: Arc<dyn CompletionProvider> = Arc::new(
OpenAICompletionProvider::new(model.full_name(), cx.background().clone()), OpenAICompletionProvider::new(model.full_name(), cx.background().clone()),
); );
completion_provider.retrieve_credentials(cx); completion_provider.retrieve_credentials(cx);
@ -2204,7 +2204,7 @@ struct ConversationEditor {
impl ConversationEditor { impl ConversationEditor {
fn new( fn new(
completion_provider: Box<dyn CompletionProvider>, completion_provider: Arc<dyn CompletionProvider>,
language_registry: Arc<LanguageRegistry>, language_registry: Arc<LanguageRegistry>,
fs: Arc<dyn Fs>, fs: Arc<dyn Fs>,
workspace: WeakViewHandle<Workspace>, workspace: WeakViewHandle<Workspace>,
@ -3409,7 +3409,7 @@ mod tests {
init(cx); init(cx);
let registry = Arc::new(LanguageRegistry::test()); let registry = Arc::new(LanguageRegistry::test());
let completion_provider = Box::new(FakeCompletionProvider::new()); let completion_provider = Arc::new(FakeCompletionProvider::new());
let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider)); let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider));
let buffer = conversation.read(cx).buffer.clone(); let buffer = conversation.read(cx).buffer.clone();
@ -3538,7 +3538,7 @@ mod tests {
cx.set_global(SettingsStore::test(cx)); cx.set_global(SettingsStore::test(cx));
init(cx); init(cx);
let registry = Arc::new(LanguageRegistry::test()); let registry = Arc::new(LanguageRegistry::test());
let completion_provider = Box::new(FakeCompletionProvider::new()); let completion_provider = Arc::new(FakeCompletionProvider::new());
let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider)); let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider));
let buffer = conversation.read(cx).buffer.clone(); let buffer = conversation.read(cx).buffer.clone();
@ -3636,7 +3636,7 @@ mod tests {
cx.set_global(SettingsStore::test(cx)); cx.set_global(SettingsStore::test(cx));
init(cx); init(cx);
let registry = Arc::new(LanguageRegistry::test()); let registry = Arc::new(LanguageRegistry::test());
let completion_provider = Box::new(FakeCompletionProvider::new()); let completion_provider = Arc::new(FakeCompletionProvider::new());
let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider)); let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider));
let buffer = conversation.read(cx).buffer.clone(); let buffer = conversation.read(cx).buffer.clone();
@ -3719,7 +3719,7 @@ mod tests {
cx.set_global(SettingsStore::test(cx)); cx.set_global(SettingsStore::test(cx));
init(cx); init(cx);
let registry = Arc::new(LanguageRegistry::test()); let registry = Arc::new(LanguageRegistry::test());
let completion_provider = Box::new(FakeCompletionProvider::new()); let completion_provider = Arc::new(FakeCompletionProvider::new());
let conversation = let conversation =
cx.add_model(|cx| Conversation::new(registry.clone(), cx, completion_provider)); cx.add_model(|cx| Conversation::new(registry.clone(), cx, completion_provider));
let buffer = conversation.read(cx).buffer.clone(); let buffer = conversation.read(cx).buffer.clone();

View file

@ -6,7 +6,7 @@ use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
use gpui::{Entity, ModelContext, ModelHandle, Task}; use gpui::{Entity, ModelContext, ModelHandle, Task};
use language::{Rope, TransactionId}; use language::{Rope, TransactionId};
use multi_buffer; use multi_buffer;
use std::{cmp, future, ops::Range}; use std::{cmp, future, ops::Range, sync::Arc};
pub enum Event { pub enum Event {
Finished, Finished,
@ -20,7 +20,7 @@ pub enum CodegenKind {
} }
pub struct Codegen { pub struct Codegen {
provider: Box<dyn CompletionProvider>, provider: Arc<dyn CompletionProvider>,
buffer: ModelHandle<MultiBuffer>, buffer: ModelHandle<MultiBuffer>,
snapshot: MultiBufferSnapshot, snapshot: MultiBufferSnapshot,
kind: CodegenKind, kind: CodegenKind,
@ -40,7 +40,7 @@ impl Codegen {
pub fn new( pub fn new(
buffer: ModelHandle<MultiBuffer>, buffer: ModelHandle<MultiBuffer>,
kind: CodegenKind, kind: CodegenKind,
provider: Box<dyn CompletionProvider>, provider: Arc<dyn CompletionProvider>,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) -> Self { ) -> Self {
let snapshot = buffer.read(cx).snapshot(cx); let snapshot = buffer.read(cx).snapshot(cx);
@ -414,7 +414,7 @@ mod tests {
let snapshot = buffer.snapshot(cx); let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5)) snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
}); });
let provider = Box::new(FakeCompletionProvider::new()); let provider = Arc::new(FakeCompletionProvider::new());
let codegen = cx.add_model(|cx| { let codegen = cx.add_model(|cx| {
Codegen::new( Codegen::new(
buffer.clone(), buffer.clone(),
@ -439,6 +439,7 @@ mod tests {
let max_len = cmp::min(new_text.len(), 10); let max_len = cmp::min(new_text.len(), 10);
let len = rng.gen_range(1..=max_len); let len = rng.gen_range(1..=max_len);
let (chunk, suffix) = new_text.split_at(len); let (chunk, suffix) = new_text.split_at(len);
println!("CHUNK: {:?}", &chunk);
provider.send_completion(chunk); provider.send_completion(chunk);
new_text = suffix; new_text = suffix;
deterministic.run_until_parked(); deterministic.run_until_parked();
@ -480,7 +481,7 @@ mod tests {
let snapshot = buffer.snapshot(cx); let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 6)) snapshot.anchor_before(Point::new(1, 6))
}); });
let provider = Box::new(FakeCompletionProvider::new()); let provider = Arc::new(FakeCompletionProvider::new());
let codegen = cx.add_model(|cx| { let codegen = cx.add_model(|cx| {
Codegen::new( Codegen::new(
buffer.clone(), buffer.clone(),
@ -546,7 +547,7 @@ mod tests {
let snapshot = buffer.snapshot(cx); let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 2)) snapshot.anchor_before(Point::new(1, 2))
}); });
let provider = Box::new(FakeCompletionProvider::new()); let provider = Arc::new(FakeCompletionProvider::new());
let codegen = cx.add_model(|cx| { let codegen = cx.add_model(|cx| {
Codegen::new( Codegen::new(
buffer.clone(), buffer.clone(),
@ -571,6 +572,7 @@ mod tests {
let max_len = cmp::min(new_text.len(), 10); let max_len = cmp::min(new_text.len(), 10);
let len = rng.gen_range(1..=max_len); let len = rng.gen_range(1..=max_len);
let (chunk, suffix) = new_text.split_at(len); let (chunk, suffix) = new_text.split_at(len);
println!("{:?}", &chunk);
provider.send_completion(chunk); provider.send_completion(chunk);
new_text = suffix; new_text = suffix;
deterministic.run_until_parked(); deterministic.run_until_parked();