This commit is contained in:
Alec Thilenius 2024-02-19 18:57:27 -08:00
parent eb39acaf30
commit 8b87e63438
2 changed files with 67 additions and 38 deletions

View file

@ -10,7 +10,8 @@ use serde::{de::DeserializeOwned, Deserialize, Serialize};
use crate::prelude::{RpcError, RpcErrorCode};
pub(crate) struct ReqResInto {
pub binary: bool,
pub streaming_req: bool,
pub binary_res: bool,
}
pub(crate) fn encode_error(e: &RpcError, for_streaming: bool) -> Vec<u8> {
@ -94,7 +95,7 @@ pub(crate) fn decode_check_query(parts: &request::Parts) -> Result<ReqResInto, R
}
};
let binary = match query.encoding.as_str() {
let binary_res = match query.encoding.as_str() {
"json" => false,
"proto" => true,
s => {
@ -109,12 +110,15 @@ pub(crate) fn decode_check_query(parts: &request::Parts) -> Result<ReqResInto, R
}
};
Ok(ReqResInto { binary })
Ok(ReqResInto {
binary_res,
streaming_req: false,
})
}
pub(crate) fn decode_check_headers(
parts: &mut request::Parts,
for_streaming: bool,
streaming_res: bool,
) -> Result<ReqResInto, Response> {
// Check the version header, if specified.
if let Some(version) = parts.headers.get("connect-protocol-version") {
@ -126,7 +130,7 @@ pub(crate) fn decode_check_headers(
format!("Unsupported protocol version: {}", version),
),
true,
for_streaming,
streaming_res,
));
}
}
@ -134,7 +138,7 @@ pub(crate) fn decode_check_headers(
// Decode the content type (binary or JSON).
// TODO: I'm not sure if this is correct. The Spec doesn't say what content type will be set for
// server-streaming responses.
let binary = match parts.headers.get("content-type") {
let (binary_res, streaming_req) = match parts.headers.get("content-type") {
Some(content_type) => match (
content_type
.to_str()
@ -144,20 +148,20 @@ pub(crate) fn decode_check_headers(
.next()
.unwrap_or_default()
.trim(),
for_streaming,
streaming_res,
) {
("application/json", false) => false,
("application/proto", false) => true,
("application/connect+json", true) => false,
("application/connect+proto", true) => true,
("application/json", false) => (false, false),
("application/proto", false) => (true, false),
("application/connect+json", true) => (false, true),
("application/connect+proto", true) => (true, true),
(s, _) => {
return Err(encode_error_response(
&RpcError::new(
RpcErrorCode::InvalidArgument,
format!("Wrong or unknown Content-Type: {}", s),
),
true,
true,
false,
streaming_res,
))
}
},
@ -167,13 +171,16 @@ pub(crate) fn decode_check_headers(
RpcErrorCode::InvalidArgument,
"Missing Content-Type header".to_string(),
),
true,
true,
false,
streaming_res,
))
}
};
Ok(ReqResInto { binary })
Ok(ReqResInto {
binary_res,
streaming_req,
})
}
pub(crate) fn decode_request_payload_from_query<M, S>(
@ -185,15 +192,15 @@ where
M: Message + DeserializeOwned + Default,
S: Send + Sync + 'static,
{
let for_streaming = false;
let streaming_res = false;
let query_str = match parts.uri.query() {
Some(x) => x,
None => {
return Err(encode_error_response(
&RpcError::new(RpcErrorCode::InvalidArgument, "Missing query".into()),
false,
false,
as_binary,
streaming_res,
))
}
};
@ -206,8 +213,8 @@ where
RpcErrorCode::InvalidArgument,
format!("Wrong query, {}", err),
),
false,
false,
as_binary,
streaming_res,
))
}
};
@ -223,8 +230,8 @@ where
RpcErrorCode::InvalidArgument,
format!("Wrong query.message, {}", err),
),
false,
false,
as_binary,
streaming_res,
))
}
}
@ -240,7 +247,7 @@ where
format!("Failed to decode binary protobuf. {}", e),
),
as_binary,
for_streaming,
streaming_res,
)
})?;
@ -253,7 +260,7 @@ where
format!("Failed to decode json. {}", e),
),
as_binary,
for_streaming,
streaming_res,
)
})?;
@ -264,15 +271,37 @@ where
pub(crate) async fn decode_request_payload<M, S>(
req: Request<Body>,
state: &S,
as_binary: bool,
for_streaming: bool,
binary_res: bool,
streaming_req: bool,
) -> Result<M, Response>
where
M: Message + DeserializeOwned + Default,
S: Send + Sync + 'static,
{
// Axum-connect only supports unary request types, so we can ignore for_streaming.
if as_binary {
let bytes = body::to_bytes(req.into_body(), usize::MAX)
.await
.map_err(|e| {
encode_error_response(
&RpcError::new(
RpcErrorCode::InvalidArgument,
format!("Failed to read request body. {}", e),
),
binary_res,
streaming_req,
)
})?;
// TODO: I need an answer to https://github.com/connectrpc/connect-es/issues/1024
// The spec doesn't seem to imply that a server-streaming response is allowed to treat the
// request as streaming (I guess with a single message?) if content-type is set to
// application/connect+*. That does seem to be how connect-es works though.
let json_bytes = if streaming_req {
// Strip and validate the envelope (the first 5 bytes).
let mut buf = [0; 5];
} else {
};
if binary_res {
let bytes = body::to_bytes(req.into_body(), usize::MAX)
.await
.map_err(|e| {
@ -281,8 +310,8 @@ where
RpcErrorCode::InvalidArgument,
format!("Failed to read request body. {}", e),
),
as_binary,
for_streaming,
binary_res,
streaming_req,
)
})?;
@ -292,8 +321,8 @@ where
RpcErrorCode::InvalidArgument,
format!("Failed to decode binary protobuf. {}", e),
),
as_binary,
for_streaming,
binary_res,
streaming_req,
)
})?;
@ -307,8 +336,8 @@ where
RpcErrorCode::InvalidArgument,
format!("Failed to read request body. {}", e),
),
as_binary,
for_streaming,
binary_res,
streaming_req,
));
}
};
@ -319,8 +348,8 @@ where
RpcErrorCode::InvalidArgument,
format!("Failed to decode JSON protobuf. {}", e),
),
as_binary,
for_streaming,
binary_res,
streaming_req,
)
})?;

View file

@ -161,7 +161,7 @@ macro_rules! impl_handler {
Box::pin(async move {
let (mut parts, body) = req.into_parts();
let ReqResInto { binary } = match decode_check_headers(&mut parts, true) {
let ReqResInto { binary_res, streaming_req } = match decode_check_headers(&mut parts, true) {
Ok(binary) => binary,
Err(e) => return e,
};