Relay buffer change events to Copilot

This commit is contained in:
Antonio Scandurra 2023-04-19 12:19:24 +02:00
parent ce8442a3d8
commit 672cf6b8c7
5 changed files with 268 additions and 99 deletions

1
Cargo.lock generated
View file

@ -4687,6 +4687,7 @@ dependencies = [
"client", "client",
"clock", "clock",
"collections", "collections",
"copilot",
"ctor", "ctor",
"db", "db",
"env_logger", "env_logger",

View file

@ -6,8 +6,13 @@ use async_compression::futures::bufread::GzipDecoder;
use async_tar::Archive; use async_tar::Archive;
use collections::HashMap; use collections::HashMap;
use futures::{future::Shared, Future, FutureExt, TryFutureExt}; use futures::{future::Shared, Future, FutureExt, TryFutureExt};
use gpui::{actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task}; use gpui::{
use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, Language, ToPointUtf16}; actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle,
};
use language::{
point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, Language, PointUtf16,
ToPointUtf16,
};
use log::{debug, error}; use log::{debug, error};
use lsp::LanguageServer; use lsp::LanguageServer;
use node_runtime::NodeRuntime; use node_runtime::NodeRuntime;
@ -105,7 +110,7 @@ enum CopilotServer {
Started { Started {
server: Arc<LanguageServer>, server: Arc<LanguageServer>,
status: SignInStatus, status: SignInStatus,
subscriptions_by_buffer_id: HashMap<usize, gpui::Subscription>, registered_buffers: HashMap<usize, RegisteredBuffer>,
}, },
} }
@ -141,6 +146,66 @@ impl Status {
} }
} }
struct RegisteredBuffer {
uri: lsp::Url,
snapshot: Option<(i32, BufferSnapshot)>,
_subscriptions: [gpui::Subscription; 2],
}
impl RegisteredBuffer {
fn report_changes(
&mut self,
buffer: &ModelHandle<Buffer>,
server: &LanguageServer,
cx: &AppContext,
) -> Result<(i32, BufferSnapshot)> {
let buffer = buffer.read(cx);
let (version, prev_snapshot) = self
.snapshot
.as_ref()
.ok_or_else(|| anyhow!("expected at least one snapshot"))?;
let next_snapshot = buffer.snapshot();
let content_changes = buffer
.edits_since::<(PointUtf16, usize)>(prev_snapshot.version())
.map(|edit| {
let edit_start = edit.new.start.0;
let edit_end = edit_start + (edit.old.end.0 - edit.old.start.0);
let new_text = next_snapshot
.text_for_range(edit.new.start.1..edit.new.end.1)
.collect();
lsp::TextDocumentContentChangeEvent {
range: Some(lsp::Range::new(
point_to_lsp(edit_start),
point_to_lsp(edit_end),
)),
range_length: None,
text: new_text,
}
})
.collect::<Vec<_>>();
if content_changes.is_empty() {
Ok((*version, prev_snapshot.clone()))
} else {
let next_version = version + 1;
self.snapshot = Some((next_version, next_snapshot.clone()));
server.notify::<lsp::notification::DidChangeTextDocument>(
lsp::DidChangeTextDocumentParams {
text_document: lsp::VersionedTextDocumentIdentifier::new(
self.uri.clone(),
next_version,
),
content_changes,
},
)?;
Ok((next_version, next_snapshot))
}
}
}
#[derive(Debug, PartialEq, Eq)] #[derive(Debug, PartialEq, Eq)]
pub struct Completion { pub struct Completion {
pub range: Range<Anchor>, pub range: Range<Anchor>,
@ -151,6 +216,7 @@ pub struct Copilot {
http: Arc<dyn HttpClient>, http: Arc<dyn HttpClient>,
node_runtime: Arc<NodeRuntime>, node_runtime: Arc<NodeRuntime>,
server: CopilotServer, server: CopilotServer,
buffers: HashMap<usize, WeakModelHandle<Buffer>>,
} }
impl Entity for Copilot { impl Entity for Copilot {
@ -212,12 +278,14 @@ impl Copilot {
http, http,
node_runtime, node_runtime,
server: CopilotServer::Starting { task: start_task }, server: CopilotServer::Starting { task: start_task },
buffers: Default::default(),
} }
} else { } else {
Self { Self {
http, http,
node_runtime, node_runtime,
server: CopilotServer::Disabled, server: CopilotServer::Disabled,
buffers: Default::default(),
} }
} }
} }
@ -233,8 +301,9 @@ impl Copilot {
server: CopilotServer::Started { server: CopilotServer::Started {
server: Arc::new(server), server: Arc::new(server),
status: SignInStatus::Authorized, status: SignInStatus::Authorized,
subscriptions_by_buffer_id: Default::default(), registered_buffers: Default::default(),
}, },
buffers: Default::default(),
}); });
(this, fake_server) (this, fake_server)
} }
@ -297,7 +366,7 @@ impl Copilot {
this.server = CopilotServer::Started { this.server = CopilotServer::Started {
server, server,
status: SignInStatus::SignedOut, status: SignInStatus::SignedOut,
subscriptions_by_buffer_id: Default::default(), registered_buffers: Default::default(),
}; };
this.update_sign_in_status(status, cx); this.update_sign_in_status(status, cx);
} }
@ -396,10 +465,8 @@ impl Copilot {
} }
fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> { fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
if let CopilotServer::Started { server, status, .. } = &mut self.server { self.update_sign_in_status(request::SignInStatus::NotSignedIn, cx);
*status = SignInStatus::SignedOut; if let CopilotServer::Started { server, .. } = &self.server {
cx.notify();
let server = server.clone(); let server = server.clone();
cx.background().spawn(async move { cx.background().spawn(async move {
server server
@ -433,6 +500,108 @@ impl Copilot {
cx.foreground().spawn(start_task) cx.foreground().spawn(start_task)
} }
pub fn register_buffer(&mut self, buffer: &ModelHandle<Buffer>, cx: &mut ModelContext<Self>) {
let buffer_id = buffer.id();
self.buffers.insert(buffer_id, buffer.downgrade());
if let CopilotServer::Started {
server,
status,
registered_buffers,
..
} = &mut self.server
{
if !matches!(status, SignInStatus::Authorized { .. }) {
return;
}
let uri: lsp::Url = format!("buffer://{}", buffer_id).parse().unwrap();
registered_buffers.entry(buffer.id()).or_insert_with(|| {
let snapshot = buffer.read(cx).snapshot();
server
.notify::<lsp::notification::DidOpenTextDocument>(
lsp::DidOpenTextDocumentParams {
text_document: lsp::TextDocumentItem {
uri: uri.clone(),
language_id: id_for_language(buffer.read(cx).language()),
version: 0,
text: snapshot.text(),
},
},
)
.log_err();
RegisteredBuffer {
uri,
snapshot: Some((0, snapshot)),
_subscriptions: [
cx.subscribe(buffer, |this, buffer, event, cx| {
this.handle_buffer_event(buffer, event, cx).log_err();
}),
cx.observe_release(buffer, move |this, _buffer, _cx| {
this.buffers.remove(&buffer_id);
this.unregister_buffer(buffer_id);
}),
],
}
});
}
}
fn handle_buffer_event(
&mut self,
buffer: ModelHandle<Buffer>,
event: &language::Event,
cx: &mut ModelContext<Self>,
) -> Result<()> {
if let CopilotServer::Started {
server,
registered_buffers,
..
} = &mut self.server
{
if let Some(registered_buffer) = registered_buffers.get_mut(&buffer.id()) {
match event {
language::Event::Edited => {
registered_buffer.report_changes(&buffer, server, cx)?;
}
language::Event::Saved => {
server.notify::<lsp::notification::DidSaveTextDocument>(
lsp::DidSaveTextDocumentParams {
text_document: lsp::TextDocumentIdentifier::new(
registered_buffer.uri.clone(),
),
text: None,
},
)?;
}
_ => {}
}
}
}
Ok(())
}
fn unregister_buffer(&mut self, buffer_id: usize) {
if let CopilotServer::Started {
server,
registered_buffers,
..
} = &mut self.server
{
if let Some(buffer) = registered_buffers.remove(&buffer_id) {
server
.notify::<lsp::notification::DidCloseTextDocument>(
lsp::DidCloseTextDocumentParams {
text_document: lsp::TextDocumentIdentifier::new(buffer.uri),
},
)
.log_err();
}
}
}
pub fn completions<T>( pub fn completions<T>(
&mut self, &mut self,
buffer: &ModelHandle<Buffer>, buffer: &ModelHandle<Buffer>,
@ -464,16 +633,14 @@ impl Copilot {
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) -> Task<Result<Vec<Completion>>> ) -> Task<Result<Vec<Completion>>>
where where
R: lsp::request::Request< R: 'static
Params = request::GetCompletionsParams, + lsp::request::Request<
Result = request::GetCompletionsResult, Params = request::GetCompletionsParams,
>, Result = request::GetCompletionsResult,
>,
T: ToPointUtf16, T: ToPointUtf16,
{ {
let buffer_id = buffer.id(); let (server, registered_buffer) = match &mut self.server {
let uri: lsp::Url = format!("buffer://{}", buffer_id).parse().unwrap();
let snapshot = buffer.read(cx).snapshot();
let server = match &mut self.server {
CopilotServer::Starting { .. } => { CopilotServer::Starting { .. } => {
return Task::ready(Err(anyhow!("copilot is still starting"))) return Task::ready(Err(anyhow!("copilot is still starting")))
} }
@ -487,56 +654,28 @@ impl Copilot {
CopilotServer::Started { CopilotServer::Started {
server, server,
status, status,
subscriptions_by_buffer_id, registered_buffers,
..
} => { } => {
if matches!(status, SignInStatus::Authorized { .. }) { if matches!(status, SignInStatus::Authorized { .. }) {
subscriptions_by_buffer_id if let Some(registered_buffer) = registered_buffers.get_mut(&buffer.id()) {
.entry(buffer_id) (server.clone(), registered_buffer)
.or_insert_with(|| { } else {
server return Task::ready(Err(anyhow!(
.notify::<lsp::notification::DidOpenTextDocument>( "requested completions for an unregistered buffer"
lsp::DidOpenTextDocumentParams { )));
text_document: lsp::TextDocumentItem { }
uri: uri.clone(),
language_id: id_for_language(
buffer.read(cx).language(),
),
version: 0,
text: snapshot.text(),
},
},
)
.log_err();
let uri = uri.clone();
cx.observe_release(buffer, move |this, _, _| {
if let CopilotServer::Started {
server,
subscriptions_by_buffer_id,
..
} = &mut this.server
{
server
.notify::<lsp::notification::DidCloseTextDocument>(
lsp::DidCloseTextDocumentParams {
text_document: lsp::TextDocumentIdentifier::new(
uri.clone(),
),
},
)
.log_err();
subscriptions_by_buffer_id.remove(&buffer_id);
}
})
});
server.clone()
} else { } else {
return Task::ready(Err(anyhow!("must sign in before using copilot"))); return Task::ready(Err(anyhow!("must sign in before using copilot")));
} }
} }
}; };
let (version, snapshot) = match registered_buffer.report_changes(buffer, &server, cx) {
Ok((version, snapshot)) => (version, snapshot),
Err(error) => return Task::ready(Err(error)),
};
let uri = registered_buffer.uri.clone();
let settings = cx.global::<Settings>(); let settings = cx.global::<Settings>();
let position = position.to_point_utf16(&snapshot); let position = position.to_point_utf16(&snapshot);
let language = snapshot.language_at(position); let language = snapshot.language_at(position);
@ -544,39 +683,23 @@ impl Copilot {
let language_name = language_name.as_deref(); let language_name = language_name.as_deref();
let tab_size = settings.tab_size(language_name); let tab_size = settings.tab_size(language_name);
let hard_tabs = settings.hard_tabs(language_name); let hard_tabs = settings.hard_tabs(language_name);
let language_id = id_for_language(language); let relative_path = snapshot
.file()
let path; .map(|file| file.path().to_path_buf())
let relative_path; .unwrap_or_default();
if let Some(file) = snapshot.file() { let request = server.request::<R>(request::GetCompletionsParams {
if let Some(file) = file.as_local() { doc: request::GetCompletionsDocument {
path = file.abs_path(cx); uri,
} else { tab_size: tab_size.into(),
path = file.full_path(cx); indent_size: 1,
} insert_spaces: !hard_tabs,
relative_path = file.path().to_path_buf(); relative_path: relative_path.to_string_lossy().into(),
} else { position: point_to_lsp(position),
path = PathBuf::new(); version: version.try_into().unwrap(),
relative_path = PathBuf::new(); },
} });
cx.background().spawn(async move { cx.background().spawn(async move {
let result = server let result = request.await?;
.request::<R>(request::GetCompletionsParams {
doc: request::GetCompletionsDocument {
source: snapshot.text(),
tab_size: tab_size.into(),
indent_size: 1,
insert_spaces: !hard_tabs,
uri,
path: path.to_string_lossy().into(),
relative_path: relative_path.to_string_lossy().into(),
language_id,
position: point_to_lsp(position),
version: 0,
},
})
.await?;
let completions = result let completions = result
.completions .completions
.into_iter() .into_iter()
@ -616,14 +739,37 @@ impl Copilot {
lsp_status: request::SignInStatus, lsp_status: request::SignInStatus,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) { ) {
self.buffers.retain(|_, buffer| buffer.is_upgradable(cx));
if let CopilotServer::Started { status, .. } = &mut self.server { if let CopilotServer::Started { status, .. } = &mut self.server {
*status = match lsp_status { match lsp_status {
request::SignInStatus::Ok { .. } request::SignInStatus::Ok { .. }
| request::SignInStatus::MaybeOk { .. } | request::SignInStatus::MaybeOk { .. }
| request::SignInStatus::AlreadySignedIn { .. } => SignInStatus::Authorized, | request::SignInStatus::AlreadySignedIn { .. } => {
request::SignInStatus::NotAuthorized { .. } => SignInStatus::Unauthorized, *status = SignInStatus::Authorized;
request::SignInStatus::NotSignedIn => SignInStatus::SignedOut,
}; for buffer in self.buffers.values().cloned().collect::<Vec<_>>() {
if let Some(buffer) = buffer.upgrade(cx) {
self.register_buffer(&buffer, cx);
}
}
}
request::SignInStatus::NotAuthorized { .. } => {
*status = SignInStatus::Unauthorized;
for buffer_id in self.buffers.keys().copied().collect::<Vec<_>>() {
self.unregister_buffer(buffer_id);
}
}
request::SignInStatus::NotSignedIn => {
*status = SignInStatus::SignedOut;
for buffer_id in self.buffers.keys().copied().collect::<Vec<_>>() {
self.unregister_buffer(buffer_id);
}
}
}
cx.notify(); cx.notify();
} }
} }

View file

@ -99,14 +99,11 @@ pub struct GetCompletionsParams {
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct GetCompletionsDocument { pub struct GetCompletionsDocument {
pub source: String,
pub tab_size: u32, pub tab_size: u32,
pub indent_size: u32, pub indent_size: u32,
pub insert_spaces: bool, pub insert_spaces: bool,
pub uri: lsp::Url, pub uri: lsp::Url,
pub path: String,
pub relative_path: String, pub relative_path: String,
pub language_id: String,
pub position: lsp::Position, pub position: lsp::Position,
pub version: usize, pub version: usize,
} }

View file

@ -19,6 +19,7 @@ test-support = [
[dependencies] [dependencies]
text = { path = "../text" } text = { path = "../text" }
copilot = { path = "../copilot" }
client = { path = "../client" } client = { path = "../client" }
clock = { path = "../clock" } clock = { path = "../clock" }
collections = { path = "../collections" } collections = { path = "../collections" }

View file

@ -12,6 +12,7 @@ use anyhow::{anyhow, Context, Result};
use client::{proto, Client, TypedEnvelope, UserStore}; use client::{proto, Client, TypedEnvelope, UserStore};
use clock::ReplicaId; use clock::ReplicaId;
use collections::{hash_map, BTreeMap, HashMap, HashSet}; use collections::{hash_map, BTreeMap, HashMap, HashSet};
use copilot::Copilot;
use futures::{ use futures::{
channel::mpsc::{self, UnboundedReceiver}, channel::mpsc::{self, UnboundedReceiver},
future::{try_join_all, Shared}, future::{try_join_all, Shared},
@ -129,6 +130,7 @@ pub struct Project {
_maintain_buffer_languages: Task<()>, _maintain_buffer_languages: Task<()>,
_maintain_workspace_config: Task<()>, _maintain_workspace_config: Task<()>,
terminals: Terminals, terminals: Terminals,
copilot_enabled: bool,
} }
enum BufferMessage { enum BufferMessage {
@ -472,6 +474,7 @@ impl Project {
terminals: Terminals { terminals: Terminals {
local_handles: Vec::new(), local_handles: Vec::new(),
}, },
copilot_enabled: Copilot::global(cx).is_some(),
} }
}) })
} }
@ -559,6 +562,7 @@ impl Project {
terminals: Terminals { terminals: Terminals {
local_handles: Vec::new(), local_handles: Vec::new(),
}, },
copilot_enabled: Copilot::global(cx).is_some(),
}; };
for worktree in worktrees { for worktree in worktrees {
let _ = this.add_worktree(&worktree, cx); let _ = this.add_worktree(&worktree, cx);
@ -664,6 +668,15 @@ impl Project {
self.start_language_server(worktree_id, worktree_path, language, cx); self.start_language_server(worktree_id, worktree_path, language, cx);
} }
if !self.copilot_enabled && Copilot::global(cx).is_some() {
self.copilot_enabled = true;
for buffer in self.opened_buffers.values() {
if let Some(buffer) = buffer.upgrade(cx) {
self.register_buffer_with_copilot(&buffer, cx);
}
}
}
cx.notify(); cx.notify();
} }
@ -1616,6 +1629,7 @@ impl Project {
self.detect_language_for_buffer(buffer, cx); self.detect_language_for_buffer(buffer, cx);
self.register_buffer_with_language_server(buffer, cx); self.register_buffer_with_language_server(buffer, cx);
self.register_buffer_with_copilot(buffer, cx);
cx.observe_release(buffer, |this, buffer, cx| { cx.observe_release(buffer, |this, buffer, cx| {
if let Some(file) = File::from_dyn(buffer.file()) { if let Some(file) = File::from_dyn(buffer.file()) {
if file.is_local() { if file.is_local() {
@ -1731,6 +1745,16 @@ impl Project {
}); });
} }
fn register_buffer_with_copilot(
&self,
buffer_handle: &ModelHandle<Buffer>,
cx: &mut ModelContext<Self>,
) {
if let Some(copilot) = Copilot::global(cx) {
copilot.update(cx, |copilot, cx| copilot.register_buffer(buffer_handle, cx));
}
}
async fn send_buffer_messages( async fn send_buffer_messages(
this: WeakModelHandle<Self>, this: WeakModelHandle<Self>,
rx: UnboundedReceiver<BufferMessage>, rx: UnboundedReceiver<BufferMessage>,