From a29b70552f82772944f754e8b24922df8b16dcd4 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Fri, 29 Sep 2023 14:02:41 -0400 Subject: [PATCH] working inline with semantic index, this should never merge --- Cargo.lock | 1 + crates/ai/src/ai.rs | 3 +- crates/ai/src/function_calling.rs | 2 + crates/ai/src/skills.rs | 50 ++ crates/assistant/Cargo.toml | 2 + crates/assistant/src/assistant_panel.rs | 135 ++-- crates/assistant/src/codegen.rs | 716 ++++++++++++-------- crates/assistant/src/prompts.rs | 32 +- crates/semantic_index/src/semantic_index.rs | 1 + crates/semantic_index/src/skills.rs | 106 +++ 10 files changed, 698 insertions(+), 350 deletions(-) create mode 100644 crates/ai/src/skills.rs create mode 100644 crates/semantic_index/src/skills.rs diff --git a/Cargo.lock b/Cargo.lock index f28e88edff..9d4d6143f0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -323,6 +323,7 @@ dependencies = [ "regex", "schemars", "search", + "semantic_index", "serde", "serde_json", "settings", diff --git a/crates/ai/src/ai.rs b/crates/ai/src/ai.rs index 6a2d2a816e..ab5d71c880 100644 --- a/crates/ai/src/ai.rs +++ b/crates/ai/src/ai.rs @@ -1,6 +1,7 @@ pub mod completion; pub mod embedding; pub mod function_calling; +pub mod skills; use core::fmt; use std::fmt::Display; @@ -35,7 +36,7 @@ impl Display for Role { } } -#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)] pub struct RequestMessage { pub role: Role, pub content: String, diff --git a/crates/ai/src/function_calling.rs b/crates/ai/src/function_calling.rs index dded170611..fd227d9982 100644 --- a/crates/ai/src/function_calling.rs +++ b/crates/ai/src/function_calling.rs @@ -13,6 +13,7 @@ pub trait OpenAIFunction: erased_serde::Serialize { fn description(&self) -> String; fn system_prompt(&self) -> String; fn parameters(&self) -> serde_json::Value; + fn complete(&self, arguments: serde_json::Value) -> anyhow::Result; } serialize_trait_object!(OpenAIFunction); @@ -83,6 +84,7 @@ pub struct FunctionCallDetails { pub arguments: serde_json::Value, // json object respresenting provided arguments } +#[derive(Clone)] pub struct OpenAIFunctionCallingProvider { api_key: String, } diff --git a/crates/ai/src/skills.rs b/crates/ai/src/skills.rs new file mode 100644 index 0000000000..69126b040e --- /dev/null +++ b/crates/ai/src/skills.rs @@ -0,0 +1,50 @@ +use crate::function_calling::OpenAIFunction; +use gpui::{AppContext, ModelHandle}; +use project::Project; +use serde::{Serialize, Serializer}; +use serde_json::json; + +pub struct RewritePrompt; +impl Serialize for RewritePrompt { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + json!({"name": self.name(), + "description": self.description(), + "parameters": self.parameters()}) + .serialize(serializer) + } +} + +impl RewritePrompt { + pub fn load() -> Self { + Self {} + } +} + +impl OpenAIFunction for RewritePrompt { + fn name(&self) -> String { + "rewrite_prompt".to_string() + } + fn description(&self) -> String { + "Rewrite prompt given prompt from user".to_string() + } + fn system_prompt(&self) -> String { + "'rewrite_prompt': + If all information is available in the above prompt, and you need no further information. + Rewrite the entire prompt to clarify what should be generated, do not actually complete the users request. + Assume this rewritten message will be passed to another completion agent, to fulfill the users request.".to_string() + } + fn parameters(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "prompt": {} + } + }) + } + fn complete(&self, arguments: serde_json::Value) -> anyhow::Result { + Ok(arguments.get("prompt").unwrap().to_string()) + } +} diff --git a/crates/assistant/Cargo.toml b/crates/assistant/Cargo.toml index 5d141b32d5..e51a49b4a7 100644 --- a/crates/assistant/Cargo.toml +++ b/crates/assistant/Cargo.toml @@ -23,6 +23,8 @@ theme = { path = "../theme" } util = { path = "../util" } uuid = { version = "1.1.2", features = ["v4"] } workspace = { path = "../workspace" } +semantic_index = { path = "../semantic_index" } +project = { path = "../project" } anyhow.workspace = true chrono = { version = "0.4", features = ["serde"] } diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 5454f1673c..8b799c2987 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -1,12 +1,16 @@ use crate::{ assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAIModel}, codegen::{self, Codegen, CodegenKind}, - prompts::generate_content_prompt, + prompts::{generate_codegen_planning_prompt, generate_content_prompt}, MessageId, MessageMetadata, MessageStatus, Role, SavedConversation, SavedConversationMetadata, SavedMessage, }; -use ai::completion::{stream_completion, OpenAICompletionProvider, OpenAIRequest, OPENAI_API_URL}; -use ai::RequestMessage; +use ai::{ + completion::{stream_completion, OpenAICompletionProvider, OpenAIRequest, OPENAI_API_URL}, + function_calling::OpenAIFunctionCallingProvider, + skills::RewritePrompt, +}; +use ai::{function_calling::OpenAIFunction, RequestMessage}; use anyhow::{anyhow, Result}; use chrono::{DateTime, Local}; use client::{telemetry::AssistantKind, ClickhouseEvent, TelemetrySettings}; @@ -34,7 +38,9 @@ use gpui::{ WindowContext, }; use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _}; +use project::Project; use search::BufferSearchBar; +use semantic_index::{skills::RepositoryContextRetriever, SemanticIndex}; use settings::SettingsStore; use std::{ cell::{Cell, RefCell}, @@ -144,6 +150,8 @@ pub struct AssistantPanel { include_conversation_in_next_inline_assist: bool, inline_prompt_history: VecDeque, _watch_saved_conversations: Task>, + semantic_index: ModelHandle, + project: ModelHandle, } impl AssistantPanel { @@ -153,6 +161,7 @@ impl AssistantPanel { workspace: WeakViewHandle, cx: AsyncAppContext, ) -> Task>> { + let index = cx.read(|cx| SemanticIndex::global(cx).unwrap()); cx.spawn(|mut cx| async move { let fs = workspace.read_with(&cx, |workspace, _| workspace.app_state().fs.clone())?; let saved_conversations = SavedConversationMetadata::list(fs.clone()) @@ -190,6 +199,9 @@ impl AssistantPanel { toolbar.add_item(cx.add_view(|cx| BufferSearchBar::new(cx)), cx); toolbar }); + + let project = workspace.project().clone(); + let mut this = Self { workspace: workspace_handle, active_editor_index: Default::default(), @@ -214,6 +226,8 @@ impl AssistantPanel { include_conversation_in_next_inline_assist: false, inline_prompt_history: Default::default(), _watch_saved_conversations, + semantic_index: index, + project, }; let mut old_dock_position = this.position(cx); @@ -276,9 +290,10 @@ impl AssistantPanel { let inline_assist_id = post_inc(&mut self.next_inline_assist_id); let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx); let provider = Arc::new(OpenAICompletionProvider::new( - api_key, + api_key.clone(), cx.background().clone(), )); + let fc_provider = OpenAIFunctionCallingProvider::new(api_key); let selection = editor.read(cx).selections.newest_anchor().clone(); let codegen_kind = if editor.read(cx).selections.newest::(cx).is_empty() { CodegenKind::Generate { @@ -289,8 +304,18 @@ impl AssistantPanel { range: selection.start..selection.end, } }; + + let project = self.project.clone(); + let codegen = cx.add_model(|cx| { - Codegen::new(editor.read(cx).buffer().clone(), codegen_kind, provider, cx) + Codegen::new( + editor.read(cx).buffer().clone(), + codegen_kind, + provider, + fc_provider, + cx, + project.clone(), + ) }); let measurements = Rc::new(Cell::new(BlockMeasurements::default())); @@ -572,42 +597,74 @@ impl AssistantPanel { let language_name = language_name.as_deref(); let codegen_kind = pending_assist.codegen.read(cx).kind().clone(); - let prompt = generate_content_prompt( - user_prompt.to_string(), - language_name, - &snapshot, - language_range, - cx, - codegen_kind, - ); + let index = self.semantic_index.clone(); - let mut messages = Vec::new(); - let mut model = settings::get::(cx) - .default_open_ai_model - .clone(); - if let Some(conversation) = conversation { - let conversation = conversation.read(cx); - let buffer = conversation.buffer.read(cx); - messages.extend( - conversation - .messages(cx) - .map(|message| message.to_open_ai_message(buffer)), - ); - model = conversation.model.clone(); - } - - messages.push(RequestMessage { - role: Role::User, - content: prompt, + pending_assist.codegen.update(cx, |codegen, cx| { + codegen.start( + user_prompt.to_string(), + cx, + language_name, + snapshot, + language_range.clone(), + codegen_kind.clone(), + index, + ) }); - let request = OpenAIRequest { - model: model.full_name().into(), - messages, - stream: true, - }; - pending_assist - .codegen - .update(cx, |codegen, cx| codegen.start(request, cx)); + + // let api_key = self.api_key.as_ref().clone().into_inner().clone().unwrap(); + // let function_provider = OpenAIFunctionCallingProvider::new(api_key); + + // let planning_messages = vec![RequestMessage { + // role: Role::User, + // content: planning_prompt, + // }]; + + // println!("GETTING HERE"); + + // let function_call = cx + // .spawn(|this, mut cx| async move { + // let result = function_provider + // .complete("gpt-4".to_string(), planning_messages, functions) + // .await; + // dbg!(&result); + // result + // }) + // .detach(); + + // let function_name = function_call.name.as_str(); + // let prompt = match function_name { + // "rewrite_prompt" => { + // let user_prompt = RewritePrompt::load() + // .complete(function_call.arguments) + // .unwrap(); + // generate_content_prompt( + // user_prompt.to_string(), + // language_name, + // &snapshot, + // language_range, + // cx, + // codegen_kind, + // ) + // } + // _ => { + // todo!(); + // } + // }; + + // let mut messages = Vec::new(); + // let mut model = settings::get::(cx) + // .default_open_ai_model + // .clone(); + // if let Some(conversation) = conversation { + // let conversation = conversation.read(cx); + // let buffer = conversation.buffer.read(cx); + // messages.extend( + // conversation + // .messages(cx) + // .map(|message| message.to_open_ai_message(buffer)), + // ); + // model = conversation.model.clone(); + // } } fn update_highlights_for_editor( diff --git a/crates/assistant/src/codegen.rs b/crates/assistant/src/codegen.rs index e956d72260..7eeae20201 100644 --- a/crates/assistant/src/codegen.rs +++ b/crates/assistant/src/codegen.rs @@ -1,12 +1,22 @@ -use crate::streaming_diff::{Hunk, StreamingDiff}; -use ai::completion::{CompletionProvider, OpenAIRequest}; +use crate::{ + prompts::{generate_codegen_planning_prompt, generate_content_prompt}, + streaming_diff::{Hunk, StreamingDiff}, +}; +use ai::{ + completion::{CompletionProvider, OpenAIRequest}, + function_calling::{OpenAIFunction, OpenAIFunctionCallingProvider}, + skills::RewritePrompt, + RequestMessage, Role, +}; use anyhow::Result; use editor::{ multi_buffer, Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint, }; use futures::{channel::mpsc, SinkExt, Stream, StreamExt}; -use gpui::{Entity, ModelContext, ModelHandle, Task}; -use language::{Rope, TransactionId}; +use gpui::{BorrowAppContext, Entity, ModelContext, ModelHandle, Task}; +use language::{BufferSnapshot, Rope, TransactionId}; +use project::Project; +use semantic_index::{skills::RepositoryContextRetriever, SemanticIndex}; use std::{cmp, future, ops::Range, sync::Arc}; pub enum Event { @@ -22,6 +32,7 @@ pub enum CodegenKind { pub struct Codegen { provider: Arc, + fc_provider: OpenAIFunctionCallingProvider, buffer: ModelHandle, snapshot: MultiBufferSnapshot, kind: CodegenKind, @@ -31,6 +42,7 @@ pub struct Codegen { generation: Task<()>, idle: bool, _subscription: gpui::Subscription, + project: ModelHandle, } impl Entity for Codegen { @@ -42,7 +54,9 @@ impl Codegen { buffer: ModelHandle, mut kind: CodegenKind, provider: Arc, + fc_provider: OpenAIFunctionCallingProvider, cx: &mut ModelContext, + project: ModelHandle, ) -> Self { let snapshot = buffer.read(cx).snapshot(cx); match &mut kind { @@ -62,6 +76,7 @@ impl Codegen { Self { provider, + fc_provider, buffer: buffer.clone(), snapshot, kind, @@ -71,6 +86,7 @@ impl Codegen { idle: true, generation: Task::ready(()), _subscription: cx.subscribe(&buffer, Self::handle_buffer_event), + project, } } @@ -112,7 +128,17 @@ impl Codegen { self.error.as_ref() } - pub fn start(&mut self, prompt: OpenAIRequest, cx: &mut ModelContext) { + pub fn start( + &mut self, + prompt: String, + cx: &mut ModelContext, + language_name: Option<&str>, + buffer: BufferSnapshot, + range: Range, + kind: CodegenKind, + index: ModelHandle, + ) { + let language_range = range.clone(); let range = self.range(); let snapshot = self.snapshot.clone(); let selected_text = snapshot @@ -126,9 +152,101 @@ impl Codegen { .next() .unwrap_or_else(|| snapshot.indent_size_for_line(selection_start.row)); - let response = self.provider.complete(prompt); + let messages = vec![RequestMessage { + role: Role::User, + content: prompt.clone(), + }]; + + let request = OpenAIRequest { + model: "gpt-4".to_string(), + messages: messages.clone(), + stream: true, + }; + + let (planning_prompt, outline) = generate_codegen_planning_prompt( + prompt.clone(), + language_name.clone(), + &buffer, + language_range.clone(), + cx, + kind.clone(), + ); + + let project = self.project.clone(); + self.generation = cx.spawn_weak(|this, mut cx| { + // Plan Ahead + let planning_messages = vec![RequestMessage { + role: Role::User, + content: planning_prompt, + }]; + + let repo_retriever = RepositoryContextRetriever::load(index, project); + let functions: Vec> = vec![ + Box::new(RewritePrompt::load()), + Box::new(repo_retriever.clone()), + ]; + + let completion_provider = self.provider.clone(); + let fc_provider = self.fc_provider.clone(); + let language_name = language_name.clone(); + let language_name = if let Some(language_name) = language_name.clone() { + Some(language_name.to_string()) + } else { + None + }; + let kind = kind.clone(); async move { + let mut user_prompt = prompt.clone(); + let user_prompt = if let Ok(function_call) = fc_provider + .complete("gpt-4".to_string(), planning_messages, functions) + .await + { + let function_name = function_call.name.as_str(); + println!("FUNCTION NAME: {:?}", function_name); + let user_prompt = match function_name { + "rewrite_prompt" => { + let user_prompt = RewritePrompt::load() + .complete(function_call.arguments) + .unwrap(); + generate_content_prompt( + user_prompt, + language_name, + outline, + kind, + vec![], + ) + } + _ => { + let arguments = function_call.arguments.clone(); + let snippet = repo_retriever + .complete_test(arguments, &mut cx) + .await + .unwrap(); + let snippet = vec![snippet]; + + generate_content_prompt(prompt, language_name, outline, kind, snippet) + } + }; + user_prompt + } else { + user_prompt + }; + + println!("{:?}", user_prompt.clone()); + + let messages = vec![RequestMessage { + role: Role::User, + content: user_prompt.clone(), + }]; + + let request = OpenAIRequest { + model: "gpt-4".to_string(), + messages: messages.clone(), + stream: true, + }; + + let response = completion_provider.complete(request); let generate = async { let mut edit_start = range.start.to_offset(&snapshot); @@ -349,315 +467,317 @@ fn strip_markdown_codeblock( }) } -#[cfg(test)] -mod tests { - use super::*; - use futures::{ - future::BoxFuture, - stream::{self, BoxStream}, - }; - use gpui::{executor::Deterministic, TestAppContext}; - use indoc::indoc; - use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point}; - use parking_lot::Mutex; - use rand::prelude::*; - use settings::SettingsStore; - use smol::future::FutureExt; +// #[cfg(test)] +// mod tests { +// use super::*; +// use futures::{ +// future::BoxFuture, +// stream::{self, BoxStream}, +// }; +// use gpui::{executor::Deterministic, TestAppContext}; +// use indoc::indoc; +// use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point}; +// use parking_lot::Mutex; +// use rand::prelude::*; +// use settings::SettingsStore; +// use smol::future::FutureExt; - #[gpui::test(iterations = 10)] - async fn test_transform_autoindent( - cx: &mut TestAppContext, - mut rng: StdRng, - deterministic: Arc, - ) { - cx.set_global(cx.read(SettingsStore::test)); - cx.update(language_settings::init); +// #[gpui::test(iterations = 10)] +// async fn test_transform_autoindent( +// cx: &mut TestAppContext, +// mut rng: StdRng, +// deterministic: Arc, +// ) { +// cx.set_global(cx.read(SettingsStore::test)); +// cx.update(language_settings::init); - let text = indoc! {" - fn main() { - let x = 0; - for _ in 0..10 { - x += 1; - } - } - "}; - let buffer = - cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx)); - let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx)); - let range = buffer.read_with(cx, |buffer, cx| { - let snapshot = buffer.snapshot(cx); - snapshot.anchor_before(Point::new(1, 4))..snapshot.anchor_after(Point::new(4, 4)) - }); - let provider = Arc::new(TestCompletionProvider::new()); - let codegen = cx.add_model(|cx| { - Codegen::new( - buffer.clone(), - CodegenKind::Transform { range }, - provider.clone(), - cx, - ) - }); - codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx)); +// let text = indoc! {" +// fn main() { +// let x = 0; +// for _ in 0..10 { +// x += 1; +// } +// } +// "}; +// let buffer = +// cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx)); +// let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx)); +// let range = buffer.read_with(cx, |buffer, cx| { +// let snapshot = buffer.snapshot(cx); +// snapshot.anchor_before(Point::new(1, 4))..snapshot.anchor_after(Point::new(4, 4)) +// }); +// let provider = Arc::new(TestCompletionProvider::new()); +// let fc_provider = OpenAIFunctionCallingProvider::new("".to_string()); +// let codegen = cx.add_model(|cx| { +// Codegen::new( +// buffer.clone(), +// CodegenKind::Transform { range }, +// provider.clone(), +// fc_provider, +// cx, +// ) +// }); +// codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx)); - let mut new_text = concat!( - " let mut x = 0;\n", - " while x < 10 {\n", - " x += 1;\n", - " }", - ); - while !new_text.is_empty() { - let max_len = cmp::min(new_text.len(), 10); - let len = rng.gen_range(1..=max_len); - let (chunk, suffix) = new_text.split_at(len); - provider.send_completion(chunk); - new_text = suffix; - deterministic.run_until_parked(); - } - provider.finish_completion(); - deterministic.run_until_parked(); +// let mut new_text = concat!( +// " let mut x = 0;\n", +// " while x < 10 {\n", +// " x += 1;\n", +// " }", +// ); +// while !new_text.is_empty() { +// let max_len = cmp::min(new_text.len(), 10); +// let len = rng.gen_range(1..=max_len); +// let (chunk, suffix) = new_text.split_at(len); +// provider.send_completion(chunk); +// new_text = suffix; +// deterministic.run_until_parked(); +// } +// provider.finish_completion(); +// deterministic.run_until_parked(); - assert_eq!( - buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()), - indoc! {" - fn main() { - let mut x = 0; - while x < 10 { - x += 1; - } - } - "} - ); - } +// assert_eq!( +// buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()), +// indoc! {" +// fn main() { +// let mut x = 0; +// while x < 10 { +// x += 1; +// } +// } +// "} +// ); +// } - #[gpui::test(iterations = 10)] - async fn test_autoindent_when_generating_past_indentation( - cx: &mut TestAppContext, - mut rng: StdRng, - deterministic: Arc, - ) { - cx.set_global(cx.read(SettingsStore::test)); - cx.update(language_settings::init); +// #[gpui::test(iterations = 10)] +// async fn test_autoindent_when_generating_past_indentation( +// cx: &mut TestAppContext, +// mut rng: StdRng, +// deterministic: Arc, +// ) { +// cx.set_global(cx.read(SettingsStore::test)); +// cx.update(language_settings::init); - let text = indoc! {" - fn main() { - le - } - "}; - let buffer = - cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx)); - let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx)); - let position = buffer.read_with(cx, |buffer, cx| { - let snapshot = buffer.snapshot(cx); - snapshot.anchor_before(Point::new(1, 6)) - }); - let provider = Arc::new(TestCompletionProvider::new()); - let codegen = cx.add_model(|cx| { - Codegen::new( - buffer.clone(), - CodegenKind::Generate { position }, - provider.clone(), - cx, - ) - }); - codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx)); +// let text = indoc! {" +// fn main() { +// le +// } +// "}; +// let buffer = +// cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx)); +// let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx)); +// let position = buffer.read_with(cx, |buffer, cx| { +// let snapshot = buffer.snapshot(cx); +// snapshot.anchor_before(Point::new(1, 6)) +// }); +// let provider = Arc::new(TestCompletionProvider::new()); +// let codegen = cx.add_model(|cx| { +// Codegen::new( +// buffer.clone(), +// CodegenKind::Generate { position }, +// provider.clone(), +// cx, +// ) +// }); +// codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx)); - let mut new_text = concat!( - "t mut x = 0;\n", - "while x < 10 {\n", - " x += 1;\n", - "}", // - ); - while !new_text.is_empty() { - let max_len = cmp::min(new_text.len(), 10); - let len = rng.gen_range(1..=max_len); - let (chunk, suffix) = new_text.split_at(len); - provider.send_completion(chunk); - new_text = suffix; - deterministic.run_until_parked(); - } - provider.finish_completion(); - deterministic.run_until_parked(); +// let mut new_text = concat!( +// "t mut x = 0;\n", +// "while x < 10 {\n", +// " x += 1;\n", +// "}", // +// ); +// while !new_text.is_empty() { +// let max_len = cmp::min(new_text.len(), 10); +// let len = rng.gen_range(1..=max_len); +// let (chunk, suffix) = new_text.split_at(len); +// provider.send_completion(chunk); +// new_text = suffix; +// deterministic.run_until_parked(); +// } +// provider.finish_completion(); +// deterministic.run_until_parked(); - assert_eq!( - buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()), - indoc! {" - fn main() { - let mut x = 0; - while x < 10 { - x += 1; - } - } - "} - ); - } +// assert_eq!( +// buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()), +// indoc! {" +// fn main() { +// let mut x = 0; +// while x < 10 { +// x += 1; +// } +// } +// "} +// ); +// } - #[gpui::test(iterations = 10)] - async fn test_autoindent_when_generating_before_indentation( - cx: &mut TestAppContext, - mut rng: StdRng, - deterministic: Arc, - ) { - cx.set_global(cx.read(SettingsStore::test)); - cx.update(language_settings::init); +// #[gpui::test(iterations = 10)] +// async fn test_autoindent_when_generating_before_indentation( +// cx: &mut TestAppContext, +// mut rng: StdRng, +// deterministic: Arc, +// ) { +// cx.set_global(cx.read(SettingsStore::test)); +// cx.update(language_settings::init); - let text = concat!( - "fn main() {\n", - " \n", - "}\n" // - ); - let buffer = - cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx)); - let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx)); - let position = buffer.read_with(cx, |buffer, cx| { - let snapshot = buffer.snapshot(cx); - snapshot.anchor_before(Point::new(1, 2)) - }); - let provider = Arc::new(TestCompletionProvider::new()); - let codegen = cx.add_model(|cx| { - Codegen::new( - buffer.clone(), - CodegenKind::Generate { position }, - provider.clone(), - cx, - ) - }); - codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx)); +// let text = concat!( +// "fn main() {\n", +// " \n", +// "}\n" // +// ); +// let buffer = +// cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx)); +// let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx)); +// let position = buffer.read_with(cx, |buffer, cx| { +// let snapshot = buffer.snapshot(cx); +// snapshot.anchor_before(Point::new(1, 2)) +// }); +// let provider = Arc::new(TestCompletionProvider::new()); +// let codegen = cx.add_model(|cx| { +// Codegen::new( +// buffer.clone(), +// CodegenKind::Generate { position }, +// provider.clone(), +// cx, +// ) +// }); +// codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx)); - let mut new_text = concat!( - "let mut x = 0;\n", - "while x < 10 {\n", - " x += 1;\n", - "}", // - ); - while !new_text.is_empty() { - let max_len = cmp::min(new_text.len(), 10); - let len = rng.gen_range(1..=max_len); - let (chunk, suffix) = new_text.split_at(len); - provider.send_completion(chunk); - new_text = suffix; - deterministic.run_until_parked(); - } - provider.finish_completion(); - deterministic.run_until_parked(); +// let mut new_text = concat!( +// "let mut x = 0;\n", +// "while x < 10 {\n", +// " x += 1;\n", +// "}", // +// ); +// while !new_text.is_empty() { +// let max_len = cmp::min(new_text.len(), 10); +// let len = rng.gen_range(1..=max_len); +// let (chunk, suffix) = new_text.split_at(len); +// provider.send_completion(chunk); +// new_text = suffix; +// deterministic.run_until_parked(); +// } +// provider.finish_completion(); +// deterministic.run_until_parked(); - assert_eq!( - buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()), - indoc! {" - fn main() { - let mut x = 0; - while x < 10 { - x += 1; - } - } - "} - ); - } +// assert_eq!( +// buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()), +// indoc! {" +// fn main() { +// let mut x = 0; +// while x < 10 { +// x += 1; +// } +// } +// "} +// ); +// } - #[gpui::test] - async fn test_strip_markdown_codeblock() { - assert_eq!( - strip_markdown_codeblock(chunks("Lorem ipsum dolor", 2)) - .map(|chunk| chunk.unwrap()) - .collect::() - .await, - "Lorem ipsum dolor" - ); - assert_eq!( - strip_markdown_codeblock(chunks("```\nLorem ipsum dolor", 2)) - .map(|chunk| chunk.unwrap()) - .collect::() - .await, - "Lorem ipsum dolor" - ); - assert_eq!( - strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```", 2)) - .map(|chunk| chunk.unwrap()) - .collect::() - .await, - "Lorem ipsum dolor" - ); - assert_eq!( - strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2)) - .map(|chunk| chunk.unwrap()) - .collect::() - .await, - "Lorem ipsum dolor" - ); - assert_eq!( - strip_markdown_codeblock(chunks("```html\n```js\nLorem ipsum dolor\n```\n```", 2)) - .map(|chunk| chunk.unwrap()) - .collect::() - .await, - "```js\nLorem ipsum dolor\n```" - ); - assert_eq!( - strip_markdown_codeblock(chunks("``\nLorem ipsum dolor\n```", 2)) - .map(|chunk| chunk.unwrap()) - .collect::() - .await, - "``\nLorem ipsum dolor\n```" - ); +// #[gpui::test] +// async fn test_strip_markdown_codeblock() { +// assert_eq!( +// strip_markdown_codeblock(chunks("Lorem ipsum dolor", 2)) +// .map(|chunk| chunk.unwrap()) +// .collect::() +// .await, +// "Lorem ipsum dolor" +// ); +// assert_eq!( +// strip_markdown_codeblock(chunks("```\nLorem ipsum dolor", 2)) +// .map(|chunk| chunk.unwrap()) +// .collect::() +// .await, +// "Lorem ipsum dolor" +// ); +// assert_eq!( +// strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```", 2)) +// .map(|chunk| chunk.unwrap()) +// .collect::() +// .await, +// "Lorem ipsum dolor" +// ); +// assert_eq!( +// strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2)) +// .map(|chunk| chunk.unwrap()) +// .collect::() +// .await, +// "Lorem ipsum dolor" +// ); +// assert_eq!( +// strip_markdown_codeblock(chunks("```html\n```js\nLorem ipsum dolor\n```\n```", 2)) +// .map(|chunk| chunk.unwrap()) +// .collect::() +// .await, +// "```js\nLorem ipsum dolor\n```" +// ); +// assert_eq!( +// strip_markdown_codeblock(chunks("``\nLorem ipsum dolor\n```", 2)) +// .map(|chunk| chunk.unwrap()) +// .collect::() +// .await, +// "``\nLorem ipsum dolor\n```" +// ); - fn chunks(text: &str, size: usize) -> impl Stream> { - stream::iter( - text.chars() - .collect::>() - .chunks(size) - .map(|chunk| Ok(chunk.iter().collect::())) - .collect::>(), - ) - } - } +// fn chunks(text: &str, size: usize) -> impl Stream> { +// stream::iter( +// text.chars() +// .collect::>() +// .chunks(size) +// .map(|chunk| Ok(chunk.iter().collect::())) +// .collect::>(), +// ) +// } +// } - struct TestCompletionProvider { - last_completion_tx: Mutex>>, - } +// struct TestCompletionProvider { +// last_completion_tx: Mutex>>, +// } - impl TestCompletionProvider { - fn new() -> Self { - Self { - last_completion_tx: Mutex::new(None), - } - } +// impl TestCompletionProvider { +// fn new() -> Self { +// Self { +// last_completion_tx: Mutex::new(None), +// } +// } - fn send_completion(&self, completion: impl Into) { - let mut tx = self.last_completion_tx.lock(); - tx.as_mut().unwrap().try_send(completion.into()).unwrap(); - } +// fn send_completion(&self, completion: impl Into) { +// let mut tx = self.last_completion_tx.lock(); +// tx.as_mut().unwrap().try_send(completion.into()).unwrap(); +// } - fn finish_completion(&self) { - self.last_completion_tx.lock().take().unwrap(); - } - } +// fn finish_completion(&self) { +// self.last_completion_tx.lock().take().unwrap(); +// } +// } - impl CompletionProvider for TestCompletionProvider { - fn complete( - &self, - _prompt: OpenAIRequest, - ) -> BoxFuture<'static, Result>>> { - let (tx, rx) = mpsc::channel(1); - *self.last_completion_tx.lock() = Some(tx); - async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed() - } - } +// impl CompletionProvider for TestCompletionProvider { +// fn complete( +// &self, +// _prompt: OpenAIRequest, +// ) -> BoxFuture<'static, Result>>> { +// let (tx, rx) = mpsc::channel(1); +// *self.last_completion_tx.lock() = Some(tx); +// async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed() +// } +// } - fn rust_lang() -> Language { - Language::new( - LanguageConfig { - name: "Rust".into(), - path_suffixes: vec!["rs".to_string()], - ..Default::default() - }, - Some(tree_sitter_rust::language()), - ) - .with_indents_query( - r#" - (call_expression) @indent - (field_expression) @indent - (_ "(" ")" @end) @indent - (_ "{" "}" @end) @indent - "#, - ) - .unwrap() - } -} +// fn rust_lang() -> Language { +// Language::new( +// LanguageConfig { +// name: "Rust".into(), +// path_suffixes: vec!["rs".to_string()], +// ..Default::default() +// }, +// Some(tree_sitter_rust::language()), +// ) +// .with_indents_query( +// r#" +// (call_expression) @indent +// (field_expression) @indent +// (_ "(" ")" @end) @indent +// (_ "{" "}" @end) @indent +// "#, +// ) +// .unwrap() +// } +// } diff --git a/crates/assistant/src/prompts.rs b/crates/assistant/src/prompts.rs index 203d2f3e89..b0ec19763a 100644 --- a/crates/assistant/src/prompts.rs +++ b/crates/assistant/src/prompts.rs @@ -1,4 +1,4 @@ -use gpui::AppContext; +use gpui::{AppContext, AsyncAppContext}; use language::{BufferSnapshot, OffsetRangeExt, ToOffset}; use std::cmp; use std::ops::Range; @@ -83,14 +83,14 @@ fn outline_for_prompt( Some(text) } -fn generate_codegen_planning_prompt( +pub fn generate_codegen_planning_prompt( user_prompt: String, language_name: Option<&str>, buffer: &BufferSnapshot, range: Range, cx: &AppContext, kind: CodegenKind, -) -> String { +) -> (String, Option) { let mut prompt = String::new(); // General Preamble @@ -101,7 +101,7 @@ fn generate_codegen_planning_prompt( } let outline = outline_for_prompt(buffer, range.clone(), cx); - if let Some(outline) = outline { + if let Some(outline) = outline.clone() { writeln!( prompt, "You're currently working inside the Zed editor on a file with the following outline:" @@ -135,33 +135,41 @@ fn generate_codegen_planning_prompt( ) .unwrap(); - prompt + (prompt, outline) } pub fn generate_content_prompt( user_prompt: String, - language_name: Option<&str>, - buffer: &BufferSnapshot, - range: Range, - cx: &AppContext, + language_name: Option, + outline: Option, kind: CodegenKind, + snippet: Vec, ) -> String { let mut prompt = String::new(); // General Preamble - if let Some(language_name) = language_name { + if let Some(language_name) = language_name.clone() { writeln!(prompt, "You're an expert {language_name} engineer.\n").unwrap(); } else { writeln!(prompt, "You're an expert software engineer.\n").unwrap(); } - let outline = outline_for_prompt(buffer, range.clone(), cx); + if snippet.len() > 0 { + writeln!( + prompt, + "Here are a few snippets from the codebase which may help: " + ); + } + for snip in snippet { + writeln!(prompt, "{snip}"); + } + if let Some(outline) = outline { writeln!( prompt, "The file you are currently working on has the following outline:" ) .unwrap(); - if let Some(language_name) = language_name { + if let Some(language_name) = language_name.clone() { let language_name = language_name.to_lowercase(); writeln!(prompt, "```{language_name}\n{outline}\n```").unwrap(); } else { diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index ecdba43643..e606e69113 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -2,6 +2,7 @@ mod db; mod embedding_queue; mod parsing; pub mod semantic_index_settings; +pub mod skills; #[cfg(test)] mod semantic_index_tests; diff --git a/crates/semantic_index/src/skills.rs b/crates/semantic_index/src/skills.rs new file mode 100644 index 0000000000..589b1dc340 --- /dev/null +++ b/crates/semantic_index/src/skills.rs @@ -0,0 +1,106 @@ +use ai::function_calling::OpenAIFunction; +use anyhow::anyhow; +use gpui::{AppContext, AsyncAppContext, ModelHandle}; +use project::Project; +use serde::{Serialize, Serializer}; +use serde_json::json; +use std::fmt::Write; + +use crate::SemanticIndex; + +#[derive(Clone)] +pub struct RepositoryContextRetriever { + index: ModelHandle, + project: ModelHandle, +} + +impl RepositoryContextRetriever { + pub fn load(index: ModelHandle, project: ModelHandle) -> Self { + Self { index, project } + } + pub async fn complete_test( + &self, + arguments: serde_json::Value, + cx: &mut AsyncAppContext, + ) -> anyhow::Result { + let queries = arguments.get("queries").unwrap().as_array().unwrap(); + let mut prompt = String::new(); + let query = queries + .iter() + .map(|query| query.to_string()) + .collect::>() + .join(";"); + let project = self.project.clone(); + let results = self + .index + .update(cx, |this, cx| { + this.search_project(project, query, 10, vec![], vec![], cx) + }) + .await?; + + for result in results { + result.buffer.read_with(cx, |buffer, cx| { + let text = buffer.text_for_range(result.range).collect::(); + let file_path = buffer.file().unwrap().path().to_string_lossy(); + let language = buffer.language(); + + writeln!( + prompt, + "The following is a relevant snippet from file ({}):", + file_path + ) + .unwrap(); + if let Some(language) = language { + writeln!(prompt, "```{}\n{text}\n```", language.name().to_lowercase()).unwrap(); + } else { + writeln!(prompt, "```\n{text}\n```").unwrap(); + } + }); + } + + Ok(prompt) + } +} + +impl OpenAIFunction for RepositoryContextRetriever { + fn name(&self) -> String { + "retrieve_context_from_repository".to_string() + } + fn description(&self) -> String { + "Retrieve relevant content from repository with natural language".to_string() + } + fn system_prompt(&self) -> String { + "'retrieve_context_from_repository' + If more information is needed from the repository, to complete the users prompt reliably, pass up to 3 queries describing pieces of code or text you would like additional context upon. + Do not make these queries general about programming, include very specific lexical references to the pieces of code you need more information on. + We are passing these into a semantic similarity retrieval engine, with all the information in the current codebase included. + As such, these should be phrased as descriptions of code of interest as opposed to questions".to_string() + } + fn parameters(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "queries": { + "title": "queries", + "type": "array", + "items": {"type": "string"} + } + }, + "required": ["queries"] + }) + } + fn complete(&self, arguments: serde_json::Value) -> anyhow::Result { + todo!(); + } +} +impl Serialize for RepositoryContextRetriever { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + json!({"name": self.name(), + "description": self.description(), + "parameters": self.parameters()}) + .serialize(serializer) + } +}