Support for Unary-Get-Request

This commit is contained in:
vkill 2024-01-06 14:36:19 +08:00 committed by Alec Thilenius
parent 80d49d8ca5
commit 929027e321
5 changed files with 199 additions and 10 deletions

View file

@ -70,6 +70,8 @@ impl AxumConnectServiceGenerator {
}
}
} else {
let method_name_unary_get = format_ident!("{}_unary_get", method.name);
quote! {
pub fn #method_name<T, H, S>(
handler: H
@ -91,6 +93,27 @@ impl AxumConnectServiceGenerator {
)
}
}
pub fn #method_name_unary_get<T, H, S>(
handler: H
) -> impl FnOnce(axum::Router<S>) -> axum_connect::router::RpcRouter<S>
where
H: axum_connect::handler::RpcHandlerUnary<#input_type, #output_type, T, S>,
T: 'static,
S: Clone + Send + Sync + 'static,
{
move |router: axum::Router<S>| {
router.route(
#path,
axum::routing::get(|
axum::extract::State(state): axum::extract::State<S>,
request: axum::http::Request<axum::body::Body>
| async move {
handler.call(request, state).await
}),
)
}
}
}
}
}

View file

@ -1,3 +1,10 @@
/*
cargo run -p axum-connect-example
curl 'http://127.0.0.1:3030/hello.HelloWorldService/SayHello?encoding=json&message=%7B%7D' -v
curl 'http://127.0.0.1:3030/hello.HelloWorldService/SayHello?encoding=json&message=%7B%22name%22%3A%22foo%22%7D' -v
*/
use async_stream::stream;
use axum::{extract::Host, Router};
use axum_connect::{futures::Stream, prelude::*};
@ -16,6 +23,7 @@ async fn main() {
// just a normal Rust function. Just like Axum, it also supports extractors!
let app = Router::new()
.rpc(HelloWorldService::say_hello(say_hello_success))
.rpc(HelloWorldService::say_hello_unary_get(say_hello_success))
.rpc(HelloWorldService::say_hello_stream(say_hello_stream));
let listener = tokio::net::TcpListener::bind("127.0.0.1:3030")

View file

@ -24,3 +24,6 @@ pbjson-types = "0.6.0"
prost = "0.12.1"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
serde_qs = { version = "0.12.0" }
base64 = { version = "0.21.5" }

View file

@ -5,7 +5,7 @@ use axum::{
response::{IntoResponse, Response},
};
use prost::Message;
use serde::de::DeserializeOwned;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use crate::prelude::{RpcError, RpcErrorCode};
@ -59,6 +59,59 @@ pub(crate) fn encode_error_response(
}
}
#[derive(Deserialize, Serialize, Debug, Clone)]
pub(crate) struct UnaryGetQuery {
pub message: String,
pub encoding: String,
pub base64: Option<usize>,
pub compression: Option<String>,
pub connect: Option<String>,
}
pub(crate) fn decode_check_query(parts: &request::Parts) -> Result<ReqResInto, Response> {
let query_str = match parts.uri.query() {
Some(x) => x,
None => {
return Err(encode_error_response(
&RpcError::new(RpcErrorCode::InvalidArgument, "Missing query".into()),
false,
false,
))
}
};
let query = match serde_qs::from_str::<UnaryGetQuery>(query_str) {
Ok(x) => x,
Err(err) => {
return Err(encode_error_response(
&RpcError::new(
RpcErrorCode::InvalidArgument,
format!("Wrong query, {}", err),
),
false,
false,
))
}
};
let binary = match query.encoding.as_str() {
"json" => false,
"proto" => true,
s => {
return Err(encode_error_response(
&RpcError::new(
RpcErrorCode::InvalidArgument,
format!("Wrong or unknown query.encoding: {}", s),
),
true,
true,
))
}
};
Ok(ReqResInto { binary })
}
pub(crate) fn decode_check_headers(
parts: &mut request::Parts,
for_streaming: bool,
@ -123,6 +176,91 @@ pub(crate) fn decode_check_headers(
Ok(ReqResInto { binary })
}
pub(crate) fn decode_request_payload_from_query<M, S>(
parts: &request::Parts,
_state: &S,
as_binary: bool,
) -> Result<M, Response>
where
M: Message + DeserializeOwned + Default,
S: Send + Sync + 'static,
{
let for_streaming = false;
let query_str = match parts.uri.query() {
Some(x) => x,
None => {
return Err(encode_error_response(
&RpcError::new(RpcErrorCode::InvalidArgument, "Missing query".into()),
false,
false,
))
}
};
let query = match serde_qs::from_str::<UnaryGetQuery>(query_str) {
Ok(x) => x,
Err(err) => {
return Err(encode_error_response(
&RpcError::new(
RpcErrorCode::InvalidArgument,
format!("Wrong query, {}", err),
),
false,
false,
))
}
};
let message = if query.base64 == Some(1) {
use base64::{engine::general_purpose, Engine as _};
match general_purpose::URL_SAFE.decode(&query.message) {
Ok(x) => String::from_utf8_lossy(x.as_slice()).to_string(),
Err(err) => {
return Err(encode_error_response(
&RpcError::new(
RpcErrorCode::InvalidArgument,
format!("Wrong query.message, {}", err),
),
false,
false,
))
}
}
} else {
query.message.into()
};
if as_binary {
let message: M = M::decode(message.as_bytes()).map_err(|e| {
encode_error_response(
&RpcError::new(
RpcErrorCode::InvalidArgument,
format!("Failed to decode binary protobuf. {}", e),
),
as_binary,
for_streaming,
)
})?;
Ok(message)
} else {
let message: M = serde_json::from_str(&message).map_err(|e| {
encode_error_response(
&RpcError::new(
RpcErrorCode::InvalidArgument,
format!("Failed to decode json. {}", e),
),
as_binary,
for_streaming,
)
})?;
Ok(message)
}
}
pub(crate) async fn decode_request_payload<M, S>(
req: Request<Body>,
state: &S,

View file

@ -2,7 +2,7 @@ use std::{convert::Infallible, pin::Pin};
use axum::{
body::Body,
http::{header, Request, StatusCode},
http::{header, Method, Request, StatusCode},
response::{IntoResponse, Response},
};
use futures::Future;
@ -17,7 +17,8 @@ use crate::{
};
use super::codec::{
decode_check_headers, decode_request_payload, encode_error_response, ReqResInto,
decode_check_headers, decode_check_query, decode_request_payload,
decode_request_payload_from_query, encode_error_response, ReqResInto,
};
pub trait RpcHandlerUnary<TMReq, TMRes, TUid, TState>: Clone + Send + Sized + 'static {
@ -138,9 +139,16 @@ macro_rules! impl_handler {
Box::pin(async move {
let (mut parts, body) = req.into_parts();
let ReqResInto { binary } = match decode_check_headers(&mut parts, false) {
let ReqResInto { binary } = if parts.method == Method::GET {
match decode_check_query(&parts) {
Ok(binary) => binary,
Err(e) => return e,
}
} else {
match decode_check_headers(&mut parts, false) {
Ok(binary) => binary,
Err(e) => return e,
}
};
let state = &state;
@ -155,11 +163,20 @@ macro_rules! impl_handler {
};
)*
let req = Request::from_parts(parts, body);
let proto_req: TMReq = match decode_request_payload(req, state, binary, false).await {
let proto_req: TMReq = if parts.method == Method::GET {
match decode_request_payload_from_query(&parts, state, binary) {
Ok(value) => value,
Err(e) => return e,
}
} else {
let req = Request::from_parts(parts, body);
match decode_request_payload(req, state, binary, false).await {
Ok(value) => value,
Err(e) => return e,
}
};
let res = self($($ty,)* proto_req).await.rpc_into_response();