diff --git a/examples/basic-streaming-response/Cargo.toml b/examples/basic-streaming-response/Cargo.toml new file mode 100644 index 00000000..fc284674 --- /dev/null +++ b/examples/basic-streaming-response/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "basic-streaming-response" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +hyper = { version = "0.14", features = [ + "http1", + "client", + "stream", +] } +lambda_runtime = { path = "../../lambda-runtime" } +tokio = { version = "1", features = ["macros"] } +tracing = { version = "0.1", features = ["log"] } +tracing-subscriber = { version = "0.3", default-features = false, features = ["ansi", "fmt"] } +serde_json = "1.0" \ No newline at end of file diff --git a/examples/basic-streaming-response/README.md b/examples/basic-streaming-response/README.md new file mode 100644 index 00000000..3b68f518 --- /dev/null +++ b/examples/basic-streaming-response/README.md @@ -0,0 +1,13 @@ +# AWS Lambda Function example + +## Build & Deploy + +1. Install [cargo-lambda](https://github.com/cargo-lambda/cargo-lambda#installation) +2. Build the function with `cargo lambda build --release` +3. Deploy the function to AWS Lambda with `cargo lambda deploy --enable-function-url --iam-role YOUR_ROLE` +4. Enable Lambda streaming response on Lambda console: change the function url's invoke mode to `RESPONSE_STREAM` +5. Verify the function works: `curl `. The results should be streamed back with 0.5 second pause between each word. + +## Build for ARM 64 + +Build the function with `cargo lambda build --release --arm64` diff --git a/examples/basic-streaming-response/src/main.rs b/examples/basic-streaming-response/src/main.rs new file mode 100644 index 00000000..04c7f8ec --- /dev/null +++ b/examples/basic-streaming-response/src/main.rs @@ -0,0 +1,42 @@ +use hyper::{body::Body, Response}; +use lambda_runtime::{service_fn, Error, LambdaEvent}; +use serde_json::Value; +use std::{thread, time::Duration}; + +async fn func(_event: LambdaEvent) -> Result, Error> { + let messages = vec!["Hello", "world", "from", "Lambda!"]; + + let (mut tx, rx) = Body::channel(); + + tokio::spawn(async move { + for message in messages.iter() { + tx.send_data((message.to_string() + "\n").into()).await.unwrap(); + thread::sleep(Duration::from_millis(500)); + } + }); + + let resp = Response::builder() + .header("content-type", "text/html") + .header("CustomHeader", "outerspace") + .body(rx)?; + + Ok(resp) +} + +#[tokio::main] +async fn main() -> Result<(), Error> { + // required to enable CloudWatch error logging by the runtime + tracing_subscriber::fmt() + .with_max_level(tracing::Level::INFO) + // disable printing the name of the module in every log line. + .with_target(false) + // this needs to be set to false, otherwise ANSI color codes will + // show up in a confusing manner in CloudWatch logs. + .with_ansi(false) + // disabling time is handy because CloudWatch will add the ingestion time. + .without_time() + .init(); + + lambda_runtime::run_with_streaming_response(service_fn(func)).await?; + Ok(()) +} diff --git a/lambda-http/Cargo.toml b/lambda-http/Cargo.toml index a2ac6250..aacf739b 100644 --- a/lambda-http/Cargo.toml +++ b/lambda-http/Cargo.toml @@ -23,19 +23,20 @@ apigw_websockets = [] alb = [] [dependencies] -base64 = "0.13.0" -bytes = "1" +base64 = "0.21" +bytes = "1.4" +futures = "0.3" http = "0.2" http-body = "0.4" -hyper = "0.14.20" +hyper = "0.14" lambda_runtime = { path = "../lambda-runtime", version = "0.7" } -serde = { version = "^1", features = ["derive"] } -serde_json = "^1" -serde_urlencoded = "0.7.0" -mime = "0.3.16" -encoding_rs = "0.8.31" -url = "2.2.2" -percent-encoding = "2.2.0" +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +serde_urlencoded = "0.7" +mime = "0.3" +encoding_rs = "0.8" +url = "2.2" +percent-encoding = "2.2" [dependencies.aws_lambda_events] version = "^0.7.2" diff --git a/lambda-http/src/lib.rs b/lambda-http/src/lib.rs index 8d030a75..b4d9c5bd 100644 --- a/lambda-http/src/lib.rs +++ b/lambda-http/src/lib.rs @@ -92,6 +92,9 @@ use std::{ task::{Context as TaskContext, Poll}, }; +mod streaming; +pub use streaming::run_with_streaming_response; + /// Type alias for `http::Request`s with a fixed [`Body`](enum.Body.html) type pub type Request = http::Request; diff --git a/lambda-http/src/streaming.rs b/lambda-http/src/streaming.rs new file mode 100644 index 00000000..150002be --- /dev/null +++ b/lambda-http/src/streaming.rs @@ -0,0 +1,34 @@ +use crate::request::LambdaRequest; +use crate::tower::ServiceBuilder; +use crate::{Request, RequestExt}; +pub use aws_lambda_events::encodings::Body as LambdaEventBody; +use bytes::Bytes; +pub use http::{self, Response}; +use http_body::Body; +use lambda_runtime::LambdaEvent; +pub use lambda_runtime::{self, service_fn, tower, Context, Error, Service}; +use std::fmt::{Debug, Display}; + +/// Starts the Lambda Rust runtime and stream response back [Configure Lambda +/// Streaming Response](https://docs.aws.amazon.com/lambda/latest/dg/configuration-response-streaming.html). +/// +/// This takes care of transforming the LambdaEvent into a [`Request`] and +/// accepts [`http::Response`] as response. +pub async fn run_with_streaming_response<'a, S, B, E>(handler: S) -> Result<(), Error> +where + S: Service, Error = E>, + S::Future: Send + 'a, + E: Debug + Display, + B: Body + Unpin + Send + 'static, + B::Data: Into + Send, + B::Error: Into + Send + Debug, +{ + let svc = ServiceBuilder::new() + .map_request(|req: LambdaEvent| { + let event: Request = req.payload.into(); + event.with_lambda_context(req.context) + }) + .service(handler); + + lambda_runtime::run_with_streaming_response(svc).await +} diff --git a/lambda-runtime-api-client/src/lib.rs b/lambda-runtime-api-client/src/lib.rs index 42a4c54b..4b082aba 100644 --- a/lambda-runtime-api-client/src/lib.rs +++ b/lambda-runtime-api-client/src/lib.rs @@ -53,7 +53,9 @@ where /// Create a new client with a given base URI and HTTP connector. pub fn with(base: Uri, connector: C) -> Self { - let client = hyper::Client::builder().build(connector); + let client = hyper::Client::builder() + .http1_max_buf_size(1024 * 1024) + .build(connector); Self { base, client } } diff --git a/lambda-runtime/src/lib.rs b/lambda-runtime/src/lib.rs index cf03664e..31c9297c 100644 --- a/lambda-runtime/src/lib.rs +++ b/lambda-runtime/src/lib.rs @@ -34,6 +34,9 @@ mod simulated; /// Types available to a Lambda function. mod types; +mod streaming; +pub use streaming::run_with_streaming_response; + use requests::{EventCompletionRequest, EventErrorRequest, IntoRequest, NextEventRequest}; pub use types::{Context, LambdaEvent}; diff --git a/lambda-runtime/src/streaming.rs b/lambda-runtime/src/streaming.rs new file mode 100644 index 00000000..85af784e --- /dev/null +++ b/lambda-runtime/src/streaming.rs @@ -0,0 +1,258 @@ +use crate::{ + build_event_error_request, incoming, type_name_of_val, Config, Context, Error, EventErrorRequest, IntoRequest, + LambdaEvent, Runtime, +}; +use bytes::Bytes; +use futures::FutureExt; +use http::header::{CONTENT_TYPE, SET_COOKIE}; +use http::{Method, Request, Response, Uri}; +use hyper::body::HttpBody; +use hyper::{client::connect::Connection, Body}; +use lambda_runtime_api_client::{build_request, Client}; +use serde::Deserialize; +use serde_json::json; +use std::collections::HashMap; +use std::str::FromStr; +use std::{ + env, + fmt::{self, Debug, Display}, + future::Future, + panic, +}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_stream::{Stream, StreamExt}; +use tower::{Service, ServiceExt}; +use tracing::{error, trace, Instrument}; + +/// Starts the Lambda Rust runtime and stream response back [Configure Lambda +/// Streaming Response](https://docs.aws.amazon.com/lambda/latest/dg/configuration-response-streaming.html). +/// +/// # Example +/// ```no_run +/// use hyper::{body::Body, Response}; +/// use lambda_runtime::{service_fn, Error, LambdaEvent}; +/// use std::{thread, time::Duration}; +/// use serde_json::Value; +/// +/// #[tokio::main] +/// async fn main() -> Result<(), Error> { +/// lambda_runtime::run_with_streaming_response(service_fn(func)).await?; +/// Ok(()) +/// } +/// async fn func(_event: LambdaEvent) -> Result, Error> { +/// let messages = vec!["Hello ", "world ", "from ", "Lambda!"]; +/// +/// let (mut tx, rx) = Body::channel(); +/// +/// tokio::spawn(async move { +/// for message in messages.iter() { +/// tx.send_data((*message).into()).await.unwrap(); +/// thread::sleep(Duration::from_millis(500)); +/// } +/// }); +/// +/// let resp = Response::builder() +/// .header("content-type", "text/plain") +/// .header("CustomHeader", "outerspace") +/// .body(rx)?; +/// +/// Ok(resp) +/// } +/// ``` +pub async fn run_with_streaming_response(handler: F) -> Result<(), Error> +where + F: Service>, + F::Future: Future, F::Error>>, + F::Error: Debug + Display, + A: for<'de> Deserialize<'de>, + B: HttpBody + Unpin + Send + 'static, + B::Data: Into + Send, + B::Error: Into + Send + Debug, +{ + trace!("Loading config from env"); + let config = Config::from_env()?; + let client = Client::builder().build().expect("Unable to create a runtime client"); + let runtime = Runtime { client }; + + let client = &runtime.client; + let incoming = incoming(client); + runtime.run_with_streaming_response(incoming, handler, &config).await +} + +impl Runtime +where + C: Service + Clone + Send + Sync + Unpin + 'static, + C::Future: Unpin + Send, + C::Error: Into>, + C::Response: AsyncRead + AsyncWrite + Connection + Unpin + Send + 'static, +{ + pub async fn run_with_streaming_response( + &self, + incoming: impl Stream, Error>> + Send, + mut handler: F, + config: &Config, + ) -> Result<(), Error> + where + F: Service>, + F::Future: Future, F::Error>>, + F::Error: fmt::Debug + fmt::Display, + A: for<'de> Deserialize<'de>, + B: HttpBody + Unpin + Send + 'static, + B::Data: Into + Send, + B::Error: Into + Send + Debug, + { + let client = &self.client; + tokio::pin!(incoming); + while let Some(next_event_response) = incoming.next().await { + trace!("New event arrived (run loop)"); + let event = next_event_response?; + let (parts, body) = event.into_parts(); + + let ctx: Context = Context::try_from(parts.headers)?; + let ctx: Context = ctx.with_config(config); + let request_id = &ctx.request_id.clone(); + + let request_span = match &ctx.xray_trace_id { + Some(trace_id) => { + env::set_var("_X_AMZN_TRACE_ID", trace_id); + tracing::info_span!("Lambda runtime invoke", requestId = request_id, xrayTraceId = trace_id) + } + None => { + env::remove_var("_X_AMZN_TRACE_ID"); + tracing::info_span!("Lambda runtime invoke", requestId = request_id) + } + }; + + // Group the handling in one future and instrument it with the span + async { + let body = hyper::body::to_bytes(body).await?; + trace!("incoming request payload - {}", std::str::from_utf8(&body)?); + + let body = match serde_json::from_slice(&body) { + Ok(body) => body, + Err(err) => { + let req = build_event_error_request(request_id, err)?; + client.call(req).await.expect("Unable to send response to Runtime APIs"); + return Ok(()); + } + }; + + let req = match handler.ready().await { + Ok(handler) => { + // Catches panics outside of a `Future` + let task = + panic::catch_unwind(panic::AssertUnwindSafe(|| handler.call(LambdaEvent::new(body, ctx)))); + + let task = match task { + // Catches panics inside of the `Future` + Ok(task) => panic::AssertUnwindSafe(task).catch_unwind().await, + Err(err) => Err(err), + }; + + match task { + Ok(response) => match response { + Ok(response) => { + trace!("Ok response from handler (run loop)"); + EventCompletionStreamingRequest { + request_id, + body: response, + } + .into_req() + } + Err(err) => build_event_error_request(request_id, err), + }, + Err(err) => { + error!("{:?}", err); + let error_type = type_name_of_val(&err); + let msg = if let Some(msg) = err.downcast_ref::<&str>() { + format!("Lambda panicked: {msg}") + } else { + "Lambda panicked".to_string() + }; + EventErrorRequest::new(request_id, error_type, &msg).into_req() + } + } + } + Err(err) => build_event_error_request(request_id, err), + }?; + + client.call(req).await.expect("Unable to send response to Runtime APIs"); + Ok::<(), Error>(()) + } + .instrument(request_span) + .await?; + } + Ok(()) + } +} + +pub(crate) struct EventCompletionStreamingRequest<'a, B> { + pub(crate) request_id: &'a str, + pub(crate) body: Response, +} + +impl<'a, B> EventCompletionStreamingRequest<'a, B> +where + B: HttpBody + Unpin + Send + 'static, + B::Data: Into + Send, + B::Error: Into + Send + Debug, +{ + fn into_req(self) -> Result, Error> { + let uri = format!("/2018-06-01/runtime/invocation/{}/response", self.request_id); + let uri = Uri::from_str(&uri)?; + + let (parts, mut body) = self.body.into_parts(); + + let mut builder = build_request().method(Method::POST).uri(uri); + let headers = builder.headers_mut().unwrap(); + + headers.insert("Transfer-Encoding", "chunked".parse()?); + headers.insert("Lambda-Runtime-Function-Response-Mode", "streaming".parse()?); + headers.insert( + "Content-Type", + "application/vnd.awslambda.http-integration-response".parse()?, + ); + + let (mut tx, rx) = Body::channel(); + + tokio::spawn(async move { + let mut header_map = parts.headers; + // default Content-Type + header_map + .entry(CONTENT_TYPE) + .or_insert("application/octet-stream".parse().unwrap()); + + let cookies = header_map.get_all(SET_COOKIE); + let cookies = cookies + .iter() + .map(|c| String::from_utf8_lossy(c.as_bytes()).to_string()) + .collect::>(); + + let headers = header_map + .iter() + .filter(|(k, _)| *k != SET_COOKIE) + .map(|(k, v)| (k.as_str(), String::from_utf8_lossy(v.as_bytes()).to_string())) + .collect::>(); + + let metadata_prelude = json!({ + "statusCode": parts.status.as_u16(), + "headers": headers, + "cookies": cookies, + }) + .to_string(); + + trace!("metadata_prelude: {}", metadata_prelude); + + tx.send_data(metadata_prelude.into()).await.unwrap(); + tx.send_data("\u{0}".repeat(8).into()).await.unwrap(); + + while let Some(chunk) = body.data().await { + let chunk = chunk.unwrap(); + tx.send_data(chunk.into()).await.unwrap(); + } + }); + + let req = builder.body(rx)?; + Ok(req) + } +}