diff --git a/Cargo.toml b/Cargo.toml index 16f57a7b..51d57ff4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,3 +10,19 @@ members = [ ] exclude = ["examples"] + +[workspace.dependencies] +base64 = "0.21" +bytes = "1" +futures = "0.3" +futures-channel = "0.3" +futures-util = "0.3" +http = "1.0" +http-body = "1.0" +http-body-util = "0.1" +http-serde = "2.0" +hyper = "1.0" +hyper-util = "0.1.1" +pin-project-lite = "0.2" +tower = "0.4" +tower-service = "0.3" diff --git a/examples/basic-streaming-response/Cargo.toml b/examples/basic-streaming-response/Cargo.toml index 4bbe66f4..e9b7499c 100644 --- a/examples/basic-streaming-response/Cargo.toml +++ b/examples/basic-streaming-response/Cargo.toml @@ -6,11 +6,6 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -hyper = { version = "0.14", features = [ - "http1", - "client", - "stream", -] } lambda_runtime = { path = "../../lambda-runtime" } tokio = { version = "1", features = ["macros"] } tracing = { version = "0.1", features = ["log"] } diff --git a/examples/basic-streaming-response/src/main.rs b/examples/basic-streaming-response/src/main.rs index 9d505206..c8932554 100644 --- a/examples/basic-streaming-response/src/main.rs +++ b/examples/basic-streaming-response/src/main.rs @@ -1,12 +1,15 @@ -use hyper::body::Body; -use lambda_runtime::{service_fn, Error, LambdaEvent, StreamResponse}; +use lambda_runtime::{ + service_fn, + streaming::{channel, Body, Response}, + Error, LambdaEvent, +}; use serde_json::Value; use std::{thread, time::Duration}; -async fn func(_event: LambdaEvent) -> Result, Error> { +async fn func(_event: LambdaEvent) -> Result, Error> { let messages = vec!["Hello", "world", "from", "Lambda!"]; - let (mut tx, rx) = Body::channel(); + let (mut tx, rx) = channel(); tokio::spawn(async move { for message in messages.iter() { @@ -15,10 +18,7 @@ async fn func(_event: LambdaEvent) -> Result, Error> } }); - Ok(StreamResponse { - metadata_prelude: Default::default(), - stream: rx, - }) + Ok(Response::from(rx)) } #[tokio::main] diff --git a/examples/http-axum-diesel-ssl/Cargo.toml b/examples/http-axum-diesel-ssl/Cargo.toml index cdcdd4ef..006a82ce 100755 --- a/examples/http-axum-diesel-ssl/Cargo.toml +++ b/examples/http-axum-diesel-ssl/Cargo.toml @@ -11,7 +11,7 @@ edition = "2021" # and it will keep the alphabetic ordering for you. [dependencies] -axum = "0.6.4" +axum = "0.7" bb8 = "0.8.0" diesel = "2.0.3" diesel-async = { version = "0.2.1", features = ["postgres", "bb8"] } diff --git a/examples/http-axum-diesel/Cargo.toml b/examples/http-axum-diesel/Cargo.toml index 5a97cfab..0366f32d 100644 --- a/examples/http-axum-diesel/Cargo.toml +++ b/examples/http-axum-diesel/Cargo.toml @@ -11,7 +11,7 @@ edition = "2021" # and it will keep the alphabetic ordering for you. [dependencies] -axum = "0.6.4" +axum = "0.7" bb8 = "0.8.0" diesel = "2.0.3" diesel-async = { version = "0.2.1", features = ["postgres", "bb8"] } diff --git a/examples/http-axum/Cargo.toml b/examples/http-axum/Cargo.toml index 50db3ebf..88df4140 100644 --- a/examples/http-axum/Cargo.toml +++ b/examples/http-axum/Cargo.toml @@ -11,11 +11,10 @@ edition = "2021" # and it will keep the alphabetic ordering for you. [dependencies] +axum = "0.7" lambda_http = { path = "../../lambda-http" } lambda_runtime = { path = "../../lambda-runtime" } +serde_json = "1.0" tokio = { version = "1", features = ["macros"] } tracing = { version = "0.1", features = ["log"] } tracing-subscriber = { version = "0.3", default-features = false, features = ["fmt"] } - -axum = "0.6.4" -serde_json = "1.0" diff --git a/examples/http-cors/Cargo.toml b/examples/http-cors/Cargo.toml index 9fd7f25b..059a3f63 100644 --- a/examples/http-cors/Cargo.toml +++ b/examples/http-cors/Cargo.toml @@ -14,7 +14,7 @@ edition = "2021" lambda_http = { path = "../../lambda-http" } lambda_runtime = { path = "../../lambda-runtime" } tokio = { version = "1", features = ["macros"] } -tower-http = { version = "0.3.3", features = ["cors"] } +tower-http = { version = "0.5", features = ["cors"] } tracing = { version = "0.1", features = ["log"] } tracing-subscriber = { version = "0.3", default-features = false, features = ["fmt"] } diff --git a/examples/http-tower-trace/Cargo.toml b/examples/http-tower-trace/Cargo.toml index 2b8f7a60..0b0c46a9 100644 --- a/examples/http-tower-trace/Cargo.toml +++ b/examples/http-tower-trace/Cargo.toml @@ -14,6 +14,6 @@ edition = "2021" lambda_http = { path = "../../lambda-http" } lambda_runtime = "0.5.1" tokio = { version = "1", features = ["macros"] } -tower-http = { version = "0.3.4", features = ["trace"] } +tower-http = { version = "0.5", features = ["trace"] } tracing = { version = "0.1", features = ["log"] } tracing-subscriber = { version = "0.3", default-features = false, features = ["fmt"] } diff --git a/lambda-events/Cargo.toml b/lambda-events/Cargo.toml index b35809a2..29d4e191 100644 --- a/lambda-events/Cargo.toml +++ b/lambda-events/Cargo.toml @@ -16,25 +16,25 @@ categories = ["api-bindings", "encoding", "web-programming"] edition = "2021" [dependencies] -base64 = "0.21" -http = { version = "0.2", optional = true } -http-body = { version = "0.4", optional = true } -http-serde = { version = "^1", optional = true } -serde = { version = "^1", features = ["derive"] } -serde_with = { version = "^3", features = ["json"], optional = true } -serde_json = "^1" -serde_dynamo = { version = "^4.1", optional = true } -bytes = { version = "1", features = ["serde"], optional = true } +base64 = { workspace = true } +bytes = { workspace = true, features = ["serde"], optional = true } chrono = { version = "0.4.31", default-features = false, features = [ "clock", "serde", "std", ], optional = true } +flate2 = { version = "1.0.24", optional = true } +http = { workspace = true, optional = true } +http-body = { workspace = true, optional = true } +http-serde = { workspace = true, optional = true } query_map = { version = "^0.7", features = [ "serde", "url-query", ], optional = true } -flate2 = { version = "1.0.24", optional = true } +serde = { version = "^1", features = ["derive"] } +serde_with = { version = "^3", features = ["json"], optional = true } +serde_json = "^1" +serde_dynamo = { version = "^4.1", optional = true } [features] default = [ diff --git a/lambda-events/src/encodings/http.rs b/lambda-events/src/encodings/http.rs index effb48f4..1cb10c81 100644 --- a/lambda-events/src/encodings/http.rs +++ b/lambda-events/src/encodings/http.rs @@ -218,25 +218,6 @@ impl HttpBody for Body { type Data = Bytes; type Error = super::Error; - fn poll_data( - self: Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, - ) -> Poll>> { - let body = take(self.get_mut()); - Poll::Ready(match body { - Body::Empty => None, - Body::Text(s) => Some(Ok(s.into())), - Body::Binary(b) => Some(Ok(b.into())), - }) - } - - fn poll_trailers( - self: Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, - ) -> Poll, Self::Error>> { - Poll::Ready(Ok(None)) - } - fn is_end_stream(&self) -> bool { matches!(self, Body::Empty) } @@ -248,6 +229,18 @@ impl HttpBody for Body { Body::Binary(ref b) => SizeHint::with_exact(b.len() as u64), } } + + fn poll_frame( + self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> Poll, Self::Error>>> { + let body = take(self.get_mut()); + Poll::Ready(match body { + Body::Empty => None, + Body::Text(s) => Some(Ok(http_body::Frame::data(s.into()))), + Body::Binary(b) => Some(Ok(http_body::Frame::data(b.into()))), + }) + } } #[cfg(test)] diff --git a/lambda-events/src/event/bedrock_agent_runtime/mod.rs b/lambda-events/src/event/bedrock_agent_runtime/mod.rs index 836d2f95..cf84d4d3 100644 --- a/lambda-events/src/event/bedrock_agent_runtime/mod.rs +++ b/lambda-events/src/event/bedrock_agent_runtime/mod.rs @@ -82,8 +82,6 @@ pub struct Agent { #[cfg(test)] mod tests { - use serde_json; - #[test] #[cfg(feature = "bedrock-agent-runtime")] fn example_bedrock_agent__runtime_event() { diff --git a/lambda-extension/Cargo.toml b/lambda-extension/Cargo.toml index 5d9d54b4..667b866f 100644 --- a/lambda-extension/Cargo.toml +++ b/lambda-extension/Cargo.toml @@ -15,15 +15,21 @@ readme = "README.md" [dependencies] async-stream = "0.3" -bytes = "1.0" +bytes = { workspace = true } chrono = { version = "0.4", features = ["serde"] } -http = "0.2" -hyper = { version = "0.14.20", features = ["http1", "client", "server", "stream", "runtime"] } +http = { workspace = true } +http-body-util = { workspace = true } +hyper = { workspace = true, features = ["http1", "client", "server"] } +hyper-util = { workspace = true } lambda_runtime_api_client = { version = "0.8", path = "../lambda-runtime-api-client" } serde = { version = "1", features = ["derive"] } serde_json = "^1" tracing = { version = "0.1", features = ["log"] } -tokio = { version = "1.0", features = ["macros", "io-util", "sync", "rt-multi-thread"] } +tokio = { version = "1.0", features = [ + "macros", + "io-util", + "sync", + "rt-multi-thread", +] } tokio-stream = "0.1.2" tower = { version = "0.4", features = ["make", "util"] } - diff --git a/lambda-extension/src/error.rs b/lambda-extension/src/error.rs index 2c3e23b3..4f6a9909 100644 --- a/lambda-extension/src/error.rs +++ b/lambda-extension/src/error.rs @@ -1,5 +1,5 @@ /// Error type that extensions may result in -pub type Error = lambda_runtime_api_client::Error; +pub type Error = lambda_runtime_api_client::BoxError; /// Simple error that encapsulates human readable descriptions #[derive(Clone, Debug, PartialEq, Eq)] diff --git a/lambda-extension/src/extension.rs b/lambda-extension/src/extension.rs index d653e0dc..cac1c7ec 100644 --- a/lambda-extension/src/extension.rs +++ b/lambda-extension/src/extension.rs @@ -1,13 +1,18 @@ +use http::Request; +use http_body_util::BodyExt; +use hyper::body::Incoming; +use hyper::server::conn::http1; +use hyper::service::service_fn; + +use hyper_util::rt::tokio::TokioIo; +use lambda_runtime_api_client::Client; use std::{ convert::Infallible, fmt, future::ready, future::Future, net::SocketAddr, path::PathBuf, pin::Pin, sync::Arc, }; - -use hyper::{server::conn::AddrStream, Server}; -use lambda_runtime_api_client::Client; -use tokio::sync::Mutex; +use tokio::{net::TcpListener, sync::Mutex}; use tokio_stream::StreamExt; -use tower::{service_fn, MakeService, Service, ServiceExt}; -use tracing::{error, trace}; +use tower::{MakeService, Service, ServiceExt}; +use tracing::trace; use crate::{ logs::*, @@ -64,22 +69,22 @@ impl<'a, E, L, T> Extension<'a, E, L, T> where E: Service, E::Future: Future>, - E::Error: Into> + fmt::Display + fmt::Debug, + E::Error: Into + fmt::Display + fmt::Debug, // Fixme: 'static bound might be too restrictive L: MakeService<(), Vec, Response = ()> + Send + Sync + 'static, L::Service: Service, Response = ()> + Send + Sync, >>::Future: Send + 'a, - L::Error: Into> + fmt::Debug, - L::MakeError: Into> + fmt::Debug, + L::Error: Into + fmt::Debug, + L::MakeError: Into + fmt::Debug, L::Future: Send, // Fixme: 'static bound might be too restrictive T: MakeService<(), Vec, Response = ()> + Send + Sync + 'static, T::Service: Service, Response = ()> + Send + Sync, >>::Future: Send + 'a, - T::Error: Into> + fmt::Debug, - T::MakeError: Into> + fmt::Debug, + T::Error: Into + fmt::Debug, + T::MakeError: Into + fmt::Debug, T::Future: Send, { /// Create a new [`Extension`] with a given extension name @@ -104,7 +109,7 @@ where where N: Service, N::Future: Future>, - N::Error: Into> + fmt::Display, + N::Error: Into + fmt::Display, { Extension { events_processor: ep, @@ -126,7 +131,7 @@ where where N: Service<()>, N::Future: Future>, - N::Error: Into> + fmt::Display, + N::Error: Into + fmt::Display, { Extension { logs_processor: Some(lp), @@ -173,7 +178,7 @@ where where N: Service<()>, N::Future: Future>, - N::Error: Into> + fmt::Display, + N::Error: Into + fmt::Display, { Extension { telemetry_processor: Some(lp), @@ -235,22 +240,27 @@ where validate_buffering_configuration(self.log_buffering)?; - // Spawn task to run processor let addr = SocketAddr::from(([0, 0, 0, 0], self.log_port_number)); - let make_service = service_fn(move |_socket: &AddrStream| { - trace!("Creating new log processor Service"); - let service = log_processor.make_service(()); - async move { - let service = Arc::new(Mutex::new(service.await?)); - Ok::<_, L::MakeError>(service_fn(move |req| log_wrapper(service.clone(), req))) - } - }); - let server = Server::bind(&addr).serve(make_service); - tokio::spawn(async move { - if let Err(e) = server.await { - error!("Error while running log processor: {}", e); + let service = log_processor.make_service(()); + let service = Arc::new(Mutex::new(service.await.unwrap())); + tokio::task::spawn(async move { + trace!("Creating new logs processor Service"); + + loop { + let service: Arc> = service.clone(); + let make_service = service_fn(move |req: Request| log_wrapper(service.clone(), req)); + + let listener = TcpListener::bind(addr).await.unwrap(); + let (tcp, _) = listener.accept().await.unwrap(); + let io = TokioIo::new(tcp); + tokio::task::spawn(async move { + if let Err(err) = http1::Builder::new().serve_connection(io, make_service).await { + println!("Error serving connection: {:?}", err); + } + }); } }); + trace!("Log processor started"); // Call Logs API to start receiving events @@ -276,22 +286,27 @@ where validate_buffering_configuration(self.telemetry_buffering)?; - // Spawn task to run processor let addr = SocketAddr::from(([0, 0, 0, 0], self.telemetry_port_number)); - let make_service = service_fn(move |_socket: &AddrStream| { + let service = telemetry_processor.make_service(()); + let service = Arc::new(Mutex::new(service.await.unwrap())); + tokio::task::spawn(async move { trace!("Creating new telemetry processor Service"); - let service = telemetry_processor.make_service(()); - async move { - let service = Arc::new(Mutex::new(service.await?)); - Ok::<_, T::MakeError>(service_fn(move |req| telemetry_wrapper(service.clone(), req))) - } - }); - let server = Server::bind(&addr).serve(make_service); - tokio::spawn(async move { - if let Err(e) = server.await { - error!("Error while running telemetry processor: {}", e); + + loop { + let service = service.clone(); + let make_service = service_fn(move |req| telemetry_wrapper(service.clone(), req)); + + let listener = TcpListener::bind(addr).await.unwrap(); + let (tcp, _) = listener.accept().await.unwrap(); + let io = TokioIo::new(tcp); + tokio::task::spawn(async move { + if let Err(err) = http1::Builder::new().serve_connection(io, make_service).await { + println!("Error serving connection: {:?}", err); + } + }); } }); + trace!("Telemetry processor started"); // Call Telemetry API to start receiving events @@ -361,7 +376,7 @@ where let event = event?; let (_parts, body) = event.into_parts(); - let body = hyper::body::to_bytes(body).await?; + let body = body.collect().await?.to_bytes(); trace!("{}", std::str::from_utf8(&body)?); // this may be very verbose let event: NextEvent = serde_json::from_slice(&body)?; let is_invoke = event.is_invoke(); diff --git a/lambda-extension/src/logs.rs b/lambda-extension/src/logs.rs index c453c951..4d1948a0 100644 --- a/lambda-extension/src/logs.rs +++ b/lambda-extension/src/logs.rs @@ -1,6 +1,10 @@ use chrono::{DateTime, Utc}; +use http::{Request, Response}; +use http_body_util::BodyExt; +use hyper::body::Incoming; +use lambda_runtime_api_client::body::Body; use serde::{Deserialize, Serialize}; -use std::{boxed::Box, fmt, sync::Arc}; +use std::{fmt, sync::Arc}; use tokio::sync::Mutex; use tower::Service; use tracing::{error, trace}; @@ -186,34 +190,31 @@ pub(crate) fn validate_buffering_configuration(log_buffering: Option` for the /// underlying `Service` to process. -pub(crate) async fn log_wrapper( - service: Arc>, - req: hyper::Request, -) -> Result, Box> +pub(crate) async fn log_wrapper(service: Arc>, req: Request) -> Result, Error> where S: Service, Response = ()>, - S::Error: Into> + fmt::Debug, + S::Error: Into + fmt::Debug, S::Future: Send, { trace!("Received logs request"); // Parse the request body as a Vec - let body = match hyper::body::to_bytes(req.into_body()).await { + let body = match req.into_body().collect().await { Ok(body) => body, Err(e) => { error!("Error reading logs request body: {}", e); return Ok(hyper::Response::builder() .status(hyper::StatusCode::BAD_REQUEST) - .body(hyper::Body::empty()) + .body(Body::empty()) .unwrap()); } }; - let logs: Vec = match serde_json::from_slice(&body) { + let logs: Vec = match serde_json::from_slice(&body.to_bytes()) { Ok(logs) => logs, Err(e) => { error!("Error parsing logs: {}", e); return Ok(hyper::Response::builder() .status(hyper::StatusCode::BAD_REQUEST) - .body(hyper::Body::empty()) + .body(Body::empty()) .unwrap()); } }; @@ -226,7 +227,7 @@ where } } - Ok(hyper::Response::new(hyper::Body::empty())) + Ok(hyper::Response::new(Body::empty())) } #[cfg(test)] diff --git a/lambda-extension/src/requests.rs b/lambda-extension/src/requests.rs index 75c24a0f..4d5f1527 100644 --- a/lambda-extension/src/requests.rs +++ b/lambda-extension/src/requests.rs @@ -1,7 +1,6 @@ use crate::{Error, LogBuffering}; use http::{Method, Request}; -use hyper::Body; -use lambda_runtime_api_client::build_request; +use lambda_runtime_api_client::{body::Body, build_request}; use serde::Serialize; const EXTENSION_NAME_HEADER: &str = "Lambda-Extension-Name"; diff --git a/lambda-extension/src/telemetry.rs b/lambda-extension/src/telemetry.rs index b3131338..1e83ee8e 100644 --- a/lambda-extension/src/telemetry.rs +++ b/lambda-extension/src/telemetry.rs @@ -1,4 +1,8 @@ use chrono::{DateTime, Utc}; +use http::{Request, Response}; +use http_body_util::BodyExt; +use hyper::body::Incoming; +use lambda_runtime_api_client::body::Body; use serde::Deserialize; use std::{boxed::Box, fmt, sync::Arc}; use tokio::sync::Mutex; @@ -256,8 +260,8 @@ pub struct RuntimeDoneMetrics { /// underlying `Service` to process. pub(crate) async fn telemetry_wrapper( service: Arc>, - req: hyper::Request, -) -> Result, Box> + req: Request, +) -> Result, Box> where S: Service, Response = ()>, S::Error: Into> + fmt::Debug, @@ -265,24 +269,24 @@ where { trace!("Received telemetry request"); // Parse the request body as a Vec - let body = match hyper::body::to_bytes(req.into_body()).await { + let body = match req.into_body().collect().await { Ok(body) => body, Err(e) => { error!("Error reading telemetry request body: {}", e); return Ok(hyper::Response::builder() .status(hyper::StatusCode::BAD_REQUEST) - .body(hyper::Body::empty()) + .body(Body::empty()) .unwrap()); } }; - let telemetry: Vec = match serde_json::from_slice(&body) { + let telemetry: Vec = match serde_json::from_slice(&body.to_bytes()) { Ok(telemetry) => telemetry, Err(e) => { error!("Error parsing telemetry: {}", e); return Ok(hyper::Response::builder() .status(hyper::StatusCode::BAD_REQUEST) - .body(hyper::Body::empty()) + .body(Body::empty()) .unwrap()); } }; @@ -295,7 +299,7 @@ where } } - Ok(hyper::Response::new(hyper::Body::empty())) + Ok(hyper::Response::new(Body::empty())) } #[cfg(test)] diff --git a/lambda-http/Cargo.toml b/lambda-http/Cargo.toml index fc93d88f..057134a6 100644 --- a/lambda-http/Cargo.toml +++ b/lambda-http/Cargo.toml @@ -23,21 +23,24 @@ apigw_websockets = [] alb = [] [dependencies] -base64 = "0.21" -bytes = "1.4" -futures = "0.3" -http = "0.2" -http-body = "0.4" -hyper = "0.14" +base64 = { workspace = true } +bytes = { workspace = true } +encoding_rs = "0.8" +futures = { workspace = true } +futures-util = { workspace = true } +http = { workspace = true } +http-body = { workspace = true } +http-body-util = { workspace = true } +hyper = { workspace = true } lambda_runtime = { path = "../lambda-runtime", version = "0.8.3" } +mime = "0.3" +percent-encoding = "2.2" +pin-project-lite = { workspace = true } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" serde_urlencoded = "0.7" tokio-stream = "0.1.2" -mime = "0.3" -encoding_rs = "0.8" url = "2.2" -percent-encoding = "2.2" [dependencies.aws_lambda_events] path = "../lambda-events" @@ -46,6 +49,7 @@ default-features = false features = ["alb", "apigw"] [dev-dependencies] +lambda_runtime_api_client = { version = "0.8", path = "../lambda-runtime-api-client" } log = "^0.4" maplit = "1.0" tokio = { version = "1.0", features = ["macros"] } diff --git a/lambda-http/src/ext/extensions.rs b/lambda-http/src/ext/extensions.rs index 313090c6..cfbdaec2 100644 --- a/lambda-http/src/ext/extensions.rs +++ b/lambda-http/src/ext/extensions.rs @@ -7,20 +7,24 @@ use lambda_runtime::Context; use crate::request::RequestContext; /// ALB/API gateway pre-parsed http query string parameters +#[derive(Clone)] pub(crate) struct QueryStringParameters(pub(crate) QueryMap); /// API gateway pre-extracted url path parameters /// /// These will always be empty for ALB requests +#[derive(Clone)] pub(crate) struct PathParameters(pub(crate) QueryMap); /// API gateway configured /// [stage variables](https://docs.aws.amazon.com/apigateway/latest/developerguide/stage-variables.html) /// /// These will always be empty for ALB requests +#[derive(Clone)] pub(crate) struct StageVariables(pub(crate) QueryMap); /// ALB/API gateway raw http path without any stage information +#[derive(Clone)] pub(crate) struct RawHttpPath(pub(crate) String); /// Extensions for [`lambda_http::Request`], `http::request::Parts`, and `http::Extensions` structs diff --git a/lambda-http/src/response.rs b/lambda-http/src/response.rs index e77ec181..d26ef838 100644 --- a/lambda-http/src/response.rs +++ b/lambda-http/src/response.rs @@ -13,7 +13,7 @@ use http::header::CONTENT_ENCODING; use http::HeaderMap; use http::{header::CONTENT_TYPE, Response, StatusCode}; use http_body::Body as HttpBody; -use hyper::body::to_bytes; +use http_body_util::BodyExt; use mime::{Mime, CHARSET}; use serde::Serialize; use std::borrow::Cow; @@ -305,7 +305,15 @@ where B::Data: Send, B::Error: fmt::Debug, { - Box::pin(async move { Body::from(to_bytes(body).await.expect("unable to read bytes from body").to_vec()) }) + Box::pin(async move { + Body::from( + body.collect() + .await + .expect("unable to read bytes from body") + .to_bytes() + .to_vec(), + ) + }) } fn convert_to_text(body: B, content_type: &str) -> BodyFuture @@ -326,7 +334,7 @@ where // assumes utf-8 Box::pin(async move { - let bytes = to_bytes(body).await.expect("unable to read bytes from body"); + let bytes = body.collect().await.expect("unable to read bytes from body").to_bytes(); let (content, _, _) = encoding.decode(&bytes); match content { @@ -345,7 +353,7 @@ mod tests { header::{CONTENT_ENCODING, CONTENT_TYPE}, Response, StatusCode, }; - use hyper::Body as HyperBody; + use lambda_runtime_api_client::body::Body as HyperBody; use serde_json::{self, json}; const SVG_LOGO: &str = include_str!("../tests/data/svg_logo.svg"); diff --git a/lambda-http/src/streaming.rs b/lambda-http/src/streaming.rs index a59cf700..601e699b 100644 --- a/lambda-http/src/streaming.rs +++ b/lambda-http/src/streaming.rs @@ -63,19 +63,11 @@ where lambda_runtime::run(svc).await } +pin_project_lite::pin_project! { pub struct BodyStream { + #[pin] pub(crate) body: B, } - -impl BodyStream -where - B: Body + Unpin + Send + 'static, - B::Data: Into + Send, - B::Error: Into + Send + Debug, -{ - fn project(self: Pin<&mut Self>) -> Pin<&mut B> { - unsafe { self.map_unchecked_mut(|s| &mut s.body) } - } } impl Stream for BodyStream @@ -86,8 +78,14 @@ where { type Item = Result; - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let body = self.project(); - body.poll_data(cx) + #[inline] + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match futures_util::ready!(self.as_mut().project().body.poll_frame(cx)?) { + Some(frame) => match frame.into_data() { + Ok(data) => Poll::Ready(Some(Ok(data))), + Err(_frame) => Poll::Ready(None), + }, + None => Poll::Ready(None), + } } } diff --git a/lambda-runtime-api-client/Cargo.toml b/lambda-runtime-api-client/Cargo.toml index ae3e0f70..8868b145 100644 --- a/lambda-runtime-api-client/Cargo.toml +++ b/lambda-runtime-api-client/Cargo.toml @@ -4,7 +4,7 @@ version = "0.8.0" edition = "2021" authors = [ "David Calavera ", - "Harold Sun " + "Harold Sun ", ] description = "AWS Lambda Runtime interaction API" license = "Apache-2.0" @@ -14,7 +14,19 @@ keywords = ["AWS", "Lambda", "API"] readme = "README.md" [dependencies] -http = "0.2" -hyper = { version = "0.14.20", features = ["http1", "client", "stream", "tcp"] } -tower-service = "0.3" +bytes = { workspace = true } +futures-channel = { workspace = true } +futures-util = { workspace = true } +http = { workspace = true } +http-body = { workspace = true } +http-body-util = { workspace = true } +hyper = { workspace = true, features = ["http1", "client"] } +hyper-util = { workspace = true, features = [ + "client", + "client-legacy", + "http1", + "tokio", +] } +tower = { workspace = true, features = ["util"] } +tower-service = { workspace = true } tokio = { version = "1.0", features = ["io-util"] } diff --git a/lambda-runtime-api-client/src/body/channel.rs b/lambda-runtime-api-client/src/body/channel.rs new file mode 100644 index 00000000..815de5f2 --- /dev/null +++ b/lambda-runtime-api-client/src/body/channel.rs @@ -0,0 +1,110 @@ +//! Body::channel utilities. Extracted from Hyper under MIT license. +//! https://github.com/hyperium/hyper/blob/master/LICENSE + +use std::pin::Pin; +use std::task::Context; +use std::task::Poll; + +use crate::body::{sender, watch}; +use bytes::Bytes; +use futures_channel::mpsc; +use futures_channel::oneshot; +use futures_util::{stream::FusedStream, Future, Stream}; +use http::HeaderMap; +use http_body::Body; +use http_body::Frame; +use http_body::SizeHint; +pub use sender::Sender; + +#[derive(Clone, Copy, PartialEq, Eq)] +pub(crate) struct DecodedLength(u64); + +impl DecodedLength { + pub(crate) const CLOSE_DELIMITED: DecodedLength = DecodedLength(::std::u64::MAX); + pub(crate) const CHUNKED: DecodedLength = DecodedLength(::std::u64::MAX - 1); + pub(crate) const ZERO: DecodedLength = DecodedLength(0); + + pub(crate) fn sub_if(&mut self, amt: u64) { + match *self { + DecodedLength::CHUNKED | DecodedLength::CLOSE_DELIMITED => (), + DecodedLength(ref mut known) => { + *known -= amt; + } + } + } + + /// Converts to an Option representing a Known or Unknown length. + pub(crate) fn into_opt(self) -> Option { + match self { + DecodedLength::CHUNKED | DecodedLength::CLOSE_DELIMITED => None, + DecodedLength(known) => Some(known), + } + } +} + +pub struct ChannelBody { + content_length: DecodedLength, + want_tx: watch::Sender, + data_rx: mpsc::Receiver>, + trailers_rx: oneshot::Receiver, +} + +pub fn channel() -> (Sender, ChannelBody) { + let (data_tx, data_rx) = mpsc::channel(0); + let (trailers_tx, trailers_rx) = oneshot::channel(); + + let (want_tx, want_rx) = watch::channel(sender::WANT_READY); + + let tx = Sender { + want_rx, + data_tx, + trailers_tx: Some(trailers_tx), + }; + let rx = ChannelBody { + content_length: DecodedLength::CHUNKED, + want_tx, + data_rx, + trailers_rx, + }; + + (tx, rx) +} + +impl Body for ChannelBody { + type Data = Bytes; + type Error = crate::Error; + + fn poll_frame( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + self.want_tx.send(sender::WANT_READY); + + if !self.data_rx.is_terminated() { + if let Some(chunk) = ready!(Pin::new(&mut self.data_rx).poll_next(cx)?) { + self.content_length.sub_if(chunk.len() as u64); + return Poll::Ready(Some(Ok(Frame::data(chunk)))); + } + } + + // check trailers after data is terminated + match ready!(Pin::new(&mut self.trailers_rx).poll(cx)) { + Ok(t) => Poll::Ready(Some(Ok(Frame::trailers(t)))), + Err(_) => Poll::Ready(None), + } + } + + fn is_end_stream(&self) -> bool { + self.content_length == DecodedLength::ZERO + } + + fn size_hint(&self) -> SizeHint { + let mut hint = SizeHint::default(); + + if let Some(content_length) = self.content_length.into_opt() { + hint.set_exact(content_length); + } + + hint + } +} diff --git a/lambda-runtime-api-client/src/body/mod.rs b/lambda-runtime-api-client/src/body/mod.rs new file mode 100644 index 00000000..7e2d597c --- /dev/null +++ b/lambda-runtime-api-client/src/body/mod.rs @@ -0,0 +1,143 @@ +//! HTTP body utilities. Extracted from Axum under MIT license. +//! https://github.com/tokio-rs/axum/blob/main/axum/LICENSE + +use crate::{BoxError, Error}; +use bytes::Bytes; +use futures_util::stream::Stream; +use http_body::{Body as _, Frame}; +use http_body_util::{BodyExt, Collected}; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use self::channel::Sender; + +macro_rules! ready { + ($e:expr) => { + match $e { + std::task::Poll::Ready(v) => v, + std::task::Poll::Pending => return std::task::Poll::Pending, + } + }; +} + +mod channel; +pub mod sender; +mod watch; + +type BoxBody = http_body_util::combinators::UnsyncBoxBody; + +fn boxed(body: B) -> BoxBody +where + B: http_body::Body + Send + 'static, + B::Error: Into, +{ + try_downcast(body).unwrap_or_else(|body| body.map_err(Error::new).boxed_unsync()) +} + +pub(crate) fn try_downcast(k: K) -> Result +where + T: 'static, + K: Send + 'static, +{ + let mut k = Some(k); + if let Some(k) = ::downcast_mut::>(&mut k) { + Ok(k.take().unwrap()) + } else { + Err(k.unwrap()) + } +} + +/// The body type used in axum requests and responses. +#[derive(Debug)] +pub struct Body(BoxBody); + +impl Body { + /// Create a new `Body` that wraps another [`http_body::Body`]. + pub fn new(body: B) -> Self + where + B: http_body::Body + Send + 'static, + B::Error: Into, + { + try_downcast(body).unwrap_or_else(|body| Self(boxed(body))) + } + + /// Create an empty body. + pub fn empty() -> Self { + Self::new(http_body_util::Empty::new()) + } + + /// Create a new `Body` stream with associated Sender half. + pub fn channel() -> (Sender, Body) { + let (sender, body) = channel::channel(); + (sender, Body::new(body)) + } + + /// Collect the body into `Bytes` + pub async fn collect(self) -> Result, Error> { + self.0.collect().await + } +} + +impl Default for Body { + fn default() -> Self { + Self::empty() + } +} + +macro_rules! body_from_impl { + ($ty:ty) => { + impl From<$ty> for Body { + fn from(buf: $ty) -> Self { + Self::new(http_body_util::Full::from(buf)) + } + } + }; +} + +body_from_impl!(&'static [u8]); +body_from_impl!(std::borrow::Cow<'static, [u8]>); +body_from_impl!(Vec); + +body_from_impl!(&'static str); +body_from_impl!(std::borrow::Cow<'static, str>); +body_from_impl!(String); + +body_from_impl!(Bytes); + +impl http_body::Body for Body { + type Data = Bytes; + type Error = Error; + + #[inline] + fn poll_frame( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + Pin::new(&mut self.0).poll_frame(cx) + } + + #[inline] + fn size_hint(&self) -> http_body::SizeHint { + self.0.size_hint() + } + + #[inline] + fn is_end_stream(&self) -> bool { + self.0.is_end_stream() + } +} + +impl Stream for Body { + type Item = Result; + + #[inline] + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match futures_util::ready!(Pin::new(&mut self).poll_frame(cx)?) { + Some(frame) => match frame.into_data() { + Ok(data) => Poll::Ready(Some(Ok(data))), + Err(_frame) => Poll::Ready(None), + }, + None => Poll::Ready(None), + } + } +} diff --git a/lambda-runtime-api-client/src/body/sender.rs b/lambda-runtime-api-client/src/body/sender.rs new file mode 100644 index 00000000..0e008454 --- /dev/null +++ b/lambda-runtime-api-client/src/body/sender.rs @@ -0,0 +1,135 @@ +//! Body::channel utilities. Extracted from Hyper under MIT license. +//! https://github.com/hyperium/hyper/blob/master/LICENSE + +use crate::Error; +use std::task::{Context, Poll}; + +use bytes::Bytes; +use futures_channel::{mpsc, oneshot}; +use http::HeaderMap; + +use super::watch; + +type BodySender = mpsc::Sender>; +type TrailersSender = oneshot::Sender; + +pub(crate) const WANT_PENDING: usize = 1; +pub(crate) const WANT_READY: usize = 2; + +/// A sender half created through [`Body::channel()`]. +/// +/// Useful when wanting to stream chunks from another thread. +/// +/// ## Body Closing +/// +/// Note that the request body will always be closed normally when the sender is dropped (meaning +/// that the empty terminating chunk will be sent to the remote). If you desire to close the +/// connection with an incomplete response (e.g. in the case of an error during asynchronous +/// processing), call the [`Sender::abort()`] method to abort the body in an abnormal fashion. +/// +/// [`Body::channel()`]: struct.Body.html#method.channel +/// [`Sender::abort()`]: struct.Sender.html#method.abort +#[must_use = "Sender does nothing unless sent on"] +pub struct Sender { + pub(crate) want_rx: watch::Receiver, + pub(crate) data_tx: BodySender, + pub(crate) trailers_tx: Option, +} + +impl Sender { + /// Check to see if this `Sender` can send more data. + pub fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + // Check if the receiver end has tried polling for the body yet + ready!(self.poll_want(cx)?); + self.data_tx + .poll_ready(cx) + .map_err(|_| Error::new(SenderError::ChannelClosed)) + } + + fn poll_want(&mut self, cx: &mut Context<'_>) -> Poll> { + match self.want_rx.load(cx) { + WANT_READY => Poll::Ready(Ok(())), + WANT_PENDING => Poll::Pending, + watch::CLOSED => Poll::Ready(Err(Error::new(SenderError::ChannelClosed))), + unexpected => unreachable!("want_rx value: {}", unexpected), + } + } + + async fn ready(&mut self) -> Result<(), Error> { + futures_util::future::poll_fn(|cx| self.poll_ready(cx)).await + } + + /// Send data on data channel when it is ready. + #[allow(unused)] + pub async fn send_data(&mut self, chunk: Bytes) -> Result<(), Error> { + self.ready().await?; + self.data_tx + .try_send(Ok(chunk)) + .map_err(|_| Error::new(SenderError::ChannelClosed)) + } + + /// Send trailers on trailers channel. + #[allow(unused)] + pub async fn send_trailers(&mut self, trailers: HeaderMap) -> Result<(), Error> { + let tx = match self.trailers_tx.take() { + Some(tx) => tx, + None => return Err(Error::new(SenderError::ChannelClosed)), + }; + tx.send(trailers).map_err(|_| Error::new(SenderError::ChannelClosed)) + } + + /// Try to send data on this channel. + /// + /// # Errors + /// + /// Returns `Err(Bytes)` if the channel could not (currently) accept + /// another `Bytes`. + /// + /// # Note + /// + /// This is mostly useful for when trying to send from some other thread + /// that doesn't have an async context. If in an async context, prefer + /// `send_data()` instead. + pub fn try_send_data(&mut self, chunk: Bytes) -> Result<(), Bytes> { + self.data_tx + .try_send(Ok(chunk)) + .map_err(|err| err.into_inner().expect("just sent Ok")) + } + + /// Send a `SenderError::BodyWriteAborted` error and terminate the stream. + #[allow(unused)] + pub fn abort(mut self) { + self.send_error(Error::new(SenderError::BodyWriteAborted)); + } + + /// Terminate the stream with an error. + pub fn send_error(&mut self, err: Error) { + let _ = self + .data_tx + // clone so the send works even if buffer is full + .clone() + .try_send(Err(err)); + } +} + +#[derive(Debug)] +enum SenderError { + ChannelClosed, + BodyWriteAborted, +} + +impl SenderError { + fn description(&self) -> &str { + match self { + SenderError::BodyWriteAborted => "user body write aborted", + SenderError::ChannelClosed => "channel closed", + } + } +} + +impl std::fmt::Display for SenderError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.description()) + } +} +impl std::error::Error for SenderError {} diff --git a/lambda-runtime-api-client/src/body/watch.rs b/lambda-runtime-api-client/src/body/watch.rs new file mode 100644 index 00000000..ac0bd4ee --- /dev/null +++ b/lambda-runtime-api-client/src/body/watch.rs @@ -0,0 +1,69 @@ +//! Body::channel utilities. Extracted from Hyper under MIT license. +//! https://github.com/hyperium/hyper/blob/master/LICENSE + +//! An SPSC broadcast channel. +//! +//! - The value can only be a `usize`. +//! - The consumer is only notified if the value is different. +//! - The value `0` is reserved for closed. + +use futures_util::task::AtomicWaker; +use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, +}; +use std::task; + +type Value = usize; + +pub(crate) const CLOSED: usize = 0; + +pub(crate) fn channel(initial: Value) -> (Sender, Receiver) { + debug_assert!(initial != CLOSED, "watch::channel initial state of 0 is reserved"); + + let shared = Arc::new(Shared { + value: AtomicUsize::new(initial), + waker: AtomicWaker::new(), + }); + + (Sender { shared: shared.clone() }, Receiver { shared }) +} + +pub(crate) struct Sender { + shared: Arc, +} + +pub(crate) struct Receiver { + shared: Arc, +} + +struct Shared { + value: AtomicUsize, + waker: AtomicWaker, +} + +impl Sender { + pub(crate) fn send(&mut self, value: Value) { + if self.shared.value.swap(value, Ordering::SeqCst) != value { + self.shared.waker.wake(); + } + } +} + +impl Drop for Sender { + fn drop(&mut self) { + self.send(CLOSED); + } +} + +impl Receiver { + pub(crate) fn load(&mut self, cx: &mut task::Context<'_>) -> Value { + self.shared.waker.register(cx.waker()); + self.shared.value.load(Ordering::SeqCst) + } + + #[allow(unused)] + pub(crate) fn peek(&self) -> Value { + self.shared.value.load(Ordering::Relaxed) + } +} diff --git a/lambda-runtime-api-client/src/error.rs b/lambda-runtime-api-client/src/error.rs new file mode 100644 index 00000000..dbb87b64 --- /dev/null +++ b/lambda-runtime-api-client/src/error.rs @@ -0,0 +1,33 @@ +//! Extracted from Axum under MIT license. +//! https://github.com/tokio-rs/axum/blob/main/axum/LICENSE +use std::{error::Error as StdError, fmt}; +pub use tower::BoxError; +/// Errors that can happen when using axum. +#[derive(Debug)] +pub struct Error { + inner: BoxError, +} + +impl Error { + /// Create a new `Error` from a boxable error. + pub fn new(error: impl Into) -> Self { + Self { inner: error.into() } + } + + /// Convert an `Error` back into the underlying boxed trait object. + pub fn into_inner(self) -> BoxError { + self.inner + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.inner.fmt(f) + } +} + +impl StdError for Error { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + Some(&*self.inner) + } +} diff --git a/lambda-runtime-api-client/src/lib.rs b/lambda-runtime-api-client/src/lib.rs index 4b082aba..15185f81 100644 --- a/lambda-runtime-api-client/src/lib.rs +++ b/lambda-runtime-api-client/src/lib.rs @@ -5,20 +5,18 @@ //! This crate includes a base HTTP client to interact with //! the AWS Lambda Runtime API. use http::{uri::PathAndQuery, uri::Scheme, Request, Response, Uri}; -use hyper::{ - client::{connect::Connection, HttpConnector}, - Body, -}; +use hyper::body::Incoming; +use hyper_util::client::legacy::connect::{Connect, Connection, HttpConnector}; use std::{convert::TryInto, fmt::Debug}; -use tokio::io::{AsyncRead, AsyncWrite}; use tower_service::Service; const USER_AGENT_HEADER: &str = "User-Agent"; const DEFAULT_USER_AGENT: &str = concat!("aws-lambda-rust/", env!("CARGO_PKG_VERSION")); const CUSTOM_USER_AGENT: Option<&str> = option_env!("LAMBDA_RUNTIME_USER_AGENT"); -/// Error type that lambdas may result in -pub type Error = Box; +mod error; +pub use error::*; +pub mod body; /// API client to interact with the AWS Lambda Runtime API. #[derive(Debug)] @@ -26,7 +24,7 @@ pub struct Client { /// The runtime API URI pub base: Uri, /// The client that manages the API connections - pub client: hyper::Client, + pub client: hyper_util::client::legacy::Client, } impl Client { @@ -41,25 +39,24 @@ impl Client { impl Client where - C: hyper::client::connect::Connect + Sync + Send + Clone + 'static, + C: Connect + Sync + Send + Clone + 'static, { /// Send a given request to the Runtime API. /// Use the client's base URI to ensure the API endpoint is correct. - pub async fn call(&self, req: Request) -> Result, Error> { + pub async fn call(&self, req: Request) -> Result, BoxError> { let req = self.set_origin(req)?; - let response = self.client.request(req).await?; - Ok(response) + self.client.request(req).await.map_err(Into::into) } /// Create a new client with a given base URI and HTTP connector. pub fn with(base: Uri, connector: C) -> Self { - let client = hyper::Client::builder() + let client = hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new()) .http1_max_buf_size(1024 * 1024) .build(connector); Self { base, client } } - fn set_origin(&self, req: Request) -> Result, Error> { + fn set_origin(&self, req: Request) -> Result, BoxError> { let (mut parts, body) = req.into_parts(); let (scheme, authority, base_path) = { let scheme = self.base.scheme().unwrap_or(&Scheme::HTTP); @@ -83,7 +80,7 @@ where } /// Builder implementation to construct any Runtime API clients. -pub struct ClientBuilder = hyper::client::HttpConnector> { +pub struct ClientBuilder = HttpConnector> { connector: C, uri: Option, } @@ -93,7 +90,7 @@ where C: Service + Clone + Send + Sync + Unpin + 'static, >::Future: Unpin + Send, >::Error: Into>, - >::Response: AsyncRead + AsyncWrite + Connection + Unpin + Send + 'static, + >::Response: Connection + Unpin + Send + 'static, { /// Create a new builder with a given HTTP connector. pub fn with_connector(self, connector: C2) -> ClientBuilder @@ -101,7 +98,7 @@ where C2: Service + Clone + Send + Sync + Unpin + 'static, >::Future: Unpin + Send, >::Error: Into>, - >::Response: AsyncRead + AsyncWrite + Connection + Unpin + Send + 'static, + >::Response: Connection + Unpin + Send + 'static, { ClientBuilder { connector, @@ -116,7 +113,10 @@ where } /// Create the new client to interact with the Runtime API. - pub fn build(self) -> Result, Error> { + pub fn build(self) -> Result, Error> + where + C: Connect + Sync + Send + Clone + 'static, + { let uri = match self.uri { Some(uri) => uri, None => { diff --git a/lambda-runtime/Cargo.toml b/lambda-runtime/Cargo.toml index 335b5482..8a579d99 100644 --- a/lambda-runtime/Cargo.toml +++ b/lambda-runtime/Cargo.toml @@ -18,30 +18,45 @@ default = ["simulated"] simulated = [] [dependencies] +async-stream = "0.3" +base64 = { workspace = true } +bytes = { workspace = true } +futures = { workspace = true } +http = { workspace = true } +http-body = { workspace = true } +http-body-util = { workspace = true } +http-serde = { workspace = true } +hyper = { workspace = true, features = [ + "http1", + "client", +] } +hyper-util = { workspace = true, features = [ + "client", + "client-legacy", + "http1", + "tokio", +] } +lambda_runtime_api_client = { version = "0.8", path = "../lambda-runtime-api-client" } +serde = { version = "1", features = ["derive", "rc"] } +serde_json = "^1" +serde_path_to_error = "0.1.11" tokio = { version = "1.0", features = [ "macros", "io-util", "sync", "rt-multi-thread", ] } -# Hyper requires the `server` feature to work on nightly -hyper = { version = "0.14.20", features = [ - "http1", +tokio-stream = "0.1.2" +tower = { workspace = true, features = ["util"] } +tracing = { version = "0.1", features = ["log"] } + +[dev-dependencies] +hyper-util = { workspace = true, features = [ "client", - "stream", + "client-legacy", + "http1", "server", + "server-auto", + "tokio", ] } -futures = "0.3" -serde = { version = "1", features = ["derive", "rc"] } -serde_json = "^1" -bytes = "1.0" -http = "0.2" -async-stream = "0.3" -tracing = { version = "0.1.37", features = ["log"] } -tower = { version = "0.4", features = ["util"] } -tokio-stream = "0.1.2" -lambda_runtime_api_client = { version = "0.8", path = "../lambda-runtime-api-client" } -serde_path_to_error = "0.1.11" -http-serde = "1.1.3" -base64 = "0.21.0" -http-body = "0.4" +pin-project-lite = { workspace = true } \ No newline at end of file diff --git a/lambda-runtime/src/lib.rs b/lambda-runtime/src/lib.rs index ccd35ab0..33679d2a 100644 --- a/lambda-runtime/src/lib.rs +++ b/lambda-runtime/src/lib.rs @@ -9,22 +9,18 @@ //! and runs the Lambda runtime. use bytes::Bytes; use futures::FutureExt; -use hyper::{ - client::{connect::Connection, HttpConnector}, - http::Request, - Body, -}; -use lambda_runtime_api_client::Client; +use http_body_util::BodyExt; +use hyper::{body::Incoming, http::Request}; +use hyper_util::client::legacy::connect::{Connect, Connection, HttpConnector}; +use lambda_runtime_api_client::{body::Body, BoxError, Client}; use serde::{Deserialize, Serialize}; use std::{ env, fmt::{self, Debug, Display}, future::Future, - marker::PhantomData, panic, sync::Arc, }; -use tokio::io::{AsyncRead, AsyncWrite}; use tokio_stream::{Stream, StreamExt}; pub use tower::{self, service_fn, Service}; use tower::{util::ServiceFn, ServiceExt}; @@ -34,6 +30,8 @@ mod deserializer; mod requests; #[cfg(test)] mod simulated; +/// Utilities for Lambda Streaming functions. +pub mod streaming; /// Types available to a Lambda function. mod types; @@ -43,7 +41,7 @@ pub use types::{Context, FunctionResponse, IntoFunctionResponse, LambdaEvent, Me use types::invoke_request_id; /// Error type that lambdas may result in -pub type Error = lambda_runtime_api_client::Error; +pub type Error = lambda_runtime_api_client::BoxError; /// Configuration derived from environment variables. #[derive(Debug, Default, Clone, Eq, PartialEq, Serialize, Deserialize)] @@ -95,16 +93,16 @@ struct Runtime = HttpConnector> { impl Runtime where - C: Service + Clone + Send + Sync + Unpin + 'static, + C: Service + Connect + Clone + Send + Sync + Unpin + 'static, C::Future: Unpin + Send, C::Error: Into>, - C::Response: AsyncRead + AsyncWrite + Connection + Unpin + Send + 'static, + C::Response: Connection + Unpin + Send + 'static, { async fn run( &self, - incoming: impl Stream, Error>> + Send, + incoming: impl Stream, Error>> + Send, mut handler: F, - ) -> Result<(), Error> + ) -> Result<(), BoxError> where F: Service>, F::Future: Future>, @@ -137,7 +135,7 @@ where // Group the handling in one future and instrument it with the span async { - let body = hyper::body::to_bytes(body).await?; + let body = body.collect().await?.to_bytes(); trace!("response body - {}", std::str::from_utf8(&body)?); #[cfg(debug_assertions)] @@ -170,13 +168,7 @@ where Ok(response) => match response { Ok(response) => { trace!("Ok response from handler (run loop)"); - EventCompletionRequest { - request_id, - body: response, - _unused_b: PhantomData, - _unused_s: PhantomData, - } - .into_req() + EventCompletionRequest::new(request_id, response).into_req() } Err(err) => build_event_error_request(request_id, err), }, @@ -205,12 +197,12 @@ where } } -fn incoming(client: &Client) -> impl Stream, Error>> + Send + '_ +fn incoming(client: &Client) -> impl Stream, Error>> + Send + '_ where - C: Service + Clone + Send + Sync + Unpin + 'static, + C: Service + Connect + Clone + Send + Sync + Unpin + 'static, >::Future: Unpin + Send, >::Error: Into>, - >::Response: AsyncRead + AsyncWrite + Connection + Unpin + Send + 'static, + >::Response: Connection + Unpin + Send + 'static, { async_stream::stream! { loop { @@ -294,20 +286,23 @@ mod endpoint_tests { }; use futures::future::BoxFuture; use http::{uri::PathAndQuery, HeaderValue, Method, Request, Response, StatusCode, Uri}; - use hyper::{server::conn::Http, service::service_fn, Body}; - use lambda_runtime_api_client::Client; + use hyper::body::Incoming; + use hyper::rt::{Read, Write}; + use hyper::service::service_fn; + + use hyper_util::server::conn::auto::Builder; + use lambda_runtime_api_client::{body::Body, Client}; use serde_json::json; use simulated::DuplexStreamWrapper; - use std::{convert::TryFrom, env, marker::PhantomData, sync::Arc}; + use std::{convert::TryFrom, env, sync::Arc}; use tokio::{ - io::{self, AsyncRead, AsyncWrite}, - select, + io, select, sync::{self, oneshot}, }; use tokio_stream::StreamExt; #[cfg(test)] - async fn next_event(req: &Request) -> Result, Error> { + async fn next_event(req: &Request) -> Result, Error> { let path = "/2018-06-01/runtime/invocation/next"; assert_eq!(req.method(), Method::GET); assert_eq!(req.uri().path_and_query().unwrap(), &PathAndQuery::from_static(path)); @@ -324,7 +319,7 @@ mod endpoint_tests { } #[cfg(test)] - async fn complete_event(req: &Request, id: &str) -> Result, Error> { + async fn complete_event(req: &Request, id: &str) -> Result, Error> { assert_eq!(Method::POST, req.method()); let rsp = Response::builder() .status(StatusCode::ACCEPTED) @@ -338,7 +333,7 @@ mod endpoint_tests { } #[cfg(test)] - async fn event_err(req: &Request, id: &str) -> Result, Error> { + async fn event_err(req: &Request, id: &str) -> Result, Error> { let expected = format!("/2018-06-01/runtime/invocation/{id}/error"); assert_eq!(expected, req.uri().path()); @@ -352,7 +347,7 @@ mod endpoint_tests { } #[cfg(test)] - async fn handle_incoming(req: Request) -> Result, Error> { + async fn handle_incoming(req: Request) -> Result, Error> { let path: Vec<&str> = req .uri() .path_and_query() @@ -370,11 +365,14 @@ mod endpoint_tests { } #[cfg(test)] - async fn handle(io: I, rx: oneshot::Receiver<()>) -> Result<(), hyper::Error> + async fn handle(io: I, rx: oneshot::Receiver<()>) -> Result<(), Error> where - I: AsyncRead + AsyncWrite + Unpin + 'static, + I: Read + Write + Unpin + 'static, { - let conn = Http::new().serve_connection(io, service_fn(handle_incoming)); + use hyper_util::rt::TokioExecutor; + + let builder = Builder::new(TokioExecutor::new()); + let conn = builder.serve_connection(io, service_fn(handle_incoming)); select! { _ = rx => { Ok(()) @@ -397,7 +395,9 @@ mod endpoint_tests { let (tx, rx) = sync::oneshot::channel(); let server = tokio::spawn(async { - handle(server, rx).await.expect("Unable to handle request"); + handle(DuplexStreamWrapper::new(server), rx) + .await + .expect("Unable to handle request"); }); let conn = simulated::Connector::with(base.clone(), DuplexStreamWrapper::new(client))?; @@ -426,18 +426,15 @@ mod endpoint_tests { let base = Uri::from_static("http://localhost:9001"); let server = tokio::spawn(async { - handle(server, rx).await.expect("Unable to handle request"); + handle(DuplexStreamWrapper::new(server), rx) + .await + .expect("Unable to handle request"); }); let conn = simulated::Connector::with(base.clone(), DuplexStreamWrapper::new(client))?; let client = Client::with(base, conn); - let req = EventCompletionRequest { - request_id: "156cb537-e2d4-11e8-9b34-d36013741fb9", - body: "done", - _unused_b: PhantomData::<&str>, - _unused_s: PhantomData::, - }; + let req = EventCompletionRequest::new("156cb537-e2d4-11e8-9b34-d36013741fb9", "done"); let req = req.into_req()?; let rsp = client.call(req).await?; @@ -459,7 +456,9 @@ mod endpoint_tests { let base = Uri::from_static("http://localhost:9001"); let server = tokio::spawn(async { - handle(server, rx).await.expect("Unable to handle request"); + handle(DuplexStreamWrapper::new(server), rx) + .await + .expect("Unable to handle request"); }); let conn = simulated::Connector::with(base.clone(), DuplexStreamWrapper::new(client))?; @@ -492,7 +491,9 @@ mod endpoint_tests { let base = Uri::from_static("http://localhost:9001"); let server = tokio::spawn(async { - handle(server, rx).await.expect("Unable to handle request"); + handle(DuplexStreamWrapper::new(server), rx) + .await + .expect("Unable to handle request"); }); let conn = simulated::Connector::with(base.clone(), DuplexStreamWrapper::new(client))?; @@ -555,7 +556,9 @@ mod endpoint_tests { let base = Uri::from_static("http://localhost:9001"); let server = tokio::spawn(async { - handle(server, rx).await.expect("Unable to handle request"); + handle(DuplexStreamWrapper::new(server), rx) + .await + .expect("Unable to handle request"); }); let conn = simulated::Connector::with(base.clone(), DuplexStreamWrapper::new(client))?; diff --git a/lambda-runtime/src/requests.rs b/lambda-runtime/src/requests.rs index 8e72fc2d..c9274cf4 100644 --- a/lambda-runtime/src/requests.rs +++ b/lambda-runtime/src/requests.rs @@ -3,8 +3,7 @@ use crate::{types::Diagnostic, Error, FunctionResponse, IntoFunctionResponse}; use bytes::Bytes; use http::header::CONTENT_TYPE; use http::{Method, Request, Response, Uri}; -use hyper::Body; -use lambda_runtime_api_client::build_request; +use lambda_runtime_api_client::{body::Body, build_request}; use serde::Serialize; use std::fmt::Debug; use std::marker::PhantomData; @@ -28,7 +27,7 @@ impl IntoRequest for NextEventRequest { let req = build_request() .method(Method::GET) .uri(Uri::from_static("/2018-06-01/runtime/invocation/next")) - .body(Body::empty())?; + .body(Default::default())?; Ok(req) } } @@ -49,6 +48,7 @@ pub struct NextEventResponse<'a> { impl<'a> IntoResponse for NextEventResponse<'a> { fn into_rsp(self) -> Result, Error> { + // let body: BoxyBody< = BoxBody::new(); let rsp = Response::builder() .header("lambda-runtime-aws-request-id", self.request_id) .header("lambda-runtime-deadline-ms", self.deadline) @@ -85,6 +85,25 @@ where pub(crate) _unused_s: PhantomData, } +impl<'a, R, B, D, E, S> EventCompletionRequest<'a, R, B, S, D, E> +where + R: IntoFunctionResponse, + B: Serialize, + S: Stream> + Unpin + Send + 'static, + D: Into + Send, + E: Into + Send + Debug, +{ + /// Initialize a new EventCompletionRequest + pub(crate) fn new(request_id: &'a str, body: R) -> EventCompletionRequest<'a, R, B, S, D, E> { + EventCompletionRequest { + request_id, + body, + _unused_b: PhantomData::, + _unused_s: PhantomData::, + } + } +} + impl<'a, R, B, S, D, E> IntoRequest for EventCompletionRequest<'a, R, B, S, D, E> where R: IntoFunctionResponse, @@ -157,12 +176,7 @@ where #[test] fn test_event_completion_request() { - let req = EventCompletionRequest { - request_id: "id", - body: "hello, world!", - _unused_b: PhantomData::<&str>, - _unused_s: PhantomData::, - }; + let req = EventCompletionRequest::new("id", "hello, world!"); let req = req.into_req().unwrap(); let expected = Uri::from_static("/2018-06-01/runtime/invocation/id/response"); assert_eq!(req.method(), Method::POST); diff --git a/lambda-runtime/src/simulated.rs b/lambda-runtime/src/simulated.rs index f6a06bca..018664fe 100644 --- a/lambda-runtime/src/simulated.rs +++ b/lambda-runtime/src/simulated.rs @@ -1,14 +1,15 @@ use http::Uri; -use hyper::client::connect::Connection; +use hyper::rt::{Read, Write}; +use hyper_util::client::legacy::connect::{Connected, Connection}; +use pin_project_lite::pin_project; use std::{ collections::HashMap, future::Future, - io::Result as IoResult, pin::Pin, sync::{Arc, Mutex}, task::{Context, Poll}, }; -use tokio::io::{AsyncRead, AsyncWrite, DuplexStream, ReadBuf}; +use tokio::io::DuplexStream; use crate::Error; @@ -17,11 +18,16 @@ pub struct Connector { inner: Arc>>, } -pub struct DuplexStreamWrapper(DuplexStream); +pin_project! { +pub struct DuplexStreamWrapper { + #[pin] + inner: DuplexStream, +} +} impl DuplexStreamWrapper { - pub(crate) fn new(stream: DuplexStream) -> DuplexStreamWrapper { - DuplexStreamWrapper(stream) + pub(crate) fn new(inner: DuplexStream) -> DuplexStreamWrapper { + DuplexStreamWrapper { inner } } } @@ -53,16 +59,12 @@ impl Connector { } } -impl hyper::service::Service for Connector { +impl tower::Service for Connector { type Response = DuplexStreamWrapper; type Error = crate::Error; #[allow(clippy::type_complexity)] type Future = Pin> + Send>>; - fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - fn call(&mut self, uri: Uri) -> Self::Future { let res = match self.inner.lock() { Ok(mut map) if map.contains_key(&uri) => Ok(map.remove(&uri).unwrap()), @@ -71,30 +73,61 @@ impl hyper::service::Service for Connector { }; Box::pin(async move { res }) } + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } } impl Connection for DuplexStreamWrapper { - fn connected(&self) -> hyper::client::connect::Connected { - hyper::client::connect::Connected::new() + fn connected(&self) -> Connected { + Connected::new() } } -impl AsyncRead for DuplexStreamWrapper { - fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll> { - Pin::new(&mut self.0).poll_read(cx, buf) +impl Read for DuplexStreamWrapper { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + mut buf: hyper::rt::ReadBufCursor<'_>, + ) -> Poll> { + let n = unsafe { + let mut tbuf = tokio::io::ReadBuf::uninit(buf.as_mut()); + match tokio::io::AsyncRead::poll_read(self.project().inner, cx, &mut tbuf) { + Poll::Ready(Ok(())) => tbuf.filled().len(), + other => return other, + } + }; + + unsafe { + buf.advance(n); + } + Poll::Ready(Ok(())) } } -impl AsyncWrite for DuplexStreamWrapper { - fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { - Pin::new(&mut self.0).poll_write(cx, buf) +impl Write for DuplexStreamWrapper { + fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + tokio::io::AsyncWrite::poll_flush(self.project().inner, cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx) } - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.0).poll_flush(cx) + fn is_write_vectored(&self) -> bool { + tokio::io::AsyncWrite::is_write_vectored(&self.inner) } - fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.0).poll_shutdown(cx) + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + tokio::io::AsyncWrite::poll_write_vectored(self.project().inner, cx, bufs) } } diff --git a/lambda-runtime/src/streaming.rs b/lambda-runtime/src/streaming.rs new file mode 100644 index 00000000..4f0c8083 --- /dev/null +++ b/lambda-runtime/src/streaming.rs @@ -0,0 +1,35 @@ +pub use lambda_runtime_api_client::body::{sender::Sender, Body}; + +pub use crate::types::StreamResponse as Response; + +/// Create a new `Body` stream with associated Sender half. +/// +/// Examples +/// +/// ``` +/// use lambda_runtime::{ +/// streaming::{channel, Body, Response}, +/// Error, LambdaEvent, +/// }; +/// use std::{thread, time::Duration}; +/// +/// async fn func(_event: LambdaEvent) -> Result, Error> { +/// let messages = vec!["Hello", "world", "from", "Lambda!"]; +/// +/// let (mut tx, rx) = channel(); +/// +/// tokio::spawn(async move { +/// for message in messages.iter() { +/// tx.send_data((message.to_string() + "\n").into()).await.unwrap(); +/// thread::sleep(Duration::from_millis(500)); +/// } +/// }); +/// +/// Ok(Response::from(rx)) +/// } +/// ``` +#[allow(unused)] +#[inline] +pub fn channel() -> (Sender, Body) { + Body::channel() +} diff --git a/lambda-runtime/src/types.rs b/lambda-runtime/src/types.rs index 8b70ce80..f2a36073 100644 --- a/lambda-runtime/src/types.rs +++ b/lambda-runtime/src/types.rs @@ -2,6 +2,7 @@ use crate::{Error, RefConfig}; use base64::prelude::*; use bytes::Bytes; use http::{header::ToStrError, HeaderMap, HeaderValue, StatusCode}; +use lambda_runtime_api_client::body::Body; use serde::{Deserialize, Serialize}; use std::{ collections::HashMap, @@ -251,11 +252,11 @@ impl IntoFunctionResponse for FunctionResponse { } } -impl IntoFunctionResponse for B +impl IntoFunctionResponse for B where B: Serialize, { - fn into_response(self) -> FunctionResponse { + fn into_response(self) -> FunctionResponse { FunctionResponse::BufferedResponse(self) } } @@ -271,6 +272,20 @@ where } } +impl From for StreamResponse +where + S: Stream> + Unpin + Send + 'static, + D: Into + Send, + E: Into + Send + Debug, +{ + fn from(value: S) -> Self { + StreamResponse { + metadata_prelude: Default::default(), + stream: value, + } + } +} + #[cfg(test)] mod test { use super::*;