Initial commit

This commit is contained in:
Alec Thilenius 2023-03-03 01:47:16 +00:00
commit ae7afcfbea
11 changed files with 473 additions and 0 deletions

2
.gitignore vendored Normal file
View file

@ -0,0 +1,2 @@
/target
/Cargo.lock

3
.vscode/settings.json vendored Normal file
View file

@ -0,0 +1,3 @@
{
"cSpell.words": ["codegen", "proto", "protobuf", "serde"]
}

3
Cargo.toml Normal file
View file

@ -0,0 +1,3 @@
[workspace]
resolver = "2"
members = ["axum-connect", "axum-connect-build", "axum-connect-examples"]

View file

@ -0,0 +1,11 @@
[package]
name = "axum-connect-build"
version = "0.1.0"
edition = "2021"
[dependencies]
anyhow = "1.0"
convert_case = "0.6.0"
protobuf = { git = "https://github.com/AThilenius/rust-protobuf.git" }
protobuf-codegen = { git = "https://github.com/AThilenius/rust-protobuf.git" }
protobuf-parse = { git = "https://github.com/AThilenius/rust-protobuf.git" }

View file

@ -0,0 +1,153 @@
use std::path::Path;
use convert_case::{Case, Casing};
use protobuf::reflect::FileDescriptor;
use protobuf_codegen::{
gen::scope::{RootScope, WithScope},
Codegen,
};
use protobuf_parse::ProtobufAbsPath;
// TODO There is certainly a much easier way to do this, but I can't make sense of rust-protobuf.
pub fn axum_connect_codegen(
include: impl AsRef<Path>,
inputs: impl IntoIterator<Item = impl AsRef<Path>>,
) -> anyhow::Result<()> {
let results = Codegen::new()
.pure()
.cargo_out_dir("connect_proto_gen")
.inputs(inputs)
.include(include)
.run()?;
let file_descriptors =
FileDescriptor::new_dynamic_fds(results.parsed.file_descriptors.clone(), &[])?;
let root_scope = RootScope {
file_descriptors: &file_descriptors.as_slice(),
};
for path in results.parsed.relative_paths {
// Find the relative file descriptor
let file_descriptor = results
.parsed
.file_descriptors
.iter()
.find(|&fd| fd.name.clone().unwrap_or_default().ends_with(path.to_str()))
.expect(&format!(
"find a file descriptor matching the relative path {}",
path.to_str()
));
// TODO: This seems fragile.
let path = path.to_path().with_extension("rs");
let cargo_out_dir = std::env::var("OUT_DIR")?;
let out_dir = Path::new(&cargo_out_dir).join("connect_proto_gen");
let proto_rs_file_name = path.file_name().unwrap().to_str().unwrap();
let proto_rs_full_path = out_dir.join(&proto_rs_file_name);
// Replace all instances of "::protobuf::" with "::axum_connect::protobuf::" in the original
// generated file.
let rust = std::fs::read_to_string(&proto_rs_full_path)?;
let rust = rust.replace("::protobuf::", "::axum_connect::protobuf::");
// std::fs::write(&proto_rs_full_path, rust)?;
// Build up the service implementation file source.
let mut c = String::new();
c.push_str(FILE_PREAMBLE_TEMPLATE);
for service in &file_descriptor.service {
// Build up methods first
let mut m = String::new();
for method in &service.method {
let input_type = root_scope
.find_message(&ProtobufAbsPath {
path: method.input_type().to_string(),
})
.rust_name_with_file()
.to_path()
.to_string();
let output_type = root_scope
.find_message(&ProtobufAbsPath {
path: method.output_type().to_string(),
})
.rust_name_with_file()
.to_path()
.to_string();
m.push_str(
&METHOD_TEMPLATE
.replace("@@METHOD_NAME@@", &method.name().to_case(Case::Snake))
.replace("@@INPUT_TYPE@@", &input_type)
.replace("@@OUTPUT_TYPE@@", &output_type)
.replace(
"@@ROUTE@@",
&format!(
"/{}.{}/{}",
file_descriptor.package(),
service.name(),
method.name()
),
),
);
}
c.push_str(
&SERVICE_TEMPLATE
.replace("@@SERVICE_NAME@@", service.name())
.replace("@@SERVICE_METHODS@@", &m),
);
}
let mut final_file = String::new();
final_file.push_str(&rust);
final_file.push_str(&c);
std::fs::write(&proto_rs_full_path, &final_file)?;
}
Ok(())
}
const FILE_PREAMBLE_TEMPLATE: &str = "// Generated by axum-connect-build
use axum::{
body::HttpBody, extract::State, http::Request, response::IntoResponse, routing::post, BoxError,
Router,
};
use axum_connect::{HandlerFuture, RpcRouter};
";
const SERVICE_TEMPLATE: &str = "
pub struct @@SERVICE_NAME@@;
impl @@SERVICE_NAME@@ {
@@SERVICE_METHODS@@
}";
const METHOD_TEMPLATE: &str = "
pub fn @@METHOD_NAME@@<T, H, S, B>(handler: H) -> impl FnOnce(Router<S, B>) -> RpcRouter<S, B>
where
H: HandlerFuture<super::@@INPUT_TYPE@@, super::@@OUTPUT_TYPE@@, T, S, B>,
T: 'static,
S: Clone + Send + Sync + 'static,
B: HttpBody + Send + 'static,
B::Data: Send,
B::Error: Into<BoxError>,
{
move |router: Router<S, B>| {
router.route(
\"@@ROUTE@@\",
post(|State(state): State<S>, request: Request<B>| async move {
let res = handler.call(request, state).await;
::axum_connect::protobuf_json_mapping::print_to_string(&res)
.unwrap()
.into_response()
}),
)
}
}
";

