diff --git a/axum-connect-build/src/gen.rs b/axum-connect-build/src/gen.rs index f417f52..f3f33fc 100644 --- a/axum-connect-build/src/gen.rs +++ b/axum-connect-build/src/gen.rs @@ -70,6 +70,8 @@ impl AxumConnectServiceGenerator { } } } else { + let method_name_unary_get = format_ident!("{}_unary_get", method.name); + quote! { pub fn #method_name( handler: H @@ -91,6 +93,27 @@ impl AxumConnectServiceGenerator { ) } } + + pub fn #method_name_unary_get( + handler: H + ) -> impl FnOnce(axum::Router) -> axum_connect::router::RpcRouter + where + H: axum_connect::handler::RpcHandlerUnary<#input_type, #output_type, T, S>, + T: 'static, + S: Clone + Send + Sync + 'static, + { + move |router: axum::Router| { + router.route( + #path, + axum::routing::get(| + axum::extract::State(state): axum::extract::State, + request: axum::http::Request + | async move { + handler.call(request, state).await + }), + ) + } + } } } } diff --git a/axum-connect-examples/src/main.rs b/axum-connect-examples/src/main.rs index 48d5228..cec9c87 100644 --- a/axum-connect-examples/src/main.rs +++ b/axum-connect-examples/src/main.rs @@ -1,3 +1,10 @@ +/* +cargo run -p axum-connect-example + +curl 'http://127.0.0.1:3030/hello.HelloWorldService/SayHello?encoding=json&message=%7B%7D' -v +curl 'http://127.0.0.1:3030/hello.HelloWorldService/SayHello?encoding=json&message=%7B%22name%22%3A%22foo%22%7D' -v +*/ + use async_stream::stream; use axum::{extract::Host, Router}; use axum_connect::{futures::Stream, prelude::*}; @@ -16,6 +23,7 @@ async fn main() { // just a normal Rust function. Just like Axum, it also supports extractors! let app = Router::new() .rpc(HelloWorldService::say_hello(say_hello_success)) + .rpc(HelloWorldService::say_hello_unary_get(say_hello_success)) .rpc(HelloWorldService::say_hello_stream(say_hello_stream)); let listener = tokio::net::TcpListener::bind("127.0.0.1:3030") diff --git a/axum-connect/Cargo.toml b/axum-connect/Cargo.toml index be04a25..85d5c85 100644 --- a/axum-connect/Cargo.toml +++ b/axum-connect/Cargo.toml @@ -24,3 +24,6 @@ pbjson-types = "0.6.0" prost = "0.12.1" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" + +serde_qs = { version = "0.12.0" } +base64 = { version = "0.21.5" } diff --git a/axum-connect/src/handler/codec.rs b/axum-connect/src/handler/codec.rs index 9997271..f23202c 100644 --- a/axum-connect/src/handler/codec.rs +++ b/axum-connect/src/handler/codec.rs @@ -5,7 +5,7 @@ use axum::{ response::{IntoResponse, Response}, }; use prost::Message; -use serde::de::DeserializeOwned; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; use crate::prelude::{RpcError, RpcErrorCode}; @@ -59,6 +59,59 @@ pub(crate) fn encode_error_response( } } +#[derive(Deserialize, Serialize, Debug, Clone)] +pub(crate) struct UnaryGetQuery { + pub message: String, + pub encoding: String, + pub base64: Option, + pub compression: Option, + pub connect: Option, +} + +pub(crate) fn decode_check_query(parts: &request::Parts) -> Result { + let query_str = match parts.uri.query() { + Some(x) => x, + None => { + return Err(encode_error_response( + &RpcError::new(RpcErrorCode::InvalidArgument, "Missing query".into()), + false, + false, + )) + } + }; + + let query = match serde_qs::from_str::(query_str) { + Ok(x) => x, + Err(err) => { + return Err(encode_error_response( + &RpcError::new( + RpcErrorCode::InvalidArgument, + format!("Wrong query, {}", err), + ), + false, + false, + )) + } + }; + + let binary = match query.encoding.as_str() { + "json" => false, + "proto" => true, + s => { + return Err(encode_error_response( + &RpcError::new( + RpcErrorCode::InvalidArgument, + format!("Wrong or unknown query.encoding: {}", s), + ), + true, + true, + )) + } + }; + + Ok(ReqResInto { binary }) +} + pub(crate) fn decode_check_headers( parts: &mut request::Parts, for_streaming: bool, @@ -123,6 +176,91 @@ pub(crate) fn decode_check_headers( Ok(ReqResInto { binary }) } +pub(crate) fn decode_request_payload_from_query( + parts: &request::Parts, + _state: &S, + as_binary: bool, +) -> Result +where + M: Message + DeserializeOwned + Default, + S: Send + Sync + 'static, +{ + let for_streaming = false; + + let query_str = match parts.uri.query() { + Some(x) => x, + None => { + return Err(encode_error_response( + &RpcError::new(RpcErrorCode::InvalidArgument, "Missing query".into()), + false, + false, + )) + } + }; + + let query = match serde_qs::from_str::(query_str) { + Ok(x) => x, + Err(err) => { + return Err(encode_error_response( + &RpcError::new( + RpcErrorCode::InvalidArgument, + format!("Wrong query, {}", err), + ), + false, + false, + )) + } + }; + + let message = if query.base64 == Some(1) { + use base64::{engine::general_purpose, Engine as _}; + + match general_purpose::URL_SAFE.decode(&query.message) { + Ok(x) => String::from_utf8_lossy(x.as_slice()).to_string(), + Err(err) => { + return Err(encode_error_response( + &RpcError::new( + RpcErrorCode::InvalidArgument, + format!("Wrong query.message, {}", err), + ), + false, + false, + )) + } + } + } else { + query.message.into() + }; + + if as_binary { + let message: M = M::decode(message.as_bytes()).map_err(|e| { + encode_error_response( + &RpcError::new( + RpcErrorCode::InvalidArgument, + format!("Failed to decode binary protobuf. {}", e), + ), + as_binary, + for_streaming, + ) + })?; + + Ok(message) + } else { + let message: M = serde_json::from_str(&message).map_err(|e| { + encode_error_response( + &RpcError::new( + RpcErrorCode::InvalidArgument, + format!("Failed to decode json. {}", e), + ), + as_binary, + for_streaming, + ) + })?; + + Ok(message) + } +} + pub(crate) async fn decode_request_payload( req: Request, state: &S, diff --git a/axum-connect/src/handler/handler_unary.rs b/axum-connect/src/handler/handler_unary.rs index 66d1a41..65cb0ef 100644 --- a/axum-connect/src/handler/handler_unary.rs +++ b/axum-connect/src/handler/handler_unary.rs @@ -2,7 +2,7 @@ use std::{convert::Infallible, pin::Pin}; use axum::{ body::Body, - http::{header, Request, StatusCode}, + http::{header, Method, Request, StatusCode}, response::{IntoResponse, Response}, }; use futures::Future; @@ -17,7 +17,8 @@ use crate::{ }; use super::codec::{ - decode_check_headers, decode_request_payload, encode_error_response, ReqResInto, + decode_check_headers, decode_check_query, decode_request_payload, + decode_request_payload_from_query, encode_error_response, ReqResInto, }; pub trait RpcHandlerUnary: Clone + Send + Sized + 'static { @@ -138,9 +139,16 @@ macro_rules! impl_handler { Box::pin(async move { let (mut parts, body) = req.into_parts(); - let ReqResInto { binary } = match decode_check_headers(&mut parts, false) { - Ok(binary) => binary, - Err(e) => return e, + let ReqResInto { binary } = if parts.method == Method::GET { + match decode_check_query(&parts) { + Ok(binary) => binary, + Err(e) => return e, + } + } else { + match decode_check_headers(&mut parts, false) { + Ok(binary) => binary, + Err(e) => return e, + } }; let state = &state; @@ -155,11 +163,20 @@ macro_rules! impl_handler { }; )* - let req = Request::from_parts(parts, body); - let proto_req: TMReq = match decode_request_payload(req, state, binary, false).await { - Ok(value) => value, - Err(e) => return e, + + let proto_req: TMReq = if parts.method == Method::GET { + match decode_request_payload_from_query(&parts, state, binary) { + Ok(value) => value, + Err(e) => return e, + } + } else { + let req = Request::from_parts(parts, body); + + match decode_request_payload(req, state, binary, false).await { + Ok(value) => value, + Err(e) => return e, + } }; let res = self($($ty,)* proto_req).await.rpc_into_response();