Update to axum 0.7; remove TBody template types

This commit is contained in:
Alec Thilenius 2023-12-28 11:50:13 -05:00
parent a7cec357ba
commit 48752280e4
10 changed files with 53 additions and 89 deletions

View file

@ -6,9 +6,8 @@ framework](https://connect.build/docs/introduction) to Rust via idiomatic
# Axum Version # Axum Version
> Axum `0.7` isn't yet supported, because I haven't figured out how to handle - `axum-connect:0.3` works with `axum:0.7`
> streaming responses with it. Use Axum `0.6` for now. PRs welcome, and I'll get - `axum-connect:0.2` works with `axum:0.6`
> it updated when I can.
# Features 🔍 # Features 🔍

View file

@ -1,6 +1,6 @@
[package] [package]
name = "axum-connect-build" name = "axum-connect-build"
version = "0.2.0" version = "0.3.0"
authors = ["Alec Thilenius <alec@thilenius.com>"] authors = ["Alec Thilenius <alec@thilenius.com>"]
edition = "2021" edition = "2021"
categories = [ categories = [

View file

@ -48,23 +48,20 @@ impl AxumConnectServiceGenerator {
if method.server_streaming { if method.server_streaming {
quote! { quote! {
pub fn #method_name<T, H, S, B>( pub fn #method_name<T, H, S>(
handler: H handler: H
) -> impl FnOnce(axum::Router<S, B>) -> axum_connect::router::RpcRouter<S, B> ) -> impl FnOnce(axum::Router<S>) -> axum_connect::router::RpcRouter<S>
where where
H: axum_connect::handler::RpcHandlerStream<#input_type, #output_type, T, S, B>, H: axum_connect::handler::RpcHandlerStream<#input_type, #output_type, T, S>,
T: 'static, T: 'static,
S: Clone + Send + Sync + 'static, S: Clone + Send + Sync + 'static,
B: axum::body::HttpBody + Send + 'static,
B::Data: Send,
B::Error: Into<axum::BoxError>,
{ {
move |router: axum::Router<S, B>| { move |router: axum::Router<S>| {
router.route( router.route(
#path, #path,
axum::routing::post(| axum::routing::post(|
axum::extract::State(state): axum::extract::State<S>, axum::extract::State(state): axum::extract::State<S>,
request: axum::http::Request<B> request: axum::http::Request<axum::body::Body>
| async move { | async move {
handler.call(request, state).await handler.call(request, state).await
}), }),
@ -74,23 +71,20 @@ impl AxumConnectServiceGenerator {
} }
} else { } else {
quote! { quote! {
pub fn #method_name<T, H, S, B>( pub fn #method_name<T, H, S>(
handler: H handler: H
) -> impl FnOnce(axum::Router<S, B>) -> axum_connect::router::RpcRouter<S, B> ) -> impl FnOnce(axum::Router<S>) -> axum_connect::router::RpcRouter<S>
where where
H: axum_connect::handler::RpcHandlerUnary<#input_type, #output_type, T, S, B>, H: axum_connect::handler::RpcHandlerUnary<#input_type, #output_type, T, S>,
T: 'static, T: 'static,
S: Clone + Send + Sync + 'static, S: Clone + Send + Sync + 'static,
B: axum::body::HttpBody + Send + 'static,
B::Data: Send,
B::Error: Into<axum::BoxError>,
{ {
move |router: axum::Router<S, B>| { move |router: axum::Router<S>| {
router.route( router.route(
#path, #path,
axum::routing::post(| axum::routing::post(|
axum::extract::State(state): axum::extract::State<S>, axum::extract::State(state): axum::extract::State<S>,
request: axum::http::Request<B> request: axum::http::Request<axum::body::Body>
| async move { | async move {
handler.call(request, state).await handler.call(request, state).await
}), }),

View file

@ -5,7 +5,7 @@ edition = "2021"
[dependencies] [dependencies]
async-stream = "0.3.5" async-stream = "0.3.5"
axum = "0.6.9" axum = "0.7.2"
axum-connect = { path = "../axum-connect" } axum-connect = { path = "../axum-connect" }
prost = "0.12.1" prost = "0.12.1"
tokio = { version = "1.0", features = ["full"] } tokio = { version = "1.0", features = ["full"] }

View file

@ -1,5 +1,3 @@
use std::net::SocketAddr;
use async_stream::stream; use async_stream::stream;
use axum::{extract::Host, Router}; use axum::{extract::Host, Router};
use axum_connect::{futures::Stream, prelude::*}; use axum_connect::{futures::Stream, prelude::*};
@ -20,13 +18,11 @@ async fn main() {
.rpc(HelloWorldService::say_hello(say_hello_success)) .rpc(HelloWorldService::say_hello(say_hello_success))
.rpc(HelloWorldService::say_hello_stream(say_hello_stream)); .rpc(HelloWorldService::say_hello_stream(say_hello_stream));
// Axum boilerplate to start the server. let listener = tokio::net::TcpListener::bind("127.0.0.1:3030")
let addr = SocketAddr::from(([127, 0, 0, 1], 3030));
println!("listening on http://{}", addr);
axum::Server::bind(&addr)
.serve(app.into_make_service())
.await .await
.unwrap(); .unwrap();
println!("listening on http://{:?}", listener.local_addr());
axum::serve(listener, app).await.unwrap();
} }
async fn say_hello_success(Host(host): Host, request: HelloRequest) -> HelloResponse { async fn say_hello_success(Host(host): Host, request: HelloRequest) -> HelloResponse {

View file

@ -1,6 +1,6 @@
[package] [package]
name = "axum-connect" name = "axum-connect"
version = "0.2.0" version = "0.3.0"
authors = ["Alec Thilenius <alec@thilenius.com>"] authors = ["Alec Thilenius <alec@thilenius.com>"]
edition = "2021" edition = "2021"
categories = [ categories = [
@ -17,7 +17,7 @@ repository = "https://github.com/AThilenius/axum-connect"
[dependencies] [dependencies]
async-stream = "0.3.5" async-stream = "0.3.5"
async-trait = "0.1.64" async-trait = "0.1.64"
axum = "0.6.9" axum = { version = "0.7.2", features = ["multipart"] }
futures = "0.3.26" futures = "0.3.26"
pbjson = "0.6.0" pbjson = "0.6.0"
pbjson-types = "0.6.0" pbjson-types = "0.6.0"

View file

@ -1,9 +1,8 @@
use axum::{ use axum::{
body::{Bytes, HttpBody}, body::{self, Body},
extract::FromRequest, extract::FromRequest,
http::{header, request, Request, StatusCode}, http::{header, request, Request, StatusCode},
response::{IntoResponse, Response}, response::{IntoResponse, Response},
BoxError,
}; };
use prost::Message; use prost::Message;
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
@ -124,8 +123,8 @@ pub(crate) fn decode_check_headers(
Ok(ReqResInto { binary }) Ok(ReqResInto { binary })
} }
pub(crate) async fn decode_request_payload<M, S, B>( pub(crate) async fn decode_request_payload<M, S>(
req: Request<B>, req: Request<Body>,
state: &S, state: &S,
as_binary: bool, as_binary: bool,
for_streaming: bool, for_streaming: bool,
@ -133,26 +132,21 @@ pub(crate) async fn decode_request_payload<M, S, B>(
where where
M: Message + DeserializeOwned + Default, M: Message + DeserializeOwned + Default,
S: Send + Sync + 'static, S: Send + Sync + 'static,
B: Send + Sync + 'static,
B: HttpBody + Send + 'static,
B::Data: Send,
B::Error: Into<BoxError>,
{ {
// Axum-connect only supports unary request types, so we can ignore for_streaming. // Axum-connect only supports unary request types, so we can ignore for_streaming.
if as_binary { if as_binary {
let bytes = match Bytes::from_request(req, state).await { let bytes = body::to_bytes(req.into_body(), usize::MAX)
Ok(bytes) => bytes, .await
Err(e) => { .map_err(|e| {
return Err(encode_error_response( encode_error_response(
&RpcError::new( &RpcError::new(
RpcErrorCode::InvalidArgument, RpcErrorCode::InvalidArgument,
format!("Failed to read request body. {}", e), format!("Failed to read request body. {}", e),
), ),
as_binary, as_binary,
for_streaming, for_streaming,
)) )
} })?;
};
let message: M = M::decode(bytes).map_err(|e| { let message: M = M::decode(bytes).map_err(|e| {
encode_error_response( encode_error_response(

View file

@ -2,10 +2,9 @@ use std::{convert::Infallible, pin::Pin};
use async_stream::stream; use async_stream::stream;
use axum::{ use axum::{
body::{HttpBody, StreamBody}, body::Body,
http::{header, Request, StatusCode}, http::{header, Request, StatusCode},
response::{IntoResponse, Response}, response::{IntoResponse, Response},
BoxError,
}; };
use futures::{Future, Stream, StreamExt}; use futures::{Future, Stream, StreamExt};
use prost::Message; use prost::Message;
@ -22,12 +21,10 @@ use super::codec::{
decode_check_headers, decode_request_payload, encode_error, encode_error_response, ReqResInto, decode_check_headers, decode_request_payload, encode_error, encode_error_response, ReqResInto,
}; };
pub trait RpcHandlerStream<TMReq, TMRes, TUid, TState, TBody>: pub trait RpcHandlerStream<TMReq, TMRes, TUid, TState>: Clone + Send + Sized + 'static {
Clone + Send + Sized + 'static
{
type Future: Future<Output = Response> + Send + 'static; type Future: Future<Output = Response> + Send + 'static;
fn call(self, req: Request<TBody>, state: TState) -> Self::Future; fn call(self, req: Request<Body>, state: TState) -> Self::Future;
} }
// TODO: Get "connect-timeout-ms" (number as string) and apply timeout. // TODO: Get "connect-timeout-ms" (number as string) and apply timeout.
@ -38,8 +35,8 @@ pub trait RpcHandlerStream<TMReq, TMRes, TUid, TState, TBody>:
// This is here because writing Rust macros sucks a**. So I uncomment this when I'm trying to modify // This is here because writing Rust macros sucks a**. So I uncomment this when I'm trying to modify
// the below macro. // the below macro.
// #[allow(unused_parens, non_snake_case, unused_mut)] // #[allow(unused_parens, non_snake_case, unused_mut)]
// impl<TMReq, TMRes, TInto, TFnItem, TFnFut, TFn, TState, TBody, T1> // impl<TMReq, TMRes, TInto, TFnItem, TFnFut, TFn, TState, T1>
// RpcHandlerStream<TMReq, TMRes, (T1, TMReq), TState, TBody> for TFn // RpcHandlerStream<TMReq, TMRes, (T1, TMReq), TState> for TFn
// where // where
// TMReq: Message + DeserializeOwned + Default + Send + 'static, // TMReq: Message + DeserializeOwned + Default + Send + 'static,
// TMRes: Message + Serialize + Send + 'static, // TMRes: Message + Serialize + Send + 'static,
@ -47,15 +44,12 @@ pub trait RpcHandlerStream<TMReq, TMRes, TUid, TState, TBody>:
// TFnItem: Stream<Item = TInto> + Send + Sized + 'static, // TFnItem: Stream<Item = TInto> + Send + Sized + 'static,
// TFnFut: Future<Output = TFnItem> + Send + Sync, // TFnFut: Future<Output = TFnItem> + Send + Sync,
// TFn: FnOnce(T1, TMReq) -> TFnFut + Clone + Send + Sync + 'static, // TFn: FnOnce(T1, TMReq) -> TFnFut + Clone + Send + Sync + 'static,
// TBody: HttpBody + Send + Sync + 'static,
// TBody::Data: Send,
// TBody::Error: Into<BoxError>,
// TState: Send + Sync + 'static, // TState: Send + Sync + 'static,
// T1: RpcFromRequestParts<TMRes, TState> + Send, // T1: RpcFromRequestParts<TMRes, TState> + Send,
// { // {
// type Future = Pin<Box<dyn Future<Output = Response> + Send>>; // type Future = Pin<Box<dyn Future<Output = Response> + Send>>;
// fn call(self, req: Request<TBody>, state: TState) -> Self::Future { // fn call(self, req: Request<Body>, state: TState) -> Self::Future {
// Box::pin(async move { // Box::pin(async move {
// let (mut parts, body) = req.into_parts(); // let (mut parts, body) = req.into_parts();
@ -136,7 +130,7 @@ pub trait RpcHandlerStream<TMReq, TMRes, TUid, TState, TBody>:
// "application/connect+json" // "application/connect+json"
// }, // },
// )], // )],
// StreamBody::new(res), // Body::from_stream(res),
// ) // )
// .into_response() // .into_response()
// }) // })
@ -148,8 +142,8 @@ macro_rules! impl_handler {
[$($ty:ident),*] [$($ty:ident),*]
) => { ) => {
#[allow(unused_parens, non_snake_case, unused_mut)] #[allow(unused_parens, non_snake_case, unused_mut)]
impl<TMReq, TMRes, TInto, TFnItem, TFnFut, TFn, TState, TBody, $($ty,)*> impl<TMReq, TMRes, TInto, TFnItem, TFnFut, TFn, TState, $($ty,)*>
RpcHandlerStream<TMReq, TMRes, ($($ty,)* TMReq), TState, TBody> for TFn RpcHandlerStream<TMReq, TMRes, ($($ty,)* TMReq), TState> for TFn
where where
TMReq: Message + DeserializeOwned + Default + Send + 'static, TMReq: Message + DeserializeOwned + Default + Send + 'static,
TMRes: Message + Serialize + Send + 'static, TMRes: Message + Serialize + Send + 'static,
@ -157,16 +151,13 @@ macro_rules! impl_handler {
TFnItem: Stream<Item = TInto> + Send + Sized + 'static, TFnItem: Stream<Item = TInto> + Send + Sized + 'static,
TFnFut: Future<Output = TFnItem> + Send + Sync, TFnFut: Future<Output = TFnItem> + Send + Sync,
TFn: FnOnce($($ty,)* TMReq) -> TFnFut + Clone + Send + Sync + 'static, TFn: FnOnce($($ty,)* TMReq) -> TFnFut + Clone + Send + Sync + 'static,
TBody: HttpBody + Send + Sync + 'static,
TBody::Data: Send,
TBody::Error: Into<BoxError>,
TState: Send + Sync + 'static, TState: Send + Sync + 'static,
$( $ty: RpcFromRequestParts<TMRes, TState> + Send, )* $( $ty: RpcFromRequestParts<TMRes, TState> + Send, )*
{ {
type Future = Pin<Box<dyn Future<Output = Response> + Send>>; type Future = Pin<Box<dyn Future<Output = Response> + Send>>;
fn call(self, req: Request<TBody>, state: TState) -> Self::Future { fn call(self, req: Request<Body>, state: TState) -> Self::Future {
Box::pin(async move { Box::pin(async move {
let (mut parts, body) = req.into_parts(); let (mut parts, body) = req.into_parts();
@ -249,7 +240,7 @@ macro_rules! impl_handler {
"application/connect+json" "application/connect+json"
}, },
)], )],
StreamBody::new(res), Body::from_stream(res),
) )
.into_response() .into_response()
}) })

View file

@ -1,10 +1,9 @@
use std::{convert::Infallible, pin::Pin}; use std::{convert::Infallible, pin::Pin};
use axum::{ use axum::{
body::HttpBody, body::Body,
http::{header, Request, StatusCode}, http::{header, Request, StatusCode},
response::{IntoResponse, Response}, response::{IntoResponse, Response},
BoxError,
}; };
use futures::Future; use futures::Future;
use prost::Message; use prost::Message;
@ -21,12 +20,10 @@ use super::codec::{
decode_check_headers, decode_request_payload, encode_error_response, ReqResInto, decode_check_headers, decode_request_payload, encode_error_response, ReqResInto,
}; };
pub trait RpcHandlerUnary<TMReq, TMRes, TUid, TState, TBody>: pub trait RpcHandlerUnary<TMReq, TMRes, TUid, TState>: Clone + Send + Sized + 'static {
Clone + Send + Sized + 'static
{
type Future: Future<Output = Response> + Send + 'static; type Future: Future<Output = Response> + Send + 'static;
fn call(self, req: Request<TBody>, state: TState) -> Self::Future; fn call(self, req: Request<Body>, state: TState) -> Self::Future;
} }
// This is for Unary. // This is for Unary.
@ -40,23 +37,20 @@ pub trait RpcHandlerUnary<TMReq, TMRes, TUid, TState, TBody>:
// This is here because writing Rust macros sucks a**. So I uncomment this when I'm trying to modify // This is here because writing Rust macros sucks a**. So I uncomment this when I'm trying to modify
// the below macro. // the below macro.
// #[allow(unused_parens, non_snake_case, unused_mut)] // #[allow(unused_parens, non_snake_case, unused_mut)]
// impl<TMReq, TMRes, TInto, TFnFut, TFn, TState, TBody, T1> // impl<TMReq, TMRes, TInto, TFnFut, TFn, TState, T1>
// RpcHandlerUnary<TMReq, TMRes, (T1, TMReq), TState, TBody> for TFn // RpcHandlerUnary<TMReq, TMRes, (T1, TMReq), TState> for TFn
// where // where
// TMReq: Message + DeserializeOwned + Default + Send + 'static, // TMReq: Message + DeserializeOwned + Default + Send + 'static,
// TMRes: Message + Serialize + Send + 'static, // TMRes: Message + Serialize + Send + 'static,
// TInto: RpcIntoResponse<TMRes>, // TInto: RpcIntoResponse<TMRes>,
// TFnFut: Future<Output = TInto> + Send, // TFnFut: Future<Output = TInto> + Send,
// TFn: FnOnce(T1, TMReq) -> TFnFut + Clone + Send + 'static, // TFn: FnOnce(T1, TMReq) -> TFnFut + Clone + Send + 'static,
// TBody: HttpBody + Send + Sync + 'static,
// TBody::Data: Send,
// TBody::Error: Into<BoxError>,
// TState: Send + Sync + 'static, // TState: Send + Sync + 'static,
// T1: RpcFromRequestParts<TMRes, TState> + Send, // T1: RpcFromRequestParts<TMRes, TState> + Send,
// { // {
// type Future = Pin<Box<dyn Future<Output = Response> + Send>>; // type Future = Pin<Box<dyn Future<Output = Response> + Send>>;
// fn call(self, req: Request<TBody>, state: TState) -> Self::Future { // fn call(self, req: Request<Body>, state: TState) -> Self::Future {
// Box::pin(async move { // Box::pin(async move {
// let (mut parts, body) = req.into_parts(); // let (mut parts, body) = req.into_parts();
@ -127,23 +121,20 @@ macro_rules! impl_handler {
[$($ty:ident),*] [$($ty:ident),*]
) => { ) => {
#[allow(unused_parens, non_snake_case, unused_mut)] #[allow(unused_parens, non_snake_case, unused_mut)]
impl<TMReq, TMRes, TInto, TFnFut, TFn, TState, TBody, $($ty,)*> impl<TMReq, TMRes, TInto, TFnFut, TFn, TState, $($ty,)*>
RpcHandlerUnary<TMReq, TMRes, ($($ty,)* TMReq), TState, TBody> for TFn RpcHandlerUnary<TMReq, TMRes, ($($ty,)* TMReq), TState> for TFn
where where
TMReq: Message + DeserializeOwned + Default + Send + 'static, TMReq: Message + DeserializeOwned + Default + Send + 'static,
TMRes: Message + Serialize + Send + 'static, TMRes: Message + Serialize + Send + 'static,
TInto: RpcIntoResponse<TMRes>, TInto: RpcIntoResponse<TMRes>,
TFnFut: Future<Output = TInto> + Send, TFnFut: Future<Output = TInto> + Send,
TFn: FnOnce($($ty,)* TMReq) -> TFnFut + Clone + Send + 'static, TFn: FnOnce($($ty,)* TMReq) -> TFnFut + Clone + Send + 'static,
TBody: HttpBody + Send + Sync + 'static,
TBody::Data: Send,
TBody::Error: Into<BoxError>,
TState: Send + Sync + 'static, TState: Send + Sync + 'static,
$( $ty: RpcFromRequestParts<TMRes, TState> + Send, )* $( $ty: RpcFromRequestParts<TMRes, TState> + Send, )*
{ {
type Future = Pin<Box<dyn Future<Output = Response> + Send>>; type Future = Pin<Box<dyn Future<Output = Response> + Send>>;
fn call(self, req: Request<TBody>, state: TState) -> Self::Future { fn call(self, req: Request<Body>, state: TState) -> Self::Future {
Box::pin(async move { Box::pin(async move {
let (mut parts, body) = req.into_parts(); let (mut parts, body) = req.into_parts();

View file

@ -1,19 +1,18 @@
use axum::Router; use axum::Router;
pub trait RpcRouterExt<S, B>: Sized { pub trait RpcRouterExt<S>: Sized {
fn rpc<F>(self, register: F) -> Self fn rpc<F>(self, register: F) -> Self
where where
F: FnOnce(Self) -> RpcRouter<S, B>; F: FnOnce(Self) -> RpcRouter<S>;
} }
impl<S, B> RpcRouterExt<S, B> for Router<S, B> { impl<S> RpcRouterExt<S> for Router<S> {
fn rpc<F>(self, register: F) -> Self fn rpc<F>(self, register: F) -> Self
where where
F: FnOnce(Self) -> RpcRouter<S, B>, F: FnOnce(Self) -> RpcRouter<S>,
{ {
register(self) register(self)
// unsafe { std::mem::transmute::<RpcRouter<S, B>, Router<S, B>>(register(self)) }
} }
} }
pub type RpcRouter<S, B> = Router<S, B>; pub type RpcRouter<S> = Router<S>;