From 8b87e63438d6bddc6adc74824edefcc7ea51e7e5 Mon Sep 17 00:00:00 2001 From: Alec Thilenius Date: Mon, 19 Feb 2024 18:57:27 -0800 Subject: [PATCH] WIP --- axum-connect/src/handler/codec.rs | 103 +++++++++++++-------- axum-connect/src/handler/handler_stream.rs | 2 +- 2 files changed, 67 insertions(+), 38 deletions(-) diff --git a/axum-connect/src/handler/codec.rs b/axum-connect/src/handler/codec.rs index 6139a10..291fdba 100644 --- a/axum-connect/src/handler/codec.rs +++ b/axum-connect/src/handler/codec.rs @@ -10,7 +10,8 @@ use serde::{de::DeserializeOwned, Deserialize, Serialize}; use crate::prelude::{RpcError, RpcErrorCode}; pub(crate) struct ReqResInto { - pub binary: bool, + pub streaming_req: bool, + pub binary_res: bool, } pub(crate) fn encode_error(e: &RpcError, for_streaming: bool) -> Vec { @@ -94,7 +95,7 @@ pub(crate) fn decode_check_query(parts: &request::Parts) -> Result false, "proto" => true, s => { @@ -109,12 +110,15 @@ pub(crate) fn decode_check_query(parts: &request::Parts) -> Result Result { // Check the version header, if specified. if let Some(version) = parts.headers.get("connect-protocol-version") { @@ -126,7 +130,7 @@ pub(crate) fn decode_check_headers( format!("Unsupported protocol version: {}", version), ), true, - for_streaming, + streaming_res, )); } } @@ -134,7 +138,7 @@ pub(crate) fn decode_check_headers( // Decode the content type (binary or JSON). // TODO: I'm not sure if this is correct. The Spec doesn't say what content type will be set for // server-streaming responses. - let binary = match parts.headers.get("content-type") { + let (binary_res, streaming_req) = match parts.headers.get("content-type") { Some(content_type) => match ( content_type .to_str() @@ -144,20 +148,20 @@ pub(crate) fn decode_check_headers( .next() .unwrap_or_default() .trim(), - for_streaming, + streaming_res, ) { - ("application/json", false) => false, - ("application/proto", false) => true, - ("application/connect+json", true) => false, - ("application/connect+proto", true) => true, + ("application/json", false) => (false, false), + ("application/proto", false) => (true, false), + ("application/connect+json", true) => (false, true), + ("application/connect+proto", true) => (true, true), (s, _) => { return Err(encode_error_response( &RpcError::new( RpcErrorCode::InvalidArgument, format!("Wrong or unknown Content-Type: {}", s), ), - true, - true, + false, + streaming_res, )) } }, @@ -167,13 +171,16 @@ pub(crate) fn decode_check_headers( RpcErrorCode::InvalidArgument, "Missing Content-Type header".to_string(), ), - true, - true, + false, + streaming_res, )) } }; - Ok(ReqResInto { binary }) + Ok(ReqResInto { + binary_res, + streaming_req, + }) } pub(crate) fn decode_request_payload_from_query( @@ -185,15 +192,15 @@ where M: Message + DeserializeOwned + Default, S: Send + Sync + 'static, { - let for_streaming = false; + let streaming_res = false; let query_str = match parts.uri.query() { Some(x) => x, None => { return Err(encode_error_response( &RpcError::new(RpcErrorCode::InvalidArgument, "Missing query".into()), - false, - false, + as_binary, + streaming_res, )) } }; @@ -206,8 +213,8 @@ where RpcErrorCode::InvalidArgument, format!("Wrong query, {}", err), ), - false, - false, + as_binary, + streaming_res, )) } }; @@ -223,8 +230,8 @@ where RpcErrorCode::InvalidArgument, format!("Wrong query.message, {}", err), ), - false, - false, + as_binary, + streaming_res, )) } } @@ -240,7 +247,7 @@ where format!("Failed to decode binary protobuf. {}", e), ), as_binary, - for_streaming, + streaming_res, ) })?; @@ -253,7 +260,7 @@ where format!("Failed to decode json. {}", e), ), as_binary, - for_streaming, + streaming_res, ) })?; @@ -264,15 +271,37 @@ where pub(crate) async fn decode_request_payload( req: Request, state: &S, - as_binary: bool, - for_streaming: bool, + binary_res: bool, + streaming_req: bool, ) -> Result where M: Message + DeserializeOwned + Default, S: Send + Sync + 'static, { - // Axum-connect only supports unary request types, so we can ignore for_streaming. - if as_binary { + let bytes = body::to_bytes(req.into_body(), usize::MAX) + .await + .map_err(|e| { + encode_error_response( + &RpcError::new( + RpcErrorCode::InvalidArgument, + format!("Failed to read request body. {}", e), + ), + binary_res, + streaming_req, + ) + })?; + + // TODO: I need an answer to https://github.com/connectrpc/connect-es/issues/1024 + // The spec doesn't seem to imply that a server-streaming response is allowed to treat the + // request as streaming (I guess with a single message?) if content-type is set to + // application/connect+*. That does seem to be how connect-es works though. + let json_bytes = if streaming_req { + // Strip and validate the envelope (the first 5 bytes). + let mut buf = [0; 5]; + } else { + }; + + if binary_res { let bytes = body::to_bytes(req.into_body(), usize::MAX) .await .map_err(|e| { @@ -281,8 +310,8 @@ where RpcErrorCode::InvalidArgument, format!("Failed to read request body. {}", e), ), - as_binary, - for_streaming, + binary_res, + streaming_req, ) })?; @@ -292,8 +321,8 @@ where RpcErrorCode::InvalidArgument, format!("Failed to decode binary protobuf. {}", e), ), - as_binary, - for_streaming, + binary_res, + streaming_req, ) })?; @@ -307,8 +336,8 @@ where RpcErrorCode::InvalidArgument, format!("Failed to read request body. {}", e), ), - as_binary, - for_streaming, + binary_res, + streaming_req, )); } }; @@ -319,8 +348,8 @@ where RpcErrorCode::InvalidArgument, format!("Failed to decode JSON protobuf. {}", e), ), - as_binary, - for_streaming, + binary_res, + streaming_req, ) })?; diff --git a/axum-connect/src/handler/handler_stream.rs b/axum-connect/src/handler/handler_stream.rs index b1d2893..1e3d3cb 100644 --- a/axum-connect/src/handler/handler_stream.rs +++ b/axum-connect/src/handler/handler_stream.rs @@ -161,7 +161,7 @@ macro_rules! impl_handler { Box::pin(async move { let (mut parts, body) = req.into_parts(); - let ReqResInto { binary } = match decode_check_headers(&mut parts, true) { + let ReqResInto { binary_res, streaming_req } = match decode_check_headers(&mut parts, true) { Ok(binary) => binary, Err(e) => return e, };