commit ae7afcfbea9bcaecac315371135e027ca425a737 Author: Alec Thilenius Date: Fri Mar 3 01:47:16 2023 +0000 Initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..4fffb2f --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +/target +/Cargo.lock diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..a698d54 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "cSpell.words": ["codegen", "proto", "protobuf", "serde"] +} diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..8cc3417 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,3 @@ +[workspace] +resolver = "2" +members = ["axum-connect", "axum-connect-build", "axum-connect-examples"] diff --git a/axum-connect-build/Cargo.toml b/axum-connect-build/Cargo.toml new file mode 100644 index 0000000..1427ba5 --- /dev/null +++ b/axum-connect-build/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "axum-connect-build" +version = "0.1.0" +edition = "2021" + +[dependencies] +anyhow = "1.0" +convert_case = "0.6.0" +protobuf = { git = "https://github.com/AThilenius/rust-protobuf.git" } +protobuf-codegen = { git = "https://github.com/AThilenius/rust-protobuf.git" } +protobuf-parse = { git = "https://github.com/AThilenius/rust-protobuf.git" } diff --git a/axum-connect-build/src/lib.rs b/axum-connect-build/src/lib.rs new file mode 100644 index 0000000..842ae7e --- /dev/null +++ b/axum-connect-build/src/lib.rs @@ -0,0 +1,153 @@ +use std::path::Path; + +use convert_case::{Case, Casing}; +use protobuf::reflect::FileDescriptor; +use protobuf_codegen::{ + gen::scope::{RootScope, WithScope}, + Codegen, +}; +use protobuf_parse::ProtobufAbsPath; + +// TODO There is certainly a much easier way to do this, but I can't make sense of rust-protobuf. +pub fn axum_connect_codegen( + include: impl AsRef, + inputs: impl IntoIterator>, +) -> anyhow::Result<()> { + let results = Codegen::new() + .pure() + .cargo_out_dir("connect_proto_gen") + .inputs(inputs) + .include(include) + .run()?; + + let file_descriptors = + FileDescriptor::new_dynamic_fds(results.parsed.file_descriptors.clone(), &[])?; + + let root_scope = RootScope { + file_descriptors: &file_descriptors.as_slice(), + }; + + for path in results.parsed.relative_paths { + // Find the relative file descriptor + let file_descriptor = results + .parsed + .file_descriptors + .iter() + .find(|&fd| fd.name.clone().unwrap_or_default().ends_with(path.to_str())) + .expect(&format!( + "find a file descriptor matching the relative path {}", + path.to_str() + )); + + // TODO: This seems fragile. + let path = path.to_path().with_extension("rs"); + let cargo_out_dir = std::env::var("OUT_DIR")?; + let out_dir = Path::new(&cargo_out_dir).join("connect_proto_gen"); + let proto_rs_file_name = path.file_name().unwrap().to_str().unwrap(); + let proto_rs_full_path = out_dir.join(&proto_rs_file_name); + + // Replace all instances of "::protobuf::" with "::axum_connect::protobuf::" in the original + // generated file. + let rust = std::fs::read_to_string(&proto_rs_full_path)?; + let rust = rust.replace("::protobuf::", "::axum_connect::protobuf::"); + // std::fs::write(&proto_rs_full_path, rust)?; + + // Build up the service implementation file source. + let mut c = String::new(); + + c.push_str(FILE_PREAMBLE_TEMPLATE); + + for service in &file_descriptor.service { + // Build up methods first + let mut m = String::new(); + + for method in &service.method { + let input_type = root_scope + .find_message(&ProtobufAbsPath { + path: method.input_type().to_string(), + }) + .rust_name_with_file() + .to_path() + .to_string(); + + let output_type = root_scope + .find_message(&ProtobufAbsPath { + path: method.output_type().to_string(), + }) + .rust_name_with_file() + .to_path() + .to_string(); + + m.push_str( + &METHOD_TEMPLATE + .replace("@@METHOD_NAME@@", &method.name().to_case(Case::Snake)) + .replace("@@INPUT_TYPE@@", &input_type) + .replace("@@OUTPUT_TYPE@@", &output_type) + .replace( + "@@ROUTE@@", + &format!( + "/{}.{}/{}", + file_descriptor.package(), + service.name(), + method.name() + ), + ), + ); + } + + c.push_str( + &SERVICE_TEMPLATE + .replace("@@SERVICE_NAME@@", service.name()) + .replace("@@SERVICE_METHODS@@", &m), + ); + } + + let mut final_file = String::new(); + final_file.push_str(&rust); + final_file.push_str(&c); + + std::fs::write(&proto_rs_full_path, &final_file)?; + } + + Ok(()) +} + +const FILE_PREAMBLE_TEMPLATE: &str = "// Generated by axum-connect-build +use axum::{ + body::HttpBody, extract::State, http::Request, response::IntoResponse, routing::post, BoxError, + Router, +}; + +use axum_connect::{HandlerFuture, RpcRouter}; +"; + +const SERVICE_TEMPLATE: &str = " +pub struct @@SERVICE_NAME@@; + +impl @@SERVICE_NAME@@ { +@@SERVICE_METHODS@@ +}"; + +const METHOD_TEMPLATE: &str = " + pub fn @@METHOD_NAME@@(handler: H) -> impl FnOnce(Router) -> RpcRouter + where + H: HandlerFuture, + T: 'static, + S: Clone + Send + Sync + 'static, + B: HttpBody + Send + 'static, + B::Data: Send, + B::Error: Into, + { + move |router: Router| { + router.route( + \"@@ROUTE@@\", + post(|State(state): State, request: Request| async move { + let res = handler.call(request, state).await; + ::axum_connect::protobuf_json_mapping::print_to_string(&res) + .unwrap() + .into_response() + }), + ) + } + } +"; diff --git a/axum-connect-examples/Cargo.toml b/axum-connect-examples/Cargo.toml new file mode 100644 index 0000000..dd681be --- /dev/null +++ b/axum-connect-examples/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "hello_world" +version = "0.1.0" +edition = "2021" + +[dependencies] +axum = "0.6.9" +axum-connect = { path = "../axum-connect" } +tokio = { version = "1.0", features = ["full"] } + +[build-dependencies] +axum-connect-build = { path = "../axum-connect-build" } diff --git a/axum-connect-examples/build.rs b/axum-connect-examples/build.rs new file mode 100644 index 0000000..653de51 --- /dev/null +++ b/axum-connect-examples/build.rs @@ -0,0 +1,5 @@ +use axum_connect_build::axum_connect_codegen; + +fn main() { + axum_connect_codegen("proto", &["proto/hello.proto"]).unwrap(); +} diff --git a/axum-connect-examples/proto/hello.proto b/axum-connect-examples/proto/hello.proto new file mode 100644 index 0000000..25b228e --- /dev/null +++ b/axum-connect-examples/proto/hello.proto @@ -0,0 +1,15 @@ +syntax = "proto3"; + +package axum_connect.examples.hello_world; + +message HelloRequest { + string name = 1; +} + +message HelloResponse { + string message = 1; +} + +service HelloWorldService { + rpc SayHello(HelloRequest) returns (HelloResponse) {} +} diff --git a/axum-connect-examples/src/main.rs b/axum-connect-examples/src/main.rs new file mode 100644 index 0000000..78c9b16 --- /dev/null +++ b/axum-connect-examples/src/main.rs @@ -0,0 +1,33 @@ +use std::net::SocketAddr; + +use axum::{extract::Host, Router}; +use axum_connect::*; +use proto::hello::{HelloRequest, HelloResponse, HelloWorldService}; + +mod proto { + include!(concat!(env!("OUT_DIR"), "/connect_proto_gen/mod.rs")); +} + +#[tokio::main] +async fn main() { + // Build our application with a route + let app = Router::new().rpc(HelloWorldService::say_hello(say_hello_handler)); + + // Run the Axum server. + let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); + println!("listening on http://{}", addr); + axum::Server::bind(&addr) + .serve(app.into_make_service()) + .await + .unwrap(); +} + +async fn say_hello_handler(Host(host): Host, request: HelloRequest) -> HelloResponse { + HelloResponse { + message: format!( + "Hello {}! You're addressing the hostname: {}.", + request.name, host + ), + special_fields: Default::default(), + } +} diff --git a/axum-connect/Cargo.toml b/axum-connect/Cargo.toml new file mode 100644 index 0000000..6e4fc34 --- /dev/null +++ b/axum-connect/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "axum-connect" +version = "0.1.0" +edition = "2021" + +[dependencies] +axum = "0.6.9" +futures = "0.3.26" +protobuf = "3.2.0" +protobuf-json-mapping = "3.2.0" +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" diff --git a/axum-connect/src/lib.rs b/axum-connect/src/lib.rs new file mode 100644 index 0000000..d704b59 --- /dev/null +++ b/axum-connect/src/lib.rs @@ -0,0 +1,224 @@ +use std::pin::Pin; + +use axum::{ + body::{Body, HttpBody}, + extract::{FromRequest, FromRequestParts}, + http::{Request, StatusCode}, + response::{IntoResponse, Response}, + BoxError, Router, +}; +use futures::Future; +use protobuf::MessageFull; +use serde::Serialize; + +pub use protobuf; +pub use protobuf_json_mapping; + +pub trait RpcRouterExt: Sized { + fn rpc(self, register: F) -> Self + where + F: FnOnce(Self) -> RpcRouter; +} + +impl RpcRouterExt for Router { + fn rpc(self, register: F) -> Self + where + F: FnOnce(Self) -> RpcRouter, + { + register(self) + } +} + +pub type RpcRouter = Router; + +pub trait RegisterRpcService: Sized { + fn register(self, router: Router) -> Self; +} + +pub trait IntoRpcResponse +where + T: MessageFull, +{ + fn into_response(self) -> Response; +} + +#[derive(Clone, Serialize)] +pub struct RpcError { + pub code: RpcErrorCode, + pub message: String, + pub details: Vec, +} + +impl RpcError { + pub fn new(code: RpcErrorCode, message: String) -> Self { + Self { + code, + message, + details: vec![], + } + } +} + +#[derive(Clone, Serialize)] +pub struct RpcErrorDetail { + #[serde(rename = "type")] + pub proto_type: String, + #[serde(rename = "value")] + pub proto_b62_value: String, +} + +#[derive(Clone, Serialize)] +#[serde(rename_all = "snake_case")] +pub enum RpcErrorCode { + Canceled, + Unknown, + InvalidArgument, + DeadlineExceeded, + NotFound, + AlreadyExists, + PermissionDenied, + ResourceExhausted, + FailedPrecondition, + Aborted, + OutOfRange, + Unimplemented, + Internal, + Unavailable, + DataLoss, + Unauthenticated, +} + +impl From for StatusCode { + fn from(val: RpcErrorCode) -> Self { + match val { + // Spec: https://connect.build/docs/protocol/#error-codes + RpcErrorCode::Canceled => StatusCode::REQUEST_TIMEOUT, + RpcErrorCode::Unknown => StatusCode::INTERNAL_SERVER_ERROR, + RpcErrorCode::InvalidArgument => StatusCode::BAD_REQUEST, + RpcErrorCode::DeadlineExceeded => StatusCode::REQUEST_TIMEOUT, + RpcErrorCode::NotFound => StatusCode::NOT_FOUND, + RpcErrorCode::AlreadyExists => StatusCode::CONFLICT, + RpcErrorCode::PermissionDenied => StatusCode::FORBIDDEN, + RpcErrorCode::ResourceExhausted => StatusCode::TOO_MANY_REQUESTS, + RpcErrorCode::FailedPrecondition => StatusCode::PRECONDITION_FAILED, + RpcErrorCode::Aborted => StatusCode::CONFLICT, + RpcErrorCode::OutOfRange => StatusCode::BAD_REQUEST, + RpcErrorCode::Unimplemented => StatusCode::NOT_FOUND, + RpcErrorCode::Internal => StatusCode::INTERNAL_SERVER_ERROR, + RpcErrorCode::Unavailable => StatusCode::SERVICE_UNAVAILABLE, + RpcErrorCode::DataLoss => StatusCode::INTERNAL_SERVER_ERROR, + RpcErrorCode::Unauthenticated => StatusCode::UNAUTHORIZED, + } + } +} + +impl IntoResponse for RpcError { + fn into_response(self) -> Response { + let status_code = StatusCode::from(self.code.clone()); + let json = serde_json::to_string(&self).expect("serialize error type"); + (status_code, json).into_response() + } +} + +impl IntoRpcResponse for Result +where + T: MessageFull, + E: Into, +{ + fn into_response(self) -> Response { + match self { + Ok(res) => rpc_to_response(res), + Err(err) => err.into().into_response(), + } + } +} + +pub trait HandlerFuture: Clone + Send + Sized + 'static { + type Future: Future + Send + 'static; + + fn call(self, req: Request, state: S) -> Self::Future; +} + +fn rpc_to_response(res: T) -> Response +where + T: MessageFull, +{ + protobuf_json_mapping::print_to_string(&res) + .map_err(|_e| { + RpcError::new( + RpcErrorCode::Internal, + "Failed to serialize response".to_string(), + ) + }) + .into_response() +} + +macro_rules! impl_handler { + ( + [$($ty:ident),*] + ) => { + #[allow(unused_parens, non_snake_case, unused_mut)] + impl HandlerFuture for F + where + TReq: MessageFull + Send + 'static, + TRes: MessageFull + Send + 'static, + F: FnOnce($($ty,)* TReq) -> Fut + Clone + Send + 'static, + Fut: Future + Send, + B: HttpBody + Send + 'static, + B::Data: Send, + B::Error: Into, + S: Send + Sync + 'static, + $( $ty: FromRequestParts + Send, )* + { + type Future = Pin + Send>>; + + fn call(self, req: Request, state: S) -> Self::Future { + Box::pin(async move { + let (mut parts, body) = req.into_parts(); + let state = &state; + + // This would be done by macro expansion. It also wouldn't be unwrapped, but + // there is no error union so I can't return a rejection. + $( + let $ty = match $ty::from_request_parts(&mut parts, state).await { + Ok(value) => value, + Err(_e) => unreachable!(), + }; + )* + + let req = Request::from_parts(parts, body); + + let body = match String::from_request(req, state).await { + Ok(value) => value, + Err(_e) => unreachable!(), + }; + + let proto_req: TReq = match protobuf_json_mapping::parse_from_str(&body) { + Ok(value) => value, + Err(_e) => unreachable!(), + }; + + let res = self($($ty,)* proto_req).await; + res + }) + } + } + }; +} + +impl_handler!([]); +impl_handler!([T1]); +impl_handler!([T1, T2]); +impl_handler!([T1, T2, T3]); +impl_handler!([T1, T2, T3, T4]); +impl_handler!([T1, T2, T3, T4, T5]); +impl_handler!([T1, T2, T3, T4, T5, T6]); +impl_handler!([T1, T2, T3, T4, T5, T6, T7]); +impl_handler!([T1, T2, T3, T4, T5, T6, T7, T8]); +impl_handler!([T1, T2, T3, T4, T5, T6, T7, T8, T9]); +impl_handler!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10]); +impl_handler!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11]); +impl_handler!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12]); +impl_handler!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13]); +impl_handler!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14]); +impl_handler!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15]);