mirror of
https://github.com/AThilenius/axum-connect.git
synced 2025-01-05 01:39:14 +00:00
Change RpcFromRequestPart impl to non-macro, add State and ConnectInfo
This commit is contained in:
parent
6ca78c66f9
commit
47b4b09162
3 changed files with 88 additions and 61 deletions
|
@ -1,22 +1,11 @@
|
|||
syntax = "proto3";
|
||||
|
||||
package package.path;
|
||||
package blink.hello;
|
||||
|
||||
message HelloRequest {
|
||||
message EmbeddedOne {
|
||||
message EmbeddedTwo {
|
||||
string name = 1;
|
||||
}
|
||||
EmbeddedTwo embedded_two = 1;
|
||||
}
|
||||
string name = 1;
|
||||
EmbeddedOne embedded_one = 2;
|
||||
}
|
||||
message HelloRequest { string name = 1; }
|
||||
|
||||
message HelloResponse {
|
||||
string message = 1;
|
||||
}
|
||||
message HelloResponse { string message = 1; }
|
||||
|
||||
service HelloWorldService {
|
||||
rpc SayHello(HelloRequest) returns (HelloResponse) {}
|
||||
rpc SayHello(HelloRequest) returns (HelloResponse) {}
|
||||
}
|
||||
|
|
|
@ -1,8 +1,12 @@
|
|||
[package]
|
||||
name = "axum-connect"
|
||||
version = "0.1.2"
|
||||
version = "0.1.3"
|
||||
edition = "2021"
|
||||
categories = ["rpc", "connect-web", "axum", "http"]
|
||||
categories = [
|
||||
"network-programming",
|
||||
"web-programming",
|
||||
"web-programming::http-server",
|
||||
]
|
||||
description = "Connect-Web RPC for Axum"
|
||||
keywords = ["rpc", "axum", "protobuf", "connect"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
use async_trait::async_trait;
|
||||
use axum::{
|
||||
extract::{FromRequestParts, Host, Query},
|
||||
extract::{
|
||||
connect_info::MockConnectInfo, ConnectInfo, FromRef, FromRequestParts, Host, Query, State,
|
||||
},
|
||||
http::{self},
|
||||
Extension,
|
||||
};
|
||||
use protobuf::MessageFull;
|
||||
use serde::de::DeserializeOwned;
|
||||
|
@ -25,49 +28,80 @@ where
|
|||
) -> Result<Self, Self::Rejection>;
|
||||
}
|
||||
|
||||
/// Macro to convert standard Axum `FromRequestParts` into `RpcFromRequestParts` by transforming
|
||||
/// their error type.
|
||||
macro_rules! impl_rpc_from_request_parts {
|
||||
($t:ident, $code:expr) => {
|
||||
#[async_trait]
|
||||
impl<M, S> RpcFromRequestParts<M, S> for $t
|
||||
where
|
||||
M: MessageFull,
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = RpcError;
|
||||
#[async_trait]
|
||||
impl<M, S> RpcFromRequestParts<M, S> for Host
|
||||
where
|
||||
M: MessageFull,
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = RpcError;
|
||||
|
||||
async fn rpc_from_request_parts(
|
||||
parts: &mut http::request::Parts,
|
||||
state: &S,
|
||||
) -> Result<Self, Self::Rejection> {
|
||||
Ok($t::from_request_parts(parts, state)
|
||||
.await
|
||||
.map_err(|e| ($code, e.to_string()).rpc_into_error())?)
|
||||
}
|
||||
}
|
||||
};
|
||||
([$($tin:ident),*], $t:ident, $code:expr) => {
|
||||
#[async_trait]
|
||||
impl<M, S, $($tin,)*> RpcFromRequestParts<M, S> for $t<$($tin,)*>
|
||||
where
|
||||
M: MessageFull,
|
||||
S: Send + Sync,
|
||||
$( $tin: DeserializeOwned, )*
|
||||
{
|
||||
type Rejection = RpcError;
|
||||
|
||||
async fn rpc_from_request_parts(
|
||||
parts: &mut http::request::Parts,
|
||||
state: &S,
|
||||
) -> Result<Self, Self::Rejection> {
|
||||
Ok($t::from_request_parts(parts, state)
|
||||
.await
|
||||
.map_err(|e| ($code, e.to_string()).rpc_into_error())?)
|
||||
}
|
||||
}
|
||||
};
|
||||
async fn rpc_from_request_parts(
|
||||
parts: &mut http::request::Parts,
|
||||
state: &S,
|
||||
) -> Result<Self, Self::Rejection> {
|
||||
Ok(Host::from_request_parts(parts, state)
|
||||
.await
|
||||
.map_err(|e| (RpcErrorCode::Internal, e.to_string()).rpc_into_error())?)
|
||||
}
|
||||
}
|
||||
|
||||
impl_rpc_from_request_parts!(Host, RpcErrorCode::Internal);
|
||||
impl_rpc_from_request_parts!([T], Query, RpcErrorCode::Internal);
|
||||
#[async_trait]
|
||||
impl<M, S, T> RpcFromRequestParts<M, S> for Query<T>
|
||||
where
|
||||
M: MessageFull,
|
||||
S: Send + Sync,
|
||||
T: DeserializeOwned,
|
||||
{
|
||||
type Rejection = RpcError;
|
||||
|
||||
async fn rpc_from_request_parts(
|
||||
parts: &mut http::request::Parts,
|
||||
state: &S,
|
||||
) -> Result<Self, Self::Rejection> {
|
||||
Ok(Query::from_request_parts(parts, state)
|
||||
.await
|
||||
.map_err(|e| (RpcErrorCode::Internal, e.to_string()).rpc_into_error())?)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<M, S, T> RpcFromRequestParts<M, S> for ConnectInfo<T>
|
||||
where
|
||||
M: MessageFull,
|
||||
S: Send + Sync,
|
||||
T: Clone + Send + Sync + 'static,
|
||||
{
|
||||
type Rejection = RpcError;
|
||||
|
||||
async fn rpc_from_request_parts(
|
||||
parts: &mut http::request::Parts,
|
||||
state: &S,
|
||||
) -> Result<Self, Self::Rejection> {
|
||||
match Extension::<Self>::from_request_parts(parts, state).await {
|
||||
Ok(Extension(connect_info)) => Ok(connect_info),
|
||||
Err(err) => match parts.extensions.get::<MockConnectInfo<T>>() {
|
||||
Some(MockConnectInfo(connect_info)) => Ok(Self(connect_info.clone())),
|
||||
None => Err((RpcErrorCode::Internal, err.to_string()).rpc_into_error()),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<M, OuterState, InnerState> RpcFromRequestParts<M, OuterState> for State<InnerState>
|
||||
where
|
||||
M: MessageFull,
|
||||
InnerState: FromRef<OuterState>,
|
||||
OuterState: Send + Sync,
|
||||
{
|
||||
type Rejection = RpcError;
|
||||
|
||||
async fn rpc_from_request_parts(
|
||||
_parts: &mut http::request::Parts,
|
||||
state: &OuterState,
|
||||
) -> Result<Self, Self::Rejection> {
|
||||
let inner_state = InnerState::from_ref(state);
|
||||
Ok(Self(inner_state))
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue