From 48752280e403171ff8a14a2d150902e9d0126710 Mon Sep 17 00:00:00 2001 From: Alec Thilenius Date: Thu, 28 Dec 2023 11:50:13 -0500 Subject: [PATCH] Update to axum 0.7; remove TBody template types --- README.md | 5 ++-- axum-connect-build/Cargo.toml | 2 +- axum-connect-build/src/gen.rs | 26 +++++++----------- axum-connect-examples/Cargo.toml | 2 +- axum-connect-examples/src/main.rs | 10 +++---- axum-connect/Cargo.toml | 4 +-- axum-connect/src/handler/codec.rs | 24 +++++++---------- axum-connect/src/handler/handler_stream.rs | 31 ++++++++-------------- axum-connect/src/handler/handler_unary.rs | 27 +++++++------------ axum-connect/src/router.rs | 11 ++++---- 10 files changed, 53 insertions(+), 89 deletions(-) diff --git a/README.md b/README.md index 9e0d5c5..2774cd4 100644 --- a/README.md +++ b/README.md @@ -6,9 +6,8 @@ framework](https://connect.build/docs/introduction) to Rust via idiomatic # Axum Version -> Axum `0.7` isn't yet supported, because I haven't figured out how to handle -> streaming responses with it. Use Axum `0.6` for now. PRs welcome, and I'll get -> it updated when I can. +- `axum-connect:0.3` works with `axum:0.7` +- `axum-connect:0.2` works with `axum:0.6` # Features 🔍 diff --git a/axum-connect-build/Cargo.toml b/axum-connect-build/Cargo.toml index 14c2aca..b1202a9 100644 --- a/axum-connect-build/Cargo.toml +++ b/axum-connect-build/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "axum-connect-build" -version = "0.2.0" +version = "0.3.0" authors = ["Alec Thilenius "] edition = "2021" categories = [ diff --git a/axum-connect-build/src/gen.rs b/axum-connect-build/src/gen.rs index eb596cc..f417f52 100644 --- a/axum-connect-build/src/gen.rs +++ b/axum-connect-build/src/gen.rs @@ -48,23 +48,20 @@ impl AxumConnectServiceGenerator { if method.server_streaming { quote! { - pub fn #method_name( + pub fn #method_name( handler: H - ) -> impl FnOnce(axum::Router) -> axum_connect::router::RpcRouter + ) -> impl FnOnce(axum::Router) -> axum_connect::router::RpcRouter where - H: axum_connect::handler::RpcHandlerStream<#input_type, #output_type, T, S, B>, + H: axum_connect::handler::RpcHandlerStream<#input_type, #output_type, T, S>, T: 'static, S: Clone + Send + Sync + 'static, - B: axum::body::HttpBody + Send + 'static, - B::Data: Send, - B::Error: Into, { - move |router: axum::Router| { + move |router: axum::Router| { router.route( #path, axum::routing::post(| axum::extract::State(state): axum::extract::State, - request: axum::http::Request + request: axum::http::Request | async move { handler.call(request, state).await }), @@ -74,23 +71,20 @@ impl AxumConnectServiceGenerator { } } else { quote! { - pub fn #method_name( + pub fn #method_name( handler: H - ) -> impl FnOnce(axum::Router) -> axum_connect::router::RpcRouter + ) -> impl FnOnce(axum::Router) -> axum_connect::router::RpcRouter where - H: axum_connect::handler::RpcHandlerUnary<#input_type, #output_type, T, S, B>, + H: axum_connect::handler::RpcHandlerUnary<#input_type, #output_type, T, S>, T: 'static, S: Clone + Send + Sync + 'static, - B: axum::body::HttpBody + Send + 'static, - B::Data: Send, - B::Error: Into, { - move |router: axum::Router| { + move |router: axum::Router| { router.route( #path, axum::routing::post(| axum::extract::State(state): axum::extract::State, - request: axum::http::Request + request: axum::http::Request | async move { handler.call(request, state).await }), diff --git a/axum-connect-examples/Cargo.toml b/axum-connect-examples/Cargo.toml index 5f42262..3883a74 100644 --- a/axum-connect-examples/Cargo.toml +++ b/axum-connect-examples/Cargo.toml @@ -5,7 +5,7 @@ edition = "2021" [dependencies] async-stream = "0.3.5" -axum = "0.6.9" +axum = "0.7.2" axum-connect = { path = "../axum-connect" } prost = "0.12.1" tokio = { version = "1.0", features = ["full"] } diff --git a/axum-connect-examples/src/main.rs b/axum-connect-examples/src/main.rs index 2c12262..48d5228 100644 --- a/axum-connect-examples/src/main.rs +++ b/axum-connect-examples/src/main.rs @@ -1,5 +1,3 @@ -use std::net::SocketAddr; - use async_stream::stream; use axum::{extract::Host, Router}; use axum_connect::{futures::Stream, prelude::*}; @@ -20,13 +18,11 @@ async fn main() { .rpc(HelloWorldService::say_hello(say_hello_success)) .rpc(HelloWorldService::say_hello_stream(say_hello_stream)); - // Axum boilerplate to start the server. - let addr = SocketAddr::from(([127, 0, 0, 1], 3030)); - println!("listening on http://{}", addr); - axum::Server::bind(&addr) - .serve(app.into_make_service()) + let listener = tokio::net::TcpListener::bind("127.0.0.1:3030") .await .unwrap(); + println!("listening on http://{:?}", listener.local_addr()); + axum::serve(listener, app).await.unwrap(); } async fn say_hello_success(Host(host): Host, request: HelloRequest) -> HelloResponse { diff --git a/axum-connect/Cargo.toml b/axum-connect/Cargo.toml index 44562b2..be04a25 100644 --- a/axum-connect/Cargo.toml +++ b/axum-connect/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "axum-connect" -version = "0.2.0" +version = "0.3.0" authors = ["Alec Thilenius "] edition = "2021" categories = [ @@ -17,7 +17,7 @@ repository = "https://github.com/AThilenius/axum-connect" [dependencies] async-stream = "0.3.5" async-trait = "0.1.64" -axum = "0.6.9" +axum = { version = "0.7.2", features = ["multipart"] } futures = "0.3.26" pbjson = "0.6.0" pbjson-types = "0.6.0" diff --git a/axum-connect/src/handler/codec.rs b/axum-connect/src/handler/codec.rs index e78e5f3..9997271 100644 --- a/axum-connect/src/handler/codec.rs +++ b/axum-connect/src/handler/codec.rs @@ -1,9 +1,8 @@ use axum::{ - body::{Bytes, HttpBody}, + body::{self, Body}, extract::FromRequest, http::{header, request, Request, StatusCode}, response::{IntoResponse, Response}, - BoxError, }; use prost::Message; use serde::de::DeserializeOwned; @@ -124,8 +123,8 @@ pub(crate) fn decode_check_headers( Ok(ReqResInto { binary }) } -pub(crate) async fn decode_request_payload( - req: Request, +pub(crate) async fn decode_request_payload( + req: Request, state: &S, as_binary: bool, for_streaming: bool, @@ -133,26 +132,21 @@ pub(crate) async fn decode_request_payload( where M: Message + DeserializeOwned + Default, S: Send + Sync + 'static, - B: Send + Sync + 'static, - B: HttpBody + Send + 'static, - B::Data: Send, - B::Error: Into, { // Axum-connect only supports unary request types, so we can ignore for_streaming. if as_binary { - let bytes = match Bytes::from_request(req, state).await { - Ok(bytes) => bytes, - Err(e) => { - return Err(encode_error_response( + let bytes = body::to_bytes(req.into_body(), usize::MAX) + .await + .map_err(|e| { + encode_error_response( &RpcError::new( RpcErrorCode::InvalidArgument, format!("Failed to read request body. {}", e), ), as_binary, for_streaming, - )) - } - }; + ) + })?; let message: M = M::decode(bytes).map_err(|e| { encode_error_response( diff --git a/axum-connect/src/handler/handler_stream.rs b/axum-connect/src/handler/handler_stream.rs index 940bda2..b1d2893 100644 --- a/axum-connect/src/handler/handler_stream.rs +++ b/axum-connect/src/handler/handler_stream.rs @@ -2,10 +2,9 @@ use std::{convert::Infallible, pin::Pin}; use async_stream::stream; use axum::{ - body::{HttpBody, StreamBody}, + body::Body, http::{header, Request, StatusCode}, response::{IntoResponse, Response}, - BoxError, }; use futures::{Future, Stream, StreamExt}; use prost::Message; @@ -22,12 +21,10 @@ use super::codec::{ decode_check_headers, decode_request_payload, encode_error, encode_error_response, ReqResInto, }; -pub trait RpcHandlerStream: - Clone + Send + Sized + 'static -{ +pub trait RpcHandlerStream: Clone + Send + Sized + 'static { type Future: Future + Send + 'static; - fn call(self, req: Request, state: TState) -> Self::Future; + fn call(self, req: Request, state: TState) -> Self::Future; } // TODO: Get "connect-timeout-ms" (number as string) and apply timeout. @@ -38,8 +35,8 @@ pub trait RpcHandlerStream: // This is here because writing Rust macros sucks a**. So I uncomment this when I'm trying to modify // the below macro. // #[allow(unused_parens, non_snake_case, unused_mut)] -// impl -// RpcHandlerStream for TFn +// impl +// RpcHandlerStream for TFn // where // TMReq: Message + DeserializeOwned + Default + Send + 'static, // TMRes: Message + Serialize + Send + 'static, @@ -47,15 +44,12 @@ pub trait RpcHandlerStream: // TFnItem: Stream + Send + Sized + 'static, // TFnFut: Future + Send + Sync, // TFn: FnOnce(T1, TMReq) -> TFnFut + Clone + Send + Sync + 'static, -// TBody: HttpBody + Send + Sync + 'static, -// TBody::Data: Send, -// TBody::Error: Into, // TState: Send + Sync + 'static, // T1: RpcFromRequestParts + Send, // { // type Future = Pin + Send>>; -// fn call(self, req: Request, state: TState) -> Self::Future { +// fn call(self, req: Request, state: TState) -> Self::Future { // Box::pin(async move { // let (mut parts, body) = req.into_parts(); @@ -136,7 +130,7 @@ pub trait RpcHandlerStream: // "application/connect+json" // }, // )], -// StreamBody::new(res), +// Body::from_stream(res), // ) // .into_response() // }) @@ -148,8 +142,8 @@ macro_rules! impl_handler { [$($ty:ident),*] ) => { #[allow(unused_parens, non_snake_case, unused_mut)] - impl - RpcHandlerStream for TFn + impl + RpcHandlerStream for TFn where TMReq: Message + DeserializeOwned + Default + Send + 'static, TMRes: Message + Serialize + Send + 'static, @@ -157,16 +151,13 @@ macro_rules! impl_handler { TFnItem: Stream + Send + Sized + 'static, TFnFut: Future + Send + Sync, TFn: FnOnce($($ty,)* TMReq) -> TFnFut + Clone + Send + Sync + 'static, - TBody: HttpBody + Send + Sync + 'static, - TBody::Data: Send, - TBody::Error: Into, TState: Send + Sync + 'static, $( $ty: RpcFromRequestParts + Send, )* { type Future = Pin + Send>>; - fn call(self, req: Request, state: TState) -> Self::Future { + fn call(self, req: Request, state: TState) -> Self::Future { Box::pin(async move { let (mut parts, body) = req.into_parts(); @@ -249,7 +240,7 @@ macro_rules! impl_handler { "application/connect+json" }, )], - StreamBody::new(res), + Body::from_stream(res), ) .into_response() }) diff --git a/axum-connect/src/handler/handler_unary.rs b/axum-connect/src/handler/handler_unary.rs index 34931da..66d1a41 100644 --- a/axum-connect/src/handler/handler_unary.rs +++ b/axum-connect/src/handler/handler_unary.rs @@ -1,10 +1,9 @@ use std::{convert::Infallible, pin::Pin}; use axum::{ - body::HttpBody, + body::Body, http::{header, Request, StatusCode}, response::{IntoResponse, Response}, - BoxError, }; use futures::Future; use prost::Message; @@ -21,12 +20,10 @@ use super::codec::{ decode_check_headers, decode_request_payload, encode_error_response, ReqResInto, }; -pub trait RpcHandlerUnary: - Clone + Send + Sized + 'static -{ +pub trait RpcHandlerUnary: Clone + Send + Sized + 'static { type Future: Future + Send + 'static; - fn call(self, req: Request, state: TState) -> Self::Future; + fn call(self, req: Request, state: TState) -> Self::Future; } // This is for Unary. @@ -40,23 +37,20 @@ pub trait RpcHandlerUnary: // This is here because writing Rust macros sucks a**. So I uncomment this when I'm trying to modify // the below macro. // #[allow(unused_parens, non_snake_case, unused_mut)] -// impl -// RpcHandlerUnary for TFn +// impl +// RpcHandlerUnary for TFn // where // TMReq: Message + DeserializeOwned + Default + Send + 'static, // TMRes: Message + Serialize + Send + 'static, // TInto: RpcIntoResponse, // TFnFut: Future + Send, // TFn: FnOnce(T1, TMReq) -> TFnFut + Clone + Send + 'static, -// TBody: HttpBody + Send + Sync + 'static, -// TBody::Data: Send, -// TBody::Error: Into, // TState: Send + Sync + 'static, // T1: RpcFromRequestParts + Send, // { // type Future = Pin + Send>>; -// fn call(self, req: Request, state: TState) -> Self::Future { +// fn call(self, req: Request, state: TState) -> Self::Future { // Box::pin(async move { // let (mut parts, body) = req.into_parts(); @@ -127,23 +121,20 @@ macro_rules! impl_handler { [$($ty:ident),*] ) => { #[allow(unused_parens, non_snake_case, unused_mut)] - impl - RpcHandlerUnary for TFn + impl + RpcHandlerUnary for TFn where TMReq: Message + DeserializeOwned + Default + Send + 'static, TMRes: Message + Serialize + Send + 'static, TInto: RpcIntoResponse, TFnFut: Future + Send, TFn: FnOnce($($ty,)* TMReq) -> TFnFut + Clone + Send + 'static, - TBody: HttpBody + Send + Sync + 'static, - TBody::Data: Send, - TBody::Error: Into, TState: Send + Sync + 'static, $( $ty: RpcFromRequestParts + Send, )* { type Future = Pin + Send>>; - fn call(self, req: Request, state: TState) -> Self::Future { + fn call(self, req: Request, state: TState) -> Self::Future { Box::pin(async move { let (mut parts, body) = req.into_parts(); diff --git a/axum-connect/src/router.rs b/axum-connect/src/router.rs index 1c0bb65..85e6deb 100644 --- a/axum-connect/src/router.rs +++ b/axum-connect/src/router.rs @@ -1,19 +1,18 @@ use axum::Router; -pub trait RpcRouterExt: Sized { +pub trait RpcRouterExt: Sized { fn rpc(self, register: F) -> Self where - F: FnOnce(Self) -> RpcRouter; + F: FnOnce(Self) -> RpcRouter; } -impl RpcRouterExt for Router { +impl RpcRouterExt for Router { fn rpc(self, register: F) -> Self where - F: FnOnce(Self) -> RpcRouter, + F: FnOnce(Self) -> RpcRouter, { register(self) - // unsafe { std::mem::transmute::, Router>(register(self)) } } } -pub type RpcRouter = Router; +pub type RpcRouter = Router;