working inline with semantic index, this should never merge

This commit is contained in:
KCaverly 2023-09-29 14:02:41 -04:00
parent 3cc499ee7c
commit a29b70552f
10 changed files with 698 additions and 350 deletions

1
Cargo.lock generated
View file

@ -323,6 +323,7 @@ dependencies = [
"regex",
"schemars",
"search",
"semantic_index",
"serde",
"serde_json",
"settings",

View file

@ -1,6 +1,7 @@
pub mod completion;
pub mod embedding;
pub mod function_calling;
pub mod skills;
use core::fmt;
use std::fmt::Display;
@ -35,7 +36,7 @@ impl Display for Role {
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
pub struct RequestMessage {
pub role: Role,
pub content: String,

View file

@ -13,6 +13,7 @@ pub trait OpenAIFunction: erased_serde::Serialize {
fn description(&self) -> String;
fn system_prompt(&self) -> String;
fn parameters(&self) -> serde_json::Value;
fn complete(&self, arguments: serde_json::Value) -> anyhow::Result<String>;
}
serialize_trait_object!(OpenAIFunction);
@ -83,6 +84,7 @@ pub struct FunctionCallDetails {
pub arguments: serde_json::Value, // json object respresenting provided arguments
}
#[derive(Clone)]
pub struct OpenAIFunctionCallingProvider {
api_key: String,
}

50
crates/ai/src/skills.rs Normal file
View file

@ -0,0 +1,50 @@
use crate::function_calling::OpenAIFunction;
use gpui::{AppContext, ModelHandle};
use project::Project;
use serde::{Serialize, Serializer};
use serde_json::json;
pub struct RewritePrompt;
impl Serialize for RewritePrompt {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
json!({"name": self.name(),
"description": self.description(),
"parameters": self.parameters()})
.serialize(serializer)
}
}
impl RewritePrompt {
pub fn load() -> Self {
Self {}
}
}
impl OpenAIFunction for RewritePrompt {
fn name(&self) -> String {
"rewrite_prompt".to_string()
}
fn description(&self) -> String {
"Rewrite prompt given prompt from user".to_string()
}
fn system_prompt(&self) -> String {
"'rewrite_prompt':
If all information is available in the above prompt, and you need no further information.
Rewrite the entire prompt to clarify what should be generated, do not actually complete the users request.
Assume this rewritten message will be passed to another completion agent, to fulfill the users request.".to_string()
}
fn parameters(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"prompt": {}
}
})
}
fn complete(&self, arguments: serde_json::Value) -> anyhow::Result<String> {
Ok(arguments.get("prompt").unwrap().to_string())
}
}

View file

@ -23,6 +23,8 @@ theme = { path = "../theme" }
util = { path = "../util" }
uuid = { version = "1.1.2", features = ["v4"] }
workspace = { path = "../workspace" }
semantic_index = { path = "../semantic_index" }
project = { path = "../project" }
anyhow.workspace = true
chrono = { version = "0.4", features = ["serde"] }

View file

@ -1,12 +1,16 @@
use crate::{
assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAIModel},
codegen::{self, Codegen, CodegenKind},
prompts::generate_content_prompt,
prompts::{generate_codegen_planning_prompt, generate_content_prompt},
MessageId, MessageMetadata, MessageStatus, Role, SavedConversation, SavedConversationMetadata,
SavedMessage,
};
use ai::completion::{stream_completion, OpenAICompletionProvider, OpenAIRequest, OPENAI_API_URL};
use ai::RequestMessage;
use ai::{
completion::{stream_completion, OpenAICompletionProvider, OpenAIRequest, OPENAI_API_URL},
function_calling::OpenAIFunctionCallingProvider,
skills::RewritePrompt,
};
use ai::{function_calling::OpenAIFunction, RequestMessage};
use anyhow::{anyhow, Result};
use chrono::{DateTime, Local};
use client::{telemetry::AssistantKind, ClickhouseEvent, TelemetrySettings};
@ -34,7 +38,9 @@ use gpui::{
WindowContext,
};
use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _};
use project::Project;
use search::BufferSearchBar;
use semantic_index::{skills::RepositoryContextRetriever, SemanticIndex};
use settings::SettingsStore;
use std::{
cell::{Cell, RefCell},
@ -144,6 +150,8 @@ pub struct AssistantPanel {
include_conversation_in_next_inline_assist: bool,
inline_prompt_history: VecDeque<String>,
_watch_saved_conversations: Task<Result<()>>,
semantic_index: ModelHandle<SemanticIndex>,
project: ModelHandle<Project>,
}
impl AssistantPanel {
@ -153,6 +161,7 @@ impl AssistantPanel {
workspace: WeakViewHandle<Workspace>,
cx: AsyncAppContext,
) -> Task<Result<ViewHandle<Self>>> {
let index = cx.read(|cx| SemanticIndex::global(cx).unwrap());
cx.spawn(|mut cx| async move {
let fs = workspace.read_with(&cx, |workspace, _| workspace.app_state().fs.clone())?;
let saved_conversations = SavedConversationMetadata::list(fs.clone())
@ -190,6 +199,9 @@ impl AssistantPanel {
toolbar.add_item(cx.add_view(|cx| BufferSearchBar::new(cx)), cx);
toolbar
});
let project = workspace.project().clone();
let mut this = Self {
workspace: workspace_handle,
active_editor_index: Default::default(),
@ -214,6 +226,8 @@ impl AssistantPanel {
include_conversation_in_next_inline_assist: false,
inline_prompt_history: Default::default(),
_watch_saved_conversations,
semantic_index: index,
project,
};
let mut old_dock_position = this.position(cx);
@ -276,9 +290,10 @@ impl AssistantPanel {
let inline_assist_id = post_inc(&mut self.next_inline_assist_id);
let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
let provider = Arc::new(OpenAICompletionProvider::new(
api_key,
api_key.clone(),
cx.background().clone(),
));
let fc_provider = OpenAIFunctionCallingProvider::new(api_key);
let selection = editor.read(cx).selections.newest_anchor().clone();
let codegen_kind = if editor.read(cx).selections.newest::<usize>(cx).is_empty() {
CodegenKind::Generate {
@ -289,8 +304,18 @@ impl AssistantPanel {
range: selection.start..selection.end,
}
};
let project = self.project.clone();
let codegen = cx.add_model(|cx| {
Codegen::new(editor.read(cx).buffer().clone(), codegen_kind, provider, cx)
Codegen::new(
editor.read(cx).buffer().clone(),
codegen_kind,
provider,
fc_provider,
cx,
project.clone(),
)
});
let measurements = Rc::new(Cell::new(BlockMeasurements::default()));
@ -572,42 +597,74 @@ impl AssistantPanel {
let language_name = language_name.as_deref();
let codegen_kind = pending_assist.codegen.read(cx).kind().clone();
let prompt = generate_content_prompt(
let index = self.semantic_index.clone();
pending_assist.codegen.update(cx, |codegen, cx| {
codegen.start(
user_prompt.to_string(),
language_name,
&snapshot,
language_range,
cx,
codegen_kind,
);
let mut messages = Vec::new();
let mut model = settings::get::<AssistantSettings>(cx)
.default_open_ai_model
.clone();
if let Some(conversation) = conversation {
let conversation = conversation.read(cx);
let buffer = conversation.buffer.read(cx);
messages.extend(
conversation
.messages(cx)
.map(|message| message.to_open_ai_message(buffer)),
);
model = conversation.model.clone();
}
messages.push(RequestMessage {
role: Role::User,
content: prompt,
language_name,
snapshot,
language_range.clone(),
codegen_kind.clone(),
index,
)
});
let request = OpenAIRequest {
model: model.full_name().into(),
messages,
stream: true,
};
pending_assist
.codegen
.update(cx, |codegen, cx| codegen.start(request, cx));
// let api_key = self.api_key.as_ref().clone().into_inner().clone().unwrap();
// let function_provider = OpenAIFunctionCallingProvider::new(api_key);
// let planning_messages = vec![RequestMessage {
// role: Role::User,
// content: planning_prompt,
// }];
// println!("GETTING HERE");
// let function_call = cx
// .spawn(|this, mut cx| async move {
// let result = function_provider
// .complete("gpt-4".to_string(), planning_messages, functions)
// .await;
// dbg!(&result);
// result
// })
// .detach();
// let function_name = function_call.name.as_str();
// let prompt = match function_name {
// "rewrite_prompt" => {
// let user_prompt = RewritePrompt::load()
// .complete(function_call.arguments)
// .unwrap();
// generate_content_prompt(
// user_prompt.to_string(),
// language_name,
// &snapshot,
// language_range,
// cx,
// codegen_kind,
// )
// }
// _ => {
// todo!();
// }
// };
// let mut messages = Vec::new();
// let mut model = settings::get::<AssistantSettings>(cx)
// .default_open_ai_model
// .clone();
// if let Some(conversation) = conversation {
// let conversation = conversation.read(cx);
// let buffer = conversation.buffer.read(cx);
// messages.extend(
// conversation
// .messages(cx)
// .map(|message| message.to_open_ai_message(buffer)),
// );
// model = conversation.model.clone();
// }
}
fn update_highlights_for_editor(

View file

@ -1,12 +1,22 @@
use crate::streaming_diff::{Hunk, StreamingDiff};
use ai::completion::{CompletionProvider, OpenAIRequest};
use crate::{
prompts::{generate_codegen_planning_prompt, generate_content_prompt},
streaming_diff::{Hunk, StreamingDiff},
};
use ai::{
completion::{CompletionProvider, OpenAIRequest},
function_calling::{OpenAIFunction, OpenAIFunctionCallingProvider},
skills::RewritePrompt,
RequestMessage, Role,
};
use anyhow::Result;
use editor::{
multi_buffer, Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint,
};
use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
use gpui::{Entity, ModelContext, ModelHandle, Task};
use language::{Rope, TransactionId};
use gpui::{BorrowAppContext, Entity, ModelContext, ModelHandle, Task};
use language::{BufferSnapshot, Rope, TransactionId};
use project::Project;
use semantic_index::{skills::RepositoryContextRetriever, SemanticIndex};
use std::{cmp, future, ops::Range, sync::Arc};
pub enum Event {
@ -22,6 +32,7 @@ pub enum CodegenKind {
pub struct Codegen {
provider: Arc<dyn CompletionProvider>,
fc_provider: OpenAIFunctionCallingProvider,
buffer: ModelHandle<MultiBuffer>,
snapshot: MultiBufferSnapshot,
kind: CodegenKind,
@ -31,6 +42,7 @@ pub struct Codegen {
generation: Task<()>,
idle: bool,
_subscription: gpui::Subscription,
project: ModelHandle<Project>,
}
impl Entity for Codegen {
@ -42,7 +54,9 @@ impl Codegen {
buffer: ModelHandle<MultiBuffer>,
mut kind: CodegenKind,
provider: Arc<dyn CompletionProvider>,
fc_provider: OpenAIFunctionCallingProvider,
cx: &mut ModelContext<Self>,
project: ModelHandle<Project>,
) -> Self {
let snapshot = buffer.read(cx).snapshot(cx);
match &mut kind {
@ -62,6 +76,7 @@ impl Codegen {
Self {
provider,
fc_provider,
buffer: buffer.clone(),
snapshot,
kind,
@ -71,6 +86,7 @@ impl Codegen {
idle: true,
generation: Task::ready(()),
_subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
project,
}
}
@ -112,7 +128,17 @@ impl Codegen {
self.error.as_ref()
}
pub fn start(&mut self, prompt: OpenAIRequest, cx: &mut ModelContext<Self>) {
pub fn start(
&mut self,
prompt: String,
cx: &mut ModelContext<Self>,
language_name: Option<&str>,
buffer: BufferSnapshot,
range: Range<language::Anchor>,
kind: CodegenKind,
index: ModelHandle<SemanticIndex>,
) {
let language_range = range.clone();
let range = self.range();
let snapshot = self.snapshot.clone();
let selected_text = snapshot
@ -126,9 +152,101 @@ impl Codegen {
.next()
.unwrap_or_else(|| snapshot.indent_size_for_line(selection_start.row));
let response = self.provider.complete(prompt);
let messages = vec![RequestMessage {
role: Role::User,
content: prompt.clone(),
}];
let request = OpenAIRequest {
model: "gpt-4".to_string(),
messages: messages.clone(),
stream: true,
};
let (planning_prompt, outline) = generate_codegen_planning_prompt(
prompt.clone(),
language_name.clone(),
&buffer,
language_range.clone(),
cx,
kind.clone(),
);
let project = self.project.clone();
self.generation = cx.spawn_weak(|this, mut cx| {
// Plan Ahead
let planning_messages = vec![RequestMessage {
role: Role::User,
content: planning_prompt,
}];
let repo_retriever = RepositoryContextRetriever::load(index, project);
let functions: Vec<Box<dyn OpenAIFunction>> = vec![
Box::new(RewritePrompt::load()),
Box::new(repo_retriever.clone()),
];
let completion_provider = self.provider.clone();
let fc_provider = self.fc_provider.clone();
let language_name = language_name.clone();
let language_name = if let Some(language_name) = language_name.clone() {
Some(language_name.to_string())
} else {
None
};
let kind = kind.clone();
async move {
let mut user_prompt = prompt.clone();
let user_prompt = if let Ok(function_call) = fc_provider
.complete("gpt-4".to_string(), planning_messages, functions)
.await
{
let function_name = function_call.name.as_str();
println!("FUNCTION NAME: {:?}", function_name);
let user_prompt = match function_name {
"rewrite_prompt" => {
let user_prompt = RewritePrompt::load()
.complete(function_call.arguments)
.unwrap();
generate_content_prompt(
user_prompt,
language_name,
outline,
kind,
vec![],
)
}
_ => {
let arguments = function_call.arguments.clone();
let snippet = repo_retriever
.complete_test(arguments, &mut cx)
.await
.unwrap();
let snippet = vec![snippet];
generate_content_prompt(prompt, language_name, outline, kind, snippet)
}
};
user_prompt
} else {
user_prompt
};
println!("{:?}", user_prompt.clone());
let messages = vec![RequestMessage {
role: Role::User,
content: user_prompt.clone(),
}];
let request = OpenAIRequest {
model: "gpt-4".to_string(),
messages: messages.clone(),
stream: true,
};
let response = completion_provider.complete(request);
let generate = async {
let mut edit_start = range.start.to_offset(&snapshot);
@ -349,315 +467,317 @@ fn strip_markdown_codeblock(
})
}
#[cfg(test)]
mod tests {
use super::*;
use futures::{
future::BoxFuture,
stream::{self, BoxStream},
};
use gpui::{executor::Deterministic, TestAppContext};
use indoc::indoc;
use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
use parking_lot::Mutex;
use rand::prelude::*;
use settings::SettingsStore;
use smol::future::FutureExt;
// #[cfg(test)]
// mod tests {
// use super::*;
// use futures::{
// future::BoxFuture,
// stream::{self, BoxStream},
// };
// use gpui::{executor::Deterministic, TestAppContext};
// use indoc::indoc;
// use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
// use parking_lot::Mutex;
// use rand::prelude::*;
// use settings::SettingsStore;
// use smol::future::FutureExt;
#[gpui::test(iterations = 10)]
async fn test_transform_autoindent(
cx: &mut TestAppContext,
mut rng: StdRng,
deterministic: Arc<Deterministic>,
) {
cx.set_global(cx.read(SettingsStore::test));
cx.update(language_settings::init);
// #[gpui::test(iterations = 10)]
// async fn test_transform_autoindent(
// cx: &mut TestAppContext,
// mut rng: StdRng,
// deterministic: Arc<Deterministic>,
// ) {
// cx.set_global(cx.read(SettingsStore::test));
// cx.update(language_settings::init);
let text = indoc! {"
fn main() {
let x = 0;
for _ in 0..10 {
x += 1;
}
}
"};
let buffer =
cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
let range = buffer.read_with(cx, |buffer, cx| {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 4))..snapshot.anchor_after(Point::new(4, 4))
});
let provider = Arc::new(TestCompletionProvider::new());
let codegen = cx.add_model(|cx| {
Codegen::new(
buffer.clone(),
CodegenKind::Transform { range },
provider.clone(),
cx,
)
});
codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
// let text = indoc! {"
// fn main() {
// let x = 0;
// for _ in 0..10 {
// x += 1;
// }
// }
// "};
// let buffer =
// cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
// let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
// let range = buffer.read_with(cx, |buffer, cx| {
// let snapshot = buffer.snapshot(cx);
// snapshot.anchor_before(Point::new(1, 4))..snapshot.anchor_after(Point::new(4, 4))
// });
// let provider = Arc::new(TestCompletionProvider::new());
// let fc_provider = OpenAIFunctionCallingProvider::new("".to_string());
// let codegen = cx.add_model(|cx| {
// Codegen::new(
// buffer.clone(),
// CodegenKind::Transform { range },
// provider.clone(),
// fc_provider,
// cx,
// )
// });
// codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
let mut new_text = concat!(
" let mut x = 0;\n",
" while x < 10 {\n",
" x += 1;\n",
" }",
);
while !new_text.is_empty() {
let max_len = cmp::min(new_text.len(), 10);
let len = rng.gen_range(1..=max_len);
let (chunk, suffix) = new_text.split_at(len);
provider.send_completion(chunk);
new_text = suffix;
deterministic.run_until_parked();
}
provider.finish_completion();
deterministic.run_until_parked();
// let mut new_text = concat!(
// " let mut x = 0;\n",
// " while x < 10 {\n",
// " x += 1;\n",
// " }",
// );
// while !new_text.is_empty() {
// let max_len = cmp::min(new_text.len(), 10);
// let len = rng.gen_range(1..=max_len);
// let (chunk, suffix) = new_text.split_at(len);
// provider.send_completion(chunk);
// new_text = suffix;
// deterministic.run_until_parked();
// }
// provider.finish_completion();
// deterministic.run_until_parked();
assert_eq!(
buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
indoc! {"
fn main() {
let mut x = 0;
while x < 10 {
x += 1;
}
}
"}
);
}
// assert_eq!(
// buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
// indoc! {"
// fn main() {
// let mut x = 0;
// while x < 10 {
// x += 1;
// }
// }
// "}
// );
// }
#[gpui::test(iterations = 10)]
async fn test_autoindent_when_generating_past_indentation(
cx: &mut TestAppContext,
mut rng: StdRng,
deterministic: Arc<Deterministic>,
) {
cx.set_global(cx.read(SettingsStore::test));
cx.update(language_settings::init);
// #[gpui::test(iterations = 10)]
// async fn test_autoindent_when_generating_past_indentation(
// cx: &mut TestAppContext,
// mut rng: StdRng,
// deterministic: Arc<Deterministic>,
// ) {
// cx.set_global(cx.read(SettingsStore::test));
// cx.update(language_settings::init);
let text = indoc! {"
fn main() {
le
}
"};
let buffer =
cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
let position = buffer.read_with(cx, |buffer, cx| {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 6))
});
let provider = Arc::new(TestCompletionProvider::new());
let codegen = cx.add_model(|cx| {
Codegen::new(
buffer.clone(),
CodegenKind::Generate { position },
provider.clone(),
cx,
)
});
codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
// let text = indoc! {"
// fn main() {
// le
// }
// "};
// let buffer =
// cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
// let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
// let position = buffer.read_with(cx, |buffer, cx| {
// let snapshot = buffer.snapshot(cx);
// snapshot.anchor_before(Point::new(1, 6))
// });
// let provider = Arc::new(TestCompletionProvider::new());
// let codegen = cx.add_model(|cx| {
// Codegen::new(
// buffer.clone(),
// CodegenKind::Generate { position },
// provider.clone(),
// cx,
// )
// });
// codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
let mut new_text = concat!(
"t mut x = 0;\n",
"while x < 10 {\n",
" x += 1;\n",
"}", //
);
while !new_text.is_empty() {
let max_len = cmp::min(new_text.len(), 10);
let len = rng.gen_range(1..=max_len);
let (chunk, suffix) = new_text.split_at(len);
provider.send_completion(chunk);
new_text = suffix;
deterministic.run_until_parked();
}
provider.finish_completion();
deterministic.run_until_parked();
// let mut new_text = concat!(
// "t mut x = 0;\n",
// "while x < 10 {\n",
// " x += 1;\n",
// "}", //
// );
// while !new_text.is_empty() {
// let max_len = cmp::min(new_text.len(), 10);
// let len = rng.gen_range(1..=max_len);
// let (chunk, suffix) = new_text.split_at(len);
// provider.send_completion(chunk);
// new_text = suffix;
// deterministic.run_until_parked();
// }
// provider.finish_completion();
// deterministic.run_until_parked();
assert_eq!(
buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
indoc! {"
fn main() {
let mut x = 0;
while x < 10 {
x += 1;
}
}
"}
);
}
// assert_eq!(
// buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
// indoc! {"
// fn main() {
// let mut x = 0;
// while x < 10 {
// x += 1;
// }
// }
// "}
// );
// }
#[gpui::test(iterations = 10)]
async fn test_autoindent_when_generating_before_indentation(
cx: &mut TestAppContext,
mut rng: StdRng,
deterministic: Arc<Deterministic>,
) {
cx.set_global(cx.read(SettingsStore::test));
cx.update(language_settings::init);
// #[gpui::test(iterations = 10)]
// async fn test_autoindent_when_generating_before_indentation(
// cx: &mut TestAppContext,
// mut rng: StdRng,
// deterministic: Arc<Deterministic>,
// ) {
// cx.set_global(cx.read(SettingsStore::test));
// cx.update(language_settings::init);
let text = concat!(
"fn main() {\n",
" \n",
"}\n" //
);
let buffer =
cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
let position = buffer.read_with(cx, |buffer, cx| {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 2))
});
let provider = Arc::new(TestCompletionProvider::new());
let codegen = cx.add_model(|cx| {
Codegen::new(
buffer.clone(),
CodegenKind::Generate { position },
provider.clone(),
cx,
)
});
codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
// let text = concat!(
// "fn main() {\n",
// " \n",
// "}\n" //
// );
// let buffer =
// cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
// let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
// let position = buffer.read_with(cx, |buffer, cx| {
// let snapshot = buffer.snapshot(cx);
// snapshot.anchor_before(Point::new(1, 2))
// });
// let provider = Arc::new(TestCompletionProvider::new());
// let codegen = cx.add_model(|cx| {
// Codegen::new(
// buffer.clone(),
// CodegenKind::Generate { position },
// provider.clone(),
// cx,
// )
// });
// codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
let mut new_text = concat!(
"let mut x = 0;\n",
"while x < 10 {\n",
" x += 1;\n",
"}", //
);
while !new_text.is_empty() {
let max_len = cmp::min(new_text.len(), 10);
let len = rng.gen_range(1..=max_len);
let (chunk, suffix) = new_text.split_at(len);
provider.send_completion(chunk);
new_text = suffix;
deterministic.run_until_parked();
}
provider.finish_completion();
deterministic.run_until_parked();
// let mut new_text = concat!(
// "let mut x = 0;\n",
// "while x < 10 {\n",
// " x += 1;\n",
// "}", //
// );
// while !new_text.is_empty() {
// let max_len = cmp::min(new_text.len(), 10);
// let len = rng.gen_range(1..=max_len);
// let (chunk, suffix) = new_text.split_at(len);
// provider.send_completion(chunk);
// new_text = suffix;
// deterministic.run_until_parked();
// }
// provider.finish_completion();
// deterministic.run_until_parked();
assert_eq!(
buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
indoc! {"
fn main() {
let mut x = 0;
while x < 10 {
x += 1;
}
}
"}
);
}
// assert_eq!(
// buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
// indoc! {"
// fn main() {
// let mut x = 0;
// while x < 10 {
// x += 1;
// }
// }
// "}
// );
// }
#[gpui::test]
async fn test_strip_markdown_codeblock() {
assert_eq!(
strip_markdown_codeblock(chunks("Lorem ipsum dolor", 2))
.map(|chunk| chunk.unwrap())
.collect::<String>()
.await,
"Lorem ipsum dolor"
);
assert_eq!(
strip_markdown_codeblock(chunks("```\nLorem ipsum dolor", 2))
.map(|chunk| chunk.unwrap())
.collect::<String>()
.await,
"Lorem ipsum dolor"
);
assert_eq!(
strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
.map(|chunk| chunk.unwrap())
.collect::<String>()
.await,
"Lorem ipsum dolor"
);
assert_eq!(
strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
.map(|chunk| chunk.unwrap())
.collect::<String>()
.await,
"Lorem ipsum dolor"
);
assert_eq!(
strip_markdown_codeblock(chunks("```html\n```js\nLorem ipsum dolor\n```\n```", 2))
.map(|chunk| chunk.unwrap())
.collect::<String>()
.await,
"```js\nLorem ipsum dolor\n```"
);
assert_eq!(
strip_markdown_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
.map(|chunk| chunk.unwrap())
.collect::<String>()
.await,
"``\nLorem ipsum dolor\n```"
);
// #[gpui::test]
// async fn test_strip_markdown_codeblock() {
// assert_eq!(
// strip_markdown_codeblock(chunks("Lorem ipsum dolor", 2))
// .map(|chunk| chunk.unwrap())
// .collect::<String>()
// .await,
// "Lorem ipsum dolor"
// );
// assert_eq!(
// strip_markdown_codeblock(chunks("```\nLorem ipsum dolor", 2))
// .map(|chunk| chunk.unwrap())
// .collect::<String>()
// .await,
// "Lorem ipsum dolor"
// );
// assert_eq!(
// strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
// .map(|chunk| chunk.unwrap())
// .collect::<String>()
// .await,
// "Lorem ipsum dolor"
// );
// assert_eq!(
// strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
// .map(|chunk| chunk.unwrap())
// .collect::<String>()
// .await,
// "Lorem ipsum dolor"
// );
// assert_eq!(
// strip_markdown_codeblock(chunks("```html\n```js\nLorem ipsum dolor\n```\n```", 2))
// .map(|chunk| chunk.unwrap())
// .collect::<String>()
// .await,
// "```js\nLorem ipsum dolor\n```"
// );
// assert_eq!(
// strip_markdown_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
// .map(|chunk| chunk.unwrap())
// .collect::<String>()
// .await,
// "``\nLorem ipsum dolor\n```"
// );
fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
stream::iter(
text.chars()
.collect::<Vec<_>>()
.chunks(size)
.map(|chunk| Ok(chunk.iter().collect::<String>()))
.collect::<Vec<_>>(),
)
}
}
// fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
// stream::iter(
// text.chars()
// .collect::<Vec<_>>()
// .chunks(size)
// .map(|chunk| Ok(chunk.iter().collect::<String>()))
// .collect::<Vec<_>>(),
// )
// }
// }
struct TestCompletionProvider {
last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
}
// struct TestCompletionProvider {
// last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
// }
impl TestCompletionProvider {
fn new() -> Self {
Self {
last_completion_tx: Mutex::new(None),
}
}
// impl TestCompletionProvider {
// fn new() -> Self {
// Self {
// last_completion_tx: Mutex::new(None),
// }
// }
fn send_completion(&self, completion: impl Into<String>) {
let mut tx = self.last_completion_tx.lock();
tx.as_mut().unwrap().try_send(completion.into()).unwrap();
}
// fn send_completion(&self, completion: impl Into<String>) {
// let mut tx = self.last_completion_tx.lock();
// tx.as_mut().unwrap().try_send(completion.into()).unwrap();
// }
fn finish_completion(&self) {
self.last_completion_tx.lock().take().unwrap();
}
}
// fn finish_completion(&self) {
// self.last_completion_tx.lock().take().unwrap();
// }
// }
impl CompletionProvider for TestCompletionProvider {
fn complete(
&self,
_prompt: OpenAIRequest,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let (tx, rx) = mpsc::channel(1);
*self.last_completion_tx.lock() = Some(tx);
async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
}
}
// impl CompletionProvider for TestCompletionProvider {
// fn complete(
// &self,
// _prompt: OpenAIRequest,
// ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
// let (tx, rx) = mpsc::channel(1);
// *self.last_completion_tx.lock() = Some(tx);
// async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
// }
// }
fn rust_lang() -> Language {
Language::new(
LanguageConfig {
name: "Rust".into(),
path_suffixes: vec!["rs".to_string()],
..Default::default()
},
Some(tree_sitter_rust::language()),
)
.with_indents_query(
r#"
(call_expression) @indent
(field_expression) @indent
(_ "(" ")" @end) @indent
(_ "{" "}" @end) @indent
"#,
)
.unwrap()
}
}
// fn rust_lang() -> Language {
// Language::new(
// LanguageConfig {
// name: "Rust".into(),
// path_suffixes: vec!["rs".to_string()],
// ..Default::default()
// },
// Some(tree_sitter_rust::language()),
// )
// .with_indents_query(
// r#"
// (call_expression) @indent
// (field_expression) @indent
// (_ "(" ")" @end) @indent
// (_ "{" "}" @end) @indent
// "#,
// )
// .unwrap()
// }
// }

View file

@ -1,4 +1,4 @@
use gpui::AppContext;
use gpui::{AppContext, AsyncAppContext};
use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
use std::cmp;
use std::ops::Range;
@ -83,14 +83,14 @@ fn outline_for_prompt(
Some(text)
}
fn generate_codegen_planning_prompt(
pub fn generate_codegen_planning_prompt(
user_prompt: String,
language_name: Option<&str>,
buffer: &BufferSnapshot,
range: Range<language::Anchor>,
cx: &AppContext,
kind: CodegenKind,
) -> String {
) -> (String, Option<String>) {
let mut prompt = String::new();
// General Preamble
@ -101,7 +101,7 @@ fn generate_codegen_planning_prompt(
}
let outline = outline_for_prompt(buffer, range.clone(), cx);
if let Some(outline) = outline {
if let Some(outline) = outline.clone() {
writeln!(
prompt,
"You're currently working inside the Zed editor on a file with the following outline:"
@ -135,33 +135,41 @@ fn generate_codegen_planning_prompt(
)
.unwrap();
prompt
(prompt, outline)
}
pub fn generate_content_prompt(
user_prompt: String,
language_name: Option<&str>,
buffer: &BufferSnapshot,
range: Range<language::Anchor>,
cx: &AppContext,
language_name: Option<String>,
outline: Option<String>,
kind: CodegenKind,
snippet: Vec<String>,
) -> String {
let mut prompt = String::new();
// General Preamble
if let Some(language_name) = language_name {
if let Some(language_name) = language_name.clone() {
writeln!(prompt, "You're an expert {language_name} engineer.\n").unwrap();
} else {
writeln!(prompt, "You're an expert software engineer.\n").unwrap();
}
let outline = outline_for_prompt(buffer, range.clone(), cx);
if snippet.len() > 0 {
writeln!(
prompt,
"Here are a few snippets from the codebase which may help: "
);
}
for snip in snippet {
writeln!(prompt, "{snip}");
}
if let Some(outline) = outline {
writeln!(
prompt,
"The file you are currently working on has the following outline:"
)
.unwrap();
if let Some(language_name) = language_name {
if let Some(language_name) = language_name.clone() {
let language_name = language_name.to_lowercase();
writeln!(prompt, "```{language_name}\n{outline}\n```").unwrap();
} else {

View file

@ -2,6 +2,7 @@ mod db;
mod embedding_queue;
mod parsing;
pub mod semantic_index_settings;
pub mod skills;
#[cfg(test)]
mod semantic_index_tests;

View file

@ -0,0 +1,106 @@
use ai::function_calling::OpenAIFunction;
use anyhow::anyhow;
use gpui::{AppContext, AsyncAppContext, ModelHandle};
use project::Project;
use serde::{Serialize, Serializer};
use serde_json::json;
use std::fmt::Write;
use crate::SemanticIndex;
#[derive(Clone)]
pub struct RepositoryContextRetriever {
index: ModelHandle<SemanticIndex>,
project: ModelHandle<Project>,
}
impl RepositoryContextRetriever {
pub fn load(index: ModelHandle<SemanticIndex>, project: ModelHandle<Project>) -> Self {
Self { index, project }
}
pub async fn complete_test(
&self,
arguments: serde_json::Value,
cx: &mut AsyncAppContext,
) -> anyhow::Result<String> {
let queries = arguments.get("queries").unwrap().as_array().unwrap();
let mut prompt = String::new();
let query = queries
.iter()
.map(|query| query.to_string())
.collect::<Vec<String>>()
.join(";");
let project = self.project.clone();
let results = self
.index
.update(cx, |this, cx| {
this.search_project(project, query, 10, vec![], vec![], cx)
})
.await?;
for result in results {
result.buffer.read_with(cx, |buffer, cx| {
let text = buffer.text_for_range(result.range).collect::<String>();
let file_path = buffer.file().unwrap().path().to_string_lossy();
let language = buffer.language();
writeln!(
prompt,
"The following is a relevant snippet from file ({}):",
file_path
)
.unwrap();
if let Some(language) = language {
writeln!(prompt, "```{}\n{text}\n```", language.name().to_lowercase()).unwrap();
} else {
writeln!(prompt, "```\n{text}\n```").unwrap();
}
});
}
Ok(prompt)
}
}
impl OpenAIFunction for RepositoryContextRetriever {
fn name(&self) -> String {
"retrieve_context_from_repository".to_string()
}
fn description(&self) -> String {
"Retrieve relevant content from repository with natural language".to_string()
}
fn system_prompt(&self) -> String {
"'retrieve_context_from_repository'
If more information is needed from the repository, to complete the users prompt reliably, pass up to 3 queries describing pieces of code or text you would like additional context upon.
Do not make these queries general about programming, include very specific lexical references to the pieces of code you need more information on.
We are passing these into a semantic similarity retrieval engine, with all the information in the current codebase included.
As such, these should be phrased as descriptions of code of interest as opposed to questions".to_string()
}
fn parameters(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"queries": {
"title": "queries",
"type": "array",
"items": {"type": "string"}
}
},
"required": ["queries"]
})
}
fn complete(&self, arguments: serde_json::Value) -> anyhow::Result<String> {
todo!();
}
}
impl Serialize for RepositoryContextRetriever {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
json!({"name": self.name(),
"description": self.description(),
"parameters": self.parameters()})
.serialize(serializer)
}
}