View file

@ -0,0 +1,12 @@
[package]
name = "hello_world"
version = "0.1.0"
edition = "2021"
[dependencies]
axum = "0.6.9"
axum-connect = { path = "../axum-connect" }
tokio = { version = "1.0", features = ["full"] }
[build-dependencies]
axum-connect-build = { path = "../axum-connect-build" }

View file

@ -0,0 +1,5 @@
use axum_connect_build::axum_connect_codegen;
fn main() {
axum_connect_codegen("proto", &["proto/hello.proto"]).unwrap();
}

View file

@ -0,0 +1,15 @@
syntax = "proto3";
package axum_connect.examples.hello_world;
message HelloRequest {
string name = 1;
}
message HelloResponse {
string message = 1;
}
service HelloWorldService {
rpc SayHello(HelloRequest) returns (HelloResponse) {}
}

View file

@ -0,0 +1,33 @@
use std::net::SocketAddr;
use axum::{extract::Host, Router};
use axum_connect::*;
use proto::hello::{HelloRequest, HelloResponse, HelloWorldService};
mod proto {
include!(concat!(env!("OUT_DIR"), "/connect_proto_gen/mod.rs"));
}
#[tokio::main]
async fn main() {
// Build our application with a route
let app = Router::new().rpc(HelloWorldService::say_hello(say_hello_handler));
// Run the Axum server.
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
println!("listening on http://{}", addr);
axum::Server::bind(&addr)
.serve(app.into_make_service())
.await
.unwrap();
}
async fn say_hello_handler(Host(host): Host, request: HelloRequest) -> HelloResponse {
HelloResponse {
message: format!(
"Hello {}! You're addressing the hostname: {}.",
request.name, host
),
special_fields: Default::default(),
}
}

12
axum-connect/Cargo.toml Normal file
View file

@ -0,0 +1,12 @@
[package]
name = "axum-connect"
version = "0.1.0"
edition = "2021"
[dependencies]
axum = "0.6.9"
futures = "0.3.26"
protobuf = "3.2.0"
protobuf-json-mapping = "3.2.0"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"

224
axum-connect/src/lib.rs Normal file
View file

