diff --git a/zed-rpc/src/auth.rs b/zed-rpc/src/auth.rs index 9136c754e3..5254ef9ca4 100644 --- a/zed-rpc/src/auth.rs +++ b/zed-rpc/src/auth.rs @@ -1,8 +1,7 @@ -use std::convert::{TryFrom, TryInto}; - use anyhow::{Context, Result}; -use rand::{rngs::OsRng, Rng as _}; +use rand::{thread_rng, Rng as _}; use rsa::{PublicKey as _, PublicKeyEncoding, RSAPrivateKey, RSAPublicKey}; +use std::convert::TryFrom; pub struct PublicKey(RSAPublicKey); @@ -10,7 +9,7 @@ pub struct PrivateKey(RSAPrivateKey); /// Generate a public and private key for asymmetric encryption. pub fn keypair() -> Result<(PublicKey, PrivateKey)> { - let mut rng = OsRng; + let mut rng = thread_rng(); let bits = 1024; let private_key = RSAPrivateKey::new(&mut rng, bits)?; let public_key = RSAPublicKey::from(&private_key); @@ -19,25 +18,25 @@ pub fn keypair() -> Result<(PublicKey, PrivateKey)> { /// Generate a random 64-character base64 string. pub fn random_token() -> String { - let mut rng = OsRng; + let mut rng = thread_rng(); let mut token_bytes = [0; 48]; for byte in token_bytes.iter_mut() { *byte = rng.gen(); } - base64::encode(&token_bytes) + base64::encode_config(&token_bytes, base64::URL_SAFE) } impl PublicKey { /// Convert a string to a base64-encoded string that can only be decoded with the corresponding /// private key. pub fn encrypt_string(&self, string: &str) -> Result { - let mut rng = OsRng; + let mut rng = thread_rng(); let bytes = string.as_bytes(); let encrypted_bytes = self .0 .encrypt(&mut rng, PADDING_SCHEME, bytes) .context("failed to encrypt string with public key")?; - let encrypted_string = base64::encode(&encrypted_bytes); + let encrypted_string = base64::encode_config(&encrypted_bytes, base64::URL_SAFE); Ok(encrypted_string) } } @@ -45,8 +44,8 @@ impl PublicKey { impl PrivateKey { /// Decrypt a base64-encoded string that was encrypted by the correspoding public key. pub fn decrypt_string(&self, encrypted_string: &str) -> Result { - let encrypted_bytes = - base64::decode(encrypted_string).context("failed to base64-decode encrypted string")?; + let encrypted_bytes = base64::decode_config(encrypted_string, base64::URL_SAFE) + .context("failed to base64-decode encrypted string")?; let bytes = self .0 .decrypt(PADDING_SCHEME, &encrypted_bytes) @@ -56,14 +55,11 @@ impl PrivateKey { } } -impl TryInto for PublicKey { +impl TryFrom for String { type Error = anyhow::Error; - fn try_into(self) -> Result { - let bytes = self - .0 - .to_pkcs1() - .context("failed to serialize public key")?; - let string = base64::encode(&bytes); + fn try_from(key: PublicKey) -> Result { + let bytes = key.0.to_pkcs1().context("failed to serialize public key")?; + let string = base64::encode_config(&bytes, base64::URL_SAFE); Ok(string) } } @@ -71,7 +67,8 @@ impl TryInto for PublicKey { impl TryFrom for PublicKey { type Error = anyhow::Error; fn try_from(value: String) -> Result { - let bytes = base64::decode(&value).context("failed to base64-decode public key string")?; + let bytes = base64::decode_config(&value, base64::URL_SAFE) + .context("failed to base64-decode public key string")?; let key = Self(RSAPublicKey::from_pkcs1(&bytes).context("failed to parse public key")?); Ok(key) } @@ -89,13 +86,14 @@ mod tests { // * generate a keypair for asymmetric encryption // * serialize the public key to send it to the server. let (public, private) = keypair().unwrap(); - let public_string: String = public.try_into().unwrap(); + let public_string = String::try_from(public).unwrap(); + assert_printable(&public_string); // SERVER: // * parse the public key // * generate a random token. // * encrypt the token using the public key. - let public: PublicKey = public_string.try_into().unwrap(); + let public = PublicKey::try_from(public_string).unwrap(); let token = random_token(); let encrypted_token = public.encrypt_string(&token).unwrap(); assert_eq!(token.len(), 64); @@ -109,6 +107,20 @@ mod tests { assert_eq!(decrypted_token, token); } + #[test] + fn test_tokens_are_always_url_safe() { + for _ in 0..5 { + let token = random_token(); + let (public_key, _) = keypair().unwrap(); + let encrypted_token = public_key.encrypt_string(&token).unwrap(); + let public_key_str = String::try_from(public_key).unwrap(); + + assert_printable(&token); + assert_printable(&public_key_str); + assert_printable(&encrypted_token); + } + } + fn assert_printable(token: &str) { for c in token.chars() { assert!( @@ -117,6 +129,8 @@ mod tests { token, c ); + assert_ne!(c, '/', "token {:?} is not URL-safe", token); + assert_ne!(c, '&', "token {:?} is not URL-safe", token); } } }