mirror of
https://github.com/AThilenius/axum-connect.git
synced 2025-01-06 18:18:42 +00:00
Support for Unary-Get-Request
This commit is contained in:
parent
80d49d8ca5
commit
929027e321
5 changed files with 199 additions and 10 deletions
|
@ -70,6 +70,8 @@ impl AxumConnectServiceGenerator {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
let method_name_unary_get = format_ident!("{}_unary_get", method.name);
|
||||||
|
|
||||||
quote! {
|
quote! {
|
||||||
pub fn #method_name<T, H, S>(
|
pub fn #method_name<T, H, S>(
|
||||||
handler: H
|
handler: H
|
||||||
|
@ -91,6 +93,27 @@ impl AxumConnectServiceGenerator {
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn #method_name_unary_get<T, H, S>(
|
||||||
|
handler: H
|
||||||
|
) -> impl FnOnce(axum::Router<S>) -> axum_connect::router::RpcRouter<S>
|
||||||
|
where
|
||||||
|
H: axum_connect::handler::RpcHandlerUnary<#input_type, #output_type, T, S>,
|
||||||
|
T: 'static,
|
||||||
|
S: Clone + Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
move |router: axum::Router<S>| {
|
||||||
|
router.route(
|
||||||
|
#path,
|
||||||
|
axum::routing::get(|
|
||||||
|
axum::extract::State(state): axum::extract::State<S>,
|
||||||
|
request: axum::http::Request<axum::body::Body>
|
||||||
|
| async move {
|
||||||
|
handler.call(request, state).await
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,3 +1,10 @@
|
||||||
|
/*
|
||||||
|
cargo run -p axum-connect-example
|
||||||
|
|
||||||
|
curl 'http://127.0.0.1:3030/hello.HelloWorldService/SayHello?encoding=json&message=%7B%7D' -v
|
||||||
|
curl 'http://127.0.0.1:3030/hello.HelloWorldService/SayHello?encoding=json&message=%7B%22name%22%3A%22foo%22%7D' -v
|
||||||
|
*/
|
||||||
|
|
||||||
use async_stream::stream;
|
use async_stream::stream;
|
||||||
use axum::{extract::Host, Router};
|
use axum::{extract::Host, Router};
|
||||||
use axum_connect::{futures::Stream, prelude::*};
|
use axum_connect::{futures::Stream, prelude::*};
|
||||||
|
@ -16,6 +23,7 @@ async fn main() {
|
||||||
// just a normal Rust function. Just like Axum, it also supports extractors!
|
// just a normal Rust function. Just like Axum, it also supports extractors!
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
.rpc(HelloWorldService::say_hello(say_hello_success))
|
.rpc(HelloWorldService::say_hello(say_hello_success))
|
||||||
|
.rpc(HelloWorldService::say_hello_unary_get(say_hello_success))
|
||||||
.rpc(HelloWorldService::say_hello_stream(say_hello_stream));
|
.rpc(HelloWorldService::say_hello_stream(say_hello_stream));
|
||||||
|
|
||||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:3030")
|
let listener = tokio::net::TcpListener::bind("127.0.0.1:3030")
|
||||||
|
|
|
@ -24,3 +24,6 @@ pbjson-types = "0.6.0"
|
||||||
prost = "0.12.1"
|
prost = "0.12.1"
|
||||||
serde = { version = "1.0", features = ["derive"] }
|
serde = { version = "1.0", features = ["derive"] }
|
||||||
serde_json = "1.0"
|
serde_json = "1.0"
|
||||||
|
|
||||||
|
serde_qs = { version = "0.12.0" }
|
||||||
|
base64 = { version = "0.21.5" }
|
||||||
|
|
|
@ -5,7 +5,7 @@ use axum::{
|
||||||
response::{IntoResponse, Response},
|
response::{IntoResponse, Response},
|
||||||
};
|
};
|
||||||
use prost::Message;
|
use prost::Message;
|
||||||
use serde::de::DeserializeOwned;
|
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::prelude::{RpcError, RpcErrorCode};
|
use crate::prelude::{RpcError, RpcErrorCode};
|
||||||
|
|
||||||
|
@ -59,6 +59,59 @@ pub(crate) fn encode_error_response(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Serialize, Debug, Clone)]
|
||||||
|
pub(crate) struct UnaryGetQuery {
|
||||||
|
pub message: String,
|
||||||
|
pub encoding: String,
|
||||||
|
pub base64: Option<usize>,
|
||||||
|
pub compression: Option<String>,
|
||||||
|
pub connect: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn decode_check_query(parts: &request::Parts) -> Result<ReqResInto, Response> {
|
||||||
|
let query_str = match parts.uri.query() {
|
||||||
|
Some(x) => x,
|
||||||
|
None => {
|
||||||
|
return Err(encode_error_response(
|
||||||
|
&RpcError::new(RpcErrorCode::InvalidArgument, "Missing query".into()),
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let query = match serde_qs::from_str::<UnaryGetQuery>(query_str) {
|
||||||
|
Ok(x) => x,
|
||||||
|
Err(err) => {
|
||||||
|
return Err(encode_error_response(
|
||||||
|
&RpcError::new(
|
||||||
|
RpcErrorCode::InvalidArgument,
|
||||||
|
format!("Wrong query, {}", err),
|
||||||
|
),
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let binary = match query.encoding.as_str() {
|
||||||
|
"json" => false,
|
||||||
|
"proto" => true,
|
||||||
|
s => {
|
||||||
|
return Err(encode_error_response(
|
||||||
|
&RpcError::new(
|
||||||
|
RpcErrorCode::InvalidArgument,
|
||||||
|
format!("Wrong or unknown query.encoding: {}", s),
|
||||||
|
),
|
||||||
|
true,
|
||||||
|
true,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(ReqResInto { binary })
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn decode_check_headers(
|
pub(crate) fn decode_check_headers(
|
||||||
parts: &mut request::Parts,
|
parts: &mut request::Parts,
|
||||||
for_streaming: bool,
|
for_streaming: bool,
|
||||||
|
@ -123,6 +176,91 @@ pub(crate) fn decode_check_headers(
|
||||||
Ok(ReqResInto { binary })
|
Ok(ReqResInto { binary })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn decode_request_payload_from_query<M, S>(
|
||||||
|
parts: &request::Parts,
|
||||||
|
_state: &S,
|
||||||
|
as_binary: bool,
|
||||||
|
) -> Result<M, Response>
|
||||||
|
where
|
||||||
|
M: Message + DeserializeOwned + Default,
|
||||||
|
S: Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
let for_streaming = false;
|
||||||
|
|
||||||
|
let query_str = match parts.uri.query() {
|
||||||
|
Some(x) => x,
|
||||||
|
None => {
|
||||||
|
return Err(encode_error_response(
|
||||||
|
&RpcError::new(RpcErrorCode::InvalidArgument, "Missing query".into()),
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let query = match serde_qs::from_str::<UnaryGetQuery>(query_str) {
|
||||||
|
Ok(x) => x,
|
||||||
|
Err(err) => {
|
||||||
|
return Err(encode_error_response(
|
||||||
|
&RpcError::new(
|
||||||
|
RpcErrorCode::InvalidArgument,
|
||||||
|
format!("Wrong query, {}", err),
|
||||||
|
),
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let message = if query.base64 == Some(1) {
|
||||||
|
use base64::{engine::general_purpose, Engine as _};
|
||||||
|
|
||||||
|
match general_purpose::URL_SAFE.decode(&query.message) {
|
||||||
|
Ok(x) => String::from_utf8_lossy(x.as_slice()).to_string(),
|
||||||
|
Err(err) => {
|
||||||
|
return Err(encode_error_response(
|
||||||
|
&RpcError::new(
|
||||||
|
RpcErrorCode::InvalidArgument,
|
||||||
|
format!("Wrong query.message, {}", err),
|
||||||
|
),
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
query.message.into()
|
||||||
|
};
|
||||||
|
|
||||||
|
if as_binary {
|
||||||
|
let message: M = M::decode(message.as_bytes()).map_err(|e| {
|
||||||
|
encode_error_response(
|
||||||
|
&RpcError::new(
|
||||||
|
RpcErrorCode::InvalidArgument,
|
||||||
|
format!("Failed to decode binary protobuf. {}", e),
|
||||||
|
),
|
||||||
|
as_binary,
|
||||||
|
for_streaming,
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
Ok(message)
|
||||||
|
} else {
|
||||||
|
let message: M = serde_json::from_str(&message).map_err(|e| {
|
||||||
|
encode_error_response(
|
||||||
|
&RpcError::new(
|
||||||
|
RpcErrorCode::InvalidArgument,
|
||||||
|
format!("Failed to decode json. {}", e),
|
||||||
|
),
|
||||||
|
as_binary,
|
||||||
|
for_streaming,
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
Ok(message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) async fn decode_request_payload<M, S>(
|
pub(crate) async fn decode_request_payload<M, S>(
|
||||||
req: Request<Body>,
|
req: Request<Body>,
|
||||||
state: &S,
|
state: &S,
|
||||||
|
|
|
@ -2,7 +2,7 @@ use std::{convert::Infallible, pin::Pin};
|
||||||
|
|
||||||
use axum::{
|
use axum::{
|
||||||
body::Body,
|
body::Body,
|
||||||
http::{header, Request, StatusCode},
|
http::{header, Method, Request, StatusCode},
|
||||||
response::{IntoResponse, Response},
|
response::{IntoResponse, Response},
|
||||||
};
|
};
|
||||||
use futures::Future;
|
use futures::Future;
|
||||||
|
@ -17,7 +17,8 @@ use crate::{
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::codec::{
|
use super::codec::{
|
||||||
decode_check_headers, decode_request_payload, encode_error_response, ReqResInto,
|
decode_check_headers, decode_check_query, decode_request_payload,
|
||||||
|
decode_request_payload_from_query, encode_error_response, ReqResInto,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub trait RpcHandlerUnary<TMReq, TMRes, TUid, TState>: Clone + Send + Sized + 'static {
|
pub trait RpcHandlerUnary<TMReq, TMRes, TUid, TState>: Clone + Send + Sized + 'static {
|
||||||
|
@ -138,9 +139,16 @@ macro_rules! impl_handler {
|
||||||
Box::pin(async move {
|
Box::pin(async move {
|
||||||
let (mut parts, body) = req.into_parts();
|
let (mut parts, body) = req.into_parts();
|
||||||
|
|
||||||
let ReqResInto { binary } = match decode_check_headers(&mut parts, false) {
|
let ReqResInto { binary } = if parts.method == Method::GET {
|
||||||
Ok(binary) => binary,
|
match decode_check_query(&parts) {
|
||||||
Err(e) => return e,
|
Ok(binary) => binary,
|
||||||
|
Err(e) => return e,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
match decode_check_headers(&mut parts, false) {
|
||||||
|
Ok(binary) => binary,
|
||||||
|
Err(e) => return e,
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let state = &state;
|
let state = &state;
|
||||||
|
@ -155,11 +163,20 @@ macro_rules! impl_handler {
|
||||||
};
|
};
|
||||||
)*
|
)*
|
||||||
|
|
||||||
let req = Request::from_parts(parts, body);
|
|
||||||
|
|
||||||
let proto_req: TMReq = match decode_request_payload(req, state, binary, false).await {
|
|
||||||
Ok(value) => value,
|
let proto_req: TMReq = if parts.method == Method::GET {
|
||||||
Err(e) => return e,
|
match decode_request_payload_from_query(&parts, state, binary) {
|
||||||
|
Ok(value) => value,
|
||||||
|
Err(e) => return e,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let req = Request::from_parts(parts, body);
|
||||||
|
|
||||||
|
match decode_request_payload(req, state, binary, false).await {
|
||||||
|
Ok(value) => value,
|
||||||
|
Err(e) => return e,
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let res = self($($ty,)* proto_req).await.rpc_into_response();
|
let res = self($($ty,)* proto_req).await.rpc_into_response();
|
||||||
|
|
Loading…
Reference in a new issue