diff --git a/.github/workflows/release_actions.yml b/.github/workflows/release_actions.yml
index c1df24a8e5..550eda882b 100644
--- a/.github/workflows/release_actions.yml
+++ b/.github/workflows/release_actions.yml
@@ -20,9 +20,7 @@ jobs:
id: get-content
with:
stringToTruncate: |
- 📣 Zed ${{ github.event.release.tag_name }} was just released!
-
- Restart your Zed or head to ${{ steps.get-release-url.outputs.URL }} to grab it.
+ 📣 Zed [${{ github.event.release.tag_name }}](${{ steps.get-release-url.outputs.URL }}) was just released!
${{ github.event.release.body }}
maxLength: 2000
diff --git a/Cargo.lock b/Cargo.lock
index 3961450c6e..85014b42fe 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -103,7 +103,7 @@ dependencies = [
"rusqlite",
"serde",
"serde_json",
- "tiktoken-rs 0.5.4",
+ "tiktoken-rs",
"util",
]
@@ -316,12 +316,13 @@ dependencies = [
"regex",
"schemars",
"search",
+ "semantic_index",
"serde",
"serde_json",
"settings",
"smol",
"theme",
- "tiktoken-rs 0.4.5",
+ "tiktoken-rs",
"util",
"uuid 1.4.1",
"workspace",
@@ -1466,7 +1467,7 @@ dependencies = [
[[package]]
name = "collab"
-version = "0.24.0"
+version = "0.25.0"
dependencies = [
"anyhow",
"async-trait",
@@ -1629,6 +1630,7 @@ dependencies = [
"theme",
"util",
"workspace",
+ "zed-actions",
]
[[package]]
@@ -6975,7 +6977,7 @@ dependencies = [
"smol",
"tempdir",
"theme",
- "tiktoken-rs 0.5.4",
+ "tiktoken-rs",
"tree-sitter",
"tree-sitter-cpp",
"tree-sitter-elixir",
@@ -8166,21 +8168,6 @@ dependencies = [
"weezl",
]
-[[package]]
-name = "tiktoken-rs"
-version = "0.4.5"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "52aacc1cff93ba9d5f198c62c49c77fa0355025c729eed3326beaf7f33bc8614"
-dependencies = [
- "anyhow",
- "base64 0.21.4",
- "bstr",
- "fancy-regex",
- "lazy_static",
- "parking_lot 0.12.1",
- "rustc-hash",
-]
-
[[package]]
name = "tiktoken-rs"
version = "0.5.4"
@@ -10103,7 +10090,7 @@ dependencies = [
[[package]]
name = "zed"
-version = "0.109.0"
+version = "0.110.0"
dependencies = [
"activity_indicator",
"ai",
@@ -10238,6 +10225,7 @@ name = "zed-actions"
version = "0.1.0"
dependencies = [
"gpui",
+ "serde",
]
[[package]]
diff --git a/Procfile b/Procfile
index 2eb7de20fb..3f42c3a967 100644
--- a/Procfile
+++ b/Procfile
@@ -1,4 +1,4 @@
web: cd ../zed.dev && PORT=3000 npm run dev
-collab: cd crates/collab && RUST_LOG=${RUST_LOG:-collab=info} cargo run serve
+collab: cd crates/collab && RUST_LOG=${RUST_LOG:-warn,collab=info} cargo run serve
livekit: livekit-server --dev
postgrest: postgrest crates/collab/admin_api.conf
diff --git a/assets/icons/link.svg b/assets/icons/link.svg
new file mode 100644
index 0000000000..4925bd8e00
--- /dev/null
+++ b/assets/icons/link.svg
@@ -0,0 +1,3 @@
+
diff --git a/assets/icons/public.svg b/assets/icons/public.svg
new file mode 100644
index 0000000000..38278cdaba
--- /dev/null
+++ b/assets/icons/public.svg
@@ -0,0 +1,3 @@
+
diff --git a/assets/icons/update.svg b/assets/icons/update.svg
new file mode 100644
index 0000000000..b529b2b08b
--- /dev/null
+++ b/assets/icons/update.svg
@@ -0,0 +1,8 @@
+
diff --git a/crates/ai/src/embedding.rs b/crates/ai/src/embedding.rs
index 332470aa54..4587ece0a2 100644
--- a/crates/ai/src/embedding.rs
+++ b/crates/ai/src/embedding.rs
@@ -85,25 +85,6 @@ impl Embedding {
}
}
-// impl FromSql for Embedding {
-// fn column_result(value: ValueRef) -> FromSqlResult {
-// let bytes = value.as_blob()?;
-// let embedding: Result, Box> = bincode::deserialize(bytes);
-// if embedding.is_err() {
-// return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
-// }
-// Ok(Embedding(embedding.unwrap()))
-// }
-// }
-
-// impl ToSql for Embedding {
-// fn to_sql(&self) -> rusqlite::Result {
-// let bytes = bincode::serialize(&self.0)
-// .map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
-// Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
-// }
-// }
-
#[derive(Clone)]
pub struct OpenAIEmbeddings {
pub client: Arc,
@@ -300,6 +281,7 @@ impl EmbeddingProvider for OpenAIEmbeddings {
request_timeout,
)
.await?;
+
request_number += 1;
match response.status() {
diff --git a/crates/assistant/Cargo.toml b/crates/assistant/Cargo.toml
index f1daf47bab..9cfdd3301a 100644
--- a/crates/assistant/Cargo.toml
+++ b/crates/assistant/Cargo.toml
@@ -22,8 +22,11 @@ settings = { path = "../settings" }
theme = { path = "../theme" }
util = { path = "../util" }
workspace = { path = "../workspace" }
-uuid.workspace = true
+semantic_index = { path = "../semantic_index" }
+project = { path = "../project" }
+uuid.workspace = true
+log.workspace = true
anyhow.workspace = true
chrono = { version = "0.4", features = ["serde"] }
futures.workspace = true
@@ -36,7 +39,7 @@ schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
smol.workspace = true
-tiktoken-rs = "0.4"
+tiktoken-rs = "0.5"
[dev-dependencies]
editor = { path = "../editor", features = ["test-support"] }
diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs
index b1c6038602..65edb1832f 100644
--- a/crates/assistant/src/assistant_panel.rs
+++ b/crates/assistant/src/assistant_panel.rs
@@ -1,7 +1,7 @@
use crate::{
assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAIModel},
codegen::{self, Codegen, CodegenKind},
- prompts::generate_content_prompt,
+ prompts::{generate_content_prompt, PromptCodeSnippet},
MessageId, MessageMetadata, MessageStatus, Role, SavedConversation, SavedConversationMetadata,
SavedMessage,
};
@@ -29,13 +29,15 @@ use gpui::{
},
fonts::HighlightStyle,
geometry::vector::{vec2f, Vector2F},
- platform::{CursorStyle, MouseButton},
+ platform::{CursorStyle, MouseButton, PromptLevel},
Action, AnyElement, AppContext, AsyncAppContext, ClipboardItem, Element, Entity, ModelContext,
- ModelHandle, SizeConstraint, Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle,
- WindowContext,
+ ModelHandle, SizeConstraint, Subscription, Task, View, ViewContext, ViewHandle,
+ WeakModelHandle, WeakViewHandle, WindowContext,
};
use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _};
+use project::Project;
use search::BufferSearchBar;
+use semantic_index::{SemanticIndex, SemanticIndexStatus};
use settings::SettingsStore;
use std::{
cell::{Cell, RefCell},
@@ -46,7 +48,7 @@ use std::{
path::{Path, PathBuf},
rc::Rc,
sync::Arc,
- time::Duration,
+ time::{Duration, Instant},
};
use theme::{
components::{action_button::Button, ComponentExt},
@@ -72,6 +74,7 @@ actions!(
ResetKey,
InlineAssist,
ToggleIncludeConversation,
+ ToggleRetrieveContext,
]
);
@@ -108,6 +111,7 @@ pub fn init(cx: &mut AppContext) {
cx.add_action(InlineAssistant::confirm);
cx.add_action(InlineAssistant::cancel);
cx.add_action(InlineAssistant::toggle_include_conversation);
+ cx.add_action(InlineAssistant::toggle_retrieve_context);
cx.add_action(InlineAssistant::move_up);
cx.add_action(InlineAssistant::move_down);
}
@@ -145,6 +149,8 @@ pub struct AssistantPanel {
include_conversation_in_next_inline_assist: bool,
inline_prompt_history: VecDeque,
_watch_saved_conversations: Task>,
+ semantic_index: Option>,
+ retrieve_context_in_next_inline_assist: bool,
}
impl AssistantPanel {
@@ -191,6 +197,9 @@ impl AssistantPanel {
toolbar.add_item(cx.add_view(|cx| BufferSearchBar::new(cx)), cx);
toolbar
});
+
+ let semantic_index = SemanticIndex::global(cx);
+
let mut this = Self {
workspace: workspace_handle,
active_editor_index: Default::default(),
@@ -215,6 +224,8 @@ impl AssistantPanel {
include_conversation_in_next_inline_assist: false,
inline_prompt_history: Default::default(),
_watch_saved_conversations,
+ semantic_index,
+ retrieve_context_in_next_inline_assist: false,
};
let mut old_dock_position = this.position(cx);
@@ -262,12 +273,19 @@ impl AssistantPanel {
return;
};
+ let project = workspace.project();
+
this.update(cx, |assistant, cx| {
- assistant.new_inline_assist(&active_editor, cx)
+ assistant.new_inline_assist(&active_editor, cx, project)
});
}
- fn new_inline_assist(&mut self, editor: &ViewHandle, cx: &mut ViewContext) {
+ fn new_inline_assist(
+ &mut self,
+ editor: &ViewHandle,
+ cx: &mut ViewContext,
+ project: &ModelHandle,
+ ) {
let api_key = if let Some(api_key) = self.api_key.borrow().clone() {
api_key
} else {
@@ -312,6 +330,27 @@ impl AssistantPanel {
Codegen::new(editor.read(cx).buffer().clone(), codegen_kind, provider, cx)
});
+ if let Some(semantic_index) = self.semantic_index.clone() {
+ let project = project.clone();
+ cx.spawn(|_, mut cx| async move {
+ let previously_indexed = semantic_index
+ .update(&mut cx, |index, cx| {
+ index.project_previously_indexed(&project, cx)
+ })
+ .await
+ .unwrap_or(false);
+ if previously_indexed {
+ let _ = semantic_index
+ .update(&mut cx, |index, cx| {
+ index.index_project(project.clone(), cx)
+ })
+ .await;
+ }
+ anyhow::Ok(())
+ })
+ .detach_and_log_err(cx);
+ }
+
let measurements = Rc::new(Cell::new(BlockMeasurements::default()));
let inline_assistant = cx.add_view(|cx| {
let assistant = InlineAssistant::new(
@@ -322,6 +361,9 @@ impl AssistantPanel {
codegen.clone(),
self.workspace.clone(),
cx,
+ self.retrieve_context_in_next_inline_assist,
+ self.semantic_index.clone(),
+ project.clone(),
);
cx.focus_self();
assistant
@@ -362,6 +404,7 @@ impl AssistantPanel {
editor: editor.downgrade(),
inline_assistant: Some((block_id, inline_assistant.clone())),
codegen: codegen.clone(),
+ project: project.downgrade(),
_subscriptions: vec![
cx.subscribe(&inline_assistant, Self::handle_inline_assistant_event),
cx.subscribe(editor, {
@@ -440,8 +483,15 @@ impl AssistantPanel {
InlineAssistantEvent::Confirmed {
prompt,
include_conversation,
+ retrieve_context,
} => {
- self.confirm_inline_assist(assist_id, prompt, *include_conversation, cx);
+ self.confirm_inline_assist(
+ assist_id,
+ prompt,
+ *include_conversation,
+ cx,
+ *retrieve_context,
+ );
}
InlineAssistantEvent::Canceled => {
self.finish_inline_assist(assist_id, true, cx);
@@ -454,6 +504,9 @@ impl AssistantPanel {
} => {
self.include_conversation_in_next_inline_assist = *include_conversation;
}
+ InlineAssistantEvent::RetrieveContextToggled { retrieve_context } => {
+ self.retrieve_context_in_next_inline_assist = *retrieve_context
+ }
}
}
@@ -532,6 +585,7 @@ impl AssistantPanel {
user_prompt: &str,
include_conversation: bool,
cx: &mut ViewContext,
+ retrieve_context: bool,
) {
let conversation = if include_conversation {
self.active_editor()
@@ -553,6 +607,8 @@ impl AssistantPanel {
return;
};
+ let project = pending_assist.project.clone();
+
self.inline_prompt_history
.retain(|prompt| prompt != user_prompt);
self.inline_prompt_history.push_back(user_prompt.into());
@@ -593,10 +649,62 @@ impl AssistantPanel {
let codegen_kind = codegen.read(cx).kind().clone();
let user_prompt = user_prompt.to_string();
- let mut messages = Vec::new();
+ let snippets = if retrieve_context {
+ let Some(project) = project.upgrade(cx) else {
+ return;
+ };
+
+ let search_results = if let Some(semantic_index) = self.semantic_index.clone() {
+ let search_results = semantic_index.update(cx, |this, cx| {
+ this.search_project(project, user_prompt.to_string(), 10, vec![], vec![], cx)
+ });
+
+ cx.background()
+ .spawn(async move { search_results.await.unwrap_or_default() })
+ } else {
+ Task::ready(Vec::new())
+ };
+
+ let snippets = cx.spawn(|_, cx| async move {
+ let mut snippets = Vec::new();
+ for result in search_results.await {
+ snippets.push(PromptCodeSnippet::new(result, &cx));
+
+ // snippets.push(result.buffer.read_with(&cx, |buffer, _| {
+ // buffer
+ // .snapshot()
+ // .text_for_range(result.range)
+ // .collect::()
+ // }));
+ }
+ snippets
+ });
+ snippets
+ } else {
+ Task::ready(Vec::new())
+ };
+
let mut model = settings::get::(cx)
.default_open_ai_model
.clone();
+ let model_name = model.full_name();
+
+ let prompt = cx.background().spawn(async move {
+ let snippets = snippets.await;
+
+ let language_name = language_name.as_deref();
+ generate_content_prompt(
+ user_prompt,
+ language_name,
+ &buffer,
+ range,
+ codegen_kind,
+ snippets,
+ model_name,
+ )
+ });
+
+ let mut messages = Vec::new();
if let Some(conversation) = conversation {
let conversation = conversation.read(cx);
let buffer = conversation.buffer.read(cx);
@@ -608,11 +716,6 @@ impl AssistantPanel {
model = conversation.model.clone();
}
- let prompt = cx.background().spawn(async move {
- let language_name = language_name.as_deref();
- generate_content_prompt(user_prompt, language_name, &buffer, range, codegen_kind)
- });
-
cx.spawn(|_, mut cx| async move {
let prompt = prompt.await;
@@ -1514,12 +1617,14 @@ impl Conversation {
Role::Assistant => "assistant".into(),
Role::System => "system".into(),
},
- content: self
- .buffer
- .read(cx)
- .text_for_range(message.offset_range)
- .collect(),
+ content: Some(
+ self.buffer
+ .read(cx)
+ .text_for_range(message.offset_range)
+ .collect(),
+ ),
name: None,
+ function_call: None,
})
})
.collect::>();
@@ -2638,12 +2743,16 @@ enum InlineAssistantEvent {
Confirmed {
prompt: String,
include_conversation: bool,
+ retrieve_context: bool,
},
Canceled,
Dismissed,
IncludeConversationToggled {
include_conversation: bool,
},
+ RetrieveContextToggled {
+ retrieve_context: bool,
+ },
}
struct InlineAssistant {
@@ -2659,6 +2768,11 @@ struct InlineAssistant {
pending_prompt: String,
codegen: ModelHandle,
_subscriptions: Vec,
+ retrieve_context: bool,
+ semantic_index: Option>,
+ semantic_permissioned: Option,
+ project: WeakModelHandle,
+ maintain_rate_limit: Option>,
}
impl Entity for InlineAssistant {
@@ -2675,51 +2789,65 @@ impl View for InlineAssistant {
let theme = theme::current(cx);
Flex::row()
- .with_child(
- Flex::row()
- .with_child(
- Button::action(ToggleIncludeConversation)
- .with_tooltip("Include Conversation", theme.tooltip.clone())
+ .with_children([Flex::row()
+ .with_child(
+ Button::action(ToggleIncludeConversation)
+ .with_tooltip("Include Conversation", theme.tooltip.clone())
+ .with_id(self.id)
+ .with_contents(theme::components::svg::Svg::new("icons/ai.svg"))
+ .toggleable(self.include_conversation)
+ .with_style(theme.assistant.inline.include_conversation.clone())
+ .element()
+ .aligned(),
+ )
+ .with_children(if SemanticIndex::enabled(cx) {
+ Some(
+ Button::action(ToggleRetrieveContext)
+ .with_tooltip("Retrieve Context", theme.tooltip.clone())
.with_id(self.id)
- .with_contents(theme::components::svg::Svg::new("icons/ai.svg"))
- .toggleable(self.include_conversation)
- .with_style(theme.assistant.inline.include_conversation.clone())
+ .with_contents(theme::components::svg::Svg::new(
+ "icons/magnifying_glass.svg",
+ ))
+ .toggleable(self.retrieve_context)
+ .with_style(theme.assistant.inline.retrieve_context.clone())
.element()
.aligned(),
)
- .with_children(if let Some(error) = self.codegen.read(cx).error() {
- Some(
- Svg::new("icons/error.svg")
- .with_color(theme.assistant.error_icon.color)
- .constrained()
- .with_width(theme.assistant.error_icon.width)
- .contained()
- .with_style(theme.assistant.error_icon.container)
- .with_tooltip::(
- self.id,
- error.to_string(),
- None,
- theme.tooltip.clone(),
- cx,
- )
- .aligned(),
- )
- } else {
- None
- })
- .aligned()
- .constrained()
- .dynamically({
- let measurements = self.measurements.clone();
- move |constraint, _, _| {
- let measurements = measurements.get();
- SizeConstraint {
- min: vec2f(measurements.gutter_width, constraint.min.y()),
- max: vec2f(measurements.gutter_width, constraint.max.y()),
- }
+ } else {
+ None
+ })
+ .with_children(if let Some(error) = self.codegen.read(cx).error() {
+ Some(
+ Svg::new("icons/error.svg")
+ .with_color(theme.assistant.error_icon.color)
+ .constrained()
+ .with_width(theme.assistant.error_icon.width)
+ .contained()
+ .with_style(theme.assistant.error_icon.container)
+ .with_tooltip::(
+ self.id,
+ error.to_string(),
+ None,
+ theme.tooltip.clone(),
+ cx,
+ )
+ .aligned(),
+ )
+ } else {
+ None
+ })
+ .aligned()
+ .constrained()
+ .dynamically({
+ let measurements = self.measurements.clone();
+ move |constraint, _, _| {
+ let measurements = measurements.get();
+ SizeConstraint {
+ min: vec2f(measurements.gutter_width, constraint.min.y()),
+ max: vec2f(measurements.gutter_width, constraint.max.y()),
}
- }),
- )
+ }
+ })])
.with_child(Empty::new().constrained().dynamically({
let measurements = self.measurements.clone();
move |constraint, _, _| {
@@ -2742,6 +2870,16 @@ impl View for InlineAssistant {
.left()
.flex(1., true),
)
+ .with_children(if self.retrieve_context {
+ Some(
+ Flex::row()
+ .with_children(self.retrieve_context_status(cx))
+ .flex(1., true)
+ .aligned(),
+ )
+ } else {
+ None
+ })
.contained()
.with_style(theme.assistant.inline.container)
.into_any()
@@ -2767,6 +2905,9 @@ impl InlineAssistant {
codegen: ModelHandle,
workspace: WeakViewHandle,
cx: &mut ViewContext,
+ retrieve_context: bool,
+ semantic_index: Option>,
+ project: ModelHandle,
) -> Self {
let prompt_editor = cx.add_view(|cx| {
let mut editor = Editor::single_line(
@@ -2780,11 +2921,16 @@ impl InlineAssistant {
editor.set_placeholder_text(placeholder, cx);
editor
});
- let subscriptions = vec![
+ let mut subscriptions = vec![
cx.observe(&codegen, Self::handle_codegen_changed),
cx.subscribe(&prompt_editor, Self::handle_prompt_editor_events),
];
- Self {
+
+ if let Some(semantic_index) = semantic_index.clone() {
+ subscriptions.push(cx.observe(&semantic_index, Self::semantic_index_changed));
+ }
+
+ let assistant = Self {
id,
prompt_editor,
workspace,
@@ -2797,7 +2943,33 @@ impl InlineAssistant {
pending_prompt: String::new(),
codegen,
_subscriptions: subscriptions,
+ retrieve_context,
+ semantic_permissioned: None,
+ semantic_index,
+ project: project.downgrade(),
+ maintain_rate_limit: None,
+ };
+
+ assistant.index_project(cx).log_err();
+
+ assistant
+ }
+
+ fn semantic_permissioned(&self, cx: &mut ViewContext) -> Task> {
+ if let Some(value) = self.semantic_permissioned {
+ return Task::ready(Ok(value));
}
+
+ let Some(project) = self.project.upgrade(cx) else {
+ return Task::ready(Err(anyhow!("project was dropped")));
+ };
+
+ self.semantic_index
+ .as_ref()
+ .map(|semantic| {
+ semantic.update(cx, |this, cx| this.project_previously_indexed(&project, cx))
+ })
+ .unwrap_or(Task::ready(Ok(false)))
}
fn handle_prompt_editor_events(
@@ -2812,6 +2984,37 @@ impl InlineAssistant {
}
}
+ fn semantic_index_changed(
+ &mut self,
+ semantic_index: ModelHandle,
+ cx: &mut ViewContext,
+ ) {
+ let Some(project) = self.project.upgrade(cx) else {
+ return;
+ };
+
+ let status = semantic_index.read(cx).status(&project);
+ match status {
+ SemanticIndexStatus::Indexing {
+ rate_limit_expiry: Some(_),
+ ..
+ } => {
+ if self.maintain_rate_limit.is_none() {
+ self.maintain_rate_limit = Some(cx.spawn(|this, mut cx| async move {
+ loop {
+ cx.background().timer(Duration::from_secs(1)).await;
+ this.update(&mut cx, |_, cx| cx.notify()).log_err();
+ }
+ }));
+ }
+ return;
+ }
+ _ => {
+ self.maintain_rate_limit = None;
+ }
+ }
+ }
+
fn handle_codegen_changed(&mut self, _: ModelHandle, cx: &mut ViewContext) {
let is_read_only = !self.codegen.read(cx).idle();
self.prompt_editor.update(cx, |editor, cx| {
@@ -2861,12 +3064,241 @@ impl InlineAssistant {
cx.emit(InlineAssistantEvent::Confirmed {
prompt,
include_conversation: self.include_conversation,
+ retrieve_context: self.retrieve_context,
});
self.confirmed = true;
cx.notify();
}
}
+ fn toggle_retrieve_context(&mut self, _: &ToggleRetrieveContext, cx: &mut ViewContext) {
+ let semantic_permissioned = self.semantic_permissioned(cx);
+
+ let Some(project) = self.project.upgrade(cx) else {
+ return;
+ };
+
+ let project_name = project
+ .read(cx)
+ .worktree_root_names(cx)
+ .collect::>()
+ .join("/");
+ let is_plural = project_name.chars().filter(|letter| *letter == '/').count() > 0;
+ let prompt_text = format!("Would you like to index the '{}' project{} for context retrieval? This requires sending code to the OpenAI API", project_name,
+ if is_plural {
+ "s"
+ } else {""});
+
+ cx.spawn(|this, mut cx| async move {
+ // If Necessary prompt user
+ if !semantic_permissioned.await.unwrap_or(false) {
+ let mut answer = this.update(&mut cx, |_, cx| {
+ cx.prompt(
+ PromptLevel::Info,
+ prompt_text.as_str(),
+ &["Continue", "Cancel"],
+ )
+ })?;
+
+ if answer.next().await == Some(0) {
+ this.update(&mut cx, |this, _| {
+ this.semantic_permissioned = Some(true);
+ })?;
+ } else {
+ return anyhow::Ok(());
+ }
+ }
+
+ // If permissioned, update context appropriately
+ this.update(&mut cx, |this, cx| {
+ this.retrieve_context = !this.retrieve_context;
+
+ cx.emit(InlineAssistantEvent::RetrieveContextToggled {
+ retrieve_context: this.retrieve_context,
+ });
+
+ if this.retrieve_context {
+ this.index_project(cx).log_err();
+ }
+
+ cx.notify();
+ })?;
+
+ anyhow::Ok(())
+ })
+ .detach_and_log_err(cx);
+ }
+
+ fn index_project(&self, cx: &mut ViewContext) -> anyhow::Result<()> {
+ let Some(project) = self.project.upgrade(cx) else {
+ return Err(anyhow!("project was dropped!"));
+ };
+
+ let semantic_permissioned = self.semantic_permissioned(cx);
+ if let Some(semantic_index) = SemanticIndex::global(cx) {
+ cx.spawn(|_, mut cx| async move {
+ // This has to be updated to accomodate for semantic_permissions
+ if semantic_permissioned.await.unwrap_or(false) {
+ semantic_index
+ .update(&mut cx, |index, cx| index.index_project(project, cx))
+ .await
+ } else {
+ Err(anyhow!("project is not permissioned for semantic indexing"))
+ }
+ })
+ .detach_and_log_err(cx);
+ }
+
+ anyhow::Ok(())
+ }
+
+ fn retrieve_context_status(
+ &self,
+ cx: &mut ViewContext,
+ ) -> Option> {
+ enum ContextStatusIcon {}
+
+ let Some(project) = self.project.upgrade(cx) else {
+ return None;
+ };
+
+ if let Some(semantic_index) = SemanticIndex::global(cx) {
+ let status = semantic_index.update(cx, |index, _| index.status(&project));
+ let theme = theme::current(cx);
+ match status {
+ SemanticIndexStatus::NotAuthenticated {} => Some(
+ Svg::new("icons/error.svg")
+ .with_color(theme.assistant.error_icon.color)
+ .constrained()
+ .with_width(theme.assistant.error_icon.width)
+ .contained()
+ .with_style(theme.assistant.error_icon.container)
+ .with_tooltip::(
+ self.id,
+ "Not Authenticated. Please ensure you have a valid 'OPENAI_API_KEY' in your environment variables.",
+ None,
+ theme.tooltip.clone(),
+ cx,
+ )
+ .aligned()
+ .into_any(),
+ ),
+ SemanticIndexStatus::NotIndexed {} => Some(
+ Svg::new("icons/error.svg")
+ .with_color(theme.assistant.inline.context_status.error_icon.color)
+ .constrained()
+ .with_width(theme.assistant.inline.context_status.error_icon.width)
+ .contained()
+ .with_style(theme.assistant.inline.context_status.error_icon.container)
+ .with_tooltip::(
+ self.id,
+ "Not Indexed",
+ None,
+ theme.tooltip.clone(),
+ cx,
+ )
+ .aligned()
+ .into_any(),
+ ),
+ SemanticIndexStatus::Indexing {
+ remaining_files,
+ rate_limit_expiry,
+ } => {
+
+ let mut status_text = if remaining_files == 0 {
+ "Indexing...".to_string()
+ } else {
+ format!("Remaining files to index: {remaining_files}")
+ };
+
+ if let Some(rate_limit_expiry) = rate_limit_expiry {
+ let remaining_seconds = rate_limit_expiry.duration_since(Instant::now());
+ if remaining_seconds > Duration::from_secs(0) && remaining_files > 0 {
+ write!(
+ status_text,
+ " (rate limit expires in {}s)",
+ remaining_seconds.as_secs()
+ )
+ .unwrap();
+ }
+ }
+ Some(
+ Svg::new("icons/update.svg")
+ .with_color(theme.assistant.inline.context_status.in_progress_icon.color)
+ .constrained()
+ .with_width(theme.assistant.inline.context_status.in_progress_icon.width)
+ .contained()
+ .with_style(theme.assistant.inline.context_status.in_progress_icon.container)
+ .with_tooltip::(
+ self.id,
+ status_text,
+ None,
+ theme.tooltip.clone(),
+ cx,
+ )
+ .aligned()
+ .into_any(),
+ )
+ }
+ SemanticIndexStatus::Indexed {} => Some(
+ Svg::new("icons/check.svg")
+ .with_color(theme.assistant.inline.context_status.complete_icon.color)
+ .constrained()
+ .with_width(theme.assistant.inline.context_status.complete_icon.width)
+ .contained()
+ .with_style(theme.assistant.inline.context_status.complete_icon.container)
+ .with_tooltip::(
+ self.id,
+ "Index up to date",
+ None,
+ theme.tooltip.clone(),
+ cx,
+ )
+ .aligned()
+ .into_any(),
+ ),
+ }
+ } else {
+ None
+ }
+ }
+
+ // fn retrieve_context_status(&self, cx: &mut ViewContext) -> String {
+ // let project = self.project.clone();
+ // if let Some(semantic_index) = self.semantic_index.clone() {
+ // let status = semantic_index.update(cx, |index, cx| index.status(&project));
+ // return match status {
+ // // This theoretically shouldnt be a valid code path
+ // // As the inline assistant cant be launched without an API key
+ // // We keep it here for safety
+ // semantic_index::SemanticIndexStatus::NotAuthenticated => {
+ // "Not Authenticated!\nPlease ensure you have an `OPENAI_API_KEY` in your environment variables.".to_string()
+ // }
+ // semantic_index::SemanticIndexStatus::Indexed => {
+ // "Indexing Complete!".to_string()
+ // }
+ // semantic_index::SemanticIndexStatus::Indexing { remaining_files, rate_limit_expiry } => {
+
+ // let mut status = format!("Remaining files to index for Context Retrieval: {remaining_files}");
+
+ // if let Some(rate_limit_expiry) = rate_limit_expiry {
+ // let remaining_seconds =
+ // rate_limit_expiry.duration_since(Instant::now());
+ // if remaining_seconds > Duration::from_secs(0) {
+ // write!(status, " (rate limit resets in {}s)", remaining_seconds.as_secs()).unwrap();
+ // }
+ // }
+ // status
+ // }
+ // semantic_index::SemanticIndexStatus::NotIndexed => {
+ // "Not Indexed for Context Retrieval".to_string()
+ // }
+ // };
+ // }
+
+ // "".to_string()
+ // }
+
fn toggle_include_conversation(
&mut self,
_: &ToggleIncludeConversation,
@@ -2929,6 +3361,7 @@ struct PendingInlineAssist {
inline_assistant: Option<(BlockId, ViewHandle)>,
codegen: ModelHandle,
_subscriptions: Vec,
+ project: WeakModelHandle,
}
fn merge_ranges(ranges: &mut Vec>, buffer: &MultiBufferSnapshot) {
diff --git a/crates/assistant/src/prompts.rs b/crates/assistant/src/prompts.rs
index d326a7f445..18e9e18f7d 100644
--- a/crates/assistant/src/prompts.rs
+++ b/crates/assistant/src/prompts.rs
@@ -1,8 +1,60 @@
use crate::codegen::CodegenKind;
+use gpui::AsyncAppContext;
use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
+use semantic_index::SearchResult;
use std::cmp::{self, Reverse};
use std::fmt::Write;
use std::ops::Range;
+use std::path::PathBuf;
+use tiktoken_rs::ChatCompletionRequestMessage;
+
+pub struct PromptCodeSnippet {
+ path: Option,
+ language_name: Option,
+ content: String,
+}
+
+impl PromptCodeSnippet {
+ pub fn new(search_result: SearchResult, cx: &AsyncAppContext) -> Self {
+ let (content, language_name, file_path) =
+ search_result.buffer.read_with(cx, |buffer, _| {
+ let snapshot = buffer.snapshot();
+ let content = snapshot
+ .text_for_range(search_result.range.clone())
+ .collect::();
+
+ let language_name = buffer
+ .language()
+ .and_then(|language| Some(language.name().to_string()));
+
+ let file_path = buffer
+ .file()
+ .and_then(|file| Some(file.path().to_path_buf()));
+
+ (content, language_name, file_path)
+ });
+
+ PromptCodeSnippet {
+ path: file_path,
+ language_name,
+ content,
+ }
+ }
+}
+
+impl ToString for PromptCodeSnippet {
+ fn to_string(&self) -> String {
+ let path = self
+ .path
+ .as_ref()
+ .and_then(|path| Some(path.to_string_lossy().to_string()))
+ .unwrap_or("".to_string());
+ let language_name = self.language_name.clone().unwrap_or("".to_string());
+ let content = self.content.clone();
+
+ format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```")
+ }
+}
#[allow(dead_code)]
fn summarize(buffer: &BufferSnapshot, selected_range: Range) -> String {
@@ -121,17 +173,25 @@ pub fn generate_content_prompt(
buffer: &BufferSnapshot,
range: Range,
kind: CodegenKind,
+ search_results: Vec,
+ model: &str,
) -> String {
+ const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500;
+ const RESERVED_TOKENS_FOR_GENERATION: usize = 1000;
+
+ let mut prompts = Vec::new();
let range = range.to_offset(buffer);
- let mut prompt = String::new();
// General Preamble
if let Some(language_name) = language_name {
- writeln!(prompt, "You're an expert {language_name} engineer.\n").unwrap();
+ prompts.push(format!("You're an expert {language_name} engineer.\n"));
} else {
- writeln!(prompt, "You're an expert engineer.\n").unwrap();
+ prompts.push("You're an expert engineer.\n".to_string());
}
+ // Snippets
+ let mut snippet_position = prompts.len() - 1;
+
let mut content = String::new();
content.extend(buffer.text_for_range(0..range.start));
if range.start == range.end {
@@ -145,59 +205,103 @@ pub fn generate_content_prompt(
}
content.extend(buffer.text_for_range(range.end..buffer.len()));
- writeln!(
- prompt,
- "The file you are currently working on has the following content:"
- )
- .unwrap();
+ prompts.push("The file you are currently working on has the following content:\n".to_string());
+
if let Some(language_name) = language_name {
let language_name = language_name.to_lowercase();
- writeln!(prompt, "```{language_name}\n{content}\n```").unwrap();
+ prompts.push(format!("```{language_name}\n{content}\n```"));
} else {
- writeln!(prompt, "```\n{content}\n```").unwrap();
+ prompts.push(format!("```\n{content}\n```"));
}
match kind {
CodegenKind::Generate { position: _ } => {
- writeln!(prompt, "In particular, the user's cursor is current on the '<|START|>' span in the above outline, with no text selected.").unwrap();
- writeln!(
- prompt,
- "Assume the cursor is located where the `<|START|` marker is."
- )
- .unwrap();
- writeln!(
- prompt,
+ prompts.push("In particular, the user's cursor is currently on the '<|START|>' span in the above outline, with no text selected.".to_string());
+ prompts
+ .push("Assume the cursor is located where the `<|START|` marker is.".to_string());
+ prompts.push(
"Text can't be replaced, so assume your answer will be inserted at the cursor."
- )
- .unwrap();
- writeln!(
- prompt,
+ .to_string(),
+ );
+ prompts.push(format!(
"Generate text based on the users prompt: {user_prompt}"
- )
- .unwrap();
+ ));
}
CodegenKind::Transform { range: _ } => {
- writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap();
- writeln!(
- prompt,
- "Modify the users code selected text based upon the users prompt: {user_prompt}"
- )
- .unwrap();
- writeln!(
- prompt,
- "You MUST reply with only the adjusted code (within the '<|START|' and '|END|>' spans), not the entire file."
- )
- .unwrap();
+ prompts.push("In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.".to_string());
+ prompts.push(format!(
+ "Modify the users code selected text based upon the users prompt: '{user_prompt}'"
+ ));
+ prompts.push("You MUST reply with only the adjusted code (within the '<|START|' and '|END|>' spans), not the entire file.".to_string());
}
}
if let Some(language_name) = language_name {
- writeln!(prompt, "Your answer MUST always be valid {language_name}").unwrap();
+ prompts.push(format!(
+ "Your answer MUST always and only be valid {language_name}"
+ ));
}
- writeln!(prompt, "Always wrap your response in a Markdown codeblock").unwrap();
- writeln!(prompt, "Never make remarks about the output.").unwrap();
+ prompts.push("Never make remarks about the output.".to_string());
+ prompts.push("Do not return any text, except the generated code.".to_string());
+ prompts.push("Always wrap your code in a Markdown block".to_string());
- prompt
+ let current_messages = [ChatCompletionRequestMessage {
+ role: "user".to_string(),
+ content: Some(prompts.join("\n")),
+ function_call: None,
+ name: None,
+ }];
+
+ let mut remaining_token_count = if let Ok(current_token_count) =
+ tiktoken_rs::num_tokens_from_messages(model, ¤t_messages)
+ {
+ let max_token_count = tiktoken_rs::model::get_context_size(model);
+ let intermediate_token_count = if max_token_count > current_token_count {
+ max_token_count - current_token_count
+ } else {
+ 0
+ };
+
+ if intermediate_token_count < RESERVED_TOKENS_FOR_GENERATION {
+ 0
+ } else {
+ intermediate_token_count - RESERVED_TOKENS_FOR_GENERATION
+ }
+ } else {
+ // If tiktoken fails to count token count, assume we have no space remaining.
+ 0
+ };
+
+ // TODO:
+ // - add repository name to snippet
+ // - add file path
+ // - add language
+ if let Ok(encoding) = tiktoken_rs::get_bpe_from_model(model) {
+ let mut template = "You are working inside a large repository, here are a few code snippets that may be useful";
+
+ for search_result in search_results {
+ let mut snippet_prompt = template.to_string();
+ let snippet = search_result.to_string();
+ writeln!(snippet_prompt, "```\n{snippet}\n```").unwrap();
+
+ let token_count = encoding
+ .encode_with_special_tokens(snippet_prompt.as_str())
+ .len();
+ if token_count <= remaining_token_count {
+ if token_count < MAXIMUM_SNIPPET_TOKEN_COUNT {
+ prompts.insert(snippet_position, snippet_prompt);
+ snippet_position += 1;
+ remaining_token_count -= token_count;
+ // If you have already added the template to the prompt, remove the template.
+ template = "";
+ }
+ } else {
+ break;
+ }
+ }
+ }
+
+ prompts.join("\n")
}
#[cfg(test)]
diff --git a/crates/channel/src/channel_store.rs b/crates/channel/src/channel_store.rs
index ae8a797d06..221b845297 100644
--- a/crates/channel/src/channel_store.rs
+++ b/crates/channel/src/channel_store.rs
@@ -9,7 +9,7 @@ use db::RELEASE_CHANNEL;
use futures::{channel::mpsc, future::Shared, Future, FutureExt, StreamExt};
use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
use rpc::{
- proto::{self, ChannelEdge, ChannelPermission},
+ proto::{self, ChannelEdge, ChannelPermission, ChannelRole, ChannelVisibility},
TypedEnvelope,
};
use serde_derive::{Deserialize, Serialize};
@@ -49,6 +49,7 @@ pub type ChannelData = (Channel, ChannelPath);
pub struct Channel {
pub id: ChannelId,
pub name: String,
+ pub visibility: proto::ChannelVisibility,
pub unseen_note_version: Option<(u64, clock::Global)>,
pub unseen_message_id: Option,
}
@@ -79,7 +80,32 @@ pub struct ChannelPath(Arc<[ChannelId]>);
pub struct ChannelMembership {
pub user: Arc,
pub kind: proto::channel_member::Kind,
- pub admin: bool,
+ pub role: proto::ChannelRole,
+}
+impl ChannelMembership {
+ pub fn sort_key(&self) -> MembershipSortKey {
+ MembershipSortKey {
+ role_order: match self.role {
+ proto::ChannelRole::Admin => 0,
+ proto::ChannelRole::Member => 1,
+ proto::ChannelRole::Banned => 2,
+ proto::ChannelRole::Guest => 3,
+ },
+ kind_order: match self.kind {
+ proto::channel_member::Kind::Member => 0,
+ proto::channel_member::Kind::AncestorMember => 1,
+ proto::channel_member::Kind::Invitee => 2,
+ },
+ username_order: self.user.github_login.as_str(),
+ }
+ }
+}
+
+#[derive(PartialOrd, Ord, PartialEq, Eq)]
+pub struct MembershipSortKey<'a> {
+ role_order: u8,
+ kind_order: u8,
+ username_order: &'a str,
}
pub enum ChannelEvent {
@@ -475,7 +501,7 @@ impl ChannelStore {
insert_edge: parent_edge,
channel_permissions: vec![ChannelPermission {
channel_id,
- is_admin: true,
+ role: ChannelRole::Admin.into(),
}],
..Default::default()
},
@@ -547,11 +573,30 @@ impl ChannelStore {
})
}
+ pub fn set_channel_visibility(
+ &mut self,
+ channel_id: ChannelId,
+ visibility: ChannelVisibility,
+ cx: &mut ModelContext,
+ ) -> Task> {
+ let client = self.client.clone();
+ cx.spawn(|_, _| async move {
+ let _ = client
+ .request(proto::SetChannelVisibility {
+ channel_id,
+ visibility: visibility.into(),
+ })
+ .await?;
+
+ Ok(())
+ })
+ }
+
pub fn invite_member(
&mut self,
channel_id: ChannelId,
user_id: UserId,
- admin: bool,
+ role: proto::ChannelRole,
cx: &mut ModelContext,
) -> Task> {
if !self.outgoing_invites.insert((channel_id, user_id)) {
@@ -565,7 +610,7 @@ impl ChannelStore {
.request(proto::InviteChannelMember {
channel_id,
user_id,
- admin,
+ role: role.into(),
})
.await;
@@ -609,11 +654,11 @@ impl ChannelStore {
})
}
- pub fn set_member_admin(
+ pub fn set_member_role(
&mut self,
channel_id: ChannelId,
user_id: UserId,
- admin: bool,
+ role: proto::ChannelRole,
cx: &mut ModelContext,
) -> Task> {
if !self.outgoing_invites.insert((channel_id, user_id)) {
@@ -624,10 +669,10 @@ impl ChannelStore {
let client = self.client.clone();
cx.spawn(|this, mut cx| async move {
let result = client
- .request(proto::SetChannelMemberAdmin {
+ .request(proto::SetChannelMemberRole {
channel_id,
user_id,
- admin,
+ role: role.into(),
})
.await;
@@ -716,8 +761,8 @@ impl ChannelStore {
.filter_map(|(user, member)| {
Some(ChannelMembership {
user,
- admin: member.admin,
- kind: proto::channel_member::Kind::from_i32(member.kind)?,
+ role: member.role(),
+ kind: member.kind(),
})
})
.collect())
@@ -912,6 +957,7 @@ impl ChannelStore {
ix,
Arc::new(Channel {
id: channel.id,
+ visibility: channel.visibility(),
name: channel.name,
unseen_note_version: None,
unseen_message_id: None,
@@ -978,7 +1024,7 @@ impl ChannelStore {
}
for permission in payload.channel_permissions {
- if permission.is_admin {
+ if permission.role() == proto::ChannelRole::Admin {
self.channels_with_admin_privileges
.insert(permission.channel_id);
} else {
diff --git a/crates/channel/src/channel_store/channel_index.rs b/crates/channel/src/channel_store/channel_index.rs
index bf0de1b644..36379a3942 100644
--- a/crates/channel/src/channel_store/channel_index.rs
+++ b/crates/channel/src/channel_store/channel_index.rs
@@ -123,12 +123,15 @@ impl<'a> ChannelPathsInsertGuard<'a> {
pub fn insert(&mut self, channel_proto: proto::Channel) {
if let Some(existing_channel) = self.channels_by_id.get_mut(&channel_proto.id) {
- Arc::make_mut(existing_channel).name = channel_proto.name;
+ let existing_channel = Arc::make_mut(existing_channel);
+ existing_channel.visibility = channel_proto.visibility();
+ existing_channel.name = channel_proto.name;
} else {
self.channels_by_id.insert(
channel_proto.id,
Arc::new(Channel {
id: channel_proto.id,
+ visibility: channel_proto.visibility(),
name: channel_proto.name,
unseen_note_version: None,
unseen_message_id: None,
diff --git a/crates/channel/src/channel_store_tests.rs b/crates/channel/src/channel_store_tests.rs
index 1a762b85cb..8ad8f21224 100644
--- a/crates/channel/src/channel_store_tests.rs
+++ b/crates/channel/src/channel_store_tests.rs
@@ -3,7 +3,7 @@ use crate::channel_chat::ChannelChatEvent;
use super::*;
use client::{test::FakeServer, Client, UserStore};
use gpui::{AppContext, ModelHandle, TestAppContext};
-use rpc::proto;
+use rpc::proto::{self};
use settings::SettingsStore;
use util::http::FakeHttpClient;
@@ -18,15 +18,17 @@ fn test_update_channels(cx: &mut AppContext) {
proto::Channel {
id: 1,
name: "b".to_string(),
+ visibility: proto::ChannelVisibility::Members as i32,
},
proto::Channel {
id: 2,
name: "a".to_string(),
+ visibility: proto::ChannelVisibility::Members as i32,
},
],
channel_permissions: vec![proto::ChannelPermission {
channel_id: 1,
- is_admin: true,
+ role: proto::ChannelRole::Admin.into(),
}],
..Default::default()
},
@@ -49,10 +51,12 @@ fn test_update_channels(cx: &mut AppContext) {
proto::Channel {
id: 3,
name: "x".to_string(),
+ visibility: proto::ChannelVisibility::Members as i32,
},
proto::Channel {
id: 4,
name: "y".to_string(),
+ visibility: proto::ChannelVisibility::Members as i32,
},
],
insert_edge: vec![
@@ -92,14 +96,17 @@ fn test_dangling_channel_paths(cx: &mut AppContext) {
proto::Channel {
id: 0,
name: "a".to_string(),
+ visibility: proto::ChannelVisibility::Members as i32,
},
proto::Channel {
id: 1,
name: "b".to_string(),
+ visibility: proto::ChannelVisibility::Members as i32,
},
proto::Channel {
id: 2,
name: "c".to_string(),
+ visibility: proto::ChannelVisibility::Members as i32,
},
],
insert_edge: vec![
@@ -114,7 +121,7 @@ fn test_dangling_channel_paths(cx: &mut AppContext) {
],
channel_permissions: vec![proto::ChannelPermission {
channel_id: 0,
- is_admin: true,
+ role: proto::ChannelRole::Admin.into(),
}],
..Default::default()
},
@@ -158,6 +165,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) {
channels: vec![proto::Channel {
id: channel_id,
name: "the-channel".to_string(),
+ visibility: proto::ChannelVisibility::Members as i32,
}],
..Default::default()
});
diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml
index c139da831e..64bc191b21 100644
--- a/crates/collab/Cargo.toml
+++ b/crates/collab/Cargo.toml
@@ -3,7 +3,7 @@ authors = ["Nathan Sobo "]
default-run = "collab"
edition = "2021"
name = "collab"
-version = "0.24.0"
+version = "0.25.0"
publish = false
[[bin]]
diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql
index c835fff60d..fb2f276b8b 100644
--- a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql
+++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql
@@ -44,7 +44,7 @@ CREATE UNIQUE INDEX "index_rooms_on_channel_id" ON "rooms" ("channel_id");
CREATE TABLE "projects" (
"id" INTEGER PRIMARY KEY AUTOINCREMENT,
- "room_id" INTEGER REFERENCES rooms (id) NOT NULL,
+ "room_id" INTEGER REFERENCES rooms (id) ON DELETE CASCADE NOT NULL,
"host_user_id" INTEGER REFERENCES users (id) NOT NULL,
"host_connection_id" INTEGER,
"host_connection_server_id" INTEGER REFERENCES servers (id) ON DELETE CASCADE,
@@ -192,7 +192,8 @@ CREATE INDEX "index_followers_on_room_id" ON "followers" ("room_id");
CREATE TABLE "channels" (
"id" INTEGER PRIMARY KEY AUTOINCREMENT,
"name" VARCHAR NOT NULL,
- "created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
+ "created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ "visibility" VARCHAR NOT NULL
);
CREATE TABLE IF NOT EXISTS "channel_chat_participants" (
@@ -234,6 +235,7 @@ CREATE TABLE "channel_members" (
"channel_id" INTEGER NOT NULL REFERENCES channels (id) ON DELETE CASCADE,
"user_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE,
"admin" BOOLEAN NOT NULL DEFAULT false,
+ "role" VARCHAR,
"accepted" BOOLEAN NOT NULL DEFAULT false,
"updated_at" TIMESTAMP NOT NULL DEFAULT now
);
diff --git a/crates/collab/migrations/20231011214412_add_guest_role.sql b/crates/collab/migrations/20231011214412_add_guest_role.sql
new file mode 100644
index 0000000000..1713547158
--- /dev/null
+++ b/crates/collab/migrations/20231011214412_add_guest_role.sql
@@ -0,0 +1,4 @@
+ALTER TABLE channel_members ADD COLUMN role TEXT;
+UPDATE channel_members SET role = CASE WHEN admin THEN 'admin' ELSE 'member' END;
+
+ALTER TABLE channels ADD COLUMN visibility TEXT NOT NULL DEFAULT 'members';
diff --git a/crates/collab/migrations/20231017185833_projects_room_id_fkey_on_delete_cascade.sql b/crates/collab/migrations/20231017185833_projects_room_id_fkey_on_delete_cascade.sql
new file mode 100644
index 0000000000..be535ff7fa
--- /dev/null
+++ b/crates/collab/migrations/20231017185833_projects_room_id_fkey_on_delete_cascade.sql
@@ -0,0 +1,8 @@
+-- Add migration script here
+
+ALTER TABLE projects
+ DROP CONSTRAINT projects_room_id_fkey,
+ ADD CONSTRAINT projects_room_id_fkey
+ FOREIGN KEY (room_id)
+ REFERENCES rooms (id)
+ ON DELETE CASCADE;
diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs
index 4dfb21ba03..2eb229046a 100644
--- a/crates/collab/src/db.rs
+++ b/crates/collab/src/db.rs
@@ -432,6 +432,7 @@ pub struct NewUserResult {
pub struct Channel {
pub id: ChannelId,
pub name: String,
+ pub visibility: ChannelVisibility,
}
#[derive(Debug, PartialEq)]
diff --git a/crates/collab/src/db/ids.rs b/crates/collab/src/db/ids.rs
index bd07af8a35..433444de67 100644
--- a/crates/collab/src/db/ids.rs
+++ b/crates/collab/src/db/ids.rs
@@ -1,4 +1,5 @@
use crate::Result;
+use rpc::proto;
use sea_orm::{entity::prelude::*, DbErr};
use serde::{Deserialize, Serialize};
@@ -82,3 +83,101 @@ id_type!(ChannelBufferCollaboratorId);
id_type!(FlagId);
id_type!(NotificationId);
id_type!(NotificationKindId);
+
+#[derive(Eq, PartialEq, Copy, Clone, Debug, EnumIter, DeriveActiveEnum, Default)]
+#[sea_orm(rs_type = "String", db_type = "String(None)")]
+pub enum ChannelRole {
+ #[sea_orm(string_value = "admin")]
+ Admin,
+ #[sea_orm(string_value = "member")]
+ #[default]
+ Member,
+ #[sea_orm(string_value = "guest")]
+ Guest,
+ #[sea_orm(string_value = "banned")]
+ Banned,
+}
+
+impl ChannelRole {
+ pub fn should_override(&self, other: Self) -> bool {
+ use ChannelRole::*;
+ match self {
+ Admin => matches!(other, Member | Banned | Guest),
+ Member => matches!(other, Banned | Guest),
+ Banned => matches!(other, Guest),
+ Guest => false,
+ }
+ }
+
+ pub fn max(&self, other: Self) -> Self {
+ if self.should_override(other) {
+ *self
+ } else {
+ other
+ }
+ }
+}
+
+impl From for ChannelRole {
+ fn from(value: proto::ChannelRole) -> Self {
+ match value {
+ proto::ChannelRole::Admin => ChannelRole::Admin,
+ proto::ChannelRole::Member => ChannelRole::Member,
+ proto::ChannelRole::Guest => ChannelRole::Guest,
+ proto::ChannelRole::Banned => ChannelRole::Banned,
+ }
+ }
+}
+
+impl Into for ChannelRole {
+ fn into(self) -> proto::ChannelRole {
+ match self {
+ ChannelRole::Admin => proto::ChannelRole::Admin,
+ ChannelRole::Member => proto::ChannelRole::Member,
+ ChannelRole::Guest => proto::ChannelRole::Guest,
+ ChannelRole::Banned => proto::ChannelRole::Banned,
+ }
+ }
+}
+
+impl Into for ChannelRole {
+ fn into(self) -> i32 {
+ let proto: proto::ChannelRole = self.into();
+ proto.into()
+ }
+}
+
+#[derive(Eq, PartialEq, Copy, Clone, Debug, EnumIter, DeriveActiveEnum, Default, Hash)]
+#[sea_orm(rs_type = "String", db_type = "String(None)")]
+pub enum ChannelVisibility {
+ #[sea_orm(string_value = "public")]
+ Public,
+ #[sea_orm(string_value = "members")]
+ #[default]
+ Members,
+}
+
+impl From for ChannelVisibility {
+ fn from(value: proto::ChannelVisibility) -> Self {
+ match value {
+ proto::ChannelVisibility::Public => ChannelVisibility::Public,
+ proto::ChannelVisibility::Members => ChannelVisibility::Members,
+ }
+ }
+}
+
+impl Into for ChannelVisibility {
+ fn into(self) -> proto::ChannelVisibility {
+ match self {
+ ChannelVisibility::Public => proto::ChannelVisibility::Public,
+ ChannelVisibility::Members => proto::ChannelVisibility::Members,
+ }
+ }
+}
+
+impl Into for ChannelVisibility {
+ fn into(self) -> i32 {
+ let proto: proto::ChannelVisibility = self.into();
+ proto.into()
+ }
+}
diff --git a/crates/collab/src/db/queries/buffers.rs b/crates/collab/src/db/queries/buffers.rs
index c85432f2bb..69f100e6b8 100644
--- a/crates/collab/src/db/queries/buffers.rs
+++ b/crates/collab/src/db/queries/buffers.rs
@@ -482,7 +482,9 @@ impl Database {
)
.await?;
- channel_members = self.get_channel_members_internal(channel_id, &*tx).await?;
+ channel_members = self
+ .get_channel_participants_internal(channel_id, &*tx)
+ .await?;
let collaborators = self
.get_channel_buffer_collaborators_internal(channel_id, &*tx)
.await?;
diff --git a/crates/collab/src/db/queries/channels.rs b/crates/collab/src/db/queries/channels.rs
index 1ca38b2e3c..56b1d3b7ff 100644
--- a/crates/collab/src/db/queries/channels.rs
+++ b/crates/collab/src/db/queries/channels.rs
@@ -1,8 +1,5 @@
use super::*;
-use rpc::proto::ChannelEdge;
-use smallvec::SmallVec;
-
-type ChannelDescendants = HashMap>;
+use rpc::proto::{channel_member::Kind, ChannelEdge};
impl Database {
#[cfg(test)]
@@ -37,8 +34,9 @@ impl Database {
}
let channel = channel::ActiveModel {
+ id: ActiveValue::NotSet,
name: ActiveValue::Set(name.to_string()),
- ..Default::default()
+ visibility: ActiveValue::Set(ChannelVisibility::Members),
}
.insert(&*tx)
.await?;
@@ -74,11 +72,11 @@ impl Database {
}
channel_member::ActiveModel {
+ id: ActiveValue::NotSet,
channel_id: ActiveValue::Set(channel.id),
user_id: ActiveValue::Set(creator_id),
accepted: ActiveValue::Set(true),
- admin: ActiveValue::Set(true),
- ..Default::default()
+ role: ActiveValue::Set(ChannelRole::Admin),
}
.insert(&*tx)
.await?;
@@ -88,6 +86,116 @@ impl Database {
.await
}
+ pub async fn join_channel(
+ &self,
+ channel_id: ChannelId,
+ user_id: UserId,
+ connection: ConnectionId,
+ environment: &str,
+ ) -> Result<(JoinRoom, Option)> {
+ self.transaction(move |tx| async move {
+ let mut joined_channel_id = None;
+
+ let channel = channel::Entity::find()
+ .filter(channel::Column::Id.eq(channel_id))
+ .one(&*tx)
+ .await?;
+
+ let mut role = self
+ .channel_role_for_user(channel_id, user_id, &*tx)
+ .await?;
+
+ if role.is_none() && channel.is_some() {
+ if let Some(invitation) = self
+ .pending_invite_for_channel(channel_id, user_id, &*tx)
+ .await?
+ {
+ // note, this may be a parent channel
+ joined_channel_id = Some(invitation.channel_id);
+ role = Some(invitation.role);
+
+ channel_member::Entity::update(channel_member::ActiveModel {
+ accepted: ActiveValue::Set(true),
+ ..invitation.into_active_model()
+ })
+ .exec(&*tx)
+ .await?;
+
+ debug_assert!(
+ self.channel_role_for_user(channel_id, user_id, &*tx)
+ .await?
+ == role
+ );
+ }
+ }
+ if role.is_none()
+ && channel.as_ref().map(|c| c.visibility) == Some(ChannelVisibility::Public)
+ {
+ let channel_id_to_join = self
+ .most_public_ancestor_for_channel(channel_id, &*tx)
+ .await?
+ .unwrap_or(channel_id);
+ // TODO: change this back to Guest.
+ role = Some(ChannelRole::Member);
+ joined_channel_id = Some(channel_id_to_join);
+
+ channel_member::Entity::insert(channel_member::ActiveModel {
+ id: ActiveValue::NotSet,
+ channel_id: ActiveValue::Set(channel_id_to_join),
+ user_id: ActiveValue::Set(user_id),
+ accepted: ActiveValue::Set(true),
+ // TODO: change this back to Guest.
+ role: ActiveValue::Set(ChannelRole::Member),
+ })
+ .exec(&*tx)
+ .await?;
+
+ debug_assert!(
+ self.channel_role_for_user(channel_id, user_id, &*tx)
+ .await?
+ == role
+ );
+ }
+
+ if channel.is_none() || role.is_none() || role == Some(ChannelRole::Banned) {
+ Err(anyhow!("no such channel, or not allowed"))?
+ }
+
+ let live_kit_room = format!("channel-{}", nanoid::nanoid!(30));
+ let room_id = self
+ .get_or_create_channel_room(channel_id, &live_kit_room, environment, &*tx)
+ .await?;
+
+ self.join_channel_room_internal(channel_id, room_id, user_id, connection, &*tx)
+ .await
+ .map(|jr| (jr, joined_channel_id))
+ })
+ .await
+ }
+
+ pub async fn set_channel_visibility(
+ &self,
+ channel_id: ChannelId,
+ visibility: ChannelVisibility,
+ user_id: UserId,
+ ) -> Result {
+ self.transaction(move |tx| async move {
+ self.check_user_is_channel_admin(channel_id, user_id, &*tx)
+ .await?;
+
+ let channel = channel::ActiveModel {
+ id: ActiveValue::Unchanged(channel_id),
+ visibility: ActiveValue::Set(visibility),
+ ..Default::default()
+ }
+ .update(&*tx)
+ .await?;
+
+ Ok(channel)
+ })
+ .await
+ }
+
pub async fn delete_channel(
&self,
channel_id: ChannelId,
@@ -98,17 +206,19 @@ impl Database {
.await?;
// Don't remove descendant channels that have additional parents.
- let mut channels_to_remove = self.get_channel_descendants([channel_id], &*tx).await?;
+ let mut channels_to_remove: HashSet = HashSet::default();
+ channels_to_remove.insert(channel_id);
+
+ let graph = self.get_channel_descendants([channel_id], &*tx).await?;
+ for edge in graph.iter() {
+ channels_to_remove.insert(ChannelId::from_proto(edge.channel_id));
+ }
+
{
let mut channels_to_keep = channel_path::Entity::find()
.filter(
channel_path::Column::ChannelId
- .is_in(
- channels_to_remove
- .keys()
- .copied()
- .filter(|&id| id != channel_id),
- )
+ .is_in(channels_to_remove.iter().copied())
.and(
channel_path::Column::IdPath
.not_like(&format!("%/{}/%", channel_id)),
@@ -133,7 +243,7 @@ impl Database {
.await?;
channel::Entity::delete_many()
- .filter(channel::Column::Id.is_in(channels_to_remove.keys().copied()))
+ .filter(channel::Column::Id.is_in(channels_to_remove.iter().copied()))
.exec(&*tx)
.await?;
@@ -150,7 +260,7 @@ impl Database {
);
tx.execute(channel_paths_stmt).await?;
- Ok((channels_to_remove.into_keys().collect(), members_to_notify))
+ Ok((channels_to_remove.into_iter().collect(), members_to_notify))
})
.await
}
@@ -160,7 +270,7 @@ impl Database {
channel_id: ChannelId,
invitee_id: UserId,
inviter_id: UserId,
- is_admin: bool,
+ role: ChannelRole,
) -> Result {
self.transaction(move |tx| async move {
self.check_user_is_channel_admin(channel_id, inviter_id, &*tx)
@@ -172,11 +282,11 @@ impl Database {
.ok_or_else(|| anyhow!("no such channel"))?;
channel_member::ActiveModel {
+ id: ActiveValue::NotSet,
channel_id: ActiveValue::Set(channel_id),
user_id: ActiveValue::Set(invitee_id),
accepted: ActiveValue::Set(false),
- admin: ActiveValue::Set(is_admin),
- ..Default::default()
+ role: ActiveValue::Set(role),
}
.insert(&*tx)
.await?;
@@ -212,14 +322,14 @@ impl Database {
channel_id: ChannelId,
user_id: UserId,
new_name: &str,
- ) -> Result {
+ ) -> Result {
self.transaction(move |tx| async move {
let new_name = Self::sanitize_channel_name(new_name)?.to_string();
self.check_user_is_channel_admin(channel_id, user_id, &*tx)
.await?;
- channel::ActiveModel {
+ let channel = channel::ActiveModel {
id: ActiveValue::Unchanged(channel_id),
name: ActiveValue::Set(new_name.clone()),
..Default::default()
@@ -227,7 +337,11 @@ impl Database {
.update(&*tx)
.await?;
- Ok(new_name)
+ Ok(Channel {
+ id: channel.id,
+ name: channel.name,
+ visibility: channel.visibility,
+ })
})
.await
}
@@ -293,10 +407,10 @@ impl Database {
&self,
channel_id: ChannelId,
member_id: UserId,
- remover_id: UserId,
+ admin_id: UserId,
) -> Result