diff --git a/Cargo.lock b/Cargo.lock index a5886a9466..202b511a6d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1340,6 +1340,7 @@ dependencies = [ "client", "futures 0.3.25", "gpui", + "language", "log", "lsp", "serde", diff --git a/crates/copilot/Cargo.toml b/crates/copilot/Cargo.toml index 301051a3b0..190a399475 100644 --- a/crates/copilot/Cargo.toml +++ b/crates/copilot/Cargo.toml @@ -10,6 +10,7 @@ doctest = false [dependencies] gpui = { path = "../gpui" } +language = { path = "../language" } settings = { path = "../settings" } lsp = { path = "../lsp" } util = { path = "../util" } diff --git a/crates/copilot/src/copilot.rs b/crates/copilot/src/copilot.rs index c768442611..4abdef0ab4 100644 --- a/crates/copilot/src/copilot.rs +++ b/crates/copilot/src/copilot.rs @@ -4,6 +4,7 @@ use anyhow::{anyhow, Result}; use async_compression::futures::bufread::GzipDecoder; use client::Client; use gpui::{actions, AppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task}; +use language::{Buffer, ToPointUtf16}; use lsp::LanguageServer; use smol::{fs, io::BufReader, stream::StreamExt}; use std::{ @@ -38,7 +39,7 @@ pub fn init(client: Arc, cx: &mut MutableAppContext) { enum CopilotServer { Downloading, - Error(String), + Error(Arc), Started { server: Arc, status: SignInStatus, @@ -59,6 +60,21 @@ pub enum Event { }, } +#[derive(Debug)] +pub enum Status { + Downloading, + Error(Arc), + SignedOut, + Unauthorized, + Authorized, +} + +impl Status { + fn is_authorized(&self) -> bool { + matches!(self, Status::Authorized) + } +} + struct Copilot { server: CopilotServer, } @@ -70,7 +86,12 @@ impl Entity for Copilot { impl Copilot { fn global(cx: &AppContext) -> Option> { if cx.has_global::>() { - Some(cx.global::>().clone()) + let copilot = cx.global::>().clone(); + if copilot.read(cx).status().is_authorized() { + Some(copilot) + } else { + None + } } else { None } @@ -103,7 +124,7 @@ impl Copilot { this.update_sign_in_status(status, cx); } Err(error) => { - this.server = CopilotServer::Error(error.to_string()); + this.server = CopilotServer::Error(error.to_string().into()); } } }) @@ -163,6 +184,35 @@ impl Copilot { } } + pub fn completions( + &self, + buffer: &ModelHandle, + position: T, + cx: &mut ModelContext, + ) -> Task> + where + T: ToPointUtf16, + { + let server = match self.authenticated_server() { + Ok(server) => server, + Err(error) => return Task::ready(Err(error)), + }; + + cx.spawn(|this, cx| async move { anyhow::Ok(()) }) + } + + pub fn status(&self) -> Status { + match &self.server { + CopilotServer::Downloading => Status::Downloading, + CopilotServer::Error(error) => Status::Error(error.clone()), + CopilotServer::Started { status, .. } => match status { + SignInStatus::Authorized { .. } => Status::Authorized, + SignInStatus::Unauthorized { .. } => Status::Unauthorized, + SignInStatus::SignedOut => Status::SignedOut, + }, + } + } + fn update_sign_in_status( &mut self, lsp_status: request::SignInStatus, @@ -181,6 +231,23 @@ impl Copilot { cx.notify(); } } + + fn authenticated_server(&self) -> Result> { + match &self.server { + CopilotServer::Downloading => Err(anyhow!("copilot is still downloading")), + CopilotServer::Error(error) => Err(anyhow!( + "copilot was not started because of an error: {}", + error + )), + CopilotServer::Started { server, status } => { + if matches!(status, SignInStatus::Authorized { .. }) { + Ok(server.clone()) + } else { + Err(anyhow!("must sign in before using copilot")) + } + } + } + } } async fn get_lsp_binary(http: Arc) -> anyhow::Result {