Support for Unary-Get-Request

This commit is contained in:
vkill 2024-01-06 14:36:19 +08:00 committed by Alec Thilenius
parent 80d49d8ca5
commit 929027e321
5 changed files with 199 additions and 10 deletions

View file

@ -70,6 +70,8 @@ impl AxumConnectServiceGenerator {
} }
} }
} else { } else {
let method_name_unary_get = format_ident!("{}_unary_get", method.name);
quote! { quote! {
pub fn #method_name<T, H, S>( pub fn #method_name<T, H, S>(
handler: H handler: H
@ -91,6 +93,27 @@ impl AxumConnectServiceGenerator {
) )
} }
} }
pub fn #method_name_unary_get<T, H, S>(
handler: H
) -> impl FnOnce(axum::Router<S>) -> axum_connect::router::RpcRouter<S>
where
H: axum_connect::handler::RpcHandlerUnary<#input_type, #output_type, T, S>,
T: 'static,
S: Clone + Send + Sync + 'static,
{
move |router: axum::Router<S>| {
router.route(
#path,
axum::routing::get(|
axum::extract::State(state): axum::extract::State<S>,
request: axum::http::Request<axum::body::Body>
| async move {
handler.call(request, state).await
}),
)
}
}
} }
} }
} }

View file

@ -1,3 +1,10 @@
/*
cargo run -p axum-connect-example
curl 'http://127.0.0.1:3030/hello.HelloWorldService/SayHello?encoding=json&message=%7B%7D' -v
curl 'http://127.0.0.1:3030/hello.HelloWorldService/SayHello?encoding=json&message=%7B%22name%22%3A%22foo%22%7D' -v
*/
use async_stream::stream; use async_stream::stream;
use axum::{extract::Host, Router}; use axum::{extract::Host, Router};
use axum_connect::{futures::Stream, prelude::*}; use axum_connect::{futures::Stream, prelude::*};
@ -16,6 +23,7 @@ async fn main() {
// just a normal Rust function. Just like Axum, it also supports extractors! // just a normal Rust function. Just like Axum, it also supports extractors!
let app = Router::new() let app = Router::new()
.rpc(HelloWorldService::say_hello(say_hello_success)) .rpc(HelloWorldService::say_hello(say_hello_success))
.rpc(HelloWorldService::say_hello_unary_get(say_hello_success))
.rpc(HelloWorldService::say_hello_stream(say_hello_stream)); .rpc(HelloWorldService::say_hello_stream(say_hello_stream));
let listener = tokio::net::TcpListener::bind("127.0.0.1:3030") let listener = tokio::net::TcpListener::bind("127.0.0.1:3030")

View file

@ -24,3 +24,6 @@ pbjson-types = "0.6.0"
prost = "0.12.1" prost = "0.12.1"
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0" serde_json = "1.0"
serde_qs = { version = "0.12.0" }
base64 = { version = "0.21.5" }

View file

@ -5,7 +5,7 @@ use axum::{
response::{IntoResponse, Response}, response::{IntoResponse, Response},
}; };
use prost::Message; use prost::Message;
use serde::de::DeserializeOwned; use serde::{de::DeserializeOwned, Deserialize, Serialize};
use crate::prelude::{RpcError, RpcErrorCode}; use crate::prelude::{RpcError, RpcErrorCode};
@ -59,6 +59,59 @@ pub(crate) fn encode_error_response(
} }
} }
#[derive(Deserialize, Serialize, Debug, Clone)]
pub(crate) struct UnaryGetQuery {
pub message: String,
pub encoding: String,
pub base64: Option<usize>,
pub compression: Option<String>,
pub connect: Option<String>,
}
pub(crate) fn decode_check_query(parts: &request::Parts) -> Result<ReqResInto, Response> {
let query_str = match parts.uri.query() {
Some(x) => x,
None => {
return Err(encode_error_response(
&RpcError::new(RpcErrorCode::InvalidArgument, "Missing query".into()),
false,
false,
))
}
};
let query = match serde_qs::from_str::<UnaryGetQuery>(query_str) {
Ok(x) => x,
Err(err) => {
return Err(encode_error_response(
&RpcError::new(
RpcErrorCode::InvalidArgument,
format!("Wrong query, {}", err),
),
false,
false,
))
}
};
let binary = match query.encoding.as_str() {
"json" => false,
"proto" => true,
s => {
return Err(encode_error_response(
&RpcError::new(
RpcErrorCode::InvalidArgument,
format!("Wrong or unknown query.encoding: {}", s),
),
true,
true,
))
}
};
Ok(ReqResInto { binary })
}
pub(crate) fn decode_check_headers( pub(crate) fn decode_check_headers(
parts: &mut request::Parts, parts: &mut request::Parts,
for_streaming: bool, for_streaming: bool,
@ -123,6 +176,91 @@ pub(crate) fn decode_check_headers(
Ok(ReqResInto { binary }) Ok(ReqResInto { binary })
} }
pub(crate) fn decode_request_payload_from_query<M, S>(
parts: &request::Parts,
_state: &S,
as_binary: bool,
) -> Result<M, Response>
where
M: Message + DeserializeOwned + Default,
S: Send + Sync + 'static,
{
let for_streaming = false;
let query_str = match parts.uri.query() {
Some(x) => x,
None => {
return Err(encode_error_response(
&RpcError::new(RpcErrorCode::InvalidArgument, "Missing query".into()),
false,
false,
))
}
};
let query = match serde_qs::from_str::<UnaryGetQuery>(query_str) {
Ok(x) => x,
Err(err) => {
return Err(encode_error_response(
&RpcError::new(
RpcErrorCode::InvalidArgument,
format!("Wrong query, {}", err),
),
false,
false,
))
}
};
let message = if query.base64 == Some(1) {
use base64::{engine::general_purpose, Engine as _};
match general_purpose::URL_SAFE.decode(&query.message) {
Ok(x) => String::from_utf8_lossy(x.as_slice()).to_string(),
Err(err) => {
return Err(encode_error_response(
&RpcError::new(
RpcErrorCode::InvalidArgument,
format!("Wrong query.message, {}", err),
),
false,
false,
))
}
}
} else {
query.message.into()
};
if as_binary {
let message: M = M::decode(message.as_bytes()).map_err(|e| {
encode_error_response(
&RpcError::new(
RpcErrorCode::InvalidArgument,
format!("Failed to decode binary protobuf. {}", e),
),
as_binary,
for_streaming,
)
})?;
Ok(message)
} else {
let message: M = serde_json::from_str(&message).map_err(|e| {
encode_error_response(
&RpcError::new(
RpcErrorCode::InvalidArgument,
format!("Failed to decode json. {}", e),
),
as_binary,
for_streaming,
)
})?;
Ok(message)
}
}
pub(crate) async fn decode_request_payload<M, S>( pub(crate) async fn decode_request_payload<M, S>(
req: Request<Body>, req: Request<Body>,
state: &S, state: &S,

View file

@ -2,7 +2,7 @@ use std::{convert::Infallible, pin::Pin};
use axum::{ use axum::{
body::Body, body::Body,
http::{header, Request, StatusCode}, http::{header, Method, Request, StatusCode},
response::{IntoResponse, Response}, response::{IntoResponse, Response},
}; };
use futures::Future; use futures::Future;
@ -17,7 +17,8 @@ use crate::{
}; };
use super::codec::{ use super::codec::{
decode_check_headers, decode_request_payload, encode_error_response, ReqResInto, decode_check_headers, decode_check_query, decode_request_payload,
decode_request_payload_from_query, encode_error_response, ReqResInto,
}; };
pub trait RpcHandlerUnary<TMReq, TMRes, TUid, TState>: Clone + Send + Sized + 'static { pub trait RpcHandlerUnary<TMReq, TMRes, TUid, TState>: Clone + Send + Sized + 'static {
@ -138,9 +139,16 @@ macro_rules! impl_handler {
Box::pin(async move { Box::pin(async move {
let (mut parts, body) = req.into_parts(); let (mut parts, body) = req.into_parts();
let ReqResInto { binary } = match decode_check_headers(&mut parts, false) { let ReqResInto { binary } = if parts.method == Method::GET {
Ok(binary) => binary, match decode_check_query(&parts) {
Err(e) => return e, Ok(binary) => binary,
Err(e) => return e,
}
} else {
match decode_check_headers(&mut parts, false) {
Ok(binary) => binary,
Err(e) => return e,
}
}; };
let state = &state; let state = &state;
@ -155,11 +163,20 @@ macro_rules! impl_handler {
}; };
)* )*
let req = Request::from_parts(parts, body);
let proto_req: TMReq = match decode_request_payload(req, state, binary, false).await {
Ok(value) => value, let proto_req: TMReq = if parts.method == Method::GET {
Err(e) => return e, match decode_request_payload_from_query(&parts, state, binary) {
Ok(value) => value,
Err(e) => return e,
}
} else {
let req = Request::from_parts(parts, body);
match decode_request_payload(req, state, binary, false).await {
Ok(value) => value,
Err(e) => return e,
}
}; };
let res = self($($ty,)* proto_req).await.rpc_into_response(); let res = self($($ty,)* proto_req).await.rpc_into_response();