From 59d9277a74f32e9b2871fd09de05a9495e732247 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 23 Mar 2023 14:17:15 +0100 Subject: [PATCH] Implement Copilot sign in and sign out --- Cargo.lock | 1 + crates/copilot/Cargo.toml | 1 + crates/copilot/src/copilot.rs | 186 ++++++++++++++++++++++++++-------- crates/copilot/src/request.rs | 83 +++++++++++++-- 4 files changed, 223 insertions(+), 48 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 028e20d984..a5886a9466 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1340,6 +1340,7 @@ dependencies = [ "client", "futures 0.3.25", "gpui", + "log", "lsp", "serde", "serde_derive", diff --git a/crates/copilot/Cargo.toml b/crates/copilot/Cargo.toml index 4e1339ee50..301051a3b0 100644 --- a/crates/copilot/Cargo.toml +++ b/crates/copilot/Cargo.toml @@ -17,6 +17,7 @@ client = { path = "../client" } workspace = { path = "../workspace" } async-compression = { version = "0.3", features = ["gzip", "futures-bufread"] } anyhow = "1.0" +log = "0.4" serde = { workspace = true } serde_derive = { workspace = true } smol = "1.2.5" diff --git a/crates/copilot/src/copilot.rs b/crates/copilot/src/copilot.rs index 22d1246550..8bea11b1f7 100644 --- a/crates/copilot/src/copilot.rs +++ b/crates/copilot/src/copilot.rs @@ -3,7 +3,7 @@ mod request; use anyhow::{anyhow, Result}; use async_compression::futures::bufread::GzipDecoder; use client::Client; -use gpui::{actions, AppContext, Entity, ModelContext, ModelHandle, MutableAppContext}; +use gpui::{actions, AppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task}; use lsp::LanguageServer; use smol::{fs, io::BufReader, stream::StreamExt}; use std::{ @@ -15,11 +15,32 @@ use util::{ fs::remove_matching, github::latest_github_release, http::HttpClient, paths, ResultExt, }; -actions!(copilot, [SignIn]); +actions!(copilot, [SignIn, SignOut]); pub fn init(client: Arc, cx: &mut MutableAppContext) { - let copilot = cx.add_model(|cx| Copilot::start(client.http_client(), cx)); + let (copilot, task) = Copilot::start(client.http_client(), cx); cx.set_global(copilot); + cx.spawn(|mut cx| async move { + task.await?; + cx.update(|cx| { + cx.add_global_action(|_: &SignIn, cx: &mut MutableAppContext| { + if let Some(copilot) = Copilot::global(cx) { + copilot + .update(cx, |copilot, cx| copilot.sign_in(cx)) + .detach_and_log_err(cx); + } + }); + cx.add_global_action(|_: &SignOut, cx: &mut MutableAppContext| { + if let Some(copilot) = Copilot::global(cx) { + copilot + .update(cx, |copilot, cx| copilot.sign_out(cx)) + .detach_and_log_err(cx); + } + }); + }); + anyhow::Ok(()) + }) + .detach_and_log_err(cx); } enum CopilotServer { @@ -31,18 +52,26 @@ enum CopilotServer { }, } +#[derive(Clone, Debug, PartialEq, Eq)] enum SignInStatus { - Authorized, - Unauthorized, + Authorized { user: String }, + Unauthorized { user: String }, SignedOut, } +pub enum Event { + PromptUserDeviceFlow { + user_code: String, + verification_uri: String, + }, +} + struct Copilot { server: CopilotServer, } impl Entity for Copilot { - type Event = (); + type Event = Event; } impl Copilot { @@ -54,46 +83,123 @@ impl Copilot { } } - fn start(http: Arc, cx: &mut ModelContext) -> Self { - let copilot = Self { + fn start( + http: Arc, + cx: &mut MutableAppContext, + ) -> (ModelHandle, Task>) { + let this = cx.add_model(|_| Self { server: CopilotServer::Downloading, - }; - cx.spawn(|this, mut cx| async move { - let start_language_server = async { - let server_path = get_lsp_binary(http).await?; - let server = - LanguageServer::new(0, &server_path, &["--stdio"], Path::new("/"), cx.clone())?; - let server = server.initialize(Default::default()).await?; - let status = server - .request::(request::CheckStatusParams { - local_checks_only: false, - }) - .await?; - let status = match status.status.as_str() { - "OK" | "MaybeOk" => SignInStatus::Authorized, - "NotAuthorized" => SignInStatus::Unauthorized, - _ => SignInStatus::SignedOut, + }); + let task = cx.spawn({ + let this = this.clone(); + |mut cx| async move { + let start_language_server = async { + let server_path = get_lsp_binary(http).await?; + let server = LanguageServer::new( + 0, + &server_path, + &["--stdio"], + Path::new("/"), + cx.clone(), + )?; + let server = server.initialize(Default::default()).await?; + let status = server + .request::(request::CheckStatusParams { + local_checks_only: false, + }) + .await?; + anyhow::Ok((server, status)) }; - anyhow::Ok((server, status)) - }; - let server = start_language_server.await; - this.update(&mut cx, |this, cx| { - cx.notify(); - match server { - Ok((server, status)) => { - this.server = CopilotServer::Started { server, status }; - Ok(()) - } - Err(error) => { - this.server = CopilotServer::Error(error.to_string()); - Err(error) + let server = start_language_server.await; + this.update(&mut cx, |this, cx| { + cx.notify(); + match server { + Ok((server, status)) => { + this.server = CopilotServer::Started { + server, + status: SignInStatus::SignedOut, + }; + this.update_sign_in_status(status, cx); + Ok(()) + } + Err(error) => { + this.server = CopilotServer::Error(error.to_string()); + Err(error) + } } + }) + } + }); + (this, task) + } + + fn sign_in(&mut self, cx: &mut ModelContext) -> Task> { + if let CopilotServer::Started { server, .. } = &self.server { + let server = server.clone(); + cx.spawn(|this, mut cx| async move { + let sign_in = server + .request::(request::SignInInitiateParams {}) + .await?; + if let request::SignInInitiateResult::PromptUserDeviceFlow(flow) = sign_in { + this.update(&mut cx, |_, cx| { + cx.emit(Event::PromptUserDeviceFlow { + user_code: flow.user_code.clone(), + verification_uri: flow.verification_uri, + }); + }); + let response = server + .request::(request::SignInConfirmParams { + user_code: flow.user_code, + }) + .await?; + this.update(&mut cx, |this, cx| this.update_sign_in_status(response, cx)); } + anyhow::Ok(()) }) - }) - .detach_and_log_err(cx); - copilot + } else { + Task::ready(Err(anyhow!("copilot hasn't started yet"))) + } + } + + fn sign_out(&mut self, cx: &mut ModelContext) -> Task> { + if let CopilotServer::Started { server, .. } = &self.server { + let server = server.clone(); + cx.spawn(|this, mut cx| async move { + server + .request::(request::SignOutParams {}) + .await?; + this.update(&mut cx, |this, cx| { + if let CopilotServer::Started { status, .. } = &mut this.server { + *status = SignInStatus::SignedOut; + cx.notify(); + } + }); + + anyhow::Ok(()) + }) + } else { + Task::ready(Err(anyhow!("copilot hasn't started yet"))) + } + } + + fn update_sign_in_status( + &mut self, + lsp_status: request::SignInStatus, + cx: &mut ModelContext, + ) { + if let CopilotServer::Started { status, .. } = &mut self.server { + *status = match lsp_status { + request::SignInStatus::Ok { user } | request::SignInStatus::MaybeOk { user } => { + SignInStatus::Authorized { user } + } + request::SignInStatus::NotAuthorized { user } => { + SignInStatus::Unauthorized { user } + } + _ => SignInStatus::SignedOut, + }; + cx.notify(); + } } } diff --git a/crates/copilot/src/request.rs b/crates/copilot/src/request.rs index 3f1f66482e..1b02227273 100644 --- a/crates/copilot/src/request.rs +++ b/crates/copilot/src/request.rs @@ -8,15 +8,82 @@ pub struct CheckStatusParams { pub local_checks_only: bool, } -#[derive(Debug, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct CheckStatusResult { - pub status: String, - pub user: Option, -} - impl lsp::request::Request for CheckStatus { type Params = CheckStatusParams; - type Result = CheckStatusResult; + type Result = SignInStatus; const METHOD: &'static str = "checkStatus"; } + +pub enum SignInInitiate {} + +#[derive(Debug, Serialize, Deserialize)] +pub struct SignInInitiateParams {} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "status")] +pub enum SignInInitiateResult { + AlreadySignedIn { user: String }, + PromptUserDeviceFlow(PromptUserDeviceFlow), +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PromptUserDeviceFlow { + pub user_code: String, + pub verification_uri: String, +} + +impl lsp::request::Request for SignInInitiate { + type Params = SignInInitiateParams; + type Result = SignInInitiateResult; + const METHOD: &'static str = "signInInitiate"; +} + +pub enum SignInConfirm {} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SignInConfirmParams { + pub user_code: String, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "status")] +pub enum SignInStatus { + #[serde(rename = "OK")] + Ok { + user: String, + }, + MaybeOk { + user: String, + }, + AlreadySignedIn { + user: String, + }, + NotAuthorized { + user: String, + }, + NotSignedIn, +} + +impl lsp::request::Request for SignInConfirm { + type Params = SignInConfirmParams; + type Result = SignInStatus; + const METHOD: &'static str = "signInConfirm"; +} + +pub enum SignOut {} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SignOutParams {} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SignOutResult {} + +impl lsp::request::Request for SignOut { + type Params = SignOutParams; + type Result = SignOutResult; + const METHOD: &'static str = "signOut"; +}