@ -0,0 +1,224 @@
use std::pin::Pin;
use axum::{
body::{Body, HttpBody},
extract::{FromRequest, FromRequestParts},
http::{Request, StatusCode},
response::{IntoResponse, Response},
BoxError, Router,
};
use futures::Future;
use protobuf::MessageFull;
use serde::Serialize;
pub use protobuf;
pub use protobuf_json_mapping;
pub trait RpcRouterExt<S, B>: Sized {
fn rpc<F>(self, register: F) -> Self
where
F: FnOnce(Self) -> RpcRouter<S, B>;
}
impl<S, B> RpcRouterExt<S, B> for Router<S, B> {
fn rpc<F>(self, register: F) -> Self
where
F: FnOnce(Self) -> RpcRouter<S, B>,
{
register(self)
}
}
pub type RpcRouter<S, B> = Router<S, B>;
pub trait RegisterRpcService<S, B>: Sized {
fn register(self, router: Router<S, B>) -> Self;
}
pub trait IntoRpcResponse<T>
where
T: MessageFull,
{
fn into_response(self) -> Response;
}
#[derive(Clone, Serialize)]
pub struct RpcError {
pub code: RpcErrorCode,
pub message: String,
pub details: Vec<RpcErrorDetail>,
}
impl RpcError {
pub fn new(code: RpcErrorCode, message: String) -> Self {
Self {
code,
message,
details: vec![],
}
}
}
#[derive(Clone, Serialize)]
pub struct RpcErrorDetail {
#[serde(rename = "type")]
pub proto_type: String,
#[serde(rename = "value")]
pub proto_b62_value: String,
}
#[derive(Clone, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum RpcErrorCode {
Canceled,
Unknown,
InvalidArgument,
DeadlineExceeded,
NotFound,
AlreadyExists,
PermissionDenied,
ResourceExhausted,
FailedPrecondition,
Aborted,
OutOfRange,
Unimplemented,
Internal,
Unavailable,
DataLoss,
Unauthenticated,
}
impl From<RpcErrorCode> for StatusCode {
fn from(val: RpcErrorCode) -> Self {
match val {
// Spec: https://connect.build/docs/protocol/#error-codes
RpcErrorCode::Canceled => StatusCode::REQUEST_TIMEOUT,
RpcErrorCode::Unknown => StatusCode::INTERNAL_SERVER_ERROR,
RpcErrorCode::InvalidArgument => StatusCode::BAD_REQUEST,
RpcErrorCode::DeadlineExceeded => StatusCode::REQUEST_TIMEOUT,
RpcErrorCode::NotFound => StatusCode::NOT_FOUND,
RpcErrorCode::AlreadyExists => StatusCode::CONFLICT,
RpcErrorCode::PermissionDenied => StatusCode::FORBIDDEN,
RpcErrorCode::ResourceExhausted => StatusCode::TOO_MANY_REQUESTS,
RpcErrorCode::FailedPrecondition => StatusCode::PRECONDITION_FAILED,
RpcErrorCode::Aborted => StatusCode::CONFLICT,
RpcErrorCode::OutOfRange => StatusCode::BAD_REQUEST,
RpcErrorCode::Unimplemented => StatusCode::NOT_FOUND,
RpcErrorCode::Internal => StatusCode::INTERNAL_SERVER_ERROR,
RpcErrorCode::Unavailable => StatusCode::SERVICE_UNAVAILABLE,
RpcErrorCode::DataLoss => StatusCode::INTERNAL_SERVER_ERROR,
RpcErrorCode::Unauthenticated => StatusCode::UNAUTHORIZED,
}
}
}
impl IntoResponse for RpcError {
fn into_response(self) -> Response {
let status_code = StatusCode::from(self.code.clone());
let json = serde_json::to_string(&self).expect("serialize error type");
(status_code, json).into_response()
}
}
impl<T, E> IntoRpcResponse<T> for Result<T, E>
where
T: MessageFull,
E: Into<RpcError>,
{
fn into_response(self) -> Response {
match self {
Ok(res) => rpc_to_response(res),
Err(err) => err.into().into_response(),
}
}
}
pub trait HandlerFuture<TReq, TRes, T, S, B = Body>: Clone + Send + Sized + 'static {
type Future: Future<Output = TRes> + Send + 'static;
fn call(self, req: Request<B>, state: S) -> Self::Future;
}
fn rpc_to_response<T>(res: T) -> Response
where
T: MessageFull,
{
protobuf_json_mapping::print_to_string(&res)
.map_err(|_e| {
RpcError::new(
RpcErrorCode::Internal,
"Failed to serialize response".to_string(),
)
})
.into_response()
}
macro_rules! impl_handler {
(
[$($ty:ident),*]
) => {
#[allow(unused_parens, non_snake_case, unused_mut)]
impl<TReq, TRes, F, Fut, S, B, $($ty,)*> HandlerFuture<TReq, TRes, ($($ty,)* TReq), S, B> for F
where
TReq: MessageFull + Send + 'static,
TRes: MessageFull + Send + 'static,
F: FnOnce($($ty,)* TReq) -> Fut + Clone + Send + 'static,
Fut: Future<Output = TRes> + Send,
B: HttpBody + Send + 'static,
B::Data: Send,
B::Error: Into<BoxError>,
S: Send + Sync + 'static,
$( $ty: FromRequestParts<S> + Send, )*
{
type Future = Pin<Box<dyn Future<Output = TRes> + Send>>;
fn call(self, req: Request<B>, state: S) -> Self::Future {
Box::pin(async move {
let (mut parts, body) = req.into_parts();
let state = &state;
// This would be done by macro expansion. It also wouldn't be unwrapped, but
// there is no error union so I can't return a rejection.
$(
let $ty = match $ty::from_request_parts(&mut parts, state).await {
Ok(value) => value,
Err(_e) => unreachable!(),
};
)*
let req = Request::from_parts(parts, body);
let body = match String::from_request(req, state).await {
Ok(value) => value,
Err(_e) => unreachable!(),
};
let proto_req: TReq = match protobuf_json_mapping::parse_from_str(&body) {
Ok(value) => value,
Err(_e) => unreachable!(),
};
let res = self($($ty,)* proto_req).await;
res
})
}
}
};
}
impl_handler!([]);
impl_handler!([T1]);
impl_handler!([T1, T2]);
impl_handler!([T1, T2, T3]);
impl_handler!([T1, T2, T3, T4]);
impl_handler!([T1, T2, T3, T4, T5]);
impl_handler!([T1, T2, T3, T4, T5, T6]);
impl_handler!([T1, T2, T3, T4, T5, T6, T7]);
impl_handler!([T1, T2, T3, T4, T5, T6, T7, T8]);
impl_handler!([T1, T2, T3, T4, T5, T6, T7, T8, T9]);
impl_handler!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10]);
impl_handler!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11]);
impl_handler!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12]);
impl_handler!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13]);
impl_handler!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14]);
impl_handler!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15]);