diff --git a/axum-connect-examples/proto/hello.proto b/axum-connect-examples/proto/hello.proto index c7710f8..36689ed 100644 --- a/axum-connect-examples/proto/hello.proto +++ b/axum-connect-examples/proto/hello.proto @@ -1,22 +1,11 @@ syntax = "proto3"; -package package.path; +package blink.hello; -message HelloRequest { - message EmbeddedOne { - message EmbeddedTwo { - string name = 1; - } - EmbeddedTwo embedded_two = 1; - } - string name = 1; - EmbeddedOne embedded_one = 2; -} +message HelloRequest { string name = 1; } -message HelloResponse { - string message = 1; -} +message HelloResponse { string message = 1; } service HelloWorldService { - rpc SayHello(HelloRequest) returns (HelloResponse) {} + rpc SayHello(HelloRequest) returns (HelloResponse) {} } diff --git a/axum-connect/Cargo.toml b/axum-connect/Cargo.toml index 87f5b94..ccdc85c 100644 --- a/axum-connect/Cargo.toml +++ b/axum-connect/Cargo.toml @@ -1,8 +1,12 @@ [package] name = "axum-connect" -version = "0.1.2" +version = "0.1.3" edition = "2021" -categories = ["rpc", "connect-web", "axum", "http"] +categories = [ + "network-programming", + "web-programming", + "web-programming::http-server", +] description = "Connect-Web RPC for Axum" keywords = ["rpc", "axum", "protobuf", "connect"] license = "MIT OR Apache-2.0" diff --git a/axum-connect/src/parts.rs b/axum-connect/src/parts.rs index 26dc59e..d25bdda 100644 --- a/axum-connect/src/parts.rs +++ b/axum-connect/src/parts.rs @@ -1,7 +1,10 @@ use async_trait::async_trait; use axum::{ - extract::{FromRequestParts, Host, Query}, + extract::{ + connect_info::MockConnectInfo, ConnectInfo, FromRef, FromRequestParts, Host, Query, State, + }, http::{self}, + Extension, }; use protobuf::MessageFull; use serde::de::DeserializeOwned; @@ -25,49 +28,80 @@ where ) -> Result; } -/// Macro to convert standard Axum `FromRequestParts` into `RpcFromRequestParts` by transforming -/// their error type. -macro_rules! impl_rpc_from_request_parts { - ($t:ident, $code:expr) => { - #[async_trait] - impl RpcFromRequestParts for $t - where - M: MessageFull, - S: Send + Sync, - { - type Rejection = RpcError; +#[async_trait] +impl RpcFromRequestParts for Host +where + M: MessageFull, + S: Send + Sync, +{ + type Rejection = RpcError; - async fn rpc_from_request_parts( - parts: &mut http::request::Parts, - state: &S, - ) -> Result { - Ok($t::from_request_parts(parts, state) - .await - .map_err(|e| ($code, e.to_string()).rpc_into_error())?) - } - } - }; - ([$($tin:ident),*], $t:ident, $code:expr) => { - #[async_trait] - impl RpcFromRequestParts for $t<$($tin,)*> - where - M: MessageFull, - S: Send + Sync, - $( $tin: DeserializeOwned, )* - { - type Rejection = RpcError; - - async fn rpc_from_request_parts( - parts: &mut http::request::Parts, - state: &S, - ) -> Result { - Ok($t::from_request_parts(parts, state) - .await - .map_err(|e| ($code, e.to_string()).rpc_into_error())?) - } - } - }; + async fn rpc_from_request_parts( + parts: &mut http::request::Parts, + state: &S, + ) -> Result { + Ok(Host::from_request_parts(parts, state) + .await + .map_err(|e| (RpcErrorCode::Internal, e.to_string()).rpc_into_error())?) + } } -impl_rpc_from_request_parts!(Host, RpcErrorCode::Internal); -impl_rpc_from_request_parts!([T], Query, RpcErrorCode::Internal); +#[async_trait] +impl RpcFromRequestParts for Query +where + M: MessageFull, + S: Send + Sync, + T: DeserializeOwned, +{ + type Rejection = RpcError; + + async fn rpc_from_request_parts( + parts: &mut http::request::Parts, + state: &S, + ) -> Result { + Ok(Query::from_request_parts(parts, state) + .await + .map_err(|e| (RpcErrorCode::Internal, e.to_string()).rpc_into_error())?) + } +} + +#[async_trait] +impl RpcFromRequestParts for ConnectInfo +where + M: MessageFull, + S: Send + Sync, + T: Clone + Send + Sync + 'static, +{ + type Rejection = RpcError; + + async fn rpc_from_request_parts( + parts: &mut http::request::Parts, + state: &S, + ) -> Result { + match Extension::::from_request_parts(parts, state).await { + Ok(Extension(connect_info)) => Ok(connect_info), + Err(err) => match parts.extensions.get::>() { + Some(MockConnectInfo(connect_info)) => Ok(Self(connect_info.clone())), + None => Err((RpcErrorCode::Internal, err.to_string()).rpc_into_error()), + }, + } + } +} + +#[async_trait] +impl RpcFromRequestParts for State +where + M: MessageFull, + InnerState: FromRef, + OuterState: Send + Sync, +{ + type Rejection = RpcError; + + async fn rpc_from_request_parts( + _parts: &mut http::request::Parts, + state: &OuterState, + ) -> Result { + let inner_state = InnerState::from_ref(state); + Ok(Self(inner_state)) + } +}