From 672cf6b8c789a9882effab185ad120e6adff42ca Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Wed, 19 Apr 2023 12:19:24 +0200 Subject: [PATCH] Relay buffer change events to Copilot --- Cargo.lock | 1 + crates/copilot/src/copilot.rs | 338 ++++++++++++++++++++++++---------- crates/copilot/src/request.rs | 3 - crates/project/Cargo.toml | 1 + crates/project/src/project.rs | 24 +++ 5 files changed, 268 insertions(+), 99 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6f05512b76..bb931853fc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4687,6 +4687,7 @@ dependencies = [ "client", "clock", "collections", + "copilot", "ctor", "db", "env_logger", diff --git a/crates/copilot/src/copilot.rs b/crates/copilot/src/copilot.rs index 1967c3cd14..57abd08939 100644 --- a/crates/copilot/src/copilot.rs +++ b/crates/copilot/src/copilot.rs @@ -6,8 +6,13 @@ use async_compression::futures::bufread::GzipDecoder; use async_tar::Archive; use collections::HashMap; use futures::{future::Shared, Future, FutureExt, TryFutureExt}; -use gpui::{actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task}; -use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, Language, ToPointUtf16}; +use gpui::{ + actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle, +}; +use language::{ + point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, Language, PointUtf16, + ToPointUtf16, +}; use log::{debug, error}; use lsp::LanguageServer; use node_runtime::NodeRuntime; @@ -105,7 +110,7 @@ enum CopilotServer { Started { server: Arc, status: SignInStatus, - subscriptions_by_buffer_id: HashMap, + registered_buffers: HashMap, }, } @@ -141,6 +146,66 @@ impl Status { } } +struct RegisteredBuffer { + uri: lsp::Url, + snapshot: Option<(i32, BufferSnapshot)>, + _subscriptions: [gpui::Subscription; 2], +} + +impl RegisteredBuffer { + fn report_changes( + &mut self, + buffer: &ModelHandle, + server: &LanguageServer, + cx: &AppContext, + ) -> Result<(i32, BufferSnapshot)> { + let buffer = buffer.read(cx); + let (version, prev_snapshot) = self + .snapshot + .as_ref() + .ok_or_else(|| anyhow!("expected at least one snapshot"))?; + let next_snapshot = buffer.snapshot(); + + let content_changes = buffer + .edits_since::<(PointUtf16, usize)>(prev_snapshot.version()) + .map(|edit| { + let edit_start = edit.new.start.0; + let edit_end = edit_start + (edit.old.end.0 - edit.old.start.0); + let new_text = next_snapshot + .text_for_range(edit.new.start.1..edit.new.end.1) + .collect(); + lsp::TextDocumentContentChangeEvent { + range: Some(lsp::Range::new( + point_to_lsp(edit_start), + point_to_lsp(edit_end), + )), + range_length: None, + text: new_text, + } + }) + .collect::>(); + + if content_changes.is_empty() { + Ok((*version, prev_snapshot.clone())) + } else { + let next_version = version + 1; + self.snapshot = Some((next_version, next_snapshot.clone())); + + server.notify::( + lsp::DidChangeTextDocumentParams { + text_document: lsp::VersionedTextDocumentIdentifier::new( + self.uri.clone(), + next_version, + ), + content_changes, + }, + )?; + + Ok((next_version, next_snapshot)) + } + } +} + #[derive(Debug, PartialEq, Eq)] pub struct Completion { pub range: Range, @@ -151,6 +216,7 @@ pub struct Copilot { http: Arc, node_runtime: Arc, server: CopilotServer, + buffers: HashMap>, } impl Entity for Copilot { @@ -212,12 +278,14 @@ impl Copilot { http, node_runtime, server: CopilotServer::Starting { task: start_task }, + buffers: Default::default(), } } else { Self { http, node_runtime, server: CopilotServer::Disabled, + buffers: Default::default(), } } } @@ -233,8 +301,9 @@ impl Copilot { server: CopilotServer::Started { server: Arc::new(server), status: SignInStatus::Authorized, - subscriptions_by_buffer_id: Default::default(), + registered_buffers: Default::default(), }, + buffers: Default::default(), }); (this, fake_server) } @@ -297,7 +366,7 @@ impl Copilot { this.server = CopilotServer::Started { server, status: SignInStatus::SignedOut, - subscriptions_by_buffer_id: Default::default(), + registered_buffers: Default::default(), }; this.update_sign_in_status(status, cx); } @@ -396,10 +465,8 @@ impl Copilot { } fn sign_out(&mut self, cx: &mut ModelContext) -> Task> { - if let CopilotServer::Started { server, status, .. } = &mut self.server { - *status = SignInStatus::SignedOut; - cx.notify(); - + self.update_sign_in_status(request::SignInStatus::NotSignedIn, cx); + if let CopilotServer::Started { server, .. } = &self.server { let server = server.clone(); cx.background().spawn(async move { server @@ -433,6 +500,108 @@ impl Copilot { cx.foreground().spawn(start_task) } + pub fn register_buffer(&mut self, buffer: &ModelHandle, cx: &mut ModelContext) { + let buffer_id = buffer.id(); + self.buffers.insert(buffer_id, buffer.downgrade()); + + if let CopilotServer::Started { + server, + status, + registered_buffers, + .. + } = &mut self.server + { + if !matches!(status, SignInStatus::Authorized { .. }) { + return; + } + + let uri: lsp::Url = format!("buffer://{}", buffer_id).parse().unwrap(); + registered_buffers.entry(buffer.id()).or_insert_with(|| { + let snapshot = buffer.read(cx).snapshot(); + server + .notify::( + lsp::DidOpenTextDocumentParams { + text_document: lsp::TextDocumentItem { + uri: uri.clone(), + language_id: id_for_language(buffer.read(cx).language()), + version: 0, + text: snapshot.text(), + }, + }, + ) + .log_err(); + + RegisteredBuffer { + uri, + snapshot: Some((0, snapshot)), + _subscriptions: [ + cx.subscribe(buffer, |this, buffer, event, cx| { + this.handle_buffer_event(buffer, event, cx).log_err(); + }), + cx.observe_release(buffer, move |this, _buffer, _cx| { + this.buffers.remove(&buffer_id); + this.unregister_buffer(buffer_id); + }), + ], + } + }); + } + } + + fn handle_buffer_event( + &mut self, + buffer: ModelHandle, + event: &language::Event, + cx: &mut ModelContext, + ) -> Result<()> { + if let CopilotServer::Started { + server, + registered_buffers, + .. + } = &mut self.server + { + if let Some(registered_buffer) = registered_buffers.get_mut(&buffer.id()) { + match event { + language::Event::Edited => { + registered_buffer.report_changes(&buffer, server, cx)?; + } + language::Event::Saved => { + server.notify::( + lsp::DidSaveTextDocumentParams { + text_document: lsp::TextDocumentIdentifier::new( + registered_buffer.uri.clone(), + ), + text: None, + }, + )?; + } + _ => {} + } + } + } + + Ok(()) + } + + fn unregister_buffer(&mut self, buffer_id: usize) { + if let CopilotServer::Started { + server, + registered_buffers, + .. + } = &mut self.server + { + if let Some(buffer) = registered_buffers.remove(&buffer_id) { + server + .notify::( + lsp::DidCloseTextDocumentParams { + text_document: lsp::TextDocumentIdentifier::new(buffer.uri), + }, + ) + .log_err(); + } + } + } + pub fn completions( &mut self, buffer: &ModelHandle, @@ -464,16 +633,14 @@ impl Copilot { cx: &mut ModelContext, ) -> Task>> where - R: lsp::request::Request< - Params = request::GetCompletionsParams, - Result = request::GetCompletionsResult, - >, + R: 'static + + lsp::request::Request< + Params = request::GetCompletionsParams, + Result = request::GetCompletionsResult, + >, T: ToPointUtf16, { - let buffer_id = buffer.id(); - let uri: lsp::Url = format!("buffer://{}", buffer_id).parse().unwrap(); - let snapshot = buffer.read(cx).snapshot(); - let server = match &mut self.server { + let (server, registered_buffer) = match &mut self.server { CopilotServer::Starting { .. } => { return Task::ready(Err(anyhow!("copilot is still starting"))) } @@ -487,56 +654,28 @@ impl Copilot { CopilotServer::Started { server, status, - subscriptions_by_buffer_id, + registered_buffers, + .. } => { if matches!(status, SignInStatus::Authorized { .. }) { - subscriptions_by_buffer_id - .entry(buffer_id) - .or_insert_with(|| { - server - .notify::( - lsp::DidOpenTextDocumentParams { - text_document: lsp::TextDocumentItem { - uri: uri.clone(), - language_id: id_for_language( - buffer.read(cx).language(), - ), - version: 0, - text: snapshot.text(), - }, - }, - ) - .log_err(); - - let uri = uri.clone(); - cx.observe_release(buffer, move |this, _, _| { - if let CopilotServer::Started { - server, - subscriptions_by_buffer_id, - .. - } = &mut this.server - { - server - .notify::( - lsp::DidCloseTextDocumentParams { - text_document: lsp::TextDocumentIdentifier::new( - uri.clone(), - ), - }, - ) - .log_err(); - subscriptions_by_buffer_id.remove(&buffer_id); - } - }) - }); - - server.clone() + if let Some(registered_buffer) = registered_buffers.get_mut(&buffer.id()) { + (server.clone(), registered_buffer) + } else { + return Task::ready(Err(anyhow!( + "requested completions for an unregistered buffer" + ))); + } } else { return Task::ready(Err(anyhow!("must sign in before using copilot"))); } } }; + let (version, snapshot) = match registered_buffer.report_changes(buffer, &server, cx) { + Ok((version, snapshot)) => (version, snapshot), + Err(error) => return Task::ready(Err(error)), + }; + let uri = registered_buffer.uri.clone(); let settings = cx.global::(); let position = position.to_point_utf16(&snapshot); let language = snapshot.language_at(position); @@ -544,39 +683,23 @@ impl Copilot { let language_name = language_name.as_deref(); let tab_size = settings.tab_size(language_name); let hard_tabs = settings.hard_tabs(language_name); - let language_id = id_for_language(language); - - let path; - let relative_path; - if let Some(file) = snapshot.file() { - if let Some(file) = file.as_local() { - path = file.abs_path(cx); - } else { - path = file.full_path(cx); - } - relative_path = file.path().to_path_buf(); - } else { - path = PathBuf::new(); - relative_path = PathBuf::new(); - } - + let relative_path = snapshot + .file() + .map(|file| file.path().to_path_buf()) + .unwrap_or_default(); + let request = server.request::(request::GetCompletionsParams { + doc: request::GetCompletionsDocument { + uri, + tab_size: tab_size.into(), + indent_size: 1, + insert_spaces: !hard_tabs, + relative_path: relative_path.to_string_lossy().into(), + position: point_to_lsp(position), + version: version.try_into().unwrap(), + }, + }); cx.background().spawn(async move { - let result = server - .request::(request::GetCompletionsParams { - doc: request::GetCompletionsDocument { - source: snapshot.text(), - tab_size: tab_size.into(), - indent_size: 1, - insert_spaces: !hard_tabs, - uri, - path: path.to_string_lossy().into(), - relative_path: relative_path.to_string_lossy().into(), - language_id, - position: point_to_lsp(position), - version: 0, - }, - }) - .await?; + let result = request.await?; let completions = result .completions .into_iter() @@ -616,14 +739,37 @@ impl Copilot { lsp_status: request::SignInStatus, cx: &mut ModelContext, ) { + self.buffers.retain(|_, buffer| buffer.is_upgradable(cx)); + if let CopilotServer::Started { status, .. } = &mut self.server { - *status = match lsp_status { + match lsp_status { request::SignInStatus::Ok { .. } | request::SignInStatus::MaybeOk { .. } - | request::SignInStatus::AlreadySignedIn { .. } => SignInStatus::Authorized, - request::SignInStatus::NotAuthorized { .. } => SignInStatus::Unauthorized, - request::SignInStatus::NotSignedIn => SignInStatus::SignedOut, - }; + | request::SignInStatus::AlreadySignedIn { .. } => { + *status = SignInStatus::Authorized; + + for buffer in self.buffers.values().cloned().collect::>() { + if let Some(buffer) = buffer.upgrade(cx) { + self.register_buffer(&buffer, cx); + } + } + } + request::SignInStatus::NotAuthorized { .. } => { + *status = SignInStatus::Unauthorized; + + for buffer_id in self.buffers.keys().copied().collect::>() { + self.unregister_buffer(buffer_id); + } + } + request::SignInStatus::NotSignedIn => { + *status = SignInStatus::SignedOut; + + for buffer_id in self.buffers.keys().copied().collect::>() { + self.unregister_buffer(buffer_id); + } + } + } + cx.notify(); } } diff --git a/crates/copilot/src/request.rs b/crates/copilot/src/request.rs index 415f160ea3..08173c413a 100644 --- a/crates/copilot/src/request.rs +++ b/crates/copilot/src/request.rs @@ -99,14 +99,11 @@ pub struct GetCompletionsParams { #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct GetCompletionsDocument { - pub source: String, pub tab_size: u32, pub indent_size: u32, pub insert_spaces: bool, pub uri: lsp::Url, - pub path: String, pub relative_path: String, - pub language_id: String, pub position: lsp::Position, pub version: usize, } diff --git a/crates/project/Cargo.toml b/crates/project/Cargo.toml index f5c144a3ad..e30ab56e45 100644 --- a/crates/project/Cargo.toml +++ b/crates/project/Cargo.toml @@ -19,6 +19,7 @@ test-support = [ [dependencies] text = { path = "../text" } +copilot = { path = "../copilot" } client = { path = "../client" } clock = { path = "../clock" } collections = { path = "../collections" } diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index d126cb4994..d5b7ac3f3f 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -12,6 +12,7 @@ use anyhow::{anyhow, Context, Result}; use client::{proto, Client, TypedEnvelope, UserStore}; use clock::ReplicaId; use collections::{hash_map, BTreeMap, HashMap, HashSet}; +use copilot::Copilot; use futures::{ channel::mpsc::{self, UnboundedReceiver}, future::{try_join_all, Shared}, @@ -129,6 +130,7 @@ pub struct Project { _maintain_buffer_languages: Task<()>, _maintain_workspace_config: Task<()>, terminals: Terminals, + copilot_enabled: bool, } enum BufferMessage { @@ -472,6 +474,7 @@ impl Project { terminals: Terminals { local_handles: Vec::new(), }, + copilot_enabled: Copilot::global(cx).is_some(), } }) } @@ -559,6 +562,7 @@ impl Project { terminals: Terminals { local_handles: Vec::new(), }, + copilot_enabled: Copilot::global(cx).is_some(), }; for worktree in worktrees { let _ = this.add_worktree(&worktree, cx); @@ -664,6 +668,15 @@ impl Project { self.start_language_server(worktree_id, worktree_path, language, cx); } + if !self.copilot_enabled && Copilot::global(cx).is_some() { + self.copilot_enabled = true; + for buffer in self.opened_buffers.values() { + if let Some(buffer) = buffer.upgrade(cx) { + self.register_buffer_with_copilot(&buffer, cx); + } + } + } + cx.notify(); } @@ -1616,6 +1629,7 @@ impl Project { self.detect_language_for_buffer(buffer, cx); self.register_buffer_with_language_server(buffer, cx); + self.register_buffer_with_copilot(buffer, cx); cx.observe_release(buffer, |this, buffer, cx| { if let Some(file) = File::from_dyn(buffer.file()) { if file.is_local() { @@ -1731,6 +1745,16 @@ impl Project { }); } + fn register_buffer_with_copilot( + &self, + buffer_handle: &ModelHandle, + cx: &mut ModelContext, + ) { + if let Some(copilot) = Copilot::global(cx) { + copilot.update(cx, |copilot, cx| copilot.register_buffer(buffer_handle, cx)); + } + } + async fn send_buffer_messages( this: WeakModelHandle, rx: UnboundedReceiver,