Change RpcFromRequestPart impl to non-macro, add State and ConnectInfo

This commit is contained in:
Alec Thilenius 2023-04-28 09:27:19 -07:00
parent 6ca78c66f9
commit 47b4b09162
3 changed files with 88 additions and 61 deletions

View file

@ -1,22 +1,11 @@
syntax = "proto3"; syntax = "proto3";
package package.path; package blink.hello;
message HelloRequest { message HelloRequest { string name = 1; }
message EmbeddedOne {
message EmbeddedTwo {
string name = 1;
}
EmbeddedTwo embedded_two = 1;
}
string name = 1;
EmbeddedOne embedded_one = 2;
}
message HelloResponse { message HelloResponse { string message = 1; }
string message = 1;
}
service HelloWorldService { service HelloWorldService {
rpc SayHello(HelloRequest) returns (HelloResponse) {} rpc SayHello(HelloRequest) returns (HelloResponse) {}
} }

View file

@ -1,8 +1,12 @@
[package] [package]
name = "axum-connect" name = "axum-connect"
version = "0.1.2" version = "0.1.3"
edition = "2021" edition = "2021"
categories = ["rpc", "connect-web", "axum", "http"] categories = [
"network-programming",
"web-programming",
"web-programming::http-server",
]
description = "Connect-Web RPC for Axum" description = "Connect-Web RPC for Axum"
keywords = ["rpc", "axum", "protobuf", "connect"] keywords = ["rpc", "axum", "protobuf", "connect"]
license = "MIT OR Apache-2.0" license = "MIT OR Apache-2.0"

View file

@ -1,7 +1,10 @@
use async_trait::async_trait; use async_trait::async_trait;
use axum::{ use axum::{
extract::{FromRequestParts, Host, Query}, extract::{
connect_info::MockConnectInfo, ConnectInfo, FromRef, FromRequestParts, Host, Query, State,
},
http::{self}, http::{self},
Extension,
}; };
use protobuf::MessageFull; use protobuf::MessageFull;
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
@ -25,49 +28,80 @@ where
) -> Result<Self, Self::Rejection>; ) -> Result<Self, Self::Rejection>;
} }
/// Macro to convert standard Axum `FromRequestParts` into `RpcFromRequestParts` by transforming #[async_trait]
/// their error type. impl<M, S> RpcFromRequestParts<M, S> for Host
macro_rules! impl_rpc_from_request_parts { where
($t:ident, $code:expr) => { M: MessageFull,
#[async_trait] S: Send + Sync,
impl<M, S> RpcFromRequestParts<M, S> for $t {
where type Rejection = RpcError;
M: MessageFull,
S: Send + Sync,
{
type Rejection = RpcError;
async fn rpc_from_request_parts( async fn rpc_from_request_parts(
parts: &mut http::request::Parts, parts: &mut http::request::Parts,
state: &S, state: &S,
) -> Result<Self, Self::Rejection> { ) -> Result<Self, Self::Rejection> {
Ok($t::from_request_parts(parts, state) Ok(Host::from_request_parts(parts, state)
.await .await
.map_err(|e| ($code, e.to_string()).rpc_into_error())?) .map_err(|e| (RpcErrorCode::Internal, e.to_string()).rpc_into_error())?)
} }
}
};
([$($tin:ident),*], $t:ident, $code:expr) => {
#[async_trait]
impl<M, S, $($tin,)*> RpcFromRequestParts<M, S> for $t<$($tin,)*>
where
M: MessageFull,
S: Send + Sync,
$( $tin: DeserializeOwned, )*
{
type Rejection = RpcError;
async fn rpc_from_request_parts(
parts: &mut http::request::Parts,
state: &S,
) -> Result<Self, Self::Rejection> {
Ok($t::from_request_parts(parts, state)
.await
.map_err(|e| ($code, e.to_string()).rpc_into_error())?)
}
}
};
} }
impl_rpc_from_request_parts!(Host, RpcErrorCode::Internal); #[async_trait]
impl_rpc_from_request_parts!([T], Query, RpcErrorCode::Internal); impl<M, S, T> RpcFromRequestParts<M, S> for Query<T>
where
M: MessageFull,
S: Send + Sync,
T: DeserializeOwned,
{
type Rejection = RpcError;
async fn rpc_from_request_parts(
parts: &mut http::request::Parts,
state: &S,
) -> Result<Self, Self::Rejection> {
Ok(Query::from_request_parts(parts, state)
.await
.map_err(|e| (RpcErrorCode::Internal, e.to_string()).rpc_into_error())?)
}
}
#[async_trait]
impl<M, S, T> RpcFromRequestParts<M, S> for ConnectInfo<T>
where
M: MessageFull,
S: Send + Sync,
T: Clone + Send + Sync + 'static,
{
type Rejection = RpcError;
async fn rpc_from_request_parts(
parts: &mut http::request::Parts,
state: &S,
) -> Result<Self, Self::Rejection> {
match Extension::<Self>::from_request_parts(parts, state).await {
Ok(Extension(connect_info)) => Ok(connect_info),
Err(err) => match parts.extensions.get::<MockConnectInfo<T>>() {
Some(MockConnectInfo(connect_info)) => Ok(Self(connect_info.clone())),
None => Err((RpcErrorCode::Internal, err.to_string()).rpc_into_error()),
},
}
}
}
#[async_trait]
impl<M, OuterState, InnerState> RpcFromRequestParts<M, OuterState> for State<InnerState>
where
M: MessageFull,
InnerState: FromRef<OuterState>,
OuterState: Send + Sync,
{
type Rejection = RpcError;
async fn rpc_from_request_parts(
_parts: &mut http::request::Parts,
state: &OuterState,
) -> Result<Self, Self::Rejection> {
let inner_state = InnerState::from_ref(state);
Ok(Self(inner_state))
}
}