Change RpcFromRequestPart impl to non-macro, add State and ConnectInfo

This commit is contained in:
Alec Thilenius 2023-04-28 09:27:19 -07:00
parent 6ca78c66f9
commit 47b4b09162
3 changed files with 88 additions and 61 deletions

View file

@ -1,22 +1,11 @@
syntax = "proto3";
package package.path;
package blink.hello;
message HelloRequest {
message EmbeddedOne {
message EmbeddedTwo {
string name = 1;
}
EmbeddedTwo embedded_two = 1;
}
string name = 1;
EmbeddedOne embedded_one = 2;
}
message HelloRequest { string name = 1; }
message HelloResponse {
string message = 1;
}
message HelloResponse { string message = 1; }
service HelloWorldService {
rpc SayHello(HelloRequest) returns (HelloResponse) {}
rpc SayHello(HelloRequest) returns (HelloResponse) {}
}

View file

@ -1,8 +1,12 @@
[package]
name = "axum-connect"
version = "0.1.2"
version = "0.1.3"
edition = "2021"
categories = ["rpc", "connect-web", "axum", "http"]
categories = [
"network-programming",
"web-programming",
"web-programming::http-server",
]
description = "Connect-Web RPC for Axum"
keywords = ["rpc", "axum", "protobuf", "connect"]
license = "MIT OR Apache-2.0"

View file

@ -1,7 +1,10 @@
use async_trait::async_trait;
use axum::{
extract::{FromRequestParts, Host, Query},
extract::{
connect_info::MockConnectInfo, ConnectInfo, FromRef, FromRequestParts, Host, Query, State,
},
http::{self},
Extension,
};
use protobuf::MessageFull;
use serde::de::DeserializeOwned;
@ -25,49 +28,80 @@ where
) -> Result<Self, Self::Rejection>;
}
/// Macro to convert standard Axum `FromRequestParts` into `RpcFromRequestParts` by transforming
/// their error type.
macro_rules! impl_rpc_from_request_parts {
($t:ident, $code:expr) => {
#[async_trait]
impl<M, S> RpcFromRequestParts<M, S> for $t
where
M: MessageFull,
S: Send + Sync,
{
type Rejection = RpcError;
#[async_trait]
impl<M, S> RpcFromRequestParts<M, S> for Host
where
M: MessageFull,
S: Send + Sync,
{
type Rejection = RpcError;
async fn rpc_from_request_parts(
parts: &mut http::request::Parts,
state: &S,
) -> Result<Self, Self::Rejection> {
Ok($t::from_request_parts(parts, state)
.await
.map_err(|e| ($code, e.to_string()).rpc_into_error())?)
}
}
};
([$($tin:ident),*], $t:ident, $code:expr) => {
#[async_trait]
impl<M, S, $($tin,)*> RpcFromRequestParts<M, S> for $t<$($tin,)*>
where
M: MessageFull,
S: Send + Sync,
$( $tin: DeserializeOwned, )*
{
type Rejection = RpcError;
async fn rpc_from_request_parts(
parts: &mut http::request::Parts,
state: &S,
) -> Result<Self, Self::Rejection> {
Ok($t::from_request_parts(parts, state)
.await
.map_err(|e| ($code, e.to_string()).rpc_into_error())?)
}
}
};
async fn rpc_from_request_parts(
parts: &mut http::request::Parts,
state: &S,
) -> Result<Self, Self::Rejection> {
Ok(Host::from_request_parts(parts, state)
.await
.map_err(|e| (RpcErrorCode::Internal, e.to_string()).rpc_into_error())?)
}
}
impl_rpc_from_request_parts!(Host, RpcErrorCode::Internal);
impl_rpc_from_request_parts!([T], Query, RpcErrorCode::Internal);
#[async_trait]
impl<M, S, T> RpcFromRequestParts<M, S> for Query<T>
where
M: MessageFull,
S: Send + Sync,
T: DeserializeOwned,
{
type Rejection = RpcError;
async fn rpc_from_request_parts(
parts: &mut http::request::Parts,
state: &S,
) -> Result<Self, Self::Rejection> {
Ok(Query::from_request_parts(parts, state)
.await
.map_err(|e| (RpcErrorCode::Internal, e.to_string()).rpc_into_error())?)
}
}
#[async_trait]
impl<M, S, T> RpcFromRequestParts<M, S> for ConnectInfo<T>
where
M: MessageFull,
S: Send + Sync,
T: Clone + Send + Sync + 'static,
{
type Rejection = RpcError;
async fn rpc_from_request_parts(
parts: &mut http::request::Parts,
state: &S,
) -> Result<Self, Self::Rejection> {
match Extension::<Self>::from_request_parts(parts, state).await {
Ok(Extension(connect_info)) => Ok(connect_info),
Err(err) => match parts.extensions.get::<MockConnectInfo<T>>() {
Some(MockConnectInfo(connect_info)) => Ok(Self(connect_info.clone())),
None => Err((RpcErrorCode::Internal, err.to_string()).rpc_into_error()),
},
}
}
}
#[async_trait]
impl<M, OuterState, InnerState> RpcFromRequestParts<M, OuterState> for State<InnerState>
where
M: MessageFull,
InnerState: FromRef<OuterState>,
OuterState: Send + Sync,
{
type Rejection = RpcError;
async fn rpc_from_request_parts(
_parts: &mut http::request::Parts,
state: &OuterState,
) -> Result<Self, Self::Rejection> {
let inner_state = InnerState::from_ref(state);
Ok(Self(inner_state))
}
}