mirror of
https://github.com/AThilenius/axum-connect.git
synced 2025-01-05 01:39:14 +00:00
Support for Unary-Get-Request
This commit is contained in:
parent
80d49d8ca5
commit
929027e321
5 changed files with 199 additions and 10 deletions
|
@ -70,6 +70,8 @@ impl AxumConnectServiceGenerator {
|
|||
}
|
||||
}
|
||||
} else {
|
||||
let method_name_unary_get = format_ident!("{}_unary_get", method.name);
|
||||
|
||||
quote! {
|
||||
pub fn #method_name<T, H, S>(
|
||||
handler: H
|
||||
|
@ -91,6 +93,27 @@ impl AxumConnectServiceGenerator {
|
|||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn #method_name_unary_get<T, H, S>(
|
||||
handler: H
|
||||
) -> impl FnOnce(axum::Router<S>) -> axum_connect::router::RpcRouter<S>
|
||||
where
|
||||
H: axum_connect::handler::RpcHandlerUnary<#input_type, #output_type, T, S>,
|
||||
T: 'static,
|
||||
S: Clone + Send + Sync + 'static,
|
||||
{
|
||||
move |router: axum::Router<S>| {
|
||||
router.route(
|
||||
#path,
|
||||
axum::routing::get(|
|
||||
axum::extract::State(state): axum::extract::State<S>,
|
||||
request: axum::http::Request<axum::body::Body>
|
||||
| async move {
|
||||
handler.call(request, state).await
|
||||
}),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,3 +1,10 @@
|
|||
/*
|
||||
cargo run -p axum-connect-example
|
||||
|
||||
curl 'http://127.0.0.1:3030/hello.HelloWorldService/SayHello?encoding=json&message=%7B%7D' -v
|
||||
curl 'http://127.0.0.1:3030/hello.HelloWorldService/SayHello?encoding=json&message=%7B%22name%22%3A%22foo%22%7D' -v
|
||||
*/
|
||||
|
||||
use async_stream::stream;
|
||||
use axum::{extract::Host, Router};
|
||||
use axum_connect::{futures::Stream, prelude::*};
|
||||
|
@ -16,6 +23,7 @@ async fn main() {
|
|||
// just a normal Rust function. Just like Axum, it also supports extractors!
|
||||
let app = Router::new()
|
||||
.rpc(HelloWorldService::say_hello(say_hello_success))
|
||||
.rpc(HelloWorldService::say_hello_unary_get(say_hello_success))
|
||||
.rpc(HelloWorldService::say_hello_stream(say_hello_stream));
|
||||
|
||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:3030")
|
||||
|
|
|
@ -24,3 +24,6 @@ pbjson-types = "0.6.0"
|
|||
prost = "0.12.1"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
|
||||
serde_qs = { version = "0.12.0" }
|
||||
base64 = { version = "0.21.5" }
|
||||
|
|
|
@ -5,7 +5,7 @@ use axum::{
|
|||
response::{IntoResponse, Response},
|
||||
};
|
||||
use prost::Message;
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
|
||||
use crate::prelude::{RpcError, RpcErrorCode};
|
||||
|
||||
|
@ -59,6 +59,59 @@ pub(crate) fn encode_error_response(
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Debug, Clone)]
|
||||
pub(crate) struct UnaryGetQuery {
|
||||
pub message: String,
|
||||
pub encoding: String,
|
||||
pub base64: Option<usize>,
|
||||
pub compression: Option<String>,
|
||||
pub connect: Option<String>,
|
||||
}
|
||||
|
||||
pub(crate) fn decode_check_query(parts: &request::Parts) -> Result<ReqResInto, Response> {
|
||||
let query_str = match parts.uri.query() {
|
||||
Some(x) => x,
|
||||
None => {
|
||||
return Err(encode_error_response(
|
||||
&RpcError::new(RpcErrorCode::InvalidArgument, "Missing query".into()),
|
||||
false,
|
||||
false,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let query = match serde_qs::from_str::<UnaryGetQuery>(query_str) {
|
||||
Ok(x) => x,
|
||||
Err(err) => {
|
||||
return Err(encode_error_response(
|
||||
&RpcError::new(
|
||||
RpcErrorCode::InvalidArgument,
|
||||
format!("Wrong query, {}", err),
|
||||
),
|
||||
false,
|
||||
false,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let binary = match query.encoding.as_str() {
|
||||
"json" => false,
|
||||
"proto" => true,
|
||||
s => {
|
||||
return Err(encode_error_response(
|
||||
&RpcError::new(
|
||||
RpcErrorCode::InvalidArgument,
|
||||
format!("Wrong or unknown query.encoding: {}", s),
|
||||
),
|
||||
true,
|
||||
true,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
Ok(ReqResInto { binary })
|
||||
}
|
||||
|
||||
pub(crate) fn decode_check_headers(
|
||||
parts: &mut request::Parts,
|
||||
for_streaming: bool,
|
||||
|
@ -123,6 +176,91 @@ pub(crate) fn decode_check_headers(
|
|||
Ok(ReqResInto { binary })
|
||||
}
|
||||
|
||||
pub(crate) fn decode_request_payload_from_query<M, S>(
|
||||
parts: &request::Parts,
|
||||
_state: &S,
|
||||
as_binary: bool,
|
||||
) -> Result<M, Response>
|
||||
where
|
||||
M: Message + DeserializeOwned + Default,
|
||||
S: Send + Sync + 'static,
|
||||
{
|
||||
let for_streaming = false;
|
||||
|
||||
let query_str = match parts.uri.query() {
|
||||
Some(x) => x,
|
||||
None => {
|
||||
return Err(encode_error_response(
|
||||
&RpcError::new(RpcErrorCode::InvalidArgument, "Missing query".into()),
|
||||
false,
|
||||
false,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let query = match serde_qs::from_str::<UnaryGetQuery>(query_str) {
|
||||
Ok(x) => x,
|
||||
Err(err) => {
|
||||
return Err(encode_error_response(
|
||||
&RpcError::new(
|
||||
RpcErrorCode::InvalidArgument,
|
||||
format!("Wrong query, {}", err),
|
||||
),
|
||||
false,
|
||||
false,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let message = if query.base64 == Some(1) {
|
||||
use base64::{engine::general_purpose, Engine as _};
|
||||
|
||||
match general_purpose::URL_SAFE.decode(&query.message) {
|
||||
Ok(x) => String::from_utf8_lossy(x.as_slice()).to_string(),
|
||||
Err(err) => {
|
||||
return Err(encode_error_response(
|
||||
&RpcError::new(
|
||||
RpcErrorCode::InvalidArgument,
|
||||
format!("Wrong query.message, {}", err),
|
||||
),
|
||||
false,
|
||||
false,
|
||||
))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
query.message.into()
|
||||
};
|
||||
|
||||
if as_binary {
|
||||
let message: M = M::decode(message.as_bytes()).map_err(|e| {
|
||||
encode_error_response(
|
||||
&RpcError::new(
|
||||
RpcErrorCode::InvalidArgument,
|
||||
format!("Failed to decode binary protobuf. {}", e),
|
||||
),
|
||||
as_binary,
|
||||
for_streaming,
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(message)
|
||||
} else {
|
||||
let message: M = serde_json::from_str(&message).map_err(|e| {
|
||||
encode_error_response(
|
||||
&RpcError::new(
|
||||
RpcErrorCode::InvalidArgument,
|
||||
format!("Failed to decode json. {}", e),
|
||||
),
|
||||
as_binary,
|
||||
for_streaming,
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(message)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn decode_request_payload<M, S>(
|
||||
req: Request<Body>,
|
||||
state: &S,
|
||||
|
|
|
@ -2,7 +2,7 @@ use std::{convert::Infallible, pin::Pin};
|
|||
|
||||
use axum::{
|
||||
body::Body,
|
||||
http::{header, Request, StatusCode},
|
||||
http::{header, Method, Request, StatusCode},
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use futures::Future;
|
||||
|
@ -17,7 +17,8 @@ use crate::{
|
|||
};
|
||||
|
||||
use super::codec::{
|
||||
decode_check_headers, decode_request_payload, encode_error_response, ReqResInto,
|
||||
decode_check_headers, decode_check_query, decode_request_payload,
|
||||
decode_request_payload_from_query, encode_error_response, ReqResInto,
|
||||
};
|
||||
|
||||
pub trait RpcHandlerUnary<TMReq, TMRes, TUid, TState>: Clone + Send + Sized + 'static {
|
||||
|
@ -138,9 +139,16 @@ macro_rules! impl_handler {
|
|||
Box::pin(async move {
|
||||
let (mut parts, body) = req.into_parts();
|
||||
|
||||
let ReqResInto { binary } = match decode_check_headers(&mut parts, false) {
|
||||
Ok(binary) => binary,
|
||||
Err(e) => return e,
|
||||
let ReqResInto { binary } = if parts.method == Method::GET {
|
||||
match decode_check_query(&parts) {
|
||||
Ok(binary) => binary,
|
||||
Err(e) => return e,
|
||||
}
|
||||
} else {
|
||||
match decode_check_headers(&mut parts, false) {
|
||||
Ok(binary) => binary,
|
||||
Err(e) => return e,
|
||||
}
|
||||
};
|
||||
|
||||
let state = &state;
|
||||
|
@ -155,11 +163,20 @@ macro_rules! impl_handler {
|
|||
};
|
||||
)*
|
||||
|
||||
let req = Request::from_parts(parts, body);
|
||||
|
||||
let proto_req: TMReq = match decode_request_payload(req, state, binary, false).await {
|
||||
Ok(value) => value,
|
||||
Err(e) => return e,
|
||||
|
||||
let proto_req: TMReq = if parts.method == Method::GET {
|
||||
match decode_request_payload_from_query(&parts, state, binary) {
|
||||
Ok(value) => value,
|
||||
Err(e) => return e,
|
||||
}
|
||||
} else {
|
||||
let req = Request::from_parts(parts, body);
|
||||
|
||||
match decode_request_payload(req, state, binary, false).await {
|
||||
Ok(value) => value,
|
||||
Err(e) => return e,
|
||||
}
|
||||
};
|
||||
|
||||
let res = self($($ty,)* proto_req).await.rpc_into_response();
|
||||
|
|
Loading…
Reference in a new issue