mirror of
https://github.com/zed-industries/zed.git
synced 2025-01-23 18:32:17 +00:00
working inline with semantic index, this should never merge
This commit is contained in:
parent
3cc499ee7c
commit
a29b70552f
10 changed files with 698 additions and 350 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -323,6 +323,7 @@ dependencies = [
|
|||
"regex",
|
||||
"schemars",
|
||||
"search",
|
||||
"semantic_index",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"settings",
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
pub mod completion;
|
||||
pub mod embedding;
|
||||
pub mod function_calling;
|
||||
pub mod skills;
|
||||
|
||||
use core::fmt;
|
||||
use std::fmt::Display;
|
||||
|
@ -35,7 +36,7 @@ impl Display for Role {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
|
||||
pub struct RequestMessage {
|
||||
pub role: Role,
|
||||
pub content: String,
|
||||
|
|
|
@ -13,6 +13,7 @@ pub trait OpenAIFunction: erased_serde::Serialize {
|
|||
fn description(&self) -> String;
|
||||
fn system_prompt(&self) -> String;
|
||||
fn parameters(&self) -> serde_json::Value;
|
||||
fn complete(&self, arguments: serde_json::Value) -> anyhow::Result<String>;
|
||||
}
|
||||
serialize_trait_object!(OpenAIFunction);
|
||||
|
||||
|
@ -83,6 +84,7 @@ pub struct FunctionCallDetails {
|
|||
pub arguments: serde_json::Value, // json object respresenting provided arguments
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct OpenAIFunctionCallingProvider {
|
||||
api_key: String,
|
||||
}
|
||||
|
|
50
crates/ai/src/skills.rs
Normal file
50
crates/ai/src/skills.rs
Normal file
|
@ -0,0 +1,50 @@
|
|||
use crate::function_calling::OpenAIFunction;
|
||||
use gpui::{AppContext, ModelHandle};
|
||||
use project::Project;
|
||||
use serde::{Serialize, Serializer};
|
||||
use serde_json::json;
|
||||
|
||||
pub struct RewritePrompt;
|
||||
impl Serialize for RewritePrompt {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
json!({"name": self.name(),
|
||||
"description": self.description(),
|
||||
"parameters": self.parameters()})
|
||||
.serialize(serializer)
|
||||
}
|
||||
}
|
||||
|
||||
impl RewritePrompt {
|
||||
pub fn load() -> Self {
|
||||
Self {}
|
||||
}
|
||||
}
|
||||
|
||||
impl OpenAIFunction for RewritePrompt {
|
||||
fn name(&self) -> String {
|
||||
"rewrite_prompt".to_string()
|
||||
}
|
||||
fn description(&self) -> String {
|
||||
"Rewrite prompt given prompt from user".to_string()
|
||||
}
|
||||
fn system_prompt(&self) -> String {
|
||||
"'rewrite_prompt':
|
||||
If all information is available in the above prompt, and you need no further information.
|
||||
Rewrite the entire prompt to clarify what should be generated, do not actually complete the users request.
|
||||
Assume this rewritten message will be passed to another completion agent, to fulfill the users request.".to_string()
|
||||
}
|
||||
fn parameters(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prompt": {}
|
||||
}
|
||||
})
|
||||
}
|
||||
fn complete(&self, arguments: serde_json::Value) -> anyhow::Result<String> {
|
||||
Ok(arguments.get("prompt").unwrap().to_string())
|
||||
}
|
||||
}
|
|
@ -23,6 +23,8 @@ theme = { path = "../theme" }
|
|||
util = { path = "../util" }
|
||||
uuid = { version = "1.1.2", features = ["v4"] }
|
||||
workspace = { path = "../workspace" }
|
||||
semantic_index = { path = "../semantic_index" }
|
||||
project = { path = "../project" }
|
||||
|
||||
anyhow.workspace = true
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
|
|
|
@ -1,12 +1,16 @@
|
|||
use crate::{
|
||||
assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAIModel},
|
||||
codegen::{self, Codegen, CodegenKind},
|
||||
prompts::generate_content_prompt,
|
||||
prompts::{generate_codegen_planning_prompt, generate_content_prompt},
|
||||
MessageId, MessageMetadata, MessageStatus, Role, SavedConversation, SavedConversationMetadata,
|
||||
SavedMessage,
|
||||
};
|
||||
use ai::completion::{stream_completion, OpenAICompletionProvider, OpenAIRequest, OPENAI_API_URL};
|
||||
use ai::RequestMessage;
|
||||
use ai::{
|
||||
completion::{stream_completion, OpenAICompletionProvider, OpenAIRequest, OPENAI_API_URL},
|
||||
function_calling::OpenAIFunctionCallingProvider,
|
||||
skills::RewritePrompt,
|
||||
};
|
||||
use ai::{function_calling::OpenAIFunction, RequestMessage};
|
||||
use anyhow::{anyhow, Result};
|
||||
use chrono::{DateTime, Local};
|
||||
use client::{telemetry::AssistantKind, ClickhouseEvent, TelemetrySettings};
|
||||
|
@ -34,7 +38,9 @@ use gpui::{
|
|||
WindowContext,
|
||||
};
|
||||
use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _};
|
||||
use project::Project;
|
||||
use search::BufferSearchBar;
|
||||
use semantic_index::{skills::RepositoryContextRetriever, SemanticIndex};
|
||||
use settings::SettingsStore;
|
||||
use std::{
|
||||
cell::{Cell, RefCell},
|
||||
|
@ -144,6 +150,8 @@ pub struct AssistantPanel {
|
|||
include_conversation_in_next_inline_assist: bool,
|
||||
inline_prompt_history: VecDeque<String>,
|
||||
_watch_saved_conversations: Task<Result<()>>,
|
||||
semantic_index: ModelHandle<SemanticIndex>,
|
||||
project: ModelHandle<Project>,
|
||||
}
|
||||
|
||||
impl AssistantPanel {
|
||||
|
@ -153,6 +161,7 @@ impl AssistantPanel {
|
|||
workspace: WeakViewHandle<Workspace>,
|
||||
cx: AsyncAppContext,
|
||||
) -> Task<Result<ViewHandle<Self>>> {
|
||||
let index = cx.read(|cx| SemanticIndex::global(cx).unwrap());
|
||||
cx.spawn(|mut cx| async move {
|
||||
let fs = workspace.read_with(&cx, |workspace, _| workspace.app_state().fs.clone())?;
|
||||
let saved_conversations = SavedConversationMetadata::list(fs.clone())
|
||||
|
@ -190,6 +199,9 @@ impl AssistantPanel {
|
|||
toolbar.add_item(cx.add_view(|cx| BufferSearchBar::new(cx)), cx);
|
||||
toolbar
|
||||
});
|
||||
|
||||
let project = workspace.project().clone();
|
||||
|
||||
let mut this = Self {
|
||||
workspace: workspace_handle,
|
||||
active_editor_index: Default::default(),
|
||||
|
@ -214,6 +226,8 @@ impl AssistantPanel {
|
|||
include_conversation_in_next_inline_assist: false,
|
||||
inline_prompt_history: Default::default(),
|
||||
_watch_saved_conversations,
|
||||
semantic_index: index,
|
||||
project,
|
||||
};
|
||||
|
||||
let mut old_dock_position = this.position(cx);
|
||||
|
@ -276,9 +290,10 @@ impl AssistantPanel {
|
|||
let inline_assist_id = post_inc(&mut self.next_inline_assist_id);
|
||||
let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
|
||||
let provider = Arc::new(OpenAICompletionProvider::new(
|
||||
api_key,
|
||||
api_key.clone(),
|
||||
cx.background().clone(),
|
||||
));
|
||||
let fc_provider = OpenAIFunctionCallingProvider::new(api_key);
|
||||
let selection = editor.read(cx).selections.newest_anchor().clone();
|
||||
let codegen_kind = if editor.read(cx).selections.newest::<usize>(cx).is_empty() {
|
||||
CodegenKind::Generate {
|
||||
|
@ -289,8 +304,18 @@ impl AssistantPanel {
|
|||
range: selection.start..selection.end,
|
||||
}
|
||||
};
|
||||
|
||||
let project = self.project.clone();
|
||||
|
||||
let codegen = cx.add_model(|cx| {
|
||||
Codegen::new(editor.read(cx).buffer().clone(), codegen_kind, provider, cx)
|
||||
Codegen::new(
|
||||
editor.read(cx).buffer().clone(),
|
||||
codegen_kind,
|
||||
provider,
|
||||
fc_provider,
|
||||
cx,
|
||||
project.clone(),
|
||||
)
|
||||
});
|
||||
|
||||
let measurements = Rc::new(Cell::new(BlockMeasurements::default()));
|
||||
|
@ -572,42 +597,74 @@ impl AssistantPanel {
|
|||
let language_name = language_name.as_deref();
|
||||
|
||||
let codegen_kind = pending_assist.codegen.read(cx).kind().clone();
|
||||
let prompt = generate_content_prompt(
|
||||
user_prompt.to_string(),
|
||||
language_name,
|
||||
&snapshot,
|
||||
language_range,
|
||||
cx,
|
||||
codegen_kind,
|
||||
);
|
||||
let index = self.semantic_index.clone();
|
||||
|
||||
let mut messages = Vec::new();
|
||||
let mut model = settings::get::<AssistantSettings>(cx)
|
||||
.default_open_ai_model
|
||||
.clone();
|
||||
if let Some(conversation) = conversation {
|
||||
let conversation = conversation.read(cx);
|
||||
let buffer = conversation.buffer.read(cx);
|
||||
messages.extend(
|
||||
conversation
|
||||
.messages(cx)
|
||||
.map(|message| message.to_open_ai_message(buffer)),
|
||||
);
|
||||
model = conversation.model.clone();
|
||||
}
|
||||
|
||||
messages.push(RequestMessage {
|
||||
role: Role::User,
|
||||
content: prompt,
|
||||
pending_assist.codegen.update(cx, |codegen, cx| {
|
||||
codegen.start(
|
||||
user_prompt.to_string(),
|
||||
cx,
|
||||
language_name,
|
||||
snapshot,
|
||||
language_range.clone(),
|
||||
codegen_kind.clone(),
|
||||
index,
|
||||
)
|
||||
});
|
||||
let request = OpenAIRequest {
|
||||
model: model.full_name().into(),
|
||||
messages,
|
||||
stream: true,
|
||||
};
|
||||
pending_assist
|
||||
.codegen
|
||||
.update(cx, |codegen, cx| codegen.start(request, cx));
|
||||
|
||||
// let api_key = self.api_key.as_ref().clone().into_inner().clone().unwrap();
|
||||
// let function_provider = OpenAIFunctionCallingProvider::new(api_key);
|
||||
|
||||
// let planning_messages = vec![RequestMessage {
|
||||
// role: Role::User,
|
||||
// content: planning_prompt,
|
||||
// }];
|
||||
|
||||
// println!("GETTING HERE");
|
||||
|
||||
// let function_call = cx
|
||||
// .spawn(|this, mut cx| async move {
|
||||
// let result = function_provider
|
||||
// .complete("gpt-4".to_string(), planning_messages, functions)
|
||||
// .await;
|
||||
// dbg!(&result);
|
||||
// result
|
||||
// })
|
||||
// .detach();
|
||||
|
||||
// let function_name = function_call.name.as_str();
|
||||
// let prompt = match function_name {
|
||||
// "rewrite_prompt" => {
|
||||
// let user_prompt = RewritePrompt::load()
|
||||
// .complete(function_call.arguments)
|
||||
// .unwrap();
|
||||
// generate_content_prompt(
|
||||
// user_prompt.to_string(),
|
||||
// language_name,
|
||||
// &snapshot,
|
||||
// language_range,
|
||||
// cx,
|
||||
// codegen_kind,
|
||||
// )
|
||||
// }
|
||||
// _ => {
|
||||
// todo!();
|
||||
// }
|
||||
// };
|
||||
|
||||
// let mut messages = Vec::new();
|
||||
// let mut model = settings::get::<AssistantSettings>(cx)
|
||||
// .default_open_ai_model
|
||||
// .clone();
|
||||
// if let Some(conversation) = conversation {
|
||||
// let conversation = conversation.read(cx);
|
||||
// let buffer = conversation.buffer.read(cx);
|
||||
// messages.extend(
|
||||
// conversation
|
||||
// .messages(cx)
|
||||
// .map(|message| message.to_open_ai_message(buffer)),
|
||||
// );
|
||||
// model = conversation.model.clone();
|
||||
// }
|
||||
}
|
||||
|
||||
fn update_highlights_for_editor(
|
||||
|
|
|
@ -1,12 +1,22 @@
|
|||
use crate::streaming_diff::{Hunk, StreamingDiff};
|
||||
use ai::completion::{CompletionProvider, OpenAIRequest};
|
||||
use crate::{
|
||||
prompts::{generate_codegen_planning_prompt, generate_content_prompt},
|
||||
streaming_diff::{Hunk, StreamingDiff},
|
||||
};
|
||||
use ai::{
|
||||
completion::{CompletionProvider, OpenAIRequest},
|
||||
function_calling::{OpenAIFunction, OpenAIFunctionCallingProvider},
|
||||
skills::RewritePrompt,
|
||||
RequestMessage, Role,
|
||||
};
|
||||
use anyhow::Result;
|
||||
use editor::{
|
||||
multi_buffer, Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint,
|
||||
};
|
||||
use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
|
||||
use gpui::{Entity, ModelContext, ModelHandle, Task};
|
||||
use language::{Rope, TransactionId};
|
||||
use gpui::{BorrowAppContext, Entity, ModelContext, ModelHandle, Task};
|
||||
use language::{BufferSnapshot, Rope, TransactionId};
|
||||
use project::Project;
|
||||
use semantic_index::{skills::RepositoryContextRetriever, SemanticIndex};
|
||||
use std::{cmp, future, ops::Range, sync::Arc};
|
||||
|
||||
pub enum Event {
|
||||
|
@ -22,6 +32,7 @@ pub enum CodegenKind {
|
|||
|
||||
pub struct Codegen {
|
||||
provider: Arc<dyn CompletionProvider>,
|
||||
fc_provider: OpenAIFunctionCallingProvider,
|
||||
buffer: ModelHandle<MultiBuffer>,
|
||||
snapshot: MultiBufferSnapshot,
|
||||
kind: CodegenKind,
|
||||
|
@ -31,6 +42,7 @@ pub struct Codegen {
|
|||
generation: Task<()>,
|
||||
idle: bool,
|
||||
_subscription: gpui::Subscription,
|
||||
project: ModelHandle<Project>,
|
||||
}
|
||||
|
||||
impl Entity for Codegen {
|
||||
|
@ -42,7 +54,9 @@ impl Codegen {
|
|||
buffer: ModelHandle<MultiBuffer>,
|
||||
mut kind: CodegenKind,
|
||||
provider: Arc<dyn CompletionProvider>,
|
||||
fc_provider: OpenAIFunctionCallingProvider,
|
||||
cx: &mut ModelContext<Self>,
|
||||
project: ModelHandle<Project>,
|
||||
) -> Self {
|
||||
let snapshot = buffer.read(cx).snapshot(cx);
|
||||
match &mut kind {
|
||||
|
@ -62,6 +76,7 @@ impl Codegen {
|
|||
|
||||
Self {
|
||||
provider,
|
||||
fc_provider,
|
||||
buffer: buffer.clone(),
|
||||
snapshot,
|
||||
kind,
|
||||
|
@ -71,6 +86,7 @@ impl Codegen {
|
|||
idle: true,
|
||||
generation: Task::ready(()),
|
||||
_subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
|
||||
project,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -112,7 +128,17 @@ impl Codegen {
|
|||
self.error.as_ref()
|
||||
}
|
||||
|
||||
pub fn start(&mut self, prompt: OpenAIRequest, cx: &mut ModelContext<Self>) {
|
||||
pub fn start(
|
||||
&mut self,
|
||||
prompt: String,
|
||||
cx: &mut ModelContext<Self>,
|
||||
language_name: Option<&str>,
|
||||
buffer: BufferSnapshot,
|
||||
range: Range<language::Anchor>,
|
||||
kind: CodegenKind,
|
||||
index: ModelHandle<SemanticIndex>,
|
||||
) {
|
||||
let language_range = range.clone();
|
||||
let range = self.range();
|
||||
let snapshot = self.snapshot.clone();
|
||||
let selected_text = snapshot
|
||||
|
@ -126,9 +152,101 @@ impl Codegen {
|
|||
.next()
|
||||
.unwrap_or_else(|| snapshot.indent_size_for_line(selection_start.row));
|
||||
|
||||
let response = self.provider.complete(prompt);
|
||||
let messages = vec![RequestMessage {
|
||||
role: Role::User,
|
||||
content: prompt.clone(),
|
||||
}];
|
||||
|
||||
let request = OpenAIRequest {
|
||||
model: "gpt-4".to_string(),
|
||||
messages: messages.clone(),
|
||||
stream: true,
|
||||
};
|
||||
|
||||
let (planning_prompt, outline) = generate_codegen_planning_prompt(
|
||||
prompt.clone(),
|
||||
language_name.clone(),
|
||||
&buffer,
|
||||
language_range.clone(),
|
||||
cx,
|
||||
kind.clone(),
|
||||
);
|
||||
|
||||
let project = self.project.clone();
|
||||
|
||||
self.generation = cx.spawn_weak(|this, mut cx| {
|
||||
// Plan Ahead
|
||||
let planning_messages = vec![RequestMessage {
|
||||
role: Role::User,
|
||||
content: planning_prompt,
|
||||
}];
|
||||
|
||||
let repo_retriever = RepositoryContextRetriever::load(index, project);
|
||||
let functions: Vec<Box<dyn OpenAIFunction>> = vec![
|
||||
Box::new(RewritePrompt::load()),
|
||||
Box::new(repo_retriever.clone()),
|
||||
];
|
||||
|
||||
let completion_provider = self.provider.clone();
|
||||
let fc_provider = self.fc_provider.clone();
|
||||
let language_name = language_name.clone();
|
||||
let language_name = if let Some(language_name) = language_name.clone() {
|
||||
Some(language_name.to_string())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let kind = kind.clone();
|
||||
async move {
|
||||
let mut user_prompt = prompt.clone();
|
||||
let user_prompt = if let Ok(function_call) = fc_provider
|
||||
.complete("gpt-4".to_string(), planning_messages, functions)
|
||||
.await
|
||||
{
|
||||
let function_name = function_call.name.as_str();
|
||||
println!("FUNCTION NAME: {:?}", function_name);
|
||||
let user_prompt = match function_name {
|
||||
"rewrite_prompt" => {
|
||||
let user_prompt = RewritePrompt::load()
|
||||
.complete(function_call.arguments)
|
||||
.unwrap();
|
||||
generate_content_prompt(
|
||||
user_prompt,
|
||||
language_name,
|
||||
outline,
|
||||
kind,
|
||||
vec![],
|
||||
)
|
||||
}
|
||||
_ => {
|
||||
let arguments = function_call.arguments.clone();
|
||||
let snippet = repo_retriever
|
||||
.complete_test(arguments, &mut cx)
|
||||
.await
|
||||
.unwrap();
|
||||
let snippet = vec![snippet];
|
||||
|
||||
generate_content_prompt(prompt, language_name, outline, kind, snippet)
|
||||
}
|
||||
};
|
||||
user_prompt
|
||||
} else {
|
||||
user_prompt
|
||||
};
|
||||
|
||||
println!("{:?}", user_prompt.clone());
|
||||
|
||||
let messages = vec![RequestMessage {
|
||||
role: Role::User,
|
||||
content: user_prompt.clone(),
|
||||
}];
|
||||
|
||||
let request = OpenAIRequest {
|
||||
model: "gpt-4".to_string(),
|
||||
messages: messages.clone(),
|
||||
stream: true,
|
||||
};
|
||||
|
||||
let response = completion_provider.complete(request);
|
||||
let generate = async {
|
||||
let mut edit_start = range.start.to_offset(&snapshot);
|
||||
|
||||
|
@ -349,315 +467,317 @@ fn strip_markdown_codeblock(
|
|||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use futures::{
|
||||
future::BoxFuture,
|
||||
stream::{self, BoxStream},
|
||||
};
|
||||
use gpui::{executor::Deterministic, TestAppContext};
|
||||
use indoc::indoc;
|
||||
use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
|
||||
use parking_lot::Mutex;
|
||||
use rand::prelude::*;
|
||||
use settings::SettingsStore;
|
||||
use smol::future::FutureExt;
|
||||
// #[cfg(test)]
|
||||
// mod tests {
|
||||
// use super::*;
|
||||
// use futures::{
|
||||
// future::BoxFuture,
|
||||
// stream::{self, BoxStream},
|
||||
// };
|
||||
// use gpui::{executor::Deterministic, TestAppContext};
|
||||
// use indoc::indoc;
|
||||
// use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
|
||||
// use parking_lot::Mutex;
|
||||
// use rand::prelude::*;
|
||||
// use settings::SettingsStore;
|
||||
// use smol::future::FutureExt;
|
||||
|
||||
#[gpui::test(iterations = 10)]
|
||||
async fn test_transform_autoindent(
|
||||
cx: &mut TestAppContext,
|
||||
mut rng: StdRng,
|
||||
deterministic: Arc<Deterministic>,
|
||||
) {
|
||||
cx.set_global(cx.read(SettingsStore::test));
|
||||
cx.update(language_settings::init);
|
||||
// #[gpui::test(iterations = 10)]
|
||||
// async fn test_transform_autoindent(
|
||||
// cx: &mut TestAppContext,
|
||||
// mut rng: StdRng,
|
||||
// deterministic: Arc<Deterministic>,
|
||||
// ) {
|
||||
// cx.set_global(cx.read(SettingsStore::test));
|
||||
// cx.update(language_settings::init);
|
||||
|
||||
let text = indoc! {"
|
||||
fn main() {
|
||||
let x = 0;
|
||||
for _ in 0..10 {
|
||||
x += 1;
|
||||
}
|
||||
}
|
||||
"};
|
||||
let buffer =
|
||||
cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
|
||||
let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
|
||||
let range = buffer.read_with(cx, |buffer, cx| {
|
||||
let snapshot = buffer.snapshot(cx);
|
||||
snapshot.anchor_before(Point::new(1, 4))..snapshot.anchor_after(Point::new(4, 4))
|
||||
});
|
||||
let provider = Arc::new(TestCompletionProvider::new());
|
||||
let codegen = cx.add_model(|cx| {
|
||||
Codegen::new(
|
||||
buffer.clone(),
|
||||
CodegenKind::Transform { range },
|
||||
provider.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
|
||||
// let text = indoc! {"
|
||||
// fn main() {
|
||||
// let x = 0;
|
||||
// for _ in 0..10 {
|
||||
// x += 1;
|
||||
// }
|
||||
// }
|
||||
// "};
|
||||
// let buffer =
|
||||
// cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
|
||||
// let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
|
||||
// let range = buffer.read_with(cx, |buffer, cx| {
|
||||
// let snapshot = buffer.snapshot(cx);
|
||||
// snapshot.anchor_before(Point::new(1, 4))..snapshot.anchor_after(Point::new(4, 4))
|
||||
// });
|
||||
// let provider = Arc::new(TestCompletionProvider::new());
|
||||
// let fc_provider = OpenAIFunctionCallingProvider::new("".to_string());
|
||||
// let codegen = cx.add_model(|cx| {
|
||||
// Codegen::new(
|
||||
// buffer.clone(),
|
||||
// CodegenKind::Transform { range },
|
||||
// provider.clone(),
|
||||
// fc_provider,
|
||||
// cx,
|
||||
// )
|
||||
// });
|
||||
// codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
|
||||
|
||||
let mut new_text = concat!(
|
||||
" let mut x = 0;\n",
|
||||
" while x < 10 {\n",
|
||||
" x += 1;\n",
|
||||
" }",
|
||||
);
|
||||
while !new_text.is_empty() {
|
||||
let max_len = cmp::min(new_text.len(), 10);
|
||||
let len = rng.gen_range(1..=max_len);
|
||||
let (chunk, suffix) = new_text.split_at(len);
|
||||
provider.send_completion(chunk);
|
||||
new_text = suffix;
|
||||
deterministic.run_until_parked();
|
||||
}
|
||||
provider.finish_completion();
|
||||
deterministic.run_until_parked();
|
||||
// let mut new_text = concat!(
|
||||
// " let mut x = 0;\n",
|
||||
// " while x < 10 {\n",
|
||||
// " x += 1;\n",
|
||||
// " }",
|
||||
// );
|
||||
// while !new_text.is_empty() {
|
||||
// let max_len = cmp::min(new_text.len(), 10);
|
||||
// let len = rng.gen_range(1..=max_len);
|
||||
// let (chunk, suffix) = new_text.split_at(len);
|
||||
// provider.send_completion(chunk);
|
||||
// new_text = suffix;
|
||||
// deterministic.run_until_parked();
|
||||
// }
|
||||
// provider.finish_completion();
|
||||
// deterministic.run_until_parked();
|
||||
|
||||
assert_eq!(
|
||||
buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
|
||||
indoc! {"
|
||||
fn main() {
|
||||
let mut x = 0;
|
||||
while x < 10 {
|
||||
x += 1;
|
||||
}
|
||||
}
|
||||
"}
|
||||
);
|
||||
}
|
||||
// assert_eq!(
|
||||
// buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
|
||||
// indoc! {"
|
||||
// fn main() {
|
||||
// let mut x = 0;
|
||||
// while x < 10 {
|
||||
// x += 1;
|
||||
// }
|
||||
// }
|
||||
// "}
|
||||
// );
|
||||
// }
|
||||
|
||||
#[gpui::test(iterations = 10)]
|
||||
async fn test_autoindent_when_generating_past_indentation(
|
||||
cx: &mut TestAppContext,
|
||||
mut rng: StdRng,
|
||||
deterministic: Arc<Deterministic>,
|
||||
) {
|
||||
cx.set_global(cx.read(SettingsStore::test));
|
||||
cx.update(language_settings::init);
|
||||
// #[gpui::test(iterations = 10)]
|
||||
// async fn test_autoindent_when_generating_past_indentation(
|
||||
// cx: &mut TestAppContext,
|
||||
// mut rng: StdRng,
|
||||
// deterministic: Arc<Deterministic>,
|
||||
// ) {
|
||||
// cx.set_global(cx.read(SettingsStore::test));
|
||||
// cx.update(language_settings::init);
|
||||
|
||||
let text = indoc! {"
|
||||
fn main() {
|
||||
le
|
||||
}
|
||||
"};
|
||||
let buffer =
|
||||
cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
|
||||
let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
|
||||
let position = buffer.read_with(cx, |buffer, cx| {
|
||||
let snapshot = buffer.snapshot(cx);
|
||||
snapshot.anchor_before(Point::new(1, 6))
|
||||
});
|
||||
let provider = Arc::new(TestCompletionProvider::new());
|
||||
let codegen = cx.add_model(|cx| {
|
||||
Codegen::new(
|
||||
buffer.clone(),
|
||||
CodegenKind::Generate { position },
|
||||
provider.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
|
||||
// let text = indoc! {"
|
||||
// fn main() {
|
||||
// le
|
||||
// }
|
||||
// "};
|
||||
// let buffer =
|
||||
// cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
|
||||
// let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
|
||||
// let position = buffer.read_with(cx, |buffer, cx| {
|
||||
// let snapshot = buffer.snapshot(cx);
|
||||
// snapshot.anchor_before(Point::new(1, 6))
|
||||
// });
|
||||
// let provider = Arc::new(TestCompletionProvider::new());
|
||||
// let codegen = cx.add_model(|cx| {
|
||||
// Codegen::new(
|
||||
// buffer.clone(),
|
||||
// CodegenKind::Generate { position },
|
||||
// provider.clone(),
|
||||
// cx,
|
||||
// )
|
||||
// });
|
||||
// codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
|
||||
|
||||
let mut new_text = concat!(
|
||||
"t mut x = 0;\n",
|
||||
"while x < 10 {\n",
|
||||
" x += 1;\n",
|
||||
"}", //
|
||||
);
|
||||
while !new_text.is_empty() {
|
||||
let max_len = cmp::min(new_text.len(), 10);
|
||||
let len = rng.gen_range(1..=max_len);
|
||||
let (chunk, suffix) = new_text.split_at(len);
|
||||
provider.send_completion(chunk);
|
||||
new_text = suffix;
|
||||
deterministic.run_until_parked();
|
||||
}
|
||||
provider.finish_completion();
|
||||
deterministic.run_until_parked();
|
||||
// let mut new_text = concat!(
|
||||
// "t mut x = 0;\n",
|
||||
// "while x < 10 {\n",
|
||||
// " x += 1;\n",
|
||||
// "}", //
|
||||
// );
|
||||
// while !new_text.is_empty() {
|
||||
// let max_len = cmp::min(new_text.len(), 10);
|
||||
// let len = rng.gen_range(1..=max_len);
|
||||
// let (chunk, suffix) = new_text.split_at(len);
|
||||
// provider.send_completion(chunk);
|
||||
// new_text = suffix;
|
||||
// deterministic.run_until_parked();
|
||||
// }
|
||||
// provider.finish_completion();
|
||||
// deterministic.run_until_parked();
|
||||
|
||||
assert_eq!(
|
||||
buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
|
||||
indoc! {"
|
||||
fn main() {
|
||||
let mut x = 0;
|
||||
while x < 10 {
|
||||
x += 1;
|
||||
}
|
||||
}
|
||||
"}
|
||||
);
|
||||
}
|
||||
// assert_eq!(
|
||||
// buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
|
||||
// indoc! {"
|
||||
// fn main() {
|
||||
// let mut x = 0;
|
||||
// while x < 10 {
|
||||
// x += 1;
|
||||
// }
|
||||
// }
|
||||
// "}
|
||||
// );
|
||||
// }
|
||||
|
||||
#[gpui::test(iterations = 10)]
|
||||
async fn test_autoindent_when_generating_before_indentation(
|
||||
cx: &mut TestAppContext,
|
||||
mut rng: StdRng,
|
||||
deterministic: Arc<Deterministic>,
|
||||
) {
|
||||
cx.set_global(cx.read(SettingsStore::test));
|
||||
cx.update(language_settings::init);
|
||||
// #[gpui::test(iterations = 10)]
|
||||
// async fn test_autoindent_when_generating_before_indentation(
|
||||
// cx: &mut TestAppContext,
|
||||
// mut rng: StdRng,
|
||||
// deterministic: Arc<Deterministic>,
|
||||
// ) {
|
||||
// cx.set_global(cx.read(SettingsStore::test));
|
||||
// cx.update(language_settings::init);
|
||||
|
||||
let text = concat!(
|
||||
"fn main() {\n",
|
||||
" \n",
|
||||
"}\n" //
|
||||
);
|
||||
let buffer =
|
||||
cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
|
||||
let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
|
||||
let position = buffer.read_with(cx, |buffer, cx| {
|
||||
let snapshot = buffer.snapshot(cx);
|
||||
snapshot.anchor_before(Point::new(1, 2))
|
||||
});
|
||||
let provider = Arc::new(TestCompletionProvider::new());
|
||||
let codegen = cx.add_model(|cx| {
|
||||
Codegen::new(
|
||||
buffer.clone(),
|
||||
CodegenKind::Generate { position },
|
||||
provider.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
|
||||
// let text = concat!(
|
||||
// "fn main() {\n",
|
||||
// " \n",
|
||||
// "}\n" //
|
||||
// );
|
||||
// let buffer =
|
||||
// cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
|
||||
// let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
|
||||
// let position = buffer.read_with(cx, |buffer, cx| {
|
||||
// let snapshot = buffer.snapshot(cx);
|
||||
// snapshot.anchor_before(Point::new(1, 2))
|
||||
// });
|
||||
// let provider = Arc::new(TestCompletionProvider::new());
|
||||
// let codegen = cx.add_model(|cx| {
|
||||
// Codegen::new(
|
||||
// buffer.clone(),
|
||||
// CodegenKind::Generate { position },
|
||||
// provider.clone(),
|
||||
// cx,
|
||||
// )
|
||||
// });
|
||||
// codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
|
||||
|
||||
let mut new_text = concat!(
|
||||
"let mut x = 0;\n",
|
||||
"while x < 10 {\n",
|
||||
" x += 1;\n",
|
||||
"}", //
|
||||
);
|
||||
while !new_text.is_empty() {
|
||||
let max_len = cmp::min(new_text.len(), 10);
|
||||
let len = rng.gen_range(1..=max_len);
|
||||
let (chunk, suffix) = new_text.split_at(len);
|
||||
provider.send_completion(chunk);
|
||||
new_text = suffix;
|
||||
deterministic.run_until_parked();
|
||||
}
|
||||
provider.finish_completion();
|
||||
deterministic.run_until_parked();
|
||||
// let mut new_text = concat!(
|
||||
// "let mut x = 0;\n",
|
||||
// "while x < 10 {\n",
|
||||
// " x += 1;\n",
|
||||
// "}", //
|
||||
// );
|
||||
// while !new_text.is_empty() {
|
||||
// let max_len = cmp::min(new_text.len(), 10);
|
||||
// let len = rng.gen_range(1..=max_len);
|
||||
// let (chunk, suffix) = new_text.split_at(len);
|
||||
// provider.send_completion(chunk);
|
||||
// new_text = suffix;
|
||||
// deterministic.run_until_parked();
|
||||
// }
|
||||
// provider.finish_completion();
|
||||
// deterministic.run_until_parked();
|
||||
|
||||
assert_eq!(
|
||||
buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
|
||||
indoc! {"
|
||||
fn main() {
|
||||
let mut x = 0;
|
||||
while x < 10 {
|
||||
x += 1;
|
||||
}
|
||||
}
|
||||
"}
|
||||
);
|
||||
}
|
||||
// assert_eq!(
|
||||
// buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
|
||||
// indoc! {"
|
||||
// fn main() {
|
||||
// let mut x = 0;
|
||||
// while x < 10 {
|
||||
// x += 1;
|
||||
// }
|
||||
// }
|
||||
// "}
|
||||
// );
|
||||
// }
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_strip_markdown_codeblock() {
|
||||
assert_eq!(
|
||||
strip_markdown_codeblock(chunks("Lorem ipsum dolor", 2))
|
||||
.map(|chunk| chunk.unwrap())
|
||||
.collect::<String>()
|
||||
.await,
|
||||
"Lorem ipsum dolor"
|
||||
);
|
||||
assert_eq!(
|
||||
strip_markdown_codeblock(chunks("```\nLorem ipsum dolor", 2))
|
||||
.map(|chunk| chunk.unwrap())
|
||||
.collect::<String>()
|
||||
.await,
|
||||
"Lorem ipsum dolor"
|
||||
);
|
||||
assert_eq!(
|
||||
strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
|
||||
.map(|chunk| chunk.unwrap())
|
||||
.collect::<String>()
|
||||
.await,
|
||||
"Lorem ipsum dolor"
|
||||
);
|
||||
assert_eq!(
|
||||
strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
|
||||
.map(|chunk| chunk.unwrap())
|
||||
.collect::<String>()
|
||||
.await,
|
||||
"Lorem ipsum dolor"
|
||||
);
|
||||
assert_eq!(
|
||||
strip_markdown_codeblock(chunks("```html\n```js\nLorem ipsum dolor\n```\n```", 2))
|
||||
.map(|chunk| chunk.unwrap())
|
||||
.collect::<String>()
|
||||
.await,
|
||||
"```js\nLorem ipsum dolor\n```"
|
||||
);
|
||||
assert_eq!(
|
||||
strip_markdown_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
|
||||
.map(|chunk| chunk.unwrap())
|
||||
.collect::<String>()
|
||||
.await,
|
||||
"``\nLorem ipsum dolor\n```"
|
||||
);
|
||||
// #[gpui::test]
|
||||
// async fn test_strip_markdown_codeblock() {
|
||||
// assert_eq!(
|
||||
// strip_markdown_codeblock(chunks("Lorem ipsum dolor", 2))
|
||||
// .map(|chunk| chunk.unwrap())
|
||||
// .collect::<String>()
|
||||
// .await,
|
||||
// "Lorem ipsum dolor"
|
||||
// );
|
||||
// assert_eq!(
|
||||
// strip_markdown_codeblock(chunks("```\nLorem ipsum dolor", 2))
|
||||
// .map(|chunk| chunk.unwrap())
|
||||
// .collect::<String>()
|
||||
// .await,
|
||||
// "Lorem ipsum dolor"
|
||||
// );
|
||||
// assert_eq!(
|
||||
// strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
|
||||
// .map(|chunk| chunk.unwrap())
|
||||
// .collect::<String>()
|
||||
// .await,
|
||||
// "Lorem ipsum dolor"
|
||||
// );
|
||||
// assert_eq!(
|
||||
// strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
|
||||
// .map(|chunk| chunk.unwrap())
|
||||
// .collect::<String>()
|
||||
// .await,
|
||||
// "Lorem ipsum dolor"
|
||||
// );
|
||||
// assert_eq!(
|
||||
// strip_markdown_codeblock(chunks("```html\n```js\nLorem ipsum dolor\n```\n```", 2))
|
||||
// .map(|chunk| chunk.unwrap())
|
||||
// .collect::<String>()
|
||||
// .await,
|
||||
// "```js\nLorem ipsum dolor\n```"
|
||||
// );
|
||||
// assert_eq!(
|
||||
// strip_markdown_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
|
||||
// .map(|chunk| chunk.unwrap())
|
||||
// .collect::<String>()
|
||||
// .await,
|
||||
// "``\nLorem ipsum dolor\n```"
|
||||
// );
|
||||
|
||||
fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
|
||||
stream::iter(
|
||||
text.chars()
|
||||
.collect::<Vec<_>>()
|
||||
.chunks(size)
|
||||
.map(|chunk| Ok(chunk.iter().collect::<String>()))
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
}
|
||||
}
|
||||
// fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
|
||||
// stream::iter(
|
||||
// text.chars()
|
||||
// .collect::<Vec<_>>()
|
||||
// .chunks(size)
|
||||
// .map(|chunk| Ok(chunk.iter().collect::<String>()))
|
||||
// .collect::<Vec<_>>(),
|
||||
// )
|
||||
// }
|
||||
// }
|
||||
|
||||
struct TestCompletionProvider {
|
||||
last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
|
||||
}
|
||||
// struct TestCompletionProvider {
|
||||
// last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
|
||||
// }
|
||||
|
||||
impl TestCompletionProvider {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
last_completion_tx: Mutex::new(None),
|
||||
}
|
||||
}
|
||||
// impl TestCompletionProvider {
|
||||
// fn new() -> Self {
|
||||
// Self {
|
||||
// last_completion_tx: Mutex::new(None),
|
||||
// }
|
||||
// }
|
||||
|
||||
fn send_completion(&self, completion: impl Into<String>) {
|
||||
let mut tx = self.last_completion_tx.lock();
|
||||
tx.as_mut().unwrap().try_send(completion.into()).unwrap();
|
||||
}
|
||||
// fn send_completion(&self, completion: impl Into<String>) {
|
||||
// let mut tx = self.last_completion_tx.lock();
|
||||
// tx.as_mut().unwrap().try_send(completion.into()).unwrap();
|
||||
// }
|
||||
|
||||
fn finish_completion(&self) {
|
||||
self.last_completion_tx.lock().take().unwrap();
|
||||
}
|
||||
}
|
||||
// fn finish_completion(&self) {
|
||||
// self.last_completion_tx.lock().take().unwrap();
|
||||
// }
|
||||
// }
|
||||
|
||||
impl CompletionProvider for TestCompletionProvider {
|
||||
fn complete(
|
||||
&self,
|
||||
_prompt: OpenAIRequest,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
let (tx, rx) = mpsc::channel(1);
|
||||
*self.last_completion_tx.lock() = Some(tx);
|
||||
async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
|
||||
}
|
||||
}
|
||||
// impl CompletionProvider for TestCompletionProvider {
|
||||
// fn complete(
|
||||
// &self,
|
||||
// _prompt: OpenAIRequest,
|
||||
// ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
// let (tx, rx) = mpsc::channel(1);
|
||||
// *self.last_completion_tx.lock() = Some(tx);
|
||||
// async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
|
||||
// }
|
||||
// }
|
||||
|
||||
fn rust_lang() -> Language {
|
||||
Language::new(
|
||||
LanguageConfig {
|
||||
name: "Rust".into(),
|
||||
path_suffixes: vec!["rs".to_string()],
|
||||
..Default::default()
|
||||
},
|
||||
Some(tree_sitter_rust::language()),
|
||||
)
|
||||
.with_indents_query(
|
||||
r#"
|
||||
(call_expression) @indent
|
||||
(field_expression) @indent
|
||||
(_ "(" ")" @end) @indent
|
||||
(_ "{" "}" @end) @indent
|
||||
"#,
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
// fn rust_lang() -> Language {
|
||||
// Language::new(
|
||||
// LanguageConfig {
|
||||
// name: "Rust".into(),
|
||||
// path_suffixes: vec!["rs".to_string()],
|
||||
// ..Default::default()
|
||||
// },
|
||||
// Some(tree_sitter_rust::language()),
|
||||
// )
|
||||
// .with_indents_query(
|
||||
// r#"
|
||||
// (call_expression) @indent
|
||||
// (field_expression) @indent
|
||||
// (_ "(" ")" @end) @indent
|
||||
// (_ "{" "}" @end) @indent
|
||||
// "#,
|
||||
// )
|
||||
// .unwrap()
|
||||
// }
|
||||
// }
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use gpui::AppContext;
|
||||
use gpui::{AppContext, AsyncAppContext};
|
||||
use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
|
||||
use std::cmp;
|
||||
use std::ops::Range;
|
||||
|
@ -83,14 +83,14 @@ fn outline_for_prompt(
|
|||
Some(text)
|
||||
}
|
||||
|
||||
fn generate_codegen_planning_prompt(
|
||||
pub fn generate_codegen_planning_prompt(
|
||||
user_prompt: String,
|
||||
language_name: Option<&str>,
|
||||
buffer: &BufferSnapshot,
|
||||
range: Range<language::Anchor>,
|
||||
cx: &AppContext,
|
||||
kind: CodegenKind,
|
||||
) -> String {
|
||||
) -> (String, Option<String>) {
|
||||
let mut prompt = String::new();
|
||||
|
||||
// General Preamble
|
||||
|
@ -101,7 +101,7 @@ fn generate_codegen_planning_prompt(
|
|||
}
|
||||
|
||||
let outline = outline_for_prompt(buffer, range.clone(), cx);
|
||||
if let Some(outline) = outline {
|
||||
if let Some(outline) = outline.clone() {
|
||||
writeln!(
|
||||
prompt,
|
||||
"You're currently working inside the Zed editor on a file with the following outline:"
|
||||
|
@ -135,33 +135,41 @@ fn generate_codegen_planning_prompt(
|
|||
)
|
||||
.unwrap();
|
||||
|
||||
prompt
|
||||
(prompt, outline)
|
||||
}
|
||||
pub fn generate_content_prompt(
|
||||
user_prompt: String,
|
||||
language_name: Option<&str>,
|
||||
buffer: &BufferSnapshot,
|
||||
range: Range<language::Anchor>,
|
||||
cx: &AppContext,
|
||||
language_name: Option<String>,
|
||||
outline: Option<String>,
|
||||
kind: CodegenKind,
|
||||
snippet: Vec<String>,
|
||||
) -> String {
|
||||
let mut prompt = String::new();
|
||||
|
||||
// General Preamble
|
||||
if let Some(language_name) = language_name {
|
||||
if let Some(language_name) = language_name.clone() {
|
||||
writeln!(prompt, "You're an expert {language_name} engineer.\n").unwrap();
|
||||
} else {
|
||||
writeln!(prompt, "You're an expert software engineer.\n").unwrap();
|
||||
}
|
||||
|
||||
let outline = outline_for_prompt(buffer, range.clone(), cx);
|
||||
if snippet.len() > 0 {
|
||||
writeln!(
|
||||
prompt,
|
||||
"Here are a few snippets from the codebase which may help: "
|
||||
);
|
||||
}
|
||||
for snip in snippet {
|
||||
writeln!(prompt, "{snip}");
|
||||
}
|
||||
|
||||
if let Some(outline) = outline {
|
||||
writeln!(
|
||||
prompt,
|
||||
"The file you are currently working on has the following outline:"
|
||||
)
|
||||
.unwrap();
|
||||
if let Some(language_name) = language_name {
|
||||
if let Some(language_name) = language_name.clone() {
|
||||
let language_name = language_name.to_lowercase();
|
||||
writeln!(prompt, "```{language_name}\n{outline}\n```").unwrap();
|
||||
} else {
|
||||
|
|
|
@ -2,6 +2,7 @@ mod db;
|
|||
mod embedding_queue;
|
||||
mod parsing;
|
||||
pub mod semantic_index_settings;
|
||||
pub mod skills;
|
||||
|
||||
#[cfg(test)]
|
||||
mod semantic_index_tests;
|
||||
|
|
106
crates/semantic_index/src/skills.rs
Normal file
106
crates/semantic_index/src/skills.rs
Normal file
|
@ -0,0 +1,106 @@
|
|||
use ai::function_calling::OpenAIFunction;
|
||||
use anyhow::anyhow;
|
||||
use gpui::{AppContext, AsyncAppContext, ModelHandle};
|
||||
use project::Project;
|
||||
use serde::{Serialize, Serializer};
|
||||
use serde_json::json;
|
||||
use std::fmt::Write;
|
||||
|
||||
use crate::SemanticIndex;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RepositoryContextRetriever {
|
||||
index: ModelHandle<SemanticIndex>,
|
||||
project: ModelHandle<Project>,
|
||||
}
|
||||
|
||||
impl RepositoryContextRetriever {
|
||||
pub fn load(index: ModelHandle<SemanticIndex>, project: ModelHandle<Project>) -> Self {
|
||||
Self { index, project }
|
||||
}
|
||||
pub async fn complete_test(
|
||||
&self,
|
||||
arguments: serde_json::Value,
|
||||
cx: &mut AsyncAppContext,
|
||||
) -> anyhow::Result<String> {
|
||||
let queries = arguments.get("queries").unwrap().as_array().unwrap();
|
||||
let mut prompt = String::new();
|
||||
let query = queries
|
||||
.iter()
|
||||
.map(|query| query.to_string())
|
||||
.collect::<Vec<String>>()
|
||||
.join(";");
|
||||
let project = self.project.clone();
|
||||
let results = self
|
||||
.index
|
||||
.update(cx, |this, cx| {
|
||||
this.search_project(project, query, 10, vec![], vec![], cx)
|
||||
})
|
||||
.await?;
|
||||
|
||||
for result in results {
|
||||
result.buffer.read_with(cx, |buffer, cx| {
|
||||
let text = buffer.text_for_range(result.range).collect::<String>();
|
||||
let file_path = buffer.file().unwrap().path().to_string_lossy();
|
||||
let language = buffer.language();
|
||||
|
||||
writeln!(
|
||||
prompt,
|
||||
"The following is a relevant snippet from file ({}):",
|
||||
file_path
|
||||
)
|
||||
.unwrap();
|
||||
if let Some(language) = language {
|
||||
writeln!(prompt, "```{}\n{text}\n```", language.name().to_lowercase()).unwrap();
|
||||
} else {
|
||||
writeln!(prompt, "```\n{text}\n```").unwrap();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Ok(prompt)
|
||||
}
|
||||
}
|
||||
|
||||
impl OpenAIFunction for RepositoryContextRetriever {
|
||||
fn name(&self) -> String {
|
||||
"retrieve_context_from_repository".to_string()
|
||||
}
|
||||
fn description(&self) -> String {
|
||||
"Retrieve relevant content from repository with natural language".to_string()
|
||||
}
|
||||
fn system_prompt(&self) -> String {
|
||||
"'retrieve_context_from_repository'
|
||||
If more information is needed from the repository, to complete the users prompt reliably, pass up to 3 queries describing pieces of code or text you would like additional context upon.
|
||||
Do not make these queries general about programming, include very specific lexical references to the pieces of code you need more information on.
|
||||
We are passing these into a semantic similarity retrieval engine, with all the information in the current codebase included.
|
||||
As such, these should be phrased as descriptions of code of interest as opposed to questions".to_string()
|
||||
}
|
||||
fn parameters(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"queries": {
|
||||
"title": "queries",
|
||||
"type": "array",
|
||||
"items": {"type": "string"}
|
||||
}
|
||||
},
|
||||
"required": ["queries"]
|
||||
})
|
||||
}
|
||||
fn complete(&self, arguments: serde_json::Value) -> anyhow::Result<String> {
|
||||
todo!();
|
||||
}
|
||||
}
|
||||
impl Serialize for RepositoryContextRetriever {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
json!({"name": self.name(),
|
||||
"description": self.description(),
|
||||
"parameters": self.parameters()})
|
||||
.serialize(serializer)
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue